Skip to content

Modernize type hints and clean up unused code #1928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
3 changes: 2 additions & 1 deletion graphrag/config/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,12 @@ class ModelType(str, Enum):
# Embeddings
OpenAIEmbedding = "openai_embedding"
AzureOpenAIEmbedding = "azure_openai_embedding"
GeminiEmbedding = "gemini_embedding" # New entry for Gemini Embedding

# Chat Completion
OpenAIChat = "openai_chat"
AzureOpenAIChat = "azure_openai_chat"
GeminiChat = "gemini_chat" # New entry for Gemini Chat Completion

# Debug
MockChat = "mock_chat"
Expand All @@ -106,7 +108,6 @@ def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'


class AuthType(str, Enum):
"""AuthType enum class definition."""

Expand Down
121 changes: 62 additions & 59 deletions graphrag/config/models/language_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,24 @@ def _validate_api_key(self) -> None:
)

def _validate_auth_type(self) -> None:
"""Validate the authentication type.

auth_type must be api_key when using OpenAI and
can be either api_key or azure_managed_identity when using AOI.

Raises
------
ConflictingSettingsError
If the Azure authentication type conflicts with the model being used.
"""
if self.auth_type == AuthType.AzureManagedIdentity and (
self.type == ModelType.OpenAIChat or self.type == ModelType.OpenAIEmbedding
):
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type}. Please rerun `graphrag init` and set the auth_type to api_key."
raise ConflictingSettingsError(msg)

type: ModelType | str = Field(description="The type of LLM model to use.")
"""Validate the authentication type.

auth_type must be api_key when using OpenAI and Gemini,
and can be either api_key or azure_managed_identity when using AOI.

Raises
------
ConflictingSettingsError
If the authentication type conflicts with the model being used.
"""
if self.auth_type == AuthType.AzureManagedIdentity and (
self.type in [ModelType.OpenAIChat, ModelType.OpenAIEmbedding, ModelType.GeminiChat, ModelType.GeminiEmbedding]
):
msg = (
f"auth_type of azure_managed_identity is not supported for model type {self.type}. "
"Please rerun `graphrag init` and set the auth_type to api_key."
)
raise ConflictingSettingsError(msg)

def _validate_type(self) -> None:
"""Validate the model type.
Expand Down Expand Up @@ -117,46 +118,44 @@ def _validate_encoding_model(self) -> None:
)

def _validate_api_base(self) -> None:
"""Validate the API base.

Required when using AOI.

Raises
------
AzureApiBaseMissingError
If the API base is missing and is required.
"""
if (
self.type == ModelType.AzureOpenAIChat
or self.type == ModelType.AzureOpenAIEmbedding
) and (self.api_base is None or self.api_base.strip() == ""):
raise AzureApiBaseMissingError(self.type)

api_version: str | None = Field(
description="The version of the LLM API to use.",
default=language_model_defaults.api_version,
)
"""Validate the API base.

Required when using AOI or Gemini.

Raises
------
AzureApiBaseMissingError
If the API base is missing and is required.
"""
if (
self.type in [
ModelType.AzureOpenAIChat,
ModelType.AzureOpenAIEmbedding,
ModelType.GeminiChat,
ModelType.GeminiEmbedding, # Added Gemini models
]
) and (self.api_base is None or self.api_base.strip() == ""):
raise AzureApiBaseMissingError(self.type)

def _validate_api_version(self) -> None:
"""Validate the API version.

Required when using AOI.

Raises
------
AzureApiBaseMissingError
If the API base is missing and is required.
"""
if (
self.type == ModelType.AzureOpenAIChat
or self.type == ModelType.AzureOpenAIEmbedding
) and (self.api_version is None or self.api_version.strip() == ""):
raise AzureApiVersionMissingError(self.type)

deployment_name: str | None = Field(
description="The deployment name to use for the LLM service.",
default=language_model_defaults.deployment_name,
)
"""Validate the API version.

Required when using AOI or Gemini.

Raises
------
AzureApiVersionMissingError
If the API version is missing and is required.
"""
if (
self.type in [
ModelType.AzureOpenAIChat,
ModelType.AzureOpenAIEmbedding,
ModelType.GeminiChat, # Added GeminiChat
ModelType.GeminiEmbedding, # Added GeminiEmbedding
]
) and (self.api_version is None or self.api_version.strip() == ""):
raise AzureApiVersionMissingError(self.type)

def _validate_deployment_name(self) -> None:
"""Validate the deployment name.
Expand All @@ -169,10 +168,14 @@ def _validate_deployment_name(self) -> None:
If the deployment name is missing and is required.
"""
if (
self.type == ModelType.AzureOpenAIChat
or self.type == ModelType.AzureOpenAIEmbedding
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
raise AzureDeploymentNameMissingError(self.type)
self.type in [
ModelType.AzureOpenAIChat,
ModelType.AzureOpenAIEmbedding,
ModelType.GeminiChat, # Added support for GeminiChat
ModelType.GeminiEmbedding, # Added support for GeminiEmbedding
]
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
raise AzureDeploymentNameMissingError(self.type)

organization: str | None = Field(
description="The organization to use for the LLM service.",
Expand Down
10 changes: 9 additions & 1 deletion graphrag/language_model/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
AzureOpenAIEmbeddingFNLLM,
OpenAIChatFNLLM,
OpenAIEmbeddingFNLLM,
GeminiChatFNLLM, # Import Gemini Chat model
GeminiEmbeddingFNLLM, # Import Gemini Embedding model
)


Expand Down Expand Up @@ -48,7 +50,7 @@ def create_chat_model(cls, model_type: str, **kwargs: Any) -> ChatModel:
A ChatModel instance.
"""
if model_type not in cls._chat_registry:
msg = f"ChatMOdel implementation '{model_type}' is not registered."
msg = f"ChatModel implementation '{model_type}' is not registered."
raise ValueError(msg)
return cls._chat_registry[model_type](**kwargs)

Expand Down Expand Up @@ -105,10 +107,16 @@ def is_supported_model(cls, model_type: str) -> bool:
ModelFactory.register_chat(
ModelType.OpenAIChat, lambda **kwargs: OpenAIChatFNLLM(**kwargs)
)
ModelFactory.register_chat(
ModelType.GeminiChat, lambda **kwargs: GeminiChatFNLLM(**kwargs) # Register Gemini Chat
)

ModelFactory.register_embedding(
ModelType.AzureOpenAIEmbedding, lambda **kwargs: AzureOpenAIEmbeddingFNLLM(**kwargs)
)
ModelFactory.register_embedding(
ModelType.OpenAIEmbedding, lambda **kwargs: OpenAIEmbeddingFNLLM(**kwargs)
)
ModelFactory.register_embedding(
ModelType.GeminiEmbedding, lambda **kwargs: GeminiEmbeddingFNLLM(**kwargs) # Register Gemini Embedding
)
79 changes: 79 additions & 0 deletions graphrag/language_model/providers/fnllm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
create_openai_client,
create_openai_embeddings_llm,
)
from fnllm.gemini import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intended to be included in your PR?
Gemini support is not available in fnllm, this will just break

create_gemini_chat_llm,
create_gemini_embeddings_llm,
create_gemini_client,
)

from graphrag.language_model.providers.fnllm.events import FNLLMEvents
from graphrag.language_model.providers.fnllm.utils import (
Expand Down Expand Up @@ -441,3 +446,77 @@ def embed(self, text: str, **kwargs) -> list[float]:
The embeddings of the text.
"""
return run_coroutine_sync(self.aembed(text, **kwargs))

class GeminiChatFNLLM:
"""A Gemini Chat Model provider using the fnllm library."""

model: FNLLMChatLLM

def __init__(
self,
*,
name: str,
config: LanguageModelConfig,
callbacks: WorkflowCallbacks | None = None,
cache: PipelineCache | None = None,
) -> None:
model_config = _create_openai_config(config, azure=False)
error_handler = _create_error_handler(callbacks) if callbacks else None
model_cache = _create_cache(cache, name)
client = create_gemini_client(model_config)
self.model = create_gemini_chat_llm(
model_config,
client=client,
cache=model_cache,
events=FNLLMEvents(error_handler) if error_handler else None,
)
self.config = config

async def achat(
self, prompt: str, history: list | None = None, **kwargs
) -> ModelResponse:
if history is None:
response = await self.model(prompt, **kwargs)
else:
response = await self.model(prompt, history=history, **kwargs)
return BaseModelResponse(
output=BaseModelOutput(content=response.output.content),
parsed_response=response.parsed_json,
history=response.history,
cache_hit=response.cache_hit,
tool_calls=response.tool_calls,
metrics=response.metrics,
)

class GeminiEmbeddingFNLLM:
"""A Gemini Embedding Model provider using the fnllm library."""

model: FNLLMEmbeddingLLM

def __init__(
self,
*,
name: str,
config: LanguageModelConfig,
callbacks: WorkflowCallbacks | None = None,
cache: PipelineCache | None = None,
) -> None:
model_config = _create_openai_config(config, azure=False)
error_handler = _create_error_handler(callbacks) if callbacks else None
model_cache = _create_cache(cache, name)
client = create_gemini_client(model_config)
self.model = create_gemini_embeddings_llm(
model_config,
client=client,
cache=model_cache,
events=FNLLMEvents(error_handler) if error_handler else None,
)
self.config = config

async def aembed(self, text: str, **kwargs) -> list[float]:
response = await self.model([text], **kwargs)
if response.output.embeddings is None:
msg = "No embeddings found in response"
raise ValueError(msg)
embeddings: list[float] = response.output.embeddings[0]
return embeddings
22 changes: 18 additions & 4 deletions graphrag/logger/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Base classes for logging and progress reporting."""

from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Dict, Optional

from graphrag.logger.progress import Progress

Expand All @@ -13,15 +13,15 @@ class StatusLogger(ABC):
"""Provides a way to log status updates from the pipeline."""

@abstractmethod
def error(self, message: str, details: dict[str, Any] | None = None):
def error(self, message: str, details: Optional[Dict[str, Any]] = None):
"""Log an error."""

@abstractmethod
def warning(self, message: str, details: dict[str, Any] | None = None):
def warning(self, message: str, details: Optional[Dict[str, Any]] = None):
"""Log a warning."""

@abstractmethod
def log(self, message: str, details: dict[str, Any] | None = None):
def log(self, message: str, details: Optional[Dict[str, Any]] = None):
"""Report a log."""


Expand Down Expand Up @@ -67,3 +67,17 @@ def info(self, message: str) -> None:
@abstractmethod
def success(self, message: str) -> None:
"""Log success."""


# Optional: Default implementation of StatusLogger for basic logging functionality
class DefaultLogger(StatusLogger):
"""Default implementation of StatusLogger that logs to standard output."""

def error(self, message: str, details: Optional[Dict[str, Any]] = None):
print(f"ERROR: {message}", details if details else "")

def warning(self, message: str, details: Optional[Dict[str, Any]] = None):
print(f"WARNING: {message}", details if details else "")

def log(self, message: str, details: Optional[Dict[str, Any]] = None):
print(f"LOG: {message}", details if details else "")