Skip to content

Commit

Permalink
major changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedharbii committed Dec 14, 2022
1 parent 6681c64 commit 0925474
Show file tree
Hide file tree
Showing 6 changed files with 10,035 additions and 10,015 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
*.pkl
*.pth
2 changes: 1 addition & 1 deletion Datasets/tartanTrajFlowDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, imgfolder , posefile = None, transform = None,

files = listdir(imgfolder)
self.rgbfiles = [(imgfolder +'/'+ ff) for ff in files if (ff.endswith('.png') or ff.endswith('.jpg'))]
self.rgbfiles.sort()
# self.rgbfiles.sort()
self.imgfolder = imgfolder

print('Find {} image files in {}'.format(len(self.rgbfiles), imgfolder))
Expand Down
30 changes: 24 additions & 6 deletions TartanVO.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,23 @@ class TartanVO(object):
# Define named constants for default values
DEFAULT_NUM_EPOCHS = 10
SAVE_FREQUENCY = 2
def __init__(self, model_name):
def __init__(self, model_name, lr=0.0001, decay=0.2):
# import ipdb;ipdb.set_trace()
self.vonet = VONet()
# self.optimizer = optimizer
self.lr = lr
self.decay = decay
self.optimizer = torch.optim.Adam(self.vonet.parameters(), lr=self.lr, weight_decay=self.decay)

# load the whole model
if model_name.endswith('.pth'):
model_file_path = os.path.join('models', model_name)
pretrained_model = torch.load(model_file_path)
self.vonet.load_state_dict(pretrained_model['model_state_dict'])
print('Model loaded...')
self.optimizer.load_state_dict(pretrained_model['optimizer_state_dict'])
print('Optimizer loaded...')

if model_name.endswith('.pkl'):
# modelname = 'models/' + model_name
model_file_path = os.path.join('models', model_name)
Expand Down Expand Up @@ -108,7 +120,7 @@ def train_model(self, model, dataloader, optimizer, dataset_len, num_epochs = 10
flow, pose = outputs
#or torch.mul(pose, pose_std)
# pose = pose / pose_std # divide by the standard deviation
pose = torch.div(pose, pose_std)
pose = torch.mul(pose, pose_std) ## or divide, not sure
#Now we have our own loss function (Up to Scale Loss Function)
loss = self.loss_function(GT=motion, Est=pose)
loss.backward()
Expand All @@ -124,26 +136,32 @@ def train_model(self, model, dataloader, optimizer, dataset_len, num_epochs = 10

if epoch % self.SAVE_FREQUENCY == 1:
torch.save({
'model_state_dict': self.vonet.state_dict(),
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, f'./models/epoch_{epoch}_batch_{i}.pkl')
}, f'./models/epoch_{epoch}_batch_{i}.pth')
print(f'Epoch: {epoch}, Model Saved')

print('Finished Training')
PATH = './models/vo_model_pretrained.pkl'
torch.save(model.state_dict(), PATH)
PATH = './models/vo_model_pretrained.pth'
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, PATH)
# torch.save(model.state_dict(), PATH)

def loss_function(self, GT, Est):
# Define a small constant value for epsilon
epsilon = torch.tensor(1e-6)

# Compute the translation error
trans_diff = GT[:, :3] - Est[:, :3]
#L2 norm, is used to calculate the translation error
trans_norm = torch.norm(trans_diff, p=2) + epsilon
trans_loss = torch.mean(trans_norm)

# Compute the rotation error
rot_diff = GT[:, 3:] - Est[:, 3:]
#Frobenius norm, is used to calculate the rotation error
rot_norm = torch.norm(rot_diff, p='fro') + epsilon
rot_loss = torch.mean(rot_norm)

Expand Down
Binary file modified results/unity_vo_model_pretrained.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 0925474

Please sign in to comment.