| |
| |
| |
| |
| import os |
| import json |
| import copy |
| import argparse |
|
|
| import torch |
|
|
| from llava.model.builder import load_pretrained_model |
| from llava.utils import disable_torch_init |
| from llava.mm_utils import get_model_name_from_path |
|
|
|
|
| def export(args): |
| |
| disable_torch_init() |
| model_path = os.path.expanduser(args.model_path) |
| model_name = get_model_name_from_path(model_path) |
| tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, |
| args.model_base, |
| model_name, |
| device="cpu") |
|
|
| |
| |
|
|
| |
| setattr(image_processor, "processor_class", "LlavaProcessor") |
| output_path = os.path.join(model_path, "preprocessor_config.json") |
| image_processor.to_json_file(output_path) |
|
|
| |
| processor_config = dict() |
| processor_config["image_token"] = "<image>" |
| processor_config["num_additional_image_tokens"] = 0 |
| processor_config["processor_class"] = "LlavaProcessor" |
| processor_config["patch_size"] = 64 |
| output_path = os.path.join(model_path, "processor_config.json") |
| json.dump(processor_config, open(output_path, "w"), indent=2) |
|
|
| |
| tokenizer_config_path = os.path.join(model_path, "tokenizer_config.json") |
| tokenizer_config = json.load(open(tokenizer_config_path, 'r')) |
| token_ids = list() |
| image_token_is_present = False |
| for k, v in tokenizer_config['added_tokens_decoder'].items(): |
| token_ids.append(int(k)) |
| if v["content"] == "<image>": |
| image_token_is_present = True |
| token_ids.pop() |
|
|
| |
| if not image_token_is_present: |
| tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}'] = copy.deepcopy( |
| tokenizer_config['added_tokens_decoder'][f'{token_ids[0]}']) |
| tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}']["content"] = "<image>" |
| json.dump(tokenizer_config, open(tokenizer_config_path, 'w'), indent=2) |
|
|
| |
| config_path = os.path.join(model_path, "config.json") |
| model_config = json.load(open(config_path, 'r')) |
| model_config["image_token_index"] = max(token_ids) + 1 |
| json.dump(model_config, open(config_path, 'w'), indent=2) |
|
|
| |
| image_res = image_processor.to_dict()['size']['shortest_edge'] |
| dummy_vision_input = torch.rand(1, 3, image_res, image_res).float() |
|
|
| vision_model = model.get_vision_tower() |
| |
| vision_model = vision_model.cpu().float().eval() |
|
|
| onnx_vision_model_path = os.path.join(model_path, "fastvithd.onnx") |
| |
| print(f"Exporting vision model to {onnx_vision_model_path}...") |
| torch.onnx.export( |
| vision_model, |
| dummy_vision_input, |
| onnx_vision_model_path, |
| input_names=['pixel_values'], |
| output_names=['last_hidden_state'], |
| |
| |
| |
| |
| opset_version=17, |
| export_params=True, |
| do_constant_folding=True |
| ) |
| print(f"Vision model ONNX export complete: {onnx_vision_model_path}") |
|
|
| |
| |
| with torch.no_grad(): |
| dummy_mm_projector_input = vision_model(dummy_vision_input) |
| |
| |
| dummy_mm_projector_input = dummy_mm_projector_input.cpu().float() |
|
|
| |
| |
| |
| mm_projector = model.get_model().mm_projector |
| mm_projector = mm_projector.cpu().float().eval() |
|
|
| onnx_mm_projector_path = os.path.join(model_path, "mm_projector.onnx") |
|
|
| print(f"Exporting mm_projector to {onnx_mm_projector_path}...") |
| torch.onnx.export( |
| mm_projector, |
| dummy_mm_projector_input, |
| onnx_mm_projector_path, |
| input_names=['last_hidden_state'], |
| output_names=['projected_image_features'], |
| opset_version=17, |
| export_params=True, |
| do_constant_folding=True |
| ) |
| print(f"mm_projector ONNX export complete: {onnx_mm_projector_path}") |
|
|
| |
| |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model-path", type=str, required=True) |
| parser.add_argument("--model-base", type=str, default=None) |
| parser.add_argument("--conv-mode", type=str, default="qwen_2") |
|
|
| args = parser.parse_args() |
|
|
| export(args) |
|
|