Skip to content

Option to disable unwrapping model for generation in PPO/RLOO/OnlineDPO #2529

@dawidm

Description

Feature request

Add trainer option to disable unwrapping model (unwrap_model_for_generation()) in online methods.

Motivation

As discussed in #2250, when deepspeed stage 3 is used and policy model is bigger than single GPU VRAM, unwrap_model_for_generation() will cause OOM. The option to disable unwrapping will make such training scenarios possible (but slow).

Your contribution

I will likely create PR after it's discussed and approved.

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions