mirror of
https://github.com/bit1001/tdl.git
synced 2025-01-07 20:28:11 +08:00
151 lines
2.9 KiB
Go
151 lines
2.9 KiB
Go
package storage
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"github.com/gotd/td/telegram/updates"
|
|
"github.com/iyear/tdl/pkg/key"
|
|
"github.com/iyear/tdl/pkg/kv"
|
|
)
|
|
|
|
type State struct {
|
|
kv *kv.KV
|
|
}
|
|
|
|
func NewState(kv *kv.KV) *State {
|
|
return &State{kv: kv}
|
|
}
|
|
|
|
func (s *State) Get(key string, v interface{}) error {
|
|
data, err := s.kv.Get(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return json.Unmarshal(data, v)
|
|
}
|
|
|
|
func (s *State) Set(key string, v interface{}) error {
|
|
data, err := json.Marshal(v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return s.kv.Set(key, data)
|
|
}
|
|
|
|
func (s *State) GetState(userID int64) (updates.State, bool, error) {
|
|
state := updates.State{}
|
|
|
|
if err := s.Get(key.State(userID), &state); err != nil {
|
|
if errors.Is(err, kv.ErrNotFound) {
|
|
return state, false, nil
|
|
}
|
|
return state, false, err
|
|
}
|
|
|
|
return state, true, nil
|
|
}
|
|
|
|
func (s *State) SetState(userID int64, state updates.State) error {
|
|
if err := s.Set(key.State(userID), state); err != nil {
|
|
return err
|
|
}
|
|
|
|
return s.Set(key.StateChannel(userID), struct{}{})
|
|
}
|
|
|
|
func (s *State) SetPts(userID int64, pts int) error {
|
|
state, k := updates.State{}, key.State(userID)
|
|
|
|
if err := s.Get(k, &state); err != nil {
|
|
return err
|
|
}
|
|
state.Pts = pts
|
|
return s.Set(k, state)
|
|
}
|
|
|
|
func (s *State) SetQts(userID int64, qts int) error {
|
|
state, k := updates.State{}, key.State(userID)
|
|
|
|
if err := s.Get(k, &state); err != nil {
|
|
return err
|
|
}
|
|
state.Qts = qts
|
|
return s.Set(k, state)
|
|
}
|
|
|
|
func (s *State) SetDate(userID int64, date int) error {
|
|
state, k := updates.State{}, key.State(userID)
|
|
|
|
if err := s.Get(k, &state); err != nil {
|
|
return err
|
|
}
|
|
state.Date = date
|
|
return s.Set(k, state)
|
|
}
|
|
|
|
func (s *State) SetSeq(userID int64, seq int) error {
|
|
state, k := updates.State{}, key.State(userID)
|
|
|
|
if err := s.Get(k, &state); err != nil {
|
|
return err
|
|
}
|
|
state.Seq = seq
|
|
return s.Set(k, state)
|
|
}
|
|
|
|
func (s *State) SetDateSeq(userID int64, date, seq int) error {
|
|
state, k := updates.State{}, key.State(userID)
|
|
|
|
if err := s.Get(k, &state); err != nil {
|
|
return err
|
|
}
|
|
state.Date = date
|
|
state.Seq = seq
|
|
return s.Set(k, state)
|
|
}
|
|
|
|
func (s *State) GetChannelPts(userID, channelID int64) (int, bool, error) {
|
|
c := make(map[int64]int)
|
|
|
|
if err := s.Get(key.StateChannel(userID), &c); err != nil {
|
|
if errors.Is(err, kv.ErrNotFound) {
|
|
return 0, false, nil
|
|
}
|
|
return 0, false, err
|
|
}
|
|
|
|
pts, ok := c[channelID]
|
|
if !ok {
|
|
return 0, false, nil
|
|
}
|
|
|
|
return pts, true, nil
|
|
}
|
|
|
|
func (s *State) SetChannelPts(userID, channelID int64, pts int) error {
|
|
c, k := make(map[int64]int), key.StateChannel(userID)
|
|
|
|
if err := s.Get(k, &c); err != nil {
|
|
return err
|
|
}
|
|
c[channelID] = pts
|
|
return s.Set(k, c)
|
|
}
|
|
|
|
func (s *State) ForEachChannels(userID int64, f func(channelID int64, pts int) error) error {
|
|
c := make(map[int64]int)
|
|
|
|
if err := s.Get(key.StateChannel(userID), &c); err != nil {
|
|
return err
|
|
}
|
|
|
|
for channelID, pts := range c {
|
|
if err := f(channelID, pts); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|