Skip to content

Commit 534f3bc

Browse files
committed
Variable is not necessary anymore, similar to Tensor + loss reduction
1 parent c7283df commit 534f3bc

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

train.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88
import torch.nn as nn
99
import yaml
10-
from torch.autograd import Variable
1110
from torch.backends import cudnn
1211
from torch.utils.data import DataLoader
1312
from torchvision.transforms import transforms
@@ -120,7 +119,6 @@ def apply_ctc_loss(floss, output, target: List[List[int]]):
120119
target = concat(target)
121120
target = torch.Tensor(target)
122121
target = target.long()
123-
target = Variable(target)
124122
target = target.view((-1,))
125123
target = target.to(device)
126124

@@ -173,7 +171,7 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
173171
model = build_model(11, seq_length, args.batch_size).to(device)
174172

175173
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
176-
floss = nn.CTCLoss(blank=10)
174+
floss = nn.CTCLoss(blank=10, reduction='mean')
177175

178176
# Train here
179177
history = []
@@ -195,7 +193,6 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
195193
int_targets = [[int(c) for c in gt] for gt in str_targets]
196194

197195
# Prepare image
198-
image = Variable(image)
199196
image = image.to(device)
200197

201198
# Forward
@@ -204,7 +201,7 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
204201
loss = apply_ctc_loss(floss, output, int_targets)
205202

206203
# Backward
207-
loss.backward()
204+
loss.sum().backward()
208205

209206
# Update
210207
optimizer.step()
@@ -213,7 +210,7 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
213210
total_distance += sum(distances)
214211
accuracy = calc_acc(output, str_targets)
215212
total_accuracy += sum(accuracy)
216-
total_loss += loss.item()
213+
total_loss += loss.sum().item()
217214
num_samples += len(str_targets)
218215

219216
if verbose:
@@ -264,7 +261,6 @@ def test(model: nn.Module, dataloader: DataLoader, verbose: bool = False) -> Dic
264261
int_targets = [[int(c) for c in gt] for gt in str_targets]
265262

266263
# Prepare image
267-
image = Variable(image)
268264
image = image.to(device)
269265

270266
# Forward

0 commit comments

Comments
 (0)