Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Deepseek-v3 #63

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

Add Deepseek-v3 #63

wants to merge 5 commits into from

Conversation

miladm
Copy link
Collaborator

@miladm miladm commented Jan 30, 2025

Goal:

  • Add Deepseek-v3
  • Enable single-chip TPU functionality
  • Use TorchAx

Non-Goal / Next-Steps:

  • Real input tensor
  • Real weights
  • FP8 quantization kernels enablement (fp8_gmm)
  • Distrbuted

@miladm miladm self-assigned this Jan 30, 2025
Copy link

google-cla bot commented Jan 30, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@miladm miladm requested review from qihqi and yaochengji January 30, 2025 09:58
@miladm
Copy link
Collaborator Author

miladm commented Jan 30, 2025

Should probably submit the original model before this PR to easily spot the diff.

@qihqi
Copy link
Collaborator

qihqi commented Jan 30, 2025

Great to get it running! Few high level changes:

  1. let's remove all the README.md etc along with their perf graphs, pdfs etc. Instead, in the .py files that we forked, add the URL to their github repo of where we forked from.
  2. Make a new README.md with the command you run to run it on TPU
  3. edit requirements.txt to TPU requirements.

Copy link
Collaborator

@qihqi qihqi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp to unblock feel free to merge after the change

@tengyifei
Copy link
Collaborator

tengyifei commented Jan 30, 2025

Can we add a unit test to make sure it runs and produces correct results as compared to CPU eager? Example for Llama: https://github.com/AI-Hypercomputer/torchprime/blob/main/torchprime/experimental/torchax_models/test/test_llama.py#L40

Btw, we also need to fix lint. You can do that by running ruff check and ruff format.

Thanks!

@yaochengji
Copy link
Collaborator

@miladm looks like only part of the model is converted to jax device.

I reverted the code change in model.py and tried to call model.to("jax") outside. I got the error: torchax.tensor.OperatorNotFound: Operator with name aten::rms_norm has no lowering.

I can rewrite the rms_norm op to fine grained op as llama model first. @qihqi , BTW, do we apply decomposition before converting torch op to jax ops?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants