Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify/Refactor Recompile #1066

Open
seanlaw opened this issue Jan 25, 2025 · 11 comments
Open

Modify/Refactor Recompile #1066

seanlaw opened this issue Jan 25, 2025 · 11 comments
Labels
question Further information is requested

Comments

@seanlaw
Copy link
Contributor

seanlaw commented Jan 25, 2025

Recently, @NimaSarajpoor added the cache._recompile() function in PR #1048. After reading:

  1. https://numba.discourse.group/t/jit-recompile-with-new-arguments/1137
  2. https://numba.discourse.group/t/switching-between-parallel-serial-mode/1125

Specifically, this response caught my attention:

RE: .recompile() etc. The Numba dispatcher object (the thing returned by @jit) is configured directly from the @jit decorator options, the
configuration is not intended to be changed once the dispatcher object is constructed. The recompile() method exists to handle cases such as a global variable that a @jit function refers to changing (recall that Numba considers global variables as compile time constants, so if they change, need to recompile).

I am starting to question whether we should be calling func.recompile()? Instead, I don't mind the monkey-patching solution. The key difference is that we are NOT monkey-patching an EXTERNAL module function and, instead, we are patching an INTERNAL module function, which is fine/safe!

Additionally, it seems much more maintainable.

This might resolve/simplify the unit tests that failed when NUMBA_JIT_DISABLE=1.

@seanlaw seanlaw added the question Further information is requested label Jan 25, 2025
@seanlaw
Copy link
Contributor Author

seanlaw commented Jan 25, 2025

@NimaSarajpoor It looks like we can possibly avoid doing func.recompile() by doing monkey-patching like this:

from numba import njit, config
import time

@njit(fastmath=True)
def _add(a, b):
    return a + b

if __name__ == "__main__":
    # Compile and run
    start = time.time()
    _add(10, 20)
    print(time.time() - start)  # 0.16077685356140137

    # Run without compiling
    start = time.time()
    _add(10, 20)
    print(time.time() - start)  2.6941299438476562e-05

    # Check flags
    print(_add.targetoptions['fastmath'])  # True

    # Prepare to change the function
    _add_py_func = _add.py_func

    # Change the function/flag
    _add = njit(_add_py_func, fastmath=False)

    # Check flags
    print(_add.targetoptions['fastmath'])  # False

    # Compile and run
    start = time.time()
    _add(10, 20)
    print(time.time() - start)  # 0.013135671615600586

    # Run without compiling
    start = time.time()
    _add(10, 20)
    print(time.time() - start)  # 2.8133392333984375e-05

    # Change the function/flag BACK
    _add = njit(_add_py_func, fastmath=True)

    # Check flags
    print(_add.targetoptions['fastmath'])  # True

    # Compile and run
    start = time.time()
    _add(10, 20)
    print(time.time() - start)  # 0.015630722045898438

    # Run without compiling
    start = time.time()
    _add(10, 20)
    print(time.time() - start)

@seanlaw
Copy link
Contributor Author

seanlaw commented Jan 25, 2025

Of course, we should really check this with the fastmath._add_assoc() example but I'm fairly certain it will work. This feels more natural/easier to maintain than calling func.recompile()

@seanlaw seanlaw changed the title Modify Recompile Modify/Refactor Recompile Jan 25, 2025
@seanlaw
Copy link
Contributor Author

seanlaw commented Jan 26, 2025

To avoid func.recompile, it appears that we can simply:

  1. copy the python function
  2. copy the targetoptions dict
  3. clean it up a bit (i.e., remove the "nopython" key since we are already calling njit)
  4. make any necessary changes (i.e., fastmath=False)
  5. create a new njitted-function using the modified dictionary/signature
@njit(fastmath=True)
def _add(a, b):
    return a + b

if __name__ == "__main__":
    _add(10, 20)

    py_func = _add.py_func  # Step 1
    njit_signature = _add.targetoptions.copy()  # Step 2
    njit_signature.pop('nopython', None)  # Step 3
    njit_signature['fastmath'] = False  # Step 4
    _add = njit(_add_py_func, **njit_signature)  # Step 5

Now, when _add is called, it will NOT use fastmath. I think this also means that instead of storing the default fastmath flags, we'll simply store the entire (cleaned up) njit_signature for each STUMPY njit function

@NimaSarajpoor
Copy link
Collaborator

NimaSarajpoor commented Jan 26, 2025

@seanlaw

I am starting to question whether we should be calling func.recompile()?
more natural/easier to maintain than calling func.recompile()

I think it is more natural because it is like overwriting "code" (which is the actual goal here) rather than recompiling it.

_add = njit(_add_py_func, fastmath=False) should work as you pointed out. I am curious to know how fastmath._set will look like. Currently, it shows:

module = importlib.import_module(f".{module_name}", package="stumpy")
func = getattr(module, func_name)
try:
func.targetoptions["fastmath"] = flag
func.recompile()

I think if we revise the lines 57 & 58 as follows:

func = njit(func.py_func, fastmath=flag)
setattr(module, func_name, func)

it should work as tests are passing (I removed the check for fastmath=config just to test this part locally). Then maybe we take a look at the try-except block and see if we can avoid it.


  1. make any necessary changes (i.e., fastmath=False)
  2. create a new njitted-function using the modified dictionary/signature

I think this also means that instead of storing the default fastmath flags, we'll simply store the entire (cleaned up) njit_signature for each STUMPY njit function

I am trying to understand if it is okay to have such flexibility. I am thinking about what you mentioned in another case that sometimes we want to limit the scope of the function. Let's say someone decides to set the key parallel to True. Does it work? I think it works if the non-parallelized function already uses prange (because, in non-parallel mode, it will be treated as range and it is okay). But, IIUC, when the function uses range in a for-loop, then setting the parallel to True should not work.


[Update]
Regarding:

module = importlib.import_module(f".{module_name}", package="stumpy")
func = getattr(module, func_name)
try:
func.targetoptions["fastmath"] = flag
func.recompile()

Replacing line 57 with func = njit(func.py_func, fastmath=flag) is a bad idea because we are ignoring the other arguments in njit signature. So, we can keep func.targetoptions["fastmath"]=flag and just replace line 58 with:

setattr(module, func_name, func)

@seanlaw
Copy link
Contributor Author

seanlaw commented Jan 26, 2025

@NimaSarajpoor Let me submit a PR for you to review

@seanlaw
Copy link
Contributor Author

seanlaw commented Jan 28, 2025

@NimaSarajpoor Based on your review of my PR, I have come to the conclusion that this is NOT better so I will close the PR. Thank you for your insights. The only question that I have is whether we should remove the func.recompile() line in the fastmath._set() function because we ALWAYS need to call cache._recompile() after anyways to ensure that all parent functions are using the latest target version for func? So, func.recompile() is redundant since calling it by itself is almost never enough/sufficient.

Instead, we should add a warning to fastmath._set() and fastmath._reset() to remind the user (i.e., us) to always explicitly call cache._recompile() when they are done changing the fastmath flags.

@NimaSarajpoor
Copy link
Collaborator

NimaSarajpoor commented Jan 28, 2025

@seanlaw

Instead, we should add a warning to fastmath._set() and fastmath._reset() to remind the user (i.e., us) to always explicitly call cache._recompile() when they are done changing the fastmath flags.

Correct. Adding the warning is definitely a good idea regardless of keeping / removing the line func.recompile().

The only question that I have is whether we should remove the func.recompile() line in the fastmath._set() function

So, func.recompile() is redundant since calling it by itself is almost never enough/sufficient.

Do not know the correct answer but going to share my opinion and leave the decision to you as you have better experience. My intention was to make sure fastmath._set does what it says for one njit function, meaning if a user calls fastmath._set() to set fastmath flag for a particular function, I wanted to make sure that it becomes effective for THAT function without having the need to call another function.

so I will close the PR.

I think we should open an issue for the concern that was raised regarding the cache path. This can help us to track it separately. What do you think?

@seanlaw
Copy link
Contributor Author

seanlaw commented Jan 28, 2025

Do not know the correct answer but going to share my opinion and leave the decision to you as you have better experience. My intention was to make sure fastmath._set does what it says for one njit function, meaning if a user calls fastmath._set() to set fastmath flag for a particular function, I wanted to make sure that it becomes effective for THAT function without having the need to call another function.

Yes, I understood the intention. However it gives the wrong impression that doing fastmath._set is all that you have to do and then STUMPY will operate "correctly" thereafter and this is simply not the case. Even I made the wrong assumption and so we should protect ourselves because directly calling func.recompile() is very misleading when you read the code (and then even more confusing if we add the warning on top of it!).

I think we should open an issue for #1067 (comment). This can help us to track it separately. What do you think?

Yes, please see #1069

@NimaSarajpoor
Copy link
Collaborator

NimaSarajpoor commented Jan 29, 2025

@seanlaw
I can now see your view. Thanks for the explanation!

I can submit a new PR to remove .recompile and add a warning. In addition, I can add a note to the docstring of fastmath._set()

@seanlaw
Copy link
Contributor Author

seanlaw commented Jan 29, 2025

@NimaSarajpoor I just added to PR #1070. Would you mind reviewing?

@NimaSarajpoor
Copy link
Collaborator

@seanlaw
Sure. Will do!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants