Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent type-inferability escaping for rrule of sortslices #817

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Jan 1, 2025

In 1.11 something changed i guess with inlining, constant-propagation and/or unrolling.
And now inds = ntuple(d -> d == dims ? p : (:), N) doesn't infer.
It used to be able to work it out based on constant folding over dims and N
but now it gives back a Tuple{Union{Colon, Vector{Int64}}, Union{Colon, Vector{Int64}}}}
I couldn't workout how to get it to do that again.
But it is so cheap to recompute approprate inds since N is like under 5 most of the time, recomputing it is cheap.
(unlike recomputing p which is not)

This at least stops there being a non-bitstype field on the pullback closure.
and contains the inference failure from poluting outside the function.

Together with #816 we should then have 1.11 passing again.

Comment on lines 27 to +29
test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2))
test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last))
test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum), check_inferred=false)
test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2))
test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last))
test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum), check_inferred=false)
test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum))
test_rrule(sortslices, rand(3, 4); fkwargs=(; dims=2))
test_rrule(sortslices, rand(5, 4); fkwargs=(; dims=1, rev=true, by=last))
test_rrule(sortslices, rand(3, 4, 5); fkwargs=(; dims=3, by=sum))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant