URI: 
       tlnrouter: load data before finding path - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 34f22e6681f9360f743dd1cea1675845c43b8b55
   DIR parent dac686b11d1c2cfc4cd130899cfd9f80656fda28
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Thu, 21 Mar 2019 12:44:32 +0100
       
       lnrouter: load data before finding path
       
       Diffstat:
         M electrum/lnrouter.py                |      73 ++++++++++++++++---------------
         M electrum/lnworker.py                |       1 +
       
       2 files changed, 38 insertions(+), 36 deletions(-)
       ---
   DIR diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
       t@@ -229,7 +229,9 @@ class ChannelDB(SqlDB):
        
            def _update_counts(self):
                self.num_channels = self.DBSession.query(ChannelInfo).count()
       +        self.num_policies = self.DBSession.query(Policy).count()
                self.num_nodes = self.DBSession.query(NodeInfo).count()
       +        self.print_error('update counts', self.num_channels, self.num_policies)
        
            @sql
            def add_recent_peer(self, peer: LNPeerAddr):
       t@@ -273,19 +275,6 @@ class ChannelDB(SqlDB):
                return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r]
        
            @sql
       -    def get_channel_info(self, channel_id: bytes):
       -        return self._chan_query_for_id(channel_id).one_or_none()
       -
       -    @sql
       -    def get_channels_for_node(self, node_id):
       -        """Returns the set of channels that have node_id as one of the endpoints."""
       -        condition = or_(
       -                 ChannelInfo.node1_id == node_id.hex(),
       -                 ChannelInfo.node2_id == node_id.hex())
       -        rows = self.DBSession.query(ChannelInfo).filter(condition).all()
       -        return [bytes.fromhex(x.short_channel_id) for x in rows]
       -
       -    @sql
            def missing_short_chan_ids(self) -> Set[int]:
                expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
                chan_ids_from_policy = set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
       t@@ -296,7 +285,7 @@ class ChannelDB(SqlDB):
            @sql
            def add_verified_channel_info(self, short_id, capacity):
                # called from lnchannelverifier
       -        channel_info = self._chan_query_for_id(short_id).one_or_none()
       +        channel_info = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_id.hex()).one_or_none()
                channel_info.trusted = True
                channel_info.capacity = capacity
                self.DBSession.commit()
       t@@ -372,7 +361,6 @@ class ChannelDB(SqlDB):
                    if p and p.timestamp >= new_policy.timestamp:
                        continue
                    new_policies[(short_channel_id, node_id)] = new_policy
       -        #self.print_error('on_channel_update: %d/%d'%(len(new_policies), len(msg_payloads)))
                # commit pending removals
                self.DBSession.commit()
                # add and commit new policies
       t@@ -380,7 +368,9 @@ class ChannelDB(SqlDB):
                    self.DBSession.add(new_policy)
                self.DBSession.commit()
                if new_policies:
       +            self.print_error('on_channel_update: %d/%d'%(len(new_policies), len(msg_payloads)))
                    self.print_error('last timestamp:', datetime.fromtimestamp(self._get_last_timestamp()).ctime())
       +            self._update_counts()
        
            @sql
            #@profiler
       t@@ -432,7 +422,7 @@ class ChannelDB(SqlDB):
                if not start_node_id or not short_channel_id: return None
                channel_info = self.get_channel_info(short_channel_id)
                if channel_info is not None:
       -            return self.get_policy_for_node(channel_info, start_node_id)
       +            return self.get_policy_for_node(short_channel_id, start_node_id)
                msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
                if not msg:
                    return None
       t@@ -446,12 +436,12 @@ class ChannelDB(SqlDB):
        
            @sql
            def remove_channel(self, short_channel_id):
       -        self._chan_query_for_id(short_channel_id).delete('evaluate')
       +        r = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()).one_or_none()
       +        if not r:
       +            return
       +        self.DBSession.delete(r)
                self.DBSession.commit()
        
       -    def _chan_query_for_id(self, short_channel_id) -> Query:
       -        return self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex())
       -
            def print_graph(self, full_ids=False):
                # used for debugging.
                # FIXME there is a race here - iterables could change size from another thread
       t@@ -489,22 +479,32 @@ class ChannelDB(SqlDB):
        
        
            @sql
       -    def get_policy_for_node(self, channel_info, node) -> Optional['Policy']:
       -        """
       -        raises when initiator/non-initiator both unequal node
       -        """
       -        if node.hex() not in (channel_info.node1_id, channel_info.node2_id):
       -            raise Exception("the given node is not a party in this channel")
       -        n1 = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node = channel_info.node1_id).one_or_none()
       -        if n1:
       -            return n1
       -        n2 = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node = channel_info.node2_id).one_or_none()
       -        return n2
       -
       -    @sql
            def get_node_addresses(self, node_info):
                return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all()
        
       +    @sql
       +    @profiler
       +    def load_data(self):
       +        r = self.DBSession.query(ChannelInfo).all()
       +        self._channels = dict([(bfh(x.short_channel_id), x) for x in r])
       +        r = self.DBSession.query(Policy).filter_by().all()
       +        self._policies = dict([((bfh(x.start_node), bfh(x.short_channel_id)), x) for x in r])
       +        self._channels_for_node = defaultdict(set)
       +        for channel_info in self._channels.values():
       +            self._channels_for_node[bfh(channel_info.node1_id)].add(bfh(channel_info.short_channel_id))
       +            self._channels_for_node[bfh(channel_info.node2_id)].add(bfh(channel_info.short_channel_id))
       +        self.print_error('load data', len(self._channels), len(self._policies), len(self._channels_for_node))
       +
       +    def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
       +        return self._policies.get((node_id, short_channel_id))
       +
       +    def get_channel_info(self, channel_id: bytes):
       +        return self._channels.get(channel_id)
       +
       +    def get_channels_for_node(self, node_id):
       +        """Returns the set of channels that have node_id as one of the endpoints."""
       +        return self._channels_for_node.get(node_id)
       +
        
        
        class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
       t@@ -586,9 +586,9 @@ class LNPathFinder(PrintError):
                channel_info = self.channel_db.get_channel_info(short_channel_id)  # type: ChannelInfo
                if channel_info is None:
                    return float('inf'), 0
       -
       -        channel_policy = self.channel_db.get_policy_for_node(channel_info, start_node)
       -        if channel_policy is None: return float('inf'), 0
       +        channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node)
       +        if channel_policy is None:
       +            return float('inf'), 0
                if channel_policy.is_disabled(): return float('inf'), 0
                route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node)
                if payment_amt_msat < channel_policy.htlc_minimum_msat:
       t@@ -618,6 +618,7 @@ class LNPathFinder(PrintError):
                To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
                i.e. an element reads as, "to get to node_id, travel through short_channel_id"
                """
       +        self.channel_db.load_data()
                assert type(nodeA) is bytes
                assert type(nodeB) is bytes
                assert type(invoice_amount_msat) is int
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -611,6 +611,7 @@ class LNWorker(PrintError):
        
            def _calc_routing_hints_for_invoice(self, amount_sat):
                """calculate routing hints (BOLT-11 'r' field)"""
       +        self.channel_db.load_data()
                routing_hints = []
                with self.lock:
                    channels = list(self.channels.values())