diff --git a/google/cloud/dataflow/io/fileio.py b/google/cloud/dataflow/io/fileio.py index ea0f652..f5e3825 100644 --- a/google/cloud/dataflow/io/fileio.py +++ b/google/cloud/dataflow/io/fileio.py @@ -256,6 +256,23 @@ def rename(src, dst): except OSError as err: raise IOError(err) + @staticmethod + def copytree(src, dst): + if src.startswith('gs://'): + assert dst.startswith('gs://'), dst + assert src.endswith('/'), src + assert dst.endswith('/'), dst + # pylint: disable=g-import-not-at-top + from google.cloud.dataflow.io import gcsio + gcsio.GcsIO().copytree(src, dst) + else: + try: + if os.path.exists(dst): + shutil.rmtree(dst) + shutil.copytree(src, dst) + except OSError as err: + raise IOError(err) + @staticmethod def exists(path): if path.startswith('gs://'): diff --git a/google/cloud/dataflow/io/gcsio.py b/google/cloud/dataflow/io/gcsio.py index 5b62400..59ec9bc 100644 --- a/google/cloud/dataflow/io/gcsio.py +++ b/google/cloud/dataflow/io/gcsio.py @@ -171,6 +171,22 @@ def copy(self, src, dest): destinationObject=dest_path) self.client.objects.Copy(request) + # We intentionally do not decorate this method with a retry, since the + # underlying copy and delete operations are already idempotent operations + # protected by retry decorators. + def copytree(self, src, dest): + """Renames the given GCS "directory" recursively from src to dest. + + Args: + src: GCS file path pattern in the form gs:////. + dest: GCS file path pattern in the form gs:////. + """ + assert src.endswith('/') + assert dest.endswith('/') + for entry in self.glob(src + '*'): + rel_path = entry[len(src):] + self.copy(entry, dest + rel_path) + # We intentionally do not decorate this method with a retry, since the # underlying copy and delete operations are already idempotent operations # protected by retry decorators. diff --git a/google/cloud/dataflow/io/gcsio_test.py b/google/cloud/dataflow/io/gcsio_test.py index deb179d..12fcf5d 100644 --- a/google/cloud/dataflow/io/gcsio_test.py +++ b/google/cloud/dataflow/io/gcsio_test.py @@ -219,6 +219,31 @@ def test_copy(self): self.assertTrue(gcsio.parse_gcs_path(dest_file_name) in self.client.objects.files) + def test_copytree(self): + src_dir_name = 'gs://gcsio-test/source/' + dest_dir_name = 'gs://gcsio-test/dest/' + file_size = 1024 + paths = ['a', 'b/c', 'b/d'] + for path in paths: + src_file_name = src_dir_name + path + dest_file_name = dest_dir_name + path + self._insert_random_file(self.client, src_file_name, + file_size) + self.assertTrue(gcsio.parse_gcs_path(src_file_name) in + self.client.objects.files) + self.assertFalse(gcsio.parse_gcs_path(dest_file_name) in + self.client.objects.files) + + self.gcs.copytree(src_dir_name, dest_dir_name) + + for path in paths: + src_file_name = src_dir_name + path + dest_file_name = dest_dir_name + path + self.assertTrue(gcsio.parse_gcs_path(src_file_name) in + self.client.objects.files) + self.assertTrue(gcsio.parse_gcs_path(dest_file_name) in + self.client.objects.files) + def test_rename(self): src_file_name = 'gs://gcsio-test/source' dest_file_name = 'gs://gcsio-test/dest'