Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Add more compile compatibility for Float8Tensor ops #285

Closed
wants to merge 9 commits into from

Conversation

ani300
Copy link
Contributor

@ani300 ani300 commented Jun 14, 2024

No description provided.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 14, 2024
@drisspg
Copy link
Contributor

drisspg commented Jun 16, 2024

We have been pretty targeted with the ops we support for Float8Tensor, I am curious if you have any concrete usecases for these ops. Could you also add some tests cases. Otherwise thanks for the contributions, I am surprised that the contstructor isnt working correctly would also love a test case there!

@drisspg drisspg self-requested a review June 16, 2024 22:57
@ani300
Copy link
Contributor Author

ani300 commented Jun 18, 2024

Yes, most of these ops are due to using the Float8Tensor to handle an FP8 kv-cache. The example usage for all of these will be in https://github.com/pytorch-labs/fp468-llm today or tomorrow at the latest. I'll add the tests for both the ops and the constructor. The issue on the constructor was that that it didn't matter what the original dtype was, it always returned fp32

@ani300
Copy link
Contributor Author

ani300 commented Jun 20, 2024

@drisspg as I'm writing the unit tests, I'm thinking of what a correct copy_ operation looks like: If we try to copy an FP32/FP16/BF16 tensor into an FP8 one, should we do some scaling if the Float8Tensor has it? Or what does the opposite operation look like as well? Say copying an FP8 tensor with scales into an FP32/FP16 one? should we unscale through the FromFloat8Constructor?

@ani300 ani300 changed the title Add more compile compatibility for Float8Tensor ops and fix FromFloat8Construct to return original dtype Add more compile compatibility for Float8Tensor ops Jun 20, 2024
@drisspg
Copy link
Contributor

drisspg commented Jun 20, 2024

@ani300 Great questions

For a copy from scaled_fp8 to hp_type I think we should unscale and copy into.

For copy from hp to an fp8 tensors, I think the semantics are a little hazier. Do you have a clear need for this operation? Otherwise I would potentially ban this for now. Some options:

  • Calculate new scales and copy in both data and new scales, Probably my option 1
  • Create a temporary FP8 tensor using the scales that exist on the LHS: and then copy in the data
  • Just directly convert existing tensor to a float8_e* type, and only copy in the data.

I actually recently thought about a related problem when adding copy_ dispatch to NF4Tensor. This was to enable Subclass -> Subclass copy_. The most reasonable semantic I could come up with is to use the high-precision dtype as the intermediary between the conversion: pytorch/ao#45

@vkuzo any strong thoughts on the semantics here?

@vkuzo
Copy link
Contributor

vkuzo commented Jun 20, 2024

If we try to copy an FP32/FP16/BF16 tensor into an FP8 one, should we do some scaling if the Float8Tensor has it? Or what does the opposite operation look like as well?

IMO:

  • if user wants to use copy_ to copy bf16 to float8, the copy is done with a direct cast, without scaling, and user gets back a torch.Tensor with dtype float8_... and not a Float8Tensor. If user actually wants scaling, there should be some wrapper code which scales the data whichever scaling strategy is relevant (per-tensor/row/group/block, dynamic/delayed/static, etc). I'm not a fan of defining copy_ with an assumed scaling strategy which returns Float8Tensor because of the ambiguity of the scaling details.
  • if user wants to usecopy_ to copy a Float8Tensor to a bf16 tensor, I think it's fine to unscale and copy as there is no ambiguity.

@ani300
Copy link
Contributor Author

ani300 commented Jun 20, 2024

Thanks @drisspg and @vkuzo for your comments and opinions! I'll implement the FP8 -> BF16 copy (which is the one I'm using anyways), add the Float8Tensor to Float8Tensor with everything equal (scale, mm_config, etc.), and ban everything else.

@ani300
Copy link
Contributor Author

ani300 commented Jun 20, 2024

For the failing unit test, I'm waiting on pytorch CI not failing to run at all to land this: pytorch/pytorch#128758

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome!

float8_experimental/float8_ops.py Outdated Show resolved Hide resolved
float8_experimental/float8_ops.py Outdated Show resolved Hide resolved
@drisspg
Copy link
Contributor

drisspg commented Jun 25, 2024

@ani300 failing CI is becuase we are still using last nights nightly

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg merged this pull request in b5a444a.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants