mirror of
https://github.com/go-shiori/shiori.git
synced 2025-03-11 23:34:20 +08:00
fix: override configuration from flags only if set (#865)
* fix: override configuration from flags only if set * use helper func and test it
This commit is contained in:
parent
ce04b106eb
commit
9c7483fd09
2 changed files with 99 additions and 7 deletions
|
@ -4,9 +4,11 @@ import (
|
|||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/go-shiori/shiori/internal/config"
|
||||
"github.com/go-shiori/shiori/internal/http"
|
||||
"github.com/go-shiori/shiori/internal/model"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
)
|
||||
|
||||
func newServerCommand() *cobra.Command {
|
||||
|
@ -27,6 +29,12 @@ func newServerCommand() *cobra.Command {
|
|||
return cmd
|
||||
}
|
||||
|
||||
func setIfFlagChanged(flagName string, flags *pflag.FlagSet, cfg *config.Config, fn func(cfg *config.Config)) {
|
||||
if flags.Changed(flagName) {
|
||||
fn(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func newServerCommandHandler() func(cmd *cobra.Command, args []string) {
|
||||
return func(cmd *cobra.Command, args []string) {
|
||||
ctx := context.Background()
|
||||
|
@ -54,13 +62,25 @@ func newServerCommandHandler() func(cmd *cobra.Command, args []string) {
|
|||
rootPath += "/"
|
||||
}
|
||||
|
||||
// Override configuration from flags
|
||||
cfg.Http.Port = port
|
||||
cfg.Http.Address = address + ":"
|
||||
cfg.Http.RootPath = rootPath
|
||||
cfg.Http.AccessLog = accessLog
|
||||
cfg.Http.ServeWebUI = serveWebUI
|
||||
cfg.Http.SecretKey = secretKey
|
||||
// Override configuration from flags if needed
|
||||
setIfFlagChanged("port", cmd.Flags(), cfg, func(cfg *config.Config) {
|
||||
cfg.Http.Port = port
|
||||
})
|
||||
setIfFlagChanged("address", cmd.Flags(), cfg, func(cfg *config.Config) {
|
||||
cfg.Http.Address = address + ":"
|
||||
})
|
||||
setIfFlagChanged("webroot", cmd.Flags(), cfg, func(cfg *config.Config) {
|
||||
cfg.Http.RootPath = rootPath
|
||||
})
|
||||
setIfFlagChanged("access-log", cmd.Flags(), cfg, func(cfg *config.Config) {
|
||||
cfg.Http.AccessLog = accessLog
|
||||
})
|
||||
setIfFlagChanged("serve-web-ui", cmd.Flags(), cfg, func(cfg *config.Config) {
|
||||
cfg.Http.ServeWebUI = serveWebUI
|
||||
})
|
||||
setIfFlagChanged("secret-key", cmd.Flags(), cfg, func(cfg *config.Config) {
|
||||
cfg.Http.SecretKey = secretKey
|
||||
})
|
||||
|
||||
dependencies.Log.Infof("Starting Shiori v%s", model.BuildVersion)
|
||||
|
||||
|
|
72
internal/cmd/server_test.go
Normal file
72
internal/cmd/server_test.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/go-shiori/shiori/internal/config"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_setIfFlagChanged(t *testing.T) {
|
||||
type args struct {
|
||||
flagName string
|
||||
flags func() *pflag.FlagSet
|
||||
cfg *config.Config
|
||||
fn func(cfg *config.Config)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
assertFn func(t *testing.T, cfg *config.Config)
|
||||
}{
|
||||
{
|
||||
name: "Flag didn't change",
|
||||
args: args{
|
||||
flagName: "port",
|
||||
flags: func() *pflag.FlagSet {
|
||||
return &pflag.FlagSet{}
|
||||
},
|
||||
cfg: &config.Config{
|
||||
Http: &config.HttpConfig{
|
||||
Port: 8080,
|
||||
},
|
||||
},
|
||||
fn: func(cfg *config.Config) {
|
||||
cfg.Http.Port = 9999
|
||||
},
|
||||
},
|
||||
assertFn: func(t *testing.T, cfg *config.Config) {
|
||||
require.Equal(t, cfg.Http.Port, 8080)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Flag changed",
|
||||
args: args{
|
||||
flagName: "port",
|
||||
flags: func() *pflag.FlagSet {
|
||||
pf := &pflag.FlagSet{}
|
||||
pf.IntP("port", "p", 8080, "Port used by the server")
|
||||
pf.Set("port", "9999")
|
||||
return pf
|
||||
},
|
||||
cfg: &config.Config{
|
||||
Http: &config.HttpConfig{
|
||||
Port: 8080,
|
||||
},
|
||||
},
|
||||
fn: func(cfg *config.Config) {
|
||||
cfg.Http.Port = 9999
|
||||
},
|
||||
},
|
||||
assertFn: func(t *testing.T, cfg *config.Config) {
|
||||
require.Equal(t, cfg.Http.Port, 9999)
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setIfFlagChanged(tt.args.flagName, tt.args.flags(), tt.args.cfg, tt.args.fn)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue