From 714bad85f90ccdbf603d93d5b902551e1b6947b2 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Fri, 10 Apr 2026 10:17:44 +0200 Subject: [PATCH 1/6] Add dataproxy client selector Signed-off-by: Katrina Rogan --- src/flyte/_run.py | 7 ++- src/flyte/remote/_client/_protocols.py | 7 +++ src/flyte/remote/_client/controlplane.py | 59 ++++++++++++++++++++++++ src/flyte/remote/_data.py | 13 +++++- tests/flyte/test_union_run_basic.py | 17 +++++++ 5 files changed, 101 insertions(+), 2 deletions(-) diff --git a/src/flyte/_run.py b/src/flyte/_run.py index 1fafc837a..8993f6fa5 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -389,6 +389,7 @@ def _to_cache_lookup_scope(scope: CacheLookupScope | None = None) -> run_pb2.Cac notification_rule_name, notification_rules = resolve_notification_settings(self._notifications) try: + from flyteidl2.cluster.payload_pb2 import SelectClusterRequest from flyteidl2.dataproxy import dataproxy_service_pb2 upload_req = dataproxy_service_pb2.UploadInputsRequest( @@ -400,7 +401,11 @@ def _to_cache_lookup_scope(scope: CacheLookupScope | None = None) -> run_pb2.Cac else: upload_req.project_id.CopyFrom(project_id) - upload_resp = await get_client().dataproxy_service.upload_inputs(upload_req) + resource = run_id if run_id is not None else project_id + dataproxy = await get_client().get_dataproxy_for_resource( + SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS, resource + ) + upload_resp = await dataproxy.upload_inputs(upload_req) resp = await get_client().run_service.create_run( run_service_pb2.CreateRunRequest( diff --git a/src/flyte/remote/_client/_protocols.py b/src/flyte/remote/_client/_protocols.py index 8400b10f8..97df261d3 100644 --- a/src/flyte/remote/_client/_protocols.py +++ b/src/flyte/remote/_client/_protocols.py @@ -2,6 +2,7 @@ from flyteidl2.app import app_payload_pb2 from flyteidl2.auth import identity_pb2 +from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2 from flyteidl2.dataproxy import dataproxy_service_pb2 from flyteidl2.project import project_service_pb2 from flyteidl2.secret import payload_pb2 @@ -134,6 +135,12 @@ class IdentityService(Protocol): async def user_info(self, request: identity_pb2.UserInfoRequest) -> identity_pb2.UserInfoResponse: ... +class ClusterService(Protocol): + async def select_cluster( + self, request: cluster_payload_pb2.SelectClusterRequest + ) -> cluster_payload_pb2.SelectClusterResponse: ... + + class TriggerService(Protocol): async def deploy_trigger( self, request: trigger_service_pb2.DeployTriggerRequest diff --git a/src/flyte/remote/_client/controlplane.py b/src/flyte/remote/_client/controlplane.py index a9240c712..3fba98eea 100644 --- a/src/flyte/remote/_client/controlplane.py +++ b/src/flyte/remote/_client/controlplane.py @@ -4,6 +4,7 @@ from flyteidl2.app.app_service_connect import AppServiceClient from flyteidl2.auth.identity_connect import IdentityServiceClient +from flyteidl2.cluster.service_connect import ClusterServiceClient from flyteidl2.dataproxy.dataproxy_service_connect import DataProxyServiceClient from flyteidl2.project.project_service_connect import ProjectServiceClient from flyteidl2.secret.secret_connect import SecretServiceClient @@ -14,6 +15,7 @@ from ._protocols import ( AppService, + ClusterService, DataProxyService, IdentityService, ProjectDomainService, @@ -175,6 +177,8 @@ def __init__(self, session_cfg: SessionConfig): self._secrets_service = SecretServiceClient(**shared) self._identity_service = IdentityServiceClient(**shared) self._trigger_service = TriggerServiceClient(**shared) + self._cluster_service = ClusterServiceClient(**shared) + self._dataproxy_cache: dict[str, DataProxyServiceClient] = {} @classmethod async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet: @@ -232,6 +236,61 @@ def identity_service(self) -> IdentityService: def trigger_service(self) -> TriggerService: return self._trigger_service + @property + def cluster_service(self) -> ClusterService: + return self._cluster_service + + async def get_dataproxy_for_resource( + self, operation: int, resource: object + ) -> DataProxyService: + """Get a DataProxy client routed to the correct cluster for the given resource. + + Calls SelectCluster to discover the cluster endpoint, then creates (or reuses + from cache) a DataProxyServiceClient pointing at that endpoint. + + Args: + operation: SelectClusterRequest.Operation enum value. + resource: A protobuf identifier (e.g. RunIdentifier, ProjectIdentifier). + """ + from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2 + + cache_key = f"{operation}:{type(resource).__name__}:{resource}" + + if cache_key in self._dataproxy_cache: + return self._dataproxy_cache[cache_key] + + # Build the SelectClusterRequest with the right oneof field + req = cluster_payload_pb2.SelectClusterRequest(operation=operation) + if hasattr(resource, "DESCRIPTOR"): + field_map = { + "OrgIdentifier": "org_id", + "ProjectIdentifier": "project_id", + "TaskIdentifier": "task_id", + "ActionIdentifier": "action_id", + "ActionAttemptIdentifier": "action_attempt_id", + } + field_name = field_map.get(type(resource).__name__) + if field_name: + getattr(req, field_name).CopyFrom(resource) + + resp = await self._cluster_service.select_cluster(req) + cluster_endpoint = resp.cluster_endpoint + + # If the cluster endpoint matches our own endpoint, reuse the default client + if not cluster_endpoint or cluster_endpoint == self._session_config.endpoint: + self._dataproxy_cache[cache_key] = self._dataproxy + return self._dataproxy + + # Create a new session config for the cluster endpoint + new_session_cfg = await create_session_config( + cluster_endpoint, + insecure=self._session_config.insecure, + insecure_skip_verify=self._session_config.insecure_skip_verify, + ) + client = DataProxyServiceClient(**new_session_cfg.connect_kwargs()) + self._dataproxy_cache[cache_key] = client + return client + @property def endpoint(self) -> str: return self._session_config.endpoint diff --git a/src/flyte/remote/_data.py b/src/flyte/remote/_data.py index eb27912ee..a95f1dcfc 100644 --- a/src/flyte/remote/_data.py +++ b/src/flyte/remote/_data.py @@ -156,10 +156,21 @@ async def _upload_single_file( from flyte._logging import logger try: + from flyteidl2.cluster.payload_pb2 import SelectClusterRequest + from flyteidl2.common import identifier_pb2 + expires_in_pb = duration_pb2.Duration() expires_in_pb.FromTimedelta(_UPLOAD_EXPIRES_IN) client = get_client() - resp = await client.dataproxy_service.create_upload_location( # type: ignore + project_id = identifier_pb2.ProjectIdentifier( + name=cfg.project, + domain=cfg.domain, + organization=cfg.org or "", + ) + dataproxy = await client.get_dataproxy_for_resource( + SelectClusterRequest.Operation.OPERATION_CREATE_UPLOAD_LOCATION, project_id + ) + resp = await dataproxy.create_upload_location( # type: ignore dataproxy_service_pb2.CreateUploadLocationRequest( project=cfg.project, domain=cfg.domain, diff --git a/tests/flyte/test_union_run_basic.py b/tests/flyte/test_union_run_basic.py index b625d866c..42cf6ac94 100644 --- a/tests/flyte/test_union_run_basic.py +++ b/tests/flyte/test_union_run_basic.py @@ -1,5 +1,6 @@ import mock import pytest +from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2 from flyteidl2.common import run_pb2 as common_run_pb2 from flyteidl2.core import literals_pb2 from flyteidl2.dataproxy import dataproxy_service_pb2 @@ -58,6 +59,7 @@ async def test_task1_remote_union_sync( offloaded_input_data=mock_offloaded, ) mock_client.dataproxy_service = mock_dataproxy_service + mock_client.get_dataproxy_for_resource = AsyncMock(return_value=mock_dataproxy_service) inputs = "say test" @@ -95,6 +97,11 @@ async def test_task1_remote_union_sync( assert upload_req.WhichOneof("task") == "task_spec" assert upload_req.task_spec.task_template.id.name == "test.task1" + # Ensure get_dataproxy_for_resource was called with OPERATION_UPLOAD_INPUTS + mock_client.get_dataproxy_for_resource.assert_called_once() + dp_call_args = mock_client.get_dataproxy_for_resource.call_args[0] + assert dp_call_args[0] == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS + # Ensure create_run uses offloaded_input_data instead of inline inputs mock_build_image_bg.assert_called_once() mock_run_service.create_run.assert_called_once() @@ -160,6 +167,7 @@ async def test_upload_inputs_with_run_id( offloaded_input_data=mock_offloaded, ) mock_client.dataproxy_service = mock_dataproxy_service + mock_client.get_dataproxy_for_resource = AsyncMock(return_value=mock_dataproxy_service) mock_code_bundler.return_value = CodeBundle( computed_version="v1", @@ -185,6 +193,15 @@ async def test_upload_inputs_with_run_id( assert upload_req.WhichOneof("task") == "task_spec" assert upload_req.task_spec.task_template.id.name == "test.task1" + # Ensure get_dataproxy_for_resource was called with OPERATION_UPLOAD_INPUTS and run_id + mock_client.get_dataproxy_for_resource.assert_called_once() + dp_call_args = mock_client.get_dataproxy_for_resource.call_args[0] + assert dp_call_args[0] == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS + # The resource should be a RunIdentifier + assert dp_call_args[1].name == "my-run" + assert dp_call_args[1].project == "testproject" + assert dp_call_args[1].domain == "development" + # create_run should use offloaded_input_data req: run_service_pb2.CreateRunRequest = mock_run_service.create_run.call_args[0][0] assert req.offloaded_input_data == mock_offloaded From c7521a64a8b2f04e95aa93a5fa511aba6a8b31c9 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Fri, 10 Apr 2026 10:23:24 +0200 Subject: [PATCH 2/6] imports Signed-off-by: Katrina Rogan --- src/flyte/remote/_data.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/src/flyte/remote/_data.py b/src/flyte/remote/_data.py index a95f1dcfc..2a0511852 100644 --- a/src/flyte/remote/_data.py +++ b/src/flyte/remote/_data.py @@ -153,23 +153,24 @@ async def _upload_single_file( :return: Tuple of (MD5 digest hex string, remote native URL). """ md5_bytes, str_digest, _ = hash_file(fp) + from flyteidl2.cluster.payload_pb2 import SelectClusterRequest + from flyteidl2.common import identifier_pb2 + from flyte._logging import logger + expires_in_pb = duration_pb2.Duration() + expires_in_pb.FromTimedelta(_UPLOAD_EXPIRES_IN) + client = get_client() + project_id = identifier_pb2.ProjectIdentifier( + name=cfg.project, + domain=cfg.domain, + organization=cfg.org or "", + ) + dataproxy = await client.get_dataproxy_for_resource( + SelectClusterRequest.Operation.OPERATION_CREATE_UPLOAD_LOCATION, project_id + ) + try: - from flyteidl2.cluster.payload_pb2 import SelectClusterRequest - from flyteidl2.common import identifier_pb2 - - expires_in_pb = duration_pb2.Duration() - expires_in_pb.FromTimedelta(_UPLOAD_EXPIRES_IN) - client = get_client() - project_id = identifier_pb2.ProjectIdentifier( - name=cfg.project, - domain=cfg.domain, - organization=cfg.org or "", - ) - dataproxy = await client.get_dataproxy_for_resource( - SelectClusterRequest.Operation.OPERATION_CREATE_UPLOAD_LOCATION, project_id - ) resp = await dataproxy.create_upload_location( # type: ignore dataproxy_service_pb2.CreateUploadLocationRequest( project=cfg.project, From 1a49dc81aac54f0a00dc306f249817bfee99630e Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Fri, 10 Apr 2026 10:25:49 +0200 Subject: [PATCH 3/6] lint Signed-off-by: Katrina Rogan --- src/flyte/remote/_client/controlplane.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/flyte/remote/_client/controlplane.py b/src/flyte/remote/_client/controlplane.py index 3fba98eea..e743e695f 100644 --- a/src/flyte/remote/_client/controlplane.py +++ b/src/flyte/remote/_client/controlplane.py @@ -240,9 +240,7 @@ def trigger_service(self) -> TriggerService: def cluster_service(self) -> ClusterService: return self._cluster_service - async def get_dataproxy_for_resource( - self, operation: int, resource: object - ) -> DataProxyService: + async def get_dataproxy_for_resource(self, operation: int, resource: object) -> DataProxyService: """Get a DataProxy client routed to the correct cluster for the given resource. Calls SelectCluster to discover the cluster endpoint, then creates (or reuses From 69dd1fc1154a7b3bfe8b6f20af3628ee01a03c0f Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Mon, 13 Apr 2026 08:49:25 +0200 Subject: [PATCH 4/6] review comments Signed-off-by: Katrina Rogan --- src/flyte/remote/_client/controlplane.py | 60 +++++++++++++++--------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/src/flyte/remote/_client/controlplane.py b/src/flyte/remote/_client/controlplane.py index e743e695f..06d338892 100644 --- a/src/flyte/remote/_client/controlplane.py +++ b/src/flyte/remote/_client/controlplane.py @@ -4,6 +4,7 @@ from flyteidl2.app.app_service_connect import AppServiceClient from flyteidl2.auth.identity_connect import IdentityServiceClient +from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2 from flyteidl2.cluster.service_connect import ClusterServiceClient from flyteidl2.dataproxy.dataproxy_service_connect import DataProxyServiceClient from flyteidl2.project.project_service_connect import ProjectServiceClient @@ -240,38 +241,46 @@ def trigger_service(self) -> TriggerService: def cluster_service(self) -> ClusterService: return self._cluster_service - async def get_dataproxy_for_resource(self, operation: int, resource: object) -> DataProxyService: + async def get_dataproxy_for_resource( + self, operation: cluster_payload_pb2.SelectClusterRequest.Operation, resource: object + ) -> DataProxyService: """Get a DataProxy client routed to the correct cluster for the given resource. Calls SelectCluster to discover the cluster endpoint, then creates (or reuses from cache) a DataProxyServiceClient pointing at that endpoint. Args: - operation: SelectClusterRequest.Operation enum value. + operation: The SelectCluster operation enum. resource: A protobuf identifier (e.g. RunIdentifier, ProjectIdentifier). """ - from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2 + from flyte._logging import logger cache_key = f"{operation}:{type(resource).__name__}:{resource}" if cache_key in self._dataproxy_cache: return self._dataproxy_cache[cache_key] - # Build the SelectClusterRequest with the right oneof field + # Build the SelectClusterRequest, matching the resource to the correct oneof field req = cluster_payload_pb2.SelectClusterRequest(operation=operation) - if hasattr(resource, "DESCRIPTOR"): - field_map = { - "OrgIdentifier": "org_id", - "ProjectIdentifier": "project_id", - "TaskIdentifier": "task_id", - "ActionIdentifier": "action_id", - "ActionAttemptIdentifier": "action_attempt_id", - } - field_name = field_map.get(type(resource).__name__) - if field_name: - getattr(req, field_name).CopyFrom(resource) - - resp = await self._cluster_service.select_cluster(req) + oneof = req.DESCRIPTOR.oneofs_by_name["resource"] + matched = False + for field in oneof.fields: + if field.message_type is resource.DESCRIPTOR: + getattr(req, field.name).CopyFrom(resource) + matched = True + break + if not matched: + raise ValueError( + f"Unsupported resource type '{type(resource).__name__}' for SelectCluster. " + f"Expected one of: {', '.join(f.message_type.name for f in oneof.fields)}" + ) + + try: + resp = await self._cluster_service.select_cluster(req) + except Exception as e: + raise RuntimeError( + f"SelectCluster failed for operation={operation} resource={type(resource).__name__}: {e}" + ) from e cluster_endpoint = resp.cluster_endpoint # If the cluster endpoint matches our own endpoint, reuse the default client @@ -280,11 +289,18 @@ async def get_dataproxy_for_resource(self, operation: int, resource: object) -> return self._dataproxy # Create a new session config for the cluster endpoint - new_session_cfg = await create_session_config( - cluster_endpoint, - insecure=self._session_config.insecure, - insecure_skip_verify=self._session_config.insecure_skip_verify, - ) + try: + new_session_cfg = await create_session_config( + cluster_endpoint, + insecure=self._session_config.insecure, + insecure_skip_verify=self._session_config.insecure_skip_verify, + ) + except Exception as e: + raise RuntimeError( + f"Failed to create session for cluster endpoint '{cluster_endpoint}': {e}" + ) from e + + logger.debug(f"Created DataProxy client for cluster endpoint: {cluster_endpoint}") client = DataProxyServiceClient(**new_session_cfg.connect_kwargs()) self._dataproxy_cache[cache_key] = client return client From 54a1a079bfe565e75899d2a65c423cede24135b0 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Mon, 13 Apr 2026 08:50:47 +0200 Subject: [PATCH 5/6] fmt Signed-off-by: Katrina Rogan --- src/flyte/remote/_client/controlplane.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/flyte/remote/_client/controlplane.py b/src/flyte/remote/_client/controlplane.py index 06d338892..5b71f5292 100644 --- a/src/flyte/remote/_client/controlplane.py +++ b/src/flyte/remote/_client/controlplane.py @@ -296,9 +296,7 @@ async def get_dataproxy_for_resource( insecure_skip_verify=self._session_config.insecure_skip_verify, ) except Exception as e: - raise RuntimeError( - f"Failed to create session for cluster endpoint '{cluster_endpoint}': {e}" - ) from e + raise RuntimeError(f"Failed to create session for cluster endpoint '{cluster_endpoint}': {e}") from e logger.debug(f"Created DataProxy client for cluster endpoint: {cluster_endpoint}") client = DataProxyServiceClient(**new_session_cfg.connect_kwargs()) From 7d10be85179f2decf246ebb78d548ff91abdeada Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Thu, 16 Apr 2026 10:51:50 +0200 Subject: [PATCH 6/6] Review comments Signed-off-by: Katrina Rogan --- src/flyte/_run.py | 7 +- src/flyte/remote/_client/controlplane.py | 169 +++++++++++------- src/flyte/remote/_data.py | 21 +-- .../remote/test_cluster_aware_dataproxy.py | 154 ++++++++++++++++ tests/flyte/test_union_run_basic.py | 17 -- 5 files changed, 265 insertions(+), 103 deletions(-) create mode 100644 tests/flyte/remote/test_cluster_aware_dataproxy.py diff --git a/src/flyte/_run.py b/src/flyte/_run.py index 8993f6fa5..1fafc837a 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -389,7 +389,6 @@ def _to_cache_lookup_scope(scope: CacheLookupScope | None = None) -> run_pb2.Cac notification_rule_name, notification_rules = resolve_notification_settings(self._notifications) try: - from flyteidl2.cluster.payload_pb2 import SelectClusterRequest from flyteidl2.dataproxy import dataproxy_service_pb2 upload_req = dataproxy_service_pb2.UploadInputsRequest( @@ -401,11 +400,7 @@ def _to_cache_lookup_scope(scope: CacheLookupScope | None = None) -> run_pb2.Cac else: upload_req.project_id.CopyFrom(project_id) - resource = run_id if run_id is not None else project_id - dataproxy = await get_client().get_dataproxy_for_resource( - SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS, resource - ) - upload_resp = await dataproxy.upload_inputs(upload_req) + upload_resp = await get_client().dataproxy_service.upload_inputs(upload_req) resp = await get_client().run_service.create_run( run_service_pb2.CreateRunRequest( diff --git a/src/flyte/remote/_client/controlplane.py b/src/flyte/remote/_client/controlplane.py index 5b71f5292..b23ec692c 100644 --- a/src/flyte/remote/_client/controlplane.py +++ b/src/flyte/remote/_client/controlplane.py @@ -6,6 +6,8 @@ from flyteidl2.auth.identity_connect import IdentityServiceClient from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2 from flyteidl2.cluster.service_connect import ClusterServiceClient +from flyteidl2.common import identifier_pb2 +from flyteidl2.dataproxy import dataproxy_service_pb2 from flyteidl2.dataproxy.dataproxy_service_connect import DataProxyServiceClient from flyteidl2.project.project_service_connect import ProjectServiceClient from flyteidl2.secret.secret_connect import SecretServiceClient @@ -164,6 +166,99 @@ def insecure(self) -> bool: return self._insecure +class ClusterAwareDataProxy: + """DataProxy client that routes each call to the correct cluster. + + Implements the DataProxyService protocol. For every RPC, extracts the target + resource from the request, calls ClusterService.SelectCluster to discover + the cluster endpoint, and dispatches to a DataProxyServiceClient pointing at + that endpoint. Per-cluster clients are cached by (operation, resource) so + repeated calls against the same resource reuse the same connection. + """ + + def __init__( + self, + cluster_service: ClusterService, + session_config: SessionConfig, + default_client: DataProxyServiceClient, + ): + self._cluster_service = cluster_service + self._session_config = session_config + self._default_client = default_client + self._cache: dict[bytes, DataProxyService] = {} + + async def create_upload_location( + self, request: dataproxy_service_pb2.CreateUploadLocationRequest + ) -> dataproxy_service_pb2.CreateUploadLocationResponse: + project_id = identifier_pb2.ProjectIdentifier( + name=request.project, domain=request.domain, organization=request.org + ) + client = await self._resolve( + cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_CREATE_UPLOAD_LOCATION, + project_id, + ) + return await client.create_upload_location(request) + + async def upload_inputs( + self, request: dataproxy_service_pb2.UploadInputsRequest + ) -> dataproxy_service_pb2.UploadInputsResponse: + which = request.WhichOneof("id") + if which == "run_id": + # SelectClusterRequest.resource doesn't include RunIdentifier; route by project. + project_id = identifier_pb2.ProjectIdentifier( + name=request.run_id.project, + domain=request.run_id.domain, + organization=request.run_id.org, + ) + elif which == "project_id": + project_id = request.project_id + else: + raise ValueError("UploadInputsRequest must set either run_id or project_id") + client = await self._resolve( + cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS, + project_id, + ) + return await client.upload_inputs(request) + + async def _resolve( + self, + operation: cluster_payload_pb2.SelectClusterRequest.Operation, + project_id: identifier_pb2.ProjectIdentifier, + ) -> DataProxyService: + from flyte._logging import logger + + cache_key = int(operation).to_bytes(4, "little") + project_id.SerializeToString(deterministic=True) + cached = self._cache.get(cache_key) + if cached is not None: + return cached + + req = cluster_payload_pb2.SelectClusterRequest(operation=operation) + req.project_id.CopyFrom(project_id) + try: + resp = await self._cluster_service.select_cluster(req) + except Exception as e: + raise RuntimeError(f"SelectCluster failed for operation={operation}: {e}") from e + + endpoint = resp.cluster_endpoint + if not endpoint or endpoint == self._session_config.endpoint: + self._cache[cache_key] = self._default_client + return self._default_client + + try: + new_cfg = await create_session_config( + endpoint, + insecure=self._session_config.insecure, + insecure_skip_verify=self._session_config.insecure_skip_verify, + ) + except Exception as e: + raise RuntimeError(f"Failed to create session for cluster endpoint '{endpoint}': {e}") from e + + logger.debug(f"Created DataProxy client for cluster endpoint: {endpoint}") + client = DataProxyServiceClient(**new_cfg.connect_kwargs()) + self._cache[cache_key] = client + return client + + class ClientSet: def __init__(self, session_cfg: SessionConfig): self._console = Console(session_cfg.endpoint, session_cfg.insecure) @@ -173,13 +268,16 @@ def __init__(self, session_cfg: SessionConfig): self._task_service = TaskServiceClient(**shared) self._app_service = AppServiceClient(**shared) self._run_service = RunServiceClient(**shared) - self._dataproxy = DataProxyServiceClient(**shared) self._log_service = RunLogsServiceClient(**shared) self._secrets_service = SecretServiceClient(**shared) self._identity_service = IdentityServiceClient(**shared) self._trigger_service = TriggerServiceClient(**shared) self._cluster_service = ClusterServiceClient(**shared) - self._dataproxy_cache: dict[str, DataProxyServiceClient] = {} + self._dataproxy = ClusterAwareDataProxy( + cluster_service=self._cluster_service, + session_config=session_cfg, + default_client=DataProxyServiceClient(**shared), + ) @classmethod async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet: @@ -219,6 +317,11 @@ def run_service(self) -> RunService: @property def dataproxy_service(self) -> DataProxyService: + """Cluster-aware DataProxy client. + + Each call routes to the cluster selected by ClusterService.SelectCluster + for the target resource, with per-cluster clients cached. + """ return self._dataproxy @property @@ -241,68 +344,6 @@ def trigger_service(self) -> TriggerService: def cluster_service(self) -> ClusterService: return self._cluster_service - async def get_dataproxy_for_resource( - self, operation: cluster_payload_pb2.SelectClusterRequest.Operation, resource: object - ) -> DataProxyService: - """Get a DataProxy client routed to the correct cluster for the given resource. - - Calls SelectCluster to discover the cluster endpoint, then creates (or reuses - from cache) a DataProxyServiceClient pointing at that endpoint. - - Args: - operation: The SelectCluster operation enum. - resource: A protobuf identifier (e.g. RunIdentifier, ProjectIdentifier). - """ - from flyte._logging import logger - - cache_key = f"{operation}:{type(resource).__name__}:{resource}" - - if cache_key in self._dataproxy_cache: - return self._dataproxy_cache[cache_key] - - # Build the SelectClusterRequest, matching the resource to the correct oneof field - req = cluster_payload_pb2.SelectClusterRequest(operation=operation) - oneof = req.DESCRIPTOR.oneofs_by_name["resource"] - matched = False - for field in oneof.fields: - if field.message_type is resource.DESCRIPTOR: - getattr(req, field.name).CopyFrom(resource) - matched = True - break - if not matched: - raise ValueError( - f"Unsupported resource type '{type(resource).__name__}' for SelectCluster. " - f"Expected one of: {', '.join(f.message_type.name for f in oneof.fields)}" - ) - - try: - resp = await self._cluster_service.select_cluster(req) - except Exception as e: - raise RuntimeError( - f"SelectCluster failed for operation={operation} resource={type(resource).__name__}: {e}" - ) from e - cluster_endpoint = resp.cluster_endpoint - - # If the cluster endpoint matches our own endpoint, reuse the default client - if not cluster_endpoint or cluster_endpoint == self._session_config.endpoint: - self._dataproxy_cache[cache_key] = self._dataproxy - return self._dataproxy - - # Create a new session config for the cluster endpoint - try: - new_session_cfg = await create_session_config( - cluster_endpoint, - insecure=self._session_config.insecure, - insecure_skip_verify=self._session_config.insecure_skip_verify, - ) - except Exception as e: - raise RuntimeError(f"Failed to create session for cluster endpoint '{cluster_endpoint}': {e}") from e - - logger.debug(f"Created DataProxy client for cluster endpoint: {cluster_endpoint}") - client = DataProxyServiceClient(**new_session_cfg.connect_kwargs()) - self._dataproxy_cache[cache_key] = client - return client - @property def endpoint(self) -> str: return self._session_config.endpoint diff --git a/src/flyte/remote/_data.py b/src/flyte/remote/_data.py index 2a0511852..0abc0d0dc 100644 --- a/src/flyte/remote/_data.py +++ b/src/flyte/remote/_data.py @@ -153,28 +153,17 @@ async def _upload_single_file( :return: Tuple of (MD5 digest hex string, remote native URL). """ md5_bytes, str_digest, _ = hash_file(fp) - from flyteidl2.cluster.payload_pb2 import SelectClusterRequest - from flyteidl2.common import identifier_pb2 - from flyte._logging import logger - expires_in_pb = duration_pb2.Duration() - expires_in_pb.FromTimedelta(_UPLOAD_EXPIRES_IN) - client = get_client() - project_id = identifier_pb2.ProjectIdentifier( - name=cfg.project, - domain=cfg.domain, - organization=cfg.org or "", - ) - dataproxy = await client.get_dataproxy_for_resource( - SelectClusterRequest.Operation.OPERATION_CREATE_UPLOAD_LOCATION, project_id - ) - try: - resp = await dataproxy.create_upload_location( # type: ignore + expires_in_pb = duration_pb2.Duration() + expires_in_pb.FromTimedelta(_UPLOAD_EXPIRES_IN) + client = get_client() + resp = await client.dataproxy_service.create_upload_location( # type: ignore dataproxy_service_pb2.CreateUploadLocationRequest( project=cfg.project, domain=cfg.domain, + org=cfg.org or "", content_md5=md5_bytes, filename=fname or fp.name, expires_in=expires_in_pb, diff --git a/tests/flyte/remote/test_cluster_aware_dataproxy.py b/tests/flyte/remote/test_cluster_aware_dataproxy.py new file mode 100644 index 000000000..1e23ca03b --- /dev/null +++ b/tests/flyte/remote/test_cluster_aware_dataproxy.py @@ -0,0 +1,154 @@ +"""Tests for the ClusterAwareDataProxy wrapper in flyte.remote._client.controlplane.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2 +from flyteidl2.common import identifier_pb2 +from flyteidl2.dataproxy import dataproxy_service_pb2 + +from flyte.remote._client.controlplane import ClusterAwareDataProxy + + +def _make_wrapper( + cluster_endpoint: str = "", + own_endpoint: str = "dns:///localhost:8090", +): + cluster_service = MagicMock() + cluster_service.select_cluster = AsyncMock( + return_value=cluster_payload_pb2.SelectClusterResponse(cluster_endpoint=cluster_endpoint) + ) + session_config = MagicMock() + session_config.endpoint = own_endpoint + session_config.insecure = True + session_config.insecure_skip_verify = False + default_client = MagicMock() + default_client.create_upload_location = AsyncMock( + return_value=dataproxy_service_pb2.CreateUploadLocationResponse(signed_url="https://signed/") + ) + default_client.upload_inputs = AsyncMock(return_value=dataproxy_service_pb2.UploadInputsResponse()) + return ( + ClusterAwareDataProxy( + cluster_service=cluster_service, + session_config=session_config, + default_client=default_client, + ), + cluster_service, + default_client, + ) + + +@pytest.mark.asyncio +async def test_create_upload_location_routes_by_project(): + wrapper, cluster_service, default_client = _make_wrapper() + req = dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o", filename="f") + + await wrapper.create_upload_location(req) + + cluster_service.select_cluster.assert_awaited_once() + sent = cluster_service.select_cluster.await_args[0][0] + assert sent.operation == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_CREATE_UPLOAD_LOCATION + assert sent.WhichOneof("resource") == "project_id" + assert sent.project_id.name == "p" + assert sent.project_id.domain == "d" + assert sent.project_id.organization == "o" + default_client.create_upload_location.assert_awaited_once_with(req) + + +@pytest.mark.asyncio +async def test_upload_inputs_with_project_id_routes_by_project(): + wrapper, cluster_service, default_client = _make_wrapper() + pid = identifier_pb2.ProjectIdentifier(name="p", domain="d", organization="o") + req = dataproxy_service_pb2.UploadInputsRequest(project_id=pid) + + await wrapper.upload_inputs(req) + + sent = cluster_service.select_cluster.await_args[0][0] + assert sent.operation == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS + assert sent.project_id == pid + default_client.upload_inputs.assert_awaited_once_with(req) + + +@pytest.mark.asyncio +async def test_upload_inputs_with_run_id_routes_by_derived_project(): + wrapper, cluster_service, default_client = _make_wrapper() + run_id = identifier_pb2.RunIdentifier(org="o", project="p", domain="d", name="r") + req = dataproxy_service_pb2.UploadInputsRequest(run_id=run_id) + + await wrapper.upload_inputs(req) + + sent = cluster_service.select_cluster.await_args[0][0] + assert sent.WhichOneof("resource") == "project_id" + assert sent.project_id.name == "p" + assert sent.project_id.domain == "d" + assert sent.project_id.organization == "o" + default_client.upload_inputs.assert_awaited_once_with(req) + + +@pytest.mark.asyncio +async def test_upload_inputs_requires_id_oneof(): + wrapper, _, _ = _make_wrapper() + with pytest.raises(ValueError): + await wrapper.upload_inputs(dataproxy_service_pb2.UploadInputsRequest()) + + +@pytest.mark.asyncio +async def test_cache_hits_reuse_selected_client(): + wrapper, cluster_service, default_client = _make_wrapper() + req = dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o", filename="f") + + await wrapper.create_upload_location(req) + await wrapper.create_upload_location(req) + + assert cluster_service.select_cluster.await_count == 1 + assert default_client.create_upload_location.await_count == 2 + + +@pytest.mark.asyncio +async def test_cache_keyed_on_operation_and_resource(): + wrapper, cluster_service, _ = _make_wrapper() + + await wrapper.create_upload_location( + dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o") + ) + await wrapper.upload_inputs( + dataproxy_service_pb2.UploadInputsRequest( + project_id=identifier_pb2.ProjectIdentifier(name="p", domain="d", organization="o"), + ) + ) + await wrapper.create_upload_location( + dataproxy_service_pb2.CreateUploadLocationRequest(project="p2", domain="d", org="o") + ) + + assert cluster_service.select_cluster.await_count == 3 + + +@pytest.mark.asyncio +async def test_remote_cluster_endpoint_creates_new_client(): + wrapper, cluster_service, default_client = _make_wrapper(cluster_endpoint="dns:///other:8090") + + new_client_inst = MagicMock() + new_client_inst.create_upload_location = AsyncMock( + return_value=dataproxy_service_pb2.CreateUploadLocationResponse(signed_url="https://remote/") + ) + new_session_cfg = MagicMock() + new_session_cfg.connect_kwargs.return_value = {} + + with ( + patch( + "flyte.remote._client.controlplane.create_session_config", + new=AsyncMock(return_value=new_session_cfg), + ), + patch( + "flyte.remote._client.controlplane.DataProxyServiceClient", + return_value=new_client_inst, + ), + ): + req = dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o") + await wrapper.create_upload_location(req) + # Cached for a subsequent call + await wrapper.create_upload_location(req) + + assert cluster_service.select_cluster.await_count == 1 + assert new_client_inst.create_upload_location.await_count == 2 + default_client.create_upload_location.assert_not_awaited() diff --git a/tests/flyte/test_union_run_basic.py b/tests/flyte/test_union_run_basic.py index 42cf6ac94..b625d866c 100644 --- a/tests/flyte/test_union_run_basic.py +++ b/tests/flyte/test_union_run_basic.py @@ -1,6 +1,5 @@ import mock import pytest -from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2 from flyteidl2.common import run_pb2 as common_run_pb2 from flyteidl2.core import literals_pb2 from flyteidl2.dataproxy import dataproxy_service_pb2 @@ -59,7 +58,6 @@ async def test_task1_remote_union_sync( offloaded_input_data=mock_offloaded, ) mock_client.dataproxy_service = mock_dataproxy_service - mock_client.get_dataproxy_for_resource = AsyncMock(return_value=mock_dataproxy_service) inputs = "say test" @@ -97,11 +95,6 @@ async def test_task1_remote_union_sync( assert upload_req.WhichOneof("task") == "task_spec" assert upload_req.task_spec.task_template.id.name == "test.task1" - # Ensure get_dataproxy_for_resource was called with OPERATION_UPLOAD_INPUTS - mock_client.get_dataproxy_for_resource.assert_called_once() - dp_call_args = mock_client.get_dataproxy_for_resource.call_args[0] - assert dp_call_args[0] == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS - # Ensure create_run uses offloaded_input_data instead of inline inputs mock_build_image_bg.assert_called_once() mock_run_service.create_run.assert_called_once() @@ -167,7 +160,6 @@ async def test_upload_inputs_with_run_id( offloaded_input_data=mock_offloaded, ) mock_client.dataproxy_service = mock_dataproxy_service - mock_client.get_dataproxy_for_resource = AsyncMock(return_value=mock_dataproxy_service) mock_code_bundler.return_value = CodeBundle( computed_version="v1", @@ -193,15 +185,6 @@ async def test_upload_inputs_with_run_id( assert upload_req.WhichOneof("task") == "task_spec" assert upload_req.task_spec.task_template.id.name == "test.task1" - # Ensure get_dataproxy_for_resource was called with OPERATION_UPLOAD_INPUTS and run_id - mock_client.get_dataproxy_for_resource.assert_called_once() - dp_call_args = mock_client.get_dataproxy_for_resource.call_args[0] - assert dp_call_args[0] == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS - # The resource should be a RunIdentifier - assert dp_call_args[1].name == "my-run" - assert dp_call_args[1].project == "testproject" - assert dp_call_args[1].domain == "development" - # create_run should use offloaded_input_data req: run_service_pb2.CreateRunRequest = mock_run_service.create_run.call_args[0][0] assert req.offloaded_input_data == mock_offloaded