Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
yil8 authored Jun 17, 2018
1 parent 471918a commit 243e3c0
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,17 @@ The core idea of NCRF is taking a grid of patches as input, e.g. 3x3, using CNN
```python
def forward(self, x):
"""
x here is assumed to be a 5-D tensor with shape of
Args:
x: 5D tensor with shape of
[batch_size, grid_size, 3, crop_size, crop_size],
where grid_size is the number of patches within a grid (e.g. 9 for
a 3x3 grid); crop_size is 224 by default for ResNet input;
Returns:
logits, 2D tensor with shape of [batch_size, grid_size], the logit
of each patch within the grid being tumor
"""
batch_size, grid_size, _, crop_size = x.shape[0:4]
# flatten grid_size dimension and combine it into batch dimension
x = x.view(-1, 3, crop_size, crop_size)

x = self.conv1(x)
Expand All @@ -122,9 +127,11 @@ def forward(self, x):
x = self.layer4(x)

x = self.avgpool(x)
# feats means features, i.e. patch embeddings from ResNet
feats = x.view(x.size(0), -1)
logits = self.fc(feats)

# restore grid_size dimension for CRF
feats = feats.view((batch_size, grid_size, -1))
logits = logits.view((batch_size, grid_size, -1))

Expand Down

0 comments on commit 243e3c0

Please sign in to comment.