diff --git a/.github/workflows/python-wheels.yml b/.github/workflows/python-wheels.yml new file mode 100644 index 0000000..f2fa2db --- /dev/null +++ b/.github/workflows/python-wheels.yml @@ -0,0 +1,180 @@ +name: Python wheels + +on: + push: + branches: [main] + tags: ["py-v*", "py-test-*"] + pull_request: + branches: [main] + workflow_dispatch: # manual ad-hoc builds from any branch + +concurrency: + # Tag pushes get their own group so publishes never get cancelled. + group: >- + ${{ github.workflow }}-${{ github.ref }}-${{ startsWith(github.ref, 'refs/tags/') && 'publish' || 'branch' }} + cancel-in-progress: ${{ !startsWith(github.ref, 'refs/tags/') }} + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + name: Build wheel (${{ matrix.target.label }}) + runs-on: ${{ matrix.target.runner }} + strategy: + fail-fast: false + matrix: + target: + # `manylinux: "2_28"` makes maturin-action run the build inside the + # official PyPA manylinux_2_28 container (Rocky Linux 8 / glibc 2.28). + # Without this, the build runs on the host (Ubuntu glibc 2.39) and + # produces a wheel that fails the auditwheel manylinux_2_28 check. + - label: linux-x86_64 + runner: ubuntu-latest + target: x86_64-unknown-linux-gnu + manylinux: "2_28" + - label: macos-universal2 + runner: macos-latest + target: universal2-apple-darwin + manylinux: "auto" + - label: windows-x86_64 + runner: windows-latest + target: x86_64-pc-windows-msvc + manylinux: "auto" + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.10" # abi3 — any 3.10+ works for building + + - uses: dtolnay/rust-toolchain@stable + with: + targets: ${{ matrix.target.label == 'macos-universal2' && 'x86_64-apple-darwin,aarch64-apple-darwin' || '' }} + + - uses: Swatinem/rust-cache@v2 + with: + workspaces: bonsai-py + key: ${{ matrix.target.label }} + + - name: Tag/version guard (tag pushes only) + if: startsWith(github.ref, 'refs/tags/') + shell: bash + run: | + tag="${GITHUB_REF#refs/tags/}" + version="${tag#py-v}" + version="${version#py-test-}" + cargo_version=$(grep -m1 '^version' bonsai-py/Cargo.toml | sed -E 's/.*"([^"]+)".*/\1/') + if [ "$version" != "$cargo_version" ]; then + echo "::error::Tag version '$version' does not match Cargo.toml version '$cargo_version'." + exit 1 + fi + + - uses: PyO3/maturin-action@v1 + with: + working-directory: bonsai-py + command: build + target: ${{ matrix.target.target }} + manylinux: ${{ matrix.target.manylinux }} + args: --release --out dist --strip + + - name: Verify wheel (Linux/macOS only — Windows venv quirks) + if: matrix.target.runner != 'windows-latest' + shell: bash + run: | + python -m venv .venv-test + source .venv-test/bin/activate + pip install --upgrade pip + pip install pytest pytest-timeout mypy + pip install bonsai-py/dist/*.whl + python -c "import bonsai_bt; print(bonsai_bt.__version__)" + pytest bonsai-py/tests/ + + - name: Wheel size sanity (Linux/macOS only) + if: matrix.target.runner != 'windows-latest' + shell: bash + run: | + size=$(stat -c%s bonsai-py/dist/*.whl 2>/dev/null || stat -f%z bonsai-py/dist/*.whl) + ceiling=$((5 * 1024 * 1024)) + if [ "$size" -gt "$ceiling" ]; then + echo "::error::Wheel size $size exceeds 5MB ceiling." + exit 1 + fi + echo "Wheel size: $size bytes (under 5MB ceiling)." + + - uses: actions/upload-artifact@v4 + with: + name: wheel-${{ matrix.target.label }} + path: bonsai-py/dist/*.whl + retention-days: ${{ startsWith(github.ref, 'refs/tags/') && 90 || 14 }} + + sdist: + name: Build sdist + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.10" + - uses: PyO3/maturin-action@v1 + with: + working-directory: bonsai-py + command: sdist + args: --out dist + - uses: actions/upload-artifact@v4 + with: + name: sdist + path: bonsai-py/dist/*.tar.gz + retention-days: ${{ startsWith(github.ref, 'refs/tags/') && 90 || 14 }} + + verify-sdist: + name: Verify sdist builds from source + needs: sdist + runs-on: ubuntu-latest + steps: + - uses: actions/setup-python@v5 + with: + python-version: "3.10" + - uses: dtolnay/rust-toolchain@stable + - uses: actions/download-artifact@v4 + with: + name: sdist + path: dist + - name: Install + smoke-test from sdist + run: | + python -m venv .venv-sdist + source .venv-sdist/bin/activate + pip install --upgrade pip + pip install --no-binary :all: dist/*.tar.gz + python -c "import bonsai_bt; print(bonsai_bt.__version__)" + + publish: + name: Publish to PyPI / TestPyPI + needs: [build, sdist, verify-sdist] + if: startsWith(github.ref, 'refs/tags/py-v') || startsWith(github.ref, 'refs/tags/py-test-') + runs-on: ubuntu-latest + environment: + name: ${{ startsWith(github.ref, 'refs/tags/py-test-') && 'testpypi' || 'pypi' }} + permissions: + id-token: write # OIDC for Trusted Publishing + steps: + - uses: actions/download-artifact@v4 + with: + path: dist + pattern: wheel-* + merge-multiple: true + + - uses: actions/download-artifact@v4 + with: + name: sdist + path: dist + + - name: List artifacts to publish + run: ls -la dist/ + + - uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: >- + ${{ startsWith(github.ref, 'refs/tags/py-test-') && 'https://test.pypi.org/legacy/' || 'https://upload.pypi.org/legacy/' }} + packages-dir: dist + skip-existing: true diff --git a/.github/workflows/rust-pr.yml b/.github/workflows/rust-pr.yml index 5026a5f..0c600a5 100644 --- a/.github/workflows/rust-pr.yml +++ b/.github/workflows/rust-pr.yml @@ -28,3 +28,29 @@ jobs: run: cargo build --examples - name: Run tests run: cargo test --verbose + pytest: + runs-on: ubuntu-latest + strategy: + matrix: + python: ["3.10", "3.13"] + steps: + - uses: actions/checkout@v6.0.2 + - uses: actions/setup-python@v6.2.0 + with: + python-version: ${{ matrix.python }} + - uses: dtolnay/rust-toolchain@stable + - name: Create + activate venv (maturin develop requires one) + # $GITHUB_PATH prepends to PATH for every subsequent step; + # $GITHUB_ENV exports VIRTUAL_ENV (which maturin/pip detect) + # so we don't need to `source venv/bin/activate` in each step. + run: | + python -m venv $HOME/.venv + echo "$HOME/.venv/bin" >> $GITHUB_PATH + echo "VIRTUAL_ENV=$HOME/.venv" >> $GITHUB_ENV + - name: Install maturin and test deps + run: pip install "maturin>=1.7,<2.0" pytest pytest-timeout mypy + - name: Build and install bonsai-py + working-directory: bonsai-py + run: maturin develop --release + - name: Run pytest + run: pytest -v bonsai-py/tests/ diff --git a/.gitignore b/.gitignore index cbed4d7..d028cad 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,16 @@ Cargo.lock **/*.rs.bk .idea/ + +# Python virtual environments +.venv/ +venv/ + +# Python build artifacts (maturin develop output, bytecode caches) +__pycache__/ +*.pyc +*.pyo +*.so +*.abi3.so +*.pyd +*.dylib diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6ddbb47..eecd0f9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,10 +29,10 @@ repos: - id: end-of-file-fixer - id: file-contents-sorter - id: fix-byte-order-marker - - id: fix-encoding-pragma - id: forbid-new-submodules - id: mixed-line-ending - id: name-tests-test + args: [--pytest-test-first] - id: requirements-txt-fixer - id: sort-simple-yaml - id: trailing-whitespace @@ -50,3 +50,10 @@ repos: pass_filenames: false types: [file, rust] language: system + - id: regen-stubs + name: regenerate bonsai_py type stub + description: Regenerate python/bonsai_py/__init__.pyi from #[gen_stub_*] annotations. If the regenerated stub differs from the committed version, the hook fails so the developer can stage the update. + entry: bash bonsai-py/scripts/regen-stubs.sh + language: system + files: ^(bonsai-py/src/.*\.rs|bonsai-py/python/bonsai_py/__init__\.pyi)$ + pass_filenames: false diff --git a/Cargo.toml b/Cargo.toml index 3dea7bd..cb0fed6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,3 @@ [workspace] resolver = "2" -members = ["bonsai", "examples"] +members = ["bonsai", "examples", "bonsai-py"] diff --git a/README.md b/README.md index 7bea51c..42b6d6b 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,9 @@ * [Honorable Mentions](#similar-crates) ## Using Bonsai + +### Rust + Bonsai is available on crates.io. The recommended way to use it is to add a line into your Cargo.toml such as: ```toml @@ -36,6 +39,10 @@ Bonsai is available on crates.io. The recommended way to use it is to add a line bonsai-bt = "*" ``` +### Python + +Python bindings are available — see [`bonsai-py/`](bonsai-py/) for installation, examples, and a side-by-side comparison of the same BT in Rust and Python. The package wraps the same Rust crate, so the BT semantics are identical; only the API surface differs. + ## What is a Behavior Tree? A _Behavior Tree_ (BT) is a data structure in which we can set the rules of how certain _behavior's_ can occur, and the order in which they would execute. BTs are a very efficient way of creating complex systems that are both modular and reactive. These properties are crucial in many applications, which has led to the spread of BT from computer game programming to many branches of AI and Robotics. diff --git a/bonsai-py/Cargo.toml b/bonsai-py/Cargo.toml new file mode 100644 index 0000000..844f5be --- /dev/null +++ b/bonsai-py/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "bonsai-py" +version = "0.12.0" +edition = "2021" +rust-version = "1.80.0" +description = "Python bindings for the bonsai-bt behavior tree library" +license = "MIT" +authors = ["Kristoffer Solberg Rakstad "] +repository = "https://github.com/sollimann/bonsai.git" +homepage = "https://github.com/sollimann/bonsai" +publish = false + +[lib] +# Internal Rust crate name — kept as `bonsai_py` to avoid colliding with the +# workspace's `bonsai-bt` crate at `bonsai/`, which also produces `libbonsai_bt.rlib`. +# The Python-facing module name is controlled separately by +# `[tool.maturin] module-name = "bonsai_bt"` in pyproject.toml. +name = "bonsai_py" +crate-type = ["cdylib", "rlib"] + +[[bin]] +name = "stub_gen" +path = "src/bin/stub_gen.rs" + +[dependencies] +# Note: `extension-module` is enabled by maturin via pyproject.toml's +# `[tool.maturin].features` setting. Keeping it out of the default feature +# list lets `cargo run --bin stub_gen` link libpython for the regular binary +# path; maturin still activates it for the wheel build. +pyo3 = { version = "0.28", features = ["abi3-py310"] } +bonsai-bt = { path = "../bonsai", version = "0.12", features = ["visualize"] } +pyo3-stub-gen = "0.22.3" diff --git a/bonsai-py/README.md b/bonsai-py/README.md new file mode 100644 index 0000000..d935c9a --- /dev/null +++ b/bonsai-py/README.md @@ -0,0 +1,81 @@ +# bonsai-bt - Python bindings + +Python bindings for the [bonsai-bt](https://github.com/sollimann/bonsai) +behavior-tree library. + +## Installation (dev) + +```bash +python -m venv .venv +source .venv/bin/activate # Windows: .venv\Scripts\Activate.ps1 +pip install maturin +cd bonsai-py +maturin develop +python -c "import bonsai_bt; print(bonsai_bt.__version__)" +``` + +## Same BT in Rust and Python + +A minimal three-node tree (`Hello → Wait(1.0) → Goodbye`) implemented in both languages. Semantics are identical because the Python package is a thin wrapper around the Rust crate; only the API surface differs (Rust requires an `enum` + explicit types; Python uses any hashable object as the action payload). + +### Rust + +```rust +use bonsai_bt::{Behavior, Event, Status, UpdateArgs, BT}; + +#[derive(Clone, Debug)] +enum Greet { Hello, Goodbye } + +fn main() { + let tree = Behavior::Sequence(vec![ + Behavior::Action(Greet::Hello), + Behavior::Wait(1.0), + Behavior::Action(Greet::Goodbye), + ]); + + let mut bt: BT = BT::new(tree, ()); + + for _ in 0..5 { + let e: Event = UpdateArgs { dt: 0.5 }.into(); + bt.tick(&e, &mut |args, _bb| { + match *args.action { + Greet::Hello => println!("hello"), + Greet::Goodbye => println!("goodbye"), + } + (Status::Success, args.dt) + }); + } +} +``` + +### Python + +```python +import bonsai_bt as bt + +tree = bt.Sequence([ + bt.Action("hello"), + bt.Wait(1.0), + bt.Action("goodbye"), +]) + +tree_bt = bt.BT(tree, None) + +def cb(args, _bb): + print(args.action) + return (bt.Status.Success, args.dt) + +for _ in range(5): + tree_bt.tick(0.5, cb) +``` + +Output (both): + + hello + goodbye + +For richer examples — multi-job orchestration, visualizer integration, parallel agents — see [examples/](examples/). + +## License + +MIT - see [LICENSE](../LICENSE). diff --git a/bonsai-py/examples/README.md b/bonsai-py/examples/README.md new file mode 100644 index 0000000..74ec2ae --- /dev/null +++ b/bonsai-py/examples/README.md @@ -0,0 +1,76 @@ +# bonsai-py examples + +Pure-Python examples mirroring `examples/` in the Rust workspace. Each example is a single self-contained `.py` file. + +## Prerequisites + +Create and activate a Python venv (one-time), then build & install the extension: + +```bash +# 1. Create a venv (only needed the first time) +python3 -m venv .venv + +# 2. Activate it (every new shell) +source .venv/bin/activate # macOS / Linux / WSL +# .\.venv\Scripts\Activate.ps1 # Windows PowerShell + +# 3. Install build deps + build the extension into the venv +pip install maturin +cd bonsai-py && maturin develop --release && cd .. +``` + +After that, just `source .venv/bin/activate` + `python bonsai-py/examples/.py` in any new shell. + +## Examples (7) + +### [simple_npc_ai.py](simple_npc_ai.py) — console NPC +NPC runs and shoots until action points are exhausted, then rests and dies. Demonstrates `WhileAll`, blackboard mutation via `@dataclass`, structural-`match` callback. + +```bash +python bonsai-py/examples/simple_npc_ai.py +``` + +### [race_timeout.py](race_timeout.py) — `Race` between work and timeout +A simulated long-running job (random 200–1200 ms on a `threading.Thread`) races a 600 ms timeout. The callback polls the work's `queue.Queue` non-blockingly. Demonstrates `Race`, asyncio main loop + threading worker, the unsendable-BT constraint. + +```bash +python bonsai-py/examples/race_timeout.py +``` + +### [graphviz_demo.py](graphviz_demo.py) — tree visualization +Builds an attack-drone tree (mix of plain-string and `@dataclass(frozen=True)` payload actions) and prints the graphviz DOT representation. Paste the output into to render it. + +```bash +python bonsai-py/examples/graphviz_demo.py +python bonsai-py/examples/graphviz_demo.py > tree.dot +``` + +### [visualizer_smoke.py](visualizer_smoke.py) — live web visualizer +Drives a deliberately rich 27-node tree at ~400 ms/tick with a 5-step status rotation and per-leaf phase offset; the browser shows continuous color animation. Demonstrates `BT.with_telemetry(port)`, `reset_bt()`, and every major factory. + +```bash +python bonsai-py/examples/visualizer_smoke.py +``` + +Then open in a browser. `Ctrl-C` to stop. + +### [boids_console.py](boids_console.py) — shared BT across N agents +Builds **one** `Behavior` tree and binds it to 10 independent `BT` instances (each with its own `Boid` dataclass blackboard). Updates positions every tick for 30 frames. Demonstrates the shared-subtree pattern, real-time-loop dt, `WhenAll` for parallel updates. + +```bash +python bonsai-py/examples/boids_console.py +``` + +### [threaded_drone.py](threaded_drone.py) — multi-job mission (threading) +Drone mission: takeoff → check battery → fly (or fall back to land) → land → repeat. Each long-running step runs on a background `threading.Thread`; the BT polls per-job `queue.Queue`s. Prints the tree's `graphviz()` at the start, then runs the mission for ~8 seconds. **Pick this variant when actions block on hardware or sync IO** — blocking syscalls, vendor SDKs without an async API, `subprocess.run`, etc. + +```bash +python bonsai-py/examples/threaded_drone.py +``` + +### [async_drone.py](async_drone.py) — multi-job mission (asyncio) +Same mission and tree as `threaded_drone.py`, but background jobs are `async def` coroutines on a single asyncio event loop, communicating via `asyncio.Queue`. **Pick this variant when actions are awaitable** (`aiohttp`, async DB drivers, websockets) — N-way concurrency without OS thread overhead. + +```bash +python bonsai-py/examples/async_drone.py +``` diff --git a/bonsai-py/examples/__init__.py b/bonsai-py/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bonsai-py/examples/async_drone.py b/bonsai-py/examples/async_drone.py new file mode 100644 index 0000000..c63ddde --- /dev/null +++ b/bonsai-py/examples/async_drone.py @@ -0,0 +1,217 @@ +""" +Async drone mission with multi-job orchestration (asyncio-native). + +Same mission as `threaded_drone.py`: + + While( + cond = AvoidOthers, + body = [ + TakeOff, + Select([ + Sequence([CheckBattery, FlyToPoint(10, 10, 10)]), + FlyToPoint(0, 0, 0), # dock fallback + ]), + Land, + ], + ) + +…but each long-running action is an `async def` coroutine scheduled on +the same asyncio event loop as the BT tick loop, communicating with the +BT via `asyncio.Queue`. No OS threads — every job is cooperatively +scheduled on a single thread. + +**Pick this variant when actions are awaitable** — `aiohttp`, async DB +drivers (`asyncpg`, `motor`), websockets, async stream readers, async +subprocess. + +How the integration works: + +* `BT.tick()` is synchronous. The callback we pass to it is also + synchronous — it polls each per-job `asyncio.Queue` with the + non-blocking `get_nowait()`. The callback never awaits. +* On the first call for a given action, the callback spawns the job + coroutine via `asyncio.create_task(job(q))`. The task starts running + on the next loop iteration. +* The main loop is `async def` and uses `await asyncio.sleep(0.5)` to + yield control back to the event loop between BT ticks. While main + is sleeping, the job coroutines run and push status updates into + their queues. + +Run: + python bonsai-py/examples/async_drone.py +""" +from __future__ import annotations + +import asyncio +import enum +import random +import time +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Optional + +import bonsai_bt as bt + +MAX_WALL_SECONDS = 8.0 # Demo cap; the BT itself would loop forever. + + +class DroneAction(enum.Enum): + AVOID_OTHERS = enum.auto() + TAKE_OFF = enum.auto() + LAND = enum.auto() + CHECK_BATTERY = enum.auto() + + +@dataclass(frozen=True) +class FlyToPoint: + x: float + y: float + z: float + + +@dataclass +class DroneState: + avoid_others: Optional[asyncio.Queue[bt.Status]] = None + takeoff: Optional[asyncio.Queue[bt.Status]] = None + land: Optional[asyncio.Queue[bt.Status]] = None + fly_to_point: Optional[asyncio.Queue[bt.Status]] = None + + +# ---- Background "jobs": one coroutine per long-running action -------------- +# Each pushes Status.Running every step, then Status.Success when done. + +async def collision_avoidance_task(q: asyncio.Queue[bt.Status]) -> None: + print("collision avoidance task started") + while True: + q.put_nowait(bt.Status.Running) + await asyncio.sleep(0.1) + + +async def takeoff_task(q: asyncio.Queue[bt.Status]) -> None: + print("takeoff task started") + for i in range(3): + q.put_nowait(bt.Status.Running) + await asyncio.sleep(0.3) + print(f"takeoff task running for {(i + 1) * 300} ms") + print("takeoff task finished") + q.put_nowait(bt.Status.Success) + + +async def landing_task(q: asyncio.Queue[bt.Status]) -> None: + print("landing task started") + for i in range(3): + q.put_nowait(bt.Status.Running) + await asyncio.sleep(0.3) + print(f"landing task running for {(i + 1) * 300} ms") + print("landing task finished") + q.put_nowait(bt.Status.Success) + + +async def fly_to_point_task(point: FlyToPoint, q: asyncio.Queue[bt.Status]) -> None: + print(f"flying task started: target=({point.x}, {point.y}, {point.z})") + for i in range(3): + q.put_nowait(bt.Status.Running) + await asyncio.sleep(0.5) + print(f"flying task running for {(i + 1) * 500} ms") + print("flying task finished") + q.put_nowait(bt.Status.Success) + + +# ---- Polling helpers ------------------------------------------------------- + +SpawnCoro = Callable[[asyncio.Queue[bt.Status]], Awaitable[None]] + + +def poll_job( + q_attr: str, + state: DroneState, + spawn: SpawnCoro, + dt: float, +) -> tuple[bt.Status, float]: + """Generic 'schedule on first call, poll thereafter' pattern for any async job. + + Calls `asyncio.create_task(spawn(q))` on first invocation; subsequent + invocations drain the queue non-blockingly. Must be called from inside + a running event loop (the BT tick runs from inside `await asyncio.sleep` + in `main`, so this holds). + """ + q = getattr(state, q_attr) + if q is None: + q = asyncio.Queue() + setattr(state, q_attr, q) + asyncio.create_task(spawn(q)) + try: + status = q.get_nowait() + except asyncio.QueueEmpty: + return bt.RUNNING + if status == bt.Status.Running: + return bt.RUNNING + setattr(state, q_attr, None) + return (status, dt) + + +def make_callback(state: DroneState, rng: random.Random): + def cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + action = args.action + if action == DroneAction.AVOID_OTHERS: + return poll_job("avoid_others", state, collision_avoidance_task, args.dt) + if action == DroneAction.TAKE_OFF: + return poll_job("takeoff", state, takeoff_task, args.dt) + if action == DroneAction.LAND: + return poll_job("land", state, landing_task, args.dt) + if action == DroneAction.CHECK_BATTERY: + # Fast sync action: 80% chance OK, 20% chance low -> Select fallback fires. + ok = rng.random() < 0.8 + print(f"check battery: {'OK' if ok else 'LOW'}") + return (bt.Status.Success if ok else bt.Status.Failure, args.dt) + if isinstance(action, FlyToPoint): + async def spawn(q: asyncio.Queue[bt.Status], p: FlyToPoint = action) -> None: + await fly_to_point_task(p, q) + return poll_job("fly_to_point", state, spawn, args.dt) + raise ValueError(f"unknown action: {action!r}") + + return cb + + +def build_tree() -> bt.Behavior: + fly_if_healthy = bt.Sequence([ + bt.Action(DroneAction.CHECK_BATTERY), + bt.Action(FlyToPoint(10.0, 10.0, 10.0)), + ]) + fly_to_dock = bt.Action(FlyToPoint(0.0, 0.0, 0.0)) + mission_with_fallback = bt.Select([fly_if_healthy, fly_to_dock]) + return bt.While( + bt.Action(DroneAction.AVOID_OTHERS), + [ + bt.Action(DroneAction.TAKE_OFF), + mission_with_fallback, + bt.Action(DroneAction.LAND), + ], + ) + + +async def main() -> None: + tree_bt = bt.BT(build_tree(), None) + state = DroneState() + rng = random.Random(0) + callback = make_callback(state, rng) + print("=== drone tree ===") + print(tree_bt.graphviz()) + print("=== mission start ===") + + start = time.perf_counter() + last = start + while True: + await asyncio.sleep(0.5) + now = time.perf_counter() + dt = now - last + last = now + result = tree_bt.tick(dt, callback) + if result is None: + break + if now - start > MAX_WALL_SECONDS: + print(f"=== demo cap reached ({MAX_WALL_SECONDS}s) — exiting ===") + break + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/bonsai-py/examples/boids_console.py b/bonsai-py/examples/boids_console.py new file mode 100644 index 0000000..0d8370c --- /dev/null +++ b/bonsai-py/examples/boids_console.py @@ -0,0 +1,179 @@ +""" +Boids flocking with one Behavior shared across many agents. + +Build the `Behavior` ONCE, then construct `bt.BT(shared_tree, boid)` for +each of N agents — every BT instance has its own blackboard (a `Boid` +dataclass with `x, y, dx, dy`) but they all share the same tree +definition. Each tick runs all 5 flocking rules: + + While( + cond = WhenAll([FlyTowardsCenter, AvoidOthers]), + body = [MatchVelocity, LimitSpeed, KeepWithinBounds], + ) + +Both cond actions return `Running`, so the BT stays in the body each +tick and all 5 actions fire — updating each boid's velocity and +position. Output is text-only (no graphics window); first and last +boids are logged each tick. + +Demonstrates a shared `Behavior` across many `BT` instances, +real-time-loop `dt` integration, `WhenAll` for parallel cond updates, +and `While`-body re-execution per tick. + +Run: + python bonsai-py/examples/boids_console.py +""" +from __future__ import annotations + +import enum +import math +import random +import time +from dataclasses import dataclass +from typing import Any + +import bonsai_bt as bt + +NUM_BOIDS = 10 +WIDTH = 1280.0 +HEIGHT = 720.0 +SPEED_LIMIT = 400.0 +VISUAL_RANGE = 32.0 +MIN_DISTANCE = 16.0 +TICKS = 30 +DT_SECONDS = 0.1 + + +class Action(enum.Enum): + AVOID_OTHERS = enum.auto() + FLY_TOWARDS_CENTER = enum.auto() + MATCH_VELOCITY = enum.auto() + LIMIT_SPEED = enum.auto() + KEEP_WITHIN_BOUNDS = enum.auto() + + +@dataclass +class Boid: + x: float + y: float + dx: float + dy: float + + def distance(self, other: Boid) -> float: + return math.hypot(self.x - other.x, self.y - other.y) + + +def build_tree() -> bt.Behavior: + """One Behavior, shared across all N boid BTs (matches the Rust pattern).""" + avoid_and_fly = bt.WhenAll([ + bt.Action(Action.FLY_TOWARDS_CENTER), + bt.Action(Action.AVOID_OTHERS), + ]) + return bt.While( + avoid_and_fly, + [ + bt.Action(Action.MATCH_VELOCITY), + bt.Action(Action.LIMIT_SPEED), + bt.Action(Action.KEEP_WITHIN_BOUNDS), + ], + ) + + +def make_callback(idx: int, all_boids: list[Boid]): + """Build a callback closed over this boid's neighbors (via `all_boids`).""" + + def cb(args: Any, boid: Boid) -> tuple[bt.Status, float]: + others = [b for j, b in enumerate(all_boids) if j != idx] + match args.action: + case Action.AVOID_OTHERS: + move_x = move_y = 0.0 + for other in others: + dist = boid.distance(other) + if 0.0 < dist < MIN_DISTANCE: + move_x += boid.x - other.x + move_y += boid.y - other.y + boid.dx += move_x * 0.5 + boid.dy += move_y * 0.5 + return bt.RUNNING + case Action.FLY_TOWARDS_CENTER: + cx = cy = 0.0 + n = 0 + for other in others: + if boid.distance(other) < VISUAL_RANGE: + cx += other.x + cy += other.y + n += 1 + if n > 0: + boid.dx += (cx / n - boid.x) * 0.05 + boid.dy += (cy / n - boid.y) * 0.05 + return bt.RUNNING + case Action.MATCH_VELOCITY: + avg_dx = avg_dy = 0.0 + n = 0 + for other in others: + if boid.distance(other) < VISUAL_RANGE: + avg_dx += other.dx + avg_dy += other.dy + n += 1 + if n > 0: + boid.dx += (avg_dx / n - boid.dx) * 0.1 + boid.dy += (avg_dy / n - boid.dy) * 0.1 + return (bt.Status.Success, args.dt) + case Action.LIMIT_SPEED: + speed = math.hypot(boid.dx, boid.dy) + if speed > SPEED_LIMIT: + boid.dx = boid.dx / speed * SPEED_LIMIT + boid.dy = boid.dy / speed * SPEED_LIMIT + return (bt.Status.Success, args.dt) + case Action.KEEP_WITHIN_BOUNDS: + edge = 40.0 + turn = 16.0 + if boid.x < edge: + boid.dx += turn + if boid.x > WIDTH - edge: + boid.dx -= turn + if boid.y < edge: + boid.dy += turn + if boid.y > HEIGHT - edge: + boid.dy -= turn + return bt.RUNNING + case _: + raise ValueError(f"unknown action: {args.action!r}") + + return cb + + +def main() -> None: + rng = random.Random(0) # deterministic for reproducible console output + boids = [ + Boid( + x=rng.uniform(WIDTH / 4, 3 * WIDTH / 4), + y=rng.uniform(HEIGHT / 4, 3 * HEIGHT / 4), + dx=(rng.random() - 0.5) * SPEED_LIMIT, + dy=(rng.random() - 0.5) * SPEED_LIMIT, + ) + for _ in range(NUM_BOIDS) + ] + + shared_tree = build_tree() + bts = [bt.BT(shared_tree, boids[i]) for i in range(NUM_BOIDS)] + + print(f"Boids console demo: {NUM_BOIDS} agents sharing one Behavior tree.") + for step in range(TICKS): + for i, tree_bt in enumerate(bts): + tree_bt.tick(DT_SECONDS, make_callback(i, boids)) + boid = boids[i] + boid.x += boid.dx * DT_SECONDS + boid.y += boid.dy * DT_SECONDS + print( + f"[boid {i:2d}] step {step:2d}" + f" pos=({boid.x:7.1f}, {boid.y:7.1f})" + f" vel=({boid.dx:7.1f}, {boid.dy:7.1f})" + ) + time.sleep(DT_SECONDS / 10.0) # tiny pause so output is readable + + print(f"Done after {TICKS} ticks. Each BT instance ticked {bts[0].tick_count()} times.") + + +if __name__ == "__main__": + main() diff --git a/bonsai-py/examples/graphviz_demo.py b/bonsai-py/examples/graphviz_demo.py new file mode 100644 index 0000000..f9745fb --- /dev/null +++ b/bonsai-py/examples/graphviz_demo.py @@ -0,0 +1,68 @@ +""" +Print the graphviz DOT representation of an attack-drone behavior tree. + +Builds an attack-drone tree (circle the target, attack when in range, give +up when too far). Calls `BT.graphviz()` to emit +a DOT string. Paste the output into +to render the tree visually. + +Demonstrates `BT.graphviz()`, and composition with `While` / `Sequence` / `WhenAny` / +`Wait` / `WaitForever` / `Action`. + +Run: + python bonsai-py/examples/graphviz_demo.py +""" +from __future__ import annotations + +from dataclasses import dataclass + +import bonsai_bt as bt + +# Payload-less actions are plain strings; payload variants are frozen +# dataclasses (hashable, immutable, work as bt.Action(...) values). +CIRCLING = "Circling" +FLY_TOWARD_PLAYER = "FlyTowardPlayer" + + +@dataclass(frozen=True) +class PlayerWithinDistance: + distance: float + + +@dataclass(frozen=True) +class PlayerFarAwayFromTarget: + distance: float + + +@dataclass(frozen=True) +class AttackPlayer: + damage: float + + +def build_tree() -> bt.Behavior: + circling = bt.Action(CIRCLING) + circle_until_player_within_distance = bt.Sequence([ + bt.While(bt.Wait(5.0), [circling]), + bt.While(bt.Action(PlayerWithinDistance(50.0)), [circling]), + ]) + give_up_or_attack = bt.WhenAny([ + bt.Action(PlayerFarAwayFromTarget(100.0)), + bt.Sequence([ + bt.Action(PlayerWithinDistance(10.0)), + bt.Action(AttackPlayer(0.1)), + ]), + ]) + attack_attempt = bt.While(give_up_or_attack, [bt.Action(FLY_TOWARD_PLAYER)]) + return bt.While( + bt.WaitForever(), + [circle_until_player_within_distance, attack_attempt], + ) + + +def main() -> None: + tree_bt = bt.BT(build_tree(), {}) + print(tree_bt.graphviz()) + + +if __name__ == "__main__": + main() diff --git a/bonsai-py/examples/race_timeout.py b/bonsai-py/examples/race_timeout.py new file mode 100644 index 0000000..f1682c0 --- /dev/null +++ b/bonsai-py/examples/race_timeout.py @@ -0,0 +1,112 @@ +""" +Race a simulated long-running job against a hard timeout. + +A `Race` runs two arms in parallel; whichever finishes first wins: + + Race([ + DO_WORK, # random 200-1200 ms + Sequence([Wait(0.6), ON_TIMEOUT]), # 600 ms hard deadline + ]) + +`DO_WORK` runs on a background `threading.Thread` and reports its status +through a `queue.Queue`. The BT polls non-blockingly with `get_nowait()`; +an empty queue means "still running, try again next tick." + +Demonstrates `Race` for timeouts, an asyncio main loop alongside a +threading worker (the `BT` itself stays on the main thread — it is +unsendable), the `bt.RUNNING` shorthand, and `time.perf_counter()` for +monotonic dt. + +Run: + python bonsai-py/examples/race_timeout.py +""" +from __future__ import annotations + +import asyncio +import enum +import queue +import random +import threading +import time +from dataclasses import dataclass, field +from typing import Any + +import bonsai_bt as bt + + +class MissionAction(enum.Enum): + DO_WORK = enum.auto() + ON_TIMEOUT = enum.auto() + + +@dataclass +class MissionState: + work: queue.Queue[bt.Status] | None = field(default=None) + + +def do_work_task(q: queue.Queue[bt.Status]) -> None: + """Background thread: sleeps in 100ms steps, sends Status.Running each step, + then Status.Success after a random 200..=1200ms total.""" + work_ms = random.randint(200, 1200) + print(f"[do_work] started; planned duration {work_ms} ms") + elapsed = 0 + step_ms = 100 + while elapsed < work_ms: + q.put(bt.Status.Running) + time.sleep(step_ms / 1000.0) + elapsed += step_ms + print(f"[do_work] finished after {elapsed} ms") + q.put(bt.Status.Success) + + +def make_callback(state: MissionState): + def cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + match args.action: + case MissionAction.DO_WORK: + if state.work is None: + state.work = queue.Queue() + threading.Thread( + target=do_work_task, args=(state.work,), daemon=True + ).start() + try: + status = state.work.get_nowait() + except queue.Empty: + return bt.RUNNING + if status == bt.Status.Running: + return bt.RUNNING + state.work = None + return (status, args.dt) + case MissionAction.ON_TIMEOUT: + print("do_work timed out!") + return (bt.Status.Failure, args.dt) + case _: + raise ValueError(f"unknown action: {args.action!r}") + + return cb + + +async def main() -> None: + timeout_s = 0.6 + tree = bt.Sequence([ + bt.Race([ + bt.Action(MissionAction.DO_WORK), + bt.Sequence([bt.Wait(timeout_s), bt.Action(MissionAction.ON_TIMEOUT)]), + ]), + ]) + tree_bt = bt.BT(tree, None) + state = MissionState() + callback = make_callback(state) + last = time.perf_counter() + + while True: + await asyncio.sleep(0.05) + now = time.perf_counter() + dt = now - last + last = now + result = tree_bt.tick(dt, callback) + if result is None: + break + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/bonsai-py/examples/simple_npc_ai.py b/bonsai-py/examples/simple_npc_ai.py new file mode 100644 index 0000000..1c858e4 --- /dev/null +++ b/bonsai-py/examples/simple_npc_ai.py @@ -0,0 +1,139 @@ +""" +Console NPC behavior demo. + +An NPC runs and shoots while it has action points. When exhausted, it rests +until fully recovered, then dies. Built from a nested `WhileAll`: + + WhileAll(IsDead, [ + WhileAll(HasActionPointsLeft, [Run, Shoot]), + Rest, + Die, + ]) + +Demonstrates `WhileAll` looping, blackboard mutation through a `@dataclass`, +and an enum action dispatched via Python's structural-match callback. + +Run: + python bonsai-py/examples/simple_npc_ai.py +""" +from __future__ import annotations + +import enum +from dataclasses import dataclass +from typing import Any + +import bonsai_bt as bt + + +class EnemyNPC(enum.Enum): + RUN = enum.auto() + SHOOT = enum.auto() + HAS_ACTION_POINTS_LEFT = enum.auto() + REST = enum.auto() + DIE = enum.auto() + IS_DEAD = enum.auto() + + +@dataclass +class BlackBoard: + times_shot: int = 0 + + +@dataclass +class NPCState: + action_points: int + max_action_points: int + alive: bool + + def consume_action_point(self) -> None: + self.action_points = max(0, self.action_points - 1) + + def rest(self) -> None: + self.action_points = min(self.action_points + 1, self.max_action_points) + print(f"Rested for a while... Action points: {self.action_points}") + + def die(self) -> None: + print("NPC died...") + self.alive = False + + def is_alive(self) -> bool: + print("NPC is alive..." if self.alive else "NPC is dead...") + return self.alive + + def fully_rested(self) -> bool: + return self.action_points == self.max_action_points + + def perform_action(self, action: str) -> None: + if self.action_points > 0: + self.consume_action_point() + print(f"Performing action: {action}. Action points: {self.action_points}") + else: + print(f"Cannot perform action: {action}. Not enough action points.") + + +def make_callback(state: NPCState): + def cb(args: Any, blackboard: BlackBoard) -> tuple[bt.Status, float]: + match args.action: + case EnemyNPC.RUN: + state.perform_action("run") + return (bt.Status.Success, 0.0) + case EnemyNPC.HAS_ACTION_POINTS_LEFT: + if state.action_points == 0: + print("NPC does not have action points left...") + return (bt.Status.Success, 0.0) + print(f"NPC has action points: {state.action_points}") + return (bt.Status.Running, 0.0) + case EnemyNPC.SHOOT: + state.perform_action("shoot") + blackboard.times_shot += 1 + return (bt.Status.Success, 0.0) + case EnemyNPC.REST: + if state.fully_rested(): + return (bt.Status.Success, 0.0) + state.rest() + return (bt.Status.Running, 0.0) + case EnemyNPC.DIE: + state.die() + return (bt.Status.Success, 0.0) + case EnemyNPC.IS_DEAD: + if state.is_alive(): + return (bt.Status.Running, 0.0) + return (bt.Status.Success, 0.0) + case _: + raise ValueError(f"unknown action: {args.action!r}") + + return cb + + +def build_tree() -> bt.Behavior: + run_and_shoot = bt.WhileAll( + bt.Action(EnemyNPC.HAS_ACTION_POINTS_LEFT), + [bt.Action(EnemyNPC.RUN), bt.Action(EnemyNPC.SHOOT)], + ) + return bt.WhileAll( + bt.Action(EnemyNPC.IS_DEAD), + [run_and_shoot, bt.Action(EnemyNPC.REST), bt.Action(EnemyNPC.DIE)], + ) + + +def main() -> None: + max_actions = 3 + blackboard = BlackBoard() + state = NPCState(action_points=max_actions, max_action_points=max_actions, alive=True) + tree_bt = bt.BT(build_tree(), blackboard) + callback = make_callback(state) + + while True: + print("reached main loop...") + result = tree_bt.tick(0.0, callback) + if result is None: + break + status, _ = result + if status != bt.Status.Running: + break + + print(f"NPC shot {blackboard.times_shot} times during the simulation.") + + +if __name__ == "__main__": + main() diff --git a/bonsai-py/examples/threaded_drone.py b/bonsai-py/examples/threaded_drone.py new file mode 100644 index 0000000..21040a8 --- /dev/null +++ b/bonsai-py/examples/threaded_drone.py @@ -0,0 +1,210 @@ +""" +Threaded drone mission with multi-job orchestration. + +A drone takes off, checks battery, flies to a mission point (or falls +back to landing at the dock if the battery is low), then lands — +repeating while a background collision-avoidance task runs: + + While( + cond = AvoidOthers, + body = [ + TakeOff, + Select([ + Sequence([CheckBattery, FlyToPoint(10, 10, 10)]), + FlyToPoint(0, 0, 0), # dock fallback + ]), + Land, + ], + ) + +Each long-running action runs on its own `threading.Thread` and reports +status through a per-job `queue.Queue`. The BT polls non-blockingly with +`get_nowait()`. The script prints the tree's `graphviz()` at the start, +then runs the mission until a wall-clock cap. + +**Pick this variant when actions block on hardware or sync IO** — +blocking syscalls, blocking C extensions, vendor SDKs without an async +API, `subprocess.run`, file IO without `aiofiles`, etc. The OS scheduler +runs the threads in parallel, so a blocking call in one job doesn't +freeze the others. + +See `async_drone.py` for the asyncio-native variant of the same mission +(pick that when actions are awaitable — `aiohttp`, async DB drivers, +websockets, async stream readers). + +Demonstrates `Select` for prioritized fallback, multi-job orchestration +via per-job channels, threading + queue.Queue, and `BT.graphviz()` for +static tree visualization at startup. + +Run: + python bonsai-py/examples/threaded_drone.py +""" +from __future__ import annotations + +import enum +import queue +import random +import threading +import time +from dataclasses import dataclass +from typing import Any, Optional + +import bonsai_bt as bt + +MAX_WALL_SECONDS = 8.0 # Demo cap; the BT itself would loop forever. + + +class DroneAction(enum.Enum): + AVOID_OTHERS = enum.auto() + TAKE_OFF = enum.auto() + LAND = enum.auto() + CHECK_BATTERY = enum.auto() + + +@dataclass(frozen=True) +class FlyToPoint: + x: float + y: float + z: float + + +@dataclass +class DroneState: + avoid_others: Optional[queue.Queue[bt.Status]] = None + takeoff: Optional[queue.Queue[bt.Status]] = None + land: Optional[queue.Queue[bt.Status]] = None + fly_to_point: Optional[queue.Queue[bt.Status]] = None + + +# ---- Background "jobs": one thread per long-running action ----------------- +# Each pushes Status.Running every step, then Status.Success when done. + +def collision_avoidance_task(q: queue.Queue[bt.Status]) -> None: + print("collision avoidance task started") + while True: + try: + q.put(bt.Status.Running, timeout=1.0) + except queue.Full: + return + time.sleep(0.1) + + +def takeoff_task(q: queue.Queue[bt.Status]) -> None: + print("takeoff task started") + for i in range(3): + q.put(bt.Status.Running) + time.sleep(0.3) + print(f"takeoff task running for {(i + 1) * 300} ms") + print("takeoff task finished") + q.put(bt.Status.Success) + + +def landing_task(q: queue.Queue[bt.Status]) -> None: + print("landing task started") + for i in range(3): + q.put(bt.Status.Running) + time.sleep(0.3) + print(f"landing task running for {(i + 1) * 300} ms") + print("landing task finished") + q.put(bt.Status.Success) + + +def fly_to_point_task(point: FlyToPoint, q: queue.Queue[bt.Status]) -> None: + print(f"flying task started: target=({point.x}, {point.y}, {point.z})") + for i in range(3): + q.put(bt.Status.Running) + time.sleep(0.5) + print(f"flying task running for {(i + 1) * 500} ms") + print("flying task finished") + q.put(bt.Status.Success) + + +# ---- Polling helpers ------------------------------------------------------- + +def poll_job( + q_attr: str, + state: DroneState, + spawn: callable, # type: ignore[type-arg] + dt: float, +) -> tuple[bt.Status, float]: + """Generic 'spawn on first call, poll thereafter' pattern for any threaded job.""" + q = getattr(state, q_attr) + if q is None: + q = queue.Queue() + setattr(state, q_attr, q) + threading.Thread(target=spawn, args=(q,), daemon=True).start() + try: + status = q.get_nowait() + except queue.Empty: + return bt.RUNNING + if status == bt.Status.Running: + return bt.RUNNING + setattr(state, q_attr, None) + return (status, dt) + + +def make_callback(state: DroneState, rng: random.Random): + def cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + action = args.action + if action == DroneAction.AVOID_OTHERS: + return poll_job("avoid_others", state, collision_avoidance_task, args.dt) + if action == DroneAction.TAKE_OFF: + return poll_job("takeoff", state, takeoff_task, args.dt) + if action == DroneAction.LAND: + return poll_job("land", state, landing_task, args.dt) + if action == DroneAction.CHECK_BATTERY: + # Fast sync action: 80% chance OK, 20% chance low -> Select fallback fires. + ok = rng.random() < 0.8 + print(f"check battery: {'OK' if ok else 'LOW'}") + return (bt.Status.Success if ok else bt.Status.Failure, args.dt) + if isinstance(action, FlyToPoint): + spawn = lambda q, p=action: fly_to_point_task(p, q) + return poll_job("fly_to_point", state, spawn, args.dt) + raise ValueError(f"unknown action: {action!r}") + + return cb + + +def build_tree() -> bt.Behavior: + fly_if_healthy = bt.Sequence([ + bt.Action(DroneAction.CHECK_BATTERY), + bt.Action(FlyToPoint(10.0, 10.0, 10.0)), + ]) + fly_to_dock = bt.Action(FlyToPoint(0.0, 0.0, 0.0)) + mission_with_fallback = bt.Select([fly_if_healthy, fly_to_dock]) + return bt.While( + bt.Action(DroneAction.AVOID_OTHERS), + [ + bt.Action(DroneAction.TAKE_OFF), + mission_with_fallback, + bt.Action(DroneAction.LAND), + ], + ) + + +def main() -> None: + tree_bt = bt.BT(build_tree(), None) + state = DroneState() + rng = random.Random(0) + callback = make_callback(state, rng) + print("=== drone tree ===") + print(tree_bt.graphviz()) + print("=== mission start ===") + + start = time.perf_counter() + last = start + while True: + time.sleep(0.5) + now = time.perf_counter() + dt = now - last + last = now + result = tree_bt.tick(dt, callback) + if result is None: + break + if now - start > MAX_WALL_SECONDS: + print(f"=== demo cap reached ({MAX_WALL_SECONDS}s) — exiting ===") + break + + +if __name__ == "__main__": + main() diff --git a/bonsai-py/examples/visualizer_smoke.py b/bonsai-py/examples/visualizer_smoke.py new file mode 100644 index 0000000..be1de5a --- /dev/null +++ b/bonsai-py/examples/visualizer_smoke.py @@ -0,0 +1,118 @@ +""" +End-to-end demo for the WebSocket visualizer. + +Drives a deliberately rich 27-node tree (covering 12 of the 14 Behavior +factories), attaches the visualizer via `BT.with_telemetry(8910)`, and +re-runs the tree every ~400 ms wall tick. Each leaf's status follows a +5-step rotation with a per-action phase offset, so a varied mix of +green / yellow / red is visible at any moment. After each complete run, +`reset_bt()` rewinds the cursor; `tick_count` and the telemetry +connection survive, so the browser sees a continuous TickTrace stream +with monotonic `tick_id`. + +Demonstrates `with_telemetry`, `reset_bt`, every major factory in one +tree, and a deterministic-cycle callback contract. + +Run: + python bonsai-py/examples/visualizer_smoke.py + +Then open in a browser. + 1. Tree renders within ~1 s; status bar reads `connected` and `27 nodes`. + 2. Every ~400 ms leaf colors shift across all subtrees. + 3. `Ctrl-C` and restart -> browser reconnects within <=1 s. + +Port 8910 must be free; if it is busy, the script raises OSError. +""" +from __future__ import annotations + +import time +from typing import Any + +from bonsai_bt import * # noqa: F401,F403 + + +def build_tree() -> Behavior: + return Sequence([ + If( + Action("low_hp"), + AlwaysSucceed(Action("flee")), + Action("regroup"), + ), + Select([ + Sequence([ + Action("acquire_target"), + WhenAll([Action("aim"), Action("track")]), + ]), + Race([Action("dodge"), Wait(2.0)]), + Invert(Action("enemy_visible")), + ]), + While(Action("has_ammo"), [Action("fire"), Wait(0.3)]), + After([Action("cooldown"), Action("ready_signal")]), + WhenAny([Action("victory_check"), Action("retreat_signal")]), + ]) + + +# Five-step status cycle visible across all three colors. Each action has a +# unique phase offset so the same wall tick produces a varied mix of statuses +# across the tree (and yellow-Running shows up). +CYCLE = ( + Status.Success, + Status.Running, + Status.Failure, + Status.Success, + Status.Running, +) + +PHASE_OFFSET = { + "low_hp": 0, + "flee": 1, + "regroup": 2, + "acquire_target": 3, + "aim": 4, + "track": 0, + "dodge": 1, + "enemy_visible": 2, + "has_ammo": 3, + "fire": 4, + "cooldown": 0, + "ready_signal": 1, + "victory_check": 2, + "retreat_signal": 3, +} + +# Four leaves whose Failure would short-circuit the root Sequence before +# downstream subtrees ever render. Substitute Running for Failure on these so +# the chain reaches the bottom branches. They still show Success and Running. +KEEP_ALIVE = {"regroup", "has_ammo", "cooldown", "ready_signal"} + + +def make_callback(tick_n_ref: list[int]): + def cb(args: Any, _bb: Any) -> tuple[Status, float]: + phase = PHASE_OFFSET.get(args.action, 0) + idx = (tick_n_ref[0] + phase) % len(CYCLE) + status = CYCLE[idx] + if args.action in KEEP_ALIVE and status == Status.Failure: + status = Status.Running + return (status, 0.0) + + return cb + + +def main() -> None: + tree_bt = BT(build_tree(), None).with_telemetry(8910) + tick_n_ref = [0] + callback = make_callback(tick_n_ref) + print("bonsai-bt visualizer: open http://127.0.0.1:8910/") + + while True: + tick_n_ref[0] += 1 + result = tree_bt.tick(1.0, callback) + if result is not None: + status, _ = result + if status in (Status.Success, Status.Failure): + tree_bt.reset_bt() + time.sleep(0.4) + + +if __name__ == "__main__": + main() diff --git a/bonsai-py/pyproject.toml b/bonsai-py/pyproject.toml new file mode 100644 index 0000000..2886f74 --- /dev/null +++ b/bonsai-py/pyproject.toml @@ -0,0 +1,51 @@ +[build-system] +requires = ["maturin>=1.7,<2.0"] +build-backend = "maturin" + +[project] +name = "bonsai-bt" +description = "Behavior trees in Python, powered by the bonsai-bt Rust crate." +readme = "README.md" +license = { text = "MIT" } +requires-python = ">=3.10" +authors = [ + { name = "Kristoffer Solberg Rakstad", email = "solkristoffer@gmail.com" }, + { name = "Anmol Kathail", email = "anmolkathail@gmail.com" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Rust", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Games/Entertainment", + "Typing :: Typed", +] +dynamic = ["version"] + +[project.urls] +Homepage = "https://github.com/sollimann/bonsai" +Repository = "https://github.com/sollimann/bonsai" +Issues = "https://github.com/sollimann/bonsai/issues" + +[tool.maturin] +module-name = "bonsai_bt" +python-source = "python" +features = ["pyo3/extension-module"] +strip = true +# Exclude local bytecode caches from the wheel. +exclude = ["python/**/__pycache__/**", "python/**/*.pyc"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +markers = [ + "perf: performance budget tests (always-on; generous bounds)", + "bench: heavy microbenchmarks (opt-in via `pytest -m bench`)", +] +timeout = 10 diff --git a/bonsai-py/python/bonsai_bt/__init__.py b/bonsai-py/python/bonsai_bt/__init__.py new file mode 100644 index 0000000..f7b06ed --- /dev/null +++ b/bonsai-py/python/bonsai_bt/__init__.py @@ -0,0 +1,22 @@ +"""bonsai-bt - behavior trees in Python, powered by the Rust bonsai-bt crate.""" + +from importlib.metadata import PackageNotFoundError, version as _version + +try: + __version__ = _version("bonsai-bt") +except PackageNotFoundError: # editable install before metadata is in place + __version__ = "0.0.0+unknown" + +from .bonsai_bt import * # noqa: F401,F403 (re-export the compiled module) + +__all__ = [ + # types + "Status", "ActionArgs", "Behavior", "BT", + # factories (leaves, decorators, composites, control flow) + "Action", "Wait", "WaitForever", + "Invert", "AlwaysSucceed", + "Sequence", "Select", "WhenAll", "WhenAny", "After", "Race", + "If", "While", "WhileAll", + # constants + "RUNNING", +] diff --git a/bonsai-py/python/bonsai_bt/__init__.pyi b/bonsai-py/python/bonsai_bt/__init__.pyi new file mode 100644 index 0000000..8e0bc77 --- /dev/null +++ b/bonsai-py/python/bonsai_bt/__init__.pyi @@ -0,0 +1,128 @@ +# This file is automatically generated by pyo3_stub_gen +# ruff: noqa: E501, F401, F403, F405 + +import builtins +import enum +import typing +__all__ = [ + "Action", + "ActionArgs", + "After", + "AlwaysSucceed", + "BT", + "Behavior", + "If", + "Invert", + "Race", + "Select", + "Sequence", + "Status", + "Wait", + "WaitForever", + "WhenAll", + "WhenAny", + "While", + "WhileAll", + "RUNNING", +] + +@typing.final +class ActionArgs: + r""" + Action callback arguments. + + Constructed by the tick bridge and passed to the user's callback. + The Rust `ActionArgs::event` field is intentionally not exposed — + Python users only see `dt` and `action`. + """ + @property + def dt(self) -> builtins.float: + r""" + Remaining delta time in seconds. + """ + @property + def action(self) -> typing.Any: + r""" + The user-supplied action value (whatever was passed to `bt.Action(...)`). + """ + def __new__(cls, dt: builtins.float, action: typing.Any) -> ActionArgs: ... + def __repr__(self) -> builtins.str: ... + +@typing.final +class BT: + r""" + A behavior-tree executor wrapping `bonsai_bt::BT`. + + Construct from a tree and a blackboard, then drive with `.tick(dt, callback)`. + The callback receives `(args, blackboard)` and must return `(Status, float)`. + """ + def __new__(cls, behavior: Behavior, blackboard: typing.Any) -> BT: ... + def tick(self, dt: builtins.float, callback: typing.Any) -> typing.Optional[tuple[Status, builtins.float]]: ... + def blackboard(self) -> typing.Any: ... + def reset_bt(self) -> None: ... + def tick_count(self) -> builtins.int: ... + def is_finished(self) -> builtins.bool: ... + def graphviz(self) -> builtins.str: ... + def with_telemetry(self, port: builtins.int, host: builtins.str = '127.0.0.1') -> BT: ... + +@typing.final +class Behavior: + r""" + An opaque behavior-tree node. + + Construct via the factory functions (`Sequence`, `Action`, `Wait`, ...) + at the module level. Subtrees are reusable - the same `Behavior` + can appear as a child of multiple parents. + """ + def __repr__(self) -> builtins.str: ... + +@typing.final +class Status(enum.Enum): + r""" + Behavior-tree node result. + + Mirrors `bonsai_bt::Status`. Comparable to `int` + (`Status.Success == 0`, `Failure == 1`, `Running == 2`) and usable + as a `dict` key or `set` member. + """ + Success = ... + Failure = ... + Running = ... + + def __reduce__(self) -> tuple[typing.Any, tuple[typing.Any, builtins.str]]: + r""" + Pickle support: name the singleton by class + variant name, since + PyO3 simple enums refuse construction by call (`Status(0)` raises). + """ + +def Action(action: typing.Any) -> Behavior: ... + +def After(children: typing.Sequence[Behavior]) -> Behavior: ... + +def AlwaysSucceed(child: Behavior) -> Behavior: ... + +def If(cond: Behavior, on_success: Behavior, on_failure: Behavior) -> Behavior: ... + +def Invert(child: Behavior) -> Behavior: ... + +def Race(children: typing.Sequence[Behavior]) -> Behavior: ... + +def Select(children: typing.Sequence[Behavior]) -> Behavior: ... + +def Sequence(children: typing.Sequence[Behavior]) -> Behavior: ... + +def Wait(seconds: builtins.float) -> Behavior: ... + +def WaitForever() -> Behavior: ... + +def WhenAll(children: typing.Sequence[Behavior]) -> Behavior: ... + +def WhenAny(children: typing.Sequence[Behavior]) -> Behavior: ... + +def While(cond: Behavior, body: typing.Sequence[Behavior]) -> Behavior: ... + +def WhileAll(cond: Behavior, body: typing.Sequence[Behavior]) -> Behavior: ... + + +RUNNING: typing.Final[tuple[Status, builtins.float]] +r"""Convenience constant: ``(Status.Running, 0.0)`` - return from a tick callback to keep the action running.""" diff --git a/bonsai-py/python/bonsai_bt/py.typed b/bonsai-py/python/bonsai_bt/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/bonsai-py/scripts/README.md b/bonsai-py/scripts/README.md new file mode 100644 index 0000000..fde355a --- /dev/null +++ b/bonsai-py/scripts/README.md @@ -0,0 +1,22 @@ +# bonsai-py scripts + +Developer scripts for working on the `bonsai_bt` extension module. + +## Scripts + +| Script | Purpose | +|---|---| +| [regen-stubs.sh](regen-stubs.sh) | Regenerates `python/bonsai_bt/__init__.pyi` from the `#[gen_stub_*]` annotations on the Rust side. Run after editing any annotated `#[pyclass]` / `#[pyfunction]` / `#[pymethods]`. Also runs automatically via the `regen-stubs` pre-commit hook and is enforced in CI. | + +## Prerequisites + +A Python venv with the `bonsai_bt` extension built in. See [../README.md](../README.md#installation-dev) for the one-time setup (`python -m venv .venv`, activate, `pip install maturin`, `maturin develop --release`). + +## Running + +From the repository root, with the venv activated: + +```bash +# Regenerate stubs after editing Rust annotations +bash bonsai-py/scripts/regen-stubs.sh +``` diff --git a/bonsai-py/scripts/regen-stubs.sh b/bonsai-py/scripts/regen-stubs.sh new file mode 100644 index 0000000..019efaf --- /dev/null +++ b/bonsai-py/scripts/regen-stubs.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash +# Regenerate the type stub and apply manual touch-ups (RUNNING constant). +# Run after editing any #[gen_stub_*] annotation. +# +# Usage: ./bonsai-py/scripts/regen-stubs.sh +set -euo pipefail + +cd "$(dirname "$0")/.." # cd into bonsai-py/ +cargo run --quiet --bin stub_gen -p bonsai-py + +# pyo3-stub-gen 0.22 reads pyproject.toml and writes the stub to +# python/bonsai_bt/__init__.pyi automatically. +STUB=python/bonsai_bt/__init__.pyi +if [ ! -f "$STUB" ]; then + echo "ERROR: $STUB was not generated. Did the binary fail silently?" >&2 + exit 1 +fi + +# pyo3-stub-gen doesn't introspect m.add() module-level constants, so +# the RUNNING declaration must be appended manually (idempotent). +if ! grep -q "^RUNNING: " "$STUB"; then + { + echo "" + echo "RUNNING: typing.Final[tuple[Status, builtins.float]]" + echo 'r"""Convenience constant: ``(Status.Running, 0.0)`` - return from a tick callback to keep the action running."""' + } >> "$STUB" +fi + +# Add RUNNING to __all__ if not already present. +if ! grep -q '"RUNNING"' "$STUB"; then + # Insert before the closing ']' of the __all__ list. + sed -i '/^__all__ = \[/,/^]/{/^]/i\ "RUNNING", +}' "$STUB" +fi + +# Strip trailing whitespace defensively. +sed -i 's/[[:space:]]*$//' "$STUB" +echo "Regenerated $STUB" diff --git a/bonsai-py/src/action_args.rs b/bonsai-py/src/action_args.rs new file mode 100644 index 0000000..63c9619 --- /dev/null +++ b/bonsai-py/src/action_args.rs @@ -0,0 +1,48 @@ +use bonsai_bt::{ActionArgs, Event}; +use pyo3::prelude::*; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; + +use crate::behavior::PyAction; + +/// Action callback arguments. +/// +/// Constructed by the tick bridge and passed to the user's callback. +/// The Rust `ActionArgs::event` field is intentionally not exposed — +/// Python users only see `dt` and `action`. +#[gen_stub_pyclass] +#[pyclass(frozen, module = "bonsai_bt", name = "ActionArgs")] +pub struct PyActionArgs { + /// Remaining delta time in seconds. + #[pyo3(get)] + pub dt: f64, + /// The user-supplied action value (whatever was passed to `bt.Action(...)`). + #[pyo3(get)] + pub action: Py, +} + +#[gen_stub_pymethods] +#[pymethods] +impl PyActionArgs { + #[new] + fn py_new(dt: f64, action: Py) -> Self { + Self { dt, action } + } + + fn __repr__(&self, py: Python<'_>) -> PyResult { + let action_repr = self.action.bind(py).repr()?.to_string(); + Ok(format!("ActionArgs(dt={}, action={})", self.dt, action_repr)) + } +} + +impl PyActionArgs { + /// Build a `PyActionArgs` from the Rust `ActionArgs` that the tick + /// callback receives. Hot path — one `clone_ref` plus an `f64` copy. + pub(crate) fn from_rust(args: &ActionArgs, py: Python<'_>) -> Self { + Self { + // `args.dt` is `bonsai_bt::Float` (f32 or f64 per feature). Cast to + // f64 at the Python boundary — Python's `float` is always f64. + dt: args.dt as f64, + action: args.action.0.clone_ref(py), + } + } +} diff --git a/bonsai-py/src/behavior.rs b/bonsai-py/src/behavior.rs new file mode 100644 index 0000000..bf38740 --- /dev/null +++ b/bonsai-py/src/behavior.rs @@ -0,0 +1,204 @@ +use bonsai_bt::Behavior; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods}; + +/// Wrapper around `Py` that satisfies `bonsai_bt::BT`'s +/// `A: Clone + Debug` bounds. +/// +/// `Py` itself is not `Clone` in PyO3 0.28 (the trait was removed +/// because `.clone()` cannot statically prove the GIL is held). We satisfy +/// the bound by acquiring the GIL inside the `Clone` impl via +/// `Python::attach` and forwarding to `clone_ref(py)`. Inside a +/// `#[pymethods]` context the GIL is already held, so re-entry is cheap. +pub(crate) struct PyAction(pub(crate) Py); + +impl Clone for PyAction { + fn clone(&self) -> Self { + Python::attach(|py| PyAction(self.0.clone_ref(py))) + } +} + +impl std::fmt::Debug for PyAction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Show the Python repr in tree-definition output (used by the + // telemetry visualizer). Fall back to a placeholder if repr fails. + Python::attach(|py| match self.0.bind(py).repr() { + Ok(s) => write!(f, "{s}"), + Err(_) => write!(f, ""), + }) + } +} + +/// An opaque behavior-tree node. +/// +/// Construct via the factory functions (`Sequence`, `Action`, `Wait`, ...) +/// at the module level. Subtrees are reusable - the same `Behavior` +/// can appear as a child of multiple parents. +#[gen_stub_pyclass] +#[pyclass(unsendable, frozen, module = "bonsai_bt", name = "Behavior")] +pub struct PyBehavior { + pub(crate) inner: Behavior, +} + +impl PyBehavior { + fn wrap(inner: Behavior) -> Self { + Self { inner } + } +} + +#[gen_stub_pymethods] +#[pymethods] +impl PyBehavior { + fn __repr__(&self) -> String { + match &self.inner { + Behavior::Wait(t) => format!("Wait({t})"), + Behavior::WaitForever => "WaitForever".to_string(), + Behavior::Action(_) => "Action(...)".to_string(), + Behavior::Invert(_) => "Invert(...)".to_string(), + Behavior::AlwaysSucceed(_) => "AlwaysSucceed(...)".to_string(), + Behavior::Select(v) => format!("Select({})", v.len()), + Behavior::If(_, _, _) => "If(...)".to_string(), + Behavior::Sequence(v) => format!("Sequence({})", v.len()), + Behavior::While(_, body) => format!("While({})", body.len()), + Behavior::WhileAll(_, body) => format!("WhileAll({})", body.len()), + Behavior::WhenAll(v) => format!("WhenAll({})", v.len()), + Behavior::WhenAny(v) => format!("WhenAny({})", v.len()), + Behavior::After(v) => format!("After({})", v.len()), + Behavior::Race(v) => format!("Race({})", v.len()), + } + } +} + +fn collect_children(children: Vec>) -> Vec> { + children.iter().map(|c| c.inner.clone()).collect() +} + +// ----- Leaves -------------------------------------------------------------- + +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(name = "Action")] +pub fn action_fn(action: Py) -> PyBehavior { + PyBehavior::wrap(Behavior::Action(PyAction(action))) +} + +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(name = "Wait")] +pub fn wait_fn(seconds: f64) -> PyResult { + if seconds.is_nan() { + return Err(PyValueError::new_err("Wait: seconds must not be NaN")); + } + Ok(PyBehavior::wrap(Behavior::Wait(seconds as bonsai_bt::Float))) +} + +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(name = "WaitForever")] +pub fn wait_forever_fn() -> PyBehavior { + PyBehavior::wrap(Behavior::WaitForever) +} + +// ----- Decorators ---------------------------------------------------------- + +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(name = "Invert")] +pub fn invert_fn(child: PyRef<'_, PyBehavior>) -> PyBehavior { + PyBehavior::wrap(Behavior::Invert(Box::new(child.inner.clone()))) +} + +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(name = "AlwaysSucceed")] +pub fn always_succeed_fn(child: PyRef<'_, PyBehavior>) -> PyBehavior { + PyBehavior::wrap(Behavior::AlwaysSucceed(Box::new(child.inner.clone()))) +} + +// ----- Composites (Vec) ----------------------------------------- + +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(name = "Sequence")] +pub fn sequence_fn(children: Vec>) -> PyBehavior { + PyBehavior::wrap(Behavior::Sequence(collect_children(children))) +} + +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(name = "Select")] +pub fn select_fn(children: Vec>) -> PyBehavior { + PyBehavior::wrap(Behavior::Select(collect_children(children))) +} + +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(name = "WhenAll")] +pub fn when_all_fn(children: Vec>) -> PyBehavior { + PyBehavior::wrap(Behavior::WhenAll(collect_children(children))) +} + +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(name = "WhenAny")] +pub fn when_any_fn(children: Vec>) -> PyBehavior { + PyBehavior::wrap(Behavior::WhenAny(collect_children(children))) +} + +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(name = "After")] +pub fn after_fn(children: Vec>) -> PyBehavior { + PyBehavior::wrap(Behavior::After(collect_children(children))) +} + +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(name = "Race")] +pub fn race_fn(children: Vec>) -> PyBehavior { + PyBehavior::wrap(Behavior::Race(collect_children(children))) +} + +// ----- Control flow -------------------------------------------------------- + +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(name = "If")] +pub fn if_fn( + cond: PyRef<'_, PyBehavior>, + on_success: PyRef<'_, PyBehavior>, + on_failure: PyRef<'_, PyBehavior>, +) -> PyBehavior { + PyBehavior::wrap(Behavior::If( + Box::new(cond.inner.clone()), + Box::new(on_success.inner.clone()), + Box::new(on_failure.inner.clone()), + )) +} + +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(name = "While")] +pub fn while_fn(cond: PyRef<'_, PyBehavior>, body: Vec>) -> PyResult { + if body.is_empty() { + return Err(PyValueError::new_err("While: body must not be empty")); + } + Ok(PyBehavior::wrap(Behavior::While( + Box::new(cond.inner.clone()), + collect_children(body), + ))) +} + +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(name = "WhileAll")] +pub fn while_all_fn(cond: PyRef<'_, PyBehavior>, body: Vec>) -> PyResult { + if body.is_empty() { + return Err(PyValueError::new_err("WhileAll: body must not be empty")); + } + Ok(PyBehavior::wrap(Behavior::WhileAll( + Box::new(cond.inner.clone()), + collect_children(body), + ))) +} diff --git a/bonsai-py/src/bin/stub_gen.rs b/bonsai-py/src/bin/stub_gen.rs new file mode 100644 index 0000000..96f2274 --- /dev/null +++ b/bonsai-py/src/bin/stub_gen.rs @@ -0,0 +1,15 @@ +//! Stub-generator binary. Builds the `.pyi` from the `#[gen_stub_*]` +//! annotations sprinkled across the binding crate. +//! +//! Run with `cargo run --bin stub_gen -p bonsai-py`. pyo3-stub-gen reads +//! `pyproject.toml` to determine the package layout and writes the stub +//! to `python/bonsai_bt/__init__.pyi`. The companion `scripts/regen-stubs.sh` +//! appends the manual `RUNNING` constant declaration afterwards. + +use pyo3_stub_gen::Result; + +fn main() -> Result<()> { + let stub = bonsai_py::stub_info()?; + stub.generate()?; + Ok(()) +} diff --git a/bonsai-py/src/bt.rs b/bonsai-py/src/bt.rs new file mode 100644 index 0000000..7baa048 --- /dev/null +++ b/bonsai-py/src/bt.rs @@ -0,0 +1,111 @@ +use bonsai_bt::{Event, Status, UpdateArgs, BT}; +use pyo3::exceptions::{PyOSError, PyRuntimeError}; +use pyo3::prelude::*; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; + +use crate::action_args::PyActionArgs; +use crate::behavior::{PyAction, PyBehavior}; +use crate::status::PyStatus; + +const POISONED_MSG: &str = "BT was invalidated by a failed with_telemetry call; construct a new BT"; + +/// A behavior-tree executor wrapping `bonsai_bt::BT`. +/// +/// Construct from a tree and a blackboard, then drive with `.tick(dt, callback)`. +/// The callback receives `(args, blackboard)` and must return `(Status, float)`. +#[gen_stub_pyclass] +#[pyclass(unsendable, module = "bonsai_bt", name = "BT")] +pub struct PyBT { + inner: Option>>, +} + +impl PyBT { + fn require_inner(&self) -> PyResult<&BT>> { + self.inner.as_ref().ok_or_else(|| PyRuntimeError::new_err(POISONED_MSG)) + } + + fn require_inner_mut(&mut self) -> PyResult<&mut BT>> { + self.inner.as_mut().ok_or_else(|| PyRuntimeError::new_err(POISONED_MSG)) + } +} + +#[gen_stub_pymethods] +#[pymethods] +impl PyBT { + #[new] + fn py_new(behavior: PyRef<'_, PyBehavior>, blackboard: Py) -> Self { + let tree = behavior.inner.clone(); + Self { + inner: Some(BT::new(tree, blackboard)), + } + } + + fn tick(&mut self, py: Python<'_>, dt: f64, callback: Py) -> PyResult> { + let inner = self.require_inner_mut()?; + // `UpdateArgs.dt` is `bonsai_bt::Float`; cast from the f64 Python input. + let event: Event = UpdateArgs { + dt: dt as bonsai_bt::Float, + } + .into(); + let mut cb_err: Option = None; + let result = inner.tick(&event, &mut |args, bb: &mut Py| { + if cb_err.is_some() { + return (Status::Failure, 0.0); + } + let py_args = PyActionArgs::from_rust(&args, py); + let bb_ref = bb.clone_ref(py); + match callback.call1(py, (py_args, bb_ref)) { + // Callback returns Python f64; cast back to `bonsai_bt::Float`. + Ok(ret) => match ret.extract::<(PyStatus, f64)>(py) { + Ok((s, remaining)) => (s.into(), remaining as bonsai_bt::Float), + Err(e) => { + cb_err = Some(e); + (Status::Failure, 0.0) + } + }, + Err(e) => { + cb_err = Some(e); + (Status::Failure, 0.0) + } + } + }); + if let Some(e) = cb_err { + return Err(e); + } + // Tick result's `Float` -> f64 for the Python return tuple. + Ok(result.map(|(s, dt)| (s.into(), dt as f64))) + } + + fn blackboard(&self, py: Python<'_>) -> PyResult> { + Ok(self.require_inner()?.blackboard().clone_ref(py)) + } + + fn reset_bt(&mut self) -> PyResult<()> { + self.require_inner_mut()?.reset_bt(); + Ok(()) + } + + fn tick_count(&self) -> PyResult { + Ok(self.require_inner()?.tick_count()) + } + + fn is_finished(&self) -> PyResult { + Ok(self.require_inner()?.is_finished()) + } + + fn graphviz(&mut self) -> PyResult { + Ok(self.require_inner_mut()?.get_graphviz()) + } + + #[pyo3(signature = (port, host = "127.0.0.1"))] + fn with_telemetry<'py>(mut slf: PyRefMut<'py, Self>, port: u16, host: &str) -> PyResult> { + let inner = slf.inner.take().ok_or_else(|| PyRuntimeError::new_err(POISONED_MSG))?; + match inner.with_telemetry_at(host, port) { + Ok(new_inner) => { + slf.inner = Some(new_inner); + Ok(slf) + } + Err(e) => Err(PyOSError::new_err(format!("with_telemetry({host}:{port}) failed: {e}"))), + } + } +} diff --git a/bonsai-py/src/lib.rs b/bonsai-py/src/lib.rs new file mode 100644 index 0000000..4390fca --- /dev/null +++ b/bonsai-py/src/lib.rs @@ -0,0 +1,50 @@ +use pyo3::prelude::*; +use pyo3_stub_gen::define_stub_info_gatherer; + +mod action_args; +mod behavior; +mod bt; +mod status; + +use action_args::PyActionArgs; +use behavior::{ + action_fn, after_fn, always_succeed_fn, if_fn, invert_fn, race_fn, select_fn, sequence_fn, wait_fn, + wait_forever_fn, when_all_fn, when_any_fn, while_all_fn, while_fn, PyBehavior, +}; +use bt::PyBT; +use status::PyStatus; + +/// Python bindings for the bonsai-bt behavior-tree library. +/// +/// Construct trees with the factory functions (Sequence, Action, Wait, ...), +/// wrap one in `BT(tree, blackboard)`, and drive it with `bt.tick(dt, callback)`. +#[pymodule] +fn bonsai_bt(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_function(wrap_pyfunction!(action_fn, m)?)?; + m.add_function(wrap_pyfunction!(wait_fn, m)?)?; + m.add_function(wrap_pyfunction!(wait_forever_fn, m)?)?; + m.add_function(wrap_pyfunction!(invert_fn, m)?)?; + m.add_function(wrap_pyfunction!(always_succeed_fn, m)?)?; + m.add_function(wrap_pyfunction!(sequence_fn, m)?)?; + m.add_function(wrap_pyfunction!(select_fn, m)?)?; + m.add_function(wrap_pyfunction!(when_all_fn, m)?)?; + m.add_function(wrap_pyfunction!(when_any_fn, m)?)?; + m.add_function(wrap_pyfunction!(after_fn, m)?)?; + m.add_function(wrap_pyfunction!(race_fn, m)?)?; + m.add_function(wrap_pyfunction!(if_fn, m)?)?; + m.add_function(wrap_pyfunction!(while_fn, m)?)?; + m.add_function(wrap_pyfunction!(while_all_fn, m)?)?; + + // Convenience constant matching Rust's `bonsai_bt::RUNNING`. + m.add("RUNNING", (PyStatus::Running, 0.0_f64).into_pyobject(py)?)?; + + Ok(()) +} + +// Add pyo3-stub-gen: emits `pub fn stub_info() -> ...` that the +// `stub_gen` binary calls to collect every #[gen_stub_*] annotated item. +define_stub_info_gatherer!(stub_info); diff --git a/bonsai-py/src/status.rs b/bonsai-py/src/status.rs new file mode 100644 index 0000000..cf0c8da --- /dev/null +++ b/bonsai-py/src/status.rs @@ -0,0 +1,72 @@ +use bonsai_bt::Status; +use pyo3::prelude::*; +use pyo3_stub_gen::derive::{gen_stub_pyclass_enum, gen_stub_pymethods}; + +/// Behavior-tree node result. +/// +/// Mirrors `bonsai_bt::Status`. Comparable to `int` +/// (`Status.Success == 0`, `Failure == 1`, `Running == 2`) and usable +/// as a `dict` key or `set` member. +#[gen_stub_pyclass_enum] +#[pyclass(eq, eq_int, hash, frozen, from_py_object, module = "bonsai_bt", name = "Status")] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub enum PyStatus { + Success, + Failure, + Running, +} + +#[gen_stub_pymethods] +#[pymethods] +impl PyStatus { + /// Pickle support: name the singleton by class + variant name, since + /// PyO3 simple enums refuse construction by call (`Status(0)` raises). + // The nested-tuple return shape is dictated by Python's pickle protocol + // (callable, args-tuple); factoring it into a type alias would obscure + // the contract more than it would clarify. + #[allow(clippy::type_complexity)] + fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyAny>, &'static str))> { + let getattr = py.import("builtins")?.getattr("getattr")?; + let cls = py.get_type::().into_any(); + let name = match self { + PyStatus::Success => "Success", + PyStatus::Failure => "Failure", + PyStatus::Running => "Running", + }; + Ok((getattr, (cls, name))) + } +} + +impl From for PyStatus { + fn from(s: Status) -> Self { + match s { + Status::Success => PyStatus::Success, + Status::Failure => PyStatus::Failure, + Status::Running => PyStatus::Running, + } + } +} + +impl From for Status { + fn from(s: PyStatus) -> Self { + match s { + PyStatus::Success => Status::Success, + PyStatus::Failure => Status::Failure, + PyStatus::Running => Status::Running, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn roundtrip_through_rust() { + for s in [Status::Success, Status::Failure, Status::Running] { + let py: PyStatus = s.into(); + let back: Status = py.into(); + assert_eq!(s, back); + } + } +} diff --git a/bonsai-py/tests/README.md b/bonsai-py/tests/README.md new file mode 100644 index 0000000..a6ca0a0 --- /dev/null +++ b/bonsai-py/tests/README.md @@ -0,0 +1,63 @@ +# bonsai-py tests + +Pytest suite for the `bonsai_bt` extension module. Primarily used to prevent any drift between the Rust bindings (`bonsai-py/src/*.rs`) and the Python surface: every `#[pyclass]`, `#[pymethods]`, and `#[pyfunction]` has at least one Python test exercising it. + +## Test files + +| File | What it covers | +|---|---| +| [conftest.py](conftest.py) | Shared fixtures: `free_port`, `noop_callback`, `counting_callback`, `basic_tree`. | +| [test_status.py](test_status.py) | `Status` enum: variants, `eq_int` discriminants (0/1/2), equality, singleton identity, hash, repr, `__module__`, pickle/copy round-trip preserves singleton identity. | +| [test_action_args.py](test_action_args.py) | `ActionArgs`: construct with str/None/dict/object/int-dt; `dt` and `action` are read-only; repr format; `__module__`. | +| [test_behavior.py](test_behavior.py) | All 14 factories exported and callable; each builds with expected repr; validation guards (empty `While`/`WhileAll` body, NaN `Wait`); empty composites + neg/inf/int `Wait` pass through; subtree reuse; kwargs on `If`; tuple accepted / generator rejected for children; identity-based equality per variant; ported Rust `behavior_tests.rs` cases (immediate termination, wait timing, select short-circuit, if branches, invert, always-succeed, when-all, while-loop). | +| [test_bt.py](test_bt.py) | `BT` construction (dict/None/custom blackboard); doctest-equivalent five-tick port; tick return shape; tick on finished returns None; `tick_count` survives `reset_bt`; callback exceptions propagate; bad return shape rejected; `WhenAll` short-circuits siblings after a raise; NaN dt sanitized to 0.0; blackboard identity + mutation persistence; action identity through callback; `reset_bt` preserves blackboard; `RUNNING` constant value and usage. | +| [test_telemetry.py](test_telemetry.py) | `with_telemetry` is chainable; accepts `host` kwarg; second bind on same port raises `OSError`; binding to unreachable IP raises `OSError`; failed `with_telemetry` poisons the BT (every subsequent method raises `RuntimeError`). | +| [test_module.py](test_module.py) | `__version__ == "0.12.0"`; module docstring present; `__all__` lists exactly the 19 expected names; all names accessible; `RUNNING` is `(Status.Running, 0.0)`; `.pyi` stub ships with the wheel. | +| [test_threading_and_pickle.py](test_threading_and_pickle.py) | `BT` is unsendable across threads (PyO3 `PanicException`); `BT`, `Behavior`, `ActionArgs` are unpicklable; `Status` IS picklable (multiprocessing-friendly). | +| [test_performance.py](test_performance.py) | `@pytest.mark.perf`: 100 ticks under 500 ms, 1000 constructions under 5 s. `@pytest.mark.bench` (opt-in): tick throughput microbenchmark. | +| [test_drift.py](test_drift.py) | Checks that every Rust `#[pyo3(name=...)]` symbol appears in `__all__`; every `#[pyclass(name=..., module='bonsai_bt')]` appears in `__all__`; every name in `__all__` is mentioned in at least one other test file. | + +## Prerequisites + +A Python venv with the `bonsai_bt` extension built in. See [../README.md](../README.md#installation-dev) for the one-time setup (`python -m venv .venv`, activate, `pip install maturin`, `maturin develop --release`). + +## Running + +From the repository root, with the venv activated: + +```bash +# Full suite (default — runs perf budget tests, skips benchmarks) +pytest bonsai-py/tests/ + +# Verbose +pytest -v bonsai-py/tests/ + +# Single file +pytest -v bonsai-py/tests/test_status.py + +# Single test +pytest -v bonsai-py/tests/test_bt.py::TestTick::test_doctest_equivalent + +# Drift gate only +pytest -v bonsai-py/tests/test_drift.py + +# Skip perf budget tests +pytest -v -m "not perf" bonsai-py/tests/ + +# Run microbenchmarks only (prints throughput; no assertions) +pytest -v -m bench bonsai-py/tests/ +``` + +A `pytest-timeout` of 10 seconds per test is configured in [pyproject.toml](../pyproject.toml). If you change a test that legitimately needs longer, bump the per-test timeout with `@pytest.mark.timeout(30)`. + +## Dependencies + +```bash +pip install pytest pytest-timeout +``` + +`mypy` is used by [test_mypy_strict.py](test_mypy_strict.py) — install with `pip install mypy`. The test is skipped if `mypy` isn't installed. + +## CI + +The `pytest` job in [.github/workflows/rust-pr.yml](../../.github/workflows/rust-pr.yml) runs this suite on Python 3.10 and 3.13 (matrix) on every PR, after building the wheel in release mode. diff --git a/bonsai-py/tests/__init__.py b/bonsai-py/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bonsai-py/tests/conftest.py b/bonsai-py/tests/conftest.py new file mode 100644 index 0000000..e828138 --- /dev/null +++ b/bonsai-py/tests/conftest.py @@ -0,0 +1,50 @@ +"""Shared fixtures for bonsai-py tests.""" +from __future__ import annotations + +import socket +from typing import Any, Callable + +import pytest + +import bonsai_bt as bt + + +@pytest.fixture +def free_port() -> int: + """Return a free TCP port (kernel-assigned).""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return int(s.getsockname()[1]) + + +@pytest.fixture +def noop_callback() -> Callable[[Any, Any], tuple[Any, float]]: + """A callback that returns Success immediately with no side effects.""" + + def cb(_args: Any, _bb: Any) -> tuple[Any, float]: + return (bt.Status.Success, 0.0) + + return cb + + +@pytest.fixture +def counting_callback() -> tuple[Callable[[Any, Any], tuple[Any, float]], list[Any]]: + """A callback that records every action it sees and returns Success.""" + calls: list[Any] = [] + + def cb(args: Any, _bb: Any) -> tuple[Any, float]: + calls.append(args.action) + return (bt.Status.Success, args.dt) + + return cb, calls + + +@pytest.fixture +def basic_tree() -> bt.Behavior: + """A small reusable tree: Sequence([Wait(0.5), Action('inc'), Wait(0.5), Action('inc')]).""" + return bt.Sequence([ + bt.Wait(0.5), + bt.Action("inc"), + bt.Wait(0.5), + bt.Action("inc"), + ]) diff --git a/bonsai-py/tests/test_action_args.py b/bonsai-py/tests/test_action_args.py new file mode 100644 index 0000000..33bfb48 --- /dev/null +++ b/bonsai-py/tests/test_action_args.py @@ -0,0 +1,82 @@ +"""ActionArgs: construction, identity, frozen, repr.""" +from __future__ import annotations + +import math + +import pytest + +import bonsai_bt as bt + + +class TestActionArgsConstruction: + def test_construct_with_string_action(self) -> None: + """String action is stored verbatim and accessible via .action.""" + a = bt.ActionArgs(0.5, "inc") + assert a.dt == 0.5 + assert a.action == "inc" + + def test_construct_with_none_action(self) -> None: + """None is a valid action value (preserved by identity).""" + a = bt.ActionArgs(0.0, None) + assert a.action is None + + def test_construct_with_arbitrary_object(self) -> None: + """Any Python object can be an action; identity is preserved.""" + sentinel = object() + a = bt.ActionArgs(0.1, sentinel) + assert a.action is sentinel + + def test_construct_with_dict_action(self) -> None: + """Dict actions are passed by reference — mutations are visible.""" + d = {"type": "fire"} + a = bt.ActionArgs(0.0, d) + assert a.action is d + + def test_construct_with_int_dt(self) -> None: + """int dt coerces to float at the FFI boundary.""" + a = bt.ActionArgs(1, "x") + assert a.dt == 1.0 + assert isinstance(a.dt, float) + + @pytest.mark.parametrize("dt", [float("nan"), float("inf"), float("-inf")]) + def test_construct_with_special_dt(self, dt: float) -> None: + """NaN / +inf / -inf dt are accepted at construction — no guard at the FFI boundary.""" + a = bt.ActionArgs(dt, "x") + if math.isnan(dt): + assert math.isnan(a.dt) + else: + assert a.dt == dt + + +class TestActionArgsImmutability: + def test_dt_readonly(self) -> None: + """ActionArgs.dt is frozen — assignment raises AttributeError.""" + a = bt.ActionArgs(0.5, "x") + with pytest.raises(AttributeError): + a.dt = 0.7 # type: ignore[misc] + + def test_action_readonly(self) -> None: + """ActionArgs.action is frozen — assignment raises AttributeError.""" + a = bt.ActionArgs(0.5, "x") + with pytest.raises(AttributeError): + a.action = "y" # type: ignore[misc] + + +class TestActionArgsRepr: + @pytest.mark.parametrize( + "dt, action, expected", + [ + (0.5, "inc", "ActionArgs(dt=0.5, action='inc')"), + (0.0, None, "ActionArgs(dt=0, action=None)"), + (1.0, 42, "ActionArgs(dt=1, action=42)"), + ], + ) + def test_repr_format(self, dt: float, action: object, expected: str) -> None: + """repr() renders as `ActionArgs(dt=..., action=...)` with Python repr on the action.""" + assert repr(bt.ActionArgs(dt, action)) == expected + + +class TestActionArgsModuleAttribution: + def test_module(self) -> None: + """ActionArgs.__module__ is `bonsai_bt` (required for pickle / introspection).""" + assert bt.ActionArgs.__module__ == "bonsai_bt" diff --git a/bonsai-py/tests/test_behavior.py b/bonsai-py/tests/test_behavior.py new file mode 100644 index 0000000..d2bd32e --- /dev/null +++ b/bonsai-py/tests/test_behavior.py @@ -0,0 +1,360 @@ +"""Behavior class + 14 factory functions + ports of Rust behavior_tests.""" +from __future__ import annotations + +from typing import Any, Callable + +import pytest + +import bonsai_bt as bt + +FACTORY_NAMES = ( + "Action", "Wait", "WaitForever", + "Invert", "AlwaysSucceed", + "Sequence", "Select", "WhenAll", "WhenAny", "After", "Race", + "If", "While", "WhileAll", +) + + +class TestFactoriesPresent: + @pytest.mark.parametrize("name", FACTORY_NAMES) + def test_factory_exported(self, name: str) -> None: + """Each of the 14 factory names is importable and callable.""" + assert hasattr(bt, name), f"missing factory {name}" + assert callable(getattr(bt, name)), f"{name} not callable" + + def test_factory_count(self) -> None: + """Exactly 14 factory names tracked — guards against silent additions.""" + assert len(FACTORY_NAMES) == 14 + + +def _trivial(label: str) -> bt.Behavior: + return bt.Action(label) + + +class TestFactoryConstruction: + @pytest.mark.parametrize( + "build, expected_repr", + [ + (lambda: bt.Action("x"), "Action(...)"), + (lambda: bt.Wait(1.0), "Wait(1)"), + (lambda: bt.WaitForever(), "WaitForever"), + (lambda: bt.Invert(_trivial("c")), "Invert(...)"), + (lambda: bt.AlwaysSucceed(_trivial("c")), "AlwaysSucceed(...)"), + (lambda: bt.Sequence([_trivial("a"), _trivial("b")]), "Sequence(2)"), + (lambda: bt.Select([_trivial("a")]), "Select(1)"), + (lambda: bt.WhenAll([_trivial("a")]), "WhenAll(1)"), + (lambda: bt.WhenAny([_trivial("a"), _trivial("b")]), "WhenAny(2)"), + (lambda: bt.After([_trivial("a")]), "After(1)"), + (lambda: bt.Race([_trivial("a"), _trivial("b")]), "Race(2)"), + (lambda: bt.If(_trivial("c"), _trivial("s"), _trivial("f")), "If(...)"), + (lambda: bt.While(_trivial("c"), [_trivial("b")]), "While(1)"), + (lambda: bt.WhileAll(_trivial("c"), [_trivial("b")]), "WhileAll(1)"), + ], + ) + def test_each_factory_builds( + self, build: Callable[[], bt.Behavior], expected_repr: str + ) -> None: + """Each factory builds a Behavior with the expected bounded repr.""" + node = build() + assert isinstance(node, bt.Behavior) + assert repr(node) == expected_repr + + +class TestValidationGuards: + def test_while_empty_body_raises(self) -> None: + """While with empty body raises ValueError (would panic in Rust).""" + with pytest.raises(ValueError, match="must not be empty"): + bt.While(_trivial("c"), []) + + def test_whileall_empty_body_raises(self) -> None: + """WhileAll with empty body raises ValueError (would panic in Rust).""" + with pytest.raises(ValueError, match="must not be empty"): + bt.WhileAll(_trivial("c"), []) + + def test_wait_nan_raises(self) -> None: + """Wait(NaN) raises ValueError at the Python boundary, never reaches Rust.""" + with pytest.raises(ValueError, match="NaN"): + bt.Wait(float("nan")) + + @pytest.mark.parametrize( + "build", + [ + lambda: bt.Sequence([]), + lambda: bt.Select([]), + lambda: bt.WhenAll([]), + lambda: bt.WhenAny([]), + lambda: bt.After([]), + lambda: bt.Race([]), + ], + ) + def test_other_empty_composites_allowed( + self, build: Callable[[], bt.Behavior] + ) -> None: + """Empty Sequence/Select/etc. are allowed (don't panic in Rust).""" + node = build() + assert isinstance(node, bt.Behavior) + + @pytest.mark.parametrize("value", [-1.0, 0.0, float("inf"), 1]) + def test_wait_passthrough_values(self, value: float) -> None: + """Negative, zero, inf, int -- all accepted at the boundary.""" + assert isinstance(bt.Wait(value), bt.Behavior) + + +class TestSubtreeReuse: + def test_leaf_reuse(self) -> None: + """The same Behavior leaf can appear as a child of one parent multiple times.""" + wait = bt.Wait(1.0) + tree = bt.Sequence([wait, wait, wait]) + assert repr(tree) == "Sequence(3)" + + def test_nested_subtree_reuse(self) -> None: + """Nested subtrees are reusable too — the same composite can be a child multiple times.""" + inner = bt.Sequence([bt.Wait(0.1), bt.Action("x")]) + outer = bt.Sequence([inner, inner, inner]) + assert repr(outer) == "Sequence(3)" + + def test_subtree_reused_across_bts(self) -> None: + """The same Behavior root can drive multiple independent BTs without interference.""" + subtree = bt.Sequence([bt.Action("a"), bt.Action("b")]) + calls1: list[Any] = [] + calls2: list[Any] = [] + + def make_cb(out: list[Any]) -> Callable[[Any, Any], tuple[bt.Status, float]]: + def cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + out.append(args.action) + return (bt.Status.Success, 0.0) + return cb + + bt.BT(subtree, None).tick(0.0, make_cb(calls1)) + bt.BT(subtree, None).tick(0.0, make_cb(calls2)) + assert calls1 == ["a", "b"] + assert calls2 == ["a", "b"] + + +class TestArgumentParsing: + def test_kwargs_on_if(self) -> None: + """If accepts cond / on_success / on_failure as keyword arguments.""" + tree = bt.If( + cond=bt.Action("c"), + on_success=bt.Action("s"), + on_failure=bt.Action("f"), + ) + assert repr(tree) == "If(...)" + + def test_tuple_accepted_for_children(self) -> None: + """PyO3 Vec extractor accepts indexable sequences, not just list.""" + tree = bt.Sequence((bt.Action("a"), bt.Action("b"))) + assert repr(tree) == "Sequence(2)" + + def test_generator_rejected_for_children(self) -> None: + """Generators are NOT accepted (extractor needs random access).""" + with pytest.raises(TypeError): + bt.Sequence(bt.Action(x) for x in ["a", "b"]) + + +class TestIdentityEquality: + @pytest.mark.parametrize( + "build", + [ + lambda: bt.Action("x"), + lambda: bt.Wait(1.0), + lambda: bt.WaitForever(), + lambda: bt.Invert(bt.Action("x")), + lambda: bt.AlwaysSucceed(bt.Action("x")), + lambda: bt.Sequence([bt.Action("x")]), + lambda: bt.Select([bt.Action("x")]), + lambda: bt.WhenAll([bt.Action("x")]), + lambda: bt.WhenAny([bt.Action("x")]), + lambda: bt.After([bt.Action("x")]), + lambda: bt.Race([bt.Action("x")]), + lambda: bt.If(bt.Action("c"), bt.Action("s"), bt.Action("f")), + lambda: bt.While(bt.Action("c"), [bt.Action("b")]), + lambda: bt.WhileAll(bt.Action("c"), [bt.Action("b")]), + ], + ) + def test_identity_based_eq_per_variant( + self, build: Callable[[], bt.Behavior] + ) -> None: + """Two structurally-identical Behaviors compare unequal — equality is identity.""" + a, b = build(), build() + assert a is not b + assert (a == b) is False + assert a != b + + +class TestBehaviorAttribution: + def test_module(self) -> None: + """Behavior.__module__ is `bonsai_bt` (required for pickle / introspection).""" + assert bt.Behavior.__module__ == "bonsai_bt" + + +# ---------- Ports of Rust behavior_tests.rs (golden-truth equivalence) ---------- + +class TestBehaviorRustParity: + """Tests ported from bonsai/tests/behavior_tests.rs. If a Rust test changes, + its Python counterpart must change too. Main drift gate against Rust semantics.""" + + def test_immediate_termination(self) -> None: + """A 0.0s tick runs all leaves; reset_bt then re-runs the whole sequence.""" + acc = [0] + + def cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + if args.action == "inc": + acc[0] += 1 + return (bt.Status.Success, args.dt) + + tree = bt.Sequence([bt.Action("inc"), bt.Action("inc")]) + b = bt.BT(tree, None) + b.tick(0.0, cb) + assert acc[0] == 2 + assert b.is_finished() + b.reset_bt() + b.tick(1.0, cb) + assert acc[0] == 4 + assert b.is_finished() + + def test_sequence_of_wait_then_action(self) -> None: + """A single tick with enough dt completes both Wait and the trailing Action.""" + seen: list[Any] = [] + + def cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + seen.append(args.action) + return (bt.Status.Success, args.dt) + + b = bt.BT(bt.Sequence([bt.Wait(1.0), bt.Action("inc")]), None) + b.tick(1.0, cb) + assert seen == ["inc"] + + def test_wait_half_then_half(self) -> None: + """Two 0.5s ticks accumulate to clear a 1.0s Wait before the Action fires.""" + seen: list[Any] = [] + + def cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + seen.append(args.action) + return (bt.Status.Success, args.dt) + + b = bt.BT(bt.Sequence([bt.Wait(1.0), bt.Action("inc")]), None) + b.tick(0.5, cb) + assert seen == [] + b.tick(0.5, cb) + assert seen == ["inc"] + + def test_select_succeed_on_first(self) -> None: + """Select short-circuits at the first Success; later siblings are never invoked.""" + calls: list[Any] = [] + + def cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + calls.append(args.action) + return (bt.Status.Success, args.dt) + + tree = bt.Select([bt.Action("a"), bt.Action("b"), bt.Action("c")]) + b = bt.BT(tree, None) + b.tick(0.1, cb) + assert calls == ["a"] + + def test_select_first_failure_tries_next(self) -> None: + """Select advances past Failures and reports Success at the first successful child.""" + calls: list[Any] = [] + + def cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + calls.append(args.action) + if args.action == "fail": + return (bt.Status.Failure, args.dt) + return (bt.Status.Success, args.dt) + + tree = bt.Select([bt.Action("fail"), bt.Action("ok")]) + b = bt.BT(tree, None) + result = b.tick(0.1, cb) + assert result is not None + status, _ = result + assert status == bt.Status.Success + assert calls == ["fail", "ok"] + + def test_if_true_branch(self) -> None: + """If runs on_success when the condition succeeds; on_failure is not invoked.""" + seen: list[Any] = [] + + def cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + seen.append(args.action) + return (bt.Status.Success, args.dt) + + tree = bt.If(bt.Action("cond_true"), bt.Action("yes"), bt.Action("no")) + b = bt.BT(tree, None) + b.tick(0.1, cb) + assert "yes" in seen + assert "no" not in seen + + def test_if_false_branch(self) -> None: + """If runs on_failure when the condition fails; on_success is not invoked.""" + seen: list[Any] = [] + + def cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + seen.append(args.action) + if args.action == "cond_false": + return (bt.Status.Failure, args.dt) + return (bt.Status.Success, args.dt) + + tree = bt.If(bt.Action("cond_false"), bt.Action("yes"), bt.Action("no")) + b = bt.BT(tree, None) + b.tick(0.1, cb) + assert "no" in seen + assert "yes" not in seen + + def test_invert_swaps_outcomes(self) -> None: + """Invert flips Success <-> Failure on the child's return status.""" + def yields_success(_a: Any, _b: Any) -> tuple[bt.Status, float]: + return (bt.Status.Success, 0.0) + + def yields_failure(_a: Any, _b: Any) -> tuple[bt.Status, float]: + return (bt.Status.Failure, 0.0) + + b1 = bt.BT(bt.Invert(bt.Action("x")), None) + r = b1.tick(0.0, yields_success) + assert r is not None + assert r[0] == bt.Status.Failure + + b2 = bt.BT(bt.Invert(bt.Action("x")), None) + r = b2.tick(0.0, yields_failure) + assert r is not None + assert r[0] == bt.Status.Success + + def test_always_succeed_swallows_failure(self) -> None: + """AlwaysSucceed coerces a child's Failure into Success.""" + def yields_failure(_a: Any, _b: Any) -> tuple[bt.Status, float]: + return (bt.Status.Failure, 0.0) + + b = bt.BT(bt.AlwaysSucceed(bt.Action("x")), None) + r = b.tick(0.0, yields_failure) + assert r is not None + assert r[0] == bt.Status.Success + + def test_when_all_waits_for_all(self) -> None: + """WhenAll blocks the parent Sequence until both parallel children finish.""" + seen: list[Any] = [] + + def cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + seen.append(args.action) + return (bt.Status.Success, args.dt) + + tree = bt.Sequence([ + bt.WhenAll([bt.Wait(0.5), bt.Wait(1.0)]), + bt.Action("after"), + ]) + b = bt.BT(tree, None) + b.tick(0.5, cb) + assert seen == [] + b.tick(0.5, cb) + assert seen == ["after"] + + def test_while_loops_until_cond_fails(self) -> None: + """While re-runs its body each iteration while the cond stays Running.""" + seen: list[str] = [] + + def cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + seen.append(args.action) + return (bt.Status.Success, args.dt) + + tree = bt.While(bt.Wait(50.0), [bt.Wait(0.5), bt.Action("tick"), bt.Wait(0.5)]) + b = bt.BT(tree, None) + b.tick(10.0, cb) + assert seen.count("tick") == 10 diff --git a/bonsai-py/tests/test_bt.py b/bonsai-py/tests/test_bt.py new file mode 100644 index 0000000..dd72140 --- /dev/null +++ b/bonsai-py/tests/test_bt.py @@ -0,0 +1,256 @@ +"""BT class: tick mechanics, callback errors, blackboard, reset, finished state.""" +from __future__ import annotations + +from typing import Any, Callable + +import pytest + +import bonsai_bt as bt + + +class TestBTConstruction: + def test_dict_blackboard(self) -> None: + """A dict can serve as the blackboard; round-trips by equality.""" + b = bt.BT(bt.Action("x"), {"k": 1}) + assert b.blackboard() == {"k": 1} + + def test_none_blackboard(self) -> None: + """None is a legal blackboard for trees that don't need shared state.""" + b = bt.BT(bt.Action("x"), None) + assert b.blackboard() is None + + def test_custom_blackboard(self) -> None: + """Any Python object can be the blackboard; identity is preserved.""" + class State: + def __init__(self) -> None: + self.counter = 0 + + s = State() + b = bt.BT(bt.Action("x"), s) + assert b.blackboard() is s + + def test_module(self) -> None: + """BT.__module__ is `bonsai_bt` (required for pickle / introspection).""" + assert bt.BT.__module__ == "bonsai_bt" + + +class TestTick: + def test_doctest_equivalent(self) -> None: + """Line-for-line port of the bonsai/src/lib.rs Rust doctest: 5 ticks at 0.5s land count==1.""" + tree = bt.Sequence([ + bt.Wait(1.0), bt.Action("inc"), + bt.Wait(1.0), bt.Action("inc"), + bt.Wait(0.5), bt.Action("dec"), + ]) + bb = {"count": 0} + b = bt.BT(tree, bb) + acc = 0 + + def cb(args: Any, _blackboard: Any) -> tuple[bt.Status, float]: + nonlocal acc + if args.action == "inc": + acc += 1 + return (bt.Status.Success, args.dt) + if args.action == "dec": + acc -= 1 + return (bt.Status.Success, args.dt) + return bt.RUNNING + + for _ in range(5): + b.tick(0.5, cb) + bb["count"] = acc + assert bb["count"] == 1 + assert b.tick_count() == 5 + + def test_tick_returns_status_and_dt( + self, + basic_tree: bt.Behavior, + noop_callback: Callable[[Any, Any], tuple[Any, float]], + ) -> None: + """tick() returns a (Status, float) tuple when the BT has not yet finished.""" + b = bt.BT(basic_tree, None) + result = b.tick(2.0, noop_callback) + assert result is not None + status, remaining = result + assert isinstance(status, bt.Status) + assert isinstance(remaining, float) + + def test_tick_on_finished_returns_none(self) -> None: + """Once is_finished() is True, every subsequent tick() returns None.""" + def done(_a: Any, _b: Any) -> tuple[bt.Status, float]: + return (bt.Status.Success, 0.0) + + b = bt.BT(bt.Action("x"), None) + b.tick(0.0, done) + assert b.is_finished() + assert b.tick(0.0, done) is None + + def test_reset_bt_on_unstarted_tree(self) -> None: + """reset_bt() on a never-ticked BT is a no-op; subsequent tick still works normally.""" + b = bt.BT(bt.Action("x"), None) + b.reset_bt() + assert not b.is_finished() + assert b.tick_count() == 0 + result = b.tick(0.0, lambda _a, _bb: (bt.Status.Success, 0.0)) + assert result is not None + + def test_tick_count_survives_reset( + self, + noop_callback: Callable[[Any, Any], tuple[Any, float]], + ) -> None: + """tick_count accumulates across ticks and persists across reset_bt (never zeroed).""" + b = bt.BT(bt.Wait(10.0), None) + for _ in range(3): + b.tick(0.1, noop_callback) + assert b.tick_count() == 3 + b.reset_bt() + assert b.tick_count() == 3, "tick_count must survive reset_bt" + assert not b.is_finished() + + +class TestCallbackContract: + def test_callback_exception_propagates_with_message(self) -> None: + """A Python exception raised inside the callback bubbles up through tick() intact.""" + def boom(_a: Any, _b: Any) -> tuple[bt.Status, float]: + raise ValueError("boom") + + b = bt.BT(bt.Action("x"), None) + with pytest.raises(ValueError, match="boom"): + b.tick(0.0, boom) + + def test_callback_wrong_return_shape_rejected(self) -> None: + """A callback that returns a non-(Status, float) value is rejected by the extractor.""" + def bad(_a: Any, _b: Any) -> str: + return "not a tuple" + + b = bt.BT(bt.Action("x"), None) + with pytest.raises(Exception): + b.tick(0.0, bad) + + def test_when_all_short_circuits_on_callback_raise(self) -> None: + """After a callback raises on a WhenAll child, later siblings in the same tick are not invoked.""" + order: list[str] = [] + + def cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + order.append(args.action) + if args.action == "b": + raise ValueError("stop") + return (bt.Status.Success, 0.0) + + tree = bt.WhenAll([bt.Action("a"), bt.Action("b"), bt.Action("c")]) + b = bt.BT(tree, None) + with pytest.raises(ValueError): + b.tick(0.0, cb) + assert order == ["a", "b"] + + def test_callback_returns_nan_dt_sanitized(self) -> None: + """A NaN dt returned from the callback is sanitized by upstream BT and never surfaces to Python.""" + import math + + def nan_cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + if args.action == "x": + return (bt.Status.Running, float("nan")) + return (bt.Status.Success, args.dt) + + b = bt.BT(bt.Sequence([bt.Action("x")]), None) + result = b.tick(1.0, nan_cb) + assert result is not None + _, dt = result + assert not math.isnan(dt) + + +class TestBlackboard: + def test_blackboard_identity_preserved(self) -> None: + """blackboard() returns the same Python object passed to BT() — not a copy.""" + bb = {"count": 0} + b = bt.BT(bt.Action("x"), bb) + assert b.blackboard() is bb + + def test_blackboard_mutation_persists_via_callback(self) -> None: + """Mutations done through the callback's blackboard handle persist in the original object.""" + bb = {"count": 0} + + def inc(_args: Any, blackboard: Any) -> tuple[bt.Status, float]: + blackboard["count"] += 1 + return (bt.Status.Success, 0.0) + + b = bt.BT(bt.Sequence([bt.Action("x"), bt.Action("x")]), bb) + b.tick(0.0, inc) + assert bb["count"] == 2 + + def test_action_identity_through_callback(self) -> None: + """The action object passed to bt.Action(...) arrives at the callback by identity, not equality.""" + sentinel = object() + seen: list[bool] = [] + + def cb(args: Any, _bb: Any) -> tuple[bt.Status, float]: + seen.append(args.action is sentinel) + return (bt.Status.Success, 0.0) + + b = bt.BT(bt.Action(sentinel), None) + b.tick(0.0, cb) + assert seen == [True] + + def test_reset_preserves_blackboard(self) -> None: + """reset_bt() rewinds tree state but does NOT touch the blackboard contents.""" + bb = {"count": 5} + b = bt.BT(bt.Action("x"), bb) + + def done(_a: Any, _b: Any) -> tuple[bt.Status, float]: + return (bt.Status.Success, 0.0) + + b.tick(0.0, done) + b.reset_bt() + assert b.blackboard() is bb + assert bb["count"] == 5 + + +class TestRunningConstant: + def test_value(self) -> None: + """bt.RUNNING is the tuple (Status.Running, 0.0).""" + assert bt.RUNNING == (bt.Status.Running, 0.0) + + def test_first_element_is_running(self) -> None: + """Unpacking bt.RUNNING yields Status.Running as the first element.""" + status, _ = bt.RUNNING + assert status is bt.Status.Running + + def test_used_as_callback_return(self) -> None: + """`return bt.RUNNING` from a callback is a valid Running tick — BT stays unfinished.""" + + def keep_running(_a: Any, _b: Any) -> tuple[bt.Status, float]: + return bt.RUNNING + + b = bt.BT(bt.Action("x"), None) + result = b.tick(0.0, keep_running) + assert result is not None + status, _ = result + assert status == bt.Status.Running + assert not b.is_finished() + + +class TestGraphviz: + def test_returns_dot_string(self) -> None: + """BT.graphviz() returns a non-empty graphviz DOT string.""" + b = bt.BT(bt.Sequence([bt.Action("a"), bt.Action("b")]), None) + dot = b.graphviz() + assert isinstance(dot, str) + assert "digraph" in dot.lower() or "graph" in dot.lower() + + def test_graphviz_idempotent_around_tick(self) -> None: + """graphviz() can be called before or after ticks — tree shape is invariant.""" + b = bt.BT(bt.Action("x"), None) + before = b.graphviz() + b.tick(0.0, lambda _a, _bb: (bt.Status.Success, 0.0)) + after = b.graphviz() + assert before == after + + def test_graphviz_on_poisoned_bt_raises(self, free_port: int) -> None: + """graphviz() on a poisoned BT raises RuntimeError (consistent with other methods).""" + holder = bt.BT(bt.Action("x"), None).with_telemetry(free_port) + assert holder is not None + victim = bt.BT(bt.Action("y"), None) + with pytest.raises(OSError): + victim.with_telemetry(free_port) + with pytest.raises(RuntimeError, match="invalidated"): + victim.graphviz() diff --git a/bonsai-py/tests/test_drift.py b/bonsai-py/tests/test_drift.py new file mode 100644 index 0000000..911de4f --- /dev/null +++ b/bonsai-py/tests/test_drift.py @@ -0,0 +1,69 @@ +"""Drift gate: every Rust binding has at least one exercising Python test. + +Two parity checks: +1. Every `#[pyfunction]` in `bonsai-py/src/*.rs` (with a `#[pyo3(name = "X")]` + override) appears in `bonsai_bt.__all__`. +2. Every public `#[pyclass]` (with `module = "bonsai_bt"` set) appears in + `bonsai_bt.__all__`. + +If a Rust contributor adds a binding and forgets to: + - add it to __all__, + - add a Python test exercising it, +then `test_no_unexercised_factories` fails CI. +""" +from __future__ import annotations + +import re +from pathlib import Path + +import bonsai_bt as bt + +SRC_DIR = Path(__file__).resolve().parent.parent / "src" + +RUST_NAME_RE = re.compile(r'#\[pyo3\(\s*name\s*=\s*"([A-Za-z_][A-Za-z0-9_]*)"\s*\)\]') +PYCLASS_NAME_RE = re.compile( + r'#\[pyclass\([^\]]*\bname\s*=\s*"([A-Za-z_][A-Za-z0-9_]*)"[^\]]*\)\]' +) + + +def _names_from_rust(pattern: re.Pattern[str]) -> set[str]: + names: set[str] = set() + for path in SRC_DIR.rglob("*.rs"): + names.update(pattern.findall(path.read_text(encoding="utf-8"))) + return names + + +def test_every_rust_pyfunction_in_all() -> None: + """Every Rust #[pyo3(name=...)] symbol scanned from src/*.rs must appear in bonsai_bt.__all__.""" + rust_names = _names_from_rust(RUST_NAME_RE) + rust_names = {n for n in rust_names if not n.startswith("_")} + missing = rust_names - set(bt.__all__) + assert not missing, ( + f"Rust declares these #[pyo3(name=...)] symbols but they're missing " + f"from bonsai_bt.__all__: {sorted(missing)}." + ) + + +def test_every_rust_pyclass_in_all() -> None: + """Every Rust #[pyclass(name=...)] must appear in bonsai_bt.__all__.""" + rust_classes = _names_from_rust(PYCLASS_NAME_RE) + missing = rust_classes - set(bt.__all__) + assert not missing, ( + f"Rust declares these pyclasses but they're missing from __all__: " + f"{sorted(missing)}" + ) + + +def test_no_unexercised_factories() -> None: + """Every name in __all__ is mentioned in at least one other test file — catches added-but-untested symbols.""" + tests_dir = Path(__file__).resolve().parent + test_text = "" + for p in tests_dir.glob("test_*.py"): + if p.name == "test_drift.py": + continue + test_text += p.read_text(encoding="utf-8") + + unexercised = [name for name in bt.__all__ if name not in test_text] + assert not unexercised, ( + f"These public names have no test mentioning them: {unexercised}." + ) diff --git a/bonsai-py/tests/test_module.py b/bonsai-py/tests/test_module.py new file mode 100644 index 0000000..b03fcda --- /dev/null +++ b/bonsai-py/tests/test_module.py @@ -0,0 +1,47 @@ +"""Module-level surface: __version__, __all__, __doc__, RUNNING.""" +from __future__ import annotations + +import bonsai_bt as bt + + +def test_version_present() -> None: + """bt.__version__ pins the wheel version (0.12.0); bump per release.""" + assert bt.__version__ == "0.12.0" + + +def test_docstring_present() -> None: + """Module docstring is non-empty and mentions behavior trees.""" + assert bt.__doc__ + assert "behavior" in bt.__doc__.lower() + + +def test_all_contents() -> None: + """__all__ contains exactly the 4 types + 14 factories + RUNNING = 19 names.""" + expected = { + "Status", "ActionArgs", "Behavior", "BT", + "Action", "Wait", "WaitForever", + "Invert", "AlwaysSucceed", + "Sequence", "Select", "WhenAll", "WhenAny", "After", "Race", + "If", "While", "WhileAll", + "RUNNING", + } + assert set(bt.__all__) == expected + + +def test_all_names_are_accessible() -> None: + """Every name listed in __all__ is actually attached to the module.""" + for name in bt.__all__: + assert hasattr(bt, name), f"missing {name}" + + +def test_running_constant() -> None: + """bt.RUNNING is the immutable tuple (Status.Running, 0.0).""" + assert bt.RUNNING == (bt.Status.Running, 0.0) + assert isinstance(bt.RUNNING, tuple) + + +def test_stub_present() -> None: + """The auto-generated .pyi stub ships alongside the wheel.""" + from pathlib import Path + stub = Path(bt.__file__).parent / "__init__.pyi" + assert stub.exists() diff --git a/bonsai-py/tests/test_mypy_strict.py b/bonsai-py/tests/test_mypy_strict.py new file mode 100644 index 0000000..438da50 --- /dev/null +++ b/bonsai-py/tests/test_mypy_strict.py @@ -0,0 +1,61 @@ +"""mypy --strict acceptance test for the typed public surface.""" +from __future__ import annotations + +import importlib.util +import os +import subprocess +import sys +import tempfile + +import pytest + + +def test_mypy_strict_accepts_sample_script() -> None: + """mypy --strict accepts a sample script using bonsai_bt's typed surface. + + Catches stub regressions that pytest's runtime tests wouldn't notice — + e.g. a removed annotation, a wrong type in __init__.pyi, or a missing + overload. Runs `mypy --strict` on a generated sample exercising + Status, ActionArgs, Behavior factories, BT.tick, with_telemetry. + """ + if importlib.util.find_spec("mypy") is None: + pytest.skip("mypy not installed (pip install mypy to enable)") + + sample = ( + "import bonsai_bt as bt\n" + "\n" + "def cb(args: bt.ActionArgs, bb: object) -> tuple[bt.Status, float]:\n" + ' if args.action == "inc":\n' + " return (bt.Status.Success, args.dt)\n" + " return bt.RUNNING\n" + "\n" + 'tree = bt.Sequence([bt.Action("inc"), bt.Wait(1)]) # int coerces to float\n' + 'tree_bt = bt.BT(tree, {"count": 0})\n' + "for _ in range(3):\n" + " res: tuple[bt.Status, float] | None = tree_bt.tick(0.5, cb)\n" + " if res is None:\n" + " tree_bt.reset_bt()\n" + "\n" + 'chained: bt.BT = bt.BT(bt.Action("x"), None).with_telemetry(0, host="0.0.0.0")\n' + ) + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", delete=False, encoding="utf-8" + ) as f: + f.write(sample) + sample_path = f.name + + try: + result = subprocess.run( + [sys.executable, "-m", "mypy", "--strict", sample_path], + capture_output=True, + text=True, + check=False, + ) + assert result.returncode == 0, ( + f"mypy --strict failed (exit {result.returncode}):\n" + f"--- stdout ---\n{result.stdout}\n" + f"--- stderr ---\n{result.stderr}" + ) + finally: + os.unlink(sample_path) diff --git a/bonsai-py/tests/test_performance.py b/bonsai-py/tests/test_performance.py new file mode 100644 index 0000000..69c689a --- /dev/null +++ b/bonsai-py/tests/test_performance.py @@ -0,0 +1,53 @@ +"""Performance budget -- generous bounds to avoid CI flakiness. Heavy +benchmarks gated behind `@pytest.mark.bench` (run with `pytest -m bench`).""" +from __future__ import annotations + +import time +from typing import Any + +import pytest + +import bonsai_bt as bt + + +@pytest.mark.perf +class TestTickBudget: + def test_simple_tick_100_iterations_under_500ms(self) -> None: + """100 BT(tree).tick(cb) round-trips complete under 500 ms — catches order-of-magnitude regressions.""" + + def cb(_a: Any, _b: Any) -> tuple[bt.Status, float]: + return (bt.Status.Success, 0.0) + + # Hoist tree construction; each iter still needs a fresh BT because + # the tree finishes after one tick and tick() returns None thereafter. + tree = bt.Sequence([bt.Action("x")]) + start = time.perf_counter() + for _ in range(100): + bt.BT(tree, None).tick(0.0, cb) + elapsed = time.perf_counter() - start + assert elapsed < 0.5, f"100 ticks took {elapsed*1000:.0f}ms (budget: 500 ms)" + + def test_construction_under_5_seconds_per_1000(self) -> None: + """1000 Sequence([Action, Wait]) constructions fit in 5 seconds — guards against construction-time blowups.""" + start = time.perf_counter() + for _ in range(1000): + bt.Sequence([bt.Action("a"), bt.Wait(0.1)]) + elapsed = time.perf_counter() - start + assert elapsed < 5.0, f"1000 constructions took {elapsed*1000:.0f}ms" + + +@pytest.mark.bench +class TestBenchmarks: + """Microbenchmarks for local profiling. Run with `pytest -m bench`.""" + + def test_bench_tick_throughput(self) -> None: + """Prints achieved ticks/sec for 10_000 BT.tick round-trips — informational, no assertion.""" + import timeit + + def one_tick() -> None: + b = bt.BT(bt.Action("x"), None) + b.tick(0.0, lambda _a, _bb: (bt.Status.Success, 0.0)) + + n = 10_000 + elapsed = timeit.timeit(one_tick, number=n) + print(f"\nbench: {n} ticks in {elapsed:.3f}s = {n/elapsed:.0f} ticks/sec") diff --git a/bonsai-py/tests/test_status.py b/bonsai-py/tests/test_status.py new file mode 100644 index 0000000..56a13dc --- /dev/null +++ b/bonsai-py/tests/test_status.py @@ -0,0 +1,75 @@ +"""Status enum: semantics, pickle, copy, hash, identity.""" +from __future__ import annotations + +import copy +import pickle + +import pytest + +import bonsai_bt as bt + + +class TestStatusSemantics: + def test_three_variants(self) -> None: + """All three Status variants exist and form a non-empty set.""" + assert {bt.Status.Success, bt.Status.Failure, bt.Status.Running} + + @pytest.mark.parametrize( + "variant, expected_int", + [(bt.Status.Success, 0), (bt.Status.Failure, 1), (bt.Status.Running, 2)], + ) + def test_eq_int_discriminant(self, variant: bt.Status, expected_int: int) -> None: + """Discriminants 0/1/2 are locked; reordering is a breaking change.""" + assert variant == expected_int + assert int(variant) == expected_int + + def test_equality(self) -> None: + """Same variant compares equal; different variants compare unequal.""" + assert bt.Status.Success == bt.Status.Success + assert bt.Status.Success != bt.Status.Failure + assert bt.Status.Success != bt.Status.Running + + def test_identity_singleton(self) -> None: + """PyO3 simple enums are singletons; `is` comparison works.""" + assert bt.Status.Success is bt.Status.Success + + def test_hashable_as_dict_key(self) -> None: + """Status implements __hash__ and is usable as a dict key / set member.""" + d = {bt.Status.Success: "ok", bt.Status.Failure: "no"} + assert d[bt.Status.Success] == "ok" + assert d[bt.Status.Failure] == "no" + + def test_repr(self) -> None: + """repr() returns the dotted variant name (Status.Success / Failure / Running).""" + assert repr(bt.Status.Success) == "Status.Success" + assert repr(bt.Status.Failure) == "Status.Failure" + assert repr(bt.Status.Running) == "Status.Running" + + def test_module_attribution(self) -> None: + """`module = "bonsai_bt"` is set on the pyclass — required for pickle.""" + assert bt.Status.__module__ == "bonsai_bt" + + +class TestStatusPickle: + @pytest.mark.parametrize( + "variant", [bt.Status.Success, bt.Status.Failure, bt.Status.Running] + ) + def test_pickle_preserves_singleton_identity(self, variant: bt.Status) -> None: + """pickle round-trip returns the same singleton (not a copy).""" + roundtripped = pickle.loads(pickle.dumps(variant)) + assert roundtripped is variant + + @pytest.mark.parametrize( + "variant", [bt.Status.Success, bt.Status.Failure, bt.Status.Running] + ) + def test_copy_preserves_singleton_identity(self, variant: bt.Status) -> None: + """copy.copy and copy.deepcopy preserve singleton identity.""" + assert copy.copy(variant) is variant + assert copy.deepcopy(variant) is variant + + def test_dict_with_status_round_trips(self) -> None: + """multiprocessing-style pickle: a dict containing Status survives serialization.""" + data = pickle.dumps({"result": bt.Status.Success, "count": 3}) + out = pickle.loads(data) + assert out["result"] is bt.Status.Success + assert out["count"] == 3 diff --git a/bonsai-py/tests/test_telemetry.py b/bonsai-py/tests/test_telemetry.py new file mode 100644 index 0000000..5fda7fe --- /dev/null +++ b/bonsai-py/tests/test_telemetry.py @@ -0,0 +1,60 @@ +"""with_telemetry: chainable, bind failures, poisoned state, host parameter.""" +from __future__ import annotations + +from typing import Any + +import pytest + +import bonsai_bt as bt + + +class TestWithTelemetry: + def test_chainable(self, free_port: int) -> None: + """with_telemetry returns the same BT instance (PyRefMut self), enabling fluent chaining.""" + b = bt.BT(bt.Action("x"), None) + b_after = b.with_telemetry(free_port) + assert b_after is b + result = b_after.tick(0.0, lambda _a, _bb: (bt.Status.Success, 0.0)) + assert result is not None + + def test_host_parameter(self, free_port: int) -> None: + """The optional `host` kwarg lets the listener bind to a non-loopback interface.""" + b = bt.BT(bt.Action("x"), None).with_telemetry(free_port, host="127.0.0.1") + assert b is not None + + def test_explicit_loopback(self, free_port: int) -> None: + """All-keyword form (port=..., host=...) works for callers that prefer kwargs.""" + b = bt.BT(bt.Action("x"), None).with_telemetry(port=free_port, host="127.0.0.1") + assert b is not None + + def test_bound_port_raises_os_error(self, free_port: int) -> None: + """A second bind on a port held by another BT raises OSError with the bind message.""" + holder = bt.BT(bt.Action("x"), None).with_telemetry(free_port) + assert holder is not None + with pytest.raises(OSError, match="failed"): + bt.BT(bt.Action("y"), None).with_telemetry(free_port) + + def test_unbindable_host_raises_os_error(self) -> None: + """Binding to an RFC-reserved address (240.0.0.0/4) raises OSError without hitting DNS.""" + with pytest.raises(OSError): + bt.BT(bt.Action("x"), None).with_telemetry(0, host="240.0.0.1") + + +class TestPoisonedBT: + def test_failed_with_telemetry_poisons_bt(self, free_port: int) -> None: + """A failed with_telemetry poisons the BT; every subsequent method raises RuntimeError.""" + holder = bt.BT(bt.Action("x"), None).with_telemetry(free_port) + assert holder is not None + + victim = bt.BT(bt.Action("y"), None) + with pytest.raises(OSError): + victim.with_telemetry(free_port) + + with pytest.raises(RuntimeError, match="invalidated"): + victim.tick(0.0, lambda _a, _b: (bt.Status.Success, 0.0)) + + with pytest.raises(RuntimeError, match="invalidated"): + victim.blackboard() + + with pytest.raises(RuntimeError, match="invalidated"): + victim.tick_count() diff --git a/bonsai-py/tests/test_threading_and_pickle.py b/bonsai-py/tests/test_threading_and_pickle.py new file mode 100644 index 0000000..19c2a33 --- /dev/null +++ b/bonsai-py/tests/test_threading_and_pickle.py @@ -0,0 +1,61 @@ +"""Cross-process and cross-thread semantics: BT is unsendable + unpicklable.""" +from __future__ import annotations + +import pickle +import threading + +import pytest + +import bonsai_bt as bt + + +class TestBTUnsendable: + def test_thread_send_raises_panic(self) -> None: + """PyBT is unsendable; touching it from a different thread raises PyO3's PanicException.""" + b = bt.BT(bt.Action("x"), None) + captured: list[BaseException] = [] + + def worker() -> None: + try: + b.tick(0.0, lambda _a, _b: (bt.Status.Success, 0.0)) + except BaseException as e: # PanicException is BaseException-subclass + captured.append(e) + + t = threading.Thread(target=worker) + t.start() + t.join() + + assert len(captured) == 1 + msg = str(captured[0]) + assert "unsendable" in msg or "thread" in msg.lower() + + +class TestBTUnpicklable: + def test_bt_not_picklable(self) -> None: + """BT instances have no __reduce__ — pickling raises TypeError/PicklingError.""" + b = bt.BT(bt.Action("x"), None) + with pytest.raises((TypeError, pickle.PicklingError)): + pickle.dumps(b) + + +class TestBehaviorUnpicklable: + def test_behavior_not_picklable(self) -> None: + """Behavior nodes are not picklable (no __reduce__ implemented).""" + with pytest.raises((TypeError, pickle.PicklingError)): + pickle.dumps(bt.Action("x")) + + +class TestActionArgsUnpicklable: + def test_action_args_not_picklable(self) -> None: + """ActionArgs instances are not picklable (no __reduce__ implemented).""" + with pytest.raises((TypeError, pickle.PicklingError)): + pickle.dumps(bt.ActionArgs(0.5, "x")) + + +class TestStatusPicklableAcrossProcesses: + """Status IS picklable — the one multiprocessing-friendly type in the binding.""" + + def test_status_round_trip_through_pickle(self) -> None: + """Each Status variant round-trips through pickle and returns the same singleton.""" + for s in (bt.Status.Success, bt.Status.Failure, bt.Status.Running): + assert pickle.loads(pickle.dumps(s)) is s diff --git a/bonsai/Cargo.toml b/bonsai/Cargo.toml index a6adbfd..7fb657f 100644 --- a/bonsai/Cargo.toml +++ b/bonsai/Cargo.toml @@ -1,5 +1,5 @@ [package] -authors = ["Kristoffer Solberg Rakstad "] +authors = ["Kristoffer Solberg Rakstad "] autotests = false categories = [] description = "Behavior tree" diff --git a/examples/Cargo.toml b/examples/Cargo.toml index ecf3c01..7421791 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -1,5 +1,5 @@ [package] -authors = ["Kristoffer Solberg Rakstad "] +authors = ["Kristoffer Solberg Rakstad "] description = "Behavior tree examples" edition = "2021" name = "examples"