Skip to content

Commit 23c1667

Browse files
committed
Fix randomness in cocktail ingredients
1 parent b0b67ea commit 23c1667

File tree

8 files changed

+21
-20
lines changed

8 files changed

+21
-20
lines changed

autoPyTorch/pipeline/components/setup/base_setup.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any, Dict
1+
from typing import Any, Dict, Optional
2+
3+
import numpy as np
24

35
from autoPyTorch.pipeline.components.base_component import autoPyTorchComponent
46

@@ -7,8 +9,8 @@ class autoPyTorchSetupComponent(autoPyTorchComponent):
79
"""Provide an abstract interface for schedulers
810
in Auto-Pytorch"""
911

10-
def __init__(self) -> None:
11-
super(autoPyTorchSetupComponent, self).__init__()
12+
def __init__(self, random_state: Optional[np.random.RandomState] = None) -> None:
13+
super(autoPyTorchSetupComponent, self).__init__(random_state=random_state)
1214

1315
def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
1416
"""

autoPyTorch/pipeline/components/setup/network_embedding/LearnedEntityEmbedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class LearnedEntityEmbedding(NetworkEmbeddingComponent):
8989
Class to learn an embedding for categorical hyperparameters.
9090
"""
9191

92-
def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = None, **kwargs: Any):
92+
def __init__(self, random_state: Optional[np.random.RandomState] = None, **kwargs: Any):
9393
super().__init__(random_state=random_state)
9494
self.config = kwargs
9595

autoPyTorch/pipeline/components/setup/network_embedding/NoEmbedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class NoEmbedding(NetworkEmbeddingComponent):
2020
Class to learn an embedding for categorical hyperparameters.
2121
"""
2222

23-
def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = None):
23+
def __init__(self, random_state: Optional[np.random.RandomState] = None):
2424
super().__init__(random_state=random_state)
2525

2626
def build_embedding(self, num_input_features: np.ndarray, num_numerical_features: int) -> nn.Module:

autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111

1212

1313
class NetworkEmbeddingComponent(autoPyTorchSetupComponent):
14-
def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = None):
15-
super().__init__()
14+
def __init__(self, random_state: Optional[np.random.RandomState] = None):
15+
super().__init__(random_state=random_state)
1616
self.embedding: Optional[nn.Module] = None
17-
self.random_state = random_state
1817

1918
def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
2019

autoPyTorch/pipeline/components/training/trainer/GridCutMixTrainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
2727
typing.Dict[str, np.ndarray]: arguments to the criterion function
2828
"""
2929
beta = 1.0
30-
lam = np.random.beta(beta, beta)
30+
lam = self.random_state.beta(beta, beta)
3131
batch_size, channel, W, H = X.size()
3232
index = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size)
3333

34-
r = np.random.rand(1)
34+
r = self.random_state.rand(1)
3535
if beta <= 0 or r > self.alpha:
3636
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}
3737

@@ -40,8 +40,8 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
4040
cut_rat = np.sqrt(1. - lam)
4141
cut_w = np.int(W * cut_rat)
4242
cut_h = np.int(H * cut_rat)
43-
cx = np.random.randint(W)
44-
cy = np.random.randint(H)
43+
cx = self.random_state.randint(W)
44+
cy = self.random_state.randint(H)
4545
bbx1 = np.clip(cx - cut_w // 2, 0, W)
4646
bby1 = np.clip(cy - cut_h // 2, 0, H)
4747
bbx2 = np.clip(cx + cut_w // 2, 0, W)

autoPyTorch/pipeline/components/training/trainer/GridCutOutTrainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
2424
np.ndarray: that processes data
2525
typing.Dict[str, np.ndarray]: arguments to the criterion function
2626
"""
27-
r = np.random.rand(1)
27+
r = self.random_state.rand(1)
2828
batch_size, channel, W, H = X.size()
2929
if r > self.cutout_prob:
3030
return X, {'y_a': y, 'y_b': y, 'lam': 1}
@@ -34,8 +34,8 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3434
cut_rat = np.sqrt(1. - self.patch_ratio)
3535
cut_w = np.int(W * cut_rat)
3636
cut_h = np.int(H * cut_rat)
37-
cx = np.random.randint(W)
38-
cy = np.random.randint(H)
37+
cx = self.random_state.randint(W)
38+
cy = self.random_state.randint(H)
3939
bbx1 = np.clip(cx - cut_w // 2, 0, W)
4040
bby1 = np.clip(cy - cut_h // 2, 0, H)
4141
bbx2 = np.clip(cx + cut_w // 2, 0, W)

autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,19 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
2828
typing.Dict[str, np.ndarray]: arguments to the criterion function
2929
"""
3030
beta = 1.0
31-
lam = np.random.beta(beta, beta)
31+
lam = self.random_state.beta(beta, beta)
3232
batch_size = X.size()[0]
3333
index = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size)
3434

35-
r = np.random.rand(1)
35+
r = self.random_state.rand(1)
3636
if beta <= 0 or r > self.alpha:
3737
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}
3838

3939
# The mixup component mixes up also on the batch dimension
4040
# It is unlikely that the batch size is lower than the number of features, but
4141
# be safe
4242
size = min(X.shape[0], X.shape[1])
43-
indices = torch.tensor(random.sample(range(1, size), max(1, np.int(size * lam))))
43+
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int(size * lam))))
4444

4545
X[:, indices] = X[index, :][:, indices]
4646

autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
2828
typing.Dict[str, np.ndarray]: arguments to the criterion function
2929
"""
3030

31-
r = np.random.rand(1)
31+
r = self.random_state.rand(1)
3232
if r > self.cutout_prob:
3333
y_a = y
3434
y_b = y
@@ -39,7 +39,7 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3939
# It is unlikely that the batch size is lower than the number of features, but
4040
# be safe
4141
size = min(X.shape[0], X.shape[1])
42-
indices = torch.tensor(random.sample(range(1, size), max(1, np.int(size * self.patch_ratio))))
42+
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int(size * self.patch_ratio))))
4343

4444
# We use an ordinal encoder on the tabular data
4545
# -1 is the conceptual equivalent to 0 in a image, that does not

0 commit comments

Comments
 (0)