diff --git a/api/auth.go b/api/auth.go index d8289c106..946cda966 100644 --- a/api/auth.go +++ b/api/auth.go @@ -213,9 +213,25 @@ func authenticationHandler(w http.ResponseWriter, r *http.Request) (ok bool, req req = r + var jwtConfig *util.JWTAuthConfig + if util.Config.Auth != nil { + jwtConfig = util.Config.Auth.JWT + } authHeader := strings.ToLower(r.Header.Get("authorization")) - if len(authHeader) > 0 && strings.Contains(authHeader, "bearer") { + if jwtConfig != nil && jwtConfig.Enabled && r.Header.Get(jwtConfig.GetHeader()) != "" { + // JWT proxy auth: if the header is present, commit to this path. + var err error + userID, err = authenticateByJWT(r) + if err != nil { + log.WithFields(log.Fields{ + "path": r.URL.Path, + "remote": r.RemoteAddr, + }).Warn("JWT auth failed: ", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + } else if len(authHeader) > 0 && strings.Contains(authHeader, "bearer") { token, err := helpers.Store(r).GetAPIToken(strings.Replace(authHeader, "bearer ", "", 1)) if err != nil { diff --git a/api/jwt_auth.go b/api/jwt_auth.go new file mode 100644 index 000000000..35e2d00ee --- /dev/null +++ b/api/jwt_auth.go @@ -0,0 +1,130 @@ +package api + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/MicahParks/keyfunc/v3" + "github.com/golang-jwt/jwt/v5" + log "github.com/sirupsen/logrus" + + "github.com/semaphoreui/semaphore/api/helpers" + "github.com/semaphoreui/semaphore/db" + "github.com/semaphoreui/semaphore/util" +) + +var ( + globalKeyfunc keyfunc.Keyfunc + globalJWTParser *jwt.Parser +) + +// initJWKSCache creates the JWT parser and starts keyfunc's JWKS client. +// keyfunc.NewDefaultCtx performs an initial HTTP fetch (up to 1 min timeout) +// but with NoErrorReturnFirstHTTPReq=true it returns successfully even if the +// endpoint is unreachable. Its built-in refresh goroutine retries hourly. +func initJWKSCache(jwksURL string) { + if !strings.HasPrefix(jwksURL, "https://") { + log.Warn("JWT JWKS URL is not HTTPS: ", jwksURL) + } + + globalJWTParser = newJWTParser(util.Config.Auth.JWT) + + kf, err := keyfunc.NewDefaultCtx(context.Background(), []string{jwksURL}) + if err != nil { + log.Errorf("JWKS setup for %s failed: %v — JWT auth will not work", jwksURL, err) + return + } + + globalKeyfunc = kf + log.Info("JWKS initialized from ", jwksURL) +} + +func newJWTParser(config *util.JWTAuthConfig) *jwt.Parser { + opts := []jwt.ParserOption{ + jwt.WithValidMethods([]string{"ES256", "ES384", "ES512", "RS256", "RS384", "RS512"}), + jwt.WithExpirationRequired(), + } + + if config.Audience != "" { + opts = append(opts, jwt.WithAudience(config.Audience)) + } + if config.Issuer != "" { + opts = append(opts, jwt.WithIssuer(config.Issuer)) + } + + return jwt.NewParser(opts...) +} + +func validateProxyJWT(tokenString string) (map[string]any, error) { + if globalKeyfunc == nil { + return nil, fmt.Errorf("JWKS not available — JWT auth is not configured") + } + + token, err := globalJWTParser.Parse(tokenString, globalKeyfunc.Keyfunc) + if err != nil { + // Parse without verification solely to extract iss/aud for operator-facing + // log messages. The token has already been rejected above. + unverified, _, parseErr := jwt.NewParser(jwt.WithoutClaimsValidation()).ParseUnverified(tokenString, jwt.MapClaims{}) + if parseErr == nil { + if claims, ok := unverified.Claims.(jwt.MapClaims); ok { + return nil, fmt.Errorf("JWT validation failed (iss=%v aud=%v): %w", + claims["iss"], claims["aud"], err) + } + } + return nil, fmt.Errorf("JWT validation failed: %w", err) + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, fmt.Errorf("unexpected claims type") + } + + return claims, nil +} + +func authenticateByJWT(r *http.Request) (int, error) { + config := util.Config.Auth.JWT + + tokenString := r.Header.Get(config.GetHeader()) + if tokenString == "" { + return 0, fmt.Errorf("no JWT in header %s", config.GetHeader()) + } + + claims, err := validateProxyJWT(tokenString) + if err != nil { + return 0, err + } + + prepareClaims(claims) + parsed, err := parseClaims(claims, config) + if err != nil { + return 0, fmt.Errorf("extract claims: %w", err) + } + + store := helpers.Store(r) + + user, err := store.GetUserByLoginOrEmail("", parsed.email) + + if errors.Is(err, db.ErrNotFound) { + user = db.User{ + Username: parsed.username, + Name: parsed.name, + Email: parsed.email, + External: true, + } + user, err = store.CreateUserWithoutPassword(user) + } + + if err != nil { + return 0, fmt.Errorf("JWT user lookup/creation: %w", err) + } + + if !user.External { + return 0, fmt.Errorf("JWT user %q conflicts with local user", user.Email) + } + + return user.ID, nil +} diff --git a/api/jwt_auth_test.go b/api/jwt_auth_test.go new file mode 100644 index 000000000..314f7c6e8 --- /dev/null +++ b/api/jwt_auth_test.go @@ -0,0 +1,287 @@ +package api + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "testing" + "time" + + "github.com/MicahParks/jwkset" + "github.com/MicahParks/keyfunc/v3" + "github.com/golang-jwt/jwt/v5" + "github.com/semaphoreui/semaphore/util" +) + +func generateTestKey(t *testing.T) *ecdsa.PrivateKey { + t.Helper() + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + return key +} + +func signTestJWT(t *testing.T, key *ecdsa.PrivateKey, claims jwt.MapClaims) string { + t.Helper() + token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + token.Header["kid"] = "test-key-1" + + signed, err := token.SignedString(key) + if err != nil { + t.Fatal(err) + } + return signed +} + +func setupTestJWTAuth(t *testing.T, key *ecdsa.PrivateKey, config *util.JWTAuthConfig) { + t.Helper() + + jwk, err := jwkset.NewJWKFromKey(&key.PublicKey, jwkset.JWKOptions{ + Metadata: jwkset.JWKMetadataOptions{ + KID: "test-key-1", + ALG: jwkset.AlgES256, + USE: jwkset.UseSig, + }, + }) + if err != nil { + t.Fatal(err) + } + + storage := jwkset.NewMemoryStorage() + if err := storage.KeyWrite(t.Context(), jwk); err != nil { + t.Fatal(err) + } + + kf, err := keyfunc.New(keyfunc.Options{ + Storage: storage, + }) + if err != nil { + t.Fatal(err) + } + + oldKF := globalKeyfunc + oldParser := globalJWTParser + t.Cleanup(func() { + globalKeyfunc = oldKF + globalJWTParser = oldParser + }) + globalKeyfunc = kf + globalJWTParser = newJWTParser(config) +} + +func setupEmptyJWTAuth(t *testing.T, config *util.JWTAuthConfig) { + t.Helper() + + kf, err := keyfunc.New(keyfunc.Options{ + Storage: jwkset.NewMemoryStorage(), + }) + if err != nil { + t.Fatal(err) + } + + oldKF := globalKeyfunc + oldParser := globalJWTParser + t.Cleanup(func() { + globalKeyfunc = oldKF + globalJWTParser = oldParser + }) + globalKeyfunc = kf + globalJWTParser = newJWTParser(config) +} + +func TestValidateProxyJWT_Valid(t *testing.T) { + key := generateTestKey(t) + config := &util.JWTAuthConfig{ + Enabled: true, + Audience: "https://semaphore.example.com", + Issuer: "https://auth.example.com", + } + setupTestJWTAuth(t, key, config) + + token := signTestJWT(t, key, jwt.MapClaims{ + "sub": "user123", + "email": "user@example.com", + "name": "Test User", + "aud": "https://semaphore.example.com", + "iss": "https://auth.example.com", + "exp": jwt.NewNumericDate(time.Now().Add(time.Hour)), + "iat": jwt.NewNumericDate(time.Now()), + }) + + claims, err := validateProxyJWT(token) + if err != nil { + t.Fatal("expected valid token, got error:", err) + } + + if claims["email"] != "user@example.com" { + t.Errorf("expected email user@example.com, got %v", claims["email"]) + } +} + +func TestValidateProxyJWT_Expired(t *testing.T) { + key := generateTestKey(t) + setupTestJWTAuth(t, key, &util.JWTAuthConfig{Enabled: true}) + + token := signTestJWT(t, key, jwt.MapClaims{ + "sub": "user123", + "email": "user@example.com", + "exp": jwt.NewNumericDate(time.Now().Add(-time.Hour)), + }) + + _, err := validateProxyJWT(token) + if err == nil { + t.Fatal("expected error for expired token") + } +} + +func TestValidateProxyJWT_WrongAudience(t *testing.T) { + key := generateTestKey(t) + setupTestJWTAuth(t, key, &util.JWTAuthConfig{ + Enabled: true, + Audience: "https://semaphore.example.com", + }) + + token := signTestJWT(t, key, jwt.MapClaims{ + "sub": "user123", + "email": "user@example.com", + "aud": "https://wrong.example.com", + "exp": jwt.NewNumericDate(time.Now().Add(time.Hour)), + }) + + _, err := validateProxyJWT(token) + if err == nil { + t.Fatal("expected error for wrong audience") + } +} + +func TestValidateProxyJWT_WrongIssuer(t *testing.T) { + key := generateTestKey(t) + setupTestJWTAuth(t, key, &util.JWTAuthConfig{ + Enabled: true, + Issuer: "https://auth.example.com", + }) + + token := signTestJWT(t, key, jwt.MapClaims{ + "sub": "user123", + "email": "user@example.com", + "iss": "https://evil.example.com", + "exp": jwt.NewNumericDate(time.Now().Add(time.Hour)), + }) + + _, err := validateProxyJWT(token) + if err == nil { + t.Fatal("expected error for wrong issuer") + } +} + +func TestValidateProxyJWT_UnknownKey(t *testing.T) { + key := generateTestKey(t) + setupEmptyJWTAuth(t, &util.JWTAuthConfig{Enabled: true}) + + token := signTestJWT(t, key, jwt.MapClaims{ + "sub": "user123", + "email": "user@example.com", + "exp": jwt.NewNumericDate(time.Now().Add(time.Hour)), + }) + + _, err := validateProxyJWT(token) + if err == nil { + t.Fatal("expected error for unknown key") + } +} + +func TestValidateProxyJWT_MissingExp(t *testing.T) { + key := generateTestKey(t) + setupTestJWTAuth(t, key, &util.JWTAuthConfig{Enabled: true}) + + token := signTestJWT(t, key, jwt.MapClaims{ + "sub": "user123", + "email": "user@example.com", + }) + + _, err := validateProxyJWT(token) + if err == nil { + t.Fatal("expected error for missing exp") + } +} + +func TestValidateProxyJWT_AudienceArray(t *testing.T) { + key := generateTestKey(t) + setupTestJWTAuth(t, key, &util.JWTAuthConfig{ + Enabled: true, + Audience: "https://semaphore.example.com", + }) + + token := signTestJWT(t, key, jwt.MapClaims{ + "sub": "user123", + "email": "user@example.com", + "aud": jwt.ClaimStrings{"https://other.example.com", "https://semaphore.example.com"}, + "exp": jwt.NewNumericDate(time.Now().Add(time.Hour)), + }) + + claims, err := validateProxyJWT(token) + if err != nil { + t.Fatal("expected valid token with audience array, got error:", err) + } + + if claims["email"] != "user@example.com" { + t.Errorf("expected email user@example.com, got %v", claims["email"]) + } +} + +func TestValidateProxyJWT_NoAudienceValidation(t *testing.T) { + key := generateTestKey(t) + setupTestJWTAuth(t, key, &util.JWTAuthConfig{Enabled: true}) + + token := signTestJWT(t, key, jwt.MapClaims{ + "sub": "user123", + "email": "user@example.com", + "aud": "https://anything.example.com", + "exp": jwt.NewNumericDate(time.Now().Add(time.Hour)), + }) + + _, err := validateProxyJWT(token) + if err != nil { + t.Fatal("expected valid token when no audience configured, got error:", err) + } +} + +func TestJWTAuthConfig_Defaults(t *testing.T) { + config := &util.JWTAuthConfig{} + + if config.GetHeader() != "" { + t.Errorf("expected empty default header, got %s", config.GetHeader()) + } + if config.GetEmailClaim() != "email" { + t.Errorf("expected default email claim 'email', got %s", config.GetEmailClaim()) + } + if config.GetNameClaim() != "name" { + t.Errorf("expected default name claim 'name', got %s", config.GetNameClaim()) + } + if config.GetUsernameClaim() != "email" { + t.Errorf("expected default username claim 'email', got %s", config.GetUsernameClaim()) + } +} + +func TestJWTAuthConfig_CustomValues(t *testing.T) { + config := &util.JWTAuthConfig{ + Header: "X-Custom-JWT", + EmailClaim: "mail", + NameClaim: "display_name", + UsernameClaim: "preferred_username", + } + + if config.GetHeader() != "X-Custom-JWT" { + t.Errorf("expected header X-Custom-JWT, got %s", config.GetHeader()) + } + if config.GetEmailClaim() != "mail" { + t.Errorf("expected email claim 'mail', got %s", config.GetEmailClaim()) + } + if config.GetNameClaim() != "display_name" { + t.Errorf("expected name claim 'display_name', got %s", config.GetNameClaim()) + } + if config.GetUsernameClaim() != "preferred_username" { + t.Errorf("expected username claim 'preferred_username', got %s", config.GetUsernameClaim()) + } +} diff --git a/api/router.go b/api/router.go index 67d850d1f..4a40cbf3c 100644 --- a/api/router.go +++ b/api/router.go @@ -95,6 +95,9 @@ func Route( environmentService server.EnvironmentService, subscriptionService pro_interfaces.SubscriptionService, ) *mux.Router { + if util.Config.Auth != nil && util.Config.Auth.JWT != nil && util.Config.Auth.JWT.Enabled { + initJWKSCache(util.Config.Auth.JWT.JWKSURL) + } projectController := &projects.ProjectController{ProjectService: projectService} runnerController := runners.NewRunnerController(store, taskPool, encryptionService) diff --git a/go.mod b/go.mod index 3abc5ee26..33dcad443 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,15 @@ go 1.24.6 require ( github.com/Masterminds/squirrel v1.5.4 + github.com/MicahParks/jwkset v0.11.0 + github.com/MicahParks/keyfunc/v3 v3.8.0 github.com/coreos/go-oidc/v3 v3.17.0 github.com/creack/pty v1.1.24 github.com/go-git/go-git/v5 v5.16.5 github.com/go-gorp/gorp/v3 v3.1.0 github.com/go-ldap/ldap/v3 v3.4.12 github.com/go-sql-driver/mysql v1.9.3 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/google/go-github v17.0.0+incompatible github.com/gorilla/handlers v1.5.2 github.com/gorilla/mux v1.8.1 @@ -72,6 +75,7 @@ require ( golang.org/x/net v0.49.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/term v0.40.0 // indirect + golang.org/x/time v0.9.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/libc v1.66.10 // indirect diff --git a/go.sum b/go.sum index 01a393bf1..f2a2605d0 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,10 @@ github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+ github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM= github.com/Masterminds/squirrel v1.5.4/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10= +github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= +github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= +github.com/MicahParks/keyfunc/v3 v3.8.0 h1:Hx2dgIjAXGk9slakM6rV9BOeaWDPEXXZ4Us8guNBfds= +github.com/MicahParks/keyfunc/v3 v3.8.0/go.mod h1:z66bkCviwqfg2YUp+Jcc/xRE9IXLcMq6DrgV/+Htru0= github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= @@ -59,6 +63,8 @@ github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -197,6 +203,8 @@ golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= +golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= diff --git a/util/config.go b/util/config.go index 3508b8fb5..a6f483d37 100644 --- a/util/config.go +++ b/util/config.go @@ -996,6 +996,15 @@ func validateConfig() { panic(err) } + if Config.Auth != nil && Config.Auth.JWT != nil && Config.Auth.JWT.Enabled { + if Config.Auth.JWT.Header == "" { + panic("jwt auth is enabled but header is not configured (set auth.jwt.header or SEMAPHORE_JWT_AUTH_HEADER)") + } + if Config.Auth.JWT.JWKSURL == "" { + panic("jwt auth is enabled but jwks_url is not configured (set auth.jwt.jwks_url or SEMAPHORE_JWT_AUTH_JWKS_URL)") + } + } + if err := validateAccessKeyEncryption(Config.AccessKeyEncryption); err != nil { panic(err) } diff --git a/util/config_auth.go b/util/config_auth.go index c5f4ef577..3794be83f 100644 --- a/util/config_auth.go +++ b/util/config_auth.go @@ -13,7 +13,44 @@ type EmailAuthConfig struct { DisableForOidc bool `json:"disable_for_oidc" env:"SEMAPHORE_EMAIL_2TP_DISABLE_FOR_OIDC"` } +type JWTAuthConfig struct { + Enabled bool `json:"enabled" env:"SEMAPHORE_JWT_AUTH_ENABLED"` + Header string `json:"header" env:"SEMAPHORE_JWT_AUTH_HEADER"` + JWKSURL string `json:"jwks_url" env:"SEMAPHORE_JWT_AUTH_JWKS_URL"` + Audience string `json:"audience" env:"SEMAPHORE_JWT_AUTH_AUDIENCE"` + Issuer string `json:"issuer" env:"SEMAPHORE_JWT_AUTH_ISSUER"` + UsernameClaim string `json:"username_claim" env:"SEMAPHORE_JWT_AUTH_USERNAME_CLAIM"` + NameClaim string `json:"name_claim" env:"SEMAPHORE_JWT_AUTH_NAME_CLAIM"` + EmailClaim string `json:"email_claim" env:"SEMAPHORE_JWT_AUTH_EMAIL_CLAIM"` +} + +func (j *JWTAuthConfig) GetUsernameClaim() string { + if j.UsernameClaim == "" { + return "email" + } + return j.UsernameClaim +} + +func (j *JWTAuthConfig) GetEmailClaim() string { + if j.EmailClaim == "" { + return "email" + } + return j.EmailClaim +} + +func (j *JWTAuthConfig) GetNameClaim() string { + if j.NameClaim == "" { + return "name" + } + return j.NameClaim +} + +func (j *JWTAuthConfig) GetHeader() string { + return j.Header +} + type AuthConfig struct { Totp *TotpConfig `json:"totp,omitempty"` Email *EmailAuthConfig `json:"email,omitempty"` + JWT *JWTAuthConfig `json:"jwt,omitempty"` }