Skip to content

Commit

Permalink
set numpy random state based on worker_id
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Apr 11, 2021
1 parent 4e97483 commit 499dae8
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions pts/model/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def train_model(
num_workers=num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
worker_init_fn=self._worker_init_fn,
**kwargs,
)

Expand All @@ -138,6 +139,7 @@ def train_model(
num_workers=num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
worker_init_fn=self._worker_init_fn,
**kwargs,
)

Expand All @@ -155,6 +157,10 @@ def train_model(
),
)

@staticmethod
def _worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)

def train(
self,
training_data: Dataset,
Expand Down

0 comments on commit 499dae8

Please sign in to comment.