Skip to content

Commit

Permalink
Merge branch 'inference' of github.com:TimoStoff/event_cnn_minimal in…
Browse files Browse the repository at this point in the history
…to inference
  • Loading branch information
TimoStoff committed May 12, 2020
2 parents 659010b + fd376cb commit a9e2159
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 4 additions & 0 deletions data_loader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ def __getitem__(self, i, seed=None):
events_end_idx = np.searchsorted(self.np_ts,
self.voxel_method['t'] + i * self.voxel_method['sliding_window_t'] + self.t0)

if events_end_idx == events_start_idx:
events_end_idx = events_start_idx + 1
print('WARNING! Set events_end_idx to events_start_idx + 1 at i={}'.format(i))

timestamp = self.h5_file['events/ts'][events_start_idx]
else:
raise Exception("Unsupported voxel_method")
Expand Down
6 changes: 1 addition & 5 deletions model/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def build_prediction_layer(self, num_output_channels, norm=None):
return ConvLayer(self.base_num_channels if self.skip_type == 'sum' else 2 * self.base_num_channels,
num_output_channels, 1, activation=None, norm=norm)


class WNet(BaseUNet):
"""
Recurrent UNet architecture where every encoder is followed by a recurrent convolutional block,
Expand Down Expand Up @@ -94,7 +95,6 @@ def __init__(self, unet_kwargs):
def forward(self, x):
"""
:param x: N x num_input_channels x H x W
:param prev_states: previous LSTM states for every encoder layer
:return: N x num_output_channels x H x W
"""

Expand Down Expand Up @@ -159,7 +159,6 @@ def __init__(self, unet_kwargs):
def forward(self, x):
"""
:param x: N x num_input_channels x H x W
:param prev_states: previous LSTM states for every encoder layer
:return: N x num_output_channels x H x W
"""

Expand Down Expand Up @@ -277,7 +276,6 @@ def __init__(self, unet_kwargs):
def forward(self, x):
"""
:param x: N x num_input_channels x H x W
:param prev_states: previous LSTM states for every encoder layer
:return: N x num_output_channels x H x W
"""

Expand Down Expand Up @@ -326,11 +324,9 @@ def __init__(self, unet_kwargs):
self.decoders = self.build_decoders()
self.pred = ConvLayer(self.base_num_channels, self.num_output_channels, kernel_size=1, activation=None)


def forward(self, x):
"""
:param x: N x num_input_channels x H x W
:param prev_states: previous LSTM states for every encoder layer
:return: N x num_output_channels x H x W
"""

Expand Down

0 comments on commit a9e2159

Please sign in to comment.