-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathLS_AMEGO.py
184 lines (156 loc) · 6.86 KB
/
LS_AMEGO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
import json
import tqdm
import numpy as np
import torch
import torch.nn as nn
from epic_kitchens.hoa import load_detections
from tools.data import FrameDsetSubsampled, FlowDataset
from torchvision import transforms
import argparse
from omegaconf import OmegaConf
class OnlineClusteringTrack:
def __init__(self, clustering_threshold):
self.cosine = torch.nn.CosineSimilarity(dim=-1)
self.threshold = clustering_threshold
self.clusters = {}
self.track2cluster = {}
def score_track(self, track, features_track):
if len(self.clusters) == 0:
return {0: self.threshold - 1}
scores = {}
for cluster_id in self.clusters.keys():
scores[cluster_id] = np.mean([self.cosine(
features_track[i], features_track[track]).item() for i in self.clusters[cluster_id]])
return scores
def step(self, track, features_track):
scores = self.score_track(track, features_track)
if max(scores.values()) >= self.threshold:
cluster = max(scores, key=scores.get)
self.track2cluster[track] = cluster
self.clusters[cluster].append(track)
else:
new_cluster = len(self.clusters)
self.track2cluster[track] = new_cluster
self.clusters[new_cluster] = [track]
return self.track2cluster[track]
class LS_AMEGO:
def __init__(self, dset, root, config):
self.dset = dset
self.root = root
self.config = config
self.output_dir = config.output_dir
self.fps = config.fps
self.v_id = config.v_id
self.consecutive = config.consecutive
self.flow_threshold = config.flow_threshold
self.hand_det_score = config.hand_det_score
self.no_filter_flow = config.no_filter_flow
self.no_filter_hands = config.no_filter_hands
self.dataset = FrameDsetSubsampled(root, self.fps, self.v_id, dset.name, video_fps=config.video_fps)
self.flow_dataset = FlowDataset(root, self.fps, self.v_id, dset.name, video_fps=config.video_fps)
self.detections = load_detections(self.dataset.dset.detections_path(self.v_id))
net = self._initialize_network()
self.net = nn.DataParallel(net)
self.net.eval().cuda()
self.grouped_tracks = []
self.group = []
self.final_results = {}
self.features_track = {}
self.clustering = OnlineClusteringTrack(config.clustering_threshold_ls)
def _initialize_network(self):
net = torch.hub.load("facebookresearch/swag", model="vit_l16_in1k")
net.head = nn.Identity()
resolution = 512
model_transforms = transforms.Compose([
transforms.Resize(
resolution,
interpolation=transforms.InterpolationMode.BICUBIC,
),
transforms.CenterCrop(resolution),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
])
self.dataset.dset.transform = model_transforms
return net
def step(self, frame_idx):
frame = self.dataset.frames[frame_idx]
flow = self._get_flow(frame_idx)
hands = self._get_hands(frame[1] - 1)
if self._should_process_frame(hands, flow):
self.group.append(frame[1])
else:
self._process_group()
def _get_flow(self, frame_idx):
if frame_idx == 0:
return 0
flow = self.flow_dataset[self.dataset.frames[frame_idx][1] - 1]
return torch.norm(flow, 2, dim=[0, 1, 2]).item()
def _get_hands(self, frame_num):
detection = self.detections[frame_num]
return [hand for hand in detection.hands if hand.score >= self.hand_det_score]
def _should_process_frame(self, hands, flow):
return (self.no_filter_hands or len(hands) != 0) and (self.no_filter_flow or flow <= self.flow_threshold)
def _process_group(self):
if len(self.group) >= self.consecutive:
self.grouped_tracks.append((self.group[0], self.group[-1]))
self.group = []
def extract_feat(self, track):
features = []
for frame in range(track[1], track[2]):
image = self.dataset.dset.load_image((self.v_id, frame))
image = self.dataset.dset.transform(image)
with torch.no_grad():
features.append(self.net(image.unsqueeze(0).cuda()))
return torch.stack(features).mean(0, keepdim=True)
def process(self):
for frame_idx in tqdm.tqdm(range(len(self.dataset.frames))):
self.step(frame_idx)
self._process_group() # Process any remaining frames in the group
for track_i, track in enumerate(self.grouped_tracks):
track_num = track_i + 1
track_boundaries = (track_num, track[0], track[-1])
self.features_track[track_num] = self.extract_feat(track_boundaries)
cluster = self.clustering.step(track_num, self.features_track)
self.final_results[track_num] = {
'cluster': cluster,
'features': self.features_track[track_num].cpu().numpy().tolist(),
'num_frame': list(range(track[0], track[-1]))
}
self._save_results()
def _save_results(self):
output_dir = os.path.join(self.output_dir, 'LS_AMEGO')
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, self.v_id + ".json"), 'w') as file:
json.dump(list(self.final_results.values()), file, indent=4)
def parse_args(config_keys):
parser = argparse.ArgumentParser(description='Modify configuration parameters.')
for key in config_keys:
parser.add_argument(f'--{key}', type=str, help=f'Override {key}')
parser.add_argument('--video_fps', type=float, help='FPS of the video to be processed')
return parser.parse_args()
if __name__ == '__main__':
config = OmegaConf.load('configs/default.yaml')
config_keys = list(config.keys())
args = parse_args(config_keys)
for key in config_keys:
value = getattr(args, key, None)
if value is not None:
if isinstance(config[key], bool):
value = value.lower() in ['true', '1']
elif isinstance(config[key], int):
value = int(value)
elif isinstance(config[key], float):
value = float(value)
config[key] = value
config.video_fps = args.video_fps
if config.dset == 'epic':
from tools.data import EPICDataset
dset = EPICDataset(config.root)
else:
from tools.data import SingleVideoDataset
dset = SingleVideoDataset(config.root, config.v_id, config.video_fps)
processor = LS_AMEGO(dset, config.root, config)
processor.process()