Skip to content

Commit

Permalink
register buffer in TPS-grid for multi-gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Baek JeongHun committed Apr 11, 2019
1 parent a89c081 commit 7553b47
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, opt):
""" Transformation """
if opt.Transformation == 'TPS':
self.Transformation = TPS_SpatialTransformerNetwork(
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), batch_size=opt.batch_size, I_channel_num=opt.input_channel)
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), batch_size=int(opt.batch_size/opt.num_gpu), I_channel_num=opt.input_channel)
else:
print('No Transformation module specified')

Expand Down
8 changes: 4 additions & 4 deletions modules/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ def __init__(self, F, I_r_size, batch_size):
self.F = F
self.C = self._build_C(self.F) # F x 2
self.P = self._build_P(self.I_r_width, self.I_r_height)
self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3
self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3
self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float() # F+3 x F+3
self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float() # n x F+3

self.batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) # n x F+3 -> batch_size x n x F+3
self.batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) # F+3 x F+3 -> batch_size x F+3 x F+3
self.register_buffer("batch_inv_delta_C", self.inv_delta_C.repeat(batch_size, 1, 1))
self.register_buffer("batch_P_hat", self.P_hat.repeat(batch_size, 1, 1))

def _build_C(self, F):
""" Return coordinates of fiducial points in I_r; C """
Expand Down

0 comments on commit 7553b47

Please sign in to comment.