Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from abc import ABC, abstractmethod
from typing import Any

from langfuse.callback.langchain import LangchainCallbackHandler
from langfuse.langchain import CallbackHandler as LangfuseCallbackHandler
from pydantic import BaseModel

from gen_ai_orchestrator.models.observability.observability_setting import (
Expand All @@ -34,7 +34,7 @@ class LangChainCallbackHandlerFactory(ABC, BaseModel):
setting: BaseObservabilitySetting

@abstractmethod
def get_callback_handler(self, **kwargs: Any) -> LangchainCallbackHandler:
def get_callback_handler(self, **kwargs: Any) -> LangfuseCallbackHandler:
"""
Fabric a callback handler.
:return: LangchainCallbackHandler.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from httpx_auth_awssigv4 import SigV4Auth
from langfuse import Langfuse
from langfuse.api.core import ApiError
from langfuse.callback import CallbackHandler as LangfuseCallbackHandler
from langfuse.langchain import CallbackHandler as LangfuseCallbackHandler
from pydantic import PrivateAttr

from gen_ai_orchestrator.configurations.environment.settings import (
application_settings,
Expand Down Expand Up @@ -59,30 +60,81 @@ class LangfuseCallbackHandlerFactory(LangChainCallbackHandlerFactory):

setting: ObservabilitySetting

# Internal client cache
_langfuse_client: Optional[Langfuse] = PrivateAttr(default=None)

def _get_langfuse_client(self) -> Langfuse:
"""
Create or return the initialized Langfuse client
"""
if self._langfuse_client is None:
settings = self._fetch_settings()
self._langfuse_client = Langfuse(
public_key=settings['public_key'],
secret_key=settings['secret_key'],
base_url=settings['base_url'],
timeout=settings['timeout'],
httpx_client=self._get_httpx_client(),
)
return self._langfuse_client

def get_callback_handler(self, **kwargs: Any) -> LangfuseCallbackHandler:
return LangfuseCallbackHandler(**self._fetch_settings(), httpx_client=self._get_httpx_client(), **kwargs)
"""
Create Langfuse CallbackHandler
"""
self._get_langfuse_client()

# Ignore Langfuse V2 parameters to stay backward-compatible
if kwargs:
logger.debug(
'Ignoring unsupported Langfuse CallbackHandler kwargs in V3: %s',
list(kwargs.keys()),
)

return LangfuseCallbackHandler(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return LangfuseCallbackHandler(
# Langfuse SDK maintains an internal map / pool of clients based on there public key, that why the client isn't passed to the callbackhandler constructor.
return LangfuseCallbackHandler(

public_key=self.setting.public_key,
)

def check_observability_setting(self) -> bool:
"""Check if the provided credentials (public and secret key) are valid,
while tracing a sample phrase"""
try:
self.get_callback_handler().auth_check()
Langfuse(**self._fetch_settings(), httpx_client=self._get_httpx_client()).trace(
name=ObservabilityTrace.CHECK_OBSERVABILITY_SETTINGS.value, output='Check observability setting trace')
client = self._get_langfuse_client()
logger.debug('Lang')

if not client.auth_check():
logger.error('Langfuse auth_check() returned False')
raise GenAIObservabilityErrorException(
'Langfuse authentication check failed'
)

with client.start_as_current_observation(
as_type='span',
name=ObservabilityTrace.CHECK_OBSERVABILITY_SETTINGS.value,
input={'message': 'Check observability setting'},
) as span:
span.update(output='Check observability setting trace')

client.flush()

except ApiError as exc:
logger.error(exc)
raise GenAIObservabilityErrorException(
create_error_info_langfuse(exc)
)
return True

def _fetch_settings(self):
def _fetch_settings(self) -> dict:
"""
Fetch necessary parameters to initialise Langfuse client.
"""
return {
'host': str(self.setting.url),
'base_url': str(self.setting.url),
'public_key': self.setting.public_key,
'secret_key': fetch_secret_key_value(self.setting.secret_key),
'timeout': application_settings.observability_provider_timeout,
'max_retries': application_settings.observability_provider_max_retries
# kept for backward-compatibility, not used anymore
'max_retries': application_settings.observability_provider_max_retries,
}

def _get_httpx_client(self) -> Optional[Client]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing import Any, Optional

from langchain_core.embeddings import Embeddings
from langfuse.callback import CallbackHandler as LangfuseCallbackHandler
from langfuse.langchain import CallbackHandler as LangfuseCallbackHandler

from gen_ai_orchestrator.configurations.environment.settings import (
application_settings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
RunnableSerializable,
)
from langchain_core.vectorstores import VectorStoreRetriever
from langfuse.callback import CallbackHandler as LangfuseCallbackHandler
from langfuse import get_client, propagate_attributes
from typing_extensions import Any

from gen_ai_orchestrator.errors.exceptions.exceptions import (
Expand Down Expand Up @@ -131,7 +131,7 @@ async def execute_rag_chain(
message_history.add_ai_message(msg.text)
session_id = (request.dialog.dialog_id,)
user_id = (request.dialog.user_id,)
tags = (request.dialog.tags,)
tags = (request.dialog.tags,) or []

logger.debug(
'RAG chain - Use chat history: %s',
Expand Down Expand Up @@ -160,16 +160,23 @@ async def execute_rag_chain(
# Langfuse callback handler
observability_handler = create_observability_callback_handler(
observability_setting=request.observability_setting,
trace_name=ObservabilityTrace.RAG.value,
session_id=session_id,
user_id=user_id,
tags=tags,
)
callback_handlers.append(observability_handler)

metadata = {}
if user_id is not None:
metadata['langfuse_user_id'] = str(user_id)
if session_id is not None:
metadata['langfuse_session_id'] = str(session_id)
if tags:
metadata['langfuse_tags'] = list(tags)

response = await conversational_retrieval_chain.ainvoke(
input=inputs,
config={'callbacks': callback_handlers},
config={
'callbacks': callback_handlers,
'metadata': metadata,
},
)

# RAG Guard
Expand Down Expand Up @@ -204,7 +211,10 @@ async def execute_rag_chain(
)
),
),
observability_info=get_observability_info(observability_handler),
observability_info=get_observability_info(
observability_handler,
ObservabilityTrace.RAG.value if observability_handler is not None else None,
),
debug=get_rag_debug_data(request, records_callback_handler, rag_duration)
if debug
else None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import logging
from typing import Optional

from langfuse.callback import CallbackHandler as LangfuseCallbackHandler
from langfuse import get_client
from langfuse.langchain import CallbackHandler as LangfuseCallbackHandler

from gen_ai_orchestrator.models.observability.observability_type import (
ObservabilitySetting,
Expand Down Expand Up @@ -45,13 +46,20 @@ def check_observability_setting(setting: ObservabilitySetting) -> bool:
return get_callback_handler_factory(setting).check_observability_setting()


def get_observability_info(observability_handler) -> Optional[ObservabilityInfo]:
def get_observability_info(observability_handler, trace_name: Optional[str] = None) -> Optional[ObservabilityInfo]:
"""Get the observability Information"""
if isinstance(observability_handler, LangfuseCallbackHandler):
return ObservabilityInfo(
trace_id=observability_handler.trace.id,
trace_name=observability_handler.trace_name,
trace_url=observability_handler.get_trace_url()
)
else:
return None
if not isinstance(observability_handler, LangfuseCallbackHandler):
return None

trace_id = getattr(observability_handler, 'last_trace_id', None)
if trace_id is None:
return None

langfuse_client = get_client()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure you are getting the right client here ?
You should be using the client public key to be sure at least.
You can access the client using observability_handler.client I think it's the easiest way.

This could lead to rights access errors.

trace_url = langfuse_client.get_trace_url(trace_id=trace_id)

return ObservabilityInfo(
trace_id=trace_id,
trace_name=trace_name,
trace_url=trace_url,
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
#
import os
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import ANY, AsyncMock, MagicMock, patch

import pytest
from langchain_core.documents import Document
Expand Down Expand Up @@ -184,7 +184,10 @@ async def test_rag_chain(
# Assert qa chain is ainvoke()d with the expected settings from request
mocked_chain.ainvoke.assert_called_once_with(
input=inputs,
config={'callbacks': [mocked_callback, mocked_langfuse_callback]},
config={
'callbacks': [mocked_callback, mocked_langfuse_callback],
'metadata': ANY,
},
)
# Assert the response is build using the expected settings
mocked_rag_response.assert_called_once_with(
Expand Down Expand Up @@ -430,7 +433,7 @@ def test_rag_guard_accepts_no_answer_even_with_docs(mocked_log):
'documents': ['a doc as a string'],
}
rag_chain.rag_guard(inputs, response, documents_required=True)
assert response['documents'] == ['a doc as a string']
assert response['documents'] == []


@patch('gen_ai_orchestrator.services.langchain.rag_chain.rag_log')
Expand Down