URI: 
       tFlatten the structure of lnrouter, so that DBSession is not used outside of ChannelDB - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 9f188c087c379078443cfb589a1b5fbd0146dd21
   DIR parent 95a217478932b75503732ce4864621b7112629c1
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Tue,  5 Mar 2019 11:22:00 +0100
       
       Flatten the structure of lnrouter, so that DBSession is not used outside of ChannelDB
       
       Diffstat:
         M electrum/lnrouter.py                |     162 +++++++++++++++----------------
         M electrum/lnworker.py                |       6 +++---
       
       2 files changed, 82 insertions(+), 86 deletions(-)
       ---
   DIR diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
       t@@ -70,7 +70,6 @@ def validate_features(features : int):
        
        Base = declarative_base()
        session_factory = sessionmaker()
       -DBSession = scoped_session(session_factory)
        
        FLAG_DISABLE   = 1 << 1
        FLAG_DIRECTION = 1 << 0
       t@@ -88,16 +87,12 @@ class ChannelInfo(Base):
            def from_msg(channel_announcement_payload):
                features = int.from_bytes(channel_announcement_payload['features'], 'big')
                validate_features(features)
       -
                channel_id = channel_announcement_payload['short_channel_id'].hex()
                node_id_1 = channel_announcement_payload['node_id_1'].hex()
                node_id_2 = channel_announcement_payload['node_id_2'].hex()
                assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
       -
                msg_payload_hex = encode_msg('channel_announcement', **channel_announcement_payload).hex()
       -
                capacity_sat = None
       -
                return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1,
                        node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
                        trusted = False)
       t@@ -106,42 +101,6 @@ class ChannelInfo(Base):
            def msg_payload(self):
                return bytes.fromhex(self.msg_payload_hex)
        
       -    def on_channel_update(self, msg: dict, trusted=False):
       -        assert self.short_channel_id == msg['short_channel_id'].hex()
       -        flags = int.from_bytes(msg['channel_flags'], 'big')
       -        direction = flags & FLAG_DIRECTION
       -        if direction == 0:
       -            node_id = self.node1_id
       -        else:
       -            node_id = self.node2_id
       -        new_policy = Policy.from_msg(msg, node_id, self.short_channel_id)
       -        old_policy = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node=node_id).one_or_none()
       -        if not old_policy:
       -            DBSession.add(new_policy)
       -            return
       -        if old_policy.timestamp >= new_policy.timestamp:
       -            return  # ignore
       -        if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)):
       -            return  # ignore
       -        old_policy.cltv_expiry_delta           = new_policy.cltv_expiry_delta
       -        old_policy.htlc_minimum_msat           = new_policy.htlc_minimum_msat
       -        old_policy.htlc_maximum_msat           = new_policy.htlc_maximum_msat
       -        old_policy.fee_base_msat               = new_policy.fee_base_msat
       -        old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths
       -        old_policy.channel_flags               = new_policy.channel_flags
       -        old_policy.timestamp                   = new_policy.timestamp
       -
       -    def get_policy_for_node(self, node) -> Optional['Policy']:
       -        """
       -        raises when initiator/non-initiator both unequal node
       -        """
       -        if node.hex() not in (self.node1_id, self.node2_id):
       -            raise Exception("the given node is not a party in this channel")
       -        n1 = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node1_id).one_or_none()
       -        if n1:
       -            return n1
       -        n2 = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node2_id).one_or_none()
       -        return n2
        
        class Policy(Base):
            __tablename__ = 'policy'
       t@@ -193,9 +152,6 @@ class NodeInfo(Base):
            timestamp = Column(Integer, nullable=False)
            alias = Column(String(64), nullable=False)
        
       -    def get_addresses(self):
       -        return DBSession.query(Address).join(NodeInfo).filter_by(node_id = self.node_id).all()
       -
            @staticmethod
            def from_msg(node_announcement_payload, addresses_already_parsed=False):
                node_id = node_announcement_payload['node_id'].hex()
       t@@ -281,27 +237,28 @@ class ChannelDB:
                the lnpeer loop is running from, which will do call in here
                """
                engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
       -        DBSession.remove()
       -        DBSession.configure(bind=engine, autoflush=False)
       +        self.DBSession = scoped_session(session_factory)
       +        self.DBSession.remove()
       +        self.DBSession.configure(bind=engine, autoflush=False)
        
                Base.metadata.drop_all(engine)
                Base.metadata.create_all(engine)
        
            def update_counts(self):
       -        self.num_channels = DBSession.query(ChannelInfo).count()
       -        self.num_nodes = DBSession.query(NodeInfo).count()
       +        self.num_channels = self.DBSession.query(ChannelInfo).count()
       +        self.num_nodes = self.DBSession.query(NodeInfo).count()
        
            def add_recent_peer(self, peer : LNPeerAddr):
       -        addr = DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none()
       +        addr = self.DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none()
                if addr is None:
                    addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now())
                else:
                    addr.last_connected_date = datetime.datetime.now()
       -        DBSession.add(addr)
       -        DBSession.commit()
       +        self.DBSession.add(addr)
       +        self.DBSession.commit()
        
            def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes):
       -        unshuffled = DBSession \
       +        unshuffled = self.DBSession \
                    .query(NodeInfo) \
                    .filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \
                    .limit(200) \
       t@@ -312,13 +269,13 @@ class ChannelDB:
                return self.network.run_from_another_thread(self._nodes_get(node_id))
        
            async def _nodes_get(self, node_id):
       -        return DBSession \
       +        return self.DBSession \
                    .query(NodeInfo) \
                    .filter_by(node_id = node_id.hex()) \
                    .one_or_none()
        
            def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
       -        adr_db = DBSession \
       +        adr_db = self.DBSession \
                    .query(Address) \
                    .filter_by(node_id = node_id.hex()) \
                    .order_by(Address.last_connected_date.desc()) \
       t@@ -328,7 +285,7 @@ class ChannelDB:
                return LNPeerAddr(adr_db.host, adr_db.port, bytes.fromhex(adr_db.node_id))
        
            def get_recent_peers(self):
       -        return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in DBSession \
       +        return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in self.DBSession \
                    .query(Address) \
                    .select_from(NodeInfo) \
                    .order_by(Address.last_connected_date.desc()) \
       t@@ -342,21 +299,21 @@ class ChannelDB:
                condition = or_(
                         ChannelInfo.node1_id == node_id.hex(),
                         ChannelInfo.node2_id == node_id.hex())
       -        rows = DBSession.query(ChannelInfo).filter(condition).all()
       +        rows = self.DBSession.query(ChannelInfo).filter(condition).all()
                return [bytes.fromhex(x.short_channel_id) for x in rows]
        
            def missing_short_chan_ids(self) -> Set[int]:
       -        expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfo.short_channel_id)))
       -        chan_ids_from_policy = set(x[0] for x in DBSession.query(Policy.short_channel_id).filter(expr).all())
       +        expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
       +        chan_ids_from_policy = set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
                if chan_ids_from_policy:
                    return chan_ids_from_policy
                # fetch channels for node_ids missing in node_info. that will also give us node_announcement
       -        expr = not_(ChannelInfo.node1_id.in_(DBSession.query(NodeInfo.node_id)))
       -        chan_ids_from_id1 = set(x[0] for x in DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
       +        expr = not_(ChannelInfo.node1_id.in_(self.DBSession.query(NodeInfo.node_id)))
       +        chan_ids_from_id1 = set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
                if chan_ids_from_id1:
                    return chan_ids_from_id1
       -        expr = not_(ChannelInfo.node2_id.in_(DBSession.query(NodeInfo.node_id)))
       -        chan_ids_from_id2 = set(x[0] for x in DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
       +        expr = not_(ChannelInfo.node2_id.in_(self.DBSession.query(NodeInfo.node_id)))
       +        chan_ids_from_id2 = set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
                if chan_ids_from_id2:
                    return chan_ids_from_id2
                return set()
       t@@ -366,7 +323,7 @@ class ChannelDB:
                channel_info = self.get_channel_info(short_id)
                channel_info.trusted = True
                channel_info.capacity = capacity
       -        DBSession.commit()
       +        self.DBSession.commit()
        
            @profiler
            def on_channel_announcement(self, msg_payloads, trusted=False):
       t@@ -374,7 +331,7 @@ class ChannelDB:
                    msg_payloads = [msg_payloads]
                for msg in msg_payloads:
                    short_channel_id = msg['short_channel_id']
       -            if DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count():
       +            if self.DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count():
                        continue
                    if constants.net.rev_genesis_bytes() != msg['chain_hash']:
                        #self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
       t@@ -384,9 +341,9 @@ class ChannelDB:
                    except UnknownEvenFeatureBits:
                        continue
                    channel_info.trusted = trusted
       -            DBSession.add(channel_info)
       +            self.DBSession.add(channel_info)
                    if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
       -        DBSession.commit()
       +        self.DBSession.commit()
                self.network.trigger_callback('ln_status')
                self.update_counts()
        
       t@@ -395,7 +352,7 @@ class ChannelDB:
                if type(msg_payloads) is dict:
                    msg_payloads = [msg_payloads]
                short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
       -        channel_infos_list = DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
       +        channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
                channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
                for msg_payload in msg_payloads:
                    short_channel_id = msg_payload['short_channel_id']
       t@@ -404,19 +361,19 @@ class ChannelDB:
                    channel_info = channel_infos.get(short_channel_id)
                    if not channel_info:
                        continue
       -            channel_info.on_channel_update(msg_payload, trusted=trusted)
       -        DBSession.commit()
       +            self._update_channel_info(channel_info, msg_payload, trusted=trusted)
       +        self.DBSession.commit()
        
            @profiler
            def on_node_announcement(self, msg_payloads):
                if type(msg_payloads) is dict:
                    msg_payloads = [msg_payloads]
       -        addresses = DBSession.query(Address).all()
       +        addresses = self.DBSession.query(Address).all()
                have_addr = {}
                for addr in addresses:
                    have_addr[(addr.node_id, addr.host, addr.port)] = addr
        
       -        nodes = DBSession.query(NodeInfo).all()
       +        nodes = self.DBSession.query(NodeInfo).all()
                timestamps = {}
                for node in nodes:
                    no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")]
       t@@ -434,7 +391,7 @@ class ChannelDB:
                        continue
                    if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp:
                        continue  # ignore
       -            DBSession.add(new_node_info)
       +            self.DBSession.add(new_node_info)
                    for new_addr in addresses:
                        key = (new_addr.node_id, new_addr.host, new_addr.port)
                        old_addr = have_addr.get(key)
       t@@ -444,7 +401,7 @@ class ChannelDB:
                            old_addr.last_connected_date = new_addr.last_connected_date
                            del new_addr
                        else:
       -                    DBSession.add(new_addr)
       +                    self.DBSession.add(new_addr)
                            have_addr[key] = new_addr
                    # TODO if this message is for a new node, and if we have no associated
                    # channels for this node, we should ignore the message and return here,
       t@@ -453,7 +410,7 @@ class ChannelDB:
                del nodes, addresses
                if old_addr:
                    del old_addr
       -        DBSession.commit()
       +        self.DBSession.commit()
                self.network.trigger_callback('ln_status')
                self.update_counts()
        
       t@@ -462,9 +419,10 @@ class ChannelDB:
                if not start_node_id or not short_channel_id: return None
                channel_info = self.get_channel_info(short_channel_id)
                if channel_info is not None:
       -            return channel_info.get_policy_for_node(start_node_id)
       +            return self.get_policy_for_node(channel_info, start_node_id)
                msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
       -        if not msg: return None
       +        if not msg:
       +            return None
                return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB
        
            def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
       t@@ -475,10 +433,10 @@ class ChannelDB:
        
            def remove_channel(self, short_channel_id):
                self.chan_query_for_id(short_channel_id).delete('evaluate')
       -        DBSession.commit()
       +        self.DBSession.commit()
        
            def chan_query_for_id(self, short_channel_id) -> Query:
       -        return DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex())
       +        return self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex())
        
            def print_graph(self, full_ids=False):
                # used for debugging.
       t@@ -492,15 +450,15 @@ class ChannelDB:
                    return other if full_ids else other[-4:]
        
                self.print_msg('nodes')
       -        for node in DBSession.query(NodeInfo).all():
       +        for node in self.DBSession.query(NodeInfo).all():
                    self.print_msg(node)
        
                self.print_msg('channels')
       -        for channel_info in DBSession.query(ChannelInfo).all():
       +        for channel_info in self.DBSession.query(ChannelInfo).all():
                    node1 = channel_info.node1_id
                    node2 = channel_info.node2_id
       -            direction1 = channel_info.get_policy_for_node(node1) is not None
       -            direction2 = channel_info.get_policy_for_node(node2) is not None
       +            direction1 = self.get_policy_for_node(channel_info, node1) is not None
       +            direction2 = self.get_policy_for_node(channel_info, node2) is not None
                    if direction1 and direction2:
                        direction = 'both'
                    elif direction1:
       t@@ -515,6 +473,44 @@ class ChannelDB:
                                           bh2u(node2) if full_ids else bh2u(node2[-4:]),
                                           direction))
        
       +    def _update_channel_info(self, channel_info, msg: dict, trusted=False):
       +        assert channel_info.short_channel_id == msg['short_channel_id'].hex()
       +        flags = int.from_bytes(msg['channel_flags'], 'big')
       +        direction = flags & FLAG_DIRECTION
       +        node_id = channel_info.node1_id if direction == 0 else channel_info.node2_id
       +        new_policy = Policy.from_msg(msg, node_id, channel_info.short_channel_id)
       +        old_policy = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node=node_id).one_or_none()
       +        if not old_policy:
       +            self.DBSession.add(new_policy)
       +            return
       +        if old_policy.timestamp >= new_policy.timestamp:
       +            return  # ignore
       +        if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)):
       +            return  # ignore
       +        old_policy.cltv_expiry_delta           = new_policy.cltv_expiry_delta
       +        old_policy.htlc_minimum_msat           = new_policy.htlc_minimum_msat
       +        old_policy.htlc_maximum_msat           = new_policy.htlc_maximum_msat
       +        old_policy.fee_base_msat               = new_policy.fee_base_msat
       +        old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths
       +        old_policy.channel_flags               = new_policy.channel_flags
       +        old_policy.timestamp                   = new_policy.timestamp
       +
       +    def get_policy_for_node(self, node) -> Optional['Policy']:
       +        """
       +        raises when initiator/non-initiator both unequal node
       +        """
       +        if node.hex() not in (self.node1_id, self.node2_id):
       +            raise Exception("the given node is not a party in this channel")
       +        n1 = self.DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node1_id).one_or_none()
       +        if n1:
       +            return n1
       +        n2 = self.DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node2_id).one_or_none()
       +        return n2
       +
       +    def get_node_addresses(self, node_info):
       +        return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all()
       +
       +
        
        class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
                                                 ('short_channel_id', bytes),
       t@@ -596,7 +592,7 @@ class LNPathFinder(PrintError):
                if channel_info is None:
                    return float('inf'), 0
        
       -        channel_policy = channel_info.get_policy_for_node(start_node)
       +        channel_policy = self.channel_db.get_policy_for_node(channel_info, start_node)
                if channel_policy is None: return float('inf'), 0
                if channel_policy.is_disabled(): return float('inf'), 0
                route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node)
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -444,7 +444,7 @@ class LNWorker(PrintError):
                    else:
                        if not node_info:
                            raise ConnStringFormatError(_('Unknown node:') + ' ' + bh2u(node_id))
       -                addrs = node_info.get_addresses()
       +                addrs = self.channel_db.get_node_addresses(node_info)
                        if len(addrs) == 0:
                            raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + bh2u(node_id))
                        host, port = self.choose_preferred_address(addrs)
       t@@ -710,7 +710,7 @@ class LNWorker(PrintError):
                unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys())
                if unconnected_nodes:
                    for node in unconnected_nodes:
       -                addrs = node.get_addresses()
       +                addrs = self.channel_db.get_node_addresses(node)
                        if not addrs:
                            continue
                        host, port = self.choose_preferred_address(addrs)
       t@@ -776,7 +776,7 @@ class LNWorker(PrintError):
                    # try random address for node_id
                    node_info = await self.channel_db._nodes_get(chan.node_id)
                    if not node_info: return
       -            addresses = node_info.get_addresses()
       +            addresses = self.channel_db.get_node_addresses(node_info)
                    if not addresses: return
                    adr_obj = random.choice(addresses)
                    host, port = adr_obj.host, adr_obj.port