URI: 
       tMerge pull request #7099 from SomberNight/202103_fail_pending_htlcs_on_shutdown - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 6004a047053e97b9f613d8329f690d1fd6d920fb
   DIR parent 05e58671c92940620aab36cf5946a0f1dd24c013
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Fri, 12 Mar 2021 11:01:07 +0100
       
       Merge pull request #7099 from SomberNight/202103_fail_pending_htlcs_on_shutdown
       
       fail pending htlcs on shutdown
       Diffstat:
         M electrum/lnhtlc.py                  |      49 +++++++++++++++++++++++++++++++
         M electrum/lnpeer.py                  |      45 +++++++++++++++++++++++++++++---
         M electrum/lnworker.py                |      41 +++++++++++++++++++++++++++----
         M electrum/tests/test_lnpeer.py       |      62 +++++++++++++++++++++++++++----
       
       4 files changed, 182 insertions(+), 15 deletions(-)
       ---
   DIR diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py
       t@@ -360,6 +360,55 @@ class HTLCManager:
                return ctns[ctx_owner] <= self.ctn_oldest_unrevoked(ctx_owner)
        
            @with_lock
       +    def is_htlc_irrevocably_removed_yet(
       +            self,
       +            *,
       +            ctx_owner: HTLCOwner = None,
       +            htlc_proposer: HTLCOwner,
       +            htlc_id: int,
       +    ) -> bool:
       +        """Returns whether the removal of an htlc was irrevocably committed to `ctx_owner's` ctx.
       +        The removal can either be a fulfill/settle or a fail; they are not distinguished.
       +        If `ctx_owner` is None, both parties' ctxs are checked.
       +        """
       +        in_local = self._is_htlc_irrevocably_removed_yet(
       +            ctx_owner=LOCAL, htlc_proposer=htlc_proposer, htlc_id=htlc_id)
       +        in_remote = self._is_htlc_irrevocably_removed_yet(
       +            ctx_owner=REMOTE, htlc_proposer=htlc_proposer, htlc_id=htlc_id)
       +        if ctx_owner is None:
       +            return in_local and in_remote
       +        elif ctx_owner == LOCAL:
       +            return in_local
       +        elif ctx_owner == REMOTE:
       +            return in_remote
       +        else:
       +            raise Exception(f"unexpected ctx_owner: {ctx_owner!r}")
       +
       +    @with_lock
       +    def _is_htlc_irrevocably_removed_yet(
       +            self,
       +            *,
       +            ctx_owner: HTLCOwner,
       +            htlc_proposer: HTLCOwner,
       +            htlc_id: int,
       +    ) -> bool:
       +        htlc_id = int(htlc_id)
       +        if htlc_id >= self.get_next_htlc_id(htlc_proposer):
       +            return False
       +        if htlc_id in self.log[htlc_proposer]['settles']:
       +            ctn_of_settle = self.log[htlc_proposer]['settles'][htlc_id][ctx_owner]
       +        else:
       +            ctn_of_settle = None
       +        if htlc_id in self.log[htlc_proposer]['fails']:
       +            ctn_of_fail = self.log[htlc_proposer]['fails'][htlc_id][ctx_owner]
       +        else:
       +            ctn_of_fail = None
       +        ctn_of_rm = ctn_of_settle or ctn_of_fail or None
       +        if ctn_of_rm is None:
       +            return False
       +        return ctn_of_rm <= self.ctn_oldest_unrevoked(ctx_owner)
       +
       +    @with_lock
            def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction,
                                   ctn: int = None) -> Dict[int, UpdateAddHtlc]:
                """Return the dict of received or sent (depending on direction) HTLCs
   DIR diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
       t@@ -9,11 +9,12 @@ from collections import OrderedDict, defaultdict
        import asyncio
        import os
        import time
       -from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union
       +from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set
        from datetime import datetime
        import functools
        
        import aiorpcx
       +from aiorpcx import TaskGroup
        
        from .crypto import sha256, sha256d
        from . import bitcoin, util
       t@@ -74,6 +75,7 @@ class Peer(Logger):
                self._sent_init = False  # type: bool
                self._received_init = False  # type: bool
                self.initialized = asyncio.Future()
       +        self.got_disconnected = asyncio.Event()
                self.querying = asyncio.Event()
                self.transport = transport
                self.pubkey = pubkey  # remote pubkey
       t@@ -98,6 +100,11 @@ class Peer(Logger):
                self.orphan_channel_updates = OrderedDict()
                Logger.__init__(self)
                self.taskgroup = SilentTaskGroup()
       +        # HTLCs offered by REMOTE, that we started removing but are still active:
       +        self.received_htlcs_pending_removal = set()  # type: Set[Tuple[Channel, int]]
       +        self.received_htlc_removed_event = asyncio.Event()
       +        self._htlc_switch_iterstart_event = asyncio.Event()
       +        self._htlc_switch_iterdone_event = asyncio.Event()
        
            def send_message(self, message_name: str, **kwargs):
                assert type(message_name) is str
       t@@ -492,6 +499,7 @@ class Peer(Logger):
                except:
                    pass
                self.lnworker.peer_closed(self)
       +        self.got_disconnected.set()
        
            def is_static_remotekey(self):
                return self.features.supports(LnFeatures.OPTION_STATIC_REMOTEKEY_OPT)
       t@@ -1575,6 +1583,7 @@ class Peer(Logger):
                self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
                assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
                assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id)
       +        self.received_htlcs_pending_removal.add((chan, htlc_id))
                chan.settle_htlc(preimage, htlc_id)
                self.send_message(
                    "update_fulfill_htlc",
       t@@ -1585,6 +1594,7 @@ class Peer(Logger):
            def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes):
                self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
                assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
       +        self.received_htlcs_pending_removal.add((chan, htlc_id))
                chan.fail_htlc(htlc_id)
                self.send_message(
                    "update_fail_htlc",
       t@@ -1596,9 +1606,10 @@ class Peer(Logger):
            def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionRoutingFailure):
                self.logger.info(f"fail_malformed_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
                assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
       -        chan.fail_htlc(htlc_id)
                if not (reason.code & OnionFailureCodeMetaFlag.BADONION and len(reason.data) == 32):
                    raise Exception(f"unexpected reason when sending 'update_fail_malformed_htlc': {reason!r}")
       +        self.received_htlcs_pending_removal.add((chan, htlc_id))
       +        chan.fail_htlc(htlc_id)
                self.send_message(
                    "update_fail_malformed_htlc",
                    channel_id=chan.channel_id,
       t@@ -1800,8 +1811,13 @@ class Peer(Logger):
            async def htlc_switch(self):
                await self.initialized
                while True:
       -            await asyncio.sleep(0.1)
       +            self._htlc_switch_iterdone_event.set()
       +            self._htlc_switch_iterdone_event.clear()
       +            await asyncio.sleep(0.1)  # TODO maybe make this partly event-driven
       +            self._htlc_switch_iterstart_event.set()
       +            self._htlc_switch_iterstart_event.clear()
                    self.ping_if_required()
       +            self._maybe_cleanup_received_htlcs_pending_removal()
                    for chan_id, chan in self.channels.items():
                        if not chan.can_send_ctx_updates():
                            continue
       t@@ -1853,6 +1869,29 @@ class Peer(Logger):
                        for htlc_id in done:
                            unfulfilled.pop(htlc_id)
        
       +    def _maybe_cleanup_received_htlcs_pending_removal(self) -> None:
       +        done = set()
       +        for chan, htlc_id in self.received_htlcs_pending_removal:
       +            if chan.hm.is_htlc_irrevocably_removed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id):
       +                done.add((chan, htlc_id))
       +        if done:
       +            for key in done:
       +                self.received_htlcs_pending_removal.remove(key)
       +            self.received_htlc_removed_event.set()
       +            self.received_htlc_removed_event.clear()
       +
       +    async def wait_one_htlc_switch_iteration(self) -> None:
       +        """Waits until the HTLC switch does a full iteration or the peer disconnects,
       +        whichever happens first.
       +        """
       +        async def htlc_switch_iteration():
       +            await self._htlc_switch_iterstart_event.wait()
       +            await self._htlc_switch_iterdone_event.wait()
       +
       +        async with TaskGroup(wait=any) as group:
       +            await group.spawn(htlc_switch_iteration())
       +            await group.spawn(self.got_disconnected.wait())
       +
            async def process_unfulfilled_htlc(
                    self, *,
                    chan: Channel,
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -22,7 +22,7 @@ import urllib.parse
        
        import dns.resolver
        import dns.exception
       -from aiorpcx import run_in_thread, TaskGroup, NetAddress
       +from aiorpcx import run_in_thread, TaskGroup, NetAddress, ignore_after
        
        from . import constants, util
        from . import keystore
       t@@ -195,6 +195,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
                self.features = features
                self.network = None  # type: Optional[Network]
                self.config = None  # type: Optional[SimpleConfig]
       +        self.stopping_soon = False  # whether we are being shut down
        
                util.register_callback(self.on_proxy_changed, ['proxy_set'])
        
       t@@ -268,6 +269,8 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
            async def _maintain_connectivity(self):
                while True:
                    await asyncio.sleep(1)
       +            if self.stopping_soon:
       +                return
                    now = time.time()
                    if len(self._peers) >= NUM_PEERS_TARGET:
                        continue
       t@@ -575,6 +578,7 @@ class LNWallet(LNWorker):
        
            lnwatcher: Optional['LNWalletWatcher']
            MPP_EXPIRY = 120
       +    TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3  # seconds
        
            def __init__(self, wallet: 'Abstract_Wallet', xprv):
                self.wallet = wallet
       t@@ -707,9 +711,32 @@ class LNWallet(LNWorker):
                    asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
        
            async def stop(self):
       -        await super().stop()
       -        await self.lnwatcher.stop()
       -        self.lnwatcher = None
       +        self.stopping_soon = True
       +        if self.listen_server:  # stop accepting new peers
       +            self.listen_server.close()
       +        async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS):
       +            await self.wait_for_received_pending_htlcs_to_get_removed()
       +        await LNWorker.stop(self)
       +        if self.lnwatcher:
       +            await self.lnwatcher.stop()
       +            self.lnwatcher = None
       +
       +    async def wait_for_received_pending_htlcs_to_get_removed(self):
       +        assert self.stopping_soon is True
       +        # We try to fail pending MPP HTLCs, and wait a bit for them to get removed.
       +        # Note: even without MPP, if we just failed/fulfilled an HTLC, it is good
       +        #       to wait a bit for it to become irrevocably removed.
       +        # Note: we don't wait for *all htlcs* to get removed, only for those
       +        #       that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed
       +        async with TaskGroup() as group:
       +            for peer in self.peers.values():
       +                await group.spawn(peer.wait_one_htlc_switch_iteration())
       +        while True:
       +            if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()):
       +                break
       +            async with TaskGroup(wait=any) as group:
       +                for peer in self.peers.values():
       +                    await group.spawn(peer.received_htlc_removed_event.wait())
        
            def peer_closed(self, peer):
                for chan in self.channels_for_peer(peer.pubkey).values():
       t@@ -1635,7 +1662,9 @@ class LNWallet(LNWorker):
                if not is_accepted and not is_expired:
                    total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
                    first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set])
       -            if time.time() - first_timestamp > self.MPP_EXPIRY:
       +            if self.stopping_soon:
       +                is_expired = True  # try to time out pending HTLCs before shutting down
       +            elif time.time() - first_timestamp > self.MPP_EXPIRY:
                        is_expired = True
                    elif total == expected_msat:
                        is_accepted = True
       t@@ -1897,6 +1926,8 @@ class LNWallet(LNWorker):
            async def reestablish_peers_and_channels(self):
                while True:
                    await asyncio.sleep(1)
       +            if self.stopping_soon:
       +                return
                    for chan in self.channels.values():
                        if chan.is_closed():
                            continue
   DIR diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
       t@@ -10,7 +10,7 @@ from concurrent import futures
        import unittest
        from typing import Iterable, NamedTuple, Tuple, List
        
       -from aiorpcx import TaskGroup
       +from aiorpcx import TaskGroup, timeout_after, TaskTimeout
        
        from electrum import bitcoin
        from electrum import constants
       t@@ -113,7 +113,8 @@ class MockWallet:
        
        
        class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
       -    MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
       +    MPP_EXPIRY = 2  # HTLC timestamps are cast to int, so this cannot be 1
       +    TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0
        
            def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name):
                self.name = name
       t@@ -121,6 +122,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
                NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
                self.node_keypair = local_keypair
                self.network = MockNetwork(tx_queue)
       +        self.taskgroup = TaskGroup()
       +        self.lnwatcher = None
       +        self.listen_server = None
                self._channels = {chan.channel_id: chan for chan in chans}
                self.payments = {}
                self.logs = defaultdict(list)
       t@@ -147,6 +151,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
                self.trampoline_forwarding_failures = {}
                self.inflight_payments = set()
                self.preimages = {}
       +        self.stopping_soon = False
        
            def get_invoice_status(self, key):
                pass
       t@@ -183,6 +188,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
                return self.name
        
            async def stop(self):
       +        await LNWallet.stop(self)
                if self.channel_db:
                    self.channel_db.stop()
                    await self.channel_db.stopped_event.wait()
       t@@ -215,6 +221,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
            _calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice
            handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc
            is_trampoline_peer = LNWallet.is_trampoline_peer
       +    wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed
       +    on_proxy_changed = LNWallet.on_proxy_changed
        
        
        class MockTransport:
       t@@ -290,13 +298,9 @@ class SquareGraph(NamedTuple):
            def all_lnworkers(self) -> Iterable[MockLNWallet]:
                return self.w_a, self.w_b, self.w_c, self.w_d
        
       -    async def stop_and_cleanup(self):
       -        async with TaskGroup() as group:
       -            for lnworker in self.all_lnworkers():
       -                await group.spawn(lnworker.stop())
       -
        
        class PaymentDone(Exception): pass
       +class TestSuccess(Exception): pass
        
        
        class TestPeer(ElectrumTestCase):
       t@@ -837,6 +841,50 @@ class TestPeer(ElectrumTestCase):
                self._run_mpp(graph, {'alice_uses_trampoline':True, 'attempts':1}, {'alice_uses_trampoline':True, 'attempts':3})
        
            @needs_test_with_all_chacha20_implementations
       +    def test_fail_pending_htlcs_on_shutdown(self):
       +        """Alice tries to pay Dave via MPP. Dave receives some HTLCs but not all.
       +        Dave shuts down (stops wallet).
       +        We test if Dave fails the pending HTLCs during shutdown.
       +        """
       +        graph = self.prepare_chans_and_peers_in_square()
       +        self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL))
       +        self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL))
       +        amount_to_pay = 600_000_000_000
       +        peers = graph.all_peers()
       +        graph.w_d.MPP_EXPIRY = 120
       +        graph.w_d.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3
       +        async def pay():
       +            graph.w_d.features |= LnFeatures.BASIC_MPP_OPT
       +            graph.w_b.enable_htlc_forwarding.clear()  # Bob will hold forwarded HTLCs
       +            assert graph.w_a.network.channel_db is not None
       +            lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay)
       +            try:
       +                async with timeout_after(0.5):
       +                    result, log = await graph.w_a.pay_invoice(pay_req, attempts=1)
       +            except TaskTimeout:
       +                # by now Dave hopefully received some HTLCs:
       +                self.assertTrue(len(graph.chan_dc.hm.htlcs(LOCAL)) > 0)
       +                self.assertTrue(len(graph.chan_dc.hm.htlcs(REMOTE)) > 0)
       +            else:
       +                self.fail(f"pay_invoice finished but was not supposed to. result={result}")
       +            await graph.w_d.stop()
       +            # Dave is supposed to have failed the pending incomplete MPP HTLCs
       +            self.assertEqual(0, len(graph.chan_dc.hm.htlcs(LOCAL)))
       +            self.assertEqual(0, len(graph.chan_dc.hm.htlcs(REMOTE)))
       +            raise TestSuccess()
       +
       +        async def f():
       +            async with TaskGroup() as group:
       +                for peer in peers:
       +                    await group.spawn(peer._message_loop())
       +                    await group.spawn(peer.htlc_switch())
       +                await asyncio.sleep(0.2)
       +                await group.spawn(pay())
       +
       +        with self.assertRaises(TestSuccess):
       +            run(f())
       +
       +    @needs_test_with_all_chacha20_implementations
            def test_close(self):
                alice_channel, bob_channel = create_test_channels()
                p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)