| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from ..base import modules |
|
|
|
|
| class PSPBlock(nn.Module): |
|
|
| def __init__(self, in_channels, out_channels, pool_size, use_bathcnorm=True): |
| super().__init__() |
| if pool_size == 1: |
| use_bathcnorm = False |
| self.pool = nn.Sequential( |
| nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)), |
| modules.Conv2dReLU(in_channels, out_channels, (1, 1), use_batchnorm=use_bathcnorm) |
| ) |
|
|
| def forward(self, x): |
| h, w = x.size(2), x.size(3) |
| x = self.pool(x) |
| x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) |
| return x |
|
|
|
|
| class PSPModule(nn.Module): |
| def __init__(self, in_channels, sizes=(1, 2, 3, 6), use_bathcnorm=True): |
| super().__init__() |
|
|
| self.blocks = nn.ModuleList([ |
| PSPBlock(in_channels, in_channels // len(sizes), size, use_bathcnorm=use_bathcnorm) for size in sizes |
| ]) |
|
|
| def forward(self, x): |
| xs = [block(x) for block in self.blocks] + [x] |
| x = torch.cat(xs, dim=1) |
| return x |
|
|
|
|
| class PSPDecoder(nn.Module): |
|
|
| def __init__( |
| self, |
| encoder_channels, |
| use_batchnorm=True, |
| out_channels=512, |
| dropout=0.2, |
| ): |
| super().__init__() |
|
|
| self.psp = PSPModule( |
| in_channels=encoder_channels[-1], |
| sizes=(1, 2, 3, 6), |
| use_bathcnorm=use_batchnorm, |
| ) |
|
|
| self.conv = modules.Conv2dReLU( |
| in_channels=encoder_channels[-1] * 2, |
| out_channels=out_channels, |
| kernel_size=1, |
| use_batchnorm=use_batchnorm, |
| ) |
|
|
| self.dropout = nn.Dropout2d(p=dropout) |
|
|
| def forward(self, *features): |
| x = features[-1] |
| x = self.psp(x) |
| x = self.conv(x) |
| x = self.dropout(x) |
|
|
| return x |
|
|