Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ if (LLAMA_BUILD)
add_compile_definitions(GGML_USE_METAL)
endif()

# Set version for mtmd (required by upstream CMakeLists.txt)
# NOTE: This is a workaround for mtmd build requirements.
# Version is set to 0.0.0 for local builds. If upstream adds version
# compatibility checks, this may need to match llama.cpp version.
if (NOT DEFINED LLAMA_BUILD_NUMBER)
set(LLAMA_BUILD_NUMBER 0)
endif()
set(LLAMA_INSTALL_VERSION 0.0.${LLAMA_BUILD_NUMBER})

# Building llava
add_subdirectory(vendor/llama.cpp/tools/mtmd)

Expand Down
2 changes: 1 addition & 1 deletion llama_cpp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .llama_cpp import *
from .llama import *

__version__ = "0.3.16"
__version__ = "0.4.0"
53 changes: 37 additions & 16 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def __init__(
logits_all: bool = False,
embedding: bool = False,
offload_kqv: bool = True,
flash_attn: bool = False,
op_offload: Optional[bool] = None,
swa_full: Optional[bool] = None,
flash_attn: Optional[bool] = None,
# Sampling Params
no_perf: bool = False,
last_n_tokens_size: int = 64,
Expand Down Expand Up @@ -173,7 +173,7 @@ def __init__(
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
embedding: Embedding mode only.
offload_kqv: Offload K, Q, V to GPU.
flash_attn: Use flash attention.
flash_attn: Use flash attention. None = auto, True = enabled, False = disabled.
op_offload: offload host tensor operations to device
swa_full: use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
no_perf: Measure performance timings.
Expand Down Expand Up @@ -341,7 +341,16 @@ def __init__(
self._logits_all = logits_all if draft_model is None else True
self.context_params.embeddings = embedding # TODO: Rename to embeddings
self.context_params.offload_kqv = offload_kqv
self.context_params.flash_attn = flash_attn
if flash_attn is None:
self.context_params.flash_attn_type = llama_cpp.LLAMA_FLASH_ATTN_TYPE_AUTO
elif flash_attn:
self.context_params.flash_attn_type = (
llama_cpp.LLAMA_FLASH_ATTN_TYPE_ENABLED
)
else:
self.context_params.flash_attn_type = (
llama_cpp.LLAMA_FLASH_ATTN_TYPE_DISABLED
)

if op_offload is not None:
self.context_params.op_offload = op_offload
Expand Down Expand Up @@ -934,7 +943,8 @@ def generate(

sample_idx += 1
if stopping_criteria is not None and stopping_criteria(
self._input_ids[: sample_idx], self._scores[sample_idx - self.n_tokens, :]
self._input_ids[:sample_idx],
self._scores[sample_idx - self.n_tokens, :],
):
return
tokens_or_none = yield token
Expand Down Expand Up @@ -1041,7 +1051,9 @@ def embed(
data: Union[List[List[float]], List[List[List[float]]]] = []

def decode_batch(seq_sizes: List[int]):
llama_cpp.llama_kv_self_clear(self._ctx.ctx)
mem = llama_cpp.llama_get_memory(self._ctx.ctx)
if mem is not None:
llama_cpp.llama_memory_clear(mem, True)
self._ctx.decode(self._batch)
self._batch.reset()

Expand Down Expand Up @@ -1112,7 +1124,9 @@ def decode_batch(seq_sizes: List[int]):

output = data[0] if isinstance(input, str) else data

llama_cpp.llama_kv_self_clear(self._ctx.ctx)
mem = llama_cpp.llama_get_memory(self._ctx.ctx)
if mem is not None:
llama_cpp.llama_memory_clear(mem, True)
self.reset()

if return_count:
Expand Down Expand Up @@ -1157,9 +1171,9 @@ def _create_completion(
bos_token_id: int = self.token_bos()
cls_token_id: int = self._model.token_cls()
sep_token_id: int = self._model.token_sep()
prefix_token_id: int = 0 # self._model.token_prefix() # TODO: Fix
middle_token_id: int = 0 # self._model.token_middle() # TODO: Fix
suffix_token_id: int = 0 # self._model.token_suffix() # TODO: Fix
prefix_token_id: int = self._model.token_prefix()
middle_token_id: int = self._model.token_middle()
suffix_token_id: int = self._model.token_suffix()
add_space_prefix: bool = (
self.metadata.get("tokenizer.ggml.add_space_prefix", "true") == "true"
)
Expand Down Expand Up @@ -1315,7 +1329,7 @@ def logit_bias_processor(
if seed is not None:
self.set_seed(seed)
else:
self.set_seed(random.Random(self._seed).randint(0, 2 ** 32))
self.set_seed(random.Random(self._seed).randint(0, 2**32))

finish_reason = "length"
multibyte_fix = 0
Expand Down Expand Up @@ -2056,7 +2070,10 @@ def create_chat_completion_openai_v1(
stream = kwargs.get("stream", False) # type: ignore
assert isinstance(stream, bool)
if stream:
return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
return (
ChatCompletionChunk(**chunk)
for chunk in self.create_chat_completion(*args, **kwargs)
) # type: ignore
else:
return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
except ImportError:
Expand Down Expand Up @@ -2096,7 +2113,7 @@ def __getstate__(self):
logits_all=self._logits_all,
embedding=self.context_params.embeddings,
offload_kqv=self.context_params.offload_kqv,
flash_attn=self.context_params.flash_attn,
flash_attn=self.context_params.flash_attn_type,
op_offload=self.context_params.op_offload,
swa_full=self.context_params.swa_full,
# Sampling Params
Expand Down Expand Up @@ -2316,19 +2333,23 @@ def from_pretrained(
)

if additional_files:
for additonal_file_name in additional_files:
for additional_file_name in additional_files:
# find the additional shard file:
matching_additional_files = [file for file in file_list if fnmatch.fnmatch(file, additonal_file_name)]
matching_additional_files = [
file
for file in file_list
if fnmatch.fnmatch(file, additional_file_name)
]

if len(matching_additional_files) == 0:
raise ValueError(
f"No file found in {repo_id} that match {additonal_file_name}\n\n"
f"No file found in {repo_id} that match {additional_file_name}\n\n"
f"Available Files:\n{json.dumps(file_list)}"
)

if len(matching_additional_files) > 1:
raise ValueError(
f"Multiple files found in {repo_id} matching {additonal_file_name}\n\n"
f"Multiple files found in {repo_id} matching {additional_file_name}\n\n"
f"Available Files:\n{json.dumps(files)}"
)

Expand Down
Loading