diff --git a/examples/ollama_example.py b/examples/ollama_example.py index 01b12e3d5..c07518a4a 100644 --- a/examples/ollama_example.py +++ b/examples/ollama_example.py @@ -1,10 +1,10 @@ import ell -@ell.lm(model="llama3", temperature=0.1) +@ell.lm(model="llama3.1:8b", temperature=0.1) def write_a_story(): return "write me a story" -ell.models.ollama.register_models(api_base="http://localhost:11434") +ell.models.ollama.register(base_url="http://localhost:11434") -write_a_story() \ No newline at end of file +print(write_a_story()) \ No newline at end of file diff --git a/src/ell/models/ollama.py b/src/ell/models/ollama.py index 438090081..e48d9af14 100644 --- a/src/ell/models/ollama.py +++ b/src/ell/models/ollama.py @@ -1,5 +1,5 @@ from ell.configurator import config -import openai +import ollama import requests import logging @@ -8,17 +8,16 @@ def register(base_url): global client - client = openai.Client(base_url=base_url) + client = ollama.Client(host=base_url) try: - response = requests.get(f"{base_url}/api/tags") - response.raise_for_status() - models = response.json().get("models", []) + models = client.list() - for model in models: - config.register_model(model["name"], client) - except requests.RequestException as e: - logger.error(f"Failed to fetch models from {base_url}: {e}") + if 'models' in models: + for model in models['models']: + config.register_model(model['name'], client) + logger.info(f"Registered {len(models['models'])} Ollama models") + else: + logger.warning("No models found in Ollama response") except Exception as e: - logger.error(f"An error occurred: {e}") - + logger.error(f"An error occurred while registering Ollama models: {e}") \ No newline at end of file diff --git a/src/ell/util/lm.py b/src/ell/util/lm.py index 325a8208b..cee14cf20 100644 --- a/src/ell/util/lm.py +++ b/src/ell/util/lm.py @@ -9,7 +9,7 @@ from ell.util.verbosity import model_usage_logger_post_end, model_usage_logger_post_intermediate, model_usage_logger_post_start - +import ollama def _get_lm_kwargs(lm_kwargs: Dict[str, Any], lm_params: LMPParams) -> Dict[str, Any]: """ @@ -44,56 +44,68 @@ def _run_lm( lm_kwargs: Dict[str, Any], _invocation_origin : str, exempt_from_tracking: bool, - client: Optional[openai.Client] = None, + client: Optional[Any] = None, _logging_color=None, ) -> Tuple[Union[lstr, Iterable[lstr]], Optional[Dict[str, Any]]]: """ Helper function to run the language model with the provided messages and parameters. """ # Todo: Decide if the client specified via the context amanger default registry is the shit or if the cliennt specified via lmp invocation args are the hing. - client = client or config.get_client_for(model) + client = client or config.get_client_for(model) metadata = dict() if client is None: raise ValueError(f"No client found for model '{model}'. Ensure the model is registered using 'register_model' in 'config.py' or specify a client directly using the 'client' argument in the decorator or function call.") - # todo: add suupport for streaming apis that dont give a final usage in the api - model_result = client.chat.completions.create( - model=model, messages=messages, stream=True, stream_options={"include_usage": True}, **lm_kwargs - ) - - choices_progress = defaultdict(list) - n = lm_kwargs.get("n", 1) - - if config.verbose and not exempt_from_tracking: - model_usage_logger_post_start(_logging_color, n) - - with model_usage_logger_post_intermediate(_logging_color, n) as _logger: - for chunk in model_result: - if chunk.usage: - # Todo: is this a good decision. - metadata = chunk.to_dict() - continue - for choice in chunk.choices: - choices_progress[choice.index].append(choice) - if config.verbose and choice.index == 0 and not exempt_from_tracking: - _logger(choice.delta.content) - - if config.verbose and not exempt_from_tracking: - model_usage_logger_post_end() - n_choices = len(choices_progress) - - tracked_results = [ - lstr( - content="".join((choice.delta.content or "" for choice in choice_deltas)), - # logits=( # - # np.concatenate([np.array( - # [c.logprob for c in choice.logprobs.content or []] - # ) for choice in choice_deltas]) # mypy type hinting is dogshit. - # ), - # Todo: Properly implement log probs. - _origin_trace=_invocation_origin, + if isinstance(client, ollama.Client): + # Handle Ollama client + prompt = "\n".join([f"{m.role}: {m.content}" for m in messages]) + ollama_kwargs = {} + if 'temperature' in lm_kwargs: + ollama_kwargs['options'] = {'temperature': lm_kwargs['temperature']} + response = client.generate(model=model, prompt=prompt, **ollama_kwargs) + result = response['response'] + metadata = response + tracked_results = [lstr(content=result, _origin_trace=_invocation_origin)] + else: + # Handle OpenAI and other clients + # todo: add suupport for streaming apis that dont give a final usage in the api + model_result = client.chat.completions.create( + model=model, messages=messages, stream=True, stream_options={"include_usage": True}, **lm_kwargs ) - for _, choice_deltas in sorted(choices_progress.items(), key= lambda x: x[0],) - ] - return tracked_results[0] if n_choices == 1 else tracked_results, metadata \ No newline at end of file + choices_progress = defaultdict(list) + n = lm_kwargs.get("n", 1) + + if config.verbose and not exempt_from_tracking: + model_usage_logger_post_start(_logging_color, n) + + with model_usage_logger_post_intermediate(_logging_color, n) as _logger: + for chunk in model_result: + if chunk.usage: + # Todo: is this a good decision. + metadata = chunk.to_dict() + continue + for choice in chunk.choices: + choices_progress[choice.index].append(choice) + if config.verbose and choice.index == 0 and not exempt_from_tracking: + _logger(choice.delta.content) + + if config.verbose and not exempt_from_tracking: + model_usage_logger_post_end() + n_choices = len(choices_progress) + + tracked_results = [ + lstr( + content="".join((choice.delta.content or "" for choice in choice_deltas)), + # logits=( # + # np.concatenate([np.array( + # [c.logprob for c in choice.logprobs.content or []] + # ) for choice in choice_deltas]) # mypy type hinting is dogshit. + # ), + # Todo: Properly implement log probs. + _origin_trace=_invocation_origin, + ) + for _, choice_deltas in sorted(choices_progress.items(), key= lambda x: x[0],) + ] + + return tracked_results[0] if len(tracked_results) == 1 else tracked_results, metadata \ No newline at end of file