Skip to content

Commit f5a350a

Browse files
committed
Fixed minor argument error.
Signed-off-by: Meet Patel <meetkuma@qti.qualcomm.com>
1 parent 3765183 commit f5a350a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

QEfficient/finetune/utils/train_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def train(
111111
num_classes = model.classifier.out_features
112112
acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)
113113

114-
autocast_ctx = get_autocast_ctx(device_type, train_config.use_autocast, dtype=torch.float16)
114+
autocast_ctx = get_autocast_ctx(train_config.use_autocast, device_type, dtype=torch.float16)
115115
op_verifier_ctx = partial(get_op_verifier_ctx, train_config.opByOpVerifier, device, train_config.dump_root_dir)
116116

117117
# Start the training loop
@@ -416,7 +416,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
416416
eval_loss = 0.0 # Initialize evaluation loss
417417
device_type = torch.device(device).type
418418

419-
autocast_ctx = get_autocast_ctx(device_type, train_config.use_autocast, dtype=torch.float16)
419+
autocast_ctx = get_autocast_ctx(train_config.use_autocast, device_type, dtype=torch.float16)
420420
for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
421421
# stop when the maximum number of eval steps is reached
422422
if train_config.max_eval_step > 0 and step > train_config.max_eval_step:

0 commit comments

Comments
 (0)