diff --git a/compiler_opt/distributed/worker.py b/compiler_opt/distributed/worker.py index 3cd8752c..3ce47287 100644 --- a/compiler_opt/distributed/worker.py +++ b/compiler_opt/distributed/worker.py @@ -15,7 +15,6 @@ """Common abstraction for a worker contract.""" import abc -import sys from typing import Any, List, Iterable, Optional, Protocol, TypeVar import gin @@ -104,8 +103,4 @@ def get_full_worker_args(worker_class: 'type[Worker]', **current_kwargs): # we don't have a way to check if `worker_class` is even known to gin, and # it's not a requirement that it were. Tests, for instance, don't use gin. pass - # Issue #38 - if sys.version_info >= (3, 9): - return current_kwargs | gin_config - else: - return {**current_kwargs, **gin_config} + return current_kwargs | gin_config diff --git a/compiler_opt/rl/data_collector_test.py b/compiler_opt/rl/data_collector_test.py index 44768d32..0f8d2f6a 100644 --- a/compiler_opt/rl/data_collector_test.py +++ b/compiler_opt/rl/data_collector_test.py @@ -15,7 +15,6 @@ """Tests for data_collector.""" # pylint: disable=protected-access -import sys from unittest import mock from absl.testing import absltest @@ -29,11 +28,7 @@ def test_build_distribution_monitor(self): data = [3, 2, 1] monitor_dict = data_collector.build_distribution_monitor(data) reference_dict = {'mean': 2, 'p_0.1': 1} - # Issue #38 - if sys.version_info >= (3, 9): - self.assertEqual(monitor_dict, monitor_dict | reference_dict) - else: - self.assertEqual(monitor_dict, {**monitor_dict, **reference_dict}) + self.assertEqual(monitor_dict, monitor_dict | reference_dict) @mock.patch('time.time') def test_early_exit(self, mock_time): diff --git a/compiler_opt/rl/local_data_collector_test.py b/compiler_opt/rl/local_data_collector_test.py index 643efc6e..27735964 100644 --- a/compiler_opt/rl/local_data_collector_test.py +++ b/compiler_opt/rl/local_data_collector_test.py @@ -17,7 +17,6 @@ # pylint: disable=protected-access import collections import string -import sys from typing import List, Tuple import tensorflow as tf @@ -180,15 +179,8 @@ def _test_iterator_fn(data_list): 'total_trajectory_length': 18, } } - # Issue #38 - if sys.version_info >= (3, 9): - self.assertEqual(monitor_dict, - monitor_dict | expected_monitor_dict_subset) - else: - self.assertEqual(monitor_dict, { - **monitor_dict, - **expected_monitor_dict_subset - }) + self.assertEqual(monitor_dict, + monitor_dict | expected_monitor_dict_subset) data_iterator, monitor_dict = collector.collect_data( policy=_mock_policy, model_id=0) data = list(data_iterator) @@ -200,15 +192,8 @@ def _test_iterator_fn(data_list): 'total_trajectory_length': 18, } } - # Issue #38 - if sys.version_info >= (3, 9): - self.assertEqual(monitor_dict, - monitor_dict | expected_monitor_dict_subset) - else: - self.assertEqual(monitor_dict, { - **monitor_dict, - **expected_monitor_dict_subset - }) + self.assertEqual(monitor_dict, + monitor_dict | expected_monitor_dict_subset) collector.close_pool()