We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents ad0db4c + 4f5dbef commit 08bbf12Copy full SHA for 08bbf12
QEfficient/finetune/utils/train_utils.py
@@ -83,6 +83,7 @@ def train(
83
best_val_loss = float("inf")
84
total_train_steps = 0
85
max_steps_reached = False # Flag to indicate max training steps reached
86
+ device_type = device.split(":")[0]
87
88
tensorboard_updates = None
89
if train_config.enable_ddp:
@@ -95,7 +96,7 @@ def train(
95
96
if device.startswith("qaic"):
97
scaler = QAicGradScaler()
98
else:
- scaler = GradScaler()
99
+ scaler = GradScaler(device_type)
100
101
loss_0_counter = torch.tensor([0]).to(device)
102
0 commit comments