Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions dataproxy/service/dataproxy_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"connectrpc.com/connect"
"github.com/flyteorg/stow"
"github.com/samber/lo"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
Expand Down Expand Up @@ -435,6 +436,67 @@ func filterInputs(inputs *task.Inputs, ignoreVars []string) *task.Inputs {
return &task.Inputs{Literals: filtered}
}

// GetActionData gets input and output data for an action by calling RunService for URIs
// and reading the data from storage.
func (s *Service) GetActionData(
ctx context.Context,
req *connect.Request[dataproxy.GetActionDataRequest],
) (*connect.Response[dataproxy.GetActionDataResponse], error) {
actionId := req.Msg.GetActionId()

urisResp, err := s.runClient.GetActionDataURIs(ctx, connect.NewRequest(&workflowpb.GetActionDataURIsRequest{
ActionId: actionId,
}))
if err != nil {
return nil, err
}

resp := &dataproxy.GetActionDataResponse{
Inputs: &task.Inputs{},
Outputs: &task.Outputs{},
}

group, groupCtx := errgroup.WithContext(ctx)

if urisResp.Msg.GetInputsUri() != "" {
group.Go(func() error {
baseRef := storage.DataReference(urisResp.Msg.GetInputsUri())
inputRef, err := s.dataStore.ConstructReference(groupCtx, baseRef, "inputs.pb")
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to construct input ref: %w", err))
}
logger.Infof(groupCtx, "GetActionData: reading inputs from %s", inputRef)
if err := s.dataStore.ReadProtobuf(groupCtx, inputRef, resp.Inputs); err != nil {
logger.Errorf(groupCtx, "GetActionData: failed to read inputs from %s: %v", inputRef, err)
return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read inputs from %s: %w", inputRef, err))
}
return nil
})
}

if urisResp.Msg.GetOutputsUri() != "" {
group.Go(func() error {
outputRef := storage.DataReference(urisResp.Msg.GetOutputsUri())
logger.Infof(groupCtx, "GetActionData: reading outputs from %s", outputRef)
var inputsOrOutputs task.Inputs
if err := s.dataStore.ReadProtobuf(groupCtx, outputRef, &inputsOrOutputs); err != nil {
logger.Errorf(groupCtx, "GetActionData: failed to read outputs from %s: %v", outputRef, err)
return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read outputs from %s: %w", outputRef, err))
}
resp.Outputs = &task.Outputs{
Literals: inputsOrOutputs.GetLiterals(),
}
return nil
})
}

if err := group.Wait(); err != nil {
return nil, err
}

return connect.NewResponse(resp), nil
}

// hashInputsProto computes a deterministic FNV-64a hash of the serialized inputs.
func hashInputsProto(inputs proto.Message) (string, error) {
marshaller := proto.MarshalOptions{Deterministic: true}
Expand Down
20 changes: 20 additions & 0 deletions flyteidl2/dataproxy/dataproxy_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ service DataProxyService {
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {description: "Creates signed URL(s) for downloading an artifact associated with a run action."};
}

// Get input and output data for an action.
rpc GetActionData(GetActionDataRequest) returns (GetActionDataResponse) {
option idempotency_level = NO_SIDE_EFFECTS;
}
}

// CreateUploadLocationRequest specifies the request for the CreateUploadLocation API.
Expand Down Expand Up @@ -197,3 +202,18 @@ message CreateDownloadLinkResponse {
// PreSignedUrls contains the signed URLs and their expiration time.
PreSignedURLs pre_signed_urls = 1;
}

// Request message for querying action data.
message GetActionDataRequest {
// Action to query.
common.ActionIdentifier action_id = 1 [(buf.validate.field).required = true];
}

// Response message for querying action data.
message GetActionDataResponse {
// Inputs for the action.
task.Inputs inputs = 1;

// Outputs for the action.
task.Outputs outputs = 2;
}
23 changes: 22 additions & 1 deletion flyteidl2/workflow/run_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ service RunService {
// Stream detailed information updates about an action. The call will terminate when the action reaches a terminal phase.
rpc WatchActionDetails(WatchActionDetailsRequest) returns (stream WatchActionDetailsResponse) {}

// Get input and output for an action.
// Deprecated: Use DataProxyService.GetActionData instead.
rpc GetActionData(GetActionDataRequest) returns (GetActionDataResponse) {
option deprecated = true;
option idempotency_level = NO_SIDE_EFFECTS;
}

Expand Down Expand Up @@ -69,6 +70,11 @@ service RunService {

// Stream updates for task groups based on the provided filter criteria.
rpc WatchGroups(WatchGroupsRequest) returns (stream WatchGroupsResponse) {}

// Get the storage URIs for an action's input and output data.
rpc GetActionDataURIs(GetActionDataURIsRequest) returns (GetActionDataURIsResponse) {
option idempotency_level = NO_SIDE_EFFECTS;
}
}

// Request message for creating a run.
Expand Down Expand Up @@ -193,6 +199,21 @@ message GetActionDataResponse {
task.Outputs outputs = 2;
}

// Request message for getting action data URIs.
message GetActionDataURIsRequest {
// Action to query.
common.ActionIdentifier action_id = 1 [(buf.validate.field).required = true];
}

// Response message for getting action data URIs.
message GetActionDataURIsResponse {
// URI for the action's input data.
string inputs_uri = 1;

// URI for the action's output data. Empty if action hasn't succeeded or has no outputs.
string outputs_uri = 2;
}

// Request message for listing runs.
message ListRunsRequest {
reserved 3, 5; // Deprecated
Expand Down
Loading
Loading