Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions livekit-agents/livekit/agents/inference/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,8 @@ def __init__(
conn_options: APIConnectOptions,
) -> None:
super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
self._stt: STT = stt
self._opts = opts
self._session = stt._ensure_session()
self._request_id = str(utils.shortuuid("stt_request_"))

self._speaking = False
Expand Down Expand Up @@ -632,7 +632,7 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
if closing_ws or self._session.closed:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we have the same pattern in other files as well.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this back with a local http_session var

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we fix all "or self._session.closed:" usage?

if closing_ws:
return
raise APIStatusError(
message="LiveKit Inference STT connection closed unexpectedly"
Expand Down Expand Up @@ -722,7 +722,7 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
}
try:
ws = await asyncio.wait_for(
self._session.ws_connect(
self._stt._ensure_session().ws_connect(
f"{base_url}/stt?model={self._opts.model}", headers=headers
),
self._conn_options.timeout,
Expand Down
5 changes: 5 additions & 0 deletions livekit-agents/livekit/agents/utils/http_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,8 @@ async def _close_http_ctx() -> None:
logger.debug("http_session(): closing the httpclient ctx")
await val().close()
_ContextVar.set(None)


def _is_http_session_ctx_set() -> bool:
"""Return True if an http session factory is bound to the current context."""
return _ContextVar.get(None) is not None
22 changes: 22 additions & 0 deletions livekit-agents/livekit/agents/voice/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ def __init__(
self._closing_task: asyncio.Task[None] | None = None
self._closing: bool = False
self._job_context_cb_registered: bool = False
self._owned_http_session_ctx: bool = False

self._global_run_state: RunResult | None = None
# TODO(theomonnom): need a better way to expose early assistant metrics
Expand Down Expand Up @@ -624,6 +625,18 @@ async def start(
# configure observability first
record_is_given = is_given(record)
job_ctx = get_job_context(required=False)

# Outside a job context (tests, scripts, ad-hoc usage) there's no
# http session bound to the event loop. Create one scoped to this
# session so STT/TTS can use http_context.http_session()
if (
job_ctx is None
and not self._owned_http_session_ctx
and not utils.http_context._is_http_session_ctx_set()
):
utils.http_context._new_session_ctx()
self._owned_http_session_ctx = True

if not is_given(record):
# defer to server-side setting for recording
record = job_ctx.job.enable_recording if job_ctx else False
Expand Down Expand Up @@ -906,6 +919,11 @@ async def _aclose_impl(

async with self._lock:
if not self._started:
# start() may have set up the http session ctx before failing —
# clean it up so we don't leak the factory on a failed start.
if self._owned_http_session_ctx:
await utils.http_context._close_http_ctx()
self._owned_http_session_ctx = False
return

self._closing = True
Expand Down Expand Up @@ -1008,6 +1026,10 @@ async def _aclose_impl(
await self._room_io.aclose()
self._room_io = None

if self._owned_http_session_ctx:
await utils.http_context._close_http_ctx()
self._owned_http_session_ctx = False

logger.debug("session closed", extra={"reason": reason.value, "error": error})

async def aclose(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ unit-tests:
tests/test_tool_search.py \
tests/test_tool_proxy.py \
tests/test_endpointing.py \
tests/test_session_host.py
tests/test_session_host.py \
tests/test_http_session_lifecycle.py

# ============================================
# Development Workflows
Expand Down
272 changes: 272 additions & 0 deletions tests/test_http_session_lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
"""
Tests for AgentSession-owned http_context lifecycle.

When running outside a job context (tests, scripts, ad-hoc usage) there is no
process-level http_session bound to the event loop. AgentSession sets one up in
start() and tears it down in aclose() so that STT/TTS/etc. can call
``utils.http_context.http_session()`` without a job process running.
"""

from __future__ import annotations

import asyncio
from pathlib import Path
from unittest.mock import MagicMock, patch

import aiohttp
import pytest

from livekit.agents import (
NOT_GIVEN,
Agent,
AgentSession,
NotGivenOr,
stt as stt_module,
tts as tts_module,
)
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from livekit.agents.utils import http_context

from .fake_io import FakeAudioInput, FakeAudioOutput, FakeTextOutput
from .fake_llm import FakeLLM
from .fake_vad import FakeVAD

_AGENT_SESSION_MOD = "livekit.agents.voice.agent_session"


class _CapturingSTT(stt_module.STT):
"""STT that records the http session it sees during stream() — no network."""

def __init__(self) -> None:
super().__init__(
capabilities=stt_module.STTCapabilities(streaming=True, interim_results=False),
)
self.captured_session: aiohttp.ClientSession | None = None

async def _recognize_impl(self, *args, **kwargs): # pragma: no cover - unused
raise NotImplementedError

def stream(
self,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> _NoopSTTStream:
# The point of the test: this call must succeed inside an active
# AgentSession, regardless of whether a job context is set.
self.captured_session = http_context.http_session()
return _NoopSTTStream(stt=self, conn_options=conn_options)


class _NoopSTTStream(stt_module.RecognizeStream):
async def _run(self) -> None:
async for _ in self._input_ch:
pass


class _CapturingTTS(tts_module.TTS):
"""TTS that records the http session it sees during synthesize() — no network."""

def __init__(self) -> None:
super().__init__(
capabilities=tts_module.TTSCapabilities(streaming=False),
sample_rate=24000,
num_channels=1,
)
self.captured_session: aiohttp.ClientSession | None = None

def synthesize(
self,
text: str,
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> _NoopChunkedStream:
self.captured_session = http_context.http_session()
return _NoopChunkedStream(tts=self, input_text=text, conn_options=conn_options)


class _NoopChunkedStream(tts_module.ChunkedStream):
async def _run(self, output_emitter: tts_module.AudioEmitter) -> None:
output_emitter.initialize(
request_id="noop",
sample_rate=24000,
num_channels=1,
mime_type="audio/pcm",
)
output_emitter.flush()


class _NoopAgent(Agent):
def __init__(self) -> None:
super().__init__(instructions="noop")


def _make_session(
stt: _CapturingSTT | None = None, tts: _CapturingTTS | None = None
) -> AgentSession:
session = AgentSession[None](
vad=FakeVAD(fake_user_speeches=[], min_silence_duration=0.5, min_speech_duration=0.05),
stt=stt or _CapturingSTT(),
llm=FakeLLM(fake_responses=[]),
tts=tts or _CapturingTTS(),
# disable AEC warmup so we don't leak the timer
aec_warmup_duration=None,
)
session.input.audio = FakeAudioInput()
session.output.audio = FakeAudioOutput()
session.output.transcription = FakeTextOutput()
return session


async def test_http_session_available_during_agent_session() -> None:
"""Inside a started AgentSession, http_context.http_session() returns a working session.

After aclose, the context is reset and http_session() raises again.
"""
# Sanity: nothing set in this task before start
with pytest.raises(RuntimeError):
http_context.http_session()

capturing_stt = _CapturingSTT()
session = _make_session(stt=capturing_stt)

await session.start(_NoopAgent())

# The set in start() propagates to this task's context (start awaited here).
sess = http_context.http_session()
assert isinstance(sess, aiohttp.ClientSession)
assert not sess.closed

# The STT.stream() called during activity start sees the same session.
assert capturing_stt.captured_session is sess

await session.aclose()

# After aclose the underlying session is closed and the contextvar is reset.
assert sess.closed
with pytest.raises(RuntimeError):
http_context.http_session()


async def test_concurrent_sessions_in_separate_tasks_are_isolated() -> None:
"""Two AgentSessions started inside their own asyncio.Task each get their own
http session. Closing one does not affect the other.
"""
barrier = asyncio.Event()

async def session_worker() -> tuple[aiohttp.ClientSession, aiohttp.ClientSession]:
capturing_stt = _CapturingSTT()
session = _make_session(stt=capturing_stt)

await session.start(_NoopAgent())
seen = http_context.http_session()
# wait so both tasks are alive simultaneously — proves isolation
await barrier.wait()
# session is still live and accessible from this task's context
still_seen = http_context.http_session()
await session.aclose()
return seen, still_seen

task_a = asyncio.create_task(session_worker())
task_b = asyncio.create_task(session_worker())

# let both reach the barrier
await asyncio.sleep(0.05)
barrier.set()

(a_first, a_second), (b_first, b_second) = await asyncio.gather(task_a, task_b)

# each task sees a stable session before close
assert a_first is a_second
assert b_first is b_second

# tasks see different sessions — not a single global one
assert a_first is not b_first

# both got closed independently
assert a_first.closed
assert b_first.closed


def _mock_job_ctx() -> MagicMock:
"""Build the minimal JobContext mock that AgentSession.start() reads from."""
mock = MagicMock()
mock.job.enable_recording = False
mock.job.id = "test-job-id"
mock.job.agent_name = "test-agent"
mock.room.name = "test-room"
mock._primary_agent_session = None
mock.session_directory = Path("/tmp/test-session")
return mock


async def test_session_does_not_own_http_ctx_inside_job_context(
job_process: None, # fixture sets up http_context for the test
) -> None:
"""When AgentSession runs inside a real job context, it must not overwrite or
close the process-level http_context on aclose.
"""
outer_session = http_context.http_session()
assert not outer_session.closed

session = _make_session()

with patch(f"{_AGENT_SESSION_MOD}.get_job_context", return_value=_mock_job_ctx()):
await session.start(_NoopAgent())

# AgentSession reuses the existing context — same ClientSession surfaces.
assert http_context.http_session() is outer_session

await session.aclose()

# The job-context session is still alive — only the job_process fixture closes it.
assert not outer_session.closed
assert http_context.http_session() is outer_session


async def test_nested_sessions_in_same_task_share_http_ctx() -> None:
"""A second AgentSession started inside a still-running outer session (same
task) must reuse the outer's http session and not close it on aclose.
"""
outer = _make_session()
await outer.start(_NoopAgent())
outer_session = http_context.http_session()
assert outer._owned_http_session_ctx is True

inner = _make_session()
await inner.start(_NoopAgent())

# inner sees the contextvar already set → does not take ownership
assert inner._owned_http_session_ctx is False
assert http_context.http_session() is outer_session

await inner.aclose()

# outer's session is unaffected by inner's close
assert not outer_session.closed
assert http_context.http_session() is outer_session

await outer.aclose()
assert outer_session.closed


async def test_start_failure_cleans_up_http_ctx() -> None:
"""If start() fails after setting up the http session, aclose() must still
clean it up. Otherwise __aexit__ on the async-with would leak the factory.
"""
session = _make_session()

with patch.object(AgentSession, "_update_activity", side_effect=RuntimeError("boom")):
with pytest.raises(BaseException): # noqa: B017,PT011 - want any failure
await session.start(_NoopAgent())

# the session never reached _started=True, but the http_session ctx was set
assert session._started is False
assert session._owned_http_session_ctx is True

await session.aclose()

# aclose must clean up even when start failed
assert session._owned_http_session_ctx is False
with pytest.raises(RuntimeError):
http_context.http_session()
Loading