From abe3cbab1fbd045dafcbb6ab41522a92e63c6637 Mon Sep 17 00:00:00 2001 From: Alexandre Tuleu Date: Tue, 18 Aug 2015 18:34:23 +0200 Subject: [PATCH] Adds unit testing for jws --- jwt/jose.go | 22 ++++++------- jwt/jose_test.go | 15 +++++++++ jwt/jws.go | 35 +++++++++++++++------ jwt/jws_test.go | 77 ++++++++++++++++++++++++++++++++++++++++++++++ jwt/signer.go | 6 ++-- jwt/signer_test.go | 55 +++++++++++++++++++++++++++++++++ 6 files changed, 187 insertions(+), 23 deletions(-) create mode 100644 jwt/jws_test.go create mode 100644 jwt/signer_test.go diff --git a/jwt/jose.go b/jwt/jose.go index 7bc7e22..72e2724 100644 --- a/jwt/jose.go +++ b/jwt/jose.go @@ -12,17 +12,17 @@ import ( // JOSE represent a JSON Object Signing and Encryption header as // defined in RFC 7515. type JOSE struct { - Algorithm string `json:"alg"` - Type string `json:"typ,omitempty"` - Content string `json:"cty,omitempty"` - Critical string `json:"crit,omitempty"` - JWKSetURL string `json:"jku,omitempty"` - JSONWebKey string `json:"jwk,omitempty"` - KeyID string `json:"kid,omitempty"` - X509URL string `json:"x5u,omitempty"` - X509CertificateChain string `json:"x5c,omitempty"` - X509ThumbprintSha1 string `json:"x5t,omitempty"` - X509ThumbprintSha256 string `json:"x5t#S256,omitempty"` + Algorithm string `json:"alg"` + Type string `json:"typ,omitempty"` + Content string `json:"cty,omitempty"` + Critical []string `json:"crit,omitempty"` + JWKSetURL string `json:"jku,omitempty"` + JSONWebKey string `json:"jwk,omitempty"` + KeyID string `json:"kid,omitempty"` + X509URL string `json:"x5u,omitempty"` + X509CertificateChain string `json:"x5c,omitempty"` + X509ThumbprintSha1 string `json:"x5t,omitempty"` + X509ThumbprintSha256 string `json:"x5t#S256,omitempty"` //Sets of adfditional headers, private or public AdditionalHeaders map[string]interface{} `json:"-"` diff --git a/jwt/jose_test.go b/jwt/jose_test.go index cfc809d..269cd6c 100644 --- a/jwt/jose_test.go +++ b/jwt/jose_test.go @@ -86,3 +86,18 @@ func (s *JOSESuite) TestHandlesPrivateHeaders(c *C) { c.Assert(res, DeepEquals, j) } } + +func (s *JOSESuite) TestEncodingIsCompact(c *C) { + j := &JOSE{ + Algorithm: "none", + Critical: []string{"exp", "foo"}, + } + // it does not print empty fields. + data, err := j.EncodeJSON() + c.Check(err, IsNil) + c.Check(string(data), Equals, `{"alg":"none","crit":["exp","foo"]}`) + + b64, err := j.EncodeBase64() + c.Check(err, IsNil) + c.Check(len(b64), Equals, 47) +} diff --git a/jwt/jws.go b/jwt/jws.go index 978cbbf..c562749 100644 --- a/jwt/jws.go +++ b/jwt/jws.go @@ -14,7 +14,11 @@ func EncodeJWS(j *JOSE, v interface{}, s Signer) ([]byte, error) { return nil, err } - j.Algorithm = s.Algorithm() + if s != nil { + j.Algorithm = s.Algorithm() + } else { + j.Algorithm = "none" + } header, err := j.EncodeJSON() if err != nil { @@ -24,17 +28,17 @@ func EncodeJWS(j *JOSE, v interface{}, s Signer) ([]byte, error) { //Allocate a buffer long enough for all payload res := make([]byte, Base64EncodedBufferLen(len(header))+Base64EncodedBufferLen(len(payload))+2) base64.URLEncoding.Encode(res, header) - lengthHeader := Base64EncodedStrippedLen(len(res)) + lengthHeader := Base64EncodedStrippedLen(len(header)) res[lengthHeader] = '.' base64.URLEncoding.Encode(res[lengthHeader+1:], payload) - fullPayloadLength := Base64EncodedBufferLen(len(payload)) + 1 + lengthHeader + fullPayloadLength := Base64EncodedStrippedLen(len(payload)) + 1 + lengthHeader res[fullPayloadLength] = '.' if s == nil { // unprotected jws, not signing it return res[:fullPayloadLength+1], nil } - signature, err := s.Sign(res) + signature, err := s.Sign(res[:fullPayloadLength]) if err != nil { return nil, err } @@ -111,9 +115,16 @@ func DecodeJWS(data []byte, v interface{}, s Signer) error { signature := make([]byte, Base64DecodedLenFromStripped(signatureLength)) signedLength := headerLength + payloadLength + 1 Base64Decode(signature, data[signedLength+1:]) + signature = signature[:Base64DecodedStrippedLen(signatureLength)] - if err := s.Verify(data[:signedLength], signature); err != nil { - return err + if s != nil { + if err := s.Verify(data[:signedLength], signature); err != nil { + return err + } + } else { + if signatureLength != 0 { + return fmt.Errorf("jws: Invalid JWS, got a signature, but none expected") + } } //decode jose @@ -122,12 +133,18 @@ func DecodeJWS(data []byte, v interface{}, s Signer) error { return err } - if jose.Algorithm != s.Algorithm() { - return fmt.Errorf("jws: Mismatched signing algorithm got %s, expected %s", jose.Algorithm, s.Algorithm()) + algo := "none" + if s != nil { + algo = s.Algorithm() + } + + if jose.Algorithm != algo { + return fmt.Errorf("jws: Mismatched signing algorithm got %s, expected %s", jose.Algorithm, algo) } payload := make([]byte, Base64DecodedLenFromStripped(payloadLength)) - Base64Decode(payload, data[headerLength+1:signatureLength]) + Base64Decode(payload, data[headerLength+1:headerLength+1+payloadLength]) + payload = payload[:Base64DecodedStrippedLen(payloadLength)] //data is safe, just need to decode it. return json.Unmarshal(payload, v) } diff --git a/jwt/jws_test.go b/jwt/jws_test.go new file mode 100644 index 0000000..3fdca5c --- /dev/null +++ b/jwt/jws_test.go @@ -0,0 +1,77 @@ +package jwt + +import ( + "crypto/rsa" + + . "gopkg.in/check.v1" +) + +type JWSSuite struct { + signers []Signer + hmacKey []byte + rsaKey *rsa.PrivateKey +} + +var _ = Suite(&JWSSuite{}) + +func (s *JWSSuite) SetUpSuite(c *C) { + var err error + s.hmacKey = NewHMACKey(24) + + s.rsaKey, err = CachedRSAkey() + + c.Assert(err, IsNil) + + s.signers = []Signer{ + NewHMAC256Signer(s.hmacKey), + NewHMAC384Signer(s.hmacKey), + NewHMAC512Signer(s.hmacKey), + NewRSA256Signer(s.rsaKey), + NewRSA384Signer(s.rsaKey), + NewRSA512Signer(s.rsaKey), + } +} + +func (s *JWSSuite) TestCanEncodeSigned(c *C) { + type foo struct { + A string + B int + } + + j := &JOSE{} + + for _, signer := range s.signers { + decoded := foo{A: "blah", B: 42} + res, err := EncodeJWS(j, decoded, signer) + c.Check(string(res), Matches, `\A[0-9A-Za-z_\-]+\.[0-9A-Za-z_\-]+.[0-9A-Za-z_\-]+\z`) + if c.Check(err, IsNil) == false { + continue + } + redecoded := foo{} + err = DecodeJWS(res, &redecoded, signer) + if c.Check(err, IsNil, Commentf("Algo is: %s", signer.Algorithm())) == true { + c.Check(redecoded, DeepEquals, decoded) + } + } +} + +func (s *JWSSuite) TestCanEncodeUnprotected(c *C) { + + type foo struct { + A string + B int + } + + j := &JOSE{} + + decoded := foo{A: "blah", B: 42} + res, err := EncodeJWS(j, decoded, nil) + c.Check(string(res), Matches, `\A[0-9A-Za-z_\-]+\.[0-9A-Za-z_\-]+.\z`) + c.Assert(err, IsNil) + + redecoded := foo{} + err = DecodeJWS(res, &redecoded, nil) + c.Assert(err, IsNil) + c.Check(redecoded, DeepEquals, decoded) + +} diff --git a/jwt/signer.go b/jwt/signer.go index 50c3885..48825e6 100644 --- a/jwt/signer.go +++ b/jwt/signer.go @@ -49,7 +49,7 @@ func NewHMAC256Signer(key []byte) Signer { // NewHMAC384Signer returs a HMACSigner using a SHA384 hash. func NewHMAC384Signer(key []byte) Signer { return &HMACSigner{ - name: "HS256", + name: "HS384", hash: crypto.SHA384, key: key, } @@ -58,8 +58,8 @@ func NewHMAC384Signer(key []byte) Signer { // NewHMAC512Signer returs a HMACSigner using a SHA512 hash. func NewHMAC512Signer(key []byte) Signer { return &HMACSigner{ - name: "HS256", - hash: crypto.SHA384, + name: "HS512", + hash: crypto.SHA512, key: key, } } diff --git a/jwt/signer_test.go b/jwt/signer_test.go new file mode 100644 index 0000000..d62069f --- /dev/null +++ b/jwt/signer_test.go @@ -0,0 +1,55 @@ +package jwt + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "fmt" + "io/ioutil" + "log" + "os" + "path/filepath" +) + +func NewHMACKey(n int) []byte { + res := make([]byte, n) + rand.Reader.Read(res) + return res +} + +func CachedRSAkey() (*rsa.PrivateKey, error) { + keyPath := filepath.Join(os.TempDir(), "narco-jwt-test.key") + f, err := os.Open(keyPath) + if err != nil { + if os.IsNotExist(err) == false { + return nil, err + } + // generate a key + log.Printf("Generating a new key") + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + return nil, err + } + + f, err := os.Create(keyPath) + if err != nil { + log.Printf("Could not cache the generated key: %s", err) + return key, nil + } + + data := x509.MarshalPKCS1PrivateKey(key) + _, err = f.Write(data) + if err != nil { + log.Printf("Could not cache the generated key: %s", err) + } + + return key, nil + } + + keydata, err := ioutil.ReadAll(f) + if err != nil { + return nil, fmt.Errorf("Could not read %s: %s", keyPath, err) + } + + return x509.ParsePKCS1PrivateKey(keydata) +}