|
import os |
|
import json |
|
import shutil |
|
|
|
from tqdm import tqdm |
|
from PIL import Image |
|
|
|
import natsort |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
from config import config |
|
from src.open_clip import create_model_and_transforms |
|
|
|
|
|
class loading_img(Dataset): |
|
def __init__(self, img_list): |
|
self.img_list = img_list |
|
|
|
def __len__(self): |
|
return len(self.img_list) |
|
|
|
def __getitem__(self, idx): |
|
return self.img_list[idx].squeeze(0) |
|
|
|
class CustomDataset(Dataset): |
|
def __init__(self, questions, clippy, preprocess_val, clip_size, base_dir): |
|
self.questions = questions |
|
self.clippy = clippy |
|
self.clip_size = clip_size |
|
self.preprocess_val = preprocess_val |
|
self.device = next(clippy.parameters()).device |
|
self.base_dir = base_dir |
|
|
|
def __getitem__(self, index): |
|
line = self.questions[index] |
|
images_dir = f"{line['q_uid']}" |
|
|
|
if line["Activity"] == "" or ("Activity" not in line): ref1 = [] |
|
|
|
else: |
|
if isinstance(line["Activity"], list): ref1 = line["Activity"] |
|
else: ref1 = line["Activity"].split(', ') |
|
|
|
keywords = ref1 |
|
clip_size = self.clip_size |
|
clippy = self.clippy |
|
preprocess_val = self.preprocess_val |
|
|
|
images = [] |
|
timelines = [] |
|
timelines_int = [] |
|
img_names = [] |
|
image_list = [] |
|
|
|
nframes_paths = line["filepath"] |
|
total_len = len(nframes_paths) |
|
nframes_paths = natsort.natsorted(nframes_paths) |
|
|
|
img_paths = [] |
|
for img_path in nframes_paths: |
|
img_path = self.base_dir + "/" + "/".join(img_path.split("/")[-4:]) |
|
img_paths.append(img_path) |
|
|
|
img_names.append(img_path.split('/')[-1].split('.')[0]) |
|
cur_img = Image.open(img_path).resize(clip_size) |
|
image_list.append(preprocess_val(cur_img)) |
|
|
|
timeline = f"{img_names[-1].split('_')[-2]}.{img_names[-1].split('_')[-1]} seconds" |
|
timeline_int = float(f"{img_names[-1].split('_')[-2]}.{img_names[-1].split('_')[-1]}") |
|
timelines.append(timeline) |
|
timelines_int.append(timeline_int) |
|
|
|
return image_list, img_paths, timelines, timelines_int, keywords, img_names |
|
|
|
def __len__(self): |
|
return len(self.questions) |
|
|
|
|
|
def disable_torch_init(): |
|
setattr(torch.nn.Linear, "reset_parameters", lambda self: None) |
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) |
|
|
|
def SortSimilarity(q_uid, simmat, keywords, nimgtokens, nframes_paths, maximgslen): |
|
sort_simmat, sort_idx = torch.sort(simmat, dim=-1, descending=True) |
|
sort_idx = torch.floor(sort_idx/nimgtokens).to(int) |
|
|
|
curimgslen = 0 |
|
|
|
imgidx_kw_dict = dict() |
|
numrow, numcol = sort_simmat.shape |
|
|
|
row_col_list = [0 for _ in range(numrow)] |
|
token = True |
|
|
|
while token: |
|
j = 0 |
|
while j < numrow: |
|
k = 0 |
|
i = row_col_list[j] |
|
|
|
while k < numcol-i: |
|
col_idx = i+k |
|
k += 1 |
|
|
|
simvalue = sort_simmat[j, col_idx].item() |
|
img_idx = sort_idx[j, col_idx].item() |
|
|
|
curr_keyword = keywords[j] |
|
curr_kfpath = nframes_paths[img_idx] |
|
|
|
if img_idx in imgidx_kw_dict: continue |
|
|
|
else: |
|
imgidx_kw_dict[img_idx] = {"kw": curr_keyword, "simvalue": simvalue, "kf_path": curr_kfpath, "kw_others": []} |
|
curimgslen += 1 |
|
|
|
row_col_list[j] = col_idx + 1 |
|
if curimgslen == maximgslen: return imgidx_kw_dict |
|
else: break |
|
|
|
j += 1 |
|
|
|
if sum(row_col_list) >= numrow*(numcol-1): token = False |
|
|
|
def create_data_loader(questions, clippy, preprocess_val, clip_size, base_dir, batch_size=1, num_workers=16): |
|
assert batch_size == 1, "batch_size must be 1" |
|
dataset = CustomDataset(questions, clippy, preprocess_val, clip_size, base_dir) |
|
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) |
|
return data_loader |
|
|
|
def eval_model(): |
|
disable_torch_init() |
|
question_path, maximgslen, base_dir, concatname, modelpath, answerpath, concatdir = config.question_path, config.maximgslen, config.base_dir, config.concatname, config.modelpath, config.answerpath, config.concatdir |
|
|
|
pretrained_ckpt = f"{modelpath}" |
|
clippy, preprocess_train, preprocess_val = create_model_and_transforms( |
|
"clippy-B-16", |
|
device="cuda", |
|
pretrained=pretrained_ckpt |
|
) |
|
clip_size = (224,224) |
|
device = next(clippy.parameters()).device |
|
|
|
questions = [json.loads(q) for q in open(os.path.expanduser(question_path), "r")] |
|
|
|
answer_path = f"{answerpath}" |
|
print(f"\nquestion_path:{question_path}\nanswer_path:{answer_path}") |
|
os.makedirs(os.path.dirname(answer_path), exist_ok=True) |
|
|
|
with open(answer_path, "w") as ans_file: |
|
data_loader = create_data_loader(questions, clippy, preprocess_val, clip_size, base_dir) |
|
concatimg_dir_base = f"{concatdir}" |
|
|
|
with torch.no_grad(): |
|
for (image_list, nframes_paths, timelines, timelines_int, keywords, img_names), line in tqdm(zip(data_loader, questions), total=len(questions)): |
|
q_uid = line["q_uid"] |
|
CA = line["CA"] if "CA" in line else None |
|
option0 = line['option 0'] |
|
option1 = line['option 1'] |
|
option2 = line['option 2'] |
|
option3 = line['option 3'] |
|
option4 = line['option 4'] |
|
question = line['question'] |
|
|
|
pastobj = None |
|
past_VLM_path = None |
|
past_VLM_timeline = None |
|
|
|
img_embed = [] |
|
nframes_paths = [e[0] for e in nframes_paths] |
|
|
|
image_set = loading_img(image_list) |
|
image_loader = DataLoader(image_set, batch_size=64, shuffle=False, num_workers=16) |
|
for e in image_loader: img_embed.append(clippy.encode_image(e.to(device), pool=False)[:, 1:]) |
|
img_embed = torch.concat(img_embed, dim=0) |
|
|
|
limit_keywords = config.limit_keywords |
|
keywords = [e[0] for e in keywords][:limit_keywords] |
|
keyword_embed = clippy.text.encode(keywords, convert_to_tensor=True) |
|
|
|
nframe, nimgtokens, channels = img_embed.shape |
|
keyword_embed = keyword_embed.unsqueeze(1) |
|
img_embed = img_embed.flatten(0, 1).unsqueeze(0) |
|
|
|
simmat = F.cosine_similarity(keyword_embed, img_embed, dim=-1).to(torch.float) |
|
imgidx_kw_dict = SortSimilarity(q_uid, simmat, keywords, nimgtokens, nframes_paths, maximgslen=maximgslen) |
|
|
|
|
|
simvalue = np.array([e["simvalue"] for e in imgidx_kw_dict.values()]) |
|
ordered_idx = np.argsort(simvalue) |
|
simvalue = simvalue[ordered_idx] |
|
kf_paths = np.array([e["kf_path"] for e in imgidx_kw_dict.values()])[ordered_idx] |
|
matchingkw = np.array([e["kw"] for e in imgidx_kw_dict.values()])[ordered_idx] |
|
|
|
|
|
time_kf_paths = np.array(kf_paths[:16]) |
|
timelines_int = np.array([float(f"{e.replace('.jpg', '').split('/')[-1].split('_')[1]}" + "."+ f"{e.replace('.jpg', '').split('/')[-1].split('_')[2]}") for e in time_kf_paths]) |
|
time_ordered_idx = np.argsort(timelines_int) |
|
|
|
timelines_int = timelines_int[time_ordered_idx] |
|
time_simvalue = np.array(simvalue[:16])[time_ordered_idx] |
|
time_kf_paths = np.array(time_kf_paths)[time_ordered_idx] |
|
time_matchingkw = np.array(matchingkw[:16])[time_ordered_idx] |
|
|
|
simvalue[:16] = time_simvalue |
|
kf_paths[:16] = time_kf_paths |
|
matchingkw[:16] = time_matchingkw |
|
|
|
segment_timeline = f"{timelines[0][0].split(' seconds')[0]}-{timelines[-1][0].split(' seconds')[0]}" |
|
|
|
imgw, imgh = Image.open(kf_paths[0]).size |
|
redwidth = 20 |
|
newimgw, newimgh = (imgw+redwidth) * 4 + redwidth, (imgh+redwidth) * 2 + redwidth |
|
concatimg = np.zeros((newimgh, newimgw, 3), dtype=np.uint8) |
|
concatimg[:, :, 0] = 255 |
|
concatimglist = [] |
|
concatimg_dir = f"{concatimg_dir_base}/{q_uid}" |
|
|
|
for i, cpath in enumerate(kf_paths): |
|
cur_img = np.array(Image.open(cpath)) |
|
whole_frame = 8 |
|
remainder = i % whole_frame |
|
rowremainder = i % (whole_frame//2) |
|
startwidth = redwidth + (imgw + redwidth)*rowremainder |
|
endwidth = startwidth + imgw |
|
|
|
if remainder / whole_frame < 0.5: concatimg[redwidth:redwidth+imgh, startwidth:endwidth, :] = cur_img |
|
else: concatimg[redwidth+imgh+redwidth:newimgh-redwidth, startwidth:endwidth, :] = cur_img |
|
|
|
if remainder == whole_frame - 1: concatimglist.append(Image.fromarray(concatimg)) |
|
|
|
if os.path.exists(concatimg_dir): shutil.rmtree(concatimg_dir) |
|
os.makedirs(f"{concatimg_dir}", exist_ok=True) |
|
for i, img in enumerate(concatimglist): img.save(f"{concatimg_dir}/concat_{i}.jpg") |
|
|
|
line["kf_paths"] = kf_paths.tolist() |
|
line["keywords"] = matchingkw.tolist() |
|
line["simvalue"] = simvalue.tolist() |
|
line["imgidx_kw_dict"] = imgidx_kw_dict |
|
line["segment_timeline"] = segment_timeline |
|
line["concatimg_dir"] = concatimg_dir |
|
|
|
ans_file.write(json.dumps(line) + "\n") |
|
|
|
print(f"question_path:{question_path}\nanswer_path:{answer_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
eval_model() |