Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 5 additions & 109 deletions quant_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@
help='path to latest fp32 checkpoint (default: none)')
best_acc1 = 0

arch_dict = {'q_resmlp': resmlp_24, 'q_resmlp_norm': resmlp_24_norm, 'q_resmlp_v2': resmlp_24, 'q_resmlp_v3': resmlp_24, 'q_resmlp_v4': resmlp_24_v4}
quantize_arch_dict = {'q_resmlp': q_resmlp, 'q_resmlp_norm': q_resmlp_norm, 'q_resmlp_v2': q_resmlp_v2, 'q_resmlp_v3': q_resmlp_v3, 'q_resmlp_v4': q_resmlp_v4}
arch_dict = {'q_resmlp': resmlp_24, 'q_resmlp_norm': resmlp_24_norm, 'q_resmlp_v2': resmlp_24, 'q_resmlp_v3': resmlp_24, 'q_resmlp_v4': resmlp_24_v4, 'q_deit_s' : deit_small_patch16_224}
quantize_arch_dict = {'q_resmlp': q_resmlp, 'q_resmlp_norm': q_resmlp_norm, 'q_resmlp_v2': q_resmlp_v2, 'q_resmlp_v3': q_resmlp_v3, 'q_resmlp_v4': q_resmlp_v4, 'q_deit_s' : q_vit}

args = parser.parse_args()
if not os.path.exists(args.save_path):
Expand Down Expand Up @@ -234,12 +234,12 @@ def main_worker(gpu, ngpus_per_node, args):
logging.info("=> using pre-trained model '{}'".format(args.arch))
arch = arch_dict[args.arch]
model = arch(pretrained=True)

print(model)
else:
logging.info("=> creating model '{}'".format(args.arch))
arch = arch_dict[args.arch]
model = arch(pretrained=False)

if args.load_pretrain:
logging.info("=> loading fp32 checkpoint '{}'".format(args.load_pretrain))
model.load_state_dict(torch.load(args.load_pretrain)['model'])
Expand Down Expand Up @@ -284,6 +284,7 @@ def main_worker(gpu, ngpus_per_node, args):
quantize_arch = quantize_arch_dict[args.arch]
model = quantize_arch(model)

print(model)
# if args.arch == "q_resmlp_v4":
# for i in range(0, 24):
# model.blocks[i].inner.weight.requires_grad = False
Expand Down Expand Up @@ -533,111 +534,6 @@ def train(train_loader, model, criterion, optimizer, epoch, args):



# def train_kd(train_loader, model, teacher, criterion, optimizer, epoch, val_loader, args, ngpus_per_node,
# dataset_length):
# batch_time = AverageMeter('Time', ':6.3f')
# data_time = AverageMeter('Data', ':6.3f')
# losses = AverageMeter('Loss', ':.4e')
# top1 = AverageMeter('Acc@1', ':6.2f')
# top5 = AverageMeter('Acc@5', ':6.2f')
# progress = ProgressMeter(
# len(train_loader),
# [batch_time, data_time, losses, top1, top5],
# prefix="Epoch: [{}]".format(epoch))

# # switch to train mode
# if args.fix_BN == True:
# model.eval()
# else:
# model.train()
# teacher.eval()

# end = time.time()

# for i, (images, target) in enumerate(train_loader):
# # measure data loading time
# data_time.update(time.time() - end)

# if args.gpu is not None:
# images = images.cuda(args.gpu, non_blocking=True)
# target = target.cuda(args.gpu, non_blocking=True)

# # compute output
# output = model(images)
# if args.distill_method != 'None':
# with torch.no_grad():
# teacher_output = teacher(images)

# if args.distill_method == 'None':
# loss = criterion(output, target)
# elif args.distill_method == 'KD_naive':
# loss = loss_kd(output, target, teacher_output, args)
# else:
# raise NotImplementedError

# # measure accuracy and record loss
# acc1, acc5 = accuracy(output, target, topk=(1, 5))
# losses.update(loss.item(), images.size(0))
# top1.update(acc1[0], images.size(0))
# top5.update(acc5[0], images.size(0))

# # compute gradient and do SGD step
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()

# # measure elapsed time
# batch_time.update(time.time() - end)
# end = time.time()

# if i % args.print_freq == 0:
# progress.display(i)
# if i % args.print_freq == 0 and args.rank == 0:
# print('Epoch {epoch_} [{iters}] Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(epoch_=epoch, iters=i,
# top1=top1, top5=top5))
# if args.wandb:
# to_log = {
# "train_loss": loss.item(),
# "train_acc1": acc1[0],
# "train_acc5": acc5[0]
# }

# scales = model.get_scales()
# # for i, scale in enumerate(scales):
# # to_log[f"train_quant/scale_{i}"] = scale
# # wandb.log(to_log)

# if i % ((dataset_length // (
# args.batch_size * args.evaluate_times)) + 2) == 0 and i > 0 and args.evaluate_times > 0:
# acc1 = validate(val_loader, model, criterion, args)

# # switch to train mode
# if args.fix_BN == True:
# model.eval()
# else:
# model.train()

# # remember best acc@1 and save checkpoint
# global best_acc1
# is_best = acc1 > best_acc1
# best_acc1 = max(acc1, best_acc1)

# if not args.multiprocessing_distributed or (args.multiprocessing_distributed
# and args.rank % ngpus_per_node == 0):
# if not os.path.exists(args.save_path):
# os.makedirs(args.save_path)

# save_checkpoint({
# 'epoch': epoch + 1,
# 'arch': args.arch,
# 'state_dict': model.state_dict(),
# 'best_acc1': best_acc1,
# 'optimizer': optimizer.state_dict(),
# }, is_best, args.save_path)
# # print(model.state_dict())
# print("Saved checkpoint.")


def validate(val_loader, model, criterion, args):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
Expand Down
5 changes: 4 additions & 1 deletion src/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@
from .q_resmlp_v2_5 import *

from .resmlp_model_v4 import *
from .q_resmlp_v4 import *
from .q_resmlp_v4 import *

from .vit import *
from .q_vit import *
Loading