forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathmulti_output_model.py
43 lines (32 loc) · 1.43 KB
/
multi_output_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from .common import preferred_dtype
class MultiOutputModel(torch.nn.Module):
def __init__(self, hidden_dim, weight_value):
super(MultiOutputModel, self).__init__()
self.linear = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.linear.weight.data.fill_(weight_value)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
def forward(self, inputs, targets):
losses = []
for x, y in zip(inputs, targets):
hidden_dim = self.linear(x)
loss = self.cross_entropy_loss(hidden_dim, y)
losses.append(loss)
return tuple(losses)
def multi_output_dataloader(model, total_samples, hidden_dim, device, inputs, targets):
assert len(inputs) == len(targets)
batch_size = model.train_micro_batch_size_per_gpu()
train_data = [
torch.full(size=(total_samples, hidden_dim),
fill_value=x,
device=device,
dtype=preferred_dtype(),
requires_grad=True) for x in inputs
]
train_label = [torch.empty(total_samples, device=device, dtype=torch.long).fill_(y) for y in targets]
train_dataset = torch.utils.data.TensorDataset(*train_data, *train_label)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
return train_loader