Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/ollama_example.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import ell

@ell.lm(model="llama3", temperature=0.1)
@ell.lm(model="llama3.1:8b", temperature=0.1)
def write_a_story():
return "write me a story"


ell.models.ollama.register_models(api_base="http://localhost:11434")
ell.models.ollama.register(base_url="http://localhost:11434")

write_a_story()
print(write_a_story())
21 changes: 10 additions & 11 deletions src/ell/models/ollama.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ell.configurator import config
import openai
import ollama
import requests
import logging

Expand All @@ -8,17 +8,16 @@

def register(base_url):
global client
client = openai.Client(base_url=base_url)
client = ollama.Client(host=base_url)

try:
response = requests.get(f"{base_url}/api/tags")
response.raise_for_status()
models = response.json().get("models", [])
models = client.list()

for model in models:
config.register_model(model["name"], client)
except requests.RequestException as e:
logger.error(f"Failed to fetch models from {base_url}: {e}")
if 'models' in models:
for model in models['models']:
config.register_model(model['name'], client)
logger.info(f"Registered {len(models['models'])} Ollama models")
else:
logger.warning("No models found in Ollama response")
except Exception as e:
logger.error(f"An error occurred: {e}")

logger.error(f"An error occurred while registering Ollama models: {e}")
96 changes: 54 additions & 42 deletions src/ell/util/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ell.util.verbosity import model_usage_logger_post_end, model_usage_logger_post_intermediate, model_usage_logger_post_start


import ollama

def _get_lm_kwargs(lm_kwargs: Dict[str, Any], lm_params: LMPParams) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -44,56 +44,68 @@ def _run_lm(
lm_kwargs: Dict[str, Any],
_invocation_origin : str,
exempt_from_tracking: bool,
client: Optional[openai.Client] = None,
client: Optional[Any] = None,
_logging_color=None,
) -> Tuple[Union[lstr, Iterable[lstr]], Optional[Dict[str, Any]]]:
"""
Helper function to run the language model with the provided messages and parameters.
"""
# Todo: Decide if the client specified via the context amanger default registry is the shit or if the cliennt specified via lmp invocation args are the hing.
client = client or config.get_client_for(model)
client = client or config.get_client_for(model)
metadata = dict()
if client is None:
raise ValueError(f"No client found for model '{model}'. Ensure the model is registered using 'register_model' in 'config.py' or specify a client directly using the 'client' argument in the decorator or function call.")

# todo: add suupport for streaming apis that dont give a final usage in the api
model_result = client.chat.completions.create(
model=model, messages=messages, stream=True, stream_options={"include_usage": True}, **lm_kwargs
)

choices_progress = defaultdict(list)
n = lm_kwargs.get("n", 1)

if config.verbose and not exempt_from_tracking:
model_usage_logger_post_start(_logging_color, n)

with model_usage_logger_post_intermediate(_logging_color, n) as _logger:
for chunk in model_result:
if chunk.usage:
# Todo: is this a good decision.
metadata = chunk.to_dict()
continue
for choice in chunk.choices:
choices_progress[choice.index].append(choice)
if config.verbose and choice.index == 0 and not exempt_from_tracking:
_logger(choice.delta.content)

if config.verbose and not exempt_from_tracking:
model_usage_logger_post_end()
n_choices = len(choices_progress)

tracked_results = [
lstr(
content="".join((choice.delta.content or "" for choice in choice_deltas)),
# logits=( #
# np.concatenate([np.array(
# [c.logprob for c in choice.logprobs.content or []]
# ) for choice in choice_deltas]) # mypy type hinting is dogshit.
# ),
# Todo: Properly implement log probs.
_origin_trace=_invocation_origin,
if isinstance(client, ollama.Client):
# Handle Ollama client
prompt = "\n".join([f"{m.role}: {m.content}" for m in messages])
ollama_kwargs = {}
if 'temperature' in lm_kwargs:
ollama_kwargs['options'] = {'temperature': lm_kwargs['temperature']}
response = client.generate(model=model, prompt=prompt, **ollama_kwargs)
result = response['response']
metadata = response
tracked_results = [lstr(content=result, _origin_trace=_invocation_origin)]
else:
# Handle OpenAI and other clients
# todo: add suupport for streaming apis that dont give a final usage in the api
model_result = client.chat.completions.create(
model=model, messages=messages, stream=True, stream_options={"include_usage": True}, **lm_kwargs
)
for _, choice_deltas in sorted(choices_progress.items(), key= lambda x: x[0],)
]

return tracked_results[0] if n_choices == 1 else tracked_results, metadata
choices_progress = defaultdict(list)
n = lm_kwargs.get("n", 1)

if config.verbose and not exempt_from_tracking:
model_usage_logger_post_start(_logging_color, n)

with model_usage_logger_post_intermediate(_logging_color, n) as _logger:
for chunk in model_result:
if chunk.usage:
# Todo: is this a good decision.
metadata = chunk.to_dict()
continue
for choice in chunk.choices:
choices_progress[choice.index].append(choice)
if config.verbose and choice.index == 0 and not exempt_from_tracking:
_logger(choice.delta.content)

if config.verbose and not exempt_from_tracking:
model_usage_logger_post_end()
n_choices = len(choices_progress)

tracked_results = [
lstr(
content="".join((choice.delta.content or "" for choice in choice_deltas)),
# logits=( #
# np.concatenate([np.array(
# [c.logprob for c in choice.logprobs.content or []]
# ) for choice in choice_deltas]) # mypy type hinting is dogshit.
# ),
# Todo: Properly implement log probs.
_origin_trace=_invocation_origin,
)
for _, choice_deltas in sorted(choices_progress.items(), key= lambda x: x[0],)
]

return tracked_results[0] if len(tracked_results) == 1 else tracked_results, metadata