From 77c5b668f30f3e19f9ff28dfd9e3401553bf8e25 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 1 Dec 2023 20:42:24 +0000 Subject: [PATCH] fix mix transformer tests --- keras_cv/backend/config.py | 1 + .../backbones/mix_transformer/mix_transformer_backbone.py | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_cv/backend/config.py b/keras_cv/backend/config.py index 27b84e68eb..781e59ed72 100644 --- a/keras_cv/backend/config.py +++ b/keras_cv/backend/config.py @@ -23,6 +23,7 @@ if hasattr(keras, "version") and keras.version().startswith("3."): _USE_KERAS_3 = True + def detect_if_tensorflow_uses_keras_3(): # We follow the version of keras that tensorflow is configured to use. try: diff --git a/keras_cv/models/backbones/mix_transformer/mix_transformer_backbone.py b/keras_cv/models/backbones/mix_transformer/mix_transformer_backbone.py index bf6a1a6ec2..9cd2496d2a 100644 --- a/keras_cv/models/backbones/mix_transformer/mix_transformer_backbone.py +++ b/keras_cv/models/backbones/mix_transformer/mix_transformer_backbone.py @@ -28,7 +28,6 @@ from keras_cv import layers as cv_layers from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras -from keras_cv.backend import ops from keras_cv.models import utils from keras_cv.models.backbones.backbone import Backbone from keras_cv.models.backbones.mix_transformer.mix_transformer_backbone_presets import ( # noqa: E501 @@ -140,8 +139,8 @@ def __init__( # call in `OverlappingPatchingAndEmbedding` stride = 4 if i == 0 else 2 new_height, new_width = ( - int(ops.shape(x)[1] / stride), - int(ops.shape(x)[2] / stride), + int(x.shape[1] / stride), + int(x.shape[2] / stride), ) x = patch_embedding_layers[i](x)