| |
| |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import argparse |
| import json |
| import safetensors.torch |
| import os |
| import sys |
| from pathlib import Path |
| from typing import Any, ContextManager, cast |
| from torch import Tensor |
|
|
| import numpy as np |
| import torch |
| import gguf |
|
|
| |
| SUPPORTED_ARCHS = ["flux", "sd3", "ltxv", "hyvid", "wan", "hidream", "qwen"] |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class QuantConfig: |
| ftype: gguf.LlamaFileType |
| qtype: gguf.GGMLQuantizationType |
|
|
| def __init__(self, ftype: gguf.LlamaFileType, qtype: gguf.GGMLQuantizationType): |
| self.ftype = ftype |
| self.qtype = qtype |
|
|
|
|
| qconfig_map: dict[str, QuantConfig] = { |
| "F16": QuantConfig(gguf.LlamaFileType.MOSTLY_F16, gguf.GGMLQuantizationType.F16), |
| "BF16": QuantConfig(gguf.LlamaFileType.MOSTLY_BF16, gguf.GGMLQuantizationType.BF16), |
| "Q8_0": QuantConfig(gguf.LlamaFileType.MOSTLY_Q8_0, gguf.GGMLQuantizationType.Q8_0), |
| "Q6_K": QuantConfig(gguf.LlamaFileType.MOSTLY_Q6_K, gguf.GGMLQuantizationType.Q6_K), |
| "Q5_K_S": QuantConfig(gguf.LlamaFileType.MOSTLY_Q5_K_S, gguf.GGMLQuantizationType.Q5_K), |
| "Q5_1": QuantConfig(gguf.LlamaFileType.MOSTLY_Q5_1, gguf.GGMLQuantizationType.Q5_1), |
| "Q5_0": QuantConfig(gguf.LlamaFileType.MOSTLY_Q5_0, gguf.GGMLQuantizationType.Q5_0), |
| "Q4_K_S": QuantConfig(gguf.LlamaFileType.MOSTLY_Q4_K_S, gguf.GGMLQuantizationType.Q4_K), |
| "Q4_1": QuantConfig(gguf.LlamaFileType.MOSTLY_Q4_1, gguf.GGMLQuantizationType.Q4_1), |
| "Q4_0": QuantConfig(gguf.LlamaFileType.MOSTLY_Q4_0, gguf.GGMLQuantizationType.Q4_0), |
| "Q3_K_S": QuantConfig(gguf.LlamaFileType.MOSTLY_Q3_K_S, gguf.GGMLQuantizationType.Q3_K), |
| |
| } |
|
|
|
|
| |
| class LazyTorchTensor(gguf.LazyBase): |
| _tensor_type = torch.Tensor |
| |
| dtype: torch.dtype |
| shape: torch.Size |
|
|
| |
| _dtype_map: dict[torch.dtype, type] = { |
| torch.float16: np.float16, |
| torch.float32: np.float32, |
| } |
|
|
| |
| |
| |
| _dtype_str_map: dict[str, torch.dtype] = { |
| "F64": torch.float64, |
| "F32": torch.float32, |
| "BF16": torch.bfloat16, |
| "F16": torch.float16, |
| |
| "I64": torch.int64, |
| |
| "I32": torch.int32, |
| |
| "I16": torch.int16, |
| "U8": torch.uint8, |
| "I8": torch.int8, |
| "BOOL": torch.bool, |
| "F8_E4M3": torch.float8_e4m3fn, |
| "F8_E5M2": torch.float8_e5m2, |
| } |
|
|
| def numpy(self) -> gguf.LazyNumpyTensor: |
| dtype = self._dtype_map[self.dtype] |
| return gguf.LazyNumpyTensor( |
| meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), |
| args=(self,), |
| func=(lambda s: s.numpy()), |
| ) |
|
|
| @classmethod |
| def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor: |
| return torch.empty(size=shape, dtype=dtype, device="meta") |
|
|
| @classmethod |
| def from_safetensors_slice(cls, st_slice: Any) -> Tensor: |
| dtype = cls._dtype_str_map[st_slice.get_dtype()] |
| shape: tuple[int, ...] = tuple(st_slice.get_shape()) |
| lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) |
| return cast(torch.Tensor, lazy) |
|
|
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| del types |
|
|
| if kwargs is None: |
| kwargs = {} |
|
|
| if func is torch.Tensor.numpy: |
| return args[0].numpy() |
|
|
| return cls._wrap_fn(func)(*args, **kwargs) |
|
|
|
|
| class Converter: |
| path_safetensors: Path |
| endianess: gguf.GGUFEndian |
| outtype: QuantConfig |
| outfile: Path |
| gguf_writer: gguf.GGUFWriter |
|
|
| def __init__( |
| self, |
| arch: str, |
| path_safetensors: Path, |
| endianess: gguf.GGUFEndian, |
| outtype: QuantConfig, |
| outfile: Path, |
| subfolder: str = None, |
| repo_id: str = None, |
| is_diffusers: bool = False, |
| ): |
| self.path_safetensors = path_safetensors |
| self.endianess = endianess |
| self.outtype = outtype |
| self.outfile = outfile |
|
|
| self.gguf_writer = gguf.GGUFWriter(path=None, arch=arch, endianess=self.endianess) |
| self.gguf_writer.add_file_type(self.outtype.ftype) |
| self.gguf_writer.add_type("diffusion") |
| if repo_id: |
| self.gguf_writer.add_string("repo_id", repo_id) |
| if subfolder: |
| self.gguf_writer.add_string("subfolder", subfolder) |
| if is_diffusers: |
| self.gguf_writer.add_bool("is_diffusers", True) |
|
|
| |
| from safetensors import safe_open |
|
|
| ctx = cast(ContextManager[Any], safe_open(path_safetensors, framework="pt", device="cpu")) |
| with ctx as model_part: |
| for name in model_part.keys(): |
| data = model_part.get_slice(name) |
| data = LazyTorchTensor.from_safetensors_slice(data) |
| self.process_tensor(name, data) |
|
|
| def process_tensor(self, name: str, data_torch: LazyTorchTensor) -> None: |
| is_1d = len(data_torch.shape) == 1 |
| current_dtype = data_torch.dtype |
| target_dtype = gguf.GGMLQuantizationType.F32 if is_1d else self.outtype.qtype |
|
|
| if data_torch.dtype not in (torch.float16, torch.float32): |
| data_torch = data_torch.to(torch.float32) |
|
|
| data = data_torch.numpy() |
|
|
| if current_dtype != target_dtype: |
| from custom_quants import quantize as custom_quantize, QuantError |
|
|
| try: |
| data = custom_quantize(data, target_dtype) |
| except QuantError as e: |
| logger.warning("%s, %s", e, "falling back to F16") |
| target_dtype = gguf.GGMLQuantizationType.F16 |
| data = custom_quantize(data, target_dtype) |
|
|
| |
| shape = gguf.quant_shape_from_byte_shape(data.shape, target_dtype) if data.dtype == np.uint8 else data.shape |
| shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" |
| logger.info(f"{f'%-32s' % f'{name},'} {current_dtype} --> {target_dtype.name}, shape = {shape_str}") |
|
|
| |
| self.gguf_writer.add_tensor(name, data, raw_dtype=target_dtype) |
|
|
| def write(self) -> None: |
| self.gguf_writer.write_header_to_file(path=self.outfile) |
| self.gguf_writer.write_kv_data_to_file() |
| self.gguf_writer.write_tensors_to_file(progress=True) |
| self.gguf_writer.close() |
|
|
|
|
| |
| def _merge_sharded_checkpoints(folder: Path): |
| with open(folder / "diffusion_pytorch_model.safetensors.index.json", "r") as f: |
| ckpt_metadata = json.load(f) |
| weight_map = ckpt_metadata.get("weight_map", None) |
| if weight_map is None: |
| raise KeyError("'weight_map' key not found in the shard index file.") |
|
|
| |
| files_to_load = set(weight_map.values()) |
| merged_state_dict = {} |
|
|
| |
| for file_name in files_to_load: |
| part_file_path = folder / file_name |
| if not os.path.exists(part_file_path): |
| raise FileNotFoundError(f"Part file {file_name} not found.") |
|
|
| with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: |
| for tensor_key in f.keys(): |
| if tensor_key in weight_map: |
| merged_state_dict[tensor_key] = f.get_tensor(tensor_key) |
|
|
| return merged_state_dict |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Convert a flux model to GGUF") |
| parser.add_argument( |
| "--outfile", |
| type=Path, |
| default=Path("model-{ftype}.gguf"), |
| help="path to write to; default: 'model-{ftype}.gguf' ; note: {ftype} will be replaced by the outtype", |
| ) |
| parser.add_argument( |
| "--outtype", |
| type=str, |
| choices=qconfig_map.keys(), |
| default="F16", |
| help="output quantization scheme", |
| ) |
| parser.add_argument( |
| "--arch", |
| type=str, |
| choices=SUPPORTED_ARCHS, |
| help="output model architecture", |
| ) |
| parser.add_argument( |
| "--bigendian", |
| action="store_true", |
| help="model is executed on big endian machine", |
| ) |
| parser.add_argument( |
| "model", |
| type=Path, |
| help="directory containing safetensors model file", |
| nargs="?", |
| ) |
| parser.add_argument("--cache_dir", type=Path, help="Directory to store the intermediate files when needed.") |
| parser.add_argument( |
| "--subfolder", type=Path, default=None, help="Subfolder on the HF Hub to load checkpoints from." |
| ) |
| parser.add_argument( |
| "--verbose", |
| action="store_true", |
| help="increase output verbosity", |
| ) |
|
|
| args = parser.parse_args() |
| if args.model is None: |
| parser.error("the following arguments are required: model") |
| if args.arch is None: |
| parser.error("the following arguments are required: --arch") |
| if args.arch not in SUPPORTED_ARCHS: |
| parser.error(f"Unsupported architecture: {args.arch}. Supported architectures: {', '.join(SUPPORTED_ARCHS)}") |
| return args |
|
|
|
|
| def convert(args): |
| if args.verbose: |
| logging.basicConfig(level=logging.DEBUG) |
| else: |
| logging.basicConfig(level=logging.INFO) |
|
|
| if not args.model.is_dir() and not args.model.is_file(): |
| if not len(str(args.model).split("/")) == 2: |
| logging.error(f"Model path {args.model} does not exist.") |
| sys.exit(1) |
|
|
| is_diffusers = False |
| repo_id = None |
| merged_state_dict = None |
| if args.model.is_dir(): |
| logging.info("Supplied a directory.") |
| files = list(args.model.glob("*.safetensors")) |
| n = len(files) |
| if n == 0: |
| logging.error("No safetensors files found.") |
| sys.exit(1) |
| if n == 1: |
| logging.info(f"Assinging {files[0]} to `args.model`") |
| args.model = files[0] |
| if n > 1: |
| assert args.model / "diffusion_pytorch_model.safetensors.index.json" in list(args.model.glob("*.*")) |
| assert args.cache_dir |
| merged_state_dict = _merge_sharded_checkpoints(args.model) |
| filepath = args.cache_dir / "merged_state_dict.safetensors" |
| safetensors.torch.save_file(merged_state_dict, filepath) |
| logging.info(f"Serialized merged state dict to {filepath}") |
| args.model = Path(filepath) |
|
|
| elif len(str(args.model).split("/")) == 2: |
| from huggingface_hub import snapshot_download |
|
|
| logging.info("Hub repo ID detected.") |
| allow_patterns = f"{args.subfolder}/*.*" if args.subfolder else None |
| local_dir = snapshot_download( |
| repo_id=str(args.model), local_dir=args.cache_dir, allow_patterns=allow_patterns, token=args.hf_token |
| ) |
| repo_id = str(args.model) |
| local_dir = Path(local_dir) |
| local_dir = local_dir / args.subfolder if args.subfolder else local_dir |
| merged_state_dict = _merge_sharded_checkpoints(local_dir) |
| filepath = ( |
| args.cache_dir / "merged_state_dict.safetensors" if args.cache_dir else "merged_state_dict.safetensors" |
| ) |
| safetensors.torch.save_file(merged_state_dict, filepath) |
| logging.info(f"Serialized merged state dict to {filepath}") |
| args.model = Path(filepath) |
| is_diffusers = True |
|
|
| if args.model.suffix != ".safetensors": |
| logging.error(f"Model path {args.model} is not a safetensors file.") |
| sys.exit(1) |
|
|
| if args.outfile.suffix != ".gguf": |
| logging.error("Output file must have .gguf extension.") |
| sys.exit(1) |
|
|
| qconfig = qconfig_map[args.outtype] |
| outfile = Path(str(args.outfile).format(ftype=args.outtype.upper())) |
|
|
| logger.info(f"Converting model in {args.model} to {outfile} with quantization {args.outtype}") |
| converter = Converter( |
| arch=args.arch, |
| path_safetensors=args.model, |
| endianess=gguf.GGUFEndian.BIG if args.bigendian else gguf.GGUFEndian.LITTLE, |
| outtype=qconfig, |
| outfile=outfile, |
| repo_id=repo_id, |
| subfolder=str(args.subfolder) if args.subfolder else None, |
| is_diffusers=is_diffusers, |
| ) |
| converter.write() |
| logger.info( |
| f"Conversion complete. Output written to {outfile}, architecture: {args.arch}, quantization: {qconfig.qtype.name}" |
| ) |
| if merged_state_dict is not None: |
| os.remove(filepath) |
| logging.info(f"Removed the intermediate {filepath}.") |
|
|