Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Cache a subclass of torch.Tensor #35792

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open

Conversation

IlyasMoutawwakil
Copy link
Member

What does this PR do?

Both torch script tracing and torch dynamo/fx have restrictions on input types (torch script has more) which makes the export fail as one torch module (the model) is passing another (the cache) around as its input. Having Cache be a subclass of torch.Tensor bypasses these issues and imo makes more sense as the Cache class has no forward and is just a container of torch tensors.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle LGTM. I'm calling up the torch.export<>transformers expert to review to double-check these changes are also okay for that goal as well 🤗

Question: Cache object holds a list of tensors, usually with a pair of tensors per layer. On some cases, we can have different tensors of a cache on different devices. Would this conflict with the new inheritance?

Double-checks:

  1. Have you confirmed that slow llama tests and slow cache tests have no regressions with respect to main? (RUN_SLOW=1 py.test tests/models/llama/test_modeling_lama.py -vv and RUN_SLOW=1 py.test tests/utils/test_cache_utils.py -vv)
  2. Have you confirmed that llama + static cache + compilation preserves throughput? (can share a script if needed :) )

src/transformers/cache_utils.py Show resolved Hide resolved
{},
proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache),
)
# def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for optimum and you're part of optimum, so I'm assuming it's okay :D

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm not sure why this is was needed as well, tagging @echarlaix @mht-sharma for more info

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure either

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding @michaelbenayoun who worked on this

@gante
Copy link
Member

gante commented Jan 20, 2025

@guangy10 as requested on Slack, have a look if you're available 🙏

@guangy10
Copy link
Contributor

For the correctness testing, no extensive testing, but we do have some correctness guarantee for supported models test_export_static_cache (pointer). Can you run slow tests on this PR?

Also I'm not exactly sure if the StaticCache will be functioning as expected. Because with nn.Module the Cache is registered as a mutable buffer and lifted to the graph input during export. I'm curious how it works with tensor subclass. It seems like tensor subclasses do not directly support buffer registration like nn.Module does. Can we compare the graph between using the nn.Module solution vs. the tensor subclass solution.

Alternatively, since the motivation is to handle the legacy torch script tracing (I assume the traffic to this path will be lower and lower over time), would it be a cleaner separation if we create a dedicated Cache subclass for it but keeping the one for pytorch2.0+ as nn.Module? No need to maintain compatibility to the torch script solution.

@IlyasMoutawwakil
Copy link
Member Author

IlyasMoutawwakil commented Jan 22, 2025

Question: Cache object holds a list of tensors, usually with a pair of tensors per layer. On some cases, we can have different tensors of a cache on different devices. Would this conflict with the new inheritance?

Shouldn't be an issue as we're not using the _make_subclass() but rather _make_wrapper_subclass(), the difference is explained by @albanD:

These two functions do quite different things. The main difference is that when you do _make_subclass(), the current object is a honest to goodness Tensor with data in its storage and everything. When you do _make_wrapper_subclass(), the current object has no data and it is expected that some field on the Tensor will be another Tensor (hence the outer one being called wrapper) that contains real data.
in https://dev-discuss.pytorch.org/t/whats-the-difference-between-torch-tensor-make-subclass-and-torch-tensor-make-wrapper-subclass/1839

One example is the QuantizedTensor subclass which has two dtypes (a public one qt.dtype and an internal one qt._data.dtype

Have you confirmed that slow llama tests and slow cache tests have no regressions with respect to main? (RUN_SLOW=1 py.test tests/models/llama/test_modeling_lama.py -vv and RUN_SLOW=1 py.test tests/utils/test_cache_utils.py -vv)
Have you confirmed that llama + static cache + compilation preserves throughput? (can share a script if needed :) )

Running them right now (btw is there a way to trigger them on the CI ?), I was only running llama fast tests and llama+executorch integration tests.

@IlyasMoutawwakil
Copy link
Member Author

IlyasMoutawwakil commented Jan 22, 2025

Edit: confirmed these two tests fail on main as well

Running RUN_SLOW=1 pytest tests/models/llama/test_modeling_llama.py -vv give two errors which I guess are related the machine I'm testing on (A100 vs the A10 that's used in the CI) ;

FAILED tests/models/llama/test_modeling_llama.py::LlamaIntegrationTest::test_llama_3_1_hard - AssertionError: 'Tell[74 chars]ical social and political upheaval in France t[557 chars]s.\n' != 'Tell[74 chars]ical political...
FAILED tests/models/llama/test_modeling_llama.py::LlamaIntegrationTest::test_model_7b_logits_bf16 - AssertionError: False is not true

in the first social and political is reversed to political and social :

E       AssertionError: 'Tell[74 chars]ical social and political upheaval in France t[557 chars]s.\n' != 'Tell[74 chars]ical political and social upheaval in France t[557 chars]s.\n'
E       Diff is 1259 characters long. Set self.maxDiff to None to see it.

in the second the assertion is not verbose enough:

>       self.assertTrue(
            torch.allclose(
                EXPECTED_MEAN[self.cuda_compute_capability_major_version].to(torch_device),
                out.logits.float().mean(-1),
                atol=1e-2,
                rtol=1e-2
            )
        )
E       AssertionError: False is not true

adding some verbosity:

E       AssertionError: False is not true : Expected: tensor([[-6.5208, -4.1218, -4.9377, -3.2536,  0.8127, -2.9811,  1.2918, -3.3848]],
E              device='cuda:0')
E       Got: tensor([[-6.5081, -4.1175, -4.9761, -3.1678,  0.8199, -3.0029,  1.2809, -3.3309]],
E              device='cuda:0')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants