diff --git a/cmd/run.go b/cmd/run.go index 085b8fd..e633e8d 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -14,6 +14,7 @@ import ( "github.com/divyam234/teldrive/api" "github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/database" + "github.com/divyam234/teldrive/internal/duration" "github.com/divyam234/teldrive/internal/kv" "github.com/divyam234/teldrive/internal/middleware" "github.com/divyam234/teldrive/internal/tgc" @@ -49,14 +50,14 @@ func NewRun() *cobra.Command { runCmd.Flags().StringP("config", "c", "", "config file (default is $HOME/.teldrive/config.toml)") runCmd.Flags().IntVarP(&config.Server.Port, "server-port", "p", 8080, "Server port") - runCmd.Flags().DurationVar(&config.Server.GracefulShutdown, "server-graceful-shutdown", 15*time.Second, "Server graceful shutdown timeout") + duration.DurationVar(runCmd.Flags(), &config.Server.GracefulShutdown, "server-graceful-shutdown", 15*time.Second, "Server graceful shutdown timeout") runCmd.Flags().IntVarP(&config.Log.Level, "log-level", "", -1, "Logging level") runCmd.Flags().StringVar(&config.Log.File, "log-file", "", "Logging file path") runCmd.Flags().BoolVar(&config.Log.Development, "log-development", false, "Enable development mode") runCmd.Flags().StringVar(&config.JWT.Secret, "jwt-secret", "", "JWT secret key") - runCmd.Flags().DurationVar(&config.JWT.SessionTime, "jwt-session-time", (30*24)*time.Hour, "JWT session duration") + duration.DurationVar(runCmd.Flags(), &config.JWT.SessionTime, "jwt-session-time", (30*24)*time.Hour, "JWT session duration") runCmd.Flags().StringSliceVar(&config.JWT.AllowedUsers, "jwt-allowed-users", []string{}, "Allowed users") runCmd.Flags().StringVar(&config.DB.DataSource, "db-data-source", "", "Database connection string") @@ -64,7 +65,7 @@ func NewRun() *cobra.Command { runCmd.Flags().BoolVar(&config.DB.Migrate.Enable, "db-migrate-enable", true, "Enable database migration") runCmd.Flags().IntVar(&config.DB.Pool.MaxIdleConnections, "db-pool-max-open-connections", 25, "Database max open connections") runCmd.Flags().IntVar(&config.DB.Pool.MaxIdleConnections, "db-pool-max-idle-connections", 25, "Database max idle connections") - runCmd.Flags().DurationVar(&config.DB.Pool.MaxLifetime, "db-pool-max-lifetime", 10*time.Minute, "Database max connection lifetime") + duration.DurationVar(runCmd.Flags(), &config.DB.Pool.MaxLifetime, "db-pool-max-lifetime", 10*time.Minute, "Database max connection lifetime") runCmd.Flags().IntVar(&config.TG.AppId, "tg-app-id", 0, "Telegram app ID") runCmd.Flags().StringVar(&config.TG.AppHash, "tg-app-hash", "", "Telegram app hash") @@ -83,12 +84,14 @@ func NewRun() *cobra.Command { runCmd.Flags().BoolVar(&config.TG.DisableStreamBots, "tg-disable-stream-bots", false, "Disable stream bots") runCmd.Flags().StringVar(&config.TG.Uploads.EncryptionKey, "tg-uploads-encryption-key", "", "Uploads encryption key") runCmd.Flags().IntVar(&config.TG.Uploads.Threads, "tg-uploads-threads", 16, "Uploads threads") - runCmd.Flags().DurationVar(&config.TG.Uploads.Retention, "tg-uploads-retention", (24*15)*time.Hour, "Uploads retention duration") + duration.DurationVar(runCmd.Flags(), &config.TG.Uploads.Retention, "tg-uploads-retention", (24*7)*time.Hour, + "Uploads retention duration") runCmd.MarkFlagRequired("tg-app-id") runCmd.MarkFlagRequired("tg-app-hash") runCmd.MarkFlagRequired("db-data-source") runCmd.MarkFlagRequired("jwt-secret") + return runCmd } diff --git a/config.sample.toml b/config.sample.toml index 629c303..e1f54cc 100644 --- a/config.sample.toml +++ b/config.sample.toml @@ -13,7 +13,7 @@ [jwt] allowed-users = [""] secret = "" - session-time = "720h" + session-time = "30d" [log] development = true @@ -41,5 +41,5 @@ [tg.uploads] encryption-key = "" - retention = "360h" + retention = "7d" threads = 16 diff --git a/internal/duration/duration.go b/internal/duration/duration.go index 56e6a61..bd1fbcb 100644 --- a/internal/duration/duration.go +++ b/internal/duration/duration.go @@ -1,41 +1,37 @@ package duration import ( - "encoding/json" - "errors" "math" "strconv" "strings" "time" + + "github.com/spf13/pflag" ) -// Duration is a time.Duration with some more parsing options type Duration time.Duration -// DurationOff is the default value for flags which can be turned off const DurationOff = Duration((1 << 63) - 1) -// Turn Duration into a string -func (d Duration) String() string { - if d == DurationOff { +func (d *Duration) String() string { + if *d == DurationOff { return "off" } - for i := len(ageSuffixes) - 2; i >= 0; i-- { - ageSuffix := &ageSuffixes[i] - if math.Abs(float64(d)) >= float64(ageSuffix.Multiplier) { - timeUnits := float64(d) / float64(ageSuffix.Multiplier) - return strconv.FormatFloat(timeUnits, 'f', -1, 64) + ageSuffix.Suffix - } + + ageSuffix := &ageSuffixes[0] + if math.Abs(float64(*d)) >= float64(ageSuffix.Multiplier) { + timeUnits := float64(*d) / float64(ageSuffix.Multiplier) + return strconv.FormatFloat(timeUnits, 'f', -1, 64) + ageSuffix.Suffix } - return time.Duration(d).String() + return time.Duration(*d).String() } -// IsSet returns if the duration is != DurationOff -func (d Duration) IsSet() bool { - return d != DurationOff +func (d *Duration) Set(s string) error { + v, err := parseDuration(s) + *d = Duration(v) + return err } -// We use time conventions var ageSuffixes = []struct { Suffix string Multiplier time.Duration @@ -44,12 +40,9 @@ var ageSuffixes = []struct { {Suffix: "w", Multiplier: time.Hour * 24 * 7}, {Suffix: "M", Multiplier: time.Hour * 24 * 30}, {Suffix: "y", Multiplier: time.Hour * 24 * 365}, - - // Default to second {Suffix: "", Multiplier: time.Second}, } -// parse the age as suffixed ages func parseDurationSuffixes(age string) (time.Duration, error) { var period float64 @@ -69,14 +62,11 @@ func parseDurationSuffixes(age string) (time.Duration, error) { return time.Duration(period), nil } -// parseDurationFromNow parses a duration string. Allows ParseDuration to match the time -// package and easier testing within the fs package. func parseDurationFromNow(age string) (d time.Duration, err error) { if age == "off" { return time.Duration(DurationOff), nil } - // Attempt to parse as a time.Duration first d, err = time.ParseDuration(age) if err == nil { return d, nil @@ -90,32 +80,19 @@ func parseDurationFromNow(age string) (d time.Duration, err error) { return d, err } -func ParseDuration(age string) (time.Duration, error) { +func newDurationValue(val time.Duration, p *time.Duration) *Duration { + *p = val + return (*Duration)(p) +} + +func DurationVar(f *pflag.FlagSet, p *time.Duration, name string, value time.Duration, usage string) { + f.VarP(newDurationValue(value, p), name, "", usage) +} + +func parseDuration(age string) (time.Duration, error) { return parseDurationFromNow(age) } -func (d Duration) Type() string { +func (d *Duration) Type() string { return "Duration" } - -func (d *Duration) UnmarshalJSON(b []byte) error { - var v interface{} - if err := json.Unmarshal(b, &v); err != nil { - return err - } - switch value := v.(type) { - case float64: - *d = Duration(value) - return nil - case string: - var err error - dur, err := ParseDuration(value) - *d = Duration(dur) - if err != nil { - return err - } - return nil - default: - return errors.New("invalid duration") - } -} diff --git a/internal/duration/duration_test.go b/internal/duration/duration_test.go index 0341939..b418b17 100644 --- a/internal/duration/duration_test.go +++ b/internal/duration/duration_test.go @@ -3,6 +3,6 @@ package duration import "testing" func TestDate(t *testing.T) { - res, _ := ParseDuration("15h2m10s") + res, _ := parseDuration("15h2m10s") _ = res }