-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add new script to output the belief file csv
- Loading branch information
1 parent
cc5bfe9
commit fafc277
Showing
2 changed files
with
248 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
""" | ||
This class is meant to create a CSV file in the format that the BBN system expects. | ||
""" | ||
|
||
# Max number of seconds in the first time difference between when | ||
# the node started and when the first update came in. | ||
# if larger, a correction will be made to make up for start up time | ||
START_UP_TIME_MAX = 0.6 | ||
|
||
class BState: | ||
""" | ||
Enum for the state of the belief | ||
""" | ||
def __init__(self): | ||
self.current = "current" | ||
self.unobserved = "unobserved" | ||
self.done = "done" | ||
|
||
|
||
class BeliefFile: | ||
def __init__(self, filename: str, skill: str, labels: list, start_time: int): | ||
self.filename = filename | ||
self.skill = skill | ||
self.labels = labels | ||
self.running_state = {} # keeps track of the steps | ||
# initialize the running states | ||
for label in labels: | ||
self.running_state[label] = BState().unobserved | ||
# set the first step to current | ||
# NOTE: the example files given had this set to current | ||
# from the very beginnning - an assumption we are making here, too | ||
self.running_state[1.0] = BState().current | ||
# this will be used to calculate the current time in the video | ||
self.start_time = start_time | ||
|
||
# initialize the file - in case we need to overwrite it | ||
with open(self.filename, 'w') as f: | ||
f.write("") | ||
|
||
# flag for handling how long it takes to start up the video | ||
self.first_time_diff = True | ||
|
||
def _add_row_to_file(self, row: str) -> None: | ||
# append the row to the file | ||
with open(self.filename, 'a') as f: | ||
f.write(row) | ||
|
||
def _add_rows(self, conf_array: list, ctime: float) -> None: | ||
""" | ||
Add multiple rows to the file based on the labels | ||
""" | ||
# <skill>, <step_num>, <state>, <confidence>, <timestep> | ||
row = self.skill | ||
|
||
# add the rows | ||
for step in self.labels: | ||
_row = row + f",{step},{self.running_state[step]}," | ||
_row = _row + f"{conf_array[int(step)]},{ctime}\n" # _row = _row + f"{conf_array[int(step)]:0.8f},{ctime:0.8f}\n" | ||
self._add_row_to_file(_row) | ||
|
||
def final_step_done(self) -> None: | ||
""" | ||
This method is called when the final step is done. | ||
""" | ||
# set the final step | ||
self.running_state[self.labels[-1]] = BState().done | ||
|
||
def update_values(self, current_step: float, conf_array: list, current_time: int) -> None: | ||
""" | ||
When you provide an update, this method will update internal state | ||
and trigger a write to the file. | ||
""" | ||
curr_time = float(current_time - self.start_time) * 1e-9 # get seconds from nano | ||
|
||
# correction of the starting time if we notice that the first | ||
# time difference is too large | ||
if self.first_time_diff and curr_time > START_UP_TIME_MAX: | ||
self.first_time_diff = False | ||
self.start_time = current_time # save this for the next update | ||
# assume 0 for now | ||
curr_time = 0.0 | ||
|
||
# check the states and see if they changed | ||
if current_step > 0 and self.running_state[current_step] != BState().current: | ||
# set the current step | ||
self.running_state[current_step] = BState().current | ||
|
||
# see if the previous state was current - that means we change it to done | ||
prev_step = current_step - 1.0 | ||
if prev_step > 0 and self.running_state[prev_step] == BState().current: | ||
self.running_state[prev_step] = BState().done | ||
|
||
# write the rows to the file | ||
self._add_rows(conf_array, curr_time) |
154 changes: 154 additions & 0 deletions
154
angel_system/global_step_prediction/get_bbn_belief_file.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
""" | ||
This script will take a kwcoco file that was output from the TCN node (for example) | ||
and output the belief file that is used by the BBN eval_kit. The belief file is a CSV. | ||
""" | ||
from pathlib import Path | ||
|
||
import click | ||
import kwcoco | ||
import numpy as np | ||
import yaml | ||
|
||
from angel_system.global_step_prediction.belief_file import BeliefFile | ||
from angel_system.global_step_prediction.global_step_predictor import ( | ||
GlobalStepPredictor, | ||
) | ||
|
||
# TODO: make these options in the future? | ||
threshold_multiplier_weak = 1.0 | ||
threshold_frame_count = 3 | ||
threshold_frame_count_weak = 8 | ||
deactivate_thresh_frame_count = 8 | ||
|
||
def get_belief_file( | ||
coco_ds: kwcoco.CocoDataset, | ||
medical_task="r18", | ||
code_dir=Path("."), | ||
out_file=Path("./belief_file.csv"), | ||
model_file=Path("./model_files/task_monitor/global_step_predictor_act_avgs_R18.npy"), | ||
) -> None: | ||
""" | ||
Run the inference and create the belief file. | ||
""" | ||
|
||
# path to the medical activity labels | ||
act_path = code_dir / "config/activity_labels/medical" / f"{medical_task}.yaml" | ||
|
||
# load the steps from the activity config file | ||
with open(act_path, "r") as stream: | ||
config = yaml.safe_load(stream) | ||
labels = [] | ||
for lbl in config["labels"]: | ||
id = float(lbl["id"]) # using float based on the belief file format | ||
if id > 0: # skip the background label - not used in belief format | ||
labels.append(id) | ||
print(f"Labels: {labels}") | ||
|
||
start_time = 0 # start of the video | ||
|
||
# setup the belief file | ||
print(f"setting up output: {out_file}") | ||
belief = BeliefFile(out_file, medical_task.upper(), labels, start_time) | ||
|
||
# setup the global step predictor | ||
gsp = GlobalStepPredictor( | ||
threshold_multiplier_weak=threshold_multiplier_weak, | ||
threshold_frame_count=threshold_frame_count, | ||
threshold_frame_count_weak=threshold_frame_count_weak, | ||
deactivate_thresh_frame_count=deactivate_thresh_frame_count, | ||
recipe_types=[f"{medical_task}"], | ||
activity_config_fpath=act_path.as_posix(), | ||
recipe_config_dict={ | ||
f"{medical_task}": code_dir | ||
/ "config/tasks/medical" | ||
/ f"{medical_task}.yaml" | ||
}, | ||
) | ||
# load the model | ||
gsp.get_average_TP_activations_from_file(model_file) | ||
|
||
all_vid_ids = np.unique(np.asarray(coco_ds.images().lookup("video_id"))) | ||
for vid_id in all_vid_ids: | ||
print(f"vid_id {vid_id}===========================") | ||
|
||
image_ids = coco_ds.index.vidid_to_gids[vid_id] | ||
annots_images = coco_ds.subset(gids=image_ids, copy=True) | ||
|
||
# All N activity confs x each video frame | ||
activity_confs = annots_images.annots().get("prob") | ||
|
||
# get the frame_index from the images | ||
ftimes = annots_images.images().lookup("frame_index") | ||
#print(ftimes) | ||
|
||
step_mode = "granular" | ||
for i, conf_array in enumerate(activity_confs): | ||
current_time = ftimes[i] # get the time from the image's frame_index | ||
|
||
if current_time > 0: # skip any 0 index frames | ||
tracker_dict_list = gsp.process_new_confidences(np.array([conf_array])) | ||
for task in tracker_dict_list: | ||
current_step_id = task[f"current_{step_mode}_step"] | ||
|
||
# If we are on the last step and it is not active, mark it as done | ||
if ( | ||
current_step_id == task[f"total_num_{step_mode}_steps"] - 1 | ||
and not task["active"] | ||
): | ||
belief.final_step_done() | ||
|
||
print(f"Updating based on: {current_time}") | ||
belief.update_values(current_step_id, conf_array, current_time) | ||
|
||
print(f"finished writing belief file: {out_file}") | ||
|
||
|
||
@click.command(context_settings={"help_option_names": ["-h", "--help"]}) | ||
@click.argument( | ||
"medical_task", | ||
type=str, | ||
) | ||
@click.argument( | ||
"coco_file", | ||
type=click.Path( | ||
exists=True, dir_okay=False, readable=True, resolve_path=True, path_type=Path | ||
), | ||
default="./stuff/r18_bench1_activity_predictions.kwcoco", | ||
) | ||
@click.option( | ||
"--code_dir", | ||
type=click.Path( | ||
exists=True, file_okay=False, readable=True, resolve_path=True, path_type=Path | ||
), | ||
default=".", | ||
help="The path to the code directory", | ||
) | ||
@click.option( | ||
"--out_file", | ||
type=click.Path(readable=True, resolve_path=True, path_type=Path), | ||
default="./belief_file.csv", | ||
help="The path to where to save the output file", | ||
) | ||
def run_expirement( | ||
medical_task: str, | ||
coco_file: Path, | ||
code_dir: Path, | ||
out_file: Path, | ||
) -> None: | ||
""" | ||
Creates the belief file. | ||
""" | ||
|
||
print(f"Running medical task: {medical_task}") | ||
print(f"coco_file = {coco_file}") | ||
|
||
get_belief_file( | ||
kwcoco.CocoDataset(coco_file), | ||
medical_task=medical_task, | ||
code_dir=code_dir, | ||
out_file=out_file, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_expirement() |