diff --git a/t_res/geoparser/linking.py b/t_res/geoparser/linking.py index b55ffef..7bb6c36 100644 --- a/t_res/geoparser/linking.py +++ b/t_res/geoparser/linking.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd from haversine import haversine +from math import exp from tqdm import tqdm tqdm.pandas() @@ -16,7 +17,7 @@ from ..utils import rel_utils from ..utils.REL import entity_disambiguation from . import ranking -from ..utils.dataclasses import Mention, MentionCandidates, StringMatchLinks, WikidataLink, MostPopularLink, ByDistanceLink, RelDisambLink, CandidateMatches, CandidateLinks, SentenceCandidates, Predictions +from ..utils.dataclasses import Mention, MentionCandidates, StringMatchLinks, WikidataLink, MostPopularLink, ByDistanceLink, RelDisambLink, CandidateMatches, CandidateLinks, SentenceCandidates, Predictions, RelPredictions class Linker: """ @@ -106,6 +107,30 @@ def wkdt_coords(self, wqid: str) -> Optional[Tuple[float, float]]: """ return self.resources["wqid_to_coords"].get(wqid, None) + def haversine(self, + origin_coords: Optional[Tuple[float, float]], + coords: Optional[Tuple[float, float]]) -> Optional[float]: + """ + Calculates the great circle distance between two points on Earth's surface. + + Args: + origin_coords (Optional[Tuple[float, float]]): coordinates of the origin + coords (Optional[Tuple[float, float]]): coordinates of the other point + + Returns: + The great circle distance between the points, or `None` if either pair + of coordinates is unavailable. + """ + if not origin_coords: + print("Missing place of publication coordinates.") + return None + try: + return haversine(origin_coords, coords, normalize=True) + except ValueError: + # We have one candidate with coordinates in Venus! + print(f"Failed to compute haversine distance from {origin_coords} to {coords}") + return None + def empty_candidates(self, mention: Mention, ranking_method: str, @@ -270,7 +295,7 @@ def disambiguation_scores(self, links: List[WikidataLink], string_similarity: fl overriding the `disambiguation_scores` method. """ raise NotImplementedError("Subclass implementation required.") - + class MostPopularLinker(Linker): """ An entity linking method that selects the candidate that is most @@ -376,29 +401,6 @@ def wikidata_links( ]) for wqid in match.wqid_links] return links - def haversine(self, - origin_coords: Optional[Tuple[float, float]], - coords: Optional[Tuple[float, float]]) -> Optional[float]: - """ - Calculates the great circle distance between two points on Earth's surface. - - Args: - origin_coords (Optional[Tuple[float, float]]): coordinates of the origin - coords (Optional[Tuple[float, float]]): coordinates of the other point - - Returns: - The great circle distance between the points, or `None` if either pair - of coordinates is unavailable. - """ - if not origin_coords: - print("Missing place of publication coordinates.") - return None - try: - return haversine(origin_coords, coords) - except ValueError: - # We have one candidate with coordinates in Venus! - return None - def disambiguation_scores(self, wikidata_links: List[ByDistanceLink], string_similarity: float) -> Dict[str, float]: @@ -432,11 +434,15 @@ def disambiguation_scores(self, ret[link.wqid] = final_score return ret -class RelDisambLinker(Linker): +class RelDisambLinker(MostPopularLinker): """ An entity linking method that selects the candidate using the [Radboud Entity Linker](https://github.com/informagi/REL/) (REL) model. + This is a subclass of the MostPopularLinker so that the disambiguation + score based on Wikidata popularity may be used to compute a combined + disambiguation score (if configured to do so). + Arguments: resources_path (str): The path to the linking resources. ranker (Ranker): A `Ranker` instance. @@ -507,6 +513,7 @@ class RelDisambLinker(Linker): "do_test": False, "default_publname": "United Kingdom", "default_publwqid": "Q145", + "reference_separation": ((49.956739, -8.17751), (60.87, 1.762973)) } ``` """ @@ -526,23 +533,35 @@ def __init__( super().__init__(resources_path, experiments_path, linking_resources) self.overwrite_training = overwrite_training - if rel_params is None: - rel_params = { - "model_path": os.path.join(resources_path, "models/disambiguation/"), - "data_path": os.path.join(experiments_path, "outputs/data/lwm/"), - "training_split": "originalsplit", - "db_embeddings": None, # The cursor to the embeddings database. - "with_publication": True, - "without_microtoponyms": True, - "do_test": False, - "default_publname": "United Kingdom", - "default_publwqid": "Q145", - } - self.rel_params = rel_params + # Default linking parameters: + params = { + "model_path": os.path.join(resources_path, "models/disambiguation/"), + "data_path": os.path.join(experiments_path, "outputs/data/lwm/"), + "training_split": "originalsplit", + "db_embeddings": None, # The cursor to the embeddings database. + "with_publication": True, + "predict_place_of_publication": True, + "combined_score": True, + "without_microtoponyms": True, + "do_test": False, + "default_publname": "United Kingdom", + "default_publwqid": "Q145", + "reference_separation": ((49.956739, -8.17751), (60.87, 1.762973)), + } + if not rel_params is None: + if not set(rel_params) <= set(params): + raise ValueError("Invalid REL config parameters.") + # Update the default parameters with any given parameters. + params.update(rel_params) + + self.rel_params = params self.ranker = ranker self.entity_disambiguation_model = None + reference_separation = self.rel_params['reference_separation'] + self.reference_distance = self.haversine(reference_separation[0], reference_separation[1]) + def __str__(self) -> str: """ Returns a string representation of the Linker object. @@ -656,11 +675,21 @@ class implementation to include REL model linking. ValueError("Entity disambiguation model not yet loaded. Call `load` method.") # Apply the REL model to the interim predictions. - rel_predictions = self.entity_disambiguation_model.predict( + rel_predictions_dict = self.entity_disambiguation_model.predict( predictions.as_dict(self.rel_params["with_publication"])) # Incorporate the REL model predictions. - return predictions.apply_rel_disambiguation(rel_predictions, self.rel_params["with_publication"]) + rel_predictions = predictions.apply_rel_disambiguation(rel_predictions_dict, self.rel_params["with_publication"]) + + # Take into account the `predict_place_of_pub` config parameter. + if self.rel_params['predict_place_of_publication']: + self.predict_place_of_publication(rel_predictions) + + # Take into account the `combined_score` config parameter. + if self.rel_params['combined_score']: + self.apply_combined_score(rel_predictions) + + return rel_predictions # Computes disambiguation scores for a collection of potential Wikidata links. # (Note: this replaces the rank_candidates function from rel_utils.py) @@ -696,6 +725,75 @@ def disambiguation_scores(self, return ret + def predict_place_of_publication(self, rel_predictions: RelPredictions): + """ + Sets the disambiguation scores for the place of publication to 1.0 inside the given + REL predictions, provided the place of publication is known and exists as a candidate link. + + Arguments: + rel_predictions: An instance of the `RelPredictions` dataclass. + """ + place_of_pub_wqid = rel_predictions.place_of_pub_wqid() + if not place_of_pub_wqid: + return + for rs in rel_predictions.rel_scores: + # If the place of publication is not in the list of scored candidates, do nothing. + if not place_of_pub_wqid in rs.scores.keys(): + return + rs.scores[place_of_pub_wqid] = 1.0 + + def apply_combined_score(self, rel_predictions: RelPredictions): + """ + Updates all disambiguation scores in the given REL predictions by + combining the REL score with place of publication information, if known. + + Arguments: + rel_predictions: An instance of the `RelPredictions` dataclass. + """ + place_of_pub_wqid = rel_predictions.place_of_pub_wqid() + if not place_of_pub_wqid: + return + + def combined_score(rel_score, popularity, proximity): + if not proximity: + return rel_score + return rel_score * max(popularity, proximity) + + # Iterate over the mention candidates and their corresponding REL scores. + for mc, rs in zip(rel_predictions.candidates(ignore_empty_candidates=False), rel_predictions.rel_scores): + # Iterate over the predicted Wikidata links. + for cl in mc.links: + # Compute popularity and proximity scores for all Wikidata links. + wqids = [wl.wqid for wl in cl.wikidata_links] + # Use the MostPopularLinker superclass to compute popularity. + popularity = super().disambiguation_scores(cl.wikidata_links) + proximity = {wqid: self.proximity( + origin_coords=self.wkdt_coords(place_of_pub_wqid), + coords=self.wkdt_coords(wqid)) for wqid in wqids} + combined = {wqid: combined_score(rs.scores[wqid], popularity[wqid], proximity[wqid]) for wqid in wqids} + # Update the REL scores. + rs.scores.update(combined) + + def proximity(self, + origin_coords: Optional[Tuple[float, float]], + coords: Optional[Tuple[float, float]]) -> Optional[float]: + """Computes the proximity measure between pairs of lat-long coordinates. + + Args: + origin_coords (Optional[Tuple[float, float]]): _description_ + coords (Optional[Tuple[float, float]]): _description_ + + Returns: + Optional[float]: _description_ + """ + if not coords: + return None + distance = self.haversine(origin_coords, coords) + # Handle caught error in the haversine method. + if not distance: + return None + return exp(-(distance/self.reference_distance)**2) + def train_load_model(self, split: Optional[str] = "originalsplit"): """ Trains or loads the entity disambiguation model and assigns to the diff --git a/t_res/utils/batch_job.py b/t_res/utils/batch_job.py index e5ba381..49ff057 100644 --- a/t_res/utils/batch_job.py +++ b/t_res/utils/batch_job.py @@ -381,9 +381,14 @@ def results_file(self) -> str: suffix += '-' + self.config[LINKER_KEY]['method_name'] if self.config[LINKER_KEY]['method_name'] == 'reldisamb': if self.config[LINKER_KEY]['rel_params']['with_publication']: - suffix += '-withpub' + if self.config[LINKER_KEY]['rel_params']['predict_place_of_publication']: + suffix += '-predictpub' + else: + suffix += '-withpub' else: suffix += '-nopub' + if self.config[LINKER_KEY]['rel_params']['combined_score']: + suffix += '-combined' if self.config[LINKER_KEY]['rel_params']['without_microtoponyms']: suffix += '-nomicro' else: @@ -450,7 +455,7 @@ def run(row): self.logger.debug(f'Running pipeline on text:\n{row[self.text_colname]}') self.logger.debug(f'Place of publication ID:{self.place_of_pub_wqid(row.name)}') self.logger.debug(f'Place of publication:\n{self.place_of_pub(row.name)}') - self.pipe.run( + return self.pipe.run( row[self.text_colname], place_of_pub_wqid=self.place_of_pub_wqid(row.name), place_of_pub=self.place_of_pub(row.name), diff --git a/t_res/utils/dataclasses.py b/t_res/utils/dataclasses.py index 1a53fed..dbdd4a0 100644 --- a/t_res/utils/dataclasses.py +++ b/t_res/utils/dataclasses.py @@ -348,20 +348,17 @@ def __post_init__(self): raise ValueError("normalized_score must be an float.") @pdataclass(frozen=True) -class RelDisambLink(WikidataLink): +class RelDisambLink(MostPopularLink): """Dataclass representing a string match and potential links in Wikidata under the `reldisamb` linking method. Attributes: - freq (int): The mention-to-wikidata link frequency. normalized_score (float): The normalized score from resource `mentions_to_wikidata_normalized.json`. """ - freq: int normalized_score: float def __post_init__(self): - if not isinstance(self.freq, int): - raise ValueError("freq must be an integer.") + super().__post_init__() if not isinstance(self.normalized_score, float): raise ValueError("normalized_score must be an float.") @@ -479,7 +476,7 @@ def scores_as_list(self) -> list: Helper method for the Predictions as_dict method.""" ret = [[k, round(v, 3)] for k, v in self.disambiguation_scores.items()] return sorted(ret, key=lambda x: (x[1], x[0]), reverse=True) - + # Linker::run method output type. @pdataclass(order=True, frozen=True) class MentionCandidates: diff --git a/tests/sample_files/batch_jobs/sample_batch_job_deezyrel.yml b/tests/sample_files/batch_jobs/sample_batch_job_deezyrel.yml index 74b86e2..4c1693d 100644 --- a/tests/sample_files/batch_jobs/sample_batch_job_deezyrel.yml +++ b/tests/sample_files/batch_jobs/sample_batch_job_deezyrel.yml @@ -8,6 +8,8 @@ linker: method_name: reldisamb rel_params: with_publication: True + predict_place_of_publication: True + combined_score: True without_microtoponyms: True default_publname: United Kingdom default_publwqid: Q145 diff --git a/tests/test_disambiguation.py b/tests/test_disambiguation.py index 4538a85..1a36b8b 100644 --- a/tests/test_disambiguation.py +++ b/tests/test_disambiguation.py @@ -112,6 +112,8 @@ def test_train(tmp_path): "training_split": "originalsplit", "db_embeddings": cursor, "with_publication": False, + "predict_place_of_publication": False, + "combined_score": False, "without_microtoponyms": True, "do_test": True, }, @@ -202,6 +204,8 @@ def test_load_eval_model(tmp_path): "training_split": "originalsplit", "db_embeddings": cursor, "with_publication": False, + "predict_place_of_publication": False, + "combined_score": False, "without_microtoponyms": False, "do_test": True, }, @@ -291,6 +295,8 @@ def test_predict(tmp_path): "training_split": "originalsplit", "db_embeddings": cursor, "with_publication": True, + "predict_place_of_publication": False, + "combined_score": False, "without_microtoponyms": True, "do_test": False, }, diff --git a/tests/test_linking.py b/tests/test_linking.py index 4054d0e..5750e0f 100644 --- a/tests/test_linking.py +++ b/tests/test_linking.py @@ -1,6 +1,7 @@ import os from pathlib import Path - +import sqlite3 +from math import exp import numpy as np import pytest @@ -37,7 +38,7 @@ def test_init(): ranker=ranking.PerfectMatchRanker("path/to/resources/"), experiments_path="path/to/experiments/", linking_resources={'resource': 'value'}, - rel_params={'param': 'value'}, + rel_params={'with_publication': False}, overwrite_training=True, ) @@ -46,19 +47,61 @@ def test_init(): assert linker.resources_path == "path/to/resources/" assert linker.experiments_path == "path/to/experiments/" assert linker.resources['resource'] == 'value' - assert linker.rel_params['param'] == 'value' + assert linker.rel_params['with_publication'] == False assert linker.overwrite_training linker = RelDisambLinker( resources_path="path/to/resources/", ranker=ranking.PerfectMatchRanker("path/to/resources/"), experiments_path="path/to/experiments/", - rel_params={'param': 'value'}, + rel_params={'with_publication': False}, linking_resources={'resource': 'value'}, ) assert not linker.overwrite_training + # Test default REL linker parameters + + # Invalid parameter raises ValueError: + with pytest.raises(ValueError): + linker = RelDisambLinker( + resources_path="path/to/resources/", + ranker=ranking.PerfectMatchRanker("path/to/resources/"), + experiments_path="path/to/experiments/", + rel_params={'invalid_param': 'value'}, + linking_resources={'resource': 'value'}, + ) + + linker = RelDisambLinker( + resources_path="path/to/resources/", + ranker=ranking.PerfectMatchRanker("path/to/resources/"), + experiments_path="path/to/experiments/", + linking_resources={'resource': 'value'}, + ) + + # Expect default parameter values: + assert linker.rel_params['with_publication'] == True + assert linker.rel_params['do_test'] == False + assert linker.rel_params["without_microtoponyms"] == True + + + linker = RelDisambLinker( + resources_path="path/to/resources/", + ranker=ranking.PerfectMatchRanker("path/to/resources/"), + experiments_path="path/to/experiments/", + rel_params={ + 'with_publication': False, + 'do_test': True, + }, + linking_resources={'resource': 'value'}, + ) + + # Default parameter values are overridden: + assert linker.rel_params['with_publication'] == False + assert linker.rel_params['do_test'] == True + # Unspecified parameters have default values: + assert linker.rel_params["without_microtoponyms"] == True + def test_new(): # Test Linker construction via string parameters. @@ -366,3 +409,46 @@ def test_linking_by_distance(): assert predictions.is_empty(ignore_empty_candidates=True) # If empty candidates are not ignored, the set of predictions is not empty: assert not predictions.is_empty(ignore_empty_candidates=False) + +@pytest.mark.resources(reason="Needs large resources") +def test_proximity(): + + with sqlite3.connect(os.path.join(current_dir, "../resources/rel_db/embeddings_database.db")) as conn: + cursor = conn.cursor() + linker = RelDisambLinker( + resources_path=os.path.join(current_dir, "../resources/"), + ranker=ranking.PerfectMatchRanker(os.path.join(current_dir, "../resources/")), + linking_resources=dict(), + rel_params={ + "model_path": os.path.join(current_dir, "../resources/models/disambiguation/"), + "data_path": os.path.join(current_dir, "sample_files/experiments/outputs/data/lwm/"), + "training_split": "apply", + "db_embeddings": cursor, + "with_publication": True, + "without_microtoponyms": False, + "do_test": False, + "default_publname": "United Kingdom", + "default_publwqid": "Q145", + "reference_separation": ((49.956739, -8.17751), (60.87, 1.762973)), + }, + ) + linker.load() + + place_of_pub_wqid = "Q203349" # Poole, Doset + wqid = "Q503331" # Dorchester, Dorset + + result = linker.proximity(linker.wkdt_coords(place_of_pub_wqid), linker.wkdt_coords(wqid)) + + # Distance from Poole to Dorchester is ~31km + d = 31.0 + # Reference distance is ~1362km + reference_d = 1362.0 + + assert result == pytest.approx(exp(-(d/reference_d)**2), abs=1e-4) + + # Test with specific coordinates that require normalization. + origin_coords = [53.067, -2.522] + coords = [-24.84, 340.47] + + result = linker.proximity(origin_coords, coords) + assert result == pytest.approx(0, abs=1e-10) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index f3ba7c3..5d80768 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -221,6 +221,8 @@ def test_deezy_rel_wpubl_wmtops(tmp_path): "training_split": "originalsplit", "db_embeddings": cursor, "with_publication": True, + "predict_place_of_publication": False, + "combined_score": False, "without_microtoponyms": True, "do_test": False, "default_publname": "United Kingdom", @@ -348,6 +350,8 @@ def test_deezy_rel_wpubl(tmp_path): "training_split": "originalsplit", "db_embeddings": cursor, "with_publication": True, + "predict_place_of_publication": False, + "combined_score": False, "without_microtoponyms": False, "do_test": False, "default_publname": "United Kingdom", @@ -473,6 +477,173 @@ def test_perfect_rel_wpubl_wmtops(): assert predictions.rel_scores[2].confidence == pytest.approx(0.178, abs=1e-3) assert predictions.rel_scores[2].scores["Q84"] == pytest.approx(0.573, abs=1e-3) +@pytest.mark.resources(reason="Needs large resources") +def test_perfect_rel_predict_place_of_pub(): + model_path = os.path.join(current_dir, "../resources/models/") + assert os.path.isdir(model_path) is True + + recogniser = ner.CustomRecogniser( + model_name="blb_lwm-ner-fine", + train_dataset=os.path.join(current_dir,"sample_files/experiments/outputs/data/lwm/ner_fine_train.json"), + test_dataset=os.path.join(current_dir,"sample_files/experiments/outputs/data/lwm/ner_fine_dev.json"), + pipe=None, + base_model="khosseini/bert_1760_1900", # Base model to fine-tune + model_path=model_path, + training_args={ + "batch_size": 8, + "num_train_epochs": 1, + "learning_rate": 0.00005, + "weight_decay": 0.0, + }, + overwrite_training=False, # Set to True if you want to overwrite model if existing + do_test=False, # Set to True if you want to train on test mode + ) + + # -------------------------------------- + # Instantiate the ranker: + ranker = ranking.PerfectMatchRanker( + resources_path=os.path.join(current_dir, "../resources/"), + mentions_to_wikidata=dict(), + wikidata_to_mentions=dict(), + ) + + with sqlite3.connect(os.path.join(current_dir, "../resources/rel_db/embeddings_database.db")) as conn: + cursor = conn.cursor() + linker = linking.RelDisambLinker( + resources_path=os.path.join(current_dir, "../resources/"), + ranker=ranker, + linking_resources=dict(), + rel_params={ + "model_path": os.path.join(current_dir,"../resources/models/disambiguation/"), + "data_path": os.path.join(current_dir,"sample_files/experiments/outputs/data/lwm/"), + "training_split": "originalsplit", + "db_embeddings": cursor, + "with_publication": True, + "predict_place_of_publication": False, + "combined_score": False, + "without_microtoponyms": True, + "do_test": True, + "default_publname": "United Kingdom", + "default_publwqid": "Q145", + }, + overwrite_training=False, + ) + + geoparser = pipeline.Pipeline(recogniser=recogniser, ranker=ranker, linker=linker) + + predictions = geoparser.run( + "A remarkable case of rattening has just occurred in the building trade at Stockton, but also in Leeds. Not in London though.", + place_of_pub_wqid="Q989418", + place_of_pub="Stockton-on-Tees, Cleveland, England", + ) + + assert isinstance(predictions, RelPredictions) + assert len(predictions.candidates()) == 3 + + # With "predict_place_of_publication" set to False, the wrong Stockton is predicted: + assert predictions.candidates()[0].best_wqid() != "Q989418" + + geoparser.linker.rel_params["predict_place_of_publication"] = True + + predictions = geoparser.run( + "A remarkable case of rattening has just occurred in the building trade at Stockton, but also in Leeds. Not in London though.", + place_of_pub_wqid="Q989418", + place_of_pub="Stockton-on-Tees, Cleveland, England", + ) + + assert isinstance(predictions, RelPredictions) + assert len(predictions.candidates()) == 3 + + # With "predict_place_of_publication" set to True, the correct Stockton is predicted + # because the place of publication is the favoured candidate: + assert predictions.candidates()[0].best_wqid() == "Q989418" + +@pytest.mark.resources(reason="Needs large resources") +def test_perfect_rel_combined_score(): + model_path = os.path.join(current_dir, "../resources/models/") + assert os.path.isdir(model_path) is True + + recogniser = ner.CustomRecogniser( + model_name="blb_lwm-ner-fine", + train_dataset=os.path.join(current_dir,"sample_files/experiments/outputs/data/lwm/ner_fine_train.json"), + test_dataset=os.path.join(current_dir,"sample_files/experiments/outputs/data/lwm/ner_fine_dev.json"), + pipe=None, + base_model="khosseini/bert_1760_1900", # Base model to fine-tune + model_path=model_path, + training_args={ + "batch_size": 8, + "num_train_epochs": 1, + "learning_rate": 0.00005, + "weight_decay": 0.0, + }, + overwrite_training=False, # Set to True if you want to overwrite model if existing + do_test=False, # Set to True if you want to train on test mode + ) + + # -------------------------------------- + # Instantiate the ranker: + ranker = ranking.PerfectMatchRanker( + resources_path=os.path.join(current_dir, "../resources/"), + mentions_to_wikidata=dict(), + wikidata_to_mentions=dict(), + ) + + with sqlite3.connect(os.path.join(current_dir, "../resources/rel_db/embeddings_database.db")) as conn: + cursor = conn.cursor() + linker = linking.RelDisambLinker( + resources_path=os.path.join(current_dir, "../resources/"), + ranker=ranker, + linking_resources=dict(), + rel_params={ + "model_path": os.path.join(current_dir,"../resources/models/disambiguation/"), + "data_path": os.path.join(current_dir,"sample_files/experiments/outputs/data/lwm/"), + "training_split": "originalsplit", + "db_embeddings": cursor, + "with_publication": True, + "predict_place_of_publication": False, + "combined_score": False, + "without_microtoponyms": True, + "do_test": True, + "default_publname": "United Kingdom", + "default_publwqid": "Q145", + "reference_separation": ((49.956739, -8.17751), (60.87, 1.762973)), + }, + overwrite_training=False, + ) + + geoparser = pipeline.Pipeline(recogniser=recogniser, ranker=ranker, linker=linker) + + predictions = geoparser.run( + "A remarkable case of rattening has just occurred in the building trade at Stockton, but also in Leeds.", + place_of_pub_wqid="Q39121", + place_of_pub="Leeds, West Yorkshire, England", + ) + + assert isinstance(predictions, RelPredictions) + assert len(predictions.candidates()) == 2 + + # With "combined_score" set to False, the wrong Stockton is predicted: + assert predictions.candidates()[0].best_wqid() != "Q989418" + assert predictions.candidates()[0].best_wqid() == "Q49240" + assert predictions.candidates()[0].best_disambiguation_score() == pytest.approx(0.225, abs=1e-3) + + geoparser.linker.rel_params["combined_score"] = True + + predictions = geoparser.run( + "A remarkable case of rattening has just occurred in the building trade at Stockton, but also in Leeds.", + place_of_pub_wqid="Q39121", + place_of_pub="Leeds, West Yorkshire, England", + ) + + assert isinstance(predictions, RelPredictions) + assert len(predictions.candidates()) == 2 + + # With "combined_score" set to True, the correct Stockton is predicted + # because the disambiguation score for the previous best candidate + # is curtailed by the combined score: + assert predictions.candidates()[0].best_wqid() == "Q989418" + assert predictions.candidates()[0].best_disambiguation_score() == pytest.approx(0.21, abs=1e-3) + @pytest.mark.resources(reason="Needs large resources") def test_modular_deezy_rel(tmp_path): model_path = os.path.join(current_dir, "../resources/models/") @@ -540,6 +711,8 @@ def test_modular_deezy_rel(tmp_path): "training_split": "apply", "db_embeddings": cursor, "with_publication": True, + "predict_place_of_publication": False, + "combined_score": False, "without_microtoponyms": False, "do_test": False, "default_publname": "United Kingdom",