Skip to content

Commit

Permalink
Rename kernel to multi_kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen authored and scarlehoff committed Mar 4, 2024
1 parent 1623808 commit 8899f31
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions n3fit/src/n3fit/backends/keras_backend/multi_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(

def build(self, input_shape):
input_dim = input_shape[-1]
self.kernel = self.add_weight(
self.multi_kernel = self.add_weight(
name="kernel",
shape=(self.replicas, input_dim, self.units),
initializer=self.kernel_initializer,
Expand All @@ -80,15 +80,15 @@ def build(self, input_shape):
# TODO: benchmark against the replica-agnostic einsum below and make that default
# see https://github.com/NNPDF/nnpdf/pull/1905#discussion_r1489344081
if self.replicas == 1:
matmul = lambda inputs: tf.tensordot(inputs, self.kernel[0], [[-1], [0]])
matmul = lambda inputs: tf.tensordot(inputs, self.multi_kernel[0], [[-1], [0]])
if self.is_first_layer:
# Manually add replica dimension
self.matmul = lambda x: tf.expand_dims(matmul(x), axis=1)
else:
self.matmul = matmul
else:
einrule = "bnf,rfg->brng" if self.is_first_layer else "brnf,rfg->brng"
self.matmul = lambda inputs: tf.einsum(einrule, inputs, self.kernel)
self.matmul = lambda inputs: tf.einsum(einrule, inputs, self.multi_kernel)

def call(self, inputs):
"""
Expand Down

0 comments on commit 8899f31

Please sign in to comment.