Skip to content

Commit

Permalink
update load weights
Browse files Browse the repository at this point in the history
  • Loading branch information
Cui-yshoho committed Dec 19, 2024
1 parent 5f5da70 commit 922ab29
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 20 deletions.
4 changes: 2 additions & 2 deletions mindone/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[
local_state = {k: v for k, v in model_to_load.parameters_and_names()}
for k, v in state_dict.items():
if k in local_state:
v.set_dtype(local_state[k].dtype)
state_dict[k] = ms.Parameter(v.to(local_state[k].dtype), name=k)
else:
pass # unexpect key keeps origin dtype
state_dict[k] = ms.Parameter(v, name=k) # unexpect key keeps origin dtype
ms.load_param_into_net(model_to_load, state_dict, strict_load=True)
return error_msgs

Expand Down
21 changes: 10 additions & 11 deletions mindone/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import mindspore as ms
from mindspore import nn, ops
from mindspore.nn.utils import no_init_parameters

from mindone.safetensors.mindspore import save_file as safe_save_file

Expand Down Expand Up @@ -61,9 +62,7 @@ def _get_pt2ms_mappings(m):
mappings = {} # pt_param_name: (ms_param_name, pt_param_to_ms_param_func)
for name, cell in m.cells_and_names():
if isinstance(cell, (nn.Conv1d, nn.Conv1dTranspose)):
mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ms.Parameter(
ops.expand_dims(x, axis=-2), name=x.name
)
mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ops.expand_dims(x, axis=-2)
elif isinstance(cell, nn.Embedding):
mappings[f"{name}.weight"] = f"{name}.embedding_table", lambda x: x
elif isinstance(cell, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
Expand Down Expand Up @@ -608,8 +607,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
user_agent=user_agent,
commit_hash=commit_hash,
)
with no_init_parameters():
model = cls.from_config(config, **unused_kwargs)

model = cls.from_config(config, **unused_kwargs)
if mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type):
raise ValueError(
f"{mindspore_dtype} needs to be of type `ms.Type`, e.g. `ms.float16`, but is {type(mindspore_dtype)}."
)
elif mindspore_dtype is not None:
model = model.to(mindspore_dtype)

if is_sharded:
load_checkpoint_and_dispatch(
Expand Down Expand Up @@ -637,13 +643,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
"error_msgs": error_msgs,
}

if mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type):
raise ValueError(
f"{mindspore_dtype} needs to be of type `ms.Type`, e.g. `ms.float16`, but is {type(mindspore_dtype)}."
)
elif mindspore_dtype is not None:
model = model.to(mindspore_dtype)

model.register_to_config(_name_or_path=pretrained_model_name_or_path)

# Set model in evaluation mode to deactivate DropOut modules by default
Expand Down
2 changes: 1 addition & 1 deletion mindone/safetensors/mindspore.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, ms.Tensor]:

def _np2ms(np_dict: Dict[str, np.ndarray]) -> Dict[str, ms.Tensor]:
for k, v in np_dict.items():
np_dict[k] = ms.Parameter(v, name=k)
np_dict[k] = ms.tensor(v)
return np_dict


Expand Down
12 changes: 6 additions & 6 deletions mindone/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

import mindspore as ms
from mindspore import Tensor, nn, ops
from mindspore.nn.utils import no_init_parameters

from .integrations import PeftAdapterMixin
from .modeling_attn_mask_utils import dtype_to_min
Expand All @@ -71,9 +72,7 @@ def _get_pt2ms_mappings(m):
mappings = {} # pt_param_name: (ms_param_name, pt_param_to_ms_param_func)
for name, cell in m.cells_and_names():
if isinstance(cell, (nn.Conv1d, nn.Conv1dTranspose)):
mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ms.Parameter(
ops.expand_dims(x, axis=-2), name=x.name
)
mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ops.expand_dims(x, axis=-2)
elif isinstance(cell, nn.Embedding):
mappings[f"{name}.weight"] = f"{name}.embedding_table", lambda x: x
elif isinstance(cell, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
Expand Down Expand Up @@ -294,9 +293,9 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_shar
local_state = {start_prefix + k: v for k, v in model_to_load.parameters_and_names()}
for k, v in state_dict.items():
if k in local_state:
v.set_dtype(local_state[k].dtype)
state_dict[k] = ms.Parameter(v.to(local_state[k].dtype), name=k)
else:
pass # unexpect key keeps origin dtype
state_dict[k] = ms.Parameter(v, name=k) # unexpect key keeps origin dtype
cm = silence_mindspore_logger() if is_sharded else nullcontext()
with cm:
ms.load_param_into_net(model_to_load, state_dict, strict_load=True)
Expand Down Expand Up @@ -1730,7 +1729,8 @@ def from_pretrained(

config.name_or_path = pretrained_model_name_or_path
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
model = cls(config, *model_args, **model_kwargs)
with no_init_parameters():
model = cls(config, *model_args, **model_kwargs)
# We cannot set default mindspore dtype. So we need to cast model weights after creating.
if mindspore_dtype is not None:
model = model.to(mindspore_dtype)
Expand Down

0 comments on commit 922ab29

Please sign in to comment.