diff --git a/bandit/plugins/pytorch_load_save.py b/bandit/plugins/pytorch_load_save.py new file mode 100644 index 000000000..77522da22 --- /dev/null +++ b/bandit/plugins/pytorch_load_save.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024 Stacklok, Inc. +# +# SPDX-License-Identifier: Apache-2.0 +r""" +========================================== +B614: Test for unsafe PyTorch load or save +========================================== + +This plugin checks for the use of `torch.load` and `torch.save`. Using +`torch.load` with untrusted data can lead to arbitrary code execution, and +improper use of `torch.save` might expose sensitive data or lead to data +corruption. A safe alternative is to use `torch.load` with the `safetensors` +library from hugingface, which provides a safe deserialization mechanism. + +:Example: + +.. code-block:: none + + >> Issue: Use of unsafe PyTorch load or save + Severity: Medium Confidence: High + CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html) + Location: examples/pytorch_load_save.py:8 + 7 loaded_model.load_state_dict(torch.load('model_weights.pth')) + 8 another_model.load_state_dict(torch.load('model_weights.pth', + map_location='cpu')) + 9 + 10 print("Model loaded successfully!") + +.. seealso:: + + - https://cwe.mitre.org/data/definitions/94.html + - https://pytorch.org/docs/stable/generated/torch.load.html#torch.load + - https://github.com/huggingface/safetensors + +.. versionadded:: 1.7.10 + +""" +import bandit +from bandit.core import issue +from bandit.core import test_properties as test + + +@test.checks("Call") +@test.test_id("B614") +def pytorch_load_save(context): + """ + This plugin checks for the use of `torch.load` and `torch.save`. Using + `torch.load` with untrusted data can lead to arbitrary code execution, + and improper use of `torch.save` might expose sensitive data or lead + to data corruption. + """ + imported = context.is_module_imported_exact("torch") + qualname = context.call_function_name_qual + if not imported and isinstance(qualname, str): + return + + qualname_list = qualname.split(".") + func = qualname_list[-1] + if all( + [ + "torch" in qualname_list, + func in ["load", "save"], + not context.check_call_arg_value("map_location", "cpu"), + ] + ): + return bandit.Issue( + severity=bandit.MEDIUM, + confidence=bandit.HIGH, + text="Use of unsafe PyTorch load or save", + cwe=issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA, + lineno=context.get_lineno_for_call_arg("load"), + ) diff --git a/doc/source/plugins/b704_pytorch_load_save.rst b/doc/source/plugins/b704_pytorch_load_save.rst new file mode 100644 index 000000000..dcc1ae3a0 --- /dev/null +++ b/doc/source/plugins/b704_pytorch_load_save.rst @@ -0,0 +1,5 @@ +----------------------- +B614: pytorch_load_save +----------------------- + +.. automodule:: bandit.plugins.pytorch_load_save diff --git a/examples/pytorch_load_save.py b/examples/pytorch_load_save.py new file mode 100644 index 000000000..e1f912022 --- /dev/null +++ b/examples/pytorch_load_save.py @@ -0,0 +1,21 @@ +import torch +import torchvision.models as models + +# Example of saving a model +model = models.resnet18(pretrained=True) +torch.save(model.state_dict(), 'model_weights.pth') + +# Example of loading the model weights in an insecure way +loaded_model = models.resnet18() +loaded_model.load_state_dict(torch.load('model_weights.pth')) + +# Save the model +torch.save(loaded_model.state_dict(), 'model_weights.pth') + +# Another example using torch.load with more parameters +another_model = models.resnet18() +another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu')) + +# Save the model +torch.save(another_model.state_dict(), 'model_weights.pth') + diff --git a/setup.cfg b/setup.cfg index 52128b17d..23c20cc56 100644 --- a/setup.cfg +++ b/setup.cfg @@ -152,6 +152,9 @@ bandit.plugins = #bandit/plugins/tarfile_unsafe_members.py tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members + #bandit/plugins/pytorch_load_save.py + pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save + # bandit/plugins/trojansource.py trojansource = bandit.plugins.trojansource:trojansource diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 681e45edf..d8241142b 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -930,6 +930,14 @@ def test_tarfile_unsafe_members(self): } self.check_example("tarfile_extractall.py", expect) + def test_pytorch_load_save(self): + """Test insecure usage of torch.load and torch.save.""" + expect = { + "SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 4, "HIGH": 0}, + "CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 4}, + } + self.check_example("pytorch_load_save.py", expect) + def test_trojansource(self): expect = { "SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 1},