Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenGarnets committed Jan 20, 2020
1 parent fd99fe5 commit c903713
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 1 addition & 2 deletions net/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(self):


def forward(self, x):
batch = x.size(0)

out = self.conv1(x)
out = self.conv2(out)
Expand All @@ -68,7 +67,7 @@ def forward(self, x):
out = self.conv7(out)
out = self.conv8(out)

out = out.reshape(batch,out.size(1)* out.size(2))
out = out.reshape(out.size(0), out.size(1)* out.size(2))
#print(out.shape)

out = self.fc(out)
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def main():
if i+batch < input.shape[0] :
optimizer.zero_grad()

pred = model(input[i:i+batch],batch)
pred = model(input[i:i+batch])


loss = criterion(pred.squeeze(), target[i:i+batch].squeeze())
Expand Down Expand Up @@ -173,7 +173,7 @@ def main():
output = []
for i in range(0,input.shape[0], batch):
if i+batch < input.shape[0] :
pred = model(input[i:i+batch],batch)
pred = model(input[i:i+batch])
pred = pred.to(torch.device("cpu"))
if i == 0 :
output = pred.tolist()
Expand Down

0 comments on commit c903713

Please sign in to comment.