Skip to content

Commit

Permalink
fixed merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
cameron-a-johnson committed Dec 26, 2024
2 parents 93ea2f2 + 7c524b3 commit 997ac60
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions angel_system/global_step_prediction/run_expirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def run_inference_all_vids(
preds, gt = [], []
avg_smd, avg_smd_normd = np.array([]), np.array([])
mean_F1s = np.array([])
<<<<<<< HEAD
tasks_completed = np.array([])
=======
>>>>>>> 7c524b3c2c8e0d3b3952bf4f2cc97ba28aa00560
for vid_id in all_vid_ids:
print(f"vid_id {vid_id}===========================")

Expand Down Expand Up @@ -144,12 +147,19 @@ def get_unique(activity_ids):
mean_F1s = np.append(mean_F1s, mean_F1)

pred_history = step_predictor.get_single_tracker_pred_history()
<<<<<<< HEAD
smds, smds_normd, task_completed = get_start_moment_distances(pred_history, activity_gt_maxes)
tasks_completed = np.append(tasks_completed, task_completed)

if not math.isnan(compute_mean_smd(smds)):
avg_smd = np.append(avg_smd, compute_mean_smd(smds))
avg_smd_normd = np.append(avg_smd_normd, compute_mean_smd(smds_normd))
=======
smds, smds_normd = get_start_moment_distances(pred_history, activity_gt_maxes)

avg_smd = np.append(avg_smd, np.mean(smds))
avg_smd_normd = np.append(avg_smd_normd, np.mean(smds_normd))
>>>>>>> 7c524b3c2c8e0d3b3952bf4f2cc97ba28aa00560

print(f"smds (# frames): {smds}, normalized:{smds_normd}")
try:
Expand All @@ -165,7 +175,10 @@ def get_unique(activity_ids):
)
print("########## OVERALL")
print(f"Overall average smd: {np.mean(avg_smd)}. Normalized: {np.mean(avg_smd_normd)}")
<<<<<<< HEAD
print(f"tasks completed: {np.sum(tasks_completed)} / {len(tasks_completed)}")
=======
>>>>>>> 7c524b3c2c8e0d3b3952bf4f2cc97ba28aa00560
print(f"overall mean F1: {np.mean(mean_F1s)}")

def get_start_moment_distances(pred_history, activity_gt_maxes):
Expand All @@ -180,12 +193,16 @@ def get_start_moment_distances(pred_history, activity_gt_maxes):
vid_length = len(activity_gt_maxes)
smds = np.zeros(num_classes)
smds_normd = np.zeros(num_classes)
<<<<<<< HEAD
task_completed = 1
=======
>>>>>>> 7c524b3c2c8e0d3b3952bf4f2cc97ba28aa00560
for i in range(num_classes):
if i+1 in pred_history:
smds[i] = abs(np.where(pred_history == i+1)[0][0] - activity_gt_maxes.index(i+1))
smds_normd[i] = smds[i] / vid_length
else:
<<<<<<< HEAD
smds[i] = -1
smds_normd[i] = -1
if i+1 == num_classes:
Expand All @@ -200,6 +217,11 @@ def compute_mean_smd(smd_array):

mask = smd_array >= 0
return np.mean(smd_array[mask])
=======
smds[i] = vid_length
smds_normd[i] = 1
return smds, smds_normd
>>>>>>> 7c524b3c2c8e0d3b3952bf4f2cc97ba28aa00560


def compute_class_f1s_and_mean_f1(TP,FP,FN):
Expand Down

0 comments on commit 997ac60

Please sign in to comment.