1- import torch
21import os
3- import random
42import logging
5- from contextlib import suppress
3+ import torch
64from pathlib import Path
75
86# TorchBench imports
97from torchbenchmark .util .model import BenchmarkModel
10- from torchbenchmark .util .jit import jit_model
11- from torchbenchmark .util .torch_feature_checker import check_native_amp
128from torchbenchmark .tasks import COMPUTER_VISION
139
1410# effdet imports
1814# timm imports
1915from timm .models .layers import set_layer_config
2016from timm .optim import create_optimizer
21- from timm .utils import ModelEmaV2 , NativeScaler
17+ from timm .utils import ModelEmaV2
2218from timm .scheduler import create_scheduler
2319
2420# local imports
3026CURRENT_DIR = Path (os .path .dirname (os .path .realpath (__file__ )))
3127DATA_DIR = os .path .join (CURRENT_DIR .parent .parent , "data" , ".data" , "coco2017-minimal" , "coco" )
3228
33- torch .manual_seed (1337 )
34- random .seed (1337 )
35- torch .backends .cudnn .deterministic = False
36- torch .backends .cudnn .benchmark = True
37-
3829class Model (BenchmarkModel ):
3930 task = COMPUTER_VISION .DETECTION
40-
4131 # Original Train batch size 32 on 2x RTX 3090 (24 GB cards)
4232 # Downscale to batch size 16 on single GPU
43- def __init__ (self , device = None , jit = False , train_bs = 16 , eval_bs = 128 ):
44- super ().__init__ ()
45- self .device = device
46- self .jit = jit
33+ DEFAULT_TRAIN_BSIZE = 16
34+ DEFAULT_EVAL_BSIZE = 128
35+
36+ def __init__ (self , test , device , jit = False , batch_size = None , extra_args = []):
37+ super ().__init__ (test = test , device = device , jit = jit , batch_size = batch_size , extra_args = extra_args )
4738 # generate arguments
4839 args = get_args ()
4940 # setup train and eval batch size
50- args .batch_size = train_bs
51- args .eval_batch_size = eval_bs
52- # Use native amp if possible
53- args .native_amp = check_native_amp ()
41+ args .batch_size = self .batch_size
5442 # Disable distributed
5543 args .distributed = False
56- args .device = device
57- args .torchscript = jit
44+ args .device = self . device
45+ args .torchscript = self . jit
5846 args .world_size = 1
5947 args .rank = 0
6048 args .pretrained_backbone = not args .no_pretrained_backbone
6149 args .prefetcher = not args .no_prefetcher
6250 args .root = DATA_DIR
6351
64- if not self .device == "cuda" :
65- raise NotImplementedError ("Only CUDA is supported by this model" )
66-
6752 with set_layer_config (scriptable = args .torchscript ):
68- extra_args = {}
53+ timm_extra_args = {}
6954 if args .img_size is not None :
70- extra_args = dict (image_size = (args .img_size , args .img_size ))
71- model = create_model (
72- model_name = args .model ,
73- bench_task = 'train' ,
74- num_classes = args .num_classes ,
75- pretrained = args .pretrained ,
76- pretrained_backbone = args .pretrained_backbone ,
77- redundant_bias = args .redundant_bias ,
78- label_smoothing = args .smoothing ,
79- legacy_focal = args .legacy_focal ,
80- jit_loss = args .jit_loss ,
81- soft_nms = args .soft_nms ,
82- bench_labeler = args .bench_labeler ,
83- checkpoint_path = args .initial_checkpoint ,
84- )
85- eval_model = create_model (
86- model_name = args .model ,
87- bench_task = 'predict' ,
88- num_classes = args .num_classes ,
89- pretrained = args .pretrained ,
90- redundant_bias = args .redundant_bias ,
91- soft_nms = args .soft_nms ,
92- checkpoint_path = args .checkpoint ,
93- checkpoint_ema = args .use_ema ,
94- ** extra_args ,
95- )
55+ timm_extra_args = dict (image_size = (args .img_size , args .img_size ))
56+ if test == "train" :
57+ model = create_model (
58+ model_name = args .model ,
59+ bench_task = 'train' ,
60+ num_classes = args .num_classes ,
61+ pretrained = args .pretrained ,
62+ pretrained_backbone = args .pretrained_backbone ,
63+ redundant_bias = args .redundant_bias ,
64+ label_smoothing = args .smoothing ,
65+ legacy_focal = args .legacy_focal ,
66+ jit_loss = args .jit_loss ,
67+ soft_nms = args .soft_nms ,
68+ bench_labeler = args .bench_labeler ,
69+ checkpoint_path = args .initial_checkpoint ,
70+ )
71+ elif test == "eval" :
72+ model = create_model (
73+ model_name = args .model ,
74+ bench_task = 'predict' ,
75+ num_classes = args .num_classes ,
76+ pretrained = args .pretrained ,
77+ redundant_bias = args .redundant_bias ,
78+ soft_nms = args .soft_nms ,
79+ checkpoint_path = args .checkpoint ,
80+ checkpoint_ema = args .use_ema ,
81+ ** timm_extra_args ,
82+ )
9683 model_config = model .config # grab before we obscure with DP/DDP wrappers
97- model = model .to (device )
84+ self . model = model .to (device )
9885 if args .channels_last :
99- model = model .to (memory_format = torch .channels_last )
100- eval_model = eval_model .to (device )
101-
102- self .model , self .eval_model = jit_model (model , eval_model , jit = jit )
103- self .optimizer = create_optimizer (args , model )
104- self .amp_autocast = suppress
105- if args .native_amp :
106- self .amp_autocast = torch .cuda .amp .autocast
107- self .loss_scaler = NativeScaler ()
108- self .model_ema = None
109- if args .model_ema :
110- # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
111- self .model_ema = ModelEmaV2 (model , decay = args .model_ema_decay )
112- self .lr_scheduler , self .num_epochs = create_scheduler (args , self .optimizer )
113-
114- self .loader_train , self .loader_eval , self .evaluator , _ , dataset_eval = create_datasets_and_loaders (args , model_config )
115- if model_config .num_classes < self .loader_train .dataset .parser .max_label :
116- logging .error (
117- f'Model { model_config .num_classes } has fewer classes than dataset { self .loader_train .dataset .parser .max_label } .' )
118- exit (1 )
119- if model_config .num_classes > self .loader_train .dataset .parser .max_label :
120- logging .warning (
121- f'Model { model_config .num_classes } has more classes than dataset { self .loader_train .dataset .parser .max_label } .' )
122- self .train_num_batch = 1
123-
124- # Create eval loader
125- input_config = resolve_input_config (args , model_config )
126- self .eval_loader = create_loader (
127- dataset_eval ,
128- input_size = input_config ['input_size' ],
129- batch_size = args .eval_batch_size ,
130- use_prefetcher = args .prefetcher ,
131- interpolation = args .eval_interpolation ,
132- fill_color = input_config ['fill_color' ],
133- mean = input_config ['mean' ],
134- std = input_config ['std' ],
135- num_workers = args .workers ,
136- pin_mem = args .pin_mem )
137- self .eval_num_batch = 1
86+ self .model = self .model .to (memory_format = torch .channels_last )
87+
88+ if test == "train" :
89+ self .optimizer = create_optimizer (args , model )
90+ self .model_ema = None
91+ if args .model_ema :
92+ # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
93+ self .model_ema = ModelEmaV2 (model , decay = args .model_ema_decay )
94+ self .lr_scheduler , self .num_epochs = create_scheduler (args , self .optimizer )
95+
96+ self .loader_train , self .loader_eval , self .evaluator , _ , dataset_eval = create_datasets_and_loaders (args , model_config )
97+ if model_config .num_classes < self .loader_train .dataset .parser .max_label :
98+ logging .error (
99+ f'Model { model_config .num_classes } has fewer classes than dataset { self .loader_train .dataset .parser .max_label } .' )
100+ exit (1 )
101+ if model_config .num_classes > self .loader_train .dataset .parser .max_label :
102+ logging .warning (
103+ f'Model { model_config .num_classes } has more classes than dataset { self .loader_train .dataset .parser .max_label } .' )
104+ elif test == "eval" :
105+ # Create eval loader
106+ input_config = resolve_input_config (args , model_config )
107+ self .loader = create_loader (
108+ dataset_eval ,
109+ input_size = input_config ['input_size' ],
110+ batch_size = args .eval_batch_size ,
111+ use_prefetcher = args .prefetcher ,
112+ interpolation = args .eval_interpolation ,
113+ fill_color = input_config ['fill_color' ],
114+ mean = input_config ['mean' ],
115+ std = input_config ['std' ],
116+ num_workers = args .workers ,
117+ pin_mem = args .pin_mem )
138118 self .args = args
119+ # Only run 1 batch in 1 epoch
120+ self .num_batches = 1
121+ self .num_epochs = 1
139122
140123 def get_module (self ):
141- self .eval_model .eval ()
142- for _ , (input , target ) in zip (range (self .eval_num_batch ), self .loader_eval ):
143- return (self .eval_model , (input , target ))
124+ for _ , (input , target ) in zip (range (self .num_batches ), self .loader_eval ):
125+ return (self .model , (input , target ))
144126
145- # Temporarily disable training because this will cause CUDA OOM in CI
146- # TODO: re-enable this test when better hardware is available
147127 def train (self , niter = 1 ):
148- raise NotImplementedError ("Disable this test because it causes CUDA OOM on Nvidia T4 GPU" )
149- if not self .device == "cuda" :
150- raise NotImplementedError ("Only CUDA is supported by this model" )
151- if self .jit :
152- raise NotImplementedError ("JIT is not supported by this model" )
153128 eval_metric = self .args .eval_metric
154- self .model .train ()
155- for epoch in range (niter ):
129+ for epoch in range (self .num_epochs ):
156130 train_metrics = train_epoch (
157131 epoch , self .model , self .loader_train ,
158132 self .optimizer , self .args ,
@@ -170,14 +144,8 @@ def train(self, niter=1):
170144 self .lr_scheduler .step (epoch + 1 , eval_metrics [eval_metric ])
171145
172146 def eval (self , niter = 1 ):
173- if not self .device == "cuda" :
174- raise NotImplementedError ("Only CUDA is supported by this model" )
175- if self .jit :
176- raise NotImplementedError ("JIT is not supported by this model" )
177- self .eval_model .eval ()
178147 for _ in range (niter ):
179148 with torch .no_grad ():
180- for _ , (input , target ) in zip (range (self .eval_num_batch ), self .eval_loader ):
181- with self .amp_autocast ():
182- output = self .eval_model (input , img_info = target )
149+ for _ , (input , target ) in zip (range (self .num_batches ), self .loader ):
150+ output = self .model (input , img_info = target )
183151 self .evaluator .add_predictions (output , target )
0 commit comments