URI: 
       tlnrouter: perform SQL requests in a separate thread. persist database. - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 46aa5c19584a1118deabb9dc2a6eba6012d38509
   DIR parent 9f188c087c379078443cfb589a1b5fbd0146dd21
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Tue,  5 Mar 2019 12:20:56 +0100
       
       lnrouter: perform SQL requests in a separate thread. persist database.
       
       Diffstat:
         M electrum/lnrouter.py                |      81 +++++++++++++++++++++-----------
       
       1 file changed, 54 insertions(+), 27 deletions(-)
       ---
   DIR diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
       t@@ -29,11 +29,11 @@ import queue
        import os
        import json
        import threading
       +import concurrent
        from collections import defaultdict
        from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
        import binascii
        import base64
       -import asyncio
        
        from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
        from sqlalchemy.pool import StaticPool
       t@@ -212,43 +212,59 @@ class Address(Base):
            port = Column(Integer, primary_key=True)
            last_connected_date = Column(DateTime(), nullable=False)
        
       -class ChannelDB:
       +
       +class ChannelDB(PrintError):
        
            NUM_MAX_RECENT_PEERS = 20
        
            def __init__(self, network: 'Network'):
                self.network = network
       -
                self.num_nodes = 0
                self.num_channels = 0
       -
                self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3')
       -
       -        # (intentionally not persisted)
                self._channel_updates_for_private_channels = {}  # type: Dict[Tuple[bytes, bytes], dict]
       -
                self.ca_verifier = LNChannelVerifier(network, self)
       +        self.db_requests = queue.Queue()
       +        threading.Thread(target=self.sql_thread).start()
        
       -        self.network.run_from_another_thread(self.sqlinit())
       -
       -    async def sqlinit(self):
       -        """
       -        this has to run on the async thread since that is where
       -        the lnpeer loop is running from, which will do call in here
       -        """
       +    def sql_thread(self):
                engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
                self.DBSession = scoped_session(session_factory)
                self.DBSession.remove()
                self.DBSession.configure(bind=engine, autoflush=False)
       +        if not os.path.exists(self.path):
       +            Base.metadata.create_all(engine)
       +        self._update_counts()
       +        while self.network.asyncio_loop.is_running():
       +            try:
       +                future, func, args, kwargs = self.db_requests.get(timeout=0.1)
       +            except queue.Empty:
       +                continue
       +            try:
       +                result = func(self, *args, **kwargs)
       +            except BaseException as e:
       +                future.set_exception(e)
       +                continue
       +            future.set_result(result)
       +        # write
       +        self.DBSession.commit()
       +        self.DBSession.remove()
       +        self.print_error("SQL thread terminated")
        
       -        Base.metadata.drop_all(engine)
       -        Base.metadata.create_all(engine)
       +    def sql(func):
       +        def wrapper(self, *args, **kwargs):
       +            f = concurrent.futures.Future()
       +            self.db_requests.put((f, func, args, kwargs))
       +            return f.result(timeout=10)
       +        return wrapper
        
       -    def update_counts(self):
       +    # not @sql
       +    def _update_counts(self):
                self.num_channels = self.DBSession.query(ChannelInfo).count()
                self.num_nodes = self.DBSession.query(NodeInfo).count()
        
       -    def add_recent_peer(self, peer : LNPeerAddr):
       +    @sql
       +    def add_recent_peer(self, peer: LNPeerAddr):
                addr = self.DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none()
                if addr is None:
                    addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now())
       t@@ -257,6 +273,7 @@ class ChannelDB:
                self.DBSession.add(addr)
                self.DBSession.commit()
        
       +    @sql
            def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes):
                unshuffled = self.DBSession \
                    .query(NodeInfo) \
       t@@ -265,15 +282,14 @@ class ChannelDB:
                    .all()
                return random.sample(unshuffled, len(unshuffled))
        
       +    @sql
            def nodes_get(self, node_id):
       -        return self.network.run_from_another_thread(self._nodes_get(node_id))
       -
       -    async def _nodes_get(self, node_id):
                return self.DBSession \
                    .query(NodeInfo) \
                    .filter_by(node_id = node_id.hex()) \
                    .one_or_none()
        
       +    @sql
            def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
                adr_db = self.DBSession \
                    .query(Address) \
       t@@ -284,6 +300,7 @@ class ChannelDB:
                    return None
                return LNPeerAddr(adr_db.host, adr_db.port, bytes.fromhex(adr_db.node_id))
        
       +    @sql
            def get_recent_peers(self):
                return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in self.DBSession \
                    .query(Address) \
       t@@ -291,9 +308,11 @@ class ChannelDB:
                    .order_by(Address.last_connected_date.desc()) \
                    .limit(self.NUM_MAX_RECENT_PEERS)]
        
       +    @sql
            def get_channel_info(self, channel_id: bytes):
       -        return self.chan_query_for_id(channel_id).one_or_none()
       +        return self._chan_query_for_id(channel_id).one_or_none()
        
       +    @sql
            def get_channels_for_node(self, node_id):
                """Returns the set of channels that have node_id as one of the endpoints."""
                condition = or_(
       t@@ -302,6 +321,7 @@ class ChannelDB:
                rows = self.DBSession.query(ChannelInfo).filter(condition).all()
                return [bytes.fromhex(x.short_channel_id) for x in rows]
        
       +    @sql
            def missing_short_chan_ids(self) -> Set[int]:
                expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
                chan_ids_from_policy = set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
       t@@ -318,13 +338,15 @@ class ChannelDB:
                    return chan_ids_from_id2
                return set()
        
       +    @sql
            def add_verified_channel_info(self, short_id, capacity):
                # called from lnchannelverifier
       -        channel_info = self.get_channel_info(short_id)
       +        channel_info = self._chan_query_for_id(short_id).one_or_none()
                channel_info.trusted = True
                channel_info.capacity = capacity
                self.DBSession.commit()
        
       +    @sql
            @profiler
            def on_channel_announcement(self, msg_payloads, trusted=False):
                if type(msg_payloads) is dict:
       t@@ -344,9 +366,10 @@ class ChannelDB:
                    self.DBSession.add(channel_info)
                    if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
                self.DBSession.commit()
       +        self._update_counts()
                self.network.trigger_callback('ln_status')
       -        self.update_counts()
        
       +    @sql
            @profiler
            def on_channel_update(self, msg_payloads, trusted=False):
                if type(msg_payloads) is dict:
       t@@ -364,6 +387,7 @@ class ChannelDB:
                    self._update_channel_info(channel_info, msg_payload, trusted=trusted)
                self.DBSession.commit()
        
       +    @sql
            @profiler
            def on_node_announcement(self, msg_payloads):
                if type(msg_payloads) is dict:
       t@@ -411,8 +435,8 @@ class ChannelDB:
                if old_addr:
                    del old_addr
                self.DBSession.commit()
       +        self._update_counts()
                self.network.trigger_callback('ln_status')
       -        self.update_counts()
        
            def get_routing_policy_for_channel(self, start_node_id: bytes,
                                               short_channel_id: bytes) -> Optional[bytes]:
       t@@ -431,11 +455,12 @@ class ChannelDB:
                short_channel_id = msg_payload['short_channel_id']
                self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
        
       +    @sql
            def remove_channel(self, short_channel_id):
       -        self.chan_query_for_id(short_channel_id).delete('evaluate')
       +        self._chan_query_for_id(short_channel_id).delete('evaluate')
                self.DBSession.commit()
        
       -    def chan_query_for_id(self, short_channel_id) -> Query:
       +    def _chan_query_for_id(self, short_channel_id) -> Query:
                return self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex())
        
            def print_graph(self, full_ids=False):
       t@@ -495,6 +520,7 @@ class ChannelDB:
                old_policy.channel_flags               = new_policy.channel_flags
                old_policy.timestamp                   = new_policy.timestamp
        
       +    @sql
            def get_policy_for_node(self, node) -> Optional['Policy']:
                """
                raises when initiator/non-initiator both unequal node
       t@@ -507,6 +533,7 @@ class ChannelDB:
                n2 = self.DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node2_id).one_or_none()
                return n2
        
       +    @sql
            def get_node_addresses(self, node_info):
                return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all()