diff --git a/api/storage.go b/api/storage.go deleted file mode 100644 index 5d646fad..00000000 --- a/api/storage.go +++ /dev/null @@ -1,57 +0,0 @@ -package api - -const ( - // LocalStorage means the storage service is local file system. - LocalStorage = -1 - // DatabaseStorage means the storage service is database. - DatabaseStorage = 0 -) - -type StorageType string - -const ( - StorageS3 StorageType = "S3" -) - -type StorageConfig struct { - S3Config *StorageS3Config `json:"s3Config"` -} - -type StorageS3Config struct { - EndPoint string `json:"endPoint"` - Path string `json:"path"` - Region string `json:"region"` - AccessKey string `json:"accessKey"` - SecretKey string `json:"secretKey"` - Bucket string `json:"bucket"` - URLPrefix string `json:"urlPrefix"` - URLSuffix string `json:"urlSuffix"` -} - -type Storage struct { - ID int `json:"id"` - Name string `json:"name"` - Type StorageType `json:"type"` - Config *StorageConfig `json:"config"` -} - -type StorageCreate struct { - Name string `json:"name"` - Type StorageType `json:"type"` - Config *StorageConfig `json:"config"` -} - -type StoragePatch struct { - ID int `json:"id"` - Type StorageType `json:"type"` - Name *string `json:"name"` - Config *StorageConfig `json:"config"` -} - -type StorageFind struct { - ID *int `json:"id"` -} - -type StorageDelete struct { - ID int `json:"id"` -} diff --git a/api/v1/storage.go b/api/v1/storage.go index 9871b90f..186d5687 100644 --- a/api/v1/storage.go +++ b/api/v1/storage.go @@ -1,8 +1,260 @@ package v1 +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + + "github.com/labstack/echo/v4" + "github.com/usememos/memos/store" +) + const ( // LocalStorage means the storage service is local file system. LocalStorage = -1 // DatabaseStorage means the storage service is database. DatabaseStorage = 0 ) + +type StorageType string + +const ( + StorageS3 StorageType = "S3" +) + +func (t StorageType) String() string { + return string(t) +} + +type StorageConfig struct { + S3Config *StorageS3Config `json:"s3Config"` +} + +type StorageS3Config struct { + EndPoint string `json:"endPoint"` + Path string `json:"path"` + Region string `json:"region"` + AccessKey string `json:"accessKey"` + SecretKey string `json:"secretKey"` + Bucket string `json:"bucket"` + URLPrefix string `json:"urlPrefix"` + URLSuffix string `json:"urlSuffix"` +} + +type Storage struct { + ID int `json:"id"` + Name string `json:"name"` + Type StorageType `json:"type"` + Config *StorageConfig `json:"config"` +} + +type CreateStorageRequest struct { + Name string `json:"name"` + Type StorageType `json:"type"` + Config *StorageConfig `json:"config"` +} + +type UpdateStorageRequest struct { + Type StorageType `json:"type"` + Name *string `json:"name"` + Config *StorageConfig `json:"config"` +} + +func (s *APIV1Service) registerStorageRoutes(g *echo.Group) { + g.POST("/storage", func(c echo.Context) error { + ctx := c.Request().Context() + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") + } + + user, err := s.Store.GetUser(ctx, &store.FindUser{ + ID: &userID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) + } + if user == nil || user.Role != store.RoleHost { + return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") + } + + create := &CreateStorageRequest{} + if err := json.NewDecoder(c.Request().Body).Decode(create); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post storage request").SetInternal(err) + } + + configString := "" + if create.Type == StorageS3 && create.Config.S3Config != nil { + configBytes, err := json.Marshal(create.Config.S3Config) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post storage request").SetInternal(err) + } + configString = string(configBytes) + } + + storage, err := s.Store.CreateStorage(ctx, &store.Storage{ + Name: create.Name, + Type: create.Type.String(), + Config: configString, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create storage").SetInternal(err) + } + storageMessage, err := ConvertStorageFromStore(storage) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to convert storage").SetInternal(err) + } + return c.JSON(http.StatusOK, storageMessage) + }) + + g.PATCH("/storage/:storageId", func(c echo.Context) error { + ctx := c.Request().Context() + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") + } + + user, err := s.Store.GetUser(ctx, &store.FindUser{ + ID: &userID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) + } + if user == nil || user.Role != store.RoleHost { + return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") + } + + storageID, err := strconv.Atoi(c.Param("storageId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("storageId"))).SetInternal(err) + } + + update := &UpdateStorageRequest{} + if err := json.NewDecoder(c.Request().Body).Decode(update); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch storage request").SetInternal(err) + } + storageUpdate := &store.UpdateStorage{ + ID: storageID, + } + if update.Name != nil { + storageUpdate.Name = update.Name + } + if update.Config != nil { + if update.Type == StorageS3 { + configBytes, err := json.Marshal(update.Config.S3Config) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post storage request").SetInternal(err) + } + configString := string(configBytes) + storageUpdate.Config = &configString + } + } + + storage, err := s.Store.UpdateStorage(ctx, storageUpdate) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch storage").SetInternal(err) + } + storageMessage, err := ConvertStorageFromStore(storage) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to convert storage").SetInternal(err) + } + return c.JSON(http.StatusOK, storageMessage) + }) + + g.GET("/storage", func(c echo.Context) error { + ctx := c.Request().Context() + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") + } + + user, err := s.Store.GetUser(ctx, &store.FindUser{ + ID: &userID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) + } + // We should only show storage list to host user. + if user == nil || user.Role != store.RoleHost { + return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") + } + + list, err := s.Store.ListStorages(ctx, &store.FindStorage{}) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage list").SetInternal(err) + } + + storageList := []*Storage{} + for _, storage := range list { + storageMessage, err := ConvertStorageFromStore(storage) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to convert storage").SetInternal(err) + } + storageList = append(storageList, storageMessage) + } + return c.JSON(http.StatusOK, storageList) + }) + + g.DELETE("/storage/:storageId", func(c echo.Context) error { + ctx := c.Request().Context() + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") + } + + user, err := s.Store.GetUser(ctx, &store.FindUser{ + ID: &userID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) + } + if user == nil || user.Role != store.RoleHost { + return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") + } + + storageID, err := strconv.Atoi(c.Param("storageId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("storageId"))).SetInternal(err) + } + + systemSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{Name: SystemSettingStorageServiceIDName.String()}) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage").SetInternal(err) + } + if systemSetting != nil { + storageServiceID := DatabaseStorage + err = json.Unmarshal([]byte(systemSetting.Value), &storageServiceID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal storage service id").SetInternal(err) + } + if storageServiceID == storageID { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Storage service %d is using", storageID)) + } + } + + if err = s.Store.DeleteStorage(ctx, &store.DeleteStorage{ID: storageID}); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete storage").SetInternal(err) + } + return c.JSON(http.StatusOK, true) + }) +} + +func ConvertStorageFromStore(storage *store.Storage) (*Storage, error) { + storageMessage := &Storage{ + ID: storage.ID, + Name: storage.Name, + Type: StorageType(storage.Type), + Config: &StorageConfig{}, + } + if storageMessage.Type == StorageS3 { + s3Config := &StorageS3Config{} + if err := json.Unmarshal([]byte(storage.Config), s3Config); err != nil { + return nil, err + } + storageMessage.Config = &StorageConfig{ + S3Config: s3Config, + } + } + return storageMessage, nil +} diff --git a/api/v1/tag.go b/api/v1/tag.go index 1fbec2a7..0530581b 100644 --- a/api/v1/tag.go +++ b/api/v1/tag.go @@ -42,7 +42,7 @@ func (s *APIV1Service) registerTagRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "Tag name shouldn't be empty") } - tag, err := s.Store.UpsertTagV1(ctx, &store.Tag{ + tag, err := s.Store.UpsertTag(ctx, &store.Tag{ Name: tagUpsert.Name, CreatorID: userID, }) diff --git a/api/v1/v1.go b/api/v1/v1.go index c0c18fff..639ab1be 100644 --- a/api/v1/v1.go +++ b/api/v1/v1.go @@ -33,4 +33,5 @@ func (s *APIV1Service) Register(rootGroup *echo.Group) { s.registerUserSettingRoutes(apiV1Group) s.registerTagRoutes(apiV1Group) s.registerShortcutRoutes(apiV1Group) + s.registerStorageRoutes(apiV1Group) } diff --git a/server/resource.go b/server/resource.go index 430aaee4..6d160ef0 100644 --- a/server/resource.go +++ b/server/resource.go @@ -155,7 +155,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage").SetInternal(err) } - storageServiceID := api.DatabaseStorage + storageServiceID := apiv1.DatabaseStorage if systemSettingStorageServiceID != nil { err = json.Unmarshal([]byte(systemSettingStorageServiceID.Value), &storageServiceID) if err != nil { @@ -164,7 +164,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { } publicID := common.GenUUID() - if storageServiceID == api.DatabaseStorage { + if storageServiceID == apiv1.DatabaseStorage { fileBytes, err := io.ReadAll(sourceFile) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to read file").SetInternal(err) @@ -176,7 +176,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { Size: size, Blob: fileBytes, } - } else if storageServiceID == api.LocalStorage { + } else if storageServiceID == apiv1.LocalStorage { // filepath.Join() should be used for local file paths, // as it handles the os-specific path separator automatically. // path.Join() always uses '/' as path separator. @@ -219,13 +219,17 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { InternalPath: filePath, } } else { - storage, err := s.Store.FindStorage(ctx, &api.StorageFind{ID: &storageServiceID}) + storage, err := s.Store.GetStorage(ctx, &store.FindStorage{ID: &storageServiceID}) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage").SetInternal(err) } + storageMessage, err := apiv1.ConvertStorageFromStore(storage) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to convert storage").SetInternal(err) + } - if storage.Type == api.StorageS3 { - s3Config := storage.Config.S3Config + if storageMessage.Type == apiv1.StorageS3 { + s3Config := storageMessage.Config.S3Config s3Client, err := s3.NewClient(ctx, &s3.Config{ AccessKey: s3Config.AccessKey, SecretKey: s3Config.SecretKey, diff --git a/server/server.go b/server/server.go index c35f88ef..847932f9 100644 --- a/server/server.go +++ b/server/server.go @@ -101,7 +101,6 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store s.registerMemoRoutes(apiGroup) s.registerMemoResourceRoutes(apiGroup) s.registerResourceRoutes(apiGroup) - s.registerStorageRoutes(apiGroup) s.registerMemoRelationRoutes(apiGroup) apiV1Service := apiv1.NewAPIV1Service(s.Secret, profile, store) diff --git a/server/storage.go b/server/storage.go deleted file mode 100644 index 455815ce..00000000 --- a/server/storage.go +++ /dev/null @@ -1,152 +0,0 @@ -package server - -import ( - "encoding/json" - "fmt" - "net/http" - "strconv" - - "github.com/labstack/echo/v4" - "github.com/usememos/memos/api" - apiv1 "github.com/usememos/memos/api/v1" - "github.com/usememos/memos/common" - "github.com/usememos/memos/store" -) - -func (s *Server) registerStorageRoutes(g *echo.Group) { - g.POST("/storage", func(c echo.Context) error { - ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") - } - - user, err := s.Store.GetUser(ctx, &store.FindUser{ - ID: &userID, - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) - } - if user == nil || user.Role != store.RoleHost { - return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") - } - - storageCreate := &api.StorageCreate{} - if err := json.NewDecoder(c.Request().Body).Decode(storageCreate); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post storage request").SetInternal(err) - } - - storage, err := s.Store.CreateStorage(ctx, storageCreate) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create storage").SetInternal(err) - } - return c.JSON(http.StatusOK, composeResponse(storage)) - }) - - g.PATCH("/storage/:storageId", func(c echo.Context) error { - ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") - } - - user, err := s.Store.GetUser(ctx, &store.FindUser{ - ID: &userID, - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) - } - if user == nil || user.Role != store.RoleHost { - return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") - } - - storageID, err := strconv.Atoi(c.Param("storageId")) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("storageId"))).SetInternal(err) - } - - storagePatch := &api.StoragePatch{ - ID: storageID, - } - if err := json.NewDecoder(c.Request().Body).Decode(storagePatch); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch storage request").SetInternal(err) - } - - storage, err := s.Store.PatchStorage(ctx, storagePatch) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch storage").SetInternal(err) - } - return c.JSON(http.StatusOK, composeResponse(storage)) - }) - - g.GET("/storage", func(c echo.Context) error { - ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") - } - - user, err := s.Store.GetUser(ctx, &store.FindUser{ - ID: &userID, - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) - } - // We should only show storage list to host user. - if user == nil || user.Role != store.RoleHost { - return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") - } - - storageList, err := s.Store.FindStorageList(ctx, &api.StorageFind{}) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage list").SetInternal(err) - } - return c.JSON(http.StatusOK, composeResponse(storageList)) - }) - - g.DELETE("/storage/:storageId", func(c echo.Context) error { - ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") - } - - user, err := s.Store.GetUser(ctx, &store.FindUser{ - ID: &userID, - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) - } - if user == nil || user.Role != store.RoleHost { - return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") - } - - storageID, err := strconv.Atoi(c.Param("storageId")) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("storageId"))).SetInternal(err) - } - - systemSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{Name: apiv1.SystemSettingStorageServiceIDName.String()}) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage").SetInternal(err) - } - if systemSetting != nil { - storageServiceID := api.DatabaseStorage - err = json.Unmarshal([]byte(systemSetting.Value), &storageServiceID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal storage service id").SetInternal(err) - } - if storageServiceID == storageID { - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Storage service %d is using", storageID)) - } - } - - if err = s.Store.DeleteStorage(ctx, &api.StorageDelete{ID: storageID}); err != nil { - if common.ErrorCode(err) == common.NotFound { - return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Storage ID not found: %d", storageID)) - } - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete storage").SetInternal(err) - } - return c.JSON(http.StatusOK, true) - }) -} diff --git a/store/activity.go b/store/activity.go index 3d31ef2c..57b01986 100644 --- a/store/activity.go +++ b/store/activity.go @@ -21,7 +21,7 @@ type ActivityMessage struct { func (s *Store) CreateActivity(ctx context.Context, create *ActivityMessage) (*ActivityMessage, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -39,11 +39,11 @@ func (s *Store) CreateActivity(ctx context.Context, create *ActivityMessage) (*A &create.ID, &create.CreatedTs, ); err != nil { - return nil, FormatError(err) + return nil, err } if err := tx.Commit(); err != nil { - return nil, FormatError(err) + return nil, err } activityMessage := create return activityMessage, nil diff --git a/store/common.go b/store/common.go index 95d0077c..79a63e84 100644 --- a/store/common.go +++ b/store/common.go @@ -11,11 +11,5 @@ const ( ) func (r RowStatus) String() string { - switch r { - case Normal: - return "NORMAL" - case Archived: - return "ARCHIVED" - } - return "" + return string(r) } diff --git a/store/shortcut.go b/store/shortcut.go index 2f75e027..de6e25ca 100644 --- a/store/shortcut.go +++ b/store/shortcut.go @@ -212,7 +212,7 @@ func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shor args..., ) if err != nil { - return nil, FormatError(err) + return nil, err } defer rows.Close() @@ -228,13 +228,13 @@ func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shor &shortcut.UpdatedTs, &shortcut.RowStatus, ); err != nil { - return nil, FormatError(err) + return nil, err } list = append(list, &shortcut) } if err := rows.Err(); err != nil { - return nil, FormatError(err) + return nil, err } return list, nil @@ -253,7 +253,7 @@ func vacuumShortcut(ctx context.Context, tx *sql.Tx) error { )` _, err := tx.ExecContext(ctx, stmt) if err != nil { - return FormatError(err) + return err } return nil diff --git a/store/storage.go b/store/storage.go index 0b560aec..cceee950 100644 --- a/store/storage.go +++ b/store/storage.go @@ -3,284 +3,200 @@ package store import ( "context" "database/sql" - "encoding/json" - "fmt" "strings" - - "github.com/usememos/memos/api" - "github.com/usememos/memos/common" ) -type storageRaw struct { +type Storage struct { ID int Name string - Type api.StorageType - Config *api.StorageConfig + Type string + Config string } -func (raw *storageRaw) toStorage() *api.Storage { - return &api.Storage{ - ID: raw.ID, - Name: raw.Name, - Type: raw.Type, - Config: raw.Config, - } +type FindStorage struct { + ID *int } -func (s *Store) CreateStorage(ctx context.Context, create *api.StorageCreate) (*api.Storage, error) { +type UpdateStorage struct { + ID int + Name *string + Config *string +} + +type DeleteStorage struct { + ID int +} + +func (s *Store) CreateStorage(ctx context.Context, create *Storage) (*Storage, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() - storageRaw, err := createStorageRaw(ctx, tx, create) - if err != nil { + query := ` + INSERT INTO storage ( + name, + type, + config + ) + VALUES (?, ?, ?) + RETURNING id + ` + if err := tx.QueryRowContext(ctx, query, create.Name, create.Type, create.Config).Scan( + &create.ID, + ); err != nil { return nil, err } if err := tx.Commit(); err != nil { - return nil, FormatError(err) - } - - return storageRaw.toStorage(), nil -} - -func (s *Store) PatchStorage(ctx context.Context, patch *api.StoragePatch) (*api.Storage, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, FormatError(err) - } - defer tx.Rollback() - - storageRaw, err := patchStorageRaw(ctx, tx, patch) - if err != nil { return nil, err } - if err := tx.Commit(); err != nil { - return nil, FormatError(err) - } - - return storageRaw.toStorage(), nil + storage := create + return storage, nil } -func (s *Store) FindStorageList(ctx context.Context, find *api.StorageFind) ([]*api.Storage, error) { +func (s *Store) ListStorages(ctx context.Context, find *FindStorage) ([]*Storage, error) { tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, FormatError(err) - } - defer tx.Rollback() - - storageRawList, err := findStorageRawList(ctx, tx, find) if err != nil { return nil, err } + defer tx.Rollback() - list := []*api.Storage{} - for _, raw := range storageRawList { - list = append(list, raw.toStorage()) + list, err := listStorages(ctx, tx, find) + if err != nil { + return nil, err } return list, nil } -func (s *Store) FindStorage(ctx context.Context, find *api.StorageFind) (*api.Storage, error) { +func (s *Store) GetStorage(ctx context.Context, find *FindStorage) (*Storage, error) { tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, FormatError(err) - } - defer tx.Rollback() - - list, err := findStorageRawList(ctx, tx, find) if err != nil { return nil, err } + defer tx.Rollback() + list, err := listStorages(ctx, tx, find) + if err != nil { + return nil, err + } if len(list) == 0 { - return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")} + return nil, nil } - storageRaw := list[0] - return storageRaw.toStorage(), nil + return list[0], nil } -func (s *Store) DeleteStorage(ctx context.Context, delete *api.StorageDelete) error { +func (s *Store) UpdateStorage(ctx context.Context, update *UpdateStorage) (*Storage, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return FormatError(err) + return nil, err } defer tx.Rollback() - if err := deleteStorage(ctx, tx, delete); err != nil { - return FormatError(err) - } - - if err := tx.Commit(); err != nil { - return FormatError(err) - } - - return nil -} - -func createStorageRaw(ctx context.Context, tx *sql.Tx, create *api.StorageCreate) (*storageRaw, error) { - set := []string{"name", "type", "config"} - args := []any{create.Name, create.Type} - placeholder := []string{"?", "?", "?"} - - var configBytes []byte - var err error - if create.Type == api.StorageS3 { - configBytes, err = json.Marshal(create.Config.S3Config) - if err != nil { - return nil, err - } - } else { - return nil, fmt.Errorf("unsupported storage type %s", string(create.Type)) - } - args = append(args, string(configBytes)) - - query := ` - INSERT INTO storage ( - ` + strings.Join(set, ", ") + ` - ) - VALUES (` + strings.Join(placeholder, ",") + `) - RETURNING id - ` - storageRaw := storageRaw{ - Name: create.Name, - Type: create.Type, - Config: create.Config, - } - if err := tx.QueryRowContext(ctx, query, args...).Scan( - &storageRaw.ID, - ); err != nil { - return nil, FormatError(err) - } - - return &storageRaw, nil -} - -func patchStorageRaw(ctx context.Context, tx *sql.Tx, patch *api.StoragePatch) (*storageRaw, error) { set, args := []string{}, []any{} - if v := patch.Name; v != nil { - set, args = append(set, "name = ?"), append(args, *v) + if update.Name != nil { + set = append(set, "name = ?") + args = append(args, *update.Name) } - if v := patch.Config; v != nil { - var configBytes []byte - var err error - if patch.Type == api.StorageS3 { - configBytes, err = json.Marshal(patch.Config.S3Config) - if err != nil { - return nil, err - } - } else { - return nil, fmt.Errorf("unsupported storage type %s", string(patch.Type)) - } - set, args = append(set, "config = ?"), append(args, string(configBytes)) + if update.Config != nil { + set = append(set, "config = ?") + args = append(args, *update.Config) } - args = append(args, patch.ID) + args = append(args, update.ID) query := ` UPDATE storage SET ` + strings.Join(set, ", ") + ` WHERE id = ? - RETURNING id, name, type, config + RETURNING + id, + name, + type, + config ` - var storageRaw storageRaw - var storageConfig string + storage := &Storage{} if err := tx.QueryRowContext(ctx, query, args...).Scan( - &storageRaw.ID, - &storageRaw.Name, - &storageRaw.Type, - &storageConfig, + &storage.ID, + &storage.Name, + &storage.Type, + &storage.Config, ); err != nil { - return nil, FormatError(err) - } - if storageRaw.Type == api.StorageS3 { - s3Config := &api.StorageS3Config{} - if err := json.Unmarshal([]byte(storageConfig), s3Config); err != nil { - return nil, err - } - storageRaw.Config = &api.StorageConfig{ - S3Config: s3Config, - } - } else { - return nil, fmt.Errorf("unsupported storage type %s", string(storageRaw.Type)) + return nil, err } - return &storageRaw, nil + if err := tx.Commit(); err != nil { + return nil, err + } + + return storage, nil } -func findStorageRawList(ctx context.Context, tx *sql.Tx, find *api.StorageFind) ([]*storageRaw, error) { - where, args := []string{"1 = 1"}, []any{} - - if v := find.ID; v != nil { - where, args = append(where, "id = ?"), append(args, *v) +func (s *Store) DeleteStorage(ctx context.Context, delete *DeleteStorage) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err } + defer tx.Rollback() query := ` - SELECT - id, - name, - type, - config - FROM storage - WHERE ` + strings.Join(where, " AND ") + ` - ORDER BY id DESC + DELETE FROM storage + WHERE id = ? ` - rows, err := tx.QueryContext(ctx, query, args...) - if err != nil { - return nil, FormatError(err) - } - defer rows.Close() - - storageRawList := make([]*storageRaw, 0) - for rows.Next() { - var storageRaw storageRaw - var storageConfig string - if err := rows.Scan( - &storageRaw.ID, - &storageRaw.Name, - &storageRaw.Type, - &storageConfig, - ); err != nil { - return nil, FormatError(err) - } - if storageRaw.Type == api.StorageS3 { - s3Config := &api.StorageS3Config{} - if err := json.Unmarshal([]byte(storageConfig), s3Config); err != nil { - return nil, err - } - storageRaw.Config = &api.StorageConfig{ - S3Config: s3Config, - } - } else { - return nil, fmt.Errorf("unsupported storage type %s", string(storageRaw.Type)) - } - storageRawList = append(storageRawList, &storageRaw) + if _, err := tx.ExecContext(ctx, query, delete.ID); err != nil { + return err } - if err := rows.Err(); err != nil { - return nil, FormatError(err) - } - - return storageRawList, nil -} - -func deleteStorage(ctx context.Context, tx *sql.Tx, delete *api.StorageDelete) error { - where, args := []string{"id = ?"}, []any{delete.ID} - - stmt := `DELETE FROM storage WHERE ` + strings.Join(where, " AND ") - result, err := tx.ExecContext(ctx, stmt, args...) - if err != nil { - return FormatError(err) - } - - rows, _ := result.RowsAffected() - if rows == 0 { - return &common.Error{Code: common.NotFound, Err: fmt.Errorf("storage not found")} + if err := tx.Commit(); err != nil { + // Prevent linter warning. + return err } return nil } + +func listStorages(ctx context.Context, tx *sql.Tx, find *FindStorage) ([]*Storage, error) { + where, args := []string{"1 = 1"}, []any{} + if find.ID != nil { + where, args = append(where, "id = ?"), append(args, *find.ID) + } + + rows, err := tx.QueryContext(ctx, ` + SELECT + id, + name, + type, + config + FROM storage + WHERE `+strings.Join(where, " AND ")+` + ORDER BY id DESC`, + args..., + ) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*Storage{} + for rows.Next() { + storage := &Storage{} + if err := rows.Scan( + &storage.ID, + &storage.Name, + &storage.Type, + &storage.Config, + ); err != nil { + return nil, err + } + list = append(list, storage) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} diff --git a/store/tag.go b/store/tag.go index 924ed62c..9de47807 100644 --- a/store/tag.go +++ b/store/tag.go @@ -21,10 +21,10 @@ type DeleteTag struct { CreatorID int } -func (s *Store) UpsertTagV1(ctx context.Context, upsert *Tag) (*Tag, error) { +func (s *Store) UpsertTag(ctx context.Context, upsert *Tag) (*Tag, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -52,7 +52,7 @@ func (s *Store) UpsertTagV1(ctx context.Context, upsert *Tag) (*Tag, error) { func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -67,7 +67,7 @@ func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) { ` rows, err := tx.QueryContext(ctx, query, args...) if err != nil { - return nil, FormatError(err) + return nil, err } defer rows.Close() @@ -78,14 +78,14 @@ func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) { &tag.Name, &tag.CreatorID, ); err != nil { - return nil, FormatError(err) + return nil, err } list = append(list, tag) } if err := rows.Err(); err != nil { - return nil, FormatError(err) + return nil, err } return list, nil @@ -94,7 +94,7 @@ func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) { func (s *Store) DeleteTag(ctx context.Context, delete *DeleteTag) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return FormatError(err) + return err } defer tx.Rollback() @@ -102,7 +102,7 @@ func (s *Store) DeleteTag(ctx context.Context, delete *DeleteTag) error { query := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ") result, err := tx.ExecContext(ctx, query, args...) if err != nil { - return FormatError(err) + return err } rows, _ := result.RowsAffected() @@ -131,7 +131,7 @@ func vacuumTag(ctx context.Context, tx *sql.Tx) error { )` _, err := tx.ExecContext(ctx, stmt) if err != nil { - return FormatError(err) + return err } return nil diff --git a/test/store/storage_test.go b/test/store/storage_test.go new file mode 100644 index 00000000..23670bec --- /dev/null +++ b/test/store/storage_test.go @@ -0,0 +1,38 @@ +package teststore + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/usememos/memos/store" +) + +func TestStorageStore(t *testing.T) { + ctx := context.Background() + ts := NewTestingStore(ctx, t) + storage, err := ts.CreateStorage(ctx, &store.Storage{ + Name: "test_storage", + Type: "S3", + Config: "{}", + }) + require.NoError(t, err) + newStorageName := "new_storage_name" + updatedStorage, err := ts.UpdateStorage(ctx, &store.UpdateStorage{ + ID: storage.ID, + Name: &newStorageName, + }) + require.NoError(t, err) + require.Equal(t, newStorageName, updatedStorage.Name) + storageList, err := ts.ListStorages(ctx, &store.FindStorage{}) + require.NoError(t, err) + require.Equal(t, 1, len(storageList)) + require.Equal(t, updatedStorage, storageList[0]) + err = ts.DeleteStorage(ctx, &store.DeleteStorage{ + ID: storage.ID, + }) + require.NoError(t, err) + storageList, err = ts.ListStorages(ctx, &store.FindStorage{}) + require.NoError(t, err) + require.Equal(t, 0, len(storageList)) +} diff --git a/web/src/components/Settings/StorageSection.tsx b/web/src/components/Settings/StorageSection.tsx index 3532d164..7d898ba4 100644 --- a/web/src/components/Settings/StorageSection.tsx +++ b/web/src/components/Settings/StorageSection.tsx @@ -22,9 +22,7 @@ const StorageSection = () => { }, []); const fetchStorageList = async () => { - const { - data: { data: storageList }, - } = await api.getStorageList(); + const { data: storageList } = await api.getStorageList(); setStorageList(storageList); }; diff --git a/web/src/helpers/api.ts b/web/src/helpers/api.ts index 0cc1f55c..34a9f1d0 100644 --- a/web/src/helpers/api.ts +++ b/web/src/helpers/api.ts @@ -230,19 +230,19 @@ export function deleteTag(tagName: string) { } export function getStorageList() { - return axios.get>(`/api/storage`); + return axios.get(`/api/v1/storage`); } export function createStorage(storageCreate: StorageCreate) { - return axios.post>(`/api/storage`, storageCreate); + return axios.post(`/api/v1/storage`, storageCreate); } export function patchStorage(storagePatch: StoragePatch) { - return axios.patch>(`/api/storage/${storagePatch.id}`, storagePatch); + return axios.patch(`/api/v1/storage/${storagePatch.id}`, storagePatch); } export function deleteStorage(storageId: StorageId) { - return axios.delete(`/api/storage/${storageId}`); + return axios.delete(`/api/v1/storage/${storageId}`); } export function getIdentityProviderList() {