Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/flyte/remote/_client/_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from flyteidl2.app import app_payload_pb2
from flyteidl2.auth import identity_pb2
from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2
from flyteidl2.dataproxy import dataproxy_service_pb2
from flyteidl2.project import project_service_pb2
from flyteidl2.secret import payload_pb2
Expand Down Expand Up @@ -134,6 +135,12 @@ class IdentityService(Protocol):
async def user_info(self, request: identity_pb2.UserInfoRequest) -> identity_pb2.UserInfoResponse: ...


class ClusterService(Protocol):
async def select_cluster(
self, request: cluster_payload_pb2.SelectClusterRequest
) -> cluster_payload_pb2.SelectClusterResponse: ...


class TriggerService(Protocol):
async def deploy_trigger(
self, request: trigger_service_pb2.DeployTriggerRequest
Expand Down
114 changes: 113 additions & 1 deletion src/flyte/remote/_client/controlplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

from flyteidl2.app.app_service_connect import AppServiceClient
from flyteidl2.auth.identity_connect import IdentityServiceClient
from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2
from flyteidl2.cluster.service_connect import ClusterServiceClient
from flyteidl2.common import identifier_pb2
from flyteidl2.dataproxy import dataproxy_service_pb2
from flyteidl2.dataproxy.dataproxy_service_connect import DataProxyServiceClient
from flyteidl2.project.project_service_connect import ProjectServiceClient
from flyteidl2.secret.secret_connect import SecretServiceClient
Expand All @@ -14,6 +18,7 @@

from ._protocols import (
AppService,
ClusterService,
DataProxyService,
IdentityService,
ProjectDomainService,
Expand Down Expand Up @@ -161,6 +166,99 @@ def insecure(self) -> bool:
return self._insecure


class ClusterAwareDataProxy:
"""DataProxy client that routes each call to the correct cluster.

Implements the DataProxyService protocol. For every RPC, extracts the target
resource from the request, calls ClusterService.SelectCluster to discover
the cluster endpoint, and dispatches to a DataProxyServiceClient pointing at
that endpoint. Per-cluster clients are cached by (operation, resource) so
repeated calls against the same resource reuse the same connection.
"""

def __init__(
self,
cluster_service: ClusterService,
session_config: SessionConfig,
default_client: DataProxyServiceClient,
):
self._cluster_service = cluster_service
self._session_config = session_config
self._default_client = default_client
self._cache: dict[bytes, DataProxyService] = {}

async def create_upload_location(
self, request: dataproxy_service_pb2.CreateUploadLocationRequest
) -> dataproxy_service_pb2.CreateUploadLocationResponse:
project_id = identifier_pb2.ProjectIdentifier(
name=request.project, domain=request.domain, organization=request.org
)
client = await self._resolve(
cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_CREATE_UPLOAD_LOCATION,
project_id,
)
return await client.create_upload_location(request)

async def upload_inputs(
self, request: dataproxy_service_pb2.UploadInputsRequest
) -> dataproxy_service_pb2.UploadInputsResponse:
which = request.WhichOneof("id")
if which == "run_id":
# SelectClusterRequest.resource doesn't include RunIdentifier; route by project.
project_id = identifier_pb2.ProjectIdentifier(
name=request.run_id.project,
domain=request.run_id.domain,
organization=request.run_id.org,
)
elif which == "project_id":
project_id = request.project_id
else:
raise ValueError("UploadInputsRequest must set either run_id or project_id")
client = await self._resolve(
cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS,
project_id,
)
return await client.upload_inputs(request)

async def _resolve(
self,
operation: cluster_payload_pb2.SelectClusterRequest.Operation,
project_id: identifier_pb2.ProjectIdentifier,
) -> DataProxyService:
from flyte._logging import logger

cache_key = int(operation).to_bytes(4, "little") + project_id.SerializeToString(deterministic=True)
cached = self._cache.get(cache_key)
if cached is not None:
return cached

req = cluster_payload_pb2.SelectClusterRequest(operation=operation)
req.project_id.CopyFrom(project_id)
try:
resp = await self._cluster_service.select_cluster(req)
except Exception as e:
raise RuntimeError(f"SelectCluster failed for operation={operation}: {e}") from e

endpoint = resp.cluster_endpoint
if not endpoint or endpoint == self._session_config.endpoint:
self._cache[cache_key] = self._default_client
return self._default_client

try:
new_cfg = await create_session_config(
endpoint,
insecure=self._session_config.insecure,
insecure_skip_verify=self._session_config.insecure_skip_verify,
)
except Exception as e:
raise RuntimeError(f"Failed to create session for cluster endpoint '{endpoint}': {e}") from e

logger.debug(f"Created DataProxy client for cluster endpoint: {endpoint}")
client = DataProxyServiceClient(**new_cfg.connect_kwargs())
self._cache[cache_key] = client
return client


class ClientSet:
def __init__(self, session_cfg: SessionConfig):
self._console = Console(session_cfg.endpoint, session_cfg.insecure)
Expand All @@ -170,11 +268,16 @@ def __init__(self, session_cfg: SessionConfig):
self._task_service = TaskServiceClient(**shared)
self._app_service = AppServiceClient(**shared)
self._run_service = RunServiceClient(**shared)
self._dataproxy = DataProxyServiceClient(**shared)
self._log_service = RunLogsServiceClient(**shared)
self._secrets_service = SecretServiceClient(**shared)
self._identity_service = IdentityServiceClient(**shared)
self._trigger_service = TriggerServiceClient(**shared)
self._cluster_service = ClusterServiceClient(**shared)
self._dataproxy = ClusterAwareDataProxy(
cluster_service=self._cluster_service,
session_config=session_cfg,
default_client=DataProxyServiceClient(**shared),
)

@classmethod
async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet:
Expand Down Expand Up @@ -214,6 +317,11 @@ def run_service(self) -> RunService:

@property
def dataproxy_service(self) -> DataProxyService:
"""Cluster-aware DataProxy client.

Each call routes to the cluster selected by ClusterService.SelectCluster
for the target resource, with per-cluster clients cached.
"""
return self._dataproxy

@property
Expand All @@ -232,6 +340,10 @@ def identity_service(self) -> IdentityService:
def trigger_service(self) -> TriggerService:
return self._trigger_service

@property
def cluster_service(self) -> ClusterService:
return self._cluster_service

@property
def endpoint(self) -> str:
return self._session_config.endpoint
Expand Down
1 change: 1 addition & 0 deletions src/flyte/remote/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ async def _upload_single_file(
dataproxy_service_pb2.CreateUploadLocationRequest(
project=cfg.project,
domain=cfg.domain,
org=cfg.org or "",
content_md5=md5_bytes,
filename=fname or fp.name,
expires_in=expires_in_pb,
Expand Down
154 changes: 154 additions & 0 deletions tests/flyte/remote/test_cluster_aware_dataproxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Tests for the ClusterAwareDataProxy wrapper in flyte.remote._client.controlplane."""

from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2
from flyteidl2.common import identifier_pb2
from flyteidl2.dataproxy import dataproxy_service_pb2

from flyte.remote._client.controlplane import ClusterAwareDataProxy


def _make_wrapper(
cluster_endpoint: str = "",
own_endpoint: str = "dns:///localhost:8090",
):
cluster_service = MagicMock()
cluster_service.select_cluster = AsyncMock(
return_value=cluster_payload_pb2.SelectClusterResponse(cluster_endpoint=cluster_endpoint)
)
session_config = MagicMock()
session_config.endpoint = own_endpoint
session_config.insecure = True
session_config.insecure_skip_verify = False
default_client = MagicMock()
default_client.create_upload_location = AsyncMock(
return_value=dataproxy_service_pb2.CreateUploadLocationResponse(signed_url="https://signed/")
)
default_client.upload_inputs = AsyncMock(return_value=dataproxy_service_pb2.UploadInputsResponse())
return (
ClusterAwareDataProxy(
cluster_service=cluster_service,
session_config=session_config,
default_client=default_client,
),
cluster_service,
default_client,
)


@pytest.mark.asyncio
async def test_create_upload_location_routes_by_project():
wrapper, cluster_service, default_client = _make_wrapper()
req = dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o", filename="f")

await wrapper.create_upload_location(req)

cluster_service.select_cluster.assert_awaited_once()
sent = cluster_service.select_cluster.await_args[0][0]
assert sent.operation == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_CREATE_UPLOAD_LOCATION
assert sent.WhichOneof("resource") == "project_id"
assert sent.project_id.name == "p"
assert sent.project_id.domain == "d"
assert sent.project_id.organization == "o"
default_client.create_upload_location.assert_awaited_once_with(req)


@pytest.mark.asyncio
async def test_upload_inputs_with_project_id_routes_by_project():
wrapper, cluster_service, default_client = _make_wrapper()
pid = identifier_pb2.ProjectIdentifier(name="p", domain="d", organization="o")
req = dataproxy_service_pb2.UploadInputsRequest(project_id=pid)

await wrapper.upload_inputs(req)

sent = cluster_service.select_cluster.await_args[0][0]
assert sent.operation == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_UPLOAD_INPUTS
assert sent.project_id == pid
default_client.upload_inputs.assert_awaited_once_with(req)


@pytest.mark.asyncio
async def test_upload_inputs_with_run_id_routes_by_derived_project():
wrapper, cluster_service, default_client = _make_wrapper()
run_id = identifier_pb2.RunIdentifier(org="o", project="p", domain="d", name="r")
req = dataproxy_service_pb2.UploadInputsRequest(run_id=run_id)

await wrapper.upload_inputs(req)

sent = cluster_service.select_cluster.await_args[0][0]
assert sent.WhichOneof("resource") == "project_id"
assert sent.project_id.name == "p"
assert sent.project_id.domain == "d"
assert sent.project_id.organization == "o"
default_client.upload_inputs.assert_awaited_once_with(req)


@pytest.mark.asyncio
async def test_upload_inputs_requires_id_oneof():
wrapper, _, _ = _make_wrapper()
with pytest.raises(ValueError):
await wrapper.upload_inputs(dataproxy_service_pb2.UploadInputsRequest())


@pytest.mark.asyncio
async def test_cache_hits_reuse_selected_client():
wrapper, cluster_service, default_client = _make_wrapper()
req = dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o", filename="f")

await wrapper.create_upload_location(req)
await wrapper.create_upload_location(req)

assert cluster_service.select_cluster.await_count == 1
assert default_client.create_upload_location.await_count == 2


@pytest.mark.asyncio
async def test_cache_keyed_on_operation_and_resource():
wrapper, cluster_service, _ = _make_wrapper()

await wrapper.create_upload_location(
dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o")
)
await wrapper.upload_inputs(
dataproxy_service_pb2.UploadInputsRequest(
project_id=identifier_pb2.ProjectIdentifier(name="p", domain="d", organization="o"),
)
)
await wrapper.create_upload_location(
dataproxy_service_pb2.CreateUploadLocationRequest(project="p2", domain="d", org="o")
)

assert cluster_service.select_cluster.await_count == 3


@pytest.mark.asyncio
async def test_remote_cluster_endpoint_creates_new_client():
wrapper, cluster_service, default_client = _make_wrapper(cluster_endpoint="dns:///other:8090")

new_client_inst = MagicMock()
new_client_inst.create_upload_location = AsyncMock(
return_value=dataproxy_service_pb2.CreateUploadLocationResponse(signed_url="https://remote/")
)
new_session_cfg = MagicMock()
new_session_cfg.connect_kwargs.return_value = {}

with (
patch(
"flyte.remote._client.controlplane.create_session_config",
new=AsyncMock(return_value=new_session_cfg),
),
patch(
"flyte.remote._client.controlplane.DataProxyServiceClient",
return_value=new_client_inst,
),
):
req = dataproxy_service_pb2.CreateUploadLocationRequest(project="p", domain="d", org="o")
await wrapper.create_upload_location(req)
# Cached for a subsequent call
await wrapper.create_upload_location(req)

assert cluster_service.select_cluster.await_count == 1
assert new_client_inst.create_upload_location.await_count == 2
default_client.create_upload_location.assert_not_awaited()
Loading