| | import os |
| | import json |
| | import shutil |
| |
|
| | from tqdm import tqdm |
| | from PIL import Image |
| |
|
| | import natsort |
| |
|
| | import numpy as np |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset, DataLoader |
| |
|
| | from config import config |
| | from src.open_clip import create_model_and_transforms |
| |
|
| |
|
| | class loading_img(Dataset): |
| | def __init__(self, img_list): |
| | self.img_list = img_list |
| |
|
| | def __len__(self): |
| | return len(self.img_list) |
| |
|
| | def __getitem__(self, idx): |
| | return self.img_list[idx].squeeze(0) |
| |
|
| | class CustomDataset(Dataset): |
| | def __init__(self, questions, clippy, preprocess_val, clip_size, base_dir): |
| | self.questions = questions |
| | self.clippy = clippy |
| | self.clip_size = clip_size |
| | self.preprocess_val = preprocess_val |
| | self.device = next(clippy.parameters()).device |
| | self.base_dir = base_dir |
| |
|
| | def __getitem__(self, index): |
| | line = self.questions[index] |
| | images_dir = f"{line['q_uid']}" |
| |
|
| | if line["Activity"] == "" or ("Activity" not in line): ref1 = [] |
| |
|
| | else: |
| | if isinstance(line["Activity"], list): ref1 = line["Activity"] |
| | else: ref1 = line["Activity"].split(', ') |
| | |
| | keywords = ref1 |
| | clip_size = self.clip_size |
| | clippy = self.clippy |
| | preprocess_val = self.preprocess_val |
| | |
| | images = [] |
| | timelines = [] |
| | timelines_int = [] |
| | img_names = [] |
| | image_list = [] |
| |
|
| | nframes_paths = line["filepath"] |
| | total_len = len(nframes_paths) |
| | nframes_paths = natsort.natsorted(nframes_paths) |
| |
|
| | img_paths = [] |
| | for img_path in nframes_paths: |
| | img_path = self.base_dir + "/" + "/".join(img_path.split("/")[-4:]) |
| | img_paths.append(img_path) |
| |
|
| | img_names.append(img_path.split('/')[-1].split('.')[0]) |
| | cur_img = Image.open(img_path).resize(clip_size) |
| | image_list.append(preprocess_val(cur_img)) |
| |
|
| | timeline = f"{img_names[-1].split('_')[-2]}.{img_names[-1].split('_')[-1]} seconds" |
| | timeline_int = float(f"{img_names[-1].split('_')[-2]}.{img_names[-1].split('_')[-1]}") |
| | timelines.append(timeline) |
| | timelines_int.append(timeline_int) |
| |
|
| | return image_list, img_paths, timelines, timelines_int, keywords, img_names |
| |
|
| | def __len__(self): |
| | return len(self.questions) |
| |
|
| |
|
| | def disable_torch_init(): |
| | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) |
| | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) |
| |
|
| | def SortSimilarity(q_uid, simmat, keywords, nimgtokens, nframes_paths, maximgslen): |
| | sort_simmat, sort_idx = torch.sort(simmat, dim=-1, descending=True) |
| | sort_idx = torch.floor(sort_idx/nimgtokens).to(int) |
| |
|
| | curimgslen = 0 |
| |
|
| | imgidx_kw_dict = dict() |
| | numrow, numcol = sort_simmat.shape |
| | |
| | row_col_list = [0 for _ in range(numrow)] |
| | token = True |
| |
|
| | while token: |
| | j = 0 |
| | while j < numrow: |
| | k = 0 |
| | i = row_col_list[j] |
| |
|
| | while k < numcol-i: |
| | col_idx = i+k |
| | k += 1 |
| |
|
| | simvalue = sort_simmat[j, col_idx].item() |
| | img_idx = sort_idx[j, col_idx].item() |
| |
|
| | curr_keyword = keywords[j] |
| | curr_kfpath = nframes_paths[img_idx] |
| |
|
| | if img_idx in imgidx_kw_dict: continue |
| |
|
| | else: |
| | imgidx_kw_dict[img_idx] = {"kw": curr_keyword, "simvalue": simvalue, "kf_path": curr_kfpath, "kw_others": []} |
| | curimgslen += 1 |
| |
|
| | row_col_list[j] = col_idx + 1 |
| | if curimgslen == maximgslen: return imgidx_kw_dict |
| | else: break |
| |
|
| | j += 1 |
| |
|
| | if sum(row_col_list) >= numrow*(numcol-1): token = False |
| |
|
| | def create_data_loader(questions, clippy, preprocess_val, clip_size, base_dir, batch_size=1, num_workers=16): |
| | assert batch_size == 1, "batch_size must be 1" |
| | dataset = CustomDataset(questions, clippy, preprocess_val, clip_size, base_dir) |
| | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) |
| | return data_loader |
| |
|
| | def eval_model(): |
| | disable_torch_init() |
| | question_path, maximgslen, base_dir, concatname, modelpath, answerpath, concatdir = config.question_path, config.maximgslen, config.base_dir, config.concatname, config.modelpath, config.answerpath, config.concatdir |
| |
|
| | pretrained_ckpt = f"{modelpath}" |
| | clippy, preprocess_train, preprocess_val = create_model_and_transforms( |
| | "clippy-B-16", |
| | device="cuda", |
| | pretrained=pretrained_ckpt |
| | ) |
| | clip_size = (224,224) |
| | device = next(clippy.parameters()).device |
| |
|
| | questions = [json.loads(q) for q in open(os.path.expanduser(question_path), "r")] |
| |
|
| | answer_path = f"{answerpath}" |
| | print(f"\nquestion_path:{question_path}\nanswer_path:{answer_path}") |
| | os.makedirs(os.path.dirname(answer_path), exist_ok=True) |
| |
|
| | with open(answer_path, "w") as ans_file: |
| | data_loader = create_data_loader(questions, clippy, preprocess_val, clip_size, base_dir) |
| | concatimg_dir_base = f"{concatdir}" |
| |
|
| | with torch.no_grad(): |
| | for (image_list, nframes_paths, timelines, timelines_int, keywords, img_names), line in tqdm(zip(data_loader, questions), total=len(questions)): |
| | q_uid = line["q_uid"] |
| | CA = line["CA"] if "CA" in line else None |
| | option0 = line['option 0'] |
| | option1 = line['option 1'] |
| | option2 = line['option 2'] |
| | option3 = line['option 3'] |
| | option4 = line['option 4'] |
| | question = line['question'] |
| |
|
| | pastobj = None |
| | past_VLM_path = None |
| | past_VLM_timeline = None |
| |
|
| | img_embed = [] |
| | nframes_paths = [e[0] for e in nframes_paths] |
| |
|
| | image_set = loading_img(image_list) |
| | image_loader = DataLoader(image_set, batch_size=64, shuffle=False, num_workers=16) |
| | for e in image_loader: img_embed.append(clippy.encode_image(e.to(device), pool=False)[:, 1:]) |
| | img_embed = torch.concat(img_embed, dim=0) |
| |
|
| | limit_keywords = config.limit_keywords |
| | keywords = [e[0] for e in keywords][:limit_keywords] |
| | keyword_embed = clippy.text.encode(keywords, convert_to_tensor=True) |
| |
|
| | nframe, nimgtokens, channels = img_embed.shape |
| | keyword_embed = keyword_embed.unsqueeze(1) |
| | img_embed = img_embed.flatten(0, 1).unsqueeze(0) |
| |
|
| | simmat = F.cosine_similarity(keyword_embed, img_embed, dim=-1).to(torch.float) |
| | imgidx_kw_dict = SortSimilarity(q_uid, simmat, keywords, nimgtokens, nframes_paths, maximgslen=maximgslen) |
| |
|
| | |
| | simvalue = np.array([e["simvalue"] for e in imgidx_kw_dict.values()]) |
| | ordered_idx = np.argsort(simvalue) |
| | simvalue = simvalue[ordered_idx] |
| | kf_paths = np.array([e["kf_path"] for e in imgidx_kw_dict.values()])[ordered_idx] |
| | matchingkw = np.array([e["kw"] for e in imgidx_kw_dict.values()])[ordered_idx] |
| |
|
| | |
| | time_kf_paths = np.array(kf_paths[:16]) |
| | timelines_int = np.array([float(f"{e.replace('.jpg', '').split('/')[-1].split('_')[1]}" + "."+ f"{e.replace('.jpg', '').split('/')[-1].split('_')[2]}") for e in time_kf_paths]) |
| | time_ordered_idx = np.argsort(timelines_int) |
| |
|
| | timelines_int = timelines_int[time_ordered_idx] |
| | time_simvalue = np.array(simvalue[:16])[time_ordered_idx] |
| | time_kf_paths = np.array(time_kf_paths)[time_ordered_idx] |
| | time_matchingkw = np.array(matchingkw[:16])[time_ordered_idx] |
| |
|
| | simvalue[:16] = time_simvalue |
| | kf_paths[:16] = time_kf_paths |
| | matchingkw[:16] = time_matchingkw |
| |
|
| | segment_timeline = f"{timelines[0][0].split(' seconds')[0]}-{timelines[-1][0].split(' seconds')[0]}" |
| |
|
| | imgw, imgh = Image.open(kf_paths[0]).size |
| | redwidth = 20 |
| | newimgw, newimgh = (imgw+redwidth) * 4 + redwidth, (imgh+redwidth) * 2 + redwidth |
| | concatimg = np.zeros((newimgh, newimgw, 3), dtype=np.uint8) |
| | concatimg[:, :, 0] = 255 |
| | concatimglist = [] |
| | concatimg_dir = f"{concatimg_dir_base}/{q_uid}" |
| |
|
| | for i, cpath in enumerate(kf_paths): |
| | cur_img = np.array(Image.open(cpath)) |
| | whole_frame = 8 |
| | remainder = i % whole_frame |
| | rowremainder = i % (whole_frame//2) |
| | startwidth = redwidth + (imgw + redwidth)*rowremainder |
| | endwidth = startwidth + imgw |
| |
|
| | if remainder / whole_frame < 0.5: concatimg[redwidth:redwidth+imgh, startwidth:endwidth, :] = cur_img |
| | else: concatimg[redwidth+imgh+redwidth:newimgh-redwidth, startwidth:endwidth, :] = cur_img |
| |
|
| | if remainder == whole_frame - 1: concatimglist.append(Image.fromarray(concatimg)) |
| |
|
| | if os.path.exists(concatimg_dir): shutil.rmtree(concatimg_dir) |
| | os.makedirs(f"{concatimg_dir}", exist_ok=True) |
| | for i, img in enumerate(concatimglist): img.save(f"{concatimg_dir}/concat_{i}.jpg") |
| |
|
| | line["kf_paths"] = kf_paths.tolist() |
| | line["keywords"] = matchingkw.tolist() |
| | line["simvalue"] = simvalue.tolist() |
| | line["imgidx_kw_dict"] = imgidx_kw_dict |
| | line["segment_timeline"] = segment_timeline |
| | line["concatimg_dir"] = concatimg_dir |
| |
|
| | ans_file.write(json.dumps(line) + "\n") |
| |
|
| | print(f"question_path:{question_path}\nanswer_path:{answer_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | eval_model() |