-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit c7f2103
Showing
11 changed files
with
308 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
version: 2 | ||
updates: | ||
- package-ecosystem: pip | ||
directory: "/" | ||
schedule: | ||
interval: weekly | ||
timezone: CET | ||
open-pull-requests-limit: 10 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# This workflow will install Python dependencies, run tests with a variety of Python versions | ||
name: Python testing | ||
|
||
on: | ||
push: | ||
branches: [ main ] | ||
pull_request: | ||
branches: [ main ] | ||
|
||
jobs: | ||
build: | ||
|
||
runs-on: ubuntu-latest | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
python-version: ["3.9", "3.10", "3.11"] | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi | ||
python -m pip install . | ||
python -m pip install pytest | ||
- name: Test with pytest | ||
run: | | ||
pytest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
name: Upload Python Package | ||
|
||
on: | ||
release: | ||
types: [published] | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
deploy: | ||
|
||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: '3.x' | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install build | ||
- name: Build package | ||
run: python -m build | ||
- name: Publish package | ||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 | ||
with: | ||
user: __token__ | ||
password: ${{ secrets.PYPI_API_TOKEN }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Ignore python files | ||
**/__pycache__/** | ||
venv/** | ||
|
||
# Ignore ctags | ||
tags |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2023 Albert Alonso | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# tsnex: Minimal t-SNEx implementation in JAX | ||
|
||
**tsnex** is a lightweight, high-performance Python library for t-Distributed Stochastic Neighbor Embedding (t-SNE) built on top of JAX. Leveraging the power of JAX, `tsnex` offers JIT compilation, automatic differentiation, and hardware acceleration support to efficiently handle high-dimensional data for visualization and clustering tasks. | ||
|
||
| [**Usage**](#usage) | ||
[**Installation**](#installation) | ||
| [**Contributing**](#contributing) | ||
| [**License**](#license) | ||
|
||
## Usage<a id="usage"></a> | ||
```python | ||
|
||
key = jax.random.key(0) | ||
X = jax.random.normal(key, shape=(100, 50)) | ||
|
||
X_embedded = tsnex.transform(X, n_components=2) | ||
``` | ||
|
||
## Installation<a id="installation"></a> | ||
`tsnex` can be installed using [PyPI](https://pypi.org/project/tsnex/) via `pip`: | ||
``` | ||
pip install tsnex | ||
``` | ||
or from GitHub directly | ||
``` | ||
pip install git+git://github.com/alonfnt/tsnex.git | ||
``` | ||
|
||
Likewise, you can clone this repository and install it locally | ||
|
||
```bash | ||
git clone https://github.com/alonfnt/tsnex.git | ||
cd tsnex | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Contributing<a id="contributing"></a> | ||
We welcome contributions to **tsnex**! Whether it's adding new features, improving documentation, or reporting issues, please feel free to make a pull request or open an issue. | ||
|
||
## License<a id="license"></a> | ||
Bayex is licensed under the MIT License. See the ![LICENSE](LICENSE) file for more details. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
[build-system] | ||
requires = ["setuptools", "wheel"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[project] | ||
name = "tsnex" | ||
version = "0.0.1" | ||
description = "Minimal t-distributed stochastic neighbor embedding (t-SNE) implementation in JAX." | ||
readme = "README.md" | ||
requires-python = ">=3.9" | ||
authors = [ | ||
{name = "Albert Alonso", email = "[email protected]"}, | ||
{name = "Antonio Matas Gil"} | ||
] | ||
license = {file = "LICENSE"} | ||
keywords = ["jax", "tsne", "python", "data-visualization"] | ||
classifiers = [ | ||
"Development Status :: 3 - Alpha", | ||
"Intended Audience :: Developers", | ||
"License :: OSI Approved :: MIT License", | ||
"Programming Language :: Python :: 3", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
] | ||
|
||
[project.urls] | ||
"Homepage" = "https://github.com/alonfnt/tsnex" | ||
"Documentation" = "https://github.com/alonfnt/tsnex" | ||
"Source" = "https://github.com/alonfnt/tsnex" | ||
"Bug Tracker" = "https://github.com/alonfnt/tsnex/issues" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
jax | ||
jaxlib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .tsne import transform | ||
|
||
__all__ = ['transform'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
import pcax | ||
|
||
|
||
def kl_divergence(p, q): | ||
return jnp.sum(p * jnp.log(p / q)) | ||
|
||
|
||
def euclidean_distance(x, y): | ||
return jnp.sum((x - y) ** 2, axis=-1) | ||
|
||
|
||
def probability_fn(x, y, sigma): | ||
return jnp.exp(-euclidean_distance(x, y) / (2 * sigma**2)) | ||
|
||
|
||
def transform( | ||
X, | ||
*, | ||
n_components=2, | ||
perplexity=30.0, | ||
learning_rate=1e-3, | ||
init="pca", | ||
seed=0, | ||
n_iter=1000, | ||
metric_fn=None, | ||
): | ||
""" | ||
Transform X to a lower dimensional representation using T-distributed Stochastic Neighbor Embedding. | ||
Args: | ||
Returns: | ||
- X_new: jax.Array, shape (n_samples, n_components) | ||
""" | ||
|
||
if init == "pca": | ||
state = pcax.fit(X, n_components) | ||
X_new = pcax.transform(state, X) | ||
elif init == "random": | ||
X_new = jax.random.normal(jax.random.key(seed), (X.shape[0], n_components)) | ||
else: | ||
raise ValueError(f"Unknown init_method: {init}") | ||
|
||
if metric_fn is None: | ||
metric_fn = euclidean_distance | ||
metric_fn = jax.vmap(jax.vmap(metric_fn, in_axes=(0, None)), in_axes=(None, 0)) | ||
|
||
# Compute the probability of neighbours on the original embedding. | ||
vmapped_prob = jax.vmap( | ||
jax.vmap(probability_fn, in_axes=(0, None, None)), in_axes=(None, 0, None) | ||
) | ||
|
||
P = vmapped_prob(X, X, perplexity) | ||
|
||
@jax.grad | ||
def loss_fn(x): | ||
distances = metric_fn(x, x) | ||
Q = jax.nn.softmax(-distances) | ||
return kl_divergence(P, Q) | ||
|
||
def train_step(x, _): | ||
grads = loss_fn(x) | ||
x_new = x - learning_rate * grads | ||
return x_new, None | ||
|
||
n_exageration = 250 | ||
X_new, _ = jax.lax.scan(train_step, X_new, xs=None, length=n_iter + n_exageration) | ||
|
||
return X_new |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
import pytest | ||
import tsnex | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"p,q,expected", | ||
[ | ||
( | ||
jnp.array([0.1, 0.9]), | ||
jnp.array([0.9, 0.1]), | ||
jnp.sum(jnp.array([0.1, 0.9]) * jnp.log(jnp.array([0.1, 0.9]) / jnp.array([0.9, 0.1]))), | ||
), | ||
(jnp.array([0.5, 0.5]), jnp.array([0.5, 0.5]), 0), | ||
], | ||
) | ||
def test_kl_divergence(p, q, expected): | ||
result = tsnex.tsne.kl_divergence(p, q) | ||
assert jnp.isclose(result, expected, atol=1e-6) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"x,y,expected", | ||
[ | ||
(jnp.array([0, 0]), jnp.array([1, 1]), 2), | ||
(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]), 27), | ||
], | ||
) | ||
def test_euclidean_distance(x, y, expected): | ||
result = tsnex.tsne.euclidean_distance(x, y) | ||
assert jnp.isclose(result, expected) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"x,y,sigma,expected", | ||
[ | ||
(jnp.array([0, 0]), jnp.array([1, 1]), 1, jnp.exp(-1)), | ||
], | ||
) | ||
def test_probability_fn(x, y, sigma, expected): | ||
result = tsnex.tsne.probability_fn(x, y, sigma) | ||
assert jnp.isclose(result, expected, atol=1e-6) | ||
|
||
|
||
@pytest.mark.parametrize("init_method", ["pca", "random"]) | ||
def test_transform_shape_and_locality(init_method): | ||
key = jax.random.key(0) | ||
X = jax.random.normal(key, shape=(100, 50)) | ||
X_transformed = tsnex.transform(X, init=init_method, seed=42, n_iter=10) | ||
assert X_transformed.shape == (100, 2) | ||
|
||
|
||
def test_transform_invalid_init(): | ||
key = jax.random.key(0) | ||
with pytest.raises(ValueError): | ||
tsnex.transform(jax.random.uniform(key, shape=(10, 5)), init="invalid_init") |