Skip to content

Commit

Permalink
Fix gradio
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Nov 29, 2023
1 parent cbf3a51 commit c27337d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
10 changes: 5 additions & 5 deletions dicee/knowledge_graph_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,11 +1050,11 @@ def predict(str_subject: str, str_predicate: str, str_object: str, random_exampl

gr.Interface(
fn=predict,
inputs=[gr.inputs.Textbox(lines=1, placeholder=None, label='Subject'),
gr.inputs.Textbox(lines=1, placeholder=None, label='Predicate'),
gr.inputs.Textbox(lines=1, placeholder=None, label='Object'), "checkbox"],
outputs=[gr.outputs.Textbox(label='Input Triple'),
gr.outputs.Dataframe(label='Outputs', type='pandas')],
inputs=[gr.Textbox(lines=1, placeholder=None, label='Subject'),
gr.Textbox(lines=1, placeholder=None, label='Predicate'),
gr.Textbox(lines=1, placeholder=None, label='Object'), "checkbox"],
outputs=[gr.Textbox(label='Input Triple'),
gr.Dataframe(label='Outputs', type='pandas')],
title=f'{self.name} Deployment',
description='1. Enter a triple to compute its score,\n'
'2. Enter a subject and predicate pair to obtain most likely top ten entities or\n'
Expand Down
14 changes: 7 additions & 7 deletions dicee/static_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,22 +456,22 @@ def deploy_tail_entity_prediction(pre_trained_kge, str_subject, str_predicate, t
if pre_trained_kge.model.name == 'Shallom':
print('Tail entity prediction is not available for Shallom')
raise NotImplementedError
scores, entity = pre_trained_kge.predict_topk(h=[str_subject], r=[str_predicate], topk=top_k)
return f'( {str_subject}, {str_predicate}, ? )', pd.DataFrame({'Entity': entity, 'Score': scores})
str_entity_scores = pre_trained_kge.predict_topk(h=[str_subject], r=[str_predicate], topk=top_k)

return f'( {str_subject}, {str_predicate}, ? )', pd.DataFrame(str_entity_scores,columns=["entity","score"])


def deploy_head_entity_prediction(pre_trained_kge, str_object, str_predicate, top_k):
if pre_trained_kge.model.name == 'Shallom':
print('Head entity prediction is not available for Shallom')
raise NotImplementedError

scores, entity = pre_trained_kge.predict_topk(t=[str_object], r=[str_predicate], topk=top_k)
return f'( ?, {str_predicate}, {str_object} )', pd.DataFrame({'Entity': entity, 'Score': scores})
str_entity_scores = pre_trained_kge.predict_topk(t=[str_object], r=[str_predicate], topk=top_k)
return f'( ?, {str_predicate}, {str_object} )', pd.DataFrame(str_entity_scores,columns=["entity","score"])


def deploy_relation_prediction(pre_trained_kge, str_subject, str_object, top_k):
scores, relations = pre_trained_kge.predict_topk(h=[str_subject], t=[str_object], topk=top_k)
return f'( {str_subject}, ?, {str_object} )', pd.DataFrame({'Relations': relations, 'Score': scores})
str_relation_scores = pre_trained_kge.predict_topk(h=[str_subject], t=[str_object], topk=top_k)
return f'( {str_subject}, ?, {str_object} )', pd.DataFrame(str_relation_scores,columns=["relation","score"])


@timeit
Expand Down

0 comments on commit c27337d

Please sign in to comment.