Skip to content

Commit 010fd80

Browse files
Merge branch 'develop' of https://github.com/HydrogenSulfate/PaddleScience into develop
2 parents 835ea6a + f2d4c91 commit 010fd80

File tree

4 files changed

+18
-6
lines changed

4 files changed

+18
-6
lines changed

examples/epnn/conf/epnn.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,11 @@ TRAIN:
4141
iters_per_epoch: 1
4242
save_freq: 50
4343
eval_during_train: true
44-
eval_with_no_grad: true
4544
lr_scheduler:
4645
epochs: ${TRAIN.epochs}
4746
iters_per_epoch: ${TRAIN.iters_per_epoch}
4847
gamma: 0.97
49-
decay_steps: 1
48+
decay_steps: 10000000
5049
pretrained_model_path: null
5150
checkpoint_path: null
5251

examples/epnn/epnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _transform_in_stress(_in):
114114
save_freq=cfg.TRAIN.save_freq,
115115
eval_during_train=cfg.TRAIN.eval_during_train,
116116
validator=validator_pde,
117-
eval_with_no_grad=cfg.TRAIN.eval_with_no_grad,
117+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
118118
)
119119

120120
# train model

examples/epnn/functions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,9 @@ def get(self, epochs=1):
248248
label_dict_train = {"dummy_loss": []}
249249
label_dict_val = {"dummy_loss": []}
250250
for i in range(epochs):
251-
shuffled_indices = np.random.permutation(self.data_state.x_train.shape[0])
251+
shuffled_indices = paddle.randperm(
252+
n=self.data_state.x_train.shape[0]
253+
).numpy()
252254
input_dict_train["state_x"].append(
253255
self.data_state.x_train[shuffled_indices[0 : self.itrain]]
254256
)
@@ -263,7 +265,7 @@ def get(self, epochs=1):
263265
)
264266
label_dict_train["dummy_loss"].append(0.0)
265267

266-
shuffled_indices = np.random.permutation(self.data_state.x_valid.shape[0])
268+
shuffled_indices = paddle.randperm(n=self.data_state.x_valid.shape[0]).numpy()
267269
input_dict_val["state_x"].append(
268270
self.data_state.x_valid[shuffled_indices[0 : self.itrain]]
269271
)
@@ -296,7 +298,7 @@ def __init__(self, dataset_path, train_p=0.6, cross_valid_p=0.2, test_p=0.2):
296298
def get_shuffled_data(self):
297299
# Need to set the seed, otherwise the loss will not match the precision
298300
ppsci.utils.misc.set_random_seed(seed=10)
299-
shuffled_indices = np.random.permutation(self.x.shape[0])
301+
shuffled_indices = paddle.randperm(n=self.x.shape[0]).numpy()
300302
n_train = math.floor(self.train_p * self.x.shape[0])
301303
n_cross_valid = math.floor(self.cross_valid_p * self.x.shape[0])
302304
n_test = math.floor(self.test_p * self.x.shape[0])

ppsci/data/dataset/vtu_dataset.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ class VtuDataset(io.Dataset):
3737
labels (Optional[Dict[str, float]]): Temporary variable for [load_vtk_with_time_file].
3838
transforms (vision.Compose, optional): Compose object contains sample wise.
3939
transform(s).
40+
41+
Examples:
42+
>>> from ppsci.dataset import VtuDataset
43+
44+
>>> dataset = VtuDataset(file_path='example.vtu') # doctest: +SKIP
45+
46+
>>> # get the length of the dataset
47+
>>> dataset_size = len(dataset) # doctest: +SKIP
48+
>>> # get the first sample of the data
49+
>>> first_sample = dataset[0] # doctest: +SKIP
50+
>>> print("First sample:", first_sample)
4051
"""
4152

4253
# Whether support batch indexing for speeding up fetching process.

0 commit comments

Comments
 (0)