-
Notifications
You must be signed in to change notification settings - Fork 36
Add dataproxy client selector #959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
714bad8
c7521a6
1a49dc8
a1fcbf0
69dd1fc
54a1a07
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
|
|
||
| from flyteidl2.app.app_service_connect import AppServiceClient | ||
| from flyteidl2.auth.identity_connect import IdentityServiceClient | ||
| from flyteidl2.cluster.service_connect import ClusterServiceClient | ||
| from flyteidl2.dataproxy.dataproxy_service_connect import DataProxyServiceClient | ||
| from flyteidl2.project.project_service_connect import ProjectServiceClient | ||
| from flyteidl2.secret.secret_connect import SecretServiceClient | ||
|
|
@@ -14,6 +15,7 @@ | |
|
|
||
| from ._protocols import ( | ||
| AppService, | ||
| ClusterService, | ||
| DataProxyService, | ||
| IdentityService, | ||
| ProjectDomainService, | ||
|
|
@@ -175,6 +177,8 @@ def __init__(self, session_cfg: SessionConfig): | |
| self._secrets_service = SecretServiceClient(**shared) | ||
| self._identity_service = IdentityServiceClient(**shared) | ||
| self._trigger_service = TriggerServiceClient(**shared) | ||
| self._cluster_service = ClusterServiceClient(**shared) | ||
| self._dataproxy_cache: dict[str, DataProxyServiceClient] = {} | ||
|
|
||
| @classmethod | ||
| async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet: | ||
|
|
@@ -232,6 +236,59 @@ def identity_service(self) -> IdentityService: | |
| def trigger_service(self) -> TriggerService: | ||
| return self._trigger_service | ||
|
|
||
| @property | ||
| def cluster_service(self) -> ClusterService: | ||
| return self._cluster_service | ||
|
|
||
| async def get_dataproxy_for_resource(self, operation: int, resource: object) -> DataProxyService: | ||
| """Get a DataProxy client routed to the correct cluster for the given resource. | ||
|
|
||
| Calls SelectCluster to discover the cluster endpoint, then creates (or reuses | ||
| from cache) a DataProxyServiceClient pointing at that endpoint. | ||
|
|
||
| Args: | ||
| operation: SelectClusterRequest.Operation enum value. | ||
| resource: A protobuf identifier (e.g. RunIdentifier, ProjectIdentifier). | ||
| """ | ||
| from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2 | ||
|
|
||
| cache_key = f"{operation}:{type(resource).__name__}:{resource}" | ||
|
|
||
| if cache_key in self._dataproxy_cache: | ||
| return self._dataproxy_cache[cache_key] | ||
|
|
||
| # Build the SelectClusterRequest with the right oneof field | ||
| req = cluster_payload_pb2.SelectClusterRequest(operation=operation) | ||
| if hasattr(resource, "DESCRIPTOR"): | ||
| field_map = { | ||
| "OrgIdentifier": "org_id", | ||
| "ProjectIdentifier": "project_id", | ||
| "TaskIdentifier": "task_id", | ||
| "ActionIdentifier": "action_id", | ||
| "ActionAttemptIdentifier": "action_attempt_id", | ||
| } | ||
| field_name = field_map.get(type(resource).__name__) | ||
| if field_name: | ||
| getattr(req, field_name).CopyFrom(resource) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the best way to create the one of 😬, @wild-endeavor @pingsutw ?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can do something like req = SelectClusterRequest(operation=operation)
if hasattr(resource, "DESCRIPTOR"):
oneof = req.DESCRIPTOR.oneofs_by_name["resource"] # replace with actual oneof name
for field in oneof.fields:
if field.message_type is resource.DESCRIPTOR:
getattr(req, field.name).CopyFrom(resource)
break
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice thanks. done |
||
|
|
||
| resp = await self._cluster_service.select_cluster(req) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make sure we are throwing informative errors here in the infamous "no healthy clusters" error.. we now have a good place to catch that and raise a good error to the user
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated |
||
| cluster_endpoint = resp.cluster_endpoint | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you going to normalize this here? stripping/adding http/s or dns:/// ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not necessary! create_session_config already calls normalize_rpc_endpoint |
||
|
|
||
| # If the cluster endpoint matches our own endpoint, reuse the default client | ||
| if not cluster_endpoint or cluster_endpoint == self._session_config.endpoint: | ||
| self._dataproxy_cache[cache_key] = self._dataproxy | ||
| return self._dataproxy | ||
|
|
||
| # Create a new session config for the cluster endpoint | ||
| new_session_cfg = await create_session_config( | ||
| cluster_endpoint, | ||
| insecure=self._session_config.insecure, | ||
| insecure_skip_verify=self._session_config.insecure_skip_verify, | ||
| ) | ||
| client = DataProxyServiceClient(**new_session_cfg.connect_kwargs()) | ||
| self._dataproxy_cache[cache_key] = client | ||
| return client | ||
|
|
||
| @property | ||
| def endpoint(self) -> str: | ||
| return self._session_config.endpoint | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this operation be an enum instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done