Skip to content

Commit

Permalink
feat: detailed logs of used cuda devices, closes #32
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Jan 6, 2025
1 parent 97b3642 commit 2f470da
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 13 deletions.
60 changes: 56 additions & 4 deletions dmlcloud/core/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import csv
import json
import os
import sys
from datetime import datetime, timedelta
from enum import IntEnum
from pathlib import Path
from typing import Callable, Optional, TYPE_CHECKING, Union

import pynvml
import torch
from omegaconf import OmegaConf
from progress_table import ProgressTable
Expand All @@ -16,6 +18,7 @@
from . import logging as dml_logging
from .distributed import all_gather_object, is_root


if TYPE_CHECKING:
from .pipeline import Pipeline
from .stage import Stage
Expand All @@ -33,6 +36,7 @@
'CsvCallback',
'WandbCallback',
'TensorboardCallback',
'CudaCallback',
]


Expand Down Expand Up @@ -91,8 +95,9 @@ class CbPriority(IntEnum):
CHECKPOINT = -190
STAGE_TIMER = -180
DIAGNOSTICS = -170
GIT = -160
METRIC_REDUCTION = -150
CUDA = -160
GIT = -150
METRIC_REDUCTION = -100

OBJECT_METHODS = 0

Expand Down Expand Up @@ -482,13 +487,60 @@ class GitDiffCallback(Callback):
def pre_run(self, pipe):
diff = git_diff()

if pipe.checkpointing_enabled:
if pipe.checkpointing_enabled and is_root():
self._save(pipe.checkpoint_dir.path / 'git_diff.txt', diff)

msg = '* GIT-DIFF:\n'
msg += '\n'.join('\t' + line for line in diff.splitlines())
msg += '\n'.join(' ' + line for line in diff.splitlines())
dml_logging.info(msg)

def _save(self, path, diff):
with open(path, 'w') as f:
f.write(diff)


class CudaCallback(Callback):
"""
Logs various properties pertaining to CUDA devices.
"""

def pre_run(self, pipe):
handle = torch.cuda._get_pynvml_handler(pipe.device)

info = {
'name': pynvml.nvmlDeviceGetName(handle),
'uuid': pynvml.nvmlDeviceGetUUID(handle),
'serial': pynvml.nvmlDeviceGetSerial(handle),
'torch_device': str(pipe.device),
'minor_number': pynvml.nvmlDeviceGetMinorNumber(handle),
'architecture': pynvml.nvmlDeviceGetArchitecture(handle),
'brand': pynvml.nvmlDeviceGetBrand(handle),
'vbios_version': pynvml.nvmlDeviceGetVbiosVersion(handle),
'driver_version': pynvml.nvmlSystemGetDriverVersion(),
'cuda_driver_version': pynvml.nvmlSystemGetCudaDriverVersion_v2(),
'nvml_version': pynvml.nvmlSystemGetNVMLVersion(),
'total_memory': pynvml.nvmlDeviceGetMemoryInfo(handle, pynvml.nvmlMemory_v2).total,
'reserved_memory': pynvml.nvmlDeviceGetMemoryInfo(handle, pynvml.nvmlMemory_v2).reserved,
'num_gpu_cores': pynvml.nvmlDeviceGetNumGpuCores(handle),
'power_managment_limit': pynvml.nvmlDeviceGetPowerManagementLimit(handle),
'power_managment_default_limit': pynvml.nvmlDeviceGetPowerManagementDefaultLimit(handle),
'cuda_compute_capability': pynvml.nvmlDeviceGetCudaComputeCapability(handle),
}
all_devices = all_gather_object(info)

msg = '* CUDA-DEVICES:\n'
info_strings = [
f'{info["torch_device"]} -> /dev/nvidia{info["minor_number"]} -> {info["name"]} (UUID: {info["uuid"]}) (VRAM: {info["total_memory"] / 1000 ** 2:.0f} MB)'
for info in all_devices
]
msg += '\n'.join(f' - [{i}] {info_str}' for i, info_str in enumerate(info_strings))
dml_logging.info(msg)

if pipe.checkpointing_enabled and is_root():
self._save(pipe.checkpoint_dir.path / 'cuda_devices.json', all_devices)

def _save(self, path, all_devices):
with open(path, 'w') as f:
devices = {f'rank_{i}': device for i, device in enumerate(all_devices)}
obj = {'devices': devices}
json.dump(obj, f, indent=4)
4 changes: 4 additions & 0 deletions dmlcloud/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
CbPriority,
CheckpointCallback,
CsvCallback,
CudaCallback,
DiagnosticsCallback,
GitDiffCallback,
TensorboardCallback,
Expand Down Expand Up @@ -178,6 +179,9 @@ def enable_checkpointing(
self.add_callback(CsvCallback(self.checkpoint_dir.path, append_stage_name=True), CbPriority.CSV)
self.add_callback(TensorboardCallback(self.checkpoint_dir.path), CbPriority.TENSORBOARD)

if self.device.type == 'cuda':
self.add_callback(CudaCallback(), CbPriority.CUDA)

def enable_wandb(
self,
project: str | None = None,
Expand Down
12 changes: 3 additions & 9 deletions dmlcloud/util/logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import io
import os
import subprocess
import sys
from datetime import datetime
from pathlib import Path
Expand Down Expand Up @@ -114,21 +113,16 @@ def general_diagnostics() -> str:
msg += f' - backend: {dist.get_backend()}\n'
msg += f' - cuda: {torch.cuda.is_available()}\n'

if torch.cuda.is_available():
msg += '* GPUs (root):\n'
nvsmi = subprocess.run(['nvidia-smi', '-L'], stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.decode()
for line in nvsmi.splitlines():
msg += f' - {line}\n'

msg += '* VERSIONS:\n'
msg += f' - python: {sys.version}\n'
msg += f' - dmlcloud: {dmlcloud.__version__}\n'
msg += f' - cuda: {torch.version.cuda}\n'
msg += f' - cuda (torch): {torch.version.cuda}\n'
try:
msg += ' - ' + Path('/proc/driver/nvidia/version').read_text().splitlines()[0] + '\n'
except (FileNotFoundError, IndexError):
pass

msg += f' - dmlcloud: {dmlcloud.__version__}\n'

for module_name in ML_MODULES:
if is_imported(module_name):
msg += f' - {module_name}: {try_get_version(module_name)}\n'
Expand Down

0 comments on commit 2f470da

Please sign in to comment.