diff --git a/src/ell/studio/__main__.py b/src/ell/studio/__main__.py index 20d03cd80..cadb584ca 100644 --- a/src/ell/studio/__main__.py +++ b/src/ell/studio/__main__.py @@ -5,7 +5,9 @@ from ell.studio.server import create_app from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse -from watchfiles import awatch +import watchfiles +import importlib +import sys import time def main(): @@ -17,7 +19,8 @@ def main(): parser.add_argument("--dev", action="store_true", help="Run in development mode") args = parser.parse_args() - app = create_app(args.storage_dir) + + app = create_app() if not args.dev: # In production mode, serve the built React app @@ -28,16 +31,13 @@ def main(): async def serve_react_app(full_path: str): return FileResponse(os.path.join(static_dir, "index.html")) - db_path = os.path.join(args.storage_dir, "ell.db") async def db_watcher(db_path, app): + print("Starting db watcher") last_stat = None - while True: - await asyncio.sleep(0.1) # Fixed interval of 0.1 seconds try: current_stat = os.stat(db_path) - if last_stat is None: print(f"Database file found: {db_path}") await app.notify_clients("database_updated") @@ -64,15 +64,61 @@ async def db_watcher(db_path, app): except Exception as e: print(f"Error checking database file: {e}") await asyncio.sleep(1) # Wait a bit longer on errors + finally: + await asyncio.sleep(1) # Use a consistent sleep interval + + def get_dependencies(module_name): + module = importlib.import_module(module_name) + return list(set(sys.modules[name].__file__ for name in sys.modules if name.startswith(module_name.split('.')[0]))) + + def reload_app(): + importlib.reload(sys.modules["ell.studio.server"]) + return create_app() + + async def run_server(server): + await server.serve() + + async def watch_files(dependencies, server, config, loop): + async for changes in watchfiles.awatch(*dependencies): + print(f"Detected changes in {changes}. Reloading...") + new_app = reload_app() + await server.shutdown() + config.app = new_app + server.force_exit = False + loop.create_task(run_server(server)) + + async def main_async(args): + db_path = os.path.join(args.storage_dir, "ell.db") + dependencies = get_dependencies("ell.studio.server") + + config = uvicorn.Config( + app=app, + host=args.host, + port=args.port, + loop=asyncio.get_event_loop(), + ) + server = uvicorn.Server(config) + + tasks = [ + asyncio.create_task(run_server(server)), + ] + # todo. figure out equivalent for other backends + # maybe the server should broadcast a message to all clients on write instead of the db watcher approach + if args.storage_dir: + tasks.append(asyncio.create_task(db_watcher(db_path, app))) + if args.dev: + tasks.append(asyncio.create_task(watch_files(dependencies, server, config, asyncio.get_event_loop()))) - # Start the database watcher - loop = asyncio.new_event_loop() + try: + await asyncio.gather(*tasks) + except asyncio.CancelledError: + pass + finally: + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) - config = uvicorn.Config(app=app, port=args.port, loop=loop) - server = uvicorn.Server(config) - loop.create_task(server.serve()) - loop.create_task(db_watcher(db_path, app)) - loop.run_forever() + asyncio.run(main_async(args)) if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/ell/studio/server.py b/src/ell/studio/server.py index 0475494eb..95dab65af 100644 --- a/src/ell/studio/server.py +++ b/src/ell/studio/server.py @@ -10,6 +10,7 @@ import logging import asyncio import json +from argparse import ArgumentParser import ell.studio.connection_manager from ell.studio.connection_manager import ConnectionManager from ell.studio.datamodels import SerializedLMPPublic, SerializedLMPWithUses @@ -18,13 +19,18 @@ from datetime import datetime, timedelta from sqlmodel import select + logger = logging.getLogger(__name__) +def create_app(): + parser = ArgumentParser(description="ELL Studio Data Server") + parser.add_argument("--storage-dir", default=os.getcwd(), + help="Directory for filesystem serializer storage (default: current directory)") + args, _ = parser.parse_known_args() -def create_app(storage_dir: Optional[str] = None): - storage_path = storage_dir or os.environ.get("ELL_STORAGE_DIR") or os.getcwd() + storage_path = args.storage_dir or os.environ.get("ELL_STORAGE_DIR") or os.getcwd() assert storage_path, "ELL_STORAGE_DIR must be set" serializer = SQLiteStore(storage_path) def get_session():