Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community[patch]: add detailed paragraph and example for BaichuanTextEmbeddings #22031

Merged
merged 5 commits into from
Jun 5, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions libs/community/langchain_community/embeddings/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from requests import RequestException

BAICHUAN_API_URL: str = "http://api.baichuan-ai.com/v1/embeddings"

Expand All @@ -22,11 +23,23 @@
# NOTE!! BaichuanTextEmbeddings only supports Chinese text embedding.
# Multi-language support is coming soon.
class BaichuanTextEmbeddings(BaseModel, Embeddings):
"""Baichuan Text Embedding models."""
"""Baichuan Text Embedding models.

To use, you should set the environment variable ``BAICHUAN_API_KEY`` to
your API key or pass it as a named parameter to the constructor.

Example:
.. code-block:: python

from langchain_community.embeddings import BaichuanTextEmbeddings

baichuan = BaichuanTextEmbeddings(baichuan_api_key="my-api-key")
"""

session: Any #: :meta private:
model_name: str = "Baichuan-Text-Embedding"
baichuan_api_key: Optional[SecretStr] = None
"""Automatically inferred from env var `BAICHUAN_API_KEY` if not provided."""

@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
Expand Down Expand Up @@ -69,6 +82,8 @@ def _embed(self, texts: List[str]) -> Optional[List[List[float]]]:
response = self.session.post(
BAICHUAN_API_URL, json={"input": texts, "model": self.model_name}
)
# Raise exception if response status code from 400 to 600
response.raise_for_status()
# Check if the response status code indicates success
if response.status_code == 200:
resp = response.json()
Expand All @@ -79,15 +94,14 @@ def _embed(self, texts: List[str]) -> Optional[List[List[float]]]:
return [result.get("embedding", []) for result in sorted_embeddings]
else:
# Log error or handle unsuccessful response appropriately
print( # noqa: T201
# Handle 100 <= status_code < 400, not include 200
raise RequestException(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will raise exceptions for all 2xx except for 200 as well, which isn't correct behavior

Why not remove this exception entirely since response.raise_for_status() already raises?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the method response.raise_for_status() just handle status_code from 400 to 600

# requests -> models.py -> Response

def raise_for_status(self):
    """Raises :class:`HTTPError`, if one occurred."""

    http_error_msg = ""
    if isinstance(self.reason, bytes):
        # We attempt to decode utf-8 first because some servers
        # choose to localize their reason strings. If the string
        # isn't utf-8, we fall back to iso-8859-1 for all other
        # encodings. (See PR #3538)
        try:
            reason = self.reason.decode("utf-8")
        except UnicodeDecodeError:
            reason = self.reason.decode("iso-8859-1")
    else:
        reason = self.reason

    if 400 <= self.status_code < 500:
        http_error_msg = (
            f"{self.status_code} Client Error: {reason} for url: {self.url}"
        )

    elif 500 <= self.status_code < 600:
        http_error_msg = (
            f"{self.status_code} Server Error: {reason} for url: {self.url}"
        )

    if http_error_msg:
        raise HTTPError(http_error_msg, response=self)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eyurtsev Can you give me some suggestions for improvement?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm merging this version. If you want feel free to follow up with a PR. The conditional if response.status==200 can be changed to just accept 2xx codes. Not super critical here, but code looks a bit odd where an error message is rasied from a 2xx code that isn't 200

f"Error: Received status code {response.status_code} from "
"embedding API"
"`BaichuanEmbedding` API"
)
return None
except Exception as e:
except Exception:
maang-h marked this conversation as resolved.
Show resolved Hide resolved
# Log the exception or handle it as needed
print(f"Exception occurred while trying to get embeddings: {str(e)}") # noqa: T201
return None
raise

def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: # type: ignore[override]
"""Public method to get embeddings for a list of documents.
Expand Down
Loading