Skip to content

Commit

Permalink
fix: add missing load_state_dict for base ColPaliDuo model
Browse files Browse the repository at this point in the history
  • Loading branch information
tonywu71 authored and ManuelFay committed Oct 21, 2024
1 parent 444ef54 commit fd81e02
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import torch
import typer
from transformers.models.paligemma import PaliGemmaForConditionalGeneration

from colpali_engine.models.paligemma.colpali.modeling_colpali import ColPali
from colpali_engine.models.paligemma.colpali_duo.configuration_colpali_duo import ColPaliDuoConfig
from colpali_engine.models.paligemma.colpali_duo.modeling_colpali_duo import ColPaliDuo
from colpali_engine.utils.torch_utils import get_torch_device
Expand All @@ -13,27 +13,32 @@ def main():
"""
Publish the base ColPaliDuo model to the hub.
"""
base_colpali_duo_name = "vidore/colpali-duo-base-0.1"
base_colpali_duo_name = "vidore/colpali-duo-base"
device = get_torch_device("auto")

# Load old model
old_model = cast(
ColPali,
ColPali.from_pretrained(
"vidore/colpali-v1.2",
PaliGemmaForConditionalGeneration,
PaliGemmaForConditionalGeneration.from_pretrained(
"google/paligemma-3b-mix-448",
torch_dtype=torch.bfloat16,
device_map=device, # or "mps" if on Apple Silicon
device_map=device,
),
).eval()

# Load new model
model_config = ColPaliDuoConfig(
**old_model.config.to_dict(),
single_vector_projector_dim=1024,
single_vector_pool_strategy="mean",
multi_vector_projector_dim=128,
)
model = ColPaliDuo(config=model_config).to(device).to(old_model.dtype).eval()

model = ColPaliDuo(config=model_config).to(device).to(torch.bfloat16).eval()
# Copy pre-trained weights from old model
model.load_state_dict(old_model.state_dict(), strict=False)

# Push to hub
model.push_to_hub(base_colpali_duo_name, private=True)

return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_load_colpali_duo_from_pretrained(colpali_duo_config: ColPaliDuoConfig)

@pytest.fixture(scope="module")
def colpali_duo_model_path() -> str:
return "vidore/colpali-duo-base-0.1"
return "vidore/colpali-duo-base"


@pytest.fixture(scope="module")
Expand Down

0 comments on commit fd81e02

Please sign in to comment.