mirror of
https://github.com/usememos/memos.git
synced 2025-01-08 21:33:38 +08:00
117 lines
3.4 KiB
Go
117 lines
3.4 KiB
Go
// Package oauth2 is the plugin for OAuth2 Identity Provider.
|
|
package oauth2
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
|
|
"github.com/pkg/errors"
|
|
"golang.org/x/oauth2"
|
|
|
|
"github.com/usememos/memos/plugin/idp"
|
|
storepb "github.com/usememos/memos/proto/gen/store"
|
|
)
|
|
|
|
// IdentityProvider represents an OAuth2 Identity Provider.
|
|
type IdentityProvider struct {
|
|
config *storepb.OAuth2Config
|
|
}
|
|
|
|
// NewIdentityProvider initializes a new OAuth2 Identity Provider with the given configuration.
|
|
func NewIdentityProvider(config *storepb.OAuth2Config) (*IdentityProvider, error) {
|
|
for v, field := range map[string]string{
|
|
config.ClientId: "clientId",
|
|
config.ClientSecret: "clientSecret",
|
|
config.TokenUrl: "tokenUrl",
|
|
config.UserInfoUrl: "userInfoUrl",
|
|
config.FieldMapping.Identifier: "fieldMapping.identifier",
|
|
} {
|
|
if v == "" {
|
|
return nil, errors.Errorf(`the field "%s" is empty but required`, field)
|
|
}
|
|
}
|
|
|
|
return &IdentityProvider{
|
|
config: config,
|
|
}, nil
|
|
}
|
|
|
|
// ExchangeToken returns the exchanged OAuth2 token using the given authorization code.
|
|
func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code string) (string, error) {
|
|
conf := &oauth2.Config{
|
|
ClientID: p.config.ClientId,
|
|
ClientSecret: p.config.ClientSecret,
|
|
RedirectURL: redirectURL,
|
|
Scopes: p.config.Scopes,
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: p.config.AuthUrl,
|
|
TokenURL: p.config.TokenUrl,
|
|
AuthStyle: oauth2.AuthStyleInParams,
|
|
},
|
|
}
|
|
|
|
token, err := conf.Exchange(ctx, code)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, "failed to exchange access token")
|
|
}
|
|
|
|
accessToken, ok := token.Extra("access_token").(string)
|
|
if !ok {
|
|
return "", errors.New(`missing "access_token" from authorization response`)
|
|
}
|
|
|
|
return accessToken, nil
|
|
}
|
|
|
|
// UserInfo returns the parsed user information using the given OAuth2 token.
|
|
func (p *IdentityProvider) UserInfo(token string) (*idp.IdentityProviderUserInfo, error) {
|
|
client := &http.Client{}
|
|
req, err := http.NewRequest(http.MethodGet, p.config.UserInfoUrl, nil)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to new http request")
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to get user information")
|
|
}
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to read response body")
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var claims map[string]any
|
|
err = json.Unmarshal(body, &claims)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
|
}
|
|
|
|
userInfo := &idp.IdentityProviderUserInfo{}
|
|
if v, ok := claims[p.config.FieldMapping.Identifier].(string); ok {
|
|
userInfo.Identifier = v
|
|
}
|
|
if userInfo.Identifier == "" {
|
|
return nil, errors.Errorf("the field %q is not found in claims or has empty value", p.config.FieldMapping.Identifier)
|
|
}
|
|
|
|
// Best effort to map optional fields
|
|
if p.config.FieldMapping.DisplayName != "" {
|
|
if v, ok := claims[p.config.FieldMapping.DisplayName].(string); ok {
|
|
userInfo.DisplayName = v
|
|
}
|
|
}
|
|
if userInfo.DisplayName == "" {
|
|
userInfo.DisplayName = userInfo.Identifier
|
|
}
|
|
if p.config.FieldMapping.Email != "" {
|
|
if v, ok := claims[p.config.FieldMapping.Email].(string); ok {
|
|
userInfo.Email = v
|
|
}
|
|
}
|
|
return userInfo, nil
|
|
}
|