package jwt import ( "encoding/base64" "encoding/json" "fmt" ) // EncodeJWS encode a JSON Object with the serialized JWS format as // specified in RFC 7515 func EncodeJWS(j *JOSE, v interface{}, s Signer) ([]byte, error) { payload, err := json.Marshal(v) if err != nil { return nil, err } if s != nil { j.Algorithm = s.Algorithm() } else { j.Algorithm = "none" } header, err := j.EncodeJSON() if err != nil { return nil, err } //Allocate a buffer long enough for all payload res := make([]byte, Base64EncodedBufferLen(len(header))+Base64EncodedBufferLen(len(payload))+2) base64.URLEncoding.Encode(res, header) lengthHeader := Base64EncodedStrippedLen(len(header)) res[lengthHeader] = '.' base64.URLEncoding.Encode(res[lengthHeader+1:], payload) fullPayloadLength := Base64EncodedStrippedLen(len(payload)) + 1 + lengthHeader res[fullPayloadLength] = '.' if s == nil { // unprotected jws, not signing it return res[:fullPayloadLength+1], nil } signature, err := s.Sign(res[:fullPayloadLength]) if err != nil { return nil, err } signedRes := make([]byte, fullPayloadLength+1+Base64EncodedBufferLen(len(signature))) copy(signedRes, res[:fullPayloadLength+1]) base64.URLEncoding.Encode(signedRes[fullPayloadLength+1:], signature) return signedRes[:fullPayloadLength+1+Base64EncodedStrippedLen(len(signature))], nil } // isBase64URLEncoding returns if the byte is a character used for // base64 URL Encoding func isBase64URLEncoding(b byte) bool { if b == '-' || b == '_' { return true } if b < '0' || b > 'z' { return false } if b <= '9' { return true } if b < 'A' { return false } if b <= 'Z' { return true } if b < 'a' { return false } return true } // DecodeJWS decode an object encoded with the JWS serialized fromat // as specified by RFC 7515. to avoid an attack where the unprotected // JOSE header would contain a modified alg field, the Signer should // also be specified. func DecodeJWS(data []byte, v interface{}, s Signer) error { var headerLength, payloadLength int for i, c := range data { if isBase64URLEncoding(c) == true { continue } if c != '.' { return fmt.Errorf("jws: invalid serialization character %v in '%s'", c, data) } if headerLength == 0 { if i == 0 { return fmt.Errorf("jws: invalid emtpy JOSE header in %s", data) } headerLength = i continue } if payloadLength == 0 { if i-headerLength == 1 { return fmt.Errorf("jws: invalid emtpy payload in %s", data) } payloadLength = i - headerLength - 1 continue } return fmt.Errorf("jws: invalid third '.' in %s", data) } signatureLength := len(data) - 2 - headerLength - payloadLength signature := make([]byte, Base64DecodedLenFromStripped(signatureLength)) signedLength := headerLength + payloadLength + 1 Base64Decode(signature, data[signedLength+1:]) signature = signature[:Base64DecodedStrippedLen(signatureLength)] if s != nil { if err := s.Verify(data[:signedLength], signature); err != nil { return err } } else { if signatureLength != 0 { return fmt.Errorf("jws: Invalid JWS, got a signature, but none expected") } } //decode jose jose, err := DecodeJOSE(data[:headerLength]) if err != nil { return err } algo := "none" if s != nil { algo = s.Algorithm() } if jose.Algorithm != algo { return fmt.Errorf("jws: Mismatched signing algorithm got %s, expected %s", jose.Algorithm, algo) } payload := make([]byte, Base64DecodedLenFromStripped(payloadLength)) Base64Decode(payload, data[headerLength+1:headerLength+1+payloadLength]) payload = payload[:Base64DecodedStrippedLen(payloadLength)] //data is safe, just need to decode it. return json.Unmarshal(payload, v) }