Skip to content

Commit

Permalink
fixed imports
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Aug 2, 2024
1 parent 79d58f3 commit d1b9e29
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from math import gcd, log
from time import time
from typing import Literal, Tuple
import multiprocessing

import torch
from einops import rearrange
Expand All @@ -21,12 +22,7 @@
from gfn.preprocessors import EnumPreprocessor, IdentityPreprocessor
from gfn.states import DiscreteStates


def import_multiprocessing_if_required():
"""Imports and configures multiprocessing if it isn't already imported."""
if "multiprocessing" not in sys.modules:
import multiprocessing
multiprocessing.set_start_method("fork") # multiprocessing-torch compatibility.
multiprocessing.set_start_method("fork") # multiprocessing-torch compatibility.


def lcm(a, b):
Expand Down Expand Up @@ -284,7 +280,6 @@ def _calculate_all_states_tensor(self, batch_size: int = 20_000):
Args:
batch_size: Compute this number of hypergrid indices in parallel.
"""
import_multiprocessing_if_required()
if self._all_states is None and self.calculate_all_states:
start_time = time()
all_states = []
Expand Down Expand Up @@ -357,7 +352,6 @@ def _generate_combinations_in_batches(self, n, k, batch_size):
for i in range(0, total_combinations, batch_size)
]

import_multiprocessing_if_required()
with multiprocessing.Pool() as pool:
for result in pool.imap(self._worker, tasks):
yield result

0 comments on commit d1b9e29

Please sign in to comment.