Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Melika-Ayoughi authored Jul 20, 2022
1 parent 80f44db commit a9aa1bb
Show file tree
Hide file tree
Showing 10 changed files with 4,151 additions and 0 deletions.
81 changes: 81 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from fvcore.common.config import CfgNode as _CfgNode
import argparse


class CfgNode(_CfgNode):
"""
The same as `fvcore.common.config.CfgNode`, but different in:
1. Use unsafe yaml loading by default.
Note that this may lead to arbitrary code execution: you must not
load a config file from untrusted sources before manually inspecting
the content of the file.
2. Support config versioning.
When attempting to merge an old config, it will convert the old config automatically.
"""

# Note that the default value of allow_unsafe is changed to True
def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None:
loaded_cfg = _CfgNode.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe)
loaded_cfg = type(self)(loaded_cfg)

self.merge_from_other_cfg(loaded_cfg)

def dump(self, *args, **kwargs):
"""
Returns:
str: a yaml string representation of the config
"""
# to make it show up in docs
return super().dump(*args, **kwargs)


global_cfg = CfgNode()


def get_cfg() -> CfgNode:
"""
Get a copy of the default config.
Returns:
a detectron2 CfgNode instance.
"""
from defaults import _C

return _C.clone()


def set_global_cfg(cfg: CfgNode) -> None:
"""
Let the global config point to the given cfg.
Assume that the given "cfg" has the key "KEY", after calling
`set_global_cfg(cfg)`, the key can be accessed by:
.. code-block:: python
from detectron2.config import global_cfg
print(global_cfg.KEY)
By using a hacky global config, you can access these configs anywhere,
without having to pass the config object or the values deep into the code.
This is a hacky feature introduced for quick prototyping / research exploration.
"""
global global_cfg
global_cfg.clear()
global_cfg.update(cfg)


def default_argument_parser():
parser = argparse.ArgumentParser(description="tvqa config file")
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
return parser
3 changes: 3 additions & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
TRAINING:
project_dir: "./output/default/"

54 changes: 54 additions & 0 deletions defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from config import CfgNode as CN

_C = CN()

_C.TRAINING = CN()

_C.TRAINING.project_dir = "./output/default/"
_C.TRAINING.data_path = "/home/mayoughi/tvqa_experiment/dataset/friends_frames/"
_C.TRAINING.epochs = 100
_C.TRAINING.lr_decay_epoch = 50
_C.TRAINING.lr_decay_epochs = [50, 75, 90]
_C.TRAINING.pretrained = True
_C.TRAINING.lr = 0.12
_C.TRAINING.batch_size = 256
_C.TRAINING.lr_decay_rate = 0.1
_C.TRAINING.momentum = 0.9
_C.TRAINING.weight_decay = 0.0001
_C.TRAINING.last_commit = "unknown"
_C.TRAINING.supervised = False
_C.TRAINING.data_mode = "correct_target_id" # cleansed, correct_target_id, weak_label
_C.TRAINING.series = "friends" # friends, bbt
_C.TRAINING.clustering = "KMeans" # AgglomerativeClustering, KMeans, MiniBatchKMeans
_C.TRAINING.kmeans_batch_size = 100 # default
_C.TRAINING.exp_type = "normal" # normal, oracle
_C.TRAINING.ours_or_baseline = "ours" # ours, baseline

# self-supervised parameters
_C.SSL = CN()
_C.SSL.align_alpha = 2
_C.SSL.unif_t = 2
_C.SSL.align_w = 1
_C.SSL.unif_w = 1
_C.SSL.random_crop = 100
_C.SSL.include_unknowns = True
_C.SSL.joint = False
_C.SSL.face_layer = False
_C.SSL.sub_layer = False
_C.SSL.mix_layer = True
_C.SSL.face_layer_out_features = 512
_C.SSL.sub_layer_out_features = 768
_C.SSL.mix_layer_out_features = 1280
_C.SSL.mix_layer_in_features = 1280

_C.SSL.supervised = False
_C.SSL.epsilon = 0.1 #with probability epsilon pick from closest cluster


_C.MODEL = CN()
_C.MODEL.out_features_1 = 512
_C.MODEL.out_features_2 = 512
_C.MODEL.out_features_3 = 512

_C.GLOBAL = CN()
30 changes: 30 additions & 0 deletions prepare_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import os
from torchvision import datasets
from PIL import Image
from facenet_pytorch import MTCNN, InceptionResnetV1
from PIL import Image, ImageDraw
from facenet_pytorch import MTCNN, extract_face
import random

mtcnn = MTCNN(keep_all=True)
clip_name = sorted(os.listdir(os.getcwd()))


for i in clip_name:
clip_dir = os.path.join(os.getcwd(),i)
for image in sorted(os.listdir(clip_dir)):
img = Image.open(os.path.join(clip_dir, image))
img_id = random.randint(1,10000)
boxes, probs, points = mtcnn.detect(img, landmarks=True)
img_draw = img.copy()
draw = ImageDraw.Draw(img_draw)
if boxes is None:
continue
for f, (box, point) in enumerate(zip(boxes, points)):
draw.rectangle(box.tolist(), width=5)
for p in point:
draw.rectangle((p - 10).tolist() + (p + 10).tolist(), width=10)
face = extract_face(img, box, save_path='/home/mayoughi/outputs/detected_face_{}_{}.png'.format(img_id, f))
img_draw.save('/home/mayoughi/outputs/annotated_faces_{}.png'.format(img_id))
break
27 changes: 27 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
apex==0.9.10dev
classy_vision==0.6.0
facenet_pytorch==2.5.2
fuzzywuzzy==0.18.0
fvcore==0.1.5.post20220512
matplotlib==3.5.2
numpy==1.22.4
omegaconf==2.2.2
openpyxl==3.0.10
openpyxl_image_loader==1.0.5
pandas==1.4.2
Pillow==9.1.1
pysrt==1.1.2
PyYAML==6.0
scikit_learn==1.1.1
scipy==1.8.1
seaborn==0.11.2
sentence_transformers==2.2.0
simcse==0.4
tensorboard==2.9.1
torch==1.11.0
torchvision==0.12.0
tqdm==4.64.0
transformers==4.19.4
umap==0.1.1
vissl==0.1.6
xlsxwriter==3.0.3
Loading

0 comments on commit a9aa1bb

Please sign in to comment.