Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Changes for training entropy model and correcting attention in local models #25

Merged
merged 1 commit into from
Jan 17, 2025

Conversation

EntilZha
Copy link
Contributor

Summary:

  • Refactor local model configs to be separate and clearer
  • Add attention arguments and correct which attention is used in local models
  • Preparation for being able to have an entropy train script
  • Fix failing unit tests

Test Plan:

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 16, 2025
@EntilZha EntilZha marked this pull request as ready for review January 17, 2025 01:02
@EntilZha EntilZha changed the title [WIP] Changes for training entropy model and correcting attention in local models Changes for training entropy model and correcting attention in local models Jan 17, 2025
Copy link

@artidoro artidoro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few suggestions for improvements. But the changes seem functionally good!

@@ -176,6 +179,10 @@ class TrainArgs(BaseModel):
data: DataloaderArgs = DataloaderArgs()
optim: OptimArgs = OptimArgs()
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
# This is only needed for training the entropy model

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our old code we had an architecture parameter which selects either vanilla Transformer or BLT. That seems easier than having all these args. Could we do that instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general agreed, I'm not quite sure yet how to do this with pydantic yet. I'll see what i can do in the next PR that will have the entropy model training config/code. Ideally, I agree we can have a parameter to specify architecture, the tricky bit is having pydantic instantiate the default values for model based on that.

cross_attn_decoder=False,
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
cross_attn_init_by_pooling=args.cross_attn_init_by_pooling,
# Defaults

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid copying all the defaults below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually prefer the explicit copy rather than share a config in which not all the parameters are significant. There might be a way to copy all the values though, I'll look into that at least.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about latent_transformer instead of global_transformer? We might want to have that rename across all files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, i can do that.

from bytelatent.model.utils import create_causal_mask, downsample
from bytelatent.tokenizers.blt_tokenizer import BOE_ID

logger = logging.getLogger()


class LocalModelArgs(BaseModel):
model_config = ConfigDict(extra="forbid")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this could be simplified by inheriting from the BaseTransformerArgs. There should be very few additional things that the local models need to know about.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, assuming that its possible to override the defualt args in BaseTransformerArgs

):
if attn_impl is None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be clearer in a single line:
attn_impl = attn_impl or self.attn_impl

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I actually prefer the explicit style, but not super strong preference.

logging.warning(
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention."
)
WARNED_SDPA = True
return "causal"
elif attn_impl == "flex_attention":
return create_block_mask(causal_mask, None, None, seqlen, seqlen)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add assert that the bias type is causal here and for sdpa.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add an assert here since then our code won't run at all without xformers, which some other issue comments need to run on not as capable GPUs. I think I could do an intermediate solution where I provide a way to suppress the error, but by default crash training.

@Vectorrent
Copy link
Contributor

This may or may not be the correct place to discuss... but I ran into a problem with the entropy model, in the patcher code.

This code expects an entropy model to exist as a checkpoint on-disk, but I was hoping to pass an already-instantiated entropy model to the patcher, to train it alongside the latent model. Is there any way we could rewrite the realtime patching, to allow a user to pass any arbitrary nn.Module to the Patcher, as an alternative to loading from a checkpoint?

@EntilZha
Copy link
Contributor Author

@Vectorrent I think it would be better to open a separate issue, since this will be closed once I merge my PR. If you open a new issue, I'll comment there.

…local models

Summary:

- Refactor local model configs to be separate and clearer
- Add attention arguments and correct which attention is used in local models
- Preparation for being able to have an entropy train script
- Fix failing unit tests

Test Plan:
@EntilZha EntilZha changed the title Changes for training entropy model and correcting attention in local models [WIP] Changes for training entropy model and correcting attention in local models Jan 17, 2025
@EntilZha EntilZha merged commit 6ffeb66 into main Jan 17, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants