Skip to content

RL utils.py has an incorrect check for is_positive_integer #902

@Ajay-26

Description

@Ajay-26

Hi! Raising this issue to note that there might be a bug in initializing cluster config for RL.

The following snippet crashes when I run it:

# Training config
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        # metrics logging
        metrics_logging_options=metrics_logging_options,
        # checkpoint saving
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=EOS_TOKENS,
    ),
)

grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

Expected Behavior
The given snippet should simply work.

Actual Behavior
The snippet raises the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_181/593643255.py in <cell line: 0>()
      8     rollout_engine='vanilla',
      9     offload_to_cpu=False,
---> 10     training_config=rl_cluster_lib.RLTrainingConfig(
     11         actor_optimizer=optimizer,
     12         eval_every_n_steps=EVAL_EVERY_N_STEPS,

/usr/local/lib/python3.11/dist-packages/tunix/rl/rl_cluster.py in __init__(self, eval_every_n_steps, max_steps, gradient_accumulation_steps, checkpoint_root_directory, checkpointing_options, metrics_logging_options, profiler_options, data_sharding_axis, max_inflight_computations, metrics_prefix, pbar_description, actor_optimizer, critic_optimizer, mini_batch_size, train_micro_batch_size, rollout_micro_batch_size, compute_logps_micro_batch_size)

/usr/local/lib/python3.11/dist-packages/tunix/rl/rl_cluster.py in __post_init__(self)
    112         "compute_logps_micro_batch_size",
    113     ]:
--> 114       rl_utils.is_positive_integer(getattr(self, name), name)
    115 
    116     if self.gradient_accumulation_steps is not None:

/usr/local/lib/python3.11/dist-packages/tunix/rl/utils.py in is_positive_integer(value, name)
     35 def is_positive_integer(value: int | None, name: str):
     36   """Checks if the value is positive."""
---> 37   if value is not None and (not value.is_integer() or value <= 0):
     38     raise ValueError(f"{name} must be a positive integer. Got: {value}")
     39 

AttributeError: 'int' object has no attribute 'is_integer'

Steps to Reproduce the Problem

1.Run the above provided snippet. It can be found in the link given below

I am trying to run the provided example on GRPO that is given here on a Kaggle notebook.

I have been able to locate that the issue is in the is_positive_integer() function here on line 35

Happy to make a PR with the fix.

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions