From c9c094cfab13b9692dd804c344bcd2d8e5be839d Mon Sep 17 00:00:00 2001 From: SomberNight Date: Tue, 8 Feb 2022 12:34:49 +0100 Subject: [PATCH] requirements: bump min aiorpcx to 0.22.0 aiorpcx 0.20 changed the behaviour/API of TaskGroups. When used as a context manager, TaskGroups no longer propagate exceptions raised by their tasks. Instead, the calling code has to explicitly check the results of tasks and decide whether to re-raise any exceptions. This is a significant change, and so this commit introduces "OldTaskGroup", which should behave as the TaskGroup class of old aiorpcx. All existing usages of TaskGroup are replaced with OldTaskGroup. closes https://github.com/spesmilo/electrum/issues/7446 --- contrib/deterministic-build/requirements.txt | 6 +-- contrib/requirements/requirements.txt | 2 +- electrum/address_synchronizer.py | 6 +-- electrum/bip39_recovery.py | 7 ++- electrum/daemon.py | 10 ++--- electrum/exchange_rate.py | 6 +-- electrum/interface.py | 7 ++- electrum/lnpeer.py | 10 ++--- electrum/lnworker.py | 12 ++--- electrum/network.py | 18 ++++---- electrum/synchronizer.py | 6 +-- electrum/tests/test_lnpeer.py | 38 ++++++++-------- electrum/tests/test_lntransport.py | 9 ++-- electrum/util.py | 46 ++++++++++++++++++-- electrum/wallet.py | 8 ++-- run_electrum | 4 +- 16 files changed, 115 insertions(+), 80 deletions(-) diff --git a/contrib/deterministic-build/requirements.txt b/contrib/deterministic-build/requirements.txt index 33769c320eda..b23dbe3e7c54 100644 --- a/contrib/deterministic-build/requirements.txt +++ b/contrib/deterministic-build/requirements.txt @@ -74,9 +74,9 @@ aiohttp==3.8.1 \ aiohttp-socks==0.7.1 \ --hash=sha256:2215cac4891ef3fa14b7d600ed343ed0f0a670c23b10e4142aa862b3db20341a \ --hash=sha256:94bcff5ef73611c6c6231c2ffc1be4af1599abec90dbd2fdbbd63233ec2fb0ff -aiorpcX==0.18.7 \ - --hash=sha256:7fa48423e1c06cd0ffb7b60f2cca7e819b6cbbf57d4bc8a82944994ef5038f05 \ - --hash=sha256:808a9ec9172df11677a0f7b459b69d1a6cf8b19c19da55541fa31fb1afce5ce7 +aiorpcX==0.22.1 \ + --hash=sha256:6026f7bed3432e206589c94dcf599be8cd85b5736b118c7275845c1bd922a553 \ + --hash=sha256:e74f9fbed3fd21598e71fe05066618fc2c06feec504fe29490ddda05fdbdde62 aiosignal==1.2.0 \ --hash=sha256:26e62109036cd181df6e6ad646f91f0dcfd05fe16d0cb924138ff2ab75d64e3a \ --hash=sha256:78ed67db6c7b7ced4f98e495e572106d5c432a93e1ddd1bf475e1dc05f5b7df2 diff --git a/contrib/requirements/requirements.txt b/contrib/requirements/requirements.txt index 581eedd55b8b..04b0a77f3167 100644 --- a/contrib/requirements/requirements.txt +++ b/contrib/requirements/requirements.txt @@ -1,7 +1,7 @@ qrcode protobuf>=3.12 qdarkstyle>=2.7 -aiorpcx>=0.18.7,<0.19 +aiorpcx>=0.22.0,<0.23 aiohttp>=3.3.0,<4.0.0 aiohttp_socks>=0.3 certifi diff --git a/electrum/address_synchronizer.py b/electrum/address_synchronizer.py index 5ed4a35ae1f7..8d213dac56e1 100644 --- a/electrum/address_synchronizer.py +++ b/electrum/address_synchronizer.py @@ -28,11 +28,9 @@ from collections import defaultdict from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, NamedTuple, Sequence, List -from aiorpcx import TaskGroup - from . import bitcoin, util from .bitcoin import COINBASE_MATURITY -from .util import profiler, bfh, TxMinedInfo, UnrelatedTransactionException, with_lock +from .util import profiler, bfh, TxMinedInfo, UnrelatedTransactionException, with_lock, OldTaskGroup from .transaction import Transaction, TxOutput, TxInput, PartialTxInput, TxOutpoint, PartialTransaction from .synchronizer import Synchronizer from .verifier import SPV @@ -193,7 +191,7 @@ def on_blockchain_updated(self, event, *args): async def stop(self): if self.network: try: - async with TaskGroup() as group: + async with OldTaskGroup() as group: if self.synchronizer: await group.spawn(self.synchronizer.stop()) if self.verifier: diff --git a/electrum/bip39_recovery.py b/electrum/bip39_recovery.py index b0cd2f9022e3..c9aeaa41ecb9 100644 --- a/electrum/bip39_recovery.py +++ b/electrum/bip39_recovery.py @@ -4,20 +4,19 @@ from typing import TYPE_CHECKING -from aiorpcx import TaskGroup - from . import bitcoin from .constants import BIP39_WALLET_FORMATS from .bip32 import BIP32_PRIME, BIP32Node from .bip32 import convert_bip32_path_to_list_of_uint32 as bip32_str_to_ints from .bip32 import convert_bip32_intpath_to_strpath as bip32_ints_to_str +from .util import OldTaskGroup if TYPE_CHECKING: from .network import Network async def account_discovery(network: 'Network', get_account_xpub): - async with TaskGroup() as group: + async with OldTaskGroup() as group: account_scan_tasks = [] for wallet_format in BIP39_WALLET_FORMATS: account_scan = scan_for_active_accounts(network, get_account_xpub, wallet_format) @@ -46,7 +45,7 @@ async def scan_for_active_accounts(network: 'Network', get_account_xpub, wallet_ async def account_has_history(network: 'Network', account_node: BIP32Node, script_type: str) -> bool: gap_limit = 20 - async with TaskGroup() as group: + async with OldTaskGroup() as group: get_history_tasks = [] for address_index in range(gap_limit): address_node = account_node.subkey_at_public_derivation("0/" + str(address_index)) diff --git a/electrum/daemon.py b/electrum/daemon.py index afac69529438..20fb2634f225 100644 --- a/electrum/daemon.py +++ b/electrum/daemon.py @@ -36,13 +36,13 @@ import aiohttp from aiohttp import web, client_exceptions -from aiorpcx import TaskGroup, timeout_after, TaskTimeout, ignore_after +from aiorpcx import timeout_after, TaskTimeout, ignore_after from . import util from .network import Network from .util import (json_decode, to_bytes, to_string, profiler, standardize_path, constant_time_compare) from .invoices import PR_PAID, PR_EXPIRED -from .util import log_exceptions, ignore_exceptions, randrange +from .util import log_exceptions, ignore_exceptions, randrange, OldTaskGroup from .wallet import Wallet, Abstract_Wallet from .storage import WalletStorage from .wallet_db import WalletDB @@ -493,7 +493,7 @@ def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True): self._stop_entered = False self._stopping_soon_or_errored = threading.Event() self._stopped_event = threading.Event() - self.taskgroup = TaskGroup() + self.taskgroup = OldTaskGroup() asyncio.run_coroutine_threadsafe(self._run(jobs=daemon_jobs), self.asyncio_loop) @log_exceptions @@ -591,12 +591,12 @@ async def stop(self): if self.gui_object: self.gui_object.stop() self.logger.info("stopping all wallets") - async with TaskGroup() as group: + async with OldTaskGroup() as group: for k, wallet in self._wallets.items(): await group.spawn(wallet.stop()) self.logger.info("stopping network and taskgroup") async with ignore_after(2): - async with TaskGroup() as group: + async with OldTaskGroup() as group: if self.network: await group.spawn(self.network.stop(full_shutdown=True)) await group.spawn(self.taskgroup.cancel_remaining()) diff --git a/electrum/exchange_rate.py b/electrum/exchange_rate.py index 4ec8fa540fbd..03b23356e90a 100644 --- a/electrum/exchange_rate.py +++ b/electrum/exchange_rate.py @@ -10,13 +10,13 @@ from decimal import Decimal from typing import Sequence, Optional -from aiorpcx.curio import timeout_after, TaskTimeout, TaskGroup +from aiorpcx.curio import timeout_after, TaskTimeout import aiohttp from . import util from .bitcoin import COIN from .i18n import _ -from .util import (ThreadJob, make_dir, log_exceptions, +from .util import (ThreadJob, make_dir, log_exceptions, OldTaskGroup, make_aiohttp_session, resource_path) from .network import Network from .simple_config import SimpleConfig @@ -449,7 +449,7 @@ async def get_currencies_safe(name, exchange): async def query_all_exchanges_for_their_ccys_over_network(): async with timeout_after(10): - async with TaskGroup() as group: + async with OldTaskGroup() as group: for name, klass in exchanges.items(): exchange = klass(None, None) await group.spawn(get_currencies_safe(name, exchange)) diff --git a/electrum/interface.py b/electrum/interface.py index a0c416a633bc..fc97ff780b8b 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -38,7 +38,6 @@ import functools import aiorpcx -from aiorpcx import TaskGroup from aiorpcx import RPCSession, Notification, NetAddress, NewlineFramer from aiorpcx.curio import timeout_after, TaskTimeout from aiorpcx.jsonrpc import JSONRPC, CodeMessageError @@ -47,7 +46,7 @@ from .util import (ignore_exceptions, log_exceptions, bfh, MySocksProxy, is_integer, is_non_negative_integer, is_hash256_str, is_hex_str, - is_int_or_float, is_non_negative_int_or_float) + is_int_or_float, is_non_negative_int_or_float, OldTaskGroup) from . import util from . import x509 from . import pem @@ -376,7 +375,7 @@ def __init__(self, *, network: 'Network', server: ServerAddr, proxy: Optional[di # Dump network messages (only for this interface). Set at runtime from the console. self.debug = False - self.taskgroup = TaskGroup() + self.taskgroup = OldTaskGroup() async def spawn_task(): task = await self.network.taskgroup.spawn(self.run()) @@ -675,7 +674,7 @@ async def ping(self): async def request_fee_estimates(self): from .simple_config import FEE_ETA_TARGETS while True: - async with TaskGroup() as group: + async with OldTaskGroup() as group: fee_tasks = [] for i in FEE_ETA_TARGETS: fee_tasks.append((i, await group.spawn(self.get_estimatefee(i)))) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index d9072c148cba..167fa1961270 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -14,14 +14,14 @@ import functools import aiorpcx -from aiorpcx import TaskGroup, ignore_after +from aiorpcx import ignore_after from .crypto import sha256, sha256d from . import bitcoin, util from . import ecc from .ecc import sig_string_from_r_and_s, der_sig_from_sig_string from . import constants -from .util import (bh2u, bfh, log_exceptions, ignore_exceptions, chunks, TaskGroup, +from .util import (bh2u, bfh, log_exceptions, ignore_exceptions, chunks, OldTaskGroup, UnrelatedTransactionException) from . import transaction from .bitcoin import make_op_return @@ -105,7 +105,7 @@ def __init__( self.announcement_signatures = defaultdict(asyncio.Queue) self.orphan_channel_updates = OrderedDict() # type: OrderedDict[ShortChannelID, dict] Logger.__init__(self) - self.taskgroup = TaskGroup() + self.taskgroup = OldTaskGroup() # HTLCs offered by REMOTE, that we started removing but are still active: self.received_htlcs_pending_removal = set() # type: Set[Tuple[Channel, int]] self.received_htlc_removed_event = asyncio.Event() @@ -1859,7 +1859,7 @@ async def htlc_switch(self): # we can get triggered for events that happen on the downstream peer. # TODO: trampoline forwarding relies on the polling async with ignore_after(0.1): - async with TaskGroup(wait=any) as group: + async with OldTaskGroup(wait=any) as group: await group.spawn(self._received_revack_event.wait()) await group.spawn(self.downstream_htlc_resolved_event.wait()) self._htlc_switch_iterstart_event.set() @@ -1943,7 +1943,7 @@ async def htlc_switch_iteration(): await self._htlc_switch_iterstart_event.wait() await self._htlc_switch_iterdone_event.wait() - async with TaskGroup(wait=any) as group: + async with OldTaskGroup(wait=any) as group: await group.spawn(htlc_switch_iteration()) await group.spawn(self.got_disconnected.wait()) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index b22824f796d4..48194491719b 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -22,11 +22,11 @@ import dns.resolver import dns.exception -from aiorpcx import run_in_thread, TaskGroup, NetAddress, ignore_after +from aiorpcx import run_in_thread, NetAddress, ignore_after from . import constants, util from . import keystore -from .util import profiler, chunks +from .util import profiler, chunks, OldTaskGroup from .invoices import PR_TYPE_LN, PR_UNPAID, PR_EXPIRED, PR_PAID, PR_INFLIGHT, PR_FAILED, PR_ROUTING, LNInvoice, LN_EXPIRY_NEVER from .util import NetworkRetryManager, JsonRPCClient from .lnutil import LN_MAX_FUNDING_SAT @@ -200,7 +200,7 @@ def __init__(self, xprv, features: LnFeatures): self.node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY) self.backup_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.BACKUP_CIPHER).privkey self._peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer # needs self.lock - self.taskgroup = TaskGroup() + self.taskgroup = OldTaskGroup() self.listen_server = None # type: Optional[asyncio.AbstractServer] self.features = features self.network = None # type: Optional[Network] @@ -767,13 +767,13 @@ async def wait_for_received_pending_htlcs_to_get_removed(self): # to wait a bit for it to become irrevocably removed. # Note: we don't wait for *all htlcs* to get removed, only for those # that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed - async with TaskGroup() as group: + async with OldTaskGroup() as group: for peer in self.peers.values(): await group.spawn(peer.wait_one_htlc_switch_iteration()) while True: if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()): break - async with TaskGroup(wait=any) as group: + async with OldTaskGroup(wait=any) as group: for peer in self.peers.values(): await group.spawn(peer.received_htlc_removed_event.wait()) @@ -2269,7 +2269,7 @@ async def _request_force_close_from_backup(self, channel_id: bytes): transport = LNTransport(privkey, peer_addr, proxy=self.network.proxy) peer = Peer(self, node_id, transport, is_channel_backup=True) try: - async with TaskGroup(wait=any) as group: + async with OldTaskGroup(wait=any) as group: await group.spawn(peer._message_loop()) await group.spawn(peer.trigger_force_close(channel_id)) return diff --git a/electrum/network.py b/electrum/network.py index a24349427bbd..4d7235eb8199 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -40,11 +40,11 @@ import functools import aiorpcx -from aiorpcx import TaskGroup, ignore_after +from aiorpcx import ignore_after from aiohttp import ClientResponse from . import util -from .util import (log_exceptions, ignore_exceptions, +from .util import (log_exceptions, ignore_exceptions, OldTaskGroup, bfh, make_aiohttp_session, send_exception_to_crash_reporter, is_hash256_str, is_non_negative_integer, MyEncoder, NetworkRetryManager, nullcontext) @@ -246,7 +246,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): LOGGING_SHORTCUT = 'n' - taskgroup: Optional[TaskGroup] + taskgroup: Optional[OldTaskGroup] interface: Optional[Interface] interfaces: Dict[ServerAddr, Interface] _connecting_ifaces: Set[ServerAddr] @@ -462,7 +462,7 @@ async def get_server_peers(): async def get_relay_fee(): self.relay_fee = await interface.get_relay_fee() - async with TaskGroup() as group: + async with OldTaskGroup() as group: await group.spawn(get_banner) await group.spawn(get_donation_address) await group.spawn(get_server_peers) @@ -839,7 +839,7 @@ async def make_reliable_wrapper(self: 'Network', *args, **kwargs): assert iface.ready.done(), "interface not ready yet" # try actual request try: - async with TaskGroup(wait=any) as group: + async with OldTaskGroup(wait=any) as group: task = await group.spawn(func(self, *args, **kwargs)) await group.spawn(iface.got_disconnected.wait()) except RequestTimedOut: @@ -1184,7 +1184,7 @@ def export_checkpoints(self, path): async def _start(self): assert not self.taskgroup - self.taskgroup = taskgroup = TaskGroup() + self.taskgroup = taskgroup = OldTaskGroup() assert not self.interface and not self.interfaces assert not self._connecting_ifaces assert not self._closing_ifaces @@ -1225,7 +1225,7 @@ async def stop(self, *, full_shutdown: bool = True): # timeout: if full_shutdown, it is up to the caller to time us out, # otherwise if e.g. restarting due to proxy changes, we time out fast async with (nullcontext() if full_shutdown else ignore_after(1)): - async with TaskGroup() as group: + async with OldTaskGroup() as group: await group.spawn(self.taskgroup.cancel_remaining()) if full_shutdown: await group.spawn(self.stop_gossip(full_shutdown=full_shutdown)) @@ -1278,7 +1278,7 @@ async def maintain_main_interface(): except asyncio.CancelledError: # suppress spurious cancellations group = self.taskgroup - if not group or group.closed(): + if not group or group.joined: raise await asyncio.sleep(0.1) @@ -1352,7 +1352,7 @@ async def get_response(server: ServerAddr): except Exception as e: res = e responses[interface.server] = res - async with TaskGroup() as group: + async with OldTaskGroup() as group: for server in servers: await group.spawn(get_response(server)) return responses diff --git a/electrum/synchronizer.py b/electrum/synchronizer.py index 5e7257bd5964..dacfed2658ac 100644 --- a/electrum/synchronizer.py +++ b/electrum/synchronizer.py @@ -28,11 +28,11 @@ from collections import defaultdict import logging -from aiorpcx import TaskGroup, run_in_thread, RPCError +from aiorpcx import run_in_thread, RPCError from . import util from .transaction import Transaction, PartialTransaction -from .util import bh2u, make_aiohttp_session, NetworkJobOnDefaultServer, random_shuffled_copy +from .util import bh2u, make_aiohttp_session, NetworkJobOnDefaultServer, random_shuffled_copy, OldTaskGroup from .bitcoin import address_to_scripthash, is_address from .logging import Logger from .interface import GracefulDisconnect, NetworkTimeout @@ -218,7 +218,7 @@ async def _request_missing_txs(self, hist, *, allow_server_not_finding_tx=False) self.requested_tx[tx_hash] = tx_height if not transaction_hashes: return - async with TaskGroup() as group: + async with OldTaskGroup() as group: for tx_hash in transaction_hashes: await group.spawn(self._get_transaction(tx_hash, allow_server_not_finding_tx=allow_server_not_finding_tx)) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 49d8c54330a0..e8cceacec046 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -10,7 +10,7 @@ import unittest from typing import Iterable, NamedTuple, Tuple, List, Dict -from aiorpcx import TaskGroup, timeout_after, TaskTimeout +from aiorpcx import timeout_after, TaskTimeout import electrum import electrum.trampoline @@ -21,7 +21,7 @@ from electrum import simple_config, lnutil from electrum.lnaddr import lnencode, LnAddr, lndecode from electrum.bitcoin import COIN, sha256 -from electrum.util import bh2u, create_and_start_event_loop, NetworkRetryManager, bfh +from electrum.util import bh2u, create_and_start_event_loop, NetworkRetryManager, bfh, OldTaskGroup from electrum.lnpeer import Peer, UpfrontShutdownScriptViolation from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving @@ -125,7 +125,7 @@ def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_que NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1) self.node_keypair = local_keypair self.network = MockNetwork(tx_queue) - self.taskgroup = TaskGroup() + self.taskgroup = OldTaskGroup() self.lnwatcher = None self.listen_server = None self._channels = {chan.channel_id: chan for chan in chans} @@ -365,7 +365,7 @@ def setUp(self): def tearDown(self): async def cleanup_lnworkers(): - async with TaskGroup() as group: + async with OldTaskGroup() as group: for lnworker in self._lnworkers_created: await group.spawn(lnworker.stop()) self._lnworkers_created.clear() @@ -569,7 +569,7 @@ async def pay(lnaddr, pay_req): self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) raise PaymentDone() async def f(): - async with TaskGroup() as group: + async with OldTaskGroup() as group: await group.spawn(p1._message_loop()) await group.spawn(p1.htlc_switch()) await group.spawn(p2._message_loop()) @@ -643,7 +643,7 @@ async def pay(): raise PaymentDone() async def f(): - async with TaskGroup() as group: + async with OldTaskGroup() as group: await group.spawn(p1._message_loop()) await group.spawn(p1.htlc_switch()) await group.spawn(p2._message_loop()) @@ -667,10 +667,10 @@ async def single_payment(pay_req): async with max_htlcs_in_flight: await w1.pay_invoice(pay_req) async def many_payments(): - async with TaskGroup() as group: + async with OldTaskGroup() as group: pay_reqs_tasks = [await group.spawn(self.prepare_invoice(w2, amount_msat=payment_value_msat)) for i in range(num_payments)] - async with TaskGroup() as group: + async with OldTaskGroup() as group: for pay_req_task in pay_reqs_tasks: lnaddr, pay_req = pay_req_task.result() await group.spawn(single_payment(pay_req)) @@ -696,7 +696,7 @@ async def pay(lnaddr, pay_req): self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) raise PaymentDone() async def f(): - async with TaskGroup() as group: + async with OldTaskGroup() as group: for peer in peers: await group.spawn(peer._message_loop()) await group.spawn(peer.htlc_switch()) @@ -740,7 +740,7 @@ async def pay(pay_req): [edge.short_channel_id for edge in log[0].route]) raise PaymentDone() async def f(): - async with TaskGroup() as group: + async with OldTaskGroup() as group: for peer in peers: await group.spawn(peer._message_loop()) await group.spawn(peer.htlc_switch()) @@ -764,7 +764,7 @@ async def pay(lnaddr, pay_req): self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code) raise PaymentDone() async def f(): - async with TaskGroup() as group: + async with OldTaskGroup() as group: for peer in peers: await group.spawn(peer._message_loop()) await group.spawn(peer.htlc_switch()) @@ -799,7 +799,7 @@ async def pay(lnaddr, pay_req): self.assertEqual(500100000000, graph.channels[('dave', 'bob')].balance(LOCAL)) raise PaymentDone() async def f(): - async with TaskGroup() as group: + async with OldTaskGroup() as group: for peer in peers: await group.spawn(peer._message_loop()) await group.spawn(peer.htlc_switch()) @@ -862,7 +862,7 @@ async def pay(lnaddr, pay_req): raise PaymentDone() async def f(): - async with TaskGroup() as group: + async with OldTaskGroup() as group: for peer in peers: await group.spawn(peer._message_loop()) await group.spawn(peer.htlc_switch()) @@ -909,7 +909,7 @@ async def pay( raise NoPathFound() async def f(kwargs): - async with TaskGroup() as group: + async with OldTaskGroup() as group: for peer in peers: await group.spawn(peer._message_loop()) await group.spawn(peer.htlc_switch()) @@ -948,7 +948,7 @@ async def pay(lnaddr, pay_req): async def f(): await turn_on_trampoline_alice() - async with TaskGroup() as group: + async with OldTaskGroup() as group: for peer in peers: await group.spawn(peer._message_loop()) await group.spawn(peer.htlc_switch()) @@ -1026,7 +1026,7 @@ async def pay(): raise SuccessfulTest() async def f(): - async with TaskGroup() as group: + async with OldTaskGroup() as group: for peer in peers: await group.spawn(peer._message_loop()) await group.spawn(peer.htlc_switch()) @@ -1196,7 +1196,7 @@ async def send_weird_messages(): raise SuccessfulTest() async def f(): - async with TaskGroup() as group: + async with OldTaskGroup() as group: for peer in [p1, p2]: await group.spawn(peer._message_loop()) await group.spawn(peer.htlc_switch()) @@ -1223,7 +1223,7 @@ async def send_weird_messages(): failing_task = None async def f(): nonlocal failing_task - async with TaskGroup() as group: + async with OldTaskGroup() as group: await group.spawn(p1._message_loop()) await group.spawn(p1.htlc_switch()) failing_task = await group.spawn(p2._message_loop()) @@ -1252,7 +1252,7 @@ async def send_weird_messages(): failing_task = None async def f(): nonlocal failing_task - async with TaskGroup() as group: + async with OldTaskGroup() as group: await group.spawn(p1._message_loop()) await group.spawn(p1.htlc_switch()) failing_task = await group.spawn(p2._message_loop()) diff --git a/electrum/tests/test_lntransport.py b/electrum/tests/test_lntransport.py index 8b7d567ba619..4d355bb83a6f 100644 --- a/electrum/tests/test_lntransport.py +++ b/electrum/tests/test_lntransport.py @@ -3,8 +3,7 @@ from electrum.ecc import ECPrivkey from electrum.lnutil import LNPeerAddr from electrum.lntransport import LNResponderTransport, LNTransport - -from aiorpcx import TaskGroup +from electrum.util import OldTaskGroup from . import ElectrumTestCase from .test_bitcoin import needs_test_with_all_chacha20_implementations @@ -73,7 +72,7 @@ async def write_messages(transport, expected_messages): async def cb(reader, writer): t = LNResponderTransport(responder_key.get_secret_bytes(), reader, writer) self.assertEqual(await t.handshake(), initiator_key.get_public_key_bytes()) - async with TaskGroup() as group: + async with OldTaskGroup() as group: await group.spawn(read_messages(t, messages_sent_by_client)) await group.spawn(write_messages(t, messages_sent_by_server)) responder_shaked.set() @@ -81,7 +80,7 @@ async def connect(): peer_addr = LNPeerAddr('127.0.0.1', 42898, responder_key.get_public_key_bytes()) t = LNTransport(initiator_key.get_secret_bytes(), peer_addr, proxy=None) await t.handshake() - async with TaskGroup() as group: + async with OldTaskGroup() as group: await group.spawn(read_messages(t, messages_sent_by_server)) await group.spawn(write_messages(t, messages_sent_by_client)) server_shaked.set() @@ -89,7 +88,7 @@ async def connect(): async def f(): server = await asyncio.start_server(cb, '127.0.0.1', 42898) try: - async with TaskGroup() as group: + async with OldTaskGroup() as group: await group.spawn(connect()) await group.spawn(responder_shaked.wait()) await group.spawn(server_shaked.wait()) diff --git a/electrum/util.py b/electrum/util.py index a218dd668c2e..14f5c9f4fe24 100644 --- a/electrum/util.py +++ b/electrum/util.py @@ -52,7 +52,6 @@ import aiohttp from aiohttp_socks import ProxyConnector, ProxyType import aiorpcx -from aiorpcx import TaskGroup import certifi import dns.resolver @@ -1226,6 +1225,47 @@ def make_aiohttp_session(proxy: Optional[dict], headers=None, timeout=None): return aiohttp.ClientSession(headers=headers, timeout=timeout, connector=connector) +class OldTaskGroup(aiorpcx.TaskGroup): + """Automatically raises exceptions on join; as in aiorpcx prior to version 0.20. + That is, when using TaskGroup as a context manager, if any task encounters an exception, + we would like that exception to be re-raised (propagated out). For the wait=all case, + the OldTaskGroup class is emulating the following code-snippet: + ``` + async with TaskGroup() as group: + await group.spawn(task1()) + await group.spawn(task2()) + + async for task in group: + if not task.cancelled(): + task.result() + ``` + So instead of the above, one can just write: + ``` + async with OldTaskGroup() as group: + await group.spawn(task1()) + await group.spawn(task2()) + ``` + """ + async def join(self): + if self._wait is all: + exc = False + try: + async for task in self: + if not task.cancelled(): + task.result() + except BaseException: # including asyncio.CancelledError + exc = True + raise + finally: + if exc: + await self.cancel_remaining() + await super().join() + else: + await super().join() + if self.completed: + self.completed.result() + + class NetworkJobOnDefaultServer(Logger, ABC): """An abstract base class for a job that runs on the main network interface. Every time the main interface changes, the job is @@ -1251,14 +1291,14 @@ def _reset(self): """Initialise fields. Called every time the underlying server connection changes. """ - self.taskgroup = TaskGroup() + self.taskgroup = OldTaskGroup() async def _start(self, interface: 'Interface'): self.interface = interface await interface.taskgroup.spawn(self._run_tasks(taskgroup=self.taskgroup)) @abstractmethod - async def _run_tasks(self, *, taskgroup: TaskGroup) -> None: + async def _run_tasks(self, *, taskgroup: OldTaskGroup) -> None: """Start tasks in taskgroup. Called every time the underlying server connection changes. """ diff --git a/electrum/wallet.py b/electrum/wallet.py index 081be26b448e..9ab5d02029dd 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -46,13 +46,13 @@ import threading import enum -from aiorpcx import TaskGroup, timeout_after, TaskTimeout, ignore_after +from aiorpcx import timeout_after, TaskTimeout, ignore_after from .i18n import _ from .bip32 import BIP32Node, convert_bip32_intpath_to_strpath, convert_bip32_path_to_list_of_uint32 from .crypto import sha256 from . import util -from .util import (NotEnoughFunds, UserCancelled, profiler, +from .util import (NotEnoughFunds, UserCancelled, profiler, OldTaskGroup, format_satoshis, format_fee_satoshis, NoDynamicFeeEstimates, WalletFileException, BitcoinException, InvalidPassword, format_time, timestamp_to_datetime, Satoshis, @@ -134,7 +134,7 @@ async def append_single_utxo(item): inputs.append(txin) u = await network.listunspent_for_scripthash(scripthash) - async with TaskGroup() as group: + async with OldTaskGroup() as group: for item in u: if len(inputs) >= imax: break @@ -155,7 +155,7 @@ async def find_utxos_for_privkey(txin_type, privkey, compressed): inputs = [] # type: List[PartialTxInput] keypairs = {} - async with TaskGroup() as group: + async with OldTaskGroup() as group: for sec in privkeys: txin_type, privkey, compressed = bitcoin.deserialize_privkey(sec) await group.spawn(find_utxos_for_privkey(txin_type, privkey, compressed)) diff --git a/run_electrum b/run_electrum index 3a7abd50b061..d80f24b5057c 100755 --- a/run_electrum +++ b/run_electrum @@ -63,8 +63,8 @@ def check_imports(): import aiorpcx except ImportError as e: sys.exit(f"Error: {str(e)}. Try 'sudo python3 -m pip install '") - if not ((0, 18, 7) <= aiorpcx._version < (0, 19)): - raise RuntimeError(f'aiorpcX version {aiorpcx._version} does not match required: 0.18.7<=ver<0.19') + if not ((0, 22, 0) <= aiorpcx._version < (0, 23)): + raise RuntimeError(f'aiorpcX version {aiorpcx._version} does not match required: 0.22.0<=ver<0.23') # the following imports are for pyinstaller from google.protobuf import descriptor from google.protobuf import message