Skip to content

Commit 3886f06

Browse files
committed
add checks for scaling factors
1 parent ee2ff5c commit 3886f06

File tree

5 files changed

+67
-17
lines changed

5 files changed

+67
-17
lines changed

autoPyTorch/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@
5858
"forecasting tasks! Please run \n pip install autoPyTorch[forecasting] \n to "\
5959
"install the corresponding dependencies!"
6060

61+
# This value is applied to ensure numerical stability: Sometimes we want to rescale some values: value / scale.
62+
# We make the scale value to be 1 if it is smaller than this value to ensure that the scaled value will not resutl in
63+
# overflow
64+
VERY_SMALL_VALUE = 1e-12
6165

6266
# The constant values for time series forecasting comes from
6367
# https://github.com/rakshitha123/TSForecasting/blob/master/experiments/deep_learning_experiments.py

autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/scaling/utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from sklearn.base import BaseEstimator
88

9+
from autoPyTorch.constants import VERY_SMALL_VALUE
10+
911

1012
# Similar to / inspired by
1113
# https://github.com/tslearn-team/tslearn/blob/a3cf3bf/tslearn/preprocessing/preprocessing.py
@@ -41,7 +43,7 @@ def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Any = None) -> "TimeSeriesS
4143
self.loc[self.static_features] = X[self.static_features].mean()
4244

4345
# ensure that if all the values are the same in a group, we could still normalize them correctly
44-
self.scale[self.scale == 0] = 1.
46+
self.scale[self.scale < VERY_SMALL_VALUE] = 1.
4547

4648
elif self.mode == "min_max":
4749
X_grouped = X.groupby(X.index)
@@ -55,14 +57,14 @@ def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Any = None) -> "TimeSeriesS
5557
self.loc = min_
5658
self.scale = diff_
5759
self.scale.mask(self.scale == 0.0, self.loc)
58-
self.scale[self.scale == 0.0] = 1.0
60+
self.scale[self.scale < VERY_SMALL_VALUE] = 1.0
5961

6062
elif self.mode == "max_abs":
6163
X_abs = X.transform("abs")
6264
max_abs_ = X_abs.groupby(X_abs.index).agg("max")
6365
max_abs_[self.static_features] = max_abs_[self.static_features].max()
6466

65-
max_abs_[max_abs_ == 0.0] = 1.0
67+
max_abs_[max_abs_ < VERY_SMALL_VALUE] = 1.0
6668
self.loc = None
6769
self.scale = max_abs_
6870

@@ -73,7 +75,7 @@ def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Any = None) -> "TimeSeriesS
7375
mean_abs_[self.static_features] = mean_abs_[self.static_features].mean()
7476
self.scale = mean_abs_.mask(mean_abs_ == 0.0, X_abs.agg("max"))
7577

76-
self.scale[self.scale == 0] = 1
78+
self.scale[self.scale < VERY_SMALL_VALUE] = 1
7779
self.loc = None
7880

7981
elif self.mode == "none":
@@ -108,7 +110,7 @@ def transform(self, X: Union[pd.DataFrame, np.ndarray]) -> Union[pd.DataFrame, n
108110
loc = X.mean(axis=0, keepdims=True)
109111
scale = np.nan_to_num(X.std(axis=0, ddof=1, keepdims=True))
110112
scale = np.where(scale == 0, loc, scale)
111-
scale[scale == 0] = 1.
113+
scale[scale < VERY_SMALL_VALUE] = 1.
112114
return (X - loc) / scale
113115

114116
elif self.mode == 'min_max':
@@ -119,21 +121,21 @@ def transform(self, X: Union[pd.DataFrame, np.ndarray]) -> Union[pd.DataFrame, n
119121
loc = min_
120122
scale = diff_
121123
scale = np.where(scale == 0., loc, scale)
122-
scale[scale == 0.0] = 1.0
124+
scale[scale < VERY_SMALL_VALUE] = 1.0
123125
return (X - loc) / scale
124126

125127
elif self.mode == "max_abs":
126128
X_abs = np.abs(X)
127129
max_abs_ = X_abs.max(0, keepdims=True)
128-
max_abs_[max_abs_ == 0.0] = 1.0
130+
max_abs_[max_abs_ < VERY_SMALL_VALUE] = 1.0
129131
scale = max_abs_
130132
return X / scale
131133

132134
elif self.mode == 'mean_abs':
133135
X_abs = np.abs(X)
134136
mean_abs_ = X_abs.mean(0, keepdims=True)
135137
scale = np.where(mean_abs_ == 0.0, np.max(X_abs), mean_abs_)
136-
scale[scale == 0] = 1
138+
scale[scale < VERY_SMALL_VALUE] = 1
137139
return X / scale
138140

139141
elif self.mode == "none":

autoPyTorch/pipeline/components/setup/forecasting_target_scaling/utils.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
from autoPyTorch.constants import VERY_SMALL_VALUE
78

89
# Similar to / inspired by
910
# https://github.com/tslearn-team/tslearn/blob/a3cf3bf/tslearn/preprocessing/preprocessing.py
@@ -30,7 +31,7 @@ def transform(self,
3031

3132
offset_targets = past_targets - loc
3233
scale = torch.where(torch.logical_or(scale == 0.0, scale == torch.nan), offset_targets[:, [-1]], scale)
33-
scale[scale == 0.0] = 1.0
34+
scale[scale < VERY_SMALL_VALUE] = 1.0
3435
if future_targets is not None:
3536
future_targets = (future_targets - loc) / scale
3637
return (past_targets - loc) / scale, future_targets, loc, scale
@@ -42,14 +43,14 @@ def transform(self,
4243
diff_ = max_ - min_
4344
loc = min_
4445
scale = torch.where(diff_ == 0, past_targets[:, [-1]], diff_)
45-
scale[scale == 0.0] = 1.0
46+
scale[scale < VERY_SMALL_VALUE] = 1.0
4647
if future_targets is not None:
4748
future_targets = (future_targets - loc) / scale
4849
return (past_targets - loc) / scale, future_targets, loc, scale
4950

5051
elif self.mode == "max_abs":
5152
max_abs_ = torch.max(torch.abs(past_targets), dim=1, keepdim=True)[0]
52-
max_abs_[max_abs_ == 0.0] = 1.0
53+
max_abs_[max_abs_ < VERY_SMALL_VALUE] = 1.0
5354
scale = max_abs_
5455
if future_targets is not None:
5556
future_targets = future_targets / scale
@@ -58,7 +59,7 @@ def transform(self,
5859
elif self.mode == 'mean_abs':
5960
mean_abs = torch.mean(torch.abs(past_targets), dim=1, keepdim=True)
6061
scale = torch.where(mean_abs == 0.0, past_targets[:, [-1]], mean_abs)
61-
scale[scale == 0.0] = 1.0
62+
scale[scale < VERY_SMALL_VALUE] = 1.0
6263
if future_targets is not None:
6364
future_targets = future_targets / scale
6465
return past_targets / scale, future_targets, None, scale
@@ -82,7 +83,7 @@ def transform(self,
8283
offset_targets = past_targets - loc
8384
# ensure that all the targets are scaled properly
8485
scale = torch.where(torch.logical_or(scale == 0.0, scale == torch.nan), offset_targets[:, [-1]], scale)
85-
scale[scale == 0.0] = 1.0
86+
scale[scale < VERY_SMALL_VALUE] = 1.0
8687

8788
if future_targets is not None:
8889
future_targets = (future_targets - loc) / scale
@@ -100,7 +101,7 @@ def transform(self,
100101
diff_ = max_ - min_
101102
loc = min_
102103
scale = torch.where(diff_ == 0, past_targets[:, [-1]], diff_)
103-
scale[scale == 0.0] = 1.0
104+
scale[scale < VERY_SMALL_VALUE] = 1.0
104105

105106
if future_targets is not None:
106107
future_targets = (future_targets - loc) / scale
@@ -110,7 +111,7 @@ def transform(self,
110111

111112
elif self.mode == "max_abs":
112113
max_abs_ = torch.max(torch.abs(valid_past_targets), dim=1, keepdim=True)[0]
113-
max_abs_[max_abs_ == 0.0] = 1.0
114+
max_abs_[max_abs_ < VERY_SMALL_VALUE] = 1.0
114115
scale = max_abs_
115116
if future_targets is not None:
116117
future_targets = future_targets / scale
@@ -122,8 +123,8 @@ def transform(self,
122123
elif self.mode == 'mean_abs':
123124
mean_abs = torch.sum(torch.abs(valid_past_targets), dim=1, keepdim=True) / valid_past_obs
124125
scale = torch.where(mean_abs == 0.0, valid_past_targets[:, [-1]], mean_abs)
125-
# in case that all values in the tensor is 0
126-
scale[scale == 0.0] = 1.0
126+
# in case that all values in the tensor is too small
127+
scale[scale < VERY_SMALL_VALUE] = 1.0
127128
if future_targets is not None:
128129
future_targets = future_targets / scale
129130

test/test_pipeline/components/preprocessing/forecasting/test_scaling.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def setUp(self) -> None:
2626

2727
columns = ['f1', 's', 'f2']
2828
self.raw_data = [data_seq_1, data_seq_2]
29+
2930
self.data = pd.DataFrame(np.concatenate([data_seq_1, data_seq_2]), columns=columns, index=[0] * 3 + [1] * 4)
3031
self.static_features = ('s',)
3132
self.static_features_column = (1, )
@@ -37,6 +38,12 @@ def setUp(self) -> None:
3738
'numerical_columns': numerical_columns,
3839
'static_features': self.static_features,
3940
'is_small_preprocess': True}
41+
very_small_values = np.array([[1e-10, 0., 1e-15],
42+
[1e-10, 0., 1e-15]])
43+
44+
self.small_data = pd.DataFrame(np.array([[1e-10, 0., 1e-15],
45+
[-1e-10, 0., +1e-15]]), columns=columns, index=[0] * 2)
46+
4047

4148
def test_base_and_standard_scaler(self):
4249
scaler_component = BaseScaler(scaling_mode='standard')
@@ -82,6 +89,10 @@ def test_base_and_standard_scaler(self):
8289
transformed_test = np.concatenate([scaler.transform(raw_data) for raw_data in self.raw_data])
8390
self.assertTrue(np.allclose(transformed_test[:, [0, -1]], transformed_test[:, [0, -1]]))
8491

92+
scaler.dataset_is_small_preprocess = True
93+
scaler.fit(self.small_data)
94+
self.assertTrue(np.allclose(scaler.scale.values.flatten(), np.array([1.41421356e-10, 1., 1.])))
95+
8596
def test_min_max(self):
8697
scaler = TimeSeriesScaler(mode='min_max',
8798
static_features=self.static_features
@@ -109,6 +120,10 @@ def test_min_max(self):
109120
transformed_test = np.concatenate([scaler.transform(raw_data) for raw_data in self.raw_data])
110121
self.assertTrue(np.allclose(transformed_test[:, [0, -1]], transformed_test[:, [0, -1]]))
111122

123+
scaler.dataset_is_small_preprocess = True
124+
scaler.fit(self.small_data)
125+
self.assertTrue(np.all(scaler.scale.values.flatten() == np.array([2e-10, 1., 1.])))
126+
112127
def test_max_abs_scaler(self):
113128
scaler = TimeSeriesScaler(mode='max_abs',
114129
static_features=self.static_features
@@ -136,6 +151,10 @@ def test_max_abs_scaler(self):
136151
transformed_test = np.concatenate([scaler.transform(raw_data) for raw_data in self.raw_data])
137152
self.assertTrue(np.allclose(transformed_test[:, [0, -1]], transformed_test[:, [0, -1]]))
138153

154+
scaler.dataset_is_small_preprocess = True
155+
scaler.fit(self.small_data)
156+
self.assertTrue(np.all(scaler.scale.values.flatten() == np.array([1e-10, 1., 1.])))
157+
139158
def test_mean_abs_scaler(self):
140159
scaler = TimeSeriesScaler(mode='mean_abs',
141160
static_features=self.static_features
@@ -162,6 +181,10 @@ def test_mean_abs_scaler(self):
162181
transformed_test = np.concatenate([scaler.transform(raw_data) for raw_data in self.raw_data])
163182
self.assertTrue(np.allclose(transformed_test[:, [0, -1]], transformed_test[:, [0, -1]]))
164183

184+
scaler.dataset_is_small_preprocess = True
185+
scaler.fit(self.small_data)
186+
self.assertTrue(np.all(scaler.scale.values.flatten() == np.array([1e-10, 1., 1.])))
187+
165188
def test_no_scaler(self):
166189
scaler = TimeSeriesScaler(mode='none',
167190
static_features=self.static_features

test/test_pipeline/components/setup/forecasting/test_forecasting_target_scaling.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ def test_target_mean_abs_scalar(self):
9595

9696
self.assertIsNone(loc_full)
9797

98+
_, _, _, scale = scalar(
99+
torch.Tensor([[1e-10, 1e-10, 1e-10],[1e-15,1e-15, 1e-15]]).reshape([2, 3, 1])
100+
)
101+
self.assertTrue(torch.equal(scale.flatten(), torch.Tensor([1e-10, 1.])))
102+
98103
def test_target_standard_scalar(self):
99104
X = {'dataset_properties': {}}
100105
scalar = BaseTargetScaler(scaling_mode='standard')
@@ -178,6 +183,11 @@ def test_target_standard_scalar(self):
178183
self.assertTrue(torch.equal(loc, loc_full))
179184
self.assertTrue(torch.equal(scale, scale_full))
180185

186+
_, _, _, scale = scalar(
187+
torch.Tensor([[1e-10, -1e-10, 1e-10],[1e-15, -1e-15, 1e-15]]).reshape([2, 3, 1])
188+
)
189+
self.assertTrue(torch.all(torch.isclose(scale.flatten(), torch.Tensor([1.1547e-10, 1.]))))
190+
181191
def test_target_min_max_scalar(self):
182192
X = {'dataset_properties': {}}
183193
scalar = BaseTargetScaler(scaling_mode='min_max')
@@ -245,6 +255,11 @@ def test_target_min_max_scalar(self):
245255
self.assertTrue(torch.equal(transformed_future_targets_full, transformed_future_targets_full))
246256
self.assertTrue(torch.equal(scale, scale_full))
247257

258+
_, _, _, scale = scalar(
259+
torch.Tensor([[1e-10, 1e-10, 1e-10],[1e-15,1e-15, 1e-15]]).reshape([2, 3, 1])
260+
)
261+
self.assertTrue(torch.equal(scale.flatten(), torch.Tensor([1e-10, 1.])))
262+
248263
def test_target_max_abs_scalar(self):
249264
X = {'dataset_properties': {}}
250265
scalar = BaseTargetScaler(scaling_mode='max_abs')
@@ -309,3 +324,8 @@ def test_target_max_abs_scalar(self):
309324
self.assertTrue(torch.equal(transformed_future_targets_full, transformed_future_targets_full))
310325
self.assertIsNone(loc_full)
311326
self.assertTrue(torch.equal(scale, scale_full))
327+
328+
_, _, _, scale = scalar(
329+
torch.Tensor([[1e-10, 1e-10, 1e-10],[1e-15,1e-15, 1e-15]]).reshape([2, 3, 1])
330+
)
331+
self.assertTrue(torch.equal(scale.flatten(), torch.Tensor([1e-10, 1.])))

0 commit comments

Comments
 (0)