Text Generation
Transformers
Safetensors
deepseek_v3
DeepSeek-R1-0528
GPTQ
Int4-Int8Mix
量化修复
vLLM
conversational
custom_code
text-generation-inference
4-bit precision
gptq
Instructions to use QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained("QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) outputs = model.generate(**inputs, max_new_tokens=40) print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:])) - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact
- SGLang
How to use QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact with Docker Model Runner:
docker model run hf.co/QuantTrio/DeepSeek-R1-0528-GPTQ-Int4-Int8Mix-Compact
| # SPDX-License-Identifier: Apache-2.0 | |
| from copy import deepcopy | |
| from typing import Any, Callable, Optional, Union | |
| import torch | |
| import vllm.model_executor.layers.fused_moe # noqa | |
| from vllm import _custom_ops as ops | |
| from vllm.logger import init_logger | |
| from vllm.model_executor.layers.fused_moe.layer import ( | |
| FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod) | |
| from vllm.model_executor.layers.linear import (LinearMethodBase, | |
| set_weight_attrs) | |
| from vllm.model_executor.layers.quantization import QuantizationMethods | |
| from vllm.model_executor.layers.quantization.base_config import ( | |
| QuantizationConfig, QuantizeMethodBase) | |
| from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( | |
| MPLinearLayerConfig, choose_mp_linear_kernel) | |
| from vllm.model_executor.layers.quantization.utils import replace_parameter | |
| from vllm.model_executor.layers.quantization.utils.gptq_utils import ( | |
| get_linear_quant_method, override_config, get_dynamic_override) | |
| from vllm.model_executor.layers.quantization.utils.marlin_utils import ( | |
| check_marlin_supported, check_moe_marlin_supports_layer, | |
| marlin_make_workspace_new, marlin_moe_permute_scales, | |
| marlin_repeat_scales_on_all_ranks, verify_marlin_supported) | |
| from vllm.model_executor.parameter import (ChannelQuantScaleParameter, | |
| GroupQuantScaleParameter, | |
| PackedColumnParameter, | |
| PackedvLLMParameter, | |
| RowvLLMParameter) | |
| from vllm.platforms import current_platform | |
| from vllm.scalar_type import scalar_types | |
| logger = init_logger(__name__) | |
| def get_moe_quant_method( | |
| config: QuantizationConfig, | |
| layer: torch.nn.Module, | |
| prefix: str, | |
| moe_method_cls: type, | |
| ): | |
| cloned_config = deepcopy(config) | |
| if isinstance(layer, FusedMoE): | |
| # False = skip module, None = no override, else = Positive match | |
| if get_dynamic_override( # noqa: E712 | |
| cloned_config, # noqa: E712 | |
| layer_name=prefix) == False: # noqa: E712 | |
| return UnquantizedFusedMoEMethod(layer.moe_config) | |
| if prefix: | |
| # Dynamic per module/layer rules may override base config | |
| override_config(cloned_config, prefix=prefix) | |
| return moe_method_cls(cloned_config) | |
| return None | |
| class GPTQMarlinConfig(QuantizationConfig): | |
| """Config class for GPTQ Marlin""" | |
| # (num_bits, is_sym) -> quant_type | |
| TYPE_MAP = { | |
| (4, True): scalar_types.uint4b8, | |
| (8, True): scalar_types.uint8b128, | |
| } | |
| def __init__(self, weight_bits: int, group_size: int, desc_act: bool, | |
| is_sym: bool, lm_head_quantized: bool, | |
| dynamic: dict[str, dict[str, Union[int, bool]]], | |
| full_config: dict[str, Any]) -> None: | |
| super().__init__() | |
| if desc_act and group_size == -1: | |
| # In this case, act_order == True is the same as act_order == False | |
| # (since we have only one group per output channel) | |
| desc_act = False | |
| # GPTQModel use `dynamic` config property to allow per module | |
| # quantization config so each module can be individually optimized. | |
| # Format is dict[str, dict] where key is a regex string that can | |
| # perform both positive ("+:" prefixed) or negative ("-:" prefixed) | |
| # matching of a module. | |
| # Default to positive match, override base quant config mode, if no | |
| # prefix is used. Value is in dict format of field key and override | |
| # value. | |
| # Negative matching will skip quantization init for this module | |
| # entirely: | |
| # non-quantized inference. More details and quantization examples can be | |
| # found at: https://github.com/ModelCloud/GPTQModel | |
| # Example: | |
| # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 | |
| # # last 1/4 of the layers 16-21 has 8bit and group_size 64 | |
| # dynamic = { | |
| # #`.*\.` matches the layers_node prefix | |
| # # positive match layer 10-15 | |
| # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, | |
| # # positive match layer 16-21 | |
| # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, | |
| # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers | |
| # } | |
| self.dynamic = dynamic | |
| self.weight_bits = weight_bits | |
| self.is_sym = is_sym | |
| self.pack_factor = 32 // weight_bits # packed into int32 | |
| self.group_size = group_size | |
| self.desc_act = desc_act | |
| self.lm_head_quantized = lm_head_quantized | |
| self.full_config = full_config | |
| if (weight_bits, is_sym) not in self.TYPE_MAP: | |
| raise ValueError("Unsupported quantization config: " | |
| f"bits={weight_bits}, sym={is_sym}") | |
| self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] | |
| def __repr__(self) -> str: | |
| return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " | |
| f"group_size={self.group_size}, " | |
| f"desc_act={self.desc_act}, " | |
| f"lm_head_quantized={self.lm_head_quantized}), " | |
| f"dynamic={self.dynamic}") | |
| def get_name(cls) -> QuantizationMethods: | |
| return "gptq_marlin" | |
| def get_supported_act_dtypes(cls) -> list[torch.dtype]: | |
| return [torch.half, torch.bfloat16] | |
| def get_min_capability(cls) -> int: | |
| return 80 | |
| def get_config_filenames(cls) -> list[str]: | |
| return ["quantize_config.json"] | |
| def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig": | |
| dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) | |
| dynamic = {} if dynamic is None else dynamic | |
| weight_bits = cls.get_from_keys(config, ["bits"]) | |
| group_size = cls.get_from_keys(config, ["group_size"]) | |
| desc_act = cls.get_from_keys(config, ["desc_act"]) | |
| is_sym = cls.get_from_keys(config, ["sym"]) | |
| lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], | |
| default=False) | |
| return cls(weight_bits, group_size, desc_act, is_sym, | |
| lm_head_quantized, dynamic, config) | |
| def override_quantization_method( | |
| cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: | |
| can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) | |
| is_valid_user_quant = (user_quant is None or user_quant == "marlin" | |
| or user_quant == "gptq_marlin") | |
| if can_convert and is_valid_user_quant: | |
| msg = ("The model is convertible to {} during runtime." | |
| " Using {} kernel.".format(cls.get_name(), cls.get_name())) | |
| logger.info(msg) | |
| return cls.get_name() | |
| if can_convert and user_quant == "gptq": | |
| logger.info("Detected that the model can run with gptq_marlin" | |
| ", however you specified quantization=gptq explicitly," | |
| " so forcing gptq. Use quantization=gptq_marlin for" | |
| " faster inference") | |
| return None | |
| def get_quant_method(self, layer: torch.nn.Module, | |
| prefix: str) -> Optional["QuantizeMethodBase"]: | |
| if isinstance(layer, FusedMoE): | |
| from vllm.model_executor.layers.quantization.moe_wna16 import ( | |
| MoeWNA16Config) | |
| if not check_moe_marlin_supports_layer(layer, self.group_size): | |
| logger.warning_once( | |
| f"Layer '{prefix}' is not supported by GPTQMoeMarlin. " | |
| "Falling back to Moe WNA16 kernels.") | |
| return MoeWNA16Config.from_config( | |
| self.full_config).get_quant_method(layer, prefix) | |
| return get_moe_quant_method(self, layer, prefix, GPTQMarlinMoEMethod) | |
| return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) | |
| def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]): | |
| quant_method = quant_config.get("quant_method", "").lower() | |
| num_bits = quant_config.get("bits") | |
| group_size = quant_config.get("group_size") | |
| sym = quant_config.get("sym") | |
| desc_act = quant_config.get("desc_act") | |
| if not current_platform.is_cuda(): | |
| return False | |
| if quant_method != "gptq": | |
| return False | |
| # Marlin conversion is only valid if required properties are found | |
| if (num_bits is None or group_size is None or sym is None | |
| or desc_act is None): | |
| return False | |
| if (num_bits, sym) not in cls.TYPE_MAP: | |
| return False | |
| return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)], | |
| group_size=group_size) | |
| class GPTQMarlinLinearMethod(LinearMethodBase): | |
| """Linear method for GPTQ Marlin. | |
| Args: | |
| quant_config: The GPTQ Marlin quantization config. | |
| """ | |
| _kernel_backends_being_used: set[str] = set() | |
| def __init__(self, quant_config: GPTQMarlinConfig) -> None: | |
| self.quant_config = quant_config | |
| # Verify supported on platform. | |
| verify_marlin_supported(quant_type=self.quant_config.quant_type, | |
| group_size=self.quant_config.group_size) | |
| def create_weights( | |
| self, | |
| layer: torch.nn.Module, | |
| input_size_per_partition: int, | |
| output_partition_sizes: list[int], | |
| input_size: int, | |
| output_size: int, | |
| params_dtype: torch.dtype, | |
| **extra_weight_attrs, | |
| ) -> None: | |
| output_size_per_partition = sum(output_partition_sizes) | |
| is_row_parallel = input_size != input_size_per_partition | |
| weight_loader = extra_weight_attrs.get("weight_loader") | |
| mp_linear_kernel_config = MPLinearLayerConfig( | |
| full_weight_shape=(input_size, output_size), | |
| partition_weight_shape=\ | |
| (input_size_per_partition, output_size_per_partition), | |
| weight_type=self.quant_config.quant_type, | |
| act_type=params_dtype, | |
| group_size=self.quant_config.group_size, | |
| zero_points=False, | |
| has_g_idx=self.quant_config.desc_act | |
| ) | |
| kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) | |
| if kernel_type.__name__ not in self._kernel_backends_being_used: | |
| logger.info("Using %s for GPTQMarlinLinearMethod", | |
| kernel_type.__name__) | |
| self._kernel_backends_being_used.add(kernel_type.__name__) | |
| # Normalize group_size | |
| if self.quant_config.group_size != -1: | |
| group_size = self.quant_config.group_size | |
| else: | |
| group_size = input_size | |
| # Determine sharding | |
| if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, | |
| self.quant_config.group_size, | |
| is_row_parallel): | |
| # By setting scale_dim == None, weight_loader will | |
| # repeat the scales on each GPU in TP>1 case. | |
| scales_and_zp_input_dim = None | |
| scales_and_zp_size = input_size // group_size | |
| else: | |
| # By setting scale_dim == 0, weight_loader will | |
| # shard the scales in TP>1 case. | |
| scales_and_zp_input_dim = 0 | |
| scales_and_zp_size = input_size_per_partition // group_size | |
| # Quantized weights | |
| qweight = PackedvLLMParameter( | |
| data=torch.empty( | |
| input_size_per_partition // self.quant_config.pack_factor, | |
| output_size_per_partition, | |
| dtype=torch.int32, | |
| ), | |
| input_dim=0, | |
| output_dim=1, | |
| packed_dim=0, | |
| packed_factor=self.quant_config.pack_factor, | |
| weight_loader=weight_loader) | |
| # Activation order | |
| g_idx = RowvLLMParameter(data=torch.empty( | |
| input_size_per_partition, | |
| dtype=torch.int32, | |
| ), | |
| input_dim=0, | |
| weight_loader=weight_loader) | |
| qzeros_args = { | |
| "data": | |
| torch.empty( | |
| scales_and_zp_size, | |
| output_size_per_partition // self.quant_config.pack_factor, | |
| dtype=torch.int32, | |
| ), | |
| "weight_loader": | |
| weight_loader | |
| } | |
| weight_scale_args = { | |
| "data": | |
| torch.empty( | |
| scales_and_zp_size, | |
| output_size_per_partition, | |
| dtype=params_dtype, | |
| ), | |
| "weight_loader": | |
| weight_loader | |
| } | |
| if scales_and_zp_input_dim is None: | |
| scales = ChannelQuantScaleParameter(output_dim=1, | |
| **weight_scale_args) | |
| qzeros = PackedColumnParameter( | |
| output_dim=1, | |
| packed_dim=1, | |
| packed_factor=self.quant_config.pack_factor, | |
| **qzeros_args) | |
| else: | |
| scales = GroupQuantScaleParameter(output_dim=1, | |
| input_dim=0, | |
| **weight_scale_args) | |
| qzeros = PackedvLLMParameter( | |
| input_dim=0, | |
| output_dim=1, | |
| packed_dim=1, | |
| packed_factor=self.quant_config.pack_factor, | |
| **qzeros_args) | |
| layer.register_parameter("qweight", qweight) | |
| layer.register_parameter("g_idx", g_idx) | |
| layer.register_parameter("scales", scales) | |
| layer.register_parameter("qzeros", qzeros) | |
| self.kernel = kernel_type(mp_linear_kernel_config, | |
| w_q_param_name="qweight", | |
| w_s_param_name="scales", | |
| w_zp_param_name="qzeros", | |
| w_gidx_param_name="g_idx") | |
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
| self.kernel.process_weights_after_loading(layer) | |
| def apply( | |
| self, | |
| layer: torch.nn.Module, | |
| x: torch.Tensor, | |
| bias: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| return self.kernel.apply_weights(layer, x, bias) | |
| class GPTQMarlinMoEMethod(FusedMoEMethodBase): | |
| """MoE Marlin method with quantization.""" | |
| def __init__(self, quant_config: GPTQMarlinConfig) -> None: | |
| self.quant_config = quant_config | |
| if self.quant_config.quant_type.size_bits == 4: | |
| self.quant_type = scalar_types.uint4b8 | |
| elif self.quant_config.quant_type.size_bits == 8: | |
| self.quant_type = scalar_types.uint8b128 | |
| else: | |
| raise ValueError( | |
| "GPTQMarlinMoEMethod only supports int4 and int8 now.") | |
| def create_weights( | |
| self, | |
| layer: torch.nn.Module, | |
| num_experts: int, | |
| hidden_size: int, | |
| intermediate_size_per_partition: int, | |
| params_dtype: torch.dtype, | |
| **extra_weight_attrs, | |
| ): | |
| intermediate_size_full = extra_weight_attrs.pop( | |
| "intermediate_size_full") | |
| self.is_k_full = (not self.quant_config.desc_act) or ( | |
| intermediate_size_per_partition == intermediate_size_full) | |
| if self.quant_config.group_size != -1: | |
| scales_size13 = hidden_size // self.quant_config.group_size | |
| w2_scales_size = (intermediate_size_full | |
| if self.quant_config.desc_act else | |
| intermediate_size_per_partition) | |
| scales_size2 = (w2_scales_size // self.quant_config.group_size) | |
| strategy = FusedMoeWeightScaleSupported.GROUP.value | |
| else: | |
| scales_size13 = 1 | |
| scales_size2 = 1 | |
| strategy = FusedMoeWeightScaleSupported.CHANNEL.value | |
| extra_weight_attrs.update({ | |
| "quant_method": strategy, | |
| "is_transposed": True | |
| }) | |
| # Fused gate_up_proj (column parallel) | |
| w13_qweight = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| hidden_size // self.quant_config.pack_factor, | |
| 2 * intermediate_size_per_partition, | |
| dtype=torch.int32, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w13_qweight", w13_qweight) | |
| set_weight_attrs(w13_qweight, extra_weight_attrs) | |
| # down_proj (row parallel) | |
| w2_qweight = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| intermediate_size_per_partition // | |
| self.quant_config.pack_factor, | |
| hidden_size, | |
| dtype=torch.int32, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w2_qweight", w2_qweight) | |
| set_weight_attrs(w2_qweight, extra_weight_attrs) | |
| # up_proj scales | |
| w13_scales = torch.nn.Parameter( | |
| torch.empty(num_experts, | |
| scales_size13, | |
| 2 * intermediate_size_per_partition, | |
| dtype=params_dtype), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w13_scales", w13_scales) | |
| set_weight_attrs(w13_scales, extra_weight_attrs) | |
| # down_proj scales | |
| w2_scales = torch.nn.Parameter( | |
| torch.empty(num_experts, | |
| scales_size2, | |
| hidden_size, | |
| dtype=params_dtype), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w2_scales", w2_scales) | |
| set_weight_attrs(w2_scales, extra_weight_attrs) | |
| # dont shard the w2 scales when running act order | |
| set_weight_attrs(w2_scales, | |
| {"load_full_w2": self.quant_config.desc_act}) | |
| # up_proj scales | |
| w13_qzeros = torch.nn.Parameter( | |
| torch.empty(num_experts, | |
| scales_size13, | |
| 2 * intermediate_size_per_partition // | |
| self.quant_config.pack_factor, | |
| dtype=params_dtype), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w13_qzeros", w13_qzeros) | |
| set_weight_attrs(w13_qzeros, extra_weight_attrs) | |
| # down_proj scales | |
| w2_qzeros = torch.nn.Parameter( | |
| torch.empty(num_experts, | |
| scales_size2, | |
| hidden_size // self.quant_config.pack_factor, | |
| dtype=params_dtype), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w2_qzeros", w2_qzeros) | |
| set_weight_attrs(w2_qzeros, extra_weight_attrs) | |
| # dont shard the w2 scales when running act order | |
| set_weight_attrs(w2_qzeros, | |
| {"load_full_w2": self.quant_config.desc_act}) | |
| w13_g_idx = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| hidden_size, | |
| dtype=torch.int32, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w13_g_idx", w13_g_idx) | |
| set_weight_attrs(w13_g_idx, extra_weight_attrs) | |
| w2_g_idx = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| intermediate_size_per_partition, | |
| dtype=torch.int32, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w2_g_idx", w2_g_idx) | |
| set_weight_attrs(w2_g_idx, extra_weight_attrs) | |
| w13_g_idx_sort_indices = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| hidden_size, | |
| dtype=torch.int32, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w13_g_idx_sort_indices", | |
| w13_g_idx_sort_indices) | |
| set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) | |
| w2_g_idx_sort_indices = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| intermediate_size_per_partition, | |
| dtype=torch.int32, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w2_g_idx_sort_indices", | |
| w2_g_idx_sort_indices) | |
| set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) | |
| device = layer.w13_qweight.device | |
| layer.workspace = marlin_make_workspace_new(device, 4) | |
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
| # Process act_order | |
| if self.quant_config.desc_act: | |
| # Get sorting based on g_idx | |
| num_experts = layer.w13_g_idx.shape[0] | |
| w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) | |
| w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) | |
| w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) | |
| w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) | |
| for e in range(num_experts): | |
| w13_g_idx_sort_indices[e] = torch.argsort( | |
| layer.w13_g_idx[e]).to(torch.int32) | |
| w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( | |
| torch.int32) | |
| w13_sorted_g_idx[e] = layer.w13_g_idx[e][ | |
| w13_g_idx_sort_indices[e]] | |
| w2_sorted_g_idx[e] = layer.w2_g_idx[e][ | |
| w2_g_idx_sort_indices[e]] | |
| replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx) | |
| replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx) | |
| replace_parameter(layer, "w13_g_idx_sort_indices", | |
| w13_g_idx_sort_indices) | |
| replace_parameter(layer, "w2_g_idx_sort_indices", | |
| w2_g_idx_sort_indices) | |
| else: | |
| # Reset g_idx related tensors | |
| num_experts = layer.w13_g_idx.shape[0] | |
| device = layer.w13_g_idx.device | |
| layer.w13_g_idx = torch.nn.Parameter( | |
| torch.empty((num_experts, 0), dtype=torch.int32, | |
| device=device), | |
| requires_grad=False, | |
| ) | |
| layer.w2_g_idx = torch.nn.Parameter( | |
| torch.empty((num_experts, 0), dtype=torch.int32, | |
| device=device), | |
| requires_grad=False, | |
| ) | |
| layer.w13_g_idx_sort_indices = torch.nn.Parameter( | |
| torch.empty((num_experts, 0), dtype=torch.int32, | |
| device=device), | |
| requires_grad=False, | |
| ) | |
| layer.w2_g_idx_sort_indices = torch.nn.Parameter( | |
| torch.empty((num_experts, 0), dtype=torch.int32, | |
| device=device), | |
| requires_grad=False, | |
| ) | |
| # Repack weights | |
| marlin_w13_qweight = ops.gptq_marlin_moe_repack( | |
| layer.w13_qweight, | |
| layer.w13_g_idx_sort_indices, | |
| layer.w13_qweight.shape[1] * self.quant_config.pack_factor, | |
| layer.w13_qweight.shape[2], | |
| self.quant_config.quant_type.size_bits, | |
| ) | |
| replace_parameter(layer, "w13_qweight", marlin_w13_qweight) | |
| marlin_w2_qweight = ops.gptq_marlin_moe_repack( | |
| layer.w2_qweight, | |
| layer.w2_g_idx_sort_indices, | |
| layer.w2_qweight.shape[1] * self.quant_config.pack_factor, | |
| layer.w2_qweight.shape[2], | |
| self.quant_config.quant_type.size_bits, | |
| ) | |
| replace_parameter(layer, "w2_qweight", marlin_w2_qweight) | |
| # Repack scales | |
| marlin_w13_scales = marlin_moe_permute_scales( | |
| s=layer.w13_scales, | |
| size_k=layer.intermediate_size_per_partition, | |
| size_n=layer.w13_scales.shape[2], | |
| group_size=self.quant_config.group_size, | |
| ) | |
| replace_parameter(layer, "w13_scales", marlin_w13_scales) | |
| marlin_w2_scales = marlin_moe_permute_scales( | |
| s=layer.w2_scales, | |
| size_k=layer.w2_scales.shape[1] * | |
| (self.quant_config.group_size if self.quant_config.group_size != -1 | |
| else self.quant_config.pack_factor), | |
| size_n=layer.w2_scales.shape[2], | |
| group_size=self.quant_config.group_size, | |
| ) | |
| replace_parameter(layer, "w2_scales", marlin_w2_scales) | |
| def apply( | |
| self, | |
| layer: torch.nn.Module, | |
| x: torch.Tensor, | |
| router_logits: torch.Tensor, | |
| top_k: int, | |
| renormalize: bool, | |
| use_grouped_topk: bool = False, | |
| topk_group: Optional[int] = None, | |
| num_expert_group: Optional[int] = None, | |
| global_num_experts: int = -1, | |
| expert_map: Optional[torch.Tensor] = None, | |
| custom_routing_function: Optional[Callable] = None, | |
| scoring_func: str = "softmax", | |
| e_score_correction_bias: Optional[torch.Tensor] = None, | |
| apply_router_weight_on_input: bool = False, | |
| activation: str = "silu", | |
| ) -> torch.Tensor: | |
| assert activation == "silu", "Only SiLU activation is supported." | |
| if apply_router_weight_on_input: | |
| raise NotImplementedError( | |
| "Apply router weight on input is not supported for " | |
| "fused Marlin MoE method.") | |
| topk_weights, topk_ids = FusedMoE.select_experts( | |
| hidden_states=x, | |
| router_logits=router_logits, | |
| use_grouped_topk=use_grouped_topk, | |
| top_k=top_k, | |
| renormalize=renormalize, | |
| topk_group=topk_group, | |
| num_expert_group=num_expert_group, | |
| custom_routing_function=custom_routing_function, | |
| scoring_func=scoring_func, | |
| e_score_correction_bias=e_score_correction_bias) | |
| return torch.ops.vllm.fused_marlin_moe( | |
| x, | |
| layer.w13_qweight, | |
| layer.w2_qweight, | |
| layer.w13_scales, | |
| layer.w2_scales, | |
| router_logits, | |
| topk_weights, | |
| topk_ids, | |
| quant_type_id=self.quant_type.id, | |
| global_num_experts=global_num_experts, | |
| expert_map=expert_map, | |
| g_idx1=layer.w13_g_idx, | |
| g_idx2=layer.w2_g_idx, | |
| sort_indices1=layer.w13_g_idx_sort_indices, | |
| sort_indices2=layer.w2_g_idx_sort_indices, | |
| workspace=layer.workspace, | |
| is_k_full=self.is_k_full) | |