feat: support proxy forward headers authentication (#1105)

* feat: Add SSO forward header

* fix: Use domain layer

* test: Some test

* chore: Print new values when debugging

* chore: Rename enabled envvar

* fix: Wrongly parsing remote ip

* fix: Always validate token. NPE on validateSession

* fix: Dont overwrite token when sso

* fix: Best effort to get ip. Parse as ip:port and then as ip

* fix: Forgot to update handler version

* fix: Forgot to commit changes

* test: GetAccountByUsername

* chore: Rename some variables

* chore: return error from ssoAccount

* refactor: Extract sso proxy auth to own middleware

* fix: Dont panic if not sso account on legacy validate session

* ci: gofmt

---------

Co-authored-by: Felipe Martin <812088+fmartingr@users.noreply.github.com>
This commit is contained in:
Federico Scodelaro 2025-07-12 07:11:42 -03:00 committed by GitHub
parent 24e06a5678
commit 9f6a4c39d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 381 additions and 40 deletions

View file

@ -27,21 +27,24 @@ Most configuration can be set directly using environment variables or flags. The
### HTTP configuration variables
| Environment variable | Default | Required | Description |
| ------------------------------------------ | ------- | -------- | ----------------------------------------------------- |
| `SHIORI_HTTP_ENABLED` | True | No | Enable HTTP service |
| `SHIORI_HTTP_PORT` | 8080 | No | Port number for the HTTP service |
| `SHIORI_HTTP_ADDRESS` | : | No | Address for the HTTP service |
| `SHIORI_HTTP_ROOT_PATH` | / | No | Root path for the HTTP service |
| `SHIORI_HTTP_ACCESS_LOG` | True | No | Logging accessibility for HTTP requests |
| `SHIORI_HTTP_SERVE_WEB_UI` | True | No | Serving Web UI via HTTP. Disable serves only the API. |
| `SHIORI_HTTP_SECRET_KEY` | | **Yes** | Secret key for HTTP sessions. |
| `SHIORI_HTTP_BODY_LIMIT` | 1024 | No | Limit for request body size |
| `SHIORI_HTTP_READ_TIMEOUT` | 10s | No | Maximum duration for reading the entire request |
| `SHIORI_HTTP_WRITE_TIMEOUT` | 10s | No | Maximum duration before timing out writes |
| `SHIORI_HTTP_IDLE_TIMEOUT` | 10s | No | Maximum amount of time to wait for the next request |
| `SHIORI_HTTP_DISABLE_KEEP_ALIVE` | true | No | Disable HTTP keep-alive connections |
| `SHIORI_HTTP_DISABLE_PARSE_MULTIPART_FORM` | true | No | Disable pre-parsing of multipart form |
| Environment variable | Default | Required | Description |
| ------------------------------------------ | ------- | -------- | ----------------------------------------------------- |
| `SHIORI_HTTP_ENABLED` | True | No | Enable HTTP service |
| `SHIORI_HTTP_PORT` | 8080 | No | Port number for the HTTP service |
| `SHIORI_HTTP_ADDRESS` | : | No | Address for the HTTP service |
| `SHIORI_HTTP_ROOT_PATH` | / | No | Root path for the HTTP service |
| `SHIORI_HTTP_ACCESS_LOG` | True | No | Logging accessibility for HTTP requests |
| `SHIORI_HTTP_SERVE_WEB_UI` | True | No | Serving Web UI via HTTP. Disable serves only the API. |
| `SHIORI_HTTP_SECRET_KEY` | | **Yes** | Secret key for HTTP sessions. |
| `SHIORI_HTTP_BODY_LIMIT` | 1024 | No | Limit for request body size |
| `SHIORI_HTTP_READ_TIMEOUT` | 10s | No | Maximum duration for reading the entire request |
| `SHIORI_HTTP_WRITE_TIMEOUT` | 10s | No | Maximum duration before timing out writes |
| `SHIORI_HTTP_IDLE_TIMEOUT` | 10s | No | Maximum amount of time to wait for the next request |
| `SHIORI_HTTP_DISABLE_KEEP_ALIVE` | true | No | Disable HTTP keep-alive connections |
| `SHIORI_HTTP_DISABLE_PARSE_MULTIPART_FORM` | true | No | Disable pre-parsing of multipart form |
| `SHIORI_SSO_PROXY_AUTH_ENABLED` | false | No | Enable SSO Auth Proxy Header |
| `SHIORI_SSO_PROXY_AUTH_HEADER_NAME` | Remote-User | No | List of CIDRs of trusted proxies |
| `SHIORI_SSO_PROXY_AUTH_TRUSTED` | 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7 | No | List of CIDRs of trusted proxies |
### Storage Configuration

View file

@ -65,6 +65,10 @@ type HttpConfig struct {
IDLETimeout time.Duration `env:"HTTP_IDLE_TIMEOUT,default=10s"`
DisableKeepAlive bool `env:"HTTP_DISABLE_KEEP_ALIVE,default=true"`
DisablePreParseMultipartForm bool `env:"HTTP_DISABLE_PARSE_MULTIPART_FORM,default=true"`
SSOProxyAuth bool `env:"SSO_PROXY_AUTH_ENABLED,default=false"`
SSOProxyAuthHeaderName string `env:"SSO_PROXY_AUTH_HEADER_NAME,default=Remote-User"`
SSOProxyAuthTrusted []string `env:"SSO_PROXY_AUTH_TRUSTED,default=10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7"`
}
// SetDefaults sets the default values for the configuration
@ -152,6 +156,9 @@ func (c *Config) DebugConfiguration(logger *logrus.Logger) {
logger.Debugf(" SHIORI_HTTP_IDLE_TIMEOUT: %s", c.Http.IDLETimeout)
logger.Debugf(" SHIORI_HTTP_DISABLE_KEEP_ALIVE: %t", c.Http.DisableKeepAlive)
logger.Debugf(" SHIORI_HTTP_DISABLE_PARSE_MULTIPART_FORM: %t", c.Http.DisablePreParseMultipartForm)
logger.Debugf(" SHIORI_SSO_PROXY_AUTH_ENABLED: %t", c.Http.SSOProxyAuth)
logger.Debugf(" SHIORI_SSO_PROXY_AUTH_HEADER_NAME: %s", c.Http.SSOProxyAuthHeaderName)
logger.Debugf(" SHIORI_SSO_PROXY_AUTH_TRUSTED: %v", c.Http.SSOProxyAuthTrusted)
}
func (c *Config) IsValid() error {

View file

@ -29,6 +29,24 @@ func (d *AccountsDomain) ListAccounts(ctx context.Context) ([]model.AccountDTO,
return accountDTOs, nil
}
func (d *AccountsDomain) GetAccountByUsername(ctx context.Context, username string) (*model.AccountDTO, error) {
if username == "" {
return nil, errors.New("empty username")
}
accounts, err := d.deps.Database().ListAccounts(ctx, model.DBListAccountsOptions{
Username: username,
})
if err != nil {
return nil, fmt.Errorf("error getting accounts: %v", err)
}
if len(accounts) != 1 {
return nil, fmt.Errorf("got none or more than one account by username: %s", username)
}
return model.Ptr(accounts[0].ToDTO()), nil
}
func (d *AccountsDomain) CreateAccount(ctx context.Context, account model.AccountDTO) (*model.AccountDTO, error) {
if err := account.IsValidCreate(); err != nil {
return nil, err

View file

@ -38,6 +38,30 @@ func TestAccountDomainsListAccounts(t *testing.T) {
})
}
func TestAccountDomainsGetAccountByUsername(t *testing.T) {
logger := logrus.New()
_, deps := testutil.GetTestConfigurationAndDependencies(t, context.TODO(), logger)
t.Run("empty", func(t *testing.T) {
account, err := deps.Domains().Accounts().GetAccountByUsername(context.Background(), "")
require.Error(t, err)
require.Nil(t, account)
})
t.Run("account found", func(t *testing.T) {
_, err := deps.Domains().Accounts().CreateAccount(context.TODO(), model.AccountDTO{
Username: "user1",
Password: "password1",
})
require.NoError(t, err)
account, err := deps.Domains().Accounts().GetAccountByUsername(context.Background(), "user1")
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, "user1", account.Username)
})
}
func TestAccountDomainCreateAccount(t *testing.T) {
logger := logrus.New()
_, deps := testutil.GetTestConfigurationAndDependencies(t, context.TODO(), logger)

View file

@ -21,6 +21,10 @@ func NewAuthMiddleware(deps model.Dependencies) *AuthMiddleware {
}
func (m *AuthMiddleware) OnRequest(deps model.Dependencies, c model.WebContext) error {
if c.UserIsLogged() {
return nil
}
token := getTokenFromHeader(c.Request())
if token == "" {
token = getTokenFromCookie(c.Request())

View file

@ -0,0 +1,104 @@
package middleware
import (
"errors"
"net"
"github.com/go-shiori/shiori/internal/model"
)
// AuthMiddleware handles authentication for incoming request by checking the token
// from the Authorization header or the token cookie and setting the account in the
// request context.
type AuthSSOProxyMiddleware struct {
deps model.Dependencies
trustedIPs []*net.IPNet
}
func NewAuthSSOProxyMiddleware(deps model.Dependencies) *AuthSSOProxyMiddleware {
plainIPs := deps.Config().Http.SSOProxyAuthTrusted
trustedIPs := make([]*net.IPNet, len(plainIPs))
for i, ip := range plainIPs {
_, ipNet, err := net.ParseCIDR(ip)
if err != nil {
deps.Logger().WithError(err).WithField("ip", ip).Error("Failed to parse trusted ip cidr")
continue
}
trustedIPs[i] = ipNet
}
return &AuthSSOProxyMiddleware{
deps: deps,
trustedIPs: trustedIPs,
}
}
func (m *AuthSSOProxyMiddleware) OnRequest(deps model.Dependencies, c model.WebContext) error {
if c.UserIsLogged() {
return nil
}
account, err := m.ssoAccount(deps, c)
if err != nil {
deps.Logger().
WithError(err).
WithField("remote_addr", c.Request().RemoteAddr).
WithField("request_id", c.GetRequestID()).
Error("getting sso account")
return nil
}
if account != nil {
c.SetAccount(account)
return nil
}
return nil
}
func (m *AuthSSOProxyMiddleware) ssoAccount(deps model.Dependencies, c model.WebContext) (*model.AccountDTO, error) {
if !deps.Config().Http.SSOProxyAuth {
return nil, nil
}
remoteAddr := c.Request().RemoteAddr
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
var addrErr *net.AddrError
if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" {
ip = remoteAddr
} else {
return nil, err
}
}
requestIP := net.ParseIP(ip)
if !m.isTrustedIP(requestIP) {
return nil, errors.New("remoteAddr is not a trusted ip")
}
headerName := deps.Config().Http.SSOProxyAuthHeaderName
userName := c.Request().Header.Get(headerName)
if userName == "" {
return nil, nil
}
account, err := deps.Domains().Accounts().GetAccountByUsername(c.Request().Context(), userName)
if err != nil {
return nil, err
}
return account, nil
}
func (m *AuthSSOProxyMiddleware) isTrustedIP(ip net.IP) bool {
for _, net := range m.trustedIPs {
if ok := net.Contains(ip); ok {
return true
}
}
return false
}
func (m *AuthSSOProxyMiddleware) OnResponse(deps model.Dependencies, c model.WebContext) error {
return nil
}

View file

@ -0,0 +1,101 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-shiori/shiori/internal/http/webcontext"
"github.com/go-shiori/shiori/internal/model"
"github.com/go-shiori/shiori/internal/testutil"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
func TestAuthMiddlewareWithSSO(t *testing.T) {
logger := logrus.New()
_, deps := testutil.GetTestConfigurationAndDependencies(t, context.TODO(), logger)
deps.Config().Http.SSOProxyAuth = true
account, err := deps.Domains().Accounts().CreateAccount(context.TODO(), model.AccountDTO{
ID: model.DBID(98),
Username: "test_username",
Password: "super_secure_password",
})
require.NoError(t, err)
t.Run("test no authorization method", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
c := webcontext.NewWebContext(w, r)
middleware := NewAuthSSOProxyMiddleware(deps)
err := middleware.OnRequest(deps, c)
require.NoError(t, err)
require.Nil(t, c.GetAccount())
})
t.Run("test untrusted ip", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.RemoteAddr = "invalid-ip"
c := webcontext.NewWebContext(w, r)
middleware := NewAuthSSOProxyMiddleware(deps)
err := middleware.OnRequest(deps, c)
require.NoError(t, err)
require.Nil(t, c.GetAccount())
})
t.Run("test empty header", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.RemoteAddr = "10.0.0.3"
c := webcontext.NewWebContext(w, r)
middleware := NewAuthSSOProxyMiddleware(deps)
err := middleware.OnRequest(deps, c)
require.NoError(t, err)
require.Nil(t, c.GetAccount())
})
t.Run("test invalid sso username", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.RemoteAddr = "10.0.0.3"
r.Header.Add("Remote-User", "username")
c := webcontext.NewWebContext(w, r)
middleware := NewAuthSSOProxyMiddleware(deps)
err := middleware.OnRequest(deps, c)
require.NoError(t, err)
require.Nil(t, c.GetAccount())
})
t.Run("test sso login", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.RemoteAddr = "10.0.0.3"
r.Header.Add("Remote-User", account.Username)
c := webcontext.NewWebContext(w, r)
middleware := NewAuthSSOProxyMiddleware(deps)
err := middleware.OnRequest(deps, c)
require.NoError(t, err)
require.NotNil(t, c.GetAccount())
})
t.Run("test sso login ip:port", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.RemoteAddr = "10.0.0.3:65342"
r.Header.Add("Remote-User", account.Username)
c := webcontext.NewWebContext(w, r)
middleware := NewAuthSSOProxyMiddleware(deps)
err := middleware.OnRequest(deps, c)
require.NoError(t, err)
require.NotNil(t, c.GetAccount())
})
}

View file

@ -33,6 +33,10 @@ func (s *HttpServer) Setup(cfg *config.Config, deps *dependencies.Dependencies)
globalMiddleware := []model.HttpMiddleware{}
if cfg.Http.SSOProxyAuth {
globalMiddleware = append(globalMiddleware, middleware.NewAuthSSOProxyMiddleware(deps))
}
// Add message response middleware if legacy message response is enabled
globalMiddleware = append(globalMiddleware, []model.HttpMiddleware{
middleware.NewMessageResponseMiddleware(deps),

View file

@ -31,6 +31,7 @@ type AuthDomain interface {
type AccountsDomain interface {
ListAccounts(ctx context.Context) ([]AccountDTO, error)
GetAccountByUsername(ctx context.Context, username string) (*AccountDTO, error)
CreateAccount(ctx context.Context, account AccountDTO) (*AccountDTO, error)
UpdateAccount(ctx context.Context, account AccountDTO) (*AccountDTO, error)
DeleteAccount(ctx context.Context, id int) error

View file

@ -160,8 +160,8 @@
},
onLoginSuccess() {
this.loadSetting();
this.loadAccount();
this.loadSetting();
this.isLoggedIn = true;
},
@ -169,10 +169,6 @@
const token = localStorage.getItem("shiori-token");
const account = localStorage.getItem("shiori-account");
if (!(token && account)) {
return false;
}
try {
const response = await fetch(new URL("api/v1/auth/me", document.baseURI), {
headers: {
@ -184,6 +180,11 @@
throw new Error('Invalid session');
}
const responseJSON = await response.json();
localStorage.setItem(
"shiori-account",
JSON.stringify(responseJSON.message),
);
return true;
} catch (err) {
// Clear invalid session data

View file

@ -1,7 +1,9 @@
package webserver
import (
"errors"
"fmt"
"net"
"net/http"
"strings"
@ -21,6 +23,7 @@ type Handler struct {
Log bool
dependencies model.Dependencies
trustedIPs []*net.IPNet
}
func (h *Handler) PrepareSessionCache() {
@ -45,30 +48,20 @@ func (h *Handler) PrepareSessionCache() {
// validateSession checks whether user session is still valid or not
func (h *Handler) validateSession(r *http.Request) error {
authorization := r.Header.Get(model.AuthorizationHeader)
if authorization == "" {
// Get token from cookie
tokenCookie, err := r.Cookie("token")
if err != nil {
return fmt.Errorf("session is not exist")
}
var account *model.AccountDTO
var err error
authorization = tokenCookie.Value
if h.dependencies.Config().Http.SSOProxyAuth {
account, err = h.ssoAccount(r)
if err != nil {
h.dependencies.Logger().WithError(err).Error("getting sso account")
}
}
var account *model.AccountDTO
if authorization != "" {
var err error
authParts := strings.SplitN(authorization, " ", 2)
if len(authParts) != 2 && authParts[0] != model.AuthorizationTokenType {
return fmt.Errorf("session has been expired")
}
account, err = h.dependencies.Domains().Auth().CheckToken(r.Context(), authParts[1])
if account == nil {
account, err = h.tokenAccount(r)
if err != nil {
return fmt.Errorf("session has been expired")
return err
}
}
@ -85,3 +78,70 @@ func (h *Handler) validateSession(r *http.Request) error {
return nil
}
func (h *Handler) tokenAccount(r *http.Request) (*model.AccountDTO, error) {
authorization := r.Header.Get(model.AuthorizationHeader)
if authorization == "" {
// Get token from cookie
tokenCookie, err := r.Cookie("token")
if err != nil {
return nil, fmt.Errorf("session is not exist")
}
authorization = tokenCookie.Value
}
if authorization != "" {
authParts := strings.SplitN(authorization, " ", 2)
if len(authParts) != 2 || authParts[0] != model.AuthorizationTokenType {
return nil, fmt.Errorf("session has been expired")
}
account, err := h.dependencies.Domains().Auth().CheckToken(r.Context(), authParts[1])
if err != nil {
return nil, fmt.Errorf("session has been expired")
}
return account, nil
}
return nil, errors.New("session has been expired")
}
func (h *Handler) ssoAccount(r *http.Request) (*model.AccountDTO, error) {
remoteAddr := r.RemoteAddr
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
var addrErr *net.AddrError
if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" {
ip = remoteAddr
} else {
return nil, err
}
}
requestIP := net.ParseIP(ip)
if !h.isTrustedIP(requestIP) {
return nil, fmt.Errorf("'%s' is not a trusted ip", r.RemoteAddr)
}
headerName := h.dependencies.Config().Http.SSOProxyAuthHeaderName
userName := r.Header.Get(headerName)
if userName == "" {
return nil, nil
}
account, err := h.dependencies.Domains().Accounts().GetAccountByUsername(r.Context(), userName)
if err != nil {
return nil, err
}
return account, nil
}
func (h *Handler) isTrustedIP(ip net.IP) bool {
for _, net := range h.trustedIPs {
if ok := net.Contains(ip); ok {
return true
}
}
return false
}

View file

@ -1,6 +1,7 @@
package webserver
import (
"net"
"time"
"github.com/go-shiori/shiori/internal/model"
@ -19,6 +20,18 @@ type Config struct {
// GetLegacyHandler returns a legacy handler to use with the new webserver
func GetLegacyHandler(cfg Config, dependencies model.Dependencies) *Handler {
plainIPs := dependencies.Config().Http.SSOProxyAuthTrusted
trustedIPs := make([]*net.IPNet, len(plainIPs))
for i, ip := range plainIPs {
_, ipNet, err := net.ParseCIDR(ip)
if err != nil {
dependencies.Logger().WithError(err).WithField("ip", ip).Error("Failed to parse trusted ip cidr")
continue
}
trustedIPs[i] = ipNet
}
return &Handler{
DB: cfg.DB,
DataDir: cfg.DataDir,
@ -28,5 +41,6 @@ func GetLegacyHandler(cfg Config, dependencies model.Dependencies) *Handler {
RootPath: cfg.RootPath,
Log: cfg.Log,
dependencies: dependencies,
trustedIPs: trustedIPs,
}
}