URI: 
       tChannelDB: avoid duplicate (host,port) entries in ChannelDB._addresses - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 2ec548dda3664834d4a50d4f323886130285257a
   DIR parent 9a803cd1d683b72a2246ed70d76f77562f8d9981
  HTML Author: SomberNight <somber.night@protonmail.com>
       Date:   Sat,  9 Jan 2021 19:56:05 +0100
       
       ChannelDB: avoid duplicate (host,port) entries in ChannelDB._addresses
       
       before:
       node_id -> set of (host, port, ts)
       after:
       node_id -> NetAddress -> timestamp
       
       Look at e.g. add_recent_peer; we only want to store
       tthe last connection time, not all of them.
       
       Diffstat:
         M electrum/channel_db.py              |      52 +++++++++++++++++--------------
         M electrum/lnutil.py                  |       8 ++++++--
       
       2 files changed, 34 insertions(+), 26 deletions(-)
       ---
   DIR diff --git a/electrum/channel_db.py b/electrum/channel_db.py
       t@@ -34,6 +34,7 @@ import asyncio
        import threading
        from enum import IntEnum
        
       +from aiorpcx import NetAddress
        
        from .sql_db import SqlDB, sql
        from . import constants, util
       t@@ -53,14 +54,6 @@ FLAG_DISABLE   = 1 << 1
        FLAG_DIRECTION = 1 << 0
        
        
       -class NodeAddress(NamedTuple):
       -    """Holds address information of Lightning nodes
       -    and how up to date this info is."""
       -    host: str
       -    port: int
       -    timestamp: int
       -
       -
        class ChannelInfo(NamedTuple):
            short_channel_id: ShortChannelID
            node1_id: bytes
       t@@ -295,8 +288,8 @@ class ChannelDB(SqlDB):
                self._channels = {}  # type: Dict[ShortChannelID, ChannelInfo]
                self._policies = {}  # type: Dict[Tuple[bytes, ShortChannelID], Policy]  # (node_id, scid) -> Policy
                self._nodes = {}  # type: Dict[bytes, NodeInfo]  # node_id -> NodeInfo
       -        # node_id -> (host, port, ts)
       -        self._addresses = defaultdict(set)  # type: Dict[bytes, Set[NodeAddress]]
       +        # node_id -> NetAddress -> timestamp
       +        self._addresses = defaultdict(dict)  # type: Dict[bytes, Dict[NetAddress, int]]
                self._channels_for_node = defaultdict(set)  # type: Dict[bytes, Set[ShortChannelID]]
                self._recent_peers = []  # type: List[bytes]  # list of node_ids
                self._chans_with_0_policies = set()  # type: Set[ShortChannelID]
       t@@ -321,7 +314,7 @@ class ChannelDB(SqlDB):
                now = int(time.time())
                node_id = peer.pubkey
                with self.lock:
       -            self._addresses[node_id].add(NodeAddress(peer.host, peer.port, now))
       +            self._addresses[node_id][peer.net_addr()] = now
                    # list is ordered
                    if node_id in self._recent_peers:
                        self._recent_peers.remove(node_id)
       t@@ -336,12 +329,12 @@ class ChannelDB(SqlDB):
        
            def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]:
                """Returns latest address we successfully connected to, for given node."""
       -        r = self._addresses.get(node_id)
       -        if not r:
       +        addr_to_ts = self._addresses.get(node_id)
       +        if not addr_to_ts:
                    return None
       -        addr = sorted(list(r), key=lambda x: x.timestamp, reverse=True)[0]
       +        addr = sorted(list(addr_to_ts), key=lambda a: addr_to_ts[a], reverse=True)[0]
                try:
       -            return LNPeerAddr(addr.host, addr.port, node_id)
       +            return LNPeerAddr(str(addr.host), addr.port, node_id)
                except ValueError:
                    return None
        
       t@@ -583,7 +576,8 @@ class ChannelDB(SqlDB):
                        self._db_save_node_info(node_id, msg_payload['raw'])
                    with self.lock:
                        for addr in node_addresses:
       -                    self._addresses[node_id].add(NodeAddress(addr.host, addr.port, 0))
       +                    net_addr = NetAddress(addr.host, addr.port)
       +                    self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0
                    self._db_save_node_addresses(node_addresses)
        
                self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
       t@@ -634,8 +628,13 @@ class ChannelDB(SqlDB):
                # delete from database
                self._db_delete_channel(short_channel_id)
        
       -    def get_node_addresses(self, node_id):
       -        return self._addresses.get(node_id)
       +    def get_node_addresses(self, node_id: bytes) -> Sequence[Tuple[str, int, int]]:
       +        """Returns list of (host, port, timestamp)."""
       +        addr_to_ts = self._addresses.get(node_id)
       +        if not addr_to_ts:
       +            return []
       +        return [(str(net_addr.host), net_addr.port, ts)
       +                for net_addr, ts in addr_to_ts.items()]
        
            @sql
            @profiler
       t@@ -643,17 +642,19 @@ class ChannelDB(SqlDB):
                if self.data_loaded.is_set():
                    return
                # Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow.
       -        #       I believe lnmsg (and lightning.json) will need a rewrite anyway, so instead of tweaking
       -        #       load_data() here, that should be done. see #6006
                c = self.conn.cursor()
                c.execute("""SELECT * FROM address""")
                for x in c:
                    node_id, host, port, timestamp = x
       -            self._addresses[node_id].add(NodeAddress(str(host), int(port), int(timestamp or 0)))
       +            try:
       +                net_addr = NetAddress(host, port)
       +            except Exception:
       +                continue
       +            self._addresses[node_id][net_addr] = int(timestamp or 0)
                def newest_ts_for_node_id(node_id):
                    newest_ts = 0
       -            for addr in self._addresses[node_id]:
       -                newest_ts = max(newest_ts, addr.timestamp)
       +            for addr, ts in self._addresses[node_id].items():
       +                newest_ts = max(newest_ts, ts)
                    return newest_ts
                sorted_node_ids = sorted(self._addresses.keys(), key=newest_ts_for_node_id, reverse=True)
                self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS]
       t@@ -791,7 +792,10 @@ class ChannelDB(SqlDB):
                        graph['nodes'].append(
                            nodeinfo._asdict(),
                        )
       -                graph['nodes'][-1]['addresses'] = [addr._asdict() for addr in self._addresses[pk]]
       +                graph['nodes'][-1]['addresses'] = [
       +                    {'host': str(addr.host), 'port': addr.port, 'timestamp': ts}
       +                    for addr, ts in self._addresses[pk].items()
       +                ]
        
                    # gather channels
                    for cid, channelinfo in self._channels.items():
   DIR diff --git a/electrum/lnutil.py b/electrum/lnutil.py
       t@@ -1106,6 +1106,7 @@ def derive_payment_secret_from_payment_preimage(payment_preimage: bytes) -> byte
        
        
        class LNPeerAddr:
       +    # note: while not programmatically enforced, this class is meant to be *immutable*
        
            def __init__(self, host: str, port: int, pubkey: bytes):
                assert isinstance(host, str), repr(host)
       t@@ -1120,7 +1121,7 @@ class LNPeerAddr:
                self.host = host
                self.port = port
                self.pubkey = pubkey
       -        self._net_addr_str = str(net_addr)
       +        self._net_addr = net_addr
        
            def __str__(self):
                return '{}@{}'.format(self.pubkey.hex(), self.net_addr_str())
       t@@ -1128,8 +1129,11 @@ class LNPeerAddr:
            def __repr__(self):
                return f'<LNPeerAddr host={self.host} port={self.port} pubkey={self.pubkey.hex()}>'
        
       +    def net_addr(self) -> NetAddress:
       +        return self._net_addr
       +
            def net_addr_str(self) -> str:
       -        return self._net_addr_str
       +        return str(self._net_addr)
        
            def __eq__(self, other):
                if not isinstance(other, LNPeerAddr):