Skip to content

Commit

Permalink
Update usage.rst - FIx other variables
Browse files Browse the repository at this point in the history
  • Loading branch information
timonmerk authored Oct 29, 2023
1 parent 628311f commit 74b0b7d
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1336,75 +1336,75 @@ gets initialized which also allows the `prior` to be directly parametrized.
import numpy as np
import cebra.datasets
import torch

if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"

neural_data = cebra.load_data(file="neural_data.npz", key="neural")

discrete_label = cebra.load_data(
file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
)

# 1. Define Cebra Dataset
input_data = cebra.data.TensorDataset(
torch.from_numpy(neural_data).type(torch.FloatTensor),
discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
).to(device)

# 2. Define Cebra Model
neural_model = cebra.models.init(
name="offset10-model",
num_neurons=InputData.input_dimension,
num_neurons=input_data.input_dimension,
num_units=32,
num_output=2,
).to(device)

InputData.configure_for(neural_model)

input_data.configure_for(neural_model)
# 3. Define Loss Function Criterion and Optimizer
Crit = cebra.models.criterions.LearnableCosineInfoNCE(
crit = cebra.models.criterions.LearnableCosineInfoNCE(
temperature=0.001,
min_temperature=0.0001
).to(device)

Opt = torch.optim.Adam(
list(neural_model.parameters()) + list(Crit.parameters()),
opt = torch.optim.Adam(
list(neural_model.parameters()) + list(crit.parameters()),
lr=0.001,
weight_decay=0,
)

# 4. Initialize Cebra Model
solver = cebra.solver.init(
name="single-session",
model=neural_model,
criterion=Crit,
optimizer=Opt,
criterion=crit,
optimizer=opt,
tqdm_on=True,
).to(device)

# 5. Define Data Loader
loader = cebra.data.single_session.DiscreteDataLoader(
dataset=InputData, num_steps=10, batch_size=200, prior="uniform"
dataset=input_data, num_steps=10, batch_size=200, prior="uniform"
)

# 6. Fit Model
solver.fit(loader=loader)

# 7. Transform Embedding
TrainBatches = np.lib.stride_tricks.sliding_window_view(
train_batches = np.lib.stride_tricks.sliding_window_view(
neural_data, neural_model.get_offset().__len__(), axis=0
)

X_train_emb = solver.transform(
torch.from_numpy(TrainBatches[:]).type(torch.FloatTensor).to(device)
x_train_emb = solver.transform(
torch.from_numpy(train_batches[:]).type(torch.FloatTensor).to(device)
).to(device)

# 8. Plot Embedding
cebra.plot_embedding(
X_train_emb,
x_train_emb,
discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
markersize=10,
)

0 comments on commit 74b0b7d

Please sign in to comment.