From c27337dc0b0ebe772d838615fb3827fc816092f2 Mon Sep 17 00:00:00 2001 From: Caglar Demir Date: Wed, 29 Nov 2023 17:37:22 +0100 Subject: [PATCH] Fix gradio --- dicee/knowledge_graph_embeddings.py | 10 +++++----- dicee/static_funcs.py | 14 +++++++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/dicee/knowledge_graph_embeddings.py b/dicee/knowledge_graph_embeddings.py index 40d0a643..4cd2e41d 100644 --- a/dicee/knowledge_graph_embeddings.py +++ b/dicee/knowledge_graph_embeddings.py @@ -1050,11 +1050,11 @@ def predict(str_subject: str, str_predicate: str, str_object: str, random_exampl gr.Interface( fn=predict, - inputs=[gr.inputs.Textbox(lines=1, placeholder=None, label='Subject'), - gr.inputs.Textbox(lines=1, placeholder=None, label='Predicate'), - gr.inputs.Textbox(lines=1, placeholder=None, label='Object'), "checkbox"], - outputs=[gr.outputs.Textbox(label='Input Triple'), - gr.outputs.Dataframe(label='Outputs', type='pandas')], + inputs=[gr.Textbox(lines=1, placeholder=None, label='Subject'), + gr.Textbox(lines=1, placeholder=None, label='Predicate'), + gr.Textbox(lines=1, placeholder=None, label='Object'), "checkbox"], + outputs=[gr.Textbox(label='Input Triple'), + gr.Dataframe(label='Outputs', type='pandas')], title=f'{self.name} Deployment', description='1. Enter a triple to compute its score,\n' '2. Enter a subject and predicate pair to obtain most likely top ten entities or\n' diff --git a/dicee/static_funcs.py b/dicee/static_funcs.py index 0b7a9b84..ce1e4664 100644 --- a/dicee/static_funcs.py +++ b/dicee/static_funcs.py @@ -456,22 +456,22 @@ def deploy_tail_entity_prediction(pre_trained_kge, str_subject, str_predicate, t if pre_trained_kge.model.name == 'Shallom': print('Tail entity prediction is not available for Shallom') raise NotImplementedError - scores, entity = pre_trained_kge.predict_topk(h=[str_subject], r=[str_predicate], topk=top_k) - return f'( {str_subject}, {str_predicate}, ? )', pd.DataFrame({'Entity': entity, 'Score': scores}) + str_entity_scores = pre_trained_kge.predict_topk(h=[str_subject], r=[str_predicate], topk=top_k) + + return f'( {str_subject}, {str_predicate}, ? )', pd.DataFrame(str_entity_scores,columns=["entity","score"]) def deploy_head_entity_prediction(pre_trained_kge, str_object, str_predicate, top_k): if pre_trained_kge.model.name == 'Shallom': print('Head entity prediction is not available for Shallom') raise NotImplementedError - - scores, entity = pre_trained_kge.predict_topk(t=[str_object], r=[str_predicate], topk=top_k) - return f'( ?, {str_predicate}, {str_object} )', pd.DataFrame({'Entity': entity, 'Score': scores}) + str_entity_scores = pre_trained_kge.predict_topk(t=[str_object], r=[str_predicate], topk=top_k) + return f'( ?, {str_predicate}, {str_object} )', pd.DataFrame(str_entity_scores,columns=["entity","score"]) def deploy_relation_prediction(pre_trained_kge, str_subject, str_object, top_k): - scores, relations = pre_trained_kge.predict_topk(h=[str_subject], t=[str_object], topk=top_k) - return f'( {str_subject}, ?, {str_object} )', pd.DataFrame({'Relations': relations, 'Score': scores}) + str_relation_scores = pre_trained_kge.predict_topk(h=[str_subject], t=[str_object], topk=top_k) + return f'( {str_subject}, ?, {str_object} )', pd.DataFrame(str_relation_scores,columns=["relation","score"]) @timeit