Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 112 additions & 1 deletion src/deisa/ray/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from __future__ import annotations
import copy
import logging
import sys
import time
from typing import Any, Dict, Mapping, Optional, Union
import numpy as np
import ray
Expand All @@ -18,7 +20,6 @@
from deisa.ray.scheduling_actor import SchedulingActor as _RealSchedulingActor
from deisa.ray.types import RayActorHandle
from deisa.ray.utils import get_node_actor_options, get_ray_address
import sys

logger = logging.getLogger(__name__)

Expand All @@ -29,6 +30,99 @@ def _validate_comm(comm: Any) -> None:
raise TypeError("comm must implement deisa.core.ICommunicator")


def _estimate_object_size_bytes(obj: Any) -> int:
if isinstance(obj, np.ndarray):
return max(int(obj.nbytes), sys.getsizeof(obj))
return sys.getsizeof(obj)


def _get_ray_object_store_memory() -> tuple[int, int]:
from ray._private import internal_api

state = internal_api.get_state_from_address(ray.get_runtime_context().gcs_address)
store_stats = internal_api.get_memory_info_reply(state).store_stats
used_bytes = int(store_stats.object_store_bytes_used)
total_bytes = int(store_stats.object_store_bytes_avail)
return used_bytes, total_bytes


def _wait_for_object_store_memory(
required_bytes: int,
poll_interval: float = 0.5,
threshold: float = 0.6,
timeout_s: float | None = 100.0,
bridge_rank: int | None = None,
array_name: str | None = None,
timestep: int | None = None,
) -> None:
"""
Block until Ray's object store has enough room for a new object.

Both conditions must hold before returning: the projected object store
usage after inserting the object is below ``threshold`` and currently free
object store memory can hold ``required_bytes``.
"""
start_time = time.monotonic()
while True:
used_bytes, total_bytes = _get_ray_object_store_memory()
free_bytes = max(total_bytes - used_bytes, 0)
Comment thread
theabm marked this conversation as resolved.
projected_used_bytes = used_bytes + required_bytes
projected_usage = projected_used_bytes / total_bytes if total_bytes > 0 else 1
Comment thread
theabm marked this conversation as resolved.
usage_ok = projected_usage < threshold
space_ok = free_bytes >= required_bytes

if usage_ok and space_ok:
Comment thread
theabm marked this conversation as resolved.
Outdated
return

elapsed_s = time.monotonic() - start_time
if timeout_s is not None and elapsed_s >= timeout_s:
message = (
"Insufficient Ray object store memory while bridge rank %s was sending chunk "
"array=%r timestep=%r after %.1fs: available=%.2f GB, used=%.1f%%, "
"projected_used=%.1f%%, needed=%.2f GB, threshold=%.1f%%"
)
logger.error(
message,
bridge_rank,
array_name,
timestep,
elapsed_s,
free_bytes / 1e9,
(100 * used_bytes / total_bytes) if total_bytes > 0 else 100,
100 * projected_usage,
required_bytes / 1e9,
100 * threshold,
)
raise MemoryError(
message
% (
bridge_rank,
array_name,
timestep,
elapsed_s,
free_bytes / 1e9,
(100 * used_bytes / total_bytes) if total_bytes > 0 else 100,
100 * projected_usage,
required_bytes / 1e9,
100 * threshold,
)
)

logger.info(
"Waiting for Ray object store memory before ray.put from bridge rank %s: "
"array=%r, timestep=%r, available=%.2f GB, "
"used=%.1f%%, projected_used=%.1f%%, needed=%.2f GB",
bridge_rank,
array_name,
timestep,
free_bytes / 1e9,
(100 * used_bytes / total_bytes) if total_bytes > 0 else 100,
100 * projected_usage,
required_bytes / 1e9,
)
time.sleep(poll_interval)


class Bridge(IBridge):
"""
Bridge between MPI ranks and Ray cluster for distributed array processing.
Expand All @@ -51,6 +145,9 @@ class Bridge(IBridge):
:class:`deisa.ray.scheduling_actor.SchedulingActor`.
_init_retries : int, optional
Number of attempts to create and ready the node actor. Defaults to 3.
object_store_memory_timeout_s : float or None, optional
Maximum time to wait for Ray object store memory before raising
``MemoryError``. Defaults to 100 seconds. Set to ``None`` to wait forever.

Attributes
----------
Expand Down Expand Up @@ -109,6 +206,9 @@ def __init__(
Keys represent the name of the array while the values are
dictionaries that must at least declare the metadata expected by
:meth:`validate_arrays_meta`.
object_store_memory_timeout_s : float or None, optional
Maximum time to wait for Ray object store memory before raising
``MemoryError``. Defaults to 100 seconds. Set to ``None`` to wait forever.

Raises
------
Expand All @@ -133,12 +233,14 @@ def __init__(
_node_id: str | None = kwargs.pop("_node_id", None)
scheduling_actor_cls: ActorClass = kwargs.pop("scheduling_actor_cls", _RealSchedulingActor)
_init_retries: int = kwargs.pop("_init_retries", 3)
object_store_memory_timeout_s: float | None = kwargs.pop("object_store_memory_timeout_s", 100.0)
if kwargs:
unexpected = next(iter(kwargs))
raise TypeError(f"Bridge.__init__() got an unexpected keyword argument '{unexpected}'")

self._init_retries = _init_retries
self._closed = False
self.object_store_memory_timeout_s = object_store_memory_timeout_s

self.arrays_metadata = copy.deepcopy(validate_arrays_metadata(arrays_metadata))
if comm is None:
Expand Down Expand Up @@ -268,6 +370,13 @@ def send(
try:
chunk_dtype = chunk.dtype
# Setting the owner allows keeping the reference when the simulation script terminates.
_wait_for_object_store_memory(
_estimate_object_size_bytes(chunk),
bridge_rank=self.bridge_id,
array_name=array_name,
timestep=timestep,
timeout_s=self.object_store_memory_timeout_s,
)
ref = ray.put(chunk, _owner=self.node_actor)
future: ray.ObjectRef = self.node_actor.add_chunk.remote(
bridge_id=self.bridge_id,
Expand All @@ -280,6 +389,8 @@ def send(
ray.get(future)
except ContractError as e:
raise e
except MemoryError:
raise
except Exception as e:
_default_exception_handler(e)

Expand Down
5 changes: 0 additions & 5 deletions src/deisa/ray/head_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def __init__(self, max_simulation_ahead: int = 1, feedback_queue_size: int = 102
# TODO: document what this event signals and update documentation
self.new_array_created: dict[str, asyncio.Lock] = {}
self.max_simulation_ahead = max_simulation_ahead
self.semaphore_per_array = {}

if feedback_queue_size <= 0:
raise ValueError(f"feedback_queue_size must be > 0, got {feedback_queue_size}")
Expand Down Expand Up @@ -175,7 +174,6 @@ def register_partial_array(

if array_name not in self.registered_arrays:
self.registered_arrays[array_name] = DaskArrayData(array_name)
self.semaphore_per_array[array_name] = asyncio.Semaphore(self.max_simulation_ahead)
self.new_array_created[array_name] = asyncio.Lock()

array = self.registered_arrays[array_name]
Expand Down Expand Up @@ -314,7 +312,6 @@ async def chunks_ready(
"""
array = self.registered_arrays[array_name]
array.update_dtype(timestep, dtype)
semaphore = self.semaphore_per_array[array_name]
lock = self.new_array_created[array_name]
creator = False
future = None
Expand All @@ -331,7 +328,6 @@ async def chunks_ready(
future = entry

if creator:
await semaphore.acquire()
async with lock:
array.chunk_refs[timestep] = []
future.set_result(True)
Expand Down Expand Up @@ -405,5 +401,4 @@ async def get_next_array(self) -> tuple[str, Timestep, da.Array]:
components to pull work in order.
"""
array = await self.arrays_ready.get()
self.semaphore_per_array[array[0]].release()
return array
25 changes: 16 additions & 9 deletions src/deisa/ray/window_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import deque, defaultdict
import gc
import logging
import time
from typing import Any, Callable, List, Optional, Literal

import dask
Expand All @@ -20,7 +21,7 @@
CallbackArgs,
_CallbackConfig,
)
from deisa.ray.utils import get_head_actor_options, get_ray_address
from deisa.ray.utils import get_head_actor_options

Callback = IDeisa.Callback
ExceptionHandler = IDeisa.ExceptionHandler
Expand All @@ -33,15 +34,21 @@ def _ray_start_impl() -> None:

Notes
-----
Initializes Ray only once with minimal logging. Used when the caller
does not provide a custom ``ray_start`` hook.
Initializes Ray only once with minimal logging. If the Ray runtime is not
Comment thread
theabm marked this conversation as resolved.
Outdated
ready yet, retries for a short window before surfacing the startup error.
"""
if not ray.is_initialized():
ray.init(
address=get_ray_address() or "auto",
log_to_driver=False,
logging_level=logging.ERROR,
)
timeout_secconds = 10.0
interval_retry = 1.0
deadline = time.monotonic() + timeout_secconds
while not ray.is_initialized():
try:
ray.init(address="auto", log_to_driver=False, logging_level=logging.ERROR)
return
except Exception:
remaining = deadline - time.monotonic()
if remaining <= 0:
raise
time.sleep(min(interval_retry, remaining))


def _with_timestep(array: da.Array, timestep: int) -> DeisaArray:
Expand Down
59 changes: 54 additions & 5 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from deisa.ray import window_handler
from deisa.ray.window_handler import Deisa
from deisa.ray.config import (
DEISA_DISTRIBUTED_SCHEDULING_ENV,
Expand Down Expand Up @@ -59,11 +60,59 @@ def ray_start():
assert d._ray_start is ray_start


# TODO remove when memory handling is done from bridge size checking that
# ray.put can happen because enough memory is available.
def test_max_simulation_ahead_is_read_from_kwargs():
d = Deisa(max_simulation_ahead=2)
assert d.max_simulation_ahead == 2
def test_default_ray_start_retries_until_ray_initializes(monkeypatch):
Comment thread
theabm marked this conversation as resolved.
Outdated
init_errors = [ConnectionError("ray runtime is not ready"), ConnectionError("still not ready")]
init_calls = []
sleeps = []

monkeypatch.setattr(window_handler.ray, "is_initialized", lambda: False)
monkeypatch.setattr(window_handler.time, "sleep", sleeps.append)

def fake_init(**kwargs):
init_calls.append(kwargs)
if init_errors:
raise init_errors.pop(0)

monkeypatch.setattr(window_handler.ray, "init", fake_init)

window_handler._ray_start_impl()

assert len(init_calls) == 3
assert sleeps == [1.0, 1.0]
assert init_calls[0] == {
"address": "auto",
"log_to_driver": False,
"logging_level": window_handler.logging.ERROR,
}


def test_default_ray_start_raises_after_retry_timeout(monkeypatch):
elapsed = 0.0
init_calls = 0

monkeypatch.setattr(window_handler.ray, "is_initialized", lambda: False)

def fake_monotonic():
return elapsed

def fake_sleep(seconds):
nonlocal elapsed
elapsed += seconds

def fake_init(**kwargs):
nonlocal init_calls
init_calls += 1
raise ConnectionError("ray runtime is not ready")

monkeypatch.setattr(window_handler.time, "monotonic", fake_monotonic)
monkeypatch.setattr(window_handler.time, "sleep", fake_sleep)
monkeypatch.setattr(window_handler.ray, "init", fake_init)

with pytest.raises(ConnectionError, match="ray runtime is not ready"):
window_handler._ray_start_impl()

assert elapsed == 10.0
assert init_calls == 11


def test_unexpected_init_kwarg_is_rejected():
Expand Down
Loading