diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0c2aa01..bbf77ab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -109,6 +109,17 @@ jobs: - name: Build run: cargo build --features=${{env.features}} --verbose --target ${{ matrix.platform.rust-target }} + # trio on Windows depends on cffi, which doesn't support + # free-threaded CPython 3.13 (only 3.14t+). PyPy is excluded + # because tests are skipped on PyPy anyway. + # trio>=0.23 for spawn_system_task(context=...). + - if: ${{ !startsWith(matrix.python-version, 'pypy') && !(matrix.platform.os == 'windows-latest' && matrix.python-version == '3.13t') }} + name: Install trio and sniffio + run: python -m pip install -U 'trio>=0.23' sniffio + - if: ${{ matrix.platform.os == 'windows-latest' && matrix.python-version == '3.13t' }} + name: Skip trio (cffi unsupported on free-threaded 3.13) + run: echo "PYO3_ASYNC_TEST_TRIO_OPTIONAL=1" >> $env:GITHUB_ENV + # uvloop doesn't compile under # Windows, https://github.com/MagicStack/uvloop/issues/536, # nor PyPy, https://github.com/MagicStack/uvloop/issues/537 @@ -136,7 +147,7 @@ jobs: - uses: taiki-e/install-action@cargo-llvm-cov - name: Install pyo3-asyncio test dependencies - run: python -m pip install -U uvloop + run: python -m pip install -U uvloop 'trio>=0.23' sniffio - run: cargo llvm-cov --all-features --codecov --output-path coverage.json - uses: codecov/codecov-action@v5 diff --git a/CHANGELOG.md b/CHANGELOG.md index 49b90e8..4989227 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ To see unreleased changes, please see the CHANGELOG on the main branch. ## [Unreleased] +- Add trio support: `TaskLocals`, `into_future_with_locals`, `generic::future_into_py_with_locals`, `into_stream_with_locals_v1` and `into_stream_with_locals_v2` now detect the running Python async library via `sniffio` and dispatch accordingly, so the existing `tokio::future_into_py`/`tokio::into_future` (and `async_std` equivalents) work unchanged when called from `trio`. No new public API or feature flag is required; the asyncio code path is unchanged. New `RuntimeKind` enum and `TaskLocals::{trio, current, kind, token}` are exposed for explicit control. `local_future_into_py_with_locals` returns `NotImplementedError` under trio (`spawn_local` requires a `LocalSet` incompatible with `trio.run`); `run`/`run_until_complete` remain asyncio-only. Requires `trio >= 0.23`. +- `into_stream_v2`: a raising async generator now closes the stream promptly under asyncio (previously relied on `SenderGlue` GC), and the captured `TaskLocals` `contextvars.Context` is now propagated into the forwarding task under both asyncio and trio. + ## [0.28.0] - 2026-02-03 - Bump to pyo3 0.28. [#76](https://github.com/PyO3/pyo3-async-runtimes/pull/76) diff --git a/Cargo.toml b/Cargo.toml index 139af78..f6af5ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -113,6 +113,18 @@ harness = false required-features = ["tokio-runtime", "testing"] +[[test]] +name = "test_trio" +path = "pytests/test_trio.rs" +harness = false +required-features = ["tokio-runtime"] + +[[test]] +name = "test_async_std_trio" +path = "pytests/test_async_std_trio.rs" +harness = false +required-features = ["async-std-runtime"] + [[test]] name = "test_race_condition_regression" path = "pytests/test_race_condition_regression.rs" diff --git a/Contributing.md b/Contributing.md index 2dc01cb..c7aae6d 100644 --- a/Contributing.md +++ b/Contributing.md @@ -49,4 +49,14 @@ Using the project's githooks are recommended to prevent CI from failing for triv ``` git config core.hookspath .githooks -``` \ No newline at end of file +``` + +## Running the trio tests + +The trio integration tests need the `trio` and `sniffio` Python packages installed +(`pip install trio sniffio`). Add `unstable-streams` for the stream-conversion tests: + +``` +cargo test --features tokio-runtime --test test_trio +cargo test --features 'tokio-runtime unstable-streams' --test test_trio +``` diff --git a/README.md b/README.md index c75bc11..56bcd7a 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ ***Forked from [`pyo3-asyncio`](https://github.com/awestlake87/pyo3-asyncio/) to deliver compatibility for PyO3 0.21+.*** -[Rust](http://www.rust-lang.org/) bindings for [Python](https://www.python.org/)'s [Asyncio Library](https://docs.python.org/3/library/asyncio.html). This crate facilitates interactions between Rust Futures and Python Coroutines and manages the lifecycle of their corresponding event loops. +[Rust](http://www.rust-lang.org/) bindings for [Python](https://www.python.org/)'s [Asyncio Library](https://docs.python.org/3/library/asyncio.html) and [trio](https://trio.readthedocs.io/). This crate facilitates interactions between Rust Futures and Python Coroutines and manages the lifecycle of their corresponding event loops. - PyO3 Project: [Homepage](https://pyo3.rs/) | [GitHub](https://github.com/PyO3/pyo3) @@ -30,9 +30,10 @@ If you are working with a Python library that makes use of async functions or wi Python bindings for an async Rust library, [`pyo3-async-runtimes`](https://github.com/PyO3/pyo3-async-runtimes) likely has the tools you need. It provides conversions between async functions in both Python and Rust and was designed with first-class support for popular Rust runtimes such as -[`tokio`](https://tokio.rs/) and [`async-std`](https://async.rs/). In addition, all async Python -code runs on the default `asyncio` event loop, so `pyo3-async-runtimes` should work just fine with existing -Python libraries. +[`tokio`](https://tokio.rs/) and [`async-std`](https://async.rs/). By default, async Python +code runs on the `asyncio` event loop, so `pyo3-async-runtimes` should work just fine with existing +Python libraries. The same conversions also work transparently under [`trio`](https://trio.readthedocs.io) +— the running Python async library is detected at call time via `sniffio`, with no extra feature flags. In the following sections, we'll give a general overview of `pyo3-async-runtimes` explaining how to call async Python functions with PyO3, how to call async Rust functions from Python, and how to configure @@ -529,6 +530,61 @@ fn main() -> PyResult<()> { } ``` +#### Using `trio` + +Unlike `uvloop`, [`trio`](https://trio.readthedocs.io) is not a drop-in +`asyncio` event loop — it is a separate async library with its own +scheduler and primitives. `pyo3-async-runtimes` detects the running +Python async library at call time (via +[`sniffio`](https://sniffio.readthedocs.io)) and uses the appropriate +park/wake primitives, so the same compiled extension works under both +`asyncio` and `trio` with no Python-side shim and no extra Cargo +features: + +```rust +//! lib.rs + +use pyo3::{prelude::*, wrap_pyfunction}; + +#[pyfunction] +fn rust_sleep(py: Python) -> PyResult> { + pyo3_async_runtimes::tokio::future_into_py(py, async { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + Ok(()) + }) +} + +#[pymodule] +fn my_async_module(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(rust_sleep, m)?)?; + Ok(()) +} +``` + +```python +import trio +from my_async_module import rust_sleep + +async def main(): + await rust_sleep() + +trio.run(main) +``` + +The same `rust_sleep` can be awaited unchanged from `asyncio.run(main())`. +Under `asyncio` the existing code path is taken and an `asyncio.Future` +is returned exactly as before, so existing users see no behavior change. + +`into_future`, `future_into_py`, and `into_stream_v2` all dispatch this +way. `local_future_into_py` (the `!Send` variant) and the +`run`/`run_until_complete` helpers remain asyncio-only — +`local_future_into_py` returns `NotImplementedError` under trio because +`spawn_local` requires a `LocalSet` that cannot share a thread with +`trio.run`, and `run_until_complete` is inherently tied to asyncio's +loop-creation API. + +Requires `trio >= 0.23`. + ### Additional Information - Managing event loop references can be tricky with `pyo3-async-runtimes`. See [Event Loop References and ContextVars](https://docs.rs/pyo3-async-runtimes/latest/pyo3_async_runtimes/#event-loop-references-and-contextvars) in the API docs to get a better intuition for how event loop references are managed in this library. diff --git a/pytests/test_async_std_trio.rs b/pytests/test_async_std_trio.rs new file mode 100644 index 0000000..32510d0 --- /dev/null +++ b/pytests/test_async_std_trio.rs @@ -0,0 +1,36 @@ +use pyo3::prelude::*; + +#[pyfunction] +fn rust_sleep(py: Python<'_>) -> PyResult> { + pyo3_async_runtimes::async_std::future_into_py(py, async move { + async_std::task::sleep(std::time::Duration::from_millis(50)).await; + Ok(42i64) + }) +} + +fn main() -> PyResult<()> { + Python::initialize(); + Python::attach(|py| { + if py.import("trio").is_err() { + if std::env::var_os("CI").is_some() + && std::env::var_os("PYO3_ASYNC_TEST_TRIO_OPTIONAL").is_none() + { + eprintln!("error: trio is not installed but CI is set"); + std::process::exit(1); + } + println!("test test_async_std_trio ... skipped (trio not available)"); + return Ok(()); + } + let driver = PyModule::from_code( + py, + c"import trio\nasync def main(f):\n return await f()\ndef drive(f):\n return trio.run(main, f)\n", + c"trio_driver.py", + c"trio_driver", + )?; + let f = wrap_pyfunction!(rust_sleep, py)?; + let r: i64 = driver.getattr("drive")?.call1((f,))?.extract()?; + assert_eq!(r, 42); + println!("test test_async_std_trio ... ok"); + Ok(()) + }) +} diff --git a/pytests/test_trio.rs b/pytests/test_trio.rs new file mode 100644 index 0000000..a1b36b5 --- /dev/null +++ b/pytests/test_trio.rs @@ -0,0 +1,797 @@ +//! Integration tests for trio support. +//! +//! These mirror the asyncio integration tests in `tokio_asyncio/mod.rs`: every +//! probe goes through the public `pyo3_async_runtimes::tokio::*` conversion +//! API, exactly as a user would call it. +//! +//! The harness is `harness = false` because the crate's `testing::main()` runs +//! tests inside `tokio::run` → `asyncio.run`; trio tests need `trio.run` per +//! case instead. Filtering matches `testing::parse_args()`: pass a substring as +//! the first positional argument. + +use std::ffi::CString; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; + +#[cfg(feature = "unstable-streams")] +use futures_util::stream::StreamExt; +use pyo3::prelude::*; +use pyo3::types::PyDict; +use pyo3::IntoPyObjectExt; +use pyo3_async_runtimes::generic::{self, ContextExt, JoinError, Runtime}; +use pyo3_async_runtimes::{tokio as pyo3_tokio, RuntimeKind, TaskLocals}; + +// --------------------------------------------------------------------------- +// `NoSpawnRuntime` doesn't spawn — it immediately drops the future, to test the +// dropped-tx path. All other probes use the crate's real `tokio` runtime. +// --------------------------------------------------------------------------- + +struct NeverJoinError; +impl JoinError for NeverJoinError { + fn is_panic(&self) -> bool { + false + } + fn into_panic(self) -> Box { + unreachable!() + } +} + +struct NoSpawnRuntime; +impl Runtime for NoSpawnRuntime { + type JoinError = NeverJoinError; + type JoinHandle = std::future::Pending>; + fn spawn(_fut: F) -> Self::JoinHandle + where + F: Future + Send + 'static, + { + std::future::pending() + } + fn spawn_blocking(_f: F) -> Self::JoinHandle + where + F: FnOnce() + Send + 'static, + { + std::future::pending() + } +} + +impl ContextExt for NoSpawnRuntime { + fn scope(_locals: TaskLocals, fut: F) -> Pin + Send>> + where + F: Future + Send + 'static, + { + Box::pin(fut) + } + fn get_task_locals() -> Option { + None + } +} + +// --------------------------------------------------------------------------- +// Probe pyfunctions — all via the public conversion API. +// --------------------------------------------------------------------------- + +#[pyfunction] +fn rust_sleep(py: Python<'_>) -> PyResult> { + pyo3_tokio::future_into_py(py, async move { + tokio::time::sleep(Duration::from_millis(50)).await; + Ok(42i64) + }) +} + +#[pyfunction] +fn rust_never(py: Python<'_>) -> PyResult> { + pyo3_tokio::future_into_py(py, async move { + std::future::pending::<()>().await; + Ok(0i64) + }) +} + +#[pyfunction] +fn rust_panic(py: Python<'_>) -> PyResult> { + pyo3_tokio::future_into_py::<_, ()>(py, async { panic!("this panic was intentional!") }) +} + +/// Convert `awaitable` to a Rust future, then expose that future back to +/// Python — round-tripping `tokio::into_future` ↔ `tokio::future_into_py`. +#[pyfunction] +fn roundtrip_awaitable<'py>( + py: Python<'py>, + awaitable: Bound<'py, PyAny>, +) -> PyResult> { + let fut = pyo3_tokio::into_future(awaitable)?; + pyo3_tokio::future_into_py(py, fut) +} + +#[pyfunction] +fn snapshot_locals(py: Python<'_>) -> PyResult> { + fn kind_str(kind: RuntimeKind) -> &'static str { + match kind { + RuntimeKind::Asyncio => "asyncio", + RuntimeKind::Trio => "trio", + _ => unreachable!(), + } + } + let locals = TaskLocals::current(py)?; + let d = PyDict::new(py); + d.set_item("kind", kind_str(locals.kind()))?; + d.set_item("token", locals.event_loop(py))?; + d.set_item("context", locals.context(py))?; + let copied = locals.copy_context(py)?; + d.set_item("copied_kind", kind_str(copied.kind()))?; + d.set_item("copied_token", copied.event_loop(py))?; + d.set_item("copied_context", copied.context(py))?; + d.into_py_any(py) +} + +#[pyfunction] +fn local_probe(py: Python<'_>) -> PyResult> { + #[allow(deprecated)] + pyo3_tokio::local_future_into_py(py, async { Ok(0i64) }) +} + +static CANCEL_PROBE_DROPPED: AtomicBool = AtomicBool::new(false); +static CANCEL_PROBE_COMPLETED: AtomicBool = AtomicBool::new(false); + +struct DropGuard; +impl Drop for DropGuard { + fn drop(&mut self) { + CANCEL_PROBE_DROPPED.store(true, Ordering::Release); + } +} + +#[pyfunction] +fn cancel_probe(py: Python<'_>) -> PyResult> { + CANCEL_PROBE_DROPPED.store(false, Ordering::Release); + CANCEL_PROBE_COMPLETED.store(false, Ordering::Release); + let guard = DropGuard; + pyo3_tokio::future_into_py(py, async move { + let _guard = guard; + std::future::pending::<()>().await; + CANCEL_PROBE_COMPLETED.store(true, Ordering::Release); + Ok(0i64) + }) +} + +#[pyfunction] +fn cancel_probe_state() -> (bool, bool) { + ( + CANCEL_PROBE_DROPPED.load(Ordering::Acquire), + CANCEL_PROBE_COMPLETED.load(Ordering::Acquire), + ) +} + +#[pyfunction] +fn tx_dropped_probe(py: Python<'_>) -> PyResult> { + generic::future_into_py_with_locals::( + py, + TaskLocals::current(py)?, + async move { Ok(0i64) }, + ) +} + +#[cfg(feature = "unstable-streams")] +#[pyfunction] +fn stream_v1_probe<'py>(py: Python<'py>, gen: Bound<'py, PyAny>) -> PyResult> { + let stream = pyo3_tokio::into_stream_v1(gen)?; + pyo3_tokio::future_into_py(py, async move { + let mut items = Vec::new(); + futures_util::pin_mut!(stream); + while let Some(item) = stream.next().await { + let v = Python::attach(|py| item?.bind(py).extract::())?; + items.push(v); + } + Ok(items) + }) +} + +#[cfg(feature = "unstable-streams")] +#[pyfunction] +fn stream_v2_probe<'py>(py: Python<'py>, gen: Bound<'py, PyAny>) -> PyResult> { + let stream = pyo3_tokio::into_stream_v2(gen)?; + pyo3_tokio::future_into_py(py, async move { + let items: Vec> = stream.collect().await; + Python::attach(|py| { + items + .into_iter() + .map(|v| v.bind(py).extract::()) + .collect::>>() + }) + }) +} + +// --------------------------------------------------------------------------- +// Driver helpers +// --------------------------------------------------------------------------- + +fn run_driver<'py, A>(py: Python<'py>, src: &str, args: A) -> PyResult> +where + A: pyo3::call::PyCallArgs<'py>, +{ + let module = PyModule::from_code( + py, + &CString::new(src).unwrap(), + &CString::new("trio_test_driver.py").unwrap(), + &CString::new("trio_test_driver").unwrap(), + )?; + module.getattr("drive")?.call1(args).map(Bound::unbind) +} + +fn assert_ok_with<'py, A>(py: Python<'py>, src: &str, args: A) +where + A: pyo3::call::PyCallArgs<'py>, +{ + let r: String = run_driver(py, src, args) + .unwrap_or_else(|e| { + e.print_and_set_sys_last_vars(py); + panic!("driver failed") + }) + .bind(py) + .extract() + .unwrap(); + assert_eq!(r, "ok"); +} + +const ASYNCIO_DRIVER: &str = r#" +import asyncio +async def main(f): + return await f() +def drive(f): + return asyncio.run(main(f)) +"#; + +const TRIO_DRIVER: &str = r#" +import trio +async def main(f): + return await f() +def drive(f): + return trio.run(main, f) +"#; + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +fn test_asyncio_roundtrip(py: Python<'_>) { + let f = wrap_pyfunction!(rust_sleep, py).unwrap(); + let r: i64 = run_driver(py, ASYNCIO_DRIVER, (f,)) + .unwrap() + .bind(py) + .extract() + .unwrap(); + assert_eq!(r, 42); +} + +fn test_trio_roundtrip(py: Python<'_>) { + let f = wrap_pyfunction!(rust_sleep, py).unwrap(); + let r: i64 = run_driver(py, TRIO_DRIVER, (f,)) + .unwrap() + .bind(py) + .extract() + .unwrap(); + assert_eq!(r, 42); +} + +fn test_trio_cancel_scope_propagates(py: Python<'_>) { + let f = wrap_pyfunction!(rust_never, py).unwrap(); + let src = r#" +import time +import trio +async def main(f): + start = time.monotonic() + with trio.move_on_after(0.05): + await f() + return time.monotonic() - start +def drive(f): + return trio.run(main, f) +"#; + let elapsed: f64 = run_driver(py, src, (f,)) + .unwrap() + .bind(py) + .extract() + .unwrap(); + assert!( + elapsed < 1.0, + "move_on_after did not propagate to RustCoroutine (elapsed {elapsed}s)" + ); +} + +fn test_trio_contextvars(py: Python<'_>) { + let f = wrap_pyfunction!(roundtrip_awaitable, py).unwrap(); + let src = r#" +import contextvars, trio +cx = contextvars.ContextVar("cx") +async def reader(): + return cx.get() +async def main(roundtrip): + cx.set("foobar") + return await roundtrip(reader()) +def drive(roundtrip): + return trio.run(main, roundtrip) +"#; + let r: String = run_driver(py, src, (f,)) + .unwrap() + .bind(py) + .extract() + .unwrap(); + assert_eq!(r, "foobar"); +} + +fn test_tasklocals_asyncio(py: Python<'_>) { + let f = wrap_pyfunction!(snapshot_locals, py).unwrap(); + let src = r#" +import asyncio, contextvars +async def main(snap): + d = snap() + assert d["kind"] == "asyncio", d["kind"] + assert d["token"] is asyncio.get_running_loop(), d["token"] + assert d["context"] is None, d["context"] + assert d["copied_kind"] == "asyncio" + assert d["copied_token"] is asyncio.get_running_loop() + assert isinstance(d["copied_context"], contextvars.Context), d["copied_context"] + return "ok" +def drive(snap): + return asyncio.run(main(snap)) +"#; + assert_ok_with(py, src, (f,)); +} + +fn test_tasklocals_trio(py: Python<'_>) { + let f = wrap_pyfunction!(snapshot_locals, py).unwrap(); + let src = r#" +import trio, contextvars +async def main(snap): + d = snap() + assert d["kind"] == "trio", d["kind"] + assert isinstance(d["token"], trio.lowlevel.TrioToken), d["token"] + assert d["context"] is None + assert d["copied_kind"] == "trio" + assert isinstance(d["copied_token"], trio.lowlevel.TrioToken) + assert isinstance(d["copied_context"], contextvars.Context) + return "ok" +def drive(snap): + return trio.run(main, snap) +"#; + assert_ok_with(py, src, (f,)); +} + +fn test_tasklocals_unsupported_library(py: Python<'_>) { + let f = wrap_pyfunction!(snapshot_locals, py).unwrap(); + let src = r#" +import sniffio +def drive(snap): + sniffio.thread_local.name = "curio" + try: + try: + snap() + except RuntimeError as e: + assert "unsupported Python async library: curio" in str(e), str(e) + return "ok" + raise AssertionError("expected RuntimeError") + finally: + sniffio.thread_local.name = None +"#; + assert_ok_with(py, src, (f,)); +} + +fn test_tasklocals_no_loop(py: Python<'_>) { + let f = wrap_pyfunction!(snapshot_locals, py).unwrap(); + let src = r#" +def drive(snap): + try: + snap() + except RuntimeError as e: + assert "no running event loop" in str(e), str(e) + return "ok" + raise AssertionError("expected RuntimeError") +"#; + assert_ok_with(py, src, (f,)); +} + +fn test_into_future_asyncio_delegates(py: Python<'_>) { + let f = wrap_pyfunction!(roundtrip_awaitable, py).unwrap(); + let src = r#" +import asyncio +async def aw(): + await asyncio.sleep(0) + return 9 +async def main(roundtrip): + return await roundtrip(aw()) +def drive(roundtrip): + return asyncio.run(main(roundtrip)) +"#; + let r: i64 = run_driver(py, src, (f,)) + .unwrap() + .bind(py) + .extract() + .unwrap(); + assert_eq!(r, 9); +} + +fn test_into_future_trio_ok(py: Python<'_>) { + let f = wrap_pyfunction!(roundtrip_awaitable, py).unwrap(); + let src = r#" +import trio +async def aw(): + await trio.sleep(0) + return 7 +async def main(roundtrip): + return await roundtrip(aw()) +def drive(roundtrip): + return trio.run(main, roundtrip) +"#; + let r: i64 = run_driver(py, src, (f,)) + .unwrap() + .bind(py) + .extract() + .unwrap(); + assert_eq!(r, 7); +} + +fn test_into_future_trio_error(py: Python<'_>) { + let f = wrap_pyfunction!(roundtrip_awaitable, py).unwrap(); + let src = r#" +import trio +async def aw(): + raise ValueError("boom") +async def main(roundtrip): + try: + await roundtrip(aw()) + except ValueError as e: + assert str(e) == "boom" + return "ok" + raise AssertionError("expected ValueError") +def drive(roundtrip): + return trio.run(main, roundtrip) +"#; + assert_ok_with(py, src, (f,)); +} + +fn test_into_future_trio_base_exception(py: Python<'_>) { + let f = wrap_pyfunction!(roundtrip_awaitable, py).unwrap(); + let src = r#" +import trio +class MyBase(BaseException): + pass +async def aw(): + raise MyBase("boom") +async def main(roundtrip): + await roundtrip(aw()) + return "unreachable" +def find(exc, target): + if isinstance(exc, target): + return True + for attr in ("__cause__", "__context__"): + nxt = getattr(exc, attr, None) + if nxt is not None and find(nxt, target): + return True + if isinstance(exc, BaseExceptionGroup): + return any(find(e, target) for e in exc.exceptions) + return False +def drive(roundtrip): + try: + r = trio.run(main, roundtrip) + except BaseException as e: + assert find(e, MyBase), f"MyBase not found in {e!r}" + return "ok" + raise AssertionError(f"expected trio.run to raise; got {r!r}") +"#; + assert_ok_with(py, src, (f,)); +} + +fn test_future_into_py_dispatch_asyncio(py: Python<'_>) { + let f = wrap_pyfunction!(rust_never, py).unwrap(); + let src = r#" +import asyncio +async def main(never): + obj = never() + assert asyncio.isfuture(obj), type(obj) + obj.cancel() + return "ok" +def drive(never): + return asyncio.run(main(never)) +"#; + assert_ok_with(py, src, (f,)); +} + +fn test_future_into_py_dispatch_trio(py: Python<'_>) { + let f = wrap_pyfunction!(rust_never, py).unwrap(); + let src = r#" +import trio, asyncio +async def main(never): + obj = never() + assert not asyncio.isfuture(obj), type(obj) + assert type(obj).__name__ == "RustCoroutine", type(obj).__name__ + obj.close() + return "ok" +def drive(never): + return trio.run(main, never) +"#; + assert_ok_with(py, src, (f,)); +} + +fn test_trio_panic(py: Python<'_>) { + let f = wrap_pyfunction!(rust_panic, py).unwrap(); + let src = r#" +import trio +async def main(f): + try: + await f() + except Exception as e: + assert "this panic was intentional!" in str(e), str(e) + assert "RustPanic" in type(e).__name__, type(e).__name__ + return "ok" + raise AssertionError("expected RustPanic") +def drive(f): + return trio.run(main, f) +"#; + assert_ok_with(py, src, (f,)); +} + +fn test_trio_local_future_into_py_not_implemented(py: Python<'_>) { + let f = wrap_pyfunction!(local_probe, py).unwrap(); + let src = r#" +import trio +async def main(probe): + try: + probe() + except NotImplementedError as e: + assert "local_future_into_py" in str(e), str(e) + return "ok" + raise AssertionError("expected NotImplementedError") +def drive(probe): + return trio.run(main, probe) +"#; + assert_ok_with(py, src, (f,)); +} + +fn test_future_into_py_close_cancels_rust(py: Python<'_>) { + let probe = wrap_pyfunction!(cancel_probe, py).unwrap(); + let state = wrap_pyfunction!(cancel_probe_state, py).unwrap(); + let src = r#" +import trio +async def main(probe, state): + c = probe() + c.close() # drops cancel_tx -> spawned select() resolves Left -> drops user fut + for _ in range(200): + dropped, completed = state() + if dropped: + break + await trio.sleep(0.005) + dropped, completed = state() + assert dropped, "DropGuard never fired" + assert not completed, "future ran to completion despite cancel" + return "ok" +def drive(probe, state): + return trio.run(main, probe, state) +"#; + assert_ok_with(py, src, (probe, state)); +} + +fn test_future_into_py_tx_dropped_error(py: Python<'_>) { + let probe = wrap_pyfunction!(tx_dropped_probe, py).unwrap(); + let src = r#" +import trio +async def main(probe): + try: + await probe() + except RuntimeError as e: + assert "Rust task was dropped before completion" in str(e), str(e) + return "ok" + raise AssertionError("expected RuntimeError") +def drive(probe): + return trio.run(main, probe) +"#; + assert_ok_with(py, src, (probe,)); +} + +fn test_trio_run_finished_error_swallowed(py: Python<'_>) { + let f = wrap_pyfunction!(rust_sleep, py).unwrap(); + let src = r#" +import trio, io, sys, time +async def main(f): + async def child(): + await f() + async with trio.open_nursery() as n: + n.start_soon(child) + await trio.sleep(0) + n.cancel_scope.cancel() +def drive(f): + buf = io.StringIO() + old = sys.stderr + sys.stderr = buf + try: + trio.run(main, f) + time.sleep(0.15) # let the rust thread fire its wake against dead token + finally: + sys.stderr = old + return "ok" +"#; + assert_ok_with(py, src, (f,)); +} + +#[cfg(feature = "unstable-streams")] +fn test_trio_into_stream_v1(py: Python<'_>) { + let probe = wrap_pyfunction!(stream_v1_probe, py).unwrap(); + let src = r#" +import trio +async def gen(): + for i in range(5): + yield i + await trio.sleep(0) +async def main(probe): + items = await probe(gen()) + assert items == [0, 1, 2, 3, 4], items + return "ok" +def drive(probe): + return trio.run(main, probe) +"#; + assert_ok_with(py, src, (probe,)); +} + +#[cfg(feature = "unstable-streams")] +fn test_trio_into_stream_v2(py: Python<'_>) { + let probe = wrap_pyfunction!(stream_v2_probe, py).unwrap(); + let src = r#" +import trio +async def gen(): + for i in range(5): + yield i + await trio.sleep(0) +async def main(probe): + items = await probe(gen()) + assert items == [0, 1, 2, 3, 4], items + return "ok" +def drive(probe): + return trio.run(main, probe) +"#; + assert_ok_with(py, src, (probe,)); +} + +#[cfg(feature = "unstable-streams")] +fn test_trio_into_stream_v2_backpressure(py: Python<'_>) { + let probe = wrap_pyfunction!(stream_v2_probe, py).unwrap(); + let src = r#" +import trio +async def gen(): + for i in range(50): + yield i +async def main(probe): + items = await probe(gen()) + assert items == list(range(50)), items + return "ok" +def drive(probe): + return trio.run(main, probe) +"#; + assert_ok_with(py, src, (probe,)); +} + +#[cfg(feature = "unstable-streams")] +fn test_trio_into_stream_v2_gen_raises(py: Python<'_>) { + let probe = wrap_pyfunction!(stream_v2_probe, py).unwrap(); + let src = r#" +import trio +async def gen(): + yield 1 + yield 2 + raise ValueError("boom") +async def main(probe): + items = await probe(gen()) + assert items == [1, 2], items + return "ok" +def drive(probe): + return trio.run(main, probe) +"#; + assert_ok_with(py, src, (probe,)); +} + +#[cfg(feature = "unstable-streams")] +fn test_trio_into_stream_v2_gen_raises_base_exception(py: Python<'_>) { + let probe = wrap_pyfunction!(stream_v2_probe, py).unwrap(); + let src = r#" +import trio +async def gen(): + yield 1 + raise SystemExit(2) +async def main(probe): + items = await probe(gen()) + assert items == [1], items + return "ok" +def drive(probe): + return trio.run(main, probe) +"#; + assert_ok_with(py, src, (probe,)); +} + +// --------------------------------------------------------------------------- +// Harness +// --------------------------------------------------------------------------- + +type TestFn = fn(Python<'_>); + +fn main() -> PyResult<()> { + Python::initialize(); + + let trio_available = Python::attach(|py| py.import("trio").is_ok()); + if !trio_available + && std::env::var_os("CI").is_some() + && std::env::var_os("PYO3_ASYNC_TEST_TRIO_OPTIONAL").is_none() + { + eprintln!("error: trio is not installed but CI is set; refusing to skip trio tests"); + std::process::exit(1); + } + let filter = std::env::args().nth(1); + + #[allow(unused_mut)] + #[rustfmt::skip] + let mut tests: Vec<(&str, TestFn, bool)> = vec![ + ("asyncio_roundtrip", test_asyncio_roundtrip, false), + ("tasklocals_asyncio", test_tasklocals_asyncio, false), + ("tasklocals_no_loop", test_tasklocals_no_loop, false), + ("into_future_asyncio_delegates", test_into_future_asyncio_delegates, false), + ("future_into_py_dispatch_asyncio", test_future_into_py_dispatch_asyncio, false), + ("trio_roundtrip", test_trio_roundtrip, true), + ("trio_cancel_scope_propagates", test_trio_cancel_scope_propagates, true), + ("trio_contextvars", test_trio_contextvars, true), + ("tasklocals_trio", test_tasklocals_trio, true), + ("tasklocals_unsupported_library", test_tasklocals_unsupported_library, true), + ("into_future_trio_ok", test_into_future_trio_ok, true), + ("into_future_trio_error", test_into_future_trio_error, true), + ("into_future_trio_base_exception", test_into_future_trio_base_exception, true), + ("future_into_py_dispatch_trio", test_future_into_py_dispatch_trio, true), + ("trio_panic", test_trio_panic, true), + ("trio_local_future_into_py_not_implemented", test_trio_local_future_into_py_not_implemented, true), + ("future_into_py_close_cancels_rust", test_future_into_py_close_cancels_rust, true), + ("future_into_py_tx_dropped_error", test_future_into_py_tx_dropped_error, true), + ("trio_run_finished_error_swallowed", test_trio_run_finished_error_swallowed, true), + ]; + #[cfg(feature = "unstable-streams")] + #[rustfmt::skip] + tests.extend([ + ("trio_into_stream_v1", test_trio_into_stream_v1 as TestFn, true), + ("trio_into_stream_v2", test_trio_into_stream_v2 as TestFn, true), + ("trio_into_stream_v2_backpressure", test_trio_into_stream_v2_backpressure as TestFn, true), + ("trio_into_stream_v2_gen_raises", test_trio_into_stream_v2_gen_raises as TestFn, true), + ("trio_into_stream_v2_gen_raises_base_exception", test_trio_into_stream_v2_gen_raises_base_exception as TestFn, true), + ]); + + let mut passed = 0usize; + let mut failed = 0usize; + let mut skipped = 0usize; + + for (name, f, needs_trio) in tests { + if let Some(ref filter) = filter { + if !name.contains(filter.as_str()) { + continue; + } + } + if needs_trio && !trio_available { + println!("test trio::{name} ... skipped (trio not installed)"); + skipped += 1; + continue; + } + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| Python::attach(f))); + match result { + Ok(()) => { + println!("test trio::{name} ... ok"); + passed += 1; + } + Err(e) => { + let msg = e + .downcast_ref::() + .map(String::as_str) + .or_else(|| e.downcast_ref::<&str>().copied()) + .unwrap_or(""); + println!("test trio::{name} ... FAILED: {msg}"); + failed += 1; + } + } + } + + println!("\n{passed} passed; {failed} failed; {skipped} skipped"); + if failed > 0 { + std::process::exit(1); + } + Ok(()) +} diff --git a/src/async_std.rs b/src/async_std.rs index 9824eab..696250b 100644 --- a/src/async_std.rs +++ b/src/async_std.rs @@ -147,11 +147,12 @@ where AsyncStdRuntime::scope_local(locals, fut).await } -/// Get the current event loop from either Python or Rust async task local context +/// Get the current asyncio event loop from either Python or Rust async task local context /// /// This function first checks if the runtime has a task-local reference to the Python event loop. -/// If not, it calls [`get_running_loop`](`crate::get_running_loop`) to get the event loop -/// associated with the current OS thread. +/// If not, it falls back to [`TaskLocals::current`](crate::TaskLocals::current) (which detects the +/// running async library via `sniffio`). Under `trio` this returns a `RuntimeError` since there is +/// no asyncio event loop; use [`get_current_locals`] instead. pub fn get_current_loop(py: Python) -> PyResult> { generic::get_current_loop::(py) } diff --git a/src/generic.rs b/src/generic.rs index d999e0a..361a0b7 100644 --- a/src/generic.rs +++ b/src/generic.rs @@ -88,20 +88,32 @@ pub trait LocalContextExt: Runtime { F: Future + 'static; } -/// Get the current event loop from either Python or Rust async task local context +/// Get the current asyncio event loop from either Python or Rust async task local context /// /// This function first checks if the runtime has a task-local reference to the Python event loop. -/// If not, it calls [`get_running_loop`](crate::get_running_loop`) to get the event loop associated -/// with the current OS thread. +/// If not, it falls back to [`TaskLocals::current`](crate::TaskLocals::current) (which detects the +/// running async library via `sniffio`). Under `trio` this returns a `RuntimeError` since there is +/// no asyncio event loop; use [`get_current_locals`] and inspect [`TaskLocals::kind`] instead. pub fn get_current_loop(py: Python) -> PyResult> where R: ContextExt, { if let Some(locals) = R::get_task_locals() { - Ok(locals.0.event_loop.clone_ref(py).into_bound(py)) - } else { - get_running_loop(py) + return match locals.kind() { + crate::RuntimeKind::Asyncio => Ok(locals.0.event_loop.clone_ref(py).into_bound(py)), + crate::RuntimeKind::Trio => Err(pyo3::exceptions::PyRuntimeError::new_err( + "get_current_loop is asyncio-specific; under trio use \ + get_current_locals().event_loop() to obtain the TrioToken", + )), + }; + } + if crate::trio::sniff(py).as_deref() == Some("trio") { + return Err(pyo3::exceptions::PyRuntimeError::new_err( + "get_current_loop is asyncio-specific; under trio use \ + get_current_locals().event_loop() to obtain the TrioToken", + )); } + get_running_loop(py) } /// Either copy the task locals from the current task OR get the current running loop and @@ -113,7 +125,7 @@ where if let Some(locals) = R::get_task_locals() { Ok(locals) } else { - Ok(TaskLocals::with_running_loop(py)?.copy_context(py)?) + Ok(TaskLocals::current(py)?.copy_context(py)?) } } @@ -615,6 +627,12 @@ where F: Future> + Send + 'static, T: for<'py> IntoPyObject<'py> + Send + 'static, { + match locals.kind() { + crate::RuntimeKind::Trio => { + return crate::trio::future_into_coroutine::(py, locals, fut); + } + crate::RuntimeKind::Asyncio => {} + } let (cancel_tx, cancel_rx) = oneshot::channel(); let py_fut = create_future(locals.0.event_loop.bind(py).clone())?; @@ -689,7 +707,7 @@ where Ok(py_fut) } -fn get_panic_message(any: &dyn std::any::Any) -> &str { +pub(crate) fn get_panic_message(any: &dyn std::any::Any) -> &str { if let Some(str_slice) = any.downcast_ref::<&str>() { str_slice } else if let Some(string) = any.downcast_ref::() { @@ -1034,6 +1052,17 @@ where F: Future> + 'static, T: for<'py> IntoPyObject<'py>, { + match locals.kind() { + crate::RuntimeKind::Trio => { + return Err(pyo3::exceptions::PyNotImplementedError::new_err( + "local_future_into_py is not supported under trio: spawn_local requires \ + a thread-local executor (e.g. tokio LocalSet) on the current thread, \ + which the trio event loop thread does not provide. Use future_into_py \ + with a Send future instead.", + )); + } + crate::RuntimeKind::Asyncio => {} + } let (cancel_tx, cancel_rx) = oneshot::channel(); let py_fut = create_future(locals.0.event_loop.clone_ref(py).into_bound(py))?; @@ -1576,19 +1605,31 @@ impl SenderGlue { const STREAM_GLUE: &str = r#" import inspect -async def forward(gen, sender): - async for item in gen: - should_continue = sender.send(item) - - if inspect.isawaitable(should_continue): - should_continue = await should_continue - - if should_continue: - continue - else: - break - - sender.close() +async def forward(gen, sender, swallow): + try: + async for item in gen: + should_continue = sender.send(item) + + if inspect.isawaitable(should_continue): + should_continue = await should_continue + + if should_continue: + continue + else: + break + except BaseException as e: + if not swallow: + raise + # trio system task: re-raising would crash trio.run with + # TrioInternalError. Swallow everything; only print Exception + # subclasses (real errors). Cancelled / KeyboardInterrupt / + # SystemExit are BaseException-only and indicate control flow, + # not failure, so stay silent for those (KI is restricted to + # trio's main task and won't reach a system task in practice). + if isinstance(e, Exception): + import traceback; traceback.print_exc() + finally: + sender.close() "#; /// unstable-streams Convert an async generator into a stream @@ -1727,27 +1768,47 @@ where let (tx, rx) = mpsc::channel(10); - locals.event_loop(py).call_method1( - pyo3::intern!(py, "call_soon_threadsafe"), - ( - locals - .event_loop(py) - .getattr(pyo3::intern!(py, "create_task"))?, - glue.call_method1( - pyo3::intern!(py, "forward"), + let sender = SenderGlue { + locals: locals.clone(), + tx: Arc::new(Mutex::new(GenericSender { + runtime: PhantomData::, + tx, + })), + }; + match locals.kind() { + crate::RuntimeKind::Asyncio => { + call_soon_threadsafe( + &locals.event_loop(py), + &locals.context(py), ( - gen, - SenderGlue { - locals, - tx: Arc::new(Mutex::new(GenericSender { - runtime: PhantomData::, - tx, - })), - }, + locals + .event_loop(py) + .getattr(pyo3::intern!(py, "create_task"))?, + glue.call_method1(pyo3::intern!(py, "forward"), (gen, sender, false))?, ), - )?, - ), - )?; + )?; + } + crate::RuntimeKind::Trio => { + let spawn = crate::trio::trio_spawn_system_task(py)?; + // Bind args via functools.partial so a future positional change to + // spawn_system_task's signature can't silently shift `True`. + let bound_forward = crate::trio::functools_partial(py)?.call1(( + glue.getattr(pyo3::intern!(py, "forward"))?, + gen, + sender, + true, + ))?; + // Propagate contextvars into the system task, matching + // schedule_awaitable. + let kwargs = pyo3::types::PyDict::new(py); + kwargs.set_item(pyo3::intern!(py, "context"), locals.context(py))?; + let bound_spawn = + crate::trio::functools_partial(py)?.call((spawn, bound_forward), Some(&kwargs))?; + locals + .event_loop(py) + .call_method1(pyo3::intern!(py, "run_sync_soon"), (bound_spawn,))?; + } + } Ok(rx) } diff --git a/src/lib.rs b/src/lib.rs index 4ba5961..9ca942d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,14 @@ //! the event loops for both languages. Python's threading model and GIL can make this interop a bit //! trickier than one might expect, so there are a few caveats that users should be aware of. //! +//! ## Beyond Asyncio +//! +//! The conversions in this crate also work under `trio`. [`TaskLocals::current`] detects the +//! running Python async library via `sniffio`, and [`into_future_with_locals`] / +//! [`generic::future_into_py_with_locals`] dispatch on [`TaskLocals::kind`] — so the same +//! `#[pyfunction]` can be awaited from `asyncio.run` or `trio.run` with no code changes. +//! The asyncio code paths are unchanged when running under asyncio. +//! //! ## Why Two Event Loops //! //! Currently, we don't have a way to run Rust futures directly on Python's event loop. Likewise, @@ -137,9 +145,9 @@ //! Python event loop and contextvars associated with the current Rust _task_. //! //! Enter `pyo3_async_runtimes::::get_current_locals`. This function first checks task-local data -//! for the `TaskLocals`, then falls back on `asyncio.get_running_loop` and -//! `contextvars.copy_context` if no task locals are found. This way both bases are -//! covered. +//! for the `TaskLocals`, then falls back on detecting the running Python async library via `sniffio` +//! (`asyncio.get_running_loop` or `trio.lowlevel.current_trio_token`) plus `contextvars.copy_context` +//! if no task locals are found. This way both bases are covered. //! //! Now, all we need is a way to store the `TaskLocals` for the Rust future. Since this is a //! runtime-specific feature, you can find the following functions in each runtime module: @@ -362,6 +370,8 @@ pub mod err; pub mod generic; +mod trio; + #[pymodule] fn pyo3_async_runtimes(py: Python, m: &Bound) -> PyResult<()> { m.add("RustPanic", py.get_type::())?; @@ -470,13 +480,24 @@ fn contextvars(py: Python<'_>) -> PyResult<&Bound<'_, PyAny>> { .bind(py)) } -fn copy_context(py: Python) -> PyResult> { +pub(crate) fn copy_context(py: Python) -> PyResult> { contextvars(py)?.call_method0(pyo3::intern!(py, "copy_context")) } +/// Which Python async library a [`TaskLocals`] was captured under. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[non_exhaustive] +pub enum RuntimeKind { + /// Python's built-in `asyncio`. + Asyncio, + /// The `trio` structured-concurrency library. + Trio, +} + /// Task-local inner structure. #[derive(Debug)] struct TaskLocalsInner { + kind: RuntimeKind, /// Track the event loop of the Python task event_loop: Py, /// Track the contextvars of the Python task @@ -491,19 +512,45 @@ impl TaskLocals { /// At a minimum, TaskLocals must store the event loop. pub fn new(event_loop: Bound) -> Self { Self(Arc::new(TaskLocalsInner { + kind: RuntimeKind::Asyncio, context: event_loop.py().None(), event_loop: event_loop.into(), })) } + /// Construct TaskLocals for trio with the given `trio.lowlevel.TrioToken`. + pub fn trio(token: Bound) -> Self { + Self(Arc::new(TaskLocalsInner { + kind: RuntimeKind::Trio, + context: token.py().None(), + event_loop: token.into(), + })) + } + /// Construct TaskLocals with the event loop returned by `get_running_loop` pub fn with_running_loop(py: Python) -> PyResult { Ok(Self::new(get_running_loop(py)?)) } + /// Detect the current Python async library via `sniffio` and capture its + /// loop or token. Falls back to [`with_running_loop`](Self::with_running_loop) + /// if `sniffio` is not installed. + pub fn current(py: Python) -> PyResult { + match crate::trio::sniff(py).as_deref() { + Some("trio") => Ok(Self::trio( + crate::trio::current_trio_token(py)?.into_bound(py), + )), + Some("asyncio") | None => Self::with_running_loop(py), + Some(other) => Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "unsupported Python async library: {other}" + ))), + } + } + /// Manually provide the contextvars for the current task. pub fn with_context(self, context: Bound) -> Self { Self(Arc::new(TaskLocalsInner { + kind: self.0.kind, event_loop: self.0.event_loop.clone_ref(context.py()), context: context.into(), })) @@ -514,11 +561,27 @@ impl TaskLocals { Ok(self.with_context(copy_context(py)?)) } - /// Get a reference to the event loop + /// Which Python async library these locals were captured under. + pub fn kind(&self) -> RuntimeKind { + self.0.kind + } + + /// Get a reference to the event loop. + /// + /// Under [`RuntimeKind::Trio`] this returns the captured + /// `trio.lowlevel.TrioToken` rather than an asyncio event loop; prefer + /// [`token`](Self::token) when writing runtime-neutral code. pub fn event_loop<'p>(&self, py: Python<'p>) -> Bound<'p, PyAny> { self.0.event_loop.clone_ref(py).into_bound(py) } + /// Get the captured loop handle — the asyncio event loop under + /// [`RuntimeKind::Asyncio`] or the `trio.lowlevel.TrioToken` under + /// [`RuntimeKind::Trio`]. Runtime-neutral alias for [`event_loop`](Self::event_loop). + pub fn token<'p>(&self, py: Python<'p>) -> Bound<'p, PyAny> { + self.event_loop(py) + } + /// Get a reference to the python context pub fn context<'p>(&self, py: Python<'p>) -> Bound<'p, PyAny> { self.0.context.clone_ref(py).into_bound(py) @@ -665,23 +728,42 @@ pub fn into_future_with_locals( let py = awaitable.py(); let (tx, rx) = oneshot::channel(); - call_soon_threadsafe( - &locals.event_loop(py), - &locals.context(py), - (PyEnsureFuture { - awaitable: awaitable.into(), - tx: Some(tx), - },), - )?; + match locals.kind() { + RuntimeKind::Asyncio => { + call_soon_threadsafe( + &locals.event_loop(py), + &locals.context(py), + (PyEnsureFuture { + awaitable: awaitable.into(), + tx: Some(tx), + },), + )?; + } + RuntimeKind::Trio => { + crate::trio::schedule_awaitable( + py, + &locals.event_loop(py), + &locals.context(py), + awaitable, + tx, + )?; + } + } + let kind = locals.kind(); Ok(async move { match rx.await { Ok(item) => item, - Err(_) => Python::attach(|py| { - Err(PyErr::from_value( - asyncio(py)?.call_method0(pyo3::intern!(py, "CancelledError"))?, - )) - }), + Err(_) => match kind { + RuntimeKind::Asyncio => Python::attach(|py| { + Err(PyErr::from_value( + asyncio(py)?.call_method0(pyo3::intern!(py, "CancelledError"))?, + )) + }), + RuntimeKind::Trio => Err(pyo3::exceptions::PyRuntimeError::new_err( + "Python awaitable runner dropped before completion", + )), + }, } }) } diff --git a/src/tokio.rs b/src/tokio.rs index 9b01e5f..dbe494c 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -158,11 +158,12 @@ where TokioRuntime::scope_local(locals, fut).await } -/// Get the current event loop from either Python or Rust async task local context +/// Get the current asyncio event loop from either Python or Rust async task local context /// /// This function first checks if the runtime has a task-local reference to the Python event loop. -/// If not, it calls [`get_running_loop`](`crate::get_running_loop`) to get the event loop -/// associated with the current OS thread. +/// If not, it falls back to [`TaskLocals::current`](crate::TaskLocals::current) (which detects the +/// running async library via `sniffio`). Under `trio` this returns a `RuntimeError` since there is +/// no asyncio event loop; use [`get_current_locals`] instead. pub fn get_current_loop(py: Python) -> PyResult> { generic::get_current_loop::(py) } diff --git a/src/trio.rs b/src/trio.rs new file mode 100644 index 0000000..2738d53 --- /dev/null +++ b/src/trio.rs @@ -0,0 +1,607 @@ +//! `trio` support — Python-side park/wake primitives and a coroutine wrapper +//! that lets Rust futures be awaited under trio. +//! +//! Nothing here is needed for the asyncio path; this module is the +//! implementation detail behind [`RuntimeKind::Trio`](crate::RuntimeKind). +//! +//! `trio` and `sniffio` are imported lazily — if neither is installed the +//! module is inert and the rest of the crate behaves exactly as before. + +use std::future::Future; +use std::panic::AssertUnwindSafe; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; + +use futures_channel::oneshot; +use futures_util::future::{select, Either, FutureExt}; +use futures_util::task::{waker_ref, ArcWake}; +use pyo3::exceptions::{PyRuntimeError, PyStopIteration}; +use pyo3::prelude::*; +use pyo3::sync::PyOnceLock; +use pyo3::IntoPyObjectExt; + +use crate::err::RustPanic; +use crate::generic::{get_panic_message, ContextExt, Runtime}; + +// --------------------------------------------------------------------------- +// sniffio +// --------------------------------------------------------------------------- + +static SNIFFIO_CURRENT: PyOnceLock>> = PyOnceLock::new(); + +/// Returns `Some("asyncio" | "trio" | ...)` if `sniffio` is importable and a +/// library is running, otherwise `None`. +pub(crate) fn sniff(py: Python<'_>) -> Option { + let current = SNIFFIO_CURRENT + .get_or_init(py, || { + py.import("sniffio") + .and_then(|m| m.getattr(pyo3::intern!(py, "current_async_library"))) + .map(Bound::unbind) + .ok() + }) + .as_ref()?; + current.bind(py).call0().ok()?.extract().ok() +} + +// --------------------------------------------------------------------------- +// trio.lowlevel handles +// --------------------------------------------------------------------------- + +static TRIO_LOWLEVEL: PyOnceLock> = PyOnceLock::new(); +static TRIO_ABORT: PyOnceLock> = PyOnceLock::new(); +static TRIO_CURRENT_TASK: PyOnceLock> = PyOnceLock::new(); +static TRIO_CURRENT_TOKEN: PyOnceLock> = PyOnceLock::new(); +static TRIO_RESCHEDULE: PyOnceLock> = PyOnceLock::new(); +static TRIO_WAIT_TASK_RESCHEDULED: PyOnceLock> = PyOnceLock::new(); +static TRIO_SPAWN_SYSTEM_TASK: PyOnceLock> = PyOnceLock::new(); +static FUNCTOOLS_PARTIAL: PyOnceLock> = PyOnceLock::new(); +static OUTCOME_OUTCOME: PyOnceLock> = PyOnceLock::new(); + +fn trio_lowlevel(py: Python<'_>) -> PyResult<&Bound<'_, PyAny>> { + TRIO_LOWLEVEL + .get_or_try_init(py, || Ok(py.import("trio.lowlevel")?.into())) + .map(|m| m.bind(py)) +} + +fn trio_attr<'py>( + py: Python<'py>, + cell: &'static PyOnceLock>, + name: &'static str, +) -> PyResult<&'py Bound<'py, PyAny>> { + cell.get_or_try_init(py, || Ok(trio_lowlevel(py)?.getattr(name)?.unbind())) + .map(|o| o.bind(py)) +} + +/// `trio.lowlevel.current_trio_token()` — the handle used to schedule callbacks +/// onto the trio run loop from any thread. +pub(crate) fn current_trio_token(py: Python<'_>) -> PyResult> { + trio_attr(py, &TRIO_CURRENT_TOKEN, "current_trio_token")? + .call0() + .map(Bound::unbind) +} + +fn trio_reschedule<'py>(py: Python<'py>) -> PyResult<&'py Bound<'py, PyAny>> { + trio_attr(py, &TRIO_RESCHEDULE, "reschedule") +} + +pub(crate) fn trio_spawn_system_task<'py>(py: Python<'py>) -> PyResult> { + trio_attr(py, &TRIO_SPAWN_SYSTEM_TASK, "spawn_system_task").cloned() +} + +pub(crate) fn functools_partial<'py>(py: Python<'py>) -> PyResult<&'py Bound<'py, PyAny>> { + FUNCTOOLS_PARTIAL + .get_or_try_init(py, || { + Ok(py + .import("functools")? + .getattr(pyo3::intern!(py, "partial"))? + .unbind()) + }) + .map(|o| o.bind(py)) +} + +fn outcome_type<'py>(py: Python<'py>) -> PyResult<&'py Bound<'py, PyAny>> { + OUTCOME_OUTCOME + .get_or_try_init(py, || { + Ok(py + .import("outcome")? + .getattr(pyo3::intern!(py, "Outcome"))? + .unbind()) + }) + .map(|o| o.bind(py)) +} + +// --------------------------------------------------------------------------- +// TrioWaker + WakerCell +// --------------------------------------------------------------------------- + +/// Abort callback passed to `wait_task_rescheduled`. Returns `Abort.FAILED` if +/// a Rust wake is already in flight (so that wake performs the single permitted +/// reschedule); otherwise claims the slot and returns `Abort.SUCCEEDED`. +#[pyclass] +struct TrioAbortFunc { + woken: Arc, +} + +#[pymethods] +impl TrioAbortFunc { + fn __call__(&self, py: Python<'_>, _raise_cancel: Py) -> PyResult> { + let abort = trio_attr(py, &TRIO_ABORT, "Abort")?; + let variant = if self.woken.swap(true, Ordering::AcqRel) { + "FAILED" + } else { + "SUCCEEDED" + }; + abort.getattr(variant).map(Bound::unbind) + } +} + +/// Park/wake primitive for trio: parks via `wait_task_rescheduled`, wakes via +/// `TrioToken.run_sync_soon(reschedule, task)`. +/// +/// Correctness: trio requires **exactly one** `reschedule` per +/// `wait_task_rescheduled` (an `Abort.SUCCEEDED` return counts as that one). +/// An `AtomicBool` guard ensures that a Rust-side wake racing with a trio +/// cancellation cannot produce a double-reschedule. +struct TrioWaker { + task: Py, + token: Py, + woken: Arc, +} + +impl TrioWaker { + fn new(py: Python<'_>) -> PyResult { + Ok(Self { + task: trio_attr(py, &TRIO_CURRENT_TASK, "current_task")? + .call0()? + .unbind(), + token: current_trio_token(py)?, + // Starts "woken" (not armed) so a synchronous self-wake before the + // first `yield_` cannot queue a spurious reschedule that would + // later fire against an unrelated park. + woken: Arc::new(AtomicBool::new(true)), + }) + } + + fn yield_(&self, py: Python<'_>) -> PyResult> { + self.woken.store(false, Ordering::Release); + let abort = TrioAbortFunc { + woken: self.woken.clone(), + }; + // We extract the trap object from the `wait_task_rescheduled` coroutine + // and discard the coroutine itself. This relies on that coroutine having + // no `finally:`/cleanup body — true of trio's implementation today + // (https://trio.readthedocs.io/en/stable/reference-lowlevel.html#trio.lowlevel.wait_task_rescheduled), + // but an implementation detail rather than a stability guarantee. + let result = trio_attr(py, &TRIO_WAIT_TASK_RESCHEDULED, "wait_task_rescheduled") + .and_then(|f| f.call1((abort,))) + .and_then(|c| c.call_method0(pyo3::intern!(py, "__await__"))) + .and_then(|i| i.call_method0(pyo3::intern!(py, "__next__"))) + .map(Bound::unbind); + if result.is_err() { + self.woken.store(true, Ordering::Release); + } + result + } + + fn traverse(&self, visit: &pyo3::PyVisit<'_>) -> Result<(), pyo3::PyTraverseError> { + visit.call(&self.task)?; + visit.call(&self.token) + } + + fn wake_threadsafe(&self, py: Python<'_>) { + if self.woken.swap(true, Ordering::AcqRel) { + return; + } + let reschedule = match trio_reschedule(py) { + Ok(r) => r, + Err(e) => { + e.print_and_set_sys_last_vars(py); + return; + } + }; + // `run_sync_soon` may raise `RunFinishedError` if `trio.run` already + // exited; that is benign during shutdown. + if let Err(e) = self.token.bind(py).call_method1( + pyo3::intern!(py, "run_sync_soon"), + (reschedule, self.task.bind(py)), + ) { + e.print_and_set_sys_last_vars(py); + } + } +} + +/// `Arc`-able cell that adapts a `TrioWaker` into a `std::task::Waker` via +/// `ArcWake`. +pub(crate) struct WakerCell { + inner: Mutex>>, + /// Set by `wake_by_ref` and consumed by `yield_`; lets a wake that arrives + /// before the `TrioWaker` is installed (or between polls) trigger an + /// immediate re-poll instead of being lost. + pending_wake: AtomicBool, +} + +impl WakerCell { + fn new() -> Arc { + Arc::new(Self { + inner: Mutex::new(None), + pending_wake: AtomicBool::new(false), + }) + } + + /// Ensure a `TrioWaker` is installed, creating one if absent, then return + /// the value to yield from `__next__`. Returns `None` if a wake was + /// recorded while no waker was armed — the caller should re-poll instead + /// of yielding. + fn yield_(&self, py: Python<'_>) -> PyResult>> { + // Acquire (or lazily create) the waker. The mutex is never held across + // a Python FFI call, avoiding GIL/mutex lock-order inversion. + let waker: Arc = { + let existing = self.inner.lock().unwrap().clone(); + match existing { + Some(w) => w, + None => { + // Defensive cross-check: this path is only reached when + // `RuntimeKind::Trio` was detected at TaskLocals + // construction time, but if the coroutine is then awaited + // under a different runtime, fail clearly rather than with + // an opaque `current_task()` error. `None` (sniffio absent) + // is treated as trio since detection already succeeded. + if !matches!(sniff(py).as_deref(), Some("trio") | None) { + return Err(PyRuntimeError::new_err( + "RustCoroutine awaited outside trio; use future_into_py with \ + the running library's TaskLocals (or no explicit locals) instead", + )); + } + let w = Arc::new(TrioWaker::new(py)?); + *self.inner.lock().unwrap() = Some(w.clone()); + w + } + } + }; + if self.pending_wake.swap(false, Ordering::AcqRel) { + return Ok(None); + } + let yielded = waker.yield_(py)?; + // A wake_by_ref that lands while yield_ is arming (e.g., between the + // swap above and TrioWaker setting woken=false) would otherwise be a + // lost wake. Re-check and self-trigger so the just-armed park + // resolves immediately instead of deadlocking. + if self.pending_wake.swap(false, Ordering::AcqRel) { + waker.wake_threadsafe(py); + } + Ok(Some(yielded)) + } + + fn clear(&self) { + *self.inner.lock().unwrap() = None; + self.pending_wake.store(false, Ordering::Release); + } + + fn traverse(&self, visit: &pyo3::PyVisit<'_>) -> Result<(), pyo3::PyTraverseError> { + // GC traverse must not block; on free-threaded builds GC may run + // concurrently with `wake_by_ref`. If contended, skip — missed this + // collection; the cycle will be caught on a later one. + let waker = match self.inner.try_lock() { + Ok(guard) => guard.clone(), + Err(_) => return Ok(()), + }; + if let Some(w) = waker { + w.traverse(visit)?; + } + Ok(()) + } +} + +impl ArcWake for WakerCell { + fn wake_by_ref(arc_self: &Arc) { + arc_self.pending_wake.store(true, Ordering::Release); + // Clone the waker out so the mutex is released before any Python call; + // `Python::attach` may block on the GIL and must not happen while + // holding the mutex. + let waker = arc_self.inner.lock().unwrap().clone(); + if let Some(w) = waker { + Python::attach(|py| w.wake_threadsafe(py)); + } + } +} + +// --------------------------------------------------------------------------- +// Coroutine pyclass +// --------------------------------------------------------------------------- + +type BoxFut = Pin>> + Send>>; + +/// Python awaitable backed by a Rust future. Parks the awaiting trio task via +/// `trio.lowlevel.wait_task_rescheduled` and wakes via `reschedule`. +/// +/// This is distinct from `pyo3::coroutine::Coroutine` (the `experimental-async` +/// feature). That pyclass polls the user's future inline from `__next__`; this +/// one wraps a `oneshot::Receiver` whose sender is fed by the user's future +/// running on a separate Rust executor via `R::spawn`, so the only park/wake +/// state needed is the trio reschedule. Converging the two would mean teaching +/// pyo3-core's waker about trio, which is tracked upstream rather than here. +#[pyclass(name = "RustCoroutine")] +pub(crate) struct Coroutine { + fut: Mutex>, + waker_cell: Arc, + /// Dropping this signals the spawned Rust task (if any) to abort. + cancel_tx: Mutex>>, +} + +impl Coroutine { + fn with_cancel(fut: BoxFut, cancel_tx: oneshot::Sender<()>) -> Self { + Self { + fut: Mutex::new(Some(fut)), + waker_cell: WakerCell::new(), + cancel_tx: Mutex::new(Some(cancel_tx)), + } + } + + fn finish(&mut self) { + *self.fut.get_mut().unwrap() = None; + self.waker_cell.clear(); + *self.cancel_tx.get_mut().unwrap() = None; + } + + fn poll(&mut self, py: Python<'_>) -> PyResult> { + let cell = self.waker_cell.clone(); + let std_waker = waker_ref(&cell); + let mut cx = Context::from_waker(&std_waker); + // The boxed `fut` is always `rx.await` for a oneshot receiver (see + // `future_into_coroutine`), woken exactly once when the spawned Rust + // task sends. The loop only iterates more than once if that wake + // races in between `poll` and `yield_`, in which case the second poll + // is `Ready`. + loop { + let fut_slot = self.fut.get_mut().unwrap(); + let fut = match fut_slot.as_mut() { + Some(f) => f, + None => { + return Err(PyRuntimeError::new_err( + "cannot reuse already awaited RustCoroutine", + )) + } + }; + match fut.as_mut().poll(&mut cx) { + Poll::Ready(res) => { + self.finish(); + return match res { + Ok(v) => Err(PyStopIteration::new_err(v)), + Err(e) if e.is_instance_of::(py) => { + let wrapped = PyRuntimeError::new_err("coroutine raised StopIteration"); + wrapped.set_cause(py, Some(e)); + Err(wrapped) + } + Err(e) => Err(e), + }; + } + Poll::Pending => match cell.yield_(py) { + Ok(Some(yielded)) => return Ok(yielded), + Ok(None) => continue, + Err(e) => { + self.finish(); + return Err(e); + } + }, + } + } + } +} + +#[pymethods] +impl Coroutine { + #[classattr] + #[pyo3(name = "__name__")] + fn name() -> &'static str { + "RustCoroutine" + } + + #[classattr] + #[pyo3(name = "__qualname__")] + fn qualname() -> &'static str { + "RustCoroutine" + } + + fn __await__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(&mut self, py: Python<'_>) -> PyResult> { + self.poll(py) + } + + /// trio's runner drives us via `coro.send(outcome)` (see + /// `trio/_core/_run.py` — it uses `send` not `outcome.send(coro)` to + /// work around CPython `throw()` bugs). After a successful abort the + /// outcome is `Error(Cancelled)`; unwrap it so cancellation propagates + /// instead of being silently dropped into a re-park busy loop. + #[pyo3(signature = (value = None))] + fn send(&mut self, py: Python<'_>, value: Option>) -> PyResult> { + if let Some(v) = value { + let v = v.bind(py); + // trio hard-depends on `outcome`; surface ImportError loudly + // rather than caching None and silently re-polling. + let outcome_ty = outcome_type(py)?; + if v.is_instance(outcome_ty).unwrap_or(false) { + if let Err(e) = v.call_method0(pyo3::intern!(py, "unwrap")) { + self.finish(); + return Err(e); + } + // Value(x).unwrap() returns x — discarded; trio only sends + // Value(None) on resume or Error(Cancelled) on abort. + } + } + self.poll(py) + } + + /// Single-argument `throw(exc)` only — the 3-arg `throw(type, value, tb)` + /// form is deprecated since CPython 3.12 and trio's runner never uses + /// `throw()` (it sends `outcome.Error` via `send()` instead). + fn throw(&mut self, exc: Bound<'_, PyAny>) -> PyResult<()> { + let err = if let Ok(ty) = exc.cast::() { + PyErr::from_type(ty.clone(), ()) + } else { + PyErr::from_value(exc) + }; + self.finish(); + Err(err) + } + + fn close(&mut self) { + self.finish(); + } + + fn __traverse__(&self, visit: pyo3::PyVisit<'_>) -> Result<(), pyo3::PyTraverseError> { + self.waker_cell.traverse(&visit) + } + + fn __clear__(&mut self) { + self.finish(); + } +} + +// --------------------------------------------------------------------------- +// into_future_with_locals — trio arm +// --------------------------------------------------------------------------- + +static TRIO_RUNNER: PyOnceLock> = PyOnceLock::new(); + +const TRIO_RUNNER_SRC: &str = r#" +async def _runner(awaitable, completer): + try: + result = await awaitable + except BaseException as exc: + completer.set_error(exc) + else: + completer.set_result(result) +"#; + +fn trio_runner(py: Python<'_>) -> PyResult<&Bound<'_, PyAny>> { + TRIO_RUNNER + .get_or_try_init(py, || -> PyResult> { + let module = PyModule::from_code( + py, + &std::ffi::CString::new(TRIO_RUNNER_SRC).unwrap(), + &std::ffi::CString::new("pyo3_async_runtimes/_trio_runner.py").unwrap(), + &std::ffi::CString::new("pyo3_async_runtimes_trio_runner").unwrap(), + )?; + Ok(module.getattr(pyo3::intern!(py, "_runner"))?.unbind()) + }) + .map(|o| o.bind(py)) +} + +/// Receives the outcome of a trio system task and forwards it to a Rust +/// oneshot channel. +#[pyclass] +struct OneshotSender { + tx: Mutex>>>>, +} + +#[pymethods] +impl OneshotSender { + fn set_result(&self, value: Py) { + if let Some(tx) = self.tx.lock().unwrap().take() { + let _ = tx.send(Ok(value)); + } + } + + fn set_error(&self, exc: Bound<'_, PyAny>) { + if let Some(tx) = self.tx.lock().unwrap().take() { + let _ = tx.send(Err(PyErr::from_value(exc))); + } + } +} + +/// Schedule `awaitable` as a trio system task via `token.run_sync_soon`, in +/// the captured `context`, and send its outcome through `tx`. +pub(crate) fn schedule_awaitable( + py: Python<'_>, + token: &Bound<'_, PyAny>, + context: &Bound<'_, PyAny>, + awaitable: Bound<'_, PyAny>, + tx: oneshot::Sender>>, +) -> PyResult<()> { + let completer = Bound::new( + py, + OneshotSender { + tx: Mutex::new(Some(tx)), + }, + )?; + let runner = trio_runner(py)?; + let spawn = trio_spawn_system_task(py)?; + // run_sync_soon only forwards positional args, so bind context= via + // functools.partial. spawn_system_task(context=...) requires trio>=0.23. + let kwargs = pyo3::types::PyDict::new(py); + kwargs.set_item(pyo3::intern!(py, "context"), context)?; + let bound_spawn = + functools_partial(py)?.call((spawn, runner, awaitable, completer), Some(&kwargs))?; + token.call_method1(pyo3::intern!(py, "run_sync_soon"), (bound_spawn,))?; + Ok(()) +} + +// --------------------------------------------------------------------------- +// future_into_py_with_locals — trio arm +// --------------------------------------------------------------------------- + +/// Spawn `fut` on `R` (with task-local propagation) and return a [`Coroutine`] +/// that awaits its result. Called from +/// [`generic::future_into_py_with_locals`](crate::generic::future_into_py_with_locals) +/// when `locals.kind() == Trio`. +#[allow(unused_must_use)] // R::spawn / R::spawn_blocking JoinHandles intentionally fire-and-forget +pub(crate) fn future_into_coroutine( + py: Python<'_>, + locals: crate::TaskLocals, + fut: F, +) -> PyResult> +where + R: Runtime + ContextExt, + F: Future> + Send + 'static, + T: for<'py> IntoPyObject<'py> + Send + 'static, +{ + let (tx, rx) = oneshot::channel(); + let (cancel_tx, cancel_rx) = oneshot::channel::<()>(); + R::spawn(async move { + let scoped = R::scope(locals, async move { + let fut = std::pin::pin!(fut); + match select(cancel_rx, fut).await { + Either::Left(_) => None, + Either::Right((result, _)) => Some(result), + } + }); + match AssertUnwindSafe(scoped).catch_unwind().await { + Err(payload) => { + let msg = get_panic_message(&*payload).to_owned(); + // Same GIL rationale as the success path below: tx.send wakes + // the Python-side receiver, which acquires the GIL. + R::spawn_blocking(move || { + let _ = tx.send(Err(RustPanic::new_err(format!( + "rust future panicked: {msg}" + )))); + }); + } + Ok(None) => {} + Ok(Some(result)) => { + // Do not block a tokio worker thread on the GIL — same rationale + // as `generic::future_into_py_with_locals`. + R::spawn_blocking(move || { + let py_result = Python::attach(|py| result.and_then(|v| v.into_py_any(py))); + let _ = tx.send(py_result); + }); + } + } + }); + let boxed: BoxFut = Box::pin(async move { + rx.await.unwrap_or_else(|_| { + Err(PyRuntimeError::new_err( + "Rust task was dropped before completion", + )) + }) + }); + Coroutine::with_cancel(boxed, cancel_tx).into_bound_py_any(py) +}