URI: 
       ttest_lnpeer: add some multi-hop payment unit tests - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit cc4029c335dea48a93c3a78fa3b27262d34458d9
   DIR parent 7153e753d1829ccdf1f0ee6821082c1f13ba0f21
  HTML Author: SomberNight <somber.night@protonmail.com>
       Date:   Wed,  6 May 2020 11:00:58 +0200
       
       ttest_lnpeer: add some multi-hop payment unit tests
       
       Diffstat:
         M electrum/lnpeer.py                  |       4 +++-
         M electrum/lnworker.py                |       1 +
         M electrum/tests/test_lnpeer.py       |     184 ++++++++++++++++++++++++++++++-
       
       3 files changed, 186 insertions(+), 3 deletions(-)
       ---
   DIR diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
       t@@ -1510,7 +1510,9 @@ class Peer(Logger):
                                self.logger.info(f"error processing onion packet: {e!r}")
                                error_reason = OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
                            else:
       -                        if processed_onion.are_we_final:
       +                        if self.lnworker._fail_htlcs_with_temp_node_failure:
       +                            error_reason = OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
       +                        elif processed_onion.are_we_final:
                                    preimage, error_reason = self.maybe_fulfill_htlc(
                                        chan=chan,
                                        htlc=htlc,
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -494,6 +494,7 @@ class LNWallet(LNWorker):
                # used in tests
                self.enable_htlc_settle = asyncio.Event()
                self.enable_htlc_settle.set()
       +        self._fail_htlcs_with_temp_node_failure = False
        
                # note: accessing channels (besides simple lookup) needs self.lock!
                self._channels = {}  # type: Dict[bytes, Channel]
   DIR diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
       t@@ -8,7 +8,7 @@ import logging
        import concurrent
        from concurrent import futures
        import unittest
       -from typing import Iterable
       +from typing import Iterable, NamedTuple
        
        from aiorpcx import TaskGroup
        
       t@@ -24,12 +24,13 @@ from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
        from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
        from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner
        from electrum.lnchannel import ChannelState, PeerState, Channel
       -from electrum.lnrouter import LNPathFinder
       +from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent
        from electrum.channel_db import ChannelDB
        from electrum.lnworker import LNWallet, NoPathFound
        from electrum.lnmsg import encode_msg, decode_msg
        from electrum.logging import console_stderr_handler, Logger
        from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID
       +from electrum.lnonion import OnionFailureCode
        
        from .test_lnchannel import create_test_channels
        from .test_bitcoin import needs_test_with_all_chacha20_implementations
       t@@ -117,6 +118,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
                # used in tests
                self.enable_htlc_settle = asyncio.Event()
                self.enable_htlc_settle.set()
       +        self._fail_htlcs_with_temp_node_failure = False
        
            def get_invoice_status(self, key):
                pass
       t@@ -212,6 +214,37 @@ def transport_pair(k1, k2, name1, name2):
            return t1, t2
        
        
       +class DiamondGraph(NamedTuple):
       +    #        A
       +    #      /   \
       +    #     B     C
       +    #      \   /
       +    #        D
       +    w_a: MockLNWallet
       +    w_b: MockLNWallet
       +    w_c: MockLNWallet
       +    w_d: MockLNWallet
       +    peer_ab: Peer
       +    peer_ac: Peer
       +    peer_ba: Peer
       +    peer_bd: Peer
       +    peer_ca: Peer
       +    peer_cd: Peer
       +    peer_db: Peer
       +    peer_dc: Peer
       +    chan_ab: Channel
       +    chan_ac: Channel
       +    chan_ba: Channel
       +    chan_bd: Channel
       +    chan_ca: Channel
       +    chan_cd: Channel
       +    chan_db: Channel
       +    chan_dc: Channel
       +
       +    def all_peers(self) -> Iterable[Peer]:
       +        return self.peer_ab, self.peer_ac, self.peer_ba, self.peer_bd, self.peer_ca, self.peer_cd, self.peer_db, self.peer_dc
       +
       +
        class PaymentDone(Exception): pass
        
        
       t@@ -252,6 +285,77 @@ class TestPeer(ElectrumTestCase):
                p2.mark_open(bob_channel)
                return p1, p2, w1, w2, q1, q2
        
       +    def prepare_chans_and_peers_in_diamond(self) -> DiamondGraph:
       +        key_a, key_b, key_c, key_d = [keypair() for i in range(4)]
       +        chan_ab, chan_ba = create_test_channels(alice_name="alice", bob_name="bob", alice_pubkey=key_a.pubkey, bob_pubkey=key_b.pubkey)
       +        chan_ac, chan_ca = create_test_channels(alice_name="alice", bob_name="carol", alice_pubkey=key_a.pubkey, bob_pubkey=key_c.pubkey)
       +        chan_bd, chan_db = create_test_channels(alice_name="bob", bob_name="dave", alice_pubkey=key_b.pubkey, bob_pubkey=key_d.pubkey)
       +        chan_cd, chan_dc = create_test_channels(alice_name="carol", bob_name="dave", alice_pubkey=key_c.pubkey, bob_pubkey=key_d.pubkey)
       +        trans_ab, trans_ba = transport_pair(key_a, key_b, chan_ab.name, chan_ba.name)
       +        trans_ac, trans_ca = transport_pair(key_a, key_c, chan_ac.name, chan_ca.name)
       +        trans_bd, trans_db = transport_pair(key_b, key_d, chan_bd.name, chan_db.name)
       +        trans_cd, trans_dc = transport_pair(key_c, key_d, chan_cd.name, chan_dc.name)
       +        txq_a, txq_b, txq_c, txq_d = [asyncio.Queue() for i in range(4)]
       +        w_a = MockLNWallet(local_keypair=key_a, chans=[chan_ab, chan_ac], tx_queue=txq_a)
       +        w_b = MockLNWallet(local_keypair=key_b, chans=[chan_ba, chan_bd], tx_queue=txq_b)
       +        w_c = MockLNWallet(local_keypair=key_c, chans=[chan_ca, chan_cd], tx_queue=txq_c)
       +        w_d = MockLNWallet(local_keypair=key_d, chans=[chan_db, chan_dc], tx_queue=txq_d)
       +        peer_ab = Peer(w_a, key_b.pubkey, trans_ab)
       +        peer_ac = Peer(w_a, key_c.pubkey, trans_ac)
       +        peer_ba = Peer(w_b, key_a.pubkey, trans_ba)
       +        peer_bd = Peer(w_b, key_d.pubkey, trans_bd)
       +        peer_ca = Peer(w_c, key_a.pubkey, trans_ca)
       +        peer_cd = Peer(w_c, key_d.pubkey, trans_cd)
       +        peer_db = Peer(w_d, key_b.pubkey, trans_db)
       +        peer_dc = Peer(w_d, key_c.pubkey, trans_dc)
       +        w_a._peers[peer_ab.pubkey] = peer_ab
       +        w_a._peers[peer_ac.pubkey] = peer_ac
       +        w_b._peers[peer_ba.pubkey] = peer_ba
       +        w_b._peers[peer_bd.pubkey] = peer_bd
       +        w_c._peers[peer_ca.pubkey] = peer_ca
       +        w_c._peers[peer_cd.pubkey] = peer_cd
       +        w_d._peers[peer_db.pubkey] = peer_db
       +        w_d._peers[peer_dc.pubkey] = peer_dc
       +
       +        w_b.network.config.set_key('lightning_forward_payments', True)
       +        w_c.network.config.set_key('lightning_forward_payments', True)
       +
       +        # mark_open won't work if state is already OPEN.
       +        # so set it to FUNDED
       +        for chan in [chan_ab, chan_ac, chan_ba, chan_bd, chan_ca, chan_cd, chan_db, chan_dc]:
       +            chan._state = ChannelState.FUNDED
       +        # this populates the channel graph:
       +        peer_ab.mark_open(chan_ab)
       +        peer_ac.mark_open(chan_ac)
       +        peer_ba.mark_open(chan_ba)
       +        peer_bd.mark_open(chan_bd)
       +        peer_ca.mark_open(chan_ca)
       +        peer_cd.mark_open(chan_cd)
       +        peer_db.mark_open(chan_db)
       +        peer_dc.mark_open(chan_dc)
       +        return DiamondGraph(
       +            w_a=w_a,
       +            w_b=w_b,
       +            w_c=w_c,
       +            w_d=w_d,
       +            peer_ab=peer_ab,
       +            peer_ac=peer_ac,
       +            peer_ba=peer_ba,
       +            peer_bd=peer_bd,
       +            peer_ca=peer_ca,
       +            peer_cd=peer_cd,
       +            peer_db=peer_db,
       +            peer_dc=peer_dc,
       +            chan_ab=chan_ab,
       +            chan_ac=chan_ac,
       +            chan_ba=chan_ba,
       +            chan_bd=chan_bd,
       +            chan_ca=chan_ca,
       +            chan_cd=chan_cd,
       +            chan_db=chan_db,
       +            chan_dc=chan_dc,
       +        )
       +
            @staticmethod
            async def prepare_invoice(
                    w2: MockLNWallet,  # receiver
       t@@ -383,6 +487,82 @@ class TestPeer(ElectrumTestCase):
                self.assertEqual(bob_init_balance_msat + num_payments * payment_value_sat * 1000, alice_channel.balance(HTLCOwner.REMOTE))
        
            @needs_test_with_all_chacha20_implementations
       +    def test_payment_multihop(self):
       +        graph = self.prepare_chans_and_peers_in_diamond()
       +        peers = graph.all_peers()
       +        async def pay(pay_req):
       +            result, log = await graph.w_a._pay(pay_req)
       +            self.assertTrue(result)
       +            raise PaymentDone()
       +        async def f():
       +            async with TaskGroup() as group:
       +                for peer in peers:
       +                    await group.spawn(peer._message_loop())
       +                    await group.spawn(peer.htlc_switch())
       +                await asyncio.sleep(0.2)
       +                pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True)
       +                await group.spawn(pay(pay_req))
       +        with self.assertRaises(PaymentDone):
       +            run(f())
       +
       +    @needs_test_with_all_chacha20_implementations
       +    def test_payment_multihop_with_preselected_path(self):
       +        graph = self.prepare_chans_and_peers_in_diamond()
       +        peers = graph.all_peers()
       +        async def pay(pay_req):
       +            with self.subTest(msg="bad path: edges do not chain together"):
       +                path = [PathEdge(node_id=graph.w_c.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id),
       +                        PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)]
       +                result, log = await graph.w_a._pay(pay_req, full_path=path)
       +                self.assertFalse(result)
       +                self.assertTrue(isinstance(log[0].exception, LNPathInconsistent))
       +            with self.subTest(msg="bad path: last node id differs from invoice pubkey"):
       +                path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id)]
       +                result, log = await graph.w_a._pay(pay_req, full_path=path)
       +                self.assertFalse(result)
       +                self.assertTrue(isinstance(log[0].exception, LNPathInconsistent))
       +            with self.subTest(msg="good path"):
       +                path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id),
       +                        PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)]
       +                result, log = await graph.w_a._pay(pay_req, full_path=path)
       +                self.assertTrue(result)
       +                self.assertEqual([edge.short_channel_id for edge in path],
       +                                 [edge.short_channel_id for edge in log[0].route])
       +            raise PaymentDone()
       +        async def f():
       +            async with TaskGroup() as group:
       +                for peer in peers:
       +                    await group.spawn(peer._message_loop())
       +                    await group.spawn(peer.htlc_switch())
       +                await asyncio.sleep(0.2)
       +                pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True)
       +                await group.spawn(pay(pay_req))
       +        with self.assertRaises(PaymentDone):
       +            run(f())
       +
       +    @needs_test_with_all_chacha20_implementations
       +    def test_payment_multihop_temp_node_failure(self):
       +        graph = self.prepare_chans_and_peers_in_diamond()
       +        graph.w_b._fail_htlcs_with_temp_node_failure = True
       +        graph.w_c._fail_htlcs_with_temp_node_failure = True
       +        peers = graph.all_peers()
       +        async def pay(pay_req):
       +            result, log = await graph.w_a._pay(pay_req)
       +            self.assertFalse(result)
       +            self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_details.failure_msg.code)
       +            raise PaymentDone()
       +        async def f():
       +            async with TaskGroup() as group:
       +                for peer in peers:
       +                    await group.spawn(peer._message_loop())
       +                    await group.spawn(peer.htlc_switch())
       +                await asyncio.sleep(0.2)
       +                pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True)
       +                await group.spawn(pay(pay_req))
       +        with self.assertRaises(PaymentDone):
       +            run(f())
       +
       +    @needs_test_with_all_chacha20_implementations
            def test_close(self):
                alice_channel, bob_channel = create_test_channels()
                p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)