diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index 44c03858ff..55ebdfa52c 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -3,6 +3,7 @@ import inspect import logging import uuid +import weakref from argparse import Namespace from collections.abc import Callable from contextlib import contextmanager @@ -91,9 +92,18 @@ def __init__(self, args: Namespace) -> None: self.tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True) - self.semaphore = asyncio.Semaphore( + # Concurrency semaphore must be created per-event-loop: asyncio + # primitives bind to the loop that was running when they were + # constructed, and Ray actors can serve requests on different loops + # across re-entries (e.g. between rollout and eval). Defer + # construction to the `semaphore` property below, which lazily + # (re)builds on first access and on any subsequent loop change. + self._semaphore_value = ( args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine ) + self._semaphore: asyncio.Semaphore | None = None + self._semaphore_loop_ref: weakref.ref | None = None + self.sampling_params: dict[str, Any] = dict( temperature=args.rollout_temperature, top_p=args.rollout_top_p, @@ -128,6 +138,32 @@ def dp_rank_context(self): self.dp_counts[dp_rank] -= 1 assert self.dp_counts[dp_rank] >= 0 + @property + def semaphore(self) -> asyncio.Semaphore: + # (Re)bind on event-loop change (see __init__ comment). Dropping the + # old semaphore is safe: asyncio.Semaphore holds no OS resources, and + # coroutines already inside `async with state.semaphore:` on the old + # loop keep a direct reference to the old object for their release; + # only new callers see the rebuilt one. + # + # Loop identity is tracked via weakref rather than id(), because + # id() is an address and the OS can recycle it for a new loop object + # after the old one is GC'd. Weakref goes dead when the old loop is + # collected, forcing a rebuild. + # + # Concurrency note: during a loop transition, in-flight coroutines + # on the old loop still hold the old semaphore, while the new loop's + # first caller rebuilds. Across the handoff window the effective + # in-flight cap is transiently (old_limit + new_limit), not + # sglang_server_concurrency. Loop transitions are rare and SGLang + # handles backpressure, so this is acceptable. + current_loop = asyncio.get_running_loop() + stored_loop = self._semaphore_loop_ref() if self._semaphore_loop_ref is not None else None + if self._semaphore is None or stored_loop is not current_loop: + self._semaphore = asyncio.Semaphore(self._semaphore_value) + self._semaphore_loop_ref = weakref.ref(current_loop) + return self._semaphore + def reset(self) -> None: self.remaining_batch_size = 0 self.pendings = set() diff --git a/slime/utils/http_utils.py b/slime/utils/http_utils.py index ede851f6b0..1f8e7f71d6 100644 --- a/slime/utils/http_utils.py +++ b/slime/utils/http_utils.py @@ -6,6 +6,7 @@ import os import random import socket +import weakref import httpx @@ -145,6 +146,7 @@ def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> _http_client: httpx.AsyncClient | None = None +_http_client_loop_ref: weakref.ref | None = None _client_concurrency: int = 0 # Optional Ray-based distributed POST dispatch @@ -199,25 +201,70 @@ async def _post(client, url, payload, max_retries=60, headers=None): def init_http_client(args): - """Initialize HTTP client and optionally enable distributed POST via Ray.""" - global _http_client, _client_concurrency, _distributed_post_enabled + """Record HTTP-client concurrency and optionally enable distributed POST. + + The actual ``httpx.AsyncClient`` is built lazily per event loop by + ``_get_http_client`` on first use, so Ray actors that hop loops get a + fresh client bound to the active loop. + """ + global _client_concurrency, _distributed_post_enabled if not args.rollout_num_gpus: return _client_concurrency = args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine - if _http_client is None: - _http_client = httpx.AsyncClient( - limits=httpx.Limits(max_connections=_client_concurrency), - timeout=httpx.Timeout(None), - trust_env=False, # internal SGLang comm only — never route through system proxy + if _client_concurrency <= 0: + # Not fatal (we fall back to 1 in _get_http_client) but almost + # certainly a configuration bug; surface it in logs rather than + # silently throttling every rollout to 1 in-flight request. + logger.warning( + "_client_concurrency computed as %d from sglang_server_concurrency=%s, " + "rollout_num_gpus=%s, rollout_num_gpus_per_engine=%s; check your config.", + _client_concurrency, + args.sglang_server_concurrency, + args.rollout_num_gpus, + args.rollout_num_gpus_per_engine, ) - # Optionally initialize distributed POST via Ray without changing interfaces if args.use_distributed_post: _init_ray_distributed_post(args) _distributed_post_enabled = True +def _get_http_client() -> httpx.AsyncClient: + """Get or (re)create the HTTP client for the current event loop. + + ``httpx.AsyncClient`` binds its connection pool (and the internal + asyncio locks it uses) to the loop that created it. Ray actors can + serve calls on different loops across re-entries (e.g. rollout → + eval), so we detect a loop change and rebuild. + + Old-client cleanup: we deliberately do *not* call ``aclose()`` on the + old client. It was bound to a loop that may already be dead; running + ``aclose()`` from a different loop races on the pool's internal + asyncio locks and can leave sockets half-closed. We drop the reference + and let GC close the underlying sockets when the client is collected + (httpx emits a ResourceWarning; expected). A socket leak across a + rare loop transition is cheaper than a partial-aclose data race. + + Thread/concurrency safety: safe under asyncio because there is no + ``await`` between the loop-change check and the assignment to + ``_http_client``, so two concurrent callers on the same loop can't + both pass the check and race on construction. Any future refactor + that inserts an ``await`` here must add a per-loop lock. + """ + global _http_client, _http_client_loop_ref + current_loop = asyncio.get_running_loop() + stored_loop = _http_client_loop_ref() if _http_client_loop_ref is not None else None + if _http_client is None or stored_loop is not current_loop: + _http_client = httpx.AsyncClient( + limits=httpx.Limits(max_connections=max(_client_concurrency, 1)), + timeout=httpx.Timeout(None), + trust_env=False, # internal SGLang comm only — never route through system proxy + ) + _http_client_loop_ref = weakref.ref(current_loop) + return _http_client + + def _init_ray_distributed_post(args): """Initialize one or more Ray async actors per node for HTTP POST. @@ -287,11 +334,11 @@ async def post(url, payload, max_retries=60, headers=None): logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") # fall through to local - return await _post(_http_client, url, payload, max_retries, headers=headers) + return await _post(_get_http_client(), url, payload, max_retries, headers=headers) async def get(url): - response = await _http_client.get(url) + response = await _get_http_client().get(url) response.raise_for_status() content = await response.aread() output = json.loads(content) diff --git a/tests/test_http_utils_loop_rebind.py b/tests/test_http_utils_loop_rebind.py new file mode 100644 index 0000000000..e3546999f6 --- /dev/null +++ b/tests/test_http_utils_loop_rebind.py @@ -0,0 +1,148 @@ +"""Regression tests for the Semaphore / HTTP-client event-loop rebind. + +Without the rebind, re-entering the rollout actor on a different asyncio +loop (e.g. rollout -> eval) causes either ``RuntimeError: <...> is bound +to a different event loop`` or a silent hang. These tests exercise that +transition by running two separate ``asyncio.run(...)`` calls (each +spins up a fresh loop) and asserting that: + + 1. The cached primitive is rebuilt on the second call + (identity check: ``is not``). + 2. Pre-patch behavior — the primitive stays pinned to the first loop + — is what the patched code avoids; this is implicit in the test + (if the cache weren't rebuilt, the second-loop user would crash + or hang). + +CPU-only; no Ray, no GPU. +""" + +import asyncio +import types + +import pytest + + +# --------------------------------------------------------------------------- +# GenerateState.semaphore +# --------------------------------------------------------------------------- + + +def _make_state(): + """Build a GenerateState with the minimum args it reads during init, + skipping the tokenizer / processor / sampling-params setup that would + require a real HF checkpoint.""" + from slime.rollout.sglang_rollout import GenerateState + from slime.utils.misc import SingletonMeta + + # Reset singleton so each test starts fresh. + SingletonMeta._instances.pop(GenerateState, None) + + args = types.SimpleNamespace( + sglang_server_concurrency=2, + rollout_num_gpus=4, + rollout_num_gpus_per_engine=2, + ) + + state = GenerateState.__new__(GenerateState) + state.args = args + state._semaphore_value = ( + args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + ) + state._semaphore = None + state._semaphore_loop_ref = None + return state + + +def test_semaphore_rebinds_across_fresh_event_loops(): + """Running `touch()` in two separate asyncio.run() calls spins up two + different loops; the second call must see a fresh Semaphore bound to + the new loop.""" + state = _make_state() + + captured = [] + + async def touch(): + # Acquire/release to force the Semaphore to bind to this loop. + async with state.semaphore: + pass + captured.append(state.semaphore) + + asyncio.run(touch()) + first_sem = captured[-1] + + asyncio.run(touch()) + second_sem = captured[-1] + + assert first_sem is not second_sem, ( + "Semaphore was not rebuilt on second asyncio.run(); the event-loop " + "rebind is not firing." + ) + + +def test_semaphore_reused_within_same_loop(): + """Within a single asyncio.run(), repeated property access returns the + same Semaphore — otherwise concurrent callers would see different + objects and the concurrency cap would not hold.""" + state = _make_state() + + async def run(): + sem_a = state.semaphore + async with sem_a: + pass + sem_b = state.semaphore + return sem_a, sem_b + + a, b = asyncio.run(run()) + assert a is b + + +# --------------------------------------------------------------------------- +# _get_http_client +# --------------------------------------------------------------------------- + + +def test_http_client_rebinds_across_fresh_event_loops(): + """Analog of the semaphore test for ``_get_http_client``. The client + is module-global; we reset it per test.""" + import slime.utils.http_utils as http_utils + + http_utils._http_client = None + http_utils._http_client_loop_ref = None + http_utils._client_concurrency = 4 # any positive + + captured = [] + + async def touch(): + client = http_utils._get_http_client() + captured.append(client) + + asyncio.run(touch()) + first_client = captured[-1] + + asyncio.run(touch()) + second_client = captured[-1] + + assert first_client is not second_client, ( + "httpx.AsyncClient was not rebuilt on second asyncio.run(); the " + "event-loop rebind is not firing." + ) + + +def test_http_client_reused_within_same_loop(): + import slime.utils.http_utils as http_utils + + http_utils._http_client = None + http_utils._http_client_loop_ref = None + http_utils._client_concurrency = 4 + + async def run(): + a = http_utils._get_http_client() + b = http_utils._get_http_client() + return a, b + + a, b = asyncio.run(run()) + assert a is b + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])