From a9aa1bbbd307189e05834d6ee07661a2909ec714 Mon Sep 17 00:00:00 2001 From: MelikaAyoughi Date: Wed, 20 Jul 2022 16:57:53 +0200 Subject: [PATCH] Add files via upload --- config.py | 81 ++ configs/config.yaml | 3 + defaults.py | 54 ++ prepare_dataset.py | 30 + requirements.txt | 27 + scene_recognition.py | 272 +++++++ train.py | 256 ++++++ tvqa_dataset.py | 1634 ++++++++++++++++++++++++++++++++++++++ utils.py | 24 + visualize.py | 1770 ++++++++++++++++++++++++++++++++++++++++++ 10 files changed, 4151 insertions(+) create mode 100644 config.py create mode 100644 configs/config.yaml create mode 100644 defaults.py create mode 100644 prepare_dataset.py create mode 100644 requirements.txt create mode 100644 scene_recognition.py create mode 100644 train.py create mode 100644 tvqa_dataset.py create mode 100644 utils.py create mode 100644 visualize.py diff --git a/config.py b/config.py new file mode 100644 index 0000000..f493b43 --- /dev/null +++ b/config.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from fvcore.common.config import CfgNode as _CfgNode +import argparse + + +class CfgNode(_CfgNode): + """ + The same as `fvcore.common.config.CfgNode`, but different in: + + 1. Use unsafe yaml loading by default. + Note that this may lead to arbitrary code execution: you must not + load a config file from untrusted sources before manually inspecting + the content of the file. + 2. Support config versioning. + When attempting to merge an old config, it will convert the old config automatically. + """ + + # Note that the default value of allow_unsafe is changed to True + def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None: + loaded_cfg = _CfgNode.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe) + loaded_cfg = type(self)(loaded_cfg) + + self.merge_from_other_cfg(loaded_cfg) + + def dump(self, *args, **kwargs): + """ + Returns: + str: a yaml string representation of the config + """ + # to make it show up in docs + return super().dump(*args, **kwargs) + + +global_cfg = CfgNode() + + +def get_cfg() -> CfgNode: + """ + Get a copy of the default config. + + Returns: + a detectron2 CfgNode instance. + """ + from defaults import _C + + return _C.clone() + + +def set_global_cfg(cfg: CfgNode) -> None: + """ + Let the global config point to the given cfg. + + Assume that the given "cfg" has the key "KEY", after calling + `set_global_cfg(cfg)`, the key can be accessed by: + + .. code-block:: python + + from detectron2.config import global_cfg + print(global_cfg.KEY) + + By using a hacky global config, you can access these configs anywhere, + without having to pass the config object or the values deep into the code. + This is a hacky feature introduced for quick prototyping / research exploration. + """ + global global_cfg + global_cfg.clear() + global_cfg.update(cfg) + + +def default_argument_parser(): + parser = argparse.ArgumentParser(description="tvqa config file") + parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + return parser \ No newline at end of file diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 0000000..8fa84fd --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,3 @@ +TRAINING: + project_dir: "./output/default/" + diff --git a/defaults.py b/defaults.py new file mode 100644 index 0000000..d5e1e25 --- /dev/null +++ b/defaults.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from config import CfgNode as CN + +_C = CN() + +_C.TRAINING = CN() + +_C.TRAINING.project_dir = "./output/default/" +_C.TRAINING.data_path = "/home/mayoughi/tvqa_experiment/dataset/friends_frames/" +_C.TRAINING.epochs = 100 +_C.TRAINING.lr_decay_epoch = 50 +_C.TRAINING.lr_decay_epochs = [50, 75, 90] +_C.TRAINING.pretrained = True +_C.TRAINING.lr = 0.12 +_C.TRAINING.batch_size = 256 +_C.TRAINING.lr_decay_rate = 0.1 +_C.TRAINING.momentum = 0.9 +_C.TRAINING.weight_decay = 0.0001 +_C.TRAINING.last_commit = "unknown" +_C.TRAINING.supervised = False +_C.TRAINING.data_mode = "correct_target_id" # cleansed, correct_target_id, weak_label +_C.TRAINING.series = "friends" # friends, bbt +_C.TRAINING.clustering = "KMeans" # AgglomerativeClustering, KMeans, MiniBatchKMeans +_C.TRAINING.kmeans_batch_size = 100 # default +_C.TRAINING.exp_type = "normal" # normal, oracle +_C.TRAINING.ours_or_baseline = "ours" # ours, baseline + +# self-supervised parameters +_C.SSL = CN() +_C.SSL.align_alpha = 2 +_C.SSL.unif_t = 2 +_C.SSL.align_w = 1 +_C.SSL.unif_w = 1 +_C.SSL.random_crop = 100 +_C.SSL.include_unknowns = True +_C.SSL.joint = False +_C.SSL.face_layer = False +_C.SSL.sub_layer = False +_C.SSL.mix_layer = True +_C.SSL.face_layer_out_features = 512 +_C.SSL.sub_layer_out_features = 768 +_C.SSL.mix_layer_out_features = 1280 +_C.SSL.mix_layer_in_features = 1280 + +_C.SSL.supervised = False +_C.SSL.epsilon = 0.1 #with probability epsilon pick from closest cluster + + +_C.MODEL = CN() +_C.MODEL.out_features_1 = 512 +_C.MODEL.out_features_2 = 512 +_C.MODEL.out_features_3 = 512 + +_C.GLOBAL = CN() diff --git a/prepare_dataset.py b/prepare_dataset.py new file mode 100644 index 0000000..5ab9015 --- /dev/null +++ b/prepare_dataset.py @@ -0,0 +1,30 @@ +import torch +import os +from torchvision import datasets +from PIL import Image +from facenet_pytorch import MTCNN, InceptionResnetV1 +from PIL import Image, ImageDraw +from facenet_pytorch import MTCNN, extract_face +import random + +mtcnn = MTCNN(keep_all=True) +clip_name = sorted(os.listdir(os.getcwd())) + + +for i in clip_name: + clip_dir = os.path.join(os.getcwd(),i) + for image in sorted(os.listdir(clip_dir)): + img = Image.open(os.path.join(clip_dir, image)) + img_id = random.randint(1,10000) + boxes, probs, points = mtcnn.detect(img, landmarks=True) + img_draw = img.copy() + draw = ImageDraw.Draw(img_draw) + if boxes is None: + continue + for f, (box, point) in enumerate(zip(boxes, points)): + draw.rectangle(box.tolist(), width=5) + for p in point: + draw.rectangle((p - 10).tolist() + (p + 10).tolist(), width=10) + face = extract_face(img, box, save_path='/home/mayoughi/outputs/detected_face_{}_{}.png'.format(img_id, f)) + img_draw.save('/home/mayoughi/outputs/annotated_faces_{}.png'.format(img_id)) + break diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..eabb5b0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,27 @@ +apex==0.9.10dev +classy_vision==0.6.0 +facenet_pytorch==2.5.2 +fuzzywuzzy==0.18.0 +fvcore==0.1.5.post20220512 +matplotlib==3.5.2 +numpy==1.22.4 +omegaconf==2.2.2 +openpyxl==3.0.10 +openpyxl_image_loader==1.0.5 +pandas==1.4.2 +Pillow==9.1.1 +pysrt==1.1.2 +PyYAML==6.0 +scikit_learn==1.1.1 +scipy==1.8.1 +seaborn==0.11.2 +sentence_transformers==2.2.0 +simcse==0.4 +tensorboard==2.9.1 +torch==1.11.0 +torchvision==0.12.0 +tqdm==4.64.0 +transformers==4.19.4 +umap==0.1.1 +vissl==0.1.6 +xlsxwriter==3.0.3 diff --git a/scene_recognition.py b/scene_recognition.py new file mode 100644 index 0000000..855adf1 --- /dev/null +++ b/scene_recognition.py @@ -0,0 +1,272 @@ +import vissl +import tensorboard +import apex +import torch +import json +from omegaconf import OmegaConf +from vissl.utils.hydra_config import AttrDict +from vissl.utils.hydra_config import compose_hydra_configuration, convert_to_attrdict +from vissl.models import build_model +from classy_vision.generic.util import load_checkpoint +from vissl.utils.checkpoint import init_model_from_consolidated_weights +from PIL import Image +import torchvision.transforms as transforms +import glob, os +from tqdm import tqdm +import matplotlib.pyplot as plt +from sklearn.manifold import TSNE +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +# Config is located at vissl/configs/config/pretrain/simclr/simclr_8node_resnet.yaml. +# All other options override the simclr_8node_resnet.yaml config. + +cfg = [ + 'config=benchmark/linear_image_classification/places205/models/regnet32Gf.yaml', + 'config.MODEL.WEIGHTS_INIT.PARAMS_FILE=./content/regnet_seer.torch', # Specify path for the model weights. + 'config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_MODE_ON=True', # Turn on model evaluation mode. + 'config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY=True', # Freeze trunk. + 'config.MODEL.FEATURE_EVAL_SETTINGS.EXTRACT_TRUNK_FEATURES_ONLY=True', # Extract the trunk features, as opposed to the HEAD. + 'config.MODEL.FEATURE_EVAL_SETTINGS.SHOULD_FLATTEN_FEATS=True', # Do not flatten features. + 'config.MODEL.FEATURE_EVAL_SETTINGS.LINEAR_EVAL_FEAT_POOL_OPS_MAP=[["res5", ["Identity", []]]]' # Extract only the res5avg features. +] + +# Compose the hydra configuration. +cfg = compose_hydra_configuration(cfg) +# Convert to AttrDict. This method will also infer certain config options +# and validate the config is valid. +_, cfg = convert_to_attrdict(cfg) + + + +model = build_model(cfg.MODEL, cfg.OPTIMIZER) +# Load the checkpoint weights. +weights = load_checkpoint(checkpoint_path=cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE) + + +# Initializei the model with the simclr model weights. +model = init_model_from_consolidated_weights( + config=cfg, + model=model, + state_dict=weights, + state_dict_key_name="classy_state_dict", + skip_layers=[], # Use this if you do not want to load all layers +) + +print("Weights have loaded") + + +def extract_features(sampling_rate=None, batch_size=10): + # episode ='01' + # imgs_dir = glob.glob(f"./dataset/frames_hq/friends_frames/friends_s01e{episode}_seg0*") + imgs_dir = glob.glob(f"./dataset/frames_hq/bbt_frames/*") + i = 0 + batch = 0 + first_batch = True + for img_dir in tqdm(imgs_dir): + imgs_path = glob.glob(os.path.join(img_dir, "*.jpg")) + for img in imgs_path: + if i % sampling_rate != 0: + i += 1 + continue + i += 1 + image = Image.open(img) + # Convert images to RGB. This is important + # as the model was trained on RGB images. + image = image.convert("RGB") + # Image transformation pipeline. + pipeline = transforms.Compose([ + transforms.Resize(size=(224, 224)), + # transforms.CenterCrop(224), + transforms.ToTensor(), + ]) + x = pipeline(image) + x = x.unsqueeze(0) + if batch == 0: + img_batch = x + else: + img_batch = torch.cat((img_batch, x), 0) + batch += 1 + if batch == batch_size: + print(f"extracting features for i {i} . . .") + features = model(img_batch) + # features = features[0] + if first_batch: + all_embeddings = features[0] + print(f"size of features: {features[0].shape}") + first_batch = False + else: + all_embeddings = torch.cat((all_embeddings, features[0]), 0) + batch = 0 + + if batch != batch_size: #the last batch + features = model(img_batch) + # features = features[0] + all_embeddings = torch.cat((all_embeddings, features[0]), 0) + + # print("extracting features . . .") + # features = model(img_batch) + # features = features[0] + # print(f"Features extracted have the shape: { features.shape }") + torch.save(all_embeddings, f"./dataset/bbt_scene_embeddings_rate_{sampling_rate}.pt") + return all_embeddings + + +def save_json(data, file_path): + with open(file_path, "w+") as f: + json.dump(data, f) + + +def load_json(file_path): + with open(file_path, "r") as f: + return json.load(f) + + +def get_embeddings(): + dataset_dict = load_json(f"./dataset/mentions/friends_dict.json") + + first_img = True + for i, data in enumerate(dataset_dict.values()): + if data['first_frame']: + image = Image.open(os.path.join(data['clip_dir'], data['img'] + ".jpg")) + image = image.convert("RGB") + pipeline = transforms.Compose([ + transforms.Resize(size=(224, 224)), + transforms.ToTensor(), + ]) + x = pipeline(image) + x = x.unsqueeze(0) + features = model(x) + dataset_dict[str(i)]['embedding'] = features[0] + dataset_dict[str(i+1)]['embedding'] = features[0] + dataset_dict[str(i+2)]['embedding'] = features[0] + dataset_dict[str(i+3)]['embedding'] = features[0] + + if first_img: + all_embeddings = features[0] + first_img = False + else: + all_embeddings = torch.cat((all_embeddings, features[0]), 0) + + # dataset_dict[i] = {"series": self.series, "clip_dir": clip_dir, "clip": clip, + # "img": img_str, + # "subtitle": srt_data['sub_text'][clip][t], + # "mentions": list(mentions), + # "filtered_mentions": list(intersect), + # "first_frame": True if frame_num == frame_num_begin else False} + + torch.save(all_embeddings, f"./dataset/friends_scenes_with_mention_embs.pt") + return all_embeddings + + +def visualize_tsne(tsne_grid, image_loader=None): + + # num_classes = len(id_to_lbl) + # convert to pandas + # label_ids = pd.DataFrame(label_ids, columns=['label'])['label'] + # create a scatter plot. + fig = plt.figure(figsize=(8, 8)) + ax = plt.subplot(aspect='equal') + # if not label_ids.isnull().values.any(): + # plt.scatter(tsne_grid[:, 0], tsne_grid[:, 1], lw=0, s=40, c=np.asarray(label_ids), + # cmap=discrete_cmap(num_classes, "tab10")) + # # , c = palette[np.asarray([lbl_to_id[lbl] for lbl in colors])] + # # c = np.random.randint(num_classes, size=len(tsne_grid[:, 1])) + # else: + # plt.scatter(tsne_grid[:, 0], tsne_grid[:, 1], lw=0, s=40) + plt.xlim(-114, 178) + plt.ylim(-117, 268) + # cbar = plt.colorbar(ticks=range(num_classes)) + # cbar.set_ticklabels(list(id_to_lbl.values())) + # plt.clim(-0.5, num_classes - 0.5) + # ax.axis('off') + # ax.axis('tight') + dataset_dict = load_json(f"./dataset/mentions/friends_dict.json") + + max_dim = 16 + max_x, max_y, min_x, min_y = 0, 0, 100, 100 + for i, data in enumerate(dataset_dict.values()): + if data['first_frame']: + x,y = tsne_grid[i//4] + if x > max_x: + max_x = x + if x < min_x: + min_x = x + if y > max_y: + max_y = y + if y < min_y: + min_y = y + tile = Image.open(os.path.join(data['clip_dir'], data['img'] + ".jpg")) + # data['embedding'] + # tile = image_loader.get(os.path.join(data['clip_dir'], data['img'] + ".jpg")) + # tile = Image.open(img) + rs = max(1, tile.width/max_dim, tile.height/max_dim) + tile = tile.resize((int(tile.width/rs), int(tile.height/rs)), Image.ANTIALIAS) + imagebox = OffsetImage(tile) #, zoom=0.2) + ab = AnnotationBbox(imagebox, (x, y), pad=0.1) + ax.add_artist(ab) + print(f"max x: {max_x}, min x {min_x}, max y {max_y}, min y {min_y}") + return fig, plt + + # max_dim = 16 + # imgs = glob.glob('./dataset/frames_hq/friends_frames/friends_s01e01_seg02_clip_17/*.jpg') + # for i, ((x, y), img) in enumerate(zip(tsne_grid, imgs)): + # # print(i, x, y, img) + # # tile = image_loader.get(img) + # tile = Image.open(img) + # rs = max(1, tile.width/max_dim, tile.height/max_dim) + # tile = tile.resize((int(tile.width/rs), int(tile.height/rs)), Image.ANTIALIAS) + # imagebox = OffsetImage(tile) #, zoom=0.2) + # ab = AnnotationBbox(imagebox, (x, y), pad=0.1) + # ax.add_artist(ab) + # + # return fig, plt + + # max_dim = 16 + # imgs_dir = glob.glob(f"./dataset/frames_hq/bbt_frames/*") + # i = 0 + # tsne_counter = 0 + # for img_dir in tqdm(imgs_dir): + # imgs_path = glob.glob(os.path.join(img_dir, "*.jpg")) + # for img in imgs_path: + # if i % sampling_rate != 0: + # i += 1 + # continue + # i += 1 + # tile = Image.open(img) + # x, y = tsne_grid[tsne_counter] + # # print(x, y) + # rs = max(1, tile.width / max_dim, tile.height / max_dim) + # tile = tile.resize((int(tile.width / rs), int(tile.height / rs)), Image.ANTIALIAS) + # imagebox = OffsetImage(tile) # , zoom=0.2) + # ab = AnnotationBbox(imagebox, (x, y), pad=0.1) + # ax.add_artist(ab) + # tsne_counter += 1 + # + # return fig, plt + + +# get_embeddings() + +all_embeddings = torch.load(f"./dataset/friends_scenes_with_mention_embs.pt") +tsne_grid = TSNE(random_state=10, n_iter=4000).fit_transform(all_embeddings.detach().numpy()) +fig, plt = visualize_tsne(tsne_grid) +fig.savefig(os.path.join(f"./output/scene_recognition/", f"friends_scene.pdf")) +plt.clf() + +# sampling_rate = 1000 +# # scene_embeddings = extract_features(sampling_rate=sampling_rate, batch_size=10) +# # scene_embeddings = torch.load("./dataset/scene_embeddings_all.pt") +# scene_embeddings = torch.load(f"./dataset/bbt_scene_embeddings_rate_{sampling_rate}.pt") +# tsne_grid = TSNE(random_state=10, n_iter=4000).fit_transform(scene_embeddings.detach().numpy()) +# fig, plt = visualize_tsne(tsne_grid) +# fig.savefig(os.path.join(f"./output/scene_recognition/", f"bbt_scene_rate_{sampling_rate}.pdf")) +# plt.clf() + + +# sampling_rate = 100 +# scene_embeddings = extract_features(sampling_rate=sampling_rate, batch_size=100) +# # scene_embeddings = torch.load("./dataset/scene_embeddings_all.pt") +# # scene_embeddings = torch.load(f"./dataset/bbt_scene_embeddings_rate_{sampling_rate}.pt") +# tsne_grid = TSNE(random_state=10, n_iter=2000).fit_transform(scene_embeddings.detach().numpy()) +# fig, plt = visualize_tsne(tsne_grid) #size of features: +# fig.savefig(os.path.join(f"./output/scene_recognition/", f"bbt_scene_rate_{sampling_rate}.pdf")) +# plt.clf() \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..e30ef48 --- /dev/null +++ b/train.py @@ -0,0 +1,256 @@ +from torch.utils.data import DataLoader +import torchvision.models as models +from tvqa_dataset import TVQADataset, get_train_transforms, get_test_transforms, TwoAugUnsupervisedTVQADataset, TwoWeakOrAugTVQADataset, OnlyUnknownsTVQADataset +from torch.nn import CrossEntropyLoss +from torch.utils.tensorboard import SummaryWriter +import torch +from tqdm import tqdm +import os +from pathlib import Path +from datetime import datetime +from facenet_pytorch import InceptionResnetV1 +from config import default_argument_parser, get_cfg, set_global_cfg, global_cfg +from fvcore.common.file_io import PathManager +from sentence_transformers import SentenceTransformer +import torch.nn.functional as F + + +def align_loss(x, y, alpha=2): + return (x - y).norm(p=2, dim=1).pow(alpha).mean() + + +def uniform_loss(x, t=2): + return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log() + + +class L2Norm(torch.nn.Module): + def forward(self, x): + return x / x.norm(p=2, dim=1, keepdim=True) + + +class VGGSupervised(torch.nn.Module): + + def __init__(self, cfg, num_classes): + super().__init__() + self.resnet = InceptionResnetV1( + classify=True, + pretrained='vggface2', + num_classes=num_classes) + + for param in self.resnet.parameters(): + param.requires_grad = False # fix the encoder part + + self.resnet.logits = torch.nn.Linear(in_features=512, out_features=num_classes, bias=True) + + def forward(self, x): + x = self.resnet(x) + return x + + +class VGGFaceSupervised(torch.nn.Module): + + def __init__(self, cfg, num_classes): + super().__init__() + self.resnet = InceptionResnetV1( + classify=True, + pretrained='vggface2', + num_classes=num_classes) + + self.resnet.last_bn = torch.nn.Identity() + for param in self.resnet.parameters(): + param.requires_grad = False # fix the encoder part + + self.resnet.logits = torch.nn.Sequential(torch.nn.Linear(in_features=512, + out_features=cfg.MODEL.out_features_1, bias=True), + torch.nn.ReLU(inplace=False), + torch.nn.Linear(in_features=cfg.MODEL.out_features_1, + out_features=cfg.MODEL.out_features_2, bias=True), + torch.nn.BatchNorm1d(cfg.MODEL.out_features_2, + eps=0.001, momentum=0.1, affine=True, + track_running_stats=True), + torch.nn.Linear(in_features=cfg.MODEL.out_features_2, + out_features=cfg.MODEL.out_features_3, bias=True), + ) + + def forward(self, x): + x = self.resnet(x) + return x + + +class VGGFacePlus(torch.nn.Module): + + def __init__(self, cfg, num_classes): + super().__init__() + self.resnet = InceptionResnetV1( + classify=True, + pretrained='vggface2', + num_classes=num_classes) + + self.resnet.last_bn = torch.nn.Identity() + for param in self.resnet.parameters(): + param.requires_grad = False # fix the encoder part + + self.resnet.logits = torch.nn.Sequential(torch.nn.Linear(in_features=512, + out_features=cfg.MODEL.out_features_1, bias=True), + torch.nn.ReLU(inplace=False), + torch.nn.Linear(in_features=cfg.MODEL.out_features_1, + out_features=cfg.MODEL.out_features_2, bias=True), + # torch.nn.ReLU(inplace=False), + torch.nn.BatchNorm1d(cfg.MODEL.out_features_2, + eps=0.001, momentum=0.1, affine=True, + track_running_stats=True), + torch.nn.Linear(in_features=cfg.MODEL.out_features_2, + out_features=cfg.MODEL.out_features_3, bias=True), + # torch.nn.ReLU(inplace=False), + ) + self.l2norm = L2Norm() + + def forward(self, x): + x = self.resnet(x) + return self.l2norm(x) + + +class VGGFaceSubtitle(torch.nn.Module): + + def __init__(self, cfg, num_classes): + super().__init__() + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.text_model = SentenceTransformer('sentence-transformers/paraphrase-mpnet-base-v2') + + self.resnet = InceptionResnetV1( + classify=True, + pretrained='vggface2', + num_classes=num_classes) + + self.resnet.last_bn = torch.nn.Identity() + self.resnet.logits = torch.nn.Identity() + for param in self.resnet.parameters(): + param.requires_grad = False # fix the encoder part + + for param in self.text_model.parameters(): + param.requires_grad = False # fix the encoder part + + self.face_layer = torch.nn.Linear(in_features=512, out_features=cfg.SSL.face_layer_out_features, bias=True) + self.sub_layer = torch.nn.Linear(in_features=768, out_features=cfg.SSL.sub_layer_out_features, bias=True) + self.mix_layer = torch.nn.Linear(in_features=cfg.SSL.mix_layer_in_features, + out_features=cfg.SSL.mix_layer_out_features, bias=True) + + self.l2norm = L2Norm() + + def forward(self, face1, face2, sub1, sub2): + if face2 is None and sub2 is None: + x1 = self.resnet(face1) + sub1 = torch.from_numpy(self.text_model.encode(sub1)).to(self.device) + if len(sub1.shape) == 1: + sub1 = sub1[None, :] + mix = torch.cat([x1, sub1], dim=1) + if global_cfg.SSL.mix_layer: + mix = self.mix_layer(mix) + return self.l2norm(mix) + else: + x = torch.cat([face1, face2]) + x = self.resnet(x) + + sub1 = torch.from_numpy(self.text_model.encode(sub1)).to(self.device) # sub1:torch.Size([256, 768]) + sub2 = torch.from_numpy(self.text_model.encode(sub2)).to(self.device) + + #1 layer for image and one for text + if global_cfg.SSL.face_layer: + x = self.face_layer(x) + if global_cfg.SSL.sub_layer: + sub1 = self.sub_layer(sub1) + sub2 = self.sub_layer(sub2) + + x1, x2 = x.chunk(2) # x1: torch.Size([256, 512]) + # concat(x1, sub1) concat(x2, sub2) + y1 = torch.cat([x1, sub1], dim=1) + y2 = torch.cat([x2, sub2], dim=1) + # opposite of chunk + mix = torch.cat([y1, y2]) # mix: torch.Size([256, 1280]) + # 1/2 layers for both to mix + if global_cfg.SSL.mix_layer: + mix = self.mix_layer(mix) + + return self.l2norm(mix).chunk(2) + + +def train(cfg): + print("start running: ", datetime.now().strftime("%d/%m/%Y %H:%M:%S")) + print(f"cuda is available: {torch.cuda.is_available()}") + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + writer = SummaryWriter(f"{cfg.TRAINING.project_dir}") + tvqa_train = TVQADataset(series=cfg.TRAINING.series, split="train", transform=get_test_transforms(series=cfg.TRAINING.series)) + # tvqa_train = TwoAugUnsupervisedTVQADataset(split="train", transform=get_train_transforms()) + train_loader = DataLoader(tvqa_train, batch_size=cfg.TRAINING.batch_size, shuffle=True, num_workers=0) #, collate_fn=lambda x: x) + + if cfg.TRAINING.pretrained is True: + model = VGGFacePlus(cfg, len(tvqa_train.lbl_to_id)) + if cfg.SSL.joint is True: + model = VGGFaceSubtitle(cfg, len(tvqa_train.lbl_to_id)) + if cfg.SSL.supervised is True: + model = VGGFaceSupervised(cfg, len(tvqa_train.lbl_to_id)) + if cfg.TRAINING.supervised is True: + model = VGGSupervised(cfg, len(tvqa_train.lbl_to_id)) + + model.to(device) + optimizer = torch.optim.Adam(model.parameters(), + lr=cfg.TRAINING.lr) + # optimizer = torch.optim.SGD(model.parameters(), + # lr=cfg.TRAINING.lr, + # momentum=cfg.TRAINING.momentum, + # weight_decay=cfg.TRAINING.weight_decay) + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, + gamma=cfg.TRAINING.lr_decay_rate, + milestones=[cfg.TRAINING.lr_decay_epoch]) + criterion = CrossEntropyLoss() + for epoch in tqdm(range(cfg.TRAINING.epochs)): + for iteration, data in enumerate(train_loader): + predictions = model(data['image'].to(device)) + data_mode = cfg.TRAINING.data_mode + # one_hot = F.one_hot(data["cleansed"].to(device), num_classes=len(tvqa_train.lbl_to_id)) + # loss = criterion(F.softmax(predictions), one_hot) + loss = criterion(predictions, data[data_mode].to(device)) + # loss = criterion(predictions, data['correct_target_id'].to(device)) + writer.add_scalar("train_loss", loss, epoch*len(train_loader)+iteration) + for param_group in optimizer.param_groups: + writer.add_scalar("lr", param_group['lr'], epoch * len(train_loader) + iteration) + break + optimizer.zero_grad() + loss.backward() + optimizer.step() + scheduler.step() + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'train_loss': loss, + }, f"{cfg.TRAINING.project_dir}model/epoch_{epoch}.tar") + print("time: ", datetime.now().strftime("%d/%m/%Y %H:%M:%S")) + if epoch > 0: + os.remove(f"{cfg.TRAINING.project_dir}model/epoch_{epoch-1}.tar") + print("end running: ", datetime.now().strftime("%d/%m/%Y %H:%M:%S")) + + +if __name__ == "__main__": + # config priority: + # 1. arguments + # 2. config.yaml file + # 3. defaults.py file + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + cfg = get_cfg() + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + + print("Command line arguments: " + str(args)) + print("Running with full config:\n{}".format(cfg)) + Path(cfg.TRAINING.project_dir).mkdir(parents=True, exist_ok=True) + Path(os.path.join(cfg.TRAINING.project_dir, "model")).mkdir(parents=True, exist_ok=True) + config_path = os.path.join(cfg.TRAINING.project_dir, "config.yaml") + with PathManager.open(config_path, "w") as f: + f.write(cfg.dump()) + print("Full config saved to {}".format(os.path.abspath(config_path))) + set_global_cfg(cfg) + train(cfg) diff --git a/tvqa_dataset.py b/tvqa_dataset.py new file mode 100644 index 0000000..3d5d148 --- /dev/null +++ b/tvqa_dataset.py @@ -0,0 +1,1634 @@ +import os +import glob + +from torchvision import transforms +from tqdm import tqdm +import json +import pysrt +from transformers import AutoModelForTokenClassification, AutoTokenizer +import torch +import math +from pathlib import Path +from PIL import Image, ImageDraw +from facenet_pytorch import MTCNN, extract_face +import torch.utils.data as data +import numpy as np +import matplotlib.pyplot as plt +from sklearn.model_selection import train_test_split +import argparse +import random +import os +import glob +import math +from config import global_cfg +from fuzzywuzzy import fuzz +import pandas as pd +from PIL import Image + + +def save_json(data, file_path): + with open(file_path, "w+") as f: + json.dump(data, f) + + +def load_json(file_path): + with open(file_path, "r") as f: + return json.load(f) + + +def get_train_transforms(series): + if series == "friends": + transform = transforms.Compose([ + transforms.RandomResizedCrop(global_cfg.SSL.random_crop, scale=(0.08, 1)), + transforms.RandomHorizontalFlip(), + # transforms.RandomVerticalFlip(), + transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), + transforms.RandomGrayscale(p=0.2), + transforms.ToTensor(), + transforms.Normalize( + (0.4313, 0.2598, 0.2205), + (0.1468, 0.1143, 0.1127), + ), + ]) + elif series == "bbt": + transform = transforms.Compose([ + transforms.RandomResizedCrop(global_cfg.SSL.random_crop, scale=(0.08, 1)), + transforms.RandomHorizontalFlip(), + # transforms.RandomVerticalFlip(), + transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), + transforms.RandomGrayscale(p=0.2), + transforms.ToTensor(), + transforms.Normalize( + (0.3186, 0.2131, 0.1926), + (0.1695, 0.1326, 0.1284), + ), + ]) + return transform + + +def get_test_transforms(series): + if series == "friends": + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + (0.4313, 0.2598, 0.2205), + (0.1468, 0.1143, 0.1127), + ), + ]) + elif series == "bbt": + transform = transforms.Compose([ + transforms.Resize(size=(160, 160)), + transforms.ToTensor(), + transforms.Normalize( + (0.3186, 0.2131, 0.1926), + (0.1695, 0.1326, 0.1284), + ), + ]) + return transform + + +def find_norm_data(dataset): + # friends mean: tensor([0.4313, 0.2598, 0.2205]) std: tensor([0.1468, 0.1143, 0.1127]) + # bbt mean: tensor([0.3186, 0.2131, 0.1926]) tensor([0.1695, 0.1326, 0.1284]) + from torch.utils.data import DataLoader + loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0) + + mean = 0. + std = 0. + nb_samples = 0. + for data in loader: + img = data['image'] + batch_samples = img.size(0) + img = img.view(batch_samples, img.size(1), -1) + mean += img.mean(2).sum(0) + std += img.std(2).sum(0) + nb_samples += batch_samples + + mean /= nb_samples + std /= nb_samples + + return mean, std + + +class TVQADataset(data.Dataset): + def __init__(self, series, split="train", transform=None): + self.entity_recognition_model = AutoModelForTokenClassification.from_pretrained( + "dbmdz/bert-large-cased-finetuned-conll03-english") + # self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") #todo: preload the model + self.label_list = [ + "O", # Outside of a named entity + "B-MISC", # Beginning of a miscellaneous entity right after another miscellaneous entity + "I-MISC", # Miscellaneous entity + "B-PER", # Beginning of a person's name right after another person's name + "I-PER", # Person's name + "B-ORG", # Beginning of an organisation right after another organisation + "I-ORG", # Organisation + "B-LOC", # Beginning of a location right after another location + "I-LOC" # Location + ] + self.split = split + self.series = series + self.transform = transform + self.project_dir = "./" + self.all_subtitles_loc = self.project_dir + "dataset/tvqa_subtitles/" + + # self.subtitle_json = self.all_subtitles_loc + f"subtitle_cache_friends_s01e{episode}.json" + # for i in range(1, 6): + # srt_data = self.load_srt(self.all_subtitles_loc, self.all_subtitles_loc + f"subtitle_cache_friends_s01e0{i}.json") + # self.dataset_dict = self.prepare_tvqa_json(srt_data, dataset_path=self.project_dir + f"dataset/train_episodes/friends_dict_s01e0{i}_extendedbb.json") + # self.dataset_dict = self.clean_dict(self.dataset_dict, path=self.project_dir + f"dataset/friends_dict.json") + if global_cfg.TRAINING.exp_type == "normal": + + if self.series == "bbt": + self.anns = self.load_anns(None, ann_path=self.project_dir + f"dataset/{self.series}_{self.split}_annotations.json") + elif self.series == "friends": + self.anns = self.load_anns(None, ann_path=self.project_dir + f"dataset/{self.split}_annotations.json") + # if self.split == "train": + # self.anns = self.load_anns(None, ann_path=self.project_dir + f"dataset/new_mix_{self.series}_annotations.json") + # elif self.split == "test": + # self.anns = self.load_anns(None, ann_path=self.project_dir + f"dataset/new_tvqa_plus_test_annotations.json") + # elif self.series == "friends": + # self.anns = self.load_anns(None, ann_path=self.project_dir + f"dataset/{self.split}_annotations.json") + # self.anns = self.load_anns(None, ann_path=self.project_dir + f"dataset/weak_all_{self.series}_{self.split}_annotations.json") + elif global_cfg.TRAINING.exp_type == "oracle": + if self.series == "bbt": + if global_cfg.TRAINING.ours_or_baseline == "baseline": + if self.split == "train": + self.anns = self.load_anns(None, ann_path=self.project_dir + f"dataset/bbt_1_8.json") + # self.anns = self.load_anns(None, ann_path=self.project_dir + f"dataset/weak_bbt_train_annotations.json") + elif self.split == "test": + self.anns = self.load_anns(None, ann_path=self.project_dir + f"dataset/bbt_9_10.json") + elif global_cfg.TRAINING.ours_or_baseline == "ours": + if self.split == "train": + self.anns = self.load_anns(None, ann_path=self.project_dir + f"dataset/{self.series}_{self.split}_annotations.json") + elif self.split == "test": + self.anns = self.load_anns(None, ann_path=self.project_dir + f"dataset/bbt_9_10.json") + self.lbl_to_id = self.load_labels(self.anns, lbl_path=self.project_dir + f"dataset/{self.series}_lbl_to_id.json") + self.id_to_lbs = {id: lbl for (lbl, id) in self.lbl_to_id.items()} + + def is_unknown(self, ann): + if not ann['name']: + return True + if len(set(ann['name']).intersection(set(self.lbl_to_id.keys()))) > 0: + return False + else: + return True + + def stitch_tokens_in_ann(self, anns): + for ann in anns.values(): + if ann["name"]: + i = 0 + while i < len(ann["name"]): + weak_lbl = ann["name"][i] + if weak_lbl.startswith("##"): + if i != 0: + print(weak_lbl, ann["name"][i - 1]) + ann["name"][i - 1] = ann["name"][i - 1] + weak_lbl[2:] + ann["name"].remove(weak_lbl) + print(ann["name"][i - 1]) + else: + i += 1 + else: + print(weak_lbl) + i += 1 + return anns + + def stitch_tokens_in_dict(self, dataset_dict): + for data in dataset_dict.values(): + if data["names"]: + i = 0 + while i < len(data["names"]): + weak_lbl = data["names"][i] + if weak_lbl.startswith("##"): + if i != 0: + print(weak_lbl, data["names"][i - 1]) + data["names"][i - 1] = data["names"][i - 1] + weak_lbl[2:] + data["names"].remove(weak_lbl) + print(data["names"][i - 1]) + else: + i += 1 + else: + print(weak_lbl) + i += 1 + return dataset_dict + + def build_tvqa_data(self): + # self.lbl_to_id = {} + # self.build_tvqa_data() + srt_data = self.load_srt(self.all_subtitles_loc, self.subtitle_json) + dataset_dict = self.prepare_tvqa_json(srt_data) + clean_dataset_dict = self.clean_dict(dataset_dict, path=self.project_dir + f"dataset/clean_friends_dict.json") + self.anns = self.load_anns(clean_dataset_dict, ann_path=self.project_dir + f"dataset/clean_annotations.json") + self.lbl_to_id = self.load_labels(self.anns, lbl_path=self.project_dir + "dataset/lbl_to_id.json") + self.split_train_val_test(self.anns, path=self.project_dir + "dataset/") + self.create_hist(file_name="all_annotations_hist") + + def clean_anns(self, anns, path, lbl_type): + if os.path.exists(path): + print("Found clean dictionary cached, loading ...") + return load_json(path) + all_labels = [] + lbl_to_freq = {} + to_be_deleted_lbls = set([]) + for ann in tqdm(anns.values()): + all_labels.append(ann[lbl_type]) + + print(all_labels) + for lbl in set(all_labels): + lbl_to_freq[lbl] = 0 + + # calculate label frequencies: + for lbl in all_labels: + lbl_to_freq[lbl] += 1 + + for lbl in lbl_to_freq.keys(): + print(f"label: {lbl} frequency:{lbl_to_freq[lbl]}") + + # sort based on frequencies + sort_dict = dict(sorted(lbl_to_freq.items(), key=lambda item: item[1], reverse=True)) + for i, lbl in enumerate(sort_dict.keys()): + if lbl_type == 'target_name': + if i >= 7: # includes "Unknown" + to_be_deleted_lbls.add(lbl) + elif lbl_type == 'name': + if i >= 6: # doesn't include "Unknown" + to_be_deleted_lbls.add(lbl) + print(f"deleted labels: {to_be_deleted_lbls}") + + # delete corresponding annotations + for ann in tqdm(anns.values()): + print(f"before: {ann[lbl_type]}") + # [x for x in array1 if x not in array2] + # delete less frequents: + # dataset[img]['names'] = [x for x in dataset[img]['names'] if x not in to_be_deleted_lbls] + # keeping less frequent as unknowns: + ann[lbl_type] = ann[lbl_type] if ann[lbl_type] not in to_be_deleted_lbls else 'Unknown' + # dataset[img]['names'] = list(set(dataset[img]['names']) - to_be_deleted_lbls) + print(f"after: {ann[lbl_type]}") + + save_json(anns, path) + return anns + + def clean_dict(self, dataset, path): + if os.path.exists(path): + print("Found clean dictionary cached, loading ...") + return load_json(path) + + all_labels = [] + lbl_to_freq = {} + to_be_deleted_lbls = set([]) + + for img in tqdm(dataset.keys()): + # if list is not empty + if dataset[img]['names']: + all_labels.extend(dataset[img]['names']) # add a list to list + + print(all_labels) + for lbl in set(all_labels): + lbl_to_freq[lbl] = 0 + + # calculate label frequencies: + for lbl in all_labels: + lbl_to_freq[lbl] += 1 + + for lbl in lbl_to_freq.keys(): + print(f"label: {lbl} frequency:{lbl_to_freq[lbl]}") + + for lbl, freq in lbl_to_freq.items(): + if freq < 580: + to_be_deleted_lbls.add(lbl) + print(f"deleted labels: {to_be_deleted_lbls}") + + # delete corresponding annotations + for img in tqdm(dataset.keys()): + print(f"before: {dataset[img]['names']}") + # [x for x in array1 if x not in array2] + # delete less frequents: + # dataset[img]['names'] = [x for x in dataset[img]['names'] if x not in to_be_deleted_lbls] + # keeping less frequent as unknowns: + dataset[img]['names'] = [x if x not in to_be_deleted_lbls else 'Unknown' for x in dataset[img]['names']] + # dataset[img]['names'] = list(set(dataset[img]['names']) - to_be_deleted_lbls) + print(f"after: {dataset[img]['names']}") + save_json(dataset, path) + return dataset + + def __getitem__(self, index): + ann = self.anns[str(index)] + if global_cfg.TRAINING.exp_type == "oracle": + dir = f"./dataset/new_tvqa_plus_{self.series}_frames_onlyfaces/" + elif global_cfg.TRAINING.exp_type == "normal": + dir = "./dataset/" + ann['series'] + "_frames/" + image = Image.open(os.path.join(dir, ann['face'])) + if self.transform is not None: + image = self.transform(image) + if global_cfg.TRAINING.exp_type == "oracle": + dict = { + "image": image, + "clip": ann["clip"], + "series": ann["series"], + "face": ann["face"], + "correct_target_name": ann["name"], + "correct_target_id": self.lbl_to_id[ann["name"]], + "bbox": ann["bbox"], + # "weak_label": self.lbl_to_id[ann['weak_lbls'][0]], + } + elif global_cfg.TRAINING.exp_type == "normal": + dict = { + # "image": image, + "clip": ann["clip"], + "series": ann["series"], + "face": ann["face"], + "subtitle": ann["subtitle"], + } + + if self.split == "train": + # if global_cfg.TRAINING.series == "friends": + target = [] + if ann['name']: + for name in ann['name']: + if name in self.lbl_to_id.keys(): + target.append(self.lbl_to_id[name]) + else: + target.append(self.lbl_to_id["Unknown"]) + # else: + # target.append(self.lbl_to_id["Unknown"]) + # else: + # target = self.lbl_to_id[ann['name']] + dict["weak_id"] = target + # dict["weak_id"] = target[0] + dict["weak_name"] = ann['name'] + dict["cleansed"] = ann["cleansed"] + if self.series == "bbt": + dict["bbox"] = ann["bbox"] + dict["face_points"] = ann["face_points"] + + elif self.split == "test": + if self.series == "friends": + dict["correct_target_name"] = ann["target_name"] + dict["correct_target_id"] = self.lbl_to_id[ann["target_name"]] + elif self.series == "bbt": + dict["correct_target_name"] = ann["name"] + dict["correct_target_id"] = self.lbl_to_id[ann["name"]] + + return dict + # target = self.lbl_to_id[ann['name']] + # target = [] + # if ann['name']: + # for name in ann['name']: + # if name in self.lbl_to_id.keys(): + # target.append(self.lbl_to_id[name]) + # else: + # target.append(self.lbl_to_id["Unknown"]) + + # return {"image": image, + # "weak_label": target, + # "correct_target_name": ann["target_name"], + # "correct_target_id": self.lbl_to_id[ann["target_name"]], + # "weak_label": [self.lbl_to_id[name] for name in ann['name'] if ann['name']], + # } + # "weak_label": [self.lbl_to_id[name] for name in ann['name'] if ann['name']]} + # , "weak_label": target, "correct_target_id": ann["target_id"]} + + def __len__(self): + return len(list(self.anns.keys())) + + def load_labels(self, anns, lbl_path, lbl_type="name"): + + if os.path.exists(lbl_path): + print("Found labels cache, loading ...") + return load_json(lbl_path) + + labels = set([]) + lbl_to_id = {} + if self.split == "test" or self.split == "dev": + if lbl_type == "target_name": + for ann in anns.values(): + labels.add(ann[lbl_type]) + elif lbl_type == "name": + for ann in anns.values(): + if ann[lbl_type]: + for lbl in ann[lbl_type]: + labels.add(lbl) + elif self.split == "train": + for ann in anns.values(): + labels.add(ann[lbl_type]) + + for idx, lbl in enumerate(labels): + lbl_to_id[lbl] = idx + + save_json(lbl_to_id, lbl_path) + return lbl_to_id + + def load_anns(self, dataset, ann_path): + + if os.path.exists(ann_path): + print("Found annotation cache, loading ...") + return load_json(ann_path) + + anns = {} + all_faces = [] + annid = 0 + for img in tqdm(dataset.keys()): + # if there is at least one face and one name (not empty) + if dataset[img]['faces'] and dataset[img]['names']: + for face, bbox, face_landmark in zip(dataset[img]['faces'], dataset[img]['bbox'], dataset[img]['face_points']): + for name in dataset[img]['names']: + anns[annid] = {'face': face, + 'name': name, + 'img': dataset[img]['img'], + 'subtitle': dataset[img]['subtitle'], + 'clip': dataset[img]['clip'], + 'series': dataset[img]['series'], + 'bbox': bbox, + 'face_points': face_landmark} + annid += 1 + save_json(anns, ann_path) + return load_json(ann_path) + + def load_anns_test(self, dataset, ann_path): + + if os.path.exists(ann_path): + print("Found annotation cache, loading ...") + return load_json(ann_path) + + anns = {} + annid = 0 + for img in tqdm(dataset.keys()): + # if there is at least one face and one name (not empty) + if dataset[img]['faces']: + for face in dataset[img]['faces']: + anns[annid] = {'face': face, 'name': dataset[img]['names'], 'img': dataset[img]['img'], + 'subtitle': dataset[img]['subtitle'], + 'clip': dataset[img]['clip'], + 'series': dataset[img]['series']} + annid += 1 + save_json(anns, ann_path) + return load_json(ann_path) + + def find_names(self, sequence): + # Bit of a hack to get the tokens with the special tokens + tokens = self.tokenizer.tokenize(self.tokenizer.decode(self.tokenizer.encode(sequence))) + inputs = self.tokenizer.encode(sequence, return_tensors="pt") + + outputs = self.entity_recognition_model(inputs)[0] # .logits + predictions = torch.argmax(outputs, dim=2) + # todo: should also include location names and other things + # name_to_label = [(token, self.label_list[prediction]) for token, prediction in + # zip(tokens, predictions[0].numpy()) if + # self.label_list[prediction] == "B-PER" or self.label_list[prediction] == "I-PER"] + name_to_other_labels = [] + name_to_label = [] + for token, prediction in zip(tokens, predictions[0].numpy()): + if self.label_list[prediction] == "B-PER" or self.label_list[prediction] == "I-PER": + name_to_label.append((token, self.label_list[prediction])) + elif self.label_list[prediction] in ["B-MISC", "I-MISC", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]: + name_to_other_labels.append((token, self.label_list[prediction])) + + return name_to_label, name_to_other_labels + + def load_srt(self, srt_dir, srt_cache_path): + """ + return: A python dict, the keys are the video names, the entries are lists, + each contains all the text from a .srt file + sub_times are the start time of the sentences. + """ + if os.path.exists(srt_cache_path): + print("Found srt data cache, loading ...") + return load_json(srt_cache_path) + print("Loading srt files from %s ..." % srt_dir) + # srt_paths = glob.glob(os.path.join(srt_dir, "friends_s01e02_seg02_clip_*.srt")) + + srt_paths = glob.glob(os.path.join(srt_dir, f"friends_s01e{self.episode}_seg0*.srt")) + # srt_paths = glob.glob(os.path.join(srt_dir, f".srt")) + name2sub_text = {} + name2sub_face = {} + name2sub_time = {} + name2sub_time_end = {} + for i in tqdm(range(len(srt_paths))): + subs = pysrt.open(srt_paths[i], encoding="iso-8859-1") + if len(subs) == 0: + subs = pysrt.open(srt_paths[i]) + + text_list = [] + name_list = [] + sub_time_list = [] + sub_time_list_end = [] + for j in range(len(subs)): + cur_sub = subs[j] + cur_str = cur_sub.text + cur_str = "(:)" + cur_str if cur_str[0] != "(" else cur_str + cur_str = cur_str.replace("\n", " ") + names = self.find_names(cur_str) + text_list.append(cur_str) + name_list.append(names) + sub_time_list.append( + 60 * cur_sub.start.minutes + cur_sub.start.seconds + 0.001 * cur_sub.start.milliseconds) + sub_time_list_end.append( + 60 * cur_sub.end.minutes + cur_sub.end.seconds + 0.001 * cur_sub.end.milliseconds) + + key_str = os.path.splitext(os.path.basename(srt_paths[i]))[0] + name2sub_text[key_str] = text_list + name2sub_face[key_str] = name_list + name2sub_time[key_str] = sub_time_list + name2sub_time_end[key_str] = sub_time_list_end + + srt_data = {"sub_text": name2sub_text, + "sub_face": name2sub_face, + "sub_time": name2sub_time, + "sub_time_end": name2sub_time_end} + save_json(srt_data, srt_cache_path) + return load_json(srt_cache_path) # we do this because the ints will turn to string 0 -> '0' + + def prepare_tvqa_json(self, srt_data, dataset_path): + if srt_data is None: + return + series_list = ["castle", "friends", "grey", "house", "met", "bbt"] + series_directory = self.project_dir + f"dataset/frames_hq/{series_list[5]}_frames/" + save_directory = self.project_dir + f"/dataset/{series_list[5]}_frames/" + + Path(save_directory).mkdir(parents=True, exist_ok=True) + + if os.path.exists(dataset_path): + print("Found dataset cache, loading ...") + return load_json(dataset_path) + + mtcnn = MTCNN(keep_all=True) + i = 0 + dataset_dict = {} + for clip in tqdm(srt_data['sub_text'].keys()): + if clip.startswith(""): + for t, names in enumerate(srt_data['sub_face'][clip]): + # print(t, text) + # print(t, srt_data['sub_time'][clip][t]) + # matching subtitles to all frames in that time frame + frame_num_begin = math.ceil(srt_data['sub_time'][clip][t] * 3) + frame_num_end = math.ceil(srt_data['sub_time_end'][clip][t] * 3) + frame_num = frame_num_begin + # for frame_num in range(frame_num_begin, frame_num_end): + # print(t, clip, srt_data['sub_time'][clip][t], frame_num, str(frame_num).zfill(5)+".jpg") + img_str = str(frame_num).zfill(5) + clip_dir = series_directory + clip + if not names: + names = [] + else: + names = [name_lbl[0] for name_lbl in names] + + # this is in case the frame number does not exist(usually happens for the last frames in folder) + try: + # make a function to draw + img = Image.open(os.path.join(clip_dir, img_str + ".jpg")) + boxes, probs, points = mtcnn.detect(img, landmarks=True) + + faces = [] + if boxes is not None: + img_draw = img.copy() + draw = ImageDraw.Draw(img_draw) + for f, (box, point) in enumerate(zip(boxes, points)): + draw.rectangle(box.tolist(), width=5) + faces.append("{}_{}_{}.png".format(clip, img_str, f)) + box[0] = box[0] - 20 + box[1] = box[1] - 20 + box[2] = box[2] + 20 + box[3] = box[3] + 20 + # for p in point: + # draw.rectangle((p - 10).tolist() + (p + 10).tolist(), width=10) + extract_face(img, box, + save_path=save_directory + "{}_{}_{}.png".format(clip, img_str, f)) + if points is not None: + points = points.tolist() + if boxes is not None: + boxes = boxes.tolist() + dataset_dict[i] = {"series": "bbt", "clip": clip, "img": img_str, "names": names, + "subtitle": srt_data['sub_text'][clip][t], "faces": faces, + "face_points": points, "bbox": boxes} + i += 1 + except OSError as e: + print(e) + continue + dataset_dict = self.stitch_tokens_in_dict(dataset_dict) + save_json(dataset_dict, dataset_path) + return load_json(dataset_path) + + def create_hist(self, file_name, lbl_type): + # lbl_type can be target_name(correct) or name(weak) + num_classes = len(self.lbl_to_id.keys()) + hist_bins = np.arange(num_classes + 1) + histogram = np.zeros((num_classes,), dtype=np.int) + if self.split == "test" or self.split == "dev": + if lbl_type == "target_name": + classes = [self.lbl_to_id[ann[lbl_type]] for ann in self.anns.values()] + elif lbl_type == "name": + classes = [] + for ann in self.anns.values(): + if ann[lbl_type]: + for lbl in ann[lbl_type]: + classes.append(self.lbl_to_id[lbl]) + elif self.split == "train": + classes = [self.lbl_to_id[ann[lbl_type]] for ann in self.anns.values()] + histogram += np.histogram(classes, bins=hist_bins)[0] + ind_sorted = np.argsort(histogram)[::-1] + bins = range(num_classes) + fig = plt.figure(figsize=(10, 8)) + + plt.bar(bins, height=histogram[ind_sorted], color='#3DA4AB') + # plt.yscale("log") + plt.ylabel("#instances", rotation=90) + id_to_lbl = {id: lbl for (lbl, id) in self.lbl_to_id.items()} + class_names = [id_to_lbl[ind] for ind in ind_sorted] + plt.xticks(bins, np.array(class_names), rotation=90, fontsize=10) + fig.savefig(os.path.join("", f"{file_name}.pdf")) + plt.clf() + + def create_cooccurance_matrix(self): + coocur_matrix = np.zeros((len(self.lbl_to_id), len(self.lbl_to_id)), np.float64) + for ann in self.dataset_dict.values(): + for i, n1 in enumerate(ann["names"]): + for j, n2 in enumerate(ann["names"]): + if i != j: + coocur_matrix[self.lbl_to_id[n1]][self.lbl_to_id[n2]] += 1 + + id_to_lbl = {id: lbl for (lbl, id) in self.lbl_to_id.items()} + num_classes = len(self.lbl_to_id.keys()) + class_names = [id_to_lbl[ind] for ind in range(num_classes)] + plt.imshow(coocur_matrix, cmap='plasma', interpolation='nearest') + + plt.xticks(range(num_classes), np.array(class_names), rotation=90, fontsize=6) + plt.yticks(range(num_classes), np.array(class_names), rotation=0, fontsize=6) + + # Plot a colorbar with label. + cb = plt.colorbar() + cb.set_label("Number of co-occurrences") + + # Add title and labels to plot. + plt.title("Co-occurrence of named entities in subtitles only") + plt.xlabel('Named Entities') + plt.ylabel('Named Entities') + plt.savefig('cooccurance_matrix_plasma.pdf') + plt.clf() + + def split_train_val_test(self, anns, path): + # want to use targets to do stratified split + y = [] + for ann in anns.values(): + y.append(self.lbl_to_id[ann["name"]]) + X = list(anns.keys()) + X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.90, random_state=49, shuffle=True, + stratify=y) + X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, train_size=0.88, random_state=50, + shuffle=True, stratify=y_train) + + # use these three subset indices to save the files for train, val, test + train_anns = {str(i): anns[idx] for i, idx in enumerate(X_train)} + val_anns = {str(i): anns[idx] for i, idx in enumerate(X_val)} + test_anns = {str(i): anns[idx] for i, idx in enumerate(X_test)} + save_json(train_anns, os.path.join(path, "train_annotations.json")) + save_json(val_anns, os.path.join(path, "val_annotations.json")) + save_json(test_anns, os.path.join(path, "test_annotations.json")) + + def split_test_dev_sets(self): + annotation_paths = glob.glob(os.path.join("./dataset/test_episodes/", f"labeled_annotations_test_s01e*.json")) + all_test_anns = {} + i = 0 + for ann_path in annotation_paths: + per_episode_ann = load_json(ann_path) + print(len(per_episode_ann.keys())) + for ann in per_episode_ann.values(): + all_test_anns[i] = ann + i += 1 + + all_inds = list(range(len(all_test_anns))) + random.shuffle(all_inds) + test_inds = all_inds[:math.ceil(len(all_inds) / 2)] + dev_inds = all_inds[math.ceil(len(all_inds) / 2) + 1:] + test_set = {i: all_test_anns[ind] for i, ind in enumerate(test_inds)} + dev_set = {i: all_test_anns[ind] for i, ind in enumerate(dev_inds)} + save_json(test_set, "./dataset/test.json") + save_json(dev_set, "./dataset/dev.json") + + def split_train_set(self, n=10): + annotation_paths = glob.glob(os.path.join("./dataset/train_episodes/", f"all_annotations_s01e*.json")) + all_train_anns = {} + i = 0 + for ann_path in annotation_paths: + per_episode_ann = load_json(ann_path) + print(len(per_episode_ann.keys())) + for ann in per_episode_ann.values(): + all_train_anns[i] = ann + i += 1 + save_json(all_train_anns, "./dataset/train.json") + + splitted_faces = {} + for ep in range(6, 25): + if ep < 10: + episode = "0" + str(ep) + elif ep == 16: + episode = "16-17" + elif ep == 17: + continue + else: + episode = str(ep) + faces = load_json(f"./dataset/train_episodes/faces_s01e{episode}.json") + random.shuffle(faces) + for i, face_split in enumerate(np.array_split(np.array(faces), n)): + if episode == "06": + splitted_faces[i] = [] + splitted_faces[i].extend(face_split) + + train_anns = {} + for i in range(n): + train_anns[i] = {} + for i, face_chunk in enumerate(splitted_faces.values()): + j = 0 + for face in face_chunk: + for ann in all_train_anns.values(): + if ann["face"] == face: + train_anns[i][j] = ann + j += 1 + for i in range(n): + save_json(train_anns[i], f"./dataset/part{i}_train.json") + + def create_confusion_matrix(self): + import pandas as pd + self.lbl_to_id["Unknown"] = len(self.lbl_to_id) + confusion_matrix = np.zeros((len(self.lbl_to_id), len(self.lbl_to_id)), np.float64) + exp = "tvqa_exp_overfit_notshuffle_split" + df = pd.read_excel(f"./output/{exp}/test_data.xls", usecols="A,B,C") # weak label, correct label, prediction + for i in range(len(df['correct label']) - 1): # last row is nothing + confusion_matrix[self.lbl_to_id[df['correct label'][i]]][self.lbl_to_id[df['prediction'][i]]] += 1 + print(f"label: {df['correct label'][i]}, predicitons: {df['prediction'][i]}") + + id_to_lbl = {id: lbl for (lbl, id) in self.lbl_to_id.items()} + num_classes = len(self.lbl_to_id.keys()) + class_names = [id_to_lbl[ind] for ind in range(num_classes)] + plt.imshow(confusion_matrix, cmap='plasma', interpolation='nearest') + + plt.xticks(range(num_classes), np.array(class_names), rotation=90, fontsize=6) + plt.yticks(range(num_classes), np.array(class_names), rotation=0, fontsize=6) + + # Plot a colorbar with label. + cb = plt.colorbar() + cb.set_label("Number of predictions") + + # Add title and labels to plot. + plt.title("Confusion Matrix for predictions and correct labels") + plt.xlabel('Correct Label') + plt.ylabel('Predicted Label') + plt.savefig('confusion_matrix.pdf') + plt.clf() + + def label_annotations(self): + import pandas as pd + dataset_loc = "./dataset/excel" + df = pd.read_excel(f"{dataset_loc}/new_all_data.xls", usecols="A,B") + self.lbl_to_id["Unknown"] = len(self.lbl_to_id) + for i, ann in enumerate(self.anns.values()): + print(df['correct label'][i], self.lbl_to_id[df['correct label'][i]]) + ann["target_name"] = df['correct label'][i] + ann["target_id"] = self.lbl_to_id[df['correct label'][i]] + save_json(self.anns, file_path=self.project_dir + f"dataset/labeled_clean_annotations.json") + + def label_annotations_for_test(self, episode): + import pandas as pd + dataset_loc = "./dataset/excel" + df = pd.read_excel(f"{dataset_loc}/s01e{episode}_new.xls", usecols="A,B,I") + # df = pd.read_excel(f"{dataset_loc}/missing_faces.xls", usecols="A,E") + # self.lbl_to_id["Unknown"] = len(self.lbl_to_id) + for ann in self.anns.values(): + # if ann["target_name"] == "Dunno": + filtered_df = df.loc[df['face_loc'] == ann['face']] + if not filtered_df['correct label'].empty: + # print(filtered_df['correct label'].iloc[0]) + # might have more matchings but the correct label is the same for all + ann["target_name"] = filtered_df['correct label'].iloc[0] + # ann["target_id"] = self.lbl_to_id[df['correct label'][i]] + else: + print(f"did not find face {ann['face']} in excel") + ann["target_name"] = "Dunno" + save_json(self.anns, + file_path=self.project_dir + f"dataset/test_episodes/really_labeled_annotations_test_s01e{episode}.json") + + +class OnlyUnknownsTVQADataset(TVQADataset): + + def __len__(self): + return len(self.anns_no_unknowns) + + def __getitem__(self, index): + ann1 = self.anns_no_unknowns[index] + dir = "./dataset/" + ann1['series'] + "_frames/" + print(f"directory:{dir}, cwd: {os.getcwd()}") + image1 = Image.open(os.path.join(dir, ann1['face'])) + + dataset_len = len(self) + random_index = int(np.random.random() * dataset_len) + ann2 = self.anns_no_unknowns[random_index] + + + image2 = Image.open(os.path.join(dir, ann2['face'])) + + img1, img2 = self.transform(image1), self.transform(image2) + dict1, dict2 = {}, {} + + if self.split == "train": + # target = [] + # if ann1['name']: + # for name in ann1['name']: + # if name in self.lbl_to_id.keys(): + # target.append(self.lbl_to_id[name]) + # else: + # target.append(self.lbl_to_id["Unknown"]) + # target = self.lbl_to_id[ann['name']] + dict1 = {"image": img1, + "subtitle": ann1["subtitle"], + } + dict2 = {"image": img2, + "subtitle": ann2["subtitle"], + } + elif self.split == "test": + dict1 = {"image": img1, + "correct_target_name": ann1["target_name"], + "correct_target_id": self.lbl_to_id[ann1["target_name"]], + "subtitle": ann1["subtitle"], + } + dict2 = {"image": img2, + "correct_target_name": ann2["target_name"], + "correct_target_id": self.lbl_to_id[ann2["target_name"]], + "subtitle": ann2["subtitle"], + } + + return dict1, dict2 + + +class TwoWeakOrAugTVQADataset(TVQADataset): + def get_random_sample(self, include_unknowns=True): + dataset_len = len(self) + while True: + random_index = int(np.random.random() * dataset_len) + ann = self.anns[str(random_index)] + if include_unknowns is False: + if self.is_unknown(ann): + continue + dir = "./dataset/" + ann['series'] + "_frames/" + image = Image.open(os.path.join(dir, ann['face'])) + break + return ann, image + + def __getitem__(self, index): + ann1 = self.anns[str(index)] + dir = "./dataset/" + ann1['series'] + "_frames/" + image1 = Image.open(os.path.join(dir, ann1['face'])) + if global_cfg.SSL.include_unknowns is False: + if self.is_unknown(ann1): + ann1, image1 = self.get_random_sample(include_unknowns=global_cfg.SSL.include_unknowns) + + # Get a random sample that has the same weak label + while True: + #todo: unknowns together + if self.is_unknown(ann1): # if no weak label exists, do augmentation + ann2 = ann1 + image2 = Image.open(os.path.join(dir, ann2['face'])) + break + ann2, image2 = self.get_random_sample(include_unknowns=global_cfg.SSL.include_unknowns) + if len(set(ann1['name']).intersection(set(ann2['name']))) > 0: + break + + img1, img2 = self.transform(image1), self.transform(image2) + dict1, dict2 = {}, {} + + if self.split == "train": + # target = [] + # if ann1['name']: + # for name in ann1['name']: + # if name in self.lbl_to_id.keys(): + # target.append(self.lbl_to_id[name]) + # else: + # target.append(self.lbl_to_id["Unknown"]) + # target = self.lbl_to_id[ann['name']] + dict1 = {"image": img1, + "subtitle": ann1["subtitle"], + } + dict2 = {"image": img2, + "subtitle": ann2["subtitle"], + } + elif self.split == "test": + dict1 = {"image": img1, + "correct_target_name": ann1["target_name"], + "correct_target_id": self.lbl_to_id[ann1["target_name"]], + "subtitle": ann1["subtitle"], + } + dict2 = {"image": img2, + "correct_target_name": ann2["target_name"], + "correct_target_id": self.lbl_to_id[ann2["target_name"]], + "subtitle": ann2["subtitle"], + } + + return dict1, dict2 + + +class TwoAugUnsupervisedTVQADataset(TVQADataset): + def __getitem__(self, index): + ann = self.anns[str(index)] + dir = "./dataset/" + ann['series'] + "_frames/" + image = Image.open(os.path.join(dir, ann['face'])) + + img1, img2 = self.transform(image), self.transform(image) + dict1, dict2 = {}, {} + + if self.split == "train": + target = [] + if ann['name']: + for name in ann['name']: + if name in self.lbl_to_id.keys(): + target.append(self.lbl_to_id[name]) + else: + target.append(self.lbl_to_id["Unknown"]) + # target = self.lbl_to_id[ann['name']] + dict1 = {"image": img1, + # "weak_label": target, + "subtitle": ann["subtitle"], + # "name": ann['name'], + } + dict2 = {"image": img2, + # "weak_label": target, + "subtitle": ann["subtitle"], + # "name": ann['name'], + } + elif self.split == "test": + dict1 = {"image": img1, + "correct_target_name": ann["target_name"], + "correct_target_id": self.lbl_to_id[ann["target_name"]], + "subtitle": ann["subtitle"], + } + dict2 = {"image": img2, + "correct_target_name": ann["target_name"], + "correct_target_id": self.lbl_to_id[ann["target_name"]], + "subtitle": ann["subtitle"], + } + + return dict1, dict2 + + +class BuildTVQADataset(TVQADataset): + def __init__(self): + self.entity_recognition_model = AutoModelForTokenClassification.from_pretrained( + "dbmdz/bert-large-cased-finetuned-conll03-english") + self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + self.label_list = [ + "O", # Outside of a named entity + "B-MISC", # Beginning of a miscellaneous entity right after another miscellaneous entity + "I-MISC", # Miscellaneous entity + "B-PER", # Beginning of a person's name right after another person's name + "I-PER", # Person's name + "B-ORG", # Beginning of an organisation right after another organisation + "I-ORG", # Organisation + "B-LOC", # Beginning of a location right after another location + "I-LOC" # Location + ] + self.project_dir = "./" + self.series = "friends" + # self.all_subtitles_loc = self.project_dir + "dataset/tvqa_subtitles/" + self.all_subtitles_loc = self.project_dir + f"dataset/hmtl/{self.series}/" + self.season = None + self.episode = None + self.subtitle_json = None + + self.build_dict() + self.anns = self.load_anns(None, ann_path=self.project_dir + f"dataset/{self.series}_annotations.json") + # self.split = "train" + # self.lbl_to_id = self.load_labels(self.anns, lbl_path=self.project_dir + f"dataset/{self.series}_lbl_to_id.json") + # self.create_hist("bbt_train_hist_plusunknowns", lbl_type="name") + + # self.build_evaluation_dict(self.project_dir + f"dataset/new_tvqa_plus_test_annotations.json") + # dataset_dict = self.add_missing_frames_from_test_set_to_trainset(dataset_path=self.project_dir + f"dataset/new_new_dict_with_names.json") + # dataset_dict = self.stitch_tokens_in_dict(dataset_dict) + # save_json(dataset_dict, f"./dataset/new_stitched_new_dict_with_names.json") + # dataset_dict = load_json(os.path.join(self.project_dir, f"dataset/new_stitched_new_dict_with_names.json")) + # # clean_dataset_dict = self.clean_dict(dataset_dict, path=self.project_dir + f"dataset/clean_new_dict_with_names.json") + # # new_anns = self.load_anns(dataset_dict, ann_path=self.project_dir + f"dataset/new_new_dict_with_names_{self.series}_annotations.json") + # new_anns = self.load_anns(dataset_dict, ann_path=self.project_dir + f"dataset/__new_anns_{self.series}_annotations.json") + # old_dict = load_json("./dataset/bbt_dict.json") + # old_anns = self.load_anns(old_dict, ann_path=self.project_dir + f"dataset/__old_anns_{self.series}_annotations.json") + # # old_anns = self.load_anns(None, ann_path=self.project_dir + f"dataset/{self.series}_annotations.json") + # start = len(old_anns) + # for ann in new_anns.values(): + # old_anns[str(start)] = ann + # start += 1 + # save_json(old_anns, "./dataset/compare_with_bbt_train_annotations.json") + # save_json(old_anns, "./dataset/new_mix_bbt_annotations.json") + + def load_srt(self, srt_dir, srt_cache_path): + """ + return: A python dict, the keys are the video names, the entries are lists, + each contains all the text from a .srt file + sub_times are the start time of the sentences. + """ + if os.path.exists(srt_cache_path): + print("Found srt data cache, loading ...") + return load_json(srt_cache_path) + print("Loading srt files from %s ..." % srt_dir) + # srt_paths = glob.glob(os.path.join(srt_dir, "friends_s01e02_seg02_clip_*.srt")) + + srt_paths = glob.glob(os.path.join(srt_dir, f"friends_s{self.season}e{self.episode}_seg0*.srt")) + # srt_paths = glob.glob(os.path.join(srt_dir, f".srt")) + name2sub_text = {} + name2sub_face = {} + name2sub_other = {} + name2sub_time = {} + name2sub_time_end = {} + for i in tqdm(range(len(srt_paths))): + subs = pysrt.open(srt_paths[i], encoding="iso-8859-1") + if len(subs) == 0: + subs = pysrt.open(srt_paths[i]) + + text_list = [] + name_list = [] + other_list = [] + sub_time_list = [] + sub_time_list_end = [] + for j in range(len(subs)): + cur_sub = subs[j] + cur_str = cur_sub.text + cur_str = "(:)" + cur_str if cur_str[0] != "(" else cur_str + cur_str = cur_str.replace("\n", " ") + names, other_labels = self.find_names(cur_str) + text_list.append(cur_str) + name_list.append(names) + other_list.append(other_labels) + sub_time_list.append( + 60 * cur_sub.start.minutes + cur_sub.start.seconds + 0.001 * cur_sub.start.milliseconds) + sub_time_list_end.append( + 60 * cur_sub.end.minutes + cur_sub.end.seconds + 0.001 * cur_sub.end.milliseconds) + + key_str = os.path.splitext(os.path.basename(srt_paths[i]))[0] + name2sub_text[key_str] = text_list + name2sub_face[key_str] = name_list + name2sub_other[key_str] = other_list + name2sub_time[key_str] = sub_time_list + name2sub_time_end[key_str] = sub_time_list_end + if len(srt_paths) == 0: + print("season and episode combination doesn't exist.") + return + srt_data = {"sub_text": name2sub_text, + "sub_face": name2sub_face, + "sub_other": name2sub_other, + "sub_time": name2sub_time, + "sub_time_end": name2sub_time_end} + save_json(srt_data, srt_cache_path) + return load_json(srt_cache_path) # we do this because the ints will turn to string 0 -> '0' + + def load_anns(self, dataset, ann_path): + + if os.path.exists(ann_path): + print("Found annotation cache, loading ...") + return load_json(ann_path) + + anns = {} + all_faces = [] + annid = 0 + for img in tqdm(dataset.keys()): + # if there is at least one face: + if dataset[img]['faces']: + for face, bbox, face_landmark in zip(dataset[img]['faces'], dataset[img]['bbox'], dataset[img]['face_points']): + anns[annid] = {'face': face, + 'name': dataset[img]['names'], + 'img': dataset[img]['img'], + 'subtitle': dataset[img]['subtitle'], + 'clip': dataset[img]['clip'], + 'series': dataset[img]['series'], + 'bbox': bbox, + 'face_points': face_landmark} + annid += 1 + # if dataset[img]['names']: + # for name in dataset[img]['names']: + # anns[annid] = {'face': face, + # 'name': name, + # 'img': dataset[img]['img'], + # 'subtitle': dataset[img]['subtitle'], + # 'clip': dataset[img]['clip'], + # 'series': dataset[img]['series'], + # 'bbox': bbox, + # 'face_points': face_landmark} + # annid += 1 + # else: + # anns[annid] = {'face': face, + # 'name': "Unknown", + # 'img': dataset[img]['img'], + # 'subtitle': dataset[img]['subtitle'], + # 'clip': dataset[img]['clip'], + # 'series': dataset[img]['series'], + # 'bbox': bbox, + # 'face_points': face_landmark} + # annid += 1 + save_json(anns, ann_path) + return load_json(ann_path) + + + def load_ann_mentions(self, dataset, ann_path): + + if os.path.exists(ann_path): + print("Found annotation cache, loading ...") + return load_json(ann_path) + + anns = {} + all_faces = [] + annid = 0 + for img in tqdm(dataset.keys()): + # if there is at least one face: + if dataset[img]['faces']: + for face, bbox, face_landmark in zip(dataset[img]['faces'], dataset[img]['bbox'], dataset[img]['face_points']): + anns[annid] = {'face': face, + 'name': dataset[img]['names'], + 'img': dataset[img]['img'], + 'subtitle': dataset[img]['subtitle'], + 'clip': dataset[img]['clip'], + 'series': dataset[img]['series'], + 'bbox': bbox, + 'face_points': face_landmark} + annid += 1 + # if dataset[img]['names']: + # for name in dataset[img]['names']: + # anns[annid] = {'face': face, + # 'name': name, + # 'img': dataset[img]['img'], + # 'subtitle': dataset[img]['subtitle'], + # 'clip': dataset[img]['clip'], + # 'series': dataset[img]['series'], + # 'bbox': bbox, + # 'face_points': face_landmark} + # annid += 1 + # else: + # anns[annid] = {'face': face, + # 'name': "Unknown", + # 'img': dataset[img]['img'], + # 'subtitle': dataset[img]['subtitle'], + # 'clip': dataset[img]['clip'], + # 'series': dataset[img]['series'], + # 'bbox': bbox, + # 'face_points': face_landmark} + # annid += 1 + save_json(anns, ann_path) + return load_json(ann_path) + + + def build_dataset_mentions(self, srt_data, dataset_path): + if srt_data is None: + return + series_list = ["castle", "friends", "grey", "house", "met", "bbt"] + series_directory = self.project_dir + f"dataset/frames_hq/{series_list[1]}_frames/" + save_directory = self.project_dir + f"/dataset/{series_list[1]}_frames/" + all_mentions = set(load_json(self.project_dir + f"/dataset/hmtl/mentions/list_mentions_filtered_{self.series}.json")) # todo list or set + Path(save_directory).mkdir(parents=True, exist_ok=True) + + if os.path.exists(dataset_path): + print("Found dataset cache, loading ...") + return load_json(dataset_path) + + # mtcnn = MTCNN(keep_all=True) + i = 0 + dataset_dict = {} + for clip in tqdm(srt_data['mentions'].keys()): + for t, mentions in enumerate(srt_data['mentions'][clip]): + + intersect = set(mentions).intersection(all_mentions) + if intersect: + # print(intersect) + frame_num_begin = math.ceil(srt_data['sub_time'][clip][t] * 3) + frame_num_end = math.ceil(srt_data['sub_time_end'][clip][t] * 3) + clip_dir = series_directory + clip + + for frame_num in range(frame_num_begin, frame_num_begin + 4): + img_str = str(frame_num).zfill(5) + if os.path.exists(os.path.join(clip_dir, img_str + ".jpg")): + # make a function to draw + # img = Image.open(os.path.join(clip_dir, img_str + ".jpg")) + + dataset_dict[i] = {"series": self.series, "clip_dir": clip_dir, "clip": clip, + "img": img_str, + "subtitle": srt_data['sub_text'][clip][t], + "mentions": list(mentions), + "filtered_mentions": list(intersect), + "first_frame": True if frame_num == frame_num_begin else False} + i += 1 + + # dataset_dict = self.stitch_tokens_in_dict(dataset_dict) + save_json(dataset_dict, dataset_path) + return load_json(dataset_path) + + def build_dict(self): + ''' + This function extracts all the information needed from the "tvqa dataset" and builds the annotation files + so we could later run TVQADataset on it and load the data + :return: + ''' + for s in range(1, 11): + for i in range(1, 26): + if i < 10: + self.episode = f"0{i}" + else: + self.episode = f"{i}" + if s < 10: + self.season = f"0{s}" + else: + self.season = f"{s}" + print(f"season: {self.season}, episode: {self.episode}") + + self.subtitle_json = self.all_subtitles_loc + f"subtitle_cache_{self.series}_s{self.season}e{self.episode}.json" + srt_data = self.load_srt(self.all_subtitles_loc, self.subtitle_json) + # dataset_dict = self.prepare_tvqa_json(srt_data, dataset_path=self.project_dir + f"dataset/{self.series}_dict_s{self.season}e{self.episode}.json") + dataset_dict = self.build_dataset_mentions(srt_data, dataset_path=self.project_dir + f"dataset/mentions/{self.series}_dict_s{self.season}e{self.episode}.json") + dict_paths = glob.glob(os.path.join("./", f"dataset/mentions/{self.series}_dict_*.json")) + stitched_dict = {} + for i in tqdm(range(len(dict_paths))): + start = len(stitched_dict.keys()) + dict_i = load_json(dict_paths[i]) + stitched_dict.update({start+key: value for key, value in enumerate(dict_i.values())}) + + save_json(stitched_dict, f"./dataset/mentions/{self.series}_dict.json") + dataset_dict = load_json(os.path.join(self.project_dir, f"dataset/mentions/{self.series}_dict.json")) + # + # # clean_dataset_dict = self.clean_dict(dataset_dict, path=self.project_dir + f"dataset/clean_{self.series}_dict.json") + # self.anns = self.load_anns(dataset_dict, ann_path=self.project_dir + f"dataset/{self.series}_annotations.json") + + def add_missing_frames_from_test_set_to_trainset(self, dataset_path): + if os.path.exists(dataset_path): + return load_json(dataset_path) + + save_directory = self.project_dir + f"dataset/bbt_frames/" + # test_anns = load_json(f"./dataset/new_tvqa_plus_test_annotations.json") + # images_in_test = set() + # for ann in test_anns.values(): + # images_in_test.add(f"{ann['clip']}_{ann['img']}") + # save_json(list(images_in_test), "./dataset/images_in_test.json") + + images_in_test = set(load_json("./dataset/images_in_test.json")) + train_images_with_face = set(load_json("./dataset/images_with_face.json")) + mtcnn = MTCNN(keep_all=True) + i = 0 + new_dict = {} + + for test_img in tqdm(list(images_in_test)): + if test_img not in train_images_with_face: + clip, img_name = "_".join(test_img.split('_')[:-1]), test_img.split('_')[-1] + img_path = os.path.join("./", f"dataset/frames_hq/bbt_frames/{clip}/{img_name}.jpg") + + subtitle, names = find_corresponding_subtitle(clip, int(img_name)) + + img = Image.open(img_path) + boxes, probs, points = mtcnn.detect(img, landmarks=True) + faces = [] + if boxes is not None: + img_draw = img.copy() + draw = ImageDraw.Draw(img_draw) + for f, (box, point) in enumerate(zip(boxes, points)): + draw.rectangle(box.tolist(), width=5) + faces.append("{}_{}_{}.png".format(clip, img_name, f)) + box[0] = box[0] - 20 + box[1] = box[1] - 20 + box[2] = box[2] + 20 + box[3] = box[3] + 20 + extract_face(img, box, + save_path=save_directory + "{}_{}_{}.png".format(clip, img_name, f)) + if points is not None: + points = points.tolist() + if boxes is not None: + boxes = boxes.tolist() + + new_dict[i] = {"series": "bbt", + "clip": clip, + "img": img_name, + "names": names, + "subtitle": subtitle, + "faces": faces, + "face_points": points, + "bbox": boxes} + + i += 1 + + save_json(new_dict, dataset_path) + return load_json(dataset_path) + + def add_missing_frames(self, dataset_path): + save_directory = self.project_dir + f"dataset/bbt_frames/" + images_with_face = load_json("./dataset/images_with_face.json") + images_with_face = set(images_with_face) + + mtcnn = MTCNN(keep_all=True) + i = 0 + new_dict = {} + + clip_paths = glob.glob(os.path.join("./", f"dataset/frames_hq/bbt_frames/*")) + for clip_path in clip_paths: + clip = clip_path.split('/')[-1] + img_paths = glob.glob(os.path.join("./", f"dataset/frames_hq/bbt_frames/{clip}/*")) + for img_path in img_paths: + img_name = img_path.split('/')[-1] + img_name = img_name.split('.')[0] + + if clip+'_'+img_name not in images_with_face: # we found a missing face + img = Image.open(img_path) + boxes, probs, points = mtcnn.detect(img, landmarks=True) + faces = [] + if boxes is not None: + img_draw = img.copy() + draw = ImageDraw.Draw(img_draw) + for f, (box, point) in enumerate(zip(boxes, points)): + draw.rectangle(box.tolist(), width=5) + faces.append("{}_{}_{}.png".format(clip, img_name, f)) + box[0] = box[0] - 20 + box[1] = box[1] - 20 + box[2] = box[2] + 20 + box[3] = box[3] + 20 + extract_face(img, box, + save_path=save_directory + "{}_{}_{}.png".format(clip, img_name, f)) + if points is not None: + points = points.tolist() + if boxes is not None: + boxes = boxes.tolist() + new_dict[i] = {"series": "bbt", + "clip": clip, + "img": img_name, + "names": [], + "subtitle": "", + "faces": faces, + "face_points": points, + "bbox": boxes} + i += 1 + + save_json(new_dict, dataset_path) + return load_json(dataset_path) + + def build_evaluation_dict(self, dataset_path): + ''' + step 1: if no subtitle, discard + step 2: if obj['label'] in objects.json then it's an object -> discard + step 3: iou -> not implemented here + step 4: if obj['label'] not in fuzzy_name_matchings then it's unknown + This function loads the annotatations of "tvqa+" dataset for evaluation + :return: + ''' + series_list = ["castle", "friends", "grey", "house", "met", "bbt"] + series_directory = self.project_dir + f"dataset/frames_hq/{series_list[5]}_frames/" + save_directory = self.project_dir + f"dataset/new_tvqa_plus_{series_list[5]}_frames/" + objects = set(load_json(self.project_dir + f"dataset/objects.json")) + + Path(save_directory).mkdir(parents=True, exist_ok=True) + fuzzy_name_matchings = load_json("./dataset/fuzzy_name_matchings.json") + main_characters = [item for sublist in fuzzy_name_matchings.values() for item in sublist] + + if os.path.exists(dataset_path): + print("Found dataset cache, loading ...") + return load_json(dataset_path) + + def complete_dict(tvqa_plus_dict, i, anns): + for ann in tqdm(anns): + clip = ann['vid_name'] + for img_id, image in ann['bbox'].items(): + # step 1: if no subtitle, discard + subtitle, names = find_corresponding_subtitle(clip, int(img_id)) + if not subtitle: + continue + for obj_id, obj in enumerate(image): + # step 2: if obj['label'] in objects.json then it's an object -> discard + if obj['label'] in objects: + continue + # step 3: if iou = 0 then we assume it's an object -> discard -> this is done later + + # step 4: if obj['label'] not in fuzzy_name_matchings then it's unknown + if obj['label'] not in main_characters: + char_name = "Unknown" + else: + for char_name in fuzzy_name_matchings.keys(): + if obj['label'] in fuzzy_name_matchings[char_name]: + break + img_str = str(obj['img_id']).zfill(5) + clip_dir = series_directory + ann['vid_name'] + img = Image.open(os.path.join(clip_dir, img_str + ".jpg")) + img = img.crop((obj['left'], obj['top'], obj['left'] + obj['width'], obj['top'] + obj['height'])) + img.save(save_directory + "{}_{}_{}.png".format(ann['vid_name'], img_str, obj_id)) + tvqa_plus_dict[i] = {"series": "bbt", + "clip": ann['vid_name'], + "subtitle": subtitle, + "face": "{}_{}_{}.png".format(ann['vid_name'], img_str, obj_id), + "img": img_str, + "weak_lbls": names, + "name": char_name, #gt_label + "bbox": [obj['left'], obj['top'], obj['left'] + obj['width'], + obj['top'] + obj['height']]} + i += 1 + return tvqa_plus_dict, i + + tvqa_plus_dict = {} + i = 0 + train_anns = load_json(self.project_dir + f"dataset/tvqa+/tvqa_plus_annotations/tvqa_plus_train.json") + tvqa_plus_dict, i = complete_dict(tvqa_plus_dict, i, train_anns) + + val_anns = load_json(self.project_dir + f"dataset/tvqa+/tvqa_plus_annotations/tvqa_plus_val.json") + tvqa_plus_dict, i = complete_dict(tvqa_plus_dict, i, val_anns) + + save_json(tvqa_plus_dict, dataset_path) + return load_json(dataset_path) + + +def find_corresponding_subtitle(clip, img_number): + srt_data = load_json(f"./dataset/tvqa_subtitles/subtitle_cache_bbt_{clip.split('_')[0]}.json") + subtitle = "" + for t, (begin_time, end_time) in enumerate( + zip(srt_data['sub_time'][clip], srt_data['sub_time_end'][clip])): + current_time = math.floor(img_number / 3) + names = [] + if current_time >= begin_time and current_time <= end_time: + subtitle = srt_data['sub_text'][clip][t] + names = srt_data['sub_face'][clip][t] + if names: + names = [name_lbl[0] for name_lbl in names] + break + + return subtitle, names + + +def fuzzy_matching(): + + train = load_json("./tvqa+/tvqa_plus_annotations/tvqa_plus_train.json") + gt_labels = [] + for ann in train: + for image in ann['bbox'].values(): + for obj in image: + gt_labels.append(obj['label']) + + val = load_json("./tvqa+/tvqa_plus_annotations/tvqa_plus_val.json") + gt_labels_val = [] + for ann in val: + for image in ann['bbox'].values(): + for obj in image: + gt_labels_val.append(obj['label']) + + def create_hist(obj_list): + labels = set(obj_list) + num_classes = len(labels) + # print(f"number of labels: {num_classes}, The unique labels are: {labels}") + lbl_to_id = {} + for idx, lbl in enumerate(labels): + lbl_to_id[lbl] = idx + hist_bins = np.arange(num_classes + 1) + histogram = np.zeros((num_classes,), dtype=np.int) + classes = [lbl_to_id[obj] for obj in obj_list] + histogram += np.histogram(classes, bins=hist_bins)[0] + ind_sorted = np.argsort(histogram)[::-1] + id_to_lbl = {id: lbl for (lbl, id) in lbl_to_id.items()} + class_names_sorted = [id_to_lbl[ind] for ind in ind_sorted] + return class_names_sorted + + cls_names = create_hist(gt_labels) + + scores = torch.zeros([len(cls_names), len(cls_names)]) + for i1, n1 in enumerate(cls_names): + for i2, n2 in enumerate(cls_names): + scores[i1, i2] = fuzz.ratio(n1, n2) + + cls_names_pd = pd.DataFrame(cls_names, columns=['cls_names']) + scores_df = pd.DataFrame(scores) + + fuzzy_name_matchings = {} + # indices corresponding to these classes Sheldon, Leonard, Penny, Howard, Raj, Amy, Bernadette, Stuart + for i, name in zip([0, 1, 2, 3, 4, 5, 6, 15],["Sheldon", "Leonard", "Penny", "Howard", "Raj", "Amy", "Bernadette", "Stuart"]): + fuzzy_name_matchings[name] = set(cls_names_pd.loc[scores_df[i] > 70, "cls_names"]) + + cls_names_val = create_hist(gt_labels_val) + scores_val = torch.zeros([len(cls_names_val), len(cls_names_val)]) + for i1, n1 in enumerate(cls_names_val): + for i2, n2 in enumerate(cls_names_val): + scores_val[i1, i2] = fuzz.ratio(n1, n2) + + cls_names_val_pd = pd.DataFrame(cls_names_val, columns=['cls_names_val']) + scores_val_df = pd.DataFrame(scores_val) + + # indices corresponding to these classes Sheldon, Leonard, Penny, Howard, Raj, Amy, Bernadette, Stuart + for i, name in zip([0, 1, 2, 3, 4, 5, 6, 15], ["Sheldon", "Leonard", "Penny", "Howard", "Raj", "Amy", "Bernadette", "Stuart"]): + fuzzy_name_matchings[name].update(set(cls_names_val_pd.loc[scores_val_df[0] > 70, "cls_names_val"])) + + # you need to check them by hand if they make sense + save_json({key: list(value) for key, value in fuzzy_name_matchings.items()}, "./dataset/fuzzy_name_matchings.json") + + +def split_train_test_episodes(): + train_dataset = TVQADataset(series="friends", split="train", transform=get_test_transforms(series="friends")) + test_dataset = TVQADataset(series="friends", split="test", transform=get_test_transforms(series="friends")) + train_friends = {} + test_friends = {} + i, j = 0, 0 + for train_ann, test_ann in zip(train_dataset.anns.values(), test_dataset.anns.values()): + if train_ann['clip'].split("_")[1].endswith('5'): + test_friends[i] = {**train_ann, **test_ann} + i += 1 + else: + train_friends[j] = {**train_ann, **test_ann} + j += 1 + save_json(train_friends, "./dataset/friends_train_annotations.json") + save_json(test_friends, "./dataset/friends_test_annotations.json") + + +if __name__ == "__main__": + # transform = transforms.Compose([transforms.ToTensor()]) + # transform = get_transforms() + # dataset = TwoAugUnsupervisedTVQADataset(split="train", transform=transform) + # dataset = TVQADataset(split="test", transform=transforms.Compose([transforms.ToTensor()])) + + # dataset = TVQADataset(series="bbt", split="train", transform=transforms.Compose([transforms.ToTensor()])) + # mean, std = find_norm_data(dataset) + # print(mean, std) + # dataset = BuildTVQADataset() + import shutil + img_paths = glob.glob(os.path.join("./", "dataset/new_tvqa_plus_bbt_frames/*")) + mtcnn = MTCNN(keep_all=True) + for img_path in tqdm(img_paths): + img = Image.open(img_path) + try: + boxes, probs, points = mtcnn.detect(img, landmarks=True) + if boxes is not None: + if len(boxes) != 1: + print(f"image {img_path} has {len(boxes)} faces") + img_draw = img.copy() + draw = ImageDraw.Draw(img_draw) + for f, (box, point) in enumerate(zip(boxes, points)): + draw.rectangle(box.tolist(), width=5) + box[0] = box[0] - 20 + box[1] = box[1] - 20 + box[2] = box[2] + 20 + box[3] = box[3] + 20 + extract_face(img, box, save_path=f"./dataset/new_tvqa_plus_bbt_frames_onlyfaces/{img_path.split('/')[-1]}") + else: + print(f"image{img_path} has no faces") + shutil.copy(img_path, f"./dataset/new_tvqa_plus_bbt_frames_onlyfaces/{img_path.split('/')[-1]}") + except Exception as e: + print(e) + print(f"image {img_path} is not valid") + shutil.copy(img_path, f"./dataset/new_tvqa_plus_bbt_frames_onlyfaces/{img_path.split('/')[-1]}") + continue + + + + # face_to_train_idx = {} + # for idx, ann in train_dataset.anns.items(): + # if ann['face'] in face_to_train_idx: + # print("face already existed") + # face_to_train_idx[ann['face']] = idx + + + # old = load_json("./dataset/bbt_train_annotations.json") + # new = load_json("./dataset/compare_with_bbt_train_annotations.json") + # for (o_i, o) in old.items(): + # if int(o_i) <= 112672: + # o['name'] = new[o_i]['name'] + # else: + # for (n_i, n) in new.items(): + # if int(n_i) <= 112672: + # continue + # if o['face'] not in ["s08e20_seg02_clip_08_00108_0.png", "s08e23_seg01_clip_00_00153_0.png", + # "s08e23_seg01_clip_00_00153_1.png", "s08e23_seg01_clip_00_00153_2.png"]: + # if o['face'] == n['face']: + # o['name'] = n['name'] + # break + # old['114324']['name'] = [] + # old['114325']['name'] = [] + # old['114326']['name'] = [] + # print(len(old)) + # save_json(old, "./dataset/bbt_train_annotations__.json") + # import torch + # new_train_anns = load_json("./dataset/bbt_train_annotations.json") + # old_embs = torch.load("./output/evaluate_bbt/model/new_bbt_face_embeddings.pt") + # for iteration, ann in tqdm(enumerate(new_train_anns.values())): + # if iteration == 0: + # new_embs = old_embs[int(ann['old_train_indices'][0])] + # else: + # new_embs = torch.cat((new_embs, old_embs[int(ann['old_train_indices'][0])]), 0) + # torch.save(new_embs, "./dataset/bbt_faceeee_embeddings.pt") + # em = torch.load("./dataset/bbt_faceeee_embeddings.pt") + # em = em.reshape((272603, 512)) + # torch.save(em, "./dataset/bbt_face_embeddings_hopefully_correct.pt") + # splits = ["test", "dev"] + # splits = ["train"] + # lbl_types = ["name", "target_name"] + # for lbl_type in lbl_types: + # for split in splits: + # tvqa = TVQADataset(lbl_type=lbl_type, split=split, episode=None, transform=transform) + # tvqa.create_hist(f"{split}_hist_{lbl_type}", lbl_type=lbl_type) + + # train_tvqa = TVQADataset(split="train", transform=transform) + # train_tvqa = TVQADataset(lbl_type="name", split="train", episode=None, transform=transform) + # train_tvqa.create_hist("train_hist", lbl_type="name") + # dev_tvqa = TVQADataset(split="dev", transform=transform) + # dev_tvqa.create_hist("dev_hist", lbl_type="target_name") + # test_tvqa = TVQADataset(split="test", transform=transform) + # test_tvqa.create_hist("test_hist", lbl_type="target_name") + + # function: copy old target label annotations into new annonation file (also had to annotate some with hand) + # old_ann_loc = "./dataset/2/2_all_annotations.json" + # old_anns = load_json(old_ann_loc) + # + # for i, new_ann in enumerate(tvqa.anns.values()): + # for old_ann in old_anns.values(): + # if old_ann['face'] == new_ann['face']: + # new_ann["target_name"] = old_ann["target_name"] + # new_ann["target_id"] = old_ann["target_id"] + # break + # save_json(tvqa.anns, "./dataset/blabla.json") diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..edb2e91 --- /dev/null +++ b/utils.py @@ -0,0 +1,24 @@ +import os +import yaml + + +def yaml_config_hook(config_file): + """ + Custom YAML config loader, which can include other yaml files (I like using config files + insteaad of using argparser) + """ + + # load yaml files in the nested 'defaults' section, which include defaults for experiments + with open(config_file) as f: + cfg = yaml.safe_load(f) + for d in cfg.get("defaults", []): + config_dir, cf = d.popitem() + cf = os.path.join(os.path.dirname(config_file), config_dir, cf + ".yaml") + with open(cf) as f: + l = yaml.safe_load(f) + cfg.update(l) + + if "defaults" in cfg.keys(): + del cfg["defaults"] + + return cfg diff --git a/visualize.py b/visualize.py new file mode 100644 index 0000000..ed4f848 --- /dev/null +++ b/visualize.py @@ -0,0 +1,1770 @@ +import torch +import os +from PIL import Image, ImageDraw +from facenet_pytorch import MTCNN, extract_face, InceptionResnetV1 +import random +import torchvision.models as models +from torchvision.models.resnet import BasicBlock +import torch.optim as optim +from torchvision import transforms +from tvqa_dataset import TVQADataset, load_json, save_json, get_test_transforms +from train import VGGFacePlus, VGGFaceSubtitle, VGGSupervised +import json +from torch.utils.data import DataLoader +from pathlib import Path +from tqdm import tqdm +import xlsxwriter +import pandas as pd +import numpy as np +from sklearn.manifold import TSNE +from sklearn.cluster import AgglomerativeClustering, KMeans, MiniBatchKMeans +import matplotlib.pyplot as plt +import seaborn as sns +import matplotlib.patheffects as PathEffects +import openpyxl +from openpyxl_image_loader import SheetImageLoader +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +import math +import torch.nn.functional as F +import umap +import umap.plot +from matplotlib.colors import LinearSegmentedColormap +from sklearn.metrics import confusion_matrix, classification_report +import logging +import argparse +from simcse import SimCSE +from train import L2Norm +from config import default_argument_parser, get_cfg, set_global_cfg, global_cfg +from fvcore.common.file_io import PathManager +from sentence_transformers import SentenceTransformer +from matplotlib.patches import Ellipse + + +def weighted_purity(Y, C): + """Computes weighted purity of HAC at one particular clustering "C". + Y, C: np.array([...]) containing unique cluster indices (need not be same!) + Note: purity --> 1 as the number of clusters increase, so don't look at this number alone! + """ + + purity = 0. + uniq_clid, clustering_skew = np.unique(C, return_counts=True) + num_samples = np.zeros(uniq_clid.shape) + # loop over all predicted clusters in C, and measure each one's cardinality and purity + for k in uniq_clid: + # gt labels for samples in this cluster + k_gt = Y[np.where(C == k)[0]] + values, counts = np.unique(k_gt, return_counts=True) + # technically purity = max(counts) / sum(counts), but in WCP, the sum(counts) multiplies to "weight" the clusters + purity += max(counts) + + purity /= Y.shape[0] + return purity, clustering_skew + + +def NMI(Y, C): + """Normalized Mutual Information: Clustering performance between ground-truth Y and prediction C + Based on https://course.ccs.neu.edu/cs6140sp15/7_locality_cluster/Assignment-6/NMI.pdf + Result matches examples on pdf + Example: + Y = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) + C = np.array([1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2]) + NMI(Y, C) = 0.1089 + C = np.array([1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]) + NMI(Y, C) = 0.2533 + """ + + def entropy(labels): + # H(Y) and H(C) + H = 0. + for k in np.unique(labels): + p = (labels == k).sum() / labels.size + H -= p * np.log2(p) + return H + + def h_y_given_c(labels, pred): + # H(Y | C) + H = 0. + for c in np.unique(pred): + p_c = (pred == c).sum() / pred.size + labels_c = labels[pred == c] + for k in np.unique(labels_c): + p = (labels_c == k).sum() / labels_c.size + H -= p_c * p * np.log2(p) + return H + + h_Y = entropy(Y) + h_C = entropy(C) + h_Y_C = h_y_given_c(Y, C) + # I(Y; C) = H(Y) - H(Y|C) + mi = h_Y - h_Y_C + # NMI = 2 * MI / (H(Y) + H(C)) + nmi = 2 * mi / (h_Y + h_C) + return nmi + + +def to_1D(series): + return pd.Series([x.item() for _list in series for x in _list]) + + +class SaveOutput: + """ + Utility function to visualize the outputs of PCA and t-SNE + """ + def __init__(self): + self.outputs = [] + + def __call__(self, module, module_in, module_out): + self.outputs.append(module_out) + + def clear(self): + self.outputs = [] + + +def get_children(model: torch.nn.Module): + # get children form model! + children = list(model.children()) + flatt_children = [] + if children == []: + # if model has no children; model is last child! :O + return model + else: + # look for children from children... to the last child! + for child in children: + try: + flatt_children.extend(get_children(child)) + except TypeError: + flatt_children.append(get_children(child)) + return flatt_children + + +def discrete_cmap(N, base_cmap=None): + """Create an N-bin discrete colormap from the specified input map""" + + # Note that if base_cmap is a string or None, you can simply do + # return plt.cm.get_cmap(base_cmap, N) + # The following works for string, None, or a colormap instance: + if base_cmap is None: + return plt.cm.get_cmap(base_cmap, N) + + base = plt.cm.get_cmap(base_cmap) + color_list = base(np.linspace(0, 1, N)) + cmap_name = base.name + str(N) + return LinearSegmentedColormap.from_list(cmap_name, color_list, N) + + +def make_excel(model, dataset, source, data_path, path_to_faces, file_name): + """ + make_excel(model=model, dataset=tvqa_all, source="dataloader", + # data_path="./dataset/frames_hq/friends_frames/", + # path_to_faces="./dataset/friends_frames/", file_name="new_all_data.xlsx") + make_excel(model=model, dataset=tvqa_test, source="dataloader", file_name="test_data.xlsx") + make_excel(model=model, dataset=tvqa_all, source="dataloader", file_name="all_data.xlsx") + :param model: + :param dataset: + :param source: + :param file_name: + :return: + """ + workbook = xlsxwriter.Workbook(f"./dataset/excel/{file_name}") + worksheet = workbook.add_worksheet() + + if source == "dataloader": + data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) + worksheet.write('A1', 'weak label') + worksheet.write('B1', 'correct label') + worksheet.write('C1', 'prediction') + worksheet.write('D1', 'week=correct') + worksheet.write('E1', 'predict=correct') + worksheet.write('F1', 'image') + worksheet.write('G1', 'face') + worksheet.write('H1', 'image_loc') + worksheet.write('I1', 'face_loc') + + for iteration, data in tqdm(enumerate(data_loader)): + # predictions = model(data['image']) + ann = dataset.anns[str(iteration)] + # print(os.path.join(path_to_faces, ann["face"])) + img = Image.open(os.path.join(path_to_faces, ann["face"])) + draw = ImageDraw.Draw(img) + draw.text((0, 0), f'label: {ann["name"]}', (255, 255, 255)) + # draw.text((0, 10), f'prediction: {dataset.id_to_lbs[int(predictions.argmax())]}', (255, 255, 255)) + # img.save(os.path.join(vis_path, ann["face"])) + worksheet.write(f'A{iteration+2}', ann["name"]) #weak label + # worksheet.write(f'C{iteration+2}', dataset.id_to_lbs[int(predictions.argmax())]) #prediction + face = ann["face"] + worksheet.insert_image(f'F{iteration+2}', os.path.join(data_path, face[0:28]+f"/{face.split('_')[-2]}.jpg")) #image + # face: friends_s01e01_seg02_clip_03_00108_0.png + # tvqa_experiment/dataset/frames_hq/friends_frames/friends_s01e01_seg02_clip_03/00108.jpg + worksheet.insert_image(f'G{iteration + 2}', os.path.join(path_to_faces, ann["face"])) # face + worksheet.write(f'H{iteration + 2}', face[0:28]+f"/{face.split('_')[-2]}.jpg") # image_loc + worksheet.write(f'I{iteration + 2}', ann["face"]) # face_loc + + workbook.close() + + +def make_excel_for_test(new_anns): + import xlsxwriter + data_path = "./dataset/frames_hq/friends_frames/" + path_to_faces = "./dataset/friends_frames/" + workbook = xlsxwriter.Workbook(f"./dataset/excel/s01e01_new_new.xlsx") + worksheet = workbook.add_worksheet() + worksheet.write('A1', 'weak label') + worksheet.write('B1', 'correct label') + worksheet.write('C1', 'prediction') + worksheet.write('D1', 'week=correct') + worksheet.write('E1', 'predict=correct') + worksheet.write('F1', 'image') + worksheet.write('G1', 'face') + worksheet.write('H1', 'image_loc') + worksheet.write('I1', 'face_loc') + old_anns = load_json("./dataset/all_annotations.json") + for iteration, ann in enumerate(new_anns.values()): + worksheet.write(f'A{iteration + 2}', "".join(ann["name"])) # weak label + face = ann["face"] + worksheet.insert_image(f'F{iteration + 2}', + os.path.join(data_path, face[0:28] + f"/{face.split('_')[-2]}.jpg")) + worksheet.insert_image(f'G{iteration + 2}', os.path.join(path_to_faces, ann["face"])) # face + worksheet.write(f'H{iteration + 2}', face[0:28] + f"/{face.split('_')[-2]}.jpg") # image_loc + worksheet.write(f'I{iteration + 2}', ann["face"]) # face_loc + for old_ann in old_anns.values(): + if ann["face"] == old_ann["face"]: + worksheet.write(f'B{iteration + 2}', old_ann["target_name"]) # face_loc + workbook.close() + + +def visualize_embeddings(model, dataset, file_name, method="tsne", model_mode="facenet_pretrained", normalize=False, mode="nothing", preload=False): + """ + # visualize_embeddings(model=model, dataset=tvqa_all, file_name="all_data", method="tsne", model_mode=model_mode, normalize=False, mode="nothing", preload=True) + build_embeddings(model=model, dataset=tvqa_all, exp=exp, file_name="all_data", method="umap") + :param model: pretrained model that generates face embeddings + :param dataset: pointing to the data : tvqa_all, tvqa_train, tvqa_test, tvqa_val + :param file_name: excel file name -> all_data + :param method: tsne or umap + :param model_mode: facenet_pretrained or resnet_pretrained + :param normalize: whether to normalize embeddings or not + :param mode: pictures, text or nothing + :return: saves files of such visualizations + """ + dataset_loc = "./dataset/excel" + + embeddings = calculate_embeddings(model, dataset, model_mode, normalize=normalize, preload=preload) + if method == "tsne": + # dists = [[(e1 - e2).norm().item() for e2 in embeddings] for e1 in embeddings] + # print(pd.DataFrame(dists)) #, columns=names, index=names?? + print("TSNE is being calculated...") + tsne_grid = TSNE(random_state=10, n_iter=2000).fit_transform(embeddings.detach().numpy()) + print("TSNE is calculated!") + df = pd.read_excel(f"{dataset_loc}/{file_name}.xls", usecols="A,B,C") # weak label, correct label, prediction + workbook = openpyxl.load_workbook(f"{dataset_loc}/{file_name}.xlsx") + sheet = workbook['Sheet1'] + image_loader = SheetImageLoader(sheet) + hac8 = AgglomerativeClustering(n_clusters=8).fit_predict(embeddings.detach().numpy()) + + dataset.lbl_to_id['Unknown'] = 50 + dataset.id_to_lbs[50] ='Unknown' + + # hac8 = [dataset.id_to_lbs[predicted_lbl] for predicted_lbl in hac8] + fig, plt = visualize_tsne(tsne_grid, hac8, dataset.id_to_lbs, image_loader, mode=mode) + fig.savefig(os.path.join("", f"hac8_clusteringpredictions_{file_name}_{mode}pictures_{model_mode}.pdf")) + plt.clf() + + correct_ids = [dataset.lbl_to_id[lbl] for lbl in df['correct label'][:-1].values.tolist()] + fig, plt = visualize_tsne(tsne_grid, correct_ids, dataset.id_to_lbs, image_loader, mode=mode) + fig.savefig(os.path.join("", f"tsne_corrects_{file_name}_{mode}pictures_{model_mode}.pdf")) + plt.clf() + + prediction_ids = [dataset.lbl_to_id[lbl] for lbl in df['prediction'][:-1].values.tolist()] + fig, plt = visualize_tsne(tsne_grid, prediction_ids, dataset.id_to_lbs, image_loader, mode=mode) + fig.savefig(os.path.join("", f"tsne_predictions_{file_name}_{mode}pictures_{model_mode}.pdf")) + plt.clf() + + elif method == "umap": + import matplotlib.pyplot as plt + df = pd.read_excel(f"{dataset_loc}/{file_name}.xls", usecols="A,B,C") + mapper = umap.UMAP().fit(embeddings.detach().numpy()) + colors = df['correct label'][:-1] + targets = np.asarray([lbl for lbl in colors]) + fig, ax = plt.subplots() + umap.plot.points(mapper, labels=targets, ax=ax) + fig.savefig(f"umap_{file_name}_{model_mode}.pdf") + + +def calculate_embeddings(model, dataset, emb_path, model_mode="facenet_pretrained", normalize=False, preload=False): + # emb_path = "./dataset/embeddings.pt" + if preload is True: + print("Preloading from existing embedding.pt file!") + return torch.load(emb_path) + print("Calculating the embeddings and saving them in embedding.pt file!") + # data_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False, num_workers=0) # todo batch size + data_loader = DataLoader(dataset, batch_size=4096, shuffle=False, num_workers=0) + flattened_model = get_children(model) + if model_mode == "resnet_pretrained" or model_mode == "facenet_reclassified": + save_output = SaveOutput() + hook_handles = [] + for layer in flattened_model: + handle = layer.register_forward_hook(save_output) + hook_handles.append(handle) + with torch.no_grad(): + for iteration, data in tqdm(enumerate(data_loader)): + model(data[0]['image'][None, :, :, :]) + + if model_mode == "resnet_pretrained": + temp_emb = save_output.outputs[len(save_output.outputs) - 2] # embeddings from the layer before logits + elif model_mode == "facenet_reclassified": + temp_emb = save_output.outputs[-1] # embeddings from the layer before loss layer + if iteration == 0: + embeddings = temp_emb + else: + embeddings = torch.cat((embeddings, temp_emb), 0) + save_output.clear() + + + elif model_mode == "facenet_pretrained": + with torch.no_grad(): + for iteration, data in tqdm(enumerate(data_loader)): + if iteration == 0: + # embeddings = model(data[0]['image'][None, :, :, :]) + embeddings = model(data['image'][:, :, :]) + else: + # embeddings = torch.cat((embeddings, model(data[0]['image'][None, :, :, :])), 0) + embeddings = torch.cat((embeddings, model(data['image'][:, :, :])), 0) + + if normalize: + embeddings = F.normalize(embeddings, p=2, dim=1) + + torch.save(embeddings, emb_path) + return embeddings + + +def calculate_sentence_embeddings(dataset, emb_path, normalize=False, preload=False): + if preload is True: + print(f"Preloading from existing {emb_path} file!") + return torch.load(emb_path) + print("Calculating the embeddings and saving them in sentence_embedding.pt file!") + data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=lambda x: x) + model = SentenceTransformer('sentence-transformers/paraphrase-mpnet-base-v2') + # model = SimCSE("princeton-nlp/sup-simcse-bert-base-uncased") + + with torch.no_grad(): + for iteration, data in tqdm(enumerate(data_loader)): + # weak_lbls = data[0]['name'] + + sub_emb = torch.from_numpy(model.encode(data[0]['subtitle'])[None, :]) + # if weak_lbls: + # weak_lbls_embedding = model.encode(' '.join(weak_lbls))[None, :] + # else: + # weak_lbls_embedding = sub_emb + # + # sub_emb = torch.cat((sub_emb, weak_lbls_embedding), 1) + + if iteration == 0: + all_embeddings = sub_emb + else: + all_embeddings = torch.cat((all_embeddings, sub_emb), 0) + + if normalize: + all_embeddings = F.normalize(all_embeddings, p=2, dim=1) + + torch.save(all_embeddings, emb_path) + return all_embeddings + + +def calculate_joint_embeddings(model, dataset, emb_path, model_mode="facenet_pretrained", normalize=False, preload=False): + # emb_path = "./dataset/embeddings.pt" + if preload is True: + print("Preloading from existing embedding.pt file!") + return torch.load(emb_path) + print("Calculating the embeddings and saving them in embedding.pt file!") + # data_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False, num_workers=0) # todo batch size + data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=lambda x: x) + flattened_model = get_children(model) + + save_output = SaveOutput() + hook_handles = [] + for layer in flattened_model: + handle = layer.register_forward_hook(save_output) + hook_handles.append(handle) + with torch.no_grad(): + for iteration, data in tqdm(enumerate(data_loader)): + model(data[0]['image'][None, :, :, :], None, data[0]['subtitle'], None) + + temp_emb = save_output.outputs[-1] # embeddings from the layer before loss layer + if iteration == 0: + embeddings = temp_emb + else: + embeddings = torch.cat((embeddings, temp_emb), 0) + save_output.clear() + + if normalize: + embeddings = F.normalize(embeddings, p=2, dim=1) + + torch.save(embeddings, emb_path) + return embeddings + + +#todo: https://jakevdp.github.io/PythonDataScienceHandbook/05.12-gaussian-mixtures.html +def draw_ellipse(position, covariance, ax=None, **kwargs): + """Draw an ellipse with a given position and covariance""" + # ax = ax or plt.gca() + + # Convert covariance to principal axes + if covariance.shape == (2, 2): + U, s, Vt = np.linalg.svd(covariance) + angle = np.degrees(np.arctan2(U[1, 0], U[0, 0])) + width, height = 2 * np.sqrt(s) + else: + angle = 0 + # width, height = 2 * np.sqrt(covariance) + width = 2 * np.sqrt(covariance) + height = width + + # Draw the Ellipse + # for w, h in zip(width, height): + # ax.add_patch(Ellipse(position, w, h, angle, **kwargs)) + for nsig in range(1, 4): + ax.add_patch(Ellipse(position, nsig * width, nsig * height, angle, **kwargs)) + + +def plot_gmm(gmm, labels, tsne_grid, label=True): + fig = plt.figure(figsize=(8, 8)) + ax = plt.subplot(aspect='equal') + + # ax = ax or plt.gca() + # labels = gmm.fit(X).predict(X) + # if label: + # ax.scatter(tsne_grid[:, 0], tsne_grid[:, 1], c=labels, s=40, cmap='viridis', zorder=2) + # else: + # ax.scatter(tsne_grid[:, 0], tsne_grid[:, 1], s=40, zorder=2) + # ax.axis('equal') + + w_factor = 0.2 / gmm.weights_.max() + for pos, covar, w in zip(gmm.means_, gmm.covariances_, gmm.weights_): + draw_ellipse(pos, covar, ax=ax, alpha=w * w_factor) + + return fig, plt + + +def visualize_tsne(tsne_grid, label_ids, id_to_lbl, image_loader=None, mode="nothing"): + + num_classes = len(id_to_lbl) + # convert to pandas + label_ids = pd.DataFrame(label_ids, columns=['label'])['label'] + # create a scatter plot. + fig = plt.figure(figsize=(8, 8)) + ax = plt.subplot(aspect='equal') + if not label_ids.isnull().values.any(): + plt.scatter(tsne_grid[:, 0], tsne_grid[:, 1], lw=0, s=40, c=np.asarray(label_ids), + cmap=discrete_cmap(num_classes, "tab10")) + # , c = palette[np.asarray([lbl_to_id[lbl] for lbl in colors])] + # c = np.random.randint(num_classes, size=len(tsne_grid[:, 1])) + else: + plt.scatter(tsne_grid[:, 0], tsne_grid[:, 1], lw=0, s=40) + plt.xlim(-25, 25) + plt.ylim(-25, 25) + cbar = plt.colorbar(ticks=range(num_classes)) + cbar.set_ticklabels(list(id_to_lbl.values())) + plt.clim(-0.5, num_classes - 0.5) + ax.axis('off') + ax.axis('tight') + + if mode == "picture": + max_dim = 16 + for i, (x, y) in enumerate(tsne_grid): + print(i, x, y) + tile = image_loader.get(f'G{i+2}') + rs = max(1, tile.width/max_dim, tile.height/max_dim) + tile = tile.resize((int(tile.width/rs), int(tile.height/rs)), Image.ANTIALIAS) + imagebox = OffsetImage(tile) #, zoom=0.2) + ab = AnnotationBbox(imagebox, (x, y), pad=0.1) + ax.add_artist(ab) + + if mode == "text": + # add the labels for each digit corresponding to the label + if not label_ids.isnull().values.any(): + txts = [] + for id, lbl in id_to_lbl.items(): + # Position of each label at median of data points. + xtext, ytext = np.median(tsne_grid[np.asarray(label_ids) == id, :], axis=0) + if math.isnan(xtext) or math.isnan(ytext): # this label does not exist in this set + continue + txt = ax.text(xtext, ytext, lbl, fontsize=10, zorder=100) + txt.set_path_effects([ + PathEffects.Stroke(linewidth=2, foreground="w"), + PathEffects.Normal()]) + txts.append(txt) + + return fig, plt + + +def evaluate(model, dataset, model_mode): + data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) + + with torch.no_grad(): + correct_target_names = [] + correct_target_ids = [] + target_ids = [] + predictions = [] + for data in tqdm(data_loader): + # target_ids.append(data['weak_label']) + correct_target_names.extend(data['correct_target_name']) + correct_target_ids.append(data['correct_target_id']) + if model_mode == "facenet_pretrained": + predictions = np.nan + continue + else: + prediction = model(data['image']) + predictions.append(int(prediction.argmax())) + + # flatten list of lists + # target_ids = [item.item() for sublist in target_ids for item in sublist] + correct_target_ids = [item.item() for sublist in correct_target_ids for item in sublist] + + results = pd.DataFrame({'correct_target_name': correct_target_names, + 'correct_target_id': correct_target_ids, + 'model_prediction': predictions, + # 'weak_label': target_ids, + # 'max_cluster_prediction': np.nan, + }) + + return results + + +def clustering_with_gmm(results, face_embeddings, tsne_grid, n_clusters=7): + from sklearn import mixture + # results["gmm"] = np.nan + model = mixture.GaussianMixture(n_components=n_clusters, covariance_type='spherical') + # a = model.fit(face_embeddings) + # labels = a.predict(face_embeddings) + labels = model.fit(tsne_grid).predict(tsne_grid) + results["gmm"] = labels + # probs = model.predict_proba(face_embeddings) + plt, fig = plot_gmm(model, labels, tsne_grid) + return results, plt, fig + + +def build_graph(embeddings, train_dataset, neighbours=3385, no_edge=True): + from sklearn.neighbors import kneighbors_graph + train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0) + knn_graph = kneighbors_graph(embeddings, neighbours, include_self=True) # make the full neighbour graph + knn_graph = knn_graph.toarray() + + if no_edge: + images = [] + with torch.no_grad(): + for i, data in enumerate(train_loader): + img = '_'.join(data["face"][0].split("_")[:-1]) + images.append(img) + + for i1, img1 in enumerate(images): + for i2, img2 in enumerate(images): + if img1 == img2 and i1 != i2: + knn_graph[i1][i2] = 0.0 + + return knn_graph + + +def predict_with_clustering(results, embeddings, n_clusters, knn_graph=None, pred_mode="max_cluster_prediction"): + # from sklearn import mixture + # model = mixture.GaussianMixture(n_components=n_clusters, covariance_type='spherical') + # labels = model.fit(embeddings).predict(embeddings) + # results["gmm"] = labels + # results[pred_mode] = np.nan + embeddings = embeddings.detach().numpy() #.astype('float32') + # print(embeddings.dtype) + # logger.info(f"type of embeddings is {embeddings.dtype}") + if global_cfg.TRAINING.clustering == "AgglomerativeClustering": + hac8_id = AgglomerativeClustering(n_clusters=n_clusters).fit_predict(embeddings) + elif global_cfg.TRAINING.clustering == "KMeans": + hac8_id = KMeans(n_clusters=n_clusters).fit_predict(embeddings) + elif global_cfg.TRAINING.clustering == "MiniBatchKMeans": + hac8_id = MiniBatchKMeans(batch_size=global_cfg.TRAINING.kmeans_batch_size, n_clusters=n_clusters).fit_predict(embeddings) + + # import pdb + # pdb.set_trace() + # if knn_graph is None: + # hac8_id = AgglomerativeClustering(n_clusters=n_clusters).fit_predict(embeddings) + # else: + # hac8_id = AgglomerativeClustering(n_clusters=n_clusters, connectivity=knn_graph).fit_predict(embeddings.detach().numpy()) + hac8_id = pd.DataFrame(hac8_id, columns=['label']) + + for i in range(n_clusters): + # true for all indices equal to that cluster + # if np.all(results.loc[hac8_id['label'] == 18, "weak_label_ids"]) == []: #todo: what have i done here?? + # results.loc[hac8_id['label'] == i, pred_mode] = 0.6 #unknown + # else: + results.loc[hac8_id['label'] == i, pred_mode] = to_1D(results.loc[hac8_id['label'] == i, "weak_label_ids"]).value_counts().idxmax() + + results[pred_mode] = pd.to_numeric(results[pred_mode], downcast='integer') + + results['direct'] = results['weak_label_ids'] + results['direct'] = results['direct'].apply(lambda x: x[0].item() if len(x) == 1 else np.nan) + results['M2'] = results[pred_mode] + results.loc[results['direct'].notnull(), 'M2'] = results.loc[results['direct'].notnull(), 'direct'] + results['M2'] = pd.to_numeric(results['M2'], downcast='integer') + + clustering_ids = hac8_id + + return results, clustering_ids + + +def calc_accuracies(results, mode="max_cluster_prediction"): + correct = (results[mode] == results["correct_target_id"]).value_counts().loc[True] + incorrect = (results[mode] == results["correct_target_id"]).value_counts().loc[False] + accuracy = correct / (correct + incorrect) + # from sklearn.metrics import classification_report, accuracy_score + # or accuracy_score(results['correct_target_id'], results['max_cluster_prediction']) + return accuracy + + +def calc_accuracies_bbt(gt_id, predictions_id): + correct = np.count_nonzero(predictions_id == gt_id) + incorrect = len(predictions_id) - np.count_nonzero(predictions_id == gt_id) + accuracy = correct / (correct + incorrect) + # from sklearn.metrics import classification_report, accuracy_score + # or accuracy_score(results['correct_target_id'], results['max_cluster_prediction']) + return accuracy + + +def calc_per_class_prec_recall(results, mode="max_cluster_prediction"): + return classification_report(results['correct_target_id'], results[mode], digits=3) + + +def calc_per_class_prec_recall_bbt(train_dataset, gt_id, predictions_id): + return classification_report(gt_id, predictions_id, labels=list(train_dataset.lbl_to_id.keys()), digits=3) + + +def calc_per_class_accuracy(dataset, results, mode="max_cluster_prediction"): + cm = confusion_matrix(results['correct_target_id'], results[mode]) + num_classes = len(dataset.lbl_to_id.keys()) + plt.imshow(cm, cmap='plasma', interpolation='nearest') + + plt.xticks(range(num_classes), np.array(list(dataset.lbl_to_id.keys())), rotation=90, fontsize=6) + plt.yticks(range(num_classes), np.array(list(dataset.lbl_to_id.keys())), rotation=0, fontsize=6) + + # Plot a colorbar with label. + cb = plt.colorbar() + cb.set_label("Number of predictions") + + # Add title and labels to plot. + plt.title("Confusion Matrix for predictions and correct labels") + plt.xlabel('Correct Label') + plt.ylabel('Predicted Label') + plt.savefig('confusion_matrix_upperbound.pdf') + plt.clf() + + + cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + return cm.diagonal(), dataset.id_to_lbs + + +def calc_per_class_accuracy_bbt(dataset, gt_names, predictions_names): + cm = confusion_matrix(gt_names, predictions_names, labels=list(dataset.lbl_to_id.keys())) + num_classes = len(dataset.lbl_to_id.keys()) + plt.imshow(cm, cmap='plasma', interpolation='nearest') + + plt.xticks(range(num_classes), np.array(list(dataset.lbl_to_id.keys())), rotation=90, fontsize=6) + plt.yticks(range(num_classes), np.array(list(dataset.lbl_to_id.keys())), rotation=0, fontsize=6) + + # Plot a colorbar with label. + cb = plt.colorbar() + cb.set_label("Number of predictions") + + # Add title and labels to plot. + plt.title("Confusion Matrix for predictions and correct labels") + plt.xlabel('Correct Label') + plt.ylabel('Predicted Label') + plt.savefig('confusion_matrix_upperbound.pdf') + plt.clf() + + + cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + return cm.diagonal(), dataset.id_to_lbs + + +def exp_num_clusters(cluster_range, results, embeddings, test_dataset): + accuracies_per_cluster = [] + clusters = [] + for n in cluster_range: + results, clustering_ids = predict_with_clustering(results, embeddings, n_clusters=n) + prediction_mode = "max_cluster_prediction" + + print( + f"per sample accuracy is {calc_accuracies(results, mode=prediction_mode)}") + accuracies = calc_per_class_accuracy(test_dataset, results, mode=prediction_mode) + print(f"mean per class accuracy: {accuracies[0].mean()}") + print(f"per class accuracies: {accuracies}") + accuracies_per_cluster.append(accuracies[0].mean()) + print(f"per class precision and recalls: {calc_per_class_prec_recall(results, mode=prediction_mode)}") + + clusters.append(n) + + fig = plt.figure(figsize=(8, 6)) + ax = plt.subplot(aspect='equal') + # plt.style.use('seaborn-darkgrid') + # plt.yticks(np.arange(0.0, 1.1, 0.1)) + ax.axis('tight') + plt.ylim(0.0, 1.0) + plt.xlim(6, 31) + plt.bar(x=clusters, height=accuracies_per_cluster, width=0.4, color='#c3abd0') # width + plt.plot(clusters, accuracies_per_cluster, color='#815f76') + plt.ylabel("Accuracy", rotation=90) + plt.xlabel("Number of Clusters") + fig.savefig(os.path.join(f"./output/ablation_num_clusters_friends/number_of_clusters_2.pdf")) + plt.clf() + + +def cleanse_labels(dataset, results, file_path): + anns = dataset.anns + for ann, max_prediction in zip(anns.values(), results['max_cluster_prediction']): + ann['cleansed'] = dataset.id_to_lbs[max_prediction] + save_json(anns, file_path=file_path) + + +def save_predictions(id_to_lbs, results, file_path, prediction_dict, prediction_mode='cleansed'): + for ann, max_prediction in zip(prediction_dict.values(), results[prediction_mode]): + ann[prediction_mode] = id_to_lbs[max_prediction] + save_json(prediction_dict, file_path=file_path) + return prediction_dict + + +def evaluate_self_supervised(train_dataset, test_dataset, face_embeddings, mode="baseline1"): + train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0) + test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0) + + with torch.no_grad(): + correct_target_names = [] + correct_target_ids = [] + weak_ids = [] + cleansed_ids = [] + face_emb = [] + unknown_id = torch.tensor([train_dataset.lbl_to_id["Unknown"]]) + for train_data, test_data, face_embedding in tqdm(zip(train_loader, test_loader, face_embeddings)): + if mode == "baseline0": + if train_data['weak_id']: + choice = random.choice(train_data['weak_id']) + weak_ids.append(choice if choice in list(train_dataset.id_to_lbs.keys()) else unknown_id) + else: + weak_ids.append(unknown_id) + elif mode == "baseline1": + weak_ids.append(train_data['weak_id']) + correct_target_names.extend(test_data['correct_target_name']) + correct_target_ids.append(test_data['correct_target_id']) + cleansed_ids.append(train_dataset.lbl_to_id[train_data['cleansed'][0]]) + face_emb.append(face_embedding) + + # flatten list of lists + # weak_ids = [item.item() for sublist in weak_ids for item in sublist] + correct_target_ids = [item.item() for sublist in correct_target_ids for item in sublist] + + if mode == "baseline0": + results = pd.DataFrame({'correct_target_name': correct_target_names, + 'correct_target_id': correct_target_ids, + 'random_weak_label': weak_ids, + }) + elif mode == "baseline1": + results = pd.DataFrame({'correct_target_name': correct_target_names, + 'correct_target_id': correct_target_ids, + 'weak_label_ids': weak_ids, + 'max_cluster_prediction': np.nan, + 'cleansed': cleansed_ids, + 'subtitle_prediction': np.nan, + 'face_embedding': face_emb, + '0': np.nan, + '1': np.nan, + '2': np.nan, + '3': np.nan, + '4': np.nan, + '5': np.nan, + '6': np.nan, + 'min_distance': np.nan, + 'closest_cluster': np.nan, + }) + + return results + + +def prepare_result(train_dataset, face_embeddings): + train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0) + + with torch.no_grad(): + weak_ids = [] + cleansed_ids = [] + face_emb = [] + unknown_id = torch.tensor([train_dataset.lbl_to_id["Unknown"]]) + for train_data, face_embedding in tqdm(zip(train_loader, face_embeddings)): + weak_ids.append(train_data['weak_id']) + cleansed_ids.append(train_dataset.lbl_to_id[train_data['cleansed'][0]]) + face_emb.append(face_embedding) + + # flatten list of lists + # correct_target_ids = [item.item() for sublist in correct_target_ids for item in sublist] + + results = pd.DataFrame({'face_embedding': face_emb, + 'weak_label_ids': weak_ids, + 'correct_target_name': np.nan, + 'correct_target_id': np.nan, + 'max_cluster_prediction': np.nan, + 'cleansed': cleansed_ids, + '0': np.nan, + '1': np.nan, + '2': np.nan, + '3': np.nan, + '4': np.nan, + '5': np.nan, + '6': np.nan, + '7': np.nan, + '8': np.nan, + 'min_distance': np.nan, + 'closest_cluster': np.nan, + }) + + return results + + +def evaluate_ep_1_5(train_dataset, face_embeddings, mode="baseline1"): + train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0) + # test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0) + + with torch.no_grad(): + correct_target_names = [] + correct_target_ids = [] + weak_ids = [] + cleansed_ids = [] + face_emb = [] + unknown_id = torch.tensor([train_dataset.lbl_to_id["Unknown"]]) + for train_data, face_embedding in tqdm(zip(train_loader, face_embeddings)): + if mode == "baseline0": + if train_data['weak_id']: + choice = random.choice(train_data['weak_id']) + weak_ids.append(choice if choice in list(train_dataset.id_to_lbs.keys()) else unknown_id) + else: + weak_ids.append(unknown_id) + elif mode == "baseline1": + weak_ids.append(train_data['weak_id']) + correct_target_names.extend(train_data['correct_target_name']) + correct_target_ids.append(train_data['correct_target_id']) + cleansed_ids.append(train_dataset.lbl_to_id[train_data['cleansed'][0]]) + face_emb.append(face_embedding) + + # flatten list of lists + # weak_ids = [item.item() for sublist in weak_ids for item in sublist] + correct_target_ids = [item.item() for sublist in correct_target_ids for item in sublist] + + if mode == "baseline0": + results = pd.DataFrame({'correct_target_name': correct_target_names, + 'correct_target_id': correct_target_ids, + 'random_weak_label': weak_ids, + }) + elif mode == "baseline1": + results = pd.DataFrame({'correct_target_name': correct_target_names, + 'correct_target_id': correct_target_ids, + 'weak_label_ids': weak_ids, + 'max_cluster_prediction': np.nan, + 'cleansed': cleansed_ids, + 'face_embedding': face_emb, + '0': np.nan, + '1': np.nan, + '2': np.nan, + '3': np.nan, + '4': np.nan, + '5': np.nan, + '6': np.nan, + 'min_distance': np.nan, + 'closest_cluster': np.nan, + }) + + return results + + +def evaluate_oracle_supervised(model, test_dataset, face_embeddings, mode="baseline1"): + # train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0) + test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0) + + with torch.no_grad(): + correct_target_names = [] + correct_target_ids = [] + weak_ids = [] + cleansed_ids = [] + face_emb = [] + predictions = [] + unknown_id = torch.tensor([test_dataset.lbl_to_id["Unknown"]]) + for test_data, face_embedding in tqdm(zip(test_loader, face_embeddings)): + if mode == "baseline0": + if test_data['weak_label']: + choice = random.choice(test_data['weak_id']) + weak_ids.append(choice if choice in list(test_dataset.id_to_lbs.keys()) else unknown_id) + else: + weak_ids.append(unknown_id) + elif mode == "baseline1": + weak_ids.append(test_data['weak_id']) + correct_target_names.extend(test_data['correct_target_name']) + correct_target_ids.append(test_data['correct_target_id']) + cleansed_ids.append(test_dataset.lbl_to_id[test_data['cleansed'][0]]) + face_emb.append(face_embedding) + prediction = model(test_data['image']) + predictions.append(int(prediction.argmax())) + + # flatten list of lists + # weak_ids = [item.item() for sublist in weak_ids for item in sublist] + correct_target_ids = [item.item() for sublist in correct_target_ids for item in sublist] + + if mode == "baseline0": + results = pd.DataFrame({'correct_target_name': correct_target_names, + 'correct_target_id': correct_target_ids, + 'random_weak_label': weak_ids, + }) + elif mode == "baseline1": + results = pd.DataFrame({ + 'correct_target_name': correct_target_names, + 'correct_target_id': correct_target_ids, + 'weak_label_ids': weak_ids, + 'max_cluster_prediction': np.nan, + 'cleansed': cleansed_ids, + 'face_embedding': face_emb, + 'model_prediction': predictions, + '0': np.nan, + '1': np.nan, + '2': np.nan, + '3': np.nan, + '4': np.nan, + '5': np.nan, + '6': np.nan, + 'min_distance': np.nan, + 'closest_cluster': np.nan, + }) + + return results + + +def match_bbox(train_dataset, test_dataset): + matchings = {str(i): {"matching": [], "iou": []} for i in range(len(test_dataset))} + for i_test, s_test in tqdm(test_dataset.anns.items()): + for i_train, s_train in train_dataset.anns.items(): + if s_test['clip'] + '_' + s_test['img'] == s_train['clip'] + '_' + s_train['img']: + iou = calc_iou(s_test['bbox'], s_train['bbox']) + # if iou > 0.0: #0.7 + if matchings[i_test]['matching']: + # print("more matchings!") + matchings[i_test]['matching'].append(i_train) + matchings[i_test]['iou'].append(iou) + else: + matchings[i_test]['matching'] = [i_train] + matchings[i_test]['iou'] = [iou] + save_json(matchings, "./dataset/box_matchings_iou0.json") + return matchings + + +def match_bbox_2(test_dataset, prediction_dict): + + ########################calculate matchings######################## + if os.path.exists("./dataset/box_matchings_with_iou.json"): + print("Found matches json, loading ...") + return load_json("./dataset/box_matchings_with_iou.json") + + else: + print("No matches.json, creating ...") + matchings = {str(i): {"matching": [], "iou": [], "prediction_baseline_0": []} for i in range(len(test_dataset.anns))} + for i_test, s_test in tqdm(test_dataset.anns.items()): + for train_face, train_ann in prediction_dict.items(): + if s_test['clip'] + '_' + s_test['img'] == train_ann['clip'] + '_' + train_ann['img']: + iou = calc_iou(s_test['bbox'], train_ann['bbox']) + # more than one match + if matchings[i_test]['matching']: + matchings[i_test]['matching'].append(train_face) + matchings[i_test]['iou'].append(iou) + matchings[i_test]['prediction_baseline_0'].append(train_ann['prediction_baseline_0']) + else: + matchings[i_test]['matching'] = [train_face] #matched faces + matchings[i_test]['iou'] = [iou] #iou of the matched faces + matchings[i_test]['prediction_baseline_0'] = [train_ann['prediction_baseline_0']] #predictions of the matched faces + save_json(matchings, "./dataset/box_matchings.json") + + + #add here the iou code: + to_be_deleted_ind = [] + to_be_deleted_face = [] + for i_test, s_test in tqdm(test_dataset.anns.items()): + if not matchings[i_test]["iou"]: + # this face is not matched with anything -> discard + to_be_deleted_ind.append(i_test) + to_be_deleted_face.append(s_test['face']) + else: + max_iou = np.max(np.array(matchings[i_test]["iou"])) + if max_iou == 0: + # this face is not matched with anything -> discard + to_be_deleted_ind.append(i_test) + to_be_deleted_face.append(s_test['face']) + save_json(to_be_deleted_ind, "./dataset/to_be_deleted_test_index.json") + save_json(to_be_deleted_face, "./dataset/to_be_deleted_test_face.json") + # update matchings: + valid_test_anns = {} + to_be_deleted_test_ind = load_json("./dataset/to_be_deleted_test_index.json") + j = 0 + for i_test, ann_test in tqdm(test_dataset.anns.items()): + if i_test not in to_be_deleted_test_ind: + valid_test_anns[str(j)] = ann_test + j += 1 + + save_json(valid_test_anns, "./dataset/bbt_test_annotations.json") + + new_matchings = {} + to_be_deleted_test_ind = load_json("./dataset/to_be_deleted_test_index.json") + j = 0 + for i_match, match in tqdm(matchings.items()): + if i_match not in to_be_deleted_test_ind: + new_matchings[str(j)] = match + j += 1 + save_json(new_matchings, "./dataset/box_matchings_with_iou.json") + return new_matchings + + +def calculate_bbt(matchings, prediction_dict, test_dataset, prediction_mode='closest_cluster'): + # train_dataset is basically prediction_dict, it is the file that is exactly like train annotations but with preedictions + # prediction_dict = train_dataset.anns + face_to_train_idx = load_json("./dataset/face_to_train_idx.json") + test_9_10_idx_to_test_all_idx = load_json("./dataset/test_9_10_idx_to_test_all_idx.json") + + gt_names = [] + gt_ids = [] + predictions = [] + for i_test, s_test in tqdm(test_dataset.anns.items()): + # if s_test['face'].startswith(('s09', 's10')): #todo + gt_names.append(s_test['name']) + gt_ids.append(test_dataset.lbl_to_id[s_test['name']]) + if global_cfg.TRAINING.exp_type == "oracle" and global_cfg.TRAINING.ours_or_baseline == "ours": + new_test_idx = test_9_10_idx_to_test_all_idx[i_test] + i_test = new_test_idx + + if global_cfg.TRAINING.ours_or_baseline == "baseline": + predictions.append(prediction_dict[i_test][prediction_mode]) + elif global_cfg.TRAINING.ours_or_baseline == "ours": + iou_idx = np.array(matchings[i_test]["iou"]).argmax() + matched_face = matchings[i_test]['matching'][iou_idx] + predictions.append(prediction_dict[face_to_train_idx[matched_face]][prediction_mode]) + + + # gt_names = [] + # gt_ids = [] + # predictions = [] + # for i_test, s_test in tqdm(test_dataset.anns.items()): + # if matchings[i_test]["iou"]: + # if np.max(np.array(matchings[i_test]["iou"])) != 0: + # gt_names.append(s_test['name']) + # gt_ids.append(test_dataset.lbl_to_id[s_test['name']]) + # predictions.append('None') + # + # iou_idx = np.array(matchings[i_test]["iou"]).argmax() + # if matchings[i_test]["iou"][iou_idx] > threshold: + # matched_face = matchings[i_test]['matching'][iou_idx] + # predictions[-1] = prediction_dict[matched_face]['prediction_baseline_0'] + + return np.array(gt_names), np.array(predictions) + + +def predict_bbt(train_dataset, evaluation_path="./dataset/evaluation_dict.json", prediction_path="./dataset/evaluation_dict_baseline_0.json", test_dataset=None, face_embeddings=None, mode="baseline1"): + # train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0) + # test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0) + # train_dataset_anns = load_json("./dataset/new_mix_bbt_annotations.json") + # train_dataset_lbl_to_id = load_json("./dataset/bbt_lbl_to_id.json") + + if os.path.exists(evaluation_path): + print("Found evaluation dict in predict_bbt, loading ...") + evaluation_dict = load_json(evaluation_path) + else: + print("No evaluation dict in predict_bbt, creating ...") + evaluation_dict = {} + for train_idx, train_ann in tqdm(train_dataset.anns.items()): + if train_ann['face'] in evaluation_dict: + evaluation_dict[train_ann['face']]["train_index"].append(train_idx) + evaluation_dict[train_ann['face']]["weak_id"].append(train_dataset.lbl_to_id[train_ann['name']]) + evaluation_dict[train_ann['face']]["weak_name"].append(train_ann["name"]) + else: + evaluation_dict[train_ann['face']] = { + "train_index": [train_idx], + "img": train_ann["img"], + "clip": train_ann["clip"], + "series": train_ann["series"], + "face": train_ann["face"], + "weak_id": [train_dataset.lbl_to_id[train_ann['name']]], + "weak_name": [train_ann["name"]], + "bbox": train_ann["bbox"], + "face_points": train_ann["face_points"], + "subtitle": train_ann["subtitle"], + } + save_json(evaluation_dict, evaluation_path) + evaluation_dict = load_json(evaluation_path) + + if os.path.exists(prediction_path): + print("Found prediction_path dict in predict_bbt, loading ...") + prediction_dict = load_json(prediction_path) + else: + with torch.no_grad(): + if mode == "baseline0": + for ann in tqdm(evaluation_dict.values()): + # if only one weak_name, then the choice is obvious + if ann['weak_name']: + choice = random.choice(ann['weak_name']) + ann["prediction_baseline_0"] = choice + else: + ann["prediction_baseline_0"] = "Unknown" + save_json(evaluation_dict, prediction_path) + prediction_dict = load_json(prediction_path) + + return prediction_dict + + ''' + elif mode == "baseline1": + face_emb = [] + for train_data, face_embedding in tqdm(zip(train_loader, face_embeddings)): + faces.append(train_data['face']) + weak_ids.append(train_data['weak_label']) + face_emb.append(face_embedding) + # correct_target_names.extend(test_data['correct_target_name']) + # correct_target_ids.append(test_data['correct_target_id']) + # cleansed_ids.append(train_dataset.lbl_to_id[train_data['cleansed'][0]]) + + # flatten list of lists + # weak_ids = [item.item() for sublist in weak_ids for item in sublist] + # correct_target_ids = [item.item() for sublist in correct_target_ids for item in sublist] + + if mode == "baseline0": + results = pd.DataFrame({'faces': faces, + 'correct_target_name': correct_target_names, + 'correct_target_id': correct_target_ids, + 'random_weak_label_id': weak_ids, + 'random_weak_label_name': weak_names, + }) + elif mode == "baseline1": + results = pd.DataFrame({'correct_target_name': np.nan, + 'correct_target_id': np.nan, + 'weak_label_ids': weak_ids, + 'max_cluster_prediction': np.nan, + 'cleansed': np.nan, + 'subtitle_prediction': np.nan, + 'face_embedding': face_emb, + '0': np.nan, + '1': np.nan, + '2': np.nan, + '3': np.nan, + '4': np.nan, + '5': np.nan, + '6': np.nan, + 'min_distance': np.nan, + 'closest_cluster': np.nan, + }) + + return results +''' + + +# if __name__ == "__main__": +# parser = argparse.ArgumentParser() +# parser.add_argument('--episode', type=str, help='The episode number') +# ARGS, unparsed = parser.parse_known_args() +# +# dataset = TVQADataset(episode=ARGS.episode, split="all", transform=transforms.Compose([transforms.ToTensor()])) +# make_excel(model=None, +# dataset=dataset, +# source="dataloader", +# data_path="./dataset/frames_hq/friends_frames/", +# path_to_faces="./dataset/friends_frames/", +# file_name=f"s01e{ARGS.episode}.xlsx") + + +def visualize_cluster_distances(results): + from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances + + sorted_results = results.sort_values("cleansed") # sort based on the cleansed label + embeddings = torch.empty(size=(3386, 512)) + for i, emb in enumerate(sorted_results["face_embedding"]): + embeddings[i] = emb + dists = euclidean_distances(embeddings) + + plt.imshow(dists, cmap='plasma', interpolation='nearest') + # plt.xticks(range(len(sorted_results["cleansed"].values)), sorted_results["cleansed"].values, rotation=90, fontsize=6) + # plt.yticks(range(len(sorted_results["cleansed"].values), sorted_results["cleansed"].values, rotation=0, fontsize=6) + + cb = plt.colorbar() + cb.set_label("distances") + plt.title("distances of points of each cluster to other cluster") + + plt.savefig('small__distances.pdf') + plt.clf() + + +def convert_dfoftensors_to_tensor(df): + new_tensor = torch.empty(size=(len(df), len(df.iloc[0]))) + for i, item in enumerate(df): + new_tensor[i] = item + return new_tensor + + +def calc_distances(train_dataset, results, num_classes, alpha=None): + unknown_id = train_dataset.lbl_to_id["Unknown"] + from scipy.spatial import distance + for i, unknown in tqdm(results[results["cleansed"] == unknown_id].iterrows()): + for cluster in range(num_classes): + unknown_cluster_emb = unknown["face_embedding"] + other_cluster_embs = results.loc[results["cleansed"] == cluster, "face_embedding"] + if not other_cluster_embs.empty: #if this is not an empty series + other_cluster_embs_torch = convert_dfoftensors_to_tensor(other_cluster_embs) + c_dists = distance.cdist(unknown_cluster_emb[None, :], other_cluster_embs_torch, 'euclidean') + else: + # print(f"cluster {cluster} is empty!") + c_dists = np.array([np.inf]) + if alpha is not None: + if cluster == unknown_id: + results.loc[i, (f'{cluster}')] = c_dists.mean() * alpha + else: + results.loc[i, (f'{cluster}')] = c_dists.mean() * (1-alpha) + else: + results.loc[i, (f'{cluster}')] = c_dists.mean() + results["min_distance"] = results[[str(x) for x in range(num_classes)]].min(axis=1) + results["closest_cluster"] = results[[str(x) for x in range(num_classes)]].idxmin(axis=1) + results.loc[results["cleansed"] != unknown_id, "closest_cluster"] = results.loc[results["cleansed"] != unknown_id, "cleansed"] + results["closest_cluster"] = results["closest_cluster"].astype(np.int64) + return results + + +def calc_distances_with_prototypes(lbl_to_id, results, num_classes, alpha=None): + # unknown_id = lbl_to_id["Unknown"] + unknown_id = lbl_to_id["Sheldon"] + + from scipy.spatial import distance + prototype_embeddings = {} + for cluster in tqdm(range(num_classes)): + cluster_embs = results.loc[results["cleansed"] == cluster, "face_embedding"] + if cluster_embs.empty: + prototype_embeddings[cluster] = np.ones((1, 512)) * np.inf + else: + cluster_embs = convert_dfoftensors_to_tensor(cluster_embs) + prototype_embeddings[cluster] = cluster_embs.mean(axis=0, keepdim=True) + + unknown_cluster_embs = convert_dfoftensors_to_tensor(results.loc[results["cleansed"] == unknown_id, "face_embedding"]) + for prototype_idx, prototype_emb in tqdm(prototype_embeddings.items()): + c_dists = distance.cdist(unknown_cluster_embs, prototype_emb, 'euclidean') + if alpha is not None: + if prototype_idx == unknown_id: + c_dists = c_dists * alpha + else: + c_dists = c_dists * (1 - alpha) + results.loc[results["cleansed"] == unknown_id, (f'{prototype_idx}')] = c_dists + + results["min_distance"] = results[[str(x) for x in range(num_classes)]].min(axis=1) + results["closest_cluster"] = results[[str(x) for x in range(num_classes)]].idxmin(axis=1) + results.loc[results["cleansed"] != unknown_id, "closest_cluster"] = results.loc[results["cleansed"] != unknown_id, "cleansed"] + results["closest_cluster"] = results["closest_cluster"].astype(np.int64) + return results + + +def recluster_unknowns(results, num_classes=7, num_clusters=7): + from scipy.spatial import distance + results['recluster_unk'] = np.nan + results['hac7_id'] = np.nan + embeddings = convert_dfoftensors_to_tensor(results.loc[results["cleansed"] == 6, 'face_embedding']) + hac7_id = AgglomerativeClustering(n_clusters=num_clusters).fit_predict(embeddings.numpy()) + results.loc[results["cleansed"] == 6, 'hac7_id'] = hac7_id + unknowns = results[results["cleansed"] == 6] + + for hacid in range(num_clusters): + for cluster in range(num_clusters): + new_cluster_embs = unknowns.loc[unknowns["hac7_id"] == hacid, 'face_embedding'] + other_cluster_embs = results.loc[results["cleansed"] == cluster, 'face_embedding'] + + new_cluster_embs_torch = convert_dfoftensors_to_tensor(new_cluster_embs) + other_cluster_embs_torch = convert_dfoftensors_to_tensor(other_cluster_embs) + c_dists = distance.cdist(new_cluster_embs_torch, other_cluster_embs_torch, 'euclidean') + results.loc[results["hac7_id"] == hacid, f'{cluster}'] = c_dists.mean() + + results["min_distance"] = results[[str(x) for x in range(num_classes)]].min(axis=1) + results["closest_cluster"] = results[[str(x) for x in range(num_classes)]].idxmin(axis=1) + results.loc[results["cleansed"] != 6, "closest_cluster"] = results.loc[results["cleansed"] != 6, "cleansed"] + results["closest_cluster"] = results["closest_cluster"].astype(np.int64) + return results + + +def calc_iou(bbox1, bbox2): + x_left = max(bbox1[0], bbox2[0]) + y_top = max(bbox1[1], bbox2[1]) + x_right = min(bbox1[2], bbox2[2]) + y_bottom = min(bbox1[3], bbox2[3]) + + # no intersection + if x_right < x_left or y_bottom < y_top: + return 0.0 + intersection_area = (x_right - x_left) * (y_bottom - y_top) + bb1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) + bb2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) + + iou = intersection_area / float(bb1_area + bb2_area - intersection_area) + return iou + +def visualize(cfg): + seed = 10 + np.random.seed(seed) + model_mode = "facenet_pretrained" + # model_mode = "facenet_reclassified" + # model_mode = "nothing" + data_mode = "test" + if cfg.TRAINING.series == "bbt": + num_classes = 9 + elif cfg.TRAINING.series == "friends": + num_classes = 7 + num_clusters = 30 + exp = cfg.TRAINING.project_dir.split('/')[-2] + print(f"this is experiment {exp}") + epoch = 99 + print(f"cuda is available: {torch.cuda.is_available()}") + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + vis_path = f"./output/visualize/{exp}" + Path(vis_path).mkdir(parents=True, exist_ok=True) + logging.basicConfig(filename=f'{vis_path}/log.out', filemode='a', level=logging.INFO, + format='%(asctime)s %(name)s - %(levelname)s - %(message)s') + logger = logging.getLogger('visualize') + + if model_mode == "resnet_pretrained": + model = models.ResNet(block=BasicBlock, layers=[2, 2, 2, 2], num_classes=num_classes) + # loading model checkpoint + checkpoint = torch.load(f"./output/{exp}/model/epoch_{epoch}.tar", map_location=torch.device(device)) + model.load_state_dict(checkpoint['model_state_dict']) + model.eval() + elif model_mode == "facenet_pretrained": + model = InceptionResnetV1(pretrained='vggface2').eval() + elif model_mode == "facenet_reclassified": + model = VGGSupervised(cfg, num_classes) + # model = VGGFacePlus(cfg, num_classes) + checkpoint = torch.load(f"./output/{exp}/model/epoch_{epoch}.tar", map_location=torch.device(device)) + model.load_state_dict(checkpoint['model_state_dict']) + model.eval() + elif model_mode == "facenet_reclassified_VGGFaceSubtitle": + model = VGGFaceSubtitle(cfg, num_classes) + checkpoint = torch.load(f"./output/{exp}/model/epoch_{epoch}.tar", map_location=torch.device(device)) + model.load_state_dict(checkpoint['model_state_dict']) + model.eval() + + if exp.startswith("oracle_bbt") or exp == 'multilabel_bbt': + logging.info("bbt oracle...") + test_dataset = TVQADataset(series=cfg.TRAINING.series, split="test", transform=get_test_transforms(series=cfg.TRAINING.series)) + results = evaluate(model, test_dataset, model_mode) + prediction_dict = load_json(f"./dataset/bbt_9_10.json") + prediction_dict = save_predictions(test_dataset.id_to_lbs, results, + f"./dataset/bbt_9_10_multilabel_prediction.json", + prediction_dict, prediction_mode='model_prediction') + # prediction_dict = load_json("./dataset/bbt_9_10_oracle_prediction.json") + matchings = load_json("./dataset/box_matchings_with_iou.json") + gt_names, predictions_names = calculate_bbt(matchings, prediction_dict, test_dataset, prediction_mode='model_prediction') + logging.info(f"number of predictions is {predictions_names.shape, gt_names.shape}") + logging.info(f"per sample accuracy of model_prediction in oracle is {calc_accuracies_bbt(gt_names, predictions_names)}") + accuracies = calc_per_class_accuracy_bbt(test_dataset, gt_names, predictions_names) + logging.info(f"mean per class accuracy with unknown: {accuracies[0].mean()}") + # logging.info(f"mean per class accuracy without unknown: {accuracies[0][:-1].mean()}") + logging.info(f"per class accuracies: {accuracies}") + logging.info(f"per class precision and recalls: {calc_per_class_prec_recall_bbt(test_dataset, gt_names, predictions_names)}") + print("koko") + # prediction_mode = 'model_prediction' + # print( + # f"per sample accuracy of model_prediction in {data_mode} dataset and model {model_mode} is {calc_accuracies(results, mode=prediction_mode)}") + # accuracies = calc_per_class_accuracy(test_dataset, results, mode=prediction_mode) + # print(f"mean per class accuracy: {accuracies[0].mean()}") + # print(f"per class accuracies: {accuracies}") + # print(f"per class precision and recalls: {calc_per_class_prec_recall(results, mode=prediction_mode)}") + elif exp == "bbt_baseline0": + logging.info("bbt baseline 0...") + train_dataset = TVQADataset(series=cfg.TRAINING.series, split="train", transform=get_test_transforms(series=cfg.TRAINING.series)) + test_dataset = TVQADataset(series=cfg.TRAINING.series, split="test", transform=get_test_transforms(series=cfg.TRAINING.series)) + prediction_dict = predict_bbt(train_dataset, mode="baseline0") + matchings = match_bbox_2(test_dataset, prediction_dict) + gt_names, predictions_names = calculate_bbt(matchings, prediction_dict, test_dataset) + + logging.info(f"per sample accuracy of model_prediction in baseline 0 is {calc_accuracies_bbt(gt_names, predictions_names)}") + accuracies = calc_per_class_accuracy_bbt(train_dataset, gt_names, predictions_names) + logging.info(f"mean per class accuracy with unknown: {accuracies[0].mean()}") + logging.info(f"mean per class accuracy without unknown: {accuracies[0][:-1].mean()}") + logging.info(f"per class accuracies: {accuracies}") + logging.info(f"per class precision and recalls: {calc_per_class_prec_recall_bbt(train_dataset, gt_names, predictions_names)}") + elif exp == "iou_matchings": + logging.info("iou_matchings...") + train_dataset = TVQADataset(series=cfg.TRAINING.series, split="train", transform=get_test_transforms(series=cfg.TRAINING.series)) + test_dataset = TVQADataset(series=cfg.TRAINING.series, split="test", transform=get_test_transforms(series=cfg.TRAINING.series)) + prediction_dict = predict_bbt(train_dataset, mode="baseline0") + matchings = match_bbox_2(test_dataset, prediction_dict) + elif exp == "calculate_embeddings": + logging.info("calculating embeddings...") + train_dataset = TVQADataset(series=cfg.TRAINING.series, split="train", transform=get_test_transforms(series=cfg.TRAINING.series)) + face_embeddings = calculate_embeddings(model, emb_path=f"./output/evaluate_bbt/model/new_bbt_face_embeddings.pt", + dataset=train_dataset, model_mode=model_mode, preload=False) + elif exp == "ablation_num_clusters_friends": + train_dataset = TVQADataset(series=cfg.TRAINING.series, split="train", transform=get_test_transforms(series=cfg.TRAINING.series)) + test_dataset = TVQADataset(series=cfg.TRAINING.series, split="test", transform=get_test_transforms(series=cfg.TRAINING.series)) + face_embeddings = calculate_embeddings(model, emb_path=f"./dataset/bigger_bb/face_embeddings_friends.pt",dataset=train_dataset, model_mode=model_mode, preload=True) + results = evaluate_self_supervised(train_dataset, test_dataset, face_embeddings, mode="baseline1") + exp_num_clusters(range(7, 31), results, face_embeddings, test_dataset) + elif exp == "clustering_and_cleansing_and_closest_kmeans" or exp=="test" or exp=="cleansing_sheldon": + logging.info("clustering_and_cleansing ...") + train_dataset = TVQADataset(series=cfg.TRAINING.series, split="train", transform=get_test_transforms(series=cfg.TRAINING.series)) + test_dataset = TVQADataset(series=cfg.TRAINING.series, split="test", transform=get_test_transforms(series=cfg.TRAINING.series)) + face_embeddings = calculate_embeddings(model, emb_path=f"./dataset/bbt_face_embeddings.pt", dataset=train_dataset, model_mode=model_mode, preload=True) + results = prepare_result(train_dataset, face_embeddings) + prediction_mode = "max_cluster_prediction" + results, clustering_ids = predict_with_clustering(results, face_embeddings, n_clusters=num_clusters, knn_graph=None, pred_mode=prediction_mode) + # cleanse_labels(train_dataset, results, file_path=f"./dataset/prediction.json") + results = calc_distances_with_prototypes(train_dataset.lbl_to_id, results, num_classes) + # results = calc_distances(train_dataset, results, num_classes, alpha=None) + prediction_dict = load_json(f"./dataset/bbt_train_annotations.json") + + prediction_dict = save_predictions(train_dataset.id_to_lbs, results, f"./dataset/bbt_train_annotations_mediaeval.json", prediction_dict, prediction_mode='M2') + + + prediction_dict = save_predictions(train_dataset.id_to_lbs, results, f"./dataset/bbt_train_annotations_sheldon.json", prediction_dict, prediction_mode='closest_cluster') + matchings = load_json("./dataset/box_matchings_with_iou.json") + gt_names, predictions_names = calculate_bbt(matchings, prediction_dict, test_dataset, prediction_mode='closest_cluster') + + logging.info(f"per sample accuracy of model_prediction in baseline 0 is {calc_accuracies_bbt(gt_names, predictions_names)}") + accuracies = calc_per_class_accuracy_bbt(train_dataset, gt_names, predictions_names) + logging.info(f"mean per class accuracy with unknown: {accuracies[0].mean()}") + logging.info(f"mean per class accuracy without unknown: {accuracies[0][:-1].mean()}") + logging.info(f"per class accuracies: {accuracies}") + logging.info(f"per class precision and recalls: {calc_per_class_prec_recall_bbt(train_dataset, gt_names, predictions_names)}") + + gt_names, predictions_names = calculate_bbt(matchings, prediction_dict, test_dataset, prediction_mode='cleansed') + logging.info( + f"per sample accuracy of model_prediction in baseline 0 is {calc_accuracies_bbt(gt_names, predictions_names)}") + accuracies = calc_per_class_accuracy_bbt(train_dataset, gt_names, predictions_names) + logging.info(f"mean per class accuracy with unknown: {accuracies[0].mean()}") + logging.info(f"mean per class accuracy without unknown: {accuracies[0][:-1].mean()}") + logging.info(f"per class accuracies: {accuracies}") + logging.info( + f"per class precision and recalls: {calc_per_class_prec_recall_bbt(train_dataset, gt_names, predictions_names)}") + elif exp == "ours_s09_s10": + logging.info("ours_s09_s10 ...") + train_dataset = TVQADataset(series=cfg.TRAINING.series, split="train", transform=get_test_transforms(series=cfg.TRAINING.series)) + test_dataset = TVQADataset(series=cfg.TRAINING.series, split="test", transform=get_test_transforms(series=cfg.TRAINING.series)) + + prediction_dict = load_json(f"./dataset/bbt_train_annotations_faster_closestcluster.json") + matchings = load_json("./dataset/box_matchings_with_iou.json") + gt_names, predictions_names = calculate_bbt(matchings, prediction_dict, test_dataset, prediction_mode='closest_cluster') + + logging.info( + f"per sample accuracy of model_prediction in baseline 0 is {calc_accuracies_bbt(gt_names, predictions_names)}") + accuracies = calc_per_class_accuracy_bbt(train_dataset, gt_names, predictions_names) + logging.info(f"mean per class accuracy with unknown: {accuracies[0].mean()}") + logging.info(f"per class accuracies: {accuracies}") + logging.info( + f"per class precision and recalls: {calc_per_class_prec_recall_bbt(train_dataset, gt_names, predictions_names)}") + print("booboo") + elif exp == "friends_prediction": + train_dataset = TVQADataset(series="friends", split="train", transform=get_test_transforms(series="friends")) + test_dataset = TVQADataset(series="friends", split="test", transform=get_test_transforms(series="friends")) + # face_embeddings = calculate_embeddings(model, emb_path=f"./dataset/bigger_bb/face_embeddings_friends.pt", + # dataset=train_dataset, model_mode=model_mode, preload=True) + face_embeddings = calculate_embeddings(model, emb_path=f"./output/distances/model/face_embeddings.pt",dataset=train_dataset, model_mode=model_mode, preload=True) + + results = evaluate_self_supervised(train_dataset, test_dataset, face_embeddings, mode="baseline1") + # results = calc_distances(train_dataset, results, num_classes, alpha=None) + + # prediction_dict = load_json(f"./dataset/train_annotations.json") + # test_dict = load_json(f"./dataset/test_annotations.json") + # prediction_dict = save_predictions(train_dataset.id_to_lbs, results, f"./dataset/friends_prediction_dict.json", + # prediction_dict, prediction_mode='closest_cluster') + # + # for ann, ann_test in zip(prediction_dict.values(), test_dict.values()): + # ann["correct_target_name"] = ann_test["target_name"] + # ann["correct_target_id"] = train_dataset.lbl_to_id[ann_test["target_name"]] + # save_json(prediction_dict, file_path=f"./dataset/friends_prediction_dict.json") + + results, clustering_ids = predict_with_clustering(results, face_embeddings, n_clusters=num_clusters, knn_graph=None, pred_mode="max_cluster_prediction") + prediction_mode = "M2" + print( + f"per sample accuracy of model_prediction in {data_mode} dataset and model {model_mode} is {calc_accuracies(results, mode=prediction_mode)}") + accuracies = calc_per_class_accuracy(test_dataset, results, mode=prediction_mode) + print(f"mean per class accuracy: {accuracies[0].mean()}") + print(f"per class accuracies: {accuracies}") + print(f"per class precision and recalls: {calc_per_class_prec_recall(results, mode=prediction_mode)}") + visualize_cluster_distances(results) + # result_matching = match_bbox(train_dataset, test_dataset) + + # results = evaluate(model, test_dataset, model_mode) + # face_embeddings = calculate_embeddings(model, emb_path=f"./output/evaluate_bbt/model/bbt_face_embeddings.pt", dataset=train_dataset, model_mode=model_mode, preload=True) + # face_embeddings = calculate_embeddings(model, emb_path=f"./output/evaluate_bbt/model/face_embeddings.pt", + # dataset=train_dataset, model_mode=model_mode, preload=True) + # face_embeddings_train = calculate_embeddings(model, emb_path=f"./dataset/bigger_bb/face_embeddings_train_{cfg.TRAINING.series}.pt", + # dataset=train_dataset, model_mode="facenet_pretrained", preload=True) + # face_embeddings_test = calculate_embeddings(model, emb_path=f"./dataset/bigger_bb/face_embeddings_test_{cfg.TRAINING.series}.pt", + # dataset=test_dataset, model_mode="facenet_pretrained", preload=True) + # tsne_grid = TSNE(random_state=seed, n_iter=2000).fit_transform(face_embeddings_train.detach().numpy()) + # torch.save(tsne_grid, "./output/evaluate_bbt/model/tsne_grid.pt") + # tsne_grid = torch.load("./output/evaluate_bbt/model/tsne_grid.pt") +# a == tsne_grid +# predict_bbt(train_dataset, mode="baseline0") + + # gt_names, predictions_names = match_bbox_2(test_dataset, threshold=0) +# print(f"per sample accuracy of model_prediction in baseline 0 is {calc_accuracies_bbt(gt_names, predictions_names)}") +# accuracies = calc_per_class_accuracy_bbt(train_dataset, gt_names, predictions_names) +# print(f"mean per class accuracy with unknown: {accuracies[0].mean()}") +# print(f"mean per class accuracy without unknown: {accuracies[0][:-1].mean()}") +# print(f"per class accuracies: {accuracies}") +# print(f"per class precision and recalls: {calc_per_class_prec_recall_bbt(train_dataset, gt_names, predictions_names)}") + + # results.to_pickle("./output/evaluate_bbt_baseline0/results_baseline0.pkl") + # results = pd.read_pickle("./output/evaluate_bbt_baseline0/results_baseline0.pkl") + # results = evaluate_self_supervised(train_dataset, test_dataset, face_embeddings_all) + + # results = evaluate_oracle_supervised(model, test_dataset, face_embeddings) +# # +# # # knn_graph = build_graph(face_embeddings, train_dataset, neighbours=3385, no_edge=True) +# # num_clusters = 10 +# results = prepare_result(train_dataset, face_embeddings) +# prediction_mode = "max_cluster_prediction" +# results, clustering_ids = predict_with_clustering(results, face_embeddings, n_clusters=num_clusters, knn_graph=None, pred_mode=prediction_mode) +# cleanse_labels(train_dataset, results, file_path=f"./dataset/cleansed_bbt_annotations_{global_cfg.TRAINING.clustering}_{global_cfg.TRAINING.kmeans_batch_size}.json") +# +# # visualize_cluster_distances(results) + + + # results = evaluate_ep_1_5(train_dataset, face_embeddings_all) + # alpha = 1.0 + # results = calc_distances(results, num_classes, alpha=alpha) + + + + + # results = recluster_unknowns(results, num_classes, num_clusters=7) +# # results, plt, fig = clustering_with_gmm(results, face_embeddings, tsne_grid) +# # fig.savefig(os.path.join(f"{vis_path}", f"lala.pdf")) +# # plt.clf() +# # exp_num_clusters(range(7, 20), results, embeddings) +# # results, clustering_ids = predict_with_clustering(results, face_embeddings, n_clusters=num_clusters, pred_mode="max_cluster_prediction") +# # results, sub_clustering_ids = predict_with_clustering(results, subtitle_embeddings, n_clusters=num_clusters, pred_mode="subtitle_prediction") +# +# + # prediction_mode = 'model_prediction' + # print(f"per sample accuracy of model_prediction in {data_mode} dataset and model {model_mode} is {calc_accuracies(results, mode=prediction_mode)}") + # accuracies = calc_per_class_accuracy(test_dataset, results, mode=prediction_mode) + # print(f"mean per class accuracy: {accuracies[0].mean()}") + # print(f"per class accuracies: {accuracies}") + # print(f"per class precision and recalls: {calc_per_class_prec_recall(results, mode=prediction_mode)}") +# +# fig, plt = visualize_tsne(tsne_grid, results["closest_cluster"].values.tolist(), train_dataset.id_to_lbs) +# fig.savefig(os.path.join(f"{vis_path}", f"{data_mode}_{model_mode}_closest_cluster.pdf")) +# plt.clf() +# +# fig, plt = visualize_tsne(tsne_grid, results["cleansed"].values.tolist(), train_dataset.id_to_lbs) +# fig.savefig(os.path.join(f"{vis_path}", f"{data_mode}_{model_mode}_cleansed.pdf")) +# plt.clf() +# + + # fig, plt = visualize_tsne(tsne_grid, results.loc[len(train_dataset) - len(test_dataset):len(train_dataset)-1,"correct_target_id"].values.tolist(), train_dataset.id_to_lbs) + # fig.savefig(os.path.join(f"{vis_path}", f"{data_mode}_{model_mode}_groundtruth.pdf")) + # plt.clf() + # fig, plt = visualize_tsne(tsne_grid, results.loc[len(train_dataset) - len(test_dataset):len(train_dataset)-1,"closest_cluster"].values.tolist(), train_dataset.id_to_lbs) + # fig.savefig(os.path.join(f"{vis_path}", f"{data_mode}_{model_mode}_closest_cluster.pdf")) + # plt.clf() + # print(f"accuracy of model_prediction in {data_mode} dataset and model {model_mode} is" + # f" {calc_accuracies(results.loc[len(train_dataset) - len(test_dataset) :len(train_dataset)-1], mode='closest_cluster')}") + # print(f"per class accuracies: {calc_per_class_accuracy(train_dataset, results.loc[len(train_dataset) - len(test_dataset) :len(train_dataset)-1], mode='closest_cluster')}") + # print(f"per class precision and recalls: {calc_per_class_prec_recall(results.loc[len(train_dataset) - len(test_dataset):len(train_dataset)-1], mode='closest_cluster')}") +# +# ###########################################visualize################################################# +# +# +# # fig, plt = visualize_tsne(tsne_grid, results["weak_label"].values.tolist(), dataset.id_to_lbs) +# # fig.savefig(os.path.join(f"{vis_path}", f"{data_mode}_{model_mode}_weaklabels.pdf")) +# # plt.clf() +# +# # +# fig, plt = visualize_tsne(tsne_grid, clustering_ids, train_dataset.id_to_lbs) +# fig.savefig(os.path.join(f"{vis_path}", f"{data_mode}_{model_mode}_hac8clusters.pdf")) +# plt.clf() +# +# +# # # logging.info(f"accuracy of maximum clustering in {data_mode} dataset and model {model_mode} is " +# # # f"{calc_accuracies(results, mode='max_cluster_prediction')}") +# # # print(f"accuracy of maximum clustering in {data_mode} dataset and model {model_mode} is " +# # # f"{calc_accuracies(results, mode='max_cluster_prediction')}") +# # # logging.info(f"per class accuracies: {calc_per_class_accuracy(results, mode='max_cluster_prediction')}") +# # print(f"per class accuracies: {calc_per_class_accuracy(results, mode='max_cluster_prediction')}") +# logger.info(f"accuracy of model_prediction in {data_mode} dataset and model {model_mode} is" +# f" {calc_accuracies(results, mode='max_cluster_prediction')}") +# print(f"accuracy of model_prediction in {data_mode} dataset and model {model_mode} is" +# f" {calc_accuracies(results, mode='max_cluster_prediction')}") +# logger.info( +# f"per class accuracies: {calc_per_class_accuracy(test_dataset, results, mode='max_cluster_prediction')}") +# print(f"per class accuracies: {calc_per_class_accuracy(test_dataset, results, mode='max_cluster_prediction')}") +# +# logger.info( +# f"per class precision and recalls: {calc_per_class_prec_recall(results, mode='max_cluster_prediction')}") +# print(f"per class precision and recalls: {calc_per_class_prec_recall(results, mode='max_cluster_prediction')}") +# +# +# # fig = plt.figure(figsize=(8, 8)) +# # plt.style.use('seaborn-darkgrid') +# # ax = plt.subplot(aspect='equal') +# # plt.plot(np.array([0.1, 0.4, 0.5, 0.52, 0.55, 0.6, 0.9]), np.array( +# # [0.696776646489589, 0.696776646489589, 0.7631103930418455, 0.7740073757551987, 0.7277433039443212, +# # 0.7218788892129113, 0.7218788892129113]), '-o', label='mean per class accuracy') +# # plt.plot(np.array([0.1, 0.4, 0.5, 0.52, 0.55, 0.6, 0.9]), np.array( +# # [0.7409923213230951, 0.7409923213230951, 0.7359716479621973, 0.7002362669816893, 0.5593620791494389, +# # 0.544595392793857, 0.544595392793857]), '-o', label='sample-level accuracy') +# # plt.xlim(0.0, 1.0) +# # plt.ylim(0.6, 0.8) +# # plt.legend() +# # plt.xlabel("alpha") +# # ax.axis('tight') +# # fig.savefig(os.path.join(f"{vis_path}", f"accuracy_alphas.pdf")) +# # plt.clf() +# ######################### +# # fig = plt.figure(figsize=(8, 8)) +# # plt.style.use('seaborn-darkgrid') +# # ax = plt.subplot(aspect='equal') +# # plt.plot(np.arange(7), np.array([0.54754098, 0.67146283, 0.82484725, 0.71489362, 0.64166667,0.63055556, 0.84646962]), '-o', label='0.1') +# # plt.plot(np.arange(7), np.array([0.54754098, 0.67146283, 0.82484725, 0.71489362, 0.64166667,0.63055556, 0.84646962]), '-o', label='0.4') +# # plt.plot(np.arange(7), np.array([0.7147541 , 0.83453237, 0.91242363, 0.80425532, 0.73888889,0.70555556, 0.63136289]), '-o', label='0.5') +# # plt.plot(np.arange(7), np.array([0.79672131, 0.86091127, 0.95315682, 0.85531915, 0.78055556,0.70833333, 0.46305419]), '-o', label='0.52') +# # plt.plot(np.arange(7), np.array([0.81311475, 0.882494 , 0.96741344, 0.85957447, 0.82222222,0.70833333, 0.0410509 ]), '-o', label='0.55') +# # plt.plot(np.arange(7), np.array([0.81311475, 0.882494 , 0.96741344, 0.85957447, 0.82222222,0.70833333, 0. ]), '-o', label='0.6') +# # plt.plot(np.arange(7), np.array([0.81311475, 0.882494 , 0.96741344, 0.85957447, 0.82222222,0.70833333, 0. ]), '-o', label='0.9') +# # plt.xlim(0, 6) +# # plt.ylim(0.2, 1) +# # plt.legend() +# # ax.axis('tight') +# # fig.savefig(os.path.join(f"{vis_path}", f"alphas.pdf")) +# # plt.clf() +# ############# +# # fig = plt.figure(figsize=(8, 8)) +# # plt.style.use('seaborn-darkgrid') +# # ax = plt.subplot(aspect='equal') +# # plt.plot(np.arange(4), np.array([0, 0.6967766471428571, 0.7631103942857143, 0.7740073757551987]), '-o',label='mean per class accuracy') +# # plt.plot(np.arange(4), np.array([0, 0.7409923213230951, 0.7359716479621973, 0.7002DataLoader362669816893]), '-o',label='sample accuracy') +# # plt.xlim(0, 4) +# # plt.ylim(0.6, 0.8) +# # plt.legend() +# # plt.xticks(range(4), np.array(["random from weaklabel", "max cluster", "closest cluster", "alpha 0.52"]), rotation=45, fontsize=6) +# # plt.xlabel("methods") +# # ax.axis('tight') +# # fig.savefig(os.path.join(f"{vis_path}", f"accuracy_methods.pdf")) +# # plt.clf() +# ############## +# # fig = plt.figure(figsize=(8, 8)) +# # plt.style.use('seaborn-darkgrid') +# # ax = plt.subplot(aspect='equal') +# # plt.plot(np.arange(7), np.array([0.17704918, 0.17026379, 0.27494908, 0.14468085, 0.2 ,0.19166667, 0.62068966]), '-o', label='random from weaklabel') +# # plt.plot(np.arange(7), np.array([0.54754098, 0.67146283, 0.82484725, 0.71489362, 0.64166667, 0.63055556, 0.84646962]), '-o', label='max cluster') +# # plt.plot(np.arange(7), np.array([0.7147541 , 0.83453237, 0.91242363, 0.80425532, 0.73888889, 0.70555556, 0.63136289]), '-o', label='closest cluster') +# # plt.plot(np.arange(7), np.array([0.79672131, 0.86091127, 0.95315682, 0.85531915, 0.78055556, 0.70833333, 0.46305419]), '-o', label='alpha 0.52') +# # plt.xlim(0, 6) +# # plt.ylim(0.2, 1) +# # plt.legend() +# # ax.axis('tight') +# # fig.savefig(os.path.join(f"{vis_path}", f"per_class_accuracies_methods.pdf")) +# # plt.clf() + + +def self_supervised_train_test_split(): + test = load_json("./dataset/self_supervised_dataset/3_test_annotations.json") + dev = load_json("./dataset/self_supervised_dataset/3_dev_annotations.json") + all = {} + i = 0 + while i < len(test.keys()) + len(dev.keys()): + for ann in test.values(): + all[i] = ann + i += 1 + for ann in dev.values(): + all[i] = ann + i += 1 + print(len(test),len(dev), len(all)) + save_json(all, "./dataset/self_supervised_dataset/all_annotations.json") + train = {} + test = {} + for i, ann in all.items(): + train[i] = {key: ann[key] for key in ann.keys() if key in ['face', 'name', 'img', 'subtitle', 'clip', 'series']} + test[i] = {key: ann[key] for key in ann.keys() if key in ['face', 'img', 'subtitle', 'clip', 'series', 'target_name']} + save_json(train, "./dataset/self_supervised_dataset/train_annotations.json") + save_json(test, "./dataset/self_supervised_dataset/test_annotations.json") + + + +if __name__ == "__main__": + # config priority: + # 1. arguments + # 2. config.yaml file + # 3. defaults.py file + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + cfg = get_cfg() + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + + print("Command line arguments: " + str(args)) + print("Running with full config:\n{}".format(cfg)) + Path(cfg.TRAINING.project_dir).mkdir(parents=True, exist_ok=True) + Path(os.path.join(cfg.TRAINING.project_dir, "model")).mkdir(parents=True, exist_ok=True) + config_path = os.path.join(cfg.TRAINING.project_dir, "config.yaml") + with PathManager.open(config_path, "w") as f: + f.write(cfg.dump()) + print("Full config saved to {}".format(os.path.abspath(config_path))) + set_global_cfg(cfg) + visualize(cfg) + +def qualitative_results(): + import xlsxwriter + import os + from tvqa_dataset import save_json, load_json + import numpy as np + prediction_dict = load_json("./dataset/bbt_train_annotations_faster_closestcluster.json") + img_to_faces = {} + for i, ann in prediction_dict.items(): + img = ann['clip']+'_'+ann['img'] + if img in img_to_faces: + img_to_faces[img]['train_id'].append(i) + img_to_faces[img]['face'].append(ann['face']) + else: + img_to_faces[img] = {'train_id': [i], 'face': [ann['face']]} + + + + face_path = "./dataset/bbt_frames/" + img_path = "./dataset/frames_hq/bbt_frames/" + workbook = xlsxwriter.Workbook(f"./dataset/friends_qualitative_results.xlsx") + worksheet = workbook.add_worksheet() + worksheet.write('A1', 'img') + worksheet.write('B1', 'subtitle') + worksheet.write('C1', 'faces') + worksheet.write('D1', 'stage1') + worksheet.write('E1', 'stage2') + worksheet.write('F1', 'stage3') + worksheet.write('G1', 'face1') + worksheet.write('H1', 'face2') + worksheet.write('I1', 'face3') + worksheet.write('J1', 'face4') + + idx = 1 + for img, values in img_to_faces.items(): + train_ids = values['train_id'] + faces = values['face'] + if len(faces) != 3: + continue + ann = prediction_dict[train_ids[0]] + # if len(ann['name']) < 5: + # continue + image = "_".join(ann['face'].split('_')[:-1]) + if image not in ['s03e01_seg02_clip_15_00032', 's01e13_seg02_clip_11_00170','s04e06_seg02_clip_16_00018', 's04e11_seg02_clip_02_00057','s04e11_seg02_clip_02_00061','s04e13_seg02_clip_00_00076','s04e15_seg02_clip_01_00147','s05e08_seg01_clip_00_00099','s05e10_seg02_clip_04_00038','s06e23_seg02_clip_07_00135','s07e11_seg02_clip_18_00029','s08e03_seg02_clip_05_00082','s08e10_seg02_clip_08_00102','s08e23_seg02_clip_04_00057','s09e17_seg02_clip_08_00060','s09e22_seg02_clip_07_00075','s09e24_seg02_clip_09_00071','s09e24_seg02_clip_12_00055','s10e06_seg02_clip_12_00123','s03e21_seg02_clip_13_00094','s10e01_seg02_clip_13_00019','s04e04_seg02_clip_04_00044','s03e01_seg02_clip_15_00051','s04e24_seg01_clip_01_00225']: + continue + worksheet.insert_image(idx, 0, os.path.join(img_path+ann['clip'], ann['img']+'.jpg')) + worksheet.write(idx, 1, ann['subtitle']) + worksheet.write(idx, 2, " ".join(faces)) + stage1, stage2, stage3 = ann['name'], [], [] + for i, (train_id, face) in enumerate(zip(train_ids, faces)): + ann = prediction_dict[train_id] + stage2.append(ann['cleansed']) + stage3.append(ann['closest_cluster']) + worksheet.insert_image(idx, i+6, os.path.join(face_path, face)) + # if set(stage2) == set(stage3): + # continue + print(f"image: {ann['img']+'.jpg'}, subtitle: {ann['subtitle']}, faces: {faces}, stage1: {stage1}, stage2: {stage2}, stage3: {stage3}") + worksheet.write(idx, 3, " ".join(stage1)) + worksheet.write(idx, 4, " ".join(stage2)) + worksheet.write(idx, 5, " ".join(stage3)) + idx += 1 + workbook.close() \ No newline at end of file