Skip to content

Commit

Permalink
more fixes to utility imports (truera#786)
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrm0 authored Jan 9, 2024
1 parent 724eee5 commit e36b997
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 19 deletions.
16 changes: 13 additions & 3 deletions trulens_eval/trulens_eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,9 @@

__version__ = "0.20.2"

from trulens_eval.feedback import Bedrock
from trulens_eval.feedback import Feedback
from trulens_eval.feedback import Huggingface
from trulens_eval.feedback import Langchain
from trulens_eval.feedback import LiteLLM
from trulens_eval.feedback import OpenAI
from trulens_eval.feedback.provider import Provider
from trulens_eval.schema import FeedbackMode
from trulens_eval.schema import Select
Expand All @@ -94,12 +91,25 @@
from trulens_eval.tru_chain import TruChain
from trulens_eval.tru_custom_app import TruCustomApp
from trulens_eval.utils.imports import OptionalImports
from trulens_eval.utils.imports import REQUIREMENT_BEDROCK
from trulens_eval.utils.imports import REQUIREMENT_LITELLM
from trulens_eval.utils.imports import REQUIREMENT_LLAMA
from trulens_eval.utils.imports import REQUIREMENT_OPENAI
from trulens_eval.utils.threading import TP

with OptionalImports(messages=REQUIREMENT_BEDROCK):
from trulens_eval.feedback import Bedrock

with OptionalImports(messages=REQUIREMENT_LLAMA):
from trulens_eval.tru_llama import TruLlama

with OptionalImports(messages=REQUIREMENT_LITELLM):
from trulens_eval.feedback import LiteLLM

with OptionalImports(messages=REQUIREMENT_OPENAI):
from trulens_eval.feedback import OpenAI


__all__ = [
"Tru",
"TruBasicApp",
Expand Down
20 changes: 16 additions & 4 deletions trulens_eval/trulens_eval/feedback/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import logging
from typing import Any, Callable, Dict, Iterable, Tuple, Union

from trulens_eval.utils.imports import OptionalImports
from trulens_eval.utils.imports import REQUIREMENT_BEDROCK
from trulens_eval.utils.imports import REQUIREMENT_LITELLM
from trulens_eval.utils.imports import REQUIREMENT_OPENAI

logger = logging.getLogger(__name__)

# Signature of feedback implementations. Take in any number of arguments
Expand All @@ -16,13 +21,20 @@
from trulens_eval.feedback.feedback import Feedback
from trulens_eval.feedback.groundedness import Groundedness
from trulens_eval.feedback.groundtruth import GroundTruthAgreement
from trulens_eval.feedback.provider.bedrock import Bedrock
# Providers of feedback functions evaluation:
from trulens_eval.feedback.provider.hugs import Huggingface
from trulens_eval.feedback.provider.langchain import Langchain
from trulens_eval.feedback.provider.litellm import LiteLLM
from trulens_eval.feedback.provider.openai import AzureOpenAI
from trulens_eval.feedback.provider.openai import OpenAI

with OptionalImports(messages=REQUIREMENT_BEDROCK):
from trulens_eval.feedback.provider.bedrock import Bedrock

with OptionalImports(messages=REQUIREMENT_LITELLM):
from trulens_eval.feedback.provider.litellm import LiteLLM

with OptionalImports(messages=REQUIREMENT_OPENAI):
from trulens_eval.feedback.provider.openai import AzureOpenAI
from trulens_eval.feedback.provider.openai import OpenAI


__all__ = [
"Feedback",
Expand Down
19 changes: 15 additions & 4 deletions trulens_eval/trulens_eval/feedback/groundedness.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,26 @@

from trulens_eval.feedback import prompts
from trulens_eval.feedback.provider.base import Provider
from trulens_eval.feedback.provider.bedrock import Bedrock
from trulens_eval.feedback.provider.hugs import Huggingface
from trulens_eval.feedback.provider.litellm import LiteLLM
from trulens_eval.feedback.provider.openai import AzureOpenAI
from trulens_eval.feedback.provider.openai import OpenAI
from trulens_eval.utils.generated import re_0_10_rating
from trulens_eval.utils.imports import OptionalImports
from trulens_eval.utils.imports import REQUIREMENT_BEDROCK
from trulens_eval.utils.imports import REQUIREMENT_LITELLM
from trulens_eval.utils.imports import REQUIREMENT_OPENAI
from trulens_eval.utils.pyschema import WithClassInfo
from trulens_eval.utils.serial import SerialModel

with OptionalImports(messages=REQUIREMENT_BEDROCK):
from trulens_eval.feedback.provider.bedrock import Bedrock

with OptionalImports(messages=REQUIREMENT_OPENAI):
from trulens_eval.feedback.provider.openai import AzureOpenAI
from trulens_eval.feedback.provider.openai import OpenAI

with OptionalImports(messages=REQUIREMENT_LITELLM):
from trulens_eval.feedback.provider.litellm import LiteLLM


logger = logging.getLogger(__name__)


Expand Down
30 changes: 22 additions & 8 deletions trulens_eval/trulens_eval/feedback/provider/endpoint/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,17 @@ def track_all_costs(
# the produced endpoints might be ones that were constructed earlier.
for endpoint in Endpoint.ENDPOINT_SETUPS:
if locals().get(endpoint.arg_flag):
mod = __import__(
endpoint.module_name, fromlist=[endpoint.class_name]
)
cls = safe_getattr(mod, endpoint.class_name)
try:
mod = __import__(
endpoint.module_name, fromlist=[endpoint.class_name]
)
cls = safe_getattr(mod, endpoint.class_name)
except Exception:
# If endpoint uses optional packages, either module not
# found error, or we will have a dummy which will fail at
# getattr.
continue

try:
e = cls()
endpoints.append(e)
Expand Down Expand Up @@ -431,10 +438,17 @@ async def atrack_all_costs(

for endpoint in Endpoint.ENDPOINT_SETUPS:
if locals().get(endpoint.arg_flag):
mod = __import__(
endpoint.module_name, fromlist=[endpoint.class_name]
)
cls = safe_getattr(mod, endpoint.class_name)
try:
mod = __import__(
endpoint.module_name, fromlist=[endpoint.class_name]
)
cls = safe_getattr(mod, endpoint.class_name)
except Exception:
# If endpoint uses optional packages, either module not
# found error, or we will have a dummy which will fail at
# getattr.
continue

try:
e = cls()
endpoints.append(e)
Expand Down

0 comments on commit e36b997

Please sign in to comment.