tlnrouter: add PathEdge/LNPaymentPath for (node_id, scid) - electrum - Electrum Bitcoin wallet HTML git clone https://git.parazyd.org/electrum DIR Log DIR Files DIR Refs DIR Submodules --- DIR commit 63b18dc30f54d7a42df923475f79cb338a706c2a DIR parent 04d018cd0f5fafb8abd4b2771fcc1c6206f80279 HTML Author: SomberNight <somber.night@protonmail.com> Date: Wed, 6 May 2020 10:51:45 +0200 lnrouter: add PathEdge/LNPaymentPath for (node_id, scid) Diffstat: M electrum/lnrouter.py | 39 +++++++++++++++++-------------- M electrum/tests/test_lnrouter.py | 25 ++++++------------------- 2 files changed, 27 insertions(+), 37 deletions(-) --- DIR diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py t@@ -50,11 +50,15 @@ def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_propor + (forwarded_amount_msat * fee_proportional_millionths // 1_000_000) -@attr.s -class RouteEdge: +@attr.s(slots=True) +class PathEdge: """if you travel through short_channel_id, you will reach node_id""" node_id = attr.ib(type=bytes, kw_only=True) short_channel_id = attr.ib(type=ShortChannelID, kw_only=True) + + +@attr.s +class RouteEdge(PathEdge): fee_base_msat = attr.ib(type=int, kw_only=True) fee_proportional_millionths = attr.ib(type=int, kw_only=True) cltv_expiry_delta = attr.ib(type=int, kw_only=True) t@@ -93,6 +97,7 @@ class RouteEdge: return bool(features & LnFeatures.VAR_ONION_REQ or features & LnFeatures.VAR_ONION_OPT) +LNPaymentPath = Sequence[PathEdge] LNPaymentRoute = Sequence[RouteEdge] t@@ -186,8 +191,8 @@ class LNPathFinder(Logger): def get_distances(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None) \ - -> Optional[Sequence[Tuple[bytes, bytes]]]: + my_channels: Dict[ShortChannelID, 'Channel'] = None + ) -> Dict[bytes, PathEdge]: # note: we don't lock self.channel_db, so while the path finding runs, # the underlying graph could potentially change... (not good but maybe ~OK?) t@@ -196,7 +201,7 @@ class LNPathFinder(Logger): # to properly calculate compound routing fees. distance_from_start = defaultdict(lambda: float('inf')) distance_from_start[nodeB] = 0 - prev_node = {} + prev_node = {} # type: Dict[bytes, PathEdge] nodes_to_explore = queue.PriorityQueue() nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters! t@@ -237,7 +242,8 @@ class LNPathFinder(Logger): alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost if alt_dist_to_neighbour < distance_from_start[edge_startnode]: distance_from_start[edge_startnode] = alt_dist_to_neighbour - prev_node[edge_startnode] = edge_endnode, edge_channel_id + prev_node[edge_startnode] = PathEdge(node_id=edge_endnode, + short_channel_id=ShortChannelID(edge_channel_id)) amount_to_forward_msat = amount_msat + fee_for_edge_msat nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode)) t@@ -247,13 +253,8 @@ class LNPathFinder(Logger): def find_path_for_payment(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *, my_channels: Dict[ShortChannelID, 'Channel'] = None) \ - -> Optional[Sequence[Tuple[bytes, bytes]]]: - """Return a path from nodeA to nodeB. - - Returns a list of (node_id, short_channel_id) representing a path. - To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1]; - i.e. an element reads as, "to get to node_id, travel through short_channel_id" - """ + -> Optional[LNPaymentPath]: + """Return a path from nodeA to nodeB.""" assert type(nodeA) is bytes assert type(nodeB) is bytes assert type(invoice_amount_msat) is int t@@ -270,19 +271,21 @@ class LNPathFinder(Logger): edge_startnode = nodeA path = [] while edge_startnode != nodeB: - edge_endnode, edge_taken = prev_node[edge_startnode] - path += [(edge_endnode, edge_taken)] - edge_startnode = edge_endnode + edge = prev_node[edge_startnode] + path += [edge] + edge_startnode = edge.node_id return path - def create_route_from_path(self, path, from_node_id: bytes, *, + def create_route_from_path(self, path: Optional[LNPaymentPath], from_node_id: bytes, *, my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute: assert isinstance(from_node_id, bytes) if path is None: raise Exception('cannot create route from None path') route = [] prev_node_id = from_node_id - for node_id, short_channel_id in path: + for edge in path: + node_id = edge.node_id + short_channel_id = edge.short_channel_id channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id, node_id=prev_node_id, my_channels=my_channels) DIR diff --git a/electrum/tests/test_lnrouter.py b/electrum/tests/test_lnrouter.py t@@ -10,6 +10,7 @@ from electrum.lnonion import (OnionHopsDataSingle, new_onion_packet, from electrum import bitcoin, lnrouter from electrum.constants import BitcoinTestnet from electrum.simple_config import SimpleConfig +from electrum.lnrouter import PathEdge from . import TestCaseForTestnet from .test_bitcoin import needs_test_with_all_chacha20_implementations t@@ -17,20 +18,6 @@ from .test_bitcoin import needs_test_with_all_chacha20_implementations class Test_LNRouter(TestCaseForTestnet): - #@staticmethod - #def parse_witness_list(witness_bytes): - # amount_witnesses = witness_bytes[0] - # witness_bytes = witness_bytes[1:] - # res = [] - # for i in range(amount_witnesses): - # witness_length = witness_bytes[0] - # this_witness = witness_bytes[1:witness_length+1] - # assert len(this_witness) == witness_length - # witness_bytes = witness_bytes[witness_length+1:] - # res += [bytes(this_witness)] - # assert witness_bytes == b"", witness_bytes - # return res - def setUp(self): super().setUp() self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop() t@@ -97,13 +84,13 @@ class Test_LNRouter(TestCaseForTestnet): cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 99999999, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0}) cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x01', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 150, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0}) path = path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000) - self.assertEqual([(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', b'\x00\x00\x00\x00\x00\x00\x00\x03'), - (b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', b'\x00\x00\x00\x00\x00\x00\x00\x02'), + self.assertEqual([PathEdge(node_id=b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', short_channel_id=bfh('0000000000000003')), + PathEdge(node_id=b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', short_channel_id=bfh('0000000000000002')), ], path) - start_node = b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb' + start_node = b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa' route = path_finder.create_route_from_path(path, start_node) - self.assertEqual(route[0].node_id, start_node) - self.assertEqual(route[0].short_channel_id, bfh('0000000000000003')) + self.assertEqual(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', route[0].node_id) + self.assertEqual(bfh('0000000000000003'), route[0].short_channel_id) # need to duplicate tear_down here, as we also need to wait for the sql thread to stop self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)