Skip to content

Commit

Permalink
Replace and relocate print with warning
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin ROYER committed Jul 2, 2024
1 parent 88a7a16 commit 5e17323
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/python/gudhi/representations/vector_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# - 2020/12 Gard: A more flexible Betti curve class capable of computing exact curves.
# - 2021/11 Vincent Rouvreau: factorize _automatic_sample_range

import warnings

import numpy as np
from scipy.spatial.distance import cdist
from sklearn.base import BaseEstimator, TransformerMixin
Expand Down Expand Up @@ -819,13 +821,11 @@ def fit(self, X, y=None, sample_weight=None):
# In fitting we remove infinite birth/death time points so that every center is finite. We do not care about duplicates.
filtered_measures_concat = measures_concat[~np.isinf(measures_concat).any(axis=1), :] if len(measures_concat) else measures_concat
filtered_weights_concat = weights_concat[~np.isinf(measures_concat).any(axis=1)] if len(measures_concat) else weights_concat

n_points = len(filtered_measures_concat)
if not n_points:
raise ValueError("Cannot fit Atol on measure with infinite components only.")

if n_points < n_clusters:
# If not enough points to fit (including 0), we will arbitrarily put centers as [-np.inf]^measure_dim at the end.
print(f"[Atol] had {n_points} points to fit {n_clusters} clusters, adding meaningless cluster centers.")
self.quantiser.n_clusters = n_points

self.quantiser.fit(X=filtered_measures_concat, sample_weight=filtered_weights_concat)
Expand All @@ -844,7 +844,9 @@ def fit(self, X, y=None, sample_weight=None):
self.inertias = np.min(dist_centers, axis=0)/2

if n_points < n_clusters:
# Where we arbitrarily put centers as [-np.inf]^measure_dim.
# There weren't enough points to fit n_clusters, so we arbitrarily put centers as [-np.inf]^measure_dim.
warnings.warn(f"[Atol] after flitering had only {n_points} points to fit {n_clusters} clusters,"
f"adding meaningless cluster centers.", RuntimeWarning)
fill_center = np.repeat(np.inf, repeats=X[0].shape[1])
fill_inertia = 0
self.centers = np.concatenate([self.centers, np.repeat([fill_center], repeats=n_clusters-n_points, axis=0)])
Expand Down

0 comments on commit 5e17323

Please sign in to comment.