Skip to content

Commit

Permalink
Merge pull request #17 from estefanysuarez/test/tutorial
Browse files Browse the repository at this point in the history
Test/tutorial
  • Loading branch information
estefanysuarez authored Mar 27, 2023
2 parents 823c0d0 + 7d0a32a commit a59c24a
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 165 deletions.
14 changes: 12 additions & 2 deletions conn2res/coding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,17 @@

def get_modules(module_assignment):
"""
# TODO
_summary_
Parameters
----------
module_assignment : _type_
_description_
Returns
-------
_type_
_description_
"""
# get module ids
module_ids = np.unique(module_assignment)
Expand Down Expand Up @@ -65,7 +75,7 @@ def encoder(reservoir_states, target, readout_modules=None,
"""

# use multiple subsets of readout nodes designated by readout_modules
# use multiple subsets of readout nodes designated by readout_modules
if readout_modules is not None:

if isinstance(readout_modules, np.ndarray):
Expand Down
55 changes: 3 additions & 52 deletions conn2res/iodata.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,26 +100,7 @@ def get_available_tasks():
return NEUROGYM_TASKS + NATIVE_TASKS + RESERVOIRPY_TASKS


def unbatch(x):
"""
Removes batch_size dimension from array
Parameters
----------
x : numpy.ndarray
array with dimensions (seq_len, batch_size, features)
Returns
-------
new_x : numpy.ndarray
new array with dimensions (batch_size*seq_len, features)
"""
# TODO right now it only works when x is (batch_first = False)
return np.concatenate(x, axis=0)


def fetch_dataset(task, **kwargs):
def fetch_dataset(task, report=True, **kwargs):
"""
Fetches inputs and labels for 'task' from the NeuroGym
repository
Expand Down Expand Up @@ -162,29 +143,11 @@ def fetch_dataset(task, **kwargs):


def create_neurogymn_dataset(task, n_trials=100, add_constant=False, **kwargs):
"""
_summary_
Parameters
----------
task : _type_
_description_
n_trials : int, optional
_description_, by default 100
add_constant : bool, optional
_description_, by default False
Returns
-------
_type_
_description_
"""
# create a Dataset object from NeuroGym
dataset = ngym.Dataset(task+'-v0', env_kwargs=kwargs)

# get environment object
env = dataset.env
# print(env.timing)

# generate per trial dataset
_ = env.reset()
Expand Down Expand Up @@ -287,7 +250,7 @@ def create_dataset(task, n_timesteps=1000, horizon=1, **kwargs):
y = np.hstack([x[horizon_max-h:-h] for h in horizon])
x = x[horizon_max:]

get_info_data(task, x, y)
# get_info_data(task, x, y)

if horizon_sign == -1:
return x, y
Expand All @@ -299,19 +262,7 @@ def create_dataset(task, n_timesteps=1000, horizon=1, **kwargs):


def get_n_features(task):
"""
_summary_
Parameters
----------
task : _type_
_description_

Returns
-------
_type_
_description_
"""
x, _ = fetch_dataset(task, n_trials=1)

return x[0].shape[1]
Expand Down Expand Up @@ -411,7 +362,7 @@ def get_info_data(task, x, y):
print(f'\tmodel = {model.__name__}')


def get_sample_weight(inputs, sample_block=None):
def get_sample_weight(inputs, labels, sample_block=None):
"""
Time averages dataset based on sample class and sample weight
Expand Down
2 changes: 1 addition & 1 deletion conn2res/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def mean_absolute_error(
def corrcoef(
y_true, y_pred, multioutput='uniform_average', nonnegative=None,
**kwargs
):
):
"""
Pearson's correlation coefficient.
Expand Down
99 changes: 16 additions & 83 deletions conn2res/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ def transform_data(


def plot_iodata(
x, y, n_trials=7, title=None, show=True, savefig=False, fname=None,
**kwargs
x, y, n_instances=7, title=None, show=True, savefig=False, fname=None, **kwargs
):
"""
#TODO
Expand All @@ -136,23 +135,14 @@ def plot_iodata(
_description_, by default False
fname : _type_, optional
_description_, by default None
"""
x = x[:n_trials]
y = y[:n_trials]

# get end points for trials to plot trial separators
end_points = []
tf = 0
for i in range(n_trials):
tf += len(x[i])
end_points.append(tf)

# convert x and y to arrays for visualization
if isinstance(x, list):
x = np.vstack(x)
x = np.vstack(x[:n_instances])
if isinstance(y, list):
y = np.vstack(y).squeeze()
y = np.vstack(y[:n_instances]).squeeze()

# set plotting theme
sns.set(style="ticks", font_scale=1.0)
Expand All @@ -161,11 +151,13 @@ def plot_iodata(
# set color palette
palette = kwargs.pop('palette', None)

# plot inputs (x) and outputs (y)
# plot
sns.lineplot(
data=x, palette=palette, dashes=False, legend=False, ax=ax, **kwargs)
data=x, palette=palette, dashes=False, legend=False,
ax=ax, **kwargs)
sns.lineplot(
data=y, palette=palette, dashes=False, legend=False, ax=ax, **kwargs)
data=y, palette=palette, dashes=False, legend=False,
ax=ax, **kwargs)

# set axis labels
ax.set_xlabel('time steps', fontsize=11)
Expand All @@ -186,11 +178,6 @@ def plot_iodata(
ax.legend(handles=ax.lines, labels=new_labels, loc='best',
fontsize=8)

# plot trial line separators
for tf in end_points:
plt.plot(
tf * np.ones((2)), np.arange(2), c='black', linestyle='--')

# set title
if title is not None:
plt.title(title, fontsize=12)
Expand Down Expand Up @@ -275,7 +262,6 @@ def plot_diagnostics(
axs = axs.ravel()

plt.subplots_adjust(wspace=0.1)

# set color palette
palette = kwargs.pop('palette', None)

Expand All @@ -291,14 +277,13 @@ def plot_diagnostics(
dashes=False, legend=False, ax=axs[2])
sns.lineplot(
data=y_pred[:160], palette=palette,
dashes=False, legend=False, ax=axs[2], linewidth=2.5)
dashes=False, legend=False, ax=axs[2])

# set axis labels
axs[0].set_ylabel('x signal \namplitude', fontsize=11)
axs[1].set_ylabel('decision \nfunction', fontsize=11)
axs[2].set_xlabel('time steps', fontsize=11)
axs[2].set_ylabel('y signal \namplitude', fontsize=11)
# axs[1].set_ylim(0, 5e7)

# set axis limits
for ax in axs:
Expand All @@ -309,16 +294,9 @@ def plot_diagnostics(
x_labels = ['x']
else:
x_labels = [f'x{n+1}' for n in range(x.shape[1])]

if dec_func.ndim == 1:
dec_func_labels = ['decision function']
else:
dec_func_labels = [f'decision function {n+1}' for n in range(dec_func.shape[1])]

# set legend
axs[0].legend(handles=axs[0].lines, labels=x_labels,
loc='upper right', fontsize=8)
axs[1].legend(handles=axs[1].lines, labels=dec_func_labels,
axs[1].legend(handles=axs[1].lines, labels=['decision function'],
loc='upper right', fontsize=8)
axs[2].legend(handles=axs[2].lines, labels=['target', 'predicted target'],
loc='upper right', fontsize=8)
Expand All @@ -340,16 +318,16 @@ def plot_diagnostics(

fig.savefig(fname=os.path.join(FIG_DIR, f'{fname}.png'),
transparent=True, bbox_inches='tight', dpi=300)

plt.close()


def plot_performance(
df, x='alpha', y='score', normalize=False,
df, x='alpha', y='score', norm=False,
title=None, show=True, savefig=False, fname=None, **kwargs
):

if normalize:
if norm:
df[y] = df[y] / max(df[y])

# set plotting theme
Expand All @@ -373,7 +351,7 @@ def plot_performance(
ax.set_xlabel('alpha', fontsize=11)
y_label = ' '.join(y.split('_'))
ax.set_ylabel(y_label, fontsize=11)

# set title
if title is not None:
plt.title(title, fontsize=12)
Expand All @@ -392,49 +370,4 @@ def plot_performance(
fig.savefig(fname=os.path.join(FIG_DIR, f'{fname}.png'),
transparent=True, bbox_inches='tight', dpi=300)

plt.close()


def plot_phase_space(x, y, sample=None, xlim=None, ylim=None, subplot=None, cmap=None,
num=1, figsize=(13, 5), title=None, fname='phase_space', savefig=False, block=False
):
#TODO
# open figure and create subplot
plt.figure(num=num, figsize=figsize)
if subplot is None:
subplot = (1, 1, 1)
plt.subplot(*subplot)

# plot data
if sample is None:
plt.plot(x)
else:
t = np.arange(*sample)
if cmap is None:
plt.plot(t, x[t])
else:
for i, _ in enumerate(t[:-1]):
plt.plot(x[t[i:i+2]], y[t[i:i+2]],
color=getattr(plt.cm, cmap)(255*i//np.diff(sample)))

# add x and y limits
if xlim is not None:
plt.xlim(xlim)
if ylim is not None:
plt.xlim(ylim)

# set xtick/ythick fontsize
plt.xticks(fontsize=22)
plt.yticks(fontsize=22)

# add title
if title is not None:
plt.title(f'{title} phase space', fontsize=22)

# set tight layout in case there are different subplots
plt.tight_layout()

if savefig:
plt.savefig(fname=os.path.join(FIG_DIR, f'{fname}.png'),
transparent=True, bbox_inches='tight', dpi=300)
plt.show(block=block)
plt.close()
Loading

0 comments on commit a59c24a

Please sign in to comment.