Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some sort of batched reduction to desc.batching #1564

Closed
f0uriest opened this issue Feb 4, 2025 · 1 comment · Fixed by #1577
Closed

Add some sort of batched reduction to desc.batching #1564

f0uriest opened this issue Feb 4, 2025 · 1 comment · Fixed by #1577
Assignees
Labels
P3 Highest Priority, someone is/should be actively working on this performance New feature or request to make the code faster

Comments

@f0uriest
Copy link
Member

f0uriest commented Feb 4, 2025

          not for this PR, but at some point might be worth implementing some sort of batched reduction `desc.batching`

Originally posted by @f0uriest in #1440 (comment)

Basically, if we need to evaluate some expensive function over a bunch of inputs and then reduce it (eg sum). Currently the two main options are a single loop which is slow on GPU, or a full vmap which means materializing the full array in memory (or perhaps hoping the the compiler fuses stuff so it isn't but still). One could potentially improve the performance of the loop by unrolling part of it though that increases compile time. Ideally you want something that takes advantage of the fact that the loop iterations are independent.

semantically should be equivalent to the following:

def unbatched_reduce(fun, x, reduction=jnp.add):
    out = 0
    for xi in x:
        out = reduction(out, fun(xi))
    return out
@unalmis unalmis added the performance New feature or request to make the code faster label Feb 9, 2025
@unalmis unalmis self-assigned this Feb 9, 2025
@unalmis unalmis added the easy Short and simple to code or review label Feb 9, 2025
@unalmis unalmis removed the easy Short and simple to code or review label Feb 9, 2025
@dpanici dpanici added the P3 Highest Priority, someone is/should be actively working on this label Feb 10, 2025
@unalmis
Copy link
Collaborator

unalmis commented Feb 12, 2025

@f0uriest if you think this can be useful partially applying a Jacobian/Hessian before materializing it completely, then please make an issue for that at some point

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
P3 Highest Priority, someone is/should be actively working on this performance New feature or request to make the code faster
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants