1
1
from __future__ import annotations
2
2
3
- import os
4
3
import ctypes
4
+ import os
5
5
import pathlib
6
6
7
+ from ._ggml import (
8
+ ggml_opt_get_optimizer_params
9
+ )
10
+
7
11
from typing import (
8
12
Callable ,
9
13
Union ,
171
175
# llama_sampler_p = NewType("llama_sampler_p", int)
172
176
# llama_sampler_p_ctypes = ctypes.c_void_p
173
177
178
+ # struct llama_opt_params;
179
+ llama_opt_params_p = NewType ("llama_opt_params_p" , int )
180
+ llama_opt_params_p_ctypes = ctypes .c_void_p
181
+
174
182
# struct llama_kv_cache;
175
183
llama_kv_cache_p = NewType ("llama_kv_cache_p" , int )
176
184
llama_kv_cache_p_ctypes = ctypes .c_void_p
243
251
# LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
244
252
# LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
245
253
# LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
254
+ # LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
246
255
# };
247
256
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0
248
257
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1
279
288
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32
280
289
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33
281
290
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34
291
+ LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35
282
292
283
293
284
294
# // note: these values should be synchronized with ggml_rope
@@ -790,6 +800,7 @@ class llama_model_params(ctypes.Structure):
790
800
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
791
801
# bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
792
802
# bool no_perf; // whether to measure performance timings
803
+ # bool op_offload; // whether to offload host tensor operations to device
793
804
# };
794
805
class llama_context_params (ctypes .Structure ):
795
806
"""Parameters for llama_context
@@ -811,7 +822,7 @@ class llama_context_params(ctypes.Structure):
811
822
yarn_beta_fast (float): YaRN low correction dim
812
823
yarn_beta_slow (float): YaRN high correction dim
813
824
yarn_orig_ctx (int): YaRN original context size
814
- defrag_thold (float): defragment the KV cache if holes/size > thold, < 0 disabled (default)
825
+ defrag_thold (float): defragment the KV cache if holes/size > thold, <= 0 disabled (default)
815
826
cb_eval (ggml_backend_sched_eval_callback): callback for scheduling eval
816
827
cb_eval_user_data (ctypes.ctypes.c_void_p): user data for cb_eval
817
828
type_k (int): data type for K cache
@@ -822,6 +833,7 @@ class llama_context_params(ctypes.Structure):
822
833
offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
823
834
flash_attn (bool): whether to use flash attention
824
835
no_perf (bool): whether to measure performance timings
836
+ op_offload(bool): whether to offload host tensor operations to device
825
837
"""
826
838
827
839
if TYPE_CHECKING :
@@ -852,6 +864,7 @@ class llama_context_params(ctypes.Structure):
852
864
offload_kqv : bool
853
865
flash_attn : bool
854
866
no_perf : bool
867
+ op_offload :bool
855
868
856
869
_fields_ = [
857
870
("n_ctx" , ctypes .c_uint32 ),
@@ -881,6 +894,7 @@ class llama_context_params(ctypes.Structure):
881
894
("offload_kqv" , ctypes .c_bool ),
882
895
("flash_attn" , ctypes .c_bool ),
883
896
("no_perf" , ctypes .c_bool ),
897
+ ("op_offload" , ctypes .c_bool ),
884
898
]
885
899
886
900
@@ -1193,7 +1207,20 @@ def llama_model_load_from_splits(
1193
1207
...
1194
1208
1195
1209
1196
- # LLAMA_API void llama_free_model(struct llama_model * model);
1210
+ # LLAMA_API void llama_model_save_to_file(
1211
+ # const struct llama_model * model,
1212
+ # const char * path_model);
1213
+ @ctypes_function (
1214
+ "llama_model_save_to_file" ,
1215
+ [llama_model_p_ctypes , ctypes .c_char_p ],
1216
+ None ,
1217
+ )
1218
+ def llama_model_save_to_file (model : llama_model_p , path_model : bytes , / ):
1219
+ ...
1220
+
1221
+
1222
+ # DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
1223
+ # "use llama_model_free instead");
1197
1224
@ctypes_function (
1198
1225
"llama_free_model" ,
1199
1226
[llama_model_p_ctypes ],
@@ -4128,8 +4155,8 @@ def llama_sampler_get_seed(smpl: llama_sampler_p, /) -> int:
4128
4155
llama_token ,
4129
4156
)
4130
4157
def llama_sampler_sample (
4131
- smpl : llama_sampler_p , ctx : llama_context_p , idx : int , /
4132
- ) -> int :
4158
+ smpl : llama_sampler_p , ctx : llama_context_p , idx : ctypes . c_int32 , /
4159
+ ) -> ctypes . c_int32 :
4133
4160
...
4134
4161
4135
4162
@@ -4306,3 +4333,85 @@ def llama_perf_sampler_reset(chain: llama_sampler_p, /):
4306
4333
...
4307
4334
4308
4335
4336
+ # //
4337
+ # // training
4338
+ # //
4339
+
4340
+ # // function that returns whether or not a given tensor contains trainable parameters
4341
+ # typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
4342
+ llama_opt_param_filter = ctypes .CFUNCTYPE (
4343
+ ctypes .c_bool , ctypes .c_void_p , ctypes .c_void_p
4344
+ )
4345
+
4346
+
4347
+ # // always returns true
4348
+ # LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
4349
+ @ctypes_function ("llama_opt_param_filter_all" , [ctypes .c_void_p , ctypes .c_void_p ], ctypes .c_bool )
4350
+ def llama_opt_param_filter_all (
4351
+ tensor : llama_model_p ,
4352
+ userdata : ctypes .c_void_p , /
4353
+ ) -> bool :
4354
+ ...
4355
+
4356
+ # struct llama_opt_params {
4357
+ # uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
4358
+
4359
+ # llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters
4360
+ # void * param_filter_ud; // userdata for determining which tensors contain trainable parameters
4361
+
4362
+ # ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
4363
+ # void * get_opt_pars_ud; // userdata for calculating optimizer parameters
4364
+ # };
4365
+ class llama_opt_params (ctypes .Structure ):
4366
+ _fields_ = [
4367
+ ("n_ctx_train" , ctypes .c_uint32 ),
4368
+ ("param_filter" , llama_opt_param_filter ),
4369
+ ("param_filter_ud" , ctypes .c_void_p ),
4370
+ ("get_opt_pars" , ggml_opt_get_optimizer_params ),
4371
+ ("get_opt_pars_ud" , ctypes .c_void_p ),
4372
+ ]
4373
+
4374
+
4375
+ # LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
4376
+ @ctypes_function (
4377
+ "llama_opt_init" ,
4378
+ [llama_context_p_ctypes , llama_model_p_ctypes , llama_opt_params_p_ctypes ],
4379
+ None ,
4380
+ )
4381
+ def llama_opt_init (
4382
+ lctx : llama_context_p ,
4383
+ model : llama_model_p ,
4384
+ lopt_params : llama_opt_params_p , /
4385
+ ):
4386
+ ...
4387
+
4388
+ # LLAMA_API void llama_opt_epoch(
4389
+ # struct llama_context * lctx,
4390
+ # ggml_opt_dataset_t dataset,
4391
+ # ggml_opt_result_t result_train,
4392
+ # ggml_opt_result_t result_eval,
4393
+ # int64_t idata_split,
4394
+ # ggml_opt_epoch_callback callback_train,
4395
+ # ggml_opt_epoch_callback callback_eval);
4396
+ @ctypes_function (
4397
+ "llama_opt_epoch" ,[
4398
+ llama_context_p_ctypes ,
4399
+ ctypes .c_void_p ,
4400
+ ctypes .c_void_p ,
4401
+ ctypes .c_void_p ,
4402
+ ctypes .c_int64 ,
4403
+ ctypes .c_void_p ,
4404
+ ctypes .c_void_p
4405
+ ],
4406
+ None ,
4407
+ )
4408
+ def llama_opt_epoch (
4409
+ lctx : llama_context_p ,
4410
+ dataset : ctypes .c_void_p ,
4411
+ result_train : ctypes .c_void_p ,
4412
+ result_eval : ctypes .c_void_p ,
4413
+ idata_split : ctypes .c_int64 ,
4414
+ callback_train : ctypes .c_void_p ,
4415
+ callback_eval : ctypes .c_void_p , /
4416
+ ):
4417
+ ...
0 commit comments