Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Imlemented set_config and added test cases #20

Merged
merged 14 commits into from
Nov 14, 2023
Merged
89 changes: 73 additions & 16 deletions nbs/00_utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"#| hide\n",
"%load_ext autoreload\n",
Expand All @@ -38,15 +47,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using JAX backend.\n"
]
}
],
"outputs": [],
"source": [
"#| export\n",
"from __future__ import annotations\n",
Expand Down Expand Up @@ -263,11 +264,11 @@
"name": "stdout",
"output_type": "stream",
"text": [
"data: [array([[-0.4150979 ],\n",
" [-0.59805975],\n",
" [-0.59252158],\n",
" [-0.88781678],\n",
" [ 0.08100867]]), 1, True, 'Hello', array(['a', 'b', 'c'], dtype='<U1')]\n",
"data: [array([[ 1.10284602],\n",
" [-0.23078658],\n",
" [-1.27621715],\n",
" [ 0.99138736],\n",
" [-0.01633328]]), 1, True, 'Hello', array(['a', 'b', 'c'], dtype='<U1')]\n",
"treedef: {'a': True, 'b': False, 'c': {'d': False, 'e': False, 'f': True}}\n"
]
}
Expand Down Expand Up @@ -580,6 +581,62 @@
"def get_config() -> Config: \n",
" return main_config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# | export\n",
"def set_config(\n",
" *,\n",
" rng_reserve_size: int = None, # The number of random number generators to reserve.\n",
" global_seed: int = None, # The global seed for random number generators.\n",
" **kwargs\n",
") -> None:\n",
" \"\"\"Sets the global configurations.\"\"\"\n",
"\n",
" def set_val(\n",
" arg_name: str, # The name of the argument.\n",
" arg_value: int, # The value of the argument.\n",
" arg_min: int # The minimum value of the argument.\n",
" ) -> None:\n",
" \"\"\"Checks the validity of the argument and sets the value.\"\"\"\n",
" \n",
" if arg_value is None or not hasattr(main_config, arg_name):\n",
" return\n",
" \n",
" if not isinstance(arg_value, int):\n",
" raise TypeError(f\"`{arg_name}` must be an integer, but got {type(arg_value).__name__}.\")\n",
" if arg_value < arg_min:\n",
" raise ValueError(f\"`{arg_name}` must be non-negative, but got {arg_value}.\")\n",
" setattr(main_config, arg_name, arg_value)\n",
"\n",
" set_val('rng_reserve_size', rng_reserve_size, 1)\n",
" set_val('global_seed', global_seed, 0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generic Test cases\n",
"set_config()\n",
"assert get_config().rng_reserve_size == 1 and get_config().global_seed == 42\n",
"set_config(rng_reserve_size=100)\n",
"assert get_config().rng_reserve_size == 100\n",
"set_config(global_seed=1234)\n",
"assert get_config().global_seed == 1234\n",
"set_config(rng_reserve_size=2, global_seed=234)\n",
"assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234\n",
"set_config()\n",
"assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234\n",
"set_config(invalid_key = 80)\n",
"assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234"
]
}
],
"metadata": {
Expand All @@ -590,5 +647,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
1 change: 1 addition & 0 deletions relax/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,4 +736,5 @@
'relax.utils.load_json': ('utils.html#load_json', 'relax/utils.py'),
'relax.utils.load_pytree': ('utils.html#load_pytree', 'relax/utils.py'),
'relax.utils.save_pytree': ('utils.html#save_pytree', 'relax/utils.py'),
'relax.utils.set_config': ('utils.html#set_config', 'relax/utils.py'),
'relax.utils.validate_configs': ('utils.html#validate_configs', 'relax/utils.py')}}}
31 changes: 30 additions & 1 deletion relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from jax.core import InconclusiveDimensionOperation

# %% auto 0
__all__ = ['validate_configs', 'save_pytree', 'load_pytree', 'auto_reshaping', 'grad_update', 'load_json', 'get_config']
__all__ = ['validate_configs', 'save_pytree', 'load_pytree', 'auto_reshaping', 'grad_update', 'load_json', 'get_config',
'set_config']

# %% ../nbs/00_utils.ipynb 5
def validate_configs(
Expand Down Expand Up @@ -138,3 +139,31 @@ def default(cls) -> Config:
# %% ../nbs/00_utils.ipynb 36
def get_config() -> Config:
return main_config

# %% ../nbs/00_utils.ipynb 37
def set_config(
BirkhoffG marked this conversation as resolved.
Show resolved Hide resolved
*,
rng_reserve_size: int = None, # The number of random number generators to reserve.
global_seed: int = None, # The global seed for random number generators.
**kwargs
) -> None:
"""Sets the global configurations."""

def set_val(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fix the indent for the arguments? It should be exactly one tab.

arg_name: str, # The name of the argument.
arg_value: int, # The value of the argument.
arg_min: int # The minimum value of the argument.
) -> None:
"""Checks the validity of the argument and sets the value."""

if arg_value is None or not hasattr(main_config, arg_name):
return

if not isinstance(arg_value, int):
raise TypeError(f"`{arg_name}` must be an integer, but got {type(arg_value).__name__}.")
if arg_value < arg_min:
raise ValueError(f"`{arg_name}` must be non-negative, but got {arg_value}.")
setattr(main_config, arg_name, arg_value)

set_val('rng_reserve_size', rng_reserve_size, 1)
set_val('global_seed', global_seed, 0)