diff --git a/README.md b/README.md index 066723d3a6..3efb843b75 100644 --- a/README.md +++ b/README.md @@ -28,22 +28,22 @@ build: - "libglib2.0-0" python_version: "3.13" python_requirements: requirements.txt -predict: "predict.py:Predictor" +run: "run.py:Runner" ``` -Define how predictions are run on your model with `predict.py`: +Define how predictions are run on your model with `run.py`: ```python -from cog import BasePredictor, Input, Path +from cog import BaseRunner, Input, Path import torch -class Predictor(BasePredictor): +class Runner(BaseRunner): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.model = torch.load("./weights.pth") # The arguments and types the model takes as input - def predict(self, + def run(self, image: Path = Input(description="Grayscale input image") ) -> Path: """Run a single prediction on the model""" @@ -57,7 +57,7 @@ In the above we accept a path to the image as an input, and return a path to our Now, you can run predictions on this model: ```console -$ cog predict -i image=@input.jpg +$ cog run -i image=@input.jpg --> Building Docker image... --> Running Prediction... --> Output written to output.jpg @@ -180,7 +180,7 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for how to set up a development environme - [Take a look at some examples of using Cog](https://github.com/replicate/cog-examples) - [Deploy models with Cog](docs/deploy.md) - [`cog.yaml` reference](docs/yaml.md) to learn how to define your model's environment -- [Prediction interface reference](docs/python.md) to learn how the `Predictor` interface works +- [Run interface reference](docs/python.md) to learn how the `Runner` interface works - [Training interface reference](docs/training.md) to learn how to add a fine-tuning API to your model - [HTTP API reference](docs/http.md) to learn how to use the HTTP API that models serve diff --git a/architecture/00-overview.md b/architecture/00-overview.md index 695f603d85..0b0a164bd7 100644 --- a/architecture/00-overview.md +++ b/architecture/00-overview.md @@ -33,7 +33,7 @@ flowchart LR ### Model Source -What the model author provides: `cog.yaml` for environment config, a Predictor class with `setup()` and `predict()` methods, and optionally model weights. +What the model author provides: `cog.yaml` for environment config, a Runner class with `setup()` and `run()` methods, and optionally model weights. **Deep dive**: [Model Source](./01-model-source.md) @@ -41,7 +41,7 @@ What the model author provides: `cog.yaml` for environment config, a Predictor c ### Python SDK -The `cog` Python package that model authors import. Provides `BasePredictor`, the type system (`Input`, `Path`, `Secret`, `ConcatenateIterator`), and the thin server entry point that launches coglet. Installed inside every Cog container as a wheel. +The `cog` Python package that model authors import. Provides `BaseRunner`, the type system (`Input`, `Path`, `Secret`, `ConcatenateIterator`), and the thin server entry point that launches coglet. Installed inside every Cog container as a wheel. **Deep dive**: [Model Source](./01-model-source.md) (covers the SDK's public API) @@ -93,7 +93,7 @@ The command-line tool for building, testing, and deploying models. flowchart TB subgraph source["Model Source"] yaml["cog.yaml"] - code["predict.py"] + code["run.py"] weights["weights"] end @@ -111,7 +111,7 @@ flowchart TB subgraph runtime["Runtime"] server["HTTP Server
(Rust/Axum)"] worker["Worker Subprocess
(Python)"] - predictor["Predictor"] + predictor["Runner"] end yaml --> config @@ -130,16 +130,16 @@ flowchart TB ## Terminology -| Term | Meaning | -| ------------- | ------------------------------------------------------------------------- | -| **SDK** | The `cog` Python package -- the framework users build models on | -| **Predictor** | User's model class with `setup()` and `predict()` methods | -| **Schema** | OpenAPI spec describing the model's input/output interface | -| **Envelope** | Fixed request/response structure wrapping model-specific data | -| **Worker** | Isolated subprocess running user code | -| **Setup** | One-time model initialization at container start | -| **Coglet** | Rust-based prediction server that runs inside containers | -| **Slot** | A concurrency unit -- one Unix socket connection to the worker subprocess | +| Term | Meaning | +| ------------ | ------------------------------------------------------------------------- | +| **SDK** | The `cog` Python package -- the framework users build models on | +| **Runner** | User's model class with `setup()` and `run()` methods | +| **Schema** | OpenAPI spec describing the model's input/output interface | +| **Envelope** | Fixed request/response structure wrapping model-specific data | +| **Worker** | Isolated subprocess running user code | +| **Setup** | One-time model initialization at container start | +| **Coglet** | Rust-based prediction server that runs inside containers | +| **Slot** | A concurrency unit -- one Unix socket connection to the worker subprocess | ## Reading Order diff --git a/architecture/01-model-source.md b/architecture/01-model-source.md index 75e43b3d36..9c15d8c73e 100644 --- a/architecture/01-model-source.md +++ b/architecture/01-model-source.md @@ -9,7 +9,7 @@ A Cog model consists of: ``` my-model/ ├── cog.yaml # Environment configuration -├── predict.py # Predictor class +├── run.py # Runner class └── weights/ # Model weights (optional, can be downloaded) ``` @@ -29,37 +29,37 @@ build: run: - curl -o /src/model.bin https://example.com/model.bin -predict: "predict.py:Predictor" +run: "run.py:Runner" concurrency: max: 1 ``` -| Field | Purpose | -| ----------------------- | -------------------------------------------- | -| `build.python_version` | Python interpreter version (3.10-3.13) | -| `build.gpu` | Enable CUDA support | -| `build.python_packages` | pip packages to install | -| `build.system_packages` | apt packages to install | -| `build.run` | Arbitrary shell commands during build | -| `predict` | Path to predictor class (`module:ClassName`) | -| `concurrency.max` | Max concurrent predictions (requires async) | +| Field | Purpose | +| ----------------------- | ------------------------------------------- | +| `build.python_version` | Python interpreter version (3.10-3.13) | +| `build.gpu` | Enable CUDA support | +| `build.python_packages` | pip packages to install | +| `build.system_packages` | apt packages to install | +| `build.run` | Arbitrary shell commands during build | +| `run` | Path to runner class (`module:ClassName`) | +| `concurrency.max` | Max concurrent predictions (requires async) | The [Build System](./05-build-system.md) uses this configuration to produce an image containing all necessary dependencies, libraries, and the correct Python/CUDA versions. -## The Predictor Class +## The Runner Class -A predictor is a Python class with two methods: +A runner is a Python class with two methods: ```python -from cog import BasePredictor, Input, Path +from cog import BaseRunner, Input, Path -class Predictor(BasePredictor): +class Runner(BaseRunner): def setup(self): """Load model into memory. Called once at container start.""" self.model = load_model("./weights") - def predict(self, prompt: str, steps: int = 50) -> Path: + def run(self, prompt: str, steps: int = 50) -> Path: """Run inference. Called for each prediction request.""" output = self.model.generate(prompt, steps=steps) output.save("/tmp/output.png") @@ -74,7 +74,7 @@ class Predictor(BasePredictor): - Optional: if omitted, Cog proceeds directly to serving - See [Container Runtime: Predictor Lifecycle](./04-container-runtime.md#predictor-lifecycle) for details on instance lifetime, concurrency, crash recovery, and shutdown -### predict() +### run() - Called **for each prediction request** - Signature defines the model's input schema (via type hints) @@ -84,12 +84,12 @@ class Predictor(BasePredictor): ## Input Types -The types used in `predict()` parameters become the model's input schema. +The types used in `run()` parameters become the model's input schema. ### Basic Types ```python -def predict( +def run( self, text: str, # String input count: int, # Integer @@ -105,7 +105,7 @@ URLs are automatically downloaded to local files: ```python from cog import Path -def predict(self, image: Path) -> Path: +def run(self, image: Path) -> Path: # Client sends: {"input": {"image": "https://example.com/photo.jpg"}} # Cog downloads the URL, `image` is a local path like /tmp/inputabc123.jpg img = PIL.Image.open(image) @@ -125,7 +125,7 @@ For sensitive values that shouldn't appear in logs: ```python from cog import Secret -def predict(self, api_key: Secret) -> str: +def run(self, api_key: Secret) -> str: # Value is masked in logs and webhooks client = SomeAPI(api_key.get_secret_value()) ... @@ -138,7 +138,7 @@ Use `Input()` to add metadata and validation: ```python from cog import Input -def predict( +def run( self, prompt: str = Input(description="The text prompt"), steps: int = Input(default=50, ge=1, le=100, description="Inference steps"), @@ -159,7 +159,7 @@ def predict( ```python from typing import Literal -def predict( +def run( self, size: Literal["small", "medium", "large"] = "medium", ) -> str: @@ -171,7 +171,7 @@ def predict( from typing import List from cog import Path -def predict( +def run( self, images: List[Path], # Multiple file inputs tags: List[str], # Multiple strings @@ -183,7 +183,7 @@ def predict( ```python from typing import Optional -def predict( +def run( self, seed: Optional[int] = None, # Can be omitted or null ) -> str: @@ -196,7 +196,7 @@ The return type annotation defines what the model produces. ### Basic Types ```python -def predict(self, prompt: str) -> str: +def run(self, prompt: str) -> str: return "Generated text..." ``` @@ -207,7 +207,7 @@ Return `cog.Path` pointing to a generated file: ```python from cog import Path -def predict(self, prompt: str) -> Path: +def run(self, prompt: str) -> Path: # Generate file output_path = "/tmp/output.png" self.model.generate(prompt).save(output_path) @@ -224,7 +224,7 @@ Return a list: from typing import List from cog import Path -def predict(self, prompt: str) -> List[Path]: +def run(self, prompt: str) -> List[Path]: paths = [] for i in range(4): path = f"/tmp/output_{i}.png" @@ -240,7 +240,7 @@ Yield values progressively: ```python from typing import Iterator -def predict(self, prompt: str) -> Iterator[str]: +def run(self, prompt: str) -> Iterator[str]: for token in self.model.generate_stream(prompt): yield token ``` @@ -254,7 +254,7 @@ For LLM-style token streaming where outputs should be concatenated: ```python from cog import ConcatenateIterator -def predict(self, prompt: str) -> ConcatenateIterator[str]: +def run(self, prompt: str) -> ConcatenateIterator[str]: for token in self.model.generate(prompt): yield token # "Hello", " ", "world", "!" # Client sees progressive: "Hello" -> "Hello " -> "Hello world" -> "Hello world!" @@ -273,7 +273,7 @@ Include weights in your source directory - they're copied into the image during ``` my-model/ ├── cog.yaml -├── predict.py +├── run.py └── weights/ └── model.safetensors ``` @@ -313,11 +313,11 @@ The choice depends on your deployment needs - bundled weights make images larger For concurrent predictions, use async: ```python -class Predictor(BasePredictor): +class Runner(BaseRunner): async def setup(self): self.model = await load_model_async() - async def predict(self, prompt: str) -> str: + async def run(self, prompt: str) -> str: return await self.model.generate(prompt) ``` @@ -330,10 +330,10 @@ See [Container Runtime](./04-container-runtime.md) for concurrency details. ## Code References -| File | Purpose | -| ------------------------- | --------------------------------------------------------- | -| `python/cog/__init__.py` | Public API exports | -| `python/cog/predictor.py` | BasePredictor class, type introspection, weights handling | -| `python/cog/types.py` | Path, Secret, ConcatenateIterator | -| `python/cog/input.py` | `Input()` function and field metadata | -| `pkg/config/config.go` | cog.yaml parsing | +| File | Purpose | +| ------------------------- | ------------------------------------------------------ | +| `python/cog/__init__.py` | Public API exports | +| `python/cog/predictor.py` | BaseRunner class, type introspection, weights handling | +| `python/cog/types.py` | Path, Secret, ConcatenateIterator | +| `python/cog/input.py` | `Input()` function and field metadata | +| `pkg/config/config.go` | cog.yaml parsing | diff --git a/architecture/02-schema.md b/architecture/02-schema.md index bd217199f3..b02ddfc050 100644 --- a/architecture/02-schema.md +++ b/architecture/02-schema.md @@ -27,7 +27,7 @@ Without the schema, consumers would have no way to know: | ------------------------ | ------------------------------------------------------------------------------ | | **Replicate platform** | Generate input forms in the web UI, validate requests before routing to models | | **HTTP server (coglet)** | Validate incoming JSON, reject malformed requests before they reach user code | -| **CLI (`cog predict`)** | Parse `-i key=value` flags into correctly-typed Python objects | +| **CLI (`cog run`)** | Parse `-i key=value` flags into correctly-typed Python objects | | **Docker label** | Extract model interface without running the container | | **API clients** | Know what to send and what to expect back without reading source code | @@ -40,7 +40,7 @@ If the static parser encounters a type it can't resolve, the build fails with a ```mermaid flowchart LR subgraph source["Model Source"] - predict["predict.py"] + predict["run.py"] types["output_types.py"] end @@ -68,7 +68,7 @@ flowchart LR 3. **Collect module scope** -- resolve module-level variable assignments (for default values, choices lists) 4. **Collect BaseModel subclasses** -- find all classes that inherit from `BaseModel` (cog's dataclass-based version; pydantic BaseModel also supported for compatibility) 5. **Resolve cross-file models** — for imported names not found locally, find the `.py` file on disk, parse it, and extract its BaseModel definitions -6. **Extract inputs** — walk the `predict()` method parameters, resolve types, defaults, and `Input()` metadata +6. **Extract inputs** — walk the `run()` method parameters, resolve types, defaults, and `Input()` metadata. Legacy class `predict()` methods are still accepted with a warning. 7. **Resolve output type** — recursively resolve the return type annotation into a `SchemaType` 8. **Generate OpenAPI** — convert the extracted `PredictorInfo` into a full OpenAPI 3.0.2 JSON document @@ -89,12 +89,12 @@ class Prediction(BaseModel): ``` ```python -# predict.py -from cog import BasePredictor +# run.py +from cog import BaseRunner from output_types import Prediction -class Predictor(BasePredictor): - def predict(self, prompt: str) -> Prediction: +class Runner(BaseRunner): + def run(self, prompt: str) -> Prediction: ... ``` @@ -186,30 +186,30 @@ Each `SchemaType` produces its JSON Schema fragment via `JSONSchema()`: ### Output Types -| Python | SchemaType | JSON Schema | -| -------------------------- | ------------------------ | --------------------------------------------------------------- | -| `str` | `SchemaPrimitive` | `{"type": "string"}` | -| `int` | `SchemaPrimitive` | `{"type": "integer"}` | -| `float` | `SchemaPrimitive` | `{"type": "number"}` | -| `bool` | `SchemaPrimitive` | `{"type": "boolean"}` | -| `Path` | `SchemaPrimitive` | `{"type": "string", "format": "uri"}` | -| `dict` (bare) | `SchemaAny` | `{"type": "object"}` | -| `dict[str, V]` | `SchemaDict` | `{"type": "object", "additionalProperties": V}` | -| `list` (bare) | `SchemaArray(SchemaAny)` | `{"type": "array", "items": {"type": "object"}}` | -| `list[T]` | `SchemaArray` | `{"type": "array", "items": T}` | -| `Annotated[T, cog.Opaque]` | `SchemaPrimitive(TypeAny)` | `{"type": "object"}` | +| Python | SchemaType | JSON Schema | +| -------------------------------- | --------------------------------------- | --------------------------------------------------------------- | +| `str` | `SchemaPrimitive` | `{"type": "string"}` | +| `int` | `SchemaPrimitive` | `{"type": "integer"}` | +| `float` | `SchemaPrimitive` | `{"type": "number"}` | +| `bool` | `SchemaPrimitive` | `{"type": "boolean"}` | +| `Path` | `SchemaPrimitive` | `{"type": "string", "format": "uri"}` | +| `dict` (bare) | `SchemaAny` | `{"type": "object"}` | +| `dict[str, V]` | `SchemaDict` | `{"type": "object", "additionalProperties": V}` | +| `list` (bare) | `SchemaArray(SchemaAny)` | `{"type": "array", "items": {"type": "object"}}` | +| `list[T]` | `SchemaArray` | `{"type": "array", "items": T}` | +| `Annotated[T, cog.Opaque]` | `SchemaPrimitive(TypeAny)` | `{"type": "object"}` | | `Annotated[list[T], cog.Opaque]` | `SchemaArray(SchemaPrimitive(TypeAny))` | `{"type": "array", "items": {"type": "object"}}` | -| `BaseModel` subclass | `SchemaObject` | `{"type": "object", "properties": {...}}` | -| `Iterator[T]` | `SchemaIterator` | `{"type": "array", "items": T, "x-cog-array-type": "iterator"}` | -| `ConcatenateIterator[str]` | `SchemaConcatIterator` | Streaming token output | -| Nested types | Recursive | `dict[str, list[dict[str, int]]]` fully supported | +| `BaseModel` subclass | `SchemaObject` | `{"type": "object", "properties": {...}}` | +| `Iterator[T]` | `SchemaIterator` | `{"type": "array", "items": T, "x-cog-array-type": "iterator"}` | +| `ConcatenateIterator[str]` | `SchemaConcatIterator` | Streaming token output | +| Nested types | Recursive | `dict[str, list[dict[str, int]]]` fully supported | ### Unsupported Output Types -| Python | Error | -| --------------------------- | -------------------------------------------------------------------- | -| `Optional[T]` / `T \| None` | Predictions must succeed with a value or fail with an error | -| `Union[A, B]` | Ambiguous for downstream consumers | +| Python | Error | +| --------------------------- | -------------------------------------------------------------------------------------------------------------------------------- | +| `Optional[T]` / `T \| None` | Predictions must succeed with a value or fail with an error | +| `Union[A, B]` | Ambiguous for downstream consumers | | External package types | Cannot be statically analyzed — define as BaseModel, use .pyi stub, or mark JSON-shaped values with `Annotated[..., cog.Opaque]` | ## Cog-Specific Extensions @@ -311,12 +311,12 @@ A simplified example showing a multi-file predictor with structured output: ## Code References -| File | Purpose | -| ----------------------------- | -------------------------------------------------------------------- | -| `pkg/schema/schema_type.go` | `SchemaType` ADT, `ResolveSchemaType()`, `JSONSchema()` generation | -| `pkg/schema/types.go` | `PredictorInfo`, `PrimitiveType`, `FieldType`, `InputField`, imports | -| `pkg/schema/python/` | Tree-sitter Python parser and cross-file resolution | -| `pkg/schema/openapi.go` | OpenAPI document assembly from `PredictorInfo` | -| `pkg/schema/generator.go` | Top-level `Generate()`, `GenerateCombined()`, `Parser` type | -| `pkg/schema/errors.go` | Typed schema error kinds | -| `pkg/image/build.go` | Build-time schema generation entry point and schema file validation | +| File | Purpose | +| --------------------------- | -------------------------------------------------------------------- | +| `pkg/schema/schema_type.go` | `SchemaType` ADT, `ResolveSchemaType()`, `JSONSchema()` generation | +| `pkg/schema/types.go` | `PredictorInfo`, `PrimitiveType`, `FieldType`, `InputField`, imports | +| `pkg/schema/python/` | Tree-sitter Python parser and cross-file resolution | +| `pkg/schema/openapi.go` | OpenAPI document assembly from `PredictorInfo` | +| `pkg/schema/generator.go` | Top-level `Generate()`, `GenerateCombined()`, `Parser` type | +| `pkg/schema/errors.go` | Typed schema error kinds | +| `pkg/image/build.go` | Build-time schema generation entry point and schema file validation | diff --git a/architecture/03-prediction-api.md b/architecture/03-prediction-api.md index 45279a914b..057c83cee8 100644 --- a/architecture/03-prediction-api.md +++ b/architecture/03-prediction-api.md @@ -59,7 +59,7 @@ What clients send to start a prediction: | `webhook_events_filter` | array (optional) | Which events to send | | `created_at` | datetime (optional) | Client-provided timestamp | -The `input` object is validated against the `Input` schema generated from the predictor's `predict()` signature. Unknown fields are rejected; missing required fields raise validation errors. +The `input` object is validated against the `Input` schema generated from the runner's `run()` signature. Unknown fields are rejected; missing required fields raise validation errors. ## PredictionResponse @@ -91,7 +91,7 @@ What comes back from the API: | `status` | enum | `starting`, `processing`, `succeeded`, `canceled`, `failed` | | `input` | object | Echo of the input (for reference) | | `output` | any | **Model-specific** -- type defined by schema | -| `logs` | string | Captured stdout/stderr from predict() | +| `logs` | string | Captured stdout/stderr from run() | | `error` | string | Error message if status is `failed` | | `metrics` | object | Timing and other metrics | | `created_at` | datetime | When request was received | @@ -103,9 +103,9 @@ What comes back from the API: ```mermaid stateDiagram-v2 [*] --> starting: Request received - starting --> processing: predict() called - processing --> succeeded: predict() returns - processing --> failed: predict() raises exception + starting --> processing: run() called + processing --> succeeded: run() returns + processing --> failed: run() raises exception processing --> canceled: Cancel requested succeeded --> [*] failed --> [*] @@ -204,7 +204,7 @@ flowchart LR coerce["Type Coercion"] end - subgraph predict["predict()"] + subgraph predict["run()"] kwargs["**kwargs"] end @@ -219,13 +219,13 @@ flowchart LR 2. **Validate against schema** -- Coglet validates types, required fields, and constraints at the HTTP edge using the OpenAPI schema 3. **Download files** -- URLs in `cog.Path` fields are fetched to local temp files 4. **Coerce types** -- Strings become Paths, etc. -5. **Call predict()** -- Validated input passed as `**kwargs` +5. **Call run()** -- Validated input passed as `**kwargs` ### Output Handling Flow ```mermaid flowchart LR - subgraph predict["predict()"] + subgraph predict["run()"] result["Return value / yields"] end @@ -243,7 +243,7 @@ flowchart LR serialize --> output ``` -1. **Capture output** -- Return value or yielded values from predict() +1. **Capture output** -- Return value or yielded values from run() 2. **Upload files** -- `cog.Path` outputs are uploaded, replaced with URLs 3. **Serialize** -- Convert to JSON-compatible format 4. **Return** -- Place in `output` field of response @@ -255,13 +255,13 @@ Input files (cog.Path): ``` Client sends: {"input": {"image": "https://example.com/photo.jpg"}} Server downloads: /tmp/inputabc123.jpg -predict() sees: image = Path("/tmp/inputabc123.jpg") +run() sees: image = Path("/tmp/inputabc123.jpg") ``` Output files (cog.Path): ``` -predict() returns: Path("/tmp/output.png") +run() returns: Path("/tmp/output.png") Server uploads: https://storage.example.com/output-xyz.png Client receives: {"output": "https://storage.example.com/output-xyz.png"} ``` @@ -306,7 +306,7 @@ sequenceDiagram Cog-->>Client: 202 {status: "starting"} Cog->>Webhook: {status: "starting"} - Note over Cog: predict() starts + Note over Cog: run() starts Cog->>Webhook: {status: "processing"} loop Output yields @@ -342,7 +342,7 @@ Webhook delivery includes structured retry with exponential backoff and automati For models that yield output progressively: ```python -def predict(self, prompt: str) -> Iterator[str]: +def run(self, prompt: str) -> Iterator[str]: for token in generate(prompt): yield token ``` diff --git a/architecture/04-container-runtime.md b/architecture/04-container-runtime.md index b486e84fde..8758609780 100644 --- a/architecture/04-container-runtime.md +++ b/architecture/04-container-runtime.md @@ -39,7 +39,7 @@ flowchart TB subgraph worker["Worker Subprocess (Python)"] subgraph predictor["Predictor"] setup["setup() → runs once at startup"] - predict["predict() → handles SlotRequest#colon;#colon;Predict"] + predict["run() → handles SlotRequest#colon;#colon;Predict"] end end @@ -85,7 +85,7 @@ The `Prediction` struct is itself a state machine -- its mutation methods (`set_ - **Responsibilities**: - Load user's predictor module - Run `setup()` once at startup - - Execute `predict()` method + - Execute selected `run()` method, or legacy `predict()` method for older models - Capture stdout/stderr via ContextVar-based log routing - Send events back to parent via slot sockets @@ -110,7 +110,7 @@ stateDiagram-v2 setup --> dead: setup() raises idle --> predicting: SlotRequest#colon;#colon;Predict - predicting --> idle: predict() returns/raises + predicting --> idle: run() returns/raises idle --> dead: Shutdown / crash @@ -121,13 +121,13 @@ stateDiagram-v2 - **`setup()` runs exactly once**, before any prediction is accepted. Use it to load weights, initialize GPU contexts, and warm caches. If it raises an exception, the worker exits and health becomes `SETUP_FAILED` -- there is no retry. -- **`self` state persists across all `predict()` calls.** Storing your loaded model on `self.model` in `setup()` and using it in every `predict()` call is the intended pattern. +- **`self` state persists across all `run()` calls.** Storing your loaded model on `self.model` in `setup()` and using it in every `run()` call is the intended pattern. - **No teardown hook.** There is no `teardown()`, `cleanup()`, or `__del__` contract. When the container shuts down, the process exits. If you need cleanup (e.g., flushing a log buffer), use `atexit`. -- **`predict()` is sequential by default.** With `COG_MAX_CONCURRENCY=1` (the default), `predict()` is never called concurrently -- each call completes before the next begins. +- **`run()` is sequential by default.** With `COG_MAX_CONCURRENCY=1` (the default), `run()` is never called concurrently -- each call completes before the next begins. -- **With `COG_MAX_CONCURRENCY > 1`, concurrent `predict()` calls share `self`.** Async predictors run multiple coroutines on a shared asyncio event loop -- not truly parallel, but interleaved at `await` points. If your model stores mutable state on `self` that could be accessed across `await` boundaries, take care. If your model isn't safe to call concurrently, leave concurrency at 1. +- **With `COG_MAX_CONCURRENCY > 1`, concurrent `run()` calls share `self`.** Async runners run multiple coroutines on a shared asyncio event loop -- not truly parallel, but interleaved at `await` points. If your model stores mutable state on `self` that could be accessed across `await` boundaries, take care. If your model isn't safe to call concurrently, leave concurrency at 1. - **A worker crash is terminal.** If the worker process crashes (segfault, OOM kill), the runtime fails all in-flight predictions and stops accepting new ones. The HTTP server stays up (health endpoints still respond) but the container must be restarted externally -- there is no automatic worker respawn. @@ -175,15 +175,15 @@ Per-prediction data. Using separate sockets per slot avoids head-of-line blockin **Worker → Parent:** -| Message | Purpose | -| ---------------------------------------------- | -------------------------------------------------------------------- | -| `Log { source, data }` | Log line from predict() | -| `Output { output }` | Yielded output value (for generators/streaming) | -| `FileOutput { filename, kind, mime_type }` | File produced by predict() -- referenced by path, uploaded by parent | -| `Metric { name, value, mode }` | Custom metric (mode: `replace`, `increment`, or `append`) | -| `Done { id, output, predict_time, is_stream }` | Prediction completed successfully | -| `Failed { id, error }` | Prediction failed | -| `Cancelled { id }` | Prediction was cancelled | +| Message | Purpose | +| ---------------------------------------------- | ---------------------------------------------------------------- | +| `Log { source, data }` | Log line from run() | +| `Output { output }` | Yielded output value (for generators/streaming) | +| `FileOutput { filename, kind, mime_type }` | File produced by run() -- referenced by path, uploaded by parent | +| `Metric { name, value, mode }` | Custom metric (mode: `replace`, `increment`, or `append`) | +| `Done { id, output, predict_time, is_stream }` | Prediction completed successfully | +| `Failed { id, error }` | Prediction failed | +| `Cancelled { id }` | Prediction was cancelled | ## Health State Machine @@ -311,7 +311,7 @@ Following a single prediction from HTTP request to response: 9. **Response assembled.** The `Prediction` state machine transitions to `succeeded`, the slot permit is released, and the response is returned to the client (or delivered via webhook for async requests). -**On error:** If `predict()` raises an exception, the worker sends a `Failed` message. The prediction is marked `failed`, the slot returns to idle, and the predictor instance survives -- it handles the next request normally. Only a process-level crash (segfault, OOM kill) destroys the instance; see [Predictor Lifecycle](#predictor-lifecycle) for what happens then. +**On error:** If `run()` raises an exception, the worker sends a `Failed` message. The prediction is marked `failed`, the slot returns to idle, and the runner instance survives -- it handles the next request normally. Only a process-level crash (segfault, OOM kill) destroys the instance; see [Predictor Lifecycle](#predictor-lifecycle) for what happens then. ## Invocation Path @@ -319,7 +319,7 @@ How coglet gets invoked when running a Cog container: ```mermaid flowchart TB - cli["cog predict / cog exec\n(CLI)"] + cli["cog run / cog exec\n(CLI)"] launcher["python -m cog.server.http\nimport coglet\ncoglet.server.serve(predictor_ref, port=5000)"] @@ -327,7 +327,7 @@ flowchart TB direction TB axum["HTTP Server (axum) #colon;5000\n/predictions, /health-check, etc."] svc["PredictionService\n(state, webhooks, permits)"] - worker_sub["Worker subprocess (Python)\n- loads predictor_ref\n- runs setup()\n- handles predict() requests"] + worker_sub["Worker subprocess (Python)\n- loads runner ref\n- runs setup()\n- handles run() requests"] axum --> svc svc -- "Unix socket + pipes" --> worker_sub @@ -371,7 +371,7 @@ When a prediction input exceeds 6MiB, it's too large to send inline through the ## File Outputs -When predict() produces file outputs (`cog.Path`), the worker sends a `FileOutput` message with the filename and MIME type. The parent handles uploading the file (or base64-encoding it for inline responses). The `output_dir` field in the `Predict` request tells the worker where to write output files. +When run() produces file outputs (`cog.Path`), the worker sends a `FileOutput` message with the filename and MIME type. The parent handles uploading the file (or base64-encoding it for inline responses). The `output_dir` field in the `Predict` request tells the worker where to write output files. `FileOutputKind` distinguishes between normal file outputs (`FileType`) and oversized outputs (`Oversized`) that exceeded an inline size limit. diff --git a/architecture/05-build-system.md b/architecture/05-build-system.md index 98f4dc42a4..8a79198d75 100644 --- a/architecture/05-build-system.md +++ b/architecture/05-build-system.md @@ -1,6 +1,6 @@ # Build System -The build system transforms [Model Source](./01-model-source.md) (cog.yaml + predict.py + weights) into a production-ready OCI image containing the [Container Runtime](./04-container-runtime.md). +The build system transforms [Model Source](./01-model-source.md) (cog.yaml + run.py + weights) into a production-ready OCI image containing the [Container Runtime](./04-container-runtime.md). ## Build Flow @@ -8,7 +8,7 @@ The build system transforms [Model Source](./01-model-source.md) (cog.yaml + pre flowchart TB subgraph input["Inputs"] yaml["cog.yaml"] - code["predict.py"] + code["run.py"] weights["weights"] end @@ -230,12 +230,12 @@ After the main build, Cog: #### Image Labels -| Label | Content | -| ------------------------ | ---------------------------- | -| `run.cog.version` | Cog CLI version | -| `run.cog.config` | Serialized cog.yaml | +| Label | Content | +| ------------------------ | ------------------------------------------ | +| `run.cog.version` | Cog CLI version | +| `run.cog.config` | Serialized cog.yaml | | `run.cog.openapi_schema` | OpenAPI spec from static schema generation | -| `run.cog.pip_freeze` | Installed package versions | +| `run.cog.pip_freeze` | Installed package versions | These labels can be fetched from a remote registry or local image store (like containerd) without pulling the full image. This allows tooling - both the Cog CLI during development and production infrastructure - to inspect model metadata and make decisions about how to run a model before booting it. diff --git a/architecture/06-cli.md b/architecture/06-cli.md index ffb1567931..b47a67ee61 100644 --- a/architecture/06-cli.md +++ b/architecture/06-cli.md @@ -2,25 +2,25 @@ The Cog CLI is a Go binary that provides commands for the full model lifecycle: development, building, testing, and deployment. This document covers what each command does and how it connects to the systems described in previous docs. -**Important**: Model code always runs inside a container, never on the host machine. Commands like `cog predict` and `cog serve` build an image, start a container, and interact with it via the [Prediction API](./03-prediction-api.md). The CLI orchestrates this, but the model execution happens in the containerized [Container Runtime](./04-container-runtime.md). +**Important**: Model code always runs inside a container, never on the host machine. Commands like `cog run` and `cog serve` build an image, start a container, and interact with it via the [Prediction API](./03-prediction-api.md). The CLI orchestrates this, but the model execution happens in the containerized [Container Runtime](./04-container-runtime.md). ## Commands Overview -| Command | Job To Be Done | -| ------------- | ------------------------------------- | -| `cog init` | Bootstrap a new model project | -| `cog build` | Create a container image | -| `cog predict` | Run a prediction in a container | -| `cog exec` | Run arbitrary commands in a container | -| `cog serve` | Start HTTP server in a container | -| `cog push` | Deploy to Replicate | -| `cog login` | Authenticate with Replicate | +| Command | Job To Be Done | +| ----------- | ------------------------------------- | +| `cog init` | Bootstrap a new model project | +| `cog build` | Create a container image | +| `cog run` | Run a prediction in a container | +| `cog exec` | Run arbitrary commands in a container | +| `cog serve` | Start HTTP server in a container | +| `cog push` | Deploy to Replicate | +| `cog login` | Authenticate with Replicate | ## Development Commands ### cog init -**Job**: Create a starter `cog.yaml` and `predict.py` for a new model. +**Job**: Create a starter `cog.yaml` and `run.py` for a new model. ```bash cog init @@ -29,18 +29,18 @@ cog init Creates: - `cog.yaml` with sensible defaults -- `predict.py` with a skeleton Predictor class +- `run.py` with a skeleton Runner class **Code**: `pkg/cli/init.go` --- -### cog predict +### cog run **Job**: Run a prediction in a container. ```bash -cog predict -i prompt="A photo of a cat" -i steps=50 +cog run -i prompt="A photo of a cat" -i steps=50 ``` What happens: @@ -58,7 +58,7 @@ Input types are inferred from the schema: - Files: `-i image=@photo.jpg` (uploaded to container) - URLs: `-i image=https://example.com/photo.jpg` -**Code**: `pkg/cli/predict.go` +**Code**: `pkg/cli/run.go` dispatches the public command; prediction execution is implemented in `pkg/cli/predict.go`. --- @@ -196,7 +196,7 @@ sequenceDiagram end CLI->>Container: POST /predictions - Container->>Container: Run predict() + Container->>Container: Run run() Container-->>CLI: Response JSON CLI->>Docker: Stop container @@ -215,7 +215,8 @@ cmd/cog/ pkg/cli/ ├── root.go # Root command, subcommand registration ├── build.go # cog build -├── predict.go # cog predict +├── run.go # cog run dispatch +├── predict.go # prediction execution and legacy cog predict ├── exec.go # cog exec ├── serve.go # cog serve ├── push.go # cog push diff --git a/crates/coglet-python/src/predictor.rs b/crates/coglet-python/src/predictor.rs index daa206d0cf..5b04b20bcc 100644 --- a/crates/coglet-python/src/predictor.rs +++ b/crates/coglet-python/src/predictor.rs @@ -2,8 +2,9 @@ use std::sync::{Arc, OnceLock}; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::types::PyDict; +use pyo3::types::{PyDict, PyTuple}; use coglet_core::worker::SlotSender; use coglet_core::{PredictionError, PredictionOutput, PredictionResult}; @@ -283,6 +284,7 @@ pub enum TrainKind { pub enum PredictorKind { /// Class instance with predict() method, optionally train() Class { + method_name: String, predict: PredictKind, train: TrainKind, }, @@ -339,16 +341,17 @@ impl PythonPredictor { }; PredictorKind::StandaloneFunction(predict_kind) } else { - // Class instance - detect predict() and train() methods - let (is_async, is_async_gen) = Self::detect_async(py, &instance, "predict")?; + // Class instance - detect run()/predict() and train() methods + let method_name = Self::selected_predict_method_name(py, &instance)?; + let (is_async, is_async_gen) = Self::detect_async(py, &instance, &method_name)?; let predict_kind = if is_async_gen { - tracing::info!("Detected async generator predict()"); + tracing::info!("Detected async generator {}()", method_name); PredictKind::AsyncGen } else if is_async { - tracing::info!("Detected async predict()"); + tracing::info!("Detected async {}()", method_name); PredictKind::Async } else { - tracing::info!("Detected sync predict()"); + tracing::info!("Detected sync {}()", method_name); PredictKind::Sync }; @@ -367,6 +370,7 @@ impl PythonPredictor { }; PredictorKind::Class { + method_name, predict: predict_kind, train: train_kind, } @@ -396,7 +400,9 @@ impl PythonPredictor { if is_function { Self::unwrap_field_info_defaults(py, &predictor.instance, "")?; } else { - Self::unwrap_field_info_defaults(py, &predictor.instance, "predict")?; + if let PredictorKind::Class { method_name, .. } = &predictor.kind { + Self::unwrap_field_info_defaults(py, &predictor.instance, method_name)?; + } if matches!(predictor.kind, PredictorKind::Class { train, .. } if train != TrainKind::None) { Self::unwrap_field_info_defaults(py, &predictor.instance, "train")?; @@ -406,6 +412,44 @@ impl PythonPredictor { Ok(predictor) } + fn selected_predict_method_name(py: Python<'_>, instance: &PyObject) -> PyResult { + let class = instance.bind(py).getattr("__class__")?; + let mro = class.getattr("__mro__")?.cast_into::()?; + let cog_predictor = py.import("cog.predictor")?; + let base_runner = cog_predictor.getattr("BaseRunner")?; + let base_predictor = cog_predictor.getattr("BasePredictor")?; + let object = py.import("builtins")?.getattr("object")?; + let callable = py.import("builtins")?.getattr("callable")?; + + let mut has_run = false; + let mut has_predict = false; + for owner in mro.iter() { + if owner.is(&base_runner) || owner.is(&base_predictor) || owner.is(&object) { + break; + } + let dict = owner.getattr("__dict__")?; + let run_value = dict.call_method1("get", ("run",))?; + if !run_value.is_none() && callable.call1((&run_value,))?.extract()? { + has_run = true; + } + let predict_value = dict.call_method1("get", ("predict",))?; + if !predict_value.is_none() && callable.call1((&predict_value,))?.extract()? { + has_predict = true; + } + } + + match (has_run, has_predict) { + (true, true) => Err(PyValueError::new_err( + "predictor must define either run() or predict(), not both", + )), + (true, false) => Ok("run".to_string()), + (false, true) => Ok("predict".to_string()), + (false, false) => Err(PyValueError::new_err( + "run() or predict() method not found on predictor", + )), + } + } + /// Replace FieldInfo defaults with their `.default` values on a method's signature. /// /// When users write `def predict(self, seed: int = Input(default=42, description="..."))`, @@ -603,7 +647,7 @@ impl PythonPredictor { pub fn predict_func<'py>(&self, py: Python<'py>) -> PyResult> { let instance = self.instance.bind(py); match &self.kind { - PredictorKind::Class { .. } => instance.getattr("predict"), + PredictorKind::Class { method_name, .. } => instance.getattr(method_name), PredictorKind::StandaloneFunction(_) => Ok(instance.clone()), } } @@ -622,8 +666,12 @@ impl PythonPredictor { /// For standalone functions, calls the function directly. pub fn predict_raw(&self, py: Python<'_>, input: &Bound<'_, PyDict>) -> PyResult { let (method_name, is_async) = match &self.kind { - PredictorKind::Class { predict, .. } => ( - "predict", + PredictorKind::Class { + method_name, + predict, + .. + } => ( + method_name.as_str(), matches!(predict, PredictKind::Async | PredictKind::AsyncGen), ), PredictorKind::StandaloneFunction(predict_kind) => ( @@ -991,11 +1039,18 @@ impl PythonPredictor { .map_err(|e| PredictionError::InvalidInput(format_validation_error(py, &e)))?; let input_dict = prepared.dict(py); - // Call predict - returns coroutine + // Call run()/predict() - returns coroutine let instance = self.instance.bind(py); - let coro = instance - .call_method("predict", (), Some(&input_dict)) - .map_err(|e| PredictionError::Failed(format!("Failed to call predict: {}", e)))?; + let method_name = match &self.kind { + PredictorKind::Class { method_name, .. } => method_name.as_str(), + PredictorKind::StandaloneFunction(_) => "", + }; + let coro = if method_name.is_empty() { + instance.call((), Some(&input_dict)) + } else { + instance.call_method(method_name, (), Some(&input_dict)) + } + .map_err(|e| PredictionError::Failed(format!("Failed to call predict: {}", e)))?; // For async generators, wrap to collect all values let is_async_gen = matches!( @@ -1290,3 +1345,175 @@ impl PythonPredictor { } } } + +#[cfg(test)] +mod tests { + use super::*; + + use std::fs; + use std::path::PathBuf; + use std::sync::atomic::{AtomicUsize, Ordering}; + + use pyo3::types::PyList; + + static TEST_FILE_COUNTER: AtomicUsize = AtomicUsize::new(0); + + fn add_python_sdk_path(py: Python<'_>) { + py.run( + c"\ +import sys, types +coglet = types.ModuleType('coglet') +coglet.CancelationException = Exception +sys.modules.setdefault('coglet', coglet) +requests = types.ModuleType('requests') +sys.modules.setdefault('requests', requests) +", + None, + None, + ) + .expect("failed to install coglet test stub"); + + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let sdk_path = manifest_dir + .parent() + .and_then(|p| p.parent()) + .expect("crate should live under crates/coglet-python") + .join("python"); + let sys = py.import("sys").expect("sys should import"); + let path = sys + .getattr("path") + .expect("sys.path should exist") + .cast_into::() + .expect("sys.path should be a list"); + path.insert(0, sdk_path.to_string_lossy().as_ref()) + .expect("failed to prepend SDK path"); + } + + fn write_predictor_source(source: &str) -> PathBuf { + let counter = TEST_FILE_COUNTER.fetch_add(1, Ordering::SeqCst); + let path = std::env::temp_dir().join(format!( + "coglet_predictor_test_{}_{}.py", + std::process::id(), + counter + )); + fs::write(&path, source).expect("failed to write test predictor"); + path + } + + fn load_predictor_source(source: &str) -> PyResult { + pyo3::Python::initialize(); + let path = write_predictor_source(source); + Python::attach(|py| { + add_python_sdk_path(py); + let predictor_ref = format!("{}:Predictor", path.display()); + let result = PythonPredictor::load(py, &predictor_ref); + let _ = fs::remove_file(&path); + result + }) + } + + fn selected_predict_method_name(predictor: &PythonPredictor) -> String { + Python::attach(|py| { + predictor + .predict_func(py) + .expect("predict function should exist") + .getattr("__name__") + .expect("predict function should have __name__") + .extract() + .expect("__name__ should be a string") + }) + } + + #[test] + fn class_with_run_loads() { + let predictor = load_predictor_source( + r#" +from cog import BaseRunner + +class Predictor(BaseRunner): + def run(self) -> str: + return "ok" +"#, + ) + .expect("predictor with run should load"); + + assert_eq!(selected_predict_method_name(&predictor), "run"); + } + + #[test] + fn class_with_run_and_predict_errors() { + let err = match load_predictor_source( + r#" +from cog import BaseRunner + +class Predictor(BaseRunner): + def run(self) -> str: + return "run" + + def predict(self) -> str: + return "predict" +"#, + ) { + Ok(_) => panic!("predictor with run and predict should error"), + Err(err) => err, + }; + + let message = err.to_string(); + assert!(message.contains("run"), "unexpected error: {message}"); + assert!(message.contains("predict"), "unexpected error: {message}"); + } + + #[test] + fn inherited_user_run_loads() { + let predictor = load_predictor_source( + r#" +from cog import BaseRunner + +class Parent(BaseRunner): + def run(self) -> str: + return "ok" + +class Predictor(Parent): + pass +"#, + ) + .expect("predictor with inherited user run should load"); + + assert_eq!(selected_predict_method_name(&predictor), "run"); + } + + #[test] + fn no_user_run_or_predict_errors() { + let err = match load_predictor_source( + r#" +from cog import BaseRunner + +class Predictor(BaseRunner): + pass +"#, + ) { + Ok(_) => panic!("predictor without run or predict should error"), + Err(err) => err, + }; + + let message = err.to_string(); + assert!(message.contains("run"), "unexpected error: {message}"); + assert!(message.contains("predict"), "unexpected error: {message}"); + } + + #[test] + fn legacy_predict_loads_with_fallback() { + let predictor = load_predictor_source( + r#" +from cog import BaseRunner + +class Predictor(BaseRunner): + def predict(self) -> str: + return "ok" +"#, + ) + .expect("predictor with legacy predict should load"); + + assert_eq!(selected_predict_method_name(&predictor), "predict"); + } +} diff --git a/docs/cli.md b/docs/cli.md index c7bdac7ca8..9d99680cb6 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -30,7 +30,7 @@ https://github.com/replicate/cog Build a Docker image from the cog.yaml in the current directory. The generated image contains your model code, dependencies, and the Cog -runtime. It can be run locally with 'cog predict' or pushed to a registry +runtime. It can be run locally with 'cog run' or pushed to a registry with 'cog push'. ``` @@ -75,7 +75,7 @@ Diagnose and fix common issues in your Cog project. NOTE: cog doctor is experimental. Behavior and checks may change in future versions. By default, cog doctor reports problems without modifying any files. -Pass --fix to automatically apply safe fixes. +Pass --fix to automatically apply safe fixes, including migrating deprecated predict names to run names when no file collision exists. ``` cog doctor [flags] @@ -132,10 +132,10 @@ cog exec [arg...] [flags] ## `cog init` -Create a cog.yaml and predict.py in the current directory. +Create a cog.yaml and run.py in the current directory. These files provide a starting template for defining your model's environment -and prediction interface. Edit them to match your model's requirements. +and run interface. Edit them to match your model's requirements. ``` cog init [flags] @@ -175,95 +175,95 @@ cog login [flags] --token-stdin Pass login token on stdin instead of opening a browser. You can find your Replicate login token at https://replicate.com/auth/token ``` -## `cog predict` - -Run a prediction. +## `cog push` -If 'image' is passed, it will run the prediction on that Docker image. -It must be an image that has been built by Cog. +Build a Docker image from cog.yaml and push it to a container registry. -Otherwise, it will build the model in the current directory and run -the prediction on that. +Cog can push to any OCI-compliant registry. When pushing to Replicate's +registry (r8.im), run 'cog login' first to authenticate. ``` -cog predict [image] [flags] +cog push [IMAGE] [flags] ``` **Examples** ``` - # Run a prediction with named inputs - cog predict -i prompt="a photo of a cat" - - # Pass a file as input - cog predict -i image=@photo.jpg - - # Save output to a file - cog predict -i image=@input.jpg -o output.png - - # Pass multiple inputs - cog predict -i prompt="sunset" -i width=1024 -i height=768 + # Push to Replicate + cog push r8.im/your-username/my-model - # Run against a pre-built image - cog predict r8.im/your-username/my-model -i prompt="hello" + # Push to any OCI registry + cog push registry.example.com/your-username/model-name - # Pass inputs as JSON - echo '{"prompt": "a cat"}' | cog predict --json @- + # Push with model weights in a separate layer (Replicate only) + cog push r8.im/your-username/my-model --separate-weights ``` **Options** ``` - -e, --env stringArray Environment variables, in the form name=value -f, --file string The name of the config file. (default "cog.yaml") - --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. - -h, --help help for predict - -i, --input stringArray Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg - --json string Pass inputs as JSON object, read from file (@inputs.json) or via stdin (@-) - -o, --output string Output path + -h, --help help for push + --no-cache Do not use cache when building the image + --openapi-schema string Load OpenAPI schema from a file --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") - --setup-timeout uint32 The timeout for a container to setup (in seconds). (default 300) + --secret stringArray Secrets to pass to the build environment in the form 'id=foo,src=/path/to/file' + --separate-weights Separate model weights from code in image layers --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") - --use-replicate-token Pass REPLICATE_API_TOKEN from local environment into the model context ``` -## `cog push` +## `cog run` -Build a Docker image from cog.yaml and push it to a container registry. +Run a prediction. -Cog can push to any OCI-compliant registry. When pushing to Replicate's -registry (r8.im), run 'cog login' first to authenticate. +If 'image' is passed, it will run the prediction on that Docker image. +It must be an image that has been built by Cog. + +Otherwise, it will build the model in the current directory and run +the prediction on that. ``` -cog push [IMAGE] [flags] +cog run [image] [flags] ``` **Examples** ``` - # Push to Replicate - cog push r8.im/your-username/my-model + # Run a prediction with named inputs + cog run -i prompt="a photo of a cat" - # Push to any OCI registry - cog push registry.example.com/your-username/model-name + # Pass a file as input + cog run -i image=@photo.jpg - # Push with model weights in a separate layer (Replicate only) - cog push r8.im/your-username/my-model --separate-weights + # Save output to a file + cog run -i image=@input.jpg -o output.png + + # Pass multiple inputs + cog run -i prompt="sunset" -i width=1024 -i height=768 + + # Run against a pre-built image + cog run r8.im/your-username/my-model -i prompt="hello" + + # Pass inputs as JSON + echo '{"prompt": "a cat"}' | cog run --json @- ``` **Options** ``` + -e, --env stringArray Environment variables, in the form name=value -f, --file string The name of the config file. (default "cog.yaml") - -h, --help help for push - --no-cache Do not use cache when building the image - --openapi-schema string Load OpenAPI schema from a file + --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. + -h, --help help for run + -i, --input stringArray Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg + --json string Pass inputs as JSON object, read from file (@inputs.json) or via stdin (@-) + -o, --output string Output path --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") - --secret stringArray Secrets to pass to the build environment in the form 'id=foo,src=/path/to/file' - --separate-weights Separate model weights from code in image layers + --setup-timeout uint32 The timeout for a container to setup (in seconds). (default 300) --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") + --use-replicate-token Pass REPLICATE_API_TOKEN from local environment into the model context ``` ## `cog serve` diff --git a/docs/getting-started-own-model.md b/docs/getting-started-own-model.md index eb2bb55f83..220509c4c3 100644 --- a/docs/getting-started-own-model.md +++ b/docs/getting-started-own-model.md @@ -27,7 +27,7 @@ sudo chmod +x /usr/local/bin/cog To configure your project for use with Cog, you'll need to add two files: - [`cog.yaml`](yaml.md) defines system requirements, Python package dependencies, etc -- [`predict.py`](python.md) describes the prediction interface for your model +- [`run.py`](python.md) describes the prediction interface for your model Use the `cog init` command to generate these files in your project: @@ -76,18 +76,18 @@ With `cog.yaml`, you can also install system packages and other things. [Take a ## Define how to run predictions -The next step is to update `predict.py` to define the interface for running predictions on your model. The `predict.py` generated by `cog init` looks something like this: +The next step is to update `run.py` to define the interface for running predictions on your model. The `run.py` generated by `cog init` looks something like this: ```python -from cog import BasePredictor, Path, Input +from cog import BaseRunner, Path, Input import torch -class Predictor(BasePredictor): +class Runner(BaseRunner): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.net = torch.load("weights.pth") - def predict(self, + def run(self, image: Path = Input(description="Image to enlarge"), scale: float = Input(description="Factor to scale image by", default=1.5) ) -> Path: @@ -98,9 +98,9 @@ class Predictor(BasePredictor): return output ``` -Edit your `predict.py` file and fill in the functions with your own model's setup and prediction code. You might need to import parts of your model from another file. +Edit your `run.py` file and fill in the functions with your own model's setup and prediction code. You might need to import parts of your model from another file. -You also need to define the inputs to your model as arguments to the `predict()` function, as demonstrated above. For each argument, you need to annotate with a type. The supported types are: +You also need to define the inputs to your model as arguments to the `run()` function, as demonstrated above. For each argument, you need to annotate with a type. The supported types are: - `str`: a string - `int`: an integer @@ -123,19 +123,19 @@ You can provide more information about the input with the `Input()` function, as There are some more advanced options you can pass, too. For more details, [take a look at the prediction interface documentation](python.md). -Next, add the line `predict: "predict.py:Predictor"` to your `cog.yaml`, so it looks something like this: +Next, add the line `run: "run.py:Runner"` to your `cog.yaml`, so it looks something like this: ```yaml build: python_version: "3.13" python_requirements: requirements.txt -predict: "predict.py:Predictor" +run: "run.py:Runner" ``` That's it! To test this works, try running a prediction on the model: ``` -$ cog predict -i image=@input.jpg +$ cog run -i image=@input.jpg ✓ Building Docker image from cog.yaml... Successfully built 664ef88bc1f4 ✓ Model running in Docker image 664ef88bc1f4 @@ -145,7 +145,7 @@ Written output to output.png To pass more inputs to the model, you can add more `-i` options: ``` -$ cog predict -i image=@image.jpg -i scale=2.0 +$ cog run -i image=@image.jpg -i scale=2.0 ``` In this case it is just a number, not a file, so you don't need the `@` prefix. diff --git a/docs/getting-started.md b/docs/getting-started.md index ea1252b0b5..04fd72acd3 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -91,28 +91,28 @@ Let's pretend we've trained a model. With Cog, we can define how to run predicti We need to write some code to describe how predictions are run on the model. -Save this to `predict.py`: +Save this to `run.py`: ```python import os os.environ["TORCH_HOME"] = "." import torch -from cog import BasePredictor, Input, Path +from cog import BaseRunner, Input, Path from PIL import Image from torchvision import models WEIGHTS = models.ResNet50_Weights.IMAGENET1K_V1 -class Predictor(BasePredictor): +class Runner(BaseRunner): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = models.resnet50(weights=WEIGHTS).to(self.device) self.model.eval() - def predict(self, image: Path = Input(description="Image to classify")) -> dict: + def run(self, image: Path = Input(description="Image to classify")) -> dict: """Run a single prediction on the model""" img = Image.open(image).convert("RGB") preds = self.model(WEIGHTS.transforms()(img).unsqueeze(0).to(self.device)) @@ -137,7 +137,7 @@ Then update `cog.yaml` to look like this: build: python_version: "3.13" python_requirements: requirements.txt -predict: "predict.py:Predictor" +run: "run.py:Runner" ``` > [!TIP] @@ -154,7 +154,7 @@ curl $IMAGE_URL > input.jpg Now, let's run the model using Cog: ```bash -cog predict -i image=@input.jpg +cog run -i image=@input.jpg ``` @@ -170,7 +170,7 @@ If you see the following output then it worked! -Note: The first time you run `cog predict`, the build process will be triggered to generate a Docker container that can run your model. The next time you run `cog predict` the pre-built container will be used. +Note: The first time you run `cog run`, the build process will be triggered to generate a Docker container that can run your model. The next time you run `cog run` the pre-built container will be used. ## Build an image @@ -183,10 +183,10 @@ cog build -t resnet ``` -You can run this image with `cog predict` by passing the filename as an argument: +You can run this image with `cog run` by passing the filename as an argument: ```bash -cog predict resnet -i image=@input.jpg +cog run resnet -i image=@input.jpg ``` diff --git a/docs/http.md b/docs/http.md index bc13d55741..2a8e5b09ea 100644 --- a/docs/http.md +++ b/docs/http.md @@ -135,7 +135,7 @@ This produces a random identifier that is 26 ASCII characters long. ## File uploads -A model's `predict` function can produce file output by yielding or returning +A model's `run` function can produce file output by yielding or returning a `cog.Path` or `cog.File` value. By default, @@ -306,13 +306,13 @@ The request body is a JSON object with the following fields: - `input`: A JSON object with the same keys as the - [arguments to the `predict()` function](python.md). + [arguments to the `run()` function](python.md). Any `File` or `Path` inputs are passed as URLs. The response body is a JSON object with the following fields: - `status`: Either `succeeded` or `failed`. -- `output`: The return value of the `predict()` function. +- `output`: The return value of the `run()` function. - `error`: If `status` is `failed`, the error message. - `metrics`: An object containing prediction metrics. Always includes `predict_time` (elapsed seconds). @@ -458,10 +458,10 @@ After cleanup, the exception must be re-raised using a bare `raise` statement. Failure to re-raise the exception may result in the termination of the container. ```python -from cog import BasePredictor, CancelationException, Input, Path +from cog import BaseRunner, CancelationException, Input, Path -class Predictor(BasePredictor): - def predict(self, image: Path = Input(description="Image to process")) -> Path: +class Runner(BaseRunner): + def run(self, image: Path = Input(description="Image to process")) -> Path: try: return self.process(image) except CancelationException: diff --git a/docs/llms.txt b/docs/llms.txt index b38aa99586..d492435f85 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -28,22 +28,22 @@ build: - "libglib2.0-0" python_version: "3.13" python_requirements: requirements.txt -predict: "predict.py:Predictor" +run: "run.py:Runner" ``` -Define how predictions are run on your model with `predict.py`: +Define how predictions are run on your model with `run.py`: ```python -from cog import BasePredictor, Input, Path +from cog import BaseRunner, Input, Path import torch -class Predictor(BasePredictor): +class Runner(BaseRunner): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.model = torch.load("./weights.pth") # The arguments and types the model takes as input - def predict(self, + def run(self, image: Path = Input(description="Grayscale input image") ) -> Path: """Run a single prediction on the model""" @@ -57,7 +57,7 @@ In the above we accept a path to the image as an input, and return a path to our Now, you can run predictions on this model: ```console -$ cog predict -i image=@input.jpg +$ cog run -i image=@input.jpg --> Building Docker image... --> Running Prediction... --> Output written to output.jpg @@ -180,7 +180,7 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for how to set up a development environme - [Take a look at some examples of using Cog](https://github.com/replicate/cog-examples) - [Deploy models with Cog](docs/deploy.md) - [`cog.yaml` reference](docs/yaml.md) to learn how to define your model's environment -- [Prediction interface reference](docs/python.md) to learn how the `Predictor` interface works +- [Run interface reference](docs/python.md) to learn how the `Runner` interface works - [Training interface reference](docs/training.md) to learn how to add a fine-tuning API to your model - [HTTP API reference](docs/http.md) to learn how to use the HTTP API that models serve @@ -226,7 +226,7 @@ https://github.com/replicate/cog Build a Docker image from the cog.yaml in the current directory. The generated image contains your model code, dependencies, and the Cog -runtime. It can be run locally with 'cog predict' or pushed to a registry +runtime. It can be run locally with 'cog run' or pushed to a registry with 'cog push'. ``` @@ -271,7 +271,7 @@ Diagnose and fix common issues in your Cog project. NOTE: cog doctor is experimental. Behavior and checks may change in future versions. By default, cog doctor reports problems without modifying any files. -Pass --fix to automatically apply safe fixes. +Pass --fix to automatically apply safe fixes, including migrating deprecated predict names to run names when no file collision exists. ``` cog doctor [flags] @@ -328,10 +328,10 @@ cog exec [arg...] [flags] ## `cog init` -Create a cog.yaml and predict.py in the current directory. +Create a cog.yaml and run.py in the current directory. These files provide a starting template for defining your model's environment -and prediction interface. Edit them to match your model's requirements. +and run interface. Edit them to match your model's requirements. ``` cog init [flags] @@ -371,95 +371,95 @@ cog login [flags] --token-stdin Pass login token on stdin instead of opening a browser. You can find your Replicate login token at https://replicate.com/auth/token ``` -## `cog predict` - -Run a prediction. +## `cog push` -If 'image' is passed, it will run the prediction on that Docker image. -It must be an image that has been built by Cog. +Build a Docker image from cog.yaml and push it to a container registry. -Otherwise, it will build the model in the current directory and run -the prediction on that. +Cog can push to any OCI-compliant registry. When pushing to Replicate's +registry (r8.im), run 'cog login' first to authenticate. ``` -cog predict [image] [flags] +cog push [IMAGE] [flags] ``` **Examples** ``` - # Run a prediction with named inputs - cog predict -i prompt="a photo of a cat" - - # Pass a file as input - cog predict -i image=@photo.jpg - - # Save output to a file - cog predict -i image=@input.jpg -o output.png - - # Pass multiple inputs - cog predict -i prompt="sunset" -i width=1024 -i height=768 + # Push to Replicate + cog push r8.im/your-username/my-model - # Run against a pre-built image - cog predict r8.im/your-username/my-model -i prompt="hello" + # Push to any OCI registry + cog push registry.example.com/your-username/model-name - # Pass inputs as JSON - echo '{"prompt": "a cat"}' | cog predict --json @- + # Push with model weights in a separate layer (Replicate only) + cog push r8.im/your-username/my-model --separate-weights ``` **Options** ``` - -e, --env stringArray Environment variables, in the form name=value -f, --file string The name of the config file. (default "cog.yaml") - --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. - -h, --help help for predict - -i, --input stringArray Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg - --json string Pass inputs as JSON object, read from file (@inputs.json) or via stdin (@-) - -o, --output string Output path + -h, --help help for push + --no-cache Do not use cache when building the image + --openapi-schema string Load OpenAPI schema from a file --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") - --setup-timeout uint32 The timeout for a container to setup (in seconds). (default 300) + --secret stringArray Secrets to pass to the build environment in the form 'id=foo,src=/path/to/file' + --separate-weights Separate model weights from code in image layers --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") - --use-replicate-token Pass REPLICATE_API_TOKEN from local environment into the model context ``` -## `cog push` +## `cog run` -Build a Docker image from cog.yaml and push it to a container registry. +Run a prediction. -Cog can push to any OCI-compliant registry. When pushing to Replicate's -registry (r8.im), run 'cog login' first to authenticate. +If 'image' is passed, it will run the prediction on that Docker image. +It must be an image that has been built by Cog. + +Otherwise, it will build the model in the current directory and run +the prediction on that. ``` -cog push [IMAGE] [flags] +cog run [image] [flags] ``` **Examples** ``` - # Push to Replicate - cog push r8.im/your-username/my-model + # Run a prediction with named inputs + cog run -i prompt="a photo of a cat" - # Push to any OCI registry - cog push registry.example.com/your-username/model-name + # Pass a file as input + cog run -i image=@photo.jpg - # Push with model weights in a separate layer (Replicate only) - cog push r8.im/your-username/my-model --separate-weights + # Save output to a file + cog run -i image=@input.jpg -o output.png + + # Pass multiple inputs + cog run -i prompt="sunset" -i width=1024 -i height=768 + + # Run against a pre-built image + cog run r8.im/your-username/my-model -i prompt="hello" + + # Pass inputs as JSON + echo '{"prompt": "a cat"}' | cog run --json @- ``` **Options** ``` + -e, --env stringArray Environment variables, in the form name=value -f, --file string The name of the config file. (default "cog.yaml") - -h, --help help for push - --no-cache Do not use cache when building the image - --openapi-schema string Load OpenAPI schema from a file + --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. + -h, --help help for run + -i, --input stringArray Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg + --json string Pass inputs as JSON object, read from file (@inputs.json) or via stdin (@-) + -o, --output string Output path --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") - --secret stringArray Secrets to pass to the build environment in the form 'id=foo,src=/path/to/file' - --separate-weights Separate model weights from code in image layers + --setup-timeout uint32 The timeout for a container to setup (in seconds). (default 300) --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") + --use-replicate-token Pass REPLICATE_API_TOKEN from local environment into the model context ``` ## `cog serve` @@ -754,7 +754,7 @@ sudo chmod +x /usr/local/bin/cog To configure your project for use with Cog, you'll need to add two files: - [`cog.yaml`](yaml.md) defines system requirements, Python package dependencies, etc -- [`predict.py`](python.md) describes the prediction interface for your model +- [`run.py`](python.md) describes the prediction interface for your model Use the `cog init` command to generate these files in your project: @@ -803,18 +803,18 @@ With `cog.yaml`, you can also install system packages and other things. [Take a ## Define how to run predictions -The next step is to update `predict.py` to define the interface for running predictions on your model. The `predict.py` generated by `cog init` looks something like this: +The next step is to update `run.py` to define the interface for running predictions on your model. The `run.py` generated by `cog init` looks something like this: ```python -from cog import BasePredictor, Path, Input +from cog import BaseRunner, Path, Input import torch -class Predictor(BasePredictor): +class Runner(BaseRunner): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.net = torch.load("weights.pth") - def predict(self, + def run(self, image: Path = Input(description="Image to enlarge"), scale: float = Input(description="Factor to scale image by", default=1.5) ) -> Path: @@ -825,9 +825,9 @@ class Predictor(BasePredictor): return output ``` -Edit your `predict.py` file and fill in the functions with your own model's setup and prediction code. You might need to import parts of your model from another file. +Edit your `run.py` file and fill in the functions with your own model's setup and prediction code. You might need to import parts of your model from another file. -You also need to define the inputs to your model as arguments to the `predict()` function, as demonstrated above. For each argument, you need to annotate with a type. The supported types are: +You also need to define the inputs to your model as arguments to the `run()` function, as demonstrated above. For each argument, you need to annotate with a type. The supported types are: - `str`: a string - `int`: an integer @@ -850,19 +850,19 @@ You can provide more information about the input with the `Input()` function, as There are some more advanced options you can pass, too. For more details, [take a look at the prediction interface documentation](python.md). -Next, add the line `predict: "predict.py:Predictor"` to your `cog.yaml`, so it looks something like this: +Next, add the line `run: "run.py:Runner"` to your `cog.yaml`, so it looks something like this: ```yaml build: python_version: "3.13" python_requirements: requirements.txt -predict: "predict.py:Predictor" +run: "run.py:Runner" ``` That's it! To test this works, try running a prediction on the model: ``` -$ cog predict -i image=@input.jpg +$ cog run -i image=@input.jpg ✓ Building Docker image from cog.yaml... Successfully built 664ef88bc1f4 ✓ Model running in Docker image 664ef88bc1f4 @@ -872,7 +872,7 @@ Written output to output.png To pass more inputs to the model, you can add more `-i` options: ``` -$ cog predict -i image=@image.jpg -i scale=2.0 +$ cog run -i image=@image.jpg -i scale=2.0 ``` In this case it is just a number, not a file, so you don't need the `@` prefix. @@ -995,28 +995,28 @@ Let's pretend we've trained a model. With Cog, we can define how to run predicti We need to write some code to describe how predictions are run on the model. -Save this to `predict.py`: +Save this to `run.py`: ```python import os os.environ["TORCH_HOME"] = "." import torch -from cog import BasePredictor, Input, Path +from cog import BaseRunner, Input, Path from PIL import Image from torchvision import models WEIGHTS = models.ResNet50_Weights.IMAGENET1K_V1 -class Predictor(BasePredictor): +class Runner(BaseRunner): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = models.resnet50(weights=WEIGHTS).to(self.device) self.model.eval() - def predict(self, image: Path = Input(description="Image to classify")) -> dict: + def run(self, image: Path = Input(description="Image to classify")) -> dict: """Run a single prediction on the model""" img = Image.open(image).convert("RGB") preds = self.model(WEIGHTS.transforms()(img).unsqueeze(0).to(self.device)) @@ -1041,7 +1041,7 @@ Then update `cog.yaml` to look like this: build: python_version: "3.13" python_requirements: requirements.txt -predict: "predict.py:Predictor" +run: "run.py:Runner" ``` > [!TIP] @@ -1058,7 +1058,7 @@ curl $IMAGE_URL > input.jpg Now, let's run the model using Cog: ```bash -cog predict -i image=@input.jpg +cog run -i image=@input.jpg ``` @@ -1074,7 +1074,7 @@ If you see the following output then it worked! -Note: The first time you run `cog predict`, the build process will be triggered to generate a Docker container that can run your model. The next time you run `cog predict` the pre-built container will be used. +Note: The first time you run `cog run`, the build process will be triggered to generate a Docker container that can run your model. The next time you run `cog run` the pre-built container will be used. ## Build an image @@ -1087,10 +1087,10 @@ cog build -t resnet ``` -You can run this image with `cog predict` by passing the filename as an argument: +You can run this image with `cog run` by passing the filename as an argument: ```bash -cog predict resnet -i image=@input.jpg +cog run resnet -i image=@input.jpg ``` @@ -1276,7 +1276,7 @@ This produces a random identifier that is 26 ASCII characters long. ## File uploads -A model's `predict` function can produce file output by yielding or returning +A model's `run` function can produce file output by yielding or returning a `cog.Path` or `cog.File` value. By default, @@ -1447,13 +1447,13 @@ The request body is a JSON object with the following fields: - `input`: A JSON object with the same keys as the - [arguments to the `predict()` function](python.md). + [arguments to the `run()` function](python.md). Any `File` or `Path` inputs are passed as URLs. The response body is a JSON object with the following fields: - `status`: Either `succeeded` or `failed`. -- `output`: The return value of the `predict()` function. +- `output`: The return value of the `run()` function. - `error`: If `status` is `failed`, the error message. - `metrics`: An object containing prediction metrics. Always includes `predict_time` (elapsed seconds). @@ -1599,10 +1599,10 @@ After cleanup, the exception must be re-raised using a bare `raise` statement. Failure to re-raise the exception may result in the termination of the container. ```python -from cog import BasePredictor, CancelationException, Input, Path +from cog import BaseRunner, CancelationException, Input, Path -class Predictor(BasePredictor): - def predict(self, image: Path = Input(description="Image to process")) -> Path: +class Runner(BaseRunner): + def run(self, image: Path = Input(description="Image to process")) -> Path: try: return self.process(image) except CancelationException: @@ -1642,9 +1642,9 @@ Cog can run notebooks in the environment you've defined in `cog.yaml` with the f cog exec -p 8888 jupyter lab --allow-root --ip=0.0.0.0 ``` -## Use notebook code in your predictor +## Use notebook code in your runner -You can also import a notebook into your Cog [Predictor](python.md) file. +You can also import a notebook into your Cog [Runner](python.md) file. First, export your notebook to a Python file: @@ -1652,15 +1652,15 @@ First, export your notebook to a Python file: jupyter nbconvert --to script my_notebook.ipynb # creates my_notebook.py ``` -Then import the exported Python script into your `predict.py` file. Any functions or variables defined in your notebook will be available to your predictor: +Then import the exported Python script into your `run.py` file. Any functions or variables defined in your notebook will be available to your runner: ```python -from cog import BasePredictor, Input +from cog import BaseRunner, Input import my_notebook -class Predictor(BasePredictor): - def predict(self, prompt: str = Input(description="string prompt")) -> str: +class Runner(BaseRunner): + def run(self, prompt: str = Input(description="string prompt")) -> str: output = my_notebook.do_stuff(prompt) return output ``` @@ -1714,12 +1714,12 @@ Using a secret mount allows the private registry credentials to be securely pass --- -# Prediction interface reference +# Run interface reference This document defines the API of the `cog` Python module, which is used to define the interface for running predictions on your model. > [!TIP] -> Run [`cog init`](getting-started-own-model.md#initialization) to generate an annotated `predict.py` file that can be used as a starting point for setting up your model. +> Run [`cog init`](getting-started-own-model.md#initialization) to generate an annotated `run.py` file that can be used as a starting point for setting up your model. > [!TIP] > Using a language model to help you write the code for your new Cog model? @@ -1729,10 +1729,10 @@ This document defines the API of the `cog` Python module, which is used to defin ## Contents - [Contents](#contents) -- [`BasePredictor`](#basepredictor) - - [`Predictor.setup()`](#predictorsetup) - - [`Predictor.predict(**kwargs)`](#predictorpredictkwargs) -- [`async` predictors and concurrency](#async-predictors-and-concurrency) +- [`BaseRunner`](#baserunner) + - [`Runner.setup()`](#runnersetup) + - [`Runner.run(**kwargs)`](#runnerrunkwargs) +- [`async` runners and concurrency](#async-runners-and-concurrency) - [`Input(**kwargs)`](#inputkwargs) - [Deprecating inputs](#deprecating-inputs) - [Output](#output) @@ -1763,20 +1763,20 @@ This document defines the API of the `cog` Python module, which is used to defin - [`BaseModel` field types](#basemodel-field-types) - [Type limitations](#type-limitations) -## `BasePredictor` +## `BaseRunner` -You define how Cog runs predictions on your model by defining a class that inherits from `BasePredictor`. It looks something like this: +You define how Cog runs predictions on your model by defining a class that inherits from `BaseRunner`. It looks something like this: ```python -from cog import BasePredictor, Path, Input +from cog import BaseRunner, Path, Input import torch -class Predictor(BasePredictor): +class Runner(BaseRunner): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.model = torch.load("weights.pth") - def predict(self, + def run(self, image: Path = Input(description="Image to enlarge"), scale: float = Input(description="Factor to scale image by", default=1.5) ) -> Path: @@ -1787,9 +1787,11 @@ class Predictor(BasePredictor): return output ``` -Your Predictor class should define two methods: `setup()` and `predict()`. +Your Runner class should define two methods: `setup()` and `run()`. + +`BasePredictor`, `Predictor`, and `predict()` still work for existing models, but they are deprecated. Cog warns when it loads or inspects those legacy names. Use `BaseRunner`, `Runner`, and `run()` for new code. -### `Predictor.setup()` +### `Runner.setup()` Prepare the model so multiple predictions run efficiently. @@ -1813,43 +1815,43 @@ While this will increase your image size and build time, it offers other advanta > When using this method, you should use the `--separate-weights` flag on `cog build` to store weights in a [separate layer](https://github.com/replicate/cog/blob/12ac02091d93beebebed037f38a0c99cd8749806/docs/getting-started.md?plain=1#L219). -### `Predictor.predict(**kwargs)` +### `Runner.run(**kwargs)` Run a single prediction. This _required_ method is where you call the model that was loaded during `setup()`, but you may also want to add pre- and post-processing code here. -The `predict()` method takes an arbitrary list of named arguments, where each argument name must correspond to an [`Input()`](#inputkwargs) annotation. +The `run()` method takes an arbitrary list of named arguments, where each argument name must correspond to an [`Input()`](#inputkwargs) annotation. -`predict()` can return strings, numbers, [`cog.Path`](#cogpath) objects representing files on disk, or lists or dicts of those types. You can also define a custom [`BaseModel`](#structured-output-with-basemodel) for structured return types. See [Input and output types](#input-and-output-types) for the full list of supported types. +`run()` can return strings, numbers, [`cog.Path`](#cogpath) objects representing files on disk, or lists or dicts of those types. You can also define a custom [`BaseModel`](#structured-output-with-basemodel) for structured return types. See [Input and output types](#input-and-output-types) for the full list of supported types. -## `async` predictors and concurrency +## `async` runners and concurrency > Added in cog 0.14.0. -You may specify your `predict()` method as `async def predict(...)`. In -addition, if you have an async `predict()` function you may also have an async +You may specify your `run()` method as `async def run(...)`. In +addition, if you have an async `run()` function you may also have an async `setup()` function: ```py -class Predictor(BasePredictor): +class Runner(BaseRunner): async def setup(self) -> None: print("async setup is also supported...") - async def predict(self) -> str: - print("async predict"); + async def run(self) -> str: + print("async run"); return "hello world"; ``` -Models that have an async `predict()` function can run predictions concurrently, up to the limit specified by [`concurrency.max`](yaml.md#max) in cog.yaml. Attempting to exceed this limit will return a 409 Conflict response. +Models that have an async `run()` function can run predictions concurrently, up to the limit specified by [`concurrency.max`](yaml.md#max) in cog.yaml. Attempting to exceed this limit will return a 409 Conflict response. ## `Input(**kwargs)` -Use cog's `Input()` function to define each of the parameters in your `predict()` method: +Use cog's `Input()` function to define each of the parameters in your `run()` method: ```py -class Predictor(BasePredictor): - def predict(self, +class Runner(BaseRunner): + def run(self, image: Path = Input(description="Image to enlarge"), scale: float = Input(description="Factor to scale image by", default=1.5, ge=1.0, le=10.0) ) -> Path: @@ -1867,13 +1869,13 @@ The `Input()` function takes these keyword arguments: - `choices`: For `str` or `int` types, a list of possible values for this input. - `deprecated`: (optional) If set to `True`, marks this input as deprecated. Deprecated inputs will still be accepted, but tools and UIs may warn users that the input is deprecated and may be removed in the future. See [Deprecating inputs](#deprecating-inputs). -Each parameter of the `predict()` method must be annotated with a type like `str`, `int`, `float`, `bool`, etc. See [Input and output types](#input-and-output-types) for the full list of supported types. +Each parameter of the `run()` method must be annotated with a type like `str`, `int`, `float`, `bool`, etc. See [Input and output types](#input-and-output-types) for the full list of supported types. Using the `Input` function provides better documentation and validation constraints to the users of your model, but it is not strictly required. You can also specify default values for your parameters using plain Python, or omit default assignment entirely: ```py -class Predictor(BasePredictor): - def predict(self, +class Runner(BaseRunner): + def run(self, prompt: str = "default prompt", # this is valid iterations: int # also valid ) -> str: @@ -1887,10 +1889,10 @@ You can mark an input as deprecated by passing `deprecated=True` to the `Input() This is useful when you want to phase out an input without breaking existing clients immediately: ```py -from cog import BasePredictor, Input +from cog import BaseRunner, Input -class Predictor(BasePredictor): - def predict(self, +class Runner(BaseRunner): + def run(self, text: str = Input(description="Some deprecated text", deprecated=True), prompt: str = Input(description="Prompt for the model") ) -> str: @@ -1905,26 +1907,26 @@ Cog predictors can return a simple data type like a string, number, float, or bo Here's an example of a predictor that returns a string: ```py -from cog import BasePredictor +from cog import BaseRunner -class Predictor(BasePredictor): - def predict(self) -> str: +class Runner(BaseRunner): + def run(self) -> str: return "hello" ``` ### Returning an object -To return a complex object with multiple values, define an `Output` object with multiple fields to return from your `predict()` method: +To return a complex object with multiple values, define an `Output` object with multiple fields to return from your `run()` method: ```py -from cog import BasePredictor, BaseModel, File +from cog import BaseRunner, BaseModel, File class Output(BaseModel): file: File text: str -class Predictor(BasePredictor): - def predict(self) -> Output: +class Runner(BaseRunner): + def run(self) -> Output: return Output(text="hello", file=io.StringIO("hello")) ``` @@ -1932,13 +1934,13 @@ Each of the output object's properties must be one of the supported output types ### Returning a list -The `predict()` method can return a list of any of the supported output types. Here's an example that outputs multiple files: +The `run()` method can return a list of any of the supported output types. Here's an example that outputs multiple files: ```py -from cog import BasePredictor, Path +from cog import BaseRunner, Path -class Predictor(BasePredictor): - def predict(self) -> list[Path]: +class Runner(BaseRunner): + def run(self) -> list[Path]: predictions = ["foo", "bar", "baz"] output = [] for i, prediction in enumerate(predictions): @@ -1956,15 +1958,15 @@ Files are named in the format `output..`, e.g. `output.0.txt`, To conditionally omit properties from the Output object, define them using `typing.Optional`: ```py -from cog import BaseModel, BasePredictor, Path +from cog import BaseModel, BaseRunner, Path from typing import Optional class Output(BaseModel): score: Optional[float] file: Optional[Path] -class Predictor(BasePredictor): - def predict(self) -> Output: +class Runner(BaseRunner): + def run(self) -> Output: if condition: return Output(score=1.5) else: @@ -1973,30 +1975,30 @@ class Predictor(BasePredictor): ### Streaming output -Cog models can stream output as the `predict()` method is running. For example, a language model can output tokens as they're being generated and an image generation model can output images as they are being generated. +Cog models can stream output as the `run()` method is running. For example, a language model can output tokens as they're being generated and an image generation model can output images as they are being generated. -To support streaming output in your Cog model, add `from typing import Iterator` to your predict.py file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `predict()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. +To support streaming output in your Cog model, add `from typing import Iterator` to your `run.py` file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `run()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. ```py -from cog import BasePredictor, Path +from cog import BaseRunner, Path from typing import Iterator -class Predictor(BasePredictor): - def predict(self) -> Iterator[Path]: +class Runner(BaseRunner): + def run(self) -> Iterator[Path]: done = False while not done: output_path, done = do_stuff() yield Path(output_path) ``` -If you have an [async `predict()` method](#async-predictors-and-concurrency), use `AsyncIterator` from the `typing` module: +If you have an [async `run()` method](#async-runners-and-concurrency), use `AsyncIterator` from the `typing` module: ```py from typing import AsyncIterator -from cog import BasePredictor, Path +from cog import BaseRunner, Path -class Predictor(BasePredictor): - async def predict(self) -> AsyncIterator[Path]: +class Runner(BaseRunner): + async def run(self) -> AsyncIterator[Path]: done = False while not done: output_path, done = do_stuff() @@ -2006,22 +2008,22 @@ class Predictor(BasePredictor): If you're streaming text output, you can use `ConcatenateIterator` to hint that the output should be concatenated together into a single string. This is useful on Replicate to display the output as a string instead of a list of strings. ```py -from cog import BasePredictor, Path, ConcatenateIterator +from cog import BaseRunner, Path, ConcatenateIterator -class Predictor(BasePredictor): - def predict(self) -> ConcatenateIterator[str]: +class Runner(BaseRunner): + def run(self) -> ConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: yield token + " " ``` -Or for async `predict()` methods, use `AsyncConcatenateIterator`: +Or for async `run()` methods, use `AsyncConcatenateIterator`: ```py -from cog import BasePredictor, Path, AsyncConcatenateIterator +from cog import BaseRunner, Path, AsyncConcatenateIterator -class Predictor(BasePredictor): - async def predict(self) -> AsyncConcatenateIterator[str]: +class Runner(BaseRunner): + async def run(self) -> AsyncConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: yield token + " " @@ -2029,17 +2031,17 @@ class Predictor(BasePredictor): ## Metrics -You can record custom metrics from your `predict()` function to track model-specific data like token counts, timing breakdowns, or confidence scores. Metrics are included in the prediction response alongside the output. +You can record custom metrics from your `run()` function to track model-specific data like token counts, timing breakdowns, or confidence scores. Metrics are included in the prediction response alongside the output. ### Recording metrics -Use `self.record_metric()` inside your `predict()` method: +Use `self.record_metric()` inside your `run()` method: ```python -from cog import BasePredictor +from cog import BaseRunner -class Predictor(BasePredictor): - def predict(self, prompt: str) -> str: +class Runner(BaseRunner): + def run(self, prompt: str) -> str: self.record_metric("temperature", 0.7) self.record_metric("token_count", 42) @@ -2155,12 +2157,12 @@ Outside an active prediction, `self.record_metric()` and `self.scope` are silent ## Cancellation -When a prediction is canceled (via the [cancel HTTP endpoint](http.md#post-predictionsprediction_idcancel) or a dropped connection), the Cog runtime interrupts the running `predict()` function. The exception raised depends on whether the predictor is sync or async: +When a prediction is canceled (via the [cancel HTTP endpoint](http.md#post-predictionsprediction_idcancel) or a dropped connection), the Cog runtime interrupts the running `run()` function. The exception raised depends on whether the runner is sync or async: -| Predictor type | Exception raised | -| --------------------------- | ------------------------ | -| Sync (`def predict`) | `CancelationException` | -| Async (`async def predict`) | `asyncio.CancelledError` | +| Runner type | Exception raised | +| ----------------------- | ------------------------ | +| Sync (`def run`) | `CancelationException` | +| Async (`async def run`) | `asyncio.CancelledError` | ### `CancelationException` @@ -2173,10 +2175,10 @@ from cog import CancelationException You do **not** need to handle this exception in normal predictor code — the runtime manages cancellation automatically. However, if you need to run cleanup logic when a prediction is cancelled, you can catch it explicitly: ```python -from cog import BasePredictor, CancelationException, Path +from cog import BaseRunner, CancelationException, Path -class Predictor(BasePredictor): - def predict(self, image: Path) -> Path: +class Runner(BaseRunner): + def run(self, image: Path) -> Path: try: return self.process(image) except CancelationException: @@ -2196,7 +2198,7 @@ For **async** predictors, cancellation follows standard Python async conventions ## Input and output types -Each parameter of the `predict()` method must be annotated with a type. The method's return type must also be annotated. +Each parameter of the `run()` method must be annotated with a type. The method's return type must also be annotated. ### Primitive types @@ -2224,10 +2226,10 @@ This example takes an input file, resizes it, and returns the resized image: ```python import tempfile -from cog import BasePredictor, Input, Path +from cog import BaseRunner, Input, Path -class Predictor(BasePredictor): - def predict(self, image: Path = Input(description="Image to enlarge")) -> Path: +class Runner(BaseRunner): + def run(self, image: Path = Input(description="Image to enlarge")) -> Path: upscaled_image = do_some_processing(image) # To output cog.Path objects the file needs to exist, so create a temporary file first. @@ -2245,11 +2247,11 @@ class Predictor(BasePredictor): `cog.File` represents a _file handle_. For models that return a `cog.File` object, the prediction output returned by Cog's built-in HTTP server will be a URL. ```python -from cog import BasePredictor, File, Input +from cog import BaseRunner, File, Input from PIL import Image -class Predictor(BasePredictor): - def predict(self, source_image: File = Input(description="Image to enlarge")) -> File: +class Runner(BaseRunner): + def run(self, source_image: File = Input(description="Image to enlarge")) -> File: pillow_img = Image.open(source_image) upscaled_image = do_some_processing(pillow_img) return File(upscaled_image) @@ -2262,10 +2264,10 @@ class Predictor(BasePredictor): `cog.Secret` redacts its contents in string representations to prevent accidental disclosure. Access the underlying value with `get_secret_value()`. ```python -from cog import BasePredictor, Secret +from cog import BaseRunner, Secret -class Predictor(BasePredictor): - def predict(self, api_token: Secret) -> None: +class Runner(BaseRunner): + def run(self, api_token: Secret) -> None: # Prints '**********' print(api_token) @@ -2299,10 +2301,10 @@ Use `Optional[T]` or `T | None` (Python 3.10+) to mark an input as optional. Opt ```python from typing import Optional -from cog import BasePredictor, Input +from cog import BaseRunner, Input -class Predictor(BasePredictor): - def predict(self, +class Runner(BaseRunner): + def run(self, prompt: Optional[str] = Input(description="Input prompt"), seed: int | None = Input(description="Random seed", default=None), ) -> str: @@ -2315,11 +2317,11 @@ Prefer `Optional[T]` or `T | None` over `str = Input(default=None)` for inputs t ```python # Bad: type annotation says str but value can be None -def predict(self, prompt: str = Input(default=None)) -> str: +def run(self, prompt: str = Input(default=None)) -> str: return "hello" + prompt # TypeError at runtime if prompt is None # Good: type annotation matches actual behavior -def predict(self, prompt: Optional[str] = Input(description="prompt")) -> str: +def run(self, prompt: Optional[str] = Input(description="prompt")) -> str: if prompt is None: return "hello" return "hello " + prompt @@ -2335,10 +2337,10 @@ Use `list[T]` or `List[T]` to accept or return a list of values. `T` can be a su **As an input type:** ```py -from cog import BasePredictor, Path +from cog import BaseRunner, Path -class Predictor(BasePredictor): - def predict(self, paths: list[Path]) -> str: +class Runner(BaseRunner): + def run(self, paths: list[Path]) -> str: output_parts = [] for path in paths: with open(path) as f: @@ -2346,21 +2348,21 @@ class Predictor(BasePredictor): return "".join(output_parts) ``` -With `cog predict`, repeat the input name to pass multiple values: +With `cog run`, repeat the input name to pass multiple values: ```bash $ echo test1 > 1.txt $ echo test2 > 2.txt -$ cog predict -i paths=@1.txt -i paths=@2.txt +$ cog run -i paths=@1.txt -i paths=@2.txt ``` **As an output type:** ```py -from cog import BasePredictor, Path +from cog import BaseRunner, Path -class Predictor(BasePredictor): - def predict(self) -> list[Path]: +class Runner(BaseRunner): + def run(self) -> list[Path]: predictions = ["foo", "bar", "baz"] output = [] for i, prediction in enumerate(predictions): @@ -2378,10 +2380,10 @@ Files are named in the format `output..`, e.g. `output.0.txt`, Use `dict` to accept or return an opaque JSON object. The value is passed through as-is without type validation. ```python -from cog import BasePredictor, Input +from cog import BaseRunner, Input -class Predictor(BasePredictor): - def predict(self, +class Runner(BaseRunner): + def run(self, params: dict = Input(description="Arbitrary JSON parameters"), ) -> dict: return {"greeting": "hello", "params": params} @@ -2392,19 +2394,19 @@ class Predictor(BasePredictor): #### `cog.Opaque` -Cog statically analyzes `predict()` type annotations to generate schemas. Some third-party package types, such as vLLM `TypedDict` definitions, may not be visible to that static analyzer even though they represent JSON-shaped object values at runtime. +Cog statically analyzes `run()` type annotations to generate schemas. Some third-party package types, such as vLLM `TypedDict` definitions, may not be visible to that static analyzer even though they represent JSON-shaped object values at runtime. Use `typing.Annotated` with `cog.Opaque` when you want Cog to accept or return those third-party object values without inspecting their fields: ```python from typing import Annotated -from cog import BasePredictor, Opaque +from cog import BaseRunner, Opaque from vllm.entrypoints.chat_utils import CustomChatCompletionMessageParam -class Predictor(BasePredictor): - def predict( +class Runner(BaseRunner): + def run( self, messages: Annotated[list[CustomChatCompletionMessageParam], Opaque], ) -> str: @@ -2425,15 +2427,15 @@ To return a complex object with multiple typed fields, define a class that inher ```python from typing import Optional -from cog import BasePredictor, BaseModel, Path +from cog import BaseRunner, BaseModel, Path class Output(BaseModel): text: str confidence: float image: Optional[Path] -class Predictor(BasePredictor): - def predict(self, prompt: str) -> Output: +class Runner(BaseRunner): + def run(self, prompt: str) -> Output: result = self.model.generate(prompt) return Output( text=result.text, @@ -2459,15 +2461,15 @@ If you already use Pydantic v2 in your model, you can use a Pydantic `BaseModel` ```python from pydantic import BaseModel as PydanticBaseModel -from cog import BasePredictor +from cog import BaseRunner class Result(PydanticBaseModel): name: str score: float tags: list[str] -class Predictor(BasePredictor): - def predict(self, prompt: str) -> Result: +class Runner(BaseRunner): + def run(self, prompt: str) -> Result: return Result(name="example", score=0.95, tags=["fast", "accurate"]) ``` @@ -2506,7 +2508,7 @@ Cog's training API allows you to define a fine-tuning interface for an existing ## How it works -If you've used Cog before, you've probably seen the [Predictor](./python.md) class, which defines the interface for creating predictions against your model. Cog's training API works similarly: You define a Python function that describes the inputs and outputs of the training process. The inputs are things like training data, epochs, batch size, seed, etc. The output is typically a file with the fine-tuned weights. +If you've used Cog before, you've probably seen the [Runner](./python.md) class, which defines the interface for creating predictions against your model. Cog's training API works similarly: You define a Python function that describes the inputs and outputs of the training process. The inputs are things like training data, epochs, batch size, seed, etc. The output is typically a file with the fine-tuned weights. `cog.yaml`: @@ -2519,7 +2521,7 @@ train: "train.py:train" `train.py`: ```python -from cog import BasePredictor, File +from cog import File import io def train(param: str) -> File: @@ -2536,7 +2538,7 @@ $ cat weights hello train ``` -You can also use classes if you want to run many model trainings and save on setup time. This works the same way as the [Predictor](./python.md) class with the only difference being the `train` method. +You can also use classes if you want to run many model trainings and save on setup time. This works the same way as the [Runner](./python.md) class with the only difference being the `train` method. `cog.yaml`: @@ -2549,7 +2551,7 @@ train: "train.py:Trainer" `train.py`: ```python -from cog import BasePredictor, File +from cog import File import io class Trainer: @@ -2619,10 +2621,10 @@ def train( ## Testing -If you are doing development of a Cog model like Llama or SDXL, you can test that the fine-tuned code path works before pushing by specifying a `COG_WEIGHTS` environment variable when running `predict`: +If you are doing development of a Cog model like Llama or SDXL, you can test that the fine-tuned code path works before pushing by specifying a `COG_WEIGHTS` environment variable when running `run`: ```console -cog predict -e COG_WEIGHTS=https://replicate.delivery/pbxt/xyz/weights.tar -i prompt="a photo of TOK" +cog run -e COG_WEIGHTS=https://replicate.delivery/pbxt/xyz/weights.tar -i prompt="a photo of TOK" ``` @@ -2805,7 +2807,7 @@ cog --version # should output the cog version number. Finally, make sure it works. Let's try running `afiaka87/glid-3-xl` locally: ```bash -cog predict 'r8.im/afiaka87/glid-3-xl' -i prompt="a fresh avocado floating in the water" -o prediction.json +cog run 'r8.im/afiaka87/glid-3-xl' -i prompt="a fresh avocado floating in the water" -o prediction.json ``` ![Output from a running cog prediction in Windows Terminal](images/cog_model_output.png) @@ -2853,7 +2855,7 @@ explorer.exe prediction.png `cog.yaml` defines how to build a Docker image and how to run predictions on your model inside that image. -It has three keys: [`build`](#build), [`image`](#image), and [`predict`](#predict). It looks a bit like this: +It has three keys: [`build`](#build), [`image`](#image), and [`run`](#run). It looks a bit like this: ```yaml build: @@ -2862,7 +2864,7 @@ build: system_packages: - "ffmpeg" - "git" -predict: "predict.py:Predictor" +run: "run.py:Runner" ``` Tip: Run [`cog init`](getting-started-own-model.md#initialization) to generate an annotated `cog.yaml` file that can be used as a starting point for setting up your model. @@ -2895,7 +2897,7 @@ build: gpu: true ``` -When you use `cog exec` or `cog predict`, Cog will automatically pass the `--gpus=all` flag to Docker. When you run a Docker image built with Cog, you'll need to pass this option to `docker run`. +When you use `cog exec` or `cog run`, Cog will automatically pass the `--gpus=all` flag to Docker. When you run a Docker image built with Cog, you'll need to pass this option to `docker run`. ### `python_requirements` @@ -3045,7 +3047,7 @@ This stanza describes the concurrency capabilities of the model. It has one opti ### `max` -The maximum number of concurrent predictions the model can process. If this is set, the model must specify an [async `predict()` method](python.md#async-predictors-and-concurrency). +The maximum number of concurrent predictions the model can process. If this is set, the model must specify an [async `run()` method](python.md#async-runners-and-concurrency). For example: @@ -3072,9 +3074,23 @@ If you set this, then you can run `cog push` without specifying the model name. If you specify an image name argument when pushing (like `cog push your-username/custom-model-name`), the argument will be used and the value of `image` in cog.yaml will be ignored. +## `run` + +The pointer to the `Runner` object in your code, which defines how predictions are run on your model. + +For example: + +```yaml +run: "run.py:Runner" +``` + +`predict:` is still accepted for existing projects, but it is deprecated. New projects should use `run:`. + +See [the Python API documentation for more information](python.md). + ## `predict` -The pointer to the `Predictor` object in your code, which defines how predictions are run on your model. +Deprecated compatibility field for [`run`](#run). Existing projects can continue using it, but Cog will warn and `cog doctor --fix` can migrate common projects to `run:`. For example: diff --git a/docs/notebooks.md b/docs/notebooks.md index 2f1f4a496c..c93fc4ee14 100644 --- a/docs/notebooks.md +++ b/docs/notebooks.md @@ -27,9 +27,9 @@ Cog can run notebooks in the environment you've defined in `cog.yaml` with the f cog exec -p 8888 jupyter lab --allow-root --ip=0.0.0.0 ``` -## Use notebook code in your predictor +## Use notebook code in your runner -You can also import a notebook into your Cog [Predictor](python.md) file. +You can also import a notebook into your Cog [Runner](python.md) file. First, export your notebook to a Python file: @@ -37,15 +37,15 @@ First, export your notebook to a Python file: jupyter nbconvert --to script my_notebook.ipynb # creates my_notebook.py ``` -Then import the exported Python script into your `predict.py` file. Any functions or variables defined in your notebook will be available to your predictor: +Then import the exported Python script into your `run.py` file. Any functions or variables defined in your notebook will be available to your runner: ```python -from cog import BasePredictor, Input +from cog import BaseRunner, Input import my_notebook -class Predictor(BasePredictor): - def predict(self, prompt: str = Input(description="string prompt")) -> str: +class Runner(BaseRunner): + def run(self, prompt: str = Input(description="string prompt")) -> str: output = my_notebook.do_stuff(prompt) return output ``` diff --git a/docs/python.md b/docs/python.md index 06d2a8adc3..99889ee023 100644 --- a/docs/python.md +++ b/docs/python.md @@ -1,9 +1,9 @@ -# Prediction interface reference +# Run interface reference This document defines the API of the `cog` Python module, which is used to define the interface for running predictions on your model. > [!TIP] -> Run [`cog init`](getting-started-own-model.md#initialization) to generate an annotated `predict.py` file that can be used as a starting point for setting up your model. +> Run [`cog init`](getting-started-own-model.md#initialization) to generate an annotated `run.py` file that can be used as a starting point for setting up your model. > [!TIP] > Using a language model to help you write the code for your new Cog model? @@ -13,10 +13,10 @@ This document defines the API of the `cog` Python module, which is used to defin ## Contents - [Contents](#contents) -- [`BasePredictor`](#basepredictor) - - [`Predictor.setup()`](#predictorsetup) - - [`Predictor.predict(**kwargs)`](#predictorpredictkwargs) -- [`async` predictors and concurrency](#async-predictors-and-concurrency) +- [`BaseRunner`](#baserunner) + - [`Runner.setup()`](#runnersetup) + - [`Runner.run(**kwargs)`](#runnerrunkwargs) +- [`async` runners and concurrency](#async-runners-and-concurrency) - [`Input(**kwargs)`](#inputkwargs) - [Deprecating inputs](#deprecating-inputs) - [Output](#output) @@ -47,20 +47,20 @@ This document defines the API of the `cog` Python module, which is used to defin - [`BaseModel` field types](#basemodel-field-types) - [Type limitations](#type-limitations) -## `BasePredictor` +## `BaseRunner` -You define how Cog runs predictions on your model by defining a class that inherits from `BasePredictor`. It looks something like this: +You define how Cog runs predictions on your model by defining a class that inherits from `BaseRunner`. It looks something like this: ```python -from cog import BasePredictor, Path, Input +from cog import BaseRunner, Path, Input import torch -class Predictor(BasePredictor): +class Runner(BaseRunner): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.model = torch.load("weights.pth") - def predict(self, + def run(self, image: Path = Input(description="Image to enlarge"), scale: float = Input(description="Factor to scale image by", default=1.5) ) -> Path: @@ -71,9 +71,11 @@ class Predictor(BasePredictor): return output ``` -Your Predictor class should define two methods: `setup()` and `predict()`. +Your Runner class should define two methods: `setup()` and `run()`. -### `Predictor.setup()` +`BasePredictor`, `Predictor`, and `predict()` still work for existing models, but they are deprecated. Cog warns when it loads or inspects those legacy names. Use `BaseRunner`, `Runner`, and `run()` for new code. + +### `Runner.setup()` Prepare the model so multiple predictions run efficiently. @@ -97,43 +99,43 @@ While this will increase your image size and build time, it offers other advanta > When using this method, you should use the `--separate-weights` flag on `cog build` to store weights in a [separate layer](https://github.com/replicate/cog/blob/12ac02091d93beebebed037f38a0c99cd8749806/docs/getting-started.md?plain=1#L219). -### `Predictor.predict(**kwargs)` +### `Runner.run(**kwargs)` Run a single prediction. This _required_ method is where you call the model that was loaded during `setup()`, but you may also want to add pre- and post-processing code here. -The `predict()` method takes an arbitrary list of named arguments, where each argument name must correspond to an [`Input()`](#inputkwargs) annotation. +The `run()` method takes an arbitrary list of named arguments, where each argument name must correspond to an [`Input()`](#inputkwargs) annotation. -`predict()` can return strings, numbers, [`cog.Path`](#cogpath) objects representing files on disk, or lists or dicts of those types. You can also define a custom [`BaseModel`](#structured-output-with-basemodel) for structured return types. See [Input and output types](#input-and-output-types) for the full list of supported types. +`run()` can return strings, numbers, [`cog.Path`](#cogpath) objects representing files on disk, or lists or dicts of those types. You can also define a custom [`BaseModel`](#structured-output-with-basemodel) for structured return types. See [Input and output types](#input-and-output-types) for the full list of supported types. -## `async` predictors and concurrency +## `async` runners and concurrency > Added in cog 0.14.0. -You may specify your `predict()` method as `async def predict(...)`. In -addition, if you have an async `predict()` function you may also have an async +You may specify your `run()` method as `async def run(...)`. In +addition, if you have an async `run()` function you may also have an async `setup()` function: ```py -class Predictor(BasePredictor): +class Runner(BaseRunner): async def setup(self) -> None: print("async setup is also supported...") - async def predict(self) -> str: - print("async predict"); + async def run(self) -> str: + print("async run"); return "hello world"; ``` -Models that have an async `predict()` function can run predictions concurrently, up to the limit specified by [`concurrency.max`](yaml.md#max) in cog.yaml. Attempting to exceed this limit will return a 409 Conflict response. +Models that have an async `run()` function can run predictions concurrently, up to the limit specified by [`concurrency.max`](yaml.md#max) in cog.yaml. Attempting to exceed this limit will return a 409 Conflict response. ## `Input(**kwargs)` -Use cog's `Input()` function to define each of the parameters in your `predict()` method: +Use cog's `Input()` function to define each of the parameters in your `run()` method: ```py -class Predictor(BasePredictor): - def predict(self, +class Runner(BaseRunner): + def run(self, image: Path = Input(description="Image to enlarge"), scale: float = Input(description="Factor to scale image by", default=1.5, ge=1.0, le=10.0) ) -> Path: @@ -151,13 +153,13 @@ The `Input()` function takes these keyword arguments: - `choices`: For `str` or `int` types, a list of possible values for this input. - `deprecated`: (optional) If set to `True`, marks this input as deprecated. Deprecated inputs will still be accepted, but tools and UIs may warn users that the input is deprecated and may be removed in the future. See [Deprecating inputs](#deprecating-inputs). -Each parameter of the `predict()` method must be annotated with a type like `str`, `int`, `float`, `bool`, etc. See [Input and output types](#input-and-output-types) for the full list of supported types. +Each parameter of the `run()` method must be annotated with a type like `str`, `int`, `float`, `bool`, etc. See [Input and output types](#input-and-output-types) for the full list of supported types. Using the `Input` function provides better documentation and validation constraints to the users of your model, but it is not strictly required. You can also specify default values for your parameters using plain Python, or omit default assignment entirely: ```py -class Predictor(BasePredictor): - def predict(self, +class Runner(BaseRunner): + def run(self, prompt: str = "default prompt", # this is valid iterations: int # also valid ) -> str: @@ -171,10 +173,10 @@ You can mark an input as deprecated by passing `deprecated=True` to the `Input() This is useful when you want to phase out an input without breaking existing clients immediately: ```py -from cog import BasePredictor, Input +from cog import BaseRunner, Input -class Predictor(BasePredictor): - def predict(self, +class Runner(BaseRunner): + def run(self, text: str = Input(description="Some deprecated text", deprecated=True), prompt: str = Input(description="Prompt for the model") ) -> str: @@ -189,26 +191,26 @@ Cog predictors can return a simple data type like a string, number, float, or bo Here's an example of a predictor that returns a string: ```py -from cog import BasePredictor +from cog import BaseRunner -class Predictor(BasePredictor): - def predict(self) -> str: +class Runner(BaseRunner): + def run(self) -> str: return "hello" ``` ### Returning an object -To return a complex object with multiple values, define an `Output` object with multiple fields to return from your `predict()` method: +To return a complex object with multiple values, define an `Output` object with multiple fields to return from your `run()` method: ```py -from cog import BasePredictor, BaseModel, File +from cog import BaseRunner, BaseModel, File class Output(BaseModel): file: File text: str -class Predictor(BasePredictor): - def predict(self) -> Output: +class Runner(BaseRunner): + def run(self) -> Output: return Output(text="hello", file=io.StringIO("hello")) ``` @@ -216,13 +218,13 @@ Each of the output object's properties must be one of the supported output types ### Returning a list -The `predict()` method can return a list of any of the supported output types. Here's an example that outputs multiple files: +The `run()` method can return a list of any of the supported output types. Here's an example that outputs multiple files: ```py -from cog import BasePredictor, Path +from cog import BaseRunner, Path -class Predictor(BasePredictor): - def predict(self) -> list[Path]: +class Runner(BaseRunner): + def run(self) -> list[Path]: predictions = ["foo", "bar", "baz"] output = [] for i, prediction in enumerate(predictions): @@ -240,15 +242,15 @@ Files are named in the format `output..`, e.g. `output.0.txt`, To conditionally omit properties from the Output object, define them using `typing.Optional`: ```py -from cog import BaseModel, BasePredictor, Path +from cog import BaseModel, BaseRunner, Path from typing import Optional class Output(BaseModel): score: Optional[float] file: Optional[Path] -class Predictor(BasePredictor): - def predict(self) -> Output: +class Runner(BaseRunner): + def run(self) -> Output: if condition: return Output(score=1.5) else: @@ -257,30 +259,30 @@ class Predictor(BasePredictor): ### Streaming output -Cog models can stream output as the `predict()` method is running. For example, a language model can output tokens as they're being generated and an image generation model can output images as they are being generated. +Cog models can stream output as the `run()` method is running. For example, a language model can output tokens as they're being generated and an image generation model can output images as they are being generated. -To support streaming output in your Cog model, add `from typing import Iterator` to your predict.py file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `predict()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. +To support streaming output in your Cog model, add `from typing import Iterator` to your `run.py` file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `run()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. ```py -from cog import BasePredictor, Path +from cog import BaseRunner, Path from typing import Iterator -class Predictor(BasePredictor): - def predict(self) -> Iterator[Path]: +class Runner(BaseRunner): + def run(self) -> Iterator[Path]: done = False while not done: output_path, done = do_stuff() yield Path(output_path) ``` -If you have an [async `predict()` method](#async-predictors-and-concurrency), use `AsyncIterator` from the `typing` module: +If you have an [async `run()` method](#async-runners-and-concurrency), use `AsyncIterator` from the `typing` module: ```py from typing import AsyncIterator -from cog import BasePredictor, Path +from cog import BaseRunner, Path -class Predictor(BasePredictor): - async def predict(self) -> AsyncIterator[Path]: +class Runner(BaseRunner): + async def run(self) -> AsyncIterator[Path]: done = False while not done: output_path, done = do_stuff() @@ -290,22 +292,22 @@ class Predictor(BasePredictor): If you're streaming text output, you can use `ConcatenateIterator` to hint that the output should be concatenated together into a single string. This is useful on Replicate to display the output as a string instead of a list of strings. ```py -from cog import BasePredictor, Path, ConcatenateIterator +from cog import BaseRunner, Path, ConcatenateIterator -class Predictor(BasePredictor): - def predict(self) -> ConcatenateIterator[str]: +class Runner(BaseRunner): + def run(self) -> ConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: yield token + " " ``` -Or for async `predict()` methods, use `AsyncConcatenateIterator`: +Or for async `run()` methods, use `AsyncConcatenateIterator`: ```py -from cog import BasePredictor, Path, AsyncConcatenateIterator +from cog import BaseRunner, Path, AsyncConcatenateIterator -class Predictor(BasePredictor): - async def predict(self) -> AsyncConcatenateIterator[str]: +class Runner(BaseRunner): + async def run(self) -> AsyncConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: yield token + " " @@ -313,17 +315,17 @@ class Predictor(BasePredictor): ## Metrics -You can record custom metrics from your `predict()` function to track model-specific data like token counts, timing breakdowns, or confidence scores. Metrics are included in the prediction response alongside the output. +You can record custom metrics from your `run()` function to track model-specific data like token counts, timing breakdowns, or confidence scores. Metrics are included in the prediction response alongside the output. ### Recording metrics -Use `self.record_metric()` inside your `predict()` method: +Use `self.record_metric()` inside your `run()` method: ```python -from cog import BasePredictor +from cog import BaseRunner -class Predictor(BasePredictor): - def predict(self, prompt: str) -> str: +class Runner(BaseRunner): + def run(self, prompt: str) -> str: self.record_metric("temperature", 0.7) self.record_metric("token_count", 42) @@ -439,12 +441,12 @@ Outside an active prediction, `self.record_metric()` and `self.scope` are silent ## Cancellation -When a prediction is canceled (via the [cancel HTTP endpoint](http.md#post-predictionsprediction_idcancel) or a dropped connection), the Cog runtime interrupts the running `predict()` function. The exception raised depends on whether the predictor is sync or async: +When a prediction is canceled (via the [cancel HTTP endpoint](http.md#post-predictionsprediction_idcancel) or a dropped connection), the Cog runtime interrupts the running `run()` function. The exception raised depends on whether the runner is sync or async: -| Predictor type | Exception raised | -| --------------------------- | ------------------------ | -| Sync (`def predict`) | `CancelationException` | -| Async (`async def predict`) | `asyncio.CancelledError` | +| Runner type | Exception raised | +| ----------------------- | ------------------------ | +| Sync (`def run`) | `CancelationException` | +| Async (`async def run`) | `asyncio.CancelledError` | ### `CancelationException` @@ -457,10 +459,10 @@ from cog import CancelationException You do **not** need to handle this exception in normal predictor code — the runtime manages cancellation automatically. However, if you need to run cleanup logic when a prediction is cancelled, you can catch it explicitly: ```python -from cog import BasePredictor, CancelationException, Path +from cog import BaseRunner, CancelationException, Path -class Predictor(BasePredictor): - def predict(self, image: Path) -> Path: +class Runner(BaseRunner): + def run(self, image: Path) -> Path: try: return self.process(image) except CancelationException: @@ -480,7 +482,7 @@ For **async** predictors, cancellation follows standard Python async conventions ## Input and output types -Each parameter of the `predict()` method must be annotated with a type. The method's return type must also be annotated. +Each parameter of the `run()` method must be annotated with a type. The method's return type must also be annotated. ### Primitive types @@ -508,10 +510,10 @@ This example takes an input file, resizes it, and returns the resized image: ```python import tempfile -from cog import BasePredictor, Input, Path +from cog import BaseRunner, Input, Path -class Predictor(BasePredictor): - def predict(self, image: Path = Input(description="Image to enlarge")) -> Path: +class Runner(BaseRunner): + def run(self, image: Path = Input(description="Image to enlarge")) -> Path: upscaled_image = do_some_processing(image) # To output cog.Path objects the file needs to exist, so create a temporary file first. @@ -529,11 +531,11 @@ class Predictor(BasePredictor): `cog.File` represents a _file handle_. For models that return a `cog.File` object, the prediction output returned by Cog's built-in HTTP server will be a URL. ```python -from cog import BasePredictor, File, Input +from cog import BaseRunner, File, Input from PIL import Image -class Predictor(BasePredictor): - def predict(self, source_image: File = Input(description="Image to enlarge")) -> File: +class Runner(BaseRunner): + def run(self, source_image: File = Input(description="Image to enlarge")) -> File: pillow_img = Image.open(source_image) upscaled_image = do_some_processing(pillow_img) return File(upscaled_image) @@ -546,10 +548,10 @@ class Predictor(BasePredictor): `cog.Secret` redacts its contents in string representations to prevent accidental disclosure. Access the underlying value with `get_secret_value()`. ```python -from cog import BasePredictor, Secret +from cog import BaseRunner, Secret -class Predictor(BasePredictor): - def predict(self, api_token: Secret) -> None: +class Runner(BaseRunner): + def run(self, api_token: Secret) -> None: # Prints '**********' print(api_token) @@ -583,10 +585,10 @@ Use `Optional[T]` or `T | None` (Python 3.10+) to mark an input as optional. Opt ```python from typing import Optional -from cog import BasePredictor, Input +from cog import BaseRunner, Input -class Predictor(BasePredictor): - def predict(self, +class Runner(BaseRunner): + def run(self, prompt: Optional[str] = Input(description="Input prompt"), seed: int | None = Input(description="Random seed", default=None), ) -> str: @@ -599,11 +601,11 @@ Prefer `Optional[T]` or `T | None` over `str = Input(default=None)` for inputs t ```python # Bad: type annotation says str but value can be None -def predict(self, prompt: str = Input(default=None)) -> str: +def run(self, prompt: str = Input(default=None)) -> str: return "hello" + prompt # TypeError at runtime if prompt is None # Good: type annotation matches actual behavior -def predict(self, prompt: Optional[str] = Input(description="prompt")) -> str: +def run(self, prompt: Optional[str] = Input(description="prompt")) -> str: if prompt is None: return "hello" return "hello " + prompt @@ -619,10 +621,10 @@ Use `list[T]` or `List[T]` to accept or return a list of values. `T` can be a su **As an input type:** ```py -from cog import BasePredictor, Path +from cog import BaseRunner, Path -class Predictor(BasePredictor): - def predict(self, paths: list[Path]) -> str: +class Runner(BaseRunner): + def run(self, paths: list[Path]) -> str: output_parts = [] for path in paths: with open(path) as f: @@ -630,21 +632,21 @@ class Predictor(BasePredictor): return "".join(output_parts) ``` -With `cog predict`, repeat the input name to pass multiple values: +With `cog run`, repeat the input name to pass multiple values: ```bash $ echo test1 > 1.txt $ echo test2 > 2.txt -$ cog predict -i paths=@1.txt -i paths=@2.txt +$ cog run -i paths=@1.txt -i paths=@2.txt ``` **As an output type:** ```py -from cog import BasePredictor, Path +from cog import BaseRunner, Path -class Predictor(BasePredictor): - def predict(self) -> list[Path]: +class Runner(BaseRunner): + def run(self) -> list[Path]: predictions = ["foo", "bar", "baz"] output = [] for i, prediction in enumerate(predictions): @@ -662,10 +664,10 @@ Files are named in the format `output..`, e.g. `output.0.txt`, Use `dict` to accept or return an opaque JSON object. The value is passed through as-is without type validation. ```python -from cog import BasePredictor, Input +from cog import BaseRunner, Input -class Predictor(BasePredictor): - def predict(self, +class Runner(BaseRunner): + def run(self, params: dict = Input(description="Arbitrary JSON parameters"), ) -> dict: return {"greeting": "hello", "params": params} @@ -676,19 +678,19 @@ class Predictor(BasePredictor): #### `cog.Opaque` -Cog statically analyzes `predict()` type annotations to generate schemas. Some third-party package types, such as vLLM `TypedDict` definitions, may not be visible to that static analyzer even though they represent JSON-shaped object values at runtime. +Cog statically analyzes `run()` type annotations to generate schemas. Some third-party package types, such as vLLM `TypedDict` definitions, may not be visible to that static analyzer even though they represent JSON-shaped object values at runtime. Use `typing.Annotated` with `cog.Opaque` when you want Cog to accept or return those third-party object values without inspecting their fields: ```python from typing import Annotated -from cog import BasePredictor, Opaque +from cog import BaseRunner, Opaque from vllm.entrypoints.chat_utils import CustomChatCompletionMessageParam -class Predictor(BasePredictor): - def predict( +class Runner(BaseRunner): + def run( self, messages: Annotated[list[CustomChatCompletionMessageParam], Opaque], ) -> str: @@ -709,15 +711,15 @@ To return a complex object with multiple typed fields, define a class that inher ```python from typing import Optional -from cog import BasePredictor, BaseModel, Path +from cog import BaseRunner, BaseModel, Path class Output(BaseModel): text: str confidence: float image: Optional[Path] -class Predictor(BasePredictor): - def predict(self, prompt: str) -> Output: +class Runner(BaseRunner): + def run(self, prompt: str) -> Output: result = self.model.generate(prompt) return Output( text=result.text, @@ -743,15 +745,15 @@ If you already use Pydantic v2 in your model, you can use a Pydantic `BaseModel` ```python from pydantic import BaseModel as PydanticBaseModel -from cog import BasePredictor +from cog import BaseRunner class Result(PydanticBaseModel): name: str score: float tags: list[str] -class Predictor(BasePredictor): - def predict(self, prompt: str) -> Result: +class Runner(BaseRunner): + def run(self, prompt: str) -> Result: return Result(name="example", score=0.95, tags=["fast", "accurate"]) ``` diff --git a/docs/training.md b/docs/training.md index c7d2363458..026f6dc52a 100644 --- a/docs/training.md +++ b/docs/training.md @@ -7,7 +7,7 @@ Cog's training API allows you to define a fine-tuning interface for an existing ## How it works -If you've used Cog before, you've probably seen the [Predictor](./python.md) class, which defines the interface for creating predictions against your model. Cog's training API works similarly: You define a Python function that describes the inputs and outputs of the training process. The inputs are things like training data, epochs, batch size, seed, etc. The output is typically a file with the fine-tuned weights. +If you've used Cog before, you've probably seen the [Runner](./python.md) class, which defines the interface for creating predictions against your model. Cog's training API works similarly: You define a Python function that describes the inputs and outputs of the training process. The inputs are things like training data, epochs, batch size, seed, etc. The output is typically a file with the fine-tuned weights. `cog.yaml`: @@ -20,7 +20,7 @@ train: "train.py:train" `train.py`: ```python -from cog import BasePredictor, File +from cog import File import io def train(param: str) -> File: @@ -37,7 +37,7 @@ $ cat weights hello train ``` -You can also use classes if you want to run many model trainings and save on setup time. This works the same way as the [Predictor](./python.md) class with the only difference being the `train` method. +You can also use classes if you want to run many model trainings and save on setup time. This works the same way as the [Runner](./python.md) class with the only difference being the `train` method. `cog.yaml`: @@ -50,7 +50,7 @@ train: "train.py:Trainer" `train.py`: ```python -from cog import BasePredictor, File +from cog import File import io class Trainer: @@ -120,8 +120,8 @@ def train( ## Testing -If you are doing development of a Cog model like Llama or SDXL, you can test that the fine-tuned code path works before pushing by specifying a `COG_WEIGHTS` environment variable when running `predict`: +If you are doing development of a Cog model like Llama or SDXL, you can test that the fine-tuned code path works before pushing by specifying a `COG_WEIGHTS` environment variable when running `run`: ```console -cog predict -e COG_WEIGHTS=https://replicate.delivery/pbxt/xyz/weights.tar -i prompt="a photo of TOK" +cog run -e COG_WEIGHTS=https://replicate.delivery/pbxt/xyz/weights.tar -i prompt="a photo of TOK" ``` diff --git a/docs/wsl2/wsl2.md b/docs/wsl2/wsl2.md index f86d279881..bc3316d58a 100644 --- a/docs/wsl2/wsl2.md +++ b/docs/wsl2/wsl2.md @@ -175,7 +175,7 @@ cog --version # should output the cog version number. Finally, make sure it works. Let's try running `afiaka87/glid-3-xl` locally: ```bash -cog predict 'r8.im/afiaka87/glid-3-xl' -i prompt="a fresh avocado floating in the water" -o prediction.json +cog run 'r8.im/afiaka87/glid-3-xl' -i prompt="a fresh avocado floating in the water" -o prediction.json ``` ![Output from a running cog prediction in Windows Terminal](images/cog_model_output.png) diff --git a/docs/yaml.md b/docs/yaml.md index 1b57161b1b..07be9aaa96 100644 --- a/docs/yaml.md +++ b/docs/yaml.md @@ -2,7 +2,7 @@ `cog.yaml` defines how to build a Docker image and how to run predictions on your model inside that image. -It has three keys: [`build`](#build), [`image`](#image), and [`predict`](#predict). It looks a bit like this: +It has three keys: [`build`](#build), [`image`](#image), and [`run`](#run). It looks a bit like this: ```yaml build: @@ -11,7 +11,7 @@ build: system_packages: - "ffmpeg" - "git" -predict: "predict.py:Predictor" +run: "run.py:Runner" ``` Tip: Run [`cog init`](getting-started-own-model.md#initialization) to generate an annotated `cog.yaml` file that can be used as a starting point for setting up your model. @@ -44,7 +44,7 @@ build: gpu: true ``` -When you use `cog exec` or `cog predict`, Cog will automatically pass the `--gpus=all` flag to Docker. When you run a Docker image built with Cog, you'll need to pass this option to `docker run`. +When you use `cog exec` or `cog run`, Cog will automatically pass the `--gpus=all` flag to Docker. When you run a Docker image built with Cog, you'll need to pass this option to `docker run`. ### `python_requirements` @@ -194,7 +194,7 @@ This stanza describes the concurrency capabilities of the model. It has one opti ### `max` -The maximum number of concurrent predictions the model can process. If this is set, the model must specify an [async `predict()` method](python.md#async-predictors-and-concurrency). +The maximum number of concurrent predictions the model can process. If this is set, the model must specify an [async `run()` method](python.md#async-runners-and-concurrency). For example: @@ -221,9 +221,23 @@ If you set this, then you can run `cog push` without specifying the model name. If you specify an image name argument when pushing (like `cog push your-username/custom-model-name`), the argument will be used and the value of `image` in cog.yaml will be ignored. +## `run` + +The pointer to the `Runner` object in your code, which defines how predictions are run on your model. + +For example: + +```yaml +run: "run.py:Runner" +``` + +`predict:` is still accepted for existing projects, but it is deprecated. New projects should use `run:`. + +See [the Python API documentation for more information](python.md). + ## `predict` -The pointer to the `Predictor` object in your code, which defines how predictions are run on your model. +Deprecated compatibility field for [`run`](#run). Existing projects can continue using it, but Cog will warn and `cog doctor --fix` can migrate common projects to `run:`. For example: diff --git a/examples/managed-weights/README.md b/examples/managed-weights/README.md index b9a99d6708..7a70f213c2 100644 --- a/examples/managed-weights/README.md +++ b/examples/managed-weights/README.md @@ -9,9 +9,9 @@ If you're looking for a starting point for a real model, see ## What this does -The predictor doesn't do inference. Instead, it reads `weights.lock` at setup, +The runner doesn't do inference. Instead, it reads `weights.lock` at setup, validates that every expected file exists on disk with the correct size and -digest, and returns a per-weight status summary from `predict()`. It's a +digest, and returns a per-weight status summary from `run()`. It's a smoke test for the weight pipeline. The `cog.yaml` declares two weight sources to exercise both code paths: diff --git a/examples/managed-weights/cog.yaml b/examples/managed-weights/cog.yaml index 2086ac55b3..cfbde00fa6 100644 --- a/examples/managed-weights/cog.yaml +++ b/examples/managed-weights/cog.yaml @@ -17,7 +17,7 @@ build: python_version: "3.12" python_requirements: requirements.txt -predict: "predict.py:Predictor" +run: "run.py:Runner" weights: - name: parakeet diff --git a/examples/managed-weights/predict.py b/examples/managed-weights/run.py similarity index 96% rename from examples/managed-weights/predict.py rename to examples/managed-weights/run.py index fc9a6d407f..6c1c7bf14c 100644 --- a/examples/managed-weights/predict.py +++ b/examples/managed-weights/run.py @@ -1,5 +1,5 @@ -# Infra verification predictor for the v1 managed-weights OCI pipeline. -# Validates weight files on disk against weights.lock at setup; predict() +# Infra verification runner for the v1 managed-weights OCI pipeline. +# Validates weight files on disk against weights.lock at setup; run() # returns a per-weight status summary. import hashlib @@ -8,7 +8,7 @@ from pathlib import Path from typing import Any -from cog import BasePredictor +from cog import BaseRunner LOCK_PATH = Path("/src/weights.lock") @@ -99,7 +99,7 @@ def _validate_weight( } -class Predictor(BasePredictor): +class Runner(BaseRunner): def setup(self) -> None: if not LOCK_PATH.exists(): raise RuntimeError(f"{LOCK_PATH} not found — cannot validate weights") @@ -152,7 +152,7 @@ def setup(self) -> None: print("all weights validated", file=sys.stderr) - def predict(self) -> str: + def run(self) -> str: summary = [] for r in self.results: entry: dict[str, Any] = { diff --git a/examples/resnet/README.md b/examples/resnet/README.md index 0791ae5c09..faf973b874 100644 --- a/examples/resnet/README.md +++ b/examples/resnet/README.md @@ -11,7 +11,7 @@ Managed weights separate your model weights from your model image. Instead of baking multi-GB weight files into the Docker image (slow builds, huge layers), cog packs them into dedicated OCI layers that get mounted at runtime. -The key idea: your `predict.py` reads weights from a path like +The key idea: your `run.py` reads weights from a path like `/src/weights/resnet50`, but those files don't live inside the Docker image -- they arrive separately and get overlaid at that path when the container starts. @@ -20,7 +20,7 @@ they arrive separately and get overlaid at that path when the container starts. ``` examples/resnet/ ├── cog.yaml # model config -- declares weights, build settings -├── predict.py # predictor -- loads weights from target path +├── run.py # runner -- loads weights from target path ├── requirements.txt # python deps ├── weights.lock # generated by `cog weights import` -- don't hand-edit ├── .dockerignore # keeps local weight dirs out of the Docker build context @@ -31,7 +31,7 @@ examples/resnet/ Weight files themselves don't live in the project directory. `cog weights import` downloads them into a content-addressed store at `~/.cache/cog/weights/` (override -with `$COG_CACHE_DIR`). When you run `cog predict`, cog assembles a temporary +with `$COG_CACHE_DIR`). When you run `cog run`, cog assembles a temporary directory under `.cog/mounts/` using hardlinks from the store and bind-mounts it into the container at the `target` path. The mount dir is cleaned up when the container stops. @@ -42,20 +42,21 @@ container stops. weights: - name: resnet50 source: - uri: hf://microsoft/resnet-50 # where to fetch from - exclude: # files to skip + uri: hf://microsoft/resnet-50 # where to fetch from + exclude: # files to skip - "pytorch_model.bin" - "flax_model.msgpack" - "tf_model.h5" - "README.md" - ".gitattributes" - target: /src/weights/resnet50 # where files appear in the container + target: /src/weights/resnet50 # where files appear in the container ``` **`name`** -- an identifier for this weight set. Used in lockfile entries and OCI tags. Pick something short and descriptive. **`source.uri`** -- where the weights come from. Two formats: + - `hf:///` -- pulls from HuggingFace Hub - A local directory path (e.g. `weights/`) -- uses files already on disk @@ -64,7 +65,7 @@ weights in multiple formats (PyTorch, TF, Flax, ONNX). Exclude the ones you don't need -- it'll save gigabytes. **`target`** -- the absolute path where weight files land inside the container. -Your `predict.py` loads from this path. Must start with `/`. +Your `run.py` loads from this path. Must start with `/`. ## Getting started @@ -85,7 +86,7 @@ to version control. ### 2. Run a prediction locally ```sh -cog predict -i image=@hotdog.png +cog run -i image=@hotdog.png ``` Locally, cog assembles the weight files from the cache and bind-mounts them @@ -117,10 +118,10 @@ build daemon on every `cog build`. - Adjust `exclude` patterns for the formats you don't need - Set `target` to wherever your code expects to find the weights - Set `image` to your registry destination (required for `cog push`) -3. Edit `predict.py` to load your model from `WEIGHTS_DIR` +3. Edit `run.py` to load your model from `WEIGHTS_DIR` 4. Update `requirements.txt` with your dependencies 5. Run `cog weights import` to fetch weights and generate the lockfile -6. Test with `cog predict` +6. Test with `cog run` 7. Push with `cog push` ### Using local weights instead of HuggingFace diff --git a/examples/resnet/cog.yaml b/examples/resnet/cog.yaml index 413d3ee998..f2fa21b3aa 100644 --- a/examples/resnet/cog.yaml +++ b/examples/resnet/cog.yaml @@ -13,7 +13,7 @@ build: python_version: "3.13" python_requirements: requirements.txt -predict: "predict.py:Predictor" +run: "run.py:Runner" weights: - name: resnet50 diff --git a/examples/resnet/predict.py b/examples/resnet/run.py similarity index 84% rename from examples/resnet/predict.py rename to examples/resnet/run.py index 7c5555a8e1..5d86d0edfc 100644 --- a/examples/resnet/predict.py +++ b/examples/resnet/run.py @@ -2,12 +2,12 @@ from PIL import Image from transformers import AutoImageProcessor, ResNetForImageClassification -from cog import BasePredictor, Input, Path +from cog import BaseRunner, Input, Path WEIGHTS_DIR = "/src/weights/resnet50" -class Predictor(BasePredictor): +class Runner(BaseRunner): def setup(self) -> None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.processor = AutoImageProcessor.from_pretrained(WEIGHTS_DIR) @@ -15,7 +15,7 @@ def setup(self) -> None: self.model = self.model.to(self.device) self.model.eval() - def predict(self, image: Path = Input(description="Image to classify")) -> dict: + def run(self, image: Path = Input(description="Image to classify")) -> dict: img = Image.open(image).convert("RGB") inputs = self.processor(img, return_tensors="pt").to(self.device) diff --git a/integration-tests/tests/build_cog_init.txtar b/integration-tests/tests/build_cog_init.txtar index fb3f4bd406..0852134359 100644 --- a/integration-tests/tests/build_cog_init.txtar +++ b/integration-tests/tests/build_cog_init.txtar @@ -10,4 +10,4 @@ stderr 'Image built as' # Verify the expected files were created exists cog.yaml -exists predict.py +exists run.py diff --git a/integration-tests/tests/doctor_clean_project.txtar b/integration-tests/tests/doctor_clean_project.txtar index edf462c739..57cb7b5a11 100644 --- a/integration-tests/tests/doctor_clean_project.txtar +++ b/integration-tests/tests/doctor_clean_project.txtar @@ -7,12 +7,12 @@ stderr 'no issues found' -- cog.yaml -- build: python_version: "3.12" -predict: "predict.py:Predictor" +run: "run.py:Runner" --- predict.py -- -from cog import BasePredictor +-- run.py -- +from cog import BaseRunner -class Predictor(BasePredictor): - def predict(self, text: str) -> str: +class Runner(BaseRunner): + def run(self, text: str) -> str: return "hello " + text diff --git a/integration-tests/tests/doctor_fix_deprecated_imports.txtar b/integration-tests/tests/doctor_fix_deprecated_imports.txtar index 4e652c7a80..71c4f52f60 100644 --- a/integration-tests/tests/doctor_fix_deprecated_imports.txtar +++ b/integration-tests/tests/doctor_fix_deprecated_imports.txtar @@ -9,7 +9,7 @@ cog doctor --fix stderr 'Fixed' # Verify the import was removed from the file -exec cat predict.py +exec cat run.py ! stdout 'ExperimentalFeatureWarning' ! stdout 'cog.types' @@ -20,16 +20,16 @@ stderr 'no issues found' -- cog.yaml -- build: python_version: "3.12" -predict: "predict.py:Predictor" +run: "run.py:Runner" --- predict.py -- +-- run.py -- import warnings -from cog import BasePredictor +from cog import BaseRunner from cog.types import ExperimentalFeatureWarning warnings.filterwarnings("ignore", category=ExperimentalFeatureWarning) -class Predictor(BasePredictor): - def predict(self, text: str) -> str: +class Runner(BaseRunner): + def run(self, text: str) -> str: return "hello " + text diff --git a/integration-tests/tests/doctor_fix_pydantic.txtar b/integration-tests/tests/doctor_fix_pydantic.txtar index 198d80bb21..4aecb0a991 100644 --- a/integration-tests/tests/doctor_fix_pydantic.txtar +++ b/integration-tests/tests/doctor_fix_pydantic.txtar @@ -10,8 +10,8 @@ cog doctor --fix stderr 'Fixed' # Verify the file was modified: pydantic.BaseModel replaced with cog.BaseModel -exec cat predict.py -stdout 'from cog import BasePredictor, Path, BaseModel' +exec cat run.py +stdout 'from cog import BaseRunner, Path, BaseModel' ! stdout 'from pydantic import BaseModel' ! stdout 'ConfigDict' ! stdout 'arbitrary_types_allowed' @@ -23,10 +23,10 @@ stderr 'no issues found' -- cog.yaml -- build: python_version: "3.12" -predict: "predict.py:Predictor" +run: "run.py:Runner" --- predict.py -- -from cog import BasePredictor, Path +-- run.py -- +from cog import BaseRunner, Path from pydantic import BaseModel, ConfigDict @@ -36,6 +36,6 @@ class VoiceCloningOutputs(BaseModel): spectrogram: Path -class Predictor(BasePredictor): - def predict(self, text: str) -> VoiceCloningOutputs: +class Runner(BaseRunner): + def run(self, text: str) -> VoiceCloningOutputs: return VoiceCloningOutputs(audio="a.wav", spectrogram="s.png") diff --git a/integration-tests/tests/doctor_predict_to_run_migration.txtar b/integration-tests/tests/doctor_predict_to_run_migration.txtar new file mode 100644 index 0000000000..5fcac2112a --- /dev/null +++ b/integration-tests/tests/doctor_predict_to_run_migration.txtar @@ -0,0 +1,23 @@ +cog doctor +stderr 'Deprecated predict interface names' +stderr 'predict in cog.yaml is deprecated' + +cog doctor --fix +stderr 'Fixed: Deprecated predict interface names' + +grep 'run: "run.py:Runner"' cog.yaml +grep 'class Runner\(BaseRunner\):' run.py +grep 'def run\(' run.py +! exists predict.py + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor + +class Predictor(BasePredictor): + def predict(self, s: str) -> str: + return "hello " + s diff --git a/integration-tests/tests/predict_deprecation_warning.txtar b/integration-tests/tests/predict_deprecation_warning.txtar new file mode 100644 index 0000000000..404087eb11 --- /dev/null +++ b/integration-tests/tests/predict_deprecation_warning.txtar @@ -0,0 +1,16 @@ +cog build -t $TEST_IMAGE +cog predict $TEST_IMAGE -i s=world +stderr '"cog predict" is deprecated, use "cog run"' +stdout 'hello world' + +-- cog.yaml -- +build: + python_version: "3.12" +run: "run.py:Runner" + +-- run.py -- +from cog import BaseRunner + +class Runner(BaseRunner): + def run(self, s: str) -> str: + return "hello " + s diff --git a/integration-tests/tests/run_exec_deprecation_warning.txtar b/integration-tests/tests/run_exec_deprecation_warning.txtar new file mode 100644 index 0000000000..ce2ba2de4c --- /dev/null +++ b/integration-tests/tests/run_exec_deprecation_warning.txtar @@ -0,0 +1,15 @@ +cog run python -c 'print("hello")' +stderr '"cog run " is deprecated, use "cog exec "' +stdout 'hello' + +-- cog.yaml -- +build: + python_version: "3.12" +run: "run.py:Runner" + +-- run.py -- +from cog import BaseRunner + +class Runner(BaseRunner): + def run(self) -> str: + return "unused" diff --git a/integration-tests/tests/run_legacy_predictor.txtar b/integration-tests/tests/run_legacy_predictor.txtar new file mode 100644 index 0000000000..ec1da124be --- /dev/null +++ b/integration-tests/tests/run_legacy_predictor.txtar @@ -0,0 +1,16 @@ +cog build -t $TEST_IMAGE +stderr 'predict' +cog run $TEST_IMAGE -i s=world +stdout 'hello world' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor + +class Predictor(BasePredictor): + def predict(self, s: str) -> str: + return "hello " + s diff --git a/integration-tests/tests/run_new_runner.txtar b/integration-tests/tests/run_new_runner.txtar new file mode 100644 index 0000000000..c583c39273 --- /dev/null +++ b/integration-tests/tests/run_new_runner.txtar @@ -0,0 +1,16 @@ +# Build and run a new-style model +cog build -t $TEST_IMAGE +cog run $TEST_IMAGE -i s=world +stdout 'hello world' + +-- cog.yaml -- +build: + python_version: "3.12" +run: "run.py:Runner" + +-- run.py -- +from cog import BaseRunner + +class Runner(BaseRunner): + def run(self, s: str) -> str: + return "hello " + s diff --git a/pkg/cli/build.go b/pkg/cli/build.go index c279ae6669..9f08aff2e6 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -41,7 +41,7 @@ func newBuildCommand() *cobra.Command { Long: `Build a Docker image from the cog.yaml in the current directory. The generated image contains your model code, dependencies, and the Cog -runtime. It can be run locally with 'cog predict' or pushed to a registry +runtime. It can be run locally with 'cog run' or pushed to a registry with 'cog push'.`, Example: ` # Build with default settings cog build diff --git a/pkg/cli/doctor.go b/pkg/cli/doctor.go index b367c53d53..d23ccee612 100644 --- a/pkg/cli/doctor.go +++ b/pkg/cli/doctor.go @@ -30,7 +30,7 @@ func newDoctorCommand() *cobra.Command { NOTE: cog doctor is experimental. Behavior and checks may change in future versions. By default, cog doctor reports problems without modifying any files. -Pass --fix to automatically apply safe fixes.`, +Pass --fix to automatically apply safe fixes, including migrating deprecated predict names to run names when no file collision exists.`, RunE: func(cmd *cobra.Command, args []string) error { return runDoctor(cmd.Context(), fix) }, diff --git a/pkg/cli/exec.go b/pkg/cli/exec.go index a081e6ca80..4d14d5b320 100644 --- a/pkg/cli/exec.go +++ b/pkg/cli/exec.go @@ -26,9 +26,8 @@ func addGpusFlag(cmd *cobra.Command) { func newExecCommand() *cobra.Command { cmd := &cobra.Command{ - Use: "exec [arg...]", - Aliases: []string{"run"}, - Short: "Execute a command inside a Docker environment", + Use: "exec [arg...]", + Short: "Execute a command inside a Docker environment", Long: `Execute a command inside a Docker environment defined by cog.yaml. Cog builds a temporary image from your cog.yaml configuration and runs the @@ -69,10 +68,6 @@ exploring the environment your model will run in.`, } func execCmd(cmd *cobra.Command, args []string) error { - if cmd.CalledAs() == "run" { - console.Warn(`"cog run " is deprecated, use "cog exec "`) - } - ctx := cmd.Context() dockerClient, err := docker.NewClient(ctx) diff --git a/pkg/cli/init-templates/base/cog.yaml b/pkg/cli/init-templates/base/cog.yaml index 0451b77e43..7fdbc0d4f7 100644 --- a/pkg/cli/init-templates/base/cog.yaml +++ b/pkg/cli/init-templates/base/cog.yaml @@ -21,5 +21,5 @@ build: # - "echo env is ready!" # - "echo another command if needed" -# predict.py defines how predictions are run on your model -predict: "predict.py:Predictor" +# run.py defines how runs are handled by your model +run: "run.py:Runner" diff --git a/pkg/cli/init-templates/base/predict.py b/pkg/cli/init-templates/base/run.py similarity index 55% rename from pkg/cli/init-templates/base/predict.py rename to pkg/cli/init-templates/base/run.py index 89dfdac167..9986ca35db 100644 --- a/pkg/cli/init-templates/base/predict.py +++ b/pkg/cli/init-templates/base/run.py @@ -1,22 +1,22 @@ -# Prediction interface for Cog ⚙️ +# Run interface for Cog # https://cog.run/python -from cog import BasePredictor, Input, Path +from cog import BaseRunner, Input, Path -class Predictor(BasePredictor): +class Runner(BaseRunner): def setup(self) -> None: - """Load the model into memory to make running multiple predictions efficient""" + """Load the model into memory to make running multiple requests efficient""" # self.model = torch.load("./weights.pth") - def predict( + def run( self, image: Path = Input(description="Grayscale input image"), scale: float = Input( description="Factor to scale image by", ge=0, le=10, default=1.5 ), ) -> Path: - """Run a single prediction on the model""" + """Run the model on a single input""" # processed_input = preprocess(image) - # output = self.model(processed_image, scale) + # output = self.model(processed_input, scale) # return postprocess(output) diff --git a/pkg/cli/init.go b/pkg/cli/init.go index da90028235..de2d7b08b3 100644 --- a/pkg/cli/init.go +++ b/pkg/cli/init.go @@ -23,10 +23,10 @@ func newInitCommand() *cobra.Command { Use: "init", SuggestFor: []string{"new", "start"}, Short: "Configure your project for use with Cog", - Long: `Create a cog.yaml and predict.py in the current directory. + Long: `Create a cog.yaml and run.py in the current directory. These files provide a starting template for defining your model's environment -and prediction interface. Edit them to match your model's requirements.`, +and run interface. Edit them to match your model's requirements.`, Example: ` # Set up a new Cog project in the current directory cog init`, RunE: initCommand, diff --git a/pkg/cli/init_test.go b/pkg/cli/init_test.go index 3648394226..7ed9066bc4 100644 --- a/pkg/cli/init_test.go +++ b/pkg/cli/init_test.go @@ -18,8 +18,23 @@ func TestInit(t *testing.T) { require.FileExists(t, path.Join(dir, ".dockerignore")) require.FileExists(t, path.Join(dir, "cog.yaml")) - require.FileExists(t, path.Join(dir, "predict.py")) + require.FileExists(t, path.Join(dir, "run.py")) require.FileExists(t, path.Join(dir, "requirements.txt")) + require.NoFileExists(t, path.Join(dir, "predict.py")) + + cogYaml, err := os.ReadFile(path.Join(dir, "cog.yaml")) + require.NoError(t, err) + require.Contains(t, string(cogYaml), `run: "run.py:Runner"`) + require.NotContains(t, string(cogYaml), `predict: "predict.py:Predictor"`) + + runPy, err := os.ReadFile(path.Join(dir, "run.py")) + require.NoError(t, err) + require.Contains(t, string(runPy), "from cog import BaseRunner, Input, Path") + require.Contains(t, string(runPy), "class Runner(BaseRunner):") + require.Contains(t, string(runPy), "def run(") + require.NotContains(t, string(runPy), "BasePredictor") + require.NotContains(t, string(runPy), "class Predictor") + require.NotContains(t, string(runPy), "def predict(") } func TestInitSkipExisting(t *testing.T) { @@ -33,11 +48,11 @@ func TestInitSkipExisting(t *testing.T) { require.FileExists(t, path.Join(dir, ".dockerignore")) require.FileExists(t, path.Join(dir, "cog.yaml")) - require.FileExists(t, path.Join(dir, "predict.py")) + require.FileExists(t, path.Join(dir, "run.py")) // update the file to show that its the same file after the second run require.NoError(t, os.WriteFile(path.Join(dir, "cog.yaml"), []byte("test123"), 0o644)) - require.NoError(t, os.WriteFile(path.Join(dir, "predict.py"), []byte("test456"), 0o644)) + require.NoError(t, os.WriteFile(path.Join(dir, "run.py"), []byte("test456"), 0o644)) require.NoError(t, os.WriteFile(path.Join(dir, ".dockerignore"), []byte("test789"), 0o644)) // Second run should skip the files that already exist @@ -46,14 +61,14 @@ func TestInitSkipExisting(t *testing.T) { require.FileExists(t, path.Join(dir, ".dockerignore")) require.FileExists(t, path.Join(dir, "cog.yaml")) - require.FileExists(t, path.Join(dir, "predict.py")) + require.FileExists(t, path.Join(dir, "run.py")) // check that the files are the same as the first run content, err := os.ReadFile(path.Join(dir, "cog.yaml")) require.NoError(t, err) require.Equal(t, []byte("test123"), content) - content, err = os.ReadFile(path.Join(dir, "predict.py")) + content, err = os.ReadFile(path.Join(dir, "run.py")) require.NoError(t, err) require.Equal(t, []byte("test456"), content) diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index 6cef3f3b91..48d1fb1bb7 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -44,18 +44,7 @@ var ( inputJSON string ) -func newPredictCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "predict [image]", - Short: "Run a prediction", - Long: `Run a prediction. - -If 'image' is passed, it will run the prediction on that Docker image. -It must be an image that has been built by Cog. - -Otherwise, it will build the model in the current directory and run -the prediction on that.`, - Example: ` # Run a prediction with named inputs +const existingPredictExamples = ` # Run a prediction with named inputs cog predict -i prompt="a photo of a cat" # Pass a file as input @@ -71,11 +60,32 @@ the prediction on that.`, cog predict r8.im/your-username/my-model -i prompt="hello" # Pass inputs as JSON - echo '{"prompt": "a cat"}' | cog predict --json @-`, + echo '{"prompt": "a cat"}' | cog predict --json @-` + +func newPredictCommand() *cobra.Command { + return newPredictionCommand("predict", true) +} + +func newPredictionCommand(use string, hidden bool) *cobra.Command { + cmd := &cobra.Command{ + Use: use + " [image]", + Short: "Run a prediction", + Long: `Run a prediction. + +If 'image' is passed, it will run the prediction on that Docker image. +It must be an image that has been built by Cog. + +Otherwise, it will build the model in the current directory and run +the prediction on that.`, + Example: strings.ReplaceAll(existingPredictExamples, "cog predict", "cog "+use), RunE: cmdPredict, Args: cobra.MaximumNArgs(1), + Hidden: hidden, SuggestFor: []string{"infer"}, } + if hidden { + cmd.Short = "Run a prediction (deprecated, use cog run)" + } addUseCudaBaseImageFlag(cmd) addUseCogBaseImageFlag(cmd) @@ -178,6 +188,10 @@ func transformPathsToBase64URLs(inputs map[string]any) (map[string]any, error) { } func cmdPredict(cmd *cobra.Command, args []string) error { + if cmd.CalledAs() == "predict" || cmd.Name() == "predict" { + console.Warn(`"cog predict" is deprecated, use "cog run"`) + } + ctx, stop := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) defer stop() diff --git a/pkg/cli/root.go b/pkg/cli/root.go index ce337c86d8..f69f369d83 100644 --- a/pkg/cli/root.go +++ b/pkg/cli/root.go @@ -48,6 +48,7 @@ https://github.com/replicate/cog`, newDoctorCommand(), newInitCommand(), newLoginCommand(), + newRunCommand(), newPredictCommand(), newPushCommand(), newExecCommand(), diff --git a/pkg/cli/run.go b/pkg/cli/run.go new file mode 100644 index 0000000000..1d7e0b4adf --- /dev/null +++ b/pkg/cli/run.go @@ -0,0 +1,138 @@ +package cli + +import ( + "strings" + + "github.com/spf13/cobra" + + "github.com/replicate/cog/pkg/util/console" +) + +type runDispatchMode int + +const ( + runDispatchPredict runDispatchMode = iota + runDispatchExec +) + +func runDispatchModeForArgs(args []string) runDispatchMode { + remaining, hasUnknownFlag := runArgsAfterPredictionFlags(args) + if hasUnknownFlag { + return runDispatchExec + } + if len(remaining) == 1 && isLikelyRunCommand(remaining[0]) { + return runDispatchExec + } + if len(remaining) <= 1 { + return runDispatchPredict + } + return runDispatchExec +} + +func newRunCommand() *cobra.Command { + cmd := newPredictionCommand("run", false) + cmd.DisableFlagParsing = true + cmd.Args = cobra.ArbitraryArgs + cmd.RunE = cmdRun + cmd.PreRunE = checkMutuallyExclusiveFlags + cmd.Flags().StringArrayVarP(&execPorts, "publish", "p", []string{}, "Publish a container's port to the host, e.g. -p 8000") + _ = cmd.Flags().MarkHidden("publish") + return cmd +} + +func cmdRun(cmd *cobra.Command, args []string) error { + mode := runDispatchModeForArgs(args) + if mode == runDispatchPredict && runArgsContainHelp(args) { + return cmd.Help() + } + if mode == runDispatchExec { + cmd.Flags().SetInterspersed(false) + if err := cmd.Flags().Parse(args); err != nil { + return err + } + if err := checkMutuallyExclusiveFlags(cmd, cmd.Flags().Args()); err != nil { + return err + } + if len(cmd.Flags().Args()) == 0 { + return cobra.MinimumNArgs(1)(cmd, cmd.Flags().Args()) + } + console.Warn(`"cog run " is deprecated, use "cog exec "`) + return execCmd(cmd, cmd.Flags().Args()) + } + cmd.Flags().SetInterspersed(true) + if err := cmd.Flags().Parse(args); err != nil { + return err + } + if err := checkMutuallyExclusiveFlags(cmd, cmd.Flags().Args()); err != nil { + return err + } + return cmdPredict(cmd, cmd.Flags().Args()) +} + +func isLikelyRunCommand(arg string) bool { + switch arg { + case "bash", "sh", "zsh", "python", "python3", "ipython", "jupyter", "pip", "pip3", "uv": + return true + default: + return false + } +} + +func runArgsContainHelp(args []string) bool { + for _, arg := range args { + if arg == "--help" || arg == "-h" { + return true + } + } + return false +} + +func runArgsAfterPredictionFlags(args []string) ([]string, bool) { + remaining := []string{} + for i := 0; i < len(args); i++ { + arg := args[i] + if !strings.HasPrefix(arg, "-") { + remaining = append(remaining, arg) + continue + } + if (arg == "--help" || arg == "-h") && len(remaining) > 0 { + return remaining, true + } + if !isRunPredictionFlag(arg) { + return remaining, true + } + if runPredictionFlagTakesValue(arg) && !strings.Contains(arg, "=") && i+1 < len(args) { + i++ + } + } + return remaining, false +} + +func isRunPredictionFlag(arg string) bool { + name := strings.TrimLeft(arg, "-") + if before, _, ok := strings.Cut(name, "="); ok { + name = before + } + switch name { + case "h", "help", "i", "input", "o", "output", "e", "env", "use-replicate-token", "json", + "use-cuda-base-image", "use-cog-base-image", "progress", "dockerfile", "gpus", + "setup-timeout", "f", "file": + return true + default: + return false + } +} + +func runPredictionFlagTakesValue(arg string) bool { + name := strings.TrimLeft(arg, "-") + if before, _, ok := strings.Cut(name, "="); ok { + name = before + } + switch name { + case "i", "input", "o", "output", "e", "env", "json", "progress", "dockerfile", + "gpus", "setup-timeout", "f", "file", "use-cuda-base-image": + return true + default: + return false + } +} diff --git a/pkg/cli/run_test.go b/pkg/cli/run_test.go new file mode 100644 index 0000000000..89f4392185 --- /dev/null +++ b/pkg/cli/run_test.go @@ -0,0 +1,67 @@ +package cli + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRunCommandDispatchMode(t *testing.T) { + tests := []struct { + name string + args []string + want runDispatchMode + }{ + {name: "no args predicts local project", args: nil, want: runDispatchPredict}, + {name: "one arg predicts image", args: []string{"r8.im/acme/model"}, want: runDispatchPredict}, + {name: "one likely command forwards to exec", args: []string{"python"}, want: runDispatchExec}, + {name: "run flag before likely command forwards to exec", args: []string{"--gpus", "all", "python"}, want: runDispatchExec}, + {name: "one image plus input flag predicts image", args: []string{"r8.im/acme/model", "-i", "prompt=hello"}, want: runDispatchPredict}, + {name: "two args forwards to exec", args: []string{"python", "script.py"}, want: runDispatchExec}, + {name: "command flag forwards to exec", args: []string{"python", "-m", "http.server"}, want: runDispatchExec}, + {name: "command args before input-like flag forwards to exec", args: []string{"python", "script.py", "-i", "input.txt"}, want: runDispatchExec}, + {name: "run flag before command forwards to exec", args: []string{"--gpus", "all", "python", "script.py"}, want: runDispatchExec}, + {name: "run flag before image predicts image", args: []string{"--gpus", "all", "r8.im/acme/model"}, want: runDispatchPredict}, + {name: "config file before image predicts image", args: []string{"--file", "custom.yaml", "r8.im/acme/model"}, want: runDispatchPredict}, + {name: "cuda base flag before image predicts image", args: []string{"--use-cuda-base-image", "false", "r8.im/acme/model"}, want: runDispatchPredict}, + {name: "publish before command forwards to exec", args: []string{"-p", "8888", "jupyter", "notebook"}, want: runDispatchExec}, + {name: "publish without command forwards to exec for arg validation", args: []string{"-p", "8888"}, want: runDispatchExec}, + {name: "help alone is prediction help", args: []string{"--help"}, want: runDispatchPredict}, + {name: "command help forwards to exec", args: []string{"python", "script.py", "--help"}, want: runDispatchExec}, + {name: "single command help forwards to exec", args: []string{"python", "--help"}, want: runDispatchExec}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, runDispatchModeForArgs(tt.args)) + }) + } +} + +func TestRunCommandDispatchModeAfterFlagParsing(t *testing.T) { + cmd := newRunCommand() + require.True(t, cmd.DisableFlagParsing) + require.Equal(t, runDispatchPredict, runDispatchModeForArgs([]string{"r8.im/acme/model", "-i", "prompt=hello"})) +} + +func TestCommandVisibilityForRunPredictAndExec(t *testing.T) { + root, err := NewRootCommand() + require.NoError(t, err) + + runCmd, _, err := root.Find([]string{"run"}) + require.NoError(t, err) + require.Equal(t, "run [image]", runCmd.Use) + require.False(t, runCmd.Hidden) + require.Contains(t, runCmd.Example, "cog run -i prompt") + publishFlag := runCmd.Flags().Lookup("publish") + require.NotNil(t, publishFlag) + require.True(t, publishFlag.Hidden) + + predictCmd, _, err := root.Find([]string{"predict"}) + require.NoError(t, err) + require.True(t, predictCmd.Hidden) + require.Contains(t, predictCmd.Short, "deprecated") + + execCmd, _, err := root.Find([]string{"exec"}) + require.NoError(t, err) + require.NotContains(t, execCmd.Aliases, "run") +} diff --git a/pkg/cli/weights.go b/pkg/cli/weights.go index d6aebad39d..6cba3ff55a 100644 --- a/pkg/cli/weights.go +++ b/pkg/cli/weights.go @@ -46,7 +46,7 @@ func newWeightsImportCommand() *cobra.Command { and pushes the layers to a registry. Import also warms the local content-addressed weight store as a side -effect, so 'cog predict' can mount the weights immediately without a +effect, so 'cog run' can mount the weights immediately without a separate 'cog weights pull'. Pull is still useful when someone clones a repo with a checked-in weights.lock but a cold local cache. diff --git a/pkg/cli/weights_pull.go b/pkg/cli/weights_pull.go index fe43f27955..e37f0891ad 100644 --- a/pkg/cli/weights_pull.go +++ b/pkg/cli/weights_pull.go @@ -23,7 +23,7 @@ func newWeightsPullCommand() *cobra.Command { Use: "pull [NAME...]", Short: "Populate the local weight cache from the registry", Long: `Downloads weight files from the registry into the local content-addressed -cache so 'cog predict' and 'cog run' can mount them at runtime. +cache so 'cog run' can mount them at runtime. You don't need to run 'cog weights pull' after 'cog weights import' — import already warms the local cache. Pull is for the case where diff --git a/pkg/config/config_file.go b/pkg/config/config_file.go index 84d50c0443..84606f1ae5 100644 --- a/pkg/config/config_file.go +++ b/pkg/config/config_file.go @@ -14,6 +14,7 @@ import ( type configFile struct { Build *buildFile `json:"build,omitempty" yaml:"build,omitempty"` Image *string `json:"image,omitempty" yaml:"image,omitempty"` + Run *string `json:"run,omitempty" yaml:"run,omitempty"` Predict *string `json:"predict,omitempty" yaml:"predict,omitempty"` Train *string `json:"train,omitempty" yaml:"train,omitempty"` Concurrency *concurrencyFile `json:"concurrency,omitempty" yaml:"concurrency,omitempty"` diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 93f12c9b59..2c0239b42d 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -660,6 +660,26 @@ predict: "" `, string(data)) } +func TestFromYAMLRunPopulatesPredict(t *testing.T) { + t.Run("non-empty run wins", func(t *testing.T) { + cfg, err := FromYAML([]byte("build:\n python_version: \"3.13\"\nrun: \"run.py:Runner\"\n")) + require.NoError(t, err) + require.Equal(t, "run.py:Runner", cfg.Predict) + }) + + t.Run("empty run falls back to predict", func(t *testing.T) { + cfg, err := FromYAML([]byte("build:\n python_version: \"3.13\"\nrun: \"\"\npredict: \"predict.py:Predictor\"\n")) + require.NoError(t, err) + require.Equal(t, "predict.py:Predictor", cfg.Predict) + }) +} + +func TestFromYAMLRunAndPredictConflict(t *testing.T) { + _, err := FromYAML([]byte("build:\n python_version: \"3.13\"\nrun: \"run.py:Runner\"\npredict: \"predict.py:Predictor\"\n")) + require.Error(t, err) + require.Contains(t, err.Error(), "only one of run or predict can be set") +} + func TestAbsolutePathInPythonRequirements(t *testing.T) { dir := t.TempDir() requirementsFilePath := filepath.Join(dir, "requirements.txt") diff --git a/pkg/config/data/config_schema_v1.0.json b/pkg/config/data/config_schema_v1.0.json index b42cc28296..8bf92cbcfa 100644 --- a/pkg/config/data/config_schema_v1.0.json +++ b/pkg/config/data/config_schema_v1.0.json @@ -160,10 +160,15 @@ "type": "string", "description": "The name given to built Docker images. If you want to push to a registry, this should also include the registry name." }, + "run": { + "$id": "#/properties/run", + "type": "string", + "description": "The pointer to the Runner object in your code, which defines how runs are executed on your model." + }, "predict": { "$id": "#/properties/predict", "type": "string", - "description": "The pointer to the `Predictor` object in your code, which defines how predictions are run on your model." + "description": "Deprecated compatibility field for run. Use run instead." }, "train": { "$id": "#/properties/train", diff --git a/pkg/config/parse.go b/pkg/config/parse.go index 9990c74306..4714e14464 100644 --- a/pkg/config/parse.go +++ b/pkg/config/parse.go @@ -114,7 +114,12 @@ func configFileToConfig(cfg *configFile) (*Config, error) { if cfg.Image != nil { config.Image = *cfg.Image } - if cfg.Predict != nil { + if cfg.Run != nil && cfg.Predict != nil && *cfg.Run != "" && *cfg.Predict != "" { + return nil, &ValidationError{Field: "run", Message: "only one of run or predict can be set"} + } + if cfg.Run != nil && *cfg.Run != "" { + config.Predict = *cfg.Run + } else if cfg.Predict != nil { config.Predict = *cfg.Predict } if cfg.Train != nil { diff --git a/pkg/config/validate.go b/pkg/config/validate.go index 60f00f872e..8e11169213 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -96,16 +96,25 @@ func validateSchema(cfg *configFile) error { // validatePredict validates the predict field. func validatePredict(cfg *configFile, result *ValidationResult) { - if cfg.Predict == nil || *cfg.Predict == "" { + if cfg.Run != nil && *cfg.Run != "" && cfg.Predict != nil && *cfg.Predict != "" { + result.AddError(&ValidationError{Field: "run", Message: "only one of run or predict can be set"}) return } - predict := *cfg.Predict - if len(strings.Split(predict, ".py:")) != 2 { + if cfg.Run != nil && *cfg.Run != "" { + validatePredictRef("run", *cfg.Run, "run.py:Runner", result) + } + if cfg.Predict != nil && *cfg.Predict != "" { + validatePredictRef("predict", *cfg.Predict, "predict.py:Predictor", result) + } +} + +func validatePredictRef(field string, ref string, example string, result *ValidationResult) { + if len(strings.Split(ref, ".py:")) != 2 { result.AddError(&ValidationError{ - Field: "predict", - Value: predict, - Message: "must be in the form 'predict.py:Predictor'", + Field: field, + Value: ref, + Message: fmt.Sprintf("must be in the form '%s'", example), }) } } @@ -612,6 +621,14 @@ func isSubpath(child, parent string) bool { // checkDeprecatedFields checks for deprecated fields and adds warnings. func checkDeprecatedFields(cfg *configFile, result *ValidationResult) { + if cfg.Predict != nil && *cfg.Predict != "" { + result.AddWarning(DeprecationWarning{ + Field: "predict", + Replacement: "run", + Message: "use run to point at run.py:Runner", + }) + } + if cfg.Build == nil { return } diff --git a/pkg/config/validate_test.go b/pkg/config/validate_test.go index 809aec55cd..534863ca5c 100644 --- a/pkg/config/validate_test.go +++ b/pkg/config/validate_test.go @@ -94,6 +94,47 @@ func TestValidateConfigFilePredictFormat(t *testing.T) { require.Contains(t, result.Err().Error(), "predict.py:Predictor") } +func TestValidateConfigFileRunPredictCompatibility(t *testing.T) { + t.Run("run is valid", func(t *testing.T) { + cfg := &configFile{Build: &buildFile{PythonVersion: ptr("3.10")}, Run: ptr("run.py:Runner")} + result := ValidateConfigFile(cfg) + require.False(t, result.HasErrors(), "expected no errors, got: %v", result.Errors) + require.Empty(t, result.Warnings) + }) + + t.Run("run validates reference format", func(t *testing.T) { + cfg := &configFile{Build: &buildFile{PythonVersion: ptr("3.10")}, Run: ptr("invalid_format")} + result := ValidateConfigFile(cfg) + require.True(t, result.HasErrors()) + require.Contains(t, result.Err().Error(), "run.py:Runner") + }) + + t.Run("predict warns", func(t *testing.T) { + cfg := &configFile{Build: &buildFile{PythonVersion: ptr("3.10")}, Predict: ptr("predict.py:Predictor")} + result := ValidateConfigFile(cfg) + require.False(t, result.HasErrors()) + require.Len(t, result.Warnings, 1) + require.Equal(t, "predict", result.Warnings[0].Field) + require.Equal(t, "run", result.Warnings[0].Replacement) + }) + + t.Run("predict warns without build", func(t *testing.T) { + cfg := &configFile{Predict: ptr("predict.py:Predictor")} + result := ValidateConfigFile(cfg) + require.False(t, result.HasErrors()) + require.Len(t, result.Warnings, 1) + require.Equal(t, "predict", result.Warnings[0].Field) + require.Equal(t, "run", result.Warnings[0].Replacement) + }) + + t.Run("both fields error", func(t *testing.T) { + cfg := &configFile{Build: &buildFile{PythonVersion: ptr("3.10")}, Run: ptr("run.py:Runner"), Predict: ptr("predict.py:Predictor")} + result := ValidateConfigFile(cfg) + require.True(t, result.HasErrors()) + require.Contains(t, result.Err().Error(), "only one of run or predict can be set") + }) +} + func TestValidateConfigFileConcurrencyType(t *testing.T) { cfg := &configFile{ Build: &buildFile{ diff --git a/pkg/doctor/check_config_test.go b/pkg/doctor/check_config_test.go index 3e52993fa0..3ad2cefe40 100644 --- a/pkg/doctor/check_config_test.go +++ b/pkg/doctor/check_config_test.go @@ -55,7 +55,7 @@ func TestConfigDeprecatedFieldsCheck_Clean(t *testing.T) { writeFile(t, dir, "cog.yaml", `build: python_version: "3.12" python_requirements: "requirements.txt" -predict: "predict.py:Predictor" +run: "run.py:Runner" `) writeFile(t, dir, "requirements.txt", "torch==2.0.0\n") @@ -72,7 +72,7 @@ func TestConfigDeprecatedFieldsCheck_PythonPackages(t *testing.T) { python_version: "3.12" python_packages: - torch==2.0.0 -predict: "predict.py:Predictor" +run: "run.py:Runner" `) ctx := buildTestCheckContext(t, dir) @@ -90,7 +90,7 @@ func TestConfigDeprecatedFieldsCheck_PreInstall(t *testing.T) { python_version: "3.12" pre_install: - pip install something -predict: "predict.py:Predictor" +run: "run.py:Runner" `) ctx := buildTestCheckContext(t, dir) @@ -203,6 +203,624 @@ predict: "predict.py:Predictor" require.NoError(t, err) } +func TestPredictToRunMigrationCheck_CleanRunProject(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +run: "run.py:Runner" +`) + writeFile(t, dir, "run.py", `from cog import BaseRunner +class Runner(BaseRunner): + def run(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestPredictToRunMigrationCheck_DetectsLegacyPredictUsage(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.NotEmpty(t, findings) + require.Equal(t, SeverityWarning, findings[0].Severity) + require.Contains(t, findings[0].Message, "deprecated") + require.Contains(t, findings[0].Remediation, "cog doctor --fix") +} + +func TestPredictToRunMigrationCheck_FixMigratesConfigAndPython(t *testing.T) { + tests := []struct { + name string + classLine string + wantLine string + }{ + {name: "class with base", classLine: "class Predictor(BasePredictor):", wantLine: "class Runner(BaseRunner):"}, + {name: "class without base", classLine: "class Predictor:", wantLine: "class Runner:"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor + +`+tt.classLine+` + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.NotEmpty(t, findings) + require.NoError(t, check.Fix(ctx, findings)) + + cogYAML, err := os.ReadFile(filepath.Join(dir, "cog.yaml")) + require.NoError(t, err) + require.Contains(t, string(cogYAML), `run: "run.py:Runner"`) + require.NotContains(t, string(cogYAML), `predict: "predict.py:Predictor"`) + + _, err = os.Stat(filepath.Join(dir, "predict.py")) + require.ErrorIs(t, err, os.ErrNotExist) + + runPy, err := os.ReadFile(filepath.Join(dir, "run.py")) + require.NoError(t, err) + require.Contains(t, string(runPy), "from cog import BaseRunner") + require.Contains(t, string(runPy), tt.wantLine) + require.Contains(t, string(runPy), "def run(") + require.NotContains(t, string(runPy), "BasePredictor") + require.NotContains(t, string(runPy), "class Predictor") + require.NotContains(t, string(runPy), "def predict(") + }) + } +} + +func TestPredictToRunMigrationCheck_FixRefusesFileCollision(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + writeFile(t, dir, "run.py", `class Runner: + def run(self) -> str: + return "existing" +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Error(t, check.Fix(ctx, findings)) + + cogYAML, err := os.ReadFile(filepath.Join(dir, "cog.yaml")) + require.NoError(t, err) + require.Contains(t, string(cogYAML), `predict: "predict.py:Predictor"`) + require.NotContains(t, string(cogYAML), `run: "run.py:Runner"`) +} + +func TestPredictToRunMigrationCheck_FixRefusesConfigOnlyFileCollision(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "run.py", `class Runner: + def run(self) -> str: + return "existing" +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Error(t, check.Fix(ctx, findings)) + + cogYAML, err := os.ReadFile(filepath.Join(dir, "cog.yaml")) + require.NoError(t, err) + require.Contains(t, string(cogYAML), `predict: "predict.py:Predictor"`) + require.NotContains(t, string(cogYAML), `run: "run.py:Runner"`) +} + +func TestPredictToRunMigrationCheck_FixRefusesMissingPredictFile(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Error(t, check.Fix(ctx, findings)) + + cogYAML, err := os.ReadFile(filepath.Join(dir, "cog.yaml")) + require.NoError(t, err) + require.Contains(t, string(cogYAML), `predict: "predict.py:Predictor"`) + require.NotContains(t, string(cogYAML), `run: "run.py:Runner"`) +} + +func TestPredictToRunMigrationCheck_DoesNotRewriteTrainFile(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +run: "run.py:Runner" +train: "train.py:Trainer" +`) + writeFile(t, dir, "run.py", `from cog import BaseRunner +class Runner(BaseRunner): + def run(self, text: str) -> str: + return text +`) + writeFile(t, dir, "train.py", `class Trainer: + def predict(self) -> str: + return "helper" +`) + + ctx := buildTestCheckContext(t, dir) + parsePythonRef(ctx, "train.py:Trainer") + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestPredictToRunMigrationCheck_IgnoresStrayPredictFileWithoutLegacyConfig(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +run: "run.py:Runner" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestPredictToRunMigrationCheck_FixRefusesCustomPredictRef(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "foo.py:Predictor" +`) + writeFile(t, dir, "foo.py", `from cog import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.NotEmpty(t, findings) + err = check.Fix(ctx, findings) + require.Error(t, err) + require.Contains(t, err.Error(), "Manual migration required") + + cogYAML, err := os.ReadFile(filepath.Join(dir, "cog.yaml")) + require.NoError(t, err) + require.Contains(t, string(cogYAML), `predict: "foo.py:Predictor"`) +} + +func TestPredictToRunMigrationCheck_FixMigratesTargetPredictMethodOnly(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +class Helper(BasePredictor): + def predict(self) -> str: + return "helper" +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.NotEmpty(t, findings) + require.NoError(t, check.Fix(ctx, findings)) + + cogYAML, err := os.ReadFile(filepath.Join(dir, "cog.yaml")) + require.NoError(t, err) + require.Contains(t, string(cogYAML), `run: "run.py:Runner"`) + require.NotContains(t, string(cogYAML), `predict: "predict.py:Predictor"`) + + runPy, err := os.ReadFile(filepath.Join(dir, "run.py")) + require.NoError(t, err) + require.Contains(t, string(runPy), "from cog import BasePredictor, BaseRunner") + require.Contains(t, string(runPy), "class Runner(BaseRunner):") + require.Contains(t, string(runPy), "def run(self, text: str) -> str:") + require.Contains(t, string(runPy), "class Helper(BasePredictor):") + require.Contains(t, string(runPy), "def predict(self) -> str:") +} + +func TestPredictToRunMigrationCheck_CheckIgnoresHelperOnlyPredictMethods(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog import BaseRunner +class Runner(BaseRunner): + def run(self, text: str) -> str: + return text +class Helper: + def predict(self) -> str: + return "helper" +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + for _, finding := range findings { + require.NotEqual(t, "predict.py", finding.File) + } +} + +func TestPredictToRunMigrationCheck_FixMigratesCogPredictorImport(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog.predictor import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.NotEmpty(t, findings) + require.NoError(t, check.Fix(ctx, findings)) + + runPy, err := os.ReadFile(filepath.Join(dir, "run.py")) + require.NoError(t, err) + require.Contains(t, string(runPy), "from cog.predictor import BaseRunner") + require.Contains(t, string(runPy), "class Runner(BaseRunner):") +} + +func TestPredictToRunMigrationCheck_FixRefusesAliasedBasePredictor(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor as CogBasePredictor +class Predictor(CogBasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.NotEmpty(t, findings) + err = check.Fix(ctx, findings) + require.Error(t, err) + require.Contains(t, err.Error(), "Manual migration required") + + cogYAML, err := os.ReadFile(filepath.Join(dir, "cog.yaml")) + require.NoError(t, err) + require.Contains(t, string(cogYAML), `predict: "predict.py:Predictor"`) +} + +func TestPredictToRunMigrationCheck_FixMigratesNestedBasePredictorImport(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `try: + from cog import BasePredictor +except ImportError: + from cog.predictor import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.NotEmpty(t, findings) + require.NoError(t, check.Fix(ctx, findings)) + + runPy, err := os.ReadFile(filepath.Join(dir, "run.py")) + require.NoError(t, err) + require.Contains(t, string(runPy), "from cog import BaseRunner") + require.Contains(t, string(runPy), "from cog.predictor import BaseRunner") + require.Contains(t, string(runPy), "class Runner(BaseRunner):") +} + +func TestPredictToRunMigrationCheck_FixMigratesImportListBasePredictor(t *testing.T) { + tests := []struct { + name string + importCode string + wantImport string + }{ + { + name: "multi-name", + importCode: `from cog import BasePredictor, Input`, + wantImport: `from cog import BaseRunner, Input`, + }, + { + name: "parenthesized", + importCode: `from cog import ( + BasePredictor, + Input, +)`, + wantImport: "BaseRunner,", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", tt.importCode+` +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.NotEmpty(t, findings) + require.NoError(t, check.Fix(ctx, findings)) + + runPy, err := os.ReadFile(filepath.Join(dir, "run.py")) + require.NoError(t, err) + require.Contains(t, string(runPy), tt.wantImport) + require.Contains(t, string(runPy), "class Runner(BaseRunner):") + }) + } +} + +func TestPredictToRunMigrationCheck_FixRefusesNestedAliasedBasePredictor(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `try: + from cog import BasePredictor as CogBasePredictor +except ImportError: + from cog.predictor import BasePredictor as CogBasePredictor +class Predictor(CogBasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.NotEmpty(t, findings) + err = check.Fix(ctx, findings) + require.Error(t, err) + require.Contains(t, err.Error(), "Manual migration required") +} + +func TestPredictToRunMigrationCheck_FixMigratesBasePredictorWhenBaseRunnerImportedElsewhere(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor +if TYPE_CHECKING: + from cog import BaseRunner +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.NotEmpty(t, findings) + require.NoError(t, check.Fix(ctx, findings)) + + runPy, err := os.ReadFile(filepath.Join(dir, "run.py")) + require.NoError(t, err) + require.Contains(t, string(runPy), "from cog import BaseRunner") + require.NotContains(t, string(runPy), "from cog import BasePredictor") + require.Contains(t, string(runPy), "class Runner(BaseRunner):") +} + +func TestPredictToRunMigrationCheck_FixRefusesNonCogBasePredictor(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from other_pkg import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.NotEmpty(t, findings) + err = check.Fix(ctx, findings) + require.Error(t, err) + require.Contains(t, err.Error(), "Manual migration required") +} + +func TestPredictToRunMigrationCheck_FixRefusesShadowedBasePredictorImport(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor +from other_pkg import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.NotEmpty(t, findings) + err = check.Fix(ctx, findings) + require.Error(t, err) + require.Contains(t, err.Error(), "Manual migration required") +} + +func TestPredictToRunMigrationCheck_FixAddsBaseRunnerToAllFallbackImports(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `try: + from cog import BasePredictor +except ImportError: + from cog.predictor import BasePredictor +class Helper(BasePredictor): + def predict(self) -> str: + return "helper" +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.NotEmpty(t, findings) + require.NoError(t, check.Fix(ctx, findings)) + + runPy, err := os.ReadFile(filepath.Join(dir, "run.py")) + require.NoError(t, err) + require.Contains(t, string(runPy), "from cog import BasePredictor, BaseRunner") + require.Contains(t, string(runPy), "from cog.predictor import BasePredictor, BaseRunner") + require.Contains(t, string(runPy), "class Helper(BasePredictor):") + require.Contains(t, string(runPy), "class Runner(BaseRunner):") +} + +func TestPredictToRunMigrationCheck_FixRefusesExistingRunField(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +run: "run.py:Runner" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &PredictToRunMigrationCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.NotEmpty(t, findings) + require.Error(t, check.Fix(ctx, findings)) +} + +func TestPredictToRunMigrationCheck_RunFixesContextForFollowingChecks(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + result, err := Run(context.Background(), RunOptions{Fix: true, ProjectDir: dir}, []Check{ + &PredictToRunMigrationCheck{}, + &ConfigPredictRefCheck{}, + }) + require.NoError(t, err) + require.Len(t, result.Results, 2) + require.True(t, result.Results[0].Fixed) + require.NoError(t, result.Results[1].Err) + require.Empty(t, result.Results[1].Findings) +} + +func TestPredictToRunMigrationCheck_RunFixSuppressesStalePredictDeprecation(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + result, err := Run(context.Background(), RunOptions{Fix: true, ProjectDir: dir}, []Check{ + &PredictToRunMigrationCheck{}, + &ConfigDeprecatedFieldsCheck{}, + }) + require.NoError(t, err) + require.Len(t, result.Results, 2) + require.True(t, result.Results[0].Fixed) + require.Empty(t, result.Results[1].Findings) +} + // buildTestCheckContext creates a CheckContext by loading the cog.yaml in the given dir. func buildTestCheckContext(t *testing.T, dir string) *CheckContext { t.Helper() diff --git a/pkg/doctor/check_predict_to_run_migration.go b/pkg/doctor/check_predict_to_run_migration.go new file mode 100644 index 0000000000..99d44b4343 --- /dev/null +++ b/pkg/doctor/check_predict_to_run_migration.go @@ -0,0 +1,422 @@ +package doctor + +import ( + "fmt" + "os" + "path/filepath" + "regexp" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" + + schemaPython "github.com/replicate/cog/pkg/schema/python" +) + +var ( + topLevelPredictKeyPattern = regexp.MustCompile(`(?m)^predict\s*:`) + topLevelRunRefPattern = regexp.MustCompile(`(?m)^run:\s*["'][^"']+["']\s*$`) + predictRefPattern = regexp.MustCompile(`(?m)^predict:\s*["']predict\.py:Predictor["']\s*$`) +) + +// PredictToRunMigrationCheck detects deprecated predict interface names and +// migrates the common starter-project shape to run interface names. +type PredictToRunMigrationCheck struct{} + +func (c *PredictToRunMigrationCheck) Name() string { return "predict-to-run-migration" } +func (c *PredictToRunMigrationCheck) Group() Group { return GroupConfig } +func (c *PredictToRunMigrationCheck) Description() string { + return "Deprecated predict interface names" +} + +func (c *PredictToRunMigrationCheck) Check(ctx *CheckContext) ([]Finding, error) { + var findings []Finding + if topLevelPredictKeyPattern.Match(ctx.ConfigFile) || predictRefPattern.Match(ctx.ConfigFile) { + findings = append(findings, Finding{ + Severity: SeverityWarning, + Message: "predict in cog.yaml is deprecated; use run with run.py:Runner", + Remediation: "Run cog doctor --fix to migrate predict: to run:", + File: ctx.ConfigFilename, + }) + } + + if predictRefPattern.Match(ctx.ConfigFile) { + if pf, ok := predictMigrationFile(ctx); ok && hasLegacyPredictPythonNames(pf) { + findings = append(findings, Finding{ + Severity: SeverityWarning, + Message: "predict.py uses deprecated Predictor/BasePredictor/predict() names", + Remediation: "Run cog doctor --fix to migrate to Runner/BaseRunner/run()", + File: "predict.py", + }) + } + } + + return findings, nil +} + +func (c *PredictToRunMigrationCheck) Fix(ctx *CheckContext, findings []Finding) error { + if len(findings) == 0 { + return nil + } + + if err := preflightPredictToRunCollisions(ctx); err != nil { + return err + } + pf, ok := predictMigrationFile(ctx) + if !ok { + return fmt.Errorf("cannot migrate predict.py to run.py because predict.py was not found") + } + if !hasLegacyPredictPythonNames(pf) { + return fmt.Errorf("cannot migrate predict.py because no legacy Predictor/BasePredictor/predict() names were found") + } + edits, err := predictToRunMigrationEdits(pf) + if err != nil { + return err + } + source := applyEdits(pf.Source, edits) + + configPath := filepath.Join(ctx.ProjectDir, ctx.ConfigFilename) + configBytes, err := os.ReadFile(configPath) + if err != nil { + return err + } + configBytes = predictRefPattern.ReplaceAll(configBytes, []byte(`run: "run.py:Runner"`)) + if err := os.WriteFile(configPath, configBytes, 0o644); err != nil { + return err + } + + oldPath := filepath.Join(ctx.ProjectDir, "predict.py") + newPath := filepath.Join(ctx.ProjectDir, "run.py") + if err := os.WriteFile(newPath, source, 0o644); err != nil { + return err + } + if err := os.Remove(oldPath); err != nil { + return err + } + + ctx.ConfigFile = configBytes + if ctx.Config != nil { + ctx.Config.Predict = "run.py:Runner" + } + if ctx.LoadResult != nil && ctx.LoadResult.Config != nil { + ctx.LoadResult.Config.Predict = "run.py:Runner" + warnings := ctx.LoadResult.Warnings[:0] + for _, warning := range ctx.LoadResult.Warnings { + if warning.Field != "predict" { + warnings = append(warnings, warning) + } + } + ctx.LoadResult.Warnings = warnings + } + delete(ctx.PythonFiles, "predict.py") + parsePythonRef(ctx, "run.py:Runner") + + return nil +} + +func hasLegacyPredictPythonNames(pf *ParsedFile) bool { + root := pf.Tree.RootNode() + return findMigrationClassByName(root, pf.Source, "Predictor") != nil +} + +func preflightPredictToRunCollisions(ctx *CheckContext) error { + if topLevelRunRefPattern.Match(ctx.ConfigFile) { + return fmt.Errorf("automatic migration cannot run when run is already set") + } + if !predictRefPattern.Match(ctx.ConfigFile) { + return fmt.Errorf("Manual migration required: automatic migration only supports predict.py:Predictor") + } + candidate := filepath.Join(ctx.ProjectDir, "run.py") + if _, err := os.Stat(candidate); err == nil { + return fmt.Errorf("cannot migrate predict.py to run.py because run.py already exists") + } else if !os.IsNotExist(err) { + return err + } + return nil +} + +func predictMigrationFile(ctx *CheckContext) (*ParsedFile, bool) { + if pf, ok := ctx.PythonFiles["predict.py"]; ok && pf != nil { + return pf, true + } + path := "predict.py" + source, err := os.ReadFile(filepath.Join(ctx.ProjectDir, path)) + if err != nil { + return nil, false + } + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + tree, err := parser.ParseCtx(ctx.ctx, nil, source) + if err != nil { + return nil, false + } + pf := &ParsedFile{ + Path: path, + Source: source, + Tree: tree, + Imports: schemaPython.CollectImports(tree.RootNode(), source), + } + ctx.PythonFiles[path] = pf + return pf, true +} + +func predictToRunMigrationEdits(pf *ParsedFile) ([]byteEdit, error) { + root := pf.Tree.RootNode() + classes := findMigrationClassesByName(root, pf.Source, "Predictor") + if len(classes) != 1 { + return nil, fmt.Errorf("Manual migration required: automatic migration only supports a single Predictor class") + } + classNode := classes[0] + if targetClassUsesAliasedBasePredictor(classNode, pf) { + return nil, fmt.Errorf("Manual migration required: automatic migration does not support aliased BasePredictor inheritance") + } + predictMethods := findMigrationMethodsByName(classNode, pf.Source, "predict") + if len(predictMethods) != 1 || findMigrationMethodByName(classNode, pf.Source, "run") != nil { + return nil, fmt.Errorf("Manual migration required: automatic migration only supports a single Predictor class with a single predict method") + } + + var edits []byteEdit + if nameNode := classNode.ChildByFieldName("name"); nameNode != nil { + edits = append(edits, replaceNode(nameNode, []byte("Runner"))) + } + if nameNode := predictMethods[0].ChildByFieldName("name"); nameNode != nil { + edits = append(edits, replaceNode(nameNode, []byte("run"))) + } + targetBaseNodes := collectTargetBasePredictorNodes(classNode, pf.Source) + if len(targetBaseNodes) > 0 && len(collectCogImportIdentifiers(root, pf.Source, "BasePredictor")) == 0 { + return nil, fmt.Errorf("Manual migration required: automatic migration only supports BasePredictor imported from cog") + } + if len(targetBaseNodes) > 0 && hasUnsupportedBasePredictorImport(root, pf.Source) { + return nil, fmt.Errorf("Manual migration required: automatic migration only supports unambiguous BasePredictor imports from cog") + } + for _, node := range targetBaseNodes { + edits = append(edits, replaceNode(node, []byte("BaseRunner"))) + } + edits = append(edits, baseRunnerImportEdits(root, pf.Source, targetBaseNodes)...) + return edits, nil +} + +func replaceNode(node *sitter.Node, replacement []byte) byteEdit { + return byteEdit{start: node.StartByte(), end: node.EndByte(), replacement: replacement} +} + +func findMigrationClassByName(root *sitter.Node, source []byte, name string) *sitter.Node { + classes := findMigrationClassesByName(root, source, name) + if len(classes) == 0 { + return nil + } + return classes[0] +} + +func findMigrationClassesByName(root *sitter.Node, source []byte, name string) []*sitter.Node { + var classes []*sitter.Node + for _, child := range schemaPython.NamedChildren(root) { + classNode := schemaPython.UnwrapClass(child) + if classNode == nil { + continue + } + nameNode := classNode.ChildByFieldName("name") + if nameNode != nil && schemaPython.Content(nameNode, source) == name { + classes = append(classes, classNode) + } + } + return classes +} + +func findMigrationMethodByName(classNode *sitter.Node, source []byte, name string) *sitter.Node { + methods := findMigrationMethodsByName(classNode, source, name) + if len(methods) == 0 { + return nil + } + return methods[0] +} + +func findMigrationMethodsByName(classNode *sitter.Node, source []byte, name string) []*sitter.Node { + body := classNode.ChildByFieldName("body") + if body == nil { + return nil + } + var methods []*sitter.Node + for _, child := range schemaPython.NamedChildren(body) { + funcNode := schemaPython.UnwrapFunction(child) + if funcNode == nil { + continue + } + nameNode := funcNode.ChildByFieldName("name") + if nameNode != nil && schemaPython.Content(nameNode, source) == name { + methods = append(methods, funcNode) + } + } + return methods +} + +func collectTargetBasePredictorNodes(classNode *sitter.Node, source []byte) []*sitter.Node { + superclasses := classNode.ChildByFieldName("superclasses") + if superclasses == nil { + return nil + } + return collectMigrationIdentifiers(superclasses, source, "BasePredictor") +} + +func targetClassUsesAliasedBasePredictor(classNode *sitter.Node, pf *ParsedFile) bool { + superclasses := classNode.ChildByFieldName("superclasses") + if superclasses == nil { + return false + } + basePredictorAliases := collectBasePredictorImportAliases(pf.Tree.RootNode(), pf.Source) + for _, node := range collectIdentifiers(superclasses) { + name := schemaPython.Content(node, pf.Source) + if basePredictorAliases[name] { + return true + } + } + return false +} + +func baseRunnerImportEdits(root *sitter.Node, source []byte, targetBaseNodes []*sitter.Node) []byteEdit { + basePredictorImports := collectCogImportIdentifiers(root, source, "BasePredictor") + if len(basePredictorImports) == 0 { + return nil + } + basePredictorUsedOutsideImportAndTargetBase := basePredictorUsedOutside(root, source, append(targetBaseNodes, basePredictorImports...)) + if len(targetBaseNodes) == 0 && basePredictorUsedOutsideImportAndTargetBase { + return nil + } + if basePredictorUsedOutsideImportAndTargetBase { + edits := make([]byteEdit, 0, len(basePredictorImports)) + for _, node := range basePredictorImports { + edits = append(edits, byteEdit{start: node.EndByte(), end: node.EndByte(), replacement: []byte(", BaseRunner")}) + } + return edits + } + + edits := make([]byteEdit, 0, len(basePredictorImports)) + for _, node := range basePredictorImports { + edits = append(edits, replaceNode(node, []byte("BaseRunner"))) + } + return edits +} + +func collectCogImportIdentifiers(root *sitter.Node, source []byte, name string) []*sitter.Node { + var nodes []*sitter.Node + if root.Type() == "import_from_statement" { + moduleNode := root.ChildByFieldName("module_name") + if moduleNode != nil && isCogBaseModule(schemaPython.Content(moduleNode, source)) { + nodes = append(nodes, collectImportIdentifierNodes(root, moduleNode, source, name)...) + } + } + for _, child := range schemaPython.NamedChildren(root) { + nodes = append(nodes, collectCogImportIdentifiers(child, source, name)...) + } + return nodes +} + +func hasUnsupportedBasePredictorImport(root *sitter.Node, source []byte) bool { + if root.Type() == "import_from_statement" { + moduleNode := root.ChildByFieldName("module_name") + if moduleNode != nil && !isCogBaseModule(schemaPython.Content(moduleNode, source)) { + if len(collectImportIdentifierNodes(root, moduleNode, source, "BasePredictor")) > 0 { + return true + } + } + } + for _, child := range schemaPython.NamedChildren(root) { + if hasUnsupportedBasePredictorImport(child, source) { + return true + } + } + return false +} + +func collectImportIdentifierNodes(node *sitter.Node, moduleNode *sitter.Node, source []byte, name string) []*sitter.Node { + if node.StartByte() == moduleNode.StartByte() && node.EndByte() == moduleNode.EndByte() { + return nil + } + var nodes []*sitter.Node + if node.Type() == "dotted_name" && schemaPython.Content(node, source) == name { + nodes = append(nodes, node) + } + for _, child := range schemaPython.AllChildren(node) { + nodes = append(nodes, collectImportIdentifierNodes(child, moduleNode, source, name)...) + } + return nodes +} + +func isCogBaseModule(module string) bool { + return module == "cog" || module == "cog.predictor" +} + +func collectBasePredictorImportAliases(root *sitter.Node, source []byte) map[string]bool { + aliases := make(map[string]bool) + if root.Type() == "import_from_statement" { + moduleNode := root.ChildByFieldName("module_name") + if moduleNode != nil && isCogBaseModule(schemaPython.Content(moduleNode, source)) { + for _, node := range collectAliasedImportNodes(root) { + nameNode := node.ChildByFieldName("name") + aliasNode := node.ChildByFieldName("alias") + if nameNode != nil && aliasNode != nil && schemaPython.Content(nameNode, source) == "BasePredictor" { + aliases[schemaPython.Content(aliasNode, source)] = true + } + } + } + } + for _, child := range schemaPython.NamedChildren(root) { + for alias := range collectBasePredictorImportAliases(child, source) { + aliases[alias] = true + } + } + return aliases +} + +func collectAliasedImportNodes(node *sitter.Node) []*sitter.Node { + var nodes []*sitter.Node + if node.Type() == "aliased_import" { + nodes = append(nodes, node) + } + for _, child := range schemaPython.NamedChildren(node) { + nodes = append(nodes, collectAliasedImportNodes(child)...) + } + return nodes +} + +func basePredictorUsedOutside(node *sitter.Node, source []byte, ignored []*sitter.Node) bool { + if node.Type() == "identifier" && schemaPython.Content(node, source) == "BasePredictor" && !nodeInList(node, ignored) { + return true + } + for _, child := range schemaPython.NamedChildren(node) { + if basePredictorUsedOutside(child, source, ignored) { + return true + } + } + return false +} + +func nodeInList(node *sitter.Node, nodes []*sitter.Node) bool { + for _, candidate := range nodes { + if node.StartByte() == candidate.StartByte() && node.EndByte() == candidate.EndByte() { + return true + } + } + return false +} + +func collectMigrationIdentifiers(node *sitter.Node, source []byte, name string) []*sitter.Node { + var nodes []*sitter.Node + if node.Type() == "identifier" && schemaPython.Content(node, source) == name { + nodes = append(nodes, node) + } + for _, child := range schemaPython.NamedChildren(node) { + nodes = append(nodes, collectMigrationIdentifiers(child, source, name)...) + } + return nodes +} + +func collectIdentifiers(node *sitter.Node) []*sitter.Node { + var nodes []*sitter.Node + if node.Type() == "identifier" { + nodes = append(nodes, node) + } + for _, child := range schemaPython.NamedChildren(node) { + nodes = append(nodes, collectIdentifiers(child)...) + } + return nodes +} diff --git a/pkg/doctor/check_python_test.go b/pkg/doctor/check_python_test.go index 8dd79ad489..f136646e47 100644 --- a/pkg/doctor/check_python_test.go +++ b/pkg/doctor/check_python_test.go @@ -38,6 +38,50 @@ class Predictor(BasePredictor): require.Empty(t, findings) } +func TestMissingTypeAnnotationsCheck_DetectsRunMissingReturn(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "run.py", `from cog import BaseRunner + +class Runner(BaseRunner): + def run(self, text: str): + return text +`) + ctx := &CheckContext{ + ctx: context.Background(), + ProjectDir: dir, + Config: &config.Config{Predict: "run.py:Runner"}, + PythonFiles: parsePythonFiles(t, dir, "run.py"), + } + + check := &MissingTypeAnnotationsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Contains(t, findings[0].Message, "Runner.run()") +} + +func TestMissingTypeAnnotationsCheck_DetectsLegacyPredictMissingReturn(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor + +class Predictor(BasePredictor): + def predict(self, text: str): + return text +`) + ctx := &CheckContext{ + ctx: context.Background(), + ProjectDir: dir, + Config: &config.Config{Predict: "predict.py:Predictor"}, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &MissingTypeAnnotationsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Contains(t, findings[0].Message, "Predictor.predict()") +} + func TestPydanticBaseModelCheck_Detects(t *testing.T) { dir := t.TempDir() writeFile(t, dir, "predict.py", `from cog import BasePredictor, Path diff --git a/pkg/doctor/check_python_type_annotations.go b/pkg/doctor/check_python_type_annotations.go index 511cfcf64d..637d6f6146 100644 --- a/pkg/doctor/check_python_type_annotations.go +++ b/pkg/doctor/check_python_type_annotations.go @@ -24,7 +24,7 @@ func (c *MissingTypeAnnotationsCheck) Check(ctx *CheckContext) ([]Finding, error var findings []Finding if ctx.Config.Predict != "" { - f := checkMethodAnnotations(ctx, ctx.Config.Predict, "predict") + f := checkPredictMethodAnnotations(ctx, ctx.Config.Predict) findings = append(findings, f...) } @@ -40,6 +40,25 @@ func (c *MissingTypeAnnotationsCheck) Fix(_ *CheckContext, _ []Finding) error { return ErrNoAutoFix } +func checkPredictMethodAnnotations(ctx *CheckContext, ref string) []Finding { + fileName, className := splitPredictRef(ref) + if fileName == "" || className == "" { + return nil + } + pf, ok := ctx.PythonFiles[fileName] + if !ok { + return nil + } + classNode := findClass(pf.Tree.RootNode(), pf.Source, className) + if classNode == nil { + return nil + } + if findMethod(classNode, pf.Source, "run") != nil { + return checkMethodAnnotations(ctx, ref, "run") + } + return checkMethodAnnotations(ctx, ref, "predict") +} + // checkMethodAnnotations checks that the given method has a return type annotation. func checkMethodAnnotations(ctx *CheckContext, ref string, methodName string) []Finding { fileName, className := splitPredictRef(ref) diff --git a/pkg/doctor/registry.go b/pkg/doctor/registry.go index 5a2873138b..95038f3817 100644 --- a/pkg/doctor/registry.go +++ b/pkg/doctor/registry.go @@ -7,6 +7,7 @@ func AllChecks() []Check { // Config checks &ConfigParseCheck{}, &ConfigSchemaCheck{}, + &PredictToRunMigrationCheck{}, &ConfigDeprecatedFieldsCheck{}, &ConfigPredictRefCheck{}, diff --git a/pkg/schema/errors.go b/pkg/schema/errors.go index 2052da239a..40f7b595ee 100644 --- a/pkg/schema/errors.go +++ b/pkg/schema/errors.go @@ -23,6 +23,7 @@ const ( ErrDefaultFactoryNotSupported ErrInvalidConstraint ErrInvalidPredictRef + ErrMethodConflict ErrOptionalOutput ErrConcatIteratorNotStr ErrChoicesNotResolvable diff --git a/pkg/schema/python/parser.go b/pkg/schema/python/parser.go index c711ea0960..8d11928626 100644 --- a/pkg/schema/python/parser.go +++ b/pkg/schema/python/parser.go @@ -64,6 +64,10 @@ func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDi if err != nil { return nil, err } + actualMethodName := methodName + if nameNode := funcNode.ChildByFieldName("name"); nameNode != nil { + actualMethodName = Content(nameNode, source) + } // 6. Check if method (has self first param) paramsNode := funcNode.ChildByFieldName("parameters") @@ -74,7 +78,7 @@ func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDi // 7. Extract parameters paramCtx := &inputParseContext{ - methodName: methodName, + methodName: actualMethodName, imports: imports, registry: inputRegistry, scope: moduleScope, @@ -88,7 +92,7 @@ func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDi // 8. Extract return type returnAnn := funcNode.ChildByFieldName("return_type") if returnAnn == nil { - return nil, schema.WrapError(schema.ErrMissingReturnType, methodName, nil) + return nil, schema.WrapError(schema.ErrMissingReturnType, actualMethodName, nil) } returnTypeAnn, err := parseTypeAnnotation(returnAnn, source) if err != nil { @@ -1191,6 +1195,9 @@ func findTargetFunction(root *sitter.Node, source []byte, predictRef, methodName } nameNode := classNode.ChildByFieldName("name") if nameNode != nil && Content(nameNode, source) == predictRef { + if methodName == "predict" { + return findPredictMethodInClass(root, classNode, source, predictRef) + } return findMethodInClass(classNode, source, predictRef, methodName) } } @@ -1213,6 +1220,95 @@ func findTargetFunction(root *sitter.Node, source []byte, predictRef, methodName return nil, schema.WrapError(schema.ErrPredictorNotFound, predictRef, nil) } +func findPredictMethodInClass(root, classNode *sitter.Node, source []byte, className string) (*sitter.Node, error) { + runNode, predictNode := collectPredictMethods(root, classNode, source, className, map[string]bool{}) + if runNode != nil && predictNode != nil { + return nil, schema.WrapError(schema.ErrMethodConflict, fmt.Sprintf("%s must define either run() or predict(), not both", className), nil) + } + if runNode != nil { + return runNode, nil + } + if predictNode != nil { + fmt.Fprintf(os.Stderr, "cog: warning: %s.predict() is deprecated; use run() instead\n", className) + return predictNode, nil + } + return nil, schema.WrapError(schema.ErrMethodNotFound, fmt.Sprintf("%s must define run() or predict()", className), nil) +} + +func collectPredictMethods(root, classNode *sitter.Node, source []byte, className string, seen map[string]bool) (*sitter.Node, *sitter.Node) { + if seen[className] { + return nil, nil + } + seen[className] = true + + body := classNode.ChildByFieldName("body") + if body == nil { + return nil, nil + } + + var runNode *sitter.Node + var predictNode *sitter.Node + for _, child := range NamedChildren(body) { + funcNode := UnwrapFunction(child) + if funcNode == nil { + continue + } + nameNode := funcNode.ChildByFieldName("name") + if nameNode == nil { + continue + } + switch Content(nameNode, source) { + case "run": + runNode = funcNode + case "predict": + predictNode = funcNode + } + } + + for _, parent := range classParentNames(classNode, source) { + parentNode := findClassByName(root, source, parent) + if parentNode == nil { + continue + } + parentRun, parentPredict := collectPredictMethods(root, parentNode, source, parent, seen) + if runNode == nil { + runNode = parentRun + } + if predictNode == nil { + predictNode = parentPredict + } + } + return runNode, predictNode +} + +func findClassByName(root *sitter.Node, source []byte, name string) *sitter.Node { + for _, child := range NamedChildren(root) { + classNode := UnwrapClass(child) + if classNode == nil { + continue + } + nameNode := classNode.ChildByFieldName("name") + if nameNode != nil && Content(nameNode, source) == name { + return classNode + } + } + return nil +} + +func classParentNames(classNode *sitter.Node, source []byte) []string { + supers := classNode.ChildByFieldName("superclasses") + if supers == nil { + return nil + } + parents := []string{} + for _, child := range NamedChildren(supers) { + if child.Type() == "identifier" { + parents = append(parents, Content(child, source)) + } + } + return parents +} + func findMethodInClass(classNode *sitter.Node, source []byte, className, methodName string) (*sitter.Node, error) { body := classNode.ChildByFieldName("body") if body == nil { diff --git a/pkg/schema/python/parser_test.go b/pkg/schema/python/parser_test.go index 2b3d59445a..3745bdcace 100644 --- a/pkg/schema/python/parser_test.go +++ b/pkg/schema/python/parser_test.go @@ -67,6 +67,88 @@ class Predictor(BasePredictor): require.Equal(t, schema.TypeString, info.Output.Primitive) } +func TestSimpleStringRunner(t *testing.T) { + source := ` +from cog import BasePredictor + +class Predictor(BasePredictor): + def run(self, s: str) -> str: + return "hello " + s +` + info := parse(t, source, "Predictor") + require.Equal(t, 1, info.Inputs.Len()) + + s, ok := info.Inputs.Get("s") + require.True(t, ok) + require.Equal(t, schema.TypeString, s.FieldType.Primitive) + require.Equal(t, schema.Required, s.FieldType.Repetition) + require.Nil(t, s.Default) + require.True(t, s.IsRequired()) + + require.Equal(t, schema.SchemaPrimitive, info.Output.Kind) + require.Equal(t, schema.TypeString, info.Output.Primitive) +} + +func TestRunnerWithBothRunAndPredictErrors(t *testing.T) { + source := ` +from cog import BasePredictor + +class Predictor(BasePredictor): + def run(self, s: str) -> str: + return s + + def predict(self, s: str) -> str: + return s +` + se := parseErr(t, source, "Predictor", schema.ModePredict) + require.Equal(t, schema.ErrMethodConflict, se.Kind) + require.Contains(t, se.Message, "Predictor must define either run() or predict(), not both") +} + +func TestRunnerWithNoRunOrPredictErrors(t *testing.T) { + source := ` +from cog import BasePredictor + +class Predictor(BasePredictor): + def setup(self) -> None: + pass +` + se := parseErr(t, source, "Predictor", schema.ModePredict) + require.Equal(t, schema.ErrMethodNotFound, se.Kind) + require.Contains(t, se.Message, "Predictor must define run() or predict()") +} + +func TestRunnerWithInheritedRun(t *testing.T) { + source := ` +from cog import BaseRunner + +class Shared(BaseRunner): + def run(self, s: str) -> str: + return "hello " + s + +class Runner(Shared): + pass +` + info := parse(t, source, "Runner") + require.Equal(t, 1, info.Inputs.Len()) +} + +func TestRunnerWithInheritedRunAndDirectPredictErrors(t *testing.T) { + source := ` +from cog import BaseRunner + +class Shared(BaseRunner): + def run(self, s: str) -> str: + return s + +class Runner(Shared): + def predict(self, s: str) -> str: + return s +` + se := parseErr(t, source, "Runner", schema.ModePredict) + require.Equal(t, schema.ErrMethodConflict, se.Kind) +} + func TestMultipleInputsWithDefaults(t *testing.T) { source := ` from cog import BasePredictor, Input, Path diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 6c1fb8ce44..e3ae649fa6 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -4,14 +4,14 @@ This package provides the core types and classes for building Cog predictors. Example: - from cog import BasePredictor, Input, Path + from cog import BaseRunner, Input, Path - class Predictor(BasePredictor): + class Runner(BaseRunner): def setup(self): # Load model weights self.model = load_model() - def predict( + def run( self, prompt: str = Input(description="Input prompt"), image: Path = Input(description="Input image"), @@ -27,7 +27,7 @@ def predict( from ._version import __version__ from .input import FieldInfo, Input from .model import BaseModel -from .predictor import BasePredictor +from .predictor import BasePredictor, BaseRunner from .types import ( AsyncConcatenateIterator, ConcatenateIterator, @@ -115,6 +115,7 @@ def current_scope() -> object: # Version "__version__", # Core classes + "BaseRunner", "BasePredictor", "BaseModel", "Opaque", diff --git a/python/cog/_inspector.py b/python/cog/_inspector.py index fd062930f1..c7a7e0d3cc 100644 --- a/python/cog/_inspector.py +++ b/python/cog/_inspector.py @@ -10,6 +10,7 @@ import re import sys import typing +import warnings from dataclasses import MISSING, Field from enum import Enum from types import ModuleType, UnionType @@ -19,6 +20,7 @@ from .coder import Coder from .input import FieldInfo from .model import BaseModel +from .predictor import BasePredictor, _user_method_owner from .types import AsyncConcatenateIterator, ConcatenateIterator try: @@ -93,6 +95,36 @@ def _validate_predict(f: Callable[..., Any], f_name: str, is_class_fn: bool) -> raise ValueError(f"{f_name}() must have a return type annotation") +def _selected_predict_method( + cls: type[Any], fullname: str +) -> tuple[str, Callable[..., Any]]: + run_owner = _user_method_owner(cls, "run") + predict_owner = _user_method_owner(cls, "predict") + defines_run = run_owner is not None + defines_predict = predict_owner is not None + if defines_run and defines_predict: + raise ValueError(f"{fullname} must define either run() or predict(), not both") + if defines_run: + return "run", cls.run + if defines_predict: + warnings.warn( + f"{fullname}.predict() is deprecated; use run() instead", + DeprecationWarning, + stacklevel=3, + ) + return "predict", cls.predict + raise ValueError(f"run or predict method not found: {fullname}") + + +def _warn_if_base_predictor_in_mro(cls: type[Any]) -> None: + if any(base is BasePredictor for base in inspect.getmro(cls)[1:]): + warnings.warn( + "BasePredictor is deprecated; use BaseRunner instead", + DeprecationWarning, + stacklevel=3, + ) + + def _validate_input_constraints( name: str, ft: adt.FieldType, field_info: FieldInfo ) -> None: @@ -483,14 +515,12 @@ def create_predictor(module_name: str, predictor_name: str) -> adt.PredictorInfo p = getattr(module, predictor_name) if inspect.isclass(p): - if not hasattr(p, "predict"): - raise ValueError(f"predict method not found: {fullname}") - if hasattr(p, "setup"): _validate_setup(_unwrap(p.setup)) - predict_fn_name = "predict" - predict_fn = _unwrap(getattr(p, predict_fn_name)) + _warn_if_base_predictor_in_mro(p) + predict_fn_name, selected_predict_fn = _selected_predict_method(p, fullname) + predict_fn = _unwrap(selected_predict_fn) is_class_fn = True elif inspect.isfunction(p): diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 12fa6fde8d..c564064bfc 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -10,27 +10,27 @@ import inspect import os import sys +import warnings from typing import Any, Optional, Union from .types import Path -class BasePredictor: +class BaseRunner: """ - Base class for Cog predictors. + Base class for Cog runners. - Subclass this to define your model's prediction interface. Override - the `setup` method to load your model, and the `predict` method to - run predictions. + Subclass this to define your model's run interface. Override the `setup` + method to load your model, and the `run` method to execute it. Example: - from cog import BasePredictor, Input, Path + from cog import BaseRunner, Input, Path - class Predictor(BasePredictor): + class Runner(BaseRunner): def setup(self): self.model = load_model() - def predict(self, prompt: str = Input(description="Input text")) -> str: + def run(self, prompt: str = Input(description="Input text")) -> str: self.record_metric("temperature", 0.7) return self.model.generate(prompt) """ @@ -40,9 +40,9 @@ def setup( weights: Optional[Union[Path, str]] = None, ) -> None: """ - Prepare the model for predictions. + Prepare the model for runs. - This method is called once when the predictor is initialized. Use it + This method is called once when the runner is initialized. Use it to load model weights and do any other one-time setup. Args: @@ -50,24 +50,33 @@ def setup( """ pass - def predict(self, **kwargs: Any) -> Any: + def run(self, *args: Any, **kwargs: Any) -> Any: """ - Run a single prediction. + Run the model once. Override this method to implement your model's prediction logic. Input parameters should be annotated with types and optionally use Input() for additional metadata. Args: - **kwargs: Prediction inputs as defined by the method signature. + *args: Positional run inputs as defined by the method signature. + **kwargs: Keyword run inputs as defined by the method signature. Returns: The prediction output. Raises: - NotImplementedError: If predict is not implemented. + NotImplementedError: If run is not implemented. """ - raise NotImplementedError("predict has not been implemented by parent class.") + run_owner = _user_method_owner(self.__class__, "run") + predict_owner = _user_method_owner(self.__class__, "predict") + if predict_owner is not None and run_owner is None: + return self.predict(*args, **kwargs) + raise NotImplementedError("run has not been implemented by parent class.") + + def predict(self, *args: Any, **kwargs: Any) -> Any: + """Deprecated compatibility bridge to run().""" + return self.run(*args, **kwargs) @property def scope(self) -> Any: @@ -106,8 +115,8 @@ def record_metric(self, key: str, value: Any, mode: str = "replace") -> None: Example:: - class Predictor(BasePredictor): - def predict(self, prompt: str) -> str: + class Runner(BaseRunner): + def run(self, prompt: str) -> str: self.record_metric("temperature", 0.7) self.record_metric("token_count", 1, mode="incr") return self.model.generate(prompt) @@ -115,9 +124,48 @@ def predict(self, prompt: str) -> str: self.scope.record_metric(key, value, mode=mode) -def load_predictor_from_ref(ref: str) -> BasePredictor: +class BasePredictor(BaseRunner): + """Deprecated compatibility alias for BaseRunner.""" + + +def _user_method_owner(cls: type[Any], method_name: str) -> type[Any] | None: + for owner in inspect.getmro(cls): + if owner in {BaseRunner, BasePredictor, object}: + break + value = owner.__dict__.get(method_name) + if callable(value): + return owner + return None + + +def _validate_runner_class(cls: type[Any], class_name: str) -> None: + run_owner = _user_method_owner(cls, "run") + predict_owner = _user_method_owner(cls, "predict") + defines_run = run_owner is not None + defines_predict = predict_owner is not None + if defines_run and defines_predict: + raise ValueError( + f"{class_name} must define either run() or predict(), not both" + ) + if not defines_run and not defines_predict: + raise ValueError(f"run or predict method not found: {class_name}") + if defines_predict: + warnings.warn( + f"{class_name}.predict() is deprecated; use run() instead", + DeprecationWarning, + stacklevel=3, + ) + if any(base is BasePredictor for base in inspect.getmro(cls)[1:]): + warnings.warn( + "BasePredictor is deprecated; use BaseRunner instead", + DeprecationWarning, + stacklevel=3, + ) + + +def load_predictor_from_ref(ref: str) -> BaseRunner: """Load a predictor from a module:class reference (e.g. 'predict.py:Predictor').""" - module_path, class_name = ref.split(":", 1) if ":" in ref else (ref, "Predictor") + module_path, explicit_class_name = ref.split(":", 1) if ":" in ref else (ref, None) module_name = os.path.basename(module_path).replace(".py", "") # Use spec_from_file_location to load from file path @@ -129,14 +177,37 @@ def load_predictor_from_ref(ref: str) -> BasePredictor: sys.modules[module_name] = module spec.loader.exec_module(module) + if explicit_class_name is None: + if hasattr(module, "Runner"): + if hasattr(module, "Predictor"): + warnings.warn( + "Both Runner and Predictor are defined; using Runner. Specify a class " + "name explicitly if this is not intended.", + UserWarning, + stacklevel=2, + ) + class_name = "Runner" + elif hasattr(module, "Predictor"): + warnings.warn( + "Predictor is deprecated; use Runner instead", + DeprecationWarning, + stacklevel=2, + ) + class_name = "Predictor" + else: + raise AttributeError(f"module {module_name!r} has no Runner or Predictor") + else: + class_name = explicit_class_name + predictor = getattr(module, class_name) # It could be a class or a function (for training) if inspect.isclass(predictor): + _validate_runner_class(predictor, class_name) return predictor() return predictor -def has_setup_weights(predictor: BasePredictor) -> bool: +def has_setup_weights(predictor: BaseRunner) -> bool: """Check if predictor's setup accepts a weights parameter.""" if not hasattr(predictor, "setup"): return False @@ -144,7 +215,7 @@ def has_setup_weights(predictor: BasePredictor) -> bool: return "weights" in sig.parameters -def extract_setup_weights(predictor: BasePredictor) -> Optional[Union[Path, str]]: +def extract_setup_weights(predictor: BaseRunner) -> Optional[Union[Path, str]]: """Extract weights from environment for setup.""" weights = os.environ.get("COG_WEIGHTS") if weights: diff --git a/python/tests/test_inspector.py b/python/tests/test_inspector.py index a205a660c3..784d6e2d06 100644 --- a/python/tests/test_inspector.py +++ b/python/tests/test_inspector.py @@ -1,10 +1,143 @@ +from pathlib import Path from typing import Annotated, List, Optional import pytest from cog import BaseModel, Opaque from cog import _adt as adt -from cog._inspector import _create_predictor_info +from cog._inspector import _create_predictor_info, create_predictor + + +def test_inspector_uses_run_method_for_classes( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + module_name = "runner_module_run_method" + (tmp_path / f"{module_name}.py").write_text( + "class Runner:\n def run(self, value: str) -> str:\n return value\n" + ) + monkeypatch.syspath_prepend(str(tmp_path)) + + info = create_predictor(module_name, "Runner") + assert "value" in info.inputs + + +def test_inspector_warns_for_legacy_predict_method( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + module_name = "runner_module_predict_method" + (tmp_path / f"{module_name}.py").write_text( + "class Runner:\n" + " def predict(self, value: str) -> str:\n" + " return value\n" + ) + monkeypatch.syspath_prepend(str(tmp_path)) + + with pytest.warns(DeprecationWarning, match=r"Runner\.predict\(\) is deprecated"): + info = create_predictor(module_name, "Runner") + assert "value" in info.inputs + + +def test_inspector_warns_for_base_predictor_inheritance( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + module_name = "runner_module_base_predictor" + (tmp_path / f"{module_name}.py").write_text( + "from cog import BasePredictor\n" + "class Runner(BasePredictor):\n" + " def run(self, value: str) -> str:\n" + " return value\n" + ) + monkeypatch.syspath_prepend(str(tmp_path)) + + with pytest.warns(DeprecationWarning, match="BasePredictor is deprecated"): + info = create_predictor(module_name, "Runner") + assert "value" in info.inputs + + +def test_inspector_supports_inherited_run( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + module_name = "runner_module_inherited_run" + (tmp_path / f"{module_name}.py").write_text( + "from cog import BaseRunner\n" + "class Shared(BaseRunner):\n" + " def run(self, value: str) -> str:\n" + " return value\n" + "class Runner(Shared):\n" + " pass\n" + ) + monkeypatch.syspath_prepend(str(tmp_path)) + + info = create_predictor(module_name, "Runner") + assert "value" in info.inputs + + +def test_inspector_rejects_inherited_run_and_direct_predict( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + module_name = "runner_module_inherited_conflict" + (tmp_path / f"{module_name}.py").write_text( + "from cog import BaseRunner\n" + "class Shared(BaseRunner):\n" + " def run(self, value: str) -> str:\n" + " return value\n" + "class Runner(Shared):\n" + " def predict(self, value: str) -> str:\n" + " return value\n" + ) + monkeypatch.syspath_prepend(str(tmp_path)) + + with pytest.raises(ValueError, match=r"either run\(\) or predict\(\)"): + create_predictor(module_name, "Runner") + + +def test_inspector_warns_for_inherited_legacy_predict( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + module_name = "runner_module_inherited_predict" + (tmp_path / f"{module_name}.py").write_text( + "from cog import BasePredictor\n" + "class Shared(BasePredictor):\n" + " def predict(self, value: str) -> str:\n" + " return value\n" + "class Predictor(Shared):\n" + " pass\n" + ) + monkeypatch.syspath_prepend(str(tmp_path)) + + with pytest.warns(DeprecationWarning, match=r"predict\(\) is deprecated"): + info = create_predictor(module_name, "Predictor") + assert "value" in info.inputs + + +def test_inspector_rejects_class_with_run_and_predict( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + module_name = "runner_module_conflict" + (tmp_path / f"{module_name}.py").write_text( + "class Runner:\n" + " def run(self, value: str) -> str:\n" + " return value\n" + " def predict(self, value: str) -> str:\n" + " return value\n" + ) + monkeypatch.syspath_prepend(str(tmp_path)) + + with pytest.raises(ValueError, match=r"either run\(\) or predict\(\)"): + create_predictor(module_name, "Runner") + + +def test_inspector_errors_when_class_has_no_run_or_predict( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + module_name = "runner_module_missing_method" + (tmp_path / f"{module_name}.py").write_text( + "class Runner:\n def setup(self) -> None:\n pass\n" + ) + monkeypatch.syspath_prepend(str(tmp_path)) + + with pytest.raises(ValueError, match="run.*predict|predict.*run"): + create_predictor(module_name, "Runner") class ExternalObject: diff --git a/python/tests/test_predictor.py b/python/tests/test_predictor.py index 76340ce359..ebde3294d9 100644 --- a/python/tests/test_predictor.py +++ b/python/tests/test_predictor.py @@ -1,8 +1,220 @@ """Tests for cog.predictor module (BasePredictor).""" +from pathlib import Path as FilePath from typing import Optional -from cog import BasePredictor, Path +import pytest + +from cog import BasePredictor, BaseRunner, Path + + +def test_base_runner_run_and_predict_bridge() -> None: + class MyRunner(BaseRunner): + def run(self, text: str) -> str: + return text.upper() + + runner = MyRunner() + assert runner.run(text="hello") == "HELLO" + assert runner.predict("hello") == "HELLO" + assert runner.predict(text="hello") == "HELLO" + + +def test_base_runner_run_delegates_to_legacy_predict_with_positional_args() -> None: + class MyRunner(BaseRunner): + def predict(self, text: str) -> str: + return text.upper() + + runner = MyRunner() + assert runner.run("hello") == "HELLO" + assert runner.run(text="hello") == "HELLO" + + +def test_base_predictor_is_legacy_subclass() -> None: + assert issubclass(BasePredictor, BaseRunner) + + +def test_load_predictor_from_ref_defaults_to_runner(tmp_path: FilePath) -> None: + model = tmp_path / "run.py" + model.write_text( + "from cog import BaseRunner\n" + "class Runner(BaseRunner):\n" + " def run(self, text: str) -> str:\n" + " return text.upper()\n" + ) + + from cog.predictor import load_predictor_from_ref + + runner = load_predictor_from_ref(str(model)) + assert runner.run(text="hello") == "HELLO" + + +def test_load_predictor_from_ref_warns_for_legacy_predictor_class( + tmp_path: FilePath, +) -> None: + model = tmp_path / "predict.py" + model.write_text( + "from cog import BaseRunner\n" + "class Predictor(BaseRunner):\n" + " def run(self, text: str) -> str:\n" + " return text.upper()\n" + ) + + from cog.predictor import load_predictor_from_ref + + with pytest.warns(DeprecationWarning, match="Predictor is deprecated"): + runner = load_predictor_from_ref(str(model)) + assert runner.run(text="hello") == "HELLO" + + +def test_load_predictor_from_ref_prefers_runner_when_both_default_classes_exist( + tmp_path: FilePath, +) -> None: + model = tmp_path / "run.py" + model.write_text( + "from cog import BaseRunner\n" + "class Runner(BaseRunner):\n" + " def run(self, text: str) -> str:\n" + " return 'runner:' + text\n" + "class Predictor(BaseRunner):\n" + " def run(self, text: str) -> str:\n" + " return 'predictor:' + text\n" + ) + + from cog.predictor import load_predictor_from_ref + + with pytest.warns(UserWarning, match="Both Runner and Predictor"): + runner = load_predictor_from_ref(str(model)) + assert runner.run(text="hello") == "runner:hello" + + +def test_load_predictor_from_ref_rejects_run_and_predict( + tmp_path: FilePath, +) -> None: + model = tmp_path / "run.py" + model.write_text( + "from cog import BaseRunner\n" + "class Runner(BaseRunner):\n" + " def run(self, text: str) -> str:\n" + " return text\n" + " def predict(self, text: str) -> str:\n" + " return text\n" + ) + + from cog.predictor import load_predictor_from_ref + + with pytest.raises( + ValueError, match=r"define either run\(\) or predict\(\), not both" + ): + load_predictor_from_ref(str(model)) + + +def test_load_predictor_from_ref_rejects_missing_run_or_predict( + tmp_path: FilePath, +) -> None: + model = tmp_path / "run.py" + model.write_text( + "from cog import BaseRunner\nclass Runner(BaseRunner):\n pass\n" + ) + + from cog.predictor import load_predictor_from_ref + + with pytest.raises(ValueError, match="run or predict"): + load_predictor_from_ref(str(model)) + + +def test_load_predictor_from_ref_warns_for_class_predict_method( + tmp_path: FilePath, +) -> None: + model = tmp_path / "run.py" + model.write_text( + "from cog import BaseRunner\n" + "class Runner(BaseRunner):\n" + " def predict(self, text: str) -> str:\n" + " return text.upper()\n" + ) + + from cog.predictor import load_predictor_from_ref + + with pytest.warns(DeprecationWarning, match=r"Runner\.predict\(\) is deprecated"): + runner = load_predictor_from_ref(str(model)) + assert runner.predict(text="hello") == "HELLO" + + +def test_load_predictor_from_ref_warns_for_base_predictor_inheritance( + tmp_path: FilePath, +) -> None: + model = tmp_path / "run.py" + model.write_text( + "from cog import BasePredictor\n" + "class Runner(BasePredictor):\n" + " def run(self, text: str) -> str:\n" + " return text.upper()\n" + ) + + from cog.predictor import load_predictor_from_ref + + with pytest.warns(DeprecationWarning, match="BasePredictor is deprecated"): + runner = load_predictor_from_ref(str(model)) + assert runner.run(text="hello") == "HELLO" + + +def test_load_predictor_from_ref_supports_inherited_run( + tmp_path: FilePath, +) -> None: + model = tmp_path / "run.py" + model.write_text( + "from cog import BaseRunner\n" + "class Shared(BaseRunner):\n" + " def run(self, text: str) -> str:\n" + " return text.upper()\n" + "class Runner(Shared):\n" + " pass\n" + ) + + from cog.predictor import load_predictor_from_ref + + runner = load_predictor_from_ref(str(model)) + assert runner.run(text="hello") == "HELLO" + + +def test_load_predictor_from_ref_rejects_inherited_run_and_direct_predict( + tmp_path: FilePath, +) -> None: + model = tmp_path / "run.py" + model.write_text( + "from cog import BaseRunner\n" + "class Shared(BaseRunner):\n" + " def run(self, text: str) -> str:\n" + " return text\n" + "class Runner(Shared):\n" + " def predict(self, text: str) -> str:\n" + " return text\n" + ) + + from cog.predictor import load_predictor_from_ref + + with pytest.raises(ValueError, match=r"either run\(\) or predict\(\)"): + load_predictor_from_ref(str(model)) + + +def test_load_predictor_from_ref_warns_for_inherited_legacy_predict( + tmp_path: FilePath, +) -> None: + model = tmp_path / "predict.py" + model.write_text( + "from cog import BasePredictor\n" + "class Shared(BasePredictor):\n" + " def predict(self, text: str) -> str:\n" + " return text.upper()\n" + "class Predictor(Shared):\n" + " pass\n" + ) + + from cog.predictor import load_predictor_from_ref + + with pytest.warns(DeprecationWarning, match=r"predict\(\) is deprecated"): + runner = load_predictor_from_ref(str(model)) + assert runner.predict(text="hello") == "HELLO" class TestBasePredictor: @@ -23,7 +235,7 @@ def test_default_predict_raises(self) -> None: predictor.predict() assert False, "Should have raised NotImplementedError" except NotImplementedError as e: - assert "predict has not been implemented" in str(e) + assert "run has not been implemented" in str(e) def test_setup_is_optional(self) -> None: class MyPredictor(BasePredictor):