Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: always jit the transform function #6

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions tsnex/tsne.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Callable
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -40,7 +41,7 @@ def _binary_search_perplexity(distances, target, tol=1e-5, max_iter=200):
sigma = 1.0

def cond_fun(val):
(sigma, perplexity, i, sigma_min, sigma_max) = val
(_, perplexity, i, _, _) = val
return (jnp.abs(perplexity - target) > tol) & (i < max_iter)

def body_fun(val):
Expand All @@ -59,6 +60,19 @@ def body_fun(val):
return sigma


@partial(
jax.jit,
static_argnames=[
"n_components",
"perplexity",
"learning_rate",
"init",
"seed",
"n_iter",
"metric_fn",
"early_exageration",
],
)
def transform(
X: jax.Array,
*,
Expand All @@ -68,7 +82,7 @@ def transform(
init: str = "pca",
seed: int = 0,
n_iter: int = 1000,
metric_fn: Callable = None,
metric_fn: Callable = euclidean_distance,
early_exageration: float = 12.0,
) -> jax.Array:
"""
Expand Down Expand Up @@ -96,8 +110,6 @@ def transform(
else:
raise ValueError(f"Unknown init_method: {init}")

if metric_fn is None:
metric_fn = euclidean_distance
metric_fn = jax.vmap(jax.vmap(metric_fn, in_axes=(0, None)), in_axes=(None, 0))

# Compute the probability of neighbours on the original embedding.
Expand Down
Loading