Skip to content

Commit

Permalink
Merge branch 'develop' into release/1.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
lukostaz committed Mar 9, 2020
2 parents 8ff796d + 2c9237e commit b10cd23
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 6 deletions.
4 changes: 4 additions & 0 deletions ampligraph/latent_features/models/EmbeddingModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def __init__(self,
verbose : bool
Verbose mode.
"""
if (loss == "bce") ^ (self.name == "ConvE"):
raise ValueError('Invalid Model - Loss combination. '
'ConvE model can be used with BCE loss only and vice versa.')

# Store for restoring later.
self.all_params = \
{
Expand Down
13 changes: 7 additions & 6 deletions tests/ampligraph/evaluation/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,28 @@
# test for #186
def test_evaluate_performance_too_many_entities_warning():
X = load_yago3_10()
model = TransE(batches_count=1000, seed=0, epochs=1, k=5, eta=1, verbose=True)
model = TransE(batches_count=200, seed=0, epochs=1, k=5, eta=1, verbose=True)
model.fit(X['train'])

# no entity list declared
with pytest.warns(UserWarning):
evaluate_performance(X['test'][::1], model, verbose=True, corrupt_side='o')
evaluate_performance(X['test'][::100], model, verbose=True, corrupt_side='o')

# with larger than threshold entity list
with pytest.warns(UserWarning):
# TOO_MANY_ENT_TH threshold is set to 50,000 entities. Using explicit value to comply with linting
# and thus avoiding exporting unused global variable.
entities_subset = np.union1d(np.unique(X["train"][:, 0]), np.unique(X["train"][:, 2]))[:50000]
evaluate_performance(X['test'][::1], model, verbose=True, corrupt_side='o', entities_subset=entities_subset)
evaluate_performance(X['test'][::100], model, verbose=True, corrupt_side='o', entities_subset=entities_subset)

# with small entity list (no exception expected)
evaluate_performance(X['test'][::1], model, verbose=True, corrupt_side='o', entities_subset=entities_subset[:10])
evaluate_performance(X['test'][::100], model, verbose=True, corrupt_side='o', entities_subset=entities_subset[:10])

# with smaller dataset, no entity list declared (no exception expected)
X_wn18rr = load_wn18rr()
model_wn18 = TransE(batches_count=1000, seed=0, epochs=1, k=5, eta=1, verbose=True)
model_wn18 = TransE(batches_count=200, seed=0, epochs=1, k=5, eta=1, verbose=True)
model_wn18.fit(X_wn18rr['train'])
evaluate_performance(X_wn18rr['test'][::1], model_wn18, verbose=True, corrupt_side='o')
evaluate_performance(X_wn18rr['test'][::100], model_wn18, verbose=True, corrupt_side='o')


def test_evaluate_performance_filter_without_xtest():
Expand Down Expand Up @@ -76,6 +76,7 @@ def test_evaluate_performance_ranking_against_specified_entities():
ranks = ranks.reshape(-1)
assert(np.sum(ranks>len(entities_subset))==0)


def test_evaluate_performance_ranking_against_shuffled_all_entities():
""" Compares mrr of test set by using default protocol against all entities vs
mrr of corruptions generated by corrupting using entities_subset = all entities shuffled
Expand Down
16 changes: 16 additions & 0 deletions tests/ampligraph/latent_features/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@
from ampligraph.evaluation.protocol import to_idx


def test_conve_bce_combo():
# no exception
model = ConvE(loss='bce')

# no exception
model = TransE(loss='nll')

# Invalid combination. Hence exception.
with pytest.raises(ValueError):
model = TransE(loss='bce')

# Invalid combination. Hence exception.
with pytest.raises(ValueError):
model = ConvE(loss='nll')


def test_large_graph_mode():
set_entity_threshold(10)
X = load_wn18()
Expand Down

0 comments on commit b10cd23

Please sign in to comment.