Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extract_central_slices_rfft performance #62

Open
rsanchezgarc opened this issue Feb 23, 2024 · 4 comments
Open

extract_central_slices_rfft performance #62

rsanchezgarc opened this issue Feb 23, 2024 · 4 comments
Labels
enhancement New feature or request

Comments

@rsanchezgarc
Copy link
Collaborator

Hi,
I have been using a profiler for the function extract_central_slices_rfft and found that the conjugate_mask is responsible for an important fraction of the whole execution

grid[conjugate_mask] *= -1

projections[conjugate_mask] = torch.conj(projections[conjugate_mask])

These are the results of my profiler (setting CUDA_LAUNCH_BLOCKING=1 to avoid asynchronous run).

Function: extract_central_slices_rfft at line 58

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    58                                           @profile
    59                                           def extract_central_slices_rfft(
    60                                               dft: torch.Tensor,
    61                                               image_shape: tuple[int, int, int],
    62                                               rotation_matrices: torch.Tensor,
    63                                               rotation_matrix_zyx: bool,
    64                                           ):
    65                                               """Extract central slice from an fftshifted rfft."""
    66                                               # generate grid of DFT sample frequencies for a central slice in the xy-plane
    67                                               # these are a coordinate grid for the DFT
    68      1050    7847185.4   7473.5     42.5      grid = rotated_central_slice_grid(
    69       525        391.5      0.7      0.0          image_shape=image_shape,
    70       525        246.1      0.5      0.0          rotation_matrices=rotation_matrices,
    71       525        151.7      0.3      0.0          rotation_matrix_zyx=rotation_matrix_zyx,
    72       525        378.2      0.7      0.0          rfft=True,
    73       525        114.8      0.2      0.0          fftshift=True,
    74       525        821.4      1.6      0.0          device=dft.device,
    75                                               )  # (..., h, w, 3)
    76                                           
    77                                               # flip coordinates in redundant half transform
    78       525     213000.9    405.7      1.2      conjugate_mask = grid[..., 2] < 0
    79                                               # conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3') #This operation does not compile
    80       525      29983.9     57.1      0.2      conjugate_mask = conjugate_mask.unsqueeze(-1).expand(*[-1] * len(conjugate_mask.shape), 3) #This does compile
    81       525    3829907.7   7295.1     20.7      grid[conjugate_mask] *= -1 #This is super slower. masked_scatter_ seems 15% faster, still slow #TODO: This is the cornercase
    82                                               # grid.masked_scatter_(conjugate_mask, -1 * grid.masked_select(conjugate_mask))
    83                                           
    84       525      20220.4     38.5      0.1      conjugate_mask = conjugate_mask[..., 0]  # un-repeat
    85                                           
    86                                               # convert frequencies to array coordinates and sample from DFT
    87      1050    2115877.3   2015.1     11.5      grid = fftfreq_to_dft_coordinates(
    88       525        566.8      1.1      0.0          frequencies=grid,
    89       525        631.4      1.2      0.0          image_shape=image_shape,
    90       525        417.4      0.8      0.0          rfft=True
    91                                               )
    92       525    2975877.7   5668.3     16.1      projections = sample_dft_3d(dft=dft, coordinates=grid)  # (..., h, w) rfft
    93                                           
    94                                               # take complex conjugate of values from redundant half transform
    95       525    1429438.3   2722.7      7.7      projections[conjugate_mask] = torch.conj(projections[conjugate_mask]) #This is slower
    96                                               # projections.masked_scatter_(conjugate_mask, torch.conj(projections.masked_select(conjugate_mask)))
    97       525       1142.1      2.2      0.0      return projections



The following changes can speed the code a bit.

grid.masked_scatter_(conjugate_mask, -1 * grid.masked_select(conjugate_mask))
projections.masked_scatter_(conjugate_mask, torch.conj(projections.masked_select(conjugate_mask)))
@rsanchezgarc rsanchezgarc added the enhancement New feature or request label Feb 23, 2024
@alisterburt
Copy link
Collaborator

Hey! Thanks for this analysis

Am I reading right that this would save 15% of 20% of the total execution time, i.e. 3% of the total?
If so, it's a nice optimisation but the API for scatter/gather feels so unintuitive to me that I'd hesitate to use it here.

The biggest point for optimisation I haven't done yet is limiting the fourier coefficients used when extracting/inserting - that would be a huge speedup in lots of cases

@rsanchezgarc
Copy link
Collaborator Author

You are correct, the overall impact will be <4%, and if you manage to compile the code, there will probably be no differences (I am not able to compile the code into something useful yet..., but that is another story), so I am not saying we should use this trick. The point of this issue is to keep track of this bottleneck (almost 30% of the running time employed in applying the mask) so we can think of how to deal with it (my current suggestion was just one of the many trials I run). I am quite surprised by the huge impact that these two lines of code have on the overall performance of my projection matching code.

@alisterburt
Copy link
Collaborator

I'm with you, really appreciate the effort! I agree it's quite big but it's a lot of elements being modified at once so it doesn't feel so surprising - if you don't implement it this way then you have to do twice as much work in rotated_central_slice_grid

if we add a similar function rotated_central_slice_coords which returns a (..., d, 3) array of coords up to a certain fftfreq_max we can remove a lot of work if not all frequencies are needed -> some other code would also need to be modified to work with this API elsewhere

@rsanchezgarc
Copy link
Collaborator Author

I'm with you, really appreciate the effort! I agree it's quite big but it's a lot of elements being modified at once so it doesn't feel so surprising - if you don't implement it this way then you have to do twice as much work in rotated_central_slice_grid

if we add a similar function rotated_central_slice_coords which returns a (..., d, 3) array of coords up to a certain fftfreq_max we can remove a lot of work if not all frequencies are needed -> some other code would also need to be modified to work with this API elsewhere

I think that this will be a quite useful feature if we can cut down computing time by a > 2x factor, but I am pretty sure that the implementation will be trickier that it sounds, so it is totally up to you if you think this is worth it.

One comment about the advantage of using the mask over doing twice rotated_central_slice_grid. Perhaps using torch.cuda.stream() can parallelize the two executions efficiently, leading to a smaller execution time. I have never tried but who knows

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants