From 69726c3925390e71145c21953a1c34e6ab17b36f Mon Sep 17 00:00:00 2001 From: boojack Date: Sat, 18 Feb 2023 10:50:13 +0800 Subject: [PATCH] feat: implement oauth2 plugin (#1110) --- go.mod | 10 +- go.sum | 32 ++++-- plugin/idp/idp.go | 7 ++ plugin/idp/oauth2/oauth2.go | 115 ++++++++++++++++++++++ plugin/idp/oauth2/oauth2_test.go | 163 +++++++++++++++++++++++++++++++ 5 files changed, 318 insertions(+), 9 deletions(-) create mode 100644 plugin/idp/idp.go create mode 100644 plugin/idp/oauth2/oauth2.go create mode 100644 plugin/idp/oauth2/oauth2_test.go diff --git a/go.mod b/go.mod index 047e825b..8f9cf246 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require github.com/google/uuid v1.3.0 require ( golang.org/x/crypto v0.1.0 - golang.org/x/net v0.1.0 + golang.org/x/net v0.6.0 ) require github.com/labstack/echo/v4 v4.9.0 @@ -37,6 +37,7 @@ require ( github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect + github.com/golang/protobuf v1.5.2 // indirect github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect @@ -51,9 +52,11 @@ require ( github.com/xtgo/uuid v0.0.0-20140804021211-a0b114877d4c // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.6.0 // indirect - golang.org/x/sys v0.1.0 // indirect - golang.org/x/text v0.4.0 // indirect + golang.org/x/sys v0.5.0 // indirect + golang.org/x/text v0.7.0 // indirect golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 // indirect + google.golang.org/appengine v1.6.7 // indirect + google.golang.org/protobuf v1.28.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) @@ -70,4 +73,5 @@ require ( go.uber.org/zap v1.24.0 golang.org/x/exp v0.0.0-20230111222715-75897c7a292a golang.org/x/mod v0.6.0 + golang.org/x/oauth2 v0.5.0 ) diff --git a/go.sum b/go.sum index bbc17136..b10c1e4f 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,11 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= @@ -115,23 +120,38 @@ go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/exp v0.0.0-20230111222715-75897c7a292a h1:/YWeLOBWYV5WAQORVPkZF3Pq9IppkcT72GKnWjNf5W8= golang.org/x/exp v0.0.0-20230111222715-75897c7a292a/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= -golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0= -golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.6.0 h1:L4ZwwTvKW9gr0ZMS1yrHD9GZhIuVjOBBnaKH+SPQK0Q= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/oauth2 v0.5.0 h1:HuArIo48skDwlrvM3sEdHXElYslAMsf3KwRkkW4MC4s= +golang.org/x/oauth2 v0.5.0/go.mod h1:9/XBHVqLaWO3/BRHs5jbpYCnOZVjj5V0ndyaAM7KB4I= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 h1:ftMN5LMiBFjbzleLqtoBZk7KdJwhuybIU+FckUHgoyQ= golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= +google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/plugin/idp/idp.go b/plugin/idp/idp.go new file mode 100644 index 00000000..c83bccfd --- /dev/null +++ b/plugin/idp/idp.go @@ -0,0 +1,7 @@ +package idp + +type IdentityProviderUserInfo struct { + Identifier string + DisplayName string + Email string +} diff --git a/plugin/idp/oauth2/oauth2.go b/plugin/idp/oauth2/oauth2.go new file mode 100644 index 00000000..ac31a7b1 --- /dev/null +++ b/plugin/idp/oauth2/oauth2.go @@ -0,0 +1,115 @@ +// Package oauth2 is the plugin for OAuth2 Identity Provider. +package oauth2 + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/pkg/errors" + "github.com/usememos/memos/plugin/idp" + "github.com/usememos/memos/store" + "golang.org/x/oauth2" +) + +// IdentityProvider represents an OAuth2 Identity Provider. +type IdentityProvider struct { + config *store.IdentityProviderOAuth2Config +} + +// NewIdentityProvider initializes a new OAuth2 Identity Provider with the given configuration. +func NewIdentityProvider(config *store.IdentityProviderOAuth2Config) (*IdentityProvider, error) { + for v, field := range map[string]string{ + config.ClientID: "clientId", + config.ClientSecret: "clientSecret", + config.TokenURL: "tokenUrl", + config.UserInfoURL: "userInfoUrl", + config.FieldMapping.Identifier: "fieldMapping.identifier", + } { + if v == "" { + return nil, errors.Errorf(`the field "%s" is empty but required`, field) + } + } + + return &IdentityProvider{ + config: config, + }, nil +} + +// ExchangeToken returns the exchanged OAuth2 token using the given authorization code. +func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code string) (string, error) { + conf := &oauth2.Config{ + ClientID: p.config.ClientID, + ClientSecret: p.config.ClientSecret, + RedirectURL: redirectURL, + Scopes: p.config.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: p.config.AuthURL, + TokenURL: p.config.TokenURL, + AuthStyle: oauth2.AuthStyleInParams, + }, + } + + token, err := conf.Exchange(ctx, code) + if err != nil { + return "", errors.Wrap(err, "failed to exchange access token") + } + + accessToken, ok := token.Extra("access_token").(string) + if !ok { + return "", errors.New(`missing "access_token" from authorization response`) + } + + return accessToken, nil +} + +// UserInfo returns the parsed user information using the given OAuth2 token. +func (p *IdentityProvider) UserInfo(token string) (*idp.IdentityProviderUserInfo, error) { + client := &http.Client{} + req, err := http.NewRequest(http.MethodGet, p.config.UserInfoURL, nil) + if err != nil { + return nil, errors.Wrap(err, "failed to new http request") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + resp, err := client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "failed to get user information") + } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } + + var claims map[string]any + err = json.Unmarshal(body, &claims) + if err != nil { + return nil, errors.Wrap(err, "failed to unmarshal response body") + } + + userInfo := &idp.IdentityProviderUserInfo{} + if v, ok := claims[p.config.FieldMapping.Identifier].(string); ok { + userInfo.Identifier = v + } + if userInfo.Identifier == "" { + return nil, errors.Errorf("the field %q is not found in claims or has empty value", p.config.FieldMapping.Identifier) + } + + // Best effort to map optional fields + if p.config.FieldMapping.DisplayName != "" { + if v, ok := claims[p.config.FieldMapping.DisplayName].(string); ok { + userInfo.DisplayName = v + } + } + if userInfo.DisplayName == "" { + userInfo.DisplayName = userInfo.Identifier + } + if p.config.FieldMapping.Email != "" { + if v, ok := claims[p.config.FieldMapping.Email].(string); ok { + userInfo.Email = v + } + } + return userInfo, nil +} diff --git a/plugin/idp/oauth2/oauth2_test.go b/plugin/idp/oauth2/oauth2_test.go new file mode 100644 index 00000000..e14e6168 --- /dev/null +++ b/plugin/idp/oauth2/oauth2_test.go @@ -0,0 +1,163 @@ +package oauth2 + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/usememos/memos/plugin/idp" + "github.com/usememos/memos/store" +) + +func TestNewIdentityProvider(t *testing.T) { + tests := []struct { + name string + config *store.IdentityProviderOAuth2Config + containsErr string + }{ + { + name: "no tokenUrl", + config: &store.IdentityProviderOAuth2Config{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + AuthURL: "", + TokenURL: "", + UserInfoURL: "https://example.com/api/user", + FieldMapping: &store.FieldMapping{ + Identifier: "login", + }, + }, + containsErr: `the field "tokenUrl" is empty but required`, + }, + { + name: "no userInfoUrl", + config: &store.IdentityProviderOAuth2Config{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + AuthURL: "", + TokenURL: "https://example.com/token", + UserInfoURL: "", + FieldMapping: &store.FieldMapping{ + Identifier: "login", + }, + }, + containsErr: `the field "userInfoUrl" is empty but required`, + }, + { + name: "no field mapping identifier", + config: &store.IdentityProviderOAuth2Config{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + AuthURL: "", + TokenURL: "https://example.com/token", + UserInfoURL: "https://example.com/api/user", + FieldMapping: &store.FieldMapping{ + Identifier: "", + }, + }, + containsErr: `the field "fieldMapping.identifier" is empty but required`, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := NewIdentityProvider(test.config) + assert.ErrorContains(t, err, test.containsErr) + }) + } +} + +func newMockServer(t *testing.T, code, accessToken string, userinfo []byte) *httptest.Server { + mux := http.NewServeMux() + + var rawIDToken string + mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + vals, err := url.ParseQuery(string(body)) + require.NoError(t, err) + + require.Equal(t, code, vals.Get("code")) + require.Equal(t, "authorization_code", vals.Get("grant_type")) + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(map[string]any{ + "access_token": accessToken, + "token_type": "Bearer", + "refresh_token": "test-refresh-token", + "expires_in": 3600, + "id_token": rawIDToken, + }) + require.NoError(t, err) + }) + mux.HandleFunc("/oauth2/userinfo", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, err := w.Write(userinfo) + require.NoError(t, err) + }) + + s := httptest.NewServer(mux) + + return s +} + +func TestIdentityProvider(t *testing.T) { + ctx := context.Background() + + const ( + testClientID = "test-client-id" + testCode = "test-code" + testAccessToken = "test-access-token" + testSubject = "123456789" + testName = "John Doe" + testEmail = "john.doe@example.com" + ) + userInfo, err := json.Marshal( + map[string]any{ + "sub": testSubject, + "name": testName, + "email": testEmail, + }, + ) + require.NoError(t, err) + + s := newMockServer(t, testCode, testAccessToken, userInfo) + + oauth2, err := NewIdentityProvider( + &store.IdentityProviderOAuth2Config{ + ClientID: testClientID, + ClientSecret: "test-client-secret", + TokenURL: fmt.Sprintf("%s/oauth2/token", s.URL), + UserInfoURL: fmt.Sprintf("%s/oauth2/userinfo", s.URL), + FieldMapping: &store.FieldMapping{ + Identifier: "sub", + DisplayName: "name", + Email: "email", + }, + }, + ) + require.NoError(t, err) + + redirectURL := "https://example.com/oauth/callback" + oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode) + require.NoError(t, err) + require.Equal(t, testAccessToken, oauthToken) + + userInfoResult, err := oauth2.UserInfo(oauthToken) + require.NoError(t, err) + + wantUserInfo := &idp.IdentityProviderUserInfo{ + Identifier: testSubject, + DisplayName: testName, + Email: testEmail, + } + assert.Equal(t, wantUserInfo, userInfoResult) +}