| | from loguru import logger |
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import init |
| | import math |
| | from torch.compiler import is_compiling |
| | from torch import __version__ |
| | from torch.version import cuda |
| |
|
| | from modules.flux_model import Modulation |
| |
|
| | IS_TORCH_2_4 = __version__ < (2, 4, 9) |
| | LT_TORCH_2_4 = __version__ < (2, 4) |
| | if LT_TORCH_2_4: |
| | if not hasattr(torch, "_scaled_mm"): |
| | raise RuntimeError( |
| | "This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later." |
| | ) |
| | CUDA_VERSION = float(cuda) if cuda else 0 |
| | if CUDA_VERSION < 12.4: |
| | raise RuntimeError( |
| | f"This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later got torch version {__version__} and CUDA version {cuda}." |
| | ) |
| | try: |
| | from cublas_ops import CublasLinear |
| | except ImportError: |
| | CublasLinear = type(None) |
| |
|
| |
|
| | class F8Linear(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | in_features: int, |
| | out_features: int, |
| | bias: bool = True, |
| | device=None, |
| | dtype=torch.float16, |
| | float8_dtype=torch.float8_e4m3fn, |
| | float_weight: torch.Tensor = None, |
| | float_bias: torch.Tensor = None, |
| | num_scale_trials: int = 12, |
| | input_float8_dtype=torch.float8_e5m2, |
| | ) -> None: |
| | super().__init__() |
| | self.in_features = in_features |
| | self.out_features = out_features |
| | self.float8_dtype = float8_dtype |
| | self.input_float8_dtype = input_float8_dtype |
| | self.input_scale_initialized = False |
| | self.weight_initialized = False |
| | self.max_value = torch.finfo(self.float8_dtype).max |
| | self.input_max_value = torch.finfo(self.input_float8_dtype).max |
| | factory_kwargs = {"dtype": dtype, "device": device} |
| | if float_weight is None: |
| | self.weight = nn.Parameter( |
| | torch.empty((out_features, in_features), **factory_kwargs) |
| | ) |
| | else: |
| | self.weight = nn.Parameter( |
| | float_weight, requires_grad=float_weight.requires_grad |
| | ) |
| | if float_bias is None: |
| | if bias: |
| | self.bias = nn.Parameter( |
| | torch.empty(out_features, **factory_kwargs), |
| | ) |
| | else: |
| | self.register_parameter("bias", None) |
| | else: |
| | self.bias = nn.Parameter(float_bias, requires_grad=float_bias.requires_grad) |
| | self.num_scale_trials = num_scale_trials |
| | self.input_amax_trials = torch.zeros( |
| | num_scale_trials, requires_grad=False, device=device, dtype=torch.float32 |
| | ) |
| | self.trial_index = 0 |
| | self.register_buffer("scale", None) |
| | self.register_buffer( |
| | "input_scale", |
| | None, |
| | ) |
| | self.register_buffer( |
| | "float8_data", |
| | None, |
| | ) |
| | self.scale_reciprocal = self.register_buffer("scale_reciprocal", None) |
| | self.input_scale_reciprocal = self.register_buffer( |
| | "input_scale_reciprocal", None |
| | ) |
| |
|
| | def _load_from_state_dict( |
| | self, |
| | state_dict, |
| | prefix, |
| | local_metadata, |
| | strict, |
| | missing_keys, |
| | unexpected_keys, |
| | error_msgs, |
| | ): |
| | sd = {k.replace(prefix, ""): v for k, v in state_dict.items()} |
| | if "weight" in sd: |
| | if ( |
| | "float8_data" not in sd |
| | or sd["float8_data"] is None |
| | and sd["weight"].shape == (self.out_features, self.in_features) |
| | ): |
| | |
| | self._parameters["weight"] = nn.Parameter( |
| | sd["weight"], requires_grad=False |
| | ) |
| | if "bias" in sd: |
| | self._parameters["bias"] = nn.Parameter( |
| | sd["bias"], requires_grad=False |
| | ) |
| | self.quantize_weight() |
| | elif sd["float8_data"].shape == ( |
| | self.out_features, |
| | self.in_features, |
| | ) and sd["weight"] == torch.zeros_like(sd["weight"]): |
| | w = sd["weight"] |
| | |
| | self._buffers["float8_data"] = sd["float8_data"] |
| | self._parameters["weight"] = nn.Parameter( |
| | torch.zeros( |
| | 1, |
| | dtype=w.dtype, |
| | device=w.device, |
| | requires_grad=False, |
| | ) |
| | ) |
| | if "bias" in sd: |
| | self._parameters["bias"] = nn.Parameter( |
| | sd["bias"], requires_grad=False |
| | ) |
| | self.weight_initialized = True |
| |
|
| | |
| | if all( |
| | key in sd |
| | for key in [ |
| | "scale", |
| | "input_scale", |
| | "scale_reciprocal", |
| | "input_scale_reciprocal", |
| | ] |
| | ): |
| | self.scale = sd["scale"].float() |
| | self.input_scale = sd["input_scale"].float() |
| | self.scale_reciprocal = sd["scale_reciprocal"].float() |
| | self.input_scale_reciprocal = sd["input_scale_reciprocal"].float() |
| | self.input_scale_initialized = True |
| | self.trial_index = self.num_scale_trials |
| | elif "scale" in sd and "scale_reciprocal" in sd: |
| | self.scale = sd["scale"].float() |
| | self.input_scale = ( |
| | sd["input_scale"].float() if "input_scale" in sd else None |
| | ) |
| | self.scale_reciprocal = sd["scale_reciprocal"].float() |
| | self.input_scale_reciprocal = ( |
| | sd["input_scale_reciprocal"].float() |
| | if "input_scale_reciprocal" in sd |
| | else None |
| | ) |
| | self.input_scale_initialized = ( |
| | True if "input_scale" in sd else False |
| | ) |
| | self.trial_index = ( |
| | self.num_scale_trials if "input_scale" in sd else 0 |
| | ) |
| | self.input_amax_trials = torch.zeros( |
| | self.num_scale_trials, |
| | requires_grad=False, |
| | dtype=torch.float32, |
| | device=self.weight.device, |
| | ) |
| | self.input_scale_initialized = False |
| | self.trial_index = 0 |
| | else: |
| | |
| | self.input_scale_initialized = False |
| | self.trial_index = 0 |
| | self.input_amax_trials = torch.zeros( |
| | self.num_scale_trials, requires_grad=False, dtype=torch.float32 |
| | ) |
| | else: |
| | raise RuntimeError( |
| | f"Weight tensor not found or has incorrect shape in state dict: {sd.keys()}" |
| | ) |
| | else: |
| | raise RuntimeError( |
| | "Weight tensor not found or has incorrect shape in state dict" |
| | ) |
| |
|
| | def quantize_weight(self): |
| | if self.weight_initialized: |
| | return |
| | amax = torch.max(torch.abs(self.weight.data)).float() |
| | self.scale = self.amax_to_scale(amax, self.max_value) |
| | self.float8_data = self.to_fp8_saturated( |
| | self.weight.data, self.scale, self.max_value |
| | ).to(self.float8_dtype) |
| | self.scale_reciprocal = self.scale.reciprocal() |
| | self.weight.data = torch.zeros( |
| | 1, dtype=self.weight.dtype, device=self.weight.device, requires_grad=False |
| | ) |
| | self.weight_initialized = True |
| |
|
| | def set_weight_tensor(self, tensor: torch.Tensor): |
| | self.weight.data = tensor |
| | self.weight_initialized = False |
| | self.quantize_weight() |
| |
|
| | def amax_to_scale(self, amax, max_val): |
| | return (max_val / torch.clamp(amax, min=1e-12)).clamp(max=max_val) |
| |
|
| | def to_fp8_saturated(self, x, scale, max_val): |
| | return (x * scale).clamp(-max_val, max_val) |
| |
|
| | def quantize_input(self, x: torch.Tensor): |
| | if self.input_scale_initialized: |
| | return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( |
| | self.input_float8_dtype |
| | ) |
| | elif self.trial_index < self.num_scale_trials: |
| |
|
| | amax = torch.max(torch.abs(x)).float() |
| |
|
| | self.input_amax_trials[self.trial_index] = amax |
| | self.trial_index += 1 |
| | self.input_scale = self.amax_to_scale( |
| | self.input_amax_trials[: self.trial_index].max(), self.input_max_value |
| | ) |
| | self.input_scale_reciprocal = self.input_scale.reciprocal() |
| | return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( |
| | self.input_float8_dtype |
| | ) |
| | else: |
| | self.input_scale = self.amax_to_scale( |
| | self.input_amax_trials.max(), self.input_max_value |
| | ) |
| | self.input_scale_reciprocal = self.input_scale.reciprocal() |
| | self.input_scale_initialized = True |
| | return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( |
| | self.input_float8_dtype |
| | ) |
| |
|
| | def reset_parameters(self) -> None: |
| | if self.weight_initialized: |
| | self.weight = nn.Parameter( |
| | torch.empty( |
| | (self.out_features, self.in_features), |
| | **{ |
| | "dtype": self.weight.dtype, |
| | "device": self.weight.device, |
| | }, |
| | ) |
| | ) |
| | self.weight_initialized = False |
| | self.input_scale_initialized = False |
| | self.trial_index = 0 |
| | self.input_amax_trials.zero_() |
| | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
| | if self.bias is not None: |
| | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) |
| | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
| | init.uniform_(self.bias, -bound, bound) |
| | self.quantize_weight() |
| | self.max_value = torch.finfo(self.float8_dtype).max |
| | self.input_max_value = torch.finfo(self.input_float8_dtype).max |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | if self.input_scale_initialized or is_compiling(): |
| | x = self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( |
| | self.input_float8_dtype |
| | ) |
| | else: |
| | x = self.quantize_input(x) |
| |
|
| | prev_dims = x.shape[:-1] |
| | x = x.view(-1, self.in_features) |
| |
|
| | |
| | out = torch._scaled_mm( |
| | x, |
| | self.float8_data.T, |
| | scale_a=self.input_scale_reciprocal, |
| | scale_b=self.scale_reciprocal, |
| | bias=self.bias, |
| | out_dtype=self.weight.dtype, |
| | use_fast_accum=True, |
| | ) |
| | if IS_TORCH_2_4: |
| | out = out[0] |
| | out = out.view(*prev_dims, self.out_features) |
| | return out |
| |
|
| | @classmethod |
| | def from_linear( |
| | cls, |
| | linear: nn.Linear, |
| | float8_dtype=torch.float8_e4m3fn, |
| | input_float8_dtype=torch.float8_e5m2, |
| | ) -> "F8Linear": |
| | f8_lin = cls( |
| | in_features=linear.in_features, |
| | out_features=linear.out_features, |
| | bias=linear.bias is not None, |
| | device=linear.weight.device, |
| | dtype=linear.weight.dtype, |
| | float8_dtype=float8_dtype, |
| | float_weight=linear.weight.data, |
| | float_bias=(linear.bias.data if linear.bias is not None else None), |
| | input_float8_dtype=input_float8_dtype, |
| | ) |
| | f8_lin.quantize_weight() |
| | return f8_lin |
| |
|
| |
|
| | @torch.inference_mode() |
| | def recursive_swap_linears( |
| | model: nn.Module, |
| | float8_dtype=torch.float8_e4m3fn, |
| | input_float8_dtype=torch.float8_e5m2, |
| | quantize_modulation: bool = True, |
| | ignore_keys: list[str] = [], |
| | ) -> None: |
| | """ |
| | Recursively swaps all nn.Linear modules in the given model with F8Linear modules. |
| | |
| | This function traverses the model's structure and replaces each nn.Linear |
| | instance with an F8Linear instance, which uses 8-bit floating point |
| | quantization for weights. The original linear layer's weights are deleted |
| | after conversion to save memory. |
| | |
| | Args: |
| | model (nn.Module): The PyTorch model to modify. |
| | |
| | Note: |
| | This function modifies the model in-place. After calling this function, |
| | all linear layers in the model will be using 8-bit quantization. |
| | """ |
| | for name, child in model.named_children(): |
| | if name in ignore_keys: |
| | continue |
| | if isinstance(child, Modulation) and not quantize_modulation: |
| | continue |
| | if isinstance(child, nn.Linear) and not isinstance( |
| | child, (F8Linear, CublasLinear) |
| | ): |
| |
|
| | setattr( |
| | model, |
| | name, |
| | F8Linear.from_linear( |
| | child, |
| | float8_dtype=float8_dtype, |
| | input_float8_dtype=input_float8_dtype, |
| | ), |
| | ) |
| | del child |
| | else: |
| | recursive_swap_linears( |
| | child, |
| | float8_dtype=float8_dtype, |
| | input_float8_dtype=input_float8_dtype, |
| | quantize_modulation=quantize_modulation, |
| | ignore_keys=ignore_keys, |
| | ) |
| |
|
| |
|
| | @torch.inference_mode() |
| | def swap_to_cublaslinear(model: nn.Module): |
| | if CublasLinear == type(None): |
| | return |
| | for name, child in model.named_children(): |
| | if isinstance(child, nn.Linear) and not isinstance( |
| | child, (F8Linear, CublasLinear) |
| | ): |
| | cublas_lin = CublasLinear( |
| | child.in_features, |
| | child.out_features, |
| | bias=child.bias is not None, |
| | dtype=child.weight.dtype, |
| | device=child.weight.device, |
| | ) |
| | cublas_lin.weight.data = child.weight.clone().detach() |
| | cublas_lin.bias.data = child.bias.clone().detach() |
| | setattr(model, name, cublas_lin) |
| | del child |
| | else: |
| | swap_to_cublaslinear(child) |
| |
|
| |
|
| | @torch.inference_mode() |
| | def quantize_flow_transformer_and_dispatch_float8( |
| | flow_model: nn.Module, |
| | device=torch.device("cuda"), |
| | float8_dtype=torch.float8_e4m3fn, |
| | input_float8_dtype=torch.float8_e5m2, |
| | offload_flow=False, |
| | swap_linears_with_cublaslinear=True, |
| | flow_dtype=torch.float16, |
| | quantize_modulation: bool = True, |
| | quantize_flow_embedder_layers: bool = True, |
| | ) -> nn.Module: |
| | """ |
| | Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device. |
| | |
| | Iteratively pushes each module to device, evals, replaces linear layers with F8Linear except for final_layer, and quantizes. |
| | |
| | Allows for fast dispatch to gpu & quantize without causing OOM on gpus with limited memory. |
| | |
| | After dispatching, if offload_flow is True, offloads the model to cpu. |
| | |
| | if swap_linears_with_cublaslinear is true, and flow_dtype == torch.float16, then swap all linears with cublaslinears for 2x performance boost on consumer GPUs. |
| | Otherwise will skip the cublaslinear swap. |
| | |
| | For added extra precision, you can set quantize_flow_embedder_layers to False, |
| | this helps maintain the output quality of the flow transformer moreso than fully quantizing, |
| | at the expense of ~512MB more VRAM usage. |
| | |
| | For added extra precision, you can set quantize_modulation to False, |
| | this helps maintain the output quality of the flow transformer moreso than fully quantizing, |
| | at the expense of ~2GB more VRAM usage, but- has a much higher impact on image quality than the embedder layers. |
| | """ |
| | for module in flow_model.double_blocks: |
| | module.to(device) |
| | module.eval() |
| | recursive_swap_linears( |
| | module, |
| | float8_dtype=float8_dtype, |
| | input_float8_dtype=input_float8_dtype, |
| | quantize_modulation=quantize_modulation, |
| | ) |
| | torch.cuda.empty_cache() |
| | for module in flow_model.single_blocks: |
| | module.to(device) |
| | module.eval() |
| | recursive_swap_linears( |
| | module, |
| | float8_dtype=float8_dtype, |
| | input_float8_dtype=input_float8_dtype, |
| | quantize_modulation=quantize_modulation, |
| | ) |
| | torch.cuda.empty_cache() |
| | to_gpu_extras = [ |
| | "vector_in", |
| | "img_in", |
| | "txt_in", |
| | "time_in", |
| | "guidance_in", |
| | "final_layer", |
| | "pe_embedder", |
| | ] |
| | for module in to_gpu_extras: |
| | m_extra = getattr(flow_model, module) |
| | if m_extra is None: |
| | continue |
| | m_extra.to(device) |
| | m_extra.eval() |
| | if isinstance(m_extra, nn.Linear) and not isinstance( |
| | m_extra, (F8Linear, CublasLinear) |
| | ): |
| | if quantize_flow_embedder_layers: |
| | setattr( |
| | flow_model, |
| | module, |
| | F8Linear.from_linear( |
| | m_extra, |
| | float8_dtype=float8_dtype, |
| | input_float8_dtype=input_float8_dtype, |
| | ), |
| | ) |
| | del m_extra |
| | elif module != "final_layer": |
| | if quantize_flow_embedder_layers: |
| | recursive_swap_linears( |
| | m_extra, |
| | float8_dtype=float8_dtype, |
| | input_float8_dtype=input_float8_dtype, |
| | quantize_modulation=quantize_modulation, |
| | ) |
| | torch.cuda.empty_cache() |
| | if ( |
| | swap_linears_with_cublaslinear |
| | and flow_dtype == torch.float16 |
| | and CublasLinear != type(None) |
| | ): |
| | swap_to_cublaslinear(flow_model) |
| | elif swap_linears_with_cublaslinear and flow_dtype != torch.float16: |
| | logger.warning("Skipping cublas linear swap because flow_dtype is not float16") |
| | if offload_flow: |
| | flow_model.to("cpu") |
| | torch.cuda.empty_cache() |
| | return flow_model |
| |
|