From 9caf1ea92779ba48f5b90cfe8fa4135eeaa3fa86 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Fri, 29 Jan 2021 16:10:56 +0100 Subject: [PATCH] Various cleanup (#13) Formatting, PEP8, ... --- petabvis/bar_plot.py | 43 ++++--- petabvis/bar_row.py | 1 + petabvis/main.py | 120 +++++++++-------- petabvis/plot_class.py | 65 ++++++---- petabvis/plot_row.py | 20 ++- petabvis/row_class.py | 33 +++-- petabvis/utils.py | 51 ++++++-- .../{visuSpec_plot.py => vis_spec_plot.py} | 121 +++++++++++------- petabvis/window_functionality.py | 119 ++++++++++------- 9 files changed, 359 insertions(+), 214 deletions(-) rename petabvis/{visuSpec_plot.py => vis_spec_plot.py} (82%) diff --git a/petabvis/bar_plot.py b/petabvis/bar_plot.py index 7a031e3..e235882 100644 --- a/petabvis/bar_plot.py +++ b/petabvis/bar_plot.py @@ -16,21 +16,20 @@ class BarPlot(plot_class.PlotClass): visualization_df: PEtab visualization table simulation_df: PEtab simulation table condition_df: PEtab condition table - plotId: Id of the plot (has to in the visualization_df aswell) + plot_id: Id of the plot (has to in the visualization_df as well) Attributes: bar_rows: A list of BarRows (one for each visualization df row) overview_df: A df containing the information of each bar - """ def __init__(self, measurement_df: pd.DataFrame = None, visualization_df: pd.DataFrame = None, simulation_df: pd.DataFrame = None, condition_df: pd.DataFrame = None, - plotId: str = ""): + plot_id: str = ""): super().__init__(measurement_df, visualization_df, simulation_df, - condition_df, plotId) + condition_df, plot_id) self.bar_width = 0.4 # bar_rows also contains simulation bars @@ -39,7 +38,9 @@ def __init__(self, measurement_df: pd.DataFrame = None, self.add_bar_rows(self.simulation_df) # A df containing the information needed to plot the bars - self.overview_df = pd.DataFrame(columns=["x", "y", "name", "sd", "sem", "provided_noise", "is_simulation", "tick_pos"]) + self.overview_df = pd.DataFrame( + columns=["x", "y", "name", "sd", "sem", "provided_noise", + "is_simulation", "tick_pos"]) self.plot_everything() @@ -81,7 +82,8 @@ def get_bars_df(self, bar_rows): df: A dataframe with information relevant for plotting a bar (x, y, sd, etc.) """ - bar_rows = [bar_row for bar_row in bar_rows if bar_row.dataset_id not in self.disabled_rows] + bar_rows = [bar_row for bar_row in bar_rows if + bar_row.dataset_id not in self.disabled_rows] x = range(len(bar_rows)) tick_pos = range(len(bar_rows)) @@ -92,8 +94,10 @@ def get_bars_df(self, bar_rows): noise = [bar.provided_noise for bar in bar_rows] is_simulation = [bar.is_simulation for bar in bar_rows] - df = pd.DataFrame(list(zip(x, y, names, sd, sem, noise, is_simulation, tick_pos)), - columns=["x", "y", "name", "sd", "sem", "provided_noise", "is_simulation", "tick_pos"]) + df = pd.DataFrame( + list(zip(x, y, names, sd, sem, noise, is_simulation, tick_pos)), + columns=["x", "y", "name", "sd", "sem", "provided_noise", + "is_simulation", "tick_pos"]) # Adjust x and tick_pos of the bars when simulation bars are plotted # such that they are next to each other @@ -108,7 +112,7 @@ def get_bars_df(self, bar_rows): df.loc[index, "tick_pos"] = i # separate measurement and simulation bars - bar_separation_shift = self.bar_width/2 + bar_separation_shift = self.bar_width / 2 df.loc[~df["is_simulation"], "x"] -= bar_separation_shift df.loc[df["is_simulation"], "x"] += bar_separation_shift @@ -129,10 +133,13 @@ def generate_plot(self): # Add bars simu_rows = self.overview_df["is_simulation"] bar_item = pg.BarGraphItem(x=self.overview_df[~simu_rows]["x"], - height=self.overview_df[~simu_rows]["y"], width=self.bar_width) + height=self.overview_df[~simu_rows][ + "y"], width=self.bar_width) self.plot.addItem(bar_item) # measurement bars - bar_item = pg.BarGraphItem(x=self.overview_df[simu_rows]["x"], brush="w", - height=self.overview_df[simu_rows]["y"], width=self.bar_width) + bar_item = pg.BarGraphItem(x=self.overview_df[simu_rows]["x"], + brush="w", + height=self.overview_df[simu_rows]["y"], + width=self.bar_width) self.plot.addItem(bar_item) # simulation bars # Add error bars @@ -141,20 +148,24 @@ def generate_plot(self): error_length = self.overview_df["sem"] if self.bar_rows[0].plot_type_data == ptc.PROVIDED: error_length = self.overview_df["provided_noise"] - error = pg.ErrorBarItem(x=self.overview_df["x"], y=self.overview_df["y"], - top=error_length, bottom=error_length, beam=0.1) + error = pg.ErrorBarItem(x=self.overview_df["x"], + y=self.overview_df["y"], + top=error_length, bottom=error_length, + beam=0.1) self.plot.addItem(error) # set tick names to the legend entry of the bars xax = self.plot.getAxis("bottom") - ticks = [list(zip(self.overview_df["tick_pos"], self.overview_df["name"]))] + ticks = [list( + zip(self.overview_df["tick_pos"], self.overview_df["name"]))] xax.setTicks(ticks) # set y-scale to log if necessary if "log" in self.bar_rows[0].y_scale: self.plot.setLogMode(y=True) if self.plot_rows[0].x_scale == "log": - self.add_warning("log not supported, using log10 instead (in " + self.plot_title + ")") + self.add_warning( + "log not supported, using log10 instead (in " + self.plot_title + ")") def add_or_remove_line(self, dataset_id): """ diff --git a/petabvis/bar_row.py b/petabvis/bar_row.py index 13b1d85..f075cc2 100644 --- a/petabvis/bar_row.py +++ b/petabvis/bar_row.py @@ -15,6 +15,7 @@ class BarRow(row_class.RowClass): sem: Standard error of the mean of the replicates provided noise: Noise of the measurements """ + def __init__(self, exp_data: pd.DataFrame, plot_spec: pd.Series, condition_df: pd.DataFrame, ): super().__init__(exp_data, plot_spec, condition_df) diff --git a/petabvis/main.py b/petabvis/main.py index 9b2b0b8..22caf24 100644 --- a/petabvis/main.py +++ b/petabvis/main.py @@ -1,19 +1,19 @@ import argparse import sys # We need sys so that we can pass argv to QApplication +import warnings import numpy as np import pandas as pd -import warnings import petab.C as ptc +import pyqtgraph as pg from PySide2 import QtWidgets, QtCore, QtGui from PySide2.QtWidgets import QVBoxLayout, QComboBox, QWidget, QLabel from petab import measurements, core -import pyqtgraph as pg from . import utils -from . import visuSpec_plot -from . import bar_plot +from . import vis_spec_plot from . import window_functionality +from .bar_plot import BarPlot class MainWindow(QtWidgets.QMainWindow): @@ -36,6 +36,7 @@ class MainWindow(QtWidgets.QMainWindow): current_list_index: List index of the currently displayed plot wid: QSplitter between main plot and correlation plot """ + def __init__(self, exp_data: pd.DataFrame, visualization_df: pd.DataFrame = None, simulation_df: pd.DataFrame = None, @@ -55,7 +56,7 @@ def __init__(self, exp_data: pd.DataFrame, self.condition_df = condition_df self.observable_df = observable_df self.exp_data = exp_data - self.visu_spec_plots = [] + self.vis_spec_plots = [] self.wid = QtWidgets.QSplitter() self.plot1_widget = pg.GraphicsLayoutWidget(show=True) self.plot2_widget = pg.GraphicsLayoutWidget(show=False) @@ -105,26 +106,30 @@ def add_plots(self): Returns: List of PlotItem """ - self.clear_QSplitter() - self.visu_spec_plots.clear() + self.clear_qsplitter() + self.vis_spec_plots.clear() if self.visualization_df is not None: # to keep the order of plots consistent with names from the plot selection - indexes = np.unique(self.visualization_df[ptc.PLOT_ID], return_index=True)[1] - plot_ids = [self.visualization_df[ptc.PLOT_ID][index] for index in sorted(indexes)] + indexes = \ + np.unique(self.visualization_df[ptc.PLOT_ID], return_index=True)[1] + plot_ids = [self.visualization_df[ptc.PLOT_ID][index] for index in + sorted(indexes)] for plot_id in plot_ids: - self.create_and_add_visuPlot(plot_id) + self.create_and_add_vis_plot(plot_id) else: # default plot when no visu_df is provided - self.create_and_add_visuPlot() + self.create_and_add_vis_plot() - plots = [visuPlot.getPlot() for visuPlot in self.visu_spec_plots] + plots = [vis_spec_plot.get_plot() for vis_spec_plot in + self.vis_spec_plots] # update the cbox self.cbox.clear() # calling this method sets the index of the cbox to 0 # and thus displays the first plot - utils.add_plotnames_to_cbox(self.exp_data, self.visualization_df, self.cbox) + utils.add_plotnames_to_cbox(self.exp_data, self.visualization_df, + self.cbox) return plots @@ -135,13 +140,15 @@ def index_changed(self, i: int): Arguments: i: index of the selected plot """ - if 0 <= i < len(self.visu_spec_plots): # i is -1 when the cbox is cleared - self.clear_QSplitter() - self.plot1_widget.addItem(self.visu_spec_plots[i].getPlot()) + if 0 <= i < len( + self.vis_spec_plots): # i is -1 when the cbox is cleared + self.clear_qsplitter() + self.plot1_widget.addItem(self.vis_spec_plots[i].get_plot()) self.plot2_widget.hide() if self.simulation_df is not None: self.plot2_widget.show() - self.plot2_widget.addItem(self.visu_spec_plots[i].correlation_plot) + self.plot2_widget.addItem( + self.vis_spec_plots[i].correlation_plot) self.current_list_index = i def keyPressEvent(self, ev): @@ -153,18 +160,18 @@ def keyPressEvent(self, ev): """ # Exit when pressing ctrl + Q ctrl = False - if (ev.modifiers() & QtCore.Qt.ControlModifier): + if ev.modifiers() & QtCore.Qt.ControlModifier: ctrl = True if ctrl and ev.key() == QtCore.Qt.Key_Q: sys.exit() - if(ev.key() == QtCore.Qt.Key_Up): + if ev.key() == QtCore.Qt.Key_Up: self.index_changed(self.current_list_index - 1) - if(ev.key() == QtCore.Qt.Key_Down): + if ev.key() == QtCore.Qt.Key_Down: self.index_changed(self.current_list_index + 1) - if(ev.key() == QtCore.Qt.Key_Left): + if ev.key() == QtCore.Qt.Key_Left: self.index_changed(self.current_list_index - 1) - if(ev.key() == QtCore.Qt.Key_Right): + if ev.key() == QtCore.Qt.Key_Right: self.index_changed(self.current_list_index + 1) def add_warning(self, message: str): @@ -177,7 +184,8 @@ def add_warning(self, message: str): if message not in self.warn_msg.text(): self.warn_msg.setText(self.warn_msg.text() + message + "\n") - def redirect_warning(self, message, category, filename=None, lineno=None, file=None, line=None): + def redirect_warning(self, message, category, filename=None, lineno=None, + file=None, line=None): """ Redirect all warning messages and display them in the window. @@ -187,11 +195,11 @@ def redirect_warning(self, message, category, filename=None, lineno=None, file=N print("Warning redirected: " + str(message)) self.add_warning(str(message)) - def create_and_add_visuPlot(self, plot_id=""): + def create_and_add_vis_plot(self, plot_id=""): """ - Create a visuSpec_plot object based on the given plot_id. + Create a vis_spec_plot object based on the given plot_id. If no plot_it is provided the default will be plotted. - Add all the warnings of the visuPlot object to the warning text box. + Add all the warnings of the vis_plot object to the warning text box. The actual plotting happens in the index_changed method @@ -201,41 +209,46 @@ def create_and_add_visuPlot(self, plot_id=""): # split the measurement df by observable when using default plots if self.visualization_df is None: # to keep the order of plots consistent with names from the plot selection - indexes = np.unique(self.exp_data[ptc.OBSERVABLE_ID], return_index=True)[1] - observable_ids = [self.exp_data[ptc.OBSERVABLE_ID][index] for index in sorted(indexes)] + indexes = \ + np.unique(self.exp_data[ptc.OBSERVABLE_ID], return_index=True)[1] + observable_ids = [self.exp_data[ptc.OBSERVABLE_ID][index] for index + in sorted(indexes)] for observable_id in observable_ids: rows = self.exp_data[ptc.OBSERVABLE_ID] == observable_id data = self.exp_data[rows] - visuPlot = visuSpec_plot.VisuSpecPlot(measurement_df=data, visualization_df=None, - condition_df=self.condition_df, - simulation_df=self.simulation_df, plotId=plot_id) - self.visu_spec_plots.append(visuPlot) - if visuPlot.warnings: - self.add_warning(visuPlot.warnings) + vis_plot = vis_spec_plot.VisSpecPlot( + measurement_df=data, visualization_df=None, + condition_df=self.condition_df, + simulation_df=self.simulation_df, plot_id=plot_id) + self.vis_spec_plots.append(vis_plot) + if vis_plot.warnings: + self.add_warning(vis_plot.warnings) else: # reduce the visualization df to the relevant rows (by plotId) rows = self.visualization_df[ptc.PLOT_ID] == plot_id - visu_df = self.visualization_df[rows] - if ptc.PLOT_TYPE_SIMULATION in visu_df.columns and\ - visu_df.iloc[0][ptc.PLOT_TYPE_SIMULATION] == ptc.BAR_PLOT: - barPlot = bar_plot.BarPlot(measurement_df=self.exp_data, - visualization_df=visu_df, + vis_df = self.visualization_df[rows] + if ptc.PLOT_TYPE_SIMULATION in vis_df.columns and \ + vis_df.iloc[0][ptc.PLOT_TYPE_SIMULATION] == ptc.BAR_PLOT: + bar_plot = BarPlot(measurement_df=self.exp_data, + visualization_df=vis_df, condition_df=self.condition_df, - simulation_df=self.simulation_df, plotId=plot_id) + simulation_df=self.simulation_df, + plot_id=plot_id) # might want to change the name of visu_spec_plots to clarify that # it can also include bar plots (maybe to plots?) - self.visu_spec_plots.append(barPlot) + self.vis_spec_plots.append(bar_plot) else: - visuPlot = visuSpec_plot.VisuSpecPlot(measurement_df=self.exp_data, - visualization_df=visu_df, - condition_df=self.condition_df, - simulation_df=self.simulation_df, plotId=plot_id) - self.visu_spec_plots.append(visuPlot) - if visuPlot.warnings: - self.add_warning(visuPlot.warnings) - - def clear_QSplitter(self): + vis_plot = vis_spec_plot.VisSpecPlot( + measurement_df=self.exp_data, + visualization_df=vis_df, + condition_df=self.condition_df, + simulation_df=self.simulation_df, plot_id=plot_id) + self.vis_spec_plots.append(vis_plot) + if vis_plot.warnings: + self.add_warning(vis_plot.warnings) + + def clear_qsplitter(self): """ Clear the GraphicsLayoutWidgets for the measurement and correlation plot @@ -258,11 +271,12 @@ def main(): visualization_df = None if args.visualization is not None: - visualization_df = core.concat_tables(args.visualization, core.get_visualization_df) + visualization_df = core.concat_tables(args.visualization, + core.get_visualization_df) app = QtWidgets.QApplication(sys.argv) - main = MainWindow(exp_data, visualization_df) - main.show() + main_window = MainWindow(exp_data, visualization_df) + main_window.show() sys.exit(app.exec_()) diff --git a/petabvis/plot_class.py b/petabvis/plot_class.py index a42e012..33d41ba 100644 --- a/petabvis/plot_class.py +++ b/petabvis/plot_class.py @@ -13,14 +13,14 @@ class PlotClass: visualization_df: PEtab visualization table simulation_df: PEtab simulation table condition_df: PEtab condition table - plotId: Id of the plot (has to in the visualization_df aswell) + plot_id: Id of the plot (has to in the visualization_df aswell) Attributes: measurement_df: PEtab measurement table visualization_df: PEtab visualization table simulation_df: PEtab simulation table condition_df: PEtab condition table - plotId: Id of the plot (has to in the visualization_df aswell) + plot_id: Id of the plot (has to in the visualization_df aswell) error_bars: A list of pg.ErrorBarItems warnings: String of warning messages if the input is incorrect or not supported @@ -35,24 +35,25 @@ def __init__(self, measurement_df: pd.DataFrame = None, visualization_df: pd.DataFrame = None, simulation_df: pd.DataFrame = None, condition_df: pd.DataFrame = None, - plotId: str = ""): + plot_id: str = ""): self.measurement_df = measurement_df self.visualization_df = visualization_df self.simulation_df = simulation_df self.condition_df = condition_df - self.plotId = plotId + self.plot_id = plot_id self.error_bars = [] self.disabled_rows = set() # set of plot_ids that are disabled self.warnings = "" - self.has_replicates = petab.measurements.measurements_have_replicates(self.measurement_df) + self.has_replicates = petab.measurements.measurements_have_replicates( + self.measurement_df) self.plot_title = utils.get_plot_title(self.visualization_df) self.plot = pg.PlotItem(title=self.plot_title) self.correlation_plot = pg.PlotItem(title="Correlation") def generate_correlation_plot(self, overview_df): """ - Generate the scatterplot between the + Generate the scatter plot between the measurement and simulation values. Arguments: @@ -62,8 +63,10 @@ def generate_correlation_plot(self, overview_df): self.correlation_plot.clear() if not overview_df.empty: - measurements = overview_df[~overview_df["is_simulation"]]["y"].tolist() - simulations = overview_df[overview_df["is_simulation"]]["y"].tolist() + measurements = overview_df[~overview_df["is_simulation"]][ + "y"].tolist() + simulations = overview_df[overview_df["is_simulation"]][ + "y"].tolist() self.add_points(overview_df) self.correlation_plot.setLabel("left", "Simulation") @@ -71,19 +74,22 @@ def generate_correlation_plot(self, overview_df): min_value = min(measurements + simulations) max_value = max(measurements + simulations) - self.correlation_plot.setRange(xRange=(min_value, max_value), yRange=(min_value, max_value)) + self.correlation_plot.setRange(xRange=(min_value, max_value), + yRange=(min_value, max_value)) self.correlation_plot.addItem(pg.InfiniteLine([0, 0], angle=45)) # calculate and add the r_squared value - self.r_squared = self.get_R_squared(measurements, simulations) + self.r_squared = self.get_r_squared(measurements, simulations) r_squared_text = "R Squared:\n" + str(self.r_squared)[0:5] - r_squared_text = pg.TextItem(str(r_squared_text), anchor=(0, 0), color="k") + r_squared_text = pg.TextItem(str(r_squared_text), anchor=(0, 0), + color="k") r_squared_text.setPos(min_value, max_value) - self.correlation_plot.addItem(r_squared_text, anchor=(0, 0), color="k") + self.correlation_plot.addItem(r_squared_text, anchor=(0, 0), + color="k") def add_points(self, overview_df: pd.DataFrame): """ - Add the points to the scatterplot and + Add the points to the scatter plot and display an info text when clicking on a point. Arguments: @@ -93,18 +99,23 @@ def add_points(self, overview_df: pd.DataFrame): measurements = overview_df[~overview_df["is_simulation"]]["y"].tolist() simulations = overview_df[overview_df["is_simulation"]]["y"].tolist() names = overview_df[~overview_df["is_simulation"]]["name"].tolist() - point_descriptions = [(names[i] + "\nmeasurement: " + str(measurements[i]) + - "\nsimulation: " + str(simulations[i])) - for i in range(len(measurements))] + point_descriptions = [ + (names[i] + "\nmeasurement: " + str(measurements[i]) + + "\nsimulation: " + str(simulations[i])) + for i in range(len(measurements))] # only line plots have x-values, barplots do not if "x_label" in overview_df.columns: x = overview_df[~overview_df["is_simulation"]]["x"].tolist() - x_label = overview_df[~overview_df["is_simulation"]]["x_label"].tolist() - point_descriptions = [(point_descriptions[i] + "\n" + str(x_label[i])) + ": " + - str(x[i]) for i in range(len(point_descriptions))] + x_label = overview_df[~overview_df["is_simulation"]][ + "x_label"].tolist() + point_descriptions = [ + (point_descriptions[i] + "\n" + str(x_label[i])) + ": " + + str(x[i]) for i in range(len(point_descriptions))] # create the scatterplot - scatter_plot = pg.ScatterPlotItem(pen=pg.mkPen(None), brush=pg.mkBrush(0, 0, 0)) - spots = [{'pos': [m, s], 'data': idx} for m, s, idx in zip(measurements, simulations, point_descriptions)] + scatter_plot = pg.ScatterPlotItem(pen=pg.mkPen(None), + brush=pg.mkBrush(0, 0, 0)) + spots = [{'pos': [m, s], 'data': idx} for m, s, idx in + zip(measurements, simulations, point_descriptions)] scatter_plot.addPoints(spots) self.correlation_plot.addItem(scatter_plot) @@ -119,7 +130,8 @@ def clicked(plot, points): if last_clicked is not None: last_clicked.resetPen() # remove the text when the same point is clicked twice - if last_clicked == points[0] and info_text.textItem.toPlainText() != "": + if (last_clicked == points[0] + and info_text.textItem.toPlainText() != ""): info_text.setText("") else: points[0].setPen('b', width=2) @@ -129,7 +141,7 @@ def clicked(plot, points): scatter_plot.sigClicked.connect(clicked) - def get_R_squared(self, measurements, simulations): + def get_r_squared(self, measurements, simulations): """ Calculate the R^2 value between the measurement and simulation values. @@ -140,13 +152,14 @@ def get_R_squared(self, measurements, simulations): Returns: The R^2 value """ - slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(measurements, simulations) + slope, intercept, r_value, p_value, std_err = scipy.stats.linregress( + measurements, simulations) print("Linear Regression Statistics for " + self.plot_title + ":") print("Slope: " + str(slope) + ", Intercept: " + str(intercept) + ", R-value: " + str(r_value) + ", p-value: " + str(p_value) + ", Std Err: " + str(std_err)) - return r_value**2 + return r_value ** 2 def add_warning(self, message: str): """ @@ -159,5 +172,5 @@ def add_warning(self, message: str): if message not in self.warnings: self.warnings = self.warnings + message + "\n" - def getPlot(self): + def get_plot(self): return self.plot diff --git a/petabvis/plot_row.py b/petabvis/plot_row.py index 5ccbfc8..4eaaa8e 100644 --- a/petabvis/plot_row.py +++ b/petabvis/plot_row.py @@ -2,8 +2,8 @@ import pandas as pd import petab.C as ptc -from . import utils from . import row_class +from . import utils class PlotRow(row_class.RowClass): @@ -18,6 +18,7 @@ class PlotRow(row_class.RowClass): sem: Standard error of the mean of the replicates provided noise: Noise of the measurements """ + def __init__(self, exp_data: pd.DataFrame, plot_spec: pd.Series, condition_df: pd.DataFrame, ): @@ -26,8 +27,10 @@ def __init__(self, exp_data: pd.DataFrame, # calculate new attributes self.x_data = self.get_x_data() self.y_data = self.get_y_data() - self.sd = utils.sd_replicates(self.line_data, self.x_var, self.is_simulation) - self.sem = utils.sem_replicates(self.line_data, self.x_var, self.is_simulation) + self.sd = utils.sd_replicates(self.line_data, self.x_var, + self.is_simulation) + self.sem = utils.sem_replicates(self.line_data, self.x_var, + self.is_simulation) self.provided_noise = self.get_provided_noise() def get_x_data(self): @@ -43,7 +46,7 @@ def get_x_data(self): else: # for time plots if self.has_replicates and self.plot_type_data != ptc.REPLICATE: # to keep the order intact (only needed if no replicate id col is provided) - x_data = np.asarray([x_values for x_values, df in self.replicates[0].groupby(self.x_var, sort=True)]) + x_data = np.array(sorted(set(self.replicates[0][self.x_var]))) else: x_data = np.asarray(self.replicates[0][self.x_var]) x_data = x_data + self.x_offset @@ -53,7 +56,7 @@ def get_x_data(self): def get_y_data(self): """ Return the mean of the y-values that should be plotted if - the plottype is not ptc.REPLICATE. + the plot type is not ptc.REPLICATE. Otherwise, return the y-values of the first replicate. @@ -63,7 +66,8 @@ def get_y_data(self): variable = self.get_y_variable_name() # either measurement or simulation y_data = np.asarray(self.replicates[0][variable]) if self.plot_type_data != ptc.REPLICATE: - y_data = utils.mean_replicates(self.line_data, self.x_var, variable) + y_data = utils.mean_replicates(self.line_data, self.x_var, + variable) y_data = y_data + self.y_offset return y_data @@ -71,11 +75,13 @@ def get_y_data(self): def get_replicate_x_data(self): """ Return the x-values of each replicate as a list of lists + Returns: x_data_replicates: The y-values for each replicate """ x_data = [] - default_x_values = [df for _, df in self.replicates[0].groupby(self.x_var, sort=True)] + default_x_values = [df for _, df in + self.replicates[0].groupby(self.x_var, sort=True)] for replicate in self.replicates: if ptc.REPLICATE_ID in self.line_data.columns: x_values = np.asarray(replicate[self.x_var]) diff --git a/petabvis/row_class.py b/petabvis/row_class.py index 63346f0..96dc876 100644 --- a/petabvis/row_class.py +++ b/petabvis/row_class.py @@ -1,7 +1,7 @@ import numpy as np import pandas as pd -import petab.C as ptc import petab +import petab.C as ptc from . import utils @@ -15,6 +15,7 @@ class RowClass: exp_data: PEtab measurement table plot_spec: A single row of a PEtab visualization table condition_df: PEtab condition table + Attributes: line_data: PEtab measurement or simulation table reduced to relevant rows plot_spec: A single row of a PEtab visualization table @@ -33,15 +34,14 @@ class RowClass: plot_type_data: The type how replicates should be handled, can be MeanAndSD, MeanAndSEM, replicate or provided is_simulation: Boolean, True if exp_data is a simulation df - has_replicates: Booelean, True if replicates are in line_data + has_replicates: Boolean, True if replicates are in line_data replicates: List of line_data subsets, divided by replicateId - """ def __init__(self, exp_data: pd.DataFrame, plot_spec: pd.Series, condition_df: pd.DataFrame, ): - self.x_data = [] # placeholder value, will be overwritten by plot_row - self.y_data = [] # placeholder value, will be overwritten by plot_row/bar_row + self.x_data = [] # placeholder value, will be overwritten by plot_row + self.y_data = [] # placeholder value, will be overwritten by plot_row/bar_row # set attributes self.plot_spec = plot_spec @@ -64,15 +64,19 @@ def __init__(self, exp_data: pd.DataFrame, # reduce dfs to relevant rows self.line_data = exp_data if self.dataset_id and ptc.DATASET_ID in self.line_data: # != "" - self.line_data = self.line_data[self.line_data[ptc.DATASET_ID] == self.dataset_id] + self.line_data = self.line_data[ + self.line_data[ptc.DATASET_ID] == self.dataset_id] if self.y_var: # != "" # filter by y-values if specified - self.line_data = self.line_data[self.line_data[ptc.OBSERVABLE_ID] == self.y_var] + self.line_data = self.line_data[ + self.line_data[ptc.OBSERVABLE_ID] == self.y_var] if self.condition_df is not None and self.x_var != ptc.TIME: # reduce the condition df to the relevant rows (by condition id) - self.condition_df = utils.reduce_condition_df(self.line_data, self.condition_df) + self.condition_df = utils.reduce_condition_df(self.line_data, + self.condition_df) - self.has_replicates = petab.measurements.measurements_have_replicates(self.line_data) + self.has_replicates = petab.measurements.measurements_have_replicates( + self.line_data) self.replicates = utils.split_replicates(self.line_data) def get_data_df(self): @@ -86,12 +90,15 @@ def get_data_df(self): df: The dataframe containing the row information. """ if len(self.x_data) == len(self.y_data): - df = pd.DataFrame({"x": self.x_data, "y": self.y_data, "name": self.legend_name, - "is_simulation": self.is_simulation, "dataset_id": self.dataset_id, - "x_label": self.x_label}) + df = pd.DataFrame( + {"x": self.x_data, "y": self.y_data, "name": self.legend_name, + "is_simulation": self.is_simulation, + "dataset_id": self.dataset_id, + "x_label": self.x_label}) return df else: - raise Exception("Error: The number of x- and y-values are different") + raise Exception( + "Error: The number of x- and y-values are different") def get_provided_noise(self): """ diff --git a/petabvis/utils.py b/petabvis/utils.py index 041b4e0..f25029c 100644 --- a/petabvis/utils.py +++ b/petabvis/utils.py @@ -1,7 +1,8 @@ -import petab.C as ptc +import warnings + import numpy as np import pandas as pd -import warnings +import petab.C as ptc from PySide2 import QtCore from PySide2.QtWidgets import QComboBox @@ -26,6 +27,7 @@ def get_legend_name(plot_spec: pd.Series): def get_x_var(plot_spec: pd.Series): """ Returns the name of the x variable of the plot specification + Arguments: plot_spec: A single row of a visualization df Returns: @@ -41,6 +43,7 @@ def get_x_var(plot_spec: pd.Series): def get_y_var(plot_spec: pd.Series): """ Returns the observable which should be plotted on the y-axis + Arguments: plot_spec: A single row of a visualization df Returns: @@ -56,8 +59,10 @@ def get_y_var(plot_spec: pd.Series): def get_x_offset(plot_spec: pd.Series): """ Returns the x offset + Arguments: plot_spec: A single row of a visualization df + Returns: The x offset """ @@ -71,8 +76,10 @@ def get_x_offset(plot_spec: pd.Series): def get_x_scale(plot_spec: pd.Series): """ Returns the scale of the x axis (lin, log or order) + Arguments: plot_spec: A single row of a visualization df + Returns: The x scale """ @@ -87,8 +94,10 @@ def get_x_scale(plot_spec: pd.Series): def get_y_scale(plot_spec: pd.Series): """ Returns the scale of the y axis (lin, log or order) + Arguments: plot_spec: A single row of a visualization df + Returns: The x offset """ @@ -102,8 +111,10 @@ def get_y_scale(plot_spec: pd.Series): def get_y_offset(plot_spec: pd.Series): """ Returns the y offset + Arguments: plot_spec: A single row of a visualization df + Returns: The y offset """ @@ -117,8 +128,10 @@ def get_y_offset(plot_spec: pd.Series): def get_x_label(plot_spec: pd.Series): """ Returns the label of the x axis + Arguments: plot_spec: A single row of a visualization df + Returns: The label of the x axis """ @@ -132,8 +145,10 @@ def get_x_label(plot_spec: pd.Series): def get_y_label(plot_spec: pd.Series): """ Returns the label of the y axis + Arguments: plot_spec: A single row of a visualization df + Returns: The label of the y axis """ @@ -147,8 +162,10 @@ def get_y_label(plot_spec: pd.Series): def get_dataset_id(plot_spec: pd.Series): """ Returns the dataset id + Arguments: plot_spec: A single row of a visualization df + Returns: The dataset id """ @@ -182,6 +199,7 @@ def reduce_condition_df(line_data, condition_df): Arguments: line_data: A subset of a measurement df condition_df: The condition df + Returns: The reduced condition df """ @@ -197,11 +215,14 @@ def reduce_condition_df(line_data, condition_df): condition_df = condition_df[ind_cond] return condition_df + def get_plot_title(visualization_df_rows: pd.DataFrame): """ Returns the title of the plot + Arguments: - plot_spec: A single row of a visualization df + visualization_df_rows: A single row of a visualization df + Returns: The plot title """ @@ -219,6 +240,7 @@ def mean_replicates(line_data: pd.DataFrame, x_var: str = ptc.TIME, y_var: str = ptc.MEASUREMENT): """ Calculate the mean of the replicates. + Note: The line_data already has to be reduced to the relevant simulationConditionIds for concentration plots @@ -226,6 +248,7 @@ def mean_replicates(line_data: pd.DataFrame, x_var: str = ptc.TIME, line_data: A subset of the measurement file x_var: Name of the x-variable y_var: Name of the y-variable (measurement or simulation) + Returns: The mean grouped by x_var """ @@ -250,6 +273,7 @@ def sd_replicates(line_data: pd.DataFrame, x_var: str, is_simulation: bool): x_var: Name of the x-variable is_simulation: Boolean to check if the y variable is measurement or simulation + Returns: The std grouped by x_var """ @@ -280,6 +304,7 @@ def sem_replicates(line_data: pd.DataFrame, x_var: str, is_simulation: bool): x_var: Name of the x-variable is_simulation: Boolean to check if the y variable is measurement or simulation + Returns: The std grouped by x_var """ @@ -290,7 +315,8 @@ def sem_replicates(line_data: pd.DataFrame, x_var: str, is_simulation: bool): grouping = ptc.SIMULATION_CONDITION_ID sd = sd_replicates(line_data, x_var, is_simulation) - n_replicates = [len(replicates) for replicates in line_data.groupby(grouping)] + n_replicates = [len(replicates) for replicates in + line_data.groupby(grouping)] sem = sd / np.sqrt(n_replicates) return sem @@ -299,11 +325,13 @@ def split_replicates(line_data: pd.DataFrame): """ Split the line_data df into replicate dfs based on their replicate Id. + If no replicateId column is in the line_data, line_data will be returned. Arguments: line_data: A subset of the measurement file + Returns: The std grouped by x_var """ @@ -317,7 +345,8 @@ def split_replicates(line_data: pd.DataFrame): return replicates -def add_plotnames_to_cbox(exp_data: pd.DataFrame, visualization_df: pd.DataFrame, cbox: QComboBox): +def add_plotnames_to_cbox(exp_data: pd.DataFrame, + visualization_df: pd.DataFrame, cbox: QComboBox): """ Add the name of every plot in the visualization df to the cbox @@ -332,10 +361,13 @@ def add_plotnames_to_cbox(exp_data: pd.DataFrame, visualization_df: pd.DataFrame # to keep the order of plotnames consistent with the plots that are shown # for every identical plot_id, the plot_name has to be the same - indexes = np.unique(visualization_df[ptc.PLOT_ID], return_index=True)[1] - plot_names = [visualization_df[ptc.PLOT_NAME][index] for index in sorted(indexes)] + indexes = \ + np.unique(visualization_df[ptc.PLOT_ID], return_index=True)[1] + plot_names = [visualization_df[ptc.PLOT_NAME][index] for index in + sorted(indexes)] if len(plot_ids) != len(plot_names): - warnings.warn("The number of plot ids should be the same as the number of plot names") + warnings.warn( + "The number of plot ids should be the same as the number of plot names") for name in plot_names: cbox.addItem(name) @@ -346,7 +378,8 @@ def add_plotnames_to_cbox(exp_data: pd.DataFrame, visualization_df: pd.DataFrame # the default plots are grouped by observable ID # to keep the order of plots consistent with names from the plot selection indexes = np.unique(exp_data[ptc.OBSERVABLE_ID], return_index=True)[1] - observable_ids = [exp_data[ptc.OBSERVABLE_ID][index] for index in sorted(indexes)] + observable_ids = [exp_data[ptc.OBSERVABLE_ID][index] for index in + sorted(indexes)] for observable_id in observable_ids: cbox.addItem(observable_id) diff --git a/petabvis/visuSpec_plot.py b/petabvis/vis_spec_plot.py similarity index 82% rename from petabvis/visuSpec_plot.py rename to petabvis/vis_spec_plot.py index c91353c..3f89923 100644 --- a/petabvis/visuSpec_plot.py +++ b/petabvis/vis_spec_plot.py @@ -1,15 +1,15 @@ import numpy as np import pandas as pd import petab.C as ptc -from PySide2 import QtCore import pyqtgraph as pg +from PySide2 import QtCore +from . import plot_class from . import plot_row from . import utils -from . import plot_class -class VisuSpecPlot(plot_class.PlotClass): +class VisSpecPlot(plot_class.PlotClass): """ Can generate a line plot based on the given specifications @@ -18,7 +18,7 @@ class VisuSpecPlot(plot_class.PlotClass): visualization_df: PEtab visualization table simulation_df: PEtab simulation table condition_df: PEtab condition table - plotId: Id of the plot (has to in the visualization_df aswell) + plot_id: Id of the plot (has to in the visualization_df aswell) Attributes: scatter_points: A dictionary containing 2 lists for @@ -30,13 +30,14 @@ class VisuSpecPlot(plot_class.PlotClass): exp_lines: A list of PlotDataItems simu_lines: A list of PlotDataItems for simulation data """ + def __init__(self, measurement_df: pd.DataFrame = None, visualization_df: pd.DataFrame = None, simulation_df: pd.DataFrame = None, condition_df: pd.DataFrame = None, - plotId: str = ""): + plot_id: str = ""): super().__init__(measurement_df, visualization_df, simulation_df, - condition_df, plotId) + condition_df, plot_id) # reduce the visualization_df to the relevant rows (by plotId) if self.visualization_df is not None: @@ -53,7 +54,8 @@ def __init__(self, measurement_df: pd.DataFrame = None, self.plot_rows = [] # list of plot_rows self.plot_rows_simulation = [] - self.overview_df = pd.DataFrame(columns=["x", "y", "name", "is_simulation", "dataset_id", "x_var"]) + self.overview_df = pd.DataFrame( + columns=["x", "y", "name", "is_simulation", "dataset_id", "x_var"]) self.exp_lines = [] # list of PlotDataItems (measurements) self.simu_lines = [] # (simulations) @@ -69,17 +71,20 @@ def plot_everything(self): """ self.plot.clear() self.error_bars = [] - self.plot_rows = self.generate_plot_rows(self.measurement_df) # list of plot_rows + self.plot_rows = self.generate_plot_rows( + self.measurement_df) # list of plot_rows self.plot_rows_simulation = self.generate_plot_rows(self.simulation_df) self.overview_df = self.generate_overview_df() self.exp_lines = self.generate_plot_data_items(self.plot_rows, is_simulation=False) # list of PlotDataItems (measurements) - self.simu_lines = self.generate_plot_data_items(self.plot_rows_simulation, is_simulation=True) # (simulations) + self.simu_lines = self.generate_plot_data_items( + self.plot_rows_simulation, is_simulation=True) # (simulations) # make sure the is_simulation column is really boolean because otherwise # the logical not operator ~ causes problems - self.overview_df["is_simulation"] = self.overview_df["is_simulation"].astype("bool") + self.overview_df["is_simulation"] = self.overview_df[ + "is_simulation"].astype("bool") self.generate_plot() if self.simulation_df is not None: @@ -96,9 +101,12 @@ def generate_overview_df(self): Returns: overview_df: A dataframe containing an overview of the plotRows """ - overview_df = pd.DataFrame(columns=["x", "y", "name", "is_simulation", "dataset_id", "x_label"]) + overview_df = pd.DataFrame( + columns=["x", "y", "name", "is_simulation", "dataset_id", + "x_label"]) if self.visualization_df is not None: - dfs = [p_row.get_data_df() for p_row in (self.plot_rows + self.plot_rows_simulation) + dfs = [p_row.get_data_df() for p_row in + (self.plot_rows + self.plot_rows_simulation) if p_row.dataset_id not in self.disabled_rows] if dfs: overview_df = pd.concat(dfs, ignore_index=True) @@ -115,7 +123,8 @@ def generate_plot_rows(self, df): if self.visualization_df is not None: for _, plot_spec in self.visualization_df.iterrows(): if df is not None: - plot_line = plot_row.PlotRow(df, plot_spec, self.condition_df) + plot_line = plot_row.PlotRow(df, plot_spec, + self.condition_df) plot_rows.append(plot_line) return plot_rows @@ -127,13 +136,15 @@ def generate_plot_data_items(self, plot_rows, is_simulation: bool = False): Arguments: plot_rows: A list of PlotRow objects is_simulation: True plot_rows belong to a simulation df + Returns: pdis: A list of PlotDataItems """ pdis = [] # list of PlotDataItems for line in plot_rows: if line.dataset_id == "": - plot_lines = self.default_plot(line, is_simulation=is_simulation) + plot_lines = self.default_plot(line, + is_simulation=is_simulation) pdis = pdis + plot_lines else: if line.dataset_id not in self.disabled_rows: @@ -169,13 +180,16 @@ def generate_plot(self): line.setPen(color, style=QtCore.Qt.DashDotLine, width=2) self.plot.addItem(line) if len(self.simu_lines) > 0: - self.simu_lines[i].setPen(color, style=QtCore.Qt.SolidLine, width=2) + self.simu_lines[i].setPen(color, style=QtCore.Qt.SolidLine, + width=2) self.plot.addItem(self.simu_lines[i]) self.add_measurements_points() # Errorbars do not support log scales - if self.plot_rows and ("log" in self.plot_rows[0].x_scale or "log" in self.plot_rows[0].y_scale): + if self.plot_rows and ( + "log" in self.plot_rows[0].x_scale or "log" in self.plot_rows[ + 0].y_scale): if len(self.error_bars) > 0: self.warnings = self.warnings + "Errorbars are not supported with log scales (in " \ + self.plot_title + ")\n" @@ -199,13 +213,15 @@ def add_measurements_points(self): measurements = df[~df["is_simulation"]]["y"].tolist() points = self.plot.plot(x, measurements, pen=None, symbol='o', - symbolBrush=pg.mkBrush(0, 0, 0), symbolSize=6) + symbolBrush=pg.mkBrush(0, 0, 0), + symbolSize=6) self.datasetId_to_points[id] = points x_simulation = df[df["is_simulation"]]["x"].tolist() simulations = df[df["is_simulation"]]["y"].tolist() points = self.plot.plot(x_simulation, simulations, pen=None, symbol='o', - symbolBrush=pg.mkBrush(255, 255, 255), symbolSize=6) + symbolBrush=pg.mkBrush(255, 255, 255), + symbolSize=6) self.datasetId_to_points[id + "_simulation"] = points def plot_row_to_plot_data_item(self, p_row: plot_row.PlotRow): @@ -223,12 +239,13 @@ def plot_row_to_plot_data_item(self, p_row: plot_row.PlotRow): pdi = pg.PlotDataItem(p_row.x_data, p_row.y_data, name=legend_name) # add it to the dict (used for disabling rows by dataset_id) if p_row.is_simulation: - self.datasetId_to_plotDataItem[p_row.dataset_id + "_simulation"] = pdi + self.datasetId_to_plotDataItem[ + p_row.dataset_id + "_simulation"] = pdi else: self.datasetId_to_plotDataItem[p_row.dataset_id] = pdi # Only add error bars when needed - if (p_row.has_replicates or p_row.plot_type_data == ptc.PROVIDED)\ + if (p_row.has_replicates or p_row.plot_type_data == ptc.PROVIDED) \ and p_row.plot_type_data != ptc.REPLICATE: error_length = p_row.sd if p_row.plot_type_data == ptc.MEAN_AND_SEM: @@ -238,23 +255,26 @@ def plot_row_to_plot_data_item(self, p_row: plot_row.PlotRow): beam_width = 0 if len(p_row.x_data) > 0: # p_row.x_data could be empty beam_width = np.max(p_row.x_data) / 100 - error = pg.ErrorBarItem(x=p_row.x_data, y=p_row.y_data, top=error_length, bottom=error_length, beam=beam_width) + error = pg.ErrorBarItem(x=p_row.x_data, y=p_row.y_data, + top=error_length, bottom=error_length, + beam=beam_width) self.error_bars.append(error) # add it to the dict (used for disabling rows by dataset_id) if p_row.is_simulation: - self.datasetId_to_errorbar[p_row.dataset_id + "_simulation"] = error + self.datasetId_to_errorbar[ + p_row.dataset_id + "_simulation"] = error else: self.datasetId_to_errorbar[p_row.dataset_id] = error - return(pdi) + return pdi def default_plot(self, p_row: plot_row.PlotRow, is_simulation=False): """ This method is used when the p_row contains no dataset_id - or no visualization file was provided - Therefore, the whole dataset will be visualized - in a single plot - The plotDataItems created here will be added to self.exp_lines + or no visualization file was provided. + + Therefore, the whole dataset will be visualized in a single plot. + The plotDataItems created here will be added to self.exp_lines. Arguments: p_row: The PlotRow object that contains the information @@ -270,8 +290,9 @@ def default_plot(self, p_row: plot_row.PlotRow, is_simulation=False): if ptc.DATASET_ID in self.measurement_df.columns: grouping = ptc.DATASET_ID else: - self.add_warning("Grouped by observable. If you want to specify another grouping option" - ", please add \"datasetID\" columns.") + self.add_warning( + "Grouped by observable. If you want to specify another " + "grouping option, please add \"datasetID\" columns.") df = self.measurement_df y_var = ptc.MEASUREMENT if is_simulation: @@ -296,10 +317,11 @@ def default_plot(self, p_row: plot_row.PlotRow, is_simulation=False): # add points if is_simulation: line_name = line_name + " simulation" - line_df = pd.DataFrame({"x": x_data.tolist(), "y": y_data.tolist(), "name": group_id, "is_simulation": True}) - else: - line_df = pd.DataFrame({"x": x_data.tolist(), "y": y_data.tolist(), "name": group_id, "is_simulation": False}) - self.overview_df = self.overview_df.append(line_df, ignore_index=True) + line_df = pd.DataFrame( + {"x": x_data.tolist(), "y": y_data.tolist(), + "name": group_id, "is_simulation": is_simulation}) + self.overview_df = self.overview_df.append(line_df, + ignore_index=True) plot_lines.append(pg.PlotDataItem(x_data, y_data, name=line_name)) @@ -314,17 +336,20 @@ def set_scales(self): if "log" in self.plot_rows[0].x_scale: self.plot.setLogMode(x=True) if self.plot_rows[0].x_scale == "log": - self.add_warning("log not supported, using log10 instead (in " + self.plot_title + ")") + self.add_warning( + "log not supported, using log10 instead (in " + self.plot_title + ")") if "log" in self.plot_rows[0].y_scale: self.plot.setLogMode(y=True) if self.plot_rows[0].y_scale == "log": - self.add_warning("log not supported, using log10 instead (in " + self.plot_title + ")") + self.add_warning( + "log not supported, using log10 instead (in " + self.plot_title + ")") def check_log_for_zeros(self): """ Add an offset to values if they contain a zero and will be plotted on log-scale. - The offset is calculated as the smalles nonzero value times 0.001 + + The offset is calculated as the smallest nonzero value times 0.001 (Also adds the offset to the simulation values). """ x_var = utils.get_x_var(self.visualization_df.iloc[0]) @@ -342,29 +367,37 @@ def check_log_for_zeros(self): y_values = np.asarray(self.measurement_df[y_var]) if ptc.X_SCALE in self.visualization_df.columns: - if 0 in x_values and "log" in self.visualization_df.iloc[0][ptc.X_SCALE]: + if 0 in x_values and "log" in self.visualization_df.iloc[0][ + ptc.X_SCALE]: offset = np.min(x_values[np.nonzero(x_values)]) * 0.001 if x_var == ptc.TIME: x_values = x_values + offset self.measurement_df[x_var] = x_values else: for variable in self.visualization_df[ptc.X_VALUES]: - self.condition_df[variable] = np.asarray(self.condition_df[variable]) + offset - self.add_warning("Unable to take log of 0, added offset of " + str(offset) + " to x-values") + self.condition_df[variable] = np.asarray( + self.condition_df[variable]) + offset + self.add_warning( + "Unable to take log of 0, added offset of " + str( + offset) + " to x-values") if self.simulation_df is not None: x_simulation = np.asarray(self.simulation_df[x_var]) self.simulation_df[x_var] = x_simulation + offset if ptc.Y_SCALE in self.visualization_df.columns: - if 0 in y_values and "log" in self.visualization_df.iloc[0][ptc.Y_SCALE]: + if 0 in y_values and "log" in self.visualization_df.iloc[0][ + ptc.Y_SCALE]: offset = np.min(y_values[np.nonzero(y_values)]) * 0.001 y_values = y_values + offset self.measurement_df[y_var] = y_values - self.add_warning("Unable to take log of 0, added offset of " + str(offset) + " to y-values") + self.add_warning( + "Unable to take log of 0, added offset of " + str( + offset) + " to y-values") if self.simulation_df is not None: - y_simulation = np.asarray(self.simulation_df[ptc.SIMULATION]) + y_simulation = np.asarray( + self.simulation_df[ptc.SIMULATION]) self.simulation_df[ptc.SIMULATION] = y_simulation + offset def add_or_remove_line(self, dataset_id): @@ -404,7 +437,7 @@ def disable_line(self, dataset_id): dataset_id: The dataset id of the line that should be removed. """ self.plot.removeItem(self.datasetId_to_plotDataItem[dataset_id]) - if self.datasetId_to_errorbar: # The plot may not have errorbars + if self.datasetId_to_errorbar: # The plot may not have error bars self.plot.removeItem(self.datasetId_to_errorbar[dataset_id]) self.plot.removeItem(self.datasetId_to_points[dataset_id]) @@ -417,6 +450,6 @@ def enable_line(self, dataset_id): dataset_id: The dataset id of the line that should be added. """ self.plot.addItem(self.datasetId_to_plotDataItem[dataset_id]) - if self.datasetId_to_errorbar: # The plot may not have errorbars + if self.datasetId_to_errorbar: # The plot may not have error bars self.plot.addItem(self.datasetId_to_errorbar[dataset_id]) self.plot.addItem(self.datasetId_to_points[dataset_id]) diff --git a/petabvis/window_functionality.py b/petabvis/window_functionality.py index 29b7c0d..fa4e6c5 100644 --- a/petabvis/window_functionality.py +++ b/petabvis/window_functionality.py @@ -1,9 +1,10 @@ -# Import after PySide2 to ensure usage of correct Qt library import os import sys from pathlib import Path import pandas as pd +import petab +import petab.C as ptc from PySide2 import QtWidgets, QtCore, QtGui from PySide2.QtCore import (QAbstractTableModel, QModelIndex, Qt, QSortFilterProxyModel) @@ -11,9 +12,6 @@ from PySide2.QtGui import QIcon from PySide2.QtWidgets import (QAction, QVBoxLayout, QHeaderView, QSizePolicy, QTableView, QWidget, QFileDialog) - -import petab -import petab.C as ptc from petab import core from petab.visualize.helper_functions import check_ex_exp_columns @@ -75,10 +73,11 @@ def get_value(self, row, column): class VisualizaionTableModel(CustomTableModel): """ Special table model for visualization files. - Make the first column of the table editable for - the checkbox column. + + Make the first column of the table editable for the checkbox column. Highlight the rows of the currently displayed plot. """ + def __init__(self, df=None, window=None): CustomTableModel.__init__(self, df) self.window = window @@ -88,14 +87,16 @@ def flags(self, index): return 0 if index.column() == 0: - return Qt.ItemIsEditable | Qt.ItemIsEnabled | Qt.ItemIsSelectable | Qt.ItemIsUserCheckable + return Qt.ItemIsEditable | Qt.ItemIsEnabled \ + | Qt.ItemIsSelectable | Qt.ItemIsUserCheckable return Qt.ItemIsSelectable | Qt.ItemIsEnabled def data(self, index, role=Qt.DisplayRole): if role == Qt.BackgroundRole: - current_plot = self.window.visu_spec_plots[self.window.current_list_index] - current_plot_id = current_plot.plotId + current_plot = self.window.vis_spec_plots[ + self.window.current_list_index] + current_plot_id = current_plot.plot_id if self.df[ptc.PLOT_ID][index.row()] == current_plot_id: return QtGui.QColor("yellow") else: @@ -107,15 +108,20 @@ def get_window(self): class CheckBoxDelegate(QtWidgets.QItemDelegate): """ - A delegate that places a fully functioning QCheckBox cell to the column to which it's applied. - Used for the visualization table to add the checkbox column and provide it's functionality. + A delegate that places a fully functioning QCheckBox cell to the column to + which it is applied. + + Used for the visualization table to add the checkbox column and provide + its functionality. """ + def __init__(self, parent): QtWidgets.QItemDelegate.__init__(self, parent) def createEditor(self, parent, option, index): """ - Important, otherwise an editor is created if the user clicks in this cell. + Important, otherwise an editor is created if the user clicks in this + cell. """ return None @@ -123,23 +129,31 @@ def paint(self, painter, option, index): """ Paint a checkbox without the label. """ - self.drawCheck(painter, option, option.rect, QtCore.Qt.Unchecked if int(index.data()) == 0 else QtCore.Qt.Checked) + self.drawCheck(painter, option, option.rect, + QtCore.Qt.Unchecked if int( + index.data()) == 0 else QtCore.Qt.Checked) def editorEvent(self, event, model, option, index): - ''' + """ Change the data in the model and the state of the checkbox - if the user presses the left mousebutton and this cell is editable. Otherwise do nothing. - ''' + if the user presses the left mousebutton and this cell is editable. + Otherwise do nothing. + """ if not int(index.flags() & QtCore.Qt.ItemIsEditable) > 0: return False - if event.type() == QtCore.QEvent.MouseButtonRelease and event.button() == QtCore.Qt.LeftButton: + if event.type() == QtCore.QEvent.MouseButtonRelease \ + and event.button() == QtCore.Qt.LeftButton: # Change the checkbox-state - plotId = model.sourceModel().get_value(index.row(), ptc.PLOT_ID) - datasetId = model.sourceModel().get_value(index.row(), ptc.DATASET_ID) + plot_id = model.sourceModel().get_value(index.row(), ptc.PLOT_ID) + dataset_id = model.sourceModel().get_value(index.row(), + ptc.DATASET_ID) window = model.sourceModel().get_window() - visu_spec_plot = [visu_spec_plot for visu_spec_plot in window.visu_spec_plots if visu_spec_plot.plotId == plotId][0] - visu_spec_plot.add_or_remove_line(datasetId) + # Set `vis_spec_plot` to the one that matches `plot_id` + for vis_spec_plot in window.vis_spec_plots: + if vis_spec_plot.plot_id == plot_id: + break + vis_spec_plot.add_or_remove_line(dataset_id) self.setModelData(None, model, index) return True @@ -149,7 +163,8 @@ def setModelData(self, editor, model, index): """ Change the state of the checkbox after it was clicked. """ - model.setData(index, 1 if int(index.data()) == 0 else 0, QtCore.Qt.EditRole) + model.setData(index, 1 if int(index.data()) == 0 else 0, + QtCore.Qt.EditRole) class TableWidget(QWidget): @@ -180,7 +195,8 @@ def __init__(self, data: pd.DataFrame, add_checkbox_col: bool, window): self.horizontal_header = self.table_view.horizontalHeader() self.horizontal_header.setSortIndicator(-1, Qt.DescendingOrder) self.vertical_header = self.table_view.verticalHeader() - self.horizontal_header.setSectionResizeMode(QHeaderView.ResizeToContents) + self.horizontal_header.setSectionResizeMode( + QHeaderView.ResizeToContents) # QWidget Layout self.main_layout = QVBoxLayout() @@ -204,7 +220,9 @@ def pop_up_table_view(window: QtWidgets.QMainWindow, df: pd.DataFrame): add_checkbox_col = False if window.visualization_df.equals(df): add_checkbox_col = True - window.table_window = TableWidget(data=df, add_checkbox_col=add_checkbox_col, window=window) + window.table_window = TableWidget(data=df, + add_checkbox_col=add_checkbox_col, + window=window) window.table_window.setGeometry(QtCore.QRect(100, 100, 800, 400)) window.table_window.show() @@ -262,8 +280,10 @@ def table_tree_view(window: QtWidgets.QMainWindow, folder_path): root_node.appendRow(branch) tree_view.setModel(model) - reconnect(tree_view.clicked, lambda i: exchange_dataframe_on_click(i, model, window)) - reconnect(tree_view.doubleClicked, lambda i: display_table_on_doubleclick(i, model, window)) + reconnect(tree_view.clicked, + lambda i: exchange_dataframe_on_click(i, model, window)) + reconnect(tree_view.doubleClicked, + lambda i: display_table_on_doubleclick(i, model, window)) def reconnect(signal, new_function=None): @@ -287,9 +307,8 @@ def exchange_dataframe_on_click(index: QtCore.QModelIndex, model: QtGui.QStandardItemModel, window: QtWidgets.QMainWindow): """ - Changes the currently plotted dataframe with the one - that gets clicked on and replot the data, - e.g. switch the measurement or visualization df. + Changes the currently plotted dataframe with the one that gets clicked on + and replot the data, e.g. switch the measurement or visualization df. Arguments: index: index of the clicked dataframe @@ -353,16 +372,18 @@ def add_file_selector(window: QtWidgets.QMainWindow): """ open_yaml_file = QAction(QIcon('open.png'), 'Open YAML file...', window) open_yaml_file.triggered.connect(lambda x: show_yaml_dialog(x, window)) - open_simulation_file = QAction(QIcon('open.png'), 'Open simulation file...', window) - open_simulation_file.triggered.connect(lambda x: show_simulation_dialog(x, window)) + open_simulation_file = QAction(QIcon('open.png'), + 'Open simulation file...', window) + open_simulation_file.triggered.connect( + lambda x: show_simulation_dialog(x, window)) quit = QAction("Quit", window) quit.triggered.connect(sys.exit) menubar = window.menuBar() - fileMenu = menubar.addMenu('&Select File') - fileMenu.addAction(open_yaml_file) - fileMenu.addAction(open_simulation_file) - fileMenu.addAction(quit) + file_menu = menubar.addMenu('&Select File') + file_menu.addAction(open_yaml_file) + file_menu.addAction(open_simulation_file) + file_menu.addAction(quit) def show_yaml_dialog(self, window: QtWidgets.QMainWindow): @@ -391,12 +412,13 @@ def show_yaml_dialog(self, window: QtWidgets.QMainWindow): window.yaml_dict = yaml_dict if ptc.VISUALIZATION_FILES not in yaml_dict: window.visualization_df = None - window.add_warning("The YAML file contains no visualization file (default plotted)") + window.add_warning( + "The YAML file contains no visualization file (default plotted)") window.simulation_df = None # table_tree_view sets the df attributes of the window # equal to the first file of each branch (measurement, visualization, ...) - window.listWidget = table_tree_view(window, last_dir) + table_tree_view(window, last_dir) window.add_plots() @@ -412,7 +434,8 @@ def show_simulation_dialog(self, window: QtWidgets.QMainWindow): settings = QtCore.QSettings("petab", "petabvis") if settings.value("last_dir") is not None: home_dir = settings.value("last_dir") - file_name = QFileDialog.getOpenFileName(window, 'Open simulation file', home_dir)[0] + file_name = QFileDialog.getOpenFileName( + window, 'Open simulation file', home_dir)[0] if file_name: # if a file was selected if window.exp_data is None: window.add_warning("Please open a YAML file first.") @@ -420,23 +443,27 @@ def show_simulation_dialog(self, window: QtWidgets.QMainWindow): window.warn_msg.setText("") sim_data = core.get_simulation_df(file_name) # check columns, and add non-mandatory default columns - sim_data, _, _ = check_ex_exp_columns(sim_data, None, None, None, None, None, - window.condition_df, sim=True) - # delete the replicateId column if it gets added to the simulation table - # but is not in exp_data because it causes problems when splitting the replicates - if ptc.REPLICATE_ID not in window.exp_data.columns and ptc.REPLICATE_ID in sim_data.columns: + sim_data, _, _ = check_ex_exp_columns( + sim_data, None, None, None, None, None, + window.condition_df, sim=True) + # delete the replicateId column if it gets added to the simulation + # table but is not in exp_data because it causes problems when + # splitting the replicates + if ptc.REPLICATE_ID not in window.exp_data.columns \ + and ptc.REPLICATE_ID in sim_data.columns: sim_data.drop(ptc.REPLICATE_ID, axis=1, inplace=True) if len(window.yaml_dict[ptc.MEASUREMENT_FILES]) > 1: - window.add_warning("Not Implemented Error: Loading a simulation file with " - "multiple measurement files is currently not supported.") + window.add_warning( + "Not Implemented Error: Loading a simulation file with " + "multiple measurement files is currently not supported.") else: window.simulation_df = sim_data window.add_plots() # insert correlation plot at position 1 window.wid.insertWidget(1, window.plot2_widget) - window.listWidget = table_tree_view(window, os.path.dirname(file_name)) + table_tree_view(window, os.path.dirname(file_name)) # save the directory for the next use last_dir = os.path.dirname(file_name) + "/"