Skip to content

Commit

Permalink
feat: Streamline the addition of new standards (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
roquelopez authored Jan 13, 2025
1 parent eaa4a5d commit 03c8edb
Show file tree
Hide file tree
Showing 16 changed files with 20,028 additions and 161,245 deletions.
17 changes: 15 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,25 @@ Contributors can add new methods for schema and value matching by following thes

2. Define a class in the module that implements either `BaseValueMatcher` (for value matching) or `BaseSchemaMatcher` (for schema matching).

3. Add a new entry in `matcher_factory.py` (e.g., `bdikit/value_matching/matcher_factory.py`). Make sure to add the correct import path for your
3. Add a new entry to the Enum class (e.g. `ValueMatchers`) in `matcher_factory.py` (e.g., `bdikit/value_matching/matcher_factory.py`).
Make sure to add the correct import path for your module to ensure it can be accessed without errors.


Adding New Standards
--------------------

Contributors can extend bdi-kit to additional standards a by following these steps:

1. Create a Python module inside the "standards" folder (`bdikit/standards`).

2. Define a class in the module that implements `BaseStandard`.

3. Add a new entry to the class `Standards(Enum)` in `bdikit/standards/standard_factory.py`. Make sure to add the correct import path for your
module to ensure it can be accessed without errors.


Code of Conduct
---------------

We abide by the principles of openness, respect, and consideration of others
of the Python Software Foundation: https://www.python.org/psf/codeofconduct/.
of the Python Software Foundation: https://www.python.org/psf/codeofconduct/.
54 changes: 27 additions & 27 deletions bdikit/api.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,19 @@
from __future__ import annotations
import logging

from collections import defaultdict
from os.path import join, dirname
from typing import (
Union,
List,
Dict,
TypedDict,
Optional,
Tuple,
Callable,
Any,
)
import itertools
import pandas as pd
import numpy as np
import panel as pn
from IPython.display import display, Markdown
from bdikit.utils import get_gdc_data, get_gdc_metadata

from bdikit.schema_matching.one2one.base import BaseSchemaMatcher
from bdikit.schema_matching.one2one.matcher_factory import SchemaMatchers
from bdikit.schema_matching.topk.base import BaseTopkSchemaMatcher
from bdikit.schema_matching.topk.matcher_factory import TopkMatchers
from bdikit.value_matching.base import BaseValueMatcher, ValueMatch, ValueMatchingResult
from bdikit.value_matching.matcher_factory import ValueMatchers
from bdikit.standards.standard_factory import Standards

from bdikit.mapping_functions import (
ValueMapper,
Expand All @@ -34,11 +22,21 @@
IdentityValueMapper,
)

from typing import (
Union,
List,
Dict,
TypedDict,
Optional,
Tuple,
Callable,
Any,
)

from bdikit.config import DEFAULT_SCHEMA_MATCHING_METHOD, DEFAULT_VALUE_MATCHING_METHOD

pn.extension("tabulator")

GDC_DATA_PATH = join(dirname(__file__), "./resource/gdc_table.csv")
DEFAULT_VALUE_MATCHING_METHOD = "tfidf"
DEFAULT_SCHEMA_MATCHING_METHOD = "coma"
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -92,10 +90,10 @@ def _load_table_for_standard(name: str) -> pd.DataFrame:
Load the table for the given standard data vocabulary. Currently, only the
GDC standard is supported.
"""
if name == "gdc":
return pd.read_csv(GDC_DATA_PATH)
else:
raise ValueError(f"The {name} standard is not supported")
standard = Standards.get_standard(name)
df = standard.get_dataframe_rep()

return df


def top_matches(
Expand Down Expand Up @@ -439,9 +437,10 @@ def _format_value_matching_input(
f"The source column '{source_column}' is not present in the source dataset."
)

if isinstance(target, str) and target == "gdc":
if isinstance(target, str):
column_names = mapping_df["target"].unique().tolist()
target_domain = get_gdc_data(column_names)
standard = Standards.get_standard(target)
target_domain = standard.get_column_values(column_names)
elif isinstance(target, pd.DataFrame):
target_domain = {
column_name: target[column_name].unique().tolist()
Expand Down Expand Up @@ -518,11 +517,12 @@ def preview_domain(
(if applicable).
"""

if isinstance(dataset, str) and dataset == "gdc":
gdc_metadata = get_gdc_metadata()
value_names = gdc_metadata[column]["value_names"]
value_descriptions = gdc_metadata[column]["value_descriptions"]
column_description = gdc_metadata[column]["description"]
if isinstance(dataset, str):
standard = Standards.get_standard(dataset)
column_metadata = standard.get_column_metadata([column])
value_names = column_metadata[column]["value_names"]
value_descriptions = column_metadata[column]["value_descriptions"]
column_description = column_metadata[column]["description"]
assert len(value_names) == len(value_descriptions)
elif isinstance(dataset, pd.DataFrame):
value_names = dataset[column].unique()
Expand Down
2 changes: 2 additions & 0 deletions bdikit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

BDIKIT_DEVICE: str = os.getenv("BDIKIT_DEVICE", default="cpu")
VALUE_MATCHING_THRESHOLD = 0.3
DEFAULT_VALUE_MATCHING_METHOD = "tfidf"
DEFAULT_SCHEMA_MATCHING_METHOD = "coma"


def get_device() -> str:
Expand Down
5 changes: 2 additions & 3 deletions bdikit/models/contrastive_learning/cl_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import List, Dict, Tuple, Optional
from bdikit.config import get_device
import numpy as np
Expand All @@ -13,7 +12,7 @@
from sklearn.metrics.pairwise import cosine_similarity
from tqdm.auto import tqdm
from bdikit.download import get_cached_model_or_download
from bdikit.utils import check_gdc_cache, write_embeddings_to_cache
from bdikit.utils import check_embedding_cache, write_embeddings_to_cache
from bdikit.models import ColumnEmbedder


Expand Down Expand Up @@ -108,7 +107,7 @@ def _sample_to_15_rows(self, table: pd.DataFrame):

def _load_table_tokens(self, table: pd.DataFrame) -> List[np.ndarray]:

embedding_file, embeddings = check_gdc_cache(table, self.model_path)
embedding_file, embeddings = check_embedding_cache(table, self.model_path)

if embeddings != None:
print(f"Table features loaded for {len(table.columns)} columns")
Expand Down
Loading

0 comments on commit 03c8edb

Please sign in to comment.