-
Notifications
You must be signed in to change notification settings - Fork 36
Add dataproxy client selector #959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
katrogan
merged 12 commits into
main
from
katrina/eng26-386-sdk-create-and-cache-per-cluster-conns-for-dataproxy-svc
Apr 16, 2026
Merged
Changes from 9 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
714bad8
Add dataproxy client selector
katrogan c7521a6
imports
katrogan 1a49dc8
lint
katrogan a1fcbf0
Merge branch 'main' into katrina/eng26-386-sdk-create-and-cache-per-c…
katrogan 69dd1fc
review comments
katrogan 54a1a07
fmt
katrogan 443b729
Merge branch 'main' into katrina/eng26-386-sdk-create-and-cache-per-c…
katrogan 7d10be8
Review comments
katrogan 426b9ba
Review comments
katrogan 00c6b3b
concurrent
katrogan 6effcfd
uv.locks
katrogan 90db3cb
alru
katrogan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,10 @@ | |
|
|
||
| from flyteidl2.app.app_service_connect import AppServiceClient | ||
| from flyteidl2.auth.identity_connect import IdentityServiceClient | ||
| from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2 | ||
| from flyteidl2.cluster.service_connect import ClusterServiceClient | ||
| from flyteidl2.common import identifier_pb2 | ||
| from flyteidl2.dataproxy import dataproxy_service_pb2 | ||
| from flyteidl2.dataproxy.dataproxy_service_connect import DataProxyServiceClient | ||
| from flyteidl2.project.project_service_connect import ProjectServiceClient | ||
| from flyteidl2.secret.secret_connect import SecretServiceClient | ||
|
|
@@ -14,6 +18,7 @@ | |
|
|
||
| from ._protocols import ( | ||
| AppService, | ||
| ClusterService, | ||
| DataProxyService, | ||
| IdentityService, | ||
| ProjectDomainService, | ||
|
|
@@ -161,6 +166,99 @@ def insecure(self) -> bool: | |
| return self._insecure | ||
|
|
||
|
|
||
| class ClusterAwareDataProxy: | ||
| """DataProxy client that routes each call to the correct cluster. | ||
|
|
||
| Implements the DataProxyService protocol. For every RPC, extracts the target | ||
| resource from the request, calls ClusterService.SelectCluster to discover | ||
| the cluster endpoint, and dispatches to a DataProxyServiceClient pointing at | ||
| that endpoint. Per-cluster clients are cached by (operation, resource) so | ||
| repeated calls against the same resource reuse the same connection. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| cluster_service: ClusterService, | ||
| session_config: SessionConfig, | ||
| default_client: DataProxyServiceClient, | ||
| ): | ||
| self._cluster_service = cluster_service | ||
| self._session_config = session_config | ||
| self._default_client = default_client | ||
| self._cache: dict[bytes, DataProxyService] = {} | ||
|
|
||
| async def create_upload_location( | ||
| self, request: dataproxy_service_pb2.CreateUploadLocationRequest | ||
| ) -> dataproxy_service_pb2.CreateUploadLocationResponse: | ||
| project_id = identifier_pb2.ProjectIdentifier( | ||
| name=request.project, domain=request.domain, organization=request.org | ||
| ) | ||
| client = await self._resolve( | ||
| cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_CREATE_UPLOAD_LOCATION, | ||
| project_id, | ||
| ) | ||
| return await client.create_upload_location(request) | ||
|
|
||
| async def upload_inputs( | ||
| self, request: dataproxy_service_pb2.UploadInputsRequest | ||
| ) -> dataproxy_service_pb2.UploadInputsResponse: | ||
| which = request.WhichOneof("id") | ||
| if which == "run_id": | ||
| # SelectClusterRequest.resource doesn't include RunIdentifier; route by project. | ||
| project_id = identifier_pb2.ProjectIdentifier( | ||
| name=request.run_id.project, | ||
| domain=request.run_id.domain, | ||
| organization=request.run_id.org, | ||
| ) | ||
| elif which == "project_id": | ||
| project_id = request.project_id | ||
| else: | ||
| raise ValueError("UploadInputsRequest must set either run_id or project_id") | ||
| client = await self._resolve( | ||
| cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS, | ||
| project_id, | ||
| ) | ||
| return await client.upload_inputs(request) | ||
|
|
||
| async def _resolve( | ||
| self, | ||
| operation: cluster_payload_pb2.SelectClusterRequest.Operation, | ||
| project_id: identifier_pb2.ProjectIdentifier, | ||
| ) -> DataProxyService: | ||
| from flyte._logging import logger | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not "co-routine" safe. This will lead to a race condition
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, thank you |
||
|
|
||
| cache_key = int(operation).to_bytes(4, "little") + project_id.SerializeToString(deterministic=True) | ||
| cached = self._cache.get(cache_key) | ||
| if cached is not None: | ||
| return cached | ||
|
|
||
| req = cluster_payload_pb2.SelectClusterRequest(operation=operation) | ||
| req.project_id.CopyFrom(project_id) | ||
| try: | ||
| resp = await self._cluster_service.select_cluster(req) | ||
| except Exception as e: | ||
| raise RuntimeError(f"SelectCluster failed for operation={operation}: {e}") from e | ||
|
|
||
| endpoint = resp.cluster_endpoint | ||
| if not endpoint or endpoint == self._session_config.endpoint: | ||
| self._cache[cache_key] = self._default_client | ||
| return self._default_client | ||
|
|
||
| try: | ||
| new_cfg = await create_session_config( | ||
| endpoint, | ||
| insecure=self._session_config.insecure, | ||
| insecure_skip_verify=self._session_config.insecure_skip_verify, | ||
| ) | ||
| except Exception as e: | ||
| raise RuntimeError(f"Failed to create session for cluster endpoint '{endpoint}': {e}") from e | ||
|
|
||
| logger.debug(f"Created DataProxy client for cluster endpoint: {endpoint}") | ||
| client = DataProxyServiceClient(**new_cfg.connect_kwargs()) | ||
| self._cache[cache_key] = client | ||
| return client | ||
|
|
||
|
|
||
| class ClientSet: | ||
| def __init__(self, session_cfg: SessionConfig): | ||
| self._console = Console(session_cfg.endpoint, session_cfg.insecure) | ||
|
|
@@ -170,11 +268,16 @@ def __init__(self, session_cfg: SessionConfig): | |
| self._task_service = TaskServiceClient(**shared) | ||
| self._app_service = AppServiceClient(**shared) | ||
| self._run_service = RunServiceClient(**shared) | ||
| self._dataproxy = DataProxyServiceClient(**shared) | ||
| self._log_service = RunLogsServiceClient(**shared) | ||
| self._secrets_service = SecretServiceClient(**shared) | ||
| self._identity_service = IdentityServiceClient(**shared) | ||
| self._trigger_service = TriggerServiceClient(**shared) | ||
| self._cluster_service = ClusterServiceClient(**shared) | ||
| self._dataproxy = ClusterAwareDataProxy( | ||
| cluster_service=self._cluster_service, | ||
| session_config=session_cfg, | ||
| default_client=DataProxyServiceClient(**shared), | ||
| ) | ||
|
|
||
| @classmethod | ||
| async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet: | ||
|
|
@@ -214,6 +317,11 @@ def run_service(self) -> RunService: | |
|
|
||
| @property | ||
| def dataproxy_service(self) -> DataProxyService: | ||
| """Cluster-aware DataProxy client. | ||
|
|
||
| Each call routes to the cluster selected by ClusterService.SelectCluster | ||
| for the target resource, with per-cluster clients cached. | ||
| """ | ||
| return self._dataproxy | ||
|
|
||
| @property | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| """Tests for the ClusterAwareDataProxy wrapper in flyte.remote._client.controlplane.""" | ||
|
|
||
| from unittest.mock import AsyncMock, MagicMock, patch | ||
|
|
||
| import pytest | ||
| from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2 | ||
| from flyteidl2.common import identifier_pb2 | ||
| from flyteidl2.dataproxy import dataproxy_service_pb2 | ||
|
|
||
| from flyte.remote._client.controlplane import ClusterAwareDataProxy | ||
|
|
||
|
|
||
| def _make_wrapper( | ||
| cluster_endpoint: str = "", | ||
| own_endpoint: str = "dns:///localhost:8090", | ||
| ): | ||
| cluster_service = MagicMock() | ||
| cluster_service.select_cluster = AsyncMock( | ||
| return_value=cluster_payload_pb2.SelectClusterResponse(cluster_endpoint=cluster_endpoint) | ||
| ) | ||
| session_config = MagicMock() | ||
| session_config.endpoint = own_endpoint | ||
| session_config.insecure = True | ||
| session_config.insecure_skip_verify = False | ||
| default_client = MagicMock() | ||
| default_client.create_upload_location = AsyncMock( | ||
| return_value=dataproxy_service_pb2.CreateUploadLocationResponse(signed_url="https://signed/") | ||
| ) | ||
| default_client.upload_inputs = AsyncMock(return_value=dataproxy_service_pb2.UploadInputsResponse()) | ||
| return ( | ||
| ClusterAwareDataProxy( | ||
| cluster_service=cluster_service, | ||
| session_config=session_config, | ||
| default_client=default_client, | ||
| ), | ||
| cluster_service, | ||
| default_client, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_create_upload_location_routes_by_project(): | ||
| wrapper, cluster_service, default_client = _make_wrapper() | ||
| req = dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o", filename="f") | ||
|
|
||
| await wrapper.create_upload_location(req) | ||
|
|
||
| cluster_service.select_cluster.assert_awaited_once() | ||
| sent = cluster_service.select_cluster.await_args[0][0] | ||
| assert sent.operation == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_CREATE_UPLOAD_LOCATION | ||
| assert sent.WhichOneof("resource") == "project_id" | ||
| assert sent.project_id.name == "p" | ||
| assert sent.project_id.domain == "d" | ||
| assert sent.project_id.organization == "o" | ||
| default_client.create_upload_location.assert_awaited_once_with(req) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_upload_inputs_with_project_id_routes_by_project(): | ||
| wrapper, cluster_service, default_client = _make_wrapper() | ||
| pid = identifier_pb2.ProjectIdentifier(name="p", domain="d", organization="o") | ||
| req = dataproxy_service_pb2.UploadInputsRequest(project_id=pid) | ||
|
|
||
| await wrapper.upload_inputs(req) | ||
|
|
||
| sent = cluster_service.select_cluster.await_args[0][0] | ||
| assert sent.operation == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS | ||
| assert sent.project_id == pid | ||
| default_client.upload_inputs.assert_awaited_once_with(req) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_upload_inputs_with_run_id_routes_by_derived_project(): | ||
| wrapper, cluster_service, default_client = _make_wrapper() | ||
| run_id = identifier_pb2.RunIdentifier(org="o", project="p", domain="d", name="r") | ||
| req = dataproxy_service_pb2.UploadInputsRequest(run_id=run_id) | ||
|
|
||
| await wrapper.upload_inputs(req) | ||
|
|
||
| sent = cluster_service.select_cluster.await_args[0][0] | ||
| assert sent.WhichOneof("resource") == "project_id" | ||
| assert sent.project_id.name == "p" | ||
| assert sent.project_id.domain == "d" | ||
| assert sent.project_id.organization == "o" | ||
| default_client.upload_inputs.assert_awaited_once_with(req) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_upload_inputs_requires_id_oneof(): | ||
| wrapper, _, _ = _make_wrapper() | ||
| with pytest.raises(ValueError): | ||
| await wrapper.upload_inputs(dataproxy_service_pb2.UploadInputsRequest()) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_cache_hits_reuse_selected_client(): | ||
| wrapper, cluster_service, default_client = _make_wrapper() | ||
| req = dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o", filename="f") | ||
|
|
||
| await wrapper.create_upload_location(req) | ||
| await wrapper.create_upload_location(req) | ||
|
|
||
| assert cluster_service.select_cluster.await_count == 1 | ||
| assert default_client.create_upload_location.await_count == 2 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_cache_keyed_on_operation_and_resource(): | ||
| wrapper, cluster_service, _ = _make_wrapper() | ||
|
|
||
| await wrapper.create_upload_location( | ||
| dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o") | ||
| ) | ||
| await wrapper.upload_inputs( | ||
| dataproxy_service_pb2.UploadInputsRequest( | ||
| project_id=identifier_pb2.ProjectIdentifier(name="p", domain="d", organization="o"), | ||
| ) | ||
| ) | ||
| await wrapper.create_upload_location( | ||
| dataproxy_service_pb2.CreateUploadLocationRequest(project="p2", domain="d", org="o") | ||
| ) | ||
|
|
||
| assert cluster_service.select_cluster.await_count == 3 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_remote_cluster_endpoint_creates_new_client(): | ||
| wrapper, cluster_service, default_client = _make_wrapper(cluster_endpoint="dns:///other:8090") | ||
|
|
||
| new_client_inst = MagicMock() | ||
| new_client_inst.create_upload_location = AsyncMock( | ||
| return_value=dataproxy_service_pb2.CreateUploadLocationResponse(signed_url="https://remote/") | ||
| ) | ||
| new_session_cfg = MagicMock() | ||
| new_session_cfg.connect_kwargs.return_value = {} | ||
|
|
||
| with ( | ||
| patch( | ||
| "flyte.remote._client.controlplane.create_session_config", | ||
| new=AsyncMock(return_value=new_session_cfg), | ||
| ), | ||
| patch( | ||
| "flyte.remote._client.controlplane.DataProxyServiceClient", | ||
| return_value=new_client_inst, | ||
| ), | ||
| ): | ||
| req = dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o") | ||
| await wrapper.create_upload_location(req) | ||
| # Cached for a subsequent call | ||
| await wrapper.create_upload_location(req) | ||
|
|
||
| assert cluster_service.select_cluster.await_count == 1 | ||
| assert new_client_inst.create_upload_location.await_count == 2 | ||
| default_client.create_upload_location.assert_not_awaited() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should be just using alru_cache on this right?