diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..afa9e06a2 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,10 @@ +**/__pycache__/ +.pytest_cache/ +.git +.github +.vscode +.DS_Store +.env + + +tests \ No newline at end of file diff --git a/.python-version b/.python-version new file mode 100644 index 000000000..8531a3b7e --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12.2 diff --git a/Dockerfile.api b/Dockerfile.api new file mode 100644 index 000000000..fc63626dc --- /dev/null +++ b/Dockerfile.api @@ -0,0 +1,23 @@ +FROM python:3.12 + +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +# Install Poetry +RUN pip install --no-cache-dir poetry + +# Copy only requirements to cache them in docker layer +COPY pyproject.toml poetry.lock* ./ + +# Project initialization: +RUN poetry config virtualenvs.create false \ + && poetry install --no-interaction --no-ansi + +# Copy project +COPY src . + +CMD ["python", "-m", "ell.api"] \ No newline at end of file diff --git a/Dockerfile.studio b/Dockerfile.studio new file mode 100644 index 000000000..760a6c84d --- /dev/null +++ b/Dockerfile.studio @@ -0,0 +1,44 @@ +# Start with a Node.js base image for building the React app +FROM node:20 AS client-builder + +WORKDIR /app/ell-studio + +# Copy package.json and package-lock.json (if available) +COPY ell-studio/package.json ell-studio/package-lock.json* ./ + +# Install dependencies +RUN npm ci + +# Copy the rest of the client code +COPY ell-studio . + +# Build the React app +RUN npm run build + +# Now, start with the Python base image +FROM python:3.12 + +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +# Install Poetry +RUN pip install --no-cache-dir poetry + +# Copy only requirements to cache them in docker layer +COPY pyproject.toml poetry.lock* ./ + +# Project initialization: +RUN poetry config virtualenvs.create false \ + && poetry install --no-interaction --no-ansi + +# Copy the Python project +COPY src . + +# Copy the built React app from the client-builder stage +COPY --from=client-builder /app/ell-studio/build /app/ell/studio/static + +CMD ["python", "-m", "ell.studio"] \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 000000000..bcd2ee930 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,74 @@ +services: + api: + build: + context: . + dockerfile: Dockerfile.api + tags: + - ell-api + + ports: + - "8081:8081" + environment: + - HOST=0.0.0.0 + - PORT=8081 + - ELL_PG_CONNECTION_STRING=postgresql://ell_user:ell_password@postgres:5432/ell_db + - ELL_MQTT_CONNECTION_STRING=mqtt://mqtt:1883 + - LOG_LEVEL=DEBUG + depends_on: + - postgres + - mqtt + + studio: + build: + context: . + dockerfile: Dockerfile.studio + tags: + - ell-studio + ports: + - "8080:8080" + environment: + - HOST=0.0.0.0 + - PORT=8080 # currently doesn't take effect -- cli defaults it + - ELL_PG_CONNECTION_STRING=postgresql://ell_user:ell_password@postgres:5432/ell_db + - ELL_MQTT_CONNECTION_STRING=mqtt://mqtt:1883 + depends_on: + - postgres + - mqtt + develop: + watch: + - action: sync+restart + path: ./src/ell/studio + target: /app/ell/studio + + mqtt: + image: eclipse-mosquitto:latest + ports: + - "1883:1883" + command: mosquitto -c /mosquitto/config/mosquitto.conf + volumes: + - mosquitto_config:/mosquitto/config + depends_on: + - mqtt-config + + mqtt-config: + image: busybox + volumes: + - mosquitto_config:/mosquitto/config + command: > + sh -c "echo 'listener 1883' > /mosquitto/config/mosquitto.conf && + echo 'allow_anonymous true' >> /mosquitto/config/mosquitto.conf" + + postgres: + image: postgres:16 + environment: + - POSTGRES_USER=ell_user + - POSTGRES_PASSWORD=ell_password + - POSTGRES_DB=ell_db + volumes: + - postgres_data:/var/lib/postgresql/data + ports: + - "5432:5432" + +volumes: + postgres_data: + mosquitto_config: \ No newline at end of file diff --git a/examples/calculator.py b/examples/calculator.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/calculator_structured.py b/examples/calculator_structured.py index 2ee7c9878..fc1623198 100644 --- a/examples/calculator_structured.py +++ b/examples/calculator_structured.py @@ -3,7 +3,6 @@ from typing import Any, Literal, Type, Union import pydantic -from ell.stores.sql import SQLiteStore ell.config.verbose = True @@ -69,5 +68,10 @@ def calc_structured(task: str) -> float: if __name__ == "__main__": - ell.set_store('./logdir', autocommit=True) + # Local + ell.init(storage_dir='./logdir', autocommit=True) + + # API server + # ell.init(base_url="http://localhost:8081", autocommit=True) + print(calc_structured("What is two plus two?")) diff --git a/poetry.lock b/poetry.lock index 32f7ff3a7..5dcc3cefb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,19 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. + +[[package]] +name = "aiomqtt" +version = "2.3.0" +description = "The idiomatic asyncio MQTT client, wrapped around paho-mqtt" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "aiomqtt-2.3.0-py3-none-any.whl", hash = "sha256:127926717bd6b012d1630f9087f24552eb9c4af58205bc2964f09d6e304f7e63"}, + {file = "aiomqtt-2.3.0.tar.gz", hash = "sha256:312feebe20bc76dc7c20916663011f3bd37aa6f42f9f687a19a1c58308d80d47"}, +] + +[package.dependencies] +paho-mqtt = ">=2.1.0,<3.0.0" +typing-extensions = {version = ">=4.4.0,<5.0.0", markers = "python_version < \"3.10\""} [[package]] name = "alabaster" @@ -369,6 +384,20 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "faker" +version = "28.4.1" +description = "Faker is a Python package that generates fake data for you." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Faker-28.4.1-py3-none-any.whl", hash = "sha256:e59c01d1e8b8e20a83255ab8232c143cb2af3b4f5ab6a3f5ce495f385ad8ab4c"}, + {file = "faker-28.4.1.tar.gz", hash = "sha256:4294d169255a045990720d6f3fa4134b764a4cdf46ef0d3c7553d2506f1adaa1"}, +] + +[package.dependencies] +python-dateutil = ">=2.4" + [[package]] name = "fastapi" version = "0.111.1" @@ -940,6 +969,20 @@ files = [ {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] +[[package]] +name = "paho-mqtt" +version = "2.1.0" +description = "MQTT version 5.0/3.1.1 client class" +optional = false +python-versions = ">=3.7" +files = [ + {file = "paho_mqtt-2.1.0-py3-none-any.whl", hash = "sha256:6db9ba9b34ed5bc6b6e3812718c7e06e2fd7444540df2455d2c51bd58808feee"}, + {file = "paho_mqtt-2.1.0.tar.gz", hash = "sha256:12d6e7511d4137555a3f6ea167ae846af2c7357b10bc6fa4f7c3968fc1723834"}, +] + +[package.extras] +proxy = ["pysocks"] + [[package]] name = "pathspec" version = "0.12.1" @@ -1079,6 +1122,52 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "polyfactory" +version = "2.16.2" +description = "Mock data generation factories" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "polyfactory-2.16.2-py3-none-any.whl", hash = "sha256:e5eaf97358fee07d0d8de86a93e81dc56e3be1e1514d145fea6c5f486cda6ea1"}, + {file = "polyfactory-2.16.2.tar.gz", hash = "sha256:6d0d90deb85e5bb1733ea8744c2d44eea2b31656e11b4fa73832d2e2ab5422da"}, +] + +[package.dependencies] +faker = "*" +typing-extensions = ">=4.6.0" + +[package.extras] +attrs = ["attrs (>=22.2.0)"] +beanie = ["beanie", "pydantic[email]"] +full = ["attrs", "beanie", "msgspec", "odmantic", "pydantic", "sqlalchemy"] +msgspec = ["msgspec"] +odmantic = ["odmantic (<1.0.0)", "pydantic[email]"] +pydantic = ["pydantic[email]"] +sqlalchemy = ["sqlalchemy (>=1.4.29)"] + +[[package]] +name = "psycopg2" +version = "2.9.9" +description = "psycopg2 - Python-PostgreSQL Database Adapter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "psycopg2-2.9.9-cp310-cp310-win32.whl", hash = "sha256:38a8dcc6856f569068b47de286b472b7c473ac7977243593a288ebce0dc89516"}, + {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, + {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, + {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, + {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, + {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, + {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, + {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, + {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, + {file = "psycopg2-2.9.9-cp38-cp38-win_amd64.whl", hash = "sha256:bac58c024c9922c23550af2a581998624d6e02350f4ae9c5f0bc642c633a2d5e"}, + {file = "psycopg2-2.9.9-cp39-cp39-win32.whl", hash = "sha256:c92811b2d4c9b6ea0285942b2e7cac98a59e166d59c588fe5cfe1eda58e72d59"}, + {file = "psycopg2-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:de80739447af31525feddeb8effd640782cf5998e1a4e9192ebdf829717e3913"}, + {file = "psycopg2-2.9.9.tar.gz", hash = "sha256:d1454bde93fb1e224166811694d600e746430c006fbb031ea06ecc2ea41bf156"}, +] + [[package]] name = "pydantic" version = "2.8.2" @@ -1238,6 +1327,37 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-mock" +version = "3.14.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"}, + {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, +] + +[package.dependencies] +pytest = ">=6.2.5" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +description = "Extensions to the standard Python datetime module" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, +] + +[package.dependencies] +six = ">=1.5" + [[package]] name = "python-dotenv" version = "1.0.1" @@ -1376,6 +1496,17 @@ files = [ {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"}, ] +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -2016,5 +2147,5 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" -python-versions = ">=3.9" -content-hash = "775e2111d96f8a406d1296ae2461ab6f4d51c6cc25410a58d51257b2d2a15500" +python-versions = ">=3.9,<4" +content-hash = "cbf93a0b1bbe4bf381f4dd6a1ac2f70621eabaf102731efcbd43215dbfa5765c" diff --git a/pyproject.toml b/pyproject.toml index 1c60d869e..4e86d76d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ include = [ ] [tool.poetry.dependencies] -python = ">=3.9" +python = ">=3.9,<4" fastapi = "^0.111.1" numpy = "^2.0.1" dill = "^0.3.8" @@ -41,8 +41,12 @@ typing-extensions = "^4.12.2" black = "^24.8.0" json-fix = "^1.0.0" pillow = "^10.4.0" +aiomqtt = "^2.3.0" +psycopg2 = "^2.9.9" [tool.poetry.group.dev.dependencies] pytest = "^8.3.2" +polyfactory = "^2.16.2" +pytest-mock = "^3.14.0" sphinx = "<8.0.0" sphinx-rtd-theme = "^2.0.0" diff --git a/src/ell/__version__.py b/src/ell/__version__.py index ffc218495..5fb3a6916 100644 --- a/src/ell/__version__.py +++ b/src/ell/__version__.py @@ -1,6 +1,6 @@ -try: - from importlib.metadata import version -except ImportError: - from importlib_metadata import version +from importlib.metadata import version, PackageNotFoundError -__version__ = version("ell-ai") +try: + __version__ = version("ell-ai") +except PackageNotFoundError: + __version__ = "unknown" \ No newline at end of file diff --git a/src/ell/api/__init__.py b/src/ell/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/ell/api/__main__.py b/src/ell/api/__main__.py new file mode 100644 index 000000000..4d0065ee9 --- /dev/null +++ b/src/ell/api/__main__.py @@ -0,0 +1,55 @@ +import asyncio +import os +from typing import cast +import uvicorn +from argparse import ArgumentParser +from ell.api.config import Config +from ell.api.server import create_app +from ell.api.logger import setup_logging, LogLevel + + + +def main(): + log_level = cast(LogLevel, os.environ.get("LOG_LEVEL", "INFO")) + setup_logging(level=log_level) + + parser = ArgumentParser(description="ELL API Server") + parser.add_argument("--storage-dir", default=None, + help="Storage directory (default: None)") + parser.add_argument("--pg-connection-string", default=None, + help="PostgreSQL connection string (default: None)") + parser.add_argument("--mqtt-connection-string", default=None, + help="MQTT connection string (default: None)") + parser.add_argument("--host", default=None, + help="Host to run the server on") + parser.add_argument("--port", type=int, default=None, + help="Port to run the server on") + parser.add_argument("--dev", action="store_true", + help="Run in development mode") + args = parser.parse_args() + + config = Config( + storage_dir=args.storage_dir, + pg_connection_string=args.pg_connection_string, + mqtt_connection_string=args.mqtt_connection_string, + ) + + app = create_app(config) + + loop = asyncio.new_event_loop() + + config = uvicorn.Config( + app=app, + host=args.host if args.host else os.environ.get("HOST", "0.0.0.0"), + port=args.port if args.port else int(os.environ.get("PORT", 8081)), + loop=loop # type: ignore + ) + server = uvicorn.Server(config) + + loop.create_task(server.serve()) + + loop.run_forever() + + +if __name__ == "__main__": + main() diff --git a/src/ell/api/client.py b/src/ell/api/client.py new file mode 100644 index 000000000..973cd6e75 --- /dev/null +++ b/src/ell/api/client.py @@ -0,0 +1,141 @@ +import httpx +from typing import Any, Dict, Optional, Protocol, List +from ell.api.types import LMP, GetLMPResponse, WriteLMPInput, WriteInvocationInput +from ell.stores.sql import SQLiteStore +from ell.sqlmodels import SerializedLMP +import logging +from httpx import HTTPStatusError + + +class EllClient(Protocol): + async def get_lmp(self, lmp_id: str) -> GetLMPResponse: + ... + + async def write_lmp(self, lmp: WriteLMPInput, uses: List[str]) -> None: + ... + + async def write_invocation(self, input: WriteInvocationInput) -> None: + ... + + async def store_blob(self, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str: + ... + + async def retrieve_blob(self, blob_id: str) -> bytes: + ... + + async def close(self): + ... + + async def get_lmp_versions(self, fqn: str) -> List[LMP]: + ... + + +class EllAPIClient(EllClient): + def __init__(self, base_url: str): + self.base_url = base_url + self.client = httpx.AsyncClient(base_url=base_url) + + async def get_lmp(self, lmp_id: str) -> GetLMPResponse: + response = await self.client.get(f"/lmp/{lmp_id}") + response.raise_for_status() + data = response.json() + if data is None: + return None + return LMP(**data) + + async def write_lmp(self, lmp: WriteLMPInput, uses: List[str]) -> None: + try: + response = await self.client.post("/lmp", json={ + "lmp": lmp.model_dump(mode="json"), + "uses": uses + }) + response.raise_for_status() + except HTTPStatusError as e: + if e.response.status_code == 422: + error_detail = e.response.json().get("detail", "No detailed error message provided") + logging.error(f"Unprocessable Entity (422) Error: {error_detail}") + raise ValueError(f"Invalid input: {error_detail}") from e + raise + + async def write_invocation(self, input: WriteInvocationInput) -> None: + response = await self.client.post( + "/invocation", + json=input.model_dump(mode="json") + ) + response.raise_for_status() + return None + + async def store_blob(self, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str: + response = await self.client.post("/blob", data={ + "blob": blob, + "metadata": metadata + }) + response.raise_for_status() + return response.json()["blob_id"] + + async def retrieve_blob(self, blob_id: str) -> bytes: + response = await self.client.get(f"/blob/{blob_id}") + response.raise_for_status() + return response.content + + async def close(self): + await self.client.aclose() + + async def __aenter__(self): + return self + + async def __aexit__(self): + await self.close() + + async def get_lmp_versions(self, fqn: str) -> List[LMP]: + response = await self.client.get("/lmp/versions", params={"fqn": fqn}) + response.raise_for_status() + data = response.json() + return [LMP(**lmp_data) for lmp_data in data] + + +class EllSqliteClient(EllClient): + def __init__(self, storage_dir: str): + self.store = SQLiteStore(storage_dir) + + async def get_lmp(self, lmp_id: str): + lmp = self.store.get_lmp(lmp_id) + if lmp: + return LMP(**lmp.model_dump()) + return None + + async def get_lmp_versions(self, fqn: str) -> List[LMP]: + slmps = self.store.get_versions_by_fqn(fqn) + return [LMP(**slmp.model_dump()) for slmp in slmps] + + async def write_lmp(self, lmp: WriteLMPInput, uses: List[str]) -> None: + serialized_lmp = SerializedLMP(**lmp.model_dump()) + self.store.write_lmp(serialized_lmp, uses) + + async def write_invocation(self, input: WriteInvocationInput) -> None: + invocation, consumes = input.to_serialized_invocation_input() + self.store.write_invocation( + invocation, + set(consumes) + ) + return None + + async def store_blob(self, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str: + if self.store.blob_store is None: + raise ValueError("Blob store is not enabled") + return self.store.blob_store.store_blob(blob, metadata) + + async def retrieve_blob(self, blob_id: str) -> bytes: + if self.store.blob_store is None: + raise ValueError("Blob store is not enabled") + return self.store.blob_store.retrieve_blob(blob_id) + + async def close(self): + # SQLiteStore doesn't have a close method, so this is a no-op + pass + + async def __aenter__(self): + return self + + async def __aexit__(self): + await self.close() diff --git a/src/ell/api/config.py b/src/ell/api/config.py new file mode 100644 index 000000000..37086c14d --- /dev/null +++ b/src/ell/api/config.py @@ -0,0 +1,47 @@ +from functools import lru_cache +import json +import os +from typing import Any, Optional +from pydantic import BaseModel + +import logging + +logger = logging.getLogger(__name__) + + +# todo. maybe we default storage dir and other things in the future to a well-known location +# like ~/.ell or something +@lru_cache(maxsize=1) +def ell_home() -> str: + return os.path.join(os.path.expanduser("~"), ".ell") + + +class Config(BaseModel): + storage_dir: Optional[str] = None + pg_connection_string: Optional[str] = None + mqtt_connection_string: Optional[str] = None + log_level: int = logging.INFO + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + + def model_post_init(self, __context: Any): + # Storage + self.pg_connection_string = self.pg_connection_string or os.getenv( + "ELL_PG_CONNECTION_STRING") + self.storage_dir = self.storage_dir or os.getenv("ELL_STORAGE_DIR") + + # Enforce that we use either sqlite or postgres, but not both + if self.pg_connection_string is not None and self.storage_dir is not None: + raise ValueError("Cannot use both sqlite and postgres") + + # For now, fall back to sqlite if no PostgreSQL connection string is provided + if self.pg_connection_string is None and self.storage_dir is None: + # This intends to honor the default we had set in the CLI + # todo. better default? + self.storage_dir = os.getcwd() + + # Pubsub + self.mqtt_connection_string = self.mqtt_connection_string or os.getenv("ELL_MQTT_CONNECTION_STRING") + + logger.info(f"Resolved config: {json.dumps(self.model_dump(), indent=2)}") + diff --git a/src/ell/api/logger.py b/src/ell/api/logger.py new file mode 100644 index 000000000..4dac81a29 --- /dev/null +++ b/src/ell/api/logger.py @@ -0,0 +1,46 @@ +import logging +from typing import Literal, Union +from colorama import Fore, Style, init + + +LogLevel = Union[Literal['CRITICAL'], Literal['FATAL'], Literal['ERROR'], Literal['WARN'], + Literal['WARNING'], Literal['INFO'], Literal['DEBUG'], Literal['NOTSET']] + + +initialized = False +def setup_logging(level: LogLevel = "INFO"): + global initialized + if initialized: + return + # Initialize colorama for cross-platform colored output + init(autoreset=True) + + # Create a custom formatter + class ColoredFormatter(logging.Formatter): + FORMATS = { + logging.DEBUG: Fore.CYAN + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.INFO: Fore.GREEN + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.WARNING: Fore.YELLOW + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.ERROR: Fore.RED + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.CRITICAL: Fore.RED + Style.BRIGHT + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL + } + + def format(self, record: logging.LogRecord) -> str: + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt, datefmt="%Y-%m-%d %H:%M:%S") + return formatter.format(record) + + # Create and configure the logger + logger = logging.getLogger("ell") + logger.setLevel(logging.getLevelName(level.upper())) + + # Create console handler and set formatter + console_handler = logging.StreamHandler() + console_handler.setFormatter(ColoredFormatter()) + + # Add the handler to the logger + logger.addHandler(console_handler) + + initialized = True + + return logger \ No newline at end of file diff --git a/src/ell/api/publisher.py b/src/ell/api/publisher.py new file mode 100644 index 000000000..cace87e4e --- /dev/null +++ b/src/ell/api/publisher.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod + +import aiomqtt + + +class Publisher(ABC): + @abstractmethod + async def publish(self, topic: str, message: str) -> None: + pass + + +class MqttPub(Publisher): + def __init__(self, conn: aiomqtt.Client): + self.mqtt_client = conn + + async def publish(self, topic: str, message: str) -> None: + await self.mqtt_client.publish(topic, message) + + +class NoopPublisher(Publisher): + async def publish(self, topic: str, message: str) -> None: + pass \ No newline at end of file diff --git a/src/ell/api/server.py b/src/ell/api/server.py new file mode 100644 index 000000000..c40b8c49b --- /dev/null +++ b/src/ell/api/server.py @@ -0,0 +1,157 @@ +import asyncio +from contextlib import asynccontextmanager +import json +import logging +from typing import Any, Dict, List, Optional + +import aiomqtt +from fastapi import Depends, FastAPI, HTTPException +from sqlmodel import Session +from ell.api.config import Config +from ell.api.publisher import MqttPub, NoopPublisher, Publisher +from ell.api.types import GetLMPResponse, LMPInvokedEvent, WriteInvocationInput, WriteLMPInput, LMP +from ell.store import Store +from ell.stores.sql import PostgresStore, SQLStore, SQLiteStore + + +logger = logging.getLogger(__name__) + + +publisher: Optional[Publisher] = None + + +async def get_publisher(): + yield publisher + +serializer: Optional[SQLStore] = None + + +def init_serializer(config: Config) -> SQLStore: + global serializer + if serializer is not None: + return serializer + elif config.pg_connection_string: + return PostgresStore(config.pg_connection_string) + elif config.storage_dir: + return SQLiteStore(config.storage_dir) + else: + raise ValueError("No storage configuration found") + + +def get_serializer(): + if serializer is None: + raise ValueError("Serializer not initialized") + return serializer + + +def get_session(): + if serializer is None: + raise ValueError("Serializer not initialized") + with Session(serializer.engine) as session: + yield session + + +def create_app(config: Config): + # setup_logging(config.log_level) + + @asynccontextmanager + async def lifespan(app: FastAPI): + global serializer + global publisher + + logger.info("Starting lifespan") + + serializer = init_serializer(config) + + if config.mqtt_connection_string is not None: + + host, port = config.mqtt_connection_string.split("://")[1].split(":") + + logger.info(f"Connecting to MQTT broker at {host}:{port}") + try: + async with aiomqtt.Client(host, int(port) if port else 1883) as mqtt: + logger.info("Connected to MQTT") + publisher = MqttPub(mqtt) + yield # Allow the app to run + except aiomqtt.MqttError as e: + logger.error(f"Failed to connect to MQTT", exc_info=e) + publisher = None + else: + publisher = NoopPublisher() + yield # allow the app to run + + app = FastAPI( + title="ELL API", + description="API server for ELL", + version="0.1.0", + lifespan=lifespan + ) + + @app.get("/lmp/versions", response_model=List[LMP]) + async def get_lmp_versions( + fqn: str, + serializer: Store = Depends(get_serializer)): + slmp = serializer.get_versions_by_fqn(fqn) + return [LMP.from_serialized_lmp(lmp) for lmp in slmp] + + + @app.get("/lmp/{lmp_id}", response_model=GetLMPResponse) + async def get_lmp(lmp_id: str, + serializer: Store = Depends(get_serializer), + session: Session = Depends(get_session)): + lmp = serializer.get_lmp(lmp_id, session=session) + if lmp is None: + raise HTTPException(status_code=404, detail="LMP not found") + + return LMP.from_serialized_lmp(lmp) + + @app.post("/lmp") + async def write_lmp( + lmp: WriteLMPInput, + uses: List[str], # SerializedLMPUses, + publisher: Publisher = Depends(get_publisher), + serializer: Store = Depends(get_serializer) + ): + serializer.write_lmp(lmp.to_serialized_lmp(), uses) + + loop = asyncio.get_event_loop() + loop.create_task( + publisher.publish( + f"lmp/{lmp.lmp_id}/created", + json.dumps({ + "lmp": lmp.model_dump(), + "uses": uses + }, default=str) + ) + ) + + @app.post("/invocation", response_model=WriteInvocationInput) + async def write_invocation( + input: WriteInvocationInput, + publisher: Publisher = Depends(get_publisher), + serializer: Store = Depends(get_serializer) + ): + logger.info(f"Writing invocation {input.invocation.lmp_id}") + invocation, consumes = input.to_serialized_invocation_input() + # TODO: return anything this might create like invocation id + _invo = serializer.write_invocation( + invocation, + consumes # type: ignore + ) + + + loop = asyncio.get_event_loop() + loop.create_task( + publisher.publish( + f"lmp/{input.invocation.lmp_id}/invoked", + LMPInvokedEvent( + lmp_id=input.invocation.lmp_id, + # invocation_id=invo.id, + consumes=consumes + ).model_dump_json() + ) + ) + return input + + + return app diff --git a/src/ell/api/types.py b/src/ell/api/types.py new file mode 100644 index 000000000..8d992635f --- /dev/null +++ b/src/ell/api/types.py @@ -0,0 +1,171 @@ +from functools import cached_property +from typing import Any, Dict, List, Optional, Tuple, Union, cast +from datetime import datetime, timezone +import uuid + +from openai import BaseModel +from pydantic import AwareDatetime, Field + +import ell.sqlmodels +from ell.types.message import Message +from ell.types.lmp import LMPType + + +def utc_now() -> datetime: + """ + Returns the current UTC timestamp. + Serializes to ISO-8601. + """ + return datetime.now(tz=timezone.utc) + + +class WriteLMPInput(BaseModel): + """ + Arguments to write a LMP. + """ + lmp_id: str + name: str + source: str + dependencies: str + lmp_type: LMPType + api_params: Optional[Dict[str, Any]] = None + initial_free_vars: Optional[Dict[str, Any]] = None + initial_global_vars: Optional[Dict[str, Any]] = None + + # this is omitted so as to not confuse whether the number should be incremented (should always happen at the db level) + # num_invocations: Optional[int] = None + commit_message: Optional[str] = None + version_number: Optional[int] = None + created_at: Optional[AwareDatetime] = Field(default_factory=utc_now) + + def to_serialized_lmp(self): + return ell.sqlmodels.SerializedLMP( + lmp_id=self.lmp_id, + lmp_type=self.lmp_type, + name=self.name, + source=self.source, + dependencies=self.dependencies, + api_params=self.api_params, + version_number=self.version_number, + initial_global_vars=self.initial_global_vars, + initial_free_vars=self.initial_free_vars, + commit_message=self.commit_message, + created_at=cast(datetime, self.created_at) + ) + + +class LMP(BaseModel): + lmp_id: str + name: str + source: str + dependencies: str + lmp_type: LMPType + api_params: Optional[Dict[str, Any]] + initial_free_vars: Optional[Dict[str, Any]] + initial_global_vars: Optional[Dict[str, Any]] + created_at: AwareDatetime + version_number: int + commit_message: Optional[str] + num_invocations: int + + @staticmethod + def from_serialized_lmp(serialized: ell.sqlmodels.SerializedLMP): + return LMP( + lmp_id=cast(str, serialized.lmp_id), + name=serialized.name, + source=serialized.source, + dependencies=serialized.dependencies, + lmp_type=serialized.lmp_type, + api_params=serialized.api_params, + initial_free_vars=serialized.initial_free_vars, + initial_global_vars=serialized.initial_global_vars, + created_at=serialized.created_at, + version_number=cast(int, serialized.version_number), + commit_message=serialized.commit_message, + num_invocations=cast(int, serialized.num_invocations), + ) + +# class GetLMPResponse(BaseModel): +# lmp: LMP +# uses: List[str] + + +GetLMPResponse = Optional[LMP] +# class LMPCreatedEvent(BaseModel): +# lmp: LMP +# uses: List[str] + +InvocationResults = Union[List[Message], Any] + + +class InvocationContents(BaseModel): + invocation_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + params: Optional[Dict[str, Any]] = None + results: Optional[InvocationResults] = None + invocation_api_params: Optional[Dict[str, Any]] = None + global_vars: Optional[Dict[str, Any]] = None + free_vars: Optional[Dict[str, Any]] = None + is_external: bool = Field(default=False) + + def to_serialized_invocation_contents(self): + return ell.sqlmodels.InvocationContents( + **self.model_dump() + ) + + @cached_property + def total_size_bytes(self) -> int: + """ + Returns the total uncompressed size of the invocation contents as JSON in bytes. + """ + import json + json_fields = [ + self.params, + self.results, + self.invocation_api_params, + self.global_vars, + self.free_vars + ] + return sum(len(json.dumps(field, default=(lambda x: x.model_dump_json() if isinstance(x, BaseModel) else str(x))).encode('utf-8')) for field in json_fields if field is not None) + + @cached_property + def should_externalize(self) -> bool: + return self.total_size_bytes > 102400 # Precisely 100kb in bytes + + +class Invocation(BaseModel): + """ + An invocation of an LMP. + """ + id: Optional[str] = None + lmp_id: str + latency_ms: int + prompt_tokens: Optional[int] = None + completion_tokens: Optional[int] = None + state_cache_key: Optional[str] = None + created_at: AwareDatetime = Field(default_factory=utc_now) + used_by_id: Optional[str] = None + contents: InvocationContents + + def to_serialized_invocation(self): + return ell.sqlmodels.Invocation( + **self.model_dump(exclude={"contents"}), + contents=self.contents.to_serialized_invocation_contents() + ) + + +class WriteInvocationInput(BaseModel): + """ + Arguments to write an invocation. + """ + invocation: Invocation + consumes: List[str] + + def to_serialized_invocation_input(self) -> Tuple[ell.sqlmodels.Invocation, List[str]]: + sinvo = self.invocation.to_serialized_invocation() + return sinvo, list(set(self.consumes)) + + +class LMPInvokedEvent(BaseModel): + lmp_id: str + # invocation_id: str + consumes: List[str] diff --git a/src/ell/configurator.py b/src/ell/configurator.py index 79e6afb5f..5e9ee8c34 100644 --- a/src/ell/configurator.py +++ b/src/ell/configurator.py @@ -5,22 +5,29 @@ from contextlib import contextmanager import threading from pydantic import BaseModel, ConfigDict, Field -from ell.store import Store +from ell.api.client import EllAPIClient, EllClient, EllSqliteClient from ell.provider import Provider _config_logger = logging.getLogger(__name__) + + class Config(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) registry: Dict[str, openai.Client] = Field(default_factory=dict, description="A dictionary mapping model names to OpenAI clients.") - verbose: bool = Field(default=False, description="If True, enables verbose logging.") + verbose: bool = Field(default=False, description="If True, enables wrapped logging for better readability.") wrapped_logging: bool = Field(default=True, description="If True, enables wrapped logging for better readability.") override_wrapped_logging_width: Optional[int] = Field(default=None, description="If set, overrides the default width for wrapped logging.") - store: Optional[Store] = Field(default=None, description="An optional Store instance for persistence.") autocommit: bool = Field(default=False, description="If True, enables automatic committing of changes to the store.") lazy_versioning: bool = Field(default=True, description="If True, enables lazy versioning for improved performance.") default_lm_params: Dict[str, Any] = Field(default_factory=dict, description="Default parameters for language models.") + default_system_prompt: str = Field(default="You are a helpful AI assistant.", description="The default system prompt used for AI interactions.") default_client: Optional[openai.Client] = Field(default=None, description="The default OpenAI client used when a specific model client is not found.") providers: Dict[Type, Type[Provider]] = Field(default_factory=dict, description="A dictionary mapping client types to provider classes.") + default_openai_client: Optional[openai.Client] = Field(default=None, description="The default OpenAI client used when a specific model client is not found.") + _client: Optional[EllClient] = None + store_blobs: bool = Field(default=True, description="If True, enables storing blobs in the store.") + + def __init__(self, **data): super().__init__(**data) @@ -39,15 +46,6 @@ def register_model(self, model_name: str, client: Any) -> None: with self._lock: self.registry[model_name] = client - @property - def has_store(self) -> bool: - """ - Check if a store is set. - - :return: True if a store is set, False otherwise. - :rtype: bool - """ - return self.store is not None @contextmanager def model_registry_override(self, overrides: Dict[str, Any]): @@ -59,12 +57,12 @@ def model_registry_override(self, overrides: Dict[str, Any]): """ if not hasattr(self._local, 'stack'): self._local.stack = [] - + with self._lock: current_registry = self._local.stack[-1] if self._local.stack else self.registry new_registry = current_registry.copy() new_registry.update(overrides) - + self._local.stack.append(new_registry) try: yield @@ -80,17 +78,20 @@ def get_client_for(self, model_name: str) -> Tuple[Optional[Any], bool]: :return: The OpenAI client for the specified model, or None if not found. :rtype: Optional[openai.Client] """ - current_registry = self._local.stack[-1] if hasattr(self._local, 'stack') and self._local.stack else self.registry + current_registry = self._local.stack[-1] if hasattr( + self._local, 'stack') and self._local.stack else self.registry client = current_registry.get(model_name) fallback = False if model_name not in current_registry.keys(): - warning_message = f"Warning: A default provider for model '{model_name}' could not be found. Falling back to default OpenAI client from environment variables." + warning_message = (f"Warning: A default provider for model '{model_name}' " + "could not be found. Falling back to default OpenAI " + "client from environment variables.") if self.verbose: from colorama import Fore, Style _config_logger.warning(f"{Fore.LIGHTYELLOW_EX}{warning_message}{Style.RESET_ALL}") else: _config_logger.debug(warning_message) - client = self.default_client + client = self.default_openai_client fallback = True return client, fallback @@ -102,32 +103,7 @@ def reset(self) -> None: self.__init__() if hasattr(self._local, 'stack'): del self._local.stack - - def set_store(self, store: Union[Store, str], autocommit: bool = True) -> None: - """ - Set the store for the configuration. - - :param store: The store to set. Can be a Store instance or a string path for SQLiteStore. - :type store: Union[Store, str] - :param autocommit: Whether to enable autocommit for the store. - :type autocommit: bool - """ - if isinstance(store, str): - from ell.stores.sql import SQLiteStore - self.store = SQLiteStore(store) - else: - self.store = store - self.autocommit = autocommit or self.autocommit - - def get_store(self) -> Store: - """ - Get the current store. - :return: The current store. - :rtype: Store - """ - return self.store - def set_default_lm_params(self, **params: Dict[str, Any]) -> None: """ Set default parameters for language models. @@ -136,8 +112,6 @@ def set_default_lm_params(self, **params: Dict[str, Any]) -> None: :type params: Dict[str, Any] """ self.default_lm_params = params - - def set_default_client(self, client: openai.Client) -> None: """ @@ -146,7 +120,10 @@ def set_default_client(self, client: openai.Client) -> None: :param client: The default OpenAI client to set. :type client: openai.Client """ - self.default_client = client + self.default_openai_client = client + + def set_ell_client(self, client: EllClient) -> None: + self._client = client def register_provider(self, provider_class: Type[Provider]) -> None: """ @@ -172,8 +149,12 @@ def get_provider_for(self, client: Any) -> Optional[Type[Provider]]: # Singleton instance config = Config() + def init( - store: Optional[Union[Store, str]] = None, + client: Optional[EllClient] = None, + base_url: Optional[str] = None, + storage_dir: Optional[str] = None, + store_blobs: bool = True, verbose: bool = False, autocommit: bool = True, lazy_versioning: bool = True, @@ -185,8 +166,8 @@ def init( :param verbose: Set verbosity of ELL operations. :type verbose: bool - :param store: Set the store for ELL. Can be a Store instance or a string path for SQLiteStore. - :type store: Union[Store, str], optional + :param storage_dir: Set the storage directory. + :type storage_dir: str :param autocommit: Set autocommit for the store operations. :type autocommit: bool :param lazy_versioning: Enable or disable lazy versioning. @@ -197,28 +178,24 @@ def init( :type default_openai_client: openai.Client, optional """ config.verbose = verbose + config.store_blobs = store_blobs config.lazy_versioning = lazy_versioning + config.autocommit = autocommit - if store is not None: - config.set_store(store, autocommit) + if client is not None: + config.set_ell_client(client) + elif base_url is not None: + config.set_ell_client(EllAPIClient(base_url)) + elif storage_dir is not None: + config.set_ell_client(EllSqliteClient(storage_dir)) if default_lm_params is not None: config.set_default_lm_params(**default_lm_params) - - if default_openai_client is not None: config.set_default_client(default_openai_client) # Existing helper functions -@wraps(config.get_store) -def get_store() -> Store: - return config.get_store() - -@wraps(config.set_store) -def set_store(*args, **kwargs) -> None: - return config.set_store(*args, **kwargs) - @wraps(config.set_default_lm_params) def set_default_lm_params(*args, **kwargs) -> None: return config.set_default_lm_params(*args, **kwargs) diff --git a/src/ell/decorators/track.py b/src/ell/decorators/track.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/ell/lmp/_track.py b/src/ell/lmp/_track.py index fb6294e80..3659c460d 100644 --- a/src/ell/lmp/_track.py +++ b/src/ell/lmp/_track.py @@ -1,20 +1,20 @@ +import asyncio import json import logging import threading -from ell.types import SerializedLMP, Invocation, InvocationTrace, InvocationContents -from ell.types.studio import LMPType, utc_now + +from ell.api.types import WriteInvocationInput, WriteLMPInput, Invocation, InvocationContents +from ell.api.types import utc_now +from ell.types.lmp import LMPType from ell.util._warnings import _autocommit_warning import ell.util.closure from ell.configurator import config -from ell.types._lstr import _lstr import inspect import secrets -import time -from datetime import datetime from functools import wraps -from typing import Any, Callable, Dict, Iterable, Optional, OrderedDict, Tuple +from typing import Any, Callable, Dict, Optional from ell.util.serialization import get_immutable_vars from ell.util.serialization import compute_state_cache_key @@ -22,36 +22,49 @@ logger = logging.getLogger(__name__) +ell_event_loop = None + + +def get_ell_event_loop(): + global ell_event_loop + if not ell_event_loop: + logger.info(f"Creating new event loop for ell, thread id: {threading.get_ident()}") + ell_event_loop = asyncio.new_event_loop() + return ell_event_loop + + # Thread-local storage for the invocation stack _invocation_stack = threading.local() + def get_current_invocation() -> Optional[str]: if not hasattr(_invocation_stack, 'stack'): _invocation_stack.stack = [] return _invocation_stack.stack[-1] if _invocation_stack.stack else None + def push_invocation(invocation_id: str): if not hasattr(_invocation_stack, 'stack'): _invocation_stack.stack = [] _invocation_stack.stack.append(invocation_id) + def pop_invocation(): if hasattr(_invocation_stack, 'stack') and _invocation_stack.stack: _invocation_stack.stack.pop() def _track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, Any]] = None) -> Callable: - - lmp_type = getattr(func_to_track, "__ell_type__", LMPType.OTHER) + lmp_type = getattr(func_to_track, "__ell_type__", LMPType.OTHER) # see if it exists if not hasattr(func_to_track, "_has_serialized_lmp"): func_to_track._has_serialized_lmp = False if not hasattr(func_to_track, "__ell_hash__") and not config.lazy_versioning: - ell.util.closure.lexically_closured_source(func_to_track, forced_dependencies) - + ell.util.closure.lexically_closured_source( + func_to_track, forced_dependencies) @wraps(func_to_track) def tracked_func(*fn_args, _get_invocation_id=False, **fn_kwargs) -> str: @@ -59,43 +72,49 @@ def tracked_func(*fn_args, _get_invocation_id=False, **fn_kwargs) -> str: # Compute the invocation id and hash the inputs for serialization. invocation_id = "invocation-" + secrets.token_hex(16) - state_cache_key : str = None - if not config.store: + state_cache_key: str = None + if not config._client: return func_to_track(*fn_args, **fn_kwargs, _invocation_origin=invocation_id)[0] parent_invocation_id = get_current_invocation() try: push_invocation(invocation_id) - + # Convert all positional arguments to named keyword arguments sig = inspect.signature(func_to_track) # Filter out kwargs that are not in the function signature - filtered_kwargs = {k: v for k, v in fn_kwargs.items() if k in sig.parameters} - + filtered_kwargs = {k: v for k, + v in fn_kwargs.items() if k in sig.parameters} + bound_args = sig.bind(*fn_args, **filtered_kwargs) bound_args.apply_defaults() all_kwargs = dict(bound_args.arguments) # Get the list of consumed lmps and clean the invocation params for serialization. - cleaned_invocation_params, ipstr, consumes = prepare_invocation_params( all_kwargs) + cleaned_invocation_params, ipstr, consumes = prepare_invocation_params( + all_kwargs) - try_use_cache = hasattr(func_to_track.__wrapper__, "__ell_use_cache__") + try_use_cache = hasattr( + func_to_track.__wrapper__, "__ell_use_cache__") - if try_use_cache: + if try_use_cache: # Todo: add nice logging if verbose for when using a cahced invocaiton. IN a different color with thar args.. - if not hasattr(func_to_track, "__ell_hash__") and config.lazy_versioning: - fn_closure, _ = ell.util.closure.lexically_closured_source(func_to_track) - + if not hasattr(func_to_track, "__ell_hash__") and config.lazy_versioning: + fn_closure, _ = ell.util.closure.lexically_closured_source( + func_to_track) + # compute the state cachekey - state_cache_key = compute_state_cache_key(ipstr, func_to_track.__ell_closure__) - + state_cache_key = compute_state_cache_key( + ipstr, func_to_track.__ell_closure__) + cache_store = func_to_track.__wrapper__.__ell_use_cache__ - cached_invocations = cache_store.get_cached_invocations(func_to_track.__ell_hash__, state_cache_key) - - + cached_invocations = cache_store.get_cached_invocations( + func_to_track.__ell_hash__, state_cache_key) + if len(cached_invocations) > 0: # TODO THis is bad? - results = [d.deserialize() for d in cached_invocations[0].results] + results = [d.deserialize() + for d in cached_invocations[0].results] logger.info(f"Using cached result for {func_to_track.__qualname__} with state cache key: {state_cache_key}") if len(results) == 1: @@ -105,8 +124,7 @@ def tracked_func(*fn_args, _get_invocation_id=False, **fn_kwargs) -> str: # Todo: Unfiy this with the non-cached case. We should go through the same code pathway. else: logger.info(f"Attempted to use cache on {func_to_track.__qualname__} but it was not cached, or did not exist in the store. Refreshing cache...") - - + _start_time = utc_now() # XXX: thread saftey note, if I prevent yielding right here and get the global context I should be fine re: cache key problem @@ -116,26 +134,28 @@ def tracked_func(*fn_args, _get_invocation_id=False, **fn_kwargs) -> str: (func_to_track(*fn_args, **fn_kwargs), {}, {}) if lmp_type == LMPType.OTHER else func_to_track(*fn_args, _invocation_origin=invocation_id, **fn_kwargs, ) - ) + ) latency_ms = (utc_now() - _start_time).total_seconds() * 1000 usage = metadata.get("usage", {}) - prompt_tokens=usage.get("prompt_tokens", 0) - completion_tokens=usage.get("completion_tokens", 0) - + prompt_tokens = usage.get("prompt_tokens", 0) + completion_tokens = usage.get("completion_tokens", 0) - #XXX: cattrs add invocation origin here recursively on all pirmitive types within a message. - #XXX: This will allow all objects to be traced automatically irrespective origin rather than relying on the API to do it, it will of vourse be expensive but unify track. - #XXX: No other code will need to consider tracking after this point. + # XXX: cattrs add invocation origin here recursively on all pirmitive types within a message. + # XXX: This will allow all objects to be traced automatically irrespective origin rather than relying on the API to do it, it will of vourse be expensive but unify track. + # XXX: No other code will need to consider tracking after this point. if not hasattr(func_to_track, "__ell_hash__") and config.lazy_versioning: - ell.util.closure.lexically_closured_source(func_to_track, forced_dependencies) - _serialize_lmp(func_to_track) + ell.util.closure.lexically_closured_source( + func_to_track, forced_dependencies) + # _serialize_lmp(func_to_track) + _serialize_lmp_sync(func_to_track) if not state_cache_key: - state_cache_key = compute_state_cache_key(ipstr, func_to_track.__ell_closure__) + state_cache_key = compute_state_cache_key( + ipstr, func_to_track.__ell_closure__) - _write_invocation(func_to_track, invocation_id, latency_ms, prompt_tokens, completion_tokens, - state_cache_key, invocation_api_params, cleaned_invocation_params, consumes, result, parent_invocation_id) + _write_invocation_sync(func_to_track, invocation_id, latency_ms, prompt_tokens, completion_tokens, + state_cache_key, invocation_api_params, cleaned_invocation_params, consumes, result, parent_invocation_id) if _get_invocation_id: return result, invocation_id @@ -144,8 +164,7 @@ def tracked_func(*fn_args, _get_invocation_id=False, **fn_kwargs) -> str: finally: pop_invocation() - - func_to_track.__wrapper__ = tracked_func + func_to_track.__wrapper__ = tracked_func if hasattr(func_to_track, "__ell_api_params__"): tracked_func.__ell_api_params__ = func_to_track.__ell_api_params__ if hasattr(func_to_track, "__ell_params_model__"): @@ -155,11 +174,12 @@ def tracked_func(*fn_args, _get_invocation_id=False, **fn_kwargs) -> str: return tracked_func -def _serialize_lmp(func): + +async def _serialize_lmp(func): # Serialize deptjh first all fo the used lmps. for f in func.__ell_uses__: - _serialize_lmp(f) - + await _serialize_lmp(f) + if getattr(func, "_has_serialized_lmp", False): return func._has_serialized_lmp = False @@ -168,10 +188,11 @@ def _serialize_lmp(func): name = func.__qualname__ api_params = getattr(func, "__ell_api_params__", None) - lmps = config.store.get_versions_by_fqn(fqn=name) + lmps = await config._client.get_lmp_versions(fqn=name) version = 0 already_in_store = any(lmp.lmp_id == func.__ell_hash__ for lmp in lmps) - + commit = None + if not already_in_store: commit = None if lmps: @@ -182,10 +203,10 @@ def _serialize_lmp(func): if not _autocommit_warning(): from ell.util.differ import write_commit_message_for_diff commit = str(write_commit_message_for_diff( - f"{latest_lmp.dependencies}\n\n{latest_lmp.source}", + f"{latest_lmp.dependencies}\n\n{latest_lmp.source}", f"{fn_closure[1]}\n\n{fn_closure[0]}")[0]) - serialized_lmp = SerializedLMP( + input = WriteLMPInput( lmp_id=func.__ell_hash__, name=name, created_at=utc_now(), @@ -198,12 +219,19 @@ def _serialize_lmp(func): api_params=api_params if api_params else None, version_number=version, ) - config.store.write_lmp(serialized_lmp, [f.__ell_hash__ for f in func.__ell_uses__]) + uses = [f.__ell_hash__ for f in func.__ell_uses__] + await config._client.write_lmp(input, uses) func._has_serialized_lmp = True -def _write_invocation(func, invocation_id, latency_ms, prompt_tokens, completion_tokens, - state_cache_key, invocation_api_params, cleaned_invocation_params, consumes, result, parent_invocation_id): - + +def _serialize_lmp_sync(func): + loop = get_ell_event_loop() + return loop.run_until_complete(_serialize_lmp(func)) + + +async def _write_invocation(func, invocation_id, latency_ms, prompt_tokens, completion_tokens, + state_cache_key, invocation_api_params, cleaned_invocation_params, consumes, result, parent_invocation_id): + invocation_contents = InvocationContents( invocation_id=invocation_id, params=cleaned_invocation_params, @@ -213,11 +241,12 @@ def _write_invocation(func, invocation_id, latency_ms, prompt_tokens, completion free_vars=get_immutable_vars(func.__ell_closure__[3]) ) - if invocation_contents.should_externalize and config.store.has_blob_storage: + if invocation_contents.should_externalize: invocation_contents.is_external = True - - # Write to the blob store - blob_id = config.store.blob_store.store_blob( + + + # Write to the blob store + blob_id = await config._client.store_blob( json.dumps(invocation_contents.model_dump( ), default=str).encode('utf-8'), invocation_id @@ -231,13 +260,37 @@ def _write_invocation(func, invocation_id, latency_ms, prompt_tokens, completion id=invocation_id, lmp_id=func.__ell_hash__, created_at=utc_now(), - latency_ms=latency_ms, + latency_ms=int(latency_ms), prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, state_cache_key=state_cache_key, used_by_id=parent_invocation_id, contents=invocation_contents ) + input = WriteInvocationInput( + invocation=invocation, + consumes=consumes + ) - config.store.write_invocation(invocation, consumes) - + await config._client.write_invocation(input) + + +def _write_invocation_sync(func, invocation_id, latency_ms, prompt_tokens, completion_tokens, + state_cache_key, invocation_api_params, cleaned_invocation_params, consumes, result, parent_invocation_id): + loop = get_ell_event_loop() + return loop.run_until_complete( + + _write_invocation( + func, + invocation_id, + latency_ms, + prompt_tokens, + completion_tokens, + state_cache_key, + invocation_api_params, + cleaned_invocation_params, + consumes, + result, + parent_invocation_id + ) + ) diff --git a/src/ell/lmp/complex.py b/src/ell/lmp/complex.py index f571720cb..75eb729b2 100644 --- a/src/ell/lmp/complex.py +++ b/src/ell/lmp/complex.py @@ -3,7 +3,7 @@ from ell.types._lstr import _lstr from ell.types import Message, ContentBlock from ell.types.message import LMP, InvocableLM, LMPParams, MessageOrDict, _lstr_generic -from ell.types.studio import LMPType +from ell.types.lmp import LMPType from ell.util._warnings import _warnings from ell.util.api import call from ell.util.verbosity import compute_color, model_usage_logger_pre diff --git a/src/ell/lmp/tool.py b/src/ell/lmp/tool.py index bab0df999..8a8942c21 100644 --- a/src/ell/lmp/tool.py +++ b/src/ell/lmp/tool.py @@ -9,7 +9,7 @@ # from ell.util.verbosity import compute_color, tool_usage_logger_pre from ell.configurator import config from ell.types._lstr import _lstr -from ell.types.studio import LMPType +from ell.types.lmp import LMPType import inspect from ell.types.message import ContentBlock, InvocableTool, ToolResult, coerce_content_list diff --git a/src/ell/models/openai.py b/src/ell/models/openai.py index 676b25b0c..945a2ecac 100644 --- a/src/ell/models/openai.py +++ b/src/ell/models/openai.py @@ -87,4 +87,4 @@ def register(client: openai.Client): pass register(default_client) -config.default_client = default_client \ No newline at end of file +config.set_default_client(default_client) \ No newline at end of file diff --git a/src/ell/types/studio.py b/src/ell/sqlmodels.py similarity index 90% rename from src/ell/types/studio.py rename to src/ell/sqlmodels.py index 4ca0468fa..571417f6e 100644 --- a/src/ell/types/studio.py +++ b/src/ell/sqlmodels.py @@ -1,42 +1,48 @@ from datetime import datetime, timezone -import enum from functools import cached_property +from typing import Any, Dict, List, Optional, Union +from pydantic import BaseModel import sqlalchemy.types as types -from ell.types.message import Any, Any, Field, Message, Optional - -from sqlmodel import Column, Field, SQLModel -from typing import Optional -from dataclasses import dataclass -from typing import Dict, List, Literal, Union, Any, Optional - -from pydantic import BaseModel, field_validator - from datetime import datetime -from typing import Any, List, Optional -from sqlmodel import Field, SQLModel, Relationship, JSON, Column +from sqlmodel import Field, SQLModel, Relationship, JSON, Column, Index, func from sqlalchemy import Index, func -from typing import TypeVar, Any +from ell.types.message import Message +from ell.types.lmp import LMPType -def utc_now() -> datetime: - """ - Returns the current UTC timestamp. - Serializes to ISO-8601. - """ - return datetime.now(tz=timezone.utc) +class InvocationContentsBase(SQLModel): + invocation_id: str = Field(foreign_key="invocation.id", index=True, primary_key=True) + params: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON)) + results: Optional[Union[List[Message], Any]] = Field(default=None, sa_column=Column(JSON)) + invocation_api_params: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON)) + global_vars: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON)) + free_vars: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON)) + is_external : bool = Field(default=False) + @cached_property + def should_externalize(self) -> bool: + import json -class SerializedLMPUses(SQLModel, table=True): - """ - Represents the many-to-many relationship between SerializedLMPs. + json_fields = [ + self.params, + self.results, + self.invocation_api_params, + self.global_vars, + self.free_vars + ] - This class is used to track which LMPs use or are used by other LMPs. - """ + total_size = sum( + len(json.dumps(field, default=(lambda x: json.dumps(x.model_dump(), default=str) if isinstance(x, BaseModel) else str(x))).encode('utf-8')) for field in json_fields if field is not None + ) + # print("total_size", total_size) - lmp_user_id: Optional[str] = Field(default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True) # ID of the LMP that is being used - lmp_using_id: Optional[str] = Field(default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True) # ID of the LMP that is using the other LMP + return total_size > 102400 # Precisely 100kb in bytes + + +class InvocationContents(InvocationContentsBase, table=True): + invocation: "Invocation" = Relationship(back_populates="contents") class UTCTimestamp(types.TypeDecorator[datetime]): @@ -51,13 +57,26 @@ def UTCTimestampField(index:bool=False, **kwargs:Any): sa_column=Column(UTCTimestamp(timezone=True), index=index, **kwargs)) -class LMPType(str, enum.Enum): - LM = "LM" - TOOL = "TOOL" - MULTIMODAL = "MULTIMODAL" - OTHER = "OTHER" +# Should be subtyped for differnet kidns of LMPS. +# XXX: Move all ofh te binary data out to a different table. +# XXX: Need a flag that says dont store images. +# XXX: Deprecate the args columns + # global_vars and free_vars removed from here +class InvocationBase(SQLModel): + id: Optional[str] = Field(default=None, primary_key=True) + lmp_id: str = Field(foreign_key="serializedlmp.lmp_id", index=True) + latency_ms: float + prompt_tokens: Optional[int] = Field(default=None) + completion_tokens: Optional[int] = Field(default=None) + state_cache_key: Optional[str] = Field(default=None) + created_at: datetime = UTCTimestampField(default=func.now(), nullable=False) + used_by_id: Optional[str] = Field(default=None, foreign_key="invocation.id", index=True) +class InvocationTrace(SQLModel, table=True): + invocation_consumer_id: str = Field(foreign_key="invocation.id", primary_key=True, index=True) + invocation_consuming_id: str = Field(foreign_key="invocation.id", primary_key=True, index=True) + class SerializedLMPBase(SQLModel): lmp_id: Optional[str] = Field(default=None, primary_key=True) @@ -75,6 +94,17 @@ class SerializedLMPBase(SQLModel): version_number: Optional[int] = Field(default=None) +class SerializedLMPUses(SQLModel, table=True): + """ + Represents the many-to-many relationship between SerializedLMPs. + + This class is used to track which LMPs use or are used by other LMPs. + """ + + lmp_user_id: Optional[str] = Field(default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True) # ID of the LMP that is being used + lmp_using_id: Optional[str] = Field(default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True) # ID of the LMP that is using the other LMP + + class SerializedLMP(SerializedLMPBase, table=True): invocations: List["Invocation"] = Relationship(back_populates="lmp") used_by: Optional[List["SerializedLMP"]] = Relationship( @@ -98,55 +128,6 @@ class Config: table_name = "serializedlmp" unique_together = [("version_number", "name")] -class InvocationTrace(SQLModel, table=True): - invocation_consumer_id: str = Field(foreign_key="invocation.id", primary_key=True, index=True) - invocation_consuming_id: str = Field(foreign_key="invocation.id", primary_key=True, index=True) - -# Should be subtyped for differnet kidns of LMPS. -# XXX: Move all ofh te binary data out to a different table. -# XXX: Need a flag that says dont store images. -# XXX: Deprecate the args columns -class InvocationBase(SQLModel): - id: Optional[str] = Field(default=None, primary_key=True) - lmp_id: str = Field(foreign_key="serializedlmp.lmp_id", index=True) - latency_ms: float - prompt_tokens: Optional[int] = Field(default=None) - completion_tokens: Optional[int] = Field(default=None) - state_cache_key: Optional[str] = Field(default=None) - created_at: datetime = UTCTimestampField(default=func.now(), nullable=False) - used_by_id: Optional[str] = Field(default=None, foreign_key="invocation.id", index=True) - # global_vars and free_vars removed from here - -class InvocationContentsBase(SQLModel): - invocation_id: str = Field(foreign_key="invocation.id", index=True, primary_key=True) - params: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON)) - results: Optional[Union[List[Message], Any]] = Field(default=None, sa_column=Column(JSON)) - invocation_api_params: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON)) - global_vars: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON)) - free_vars: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON)) - is_external : bool = Field(default=False) - - @cached_property - def should_externalize(self) -> bool: - import json - - json_fields = [ - self.params, - self.results, - self.invocation_api_params, - self.global_vars, - self.free_vars - ] - - total_size = sum( - len(json.dumps(field, default=(lambda x: json.dumps(x.model_dump(), default=str) if isinstance(x, BaseModel) else str(x))).encode('utf-8')) for field in json_fields if field is not None - ) - # print("total_size", total_size) - - return total_size > 102400 # Precisely 100kb in bytes - -class InvocationContents(InvocationContentsBase, table=True): - invocation: "Invocation" = Relationship(back_populates="contents") class Invocation(InvocationBase, table=True): lmp: SerializedLMP = Relationship(back_populates="invocations") @@ -175,3 +156,5 @@ class Invocation(InvocationBase, table=True): Index('ix_invocation_created_at_latency_ms', 'created_at', 'latency_ms'), Index('ix_invocation_created_at_tokens', 'created_at', 'prompt_tokens', 'completion_tokens'), ) + + diff --git a/src/ell/store.py b/src/ell/store.py index 841cca242..073877535 100644 --- a/src/ell/store.py +++ b/src/ell/store.py @@ -1,10 +1,11 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from datetime import datetime -from typing import Any, Optional, Dict, List, Set, Union -from ell.types._lstr import _lstr -from ell.types import SerializedLMP, Invocation +from sqlmodel import Session +from ell.sqlmodels import Invocation, SerializedLMP from ell.types.message import InvocableLM +from typing import Any, Optional, List, Set + + class BlobStore(ABC): @abstractmethod @@ -30,7 +31,17 @@ def has_blob_storage(self) -> bool: return self.blob_store is not None @abstractmethod - def write_lmp(self, serialized_lmp: SerializedLMP, uses: Dict[str, Any]) -> Optional[Any]: + def get_lmp(self, lmp_id: str, session: Optional[Session] = None) -> Optional[SerializedLMP]: + """ + Get an LMP by its ID. + + :param lmp_id: ID of the LMP to retrieve. + :return: SerializedLMP object containing all LMP details, or None if the LMP does not exist. + """ + pass + + @abstractmethod + def write_lmp(self, serialized_lmp: SerializedLMP, uses: List[str]) -> Optional[Any]: """ Write an LMP (Language Model Package) to the storage. @@ -46,7 +57,6 @@ def write_invocation(self, invocation: Invocation, consumes: Set[str]) -> Optio Write an invocation of an LMP to the storage. :param invocation: Invocation object containing all invocation details. - :param results: List of SerializedLStr objects representing the results. :param consumes: Set of invocation IDs consumed by this invocation. :return: Optional return value. """ diff --git a/src/ell/stores/sql.py b/src/ell/stores/sql.py index c96397e12..5f97294ba 100644 --- a/src/ell/stores/sql.py +++ b/src/ell/stores/sql.py @@ -1,54 +1,76 @@ from datetime import datetime, timedelta import json import os -from typing import Any, Optional, Dict, List, Set, Union -from pydantic import BaseModel +from typing import Any, Optional, Dict, List, Set from sqlmodel import Session, SQLModel, create_engine, select +from ell.sqlmodels import Invocation, InvocationTrace, SerializedLMP import ell.store -import cattrs -import numpy as np from sqlalchemy.sql import text -from ell.types import InvocationTrace, SerializedLMP, Invocation, InvocationContents -from ell.types._lstr import _lstr -from sqlalchemy import or_, func, and_, extract, FromClause -from sqlalchemy.types import TypeDecorator, VARCHAR -from ell.types.studio import SerializedLMPUses, utc_now +from sqlalchemy import func, and_ from ell.util.serialization import pydantic_ltype_aware_cattr import gzip -import json +import logging +from typing import Any, Optional, Dict, List, Set +from datetime import datetime, timedelta +from sqlalchemy import Engine, func, and_ + +logger = logging.getLogger(__name__) class SQLStore(ell.store.Store): - def __init__(self, db_uri: str, blob_store: Optional[ell.store.BlobStore] = None): - self.engine = create_engine(db_uri, - json_serializer=lambda obj: json.dumps(pydantic_ltype_aware_cattr.unstructure(obj), - sort_keys=True, default=repr)) - + def __init__(self, db_uri: Optional[str] = None, blob_store: Optional[ell.store.BlobStore] = None, engine: Optional[Engine] = None): + + if engine is not None: + self.engine = engine + elif db_uri is None: + raise ValueError( + "db_uri cannot be None when engine is not provided as an argument") + else: + self.engine = create_engine( + db_uri, + json_serializer=lambda obj: json.dumps(pydantic_ltype_aware_cattr.unstructure(obj), + sort_keys=True, default=repr)) SQLModel.metadata.create_all(self.engine) - self.open_files: Dict[str, Dict[str, Any]] = {} super().__init__(blob_store) - def write_lmp(self, serialized_lmp: SerializedLMP, uses: Dict[str, Any]) -> Optional[Any]: + def get_lmp(self, lmp_id: str, session: Optional[Session] = None) -> Optional[SerializedLMP]: + if session is None: + with Session(self.engine) as session: + return session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first() + else: + return session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first() + + def write_lmp(self, serialized_lmp: SerializedLMP, uses: List[str]) -> Optional[Any]: + """ + Creates an LMP if it does not exist. + LMPs as entities are not unique by fqn but by lmp_id. + """ with Session(self.engine) as session: + logger.debug(f"Begin writing LMP {serialized_lmp.lmp_id}") # Bind the serialized_lmp to the session - lmp = session.exec(select(SerializedLMP).filter(SerializedLMP.lmp_id == serialized_lmp.lmp_id)).first() - + lmp = None + if serialized_lmp.lmp_id: + lmp = self.get_lmp(serialized_lmp.lmp_id, session) + if lmp: # Already added to the DB. + logger.debug(f"LMP {serialized_lmp.lmp_id} already exists in the DB. Skipping write.") return lmp else: session.add(serialized_lmp) for use_id in uses: - used_lmp = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == use_id)).first() + used_lmp = self.get_lmp(use_id, session) if used_lmp: serialized_lmp.uses.append(used_lmp) session.commit() + logger.debug(f"Wrote new LMP {serialized_lmp.lmp_id} to the DB.") return None def write_invocation(self, invocation: Invocation, consumes: Set[str]) -> Optional[Any]: with Session(self.engine) as session: - lmp = session.exec(select(SerializedLMP).filter(SerializedLMP.lmp_id == invocation.lmp_id)).first() + logger.debug(f"Begin writing invocation {invocation.id}") + lmp = self.get_lmp(invocation.lmp_id, session) assert lmp is not None, f"LMP with id {invocation.lmp_id} not found. Writing invocation erroneously" # Increment num_invocations @@ -59,19 +81,26 @@ def write_invocation(self, invocation: Invocation, consumes: Set[str]) -> Option # Add the invocation contents session.add(invocation.contents) - + # Add the invocation session.add(invocation) + logger.debug(f"Committing invocation {invocation.id}") + session.commit() + logger.debug(f"Committed invocation {invocation.id}") + # Now create traces. for consumed_id in consumes: + logger.debug(f"Creating trace from {invocation.id} to {consumed_id}") session.add(InvocationTrace( invocation_consumer_id=invocation.id, invocation_consuming_id=consumed_id )) + logger.debug(f"Committing traces for invocation {invocation.id}") session.commit() - return None + logger.debug(f"Committed traces for invocation {invocation.id}") + return invocation def get_cached_invocations(self, lmp_id :str, state_cache_key :str) -> List[Invocation]: with Session(self.engine) as session: @@ -121,7 +150,7 @@ def get_lmps(self, session: Session, skip: int = 0, limit: int = 10, subquery=No return results def get_invocations(self, session: Session, lmp_filters: Dict[str, Any], skip: int = 0, limit: int = 10, filters: Optional[Dict[str, Any]] = None, hierarchical: bool = False) -> List[Dict[str, Any]]: - + query = select(Invocation).join(SerializedLMP) # Apply LMP filters @@ -163,7 +192,7 @@ def get_traces(self, session: Session): }) return traces - + def get_invocations_aggregate(self, session: Session, lmp_filters: Dict[str, Any] = None, filters: Dict[str, Any] = None, days: int = 30) -> Dict[str, Any]: # Calculate the start date for the graph data start_date = datetime.utcnow() - timedelta(days=days) @@ -181,7 +210,7 @@ def get_invocations_aggregate(self, session: Session, lmp_filters: Dict[str, Any if filters: base_subquery = base_subquery.filter(and_(*[getattr(Invocation, k) == v for k, v in filters.items()])) - + data = session.exec(base_subquery).all() # Calculate aggregate metrics @@ -208,16 +237,37 @@ def get_invocations_aggregate(self, session: Session, lmp_filters: Dict[str, Any "graph_data": graph_data } + class SQLiteStore(SQLStore): def __init__(self, db_dir: str): assert not db_dir.endswith('.db'), "Create store with a directory not a db." - - os.makedirs(db_dir, exist_ok=True) - self.db_dir = db_dir - db_path = os.path.join(db_dir, 'ell.db') - blob_store = SQLBlobStore(db_dir) - super().__init__(f'sqlite:///{db_path}', blob_store=blob_store) + if ":memory:" not in db_dir: + os.makedirs(db_dir, exist_ok=True) + self.db_dir = db_dir + db_path = os.path.join(db_dir, 'ell.db') + blob_store = SQLBlobStore(db_dir) + super().__init__(f'sqlite:///{db_path}', blob_store=blob_store) + else: + from sqlalchemy.pool import StaticPool + # todo. set up blob store for in-memory + engine = create_engine( + 'sqlite://', + connect_args={'check_same_thread': False}, + poolclass=StaticPool + ) + + return super().__init__(engine=engine) + + + def write_external_blob(self, id: str, json_dump: str, depth: int = 2): + assert self.blob_store is not None, "Blob store is not initialized" + self.blob_store.store_blob(json_dump.encode('utf-8'), metadata={'id': id, 'depth': depth}) + + def read_external_blob(self, id: str, depth: int = 2) -> str: + assert self.blob_store is not None, "Blob store is not initialized" + return self.blob_store.retrieve_blob(id).decode('utf-8') +# todo. rename to sqlite blob store or local fs blob store class SQLBlobStore(ell.store.BlobStore): def __init__(self, db_dir: str): self.db_dir = db_dir @@ -242,7 +292,33 @@ def _get_blob_path(self, id: str, depth: int = 2) -> str: file_name = _id[depth*increment:] return os.path.join(self.db_dir, *dirs, file_name) -class PostgresStore(SQLStore): +# i think we should consider for multimedia inputs +# having the image available at a url ahead of time (ie don't call the llm with base64) +# to reduce transactional issues later +# it's recommended to use image urls for long running conversations anyway +# suppodely it's generally bad to store blobs in postgres so i'm not sure i'll implement this +class PostgresBlobStore(ell.store.BlobStore): + """ + Blob store that uses postgres as the backing store. + Not recommended for production use. + + """ def __init__(self, db_uri: str): - super().__init__(db_uri) - + self.db_uri = db_uri + raise NotImplementedError("Not implemented") + + def store_blob(self, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str: + raise NotImplementedError("Not implemented") + + def retrieve_blob(self, blob_id: str) -> bytes: + raise NotImplementedError("Not implemented") + +class PostgresStore(SQLStore): + def __init__(self, db_uri: str, blob_store: Optional[ell.store.BlobStore] = None): + super().__init__(db_uri ) + if blob_store is not None: + self.blob_store = blob_store + else: + logger.warning("No blob store provided.") + # raise NotImplementedError("Not implemented") + # self.blob_store = PostgresBlobStore(db_uri) diff --git a/src/ell/studio/__main__.py b/src/ell/studio/__main__.py index 06a0354b3..8c9ea2fa9 100644 --- a/src/ell/studio/__main__.py +++ b/src/ell/studio/__main__.py @@ -1,12 +1,13 @@ import asyncio import os +from fastapi import FastAPI import uvicorn from argparse import ArgumentParser from ell.studio.config import Config +from ell.studio.logger import setup_logging from ell.studio.server import create_app from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse -from watchfiles import awatch import time @@ -16,22 +17,32 @@ def main(): help="Directory for filesystem serializer storage (default: current directory)") parser.add_argument("--pg-connection-string", default=None, help="PostgreSQL connection string (default: None)") - parser.add_argument("--host", default="127.0.0.1", help="Host to run the server on") - parser.add_argument("--port", type=int, default=5000, help="Port to run the server on") - parser.add_argument("--dev", action="store_true", help="Run in development mode") + parser.add_argument("--mqtt-connection-string", default=None, + help="MQTT connection string (default: None)") + parser.add_argument("--host", default="0.0.0.0", + help="Host to run the server on") + parser.add_argument("--port", type=int, default=5000, + help="Port to run the server on") + parser.add_argument("--dev", action="store_true", + help="Run in development mode") args = parser.parse_args() if args.dev: assert args.port == 5000, "Port must be 5000 in development mode" - config = Config.create(storage_dir=args.storage_dir, - pg_connection_string=args.pg_connection_string) + config = Config( + storage_dir=args.storage_dir, + pg_connection_string=args.pg_connection_string, + mqtt_connection_string=args.mqtt_connection_string + ) + app = create_app(config) if not args.dev: # In production mode, serve the built React app static_dir = os.path.join(os.path.dirname(__file__), "static") - # app.mount("/", StaticFiles(directory=static_dir, html=True), name="static") + app.mount("/", StaticFiles(directory=static_dir, + html=True), name="static") @app.get("/{full_path:path}") async def serve_react_app(full_path: str): @@ -41,16 +52,15 @@ async def serve_react_app(full_path: str): else: return FileResponse(os.path.join(static_dir, "index.html")) - db_path = os.path.join(args.storage_dir) - async def db_watcher(db_path, app): + async def db_watcher(db_path: str, app: FastAPI): last_stat = None while True: await asyncio.sleep(0.1) # Fixed interval of 0.1 seconds try: current_stat = os.stat(db_path) - + if last_stat is None: print(f"Database file found: {db_path}") await app.notify_clients("database_updated") @@ -83,9 +93,15 @@ async def db_watcher(db_path, app): config = uvicorn.Config(app=app, host=args.host, port=args.port, loop=loop) server = uvicorn.Server(config) - loop.create_task(server.serve()) - loop.create_task(db_watcher(db_path, app)) + + tasks = [] + tasks.append(loop.create_task(server.serve())) + + if args.storage_dir: + tasks.append(loop.create_task(db_watcher(args.storage_dir, app))) + loop.run_forever() + if __name__ == "__main__": main() diff --git a/src/ell/studio/config.py b/src/ell/studio/config.py index 851c9c5e5..2feab899d 100644 --- a/src/ell/studio/config.py +++ b/src/ell/studio/config.py @@ -1,6 +1,7 @@ from functools import lru_cache +import json import os -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel import logging @@ -10,7 +11,7 @@ # todo. maybe we default storage dir and other things in the future to a well-known location # like ~/.ell or something -@lru_cache +@lru_cache(maxsize=1) def ell_home() -> str: return os.path.join(os.path.expanduser("~"), ".ell") @@ -18,23 +19,28 @@ def ell_home() -> str: class Config(BaseModel): pg_connection_string: Optional[str] = None storage_dir: Optional[str] = None - - @classmethod - def create( - cls, - storage_dir: Optional[str] = None, - pg_connection_string: Optional[str] = None, - ) -> 'Config': - pg_connection_string = pg_connection_string or os.getenv("ELL_PG_CONNECTION_STRING") - storage_dir = storage_dir or os.getenv("ELL_STORAGE_DIR") - + mqtt_connection_string: Optional[str] = None + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + + def model_post_init(self, __context: Any): + # Storage + self.pg_connection_string = self.pg_connection_string or os.getenv( + "ELL_PG_CONNECTION_STRING") + self.storage_dir = self.storage_dir or os.getenv("ELL_STORAGE_DIR") + # Enforce that we use either sqlite or postgres, but not both - if pg_connection_string is not None and storage_dir is not None: + if self.pg_connection_string is not None and self.storage_dir is not None: raise ValueError("Cannot use both sqlite and postgres") - + # For now, fall back to sqlite if no PostgreSQL connection string is provided - if pg_connection_string is None and storage_dir is None: + if self.pg_connection_string is None and self.storage_dir is None: # This intends to honor the default we had set in the CLI - storage_dir = os.getcwd() + # todo. better default? + self.storage_dir = os.getcwd() + + # Pubsub + self.mqtt_connection_string = self.mqtt_connection_string or os.getenv("ELL_MQTT_CONNECTION_STRING") + + logger.info(f"Resolved config: {json.dumps(self.model_dump(), indent=2)}") - return cls(pg_connection_string=pg_connection_string, storage_dir=storage_dir) \ No newline at end of file diff --git a/src/ell/studio/connection_manager.py b/src/ell/studio/connection_manager.py deleted file mode 100644 index 765a45016..000000000 --- a/src/ell/studio/connection_manager.py +++ /dev/null @@ -1,18 +0,0 @@ -from fastapi import WebSocket - - -class ConnectionManager: - def __init__(self): - self.active_connections = [] - - async def connect(self, websocket: WebSocket): - await websocket.accept() - self.active_connections.append(websocket) - - def disconnect(self, websocket: WebSocket): - self.active_connections.remove(websocket) - - async def broadcast(self, message: str): - for connection in self.active_connections: - print(f"Broadcasting message to {connection} {message}") - await connection.send_text(message) \ No newline at end of file diff --git a/src/ell/studio/datamodels.py b/src/ell/studio/datamodels.py index d115d99cc..eb3eaba54 100644 --- a/src/ell/studio/datamodels.py +++ b/src/ell/studio/datamodels.py @@ -1,7 +1,7 @@ from datetime import datetime -from typing import List, Optional, Dict, Any -from sqlmodel import SQLModel -from ell.types import SerializedLMPBase, InvocationBase, InvocationContentsBase +from typing import List +from ell.sqlmodels import InvocationBase, InvocationContentsBase, SerializedLMPBase + class SerializedLMPWithUses(SerializedLMPBase): diff --git a/src/ell/studio/logger.py b/src/ell/studio/logger.py new file mode 100644 index 000000000..ca291177b --- /dev/null +++ b/src/ell/studio/logger.py @@ -0,0 +1,40 @@ +import logging +from colorama import Fore, Style, init + +initialized = False +def setup_logging(level: int = logging.INFO): + global initialized + if initialized: + return + # Initialize colorama for cross-platform colored output + init(autoreset=True) + + # Create a custom formatter + class ColoredFormatter(logging.Formatter): + FORMATS = { + logging.DEBUG: Fore.CYAN + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.INFO: Fore.GREEN + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.WARNING: Fore.YELLOW + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.ERROR: Fore.RED + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.CRITICAL: Fore.RED + Style.BRIGHT + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL + } + + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt, datefmt="%Y-%m-%d %H:%M:%S") + return formatter.format(record) + + # Create and configure the logger + logger = logging.getLogger("ell") + logger.setLevel(level) + + # Create console handler and set formatter + console_handler = logging.StreamHandler() + console_handler.setFormatter(ColoredFormatter()) + + # Add the handler to the logger + logger.addHandler(console_handler) + + initialized = True + + return logger \ No newline at end of file diff --git a/src/ell/studio/pubsub.py b/src/ell/studio/pubsub.py new file mode 100644 index 000000000..0c5c58550 --- /dev/null +++ b/src/ell/studio/pubsub.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +import asyncio +from functools import lru_cache +import json +import logging +from typing import Any + +from aiomqtt import Topic +import aiomqtt +from fastapi import WebSocket + +logger = logging.getLogger(__name__) + +Subscriber = WebSocket + + +class PubSub(ABC): + @abstractmethod + async def publish(self, topic: str, message: str) -> None: + pass + + def subscribe(self, topic: str, subscriber: Subscriber) -> None: + pass + + async def subscribe_async(self, topic: str, subscriber: Subscriber) -> None: + pass + + def unsubscribe(self, topic: str, subscriber: Subscriber): + pass + + def unsubscribe_from_all(self, subscriber: Subscriber): + pass + + +@lru_cache(maxsize=128) +def matchable(topic: str) -> Topic: + return Topic(topic) + + +class WebSocketPubSub(PubSub): + def __init__(self): + self.subscriptions: dict[str, list[Subscriber]] = {} + + async def publish(self, topic: str, message: Any): + # Notify all subscribers for the topic + # determine if match baased on mqtt wildcard logic + _topic = matchable(topic) + subscriptions = self.subscriptions.copy() # copy to avoid mutating while iterating + logger.info(f"Relaying message to socket {topic} subscribers") + for pattern in subscriptions: + if _topic.matches(pattern): + for subscriber in subscriptions[pattern]: + asyncio.create_task(subscriber.send_json( + {"topic": topic, "message": message})) + + def subscribe(self, topic: str, subscriber: Subscriber) -> None: + logger.info(f"Subscribing ws {subscriber} to {topic}") + # Add the subscriber to the list for the topic + if topic not in self.subscriptions: + self.subscriptions[topic] = [] + self.subscriptions[topic].append(subscriber) + + def unsubscribe(self, topic: str, subscriber: Subscriber): + subscriptions = self.subscriptions.copy() + if topic in subscriptions: + self.subscriptions[topic].remove(subscriber) + if not self.subscriptions[topic]: + del self.subscriptions[topic] + + def unsubscribe_from_all(self, subscriber: Subscriber): + for topic in self.subscriptions.copy(): + self.unsubscribe(topic, subscriber) + + +class MqttWebSocketPubSub(WebSocketPubSub): + mqtt_client: aiomqtt.Client + + def __init__(self, conn: aiomqtt.Client): + super().__init__() + self.mqtt_client = conn + + def listen(self, loop: asyncio.AbstractEventLoop): + self.listener = loop.create_task(self._relay_all()) + return self.listener + + async def publish(self, topic: str, message: str) -> None: + # this is a bit sus + await self.mqtt_client.publish(topic, message) + + async def _relay_all(self) -> None: + logger.info("Starting mqtt listener") + async for message in self.mqtt_client.messages: + try: + logger.info(f"Received message on topic {message.topic}: {message.payload}") + await super().publish(str(message.topic), json.loads( + message.payload # type: ignore + )) + except Exception as e: + logger.error(f"Error relaying message: {e}") + + async def subscribe_async(self, topic: str, subscriber: Subscriber) -> None: + await self.mqtt_client.subscribe(topic) + super().subscribe(topic, subscriber) + + +class NoOpPubSub(PubSub): + def subscribe(self, topic: str, subscriber: Subscriber) -> None: + pass + + def unsubscribe(self, topic: str, subscriber: Subscriber) -> None: + pass + + def unsubscribe_from_all(self, subscriber: Subscriber) -> None: + pass + + async def publish(self, topic: str, message: Any) -> None: + pass diff --git a/src/ell/studio/server.py b/src/ell/studio/server.py index 54397c6cd..397446ae0 100644 --- a/src/ell/studio/server.py +++ b/src/ell/studio/server.py @@ -1,45 +1,107 @@ +import asyncio +from contextlib import asynccontextmanager from typing import Optional, Dict, Any +import aiomqtt +import logging +import json + from sqlmodel import Session from ell.stores.sql import PostgresStore, SQLiteStore from ell import __version__ from fastapi import FastAPI, Query, HTTPException, Depends, Response, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware -import logging -import json from ell.studio.config import Config -from ell.studio.connection_manager import ConnectionManager -from ell.studio.datamodels import InvocationPublicWithConsumes, SerializedLMPWithUses +from ell.studio.datamodels import InvocationPublicWithConsumes,SerializedLMPWithUses,InvocationsAggregate +from ell.studio.pubsub import MqttWebSocketPubSub, NoOpPubSub, WebSocketPubSub -from ell.types import SerializedLMP +from ell.sqlmodels import SerializedLMP from datetime import datetime, timedelta from sqlmodel import select - logger = logging.getLogger(__name__) -from ell.studio.datamodels import InvocationsAggregate - - def get_serializer(config: Config): if config.pg_connection_string: + logger.info("Initializing Postgres serializer") return PostgresStore(config.pg_connection_string) elif config.storage_dir: + logger.info("Initializing SQLite serializer") return SQLiteStore(config.storage_dir) else: raise ValueError("No storage configuration found") +pubsub = None + +async def get_pubsub(): + yield pubsub + + def create_app(config:Config): + # setup_logging() serializer = get_serializer(config) def get_session(): with Session(serializer.engine) as session: yield session - app = FastAPI(title="ell Studio", version=__version__) + + @asynccontextmanager + async def lifespan(app: FastAPI): + global pubsub + # when we're just using sqlite, handle publishes from db_watcher + if config.storage_dir is not None: + pubsub=WebSocketPubSub() + yield + elif config.mqtt_connection_string is None: + pubsub = NoOpPubSub() + yield + else: + retry_interval_seconds = 1 + retry_max_attempts = 5 + task = None + + for attempt in range(retry_max_attempts): + try: + host, port = config.mqtt_connection_string.split("://")[1].split(":") + + logger.info(f"Connecting to MQTT broker at {host}:{port}") + + async with aiomqtt.Client(hostname=host, port=int(port) if port else 1883) as mqtt: + logger.info("Connected to MQTT") + pubsub = MqttWebSocketPubSub(mqtt) + loop = asyncio.get_event_loop() + task = pubsub.listen(loop) + # await pubsub.mqtt_client.subscribe("#") + # async for message in pubsub.mqtt_client.messages: + # logger.info(f"Received message on topic {message.topic}: {message.payload}") + # logger.info("Subscribed to all topics") + + yield # Allow the app to run + + # Clean up after yield + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + break # Exit the retry loop if successful + except aiomqtt.MqttError as e: + logger.error(f"Failed to connect to MQTT [Attempt {attempt + 1}/{retry_max_attempts}]: {e}") + if attempt < retry_max_attempts - 1: + await asyncio.sleep(retry_interval_seconds) + else: + logger.error("Max retry attempts reached. Unable to connect to MQTT.") + raise + + pubsub = None # Reset pubsub after exiting the context + + + app = FastAPI(title="ell Studio", version=__version__, lifespan=lifespan) # Enable CORS for all origins app.add_middleware( @@ -50,17 +112,18 @@ def get_session(): allow_headers=["*"], ) - manager = ConnectionManager() @app.websocket("/ws") - async def websocket_endpoint(websocket: WebSocket): - await manager.connect(websocket) + async def websocket_endpoint(websocket: WebSocket, pubsub: MqttWebSocketPubSub = Depends(get_pubsub)): + await websocket.accept() + await pubsub.subscribe_async("all", websocket) + await pubsub.subscribe_async("lmp/#", websocket) try: while True: data = await websocket.receive_text() # Handle incoming WebSocket messages if needed except WebSocketDisconnect: - manager.disconnect(websocket) + pubsub.unsubscribe_from_all(websocket) @app.get("/api/latest/lmps", response_model=list[SerializedLMPWithUses]) @@ -194,14 +257,18 @@ def get_lmp_history( return history + # used by db_watcher for sqlite async def notify_clients(entity: str, id: Optional[str] = None): + if pubsub is None: + logger.error("Pubsub not ready; cannot notify clients") + return message = json.dumps({"entity": entity, "id": id}) - await manager.broadcast(message) + await pubsub.publish("all", message) # Add this method to the app object app.notify_clients = notify_clients - + @app.get("/api/invocations/aggregate", response_model=InvocationsAggregate) def get_invocations_aggregate( lmp_name: Optional[str] = Query(None), @@ -217,7 +284,7 @@ def get_invocations_aggregate( aggregate_data = serializer.get_invocations_aggregate(session, lmp_filters=lmp_filters, days=days) return InvocationsAggregate(**aggregate_data) - - - + + + return app \ No newline at end of file diff --git a/src/ell/types/__init__.py b/src/ell/types/__init__.py index 055adcf2a..8171dffc2 100644 --- a/src/ell/types/__init__.py +++ b/src/ell/types/__init__.py @@ -3,5 +3,5 @@ """ from ell.types.message import * -from ell.types.studio import * from ell.types._lstr import * +from ell.types.lmp import * diff --git a/src/ell/types/lmp.py b/src/ell/types/lmp.py new file mode 100644 index 000000000..9560acf45 --- /dev/null +++ b/src/ell/types/lmp.py @@ -0,0 +1,8 @@ +import enum + + +class LMPType(str, enum.Enum): + LM = "LM" + TOOL = "TOOL" + MULTIMODAL = "MULTIMODAL" + OTHER = "OTHER" \ No newline at end of file diff --git a/tests/api/test_api.py b/tests/api/test_api.py new file mode 100644 index 000000000..92f4a4113 --- /dev/null +++ b/tests/api/test_api.py @@ -0,0 +1,205 @@ +from datetime import timezone +from logging import DEBUG +from uuid import uuid4 +import pytest +from typing import Any, Dict, List +from fastapi.testclient import TestClient +from sqlmodel import Session +from ell.api.server import NoopPublisher, create_app, get_publisher, get_serializer, get_session +from ell.api.config import Config +from ell.api.types import WriteLMPInput + +from ell.sqlmodels import SerializedLMP +from ell.stores.sql import SQLStore, SQLiteStore +from ell.studio.logger import setup_logging +from ell.api.types import utc_now +from ell.types.lmp import LMPType + + +@pytest.fixture +def sql_store() -> SQLStore: + return SQLiteStore(":memory:") + + +def test_construct_serialized_lmp(): + serialized_lmp = SerializedLMP( + lmp_id="test_lmp_id", + name="Test LMP", + source="def test_function(): pass", + dependencies=str(["dep1", "dep2"]), + lmp_type=LMPType.LM, + api_params={"param1": "value1"}, + version_number=1, + # uses={"used_lmp_1": {}, "used_lmp_2": {}}, + initial_global_vars={"global_var1": "value1"}, + initial_free_vars={"free_var1": "value2"}, + commit_message="Initial commit", + created_at=utc_now() + ) + assert serialized_lmp.lmp_id == "test_lmp_id" + assert serialized_lmp.name == "Test LMP" + assert serialized_lmp.source == "def test_function(): pass" + assert serialized_lmp.dependencies == str(["dep1", "dep2"]) + assert serialized_lmp.api_params == {"param1": "value1"} + assert serialized_lmp.version_number == 1 + assert serialized_lmp.created_at is not None + + +def test_write_lmp_input(): + # Should be able to construct a WriteLMPInput from data + input = WriteLMPInput( + lmp_id="test_lmp_id", + name="Test LMP", + source="def test_function(): pass", + dependencies=str(["dep1", "dep2"]), + lmp_type=LMPType.LM, + api_params={"param1": "value1"}, + initial_global_vars={"global_var1": "value1"}, + initial_free_vars={"free_var1": "value2"}, + commit_message="Initial commit", + version_number=1, + ) + + # Should default a created_at to utc_now + assert input.created_at is not None + assert input.created_at.tzinfo == timezone.utc + + # Should be able to construct a SerializedLMP from a WriteLMPInput + model = SerializedLMP(**input.model_dump()) + assert model.created_at == input.created_at + + input2 = WriteLMPInput( + lmp_id="test_lmp_id", + name="Test LMP", + source="def test_function(): pass", + dependencies=str(["dep1", "dep2"]), + lmp_type=LMPType.LM, + api_params={"param1": "value1"}, + initial_global_vars={"global_var1": "value1"}, + initial_free_vars={"free_var1": "value2"}, + commit_message="Initial commit", + version_number=1, + # should work with an isoformat string + created_at=utc_now().isoformat() # type: ignore + ) + model2 = SerializedLMP(**input2.model_dump()) + assert model2.created_at == input2.created_at + assert input2.created_at is not None + assert input2.created_at.tzinfo == timezone.utc + + +def create_test_app(sql_store: SQLStore): + setup_logging(DEBUG) + config = Config(storage_dir=":memory:") + app = create_app(config) + + publisher = NoopPublisher() + + async def get_publisher_override(): + yield publisher + + async def get_session_override(): + with Session(sql_store.engine) as session: + yield session + + def get_serializer_override(): + return sql_store + + app.dependency_overrides[get_publisher] = get_publisher_override + app.dependency_overrides[get_session] = get_session_override + app.dependency_overrides[get_serializer] = get_serializer_override + + client = TestClient(app) + + return app, client, publisher, config + + +def test_write_lmp(sql_store: SQLStore): + _app, client, *_ = create_test_app(sql_store) + + lmp_data: Dict[str, Any] = { + "lmp_id": uuid4().hex, + "name": "Test LMP", + "source": "def test_function(): pass", + "dependencies": str(["dep1", "dep2"]), + "lmp_type": LMPType.LM, + "api_params": {"param1": "value1"}, + "version_number": 1, + "uses": {"used_lmp_1": {}, "used_lmp_2": {}}, + "initial_global_vars": {"global_var1": "value1"}, + "initial_free_vars": {"free_var1": "value2"}, + "commit_message": "Initial commit", + "created_at": utc_now().isoformat().replace("+00:00", "Z") + } + uses: List[str] = [ + "used_lmp_1", + "used_lmp_2" + ] + + response = client.post( + "/lmp", + json={ + "lmp": lmp_data, + "uses": uses + } + ) + + assert response.status_code == 200 + + lmp = client.get(f"/lmp/{lmp_data['lmp_id']}") + assert lmp.status_code == 200 + del lmp_data["uses"] + assert lmp.json() == {**lmp_data, "num_invocations": 0} + + +def test_write_invocation(sql_store: SQLStore): + _app, client, *_ = create_test_app(sql_store) + + lmp_id = uuid4().hex + lmp_data: Dict[str, Any] = { + "lmp_id": lmp_id, + "name": "Test LMP", + "source": "def test_function(): pass", + "dependencies": str(["dep1", "dep2"]), + "lmp_type": LMPType.LM, + "api_params": {"param1": "value1"}, + } + response = client.post( + "/lmp", + json={'lmp': lmp_data, 'uses': []} + ) + try: + assert response.status_code == 200 + except Exception as e: + print(response.json()) + raise e + + invocation_data = { + "id": uuid4().hex, + "lmp_id": lmp_id, + "args": ["arg1", "arg2"], + "kwargs": {"kwarg1": "value1"}, + "global_vars": {"global_var1": "value1"}, + "free_vars": {"free_var1": "value2"}, + "latency_ms": 100.0, + "invocation_kwargs": {"model": "gpt-4o", "messages": [{"role": "system", "content": "You are a JSON parser. You respond only in JSON. Do not format using markdown."}, {"role": "user", "content": "You are given the following task: \"What is two plus two?\"\n Parse the task into the following type:\n {'$defs': {'Add': {'properties': {'op': {'const': '+', 'enum': ['+'], 'title': 'Op', 'type': 'string'}, 'a': {'title': 'A', 'type': 'number'}, 'b': {'title': 'B', 'type': 'number'}}, 'required': ['op', 'a', 'b'], 'title': 'Add', 'type': 'object'}, 'Div': {'properties': {'op': {'const': '/', 'enum': ['/'], 'title': 'Op', 'type': 'string'}, 'a': {'title': 'A', 'type': 'number'}, 'b': {'title': 'B', 'type': 'number'}}, 'required': ['op', 'a', 'b'], 'title': 'Div', 'type': 'object'}, 'Mul': {'properties': {'op': {'const': '*', 'enum': ['*'], 'title': 'Op', 'type': 'string'}, 'a': {'title': 'A', 'type': 'number'}, 'b': {'title': 'B', 'type': 'number'}}, 'required': ['op', 'a', 'b'], 'title': 'Mul', 'type': 'object'}, 'Sub': {'properties': {'op': {'const': '-', 'enum': ['-'], 'title': 'Op', 'type': 'string'}, 'a': {'title': 'A', 'type': 'number'}, 'b': {'title': 'B', 'type': 'number'}}, 'required': ['op', 'a', 'b'], 'title': 'Sub', 'type': 'object'}}, 'anyOf': [{'$ref': '#/$defs/Add'}, {'$ref': '#/$defs/Sub'}, {'$ref': '#/$defs/Mul'}, {'$ref': '#/$defs/Div'}]}\n "}], "lm_kwargs": {"temperature": 0.1}, "client": None}, + "contents": { } + } + consumes_data = [] + + input = { + "invocation": invocation_data, + "consumes": consumes_data + } + response = client.post( + "/invocation", + json=input + ) + + print(response.json()) + assert response.status_code == 200 + # assert response.json() == input + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/test_sql_store.py b/tests/test_sql_store.py index 0555d1d4f..1cda9775c 100644 --- a/tests/test_sql_store.py +++ b/tests/test_sql_store.py @@ -1,11 +1,12 @@ import pytest from datetime import datetime, timezone from sqlmodel import Session, select -from ell.stores.sql import SQLStore, SerializedLMP +from ell.sqlmodels import SerializedLMP +from ell.stores.sql import SQLStore from sqlalchemy import Engine, create_engine, func -from ell.types.studio import LMPType -from ell.types.studio import utc_now +from ell.types.lmp import LMPType +from ell.api.types import utc_now @pytest.fixture def in_memory_db(): diff --git a/tests/test_track.py b/tests/test_track.py new file mode 100644 index 000000000..0cc92dc75 --- /dev/null +++ b/tests/test_track.py @@ -0,0 +1,193 @@ +from typing import Any, Dict, List, Optional +import openai +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.completion_usage import CompletionUsage +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.pytest_plugin import register_fixture +from pytest_mock import MockerFixture + + +import ell +import pytest +from ell.api.client import EllAPIClient, EllClient, EllSqliteClient +from ell.api.types import LMP, GetLMPResponse, WriteInvocationInput, WriteLMPInput + +from ell.stores.sql import SQLStore, SQLiteStore + +import ell.providers.openai + +# T = TypeVar('T') + +# class SpyWrapper: +# def __init__(self, original_method: Callable, spy_method: Any): +# self.original_method = original_method +# self.spy_method = spy_method + +# def __call__(self, *args: Any, **kwargs: Any) -> Any: +# return self.spy_method(*args, **kwargs) + +# def __getattr__(self, name: str) -> Any: +# return getattr(self.spy_method, name) + +# class SpiedObject(Generic[T]): +# def __init__(self, original: T): +# self._original = original + +# def __getattr__(self, name: str) -> Any: +# return getattr(self._original, name) + +# def spy_all_methods(mocker: MockerFixture, obj: T) -> SpiedObject[T]: +# spied_obj = SpiedObject(obj) +# for attr_name in dir(obj): +# attr = getattr(obj, attr_name) +# if callable(attr) and not attr_name.startswith('__'): +# spy = mocker.spy(obj, attr_name) +# setattr(spied_obj, attr_name, SpyWrapper(attr, spy)) +# return spied_obj + + +@pytest.fixture +def sql_store() -> SQLStore: + return SQLiteStore(":memory:") + + +@pytest.fixture +def mock_openai_chatcompletion(monkeypatch: pytest.MonkeyPatch): + + class AsyncChatCompletionIterator: + def __init__(self, answer: str): + self.answer_index = 0 + self.answer_deltas = answer.split(" ") + + def __aiter__(self): + return self + + async def __anext__(self): + if self.answer_index < len(self.answer_deltas): + answer_chunk = self.answer_deltas[self.answer_index] + self.answer_index += 1 + return openai.util.convert_to_openai_object( + {"choices": [{"delta": {"content": answer_chunk}}]}) + else: + raise StopAsyncIteration + + async def mock_acreate(*args, **kwargs): + return AsyncChatCompletionIterator("The capital of France is Paris.") + + monkeypatch.setattr(openai.ChatCompletion, "create", mock_acreate) + + def mock_create(*args, **kwargs): + return ChatCompletion( + id="chatcmpl-123", + object="chat.completion", + created=1677652288, + model="gpt-3.5-turbo", + choices=[Choice( + index=0, + finish_reason='stop', + message=ChatCompletionMessage(role="assistant", content="Hello!"))], + usage=CompletionUsage( + prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + + class MockCompletions: + def create(*args, **kwargs): + return mock_create(*args, **kwargs) + + class MockChat: + def __init__(self, *args, **kwargs): + self.completions = MockCompletions() + + class MockOpenAPIClient: + def __init__(self, *args, **kwargs): + self.api_key = "test-api-key" + self.chat = MockChat() + + # monkeypatch.setattr(openai.chat.completions, "create", mock_create) + return MockOpenAPIClient(api_key="test-api-key") + + +def test_track_decorator_sqlite(mock_openai_chatcompletion: openai.Client): + ell.config.register_model(model_name="test-model", + client=mock_openai_chatcompletion) + ell.providers.openai.OpenAIProvider.get_client_type = lambda: mock_openai_chatcompletion.__class__ + ell.config.register_provider(ell.providers.openai.OpenAIProvider) + + ell.init(client=EllSqliteClient(storage_dir=':memory:')) + + @ell.simple(model="test-model") + def test_fn(): + return f"this is a test" + + test_fn() + + +@register_fixture +class LMPFactory(ModelFactory[LMP]): + ... + + +def test_track_decorator_api( + mocker: MockerFixture, + mock_openai_chatcompletion: openai.Client, + lmp_factory: LMPFactory +): + ell.config.register_model(model_name="test-model", + client=mock_openai_chatcompletion) + + ell.config.register_model(model_name='gpt-4o-mini', + client=mock_openai_chatcompletion) + + ell.providers.openai.OpenAIProvider.get_client_type = lambda: mock_openai_chatcompletion.__class__ + ell.config.register_provider(ell.providers.openai.OpenAIProvider) + + class TestAPIClient(EllClient): + def __init__(self, base_url: str): + self.base_url = base_url + + async def get_lmp(self, lmp_id: str) -> GetLMPResponse: + lmp = lmp_factory.build() + lmp.id = lmp_id + return lmp + + async def get_lmp_versions(self, fqn: str) -> List[LMP]: + lmp = lmp_factory.build() + lmp.name = fqn + return [lmp_factory.build()] + + async def write_lmp(self, lmp: WriteLMPInput, uses: List[str]): + return None + + async def write_invocation(self, input: WriteInvocationInput): + return None + + async def store_blob(self, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str: + return "foo" + + async def retrieve_blob(self, blob_id: str) -> bytes: + return b"bar" + + async def close(self): + pass + + api_client = TestAPIClient(base_url="foo") + get_lmp_versions = mocker.spy(api_client, "get_lmp_versions") + write_lmp = mocker.spy(api_client, "write_lmp") + write_invocation = mocker.spy(api_client, "write_invocation") + + ell.init(client=api_client) + + @ell.simple(model="test-model") + def test_fn(): + return f"this is a test" + + test_fn() + + assert get_lmp_versions.call_count == 1 + assert write_lmp.call_count == 1 + assert write_invocation.call_count == 1 + + +if __name__ == "__main__": + pytest.main()