Skip to content

Commit a19e7ed

Browse files
committed
good fix: norm_x to float64 to deal with precision issue
1 parent 877ade9 commit a19e7ed

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

netshare/pre_post_processors/netshare/netshare_pre_post_processor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -343,12 +343,12 @@ def _post_process(self, input_folder, output_folder,
343343
print(f"{self.__class__.__name__}.{inspect.stack()[0][3]}")
344344

345345
# Denormalize the fields (e.g. int to IP, vector to word, etc.)
346-
# denormalize_fields(
347-
# config_pre_post_processor=self._config,
348-
# pre_processed_data_folder=pre_processed_data_folder,
349-
# generated_data_folder=input_folder,
350-
# post_processed_data_folder=output_folder
351-
# )
346+
denormalize_fields(
347+
config_pre_post_processor=self._config,
348+
pre_processed_data_folder=pre_processed_data_folder,
349+
generated_data_folder=input_folder,
350+
post_processed_data_folder=output_folder
351+
)
352352

353353
# Choose the best generated data across different hyperparameters/checkpoints
354354
choose_best_model(

netshare/utils/field.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,16 @@ def denormalize(self, norm_x):
6464
if norm_x.shape[-1] != self.dim_x:
6565
raise ValueError(f"Dimension is {norm_x.shape[-1]}. "
6666
f"Expected dimension is {self.dim_x}")
67+
norm_x = norm_x.astype(np.float64) # Convert to float64 for precision
68+
6769
# [0, 1] normalization
6870
if self.norm_option == Normalization.ZERO_ONE:
69-
to_return = np.asarray(
70-
norm_x * float(self.max_x - self.min_x) + self.min_x,
71-
dtype=np.float64)
71+
to_return = norm_x * float(self.max_x - self.min_x) + self.min_x
7272

7373
# [-1, 1] normalization
7474
elif self.norm_option == Normalization.MINUSONE_ONE:
75-
to_return = np.asarray((norm_x + 1) / 2.0 *
76-
float(self.max_x - self.min_x) + self.min_x,
77-
dtype=np.float64)
75+
to_return = (norm_x + 1) / 2.0 * \
76+
float(self.max_x - self.min_x) + self.min_x
7877

7978
else:
8079
raise Exception("Not valid normalization option!")

0 commit comments

Comments
 (0)