Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ To learn more about active deprecations, we recommend checking [GitHub Discussio

### New

- TODO ([#XXX](https://github.com/kedacore/keda/issues/XXX))
- **Temporal Scaler**: Add composite metric support (backlog + running workflow count) ([#7459](https://github.com/kedacore/keda/issues/7459))

#### Experimental

Expand Down
75 changes: 61 additions & 14 deletions pkg/scalers/temporal_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"crypto/tls"
"fmt"
"log/slog"
"regexp"
"time"

"github.com/go-logr/logr"
workflowservice "go.temporal.io/api/workflowservice/v1"
sdk "go.temporal.io/sdk/client"
sdklog "go.temporal.io/sdk/log"
"google.golang.org/grpc"
Expand Down Expand Up @@ -35,17 +37,19 @@ type temporalScaler struct {
}

type temporalMetadata struct {
Endpoint string `keda:"name=endpoint, order=triggerMetadata;resolvedEnv"`
Namespace string `keda:"name=namespace, order=triggerMetadata;resolvedEnv, default=default"`
ActivationTargetQueueSize int64 `keda:"name=activationTargetQueueSize, order=triggerMetadata, default=0"`
TargetQueueSize int64 `keda:"name=targetQueueSize, order=triggerMetadata, default=5"`
TaskQueue string `keda:"name=taskQueue, order=triggerMetadata;resolvedEnv"`
QueueTypes []string `keda:"name=queueTypes, order=triggerMetadata, optional"`
BuildID string `keda:"name=buildId, order=triggerMetadata;resolvedEnv, optional"`
AllActive bool `keda:"name=selectAllActive, order=triggerMetadata, default=false"`
Unversioned bool `keda:"name=selectUnversioned, order=triggerMetadata, default=false"`
APIKey string `keda:"name=apiKey, order=authParams;resolvedEnv, optional"`
MinConnectTimeout int `keda:"name=minConnectTimeout, order=triggerMetadata, default=5"`
Endpoint string `keda:"name=endpoint, order=triggerMetadata;resolvedEnv"`
Namespace string `keda:"name=namespace, order=triggerMetadata;resolvedEnv, default=default"`
ActivationTargetQueueSize int64 `keda:"name=activationTargetQueueSize, order=triggerMetadata, default=0"`
TargetQueueSize int64 `keda:"name=targetQueueSize, order=triggerMetadata, default=5"`
TaskQueue string `keda:"name=taskQueue, order=triggerMetadata;resolvedEnv"`
QueueTypes []string `keda:"name=queueTypes, order=triggerMetadata, optional"`
BuildID string `keda:"name=buildId, order=triggerMetadata;resolvedEnv, optional"`
AllActive bool `keda:"name=selectAllActive, order=triggerMetadata, default=false"`
Unversioned bool `keda:"name=selectUnversioned, order=triggerMetadata, default=false"`
IncludeRunningWorkflowCount bool `keda:"name=includeRunningWorkflowCount, order=triggerMetadata, default=false"`
WorkflowTaskQueueForCount string `keda:"name=workflowTaskQueueForCount, order=triggerMetadata;resolvedEnv, optional"`
APIKey string `keda:"name=apiKey, order=authParams;resolvedEnv, optional"`
MinConnectTimeout int `keda:"name=minConnectTimeout, order=triggerMetadata, default=5"`

UnsafeSsl bool `keda:"name=unsafeSsl, order=triggerMetadata, optional"`
Cert string `keda:"name=cert, order=authParams, optional"`
Expand Down Expand Up @@ -73,6 +77,10 @@ func (a *temporalMetadata) Validate() error {
return fmt.Errorf("minConnectTimeout must be a positive number")
}

if a.WorkflowTaskQueueForCount != "" && !a.IncludeRunningWorkflowCount {
return fmt.Errorf("workflowTaskQueueForCount has no effect unless includeRunningWorkflowCount is true")
}

return nil
}

Expand Down Expand Up @@ -127,14 +135,24 @@ func (s *temporalScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpe
}

func (s *temporalScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) {
queueSize, err := s.getQueueSize(ctx)
backlog, err := s.getQueueSize(ctx)
if err != nil {
return nil, false, fmt.Errorf("failed to get Temporal queue size: %w", err)
}

metric := GenerateMetricInMili(metricName, float64(queueSize))
metric := GenerateMetricInMili(metricName, float64(backlog))

return []external_metrics.ExternalMetricValue{metric}, queueSize > s.metadata.ActivationTargetQueueSize, nil
isActive := backlog > s.metadata.ActivationTargetQueueSize
if !isActive && s.metadata.IncludeRunningWorkflowCount {
runningCount, err := s.getRunningWorkflowCount(ctx)
if err != nil {
s.logger.V(1).Info("failed to get running workflow count, skipping for activity check", "error", err)
} else {
isActive = runningCount > 0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this intentionally 0 or should it be > s.metadata.ActivationTargetQueueSize?

}
}

return []external_metrics.ExternalMetricValue{metric}, isActive, nil
}

func (s *temporalScaler) getQueueSize(ctx context.Context) (int64, error) {
Expand Down Expand Up @@ -162,6 +180,35 @@ func (s *temporalScaler) getQueueSize(ctx context.Context) (int64, error) {
return getCombinedBacklogCount(resp), nil
}

// validTaskQueueName matches task queue names containing only safe characters.
// Temporal task queue names support alphanumerics, hyphens, underscores, dots, forward slashes, and colons.
// Rejecting anything outside this set prevents query injection since the SDK offers no parameterized queries.
var validTaskQueueName = regexp.MustCompile(`^[a-zA-Z0-9\-_./:]+$`)

// getRunningWorkflowCount returns the approximate number of running workflow executions
// for the task queue (or workflowTaskQueueForCount if set). Used to avoid premature
// scale-down when workers are fast and backlog is often zero.
func (s *temporalScaler) getRunningWorkflowCount(ctx context.Context) (int64, error) {
taskQueue := s.metadata.WorkflowTaskQueueForCount
if taskQueue == "" {
taskQueue = s.metadata.TaskQueue
}
if !validTaskQueueName.MatchString(taskQueue) {
return 0, fmt.Errorf("task queue name %q contains characters not allowed in visibility queries", taskQueue)
}
query := fmt.Sprintf("ExecutionStatus = 'Running' AND TaskQueue = '%s'", taskQueue)

req := &workflowservice.CountWorkflowExecutionsRequest{
Namespace: s.metadata.Namespace,
Query: query,
}
resp, err := s.tcl.CountWorkflow(ctx, req)
Comment thread
rickbrouwer marked this conversation as resolved.
if err != nil {
return 0, fmt.Errorf("count workflow: %w", err)
}
return resp.GetCount(), nil
}

func getQueueTypes(queueTypes []string) []sdk.TaskQueueType {
var taskQueueTypes []sdk.TaskQueueType
for _, t := range queueTypes {
Expand Down
Loading
Loading