Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion dataproxy/service/dataproxy_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"connectrpc.com/connect"
"github.com/flyteorg/stow"
"golang.org/x/sync/errgroup"
"github.com/samber/lo"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb"
Expand All @@ -28,6 +29,8 @@ import (
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task/taskconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/trigger"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/trigger/triggerconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect"
)

type Service struct {
Expand All @@ -37,15 +40,17 @@ type Service struct {
dataStore *storage.DataStore
taskClient taskconnect.TaskServiceClient
triggerClient triggerconnect.TriggerServiceClient
runClient workflowconnect.RunServiceClient
}

// NewService creates a new DataProxyService instance.
func NewService(cfg config.DataProxyConfig, dataStore *storage.DataStore, taskClient taskconnect.TaskServiceClient, triggerClient triggerconnect.TriggerServiceClient) *Service {
func NewService(cfg config.DataProxyConfig, dataStore *storage.DataStore, taskClient taskconnect.TaskServiceClient, triggerClient triggerconnect.TriggerServiceClient, runClient workflowconnect.RunServiceClient) *Service {
return &Service{
cfg: cfg,
dataStore: dataStore,
taskClient: taskClient,
triggerClient: triggerClient,
runClient: runClient,
}
}

Expand Down Expand Up @@ -327,6 +332,67 @@ func filterInputs(inputs *task.Inputs, ignoreVars []string) *task.Inputs {
return &task.Inputs{Literals: filtered}
}

// GetActionData gets input and output data for an action by calling RunService for URIs
// and reading the data from storage.
func (s *Service) GetActionData(
ctx context.Context,
req *connect.Request[dataproxy.GetActionDataRequest],
) (*connect.Response[dataproxy.GetActionDataResponse], error) {
actionId := req.Msg.GetActionId()

urisResp, err := s.runClient.GetActionDataURIs(ctx, connect.NewRequest(&workflow.GetActionDataURIsRequest{
ActionId: actionId,
}))
if err != nil {
return nil, err
}

resp := &dataproxy.GetActionDataResponse{
Inputs: &task.Inputs{},
Outputs: &task.Outputs{},
}

group, groupCtx := errgroup.WithContext(ctx)

if urisResp.Msg.GetInputsUri() != "" {
group.Go(func() error {
baseRef := storage.DataReference(urisResp.Msg.GetInputsUri())
inputRef, err := s.dataStore.ConstructReference(groupCtx, baseRef, "inputs.pb")
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to construct input ref: %w", err))
}
logger.Infof(groupCtx, "GetActionData: reading inputs from %s", inputRef)
if err := s.dataStore.ReadProtobuf(groupCtx, inputRef, resp.Inputs); err != nil {
logger.Errorf(groupCtx, "GetActionData: failed to read inputs from %s: %v", inputRef, err)
return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read inputs from %s: %w", inputRef, err))
}
return nil
})
}

if urisResp.Msg.GetOutputsUri() != "" {
group.Go(func() error {
outputRef := storage.DataReference(urisResp.Msg.GetOutputsUri())
logger.Infof(groupCtx, "GetActionData: reading outputs from %s", outputRef)
var inputsOrOutputs task.Inputs
if err := s.dataStore.ReadProtobuf(groupCtx, outputRef, &inputsOrOutputs); err != nil {
logger.Errorf(groupCtx, "GetActionData: failed to read outputs from %s: %v", outputRef, err)
return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read outputs from %s: %w", outputRef, err))
}
resp.Outputs = &task.Outputs{
Literals: inputsOrOutputs.GetLiterals(),
}
return nil
})
}

if err := group.Wait(); err != nil {
return nil, err
}

return connect.NewResponse(resp), nil
}

// hashInputsProto computes a deterministic FNV-64a hash of the serialized inputs.
func hashInputsProto(inputs proto.Message) (string, error) {
marshaller := proto.MarshalOptions{Deterministic: true}
Expand Down
8 changes: 4 additions & 4 deletions dataproxy/service/dataproxy_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestCreateUploadLocation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore := setupMockDataStore(t)
service := NewService(cfg, mockStore, nil, nil)
service := NewService(cfg, mockStore, nil, nil, nil)

req := &connect.Request[dataproxy.CreateUploadLocationRequest]{
Msg: tt.req,
Expand Down Expand Up @@ -218,7 +218,7 @@ func TestCheckFileExists(t *testing.T) {
mockStore = setupMockDataStoreWithExistingFile(t, tt.existingFileMD5)
}

service := NewService(cfg, mockStore, nil, nil)
service := NewService(cfg, mockStore, nil, nil, nil)
storagePath := storage.DataReference("s3://test-bucket/uploads/test-project/test-domain/test-root/test-file.txt")

err := service.checkFileExists(ctx, storagePath, tt.req)
Expand Down Expand Up @@ -296,7 +296,7 @@ func TestConstructStoragePath(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore := setupMockDataStore(t)
service := NewService(cfg, mockStore, nil, nil)
service := NewService(cfg, mockStore, nil, nil, nil)

path, err := service.constructStoragePath(ctx, tt.req)

Expand Down Expand Up @@ -453,7 +453,7 @@ func TestUploadInputs(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore := setupMockDataStoreWithWriteProtobuf(t)
svc := NewService(cfg, mockStore, nil, nil)
svc := NewService(cfg, mockStore, nil, nil, nil)

req := &connect.Request[dataproxy.UploadInputsRequest]{
Msg: tt.req,
Expand Down
4 changes: 3 additions & 1 deletion dataproxy/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy/dataproxyconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task/taskconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/trigger/triggerconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect"

"github.com/flyteorg/flyte/v2/flytestdlib/logger"
)
Expand All @@ -23,8 +24,9 @@ func Setup(ctx context.Context, sc *app.SetupContext) error {
baseURL := sc.BaseURL
taskClient := taskconnect.NewTaskServiceClient(http.DefaultClient, baseURL)
triggerClient := triggerconnect.NewTriggerServiceClient(http.DefaultClient, baseURL)
runClient := workflowconnect.NewRunServiceClient(http.DefaultClient, baseURL)

svc := service.NewService(*cfg, sc.DataStore, taskClient, triggerClient)
svc := service.NewService(*cfg, sc.DataStore, taskClient, triggerClient, runClient)

path, handler := dataproxyconnect.NewDataProxyServiceHandler(svc)
sc.Mux.Handle(path, handler)
Expand Down
20 changes: 20 additions & 0 deletions flyteidl2/dataproxy/dataproxy_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ service DataProxyService {
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {description: "Uploads inputs for a given run or project and returns a URI and cache key that can be used to reference these inputs at runtime."};
}

// Get input and output data for an action.
rpc GetActionData(GetActionDataRequest) returns (GetActionDataResponse) {
option idempotency_level = NO_SIDE_EFFECTS;
}
}

// CreateUploadLocationRequest specifies the request for the CreateUploadLocation API.
Expand Down Expand Up @@ -140,3 +145,18 @@ message UploadInputsRequest {
message UploadInputsResponse {
common.OffloadedInputData offloaded_input_data = 1;
}

// Request message for querying action data.
message GetActionDataRequest {
// Action to query.
common.ActionIdentifier action_id = 1 [(buf.validate.field).required = true];
}

// Response message for querying action data.
message GetActionDataResponse {
// Inputs for the action.
task.Inputs inputs = 1;

// Outputs for the action.
task.Outputs outputs = 2;
}
23 changes: 22 additions & 1 deletion flyteidl2/workflow/run_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ service RunService {
// Stream detailed information updates about an action. The call will terminate when the action reaches a terminal phase.
rpc WatchActionDetails(WatchActionDetailsRequest) returns (stream WatchActionDetailsResponse) {}

// Get input and output for an action.
// Deprecated: Use DataProxyService.GetActionData instead.
rpc GetActionData(GetActionDataRequest) returns (GetActionDataResponse) {
option deprecated = true;
option idempotency_level = NO_SIDE_EFFECTS;
}

Expand Down Expand Up @@ -69,6 +70,11 @@ service RunService {

// Stream updates for task groups based on the provided filter criteria.
rpc WatchGroups(WatchGroupsRequest) returns (stream WatchGroupsResponse) {}

// Get the storage URIs for an action's input and output data.
rpc GetActionDataURIs(GetActionDataURIsRequest) returns (GetActionDataURIsResponse) {
option idempotency_level = NO_SIDE_EFFECTS;
}
}

// Request message for creating a run.
Expand Down Expand Up @@ -193,6 +199,21 @@ message GetActionDataResponse {
task.Outputs outputs = 2;
}

// Request message for getting action data URIs.
message GetActionDataURIsRequest {
// Action to query.
common.ActionIdentifier action_id = 1 [(buf.validate.field).required = true];
}

// Response message for getting action data URIs.
message GetActionDataURIsResponse {
// URI for the action's input data.
string inputs_uri = 1;

// URI for the action's output data. Empty if action hasn't succeeded or has no outputs.
string outputs_uri = 2;
}

// Request message for listing runs.
message ListRunsRequest {
reserved 3, 5; // Deprecated
Expand Down
Loading
Loading