Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/ell/decorators/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,16 @@ def model_call(
return tracked_str, api_params, metadata

# TODO: # we'll deal with type safety here later
model_call.__ell_lm_kwargs__ = lm_kwargs
# XXX: Do we need intermediate params?
model_call.__ell_func__ = prompt
model_call.__ell_type__ = LMPType.LM
model_call.__ell_exempt_from_tracking = exempt_from_tracking

if exempt_from_tracking:
return model_call
else:
return track(model_call, forced_dependencies=dict(tools=tools))
return track(model_call, forced_dependencies=dict(tools=tools), lmp_type=LMPType.LM, lm_kwargs=lm_kwargs)


return parameterized_lm_decorator



return parameterized_lm_decorator
36 changes: 17 additions & 19 deletions src/ell/decorators/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,15 @@ def exclude_var(v):
# is module or is immutable
return inspect.ismodule(v)

def track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, Any]] = None) -> Callable:

lmp_type = getattr(func_to_track, "__ell_type__", LMPType.OTHER)
_has_serialized_lmp = {}
_lmp_hash = {}

def track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, Any]] = None, lm_kwargs: Optional[Dict[str, Any]] = None, lmp_type: Optional[LMPType] = LMPType.OTHER) -> Callable:

# see if it exists
if not hasattr(func_to_track, "_has_serialized_lmp"):
func_to_track._has_serialized_lmp = False

if not hasattr(func_to_track, "__ell_hash__") and not config.lazy_versioning:

if not ell.util.closure.has_closured_function(func_to_track) and not config.lazy_versioning:
ell.util.closure.lexically_closured_source(func_to_track, forced_dependencies)


@wraps(func_to_track)
def tracked_func(*fn_args, **fn_kwargs) -> str:
# XXX: Cache keys and global variable binding is not thread safe.
Expand All @@ -76,14 +72,15 @@ def tracked_func(*fn_args, **fn_kwargs) -> str:

if try_use_cache:
# Todo: add nice logging if verbose for when using a cahced invocaiton. IN a different color with thar args..
if not hasattr(func_to_track, "__ell_hash__") and config.lazy_versioning:
if not ell.util.closure.has_closured_function(func_to_track) and config.lazy_versioning:
fn_closure, _ = ell.util.closure.lexically_closured_source(func_to_track)

# compute the state cachekey
state_cache_key = compute_state_cache_key(ipstr, func_to_track.__ell_closure__)
lexical_closure = ell.util.closure.get_lexical_closure(func_to_track)
state_cache_key = compute_state_cache_key(ipstr, lexical_closure.closure)

cache_store = func_to_track.__wrapper__.__ell_use_cache__
cached_invocations = cache_store.get_cached_invocations(func_to_track.__ell_hash__, state_cache_key)
cached_invocations = cache_store.get_cached_invocations(lexical_closure.hash, state_cache_key)


if len(cached_invocations) > 0:
Expand Down Expand Up @@ -115,12 +112,13 @@ def tracked_func(*fn_args, **fn_kwargs) -> str:
prompt_tokens=usage.get("prompt_tokens", 0)
completion_tokens=usage.get("completion_tokens", 0)

if not hasattr(func_to_track, "__ell_hash__") and config.lazy_versioning:
if not ell.util.closure.has_closured_function(func_to_track) and config.lazy_versioning:
ell.util.closure.lexically_closured_source(func_to_track, forced_dependencies)
_serialize_lmp(func_to_track)

lexical_closure = ell.util.closure.get_lexical_closure(func_to_track)
if not state_cache_key:
state_cache_key = compute_state_cache_key(ipstr, func_to_track.__ell_closure__)
state_cache_key = compute_state_cache_key(ipstr, lexical_closure.closure)

_write_invocation(func_to_track, invocation_id, latency_ms, prompt_tokens, completion_tokens,
state_cache_key, invocation_kwargs, cleaned_invocation_params, consumes, result, parent_invocation_id)
Expand All @@ -131,8 +129,7 @@ def tracked_func(*fn_args, **fn_kwargs) -> str:


func_to_track.__wrapper__ = tracked_func
if hasattr(func_to_track, "__ell_lm_kwargs__"):
tracked_func.__ell_lm_kwargs__ = func_to_track.__ell_lm_kwargs__
# XXX: Move away from __ private declarations this should be object oriented.
if hasattr(func_to_track, "__ell_params_model__"):
tracked_func.__ell_params_model__ = func_to_track.__ell_params_model__
tracked_func.__ell_func__ = func_to_track
Expand All @@ -142,12 +139,13 @@ def tracked_func(*fn_args, **fn_kwargs) -> str:

def _serialize_lmp(func):
# Serialize deptjh first all fo the used lmps.
for f in func.__ell_uses__:
lexical_closure = ell.util.closure.get_lexical_closure(func)
for f in lexical_closure.uses:
_serialize_lmp(f)

if getattr(func, "_has_serialized_lmp", False):
if getattr(func, _has_serialized_lmp[func], False):
return
func._has_serialized_lmp = False
_has_serialized_lmp[func] = True
fn_closure = func.__ell_closure__
lmp_type = func.__ell_type__
name = func.__qualname__
Expand Down
32 changes: 24 additions & 8 deletions src/ell/util/closure.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def xD():
"""
import collections
import ast
from dataclasses import dataclass
import hashlib
import itertools
import os
Expand Down Expand Up @@ -279,14 +280,7 @@ def _generate_function_hash(source, dsrc, qualname):
"""Generate a hash for the function."""
return "lmp-" + hashlib.md5("\n".join((source, dsrc, qualname)).encode()).hexdigest()

def _update_ell_func(outer_ell_func, source, dsrc, globals_dict, frees_dict, fn_hash, uses):
"""Update the ell function attributes."""
formatted_source = _format_source(source)
formatted_dsrc = _format_source(dsrc)
if hasattr(outer_ell_func, "__ell_func__"):
outer_ell_func.__ell_closure__ = (formatted_source, formatted_dsrc, globals_dict, frees_dict)
outer_ell_func.__ell_hash__ = fn_hash
outer_ell_func.__ell_uses__ = uses


def _raise_error(message, exception, recursion_stack):
"""Raise an error with detailed information."""
Expand Down Expand Up @@ -505,3 +499,25 @@ def globalvars(func, recurse=True, builtin=False):
#NOTE: if name not in __globals__, then we skip it...
return dict((name,globs[name]) for name in func if name in globs)


def _update_ell_func(outer_ell_func, source, dsrc, globals_dict, frees_dict, fn_hash, uses):
"""Update the ell function attributes."""
formatted_source = _format_source(source)
formatted_dsrc = _format_source(dsrc)
if hasattr(outer_ell_func, "__ell_func__"):
function_closures[outer_ell_func] = LexicalClosure(hash=fn_hash, closure=(formatted_source, formatted_dsrc, globals_dict, frees_dict), uses=uses)

@dataclass
class LexicalClosure:
hash : str
closure : Tuple[str, str, Dict[str, Any], Dict[str, Any]]
uses : Set[str]

# cache of all the closured funciton closures.
function_closures : Dict[Callable, LexicalClosure] = {}

def has_closured_function(func : Callable) -> bool:
return func in function_closures

def get_lexical_closure(func : Callable) -> LexicalClosure | None:
return function_closures.get(func)