Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Changelog

Documenting changes which affect configuration usage patterns (added/moved/removed/renamed fields, notable logic changes).

- **`model.lora`**: Moved from `model.experimental.lora` to `model.lora` (no longer experimental) (#1440, 2025-12-16)
2 changes: 1 addition & 1 deletion configs/ci/integration/rl_lora/resume.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT"
[trainer.optim]
lr = 1e-4

[trainer.model.experimental.lora]
[trainer.model.lora]
rank = 8

[trainer.ckpt.weights]
Expand Down
2 changes: 1 addition & 1 deletion configs/ci/integration/rl_lora/start.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT"
[trainer.optim]
lr = 1e-4

[trainer.model.experimental.lora]
[trainer.model.lora]
rank = 8

[trainer.ckpt.weights]
Expand Down
2 changes: 1 addition & 1 deletion configs/wiki_search/rl.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ name = "wiki-search-4b"
lr = 1e-5
weight_decay = 0.0

[trainer.model.experimental.lora]
[trainer.model.lora]
rank = 8
alpha = 32
dropout = 0.0
Expand Down
2 changes: 1 addition & 1 deletion examples/alphabet_sort/rl.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ impl = "liger_kernel"
[trainer.model.ac]
freq = 1

[trainer.model.experimental.lora]
[trainer.model.lora]
rank = 32
alpha = 64

Expand Down
2 changes: 1 addition & 1 deletion examples/wiki_search/rl.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ name = "wiki-search"
lr = 1e-5
weight_decay = 0.0

[trainer.model.experimental.lora]
[trainer.model.lora]
rank = 8
alpha = 32
dropout = 0.0
Expand Down
6 changes: 3 additions & 3 deletions src/prime_rl/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,18 +368,18 @@ def auto_setup_weight_broadcast(self):

@model_validator(mode="after")
def auto_setup_lora(self):
if self.trainer.model.experimental.lora is not None:
if self.trainer.model.lora is not None:
if self.trainer.weight_broadcast.type == "nccl":
raise ValueError("NCCL weight broadcast does not support LoRA yet.")
self.trainer.weight_broadcast.adapter_only = True
if self.orchestrator.lora_name is None:
lora_name = (
f"r{self.trainer.model.experimental.lora.rank}-a{self.trainer.model.experimental.lora.alpha}"
f"r{self.trainer.model.lora.rank}-a{self.trainer.model.lora.alpha}"
)
self.orchestrator.lora_name = lora_name
if self.inference is not None:
self.inference.enable_lora = True
self.inference.max_lora_rank = self.trainer.model.experimental.lora.rank
self.inference.max_lora_rank = self.trainer.model.lora.rank
else:
warnings.warn(
"LoRA is enabled, but inference is not configured. When manually starting the inference server, make sure to set `--enable_lora` and `--max-lora-rank`."
Expand Down
25 changes: 7 additions & 18 deletions src/prime_rl/trainer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,6 @@ class LoRAConfig(BaseConfig):
] = []


class ExperimentalConfig(BaseConfig):
"""Experimental modeling features."""

lora: Annotated[
LoRAConfig | None,
Field(
description="Whether to apply LoRA to the model. If None, will not apply LoRA.",
),
] = None


class ModelConfig(BaseConfig):
"""Configures the model for training."""

Expand Down Expand Up @@ -234,20 +223,20 @@ class ModelConfig(BaseConfig):
),
] = True

lora: Annotated[
LoRAConfig | None,
Field(
description="Whether to apply LoRA to the model. If None, will not apply LoRA.",
),
] = None

debug: Annotated[
DebugModelConfig,
Field(
description="Debugging feature around model and distributed training.",
),
] = DebugModelConfig()

experimental: Annotated[
ExperimentalConfig,
Field(
description="Experimental modeling features.",
),
] = ExperimentalConfig()

@model_validator(mode="after")
def _map_model_name_for_moe(self):
"""Map model name if it exists in MOE_MODEL_MAPS."""
Expand Down
4 changes: 2 additions & 2 deletions src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,8 @@ def setup_model(
model = get_model(config, device=torch.device("cpu"), dtype=DTYPE_MAP[config.optimization_dtype])

# Apply LoRA before FSDP setup
if config.experimental.lora is not None:
apply_lora_to_model(model, config.experimental.lora)
if config.lora is not None:
apply_lora_to_model(model, config.lora)

# the right order is AC -> Compile -> FSDP
if config.ac is not None:
Expand Down
6 changes: 3 additions & 3 deletions src/prime_rl/trainer/rl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,11 @@ def dont_do_massive_traces(self):
@model_validator(mode="after")
def validate_lora_adapter_saving(self):
if self.ckpt and self.ckpt.weights and self.ckpt.weights.save_adapter_separately:
lora_enabled = self.model and self.model.experimental and self.model.experimental.lora
lora_enabled = self.model and self.model.lora
if not lora_enabled:
raise ValueError(
"save_adapter_separately=True requires LoRA to be enabled. "
"Set model.experimental.lora or disable save_adapter_separately."
"Set model.lora or disable save_adapter_separately."
)
return self

Expand All @@ -217,7 +217,7 @@ def validate_opt_and_fsdp_offload(self):

@model_validator(mode="after")
def validate_lora_broadcast(self):
if self.weight_broadcast.adapter_only and not self.model.experimental.lora:
if self.weight_broadcast.adapter_only and not self.model.lora:
raise ValueError("Adapter only weight broadcast requires LoRA to be enabled.")
if self.weight_broadcast.type == "nccl" and self.weight_broadcast.adapter_only:
# TODO: Support this
Expand Down
4 changes: 2 additions & 2 deletions src/prime_rl/trainer/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def train(config: RLTrainerConfig):
# Set up checkpoint manager
logger.info(f"Initializing checkpoint managers ({config.ckpt})")
ckpt_manager, weight_ckpt_manager = setup_ckpt_managers(
config.output_dir, config.ckpt, config.model.experimental.lora
config.output_dir, config.ckpt, config.model.lora
)

# get the checkpoint step to load from
Expand Down Expand Up @@ -127,7 +127,7 @@ def train(config: RLTrainerConfig):
# Set up weight broadcast
logger.info(f"Initializing weight broadcast ({config.weight_broadcast})")
weight_broadcast = setup_weight_broadcast(
config.output_dir, config.weight_broadcast, config.model.experimental.lora
config.output_dir, config.weight_broadcast, config.model.lora
)

if parallel_dims.cp_enabled:
Expand Down
4 changes: 2 additions & 2 deletions src/prime_rl/trainer/sft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ def dont_do_massive_traces(self):
@model_validator(mode="after")
def validate_lora_adapter_saving(self):
if self.ckpt and self.ckpt.weights and self.ckpt.weights.save_adapter_separately:
lora_enabled = self.model and self.model.experimental and self.model.experimental.lora
lora_enabled = self.model and self.model.lora
if not lora_enabled:
raise ValueError(
"save_adapter_separately=True requires LoRA to be enabled. "
"Set model.experimental.lora or disable save_adapter_separately."
"Set model.lora or disable save_adapter_separately."
)
return self

Expand Down
4 changes: 1 addition & 3 deletions src/prime_rl/trainer/sft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ def train(config: SFTTrainerConfig):

# Set up checkpoint manager
logger.info(f"Initializing checkpoint managers ({config.ckpt})")
ckpt_manager, weight_ckpt_manager = setup_ckpt_managers(
config.output_dir, config.ckpt, config.model.experimental.lora
)
ckpt_manager, weight_ckpt_manager = setup_ckpt_managers(config.output_dir, config.ckpt, config.model.lora)

checkpoint_step = None
if config.ckpt and config.ckpt.resume_step is not None and ckpt_manager is not None:
Expand Down