URI: 
       tpass blacklist to lnrouter.find_route, so that lnrouter is stateless (see #6778) - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit ad91257729dc67cc8813d18ee8f1464c5571d72e
   DIR parent 9d7a3174042ffcd35d20833ba844c28123e30f0d
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Mon, 11 Jan 2021 15:19:50 +0100
       
       pass blacklist to lnrouter.find_route, so that lnrouter is stateless (see #6778)
       
       Diffstat:
         M electrum/lnrouter.py                |      32 ++++++++++---------------------
         M electrum/lnutil.py                  |      16 +++++++++++++++-
         M electrum/lnworker.py                |      11 +++++++----
         M electrum/network.py                 |       3 ++-
         M electrum/tests/test_lnpeer.py       |       2 ++
       
       5 files changed, 36 insertions(+), 28 deletions(-)
       ---
   DIR diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
       t@@ -135,24 +135,12 @@ def is_fee_sane(fee_msat: int, *, payment_amount_msat: int) -> bool:
            return False
        
        
       -BLACKLIST_DURATION = 3600
        
        class LNPathFinder(Logger):
        
            def __init__(self, channel_db: ChannelDB):
                Logger.__init__(self)
                self.channel_db = channel_db
       -        self.blacklist = dict() # short_chan_id -> timestamp
       -
       -    def add_to_blacklist(self, short_channel_id: ShortChannelID):
       -        self.logger.info(f'blacklisting channel {short_channel_id}')
       -        now = int(time.time())
       -        self.blacklist[short_channel_id] = now
       -
       -    def is_blacklisted(self, short_channel_id: ShortChannelID) -> bool:
       -        now = int(time.time())
       -        t = self.blacklist.get(short_channel_id, 0)
       -        return now - t < BLACKLIST_DURATION
        
            def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
                           payment_amt_msat: int, ignore_costs=False, is_mine=False, *,
       t@@ -200,10 +188,9 @@ class LNPathFinder(Logger):
                overall_cost = base_cost + fee_msat + cltv_cost
                return overall_cost, fee_msat
        
       -    def get_distances(self, nodeA: bytes, nodeB: bytes,
       -                      invoice_amount_msat: int, *,
       -                      my_channels: Dict[ShortChannelID, 'Channel'] = None
       -                      ) -> Dict[bytes, PathEdge]:
       +    def get_distances(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *,
       +                      my_channels: Dict[ShortChannelID, 'Channel'] = None,
       +                      blacklist: Set[ShortChannelID] = None) -> Dict[bytes, PathEdge]:
                # note: we don't lock self.channel_db, so while the path finding runs,
                #       the underlying graph could potentially change... (not good but maybe ~OK?)
        
       t@@ -216,7 +203,6 @@ class LNPathFinder(Logger):
                nodes_to_explore = queue.PriorityQueue()
                nodes_to_explore.put((0, invoice_amount_msat, nodeB))  # order of fields (in tuple) matters!
        
       -
                # main loop of search
                while nodes_to_explore.qsize() > 0:
                    dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get()
       t@@ -229,7 +215,7 @@ class LNPathFinder(Logger):
                        continue
                    for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels):
                        assert isinstance(edge_channel_id, bytes)
       -                if self.is_blacklisted(edge_channel_id):
       +                if blacklist and edge_channel_id in blacklist:
                            continue
                        channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels)
                        edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
       t@@ -263,7 +249,8 @@ class LNPathFinder(Logger):
            @profiler
            def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
                                      invoice_amount_msat: int, *,
       -                              my_channels: Dict[ShortChannelID, 'Channel'] = None) \
       +                              my_channels: Dict[ShortChannelID, 'Channel'] = None,
       +                              blacklist: Set[ShortChannelID] = None) \
                    -> Optional[LNPaymentPath]:
                """Return a path from nodeA to nodeB."""
                assert type(nodeA) is bytes
       t@@ -272,7 +259,7 @@ class LNPathFinder(Logger):
                if my_channels is None:
                    my_channels = {}
        
       -        prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels)
       +        prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist)
        
                if nodeA not in prev_node:
                    return None  # no path found
       t@@ -312,8 +299,9 @@ class LNPathFinder(Logger):
                return route
        
            def find_route(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *,
       -                   path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[LNPaymentRoute]:
       +                   path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None,
       +                   blacklist: Set[ShortChannelID] = None) -> Optional[LNPaymentRoute]:
                if not path:
       -            path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels)
       +            path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist)
                if path:
                    return self.create_route_from_path(path, nodeA, my_channels=my_channels)
   DIR diff --git a/electrum/lnutil.py b/electrum/lnutil.py
       t@@ -8,7 +8,7 @@ import json
        from collections import namedtuple, defaultdict
        from typing import NamedTuple, List, Tuple, Mapping, Optional, TYPE_CHECKING, Union, Dict, Set, Sequence
        import re
       -
       +import time
        import attr
        from aiorpcx import NetAddress
        
       t@@ -1313,3 +1313,17 @@ class OnionFailureCodeMetaFlag(IntFlag):
            NODE     = 0x2000
            UPDATE   = 0x1000
        
       +
       +class ChannelBlackList:
       +
       +    def __init__(self):
       +        self.blacklist = dict() # short_chan_id -> timestamp
       +
       +    def add(self, short_channel_id: ShortChannelID):
       +        now = int(time.time())
       +        self.blacklist[short_channel_id] = now
       +
       +    def get_current_list(self) -> Set[ShortChannelID]:
       +        BLACKLIST_DURATION = 3600
       +        now = int(time.time())
       +        return set(k for k, t in self.blacklist.items() if now - t < BLACKLIST_DURATION)
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -7,7 +7,7 @@ import os
        from decimal import Decimal
        import random
        import time
       -from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping, Any
       +from typing import Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping, Any
        import threading
        import socket
        import aiohttp
       t@@ -540,6 +540,7 @@ class LNGossip(LNWorker):
                    if categorized_chan_upds.good:
                        self.logger.debug(f'on_channel_update: {len(categorized_chan_upds.good)}/{len(chan_upds_chunk)}')
        
       +
        class LNWallet(LNWorker):
        
            lnwatcher: Optional['LNWalletWatcher']
       t@@ -1014,7 +1015,8 @@ class LNWallet(LNWorker):
                            except IndexError:
                                self.logger.info("payment destination reported error")
                            else:
       -                        self.network.path_finder.add_to_blacklist(short_chan_id)
       +                        self.logger.info(f'blacklisting channel {short_channel_id}')
       +                        self.network.channel_blacklist.add(short_chan_id)
                    else:
                        # probably got "update_fail_malformed_htlc". well... who to penalise now?
                        assert payment_attempt.failure_message is not None
       t@@ -1127,6 +1129,7 @@ class LNWallet(LNWorker):
                channels = list(self.channels.values())
                scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
                                       if chan.short_channel_id is not None}
       +        blacklist = self.network.channel_blacklist.get_current_list()
                for private_route in r_tags:
                    if len(private_route) == 0:
                        continue
       t@@ -1144,7 +1147,7 @@ class LNWallet(LNWorker):
                    try:
                        route = self.network.path_finder.find_route(
                            self.node_keypair.pubkey, border_node_pubkey, amount_msat,
       -                    path=path, my_channels=scid_to_my_channels)
       +                    path=path, my_channels=scid_to_my_channels, blacklist=blacklist)
                    except NoChannelPolicy:
                        continue
                    if not route:
       t@@ -1186,7 +1189,7 @@ class LNWallet(LNWorker):
                if route is None:
                    route = self.network.path_finder.find_route(
                        self.node_keypair.pubkey, invoice_pubkey, amount_msat,
       -                path=full_path, my_channels=scid_to_my_channels)
       +                path=full_path, my_channels=scid_to_my_channels, blacklist=blacklist)
                    if not route:
                        raise NoPathFound()
                    if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()):
   DIR diff --git a/electrum/network.py b/electrum/network.py
       t@@ -45,7 +45,6 @@ from . import util
        from .util import (log_exceptions, ignore_exceptions,
                           bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter,
                           is_hash256_str, is_non_negative_integer, MyEncoder, NetworkRetryManager)
       -
        from .bitcoin import COIN
        from . import constants
        from . import blockchain
       t@@ -60,6 +59,7 @@ from .version import PROTOCOL_VERSION
        from .simple_config import SimpleConfig
        from .i18n import _
        from .logging import get_logger, Logger
       +from .lnutil import ChannelBlackList
        
        if TYPE_CHECKING:
            from .channel_db import ChannelDB
       t@@ -335,6 +335,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
                self._has_ever_managed_to_connect_to_server = False
        
                # lightning network
       +        self.channel_blacklist = ChannelBlackList()
                self.channel_db = None  # type: Optional[ChannelDB]
                self.lngossip = None  # type: Optional[LNGossip]
                self.local_watchtower = None  # type: Optional[WatchTower]
   DIR diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
       t@@ -32,6 +32,7 @@ from electrum.lnmsg import encode_msg, decode_msg
        from electrum.logging import console_stderr_handler, Logger
        from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID
        from electrum.lnonion import OnionFailureCode
       +from electrum.lnutil import ChannelBlackList
        
        from .test_lnchannel import create_test_channels
        from .test_bitcoin import needs_test_with_all_chacha20_implementations
       t@@ -62,6 +63,7 @@ class MockNetwork:
                self.path_finder = LNPathFinder(self.channel_db)
                self.tx_queue = tx_queue
                self._blockchain = MockBlockchain()
       +        self.channel_blacklist = ChannelBlackList()
        
            @property
            def callback_lock(self):