Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 1 addition & 109 deletions dataproxy/service/dataproxy_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ import (
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task/taskconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/trigger"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/trigger/triggerconnect"
workflowpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect"
)

type Service struct {
Expand All @@ -39,17 +37,15 @@ type Service struct {
dataStore *storage.DataStore
taskClient taskconnect.TaskServiceClient
triggerClient triggerconnect.TriggerServiceClient
runClient workflowconnect.RunServiceClient
}

// NewService creates a new DataProxyService instance.
func NewService(cfg config.DataProxyConfig, dataStore *storage.DataStore, taskClient taskconnect.TaskServiceClient, triggerClient triggerconnect.TriggerServiceClient, runClient workflowconnect.RunServiceClient) *Service {
func NewService(cfg config.DataProxyConfig, dataStore *storage.DataStore, taskClient taskconnect.TaskServiceClient, triggerClient triggerconnect.TriggerServiceClient) *Service {
return &Service{
cfg: cfg,
dataStore: dataStore,
taskClient: taskClient,
triggerClient: triggerClient,
runClient: runClient,
}
}

Expand Down Expand Up @@ -274,110 +270,6 @@ func (s *Service) UploadInputs(
}), nil
}

// CreateDownloadLink generates signed URL(s) for downloading an artifact associated with a run action.
func (s *Service) CreateDownloadLink(
ctx context.Context,
req *connect.Request[dataproxy.CreateDownloadLinkRequest],
) (*connect.Response[dataproxy.CreateDownloadLinkResponse], error) {
logger.Infof(ctx, "CreateDownloadLink request received")

if err := req.Msg.Validate(); err != nil {
logger.Errorf(ctx, "Invalid CreateDownloadLink request: %v", err)
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}

if req.Msg.GetArtifactType() == dataproxy.ArtifactType_ARTIFACT_TYPE_UNSPECIFIED {
return nil, connect.NewError(connect.CodeInvalidArgument,
fmt.Errorf("artifact_type is required"))
}

// Set expires_in to default if not provided in request
if req.Msg.GetExpiresIn() == nil {
req.Msg.ExpiresIn = durationpb.New(s.cfg.Download.MaxExpiresIn.Duration)
}
expiresIn := req.Msg.GetExpiresIn().AsDuration()

nativeURL, err := s.resolveArtifactURL(ctx, req.Msg)
if err != nil {
return nil, err
}

ref := storage.DataReference(nativeURL)
meta, err := s.dataStore.Head(ctx, ref)
if err != nil {
logger.Errorf(ctx, "Failed to head artifact at [%s]: %v", nativeURL, err)
return nil, connect.NewError(connect.CodeInternal,
fmt.Errorf("failed to check artifact existence: %w", err))
}
if !meta.Exists() {
return nil, connect.NewError(connect.CodeNotFound,
fmt.Errorf("artifact not found at [%s]", nativeURL))
}

signedResp, err := s.dataStore.CreateSignedURL(ctx, ref, storage.SignedURLProperties{
Scope: stow.ClientMethodGet,
ExpiresIn: expiresIn,
})
if err != nil {
logger.Errorf(ctx, "Failed to create signed URL for [%s]: %v", nativeURL, err)
return nil, connect.NewError(connect.CodeInternal,
fmt.Errorf("failed to create signed URL: %w", err))
}

expiresAt := timestamppb.New(time.Now().Add(expiresIn))
return connect.NewResponse(&dataproxy.CreateDownloadLinkResponse{
PreSignedUrls: &dataproxy.PreSignedURLs{
SignedUrl: []string{signedResp.URL.String()},
ExpiresAt: expiresAt,
},
}), nil
}

// resolveArtifactURL resolves the native storage URL for the requested artifact type and source.
func (s *Service) resolveArtifactURL(ctx context.Context, req *dataproxy.CreateDownloadLinkRequest) (string, error) {
attemptIDEnvelope, ok := req.GetSource().(*dataproxy.CreateDownloadLinkRequest_ActionAttemptId)
if !ok {
return "", connect.NewError(connect.CodeInvalidArgument,
fmt.Errorf("unsupported source type"))
}

attemptID := attemptIDEnvelope.ActionAttemptId
actionResp, err := s.runClient.GetActionDetails(ctx, connect.NewRequest(&workflowpb.GetActionDetailsRequest{
ActionId: attemptID.GetActionId(),
}))
if err != nil {
logger.Errorf(ctx, "Failed to get action details for %v: %v", attemptID.GetActionId(), err)
return "", connect.NewError(connect.CodeNotFound,
fmt.Errorf("failed to get action details: %w", err))
}

// Find the matching attempt by attempt number.
var matchedAttempt *workflowpb.ActionAttempt
for _, attempt := range actionResp.Msg.GetDetails().GetAttempts() {
if attempt.GetAttempt() == attemptID.GetAttempt() {
matchedAttempt = attempt
break
}
}
if matchedAttempt == nil {
return "", connect.NewError(connect.CodeNotFound,
fmt.Errorf("attempt %d not found for action [%v]", attemptID.GetAttempt(), attemptID.GetActionId()))
}

switch req.GetArtifactType() {
case dataproxy.ArtifactType_ARTIFACT_TYPE_REPORT:
reportURI := matchedAttempt.GetOutputs().GetReportUri()
if reportURI == "" {
return "", connect.NewError(connect.CodeNotFound,
fmt.Errorf("no report URI found for action [%v] attempt %d", attemptID.GetActionId(), attemptID.GetAttempt()))
}
return reportURI, nil
default:
return "", connect.NewError(connect.CodeInvalidArgument,
fmt.Errorf("unsupported artifact type: %v", req.GetArtifactType()))
}
}

// resolveTaskTemplate resolves the task template from the request's task oneof.
func (s *Service) resolveTaskTemplate(ctx context.Context, req *dataproxy.UploadInputsRequest) (*flyteIdlCore.TaskTemplate, error) {
switch t := req.Task.(type) {
Expand Down
8 changes: 4 additions & 4 deletions dataproxy/service/dataproxy_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestCreateUploadLocation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore := setupMockDataStore(t)
service := NewService(cfg, mockStore, nil, nil, nil)
service := NewService(cfg, mockStore, nil, nil)

req := &connect.Request[dataproxy.CreateUploadLocationRequest]{
Msg: tt.req,
Expand Down Expand Up @@ -218,7 +218,7 @@ func TestCheckFileExists(t *testing.T) {
mockStore = setupMockDataStoreWithExistingFile(t, tt.existingFileMD5)
}

service := NewService(cfg, mockStore, nil, nil, nil)
service := NewService(cfg, mockStore, nil, nil)
storagePath := storage.DataReference("s3://test-bucket/uploads/test-project/test-domain/test-root/test-file.txt")

err := service.checkFileExists(ctx, storagePath, tt.req)
Expand Down Expand Up @@ -296,7 +296,7 @@ func TestConstructStoragePath(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore := setupMockDataStore(t)
service := NewService(cfg, mockStore, nil, nil, nil)
service := NewService(cfg, mockStore, nil, nil)

path, err := service.constructStoragePath(ctx, tt.req)

Expand Down Expand Up @@ -453,7 +453,7 @@ func TestUploadInputs(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore := setupMockDataStoreWithWriteProtobuf(t)
svc := NewService(cfg, mockStore, nil, nil, nil)
svc := NewService(cfg, mockStore, nil, nil)

req := &connect.Request[dataproxy.UploadInputsRequest]{
Msg: tt.req,
Expand Down
4 changes: 1 addition & 3 deletions dataproxy/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy/dataproxyconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task/taskconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/trigger/triggerconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect"

"github.com/flyteorg/flyte/v2/flytestdlib/logger"
)
Expand All @@ -25,9 +24,8 @@ func Setup(ctx context.Context, sc *app.SetupContext) error {
baseURL := sc.BaseURL
taskClient := taskconnect.NewTaskServiceClient(http.DefaultClient, baseURL)
triggerClient := triggerconnect.NewTriggerServiceClient(http.DefaultClient, baseURL)
runClient := workflowconnect.NewRunServiceClient(http.DefaultClient, baseURL)

svc := service.NewService(*cfg, sc.DataStore, taskClient, triggerClient, runClient)
svc := service.NewService(*cfg, sc.DataStore, taskClient, triggerClient)

path, handler := dataproxyconnect.NewDataProxyServiceHandler(svc)
sc.Mux.Handle(path, handler)
Expand Down
57 changes: 0 additions & 57 deletions flyteidl2/dataproxy/dataproxy_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@ import "protoc-gen-openapiv2/options/annotations.proto";

option go_package = "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy";

// ArtifactType defines the type of artifact to be downloaded.
enum ArtifactType {
ARTIFACT_TYPE_UNSPECIFIED = 0;
// ARTIFACT_TYPE_REPORT refers to the HTML report file optionally generated after a task finishes executing.
ARTIFACT_TYPE_REPORT = 1;
}

// DataProxyService provides an interface for managing data uploads and downloads.
service DataProxyService {
// CreateUploadLocation generates a signed URL for uploading data to the configured storage backend.
Expand All @@ -47,19 +40,6 @@ service DataProxyService {
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {description: "Uploads inputs for a given run or project and returns a URI and cache key that can be used to reference these inputs at runtime."};
}

// CreateDownloadLink generates signed URL(s) for downloading a given artifact.
rpc CreateDownloadLink(CreateDownloadLinkRequest) returns (CreateDownloadLinkResponse) {
option (google.api.http) = {
post: "/api/v1/dataproxy/artifact_urn/download"
body: "*"
additional_bindings: {
post: "/api/v1/org/dataproxy/artifact_urn/download"
body: "*"
}
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {description: "Creates signed URL(s) for downloading an artifact associated with a run action."};
}
}

// CreateUploadLocationRequest specifies the request for the CreateUploadLocation API.
Expand Down Expand Up @@ -160,40 +140,3 @@ message UploadInputsRequest {
message UploadInputsResponse {
common.OffloadedInputData offloaded_input_data = 1;
}

// PreSignedURLs contains a list of signed URLs for downloading artifacts.
message PreSignedURLs {
// SignedUrl are the pre-signed URLs for downloading the artifact.
repeated string signed_url = 1;

// ExpiresAt defines when the signed URLs will expire.
google.protobuf.Timestamp expires_at = 2;
}

// CreateDownloadLinkRequest specifies the request for the CreateDownloadLink API.
message CreateDownloadLinkRequest {
// ArtifactType is the type of artifact to download.
// +required
ArtifactType artifact_type = 1 [(buf.validate.field).enum = {
not_in: [0]
}];

// Source identifies the action attempt whose artifact is to be downloaded.
oneof source {
option (buf.validate.oneof).required = true;

// ActionAttemptId identifies the specific attempt of a run action that produced the artifact.
common.ActionAttemptIdentifier action_attempt_id = 2;
}

// ExpiresIn defines the requested expiration duration for the generated URLs. The request will be
// rejected if this exceeds the platform's configured maximum.
// +optional. The default value comes from the global config.
google.protobuf.Duration expires_in = 3;
}

// CreateDownloadLinkResponse specifies the response for the CreateDownloadLink API.
message CreateDownloadLinkResponse {
// PreSignedUrls contains the signed URLs and their expiration time.
PreSignedURLs pre_signed_urls = 1;
}
Loading
Loading