From d1b9e29a7cf22d4f84f4c5a38118c5fabff38887 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 2 Aug 2024 00:17:53 -0400 Subject: [PATCH] fixed imports --- src/gfn/gym/hypergrid.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index 0f96b08..a43f382 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -10,6 +10,7 @@ from math import gcd, log from time import time from typing import Literal, Tuple +import multiprocessing import torch from einops import rearrange @@ -21,12 +22,7 @@ from gfn.preprocessors import EnumPreprocessor, IdentityPreprocessor from gfn.states import DiscreteStates - -def import_multiprocessing_if_required(): - """Imports and configures multiprocessing if it isn't already imported.""" - if "multiprocessing" not in sys.modules: - import multiprocessing - multiprocessing.set_start_method("fork") # multiprocessing-torch compatibility. +multiprocessing.set_start_method("fork") # multiprocessing-torch compatibility. def lcm(a, b): @@ -284,7 +280,6 @@ def _calculate_all_states_tensor(self, batch_size: int = 20_000): Args: batch_size: Compute this number of hypergrid indices in parallel. """ - import_multiprocessing_if_required() if self._all_states is None and self.calculate_all_states: start_time = time() all_states = [] @@ -357,7 +352,6 @@ def _generate_combinations_in_batches(self, n, k, batch_size): for i in range(0, total_combinations, batch_size) ] - import_multiprocessing_if_required() with multiprocessing.Pool() as pool: for result in pool.imap(self._worker, tasks): yield result