Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…ience into develop
  • Loading branch information
HydrogenSulfate committed Mar 2, 2024
2 parents 835ea6a + f2d4c91 commit 010fd80
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 6 deletions.
3 changes: 1 addition & 2 deletions examples/epnn/conf/epnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@ TRAIN:
iters_per_epoch: 1
save_freq: 50
eval_during_train: true
eval_with_no_grad: true
lr_scheduler:
epochs: ${TRAIN.epochs}
iters_per_epoch: ${TRAIN.iters_per_epoch}
gamma: 0.97
decay_steps: 1
decay_steps: 10000000
pretrained_model_path: null
checkpoint_path: null

Expand Down
2 changes: 1 addition & 1 deletion examples/epnn/epnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _transform_in_stress(_in):
save_freq=cfg.TRAIN.save_freq,
eval_during_train=cfg.TRAIN.eval_during_train,
validator=validator_pde,
eval_with_no_grad=cfg.TRAIN.eval_with_no_grad,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)

# train model
Expand Down
8 changes: 5 additions & 3 deletions examples/epnn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,9 @@ def get(self, epochs=1):
label_dict_train = {"dummy_loss": []}
label_dict_val = {"dummy_loss": []}
for i in range(epochs):
shuffled_indices = np.random.permutation(self.data_state.x_train.shape[0])
shuffled_indices = paddle.randperm(
n=self.data_state.x_train.shape[0]
).numpy()
input_dict_train["state_x"].append(
self.data_state.x_train[shuffled_indices[0 : self.itrain]]
)
Expand All @@ -263,7 +265,7 @@ def get(self, epochs=1):
)
label_dict_train["dummy_loss"].append(0.0)

shuffled_indices = np.random.permutation(self.data_state.x_valid.shape[0])
shuffled_indices = paddle.randperm(n=self.data_state.x_valid.shape[0]).numpy()
input_dict_val["state_x"].append(
self.data_state.x_valid[shuffled_indices[0 : self.itrain]]
)
Expand Down Expand Up @@ -296,7 +298,7 @@ def __init__(self, dataset_path, train_p=0.6, cross_valid_p=0.2, test_p=0.2):
def get_shuffled_data(self):
# Need to set the seed, otherwise the loss will not match the precision
ppsci.utils.misc.set_random_seed(seed=10)
shuffled_indices = np.random.permutation(self.x.shape[0])
shuffled_indices = paddle.randperm(n=self.x.shape[0]).numpy()
n_train = math.floor(self.train_p * self.x.shape[0])
n_cross_valid = math.floor(self.cross_valid_p * self.x.shape[0])
n_test = math.floor(self.test_p * self.x.shape[0])
Expand Down
11 changes: 11 additions & 0 deletions ppsci/data/dataset/vtu_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ class VtuDataset(io.Dataset):
labels (Optional[Dict[str, float]]): Temporary variable for [load_vtk_with_time_file].
transforms (vision.Compose, optional): Compose object contains sample wise.
transform(s).
Examples:
>>> from ppsci.dataset import VtuDataset
>>> dataset = VtuDataset(file_path='example.vtu') # doctest: +SKIP
>>> # get the length of the dataset
>>> dataset_size = len(dataset) # doctest: +SKIP
>>> # get the first sample of the data
>>> first_sample = dataset[0] # doctest: +SKIP
>>> print("First sample:", first_sample)
"""

# Whether support batch indexing for speeding up fetching process.
Expand Down

0 comments on commit 010fd80

Please sign in to comment.