Skip to content

Commit

Permalink
Merge branch 'main' into docker_builds
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka authored Mar 20, 2024
2 parents e44fa64 + a06a1a6 commit 877f0ea
Show file tree
Hide file tree
Showing 25 changed files with 1,083 additions and 209 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fsdp_config:
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 2
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
_deps = [
"setuptools<=59.5.0",
"pyyaml>=5.0.0",
"numpy>=1.0.0",
"numpy>=1.17.0",
"matplotlib>=3.0.0",
"merge-args>=0.1.0",
"onnx>=1.5.0,<1.15.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def _clear_missing_keys(module, incompatible_keys):
self.register_load_state_dict_post_hook(_clear_missing_keys)

def forward(self, *args, **kwargs):
self.teacher_model.eval()
if not self.kd_enabled:
return self.student_model(*args, **kwargs)

Expand Down Expand Up @@ -118,6 +117,13 @@ def named_modules(
memo=memo, prefix=prefix, remove_duplicate=remove_duplicate
)

def named_children(self):
return self.student_model.named_children()

def train(self, mode: bool = True):
self.student_model.train(mode)
return self

def __getattr__(self, name: str) -> Any:
try:
return super().__getattr__(name)
Expand Down
8 changes: 6 additions & 2 deletions src/sparseml/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,21 @@ def reload_model_from_checkpoint(model: Module, checkpoint: Optional[str] = None


def save_model_and_recipe(
model: Module, save_path: str, tokenizer: Optional[Any] = None
model: Module,
save_path: str,
tokenizer: Optional[Any] = None,
save_safetensors: bool = False,
):
"""
Save a model, tokenizer and the currently loaded recipe to file
:param model: pytorch model to save
:param save_path: path to save output to
:param tokenizer: model tokenizer to save
:param save_safetensors: whether to save as safetensors or pickle (bin)
"""

model.save_pretrained(save_path)
model.save_pretrained(save_path, safe_serialization=save_safetensors)

if tokenizer is not None:
tokenizer.save_pretrained(save_path)
Expand Down
Loading

0 comments on commit 877f0ea

Please sign in to comment.