-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
80f44db
commit a9aa1bb
Showing
10 changed files
with
4,151 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
|
||
from fvcore.common.config import CfgNode as _CfgNode | ||
import argparse | ||
|
||
|
||
class CfgNode(_CfgNode): | ||
""" | ||
The same as `fvcore.common.config.CfgNode`, but different in: | ||
1. Use unsafe yaml loading by default. | ||
Note that this may lead to arbitrary code execution: you must not | ||
load a config file from untrusted sources before manually inspecting | ||
the content of the file. | ||
2. Support config versioning. | ||
When attempting to merge an old config, it will convert the old config automatically. | ||
""" | ||
|
||
# Note that the default value of allow_unsafe is changed to True | ||
def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None: | ||
loaded_cfg = _CfgNode.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe) | ||
loaded_cfg = type(self)(loaded_cfg) | ||
|
||
self.merge_from_other_cfg(loaded_cfg) | ||
|
||
def dump(self, *args, **kwargs): | ||
""" | ||
Returns: | ||
str: a yaml string representation of the config | ||
""" | ||
# to make it show up in docs | ||
return super().dump(*args, **kwargs) | ||
|
||
|
||
global_cfg = CfgNode() | ||
|
||
|
||
def get_cfg() -> CfgNode: | ||
""" | ||
Get a copy of the default config. | ||
Returns: | ||
a detectron2 CfgNode instance. | ||
""" | ||
from defaults import _C | ||
|
||
return _C.clone() | ||
|
||
|
||
def set_global_cfg(cfg: CfgNode) -> None: | ||
""" | ||
Let the global config point to the given cfg. | ||
Assume that the given "cfg" has the key "KEY", after calling | ||
`set_global_cfg(cfg)`, the key can be accessed by: | ||
.. code-block:: python | ||
from detectron2.config import global_cfg | ||
print(global_cfg.KEY) | ||
By using a hacky global config, you can access these configs anywhere, | ||
without having to pass the config object or the values deep into the code. | ||
This is a hacky feature introduced for quick prototyping / research exploration. | ||
""" | ||
global global_cfg | ||
global_cfg.clear() | ||
global_cfg.update(cfg) | ||
|
||
|
||
def default_argument_parser(): | ||
parser = argparse.ArgumentParser(description="tvqa config file") | ||
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") | ||
parser.add_argument( | ||
"opts", | ||
help="Modify config options using the command-line", | ||
default=None, | ||
nargs=argparse.REMAINDER, | ||
) | ||
return parser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
TRAINING: | ||
project_dir: "./output/default/" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
from config import CfgNode as CN | ||
|
||
_C = CN() | ||
|
||
_C.TRAINING = CN() | ||
|
||
_C.TRAINING.project_dir = "./output/default/" | ||
_C.TRAINING.data_path = "/home/mayoughi/tvqa_experiment/dataset/friends_frames/" | ||
_C.TRAINING.epochs = 100 | ||
_C.TRAINING.lr_decay_epoch = 50 | ||
_C.TRAINING.lr_decay_epochs = [50, 75, 90] | ||
_C.TRAINING.pretrained = True | ||
_C.TRAINING.lr = 0.12 | ||
_C.TRAINING.batch_size = 256 | ||
_C.TRAINING.lr_decay_rate = 0.1 | ||
_C.TRAINING.momentum = 0.9 | ||
_C.TRAINING.weight_decay = 0.0001 | ||
_C.TRAINING.last_commit = "unknown" | ||
_C.TRAINING.supervised = False | ||
_C.TRAINING.data_mode = "correct_target_id" # cleansed, correct_target_id, weak_label | ||
_C.TRAINING.series = "friends" # friends, bbt | ||
_C.TRAINING.clustering = "KMeans" # AgglomerativeClustering, KMeans, MiniBatchKMeans | ||
_C.TRAINING.kmeans_batch_size = 100 # default | ||
_C.TRAINING.exp_type = "normal" # normal, oracle | ||
_C.TRAINING.ours_or_baseline = "ours" # ours, baseline | ||
|
||
# self-supervised parameters | ||
_C.SSL = CN() | ||
_C.SSL.align_alpha = 2 | ||
_C.SSL.unif_t = 2 | ||
_C.SSL.align_w = 1 | ||
_C.SSL.unif_w = 1 | ||
_C.SSL.random_crop = 100 | ||
_C.SSL.include_unknowns = True | ||
_C.SSL.joint = False | ||
_C.SSL.face_layer = False | ||
_C.SSL.sub_layer = False | ||
_C.SSL.mix_layer = True | ||
_C.SSL.face_layer_out_features = 512 | ||
_C.SSL.sub_layer_out_features = 768 | ||
_C.SSL.mix_layer_out_features = 1280 | ||
_C.SSL.mix_layer_in_features = 1280 | ||
|
||
_C.SSL.supervised = False | ||
_C.SSL.epsilon = 0.1 #with probability epsilon pick from closest cluster | ||
|
||
|
||
_C.MODEL = CN() | ||
_C.MODEL.out_features_1 = 512 | ||
_C.MODEL.out_features_2 = 512 | ||
_C.MODEL.out_features_3 = 512 | ||
|
||
_C.GLOBAL = CN() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import torch | ||
import os | ||
from torchvision import datasets | ||
from PIL import Image | ||
from facenet_pytorch import MTCNN, InceptionResnetV1 | ||
from PIL import Image, ImageDraw | ||
from facenet_pytorch import MTCNN, extract_face | ||
import random | ||
|
||
mtcnn = MTCNN(keep_all=True) | ||
clip_name = sorted(os.listdir(os.getcwd())) | ||
|
||
|
||
for i in clip_name: | ||
clip_dir = os.path.join(os.getcwd(),i) | ||
for image in sorted(os.listdir(clip_dir)): | ||
img = Image.open(os.path.join(clip_dir, image)) | ||
img_id = random.randint(1,10000) | ||
boxes, probs, points = mtcnn.detect(img, landmarks=True) | ||
img_draw = img.copy() | ||
draw = ImageDraw.Draw(img_draw) | ||
if boxes is None: | ||
continue | ||
for f, (box, point) in enumerate(zip(boxes, points)): | ||
draw.rectangle(box.tolist(), width=5) | ||
for p in point: | ||
draw.rectangle((p - 10).tolist() + (p + 10).tolist(), width=10) | ||
face = extract_face(img, box, save_path='/home/mayoughi/outputs/detected_face_{}_{}.png'.format(img_id, f)) | ||
img_draw.save('/home/mayoughi/outputs/annotated_faces_{}.png'.format(img_id)) | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
apex==0.9.10dev | ||
classy_vision==0.6.0 | ||
facenet_pytorch==2.5.2 | ||
fuzzywuzzy==0.18.0 | ||
fvcore==0.1.5.post20220512 | ||
matplotlib==3.5.2 | ||
numpy==1.22.4 | ||
omegaconf==2.2.2 | ||
openpyxl==3.0.10 | ||
openpyxl_image_loader==1.0.5 | ||
pandas==1.4.2 | ||
Pillow==9.1.1 | ||
pysrt==1.1.2 | ||
PyYAML==6.0 | ||
scikit_learn==1.1.1 | ||
scipy==1.8.1 | ||
seaborn==0.11.2 | ||
sentence_transformers==2.2.0 | ||
simcse==0.4 | ||
tensorboard==2.9.1 | ||
torch==1.11.0 | ||
torchvision==0.12.0 | ||
tqdm==4.64.0 | ||
transformers==4.19.4 | ||
umap==0.1.1 | ||
vissl==0.1.6 | ||
xlsxwriter==3.0.3 |
Oops, something went wrong.