| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class ConvBnRelu(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int, |
| stride: int = 1, |
| padding: int = 0, |
| dilation: int = 1, |
| groups: int = 1, |
| bias: bool = True, |
| add_relu: bool = True, |
| interpolate: bool = False |
| ): |
| super(ConvBnRelu, self).__init__() |
| self.conv = nn.Conv2d( |
| in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, |
| stride=stride, padding=padding, dilation=dilation, bias=bias, groups=groups |
| ) |
| self.add_relu = add_relu |
| self.interpolate = interpolate |
| self.bn = nn.BatchNorm2d(out_channels) |
| self.activation = nn.ReLU(inplace=True) |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| x = self.bn(x) |
| if self.add_relu: |
| x = self.activation(x) |
| if self.interpolate: |
| x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) |
| return x |
|
|
|
|
| class FPABlock(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| upscale_mode='bilinear' |
| ): |
| super(FPABlock, self).__init__() |
|
|
| self.upscale_mode = upscale_mode |
| if self.upscale_mode == 'bilinear': |
| self.align_corners = True |
| else: |
| self.align_corners = False |
|
|
| |
| self.branch1 = nn.Sequential( |
| nn.AdaptiveAvgPool2d(1), |
| ConvBnRelu(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0) |
| ) |
|
|
| |
| self.mid = nn.Sequential( |
| ConvBnRelu(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0) |
| ) |
| self.down1 = nn.Sequential( |
| nn.MaxPool2d(kernel_size=2, stride=2), |
| ConvBnRelu(in_channels=in_channels, out_channels=1, kernel_size=7, stride=1, padding=3) |
| ) |
| self.down2 = nn.Sequential( |
| nn.MaxPool2d(kernel_size=2, stride=2), |
| ConvBnRelu(in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2) |
| ) |
| self.down3 = nn.Sequential( |
| nn.MaxPool2d(kernel_size=2, stride=2), |
| ConvBnRelu(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1), |
| ConvBnRelu(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1), |
| ) |
| self.conv2 = ConvBnRelu(in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2) |
| self.conv1 = ConvBnRelu(in_channels=1, out_channels=1, kernel_size=7, stride=1, padding=3) |
|
|
| def forward(self, x): |
| h, w = x.size(2), x.size(3) |
| b1 = self.branch1(x) |
| upscale_parameters = dict( |
| mode=self.upscale_mode, |
| align_corners=self.align_corners |
| ) |
| b1 = F.interpolate(b1, size=(h, w), **upscale_parameters) |
|
|
| mid = self.mid(x) |
| x1 = self.down1(x) |
| x2 = self.down2(x1) |
| x3 = self.down3(x2) |
| x3 = F.interpolate(x3, size=(h // 4, w // 4), **upscale_parameters) |
|
|
| x2 = self.conv2(x2) |
| x = x2 + x3 |
| x = F.interpolate(x, size=(h // 2, w // 2), **upscale_parameters) |
|
|
| x1 = self.conv1(x1) |
| x = x + x1 |
| x = F.interpolate(x, size=(h, w), **upscale_parameters) |
|
|
| x = torch.mul(x, mid) |
| x = x + b1 |
| return x |
|
|
|
|
| class GAUBlock(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| upscale_mode: str = 'bilinear' |
| ): |
| super(GAUBlock, self).__init__() |
|
|
| self.upscale_mode = upscale_mode |
| self.align_corners = True if upscale_mode == 'bilinear' else None |
|
|
| self.conv1 = nn.Sequential( |
| nn.AdaptiveAvgPool2d(1), |
| ConvBnRelu(in_channels=out_channels, out_channels=out_channels, kernel_size=1, add_relu=False), |
| nn.Sigmoid() |
| ) |
| self.conv2 = ConvBnRelu(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) |
|
|
| def forward(self, x, y): |
| """ |
| Args: |
| x: low level feature |
| y: high level feature |
| """ |
| h, w = x.size(2), x.size(3) |
| y_up = F.interpolate( |
| y, size=(h, w), mode=self.upscale_mode, align_corners=self.align_corners |
| ) |
| x = self.conv2(x) |
| y = self.conv1(y) |
| z = torch.mul(x, y) |
| return y_up + z |
|
|
|
|
| class PANDecoder(nn.Module): |
|
|
| def __init__( |
| self, |
| encoder_channels, |
| decoder_channels, |
| upscale_mode: str = 'bilinear' |
| ): |
| super().__init__() |
|
|
| self.fpa = FPABlock(in_channels=encoder_channels[-1], out_channels=decoder_channels) |
| self.gau3 = GAUBlock(in_channels=encoder_channels[-2], out_channels=decoder_channels, upscale_mode=upscale_mode) |
| self.gau2 = GAUBlock(in_channels=encoder_channels[-3], out_channels=decoder_channels, upscale_mode=upscale_mode) |
| self.gau1 = GAUBlock(in_channels=encoder_channels[-4], out_channels=decoder_channels, upscale_mode=upscale_mode) |
|
|
| def forward(self, *features): |
| bottleneck = features[-1] |
| x5 = self.fpa(bottleneck) |
| x4 = self.gau3(features[-2], x5) |
| x3 = self.gau2(features[-3], x4) |
| x2 = self.gau1(features[-4], x3) |
|
|
| return x2 |
|
|