Skip to content

Commit

Permalink
Refactor HF checkpointer (#1690)
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress authored Jan 30, 2025
1 parent 63a733d commit cc0df9f
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 203 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
uses: actions/checkout@v3
with:
repository: mosaicml/ci-testing
ref: v0.2.2
ref: v0.3.3
path: ./ci-testing
- uses: ./ci-testing/.github/actions/coverage
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
pytest_command: "coverage run -m pytest"
steps:
- name: Run PR CPU Tests
uses: mosaicml/ci-testing/.github/actions/pytest-cpu@v0.2.2
uses: mosaicml/ci-testing/.github/actions/pytest-cpu@v0.3.2
with:
name: ${{ matrix.name }}
container: ${{ matrix.container }}
Expand Down
316 changes: 117 additions & 199 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _maybe_get_license_filename(
def _log_model_with_multi_process(
mlflow_logger: MLFlowLogger,
python_logging_level: int,
transformers_model: str,
transformers_model: Union[dict[str, Any], str],
artifact_path: str,
pretrained_model_name: str,
registered_model_name: Optional[str],
Expand Down Expand Up @@ -213,35 +213,6 @@ def save_model_patch(*args: Any, **kwargs: Any):
)


def _register_model_with_run_id_multiprocess(
mlflow_logger: MLFlowLogger,
composer_logging_level: int,
model_uri: str,
name: str,
await_creation_for: int,
):
"""Call MLFlowLogger.register_model_with_run_id.
Used mainly to register from a child process.
"""
# Setup logging for child process. This ensures that any logs from composer are surfaced.
if composer_logging_level > 0:
# If logging_level is 0, then the composer logger was unset.
logging.basicConfig(
format=
f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s',
force=True,
)
logging.getLogger('composer').setLevel(composer_logging_level)

# Register model.
mlflow_logger.register_model_with_run_id(
model_uri=model_uri,
name=name,
await_creation_for=await_creation_for,
)


class HuggingFaceCheckpointer(Callback):
"""Save a huggingface formatted checkpoint during training.
Expand Down Expand Up @@ -582,23 +553,7 @@ def transform_model_pre_registration(
"""
return model

def _save_checkpoint(
self,
state: State,
logger: Logger,
upload_to_save_folder: bool,
register_to_mlflow: bool,
):
"""Save a HuggingFace formatted checkpoint.
Args:
state (State): The training state.
logger (Logger): The logger.
upload_to_save_folder (bool): Whether to upload the HF checkpoint to the save folder.
register_to_mlflow (bool): Whether to register the model to MLFlow
"""
del logger # unused

def _get_hf_model(self, state: State):
self.last_checkpoint_batch = state.timestamp.batch

log.info('Saving HuggingFace formatted checkpoint')
Expand All @@ -608,19 +563,6 @@ def _save_checkpoint(
MPTConfig.register_for_auto_class()
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')

save_dir = format_name_with_dist_and_time(
str(
Path(self.save_dir_format_str) /
self.huggingface_folder_name_fstr,
),
state.run_name,
state.timestamp,
)

# Use a temporary directory if save_dir is remote.
use_temp_dir = self.remote_ud is not None
temp_save_dir = tempfile.mkdtemp() if use_temp_dir else save_dir

log.debug('Gathering state dict')

if state.is_model_ddp:
Expand Down Expand Up @@ -735,6 +677,107 @@ def tensor_hook(

log.debug('Saving Hugging Face checkpoint to disk')

return new_model_instance, original_tokenizer

def _register_hf_model(
self,
temp_save_dir: str,
original_tokenizer: PreTrainedTokenizerBase,
use_temp_dir: bool,
new_model_instance: PreTrainedModel,
):
assert new_model_instance is not None
new_model_instance = self.transform_model_pre_registration(
new_model_instance,
)
register_save_dir = os.path.join(
temp_save_dir,
'register_save',
)
new_model_instance.save_pretrained(
register_save_dir,
max_shard_size='1GB',
)
if original_tokenizer:
original_tokenizer.save_pretrained(register_save_dir)

self.pre_register_edit(register_save_dir)

for mlflow_logger in self.mlflow_loggers:
if self.mlflow_registered_model_name:
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}',
)

# Save the monitor process to be restored after registering the model.
with _monitor_process_saver(mlflow_logger):
process = SpawnProcess(
target=_log_model_with_multi_process,
kwargs={
'mlflow_logger':
mlflow_logger,
'python_logging_level':
logging.getLogger('llmfoundry').level,
'transformers_model': {
'model': new_model_instance,
'tokenizer': original_tokenizer,
} if self.using_peft else register_save_dir,
'artifact_path':
'final_model_checkpoint',
'pretrained_model_name':
self.pretrained_model_name,
'registered_model_name':
self.mlflow_registered_model_name,
'await_registration_for':
3600,
'mlflow_logging_config':
self.mlflow_logging_config,
},
)

process.start()
self.register_processes.append(process)

# Save the temporary directory to be cleaned up later.
if use_temp_dir:
self.temp_save_dir = temp_save_dir

def _save_checkpoint(
self,
state: State,
logger: Logger,
upload_to_save_folder: bool,
register_to_mlflow: bool,
):
"""Save a HuggingFace formatted checkpoint.
Args:
state (State): The training state.
logger (Logger): The logger.
upload_to_save_folder (bool): Whether to upload the HF checkpoint to the save folder.
register_to_mlflow (bool): Whether to register the model to MLFlow
"""
del logger # unused

save_dir = format_name_with_dist_and_time(
str(
Path(self.save_dir_format_str) /
self.huggingface_folder_name_fstr,
),
state.run_name,
state.timestamp,
)

# Use a temporary directory if save_dir is remote.
use_temp_dir = self.remote_ud is not None
temp_save_dir = tempfile.mkdtemp() if use_temp_dir else save_dir

new_model_instance, original_tokenizer = self._get_hf_model(state)

dist.barrier()

if dist.get_global_rank() == 0:
assert new_model_instance is not None
if upload_to_save_folder:
# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
Expand Down Expand Up @@ -783,150 +826,25 @@ def tensor_hook(
dist.barrier()

if dist.get_global_rank() == 0:
assert new_model_instance is not None
if self.using_peft:
model_name = self.mlflow_logging_config.get('metadata', {}).get(
'pretrained_model_name',
None,
)
if model_name is not None:
new_model_instance.name_or_path = model_name
new_model_instance.model.name_or_path = model_name
new_model_instance.base_model.name_or_path = model_name
if register_to_mlflow:
assert new_model_instance is not None
new_model_instance = self.transform_model_pre_registration(
self._register_hf_model(
temp_save_dir,
original_tokenizer,
use_temp_dir,
new_model_instance,
)
if self.using_peft:

# Save and register peft model to mlflow, this code path uses our older two step logic
self._save_and_register_peft_model(
state,
new_model_instance,
original_tokenizer,
temp_save_dir,
)
else:
register_save_dir = os.path.join(
temp_save_dir,
'register_save',
)
new_model_instance.save_pretrained(
register_save_dir,
max_shard_size='1GB',
)
if original_tokenizer:
original_tokenizer.save_pretrained(register_save_dir)

self.pre_register_edit(register_save_dir)

for mlflow_logger in self.mlflow_loggers:
if self.mlflow_registered_model_name:
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}',
)

# Save the monitor process to be restored after registering the model.
with _monitor_process_saver(mlflow_logger):
process = SpawnProcess(
target=_log_model_with_multi_process,
kwargs={
'mlflow_logger':
mlflow_logger,
'python_logging_level':
logging.getLogger('llmfoundry').level,
'transformers_model':
register_save_dir,
'artifact_path':
'final_model_checkpoint',
'pretrained_model_name':
self.pretrained_model_name,
'registered_model_name':
self.mlflow_registered_model_name,
'await_registration_for':
3600,
'mlflow_logging_config':
self.mlflow_logging_config,
},
)

process.start()
self.register_processes.append(process)

# Save the temporary directory to be cleaned up later.
if use_temp_dir:
self.temp_save_dir = temp_save_dir
else:
# Clean up the temporary directory if we don't need to register to mlflow.
if use_temp_dir:
shutil.rmtree(temp_save_dir)
dist.barrier()

def _save_and_register_peft_model(
self,
state: State,
new_model_instance: Any,
original_tokenizer: Optional[Any],
save_dir: str,
):
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer

log.debug('Logging Hugging Face model to MLFlow')
for i, mlflow_logger in enumerate(self.mlflow_loggers):
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}',
)

local_save_path = str(Path(save_dir) / f'mlflow_save_{i}',)

# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
import mlflow.store
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''

model_saving_kwargs: dict[str, Any] = {
'path': local_save_path,
}
model_saving_kwargs['flavor'] = 'peft'
model_saving_kwargs['save_pretrained_dir'] = save_dir
model_saving_kwargs['metadata'] = self.mlflow_logging_config[
'metadata']

context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
# Add the pip requirements directly to avoid mlflow
# attempting to run inference on the model
model_saving_kwargs['pip_requirements'] = [
'transformers',
'torch',
]
mlflow_logger.save_model(**model_saving_kwargs)

# Upload the license file generated by mlflow during the model saving.
# Get and log the license file.
license_filename = _maybe_get_license_filename(
local_save_path,
self.pretrained_model_name,
)
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
mlflow_logger._run_id,
os.path.join(local_save_path, license_filename),
)

self.pre_register_edit(local_save_path)

with _monitor_process_saver(mlflow_logger):
process = SpawnProcess(
target=_register_model_with_run_id_multiprocess,
kwargs={
'mlflow_logger':
mlflow_logger,
'composer_logging_level':
logging.getLogger('composer').level,
'model_uri':
local_save_path,
'name':
self.mlflow_registered_model_name,
'await_creation_for':
3600,
},
)
process.start()
self.register_processes.append(process)
2 changes: 0 additions & 2 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,7 +1167,6 @@ def transform_model_pre_registration(self, model: PreTrainedModel):
assert model_cfg is not None
assert tokenizer_name is not None

checkpointer._save_and_register_peft_model = MagicMock()
checkpointer.using_peft = True
checkpointer._save_checkpoint(
state=state,
Expand All @@ -1176,7 +1175,6 @@ def transform_model_pre_registration(self, model: PreTrainedModel):
register_to_mlflow=True,
)

checkpointer._save_and_register_peft_model.assert_not_called()
assert mlflow_logger_mock.log_model.call_count == 1


Expand Down

0 comments on commit cc0df9f

Please sign in to comment.