feat: add custom duration parser

This commit is contained in:
divyam234 2024-02-19 21:32:25 +05:30
parent 22f2fef688
commit c5019f8a67
4 changed files with 35 additions and 55 deletions

View file

@ -14,6 +14,7 @@ import (
"github.com/divyam234/teldrive/api"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/database"
"github.com/divyam234/teldrive/internal/duration"
"github.com/divyam234/teldrive/internal/kv"
"github.com/divyam234/teldrive/internal/middleware"
"github.com/divyam234/teldrive/internal/tgc"
@ -49,14 +50,14 @@ func NewRun() *cobra.Command {
runCmd.Flags().StringP("config", "c", "", "config file (default is $HOME/.teldrive/config.toml)")
runCmd.Flags().IntVarP(&config.Server.Port, "server-port", "p", 8080, "Server port")
runCmd.Flags().DurationVar(&config.Server.GracefulShutdown, "server-graceful-shutdown", 15*time.Second, "Server graceful shutdown timeout")
duration.DurationVar(runCmd.Flags(), &config.Server.GracefulShutdown, "server-graceful-shutdown", 15*time.Second, "Server graceful shutdown timeout")
runCmd.Flags().IntVarP(&config.Log.Level, "log-level", "", -1, "Logging level")
runCmd.Flags().StringVar(&config.Log.File, "log-file", "", "Logging file path")
runCmd.Flags().BoolVar(&config.Log.Development, "log-development", false, "Enable development mode")
runCmd.Flags().StringVar(&config.JWT.Secret, "jwt-secret", "", "JWT secret key")
runCmd.Flags().DurationVar(&config.JWT.SessionTime, "jwt-session-time", (30*24)*time.Hour, "JWT session duration")
duration.DurationVar(runCmd.Flags(), &config.JWT.SessionTime, "jwt-session-time", (30*24)*time.Hour, "JWT session duration")
runCmd.Flags().StringSliceVar(&config.JWT.AllowedUsers, "jwt-allowed-users", []string{}, "Allowed users")
runCmd.Flags().StringVar(&config.DB.DataSource, "db-data-source", "", "Database connection string")
@ -64,7 +65,7 @@ func NewRun() *cobra.Command {
runCmd.Flags().BoolVar(&config.DB.Migrate.Enable, "db-migrate-enable", true, "Enable database migration")
runCmd.Flags().IntVar(&config.DB.Pool.MaxIdleConnections, "db-pool-max-open-connections", 25, "Database max open connections")
runCmd.Flags().IntVar(&config.DB.Pool.MaxIdleConnections, "db-pool-max-idle-connections", 25, "Database max idle connections")
runCmd.Flags().DurationVar(&config.DB.Pool.MaxLifetime, "db-pool-max-lifetime", 10*time.Minute, "Database max connection lifetime")
duration.DurationVar(runCmd.Flags(), &config.DB.Pool.MaxLifetime, "db-pool-max-lifetime", 10*time.Minute, "Database max connection lifetime")
runCmd.Flags().IntVar(&config.TG.AppId, "tg-app-id", 0, "Telegram app ID")
runCmd.Flags().StringVar(&config.TG.AppHash, "tg-app-hash", "", "Telegram app hash")
@ -83,12 +84,14 @@ func NewRun() *cobra.Command {
runCmd.Flags().BoolVar(&config.TG.DisableStreamBots, "tg-disable-stream-bots", false, "Disable stream bots")
runCmd.Flags().StringVar(&config.TG.Uploads.EncryptionKey, "tg-uploads-encryption-key", "", "Uploads encryption key")
runCmd.Flags().IntVar(&config.TG.Uploads.Threads, "tg-uploads-threads", 16, "Uploads threads")
runCmd.Flags().DurationVar(&config.TG.Uploads.Retention, "tg-uploads-retention", (24*15)*time.Hour, "Uploads retention duration")
duration.DurationVar(runCmd.Flags(), &config.TG.Uploads.Retention, "tg-uploads-retention", (24*7)*time.Hour,
"Uploads retention duration")
runCmd.MarkFlagRequired("tg-app-id")
runCmd.MarkFlagRequired("tg-app-hash")
runCmd.MarkFlagRequired("db-data-source")
runCmd.MarkFlagRequired("jwt-secret")
return runCmd
}

View file

@ -13,7 +13,7 @@
[jwt]
allowed-users = [""]
secret = ""
session-time = "720h"
session-time = "30d"
[log]
development = true
@ -41,5 +41,5 @@
[tg.uploads]
encryption-key = ""
retention = "360h"
retention = "7d"
threads = 16

View file

@ -1,41 +1,37 @@
package duration
import (
"encoding/json"
"errors"
"math"
"strconv"
"strings"
"time"
"github.com/spf13/pflag"
)
// Duration is a time.Duration with some more parsing options
type Duration time.Duration
// DurationOff is the default value for flags which can be turned off
const DurationOff = Duration((1 << 63) - 1)
// Turn Duration into a string
func (d Duration) String() string {
if d == DurationOff {
func (d *Duration) String() string {
if *d == DurationOff {
return "off"
}
for i := len(ageSuffixes) - 2; i >= 0; i-- {
ageSuffix := &ageSuffixes[i]
if math.Abs(float64(d)) >= float64(ageSuffix.Multiplier) {
timeUnits := float64(d) / float64(ageSuffix.Multiplier)
return strconv.FormatFloat(timeUnits, 'f', -1, 64) + ageSuffix.Suffix
}
ageSuffix := &ageSuffixes[0]
if math.Abs(float64(*d)) >= float64(ageSuffix.Multiplier) {
timeUnits := float64(*d) / float64(ageSuffix.Multiplier)
return strconv.FormatFloat(timeUnits, 'f', -1, 64) + ageSuffix.Suffix
}
return time.Duration(d).String()
return time.Duration(*d).String()
}
// IsSet returns if the duration is != DurationOff
func (d Duration) IsSet() bool {
return d != DurationOff
func (d *Duration) Set(s string) error {
v, err := parseDuration(s)
*d = Duration(v)
return err
}
// We use time conventions
var ageSuffixes = []struct {
Suffix string
Multiplier time.Duration
@ -44,12 +40,9 @@ var ageSuffixes = []struct {
{Suffix: "w", Multiplier: time.Hour * 24 * 7},
{Suffix: "M", Multiplier: time.Hour * 24 * 30},
{Suffix: "y", Multiplier: time.Hour * 24 * 365},
// Default to second
{Suffix: "", Multiplier: time.Second},
}
// parse the age as suffixed ages
func parseDurationSuffixes(age string) (time.Duration, error) {
var period float64
@ -69,14 +62,11 @@ func parseDurationSuffixes(age string) (time.Duration, error) {
return time.Duration(period), nil
}
// parseDurationFromNow parses a duration string. Allows ParseDuration to match the time
// package and easier testing within the fs package.
func parseDurationFromNow(age string) (d time.Duration, err error) {
if age == "off" {
return time.Duration(DurationOff), nil
}
// Attempt to parse as a time.Duration first
d, err = time.ParseDuration(age)
if err == nil {
return d, nil
@ -90,32 +80,19 @@ func parseDurationFromNow(age string) (d time.Duration, err error) {
return d, err
}
func ParseDuration(age string) (time.Duration, error) {
func newDurationValue(val time.Duration, p *time.Duration) *Duration {
*p = val
return (*Duration)(p)
}
func DurationVar(f *pflag.FlagSet, p *time.Duration, name string, value time.Duration, usage string) {
f.VarP(newDurationValue(value, p), name, "", usage)
}
func parseDuration(age string) (time.Duration, error) {
return parseDurationFromNow(age)
}
func (d Duration) Type() string {
func (d *Duration) Type() string {
return "Duration"
}
func (d *Duration) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch value := v.(type) {
case float64:
*d = Duration(value)
return nil
case string:
var err error
dur, err := ParseDuration(value)
*d = Duration(dur)
if err != nil {
return err
}
return nil
default:
return errors.New("invalid duration")
}
}

View file

@ -3,6 +3,6 @@ package duration
import "testing"
func TestDate(t *testing.T) {
res, _ := ParseDuration("15h2m10s")
res, _ := parseDuration("15h2m10s")
_ = res
}