Adds JWS encode and decode

It also adds interfaces Signer and HS256 and RS256 implementation.
This commit is contained in:
2015-08-18 16:11:56 +02:00
parent 4eb3bdaa13
commit 00c5853b45
6 changed files with 591 additions and 25 deletions

View File

@@ -3,30 +3,61 @@ package jwt
import (
"encoding/base64"
"fmt"
"strings"
)
// Base64Encode encodes a payload of data as a string, like specified
// in RFC 7515 (no trailing '='
func Base64Encode(b []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(b), "=")
// Base64EncodedBufferLen return the size needed to encode the
// Base64Encode (padding '=' would be added, can be stripped with
// Base64DecodedStrippedLen
func Base64EncodedBufferLen(n int) int {
return base64.URLEncoding.EncodedLen(n)
}
// Base64Decode decodes a payload of data as a []byte, like specified
// in RFC 7515 (no trailing '=')
func Base64Decode(s string) ([]byte, error) {
switch len(s) % 4 {
// Base64EncodedStrippedLen return the size of the encoded slice with no
// padding.
func Base64EncodedStrippedLen(n int) int {
eN := Base64EncodedBufferLen(n)
mod := n % 3
if mod == 0 {
return eN
}
return eN - (3 - mod)
}
// Base64DecodedStrippedLen returns the size of the data, without 0
// bit padding, from the stripped length n.
func Base64DecodedStrippedLen(n int) int {
return Base64DecodedLenFromStripped(n) + n - ((n+3)/4)*4
}
// Base64DecodedLenFromStripped returns the size of the data, with 0
// bit padding, from the stripped length n.
func Base64DecodedLenFromStripped(n int) int {
return base64.URLEncoding.DecodedLen(4 * ((n + 3) / 4))
}
// Base64Decode decodes a payload of data as a []byte, from a stripped
// '=' encoded string.
func Base64Decode(dst, src []byte) error {
switch len(src) % 4 {
// in this case, we have to copy the src, in order to allow decode part of the input
case 2:
s += "=="
oldSrc := src
src = make([]byte, len(src)+2)
copy(src, oldSrc)
src[len(src)-2] = '='
src[len(src)-1] = '='
case 3:
s += "="
oldSrc := src
src = make([]byte, len(src)+1)
copy(src, oldSrc)
src[len(src)-1] = '='
case 1:
return nil, fmt.Errorf("jwt: Invalid base64 string (length:%d %% 4 == 1): '%s'", len(s), s)
return fmt.Errorf("jwt: Invalid base64 string (length:%d %% 4 == 1): '%s'", len(src), src)
}
d, err := base64.URLEncoding.DecodeString(s)
_, err := base64.URLEncoding.Decode(dst, src)
if err != nil {
return nil, fmt.Errorf("jwt: %s", err)
return fmt.Errorf("jwt: %s", err)
}
return d, nil
return nil
}

View File

@@ -1,6 +1,7 @@
package jwt
import (
"encoding/base64"
"testing"
. "gopkg.in/check.v1"
@@ -16,19 +17,20 @@ var _ = Suite(&Base64Suite{})
type dataAndEncode struct {
data []byte
encoded string
encoded []byte
}
func (s *Base64Suite) TestEncodeWithNoTrailing(c *C) {
data := []dataAndEncode{
{[]byte{0, 0, 0}, "AAAA"}, //With trailing Should be AAAA
{[]byte{0, 0}, "AAA"}, //With trailing Should be AAA=
{[]byte{0}, "AA"}, //With trailing Should be AA==
{[]byte{0, 0, 0}, []byte("AAAA")}, //With trailing Should be AAAA
{[]byte{0, 0}, []byte("AAA")}, //With trailing Should be AAA=
{[]byte{0}, []byte("AA")}, //With trailing Should be AA==
}
for _, d := range data {
res := Base64Encode(d.data)
c.Check(res, Equals, d.encoded)
res := make([]byte, Base64EncodedBufferLen(len(d.data)))
base64.URLEncoding.Encode(res, d.data)
c.Check(res[:Base64EncodedStrippedLen(len(d.data))], DeepEquals, d.encoded)
}
}
@@ -41,11 +43,12 @@ func (s *Base64Suite) TestDecodeWithNoTrailing(c *C) {
}
for encoded, expected := range data {
res, err := Base64Decode(encoded)
res := make([]byte, Base64DecodedLenFromStripped(len(encoded)))
err := Base64Decode(res, []byte(encoded))
if c.Check(err, IsNil) == false {
continue
}
c.Check(res, DeepEquals, expected)
c.Check(res[:Base64DecodedStrippedLen(len(encoded))], DeepEquals, expected)
}
}
@@ -56,8 +59,8 @@ func (s *Base64Suite) TestDetectBadFormat(c *C) {
}
for encoded, errorMatches := range data {
res, err := Base64Decode(encoded)
res := make([]byte, Base64DecodedLenFromStripped(len(encoded)))
err := Base64Decode(res, []byte(encoded))
c.Check(err, ErrorMatches, errorMatches)
c.Check(res, IsNil)
}
}

147
jwt/jose.go Normal file
View File

@@ -0,0 +1,147 @@
package jwt
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"reflect"
"strings"
)
// JOSE represent a JSON Object Signing and Encryption header as
// defined in RFC 7515.
type JOSE struct {
Algorithm string `json:"alg"`
Type string `json:"typ,omitempty"`
Content string `json:"cty,omitempty"`
Critical string `json:"crit,omitempty"`
JWKSetURL string `json:"jku,omitempty"`
JSONWebKey string `json:"jwk,omitempty"`
KeyID string `json:"kid,omitempty"`
X509URL string `json:"x5u,omitempty"`
X509CertificateChain string `json:"x5c,omitempty"`
X509ThumbprintSha1 string `json:"x5t,omitempty"`
X509ThumbprintSha256 string `json:"x5t#S256,omitempty"`
//Sets of adfditional headers, private or public
AdditionalHeaders map[string]interface{} `json:"-"`
}
// EncodeJSON encodes a JOSE header in JSON format. Not using
// MarshallJSON tyo avoid loops
func (j *JOSE) EncodeJSON() ([]byte, error) {
data, err := json.Marshal(j)
if err != nil {
return nil, err
}
if len(data) == 0 || data[0] != '{' || data[len(data)-1] != '}' {
return nil, fmt.Errorf("jws: Invalid JSON encoding: %s", data)
}
if len(j.AdditionalHeaders) == 0 {
return data, nil
}
moreHeader, err := json.Marshal(j.AdditionalHeaders)
if err != nil {
return nil, err
}
if len(moreHeader) == 0 || moreHeader[0] != '{' || moreHeader[len(moreHeader)-1] != '}' {
return nil, fmt.Errorf("jws: Invalid JSON encoding: %s", data)
}
data[len(data)-1] = ','
return append(data, moreHeader[1:]...), nil
}
// EncodeBase64 encodes a JOSE into JSON base64 string. will allocate
// buffer.
func (j *JOSE) EncodeBase64() ([]byte, error) {
dec, err := j.EncodeJSON()
if err != nil {
return nil, err
}
enc := make([]byte, Base64EncodedBufferLen(len(dec)))
base64.URLEncoding.Encode(enc, dec)
return enc[:Base64EncodedStrippedLen(len(dec))], nil
}
// DecodeJOSE a JOSE header from a base64 text as defined in the RFC 7515
func DecodeJOSE(data []byte) (*JOSE, error) {
bData := make([]byte, Base64DecodedLenFromStripped(len(data)))
if err := Base64Decode(bData, data); err != nil {
return nil, err
}
j := &JOSE{}
if err := json.NewDecoder(bytes.NewBuffer(bData)).Decode(j); err != nil {
return nil, err
}
return j, nil
}
// Validate validates a JOSE header data. Maybe it should disappear
func (j *JOSE) Validate() error {
if len(j.Algorithm) == 0 {
return fmt.Errorf("jwt: missing 'alg' header")
}
return nil
}
type fieldSetter func(*JOSE, string)
var fieldSetters = make(map[string]fieldSetter)
// UnmarshalJSON is here to satisfy interface json.Unmarshaller. We
// need to provide ou own unmarshaller for the additional header.
func (j *JOSE) UnmarshalJSON(b []byte) error {
raw := make(map[string]interface{})
if err := json.Unmarshal(b, &raw); err != nil {
return err
}
if j.AdditionalHeaders == nil {
j.AdditionalHeaders = make(map[string]interface{}, len(raw))
}
for key, data := range raw {
if fSetter, ok := fieldSetters[key]; ok == true {
strData, ok := data.(string)
if ok == false {
return &json.UnmarshalTypeError{
Value: reflect.TypeOf(data).Kind().String(),
Type: reflect.TypeOf(string("")),
}
}
fSetter(j, strData)
continue
}
j.AdditionalHeaders[key] = data
}
return nil
}
func init() {
tJOSE := reflect.TypeOf(JOSE{})
for i := 0; i < tJOSE.NumField(); i++ {
jField := tJOSE.Field(i)
jTag := jField.Tag.Get("json")
if len(jTag) == 0 || jTag == "-" {
continue
}
jName := strings.Split(jTag, ",")
fieldSetters[jName[0]] = func(j *JOSE, data string) {
reflect.ValueOf(j).Elem().FieldByName(jField.Name).SetString(data)
}
}
}

88
jwt/jose_test.go Normal file
View File

@@ -0,0 +1,88 @@
package jwt
import (
"encoding/base64"
. "gopkg.in/check.v1"
)
type JOSESuite struct{}
var _ = Suite(&JOSESuite{})
func (s *JOSESuite) TestCheckForDataEncoding(c *C) {
data := map[string]string{
`"alg":"none"}`: `json: .*`,
`{alg:"none"}`: `invalid character '.' looking for beginning of object key string`,
`{"alg":1234}`: `json: cannot unmarshal float64 into Go value of type string`,
}
for dec, errMatches := range data {
encoded := make([]byte, Base64EncodedBufferLen(len(dec)))
base64.URLEncoding.Encode(encoded, []byte(dec))
j, err := DecodeJOSE(encoded)
c.Check(j, IsNil)
c.Check(err, ErrorMatches, errMatches)
}
}
func (s *JOSESuite) TestHandleDuplicateField(c *C) {
decoded := `{"alg":"foo","alg":"none"}`
encoded := make([]byte, Base64EncodedBufferLen(len(decoded)))
base64.URLEncoding.Encode(encoded, []byte(decoded))
j, err := DecodeJOSE(encoded)
c.Assert(j, NotNil)
c.Assert(err, IsNil)
c.Assert(j.Algorithm, Equals, "none")
}
func (s *JOSESuite) TestHeaderValidation(c *C) {
data := map[string]string{
`{"alg":""}`: `jwt: missing 'alg' header`,
`{}`: `jwt: missing 'alg' header`,
}
for dec, errMatches := range data {
encoded := make([]byte, Base64EncodedBufferLen(len(dec)))
base64.URLEncoding.Encode(encoded, []byte(dec))
j, err := DecodeJOSE(encoded)
c.Check(j, NotNil)
c.Check(err, IsNil)
c.Check(j.Validate(), ErrorMatches, errMatches)
}
}
func (s *JOSESuite) TestHandlesPrivateHeaders(c *C) {
data := []map[string]interface{}{
map[string]interface{}{
"foo": true,
"bar": 12.34,
"baz": float64(456), //default of json is to parse float64
"blah": "omg",
},
}
for _, headers := range data {
j := &JOSE{
Algorithm: "none",
AdditionalHeaders: make(map[string]interface{}),
}
for n, v := range headers {
j.AdditionalHeaders[n] = v
}
enc, err := j.EncodeBase64()
if c.Check(err, IsNil) == false {
continue
}
res, err := DecodeJOSE(enc)
c.Check(err, IsNil)
c.Assert(res, DeepEquals, j)
}
}

133
jwt/jws.go Normal file
View File

@@ -0,0 +1,133 @@
package jwt
import (
"encoding/base64"
"encoding/json"
"fmt"
)
// EncodeJWS encode a JSON Object with the serialized JWS format as
// specified in RFC 7515
func EncodeJWS(j *JOSE, v interface{}, s Signer) ([]byte, error) {
payload, err := json.Marshal(v)
if err != nil {
return nil, err
}
j.Algorithm = s.Algorithm()
header, err := j.EncodeJSON()
if err != nil {
return nil, err
}
//Allocate a buffer long enough for all payload
res := make([]byte, Base64EncodedBufferLen(len(header))+Base64EncodedBufferLen(len(payload))+2)
base64.URLEncoding.Encode(res, header)
lengthHeader := Base64EncodedStrippedLen(len(res))
res[lengthHeader] = '.'
base64.URLEncoding.Encode(res[lengthHeader+1:], payload)
fullPayloadLength := Base64EncodedBufferLen(len(payload)) + 1 + lengthHeader
res[fullPayloadLength] = '.'
if s == nil {
// unprotected jws, not signing it
return res[:fullPayloadLength+1], nil
}
signature, err := s.Sign(res)
if err != nil {
return nil, err
}
signedRes := make([]byte, fullPayloadLength+1+Base64EncodedBufferLen(len(signature)))
copy(signedRes, res[:fullPayloadLength+1])
base64.URLEncoding.Encode(signedRes[fullPayloadLength+1:], signature)
return signedRes[:fullPayloadLength+1+Base64EncodedStrippedLen(len(signature))], nil
}
// isBase64URLEncoding returns if the byte is a character used for
// base64 URL Encoding
func isBase64URLEncoding(b byte) bool {
if b == '-' || b == '_' {
return true
}
if b < '0' || b > 'z' {
return false
}
if b <= '9' {
return true
}
if b < 'A' {
return false
}
if b <= 'Z' {
return true
}
if b < 'a' {
return false
}
return true
}
// DecodeJWS decode an object encoded with the JWS serialized fromat
// as specified by RFX 7515. to avoid an attack where the unprotected
// JOSE header would contain a tempered signing algorithm, the Signer
// should also be specified.
func DecodeJWS(data []byte, v interface{}, s Signer) error {
var headerLength, payloadLength int
for i, c := range data {
if isBase64URLEncoding(c) == true {
continue
}
if c != '.' {
return fmt.Errorf("jws: invalid serialization character %v in '%s'", c, data)
}
if headerLength == 0 {
if i == 0 {
return fmt.Errorf("jws: invalid emtpy JOSE header in %s", data)
}
headerLength = i
continue
}
if payloadLength == 0 {
if i-headerLength == 1 {
return fmt.Errorf("jws: invalid emtpy payload in %s", data)
}
payloadLength = i - headerLength - 1
continue
}
return fmt.Errorf("jws: invalid third '.' in %s", data)
}
signatureLength := len(data) - 2 - headerLength - payloadLength
signature := make([]byte, Base64DecodedLenFromStripped(signatureLength))
signedLength := headerLength + payloadLength + 1
Base64Decode(signature, data[signedLength+1:])
if err := s.Verify(data[:signedLength], signature); err != nil {
return err
}
//decode jose
jose, err := DecodeJOSE(data[:headerLength])
if err != nil {
return err
}
if jose.Algorithm != s.Algorithm() {
return fmt.Errorf("jws: Mismatched signing algorithm got %s, expected %s", jose.Algorithm, s.Algorithm())
}
payload := make([]byte, Base64DecodedLenFromStripped(payloadLength))
Base64Decode(payload, data[headerLength+1:signatureLength])
//data is safe, just need to decode it.
return json.Unmarshal(payload, v)
}

164
jwt/signer.go Normal file
View File

@@ -0,0 +1,164 @@
package jwt
import (
"crypto"
"crypto/hmac"
"crypto/rand"
"crypto/rsa"
"fmt"
)
type hashLinkError struct {
name string
}
func (e *hashLinkError) Error() string {
return fmt.Sprintf("jws: Hash %s is not linked into binary", e.name)
}
// A Signer is able to crypto-sign a JWS object.
type Signer interface {
// Sign returns the signature of a binary payload. It should
// returns the signature in binary format
Sign([]byte) ([]byte, error)
// Verify check if the signed binary paylod correspond to the
// provided binary siganture. It assumes a binary format
// signature.
Verify(signed, signature []byte) error
// Algorithm is returning a string idnetifying the algorithm used
// as defined in RFC 7518.
Algorithm() string
}
// A HMACSigner is a Signer using HMAC based algorithm
type HMACSigner struct {
name string
hash crypto.Hash
key []byte
}
// NewHMAC256Signer returs a HMACSigner using a SHA256 hash.
func NewHMAC256Signer(key []byte) Signer {
return &HMACSigner{
name: "HS256",
hash: crypto.SHA256,
key: key,
}
}
// NewHMAC384Signer returs a HMACSigner using a SHA384 hash.
func NewHMAC384Signer(key []byte) Signer {
return &HMACSigner{
name: "HS256",
hash: crypto.SHA384,
key: key,
}
}
// NewHMAC512Signer returs a HMACSigner using a SHA512 hash.
func NewHMAC512Signer(key []byte) Signer {
return &HMACSigner{
name: "HS256",
hash: crypto.SHA384,
key: key,
}
}
// Algorithm return 'HSXXX' where XXX is either 256|384|512 depending
// on the SHA Hash used.
func (s *HMACSigner) Algorithm() string {
return s.name
}
// Sign compute the binary signature for data.
func (s *HMACSigner) Sign(data []byte) ([]byte, error) {
if s.hash.Available() == false {
return nil, &hashLinkError{s.name}
}
hasher := hmac.New(s.hash.New, s.key)
hasher.Write(data)
return hasher.Sum(nil), nil
}
// Verify returns an error if the binary provided signature does not
// correspond to signed. It return nil if the signature is correct.
func (s *HMACSigner) Verify(signed, signature []byte) error {
expected, err := s.Sign(signed)
if err != nil {
return err
}
if hmac.Equal(signature, expected) == false {
return fmt.Errorf("jws: Invalid %s signature", s.name)
}
return nil
}
// A RSASigner is a Signer using a PKS 1v15 algorithm.
type RSASigner struct {
name string
hash crypto.Hash
key *rsa.PrivateKey
}
// NewRSA256Signer returns a RSASigner using a SHA 256 hash.
func NewRSA256Signer(key *rsa.PrivateKey) Signer {
return &RSASigner{
name: "RS256",
hash: crypto.SHA256,
key: key,
}
}
// NewRSA384Signer returns a RSASigner using a SHA 384 hash.
func NewRSA384Signer(key *rsa.PrivateKey) Signer {
return &RSASigner{
name: "RS384",
hash: crypto.SHA384,
key: key,
}
}
// NewRSA512Signer returns a RSASigner using a SHA 512 hash.
func NewRSA512Signer(key *rsa.PrivateKey) Signer {
return &RSASigner{
name: "RS512",
hash: crypto.SHA512,
key: key,
}
}
// Sign compute the binary signature for data.
func (s *RSASigner) Sign(data []byte) ([]byte, error) {
if s.hash.Available() == false {
return nil, &hashLinkError{s.name}
}
hasher := s.hash.New()
hasher.Write(data)
res, err := rsa.SignPKCS1v15(rand.Reader, s.key, s.hash, hasher.Sum(nil))
if err != nil {
return nil, err
}
return res, nil
}
// Verify returns an error if the binary provided signature does not
// correspond to signed. It return nil if the signature is correct.
func (s *RSASigner) Verify(signed, signature []byte) error {
if s.hash.Available() == false {
return &hashLinkError{s.name}
}
hasher := s.hash.New()
hasher.Write(signed)
return rsa.VerifyPKCS1v15(s.key.Public().(*rsa.PublicKey), s.hash, hasher.Sum(nil), signature)
}
// Algorithm return 'RSXXX' where XXX is either 256|384|512 depending
// on the SHA Hash used.
func (s *RSASigner) Algorithm() string {
return s.name
}