Skip to content

Commit

Permalink
Add missing file
Browse files Browse the repository at this point in the history
  • Loading branch information
mvankem committed May 22, 2024
1 parent a4836f3 commit 1737c71
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions training/pt2kerasify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python3

import sys

import torch
import keras

import kerasify

input_fn = sys.argv[1]
output_fn = sys.argv[2]

# Load PyTorch model
encoder = torch.load(input_fn)

# Convert into keras model
# - only ReLU
# - only fully connected
# - no batch norm etc.
input_shape = [encoder[0].in_features]
n_units_lst = [layer.out_features for layer in encoder if not isinstance(layer, torch.nn.ReLU)]

model = keras.models.Sequential()
model.add(keras.layers.Dense(n_units_lst[0], input_shape=input_shape, activation='relu'))
for n_units in n_units_lst[1:-1]:
model.add(keras.layers.Dense(n_units, activation='relu'))
model.add(keras.layers.Dense(n_units_lst[-1], activation='linear'))

print()
model.summary()

# Copy weights
n_layers = len(n_units_lst)
for i in range(n_layers):
model.layers[i].set_weights([encoder[i * 2].weight.detach().T, encoder[i * 2].bias.detach()])

# Kerasify
kerasify.export_model(model, output_fn)

0 comments on commit 1737c71

Please sign in to comment.