-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy path__init__.py
41 lines (34 loc) · 1.2 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
def xavier(m: nn.Module) -> None:
"""
Applies Xavier initialization to linear modules.
:param m: the module to be initialized
Example::
>>> net = nn.Sequential(nn.Linear(10, 10), nn.Relu)
>>> net.apply(xavier)
"""
if m.__class__.__name__ == 'Linear':
fan_in = m.weight.data.size(1)
fan_out = m.weight.data.size(0)
std = 1.0 * math.sqrt(2.0 / (fan_in + fan_out))
a = math.sqrt(3.0) * std
m.weight.data.uniform_(-a, a)
if m.bias is not None:
m.bias.data.fill_(0.0)
def num_flat_features(x: torch.Tensor) -> int:
"""
Computes the total number of items except the first dimension.
:param x: input tensor
:return: number of item from the second dimension onward
"""
size = x.size()[1:]
num_features = 1
for ff in size:
num_features *= ff
return num_features