From fd81e02df67166dcd66413a030c2cae297dca5d9 Mon Sep 17 00:00:00 2001 From: Tony Wu <28306721+tonywu71@users.noreply.github.com> Date: Wed, 16 Oct 2024 17:36:43 +0200 Subject: [PATCH] fix: add missing load_state_dict for base ColPaliDuo model --- .../colpali_duo/publish_base_colpali_duo.py | 19 ++++++++++++------- .../colpali_duo/test_modeling_colpali_duo.py | 2 +- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/colpali_engine/models/paligemma/colpali_duo/publish_base_colpali_duo.py b/colpali_engine/models/paligemma/colpali_duo/publish_base_colpali_duo.py index 432b15c1..d7247160 100644 --- a/colpali_engine/models/paligemma/colpali_duo/publish_base_colpali_duo.py +++ b/colpali_engine/models/paligemma/colpali_duo/publish_base_colpali_duo.py @@ -2,8 +2,8 @@ import torch import typer +from transformers.models.paligemma import PaliGemmaForConditionalGeneration -from colpali_engine.models.paligemma.colpali.modeling_colpali import ColPali from colpali_engine.models.paligemma.colpali_duo.configuration_colpali_duo import ColPaliDuoConfig from colpali_engine.models.paligemma.colpali_duo.modeling_colpali_duo import ColPaliDuo from colpali_engine.utils.torch_utils import get_torch_device @@ -13,27 +13,32 @@ def main(): """ Publish the base ColPaliDuo model to the hub. """ - base_colpali_duo_name = "vidore/colpali-duo-base-0.1" + base_colpali_duo_name = "vidore/colpali-duo-base" device = get_torch_device("auto") + # Load old model old_model = cast( - ColPali, - ColPali.from_pretrained( - "vidore/colpali-v1.2", + PaliGemmaForConditionalGeneration, + PaliGemmaForConditionalGeneration.from_pretrained( + "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, - device_map=device, # or "mps" if on Apple Silicon + device_map=device, ), ).eval() + # Load new model model_config = ColPaliDuoConfig( **old_model.config.to_dict(), single_vector_projector_dim=1024, single_vector_pool_strategy="mean", multi_vector_projector_dim=128, ) + model = ColPaliDuo(config=model_config).to(device).to(old_model.dtype).eval() - model = ColPaliDuo(config=model_config).to(device).to(torch.bfloat16).eval() + # Copy pre-trained weights from old model + model.load_state_dict(old_model.state_dict(), strict=False) + # Push to hub model.push_to_hub(base_colpali_duo_name, private=True) return diff --git a/tests/models/paligemma/colpali_duo/test_modeling_colpali_duo.py b/tests/models/paligemma/colpali_duo/test_modeling_colpali_duo.py index 9a8a30a2..11227a16 100644 --- a/tests/models/paligemma/colpali_duo/test_modeling_colpali_duo.py +++ b/tests/models/paligemma/colpali_duo/test_modeling_colpali_duo.py @@ -35,7 +35,7 @@ def test_load_colpali_duo_from_pretrained(colpali_duo_config: ColPaliDuoConfig) @pytest.fixture(scope="module") def colpali_duo_model_path() -> str: - return "vidore/colpali-duo-base-0.1" + return "vidore/colpali-duo-base" @pytest.fixture(scope="module")