Skip to content

Commit 0b33180

Browse files
authored
Update train_scratch.py
1 parent 39ae9ba commit 0b33180

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

lesson63-迁移学习-自定义数据集实战/train_scratch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
viz = visdom.Visdom()
3030

3131
def evalute(model, loader):
32+
model.eval()
33+
3234
correct = 0
3335
total = len(loader.dataset)
3436

@@ -58,7 +60,8 @@ def main():
5860

5961
# x: [b, 3, 224, 224], y: [b]
6062
x, y = x.to(device), y.to(device)
61-
63+
64+
model.train()
6265
logits = model(x)
6366
loss = criteon(logits, y)
6467

@@ -94,4 +97,4 @@ def main():
9497

9598

9699
if __name__ == '__main__':
97-
main()
100+
main()

0 commit comments

Comments
 (0)