Skip to content

Commit

Permalink
Fixed handling of overwrite flags is s3pathhandler and CircleCI error…
Browse files Browse the repository at this point in the history
…s. (facebookresearch#11)

Summary: Pull Request resolved: fairinternal/iopath#11

Test Plan:
Imported from GitHub, without a `Test Plan:` line.

buck test mode/opt  //fair_infra/data/iopath/tests:iopath_test

Reviewed By: kkondaka

Differential Revision: D32680188

Pulled By: sujitoc

fbshipit-source-id: f0fca7de34a15ab0217a01306fe87f46126c23d6
  • Loading branch information
Sujit Verma authored and facebook-github-bot committed Dec 13, 2021
1 parent 709157d commit 0e2abaf
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
13 changes: 10 additions & 3 deletions iopath/common/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,11 @@ def _copy_from_local(
"S3PathHandler does not currently support uploading directories"
)

if not overwrite and self._exists(dst_path):
logger = logging.getLogger(__name__)
logger.error("Error: Destination path {} already exists.".format(dst_path))
return False

bucket, s3_path = self._parse_uri(dst_path)
client = self._get_client(bucket)
try:
Expand Down Expand Up @@ -423,6 +428,11 @@ def _copy(
"""
self._check_kwargs(kwargs)

if not overwrite and self._exists(dst_path):
logger = logging.getLogger(__name__)
logger.error("Error: Destination path {} already exists.".format(dst_path))
return False

src_bucket, src_s3_path = self._parse_uri(src_path)
dst_bucket, dst_s3_path = self._parse_uri(dst_path)
assert src_bucket == dst_bucket, "For now, can only _copy() within a bucket."
Expand Down Expand Up @@ -756,6 +766,3 @@ def _read_chunk_to_buffer(self, start_offset: int) -> None:
self.buffer.seek(0)
self.buffer.write(ret)
self.buffered_window = download_range

def read1(self, size: int = -1) -> bytes:
return self.read(size)
24 changes: 19 additions & 5 deletions tests/test_s3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import io
import os
import unittest
from unittest.mock import patch

Expand All @@ -14,10 +13,26 @@
except ImportError:
boto3 = None

test_bucket = "TEST_BUCKET_NAME_REPLACE_ME"
test_rel_path = "TEST_REL_PATH_REPLACE_ME"


def test_bucket_defined():
return (
test_bucket != "TEST_BUCKET_NAME_REPLACE_ME"
and test_rel_path != "TEST_REL_PATH_REPLACE_ME"
)


@unittest.skipIf(not boto3, "Requires boto3 install")
@unittest.skipIf(
test_bucket == "TEST_BUCKET_NAME_REPLACE_ME", "Test Bucket not specified."
)
@unittest.skipIf(
test_rel_path == "TEST_REL_PATH_REPLACE_ME", "Test relative path not specified."
)
class TestsS3(unittest.TestCase):
s3_auth = True
s3_auth = test_bucket == test_bucket_defined()
skip_s3_auth_required_tests_message = (
"Provide an s3 project and bucket you are"
+ "authorised against, then set the s3_auth flag to True"
Expand All @@ -33,9 +48,9 @@ def run(self, result=None):
@classmethod
def setUpClass(cls):
# NOTE: user should change this location.
cls.s3_bucket = "TEST_BUCKET_NAME_REPLACE_ME"
cls.s3_bucket = test_bucket
# NOTE: user should change this to a valid bucket path that is accessible.
cls.s3_rel_path = "TEST_REL_PATH_REPLACE_ME"
cls.s3_rel_path = test_rel_path
cls.s3_full_path = "s3://" + cls.s3_bucket + "/" + cls.s3_rel_path
cls.s3_pathhandler = S3PathHandler()
cls.pathmanager = PathManager()
Expand Down Expand Up @@ -86,7 +101,6 @@ def tearDownClass(cls):
# Up here, test class attributes,
# and helpers that don't require S3 access.
#############################################

def test_00_supported_prefixes(self):
supported_prefixes = self.s3_pathhandler._get_supported_prefixes()
self.assertEqual(supported_prefixes, ["s3://"])
Expand Down

0 comments on commit 0e2abaf

Please sign in to comment.