diff --git a/docs/Configuration.md b/docs/Configuration.md index 65bfab57..7525e609 100644 --- a/docs/Configuration.md +++ b/docs/Configuration.md @@ -27,21 +27,24 @@ Most configuration can be set directly using environment variables or flags. The ### HTTP configuration variables -| Environment variable | Default | Required | Description | -| ------------------------------------------ | ------- | -------- | ----------------------------------------------------- | -| `SHIORI_HTTP_ENABLED` | True | No | Enable HTTP service | -| `SHIORI_HTTP_PORT` | 8080 | No | Port number for the HTTP service | -| `SHIORI_HTTP_ADDRESS` | : | No | Address for the HTTP service | -| `SHIORI_HTTP_ROOT_PATH` | / | No | Root path for the HTTP service | -| `SHIORI_HTTP_ACCESS_LOG` | True | No | Logging accessibility for HTTP requests | -| `SHIORI_HTTP_SERVE_WEB_UI` | True | No | Serving Web UI via HTTP. Disable serves only the API. | -| `SHIORI_HTTP_SECRET_KEY` | | **Yes** | Secret key for HTTP sessions. | -| `SHIORI_HTTP_BODY_LIMIT` | 1024 | No | Limit for request body size | -| `SHIORI_HTTP_READ_TIMEOUT` | 10s | No | Maximum duration for reading the entire request | -| `SHIORI_HTTP_WRITE_TIMEOUT` | 10s | No | Maximum duration before timing out writes | -| `SHIORI_HTTP_IDLE_TIMEOUT` | 10s | No | Maximum amount of time to wait for the next request | -| `SHIORI_HTTP_DISABLE_KEEP_ALIVE` | true | No | Disable HTTP keep-alive connections | -| `SHIORI_HTTP_DISABLE_PARSE_MULTIPART_FORM` | true | No | Disable pre-parsing of multipart form | +| Environment variable | Default | Required | Description | +| ------------------------------------------ | ------- | -------- | ----------------------------------------------------- | +| `SHIORI_HTTP_ENABLED` | True | No | Enable HTTP service | +| `SHIORI_HTTP_PORT` | 8080 | No | Port number for the HTTP service | +| `SHIORI_HTTP_ADDRESS` | : | No | Address for the HTTP service | +| `SHIORI_HTTP_ROOT_PATH` | / | No | Root path for the HTTP service | +| `SHIORI_HTTP_ACCESS_LOG` | True | No | Logging accessibility for HTTP requests | +| `SHIORI_HTTP_SERVE_WEB_UI` | True | No | Serving Web UI via HTTP. Disable serves only the API. | +| `SHIORI_HTTP_SECRET_KEY` | | **Yes** | Secret key for HTTP sessions. | +| `SHIORI_HTTP_BODY_LIMIT` | 1024 | No | Limit for request body size | +| `SHIORI_HTTP_READ_TIMEOUT` | 10s | No | Maximum duration for reading the entire request | +| `SHIORI_HTTP_WRITE_TIMEOUT` | 10s | No | Maximum duration before timing out writes | +| `SHIORI_HTTP_IDLE_TIMEOUT` | 10s | No | Maximum amount of time to wait for the next request | +| `SHIORI_HTTP_DISABLE_KEEP_ALIVE` | true | No | Disable HTTP keep-alive connections | +| `SHIORI_HTTP_DISABLE_PARSE_MULTIPART_FORM` | true | No | Disable pre-parsing of multipart form | +| `SHIORI_SSO_PROXY_AUTH_ENABLED` | false | No | Enable SSO Auth Proxy Header | +| `SHIORI_SSO_PROXY_AUTH_HEADER_NAME` | Remote-User | No | List of CIDRs of trusted proxies | +| `SHIORI_SSO_PROXY_AUTH_TRUSTED` | 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7 | No | List of CIDRs of trusted proxies | ### Storage Configuration diff --git a/internal/config/config.go b/internal/config/config.go index bc260ce4..90a05999 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -65,6 +65,10 @@ type HttpConfig struct { IDLETimeout time.Duration `env:"HTTP_IDLE_TIMEOUT,default=10s"` DisableKeepAlive bool `env:"HTTP_DISABLE_KEEP_ALIVE,default=true"` DisablePreParseMultipartForm bool `env:"HTTP_DISABLE_PARSE_MULTIPART_FORM,default=true"` + + SSOProxyAuth bool `env:"SSO_PROXY_AUTH_ENABLED,default=false"` + SSOProxyAuthHeaderName string `env:"SSO_PROXY_AUTH_HEADER_NAME,default=Remote-User"` + SSOProxyAuthTrusted []string `env:"SSO_PROXY_AUTH_TRUSTED,default=10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7"` } // SetDefaults sets the default values for the configuration @@ -152,6 +156,9 @@ func (c *Config) DebugConfiguration(logger *logrus.Logger) { logger.Debugf(" SHIORI_HTTP_IDLE_TIMEOUT: %s", c.Http.IDLETimeout) logger.Debugf(" SHIORI_HTTP_DISABLE_KEEP_ALIVE: %t", c.Http.DisableKeepAlive) logger.Debugf(" SHIORI_HTTP_DISABLE_PARSE_MULTIPART_FORM: %t", c.Http.DisablePreParseMultipartForm) + logger.Debugf(" SHIORI_SSO_PROXY_AUTH_ENABLED: %t", c.Http.SSOProxyAuth) + logger.Debugf(" SHIORI_SSO_PROXY_AUTH_HEADER_NAME: %s", c.Http.SSOProxyAuthHeaderName) + logger.Debugf(" SHIORI_SSO_PROXY_AUTH_TRUSTED: %v", c.Http.SSOProxyAuthTrusted) } func (c *Config) IsValid() error { diff --git a/internal/domains/accounts.go b/internal/domains/accounts.go index 1a103c93..1a1f8d63 100644 --- a/internal/domains/accounts.go +++ b/internal/domains/accounts.go @@ -29,6 +29,24 @@ func (d *AccountsDomain) ListAccounts(ctx context.Context) ([]model.AccountDTO, return accountDTOs, nil } +func (d *AccountsDomain) GetAccountByUsername(ctx context.Context, username string) (*model.AccountDTO, error) { + if username == "" { + return nil, errors.New("empty username") + } + + accounts, err := d.deps.Database().ListAccounts(ctx, model.DBListAccountsOptions{ + Username: username, + }) + if err != nil { + return nil, fmt.Errorf("error getting accounts: %v", err) + } + if len(accounts) != 1 { + return nil, fmt.Errorf("got none or more than one account by username: %s", username) + } + + return model.Ptr(accounts[0].ToDTO()), nil +} + func (d *AccountsDomain) CreateAccount(ctx context.Context, account model.AccountDTO) (*model.AccountDTO, error) { if err := account.IsValidCreate(); err != nil { return nil, err diff --git a/internal/domains/accounts_test.go b/internal/domains/accounts_test.go index 116d90f5..fe71f9a7 100644 --- a/internal/domains/accounts_test.go +++ b/internal/domains/accounts_test.go @@ -38,6 +38,30 @@ func TestAccountDomainsListAccounts(t *testing.T) { }) } +func TestAccountDomainsGetAccountByUsername(t *testing.T) { + logger := logrus.New() + _, deps := testutil.GetTestConfigurationAndDependencies(t, context.TODO(), logger) + + t.Run("empty", func(t *testing.T) { + account, err := deps.Domains().Accounts().GetAccountByUsername(context.Background(), "") + require.Error(t, err) + require.Nil(t, account) + }) + + t.Run("account found", func(t *testing.T) { + _, err := deps.Domains().Accounts().CreateAccount(context.TODO(), model.AccountDTO{ + Username: "user1", + Password: "password1", + }) + require.NoError(t, err) + + account, err := deps.Domains().Accounts().GetAccountByUsername(context.Background(), "user1") + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, "user1", account.Username) + }) +} + func TestAccountDomainCreateAccount(t *testing.T) { logger := logrus.New() _, deps := testutil.GetTestConfigurationAndDependencies(t, context.TODO(), logger) diff --git a/internal/http/middleware/auth.go b/internal/http/middleware/auth.go index a45b5be4..8fb0928a 100644 --- a/internal/http/middleware/auth.go +++ b/internal/http/middleware/auth.go @@ -21,6 +21,10 @@ func NewAuthMiddleware(deps model.Dependencies) *AuthMiddleware { } func (m *AuthMiddleware) OnRequest(deps model.Dependencies, c model.WebContext) error { + if c.UserIsLogged() { + return nil + } + token := getTokenFromHeader(c.Request()) if token == "" { token = getTokenFromCookie(c.Request()) diff --git a/internal/http/middleware/auth_sso_proxy.go b/internal/http/middleware/auth_sso_proxy.go new file mode 100644 index 00000000..ac7e6050 --- /dev/null +++ b/internal/http/middleware/auth_sso_proxy.go @@ -0,0 +1,104 @@ +package middleware + +import ( + "errors" + "net" + + "github.com/go-shiori/shiori/internal/model" +) + +// AuthMiddleware handles authentication for incoming request by checking the token +// from the Authorization header or the token cookie and setting the account in the +// request context. +type AuthSSOProxyMiddleware struct { + deps model.Dependencies + + trustedIPs []*net.IPNet +} + +func NewAuthSSOProxyMiddleware(deps model.Dependencies) *AuthSSOProxyMiddleware { + plainIPs := deps.Config().Http.SSOProxyAuthTrusted + trustedIPs := make([]*net.IPNet, len(plainIPs)) + for i, ip := range plainIPs { + _, ipNet, err := net.ParseCIDR(ip) + if err != nil { + deps.Logger().WithError(err).WithField("ip", ip).Error("Failed to parse trusted ip cidr") + continue + } + + trustedIPs[i] = ipNet + } + + return &AuthSSOProxyMiddleware{ + deps: deps, + trustedIPs: trustedIPs, + } +} + +func (m *AuthSSOProxyMiddleware) OnRequest(deps model.Dependencies, c model.WebContext) error { + if c.UserIsLogged() { + return nil + } + + account, err := m.ssoAccount(deps, c) + if err != nil { + deps.Logger(). + WithError(err). + WithField("remote_addr", c.Request().RemoteAddr). + WithField("request_id", c.GetRequestID()). + Error("getting sso account") + return nil + } + if account != nil { + c.SetAccount(account) + return nil + } + + return nil +} + +func (m *AuthSSOProxyMiddleware) ssoAccount(deps model.Dependencies, c model.WebContext) (*model.AccountDTO, error) { + if !deps.Config().Http.SSOProxyAuth { + return nil, nil + } + + remoteAddr := c.Request().RemoteAddr + ip, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + var addrErr *net.AddrError + if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" { + ip = remoteAddr + } else { + return nil, err + } + } + requestIP := net.ParseIP(ip) + if !m.isTrustedIP(requestIP) { + return nil, errors.New("remoteAddr is not a trusted ip") + } + + headerName := deps.Config().Http.SSOProxyAuthHeaderName + userName := c.Request().Header.Get(headerName) + if userName == "" { + return nil, nil + } + + account, err := deps.Domains().Accounts().GetAccountByUsername(c.Request().Context(), userName) + if err != nil { + return nil, err + } + + return account, nil +} +func (m *AuthSSOProxyMiddleware) isTrustedIP(ip net.IP) bool { + for _, net := range m.trustedIPs { + if ok := net.Contains(ip); ok { + return true + } + } + return false +} + +func (m *AuthSSOProxyMiddleware) OnResponse(deps model.Dependencies, c model.WebContext) error { + return nil +} diff --git a/internal/http/middleware/auth_sso_proxy_test.go b/internal/http/middleware/auth_sso_proxy_test.go new file mode 100644 index 00000000..adf48f7d --- /dev/null +++ b/internal/http/middleware/auth_sso_proxy_test.go @@ -0,0 +1,101 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-shiori/shiori/internal/http/webcontext" + "github.com/go-shiori/shiori/internal/model" + "github.com/go-shiori/shiori/internal/testutil" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" +) + +func TestAuthMiddlewareWithSSO(t *testing.T) { + logger := logrus.New() + _, deps := testutil.GetTestConfigurationAndDependencies(t, context.TODO(), logger) + deps.Config().Http.SSOProxyAuth = true + + account, err := deps.Domains().Accounts().CreateAccount(context.TODO(), model.AccountDTO{ + ID: model.DBID(98), + Username: "test_username", + Password: "super_secure_password", + }) + require.NoError(t, err) + + t.Run("test no authorization method", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + c := webcontext.NewWebContext(w, r) + + middleware := NewAuthSSOProxyMiddleware(deps) + err := middleware.OnRequest(deps, c) + require.NoError(t, err) + require.Nil(t, c.GetAccount()) + }) + + t.Run("test untrusted ip", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.RemoteAddr = "invalid-ip" + c := webcontext.NewWebContext(w, r) + + middleware := NewAuthSSOProxyMiddleware(deps) + err := middleware.OnRequest(deps, c) + require.NoError(t, err) + require.Nil(t, c.GetAccount()) + }) + + t.Run("test empty header", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.RemoteAddr = "10.0.0.3" + c := webcontext.NewWebContext(w, r) + + middleware := NewAuthSSOProxyMiddleware(deps) + err := middleware.OnRequest(deps, c) + require.NoError(t, err) + require.Nil(t, c.GetAccount()) + }) + + t.Run("test invalid sso username", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.RemoteAddr = "10.0.0.3" + r.Header.Add("Remote-User", "username") + c := webcontext.NewWebContext(w, r) + + middleware := NewAuthSSOProxyMiddleware(deps) + err := middleware.OnRequest(deps, c) + require.NoError(t, err) + require.Nil(t, c.GetAccount()) + }) + + t.Run("test sso login", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.RemoteAddr = "10.0.0.3" + r.Header.Add("Remote-User", account.Username) + c := webcontext.NewWebContext(w, r) + + middleware := NewAuthSSOProxyMiddleware(deps) + err := middleware.OnRequest(deps, c) + require.NoError(t, err) + require.NotNil(t, c.GetAccount()) + }) + + t.Run("test sso login ip:port", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.RemoteAddr = "10.0.0.3:65342" + r.Header.Add("Remote-User", account.Username) + c := webcontext.NewWebContext(w, r) + + middleware := NewAuthSSOProxyMiddleware(deps) + err := middleware.OnRequest(deps, c) + require.NoError(t, err) + require.NotNil(t, c.GetAccount()) + }) +} diff --git a/internal/http/server.go b/internal/http/server.go index ca70ce5e..8f77f684 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -33,6 +33,10 @@ func (s *HttpServer) Setup(cfg *config.Config, deps *dependencies.Dependencies) globalMiddleware := []model.HttpMiddleware{} + if cfg.Http.SSOProxyAuth { + globalMiddleware = append(globalMiddleware, middleware.NewAuthSSOProxyMiddleware(deps)) + } + // Add message response middleware if legacy message response is enabled globalMiddleware = append(globalMiddleware, []model.HttpMiddleware{ middleware.NewMessageResponseMiddleware(deps), diff --git a/internal/model/domains.go b/internal/model/domains.go index fe8ee91b..cdbf23c5 100644 --- a/internal/model/domains.go +++ b/internal/model/domains.go @@ -31,6 +31,7 @@ type AuthDomain interface { type AccountsDomain interface { ListAccounts(ctx context.Context) ([]AccountDTO, error) + GetAccountByUsername(ctx context.Context, username string) (*AccountDTO, error) CreateAccount(ctx context.Context, account AccountDTO) (*AccountDTO, error) UpdateAccount(ctx context.Context, account AccountDTO) (*AccountDTO, error) DeleteAccount(ctx context.Context, id int) error diff --git a/internal/view/index.html b/internal/view/index.html index fd0b6c0c..bd8dc2e5 100644 --- a/internal/view/index.html +++ b/internal/view/index.html @@ -160,8 +160,8 @@ }, onLoginSuccess() { - this.loadSetting(); this.loadAccount(); + this.loadSetting(); this.isLoggedIn = true; }, @@ -169,10 +169,6 @@ const token = localStorage.getItem("shiori-token"); const account = localStorage.getItem("shiori-account"); - if (!(token && account)) { - return false; - } - try { const response = await fetch(new URL("api/v1/auth/me", document.baseURI), { headers: { @@ -184,6 +180,11 @@ throw new Error('Invalid session'); } + const responseJSON = await response.json(); + localStorage.setItem( + "shiori-account", + JSON.stringify(responseJSON.message), + ); return true; } catch (err) { // Clear invalid session data diff --git a/internal/webserver/handler.go b/internal/webserver/handler.go index ef7adf2c..6a0f97df 100644 --- a/internal/webserver/handler.go +++ b/internal/webserver/handler.go @@ -1,7 +1,9 @@ package webserver import ( + "errors" "fmt" + "net" "net/http" "strings" @@ -21,6 +23,7 @@ type Handler struct { Log bool dependencies model.Dependencies + trustedIPs []*net.IPNet } func (h *Handler) PrepareSessionCache() { @@ -45,30 +48,20 @@ func (h *Handler) PrepareSessionCache() { // validateSession checks whether user session is still valid or not func (h *Handler) validateSession(r *http.Request) error { - authorization := r.Header.Get(model.AuthorizationHeader) - if authorization == "" { - // Get token from cookie - tokenCookie, err := r.Cookie("token") - if err != nil { - return fmt.Errorf("session is not exist") - } + var account *model.AccountDTO + var err error - authorization = tokenCookie.Value + if h.dependencies.Config().Http.SSOProxyAuth { + account, err = h.ssoAccount(r) + if err != nil { + h.dependencies.Logger().WithError(err).Error("getting sso account") + } } - var account *model.AccountDTO - - if authorization != "" { - var err error - - authParts := strings.SplitN(authorization, " ", 2) - if len(authParts) != 2 && authParts[0] != model.AuthorizationTokenType { - return fmt.Errorf("session has been expired") - } - - account, err = h.dependencies.Domains().Auth().CheckToken(r.Context(), authParts[1]) + if account == nil { + account, err = h.tokenAccount(r) if err != nil { - return fmt.Errorf("session has been expired") + return err } } @@ -85,3 +78,70 @@ func (h *Handler) validateSession(r *http.Request) error { return nil } + +func (h *Handler) tokenAccount(r *http.Request) (*model.AccountDTO, error) { + authorization := r.Header.Get(model.AuthorizationHeader) + if authorization == "" { + // Get token from cookie + tokenCookie, err := r.Cookie("token") + if err != nil { + return nil, fmt.Errorf("session is not exist") + } + + authorization = tokenCookie.Value + } + + if authorization != "" { + authParts := strings.SplitN(authorization, " ", 2) + if len(authParts) != 2 || authParts[0] != model.AuthorizationTokenType { + return nil, fmt.Errorf("session has been expired") + } + + account, err := h.dependencies.Domains().Auth().CheckToken(r.Context(), authParts[1]) + if err != nil { + return nil, fmt.Errorf("session has been expired") + } + + return account, nil + } + + return nil, errors.New("session has been expired") +} + +func (h *Handler) ssoAccount(r *http.Request) (*model.AccountDTO, error) { + remoteAddr := r.RemoteAddr + ip, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + var addrErr *net.AddrError + if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" { + ip = remoteAddr + } else { + return nil, err + } + } + requestIP := net.ParseIP(ip) + if !h.isTrustedIP(requestIP) { + return nil, fmt.Errorf("'%s' is not a trusted ip", r.RemoteAddr) + } + + headerName := h.dependencies.Config().Http.SSOProxyAuthHeaderName + userName := r.Header.Get(headerName) + if userName == "" { + return nil, nil + } + + account, err := h.dependencies.Domains().Accounts().GetAccountByUsername(r.Context(), userName) + if err != nil { + return nil, err + } + + return account, nil +} +func (h *Handler) isTrustedIP(ip net.IP) bool { + for _, net := range h.trustedIPs { + if ok := net.Contains(ip); ok { + return true + } + } + return false +} diff --git a/internal/webserver/server.go b/internal/webserver/server.go index 4cf59f5f..9875855e 100644 --- a/internal/webserver/server.go +++ b/internal/webserver/server.go @@ -1,6 +1,7 @@ package webserver import ( + "net" "time" "github.com/go-shiori/shiori/internal/model" @@ -19,6 +20,18 @@ type Config struct { // GetLegacyHandler returns a legacy handler to use with the new webserver func GetLegacyHandler(cfg Config, dependencies model.Dependencies) *Handler { + plainIPs := dependencies.Config().Http.SSOProxyAuthTrusted + trustedIPs := make([]*net.IPNet, len(plainIPs)) + for i, ip := range plainIPs { + _, ipNet, err := net.ParseCIDR(ip) + if err != nil { + dependencies.Logger().WithError(err).WithField("ip", ip).Error("Failed to parse trusted ip cidr") + continue + } + + trustedIPs[i] = ipNet + } + return &Handler{ DB: cfg.DB, DataDir: cfg.DataDir, @@ -28,5 +41,6 @@ func GetLegacyHandler(cfg Config, dependencies model.Dependencies) *Handler { RootPath: cfg.RootPath, Log: cfg.Log, dependencies: dependencies, + trustedIPs: trustedIPs, } }