Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ecu_status: should use any_child_ecu_in_update flag instead of any_in_update to avoid self-lockedown, still wait on any_child_ecu_in_update before reboot #459

Merged
merged 14 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/otaclient/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class OTAClientStatus:

@dataclass
class MultipleECUStatusFlags:
any_in_update: mp_sync.Event
any_child_ecu_in_update: mp_sync.Event
any_requires_network: mp_sync.Event
all_success: mp_sync.Event

Expand Down
77 changes: 21 additions & 56 deletions src/otaclient/grpc/api_v2/ecu_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,11 @@
import math
import time
from itertools import chain
from typing import Dict, Iterable, Optional

from otaclient._types import MultipleECUStatusFlags, OTAClientStatus
from otaclient.configs.cfg import cfg, ecu_info
from otaclient.grpc.api_v2.types import convert_to_apiv2_status
from otaclient_api.v2 import types as api_types
from otaclient_common.typing import T

logger = logging.getLogger(__name__)

Expand All @@ -64,23 +62,6 @@
ACTIVE_POLLING_INTERVAL = 1 # seconds


class _OrderedSet(Dict[T, None]):
def __init__(self, _input: Optional[Iterable[T]]):
if _input:
for elem in _input:
self[elem] = None
super().__init__()

def add(self, value: T):
self[value] = None

def remove(self, value: T):
super().pop(value)

def discard(self, value: T):
super().pop(value, None)


class ECUStatusStorage:

def __init__(
Expand All @@ -95,27 +76,13 @@ def __init__(
# ECU status storage
self.storage_last_updated_timestamp = 0

# ECUs that are/will be active during an OTA session,
# at init it will be the ECUs listed in available_ecu_ids defined
# in ecu_info.yaml.
# When receives update request, the list will be set to include ECUs
# listed in the update request, and be extended by merging
# available_ecu_ids in sub ECUs' status report.
# Internally referenced when generating overall ECU status report.
# TODO: in the future if otaclient can preserve OTA session info,
# ECUStatusStorage should restore the tracked_active_ecus info
# in the saved session info.
self._tracked_active_ecus: _OrderedSet[str] = _OrderedSet(
ecu_info.get_available_ecu_ids()
)

# The attribute that will be exported in status API response,
# NOTE(20230801): available_ecu_ids only serves information purpose,
# it should only be updated with ecu_info.yaml or merging
# available_ecu_ids field in sub ECUs' status report.
# NOTE(20230801): for web.auto user, available_ecu_ids in status API response
# will be used to generate update request list, so be-careful!
self._available_ecu_ids: _OrderedSet[str] = _OrderedSet(
# NOTE(20241219): we will only look at status of ECUs listed in available_ecu_ids.
# ECUs that in the secondaries field but no in available_ecu_ids field
# are considered to be the ECUs not ready for OTA. See ecu_info.yaml doc.
self._available_ecu_ids: dict[str, None] = dict.fromkeys(
ecu_info.get_available_ecu_ids()
)

Expand Down Expand Up @@ -185,7 +152,7 @@ async def _generate_overall_status_report(self):
for status in chain(
self._all_ecus_status_v2.values(), self._all_ecus_status_v1.values()
)
if status.ecu_id in self._tracked_active_ecus
if status.ecu_id in self._available_ecu_ids
and status.is_in_update
and status.ecu_id not in lost_ecus
}
Expand All @@ -195,10 +162,11 @@ async def _generate_overall_status_report(self):
"new ECU(s) that acks update request and enters OTA update detected"
f"{_new_in_update_ecu}, current updating ECUs: {in_update_ecus_id}"
)
if in_update_ecus_id:
ecu_status_flags.any_in_update.set()

if self.in_update_child_ecus_id:
ecu_status_flags.any_child_ecu_in_update.set()
else:
ecu_status_flags.any_in_update.clear()
ecu_status_flags.any_child_ecu_in_update.clear()

# check if there is any failed child/self ECU in tracked active ECUs set
_old_failed_ecus_id = self.failed_ecus_id
Expand All @@ -207,7 +175,7 @@ async def _generate_overall_status_report(self):
for status in chain(
self._all_ecus_status_v2.values(), self._all_ecus_status_v1.values()
)
if status.ecu_id in self._tracked_active_ecus
if status.ecu_id in self._available_ecu_ids
and status.is_failed
and status.ecu_id not in lost_ecus
}
Expand All @@ -223,7 +191,7 @@ async def _generate_overall_status_report(self):
for status in chain(
self._all_ecus_status_v2.values(), self._all_ecus_status_v1.values()
)
if status.ecu_id in self._tracked_active_ecus
if status.ecu_id in self._available_ecu_ids
and status.ecu_id not in lost_ecus
)
):
Expand All @@ -240,12 +208,12 @@ async def _generate_overall_status_report(self):
for status in chain(
self._all_ecus_status_v2.values(), self._all_ecus_status_v1.values()
)
if status.ecu_id in self._tracked_active_ecus
if status.ecu_id in self._available_ecu_ids
and status.is_success
and status.ecu_id not in lost_ecus
}
# NOTE: all_success doesn't count the lost ECUs
if len(self.success_ecus_id) == len(self._tracked_active_ecus):
if self.success_ecus_id == set(self._available_ecu_ids):
ecu_status_flags.all_success.set()
else:
ecu_status_flags.all_success.clear()
Expand Down Expand Up @@ -334,19 +302,20 @@ async def on_ecus_accept_update_request(self, ecus_accept_update: set[str]):
"""
ecu_status_flags = self.ecu_status_flags
async with self._properties_update_lock:
self._tracked_active_ecus = _OrderedSet(ecus_accept_update)

self.last_update_request_received_timestamp = int(time.time())
self.lost_ecus_id -= ecus_accept_update
self.failed_ecus_id -= ecus_accept_update
self.success_ecus_id -= ecus_accept_update

self.in_update_ecus_id.update(ecus_accept_update)
self.in_update_child_ecus_id = self.in_update_ecus_id - {self.my_ecu_id}
self.success_ecus_id -= ecus_accept_update

ecu_status_flags.all_success.clear()
ecu_status_flags.any_requires_network.set()
ecu_status_flags.any_in_update.set()
if self.in_update_child_ecus_id:
ecu_status_flags.any_child_ecu_in_update.set()
else:
ecu_status_flags.any_child_ecu_in_update.clear()

def get_polling_interval(self) -> int:
"""Return <ACTIVE_POLLING_INTERVAL> if there is active OTA update,
Expand All @@ -355,11 +324,8 @@ def get_polling_interval(self) -> int:
NOTE: use get_polling_waiter if want to wait, only call this method
if one only wants to get the polling interval value.
"""
ecu_status_flags = self.ecu_status_flags
return (
ACTIVE_POLLING_INTERVAL
if ecu_status_flags.any_in_update.is_set()
else IDLE_POLLING_INTERVAL
ACTIVE_POLLING_INTERVAL if self.in_update_ecus_id else IDLE_POLLING_INTERVAL
)

def get_polling_waiter(self):
Expand All @@ -377,13 +343,12 @@ def get_polling_waiter(self):
_inner_wait_interval = 1 # second

async def _waiter():
ecu_status_flags = self.ecu_status_flags
if ecu_status_flags.any_in_update.is_set():
if self.in_update_ecus_id:
await asyncio.sleep(ACTIVE_POLLING_INTERVAL)
return

for _ in range(math.ceil(IDLE_POLLING_INTERVAL / _inner_wait_interval)):
if ecu_status_flags.any_in_update.is_set():
if self.in_update_ecus_id:
return
await asyncio.sleep(_inner_wait_interval)

Expand Down
2 changes: 1 addition & 1 deletion src/otaclient/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def main() -> None: # pragma: no cover
local_otaclient_op_queue = mp_ctx.Queue()
local_otaclient_resp_queue = mp_ctx.Queue()
ecu_status_flags = MultipleECUStatusFlags(
any_in_update=mp_ctx.Event(),
any_child_ecu_in_update=mp_ctx.Event(),
any_requires_network=mp_ctx.Event(),
all_success=mp_ctx.Event(),
)
Expand Down
4 changes: 1 addition & 3 deletions src/otaclient/ota_core.py
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert to the behavior of otaclient v3.8.x and before: wait for all child ECUs finish OTA update.

Original file line number Diff line number Diff line change
Expand Up @@ -579,11 +579,9 @@ def _execute_update(self):
)
)

# NOTE: we don't need to wait for sub ECUs if sub ECUs don't
# depend on otaproxy on this ECU.
if proxy_info.enable_local_ota_proxy:
wait_and_log(
check_flag=self.ecu_status_flags.any_requires_network.is_set,
check_flag=self.ecu_status_flags.any_child_ecu_in_update.is_set,
check_for=False,
message="permit reboot flag",
log_func=logger.info,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_otaclient/test_create_standby.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_update_with_rebuild_mode(
):
status_collector, status_report_queue = ota_status_collector
ecu_status_flags = mocker.MagicMock()
ecu_status_flags.any_requires_network.is_set = mocker.MagicMock(
ecu_status_flags.any_child_ecu_in_update.is_set = mocker.MagicMock(
return_value=False
)

Expand Down Expand Up @@ -145,7 +145,7 @@ def test_update_with_rebuild_mode(
# ------ assertions ------ #
persist_handler.assert_called_once()

ecu_status_flags.any_requires_network.is_set.assert_called_once()
ecu_status_flags.any_child_ecu_in_update.is_set.assert_called_once()
# --- ensure the update stats are collected
collected_status = status_collector.otaclient_status
assert collected_status
Expand Down
67 changes: 59 additions & 8 deletions tests/test_otaclient/test_grpc/test_api_v2/test_ecu_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def setup_test(self, mocker: MockerFixture, ecu_info_fixture: ECUInfo):
# init and setup the ecu_storage
# NOTE: here we use threading.Event instead
self.ecu_status_flags = ecu_status_flags = MultipleECUStatusFlags(
any_in_update=threading.Event(), # type: ignore[assignment]
any_child_ecu_in_update=threading.Event(), # type: ignore[assignment]
any_requires_network=threading.Event(), # type: ignore[assignment]
all_success=threading.Event(), # type: ignore[assignment]
)
Expand Down Expand Up @@ -380,7 +380,7 @@ async def test_export(
},
# ecu_status_flags
{
"any_in_update": True,
"any_child_ecu_in_update": True,
"any_requires_network": True,
"all_success": False,
},
Expand Down Expand Up @@ -429,7 +429,55 @@ async def test_export(
},
# ecu_status_flags
{
"any_in_update": True,
"any_child_ecu_in_update": True,
"any_requires_network": True,
"all_success": False,
},
),
# case 3:
# only main ECU doing OTA update.
(
# local ECU status: UPDATING
_internal_types.OTAClientStatus(
ota_status=_internal_types.OTAStatus.UPDATING,
update_phase=_internal_types.UpdatePhase.DOWNLOADING_OTA_FILES,
),
# sub ECUs status
[
# p1: SUCCESS
api_types.StatusResponse(
available_ecu_ids=["p1"],
ecu_v2=[
api_types.StatusResponseEcuV2(
ecu_id="p1",
ota_status=api_types.StatusOta.SUCCESS,
),
],
),
# p2: SUCCESS
api_types.StatusResponse(
available_ecu_ids=["p2"],
ecu=[
api_types.StatusResponseEcu(
ecu_id="p2",
status=api_types.Status(
status=api_types.StatusOta.SUCCESS,
),
)
],
),
],
# expected overal ECUs status report set by on_ecus_accept_update_request,
{
"lost_ecus_id": set(),
"in_update_ecus_id": {"autoware"},
"in_update_child_ecus_id": set(),
"failed_ecus_id": set(),
"success_ecus_id": {"p1", "p2"},
},
# ecu_status_flags
{
"any_child_ecu_in_update": False,
"any_requires_network": True,
"all_success": False,
},
Expand Down Expand Up @@ -510,7 +558,7 @@ async def test_overall_ecu_status_report_generation(
},
# ecu_status_flags
{
"any_in_update": True,
"any_child_ecu_in_update": True,
"any_requires_network": True,
"all_success": False,
},
Expand Down Expand Up @@ -562,7 +610,7 @@ async def test_overall_ecu_status_report_generation(
},
# ecu_status_flags
{
"any_in_update": True,
"any_child_ecu_in_update": True,
"any_requires_network": True,
"all_success": False,
},
Expand Down Expand Up @@ -604,15 +652,18 @@ async def test_on_receive_update_request(
for k, v in flags_status.items():
assert getattr(self.ecu_status_flags, k).is_set() == v

async def test_polling_waiter_switching_from_idling_to_active(self):
async def test_polling_waiter_switching_from_idling_to_active(
self, mocker: pytest_mock.MockerFixture
):
"""Waiter should immediately return if active_ota_update_present is set."""
_sleep_time, _mocked_interval = 3, 60

mocker.patch(f"{ECU_STATUS_MODULE}.IDLE_POLLING_INTERVAL", _mocked_interval)

async def _event_setter():
await asyncio.sleep(_sleep_time)
self.ecu_status_flags.any_in_update.set()
await self.ecu_storage.on_ecus_accept_update_request({"autoware"})

self.ecu_status_flags.any_in_update.clear()
_waiter = self.ecu_storage.get_polling_waiter()
asyncio.create_task(_event_setter())
# waiter should return on active_ota_update_present is set, instead of waiting the
Expand Down
6 changes: 3 additions & 3 deletions tests/test_otaclient/test_ota_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_otaupdater(
) -> None:
_, report_queue = ota_status_collector
ecu_status_flags = mocker.MagicMock()
ecu_status_flags.any_requires_network.is_set = mocker.MagicMock(
ecu_status_flags.any_child_ecu_in_update.is_set = mocker.MagicMock(
return_value=False
)

Expand Down Expand Up @@ -202,7 +202,7 @@ def test_otaupdater(
assert _downloaded_files_size == self._delta_bundle.total_download_files_size

# assert the control_flags has been waited
ecu_status_flags.any_requires_network.is_set.assert_called_once()
ecu_status_flags.any_child_ecu_in_update.is_set.assert_called_once()

assert _updater.update_version == str(cfg.UPDATE_VERSION)

Expand Down Expand Up @@ -235,7 +235,7 @@ def mock_setup(
):
_, status_report_queue = ota_status_collector
ecu_status_flags = mocker.MagicMock()
ecu_status_flags.any_requires_network.is_set = mocker.MagicMock(
ecu_status_flags.any_child_ecu_in_update.is_set = mocker.MagicMock(
return_value=False
)

Expand Down
Loading