Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New linking algorithm combining place of publication & REL scores #290

Open
wants to merge 8 commits into
base: 276-refactor
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 139 additions & 41 deletions t_res/geoparser/linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pandas as pd
from haversine import haversine
from math import exp
from tqdm import tqdm

tqdm.pandas()
Expand All @@ -16,7 +17,7 @@
from ..utils import rel_utils
from ..utils.REL import entity_disambiguation
from . import ranking
from ..utils.dataclasses import Mention, MentionCandidates, StringMatchLinks, WikidataLink, MostPopularLink, ByDistanceLink, RelDisambLink, CandidateMatches, CandidateLinks, SentenceCandidates, Predictions
from ..utils.dataclasses import Mention, MentionCandidates, StringMatchLinks, WikidataLink, MostPopularLink, ByDistanceLink, RelDisambLink, CandidateMatches, CandidateLinks, SentenceCandidates, Predictions, RelPredictions

class Linker:
"""
Expand Down Expand Up @@ -106,6 +107,30 @@ def wkdt_coords(self, wqid: str) -> Optional[Tuple[float, float]]:
"""
return self.resources["wqid_to_coords"].get(wqid, None)

def haversine(self,
origin_coords: Optional[Tuple[float, float]],
coords: Optional[Tuple[float, float]]) -> Optional[float]:
"""
Calculates the great circle distance between two points on Earth's surface.

Args:
origin_coords (Optional[Tuple[float, float]]): coordinates of the origin
coords (Optional[Tuple[float, float]]): coordinates of the other point

Returns:
The great circle distance between the points, or `None` if either pair
of coordinates is unavailable.
"""
if not origin_coords:
print("Missing place of publication coordinates.")
return None
try:
return haversine(origin_coords, coords, normalize=True)
except ValueError:
# We have one candidate with coordinates in Venus!
print(f"Failed to compute haversine distance from {origin_coords} to {coords}")
return None

def empty_candidates(self,
mention: Mention,
ranking_method: str,
Expand Down Expand Up @@ -270,7 +295,7 @@ def disambiguation_scores(self, links: List[WikidataLink], string_similarity: fl
overriding the `disambiguation_scores` method.
"""
raise NotImplementedError("Subclass implementation required.")

class MostPopularLinker(Linker):
"""
An entity linking method that selects the candidate that is most
Expand Down Expand Up @@ -376,29 +401,6 @@ def wikidata_links(
]) for wqid in match.wqid_links]
return links

def haversine(self,
origin_coords: Optional[Tuple[float, float]],
coords: Optional[Tuple[float, float]]) -> Optional[float]:
"""
Calculates the great circle distance between two points on Earth's surface.

Args:
origin_coords (Optional[Tuple[float, float]]): coordinates of the origin
coords (Optional[Tuple[float, float]]): coordinates of the other point

Returns:
The great circle distance between the points, or `None` if either pair
of coordinates is unavailable.
"""
if not origin_coords:
print("Missing place of publication coordinates.")
return None
try:
return haversine(origin_coords, coords)
except ValueError:
# We have one candidate with coordinates in Venus!
return None

def disambiguation_scores(self,
wikidata_links: List[ByDistanceLink],
string_similarity: float) -> Dict[str, float]:
Expand Down Expand Up @@ -432,11 +434,15 @@ def disambiguation_scores(self,
ret[link.wqid] = final_score
return ret

class RelDisambLinker(Linker):
class RelDisambLinker(MostPopularLinker):
"""
An entity linking method that selects the candidate using the [Radboud
Entity Linker](https://github.com/informagi/REL/) (REL) model.

This is a subclass of the MostPopularLinker so that the disambiguation
score based on Wikidata popularity may be used to compute a combined
disambiguation score (if configured to do so).

Arguments:
resources_path (str): The path to the linking resources.
ranker (Ranker): A `Ranker` instance.
Expand Down Expand Up @@ -507,6 +513,7 @@ class RelDisambLinker(Linker):
"do_test": False,
"default_publname": "United Kingdom",
"default_publwqid": "Q145",
"reference_separation": ((49.956739, -8.17751), (60.87, 1.762973))
}
```
"""
Expand All @@ -526,23 +533,35 @@ def __init__(
super().__init__(resources_path, experiments_path, linking_resources)

self.overwrite_training = overwrite_training
if rel_params is None:
rel_params = {
"model_path": os.path.join(resources_path, "models/disambiguation/"),
"data_path": os.path.join(experiments_path, "outputs/data/lwm/"),
"training_split": "originalsplit",
"db_embeddings": None, # The cursor to the embeddings database.
"with_publication": True,
"without_microtoponyms": True,
"do_test": False,
"default_publname": "United Kingdom",
"default_publwqid": "Q145",
}

self.rel_params = rel_params
# Default linking parameters:
params = {
"model_path": os.path.join(resources_path, "models/disambiguation/"),
"data_path": os.path.join(experiments_path, "outputs/data/lwm/"),
"training_split": "originalsplit",
"db_embeddings": None, # The cursor to the embeddings database.
"with_publication": True,
"predict_place_of_publication": True,
"combined_score": True,
"without_microtoponyms": True,
"do_test": False,
"default_publname": "United Kingdom",
"default_publwqid": "Q145",
"reference_separation": ((49.956739, -8.17751), (60.87, 1.762973)),
}
if not rel_params is None:
if not set(rel_params) <= set(params):
raise ValueError("Invalid REL config parameters.")
# Update the default parameters with any given parameters.
params.update(rel_params)

self.rel_params = params
self.ranker = ranker
self.entity_disambiguation_model = None

reference_separation = self.rel_params['reference_separation']
self.reference_distance = self.haversine(reference_separation[0], reference_separation[1])

def __str__(self) -> str:
"""
Returns a string representation of the Linker object.
Expand Down Expand Up @@ -656,11 +675,21 @@ class implementation to include REL model linking.
ValueError("Entity disambiguation model not yet loaded. Call `load` method.")

# Apply the REL model to the interim predictions.
rel_predictions = self.entity_disambiguation_model.predict(
rel_predictions_dict = self.entity_disambiguation_model.predict(
predictions.as_dict(self.rel_params["with_publication"]))

# Incorporate the REL model predictions.
return predictions.apply_rel_disambiguation(rel_predictions, self.rel_params["with_publication"])
rel_predictions = predictions.apply_rel_disambiguation(rel_predictions_dict, self.rel_params["with_publication"])

# Take into account the `predict_place_of_pub` config parameter.
if self.rel_params['predict_place_of_publication']:
self.predict_place_of_publication(rel_predictions)

# Take into account the `combined_score` config parameter.
if self.rel_params['combined_score']:
self.apply_combined_score(rel_predictions)

return rel_predictions

# Computes disambiguation scores for a collection of potential Wikidata links.
# (Note: this replaces the rank_candidates function from rel_utils.py)
Expand Down Expand Up @@ -696,6 +725,75 @@ def disambiguation_scores(self,

return ret

def predict_place_of_publication(self, rel_predictions: RelPredictions):
"""
Sets the disambiguation scores for the place of publication to 1.0 inside the given
REL predictions, provided the place of publication is known and exists as a candidate link.

Arguments:
rel_predictions: An instance of the `RelPredictions` dataclass.
"""
place_of_pub_wqid = rel_predictions.place_of_pub_wqid()
if not place_of_pub_wqid:
return
for rs in rel_predictions.rel_scores:
# If the place of publication is not in the list of scored candidates, do nothing.
if not place_of_pub_wqid in rs.scores.keys():
return
rs.scores[place_of_pub_wqid] = 1.0

def apply_combined_score(self, rel_predictions: RelPredictions):
"""
Updates all disambiguation scores in the given REL predictions by
combining the REL score with place of publication information, if known.

Arguments:
rel_predictions: An instance of the `RelPredictions` dataclass.
"""
place_of_pub_wqid = rel_predictions.place_of_pub_wqid()
if not place_of_pub_wqid:
return

def combined_score(rel_score, popularity, proximity):
if not proximity:
return rel_score
return rel_score * max(popularity, proximity)

# Iterate over the mention candidates and their corresponding REL scores.
for mc, rs in zip(rel_predictions.candidates(ignore_empty_candidates=False), rel_predictions.rel_scores):
# Iterate over the predicted Wikidata links.
for cl in mc.links:
# Compute popularity and proximity scores for all Wikidata links.
wqids = [wl.wqid for wl in cl.wikidata_links]
# Use the MostPopularLinker superclass to compute popularity.
popularity = super().disambiguation_scores(cl.wikidata_links)
proximity = {wqid: self.proximity(
origin_coords=self.wkdt_coords(place_of_pub_wqid),
coords=self.wkdt_coords(wqid)) for wqid in wqids}
combined = {wqid: combined_score(rs.scores[wqid], popularity[wqid], proximity[wqid]) for wqid in wqids}
# Update the REL scores.
rs.scores.update(combined)

def proximity(self,
origin_coords: Optional[Tuple[float, float]],
coords: Optional[Tuple[float, float]]) -> Optional[float]:
"""Computes the proximity measure between pairs of lat-long coordinates.

Args:
origin_coords (Optional[Tuple[float, float]]): _description_
coords (Optional[Tuple[float, float]]): _description_

Returns:
Optional[float]: _description_
"""
if not coords:
return None
distance = self.haversine(origin_coords, coords)
# Handle caught error in the haversine method.
if not distance:
return None
return exp(-(distance/self.reference_distance)**2)

def train_load_model(self, split: Optional[str] = "originalsplit"):
"""
Trains or loads the entity disambiguation model and assigns to the
Expand Down
9 changes: 7 additions & 2 deletions t_res/utils/batch_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,14 @@ def results_file(self) -> str:
suffix += '-' + self.config[LINKER_KEY]['method_name']
if self.config[LINKER_KEY]['method_name'] == 'reldisamb':
if self.config[LINKER_KEY]['rel_params']['with_publication']:
suffix += '-withpub'
if self.config[LINKER_KEY]['rel_params']['predict_place_of_publication']:
suffix += '-predictpub'
else:
suffix += '-withpub'
else:
suffix += '-nopub'
if self.config[LINKER_KEY]['rel_params']['combined_score']:
suffix += '-combined'
if self.config[LINKER_KEY]['rel_params']['without_microtoponyms']:
suffix += '-nomicro'
else:
Expand Down Expand Up @@ -450,7 +455,7 @@ def run(row):
self.logger.debug(f'Running pipeline on text:\n{row[self.text_colname]}')
self.logger.debug(f'Place of publication ID:{self.place_of_pub_wqid(row.name)}')
self.logger.debug(f'Place of publication:\n{self.place_of_pub(row.name)}')
self.pipe.run(
return self.pipe.run(
row[self.text_colname],
place_of_pub_wqid=self.place_of_pub_wqid(row.name),
place_of_pub=self.place_of_pub(row.name),
Expand Down
9 changes: 3 additions & 6 deletions t_res/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,20 +348,17 @@ def __post_init__(self):
raise ValueError("normalized_score must be an float.")

@pdataclass(frozen=True)
class RelDisambLink(WikidataLink):
class RelDisambLink(MostPopularLink):
"""Dataclass representing a string match and potential links in
Wikidata under the `reldisamb` linking method.

Attributes:
freq (int): The mention-to-wikidata link frequency.
normalized_score (float): The normalized score from resource `mentions_to_wikidata_normalized.json`.
"""
freq: int
normalized_score: float

def __post_init__(self):
if not isinstance(self.freq, int):
raise ValueError("freq must be an integer.")
super().__post_init__()
if not isinstance(self.normalized_score, float):
raise ValueError("normalized_score must be an float.")

Expand Down Expand Up @@ -479,7 +476,7 @@ def scores_as_list(self) -> list:
Helper method for the Predictions as_dict method."""
ret = [[k, round(v, 3)] for k, v in self.disambiguation_scores.items()]
return sorted(ret, key=lambda x: (x[1], x[0]), reverse=True)

# Linker::run method output type.
@pdataclass(order=True, frozen=True)
class MentionCandidates:
Expand Down
2 changes: 2 additions & 0 deletions tests/sample_files/batch_jobs/sample_batch_job_deezyrel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ linker:
method_name: reldisamb
rel_params:
with_publication: True
predict_place_of_publication: True
combined_score: True
without_microtoponyms: True
default_publname: United Kingdom
default_publwqid: Q145
Expand Down
6 changes: 6 additions & 0 deletions tests/test_disambiguation.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def test_train(tmp_path):
"training_split": "originalsplit",
"db_embeddings": cursor,
"with_publication": False,
"predict_place_of_publication": False,
"combined_score": False,
"without_microtoponyms": True,
"do_test": True,
},
Expand Down Expand Up @@ -202,6 +204,8 @@ def test_load_eval_model(tmp_path):
"training_split": "originalsplit",
"db_embeddings": cursor,
"with_publication": False,
"predict_place_of_publication": False,
"combined_score": False,
"without_microtoponyms": False,
"do_test": True,
},
Expand Down Expand Up @@ -291,6 +295,8 @@ def test_predict(tmp_path):
"training_split": "originalsplit",
"db_embeddings": cursor,
"with_publication": True,
"predict_place_of_publication": False,
"combined_score": False,
"without_microtoponyms": True,
"do_test": False,
},
Expand Down
Loading
Loading