| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from psalm.train.train_datasets import * |
| | from psalm.eval.eval_davis import DAVIS_Dataset, Ego_Train_Dataset, Multicondition_Dataset |
| | from psalm.mask_config.config import Config |
| |
|
| | |
| | from psalm.model.language_model.llava_phi_SSL_MultiCondition import PSALM_SSL_MultiCondition |
| |
|
| | from psalm.train.llava_trainer_SSL import LLaVATrainerSSL |
| |
|
| | from fvcore.common.config import CfgNode |
| | import warnings |
| |
|
| | print('Version: SSL_MultiCondition!') |
| |
|
| | warnings.filterwarnings('ignore') |
| | local_rank = None |
| |
|
| | def print_trainable_parm(model,prefix): |
| | for name, module in model.named_modules(): |
| | print_flag = False |
| | for p in module.parameters(): |
| | if p.requires_grad == True: |
| | print(f'{prefix}: {name}') |
| | print_flag = True |
| | break |
| | def get_mask_config(config='./psalm/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml'): |
| | cfg_coco = Config.fromfile(config) |
| | cfg_base = CfgNode.load_yaml_with_base(config, allow_unsafe=True) |
| | cfg_base.update(cfg_coco.__dict__.items()) |
| | cfg = cfg_base |
| | cfg = Config(cfg) |
| | return cfg |
| |
|
| | def print_dtype(model,prefix,dtype): |
| | for name,p in model.named_parameters(): |
| | if p.dtype != dtype: |
| | print(f'{prefix}: {name}') |
| | print(p.dtype) |
| |
|
| | def rank0_print(*args): |
| | if local_rank == 0: |
| | print(*args) |
| |
|
| |
|
| | @dataclass |
| | class ModelArguments: |
| | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") |
| | version: Optional[str] = field(default="v0") |
| | freeze_backbone: bool = field(default=False) |
| | train_backbone: bool = field(default=False) |
| | tune_mm_mlp_adapter: bool = field(default=False) |
| | vision_tower: Optional[str] = field(default=None) |
| | mm_vision_select_layer: Optional[int] = field(default=-1) |
| | pretrain_mm_mlp_adapter: Optional[str] = field(default=None) |
| | mm_use_im_start_end: bool = field(default=False) |
| | mm_use_im_patch_token: bool = field(default=True) |
| | mm_vision_select_feature: Optional[str] = field(default="patch") |
| | with_norm: bool = field(default=True) |
| | with_layernorm: bool = field(default=False) |
| | skip_init_vision: bool = field(default=False) |
| | with_sam: bool = field(default=False) |
| | with_swin: bool = field(default=False) |
| | with_teacher: bool = field(default=False) |
| | swin_type: Optional[str] = field(default="base") |
| | projector_outdim: Optional[int] = field(default=2048) |
| | mm_projector_type: Optional[str] = field(default="swin_conv") |
| | model_version: Optional[str] = field(default="v1") |
| | load_mask2former: bool = field(default=True) |
| | seg_task: Optional[str] = field(default="panoptic") |
| | mask_config: Optional[str] = field(default="./psalm/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml") |
| | dino_path: Optional[str] = field(default=None) |
| |
|
| | @dataclass |
| | class DataArguments: |
| | data_path: str = field(default=None, |
| | metadata={"help": "Path to the training data."}) |
| | lazy_preprocess: bool = False |
| | is_multimodal: bool = False |
| | image_folder: Optional[str] = field(default=None) |
| | refcoco_image_folder: Optional[str] = "/path/to/refer_seg/images/mscoco/images/train2014" |
| | image_first: bool = field(default=True) |
| | seg_last: bool = field(default=True) |
| | instruction_version: str = 'v1' |
| | image_aspect_ratio: str = 'square' |
| | image_grid_pinpoints: Optional[str] = field(default=None) |
| | json_path: str = '/path/to/instruction_segmentation_train.json' |
| | instance_json_path: str = '/path/to/instruction_segmentation_train.json' |
| | lvis_json_path: str = '/path/to/lvis_instance_train.json' |
| | lvis_categories_path: str = '/path/to/lvis_instance_categories.json' |
| | region_json_path: str = '/path/to/visual_prompt_segmentation_train.json' |
| | panoptic_json_path: str = "/path/to/coco" |
| | ref_coco_path: str = '/path/to/refcoco/refcoco_train.json' |
| | ref_coco_plus_path: str = '/path/to/refcoco+/refcoco+_train.json' |
| | ref_coco_g_path: str = '/path/to/refcocog/refcocog_train.json' |
| | mmconv_path: str = '/path/to/llava_1_5' |
| | data_ratio: str = '1||1||1||1' |
| | fix_dataset_len: int = 0 |
| | segmentation: bool = True |
| |
|
| | @dataclass |
| | class TrainingArguments(transformers.TrainingArguments): |
| | cache_dir: Optional[str] = field(default=None) |
| | optim: str = field(default="adamw_torch") |
| | remove_unused_columns: bool = field(default=False) |
| | freeze_mm_mlp_adapter: bool = field(default=False) |
| | mpt_attn_impl: Optional[str] = field(default="triton") |
| | model_max_length: int = field( |
| | default=512, |
| | metadata={ |
| | "help": |
| | "Maximum sequence length. Sequences will be right padded (and possibly truncated)." |
| | }, |
| | ) |
| | double_quant: bool = field( |
| | default=True, |
| | metadata={"help": "Compress the quantization statistics through double quantization."} |
| | ) |
| | quant_type: str = field( |
| | default="nf4", |
| | metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} |
| | ) |
| | bits: int = field( |
| | default=16, |
| | metadata={"help": "How many bits to use."} |
| | ) |
| | lora_enable: bool = False |
| | lora_r: int = 64 |
| | lora_alpha: int = 16 |
| | lora_dropout: float = 0.05 |
| | lora_weight_path: str = "" |
| | lora_bias: str = "none" |
| | dataloader_drop_last: bool = True |
| |
|
| |
|
| | def maybe_zero_3(param, ignore_status=False, name=None): |
| | from deepspeed import zero |
| | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus |
| | if hasattr(param, "ds_id"): |
| | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: |
| | if not ignore_status: |
| | logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") |
| | with zero.GatheredParameters([param]): |
| | param = param.data.detach().cpu().clone() |
| | else: |
| | param = param.detach().cpu().clone() |
| | return param |
| |
|
| |
|
| | |
| | def get_peft_state_maybe_zero_3(named_params, bias): |
| | if bias == "none": |
| | to_return = {k: t for k, t in named_params if "lora_" in k} |
| | elif bias == "all": |
| | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} |
| | elif bias == "lora_only": |
| | to_return = {} |
| | maybe_lora_bias = {} |
| | lora_bias_names = set() |
| | for k, t in named_params: |
| | if "lora_" in k: |
| | to_return[k] = t |
| | bias_name = k.split("lora_")[0] + "bias" |
| | lora_bias_names.add(bias_name) |
| | elif "bias" in k: |
| | maybe_lora_bias[k] = t |
| | for k, t in maybe_lora_bias: |
| | if bias_name in lora_bias_names: |
| | to_return[bias_name] = t |
| | else: |
| | raise NotImplementedError |
| | to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()} |
| | return to_return |
| |
|
| |
|
| | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): |
| | to_return = {k: t for k, t in named_params if "lora_" not in k} |
| | if require_grad_only: |
| | to_return = {k: t for k, t in to_return.items() if t.requires_grad} |
| | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} |
| | return to_return |
| |
|
| |
|
| | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): |
| | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} |
| | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} |
| | return to_return |
| |
|
| |
|
| | def find_all_linear_names(model): |
| | cls = torch.nn.Linear |
| | lora_module_names = set() |
| | for name, module in model.named_modules(): |
| | if isinstance(module, cls): |
| | names = name.split('.') |
| | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) |
| |
|
| | if 'lm_head' in lora_module_names: |
| | lora_module_names.remove('lm_head') |
| | return list(lora_module_names) |
| |
|
| |
|
| | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, |
| | output_dir: str): |
| | """Collects the state dict and dump to disk.""" |
| |
|
| | if getattr(trainer.args, "tune_mm_mlp_adapter", False): |
| | |
| | keys_to_match = ['mm_projector'] |
| | if getattr(trainer.args, "use_im_start_end", False): |
| | keys_to_match.extend(['embed_tokens', 'embed_in']) |
| |
|
| | weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) |
| | trainer.model.config.save_pretrained(output_dir) |
| |
|
| | current_folder = output_dir.split('/')[-1] |
| | parent_folder = os.path.dirname(output_dir) |
| | if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: |
| | if current_folder.startswith('checkpoint-'): |
| | mm_projector_folder = os.path.join(parent_folder, "mm_projector") |
| | os.makedirs(mm_projector_folder, exist_ok=True) |
| | torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) |
| | else: |
| | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) |
| | return |
| |
|
| | if trainer.deepspeed: |
| | torch.cuda.synchronize() |
| | trainer.save_model(output_dir) |
| | return |
| |
|
| | state_dict = trainer.model.state_dict() |
| | if trainer.args.should_save: |
| | cpu_state_dict = { |
| | key: value.cpu() |
| | for key, value in state_dict.items() |
| | } |
| | del state_dict |
| | trainer._save(output_dir, state_dict=cpu_state_dict) |
| |
|
| |
|
| | def smart_tokenizer_and_embedding_resize( |
| | special_tokens_dict: Dict, |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | model: transformers.PreTrainedModel, |
| | ): |
| | """Resize tokenizer and embedding. |
| | |
| | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. |
| | """ |
| | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) |
| | model.resize_token_embeddings(len(tokenizer)) |
| |
|
| | if num_new_tokens > 0: |
| | input_embeddings = model.get_input_embeddings().weight.data |
| | output_embeddings = model.get_output_embeddings().weight.data |
| |
|
| | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( |
| | dim=0, keepdim=True) |
| | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( |
| | dim=0, keepdim=True) |
| |
|
| | input_embeddings[-num_new_tokens:] = input_embeddings_avg |
| | output_embeddings[-num_new_tokens:] = output_embeddings_avg |
| |
|
| |
|
| | def _tokenize_fn(strings: Sequence[str], |
| | tokenizer: transformers.PreTrainedTokenizer) -> Dict: |
| | """Tokenize a list of strings.""" |
| | tokenized_list = [ |
| | tokenizer( |
| | text, |
| | return_tensors="pt", |
| | padding="longest", |
| | max_length=tokenizer.model_max_length, |
| | truncation=True, |
| | ) for text in strings |
| | ] |
| | input_ids = labels = [ |
| | tokenized.input_ids[0] for tokenized in tokenized_list |
| | ] |
| | input_ids_lens = labels_lens = [ |
| | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() |
| | for tokenized in tokenized_list |
| | ] |
| | return dict( |
| | input_ids=input_ids, |
| | labels=labels, |
| | input_ids_lens=input_ids_lens, |
| | labels_lens=labels_lens, |
| | ) |
| |
|
| |
|
| | def _mask_targets(target, tokenized_lens, speakers): |
| | |
| | cur_idx = tokenized_lens[0] |
| | tokenized_lens = tokenized_lens[1:] |
| | target[:cur_idx] = IGNORE_INDEX |
| | for tokenized_len, speaker in zip(tokenized_lens, speakers): |
| | if speaker == "human": |
| | target[cur_idx + 2:cur_idx + tokenized_len] = IGNORE_INDEX |
| | cur_idx += tokenized_len |
| |
|
| |
|
| | def _add_speaker_and_signal(header, source, get_conversation=True): |
| | """Add speaker and start/end signal on each round.""" |
| | BEGIN_SIGNAL = "### " |
| | END_SIGNAL = "\n" |
| | conversation = header |
| | for sentence in source: |
| | from_str = sentence["from"] |
| | if from_str.lower() == "human": |
| | from_str = conversation_lib.default_conversation.roles[0] |
| | elif from_str.lower() == "gpt": |
| | from_str = conversation_lib.default_conversation.roles[1] |
| | else: |
| | from_str = 'unknown' |
| | sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + |
| | sentence["value"] + END_SIGNAL) |
| | if get_conversation: |
| | conversation += sentence["value"] |
| | conversation += BEGIN_SIGNAL |
| | return conversation |
| |
|
| |
|
| | def make_unify_datamodule(tokenizer, data_args, training_args): |
| | data_ratio = data_args.data_ratio |
| | data_ratio = data_ratio.split('||') |
| | data_ratio = [int(data_) for data_ in data_ratio] |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | egoexo_dataset = Multicondition_Dataset(json_path=data_args.region_json_path, tokenizer=tokenizer,data_args=data_args) |
| | |
| | |
| | |
| | |
| | |
| | datasets = [egoexo_dataset] |
| | print(f'the dataset ratio is: {data_ratio}') |
| |
|
| | |
| | train_dataset = UnifyDatasetSingleDatasetForBatch(datasets,data_ratio,16,fix_dataset_len=data_args.fix_dataset_len) |
| | print(f'total unify dataset number is {len(train_dataset)}') |
| | data_collator = DataCollatorForCOCODatasetV2(tokenizer=tokenizer) |
| | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) |
| |
|
| |
|
| | def train(): |
| | global local_rank |
| |
|
| | parser = transformers.HfArgumentParser( |
| | (ModelArguments, DataArguments, TrainingArguments)) |
| | model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| | local_rank = training_args.local_rank |
| | compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) |
| |
|
| | mask_cfg = get_mask_config(config=model_args.mask_config) |
| | mask_cfg.MODEL.MASK_FORMER.SEG_TASK = model_args.seg_task |
| | bnb_model_from_pretrained_args = {} |
| |
|
| | print('using model PSALM SSL Multicondtion') |
| | |
| |
|
| | |
| | ''' |
| | model = PSALM.from_pretrained( |
| | model_args.model_name_or_path, |
| | mask_decoder_cfg=mask_cfg, |
| | add_cross_attn=True, |
| | cache_dir=training_args.cache_dir, |
| | **bnb_model_from_pretrained_args |
| | ) |
| | model.is_train_mask_decode = False |
| | if not model.is_train_mask_decode: |
| | mask2former_ckpt = model_args.vision_tower if model_args.load_mask2former else None |
| | model.initial_mask_module(mask2former_ckpt) |
| | ''' |
| | |
| | |
| | ''' #SSL version |
| | model = PSALM_SSL.from_pretrained( |
| | # model_args.model_name_or_path, |
| | "/data/work2-gcp-europe-west4-a/yuqian_fu/Ego/huggingface/hub/PSALM", |
| | mask_decoder_cfg=mask_cfg, |
| | add_cross_attn=True, |
| | cache_dir=training_args.cache_dir, |
| | **bnb_model_from_pretrained_args |
| | ) |
| | ''' |
| |
|
| | model = PSALM_SSL_MultiCondition.from_pretrained( |
| | |
| | |
| | "/data/work-gcp-europe-west4-a/yuqian_fu/Ego/OursMultiCondition_EgoQuery_SmallJson_1101_CAwithlearnableweight_1Head_TwoStageS1/checkpoint-152", |
| | mask_decoder_cfg=mask_cfg, |
| | add_cross_attn=True, |
| | cache_dir=training_args.cache_dir, |
| | **bnb_model_from_pretrained_args |
| | ) |
| | |
| | model2 = PSALM_SSL_MultiCondition.from_pretrained( |
| | |
| | |
| | "/data/work-gcp-europe-west4-a/yuqian_fu/Ego/OursMultiCondition_EgoQuery_SmallJson_1101_CAwithlearnableweight_1Head_TwoStageS2/checkpoint-3056", |
| | mask_decoder_cfg=mask_cfg, |
| | add_cross_attn=True, |
| | cache_dir=training_args.cache_dir, |
| | **bnb_model_from_pretrained_args |
| | ) |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ''' |
| | # Lora Train Version: (By default, it is trained wo lora) |
| | #training_args.lora_enable = True #Looks like not quiet working |
| | if (training_args.lora_enable == True): |
| | print("Attention: CUrrent we are using lora for training") |
| | ''' |
| |
|
| | model.config.use_cache = False |
| |
|
| | if model_args.freeze_backbone: |
| | model.model.requires_grad_(False) |
| |
|
| |
|
| | if training_args.gradient_checkpointing: |
| | if hasattr(model, "enable_input_require_grads"): |
| | model.enable_input_require_grads() |
| | else: |
| | def make_inputs_require_grad(module, input, output): |
| | output.requires_grad_(True) |
| |
|
| | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
| |
|
| | tokenizer = transformers.AutoTokenizer.from_pretrained( |
| | model_args.model_name_or_path, |
| | cache_dir=training_args.cache_dir, |
| | model_max_length=training_args.model_max_length, |
| | padding_side="right", |
| | use_fast=False, |
| | ) |
| |
|
| | if tokenizer.pad_token is None: |
| | smart_tokenizer_and_embedding_resize( |
| | special_tokens_dict=dict(pad_token="[PAD]"), |
| | tokenizer=tokenizer, |
| | model=model, |
| | ) |
| | if model_args.version in conversation_lib.conv_templates: |
| | conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] |
| | else: |
| | conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] |
| |
|
| | if model_args.vision_tower is not None: |
| | model.get_model().initialize_vision_modules( |
| | model_args=model_args, |
| | fsdp=training_args.fsdp |
| | ) |
| |
|
| | vision_tower = model.get_vision_tower() |
| | vision_tower.to(dtype=torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32), device=training_args.device) |
| | data_args.image_processor = vision_tower.image_processor |
| | data_args.is_multimodal = True |
| |
|
| | model.config.image_aspect_ratio = data_args.image_aspect_ratio |
| | model.config.image_grid_pinpoints = data_args.image_grid_pinpoints |
| |
|
| | model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter |
| | if model_args.tune_mm_mlp_adapter: |
| | model.requires_grad_(False) |
| | for p in model.get_model().mm_projector.parameters(): |
| | p.requires_grad = True |
| | if not model_args.train_backbone: |
| | model.model.vision_tower.requires_grad_(False) |
| |
|
| |
|
| | model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter |
| | if training_args.freeze_mm_mlp_adapter: |
| | for p in model.get_model().mm_projector.parameters(): |
| | p.requires_grad = False |
| |
|
| |
|
| | model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end |
| | training_args.use_im_start_end = model_args.mm_use_im_start_end |
| | model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token |
| | model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) |
| |
|
| | tokenizer.add_tokens("[SEG]") |
| | model.resize_token_embeddings(len(tokenizer)) |
| | model.get_special_token(SEG=tokenizer("[SEG]", return_tensors='pt', add_special_tokens=False)['input_ids'], EOS=tokenizer.eos_token_id) |
| | |
| | data_module = make_unify_datamodule(tokenizer=tokenizer, data_args=data_args, training_args=training_args) |
| | training_args.dataloader_drop_last = True |
| |
|
| |
|
| | |
| | |
| |
|
| | |
| | |
| | for name, param in model.named_parameters(): |
| | if "fuse_model" in name: |
| | print("model1",name,param) |
| | |
| | |
| | for name, param in model2.named_parameters(): |
| | if "fuse_model" in name: |
| | print("model2",name,param) |
| |
|
| | |
| |
|
| | if __name__ == "__main__": |
| | train() |
| |
|