Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use filter arg in tarfile.extractall to prevent unsafe unarchival operations #2722

Merged
merged 4 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/release_notes_updated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ jobs:
- name: Check for development branch
id: branch
shell: python
env:
REF: ${{ github.event.pull_request.head.ref }}
run: |
from re import compile
main = '^main$'
Expand All @@ -19,7 +21,7 @@ jobs:
min_dep_update = '^min-dep-update-[a-f0-9]{7}$'
regex = main, release, backport, dep_update, min_dep_update
patterns = list(map(compile, regex))
ref = "${{ github.event.pull_request.head.ref }}"
ref = "$REF"
is_dev = not any(pattern.match(ref) for pattern in patterns)
print('::set-output name=is_dev::' + str(is_dev))
- if: ${{ steps.branch.outputs.is_dev == 'true' }}
Expand Down
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Future Release
* Temporarily restrict Dask version (:pr:`2694`)
* Remove support for creating ``EntitySets`` from Dask or Pyspark dataframes (:pr:`2705`)
* Bump minimum versions of ``tqdm`` and ``pip`` in requirements files (:pr:`2716`)
* Use ``filter`` arg in call to ``tarfile.extractall`` to safely deserialize EntitySets (:pr:`2722`)
thehomebrewnerd marked this conversation as resolved.
Show resolved Hide resolved
* Documentation Changes
* Testing Changes
* Fix serialization test to work with pytest 8.1.1 (:pr:`2694`)
Expand Down
10 changes: 9 additions & 1 deletion featuretools/entityset/deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import tarfile
import tempfile
from inspect import getfullargspec

import pandas as pd
import woodwork.type_sys.type_system as ww_type_system
Expand Down Expand Up @@ -140,6 +141,8 @@ def read_data_description(path):
def read_entityset(path, profile_name=None, **kwargs):
"""Read entityset from disk, S3 path, or URL.

NOTE: Never attempt to read an archived EntitySet from an untrusted source.

Args:
path (str): Directory on disk, S3 path, or URL to read `data_description.json`.
profile_name (str, bool): The AWS profile specified to write to S3. Will default to None and search for AWS credentials.
Expand All @@ -159,7 +162,12 @@ def read_entityset(path, profile_name=None, **kwargs):
use_smartopen_es(local_path, path, transport_params)

with tarfile.open(str(local_path)) as tar:
tar.extractall(path=tmpdir)
if "filter" in getfullargspec(tar.extractall).kwonlyargs:
tar.extractall(path=tmpdir, filter="data")
else:
raise RuntimeError(
"Please upgrade your Python version to the latest patch release to allow for safe extraction of the EntitySet archive.",
)

data_description = read_data_description(tmpdir)
return description_to_entityset(data_description, **kwargs)
Expand Down
14 changes: 13 additions & 1 deletion featuretools/tests/entityset_tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
import tempfile
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from urllib.request import urlretrieve

import boto3
Expand Down Expand Up @@ -292,6 +292,18 @@ def test_deserialize_local_tar(es):
assert es.__eq__(new_es, deep=True)


@patch("featuretools.entityset.deserialize.getfullargspec")
def test_deserialize_errors_if_python_version_unsafe(mock_inspect, es):
mock_response = MagicMock()
mock_response.kwonlyargs = []
mock_inspect.return_value = mock_response
with tempfile.TemporaryDirectory() as tmp_path:
temp_tar_filepath = os.path.join(tmp_path, TEST_FILE)
urlretrieve(URL, filename=temp_tar_filepath)
with pytest.raises(RuntimeError, match=""):
deserialize.read_entityset(temp_tar_filepath)


def test_deserialize_url_csv(es):
new_es = deserialize.read_entityset(URL)
assert es.__eq__(new_es, deep=True)
Expand Down
Loading