From 852903bdbdf57d56ac7afda55d6deea49f3d98d1 Mon Sep 17 00:00:00 2001 From: boojack Date: Sat, 18 Feb 2023 18:31:03 +0800 Subject: [PATCH] fix: idp config definition (#1115) fix: idp definition --- api/idp.go | 4 +- server/idp.go | 101 +++++++++++++++++++++++++++++++++++++------------- store/idp.go | 20 +++++++--- 3 files changed, 94 insertions(+), 31 deletions(-) diff --git a/api/idp.go b/api/idp.go index 8e7753db..bfffbfcf 100644 --- a/api/idp.go +++ b/api/idp.go @@ -6,7 +6,9 @@ const ( IdentityProviderOAuth2 IdentityProviderType = "OAUTH2" ) -type IdentityProviderConfig interface{} +type IdentityProviderConfig struct { + OAuth2Config *IdentityProviderOAuth2Config `json:"oauth2Config"` +} type IdentityProviderOAuth2Config struct { ClientID string `json:"clientId"` diff --git a/server/idp.go b/server/idp.go index d1d87a01..1b369b5c 100644 --- a/server/idp.go +++ b/server/idp.go @@ -35,16 +35,16 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post identity provider request").SetInternal(err) } - identityProvider, err := s.Store.CreateIdentityProvider(ctx, &store.IdentityProviderMessage{ + identityProviderMessage, err := s.Store.CreateIdentityProvider(ctx, &store.IdentityProviderMessage{ Name: identityProviderCreate.Name, Type: store.IdentityProviderType(identityProviderCreate.Type), IdentifierFilter: identityProviderCreate.IdentifierFilter, - Config: (*store.IdentityProviderConfig)(identityProviderCreate.Config), + Config: convertIdentityProviderConfigToStore(identityProviderCreate.Config), }) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider").SetInternal(err) } - return c.JSON(http.StatusOK, composeResponse(identityProvider)) + return c.JSON(http.StatusOK, composeResponse(convertIdentityProviderFromStore(identityProviderMessage))) }) g.PATCH("/idp/:idpId", func(c echo.Context) error { @@ -76,17 +76,17 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch identity provider request").SetInternal(err) } - identityProvider, err := s.Store.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderMessage{ + identityProviderMessage, err := s.Store.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderMessage{ ID: identityProviderPatch.ID, Type: store.IdentityProviderType(identityProviderPatch.Type), Name: identityProviderPatch.Name, IdentifierFilter: identityProviderPatch.IdentifierFilter, - Config: (*store.IdentityProviderConfig)(identityProviderPatch.Config), + Config: convertIdentityProviderConfigToStore(identityProviderPatch.Config), }) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch identity provider").SetInternal(err) } - return c.JSON(http.StatusOK, identityProvider) + return c.JSON(http.StatusOK, composeResponse(convertIdentityProviderFromStore(identityProviderMessage))) }) g.GET("/idp", func(c echo.Context) error { @@ -112,13 +112,44 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider list").SetInternal(err) } - var identityProviderList []*api.IdentityProvider + identityProviderList := []*api.IdentityProvider{} for _, identityProviderMessage := range identityProviderMessageList { identityProviderList = append(identityProviderList, convertIdentityProviderFromStore(identityProviderMessage)) } return c.JSON(http.StatusOK, composeResponse(identityProviderList)) }) + g.GET("/idp/:idpId", func(c echo.Context) error { + ctx := c.Request().Context() + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") + } + + user, err := s.Store.FindUser(ctx, &api.UserFind{ + ID: &userID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) + } + // We should only show identity provider list to host user. + if user == nil || user.Role != api.Host { + return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") + } + + identityProviderID, err := strconv.Atoi(c.Param("idpId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err) + } + identityProviderMessage, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{ + ID: &identityProviderID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get identity provider").SetInternal(err) + } + return c.JSON(http.StatusOK, composeResponse(convertIdentityProviderFromStore(identityProviderMessage))) + }) + g.DELETE("/idp/:idpId", func(c echo.Context) error { ctx := c.Request().Context() userID, ok := c.Get(getUserIDContextKey()).(int) @@ -152,27 +183,47 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) { } func convertIdentityProviderFromStore(identityProviderMessage *store.IdentityProviderMessage) *api.IdentityProvider { - identityProvider := &api.IdentityProvider{ + return &api.IdentityProvider{ ID: identityProviderMessage.ID, Name: identityProviderMessage.Name, Type: api.IdentityProviderType(identityProviderMessage.Type), IdentifierFilter: identityProviderMessage.IdentifierFilter, + Config: convertIdentityProviderConfigFromStore(identityProviderMessage.Config), + } +} + +func convertIdentityProviderConfigFromStore(config *store.IdentityProviderConfig) *api.IdentityProviderConfig { + return &api.IdentityProviderConfig{ + OAuth2Config: &api.IdentityProviderOAuth2Config{ + ClientID: config.OAuth2Config.ClientID, + ClientSecret: config.OAuth2Config.ClientSecret, + AuthURL: config.OAuth2Config.AuthURL, + TokenURL: config.OAuth2Config.TokenURL, + UserInfoURL: config.OAuth2Config.UserInfoURL, + Scopes: config.OAuth2Config.Scopes, + FieldMapping: &api.FieldMapping{ + Identifier: config.OAuth2Config.FieldMapping.Identifier, + DisplayName: config.OAuth2Config.FieldMapping.DisplayName, + Email: config.OAuth2Config.FieldMapping.Email, + }, + }, + } +} + +func convertIdentityProviderConfigToStore(config *api.IdentityProviderConfig) *store.IdentityProviderConfig { + return &store.IdentityProviderConfig{ + OAuth2Config: &store.IdentityProviderOAuth2Config{ + ClientID: config.OAuth2Config.ClientID, + ClientSecret: config.OAuth2Config.ClientSecret, + AuthURL: config.OAuth2Config.AuthURL, + TokenURL: config.OAuth2Config.TokenURL, + UserInfoURL: config.OAuth2Config.UserInfoURL, + Scopes: config.OAuth2Config.Scopes, + FieldMapping: &store.FieldMapping{ + Identifier: config.OAuth2Config.FieldMapping.Identifier, + DisplayName: config.OAuth2Config.FieldMapping.DisplayName, + Email: config.OAuth2Config.FieldMapping.Email, + }, + }, } - if identityProvider.Type == api.IdentityProviderOAuth2 { - configMessage := any(identityProviderMessage.Config).(*store.IdentityProviderOAuth2Config) - identityProvider.Config = any(&api.IdentityProviderOAuth2Config{ - ClientID: configMessage.ClientID, - ClientSecret: configMessage.ClientSecret, - AuthURL: configMessage.AuthURL, - TokenURL: configMessage.TokenURL, - UserInfoURL: configMessage.UserInfoURL, - Scopes: configMessage.Scopes, - FieldMapping: &api.FieldMapping{ - Identifier: configMessage.FieldMapping.Identifier, - DisplayName: configMessage.FieldMapping.DisplayName, - Email: configMessage.FieldMapping.Email, - }, - }).(*api.IdentityProviderConfig) - } - return identityProvider } diff --git a/store/idp.go b/store/idp.go index cf556090..0d6905d6 100644 --- a/store/idp.go +++ b/store/idp.go @@ -16,7 +16,9 @@ const ( IdentityProviderOAuth2 IdentityProviderType = "OAUTH2" ) -type IdentityProviderConfig interface{} +type IdentityProviderConfig struct { + OAuth2Config *IdentityProviderOAuth2Config +} type IdentityProviderOAuth2Config struct { ClientID string `json:"clientId"` @@ -67,7 +69,7 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv var configBytes []byte if create.Type == IdentityProviderOAuth2 { - configBytes, err = json.Marshal(any(create.Config).(*IdentityProviderOAuth2Config)) + configBytes, err = json.Marshal(create.Config.OAuth2Config) if err != nil { return nil, err } @@ -153,7 +155,7 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti if v := update.Config; v != nil { var configBytes []byte if update.Type == IdentityProviderOAuth2 { - configBytes, err = json.Marshal(any(update.Config).(*IdentityProviderOAuth2Config)) + configBytes, err = json.Marshal(update.Config.OAuth2Config) if err != nil { return nil, err } @@ -182,9 +184,13 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti return nil, FormatError(err) } if identityProviderMessage.Type == IdentityProviderOAuth2 { - if err := json.Unmarshal([]byte(identityProviderConfig), any(identityProviderMessage.Config).(*IdentityProviderOAuth2Config)); err != nil { + oauth2Config := &IdentityProviderOAuth2Config{} + if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { return nil, err } + identityProviderMessage.Config = &IdentityProviderConfig{ + OAuth2Config: oauth2Config, + } } else { return nil, fmt.Errorf("unsupported idp type %s", string(identityProviderMessage.Type)) } @@ -252,9 +258,13 @@ func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityPr return nil, FormatError(err) } if identityProviderMessage.Type == IdentityProviderOAuth2 { - if err := json.Unmarshal([]byte(identityProviderConfig), any(identityProviderMessage.Config).(*IdentityProviderOAuth2Config)); err != nil { + oauth2Config := &IdentityProviderOAuth2Config{} + if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { return nil, err } + identityProviderMessage.Config = &IdentityProviderConfig{ + OAuth2Config: oauth2Config, + } } else { return nil, fmt.Errorf("unsupported idp type %s", string(identityProviderMessage.Type)) }