| | from itertools import repeat |
| | import collections.abc |
| |
|
| | from torch import nn as nn |
| | from torchvision.ops.misc import FrozenBatchNorm2d |
| |
|
| |
|
| | def freeze_batch_norm_2d(module, module_match={}, name=''): |
| | """ |
| | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is |
| | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and |
| | returned. Otherwise, the module is walked recursively and submodules are converted in place. |
| | |
| | Args: |
| | module (torch.nn.Module): Any PyTorch module. |
| | module_match (dict): Dictionary of full module names to freeze (all if empty) |
| | name (str): Full module name (prefix) |
| | |
| | Returns: |
| | torch.nn.Module: Resulting module |
| | |
| | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 |
| | """ |
| | res = module |
| | is_match = True |
| | if module_match: |
| | is_match = name in module_match |
| | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): |
| | res = FrozenBatchNorm2d(module.num_features) |
| | res.num_features = module.num_features |
| | res.affine = module.affine |
| | if module.affine: |
| | res.weight.data = module.weight.data.clone().detach() |
| | res.bias.data = module.bias.data.clone().detach() |
| | res.running_mean.data = module.running_mean.data |
| | res.running_var.data = module.running_var.data |
| | res.eps = module.eps |
| | else: |
| | for child_name, child in module.named_children(): |
| | full_child_name = '.'.join([name, child_name]) if name else child_name |
| | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) |
| | if new_child is not child: |
| | res.add_module(child_name, new_child) |
| | return res |
| |
|
| |
|
| | |
| | def _ntuple(n): |
| | def parse(x): |
| | if isinstance(x, collections.abc.Iterable): |
| | return x |
| | return tuple(repeat(x, n)) |
| | return parse |
| |
|
| |
|
| | to_1tuple = _ntuple(1) |
| | to_2tuple = _ntuple(2) |
| | to_3tuple = _ntuple(3) |
| | to_4tuple = _ntuple(4) |
| | to_ntuple = lambda n, x: _ntuple(n)(x) |
| |
|