diff --git a/net/model.py b/net/model.py index 20f7d64..34b19a8 100644 --- a/net/model.py +++ b/net/model.py @@ -57,7 +57,6 @@ def __init__(self): def forward(self, x): - batch = x.size(0) out = self.conv1(x) out = self.conv2(out) @@ -68,7 +67,7 @@ def forward(self, x): out = self.conv7(out) out = self.conv8(out) - out = out.reshape(batch,out.size(1)* out.size(2)) + out = out.reshape(out.size(0), out.size(1)* out.size(2)) #print(out.shape) out = self.fc(out) diff --git a/train.py b/train.py index 7efe274..94d973f 100644 --- a/train.py +++ b/train.py @@ -135,7 +135,7 @@ def main(): if i+batch < input.shape[0] : optimizer.zero_grad() - pred = model(input[i:i+batch],batch) + pred = model(input[i:i+batch]) loss = criterion(pred.squeeze(), target[i:i+batch].squeeze()) @@ -173,7 +173,7 @@ def main(): output = [] for i in range(0,input.shape[0], batch): if i+batch < input.shape[0] : - pred = model(input[i:i+batch],batch) + pred = model(input[i:i+batch]) pred = pred.to(torch.device("cpu")) if i == 0 : output = pred.tolist()