Nunzio
added BiSeNet V2
60fd570
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as modelzoo
# URL for pretrained backbone weights
backbone_url = 'https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/backbone_v2.pth'
class ConvBNReLU(nn.Module):
"""
Convolution + BatchNorm + ReLU block.
"""
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1,
dilation=1, groups=1, bias=False):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(
in_chan, out_chan, kernel_size=ks, stride=stride,
padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_chan)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
feat = self.conv(x)
feat = self.bn(feat)
feat = self.relu(feat)
return feat
class UpSample(nn.Module):
"""
Upsample block using PixelShuffle.
"""
def __init__(self, n_chan, factor=2):
super(UpSample, self).__init__()
out_chan = n_chan * factor * factor
self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0)
self.up = nn.PixelShuffle(factor)
self.init_weight()
def forward(self, x):
feat = self.proj(x)
feat = self.up(feat)
return feat
def init_weight(self):
nn.init.xavier_normal_(self.proj.weight, gain=1.)
class DetailBranch(nn.Module):
"""
Detail branch for capturing spatial details.
"""
def __init__(self):
super(DetailBranch, self).__init__()
self.S1 = nn.Sequential(
ConvBNReLU(3, 64, 3, stride=2),
ConvBNReLU(64, 64, 3, stride=1),
)
self.S2 = nn.Sequential(
ConvBNReLU(64, 64, 3, stride=2),
ConvBNReLU(64, 64, 3, stride=1),
ConvBNReLU(64, 64, 3, stride=1),
)
self.S3 = nn.Sequential(
ConvBNReLU(64, 128, 3, stride=2),
ConvBNReLU(128, 128, 3, stride=1),
ConvBNReLU(128, 128, 3, stride=1),
)
def forward(self, x):
feat = self.S1(x)
feat = self.S2(feat)
feat = self.S3(feat)
return feat
class StemBlock(nn.Module):
"""
Stem block for the semantic branch.
"""
def __init__(self):
super(StemBlock, self).__init__()
self.conv = ConvBNReLU(3, 16, 3, stride=2)
self.left = nn.Sequential(
ConvBNReLU(16, 8, 1, stride=1, padding=0),
ConvBNReLU(8, 16, 3, stride=2),
)
self.right = nn.MaxPool2d(
kernel_size=3, stride=2, padding=1, ceil_mode=False)
self.fuse = ConvBNReLU(32, 16, 3, stride=1)
def forward(self, x):
feat = self.conv(x)
feat_left = self.left(feat)
feat_right = self.right(feat)
feat = torch.cat([feat_left, feat_right], dim=1)
feat = self.fuse(feat)
return feat
class CEBlock(nn.Module):
"""
Context Embedding Block.
"""
def __init__(self):
super(CEBlock, self).__init__()
self.bn = nn.BatchNorm2d(128)
self.conv_gap = ConvBNReLU(128, 128, 1, stride=1, padding=0)
# In paper, this is a naive conv2d, no bn-relu
self.conv_last = ConvBNReLU(128, 128, 3, stride=1)
def forward(self, x):
feat = torch.mean(x, dim=(2, 3), keepdim=True)
feat = self.bn(feat)
feat = self.conv_gap(feat)
feat = feat + x
feat = self.conv_last(feat)
return feat
class GELayerS1(nn.Module):
"""
Gather-and-Expansion Layer with stride 1.
"""
def __init__(self, in_chan, out_chan, exp_ratio=6):
super(GELayerS1, self).__init__()
mid_chan = in_chan * exp_ratio
self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
self.dwconv = nn.Sequential(
nn.Conv2d(
in_chan, mid_chan, kernel_size=3, stride=1,
padding=1, groups=in_chan, bias=False),
nn.BatchNorm2d(mid_chan),
nn.ReLU(inplace=True), # not shown in paper
)
self.conv2 = nn.Sequential(
nn.Conv2d(
mid_chan, out_chan, kernel_size=1, stride=1,
padding=0, bias=False),
nn.BatchNorm2d(out_chan),
)
self.conv2[1].last_bn = True
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
feat = self.conv1(x)
feat = self.dwconv(feat)
feat = self.conv2(feat)
feat = feat + x
feat = self.relu(feat)
return feat
class GELayerS2(nn.Module):
"""
Gather-and-Expansion Layer with stride 2.
"""
def __init__(self, in_chan, out_chan, exp_ratio=6):
super(GELayerS2, self).__init__()
mid_chan = in_chan * exp_ratio
self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
self.dwconv1 = nn.Sequential(
nn.Conv2d(
in_chan, mid_chan, kernel_size=3, stride=2,
padding=1, groups=in_chan, bias=False),
nn.BatchNorm2d(mid_chan),
)
self.dwconv2 = nn.Sequential(
nn.Conv2d(
mid_chan, mid_chan, kernel_size=3, stride=1,
padding=1, groups=mid_chan, bias=False),
nn.BatchNorm2d(mid_chan),
nn.ReLU(inplace=True), # not shown in paper
)
self.conv2 = nn.Sequential(
nn.Conv2d(
mid_chan, out_chan, kernel_size=1, stride=1,
padding=0, bias=False),
nn.BatchNorm2d(out_chan),
)
self.conv2[1].last_bn = True
self.shortcut = nn.Sequential(
nn.Conv2d(
in_chan, in_chan, kernel_size=3, stride=2,
padding=1, groups=in_chan, bias=False),
nn.BatchNorm2d(in_chan),
nn.Conv2d(
in_chan, out_chan, kernel_size=1, stride=1,
padding=0, bias=False),
nn.BatchNorm2d(out_chan),
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
feat = self.conv1(x)
feat = self.dwconv1(feat)
feat = self.dwconv2(feat)
feat = self.conv2(feat)
shortcut = self.shortcut(x)
feat = feat + shortcut
feat = self.relu(feat)
return feat
class SegmentBranch(nn.Module):
"""
Semantic branch for extracting semantic features.
"""
def __init__(self):
super(SegmentBranch, self).__init__()
self.S1S2 = StemBlock()
self.S3 = nn.Sequential(
GELayerS2(16, 32),
GELayerS1(32, 32),
)
self.S4 = nn.Sequential(
GELayerS2(32, 64),
GELayerS1(64, 64),
)
self.S5_4 = nn.Sequential(
GELayerS2(64, 128),
GELayerS1(128, 128),
GELayerS1(128, 128),
GELayerS1(128, 128),
)
self.S5_5 = CEBlock()
def forward(self, x):
feat2 = self.S1S2(x)
feat3 = self.S3(feat2)
feat4 = self.S4(feat3)
feat5_4 = self.S5_4(feat4)
feat5_5 = self.S5_5(feat5_4)
return feat2, feat3, feat4, feat5_4, feat5_5
class BGALayer(nn.Module):
"""
Bilateral Guided Aggregation Layer.
"""
def __init__(self):
super(BGALayer, self).__init__()
self.left1 = nn.Sequential(
nn.Conv2d(
128, 128, kernel_size=3, stride=1,
padding=1, groups=128, bias=False),
nn.BatchNorm2d(128),
nn.Conv2d(
128, 128, kernel_size=1, stride=1,
padding=0, bias=False),
)
self.left2 = nn.Sequential(
nn.Conv2d(
128, 128, kernel_size=3, stride=2,
padding=1, bias=False),
nn.BatchNorm2d(128),
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
)
self.right1 = nn.Sequential(
nn.Conv2d(
128, 128, kernel_size=3, stride=1,
padding=1, bias=False),
nn.BatchNorm2d(128),
)
self.right2 = nn.Sequential(
nn.Conv2d(
128, 128, kernel_size=3, stride=1,
padding=1, groups=128, bias=False),
nn.BatchNorm2d(128),
nn.Conv2d(
128, 128, kernel_size=1, stride=1,
padding=0, bias=False),
)
self.up1 = nn.Upsample(scale_factor=4)
self.up2 = nn.Upsample(scale_factor=4)
# In paper, this may have no relu
self.conv = nn.Sequential(
nn.Conv2d(
128, 128, kernel_size=3, stride=1,
padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True), # not shown in paper
)
def forward(self, x_d, x_s):
dsize = x_d.size()[2:]
left1 = self.left1(x_d)
left2 = self.left2(x_d)
right1 = self.right1(x_s)
right2 = self.right2(x_s)
right1 = self.up1(right1)
left = left1 * torch.sigmoid(right1)
right = left2 * torch.sigmoid(right2)
right = self.up2(right)
out = self.conv(left + right)
return out
class SegmentHead(nn.Module):
"""
Segmentation head for outputting logits.
"""
def __init__(self, in_chan, mid_chan, n_classes, up_factor=8, aux=True):
super(SegmentHead, self).__init__()
self.conv = ConvBNReLU(in_chan, mid_chan, 3, stride=1)
self.drop = nn.Dropout(0.1)
self.up_factor = up_factor
out_chan = n_classes
mid_chan2 = up_factor * up_factor if aux else mid_chan
up_factor = up_factor // 2 if aux else up_factor
self.conv_out = nn.Sequential(
nn.Sequential(
nn.Upsample(scale_factor=2),
ConvBNReLU(mid_chan, mid_chan2, 3, stride=1)
) if aux else nn.Identity(),
nn.Conv2d(mid_chan2, out_chan, 1, 1, 0, bias=True),
nn.Upsample(scale_factor=up_factor, mode='bilinear', align_corners=False)
)
def forward(self, x):
feat = self.conv(x)
feat = self.drop(feat)
feat = self.conv_out(feat)
return feat
class CustomArgMax(torch.autograd.Function):
"""
Custom ArgMax function for ONNX export compatibility.
"""
@staticmethod
def forward(ctx, feat_out, dim):
return feat_out.argmax(dim=dim).int()
@staticmethod
def symbolic(g, feat_out, dim: int):
return g.op('CustomArgMax', feat_out, dim_i=dim)
class BiSeNetV2(nn.Module):
"""
BiSeNetV2 main model.
"""
def __init__(self, n_classes, aux_mode='train'):
super(BiSeNetV2, self).__init__()
self.aux_mode = aux_mode
self.detail = DetailBranch()
self.segment = SegmentBranch()
self.bga = BGALayer()
# Main segmentation head
self.head = SegmentHead(128, 1024, n_classes, up_factor=8, aux=False)
if self.aux_mode == 'train':
# Auxiliary heads for deep supervision
self.aux2 = SegmentHead(16, 128, n_classes, up_factor=4)
self.aux3 = SegmentHead(32, 128, n_classes, up_factor=8)
self.aux4 = SegmentHead(64, 128, n_classes, up_factor=16)
self.aux5_4 = SegmentHead(128, 128, n_classes, up_factor=32)
self.init_weights()
def forward(self, x):
size = x.size()[2:]
feat_d = self.detail(x)
feat2, feat3, feat4, feat5_4, feat_s = self.segment(x)
feat_head = self.bga(feat_d, feat_s)
logits = self.head(feat_head)
if self.aux_mode == 'train':
logits_aux2 = self.aux2(feat2)
logits_aux3 = self.aux3(feat3)
logits_aux4 = self.aux4(feat4)
logits_aux5_4 = self.aux5_4(feat5_4)
return logits, logits_aux2, logits_aux3, logits_aux4, logits_aux5_4
elif self.aux_mode == 'eval':
return logits,
elif self.aux_mode == 'pred':
# Use custom argmax for ONNX compatibility
pred = CustomArgMax.apply(logits, 1)
return pred
else:
raise NotImplementedError
def init_weights(self):
"""
Initialize model weights.
"""
for name, module in self.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight, mode='fan_out')
if not module.bias is None: nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.modules.batchnorm._BatchNorm):
if hasattr(module, 'last_bn') and module.last_bn:
nn.init.zeros_(module.weight)
else:
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
self.load_pretrain()
def load_pretrain(self):
"""
Load pretrained backbone weights.
"""
state = modelzoo.load_url(backbone_url)
for name, child in self.named_children():
if name in state.keys():
child.load_state_dict(state[name], strict=True)
def get_params(self):
"""
Get model parameters for optimizer with/without weight decay.
"""
def add_param_to_list(mod, wd_params, nowd_params):
for param in mod.parameters():
if param.dim() == 1:
nowd_params.append(param)
elif param.dim() == 4:
wd_params.append(param)
else:
print(name)
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
for name, child in self.named_children():
if 'head' in name or 'aux' in name:
add_param_to_list(child, lr_mul_wd_params, lr_mul_nowd_params)
else:
add_param_to_list(child, wd_params, nowd_params)
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params