Skip to content

Commit

Permalink
Show values inside heatmap if graph_show_values is given
Browse files Browse the repository at this point in the history
  • Loading branch information
tbarbette committed Nov 16, 2024
1 parent ef5226c commit 1a84ddd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 19 deletions.
26 changes: 20 additions & 6 deletions npf/output/graph/plots/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
from npf.types.dataset import XYEB


def do_heatmap(graph, axis, key, result_type, data : XYEB, xdata : XYEB, vars_values: dict, shift=0, idx=0, sparse=False):
def do_heatmap(graph, axis, key, result_type, data : XYEB, xdata : XYEB, vars_values: dict, shift=0, idx=0, sparse=False,show_values=False):
graph.format_figure(axis, result_type, shift, key=key)
nseries = 0
yvals = []
for x,y,e,build in data:
nseries = max(len(y), nseries)
y = get_numeric(build._pretty_name)
yvals.append(y)


if not key in vars_values:
print("WARNING: Heatmap with an axis of size 1")
xvals = [1]
Expand All @@ -32,18 +32,18 @@ def do_heatmap(graph, axis, key, result_type, data : XYEB, xdata : XYEB, vars_va
xmax=len(xvals) - 1
ymin=0
ymax=len(yvals) - 1

data = [data[i] for i in np.argsort(yvals)]
yvals = [yvals[i] for i in np.argsort(yvals)]


matrix = np.empty(tuple((ymax-ymin + 1,xmax-xmin + 1)))
matrix[:] = np.NaN

if len(data) <= 1 or nseries <= 1:
print("WARNING: Heatmap needs two dynamic variables. The map will have a weird ratio")


for i, (x, ys, e, build) in enumerate(data): #X index
assert(isinstance(build,Build))
for yi in range(nseries): #index in the array of Y, so it is the index of X
Expand All @@ -59,6 +59,20 @@ def do_heatmap(graph, axis, key, result_type, data : XYEB, xdata : XYEB, vars_va
pos = axis.imshow(matrix)
axis.figure.colorbar(pos, ax=axis)

if show_values:
mean = np.mean(matrix)
for i in range(len(data)):
for j in range(nseries):
v = matrix[i, j]
text = axis.text(
j,
i,
f'%0.{str(show_values - 1)}f' %v,
ha="center",
va="center",
color="w" if v< mean else "black",
)

if sparse:
prop = xmax-xmin / ymax-ymin
if prop < 0:
Expand Down
32 changes: 19 additions & 13 deletions npf/output/grapher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,12 +1195,12 @@ def generate_plot_for_graph(self, i, i_subplot, figure, n_cols, n_lines, vars_va
xname=self.var_name(result_type)
elif graph_type == "heatmap":
"""Heatmap"""
r, ndata = do_heatmap(self, axis, key, result_type, data, xdata, vars_values, shift, ISUBPLOT, sparse = False)
r, ndata = do_heatmap(self, axis, key, result_type, data, xdata, vars_values, shift, ISUBPLOT, sparse = False, show_values=self.get_show_values())
default_add_legend = False
barplot = True
elif graph_type == "sparse_heatmap":
"""sparse Heatmap"""
r, ndata = do_heatmap(self, axis, key, result_type, data, xdata, vars_values, shift, ISUBPLOT, sparse = True)
r, ndata = do_heatmap(self, axis, key, result_type, data, xdata, vars_values, shift, ISUBPLOT, sparse = True, show_values=self.get_show_values())
default_add_legend = False
barplot = True

Expand Down Expand Up @@ -1551,15 +1551,21 @@ def generate_plot_for_graph(self, i, i_subplot, figure, n_cols, n_lines, vars_va
def reject_outliers(self, result, test):
return test.reject_outliers(result)

def write_labels(self, rects, plt, color, idx = 0, each=False):
if self.config('graph_show_values',False):
prec = self.config('graph_show_values',False)
if is_numeric(prec):
prec = get_numeric(prec)
elif type(prec) is list and is_numeric(prec[idx]):
prec = get_numeric(prec[idx])
else:
prec = 2
def get_show_values(self):
prec = self.config('graph_show_values',False)

if not prec:
return False
if is_numeric(prec):
prec = get_numeric(prec)
elif type(prec) is list and is_numeric(prec[idx]):
prec = get_numeric(prec[idx])
else:
prec = 2
return prec

def write_labels(self, prec, rects, plt, color, idx = 0, each=False):
if prec:
def autolabel(rects, ax):
for rect in rects:
if hasattr(rect, 'get_ydata'):
Expand Down Expand Up @@ -1616,7 +1622,7 @@ def do_simple_barplot(self,axis, result_type, data,shift=0,isubplot=0):
c = graphcolorseries[gcolor[isubplot % len(gcolor)]][0]
rects = plt.bar(ticks, y, label=x, color=c, width=width, yerr=( y - mean + std, mean - y + std))

self.write_labels(rects, plt,c)
self.write_labels(self.get_show_values(), rects, plt,c)

plt.xticks(ticks, x)
plt.gca().set_xlim(0, len(x))
Expand Down Expand Up @@ -1855,7 +1861,7 @@ def do_line_plot(self, axis, key, result_type, data : XYEB, data_types, shift,id
allmin = min(allmin, np.min(ax))
allmax = max(allmax, np.max(ax))

self.write_labels(rects, plt, build._color, idx, True)
self.write_labels(self.get_show_values(), rects, plt, build._color, idx, True)

if xmin == float('inf'):
return False, len(data)
Expand Down

0 comments on commit 1a84ddd

Please sign in to comment.