-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Cache
a subclass of torch.Tensor
#35792
base: main
Are you sure you want to change the base?
Conversation
1114e7e
to
b67b6eb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In principle LGTM. I'm calling up the torch.export<>transformers expert to review to double-check these changes are also okay for that goal as well 🤗
Question: Cache
object holds a list of tensors, usually with a pair of tensors per layer. On some cases, we can have different tensors of a cache on different devices. Would this conflict with the new inheritance?
Double-checks:
- Have you confirmed that slow llama tests and slow cache tests have no regressions with respect to
main
? (RUN_SLOW=1 py.test tests/models/llama/test_modeling_lama.py -vv
andRUN_SLOW=1 py.test tests/utils/test_cache_utils.py -vv
) - Have you confirmed that llama + static cache + compilation preserves throughput? (can share a script if needed :) )
{}, | ||
proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache), | ||
) | ||
# def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for optimum and you're part of optimum, so I'm assuming it's okay :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'm not sure why this is was needed as well, tagging @echarlaix @mht-sharma for more info
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding @michaelbenayoun who worked on this
@guangy10 as requested on Slack, have a look if you're available 🙏 |
For the correctness testing, no extensive testing, but we do have some correctness guarantee for supported models Also I'm not exactly sure if the Alternatively, since the motivation is to handle the legacy torch script tracing (I assume the traffic to this path will be lower and lower over time), would it be a cleaner separation if we create a dedicated Cache subclass for it but keeping the one for pytorch2.0+ as |
Shouldn't be an issue as we're not using the
One example is the QuantizedTensor subclass which has two dtypes (a public one
Running them right now (btw is there a way to trigger them on the CI ?), I was only running llama fast tests and llama+executorch integration tests. |
Edit: confirmed these two tests fail on main as well Running
in the first
in the second the assertion is not verbose enough:
adding some verbosity:
|
5829a6a
to
da60604
Compare
What does this PR do?
Both torch script tracing and torch dynamo/fx have restrictions on input types (torch script has more) which makes the export fail as one torch module (the model) is passing another (the cache) around as its input. Having Cache be a subclass of torch.Tensor bypasses these issues and imo makes more sense as the Cache class has no forward and is just a container of torch tensors.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.