mirror of
https://github.com/usememos/memos.git
synced 2025-01-07 20:58:19 +08:00
163 lines
4.2 KiB
Go
163 lines
4.2 KiB
Go
package oauth2
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/usememos/memos/plugin/idp"
|
|
"github.com/usememos/memos/store"
|
|
)
|
|
|
|
func TestNewIdentityProvider(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
config *store.IdentityProviderOAuth2Config
|
|
containsErr string
|
|
}{
|
|
{
|
|
name: "no tokenUrl",
|
|
config: &store.IdentityProviderOAuth2Config{
|
|
ClientID: "test-client-id",
|
|
ClientSecret: "test-client-secret",
|
|
AuthURL: "",
|
|
TokenURL: "",
|
|
UserInfoURL: "https://example.com/api/user",
|
|
FieldMapping: &store.FieldMapping{
|
|
Identifier: "login",
|
|
},
|
|
},
|
|
containsErr: `the field "tokenUrl" is empty but required`,
|
|
},
|
|
{
|
|
name: "no userInfoUrl",
|
|
config: &store.IdentityProviderOAuth2Config{
|
|
ClientID: "test-client-id",
|
|
ClientSecret: "test-client-secret",
|
|
AuthURL: "",
|
|
TokenURL: "https://example.com/token",
|
|
UserInfoURL: "",
|
|
FieldMapping: &store.FieldMapping{
|
|
Identifier: "login",
|
|
},
|
|
},
|
|
containsErr: `the field "userInfoUrl" is empty but required`,
|
|
},
|
|
{
|
|
name: "no field mapping identifier",
|
|
config: &store.IdentityProviderOAuth2Config{
|
|
ClientID: "test-client-id",
|
|
ClientSecret: "test-client-secret",
|
|
AuthURL: "",
|
|
TokenURL: "https://example.com/token",
|
|
UserInfoURL: "https://example.com/api/user",
|
|
FieldMapping: &store.FieldMapping{
|
|
Identifier: "",
|
|
},
|
|
},
|
|
containsErr: `the field "fieldMapping.identifier" is empty but required`,
|
|
},
|
|
}
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
_, err := NewIdentityProvider(test.config)
|
|
assert.ErrorContains(t, err, test.containsErr)
|
|
})
|
|
}
|
|
}
|
|
|
|
func newMockServer(t *testing.T, code, accessToken string, userinfo []byte) *httptest.Server {
|
|
mux := http.NewServeMux()
|
|
|
|
var rawIDToken string
|
|
mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
|
|
require.Equal(t, http.MethodPost, r.Method)
|
|
|
|
body, err := io.ReadAll(r.Body)
|
|
require.NoError(t, err)
|
|
vals, err := url.ParseQuery(string(body))
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, code, vals.Get("code"))
|
|
require.Equal(t, "authorization_code", vals.Get("grant_type"))
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w).Encode(map[string]any{
|
|
"access_token": accessToken,
|
|
"token_type": "Bearer",
|
|
"refresh_token": "test-refresh-token",
|
|
"expires_in": 3600,
|
|
"id_token": rawIDToken,
|
|
})
|
|
require.NoError(t, err)
|
|
})
|
|
mux.HandleFunc("/oauth2/userinfo", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, err := w.Write(userinfo)
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
s := httptest.NewServer(mux)
|
|
|
|
return s
|
|
}
|
|
|
|
func TestIdentityProvider(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
const (
|
|
testClientID = "test-client-id"
|
|
testCode = "test-code"
|
|
testAccessToken = "test-access-token"
|
|
testSubject = "123456789"
|
|
testName = "John Doe"
|
|
testEmail = "john.doe@example.com"
|
|
)
|
|
userInfo, err := json.Marshal(
|
|
map[string]any{
|
|
"sub": testSubject,
|
|
"name": testName,
|
|
"email": testEmail,
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
s := newMockServer(t, testCode, testAccessToken, userInfo)
|
|
|
|
oauth2, err := NewIdentityProvider(
|
|
&store.IdentityProviderOAuth2Config{
|
|
ClientID: testClientID,
|
|
ClientSecret: "test-client-secret",
|
|
TokenURL: fmt.Sprintf("%s/oauth2/token", s.URL),
|
|
UserInfoURL: fmt.Sprintf("%s/oauth2/userinfo", s.URL),
|
|
FieldMapping: &store.FieldMapping{
|
|
Identifier: "sub",
|
|
DisplayName: "name",
|
|
Email: "email",
|
|
},
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
redirectURL := "https://example.com/oauth/callback"
|
|
oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode)
|
|
require.NoError(t, err)
|
|
require.Equal(t, testAccessToken, oauthToken)
|
|
|
|
userInfoResult, err := oauth2.UserInfo(oauthToken)
|
|
require.NoError(t, err)
|
|
|
|
wantUserInfo := &idp.IdentityProviderUserInfo{
|
|
Identifier: testSubject,
|
|
DisplayName: testName,
|
|
Email: testEmail,
|
|
}
|
|
assert.Equal(t, wantUserInfo, userInfoResult)
|
|
}
|