Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA/Pytorch multiprocessing workaround and test fixes #1304

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def do_forward_pass(neox_args, model, inference=False):
tokens, attention_mask, position_ids = get_batch(
neox_args, context_tokens_tensor[:, : neox_args.seq_length]
)
logits = model((tokens, position_ids, attention_mask))
output = model((tokens, position_ids, attention_mask))
logits = output[0] if isinstance(output, tuple) else output


# reset to train mode, if model was in training before
if model_was_in_train:
Expand Down
54 changes: 54 additions & 0 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import sys
import os

import torch

Expand All @@ -26,6 +27,7 @@
import math


'''
class Tee:
"""Duplicate output to both stdout/err and file"""

Expand Down Expand Up @@ -61,6 +63,58 @@ def flush(self) -> None:
self.file.flush()
except OSError:
pass
'''

class Tee:
"""Duplicate output to both stdout/err and file"""

def __init__(self, file, err: bool = False) -> None:
self.err = err
self.std = sys.stderr if err else sys.stdout

if isinstance(file, str):
try:
# Ensure the directory exists if file is a path
os.makedirs(os.path.dirname(file), exist_ok=True)
self.file = open(file, "w")
except IOError as e:
print(f"Warning: Could not open file {file} for writing. {str(e)}", file=self.std)
self.file = None
elif hasattr(file, 'write') and hasattr(file, 'flush'):
# If it's a file-like object, use it directly
self.file = file
else:
raise ValueError("'file' must be either a file path or a file-like object")

if not err:
sys.stdout = self
else:
sys.stderr = self

def __del__(self) -> None:
if not self.err:
sys.stdout = self.std
else:
sys.stderr = self.std

if self.file and hasattr(self.file, 'close'):
self.file.close()

def write(self, data) -> None:
self.std.write(data)
if self.file:
try:
self.file.write(data)
except IOError as e:
print(f"Warning: Could not write to file. {str(e)}", file=self.std)

def flush(self) -> None:
self.std.flush()
if self.file:
try:
self.file.flush()
except IOError as e:
print(f"Warning: Could not flush file. {str(e)}", file=self.std)


def human_readable_flops(num) -> str:
Expand Down
3 changes: 2 additions & 1 deletion megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
OKAY = f"{GREEN}[OKAY]{END}"
WARNING = f"{YELLOW}[WARNING]{END}"
FAIL = f"{RED}[FAIL]{END}"
ERROR = f"{RED}[ERROR]{END}"
INFO = "[INFO]"

# ZERO defaults by deespeed
Expand Down Expand Up @@ -875,7 +876,6 @@ def calculate_derived(self):
"""
Derives additional configuration values necessary for training from the current config
"""

# number of gpus
# Get number of GPUs param or hostfile to determine train_batch_size
global_num_gpus = getattr(self, "global_num_gpus", None)
Expand All @@ -896,6 +896,7 @@ def calculate_derived(self):
else:
global_num_gpus = torch.cuda.device_count()
self.update_value("global_num_gpus", global_num_gpus)


logging.info(
self.__class__.__name__
Expand Down
25 changes: 14 additions & 11 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import time
import shutil
import itertools
import inspect
import subprocess
from pathlib import Path
from abc import ABC, abstractmethod
from deepspeed.accelerator import get_accelerator
Expand Down Expand Up @@ -48,6 +50,14 @@
DEEPSPEED_UNIT_WORKER_TIMEOUT = 120
DEEPSPEED_TEST_TIMEOUT = 600

def is_rocm_pytorch():
"""
Check if the current PyTorch installation is using ROCm.

Returns:
bool: True if PyTorch is using ROCm, False otherwise.
"""
return hasattr(torch.version, 'hip') and torch.version.hip is not None

def get_xdist_worker_id():
xdist_worker = os.environ.get("PYTEST_XDIST_WORKER", None)
Expand All @@ -67,7 +77,6 @@ def get_master_port():

_num_gpus = None


def set_accelerator_visible():
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
xdist_worker_id = get_xdist_worker_id()
Expand Down Expand Up @@ -123,8 +132,6 @@ def set_accelerator_visible():
def count_gpus():
global _num_gpus
if _num_gpus is None:
import subprocess

nvidia_smi = subprocess.check_output(["nvidia-smi", "--list-gpus"])
_num_gpus = len(nvidia_smi.decode("utf-8").strip().split("\n"))
return _num_gpus
Expand All @@ -137,8 +144,6 @@ def set_cuda_visibile():
xdist_worker_id = 0
if cuda_visible is None:
# CUDA_VISIBLE_DEVICES is not set, discover it from nvidia-smi instead
import subprocess

nvidia_smi = subprocess.check_output(["nvidia-smi", "--list-gpus"])
num_gpus = len(nvidia_smi.decode("utf-8").strip().split("\n"))
cuda_visible = ",".join(map(str, range(num_gpus)))
Expand Down Expand Up @@ -428,9 +433,7 @@ def test_2(self, val1, val2, val3, val4):
assert int(os.environ["WORLD_SIZE"]) == 1
assert all(val1, val2, val3, val4)
"""

def __init__(self):
self.is_dist_test = True
is_dist_test = True

# Temporary directory that is shared among test methods in a class
@pytest.fixture(autouse=True, scope="class")
Expand Down Expand Up @@ -476,7 +479,7 @@ def get_test_path(filename):
def model_setup(yaml_list=None, param_dict=None, clear_data=True):
from megatron.neox_arguments import NeoXArgs
from megatron.mpu import destroy_model_parallel
from megatron import initialize_megatron
from megatron.initialize import initialize_megatron
from megatron.training import setup_model_and_optimizer

destroy_model_parallel() # mpu model parallel contains remaining global vars
Expand Down Expand Up @@ -509,10 +512,10 @@ def model_setup(yaml_list=None, param_dict=None, clear_data=True):
args_loaded.build_tokenizer()

initialize_megatron(neox_args=args_loaded)
model, optimizer, lr_scheduler = setup_model_and_optimizer(
model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
neox_args=args_loaded, use_cache=True
)
return model, optimizer, lr_scheduler, args_loaded
return model, optimizer, lr_scheduler, reference_model, args_loaded


def simulate_deepy_env(monkeypatch, input_args):
Expand Down
3 changes: 3 additions & 0 deletions tests/model/test_fused_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)


@pytest.mark.forked
@pytest.mark.xfail(reason="SystemExit: None")
def test_load_fused_kernels():
load()
Expand All @@ -45,6 +46,7 @@ def test_load_fused_kernels():
raise e


@pytest.mark.forked
@pytest.mark.xfail(reason="SystemExit: None")
def test_fused_softmax():
load()
Expand Down Expand Up @@ -148,6 +150,7 @@ def test_fused_softmax():
)


@pytest.mark.forked
@pytest.mark.xfail(reason="SystemExit: None")
def test_fused_upper_triangle_mask_softmax():
load()
Expand Down
51 changes: 18 additions & 33 deletions tests/model/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
import torch

PARAMS_TO_TEST = {
"pipe_parallel_size,model_parallel_size": [[0, 1], [1, 2], [0, 2], [2, 1]],
"include":["localhost:0,1"],
"pipe_parallel_size,model_parallel_size": [[1, 2], [0, 2], [2, 1]],
"checkpoint_validation_with_forward_pass": [True],
"fp16,fp32_allreduce": [
[
Expand Down Expand Up @@ -61,31 +62,22 @@
}

parameters, names = parametrize(
PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None
PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=42
)

class TestModelCheckpoint(DistributedTest):
world_size = 2

@pytest.mark.skip
@pytest.mark.parametrize("param_dict", parameters, ids=names)
def test_train(param_dict):
import tempfile

d = tempfile.mkdtemp()
param_dict["save"] = d

t1 = test_run_checkpoint_test_class()
t1.run_checkpoint_test(param_dict=param_dict)


class test_run_checkpoint_test_class(DistributedTest):
def run_checkpoint_test(yaml_list=None, param_dict=None):

@pytest.mark.parametrize("param_dict", parameters, ids=names)
def test_checkpoint(self, param_dict, tmpdir):
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
print("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")

model, optimizer, lr_scheduler, args_loaded = model_setup(
yaml_list, param_dict, clear_data=True
model, optimizer, lr_scheduler, reference_model, args_loaded = model_setup(
yaml_list=None, param_dict=param_dict, clear_data=True
)
print("CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC")

# save model checkpoint
save_checkpoint(
Expand All @@ -101,8 +93,9 @@ def run_checkpoint_test(yaml_list=None, param_dict=None):
reloaded_model,
reloaded_optimizer,
reloaded_lr_scheduler,
reloaded_reference_model,
args_reloaded,
) = model_setup(yaml_list, param_dict, clear_data=False)
) = model_setup(yaml_list=None, param_dict=param_dict, clear_data=False)
iteration = load_checkpoint(
neox_args=args_reloaded,
model=reloaded_model,
Expand All @@ -111,9 +104,7 @@ def run_checkpoint_test(yaml_list=None, param_dict=None):
)

# ensure same checkpoint is loaded
assert (
iteration == 42
), "run_checkpoint_test() iteration loaded from checkpoint correct"
assert iteration == 42, "Iteration loaded from checkpoint is incorrect"

# check all weight groups are the same
for idx, ((n1, p1), (n2, p2)) in enumerate(
Expand All @@ -123,14 +114,8 @@ def run_checkpoint_test(yaml_list=None, param_dict=None):
)
):
assert n1 == n2
params_equal = (p1 == p2).all().item()
assert params_equal, "run_checkpoint_test() params equal: " + str(n1)

params_equal = torch.all(p1 == p2).item()
assert params_equal, f"Parameters not equal: {n1}"

if __name__ == "__main__":
params = list(
parametrize(
PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None
)
)
test_train(params[0])
# Clean up
del model, reloaded_model
21 changes: 9 additions & 12 deletions tests/model/test_model_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tests.common import DistributedTest, model_setup, parametrize

PARAMS_TO_TEST = {
"include":["localhost:0,1"],
"pipe_parallel_size,model_parallel_size,world_size": [
[0, 1, 1],
[0, 1, 2],
Expand Down Expand Up @@ -63,18 +64,11 @@
PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None
)


@pytest.mark.skip
@pytest.mark.parametrize("param_dict", parameters, ids=names)
def test_train(param_dict):
t1 = run_generate_test_class()
t1.run_generate_test(param_dict, param_dict.pop("prompt"))


class run_generate_test_class(DistributedTest):
class TestModelGeneration(DistributedTest):
world_size = 2

def run_generate_test(param_dict, prompt):
@pytest.mark.parametrize("param_dict", parameters, ids=names)
def test_generate(self, param_dict, tmpdir):
from megatron.text_generation_utils import generate_samples_from_prompt
from megatron.utils import is_mp_rank_0

Expand All @@ -89,10 +83,10 @@ def run_generate_test(param_dict, prompt):
}

param_dict.update(fixed_params)
# TODO: we don't need to reinstantiate the model every time if we're only changing sampling settings - should be a workaround for this
model, _, _, args_loaded = model_setup(None, param_dict, clear_data=True)
model, _, _, _, args_loaded = model_setup(None, param_dict, clear_data=True)
model.eval()

prompt = param_dict.pop("prompt")
prompts = [prompt for _ in range(args_loaded.num_samples)]
output = generate_samples_from_prompt(
neox_args=args_loaded,
Expand All @@ -111,3 +105,6 @@ def run_generate_test(param_dict, prompt):
for prompt, out in zip(prompts, output):
assert prompt == out["context"]
assert len(out["text"]) > 0

# Clean up
del model
2 changes: 1 addition & 1 deletion tests/model/test_model_instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class test_instantiate_optimizers_class(DistributedTest):
def run_test_model_instantiation(yaml_list=None, param_dict=None):
from deepspeed.runtime.pipe.engine import PipelineEngine, DeepSpeedEngine

model, optimizer, lr_scheduler, args_loaded = model_setup(yaml_list, param_dict)
model, optimizer, lr_scheduler, reference_model, args_loaded = model_setup(yaml_list, param_dict)
if args_loaded.pipe_parallel_size < 2:
assert isinstance(
model, DeepSpeedEngine
Expand Down
4 changes: 0 additions & 4 deletions tests/model/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@

keys_to_test = PARAMS_TO_TEST.keys()

# TODO: fix model training tests
@pytest.mark.skip(
reason="All model tests are skipped until we fix the CUDA + torch multiprocessing issue."
)
@pytest.mark.parametrize(
"key, value",
[(key, value) for key in keys_to_test for value in PARAMS_TO_TEST[key]],
Expand Down
3 changes: 0 additions & 3 deletions tests/unit/test_format_conversion_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from megatron.neox_arguments.neox_args import NeoXArgsTokenizer


@pytest.mark.skip(
reason="Conversion test is skipped until we fix the CUDA + torch multiprocessing issue."
)
def test_gpt_neox_to_huggingface(monkeypatch, tmpdir, tmp_path):
# Generate random GPT-NEOX model, check we can convert to hf format

Expand Down
Loading
Loading