Skip to content

Commit

Permalink
支持0维Tensor需要的修改 (#2621)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zth9730 authored Nov 4, 2022
1 parent 114f380 commit e6d2088
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/source/cls/custom_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ for epoch in range(1, epochs + 1):
optimizer.clear_grad()

# Calculate loss
avg_loss = loss.numpy()[0]
avg_loss = float(loss)

# Calculate metrics
preds = paddle.argmax(logits, axis=1)
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorial/cls/cls_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@
" optimizer.clear_grad()\n",
"\n",
" # Calculate loss\n",
" avg_loss += loss.numpy()[0]\n",
" avg_loss += float(loss)\n",
"\n",
" # Calculate metrics\n",
" preds = paddle.argmax(logits, axis=1)\n",
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/cls/exps/panns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
optimizer.clear_grad()

# Calculate loss
avg_loss += loss.numpy()[0]
avg_loss += float(loss)

# Calculate metrics
preds = paddle.argmax(logits, axis=1)
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/kws/exps/mdtc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
optimizer.clear_grad()

# Calculate loss
avg_loss += loss.numpy()[0]
avg_loss += float(loss)

# Calculate metrics
num_corrects += corrects
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/tts/test_pwg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import paddle
import torch
from paddle.device.cuda import synchronize
from parallel_wavegan.layers import residual_block
from parallel_wavegan.layers import upsample
from parallel_wavegan.models import parallel_wavegan as pwgan
Expand All @@ -24,7 +25,6 @@
from paddlespeech.t2s.models.parallel_wavegan import ResidualBlock
from paddlespeech.t2s.models.parallel_wavegan import ResidualPWGDiscriminator
from paddlespeech.t2s.utils.layer_tools import summary
from paddlespeech.t2s.utils.profile import synchronize

paddle.set_device("gpu:0")
device = torch.device("cuda:0")
Expand Down

0 comments on commit e6d2088

Please sign in to comment.