URI: 
       ttest_lnpeer: add test for mpp_timeout - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 549b9a95df6fedbf0880df139a99f419931150e1
   DIR parent d4de25a8cde35ee47a7594194eef794a1b4f7b2b
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Wed, 10 Mar 2021 17:09:07 +0100
       
       ttest_lnpeer: add test for mpp_timeout
       
       Diffstat:
         M electrum/lnpeer.py                  |       6 ++++--
         M electrum/lnworker.py                |       6 ++++--
         M electrum/tests/test_lnpeer.py       |      46 +++++++++++++++++++++----------
       
       3 files changed, 39 insertions(+), 19 deletions(-)
       ---
   DIR diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
       t@@ -1820,7 +1820,7 @@ class Peer(Logger):
                                error_reason = e
                            else:
                                try:
       -                            preimage, fw_info, error_bytes = self.process_unfulfilled_htlc(
       +                            preimage, fw_info, error_bytes = await self.process_unfulfilled_htlc(
                                        chan=chan,
                                        htlc=htlc,
                                        forwarding_info=forwarding_info,
       t@@ -1849,7 +1849,7 @@ class Peer(Logger):
                        for htlc_id in done:
                            unfulfilled.pop(htlc_id)
        
       -    def process_unfulfilled_htlc(
       +    async def process_unfulfilled_htlc(
                    self, *,
                    chan: Channel,
                    htlc: UpdateAddHtlc,
       t@@ -1885,6 +1885,7 @@ class Peer(Logger):
                                    processed_onion=trampoline_onion,
                                    is_trampoline=True)
                            else:
       +                        await self.lnworker.enable_htlc_forwarding.wait()
                                self.maybe_forward_trampoline(
                                    chan=chan,
                                    htlc=htlc,
       t@@ -1899,6 +1900,7 @@ class Peer(Logger):
                                raise error_reason
        
                elif not forwarding_info:
       +            await self.lnworker.enable_htlc_forwarding.wait()
                    next_chan_id, next_htlc_id = self.maybe_forward_htlc(
                        htlc=htlc,
                        processed_onion=processed_onion)
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -91,7 +91,6 @@ SAVED_PR_STATUS = [PR_PAID, PR_UNPAID] # status that are persisted
        
        
        NUM_PEERS_TARGET = 4
       -MPP_EXPIRY = 120
        
        
        FALLBACK_NODE_LIST_TESTNET = (
       t@@ -575,6 +574,7 @@ class LNGossip(LNWorker):
        class LNWallet(LNWorker):
        
            lnwatcher: Optional['LNWalletWatcher']
       +    MPP_EXPIRY = 120
        
            def __init__(self, wallet: 'Abstract_Wallet', xprv):
                self.wallet = wallet
       t@@ -592,6 +592,8 @@ class LNWallet(LNWorker):
                # used in tests
                self.enable_htlc_settle = asyncio.Event()
                self.enable_htlc_settle.set()
       +        self.enable_htlc_forwarding = asyncio.Event()
       +        self.enable_htlc_forwarding.set()
        
                # note: accessing channels (besides simple lookup) needs self.lock!
                self._channels = {}  # type: Dict[bytes, Channel]
       t@@ -1633,7 +1635,7 @@ class LNWallet(LNWorker):
                if not is_accepted and not is_expired:
                    total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
                    first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set])
       -            if time.time() - first_timestamp > MPP_EXPIRY:
       +            if time.time() - first_timestamp > self.MPP_EXPIRY:
                        is_expired = True
                    elif total == expected_msat:
                        is_accepted = True
   DIR diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
       t@@ -113,6 +113,7 @@ class MockWallet:
        
        
        class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
       +    MPP_EXPIRY = 1
            def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name):
                self.name = name
                Logger.__init__(self)
       t@@ -136,6 +137,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
                # used in tests
                self.enable_htlc_settle = asyncio.Event()
                self.enable_htlc_settle.set()
       +        self.enable_htlc_forwarding = asyncio.Event()
       +        self.enable_htlc_forwarding.set()
                self.received_htlcs = dict()
                self.sent_htlcs = defaultdict(asyncio.Queue)
                self.sent_htlcs_routes = dict()
       t@@ -747,7 +750,7 @@ class TestPeer(ElectrumTestCase):
                with self.assertRaises(PaymentDone):
                    run(f())
        
       -    def _test_multipart_payment(self, graph, *, attempts):
       +    async def _run_mpp(self, graph, *, attempts):
                self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL))
                self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL))
                amount_to_pay = 600_000_000_000
       t@@ -761,32 +764,45 @@ class TestPeer(ElectrumTestCase):
                        raise PaymentDone()
                    else:
                        raise NoPathFound()
       -        async def f():
       -            async with TaskGroup() as group:
       -                for peer in peers:
       -                    await group.spawn(peer._message_loop())
       -                    await group.spawn(peer.htlc_switch())
       -                await asyncio.sleep(0.2)
       -                await group.spawn(pay())
       -        self.assertFalse(graph.w_d.features.supports(LnFeatures.BASIC_MPP_OPT))
       -        with self.assertRaises(NoPathFound):
       -            run(f())
       +        async with TaskGroup() as group:
       +            for peer in peers:
       +                await group.spawn(peer._message_loop())
       +                await group.spawn(peer.htlc_switch())
       +            await asyncio.sleep(0.2)
       +            await group.spawn(pay())
       +
       +    @needs_test_with_all_chacha20_implementations
       +    def test_multipart_payment_with_timeout(self):
       +        graph = self.prepare_chans_and_peers_in_square()
                graph.w_d.features |= LnFeatures.BASIC_MPP_OPT
       +        graph.w_b.enable_htlc_forwarding.clear()
       +        with self.assertRaises(NoPathFound):
       +           run(self._run_mpp(graph, attempts=1))
       +        graph.w_b.enable_htlc_forwarding.set()
                with self.assertRaises(PaymentDone):
       -            run(f())
       +           run(self._run_mpp(graph, attempts=1))
        
            @needs_test_with_all_chacha20_implementations
            def test_multipart_payment(self):
                graph = self.prepare_chans_and_peers_in_square()
       -        self._test_multipart_payment(graph, attempts=1)
       +        self.assertFalse(graph.w_d.features.supports(LnFeatures.BASIC_MPP_OPT))
       +        with self.assertRaises(NoPathFound):
       +           run(self._run_mpp(graph, attempts=1))
       +        graph.w_d.features |= LnFeatures.BASIC_MPP_OPT
       +        with self.assertRaises(PaymentDone):
       +           run(self._run_mpp(graph, attempts=1))
        
            @needs_test_with_all_chacha20_implementations
            def test_multipart_payment_with_trampoline(self):
                graph = self.prepare_chans_and_peers_in_square()
       +        graph.w_d.features |= LnFeatures.BASIC_MPP_OPT
                graph.w_a.network.channel_db.stop()
                graph.w_a.network.channel_db = None
       -        # Note: first attempt will fail with insufficient trampoline fee
       -        self._test_multipart_payment(graph, attempts=3)
       +        # Note: single attempt will fail with insufficient trampoline fee
       +        with self.assertRaises(NoPathFound):
       +           run(self._run_mpp(graph, attempts=1))
       +        with self.assertRaises(PaymentDone):
       +           run(self._run_mpp(graph, attempts=3))
        
            @needs_test_with_all_chacha20_implementations
            def test_close(self):