2023-09-28 22:09:52 +08:00
package mysql
import (
"context"
"strings"
"github.com/pkg/errors"
"github.com/usememos/memos/store"
)
2023-10-05 23:11:29 +08:00
func ( d * DB ) CreateUser ( ctx context . Context , create * store . User ) ( * store . User , error ) {
2023-10-08 18:29:32 +08:00
fields := [ ] string { "`username`" , "`role`" , "`email`" , "`nickname`" , "`password_hash`" , "`avatar_url`" }
placeholder := [ ] string { "?" , "?" , "?" , "?" , "?" , "?" }
args := [ ] any { create . Username , create . Role , create . Email , create . Nickname , create . PasswordHash , create . AvatarURL }
stmt := "INSERT INTO user (" + strings . Join ( fields , ", " ) + ") VALUES (" + strings . Join ( placeholder , ", " ) + ")"
result , err := d . db . ExecContext ( ctx , stmt , args ... )
2023-09-28 22:09:52 +08:00
if err != nil {
return nil , err
}
id , err := result . LastInsertId ( )
if err != nil {
return nil , err
}
2023-10-28 09:44:52 +08:00
id32 := int32 ( id )
list , err := d . ListUsers ( ctx , & store . FindUser { ID : & id32 } )
2023-09-28 22:09:52 +08:00
if err != nil {
return nil , err
}
if len ( list ) != 1 {
return nil , errors . Wrapf ( nil , "unexpected user count: %d" , len ( list ) )
}
return list [ 0 ] , nil
}
2023-10-05 23:11:29 +08:00
func ( d * DB ) UpdateUser ( ctx context . Context , update * store . UpdateUser ) ( * store . User , error ) {
2023-09-28 22:09:52 +08:00
set , args := [ ] string { } , [ ] any { }
if v := update . UpdatedTs ; v != nil {
2023-10-20 19:10:38 +08:00
set , args = append ( set , "`updated_ts` = FROM_UNIXTIME(?)" ) , append ( args , * v )
2023-09-28 22:09:52 +08:00
}
if v := update . RowStatus ; v != nil {
2023-10-07 22:56:12 +08:00
set , args = append ( set , "`row_status` = ?" ) , append ( args , * v )
2023-09-28 22:09:52 +08:00
}
if v := update . Username ; v != nil {
2023-10-07 22:56:12 +08:00
set , args = append ( set , "`username` = ?" ) , append ( args , * v )
2023-09-28 22:09:52 +08:00
}
if v := update . Email ; v != nil {
2023-10-07 22:56:12 +08:00
set , args = append ( set , "`email` = ?" ) , append ( args , * v )
2023-09-28 22:09:52 +08:00
}
if v := update . Nickname ; v != nil {
2023-10-07 22:56:12 +08:00
set , args = append ( set , "`nickname` = ?" ) , append ( args , * v )
2023-09-28 22:09:52 +08:00
}
if v := update . AvatarURL ; v != nil {
2023-10-07 22:56:12 +08:00
set , args = append ( set , "`avatar_url` = ?" ) , append ( args , * v )
2023-09-28 22:09:52 +08:00
}
if v := update . PasswordHash ; v != nil {
2023-10-07 22:56:12 +08:00
set , args = append ( set , "`password_hash` = ?" ) , append ( args , * v )
2023-09-28 22:09:52 +08:00
}
args = append ( args , update . ID )
2023-10-07 22:56:12 +08:00
query := "UPDATE `user` SET " + strings . Join ( set , ", " ) + " WHERE `id` = ?"
2023-09-28 22:09:52 +08:00
if _ , err := d . db . ExecContext ( ctx , query , args ... ) ; err != nil {
return nil , err
}
2023-10-28 09:44:52 +08:00
user , err := d . GetUser ( ctx , & store . FindUser { ID : & update . ID } )
if err != nil {
2023-09-28 22:09:52 +08:00
return nil , err
}
return user , nil
}
2023-10-05 23:11:29 +08:00
func ( d * DB ) ListUsers ( ctx context . Context , find * store . FindUser ) ( [ ] * store . User , error ) {
2023-09-28 22:09:52 +08:00
where , args := [ ] string { "1 = 1" } , [ ] any { }
if v := find . ID ; v != nil {
2023-10-07 22:56:12 +08:00
where , args = append ( where , "`id` = ?" ) , append ( args , * v )
2023-09-28 22:09:52 +08:00
}
if v := find . Username ; v != nil {
2023-10-07 22:56:12 +08:00
where , args = append ( where , "`username` = ?" ) , append ( args , * v )
2023-09-28 22:09:52 +08:00
}
if v := find . Role ; v != nil {
2023-10-07 22:56:12 +08:00
where , args = append ( where , "`role` = ?" ) , append ( args , * v )
2023-09-28 22:09:52 +08:00
}
if v := find . Email ; v != nil {
2023-10-07 22:56:12 +08:00
where , args = append ( where , "`email` = ?" ) , append ( args , * v )
2023-09-28 22:09:52 +08:00
}
if v := find . Nickname ; v != nil {
2023-10-07 22:56:12 +08:00
where , args = append ( where , "`nickname` = ?" ) , append ( args , * v )
}
query := "SELECT `id`, `username`, `role`, `email`, `nickname`, `password_hash`, `avatar_url`, UNIX_TIMESTAMP(`created_ts`), UNIX_TIMESTAMP(`updated_ts`), `row_status` FROM `user` WHERE " + strings . Join ( where , " AND " ) + " ORDER BY `created_ts` DESC, `row_status` DESC"
2023-09-28 22:09:52 +08:00
rows , err := d . db . QueryContext ( ctx , query , args ... )
if err != nil {
return nil , err
}
defer rows . Close ( )
list := make ( [ ] * store . User , 0 )
for rows . Next ( ) {
var user store . User
if err := rows . Scan (
& user . ID ,
& user . Username ,
& user . Role ,
& user . Email ,
& user . Nickname ,
& user . PasswordHash ,
& user . AvatarURL ,
& user . CreatedTs ,
& user . UpdatedTs ,
& user . RowStatus ,
) ; err != nil {
return nil , err
}
list = append ( list , & user )
}
if err := rows . Err ( ) ; err != nil {
return nil , err
}
return list , nil
}
2023-10-28 09:44:52 +08:00
func ( d * DB ) GetUser ( ctx context . Context , find * store . FindUser ) ( * store . User , error ) {
list , err := d . ListUsers ( ctx , find )
if err != nil {
return nil , err
}
if len ( list ) != 1 {
return nil , errors . Wrapf ( nil , "unexpected user count: %d" , len ( list ) )
}
return list [ 0 ] , nil
}
2023-10-05 23:11:29 +08:00
func ( d * DB ) DeleteUser ( ctx context . Context , delete * store . DeleteUser ) error {
2023-10-07 22:56:12 +08:00
result , err := d . db . ExecContext ( ctx , "DELETE FROM `user` WHERE `id` = ?" , delete . ID )
2023-09-28 22:09:52 +08:00
if err != nil {
return err
}
if _ , err := result . RowsAffected ( ) ; err != nil {
return err
}
if err := d . Vacuum ( ctx ) ; err != nil {
// Prevent linter warning.
return err
}
return nil
}