-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
56b6441
commit 81263ba
Showing
1 changed file
with
95 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import pandas as pd | ||
from typing import Dict, Any | ||
from magneto import Magneto as Magneto_Lib | ||
from bdikit.schema_matching.one2one.base import BaseSchemaMatcher | ||
from bdikit.download import get_cached_model_or_download | ||
|
||
DEFAULT_MAGNETO_MODEL = "magneto-v0.1" | ||
|
||
|
||
class MagnetoBase(BaseSchemaMatcher): | ||
def __init__(self, kwargs: Dict[str, Any] = None): | ||
if kwargs is None: | ||
kwargs = {} | ||
self.magneto = Magneto_Lib(**kwargs) | ||
|
||
def map( | ||
self, | ||
source: pd.DataFrame, | ||
target: pd.DataFrame, | ||
): | ||
raw_matches = self.magneto.get_matches(source, target) | ||
# Initialize result dictionary | ||
result = {} | ||
|
||
# Iterate through the input dictionary | ||
for (source, target), score in raw_matches.items(): | ||
source_column = source[1] | ||
target_column = target[1] | ||
|
||
# Update the result if it's a new source or has a higher score | ||
if ( | ||
source_column not in result | ||
or raw_matches[ | ||
(("source", source_column), ("target", result[source_column])) | ||
] | ||
< score | ||
): | ||
result[source_column] = target_column | ||
|
||
return result | ||
|
||
|
||
class Magneto(MagnetoBase): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
|
||
class MagnetoFT(MagnetoBase): | ||
def __init__( | ||
self, | ||
encoding_mode: str = "header_values_verbose", | ||
model_name: str = DEFAULT_MAGNETO_MODEL, | ||
model_path: str = None, | ||
): | ||
embedding_model = check_magneto_model(model_name, model_path) | ||
kwargs = {"encoding_mode": encoding_mode, "embedding_model": embedding_model} | ||
super().__init__(kwargs) | ||
|
||
|
||
class MagnetoGPT(MagnetoBase): | ||
def __init__(self): | ||
kwargs = {"use_bp_reranker": False, "use_gpt_reranker": True} | ||
super().__init__(kwargs) | ||
|
||
|
||
class MagnetoFTGPT(MagnetoBase): | ||
def __init__( | ||
self, | ||
encoding_mode: str = "header_values_verbose", | ||
model_name: str = DEFAULT_MAGNETO_MODEL, | ||
model_path: str = None, | ||
): | ||
embedding_model = check_magneto_model(model_name, model_path) | ||
kwargs = { | ||
"encoding_mode": encoding_mode, | ||
"embedding_model": embedding_model, | ||
"use_bp_reranker": False, | ||
"use_gpt_reranker": True, | ||
} | ||
super().__init__(kwargs) | ||
|
||
|
||
def check_magneto_model(model_name: str, model_path: str): | ||
if model_name and model_path: | ||
raise ValueError( | ||
"Only one of model_name or model_path should be provided " | ||
"(they are mutually exclusive)" | ||
) | ||
|
||
if model_path: | ||
return model_path | ||
elif model_name: | ||
return get_cached_model_or_download(model_name) | ||
else: | ||
raise ValueError("Either model_name or model_path must be provided") |