diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4d0b0b3..ad05ea9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -3,35 +3,26 @@ name: Python Package using Conda on: [push, pull_request, workflow_dispatch] jobs: - build-linux: + build: runs-on: ubuntu-latest + strategy: - max-parallel: 5 - + matrix: + python-version: [3.7] + steps: - - uses: actions/checkout@v4 - - name: Set up Python 3.10 - uses: actions/setup-python@v3 - with: - python-version: '3.10' - - name: Add conda to system path - run: | - # $CONDA is an environment variable pointing to the root of the miniconda directory - echo $CONDA/bin >> $GITHUB_PATH - - name: Install dependencies - run: | - conda env update --file environment.yml --name base - - name: Lint with flake8 - run: | - conda install flake8 - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - conda install pytest - pytest - - name: Testing - run: | - PYTHONPATH=tests/test_simple.py + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: 3.7 + + - name: Checkout 🛎️ + uses: actions/checkout@v2 + + - name: Install Dependencies + run: | + pip install torch + ls ./ + - name: Testing + run: | + PYTHONPATH=tests/test_simple.py diff --git a/tests/test_simple.py b/tests/test_simple.py new file mode 100644 index 0000000..e76990b --- /dev/null +++ b/tests/test_simple.py @@ -0,0 +1,26 @@ +import sys, os +import torch +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) +from relaxit.distributions import ( + HardConcrete, + GaussianRelaxedBernoulli +) + +def test_rsample(): + # a = torch.tensor([0.2, 0.4, 0.3, 0.1], requires_grad=True) + loc = torch.tensor([0.], requires_grad=True) + scale = torch.tensor([1.], requires_grad=True) + alpha = torch.tensor([1.], requires_grad=True) + beta = torch.tensor([2.], requires_grad=True) + gamma = torch.tensor([-3.], requires_grad=True) + xi = torch.tensor([4.], requires_grad=True) + + distr_2 = GaussianRelaxedBernoulli(loc = loc, scale=scale) + samples_2 = distr_2.rsample(sample_shape = torch.Size([3])) + assert samples_2.shape == torch.Size([3, 1]) + assert samples_2.requires_grad == True + distr_3 = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi) + samples_3 = distr_3.rsample(sample_shape = torch.Size([3])) + assert samples_3.shape == torch.Size([3, 1]) + assert samples_3.requires_grad == True + print("rsample is OK") \ No newline at end of file