-
Notifications
You must be signed in to change notification settings - Fork 197
/
Copy pathmlp.py
27 lines (23 loc) · 1022 Bytes
/
mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Defined in Section 4.1.6
import torch
from torch import nn
from torch.nn import functional as F
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, num_class):
super(MLP, self).__init__()
# 线性变换:输入层->隐含层
self.linear1 = nn.Linear(input_dim, hidden_dim)
# 使用ReLU激活函数
self.activate = F.relu
# 线性变换:隐含层->输出层
self.linear2 = nn.Linear(hidden_dim, num_class)
def forward(self, inputs):
hidden = self.linear1(inputs)
activation = self.activate(hidden)
outputs = self.linear2(activation)
probs = F.softmax(outputs, dim=1) # 获得每个输入属于某一类别的概率
return probs
mlp = MLP(input_dim=4, hidden_dim=5, num_class=2)
inputs = torch.rand(3, 4) # 输入形状为(3, 4)的张量,其中3表示有3个输入,4表示每个输入的维度
probs = mlp(inputs) # 自动调用forward函数
print(probs) # 输出3个输入对应输出的概率