-
Notifications
You must be signed in to change notification settings - Fork 8.2k
Open
Description
test_input = torch.from_numpy(X_test)
test_label = torch.from_numpy(y_test)
# create the data loader for the test set
testset = torch.utils.data.TensorDataset(test_input, test_label)
testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=False, num_workers=0)
cnn.eval()
def train_SCU(X_train, y_train):
train_input = torch.from_numpy(X_train)
train_label = torch.from_numpy(y_train)
trainset = torch.utils.data.TensorDataset(train_input, train_label)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=0)
cnn = SCU(opt, num_classes).to(device)
cnn.train()
ce_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=opt.lr, weight_decay=opt.w_decay)
for epoch in range(opt.n_epochs):
flag = 0
cumulative_accuracy = 0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
inputs = inputs.float()
optimizer.zero_grad()
outputs, outs = cnn(inputs)
loss = ce_loss(outputs, labels)
loss.backward()
optimizer.step()
_, predicted = torch.max(outputs, 1)
cumulative_accuracy += get_accuracy(labels, predicted)
return cnn, outs
cnn.eval()
AttributeError: 'tuple' object has no attribute 'eval'
Metadata
Metadata
Assignees
Labels
No labels