Skip to content

Support block-modular architecture #277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 158 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
158 commits
Select commit Hold shift + click to select a range
5137757
stuff
jlamypoirier Mar 26, 2025
f0cb32a
Merge remote-tracking branch 'origin/main' into config_updates
jlamypoirier Mar 26, 2025
f26010e
Update pretrained config
jlamypoirier Mar 27, 2025
b930a39
stuff
jlamypoirier Mar 27, 2025
918a7a8
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Mar 27, 2025
8117c47
fixes
jlamypoirier Mar 27, 2025
1c995d3
fix
jlamypoirier Mar 27, 2025
3f90475
Merge branch 'main' into config_updates
jlamypoirier Mar 27, 2025
e389058
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Mar 27, 2025
506fe92
fixes
jlamypoirier Mar 27, 2025
971d3ef
fixes
jlamypoirier Mar 27, 2025
6bf20cb
Tests wip
jlamypoirier Mar 28, 2025
c13fb19
misc
jlamypoirier Mar 29, 2025
a20fcec
tests
jlamypoirier Apr 1, 2025
9af26a7
Merge branch 'main' into config_updates
jlamypoirier Apr 1, 2025
9af372d
Tests, fixes, remove tuple format
jlamypoirier Apr 1, 2025
dded00a
fix
jlamypoirier Apr 2, 2025
42d5ca4
Merge remote-tracking branch 'origin/main' into config_updates
jlamypoirier Apr 2, 2025
986f9f3
fix
jlamypoirier Apr 2, 2025
5abc087
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Apr 2, 2025
8e3e795
fixes
jlamypoirier Apr 2, 2025
da6eb7b
fixes
jlamypoirier Apr 3, 2025
67e08aa
Merge branch 'main' into config_updates
jlamypoirier Apr 3, 2025
a09e6f3
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Apr 3, 2025
baad705
fix
jlamypoirier Apr 3, 2025
b702837
Test, fixes
jlamypoirier Apr 5, 2025
a8684f8
Knowledge distillation, fix cross-entropy
jlamypoirier Apr 11, 2025
b781729
Fixes, distillation
jlamypoirier Apr 13, 2025
db6504b
fixes
jlamypoirier Apr 14, 2025
7c2933a
Merge remote-tracking branch 'origin/main' into config_updates
jlamypoirier Apr 14, 2025
a017c11
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Apr 14, 2025
368a6bf
Merge remote-tracking branch 'origin/main' into update_pretrained_config
jlamypoirier Apr 14, 2025
e0c82a0
Merge remote-tracking branch 'origin/main' into distillation
jlamypoirier Apr 14, 2025
16a3dd7
Merge branch 'update_pretrained_config' into distillation
jlamypoirier Apr 14, 2025
cff9892
fixes
jlamypoirier Apr 14, 2025
793ecde
Merge branch 'update_pretrained_config' into distillation
jlamypoirier Apr 14, 2025
b67006a
fixes
jlamypoirier Apr 15, 2025
2014108
Add constraints
jlamypoirier Apr 16, 2025
4fb78e4
Merge remote-tracking branch 'origin/main' into distillation
jlamypoirier Apr 16, 2025
fa3d556
Add constraints
jlamypoirier Apr 16, 2025
6c2c887
Separate reference model preprocessing
jlamypoirier Apr 16, 2025
67f9db6
fix
jlamypoirier Apr 16, 2025
48141e5
Merge remote-tracking branch 'origin/main' into update_pretrained_config
jlamypoirier Apr 17, 2025
e6e5a32
Merge branch 'update_pretrained_config' into distillation
jlamypoirier Apr 17, 2025
537deca
fix
jlamypoirier Apr 17, 2025
a590e8b
Merge branch 'distillation' into reference_model_preprocessing
jlamypoirier Apr 17, 2025
3d5dc94
Merge commit '6ad0a96c9328234b907d01a82c4c52bd48752b2f' into update_p…
jlamypoirier Apr 18, 2025
2bb0c08
Merge branch 'update_pretrained_config' into distillation
jlamypoirier Apr 18, 2025
067ba97
Merge remote-tracking branch 'origin/main' into distillation
jlamypoirier Apr 21, 2025
d2b3154
misc
jlamypoirier Apr 21, 2025
2e63d29
Merge branch 'distillation' into reference_model_preprocessing
jlamypoirier Apr 22, 2025
7133e4d
Merge remote-tracking branch 'origin/main' into reference_model_prepr…
jlamypoirier Apr 22, 2025
a0ba051
fixes
jlamypoirier Apr 25, 2025
9ddfb69
add per-layer lr-scale
RaymondLi0 Apr 25, 2025
5e282cc
modeling mtp llamba
oleksost Apr 28, 2025
87b3197
modeling apriel ssm
oleksost Apr 29, 2025
d3e1df2
Apriel to SSM
oleksost Apr 29, 2025
082cf22
Apriel SSM conversion
oleksost Apr 29, 2025
66fb0a2
Merge remote-tracking branch 'origin/main' into reference_model_prepr…
jlamypoirier Apr 29, 2025
0d4d5c5
fix
jlamypoirier Apr 29, 2025
b5ffd26
Merge remote-tracking branch 'origin/reference_model_preprocessing' i…
oleksost Apr 29, 2025
c43e535
wip
oleksost Apr 29, 2025
a1f44d4
conversion apriel ssm
oleksost Apr 29, 2025
fbec02d
config apriel
oleksost Apr 29, 2025
75d6460
temp checkpoint conversion
oleksost Apr 29, 2025
73a4252
block pattern for hybrid conversion
oleksost Apr 30, 2025
5afc7dc
SSMBlockType
oleksost Apr 30, 2025
8e9facf
wip
oleksost Apr 30, 2025
77ad39f
add token-prediction loss coefficients
RaymondLi0 Apr 30, 2025
da9bf1a
eval apriel ssm
oleksost May 1, 2025
ac4a598
fix
jlamypoirier May 1, 2025
0c0e7d9
adding check for missing `rope_type` (#246)
nitsanluke May 1, 2025
97ba9d4
Loss masking for distillation
jlamypoirier May 1, 2025
231d5d8
test, misc
jlamypoirier May 1, 2025
d7922af
Merge branch 'reference_model_preprocessing' into distillation_loss_mask
jlamypoirier May 1, 2025
30a75b0
eval apriel ssm
oleksost May 1, 2025
a50bc2e
cleanup
oleksost May 1, 2025
f8af7be
Merge branch 'oleksiy/apriel-ssm' of https://github.com/ServiceNow/Fa…
oleksost May 1, 2025
6532c5f
hybrid config
oleksost May 2, 2025
2a5646b
Merge remote-tracking branch 'origin/distillation_loss_mask' into ole…
oleksost May 2, 2025
9a678df
sft distill
oleksost May 2, 2025
a7abe53
conversion
oleksost May 2, 2025
a68c0b7
conversion
oleksost May 2, 2025
9cfef44
lr stage definition as string
oleksost May 2, 2025
005e623
fixes
jlamypoirier May 2, 2025
cad951a
fix
jlamypoirier May 2, 2025
40970ec
Merge branch 'reference_model_preprocessing' into distillation_loss_mask
jlamypoirier May 2, 2025
bce916d
loss maks
oleksost May 2, 2025
9d95064
fix
jlamypoirier May 2, 2025
2c96abb
Merge remote-tracking branch 'origin/main' into reference_model_prepr…
jlamypoirier May 2, 2025
935c470
fix
jlamypoirier May 2, 2025
9aff3b7
fix shuffled tokens
oleksost May 2, 2025
d82ddbf
Merge remote-tracking branch 'origin/main' into reference_model_prepr…
jlamypoirier May 2, 2025
6949c49
Merge branch 'reference_model_preprocessing' into distillation_loss_mask
jlamypoirier May 2, 2025
9c105e7
Merge remote-tracking branch 'origin/main' into distillation_loss_mask
jlamypoirier May 2, 2025
ae4d111
fixes
jlamypoirier May 2, 2025
deb7ce6
fixes
jlamypoirier May 2, 2025
eaba34f
innit like in mamba in llama
oleksost May 2, 2025
f8ca122
embeddings_lr_scale
oleksost May 5, 2025
2db740b
fix
jlamypoirier May 5, 2025
41d4da3
disable freezing
RaymondLi0 May 5, 2025
4160b1f
hybrid model loading and exporting
oleksost May 6, 2025
30ad8b8
wip
oleksost May 7, 2025
ea55ae2
Merge branch 'main' into oleksiy/apriel-ssm
oleksost May 7, 2025
cd4edd5
Merge remote-tracking branch 'origin/distillation_loss_mask' into ole…
oleksost May 7, 2025
9c4f38f
layer-lr scale for mlp as well
RaymondLi0 May 7, 2025
1784dca
wip
oleksost May 7, 2025
1e3cc28
nvm
oleksost May 7, 2025
2dc945b
hybrid modeling
oleksost May 9, 2025
4277e67
modeling
oleksost May 9, 2025
6153c33
Merge branch 'main' into oleksiy/apriel-ssm
oleksost May 9, 2025
c71cb16
nvm
oleksost May 9, 2025
be04c19
output lr scale
oleksost May 9, 2025
1311f5b
output_lr_scale
oleksost May 9, 2025
baf4011
nvm
oleksost May 9, 2025
6cf26c5
eval
oleksost May 10, 2025
901d1b6
rename
oleksost May 12, 2025
b5696fb
Merge remote-tracking branch 'origin/raymond/per_layer_lr_scale' into…
oleksost May 12, 2025
616c540
per_layer_lr_scale
oleksost May 12, 2025
9af5ee5
merged also prediction_loss_coefficient from #243
oleksost May 12, 2025
1a7939b
added logging in mamba
oleksost May 12, 2025
532d0d5
no norm layer freezing
oleksost May 12, 2025
8349130
test
oleksost May 12, 2025
023102c
test
oleksost May 12, 2025
865da95
debug
oleksost May 12, 2025
87c93d3
comment
oleksost May 12, 2025
da4977d
Merge remote-tracking branch 'origin/main' into oleksiy/apriel-ssm
oleksost May 12, 2025
a18b80f
debug
oleksost May 12, 2025
40d5437
wip
oleksost May 14, 2025
72ace3b
fix
RaymondLi0 May 14, 2025
121e906
test + comment
oleksost May 14, 2025
aa3bc0b
stuff
jlamypoirier May 14, 2025
28d321e
stuff
jlamypoirier May 14, 2025
1bbd7fb
stuff
jlamypoirier May 14, 2025
3595949
Minimalistic dynamic configs
jlamypoirier May 14, 2025
39b1a04
stuff
jlamypoirier May 15, 2025
8a8fa77
fix
RaymondLi0 May 16, 2025
8e25990
add test with frozen weights
RaymondLi0 May 16, 2025
456a0c5
add description for tests
RaymondLi0 May 16, 2025
87efd45
15b model apriel hybrid
oleksost May 20, 2025
95c7b53
Merge remote-tracking branch 'origin/main' into oleksiy/apriel-ssm
oleksost May 20, 2025
326387d
Merge remote-tracking branch 'origin/raymond/fix-frozen-weight' into …
oleksost May 20, 2025
aafbfb5
nvm
oleksost May 20, 2025
c7fe8d7
nvm
oleksost May 20, 2025
848ef04
Merge remote-tracking branch 'origin/main' into oleksiy/apriel-ssm
oleksost May 20, 2025
c285e8d
nvm
oleksost May 20, 2025
26e4924
Merge remote-tracking branch 'origin/minimalistic_dynamic_classes' in…
oleksost May 20, 2025
3eaa240
modeling
oleksost May 22, 2025
4781d15
wip
oleksost May 26, 2025
ac4bfa9
wip
oleksost May 27, 2025
45008b5
wip
oleksost May 27, 2025
a378954
wip hybrid block architecture
oleksost May 28, 2025
38fc529
wip
oleksost May 29, 2025
852bb92
Merge remote-tracking branch 'origin/main' into modular_hybrids
oleksost May 29, 2025
e5534fd
wip
oleksost Jun 3, 2025
6860c43
added lr scales per block
oleksost Jun 3, 2025
7178407
weight sharing
oleksost Jun 3, 2025
0553a4b
test
oleksost Jun 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T:

if expected_class is not None:
# Should be handled in `from_dict`, but can fail if instantiating directly.
Assert.is_(self.__class__, expected_class)

# TODO: is this ok? i.e. we want the assigned class to be a subclass of the expected class, not neccessarily exactly the same class.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is handled in from_dict. The expected class is not the same as the type hint.

Assert.custom(issubclass, expected_class, self.__class__)
if not self._validated:
try:
self._validate()
Expand Down Expand Up @@ -720,7 +720,7 @@ def _get_class_name(cls) -> str:
@classmethod
def from_dict(
cls,
default: "Config| dict[str, typing.Any]]",
default: "Config| dict[str, typing.Any]",
*updates: "Config| dict[str | tuple[str, ...], typing.Any]",
strict: bool = True,
update_type: UpdateType = UpdateType.override,
Expand Down
26 changes: 13 additions & 13 deletions fast_llm/engine/optimizer/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,19 @@ def create_schedule_from_config(config: LearningRateScheduleConfig) -> LearningR
begin_step = 0
for stage_arg_str in config.schedule.split(";"):
try:
for stage_type, num_steps, lr, *stage_args in stage_arg_str.split(","):
assert begin_step is not None
num_steps = int(num_steps)
end_step = None if num_steps < 0 else begin_step + num_steps
kwargs = {"begin_step": begin_step, "end_step": end_step, "lr": float(lr)}
if len(stage_args) > 0:
kwargs["end_lr"] = float(stage_args[0])
if len(stage_args) > 1:
kwargs["power"] = float(stage_args[1])
if len(stage_args) > 2:
raise ValueError(stage_args[2:])
stages.append(_STAGE_TYPE_MAP[stage_type](**kwargs))
begin_step = end_step
stage_type, num_steps, lr, *stage_args = stage_arg_str.split(",")
assert begin_step is not None
num_steps = int(num_steps)
end_step = None if num_steps < 0 else begin_step + num_steps
kwargs = {"begin_step": begin_step, "end_step": end_step, "lr": float(lr)}
if len(stage_args) > 0:
kwargs["end_lr"] = float(stage_args[0])
if len(stage_args) > 1:
kwargs["power"] = float(stage_args[1])
if len(stage_args) > 2:
raise ValueError(stage_args[2:])
stages.append(_STAGE_TYPE_MAP[stage_type](**kwargs))
begin_step = end_step
except Exception:
raise ValueError(f'Cannot parse optimizer stage definition "{stage_arg_str}"')
return LearningRateSchedule(stages)
Loading