diff --git a/monai/networks/nets/cell_sam_wrapper.py b/monai/networks/nets/cell_sam_wrapper.py index 3da0e0b5e4..ccca07dee2 100644 --- a/monai/networks/nets/cell_sam_wrapper.py +++ b/monai/networks/nets/cell_sam_wrapper.py @@ -33,10 +33,10 @@ class CellSamWrapper(torch.nn.Module): auto_resize_inputs: whether to resize inputs before passing to the network. network_resize_roi: expected input size for the network. checkpoint: checkpoint file to load the SAM weights from. - return_features: whether to return features + return_features: whether to return features """ - + def __init__( self, auto_resize_inputs=True, @@ -77,7 +77,7 @@ def forward(self, x): if self.auto_resize_inputs: x = F.interpolate(x, size=self.network_resize_roi, mode="bilinear") - x = self.model.image_encoder(x) + x = self.model.image_encoder(x) if not self.return_features: x = self.model.mask_decoder(x)