Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 52 additions & 20 deletions posthog/ai/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ def on_chain_end(
**kwargs: Any,
):
"""Capture a completed LangChain chain run as a trace or span."""
self._log_debug_event("on_chain_end", run_id, parent_run_id, outputs=outputs)
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, outputs)
self._capture_trace_or_span_run(
"on_chain_end", "outputs", outputs, run_id, parent_run_id
)

def on_chain_error(
self,
Expand All @@ -185,8 +186,9 @@ def on_chain_error(
**kwargs: Any,
):
"""Capture a failed LangChain chain run as a trace or span."""
self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error)
self._capture_trace_or_span_run(
"on_chain_error", "error", error, run_id, parent_run_id
)

def on_chat_model_start(
self,
Expand Down Expand Up @@ -243,10 +245,9 @@ def on_llm_end(
"""
The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
"""
self._log_debug_event(
"on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs
self._capture_generation_run(
"on_llm_end", "response", response, run_id, parent_run_id, kwargs=kwargs
)
self._pop_run_and_capture_generation(run_id, parent_run_id, response)

def on_llm_error(
self,
Expand All @@ -257,8 +258,9 @@ def on_llm_error(
**kwargs: Any,
):
"""Capture a failed LLM run as a PostHog AI generation event."""
self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error)
self._pop_run_and_capture_generation(run_id, parent_run_id, error)
self._capture_generation_run(
"on_llm_error", "error", error, run_id, parent_run_id
)

def on_tool_start(
self,
Expand Down Expand Up @@ -288,8 +290,9 @@ def on_tool_end(
**kwargs: Any,
) -> Any:
"""Capture a completed LangChain tool run as a span."""
self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output)
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, output)
self._capture_trace_or_span_run(
"on_tool_end", "output", output, run_id, parent_run_id
)

def on_tool_error(
self,
Expand All @@ -301,8 +304,9 @@ def on_tool_error(
**kwargs: Any,
) -> Any:
"""Capture a failed LangChain tool run as a span."""
self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error)
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error)
self._capture_trace_or_span_run(
"on_tool_error", "error", error, run_id, parent_run_id
)

def on_retriever_start(
self,
Expand Down Expand Up @@ -330,10 +334,9 @@ def on_retriever_end(
**kwargs: Any,
):
"""Capture a completed LangChain retriever run as a span."""
self._log_debug_event(
"on_retriever_end", run_id, parent_run_id, documents=documents
self._capture_trace_or_span_run(
"on_retriever_end", "documents", documents, run_id, parent_run_id
)
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, documents)

def on_retriever_error(
self,
Expand All @@ -345,8 +348,9 @@ def on_retriever_error(
**kwargs: Any,
) -> Any:
"""Run when Retriever errors."""
self._log_debug_event("on_retriever_error", run_id, parent_run_id, error=error)
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error)
self._capture_trace_or_span_run(
"on_retriever_error", "error", error, run_id, parent_run_id
)

def on_agent_action(
self,
Expand All @@ -370,8 +374,36 @@ def on_agent_finish(
**kwargs: Any,
) -> Any:
"""Capture a completed LangChain agent action as a span."""
self._log_debug_event("on_agent_finish", run_id, parent_run_id, finish=finish)
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, finish)
self._capture_trace_or_span_run(
"on_agent_finish", "finish", finish, run_id, parent_run_id
)

def _capture_trace_or_span_run(
self,
event_name: str,
payload_name: str,
payload: Any,
run_id: UUID,
parent_run_id: Optional[UUID],
):
self._log_debug_event(
event_name, run_id, parent_run_id, **{payload_name: payload}
)
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, payload)

def _capture_generation_run(
self,
event_name: str,
payload_name: str,
payload: Any,
run_id: UUID,
parent_run_id: Optional[UUID],
**extra: Any,
):
self._log_debug_event(
event_name, run_id, parent_run_id, **{payload_name: payload}, **extra
)
self._pop_run_and_capture_generation(run_id, parent_run_id, payload)

def _set_parent_of_run(self, run_id: UUID, parent_run_id: Optional[UUID] = None):
"""
Expand Down
79 changes: 8 additions & 71 deletions posthog/ai/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from posthog.ai.sanitization import sanitize_openai, sanitize_openai_response
from posthog.client import Client as PostHogClient
from posthog import setup
from posthog.ai.openai.wrapper_utils import warn_on_fallback
from posthog.ai.openai.wrapper_utils import OpenAIWrapperResource


class OpenAI(openai.OpenAI):
Expand Down Expand Up @@ -91,18 +91,9 @@ def _parse_and_track(
)


class WrappedResponses:
class WrappedResponses(OpenAIWrapperResource):
"""Wrapper for OpenAI responses that tracks usage in PostHog."""

def __init__(self, client: OpenAI, original_responses):
self._client = client
self._original = original_responses

def __getattr__(self, name):
"""Fallback to original responses object for any methods we don't explicitly handle."""
warn_on_fallback(self.__class__.__name__, name)
return getattr(self._original, name)

def create(
self,
posthog_distinct_id: Optional[str] = None,
Expand Down Expand Up @@ -312,36 +303,18 @@ def parse(
)


class WrappedChat:
class WrappedChat(OpenAIWrapperResource):
"""Wrapper for OpenAI chat that tracks usage in PostHog."""

def __init__(self, client: OpenAI, original_chat):
self._client = client
self._original = original_chat

def __getattr__(self, name):
"""Fallback to original chat object for any methods we don't explicitly handle."""
warn_on_fallback(self.__class__.__name__, name)
return getattr(self._original, name)

@property
def completions(self):
"""Access chat completions with PostHog usage tracking."""
return WrappedCompletions(self._client, self._original.completions)


class WrappedCompletions:
class WrappedCompletions(OpenAIWrapperResource):
"""Wrapper for OpenAI chat completions that tracks usage in PostHog."""

def __init__(self, client: OpenAI, original_completions):
self._client = client
self._original = original_completions

def __getattr__(self, name):
"""Fallback to original completions object for any methods we don't explicitly handle."""
warn_on_fallback(self.__class__.__name__, name)
return getattr(self._original, name)

def parse(
self,
posthog_distinct_id: Optional[str] = None,
Expand Down Expand Up @@ -566,18 +539,9 @@ def _capture_streaming_event(
capture_streaming_event(self._client._ph_client, event_data)


class WrappedEmbeddings:
class WrappedEmbeddings(OpenAIWrapperResource):
"""Wrapper for OpenAI embeddings that tracks usage in PostHog."""

def __init__(self, client: OpenAI, original_embeddings):
self._client = client
self._original = original_embeddings

def __getattr__(self, name):
"""Fallback to original embeddings object for any methods we don't explicitly handle."""
warn_on_fallback(self.__class__.__name__, name)
return getattr(self._original, name)

def create(
self,
posthog_distinct_id: Optional[str] = None,
Expand Down Expand Up @@ -651,54 +615,27 @@ def create(
return response


class WrappedBeta:
class WrappedBeta(OpenAIWrapperResource):
"""Wrapper for OpenAI beta features that tracks usage in PostHog."""

def __init__(self, client: OpenAI, original_beta):
self._client = client
self._original = original_beta

def __getattr__(self, name):
"""Fallback to original beta object for any methods we don't explicitly handle."""
warn_on_fallback(self.__class__.__name__, name)
return getattr(self._original, name)

@property
def chat(self):
"""Access beta chat APIs with PostHog usage tracking."""
return WrappedBetaChat(self._client, self._original.chat)


class WrappedBetaChat:
class WrappedBetaChat(OpenAIWrapperResource):
"""Wrapper for OpenAI beta chat that tracks usage in PostHog."""

def __init__(self, client: OpenAI, original_beta_chat):
self._client = client
self._original = original_beta_chat

def __getattr__(self, name):
"""Fallback to original beta chat object for any methods we don't explicitly handle."""
warn_on_fallback(self.__class__.__name__, name)
return getattr(self._original, name)

@property
def completions(self):
"""Access beta chat completions with PostHog usage tracking."""
return WrappedBetaCompletions(self._client, self._original.completions)


class WrappedBetaCompletions:
class WrappedBetaCompletions(OpenAIWrapperResource):
"""Wrapper for OpenAI beta chat completions that tracks usage in PostHog."""

def __init__(self, client: OpenAI, original_beta_completions):
self._client = client
self._original = original_beta_completions

def __getattr__(self, name):
"""Fallback to original beta completions object for any methods we don't explicitly handle."""
warn_on_fallback(self.__class__.__name__, name)
return getattr(self._original, name)

def parse(
self,
posthog_distinct_id: Optional[str] = None,
Expand Down
Loading
Loading