Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def reset(
assert not (random and sink)

if random and seed is not None:
set_seed(seed, performance_mode=True)
set_seed(seed, deterministic_mode=False) # TODO: configurable?

if isinstance(batch_shape, int):
batch_shape = (batch_shape,)
Expand Down
30 changes: 15 additions & 15 deletions src/gfn/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ def filter_kwargs_for_callable(
# -----------------------------------------------------------------------------


def set_seed(seed: int, performance_mode: bool = False) -> None:
def set_seed(seed: int, deterministic_mode: bool = False) -> None:
"""Used to control randomness for both single and distributed training.

Args:
seed: The seed to use for all random number generators
performance_mode: If True, disables deterministic behavior for better performance.
deterministic_mode: If True, uses deterministic behavior for better performance.
In multi-GPU settings, this only affects cuDNN. In multi-CPU settings,
this allows parallel processing in NumPy.
"""
Expand Down Expand Up @@ -186,19 +186,22 @@ def set_seed(seed: int, performance_mode: bool = False) -> None:

# Set device-specific environment variables
if torch.cuda.is_available():
# For GPU training, we can use multiple threads for CPU operations
if performance_mode:
os.environ["OMP_NUM_THREADS"] = str(num_cpus)
os.environ["MKL_NUM_THREADS"] = str(num_cpus)
else:
if deterministic_mode:
# For reproducibility in GPU training, we still want deterministic
# CPU operations
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
else:
# For GPU training, we can use multiple threads for CPU operations
os.environ["OMP_NUM_THREADS"] = str(num_cpus)
os.environ["MKL_NUM_THREADS"] = str(num_cpus)
Comment on lines +194 to +197
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this the default behavior? If yes, I think we can remove this else branch

else:
# For CPU-only training, we need to be more careful with threading
if performance_mode:

if deterministic_mode:
# For perfect reproducibility in CPU training, disable parallel processing
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
else:
# Allow parallel processing but with controlled number of threads
# Different backends might handle threading differently
if backend in ["mpi", "ccl"]:
Expand All @@ -211,10 +214,6 @@ def set_seed(seed: int, performance_mode: bool = False) -> None:

os.environ["OMP_NUM_THREADS"] = str(num_threads)
os.environ["MKL_NUM_THREADS"] = str(num_threads)
else:
# For perfect reproducibility in CPU training, disable parallel processing
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

else:
# Non-distributed training - use the global seed
Expand All @@ -237,7 +236,7 @@ def set_seed(seed: int, performance_mode: bool = False) -> None:
threading.current_thread()._seed = seed

# These are only set when we care about reproducibility over performance
if not performance_mode:
if deterministic_mode:
# GPU-specific settings
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
Expand Down Expand Up @@ -332,7 +331,8 @@ def make_dataloader_seed_fns(

def _worker_init_fn(worker_id: int) -> None: # pragma: no cover
# Each worker gets a distinct seed in the same pattern used for ranks.
set_seed(base_seed + worker_id, performance_mode=False)
# TODO: Can this be false?
set_seed(base_seed + worker_id, deterministic_mode=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be false by default; as it is the default torch behavior


gen = torch.Generator()
gen.manual_seed(base_seed)
Expand Down