Skip to content

Commit

Permalink
feat: initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
alonfnt committed Mar 2, 2024
0 parents commit c7f2103
Show file tree
Hide file tree
Showing 11 changed files with 308 additions and 0 deletions.
9 changes: 9 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
version: 2
updates:
- package-ecosystem: pip
directory: "/"
schedule:
interval: weekly
timezone: CET
open-pull-requests-limit: 10

33 changes: 33 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# This workflow will install Python dependencies, run tests with a variety of Python versions
name: Python testing

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
python -m pip install .
python -m pip install pytest
- name: Test with pytest
run: |
pytest
31 changes: 31 additions & 0 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Upload Python Package

on:
release:
types: [published]

permissions:
contents: read

jobs:
deploy:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Ignore python files
**/__pycache__/**
venv/**

# Ignore ctags
tags
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2023 Albert Alonso

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# tsnex: Minimal t-SNEx implementation in JAX

**tsnex** is a lightweight, high-performance Python library for t-Distributed Stochastic Neighbor Embedding (t-SNE) built on top of JAX. Leveraging the power of JAX, `tsnex` offers JIT compilation, automatic differentiation, and hardware acceleration support to efficiently handle high-dimensional data for visualization and clustering tasks.

| [**Usage**](#usage)
[**Installation**](#installation)
| [**Contributing**](#contributing)
| [**License**](#license)

## Usage<a id="usage"></a>
```python

key = jax.random.key(0)
X = jax.random.normal(key, shape=(100, 50))

X_embedded = tsnex.transform(X, n_components=2)
```

## Installation<a id="installation"></a>
`tsnex` can be installed using [PyPI](https://pypi.org/project/tsnex/) via `pip`:
```
pip install tsnex
```
or from GitHub directly
```
pip install git+git://github.com/alonfnt/tsnex.git
```

Likewise, you can clone this repository and install it locally

```bash
git clone https://github.com/alonfnt/tsnex.git
cd tsnex
pip install -r requirements.txt
```

## Contributing<a id="contributing"></a>
We welcome contributions to **tsnex**! Whether it's adding new features, improving documentation, or reporting issues, please feel free to make a pull request or open an issue.

## License<a id="license"></a>
Bayex is licensed under the MIT License. See the ![LICENSE](LICENSE) file for more details.

31 changes: 31 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "tsnex"
version = "0.0.1"
description = "Minimal t-distributed stochastic neighbor embedding (t-SNE) implementation in JAX."
readme = "README.md"
requires-python = ">=3.9"
authors = [
{name = "Albert Alonso", email = "[email protected]"},
{name = "Antonio Matas Gil"}
]
license = {file = "LICENSE"}
keywords = ["jax", "tsne", "python", "data-visualization"]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
]

[project.urls]
"Homepage" = "https://github.com/alonfnt/tsnex"
"Documentation" = "https://github.com/alonfnt/tsnex"
"Source" = "https://github.com/alonfnt/tsnex"
"Bug Tracker" = "https://github.com/alonfnt/tsnex/issues"
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
jax
jaxlib
3 changes: 3 additions & 0 deletions tsnex/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .tsne import transform

__all__ = ['transform']
73 changes: 73 additions & 0 deletions tsnex/tsne.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import jax
import jax.numpy as jnp

import pcax


def kl_divergence(p, q):
return jnp.sum(p * jnp.log(p / q))


def euclidean_distance(x, y):
return jnp.sum((x - y) ** 2, axis=-1)


def probability_fn(x, y, sigma):
return jnp.exp(-euclidean_distance(x, y) / (2 * sigma**2))


def transform(
X,
*,
n_components=2,
perplexity=30.0,
learning_rate=1e-3,
init="pca",
seed=0,
n_iter=1000,
metric_fn=None,
):
"""
Transform X to a lower dimensional representation using T-distributed Stochastic Neighbor Embedding.
Args:
Returns:
- X_new: jax.Array, shape (n_samples, n_components)
"""

if init == "pca":
state = pcax.fit(X, n_components)
X_new = pcax.transform(state, X)
elif init == "random":
X_new = jax.random.normal(jax.random.key(seed), (X.shape[0], n_components))
else:
raise ValueError(f"Unknown init_method: {init}")

if metric_fn is None:
metric_fn = euclidean_distance
metric_fn = jax.vmap(jax.vmap(metric_fn, in_axes=(0, None)), in_axes=(None, 0))

# Compute the probability of neighbours on the original embedding.
vmapped_prob = jax.vmap(
jax.vmap(probability_fn, in_axes=(0, None, None)), in_axes=(None, 0, None)
)

P = vmapped_prob(X, X, perplexity)

@jax.grad
def loss_fn(x):
distances = metric_fn(x, x)
Q = jax.nn.softmax(-distances)
return kl_divergence(P, Q)

def train_step(x, _):
grads = loss_fn(x)
x_new = x - learning_rate * grads
return x_new, None

n_exageration = 250
X_new, _ = jax.lax.scan(train_step, X_new, xs=None, length=n_iter + n_exageration)

return X_new
57 changes: 57 additions & 0 deletions tsnex/tsne_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import jax
import jax.numpy as jnp
import pytest
import tsnex


@pytest.mark.parametrize(
"p,q,expected",
[
(
jnp.array([0.1, 0.9]),
jnp.array([0.9, 0.1]),
jnp.sum(jnp.array([0.1, 0.9]) * jnp.log(jnp.array([0.1, 0.9]) / jnp.array([0.9, 0.1]))),
),
(jnp.array([0.5, 0.5]), jnp.array([0.5, 0.5]), 0),
],
)
def test_kl_divergence(p, q, expected):
result = tsnex.tsne.kl_divergence(p, q)
assert jnp.isclose(result, expected, atol=1e-6)


@pytest.mark.parametrize(
"x,y,expected",
[
(jnp.array([0, 0]), jnp.array([1, 1]), 2),
(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]), 27),
],
)
def test_euclidean_distance(x, y, expected):
result = tsnex.tsne.euclidean_distance(x, y)
assert jnp.isclose(result, expected)


@pytest.mark.parametrize(
"x,y,sigma,expected",
[
(jnp.array([0, 0]), jnp.array([1, 1]), 1, jnp.exp(-1)),
],
)
def test_probability_fn(x, y, sigma, expected):
result = tsnex.tsne.probability_fn(x, y, sigma)
assert jnp.isclose(result, expected, atol=1e-6)


@pytest.mark.parametrize("init_method", ["pca", "random"])
def test_transform_shape_and_locality(init_method):
key = jax.random.key(0)
X = jax.random.normal(key, shape=(100, 50))
X_transformed = tsnex.transform(X, init=init_method, seed=42, n_iter=10)
assert X_transformed.shape == (100, 2)


def test_transform_invalid_init():
key = jax.random.key(0)
with pytest.raises(ValueError):
tsnex.transform(jax.random.uniform(key, shape=(10, 5)), init="invalid_init")

0 comments on commit c7f2103

Please sign in to comment.