Skip to content

Commit

Permalink
change
Browse files Browse the repository at this point in the history
  • Loading branch information
rogeriobonatti committed Aug 12, 2022
1 parent 577dd6f commit 8315dcf
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions mushr_rhc_ros/src/eval_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
times = times[condition]

# clear data above episode max
max_dist = 1500
max_dist = 1000
condition = distances<max_dist
distances = distances[condition]
times = times[condition]
Expand Down Expand Up @@ -93,7 +93,7 @@
'12L': all_vals_mean[2::4],
'24L': all_vals_mean[3::4],
# 'Dataset fraction': [0.0, 0.01, 0.1, 0.5, 1.0],
'Dataset fraction': ['540', '30K', '300K', '1.5M', '3M']}
'Number of Tokens Processed': ['540', '30K', '300K', '1.5M', '3M']}

# data = {'3L': all_vals_median[::4],
# '6L': all_vals_median[1::4],
Expand All @@ -106,6 +106,8 @@

df = pd.DataFrame(data)
print(df)
dfm = df.melt('Dataset fraction', var_name='cols', value_name='Average meters traveled')
sns.catplot(x="Dataset fraction", y="Average meters traveled", hue='cols', data=dfm, kind='point')
dfm = df.melt('Number of Tokens Processed', var_name='cols', value_name='Average meters traveled [m]')
sns.catplot(x="Number of Tokens Processed", y="Average meters traveled [m]", hue='cols', data=dfm, kind='point')
plt.grid()
plt.gcf().subplots_adjust(bottom=0.15)
plt.savefig('/home/azureuser/hackathon_data_premium/e2e_eval_models4/model_test/all_plots.png')

0 comments on commit 8315dcf

Please sign in to comment.