Skip to content

Commit

Permalink
modify code format
Browse files Browse the repository at this point in the history
  • Loading branch information
justld committed Nov 10, 2021
1 parent ecd0830 commit 916d06a
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,5 @@ model:
type: HRNet_W48
pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w48_ssld.tar.gz
in_channels: 720
num_classes: 19
drop_prob: 0.1
proj_dim: 720
79 changes: 38 additions & 41 deletions paddleseg/models/hrnet_contrast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,6 @@
from paddleseg.utils import utils


class ProjectionHead(nn.Layer):
"""
The projection head used by contrast learning.
Args:
dim_in (int): The dimensions of input features.
proj_dim (int|optional): The output dimensions of projection head. Default: 256.
proj (str|optional): The type of projection head, only support 'linear' and 'convmlp'. Default: 'convmlp'
"""
def __init__(self, dim_in, proj_dim=256, proj='convmlp'):
super(ProjectionHead, self).__init__()

if proj == 'linear':
self.proj = nn.Conv2D(dim_in, proj_dim, kernel_size=1)
elif proj == 'convmlp':
self.proj = nn.Sequential(
nn.Conv2D(dim_in, dim_in, kernel_size=1),
layers.SyncBatchNorm(dim_in),
nn.ReLU(),
nn.Conv2D(dim_in, proj_dim, kernel_size=1),
)

def forward(self, x):
return F.normalize(self.proj(x), p=2, axis=1)


@manager.MODELS.add_component
class HRNetW48Contrast(nn.Layer):
"""
Expand Down Expand Up @@ -82,13 +57,11 @@ def __init__(self,
self.pretrained = pretrained

self.cls_head = nn.Sequential(
nn.Conv2D(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1),
layers.SyncBatchNorm(in_channels),
nn.ReLU(),
layers.ConvBNReLU(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1),
nn.Dropout2D(drop_prob),
nn.Conv2D(in_channels,
num_classes,
Expand All @@ -106,21 +79,45 @@ def init_weight(self):
def forward(self, x):
feats = self.backbone(x)[0]
out = self.cls_head(feats)
logit_list = []
if self.training:
emb = self.proj_head(feats)
return [
logit_list.append(
F.interpolate(out,
x.shape[2:],
mode='bilinear',
align_corners=self.align_corners), {
'seg': out,
'embed': emb
}
]
align_corners=self.align_corners))
logit_list.append({'seg': out, 'embed': emb})
else:
return [
logit_list.append(
F.interpolate(out,
x.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
]
align_corners=self.align_corners))
return logit_list


class ProjectionHead(nn.Layer):
"""
The projection head used by contrast learning.
Args:
dim_in (int): The dimensions of input features.
proj_dim (int, optional): The output dimensions of projection head. Default: 256.
proj (str, optional): The type of projection head, only support 'linear' and 'convmlp'. Default: 'convmlp'.
"""
def __init__(self, dim_in, proj_dim=256, proj='convmlp'):
super(ProjectionHead, self).__init__()
if proj == 'linear':
self.proj = nn.Conv2D(dim_in, proj_dim, kernel_size=1)
elif proj == 'convmlp':
self.proj = nn.Sequential(
layers.ConvBNReLU(dim_in, dim_in, kernel_size=1),
nn.Conv2D(dim_in, proj_dim, kernel_size=1),
)
else:
raise ValueError(
"The type of project head only support 'linear' and 'convmlp', but got {}."
.format(proj))

def forward(self, x):
return F.normalize(self.proj(x), p=2, axis=1)
58 changes: 30 additions & 28 deletions paddleseg/models/losses/pixel_contrast_cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ class PixelContrastCrossEntropyLoss(nn.Layer):
(https://arxiv.org/abs/2101.11939).
Args:
temperature (float|optional): Controling the numerical similarity of features. Default: 0.1.
base_temperature (float|optional): Controling the numerical range of contrast loss. Default: 0.07.
ignore_index (int|optional): Specifies a target value that is ignored
temperature (float, optional): Controling the numerical similarity of features. Default: 0.1.
base_temperature (float, optional): Controling the numerical range of contrast loss. Default: 0.07.
ignore_index (int, optional): Specifies a target value that is ignored
and does not contribute to the input gradient. Default 255.
max_samples (int|optional): Max sampling anchors. Default: 1024.
max_samples (int, optional): Max sampling anchors. Default: 1024.
max_views (int): Sampled samplers of a class. Default: 100.
"""
def __init__(self,
Expand All @@ -42,7 +42,7 @@ def __init__(self,
ignore_index=255,
max_samples=1024,
max_views=100):
super(PixelContrastCrossEntropyLoss, self).__init__()
super().__init__()
self.temperature = temperature
self.base_temperature = base_temperature
self.ignore_index = ignore_index
Expand All @@ -52,24 +52,26 @@ def __init__(self,
def _hard_anchor_sampling(self, X, y_hat, y):
"""
Args:
X(Tensor): reshaped feats, shape = [N, H*W, feat_channels]
y_hat(Tensor): reshaped label, shape = [N, H*W]
y(Tensor): reshaped predict, shape = [N, H*W]
X (Tensor): reshaped feats, shape = [N, H * W, feat_channels]
y_hat (Tensor): reshaped label, shape = [N, H * W]
y (Tensor): reshaped predict, shape = [N, H * W]
"""
batch_size, feat_dim = X.shape[0], X.shape[-1]
batch_size, feat_dim = paddle.shape(X)[0], paddle.shape(X)[-1]
classes = []
total_classes = 0
for ii in range(batch_size):
this_y = y_hat[ii]
this_classes = paddle.unique(this_y)
this_classes = [x for x in this_classes if x != self.ignore_index]
this_classes = [
x for x in this_classes
if (this_y == x).nonzero().shape[0] > self.max_views
for i in range(batch_size):
current_y = y_hat[i]
current_classes = paddle.unique(current_y)
current_classes = [
x for x in current_classes if x != self.ignore_index
]
current_classes = [
x for x in current_classes
if (current_y == x).nonzero().shape[0] > self.max_views
]

classes.append(this_classes)
total_classes += len(this_classes)
classes.append(current_classes)
total_classes += len(current_classes)

n_view = self.max_samples // total_classes
n_view = min(n_view, self.max_views)
Expand All @@ -78,16 +80,16 @@ def _hard_anchor_sampling(self, X, y_hat, y):
y_ = paddle.zeros([total_classes], dtype='float32')

X_ptr = 0
for ii in range(batch_size):
this_y_hat = y_hat[ii]
this_y = y[ii]
this_classes = classes[ii]
for i in range(batch_size):
this_y_hat = y_hat[i]
current_y = y[i]
current_classes = classes[i]

for cls_id in this_classes:
hard_indices = paddle.logical_and((this_y_hat == cls_id),
(this_y != cls_id)).nonzero()
easy_indices = paddle.logical_and((this_y_hat == cls_id),
(this_y == cls_id)).nonzero()
for cls_id in current_classes:
hard_indices = paddle.logical_and(
(this_y_hat == cls_id), (current_y != cls_id)).nonzero()
easy_indices = paddle.logical_and(
(this_y_hat == cls_id), (current_y == cls_id)).nonzero()

num_hard = hard_indices.shape[0]
num_easy = easy_indices.shape[0]
Expand Down Expand Up @@ -119,7 +121,7 @@ def _hard_anchor_sampling(self, X, y_hat, y):
if indices is None:
raise UserWarning('hard sampling indice error')

X_.append(paddle.index_select(X[ii, :, :], indices.squeeze(1)))
X_.append(paddle.index_select(X[i, :, :], indices.squeeze(1)))
y_[X_ptr] = float(cls_id)
X_ptr += 1
X_ = paddle.stack(X_, axis=0)
Expand Down

0 comments on commit 916d06a

Please sign in to comment.