Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/flyte/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def _to_cache_lookup_scope(scope: CacheLookupScope | None = None) -> run_pb2.Cac
notification_rule_name, notification_rules = resolve_notification_settings(self._notifications)

try:
from flyteidl2.cluster.payload_pb2 import SelectClusterRequest
from flyteidl2.dataproxy import dataproxy_service_pb2

upload_req = dataproxy_service_pb2.UploadInputsRequest(
Expand All @@ -400,7 +401,11 @@ def _to_cache_lookup_scope(scope: CacheLookupScope | None = None) -> run_pb2.Cac
else:
upload_req.project_id.CopyFrom(project_id)

upload_resp = await get_client().dataproxy_service.upload_inputs(upload_req)
resource = run_id if run_id is not None else project_id
dataproxy = await get_client().get_dataproxy_for_resource(
SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS, resource
)
upload_resp = await dataproxy.upload_inputs(upload_req)

resp = await get_client().run_service.create_run(
run_service_pb2.CreateRunRequest(
Expand Down
7 changes: 7 additions & 0 deletions src/flyte/remote/_client/_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from flyteidl2.app import app_payload_pb2
from flyteidl2.auth import identity_pb2
from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2
from flyteidl2.dataproxy import dataproxy_service_pb2
from flyteidl2.project import project_service_pb2
from flyteidl2.secret import payload_pb2
Expand Down Expand Up @@ -134,6 +135,12 @@ class IdentityService(Protocol):
async def user_info(self, request: identity_pb2.UserInfoRequest) -> identity_pb2.UserInfoResponse: ...


class ClusterService(Protocol):
async def select_cluster(
self, request: cluster_payload_pb2.SelectClusterRequest
) -> cluster_payload_pb2.SelectClusterResponse: ...


class TriggerService(Protocol):
async def deploy_trigger(
self, request: trigger_service_pb2.DeployTriggerRequest
Expand Down
57 changes: 57 additions & 0 deletions src/flyte/remote/_client/controlplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from flyteidl2.app.app_service_connect import AppServiceClient
from flyteidl2.auth.identity_connect import IdentityServiceClient
from flyteidl2.cluster.service_connect import ClusterServiceClient
from flyteidl2.dataproxy.dataproxy_service_connect import DataProxyServiceClient
from flyteidl2.project.project_service_connect import ProjectServiceClient
from flyteidl2.secret.secret_connect import SecretServiceClient
Expand All @@ -14,6 +15,7 @@

from ._protocols import (
AppService,
ClusterService,
DataProxyService,
IdentityService,
ProjectDomainService,
Expand Down Expand Up @@ -175,6 +177,8 @@ def __init__(self, session_cfg: SessionConfig):
self._secrets_service = SecretServiceClient(**shared)
self._identity_service = IdentityServiceClient(**shared)
self._trigger_service = TriggerServiceClient(**shared)
self._cluster_service = ClusterServiceClient(**shared)
self._dataproxy_cache: dict[str, DataProxyServiceClient] = {}

@classmethod
async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet:
Expand Down Expand Up @@ -232,6 +236,59 @@ def identity_service(self) -> IdentityService:
def trigger_service(self) -> TriggerService:
return self._trigger_service

@property
def cluster_service(self) -> ClusterService:
return self._cluster_service

async def get_dataproxy_for_resource(self, operation: int, resource: object) -> DataProxyService:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this operation be an enum instead?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""Get a DataProxy client routed to the correct cluster for the given resource.

Calls SelectCluster to discover the cluster endpoint, then creates (or reuses
from cache) a DataProxyServiceClient pointing at that endpoint.

Args:
operation: SelectClusterRequest.Operation enum value.
resource: A protobuf identifier (e.g. RunIdentifier, ProjectIdentifier).
"""
from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2

cache_key = f"{operation}:{type(resource).__name__}:{resource}"

if cache_key in self._dataproxy_cache:
return self._dataproxy_cache[cache_key]

# Build the SelectClusterRequest with the right oneof field
req = cluster_payload_pb2.SelectClusterRequest(operation=operation)
if hasattr(resource, "DESCRIPTOR"):
field_map = {
"OrgIdentifier": "org_id",
"ProjectIdentifier": "project_id",
"TaskIdentifier": "task_id",
"ActionIdentifier": "action_id",
"ActionAttemptIdentifier": "action_attempt_id",
}
field_name = field_map.get(type(resource).__name__)
if field_name:
getattr(req, field_name).CopyFrom(resource)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the best way to create the one of 😬, @wild-endeavor @pingsutw ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do something like

  req = SelectClusterRequest(operation=operation)                  
  if hasattr(resource, "DESCRIPTOR"):                                                                    
      oneof = req.DESCRIPTOR.oneofs_by_name["resource"]  # replace with actual oneof name
      for field in oneof.fields:                                                                         
          if field.message_type is resource.DESCRIPTOR:                                                  
              getattr(req, field.name).CopyFrom(resource)                                                
              break    

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice thanks. done


resp = await self._cluster_service.select_cluster(req)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure we are throwing informative errors here in the infamous "no healthy clusters" error.. we now have a good place to catch that and raise a good error to the user

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

cluster_endpoint = resp.cluster_endpoint
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you going to normalize this here? stripping/adding http/s or dns:/// ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary! create_session_config already calls normalize_rpc_endpoint


# If the cluster endpoint matches our own endpoint, reuse the default client
if not cluster_endpoint or cluster_endpoint == self._session_config.endpoint:
self._dataproxy_cache[cache_key] = self._dataproxy
return self._dataproxy

# Create a new session config for the cluster endpoint
new_session_cfg = await create_session_config(
cluster_endpoint,
insecure=self._session_config.insecure,
insecure_skip_verify=self._session_config.insecure_skip_verify,
)
client = DataProxyServiceClient(**new_session_cfg.connect_kwargs())
self._dataproxy_cache[cache_key] = client
return client

@property
def endpoint(self) -> str:
return self._session_config.endpoint
Expand Down
20 changes: 16 additions & 4 deletions src/flyte/remote/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,25 @@ async def _upload_single_file(
:return: Tuple of (MD5 digest hex string, remote native URL).
"""
md5_bytes, str_digest, _ = hash_file(fp)
from flyteidl2.cluster.payload_pb2 import SelectClusterRequest
from flyteidl2.common import identifier_pb2

from flyte._logging import logger

expires_in_pb = duration_pb2.Duration()
expires_in_pb.FromTimedelta(_UPLOAD_EXPIRES_IN)
client = get_client()
project_id = identifier_pb2.ProjectIdentifier(
name=cfg.project,
domain=cfg.domain,
organization=cfg.org or "",
)
dataproxy = await client.get_dataproxy_for_resource(
SelectClusterRequest.Operation.OPERATION_CREATE_UPLOAD_LOCATION, project_id
)

try:
expires_in_pb = duration_pb2.Duration()
expires_in_pb.FromTimedelta(_UPLOAD_EXPIRES_IN)
client = get_client()
resp = await client.dataproxy_service.create_upload_location( # type: ignore
resp = await dataproxy.create_upload_location( # type: ignore
dataproxy_service_pb2.CreateUploadLocationRequest(
project=cfg.project,
domain=cfg.domain,
Expand Down
17 changes: 17 additions & 0 deletions tests/flyte/test_union_run_basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import mock
import pytest
from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2
from flyteidl2.common import run_pb2 as common_run_pb2
from flyteidl2.core import literals_pb2
from flyteidl2.dataproxy import dataproxy_service_pb2
Expand Down Expand Up @@ -58,6 +59,7 @@ async def test_task1_remote_union_sync(
offloaded_input_data=mock_offloaded,
)
mock_client.dataproxy_service = mock_dataproxy_service
mock_client.get_dataproxy_for_resource = AsyncMock(return_value=mock_dataproxy_service)

inputs = "say test"

Expand Down Expand Up @@ -95,6 +97,11 @@ async def test_task1_remote_union_sync(
assert upload_req.WhichOneof("task") == "task_spec"
assert upload_req.task_spec.task_template.id.name == "test.task1"

# Ensure get_dataproxy_for_resource was called with OPERATION_UPLOAD_INPUTS
mock_client.get_dataproxy_for_resource.assert_called_once()
dp_call_args = mock_client.get_dataproxy_for_resource.call_args[0]
assert dp_call_args[0] == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS

# Ensure create_run uses offloaded_input_data instead of inline inputs
mock_build_image_bg.assert_called_once()
mock_run_service.create_run.assert_called_once()
Expand Down Expand Up @@ -160,6 +167,7 @@ async def test_upload_inputs_with_run_id(
offloaded_input_data=mock_offloaded,
)
mock_client.dataproxy_service = mock_dataproxy_service
mock_client.get_dataproxy_for_resource = AsyncMock(return_value=mock_dataproxy_service)

mock_code_bundler.return_value = CodeBundle(
computed_version="v1",
Expand All @@ -185,6 +193,15 @@ async def test_upload_inputs_with_run_id(
assert upload_req.WhichOneof("task") == "task_spec"
assert upload_req.task_spec.task_template.id.name == "test.task1"

# Ensure get_dataproxy_for_resource was called with OPERATION_UPLOAD_INPUTS and run_id
mock_client.get_dataproxy_for_resource.assert_called_once()
dp_call_args = mock_client.get_dataproxy_for_resource.call_args[0]
assert dp_call_args[0] == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS
# The resource should be a RunIdentifier
assert dp_call_args[1].name == "my-run"
assert dp_call_args[1].project == "testproject"
assert dp_call_args[1].domain == "development"

# create_run should use offloaded_input_data
req: run_service_pb2.CreateRunRequest = mock_run_service.create_run.call_args[0][0]
assert req.offloaded_input_data == mock_offloaded
Expand Down
Loading