diff --git a/dicee/knowledge_graph_embeddings.py b/dicee/knowledge_graph_embeddings.py index 4afee4b4..40d0a643 100644 --- a/dicee/knowledge_graph_embeddings.py +++ b/dicee/knowledge_graph_embeddings.py @@ -1,4 +1,3 @@ -import os.path from typing import List, Tuple, Set, Iterable, Dict, Union import torch from torch import optim @@ -8,7 +7,6 @@ from .static_funcs import random_prediction, deploy_triple_prediction, deploy_tail_entity_prediction, \ deploy_relation_prediction, deploy_head_entity_prediction, load_pickle from .static_funcs_training import evaluate_lp -from .static_preprocess_funcs import create_constraints import numpy as np import sys import gradio as gr @@ -20,12 +18,11 @@ class KGE(BaseInteractiveKGE): def __init__(self, path=None, url=None, construct_ensemble=False, model_name=None, apply_semantic_constraint=False): - super().__init__(path=path, url=url,construct_ensemble=construct_ensemble, model_name=model_name) + super().__init__(path=path, url=url, construct_ensemble=construct_ensemble, model_name=model_name) def __str__(self): return "KGE | " + str(self.model) - # given a string, return is bpe encoded embeddings def eval_lp_performance(self, dataset=List[Tuple[str, str, str]], filtered=True): assert isinstance(dataset, list) and len(dataset) > 0 diff --git a/tests/test_link_prediction_evaluation.py b/tests/test_link_prediction_evaluation.py index 5b8f56f3..74cfc642 100644 --- a/tests/test_link_prediction_evaluation.py +++ b/tests/test_link_prediction_evaluation.py @@ -97,19 +97,13 @@ def test_distmult_kvsall(self): model = KGE(result["path_experiment_folder"]) assert result["Train"] == evaluate_link_prediction_performance_with_reciprocals(model, triples=train_triples, er_vocab=get_er_vocab( - all_triples), - re_vocab=get_re_vocab( all_triples)) assert result["Val"] == evaluate_link_prediction_performance_with_reciprocals(model, triples=valid_triples, er_vocab=get_er_vocab( - all_triples), - re_vocab=get_re_vocab( all_triples)) assert result["Test"] == evaluate_link_prediction_performance_with_reciprocals(model, triples=test_triples, er_vocab=get_er_vocab( - all_triples), - re_vocab=get_re_vocab( all_triples)) @pytest.mark.filterwarnings('ignore::UserWarning')