URI: 
       tlnrouter: run Dijkstra in reverse direction - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 2364de930b17a04e7191846956231feb2cb725ca
   DIR parent 7edbd5682ae2de059e7608b00ff7b6184c122e10
  HTML Author: SomberNight <somber.night@protonmail.com>
       Date:   Sat, 20 Oct 2018 17:52:29 +0200
       
       lnrouter: run Dijkstra in reverse direction
       
       Diffstat:
         M electrum/lnrouter.py                |      73 +++++++++++++++++--------------
         M electrum/lnworker.py                |       4 ++--
         M electrum/tests/test_lnrouter.py     |       5 +++++
       
       3 files changed, 47 insertions(+), 35 deletions(-)
       ---
   DIR diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
       t@@ -565,59 +565,64 @@ class LNPathFinder(PrintError):
                self.blacklist = set()
        
            def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
       -                   payment_amt_msat: int, ignore_cltv=False) -> float:
       -        """Heuristic cost of going through a channel."""
       +                   payment_amt_msat: int, ignore_costs=False) -> Tuple[float, int]:
       +        """Heuristic cost of going through a channel.
       +        Returns (heuristic_cost, fee_for_edge_msat).
       +        """
                channel_info = self.channel_db.get_channel_info(short_channel_id)  # type: ChannelInfo
                if channel_info is None:
       -            return float('inf')
       +            return float('inf'), 0
        
                channel_policy = channel_info.get_policy_for_node(start_node)
       -        if channel_policy is None: return float('inf')
       -        if channel_policy.disabled: return float('inf')
       +        if channel_policy is None: return float('inf'), 0
       +        if channel_policy.disabled: return float('inf'), 0
                route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node)
                if payment_amt_msat < channel_policy.htlc_minimum_msat:
       -            return float('inf')  # payment amount too little
       +            return float('inf'), 0  # payment amount too little
                if channel_info.capacity_sat is not None and \
                        payment_amt_msat // 1000 > channel_info.capacity_sat:
       -            return float('inf')  # payment amount too large
       +            return float('inf'), 0  # payment amount too large
                if channel_policy.htlc_maximum_msat is not None and \
                        payment_amt_msat > channel_policy.htlc_maximum_msat:
       -            return float('inf')  # payment amount too large
       +            return float('inf'), 0  # payment amount too large
                if not route_edge.is_sane_to_use(payment_amt_msat):
       -            return float('inf')  # thanks but no thanks
       -        fee_msat = route_edge.fee_for_edge(payment_amt_msat)
       +            return float('inf'), 0  # thanks but no thanks
       +        fee_msat = route_edge.fee_for_edge(payment_amt_msat) if not ignore_costs else 0
                # TODO revise
                # paying 10 more satoshis ~ waiting one more block
                fee_cost = fee_msat / 1000 / 10
       -        cltv_cost = route_edge.cltv_expiry_delta if not ignore_cltv else 0
       -        return cltv_cost + fee_cost + 1
       +        cltv_cost = route_edge.cltv_expiry_delta if not ignore_costs else 0
       +        return cltv_cost + fee_cost + 1, fee_msat
        
            @profiler
       -    def find_path_for_payment(self, from_node_id: bytes, to_node_id: bytes,
       -                              amount_msat: int, my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]:
       -        """Return a path between from_node_id and to_node_id.
       +    def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
       +                              invoice_amount_msat: int,
       +                              my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]:
       +        """Return a path from nodeA to nodeB.
        
                Returns a list of (node_id, short_channel_id) representing a path.
                To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
                i.e. an element reads as, "to get to node_id, travel through short_channel_id"
                """
       -        assert type(amount_msat) is int
       +        assert type(invoice_amount_msat) is int
                if my_channels is None: my_channels = []
       -        unable_channels = set(map(lambda x: x.short_channel_id, filter(lambda x: not x.can_pay(amount_msat), my_channels)))
       +        unable_channels = set(map(lambda x: x.short_channel_id,
       +                                  filter(lambda x: not x.can_pay(invoice_amount_msat), my_channels)))
        
       -        # TODO find multiple paths??
                # FIXME paths cannot be longer than 21 edges (onion packet)...
        
                # run Dijkstra
       +        # The search is run in the REVERSE direction, from nodeB to nodeA,
       +        # to properly calculate compound routing fees.
                distance_from_start = defaultdict(lambda: float('inf'))
       -        distance_from_start[from_node_id] = 0
       +        distance_from_start[nodeB] = 0
                prev_node = {}
                nodes_to_explore = queue.PriorityQueue()
       -        nodes_to_explore.put((0, from_node_id))
       +        nodes_to_explore.put((0, invoice_amount_msat, nodeB))  # order of fields (in tuple) matters!
        
                while nodes_to_explore.qsize() > 0:
       -            dist_to_cur_node, cur_node = nodes_to_explore.get()
       -            if cur_node == to_node_id:
       +            dist_to_cur_node, amount_msat, cur_node = nodes_to_explore.get()
       +            if cur_node == nodeA:
                        break
                    if dist_to_cur_node != distance_from_start[cur_node]:
                        # queue.PriorityQueue does not implement decrease_priority,
       t@@ -628,27 +633,29 @@ class LNPathFinder(PrintError):
                        if edge_channel_id in self.blacklist or edge_channel_id in unable_channels:
                            continue
                        channel_info = self.channel_db.get_channel_info(edge_channel_id)
       -                node1, node2 = channel_info.node_id_1, channel_info.node_id_2
       -                neighbour = node2 if node1 == cur_node else node1
       -                ignore_cltv_delta_in_edge_cost = cur_node == from_node_id
       -                edge_cost = self._edge_cost(edge_channel_id, cur_node, neighbour, amount_msat,
       -                                            ignore_cltv=ignore_cltv_delta_in_edge_cost)
       +                neighbour = channel_info.node_id_2 if channel_info.node_id_1 == cur_node else channel_info.node_id_1
       +                ignore_costs = neighbour == nodeA  # no fees when using our own channel
       +                edge_cost, fee_for_edge_msat = self._edge_cost(edge_channel_id,
       +                                                               start_node=neighbour,
       +                                                               end_node=cur_node,
       +                                                               payment_amt_msat=amount_msat,
       +                                                               ignore_costs=ignore_costs)
                        alt_dist_to_neighbour = distance_from_start[cur_node] + edge_cost
                        if alt_dist_to_neighbour < distance_from_start[neighbour]:
                            distance_from_start[neighbour] = alt_dist_to_neighbour
                            prev_node[neighbour] = cur_node, edge_channel_id
       -                    nodes_to_explore.put((alt_dist_to_neighbour, neighbour))
       +                    amount_to_forward_msat = amount_msat + fee_for_edge_msat
       +                    nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, neighbour))
                else:
                    return None  # no path found
        
       -        # backtrack from end to start
       -        cur_node = to_node_id
       +        # backtrack from search_end (nodeA) to search_start (nodeB)
       +        cur_node = nodeA
                path = []
       -        while cur_node != from_node_id:
       +        while cur_node != nodeB:
                    prev_node_id, edge_taken = prev_node[cur_node]
       -            path += [(cur_node, edge_taken)]
       +            path += [(prev_node_id, edge_taken)]
                    cur_node = prev_node_id
       -        path.reverse()
                return path
        
            def create_route_from_path(self, path, from_node_id: bytes) -> List[RouteEdge]:
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -260,14 +260,14 @@ class LNWorker(PrintError):
                        f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
                route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat)
                node_id, short_channel_id = route[0].node_id, route[0].short_channel_id
       -        peer = self.peers[node_id]
                with self.lock:
                    channels = list(self.channels.values())
                for chan in channels:
                    if chan.short_channel_id == short_channel_id:
                        break
                else:
       -            raise Exception("ChannelDB returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
       +            raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
       +        peer = self.peers[node_id]
                coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry())
                return addr, peer, asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
        
   DIR diff --git a/electrum/tests/test_lnrouter.py b/electrum/tests/test_lnrouter.py
       t@@ -93,6 +93,11 @@ class Test_LNRouter(TestCaseForTestnet):
                cdb.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': o(10), 'htlc_minimum_msat': o(250), 'fee_base_msat': o(100), 'fee_proportional_millionths': o(99999999), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'timestamp': b'\x00\x00\x00\x00'}, trusted=True)
                cdb.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x01', 'cltv_expiry_delta': o(10), 'htlc_minimum_msat': o(250), 'fee_base_msat': o(100), 'fee_proportional_millionths': o(150), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'timestamp': b'\x00\x00\x00\x00'}, trusted=True)
                self.assertNotEqual(None, path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000))
       +        self.assertEqual([(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', b'\x00\x00\x00\x00\x00\x00\x00\x03'),
       +                          (b'\x02cccccccccccccccccccccccccccccccc', b'\x00\x00\x00\x00\x00\x00\x00\x01'),
       +                          (b'\x02dddddddddddddddddddddddddddddddd', b'\x00\x00\x00\x00\x00\x00\x00\x04'),
       +                          (b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', b'\x00\x00\x00\x00\x00\x00\x00\x05')],
       +                         path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000))