diff --git a/store/db/mysql/idp.go b/store/db/mysql/idp.go index 8f84df29..c5fcc7f3 100644 --- a/store/db/mysql/idp.go +++ b/store/db/mysql/idp.go @@ -22,15 +22,16 @@ func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityP return nil, errors.Errorf("unsupported idp type %s", string(create.Type)) } - stmt := "INSERT INTO `idp` (`name`, `type`, `identifier_filter`, `config`) VALUES (?, ?, ?, ?)" - result, err := d.db.ExecContext( - ctx, - stmt, - create.Name, - create.Type, - create.IdentifierFilter, - string(configBytes), - ) + placeholders := []string{"?", "?", "?", "?"} + fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"} + args := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)} + + if create.ID != 0 { + fields, placeholders, args = append(fields, "`id`"), append(placeholders, "?"), append(args, create.ID) + } + + stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ")" + result, err := d.db.ExecContext(ctx, stmt, args...) if err != nil { return nil, err } diff --git a/store/db/sqlite/idp.go b/store/db/sqlite/idp.go index 8d67d112..27b864b7 100644 --- a/store/db/sqlite/idp.go +++ b/store/db/sqlite/idp.go @@ -23,26 +23,16 @@ func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityP return nil, errors.Errorf("unsupported idp type %s", string(create.Type)) } - stmt := ` - INSERT INTO idp ( - name, - type, - identifier_filter, - config - ) - VALUES (?, ?, ?, ?) - RETURNING id - ` - if err := d.db.QueryRowContext( - ctx, - stmt, - create.Name, - create.Type, - create.IdentifierFilter, - string(configBytes), - ).Scan( - &create.ID, - ); err != nil { + placeholders := []string{"?", "?", "?", "?"} + fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"} + args := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)} + + if create.ID != 0 { + fields, placeholders, args = append(fields, "`id`"), append(placeholders, "?"), append(args, create.ID) + } + + stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ") RETURNING `id`" + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil { return nil, err }