7
7
import torch
8
8
import torch .nn as nn
9
9
import yaml
10
- from torch .autograd import Variable
11
10
from torch .backends import cudnn
12
11
from torch .utils .data import DataLoader
13
12
from torchvision .transforms import transforms
@@ -120,7 +119,6 @@ def apply_ctc_loss(floss, output, target: List[List[int]]):
120
119
target = concat (target )
121
120
target = torch .Tensor (target )
122
121
target = target .long ()
123
- target = Variable (target )
124
122
target = target .view ((- 1 ,))
125
123
target = target .to (device )
126
124
@@ -173,7 +171,7 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
173
171
model = build_model (11 , seq_length , args .batch_size ).to (device )
174
172
175
173
optimizer = torch .optim .Adam (model .parameters (), lr = args .lr )
176
- floss = nn .CTCLoss (blank = 10 )
174
+ floss = nn .CTCLoss (blank = 10 , reduction = 'mean' )
177
175
178
176
# Train here
179
177
history = []
@@ -195,7 +193,6 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
195
193
int_targets = [[int (c ) for c in gt ] for gt in str_targets ]
196
194
197
195
# Prepare image
198
- image = Variable (image )
199
196
image = image .to (device )
200
197
201
198
# Forward
@@ -204,7 +201,7 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
204
201
loss = apply_ctc_loss (floss , output , int_targets )
205
202
206
203
# Backward
207
- loss .backward ()
204
+ loss .sum (). backward ()
208
205
209
206
# Update
210
207
optimizer .step ()
@@ -213,7 +210,7 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
213
210
total_distance += sum (distances )
214
211
accuracy = calc_acc (output , str_targets )
215
212
total_accuracy += sum (accuracy )
216
- total_loss += loss .item ()
213
+ total_loss += loss .sum (). item ()
217
214
num_samples += len (str_targets )
218
215
219
216
if verbose :
@@ -264,7 +261,6 @@ def test(model: nn.Module, dataloader: DataLoader, verbose: bool = False) -> Dic
264
261
int_targets = [[int (c ) for c in gt ] for gt in str_targets ]
265
262
266
263
# Prepare image
267
- image = Variable (image )
268
264
image = image .to (device )
269
265
270
266
# Forward
0 commit comments