2023-09-28 22:09:52 +08:00
package mysql
import (
"context"
"strings"
"github.com/pkg/errors"
2024-04-13 10:50:25 +08:00
storepb "github.com/usememos/memos/proto/gen/store"
2023-09-28 22:09:52 +08:00
"github.com/usememos/memos/store"
)
2023-10-05 23:11:29 +08:00
func ( d * DB ) CreateIdentityProvider ( ctx context . Context , create * store . IdentityProvider ) ( * store . IdentityProvider , error ) {
2023-10-08 18:28:22 +08:00
placeholders := [ ] string { "?" , "?" , "?" , "?" }
fields := [ ] string { "`name`" , "`type`" , "`identifier_filter`" , "`config`" }
2024-04-13 10:50:25 +08:00
args := [ ] any { create . Name , create . Type . String ( ) , create . IdentifierFilter , create . Config }
2023-10-08 18:28:22 +08:00
stmt := "INSERT INTO `idp` (" + strings . Join ( fields , ", " ) + ") VALUES (" + strings . Join ( placeholders , ", " ) + ")"
result , err := d . db . ExecContext ( ctx , stmt , args ... )
2023-09-28 22:09:52 +08:00
if err != nil {
return nil , err
}
id , err := result . LastInsertId ( )
if err != nil {
return nil , err
}
create . ID = int32 ( id )
return create , nil
}
2023-10-05 23:11:29 +08:00
func ( d * DB ) ListIdentityProviders ( ctx context . Context , find * store . FindIdentityProvider ) ( [ ] * store . IdentityProvider , error ) {
2023-09-28 22:09:52 +08:00
where , args := [ ] string { "1 = 1" } , [ ] any { }
if v := find . ID ; v != nil {
2023-10-07 22:56:12 +08:00
where , args = append ( where , "`id` = ?" ) , append ( args , * v )
}
rows , err := d . db . QueryContext ( ctx , "SELECT `id`, `name`, `type`, `identifier_filter`, `config` FROM `idp` WHERE " + strings . Join ( where , " AND " ) + " ORDER BY `id` ASC" ,
2023-09-28 22:09:52 +08:00
args ... ,
)
if err != nil {
return nil , err
}
defer rows . Close ( )
var identityProviders [ ] * store . IdentityProvider
for rows . Next ( ) {
var identityProvider store . IdentityProvider
2024-04-13 10:50:25 +08:00
var typeString string
2023-09-28 22:09:52 +08:00
if err := rows . Scan (
& identityProvider . ID ,
& identityProvider . Name ,
2024-04-13 10:50:25 +08:00
& typeString ,
2023-09-28 22:09:52 +08:00
& identityProvider . IdentifierFilter ,
2024-04-13 10:50:25 +08:00
& identityProvider . Config ,
2023-09-28 22:09:52 +08:00
) ; err != nil {
return nil , err
}
2024-04-13 10:50:25 +08:00
identityProvider . Type = storepb . IdentityProvider_Type ( storepb . IdentityProvider_Type_value [ typeString ] )
2023-09-28 22:09:52 +08:00
identityProviders = append ( identityProviders , & identityProvider )
}
if err := rows . Err ( ) ; err != nil {
return nil , err
}
return identityProviders , nil
}
2023-10-05 23:11:29 +08:00
func ( d * DB ) GetIdentityProvider ( ctx context . Context , find * store . FindIdentityProvider ) ( * store . IdentityProvider , error ) {
2023-09-28 22:09:52 +08:00
list , err := d . ListIdentityProviders ( ctx , find )
if err != nil {
return nil , err
}
if len ( list ) == 0 {
return nil , nil
}
identityProvider := list [ 0 ]
return identityProvider , nil
}
2023-10-05 23:11:29 +08:00
func ( d * DB ) UpdateIdentityProvider ( ctx context . Context , update * store . UpdateIdentityProvider ) ( * store . IdentityProvider , error ) {
2023-09-28 22:09:52 +08:00
set , args := [ ] string { } , [ ] any { }
if v := update . Name ; v != nil {
2023-10-07 22:56:12 +08:00
set , args = append ( set , "`name` = ?" ) , append ( args , * v )
2023-09-28 22:09:52 +08:00
}
if v := update . IdentifierFilter ; v != nil {
2023-10-07 22:56:12 +08:00
set , args = append ( set , "`identifier_filter` = ?" ) , append ( args , * v )
2023-09-28 22:09:52 +08:00
}
if v := update . Config ; v != nil {
2024-04-13 10:50:25 +08:00
set , args = append ( set , "`config` = ?" ) , append ( args , * v )
2023-09-28 22:09:52 +08:00
}
args = append ( args , update . ID )
2023-10-07 22:56:12 +08:00
stmt := "UPDATE `idp` SET " + strings . Join ( set , ", " ) + " WHERE `id` = ?"
2023-09-28 22:09:52 +08:00
_ , err := d . db . ExecContext ( ctx , stmt , args ... )
if err != nil {
return nil , err
}
2023-09-29 09:15:54 +08:00
identityProvider , err := d . GetIdentityProvider ( ctx , & store . FindIdentityProvider {
ID : & update . ID ,
} )
if err != nil {
2023-09-28 22:09:52 +08:00
return nil , err
}
2023-09-29 09:15:54 +08:00
if identityProvider == nil {
return nil , errors . Errorf ( "idp %d not found" , update . ID )
2023-09-28 22:09:52 +08:00
}
2023-09-29 09:15:54 +08:00
return identityProvider , nil
2023-09-28 22:09:52 +08:00
}
2023-10-05 23:11:29 +08:00
func ( d * DB ) DeleteIdentityProvider ( ctx context . Context , delete * store . DeleteIdentityProvider ) error {
2023-10-07 22:56:12 +08:00
where , args := [ ] string { "`id` = ?" } , [ ] any { delete . ID }
stmt := "DELETE FROM `idp` WHERE " + strings . Join ( where , " AND " )
2023-09-28 22:09:52 +08:00
result , err := d . db . ExecContext ( ctx , stmt , args ... )
if err != nil {
return err
}
if _ , err = result . RowsAffected ( ) ; err != nil {
return err
}
return nil
}