-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Changes for training entropy model and correcting attention in local models #25
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few suggestions for improvements. But the changes seem functionally good!
@@ -176,6 +179,10 @@ class TrainArgs(BaseModel): | |||
data: DataloaderArgs = DataloaderArgs() | |||
optim: OptimArgs = OptimArgs() | |||
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() | |||
# This is only needed for training the entropy model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our old code we had an architecture parameter which selects either vanilla Transformer or BLT. That seems easier than having all these args. Could we do that instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general agreed, I'm not quite sure yet how to do this with pydantic yet. I'll see what i can do in the next PR that will have the entropy model training config/code. Ideally, I agree we can have a parameter to specify architecture, the tricky bit is having pydantic instantiate the default values for model
based on that.
cross_attn_decoder=False, | ||
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, | ||
cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, | ||
# Defaults |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we avoid copying all the defaults below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually prefer the explicit copy rather than share a config in which not all the parameters are significant. There might be a way to copy all the values though, I'll look into that at least.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about latent_transformer instead of global_transformer? We might want to have that rename across all files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, i can do that.
from bytelatent.model.utils import create_causal_mask, downsample | ||
from bytelatent.tokenizers.blt_tokenizer import BOE_ID | ||
|
||
logger = logging.getLogger() | ||
|
||
|
||
class LocalModelArgs(BaseModel): | ||
model_config = ConfigDict(extra="forbid") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like this could be simplified by inheriting from the BaseTransformerArgs. There should be very few additional things that the local models need to know about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, assuming that its possible to override the defualt args in BaseTransformerArgs
): | ||
if attn_impl is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be clearer in a single line:
attn_impl = attn_impl or self.attn_impl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I actually prefer the explicit style, but not super strong preference.
logging.warning( | ||
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention." | ||
) | ||
WARNED_SDPA = True | ||
return "causal" | ||
elif attn_impl == "flex_attention": | ||
return create_block_mask(causal_mask, None, None, seqlen, seqlen) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add assert that the bias type is causal here and for sdpa.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't add an assert here since then our code won't run at all without xformers, which some other issue comments need to run on not as capable GPUs. I think I could do an intermediate solution where I provide a way to suppress the error, but by default crash training.
This may or may not be the correct place to discuss... but I ran into a problem with the entropy model, in the patcher code. This code expects an entropy model to exist as a checkpoint on-disk, but I was hoping to pass an already-instantiated entropy model to the patcher, to train it alongside the latent model. Is there any way we could rewrite the realtime patching, to allow a user to pass any arbitrary |
@Vectorrent I think it would be better to open a separate issue, since this will be closed once I merge my PR. If you open a new issue, I'll comment there. |
…local models Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan:
Summary:
Test Plan: