Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Nov 10, 2023
2 parents 0e12910 + a290424 commit b0f9951
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 62 deletions.
23 changes: 3 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,34 +214,17 @@ pre_trained_kge.predict_topk(r=[".."],t=[".."],topk=10)

</details>

## Using Large Pre-trained Embedding Models
## Downloading Pretrained Models

<details> <summary> To see a code snippet </summary>

**Stay tune for Keci with >10B parameters on DBpedia!**
```bash
# To download a pretrained ConEx on DBpedia 03-2022
mkdir ConEx && cd ConEx && wget -r -nd -np https://hobbitdata.informatik.uni-leipzig.de/KGE/DBpedia/ConEx/ && cd ..
```
```python
from dicee import KGE
# (1) Load a pretrained ConEx on DBpedia
pre_trained_kge = KGE(path='ConEx')
pre_trained_kge.triple_score(h=["http://dbpedia.org/resource/Albert_Einstein"],r=["http://dbpedia.org/ontology/birthPlace"],t=["http://dbpedia.org/resource/Ulm"]) # tensor([0.9309])
pre_trained_kge.triple_score(h=["http://dbpedia.org/resource/Albert_Einstein"],r=["http://dbpedia.org/ontology/birthPlace"],t=["http://dbpedia.org/resource/German_Empire"]) # tensor([0.9981])
pre_trained_kge.triple_score(h=["http://dbpedia.org/resource/Albert_Einstein"],r=["http://dbpedia.org/ontology/birthPlace"],t=["http://dbpedia.org/resource/Kingdom_of_Württemberg"]) # tensor([0.9994])
pre_trained_kge.triple_score(h=["http://dbpedia.org/resource/Albert_Einstein"],r=["http://dbpedia.org/ontology/birthPlace"],t=["http://dbpedia.org/resource/Germany"]) # tensor([0.9498])
pre_trained_kge.triple_score(h=["http://dbpedia.org/resource/Albert_Einstein"],r=["http://dbpedia.org/ontology/birthPlace"],t=["http://dbpedia.org/resource/France"]) # very low
pre_trained_kge.triple_score(h=["http://dbpedia.org/resource/Albert_Einstein"],r=["http://dbpedia.org/ontology/birthPlace"],t=["http://dbpedia.org/resource/Italy"]) # very low
model = KGE(url="https://files.dice-research.org/projects/DiceEmbeddings/KINSHIP-Keci-dim128-epoch256-KvsAll")
```

Please contact: ```[email protected] ``` or ```[email protected] ``` , if you lack hardware resources to obtain embeddings of a specific knowledge Graph.
- [DBpedia version: 06-2022 Embeddings](https://hobbitdata.informatik.uni-leipzig.de/KGE/DBpediaQMultEmbeddings_03_07):
- Models: ConEx, QMult
- [YAGO3-10 ConEx embeddings](https://hobbitdata.informatik.uni-leipzig.de/KGE/conex/YAGO3-10.zip)
- [FB15K-237 ConEx embeddings](https://hobbitdata.informatik.uni-leipzig.de/KGE/conex/FB15K-237.zip)
- [WN18RR ConEx embeddings](https://hobbitdata.informatik.uni-leipzig.de/KGE/conex/WN18RR.zip)
- For more please look at [Hobbit Data](https://files.dice-research.org/projects/DiceEmbeddings/)
- For more please look at [dice-research.org/projects/DiceEmbeddings/](https://files.dice-research.org/projects/DiceEmbeddings/)

</details>

Expand Down
42 changes: 32 additions & 10 deletions dicee/abstracts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import datetime
from .static_funcs import load_model_ensemble, load_model, save_checkpoint_model, load_json
from .static_funcs import load_model_ensemble, load_model, save_checkpoint_model, load_json, download_pretrained_model
import torch
from typing import List, Tuple, Union
import random
Expand Down Expand Up @@ -139,18 +139,23 @@ class BaseInteractiveKGE:
apply_semantic_constraint : boolean
"""

def __init__(self, path: str, construct_ensemble: bool = False, model_name: str = None,
def __init__(self, path: str=None, url:str=None, construct_ensemble: bool = False, model_name: str = None,
apply_semantic_constraint: bool = False):
if url is not None:
assert path is None
self.path = download_pretrained_model(url)
else:
self.path = path
try:
assert os.path.isdir(path)
assert os.path.isdir(self.path)
except AssertionError:
raise AssertionError(f'Could not find a directory {path}')
self.path = path
raise AssertionError(f'Could not find a directory {self.path}')

# (1) Load model...
self.construct_ensemble = construct_ensemble
self.apply_semantic_constraint = apply_semantic_constraint
self.configs = load_json(path + '/configuration.json')
self.configs.update(load_json(path + '/report.json'))
self.configs = load_json(self.path + '/configuration.json')
self.configs.update(load_json(self.path + '/report.json'))

if construct_ensemble:
self.model, tuple_of_entity_relation_idx = load_model_ensemble(self.path)
Expand All @@ -159,12 +164,10 @@ def __init__(self, path: str, construct_ensemble: bool = False, model_name: str
self.model, tuple_of_entity_relation_idx = load_model(self.path, model_name=model_name)
else:
self.model, tuple_of_entity_relation_idx = load_model(self.path)

if self.configs["byte_pair_encoding"]:
if self.configs.get("byte_pair_encoding", None):
self.enc = tiktoken.get_encoding("gpt2")
self.dummy_id = tiktoken.get_encoding("gpt2").encode(" ")[0]
self.max_length_subword_tokens = self.configs["max_length_subword_tokens"]

else:
assert len(tuple_of_entity_relation_idx) == 2

Expand All @@ -179,6 +182,20 @@ def __init__(self, path: str, construct_ensemble: bool = False, model_name: str
self.idx_to_entity = {v: k for k, v in self.entity_to_idx.items()}
self.idx_to_relations = {v: k for k, v in self.relation_to_idx.items()}



# See https://numpy.org/doc/stable/reference/generated/numpy.memmap.html
# @TODO: Ignore temporalryIf file exists
#if os.path.exists(self.path + '/train_set.npy'):
# self.train_set = np.load(file=self.path + '/train_set.npy', mmap_mode='r')

#if apply_semantic_constraint:
# (self.domain_constraints_per_rel, self.range_constraints_per_rel,
# self.domain_per_rel, self.range_per_rel) = create_constraints(self.train_set)

def get_eval_report(self) -> dict:
return load_json(self.path + "/eval_report.json")

def get_bpe_token_representation(self, str_entity_or_relation: Union[List[str], str]) -> Union[
List[List[int]], List[int]]:
"""
Expand Down Expand Up @@ -572,3 +589,8 @@ def on_train_epoch_end(self, trainer, model):

def on_train_batch_end(self, *args, **kwargs):
return





3 changes: 1 addition & 2 deletions dicee/eval_static_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def evaluate_link_prediction_performance(model: KGE, triples, er_vocab: Dict[Tup

@torch.no_grad()
def evaluate_link_prediction_performance_with_reciprocals(model: KGE, triples,
er_vocab: Dict[Tuple, List],
re_vocab: Dict[Tuple, List]):
er_vocab: Dict[Tuple, List]):
model.model.eval()
entity_to_idx = model.entity_to_idx
relation_to_idx = model.relation_to_idx
Expand Down
22 changes: 6 additions & 16 deletions dicee/knowledge_graph_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os.path
from typing import List, Tuple, Set, Iterable, Dict, Union
import torch
from torch import optim
Expand All @@ -8,7 +7,6 @@
from .static_funcs import random_prediction, deploy_triple_prediction, deploy_tail_entity_prediction, \
deploy_relation_prediction, deploy_head_entity_prediction, load_pickle
from .static_funcs_training import evaluate_lp
from .static_preprocess_funcs import create_constraints
import numpy as np
import sys
import gradio as gr
Expand All @@ -17,18 +15,10 @@
class KGE(BaseInteractiveKGE):
""" Knowledge Graph Embedding Class for interactive usage of pre-trained models"""

def __init__(self, path, construct_ensemble=False,
def __init__(self, path=None, url=None, construct_ensemble=False,
model_name=None,
apply_semantic_constraint=False):
super().__init__(path=path, construct_ensemble=construct_ensemble, model_name=model_name)
# See https://numpy.org/doc/stable/reference/generated/numpy.memmap.html
# If file exists
if os.path.exists(path + '/train_set.npy'):
self.train_set = np.load(file=path + '/train_set.npy', mmap_mode='r')

if apply_semantic_constraint:
(self.domain_constraints_per_rel, self.range_constraints_per_rel,
self.domain_per_rel, self.range_per_rel) = create_constraints(self.train_set)
super().__init__(path=path, url=url, construct_ensemble=construct_ensemble, model_name=model_name)

def __str__(self):
return "KGE | " + str(self.model)
Expand Down Expand Up @@ -243,7 +233,7 @@ def predict(self, *, h: Union[List[str], str] = None, r: Union[List[str], str] =
# h r ?
scores = self.predict_missing_tail_entity(h, r, within)
else:
scores=self.triple_score(h, r, t, logits=True)
scores = self.triple_score(h, r, t, logits=True)

if logits:
return scores
Expand Down Expand Up @@ -359,7 +349,7 @@ def triple_score(self, h: Union[List[str], str] = None, r: Union[List[str], str]
pytorch tensor of triple score
"""

if self.configs["byte_pair_encoding"]:
if self.configs.get("byte_pair_encoding", None):
h_encode = self.enc.encode(h)
r_encode = self.enc.encode(r)
t_encode = self.enc.encode(t)
Expand Down Expand Up @@ -396,9 +386,9 @@ def triple_score(self, h: Union[List[str], str] = None, r: Union[List[str], str]
else:
with torch.no_grad():
if logits:
return self.model(x)
return self.model(x)
else:
return torch.sigmoid(self.model(x))
return torch.sigmoid(self.model(x))

def t_norm(self, tens_1: torch.Tensor, tens_2: torch.Tensor, tnorm: str = 'min') -> torch.Tensor:
if 'min' in tnorm:
Expand Down
61 changes: 54 additions & 7 deletions dicee/static_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
import pickle
from collections import defaultdict

import requests
from urllib.parse import urlparse
from bs4 import BeautifulSoup

def create_recipriocal_triples(x):
"""
Add inverse triples into dask dataframe
Expand All @@ -25,6 +29,8 @@ def create_recipriocal_triples(x):
return pd.concat([x, x['object'].to_frame(name='subject').join(
x['relation'].map(lambda x: x + '_inverse').to_frame(name='relation')).join(
x['subject'].to_frame(name='object'))], ignore_index=True)


def get_er_vocab(data, file_path: str = None):
# head entity and relation
er_vocab = defaultdict(list)
Expand Down Expand Up @@ -110,19 +116,19 @@ def load_model(path_of_experiment_folder: str, model_name='model.pt') -> Tuple[o
weights = torch.load(path_of_experiment_folder + f'/{model_name}', torch.device('cpu'))
configs = load_json(path_of_experiment_folder + '/configuration.json')

if configs["byte_pair_encoding"]:
if configs.get("byte_pair_encoding", None):
num_tokens, ent_dim = weights['token_embeddings.weight'].shape
# (2) Loading input configuration.
configs = load_json(path_of_experiment_folder + '/configuration.json')
report = load_json(path_of_experiment_folder + '/report.json')
# Load ordered_bpe_entities.p
configs["ordered_bpe_entities"]=load_pickle(file_path=path_of_experiment_folder+"/ordered_bpe_entities.p")
configs["ordered_bpe_entities"] = load_pickle(file_path=path_of_experiment_folder + "/ordered_bpe_entities.p")
configs["num_tokens"] = num_tokens
configs["max_length_subword_tokens"] = report["max_length_subword_tokens"]
else:
num_ent, ent_dim = weights['entity_embeddings.weight'].shape
num_rel, rel_dim = weights['relation_embeddings.weight'].shape
assert ent_dim==rel_dim
assert ent_dim == rel_dim
# Update the training configuration
configs["num_entities"] = num_ent
configs["num_relations"] = num_rel
Expand All @@ -136,7 +142,7 @@ def load_model(path_of_experiment_folder: str, model_name='model.pt') -> Tuple[o
parameter.requires_grad = False
model.eval()
start_time = time.time()
if configs["byte_pair_encoding"]:
if configs.get("byte_pair_encoding", None):
return model, None
else:
print('Loading entity and relation indexes...', end=' ')
Expand All @@ -146,13 +152,13 @@ def load_model(path_of_experiment_folder: str, model_name='model.pt') -> Tuple[o
entity_to_idx = pickle.load(f)
except FileNotFoundError:
print("entity_to_idx.p not found")
entity_to_idx=dict()
entity_to_idx = dict()
try:
with open(path_of_experiment_folder + '/relation_to_idx.p', 'rb') as f:
relation_to_idx = pickle.load(f)
except FileNotFoundError:
print("relation_to_idx.p not found")
relation_to_idx=dict()
relation_to_idx = dict()
print(f'Done! It took {time.time() - start_time:.4f}')
return model, (entity_to_idx, relation_to_idx)

Expand Down Expand Up @@ -407,7 +413,7 @@ def intialize_model(args: dict) -> Tuple[object, str]:

def load_json(p: str) -> dict:
with open(p, 'r') as r:
args = json.load(r)
args = json.load(r)
return args


Expand Down Expand Up @@ -525,6 +531,7 @@ def load_numpy(path) -> np.ndarray:
data = np.load(f)
return data


def evaluate(entity_to_idx, scores, easy_answers, hard_answers):
"""
# @TODO: CD: Renamed this function
Expand Down Expand Up @@ -583,3 +590,43 @@ def evaluate(entity_to_idx, scores, easy_answers, hard_answers):
avg_h10 = total_h10 / num_queries

return avg_mrr, avg_h1, avg_h3, avg_h10



def download_file(url, destination_folder="."):
response = requests.get(url, stream=True)
if response.status_code == 200:
filename = os.path.join(destination_folder, os.path.basename(urlparse(url).path))
with open(filename, 'wb') as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
file.write(chunk)
print(f"Downloaded: {filename}")
else:
print(f"Failed to download: {url}")


def download_files_from_url(base_url, destination_folder="."):
response = requests.get(base_url)
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
# Find the table with id "list"
table = soup.find('table', {'id': 'list'})
# Extract all hrefs under the table
hrefs = [a['href'] for a in table.find_all('a', href=True)]
# To remove '?C=N&O=A', '?C=N&O=D', '?C=S&O=A', '?C=S&O=D', '?C=M&O=A', '?C=M&O=D', '../'
hrefs = [i for i in hrefs if len(i) > 3 and "." in i]
for file_url in hrefs:
download_file(base_url + "/" + file_url, destination_folder)


def download_pretrained_model(url: str) -> str:
assert url[-1] != "/"
dir_name = url[url.rfind("/") + 1:]
url_to_download_from = f"https://files.dice-research.org/projects/DiceEmbeddings/{dir_name}"
if os.path.exists(dir_name):
print("Path exists", dir_name)
else:
os.mkdir(dir_name)
download_files_from_url(url_to_download_from, destination_folder=dir_name)
return dir_name
42 changes: 42 additions & 0 deletions examples/download_pretrained_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# pip install dicee
from dicee import KGE
import pandas as pd
from dicee.static_funcs import get_er_vocab
from dicee.eval_static_funcs import evaluate_link_prediction_performance_with_reciprocals

# (1) Download a pre-trained model and store it a newly created directory (KINSHIP-Keci-dim128-epoch256-KvsAll)
model = KGE(url="https://files.dice-research.org/projects/DiceEmbeddings/KINSHIP-Keci-dim128-epoch256-KvsAll")
# (2) Make a prediction
print(model.predict(h="person49", r="term12", t="person39", logits=False))
# Load the train, validation, test datasets
train_triples = pd.read_csv("KGs/KINSHIP/train.txt",
sep="\s+",
header=None, usecols=[0, 1, 2],
names=['subject', 'relation', 'object'],
dtype=str).values.tolist()
valid_triples = pd.read_csv("KGs/KINSHIP/valid.txt",
sep="\s+",
header=None, usecols=[0, 1, 2],
names=['subject', 'relation', 'object'],
dtype=str).values.tolist()
test_triples = pd.read_csv("KGs/KINSHIP/test.txt",
sep="\s+",
header=None, usecols=[0, 1, 2],
names=['subject', 'relation', 'object'],
dtype=str).values.tolist()
# Compute the mapping from each unique entity and relation pair to all entities, i.e.,
# e.g. V_{e_i,r_j} = {x | x \in Entities s.t. e_i, r_j, x) \in Train \cup Val \cup Test}
# This mapping is used to compute the filtered MRR and Hit@n
er_vocab = get_er_vocab(train_triples + valid_triples + test_triples)

result = model.get_eval_report()

print(result["Train"])
print(evaluate_link_prediction_performance_with_reciprocals(model, triples=train_triples,
er_vocab=er_vocab))
print(result["Val"])
print(evaluate_link_prediction_performance_with_reciprocals(model, triples=valid_triples,
er_vocab=er_vocab))
print(result["Test"])
print(evaluate_link_prediction_performance_with_reciprocals(model, triples=test_triples,
er_vocab=er_vocab))
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ rdflib>=7.0.0
sphinx>=7.1.2
ghp-import>=2.1.0
furo>=2023.08.19
tiktoken>=0.5.1
tiktoken>=0.5.1
beautifulsoup4>=4.12.2
Loading

0 comments on commit b0f9951

Please sign in to comment.