Skip to content

Commit 6823a3e

Browse files
authored
[Fix]fix data transform error (#779)
* [Fix]fix data transform error * [Update]add 'auto_collation' for some examples
1 parent 583f2fb commit 6823a3e

File tree

6 files changed

+42
-33
lines changed

6 files changed

+42
-33
lines changed

docs/zh/examples/tempoGAN.md

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,9 @@ examples/tempoGAN/tempoGAN.py:78:94
154154

155155
本问题采用无监督学习的方式,虽然不是以监督学习方式进行训练,但此处仍然可以采用监督约束 `SupervisedConstraint`,在定义约束之前,需要给监督约束指定文件路径等数据读取配置,因为 tempoGAN 属于自监督学习,数据集中没有标签数据,而是使用一部分输入数据作为 `label`,因此需要设置约束的 `output_expr`
156156

157-
``` py linenums="122"
157+
``` py linenums="123"
158158
--8<--
159-
examples/tempoGAN/tempoGAN.py:122:125
159+
examples/tempoGAN/tempoGAN.py:123:126
160160
--8<--
161161
```
162162

@@ -166,7 +166,7 @@ examples/tempoGAN/tempoGAN.py:122:125
166166

167167
``` py linenums="98"
168168
--8<--
169-
examples/tempoGAN/tempoGAN.py:98:127
169+
examples/tempoGAN/tempoGAN.py:98:129
170170
--8<--
171171
```
172172

@@ -177,7 +177,9 @@ examples/tempoGAN/tempoGAN.py:98:127
177177
3. `label`: Array 类型的标签数据;
178178
4. `transforms`: 所有数据 transform 方法,此处 `FunctionalTransform` 为PaddleScience 预留的自定义数据 transform 类,该类支持编写代码时自定义输入数据的 transform,具体代码请参考 [自定义 loss 和 data transform](#38)
179179

180-
`batch_size` 字段表示 batch的大小;
180+
`auto_collation` 字段表示允许 BatchSampler 自动排序;
181+
182+
`batch_size` 字段表示 batch 的大小;
181183

182184
`sampler` 字段表示采样方法,其中各个字段表示:
183185

@@ -193,27 +195,27 @@ examples/tempoGAN/tempoGAN.py:98:127
193195

194196
在约束构建完毕之后,以我们刚才的命名为关键字,封装到一个字典中,方便后续访问,由于本问题设置了`use_spatialdisc``use_tempodisc`,导致 Generator 的部分约束不一定存在,因此先封装一定存在的约束到字典中,当其余约束存在时,在向字典中添加约束元素。
195197

196-
``` py linenums="129"
198+
``` py linenums="130"
197199
--8<--
198-
examples/tempoGAN/tempoGAN.py:129:160
200+
examples/tempoGAN/tempoGAN.py:130:162
199201
--8<--
200202
```
201203

202204
#### 3.6.2 Discriminator 的约束
203205

204-
``` py linenums="164"
206+
``` py linenums="166"
205207
--8<--
206-
examples/tempoGAN/tempoGAN.py:164:201
208+
examples/tempoGAN/tempoGAN.py:166:204
207209
--8<--
208210
```
209211

210212
各个参数含义与[Generator 的约束](#361)相同。
211213

212214
#### 3.6.3 Discriminator_tempo 的约束
213215

214-
``` py linenums="205"
216+
``` py linenums="208"
215217
--8<--
216-
examples/tempoGAN/tempoGAN.py:205:244
218+
examples/tempoGAN/tempoGAN.py:208:248
217219
--8<--
218220
```
219221

@@ -279,9 +281,9 @@ examples/tempoGAN/functions.py:430:488
279281

280282
完成上述设置之后,首先需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`,然后启动训练。
281283

282-
``` py linenums="247"
284+
``` py linenums="251"
283285
--8<--
284-
examples/tempoGAN/tempoGAN.py:247:258
286+
examples/tempoGAN/tempoGAN.py:251:262
285287
--8<--
286288
```
287289

@@ -293,15 +295,15 @@ examples/tempoGAN/tempoGAN.py:247:258
293295

294296
训练中仅在特定 `Epoch` 保存特定图片的目标结果和模型输出结果,训练结束后针对最后一个 `Epoch` 的输出结果进行一次评估,以便直观评价模型优化效果。不使用 PaddleScience 中内置的评估器,也不在训练过程中进行评估:
295297

296-
``` py linenums="287"
298+
``` py linenums="291"
297299
--8<--
298-
examples/tempoGAN/tempoGAN.py:287:293
300+
examples/tempoGAN/tempoGAN.py:291:297
299301
--8<--
300302
```
301303

302-
``` py linenums="307"
304+
``` py linenums="311"
303305
--8<--
304-
examples/tempoGAN/tempoGAN.py:307:323
306+
examples/tempoGAN/tempoGAN.py:311:327
305307
--8<--
306308
```
307309

@@ -311,17 +313,17 @@ examples/tempoGAN/tempoGAN.py:307:323
311313

312314
本问题的评估指标为,将模型输出的超分结果与实际高分辨率图片做对比,使用三个指标 MSE(Mean-Square Error) 、PSNR(Peak Signal-to-Noise Ratio) 、SSIM(Structural SIMilarity) 来评价图片相似度。因此没有使用 PaddleScience 中的内置评估器,也没有 `Solver.eval()` 过程。
313315

314-
``` py linenums="326"
316+
``` py linenums="330"
315317
--8<--
316-
examples/tempoGAN/tempoGAN.py:326:406
318+
examples/tempoGAN/tempoGAN.py:330:410
317319
--8<--
318320
```
319321

320322
另外,其中:
321323

322-
``` py linenums="396"
324+
``` py linenums="400"
323325
--8<--
324-
examples/tempoGAN/tempoGAN.py:396:403
326+
examples/tempoGAN/tempoGAN.py:400:407
325327
--8<--
326328
```
327329

docs/zh/examples/topopt.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ examples/topopt/topopt.py:36:38
115115

116116
``` py linenums="102"
117117
--8<--
118-
examples/topopt/functions.py:102:133
118+
examples/topopt/functions.py:102:135
119119
--8<--
120120
```
121121

@@ -125,7 +125,7 @@ examples/topopt/functions.py:102:133
125125

126126
``` py linenums="50"
127127
--8<--
128-
examples/topopt/topopt.py:50:75
128+
examples/topopt/topopt.py:50:76
129129
--8<--
130130
```
131131

@@ -136,7 +136,7 @@ examples/topopt/topopt.py:50:75
136136
3. `label`: 标签变量字典:`{"label_name": label_dataset}`
137137
4. `transforms`: 数据集预处理配,其中 `"FunctionalTransform"` 为用户自定义的预处理方式。
138138

139-
读取配置中 `"batch_size"` 字段表示训练时指定的批大小,`"sampler"` 字段表示 dataloader 的相关采样配置。
139+
读取配置中 `auto_collation` 字段表示允许 BatchSampler 自动排序, `batch_size` 字段表示训练时指定的批大小,`sampler` 字段表示 dataloader 的相关采样配置。
140140

141141
第二个参数是损失函数,这里使用[自定义损失](#381),通过 `cfg.vol_coeff` 确定损失公式中 $\beta$ 对应的值。
142142

@@ -194,9 +194,9 @@ $$
194194

195195
loss 构建代码如下:
196196

197-
``` py linenums="263"
197+
``` py linenums="264"
198198
--8<--
199-
examples/topopt/topopt.py:263:274
199+
examples/topopt/topopt.py:264:275
200200
--8<--
201201
```
202202

@@ -215,9 +215,9 @@ $$
215215
其中 $n_{0} = w_{00} + w_{01}$ , $n_{1} = w_{10} + w_{11}$ ,$w_{tp}$ 表示实际是 $t$ 类且被预测为 $p$ 类的像素点的数量
216216
metric 构建代码如下:
217217

218-
``` py linenums="277"
218+
``` py linenums="278"
219219
--8<--
220-
examples/topopt/topopt.py:277:317
220+
examples/topopt/topopt.py:278:318
221221
--8<--
222222
```
223223

@@ -233,9 +233,9 @@ examples/topopt/conf/topopt.yaml:29:31
233233

234234
训练代码如下:
235235

236-
``` py linenums="77"
236+
``` py linenums="78"
237237
--8<--
238-
examples/topopt/topopt.py:77:111
238+
examples/topopt/topopt.py:78:111
239239
--8<--
240240
```
241241

@@ -249,7 +249,7 @@ examples/topopt/topopt.py:77:111
249249

250250
``` py linenums="218"
251251
--8<--
252-
examples/topopt/topopt.py:218:245
252+
examples/topopt/topopt.py:218:246
253253
--8<--
254254
```
255255

examples/tempoGAN/tempoGAN.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def train(cfg: DictConfig):
111111
},
112112
),
113113
},
114+
"auto_collation": True,
114115
"batch_size": cfg.TRAIN.batch_size.sup_constraint,
115116
"sampler": {
116117
"name": "BatchSampler",
@@ -143,6 +144,7 @@ def train(cfg: DictConfig):
143144
},
144145
),
145146
},
147+
"auto_collation": True,
146148
"batch_size": int(cfg.TRAIN.batch_size.sup_constraint // 3),
147149
"sampler": {
148150
"name": "BatchSampler",
@@ -188,6 +190,7 @@ def train(cfg: DictConfig):
188190
},
189191
),
190192
},
193+
"auto_collation": True,
191194
"batch_size": cfg.TRAIN.batch_size.sup_constraint,
192195
"sampler": {
193196
"name": "BatchSampler",
@@ -229,6 +232,7 @@ def train(cfg: DictConfig):
229232
},
230233
),
231234
},
235+
"auto_collation": True,
232236
"batch_size": int(cfg.TRAIN.batch_size.sup_constraint // 3),
233237
"sampler": {
234238
"name": "BatchSampler",

examples/topopt/functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def augmentation(
113113
"""
114114
inputs = input_dict["input"]
115115
labels = label_dict["output"]
116+
assert len(inputs.shape) == 3
117+
assert len(labels.shape) == 3
116118

117119
# random horizontal flip
118120
if np.random.random() > 0.5:
@@ -125,7 +127,7 @@ def augmentation(
125127
# random 90* rotation
126128
if np.random.random() > 0.5:
127129
new_perm = list(range(len(inputs.shape)))
128-
new_perm[1], new_perm[2] = new_perm[2], new_perm[1]
130+
new_perm[-2], new_perm[-1] = new_perm[-1], new_perm[-2]
129131
inputs = np.transpose(inputs, new_perm)
130132
labels = np.transpose(labels, new_perm)
131133

examples/topopt/topopt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def train(cfg: DictConfig):
6262
},
6363
),
6464
},
65+
"auto_collation": True,
6566
"batch_size": cfg.TRAIN.batch_size,
6667
"sampler": {
6768
"name": "BatchSampler",
@@ -76,7 +77,6 @@ def train(cfg: DictConfig):
7677

7778
# train models for 4 cases
7879
for sampler_key, num in cfg.CASE_PARAM:
79-
8080
# initialize SIMP iteration stop time sampler
8181
SIMP_stop_point_sampler = func_module.generate_sampler(sampler_key, num)
8282

@@ -229,6 +229,7 @@ def evaluate_model(
229229
},
230230
),
231231
},
232+
"auto_collation": True,
232233
"batch_size": cfg.EVAL.batch_size,
233234
"sampler": {
234235
"name": "BatchSampler",

examples/topopt/topoptmodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class TopOptNN(ppsci.arch.UNetEx):
3939
4040
Examples:
4141
>>> import ppsci
42-
>>> model = ppsci.arch.ppsci.arch.UNetEx("input", "output", 2, 1, 3, (16, 32, 64), 2, lambda: 1, Flase, False)
42+
>>> model = ppsci.arch.ppsci.arch.TopOptNN("input", "output", 2, 1, 3, (16, 32, 64), 2, lambda: 1, Flase, False)
4343
"""
4444

4545
def __init__(

0 commit comments

Comments
 (0)