| | """Modified from https://github.com/kijai/ComfyUI-MochiWrapper |
| | """ |
| | import importlib.util |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | def replace_parameters_by_name(module, name_keywords, device): |
| | from torch import nn |
| | for name, param in list(module.named_parameters(recurse=False)): |
| | if any(keyword in name for keyword in name_keywords): |
| | if isinstance(param, nn.Parameter): |
| | tensor = param.data |
| | delattr(module, name) |
| | setattr(module, name, tensor.to(device=device)) |
| | for child_name, child_module in module.named_children(): |
| | replace_parameters_by_name(child_module, name_keywords, device) |
| |
|
| | def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens'], device=None): |
| | for name, module in model.named_modules(): |
| | flag = False |
| | for _exclude_module_name in exclude_module_name: |
| | if _exclude_module_name in name: |
| | flag = True |
| | if flag: |
| | continue |
| | for param_name, param in module.named_parameters(): |
| | flag = False |
| | for _exclude_module_name in exclude_module_name: |
| | if _exclude_module_name in param_name: |
| | flag = True |
| | if flag: |
| | continue |
| | param.data = param.data.to(torch.float8_e4m3fn) |
| |
|
| | def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs): |
| | weight_dtype = cls.weight.dtype |
| | cls.to(origin_dtype) |
| |
|
| | |
| | inputs = [input.to(origin_dtype) for input in inputs] |
| | out = cls.original_forward(*inputs, **kwargs) |
| |
|
| | cls.to(weight_dtype) |
| | return out |
| |
|
| | def convert_weight_dtype_wrapper(module, origin_dtype): |
| | for name, module in module.named_modules(): |
| | if name == "" or "embed_tokens" in name: |
| | continue |
| | original_forward = module.forward |
| | if hasattr(module, "weight") and module.weight is not None: |
| | setattr(module, "original_forward", original_forward) |
| | setattr( |
| | module, |
| | "forward", |
| | lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs) |
| | ) |