golang-jwt / jwt

Go implementation of JSON Web Tokens (JWT).

Home Page:https://golang-jwt.github.io/jwt/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

I've mad a small library to help with JWT

ivanjaros opened this issue · comments

This library is lacking in "batteries included" department. A lot of manual work has to be done to make it work. So I made a small library to help with this. Consider implementing whatever parts you think are suitable.

package jwtool

// contains custom imports:
// random is just random string generator
// protoj is just json encoder

type Validatable interface {
	Validate() error
}

const TokenIdLength = 15

// returns a random TokenIdLength characters long string.
// the string is not crypto-random and it consists of a-z, A-Z and 0-9 charset without any special characters.
func MakeTokenId() string {
	return random.StandardString(TokenIdLength)
}

// supports secret as ecdsa.PrivateKey, string or []byte
func Sign(secret any, claims jwt.Claims) (string, error) {
	var signing jwt.SigningMethod

	switch secret.(type) {
	case ecdsa.PrivateKey:
		switch secret.(ecdsa.PrivateKey).Curve.Params().BitSize {
		case 256:
			signing = jwt.SigningMethodES256
		case 384:
			signing = jwt.SigningMethodES384
		case 521:
			signing = jwt.SigningMethodES512
		default:
			return "", errors.New("unknown elliptic curve")
		}
		key := secret.(ecdsa.PrivateKey)
		return jwt.NewWithClaims(signing, claims).SignedString(&key)
	case string:
		signing = jwt.SigningMethodHS384
		key := []byte(secret.(string))
		return jwt.NewWithClaims(signing, claims).SignedString(key)
	case []byte:
		signing = jwt.SigningMethodHS384
		key := secret.([]byte)
		return jwt.NewWithClaims(signing, claims).SignedString(key)
	default:
		return "", errors.New("unknown secret type")
	}
}

// supports secret as ecdsa.PublicKey, string or []byte
func Parse(secret any, str string, claims jwt.Claims) (*jwt.Token, error) {
	var method string
	var tok *jwt.Token
	var err error

	switch secret.(type) {
	case ecdsa.PublicKey:
		switch secret.(ecdsa.PublicKey).Curve.Params().BitSize {
		case 256:
			method = jwt.SigningMethodES256.Alg()
		case 384:
			method = jwt.SigningMethodES384.Alg()
		case 521:
			method = jwt.SigningMethodES512.Alg()
		default:
			return nil, errors.New("unknown elliptic curve")
		}
		key := secret.(ecdsa.PublicKey)
		tok, err = jwt.ParseWithClaims(str, claims, func(token *jwt.Token) (any, error) { return &key, nil }, jwt.WithValidMethods([]string{method}))
	case string:
		method = jwt.SigningMethodHS384.Alg()
		key := []byte(secret.(string))
		tok, err = jwt.ParseWithClaims(str, claims, func(token *jwt.Token) (any, error) { return key, nil }, jwt.WithValidMethods([]string{method}))
	case []byte:
		method = jwt.SigningMethodHS384.Alg()
		tok, err = jwt.ParseWithClaims(str, claims, func(token *jwt.Token) (any, error) { return secret, nil }, jwt.WithValidMethods([]string{method}))
	default:
		return nil, errors.New("unknown secret type")
	}

	if err != nil {
		return nil, err
	}

	if tok.Valid == false {
		return nil, errors.New("invalid token")
	}

	return tok, nil
}

func MakeClaims(options ...MakeOption) jwt.RegisteredClaims {
	now := time.Now()
	var c jwt.RegisteredClaims
	defaults := []MakeOption{
		WithExpiration(now.Add(time.Minute * 15)),
		WithId(MakeTokenId()),
		WithIssuedAt(now),
		WithNotBefore(now.Add(time.Minute * -1)), // compensate possible differences between client and server
	}
	for k := range defaults {
		defaults[k](&c)
	}
	for k := range options {
		options[k](&c)
	}
	return c
}

type MakeOption func(*jwt.RegisteredClaims)

func WithAudience(aud string) MakeOption {
	return func(c *jwt.RegisteredClaims) {
		c.Audience = append(c.Audience, aud)
	}
}

func WithExpiration(exp time.Time) MakeOption {
	return func(c *jwt.RegisteredClaims) {
		c.ExpiresAt = jwt.NewNumericDate(exp)
	}
}

func WithLifespan(l time.Duration) MakeOption {
	return WithExpiration(time.Now().Add(l))
}

func WithId(id string) MakeOption {
	return func(c *jwt.RegisteredClaims) {
		c.ID = id
	}
}

func WithIssuedAt(ia time.Time) MakeOption {
	return func(c *jwt.RegisteredClaims) {
		c.IssuedAt = jwt.NewNumericDate(ia)
	}
}

func WithIssuer(iss string) MakeOption {
	return func(c *jwt.RegisteredClaims) {
		c.Issuer = iss
	}
}

func WithNotBefore(nbf time.Time) MakeOption {
	return func(c *jwt.RegisteredClaims) {
		c.NotBefore = jwt.NewNumericDate(nbf)
	}
}

func WithSubject(sub string) MakeOption {
	return func(c *jwt.RegisteredClaims) {
		c.Subject = sub
	}
}

type UniversalClaims struct {
	jwt.RegisteredClaims
	Payload json.RawMessage `json:"pld"`
}

func NewToken(payload any, secret any, options ...MakeOption) (string, error) {
	typed, ok := payload.(Validatable)
	if ok {
		if err := typed.Validate(); err != nil {
			return "", err
		}
	}
	data, err := protoj.MarshalDenseData(payload)
	if err != nil {
		return "", err
	}
	c := UniversalClaims{RegisteredClaims: MakeClaims(options...), Payload: data}
	return Sign(secret, c)
}

func ParseToken(secret any, str string, payload any, validators ...Validator) error {
	var c UniversalClaims
	if _, err := Parse(secret, str, &c); err != nil {
		return err
	}
	for k := range validators {
		if err := validators[k](c.RegisteredClaims); err != nil {
			return err
		}
	}
	if payload != nil {
		if err := protoj.UnmarshalData(c.Payload, &payload); err != nil {
			return err
		}
		typed, ok := payload.(Validatable)
		if ok {
			return typed.Validate()
		}
	}
	return nil
}

type Validator func(jwt.RegisteredClaims) error

func VerifySubject(sub string) Validator {
	return func(c jwt.RegisteredClaims) error {
		if c.Subject != sub {
			return errors.New("expected subject '"+sub"', got '"+c.Subject+"'")
		}
		return nil
	}
}

func VerifyIssuer(iss ...string) Validator {
	return func(c jwt.RegisteredClaims) error {
		for k := range iss {
			if c.VerifyIssuer(iss[k], false) {
				return nil
			}
		}
		return errors.New("no issuer match")
	}
}

func VerifyAudience(aud ...string) Validator {
	return func(c jwt.RegisteredClaims) error {
		for k := range aud {
			if c.VerifyAudience(aud[k], false) {
				return nil
			}
		}
		return errors.New("no audience match")
	}
}

Usage example:

func NewFooToken(secret ecdsa.PrivateKey, issuer string, p FooPayload) (string, error) {
	if issuer != "Foo" {
		return "", errors.New("invalid issuer")
	}

	return jwtool.NewToken(p, secret,
		jwtool.WithIssuer(issuer),
		jwtool.WithSubject("foo_bar"),
		jwtool.WithAudience("baz"),
		jwtool.WithLifespan(time.Minute*5), // instead of default 15 min
	)
}

func ParseFooToken(secret ecdsa.PublicKey, str string) (FooPayload, error) {
	var p FooPayload
	err := jwtool.ParseToken(secret, str, &p, jwtool.VerifyIssuer("Foo"), jwtool.VerifySubject("foo_bar"), jwtool.VerifyAudience("baz"))
	return p, err
}