Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Vepricov committed Nov 5, 2024
1 parent 581227e commit 8c41573
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 29 deletions.
49 changes: 20 additions & 29 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,26 @@ name: Python Package using Conda
on: [push, pull_request, workflow_dispatch]

jobs:
build-linux:
build:
runs-on: ubuntu-latest

strategy:
max-parallel: 5

matrix:
python-version: [3.7]

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: '3.10'
- name: Add conda to system path
run: |
# $CONDA is an environment variable pointing to the root of the miniconda directory
echo $CONDA/bin >> $GITHUB_PATH
- name: Install dependencies
run: |
conda env update --file environment.yml --name base
- name: Lint with flake8
run: |
conda install flake8
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
conda install pytest
pytest
- name: Testing
run: |
PYTHONPATH=tests/test_simple.py
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: 3.7

- name: Checkout 🛎️
uses: actions/checkout@v2

- name: Install Dependencies
run: |
pip install torch
ls ./
- name: Testing
run: |
PYTHONPATH=tests/test_simple.py
26 changes: 26 additions & 0 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import sys, os
import torch
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src')))
from relaxit.distributions import (
HardConcrete,
GaussianRelaxedBernoulli
)

def test_rsample():
# a = torch.tensor([0.2, 0.4, 0.3, 0.1], requires_grad=True)
loc = torch.tensor([0.], requires_grad=True)
scale = torch.tensor([1.], requires_grad=True)
alpha = torch.tensor([1.], requires_grad=True)
beta = torch.tensor([2.], requires_grad=True)
gamma = torch.tensor([-3.], requires_grad=True)
xi = torch.tensor([4.], requires_grad=True)

distr_2 = GaussianRelaxedBernoulli(loc = loc, scale=scale)
samples_2 = distr_2.rsample(sample_shape = torch.Size([3]))
assert samples_2.shape == torch.Size([3, 1])
assert samples_2.requires_grad == True
distr_3 = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi)
samples_3 = distr_3.rsample(sample_shape = torch.Size([3]))
assert samples_3.shape == torch.Size([3, 1])
assert samples_3.requires_grad == True
print("rsample is OK")

0 comments on commit 8c41573

Please sign in to comment.