Compare commits
10 Commits
8dccf95fe3
...
91908f15e7
| Author | SHA1 | Date | |
|---|---|---|---|
| 91908f15e7 | |||
| f64f225b18 | |||
| 336987e806 | |||
| 26b7356b79 | |||
| abe3cbab1f | |||
| 00c5853b45 | |||
| 4eb3bdaa13 | |||
| dc107d934d | |||
| 4785e489d5 | |||
| 5a9b19fcb1 |
2
Makefile
2
Makefile
@@ -6,4 +6,4 @@ build:
|
|||||||
check:
|
check:
|
||||||
go test -coverprofile=cover.out -covermode=count
|
go test -coverprofile=cover.out -covermode=count
|
||||||
go vet
|
go vet
|
||||||
# golint
|
golint
|
||||||
|
|||||||
@@ -11,8 +11,13 @@ type narcoKey int
|
|||||||
|
|
||||||
const errorFormatterKey narcoKey = 1
|
const errorFormatterKey narcoKey = 1
|
||||||
|
|
||||||
|
// ErrorFormatter defines a function that can format an HTTP1.1 error
|
||||||
|
// and is but is Context-aware
|
||||||
type ErrorFomatter func(ctx context.Context, rw http.ResponseWriter, message interface{}, status int)
|
type ErrorFomatter func(ctx context.Context, rw http.ResponseWriter, message interface{}, status int)
|
||||||
|
|
||||||
|
// Error is formatting a in a http.Response an error, like
|
||||||
|
// http.Error. However it uses the ErrorFormatter in the current
|
||||||
|
// context to do so, if one is defined with WithErrorFormatter
|
||||||
func Error(ctx context.Context, rw http.ResponseWriter, message interface{}, status int) {
|
func Error(ctx context.Context, rw http.ResponseWriter, message interface{}, status int) {
|
||||||
fmter, ok := ctx.Value(errorFormatterKey).(ErrorFomatter)
|
fmter, ok := ctx.Value(errorFormatterKey).(ErrorFomatter)
|
||||||
if ok == false {
|
if ok == false {
|
||||||
@@ -25,18 +30,25 @@ func Error(ctx context.Context, rw http.ResponseWriter, message interface{}, sta
|
|||||||
|
|
||||||
var narcoDefaultFormatter ErrorFomatter
|
var narcoDefaultFormatter ErrorFomatter
|
||||||
|
|
||||||
|
// TextErrorFormatter is an ErrorFormatter that formats the error by
|
||||||
|
// just putting plain text
|
||||||
func TextErrorFormatter(ctx context.Context, rw http.ResponseWriter, message interface{}, status int) {
|
func TextErrorFormatter(ctx context.Context, rw http.ResponseWriter, message interface{}, status int) {
|
||||||
rw.Header()["Content-Type"] = []string{"text/plain; charset=utf-8"}
|
rw.Header()["Content-Type"] = []string{"text/plain; charset=utf-8"}
|
||||||
rw.WriteHeader(status)
|
rw.WriteHeader(status)
|
||||||
fmt.Fprintf(rw, "%s", message)
|
fmt.Fprintf(rw, "%s", message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BasicHTMLErrorFormatter is an ErrorFormatter that formats the error
|
||||||
|
// in a very basic HTML page (you may want to write your own with a
|
||||||
|
// choosen template.
|
||||||
func BasicHTMLErrorFormatter(ctx context.Context, rw http.ResponseWriter, message interface{}, status int) {
|
func BasicHTMLErrorFormatter(ctx context.Context, rw http.ResponseWriter, message interface{}, status int) {
|
||||||
rw.Header()["Content-Type"] = []string{"text/html; charset=utf-8"}
|
rw.Header()["Content-Type"] = []string{"text/html; charset=utf-8"}
|
||||||
rw.WriteHeader(status)
|
rw.WriteHeader(status)
|
||||||
fmt.Fprintf(rw, `<!doctype html><html><body><h1>Error %d</h1>%s</body></html>`, status, message)
|
fmt.Fprintf(rw, `<!doctype html><html><body><h1>Error %d</h1>%s</body></html>`, status, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithErrorFormatter is returning a context where the user has
|
||||||
|
// defined an ErrorFormatter that Error() should use.
|
||||||
func WithErrorFormatter(ctx context.Context, fmter ErrorFomatter) context.Context {
|
func WithErrorFormatter(ctx context.Context, fmter ErrorFomatter) context.Context {
|
||||||
return context.WithValue(ctx, errorFormatterKey, fmter)
|
return context.WithValue(ctx, errorFormatterKey, fmter)
|
||||||
}
|
}
|
||||||
|
|||||||
80
jwt.go
Normal file
80
jwt.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package narco
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/codemodus/chain"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"ponyo.epfl.ch/gitlab/alexandre.tuleu/narco/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JWT struct {
|
||||||
|
signer jwt.Signer
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
jwtKeyGenerator narcoKey = 2
|
||||||
|
jwtKeyOutput narcoKey = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
func (j *JWT) Wrap() func(chain.Handler) chain.Handler {
|
||||||
|
return func(other chain.Handler) chain.Handler {
|
||||||
|
return chain.HandlerFunc(func(ctx context.Context, rw http.ResponseWriter, req *http.Request) {
|
||||||
|
tokenGenerator, ok := ctx.Value(jwtKeyGenerator).(TokenCreator)
|
||||||
|
|
||||||
|
if tokenGenerator == nil || ok == false {
|
||||||
|
//we did not register a token we ignore all processing
|
||||||
|
other.ServeHTTPContext(ctx, rw, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// check for the header
|
||||||
|
tokenStr, ok := req.Header["Authorization"]
|
||||||
|
if ok == false {
|
||||||
|
// no Authorization header, we are just not
|
||||||
|
// authentified, we process down, no token added
|
||||||
|
other.ServeHTTPContext(ctx, rw, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tokenStr) != 1 || strings.HasPrefix(tokenStr[0], "Bearer ") == false {
|
||||||
|
Error(ctx, rw, fmt.Errorf("Invalid Authorization HTTP Header %v", tokenStr), http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tkData := strings.TrimPrefix(tokenStr[0], "Bearer ")
|
||||||
|
|
||||||
|
//parse the desired token, with signature checking
|
||||||
|
token := tokenGenerator()
|
||||||
|
err := jwt.DecodeJWS([]byte(tkData), token, j.signer)
|
||||||
|
if err != nil {
|
||||||
|
Error(ctx, rw, err, http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// we process down the pipeline, with the token added to
|
||||||
|
// context
|
||||||
|
other.ServeHTTPContext(context.WithValue(ctx, jwtKeyOutput, token), rw, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetJwt(ctx context.Context) interface{} {
|
||||||
|
return ctx.Value(jwtKeyOutput)
|
||||||
|
}
|
||||||
|
|
||||||
|
//A Token Creator should produce a new pointer to a desired token
|
||||||
|
//struct
|
||||||
|
type TokenCreator func() interface{}
|
||||||
|
|
||||||
|
// RegisterJwtType should be used to register the token rtyope
|
||||||
|
// expected in the context. TokenCreator should allocate a new,
|
||||||
|
// unparsed token.
|
||||||
|
func RegisterJwtType(ctx context.Context, generator TokenCreator) context.Context {
|
||||||
|
return context.WithValue(ctx, jwtKeyGenerator, generator)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewJWT(s jwt.Signer) *JWT {
|
||||||
|
return &JWT{signer: s}
|
||||||
|
}
|
||||||
1
jwt/.gitignore
vendored
Normal file
1
jwt/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
cover.out
|
||||||
9
jwt/Makefile
Normal file
9
jwt/Makefile
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
all: build check
|
||||||
|
|
||||||
|
build:
|
||||||
|
go build
|
||||||
|
|
||||||
|
check:
|
||||||
|
go test -coverprofile=cover.out -covermode=count
|
||||||
|
go vet
|
||||||
|
golint
|
||||||
63
jwt/base64.go
Normal file
63
jwt/base64.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Base64EncodedBufferLen return the size needed to encode the
|
||||||
|
// Base64Encode (padding '=' would be added, can be stripped with
|
||||||
|
// Base64DecodedStrippedLen
|
||||||
|
func Base64EncodedBufferLen(n int) int {
|
||||||
|
return base64.URLEncoding.EncodedLen(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base64EncodedStrippedLen return the size of the encoded slice with no
|
||||||
|
// padding.
|
||||||
|
func Base64EncodedStrippedLen(n int) int {
|
||||||
|
eN := Base64EncodedBufferLen(n)
|
||||||
|
mod := n % 3
|
||||||
|
if mod == 0 {
|
||||||
|
return eN
|
||||||
|
}
|
||||||
|
return eN - (3 - mod)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base64DecodedStrippedLen returns the size of the data, without 0
|
||||||
|
// bit padding, from the stripped length n.
|
||||||
|
func Base64DecodedStrippedLen(n int) int {
|
||||||
|
return Base64DecodedLenFromStripped(n) + n - ((n+3)/4)*4
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base64DecodedLenFromStripped returns the size of the data, with 0
|
||||||
|
// bit padding, from the stripped length n.
|
||||||
|
func Base64DecodedLenFromStripped(n int) int {
|
||||||
|
return base64.URLEncoding.DecodedLen(4 * ((n + 3) / 4))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base64Decode decodes a payload of data as a []byte, from a stripped
|
||||||
|
// '=' encoded string.
|
||||||
|
func Base64Decode(dst, src []byte) error {
|
||||||
|
switch len(src) % 4 {
|
||||||
|
// in this case, we have to copy the src, in order to allow decode part of the input
|
||||||
|
case 2:
|
||||||
|
oldSrc := src
|
||||||
|
src = make([]byte, len(src)+2)
|
||||||
|
copy(src, oldSrc)
|
||||||
|
src[len(src)-2] = '='
|
||||||
|
src[len(src)-1] = '='
|
||||||
|
case 3:
|
||||||
|
oldSrc := src
|
||||||
|
src = make([]byte, len(src)+1)
|
||||||
|
copy(src, oldSrc)
|
||||||
|
src[len(src)-1] = '='
|
||||||
|
case 1:
|
||||||
|
return fmt.Errorf("jwt: Invalid base64 string (length:%d %% 4 == 1): '%s'", len(src), src)
|
||||||
|
}
|
||||||
|
_, err := base64.URLEncoding.Decode(dst, src)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("jwt: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
66
jwt/base64_test.go
Normal file
66
jwt/base64_test.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "gopkg.in/check.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test(t *testing.T) {
|
||||||
|
TestingT(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Base64Suite struct{}
|
||||||
|
|
||||||
|
var _ = Suite(&Base64Suite{})
|
||||||
|
|
||||||
|
type dataAndEncode struct {
|
||||||
|
data []byte
|
||||||
|
encoded []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Base64Suite) TestEncodeWithNoTrailing(c *C) {
|
||||||
|
data := []dataAndEncode{
|
||||||
|
{[]byte{0, 0, 0}, []byte("AAAA")}, //With trailing Should be AAAA
|
||||||
|
{[]byte{0, 0}, []byte("AAA")}, //With trailing Should be AAA=
|
||||||
|
{[]byte{0}, []byte("AA")}, //With trailing Should be AA==
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, d := range data {
|
||||||
|
res := make([]byte, Base64EncodedBufferLen(len(d.data)))
|
||||||
|
base64.URLEncoding.Encode(res, d.data)
|
||||||
|
c.Check(res[:Base64EncodedStrippedLen(len(d.data))], DeepEquals, d.encoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Base64Suite) TestDecodeWithNoTrailing(c *C) {
|
||||||
|
data := map[string][]byte{
|
||||||
|
"AAAA": []byte{0, 0, 0},
|
||||||
|
"AAA": []byte{0, 0},
|
||||||
|
"AA": []byte{0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for encoded, expected := range data {
|
||||||
|
res := make([]byte, Base64DecodedLenFromStripped(len(encoded)))
|
||||||
|
err := Base64Decode(res, []byte(encoded))
|
||||||
|
if c.Check(err, IsNil) == false {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.Check(res[:Base64DecodedStrippedLen(len(encoded))], DeepEquals, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Base64Suite) TestDetectBadFormat(c *C) {
|
||||||
|
data := map[string]string{
|
||||||
|
"A": `jwt: Invalid base64 string \(length:[0-9]+ % 4 == 1\): 'A'`,
|
||||||
|
"ABCD%===": `jwt: illegal base64 data at input byte [0-9]+`,
|
||||||
|
}
|
||||||
|
|
||||||
|
for encoded, errorMatches := range data {
|
||||||
|
res := make([]byte, Base64DecodedLenFromStripped(len(encoded)))
|
||||||
|
err := Base64Decode(res, []byte(encoded))
|
||||||
|
c.Check(err, ErrorMatches, errorMatches)
|
||||||
|
}
|
||||||
|
}
|
||||||
147
jwt/jose.go
Normal file
147
jwt/jose.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JOSE represent a JSON Object Signing and Encryption header as
|
||||||
|
// defined in RFC 7515.
|
||||||
|
type JOSE struct {
|
||||||
|
Algorithm string `json:"alg"`
|
||||||
|
Type string `json:"typ,omitempty"`
|
||||||
|
Content string `json:"cty,omitempty"`
|
||||||
|
Critical []string `json:"crit,omitempty"`
|
||||||
|
JWKSetURL string `json:"jku,omitempty"`
|
||||||
|
JSONWebKey string `json:"jwk,omitempty"`
|
||||||
|
KeyID string `json:"kid,omitempty"`
|
||||||
|
X509URL string `json:"x5u,omitempty"`
|
||||||
|
X509CertificateChain string `json:"x5c,omitempty"`
|
||||||
|
X509ThumbprintSha1 string `json:"x5t,omitempty"`
|
||||||
|
X509ThumbprintSha256 string `json:"x5t#S256,omitempty"`
|
||||||
|
|
||||||
|
//Sets of adfditional headers, private or public
|
||||||
|
AdditionalHeaders map[string]interface{} `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeJSON encodes a JOSE header in JSON format. Not using
|
||||||
|
// MarshallJSON tyo avoid loops
|
||||||
|
func (j *JOSE) EncodeJSON() ([]byte, error) {
|
||||||
|
data, err := json.Marshal(j)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data) == 0 || data[0] != '{' || data[len(data)-1] != '}' {
|
||||||
|
return nil, fmt.Errorf("jws: Invalid JSON encoding: %s", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(j.AdditionalHeaders) == 0 {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
moreHeader, err := json.Marshal(j.AdditionalHeaders)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(moreHeader) == 0 || moreHeader[0] != '{' || moreHeader[len(moreHeader)-1] != '}' {
|
||||||
|
return nil, fmt.Errorf("jws: Invalid JSON encoding: %s", data)
|
||||||
|
}
|
||||||
|
data[len(data)-1] = ','
|
||||||
|
return append(data, moreHeader[1:]...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeBase64 encodes a JOSE into JSON base64 string. will allocate
|
||||||
|
// buffer.
|
||||||
|
func (j *JOSE) EncodeBase64() ([]byte, error) {
|
||||||
|
dec, err := j.EncodeJSON()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
enc := make([]byte, Base64EncodedBufferLen(len(dec)))
|
||||||
|
base64.URLEncoding.Encode(enc, dec)
|
||||||
|
return enc[:Base64EncodedStrippedLen(len(dec))], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeJOSE a JOSE header from a base64 text as defined in the RFC 7515
|
||||||
|
func DecodeJOSE(data []byte) (*JOSE, error) {
|
||||||
|
bData := make([]byte, Base64DecodedLenFromStripped(len(data)))
|
||||||
|
|
||||||
|
if err := Base64Decode(bData, data); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
j := &JOSE{}
|
||||||
|
if err := json.NewDecoder(bytes.NewBuffer(bData)).Decode(j); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return j, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates a JOSE header data. Maybe it should disappear
|
||||||
|
func (j *JOSE) Validate() error {
|
||||||
|
if len(j.Algorithm) == 0 {
|
||||||
|
return fmt.Errorf("jwt: missing 'alg' header")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type fieldSetter func(*JOSE, string)
|
||||||
|
|
||||||
|
var fieldSetters = make(map[string]fieldSetter)
|
||||||
|
|
||||||
|
// UnmarshalJSON is here to satisfy interface json.Unmarshaller. We
|
||||||
|
// need to provide ou own unmarshaller for the additional header.
|
||||||
|
func (j *JOSE) UnmarshalJSON(b []byte) error {
|
||||||
|
raw := make(map[string]interface{})
|
||||||
|
|
||||||
|
if err := json.Unmarshal(b, &raw); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if j.AdditionalHeaders == nil {
|
||||||
|
j.AdditionalHeaders = make(map[string]interface{}, len(raw))
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, data := range raw {
|
||||||
|
if fSetter, ok := fieldSetters[key]; ok == true {
|
||||||
|
strData, ok := data.(string)
|
||||||
|
if ok == false {
|
||||||
|
return &json.UnmarshalTypeError{
|
||||||
|
Value: reflect.TypeOf(data).Kind().String(),
|
||||||
|
Type: reflect.TypeOf(string("")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fSetter(j, strData)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
j.AdditionalHeaders[key] = data
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
tJOSE := reflect.TypeOf(JOSE{})
|
||||||
|
for i := 0; i < tJOSE.NumField(); i++ {
|
||||||
|
jField := tJOSE.Field(i)
|
||||||
|
jTag := jField.Tag.Get("json")
|
||||||
|
if len(jTag) == 0 || jTag == "-" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
jName := strings.Split(jTag, ",")
|
||||||
|
fieldSetters[jName[0]] = func(j *JOSE, data string) {
|
||||||
|
reflect.ValueOf(j).Elem().FieldByName(jField.Name).SetString(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
103
jwt/jose_test.go
Normal file
103
jwt/jose_test.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
|
||||||
|
. "gopkg.in/check.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JOSESuite struct{}
|
||||||
|
|
||||||
|
var _ = Suite(&JOSESuite{})
|
||||||
|
|
||||||
|
func (s *JOSESuite) TestCheckForDataEncoding(c *C) {
|
||||||
|
data := map[string]string{
|
||||||
|
`"alg":"none"}`: `json: .*`,
|
||||||
|
`{alg:"none"}`: `invalid character '.' looking for beginning of object key string`,
|
||||||
|
`{"alg":1234}`: `json: cannot unmarshal float64 into Go value of type string`,
|
||||||
|
}
|
||||||
|
|
||||||
|
for dec, errMatches := range data {
|
||||||
|
encoded := make([]byte, Base64EncodedBufferLen(len(dec)))
|
||||||
|
base64.URLEncoding.Encode(encoded, []byte(dec))
|
||||||
|
|
||||||
|
j, err := DecodeJOSE(encoded)
|
||||||
|
c.Check(j, IsNil)
|
||||||
|
c.Check(err, ErrorMatches, errMatches)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JOSESuite) TestHandleDuplicateField(c *C) {
|
||||||
|
|
||||||
|
decoded := `{"alg":"foo","alg":"none"}`
|
||||||
|
encoded := make([]byte, Base64EncodedBufferLen(len(decoded)))
|
||||||
|
base64.URLEncoding.Encode(encoded, []byte(decoded))
|
||||||
|
j, err := DecodeJOSE(encoded)
|
||||||
|
c.Assert(j, NotNil)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
|
c.Assert(j.Algorithm, Equals, "none")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JOSESuite) TestHeaderValidation(c *C) {
|
||||||
|
|
||||||
|
data := map[string]string{
|
||||||
|
`{"alg":""}`: `jwt: missing 'alg' header`,
|
||||||
|
`{}`: `jwt: missing 'alg' header`,
|
||||||
|
}
|
||||||
|
|
||||||
|
for dec, errMatches := range data {
|
||||||
|
encoded := make([]byte, Base64EncodedBufferLen(len(dec)))
|
||||||
|
base64.URLEncoding.Encode(encoded, []byte(dec))
|
||||||
|
j, err := DecodeJOSE(encoded)
|
||||||
|
c.Check(j, NotNil)
|
||||||
|
c.Check(err, IsNil)
|
||||||
|
c.Check(j.Validate(), ErrorMatches, errMatches)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JOSESuite) TestHandlesPrivateHeaders(c *C) {
|
||||||
|
data := []map[string]interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"foo": true,
|
||||||
|
"bar": 12.34,
|
||||||
|
"baz": float64(456), //default of json is to parse float64
|
||||||
|
"blah": "omg",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, headers := range data {
|
||||||
|
j := &JOSE{
|
||||||
|
Algorithm: "none",
|
||||||
|
AdditionalHeaders: make(map[string]interface{}),
|
||||||
|
}
|
||||||
|
for n, v := range headers {
|
||||||
|
j.AdditionalHeaders[n] = v
|
||||||
|
}
|
||||||
|
enc, err := j.EncodeBase64()
|
||||||
|
if c.Check(err, IsNil) == false {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := DecodeJOSE(enc)
|
||||||
|
c.Check(err, IsNil)
|
||||||
|
c.Assert(res, DeepEquals, j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JOSESuite) TestEncodingIsCompact(c *C) {
|
||||||
|
j := &JOSE{
|
||||||
|
Algorithm: "none",
|
||||||
|
Critical: []string{"exp", "foo"},
|
||||||
|
}
|
||||||
|
// it does not print empty fields.
|
||||||
|
data, err := j.EncodeJSON()
|
||||||
|
c.Check(err, IsNil)
|
||||||
|
c.Check(string(data), Equals, `{"alg":"none","crit":["exp","foo"]}`)
|
||||||
|
|
||||||
|
b64, err := j.EncodeBase64()
|
||||||
|
c.Check(err, IsNil)
|
||||||
|
c.Check(len(b64), Equals, 47)
|
||||||
|
}
|
||||||
150
jwt/jws.go
Normal file
150
jwt/jws.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EncodeJWS encode a JSON Object with the serialized JWS format as
|
||||||
|
// specified in RFC 7515
|
||||||
|
func EncodeJWS(j *JOSE, v interface{}, s Signer) ([]byte, error) {
|
||||||
|
payload, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if s != nil {
|
||||||
|
j.Algorithm = s.Algorithm()
|
||||||
|
} else {
|
||||||
|
j.Algorithm = "none"
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err := j.EncodeJSON()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
//Allocate a buffer long enough for all payload
|
||||||
|
res := make([]byte, Base64EncodedBufferLen(len(header))+Base64EncodedBufferLen(len(payload))+2)
|
||||||
|
base64.URLEncoding.Encode(res, header)
|
||||||
|
lengthHeader := Base64EncodedStrippedLen(len(header))
|
||||||
|
res[lengthHeader] = '.'
|
||||||
|
base64.URLEncoding.Encode(res[lengthHeader+1:], payload)
|
||||||
|
fullPayloadLength := Base64EncodedStrippedLen(len(payload)) + 1 + lengthHeader
|
||||||
|
res[fullPayloadLength] = '.'
|
||||||
|
|
||||||
|
if s == nil {
|
||||||
|
// unprotected jws, not signing it
|
||||||
|
return res[:fullPayloadLength+1], nil
|
||||||
|
}
|
||||||
|
signature, err := s.Sign(res[:fullPayloadLength])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
signedRes := make([]byte, fullPayloadLength+1+Base64EncodedBufferLen(len(signature)))
|
||||||
|
copy(signedRes, res[:fullPayloadLength+1])
|
||||||
|
base64.URLEncoding.Encode(signedRes[fullPayloadLength+1:], signature)
|
||||||
|
return signedRes[:fullPayloadLength+1+Base64EncodedStrippedLen(len(signature))], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isBase64URLEncoding returns if the byte is a character used for
|
||||||
|
// base64 URL Encoding
|
||||||
|
func isBase64URLEncoding(b byte) bool {
|
||||||
|
if b == '-' || b == '_' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if b < '0' || b > 'z' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if b <= '9' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if b < 'A' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if b <= 'Z' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if b < 'a' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeJWS decode an object encoded with the JWS serialized fromat
|
||||||
|
// as specified by RFC 7515. to avoid an attack where the unprotected
|
||||||
|
// JOSE header would contain a modified alg field, the Signer should
|
||||||
|
// also be specified.
|
||||||
|
func DecodeJWS(data []byte, v interface{}, s Signer) error {
|
||||||
|
var headerLength, payloadLength int
|
||||||
|
for i, c := range data {
|
||||||
|
if isBase64URLEncoding(c) == true {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c != '.' {
|
||||||
|
return fmt.Errorf("jws: invalid serialization character %v in '%s'", c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
if headerLength == 0 {
|
||||||
|
if i == 0 {
|
||||||
|
return fmt.Errorf("jws: invalid emtpy JOSE header in %s", data)
|
||||||
|
}
|
||||||
|
headerLength = i
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if payloadLength == 0 {
|
||||||
|
if i-headerLength == 1 {
|
||||||
|
return fmt.Errorf("jws: invalid emtpy payload in %s", data)
|
||||||
|
}
|
||||||
|
payloadLength = i - headerLength - 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("jws: invalid third '.' in %s", data)
|
||||||
|
}
|
||||||
|
signatureLength := len(data) - 2 - headerLength - payloadLength
|
||||||
|
signature := make([]byte, Base64DecodedLenFromStripped(signatureLength))
|
||||||
|
signedLength := headerLength + payloadLength + 1
|
||||||
|
Base64Decode(signature, data[signedLength+1:])
|
||||||
|
signature = signature[:Base64DecodedStrippedLen(signatureLength)]
|
||||||
|
|
||||||
|
if s != nil {
|
||||||
|
if err := s.Verify(data[:signedLength], signature); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if signatureLength != 0 {
|
||||||
|
return fmt.Errorf("jws: Invalid JWS, got a signature, but none expected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//decode jose
|
||||||
|
jose, err := DecodeJOSE(data[:headerLength])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
algo := "none"
|
||||||
|
if s != nil {
|
||||||
|
algo = s.Algorithm()
|
||||||
|
}
|
||||||
|
|
||||||
|
if jose.Algorithm != algo {
|
||||||
|
return fmt.Errorf("jws: Mismatched signing algorithm got %s, expected %s", jose.Algorithm, algo)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := make([]byte, Base64DecodedLenFromStripped(payloadLength))
|
||||||
|
Base64Decode(payload, data[headerLength+1:headerLength+1+payloadLength])
|
||||||
|
payload = payload[:Base64DecodedStrippedLen(payloadLength)]
|
||||||
|
//data is safe, just need to decode it.
|
||||||
|
return json.Unmarshal(payload, v)
|
||||||
|
}
|
||||||
211
jwt/jws_test.go
Normal file
211
jwt/jws_test.go
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
. "gopkg.in/check.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JWSSuite struct {
|
||||||
|
signers []Signer
|
||||||
|
hmacKey []byte
|
||||||
|
rsaKey *rsa.PrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = Suite(&JWSSuite{})
|
||||||
|
|
||||||
|
func (s *JWSSuite) SetUpSuite(c *C) {
|
||||||
|
var err error
|
||||||
|
s.hmacKey = NewHMACKey(24)
|
||||||
|
|
||||||
|
s.rsaKey, err = CachedRSAkey()
|
||||||
|
|
||||||
|
c.Assert(err, IsNil)
|
||||||
|
|
||||||
|
s.signers = []Signer{
|
||||||
|
NewHMAC256Signer(s.hmacKey),
|
||||||
|
NewHMAC384Signer(s.hmacKey),
|
||||||
|
NewHMAC512Signer(s.hmacKey),
|
||||||
|
NewRSA256Signer(s.rsaKey),
|
||||||
|
NewRSA384Signer(s.rsaKey),
|
||||||
|
NewRSA512Signer(s.rsaKey),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JWSSuite) TestCanEncodeSigned(c *C) {
|
||||||
|
type foo struct {
|
||||||
|
A string
|
||||||
|
B int
|
||||||
|
}
|
||||||
|
|
||||||
|
j := &JOSE{}
|
||||||
|
|
||||||
|
for _, signer := range s.signers {
|
||||||
|
decoded := foo{A: "blah", B: 42}
|
||||||
|
res, err := EncodeJWS(j, decoded, signer)
|
||||||
|
c.Check(string(res), Matches, `\A[0-9A-Za-z_\-]+\.[0-9A-Za-z_\-]+.[0-9A-Za-z_\-]+\z`)
|
||||||
|
if c.Check(err, IsNil) == false {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
redecoded := foo{}
|
||||||
|
err = DecodeJWS(res, &redecoded, signer)
|
||||||
|
if c.Check(err, IsNil, Commentf("Algo is: %s", signer.Algorithm())) == true {
|
||||||
|
c.Check(redecoded, DeepEquals, decoded)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JWSSuite) TestCanEncodeUnprotected(c *C) {
|
||||||
|
|
||||||
|
type foo struct {
|
||||||
|
A string
|
||||||
|
B int
|
||||||
|
}
|
||||||
|
|
||||||
|
j := &JOSE{}
|
||||||
|
|
||||||
|
decoded := foo{A: "blah", B: 42}
|
||||||
|
res, err := EncodeJWS(j, decoded, nil)
|
||||||
|
c.Check(string(res), Matches, `\A[0-9A-Za-z_\-]+\.[0-9A-Za-z_\-]+.\z`)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
|
|
||||||
|
redecoded := foo{}
|
||||||
|
err = DecodeJWS(res, &redecoded, nil)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
|
c.Check(redecoded, DeepEquals, decoded)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JWSSuite) TestAttackerTriesToForgeAToken(c *C) {
|
||||||
|
type ident struct {
|
||||||
|
Iat int64
|
||||||
|
Adm bool
|
||||||
|
}
|
||||||
|
|
||||||
|
//--------------------------------------------------------------------------
|
||||||
|
// Server side
|
||||||
|
//--------------------------------------------------------------------------
|
||||||
|
|
||||||
|
//we create a token and sign it
|
||||||
|
|
||||||
|
validToken, err := EncodeJWS(&JOSE{}, ident{Iat: time.Now().Unix(), Adm: false}, s.signers[0])
|
||||||
|
c.Assert(err, IsNil)
|
||||||
|
|
||||||
|
//--------------------------------------------------------------------------
|
||||||
|
// Untrusted side
|
||||||
|
//--------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// now an attacker takes our valid token, unmarshall it, and create a new one with no signature
|
||||||
|
t := ident{}
|
||||||
|
|
||||||
|
tokenRx := regexp.MustCompile(`\A([a-zA-Z0-9_\-]+)\.([a-zA-Z0-9_\-]+).([a-zA-Z0-9_\-]+)\z`)
|
||||||
|
|
||||||
|
matches := tokenRx.FindSubmatch(validToken)
|
||||||
|
c.Assert(matches, NotNil)
|
||||||
|
|
||||||
|
// extract JOSE
|
||||||
|
j, err := DecodeJOSE(matches[1])
|
||||||
|
c.Assert(err, IsNil)
|
||||||
|
|
||||||
|
// extract payload
|
||||||
|
payload := make([]byte, Base64DecodedLenFromStripped(len(matches[2])))
|
||||||
|
Base64Decode(payload, matches[2])
|
||||||
|
err = json.Unmarshal(payload, &t)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
|
|
||||||
|
//we forge an admin token
|
||||||
|
t.Adm = true
|
||||||
|
|
||||||
|
//we un-sign the token, by using the none algorithm
|
||||||
|
forgedToken, err := EncodeJWS(j, t, nil)
|
||||||
|
|
||||||
|
//--------------------------------------------------------------------------
|
||||||
|
// Server side
|
||||||
|
//--------------------------------------------------------------------------
|
||||||
|
|
||||||
|
//we verify the forgedToken, it should fail
|
||||||
|
received := ident{Adm: false}
|
||||||
|
// please note that we speficy ecplitely the signer, and therefore
|
||||||
|
// the algorithm to use.
|
||||||
|
err = DecodeJWS(forgedToken, &received, s.signers[0])
|
||||||
|
|
||||||
|
//we got a wrong signature for our chosen algorithm
|
||||||
|
c.Assert(err, ErrorMatches, fmt.Sprintf(`jws: .* %s .*`, s.signers[0].Algorithm()))
|
||||||
|
// of course, our token remain Adm:false
|
||||||
|
c.Assert(received.Adm, Equals, false)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JWSSuite) TestAttackerTriesToForgeAHMACToken(c *C) {
|
||||||
|
type ident struct {
|
||||||
|
Iat int64
|
||||||
|
Adm bool
|
||||||
|
}
|
||||||
|
|
||||||
|
//--------------------------------------------------------------------------
|
||||||
|
// Server side
|
||||||
|
//--------------------------------------------------------------------------
|
||||||
|
|
||||||
|
//we create a token and sign it using RSA256
|
||||||
|
serverSigner := NewRSA256Signer(s.rsaKey)
|
||||||
|
validToken, err := EncodeJWS(&JOSE{}, ident{Iat: time.Now().Unix(), Adm: false}, serverSigner)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
|
|
||||||
|
//--------------------------------------------------------------------------
|
||||||
|
// Untrusted side
|
||||||
|
//--------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// now an attacker takes our valid token, unmarshall it, and create a new one with no signature
|
||||||
|
t := ident{}
|
||||||
|
|
||||||
|
tokenRx := regexp.MustCompile(`\A([a-zA-Z0-9_\-]+)\.([a-zA-Z0-9_\-]+).([a-zA-Z0-9_\-]+)\z`)
|
||||||
|
|
||||||
|
matches := tokenRx.FindSubmatch(validToken)
|
||||||
|
c.Assert(matches, NotNil)
|
||||||
|
|
||||||
|
// extract JOSE
|
||||||
|
j, err := DecodeJOSE(matches[1])
|
||||||
|
c.Assert(err, IsNil)
|
||||||
|
|
||||||
|
// extract payload
|
||||||
|
payload := make([]byte, Base64DecodedLenFromStripped(len(matches[2])))
|
||||||
|
Base64Decode(payload, matches[2])
|
||||||
|
err = json.Unmarshal(payload, &t)
|
||||||
|
c.Assert(err, IsNil)
|
||||||
|
|
||||||
|
//we forge an admin token
|
||||||
|
t.Adm = true
|
||||||
|
|
||||||
|
//now here is the trick, we sign an HMAC token using the serverPublicKey
|
||||||
|
|
||||||
|
publicKeyAsByte, err := x509.MarshalPKIXPublicKey(s.rsaKey.Public())
|
||||||
|
c.Assert(err, IsNil, Commentf("While marshalling key"))
|
||||||
|
//we create a new signing algorithm expecting the server will
|
||||||
|
//still use the assymetric key on the HMAC.
|
||||||
|
attackerSigner := NewHMAC256Signer(publicKeyAsByte)
|
||||||
|
//we un-sign the token, by using another symmetric algorithm, and
|
||||||
|
//the publickKey
|
||||||
|
forgedToken, err := EncodeJWS(j, t, attackerSigner)
|
||||||
|
|
||||||
|
//--------------------------------------------------------------------------
|
||||||
|
// Server side
|
||||||
|
//--------------------------------------------------------------------------
|
||||||
|
|
||||||
|
//we verify the forgedToken, it should fail
|
||||||
|
received := ident{Adm: false}
|
||||||
|
// please note that we speficy ecplitely the signer, and therefore
|
||||||
|
// the algorithm to use.
|
||||||
|
err = DecodeJWS(forgedToken, &received, serverSigner)
|
||||||
|
|
||||||
|
//we got a wrong signature for our chosen algorithm, we still are
|
||||||
|
//using the rsa algorithm
|
||||||
|
c.Assert(err, ErrorMatches, `crypto/rsa: verification error`)
|
||||||
|
// of course, our token remain Adm:false
|
||||||
|
c.Assert(received.Adm, Equals, false)
|
||||||
|
|
||||||
|
}
|
||||||
164
jwt/signer.go
Normal file
164
jwt/signer.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type hashLinkError struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *hashLinkError) Error() string {
|
||||||
|
return fmt.Sprintf("jws: Hash %s is not linked into binary", e.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Signer is able to crypto-sign a JWS object.
|
||||||
|
type Signer interface {
|
||||||
|
// Sign returns the signature of a binary payload. It should
|
||||||
|
// returns the signature in binary format
|
||||||
|
Sign([]byte) ([]byte, error)
|
||||||
|
// Verify check if the signed binary paylod correspond to the
|
||||||
|
// provided binary siganture. It assumes a binary format
|
||||||
|
// signature.
|
||||||
|
Verify(signed, signature []byte) error
|
||||||
|
// Algorithm is returning a string idnetifying the algorithm used
|
||||||
|
// as defined in RFC 7518.
|
||||||
|
Algorithm() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// A HMACSigner is a Signer using HMAC based algorithm
|
||||||
|
type HMACSigner struct {
|
||||||
|
name string
|
||||||
|
hash crypto.Hash
|
||||||
|
key []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHMAC256Signer returs a HMACSigner using a SHA256 hash.
|
||||||
|
func NewHMAC256Signer(key []byte) Signer {
|
||||||
|
return &HMACSigner{
|
||||||
|
name: "HS256",
|
||||||
|
hash: crypto.SHA256,
|
||||||
|
key: key,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHMAC384Signer returs a HMACSigner using a SHA384 hash.
|
||||||
|
func NewHMAC384Signer(key []byte) Signer {
|
||||||
|
return &HMACSigner{
|
||||||
|
name: "HS384",
|
||||||
|
hash: crypto.SHA384,
|
||||||
|
key: key,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHMAC512Signer returs a HMACSigner using a SHA512 hash.
|
||||||
|
func NewHMAC512Signer(key []byte) Signer {
|
||||||
|
return &HMACSigner{
|
||||||
|
name: "HS512",
|
||||||
|
hash: crypto.SHA512,
|
||||||
|
key: key,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Algorithm return 'HSXXX' where XXX is either 256|384|512 depending
|
||||||
|
// on the SHA Hash used.
|
||||||
|
func (s *HMACSigner) Algorithm() string {
|
||||||
|
return s.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign compute the binary signature for data.
|
||||||
|
func (s *HMACSigner) Sign(data []byte) ([]byte, error) {
|
||||||
|
if s.hash.Available() == false {
|
||||||
|
return nil, &hashLinkError{s.name}
|
||||||
|
}
|
||||||
|
|
||||||
|
hasher := hmac.New(s.hash.New, s.key)
|
||||||
|
hasher.Write(data)
|
||||||
|
return hasher.Sum(nil), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify returns an error if the binary provided signature does not
|
||||||
|
// correspond to signed. It return nil if the signature is correct.
|
||||||
|
func (s *HMACSigner) Verify(signed, signature []byte) error {
|
||||||
|
expected, err := s.Sign(signed)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if hmac.Equal(signature, expected) == false {
|
||||||
|
return fmt.Errorf("jws: Invalid %s signature", s.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// A RSASigner is a Signer using a PKS 1v15 algorithm.
|
||||||
|
type RSASigner struct {
|
||||||
|
name string
|
||||||
|
hash crypto.Hash
|
||||||
|
key *rsa.PrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRSA256Signer returns a RSASigner using a SHA 256 hash.
|
||||||
|
func NewRSA256Signer(key *rsa.PrivateKey) Signer {
|
||||||
|
return &RSASigner{
|
||||||
|
name: "RS256",
|
||||||
|
hash: crypto.SHA256,
|
||||||
|
key: key,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRSA384Signer returns a RSASigner using a SHA 384 hash.
|
||||||
|
func NewRSA384Signer(key *rsa.PrivateKey) Signer {
|
||||||
|
return &RSASigner{
|
||||||
|
name: "RS384",
|
||||||
|
hash: crypto.SHA384,
|
||||||
|
key: key,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRSA512Signer returns a RSASigner using a SHA 512 hash.
|
||||||
|
func NewRSA512Signer(key *rsa.PrivateKey) Signer {
|
||||||
|
return &RSASigner{
|
||||||
|
name: "RS512",
|
||||||
|
hash: crypto.SHA512,
|
||||||
|
key: key,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign compute the binary signature for data.
|
||||||
|
func (s *RSASigner) Sign(data []byte) ([]byte, error) {
|
||||||
|
if s.hash.Available() == false {
|
||||||
|
return nil, &hashLinkError{s.name}
|
||||||
|
}
|
||||||
|
|
||||||
|
hasher := s.hash.New()
|
||||||
|
hasher.Write(data)
|
||||||
|
res, err := rsa.SignPKCS1v15(rand.Reader, s.key, s.hash, hasher.Sum(nil))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify returns an error if the binary provided signature does not
|
||||||
|
// correspond to signed. It return nil if the signature is correct.
|
||||||
|
func (s *RSASigner) Verify(signed, signature []byte) error {
|
||||||
|
if s.hash.Available() == false {
|
||||||
|
return &hashLinkError{s.name}
|
||||||
|
}
|
||||||
|
|
||||||
|
hasher := s.hash.New()
|
||||||
|
hasher.Write(signed)
|
||||||
|
|
||||||
|
return rsa.VerifyPKCS1v15(s.key.Public().(*rsa.PublicKey), s.hash, hasher.Sum(nil), signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Algorithm return 'RSXXX' where XXX is either 256|384|512 depending
|
||||||
|
// on the SHA Hash used.
|
||||||
|
func (s *RSASigner) Algorithm() string {
|
||||||
|
return s.name
|
||||||
|
}
|
||||||
55
jwt/signer_test.go
Normal file
55
jwt/signer_test.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewHMACKey(n int) []byte {
|
||||||
|
res := make([]byte, n)
|
||||||
|
rand.Reader.Read(res)
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func CachedRSAkey() (*rsa.PrivateKey, error) {
|
||||||
|
keyPath := filepath.Join(os.TempDir(), "narco-jwt-test.key")
|
||||||
|
f, err := os.Open(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) == false {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// generate a key
|
||||||
|
log.Printf("Generating a new key")
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Create(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Could not cache the generated key: %s", err)
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data := x509.MarshalPKCS1PrivateKey(key)
|
||||||
|
_, err = f.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Could not cache the generated key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
keydata, err := ioutil.ReadAll(f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not read %s: %s", keyPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return x509.ParsePKCS1PrivateKey(keydata)
|
||||||
|
}
|
||||||
15
logger.go
15
logger.go
@@ -23,21 +23,6 @@ func NewLogger() *Logger {
|
|||||||
return &Logger{log.New(os.Stdout, "[narco] ", log.LstdFlags)}
|
return &Logger{log.New(os.Stdout, "[narco] ", log.LstdFlags)}
|
||||||
}
|
}
|
||||||
|
|
||||||
type FooResponseWriter struct {
|
|
||||||
header http.Header
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *FooResponseWriter) Header() http.Header {
|
|
||||||
return f.header
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *FooResponseWriter) WriteHeader(status int) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *FooResponseWriter) Write(b []byte) (int, error) {
|
|
||||||
return len(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Wrap() func(chain.Handler) chain.Handler {
|
func (l *Logger) Wrap() func(chain.Handler) chain.Handler {
|
||||||
return func(other chain.Handler) chain.Handler {
|
return func(other chain.Handler) chain.Handler {
|
||||||
return chain.HandlerFunc(func(ctx context.Context, h http.ResponseWriter, r *http.Request) {
|
return chain.HandlerFunc(func(ctx context.Context, h http.ResponseWriter, r *http.Request) {
|
||||||
|
|||||||
73
recovery_test.go
Normal file
73
recovery_test.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package narco
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
. "gopkg.in/check.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RecovererSuite struct {
|
||||||
|
*NarcoSuite
|
||||||
|
|
||||||
|
recov *Recoverer
|
||||||
|
buffer bytes.Buffer
|
||||||
|
rec *httptest.ResponseRecorder
|
||||||
|
req *http.Request
|
||||||
|
URL string
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = Suite(&RecovererSuite{NarcoSuite: NewNarcoSuite()})
|
||||||
|
|
||||||
|
func (s *RecovererSuite) SetUpSuite(c *C) {
|
||||||
|
|
||||||
|
s.recov = NewRecoverer()
|
||||||
|
s.recov.Logger = log.New(&s.buffer, "[narco] ", log.LstdFlags)
|
||||||
|
s.chain = s.chain.Append(s.recov.Wrap())
|
||||||
|
|
||||||
|
c.Assert(s.chain, NotNil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RecovererSuite) SetUpTest(c *C) {
|
||||||
|
s.buffer.Reset()
|
||||||
|
s.rec = httptest.NewRecorder()
|
||||||
|
var err error
|
||||||
|
s.URL = "http://" + httptest.DefaultRemoteAddr + "/"
|
||||||
|
s.req, err = http.NewRequest("GET", s.URL, nil)
|
||||||
|
c.Assert(err, IsNil, Commentf("Unexpected error :%s", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RecovererSuite) TestRecoversFromPanic(c *C) {
|
||||||
|
// this should not panic
|
||||||
|
defer func() {
|
||||||
|
err := recover()
|
||||||
|
c.Assert(err, IsNil, Commentf("It should not have panic: %s", err))
|
||||||
|
}()
|
||||||
|
|
||||||
|
s.ServeHTTP(s.rec, s.req, func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
panic("foo")
|
||||||
|
})
|
||||||
|
|
||||||
|
lines := strings.Split(s.buffer.String(), "\n")
|
||||||
|
|
||||||
|
c.Assert(lines[0], Matches, `\[narco\] .* `+"PANIC: foo")
|
||||||
|
for i, l := range lines {
|
||||||
|
if i == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
matchString := "foo"
|
||||||
|
if i == 1 {
|
||||||
|
matchString = `goroutine [0-9]+ (\[running\])?:`
|
||||||
|
} else if i%2 == 0 {
|
||||||
|
//either function call, empty or created by line
|
||||||
|
matchString = `(.*\(.*\)|created by .*|)`
|
||||||
|
} else {
|
||||||
|
//function location, or <autogenerated>
|
||||||
|
matchString = `\t(.*\.go|\<autogenerated\>):[0-9]+ .*`
|
||||||
|
}
|
||||||
|
c.Assert(l, Matches, matchString, Commentf("On line %d", i))
|
||||||
|
}
|
||||||
|
}
|
||||||
19
router.go
Normal file
19
router.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package narco
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/codemodus/chain"
|
||||||
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HandlerFunc func(ctx context.Context, rw http.ResponseWriter, req *http.Request, ps httprouter.Params)
|
||||||
|
|
||||||
|
func EndChain(chain chain.Chain, h HandlerFunc) httprouter.Handle {
|
||||||
|
return func(rw http.ResponseWriter, req *http.Request, ps httprouter.Params) {
|
||||||
|
chain.EndFn(func(ctx context.Context, rw_ http.ResponseWriter, req_ *http.Request) {
|
||||||
|
h(ctx, rw_, req_, ps)
|
||||||
|
}).ServeHTTP(rw, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user