diff --git a/.gitignore b/.gitignore index 3c77eb5..08bb0f8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +!*/ __pycache__ .vscode -tmp/* \ No newline at end of file +paper/* +tmp/* +env_nteEnc.yml \ No newline at end of file diff --git a/README.md b/README.md index dac5b89..dae6c4c 100644 --- a/README.md +++ b/README.md @@ -71,3 +71,9 @@ apt-get -y install '^libxcb.*-dev' libx11-xcb-dev libglu1-mesa-dev libxrender-de apt packages: apt -y install ffmpeg, portaudio19-dev +### Trouble shooting +When running on an older Python version with the most recent PyQT6 package you might ancounter the following error: +'''symbol lookup error: [your_path]/python3.8/site-packages/PyQt6/Qt6/plugins/platforms/../../lib/libQt6WaylandClient.so.6: undefined symbol: wl_proxy_marshal_flags''' + +To fix this you can enforce the use of x11 instead of Wayland by adding the following line to your bashrc: +'''export QT_QPA_PLATFORM=xcb''' diff --git a/WiN_GUI.py b/WiN_GUI.py index 5495af2..308bbe7 100644 --- a/WiN_GUI.py +++ b/WiN_GUI.py @@ -1,49 +1,100 @@ """ -Use this GUI to visualize the encoding from sample-based into event-/spike-based data. -It only supports datasets with a specific structure. For more details see the README.md. -It is possible to change the neuron model and its parameters on the flight. -It also provides auditory feedback on the encoding result. +WiN_GUI.py + +This GUI is designed to visualize the encoding from sample-based data into event-/spike-based data. +It supports datasets with a specific structure, as detailed in the README.md. The GUI allows users +to change the neuron model and its parameters on the fly and provides auditory feedback on the +encoding results. + +Features: +- Visualization of spike patterns +- Real-time adjustment of neuron model parameters +- Auditory feedback on encoding results +- Detailed description available at: https://www.sciencedirect.com/science/article/pii/S2352711024001304 +- And [fill in the link to the v2 update publication] + +Authors: +- Simon F. Muller-Cleve (v1, v2) +- Fernando M. Quintana (v1) +- Vittorio Fra (v1, v2) + +Dependencies: +- numpy +- pandas +- torch +- matplotlib +- pydub +- PyQt6 + +Usage: +Run this script to launch the GUI. Ensure that the required dependencies are installed and the dataset +is structured as specified in the README.md. + +License: +This project is licensed under the GPL-3.0 License. See the LICENSE file for more details. -You can find a detailed description of the GUI here: https://www.sciencedirect.com/science/article/pii/S2352711024001304 - -Simon F. Muller-Cleve -Fernando M. Quintana -Vittorio Fra """ - import logging +import os +import shutil +import stat import sys +import tempfile from decimal import Decimal from random import random +import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch from matplotlib.backends.backend_qtagg import FigureCanvas -from matplotlib.figure import Figure +from matplotlib.gridspec import GridSpec from matplotlib.pyplot import cm from pydub import AudioSegment from pydub.generators import Sawtooth from PyQt6 import QtCore -from PyQt6.QtCore import QObject, Qt, QThread, QUrl +from PyQt6.QtCore import QEvent, QObject, Qt, QThread, QUrl +from PyQt6.QtGui import QColor, QFont from PyQt6.QtMultimedia import QAudioOutput, QMediaPlayer from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDial, QFileDialog, QGridLayout, QLabel, QMainWindow, - QPushButton, QSlider, QSplitter, QWidget) + QPushButton, QSizePolicy, QSlider, QSplitter, + QTableWidget, QTableWidgetItem, QTabWidget, + QWidget) -from utils.data_management import (create_directory, load_data, - preprocess_data, split_data) -from utils.neuron_models import IZ_neuron, LIF_neuron, MN_neuron, RLIF_neuron +from utils.data_management import load_data, preprocess_data, split_data +from utils.neuron_models import IZ_neuron, LIF_neuron, MN_neuron, CuBaLIF_neuron +from utils.spike_pattern_classifier import classifySpikes, prepareDataset WINDOW_WIDTH, WINDOW_HEIGTH = 1500, 750 DISPLAY_HEIGHT = 35 +MIDPOINT_LIGHTNESS = 200 +EXTREME_LIGHTNESS = 150 + + +class CustomSlider(QSlider): + def enterEvent(self, event): + self.setCursor(Qt.CursorShape.OpenHandCursor) + super().enterEvent(event) + + def leaveEvent(self, event): + self.unsetCursor() + super().leaveEvent(event) + + def mousePressEvent(self, event): + self.setCursor(Qt.CursorShape.ClosedHandCursor) + super().mousePressEvent(event) + + def mouseReleaseEvent(self, event): + self.setCursor(Qt.CursorShape.OpenHandCursor) + super().mouseReleaseEvent(event) class EncodingCalc(QObject): """EncodingGUI's controller class.""" - signalData = QtCore.pyqtSignal(np.ndarray, np.ndarray) + signalDataEnc = QtCore.pyqtSignal(np.ndarray, np.ndarray) def __init__(self, parent=None): super(self.__class__, self).__init__() @@ -82,8 +133,8 @@ def simulate(self): dt=self.main_gui.dt / self.main_gui.dt_slider.value(), ) - elif self.main_gui.neuron_model_name == "Recurrent leaky integrate-and-fire": - self.neurons = RLIF_neuron( + elif self.main_gui.neuron_model_name == "Current-based leaky integrate-and-fire": + self.neurons = CuBaLIF_neuron( len(self.main_gui.channels), params, dt=self.main_gui.dt / self.main_gui.dt_slider.value(), @@ -116,31 +167,87 @@ def simulate(self): if self.main_gui.enable_data_splitting: if self.main_gui.data_split is None: - self.signalData.emit( + self.signalDataEnc.emit( self.main_gui.data[sample].unsqueeze(1).cpu().numpy(), output) else: - self.signalData.emit( + self.signalDataEnc.emit( self.main_gui.data_split[sample].unsqueeze(1).cpu().numpy(), output) else: - self.signalData.emit( + self.signalDataEnc.emit( self.main_gui.data[sample].unsqueeze(1).cpu().numpy(), output) +class ClassificationCalc(QObject): + """EncodingGUI's controller class.""" + + signalDataClass = QtCore.pyqtSignal(np.ndarray, np.ndarray) + + def __init__(self, parent=None): + super(self.__class__, self).__init__() + self.main_gui = parent + self.main_gui.classify_event.connect(self.classify) + + @torch.no_grad() + @QtCore.pyqtSlot() + def classify(self): + if self.main_gui.output_data is None: + return + else: + generator = prepareDataset(self.main_gui.output_data) + predictions, softmax = classifySpikes(generator) + self.probs = softmax.copy() + # let us get the most frequent predicted class over all batches + self.finalPredictionList = [] + for sensorId in range(self.main_gui.output_data.shape[-1]): + # nothing to do if no spikes given + uniquePredictions, count = np.unique( + predictions[sensorId, :], return_counts=True) + + # Sort predictions by count in descending order + sorted_indices = np.argsort(count)[::-1] + sorted_predictions = uniquePredictions[sorted_indices] + sorted_counts = count[sorted_indices] + + # Check if 'No spikes' has the highest count + if len(sorted_predictions) > 1 and sorted_predictions[0] == 'No spikes': + # Select the second highest count + self.finalPredictionList.append(sorted_predictions[1]) + else: + # Select the highest count + self.finalPredictionList.append(sorted_predictions[0]) + + mean_softmax = np.mean(softmax, axis=1) + # Calculate the sum along axis 1, keeping the dimensions + sum_mean_softmax = np.sum(mean_softmax, axis=1, keepdims=True) + + # Normalize mean_softmax, avoiding division by zero + self.normalized_softmax = np.divide( + mean_softmax, + sum_mean_softmax, + where=sum_mean_softmax != 0 + ) + + self.signalDataClass.emit( + np.array(self.finalPredictionList), self.normalized_softmax) + + class WiN_GUI_Window(QMainWindow): """EncodingGUI's main window.""" draw_event = QtCore.pyqtSignal() # Signal used to update the plots simulate_event = QtCore.pyqtSignal() # Signal used to trigger a new simulation - # Signal used to trigger generation of new audio - audio_event = QtCore.pyqtSignal() + classify_event = QtCore.pyqtSignal() # Signal used to trigger classification + audio_event = QtCore.pyqtSignal() # Signal used to trigger audio update + write_event = QtCore.pyqtSignal() # Signal used to write the table def __init__(self): super().__init__() # Window creation - self.setWindowTitle("WiN-GUI") + gui_window_title = "WiN-GUI" + self.setWindowTitle(gui_window_title) self.setMinimumSize(WINDOW_WIDTH, WINDOW_HEIGTH) self.setWindowFlags( @@ -152,12 +259,17 @@ def __init__(self): | QtCore.Qt.WindowType.WindowMaximizeButtonHint # | QtCore.Qt.WindowType.WindowStaysOnTopHint # enforce window in forground ) - - # TODO inlcude removing tmp folder when closing GUI # create tmp path to store audio file - self.tmp_path = "./tmp" - # used to store tmp data (audio file, etc.) - create_directory(self.tmp_path) + self.tmp_dir = "./" + # Remove old temporary folder if present + for folder in os.listdir(self.tmp_dir): + if folder.startswith("tmp"): + # NOTE: onerror is deprecated as of python 3.12, to be replaced by onexc + shutil.rmtree(folder, onerror=self._removeReadonly) + self.tmp_folder = tempfile.mkdtemp( + dir=self.tmp_dir) # Create a temporary folder + + self.output_data = None # setting defaults self.upsample_fac = 1 @@ -174,51 +286,102 @@ def __init__(self): self.neuron_model_name = "Mihalas-Niebur" self.dataFilename = None self.neuronStateVariables = None + self.calcSpikePatternClassification = False + self.showSubClasses = False + + self.initUI() - # set the main layout to grid + def initUI(self): + # Init the main layout self.generalLayout = QSplitter(Qt.Orientation.Horizontal) self.setCentralWidget(self.generalLayout) - # Canvas pane - self.canvasLayout = QGridLayout() - canvasWidget = QWidget(self) - canvasWidget.setLayout(self.canvasLayout) - self.generalLayout.addWidget(canvasWidget) + # Set the size policy to make the window resizable + self.setSizePolicy(QSizePolicy.Policy.Expanding, + QSizePolicy.Policy.Expanding) + + # Enable hover events + self.setAttribute(Qt.WidgetAttribute.WA_Hover) - # Parameters pane + # Add tabs + self.tabs = QTabWidget() + self.data_tab = QWidget() + self.spike_pattern_tab = QWidget() + + self.tabs.addTab(self.data_tab, "Data Visualization") + self.tabs.addTab(self.spike_pattern_tab, "Spike-Pattern Visualization") + + # Add tabs to the main layout + self.generalLayout.addWidget(self.tabs) + + # Canvas pane in the first tab + self.canvasLayout = QGridLayout(self.data_tab) + self.data_tab.setLayout(self.canvasLayout) + + # Spike Pattern Visualizer in the second tab + self.spikePatternLayout = QGridLayout(self.spike_pattern_tab) + self.spike_pattern_tab.setLayout(self.spikePatternLayout) + + # Parameters pane (always visible on the right side) self.parametersLayout = QGridLayout() parametersWidget = QWidget(self) parametersWidget.setLayout(self.parametersLayout) self.generalLayout.addWidget(parametersWidget) - # init GUI - self.createCanvas() + # Initialize GUI elements + self.createCanvas() # Now in the first tab + self.createSpikePatternVisualizer() self.loadParameter() - self.createModelSelection() - self.createDataSelection() - self.createPreprocessingSelection() - self.createAudioPushButtons() - self.createParamSlider() - self.draw_event.connect(self.draw) - self.audio_event.connect(self.spikeToAudio) - - # Encoding simulator creation + self.createDataSection() + self.createPreprocessingSection() + self.createModelSection() + self.createParamSliderSection() + self.createSpikePatternClassifierSection() + self.createAudioSection() + + # Encoding simulator creation and threading self.encoding_calc = EncodingCalc(self) - self.thread = QThread(parent=self) + self.enc_thread = QThread(parent=self) + self.encoding_calc.moveToThread(self.enc_thread) - # When simulation finished, plot the result - self.encoding_calc.signalData.connect(self._updateCanvas) - self.encoding_calc.signalData.connect(self._updateSpikesToAudio) - self.encoding_calc.moveToThread(self.thread) + self.classification_calc = ClassificationCalc(self) + self.class_thread = QThread(parent=self) + self.classification_calc.moveToThread(self.class_thread) - self.thread.start() + # Connect signals and start the thread + self.draw_event.connect(self.drawCanvas) + self.audio_event.connect(self.spikeToAudio) + self.write_event.connect(self.writeTable) + self.encoding_calc.signalDataEnc.connect(self._updateCanvas) + self.encoding_calc.signalDataEnc.connect(self._updateSpikesToAudio) + self.classification_calc.signalDataClass.connect( + self._updateSpikePattern) + + self.enc_thread.start() + self.class_thread.start() + + # def event(self, event): + # if event.type() == QEvent.Type.HoverMove: + # pos = event.position() + # height = self.height() + # width = self.width() + # margin = 5 # Margin for resize area + + # if pos.x() < margin or pos.x() > width - margin: + # self.setCursor(Qt.CursorShape.SizeHorCursor) + # elif pos.y() < margin or pos.y() > height - margin: + # self.setCursor(Qt.CursorShape.SizeVerCursor) + # else: + # self.setCursor(Qt.CursorShape.ArrowCursor) + + # return super().event(event) ###################### # DATA VISUALIZATION # ###################### - def createAudioPushButtons(self): - filename = f"{self.tmp_path}/spikeToAudio.wav" + def createAudioSection(self): + filename = f"{self.tmp_folder}/spikeToAudio.wav" # self.eventsAudioStream = [] # TODO use this variable for the event audio stream self.player = QMediaPlayer() self.audio_output = QAudioOutput() @@ -244,7 +407,7 @@ def createAudioPushButtons(self): self.play_endlessly_button.clicked.connect(self._playEndlessly) self.parametersLayout.addLayout( - pushButtonLayout, 5, 0, Qt.AlignmentFlag.AlignBottom) + pushButtonLayout, 6, 0, Qt.AlignmentFlag.AlignBottom) def createCanvas(self): """ @@ -257,48 +420,50 @@ def createCanvas(self): # dynamically creating the plots for state variables despite V or spk plus raw input if self.neuron_model_name == "Mihalas-Niebur": - num_figures = len(MN_neuron.NeuronState._fields) + 1 + num_figures = len(MN_neuron.NeuronState._fields) elif self.neuron_model_name == "Izhikevich": - num_figures = len(IZ_neuron.NeuronState._fields) + 1 + num_figures = len(IZ_neuron.NeuronState._fields) elif self.neuron_model_name == "Leaky integrate-and-fire": - num_figures = len(LIF_neuron.NeuronState._fields) + 1 - elif self.neuron_model_name == "Recurrent leaky integrate-and-fire": - num_figures = len(RLIF_neuron.NeuronState._fields) + 1 + num_figures = len(LIF_neuron.NeuronState._fields) + elif self.neuron_model_name == "Current-based leaky integrate-and-fire": + num_figures = len(CuBaLIF_neuron.NeuronState._fields) else: ValueError("No neuron model selected.") + num_figures += 1 # add raster plot + + # Create a figure and GridSpec + self.figure = plt.figure(figsize=(10, 6)) + self.gs = GridSpec((num_figures // 2) + 1, 2, figure=self.figure) - # Create lists to store the figures and axes - self.figures = [] + # Create lists to store the axes self.axes = {} self.axis_names = [] for i in range(num_figures): self.axis_names.append(f"_dynamic_ax{i}") - # TODO fix incorrect size of raster plot when having 4 figures only (raster plot to high) - # set last row - last_row = int((num_figures/2) + 0.5) - # Create the figures and axes dynamically + # Create the axes dynamically using GridSpec for i, axis_name in enumerate(self.axis_names): - if i < len(self.axis_names)-1: - # all state variables plus raw data - figure = FigureCanvas(Figure()) # figsize=(5, 5) - self.figures.append(figure) - self.plotLayout.addWidget(figure, i // 2, i % 2, 1, 1) + if i < len(self.axis_names) - 1: + ax = self.figure.add_subplot(self.gs[i // 2, i % 2]) else: - # raster plot at the bottom - # spans over 2 columns and 1/2 row - figure = FigureCanvas(Figure()) # figsize=(10, 5) - self.figures.append(figure) - self.plotLayout.addWidget(figure, last_row, 0, 3, 2) - ax = figure.figure.subplots() + ax = self.figure.add_subplot(self.gs[-1, :]) self.axes[axis_name] = ax # Assign the axes to the corresponding variables dynamically for ax_variable_name, ax in self.axes.items(): setattr(self, ax_variable_name, ax) + # Add the figure to the layout + self.canvas = FigureCanvas(self.figure) + self.plotLayout.addWidget(self.canvas, 0, 0) self.canvasLayout.addLayout(self.plotLayout, 0, 0) + # Adjust layout to reduce whitespace + self.figure.tight_layout() + # Alternatively, you can use subplots_adjust for more control: + # self.figure.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05, wspace=0.2, hspace=0.4) + self._updateFontSizes() + def createChannelSelection(self): # create the channel selection self.channel_grid = QGridLayout() @@ -307,46 +472,6 @@ def createChannelSelection(self): if self.dataFilename == None: # TODO create textbox print('Show only text.') - - # TODO here we can include a custom layout for the braille data, but have to ensure we can also show splitted data - # elif 'example_braille_data' in self.dataFilename: - # self.dt = 1E-2 - # position_list = [[3, 1], [3, 2], [3, 3], [2, 4, 2, 1], [2, 0, 2, 1], [ - # 2, 3], [2, 2], [2, 1], [1, 1], [1, 2], [1, 3], [0, 2]] - # # TODO should the buttons have the color of the traces? - # for i in range(len(self.channels)): - # checkbox = QPushButton(str(i)) - # checkbox.setCheckable(True) - # checkbox.setChecked(True) - # checkbox.setStyleSheet( - # "background-color : lightgreen;" - # "border-top-left-radius : 25px;" - # "border-top-right-radius : 25px;" - # "border-bottom-left-radius : 25px;" - # "border-bottom-right-radius : 25px" - # ) - # checkbox.clicked.connect( - # lambda value, id=i: self._updateChannelCheckbox(value, id)) - # self.channel_box.append(checkbox) - # # set individual size for the most outer buttons - # if i == 3 or i == 4: - # checkbox.setFixedSize(50, 120) - # self.channel_grid.addWidget( - # checkbox, - # position_list[i][0], - # position_list[i][1], - # position_list[i][2], - # position_list[i][3], - # alignment=Qt.AlignmentFlag.AlignCenter, - # ) - # else: - # checkbox.setFixedSize(50, 60) - # self.channel_grid.addWidget( - # checkbox, - # position_list[i][0], - # position_list[i][1], - # alignment=Qt.AlignmentFlag.AlignCenter, - # ) else: position_list = [] # creating the default layout as a grid (pref. height over width) @@ -380,15 +505,17 @@ def createChannelSelection(self): self.channel_grid.addWidget( checkbox, position_list[i][0], position_list[i][1], alignment=Qt.AlignmentFlag.AlignCenter) - self.parametersLayout.addLayout(self.channel_grid, 2, 0) + self.parametersLayout.addLayout( + self.channel_grid, 2, 0, Qt.AlignmentFlag.AlignTop) - def createDataSelection(self): + def createDataSection(self): """ Used to select the class. """ dataSelectionLayout = QGridLayout() title = QLabel("Data management") self.loadButton = QPushButton("Load data") + self.loadButton.setCursor(Qt.CursorShape.PointingHandCursor) self.loadButton.clicked.connect(self.openData) dataSelectionLayout.addWidget(title, 0, 0, 1, 0) @@ -399,15 +526,19 @@ def createDataSelection(self): # create a dial to select the repetition self.selectedRepetition = 0 self.dialRepetition = QDial(self) + self.dialRepetition.setCursor(Qt.CursorShape.OpenHandCursor) self.dialRepetition.setMinimum(0) self.dialRepetition.setMaximum(0) self.dialRepetition.setValue(self.selectedRepetition) - self.dialRepetition.valueChanged.connect(self._updateDialRepetition) + self.dialRepetition.sliderReleased.connect(self._updateDialRepetition) + self.dialRepetition.sliderPressed.connect(self._onDialPressed) + self.dialRepetition.sliderReleased.connect(self._onDialReleased) dataSelectionLayout.addWidget(self.dialRepetition, 2, 0) # TODO only needed if differnet labels given (read from data) # create a combo box with all letters self.comboBoxLetters = QComboBox() + self.comboBoxLetters.setCursor(Qt.CursorShape.PointingHandCursor) self.comboBoxLetters.currentTextChanged.connect( self._updateComboBoxLettersText) dataSelectionLayout.addWidget(self.comboBoxLetters, 2, 1) @@ -417,18 +548,20 @@ def createDataSelection(self): ) self.createChannelSelection() - def createModelSelection(self): + def createModelSection(self): """ Select the neuron model to use. """ modelSelectionLayout = QGridLayout() title = QLabel("Neuron model and parameters") self.combo_box_neuron_model = QComboBox(self) + self.combo_box_neuron_model.setCursor( + Qt.CursorShape.PointingHandCursor) neuron_neuron_model_names = [ "Mihalas-Niebur", "Izhikevich", "Leaky integrate-and-fire", - "Recurrent leaky integrate-and-fire", + "Current-based leaky integrate-and-fire", ] self.combo_box_neuron_model.addItems(neuron_neuron_model_names) self.combo_box_neuron_model.currentTextChanged.connect( @@ -439,7 +572,7 @@ def createModelSelection(self): self.parametersLayout.addLayout( modelSelectionLayout, 3, 0, Qt.AlignmentFlag.AlignBottom) - def createParamSlider(self): + def createParamSliderSection(self): """ Used to select the neuron parameter to change. """ @@ -461,7 +594,7 @@ def createParamSlider(self): param_values[-1] * self.factor[id]) # read start value # create a slider for every param - slider = QSlider(Qt.Orientation.Horizontal, self) + slider = CustomSlider(Qt.Orientation.Horizontal, self) slider.setMinimum(int(param_values[0] * self.factor[id])) slider.setMaximum(int(param_values[1] * self.factor[id])) slider.setValue(self.sliderValues[id]) # set start value @@ -470,22 +603,15 @@ def createParamSlider(self): int(abs(np.diff(param_values[:2])[0]) * self.factor[id] / 20) ) # display n steps slider.setSingleStep(self.steps_size[id]) # read step value - slider.valueChanged.connect( - lambda value, id=id: self._updateParamSlider(value, id) + + # Connect the sliderReleased signal to the event handler + slider.sliderReleased.connect( + lambda id=id, slider=slider: self._updateParamSlider( + slider.value(), id) ) + self.sliders.append(slider) self.sliderLayout.addWidget(slider, id + 2, 1) - - # create label for each slider - # # TODO here we assume that the parameter name is always 'x_n' or does not contain '_' at all - # # TODO this is not ideal, but in most cases variables have only a single digit subscript - # # Use regular expression to find the pattern 'x_n' and format it as 'xn' - # import re - # formatted_string = re.sub(r'(\w)_(\w)', r'\1\2', param_key) - # # Creating a QLabel - # sliderLabel = QLabel() - # # Setting HTML content with the formatted string - # sliderLabel.setText(f'
{formatted_string}
') sliderLabel = QLabel(param_key, self) # write parameter name sliderLabel.setAlignment(Qt.AlignmentFlag.AlignLeft) @@ -499,9 +625,9 @@ def createParamSlider(self): self.sliderLayout.addWidget(self.sliderParamLabel[id], id + 2, 2) self.parametersLayout.addLayout( - self.sliderLayout, 4, 0) + self.sliderLayout, 4, 0, Qt.AlignmentFlag.AlignTop) - def createPreprocessingSelection(self): + def createPreprocessingSection(self): """ Creates the preprocessing section. """ @@ -513,6 +639,7 @@ def createPreprocessingSelection(self): # TODO create a 2x2 grid for checkboxes # checkboxes for: normalize, filter, startTrialAtNull, split_data self.normalizeDataCheckbox = QCheckBox("Normalize data") + self.normalizeDataCheckbox.setCursor(Qt.CursorShape.PointingHandCursor) self.normalizeDataCheckbox.setChecked(self.normalizeData) self.normalizeDataCheckbox.stateChanged.connect( self._updateNormalizeData) @@ -520,6 +647,7 @@ def createPreprocessingSelection(self): self.normalizeDataCheckbox, 1, 0, Qt.AlignmentFlag.AlignLeft) self.filterSignalCheckbox = QCheckBox("Filter signal") + self.filterSignalCheckbox.setCursor(Qt.CursorShape.PointingHandCursor) self.filterSignalCheckbox.setChecked(self.filterSignal) self.filterSignalCheckbox.stateChanged.connect( self._updateFilterSignal) @@ -527,6 +655,8 @@ def createPreprocessingSelection(self): self.filterSignalCheckbox, 1, 1, Qt.AlignmentFlag.AlignLeft) self.startTrialAtNullCheckbox = QCheckBox("Start trial at null") + self.startTrialAtNullCheckbox.setCursor( + Qt.CursorShape.PointingHandCursor) self.startTrialAtNullCheckbox.setChecked(self.startTrialAtNull) self.startTrialAtNullCheckbox.stateChanged.connect( self._updateStartTrialAtNull) @@ -534,6 +664,7 @@ def createPreprocessingSelection(self): self.startTrialAtNullCheckbox, 2, 0, Qt.AlignmentFlag.AlignCenter) self.splitDataCheckbox = QCheckBox("Split data") + self.splitDataCheckbox.setCursor(Qt.CursorShape.PointingHandCursor) self.splitDataCheckbox.setChecked(self.enable_data_splitting) self.splitDataCheckbox.stateChanged.connect( self._updateSplitData) @@ -546,6 +677,7 @@ def createPreprocessingSelection(self): dt_label_text.setAlignment(Qt.AlignmentFlag.AlignLeft) self.preprocessingLayout.addWidget(dt_label_text, 3, 0) self.dt_slider = QSlider(Qt.Orientation.Horizontal, self) + self.dt_slider.setCursor(Qt.CursorShape.PointingHandCursor) self.dt_slider.setMinimum(1) self.dt_slider.setMaximum(10) self.dt_slider.setValue(1) @@ -563,6 +695,7 @@ def createPreprocessingSelection(self): scale_label_text.setAlignment(Qt.AlignmentFlag.AlignLeft) self.preprocessingLayout.addWidget(scale_label_text, 4, 0) self.scale_slider = QSlider(Qt.Orientation.Horizontal, self) + self.scale_slider.setCursor(Qt.CursorShape.PointingHandCursor) self.scale_slider.setMinimum(1) self.scale_slider.setMaximum(10) self.scale_slider.setValue(1) @@ -578,11 +711,116 @@ def createPreprocessingSelection(self): # add to overall layout self.parametersLayout.addLayout( - self.preprocessingLayout, 1, 0, Qt.AlignmentFlag.AlignTop - ) + self.preprocessingLayout, 1, 0, Qt.AlignmentFlag.AlignTop) + + def createSpikePatternClassifierSection(self): + # here we need to have two checkboxes, one to activate the calssification and the second to select using super or sub labels + self.spikePatternClassifierLayout = QGridLayout() + title = QLabel("Spike-Pattern Classifier") + self.spikePatternClassifierLayout.addWidget(title, 0, 0) + + self.spikePatternClassifierCheckbox = QCheckBox( + "Pattern classification") + self.spikePatternClassifierCheckbox.setCursor( + Qt.CursorShape.PointingHandCursor) + self.spikePatternClassifierCheckbox.setChecked(False) + self.spikePatternClassifierCheckbox.stateChanged.connect( + self._updateCalculateSpikePatternClassification) + self.spikePatternClassifierLayout.addWidget( + self.spikePatternClassifierCheckbox, 1, 0, Qt.AlignmentFlag.AlignTop) + + self.superSubLabelCheckbox = QCheckBox("Neuronal behaviours") + self.superSubLabelCheckbox.setCursor(Qt.CursorShape.PointingHandCursor) + self.superSubLabelCheckbox.setChecked(False) + self.superSubLabelCheckbox.stateChanged.connect( + self._updateShowSpikePatternSubClasses) + self.spikePatternClassifierLayout.addWidget( + self.superSubLabelCheckbox, 1, 1, Qt.AlignmentFlag.AlignTop) + + self.parametersLayout.addLayout( + self.spikePatternClassifierLayout, 5, 0, Qt.AlignmentFlag.AlignLeft) + + def createSpikePatternVisualizer(self): + """Create a centered message for spike pattern visualization in the second tab.""" + if self.calcSpikePatternClassification: + if self.showSubClasses: + # show all 20 classes + patternLabels = ["ID", + "Major", + "Tonic spiking", # A + "Class 1", # B + "Spike frequency\nadaptation", # C + "Phasic spiking", # D + "Accommodation", # E + "Threshold\nvariability", # F + "Rebound spike", # G + "Class 2", # H + "Integrator", # I + "Input\nbistability", # J + "Hyperpolarizing\nspiking", # K + "Hyperpolarizing\nbursting", # L + "Tonic bursting", # M + "Phasic bursting", # N + "Rebound burst", # O + "Mixed mode", # P + "Afterpotentials", # Q + "Basal\nbistability", # R + "Preferred\nfrequency", # S + "Spike latency" # T + ] + else: + # show only major classes + """ + Regular: A, B, K, Q + Single burst: N, O + Multi-burst: L, M, R, S + Mixed: C, D, E, H, J, P + Unstructured: F, G, I, T + """ + + patternLabels = ["ID", + "Major", + "Regular", + "Single burst", + "Multi-burst", + "Mixed", + "Unstructured" + ] + + self.spikePatternTable = QTableWidget() + self.spikePatternTable.setRowCount(1) + # Two columns: ID and Spike Pattern + self.spikePatternTable.setColumnCount(len(patternLabels)) + self.spikePatternTable.setHorizontalHeaderLabels(patternLabels) + + index_major = patternLabels.index("Major") + header_major = self.spikePatternTable.horizontalHeaderItem(index_major) + # Set the font to bold + font_major = QFont() + font_major.setBold(True) + header_major.setFont(font_major) + + index_italic = [num for num, el in enumerate(patternLabels) if el not in ["ID", "Major"]] + for idx in index_italic: + header_italic = self.spikePatternTable.horizontalHeaderItem(idx) + # Set the font to italic + font_italic = QFont() + font_italic.setItalic(True) + header_italic.setFont(font_italic) + + self.spikePatternLayout.addWidget(self.spikePatternTable, 0, 0) + else: + # Create a QLabel for the message + self.spikePatternTable = QLabel( + "To calculate the spike-pattern classification, please activate the 'Pattern classification' checkbox in the parameter section.\nFor the detailed view of all the neuronal behaviours, activate the 'Neuronal behaviours' checkbox.") + self.spikePatternTable.setAlignment(Qt.AlignmentFlag.AlignCenter) + + # Add the message label to the layout and center it + self.spikePatternLayout.addWidget( + self.spikePatternTable, 0, 0, 1, 1, Qt.AlignmentFlag.AlignCenter) @QtCore.pyqtSlot() - def draw(self): + def drawCanvas(self): """ Update the plots. """ @@ -592,9 +830,13 @@ def draw(self): output = self.output_data axes = self.axes axis_names = self.axis_names - figures = self.figures dt = self.dt + # Get height of the GUI window + height = self.generalLayout.height() + tick_font_size = height * 0.01 + title_font_size = height * 0.015 + # some color schemes can be found here # (https://matplotlib.org/stable/tutorials/colors/colormaps.html) @@ -604,7 +846,7 @@ def draw(self): # TODO check if this is wanted (fast plotting vs accuracy) idx = np.arange(0, input.shape[0], upsample_fac) - time = idx*dt/upsample_fac + time = idx * dt / upsample_fac # plot the input data axes[axis_names[0]].clear() @@ -617,7 +859,7 @@ def draw(self): if ymin != ymax: axes[axis_names[0]].set_ylim(ymin, ymax) axes[axis_names[0]].set_title("Input") - figures[0].draw() + self.canvas.draw() axes[axis_names[1]].clear() @@ -628,14 +870,14 @@ def draw(self): self.neuronStateVariables = IZ_neuron.NeuronState._fields elif self.neuron_model_name == "Leaky integrate-and-fire": self.neuronStateVariables = LIF_neuron.NeuronState._fields - elif self.neuron_model_name == "Recurrent leaky integrate-and-fire": - self.neuronStateVariables = RLIF_neuron.NeuronState._fields + elif self.neuron_model_name == "Current-based leaky integrate-and-fire": + self.neuronStateVariables = CuBaLIF_neuron.NeuronState._fields else: ValueError("No neuron model selected.") variable_info = [] for i, name in enumerate(self.neuronStateVariables): - variable_info.append({"index": str(i+1), "title": name}) + variable_info.append({"index": str(i + 1), "title": name}) variables = [] for i, single_variable_info in enumerate(variable_info): @@ -658,8 +900,6 @@ def draw(self): if i < len(variables) - 1: ax.set_prop_cycle("color", colors[::2][channels]) ax.plot(time, output[idx, index, 0, :][:, channels]) - if index in [3, 4]: - ax.plot(time, output[idx, index, 0, :][:, channels]) ymin = np.min(output[idx, index, 0, :][:, channels]) - \ 0.1 * abs(np.min(output[idx, index, 0, :][:, channels])) @@ -667,20 +907,238 @@ def draw(self): 0.1 * abs(np.max(output[idx, index, 0, :][:, channels])) if ymin != ymax: ax.set_ylim(ymin, ymax) - ax.set_title(title) else: # create raster plot t, neuron_idx = np.where(output[:, 0, 0, :]) i = np.where(np.in1d(neuron_idx, np.where(channels)[0])) - t = t[i]*(dt/upsample_fac) + t = t[i] * (dt / upsample_fac) neuron_idx = neuron_idx[i] # remove color argument to get monochrom ax.scatter(x=t, y=neuron_idx, c=colors[::2][neuron_idx], s=5) ax.set_ylim(-1, output.shape[-1]) - ax.set_xlim(0, output.shape[0] / (1/self.dt) / upsample_fac) + ax.set_xlim(0, output.shape[0] / (1 / self.dt) / upsample_fac) - for var in variables: - var["ax"].figure.canvas.draw() + ax.set_title(title, fontsize=title_font_size) + ax.tick_params(axis='both', which='major', + labelsize=tick_font_size) + self.canvas.draw() + + @QtCore.pyqtSlot() + def writeTable(self): + """Update the spike pattern visualizer.""" + if self.output_data is not None and self.calcSpikePatternClassification: + ### REMINDER: + # self.classification_calc.probs has shape (n_channels,1,n_behaviours) + # it contains the n_behaviours probabilities for each channel + # self.finalPredictionList is a list with n_channels elements + # it contains the individual prediction (based on all the classes) for each channel + if self.showSubClasses: + self.patternLabels = [ + "Tonic spiking", # A + "Class 1", # B + "Spike frequency\nadaptation", # C + "Phasic spiking", # D + "Accommodation", # E + "Threshold\nvariability", # F + "Rebound spike", # G + "Class 2", # H + "Integrator", # I + "Input\nbistability", # J + "Hyperpolarizing\nspiking", # K + "Hyperpolarizing\nbursting", # L + "Tonic bursting", # M + "Phasic bursting", # N + "Rebound burst", # O + "Mixed mode", # P + "Afterpotentials", # Q + "Basal\nbistability", # R + "Preferred\nfrequency", # S + "Spike latency" # T + ] + else: + self.patternLabelsSubClasses = [ + "Tonic spiking", # A + "Class 1", # B + "Spike frequency\nadaptation", # C + "Phasic spiking", # D + "Accommodation", # E + "Threshold\nvariability", # F + "Rebound spike", # G + "Class 2", # H + "Integrator", # I + "Input\nbistability", # J + "Hyperpolarizing\nspiking", # K + "Hyperpolarizing\nbursting", # L + "Tonic bursting", # M + "Phasic bursting", # N + "Rebound burst", # O + "Mixed mode", # P + "Afterpotentials", # Q + "Basal\nbistability", # R + "Preferred\nfrequency", # S + "Spike latency" # T + ] + # mapping from all 20 to major + self.patternLabels = { + "Regular": ["Tonic spiking", "Class 1", "Hyperpolarizing\nspiking", "Afterpotentials"], + "Single burst": ["Phasic bursting", "Rebound burst"], + "Multi-burst": ["Hyperpolarizing\nbursting", "Tonic bursting", "Basal\nbistability", "Preferred\nfrequency"], + "Mixed": ["Spike frequency\nadaptation", "Phasic spiking", "Accommodation", "Class 2", "Input\nbistability", "Mixed mode"], + "Unstructured": ["Threshold\nvariability", "Rebound spike", "Integrator", "Spike latency"] + } + + # Function to evaluate classification on super-classes + def superclass_probabilities(self): + + superclass_probs = [] + classification = [] + + for ch in self.classification_calc.probs: + probs = ch[0].copy() + superclass_softmax_sum = np.zeros(len(self.patternLabels)) + for num,superclass in enumerate(list(self.patternLabels.keys())): + for subclass in self.patternLabels[superclass]: + idx = self.patternLabelsSubClasses.index(subclass) + superclass_softmax_sum[num] += probs[idx] + superclass_probs.append(superclass_softmax_sum) + if np.sum(probs) == 0: + classification.append("No spikes") + else: + classification.append(np.argmax(superclass_softmax_sum)) + + return classification, superclass_probs + + # Clear the table + self.spikePatternTable.setRowCount(self.output_data.shape[-1]) + # Add new rows + for sensorID in range(self.output_data.shape[-1]): + # ID + self.spikePatternTable.setItem( + sensorID, 0, QTableWidgetItem(str(sensorID))) # ID + + if self.showSubClasses: + # Final prediction + if (len(np.where(np.array(self.classification_calc.probs[sensorID])==np.max(self.classification_calc.probs[sensorID]))[0]) > 1) & (self.finalPredictionList[sensorID] != 'No spikes'): + item = QTableWidgetItem("Class overlap") # Class ambiguity + font = QFont() + font.setItalic(True) + item.setFont(font) + self.spikePatternTable.setItem(sensorID, 1, item) # overwrite prediction if multiple classes with equal probabilities + + else: + self.spikePatternTable.setItem(sensorID, 1, QTableWidgetItem( + self.finalPredictionList[sensorID])) # predicted spike pattern + + for pattern_label_counter in range(len(self.patternLabels)): + if self.finalPredictionList[sensorID] == 'No spikes': + item = QTableWidgetItem("") + + # color the cell white + item.setBackground(QColor(255, 255, 255)) + self.spikePatternTable.setItem( + sensorID, pattern_label_counter + 2, item) + else: + probability = self.normalized_softmax[sensorID, + pattern_label_counter] + percentage = np.round(probability*100,1) + item = QTableWidgetItem(str(percentage) + " %") + font = QFont() + font.setItalic(True) + item.setFont(font) + item.setTextAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter) + + # Calculate color based on probability + red = int(probability * 255) + blue = int((1 - probability) * 255) + green = 5 + color = QColor(red, green, blue) + + # Adjust color lightness based on distance from 0.5 + distance_from_mid = abs(probability - 0.5) + lightness_factor = EXTREME_LIGHTNESS + \ + int((1 - distance_from_mid * 2) * + (MIDPOINT_LIGHTNESS - EXTREME_LIGHTNESS)) + adjusted_color = color.lighter(lightness_factor) + + item.setBackground(adjusted_color) + + # probability of each pattern + self.spikePatternTable.setItem( + sensorID, pattern_label_counter + 2, item) + + else: + # Final prediction + self.finalPredictionList_superclass, probabilities = superclass_probabilities(self) + if (len(np.where(np.array(probabilities[sensorID])==np.max(probabilities[sensorID]))[0]) > 1) & (self.finalPredictionList_superclass[sensorID] != 'No spikes'): + item = QTableWidgetItem("Class overlap") # Class ambiguity + font = QFont() + font.setItalic(True) + item.setFont(font) + self.spikePatternTable.setItem(sensorID, 1, item) # overwrite prediction if multiple classes with equal probabilities + else: + item = QTableWidgetItem( + list(self.patternLabels.keys())[self.finalPredictionList_superclass[sensorID]] if self.finalPredictionList_superclass[sensorID] != "No spikes" else self.finalPredictionList_superclass[sensorID]) + self.spikePatternTable.setItem(sensorID, 1, item) # predicted spike pattern + + for pattern_label_counter, (key, _) in enumerate(self.patternLabels.items()): + if self.finalPredictionList_superclass[sensorID] == 'No spikes': + item = QTableWidgetItem("") + # color the cell white + item.setBackground(QColor(255, 255, 255)) + self.spikePatternTable.setItem( + sensorID, pattern_label_counter + 2, item) + else: + probability = probabilities[sensorID][pattern_label_counter] + percentage = np.round(probability*100,1) + item = QTableWidgetItem(str(percentage) + " %") + font = QFont() + font.setItalic(True) + item.setFont(font) + item.setTextAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter) + + # Calculate color based on probability + red = int(probability * 255) + blue = int((1 - probability) * 255) + green = 5 + color = QColor(red, green, blue) + + # Adjust color lightness based on distance from 0.5 + distance_from_mid = abs(probability - 0.5) + lightness_factor = EXTREME_LIGHTNESS + \ + int((1 - distance_from_mid * 2) * + (MIDPOINT_LIGHTNESS - EXTREME_LIGHTNESS)) + adjusted_color = color.lighter(lightness_factor) + + item.setBackground(adjusted_color) + + # probability of each pattern + self.spikePatternTable.setItem( + sensorID, pattern_label_counter + 2, item) + + + def resizeEvent(self, event): + self._updateFontSizes() + super().resizeEvent(event) # Call the base class implementation + + def _updateCalculateSpikePatternClassification(self): + # we only need to calcualte the classification if the checkbox is ticked + self.calcSpikePatternClassification = self.sender().isChecked() + self._resetLayout(None, self.spikePatternLayout) + self.createSpikePatternVisualizer() + if self.calcSpikePatternClassification: + self.classify_event.emit() + + def _updateShowSpikePatternSubClasses(self): + # when ticked we have to change the table + self.showSubClasses = self.sender().isChecked() + self._resetLayout(None, self.spikePatternLayout) + self.createSpikePatternVisualizer() + self.write_event.emit() + + def _updateSpikePattern(self, predictions, softmax): + self.normalized_softmax = softmax + self.finalPredictionList = predictions + self.write_event.emit() def _updateCanvas(self, input_data, output_data): self.input_data = input_data @@ -717,6 +1175,19 @@ def _updateComboBoxLettersText(self, s): self.active_class = self.le.transform([s])[0] self.simulate_event.emit() + def _updateFontSizes(self): + # Get height of the GUI window + height = self.generalLayout.height() + tick_font_size = height * 0.01 + title_font_size = height * 0.015 + + # Update the font sizes of your plot ticks and labels + for ax in self.figure.get_axes(): + ax.tick_params(axis='both', which='major', + labelsize=tick_font_size) + ax.title.set_size(title_font_size) + self.canvas.draw_idle() # Redraw the canvas to apply changes + def _resetLayout(self, layout, sublayout): """Remove all the sliders from the interface.""" def deleteItems(layout): @@ -733,8 +1204,6 @@ def deleteItems(layout): if layout != None: layout.removeItem(sublayout) - # END DATA VISUALIZATION - ################# # MODEL HANDLER # ################# @@ -743,14 +1212,16 @@ def changeModel(self, neuron_model_name): # Create widgets for changing neuron model self.neuron_model_name = neuron_model_name self.loadParameter() # load parameters from file - # set parameters to initial values + # reset the parameter layout self._resetLayout(self.parametersLayout, self.sliderLayout) - self.createParamSlider() # create widgets for parameters - # set parameters to initial values + self.createParamSliderSection() + # set canvas according to neuron model self._resetLayout(None, self.canvasLayout) - self.createCanvas() # create widgets for parameters + self.createCanvas() # Emit a signal to update the GUI self.simulate_event.emit() + if self.calcSpikePatternClassification: + self.classify_event.emit() logging.info(f"Neuron model changed to {neuron_model_name} neuron.") def loadParameter(self): @@ -764,14 +1235,12 @@ def loadParameter(self): elif self.neuron_model_name == "Leaky integrate-and-fire": from utils.neuron_parameters import lif_parameter self.parameter = lif_parameter - elif self.neuron_model_name == "Recurrent leaky integrate-and-fire": + elif self.neuron_model_name == "Current-based leaky integrate-and-fire": from utils.neuron_parameters import rlif_parameter self.parameter = rlif_parameter else: raise ValueError("Select a valid neuron model.") - # END MODEL HANDLER - ################## # DATAMANAGEMENT # ################## @@ -780,57 +1249,63 @@ def openData(self): """Load data and set up the data management interface.""" self.dataFilename = QFileDialog.getOpenFileName( self, "Open file", "./", "Pickle file (*.pkl)")[0] - self.data_dict = pd.read_pickle(self.dataFilename) - # check windows vs unix # TODO check if this is actually needed - if len(self.dataFilename.split('/')) > len(self.dataFilename.split('\\')): - newFilename = self.dataFilename.split('/') + if self.dataFilename == "": + return + # load the data else: - newFilename = self.dataFilename.split('\\') - self.loadButton.setText(newFilename[-1]) + self.data_dict = pd.read_pickle(self.dataFilename) + # check windows vs unix # TODO check if this is actually needed + if len(self.dataFilename.split('/')) > len(self.dataFilename.split('\\')): + newFilename = self.dataFilename.split('/') + else: + newFilename = self.dataFilename.split('\\') + self.loadButton.setText(newFilename[-1]) - self._loadData() - self.createChannelSelection() + self._loadData() + self.createChannelSelection() - # Set the class selection - self.comboBoxLetters.clear() - if 'letter' in list(self.data_dict.keys()): - logging.info('Found special naming. Gonna use it') - logging.info('Setting up box for class selection.') - self.comboBoxLetters.addItems( - list(np.unique(self.data_dict["letter"]))) - # set first class as default - self.active_class = self.le.transform( - [list(np.unique(self.data_dict["letter"]))[0]])[0] - else: - if 'class' in list(self.data_dict.keys()): - logging.info('Found standart naming. Gonna use it.') + # Set the class selection + self.comboBoxLetters.clear() + if 'letter' in list(self.data_dict.keys()): + logging.info('Found special naming. Gonna use it') logging.info('Setting up box for class selection.') self.comboBoxLetters.addItems( - list(np.unique(self.data_dict["class"]))) + list(np.unique(self.data_dict["letter"]))) # set first class as default self.active_class = self.le.transform( - [list(np.unique(self.data_dict["class"]))[0]])[0] + [list(np.unique(self.data_dict["letter"]))[0]])[0] else: - logging.warning('No classes found. (Remove box?)') - - # only create wheel if multiple repetitions are given - if 'repetition' in list(self.data_dict.keys()) and len(np.unique(self.data_dict["repetition"])) > 1: - logging.info('Setting up dial to select the repetition.') - # Modify the repetition selection - self.selectedRepetition = int( - random() * len(np.unique(self.data_dict["repetition"])) - ) - self.dialRepetition.setMaximum( - len(np.unique(self.data_dict["repetition"])) - 1) - else: - # remove dial - logging.warning('Only single trial per class. (Removing dial?)') - # Remove the dial widget from the layout - self.dialRepetition.setParent(None) - self.dialRepetition.deleteLater() + if 'class' in list(self.data_dict.keys()): + logging.info('Found standart naming. Gonna use it.') + logging.info('Setting up box for class selection.') + self.comboBoxLetters.addItems( + list(np.unique(self.data_dict["class"]))) + # set first class as default + self.active_class = self.le.transform( + [list(np.unique(self.data_dict["class"]))[0]])[0] + else: + logging.warning('No classes found. (Remove box?)') + + # only create wheel if multiple repetitions are given + if 'repetition' in list(self.data_dict.keys()) and len(np.unique(self.data_dict["repetition"])) > 1: + logging.info('Setting up dial to select the repetition.') + # Modify the repetition selection + self.selectedRepetition = int( + random() * len(np.unique(self.data_dict["repetition"])) + ) + self.dialRepetition.setMaximum( + len(np.unique(self.data_dict["repetition"])) - 1) + else: + # remove dial + logging.warning( + 'Only single trial per class. (Removing dial?)') + # Remove the dial widget from the layout + self.dialRepetition.setParent(None) + self.dialRepetition.deleteLater() def _loadData(self): """Load the data from the file.""" + # TODO make sure user can close the load window without triggering any calcualtion self.data_split, self.labels, self.timestamps, self.le, self.data = load_data( self.dataFilename, upsample_fac=self.upsample_fac, @@ -840,18 +1315,28 @@ def _loadData(self): startTrialAtNull=self.startTrialAtNull, ) if 'example_braille_data' in self.dataFilename: - self.dt = 1E-2 + self.dt = 1E-2 # 1/40 # 25Hz in sec self.channels = np.ones(self.data.shape[-1], dtype=bool) self.data_default = self.data.numpy() self.timestamps_default = self.timestamps.copy() self.simulate_event.emit() + if self.calcSpikePatternClassification: + self.classify_event.emit() - def _updateDialRepetition(self, value): + def _updateDialRepetition(self): """Update the repetition according to the dial.""" - self.selectedRepetition = value + self.dialRepetition.sliderReleased.connect(self._updateDialRepetition) + if self.output_data is not None: + self.simulate_event.emit() + if self.calcSpikePatternClassification: + self.classify_event.emit() - self.simulate_event.emit() + def _onDialPressed(self): + self.dialRepetition.setCursor(Qt.CursorShape.ClosedHandCursor) + + def _onDialReleased(self): + self.dialRepetition.setCursor(Qt.CursorShape.OpenHandCursor) def _updateDt(self): """Recalculate the input data and neuron output with new dt.""" @@ -860,23 +1345,30 @@ def _updateDt(self): value = self.sender().value() self.dt_label.setText(str(value)) self.upsample_fac = value - data_steps = len(self.data_default[0]) - self.data_steps = data_steps * self.upsample_fac - # here we change the number of computed time steps according to the upsample factor - self._updateData() - self.simulate_event.emit() + if self.output_data is not None: + # here we change the number of computed time steps according to the upsample factor + self._updateData() + self.simulate_event.emit() + if self.calcSpikePatternClassification: + self.classify_event.emit() def _updateFilterSignal(self): """Update data according to filter signal checkbox.""" self.filterSignal = self.sender().isChecked() - self._updateData() - self.simulate_event.emit() + if self.output_data is not None: + self._updateData() + self.simulate_event.emit() + if self.calcSpikePatternClassification: + self.classify_event.emit() def _updateNormalizeData(self): """Update data according to normalize data checkbox.""" self.normalizeData = self.sender().isChecked() - self._updateData() - self.simulate_event.emit() + if self.output_data is not None: + self._updateData() + self.simulate_event.emit() + if self.calcSpikePatternClassification: + self.classify_event.emit() def _updateParamSlider(self, value, id): """Update the parameter according to the slider.""" @@ -889,15 +1381,21 @@ def _updateParamSlider(self, value, id): self.sliderValues[id] = value self.sliderParamLabel[id].setText( str(value / int(self.factor[id]))) - self.simulate_event.emit() + if self.output_data is not None: + self.simulate_event.emit() + if self.calcSpikePatternClassification: + self.classify_event.emit() def _updateScale(self): """Update the data scaling.""" value = self.sender().value() self.scale_label.setText(str(value)) self.scale = value - self._updateData() - self.simulate_event.emit() + if self.output_data is not None: + self._updateData() + self.simulate_event.emit() + if self.calcSpikePatternClassification: + self.classify_event.emit() def _updateSplitData(self): """ @@ -905,14 +1403,17 @@ def _updateSplitData(self): If no negative values given, channel will contain zeros only. """ self.enable_data_splitting = self.sender().isChecked() - self._updateData() - if self.enable_data_splitting: - self.channels = np.ones(self.data_split.shape[-1], dtype=bool) - else: - self.channels = np.ones(self.data.shape[-1], dtype=bool) - self._resetLayout(self.parametersLayout, self.channel_grid) - self.createChannelSelection() - self.simulate_event.emit() + if self.output_data is not None: + self._updateData() + if self.enable_data_splitting: + self.channels = np.ones(self.data_split.shape[-1], dtype=bool) + else: + self.channels = np.ones(self.data.shape[-1], dtype=bool) + self._resetLayout(self.parametersLayout, self.channel_grid) + self.createChannelSelection() + self.simulate_event.emit() + if self.calcSpikePatternClassification: + self.classify_event.emit() def _updateData(self): timestamps, data = preprocess_data( @@ -935,10 +1436,11 @@ def _updateStartTrialAtNull(self): Will redraw the plots. """ self.startTrialAtNull = self.sender().isChecked() - self._updateData() - self.simulate_event.emit() - - # END DATAMANAGEMENT + if self.output_data is not None: + self._updateData() + self.simulate_event.emit() + if self.calcSpikePatternClassification: + self.classify_event.emit() ################### # SPIKES TO AUDIO # @@ -1000,16 +1502,41 @@ def _updateSpikesToAudio(self): ''' Triggers the calculation of the spike to audio conversion. ''' - neuronSpikeTimesSparse = np.reshape(self.output_data[:, 0, :, :], ( + neuronSpikeTimesDense = np.reshape(self.output_data[:, 0, :, :], ( self.output_data.shape[0], self.output_data.shape[-1])) # TODO call spikes by key? # convert sparse representation to spike times - neuronSpikeTimes = np.where(neuronSpikeTimesSparse == 1)[0] * self.dt + neuronSpikeTimes = np.where(neuronSpikeTimesDense == 1)[0] * self.dt audio_duration = len(self.input_data) * self.dt audio = self.spikeToAudio( - out_path='./tmp', neuron_spike_times=neuronSpikeTimes, audio_duration=audio_duration) + out_path=self.tmp_folder, neuron_spike_times=neuronSpikeTimes, audio_duration=audio_duration) + + ################### + # CLEAN CLOSE APP # + ################### + + def closeEvent(self, event): + # Stop the main threads + self._stopThreads() + + # Remove the temporary folder + for folder in os.listdir(self.tmp_dir): + if folder.startswith("tmp"): + # NOTE: onerror is deprecated as of python 3.12, to be replaced by onexc + shutil.rmtree(folder, onerror=self._removeReadonly) + + event.accept() # Accept the close eventv + + def _removeReadonly(self, func, path, _): + "Clear the readonly bit and reattempt the removal" + os.chmod(path, stat.S_IWRITE) + func(path) - # END SPIKE TO AUDIO + def _stopThreads(self): + self.enc_thread.quit() + self.enc_thread.wait() + self.class_thread.quit() + self.class_thread.wait() def main(): @@ -1018,7 +1545,7 @@ def main(): winGUIwindow = WiN_GUI_Window() winGUIwindow.show() - sys.exit(WiN_GUI.exec()) # TODO inlcude removing tmp folder + sys.exit(WiN_GUI.exec()) if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index 61e8f61..a096cb1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ pyside2 pyqt6 scikit-learn scipy -torch \ No newline at end of file +torch +tqdm \ No newline at end of file diff --git a/utils/neuron_models.py b/utils/neuron_models.py index bd0b021..5f20055 100644 --- a/utils/neuron_models.py +++ b/utils/neuron_models.py @@ -1,49 +1,69 @@ +""" +neuron_models.py + +This module implements various neuron models for use in spiking neural network (SNN) simulations. +The primary focus is on biologically plausible neuron models, such as the Mihalas-Niebur (MN) neuron model. + +Main Components: +- MN_neuron: A class implementing the Mihalas-Niebur neuron model. This model includes parameters for + synaptic weights, membrane potential, and other neuron-specific properties. +- IZ_neuron: A class implementing the Izhikevich neuron model. This model simulates spiking and bursting + behavior of neurons. +- LIF_neuron: A class implementing the Leaky Integrate-and-Fire neuron model. This model is a simple + representation of neuronal activity. +- CuBaLIF_neuron: A class implementing the Current-Based Leaky Integrate-and-Fire neuron model. This model + includes a synaptic current to simulate more complex neuronal dynamics. + +Classes and Functions: +- MN_neuron: A PyTorch module representing the Mihalas-Niebur neuron model. + - NeuronState: A named tuple to store the state variables of the neuron (V, i1, i2, Thr, spk). + - __init__(self, nb_inputs, parameters_combination, dt=1/1000, ...): Initializes the neuron with the given parameters. + - forward(self, input): Defines the forward pass of the neuron model. +- IZ_neuron: A PyTorch module representing the Izhikevich neuron model. + - __init__(self, nb_inputs, parameters_combination, dt=1/1000, ...): Initializes the neuron with the given parameters. + - forward(self, input): Defines the forward pass of the neuron model. +- LIF_neuron: A PyTorch module representing the Leaky Integrate-and-Fire neuron model. + - __init__(self, nb_inputs, parameters_combination, dt=1/1000, ...): Initializes the neuron with the given parameters. + - forward(self, input): Defines the forward pass of the neuron model. +- CuBaLIF_neuron: A PyTorch module representing the Current-Based Leaky Integrate-and-Fire neuron model. + - __init__(self, nb_inputs, parameters_combination, dt=1/1000, ...): Initializes the neuron with the given parameters. + - forward(self, input): Defines the forward pass of the neuron model. + +Dependencies: +- torch: PyTorch library for tensor computations and neural network operations. +- surrogate_gradient: Custom autograd function for spiking nonlinearity with a surrogate gradient. + +Usage: +This module is intended to be used as part of the WiN-GUI project for simulating spiking neural networks. +To use this module, ensure that the required dependencies are installed. + +Example: + import torch + from neuron_models import MN_neuron, IZ_neuron, LIF_neuron, CuBaLIF_neuron + + # Define neuron parameters + nb_inputs = 10 + parameters_combination = {...} + + # Initialize the MN neuron + mn_neuron = MN_neuron(nb_inputs, parameters_combination) + + # Define input tensor + input_tensor = torch.randn(nb_inputs) + + # Run the forward pass + output = mn_neuron(input_tensor) + +License: +This project is licensed under the GPL-3.0 License. See the LICENSE file for more details. + +""" + from collections import namedtuple -import numpy as np import torch import torch.nn as nn - - -class SurrGradSpike(torch.autograd.Function): - """ - Here we implement our spiking nonlinearity which also implements - the surrogate gradient. By subclassing torch.autograd.Function, - we will be able to use all of PyTorch's autograd functionality. - Here we use the normalized negative part of a fast sigmoid - as this was done in Zenke & Ganguli (2018). - """ - - scale = 100 - - @staticmethod - def forward(ctx, input): - """ - In the forward pass we compute a step function of the input Tensor - and return it. ctx is a context object that we use to stash information which - we need to later backpropagate our error signals. To achieve this we use the - ctx.save_for_backward method. - """ - ctx.save_for_backward(input) - out = torch.zeros_like(input) - out[input > 0] = 1.0 - return out - - @staticmethod - def backward(ctx, grad_output): - """ - In the backward pass we receive a Tensor we need to compute the - surrogate gradient of the loss with respect to the input. - Here we use the normalized negative part of a fast sigmoid - as this was done in Zenke & Ganguli (2018). - """ - (input,) = ctx.saved_tensors - grad_input = grad_output.clone() - grad = grad_input / (SurrGradSpike.scale * torch.abs(input) + 1.0) ** 2 - return grad - - -activation = SurrGradSpike.apply +from utils.surrogate_gradient import activation # MN neuron @@ -197,7 +217,7 @@ def forward(self, x): u = self.state.u numerical_res = round(self.dt) - if self.dt>1: + if self.dt > 1: output_spike = torch.zeros_like(self.state.spk) for i in range(numerical_res): V = V + (((0.04 * V + 5) * V) + 140 - u + x) @@ -273,7 +293,8 @@ def forward(self, x): V = self.state.V spk = self.state.spk - V = (self.beta * V + (1.0-self.beta) * x * self.R) * (1.0 - spk) + # V = (self.beta * V + (1.0-self.beta) * x * self.R) * (1.0 - spk) + V = (self.beta * V + x * self.R) * (1.0 - spk) # reset mechanism: zero spk = activation(V-self.threshold) self.state = self.NeuronState(V=V, spk=spk) @@ -284,7 +305,7 @@ def reset(self): self.state = None -class RLIF_neuron(nn.Module): +class CuBaLIF_neuron(nn.Module): NeuronState = namedtuple("NeuronState", ["V", "syn", "spk"]) def __init__( @@ -297,7 +318,7 @@ def __init__( thr=1.0, R=1.0, ): - super(RLIF_neuron, self).__init__() + super(CuBaLIF_neuron, self).__init__() self.nb_inputs = nb_inputs self.alpha = alpha @@ -330,9 +351,11 @@ def forward(self, x): spk = self.state.spk syn = self.state.syn - syn = self.alpha*syn + spk - V = (self.beta * V + (1.0-self.beta) * x * - self.R + (1.0-self.beta)*syn) * (1.0 - spk) + # syn = self.alpha*syn + spk + # V = (self.beta * V + (1.0-self.beta) * x * + # self.R + (1.0-self.beta)*syn) * (1.0 - spk) + syn = self.alpha*syn + x*self.R + V = (self.beta * V + syn) * (1.0 - spk) # reset mechanism: zero spk = activation(V-self.threshold) self.state = self.NeuronState(V=V, syn=syn, spk=spk) diff --git a/utils/neuron_parameters.py b/utils/neuron_parameters.py index d818ae4..c05522a 100644 --- a/utils/neuron_parameters.py +++ b/utils/neuron_parameters.py @@ -1,9 +1,22 @@ """ Example structure of parameter list: -name = { - "param_1": [min, max, step, init] + +Each parameter dictionary contains the following keys: +- "param_name": [min, max, step, init] + +Where: +- param_name: The name of the parameter. +- min: The minimum value of the parameter. +- max: The maximum value of the parameter. +- step: The step size for adjusting the parameter. +- init: The initial value of the parameter. + +Example: +neuron_parameters = { + "param_1": [min_value, max_value, step_size, initial_value], + "param_2": [min_value, max_value, step_size, initial_value], ... - "param_n": [min, max, step, init] + "param_n": [min_value, max_value, step_size, initial_value] } """ @@ -21,18 +34,6 @@ "R2": [0, 2, 0.01, 1], # Ohm? } -# mn_parameter = { -# "a": [-100, 40, 0.1, 0], # 1/s -# "A_1/C": [-5, 15, 0.1, 0], # V/s -# "A_2/C": [-1, 1, 0.01, 0], # V/s -# "b": [-20, 20, 0.1, 10], # 1/s -# "G/C": [0, 75, 0.1, 50], # 1/s -# "k_1": [0, 300, 1, 200], # 1/s -# "k_2": [0, 30, 0.1, 20], # 1/s -# "R_1": [0, 2, 0.01, 0], # Ohm? -# "R_2": [0, 2, 0.01, 1], # Ohm? -# } - # Iziekevich neuron iz_parameter = { "a": [0, 0.1, 0.01, 0.02], diff --git a/utils/spike_pattern_classifier.py b/utils/spike_pattern_classifier.py new file mode 100644 index 0000000..4fb5d2d --- /dev/null +++ b/utils/spike_pattern_classifier.py @@ -0,0 +1,465 @@ +""" +spike_pattern_classifier.py + +This module implements functions and utilities for classifying spike patterns using spiking neural networks (SNNs). +It includes functions for checking GPU availability, preparing datasets, loading pre-trained weights, and running +simulations of spiking neural networks. The module also defines custom autograd functions for surrogate gradient +computation, which are used to train the SNN. + +Main Components: +- checkCuda: A function to check for available GPUs and set up the device for computation. +- classifySpikes: A function to classify spike patterns from a generator of spike data. +- prepareDataset: A function to prepare a dataset for spike pattern classification. +- getFiringPatternLabels: A function to retrieve a mapping of firing pattern labels. +- loadWeights: A function to load pre-trained weights for the spiking neural network. +- computeActivity: A function to compute the activity of a feedforward spiking neural network layer. +- computeRecurrentActivity: A function to compute the activity of a recurrent spiking neural network layer. +- runSNN: A function to run a spiking neural network simulation. + +Dependencies: +- numpy +- torch +- torch.nn +- torch.utils.data +- tqdm + +Usage: +This module is intended to be used as part of the WiN-GUI project for spike pattern classification. +To use this module, ensure that the required dependencies are installed and the dataset is structured +as specified in the README.md. + +Example: + import torch + from spike_pattern_classifier import classifySpikes, prepareDataset, loadWeights, runSNN + + # Prepare dataset + dataset = prepareDataset(data_path) + + # Load pre-trained weights + model = loadWeights(model_path) + + # Run SNN simulation + results = runSNN(model, dataset) + +""" + +import numpy as np +import torch +import torch.nn as nn +from utils.surrogate_gradient import activation +from torch.utils.data import DataLoader, TensorDataset +from tqdm import tqdm + + +def checkCuda(share_GPU=False, gpu_sel=0, gpu_mem_frac=None): + """ + Check for available GPU and distribute work (if needed/wanted). + + This function checks for available GPUs and sets up the device for computation. + It can distribute the load across multiple GPUs if specified, or use a single GPU or CPU. + + Args: + share_GPU (bool, optional): If True, the load will be shared across multiple GPUs. Defaults to False. + gpu_sel (int, optional): The index of the GPU to use if not sharing the load. Defaults to 0. + gpu_mem_frac (float, optional): The fraction of GPU memory to allocate for the process. + If None, the default allocation is used. Defaults to None. + + Returns: + torch.device: The device to be used for computation (either a specific GPU or the CPU). + """ + + if (torch.cuda.device_count() > 1) & (share_GPU): + gpu_av = [torch.cuda.is_available() + for ii in range(torch.cuda.device_count())] + # print("Detected {} GPUs. The load will be shared.".format( + # torch.cuda.device_count())) + for gpu in range(len(gpu_av)): + if True in gpu_av: + if gpu_av[gpu_sel]: + device = torch.device("cuda:"+str(gpu)) + # print("Selected GPUs: {}" .format("cuda:"+str(gpu))) + else: + device = torch.device("cuda:"+str(gpu_av.index(True))) + else: + device = torch.device("cpu") + # print("No available GPU detected. Running on CPU.") + elif (torch.cuda.device_count() > 1) & (not share_GPU): + # print("Multiple GPUs detected but single GPU selected. Setting up the simulation on {}".format( + # "cuda:"+str(gpu_sel))) + device = torch.device("cuda:"+str(gpu_sel)) + if gpu_mem_frac is not None: + # decrese or comment out memory fraction if more is available (the smaller the better) + torch.cuda.set_per_process_memory_fraction( + gpu_mem_frac, device=device) + else: + if torch.cuda.is_available(): + # print("Single GPU detected. Setting up the simulation there.") + device = torch.device("cuda:"+str(torch.cuda.current_device())) + # thr 1: None, thr 2: 0.8, thr 5: 0.5, thr 10: None + if gpu_mem_frac is not None: + # decrese or comment out memory fraction if more is available (the smaller the better) + torch.cuda.set_per_process_memory_fraction( + gpu_mem_frac, device=device) + else: + # print("No GPU detected. Running on CPU.") + device = torch.device("cpu") + + return device + + +def prepareDataset(data): + """ + Prepares a dataset for spike pattern classification. + + This function processes the input spike data, ensuring that each sensor has exactly 1000 time steps. + If the data has fewer than 1000 time steps, it repeats the data. If the data has more than 1000 time steps, + it uses a sliding window approach to create multiple samples. The function returns a DataLoader for the processed dataset. + + Args: + data (np.ndarray): A 4D NumPy array of shape [timesteps, internal variables, 1, sensors] containing the spike data. + + Returns: + DataLoader: A DataLoader object for the processed dataset, with each batch containing the spike data and corresponding sensor indices. + """ + + neuronSpikeTimesDense = np.reshape( + data[:, 0, :, :], (data.shape[0], data.shape[-1])) + neuronSpikeTimesDense = torch.as_tensor( + neuronSpikeTimesDense, dtype=torch.float32) + target_nb_samples = 1000 + stride = 100 + + # if neuronSpikeTimesDense.shape[0] < target_nb_samples: + # # we need the data to be repeated + # sensor_idc = torch.arange(neuronSpikeTimesDense.shape[-1]) + # repeats = 1000 // neuronSpikeTimesDense.shape[0] + # remainder = 1000 % neuronSpikeTimesDense.shape[0] + + # # Create an array of zeros + # neuronSpikeTimesDenseRepeted = torch.zeros( + # (target_nb_samples, neuronSpikeTimesDense.shape[1]), dtype=neuronSpikeTimesDense.dtype) + + # for sensor_idx in range(neuronSpikeTimesDense.shape[1]): + # # Repeat and concatenate the array to get exactly 1000 entries + # neuronSpikeTimesDenseRepeted[:, sensor_idx] = torch.cat([ + # neuronSpikeTimesDense[:, sensor_idx].repeat(repeats), + # neuronSpikeTimesDense[:remainder, sensor_idx] + # ]) + + # neuronSpikeTimesDense = neuronSpikeTimesDenseRepeted.unsqueeze(0) + # print(neuronSpikeTimesDense.shape) + # # Add extra dimension to match the shape + # sensor_idc = sensor_idc.unsqueeze(0) + # print(sensor_idc.shape) + # batch_size = 1 + + if neuronSpikeTimesDense.shape[0] > target_nb_samples: + nb_splits = ( + neuronSpikeTimesDense.shape[0] - target_nb_samples) // stride + 1 + + # Create sensor ID list + sensor_idc_init = torch.arange(neuronSpikeTimesDense.shape[-1]) + sensor_idc = sensor_idc_init.unsqueeze(0).repeat(nb_splits, 1) + + start_points = range( + 0, neuronSpikeTimesDense.shape[0] - target_nb_samples + 1, stride) + + # Pre-allocate array for the sliced data + neuronSpikeTimesDenseRepeted = torch.zeros( + (nb_splits, target_nb_samples, neuronSpikeTimesDense.shape[1]), dtype=neuronSpikeTimesDense.dtype) + + # Fill the new array using sliding window + for sensor_idx in range(neuronSpikeTimesDense.shape[1]): + for split_idx, start in enumerate(start_points): + neuronSpikeTimesDenseRepeted[split_idx, :, + sensor_idx] = neuronSpikeTimesDense[start:start + target_nb_samples, sensor_idx] + + neuronSpikeTimesDense = neuronSpikeTimesDenseRepeted + + batch_size = min(nb_splits, 128) + + else: + batch_size = 1 + neuronSpikeTimesDense = neuronSpikeTimesDense.unsqueeze(0) + sensor_idc = torch.arange(neuronSpikeTimesDense.shape[-1]) + sensor_idc = sensor_idc.unsqueeze(0) + + # data has always 1000 entries in the first dimension and the second dimension is the number of sensors each sensor is repeated for data longer then 1000 entries + ds_test = TensorDataset(neuronSpikeTimesDense, sensor_idc) + generator = DataLoader(ds_test, batch_size=batch_size, + shuffle=False, num_workers=4, pin_memory=True) + + return generator + + +def getFiringPatternLabels(): + """ + Retrieves a mapping of firing pattern labels. + + This function returns a dictionary that maps single-character keys to descriptive labels of various firing patterns. + These labels are used to classify different types of neuronal firing behaviors. + + Returns: + dict: A dictionary where keys are single-character strings and values are descriptive labels of firing patterns. + """ + + labels_mapping = { + 'A': "Tonic spiking", + 'B': "Class 1", + 'C': "Spike frequency adaptation", + 'D': "Phasic spiking", + 'E': "Accommodation", + 'F': "Threshold variability", + 'G': "Rebound spike", + 'H': "Class 2", + 'I': "Integrator", + 'J': "Input bistability", + 'K': "Hyperpolarizing spiking", + 'L': "Hyperpolarizing bursting", + 'M': "Tonic bursting", + 'N': "Phasic bursting", + 'O': "Rebound burst", + 'P': "Mixed mode", + 'Q': "Afterpotentials", + 'R': "Basal bistability", + 'S': "Preferred frequency", + 'T': "Spike latency", + } + return labels_mapping + + +def loadWeights(map_location): + """ + Loads pre-trained weights for the spiking neural network. + + This function loads the pre-trained weights from a specified file and maps them to the given device location. + + Args: + map_location (str or torch.device): The device location to map the loaded weights to (e.g., 'cpu', 'cuda'). + + Returns: + dict: A dictionary containing the loaded weights for the spiking neural network. + """ + + lays = torch.load("./utils/weights.pt", + map_location=map_location) + return lays + + +def computeActivity(nb_input, nb_neurons, input_activity, nb_steps, device): + """ + Computes the activity of a feedforward spiking neural network layer. + + This function simulates the activity of a feedforward layer in a spiking neural network over a specified number of time steps. + It records the membrane potential and spike activity for each input and neuron. + + Args: + nb_input (int): The number trials within a batch. + nb_neurons (int): The number of neurons in the layer. + input_activity (torch.Tensor): A tensor of shape [nb_input, nb_steps] representing the input activity over time. + nb_steps (int): The number of time steps to simulate. + device (torch.device): The device to perform the computation on (e.g., 'cpu', 'cuda'). + + Returns: + torch.Tensor: A tensor of shape [nb_input, nb_steps, nb_neurons] representing the spike activity of the neurons over time. + """ + + syn = torch.zeros((nb_input, nb_neurons), device=device, dtype=torch.float) + mem = torch.zeros((nb_input, nb_neurons), device=device, dtype=torch.float) + + # Preallocate memory for recording + mem_rec = torch.zeros((nb_steps, nb_input, nb_neurons), + device=device, dtype=torch.float) + spk_rec = torch.zeros((nb_steps, nb_input, nb_neurons), + device=device, dtype=torch.float) + + # Compute feedforward layer activity + for t in range(nb_steps): + mthr = mem - 1.0 + out = activation(mthr) + rst_out = out.detach() + + new_syn = 0.8187 * syn + input_activity[:, t] + new_mem = (0.9048 * mem + syn) * (1.0 - rst_out) + + mem_rec[t] = mem + spk_rec[t] = out + + mem = new_mem + syn = new_syn + + # Transpose spk_rec to match the original output shape + spk_rec = spk_rec.transpose(0, 1) + return spk_rec + + +def computeRecurrentActivity(nb_input, nb_neurons, input_activity, layer, nb_steps, device): + """ + Computes the activity of a recurrent spiking neural network layer. + + This function simulates the activity of a recurrent layer in a spiking neural network over a specified number of time steps. + It records the membrane potential and spike activity for each input and neuron. + + Args: + nb_input (int): The number trials within a batch. + nb_neurons (int): The number of neurons in the recurrent layer. + input_activity (torch.Tensor): A tensor of shape [nb_input, nb_steps] representing the input activity over time. + layer (torch.Tensor): A tensor representing the recurrent weights of the layer. + nb_steps (int): The number of time steps to simulate. + device (torch.device): The device to perform the computation on (e.g., 'cpu', 'cuda'). + + Returns: + torch.Tensor: A tensor of shape [nb_input, nb_steps, nb_neurons] representing the spike activity of the neurons over time. + """ + + out = torch.zeros((nb_input, nb_neurons), device=device, dtype=torch.float) + syn = torch.zeros((nb_input, nb_neurons), device=device, dtype=torch.float) + mem = torch.zeros((nb_input, nb_neurons), device=device, dtype=torch.float) + + # Preallocate memory for recording + mem_rec = torch.zeros((nb_steps, nb_input, nb_neurons), + device=device, dtype=torch.float) + spk_rec = torch.zeros((nb_steps, nb_input, nb_neurons), + device=device, dtype=torch.float) + + # Compute recurrent layer activity + for t in range(nb_steps): + # input activity plus last step output activity + h1 = input_activity[:, t] + torch.einsum("ab,bc->ac", (out, layer)) + mthr = mem - 1.0 + out = activation(mthr) + rst = out.detach() # We do not want to backprop through the reset + + new_syn = 0.8187 * syn + h1 + new_mem = (0.9048 * mem + syn) * (1.0 - rst) + + mem_rec[t] = mem + spk_rec[t] = out + + mem = new_mem + syn = new_syn + + # Transpose spk_rec to match the original output shape + spk_rec = spk_rec.transpose(0, 1) + return spk_rec + + +def runSNN(inputs, nb_steps, layers, device): + """ + Runs a spiking neural network (SNN) simulation. + + This function simulates the activity of a spiking neural network over a specified number of time steps. + It processes the input through a recurrent layer and a readout layer, and returns the spike activity of the output layer. + + Args: + inputs (torch.Tensor): A tensor of shape [batch_size, timesteps, input_dim] representing the input spike data. + nb_steps (int): The number of time steps to simulate. + layers (tuple): A tuple containing the weight matrices (w1, w2, v1) for the network layers. + device (torch.device): The device to perform the computation on (e.g., 'cpu', 'cuda'). + + Returns: + torch.Tensor: A tensor of shape [batch_size, timesteps, nb_outputs] representing the spike activity of the output layer. + """ + + w1, w2, v1 = layers + bs = inputs.shape[0] + nb_outputs = 20 # number of spiking behaviours from MN paper + nb_hidden = 250 + + h1 = torch.einsum("abc,cd->abd", (inputs, w1)) + spk_rec = computeRecurrentActivity(bs, nb_hidden, h1, v1, nb_steps, device) + + # Readout layer + h2 = torch.einsum("abc,cd->abd", (spk_rec, w2)) + s_out_rec = computeActivity(bs, nb_outputs, h2, nb_steps, device) + + return s_out_rec + + +def classifySpikes(generator): + """ + Classifies spike patterns from a generator of spike data. + + This function processes spike data from multiple sensors, runs a spiking neural network (SNN) to classify the spike patterns, + and returns the predicted labels and softmax probabilities for each sensor. + + Args: + generator (iterable): An iterable that yields tuples of (spikes, id), where: + - spikes (torch.Tensor): A tensor of shape [batch_size, timesteps, sensors] containing the spike data. + - id: An identifier for the batch (not used in this function). + + Returns: + tuple: A tuple containing: + - predictions_out_list (np.ndarray): A 2D array of shape [sensors, total_batches] containing the predicted labels for each sensor. + - softmax_out_list (np.ndarray): A 3D array of shape [sensors, total_batches, num_classes] containing the softmax probabilities for each sensor. + """ + + labels_mapping = getFiringPatternLabels() + device = checkCuda() + + # Load the pre-trained weights + layers = loadWeights(map_location=device) + + # The log softmax function across output units + log_softmax_fn = nn.LogSoftmax(dim=1) + + predictions_out_list = [] + softmax_out_list = [] + + generatorObject = tqdm(generator, desc="Classifying spikes", total=len( + generator), position=0, leave=True) + for spikes, _ in generatorObject: + predictions_list = [] + softmax_list = [] + + channelObject = tqdm( + range(spikes.shape[-1]), desc="Processing channels", position=1, leave=False) + for sensorId in channelObject: + nb_spikes = torch.sum(spikes[:, :, sensorId], axis=1) + sensor_spikes = spikes[:, :, sensorId].to(device).unsqueeze(2) + + # Identify trials with 0 spikes + zero_spike_trials = (nb_spikes == 0) + + # Filter out trials with 0 spikes + non_zero_spike_trials = ~zero_spike_trials + filtered_sensor_spikes = sensor_spikes[non_zero_spike_trials] + + # Run SNN only on trials with non-zero spikes + spks_out = runSNN(inputs=filtered_sensor_spikes, nb_steps=spikes.shape[1], layers=layers, device=device) + spks_sum = torch.sum(spks_out, 1) # sum over time + max_activity_idc = torch.argmax(spks_sum, 1) # argmax over output units + + # MN-defined label of the spiking behaviour + prediction = [labels_mapping[list(labels_mapping.keys())[idx.item()]] for idx in max_activity_idc] + softmax = torch.exp(log_softmax_fn(spks_sum)) + + # Initialize prediction_list and softmax_list with default values for zero spike trials + predictions = ["No spikes"] * len(nb_spikes) + softmaxs = [np.zeros(20)] * len(nb_spikes) + + # Fill in the values for non-zero spike trials + non_zero_indices = torch.flatten(torch.nonzero(non_zero_spike_trials)).tolist() + for i, idx in enumerate(non_zero_indices): + predictions[idx] = prediction[i] + softmaxs[idx] = softmax[i].cpu().detach().numpy() + + # Append the results to the final lists + predictions_list.append(predictions) + softmax_list.append(softmaxs) + + # Convert lists to NumPy arrays + predictions_list = np.array(predictions_list) + softmax_list = np.array(softmax_list) + + # Concatenate the results for the current batch + if len(predictions_out_list) == 0: + predictions_out_list = predictions_list + softmax_out_list = softmax_list + else: + predictions_out_list = np.concatenate( + (predictions_out_list, predictions_list), axis=1) + softmax_out_list = np.concatenate( + (softmax_out_list, softmax_list), axis=1) + + return predictions_out_list, softmax_out_list diff --git a/utils/surrogate_gradient.py b/utils/surrogate_gradient.py new file mode 100644 index 0000000..51e6586 --- /dev/null +++ b/utils/surrogate_gradient.py @@ -0,0 +1,81 @@ +""" +surrogate_gradient.py + +This module implements a custom autograd function for spiking nonlinearity with a surrogate gradient, +which is essential for training spiking neural networks (SNNs) using backpropagation. The surrogate +gradient is based on the normalized negative part of a fast sigmoid, as described in Zenke & Ganguli (2018). + +Classes and Functions: +- SurrGradSpike: A custom autograd function that implements the spiking nonlinearity and surrogate gradient. + - scale: A class attribute that defines the scaling factor for the surrogate gradient. + - forward(ctx, input): Computes a step function of the input tensor and returns it. The input tensor is saved + for backward computation. + - backward(ctx, grad_output): Computes the surrogate gradient of the loss with respect to the input using the + saved input tensor. The gradient is calculated using the normalized negative part of a fast sigmoid. +- activation: A function that applies the SurrGradSpike autograd function. + +Usage: +- The SurrGradSpike class provides static methods for the forward and backward passes. + - forward(ctx, input): Computes a step function of the input tensor and returns it. The input tensor is saved for + backward computation. + - backward(ctx, grad_output): Computes the surrogate gradient of the loss with respect to the input using the saved + input tensor. The gradient is calculated using the normalized negative part of a fast sigmoid. + +Example: + import torch + from surrogate_gradient import activation + + input_tensor = torch.tensor([0.5, -0.5, 0.0, 1.0], requires_grad=True) + output_tensor = activation(input_tensor) + output_tensor.backward(torch.ones_like(input_tensor)) + +Dependencies: +- torch: PyTorch library for tensor computations and neural network operations. + +License: +This project is licensed under the GPL-3.0 License. See the LICENSE file for more details. + +""" + +import torch + + +class SurrGradSpike(torch.autograd.Function): + """ + Here we implement our spiking nonlinearity which also implements + the surrogate gradient. By subclassing torch.autograd.Function, + we will be able to use all of PyTorch's autograd functionality. + Here we use the normalized negative part of a fast sigmoid + as this was done in Zenke & Ganguli (2018). + """ + + scale = 100 + + @staticmethod + def forward(ctx, input): + """ + In the forward pass we compute a step function of the input Tensor + and return it. ctx is a context object that we use to stash information which + we need to later backpropagate our error signals. To achieve this we use the + ctx.save_for_backward method. + """ + ctx.save_for_backward(input) + out = torch.zeros_like(input) + out[input > 0] = 1.0 + return out + + @staticmethod + def backward(ctx, grad_output): + """ + In the backward pass we receive a Tensor we need to compute the + surrogate gradient of the loss with respect to the input. + Here we use the normalized negative part of a fast sigmoid + as this was done in Zenke & Ganguli (2018). + """ + (input,) = ctx.saved_tensors + grad_input = grad_output.clone() + grad = grad_input / (SurrGradSpike.scale * torch.abs(input) + 1.0) ** 2 + return grad + + +activation = SurrGradSpike.apply diff --git a/utils/weights.pt b/utils/weights.pt new file mode 100644 index 0000000..ba6269d Binary files /dev/null and b/utils/weights.pt differ