You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For a model with QuantumModule, torch.save(model.state_dict(), "model.pt") and model.load_state_dict(torch.load("model.pt")) may not work because state keys are lazily created during the forwarding process.
The save and load example in Save and Load QNN models may not directly work. Without commenting on the training part, I trained the model and saved the state_dict with:
torch.save(model.state_dict(), "model.pt")
When I tried to load a newly created model, the model may not have some keys in state_dict because these keys are lazily created. For example, load a new model with:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[4], [line 2](vscode-notebook-cell:?execution_count=4&line=2)
[1](vscode-notebook-cell:?execution_count=4&line=1) model2 = Model().to(device)
----> [2](vscode-notebook-cell:?execution_count=4&line=2) model2.load_state_dict(torch.load("model.pt"))
File [d:\Programming\Anaconda3\envs\torchquantum\lib\site-packages\torch\nn\modules\module.py:2153](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2153), in Module.load_state_dict(self, state_dict, strict, assign)
[2148](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2148) error_msgs.insert(
[2149](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2149) 0, 'Missing key(s) in state_dict: {}. '.format(
[2150](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2150) ', '.join(f'"{k}"' for k in missing_keys)))
[2152](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2152) if len(error_msgs) > 0:
-> [2153](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2153) raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
[2154](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2154) self.__class__.__name__, "\n\t".join(error_msgs)))
[2155](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2155) return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for Model:
Unexpected key(s) in state_dict: "qf.q_layer.q_device.state", "qf.q_layer.q_device.states".
Potential Solution
The cause is that during the forwarding process, the model may create new features.
For example, in the above example, in TrainableQuanvFilter, self.q_layer is created inside __init__ with self.q_layer = U3CU3Layer0(self.arch). torchquantum/layer/layers/u3_layer.py
U3CU3Layer0 is inherited from LayerTemplate0 and its forward() method is also inherited from LayerTemplate0. Inside forward(), a new feature is appended to the object, and this will only be appended when forwarding the model: torchquantum/layer/layers/layers.py
One solution is to finish creating all features during __init__, but I am not familiar with torchquantum's design principle. Since forward() requires q_device as an input, and this input is to be assigned to the feature, it may be designed to be used for many devices. So this change may require a large interface change, and may need to create one layer object for each device instead of reusing only one layer object.
Another way is every time to load the state_dict, forward the model once. This may be time-consuming when the model is large.
Related Issues
These issues may be related to this issue.
#210 is not resolved yet. #49 provided the save and load example, but as explained, this will not work. In the example, it directly saves and loads, so keys in state_dict are not missing.
The text was updated successfully, but these errors were encountered:
Problem
For a model with
QuantumModule
, torch.save(model.state_dict(), "model.pt") and model.load_state_dict(torch.load("model.pt")) may not work because state keys are lazily created during the forwarding process.Example
I used the
Model1
in Quantum Convolution (Quanvolution) example. The detailed code is below:The save and load example in Save and Load QNN models may not directly work. Without commenting on the training part, I trained the model and saved the state_dict with:
When I tried to load a newly created model, the model may not have some keys in state_dict because these keys are lazily created. For example, load a new model with:
The error is:
Potential Solution
The cause is that during the forwarding process, the model may create new features.
For example, in the above example, in
TrainableQuanvFilter
,self.q_layer
is created inside__init__
withself.q_layer = U3CU3Layer0(self.arch)
. torchquantum/layer/layers/u3_layer.pyU3CU3Layer0
is inherited fromLayerTemplate0
and itsforward()
method is also inherited fromLayerTemplate0
. Insideforward()
, a new feature is appended to the object, and this will only be appended when forwarding the model: torchquantum/layer/layers/layers.pyOne solution is to finish creating all features during
__init__
, but I am not familiar with torchquantum's design principle. Sinceforward()
requires q_device as an input, and this input is to be assigned to the feature, it may be designed to be used for many devices. So this change may require a large interface change, and may need to create one layer object for each device instead of reusing only one layer object.Another way is every time to load the state_dict, forward the model once. This may be time-consuming when the model is large.
Related Issues
These issues may be related to this issue.
#210 is not resolved yet.
#49 provided the save and load example, but as explained, this will not work. In the example, it directly saves and loads, so keys in state_dict are not missing.
The text was updated successfully, but these errors were encountered: