| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
| class NormalizedMultiScaleAttention(nn.Module): |
| """ |
| Normalized Multi-Scale Attention (Normalized-MSA) module |
| Enhances multi-scale feature representation by balancing computational efficiency with representation strength. |
| """ |
| def __init__(self, in_channels, scales=[1, 2, 4]): |
| super(NormalizedMultiScaleAttention, self).__init__() |
| self.scales = scales |
| self.in_channels = in_channels |
| |
| |
| self.spatial_convs = nn.ModuleList([ |
| nn.Sequential( |
| nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False), |
| nn.BatchNorm2d(in_channels), |
| nn.Sigmoid() |
| ) for _ in range(len(scales)) |
| ]) |
| |
| |
| self.edge_conv = nn.Sequential( |
| nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False), |
| nn.BatchNorm2d(in_channels), |
| nn.ReLU(inplace=True) |
| ) |
| |
| |
| self.scale_weights = nn.Parameter(torch.ones(len(scales)) / len(scales)) |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
| |
| def forward(self, x): |
| batch_size, channels, height, width = x.size() |
| multi_scale_features = [] |
| |
| |
| edge_features = self.edge_conv(x) |
| |
| for i, scale in enumerate(self.scales): |
| |
| if scale == 1: |
| x_s = x |
| else: |
| |
| x_s = F.avg_pool2d(x, kernel_size=scale, stride=scale) |
| |
| |
| spatial_attn = self.spatial_convs[i](x_s) |
| |
| |
| |
| x_flat = x_s.view(batch_size, channels, -1) |
| x_t = x_flat.transpose(1, 2) |
| |
| |
| norm_factor = math.sqrt(x_flat.size(2)) |
| channel_attn = torch.bmm(x_flat, x_t) / norm_factor |
| channel_attn = F.softmax(channel_attn, dim=2) |
| |
| |
| attended = torch.bmm(channel_attn, x_flat) |
| attended = attended.view(batch_size, channels, *x_s.size()[2:]) |
| |
| |
| attended = attended * spatial_attn |
| |
| |
| if scale != 1: |
| attended = F.interpolate(attended, size=(height, width), mode='bilinear', align_corners=False) |
| |
| multi_scale_features.append(attended) |
| |
| |
| weighted_features = [] |
| for i, feature in enumerate(multi_scale_features): |
| weighted_features.append(feature * self.scale_weights[i]) |
| |
| |
| output = torch.stack(weighted_features, dim=0).sum(dim=0) |
| |
| |
| output = output + 0.1 * edge_features |
| |
| return output |
|
|
| class EntropyOptimizedGating(nn.Module): |
| """ |
| Entropy-Optimized Gating (EOG) module |
| Feature redundancy is adaptively suppressed using a normalized entropy function. |
| """ |
| def __init__(self, channels, beta=0.3, epsilon=1e-5): |
| super(EntropyOptimizedGating, self).__init__() |
| self.channels = channels |
| self.beta = nn.Parameter(torch.tensor([beta])) |
| self.epsilon = epsilon |
| |
| self.residual_weight = nn.Parameter(torch.tensor([0.2])) |
| |
| def forward(self, x): |
| batch_size, channels, height, width = x.size() |
| |
| |
| entropies = [] |
| gates = [] |
| |
| for c in range(channels): |
| |
| channel = x[:, c, :, :] |
| |
| |
| abs_channel = torch.abs(channel) |
| sum_abs = torch.sum(abs_channel, dim=(1, 2), keepdim=True) + self.epsilon |
| norm_prob = abs_channel / sum_abs |
| |
| |
| |
| log_prob = torch.log(norm_prob + self.epsilon) |
| entropy = -torch.sum(norm_prob * log_prob, dim=(1, 2)) |
| |
| |
| max_entropy = math.log(height * width) |
| norm_entropy = entropy / max_entropy |
| |
| |
| gate = (norm_entropy > self.beta).float() |
| |
| entropies.append(norm_entropy) |
| gates.append(gate) |
| |
| |
| entropies = torch.stack(entropies, dim=1) |
| gates = torch.stack(gates, dim=1) |
| |
| |
| gates = gates.view(batch_size, channels, 1, 1) |
| gated_output = x * gates |
| |
| |
| output = gated_output + self.residual_weight * x |
| |
| return output |
|
|
| class EOANetModule(nn.Module): |
| """ |
| Entropy-Optimized Attention Network (EOANet) module |
| Combines Normalized Multi-Scale Attention with Entropy-Optimized Gating |
| """ |
| def __init__(self, in_channels, scales=[1, 2, 4], beta=0.5): |
| super(EOANetModule, self).__init__() |
| self.msa = NormalizedMultiScaleAttention(in_channels, scales) |
| self.eog = EntropyOptimizedGating(in_channels, beta) |
| |
| def forward(self, x): |
| |
| x_msa = self.msa(x) |
| |
| |
| x_eog = self.eog(x_msa) |
| |
| return x_eog |