-
-
Notifications
You must be signed in to change notification settings - Fork 618
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Pytorch Load / Save Plugin This plugin checks for the use of `torch.load` and `torch.save`. Using `torch.load` with untrusted data can lead to arbitrary code execution, and improper use of `torch.save` might expose sensitive data or lead to data corruption. Signed-off-by: Luke Hinds <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add missing save check Signed-off-by: Luke Hinds <[email protected]> * Review fixes from 8b92a02 Signed-off-by: Luke Hinds <[email protected]> * Fix tox issues Signed-off-by: Luke Hinds <[email protected]> * Review fixes Signed-off-by: Luke Hinds <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_functional.py * Update bandit/plugins/pytorch_load_save.py Co-authored-by: Eric Brown <[email protected]> * Update bandit/plugins/pytorch_load_save.py Co-authored-by: Eric Brown <[email protected]> * Update doc/source/plugins/b704_pytorch_load_save.rst Co-authored-by: Eric Brown <[email protected]> * Update bandit/plugins/pytorch_load_save.py Co-authored-by: Eric Brown <[email protected]> --------- Signed-off-by: Luke Hinds <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Brown <[email protected]>
- Loading branch information
1 parent
4ac55df
commit 36fd650
Showing
5 changed files
with
109 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Copyright (c) 2024 Stacklok, Inc. | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
r""" | ||
========================================== | ||
B614: Test for unsafe PyTorch load or save | ||
========================================== | ||
This plugin checks for the use of `torch.load` and `torch.save`. Using | ||
`torch.load` with untrusted data can lead to arbitrary code execution, and | ||
improper use of `torch.save` might expose sensitive data or lead to data | ||
corruption. A safe alternative is to use `torch.load` with the `safetensors` | ||
library from hugingface, which provides a safe deserialization mechanism. | ||
:Example: | ||
.. code-block:: none | ||
>> Issue: Use of unsafe PyTorch load or save | ||
Severity: Medium Confidence: High | ||
CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html) | ||
Location: examples/pytorch_load_save.py:8 | ||
7 loaded_model.load_state_dict(torch.load('model_weights.pth')) | ||
8 another_model.load_state_dict(torch.load('model_weights.pth', | ||
map_location='cpu')) | ||
9 | ||
10 print("Model loaded successfully!") | ||
.. seealso:: | ||
- https://cwe.mitre.org/data/definitions/94.html | ||
- https://pytorch.org/docs/stable/generated/torch.load.html#torch.load | ||
- https://github.com/huggingface/safetensors | ||
.. versionadded:: 1.7.10 | ||
""" | ||
import bandit | ||
from bandit.core import issue | ||
from bandit.core import test_properties as test | ||
|
||
|
||
@test.checks("Call") | ||
@test.test_id("B614") | ||
def pytorch_load_save(context): | ||
""" | ||
This plugin checks for the use of `torch.load` and `torch.save`. Using | ||
`torch.load` with untrusted data can lead to arbitrary code execution, | ||
and improper use of `torch.save` might expose sensitive data or lead | ||
to data corruption. | ||
""" | ||
imported = context.is_module_imported_exact("torch") | ||
qualname = context.call_function_name_qual | ||
if not imported and isinstance(qualname, str): | ||
return | ||
|
||
qualname_list = qualname.split(".") | ||
func = qualname_list[-1] | ||
if all( | ||
[ | ||
"torch" in qualname_list, | ||
func in ["load", "save"], | ||
not context.check_call_arg_value("map_location", "cpu"), | ||
] | ||
): | ||
return bandit.Issue( | ||
severity=bandit.MEDIUM, | ||
confidence=bandit.HIGH, | ||
text="Use of unsafe PyTorch load or save", | ||
cwe=issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA, | ||
lineno=context.get_lineno_for_call_arg("load"), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
----------------------- | ||
B614: pytorch_load_save | ||
----------------------- | ||
|
||
.. automodule:: bandit.plugins.pytorch_load_save |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
import torchvision.models as models | ||
|
||
# Example of saving a model | ||
model = models.resnet18(pretrained=True) | ||
torch.save(model.state_dict(), 'model_weights.pth') | ||
|
||
# Example of loading the model weights in an insecure way | ||
loaded_model = models.resnet18() | ||
loaded_model.load_state_dict(torch.load('model_weights.pth')) | ||
|
||
# Save the model | ||
torch.save(loaded_model.state_dict(), 'model_weights.pth') | ||
|
||
# Another example using torch.load with more parameters | ||
another_model = models.resnet18() | ||
another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu')) | ||
|
||
# Save the model | ||
torch.save(another_model.state_dict(), 'model_weights.pth') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters