| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import re |
| import contextlib |
| import numpy as np |
| import torch |
| import warnings |
| import dnnlib |
|
|
| from guided_diffusion import dist_util, logger |
|
|
| |
| |
| |
|
|
| _constant_cache = dict() |
|
|
|
|
| def constant(value, shape=None, dtype=None, device=None, memory_format=None): |
| value = np.asarray(value) |
| if shape is not None: |
| shape = tuple(shape) |
| if dtype is None: |
| dtype = torch.get_default_dtype() |
| if device is None: |
| device = torch.device('cpu') |
| if memory_format is None: |
| memory_format = torch.contiguous_format |
|
|
| key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, |
| memory_format) |
| tensor = _constant_cache.get(key, None) |
| if tensor is None: |
| tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) |
| if shape is not None: |
| tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) |
| tensor = tensor.contiguous(memory_format=memory_format) |
| _constant_cache[key] = tensor |
| return tensor |
|
|
|
|
| |
| |
|
|
| try: |
| nan_to_num = torch.nan_to_num |
| except AttributeError: |
|
|
| def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): |
| assert isinstance(input, torch.Tensor) |
| if posinf is None: |
| posinf = torch.finfo(input.dtype).max |
| if neginf is None: |
| neginf = torch.finfo(input.dtype).min |
| assert nan == 0 |
| return torch.clamp(input.unsqueeze(0).nansum(0), |
| min=neginf, |
| max=posinf, |
| out=out) |
|
|
|
|
| |
| |
|
|
| try: |
| symbolic_assert = torch._assert |
| except AttributeError: |
| symbolic_assert = torch.Assert |
|
|
| |
| |
| |
|
|
|
|
| @contextlib.contextmanager |
| def suppress_tracer_warnings(): |
| flt = ('ignore', None, torch.jit.TracerWarning, None, 0) |
| warnings.filters.insert(0, flt) |
| yield |
| warnings.filters.remove(flt) |
|
|
|
|
| |
| |
| |
| |
|
|
|
|
| def assert_shape(tensor, ref_shape): |
| if tensor.ndim != len(ref_shape): |
| raise AssertionError( |
| f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}' |
| ) |
| for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): |
| if ref_size is None: |
| pass |
| elif isinstance(ref_size, torch.Tensor): |
| with suppress_tracer_warnings( |
| ): |
| symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), |
| f'Wrong size for dimension {idx}') |
| elif isinstance(size, torch.Tensor): |
| with suppress_tracer_warnings( |
| ): |
| symbolic_assert( |
| torch.equal(size, torch.as_tensor(ref_size)), |
| f'Wrong size for dimension {idx}: expected {ref_size}') |
| elif size != ref_size: |
| raise AssertionError( |
| f'Wrong size for dimension {idx}: got {size}, expected {ref_size}' |
| ) |
|
|
|
|
| |
| |
|
|
|
|
| def profiled_function(fn): |
| def decorator(*args, **kwargs): |
| with torch.autograd.profiler.record_function(fn.__name__): |
| return fn(*args, **kwargs) |
|
|
| decorator.__name__ = fn.__name__ |
| return decorator |
|
|
|
|
| |
| |
| |
|
|
|
|
| class InfiniteSampler(torch.utils.data.Sampler): |
| def __init__(self, |
| dataset, |
| rank=0, |
| num_replicas=1, |
| shuffle=True, |
| seed=0, |
| window_size=0.5): |
| assert len(dataset) > 0 |
| assert num_replicas > 0 |
| assert 0 <= rank < num_replicas |
| assert 0 <= window_size <= 1 |
| super().__init__(dataset) |
| self.dataset = dataset |
| self.rank = rank |
| self.num_replicas = num_replicas |
| self.shuffle = shuffle |
| self.seed = seed |
| self.window_size = window_size |
|
|
| def __iter__(self): |
| order = np.arange(len(self.dataset)) |
| rnd = None |
| window = 0 |
| if self.shuffle: |
| rnd = np.random.RandomState(self.seed) |
| rnd.shuffle(order) |
| window = int(np.rint(order.size * self.window_size)) |
|
|
| idx = 0 |
| while True: |
| i = idx % order.size |
| if idx % self.num_replicas == self.rank: |
| yield order[i] |
| if window >= 2: |
| j = (i - rnd.randint(window)) % order.size |
| order[i], order[j] = order[j], order[i] |
| idx += 1 |
|
|
|
|
| |
| |
|
|
|
|
| def params_and_buffers(module): |
| assert isinstance(module, torch.nn.Module) |
| return list(module.parameters()) + list(module.buffers()) |
|
|
|
|
| def named_params_and_buffers(module): |
| assert isinstance(module, torch.nn.Module) |
| return list(module.named_parameters()) + list(module.named_buffers()) |
|
|
|
|
| def copy_params_and_buffers(src_module, dst_module, require_all=False, load_except=(), model_name=''): |
| assert isinstance(src_module, torch.nn.Module) |
| assert isinstance(dst_module, torch.nn.Module) |
| src_tensors = dict(named_params_and_buffers(src_module)) |
| for name, tensor in named_params_and_buffers(dst_module): |
| assert (name in src_tensors) or (not require_all) |
| if name in src_tensors: |
| try: |
| if name in load_except: |
| logger.log('ignore load_except module: ', name) |
| else: |
| tensor.copy_(src_tensors[name].detach()).requires_grad_( |
| tensor.requires_grad) |
| except: |
| print(name) |
|
|
| |
| |
| |
|
|
|
|
| @contextlib.contextmanager |
| def ddp_sync(module, sync): |
| assert isinstance(module, torch.nn.Module) |
| if sync or not isinstance(module, |
| torch.nn.parallel.DistributedDataParallel): |
| yield |
| else: |
| with module.no_sync(): |
| yield |
|
|
|
|
| |
| |
|
|
|
|
| def check_ddp_consistency(module, ignore_regex=None): |
| assert isinstance(module, torch.nn.Module) |
| for name, tensor in named_params_and_buffers(module): |
| fullname = type(module).__name__ + '.' + name |
| |
| if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): |
| continue |
| tensor = tensor.detach() |
| if tensor.is_floating_point(): |
| tensor = nan_to_num(tensor) |
| other = tensor.clone() |
| torch.distributed.broadcast(tensor=other, src=0) |
| assert (tensor == other).all(), fullname |
|
|
|
|
| |
| |
|
|
|
|
| def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): |
| assert isinstance(module, torch.nn.Module) |
| assert not isinstance(module, torch.jit.ScriptModule) |
| assert isinstance(inputs, (tuple, list)) |
|
|
| |
| entries = [] |
| nesting = [0] |
|
|
| def pre_hook(_mod, _inputs): |
| nesting[0] += 1 |
|
|
| def post_hook(mod, _inputs, outputs): |
| nesting[0] -= 1 |
| if nesting[0] <= max_nesting: |
| outputs = list(outputs) if isinstance(outputs, |
| (tuple, |
| list)) else [outputs] |
| outputs = [t for t in outputs if isinstance(t, torch.Tensor)] |
| entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) |
|
|
| hooks = [ |
| mod.register_forward_pre_hook(pre_hook) for mod in module.modules() |
| ] |
| hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] |
|
|
| |
| outputs = module(*inputs) |
| for hook in hooks: |
| hook.remove() |
|
|
| |
| tensors_seen = set() |
| for e in entries: |
| e.unique_params = [ |
| t for t in e.mod.parameters() if id(t) not in tensors_seen |
| ] |
| e.unique_buffers = [ |
| t for t in e.mod.buffers() if id(t) not in tensors_seen |
| ] |
| e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] |
| tensors_seen |= { |
| id(t) |
| for t in e.unique_params + e.unique_buffers + e.unique_outputs |
| } |
|
|
| |
| if skip_redundant: |
| entries = [ |
| e for e in entries if len(e.unique_params) or len(e.unique_buffers) |
| or len(e.unique_outputs) |
| ] |
|
|
| |
| rows = [[ |
| type(module).__name__, 'Parameters', 'Buffers', 'Output shape', |
| 'Datatype' |
| ]] |
| rows += [['---'] * len(rows[0])] |
| param_total = 0 |
| buffer_total = 0 |
| submodule_names = {mod: name for name, mod in module.named_modules()} |
| for e in entries: |
| name = '<top-level>' if e.mod is module else submodule_names[e.mod] |
| param_size = sum(t.numel() for t in e.unique_params) |
| buffer_size = sum(t.numel() for t in e.unique_buffers) |
| output_shapes = [str(list(t.shape)) for t in e.outputs] |
| output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] |
| rows += [[ |
| name + (':0' if len(e.outputs) >= 2 else ''), |
| str(param_size) if param_size else '-', |
| str(buffer_size) if buffer_size else '-', |
| (output_shapes + ['-'])[0], |
| (output_dtypes + ['-'])[0], |
| ]] |
| for idx in range(1, len(e.outputs)): |
| rows += [[ |
| name + f':{idx}', '-', '-', output_shapes[idx], |
| output_dtypes[idx] |
| ]] |
| param_total += param_size |
| buffer_total += buffer_size |
| rows += [['---'] * len(rows[0])] |
| rows += [['Total', str(param_total), str(buffer_total), '-', '-']] |
|
|
| |
| widths = [max(len(cell) for cell in column) for column in zip(*rows)] |
| print() |
| for row in rows: |
| print(' '.join(cell + ' ' * (width - len(cell)) |
| for cell, width in zip(row, widths))) |
| print() |
| return outputs |
|
|
|
|
| |
|
|