@@ -22,6 +22,7 @@ class Linear(nn.Module):
22
22
default = None
23
23
dropout (float): 0. - 1., default = 0.
24
24
bias (bool): to enable bias, default = True
25
+ out_shape (tuple): a desired output shape in tuple with out batches
25
26
26
27
Return:
27
28
torch.Tensor of shape (B, out_features)
@@ -33,30 +34,45 @@ def __init__(self,
33
34
activation : str = None ,
34
35
dropout : float = 0. ,
35
36
bias : bool = True ,
37
+ out_shape : tuple = None ,
36
38
** kwargs ):
37
39
super (Linear , self ).__init__ ()
40
+ self .t_size = tuple (tensor_size )
38
41
# Checks
39
- assert type (tensor_size ) in [int , list , tuple ], \
40
- "Linear: tensor_size must tuple/list"
42
+ if not type (tensor_size ) in [int , list , tuple ]:
43
+ raise TypeError ( "Linear: tensor_size must tuple/list" )
41
44
42
45
if isinstance (tensor_size , int ):
43
46
in_features = tensor_size
44
47
else :
45
48
assert len (tensor_size ) >= 2 , \
46
49
"Linear: when tuple/list, tensor_size must of length 2 or more"
47
50
in_features = np .prod (tensor_size [1 :])
48
- assert isinstance (out_features , int ), "Linear:out_features must be int"
49
- assert isinstance (dropout , float ), "Linear: dropout must be float"
50
- if dropout > 0. :
51
+
52
+ if not isinstance (out_features , int ):
53
+ raise TypeError ("Linear:out_features must be int" )
54
+
55
+ if not isinstance (dropout , float ):
56
+ raise TypeError ("Linear: dropout must be float" )
57
+ if 1. > dropout > 0. :
51
58
self .dropout = nn .Dropout2d (dropout )
59
+
52
60
if isinstance (activation , str ):
53
61
activation = activation .lower ()
54
62
assert activation in [None , "" , ] + Activations .available (),\
55
63
"Linear: activation must be None/''/" + \
56
64
"/" .join (Activations .available ())
57
- assert isinstance (bias , bool ), "Linear: bias must be bool"
58
- multiplier = 2 if activation in ("maxo" , "rmxo" ) else 1
65
+ self .act = activation
66
+
67
+ if not isinstance (bias , bool ):
68
+ raise TypeError ("Linear: bias must be bool" )
59
69
70
+ if out_shape is not None :
71
+ assert np .prod (out_shape ) == out_features , \
72
+ "Linear: np.prod(out_shape) != out_features"
73
+ self .out_shape = out_shape
74
+
75
+ multiplier = 2 if activation in ("maxo" , "rmxo" ) else 1
60
76
# get weight and bias
61
77
self .weight = nn .Parameter (torch .rand (out_features * multiplier ,
62
78
in_features ))
@@ -66,9 +82,12 @@ def __init__(self,
66
82
self .bias = nn .Parameter (torch .zeros (out_features * multiplier ))
67
83
# get activation function
68
84
if activation is not None :
69
- self .activation = Activations (activation )
85
+ if activation in Activations .available ():
86
+ self .activation = Activations (activation )
70
87
# out tensor size
71
88
self .tensor_size = (1 , out_features )
89
+ if hasattr (self , "out_shape" ):
90
+ self .tensor_size = tuple ([1 , ] + list (out_shape ))
72
91
73
92
def forward (self , tensor ):
74
93
if tensor .dim () > 2 :
@@ -80,13 +99,25 @@ def forward(self, tensor):
80
99
tensor = tensor + self .bias .view (1 , - 1 )
81
100
if hasattr (self , "activation" ):
82
101
tensor = self .activation (tensor )
102
+ if hasattr (self , "out_shape" ):
103
+ tensor = tensor .view (- 1 , * self .out_shape )
83
104
return tensor
84
105
106
+ def __repr__ (self ):
107
+ msg = "x" .join (["_" ]+ [str (x )for x in self .t_size [1 :]]) + " -> "
108
+ if hasattr (self , "dropout" ):
109
+ msg += "dropout -> "
110
+ msg += "{}({})" .format ("linear" , "x" .join ([str (x ) for x in
111
+ self .weight .shape ]))+ " -> "
112
+ if hasattr (self , "activation" ):
113
+ msg += self .act + " -> "
114
+ msg += "x" .join (["_" ]+ [str (x )for x in self .tensor_size [1 :]])
115
+ return msg
85
116
86
117
# from core.NeuralLayers import Activations
87
118
# tensor_size = (2, 3, 10, 10)
88
119
# x = torch.rand(*tensor_size)
89
- # test = Linear(tensor_size, 16, "maxo ", 0., bias= True)
120
+ # test = Linear(tensor_size, 16, "", 0., True, (1, 4, 4) )
90
121
# test(x).size()
91
122
# test.weight.shape
92
123
# test.bias.shape
0 commit comments