From 71bb6dc3c7c7326703db7c37da62700a41406aa8 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Mon, 4 Jul 2022 17:18:08 +0200 Subject: [PATCH 1/3] added a get config method for wrn cifar in order to be able to deserialize it --- tensorflow2/tf2cv/models/wrn_cifar.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tensorflow2/tf2cv/models/wrn_cifar.py b/tensorflow2/tf2cv/models/wrn_cifar.py index e81328c35..376a38fd6 100644 --- a/tensorflow2/tf2cv/models/wrn_cifar.py +++ b/tensorflow2/tf2cv/models/wrn_cifar.py @@ -87,6 +87,18 @@ def call(self, x, training=None): x = self.output1(x) return x + def get_config(self): + config = super(CIFARWRN, self).get_config() + config.update({ + "channels": self.channels, + "init_block_channels": self.init_block_channels, + "in_channels": self.in_channels, + "in_size": self.in_size, + "classes": self.classes, + "data_format": self.data_format, + }) + return config + def get_wrn_cifar(classes, blocks, From a5c4ce0c714311d0663b07d37a2d43a17445b482 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Mon, 4 Jul 2022 17:22:36 +0200 Subject: [PATCH 2/3] added unsaved variables to wrn in order to make serialization easier --- tensorflow2/tf2cv/models/wrn_cifar.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow2/tf2cv/models/wrn_cifar.py b/tensorflow2/tf2cv/models/wrn_cifar.py index 376a38fd6..ae04bcda4 100644 --- a/tensorflow2/tf2cv/models/wrn_cifar.py +++ b/tensorflow2/tf2cv/models/wrn_cifar.py @@ -44,6 +44,9 @@ def __init__(self, self.in_size = in_size self.classes = classes self.data_format = data_format + self.channels = channels + self.init_block_channels = init_block_channels + self.in_channels = in_channels self.features = SimpleSequential(name="features") self.features.add(conv3x3( From 00e0c018ad840a9d021b450e08790048db6b0dd5 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Mon, 4 Jul 2022 18:03:06 +0200 Subject: [PATCH 3/3] corrected config --- tensorflow2/tf2cv/models/wrn_cifar.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow2/tf2cv/models/wrn_cifar.py b/tensorflow2/tf2cv/models/wrn_cifar.py index ae04bcda4..27a622869 100644 --- a/tensorflow2/tf2cv/models/wrn_cifar.py +++ b/tensorflow2/tf2cv/models/wrn_cifar.py @@ -91,15 +91,14 @@ def call(self, x, training=None): return x def get_config(self): - config = super(CIFARWRN, self).get_config() - config.update({ + config = { "channels": self.channels, "init_block_channels": self.init_block_channels, "in_channels": self.in_channels, "in_size": self.in_size, "classes": self.classes, "data_format": self.data_format, - }) + } return config