Skip to content

Commit

Permalink
BTE with KvsAll is done along with downloading a pretrained model
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Nov 10, 2023
1 parent 8ef1e3c commit 765101d
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 10 deletions.
5 changes: 1 addition & 4 deletions dicee/knowledge_graph_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os.path
from typing import List, Tuple, Set, Iterable, Dict, Union
import torch
from torch import optim
Expand All @@ -8,7 +7,6 @@
from .static_funcs import random_prediction, deploy_triple_prediction, deploy_tail_entity_prediction, \
deploy_relation_prediction, deploy_head_entity_prediction, load_pickle
from .static_funcs_training import evaluate_lp
from .static_preprocess_funcs import create_constraints
import numpy as np
import sys
import gradio as gr
Expand All @@ -20,12 +18,11 @@ class KGE(BaseInteractiveKGE):
def __init__(self, path=None, url=None, construct_ensemble=False,
model_name=None,
apply_semantic_constraint=False):
super().__init__(path=path, url=url,construct_ensemble=construct_ensemble, model_name=model_name)
super().__init__(path=path, url=url, construct_ensemble=construct_ensemble, model_name=model_name)

def __str__(self):
return "KGE | " + str(self.model)


# given a string, return is bpe encoded embeddings
def eval_lp_performance(self, dataset=List[Tuple[str, str, str]], filtered=True):
assert isinstance(dataset, list) and len(dataset) > 0
Expand Down
6 changes: 0 additions & 6 deletions tests/test_link_prediction_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,13 @@ def test_distmult_kvsall(self):
model = KGE(result["path_experiment_folder"])
assert result["Train"] == evaluate_link_prediction_performance_with_reciprocals(model, triples=train_triples,
er_vocab=get_er_vocab(
all_triples),
re_vocab=get_re_vocab(
all_triples))

assert result["Val"] == evaluate_link_prediction_performance_with_reciprocals(model, triples=valid_triples,
er_vocab=get_er_vocab(
all_triples),
re_vocab=get_re_vocab(
all_triples))
assert result["Test"] == evaluate_link_prediction_performance_with_reciprocals(model, triples=test_triples,
er_vocab=get_er_vocab(
all_triples),
re_vocab=get_re_vocab(
all_triples))

@pytest.mark.filterwarnings('ignore::UserWarning')
Expand Down

0 comments on commit 765101d

Please sign in to comment.