diff --git a/src/ell/providers/openai.py b/src/ell/providers/openai.py index 8fe7f4d1b..a9d70d874 100644 --- a/src/ell/providers/openai.py +++ b/src/ell/providers/openai.py @@ -83,7 +83,8 @@ def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]: role=message.role, content=[_content_block_to_openai_format(c) for c in message.content] if message.role != "system" - else message.text_only + else message.text_only, + tool_call_id=message.tool_call_id ))) final_call_params["messages"] = openai_messages @@ -159,7 +160,7 @@ def translate_from_provider( ) ) if logger: logger(repr(tool_call)) - messages.append(Message(role=role, content=content_blocks)) + messages.append(Message(role=role, content=content_blocks, tool_call_id=message.tool_call_id)) return messages, metadata @@ -181,4 +182,4 @@ def _content_block_to_openai_format(content_block: ContentBlock) -> Dict[str, An elif ((text := content_block.text) is not None): return dict(type="text", text=text) elif (parsed := content_block.parsed): return dict(type="text", text=parsed.model_dump_json()) else: - raise ValueError(f"Unsupported content block type for openai: {content_block}") \ No newline at end of file + raise ValueError(f"Unsupported content block type for openai: {content_block}") diff --git a/src/ell/types/message.py b/src/ell/types/message.py index 93b8fc5c0..92e6bc2a0 100644 --- a/src/ell/types/message.py +++ b/src/ell/types/message.py @@ -312,12 +312,11 @@ def to_content_blocks( class Message(BaseModel): role: str content: List[ContentBlock] - + tool_call_id: Optional[_lstr_generic] = Field(default=None) - def __init__(self, role: str, content: Union[AnyContent, List[AnyContent], None] = None, **content_block_kwargs): + def __init__(self, role: str, content: Union[AnyContent, List[AnyContent], None] = None, tool_call_id: Optional[_lstr_generic] = None, **content_block_kwargs): content_blocks = to_content_blocks(content, **content_block_kwargs) - - super().__init__(role=role, content=content_blocks) + super().__init__(role=role, content=content_blocks, tool_call_id=tool_call_id) # XXX: This choice of naming is unfortunate, but it is what it is. @property @@ -452,6 +451,8 @@ def model_validate(cls, obj: Any) -> 'Message': else: content_blocks.append(ContentBlock.coerce(block)) obj['content'] = content_blocks + if 'tool_call_id' in obj: + obj['tool_call_id'] = _lstr(obj['tool_call_id']) return super().model_validate(obj) @classmethod