-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathanalyze_different_configs.py
176 lines (137 loc) · 5.48 KB
/
analyze_different_configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
import helpers
from corcol_params.sim_params import sim_dict
# Analyze results for different stimulation configurations
stim_currents = [0.5, 1, 1.5, 2, 2.5]
N_GROUPS_LIST = [1, 2, 4, 8, 16, 32] # number of stim electrode groups
results_df = pd.DataFrame(
columns=["current", "condition", "n_groups", "num_components", "overlap"]
)
for current in stim_currents:
base_path = os.path.join(
os.getcwd(), "outputs", f"data_8Hz_k10_scale005_[{current}]uA"
)
# Baseline
sim_dict["data_path"] = os.path.join(base_path, "data_baseline")
pkl_path = os.path.join(sim_dict["data_path"], "baseline_spike_rates.pkl")
with open(pkl_path, "rb") as f:
baseline_spike_rates = pickle.load(f)
pca = PCA(n_components=3)
pca.fit(baseline_spike_rates)
baseline_pca = pca.transform(baseline_spike_rates)
baseline_num_components = helpers.get_dimensionality(
baseline_spike_rates, variance_threshold=0.95
)
baseline_row = pd.DataFrame(
{
"current": [current],
"condition": ["baseline"],
"n_groups": [None],
"num_components": [baseline_num_components],
"overlap": [None],
}
)
results_df = pd.concat([results_df, baseline_row], ignore_index=True)
# Iterate through stim channel groups
stim_projected_list = []
stim_num_components_list = []
for n_groups in N_GROUPS_LIST:
sim_dict["data_path"] = os.path.join(
base_path, f"data_randstim_{n_groups}groups/"
)
pkl_path = os.path.join(
sim_dict["data_path"], f"{n_groups}groups_stim_spike_rates.pkl"
)
pkl_path_stim_pulses = os.path.join(
sim_dict["data_path"], f"{n_groups}groups_stim_pulses.pkl"
)
with open(pkl_path, "rb") as f:
stim_evoked_spike_rates = pickle.load(f)
stim_projected = pca.transform(stim_evoked_spike_rates)
stim_projected_list.append(stim_projected)
num_components = helpers.get_dimensionality(
stim_evoked_spike_rates, variance_threshold=0.85
)
stim_num_components_list.append(num_components)
overlap_list, _, _ = helpers.compute_all_overlaps(baseline_pca, stim_projected_list)
stim_data = pd.DataFrame(
{
"current": [current] * len(N_GROUPS_LIST),
"condition": ["stim"] * len(N_GROUPS_LIST),
"n_groups": N_GROUPS_LIST,
"num_components": stim_num_components_list,
"overlap": overlap_list,
}
)
results_df = pd.concat([results_df, stim_data], ignore_index=True)
# %% Plot the results for 32 channel group for all currents
df_32 = results_df[results_df["n_groups"] == 32]
plt.plot(np.arange(5), df_32["overlap"])
plt.xticks(np.arange(5), np.arange(0.5, 3, 0.5))
plt.xlabel("Stimulation Current (uA)")
plt.ylabel("Overlap Score (Jaccard Index)")
plt.title("Overlap Comparison for 32 Stim Groups at Different Currents")
# %% Plot all overlaps for different currents
n_rows = len(stim_currents) // 2 + len(stim_currents) % 2
n_cols = 2
fig, axes = plt.subplots(n_rows, n_cols, figsize=(8, 6), constrained_layout=True)
axes = axes.flatten()
# Loop through each stimulation current and create a subplot
for i, current in enumerate(stim_currents):
df_current = results_df[results_df["current"] == current]
df_current = df_current.dropna(subset=["n_groups", "overlap"])
df_current["n_groups"] = pd.to_numeric(df_current["n_groups"], errors="coerce")
df_current["overlap"] = pd.to_numeric(df_current["overlap"], errors="coerce")
n_groups = df_current["n_groups"].values
overlap = df_current["overlap"].values
if len(n_groups) > 0 and len(overlap) > 0:
ax = axes[i]
ax.plot(range(6), overlap, marker="o", linestyle="-", color="C0")
ax.set_xticks(range(6))
ax.set_xticklabels(n_groups, rotation=45)
ax.set_title(f"Current = {current} µA")
ax.set_xlabel("N Groups")
ax.set_ylabel("Overlap")
if len(stim_currents) % 2 != 0:
fig.delaxes(axes[-1])
plt.suptitle("Overlap vs N Groups for Different Currents", fontsize=16)
plt.tight_layout()
# %%
# Create a single figure and axis for plotting
fig, ax = plt.subplots(figsize=(8, 6))
for i, current in enumerate(stim_currents):
# Filter the DataFrame for the current stimulation value
df_current = results_df[results_df["current"] == current]
# Drop rows with NaN values in 'n_groups' and 'overlap' columns
df_current = df_current.dropna(subset=["n_groups", "overlap"])
# Ensure columns are numeric
df_current["n_groups"] = pd.to_numeric(df_current["n_groups"], errors="coerce")
df_current["overlap"] = pd.to_numeric(df_current["overlap"], errors="coerce")
n_groups = df_current["n_groups"].values
overlap = df_current["overlap"].values
# Plot if data is available
if len(n_groups) > 0 and len(overlap) > 0:
ax.plot(
range(6),
overlap,
marker="o",
linestyle="-",
color=f"C{i}",
label=f"{current} µA",
)
# Set axis labels and title
ax.set_xlabel("N Groups")
ax.set_xticks(range(6))
ax.set_xticklabels(n_groups, rotation=45)
ax.set_ylabel("Overlap")
ax.set_title("Overlap vs N Groups for Different Currents")
# Add legend
ax.legend(title="Current (µA)")
# Adjust layout for better visualization
plt.tight_layout()
plt.show()