Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion dataproxy/service/dataproxy_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"connectrpc.com/connect"
"github.com/flyteorg/stow"
"github.com/samber/lo"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
Expand All @@ -28,6 +29,8 @@ import (
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task/taskconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/trigger"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/trigger/triggerconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect"
)

type Service struct {
Expand All @@ -37,15 +40,17 @@ type Service struct {
dataStore *storage.DataStore
taskClient taskconnect.TaskServiceClient
triggerClient triggerconnect.TriggerServiceClient
runClient workflowconnect.RunServiceClient
}

// NewService creates a new DataProxyService instance.
func NewService(cfg config.DataProxyConfig, dataStore *storage.DataStore, taskClient taskconnect.TaskServiceClient, triggerClient triggerconnect.TriggerServiceClient) *Service {
func NewService(cfg config.DataProxyConfig, dataStore *storage.DataStore, taskClient taskconnect.TaskServiceClient, triggerClient triggerconnect.TriggerServiceClient, runClient workflowconnect.RunServiceClient) *Service {
return &Service{
cfg: cfg,
dataStore: dataStore,
taskClient: taskClient,
triggerClient: triggerClient,
runClient: runClient,
}
}

Expand Down Expand Up @@ -327,6 +332,67 @@ func filterInputs(inputs *task.Inputs, ignoreVars []string) *task.Inputs {
return &task.Inputs{Literals: filtered}
}

// GetActionData gets input and output data for an action by calling RunService for URIs
// and reading the data from storage.
func (s *Service) GetActionData(
ctx context.Context,
req *connect.Request[dataproxy.GetActionDataRequest],
) (*connect.Response[dataproxy.GetActionDataResponse], error) {
actionId := req.Msg.GetActionId()

urisResp, err := s.runClient.GetActionDataURIs(ctx, connect.NewRequest(&workflow.GetActionDataURIsRequest{
ActionId: actionId,
}))
if err != nil {
return nil, err
}

resp := &dataproxy.GetActionDataResponse{
Inputs: &task.Inputs{},
Outputs: &task.Outputs{},
}

group, groupCtx := errgroup.WithContext(ctx)

if urisResp.Msg.GetInputsUri() != "" {
group.Go(func() error {
baseRef := storage.DataReference(urisResp.Msg.GetInputsUri())
inputRef, err := s.dataStore.ConstructReference(groupCtx, baseRef, "inputs.pb")
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to construct input ref: %w", err))
}
logger.Infof(groupCtx, "GetActionData: reading inputs from %s", inputRef)
if err := s.dataStore.ReadProtobuf(groupCtx, inputRef, resp.Inputs); err != nil {
logger.Errorf(groupCtx, "GetActionData: failed to read inputs from %s: %v", inputRef, err)
return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read inputs from %s: %w", inputRef, err))
}
return nil
})
}

if urisResp.Msg.GetOutputsUri() != "" {
group.Go(func() error {
outputRef := storage.DataReference(urisResp.Msg.GetOutputsUri())
logger.Infof(groupCtx, "GetActionData: reading outputs from %s", outputRef)
var inputsOrOutputs task.Inputs
if err := s.dataStore.ReadProtobuf(groupCtx, outputRef, &inputsOrOutputs); err != nil {
logger.Errorf(groupCtx, "GetActionData: failed to read outputs from %s: %v", outputRef, err)
return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read outputs from %s: %w", outputRef, err))
}
resp.Outputs = &task.Outputs{
Literals: inputsOrOutputs.GetLiterals(),
}
return nil
})
}

if err := group.Wait(); err != nil {
return nil, err
}

return connect.NewResponse(resp), nil
}

// hashInputsProto computes a deterministic FNV-64a hash of the serialized inputs.
func hashInputsProto(inputs proto.Message) (string, error) {
marshaller := proto.MarshalOptions{Deterministic: true}
Expand Down
173 changes: 166 additions & 7 deletions dataproxy/service/dataproxy_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"connectrpc.com/connect"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb"
"k8s.io/apimachinery/pkg/api/resource"

Expand All @@ -21,6 +22,8 @@ import (
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow"
workflowMocks "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect/mocks"
)

func TestCreateUploadLocation(t *testing.T) {
Expand Down Expand Up @@ -97,7 +100,7 @@ func TestCreateUploadLocation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore := setupMockDataStore(t)
service := NewService(cfg, mockStore, nil, nil)
service := NewService(cfg, mockStore, nil, nil, nil)

req := &connect.Request[dataproxy.CreateUploadLocationRequest]{
Msg: tt.req,
Expand Down Expand Up @@ -218,7 +221,7 @@ func TestCheckFileExists(t *testing.T) {
mockStore = setupMockDataStoreWithExistingFile(t, tt.existingFileMD5)
}

service := NewService(cfg, mockStore, nil, nil)
service := NewService(cfg, mockStore, nil, nil, nil)
storagePath := storage.DataReference("s3://test-bucket/uploads/test-project/test-domain/test-root/test-file.txt")

err := service.checkFileExists(ctx, storagePath, tt.req)
Expand Down Expand Up @@ -296,7 +299,7 @@ func TestConstructStoragePath(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore := setupMockDataStore(t)
service := NewService(cfg, mockStore, nil, nil)
service := NewService(cfg, mockStore, nil, nil, nil)

path, err := service.constructStoragePath(ctx, tt.req)

Expand Down Expand Up @@ -389,7 +392,7 @@ func TestUploadInputs(t *testing.T) {
Name: "test-run",
},
},
Task: &dataproxy.UploadInputsRequest_TaskSpec{TaskSpec: testTaskSpec},
Task: &dataproxy.UploadInputsRequest_TaskSpec{TaskSpec: testTaskSpec},
Inputs: &task.Inputs{
Literals: []*task.NamedLiteral{
{Name: "x", Value: &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_Primitive{Primitive: &core.Primitive{Value: &core.Primitive_Integer{Integer: 42}}}}}}},
Expand All @@ -414,7 +417,7 @@ func TestUploadInputs(t *testing.T) {
Domain: "test-domain",
},
},
Task: &dataproxy.UploadInputsRequest_TaskSpec{TaskSpec: testTaskSpec},
Task: &dataproxy.UploadInputsRequest_TaskSpec{TaskSpec: testTaskSpec},
Inputs: &task.Inputs{
Literals: []*task.NamedLiteral{
{Name: "y", Value: &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_Primitive{Primitive: &core.Primitive{Value: &core.Primitive_StringValue{StringValue: "hello"}}}}}}},
Expand All @@ -435,7 +438,7 @@ func TestUploadInputs(t *testing.T) {
Org: "org", Project: "proj", Domain: "dom", Name: "run1",
},
},
Task: &dataproxy.UploadInputsRequest_TaskSpec{TaskSpec: testTaskSpecWithIgnoredVars},
Task: &dataproxy.UploadInputsRequest_TaskSpec{TaskSpec: testTaskSpecWithIgnoredVars},
Inputs: &task.Inputs{
Literals: []*task.NamedLiteral{
{Name: "x", Value: &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_Primitive{Primitive: &core.Primitive{Value: &core.Primitive_Integer{Integer: 1}}}}}}},
Expand All @@ -453,7 +456,7 @@ func TestUploadInputs(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore := setupMockDataStoreWithWriteProtobuf(t)
svc := NewService(cfg, mockStore, nil, nil)
svc := NewService(cfg, mockStore, nil, nil, nil)

req := &connect.Request[dataproxy.UploadInputsRequest]{
Msg: tt.req,
Expand Down Expand Up @@ -490,6 +493,162 @@ func setupMockDataStoreWithWriteProtobuf(t *testing.T) *storage.DataStore {
}
}

func TestGetActionData(t *testing.T) {
ctx := context.Background()
cfg := config.DataProxyConfig{}

actionID := &common.ActionIdentifier{
Name: "a0",
Run: &common.RunIdentifier{
Org: "org", Project: "proj", Domain: "dom", Name: "run1",
},
}

storedInputs := &task.Inputs{
Literals: []*task.NamedLiteral{
{Name: "x", Value: &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_Primitive{Primitive: &core.Primitive{Value: &core.Primitive_Integer{Integer: 1}}}}}}},
},
}
storedOutputs := &task.Inputs{
Literals: []*task.NamedLiteral{
{Name: "o", Value: &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_Primitive{Primitive: &core.Primitive{Value: &core.Primitive_StringValue{StringValue: "result"}}}}}}},
},
}

tests := []struct {
name string
inputsURI string
outputsURI string
runClientErr error
readInputsErr error
readOutputsErr error
wantErr bool
expectInputsLen int
expectOutputsLen int
}{
{
name: "success with both inputs and outputs",
inputsURI: "s3://test-bucket/inputs-dir",
outputsURI: "s3://test-bucket/outputs/outputs.pb",
expectInputsLen: 1,
expectOutputsLen: 1,
},
{
name: "success with only inputs",
inputsURI: "s3://test-bucket/inputs-dir",
outputsURI: "",
expectInputsLen: 1,
expectOutputsLen: 0,
},
{
name: "success with only outputs",
inputsURI: "",
outputsURI: "s3://test-bucket/outputs/outputs.pb",
expectInputsLen: 0,
expectOutputsLen: 1,
},
{
name: "success with neither inputs nor outputs",
inputsURI: "",
outputsURI: "",
expectInputsLen: 0,
expectOutputsLen: 0,
},
{
name: "RunService error propagates",
runClientErr: connect.NewError(connect.CodeNotFound, assertErr("not found")),
wantErr: true,
},
{
name: "read inputs error propagates",
inputsURI: "s3://test-bucket/inputs-dir",
readInputsErr: assertErr("read failed"),
wantErr: true,
},
{
name: "read outputs error propagates",
outputsURI: "s3://test-bucket/outputs/outputs.pb",
readOutputsErr: assertErr("read failed"),
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
runClient := workflowMocks.NewRunServiceClient(t)
if tt.runClientErr != nil {
runClient.EXPECT().GetActionDataURIs(mock.Anything, mock.Anything).Return(nil, tt.runClientErr)
} else {
runClient.EXPECT().GetActionDataURIs(mock.Anything, mock.Anything).Return(
connect.NewResponse(&workflow.GetActionDataURIsResponse{
InputsUri: tt.inputsURI,
OutputsUri: tt.outputsURI,
}), nil)
}

mockComposedStore := storageMocks.NewComposedProtobufStore(t)

if tt.inputsURI != "" {
expectedInputRef := storage.DataReference(tt.inputsURI + "/inputs.pb")
call := mockComposedStore.On("ReadProtobuf", mock.Anything, expectedInputRef, mock.Anything)
if tt.readInputsErr != nil {
call.Return(tt.readInputsErr).Maybe()
} else {
call.Run(func(args mock.Arguments) {
msg := args.Get(2).(proto.Message)
proto.Reset(msg)
proto.Merge(msg, storedInputs)
}).Return(nil).Maybe()
}
}

if tt.outputsURI != "" {
expectedOutputRef := storage.DataReference(tt.outputsURI)
call := mockComposedStore.On("ReadProtobuf", mock.Anything, expectedOutputRef, mock.Anything)
if tt.readOutputsErr != nil {
call.Return(tt.readOutputsErr).Maybe()
} else {
call.Run(func(args mock.Arguments) {
msg := args.Get(2).(proto.Message)
proto.Reset(msg)
proto.Merge(msg, storedOutputs)
}).Return(nil).Maybe()
}
}

ds := &storage.DataStore{
ComposedProtobufStore: mockComposedStore,
ReferenceConstructor: &simpleRefConstructor{},
}
svc := NewService(cfg, ds, nil, nil, runClient)

resp, err := svc.GetActionData(ctx, connect.NewRequest(&dataproxy.GetActionDataRequest{
ActionId: actionID,
}))

if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, resp)
return
}
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Len(t, resp.Msg.GetInputs().GetLiterals(), tt.expectInputsLen)
assert.Len(t, resp.Msg.GetOutputs().GetLiterals(), tt.expectOutputsLen)
if tt.expectInputsLen > 0 {
assert.Equal(t, "x", resp.Msg.GetInputs().GetLiterals()[0].GetName())
}
if tt.expectOutputsLen > 0 {
assert.Equal(t, "o", resp.Msg.GetOutputs().GetLiterals()[0].GetName())
}
})
}
}

type assertErr string

func (e assertErr) Error() string { return string(e) }

func setupMockDataStoreWithExistingFile(t *testing.T, contentMD5 string) *storage.DataStore {
mockComposedStore := storageMocks.NewComposedProtobufStore(t)

Expand Down
6 changes: 4 additions & 2 deletions dataproxy/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ import (
"fmt"
"net/http"

"github.com/flyteorg/flyte/v2/flytestdlib/app"
"github.com/flyteorg/flyte/v2/dataproxy/config"
"github.com/flyteorg/flyte/v2/dataproxy/service"
"github.com/flyteorg/flyte/v2/flytestdlib/app"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/cluster/clusterconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy/dataproxyconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task/taskconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/trigger/triggerconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect"

"github.com/flyteorg/flyte/v2/flytestdlib/logger"
)
Expand All @@ -24,8 +25,9 @@ func Setup(ctx context.Context, sc *app.SetupContext) error {
baseURL := sc.BaseURL
taskClient := taskconnect.NewTaskServiceClient(http.DefaultClient, baseURL)
triggerClient := triggerconnect.NewTriggerServiceClient(http.DefaultClient, baseURL)
runClient := workflowconnect.NewRunServiceClient(http.DefaultClient, baseURL)

svc := service.NewService(*cfg, sc.DataStore, taskClient, triggerClient)
svc := service.NewService(*cfg, sc.DataStore, taskClient, triggerClient, runClient)

path, handler := dataproxyconnect.NewDataProxyServiceHandler(svc)
sc.Mux.Handle(path, handler)
Expand Down
Loading
Loading