| | --- |
| | base_model: black-forest-labs/FLUX.1-dev |
| | library_name: diffusers |
| | base_model_relation: quantized |
| | tags: |
| | - quantization |
| | --- |
| | |
| | # Visual comparison of Flux-dev model outputs using BF16 and torchao int4_weight_only quantization |
| |
|
| | <td style="text-align: center;"> |
| | BF16<br> |
| | <medium-zoom background="rgba(0,0,0,.7)"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/quantization-backends-diffusers/combined_flux-dev_bf16_combined.png" alt="Flux-dev output with BF16: Baroque, Futurist, Noir styles"></medium-zoom> |
| | </td> |
| | <td style="text-align: center;"> |
| | torchao int4_weight_only<br> |
| | <medium-zoom background="rgba(0,0,0,.7)"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/quantization-backends-diffusers/combined_flux-dev_torchao_4bit_combined.png" alt="torchao int4_weight_only Output"></medium-zoom> |
| | </td> |
| |
|
| | # Usage with Diffusers |
| |
|
| | To use this quantized FLUX.1 [dev] checkpoint, you need to install the 🧨 diffusers and torchao library: |
| |
|
| | ``` |
| | pip install -U torchao |
| | ``` |
| |
|
| | For now, we require this specific branch in diffusers library to fix an error when loading the model |
| |
|
| | ``` |
| | pip install git+https://github.com/huggingface/diffusers.git@torchao-int4-serialization |
| | ``` |
| |
|
| | After installing the required library, you can run the following script: |
| |
|
| | ```python |
| | from diffusers import FluxPipeline |
| | |
| | pipe = FluxPipeline.from_pretrained( |
| | "diffusers/FLUX.1-dev-torchao-int4", |
| | torch_dtype=torch.bfloat16, |
| | use_safetensors=False, |
| | device_map="balanced" |
| | ) |
| | |
| | prompt = "Baroque style, a lavish palace interior with ornate gilded ceilings, intricate tapestries, and dramatic lighting over a grand staircase." |
| | |
| | pipe_kwargs = { |
| | "prompt": prompt, |
| | "height": 1024, |
| | "width": 1024, |
| | "guidance_scale": 3.5, |
| | "num_inference_steps": 50, |
| | "max_sequence_length": 512, |
| | } |
| | |
| | image = pipe( |
| | **pipe_kwargs, generator=torch.manual_seed(0), |
| | ).images[0] |
| | |
| | image.save("flux.png") |
| | ``` |
| |
|
| | # How to generate this quantized checkpoint ? |
| |
|
| | This checkpoint was created with the following script using "black-forest-labs/FLUX.1-dev" checkpoint: |
| |
|
| | ```python |
| | |
| | import torch |
| | from diffusers import FluxPipeline |
| | from diffusers.quantizers import PipelineQuantizationConfig |
| | from diffusers import TorchAoConfig as DiffusersTorchAoConfig |
| | from transformers import TorchAoConfig as TransformersTorchAoConfig |
| | |
| | pipeline_quant_config = PipelineQuantizationConfig( |
| | quant_mapping={ |
| | "transformer": DiffusersTorchAoConfig("int4_weight_only"), |
| | "text_encoder_2": TransformersTorchAoConfig("int4_weight_only"), |
| | } |
| | ) |
| | |
| | pipe = FluxPipeline.from_pretrained( |
| | "black-forest-labs/FLUX.1-dev", |
| | quantization_config=pipeline_quant_config, |
| | torch_dtype=torch.bfloat16, |
| | device_map="balanced" |
| | ) |
| | |
| | # safe_serialization set to `False` as we can't save torchao quantized model to safetensors format |
| | pipe.save_pretrained("FLUX.1-dev-torchao-int4", safe_serialization=False) |
| | ``` |