diff --git a/README.md b/README.md index 77d98d9..318a45e 100644 --- a/README.md +++ b/README.md @@ -103,12 +103,17 @@ The core idea of NCRF is taking a grid of patches as input, e.g. 3x3, using CNN ```python def forward(self, x): """ - x here is assumed to be a 5-D tensor with shape of + Args: + x: 5D tensor with shape of [batch_size, grid_size, 3, crop_size, crop_size], where grid_size is the number of patches within a grid (e.g. 9 for a 3x3 grid); crop_size is 224 by default for ResNet input; + Returns: + logits, 2D tensor with shape of [batch_size, grid_size], the logit + of each patch within the grid being tumor """ batch_size, grid_size, _, crop_size = x.shape[0:4] + # flatten grid_size dimension and combine it into batch dimension x = x.view(-1, 3, crop_size, crop_size) x = self.conv1(x) @@ -122,9 +127,11 @@ def forward(self, x): x = self.layer4(x) x = self.avgpool(x) + # feats means features, i.e. patch embeddings from ResNet feats = x.view(x.size(0), -1) logits = self.fc(feats) + # restore grid_size dimension for CRF feats = feats.view((batch_size, grid_size, -1)) logits = logits.view((batch_size, grid_size, -1))