Skip to content

style: apply ALL ruff rules #158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,22 @@ target-version = "py310"

[tool.ruff.format]
docstring-code-format = true
skip-magic-trailing-comma = true

[tool.ruff.lint]
select = ["A", "ASYNC", "B", "BLE", "C4", "C90", "D", "DTZ", "E", "EM", "ERA", "F", "FBT", "FLY", "FURB", "G", "I", "ICN", "INP", "INT", "ISC", "LOG", "N", "NPY", "PERF", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "Q", "RET", "RSE", "RUF", "S", "SIM", "SLF", "SLOT", "T10", "T20", "TCH", "TID", "TRY", "UP", "W", "YTT"]
ignore = ["D203", "D213", "E501", "RET504", "RUF002", "RUF022", "S101", "S307", "TC004"]
select = ["ALL"]
ignore = ["CPY", "FIX", "ARG001", "COM812", "D203", "D213", "E501", "PD008", "PD009", "RET504", "S101", "TD003"]
unfixable = ["ERA001", "F401", "F841", "T201", "T203"]

[tool.ruff.lint.flake8-annotations]
allow-star-arg-any = true

[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "all"

[tool.ruff.lint.isort]
split-on-trailing-comma = false

[tool.ruff.lint.pycodestyle]
max-doc-length = 100

Expand Down
2 changes: 1 addition & 1 deletion src/raglite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
vector_search,
)

__all__ = [
__all__ = [ # noqa: RUF022
# Config
"RAGLiteConfig",
# Insert
Expand Down
12 changes: 6 additions & 6 deletions src/raglite/_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
insert_variant: str | None = None,
search_variant: str | None = None,
config: RAGLiteConfig | None = None,
):
) -> None:
super().__init__(
dataset,
num_results=num_results,
Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(
num_results: int = 10,
insert_variant: str | None = None,
search_variant: str | None = None,
):
) -> None:
super().__init__(
dataset,
num_results=num_results,
Expand All @@ -156,7 +156,7 @@ def __init__(
self.embedder_dim = 3072
self.persist_path = self.cwd / self.insert_id

def insert_documents(self, max_workers: int | None = None) -> None:
def insert_documents(self, max_workers: int | None = None) -> None: # noqa: ARG002
# Adapted from https://docs.llamaindex.ai/en/stable/examples/vector_stores/FaissIndexDemo/.
import faiss
from llama_index.core import Document, StorageContext, VectorStoreIndex
Expand All @@ -178,7 +178,7 @@ def insert_documents(self, max_workers: int | None = None) -> None:
index.storage_context.persist(persist_dir=self.persist_path)

@cached_property
def index(self) -> Any:
def index(self) -> Any: # noqa: ANN401
from llama_index.core import StorageContext, load_index_from_storage
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
Expand Down Expand Up @@ -215,7 +215,7 @@ def __init__(
num_results: int = 10,
insert_variant: str | None = None,
search_variant: str | None = None,
):
) -> None:
super().__init__(
dataset,
num_results=num_results,
Expand All @@ -227,7 +227,7 @@ def __init__(
)

@cached_property
def client(self) -> Any:
def client(self) -> Any: # noqa: ANN401
import openai

return openai.OpenAI()
Expand Down
37 changes: 8 additions & 29 deletions src/raglite/_chatml_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,12 @@

import json
import warnings
from typing import ( # noqa: UP035
Any,
Iterator,
List,
Optional,
Union,
cast,
)
from typing import Any, Iterator, List, Optional, Union, cast # noqa: UP035

import jinja2
from jinja2.sandbox import ImmutableSandboxedEnvironment

from raglite._lazy_llama import (
llama,
llama_chat_format,
llama_grammar,
llama_types,
)
from raglite._lazy_llama import llama, llama_chat_format, llama_grammar, llama_types


def _accumulate_chunks(
Expand Down Expand Up @@ -98,7 +86,7 @@ def _convert_chunks_to_completion(
{
"text": text,
"index": 0,
"logprobs": logprobs, # TODO: Improve accumulation of logprobs
"logprobs": logprobs, # TODO(lsorber): Improve accumulation of logprobs
"finish_reason": finish_reason, # type: ignore[typeddict-item]
}
],
Expand Down Expand Up @@ -143,12 +131,7 @@ def _stream_tool_calls(
llama_grammar.JSON_GBNF, verbose=llama.verbose
)
completion_or_chunks = llama.create_completion(
prompt=prompt,
**{
**completion_kwargs,
"max_tokens": None,
"grammar": grammar,
},
prompt=prompt, **{**completion_kwargs, "max_tokens": None, "grammar": grammar}
)
chunks: List[llama_types.CreateCompletionResponse] = []
chat_chunks = llama_chat_format._convert_completion_to_chat_function( # noqa: SLF001
Expand Down Expand Up @@ -206,11 +189,7 @@ def _convert_text_completion_logprobs_to_chat(
"bytes": None,
"logprob": logprob, # type: ignore[typeddict-item]
"top_logprobs": [
{
"token": top_token,
"logprob": top_logprob,
"bytes": None,
}
{"token": top_token, "logprob": top_logprob, "bytes": None}
for top_token, top_logprob in (top_logprobs or {}).items()
],
}
Expand Down Expand Up @@ -318,9 +297,9 @@ def chatml_function_calling_with_streaming(
"{% endfor %}"
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
)
template_renderer = ImmutableSandboxedEnvironment(
undefined=jinja2.StrictUndefined,
).from_string(function_calling_template)
template_renderer = ImmutableSandboxedEnvironment(undefined=jinja2.StrictUndefined).from_string(
function_calling_template
)

# Convert legacy functions to tools
if functions is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/raglite/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def install_mcp_server(
"--python",
"3.11",
"--with",
"numpy<2.0.0", # TODO: Remove this constraint when uv no longer needs it to solve the environment.
"numpy<2.0.0", # TODO(lsorber): Remove this constraint when uv no longer needs it to solve the environment.
"raglite",
"mcp",
"run",
Expand Down
2 changes: 1 addition & 1 deletion src/raglite/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


# Lazily load the default search method to avoid circular imports.
# TODO: Replace with search_and_rerank_chunk_spans after benchmarking.
# TODO(lsorber): Replace with search_and_rerank_chunk_spans after benchmarking.
def _vector_search(
query: str, *, num_results: int = 8, config: "RAGLiteConfig | None" = None
) -> tuple[list[ChunkId], list[float]]:
Expand Down
4 changes: 2 additions & 2 deletions src/raglite/_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _create_segment(
# Compute the number of tokens per sentence. We use a method based on a sentinel token to
# minimise the number of calls to embedder.tokenize, which incurs a significant overhead
# (presumably to load the tokenizer) [1].
# TODO: Make token counting faster and more robust once [1] is fixed.
# TODO(lsorber): Make token counting faster and more robust once [1] is fixed.
# [1] https://github.com/abetlen/llama-cpp-python/issues/1763
num_tokens_list: list[int] = []
sentence_batch, sentence_batch_len = [], 0
Expand All @@ -94,7 +94,7 @@ def _create_segment(
# Compute the maximum number of tokens for each segment's preamble and content.
# Unfortunately, llama-cpp-python truncates the input to n_batch tokens and crashes if you try
# to increase it [1]. Until this is fixed, we have to limit max_tokens to n_batch.
# TODO: Improve the context window size once [1] is fixed.
# TODO(lsorber): Improve the context window size once [1] is fixed.
# [1] https://github.com/abetlen/llama-cpp-python/issues/1762
max_tokens = min(n_ctx, n_batch) - 16
max_tokens_preamble = round(0.382 * max_tokens) # Golden ratio.
Expand Down
17 changes: 4 additions & 13 deletions src/raglite/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,7 @@ class ContextEvalResponse(BaseModel):

relevant_chunks = []
for candidate_chunk in tqdm(
candidate_chunks,
desc="Evaluating chunks",
unit="chunk",
dynamic_ncols=True,
leave=False,
candidate_chunks, desc="Evaluating chunks", unit="chunk", dynamic_ncols=True, leave=False
):
try:
context_eval_response = extract_with_llm(
Expand All @@ -139,8 +135,7 @@ class AnswerResponse(BaseModel):
extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
)
answer: str = Field(
...,
description="A complete answer to the given question using the provided context.",
..., description="A complete answer to the given question using the provided context."
)
system_prompt: ClassVar[str] = f"""
You are given a set of contexts extracted from a document.
Expand Down Expand Up @@ -191,11 +186,7 @@ def insert_evals(
session.execute(text("CHECKPOINT;"))


def answer_evals(
num_evals: int = 100,
*,
config: RAGLiteConfig | None = None,
) -> "pd.DataFrame":
def answer_evals(num_evals: int = 100, *, config: RAGLiteConfig | None = None) -> "pd.DataFrame":
"""Read evals from the database and answer them with RAG."""
try:
import pandas as pd
Expand Down Expand Up @@ -251,7 +242,7 @@ def evaluate(
class RAGLiteRagasEmbeddings(BaseRagasEmbeddings):
"""A RAGLite embedder for Ragas."""

def __init__(self, config: RAGLiteConfig | None = None):
def __init__(self, config: RAGLiteConfig | None = None) -> None:
self.config = config or RAGLiteConfig()

def embed_query(self, text: str) -> list[float]:
Expand Down
2 changes: 1 addition & 1 deletion src/raglite/_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class MyNameResponse(BaseModel):
# is disabled by default because it only supports a subset of JSON schema features [2].
# [1] https://docs.litellm.ai/docs/completion/json_mode
# [2] https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported
# TODO: Fall back to {"type": "json_object"} if JSON schema is not supported by the LLM.
# TODO(lsorber): Fall back to {"type": "json_object"} if JSON schema isn't supported by the LLM.
response_format: dict[str, Any] | None = (
{
"type": "json_schema",
Expand Down
11 changes: 2 additions & 9 deletions src/raglite/_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def _create_chunk_records(
chunklets = split_chunklets(sentences, max_size=config.chunk_max_size)
chunklet_embeddings = embed_strings(chunklets, config=config)
chunks, chunk_embeddings = split_chunks(
chunklets=chunklets,
chunklet_embeddings=chunklet_embeddings,
max_size=config.chunk_max_size,
chunklets=chunklets, chunklet_embeddings=chunklet_embeddings, max_size=config.chunk_max_size
)
# Create the chunk records.
chunk_records, headings = [], ""
Expand Down Expand Up @@ -79,12 +77,7 @@ def _create_chunk_records(
)
else:
chunk_embedding_records_list.append(
[
ChunkEmbedding(
chunk_id=chunk_record.id,
embedding=full_chunk_embedding,
)
]
[ChunkEmbedding(chunk_id=chunk_record.id, embedding=full_chunk_embedding)]
)
return document, chunk_records, chunk_embedding_records_list

Expand Down
6 changes: 3 additions & 3 deletions src/raglite/_lazy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ def __getattr__(name: str) -> object:
class LazyAttributeError:
error_message = "To use llama.cpp models, please install `llama-cpp-python`."

def __init__(self, error: ModuleNotFoundError | None = None):
def __init__(self, error: ModuleNotFoundError | None = None) -> None:
self.error = error

def __getattr__(self, name: str) -> NoReturn:
raise ModuleNotFoundError(self.error_message) from self.error

def __call__(self, *args: Any, **kwargs: Any) -> NoReturn:
def __call__(self, *args: Any, **kwargs: Any) -> NoReturn: # noqa: ARG002
raise ModuleNotFoundError(self.error_message) from self.error

class LazySubmoduleError:
def __init__(self, error: ModuleNotFoundError):
def __init__(self, error: ModuleNotFoundError) -> None:
self.error = error

def __getattr__(self, name: str) -> LazyAttributeError | type[LazyAttributeError]:
Expand Down
9 changes: 3 additions & 6 deletions src/raglite/_litellm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Add support for llama-cpp-python models to LiteLLM."""

# ruff: noqa: ANN401, ARG002

import asyncio
import contextlib
import logging
Expand All @@ -26,12 +28,7 @@

from raglite._chatml_function_calling import chatml_function_calling_with_streaming
from raglite._config import RAGLiteConfig
from raglite._lazy_llama import (
Llama,
LlamaRAMCache,
llama_supports_gpu_offload,
llama_types,
)
from raglite._lazy_llama import Llama, LlamaRAMCache, llama_supports_gpu_offload, llama_types

# Reduce the logging level for LiteLLM, flashrank, and httpx.
litellm.suppress_debug_info = True
Expand Down
2 changes: 1 addition & 1 deletion src/raglite/_query_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Compute and update an optimal query adapter."""

# ruff: noqa: N806
# ruff: noqa: N806, RUF002

from dataclasses import replace

Expand Down
4 changes: 2 additions & 2 deletions src/raglite/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _get_tools(
"The `query` string MUST be a precise single-faceted question in the user's language.\n"
"The `query` string MUST resolve all pronouns to explicit nouns."
),
},
}
},
"required": ["query"],
"additionalProperties": False,
Expand Down Expand Up @@ -237,7 +237,7 @@ async def async_rag(
# Add the tool call requests to the message array.
messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr]
# Run the tool calls to retrieve the RAG context and append the output to the message array.
# TODO: Make this async.
# TODO(lsorber): Make this async.
messages.extend(_run_tools(tool_calls, on_retrieval, config))
# Asynchronously stream the assistant response.
chunks = []
Expand Down
10 changes: 2 additions & 8 deletions src/raglite/_split_chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@


def split_chunks( # noqa: C901, PLR0915
chunklets: list[str],
chunklet_embeddings: FloatMatrix,
max_size: int = 2048,
chunklets: list[str], chunklet_embeddings: FloatMatrix, max_size: int = 2048
) -> tuple[list[str], list[FloatMatrix]]:
"""Split chunklets into optimal semantic chunks with corresponding chunklet embeddings.

Expand Down Expand Up @@ -103,11 +101,7 @@ def split_chunks( # noqa: C901, PLR0915
)
b_ub = np.ones(A.shape[0], dtype=np.float32)
res = linprog(
partition_similarity,
A_ub=-A,
b_ub=-b_ub,
bounds=(0, 1),
integrality=[1] * A.shape[1],
partition_similarity, A_ub=-A, b_ub=-b_ub, bounds=(0, 1), integrality=[1] * A.shape[1]
)
if not res.success:
error_message = "Optimization of chunk partitions failed."
Expand Down
Loading