Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion slime/rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import logging
import uuid
import weakref
from argparse import Namespace
from collections.abc import Callable
from contextlib import contextmanager
Expand Down Expand Up @@ -91,9 +92,18 @@ def __init__(self, args: Namespace) -> None:
self.tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True)
self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True)

self.semaphore = asyncio.Semaphore(
# Concurrency semaphore must be created per-event-loop: asyncio
# primitives bind to the loop that was running when they were
# constructed, and Ray actors can serve requests on different loops
# across re-entries (e.g. between rollout and eval). Defer
# construction to the `semaphore` property below, which lazily
# (re)builds on first access and on any subsequent loop change.
self._semaphore_value = (
args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine
)
self._semaphore: asyncio.Semaphore | None = None
self._semaphore_loop_ref: weakref.ref | None = None

self.sampling_params: dict[str, Any] = dict(
temperature=args.rollout_temperature,
top_p=args.rollout_top_p,
Expand Down Expand Up @@ -128,6 +138,32 @@ def dp_rank_context(self):
self.dp_counts[dp_rank] -= 1
assert self.dp_counts[dp_rank] >= 0

@property
def semaphore(self) -> asyncio.Semaphore:
# (Re)bind on event-loop change (see __init__ comment). Dropping the
# old semaphore is safe: asyncio.Semaphore holds no OS resources, and
# coroutines already inside `async with state.semaphore:` on the old
# loop keep a direct reference to the old object for their release;
# only new callers see the rebuilt one.
#
# Loop identity is tracked via weakref rather than id(), because
# id() is an address and the OS can recycle it for a new loop object
# after the old one is GC'd. Weakref goes dead when the old loop is
# collected, forcing a rebuild.
#
# Concurrency note: during a loop transition, in-flight coroutines
# on the old loop still hold the old semaphore, while the new loop's
# first caller rebuilds. Across the handoff window the effective
# in-flight cap is transiently (old_limit + new_limit), not
# sglang_server_concurrency. Loop transitions are rare and SGLang
# handles backpressure, so this is acceptable.
current_loop = asyncio.get_running_loop()
stored_loop = self._semaphore_loop_ref() if self._semaphore_loop_ref is not None else None
if self._semaphore is None or stored_loop is not current_loop:
self._semaphore = asyncio.Semaphore(self._semaphore_value)
self._semaphore_loop_ref = weakref.ref(current_loop)
return self._semaphore

def reset(self) -> None:
self.remaining_batch_size = 0
self.pendings = set()
Expand Down
67 changes: 57 additions & 10 deletions slime/utils/http_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import random
import socket
import weakref

import httpx

Expand Down Expand Up @@ -145,6 +146,7 @@ def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) ->


_http_client: httpx.AsyncClient | None = None
_http_client_loop_ref: weakref.ref | None = None
_client_concurrency: int = 0

# Optional Ray-based distributed POST dispatch
Expand Down Expand Up @@ -199,25 +201,70 @@ async def _post(client, url, payload, max_retries=60, headers=None):


def init_http_client(args):
"""Initialize HTTP client and optionally enable distributed POST via Ray."""
global _http_client, _client_concurrency, _distributed_post_enabled
"""Record HTTP-client concurrency and optionally enable distributed POST.

The actual ``httpx.AsyncClient`` is built lazily per event loop by
``_get_http_client`` on first use, so Ray actors that hop loops get a
fresh client bound to the active loop.
"""
global _client_concurrency, _distributed_post_enabled
if not args.rollout_num_gpus:
return

_client_concurrency = args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine
if _http_client is None:
_http_client = httpx.AsyncClient(
limits=httpx.Limits(max_connections=_client_concurrency),
timeout=httpx.Timeout(None),
trust_env=False, # internal SGLang comm only — never route through system proxy
if _client_concurrency <= 0:
# Not fatal (we fall back to 1 in _get_http_client) but almost
# certainly a configuration bug; surface it in logs rather than
# silently throttling every rollout to 1 in-flight request.
logger.warning(
"_client_concurrency computed as %d from sglang_server_concurrency=%s, "
"rollout_num_gpus=%s, rollout_num_gpus_per_engine=%s; check your config.",
_client_concurrency,
args.sglang_server_concurrency,
args.rollout_num_gpus,
args.rollout_num_gpus_per_engine,
)

# Optionally initialize distributed POST via Ray without changing interfaces
if args.use_distributed_post:
_init_ray_distributed_post(args)
_distributed_post_enabled = True


def _get_http_client() -> httpx.AsyncClient:
"""Get or (re)create the HTTP client for the current event loop.

``httpx.AsyncClient`` binds its connection pool (and the internal
asyncio locks it uses) to the loop that created it. Ray actors can
serve calls on different loops across re-entries (e.g. rollout →
eval), so we detect a loop change and rebuild.

Old-client cleanup: we deliberately do *not* call ``aclose()`` on the
old client. It was bound to a loop that may already be dead; running
``aclose()`` from a different loop races on the pool's internal
asyncio locks and can leave sockets half-closed. We drop the reference
and let GC close the underlying sockets when the client is collected
(httpx emits a ResourceWarning; expected). A socket leak across a
rare loop transition is cheaper than a partial-aclose data race.

Thread/concurrency safety: safe under asyncio because there is no
``await`` between the loop-change check and the assignment to
``_http_client``, so two concurrent callers on the same loop can't
both pass the check and race on construction. Any future refactor
that inserts an ``await`` here must add a per-loop lock.
"""
global _http_client, _http_client_loop_ref
current_loop = asyncio.get_running_loop()
stored_loop = _http_client_loop_ref() if _http_client_loop_ref is not None else None
if _http_client is None or stored_loop is not current_loop:
_http_client = httpx.AsyncClient(
limits=httpx.Limits(max_connections=max(_client_concurrency, 1)),
timeout=httpx.Timeout(None),
trust_env=False, # internal SGLang comm only — never route through system proxy
)
_http_client_loop_ref = weakref.ref(current_loop)
return _http_client


def _init_ray_distributed_post(args):
"""Initialize one or more Ray async actors per node for HTTP POST.

Expand Down Expand Up @@ -287,11 +334,11 @@ async def post(url, payload, max_retries=60, headers=None):
logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})")
# fall through to local

return await _post(_http_client, url, payload, max_retries, headers=headers)
return await _post(_get_http_client(), url, payload, max_retries, headers=headers)


async def get(url):
response = await _http_client.get(url)
response = await _get_http_client().get(url)
response.raise_for_status()
content = await response.aread()
output = json.loads(content)
Expand Down
148 changes: 148 additions & 0 deletions tests/test_http_utils_loop_rebind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Regression tests for the Semaphore / HTTP-client event-loop rebind.

Without the rebind, re-entering the rollout actor on a different asyncio
loop (e.g. rollout -> eval) causes either ``RuntimeError: <...> is bound
to a different event loop`` or a silent hang. These tests exercise that
transition by running two separate ``asyncio.run(...)`` calls (each
spins up a fresh loop) and asserting that:

1. The cached primitive is rebuilt on the second call
(identity check: ``is not``).
2. Pre-patch behavior — the primitive stays pinned to the first loop
— is what the patched code avoids; this is implicit in the test
(if the cache weren't rebuilt, the second-loop user would crash
or hang).

CPU-only; no Ray, no GPU.
"""

import asyncio
import types

import pytest


# ---------------------------------------------------------------------------
# GenerateState.semaphore
# ---------------------------------------------------------------------------


def _make_state():
"""Build a GenerateState with the minimum args it reads during init,
skipping the tokenizer / processor / sampling-params setup that would
require a real HF checkpoint."""
from slime.rollout.sglang_rollout import GenerateState
from slime.utils.misc import SingletonMeta

# Reset singleton so each test starts fresh.
SingletonMeta._instances.pop(GenerateState, None)

args = types.SimpleNamespace(
sglang_server_concurrency=2,
rollout_num_gpus=4,
rollout_num_gpus_per_engine=2,
)

state = GenerateState.__new__(GenerateState)
state.args = args
state._semaphore_value = (
args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine
)
state._semaphore = None
state._semaphore_loop_ref = None
return state


def test_semaphore_rebinds_across_fresh_event_loops():
"""Running `touch()` in two separate asyncio.run() calls spins up two
different loops; the second call must see a fresh Semaphore bound to
the new loop."""
state = _make_state()

captured = []

async def touch():
# Acquire/release to force the Semaphore to bind to this loop.
async with state.semaphore:
pass
captured.append(state.semaphore)

asyncio.run(touch())
first_sem = captured[-1]

asyncio.run(touch())
second_sem = captured[-1]

assert first_sem is not second_sem, (
"Semaphore was not rebuilt on second asyncio.run(); the event-loop "
"rebind is not firing."
)


def test_semaphore_reused_within_same_loop():
"""Within a single asyncio.run(), repeated property access returns the
same Semaphore — otherwise concurrent callers would see different
objects and the concurrency cap would not hold."""
state = _make_state()

async def run():
sem_a = state.semaphore
async with sem_a:
pass
sem_b = state.semaphore
return sem_a, sem_b

a, b = asyncio.run(run())
assert a is b


# ---------------------------------------------------------------------------
# _get_http_client
# ---------------------------------------------------------------------------


def test_http_client_rebinds_across_fresh_event_loops():
"""Analog of the semaphore test for ``_get_http_client``. The client
is module-global; we reset it per test."""
import slime.utils.http_utils as http_utils

http_utils._http_client = None
http_utils._http_client_loop_ref = None
http_utils._client_concurrency = 4 # any positive

captured = []

async def touch():
client = http_utils._get_http_client()
captured.append(client)

asyncio.run(touch())
first_client = captured[-1]

asyncio.run(touch())
second_client = captured[-1]

assert first_client is not second_client, (
"httpx.AsyncClient was not rebuilt on second asyncio.run(); the "
"event-loop rebind is not firing."
)


def test_http_client_reused_within_same_loop():
import slime.utils.http_utils as http_utils

http_utils._http_client = None
http_utils._http_client_loop_ref = None
http_utils._client_concurrency = 4

async def run():
a = http_utils._get_http_client()
b = http_utils._get_http_client()
return a, b

a, b = asyncio.run(run())
assert a is b


if __name__ == "__main__":
pytest.main([__file__, "-v"])