FastVLM-1.5B-RKLLM / export_onnx.py
happyme531's picture
Upload 12 files
6128fc3 verified
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
import os
import json
import copy
import argparse
import torch
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import get_model_name_from_path
def export(args):
# Load model
disable_torch_init()
model_path = os.path.expanduser(args.model_path)
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path,
args.model_base,
model_name,
device="cpu")
# Save extra metadata that is not saved during LLaVA training
# required by HF for auto-loading model and for mlx-vlm preprocessing
# Save image processing config
setattr(image_processor, "processor_class", "LlavaProcessor")
output_path = os.path.join(model_path, "preprocessor_config.json")
image_processor.to_json_file(output_path)
# Create processor config
processor_config = dict()
processor_config["image_token"] = "<image>"
processor_config["num_additional_image_tokens"] = 0
processor_config["processor_class"] = "LlavaProcessor"
processor_config["patch_size"] = 64
output_path = os.path.join(model_path, "processor_config.json")
json.dump(processor_config, open(output_path, "w"), indent=2)
# Modify tokenizer to include <image> special token.
tokenizer_config_path = os.path.join(model_path, "tokenizer_config.json")
tokenizer_config = json.load(open(tokenizer_config_path, 'r'))
token_ids = list()
image_token_is_present = False
for k, v in tokenizer_config['added_tokens_decoder'].items():
token_ids.append(int(k))
if v["content"] == "<image>":
image_token_is_present = True
token_ids.pop()
# Append only if <image> token is not present
if not image_token_is_present:
tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}'] = copy.deepcopy(
tokenizer_config['added_tokens_decoder'][f'{token_ids[0]}'])
tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}']["content"] = "<image>"
json.dump(tokenizer_config, open(tokenizer_config_path, 'w'), indent=2)
# Modify config to contain token id for <image>
config_path = os.path.join(model_path, "config.json")
model_config = json.load(open(config_path, 'r'))
model_config["image_token_index"] = max(token_ids) + 1
json.dump(model_config, open(config_path, 'w'), indent=2)
# Export the vision encoder to ONNX
image_res = image_processor.to_dict()['size']['shortest_edge']
dummy_vision_input = torch.rand(1, 3, image_res, image_res).float() # Dummy input tensor
vision_model = model.get_vision_tower()
# Ensure model is on CPU, in float precision, and in evaluation mode for ONNX export
vision_model = vision_model.cpu().float().eval()
onnx_vision_model_path = os.path.join(model_path, "fastvithd.onnx")
print(f"Exporting vision model to {onnx_vision_model_path}...")
torch.onnx.export(
vision_model,
dummy_vision_input, # Pass the dummy input tensor
onnx_vision_model_path,
input_names=['pixel_values'], # ONNX图中输入节点的名称
output_names=['last_hidden_state'], # ONNX图中输出节点的名称
# dynamic_axes={
# 'pixel_values': {0: 'batch_size'}, # 输入'pixel_values'的第0维是动态的batch_size
# 'last_hidden_state': {0: 'batch_size'} # 输出'last_hidden_state'的第0维是动态的batch_size
# },
opset_version=17, # ONNX opset 版本
export_params=True, # 在模型文件中存储训练好的参数权重
do_constant_folding=True # 执行常量折叠优化
)
print(f"Vision model ONNX export complete: {onnx_vision_model_path}")
# Generate dummy input for mm_projector by passing dummy_vision_input through vision_model
# This ensures the mm_projector receives input with the correct shape and characteristics
with torch.no_grad():
dummy_mm_projector_input = vision_model(dummy_vision_input)
# Ensure the input is on CPU and in float32 precision for the projector
dummy_mm_projector_input = dummy_mm_projector_input.cpu().float()
# Export the mm_projector to ONNX
# model.get_model() gives the underlying base model (e.g., LlavaLlamaModel)
# which contains the mm_projector attribute.
mm_projector = model.get_model().mm_projector
mm_projector = mm_projector.cpu().float().eval()
onnx_mm_projector_path = os.path.join(model_path, "mm_projector.onnx")
print(f"Exporting mm_projector to {onnx_mm_projector_path}...")
torch.onnx.export(
mm_projector,
dummy_mm_projector_input,
onnx_mm_projector_path,
input_names=['last_hidden_state'],
output_names=['projected_image_features'],
opset_version=17,
export_params=True,
do_constant_folding=True
)
print(f"mm_projector ONNX export complete: {onnx_mm_projector_path}")
# Removed CoreML specific code and intermediate .pt file handling
# No need for os.remove(pt_name) as pt_name is no longer created
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, required=True)
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--conv-mode", type=str, default="qwen_2")
args = parser.parse_args()
export(args)