Skip to content

Make Delta.log_prob jit-able on metal and raise explicit error for Delta sampling site during initialization. #3727

Make Delta.log_prob jit-able on metal and raise explicit error for Delta sampling site during initialization.

Make Delta.log_prob jit-able on metal and raise explicit error for Delta sampling site during initialization. #3727

Workflow file for this run

# adapted from https://github.com/actions/starter-workflows/blob/master/ci/python-package.yml
name: CI
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
env:
PYTEST_ADDOPTS: "--cov=numpyro --cov-append"
jobs:
lint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9","3.10","3.12"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt install -y pandoc gsfonts
python -m pip install --upgrade pip
pip install jaxlib
pip install jax
pip install .[doc,test]
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -r docs/requirements.txt
pip freeze
- name: Lint with mypy and ruff
run: |
make lint
- name: Build documentation
run: |
make docs
- name: Test documentation
run: |
make doctest
python -m doctest -v README.md
test-modeling:
runs-on: ubuntu-latest
needs: lint
strategy:
matrix:
python-version: ["3.9", "3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt install -y graphviz
python -m pip install --upgrade pip
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install jaxlib
pip install jax
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -e .[dev,test]
pip freeze
- name: Test with pytest
run: |
CI=1 pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/
- name: Test x64
run: |
JAX_ENABLE_X64=1 pytest -vs test/test_distributions.py -k powerLaw
- name: Coveralls
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.10'
uses: coverallsapp/github-action@v2
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
parallel: true
flag-name: test-modeling
test-inference:
runs-on: ubuntu-latest
needs: lint
strategy:
matrix:
python-version: ["3.9", "3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install jaxlib
pip install jax
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -e .[dev,test]
pip freeze
- name: Test with pytest
run: |
pytest -vs --durations=20 test/infer/test_mcmc.py
pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py --ignore=test/contrib/test_nested_sampling.py
pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py
- name: Test x64
run: |
JAX_ENABLE_X64=1 pytest -vs test/infer/test_mcmc.py -k x64
- name: Test chains
run: |
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap"
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/test_tfp.py -k "chain"
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/stochastic_support/test_dcc.py
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_hmc_gibbs.py -k "chain"
- name: Test custom prng
run: |
JAX_ENABLE_CUSTOM_PRNG=1 pytest -vs test/infer/test_mcmc.py
- name: Test nested sampling
run: |
JAX_ENABLE_X64=1 pytest -vs test/contrib/test_nested_sampling.py
- name: Coveralls
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.10'
uses: coverallsapp/github-action@v2
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
parallel: true
flag-name: test-inference
examples:
runs-on: ubuntu-latest
needs: lint
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install jaxlib
pip install jax
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -e .[dev,examples,test]
pip freeze
- name: Test with pytest
run: |
CI=1 XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs -k test_example
- name: Coveralls
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.10'
uses: coverallsapp/github-action@v2
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
parallel: true
flag-name: examples
finish:
needs: [test-modeling, test-inference, examples]
runs-on: ubuntu-latest
steps:
- name: Coveralls finished
uses: coverallsapp/github-action@v2
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
parallel-finished: true
carryforward: "test-modeling,test-inference,examples"