-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add example of PyMC usage #41
Comments
This is a great idea! This is related to this issue in NumPyro. However the thing to figure out here is how to convert PyMC3's log-posterior into a log-prior and log-likelihood that take in a single data point as well as model parameters. That way the gradient estimators in sgmcmcjax can be used. In the |
You can obtain the "log-prior" graph via Instead of just using |
@ricardoV94 : the log-prior graph sounds like the correct thing yes. However the log-likelihood would need to be a function without data "baked in". So a function like The algorithms in sgmcmcjax would use this function to only evaluate the log-likelihood for subsets of data at a time as you mentioned. You can see how this happens in this page of the docs in cell number 5: Note that this |
How do you make sure data is partitioned correctly? If I have a vectorized Normal likelihood with Anyway, to have the data as an input, you can do something more involved like: import pymc as pm
from pymc.sampling_jax import get_jaxified_graph
with pm.Model() as model:
x = pm.Normal("x")
y = pm.Normal("y", x, observed=[0, 1, 2, 3])
print(model.compile_fn(model.datalogpt)({"x": 0}))
original_data = []
dummy_data_inputs = []
for observed_RV in model.observed_RVs:
data = model.rvs_to_values[observed_RV]
dummy_data_input = data.type()
# TODO: You should revert these inplace changes after you're done
model.rvs_to_values[observed_RV] = dummy_data_input
original_data.append(data.data)
dummy_data_inputs.append(dummy_data_input)
loglike_fn = get_jaxified_graph(
inputs=model.value_vars + dummy_data_inputs,
outputs=[model.datalogpt],
)
print(
loglike_fn(0, original_data[0][:1]),
loglike_fn(0, original_data[0][:2]),
loglike_fn(0, original_data[0][:3]),
loglike_fn(0, original_data[0][:4]),
sep="\n",
)
|
Your second link should be: https://github.com/blackjax-devs/blackjax/blob/main/examples/use_with_pymc.ipynb |
I don't quite understand what this means; could you explain some more please? In your example: In case this is what you were asking about: the standard way to estimate the log-likelihood for these sgmcmc algorithms is to use equation (4) in this paper. Note that in this equation This is implemented in this library here; note that |
@jeremiecoullon thanks for the reply. The idea of splitting the data in the likelihood just seemed surprising to me. For instance if you have a linear regression, each "datapoint" includes multiple predictors + observation(s). Or if you have multivariate likelihood you may have several observations per "datapoint". I was just curious how did you guys handle those cases. Anyway let me know if the snippet above is sufficient to make an example with PyMC :) |
Splitting data in the likelihood: the approach is exactly the same as in stochastic gradient descent (and related algorithms like Adam, RMSProp etc..).
Are you asking what happens when the data is high dimensional? And what we do in the case of supervised learning? To give an example, consider a dataset of 5 points: D = {(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x5, y5)}. This is a supervised problem, and x_i may be high dimensional. A minibatch of data might be D_minibatch = {(x2, y2), (x5, y5)}, or D_minibatch = {(x1, y1), (x4, y4), (x5, y5)}. The likelihood for this smaller dataset is faster to evaluate than the likelihood for the entire dataset So you don’t “split” the dimensionality of x_i, and you always keep x_i and y_i together. Is this what you meant? Snippet: I'll have a look at this to understand it and see if that works! |
Yes that's what I meant! Thanks for clarifying |
I added a notebook with a basic Gaussian model example working. Some questions:
|
This is great @jeremiecoullon, thanks for adding that! Also the SolveTriangular, we still need to add a JAX implementation for it. |
PyMC v4 has a JAX backend and can use samplers like those from numpyro or blackjax, it should be pretty easy thus to add an example of how to use SGMCMCJax with a PyMC model.
https://github.com/pymc-devs/pymc/blob/main/pymc/sampling_jax.py#L141
https://github.com/blackjax-devs/blackjax/blob/main/examples/use_with_pymc3.ipynb
The text was updated successfully, but these errors were encountered: