diff --git a/tsnex/tsne.py b/tsnex/tsne.py index 7f2f9e0..b5cb276 100644 --- a/tsnex/tsne.py +++ b/tsnex/tsne.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Callable import jax import jax.numpy as jnp @@ -40,7 +41,7 @@ def _binary_search_perplexity(distances, target, tol=1e-5, max_iter=200): sigma = 1.0 def cond_fun(val): - (sigma, perplexity, i, sigma_min, sigma_max) = val + (_, perplexity, i, _, _) = val return (jnp.abs(perplexity - target) > tol) & (i < max_iter) def body_fun(val): @@ -59,6 +60,19 @@ def body_fun(val): return sigma +@partial( + jax.jit, + static_argnames=[ + "n_components", + "perplexity", + "learning_rate", + "init", + "seed", + "n_iter", + "metric_fn", + "early_exageration", + ], +) def transform( X: jax.Array, *, @@ -68,7 +82,7 @@ def transform( init: str = "pca", seed: int = 0, n_iter: int = 1000, - metric_fn: Callable = None, + metric_fn: Callable = euclidean_distance, early_exageration: float = 12.0, ) -> jax.Array: """ @@ -96,8 +110,6 @@ def transform( else: raise ValueError(f"Unknown init_method: {init}") - if metric_fn is None: - metric_fn = euclidean_distance metric_fn = jax.vmap(jax.vmap(metric_fn, in_axes=(0, None)), in_axes=(None, 0)) # Compute the probability of neighbours on the original embedding.