Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tunix/rl/rl_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from tunix.rl import utils
from absl.testing import absltest
import numpy as np

class UtilsTest(absltest.TestCase):

def test_is_positive_integer(self):
# 1. Standard Ints (Should Pass)
try:
utils.is_positive_integer(5, "test_var")
utils.is_positive_integer(1, "test_var")
except ValueError:
self.fail("is_positive_integer raised ValueError unexpectedly for int!")

# 2. Numpy Types (THIS IS THE BUG - It will fail now)
try:
utils.is_positive_integer(np.int64(5), "numpy_int")
utils.is_positive_integer(np.float32(5.0), "numpy_float")
except AttributeError:
# We catch AttributeError specifically because that's what we are fixing
self.fail("CRASH: AttributeError detected! The bug exists.")
except ValueError:
self.fail("is_positive_integer failed on Numpy types!")

# 3. Fail Cases (Should raise ValueError)
with self.assertRaisesRegex(ValueError, "positive integer"):
utils.is_positive_integer(5.5, "float_var")

with self.assertRaisesRegex(ValueError, "positive integer"):
utils.is_positive_integer(-5, "neg_var")

if __name__ == '__main__':
absltest.main()
24 changes: 20 additions & 4 deletions tunix/rl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,27 @@
NamedSharding = jax.sharding.NamedSharding


def is_positive_integer(value: int | None, name: str):
"""Checks if the value is positive."""
if value is not None and (not value.is_integer() or value <= 0):
raise ValueError(f"{name} must be a positive integer. Got: {value}")
def is_integer_value(x):
"""Checks if a value is effectively an integer (safe for all types)."""
# 1. Booleans are effectively integers (0/1) but we don't want them here.
if isinstance(x, bool):
return False

# 2. Native Ints and Numpy Ints (The fix for #903/#953)
if isinstance(x, (int, np.integer)):
return True

# 3. Floats that are actually integers (e.g. 5.0)
if isinstance(x, (float, np.floating)):
return x.is_integer()

return False

def is_positive_integer(value, name: str):
"""Checks if the value is positive and integer-like."""
# Use the new helper instead of calling .is_integer() directly
if value is not None and (not is_integer_value(value) or value <= 0):
raise ValueError(f"{name} must be a positive integer. Got: {value}")

def check_divisibility(
small_size,
Expand Down