@@ -36,10 +36,11 @@ def forward(self, x):
3636 return x , out1 , out2 , out3 , out4 , out5 , out6 , out7 , out8 , out9
3737
3838 model = NoPytorchAutocastModel ().cuda ().eval ()
39- inputs = (torch .randn ((8 , 3 , 32 , 32 ), dtype = torch .float32 , device = "cuda" ),)
39+ BS = 8
40+ inputs = (torch .randn ((BS , 3 , 32 , 32 ), dtype = torch .float32 , device = "cuda" ),)
4041 ep = torch .export .export (model , inputs )
4142 calibration_dataloader = torch .utils .data .DataLoader (
42- torch .utils .data .TensorDataset (* inputs ), batch_size = 2 , shuffle = False
43+ torch .utils .data .TensorDataset (* inputs ), batch_size = BS , shuffle = False
4344 )
4445
4546 with torch_tensorrt .dynamo .Debugger (
@@ -126,10 +127,11 @@ def forward(self, x):
126127 return x , out1 , out2 , out3 , out4 , out5 , out6 , out7 , out8 , out9
127128
128129 model = WholePytorchAutocastModel ().cuda ().eval ()
129- inputs = (torch .randn ((8 , 3 , 32 , 32 ), dtype = torch .float32 , device = "cuda" ),)
130+ BS = 4
131+ inputs = (torch .randn ((BS , 3 , 32 , 32 ), dtype = torch .float32 , device = "cuda" ),)
130132 ep = torch .export .export (model , inputs )
131133 calibration_dataloader = torch .utils .data .DataLoader (
132- torch .utils .data .TensorDataset (* inputs ), batch_size = 2 , shuffle = False
134+ torch .utils .data .TensorDataset (* inputs ), batch_size = BS , shuffle = False
133135 )
134136
135137 with torch_tensorrt .dynamo .Debugger (
@@ -205,10 +207,11 @@ def forward(self, x):
205207 return x , out1 , out2 , out3 , out4 , out5 , out6 , out7 , out8 , out9
206208
207209 model = MixedPytorchAutocastModel ().cuda ().eval ()
208- inputs = (torch .randn ((8 , 3 , 32 , 32 ), dtype = torch .float32 , device = "cuda" ),)
210+ BS = 2
211+ inputs = (torch .randn ((BS , 3 , 32 , 32 ), dtype = torch .float32 , device = "cuda" ),)
209212 ep = torch .export .export (model , inputs )
210213 calibration_dataloader = torch .utils .data .DataLoader (
211- torch .utils .data .TensorDataset (* inputs ), batch_size = 2 , shuffle = False
214+ torch .utils .data .TensorDataset (* inputs ), batch_size = BS , shuffle = False
212215 )
213216
214217 with torch_tensorrt .dynamo .Debugger (
0 commit comments