diff --git a/jwt/base64.go b/jwt/base64.go index 43c7e6e..1a8878e 100644 --- a/jwt/base64.go +++ b/jwt/base64.go @@ -3,30 +3,61 @@ package jwt import ( "encoding/base64" "fmt" - "strings" ) -// Base64Encode encodes a payload of data as a string, like specified -// in RFC 7515 (no trailing '=' -func Base64Encode(b []byte) string { - return strings.TrimRight(base64.URLEncoding.EncodeToString(b), "=") +// Base64EncodedBufferLen return the size needed to encode the +// Base64Encode (padding '=' would be added, can be stripped with +// Base64DecodedStrippedLen +func Base64EncodedBufferLen(n int) int { + return base64.URLEncoding.EncodedLen(n) } -// Base64Decode decodes a payload of data as a []byte, like specified -// in RFC 7515 (no trailing '=') -func Base64Decode(s string) ([]byte, error) { - switch len(s) % 4 { +// Base64EncodedStrippedLen return the size of the encoded slice with no +// padding. +func Base64EncodedStrippedLen(n int) int { + eN := Base64EncodedBufferLen(n) + mod := n % 3 + if mod == 0 { + return eN + } + return eN - (3 - mod) +} + +// Base64DecodedStrippedLen returns the size of the data, without 0 +// bit padding, from the stripped length n. +func Base64DecodedStrippedLen(n int) int { + return Base64DecodedLenFromStripped(n) + n - ((n+3)/4)*4 +} + +// Base64DecodedLenFromStripped returns the size of the data, with 0 +// bit padding, from the stripped length n. +func Base64DecodedLenFromStripped(n int) int { + return base64.URLEncoding.DecodedLen(4 * ((n + 3) / 4)) +} + +// Base64Decode decodes a payload of data as a []byte, from a stripped +// '=' encoded string. +func Base64Decode(dst, src []byte) error { + switch len(src) % 4 { + // in this case, we have to copy the src, in order to allow decode part of the input case 2: - s += "==" + oldSrc := src + src = make([]byte, len(src)+2) + copy(src, oldSrc) + src[len(src)-2] = '=' + src[len(src)-1] = '=' case 3: - s += "=" + oldSrc := src + src = make([]byte, len(src)+1) + copy(src, oldSrc) + src[len(src)-1] = '=' case 1: - return nil, fmt.Errorf("jwt: Invalid base64 string (length:%d %% 4 == 1): '%s'", len(s), s) + return fmt.Errorf("jwt: Invalid base64 string (length:%d %% 4 == 1): '%s'", len(src), src) } - d, err := base64.URLEncoding.DecodeString(s) + _, err := base64.URLEncoding.Decode(dst, src) if err != nil { - return nil, fmt.Errorf("jwt: %s", err) + return fmt.Errorf("jwt: %s", err) } - return d, nil + return nil } diff --git a/jwt/base64_test.go b/jwt/base64_test.go index 5bb17d1..867c1ab 100644 --- a/jwt/base64_test.go +++ b/jwt/base64_test.go @@ -1,6 +1,7 @@ package jwt import ( + "encoding/base64" "testing" . "gopkg.in/check.v1" @@ -16,19 +17,20 @@ var _ = Suite(&Base64Suite{}) type dataAndEncode struct { data []byte - encoded string + encoded []byte } func (s *Base64Suite) TestEncodeWithNoTrailing(c *C) { data := []dataAndEncode{ - {[]byte{0, 0, 0}, "AAAA"}, //With trailing Should be AAAA - {[]byte{0, 0}, "AAA"}, //With trailing Should be AAA= - {[]byte{0}, "AA"}, //With trailing Should be AA== + {[]byte{0, 0, 0}, []byte("AAAA")}, //With trailing Should be AAAA + {[]byte{0, 0}, []byte("AAA")}, //With trailing Should be AAA= + {[]byte{0}, []byte("AA")}, //With trailing Should be AA== } for _, d := range data { - res := Base64Encode(d.data) - c.Check(res, Equals, d.encoded) + res := make([]byte, Base64EncodedBufferLen(len(d.data))) + base64.URLEncoding.Encode(res, d.data) + c.Check(res[:Base64EncodedStrippedLen(len(d.data))], DeepEquals, d.encoded) } } @@ -41,11 +43,12 @@ func (s *Base64Suite) TestDecodeWithNoTrailing(c *C) { } for encoded, expected := range data { - res, err := Base64Decode(encoded) + res := make([]byte, Base64DecodedLenFromStripped(len(encoded))) + err := Base64Decode(res, []byte(encoded)) if c.Check(err, IsNil) == false { continue } - c.Check(res, DeepEquals, expected) + c.Check(res[:Base64DecodedStrippedLen(len(encoded))], DeepEquals, expected) } } @@ -56,8 +59,8 @@ func (s *Base64Suite) TestDetectBadFormat(c *C) { } for encoded, errorMatches := range data { - res, err := Base64Decode(encoded) + res := make([]byte, Base64DecodedLenFromStripped(len(encoded))) + err := Base64Decode(res, []byte(encoded)) c.Check(err, ErrorMatches, errorMatches) - c.Check(res, IsNil) } } diff --git a/jwt/jose.go b/jwt/jose.go new file mode 100644 index 0000000..7bc7e22 --- /dev/null +++ b/jwt/jose.go @@ -0,0 +1,147 @@ +package jwt + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "reflect" + "strings" +) + +// JOSE represent a JSON Object Signing and Encryption header as +// defined in RFC 7515. +type JOSE struct { + Algorithm string `json:"alg"` + Type string `json:"typ,omitempty"` + Content string `json:"cty,omitempty"` + Critical string `json:"crit,omitempty"` + JWKSetURL string `json:"jku,omitempty"` + JSONWebKey string `json:"jwk,omitempty"` + KeyID string `json:"kid,omitempty"` + X509URL string `json:"x5u,omitempty"` + X509CertificateChain string `json:"x5c,omitempty"` + X509ThumbprintSha1 string `json:"x5t,omitempty"` + X509ThumbprintSha256 string `json:"x5t#S256,omitempty"` + + //Sets of adfditional headers, private or public + AdditionalHeaders map[string]interface{} `json:"-"` +} + +// EncodeJSON encodes a JOSE header in JSON format. Not using +// MarshallJSON tyo avoid loops +func (j *JOSE) EncodeJSON() ([]byte, error) { + data, err := json.Marshal(j) + if err != nil { + return nil, err + } + + if len(data) == 0 || data[0] != '{' || data[len(data)-1] != '}' { + return nil, fmt.Errorf("jws: Invalid JSON encoding: %s", data) + } + + if len(j.AdditionalHeaders) == 0 { + return data, nil + } + + moreHeader, err := json.Marshal(j.AdditionalHeaders) + if err != nil { + return nil, err + } + + if len(moreHeader) == 0 || moreHeader[0] != '{' || moreHeader[len(moreHeader)-1] != '}' { + return nil, fmt.Errorf("jws: Invalid JSON encoding: %s", data) + } + data[len(data)-1] = ',' + return append(data, moreHeader[1:]...), nil +} + +// EncodeBase64 encodes a JOSE into JSON base64 string. will allocate +// buffer. +func (j *JOSE) EncodeBase64() ([]byte, error) { + dec, err := j.EncodeJSON() + if err != nil { + return nil, err + } + + enc := make([]byte, Base64EncodedBufferLen(len(dec))) + base64.URLEncoding.Encode(enc, dec) + return enc[:Base64EncodedStrippedLen(len(dec))], nil +} + +// DecodeJOSE a JOSE header from a base64 text as defined in the RFC 7515 +func DecodeJOSE(data []byte) (*JOSE, error) { + bData := make([]byte, Base64DecodedLenFromStripped(len(data))) + + if err := Base64Decode(bData, data); err != nil { + return nil, err + } + + j := &JOSE{} + if err := json.NewDecoder(bytes.NewBuffer(bData)).Decode(j); err != nil { + return nil, err + } + + return j, nil +} + +// Validate validates a JOSE header data. Maybe it should disappear +func (j *JOSE) Validate() error { + if len(j.Algorithm) == 0 { + return fmt.Errorf("jwt: missing 'alg' header") + } + + return nil +} + +type fieldSetter func(*JOSE, string) + +var fieldSetters = make(map[string]fieldSetter) + +// UnmarshalJSON is here to satisfy interface json.Unmarshaller. We +// need to provide ou own unmarshaller for the additional header. +func (j *JOSE) UnmarshalJSON(b []byte) error { + raw := make(map[string]interface{}) + + if err := json.Unmarshal(b, &raw); err != nil { + return err + } + + if j.AdditionalHeaders == nil { + j.AdditionalHeaders = make(map[string]interface{}, len(raw)) + } + + for key, data := range raw { + if fSetter, ok := fieldSetters[key]; ok == true { + strData, ok := data.(string) + if ok == false { + return &json.UnmarshalTypeError{ + Value: reflect.TypeOf(data).Kind().String(), + Type: reflect.TypeOf(string("")), + } + } + + fSetter(j, strData) + continue + } + + j.AdditionalHeaders[key] = data + } + return nil +} + +func init() { + tJOSE := reflect.TypeOf(JOSE{}) + for i := 0; i < tJOSE.NumField(); i++ { + jField := tJOSE.Field(i) + jTag := jField.Tag.Get("json") + if len(jTag) == 0 || jTag == "-" { + continue + } + + jName := strings.Split(jTag, ",") + fieldSetters[jName[0]] = func(j *JOSE, data string) { + reflect.ValueOf(j).Elem().FieldByName(jField.Name).SetString(data) + } + } +} diff --git a/jwt/jose_test.go b/jwt/jose_test.go new file mode 100644 index 0000000..cfc809d --- /dev/null +++ b/jwt/jose_test.go @@ -0,0 +1,88 @@ +package jwt + +import ( + "encoding/base64" + + . "gopkg.in/check.v1" +) + +type JOSESuite struct{} + +var _ = Suite(&JOSESuite{}) + +func (s *JOSESuite) TestCheckForDataEncoding(c *C) { + data := map[string]string{ + `"alg":"none"}`: `json: .*`, + `{alg:"none"}`: `invalid character '.' looking for beginning of object key string`, + `{"alg":1234}`: `json: cannot unmarshal float64 into Go value of type string`, + } + + for dec, errMatches := range data { + encoded := make([]byte, Base64EncodedBufferLen(len(dec))) + base64.URLEncoding.Encode(encoded, []byte(dec)) + + j, err := DecodeJOSE(encoded) + c.Check(j, IsNil) + c.Check(err, ErrorMatches, errMatches) + } + +} + +func (s *JOSESuite) TestHandleDuplicateField(c *C) { + + decoded := `{"alg":"foo","alg":"none"}` + encoded := make([]byte, Base64EncodedBufferLen(len(decoded))) + base64.URLEncoding.Encode(encoded, []byte(decoded)) + j, err := DecodeJOSE(encoded) + c.Assert(j, NotNil) + c.Assert(err, IsNil) + c.Assert(j.Algorithm, Equals, "none") + +} + +func (s *JOSESuite) TestHeaderValidation(c *C) { + + data := map[string]string{ + `{"alg":""}`: `jwt: missing 'alg' header`, + `{}`: `jwt: missing 'alg' header`, + } + + for dec, errMatches := range data { + encoded := make([]byte, Base64EncodedBufferLen(len(dec))) + base64.URLEncoding.Encode(encoded, []byte(dec)) + j, err := DecodeJOSE(encoded) + c.Check(j, NotNil) + c.Check(err, IsNil) + c.Check(j.Validate(), ErrorMatches, errMatches) + } + +} + +func (s *JOSESuite) TestHandlesPrivateHeaders(c *C) { + data := []map[string]interface{}{ + map[string]interface{}{ + "foo": true, + "bar": 12.34, + "baz": float64(456), //default of json is to parse float64 + "blah": "omg", + }, + } + + for _, headers := range data { + j := &JOSE{ + Algorithm: "none", + AdditionalHeaders: make(map[string]interface{}), + } + for n, v := range headers { + j.AdditionalHeaders[n] = v + } + enc, err := j.EncodeBase64() + if c.Check(err, IsNil) == false { + continue + } + + res, err := DecodeJOSE(enc) + c.Check(err, IsNil) + c.Assert(res, DeepEquals, j) + } +} diff --git a/jwt/jws.go b/jwt/jws.go new file mode 100644 index 0000000..978cbbf --- /dev/null +++ b/jwt/jws.go @@ -0,0 +1,133 @@ +package jwt + +import ( + "encoding/base64" + "encoding/json" + "fmt" +) + +// EncodeJWS encode a JSON Object with the serialized JWS format as +// specified in RFC 7515 +func EncodeJWS(j *JOSE, v interface{}, s Signer) ([]byte, error) { + payload, err := json.Marshal(v) + if err != nil { + return nil, err + } + + j.Algorithm = s.Algorithm() + + header, err := j.EncodeJSON() + if err != nil { + return nil, err + } + + //Allocate a buffer long enough for all payload + res := make([]byte, Base64EncodedBufferLen(len(header))+Base64EncodedBufferLen(len(payload))+2) + base64.URLEncoding.Encode(res, header) + lengthHeader := Base64EncodedStrippedLen(len(res)) + res[lengthHeader] = '.' + base64.URLEncoding.Encode(res[lengthHeader+1:], payload) + fullPayloadLength := Base64EncodedBufferLen(len(payload)) + 1 + lengthHeader + res[fullPayloadLength] = '.' + + if s == nil { + // unprotected jws, not signing it + return res[:fullPayloadLength+1], nil + } + signature, err := s.Sign(res) + if err != nil { + return nil, err + } + + signedRes := make([]byte, fullPayloadLength+1+Base64EncodedBufferLen(len(signature))) + copy(signedRes, res[:fullPayloadLength+1]) + base64.URLEncoding.Encode(signedRes[fullPayloadLength+1:], signature) + return signedRes[:fullPayloadLength+1+Base64EncodedStrippedLen(len(signature))], nil +} + +// isBase64URLEncoding returns if the byte is a character used for +// base64 URL Encoding +func isBase64URLEncoding(b byte) bool { + if b == '-' || b == '_' { + return true + } + + if b < '0' || b > 'z' { + return false + } + + if b <= '9' { + return true + } + + if b < 'A' { + return false + } + + if b <= 'Z' { + return true + } + + if b < 'a' { + return false + } + + return true +} + +// DecodeJWS decode an object encoded with the JWS serialized fromat +// as specified by RFX 7515. to avoid an attack where the unprotected +// JOSE header would contain a tempered signing algorithm, the Signer +// should also be specified. +func DecodeJWS(data []byte, v interface{}, s Signer) error { + var headerLength, payloadLength int + for i, c := range data { + if isBase64URLEncoding(c) == true { + continue + } + if c != '.' { + return fmt.Errorf("jws: invalid serialization character %v in '%s'", c, data) + } + + if headerLength == 0 { + if i == 0 { + return fmt.Errorf("jws: invalid emtpy JOSE header in %s", data) + } + headerLength = i + continue + } + + if payloadLength == 0 { + if i-headerLength == 1 { + return fmt.Errorf("jws: invalid emtpy payload in %s", data) + } + payloadLength = i - headerLength - 1 + continue + } + + return fmt.Errorf("jws: invalid third '.' in %s", data) + } + signatureLength := len(data) - 2 - headerLength - payloadLength + signature := make([]byte, Base64DecodedLenFromStripped(signatureLength)) + signedLength := headerLength + payloadLength + 1 + Base64Decode(signature, data[signedLength+1:]) + + if err := s.Verify(data[:signedLength], signature); err != nil { + return err + } + + //decode jose + jose, err := DecodeJOSE(data[:headerLength]) + if err != nil { + return err + } + + if jose.Algorithm != s.Algorithm() { + return fmt.Errorf("jws: Mismatched signing algorithm got %s, expected %s", jose.Algorithm, s.Algorithm()) + } + + payload := make([]byte, Base64DecodedLenFromStripped(payloadLength)) + Base64Decode(payload, data[headerLength+1:signatureLength]) + //data is safe, just need to decode it. + return json.Unmarshal(payload, v) +} diff --git a/jwt/signer.go b/jwt/signer.go new file mode 100644 index 0000000..50c3885 --- /dev/null +++ b/jwt/signer.go @@ -0,0 +1,164 @@ +package jwt + +import ( + "crypto" + "crypto/hmac" + "crypto/rand" + "crypto/rsa" + "fmt" +) + +type hashLinkError struct { + name string +} + +func (e *hashLinkError) Error() string { + return fmt.Sprintf("jws: Hash %s is not linked into binary", e.name) +} + +// A Signer is able to crypto-sign a JWS object. +type Signer interface { + // Sign returns the signature of a binary payload. It should + // returns the signature in binary format + Sign([]byte) ([]byte, error) + // Verify check if the signed binary paylod correspond to the + // provided binary siganture. It assumes a binary format + // signature. + Verify(signed, signature []byte) error + // Algorithm is returning a string idnetifying the algorithm used + // as defined in RFC 7518. + Algorithm() string +} + +// A HMACSigner is a Signer using HMAC based algorithm +type HMACSigner struct { + name string + hash crypto.Hash + key []byte +} + +// NewHMAC256Signer returs a HMACSigner using a SHA256 hash. +func NewHMAC256Signer(key []byte) Signer { + return &HMACSigner{ + name: "HS256", + hash: crypto.SHA256, + key: key, + } +} + +// NewHMAC384Signer returs a HMACSigner using a SHA384 hash. +func NewHMAC384Signer(key []byte) Signer { + return &HMACSigner{ + name: "HS256", + hash: crypto.SHA384, + key: key, + } +} + +// NewHMAC512Signer returs a HMACSigner using a SHA512 hash. +func NewHMAC512Signer(key []byte) Signer { + return &HMACSigner{ + name: "HS256", + hash: crypto.SHA384, + key: key, + } +} + +// Algorithm return 'HSXXX' where XXX is either 256|384|512 depending +// on the SHA Hash used. +func (s *HMACSigner) Algorithm() string { + return s.name +} + +// Sign compute the binary signature for data. +func (s *HMACSigner) Sign(data []byte) ([]byte, error) { + if s.hash.Available() == false { + return nil, &hashLinkError{s.name} + } + + hasher := hmac.New(s.hash.New, s.key) + hasher.Write(data) + return hasher.Sum(nil), nil +} + +// Verify returns an error if the binary provided signature does not +// correspond to signed. It return nil if the signature is correct. +func (s *HMACSigner) Verify(signed, signature []byte) error { + expected, err := s.Sign(signed) + if err != nil { + return err + } + if hmac.Equal(signature, expected) == false { + return fmt.Errorf("jws: Invalid %s signature", s.name) + } + + return nil +} + +// A RSASigner is a Signer using a PKS 1v15 algorithm. +type RSASigner struct { + name string + hash crypto.Hash + key *rsa.PrivateKey +} + +// NewRSA256Signer returns a RSASigner using a SHA 256 hash. +func NewRSA256Signer(key *rsa.PrivateKey) Signer { + return &RSASigner{ + name: "RS256", + hash: crypto.SHA256, + key: key, + } +} + +// NewRSA384Signer returns a RSASigner using a SHA 384 hash. +func NewRSA384Signer(key *rsa.PrivateKey) Signer { + return &RSASigner{ + name: "RS384", + hash: crypto.SHA384, + key: key, + } +} + +// NewRSA512Signer returns a RSASigner using a SHA 512 hash. +func NewRSA512Signer(key *rsa.PrivateKey) Signer { + return &RSASigner{ + name: "RS512", + hash: crypto.SHA512, + key: key, + } +} + +// Sign compute the binary signature for data. +func (s *RSASigner) Sign(data []byte) ([]byte, error) { + if s.hash.Available() == false { + return nil, &hashLinkError{s.name} + } + + hasher := s.hash.New() + hasher.Write(data) + res, err := rsa.SignPKCS1v15(rand.Reader, s.key, s.hash, hasher.Sum(nil)) + if err != nil { + return nil, err + } + return res, nil +} + +// Verify returns an error if the binary provided signature does not +// correspond to signed. It return nil if the signature is correct. +func (s *RSASigner) Verify(signed, signature []byte) error { + if s.hash.Available() == false { + return &hashLinkError{s.name} + } + + hasher := s.hash.New() + hasher.Write(signed) + + return rsa.VerifyPKCS1v15(s.key.Public().(*rsa.PublicKey), s.hash, hasher.Sum(nil), signature) +} + +// Algorithm return 'RSXXX' where XXX is either 256|384|512 depending +// on the SHA Hash used. +func (s *RSASigner) Algorithm() string { + return s.name +}