URI: 
       tlnrouter: add PathEdge/LNPaymentPath for (node_id, scid) - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 63b18dc30f54d7a42df923475f79cb338a706c2a
   DIR parent 04d018cd0f5fafb8abd4b2771fcc1c6206f80279
  HTML Author: SomberNight <somber.night@protonmail.com>
       Date:   Wed,  6 May 2020 10:51:45 +0200
       
       lnrouter: add PathEdge/LNPaymentPath for (node_id, scid)
       
       Diffstat:
         M electrum/lnrouter.py                |      39 +++++++++++++++++--------------
         M electrum/tests/test_lnrouter.py     |      25 ++++++-------------------
       
       2 files changed, 27 insertions(+), 37 deletions(-)
       ---
   DIR diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
       t@@ -50,11 +50,15 @@ def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_propor
                   + (forwarded_amount_msat * fee_proportional_millionths // 1_000_000)
        
        
       -@attr.s
       -class RouteEdge:
       +@attr.s(slots=True)
       +class PathEdge:
            """if you travel through short_channel_id, you will reach node_id"""
            node_id = attr.ib(type=bytes, kw_only=True)
            short_channel_id = attr.ib(type=ShortChannelID, kw_only=True)
       +
       +
       +@attr.s
       +class RouteEdge(PathEdge):
            fee_base_msat = attr.ib(type=int, kw_only=True)
            fee_proportional_millionths = attr.ib(type=int, kw_only=True)
            cltv_expiry_delta = attr.ib(type=int, kw_only=True)
       t@@ -93,6 +97,7 @@ class RouteEdge:
                return bool(features & LnFeatures.VAR_ONION_REQ or features & LnFeatures.VAR_ONION_OPT)
        
        
       +LNPaymentPath = Sequence[PathEdge]
        LNPaymentRoute = Sequence[RouteEdge]
        
        
       t@@ -186,8 +191,8 @@ class LNPathFinder(Logger):
        
            def get_distances(self, nodeA: bytes, nodeB: bytes,
                              invoice_amount_msat: int, *,
       -                      my_channels: Dict[ShortChannelID, 'Channel'] = None) \
       -                      -> Optional[Sequence[Tuple[bytes, bytes]]]:
       +                      my_channels: Dict[ShortChannelID, 'Channel'] = None
       +                      ) -> Dict[bytes, PathEdge]:
                # note: we don't lock self.channel_db, so while the path finding runs,
                #       the underlying graph could potentially change... (not good but maybe ~OK?)
        
       t@@ -196,7 +201,7 @@ class LNPathFinder(Logger):
                # to properly calculate compound routing fees.
                distance_from_start = defaultdict(lambda: float('inf'))
                distance_from_start[nodeB] = 0
       -        prev_node = {}
       +        prev_node = {}  # type: Dict[bytes, PathEdge]
                nodes_to_explore = queue.PriorityQueue()
                nodes_to_explore.put((0, invoice_amount_msat, nodeB))  # order of fields (in tuple) matters!
        
       t@@ -237,7 +242,8 @@ class LNPathFinder(Logger):
                        alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
                        if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
                            distance_from_start[edge_startnode] = alt_dist_to_neighbour
       -                    prev_node[edge_startnode] = edge_endnode, edge_channel_id
       +                    prev_node[edge_startnode] = PathEdge(node_id=edge_endnode,
       +                                                         short_channel_id=ShortChannelID(edge_channel_id))
                            amount_to_forward_msat = amount_msat + fee_for_edge_msat
                            nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode))
        
       t@@ -247,13 +253,8 @@ class LNPathFinder(Logger):
            def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
                                      invoice_amount_msat: int, *,
                                      my_channels: Dict[ShortChannelID, 'Channel'] = None) \
       -            -> Optional[Sequence[Tuple[bytes, bytes]]]:
       -        """Return a path from nodeA to nodeB.
       -
       -        Returns a list of (node_id, short_channel_id) representing a path.
       -        To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
       -        i.e. an element reads as, "to get to node_id, travel through short_channel_id"
       -        """
       +            -> Optional[LNPaymentPath]:
       +        """Return a path from nodeA to nodeB."""
                assert type(nodeA) is bytes
                assert type(nodeB) is bytes
                assert type(invoice_amount_msat) is int
       t@@ -270,19 +271,21 @@ class LNPathFinder(Logger):
                edge_startnode = nodeA
                path = []
                while edge_startnode != nodeB:
       -            edge_endnode, edge_taken = prev_node[edge_startnode]
       -            path += [(edge_endnode, edge_taken)]
       -            edge_startnode = edge_endnode
       +            edge = prev_node[edge_startnode]
       +            path += [edge]
       +            edge_startnode = edge.node_id
                return path
        
       -    def create_route_from_path(self, path, from_node_id: bytes, *,
       +    def create_route_from_path(self, path: Optional[LNPaymentPath], from_node_id: bytes, *,
                                       my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute:
                assert isinstance(from_node_id, bytes)
                if path is None:
                    raise Exception('cannot create route from None path')
                route = []
                prev_node_id = from_node_id
       -        for node_id, short_channel_id in path:
       +        for edge in path:
       +            node_id = edge.node_id
       +            short_channel_id = edge.short_channel_id
                    channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
                                                                         node_id=prev_node_id,
                                                                         my_channels=my_channels)
   DIR diff --git a/electrum/tests/test_lnrouter.py b/electrum/tests/test_lnrouter.py
       t@@ -10,6 +10,7 @@ from electrum.lnonion import (OnionHopsDataSingle, new_onion_packet,
        from electrum import bitcoin, lnrouter
        from electrum.constants import BitcoinTestnet
        from electrum.simple_config import SimpleConfig
       +from electrum.lnrouter import PathEdge
        
        from . import TestCaseForTestnet
        from .test_bitcoin import needs_test_with_all_chacha20_implementations
       t@@ -17,20 +18,6 @@ from .test_bitcoin import needs_test_with_all_chacha20_implementations
        
        class Test_LNRouter(TestCaseForTestnet):
        
       -    #@staticmethod
       -    #def parse_witness_list(witness_bytes):
       -    #    amount_witnesses = witness_bytes[0]
       -    #    witness_bytes = witness_bytes[1:]
       -    #    res = []
       -    #    for i in range(amount_witnesses):
       -    #        witness_length = witness_bytes[0]
       -    #        this_witness = witness_bytes[1:witness_length+1]
       -    #        assert len(this_witness) == witness_length
       -    #        witness_bytes = witness_bytes[witness_length+1:]
       -    #        res += [bytes(this_witness)]
       -    #    assert witness_bytes == b"", witness_bytes
       -    #    return res
       -
            def setUp(self):
                super().setUp()
                self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop()
       t@@ -97,13 +84,13 @@ class Test_LNRouter(TestCaseForTestnet):
                cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 99999999, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0})
                cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x01', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 150, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0})
                path = path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000)
       -        self.assertEqual([(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', b'\x00\x00\x00\x00\x00\x00\x00\x03'),
       -                          (b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', b'\x00\x00\x00\x00\x00\x00\x00\x02'),
       +        self.assertEqual([PathEdge(node_id=b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', short_channel_id=bfh('0000000000000003')),
       +                          PathEdge(node_id=b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', short_channel_id=bfh('0000000000000002')),
                                 ], path)
       -        start_node = b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb'
       +        start_node = b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'
                route = path_finder.create_route_from_path(path, start_node)
       -        self.assertEqual(route[0].node_id, start_node)
       -        self.assertEqual(route[0].short_channel_id, bfh('0000000000000003'))
       +        self.assertEqual(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', route[0].node_id)
       +        self.assertEqual(bfh('0000000000000003'),                 route[0].short_channel_id)
        
                # need to duplicate tear_down here, as we also need to wait for the sql thread to stop
                self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)