Skip to content

Commit 9ded043

Browse files
authored
Revert "[Fix]fix data transform error (#779)" (#781)
This reverts commit 127505a.
1 parent 127505a commit 9ded043

File tree

6 files changed

+33
-42
lines changed

6 files changed

+33
-42
lines changed

docs/zh/examples/tempoGAN.md

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,9 @@ examples/tempoGAN/tempoGAN.py:78:94
154154

155155
本问题采用无监督学习的方式,虽然不是以监督学习方式进行训练,但此处仍然可以采用监督约束 `SupervisedConstraint`,在定义约束之前,需要给监督约束指定文件路径等数据读取配置,因为 tempoGAN 属于自监督学习,数据集中没有标签数据,而是使用一部分输入数据作为 `label`,因此需要设置约束的 `output_expr`
156156

157-
``` py linenums="123"
157+
``` py linenums="122"
158158
--8<--
159-
examples/tempoGAN/tempoGAN.py:123:126
159+
examples/tempoGAN/tempoGAN.py:122:125
160160
--8<--
161161
```
162162

@@ -166,7 +166,7 @@ examples/tempoGAN/tempoGAN.py:123:126
166166

167167
``` py linenums="98"
168168
--8<--
169-
examples/tempoGAN/tempoGAN.py:98:129
169+
examples/tempoGAN/tempoGAN.py:98:127
170170
--8<--
171171
```
172172

@@ -177,9 +177,7 @@ examples/tempoGAN/tempoGAN.py:98:129
177177
3. `label`: Array 类型的标签数据;
178178
4. `transforms`: 所有数据 transform 方法,此处 `FunctionalTransform` 为PaddleScience 预留的自定义数据 transform 类,该类支持编写代码时自定义输入数据的 transform,具体代码请参考 [自定义 loss 和 data transform](#38)
179179

180-
`auto_collation` 字段表示允许 BatchSampler 自动排序;
181-
182-
`batch_size` 字段表示 batch 的大小;
180+
`batch_size` 字段表示 batch的大小;
183181

184182
`sampler` 字段表示采样方法,其中各个字段表示:
185183

@@ -195,27 +193,27 @@ examples/tempoGAN/tempoGAN.py:98:129
195193

196194
在约束构建完毕之后,以我们刚才的命名为关键字,封装到一个字典中,方便后续访问,由于本问题设置了`use_spatialdisc``use_tempodisc`,导致 Generator 的部分约束不一定存在,因此先封装一定存在的约束到字典中,当其余约束存在时,在向字典中添加约束元素。
197195

198-
``` py linenums="130"
196+
``` py linenums="129"
199197
--8<--
200-
examples/tempoGAN/tempoGAN.py:130:162
198+
examples/tempoGAN/tempoGAN.py:129:160
201199
--8<--
202200
```
203201

204202
#### 3.6.2 Discriminator 的约束
205203

206-
``` py linenums="166"
204+
``` py linenums="164"
207205
--8<--
208-
examples/tempoGAN/tempoGAN.py:166:204
206+
examples/tempoGAN/tempoGAN.py:164:201
209207
--8<--
210208
```
211209

212210
各个参数含义与[Generator 的约束](#361)相同。
213211

214212
#### 3.6.3 Discriminator_tempo 的约束
215213

216-
``` py linenums="208"
214+
``` py linenums="205"
217215
--8<--
218-
examples/tempoGAN/tempoGAN.py:208:248
216+
examples/tempoGAN/tempoGAN.py:205:244
219217
--8<--
220218
```
221219

@@ -281,9 +279,9 @@ examples/tempoGAN/functions.py:430:488
281279

282280
完成上述设置之后,首先需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`,然后启动训练。
283281

284-
``` py linenums="251"
282+
``` py linenums="247"
285283
--8<--
286-
examples/tempoGAN/tempoGAN.py:251:262
284+
examples/tempoGAN/tempoGAN.py:247:258
287285
--8<--
288286
```
289287

@@ -295,15 +293,15 @@ examples/tempoGAN/tempoGAN.py:251:262
295293

296294
训练中仅在特定 `Epoch` 保存特定图片的目标结果和模型输出结果,训练结束后针对最后一个 `Epoch` 的输出结果进行一次评估,以便直观评价模型优化效果。不使用 PaddleScience 中内置的评估器,也不在训练过程中进行评估:
297295

298-
``` py linenums="291"
296+
``` py linenums="287"
299297
--8<--
300-
examples/tempoGAN/tempoGAN.py:291:297
298+
examples/tempoGAN/tempoGAN.py:287:293
301299
--8<--
302300
```
303301

304-
``` py linenums="311"
302+
``` py linenums="307"
305303
--8<--
306-
examples/tempoGAN/tempoGAN.py:311:327
304+
examples/tempoGAN/tempoGAN.py:307:323
307305
--8<--
308306
```
309307

@@ -313,17 +311,17 @@ examples/tempoGAN/tempoGAN.py:311:327
313311

314312
本问题的评估指标为,将模型输出的超分结果与实际高分辨率图片做对比,使用三个指标 MSE(Mean-Square Error) 、PSNR(Peak Signal-to-Noise Ratio) 、SSIM(Structural SIMilarity) 来评价图片相似度。因此没有使用 PaddleScience 中的内置评估器,也没有 `Solver.eval()` 过程。
315313

316-
``` py linenums="330"
314+
``` py linenums="326"
317315
--8<--
318-
examples/tempoGAN/tempoGAN.py:330:410
316+
examples/tempoGAN/tempoGAN.py:326:406
319317
--8<--
320318
```
321319

322320
另外,其中:
323321

324-
``` py linenums="400"
322+
``` py linenums="396"
325323
--8<--
326-
examples/tempoGAN/tempoGAN.py:400:407
324+
examples/tempoGAN/tempoGAN.py:396:403
327325
--8<--
328326
```
329327

docs/zh/examples/topopt.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ examples/topopt/topopt.py:36:38
115115

116116
``` py linenums="102"
117117
--8<--
118-
examples/topopt/functions.py:102:135
118+
examples/topopt/functions.py:102:133
119119
--8<--
120120
```
121121

@@ -125,7 +125,7 @@ examples/topopt/functions.py:102:135
125125

126126
``` py linenums="50"
127127
--8<--
128-
examples/topopt/topopt.py:50:76
128+
examples/topopt/topopt.py:50:75
129129
--8<--
130130
```
131131

@@ -136,7 +136,7 @@ examples/topopt/topopt.py:50:76
136136
3. `label`: 标签变量字典:`{"label_name": label_dataset}`
137137
4. `transforms`: 数据集预处理配,其中 `"FunctionalTransform"` 为用户自定义的预处理方式。
138138

139-
读取配置中 `auto_collation` 字段表示允许 BatchSampler 自动排序, `batch_size` 字段表示训练时指定的批大小,`sampler` 字段表示 dataloader 的相关采样配置。
139+
读取配置中 `"batch_size"` 字段表示训练时指定的批大小,`"sampler"` 字段表示 dataloader 的相关采样配置。
140140

141141
第二个参数是损失函数,这里使用[自定义损失](#381),通过 `cfg.vol_coeff` 确定损失公式中 $\beta$ 对应的值。
142142

@@ -194,9 +194,9 @@ $$
194194

195195
loss 构建代码如下:
196196

197-
``` py linenums="264"
197+
``` py linenums="263"
198198
--8<--
199-
examples/topopt/topopt.py:264:275
199+
examples/topopt/topopt.py:263:274
200200
--8<--
201201
```
202202

@@ -215,9 +215,9 @@ $$
215215
其中 $n_{0} = w_{00} + w_{01}$ , $n_{1} = w_{10} + w_{11}$ ,$w_{tp}$ 表示实际是 $t$ 类且被预测为 $p$ 类的像素点的数量
216216
metric 构建代码如下:
217217

218-
``` py linenums="278"
218+
``` py linenums="277"
219219
--8<--
220-
examples/topopt/topopt.py:278:318
220+
examples/topopt/topopt.py:277:317
221221
--8<--
222222
```
223223

@@ -233,9 +233,9 @@ examples/topopt/conf/topopt.yaml:29:31
233233

234234
训练代码如下:
235235

236-
``` py linenums="78"
236+
``` py linenums="77"
237237
--8<--
238-
examples/topopt/topopt.py:78:111
238+
examples/topopt/topopt.py:77:111
239239
--8<--
240240
```
241241

@@ -249,7 +249,7 @@ examples/topopt/topopt.py:78:111
249249

250250
``` py linenums="218"
251251
--8<--
252-
examples/topopt/topopt.py:218:246
252+
examples/topopt/topopt.py:218:245
253253
--8<--
254254
```
255255

examples/tempoGAN/tempoGAN.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ def train(cfg: DictConfig):
111111
},
112112
),
113113
},
114-
"auto_collation": True,
115114
"batch_size": cfg.TRAIN.batch_size.sup_constraint,
116115
"sampler": {
117116
"name": "BatchSampler",
@@ -144,7 +143,6 @@ def train(cfg: DictConfig):
144143
},
145144
),
146145
},
147-
"auto_collation": True,
148146
"batch_size": int(cfg.TRAIN.batch_size.sup_constraint // 3),
149147
"sampler": {
150148
"name": "BatchSampler",
@@ -190,7 +188,6 @@ def train(cfg: DictConfig):
190188
},
191189
),
192190
},
193-
"auto_collation": True,
194191
"batch_size": cfg.TRAIN.batch_size.sup_constraint,
195192
"sampler": {
196193
"name": "BatchSampler",
@@ -232,7 +229,6 @@ def train(cfg: DictConfig):
232229
},
233230
),
234231
},
235-
"auto_collation": True,
236232
"batch_size": int(cfg.TRAIN.batch_size.sup_constraint // 3),
237233
"sampler": {
238234
"name": "BatchSampler",

examples/topopt/functions.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ def augmentation(
113113
"""
114114
inputs = input_dict["input"]
115115
labels = label_dict["output"]
116-
assert len(inputs.shape) == 3
117-
assert len(labels.shape) == 3
118116

119117
# random horizontal flip
120118
if np.random.random() > 0.5:
@@ -127,7 +125,7 @@ def augmentation(
127125
# random 90* rotation
128126
if np.random.random() > 0.5:
129127
new_perm = list(range(len(inputs.shape)))
130-
new_perm[-2], new_perm[-1] = new_perm[-1], new_perm[-2]
128+
new_perm[1], new_perm[2] = new_perm[2], new_perm[1]
131129
inputs = np.transpose(inputs, new_perm)
132130
labels = np.transpose(labels, new_perm)
133131

examples/topopt/topopt.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def train(cfg: DictConfig):
6262
},
6363
),
6464
},
65-
"auto_collation": True,
6665
"batch_size": cfg.TRAIN.batch_size,
6766
"sampler": {
6867
"name": "BatchSampler",
@@ -77,6 +76,7 @@ def train(cfg: DictConfig):
7776

7877
# train models for 4 cases
7978
for sampler_key, num in cfg.CASE_PARAM:
79+
8080
# initialize SIMP iteration stop time sampler
8181
SIMP_stop_point_sampler = func_module.generate_sampler(sampler_key, num)
8282

@@ -229,7 +229,6 @@ def evaluate_model(
229229
},
230230
),
231231
},
232-
"auto_collation": True,
233232
"batch_size": cfg.EVAL.batch_size,
234233
"sampler": {
235234
"name": "BatchSampler",

examples/topopt/topoptmodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class TopOptNN(ppsci.arch.UNetEx):
3939
4040
Examples:
4141
>>> import ppsci
42-
>>> model = ppsci.arch.ppsci.arch.TopOptNN("input", "output", 2, 1, 3, (16, 32, 64), 2, lambda: 1, Flase, False)
42+
>>> model = ppsci.arch.ppsci.arch.UNetEx("input", "output", 2, 1, 3, (16, 32, 64), 2, lambda: 1, Flase, False)
4343
"""
4444

4545
def __init__(

0 commit comments

Comments
 (0)