diff --git a/otel/otel.go b/otel/otel.go index ed65bb16d0..90404efc11 100644 --- a/otel/otel.go +++ b/otel/otel.go @@ -8,10 +8,13 @@ import ( "errors" "fmt" "os" + "slices" + "strings" "time" "go.opentelemetry.io/contrib/exporters/autoexport" "go.opentelemetry.io/contrib/propagators/autoprop" + "go.opentelemetry.io/contrib/propagators/aws/xray" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" @@ -23,6 +26,7 @@ import ( "github.com/bombsimon/logrusr/v4" "github.com/sirupsen/logrus" + "github.com/zalando/skipper/otel/xxray" ) var log = logrus.WithField("package", "otel") @@ -53,6 +57,10 @@ type BatchSpanProcessor struct { MaxExportBatchSize int `yaml:"maxExportBatchSize"` } +func init() { + autoprop.RegisterTextMapPropagator("xxray", xxray.NewPropagator()) +} + // Init bootstraps the OpenTelemetry pipeline using environment variables and provided options. // Make sure to call shutdown for proper cleanup if err is nil. // @@ -129,17 +137,22 @@ func Init(ctx context.Context, o *Options) (shutdown func(context.Context) error return handleErr(err) } - tracerProvider := trace.NewTracerProvider(batcherOpt, resourceOpt) - shutdownFuncs = append(shutdownFuncs, tracerProvider.Shutdown) - - otel.SetTracerProvider(tracerProvider) - propagator, err := textMapPropagator(o) if err != nil { return handleErr(err) } otel.SetTextMapPropagator(propagator) + var idGenerator trace.IDGenerator + if hasPropagator("xray", o) || hasPropagator("xxray", o) { + idGenerator = xray.NewIDGenerator() + } + + tracerProvider := trace.NewTracerProvider(batcherOpt, resourceOpt, trace.WithIDGenerator(idGenerator)) + shutdownFuncs = append(shutdownFuncs, tracerProvider.Shutdown) + + otel.SetTracerProvider(tracerProvider) + otel.SetErrorHandler(otel.ErrorHandlerFunc(func(err error) { log.Error(err) })) otel.SetLogger(logrusr.New(log)) @@ -246,6 +259,14 @@ func textMapPropagator(o *Options) (propagation.TextMapPropagator, error) { } } +func hasPropagator(name string, o *Options) bool { + if len(o.Propagators) > 0 { + return slices.Contains(o.Propagators, name) + } else { + return slices.Contains(strings.Split(os.Getenv("OTEL_PROPAGATORS"), ","), name) + } +} + func skipperDebugSpanExporter(ctx context.Context) (trace.SpanExporter, error) { return stdouttrace.New(stdouttrace.WithWriter(writerFunc(func(p []byte) (int, error) { log.Debugf("Span: %s", p) diff --git a/otel/xxray/propagator.go b/otel/xxray/propagator.go new file mode 100644 index 0000000000..1d608904ea --- /dev/null +++ b/otel/xxray/propagator.go @@ -0,0 +1,159 @@ +package xxray + +import ( + "context" + "errors" + "strings" + + "go.opentelemetry.io/contrib/propagators/aws/xray" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/trace" +) + +// Propagator is an AWS X-Ray trace propagator that extends the standard [xray.Propagator]. +// Standard propagator requires both Root and Parent keys to be present in the X-Amzn-Trace-Id header +// to successfully extract span context. +// AWS [ALB request tracing] creates X-Amzn-Trace-Id header with only Root field - this propagator +// can re-use it to obtain trace ID value. +// +// [ALB request tracing]: https://docs.aws.amazon.com/elasticloadbalancing/latest/application/load-balancer-request-tracing.html +type Propagator struct { + xray.Propagator + idGenerator *xray.IDGenerator +} + +func NewPropagator() *Propagator { + return &Propagator{idGenerator: xray.NewIDGenerator()} +} + +func (p *Propagator) Extract(ctx context.Context, carrier propagation.TextMapCarrier) context.Context { + newCtx := p.Propagator.Extract(ctx, carrier) + // If failed to extract span context, try to re-use trace id + if newCtx == ctx { + if header := carrier.Get(traceHeaderKey); header != "" { + tsc, err := extract(header) + if err == nil && tsc.TraceID().IsValid() { + // Re-use only trace id + return trace.ContextWithRemoteSpanContext(ctx, trace.NewSpanContext(trace.SpanContextConfig{ + TraceID: tsc.TraceID(), + SpanID: p.idGenerator.NewSpanID(ctx, tsc.TraceID()), + })) + } + } + } + return newCtx +} + +// The rest is copied from https://github.com/open-telemetry/opentelemetry-go-contrib/blob/80c9316336ebb4f4c67d8e1011a3add889213fb7/propagators/aws/xray/propagator.go +const ( + traceHeaderKey = "X-Amzn-Trace-Id" + traceHeaderDelimiter = ";" + kvDelimiter = "=" + traceIDKey = "Root" + sampleFlagKey = "Sampled" + parentIDKey = "Parent" + traceIDVersion = "1" + traceIDDelimiter = "-" + isSampled = "1" + notSampled = "0" + + traceFlagNone = 0x0 + traceFlagSampled = 0x1 << 0 + traceIDLength = 35 + traceIDDelimitterIndex1 = 1 + traceIDDelimitterIndex2 = 10 + traceIDFirstPartLength = 8 + sampledFlagLength = 1 +) + +var ( + empty = trace.SpanContext{} + errInvalidTraceHeader = errors.New("invalid X-Amzn-Trace-Id header value, should contain 3 different part separated by ;") + errMalformedTraceID = errors.New("cannot decode trace ID from header") + errLengthTraceIDHeader = errors.New("incorrect length of X-Ray trace ID found, 35 character length expected") + errInvalidTraceIDVersion = errors.New("invalid X-Ray trace ID header found, does not have valid trace ID version") + errInvalidSpanIDLength = errors.New("invalid span ID length, must be 16") +) + +// extract extracts Span Context from context. +func extract(headerVal string) (trace.SpanContext, error) { + var ( + scc = trace.SpanContextConfig{} + err error + delimiterIndex int + part string + ) + pos := 0 + for pos < len(headerVal) { + delimiterIndex = indexOf(headerVal, traceHeaderDelimiter, pos) + if delimiterIndex >= 0 { + part = headerVal[pos:delimiterIndex] + pos = delimiterIndex + 1 + } else { + // last part + part = strings.TrimSpace(headerVal[pos:]) + pos = len(headerVal) + } + equalsIndex := strings.Index(part, kvDelimiter) + if equalsIndex < 0 { + return empty, errInvalidTraceHeader + } + value := part[equalsIndex+1:] + switch { + case strings.HasPrefix(part, traceIDKey): + scc.TraceID, err = parseTraceID(value) + if err != nil { + return empty, err + } + case strings.HasPrefix(part, parentIDKey): + // extract parentId + scc.SpanID, err = trace.SpanIDFromHex(value) + if err != nil { + return empty, errInvalidSpanIDLength + } + case strings.HasPrefix(part, sampleFlagKey): + // extract traceflag + scc.TraceFlags = parseTraceFlag(value) + } + } + return trace.NewSpanContext(scc), nil +} + +// indexOf returns position of the first occurrence of a substr in str starting at pos index. +func indexOf(str, substr string, pos int) int { + index := strings.Index(str[pos:], substr) + if index > -1 { + index += pos + } + return index +} + +// parseTraceID returns trace ID if valid else return invalid trace ID. +func parseTraceID(xrayTraceID string) (trace.TraceID, error) { + if len(xrayTraceID) != traceIDLength { + return empty.TraceID(), errLengthTraceIDHeader + } + if !strings.HasPrefix(xrayTraceID, traceIDVersion) { + return empty.TraceID(), errInvalidTraceIDVersion + } + + if xrayTraceID[traceIDDelimitterIndex1:traceIDDelimitterIndex1+1] != traceIDDelimiter || + xrayTraceID[traceIDDelimitterIndex2:traceIDDelimitterIndex2+1] != traceIDDelimiter { + return empty.TraceID(), errMalformedTraceID + } + + epochPart := xrayTraceID[traceIDDelimitterIndex1+1 : traceIDDelimitterIndex2] + uniquePart := xrayTraceID[traceIDDelimitterIndex2+1 : traceIDLength] + + result := epochPart + uniquePart + return trace.TraceIDFromHex(result) +} + +// parseTraceFlag returns a parsed trace flag. +func parseTraceFlag(xraySampledFlag string) trace.TraceFlags { + // Use a direct comparison here (#7262). + if xraySampledFlag == isSampled { + return trace.FlagsSampled + } + return trace.FlagsSampled.WithSampled(false) +}