Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions tests/rl/rl_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,5 +254,52 @@ def _create_mock_train_example(
np.testing.assert_array_equal(pack2.completion_mask, expected_mask_2)


class IsPositiveIntegerTest(absltest.TestCase):
"""Tests for `utils.is_positive_integer`."""

def test_accepts_python_int(self):
utils.is_positive_integer(1, 'x')
utils.is_positive_integer(100, 'x')

def test_accepts_numpy_int(self):
utils.is_positive_integer(np.int32(5), 'x')
utils.is_positive_integer(np.int64(5), 'x')
utils.is_positive_integer(np.uint8(5), 'x')

def test_accepts_none(self):
utils.is_positive_integer(None, 'x')

def test_rejects_zero(self):
with self.assertRaises(ValueError):
utils.is_positive_integer(0, 'x')

def test_rejects_negative(self):
with self.assertRaises(ValueError):
utils.is_positive_integer(-1, 'x')
with self.assertRaises(ValueError):
utils.is_positive_integer(np.int64(-5), 'x')

def test_rejects_bool(self):
with self.assertRaises(ValueError):
utils.is_positive_integer(True, 'x')
with self.assertRaises(ValueError):
utils.is_positive_integer(False, 'x')

def test_rejects_float(self):
with self.assertRaises(ValueError):
utils.is_positive_integer(1.0, 'x')
with self.assertRaises(ValueError):
utils.is_positive_integer(1.5, 'x')

def test_rejects_string(self):
with self.assertRaises(ValueError):
utils.is_positive_integer('5', 'x')

def test_error_message_includes_name(self):
with self.assertRaises(ValueError) as ctx:
utils.is_positive_integer(-1, 'max_steps')
self.assertIn('max_steps', str(ctx.exception))


if __name__ == '__main__':
absltest.main()
13 changes: 11 additions & 2 deletions tunix/rl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,17 @@


def is_positive_integer(value: int | None, name: str):
"""Checks if the value is positive."""
if value is not None and (not isinstance(value, int) or value <= 0):
"""Checks if the value is a positive integer.

Accepts Python ``int`` and numpy integer scalars (e.g. ``np.int64``).
Explicitly rejects ``bool``, which is a subclass of ``int`` in Python but
is not semantically an integer in this context.
"""
if value is None:
return
if isinstance(value, bool) or not isinstance(value, (int, np.integer)):
raise ValueError(f"{name} must be a positive integer. Got: {value}")
if value <= 0:
raise ValueError(f"{name} must be a positive integer. Got: {value}")


Expand Down