diff --git a/Cargo.lock b/Cargo.lock index 8d9451f6..fc5855b2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -308,6 +308,14 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "example-tdl-package-complex" +version = "0.1.0" +dependencies = [ + "serde", + "spider-tdl", +] + [[package]] name = "flume" version = "0.11.1" @@ -685,6 +693,16 @@ version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" +[[package]] +name = "libloading" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60" +dependencies = [ + "cfg-if", + "windows-link", +] + [[package]] name = "libm" version = "0.2.16" @@ -1305,6 +1323,37 @@ dependencies = [ "tokio", ] +[[package]] +name = "spider-task-executor" +version = "0.1.0" +dependencies = [ + "libloading", + "rmp-serde", + "spider-tdl", + "thiserror", +] + +[[package]] +name = "spider-tdl" +version = "0.1.0" +dependencies = [ + "anyhow", + "rmp-serde", + "serde", + "spider-core", + "spider-tdl-derive", + "thiserror", +] + +[[package]] +name = "spider-tdl-derive" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "spin" version = "0.9.8" @@ -1610,6 +1659,18 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "tdl-integration-tests" +version = "0.1.0" +dependencies = [ + "anyhow", + "rmp-serde", + "serde", + "spider-core", + "spider-task-executor", + "spider-tdl", +] + [[package]] name = "testing_table" version = "0.3.0" diff --git a/Cargo.toml b/Cargo.toml index b19839b5..3aa1db4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,4 +4,9 @@ members = [ "components/spider-core", "components/spider-derive", "components/spider-storage", + "components/spider-task-executor", + "components/spider-tdl", + "components/spider-tdl-derive", + "examples/example-tdl-package-complex", + "tests/tdl-integration", ] diff --git a/claude/task-exec-prototyping/context.md b/claude/task-exec-prototyping/context.md new file mode 100644 index 00000000..6667d3c7 --- /dev/null +++ b/claude/task-exec-prototyping/context.md @@ -0,0 +1,350 @@ +I'm designing a task execution package that allows users to write their own task functions with +custom function signatures using specific types. The high-level goal is to organize these +user-defined functions into a single package, exposing C-ffi APIs to call these methods with runtime +inputs and outputs. I will work through the design in detail. Your job is to help me implement such +a prototype. This package will be implemented in Rust. + +1. Type system + +The package would only support the following types, with the base types aliased to Rust's built-in +types: + +- Primitive types: + - `int8`: `i8` + - `int16`: `i16` + - `int32`: `i32` + - `int64`: `i64` + - `double`: `f64` + - `float`: `f32` + - `Bytes`: `Vec` + - `boolean`: `bool` +- `List`: `Vec`, where `T` is any supported type. +- `Map`: `HashMap`, where `K` must be one of {`int8`, `int16`, `int32`, `int64`, + `Bytes`} and `V` is any supported type. +- User-defined structs, where the struct contains only supported types, with each field uniquely + identified by a name. + +You might want to create these type aliases in the `tdl_types` module. + +User defined task function is expected to have the following signature: + +```rust +fn my_task(args...) -> Result; +``` + +Where: + +- `args...` are positional arguements. Each must be one of the supported types. +- `T` is the return type of the function. It can be any supported type, or a tuple of supported + types. +- `Error` is the error type of the function, defined in the package. It contains a custom variant + that allows users to return a custom error message as a `String`. + +All types, including the result type and the arguments, must support serialization and +deserialization using `serde`. This is because the task function will receive inputs as serialized +msgpack bytes from the C-ffi layer, and return outputs (the result) as serialized msgpack bytes. +We will go through this in detail in a later section. + +2. `task` proc-macro + +First, we will define a type trait for user-defined types. + +```rust +trait Task { + /// The name of the task. + const NAME: &'static str; + + /// The parameters of the task. + type Params: for<'de> Deserialize<'de>; + + /// The return type of the task. + type Return: Serialize + for<'de> Deserialize<'de>; + + fn execute(args: Self::Params) -> Result; +} +``` + +We need a `task` proc-macro that automatically generates the task execution driver code on top of +the user-defined task function. With the following task: + +```rust +fn my_task(a: int32, b: MyStruct1, ...) -> Result<(List), Error> { + // User implementation. +} +``` + +The proc-macro will geneerate the following code: + +```rust +/// An empty struct as a type marker, with the same name as the task. +pub(crate) struct my_task {} + +impl my_task { + /// User implementation + fn __my_task(a: int32, b: MyStruct1, ...) -> Result<(List, int64), Error> { + // User implementation. + } +} + +/// The parameters of the task, mirrored in the function signature. +#[derive(Deserialize)] +struct __my_task_params { + a: int32, + b: MyStruct1, + ... +} + +impl Task for my_task { + // The exact name of the task defined by the user, if not specified. + const NAME: &'static str = "my_task"; + + // The parameters of the task + type Params = __my_task_params; + + // The return type of the task + type Return = (List, int64); + + fn execute(args: Self::Params) -> Result { + // Call the user implementation with the deserialized parameters. + Self::__my_task(args.a, args.b, ...) + } +} +``` + +Requirements: + +- Invalid types should be rejected, including both the args and the return type. +- The proc-macro should accept a `name` argument, allowing users to specify a custom name for the + task that may contain namespace information. +- The return type should always be a tuple, even it only returns a single value. This is to simplify + the result deserialization logic. + +3. Task execution handler + +On top of `Task`, we need a `TaskHandler` trait that wraps the input/output serialization and +deserialization. + +```rust +enum ExecutionResult { + Outputs(Vec), + Error(Vec), +} + +trait TaskHandler: { + fn execute(&self, serialized_inputs: &[u8]) -> Vec; + + fn name(&self) -> &'static str; +} + +struct TaskHandlerImpl { + _marker: std::marker::PhantomData, +} + +impl TaskHandlerImpl { + fn new() -> Self { + Self { _marker: std::marker::PhantomData } + } +} + +impl TaskHandler for TaskHandlerImpl +{ + fn execute_raw(&self, raw_args: &[u8]) -> ExecutionResult { + // 1. Deserialize the input bytes into the parameters. + // The input bytes are serialized `TaskInput`. For this deserialization, check the doc under + // `claude/task-exec-prototyping/struct-serde.md`. + let params: T::Params = ...; + + // 2. Execute the task + let result = T::default().execute(params)?; + + // 3. Serialize the result into `ExecutionResult`. + // If the result is Ok, serialize the output value into bytes and return + // `ExecutionResult::Outputs`. Since the return type is always a tuple, we want to serialize + // each element inside the tuple into msgpack bytes independently, and then serialize them + // using wire format. Please come up with a design that can avoid double memory copying: the + // serialization should be streamingly appended into the output buffer, similar to the input + // deserialization as mentioned in previous. + // On error, serialize the error message into bytes using msgpack and return + // `ExecutionResult::Error`. + } + + fn name(&self) -> &'static str { + T::NAME + } +} +``` + +Read `exection_raw` carefully to generate a correct implementation. + +4. Task registration + +We need a macro to register the task functions into a package. The task package should accept a name +and the task objects (converted by the proc macro in step 2, which has the same name as the +user-defined task function). + +Under the hood, the macro should generate a global hashmap that maps the task name (`Task::NAME`) to +a `dyn TaskHandler` for later access. The registration should also generate C-ffi APIs to access +task functions by given task name and inputs. + +In a library, there can be only one task package. + +The C-ffi APIs required: + +(1). `__spider_tdl_package_get_name` which returns a byte view of the package name. +(2). `__spider_tdl_package_execute` which takes in a task name and serialized input bytes, and +returns serialized output bytes. + +Specification for `__spider_tdl_package_execute`: + +(1) Inputs: + +Inputs are given with the following C-ffi type: + +```rust +/// Represents a C `T const*` pointer + `size_t` length as a single ABI-stable value. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct CArray<'lifetime, T> { + pointer: *const T, + length: usize, + _lifetime: PhantomData<&'lifetime [T]>, +} + +pub type CCharArray<'lifetime> = CArray<'lifetime, c_char>; + +pub type CByteArray<'lifetime> = CArray<'lifetime, u8>; +``` + +The name is given as `CCharArray`, and the input bytes are given as `CByteArray`. + +The API may convert the input bytes into the corresponding Rust reference, as `&str` and `&[u8]`. + +(2) Outputs: + +Outputs are given with the following C-ffi type: + +```rust +#[repr(C)] +struct TaskExecutionResult { + is_error: bool, + pointer: *mut u8, + length: usize, +} + +impl TaskExecutionResult { + pub fn new(result: ExecutionResult) { + /* + Something like: + let is_error = match result { + ExecutionResult::Outputs(_) => false, + ExecutionResult::Error(_) => true, + }; + let buffer = match result { + ExecutionResult::Outputs(bytes) => bytes, + ExecutionResult::Error(bytes) => bytes, + }; + You may need to implement them properly + */ + let buffer: *mut [u8] = Box::into_raw(buffer); + Self { + is_error, + pointer: buffer as *mut u8, // use https://doc.rust-lang.org/nightly/std/primitive.pointer.html#method.as_mut_ptr-1 when stabilized -__- + length: buffer.len(), + } + } + + // by not using `self`, forces callers to do `OwnedSlice::into_raw(buffer)` instead of `buffer.into_raw()`; + // up to preference, this below follows rust std convention + pub fn into_raw(this: Self) -> Box<[u8]> { + unsafe { + Box::from_raw(std::slice::from_raw_parts_mut(this.pointer, this.length)) + } + } +} + +impl Drop for TaskExecutionResult { + fn drop(&mut self) { + unsafe { + Box::from_raw(std::slice::from_raw_parts_mut(this.pointer, this.length)) + } + } +} +``` + +The API should convert the return properly. + +The C-ffi API may call the task table's method directly, which abstracts the situation where the +task is not found (and thus returned as an error of `ExecutionResult::Error`). + +5. End-to-end flow + +Let's now review the end-to-end flow. In this session, you only need to focus on a small set of this +flow, but I'd like to list it here for you to have a better understanding. + +The overall system has three components: the storage, the execution manager, and the actual +execution process. Each is an individual process. + +Step 1: The execution manager requests a task from the storage, with the task name and its inputs. +Step 2: The execution manager passes the task into the execution process (through OS PIPE). +Step 3: The execution process loads the shared library that contains the task function (identified +by the task package name), and calls the task function with the inputs within the same process +through C-ffi APIs. +Step 4: The execution process retrieves the outputs from the task function, and returns them to the +execution manager (through OS PIPE). +Step 5: Depending on the error or success, the execution manager updates the task status in the +storage. + +Note: +1. The execution results are generated in the task package lib. The ownership transfers from the lib +to the execution process across C-ffi. The execution process dispatches the results to see if it's +an error or not, and only notify the execution manager with the payload. The lifetime of the result +ends inside the execution process. This means the C-ffi `TaskExecutionResult` definition must be +visible in both the execution process and the separately compiled task package lib. +2. The error type, `Error`, will be deserialized inside the execution process. This means the error +must be defined in the task package lib, the execution process, and the execution manager. +3. The serialized output bytes are not deserialized until it reaches the execution manager. +4. The execution manager may need to maintain multiple TDL packages (loaded at runtime), but it +should only be asked to execute one task at a time. + +In this session, you only need to implement the task package lib and the driver code in the +execution process that interacts with the lib. You don't need to worry about the rest, but you +should generate the overall flow diagram and make sure the design doesn't violate with the large +picture. + +--- + +Your first task will be to come up with a formal design doc with the plan for how to implement the +package lib and the driver code inside the execution process. Put the design doc in the +`claude/task-exec-prototyping/design-doc.md` file. + +We should refer the package lib as "TDL package", and the execution process as "task executor". + +We will ask you to implement a prototype later. + +--- + +We want to enforce the first parameter to be a `TaskContext` struct, containing the runtime metadata +of the task execution. It should contain: + +```rust +struct TaskContext { + job_id: JobId, + task_id: TaskId, // Currently defined in `spider-storage`, but needed to be in `spider-core` + task_instance_id: TaskInstanceId, +} +``` + +This struct is constructed in the execution manager process, and passed all the way down to the task +package lib, serialized and deserialized using msgpack. + +Help me update the design to include this new requirement. Make sure that: + +* The proc macro should check if the first parameter is a `TaskContext` struct, and only the first + parameter. +* It is visible in the execution manager, the task executor, and the TDL package lib. + +--- + +Besides this new feature, can you also make sure and clearly document what is the behavior if a task +function contains no other parameters than the first one? This would make the param struct an empty +struct, and how would the serde work in this case? diff --git a/claude/task-exec-prototyping/design-doc.md b/claude/task-exec-prototyping/design-doc.md new file mode 100644 index 00000000..54e73e9a --- /dev/null +++ b/claude/task-exec-prototyping/design-doc.md @@ -0,0 +1,962 @@ +# TDL Package & Task Executor: Design Document + +## Context + +The spider system needs a way for users to write custom task functions in Rust, package them into +a shared library (cdylib), and have the task executor load and invoke these functions at runtime +via C-FFI. This design covers the two components we are prototyping: + +1. **TDL Package** -- the shared library containing user-defined tasks, a `#[task]` proc-macro, + a registration macro, and C-FFI entry points. +2. **Task Executor** -- driver code in the execution process that loads a TDL package via `dlopen` + and calls its C-FFI APIs. + +The executor receives a task name and serialized inputs (wire format), dispatches into the loaded +library, and receives serialized outputs or an error. Deserialization of outputs happens upstream +in the execution manager (out of scope for this prototype). + +--- + +## 1. Crate Organization + +Four new crates, added to the workspace: + +| Crate | Path | Type | Purpose | +|---|---|---|---| +| `spider-tdl` | `components/spider-tdl` | `lib` | Shared types, traits, wire serde, registration macro | +| `spider-tdl-derive` | `components/spider-tdl-derive` | `proc-macro` | `#[task]` attribute macro | +| `spider-executor` | `components/spider-executor` | `lib` | Loads cdylib, calls C-FFI | +| `example-tdl-package` | `examples/example-tdl-package` | `cdylib` | Sample TDL package | + +Dependency graph: + +``` +spider-tdl-derive (proc-macro: syn, quote, proc-macro2) + │ + ▼ + spider-tdl (rmp-serde, serde, thiserror; re-exports spider-tdl-derive) + │ + ├──► spider-executor (libloading, thiserror) + │ + └──► example-tdl-package (cdylib; serde, rmp-serde) +``` + +`spider-tdl` is the single source of truth for all shared types (C-FFI structs, error type, +`Task` trait, `TaskHandler`, wire format). Both the TDL package and the executor depend on it. + +Workspace root `Cargo.toml` change: + +```toml +members = [ + "components/spider-core", + "components/spider-derive", + "components/spider-storage", + "components/spider-tdl", + "components/spider-tdl-derive", + "components/spider-executor", + "examples/example-tdl-package", +] +``` + +--- + +## 2. Module Structure + +### 2.1 `spider-tdl` + +``` +components/spider-tdl/ + Cargo.toml + src/ + lib.rs # Re-exports all public API + tdl_types.rs # Type aliases (int8=i8, List=Vec, etc.) + error.rs # TdlError enum (shared across all components) + task_context.rs # TaskContext struct (runtime metadata, shared across all components) + ffi.rs # CArray, CCharArray, CByteArray, TaskExecutionResult (#[repr(C)]) + wire.rs # Wire format serde (adapted from claude/struct-serde/example) + task.rs # Task trait, TaskHandler trait, TaskHandlerImpl, ExecutionResult + register.rs # register_tasks! macro_rules +``` + +Dependencies: + +```toml +[dependencies] +spider-core = { path = "../spider-core" } +spider-tdl-derive = { path = "../spider-tdl-derive" } +rmp-serde = "1.3.1" +serde = { version = "1.0.228", features = ["derive"] } +thiserror = "2.0.18" +``` + +### 2.2 `spider-tdl-derive` + +``` +components/spider-tdl-derive/ + Cargo.toml + src/ + lib.rs # #[proc_macro_attribute] pub fn task(...) + task_macro.rs # Core code generation logic + validation.rs # Type validation rules +``` + +Dependencies: + +```toml +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0.106" +quote = "1.0.45" +syn = { version = "2.0.117", features = ["full"] } +``` + +### 2.3 `spider-executor` + +``` +components/spider-executor/ + Cargo.toml + src/ + lib.rs + loader.rs # TdlPackageLoader: dlopen, symbol lookup, safe wrappers + error.rs # Executor-specific error type +``` + +Dependencies: + +```toml +[dependencies] +spider-tdl = { path = "../spider-tdl" } +libloading = "0.8" +thiserror = "2.0.18" +``` + +### 2.4 `example-tdl-package` + +``` +examples/example-tdl-package/ + Cargo.toml + src/ + lib.rs # User structs, #[task] functions, register_tasks! invocation +``` + +```toml +[lib] +crate-type = ["cdylib"] + +[dependencies] +spider-tdl = { path = "../../components/spider-tdl" } +rmp-serde = "1.3.1" +serde = { version = "1.0.228", features = ["derive"] } +``` + +--- + +## 3. Type Aliases (`tdl_types`) + +```rust +// components/spider-tdl/src/tdl_types.rs + +pub type int8 = i8; +pub type int16 = i16; +pub type int32 = i32; +pub type int64 = i64; +pub type float = f32; +pub type double = f64; +pub type boolean = bool; +pub type Bytes = Vec; +pub type List = Vec; +pub type Map = std::collections::HashMap where K: MapKey; + +// --- Sealed marker trait restricting Map key types --- + +mod private { + pub trait Sealed {} +} + +/// Marker trait for types allowed as `Map` keys. +/// Sealed — users cannot implement this for their own types. +pub trait MapKey: Eq + std::hash::Hash + private::Sealed {} + +impl private::Sealed for i8 {} +impl private::Sealed for i16 {} +impl private::Sealed for i32 {} +impl private::Sealed for i64 {} +impl private::Sealed for Vec {} + +impl MapKey for i8 {} +impl MapKey for i16 {} +impl MapKey for i32 {} +impl MapKey for i64 {} +impl MapKey for Vec {} +``` + +This provides two layers of key-type enforcement: +1. **Type-level** -- `Map` fails to compile (no `MapKey` impl for `String`). +2. **Proc-macro** -- catches it earlier with a clearer error message naming the offending parameter. + +--- + +## 4. Error Type + +```rust +// components/spider-tdl/src/error.rs + +#[derive(Debug, thiserror::Error, serde::Serialize, serde::Deserialize)] +pub enum TdlError { + #[error("task not found: {0}")] + TaskNotFound(String), + + #[error("deserialization error: {0}")] + DeserializationError(String), + + #[error("serialization error: {0}")] + SerializationError(String), + + #[error("execution error: {0}")] + ExecutionError(String), + + #[error("{0}")] + Custom(String), +} +``` + +`TdlError` derives `Serialize + Deserialize` so it can be msgpack-encoded into +`ExecutionResult::Error(Vec)` and decoded on the executor side. + +This is the error type that user task functions return: `fn my_task(...) -> Result`. + +--- + +## 5. `TaskContext` + +Every task function must accept `TaskContext` as its **first** parameter. It carries runtime +metadata about the current task execution, constructed by the execution manager. + +```rust +// components/spider-tdl/src/task_context.rs + +use spider_core::types::id::{JobId, TaskId, TaskInstanceId}; + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct TaskContext { + pub job_id: JobId, + pub task_id: TaskId, + pub task_instance_id: TaskInstanceId, +} +``` + +**Visibility:** Defined in `spider-tdl`, re-exported publicly. Visible to: +- **Execution manager** -- constructs and msgpack-serializes it. +- **Task executor** -- passes serialized bytes through to the cdylib (opaque). +- **TDL package** -- deserializes it from msgpack and passes to the user's task function. + +**Serialization:** Plain msgpack (not wire format). A single `rmp_serde::to_vec(&ctx)` / +`rmp_serde::from_slice(&bytes)`. This is separate from the task inputs wire stream. + +**Why separate from inputs:** `TaskContext` is runtime metadata owned by the execution manager, +not user-supplied data from the storage layer. It travels a different path -- the execution +manager constructs it and attaches it alongside the input wire bytes. + +**Note on `TaskId`:** `TaskId` is already defined in `spider-core::types::id` (as +`Id`). No relocation from `spider-storage` is needed. + +--- + +## 6. C-FFI Types (`ffi`) + +All `#[repr(C)]` types shared between the TDL package and the executor. + +```rust +// components/spider-tdl/src/ffi.rs + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct CArray<'a, T> { + pointer: *const T, + length: usize, + _lifetime: PhantomData<&'a [T]>, +} + +impl<'a, T> CArray<'a, T> { + pub fn from_slice(slice: &'a [T]) -> Self { ... } + /// # Safety + /// The pointer must be valid for `length` elements. + pub unsafe fn as_slice(&self) -> &'a [T] { ... } +} + +pub type CCharArray<'a> = CArray<'a, c_char>; +pub type CByteArray<'a> = CArray<'a, u8>; +``` + +For `CCharArray`, add a convenience: + +```rust +impl<'a> CCharArray<'a> { + pub fn from_str(s: &'a str) -> Self { ... } // cast *const u8 to *const c_char + /// # Safety + /// The pointer must be valid UTF-8. + pub unsafe fn as_str(&self) -> &'a str { ... } +} +``` + +### `TaskExecutionResult` + +```rust +#[repr(C)] +pub struct TaskExecutionResult { + is_error: bool, + pointer: *mut u8, + length: usize, +} +``` + +Construction (inside TDL package, leaks the Vec): + +```rust +impl TaskExecutionResult { + pub fn from_execution_result(result: ExecutionResult) -> Self { + let (is_error, buffer) = match result { + ExecutionResult::Outputs(bytes) => (false, bytes), + ExecutionResult::Error(bytes) => (true, bytes), + }; + let boxed: Box<[u8]> = buffer.into_boxed_slice(); + let len = boxed.len(); + let ptr = Box::into_raw(boxed) as *mut u8; + Self { is_error, pointer: ptr, length: len } + } +} +``` + +Consumption (inside executor, reclaims ownership): + +```rust +impl TaskExecutionResult { + /// Reclaim ownership of the byte buffer. Must only be called once. + /// # Safety + /// The pointer must have been produced by `from_execution_result`. + pub unsafe fn into_result(self) -> Result, Vec> { + let boxed = unsafe { + Box::from_raw(std::slice::from_raw_parts_mut(self.pointer, self.length)) + }; + let vec = boxed.into_vec(); + if self.is_error { Err(vec) } else { Ok(vec) } + // Note: must prevent Drop from double-freeing. + // Use ManuallyDrop or mem::forget on self. + } +} +``` + +Memory safety note: this works because both the cdylib and the executor are in the same process +and share the same Rust global allocator. The allocation is created on one side +(`Box::into_raw`) and freed on the other (`Box::from_raw`). + +--- + +## 6. Serialization Formats + +Implemented in `components/spider-tdl/src/wire.rs`. + +There are two serialization layers: + +| Layer | Format | Purpose | +|-------|--------|---------| +| **Wire format** | Custom length-prefixed framing (u32 LE) | Frames a sequence of opaque payloads into a flat byte stream. Handles element boundaries. | +| **Payload format** | MessagePack (`rmp-serde`) | Encodes each individual value (input arg, output element, or error). Self-describing, compact binary. | + +The wire format never interprets payload bytes -- it reads/writes them as opaque +`[len][data]` chunks. All type-aware serialization is done by msgpack at the payload layer. + +**Wire format:** + +``` +[count: u32 LE] [len₀: u32 LE][payload₀ …] [len₁: u32 LE][payload₁ …] … +``` + +Both inputs and outputs share the same wire format. The count header is `u32` LE, and each +payload is prefixed by its `u32` LE byte length. + +### 6.1 `TaskInputs` — streaming serializer + struct deserializer + +**Serialization (storage layer):** `TaskInputs` is a streaming appender. Each `TaskInput` is +appended one at a time; the count header is deferred and patched on `release()`. + +```rust +let mut inputs = TaskInputs::new(); +inputs.append(TaskInput::ValuePayload(msgpack_bytes_0))?; +inputs.append(TaskInput::ValuePayload(msgpack_bytes_1))?; +let wire: Vec = inputs.release(); +``` + +No intermediate `Vec<&[u8]>` allocation. The internal `WireFrameBuilder` reserves 4 bytes for +the count header upfront and appends each payload's length + data inline. + +**Deserialization (TDL package):** `TaskInputs::deserialize()` zero-copy-deserializes the +wire buffer directly into a struct. Each field positionally consumes one payload: + +```rust +let params: MyParams = TaskInputs::deserialize(&wire)?; +``` + +Internally, a `StreamDeserializer` holds a borrowed `&[u8]` reference to the wire buffer. For +each struct field, it reads the length prefix and yields a `&[u8]` slice pointing directly into +the original buffer (zero-copy). That slice is fed to `rmp_serde::from_read_ref()`, which +deserializes the field value in one step. Total copies per field: one (from the msgpack payload +into the target value). + +``` +wire buffer (borrowed &[u8]): +┌───────┬──────┬───────────┬──────┬───────────┬───┐ +│ count │ len₀ │ payload₀ │ len₁ │ payload₁ │...│ +└───────┴──────┴─────┬─────┴──────┴─────┬─────┴───┘ + │ │ + &[u8] slice &[u8] slice ← zero-copy borrows + │ │ + ▼ ▼ + rmp_serde deser rmp_serde deser ← one deser per field + │ │ + ▼ ▼ + field₀ value field₁ value ← target Params struct +``` + +### 6.2 `TaskOutputs` — streaming serializer + Vec deserializer + +**Serialization (TDL package):** `TaskOutputs` is a streaming appender. Each tuple element is +serialized into msgpack and appended directly into the wire buffer in-place — no intermediate +`Vec` per element. + +```rust +// Generated by #[task] for return type (T0, T1): +fn serialize_return(result: &(T0, T1)) -> Result, TdlError> { + let mut outputs = TaskOutputs::new(); + outputs.append(&result.0)?; + outputs.append(&result.1)?; + Ok(outputs.release()) +} +``` + +`TaskOutputs::append(value: &V)` writes a placeholder `u32` length prefix, +serializes the value via `rmp_serde::encode::write` directly into the buffer, then back-patches +the length. This avoids allocating a temporary `Vec` per element — each element requires +exactly one copy (from the value into the wire buffer). + +**Deserialization (storage layer):** `TaskOutputs::deserialize()` extracts each payload as an +opaque `Vec` without decoding the msgpack contents: + +```rust +let outputs: Vec = TaskOutputs::deserialize(&wire)?; +// Each output is a raw msgpack blob, decoded downstream. +``` + +### 6.3 `WireFrameBuilder` — shared core + +Both `TaskInputs` and `TaskOutputs` are thin wrappers around a private `WireFrameBuilder`: + +- `new()` — allocates a buffer with 4 zero bytes (placeholder count header). +- `append_payload(&[u8])` — writes length prefix + raw bytes (used by `TaskInputs`). +- `append_serialize(&V)` — writes placeholder length, serializes in-place via + `rmp_serde::encode::write`, back-patches the length (used by `TaskOutputs`). +- `release()` — patches the count header and returns the buffer. + +Overflow is checked on every append: payload count and individual payload lengths must fit in +`u32`. Returns `WireError::Overflow` on failure. + +### 6.4 Errors + +**Format:** A single msgpack-encoded `TdlError` value (no wire framing). + +``` +[msgpack-encoded TdlError] +``` + +- No length-prefixed framing -- the entire byte buffer is one msgpack blob. +- `TdlError` derives `Serialize + Deserialize`, so `rmp_serde::to_vec(&err)` / + `rmp_serde::from_slice(&bytes)` is all that is needed. + +**Where serialized:** TDL package (cdylib), inside `TaskHandlerImpl::execute_raw()`, when the +task returns `Err(TdlError)` or deserialization/serialization fails. + +**Where deserialized:** Task executor, after reclaiming the `TaskExecutionResult` buffer, to +determine the failure reason. The executor may also forward the raw error bytes to the execution +manager if needed. + +### 6.5 Summary: what each component sees + +``` + TaskContext Inputs Outputs Errors + ─────────── ────── ─────── ────── +Storage — TaskInputs TaskOutputs — + → append+release ← deserialize + → Vec + +Exec Manager TaskContext raw bytes raw bytes raw bytes + → serialize (passthrough) (passthrough) (passthrough) + (msgpack) + +Task Executor raw bytes raw bytes raw bytes TdlError + (passthrough) (passthrough) (passthrough) ← deserialize + (msgpack) + +TDL Package TaskContext Params struct TaskOutputs TdlError + ← deserialize ← TaskInputs:: → append+release → serialize + (msgpack) deserialize (msgpack) +``` + +The wire error type (`WireError`) is specific to the wire module. It is converted to `TdlError` +at the `TaskHandlerImpl` boundary before crossing the C-FFI edge. + +--- + +## 7. `Task` Trait and `TaskHandler` + +```rust +// components/spider-tdl/src/task.rs + +pub enum ExecutionResult { + Outputs(Vec), + Error(Vec), +} + +pub trait Task { + const NAME: &'static str; + type Params: for<'de> serde::Deserialize<'de>; + type Return; + + fn execute(ctx: TaskContext, params: Self::Params) -> Result; + + /// Serialize the return tuple into wire format bytes. + /// Generated by the #[task] proc-macro. + fn serialize_return(result: &Self::Return) -> Result, TdlError>; +} + +pub trait TaskHandler: Send + Sync { + /// `raw_ctx` is msgpack-encoded `TaskContext`. + /// `raw_args` is wire-format-encoded task inputs. + fn execute_raw(&self, raw_ctx: &[u8], raw_args: &[u8]) -> ExecutionResult; + fn name(&self) -> &'static str; +} + +pub struct TaskHandlerImpl { + _marker: PhantomData, +} + +impl TaskHandlerImpl { + pub fn new() -> Self { + Self { _marker: PhantomData } + } +} + +impl TaskHandler for TaskHandlerImpl { + fn execute_raw(&self, raw_ctx: &[u8], raw_args: &[u8]) -> ExecutionResult { + // 1. Deserialize TaskContext (msgpack) + let ctx: TaskContext = match rmp_serde::from_slice(raw_ctx) { + Ok(c) => c, + Err(e) => { + let err = TdlError::DeserializationError( + format!("failed to deserialize TaskContext: {e}") + ); + return ExecutionResult::Error(rmp_serde::to_vec(&err).unwrap()); + } + }; + + // 2. Deserialize task inputs (wire format) + let params: T::Params = match wire::deserialize_task_inputs(raw_args) { + Ok(p) => p, + Err(e) => { + let err = TdlError::DeserializationError(e.to_string()); + return ExecutionResult::Error(rmp_serde::to_vec(&err).unwrap()); + } + }; + + // 3. Execute task with context + match T::execute(ctx, params) { + Ok(result) => { + // 4. Serialize outputs + match T::serialize_return(&result) { + Ok(bytes) => ExecutionResult::Outputs(bytes), + Err(e) => ExecutionResult::Error(rmp_serde::to_vec(&e).unwrap()), + } + } + Err(e) => ExecutionResult::Error(rmp_serde::to_vec(&e).unwrap()), + } + } + + fn name(&self) -> &'static str { + T::NAME + } +} +``` + +--- + +## 8. `#[task]` Proc-Macro + +### Input + +The first parameter must always be `ctx: TaskContext`. Remaining parameters are the task's +user-supplied inputs, deserialized from the wire format. + +```rust +#[task] +fn my_task(ctx: TaskContext, a: int32, b: MyStruct1) -> Result<(List, int64), TdlError> { + // user body +} +``` + +Or with a custom name: + +```rust +#[task(name = "my_namespace::my_task")] +fn my_task(ctx: TaskContext, a: int32, b: MyStruct1) -> Result<(List, int64), TdlError> { ... } +``` + +### Generated output + +The proc-macro strips `ctx: TaskContext` from the params struct (it is not part of the wire +inputs) and threads it through to the user function separately. + +```rust +/// Marker struct. +pub struct my_task; + +impl my_task { + /// The original user function, renamed. + fn __my_task(ctx: TaskContext, a: int32, b: MyStruct1) -> Result<(List, int64), TdlError> { + // original body + } +} + +/// Params struct for deserialization -- only the wire-format inputs, NOT TaskContext. +#[derive(serde::Deserialize)] +struct __my_task_params { + a: int32, + b: MyStruct1, +} + +impl spider_tdl::Task for my_task { + const NAME: &'static str = "my_task"; // or "my_namespace::my_task" + type Params = __my_task_params; + type Return = (List, int64); + + fn execute(ctx: spider_tdl::TaskContext, params: Self::Params) -> Result { + Self::__my_task(ctx, params.a, params.b) + } + + fn serialize_return(result: &Self::Return) -> Result, spider_tdl::TdlError> { + let mut outputs = spider_tdl::wire::TaskOutputs::new(); + outputs.append(&result.0).map_err(|e| spider_tdl::TdlError::SerializationError(e.to_string()))?; + outputs.append(&result.1).map_err(|e| spider_tdl::TdlError::SerializationError(e.to_string()))?; + Ok(outputs.release()) + } +} +``` + +### No-input tasks (empty params) + +When a task has no user-supplied inputs (only `TaskContext`): + +```rust +#[task] +fn my_context_only_task(ctx: TaskContext) -> Result<(int32,), TdlError> { + Ok((42,)) +} +``` + +The generated params struct is empty: + +```rust +#[derive(serde::Deserialize)] +struct __my_context_only_task_params {} + +impl spider_tdl::Task for my_context_only_task { + type Params = __my_context_only_task_params; + // ... + fn execute(ctx: spider_tdl::TaskContext, params: Self::Params) -> Result { + Self::__my_context_only_task(ctx) // no params to unpack + } +} +``` + +**Wire format for empty params:** The input wire bytes must still be a valid wire frame with +`count = 0`: + +``` +[0x00, 0x00, 0x00, 0x00] ← count: u32 LE = 0, no field entries +``` + +This is 4 bytes total. The `StreamDeserializer` reads `count = 0`, validates +`0 == fields.len()` (the empty struct has 0 fields), and calls `visitor.visit_seq()` which +immediately returns `None` for `next_element_seed` -- producing the empty struct with no +deserialization work. The storage layer must produce this 4-byte wire frame even when +`Vec` is empty. + +### Validation rules (compile-time errors) + +1. **First parameter must be `ctx: TaskContext`** -- the macro checks that the first argument's + type is `TaskContext`. Compile error if missing or if `TaskContext` appears at any other + position. +2. **Return type must be `Result<(...), TdlError>`** -- the Ok type must be a parenthesized + tuple (even for single values: `Result<(int32,), TdlError>`). +3. **Argument types must be supported types** -- primitives, aliases (`int32`, `Bytes`, etc.), + `Vec`/`List`, `HashMap`/`Map`, or user-defined structs (single-segment + identifiers not in the primitive set, assumed valid). This applies to all parameters after + `TaskContext`. +4. **Map key restriction** -- K must be one of `{i8, i16, i32, i64, int8, int16, int32, int64, + Vec, Bytes}`. +5. **No `self` parameter** -- must be a free function. +6. **Tuple element types** follow the same validation as argument types. + +Type alias resolution (e.g., `type MyInt = int32;`) is not possible at the syntactic level. +User-defined struct names pass validation and fail at serde time if incorrect. + +--- + +## 9. `register_tasks!` Macro + +A `macro_rules!` macro in `spider-tdl/src/register.rs`. + +### Usage + +```rust +spider_tdl::register_tasks! { + package_name: "my_package", + tasks: [my_task, another_task] +} +``` + +### Generated code + +```rust +static __SPIDER_TDL_REGISTRY: std::sync::LazyLock< + std::collections::HashMap<&'static str, Box> +> = std::sync::LazyLock::new(|| { + let mut map = std::collections::HashMap::new(); + map.insert( + ::NAME, + Box::new(spider_tdl::TaskHandlerImpl::::new()) as Box, + ); + map.insert( + ::NAME, + Box::new(spider_tdl::TaskHandlerImpl::::new()) as Box, + ); + map +}); + +static __SPIDER_TDL_PACKAGE_NAME: &str = "my_package"; + +#[unsafe(no_mangle)] +pub extern "C" fn __spider_tdl_package_get_name<'a>() -> spider_tdl::ffi::CCharArray<'a> { + spider_tdl::ffi::CCharArray::from_str(__SPIDER_TDL_PACKAGE_NAME) +} + +#[unsafe(no_mangle)] +pub extern "C" fn __spider_tdl_package_execute( + name: spider_tdl::ffi::CCharArray<'_>, + ctx: spider_tdl::ffi::CByteArray<'_>, + inputs: spider_tdl::ffi::CByteArray<'_>, +) -> spider_tdl::ffi::TaskExecutionResult { + let task_name: &str = unsafe { name.as_str() }; + let raw_ctx: &[u8] = unsafe { ctx.as_slice() }; + let raw_inputs: &[u8] = unsafe { inputs.as_slice() }; + + let result = match __SPIDER_TDL_REGISTRY.get(task_name) { + Some(handler) => handler.execute_raw(raw_ctx, raw_inputs), + None => { + let err = spider_tdl::TdlError::TaskNotFound(task_name.to_string()); + spider_tdl::ExecutionResult::Error(rmp_serde::to_vec(&err).unwrap()) + } + }; + + spider_tdl::ffi::TaskExecutionResult::from_execution_result(result) +} +``` + +Uses `LazyLock` (stable since Rust 1.80) -- no extra dependency needed. +Uses `#[unsafe(no_mangle)]` per Rust 2024 edition. + +--- + +## 10. Task Executor Driver + +### `TdlPackageLoader` + +```rust +// components/spider-executor/src/loader.rs + +pub struct TdlPackageLoader { + library: libloading::Library, +} + +type GetNameFn = unsafe extern "C" fn() -> CCharArray<'static>; +type ExecuteFn = unsafe extern "C" fn(CCharArray<'_>, CByteArray<'_>, CByteArray<'_>) -> TaskExecutionResult; + +impl TdlPackageLoader { + pub fn load(path: impl AsRef) -> Result { + let library = unsafe { libloading::Library::new(path.as_ref()) }?; + Ok(Self { library }) + } + + pub fn package_name(&self) -> Result<&str, ExecutorError> { + unsafe { + let func: libloading::Symbol = + self.library.get(b"__spider_tdl_package_get_name")?; + let name_arr = func(); + let bytes = name_arr.as_slice(); + std::str::from_utf8(bytes).map_err(|e| ExecutorError::InvalidUtf8(e)) + } + } + + /// Execute a task by name. + /// - `raw_ctx` is msgpack-encoded `TaskContext`, constructed by the execution manager. + /// - `raw_inputs` is wire-format-encoded task inputs, produced by the storage layer. + /// Both are passed through opaquely. + pub fn execute_task( + &self, + task_name: &str, + raw_ctx: &[u8], + raw_inputs: &[u8], + ) -> Result, TdlError> { + unsafe { + let func: libloading::Symbol = + self.library.get(b"__spider_tdl_package_execute")?; + let name_arr = CCharArray::from_str(task_name); + let ctx_arr = CByteArray::from_slice(raw_ctx); + let input_arr = CByteArray::from_slice(raw_inputs); + let result = func(name_arr, ctx_arr, input_arr); + + // Reclaim ownership and interpret + match result.into_result() { + Ok(output_bytes) => Ok(output_bytes), + Err(error_bytes) => { + let tdl_error: TdlError = rmp_serde::from_slice(&error_bytes)?; + Err(tdl_error) + } + } + } + } +} +``` + +The executor does NOT deserialize the output bytes -- it passes them back to the execution +manager. It only deserializes error bytes to determine the failure reason. + +--- + +## 11. End-to-End Data Flow + +`Vec` only exists in the storage layer. The storage layer serializes it into wire +bytes (`serialize_task_inputs`), and from that point on, only raw bytes flow through the system. + +`TaskContext` is constructed and serialized by the execution manager. It travels separately +from the task inputs. + +``` +Storage Execution Manager Task Executor TDL Package (cdylib) + │ │ │ │ + │ Vec │ │ │ + │ serialize_task_inputs │ │ │ + │ → input wire bytes │ │ │ + │ │ │ │ + │ task_name + │ │ │ + │ input wire bytes │ │ │ + ├──────────────────────►│ │ │ + │ │ │ │ + │ │ construct TaskContext │ │ + │ │ rmp_serde::to_vec() │ │ + │ │ → ctx bytes │ │ + │ │ │ │ + │ │ task_name + │ │ + │ │ ctx bytes + │ │ + │ │ input wire bytes │ │ + │ ├──────────────────────►│ │ + │ │ (via OS pipe) │ │ + │ │ │ │ + │ │ │ dlopen + symbol lookup │ + │ │ ├───────────────────────►│ + │ │ │ __spider_tdl_package_execute + │ │ │ (CCharArray, │ + │ │ │ CByteArray[ctx], │ + │ │ │ CByteArray[inputs]) │ + │ │ │ │ + │ │ │ │ rmp_serde::from_slice() + │ │ │ │ ctx bytes → TaskContext + │ │ │ │ + │ │ │ │ deserialize_task_inputs() + │ │ │ │ wire bytes → Params struct + │ │ │ │ + │ │ │ │ Task::execute(ctx, params) + │ │ │ │ → Result + │ │ │ │ + │ │ │ │ serialize_return() or + │ │ │ │ serialize error + │ │ │ │ → ExecutionResult + │ │ │ │ + │ │ │ TaskExecutionResult │ + │ │ │◄───────────────────────┤ + │ │ │ (repr(C), leaked Vec) │ + │ │ │ │ + │ │ │ into_result() │ + │ │ │ → reclaim Vec │ + │ │ │ │ + │ │ raw output/error bytes│ │ + │ │◄──────────────────────┤ │ + │ │ (via OS pipe) │ │ + │ │ │ │ + │ raw output/error bytes│ │ │ + │◄──────────────────────┤ │ │ + │ (passthru from mgr) │ │ │ + │ │ │ │ + │ deserialize outputs │ │ │ + │ wire bytes → Vec> │ │ + │ (each element is one │ │ │ + │ msgpack-encoded │ │ │ + │ output value) │ │ │ +``` + +Key points: +- `serialize_task_inputs()` is called in the **storage layer**, not the executor. +- `TaskContext` is constructed and msgpack-serialized by the **execution manager**. It is + passed as a separate byte stream alongside the input wire bytes. +- The task executor treats all three byte streams (ctx, inputs, outputs) as opaque passthrough. +- `deserialize_task_inputs()` is called in the **TDL package** (cdylib) only. +- Output bytes flow back opaquely to the **storage layer**, which deserializes the wire frame + into `Vec>` (each element is one msgpack-encoded tuple output value). + +--- + +## 12. Implementation Order + +1. **`spider-tdl` foundation** -- `tdl_types.rs`, `error.rs`, `ffi.rs` +2. **Wire format** -- implement `wire.rs` with `TaskInputs`, `TaskOutputs`, `WireFrameBuilder`, + `StreamDeserializer`, and `WireError` +3. **Task traits** -- `task.rs` with `Task`, `TaskHandler`, `TaskHandlerImpl`, `ExecutionResult` +4. **Registration macro** -- `register.rs` +5. **Proc-macro** -- `spider-tdl-derive` with `task_macro.rs` and `validation.rs` +6. **Executor** -- `spider-executor` with `loader.rs` +7. **Example package** -- `examples/example-tdl-package` +8. **Integration test** -- build example cdylib, load it from executor, run a task end-to-end + +--- + +## 13. Key Files to Reference + +| File | Purpose | +|---|---| +| `claude/struct-serde/example/src/lib.rs` | Wire format serde to adapt into `wire.rs` | +| `components/spider-derive/src/lib.rs` | Proc-macro entry point pattern | +| `components/spider-derive/src/mysql.rs` | Proc-macro codegen pattern with syn/quote | +| `components/spider-core/src/types/io.rs` | `TaskInput`, `TaskOutput` definitions | +| `components/spider-core/src/task.rs` | Error enum pattern, module structure | + +--- + +## 14. Verification Plan + +1. **Unit tests in `spider-tdl`**: wire format round-trip, error serialization round-trip, + `TaskExecutionResult` construction and reclamation. +2. **Proc-macro compile tests**: valid tasks compile; invalid return types, invalid map keys, + self parameters produce compile errors. +3. **Integration test**: build `example-tdl-package` as cdylib, use `TdlPackageLoader` to load + it, call `package_name()`, call `execute_task()` with serialized inputs, verify outputs + deserialize correctly. Also test error paths (unknown task, deserialization failure). diff --git a/claude/task-exec-prototyping/struct-serde.md b/claude/task-exec-prototyping/struct-serde.md new file mode 100644 index 00000000..4d5b48be --- /dev/null +++ b/claude/task-exec-prototyping/struct-serde.md @@ -0,0 +1,216 @@ +# Positional Deserialization: `Vec` Wire Format to Struct + +## Problem + +Given a `Vec` (defined in `components/spider-core/src/types/io.rs`): + +```rust +pub enum TaskInput { + ValuePayload(Vec), // each Vec is a msgpack-encoded field value +} +``` + +We want to: + +1. **Serialize** the `Vec` into a flat byte stream at the source. +2. **Deserialize** that byte stream directly into an arbitrary user-defined struct at the sink, + where each struct field positionally consumes one `TaskInput`'s payload. + +The sink struct only needs standard `#[derive(Deserialize)]` -- no custom derive macros. + +## Two-Layer Design + +There are two distinct serialization layers: + +| Layer | Format | Purpose | +|-------|--------|---------| +| **Wire format** | Custom length-prefixed framing (u32 LE) | Encodes the `Vec` sequence into a flat byte stream. Handles field boundaries. | +| **Payload format** | MessagePack (`rmp-serde`) | Encodes each individual field value inside its `ValuePayload`. Self-describing, compact binary. | + +The wire format only frames the sequence of payloads. It never interprets the payload bytes -- +it writes and reads them as opaque `[len][data]` chunks. All type-aware serialization is done +by msgpack at the payload layer. + +## Wire Format + +```text +[count: u32 LE] [len₀: u32 LE][data₀ …] [len₁: u32 LE][data₁ …] … +``` + +- `count` -- number of fields (= number of `TaskInput` elements). +- Each field is a `[len][data]` pair. `data` is the raw `Vec` from `ValuePayload`, written + verbatim (it is already msgpack-encoded by whatever produced the `TaskInput`). +- Fixed-width u32 LE lengths. Faster to parse than varints; 4 bytes of overhead per field is + negligible vs. actual payload. + +## Data Flow + +```text +Source Sink +------ ---- +Vec &[u8] (the wire buffer) + │ │ + ▼ ▼ +serialize_task_inputs() deserialize_task_inputs::() + │ │ + ▼ ▼ +flat byte stream ───── network/disk ─────► StreamDeserializer + │ + ┌─────┴──────────────┐ + │ For each field: │ + │ read [len][data] │ + │ data is &'de [u8] │ ← zero-copy slice + │ into wire buffer │ + │ rmp_serde:: │ + │ Deserializer │ + │ ::from_read_ref()│ ← one deser per field + └────────────────────┘ + │ + ▼ + T (the struct) +``` + +**Key property:** No intermediate `Vec` is constructed at the sink. Each field's bytes +are a borrowed `&[u8]` slice into the original wire buffer, and rmp_serde deserializes from that +slice directly. One deserialization step per field, one memory copy per field. + +## Implementation + +### Serialization (source side) + +Straightforward -- iterate and write length-prefixed chunks: + +```rust +pub fn serialize_task_inputs(inputs: &[TaskInput]) -> Vec { + let total: usize = 4 + inputs.iter().map(|i| { + let TaskInput::ValuePayload(b) = i; + 4 + b.len() + }).sum::(); + + let mut buf = Vec::with_capacity(total); + buf.extend_from_slice(&(inputs.len() as u32).to_le_bytes()); + for input in inputs { + let TaskInput::ValuePayload(bytes) = input; + buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); + buf.extend_from_slice(bytes); + } + buf +} +``` + +### Deserialization (sink side) -- custom serde Deserializer + +The entry point: + +```rust +pub fn deserialize_task_inputs<'de, T: Deserialize<'de>>(data: &'de [u8]) -> Result { + let mut de = StreamDeserializer::new(data)?; + T::deserialize(&mut de) +} +``` + +#### `StreamDeserializer` + +Holds the wire buffer and a cursor. Only implements `deserialize_struct`; all other +`deserialize_*` methods forward to an error via `forward_to_deserialize_any!`. + +```rust +struct StreamDeserializer<'de> { + data: &'de [u8], + pos: usize, + count: usize, // from the wire header + current_field: usize, + type_name: &'static str, + field_names: &'static [&'static str], +} +``` + +`deserialize_struct` validates `count == fields.len()` (producing a `LengthMismatch` error on +mismatch), then calls `visitor.visit_seq(FieldSeqAccess { ... })`. + +#### `FieldSeqAccess` (the core) + +Implements `serde::de::SeqAccess`. Each call to `next_element_seed`: + +1. Reads the next `[len: u32 LE][data: &'de [u8]]` from the buffer. +2. Creates a `rmp_serde::Deserializer::from_read_ref(data)`. +3. Calls `seed.deserialize(&mut rmp_de)` -- serde routes this to rmp_serde, which + deserializes the field value from the borrowed slice. +4. Maps `rmp_serde::decode::Error` to our `Error::FieldDeserialization { type_name, field, position, .. }`. + +```rust +impl<'a, 'de> SeqAccess<'de> for FieldSeqAccess<'a, 'de> { + type Error = Error; + + fn next_element_seed>( + &mut self, seed: T, + ) -> Result, Error> { + if self.de.current_field >= self.de.count { + return Ok(None); + } + let idx = self.de.current_field; + let field_name = self.de.field_names.get(idx).copied().unwrap_or(""); + let type_name = self.de.type_name; + + let bytes = self.de.next_field_bytes()?; // &'de [u8] into wire buffer + self.de.current_field += 1; + + let mut rmp_de = rmp_serde::Deserializer::from_read_ref(bytes); + seed.deserialize(&mut rmp_de) + .map(Some) + .map_err(|e| Error::FieldDeserialization { + type_name, field: field_name, position: idx, source: e, + }) + } +} +``` + +**Why `seed.deserialize(&mut rmp_de)` works across error types:** +`DeserializeSeed::deserialize` is generic over `D: Deserializer<'de>` and returns +`Result`. Here `D` is `&mut rmp_serde::Deserializer<...>`, so it returns +`Result`. We `.map_err()` that into our `Error` at the +`SeqAccess` boundary. The error types don't need to match inside `seed.deserialize` -- they +only need to match the `SeqAccess::Error` associated type on the way out. + +### Error Type + +```rust +pub enum Error { + LengthMismatch { type_name, expected, actual }, + FieldDeserialization { type_name, field, position, source: rmp_serde::decode::Error }, + InvalidFormat(&'static str), // wire buffer corruption + Custom(String), // required by serde::de::Error +} +``` + +Must implement `serde::de::Error` (for the `custom()` constructor) and `std::fmt::Display`. + +## Usage + +```rust +use serde::Deserialize; + +// Sink defines its own struct -- only needs standard serde. +#[derive(Deserialize)] +struct Job { + name: String, // consumes inputs[0] + priority: u32, // consumes inputs[1] + payload: Vec, // consumes inputs[2] +} + +// Source side: each TaskInput::ValuePayload contains rmp_serde::to_vec()-encoded bytes. +let wire: Vec = serialize_task_inputs(&task_inputs); + +// Sink side (e.g. after receiving `wire` over the network): +let job: Job = deserialize_task_inputs(&wire)?; +``` + +Fields with complex types (nested structs, enums, `Option`, etc.) work automatically -- each +field's `ValuePayload` is a self-contained msgpack blob, and rmp_serde handles the inner +structure. Because msgpack is self-describing, the payload layer is more resilient to type +mismatches than non-self-describing formats. + +## Working Example + +A compilable and tested example crate lives at `claude/struct-serde/example/`. +Dependencies: `rmp-serde = "1"`, `serde = "1"` (with `derive` feature). diff --git a/components/spider-task-executor/Cargo.toml b/components/spider-task-executor/Cargo.toml new file mode 100644 index 00000000..a2525fb4 --- /dev/null +++ b/components/spider-task-executor/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "spider-task-executor" +version = "0.1.0" +edition = "2024" + +[lib] +name = "spider_task_executor" +path = "src/lib.rs" + +[dependencies] +libloading = "0.9.0" +rmp-serde = "1.3.1" +spider-tdl = { path = "../spider-tdl" } +thiserror = "2.0.18" diff --git a/components/spider-task-executor/src/error.rs b/components/spider-task-executor/src/error.rs new file mode 100644 index 00000000..827a0790 --- /dev/null +++ b/components/spider-task-executor/src/error.rs @@ -0,0 +1,23 @@ +//! Error types for the task executor. + +use spider_tdl::TdlError; + +/// Errors produced by the task executor when loading or invoking a TDL package. +#[derive(Debug, thiserror::Error)] +pub enum ExecutorError { + /// The shared library could not be loaded via `dlopen`. + #[error("failed to load TDL package library: {0}")] + LibraryLoad(#[from] libloading::Error), + + /// A package with the same name is already loaded. + #[error("duplicate package name: {0}")] + DuplicatePackage(String), + + /// The task returned an error, deserialized from the msgpack error payload. + #[error("task execution failed: {0}")] + TaskError(#[from] TdlError), + + /// The error payload returned by the task could not be deserialized. + #[error("failed to deserialize error payload: {0}")] + ErrorPayloadDeserialization(#[from] rmp_serde::decode::Error), +} diff --git a/components/spider-task-executor/src/lib.rs b/components/spider-task-executor/src/lib.rs new file mode 100644 index 00000000..0da1d29e --- /dev/null +++ b/components/spider-task-executor/src/lib.rs @@ -0,0 +1,5 @@ +pub mod error; +pub mod loader; + +pub use error::ExecutorError; +pub use loader::{TdlPackage, TdlPackageLoader}; diff --git a/components/spider-task-executor/src/loader.rs b/components/spider-task-executor/src/loader.rs new file mode 100644 index 00000000..de968a01 --- /dev/null +++ b/components/spider-task-executor/src/loader.rs @@ -0,0 +1,167 @@ +//! TDL package loader. +//! +//! [`TdlPackageLoader`] manages a registry of loaded TDL packages (cdylibs), indexed by package +//! name. Each package is loaded via `dlopen` and its name is discovered by calling the +//! `__spider_tdl_package_get_name` C-FFI entry point. Callers look up a package by name via +//! [`TdlPackageLoader::get`] and invoke tasks directly on the returned [`TdlPackage`] reference. + +use std::{collections::HashMap, path::Path}; + +use spider_tdl::{ + TdlError, + ffi::{CByteArray, CCharArray, TaskExecutionResult}, +}; + +use crate::error::ExecutorError; + +type GetNameFn = unsafe extern "C" fn() -> CCharArray<'static>; +type ExecuteFn = + unsafe extern "C" fn(CCharArray<'_>, CByteArray<'_>, CByteArray<'_>) -> TaskExecutionResult; + +/// A single loaded TDL package backed by a `dlopen`-ed shared library. +/// +/// Obtained from [`TdlPackageLoader::get`]. Provides [`Self::execute_task`] to invoke a task +/// inside the package by name. +pub struct TdlPackage { + library: libloading::Library, +} + +impl TdlPackage { + /// Returns the package name declared by the loaded library. + /// + /// # Returns + /// + /// The package name as a `&str` on success. + /// + /// # Errors + /// + /// Returns [`ExecutorError::LibraryLoad`] if the `__spider_tdl_package_get_name` symbol is + /// not found. + pub fn get_name(&self) -> Result<&str, ExecutorError> { + // SAFETY: the library is a valid TDL package produced by `register_tasks!`, so the + // symbol exists and returns a CCharArray pointing to a static string inside the library. + unsafe { + let func: libloading::Symbol = + self.library.get(b"__spider_tdl_package_get_name")?; + let name_arr = func(); + // SAFETY: the package name is a Rust `&'static str` produced by `register_tasks!`, + // so it is guaranteed to be valid UTF-8. + Ok(name_arr.as_str()) + } + } + + /// Executes a task by name with raw serialized context and inputs. + /// + /// Both `raw_ctx` (msgpack-encoded `TaskContext`) and `raw_inputs` (wire-format task inputs) + /// are passed through opaquely to the C-FFI entry point. The executor does not interpret + /// their contents. + /// + /// # Returns + /// + /// The wire-format output bytes on success (opaque to the executor, decoded by the storage + /// layer). + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`ExecutorError::LibraryLoad`] if the `__spider_tdl_package_execute` symbol is not found. + /// * [`ExecutorError::TaskError`] if the task returned a [`TdlError`]. + /// * [`ExecutorError::ErrorPayloadDeserialization`] if the error payload cannot be decoded. + pub fn execute_task( + &self, + task_name: &str, + raw_ctx: &[u8], + raw_inputs: &[u8], + ) -> Result, ExecutorError> { + // SAFETY: the library is a valid TDL package produced by `register_tasks!`. The + // CCharArray/CByteArray values borrow from the caller's stack and remain valid for the + // duration of the synchronous FFI call. The returned TaskExecutionResult owns a buffer + // allocated by the same global allocator. + unsafe { + let func: libloading::Symbol = + self.library.get(b"__spider_tdl_package_execute")?; + let name_arr = CCharArray::from_str(task_name); + let ctx_arr = CByteArray::from_slice(raw_ctx); + let input_arr = CByteArray::from_slice(raw_inputs); + let result = func(name_arr, ctx_arr, input_arr); + + match result.into_result() { + Ok(output_bytes) => Ok(output_bytes), + Err(error_bytes) => { + let tdl_error: TdlError = rmp_serde::from_slice(&error_bytes)?; + Err(ExecutorError::TaskError(tdl_error)) + } + } + } + } +} + +/// Registry of loaded TDL packages, keyed by package name. +/// +/// Each package is loaded from a cdylib at runtime. The loader discovers the package name by +/// calling the library's `__spider_tdl_package_get_name` entry point and rejects duplicates. +/// Callers look up a package by name via [`Self::get`] and invoke tasks on the returned +/// [`TdlPackage`] reference. +pub struct TdlPackageLoader { + packages: HashMap, +} + +impl Default for TdlPackageLoader { + fn default() -> Self { + Self::new() + } +} + +impl TdlPackageLoader { + /// Creates an empty loader with no packages. + #[must_use] + pub fn new() -> Self { + Self { + packages: HashMap::new(), + } + } + + /// Loads a TDL package from the shared library at `path` and registers it by its declared + /// package name. + /// + /// The package name is discovered by calling `__spider_tdl_package_get_name` inside the + /// loaded library. + /// + /// # Returns + /// + /// The package name on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`ExecutorError::LibraryLoad`] if `dlopen` fails or the name symbol is missing. + /// * [`ExecutorError::DuplicatePackage`] if a package with the same name is already loaded. + /// + /// # Panics + /// + /// Panics if the just-inserted package cannot be found in the internal map -- this indicates + /// a logic error and cannot occur in practice. + pub fn load(&mut self, path: &Path) -> Result<&str, ExecutorError> { + // SAFETY: loading a shared library runs its init routines, but the safety contract is + // between the deployment environment and the library. + let library = unsafe { libloading::Library::new(path) }?; + let package = TdlPackage { library }; + + let name = package.get_name()?.to_owned(); + if self.packages.contains_key(&name) { + return Err(ExecutorError::DuplicatePackage(name)); + } + self.packages.insert(name.clone(), package); + + Ok(self.packages.get_key_value(&name).expect("just inserted").0) + } + + /// Returns a reference to the loaded package with the given name, or `None` if no such + /// package is loaded. + #[must_use] + pub fn get(&self, package_name: &str) -> Option<&TdlPackage> { + self.packages.get(package_name) + } +} diff --git a/components/spider-tdl-derive/Cargo.toml b/components/spider-tdl-derive/Cargo.toml new file mode 100644 index 00000000..39a6452c --- /dev/null +++ b/components/spider-tdl-derive/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "spider-tdl-derive" +version = "0.1.0" +edition = "2024" + +[lib] +name = "spider_tdl_derive" +path = "src/lib.rs" +proc-macro = true + +[dependencies] +proc-macro2 = "1.0.106" +quote = "1.0.45" +syn = { version = "2.0.117", features = ["full"] } diff --git a/components/spider-tdl-derive/src/lib.rs b/components/spider-tdl-derive/src/lib.rs new file mode 100644 index 00000000..99b1f4e3 --- /dev/null +++ b/components/spider-tdl-derive/src/lib.rs @@ -0,0 +1,32 @@ +mod task_macro; + +use proc_macro::TokenStream; +use syn::{ItemFn, parse_macro_input}; +use task_macro::TaskAttr; + +/// Attribute macro that transforms a task function into a marker struct with a [`spider_tdl::Task`] +/// implementation. +/// +/// # Usage +/// +/// ```ignore +/// #[task] +/// fn my_task(ctx: TaskContext, a: int32, b: int32) -> Result<(int32,), TdlError> { +/// Ok((a + b,)) +/// } +/// ``` +/// +/// An optional `name` argument overrides the registered task name: +/// +/// ```ignore +/// #[task(name = "my_namespace::my_task")] +/// fn my_task(ctx: TaskContext) -> Result<(int32,), TdlError> { ... } +/// ``` +#[proc_macro_attribute] +pub fn task(attr: TokenStream, item: TokenStream) -> TokenStream { + let attr = parse_macro_input!(attr as TaskAttr); + let item = parse_macro_input!(item as ItemFn); + task_macro::expand(&attr, &item) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} diff --git a/components/spider-tdl-derive/src/task_macro.rs b/components/spider-tdl-derive/src/task_macro.rs new file mode 100644 index 00000000..54aae227 --- /dev/null +++ b/components/spider-tdl-derive/src/task_macro.rs @@ -0,0 +1,450 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{ + FnArg, + GenericArgument, + ItemFn, + LitStr, + Pat, + PathArguments, + ReturnType, + Token, + Type, + parse::{Parse, ParseStream}, +}; + +/// Parsed representation of the `#[task(...)]` attribute arguments. +/// +/// Supports an optional `name = "..."` argument. When omitted, the function name is used as the +/// task name. +pub struct TaskAttr { + name: Option, +} + +impl Parse for TaskAttr { + fn parse(input: ParseStream) -> syn::Result { + if input.is_empty() { + return Ok(Self { name: None }); + } + + let ident: syn::Ident = input.parse()?; + if ident != "name" { + return Err(syn::Error::new_spanned(&ident, "expected `name = \"...\"`")); + } + input.parse::()?; + let name: LitStr = input.parse()?; + Ok(Self { name: Some(name) }) + } +} + +/// Expands a `#[task]` annotated function into a marker struct, params struct, and `Task` trait +/// implementation. +pub fn expand(attr: &TaskAttr, func: &ItemFn) -> syn::Result { + validate_no_self(func)?; + validate_has_parameters(func)?; + + let func_name = &func.sig.ident; + let private_method_name = format_ident!("__{func_name}"); + let params_struct_name = format_ident!("__{func_name}_params"); + + let task_name_str = attr + .name + .as_ref() + .map_or_else(|| func_name.to_string(), LitStr::value); + + let first_param = func + .sig + .inputs + .first() + .expect("validated that function has at least one parameter"); + validate_first_param_is_task_context(first_param)?; + + let (return_type_tokens, needs_return_wrapping) = extract_return_tuple(&func.sig.output)?; + + let task_params: Vec<_> = func.sig.inputs.iter().skip(1).collect(); + + let param_fields: Vec = task_params + .iter() + .map(|arg| { + let FnArg::Typed(pat_type) = arg else { + unreachable!("self parameters are rejected by validation"); + }; + let pat = &pat_type.pat; + let ty = &pat_type.ty; + quote! { #pat: #ty } + }) + .collect(); + + let param_field_names: Vec<&Box> = task_params + .iter() + .map(|arg| { + let FnArg::Typed(pat_type) = arg else { + unreachable!("self parameters are rejected by validation"); + }; + &pat_type.pat + }) + .collect(); + + let original_params = &func.sig.inputs; + let original_output = &func.sig.output; + let original_body = &func.block; + let vis = &func.vis; + + let execute_call = if param_field_names.is_empty() { + quote! { Self::#private_method_name(ctx) } + } else { + let field_accesses = param_field_names.iter().map(|name| { + quote! { params.#name } + }); + quote! { Self::#private_method_name(ctx, #(#field_accesses),*) } + }; + + let params_arg = if param_field_names.is_empty() { + quote! { _params } + } else { + quote! { params } + }; + + let private_method = if needs_return_wrapping { + let ReturnType::Type(_, original_return_type) = original_output else { + unreachable!("validated that function has a return type"); + }; + let wrapped_return = quote! { + Result<#return_type_tokens, spider_tdl::TdlError> + }; + quote! { + #[allow(clippy::redundant_closure_call)] + fn #private_method_name(#original_params) -> #wrapped_return { + (|| -> #original_return_type #original_body)().map(|__v| (__v,)) + } + } + } else { + quote! { + fn #private_method_name(#original_params) #original_output + #original_body + } + }; + + let expanded = quote! { + #[allow(non_camel_case_types)] + #vis struct #func_name; + + impl #func_name { + #private_method + } + + #[derive(serde::Deserialize)] + struct #params_struct_name { + #(#param_fields,)* + } + + impl spider_tdl::Task for #func_name { + const NAME: &'static str = #task_name_str; + type Params = #params_struct_name; + type Return = #return_type_tokens; + + fn execute( + ctx: spider_tdl::TaskContext, + #params_arg: Self::Params, + ) -> Result { + #execute_call + } + } + }; + + Ok(expanded) +} + +fn validate_no_self(func: &ItemFn) -> syn::Result<()> { + for arg in &func.sig.inputs { + if let FnArg::Receiver(receiver) = arg { + return Err(syn::Error::new_spanned( + receiver, + "task functions must not have a `self` parameter", + )); + } + } + Ok(()) +} + +fn validate_has_parameters(func: &ItemFn) -> syn::Result<()> { + if func.sig.inputs.is_empty() { + return Err(syn::Error::new_spanned( + &func.sig, + "task functions must have at least one parameter (TaskContext)", + )); + } + Ok(()) +} + +fn validate_first_param_is_task_context(param: &FnArg) -> syn::Result<()> { + let FnArg::Typed(pat_type) = param else { + return Err(syn::Error::new_spanned( + param, + "first parameter must be `TaskContext`, not `self`", + )); + }; + + let Type::Path(type_path) = pat_type.ty.as_ref() else { + return Err(syn::Error::new_spanned( + &pat_type.ty, + "first parameter must have type `TaskContext`", + )); + }; + + let last_segment = type_path + .path + .segments + .last() + .expect("type path should have at least one segment"); + + if last_segment.ident != "TaskContext" { + return Err(syn::Error::new_spanned( + &pat_type.ty, + "first parameter must have type `TaskContext`", + )); + } + + Ok(()) +} + +/// Returns `(return_type_tokens, needs_wrapping)` where `needs_wrapping` is `true` when the +/// user wrote a bare type (e.g., `int32`) that was auto-wrapped into `(int32,)`. +fn extract_return_tuple(output: &ReturnType) -> syn::Result<(TokenStream, bool)> { + let ReturnType::Type(_, return_type) = output else { + return Err(syn::Error::new_spanned( + output, + "task functions must return `Result<(T, ...), TdlError>`", + )); + }; + + let Type::Path(type_path) = return_type.as_ref() else { + return Err(syn::Error::new_spanned( + return_type, + "task functions must return `Result<(T, ...), TdlError>`", + )); + }; + + let last_segment = type_path + .path + .segments + .last() + .expect("return type path should have at least one segment"); + + if last_segment.ident != "Result" { + return Err(syn::Error::new_spanned( + return_type, + "task functions must return `Result<(T, ...), TdlError>`", + )); + } + + let PathArguments::AngleBracketed(angle_args) = &last_segment.arguments else { + return Err(syn::Error::new_spanned( + &last_segment.arguments, + "expected generic arguments on `Result`", + )); + }; + + let first_arg = angle_args.args.first().ok_or_else(|| { + syn::Error::new_spanned( + angle_args, + "expected at least one generic argument on `Result`", + ) + })?; + + let GenericArgument::Type(ok_type) = first_arg else { + return Err(syn::Error::new_spanned( + first_arg, + "expected a type as the first generic argument of `Result`", + )); + }; + + if let Type::Tuple(tuple_type) = ok_type { + Ok((quote! { #tuple_type }, false)) + } else { + Ok((quote! { (#ok_type,) }, true)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Parses `attr_str` as task attributes, `func_str` as a function, expands, and returns the + /// normalized token string. + fn expand_to_string(attr_str: &str, func_str: &str) -> String { + let attr: TaskAttr = syn::parse_str(attr_str).expect("failed to parse task attribute"); + let func: ItemFn = syn::parse_str(func_str).expect("failed to parse function"); + let expanded = expand(&attr, &func).expect("macro expansion failed"); + expanded.to_string() + } + + /// Normalizes a `TokenStream` to a comparable string. + fn normalize(tokens: &TokenStream) -> String { + tokens.to_string() + } + + #[test] + fn expand_task_with_tuple_return() { + let actual = expand_to_string( + "", + r" + fn add(ctx: TaskContext, a: int32, b: int32) -> Result<(int32, int32), TdlError> { + Ok((a + b, a - b)) + } + ", + ); + + let expected = normalize("e! { + #[allow(non_camel_case_types)] + struct add; + + impl add { + fn __add(ctx: TaskContext, a: int32, b: int32) -> Result<(int32, int32), TdlError> { + Ok((a + b, a - b)) + } + } + + #[derive(serde::Deserialize)] + struct __add_params { + a: int32, + b: int32, + } + + impl spider_tdl::Task for add { + const NAME: &'static str = "add"; + type Params = __add_params; + type Return = (int32, int32); + + fn execute( + ctx: spider_tdl::TaskContext, + params: Self::Params, + ) -> Result { + Self::__add(ctx, params.a, params.b) + } + } + }); + + assert_eq!(actual, expected); + } + + #[test] + fn expand_task_with_custom_name() { + let actual = expand_to_string( + r#"name = "my_ns::my_task""#, + r" + fn my_task(ctx: TaskContext, x: int64) -> Result<(int64,), TdlError> { + Ok((x,)) + } + ", + ); + + assert!(actual.contains(r#"const NAME : & 'static str = "my_ns::my_task""#)); + } + + #[test] + fn expand_task_empty_params() { + let actual = expand_to_string( + "", + r" + fn noop(ctx: TaskContext) -> Result<(int32,), TdlError> { + Ok((42,)) + } + ", + ); + + let expected = normalize("e! { + #[allow(non_camel_case_types)] + struct noop; + + impl noop { + fn __noop(ctx: TaskContext) -> Result<(int32,), TdlError> { + Ok((42,)) + } + } + + #[derive(serde::Deserialize)] + struct __noop_params {} + + impl spider_tdl::Task for noop { + const NAME: &'static str = "noop"; + type Params = __noop_params; + type Return = (int32,); + + fn execute( + ctx: spider_tdl::TaskContext, + _params: Self::Params, + ) -> Result { + Self::__noop(ctx) + } + } + }); + + assert_eq!(actual, expected); + } + + #[test] + fn reject_missing_task_context() { + let attr: TaskAttr = syn::parse_str("").expect("failed to parse attribute"); + let func: ItemFn = + syn::parse_str("fn bad(a: int32) -> Result<(int32,), TdlError> { Ok((a,)) }") + .expect("failed to parse function"); + + let err = expand(&attr, &func).expect_err("expected error for missing TaskContext"); + assert!(err.to_string().contains("TaskContext")); + } + + #[test] + fn auto_wrap_single_value_return() { + let actual = expand_to_string( + "", + r" + fn single(ctx: TaskContext, x: int32) -> Result { + Ok(x) + } + ", + ); + + let expected = normalize("e! { + #[allow(non_camel_case_types)] + struct single; + + impl single { + #[allow(clippy::redundant_closure_call)] + fn __single(ctx: TaskContext, x: int32) -> Result<(int32,), spider_tdl::TdlError> { + (|| -> Result { Ok(x) })().map(|__v| (__v,)) + } + } + + #[derive(serde::Deserialize)] + struct __single_params { + x: int32, + } + + impl spider_tdl::Task for single { + const NAME: &'static str = "single"; + type Params = __single_params; + type Return = (int32,); + + fn execute( + ctx: spider_tdl::TaskContext, + params: Self::Params, + ) -> Result { + Self::__single(ctx, params.x) + } + } + }); + + assert_eq!(actual, expected); + } + + #[test] + fn reject_no_parameters() { + let attr: TaskAttr = syn::parse_str("").expect("failed to parse attribute"); + let func: ItemFn = syn::parse_str("fn bad() -> Result<(int32,), TdlError> { Ok((42,)) }") + .expect("failed to parse function"); + + let err = expand(&attr, &func).expect_err("expected error for no parameters"); + assert!(err.to_string().contains("at least one parameter")); + } +} diff --git a/components/spider-tdl/Cargo.toml b/components/spider-tdl/Cargo.toml new file mode 100644 index 00000000..cf5bfefe --- /dev/null +++ b/components/spider-tdl/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "spider-tdl" +version = "0.1.0" +edition = "2024" + +[lib] +name = "spider_tdl" +path = "src/lib.rs" + +[dependencies] +rmp-serde = "1.3.1" +serde = { version = "1.0.228", features = ["derive"] } +spider-core = { path = "../spider-core" } +spider-tdl-derive = { path = "../spider-tdl-derive" } +thiserror = "2.0.18" + +[dev-dependencies] +anyhow = "1.0.98" diff --git a/components/spider-tdl/src/error.rs b/components/spider-tdl/src/error.rs new file mode 100644 index 00000000..6b202921 --- /dev/null +++ b/components/spider-tdl/src/error.rs @@ -0,0 +1,58 @@ +//! Error type returned from user-authored TDL tasks. +//! +//! [`TdlError`] crosses the C-FFI boundary as a msgpack-encoded payload inside +//! `ExecutionResult::Error`, so it derives both `serde::Serialize` and `serde::Deserialize`. + +/// Errors produced while deserializing inputs, executing a user task, or serializing outputs. +/// +/// User task functions return `Result`. The `TaskHandlerImpl` wrapper additionally +/// produces `TdlError` values for framing failures on either side of the wire. +#[derive(Debug, thiserror::Error, serde::Serialize, serde::Deserialize)] +pub enum TdlError { + #[error("task not found: {0}")] + TaskNotFound(String), + + #[error("deserialization error: {0}")] + DeserializationError(String), + + #[error("serialization error: {0}")] + SerializationError(String), + + #[error("execution error: {0}")] + ExecutionError(String), + + #[error("{0}")] + Custom(String), +} + +#[cfg(test)] +mod tests { + use super::TdlError; + + #[test] + fn round_trip_task_not_found() -> anyhow::Result<()> { + let original = TdlError::TaskNotFound("my_task".to_owned()); + let encoded = rmp_serde::to_vec(&original)?; + let decoded: TdlError = rmp_serde::from_slice(&encoded)?; + assert!(matches!(decoded, TdlError::TaskNotFound(ref name) if name == "my_task")); + Ok(()) + } + + #[test] + fn round_trip_all_variants() -> anyhow::Result<()> { + let cases = [ + TdlError::TaskNotFound("t".to_owned()), + TdlError::DeserializationError("d".to_owned()), + TdlError::SerializationError("s".to_owned()), + TdlError::ExecutionError("e".to_owned()), + TdlError::Custom("c".to_owned()), + ]; + for original in cases { + let original_display = original.to_string(); + let encoded = rmp_serde::to_vec(&original)?; + let decoded: TdlError = rmp_serde::from_slice(&encoded)?; + assert_eq!(decoded.to_string(), original_display); + } + Ok(()) + } +} diff --git a/components/spider-tdl/src/ffi.rs b/components/spider-tdl/src/ffi.rs new file mode 100644 index 00000000..635ef04e --- /dev/null +++ b/components/spider-tdl/src/ffi.rs @@ -0,0 +1,280 @@ +//! `#[repr(C)]` types shared across the TDL package / task executor C-FFI boundary. +//! +//! Both sides of the boundary live in the same process and share the same Rust global allocator, +//! so buffers allocated on one side can be reclaimed on the other via `Box::into_raw` / +//! `Box::from_raw`. These types are intentionally thin: they carry pointers and lengths only. + +use std::{ffi::c_char, fmt, marker::PhantomData, mem::ManuallyDrop}; + +/// Borrowed, C-ABI-compatible view of a contiguous slice `&'borrow_lifetime [ElementType]`. +/// +/// The lifetime parameter is tracked via [`PhantomData`] so that a [`CArray`] cannot outlive the +/// slice it was constructed from when it stays on the Rust side. Once the value is passed across +/// the C-FFI boundary, the lifetime is erased and safety falls to the caller. +#[repr(C)] +pub struct CArray<'borrow_lifetime, ElementType> { + pointer: *const ElementType, + length: usize, + _lifetime: PhantomData<&'borrow_lifetime [ElementType]>, +} + +// Manual `Copy`/`Clone` impls avoid the auto-derived `ElementType: Copy` / `ElementType: Clone` +// bounds: a borrowed pointer/length pair is always trivially copyable regardless of the element +// type. +impl Copy for CArray<'_, ElementType> {} + +impl Clone for CArray<'_, ElementType> { + fn clone(&self) -> Self { + *self + } +} + +impl fmt::Debug for CArray<'_, ElementType> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CArray") + .field("pointer", &self.pointer) + .field("length", &self.length) + .finish() + } +} + +impl<'borrow_lifetime, ElementType> CArray<'borrow_lifetime, ElementType> { + /// Borrows `slice` as a C-ABI array view. + /// + /// The returned [`CArray`] is tied to the lifetime of `slice`; the pointer remains valid as + /// long as the original slice is not moved or dropped. + pub const fn from_slice(slice: &'borrow_lifetime [ElementType]) -> Self { + Self { + pointer: slice.as_ptr(), + length: slice.len(), + _lifetime: PhantomData, + } + } + + /// Returns the number of elements in the borrowed view. + #[must_use] + pub const fn len(&self) -> usize { + self.length + } + + /// Returns `true` if the view contains no elements. + #[must_use] + pub const fn is_empty(&self) -> bool { + self.length == 0 + } + + /// Reconstructs a Rust slice from the raw pointer and length. + /// + /// # Returns + /// + /// A slice of `length` elements starting at `pointer`. + /// + /// # Safety + /// + /// The caller must guarantee that: + /// + /// * `pointer` points to a single, contiguous allocation of at least `length` elements of + /// `ElementType`, properly initialized. + /// * The memory remains valid and immutable for the returned lifetime. + /// * `length * size_of::()` does not exceed `isize::MAX`. + #[must_use] + pub const unsafe fn as_slice(&self) -> &'borrow_lifetime [ElementType] { + // SAFETY: the caller upholds the invariants documented above. + unsafe { std::slice::from_raw_parts(self.pointer, self.length) } + } +} + +/// Borrowed view of a UTF-8 string as a `char`-typed C array. +pub type CCharArray<'borrow_lifetime> = CArray<'borrow_lifetime, c_char>; + +/// Borrowed view of a raw byte buffer. +pub type CByteArray<'borrow_lifetime> = CArray<'borrow_lifetime, u8>; + +impl<'borrow_lifetime> CCharArray<'borrow_lifetime> { + /// Borrows `s` as a C-ABI char array view. + /// + /// The returned view is **not** NUL-terminated; both sides of the FFI boundary rely on the + /// explicit `length` field rather than a terminator. + #[must_use] + pub const fn from_str(s: &'borrow_lifetime str) -> Self { + Self { + pointer: s.as_ptr().cast::(), + length: s.len(), + _lifetime: PhantomData, + } + } + + /// Reconstructs a Rust `&str` from the raw pointer and length. + /// + /// # Returns + /// + /// A `&str` view of the underlying bytes. + /// + /// # Safety + /// + /// In addition to the invariants required by [`CArray::as_slice`], the caller must guarantee + /// that the bytes are valid UTF-8. No validation is performed. + #[must_use] + pub const unsafe fn as_str(&self) -> &'borrow_lifetime str { + // SAFETY: the caller guarantees pointer validity and UTF-8 correctness. + let bytes = unsafe { std::slice::from_raw_parts(self.pointer.cast::(), self.length) }; + // SAFETY: the caller guarantees the bytes are valid UTF-8. + unsafe { std::str::from_utf8_unchecked(bytes) } + } +} + +/// Owned, C-ABI-compatible result buffer returned from a TDL package's `execute` entry point. +/// +/// The buffer is allocated on the TDL-package side by leaking a `Box<[u8]>` via [`Box::into_raw`] +/// and reclaimed on the executor side via [`Box::from_raw`]. This only works because both sides +/// share the same global allocator, which is true when the package is loaded via `dlopen` into +/// the executor process. +/// +/// Instances are produced by [`TaskExecutionResult::from_outputs`] / +/// [`TaskExecutionResult::from_error`] and consumed exactly once by +/// [`TaskExecutionResult::into_result`]; using any other lifecycle risks a double free or leak. +#[repr(C)] +pub struct TaskExecutionResult { + is_error: bool, + pointer: *mut u8, + length: usize, +} + +impl TaskExecutionResult { + /// Constructs a successful result wrapping wire-format-encoded output bytes. + #[must_use] + pub fn from_outputs(bytes: Vec) -> Self { + Self::from_buffer(false, bytes) + } + + /// Constructs a failing result wrapping msgpack-encoded `TdlError` bytes. + #[must_use] + pub fn from_error(bytes: Vec) -> Self { + Self::from_buffer(true, bytes) + } + + /// Reclaims ownership of the leaked buffer and returns it. + /// + /// This must be called **exactly once** for every value produced by [`Self::from_outputs`] or + /// [`Self::from_error`]; the reclaimed `Vec` takes back allocator ownership and will free + /// the buffer when dropped. + /// + /// # Returns + /// + /// `Ok(bytes)` on success, where `bytes` is the wire-format output payload produced by the + /// user task. + /// + /// # Errors + /// + /// Returns `Err(bytes)` if the result represented failure, where `bytes` is a + /// msgpack-encoded `TdlError` produced inside the TDL package. The caller is responsible for + /// decoding it via `rmp_serde::from_slice`. + /// + /// # Safety + /// + /// The caller must guarantee that: + /// + /// * `self.pointer` / `self.length` originated from a prior call to [`Self::from_outputs`] or + /// [`Self::from_error`] in a component that shares this process's global allocator. + /// * The buffer has not already been reclaimed. + pub unsafe fn into_result(self) -> Result, Vec> { + // Prevent the destructor from running after we reconstruct the `Box`. + let this = ManuallyDrop::new(self); + // SAFETY: the caller upholds the invariants documented above. The buffer was produced by + // `Box::<[u8]>::into_raw`, and the same length is recorded in `this.length`, so + // reconstructing the slice and boxing it is sound. + let boxed: Box<[u8]> = unsafe { + Box::from_raw(std::ptr::slice_from_raw_parts_mut( + this.pointer, + this.length, + )) + }; + let vec = boxed.into_vec(); + if this.is_error { Err(vec) } else { Ok(vec) } + } + + /// Converts an [`ExecutionResult`] into its C-ABI-compatible form. + /// + /// This is the primary conversion used by the `register_tasks!` macro's generated + /// `__spider_tdl_package_execute` entry point. + #[must_use] + pub fn from_execution_result(result: crate::ExecutionResult) -> Self { + match result { + crate::ExecutionResult::Outputs(bytes) => Self::from_outputs(bytes), + crate::ExecutionResult::Error(bytes) => Self::from_error(bytes), + } + } + + fn from_buffer(is_error: bool, buffer: Vec) -> Self { + let boxed: Box<[u8]> = buffer.into_boxed_slice(); + let length = boxed.len(); + let pointer = Box::into_raw(boxed).cast::(); + Self { + is_error, + pointer, + length, + } + } +} + +#[cfg(test)] +mod tests { + use super::{CByteArray, CCharArray, TaskExecutionResult}; + + #[test] + fn c_byte_array_round_trip() { + let data: [u8; 5] = [1, 2, 3, 4, 5]; + let view = CByteArray::from_slice(&data); + assert_eq!(view.len(), 5); + assert!(!view.is_empty()); + // SAFETY: `data` is still alive, so the borrowed view is valid. + let reconstructed = unsafe { view.as_slice() }; + assert_eq!(reconstructed, &data[..]); + } + + #[test] + fn c_byte_array_empty() { + let data: [u8; 0] = []; + let view = CByteArray::from_slice(&data); + assert_eq!(view.len(), 0); + assert!(view.is_empty()); + } + + #[test] + fn c_char_array_round_trip() { + let original = "hello, TDL"; + let view = CCharArray::from_str(original); + assert_eq!(view.len(), original.len()); + // SAFETY: `original` is still alive and is valid UTF-8. + let reconstructed = unsafe { view.as_str() }; + assert_eq!(reconstructed, original); + } + + #[test] + fn task_execution_result_success_round_trip() { + let payload = vec![10u8, 20, 30, 40]; + let expected = payload.clone(); + let result = TaskExecutionResult::from_outputs(payload); + // SAFETY: the result was freshly produced by `from_outputs` and not reclaimed yet. + let reclaimed = unsafe { result.into_result() }; + assert_eq!(reclaimed, Ok(expected)); + } + + #[test] + fn task_execution_result_error_round_trip() { + let payload = vec![0xdeu8, 0xad, 0xbe, 0xef]; + let expected = payload.clone(); + let result = TaskExecutionResult::from_error(payload); + // SAFETY: the result was freshly produced by `from_error` and not reclaimed yet. + let reclaimed = unsafe { result.into_result() }; + assert_eq!(reclaimed, Err(expected)); + } + + #[test] + fn task_execution_result_empty_buffer() { + let result = TaskExecutionResult::from_outputs(Vec::new()); + // SAFETY: freshly produced, single consumption. + let reclaimed = unsafe { result.into_result() }; + assert_eq!(reclaimed, Ok(Vec::new())); + } +} diff --git a/components/spider-tdl/src/lib.rs b/components/spider-tdl/src/lib.rs new file mode 100644 index 00000000..cfde3cce --- /dev/null +++ b/components/spider-tdl/src/lib.rs @@ -0,0 +1,12 @@ +pub mod error; +pub mod ffi; +pub mod register; +pub mod task; +pub mod task_context; +pub mod tdl_types; +pub mod wire; + +pub use error::TdlError; +pub use spider_tdl_derive::task; +pub use task::{ExecutionResult, Task, TaskHandler, TaskHandlerImpl}; +pub use task_context::TaskContext; diff --git a/components/spider-tdl/src/register.rs b/components/spider-tdl/src/register.rs new file mode 100644 index 00000000..70b3e0df --- /dev/null +++ b/components/spider-tdl/src/register.rs @@ -0,0 +1,75 @@ +//! Registration macro for TDL task packages. +//! +//! The [`register_tasks!`] macro generates the static dispatch table and C-FFI entry points that +//! the task executor uses to discover and invoke tasks within a compiled TDL package. + +/// Generates the TDL package's task registry and C-FFI entry points. +/// +/// # Usage +/// +/// ```ignore +/// spider_tdl::register_tasks! { +/// package_name: "my_package", +/// tasks: [my_task, another_task] +/// } +/// ``` +/// +/// # Generated items +/// +/// * `__SPIDER_TDL_REGISTRY` — a `LazyLock>>` dispatch +/// table populated from the listed task types. +/// * `__SPIDER_TDL_PACKAGE_NAME` — a `&'static str` holding the package name. +/// * `__spider_tdl_package_get_name` — an `extern "C"` function returning the package name as a +/// [`CCharArray`][crate::ffi::CCharArray]. +/// * `__spider_tdl_package_execute` — an `extern "C"` function that looks up a task by name in the +/// registry and executes it, returning a +/// [`TaskExecutionResult`][crate::ffi::TaskExecutionResult]. +#[macro_export] +macro_rules! register_tasks { + ( + package_name: $package_name:expr, + tasks: [$($task:ty),* $(,)?] + ) => { + static __SPIDER_TDL_REGISTRY: std::sync::LazyLock< + std::collections::HashMap<&'static str, Box>, + > = std::sync::LazyLock::new(|| { + let mut map = std::collections::HashMap::new(); + $( + map.insert( + <$task as $crate::Task>::NAME, + Box::new($crate::TaskHandlerImpl::<$task>::new()) + as Box, + ); + )* + map + }); + + static __SPIDER_TDL_PACKAGE_NAME: &str = $package_name; + + #[unsafe(no_mangle)] + pub extern "C" fn __spider_tdl_package_get_name<'a>() -> $crate::ffi::CCharArray<'a> { + $crate::ffi::CCharArray::from_str(__SPIDER_TDL_PACKAGE_NAME) + } + + #[unsafe(no_mangle)] + pub extern "C" fn __spider_tdl_package_execute( + name: $crate::ffi::CCharArray<'_>, + ctx: $crate::ffi::CByteArray<'_>, + inputs: $crate::ffi::CByteArray<'_>, + ) -> $crate::ffi::TaskExecutionResult { + let task_name: &str = unsafe { name.as_str() }; + let raw_ctx: &[u8] = unsafe { ctx.as_slice() }; + let raw_inputs: &[u8] = unsafe { inputs.as_slice() }; + + let result = match __SPIDER_TDL_REGISTRY.get(task_name) { + Some(handler) => handler.execute_raw(raw_ctx, raw_inputs), + None => { + let err = $crate::TdlError::TaskNotFound(task_name.to_string()); + $crate::ExecutionResult::from_tdl_error(&err) + } + }; + + $crate::ffi::TaskExecutionResult::from_execution_result(result) + } + }; +} diff --git a/components/spider-tdl/src/task.rs b/components/spider-tdl/src/task.rs new file mode 100644 index 00000000..111c3fbe --- /dev/null +++ b/components/spider-tdl/src/task.rs @@ -0,0 +1,335 @@ +//! Task execution traits and runtime handler. +//! +//! This module defines the [`Task`] trait (implemented by the `#[task]` proc-macro for each +//! user-defined task function), the [`TaskHandler`] trait (the type-erased interface used by the +//! registration macro's dispatch table), and [`TaskHandlerImpl`] (the bridge between them). + +use std::marker::PhantomData; + +use crate::{ + error::TdlError, + task_context::TaskContext, + wire::{TaskInputs, TaskOutputs}, +}; + +/// The result of executing a task through the [`TaskHandler`] interface. +/// +/// Both variants carry a byte buffer: `Outputs` holds wire-format-encoded output payloads, and +/// `Error` holds a msgpack-encoded [`TdlError`]. +pub enum ExecutionResult { + /// Wire-format-encoded output payloads (one per return-tuple element). + Outputs(Vec), + + /// Msgpack-encoded [`TdlError`]. + Error(Vec), +} + +/// Trait implemented by the `#[task]` proc-macro for each user-defined task function. +/// +/// The proc-macro generates a marker struct, a params struct, and an impl of this trait that +/// wires deserialized parameters into the user's function body. +pub trait Task { + /// The name of the task, used as the lookup key in the registration table. + const NAME: &'static str; + + /// The deserialized parameter struct (generated by the proc-macro, excludes [`TaskContext`]). + type Params: for<'de> serde::Deserialize<'de>; + + /// The return type of the task (always a tuple, even for single values). + type Return: serde::Serialize; + + /// Executes the task with the given context and deserialized parameters. + /// + /// # Returns + /// + /// The task's return tuple on success. + /// + /// # Errors + /// + /// Returns a [`TdlError`] on failure. + fn execute(ctx: TaskContext, params: Self::Params) -> Result; + + /// Serializes the return tuple into wire-format bytes. + /// + /// The default implementation uses [`TaskOutputs::serialize_from`] which drives serde's + /// `Serialize` impl for the tuple, decomposing it into individually msgpack-encoded wire + /// payloads. The proc-macro does not need to generate this method. + /// + /// # Returns + /// + /// The wire-format byte stream on success. + /// + /// # Errors + /// + /// Returns a [`TdlError`] if serialization of any element fails. + fn serialize_return(result: &Self::Return) -> Result, TdlError> { + TaskOutputs::serialize_from(result).map_err(|e| TdlError::SerializationError(e.to_string())) + } +} + +/// Type-erased task execution interface. +/// +/// Used by the registration macro's dispatch table to invoke tasks by name without knowing the +/// concrete [`Task`] type. Implementations must be thread-safe since the dispatch table is a +/// `static` shared across threads. +pub trait TaskHandler: Send + Sync { + /// Executes the task from raw serialized inputs. + /// + /// # Parameters + /// + /// * `raw_ctx` - Msgpack-encoded [`TaskContext`]. + /// * `raw_args` - Wire-format-encoded task inputs. + /// + /// # Returns + /// + /// [`ExecutionResult::Outputs`] on success, [`ExecutionResult::Error`] on failure. + fn execute_raw(&self, raw_ctx: &[u8], raw_args: &[u8]) -> ExecutionResult; + + /// Returns the task's registered name ([`Task::NAME`]). + fn name(&self) -> &'static str; +} + +/// Bridges a concrete [`Task`] implementation to the type-erased [`TaskHandler`] interface. +/// +/// Handles deserialization of [`TaskContext`] (from msgpack) and task parameters (from wire +/// format), delegates to [`Task::execute`], and serializes the result back into bytes. +/// +/// # Type Parameters +/// +/// * `TaskType` - The concrete [`Task`] implementation (typically the marker struct generated by +/// the `#[task]` proc-macro). +pub struct TaskHandlerImpl { + _marker: PhantomData, +} + +impl Default for TaskHandlerImpl { + fn default() -> Self { + Self::new() + } +} + +impl TaskHandlerImpl { + /// Creates a new handler for the given task type. + #[must_use] + pub const fn new() -> Self { + Self { + _marker: PhantomData, + } + } +} + +impl TaskHandler for TaskHandlerImpl { + fn execute_raw(&self, raw_ctx: &[u8], raw_args: &[u8]) -> ExecutionResult { + // 1. Deserialize TaskContext (msgpack). + let ctx: TaskContext = match rmp_serde::from_slice(raw_ctx) { + Ok(c) => c, + Err(e) => { + return serialize_error(&TdlError::DeserializationError(format!( + "failed to deserialize TaskContext: {e}" + ))); + } + }; + + // 2. Deserialize task parameters (wire format). + let params: TaskType::Params = match TaskInputs::deserialize(raw_args) { + Ok(p) => p, + Err(e) => { + return serialize_error(&TdlError::DeserializationError(e.to_string())); + } + }; + + // 3. Execute the user's task function. + let result = match TaskType::execute(ctx, params) { + Ok(r) => r, + Err(ref e) => return serialize_error(e), + }; + + // 4. Serialize the return tuple into wire-format bytes. + match TaskType::serialize_return(&result) { + Ok(bytes) => ExecutionResult::Outputs(bytes), + Err(ref e) => serialize_error(e), + } + } + + fn name(&self) -> &'static str { + TaskType::NAME + } +} + +impl ExecutionResult { + /// Creates an [`ExecutionResult::Error`] by msgpack-serializing the given [`TdlError`]. + /// + /// This is intended for use by the `register_tasks!` macro so that callers do not need a + /// direct dependency on `rmp_serde`. + /// + /// # Panics + /// + /// Panics if msgpack serialization of `err` fails (should never happen for well-formed + /// [`TdlError`] values). + #[must_use] + pub fn from_tdl_error(err: &TdlError) -> Self { + let bytes = rmp_serde::to_vec(err).expect("TdlError msgpack serialization failed"); + Self::Error(bytes) + } +} + +/// Serializes a [`TdlError`] into an [`ExecutionResult::Error`]. +fn serialize_error(err: &TdlError) -> ExecutionResult { + ExecutionResult::from_tdl_error(err) +} + +#[cfg(test)] +mod tests { + use spider_core::types::id::{JobId, TaskId}; + + use super::{ExecutionResult, Task, TaskHandler, TaskHandlerImpl}; + use crate::{ + error::TdlError, + task_context::TaskContext, + tdl_types::int32, + wire::{TaskInputs, TaskOutputs}, + }; + + struct AddTask; + + #[derive(serde::Deserialize)] + struct AddTaskParams { + a: int32, + b: int32, + } + + impl Task for AddTask { + type Params = AddTaskParams; + type Return = (int32,); + + const NAME: &'static str = "test::add"; + + fn execute(_ctx: TaskContext, params: Self::Params) -> Result { + Ok((params.a + params.b,)) + } + } + + struct FailTask; + + #[derive(serde::Deserialize)] + struct EmptyParams {} + + impl Task for FailTask { + type Params = EmptyParams; + type Return = (int32,); + + const NAME: &'static str = "test::fail"; + + fn execute(_ctx: TaskContext, _params: Self::Params) -> Result { + Err(TdlError::ExecutionError("intentional failure".to_owned())) + } + } + + fn make_ctx() -> TaskContext { + TaskContext { + job_id: JobId::new(), + task_id: TaskId::new(), + task_instance_id: 1, + } + } + + fn encode_ctx(ctx: &TaskContext) -> Vec { + rmp_serde::to_vec(ctx).expect("TaskContext serialization failed") + } + + fn make_inputs(a: int32, b: int32) -> Vec { + let mut inputs = TaskInputs::new(); + inputs + .append(spider_core::types::io::TaskInput::ValuePayload( + rmp_serde::to_vec(&a).expect("msgpack encoding failed"), + )) + .expect("append failed"); + inputs + .append(spider_core::types::io::TaskInput::ValuePayload( + rmp_serde::to_vec(&b).expect("msgpack encoding failed"), + )) + .expect("append failed"); + inputs.release() + } + + #[test] + fn handler_success() -> anyhow::Result<()> { + let handler = TaskHandlerImpl::::new(); + assert_eq!(handler.name(), "test::add"); + + let ctx = make_ctx(); + let raw_ctx = encode_ctx(&ctx); + let raw_args = make_inputs(10, 32); + + let result = handler.execute_raw(&raw_ctx, &raw_args); + match result { + ExecutionResult::Outputs(bytes) => { + let outputs = TaskOutputs::deserialize(&bytes)?; + assert_eq!(outputs.len(), 1); + let sum: int32 = rmp_serde::from_slice(&outputs[0])?; + assert_eq!(sum, 42); + } + ExecutionResult::Error(bytes) => { + let err: TdlError = rmp_serde::from_slice(&bytes)?; + panic!("unexpected error: {err}"); + } + } + Ok(()) + } + + #[test] + fn handler_task_error() -> anyhow::Result<()> { + let handler = TaskHandlerImpl::::new(); + + let ctx = make_ctx(); + let raw_ctx = encode_ctx(&ctx); + let raw_args = TaskInputs::new().release(); // empty params + + let result = handler.execute_raw(&raw_ctx, &raw_args); + match result { + ExecutionResult::Outputs(_) => panic!("expected error"), + ExecutionResult::Error(bytes) => { + let err: TdlError = rmp_serde::from_slice(&bytes)?; + assert!( + matches!(err, TdlError::ExecutionError(ref msg) if msg == "intentional failure") + ); + } + } + Ok(()) + } + + #[test] + fn handler_bad_ctx() -> anyhow::Result<()> { + let handler = TaskHandlerImpl::::new(); + let raw_args = make_inputs(1, 2); + + // Pass garbage as TaskContext. + let result = handler.execute_raw(&[0xc1], &raw_args); + match result { + ExecutionResult::Outputs(_) => panic!("expected error"), + ExecutionResult::Error(bytes) => { + let err: TdlError = rmp_serde::from_slice(&bytes)?; + assert!(matches!(err, TdlError::DeserializationError(_))); + } + } + Ok(()) + } + + #[test] + fn handler_bad_inputs() -> anyhow::Result<()> { + let handler = TaskHandlerImpl::::new(); + let ctx = make_ctx(); + let raw_ctx = encode_ctx(&ctx); + + // Pass truncated wire data. + let result = handler.execute_raw(&raw_ctx, &[0x01]); + match result { + ExecutionResult::Outputs(_) => panic!("expected error"), + ExecutionResult::Error(bytes) => { + let err: TdlError = rmp_serde::from_slice(&bytes)?; + assert!(matches!(err, TdlError::DeserializationError(_))); + } + } + Ok(()) + } +} diff --git a/components/spider-tdl/src/task_context.rs b/components/spider-tdl/src/task_context.rs new file mode 100644 index 00000000..9e173b80 --- /dev/null +++ b/components/spider-tdl/src/task_context.rs @@ -0,0 +1,45 @@ +//! Runtime metadata passed to every task function as the first parameter. +//! +//! [`TaskContext`] is constructed by the execution manager, msgpack-serialized, and forwarded +//! through the task executor into the TDL package. It is separate from the task's user-supplied +//! inputs, which travel as a wire-format byte stream. + +use spider_core::types::id::{JobId, TaskId, TaskInstanceId}; + +/// Runtime metadata for a single task execution. +/// +/// Every task function receives a [`TaskContext`] as its first parameter, providing identity +/// information about the job, task, and task instance that triggered the execution. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct TaskContext { + pub job_id: JobId, + pub task_id: TaskId, + pub task_instance_id: TaskInstanceId, +} + +#[cfg(test)] +mod tests { + use spider_core::types::id::{JobId, TaskId}; + + use super::TaskContext; + + #[test] + fn round_trip_msgpack() -> anyhow::Result<()> { + let original = TaskContext { + job_id: JobId::new(), + task_id: TaskId::new(), + task_instance_id: 42, + }; + + let encoded = rmp_serde::to_vec(&original)?; + let decoded: TaskContext = rmp_serde::from_slice(&encoded)?; + + assert_eq!(original.job_id.as_uuid_ref(), decoded.job_id.as_uuid_ref()); + assert_eq!( + original.task_id.as_uuid_ref(), + decoded.task_id.as_uuid_ref() + ); + assert_eq!(original.task_instance_id, decoded.task_instance_id); + Ok(()) + } +} diff --git a/components/spider-tdl/src/tdl_types.rs b/components/spider-tdl/src/tdl_types.rs new file mode 100644 index 00000000..638d1201 --- /dev/null +++ b/components/spider-tdl/src/tdl_types.rs @@ -0,0 +1,65 @@ +//! Type aliases and marker traits that spell TDL primitive types in the Rust source. +//! +//! These names (`int8`, `float`, `List`, `Map`, ...) are the user-facing surface for authors of +//! TDL packages: the `#[task]` proc-macro inspects parameter types by name, so users should write +//! the aliases directly in their task signatures rather than the underlying Rust primitives. + +// The lowercase primitive aliases (`int8`, `float`, `boolean`, ...) intentionally mirror the TDL +// language's primitive type spelling, so that TDL package source reads like TDL rather than Rust. +#![allow(non_camel_case_types)] + +/// Signed 8-bit integer. +pub type int8 = i8; + +/// Signed 16-bit integer. +pub type int16 = i16; + +/// Signed 32-bit integer. +pub type int32 = i32; + +/// Signed 64-bit integer. +pub type int64 = i64; + +/// 32-bit IEEE-754 floating-point number. +pub type float = f32; + +/// 64-bit IEEE-754 floating-point number. +pub type double = f64; + +/// Boolean value. +pub type boolean = bool; + +/// Opaque byte string. +pub type Bytes = Vec; + +/// Homogeneous list of values. +pub type List = Vec; + +/// Homogeneous key-value map. +/// +/// The key type must satisfy [`MapKey`]. Because Rust does not enforce `where` clauses on type +/// aliases on stable, the bound is enforced by the `#[task]` proc-macro at parse time rather than +/// by the type alias itself. +pub type Map = std::collections::HashMap; + +mod private { + pub trait Sealed {} +} + +/// Marker trait restricting which types may appear as keys of a TDL [`Map`]. +/// +/// The trait is sealed: downstream crates cannot implement it for their own types. This guarantees +/// that every permitted key type has a stable, well-defined encoding at the wire layer. +pub trait MapKey: Eq + std::hash::Hash + private::Sealed {} + +impl private::Sealed for i8 {} +impl private::Sealed for i16 {} +impl private::Sealed for i32 {} +impl private::Sealed for i64 {} +impl private::Sealed for Vec {} + +impl MapKey for i8 {} +impl MapKey for i16 {} +impl MapKey for i32 {} +impl MapKey for i64 {} +impl MapKey for Vec {} diff --git a/components/spider-tdl/src/wire.rs b/components/spider-tdl/src/wire.rs new file mode 100644 index 00000000..d3158b5b --- /dev/null +++ b/components/spider-tdl/src/wire.rs @@ -0,0 +1,1015 @@ +//! Wire-format framing for TDL task inputs and task outputs. +//! +//! The wire format is a thin, length-prefixed framing layer that wraps an ordered sequence of +//! opaque byte payloads. It is used on both sides of the TDL package boundary: +//! +//! * Task inputs, originating in the storage layer, frame a `Vec` into a byte stream +//! which the TDL package then deserializes directly into the task's parameter struct. +//! * Task outputs, produced inside the TDL package, frame the elements of the return tuple into a +//! byte stream which the storage layer later unframes into a `Vec>` of per-element +//! msgpack payloads. +//! +//! ```text +//! [count: u32 LE] [len_0: u32 LE][payload_0 ...] [len_1: u32 LE][payload_1 ...] ... +//! ``` +//! +//! The wire layer never interprets the payload bytes -- that responsibility belongs to the payload +//! layer (msgpack, via `rmp-serde`). Field-level deserialization is zero-copy: each payload is +//! handed to `rmp_serde` as a borrowed slice into the original wire buffer. + +use std::fmt; + +use serde::de::{self, DeserializeSeed, SeqAccess, Visitor}; +use spider_core::types::io::{TaskInput, TaskOutput}; + +/// Length of the wire header recording the payload count, in bytes. +const COUNT_HEADER_LEN: usize = 4; + +/// Length of the per-payload length prefix, in bytes. +const FIELD_LEN_PREFIX_LEN: usize = 4; + +/// Errors produced while framing or unframing a TDL wire buffer. +/// +/// [`WireError`] is module-local: it describes failures of the wire/payload layer specifically. +/// Higher-level call sites (for example, `TaskHandlerImpl` once it is implemented) translate it +/// into a [`crate::TdlError`] before the error crosses the C-FFI edge. +#[derive(Debug, thiserror::Error)] +pub enum WireError { + /// The encoded payload count does not match the destination struct's field count. + #[error("`{type_name}`: expected {expected} payloads, got {actual}")] + LengthMismatch { + type_name: &'static str, + expected: usize, + actual: usize, + }, + + /// A single payload failed to decode from its msgpack bytes. + #[error( + "`{type_name}::{field}` (position {position}): failed to decode msgpack payload: {source}" + )] + FieldDeserialization { + type_name: &'static str, + field: &'static str, + position: usize, + #[source] + source: rmp_serde::decode::Error, + }, + + /// The wire buffer is malformed -- truncated, corrupted, or otherwise not a valid framing + /// of a payload sequence. + #[error("invalid wire format: {0}")] + InvalidFormat(&'static str), + + /// A value exceeds the wire format's `u32` size limit during serialization. + #[error("wire format overflow: {0}")] + Overflow(String), + + /// Catch-all bucket required by [`serde::de::Error`] and [`serde::ser::Error`] for errors + /// that do not fit any specific variant. + #[error("{0}")] + Custom(String), +} + +impl de::Error for WireError { + fn custom(msg: MessageType) -> Self { + Self::Custom(msg.to_string()) + } +} + +impl serde::ser::Error for WireError { + fn custom(msg: MessageType) -> Self { + Self::Custom(msg.to_string()) + } +} + +/// Streaming wire-format serializer for task inputs. +/// +/// Appends [`TaskInput`] payloads one at a time into an internal buffer, writing each payload's +/// length prefix inline. The count header at the front of the buffer is patched by +/// [`Self::release`] once all inputs have been appended. +/// +/// # Example (conceptual) +/// +/// ```ignore +/// let mut inputs = TaskInputs::new(); +/// inputs.append(TaskInput::ValuePayload(msgpack_bytes_0))?; +/// inputs.append(TaskInput::ValuePayload(msgpack_bytes_1))?; +/// let wire: Vec = inputs.release(); +/// ``` +pub struct TaskInputs { + builder: WireFrameBuilder, +} + +impl Default for TaskInputs { + fn default() -> Self { + Self::new() + } +} + +impl TaskInputs { + /// Creates a new streaming serializer with an empty buffer. + #[must_use] + pub fn new() -> Self { + Self { + builder: WireFrameBuilder::new(), + } + } + + /// Appends a single task input to the wire buffer. + /// + /// The payload bytes inside the [`TaskInput::ValuePayload`] variant are written directly; no + /// re-encoding takes place. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`WireError::Overflow`] if the payload count would exceed [`u32::MAX`]. + /// * [`WireError::Overflow`] if the payload is longer than [`u32::MAX`] bytes. + pub fn append(&mut self, input: TaskInput) -> Result<(), WireError> { + let TaskInput::ValuePayload(bytes) = input; + self.builder.append_payload(&bytes) + } + + /// Finalizes the count header and returns the completed wire-format buffer. + #[must_use] + pub fn release(self) -> Vec { + self.builder.release() + } + + /// Deserializes a wire-format byte stream directly into a struct of type `TargetType`. + /// + /// This is the deserialization counterpart to [`Self::append`] + [`Self::release`]. Each + /// field of `TargetType` positionally consumes one wire-format payload, which is then + /// deserialized from msgpack via a zero-copy borrowed slice into `data`. + /// + /// # Type Parameters + /// + /// * `'de` - The lifetime of the wire buffer `data`. + /// * `TargetType` - The struct to produce. Must implement [`serde::Deserialize`]. + /// + /// # Returns + /// + /// The deserialized struct on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`WireError::InvalidFormat`] if the buffer is truncated or malformed. + /// * [`WireError::LengthMismatch`] if the payload count does not match `TargetType`'s field + /// count. + /// * [`WireError::FieldDeserialization`] if any payload fails to decode. + /// * [`WireError::Custom`] if `TargetType` is not a struct. + pub fn deserialize<'de, TargetType>(data: &'de [u8]) -> Result + where + TargetType: serde::Deserialize<'de>, { + let mut deserializer = StreamDeserializer::new(data)?; + TargetType::deserialize(&mut deserializer) + } +} + +/// Streaming wire-format serializer and deserializer for task outputs. +/// +/// On the serialization side, pre-encoded msgpack payloads (one per tuple element) are appended +/// via [`Self::append`], and the final wire buffer is obtained from [`Self::release`]. +/// +/// On the deserialization side, [`Self::deserialize`] extracts each payload as an opaque +/// `Vec` without decoding the msgpack contents. +pub struct TaskOutputs { + builder: WireFrameBuilder, +} + +impl Default for TaskOutputs { + fn default() -> Self { + Self::new() + } +} + +impl TaskOutputs { + /// Creates a new streaming serializer with an empty buffer. + #[must_use] + pub fn new() -> Self { + Self { + builder: WireFrameBuilder::new(), + } + } + + /// Serializes `value` into msgpack and appends the encoded bytes directly to the wire buffer. + /// + /// The msgpack encoding is written in-place: a placeholder length prefix is reserved, the + /// value is serialized into the buffer, and the prefix is back-patched with the actual + /// payload size. This avoids allocating an intermediate `Vec` per element. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`WireError::Overflow`] if the payload count would exceed [`u32::MAX`]. + /// * [`WireError::Overflow`] if the serialized payload is longer than [`u32::MAX`] bytes. + /// * [`WireError::Custom`] if msgpack serialization of `value` fails. + pub fn append( + &mut self, + value: &ValueType, + ) -> Result<(), WireError> { + self.builder.append_serialize(value) + } + + /// Finalizes the count header and returns the completed wire-format buffer. + #[must_use] + pub fn release(self) -> Vec { + self.builder.release() + } + + /// Deserializes a wire-format byte stream into a vector of [`TaskOutput`] values. + /// + /// Each payload is extracted as an opaque `Vec`. The msgpack contents are **not** + /// decoded here -- the storage layer is responsible for interpreting each output downstream. + /// + /// # Returns + /// + /// A vector of output payloads on success, one per wire-format element. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`WireError::InvalidFormat`] if the buffer is truncated or malformed. + pub fn deserialize(data: &[u8]) -> Result, WireError> { + unframe_payloads(data) + } + + /// Serializes a tuple by decomposing it into individual elements, encoding each element as + /// msgpack, and framing them in the wire format. + /// + /// This drives serde's `Serialize` impl for the tuple: serde calls `serialize_tuple(len)` + /// followed by `serialize_element` for each element. A custom [`TupleOutputSerializer`] + /// intercepts these calls and routes each element through [`TaskOutputs::append`]. + /// + /// # Type Parameters + /// + /// * `TupleType` - The tuple type to serialize. Must implement [`serde::Serialize`]. + /// + /// # Returns + /// + /// The wire-format byte stream on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`WireError::Overflow`] if any payload exceeds [`u32::MAX`] bytes. + /// * [`WireError::Custom`] if msgpack serialization of any element fails. + /// * [`WireError::Custom`] if the value is not a tuple. + pub fn serialize_from( + value: &TupleType, + ) -> Result, WireError> { + value.serialize(TupleOutputSerializer { + outputs: Self::new(), + }) + } +} + +/// Custom serde [`serde::Serializer`] that decomposes a tuple into individually-encoded wire +/// payloads via [`TaskOutputs`]. +/// +/// Only `serialize_tuple` (and `serialize_unit` for the empty-tuple case) are supported. All +/// other serialization methods return an error. +struct TupleOutputSerializer { + outputs: TaskOutputs, +} + +/// Generates `serialize_*` methods on a `serde::Serializer` impl that all return the same +/// error. Covers the three method shapes in the trait: +/// +/// - `primitive`: `fn method(self, _: Type) -> Result` +/// - `compound`: `fn method(self, ...) -> Result` +/// - `generic`: `fn method(self, ...) -> Result` +macro_rules! reject_non_tuple { + // fn method(self, _: PrimType) + (primitive: $($method:ident($prim:ty)),* $(,)?) => { + $( + fn $method(self, _: $prim) -> Result { + Err(unsupported_type_error()) + } + )* + }; + // fn method(self, ...) -> Result (compound starters) + (compound: $($method:ident($($arg:ident: $ty:ty),*) -> $assoc:ty),* $(,)?) => { + $( + fn $method(self, $($arg: $ty),*) -> Result<$assoc, Self::Error> { + Err(unsupported_type_error()) + } + )* + }; +} + +impl serde::Serializer for TupleOutputSerializer { + type Error = WireError; + type Ok = Vec; + type SerializeMap = serde::ser::Impossible; + type SerializeSeq = serde::ser::Impossible; + type SerializeStruct = serde::ser::Impossible; + type SerializeStructVariant = serde::ser::Impossible; + type SerializeTuple = Self; + type SerializeTupleStruct = serde::ser::Impossible; + type SerializeTupleVariant = serde::ser::Impossible; + + reject_non_tuple! { primitive: + serialize_bool(bool), + serialize_i8(i8), serialize_i16(i16), serialize_i32(i32), serialize_i64(i64), + serialize_u8(u8), serialize_u16(u16), serialize_u32(u32), serialize_u64(u64), + serialize_f32(f32), serialize_f64(f64), + serialize_char(char), serialize_str(&str), serialize_bytes(&[u8]), + } + + reject_non_tuple! { compound: + serialize_unit_struct( + _n: &'static str + ) -> Self::Ok, + serialize_unit_variant( + _n: &'static str, _i: u32, _v: &'static str + ) -> Self::Ok, + serialize_seq( + _len: Option + ) -> Self::SerializeSeq, + serialize_tuple_struct( + _n: &'static str, _len: usize + ) -> Self::SerializeTupleStruct, + serialize_tuple_variant( + _n: &'static str, _i: u32, _v: &'static str, _len: usize + ) -> Self::SerializeTupleVariant, + serialize_map( + _len: Option + ) -> Self::SerializeMap, + serialize_struct( + _n: &'static str, _len: usize + ) -> Self::SerializeStruct, + serialize_struct_variant( + _n: &'static str, _i: u32, _v: &'static str, _len: usize + ) -> Self::SerializeStructVariant, + } + + fn serialize_tuple(self, _len: usize) -> Result { + Ok(self) + } + + fn serialize_unit(self) -> Result { + Ok(self.outputs.release()) + } + + fn serialize_none(self) -> Result { + Err(unsupported_type_error()) + } + + fn serialize_some( + self, + _: &ValueType, + ) -> Result { + Err(unsupported_type_error()) + } + + fn serialize_newtype_struct( + self, + _: &'static str, + _: &ValueType, + ) -> Result { + Err(unsupported_type_error()) + } + + fn serialize_newtype_variant( + self, + _: &'static str, + _: u32, + _: &'static str, + _: &ValueType, + ) -> Result { + Err(unsupported_type_error()) + } +} + +impl serde::ser::SerializeTuple for TupleOutputSerializer { + type Error = WireError; + type Ok = Vec; + + fn serialize_element( + &mut self, + value: &ValueType, + ) -> Result<(), Self::Error> { + self.outputs.append(value) + } + + fn end(self) -> Result { + Ok(self.outputs.release()) + } +} + +fn unsupported_type_error() -> WireError { + WireError::Custom("task output must be a tuple".to_owned()) +} + +/// Streaming wire-format builder shared by [`TaskInputs`] and [`TaskOutputs`]. +/// +/// Reserves space for the `u32` count header upfront and patches it in [`Self::release`] once +/// the final count is known. Each [`Self::append_payload`] call writes a length-prefixed payload +/// directly into the buffer. +struct WireFrameBuilder { + buffer: Vec, + count: u32, +} + +impl WireFrameBuilder { + fn new() -> Self { + let buffer = vec![0u8; COUNT_HEADER_LEN]; + Self { buffer, count: 0 } + } + + fn append_payload(&mut self, payload: &[u8]) -> Result<(), WireError> { + let payload_len = u32::try_from(payload.len()).map_err(|_| { + WireError::Overflow(format!( + "payload length {} bytes exceeds u32::MAX", + payload.len() + )) + })?; + self.increment_count()?; + self.buffer.extend_from_slice(&payload_len.to_le_bytes()); + self.buffer.extend_from_slice(payload); + Ok(()) + } + + /// Serializes `value` into msgpack directly into the buffer with a length prefix. + /// + /// Writes a placeholder `u32` length, serializes the value in-place, then back-patches the + /// length with the actual payload size. + fn append_serialize( + &mut self, + value: &ValueType, + ) -> Result<(), WireError> { + self.increment_count()?; + + // Reserve space for the length prefix. + let len_offset = self.buffer.len(); + self.buffer.extend_from_slice(&0u32.to_le_bytes()); + + // Serialize directly into the buffer. + rmp_serde::encode::write(&mut self.buffer, value) + .map_err(|e| WireError::Custom(format!("msgpack serialization failed: {e}")))?; + + // Back-patch the length prefix. + let payload_len = self.buffer.len() - len_offset - FIELD_LEN_PREFIX_LEN; + let payload_len_u32 = u32::try_from(payload_len).map_err(|_| { + WireError::Overflow(format!( + "payload length {payload_len} bytes exceeds u32::MAX" + )) + })?; + self.buffer[len_offset..len_offset + FIELD_LEN_PREFIX_LEN] + .copy_from_slice(&payload_len_u32.to_le_bytes()); + Ok(()) + } + + fn increment_count(&mut self) -> Result<(), WireError> { + self.count = self + .count + .checked_add(1) + .ok_or_else(|| WireError::Overflow("payload count exceeds u32::MAX".to_owned()))?; + Ok(()) + } + + fn release(mut self) -> Vec { + self.buffer[..COUNT_HEADER_LEN].copy_from_slice(&self.count.to_le_bytes()); + self.buffer + } +} + +/// Parses the wire-format framing and extracts each payload as an owned `Vec`. +/// +/// Shared deserialization core for [`TaskOutputs::deserialize`]. Each payload is copied out of +/// the wire buffer into its own allocation. +fn unframe_payloads(data: &[u8]) -> Result>, WireError> { + if data.len() < COUNT_HEADER_LEN { + return Err(WireError::InvalidFormat( + "buffer too small for the payload count header", + )); + } + let count_bytes: [u8; COUNT_HEADER_LEN] = data[..COUNT_HEADER_LEN] + .try_into() + .expect("slice length checked above"); + let count = u32::from_le_bytes(count_bytes) as usize; + + let mut pos = COUNT_HEADER_LEN; + let mut payloads = Vec::with_capacity(count); + for _ in 0..count { + if pos + FIELD_LEN_PREFIX_LEN > data.len() { + return Err(WireError::InvalidFormat( + "unexpected end of buffer reading payload length", + )); + } + let len_bytes: [u8; FIELD_LEN_PREFIX_LEN] = data[pos..pos + FIELD_LEN_PREFIX_LEN] + .try_into() + .expect("slice length checked above"); + let field_len = u32::from_le_bytes(len_bytes) as usize; + pos += FIELD_LEN_PREFIX_LEN; + + if pos + field_len > data.len() { + return Err(WireError::InvalidFormat( + "unexpected end of buffer reading payload data", + )); + } + payloads.push(data[pos..pos + field_len].to_vec()); + pos += field_len; + } + Ok(payloads) +} + +/// Single-pass, zero-copy cursor over a wire-format byte stream. +/// +/// Holds a borrowed slice of the wire buffer and a position cursor. The cursor advances each +/// time a field is consumed, yielding borrowed slices into the buffer that can be handed to +/// `rmp_serde` for payload deserialization. +struct StreamDeserializer<'de> { + data: &'de [u8], + pos: usize, + count: usize, + current_field: usize, + type_name: &'static str, + field_names: &'static [&'static str], +} + +impl<'de> StreamDeserializer<'de> { + fn new(data: &'de [u8]) -> Result { + if data.len() < COUNT_HEADER_LEN { + return Err(WireError::InvalidFormat( + "buffer too small for the payload count header", + )); + } + let count_bytes: [u8; COUNT_HEADER_LEN] = data[..COUNT_HEADER_LEN] + .try_into() + .expect("slice length checked above"); + let count = u32::from_le_bytes(count_bytes) as usize; + Ok(Self { + data, + pos: COUNT_HEADER_LEN, + count, + current_field: 0, + type_name: "", + field_names: &[], + }) + } + + fn next_field_bytes(&mut self) -> Result<&'de [u8], WireError> { + if self.pos + FIELD_LEN_PREFIX_LEN > self.data.len() { + return Err(WireError::InvalidFormat( + "unexpected end of buffer reading payload length", + )); + } + let len_bytes: [u8; FIELD_LEN_PREFIX_LEN] = self.data + [self.pos..self.pos + FIELD_LEN_PREFIX_LEN] + .try_into() + .expect("slice length checked above"); + let field_len = u32::from_le_bytes(len_bytes) as usize; + self.pos += FIELD_LEN_PREFIX_LEN; + + if self.pos + field_len > self.data.len() { + return Err(WireError::InvalidFormat( + "unexpected end of buffer reading payload data", + )); + } + let bytes = &self.data[self.pos..self.pos + field_len]; + self.pos += field_len; + Ok(bytes) + } +} + +impl<'de> serde::Deserializer<'de> for &mut StreamDeserializer<'de> { + type Error = WireError; + + serde::forward_to_deserialize_any! { + bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string + bytes byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct map enum identifier ignored_any + } + + fn deserialize_struct( + self, + name: &'static str, + fields: &'static [&'static str], + visitor: VisitorType, + ) -> Result + where + VisitorType: Visitor<'de>, { + if self.count != fields.len() { + return Err(WireError::LengthMismatch { + type_name: name, + expected: fields.len(), + actual: self.count, + }); + } + self.type_name = name; + self.field_names = fields; + visitor.visit_seq(FieldSeqAccess { de: self }) + } + + fn deserialize_any(self, _: VisitorType) -> Result + where + VisitorType: Visitor<'de>, { + Err(WireError::Custom( + "wire stream can only be deserialized into a struct".to_owned(), + )) + } +} + +/// Per-field adapter that hands the next payload of a wire frame to a serde visitor. +struct FieldSeqAccess<'borrow_lifetime, 'de> { + de: &'borrow_lifetime mut StreamDeserializer<'de>, +} + +impl<'de> SeqAccess<'de> for FieldSeqAccess<'_, 'de> { + type Error = WireError; + + fn next_element_seed( + &mut self, + seed: SeedType, + ) -> Result, WireError> + where + SeedType: DeserializeSeed<'de>, { + if self.de.current_field == self.de.count { + return Ok(None); + } + if self.de.current_field > self.de.count { + return Err(WireError::LengthMismatch { + type_name: self.de.type_name, + expected: self.de.count, + actual: self.de.current_field + 1, + }); + } + + let idx = self.de.current_field; + let field_name = self.de.field_names.get(idx).copied().unwrap_or(""); + let type_name = self.de.type_name; + + let bytes = self.de.next_field_bytes()?; + self.de.current_field += 1; + + // `bytes` is a borrowed `&'de [u8]` into the original buffer (zero-copy). + // `rmp_serde` deserializes directly from it in a single step per field. + let mut rmp_de = rmp_serde::Deserializer::from_read_ref(bytes); + seed.deserialize(&mut rmp_de) + .map(Some) + .map_err(|source| WireError::FieldDeserialization { + type_name, + field: field_name, + position: idx, + source, + }) + } + + fn size_hint(&self) -> Option { + Some(self.de.count - self.de.current_field) + } +} + +#[cfg(test)] +mod tests { + use serde::Deserialize; + use spider_core::types::io::{TaskInput, TaskOutput}; + + use super::{TaskInputs, TaskOutputs, WireError}; + use crate::tdl_types::{Bytes, List, Map, int8, int16}; + + /// msgpack-encodes a single value as a payload. + fn encode(value: &ValueType) -> Vec { + rmp_serde::to_vec(value).expect("msgpack encoding failed") + } + + #[derive(Debug, PartialEq, Deserialize)] + struct Job { + name: String, + priority: u32, + payload: Vec, + } + + #[derive(Debug, PartialEq, Deserialize)] + struct Single { + value: i64, + } + + #[derive(Debug, PartialEq, Deserialize)] + struct Empty {} + + #[derive(Debug, PartialEq, serde::Serialize, Deserialize)] + struct Inner { + x: i64, + y: i64, + } + + #[derive(Debug, PartialEq, Deserialize)] + struct Outer { + label: String, + point: Inner, + } + + #[derive(Debug, PartialEq, Deserialize)] + struct Params { + greeting: String, + count: u32, + } + + #[test] + fn wire_frame_byte_layout() -> anyhow::Result<()> { + let mut outputs = TaskOutputs::new(); + outputs.append(&42u8)?; + outputs.append(&"hi")?; + let wire = outputs.release(); + + let encoded_42 = encode(&42u8); + let encoded_hi = encode(&"hi"); + + let mut expected = Vec::new(); + expected.extend_from_slice(&2u32.to_le_bytes()); + expected.extend_from_slice(&u32::try_from(encoded_42.len())?.to_le_bytes()); + expected.extend_from_slice(&encoded_42); + expected.extend_from_slice(&u32::try_from(encoded_hi.len())?.to_le_bytes()); + expected.extend_from_slice(&encoded_hi); + assert_eq!(wire, expected); + Ok(()) + } + + #[test] + fn task_inputs_streaming_round_trip() -> anyhow::Result<()> { + let mut inputs = TaskInputs::new(); + inputs.append(TaskInput::ValuePayload(encode(&"hello".to_owned())))?; + inputs.append(TaskInput::ValuePayload(encode(&42u32)))?; + let wire = inputs.release(); + + let params: Params = TaskInputs::deserialize(&wire)?; + assert_eq!(params.greeting, "hello"); + assert_eq!(params.count, 42); + Ok(()) + } + + #[test] + fn task_inputs_empty() -> anyhow::Result<()> { + let wire = TaskInputs::new().release(); + assert_eq!(wire, 0u32.to_le_bytes()); + + let value: Empty = TaskInputs::deserialize(&wire)?; + assert_eq!(value, Empty {}); + Ok(()) + } + + #[test] + fn task_inputs_nested_struct() -> anyhow::Result<()> { + let mut inputs = TaskInputs::new(); + inputs.append(TaskInput::ValuePayload(encode(&"origin".to_owned())))?; + inputs.append(TaskInput::ValuePayload(encode(&Inner { x: -10, y: 42 })))?; + let wire = inputs.release(); + + let outer: Outer = TaskInputs::deserialize(&wire)?; + assert_eq!( + outer, + Outer { + label: "origin".to_owned(), + point: Inner { x: -10, y: 42 }, + } + ); + Ok(()) + } + + #[test] + fn task_inputs_deserialize_length_mismatch() { + let mut inputs = TaskInputs::new(); + inputs + .append(TaskInput::ValuePayload(encode(&"only-one".to_owned()))) + .expect("append failed"); + let wire = inputs.release(); + + let err = TaskInputs::deserialize::(&wire).expect_err("expected length mismatch"); + match err { + WireError::LengthMismatch { + type_name, + expected, + actual, + } => { + assert_eq!(type_name, "Job"); + assert_eq!(expected, 3); + assert_eq!(actual, 1); + } + other => panic!("unexpected error: {other}"), + } + } + + #[test] + fn task_inputs_deserialize_field_error() { + let mut inputs = TaskInputs::new(); + inputs + .append(TaskInput::ValuePayload(encode(&"name".to_owned()))) + .expect("append failed"); + // 0xC1 is a reserved/invalid msgpack byte. + inputs + .append(TaskInput::ValuePayload(vec![0xc1u8])) + .expect("append failed"); + inputs + .append(TaskInput::ValuePayload(encode(&vec![0u8]))) + .expect("append failed"); + let wire = inputs.release(); + + let err = TaskInputs::deserialize::(&wire) + .expect_err("expected field deserialization error"); + match err { + WireError::FieldDeserialization { + type_name, + field, + position, + .. + } => { + assert_eq!(type_name, "Job"); + assert_eq!(field, "priority"); + assert_eq!(position, 1); + } + other => panic!("unexpected error: {other}"), + } + } + + #[test] + fn task_inputs_deserialize_truncated_header() { + let err = TaskInputs::deserialize::(&[0x01]).expect_err("expected invalid format"); + assert!(matches!(err, WireError::InvalidFormat(_))); + } + + #[test] + fn task_inputs_deserialize_truncated_field() { + // Header declares 1 payload of 100 bytes but supplies only 2 bytes of data. + let mut wire = Vec::new(); + wire.extend_from_slice(&1u32.to_le_bytes()); + wire.extend_from_slice(&100u32.to_le_bytes()); + wire.extend_from_slice(&[0u8, 1]); + + let err = TaskInputs::deserialize::(&wire).expect_err("expected invalid format"); + assert!(matches!(err, WireError::InvalidFormat(_))); + } + + #[test] + fn task_outputs_streaming_round_trip() -> anyhow::Result<()> { + let mut outputs = TaskOutputs::new(); + outputs.append(&"result".to_owned())?; + outputs.append(&99i64)?; + let wire = outputs.release(); + + let decoded: Vec = TaskOutputs::deserialize(&wire)?; + assert_eq!(decoded.len(), 2); + // Each payload is the msgpack encoding of the original value. + assert_eq!(decoded[0], encode(&"result".to_owned())); + assert_eq!(decoded[1], encode(&99i64)); + Ok(()) + } + + #[test] + fn task_outputs_empty() -> anyhow::Result<()> { + let wire = TaskOutputs::new().release(); + assert_eq!(wire, 0u32.to_le_bytes()); + + let decoded: Vec = TaskOutputs::deserialize(&wire)?; + assert!(decoded.is_empty()); + Ok(()) + } + + #[test] + fn task_outputs_deserialize_truncated() { + let err = TaskOutputs::deserialize(&[0x01]).expect_err("expected invalid format"); + assert!(matches!(err, WireError::InvalidFormat(_))); + } + + #[test] + fn error_display_length_mismatch() { + let err = WireError::LengthMismatch { + type_name: "Foo", + expected: 2, + actual: 5, + }; + assert_eq!(err.to_string(), "`Foo`: expected 2 payloads, got 5"); + } + + #[test] + fn error_display_field_deserialization() { + let source = + rmp_serde::from_slice::(&[0xc1u8]).expect_err("expected rmp_serde decode error"); + let err = WireError::FieldDeserialization { + type_name: "Foo", + field: "bar", + position: 1, + source, + }; + let msg = err.to_string(); + assert!(msg.contains("Foo::bar")); + assert!(msg.contains("position 1")); + } + + #[derive(Debug, PartialEq, serde::Serialize, Deserialize)] + struct A { + map: Map, + list: List, + } + + #[derive(Debug, PartialEq, serde::Serialize, Deserialize)] + struct B { + a: A, + value: int16, + list_map: Map>, + } + + #[derive(Debug, PartialEq, serde::Serialize, Deserialize)] + struct EmptyInner {} + + #[derive(Debug, PartialEq, serde::Serialize, Deserialize)] + struct C { + a: A, + empty: EmptyInner, + b: int16, + } + + #[test] + fn compound_type_round_trip() -> anyhow::Result<()> { + let original = B { + a: A { + map: Map::from([(1i8, vec![0xabu8, 0xcd]), (-3i8, vec![])]), + list: vec![10, 20, -1], + }, + value: 1234, + list_map: Map::from([(0i8, vec![1, 2, 3]), (5i8, vec![])]), + }; + + // Serialize each field of B as a separate TaskInput. + let mut inputs = TaskInputs::new(); + inputs.append(TaskInput::ValuePayload(encode(&original.a)))?; + inputs.append(TaskInput::ValuePayload(encode(&original.value)))?; + inputs.append(TaskInput::ValuePayload(encode(&original.list_map)))?; + let wire = inputs.release(); + + let decoded: B = TaskInputs::deserialize(&wire)?; + assert_eq!(decoded, original); + Ok(()) + } + + #[test] + fn compound_type_with_empty_inner_round_trip() -> anyhow::Result<()> { + let original = C { + a: A { + map: Map::from([(42i8, vec![0xffu8])]), + list: vec![-128, 0, 127], + }, + empty: EmptyInner {}, + b: -1, + }; + + let mut inputs = TaskInputs::new(); + inputs.append(TaskInput::ValuePayload(encode(&original.a)))?; + inputs.append(TaskInput::ValuePayload(encode(&original.empty)))?; + inputs.append(TaskInput::ValuePayload(encode(&original.b)))?; + let wire = inputs.release(); + + let decoded: C = TaskInputs::deserialize(&wire)?; + assert_eq!(decoded, original); + Ok(()) + } + + #[test] + fn compound_type_output_round_trip() -> anyhow::Result<()> { + let original = B { + a: A { + map: Map::from([(0i8, vec![1u8, 2, 3])]), + list: vec![], + }, + value: 0, + list_map: Map::new(), + }; + + // Simulate proc-macro: serialize each tuple element directly into TaskOutputs. + let mut outputs = TaskOutputs::new(); + outputs.append(&original.a)?; + outputs.append(&original.value)?; + outputs.append(&original.list_map)?; + let wire = outputs.release(); + + // Storage-layer side: unframe into Vec. + let payloads: Vec = TaskOutputs::deserialize(&wire)?; + assert_eq!(payloads.len(), 3); + + // Verify each payload decodes to the expected value. + let decoded_a: A = rmp_serde::from_slice(&payloads[0])?; + let decoded_value: int16 = rmp_serde::from_slice(&payloads[1])?; + let decoded_list_map: Map> = rmp_serde::from_slice(&payloads[2])?; + assert_eq!(decoded_a, original.a); + assert_eq!(decoded_value, original.value); + assert_eq!(decoded_list_map, original.list_map); + Ok(()) + } + + #[test] + fn empty_params_round_trip() -> anyhow::Result<()> { + // A task with only TaskContext and no user-supplied inputs produces an empty wire frame. + let wire = TaskInputs::new().release(); + let decoded: Empty = TaskInputs::deserialize(&wire)?; + assert_eq!(decoded, Empty {}); + Ok(()) + } +} diff --git a/examples/example-tdl-package-complex/Cargo.toml b/examples/example-tdl-package-complex/Cargo.toml new file mode 100644 index 00000000..87317bb9 --- /dev/null +++ b/examples/example-tdl-package-complex/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "example-tdl-package-complex" +version = "0.1.0" +edition = "2024" +publish = false + +[lib] +name = "example_tdl_package_complex" +path = "src/lib.rs" +crate-type = ["cdylib"] + +[dependencies] +serde = { version = "1.0.228", features = ["derive"] } +spider-tdl = { path = "../../components/spider-tdl" } diff --git a/examples/example-tdl-package-complex/src/lib.rs b/examples/example-tdl-package-complex/src/lib.rs new file mode 100644 index 00000000..bad1ac28 --- /dev/null +++ b/examples/example-tdl-package-complex/src/lib.rs @@ -0,0 +1,58 @@ +use spider_tdl::{TaskContext, TdlError, task, tdl_types::double}; + +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +pub struct Complex { + pub re: double, + pub im: double, +} + +#[task(name = "complex::add")] +fn add(ctx: TaskContext, a: Complex, b: Complex) -> Result { + let _ = ctx; + Ok(Complex { + re: a.re + b.re, + im: a.im + b.im, + }) +} + +#[task(name = "complex::sub")] +fn sub(ctx: TaskContext, a: Complex, b: Complex) -> Result { + let _ = ctx; + Ok(Complex { + re: a.re - b.re, + im: a.im - b.im, + }) +} + +#[task(name = "complex::mul")] +fn mul(ctx: TaskContext, a: Complex, b: Complex) -> Result { + let _ = ctx; + Ok(Complex { + re: a.im.mul_add(-b.im, a.re * b.re), + im: a.im.mul_add(b.re, a.re * b.im), + }) +} + +#[task(name = "complex::div")] +fn div(ctx: TaskContext, a: Complex, b: Complex) -> Result { + let _ = ctx; + let denom = b.im.mul_add(b.im, b.re * b.re); + if denom == 0.0 { + return Err(TdlError::ExecutionError("division by zero".to_owned())); + } + Ok(Complex { + re: a.im.mul_add(b.im, a.re * b.re) / denom, + im: a.re.mul_add(-b.im, a.im * b.re) / denom, + }) +} + +#[task(name = "complex::always_fail")] +fn always_fail(ctx: TaskContext) -> Result<(), TdlError> { + let _ = ctx; + Err(TdlError::Custom("this task always fails".to_owned())) +} + +spider_tdl::register_tasks! { + package_name: "complex", + tasks: [add, sub, mul, div, always_fail] +} diff --git a/taskfiles/test.yaml b/taskfiles/test.yaml index 3b8316f5..45ad533b 100644 --- a/taskfiles/test.yaml +++ b/taskfiles/test.yaml @@ -214,6 +214,7 @@ tasks: MARIADB_DATABASE: "{{.MARIADB_DATABASE}}" MARIADB_USERNAME: "{{.MARIADB_USERNAME}}" MARIADB_PASSWORD: "{{.MARIADB_PASSWORD}}" + SPIDER_TDL_PACKAGE_COMPLEX: "{{.G_RUST_BUILD_DIR}}/release/libexample_tdl_package_complex.so" SPIDER_TEST_INSTRUMENT_OUTPUT_DIR: sh: "echo {{.G_BUILD_DIR}}/spider-instrument-$(uuidgen)" requires: @@ -225,6 +226,7 @@ tasks: - defer: "rm -rf ${SPIDER_TEST_INSTRUMENT_OUTPUT_DIR}" - |- . "{{.G_RUST_TOOLCHAIN_ENV_FILE}}" + cargo build --package example-tdl-package-complex --release cargo nextest run --all --all-features --run-ignored all --release - |- for f in ${SPIDER_TEST_INSTRUMENT_OUTPUT_DIR}/*; do diff --git a/tests/tdl-integration/Cargo.toml b/tests/tdl-integration/Cargo.toml new file mode 100644 index 00000000..3693a669 --- /dev/null +++ b/tests/tdl-integration/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "tdl-integration-tests" +version = "0.1.0" +edition = "2024" +publish = false + +[lib] +name = "tdl_integration_tests" +path = "src/lib.rs" + +[dev-dependencies] +anyhow = "1.0.98" +rmp-serde = "1.3.1" +serde = { version = "1.0.228", features = ["derive"] } +spider-core = { path = "../../components/spider-core" } +spider-task-executor = { path = "../../components/spider-task-executor" } +spider-tdl = { path = "../../components/spider-tdl" } diff --git a/tests/tdl-integration/src/lib.rs b/tests/tdl-integration/src/lib.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/tdl-integration/src/lib.rs @@ -0,0 +1 @@ + diff --git a/tests/tdl-integration/tests/complex.rs b/tests/tdl-integration/tests/complex.rs new file mode 100644 index 00000000..08477250 --- /dev/null +++ b/tests/tdl-integration/tests/complex.rs @@ -0,0 +1,228 @@ +use std::path::Path; + +use spider_core::types::{ + id::{JobId, TaskId}, + io::TaskInput, +}; +use spider_task_executor::TdlPackageLoader; +use spider_tdl::{ + TaskContext, + wire::{TaskInputs, TaskOutputs}, +}; + +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +struct Complex { + re: f64, + im: f64, +} + +fn lib_path() -> Option { + std::env::var("SPIDER_TDL_PACKAGE_COMPLEX").ok() +} + +fn encode_ctx() -> Vec { + let ctx = TaskContext { + job_id: JobId::new(), + task_id: TaskId::new(), + task_instance_id: 1, + }; + rmp_serde::to_vec(&ctx).expect("TaskContext serialization failed") +} + +fn encode_inputs(a: &Complex, b: &Complex) -> Vec { + let mut inputs = TaskInputs::new(); + inputs + .append(TaskInput::ValuePayload( + rmp_serde::to_vec(a).expect("complex serialization failed"), + )) + .expect("append failed"); + inputs + .append(TaskInput::ValuePayload( + rmp_serde::to_vec(b).expect("complex serialization failed"), + )) + .expect("append failed"); + inputs.release() +} + +fn encode_empty_inputs() -> Vec { + TaskInputs::new().release() +} + +fn decode_complex(output_bytes: &[u8]) -> Complex { + let outputs = TaskOutputs::deserialize(output_bytes).expect("output deserialization failed"); + assert_eq!(outputs.len(), 1, "expected exactly one output element"); + rmp_serde::from_slice(&outputs[0]).expect("complex deserialization failed") +} + +#[test] +fn load_and_get_package_name() -> anyhow::Result<()> { + let Some(path) = lib_path() else { + return Ok(()); + }; + let mut loader = TdlPackageLoader::new(); + let name = loader.load(Path::new(&path))?; + assert_eq!(name, "complex"); + Ok(()) +} + +#[test] +fn duplicate_load_rejected() -> anyhow::Result<()> { + let Some(path) = lib_path() else { + return Ok(()); + }; + let mut loader = TdlPackageLoader::new(); + loader.load(Path::new(&path))?; + + let err = loader + .load(Path::new(&path)) + .expect_err("expected duplicate package error"); + assert!(err.to_string().contains("duplicate")); + Ok(()) +} + +#[test] +fn add() -> anyhow::Result<()> { + let Some(path) = lib_path() else { + return Ok(()); + }; + let mut loader = TdlPackageLoader::new(); + loader.load(Path::new(&path))?; + let package = loader.get("complex").expect("package not loaded"); + + let a = Complex { re: 1.0, im: 2.0 }; + let b = Complex { re: 3.0, im: 4.0 }; + let result = decode_complex(&package.execute_task( + "complex::add", + &encode_ctx(), + &encode_inputs(&a, &b), + )?); + + assert_eq!(result, Complex { re: 4.0, im: 6.0 }); + Ok(()) +} + +#[test] +fn sub() -> anyhow::Result<()> { + let Some(path) = lib_path() else { + return Ok(()); + }; + let mut loader = TdlPackageLoader::new(); + loader.load(Path::new(&path))?; + let package = loader.get("complex").expect("package not loaded"); + + let a = Complex { re: 5.0, im: 3.0 }; + let b = Complex { re: 2.0, im: 1.0 }; + let result = decode_complex(&package.execute_task( + "complex::sub", + &encode_ctx(), + &encode_inputs(&a, &b), + )?); + + assert_eq!(result, Complex { re: 3.0, im: 2.0 }); + Ok(()) +} + +#[test] +fn mul() -> anyhow::Result<()> { + let Some(path) = lib_path() else { + return Ok(()); + }; + let mut loader = TdlPackageLoader::new(); + loader.load(Path::new(&path))?; + let package = loader.get("complex").expect("package not loaded"); + + // (1 + 2i) * (3 + 4i) = (1*3 - 2*4) + (1*4 + 2*3)i = -5 + 10i + let a = Complex { re: 1.0, im: 2.0 }; + let b = Complex { re: 3.0, im: 4.0 }; + let result = decode_complex(&package.execute_task( + "complex::mul", + &encode_ctx(), + &encode_inputs(&a, &b), + )?); + + assert_eq!(result, Complex { re: -5.0, im: 10.0 }); + Ok(()) +} + +#[test] +fn div() -> anyhow::Result<()> { + let Some(path) = lib_path() else { + return Ok(()); + }; + let mut loader = TdlPackageLoader::new(); + loader.load(Path::new(&path))?; + let package = loader.get("complex").expect("package not loaded"); + + // (4 + 2i) / (2 + 0i) = (2 + 1i) + let a = Complex { re: 4.0, im: 2.0 }; + let b = Complex { re: 2.0, im: 0.0 }; + let result = decode_complex(&package.execute_task( + "complex::div", + &encode_ctx(), + &encode_inputs(&a, &b), + )?); + + assert_eq!(result, Complex { re: 2.0, im: 1.0 }); + Ok(()) +} + +#[test] +fn div_by_zero() -> anyhow::Result<()> { + let Some(path) = lib_path() else { + return Ok(()); + }; + let mut loader = TdlPackageLoader::new(); + loader.load(Path::new(&path))?; + let package = loader.get("complex").expect("package not loaded"); + + let a = Complex { re: 1.0, im: 0.0 }; + let b = Complex { re: 0.0, im: 0.0 }; + let err = package + .execute_task("complex::div", &encode_ctx(), &encode_inputs(&a, &b)) + .expect_err("expected division by zero error"); + + assert!(err.to_string().contains("division by zero")); + Ok(()) +} + +#[test] +fn always_fail() -> anyhow::Result<()> { + let Some(path) = lib_path() else { + return Ok(()); + }; + let mut loader = TdlPackageLoader::new(); + loader.load(Path::new(&path))?; + let package = loader.get("complex").expect("package not loaded"); + + let err = package + .execute_task( + "complex::always_fail", + &encode_ctx(), + &encode_empty_inputs(), + ) + .expect_err("expected always_fail error"); + + assert!(err.to_string().contains("this task always fails")); + Ok(()) +} + +#[test] +fn task_not_found() -> anyhow::Result<()> { + let Some(path) = lib_path() else { + return Ok(()); + }; + let mut loader = TdlPackageLoader::new(); + loader.load(Path::new(&path))?; + let package = loader.get("complex").expect("package not loaded"); + + let err = package + .execute_task( + "complex::nonexistent", + &encode_ctx(), + &encode_empty_inputs(), + ) + .expect_err("expected task not found error"); + + assert!(err.to_string().contains("not found")); + Ok(()) +}