|
40 | 40 | " batch_size=batch_size,\n",
|
41 | 41 | " series_length=series_length,\n",
|
42 | 42 | " truncated_length=BPTT_T,\n",
|
| 43 | + " total_values_in_one_chunck = batch_size * BPTT_T,\n", |
43 | 44 | ")\n",
|
44 | 45 | "train_size = len(train_data)\n",
|
45 | 46 | "\n",
|
|
150 | 151 | " optimizer.step()\n",
|
151 | 152 | " \n",
|
152 | 153 | " pred = (torch.sigmoid(logits) > 0.5)\n",
|
153 |
| - " correct += (pred == target.byte()).int().sum().item()\n", |
| 154 | + " correct += (pred == target.byte()).int().sum().item()/total_values_in_one_chunck\n", |
| 155 | + |
154 | 156 | " \n",
|
155 | 157 | " return correct, loss.item(), hidden"
|
156 | 158 | ]
|
|
171 | 173 | " logits, hidden = model(data, hidden)\n",
|
172 | 174 | " \n",
|
173 | 175 | " pred = (torch.sigmoid(logits) > 0.5)\n",
|
174 |
| - " correct += (pred == target.byte()).int().sum().item()\n", |
| 176 | + " correct += (pred == target.byte()).int().sum().item()/total_values_in_one_chunck\n", |
175 | 177 | "\n",
|
176 | 178 | " return correct"
|
177 | 179 | ]
|
|
208 | 210 | "while epoch < n_epochs:\n",
|
209 | 211 | " correct, loss, hidden = train(hidden)\n",
|
210 | 212 | " epoch += 1\n",
|
211 |
| - " train_accuracy = float(correct) / train_size\n", |
| 213 | + " train_accuracy = float(correct)*100/ train_size\n", |
212 | 214 | " print(f'Train Epoch: {epoch}/{n_epochs}, loss: {loss:.3f}, accuracy {train_accuracy:.1f}%')\n",
|
213 | 215 | "\n",
|
214 | 216 | "#test \n",
|
|
0 commit comments