diff --git a/dyana/loaders/megatron/main.py b/dyana/loaders/megatron/main.py index 1abc1d8..616a3c4 100644 --- a/dyana/loaders/megatron/main.py +++ b/dyana/loaders/megatron/main.py @@ -1,3 +1,5 @@ +# ruff: noqa: I001, E402, F401, F821 +# type: ignore import os import sys import logging @@ -11,12 +13,12 @@ warnings.filterwarnings("ignore", category=UserWarning) # Import torch and configure CUDA -import torch +import torch # noqa: E402 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True if torch.cuda.is_available(): - torch.cuda.init() + torch.cuda.init() # type: ignore[no-untyped-call] torch.cuda.set_device(0) if __name__ == "__main__": @@ -31,7 +33,7 @@ raise RuntimeError("CUDA is not available but required") # Force CUDA initialization - torch.cuda.init() + torch.cuda.init() # type: ignore[no-untyped-call] torch.cuda.set_device(0) # Allocate a small tensor to ensure CUDA is working test_tensor = torch.zeros(1, device="cuda") @@ -93,20 +95,20 @@ try: te.initialize() - print(f"Initialized Transformer Engine version: {te.__version__}") + print(f"Initialized Transformer Engine version: {te.__version__}") # noqa: F821 except Exception as e: print(f"Warning: Transformer Engine initialization failed: {e}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") try: - print(f"Transformer Engine version: {transformer_engine.__version__}") + print(f"Transformer Engine version: {transformer_engine.__version__}") # noqa: F821 print(f"CUDA devices: {torch.cuda.device_count()}") print(f"CUDA version: {torch.version.cuda}") profiler.track( "env_info", { - "te_version": transformer_engine.__version__, + "te_version": transformer_engine.__version__, # noqa: F821 "cuda_devices": torch.cuda.device_count(), "cuda_version": torch.version.cuda, }, @@ -147,7 +149,7 @@ tokenizer = LlamaTokenizer.from_pretrained(str(tokenizer_path.parent), local_files_only=True) profiler.on_stage("tokenizer_loaded") - model = GPTModel( + model = GPTModel( # noqa: F821 config=config, vocab_size=tokenizer.vocab_size, max_sequence_length=4096,