major changes related to bots

This commit is contained in:
divyam234 2023-09-20 00:50:44 +05:30
parent fd29819b0d
commit b75f9bb22c
43 changed files with 1702 additions and 1141 deletions

View file

@ -31,15 +31,17 @@ jobs:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Image Name
id: imagename
run: echo "name=ghcr.io/${GITHUB_REPOSITORY,,}/server" >> $GITHUB_OUTPUT
- name: Set Vars
id: vars
run: |
echo "TAG=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
echo "IMAGE_NAME=ghcr.io/${GITHUB_REPOSITORY,,}/server" >> $GITHUB_ENV
- name: Build Image
uses: docker/build-push-action@v3
with:
context: ./
platforms: linux/amd64,linux/arm64,linux/arm/v7
platforms: linux/amd64,linux/arm64
pull: true
push: true
tags: ${{ steps.imagename.outputs.name }}:latest
tags: ${{ env.IMAGE_NAME }}:${{ env.TAG }} , ${{ env.IMAGE_NAME }}:latest

4
.gitignore vendored
View file

@ -13,7 +13,6 @@
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Dependency directories (remove the comment below to include it)
# vendor/
@ -27,4 +26,5 @@ sslcerts
*.env
*.env.example
*.env.local
*.env.staging
*.env.staging
*.db

View file

@ -3,7 +3,9 @@
Telegram Drive is a powerful utility that enables you to create your own cloud storage service using Telegram as the backend.
[![Discord](https://img.shields.io/discord/1142377485737148479?label=discord&logo=discord&style=flat-square&logoColor=white)](https://discord.gg/J2gVAZnHfP)
[![Discord](https://img.shields.io/discord/1142377485737148479?label=discord&logo=discord&style=flat-square&logoColor=white)](https://discord.gg/J2gVAZnHfP)
**Click on icon to join Discord Server for better support**
<details open="open">
@ -47,15 +49,15 @@ cd teldrive
**Follow Below Steps**
- Create the `.env` or `teldrive.env` file with your variables and start your container.
- Create the `teldrive.env` file with your variables and start your container.
```sh
docker compose up -d
```
- **Go to http://localhost:8080**
- **Uploads from UI will be slower due to limitations of browser use [Teldrive Uploader](https://github.com/divyam234/teldrive-upload) for faster uploads.Make sure to use Multi Client mode if you are using uploader.**
- **Uploads from UI will be slower due to limitations of browser use [Teldrive Uploader](https://github.com/divyam234/teldrive-upload) for faster uploads.Make sure to use Multi Bots mode if you are using uploader.**
- **If you intend to share download links with others, ensure that you enable multi-client mode with bots.**
- **If you intend to share download links with others, ensure that you enable multi bots mode with bots.**
### Use without docker
@ -69,8 +71,8 @@ docker compose up -d
## Setting up things
If you're locally or remotely hosting, create a file named `.env` or `teldrive.env` in the root directory and add all the variables there.
An example of `.env` file:
If you're locally or remotely hosting, create a file named `teldrive.env` in the root directory and add all the variables there.
An example of `teldrive.env` file:
```sh
APP_ID=1234
@ -81,60 +83,45 @@ COOKIE_SAME_SITE=true
JWT_SECRET=abc
DATABASE_URL=abc
RATE_LIMIT=true
TG_CLIENT_DEVICE_MODEL="Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/116.0" # Any valid browser user agent here
MULTI_CLIENT=false
MULTI_TOKEN1=""
MULTI_TOKEN2=""
MULTI_TOKEN3=""
```
**Use strong JWT secret instead of pure guessable string.You can use openssl to generate it.**
```
> **Warning**
>Default Channel can be selected through UI make sure to set it from account settings on first login.<br>
>Use strong JWT secret instead of pure guessable string.You can use openssl to generate it.<br>
```bash
$ openssl rand -base64 32
$ openssl rand -hex 32
```
**Multi Bots Mode is recommended to avoid flood errors and enable maximum download speed, especially if you are using downloaders like IDM and aria2c which use multiple connections for downloads.**
> **Note**
> What it multi bots feature and what it does? <br>
> This feature shares the Telegram API requests between other bots to avoid getting floodwaited (A kind of rate limiting that Telegram does in the backend to avoid flooding their servers) and to make the server handle more requests. <br>
To enable multi bots, generate new bot tokens from BotFather and add it through UI on first login.
**Multi-Client Mode is recommended to avoid flood errors and enable maximum download speed, especially if you are using downloaders like IDM and aria2c which use multiple connections for downloads.**
### Mandatory Vars
Before running the bot, you will need to set up the following mandatory variables:
- `APP_ID` : Use official ones as mentioned above.
- `APP_ID` : This is the API ID for your Telegram account, which can be obtained from my.telegram.org.
- `APP_HASH` : Use official ones as mentioned above.
- `APP_HASH` : This is the API HASH for your Telegram account, which can be obtained from my.telegram.org.
- `JWT_SECRET` : Used for signing jwt tokens
- `DATABASE_URL` : Connection String obtained from Postgres DB (you can use Neon db as free alternative fro postgres)
- `CHANNEL_ID` : This is the channel ID for the log channel where app will store files . To obtain a channel ID, create a new telegram channel (public or private), post something in the channel, forward the message to [@JsonDumpBot](https://t.me/JsonDumpBot) . Copy the forwarded channel ID and paste it into the this field and remove -100 from the start.
### Optional Vars
In addition to the mandatory variables, you can also set the following optional variables:
- `HTTPS` : Only needed when frontend is deployed on vercel.
- `HTTPS` : Only needed when frontend is on other domain.
- `PORT` : Change listen port default is 8080
- `ALLOWED_USERS` : Allow certain telegram usernames including yours to access the app.Enter comma seperated telegram usernames here.Its needed when your instance is on public cloud and you want to restrict other people to access you app.
- `COOKIE_SAME_SITE` : Only needed when frontend is deployed on vercel.
- `MULTI_CLIENT` : Enable or Disable Multi Token mode. If true you have pass atleast one Multi Token
- `MULTI_TOKEN[1....]` : Recommended to add atleast 10-12 tokens
### For making use of Multi-Client support
> **Note**
> What it multi-client feature and what it does? <br>
> This feature shares the Telegram API requests between other bots to avoid getting floodwaited (A kind of rate limiting that Telegram does in the backend to avoid flooding their servers) and to make the server handle more requests. <br>
To enable multi-client, generate new bot tokens and add it as your environmental variables with the following key names.
`MULTI_TOKEN1`: Add your first bot token here.
`MULTI_TOKEN2`: Add your second bot token here.
you may also add as many as bots you want. (max limit is not tested yet)
`MULTI_TOKEN3`, `MULTI_TOKEN4`, etc.
### For making use of Multi Bots support
> **Warning**
> Don't forget to add all these bots to the `CHANNEL_ID` as admin for the proper functioning
>Bots will be auto added as admin in channel if you set them from UI if it fails somehow add it manually.
## FAQ
- How to get Postgres DB url ?

99
cache/bigcache.go vendored
View file

@ -1,99 +0,0 @@
package cache
import (
"bytes"
"context"
"encoding/gob"
"errors"
"github.com/allegro/bigcache/v3"
)
type bigCache struct {
cache *bigcache.BigCache
}
func newBigCache(cacheConfig *cacheConfig) (*bigCache, error) {
cache, err := bigcache.New(context.Background(), bigcache.Config{
Shards: 16,
LifeWindow: cacheConfig.ttl,
CleanWindow: cacheConfig.cleanFreq,
MaxEntriesInWindow: 1000 * 10 * 60,
MaxEntrySize: 500,
Verbose: false,
HardMaxCacheSize: cacheConfig.size,
StatsEnabled: true,
})
if err != nil {
return nil, err
}
return &bigCache{
cache: cache,
}, nil
}
// Set inserts the key/value pair into the cache.
// Only the exported fields of the given struct will be
// serialized and stored
func (c *bigCache) Set(key, value interface{}) error {
keyString, ok := key.(string)
if !ok {
return errors.New("a cache key must be a string")
}
valueBytes, err := serializeGOB(value)
if err != nil {
return err
}
return c.cache.Set(keyString, valueBytes)
}
// Get returns the value correlating to the key in the cache
func (c *bigCache) Get(key interface{}) (interface{}, error) {
// Assert the key is of string type
keyString, ok := key.(string)
if !ok {
return nil, errors.New("a cache key must be a string")
}
// Get the value in the byte format it is stored in
valueBytes, err := c.cache.Get(keyString)
if err != nil {
return nil, err
}
// Deserialize the bytes of the value
value, err := deserializeGOB(valueBytes)
if err != nil {
return nil, err
}
return value, nil
}
func serializeGOB(value interface{}) ([]byte, error) {
buf := bytes.Buffer{}
enc := gob.NewEncoder(&buf)
gob.Register(value)
err := enc.Encode(&value)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func deserializeGOB(valueBytes []byte) (interface{}, error) {
var value interface{}
buf := bytes.NewBuffer(valueBytes)
dec := gob.NewDecoder(buf)
err := dec.Decode(&value)
if err != nil {
return nil, err
}
return value, nil
}

72
cache/cache.go vendored
View file

@ -1,72 +0,0 @@
package cache
import (
"time"
)
type cacheConfig struct {
size int // Size in MB
ttl time.Duration
cleanFreq time.Duration
}
// Interface to wrap any caching implementation
type Cache interface {
Set(key, value interface{}) error // Only exported fields in struct will be stored
Get(key interface{}) (interface{}, error)
}
// New builds a new default cache. You may pass options to modify the default values
func New(opts ...Option) (Cache, error) {
cacheConfig := &cacheConfig{
size: 1,
ttl: 60 * time.Second,
cleanFreq: 30 * time.Second,
}
for _, opt := range opts {
opt.apply(cacheConfig)
}
cache, err := newBigCache(cacheConfig)
if err != nil {
return nil, err
}
return cache, nil
}
type Option interface {
apply(cacheConfig *cacheConfig)
}
type optionFunc func(*cacheConfig)
func (opt optionFunc) apply(cacheConfig *cacheConfig) {
opt(cacheConfig)
}
// WithSizeInMB sets the size of the cache in MBs
// The minimum size of the cache is 1 MB
// If a size of 0 or less is passed the cache will have unlimited size
func WithSizeInMB(size int) Option {
return optionFunc(func(cacheConfig *cacheConfig) {
cacheConfig.size = size
})
}
// WithTTL will cause the cache to expire any item that lives longer
// than the given ttl
func WithTTL(ttl time.Duration) Option {
return optionFunc(func(cacheConfig *cacheConfig) {
cacheConfig.ttl = ttl
})
}
// WithCleanFrequency sets how often the cache will clean out expired items
// The lowest the frequency may be is 1 second
// If the time is 0 then no cleaning will happen and items will never be removed
func WithCleanFrequency(cleanFreq time.Duration) Option {
return optionFunc(func(cacheConfig *cacheConfig) {
cacheConfig.cleanFreq = cleanFreq
})
}

63
cache/cachutil.go vendored
View file

@ -1,63 +0,0 @@
package cache
import (
"reflect"
"time"
)
var globalCache Cache
func CacheInit() {
var err error
globalCache, err = New(
WithSizeInMB(10),
WithTTL(12*time.Hour),
WithCleanFrequency(24*time.Hour),
)
if err != nil {
panic("Failed to initialize global cache: " + err.Error())
}
}
func GetCache() Cache {
return globalCache
}
func CachedFunction(fn interface{}, key string) func(...interface{}) (interface{}, error) {
return func(args ...interface{}) (interface{}, error) {
// Check if the result is already cached
if cachedResult, err := globalCache.Get(key); err == nil {
return cachedResult, nil
}
// If not cached, call the original function to get the result
f := reflect.ValueOf(fn)
if len(args) == 0 {
args = nil // Ensure nil is passed when there are no arguments.
}
result := f.Call(getArgs(args))
// Check if the function returned an error as the last return value
if err, ok := result[len(result)-1].Interface().(error); ok && err != nil {
return nil, err
}
// Extract the result from the function call
finalResult := result[0].Interface()
// Cache the result with a default TTL (time-to-live)
globalCache.Set(key, finalResult)
return finalResult, nil
}
}
func getArgs(args []interface{}) []reflect.Value {
var values []reflect.Value
for _, arg := range args {
values = append(values, reflect.ValueOf(arg))
}
return values
}

View file

@ -4,10 +4,13 @@ import (
"embed"
"log"
"os"
"path/filepath"
"time"
"github.com/divyam234/teldrive/utils"
"github.com/divyam234/teldrive/utils/kv"
"github.com/pressly/goose/v3"
"go.etcd.io/bbolt"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@ -17,6 +20,8 @@ import (
//go:embed migrations/*.sql
var embedMigrations embed.FS
var DB *gorm.DB
var BoltDB *bbolt.DB
var KV kv.KV
func InitDB() {
@ -63,6 +68,19 @@ func InitDB() {
}
}()
config := utils.GetConfig()
BoltDB, err = bbolt.Open(filepath.Join(config.ExecDir, "teldrive.db"), 0666, &bbolt.Options{
Timeout: time.Second,
NoGrowSync: false,
})
if err != nil {
panic(err)
}
KV, err = kv.New(kv.Options{Bucket: "teldrive", DB: BoltDB})
if err != nil {
panic(err)
}
}
func migrate() {

View file

@ -0,0 +1,3 @@
-- +goose Up
ALTER TABLE teldrive.users DROP COLUMN settings;

View file

@ -0,0 +1,27 @@
-- +goose Up
-- +goose StatementBegin
CREATE TABLE teldrive.bots (
user_id bigint NOT NULL,
token text NOT NULL,
bot_user_name text NOT NULL,
bot_id bigint NOT NULL,
FOREIGN KEY (user_id) REFERENCES teldrive.users(user_id),
CONSTRAINT btoken_user_un UNIQUE (user_id,token)
);
CREATE TABLE teldrive.channels (
channel_id bigint NOT NULL PRIMARY KEY,
channel_name text NOT NULL,
user_id bigint NOT NULL,
selected boolean DEFAULT false,
FOREIGN KEY (user_id) REFERENCES teldrive.users(user_id)
);
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
DROP TABLE IF EXISTS teldrive.bots;
DROP TABLE IF EXISTS teldrive.channels;
-- +goose StatementEnd

View file

@ -0,0 +1,22 @@
-- +goose Up
-- +goose StatementBegin
CREATE OR REPLACE FUNCTION teldrive.account_stats(
IN u_id BIGINT
) RETURNS TABLE (total_size BIGINT, total_files BIGINT, ch_id BIGINT,ch_name TEXT ) AS $$
DECLARE
total_size BIGINT;
total_files BIGINT;
ch_id BIGINT;
ch_name TEXT;
BEGIN
SELECT COUNT(*), SUM(size) into total_files,total_size FROM teldrive.files WHERE user_id=u_id AND type= 'file' and status='active';
SELECT channel_id ,channel_name into ch_id,ch_name FROM teldrive.channels WHERE selected=TRUE AND user_id=u_id;
RETURN QUERY SELECT total_size,total_files,ch_id,ch_name;
END;
$$ LANGUAGE plpgsql;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
DROP FUNCTION IF EXISTS teldrive.account_stats;
-- +goose StatementEnd

View file

@ -7,6 +7,7 @@ services:
container_name: teldrive
volumes:
- ./sessions:/app/sessions:rw
env_file: .env
- ./teldrive.db:/app/teldrive.db:rw
env_file: teldrive.env
ports:
- 8080:8080

23
go.mod
View file

@ -3,19 +3,20 @@ module github.com/divyam234/teldrive
go 1.21
require (
github.com/allegro/bigcache/v3 v3.1.0
github.com/divyam234/cors v1.4.2
github.com/gin-gonic/gin v1.9.1
github.com/go-co-op/gocron v1.32.1
github.com/go-co-op/gocron v1.34.0
github.com/go-jose/go-jose/v3 v3.0.0
github.com/gotd/contrib v0.19.0
github.com/gotd/td v0.85.0
github.com/gotd/td v0.87.0
github.com/joho/godotenv v1.5.1
github.com/kelseyhightower/envconfig v1.4.0
github.com/mitchellh/mapstructure v1.5.0
github.com/pkg/errors v0.9.1
github.com/quantumsheep/range-parser v1.1.0
go.uber.org/zap v1.25.0
github.com/thoas/go-funk v0.9.3
go.etcd.io/bbolt v1.3.7
go.uber.org/zap v1.26.0
golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb
golang.org/x/time v0.3.0
gorm.io/driver/postgres v1.5.2
@ -23,7 +24,7 @@ require (
)
require (
github.com/google/uuid v1.3.0 // indirect
github.com/google/uuid v1.3.1 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
)
@ -62,16 +63,16 @@ require (
github.com/segmentio/asm v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
go.opentelemetry.io/otel v1.16.0 // indirect
go.opentelemetry.io/otel/trace v1.16.0 // indirect
go.opentelemetry.io/otel v1.18.0 // indirect
go.opentelemetry.io/otel/trace v1.18.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.12.0 // indirect
golang.org/x/net v0.14.0 // indirect
golang.org/x/crypto v0.13.0 // indirect
golang.org/x/net v0.15.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.11.0 // indirect
golang.org/x/text v0.12.0 // indirect
golang.org/x/sys v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
nhooyr.io/websocket v1.8.7 // indirect

52
go.sum
View file

@ -1,7 +1,3 @@
github.com/allegro/bigcache/v3 v3.1.0 h1:H2Vp8VOvxcrB91o86fUSVJFqeuz8kpyyB02eH3bSzwk=
github.com/allegro/bigcache/v3 v3.1.0/go.mod h1:aPyh7jEvrog9zAwx5N7+JUQX5dZTSGpxF1LAR4dr35I=
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
@ -30,8 +26,8 @@ github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwv
github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-co-op/gocron v1.32.1 h1:h+StA6Qzlv+ImlCaLfA26rLN9eS/l4sO7oWmPUbRVIY=
github.com/go-co-op/gocron v1.32.1/go.mod h1:UGz2oYvVS6PsqlwuOdo5L1Djsg/cQjxJ6T5ntkhp9Bg=
github.com/go-co-op/gocron v1.34.0 h1:/rcOZjJWUYnGR0ZDKozPXEnJ+wJt220FSLo2/hxZvV0=
github.com/go-co-op/gocron v1.34.0/go.mod h1:NLi+bkm4rRSy1F8U7iacZOz0xPseMoIOnvabGoSe/no=
github.com/go-faster/errors v0.6.1 h1:nNIPOBkprlKzkThvS/0YaX8Zs9KewLCOSFQS5BU06FI=
github.com/go-faster/errors v0.6.1/go.mod h1:5MGV2/2T9yvlrbhe9pD9LO5Z/2zCSq2T8j+Jpi2LAyY=
github.com/go-faster/jx v1.1.0 h1:ZsW3wD+snOdmTDy9eIVgQdjUpXRRV4rqW8NS3t+20bg=
@ -76,8 +72,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
@ -87,8 +83,8 @@ github.com/gotd/ige v0.2.2 h1:XQ9dJZwBfDnOGSTxKXBGP4gMud3Qku2ekScRjDWWfEk=
github.com/gotd/ige v0.2.2/go.mod h1:tuCRb+Y5Y3eNTo3ypIfNpQ4MFjrnONiL2jN2AKZXmb0=
github.com/gotd/neo v0.1.5 h1:oj0iQfMbGClP8xI59x7fE/uHoTJD7NZH9oV1WNuPukQ=
github.com/gotd/neo v0.1.5/go.mod h1:9A2a4bn9zL6FADufBdt7tZt+WMhvZoc5gWXihOPoiBQ=
github.com/gotd/td v0.85.0 h1:yDKBAdNwcNuICOqhlXFfadZ//Loanylzmqx807Q3LGI=
github.com/gotd/td v0.85.0/go.mod h1:zEftpfZa/x7OC70BHmgqqa8gC/OHwkKis8CpcnxKx9o=
github.com/gotd/td v0.87.0 h1:WdsNdx+GZdUae2ow2wVfawghW4if9wBbFmbRcssLf24=
github.com/gotd/td v0.87.0/go.mod h1:ZtLcOeBvDSs8g17vZ2nvISnL0RP9cTITCaUBHohN+KI=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
@ -176,6 +172,8 @@ github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/thoas/go-funk v0.9.3 h1:7+nAEx3kn5ZJcnDm2Bh23N2yOtweO14bi//dvRtgLpw=
github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
@ -184,10 +182,12 @@ github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLY
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
go.opentelemetry.io/otel v1.16.0 h1:Z7GVAX/UkAXPKsy94IU+i6thsQS4nb7LviLpnaNeW8s=
go.opentelemetry.io/otel v1.16.0/go.mod h1:vl0h9NUa1D5s1nv3A5vZOYWn8av4K8Ml6JDeHrT/bx4=
go.opentelemetry.io/otel/trace v1.16.0 h1:8JRpaObFoW0pxuVPapkgH8UhHQj+bJW8jJsCZEu5MQs=
go.opentelemetry.io/otel/trace v1.16.0/go.mod h1:Yt9vYq1SdNz3xdjZZK7wcXv1qv2pwLkqr2QVwea0ef0=
go.etcd.io/bbolt v1.3.7 h1:j+zJOnnEjF/kyHlDDgGnVL/AIqIJPq8UoB2GSNfkUfQ=
go.etcd.io/bbolt v1.3.7/go.mod h1:N9Mkw9X8x5fupy0IKsmuqVtoGDyxsaDlbk4Rd05IAQw=
go.opentelemetry.io/otel v1.18.0 h1:TgVozPGZ01nHyDZxK5WGPFB9QexeTMXEH7+tIClWfzs=
go.opentelemetry.io/otel v1.18.0/go.mod h1:9lWqYO0Db579XzVuCKFNPDl4s73Voa+zEck3wHaAYQI=
go.opentelemetry.io/otel/trace v1.18.0 h1:NY+czwbHbmndxojTEKiSMHkG2ClNH2PwmcHrdo0JY10=
go.opentelemetry.io/otel/trace v1.18.0/go.mod h1:T2+SGJGuYZY3bjj5rgh/hN7KIrlpWC5nS8Mjvzckz+0=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
@ -195,24 +195,24 @@ go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.25.0 h1:4Hvk6GtkucQ790dqmj7l1eEnRdKm3k3ZUrUMS2d5+5c=
go.uber.org/zap v1.25.0/go.mod h1:JIAUzQIH94IC4fOJQm7gMmBJP5k7wQfdcnYdPoEXJYk=
go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb h1:mIKbk8weKhSeLH2GmUTrvx8CjkyJmnU1wFmg59CUjFA=
golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
@ -225,21 +225,21 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.12.0 h1:YW6HUoUmYBpwSgyaGaZq1fHjrBjX1rlpZ54T6mu2kss=
golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM=
golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=

View file

@ -6,12 +6,10 @@ import (
"path/filepath"
"time"
"github.com/divyam234/teldrive/cache"
"github.com/divyam234/teldrive/database"
"github.com/divyam234/teldrive/routes"
"github.com/divyam234/teldrive/ui"
"github.com/divyam234/teldrive/utils"
"github.com/gin-contrib/gzip"
"github.com/divyam234/cors"
"github.com/divyam234/teldrive/utils/cron"
@ -25,18 +23,12 @@ func main() {
router := gin.Default()
router.Use(gzip.Gzip(gzip.DefaultCompression))
utils.InitConfig()
utils.InitializeLogger()
database.InitDB()
cache.CacheInit()
utils.InitBotClients()
scheduler := gocron.NewScheduler(time.UTC)
scheduler.Every(1).Hour().Do(cron.FilesDeleteJob)

57
mapper/main.go Normal file
View file

@ -0,0 +1,57 @@
package mapper
import (
"github.com/divyam234/teldrive/models"
"github.com/divyam234/teldrive/schemas"
)
func MapFileToFileOut(file models.File) schemas.FileOut {
return schemas.FileOut{
ID: file.ID,
Name: file.Name,
Type: file.Type,
MimeType: file.MimeType,
Path: file.Path,
Size: file.Size,
Starred: file.Starred,
ParentID: file.ParentID,
UpdatedAt: file.UpdatedAt,
}
}
func MapFileInToFile(file schemas.FileIn) models.File {
return models.File{
Name: file.Name,
Type: file.Type,
MimeType: file.MimeType,
Path: file.Path,
Size: file.Size,
Starred: file.Starred,
Depth: file.Depth,
UserID: file.UserID,
ParentID: file.ParentID,
Parts: file.Parts,
ChannelID: file.ChannelID,
Status: file.Status,
}
}
func MapFileToFileOutFull(file models.File) *schemas.FileOutFull {
return &schemas.FileOutFull{
FileOut: MapFileToFileOut(file),
Parts: file.Parts, ChannelID: file.ChannelID,
}
}
func MapUploadSchema(in *models.Upload) *schemas.UploadPartOut {
out := &schemas.UploadPartOut{
ID: in.ID,
Name: in.Name,
PartId: in.PartId,
ChannelID: in.ChannelID,
PartNo: in.PartNo,
TotalParts: in.TotalParts,
Size: in.Size,
}
return out
}

8
models/bot.model.go Normal file
View file

@ -0,0 +1,8 @@
package models
type Bot struct {
Token string `gorm:"type:text;primaryKey"`
UserID int64 `gorm:"type:bigint"`
BotID int64 `gorm:"type:bigint"`
BotUserName string `gorm:"type:text"`
}

8
models/channel.go Normal file
View file

@ -0,0 +1,8 @@
package models
type Channel struct {
ChannelID int64 `gorm:"type:bigint;primaryKey"`
ChannelName string `gorm:"type:text"`
UserID int64 `gorm:"type:bigint;"`
Selected bool `gorm:"type:boolean;"`
}

View file

@ -5,7 +5,6 @@ import (
"github.com/divyam234/teldrive/database"
"github.com/divyam234/teldrive/services"
"github.com/divyam234/teldrive/utils"
"github.com/gin-gonic/gin"
)
@ -13,10 +12,9 @@ import (
func addFileRoutes(rg *gin.RouterGroup) {
r := rg.Group("/files")
r.Use(Authmiddleware)
fileService := services.FileService{Db: database.DB, ChannelID: utils.GetConfig().ChannelID}
fileService := services.FileService{Db: database.DB}
r.GET("", func(c *gin.Context) {
r.GET("", Authmiddleware, func(c *gin.Context) {
res, err := fileService.ListFiles(c)
if err != nil {
@ -27,7 +25,7 @@ func addFileRoutes(rg *gin.RouterGroup) {
c.JSON(http.StatusOK, res)
})
r.POST("", func(c *gin.Context) {
r.POST("", Authmiddleware, func(c *gin.Context) {
res, err := fileService.CreateFile(c)
@ -39,7 +37,7 @@ func addFileRoutes(rg *gin.RouterGroup) {
c.JSON(http.StatusOK, res)
})
r.GET("/:fileID", func(c *gin.Context) {
r.GET("/:fileID", Authmiddleware, func(c *gin.Context) {
res, err := fileService.GetFileByID(c)
@ -51,7 +49,7 @@ func addFileRoutes(rg *gin.RouterGroup) {
c.JSON(http.StatusOK, res)
})
r.PATCH("/:fileID", func(c *gin.Context) {
r.PATCH("/:fileID", Authmiddleware, func(c *gin.Context) {
res, err := fileService.UpdateFile(c)
@ -68,7 +66,7 @@ func addFileRoutes(rg *gin.RouterGroup) {
fileService.GetFileStream(c)
})
r.POST("/movefiles", func(c *gin.Context) {
r.POST("/movefiles", Authmiddleware, func(c *gin.Context) {
res, err := fileService.MoveFiles(c)
@ -80,7 +78,7 @@ func addFileRoutes(rg *gin.RouterGroup) {
c.JSON(http.StatusOK, res)
})
r.POST("/makedir", func(c *gin.Context) {
r.POST("/makedir", Authmiddleware, func(c *gin.Context) {
res, err := fileService.MakeDirectory(c)
@ -92,7 +90,7 @@ func addFileRoutes(rg *gin.RouterGroup) {
c.JSON(http.StatusOK, res)
})
r.POST("/deletefiles", func(c *gin.Context) {
r.POST("/deletefiles", Authmiddleware, func(c *gin.Context) {
res, err := fileService.DeleteFiles(c)

View file

@ -4,7 +4,6 @@ import (
"net/http"
"time"
"github.com/divyam234/teldrive/utils"
"github.com/divyam234/teldrive/utils/auth"
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v3/jwt"
@ -12,9 +11,6 @@ import (
func Authmiddleware(c *gin.Context) {
if c.FullPath() == "/api/files/:fileID/:fileName" && utils.GetConfig().MultiClient {
c.Next()
}
cookie, err := c.Request.Cookie("user-session")
if err != nil {

View file

@ -5,7 +5,6 @@ import (
"github.com/divyam234/teldrive/database"
"github.com/divyam234/teldrive/services"
"github.com/divyam234/teldrive/utils"
"github.com/gin-gonic/gin"
)
@ -15,7 +14,7 @@ func addUploadRoutes(rg *gin.RouterGroup) {
r := rg.Group("/uploads")
r.Use(Authmiddleware)
uploadService := services.UploadService{Db: database.DB, ChannelID: utils.GetConfig().ChannelID}
uploadService := services.UploadService{Db: database.DB}
r.GET("/:id", func(c *gin.Context) {

View file

@ -1,6 +1,9 @@
package routes
import (
"net/http"
"github.com/divyam234/teldrive/database"
"github.com/divyam234/teldrive/services"
"github.com/gin-gonic/gin"
@ -9,11 +12,87 @@ import (
func addUserRoutes(rg *gin.RouterGroup) {
r := rg.Group("/users")
r.Use(Authmiddleware)
userService := services.UserService{}
userService := services.UserService{Db: database.DB}
r.GET("/profile", func(c *gin.Context) {
if c.Query("photo") != "" {
userService.GetProfilePhoto(c)
}
})
r.GET("/stats", func(c *gin.Context) {
res, err := userService.Stats(c)
if err != nil {
c.AbortWithError(err.Code, err.Error)
return
}
c.JSON(http.StatusOK, res)
})
r.GET("/bots", func(c *gin.Context) {
res, err := userService.GetBots(c)
if err != nil {
c.AbortWithError(err.Code, err.Error)
return
}
c.JSON(http.StatusOK, res)
})
r.GET("/channels", func(c *gin.Context) {
res, err := userService.ListChannels(c)
if err != nil {
c.AbortWithError(err.Code, err.Error)
return
}
c.JSON(http.StatusOK, res)
})
r.PATCH("/channels", func(c *gin.Context) {
res, err := userService.UpdateChannel(c)
if err != nil {
c.AbortWithError(err.Code, err.Error)
return
}
c.JSON(http.StatusOK, res)
})
r.POST("/bots", func(c *gin.Context) {
res, err := userService.AddBots(c)
if err != nil {
c.AbortWithError(err.Code, err.Error)
return
}
c.JSON(http.StatusCreated, res)
})
r.GET("/bots/revoke", func(c *gin.Context) {
res, err := userService.RevokeBotSession(c)
if err != nil {
c.AbortWithError(err.Code, err.Error)
return
}
c.JSON(http.StatusOK, res)
})
r.DELETE("/cache", func(c *gin.Context) {
res, err := userService.ClearCache(c)
if err != nil {
c.AbortWithError(err.Code, err.Error)
return
}
c.JSON(http.StatusOK, res)
})
}

13
schemas/user.schema.go Normal file
View file

@ -0,0 +1,13 @@
package schemas
type AccountStats struct {
TotalSize int64 `json:"totalSize"`
TotalFiles int64 `json:"totalFiles"`
ChId int64 `json:"channelId,omitempty"`
ChName string `json:"channelName,omitempty"`
}
type Channel struct {
ChannelID int64 `json:"channelId"`
ChannelName string `json:"channelName"`
}

View file

@ -3,23 +3,27 @@ package services
import (
"bytes"
"context"
"crypto/md5"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log"
"math/big"
"net"
"net/http"
"strconv"
"time"
"github.com/divyam234/teldrive/database"
"github.com/divyam234/teldrive/models"
"github.com/divyam234/teldrive/schemas"
"github.com/divyam234/teldrive/types"
"github.com/divyam234/teldrive/utils"
"github.com/divyam234/teldrive/utils/auth"
"github.com/divyam234/teldrive/utils/kv"
"github.com/divyam234/teldrive/utils/tgc"
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/gorilla/websocket"
@ -141,6 +145,11 @@ func (as *AuthService) LogIn(c *gin.Context) (*schemas.Message, *types.AppError)
IsPremium: session.IsPremium,
}
tokenBytes, _ := json.Marshal(jwtClaims)
md5hash := md5.Sum(tokenBytes)
hexToken := hex.EncodeToString(md5hash[:])
jwtClaims.Hash = hexToken
jweToken, err := auth.Encode(jwtClaims)
if err != nil {
@ -157,12 +166,15 @@ func (as *AuthService) LogIn(c *gin.Context) (*schemas.Message, *types.AppError)
var result []models.User
if err := as.Db.Model(&models.User{}).Where("user_id = ?", session.UserID).Find(&result).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to create or update user"), Code: http.StatusInternalServerError}
if err := as.Db.Model(&models.User{}).Where("user_id = ?", session.UserID).
Find(&result).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to create or update user"),
Code: http.StatusInternalServerError}
}
if len(result) == 0 {
if err := as.Db.Create(&user).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to create or update user"), Code: http.StatusInternalServerError}
return nil, &types.AppError{Error: errors.New("failed to create or update user"),
Code: http.StatusInternalServerError}
}
//Create root folder on first login
@ -177,14 +189,21 @@ func (as *AuthService) LogIn(c *gin.Context) (*schemas.Message, *types.AppError)
ParentID: "root",
}
if err := as.Db.Create(file).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to create or update user"), Code: http.StatusInternalServerError}
return nil, &types.AppError{Error: errors.New("failed to create or update user"),
Code: http.StatusInternalServerError}
}
} else {
if err := as.Db.Model(&models.User{}).Where("user_id = ?", session.UserID).Update("tg_session", session.Sesssion).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to create or update user"), Code: http.StatusInternalServerError}
if err := as.Db.Model(&models.User{}).Where("user_id = ?", session.UserID).
Update("tg_session", session.Sesssion).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to create or update user"),
Code: http.StatusInternalServerError}
}
}
setCookie(c, as.SessionCookieName, jweToken, as.SessionMaxAge)
database.KV.Set(kv.Key("sessions", hexToken), tokenBytes)
return &schemas.Message{Status: true, Message: "login success"}, nil
}
@ -206,7 +225,10 @@ func (as *AuthService) GetSession(c *gin.Context) *types.Session {
newExpires := now.Add(time.Duration(as.SessionMaxAge) * time.Second)
session := &types.Session{Name: jwePayload.Name, UserName: jwePayload.UserName, Expires: newExpires.Format(time.RFC3339)}
session := &types.Session{Name: jwePayload.Name,
UserName: jwePayload.UserName,
Hash: jwePayload.Hash,
Expires: newExpires.Format(time.RFC3339)}
jwePayload.IssuedAt = jwt.NewNumericDate(now)
@ -224,16 +246,15 @@ func (as *AuthService) GetSession(c *gin.Context) *types.Session {
func (as *AuthService) Logout(c *gin.Context) (*schemas.Message, *types.AppError) {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
client, _ := tgc.UserLogin(jwtUser.TgSession)
client, _ := utils.GetAuthClient(c, jwtUser.TgSession, userId)
client.Run(c, func(ctx context.Context) error {
tgc.RunWithAuth(c, client, "", func(ctx context.Context) error {
_, err := client.API().AuthLogOut(c)
return err
})
setCookie(c, as.SessionCookieName, "", -1)
database.KV.Delete(kv.Key("sessions", jwtUser.Hash))
return &schemas.Message{Status: true, Message: "logout success"}, nil
}
@ -258,7 +279,6 @@ func (as *AuthService) HandleMultipleLogin(c *gin.Context) {
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Println(err)
return
}
defer conn.Close()
@ -266,117 +286,119 @@ func (as *AuthService) HandleMultipleLogin(c *gin.Context) {
dispatcher := tg.NewUpdateDispatcher()
loggedIn := qrlogin.OnLoginToken(dispatcher)
sessionStorage := &session.StorageMemory{}
tgClient, stop, _ := utils.StartNonAuthClient(dispatcher, sessionStorage)
tgClient := tgc.NoLogin(dispatcher, sessionStorage)
defer stop()
err = tgClient.Run(c, func(ctx context.Context) error {
for {
message := &SocketMessage{}
err := conn.ReadJSON(message)
for {
message := &SocketMessage{}
err := conn.ReadJSON(message)
if err != nil {
return err
}
if message.AuthType == "qr" {
go func() {
authorization, err := tgClient.QR().Auth(c, loggedIn, func(ctx context.Context, token qrlogin.Token) error {
conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": map[string]string{"token": token.URL()}})
return nil
})
if err != nil {
log.Println(err)
return
if tgerr.Is(err, "SESSION_PASSWORD_NEEDED") {
conn.WriteJSON(map[string]interface{}{"type": "auth", "message": "2FA required"})
return
}
if err != nil {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()})
return
}
user, ok := authorization.User.AsNotEmpty()
if !ok {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"})
return
}
if !checkUserIsAllowed(user.Username) {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"})
tgClient.API().AuthLogOut(c)
return
}
res, _ := sessionStorage.LoadSession(c)
sessionData := &SessionData{}
json.Unmarshal(res, sessionData)
session := prepareSession(user, &sessionData.Data)
conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"})
}()
}
if message.AuthType == "phone" && message.Message == "sendcode" {
go func() {
res, err := tgClient.Auth().SendCode(c, message.PhoneNo, tgauth.SendCodeOptions{})
if err != nil {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()})
return
}
code := res.(*tg.AuthSentCode)
conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": map[string]string{"phoneCodeHash": code.PhoneCodeHash}})
}()
}
if message.AuthType == "phone" && message.Message == "signin" {
go func() {
auth, err := tgClient.Auth().SignIn(c, message.PhoneNo, message.PhoneCode, message.PhoneCodeHash)
if errors.Is(err, tgauth.ErrPasswordAuthNeeded) {
conn.WriteJSON(map[string]interface{}{"type": "auth", "message": "2FA required"})
return
}
if err != nil {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()})
return
}
user, ok := auth.User.AsNotEmpty()
if !ok {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"})
return
}
if !checkUserIsAllowed(user.Username) {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"})
tgClient.API().AuthLogOut(c)
return
}
res, _ := sessionStorage.LoadSession(c)
sessionData := &SessionData{}
json.Unmarshal(res, sessionData)
session := prepareSession(user, &sessionData.Data)
conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"})
}()
}
if message.AuthType == "2fa" && message.Password != "" {
go func() {
auth, err := tgClient.Auth().Password(c, message.Password)
if err != nil {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()})
return
}
user, ok := auth.User.AsNotEmpty()
if !ok {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"})
return
}
if !checkUserIsAllowed(user.Username) {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"})
tgClient.API().AuthLogOut(c)
return
}
res, _ := sessionStorage.LoadSession(c)
sessionData := &SessionData{}
json.Unmarshal(res, sessionData)
session := prepareSession(user, &sessionData.Data)
conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"})
}()
}
}
if message.AuthType == "qr" {
go func() {
authorization, err := tgClient.QR().Auth(c, loggedIn, func(ctx context.Context, token qrlogin.Token) error {
conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": map[string]string{"token": token.URL()}})
return nil
})
})
if tgerr.Is(err, "SESSION_PASSWORD_NEEDED") {
conn.WriteJSON(map[string]interface{}{"type": "auth", "message": "2FA required"})
return
}
if err != nil {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()})
return
}
user, ok := authorization.User.AsNotEmpty()
if !ok {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"})
return
}
if !checkUserIsAllowed(user.Username) {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"})
tgClient.API().AuthLogOut(c)
return
}
res, _ := sessionStorage.LoadSession(c)
sessionData := &SessionData{}
json.Unmarshal(res, sessionData)
session := prepareSession(user, &sessionData.Data)
conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"})
}()
}
if message.AuthType == "phone" && message.Message == "sendcode" {
go func() {
res, err := tgClient.Auth().SendCode(c, message.PhoneNo, tgauth.SendCodeOptions{})
if err != nil {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()})
return
}
code := res.(*tg.AuthSentCode)
conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": map[string]string{"phoneCodeHash": code.PhoneCodeHash}})
}()
}
if message.AuthType == "phone" && message.Message == "signin" {
go func() {
auth, err := tgClient.Auth().SignIn(c, message.PhoneNo, message.PhoneCode, message.PhoneCodeHash)
if errors.Is(err, tgauth.ErrPasswordAuthNeeded) {
conn.WriteJSON(map[string]interface{}{"type": "auth", "message": "2FA required"})
return
}
if err != nil {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()})
return
}
user, ok := auth.User.AsNotEmpty()
if !ok {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"})
return
}
if !checkUserIsAllowed(user.Username) {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"})
tgClient.API().AuthLogOut(c)
return
}
res, _ := sessionStorage.LoadSession(c)
sessionData := &SessionData{}
json.Unmarshal(res, sessionData)
session := prepareSession(user, &sessionData.Data)
conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"})
}()
}
if message.AuthType == "2fa" && message.Password != "" {
go func() {
auth, err := tgClient.Auth().Password(c, message.Password)
if err != nil {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": err.Error()})
return
}
user, ok := auth.User.AsNotEmpty()
if !ok {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "auth failed"})
return
}
if !checkUserIsAllowed(user.Username) {
conn.WriteJSON(map[string]interface{}{"type": "error", "message": "user not allowed"})
tgClient.API().AuthLogOut(c)
return
}
res, _ := sessionStorage.LoadSession(c)
sessionData := &SessionData{}
json.Unmarshal(res, sessionData)
session := prepareSession(user, &sessionData.Data)
conn.WriteJSON(map[string]interface{}{"type": "auth", "payload": session, "message": "success"})
}()
}
if err != nil {
return
}
}

209
services/common.go Normal file
View file

@ -0,0 +1,209 @@
package services
import (
"bytes"
"context"
"fmt"
"math"
"strconv"
"github.com/divyam234/teldrive/database"
"github.com/divyam234/teldrive/models"
"github.com/divyam234/teldrive/schemas"
"github.com/divyam234/teldrive/types"
"github.com/divyam234/teldrive/utils"
"github.com/divyam234/teldrive/utils/kv"
"github.com/divyam234/teldrive/utils/tgc"
"github.com/gin-gonic/gin"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
"github.com/pkg/errors"
"github.com/thoas/go-funk"
)
func getChunk(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass, offset int64, limit int64) ([]byte, error) {
req := &tg.UploadGetFileRequest{
Offset: offset,
Limit: int(limit),
Location: location,
}
r, err := tgClient.API().UploadGetFile(ctx, req)
if err != nil {
return nil, err
}
switch result := r.(type) {
case *tg.UploadFile:
return result.Bytes, nil
default:
return nil, fmt.Errorf("unexpected type %T", r)
}
}
func iterContent(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass) (*bytes.Buffer, error) {
offset := int64(0)
limit := int64(1024 * 1024)
buff := &bytes.Buffer{}
for {
r, err := getChunk(ctx, tgClient, location, offset, limit)
if err != nil {
return buff, err
}
if len(r) == 0 {
break
}
buff.Write(r)
offset += int64(limit)
}
return buff, nil
}
func getUserAuth(c *gin.Context) (int64, string) {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
return userId, jwtUser.TgSession
}
func getBotInfo(ctx context.Context, token string) (*BotInfo, error) {
client, _ := tgc.BotLogin(token)
var user *tg.User
err := tgc.RunWithAuth(ctx, client, token, func(ctx context.Context) error {
user, _ = client.Self(ctx)
return nil
})
if err != nil {
return nil, err
}
return &BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil
}
func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) {
ids := funk.Map(*file.Parts, func(part models.Part) tg.InputMessageClass {
return tg.InputMessageClass(&tg.InputMessageID{ID: int(part.ID)})
})
channel, err := GetChannelById(ctx, client, *file.ChannelID, userID)
if err != nil {
return nil, err
}
messageRequest := tg.ChannelsGetMessagesRequest{Channel: channel, ID: ids.([]tg.InputMessageClass)}
res, err := client.API().ChannelsGetMessages(ctx, &messageRequest)
if err != nil {
return nil, err
}
messages := res.(*tg.MessagesChannelMessages)
parts := []types.Part{}
for _, message := range messages.Messages {
item := message.(*tg.Message)
media := item.Media.(*tg.MessageMediaDocument)
document := media.Document.(*tg.Document)
location := document.AsInputDocumentFileLocation()
parts = append(parts, types.Part{Location: location, Start: 0, End: document.Size - 1, Size: document.Size})
}
return parts, nil
}
func rangedParts(parts []types.Part, start, end int64) []types.Part {
chunkSize := parts[0].Size
startPartNumber := utils.Max(int64(math.Ceil(float64(start)/float64(chunkSize)))-1, 0)
endPartNumber := int64(math.Ceil(float64(end) / float64(chunkSize)))
partsToDownload := parts[startPartNumber:endPartNumber]
partsToDownload[0].Start = start % chunkSize
partsToDownload[len(partsToDownload)-1].End = end % chunkSize
for i, part := range partsToDownload {
partsToDownload[i].Length = part.End - part.Start + 1
}
return partsToDownload
}
func GetChannelById(ctx context.Context, client *telegram.Client, channelID int64, userID string) (*tg.InputChannel, error) {
channel := &tg.InputChannel{}
key := kv.Key("channels", strconv.FormatInt(channelID, 10), userID)
err := kv.GetValue(database.KV, key, channel)
if err != nil {
inputChannel := &tg.InputChannel{
ChannelID: channelID,
}
channels, err := client.API().ChannelsGetChannels(ctx, []tg.InputChannelClass{inputChannel})
if err != nil {
return nil, err
}
if len(channels.GetChats()) == 0 {
return nil, errors.New("no channels found")
}
channel = channels.GetChats()[0].(*tg.Channel).AsInput()
kv.SetValue(database.KV, key, channel)
}
return channel, nil
}
func GetDefaultChannel(ctx context.Context, userID int64) (int64, error) {
var channelID int64
key := kv.Key("users", strconv.FormatInt(userID, 10), "channel")
err := kv.GetValue(database.KV, key, &channelID)
if err != nil {
var channelIds []int64
database.DB.Model(&models.Channel{}).Where("user_id = ?", userID).Where("selected = ?", true).
Pluck("channel_id", &channelIds)
if len(channelIds) == 1 {
channelID = channelIds[0]
kv.SetValue(database.KV, key, &channelID)
}
}
if channelID == 0 {
return channelID, errors.New("default channel not set")
}
return channelID, nil
}
func GetBotsToken(userID int64) ([]string, error) {
var bots []string
key := kv.Key("users", strconv.FormatInt(userID, 10), "bots")
err := kv.GetValue(database.KV, key, &bots)
if err != nil {
if err := database.DB.Model(&models.Bot{}).Where("user_id = ?", userID).Pluck("token", &bots).Error; err != nil {
return nil, err
}
kv.SetValue(database.KV, key, &bots)
}
return bots, nil
}

View file

@ -3,26 +3,29 @@ package services
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"net/http"
"strconv"
"strings"
"time"
"github.com/divyam234/teldrive/cache"
"github.com/divyam234/teldrive/database"
"github.com/divyam234/teldrive/mapper"
"github.com/divyam234/teldrive/models"
"github.com/divyam234/teldrive/schemas"
"github.com/divyam234/teldrive/utils"
"github.com/divyam234/teldrive/utils/kv"
"github.com/divyam234/teldrive/utils/md5"
"github.com/divyam234/teldrive/utils/reader"
"github.com/divyam234/teldrive/utils/tgc"
"github.com/gotd/td/telegram"
"github.com/divyam234/teldrive/types"
"github.com/gin-gonic/gin"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
"github.com/jackc/pgx/v5/pgconn"
"github.com/mitchellh/mapstructure"
range_parser "github.com/quantumsheep/range-parser"
@ -31,19 +34,11 @@ import (
)
type FileService struct {
Db *gorm.DB
ChannelID int64
}
func getAuthUserId(c *gin.Context) int64 {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
return userId
Db *gorm.DB
}
func (fs *FileService) CreateFile(c *gin.Context) (*schemas.FileOut, *types.AppError) {
userId := getAuthUserId(c)
userId, _ := getUserAuth(c)
var fileIn schemas.FileIn
if err := c.ShouldBindJSON(&fileIn); err != nil {
return nil, &types.AppError{Error: errors.New("invalid request payload"), Code: http.StatusBadRequest}
@ -71,14 +66,21 @@ func (fs *FileService) CreateFile(c *gin.Context) (*schemas.FileOut, *types.AppE
fileIn.Depth = utils.IntPointer(len(strings.Split(fileIn.Path, "/")) - 1)
} else if fileIn.Type == "file" {
fileIn.Path = ""
fileIn.ChannelID = &fs.ChannelID
channelId, err := GetDefaultChannel(c, userId)
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
}
fileIn.ChannelID = &channelId
}
fileIn.UserID = userId
fileIn.Starred = utils.BoolPointer(false)
fileIn.Status = "active"
fileDb := mapFileInToFile(fileIn)
fileDb := mapper.MapFileInToFile(fileIn)
if err := fs.Db.Create(&fileDb).Error; err != nil {
pgErr := err.(*pgconn.PgError)
@ -89,7 +91,7 @@ func (fs *FileService) CreateFile(c *gin.Context) (*schemas.FileOut, *types.AppE
}
res := mapFileToFileOut(fileDb)
res := mapper.MapFileToFileOut(fileDb)
return &res, nil
}
@ -111,7 +113,7 @@ func (fs *FileService) UpdateFile(c *gin.Context) (*schemas.FileOut, *types.AppE
return nil, &types.AppError{Error: errors.New("failed to update the file"), Code: http.StatusInternalServerError}
}
} else {
fileDb := mapFileInToFile(fileUpdate)
fileDb := mapper.MapFileInToFile(fileUpdate)
if err := fs.Db.Model(&files).Clauses(clause.Returning{}).Where("id = ?", fileID).Updates(fileDb).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to update the file"), Code: http.StatusInternalServerError}
}
@ -121,7 +123,10 @@ func (fs *FileService) UpdateFile(c *gin.Context) (*schemas.FileOut, *types.AppE
return nil, &types.AppError{Error: errors.New("file not updated"), Code: http.StatusNotFound}
}
file := mapFileToFileOut(files[0])
file := mapper.MapFileToFileOut(files[0])
key := kv.Key("files", fileID)
database.KV.Delete(key)
return &file, nil
@ -139,24 +144,24 @@ func (fs *FileService) GetFileByID(c *gin.Context) (*schemas.FileOutFull, error)
return nil, errors.New("file not found")
}
return mapFileToFileOutFull(file[0]), nil
return mapper.MapFileToFileOutFull(file[0]), nil
}
func (fs *FileService) ListFiles(c *gin.Context) (*schemas.FileResponse, *types.AppError) {
userId := getAuthUserId(c)
userId, _ := getUserAuth(c)
var pagingParams schemas.PaginationQuery
pagingParams.PerPage = 200
if err := c.ShouldBindQuery(&pagingParams); err != nil {
return nil, &types.AppError{Error: errors.New(""), Code: http.StatusBadRequest}
return nil, &types.AppError{Error: errors.New("invalid params"), Code: http.StatusBadRequest}
}
var sortingParams schemas.SortingQuery
sortingParams.Order = "asc"
sortingParams.Sort = "name"
if err := c.ShouldBindQuery(&sortingParams); err != nil {
return nil, &types.AppError{Error: errors.New(""), Code: http.StatusBadRequest}
return nil, &types.AppError{Error: errors.New("invalid params"), Code: http.StatusBadRequest}
}
var fileQuery schemas.FileQuery
@ -164,7 +169,7 @@ func (fs *FileService) ListFiles(c *gin.Context) (*schemas.FileResponse, *types.
fileQuery.Status = "active"
fileQuery.UserID = userId
if err := c.ShouldBindQuery(&fileQuery); err != nil {
return nil, &types.AppError{Error: errors.New(""), Code: http.StatusBadRequest}
return nil, &types.AppError{Error: errors.New("invalid params"), Code: http.StatusBadRequest}
}
query := fs.Db.Model(&models.File{}).Limit(pagingParams.PerPage).
@ -250,12 +255,12 @@ func (fs *FileService) MakeDirectory(c *gin.Context) (*schemas.FileOut, *types.A
return nil, &types.AppError{Error: errors.New("invalid request payload"), Code: http.StatusBadRequest}
}
userId := getAuthUserId(c)
userId, _ := getUserAuth(c)
if err := fs.Db.Raw("select * from teldrive.create_directories(?, ?)", userId, payload.Path).Scan(&files).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to create directories"), Code: http.StatusInternalServerError}
}
file := mapFileToFileOut(files[0])
file := mapper.MapFileToFileOut(files[0])
return &file, nil
@ -309,16 +314,42 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
fileID := c.Param("fileID")
var err error
authHash := c.Query("hash")
res, err := cache.CachedFunction(fs.GetFileByID, fmt.Sprintf("files:%s", fileID))(c)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
if authHash == "" {
http.Error(w, "misssing hash", http.StatusBadRequest)
return
}
file := res.(*schemas.FileOutFull)
data, err := database.KV.Get(kv.Key("sessions", authHash))
if err != nil {
http.Error(w, "hash missing relogin", http.StatusBadRequest)
return
}
jwtUser := &types.JWTClaims{}
err = json.Unmarshal(data, jwtUser)
if err != nil {
http.Error(w, "invalid hash", http.StatusBadRequest)
return
}
file := &schemas.FileOutFull{}
key := kv.Key("files", fileID)
err = kv.GetValue(database.KV, key, file)
if err != nil {
file, err = fs.GetFileByID(c)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
kv.SetValue(database.KV, key, file)
}
ifModifiedSinceHeader := r.Header.Get("If-Modified-Since")
if ifModifiedSinceHeader != "" {
@ -373,113 +404,41 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=\"%s\"", disposition, file.Name))
client, idx := utils.GetDownloadClient(c)
userID, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
defer func() {
utils.GetClientWorkload().Dec(idx)
}()
tokens, err := GetBotsToken(userID)
ir, iw := io.Pipe()
parts, err := fs.getParts(c, client, file)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
parts = rangedParts(parts, int64(start), int64(end))
var client *telegram.Client
var token string
var channelUser string
if len(tokens) == 0 {
client, _ = tgc.UserLogin(jwtUser.TgSession)
channelUser = jwtUser.Subject
} else {
tgc.Workers.Set(tokens)
token = tgc.Workers.Next()
client, _ = tgc.BotLogin(token)
channelUser = strings.Split(token, ":")[0]
}
if r.Method != "HEAD" {
go func() {
defer iw.Close()
for _, part := range parts {
streamFilePart(c, client, iw, &part, part.Start, part.End, 1024*1024)
tgc.RunWithAuth(c, client, token, func(ctx context.Context) error {
parts, err := getParts(c, client, file, channelUser)
if err != nil {
return err
}
}()
io.CopyN(w, ir, contentLength)
}
}
func (fs *FileService) getParts(ctx context.Context, tgClient *telegram.Client, file *schemas.FileOutFull) ([]types.Part, error) {
ids := []tg.InputMessageID{}
for _, part := range *file.Parts {
ids = append(ids, tg.InputMessageID{ID: int(part.ID)})
}
s := make([]tg.InputMessageClass, len(ids))
for i := range ids {
s[i] = &ids[i]
}
api := tgClient.API()
res, err := cache.CachedFunction(utils.GetChannelById, fmt.Sprintf("channels:%d", fs.ChannelID))(ctx, api, fs.ChannelID)
if err != nil {
return nil, err
}
channel := res.(*tg.Channel)
messageRequest := tg.ChannelsGetMessagesRequest{Channel: &tg.InputChannel{ChannelID: fs.ChannelID, AccessHash: channel.AccessHash},
ID: s}
res, err = cache.CachedFunction(api.ChannelsGetMessages, fmt.Sprintf("messages:%s", file.ID))(ctx, &messageRequest)
if err != nil {
return nil, err
}
messages := res.(*tg.MessagesChannelMessages)
parts := []types.Part{}
for _, message := range messages.Messages {
item := message.(*tg.Message)
media := item.Media.(*tg.MessageMediaDocument)
document := media.Document.(*tg.Document)
location := document.AsInputDocumentFileLocation()
parts = append(parts, types.Part{Location: location, Start: 0, End: document.Size - 1, Size: document.Size})
}
return parts, nil
}
func mapFileToFileOut(file models.File) schemas.FileOut {
return schemas.FileOut{
ID: file.ID,
Name: file.Name,
Type: file.Type,
MimeType: file.MimeType,
Path: file.Path,
Size: file.Size,
Starred: file.Starred,
ParentID: file.ParentID,
UpdatedAt: file.UpdatedAt,
}
}
func mapFileInToFile(file schemas.FileIn) models.File {
return models.File{
Name: file.Name,
Type: file.Type,
MimeType: file.MimeType,
Path: file.Path,
Size: file.Size,
Starred: file.Starred,
Depth: file.Depth,
UserID: file.UserID,
ParentID: file.ParentID,
Parts: file.Parts,
ChannelID: file.ChannelID,
Status: file.Status,
}
}
func mapFileToFileOutFull(file models.File) *schemas.FileOutFull {
return &schemas.FileOutFull{
FileOut: mapFileToFileOut(file),
Parts: file.Parts, ChannelID: file.ChannelID,
parts = rangedParts(parts, start, end)
r, _ := reader.NewLinearReader(c, client, parts)
io.CopyN(w, r, contentLength)
return nil
})
}
}
@ -512,81 +471,3 @@ func getOrder(sortingParams schemas.SortingQuery) string {
return fmt.Sprintf("%s %s", sortColumn, strings.ToUpper(sortingParams.Order))
}
func chunk(ctx context.Context, tgClient *telegram.Client, part *types.Part, offset int64, limit int64) ([]byte, error) {
req := &tg.UploadGetFileRequest{
Offset: offset,
Limit: int(limit),
Location: part.Location,
}
r, err := tgClient.API().UploadGetFile(ctx, req)
if err != nil {
return nil, err
}
switch result := r.(type) {
case *tg.UploadFile:
return result.Bytes, nil
default:
return nil, fmt.Errorf("unexpected type %T", r)
}
}
func streamFilePart(ctx context.Context, tgClient *telegram.Client, writer *io.PipeWriter, part *types.Part, start, end, chunkSize int64) error {
offset := start - (start % chunkSize)
firstPartCut := start - offset
lastPartCut := (end % chunkSize) + 1
partCount := int(math.Ceil(float64(end+1)/float64(chunkSize))) - int(math.Floor(float64(offset)/float64(chunkSize)))
currentPart := 1
for {
r, _ := chunk(ctx, tgClient, part, offset, chunkSize)
if len(r) == 0 {
break
} else if partCount == 1 {
r = r[firstPartCut:lastPartCut]
} else if currentPart == 1 {
r = r[firstPartCut:]
} else if currentPart == partCount {
r = r[:lastPartCut]
}
writer.Write(r)
currentPart++
offset += chunkSize
if currentPart > partCount {
break
}
}
return nil
}
func rangedParts(parts []types.Part, start, end int64) []types.Part {
chunkSize := parts[0].Size
startPartNumber := utils.Max(int64(math.Ceil(float64(start)/float64(chunkSize)))-1, 0)
endPartNumber := int64(math.Ceil(float64(end) / float64(chunkSize)))
partsToDownload := parts[startPartNumber:endPartNumber]
partsToDownload[0].Start = start % chunkSize
partsToDownload[len(partsToDownload)-1].End = end % chunkSize
return partsToDownload
}

View file

@ -3,17 +3,19 @@ package services
import (
"context"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"github.com/divyam234/teldrive/cache"
"github.com/divyam234/teldrive/mapper"
"github.com/divyam234/teldrive/schemas"
"github.com/divyam234/teldrive/utils"
"github.com/divyam234/teldrive/utils/tgc"
"github.com/divyam234/teldrive/types"
"github.com/divyam234/teldrive/models"
"github.com/gin-gonic/gin"
"github.com/gotd/td/telegram"
"github.com/gotd/td/telegram/message"
"github.com/gotd/td/telegram/uploader"
"github.com/gotd/td/tg"
@ -21,8 +23,7 @@ import (
)
type UploadService struct {
Db *gorm.DB
ChannelID int64
Db *gorm.DB
}
func (us *UploadService) GetUploadFileById(c *gin.Context) (*schemas.UploadOut, *types.AppError) {
@ -66,12 +67,10 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
us.Db.Model(&models.Upload{}).Where("upload_id = ?", uploadId).Where("part_no = ?", uploadQuery.PartNo).Find(&uploadPart)
if len(uploadPart) == 1 {
out := mapSchema(&uploadPart[0])
out := mapper.MapUploadSchema(&uploadPart[0])
return out, nil
}
client, idx := utils.GetUploadClient(c)
file := c.Request.Body
fileSize := c.Request.ContentLength
@ -85,17 +84,52 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
ctx, cancel := context.WithCancel(ctx)
defer func() {
if idx != -1 {
utils.GetClientWorkload().Dec(idx)
}
cancel()
}()
err := client.Run(ctx, func(ctx context.Context) error {
userId, session := getUserAuth(c)
tokens, err := GetBotsToken(userId)
if err != nil {
return nil, &types.AppError{Error: errors.New("failed to fetch bots"), Code: http.StatusInternalServerError}
}
var client *telegram.Client
var token string
var channelUser string
if len(tokens) == 0 {
client, _ = tgc.UserLogin(session)
channelUser = strconv.FormatInt(userId, 10)
} else {
tgc.Workers.Set(tokens)
token = tgc.Workers.Next()
client, _ = tgc.BotLogin(token)
channelUser = strings.Split(token, ":")[0]
}
var out *schemas.UploadPartOut
err = tgc.RunWithAuth(ctx, client, token, func(ctx context.Context) error {
channelId, err := GetDefaultChannel(ctx, userId)
if err != nil {
return err
}
channel, err := GetChannelById(ctx, client, channelId, channelUser)
if err != nil {
return err
}
api := client.API()
u := uploader.NewUploader(api).WithThreads(10).WithPartSize(512 * 1024)
u := uploader.NewUploader(api).WithThreads(16).WithPartSize(512 * 1024)
upload, err := u.Upload(c, uploader.NewUpload(fileName, file, fileSize))
@ -105,19 +139,12 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
document := message.UploadedDocument(upload).Filename(fileName).ForceFile(true)
res, err := cache.CachedFunction(utils.GetChannelById, fmt.Sprintf("channels:%d", us.ChannelID))(c, client.API(), us.ChannelID)
if err != nil {
return err
}
channel := res.(*tg.Channel)
sender := message.NewSender(client.API())
target := sender.To(&tg.InputPeerChannel{ChannelID: channel.ID, AccessHash: channel.AccessHash})
target := sender.To(&tg.InputPeerChannel{ChannelID: channel.ChannelID,
AccessHash: channel.AccessHash})
res, err = target.Media(c, document)
res, err := target.Media(c, document)
if err != nil {
return err
@ -127,6 +154,26 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
msgId = updates.Updates[0].(*tg.UpdateMessageID).ID
if msgId == 0 {
return errors.New("failed to upload part")
}
partUpload := &models.Upload{
Name: fileName,
UploadId: uploadId,
PartId: msgId,
ChannelID: channelId,
Size: fileSize,
PartNo: uploadQuery.PartNo,
TotalParts: uploadQuery.TotalParts,
}
if err := us.Db.Create(partUpload).Error; err != nil {
return errors.New("failed to upload part")
}
out = mapper.MapUploadSchema(partUpload)
return nil
})
@ -134,38 +181,5 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
}
if msgId == 0 {
return nil, &types.AppError{Error: errors.New("failed to upload part"), Code: http.StatusInternalServerError}
}
partUpload := &models.Upload{
Name: fileName,
UploadId: uploadId,
PartId: msgId,
ChannelID: us.ChannelID,
Size: fileSize,
PartNo: uploadQuery.PartNo,
TotalParts: uploadQuery.TotalParts,
}
if err := us.Db.Create(partUpload).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to upload part"), Code: http.StatusInternalServerError}
}
out := mapSchema(partUpload)
return out, nil
}
func mapSchema(in *models.Upload) *schemas.UploadPartOut {
out := &schemas.UploadPartOut{
ID: in.ID,
Name: in.Name,
PartId: in.PartId,
ChannelID: in.ChannelID,
PartNo: in.PartNo,
TotalParts: in.TotalParts,
Size: in.Size,
}
return out
}

View file

@ -7,65 +7,48 @@ import (
"fmt"
"net/http"
"strconv"
"sync"
"github.com/divyam234/teldrive/database"
"github.com/divyam234/teldrive/models"
"github.com/divyam234/teldrive/schemas"
"github.com/divyam234/teldrive/types"
"github.com/divyam234/teldrive/utils"
"github.com/divyam234/teldrive/utils/kv"
"github.com/divyam234/teldrive/utils/tgc"
"github.com/gotd/td/telegram"
"github.com/gotd/td/telegram/message/peer"
"github.com/gotd/td/telegram/query"
"github.com/gotd/td/tg"
"github.com/thoas/go-funk"
"go.etcd.io/bbolt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type UserService struct {
Db *gorm.DB
}
func getChunk(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass, offset int64, limit int) ([]byte, error) {
req := &tg.UploadGetFileRequest{
Offset: offset,
Limit: int(limit),
Location: location,
}
r, err := tgClient.API().UploadGetFile(ctx, req)
if err != nil {
return nil, err
}
switch result := r.(type) {
case *tg.UploadFile:
return result.Bytes, nil
default:
return nil, fmt.Errorf("unexpected type %T", r)
}
}
func iterContent(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass) (*bytes.Buffer, error) {
offset := int64(0)
limit := 1024 * 1024
buff := &bytes.Buffer{}
for {
r, err := getChunk(ctx, tgClient, location, offset, limit)
if err != nil {
return buff, err
}
if len(r) == 0 {
break
}
buff.Write(r)
offset += int64(limit)
}
return buff, nil
type BotInfo struct {
Id int64
UserName string
AccessHash int64
Token string
}
func (us *UserService) GetProfilePhoto(c *gin.Context) {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
client, _ := utils.GetAuthClient(c, jwtUser.TgSession, userId)
_, session := getUserAuth(c)
err := client.Run(c, func(ctx context.Context) error {
client, err := tgc.UserLogin(session)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
err = tgc.RunWithAuth(c, client, "", func(ctx context.Context) error {
self, err := client.Self(c)
if err != nil {
return err
@ -96,3 +79,283 @@ func (us *UserService) GetProfilePhoto(c *gin.Context) {
return
}
}
func (us *UserService) Stats(c *gin.Context) (*schemas.AccountStats, *types.AppError) {
userId, _ := getUserAuth(c)
var res []schemas.AccountStats
if err := us.Db.Raw("select * from teldrive.account_stats(?);", userId).Scan(&res).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to get stats"), Code: http.StatusInternalServerError}
}
return &res[0], nil
}
func (us *UserService) GetBots(c *gin.Context) ([]string, *types.AppError) {
userID, _ := getUserAuth(c)
tokens, err := GetBotsToken(userID)
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
}
return tokens, nil
}
func (us *UserService) UpdateChannel(c *gin.Context) (*schemas.Message, *types.AppError) {
userId, _ := getUserAuth(c)
var payload schemas.Channel
if err := c.ShouldBindJSON(&payload); err != nil {
return nil, &types.AppError{Error: errors.New("invalid request payload"), Code: http.StatusBadRequest}
}
channel := &models.Channel{ChannelID: payload.ChannelID, ChannelName: payload.ChannelName, UserID: userId,
Selected: true}
if err := us.Db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "channel_id"}},
DoUpdates: clause.Assignments(map[string]interface{}{"selected": true}),
}).Create(channel).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to update channel"),
Code: http.StatusInternalServerError}
}
us.Db.Model(&models.Channel{}).Where("channel_id != ?", payload.ChannelID).
Where("user_id = ?", userId).Update("selected", false)
key := kv.Key("users", strconv.FormatInt(userId, 10), "channel")
database.KV.Delete(key)
kv.SetValue(database.KV, key, payload.ChannelID)
//add bots as admin if channel is changed in background
go func() {
userId, session := getUserAuth(c)
client, _ := tgc.UserLogin(session)
var botsTokens []string
us.Db.Model(&models.Bot{}).Where("user_id = ?", userId).Pluck("token", &botsTokens)
if len(botsTokens) > 0 {
us.addBots(c, client, userId, payload.ChannelID, botsTokens)
}
}()
return &schemas.Message{Status: true, Message: "channel updated"}, nil
}
func (us *UserService) ListChannels(c *gin.Context) (interface{}, *types.AppError) {
_, session := getUserAuth(c)
client, _ := tgc.UserLogin(session)
channels := make(map[int64]*schemas.Channel)
client.Run(c, func(ctx context.Context) error {
dialogs, _ := query.GetDialogs(client.API()).BatchSize(100).Collect(ctx)
for _, dialog := range dialogs {
if !dialog.Deleted() {
for _, channel := range dialog.Entities.Channels() {
_, exists := channels[channel.ID]
if !exists && channel.Creator {
channels[channel.ID] = &schemas.Channel{ChannelID: channel.ID, ChannelName: channel.Title}
}
}
}
}
return nil
})
return funk.Values(channels), nil
}
func (us *UserService) AddBots(c *gin.Context) (*schemas.Message, *types.AppError) {
userId, session := getUserAuth(c)
client, _ := tgc.UserLogin(session)
var botsTokens []string
if err := c.ShouldBindJSON(&botsTokens); err != nil {
return nil, &types.AppError{Error: errors.New("invalid request payload"), Code: http.StatusBadRequest}
}
if len(botsTokens) == 0 {
return &schemas.Message{Status: true, Message: "no bots to add"}, nil
}
channelId, err := GetDefaultChannel(c, userId)
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
}
return us.addBots(c, client, userId, channelId, botsTokens)
}
func (us *UserService) RevokeBotSession(c *gin.Context) (*schemas.Message, *types.AppError) {
pattern := []byte("botsession:")
err := database.BoltDB.Update(func(tx *bbolt.Tx) error {
bucket := tx.Bucket([]byte("teldrive"))
if bucket == nil {
return errors.New("bucket not found")
}
c := bucket.Cursor()
for key, _ := c.First(); key != nil; key, _ = c.Next() {
if bytes.HasPrefix(key, pattern) {
if err := c.Delete(); err != nil {
return err
}
}
}
return nil
})
if err != nil {
return nil, &types.AppError{Error: errors.New("failed to revoke session"),
Code: http.StatusInternalServerError}
}
return &schemas.Message{Status: true, Message: "session revoked"}, nil
}
func (us *UserService) ClearCache(c *gin.Context) (*schemas.Message, *types.AppError) {
pattern := []byte("users")
err := database.BoltDB.Update(func(tx *bbolt.Tx) error {
bucket := tx.Bucket([]byte("teldrive"))
if bucket == nil {
return errors.New("bucket not found")
}
c := bucket.Cursor()
for key, _ := c.First(); key != nil; key, _ = c.Next() {
if bytes.HasPrefix(key, pattern) {
if err := c.Delete(); err != nil {
return err
}
}
}
return nil
})
if err != nil {
return nil, &types.AppError{Error: errors.New("failed to clear cache"),
Code: http.StatusInternalServerError}
}
return &schemas.Message{Status: true, Message: "cache cleared"}, nil
}
func (us *UserService) addBots(c context.Context, client *telegram.Client, userId int64, channelId int64, botsTokens []string) (*schemas.Message, *types.AppError) {
botInfo := []BotInfo{}
var wg sync.WaitGroup
err := tgc.RunWithAuth(c, client, "", func(ctx context.Context) error {
channel, err := GetChannelById(ctx, client, channelId, strconv.FormatInt(userId, 10))
if err != nil {
return err
}
if err != nil {
return err
}
botInfoChannel := make(chan *BotInfo, len(botsTokens))
waitChan := make(chan struct{}, 6)
for _, token := range botsTokens {
waitChan <- struct{}{}
wg.Add(1)
go func(t string) {
info, err := getBotInfo(c, t)
if err != nil {
return
}
botPeerClass, err := peer.DefaultResolver(client.API()).ResolveDomain(ctx, info.UserName)
if err != nil {
return
}
botPeer := botPeerClass.(*tg.InputPeerUser)
info.AccessHash = botPeer.AccessHash
defer func() {
<-waitChan
wg.Done()
}()
if err == nil {
botInfoChannel <- info
}
}(token)
}
wg.Wait()
close(botInfoChannel)
for result := range botInfoChannel {
botInfo = append(botInfo, *result)
}
if len(botsTokens) == len(botInfo) {
users := funk.Map(botInfo, func(info BotInfo) tg.InputUser {
return tg.InputUser{UserID: info.Id, AccessHash: info.AccessHash}
})
botsToAdd := users.([]tg.InputUser)
for _, user := range botsToAdd {
payload := &tg.ChannelsEditAdminRequest{
Channel: channel,
UserID: tg.InputUserClass(&user),
AdminRights: tg.ChatAdminRights{
ChangeInfo: true,
PostMessages: true,
EditMessages: true,
DeleteMessages: true,
BanUsers: true,
InviteUsers: true,
PinMessages: true,
ManageCall: true,
Other: true,
ManageTopics: true,
},
Rank: "bot",
}
client.API().ChannelsEditAdmin(ctx, payload)
}
} else {
return errors.New("failed to fetch bots")
}
return nil
})
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
}
payload := []models.Bot{}
for _, info := range botInfo {
payload = append(payload, models.Bot{UserID: userId, Token: info.Token, BotID: info.Id,
BotUserName: info.UserName,
})
}
key := kv.Key("users", strconv.FormatInt(userId, 10), "bots")
database.KV.Delete(key)
if err := us.Db.Clauses(clause.OnConflict{DoNothing: true}).Create(&payload).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to add bots"), Code: http.StatusInternalServerError}
}
return &schemas.Message{Status: true, Message: "bots added"}, nil
}

View file

@ -15,6 +15,7 @@ type Part struct {
Size int64
Start int64
End int64
Length int64
}
type JWTClaims struct {
@ -24,6 +25,7 @@ type JWTClaims struct {
UserName string `json:"userName"`
Bot bool `json:"bot"`
IsPremium bool `json:"isPremium"`
Hash string `json:"hash"`
}
type TgSession struct {
@ -39,5 +41,6 @@ type Session struct {
Name string `json:"name"`
UserName string `json:"userName"`
IsPremium bool `json:"isPremium"`
Hash string `json:"hash"`
Expires string `json:"expires"`
}

@ -1 +1 @@
Subproject commit 848127bfd0f1ac563080a34d9eef414f04d14ec0
Subproject commit 4c77129367f4830d868a1e9211ac588d9802972f

View file

@ -8,6 +8,7 @@ import (
"path"
"strings"
"github.com/gin-contrib/gzip"
"github.com/gin-gonic/contrib/static"
"github.com/gin-gonic/gin"
)
@ -24,6 +25,7 @@ func AddRoutes(router gin.IRouter) {
isImg, _ := path.Match("/img/*", c.Request.URL.Path)
if isStatic || isImg {
c.Writer.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
gzip.Gzip(gzip.DefaultCompression)(c)
} else {
c.Writer.Header().Set("Cache-Control", "public, max-age=0, s-maxage=0, must-revalidate")
}

View file

@ -12,9 +12,7 @@ type MultiToken string
type Config struct {
AppId int `envconfig:"APP_ID" required:"true"`
AppHash string `envconfig:"APP_HASH" required:"true"`
ChannelID int64 `envconfig:"CHANNEL_ID" required:"true"`
JwtSecret string `envconfig:"JWT_SECRET" required:"true"`
MultiClient bool `envconfig:"MULTI_CLIENT" default:"false"`
Https bool `envconfig:"HTTPS" default:"false"`
CookieSameSite bool `envconfig:"COOKIE_SAME_SITE" default:"true"`
AllowedUsers []string `envconfig:"ALLOWED_USERS"`
@ -22,7 +20,7 @@ type Config struct {
RateLimit bool `envconfig:"RATE_LIMIT" default:"true"`
RateBurst int `envconfig:"RATE_BURST" default:"5"`
Rate int `envconfig:"RATE" default:"100"`
TgClientDeviceModel string `envconfig:"TG_CLIENT_DEVICE_MODEL" required:"true"`
TgClientDeviceModel string `envconfig:"TG_CLIENT_DEVICE_MODEL" default:"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/116.0"`
TgClientSystemVersion string `envconfig:"TG_CLIENT_SYSTEM_VERSION" default:"Win32"`
TgClientAppVersion string `envconfig:"TG_CLIENT_APP_VERSION" default:"2.1.9 K"`
TgClientLangCode string `envconfig:"TG_CLIENT_LANG_CODE" default:"en"`
@ -39,8 +37,6 @@ func InitConfig() {
execDir := getExecutableDir()
godotenv.Load(filepath.Join(execDir, ".env"))
godotenv.Load(filepath.Join(execDir, "teldrive.env"))
err := envconfig.Process("", &config)
if err != nil {

View file

@ -2,71 +2,104 @@ package cron
import (
"context"
"fmt"
"database/sql/driver"
"encoding/json"
"strconv"
"github.com/divyam234/teldrive/cache"
"github.com/divyam234/teldrive/database"
"github.com/divyam234/teldrive/models"
"github.com/divyam234/teldrive/utils"
"github.com/divyam234/teldrive/services"
"github.com/divyam234/teldrive/utils/tgc"
"github.com/gotd/td/tg"
)
type Result struct {
ID string
Parts models.Parts
TgSession string
UserId int64
ChannelId int64
type Files []File
type File struct {
ID string `json:"id"`
Parts []models.Part `json:"parts"`
}
func deleteTGMessage(ctx context.Context, client *tg.Client, result Result) error {
func (a Files) Value() (driver.Value, error) {
return json.Marshal(a)
}
ids := make([]int, len(result.Parts))
for i, part := range result.Parts {
ids[i] = int(part.ID)
}
res, err := cache.CachedFunction(utils.GetChannelById, fmt.Sprintf("channels:%d", result.ChannelId))(ctx, client, result.ChannelId)
if err != nil {
return err
}
channel := res.(*tg.Channel)
messageDeleteRequest := tg.ChannelsDeleteMessagesRequest{Channel: &tg.InputChannel{ChannelID: result.ChannelId, AccessHash: channel.AccessHash},
ID: ids}
_, err = client.ChannelsDeleteMessages(ctx, &messageDeleteRequest)
if err != nil {
func (a *Files) Scan(value interface{}) error {
if err := json.Unmarshal(value.([]byte), &a); err != nil {
return err
}
return nil
}
type Result struct {
Files Files
TgSession string
UserId int64
ChannelId int64
}
func deleteTGMessages(ctx context.Context, result Result) error {
db := database.DB
client, err := tgc.UserLogin(result.TgSession)
if err != nil {
return err
}
ids := []int{}
fileIds := []string{}
for _, file := range result.Files {
fileIds = append(fileIds, file.ID)
for _, part := range file.Parts {
ids = append(ids, int(part.ID))
}
}
err = tgc.RunWithAuth(ctx, client, "", func(ctx context.Context) error {
channel, err := services.GetChannelById(ctx, client, result.ChannelId, strconv.FormatInt(result.UserId, 10))
if err != nil {
return err
}
messageDeleteRequest := tg.ChannelsDeleteMessagesRequest{Channel: channel, ID: ids}
_, err = client.API().ChannelsDeleteMessages(ctx, &messageDeleteRequest)
if err != nil {
return err
}
return nil
})
if err == nil {
db.Where("id = any($1)", fileIds).Delete(&models.File{})
}
return nil
}
func FilesDeleteJob() {
db := database.DB
ctx := context.Background()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var results []Result
if err := db.Model(&models.File{}).Select("files.id", "files.parts", "files.user_id", "files.channel_id", "u.tg_session").
if err := db.Model(&models.File{}).
Select("JSONB_AGG(jsonb_build_object('id',files.id, 'parts',files.parts)) as files", "files.channel_id", "files.user_id", "u.tg_session").
Joins("left join teldrive.users as u on u.user_id = files.user_id").
Where("status = ?", "pending_deletion").Scan(&results).Error; err != nil {
Where("type = ?", "file").
Where("status = ?", "pending_deletion").
Group("files.channel_id").Group("files.user_id").Group("u.tg_session").
Scan(&results).Error; err != nil {
return
}
for _, file := range results {
client, err := utils.GetAuthClient(ctx, file.TgSession, file.UserId)
if err != nil {
break
}
err = client.Run(ctx, func(ctx context.Context) error {
err = deleteTGMessage(ctx, client.API(), file)
return err
})
if err == nil {
db.Where("id = ?", file.ID).Delete(&models.File{})
}
for _, row := range results {
deleteTGMessages(ctx, row)
}
}

38
utils/kv/bolt.go Normal file
View file

@ -0,0 +1,38 @@
package kv
import (
"go.etcd.io/bbolt"
)
type Bolt struct {
bucket []byte
db *bbolt.DB
}
func (b *Bolt) Get(key string) ([]byte, error) {
var val []byte
if err := b.db.View(func(tx *bbolt.Tx) error {
val = tx.Bucket(b.bucket).Get([]byte(key))
return nil
}); err != nil {
return nil, err
}
if val == nil {
return nil, ErrNotFound
}
return val, nil
}
func (b *Bolt) Set(key string, val []byte) error {
return b.db.Update(func(tx *bbolt.Tx) error {
return tx.Bucket(b.bucket).Put([]byte(key), val)
})
}
func (b *Bolt) Delete(key string) error {
return b.db.Update(func(tx *bbolt.Tx) error {
return tx.Bucket(b.bucket).Delete([]byte(key))
})
}

27
utils/kv/cache.go Normal file
View file

@ -0,0 +1,27 @@
package kv
import (
"encoding/json"
)
func GetValue(kv KV, key string, target interface{}) error {
data, err := kv.Get(key)
if err != nil {
return err
}
if err := json.Unmarshal(data, target); err != nil {
return err
}
return nil
}
func SetValue(kv KV, key string, value interface{}) error {
data, err := json.Marshal(value)
if err != nil {
return err
}
return kv.Set(key, data)
}

9
utils/kv/key.go Normal file
View file

@ -0,0 +1,9 @@
package kv
import (
"strings"
)
func Key(indexes ...string) string {
return strings.Join(indexes, ":")
}

32
utils/kv/kv.go Normal file
View file

@ -0,0 +1,32 @@
package kv
import (
"errors"
"go.etcd.io/bbolt"
)
var ErrNotFound = errors.New("key not found")
type KV interface {
Get(key string) ([]byte, error)
Set(key string, value []byte) error
Delete(key string) error
}
type Options struct {
Bucket string
DB *bbolt.DB
}
func New(opts Options) (KV, error) {
if err := opts.DB.Update(func(tx *bbolt.Tx) error {
_, err := tx.CreateBucketIfNotExists([]byte(opts.Bucket))
return err
}); err != nil {
return nil, err
}
return &Bolt{db: opts.DB, bucket: []byte(opts.Bucket)}, nil
}

33
utils/kv/session.go Normal file
View file

@ -0,0 +1,33 @@
package kv
import (
"context"
"errors"
"github.com/gotd/td/telegram"
)
type Session struct {
kv KV
key string
}
func NewSession(kv KV, key string) telegram.SessionStorage {
return &Session{kv: kv, key: key}
}
func (s *Session) LoadSession(_ context.Context) ([]byte, error) {
b, err := s.kv.Get(s.key)
if err != nil {
if errors.Is(err, ErrNotFound) {
return nil, nil
}
return nil, err
}
return b, nil
}
func (s *Session) StoreSession(_ context.Context, data []byte) error {
return s.kv.Set(s.key, data)
}

View file

@ -1,8 +1,6 @@
package utils
import (
"context"
"fmt"
"os"
"path/filepath"
"regexp"
@ -11,7 +9,6 @@ import (
"reflect"
"github.com/gotd/td/tg"
"golang.org/x/exp/constraints"
"unicode"
@ -65,25 +62,6 @@ func GetField(v interface{}, field string) string {
}
}
func GetChannelById(ctx context.Context, client *tg.Client, channelID int64) (*tg.Channel, error) {
inputChannel := &tg.InputChannel{
ChannelID: channelID,
AccessHash: 0,
}
channels, err := client.ChannelsGetChannels(ctx, []tg.InputChannelClass{inputChannel})
if err != nil {
return nil, fmt.Errorf("failed to fetch channel: %w", err)
}
if len(channels.GetChats()) == 0 {
return nil, fmt.Errorf("no channels found")
}
channel := channels.GetChats()[0].(*tg.Channel)
return channel, nil
}
func BoolPointer(b bool) *bool {
return &b
}

172
utils/reader/lr.go Normal file
View file

@ -0,0 +1,172 @@
package reader
import (
"context"
"fmt"
"io"
"sync"
"github.com/divyam234/teldrive/types"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
)
type linearReader struct {
ctx context.Context
parts []types.Part
pos int
reader io.ReadCloser
client *telegram.Client
bytesread int64
chunkSize int64
sync.Mutex
}
func NewLinearReader(ctx context.Context, client *telegram.Client, parts []types.Part) (io.ReadCloser, error) {
r := &linearReader{
ctx: ctx,
parts: parts,
client: client,
chunkSize: int64(1024 * 1024),
}
res, err := r.nextPart()
if err != nil {
return nil, err
}
r.reader = res
return r, nil
}
func (r *linearReader) Read(p []byte) (n int, err error) {
r.Lock()
defer r.Unlock()
n, err = r.reader.Read(p)
if err != nil {
return 0, err
}
r.bytesread += int64(n)
if r.bytesread == r.parts[r.pos].Length && r.pos < len(r.parts)-1 {
r.pos++
r.reader, err = r.nextPart()
if err != nil {
return 0, err
}
r.bytesread = 0
}
return n, nil
}
func (r *linearReader) Close() (err error) {
if r.reader != nil {
err = r.reader.Close()
r.reader = nil
}
return
}
func (r *linearReader) chunk(offset int64, limit int64) ([]byte, error) {
req := &tg.UploadGetFileRequest{
Offset: offset,
Limit: int(limit),
Location: r.parts[r.pos].Location,
}
res, err := r.client.API().UploadGetFile(r.ctx, req)
if err != nil {
return nil, err
}
switch result := res.(type) {
case *tg.UploadFile:
return result.Bytes, nil
default:
return nil, fmt.Errorf("unexpected type %T", r)
}
}
func (r *linearReader) nextPart() (io.ReadCloser, error) {
stream := r.tgRangeStream()
ir, iw := io.Pipe()
go func() {
defer iw.Close()
for {
data, more := <-stream
if !more {
return
}
_, err := iw.Write(data)
if err != nil {
return
}
}
}()
return ir, nil
}
func (r *linearReader) tgRangeStream() chan []byte {
start := r.parts[r.pos].Start
end := r.parts[r.pos].End
offset := start - (start % r.chunkSize)
firstPartCut := start - offset
lastPartCut := (end % r.chunkSize) + 1
partCount := int((end - offset + r.chunkSize) / r.chunkSize)
currentPart := 1
stream := make(chan []byte)
go func() {
defer close(stream)
for {
res, _ := r.chunk(offset, r.chunkSize)
if len(res) == 0 {
return
} else if partCount == 1 {
res = res[firstPartCut:lastPartCut]
} else if currentPart == 1 {
res = res[firstPartCut:]
} else if currentPart == partCount {
res = res[:lastPartCut]
}
stream <- res
currentPart++
offset += r.chunkSize
if currentPart > partCount {
return
}
}
}()
return stream
}

42
utils/tgc/run.go Normal file
View file

@ -0,0 +1,42 @@
package tgc
import (
"context"
"github.com/divyam234/teldrive/utils"
"github.com/gotd/td/telegram"
"github.com/pkg/errors"
"go.uber.org/zap"
)
func RunWithAuth(ctx context.Context, client *telegram.Client, token string, f func(ctx context.Context) error) error {
return client.Run(ctx, func(ctx context.Context) error {
status, err := client.Auth().Status(ctx)
if err != nil {
return err
}
if token == "" {
if !status.Authorized {
return errors.Errorf("not authorized. please login first")
}
utils.Logger.Info("User Session",
zap.Int64("id", status.User.ID),
zap.String("username", status.User.Username))
} else {
if !status.Authorized {
utils.Logger.Info("creating bot session")
_, err := client.Auth().Bot(ctx, token)
if err != nil {
return err
}
status, _ = client.Auth().Status(ctx)
utils.Logger.Info("Bot Session",
zap.Int64("id", status.User.ID),
zap.String("username", status.User.Username))
}
}
return f(ctx)
})
}

98
utils/tgc/tgc.go Normal file
View file

@ -0,0 +1,98 @@
package tgc
import (
"context"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/divyam234/teldrive/database"
"github.com/divyam234/teldrive/utils"
"github.com/divyam234/teldrive/utils/kv"
"github.com/gotd/contrib/middleware/floodwait"
"github.com/gotd/contrib/middleware/ratelimit"
tdclock "github.com/gotd/td/clock"
"github.com/gotd/td/session"
"github.com/gotd/td/telegram"
"golang.org/x/time/rate"
)
func deviceConfig(appConfig *utils.Config) telegram.DeviceConfig {
config := telegram.DeviceConfig{
DeviceModel: appConfig.TgClientDeviceModel,
SystemVersion: appConfig.TgClientSystemVersion,
AppVersion: appConfig.TgClientAppVersion,
SystemLangCode: appConfig.TgClientSystemLangCode,
LangPack: appConfig.TgClientLangPack,
LangCode: appConfig.TgClientLangCode,
}
return config
}
func New(handler telegram.UpdateHandler, storage session.Storage, middlewares ...telegram.Middleware) *telegram.Client {
_clock := tdclock.System
config := utils.GetConfig()
noUpdates := true
if handler != nil {
noUpdates = false
}
opts := telegram.Options{
ReconnectionBackoff: func() backoff.BackOff {
b := backoff.NewExponentialBackOff()
b.Multiplier = 1.1
b.MaxElapsedTime = time.Duration(120) * time.Second
b.Clock = _clock
return b
},
Device: deviceConfig(config),
SessionStorage: storage,
RetryInterval: 5 * time.Second,
MaxRetries: 5,
DialTimeout: 10 * time.Second,
Middlewares: middlewares,
Clock: _clock,
NoUpdates: noUpdates,
UpdateHandler: handler,
}
return telegram.NewClient(config.AppId, config.AppHash, opts)
}
func NoLogin(handler telegram.UpdateHandler, storage session.Storage) *telegram.Client {
middlewares := []telegram.Middleware{floodwait.NewSimpleWaiter()}
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*100), 5))
return New(handler, storage, middlewares...)
}
func UserLogin(sessionStr string) (*telegram.Client, error) {
data, err := session.TelethonSession(sessionStr)
if err != nil {
return nil, err
}
var (
storage = new(session.StorageMemory)
loader = session.Loader{Storage: storage}
)
if err := loader.Save(context.TODO(), data); err != nil {
return nil, err
}
middlewares := []telegram.Middleware{floodwait.NewSimpleWaiter()}
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*100), 5))
return New(nil, storage, middlewares...), nil
}
func BotLogin(token string) (*telegram.Client, error) {
config := utils.GetConfig()
storage := kv.NewSession(database.KV, kv.Key("botsession", token))
middlewares := []telegram.Middleware{floodwait.NewSimpleWaiter()}
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*time.Duration(config.Rate)), config.RateBurst))
return New(nil, storage, middlewares...), nil
}

29
utils/tgc/workers.go Normal file
View file

@ -0,0 +1,29 @@
package tgc
import (
"sync"
)
type BotWorkers struct {
sync.Mutex
bots []string
index int
}
func (w *BotWorkers) Set(bots []string) {
w.Lock()
defer w.Unlock()
if len(w.bots) == 0 {
w.bots = bots
}
}
func (w *BotWorkers) Next() string {
w.Lock()
defer w.Unlock()
item := w.bots[w.index]
w.index = (w.index + 1) % len(w.bots)
return item
}
var Workers = &BotWorkers{}

View file

@ -1,297 +0,0 @@
package utils
import (
"context"
"fmt"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/divyam234/teldrive/types"
"github.com/gin-gonic/gin"
"github.com/gotd/contrib/bg"
"github.com/gotd/contrib/middleware/floodwait"
"github.com/gotd/contrib/middleware/ratelimit"
tdclock "github.com/gotd/td/clock"
"github.com/gotd/td/session"
"github.com/gotd/td/telegram"
"github.com/pkg/errors"
"go.uber.org/zap"
"golang.org/x/time/rate"
)
var clients map[int64]*telegram.Client
type Workload struct {
mu sync.Mutex
workloads map[int]int
}
func (w *Workload) Set(key int, value int) {
w.mu.Lock()
defer w.mu.Unlock()
w.workloads[key] = value
}
func (w *Workload) Get(key int) int {
w.mu.Lock()
defer w.mu.Unlock()
return w.workloads[key]
}
func (w *Workload) Inc(key int) {
w.mu.Lock()
defer w.mu.Unlock()
w.workloads[key]++
}
func (w *Workload) Dec(key int) {
w.mu.Lock()
defer w.mu.Unlock()
w.workloads[key]--
}
func (w *Workload) GetMinIndex() int {
w.mu.Lock()
defer w.mu.Unlock()
smallest := w.workloads[0]
idx := 0
for i, workload := range clientWorkload.workloads {
if workload < smallest {
smallest = workload
idx = i
}
}
return idx
}
var clientWorkload *Workload
func GetClientWorkload() *Workload {
return clientWorkload
}
func getDeviceConfig() telegram.DeviceConfig {
appConfig := GetConfig()
config := telegram.DeviceConfig{
DeviceModel: appConfig.TgClientDeviceModel,
SystemVersion: appConfig.TgClientSystemVersion,
AppVersion: appConfig.TgClientAppVersion,
SystemLangCode: appConfig.TgClientSystemLangCode,
LangPack: appConfig.TgClientLangPack,
LangCode: appConfig.TgClientLangCode,
}
return config
}
func reconnectionBackoff() backoff.BackOff {
_clock := tdclock.System
b := backoff.NewExponentialBackOff()
b.Multiplier = 1.1
b.MaxElapsedTime = time.Duration(120) * time.Second
b.Clock = _clock
return b
}
func GetBotClient(clientName string) *telegram.Client {
config := GetConfig()
sessionStorage := &telegram.FileSessionStorage{
Path: filepath.Join(config.ExecDir, "sessions", clientName+".json"),
}
middlewares := []telegram.Middleware{floodwait.NewSimpleWaiter()}
if config.RateLimit {
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*time.Duration(config.Rate)), config.RateBurst))
}
options := telegram.Options{
SessionStorage: sessionStorage,
Middlewares: middlewares,
ReconnectionBackoff: reconnectionBackoff,
RetryInterval: 5 * time.Second,
MaxRetries: 5,
Device: getDeviceConfig(),
Clock: tdclock.System,
}
client := telegram.NewClient(config.AppId, config.AppHash, options)
return client
}
func GetAuthClient(ctx context.Context, sessionStr string, userId int64) (*telegram.Client, error) {
data, err := session.TelethonSession(sessionStr)
if err != nil {
return nil, err
}
var (
storage = new(session.StorageMemory)
loader = session.Loader{Storage: storage}
)
if err := loader.Save(ctx, data); err != nil {
return nil, err
}
middlewares := []telegram.Middleware{floodwait.NewSimpleWaiter()}
if config.RateLimit {
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*100), 5))
}
client := telegram.NewClient(config.AppId, config.AppHash, telegram.Options{
SessionStorage: storage,
Middlewares: middlewares,
ReconnectionBackoff: reconnectionBackoff,
RetryInterval: 5 * time.Second,
MaxRetries: 5,
Device: getDeviceConfig(),
Clock: tdclock.System,
})
return client, nil
}
func StartNonAuthClient(handler telegram.UpdateHandler, storage telegram.SessionStorage) (*telegram.Client, bg.StopFunc, error) {
middlewares := []telegram.Middleware{}
if config.RateLimit {
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*100), 5))
}
client := telegram.NewClient(config.AppId, config.AppHash, telegram.Options{
SessionStorage: storage,
Middlewares: middlewares,
Device: getDeviceConfig(),
UpdateHandler: handler,
})
stop, err := bg.Connect(client)
if err != nil {
return nil, nil, err
}
return client, stop, nil
}
func startBotClient(ctx context.Context, client *telegram.Client, token string) (bg.StopFunc, error) {
stop, err := bg.Connect(client)
if err != nil {
return nil, errors.Wrap(err, "failed to start client")
}
tguser, err := client.Self(ctx)
if err != nil {
if _, err := client.Auth().Bot(ctx, token); err != nil {
return nil, err
}
tguser, _ = client.Self(ctx)
}
Logger.Info("started Client", zap.String("user", tguser.Username))
return stop, nil
}
func startAuthClient(c *gin.Context, client *telegram.Client) (bg.StopFunc, error) {
stop, err := bg.Connect(client)
if err != nil {
return nil, err
}
tguser, err := client.Self(c)
if err != nil {
return nil, err
}
Logger.Info("started Client", zap.String("user", tguser.Username))
clients[tguser.GetID()] = client
return stop, nil
}
func InitBotClients() {
ctx := context.Background()
clients = make(map[int64]*telegram.Client)
clientWorkload = &Workload{workloads: make(map[int]int)}
if config.MultiClient {
if err := os.MkdirAll(filepath.Join(config.ExecDir, "sessions"), 0700); err != nil {
return
}
var keysToSort []string
for _, e := range os.Environ() {
if strings.HasPrefix(e, "MULTI_TOKEN") {
if i := strings.Index(e, "="); i >= 0 {
keysToSort = append(keysToSort, e[:i])
}
}
}
sort.Strings(keysToSort)
for idx, key := range keysToSort {
client := GetBotClient(fmt.Sprintf("client%d", idx))
clientWorkload.Set(idx, 0)
clients[int64(idx)] = client
go func(k string) {
startBotClient(ctx, client, os.Getenv(k))
}(key)
}
}
}
func GetUploadClient(c *gin.Context) (*telegram.Client, int) {
if config.MultiClient {
idx := clientWorkload.GetMinIndex()
clientWorkload.Inc(idx)
return GetBotClient(fmt.Sprintf("client%d", idx)), idx
} else {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
client, _ := GetAuthClient(c, jwtUser.TgSession, userId)
return client, -1
}
}
func GetDownloadClient(c *gin.Context) (*telegram.Client, int) {
if config.MultiClient {
idx := clientWorkload.GetMinIndex()
clientWorkload.Inc(idx)
return clients[int64(idx)], idx
} else {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
if client, ok := clients[userId]; ok {
return client, -1
}
client, _ := GetAuthClient(c, jwtUser.TgSession, userId)
startAuthClient(c, client)
return client, -1
}
}