diff --git a/is_fid_pytorch.py b/is_fid_pytorch.py index 7a7c759..c0c7537 100644 --- a/is_fid_pytorch.py +++ b/is_fid_pytorch.py @@ -246,7 +246,7 @@ def __calc_is(preds, n_split, return_each_score=False): scores.append(entropy(pyx, py)) split_scores.append(np.exp(np.mean(scores))) if n_split == 1 and return_each_score: - return np.mean(split_scores), np.std(split_scores), scores + return scores, 0 return np.mean(split_scores), np.std(split_scores) @staticmethod