Skip to content

Commit

Permalink
Deploying to gh-pages from @ ab4a438 🚀
Browse files Browse the repository at this point in the history
  • Loading branch information
kisnikser committed Dec 10, 2024
1 parent e04800f commit 3856c0c
Show file tree
Hide file tree
Showing 16 changed files with 594 additions and 315 deletions.
Binary file modified .doctrees/environment.pickle
Binary file not shown.
Binary file modified .doctrees/relaxit.distributions.doctree
Binary file not shown.
83 changes: 38 additions & 45 deletions _modules/relaxit/distributions/CorrelatedRelaxedBernoulli.html
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,7 @@
<div itemprop="articleBody">

<h1>Source code for relaxit.distributions.CorrelatedRelaxedBernoulli</h1><div class="highlight"><pre>
<span></span><span class="ch">#!/usr/bin/env python3</span>
<span class="c1"># -*- coding: utf-8 -*-</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd">The :mod:`relaxit.distributions.CorrelatedRelaxedBernoulli` contains classes:</span>

<span class="sd">- :class:`relaxit.distributions.CorrelatedRelaxedBernoulli.CorrelatedRelaxedBernoulli`</span>

<span class="sd">&quot;&quot;&quot;</span>
<span class="kn">from</span> <span class="nn">__future__</span> <span class="kn">import</span> <span class="n">print_function</span>

<span class="n">__docformat__</span> <span class="o">=</span> <span class="s2">&quot;restructuredtext&quot;</span>

<span class="kn">import</span> <span class="nn">torch</span>
<span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">pyro.distributions.torch_distribution</span> <span class="kn">import</span> <span class="n">TorchDistribution</span>
<span class="kn">from</span> <span class="nn">torch.distributions</span> <span class="kn">import</span> <span class="n">constraints</span>
<span class="kn">from</span> <span class="nn">torch.distributions.normal</span> <span class="kn">import</span> <span class="n">Normal</span>
Expand All @@ -98,14 +86,12 @@ <h1>Source code for relaxit.distributions.CorrelatedRelaxedBernoulli</h1><div cl
<a class="viewcode-back" href="../../../relaxit.distributions.html#relaxit.distributions.CorrelatedRelaxedBernoulli.CorrelatedRelaxedBernoulli">[docs]</a>
<span class="k">class</span> <span class="nc">CorrelatedRelaxedBernoulli</span><span class="p">(</span><span class="n">TorchDistribution</span><span class="p">):</span>
<span class="w"> </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Correlated Relaxed Bernoulli distribution class inheriting from Pyro&#39;s TorchDistribution.</span>

<span class="sd"> :param pi: Selection probability vector.</span>
<span class="sd"> :type pi: torch.Tensor</span>
<span class="sd"> :param R: Covariance matrix.</span>
<span class="sd"> :type R: torch.Tensor</span>
<span class="sd"> :param tau: Temperature hyper-parameter.</span>
<span class="sd"> :type tau: torch.Tensor</span>
<span class="sd"> Correlated Relaxed Bernoulli distribution class from https://openreview.net/pdf?id=oDFvtxzPOx.</span>

<span class="sd"> Args:</span>
<span class="sd"> pi (torch.Tensor): Selection probability vector.</span>
<span class="sd"> R (torch.Tensor): Covariance matrix.</span>
<span class="sd"> tau (torch.Tensor): Temperature hyper-parameter.</span>
<span class="sd"> &quot;&quot;&quot;</span>

<span class="n">arg_constraints</span> <span class="o">=</span> <span class="p">{</span>
Expand All @@ -123,16 +109,14 @@ <h1>Source code for relaxit.distributions.CorrelatedRelaxedBernoulli</h1><div cl
<span class="n">tau</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">validate_args</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;Initializes the CorrelatedRelaxedBernoulli distribution.</span>

<span class="sd"> :param pi: Selection probability vector.</span>
<span class="sd"> :type pi: torch.Tensor</span>
<span class="sd"> :param R: Covariance matrix.</span>
<span class="sd"> :type R: torch.Tensor</span>
<span class="sd"> :param tau: Temperature hyper-parameter.</span>
<span class="sd"> :type tau: torch.Tensor</span>
<span class="sd"> :param validate_args: Whether to validate arguments.</span>
<span class="sd"> :type validate_args: bool</span>
<span class="w"> </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Initializes the CorrelatedRelaxedBernoulli distribution.</span>

<span class="sd"> Args:</span>
<span class="sd"> pi (torch.Tensor): Selection probability vector.</span>
<span class="sd"> R (torch.Tensor): Covariance matrix.</span>
<span class="sd"> tau (torch.Tensor): Temperature hyper-parameter.</span>
<span class="sd"> validate_args (bool, optional): Whether to validate arguments. Defaults to None.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">validate_args</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_validate_args</span><span class="p">(</span><span class="n">pi</span><span class="p">,</span> <span class="n">R</span><span class="p">,</span> <span class="n">tau</span><span class="p">)</span>
Expand All @@ -153,6 +137,9 @@ <h1>Source code for relaxit.distributions.CorrelatedRelaxedBernoulli</h1><div cl
<span class="sd"> The batch shape represents the shape of independent distributions.</span>
<span class="sd"> For example, if `pi` is a tensor of shape (batch_size, pi_shape),</span>
<span class="sd"> the batch shape will be `[batch_size]`, indicating batch_size independent Bernoulli distributions.</span>

<span class="sd"> Returns:</span>
<span class="sd"> torch.Size: The batch shape of the distribution.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pi</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>

Expand All @@ -164,6 +151,9 @@ <h1>Source code for relaxit.distributions.CorrelatedRelaxedBernoulli</h1><div cl
<span class="sd"> The event shape represents the shape of each individual event.</span>
<span class="sd"> For example, if `pi` is a tensor of shape (batch_size, pi_shape),</span>
<span class="sd"> the event shape will be `[pi_shape]`.</span>

<span class="sd"> Returns:</span>
<span class="sd"> torch.Size: The event shape of the distribution.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pi</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">:]</span>

Expand All @@ -173,10 +163,11 @@ <h1>Source code for relaxit.distributions.CorrelatedRelaxedBernoulli</h1><div cl
<span class="w"> </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Generates a sample from the distribution using the reparameterization trick.</span>

<span class="sd"> :param sample_shape: The shape of the sample.</span>
<span class="sd"> :type sample_shape: torch.Size</span>
<span class="sd"> :return: A sample from the distribution.</span>
<span class="sd"> :rtype: torch.Tensor</span>
<span class="sd"> Args:</span>
<span class="sd"> sample_shape (torch.Size, optional): The shape of the sample. Defaults to torch.Size().</span>

<span class="sd"> Returns:</span>
<span class="sd"> torch.Tensor: A sample from the distribution.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># Sample from the standard multivariate normal distribution</span>
<span class="n">shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">sample_shape</span><span class="p">)</span> <span class="o">+</span> <span class="nb">tuple</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">pi</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
Expand Down Expand Up @@ -205,10 +196,11 @@ <h1>Source code for relaxit.distributions.CorrelatedRelaxedBernoulli</h1><div cl
<span class="w"> </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Generates a sample from the distribution.</span>

<span class="sd"> :param sample_shape: The shape of the sample.</span>
<span class="sd"> :type sample_shape: torch.Size</span>
<span class="sd"> :return: A sample from the distribution.</span>
<span class="sd"> :rtype: torch.Tensor</span>
<span class="sd"> Args:</span>
<span class="sd"> sample_shape (torch.Size, optional): The shape of the sample. Defaults to torch.Size().</span>

<span class="sd"> Returns:</span>
<span class="sd"> torch.Tensor: A sample from the distribution.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">rsample</span><span class="p">(</span><span class="n">sample_shape</span><span class="p">)</span></div>
Expand All @@ -220,10 +212,11 @@ <h1>Source code for relaxit.distributions.CorrelatedRelaxedBernoulli</h1><div cl
<span class="w"> </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Computes the log probability of the given value.</span>

<span class="sd"> :param value: The value for which to compute the log probability.</span>
<span class="sd"> :type value: torch.Tensor</span>
<span class="sd"> :return: The log probability of the given value.</span>
<span class="sd"> :rtype: torch.Tensor</span>
<span class="sd"> Args:</span>
<span class="sd"> value (torch.Tensor): The value for which to compute the log probability.</span>

<span class="sd"> Returns:</span>
<span class="sd"> torch.Tensor: The log probability of the given value.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_validate_args</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_validate_sample</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
Expand All @@ -244,8 +237,8 @@ <h1>Source code for relaxit.distributions.CorrelatedRelaxedBernoulli</h1><div cl
<span class="w"> </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Validates the given sample value.</span>

<span class="sd"> :param value: The sample value to validate.</span>
<span class="sd"> :type value: torch.Tensor</span>
<span class="sd"> Args:</span>
<span class="sd"> value (torch.Tensor): The sample value to validate.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_validate_args</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="n">value</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">all</span><span class="p">()</span> <span class="ow">or</span> <span class="ow">not</span> <span class="p">(</span><span class="n">value</span> <span class="o">&lt;=</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">all</span><span class="p">():</span>
Expand Down
Loading

0 comments on commit 3856c0c

Please sign in to comment.