Skip to content

Commit

Permalink
#12558: TTNN implementation of MNIST model
Browse files Browse the repository at this point in the history
  • Loading branch information
sabira-mcw committed Sep 30, 2024
1 parent b435fce commit ebcfbca
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions models/demos/mnist/tt/tt_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn
import torch


def mnist(device, batch_size, x, parameters):
x = ttnn.reshape(x, (x.shape[0], 1, 1, 784))

x = ttnn.to_device(x, device=device)
x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT)
x = ttnn.linear(
x, parameters.fc1.weight, bias=parameters.fc1.bias, memory_config=ttnn.L1_MEMORY_CONFIG, activation="relu"
)
x = ttnn.linear(
x,
parameters.fc2.weight,
bias=parameters.fc2.bias,
memory_config=ttnn.L1_MEMORY_CONFIG,
activation="relu",
)
x = ttnn.linear(
x,
parameters.fc3.weight,
bias=parameters.fc3.bias,
memory_config=ttnn.L1_MEMORY_CONFIG,
activation="relu",
)

x = ttnn.softmax(x)

return x

0 comments on commit ebcfbca

Please sign in to comment.