Skip to content

Commit 4cc6a9b

Browse files
committed
pass in device and dtypes explicitly for #102 among others
1 parent dbcb1bd commit 4cc6a9b

File tree

3 files changed

+35
-16
lines changed

3 files changed

+35
-16
lines changed

torchsummary/tests/test_models/test_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,22 @@ def forward(self, x1, x2):
3535
x2 = self.fc2b(x2)
3636
x = torch.cat((x1, x2), 0)
3737
return F.log_softmax(x, dim=1)
38+
39+
class MultipleInputNetDifferentDtypes(nn.Module):
40+
def __init__(self):
41+
super(MultipleInputNetDifferentDtypes, self).__init__()
42+
self.fc1a = nn.Linear(300, 50)
43+
self.fc1b = nn.Linear(50, 10)
44+
45+
self.fc2a = nn.Linear(300, 50)
46+
self.fc2b = nn.Linear(50, 10)
47+
48+
def forward(self, x1, x2):
49+
x1 = F.relu(self.fc1a(x1))
50+
x1 = self.fc1b(x1)
51+
x2 = x2.type(torch.FloatTensor)
52+
x2 = F.relu(self.fc2a(x2))
53+
x2 = self.fc2b(x2)
54+
# set x2 to FloatTensor
55+
x = torch.cat((x1, x2), 0)
56+
return F.log_softmax(x, dim=1)

torchsummary/tests/unit_tests/torchsummary_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
from torchsummary import summary
3-
from torchsummary.tests.test_models.test_model import SingleInputNet, MultipleInputNet
3+
from torchsummary.tests.test_models.test_model import SingleInputNet, MultipleInputNet, MultipleInputNetDifferentDtypes
44
import torch
55

66
class torchsummaryTests(unittest.TestCase):
@@ -30,9 +30,18 @@ def test_single_layer_network_on_gpu(self):
3030
model = torch.nn.Linear(2, 5)
3131
model.cuda()
3232
input = (1, 2)
33-
total_params, trainable_params = summary(model, input, device="cuda")
33+
total_params, trainable_params = summary(model, input, device="cuda:0")
3434
self.assertEqual(total_params, 15)
3535
self.assertEqual(trainable_params, 15)
3636

37+
def test_multiple_input_types(self):
38+
model = MultipleInputNetDifferentDtypes()
39+
input1 = (1, 300)
40+
input2 = (1, 300)
41+
dtypes = [torch.FloatTensor, torch.LongTensor]
42+
total_params, trainable_params = summary(model, [input1, input2], device="cpu", dtypes=dtypes)
43+
self.assertEqual(total_params, 31120)
44+
self.assertEqual(trainable_params, 31120)
45+
3746
if __name__ == '__main__':
3847
unittest.main(buffer=True)

torchsummary/torchsummary.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import numpy as np
77

88

9-
def summary(model, input_size, batch_size=-1, device="cuda"):
9+
def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None):
10+
if dtypes == None:
11+
dtypes = [torch.FloatTensor]*len(input_size)
1012

1113
def register_hook(module):
1214

@@ -40,24 +42,13 @@ def hook(module, input, output):
4042
):
4143
hooks.append(module.register_forward_hook(hook))
4244

43-
device = device.lower()
44-
assert device in [
45-
"cuda",
46-
"cpu",
47-
], "Input device is not valid, please specify 'cuda' or 'cpu'"
48-
49-
if device == "cuda" and torch.cuda.is_available():
50-
dtype = torch.cuda.FloatTensor
51-
else:
52-
dtype = torch.FloatTensor
53-
5445
# multiple inputs to the network
5546
if isinstance(input_size, tuple):
5647
input_size = [input_size]
5748

49+
5850
# batch_size of 2 for batchnorm
59-
x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
60-
# print(type(x[0]))
51+
x = [ torch.rand(2, *in_size).type(dtype).to(device=device) for in_size, dtype in zip(input_size, dtypes)]
6152

6253
# create properties
6354
summary = OrderedDict()

0 commit comments

Comments
 (0)