memos/plugin/idp/oauth2/oauth2_test.go
2023-02-18 10:50:13 +08:00

163 lines
4.2 KiB
Go

package oauth2
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/store"
)
func TestNewIdentityProvider(t *testing.T) {
tests := []struct {
name string
config *store.IdentityProviderOAuth2Config
containsErr string
}{
{
name: "no tokenUrl",
config: &store.IdentityProviderOAuth2Config{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
AuthURL: "",
TokenURL: "",
UserInfoURL: "https://example.com/api/user",
FieldMapping: &store.FieldMapping{
Identifier: "login",
},
},
containsErr: `the field "tokenUrl" is empty but required`,
},
{
name: "no userInfoUrl",
config: &store.IdentityProviderOAuth2Config{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
AuthURL: "",
TokenURL: "https://example.com/token",
UserInfoURL: "",
FieldMapping: &store.FieldMapping{
Identifier: "login",
},
},
containsErr: `the field "userInfoUrl" is empty but required`,
},
{
name: "no field mapping identifier",
config: &store.IdentityProviderOAuth2Config{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
AuthURL: "",
TokenURL: "https://example.com/token",
UserInfoURL: "https://example.com/api/user",
FieldMapping: &store.FieldMapping{
Identifier: "",
},
},
containsErr: `the field "fieldMapping.identifier" is empty but required`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
_, err := NewIdentityProvider(test.config)
assert.ErrorContains(t, err, test.containsErr)
})
}
}
func newMockServer(t *testing.T, code, accessToken string, userinfo []byte) *httptest.Server {
mux := http.NewServeMux()
var rawIDToken string
mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
body, err := io.ReadAll(r.Body)
require.NoError(t, err)
vals, err := url.ParseQuery(string(body))
require.NoError(t, err)
require.Equal(t, code, vals.Get("code"))
require.Equal(t, "authorization_code", vals.Get("grant_type"))
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(map[string]any{
"access_token": accessToken,
"token_type": "Bearer",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
"id_token": rawIDToken,
})
require.NoError(t, err)
})
mux.HandleFunc("/oauth2/userinfo", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, err := w.Write(userinfo)
require.NoError(t, err)
})
s := httptest.NewServer(mux)
return s
}
func TestIdentityProvider(t *testing.T) {
ctx := context.Background()
const (
testClientID = "test-client-id"
testCode = "test-code"
testAccessToken = "test-access-token"
testSubject = "123456789"
testName = "John Doe"
testEmail = "john.doe@example.com"
)
userInfo, err := json.Marshal(
map[string]any{
"sub": testSubject,
"name": testName,
"email": testEmail,
},
)
require.NoError(t, err)
s := newMockServer(t, testCode, testAccessToken, userInfo)
oauth2, err := NewIdentityProvider(
&store.IdentityProviderOAuth2Config{
ClientID: testClientID,
ClientSecret: "test-client-secret",
TokenURL: fmt.Sprintf("%s/oauth2/token", s.URL),
UserInfoURL: fmt.Sprintf("%s/oauth2/userinfo", s.URL),
FieldMapping: &store.FieldMapping{
Identifier: "sub",
DisplayName: "name",
Email: "email",
},
},
)
require.NoError(t, err)
redirectURL := "https://example.com/oauth/callback"
oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode)
require.NoError(t, err)
require.Equal(t, testAccessToken, oauthToken)
userInfoResult, err := oauth2.UserInfo(oauthToken)
require.NoError(t, err)
wantUserInfo := &idp.IdentityProviderUserInfo{
Identifier: testSubject,
DisplayName: testName,
Email: testEmail,
}
assert.Equal(t, wantUserInfo, userInfoResult)
}