-
Notifications
You must be signed in to change notification settings - Fork 36
Add dataproxy client selector #959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
714bad8
c7521a6
1a49dc8
a1fcbf0
69dd1fc
54a1a07
443b729
7d10be8
426b9ba
00c6b3b
6effcfd
90db3cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,10 @@ | |
|
|
||
| from flyteidl2.app.app_service_connect import AppServiceClient | ||
| from flyteidl2.auth.identity_connect import IdentityServiceClient | ||
| from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2 | ||
| from flyteidl2.cluster.service_connect import ClusterServiceClient | ||
| from flyteidl2.common import identifier_pb2 | ||
| from flyteidl2.dataproxy import dataproxy_service_pb2 | ||
| from flyteidl2.dataproxy.dataproxy_service_connect import DataProxyServiceClient | ||
| from flyteidl2.project.project_service_connect import ProjectServiceClient | ||
| from flyteidl2.secret.secret_connect import SecretServiceClient | ||
|
|
@@ -14,6 +18,7 @@ | |
|
|
||
| from ._protocols import ( | ||
| AppService, | ||
| ClusterService, | ||
| DataProxyService, | ||
| IdentityService, | ||
| ProjectDomainService, | ||
|
|
@@ -161,6 +166,99 @@ def insecure(self) -> bool: | |
| return self._insecure | ||
|
|
||
|
|
||
| class ClusterAwareDataProxy: | ||
| """DataProxy client that routes each call to the correct cluster. | ||
|
|
||
| Implements the DataProxyService protocol. For every RPC, extracts the target | ||
| resource from the request, calls ClusterService.SelectCluster to discover | ||
| the cluster endpoint, and dispatches to a DataProxyServiceClient pointing at | ||
| that endpoint. Per-cluster clients are cached by (operation, resource) so | ||
| repeated calls against the same resource reuse the same connection. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| cluster_service: ClusterService, | ||
| session_config: SessionConfig, | ||
| default_client: DataProxyServiceClient, | ||
| ): | ||
| self._cluster_service = cluster_service | ||
| self._session_config = session_config | ||
| self._default_client = default_client | ||
| self._cache: dict[bytes, DataProxyService] = {} | ||
|
|
||
| async def create_upload_location( | ||
| self, request: dataproxy_service_pb2.CreateUploadLocationRequest | ||
| ) -> dataproxy_service_pb2.CreateUploadLocationResponse: | ||
| project_id = identifier_pb2.ProjectIdentifier( | ||
| name=request.project, domain=request.domain, organization=request.org | ||
| ) | ||
| client = await self._resolve( | ||
| cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_CREATE_UPLOAD_LOCATION, | ||
| project_id, | ||
| ) | ||
| return await client.create_upload_location(request) | ||
|
|
||
| async def upload_inputs( | ||
| self, request: dataproxy_service_pb2.UploadInputsRequest | ||
| ) -> dataproxy_service_pb2.UploadInputsResponse: | ||
| which = request.WhichOneof("id") | ||
| if which == "run_id": | ||
| # SelectClusterRequest.resource doesn't include RunIdentifier; route by project. | ||
| project_id = identifier_pb2.ProjectIdentifier( | ||
| name=request.run_id.project, | ||
| domain=request.run_id.domain, | ||
| organization=request.run_id.org, | ||
| ) | ||
| elif which == "project_id": | ||
| project_id = request.project_id | ||
| else: | ||
| raise ValueError("UploadInputsRequest must set either run_id or project_id") | ||
| client = await self._resolve( | ||
| cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS, | ||
| project_id, | ||
| ) | ||
| return await client.upload_inputs(request) | ||
|
|
||
| async def _resolve( | ||
| self, | ||
| operation: cluster_payload_pb2.SelectClusterRequest.Operation, | ||
| project_id: identifier_pb2.ProjectIdentifier, | ||
| ) -> DataProxyService: | ||
| from flyte._logging import logger | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not "co-routine" safe. This will lead to a race condition
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, thank you |
||
|
|
||
| cache_key = int(operation).to_bytes(4, "little") + project_id.SerializeToString(deterministic=True) | ||
| cached = self._cache.get(cache_key) | ||
| if cached is not None: | ||
| return cached | ||
|
|
||
| req = cluster_payload_pb2.SelectClusterRequest(operation=operation) | ||
| req.project_id.CopyFrom(project_id) | ||
| try: | ||
| resp = await self._cluster_service.select_cluster(req) | ||
| except Exception as e: | ||
| raise RuntimeError(f"SelectCluster failed for operation={operation}: {e}") from e | ||
|
|
||
| endpoint = resp.cluster_endpoint | ||
| if not endpoint or endpoint == self._session_config.endpoint: | ||
| self._cache[cache_key] = self._default_client | ||
| return self._default_client | ||
|
|
||
| try: | ||
| new_cfg = await create_session_config( | ||
| endpoint, | ||
| insecure=self._session_config.insecure, | ||
| insecure_skip_verify=self._session_config.insecure_skip_verify, | ||
| ) | ||
| except Exception as e: | ||
| raise RuntimeError(f"Failed to create session for cluster endpoint '{endpoint}': {e}") from e | ||
|
|
||
| logger.debug(f"Created DataProxy client for cluster endpoint: {endpoint}") | ||
| client = DataProxyServiceClient(**new_cfg.connect_kwargs()) | ||
| self._cache[cache_key] = client | ||
| return client | ||
|
|
||
|
|
||
| class ClientSet: | ||
| def __init__(self, session_cfg: SessionConfig): | ||
| self._console = Console(session_cfg.endpoint, session_cfg.insecure) | ||
|
|
@@ -170,11 +268,16 @@ def __init__(self, session_cfg: SessionConfig): | |
| self._task_service = TaskServiceClient(**shared) | ||
| self._app_service = AppServiceClient(**shared) | ||
| self._run_service = RunServiceClient(**shared) | ||
| self._dataproxy = DataProxyServiceClient(**shared) | ||
| self._log_service = RunLogsServiceClient(**shared) | ||
| self._secrets_service = SecretServiceClient(**shared) | ||
| self._identity_service = IdentityServiceClient(**shared) | ||
| self._trigger_service = TriggerServiceClient(**shared) | ||
| self._cluster_service = ClusterServiceClient(**shared) | ||
| self._dataproxy = ClusterAwareDataProxy( | ||
| cluster_service=self._cluster_service, | ||
| session_config=session_cfg, | ||
| default_client=DataProxyServiceClient(**shared), | ||
| ) | ||
|
|
||
| @classmethod | ||
| async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet: | ||
|
|
@@ -214,6 +317,11 @@ def run_service(self) -> RunService: | |
|
|
||
| @property | ||
| def dataproxy_service(self) -> DataProxyService: | ||
| """Cluster-aware DataProxy client. | ||
|
|
||
| Each call routes to the cluster selected by ClusterService.SelectCluster | ||
| for the target resource, with per-cluster clients cached. | ||
| """ | ||
| return self._dataproxy | ||
|
|
||
| @property | ||
|
|
@@ -232,6 +340,10 @@ def identity_service(self) -> IdentityService: | |
| def trigger_service(self) -> TriggerService: | ||
| return self._trigger_service | ||
|
|
||
| @property | ||
| def cluster_service(self) -> ClusterService: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not expose this? Do we need it anywhere at all
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| return self._cluster_service | ||
|
|
||
| @property | ||
| def endpoint(self) -> str: | ||
| return self._session_config.endpoint | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| """Tests for the ClusterAwareDataProxy wrapper in flyte.remote._client.controlplane.""" | ||
|
|
||
| from unittest.mock import AsyncMock, MagicMock, patch | ||
|
|
||
| import pytest | ||
| from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2 | ||
| from flyteidl2.common import identifier_pb2 | ||
| from flyteidl2.dataproxy import dataproxy_service_pb2 | ||
|
|
||
| from flyte.remote._client.controlplane import ClusterAwareDataProxy | ||
|
|
||
|
|
||
| def _make_wrapper( | ||
| cluster_endpoint: str = "", | ||
| own_endpoint: str = "dns:///localhost:8090", | ||
| ): | ||
| cluster_service = MagicMock() | ||
| cluster_service.select_cluster = AsyncMock( | ||
| return_value=cluster_payload_pb2.SelectClusterResponse(cluster_endpoint=cluster_endpoint) | ||
| ) | ||
| session_config = MagicMock() | ||
| session_config.endpoint = own_endpoint | ||
| session_config.insecure = True | ||
| session_config.insecure_skip_verify = False | ||
| default_client = MagicMock() | ||
| default_client.create_upload_location = AsyncMock( | ||
| return_value=dataproxy_service_pb2.CreateUploadLocationResponse(signed_url="https://signed/") | ||
| ) | ||
| default_client.upload_inputs = AsyncMock(return_value=dataproxy_service_pb2.UploadInputsResponse()) | ||
| return ( | ||
| ClusterAwareDataProxy( | ||
| cluster_service=cluster_service, | ||
| session_config=session_config, | ||
| default_client=default_client, | ||
| ), | ||
| cluster_service, | ||
| default_client, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_create_upload_location_routes_by_project(): | ||
| wrapper, cluster_service, default_client = _make_wrapper() | ||
| req = dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o", filename="f") | ||
|
|
||
| await wrapper.create_upload_location(req) | ||
|
|
||
| cluster_service.select_cluster.assert_awaited_once() | ||
| sent = cluster_service.select_cluster.await_args[0][0] | ||
| assert sent.operation == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_CREATE_UPLOAD_LOCATION | ||
| assert sent.WhichOneof("resource") == "project_id" | ||
| assert sent.project_id.name == "p" | ||
| assert sent.project_id.domain == "d" | ||
| assert sent.project_id.organization == "o" | ||
| default_client.create_upload_location.assert_awaited_once_with(req) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_upload_inputs_with_project_id_routes_by_project(): | ||
| wrapper, cluster_service, default_client = _make_wrapper() | ||
| pid = identifier_pb2.ProjectIdentifier(name="p", domain="d", organization="o") | ||
| req = dataproxy_service_pb2.UploadInputsRequest(project_id=pid) | ||
|
|
||
| await wrapper.upload_inputs(req) | ||
|
|
||
| sent = cluster_service.select_cluster.await_args[0][0] | ||
| assert sent.operation == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS | ||
| assert sent.project_id == pid | ||
| default_client.upload_inputs.assert_awaited_once_with(req) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_upload_inputs_with_run_id_routes_by_derived_project(): | ||
| wrapper, cluster_service, default_client = _make_wrapper() | ||
| run_id = identifier_pb2.RunIdentifier(org="o", project="p", domain="d", name="r") | ||
| req = dataproxy_service_pb2.UploadInputsRequest(run_id=run_id) | ||
|
|
||
| await wrapper.upload_inputs(req) | ||
|
|
||
| sent = cluster_service.select_cluster.await_args[0][0] | ||
| assert sent.WhichOneof("resource") == "project_id" | ||
| assert sent.project_id.name == "p" | ||
| assert sent.project_id.domain == "d" | ||
| assert sent.project_id.organization == "o" | ||
| default_client.upload_inputs.assert_awaited_once_with(req) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_upload_inputs_requires_id_oneof(): | ||
| wrapper, _, _ = _make_wrapper() | ||
| with pytest.raises(ValueError): | ||
| await wrapper.upload_inputs(dataproxy_service_pb2.UploadInputsRequest()) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_cache_hits_reuse_selected_client(): | ||
| wrapper, cluster_service, default_client = _make_wrapper() | ||
| req = dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o", filename="f") | ||
|
|
||
| await wrapper.create_upload_location(req) | ||
| await wrapper.create_upload_location(req) | ||
|
|
||
| assert cluster_service.select_cluster.await_count == 1 | ||
| assert default_client.create_upload_location.await_count == 2 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_cache_keyed_on_operation_and_resource(): | ||
| wrapper, cluster_service, _ = _make_wrapper() | ||
|
|
||
| await wrapper.create_upload_location( | ||
| dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o") | ||
| ) | ||
| await wrapper.upload_inputs( | ||
| dataproxy_service_pb2.UploadInputsRequest( | ||
| project_id=identifier_pb2.ProjectIdentifier(name="p", domain="d", organization="o"), | ||
| ) | ||
| ) | ||
| await wrapper.create_upload_location( | ||
| dataproxy_service_pb2.CreateUploadLocationRequest(project="p2", domain="d", org="o") | ||
| ) | ||
|
|
||
| assert cluster_service.select_cluster.await_count == 3 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_remote_cluster_endpoint_creates_new_client(): | ||
| wrapper, cluster_service, default_client = _make_wrapper(cluster_endpoint="dns:///other:8090") | ||
|
|
||
| new_client_inst = MagicMock() | ||
| new_client_inst.create_upload_location = AsyncMock( | ||
| return_value=dataproxy_service_pb2.CreateUploadLocationResponse(signed_url="https://remote/") | ||
| ) | ||
| new_session_cfg = MagicMock() | ||
| new_session_cfg.connect_kwargs.return_value = {} | ||
|
|
||
| with ( | ||
| patch( | ||
| "flyte.remote._client.controlplane.create_session_config", | ||
| new=AsyncMock(return_value=new_session_cfg), | ||
| ), | ||
| patch( | ||
| "flyte.remote._client.controlplane.DataProxyServiceClient", | ||
| return_value=new_client_inst, | ||
| ), | ||
| ): | ||
| req = dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o") | ||
| await wrapper.create_upload_location(req) | ||
| # Cached for a subsequent call | ||
| await wrapper.create_upload_location(req) | ||
|
|
||
| assert cluster_service.select_cluster.await_count == 1 | ||
| assert new_client_inst.create_upload_location.await_count == 2 | ||
| default_client.create_upload_location.assert_not_awaited() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should be just using alru_cache on this right?