Skip to content

Commit

Permalink
Various cleanup (#13)
Browse files Browse the repository at this point in the history
Formatting, PEP8, ...
  • Loading branch information
dweindl authored Jan 29, 2021
1 parent 8e155fc commit 9caf1ea
Show file tree
Hide file tree
Showing 9 changed files with 359 additions and 214 deletions.
43 changes: 27 additions & 16 deletions petabvis/bar_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,20 @@ class BarPlot(plot_class.PlotClass):
visualization_df: PEtab visualization table
simulation_df: PEtab simulation table
condition_df: PEtab condition table
plotId: Id of the plot (has to in the visualization_df aswell)
plot_id: Id of the plot (has to in the visualization_df as well)
Attributes:
bar_rows: A list of BarRows (one for each visualization df row)
overview_df: A df containing the information of each bar
"""

def __init__(self, measurement_df: pd.DataFrame = None,
visualization_df: pd.DataFrame = None,
simulation_df: pd.DataFrame = None,
condition_df: pd.DataFrame = None,
plotId: str = ""):
plot_id: str = ""):
super().__init__(measurement_df, visualization_df, simulation_df,
condition_df, plotId)
condition_df, plot_id)

self.bar_width = 0.4
# bar_rows also contains simulation bars
Expand All @@ -39,7 +38,9 @@ def __init__(self, measurement_df: pd.DataFrame = None,
self.add_bar_rows(self.simulation_df)

# A df containing the information needed to plot the bars
self.overview_df = pd.DataFrame(columns=["x", "y", "name", "sd", "sem", "provided_noise", "is_simulation", "tick_pos"])
self.overview_df = pd.DataFrame(
columns=["x", "y", "name", "sd", "sem", "provided_noise",
"is_simulation", "tick_pos"])

self.plot_everything()

Expand Down Expand Up @@ -81,7 +82,8 @@ def get_bars_df(self, bar_rows):
df: A dataframe with information relevant
for plotting a bar (x, y, sd, etc.)
"""
bar_rows = [bar_row for bar_row in bar_rows if bar_row.dataset_id not in self.disabled_rows]
bar_rows = [bar_row for bar_row in bar_rows if
bar_row.dataset_id not in self.disabled_rows]

x = range(len(bar_rows))
tick_pos = range(len(bar_rows))
Expand All @@ -92,8 +94,10 @@ def get_bars_df(self, bar_rows):
noise = [bar.provided_noise for bar in bar_rows]
is_simulation = [bar.is_simulation for bar in bar_rows]

df = pd.DataFrame(list(zip(x, y, names, sd, sem, noise, is_simulation, tick_pos)),
columns=["x", "y", "name", "sd", "sem", "provided_noise", "is_simulation", "tick_pos"])
df = pd.DataFrame(
list(zip(x, y, names, sd, sem, noise, is_simulation, tick_pos)),
columns=["x", "y", "name", "sd", "sem", "provided_noise",
"is_simulation", "tick_pos"])

# Adjust x and tick_pos of the bars when simulation bars are plotted
# such that they are next to each other
Expand All @@ -108,7 +112,7 @@ def get_bars_df(self, bar_rows):
df.loc[index, "tick_pos"] = i

# separate measurement and simulation bars
bar_separation_shift = self.bar_width/2
bar_separation_shift = self.bar_width / 2
df.loc[~df["is_simulation"], "x"] -= bar_separation_shift
df.loc[df["is_simulation"], "x"] += bar_separation_shift

Expand All @@ -129,10 +133,13 @@ def generate_plot(self):
# Add bars
simu_rows = self.overview_df["is_simulation"]
bar_item = pg.BarGraphItem(x=self.overview_df[~simu_rows]["x"],
height=self.overview_df[~simu_rows]["y"], width=self.bar_width)
height=self.overview_df[~simu_rows][
"y"], width=self.bar_width)
self.plot.addItem(bar_item) # measurement bars
bar_item = pg.BarGraphItem(x=self.overview_df[simu_rows]["x"], brush="w",
height=self.overview_df[simu_rows]["y"], width=self.bar_width)
bar_item = pg.BarGraphItem(x=self.overview_df[simu_rows]["x"],
brush="w",
height=self.overview_df[simu_rows]["y"],
width=self.bar_width)
self.plot.addItem(bar_item) # simulation bars

# Add error bars
Expand All @@ -141,20 +148,24 @@ def generate_plot(self):
error_length = self.overview_df["sem"]
if self.bar_rows[0].plot_type_data == ptc.PROVIDED:
error_length = self.overview_df["provided_noise"]
error = pg.ErrorBarItem(x=self.overview_df["x"], y=self.overview_df["y"],
top=error_length, bottom=error_length, beam=0.1)
error = pg.ErrorBarItem(x=self.overview_df["x"],
y=self.overview_df["y"],
top=error_length, bottom=error_length,
beam=0.1)
self.plot.addItem(error)

# set tick names to the legend entry of the bars
xax = self.plot.getAxis("bottom")
ticks = [list(zip(self.overview_df["tick_pos"], self.overview_df["name"]))]
ticks = [list(
zip(self.overview_df["tick_pos"], self.overview_df["name"]))]
xax.setTicks(ticks)

# set y-scale to log if necessary
if "log" in self.bar_rows[0].y_scale:
self.plot.setLogMode(y=True)
if self.plot_rows[0].x_scale == "log":
self.add_warning("log not supported, using log10 instead (in " + self.plot_title + ")")
self.add_warning(
"log not supported, using log10 instead (in " + self.plot_title + ")")

def add_or_remove_line(self, dataset_id):
"""
Expand Down
1 change: 1 addition & 0 deletions petabvis/bar_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class BarRow(row_class.RowClass):
sem: Standard error of the mean of the replicates
provided noise: Noise of the measurements
"""

def __init__(self, exp_data: pd.DataFrame,
plot_spec: pd.Series, condition_df: pd.DataFrame, ):
super().__init__(exp_data, plot_spec, condition_df)
Expand Down
120 changes: 67 additions & 53 deletions petabvis/main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import argparse
import sys # We need sys so that we can pass argv to QApplication
import warnings

import numpy as np
import pandas as pd
import warnings
import petab.C as ptc
import pyqtgraph as pg
from PySide2 import QtWidgets, QtCore, QtGui
from PySide2.QtWidgets import QVBoxLayout, QComboBox, QWidget, QLabel
from petab import measurements, core
import pyqtgraph as pg

from . import utils
from . import visuSpec_plot
from . import bar_plot
from . import vis_spec_plot
from . import window_functionality
from .bar_plot import BarPlot


class MainWindow(QtWidgets.QMainWindow):
Expand All @@ -36,6 +36,7 @@ class MainWindow(QtWidgets.QMainWindow):
current_list_index: List index of the currently displayed plot
wid: QSplitter between main plot and correlation plot
"""

def __init__(self, exp_data: pd.DataFrame,
visualization_df: pd.DataFrame = None,
simulation_df: pd.DataFrame = None,
Expand All @@ -55,7 +56,7 @@ def __init__(self, exp_data: pd.DataFrame,
self.condition_df = condition_df
self.observable_df = observable_df
self.exp_data = exp_data
self.visu_spec_plots = []
self.vis_spec_plots = []
self.wid = QtWidgets.QSplitter()
self.plot1_widget = pg.GraphicsLayoutWidget(show=True)
self.plot2_widget = pg.GraphicsLayoutWidget(show=False)
Expand Down Expand Up @@ -105,26 +106,30 @@ def add_plots(self):
Returns:
List of PlotItem
"""
self.clear_QSplitter()
self.visu_spec_plots.clear()
self.clear_qsplitter()
self.vis_spec_plots.clear()

if self.visualization_df is not None:
# to keep the order of plots consistent with names from the plot selection
indexes = np.unique(self.visualization_df[ptc.PLOT_ID], return_index=True)[1]
plot_ids = [self.visualization_df[ptc.PLOT_ID][index] for index in sorted(indexes)]
indexes = \
np.unique(self.visualization_df[ptc.PLOT_ID], return_index=True)[1]
plot_ids = [self.visualization_df[ptc.PLOT_ID][index] for index in
sorted(indexes)]
for plot_id in plot_ids:
self.create_and_add_visuPlot(plot_id)
self.create_and_add_vis_plot(plot_id)

else: # default plot when no visu_df is provided
self.create_and_add_visuPlot()
self.create_and_add_vis_plot()

plots = [visuPlot.getPlot() for visuPlot in self.visu_spec_plots]
plots = [vis_spec_plot.get_plot() for vis_spec_plot in
self.vis_spec_plots]

# update the cbox
self.cbox.clear()
# calling this method sets the index of the cbox to 0
# and thus displays the first plot
utils.add_plotnames_to_cbox(self.exp_data, self.visualization_df, self.cbox)
utils.add_plotnames_to_cbox(self.exp_data, self.visualization_df,
self.cbox)

return plots

Expand All @@ -135,13 +140,15 @@ def index_changed(self, i: int):
Arguments:
i: index of the selected plot
"""
if 0 <= i < len(self.visu_spec_plots): # i is -1 when the cbox is cleared
self.clear_QSplitter()
self.plot1_widget.addItem(self.visu_spec_plots[i].getPlot())
if 0 <= i < len(
self.vis_spec_plots): # i is -1 when the cbox is cleared
self.clear_qsplitter()
self.plot1_widget.addItem(self.vis_spec_plots[i].get_plot())
self.plot2_widget.hide()
if self.simulation_df is not None:
self.plot2_widget.show()
self.plot2_widget.addItem(self.visu_spec_plots[i].correlation_plot)
self.plot2_widget.addItem(
self.vis_spec_plots[i].correlation_plot)
self.current_list_index = i

def keyPressEvent(self, ev):
Expand All @@ -153,18 +160,18 @@ def keyPressEvent(self, ev):
"""
# Exit when pressing ctrl + Q
ctrl = False
if (ev.modifiers() & QtCore.Qt.ControlModifier):
if ev.modifiers() & QtCore.Qt.ControlModifier:
ctrl = True
if ctrl and ev.key() == QtCore.Qt.Key_Q:
sys.exit()

if(ev.key() == QtCore.Qt.Key_Up):
if ev.key() == QtCore.Qt.Key_Up:
self.index_changed(self.current_list_index - 1)
if(ev.key() == QtCore.Qt.Key_Down):
if ev.key() == QtCore.Qt.Key_Down:
self.index_changed(self.current_list_index + 1)
if(ev.key() == QtCore.Qt.Key_Left):
if ev.key() == QtCore.Qt.Key_Left:
self.index_changed(self.current_list_index - 1)
if(ev.key() == QtCore.Qt.Key_Right):
if ev.key() == QtCore.Qt.Key_Right:
self.index_changed(self.current_list_index + 1)

def add_warning(self, message: str):
Expand All @@ -177,7 +184,8 @@ def add_warning(self, message: str):
if message not in self.warn_msg.text():
self.warn_msg.setText(self.warn_msg.text() + message + "\n")

def redirect_warning(self, message, category, filename=None, lineno=None, file=None, line=None):
def redirect_warning(self, message, category, filename=None, lineno=None,
file=None, line=None):
"""
Redirect all warning messages and display them in the window.
Expand All @@ -187,11 +195,11 @@ def redirect_warning(self, message, category, filename=None, lineno=None, file=N
print("Warning redirected: " + str(message))
self.add_warning(str(message))

def create_and_add_visuPlot(self, plot_id=""):
def create_and_add_vis_plot(self, plot_id=""):
"""
Create a visuSpec_plot object based on the given plot_id.
Create a vis_spec_plot object based on the given plot_id.
If no plot_it is provided the default will be plotted.
Add all the warnings of the visuPlot object to the warning text box.
Add all the warnings of the vis_plot object to the warning text box.
The actual plotting happens in the index_changed method
Expand All @@ -201,41 +209,46 @@ def create_and_add_visuPlot(self, plot_id=""):
# split the measurement df by observable when using default plots
if self.visualization_df is None:
# to keep the order of plots consistent with names from the plot selection
indexes = np.unique(self.exp_data[ptc.OBSERVABLE_ID], return_index=True)[1]
observable_ids = [self.exp_data[ptc.OBSERVABLE_ID][index] for index in sorted(indexes)]
indexes = \
np.unique(self.exp_data[ptc.OBSERVABLE_ID], return_index=True)[1]
observable_ids = [self.exp_data[ptc.OBSERVABLE_ID][index] for index
in sorted(indexes)]
for observable_id in observable_ids:
rows = self.exp_data[ptc.OBSERVABLE_ID] == observable_id
data = self.exp_data[rows]
visuPlot = visuSpec_plot.VisuSpecPlot(measurement_df=data, visualization_df=None,
condition_df=self.condition_df,
simulation_df=self.simulation_df, plotId=plot_id)
self.visu_spec_plots.append(visuPlot)
if visuPlot.warnings:
self.add_warning(visuPlot.warnings)
vis_plot = vis_spec_plot.VisSpecPlot(
measurement_df=data, visualization_df=None,
condition_df=self.condition_df,
simulation_df=self.simulation_df, plot_id=plot_id)
self.vis_spec_plots.append(vis_plot)
if vis_plot.warnings:
self.add_warning(vis_plot.warnings)

else:
# reduce the visualization df to the relevant rows (by plotId)
rows = self.visualization_df[ptc.PLOT_ID] == plot_id
visu_df = self.visualization_df[rows]
if ptc.PLOT_TYPE_SIMULATION in visu_df.columns and\
visu_df.iloc[0][ptc.PLOT_TYPE_SIMULATION] == ptc.BAR_PLOT:
barPlot = bar_plot.BarPlot(measurement_df=self.exp_data,
visualization_df=visu_df,
vis_df = self.visualization_df[rows]
if ptc.PLOT_TYPE_SIMULATION in vis_df.columns and \
vis_df.iloc[0][ptc.PLOT_TYPE_SIMULATION] == ptc.BAR_PLOT:
bar_plot = BarPlot(measurement_df=self.exp_data,
visualization_df=vis_df,
condition_df=self.condition_df,
simulation_df=self.simulation_df, plotId=plot_id)
simulation_df=self.simulation_df,
plot_id=plot_id)
# might want to change the name of visu_spec_plots to clarify that
# it can also include bar plots (maybe to plots?)
self.visu_spec_plots.append(barPlot)
self.vis_spec_plots.append(bar_plot)
else:
visuPlot = visuSpec_plot.VisuSpecPlot(measurement_df=self.exp_data,
visualization_df=visu_df,
condition_df=self.condition_df,
simulation_df=self.simulation_df, plotId=plot_id)
self.visu_spec_plots.append(visuPlot)
if visuPlot.warnings:
self.add_warning(visuPlot.warnings)

def clear_QSplitter(self):
vis_plot = vis_spec_plot.VisSpecPlot(
measurement_df=self.exp_data,
visualization_df=vis_df,
condition_df=self.condition_df,
simulation_df=self.simulation_df, plot_id=plot_id)
self.vis_spec_plots.append(vis_plot)
if vis_plot.warnings:
self.add_warning(vis_plot.warnings)

def clear_qsplitter(self):
"""
Clear the GraphicsLayoutWidgets for the
measurement and correlation plot
Expand All @@ -258,11 +271,12 @@ def main():

visualization_df = None
if args.visualization is not None:
visualization_df = core.concat_tables(args.visualization, core.get_visualization_df)
visualization_df = core.concat_tables(args.visualization,
core.get_visualization_df)

app = QtWidgets.QApplication(sys.argv)
main = MainWindow(exp_data, visualization_df)
main.show()
main_window = MainWindow(exp_data, visualization_df)
main_window.show()
sys.exit(app.exec_())


Expand Down
Loading

0 comments on commit 9caf1ea

Please sign in to comment.