URI: 
       toptimize channel_db: - use python objects mirrored by sql database - write sql to file asynchronously - the sql decorator is awaited in sweepstore, not in channel_db - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit f2d58d0e3f97975d4dcfcbcacc96d7e206190ef6
   DIR parent 180f6d34bec2f3e443488f922591d51c11cab1f6
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Tue, 18 Jun 2019 13:49:31 +0200
       
       optimize channel_db:
        - use python objects mirrored by sql database
        - write sql to file asynchronously
        - the sql decorator is awaited in sweepstore, not in channel_db
       
       Diffstat:
         M electrum/channel_db.py              |     534 ++++++++++++++-----------------
         M electrum/gui/qt/lightning_dialog.py |       9 +++++----
         M electrum/lnaddr.py                  |      11 +++++++++--
         M electrum/lnchannel.py               |      12 +++++-------
         M electrum/lnpeer.py                  |      87 ++++++++++++++-----------------
         M electrum/lnrouter.py                |      19 ++++++++++---------
         M electrum/lnwatcher.py               |      49 ++++++++++++++++++-------------
         M electrum/lnworker.py                |     131 ++++++++++++++++---------------
         M electrum/sql_db.py                  |      22 +++++++++++++++++-----
         M electrum/tests/test_lnpeer.py       |       3 ++-
         M electrum/tests/test_lnrouter.py     |      12 ++++++------
       
       11 files changed, 435 insertions(+), 454 deletions(-)
       ---
   DIR diff --git a/electrum/channel_db.py b/electrum/channel_db.py
       t@@ -51,6 +51,7 @@ from .crypto import sha256d
        from . import ecc
        from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH,
                             NotFoundChanAnnouncementForUpdate)
       +from .lnverifier import verify_sig_for_channel_update
        from .lnmsg import encode_msg
        
        if TYPE_CHECKING:
       t@@ -70,85 +71,83 @@ Base = declarative_base()
        FLAG_DISABLE   = 1 << 1
        FLAG_DIRECTION = 1 << 0
        
       -class ChannelInfo(Base):
       -    __tablename__ = 'channel_info'
       -    short_channel_id = Column(String(64), primary_key=True)
       -    node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
       -    node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
       -    capacity_sat = Column(Integer)
       -    msg_payload_hex = Column(String(1024), nullable=False)
       -    trusted = Column(Boolean, nullable=False)
       +class ChannelInfo(NamedTuple):
       +    short_channel_id: bytes
       +    node1_id: bytes
       +    node2_id: bytes
       +    capacity_sat: int
       +    msg_payload: bytes
       +    trusted: bool
        
            @staticmethod
            def from_msg(payload):
                features = int.from_bytes(payload['features'], 'big')
                validate_features(features)
       -        channel_id = payload['short_channel_id'].hex()
       -        node_id_1 = payload['node_id_1'].hex()
       -        node_id_2 = payload['node_id_2'].hex()
       +        channel_id = payload['short_channel_id']
       +        node_id_1 = payload['node_id_1']
       +        node_id_2 = payload['node_id_2']
                assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
       -        msg_payload_hex = encode_msg('channel_announcement', **payload).hex()
       +        msg_payload = encode_msg('channel_announcement', **payload)
                capacity_sat = None
       -        return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1,
       -                node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
       -                trusted = False)
       -
       -    @property
       -    def msg_payload(self):
       -        return bytes.fromhex(self.msg_payload_hex)
       -
       -
       -class Policy(Base):
       -    __tablename__ = 'policy'
       -    start_node                  = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
       -    short_channel_id            = Column(String(64), ForeignKey('channel_info.short_channel_id'), primary_key=True)
       -    cltv_expiry_delta           = Column(Integer, nullable=False)
       -    htlc_minimum_msat           = Column(Integer, nullable=False)
       -    htlc_maximum_msat           = Column(Integer)
       -    fee_base_msat               = Column(Integer, nullable=False)
       -    fee_proportional_millionths = Column(Integer, nullable=False)
       -    channel_flags               = Column(Integer, nullable=False)
       -    timestamp                   = Column(Integer, nullable=False)
       +        return ChannelInfo(
       +            short_channel_id = channel_id,
       +            node1_id = node_id_1,
       +            node2_id = node_id_2,
       +            capacity_sat = capacity_sat,
       +            msg_payload = msg_payload,
       +            trusted = False)
       +
       +
       +
       +class Policy(NamedTuple):
       +    key: bytes
       +    cltv_expiry_delta: int
       +    htlc_minimum_msat: int
       +    htlc_maximum_msat: int
       +    fee_base_msat: int
       +    fee_proportional_millionths: int
       +    channel_flags: int
       +    timestamp: int
        
            @staticmethod
            def from_msg(payload):
       -        cltv_expiry_delta           = int.from_bytes(payload['cltv_expiry_delta'], "big")
       -        htlc_minimum_msat           = int.from_bytes(payload['htlc_minimum_msat'], "big")
       -        htlc_maximum_msat           = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None
       -        fee_base_msat               = int.from_bytes(payload['fee_base_msat'], "big")
       -        fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big")
       -        channel_flags               = int.from_bytes(payload['channel_flags'], "big")
       -        timestamp                   = int.from_bytes(payload['timestamp'], "big")
       -        start_node                  = payload['start_node'].hex()
       -        short_channel_id            = payload['short_channel_id'].hex()
       -
       -        return Policy(start_node=start_node,
       -                short_channel_id=short_channel_id,
       -                cltv_expiry_delta=cltv_expiry_delta,
       -                htlc_minimum_msat=htlc_minimum_msat,
       -                fee_base_msat=fee_base_msat,
       -                fee_proportional_millionths=fee_proportional_millionths,
       -                channel_flags=channel_flags,
       -                timestamp=timestamp,
       -                htlc_maximum_msat=htlc_maximum_msat)
       +        return Policy(
       +            key                         = payload['short_channel_id'] + payload['start_node'],
       +            cltv_expiry_delta           = int.from_bytes(payload['cltv_expiry_delta'], "big"),
       +            htlc_minimum_msat           = int.from_bytes(payload['htlc_minimum_msat'], "big"),
       +            htlc_maximum_msat           = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None,
       +            fee_base_msat               = int.from_bytes(payload['fee_base_msat'], "big"),
       +            fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big"),
       +            channel_flags               = int.from_bytes(payload['channel_flags'], "big"),
       +            timestamp                   = int.from_bytes(payload['timestamp'], "big")
       +        )
        
            def is_disabled(self):
                return self.channel_flags & FLAG_DISABLE
        
       -class NodeInfo(Base):
       -    __tablename__ = 'node_info'
       -    node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
       -    features = Column(Integer, nullable=False)
       -    timestamp = Column(Integer, nullable=False)
       -    alias = Column(String(64), nullable=False)
       +    @property
       +    def short_channel_id(self):
       +        return self.key[0:8]
       +
       +    @property
       +    def start_node(self):
       +        return self.key[8:]
       +
       +
       +
       +class NodeInfo(NamedTuple):
       +    node_id: bytes
       +    features: int
       +    timestamp: int
       +    alias: str
        
            @staticmethod
            def from_msg(payload):
       -        node_id = payload['node_id'].hex()
       +        node_id = payload['node_id']
                features = int.from_bytes(payload['features'], "big")
                validate_features(features)
                addresses = NodeInfo.parse_addresses_field(payload['addresses'])
       -        alias = payload['alias'].rstrip(b'\x00').hex()
       +        alias = payload['alias'].rstrip(b'\x00')
                timestamp = int.from_bytes(payload['timestamp'], "big")
                return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [
                    Address(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses]
       t@@ -193,110 +192,136 @@ class NodeInfo(Base):
                        break
                return addresses
        
       -class Address(Base):
       +
       +class Address(NamedTuple):
       +    node_id: bytes
       +    host: str
       +    port: int
       +    last_connected_date: int
       +
       +
       +class ChannelInfoBase(Base):
       +    __tablename__ = 'channel_info'
       +    short_channel_id = Column(String(64), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
       +    node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
       +    node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
       +    capacity_sat = Column(Integer)
       +    msg_payload = Column(String(1024), nullable=False)
       +    trusted = Column(Boolean, nullable=False)
       +
       +    def to_nametuple(self):
       +        return ChannelInfo(
       +            short_channel_id=self.short_channel_id,
       +            node1_id=self.node1_id,
       +            node2_id=self.node2_id,
       +            capacity_sat=self.capacity_sat,
       +            msg_payload=self.msg_payload,
       +            trusted=self.trusted
       +        )
       +
       +class PolicyBase(Base):
       +    __tablename__ = 'policy'
       +    key                         = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
       +    cltv_expiry_delta           = Column(Integer, nullable=False)
       +    htlc_minimum_msat           = Column(Integer, nullable=False)
       +    htlc_maximum_msat           = Column(Integer)
       +    fee_base_msat               = Column(Integer, nullable=False)
       +    fee_proportional_millionths = Column(Integer, nullable=False)
       +    channel_flags               = Column(Integer, nullable=False)
       +    timestamp                   = Column(Integer, nullable=False)
       +
       +    def to_nametuple(self):
       +        return Policy(
       +            key=self.key,
       +            cltv_expiry_delta=self.cltv_expiry_delta,
       +            htlc_minimum_msat=self.htlc_minimum_msat,
       +            htlc_maximum_msat=self.htlc_maximum_msat,
       +            fee_base_msat= self.fee_base_msat,
       +            fee_proportional_millionths = self.fee_proportional_millionths,
       +            channel_flags=self.channel_flags,
       +            timestamp=self.timestamp
       +        )
       +
       +class NodeInfoBase(Base):
       +    __tablename__ = 'node_info'
       +    node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
       +    features = Column(Integer, nullable=False)
       +    timestamp = Column(Integer, nullable=False)
       +    alias = Column(String(64), nullable=False)
       +
       +class AddressBase(Base):
            __tablename__ = 'address'
       -    node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
       +    node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
            host = Column(String(256), primary_key=True)
            port = Column(Integer, primary_key=True)
            last_connected_date = Column(Integer(), nullable=True)
        
        
       -
        class ChannelDB(SqlDB):
        
            NUM_MAX_RECENT_PEERS = 20
        
            def __init__(self, network: 'Network'):
                path = os.path.join(get_headers_dir(network.config), 'channel_db')
       -        super().__init__(network, path, Base)
       +        super().__init__(network, path, Base, commit_interval=100)
                self.num_nodes = 0
                self.num_channels = 0
                self._channel_updates_for_private_channels = {}  # type: Dict[Tuple[bytes, bytes], dict]
                self.ca_verifier = LNChannelVerifier(network, self)
       -        self.update_counts()
       +        # initialized in load_data
       +        self._channels = {}
       +        self._policies = {}
       +        self._nodes = {}
       +        self._addresses = defaultdict(set)
       +        self._channels_for_node = defaultdict(set)
        
       -    @sql
            def update_counts(self):
       -        self._update_counts()
       +        self.num_channels = len(self._channels)
       +        self.num_policies = len(self._policies)
       +        self.num_nodes = len(self._nodes)
        
       -    def _update_counts(self):
       -        self.num_channels = self.DBSession.query(ChannelInfo).count()
       -        self.num_policies = self.DBSession.query(Policy).count()
       -        self.num_nodes = self.DBSession.query(NodeInfo).count()
       +    def get_channel_ids(self):
       +        return set(self._channels.keys())
        
       -    @sql
       -    def known_ids(self):
       -        known = self.DBSession.query(ChannelInfo.short_channel_id).all()
       -        return set(bfh(r.short_channel_id) for r in known)
       -
       -    @sql
            def add_recent_peer(self, peer: LNPeerAddr):
                now = int(time.time())
       -        node_id = peer.pubkey.hex()
       -        addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
       +        node_id = peer.pubkey
       +        self._addresses[node_id].add((peer.host, peer.port, now))
       +        self.save_address(node_id, peer, now)
       +
       +    @sql
       +    def save_address(self, node_id, peer, now):
       +        addr = self.DBSession.query(AddressBase).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
                if addr:
                    addr.last_connected_date = now
                else:
       -            addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
       +            addr = AddressBase(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
                    self.DBSession.add(addr)
       -        self.DBSession.commit()
       -
       -    @sql
       -    def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes):
       -        unshuffled = self.DBSession \
       -            .query(NodeInfo) \
       -            .filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \
       -            .limit(200) \
       -            .all()
       -        return random.sample(unshuffled, len(unshuffled))
        
       -    @sql
       -    def nodes_get(self, node_id):
       -        return self.DBSession \
       -            .query(NodeInfo) \
       -            .filter_by(node_id = node_id.hex()) \
       -            .one_or_none()
       +    def get_200_randomly_sorted_nodes_not_in(self, node_ids):
       +        unshuffled = set(self._nodes.keys()) - node_ids
       +        return random.sample(unshuffled, min(200, len(unshuffled)))
        
       -    @sql
            def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
       -        r = self.DBSession.query(Address).filter_by(node_id=node_id.hex()).order_by(Address.last_connected_date.desc()).all()
       +        r = self._addresses.get(node_id)
                if not r:
                    return None
       -        addr = r[0]
       -        return LNPeerAddr(addr.host, addr.port, bytes.fromhex(addr.node_id))
       +        addr = sorted(list(r), key=lambda x: x[2])[0]
       +        host, port, timestamp = addr
       +        return LNPeerAddr(host, port, node_id)
        
       -    @sql
            def get_recent_peers(self):
       -        r = self.DBSession.query(Address).filter(Address.last_connected_date.isnot(None)).order_by(Address.last_connected_date.desc()).limit(self.NUM_MAX_RECENT_PEERS).all()
       -        return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r]
       -
       -    @sql
       -    def missing_channel_announcements(self) -> Set[int]:
       -        expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
       -        return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
       +        r = [self.get_last_good_address(x) for x in self._addresses.keys()]
       +        r = r[-self.NUM_MAX_RECENT_PEERS:]
       +        return r
        
       -    @sql
       -    def missing_channel_updates(self) -> Set[int]:
       -        expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id)))
       -        return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
       -
       -    @sql
       -    def add_verified_channel_info(self, short_id, capacity):
       -        # called from lnchannelverifier
       -        channel_info = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_id.hex()).one_or_none()
       -        channel_info.trusted = True
       -        channel_info.capacity = capacity
       -        self.DBSession.commit()
       -
       -    @sql
       -    @profiler
       -    def on_channel_announcement(self, msg_payloads, trusted=True):
       +    def add_channel_announcement(self, msg_payloads, trusted=True):
                if type(msg_payloads) is dict:
                    msg_payloads = [msg_payloads]
       -        new_channels = {}
       +        added = 0
                for msg in msg_payloads:
       -            short_channel_id = bh2u(msg['short_channel_id'])
       -            if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count():
       +            short_channel_id = msg['short_channel_id']
       +            if short_channel_id in self._channels:
                        continue
                    if constants.net.rev_genesis_bytes() != msg['chain_hash']:
                        self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash'])))
       t@@ -306,24 +331,24 @@ class ChannelDB(SqlDB):
                    except UnknownEvenFeatureBits:
                        self.logger.info("unknown feature bits")
                        continue
       -            channel_info.trusted = trusted
       -            new_channels[short_channel_id] = channel_info
       +            #channel_info.trusted = trusted
       +            added += 1
       +            self._channels[short_channel_id] = channel_info
       +            self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
       +            self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
       +            self.save_channel(channel_info)
                    if not trusted:
                        self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
       -        for channel_info in new_channels.values():
       -            self.DBSession.add(channel_info)
       -        self.DBSession.commit()
       -        self._update_counts()
       -        self.logger.debug('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
        
       -    @sql
       -    def get_last_timestamp(self):
       -        return self._get_last_timestamp()
       +        self.update_counts()
       +        self.logger.debug('add_channel_announcement: %d/%d'%(added, len(msg_payloads)))
        
       -    def _get_last_timestamp(self):
       -        from sqlalchemy.sql import func
       -        r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one()
       -        return r.max_timestamp or 0
       +
       +    #def add_verified_channel_info(self, short_id, capacity):
       +    #    # called from lnchannelverifier
       +    #    channel_info = self.DBSession.query(ChannelInfoBase).filter_by(short_channel_id = short_id).one_or_none()
       +    #    channel_info.trusted = True
       +    #    channel_info.capacity = capacity
        
            def print_change(self, old_policy, new_policy):
                # print what changed between policies
       t@@ -340,89 +365,74 @@ class ChannelDB(SqlDB):
                if old_policy.channel_flags != new_policy.channel_flags:
                    self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
        
       -    @sql
       -    def get_info_for_updates(self, payloads):
       -        short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads]
       -        channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
       -        channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
       -        return channel_infos
       -
       -    @sql
       -    def get_policies_for_updates(self, payloads):
       -        out = {}
       -        for payload in payloads:
       -            short_channel_id = payload['short_channel_id'].hex()
       -            start_node = payload['start_node'].hex()
       -            policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none()
       -            if policy:
       -                out[short_channel_id+start_node] = policy
       -        return out
       -
       -    @profiler
       -    def filter_channel_updates(self, payloads, max_age=None):
       +    def add_channel_updates(self, payloads, max_age=None, verify=True):
                orphaned = []      # no channel announcement for channel update
                expired = []       # update older than two weeks
                deprecated = []    # update older than database entry
       -        good = {}          # good updates
       +        good = []          # good updates
                to_delete = []     # database entries to delete
                # filter orphaned and expired first
                known = []
                now = int(time.time())
       -        channel_infos = self.get_info_for_updates(payloads)
                for payload in payloads:
                    short_channel_id = payload['short_channel_id']
                    timestamp = int.from_bytes(payload['timestamp'], "big")
                    if max_age and now - timestamp > max_age:
                        expired.append(short_channel_id)
                        continue
       -            channel_info = channel_infos.get(short_channel_id)
       +            channel_info = self._channels.get(short_channel_id)
                    if not channel_info:
                        orphaned.append(short_channel_id)
                        continue
                    flags = int.from_bytes(payload['channel_flags'], 'big')
                    direction = flags & FLAG_DIRECTION
                    start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
       -            payload['start_node'] = bfh(start_node)
       +            payload['start_node'] = start_node
                    known.append(payload)
                # compare updates to existing database entries
       -        old_policies = self.get_policies_for_updates(known)
                for payload in known:
                    timestamp = int.from_bytes(payload['timestamp'], "big")
                    start_node = payload['start_node']
                    short_channel_id = payload['short_channel_id']
       -            key = (short_channel_id+start_node).hex()
       -            old_policy = old_policies.get(key)
       -            if old_policy:
       -                if timestamp <= old_policy.timestamp:
       -                    deprecated.append(short_channel_id)
       -                else:
       -                    good[key] = payload
       -                    to_delete.append(old_policy)
       -            else:
       -                good[key] = payload
       -        good = list(good.values())
       +            key = (start_node, short_channel_id)
       +            old_policy = self._policies.get(key)
       +            if old_policy and timestamp <= old_policy.timestamp:
       +                deprecated.append(short_channel_id)
       +                continue
       +            good.append(payload)
       +            if verify:
       +                self.verify_channel_update(payload)
       +            policy = Policy.from_msg(payload)
       +            self._policies[key] = policy
       +            self.save_policy(policy)
       +        #
       +        self.update_counts()
                return orphaned, expired, deprecated, good, to_delete
        
            def add_channel_update(self, payload):
       -        orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload])
       +        orphaned, expired, deprecated, good, to_delete = self.add_channel_updates([payload], verify=False)
                assert len(good) == 1
       -        self.update_policies(good, to_delete)
        
            @sql
       -    @profiler
       -    def update_policies(self, to_add, to_delete):
       -        for policy in to_delete:
       -            self.DBSession.delete(policy)
       -        self.DBSession.commit()
       -        for payload in to_add:
       -            policy = Policy.from_msg(payload)
       -            self.DBSession.add(policy)
       -        self.DBSession.commit()
       -        self._update_counts()
       +    def save_policy(self, policy):
       +        self.DBSession.execute(PolicyBase.__table__.insert().values(policy))
        
            @sql
       -    @profiler
       -    def on_node_announcement(self, msg_payloads):
       +    def delete_policy(self, short_channel_id, node_id):
       +        self.DBSession.execute(PolicyBase.__table__.delete().values(policy))
       +
       +    @sql
       +    def save_channel(self, channel_info):
       +        self.DBSession.execute(ChannelInfoBase.__table__.insert().values(channel_info))
       +
       +    def verify_channel_update(self, payload):
       +        short_channel_id = payload['short_channel_id']
       +        if constants.net.rev_genesis_bytes() != payload['chain_hash']:
       +            raise Exception('wrong chain hash')
       +        if not verify_sig_for_channel_update(payload, payload['start_node']):
       +            raise BaseException('verify error')
       +
       +    def add_node_announcement(self, msg_payloads):
                if type(msg_payloads) is dict:
                    msg_payloads = [msg_payloads]
                old_addr = None
       t@@ -435,29 +445,35 @@ class ChannelDB(SqlDB):
                        continue
                    node_id = node_info.node_id
                    # Ignore node if it has no associated channel (DoS protection)
       -            # FIXME this is slow
       -            expr = or_(ChannelInfo.node1_id==node_id, ChannelInfo.node2_id==node_id)
       -            if len(self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).limit(1).all()) == 0:
       +            if node_id not in self._channels_for_node:
                        #self.logger.info('ignoring orphan node_announcement')
                        continue
       -            node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none()
       +            node = self._nodes.get(node_id)
                    if node and node.timestamp >= node_info.timestamp:
                        continue
                    node = new_nodes.get(node_id)
                    if node and node.timestamp >= node_info.timestamp:
                        continue
       -            new_nodes[node_id] = node_info
       +            # save
       +            self._nodes[node_id] = node_info
       +            self.save_node(node_info)
                    for addr in node_addresses:
       -                new_addresses[(addr.node_id,addr.host,addr.port)] = addr
       +                self._addresses[node_id].add((addr.host, addr.port, 0))
       +            self.save_node_addresses(node_id, node_addresses)
       +
                self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
       -        for node_info in new_nodes.values():
       -            self.DBSession.add(node_info)
       -        for new_addr in new_addresses.values():
       -            old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
       +        self.update_counts()
       +
       +    @sql
       +    def save_node_addresses(self, node_if, node_addresses):
       +        for new_addr in node_addresses:
       +            old_addr = self.DBSession.query(AddressBase).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
                    if not old_addr:
       -                self.DBSession.add(new_addr)
       -        self.DBSession.commit()
       -        self._update_counts()
       +                self.DBSession.execute(AddressBase.__table__.insert().values(new_addr))
       +
       +    @sql
       +    def save_node(self, node_info):
       +        self.DBSession.execute(NodeInfoBase.__table__.insert().values(node_info))
        
            def get_routing_policy_for_channel(self, start_node_id: bytes,
                                               short_channel_id: bytes) -> Optional[bytes]:
       t@@ -470,41 +486,28 @@ class ChannelDB(SqlDB):
                    return None
                return Policy.from_msg(msg) # won't actually be written to DB
        
       -    @sql
       -    @profiler
            def get_old_policies(self, delta):
       -        timestamp = int(time.time()) - delta
       -        old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp)
       -        return old_policies.distinct().count()
       +        now = int(time.time())
       +        return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
        
       -    @sql
       -    @profiler
            def prune_old_policies(self, delta):
       -        # note: delete queries are order sensitive
       -        timestamp = int(time.time()) - delta
       -        old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp)
       -        delete_old_channels = ChannelInfo.__table__.delete().where(ChannelInfo.short_channel_id.in_(old_policies))
       -        delete_old_policies = Policy.__table__.delete().where(Policy.timestamp <= timestamp)
       -        self.DBSession.execute(delete_old_channels)
       -        self.DBSession.execute(delete_old_policies)
       -        self.DBSession.commit()
       -        self._update_counts()
       +        l = self.get_old_policies(delta)
       +        for k in l:
       +            self._policies.pop(k)
       +        if l:
       +            self.logger.info(f'Deleting {len(l)} old policies')
        
       -    @sql
       -    @profiler
            def get_orphaned_channels(self):
       -        subquery = self.DBSession.query(Policy.short_channel_id)
       -        orphaned = self.DBSession.query(ChannelInfo).filter(not_(ChannelInfo.short_channel_id.in_(subquery)))
       -        return orphaned.count()
       +        ids = set(x[1] for x in self._policies.keys())
       +        return list(x for x in self._channels.keys() if x not in ids)
        
       -    @sql
       -    @profiler
            def prune_orphaned_channels(self):
       -        subquery = self.DBSession.query(Policy.short_channel_id)
       -        delete_orphaned = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(subquery)))
       -        self.DBSession.execute(delete_orphaned)
       -        self.DBSession.commit()
       -        self._update_counts()
       +        l = self.get_orphaned_channels()
       +        for k in l:
       +            self._channels.pop(k)
       +        self.update_counts()
       +        if l:
       +            self.logger.info(f'Deleting {len(l)} orphaned channels')
        
            def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
                if not verify_sig_for_channel_update(msg_payload, start_node_id):
       t@@ -513,67 +516,27 @@ class ChannelDB(SqlDB):
                msg_payload['start_node'] = start_node_id
                self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
        
       -    @sql
            def remove_channel(self, short_channel_id):
       -        r = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()).one_or_none()
       -        if not r:
       -            return
       -        self.DBSession.delete(r)
       -        self.DBSession.commit()
       -
       -    def print_graph(self, full_ids=False):
       -        # used for debugging.
       -        # FIXME there is a race here - iterables could change size from another thread
       -        def other_node_id(node_id, channel_id):
       -            channel_info = self.get_channel_info(channel_id)
       -            if node_id == channel_info.node1_id:
       -                other = channel_info.node2_id
       -            else:
       -                other = channel_info.node1_id
       -            return other if full_ids else other[-4:]
       -
       -        print_msg('nodes')
       -        for node in self.DBSession.query(NodeInfo).all():
       -            print_msg(node)
       -
       -        print_msg('channels')
       -        for channel_info in self.DBSession.query(ChannelInfo).all():
       -            short_channel_id = channel_info.short_channel_id
       -            node1 = channel_info.node1_id
       -            node2 = channel_info.node2_id
       -            direction1 = self.get_policy_for_node(channel_info, node1) is not None
       -            direction2 = self.get_policy_for_node(channel_info, node2) is not None
       -            if direction1 and direction2:
       -                direction = 'both'
       -            elif direction1:
       -                direction = 'forward'
       -            elif direction2:
       -                direction = 'backward'
       -            else:
       -                direction = 'none'
       -            print_msg('{}: {}, {}, {}'
       -                           .format(bh2u(short_channel_id),
       -                                   bh2u(node1) if full_ids else bh2u(node1[-4:]),
       -                                   bh2u(node2) if full_ids else bh2u(node2[-4:]),
       -                                   direction))
       +        self._channels.pop(short_channel_id, None)
        
       -
       -    @sql
       -    def get_node_addresses(self, node_info):
       -        return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all()
       +    def get_node_addresses(self, node_id):
       +        return self._addresses.get(node_id)
        
            @sql
            @profiler
            def load_data(self):
       -        r = self.DBSession.query(ChannelInfo).all()
       -        self._channels = dict([(bfh(x.short_channel_id), x) for x in r])
       -        r = self.DBSession.query(Policy).filter_by().all()
       -        self._policies = dict([((bfh(x.start_node), bfh(x.short_channel_id)), x) for x in r])
       -        self._channels_for_node = defaultdict(set)
       +        for x in self.DBSession.query(AddressBase).all():
       +            self._addresses[x.node_id].add((str(x.host), int(x.port), int(x.last_connected_date or 0)))
       +        for x in self.DBSession.query(ChannelInfoBase).all():
       +            self._channels[x.short_channel_id] = x.to_nametuple()
       +        for x in self.DBSession.query(PolicyBase).filter_by().all():
       +            p = x.to_nametuple()
       +            self._policies[(p.start_node, p.short_channel_id)] = p
                for channel_info in self._channels.values():
       -            self._channels_for_node[bfh(channel_info.node1_id)].add(bfh(channel_info.short_channel_id))
       -            self._channels_for_node[bfh(channel_info.node2_id)].add(bfh(channel_info.short_channel_id))
       +            self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
       +            self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
                self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}')
       +        self.update_counts()
        
            def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
                return self._policies.get((node_id, short_channel_id))
       t@@ -584,6 +547,3 @@ class ChannelDB(SqlDB):
            def get_channels_for_node(self, node_id) -> Set[bytes]:
                """Returns the set of channels that have node_id as one of the endpoints."""
                return self._channels_for_node.get(node_id) or set()
       -
       -
       -
   DIR diff --git a/electrum/gui/qt/lightning_dialog.py b/electrum/gui/qt/lightning_dialog.py
       t@@ -56,10 +56,11 @@ class WatcherList(MyTreeView):
                    return
                self.model().clear()
                self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')})
       -        sweepstore = self.parent.lnwatcher.sweepstore
       -        for outpoint in sweepstore.list_sweep_tx():
       -            n = sweepstore.get_num_tx(outpoint)
       -            status = self.parent.lnwatcher.get_channel_status(outpoint)
       +        lnwatcher = self.parent.lnwatcher
       +        l = lnwatcher.list_sweep_tx()
       +        for outpoint in l:
       +            n = lnwatcher.get_num_tx(outpoint)
       +            status = lnwatcher.get_channel_status(outpoint)
                    items = [QStandardItem(e) for e in [outpoint, "%d"%n, status]]
                    self.model().insertRow(self.model().rowCount(), items)
        
   DIR diff --git a/electrum/lnaddr.py b/electrum/lnaddr.py
       t@@ -258,14 +258,21 @@ class LnAddr(object):
            def get_min_final_cltv_expiry(self) -> int:
                return self._min_final_cltv_expiry
        
       -    def get_description(self):
       +    def get_tag(self, tag):
                description = ''
                for k,v in self.tags:
       -            if k == 'd':
       +            if k == tag:
                        description = v
                        break
                return description
        
       +    def get_description(self):
       +        return self.get_tag('d')
       +
       +    def get_expiry(self):
       +        return int(self.get_tag('x') or '3600')
       +
       +
        
        def lndecode(a, verbose=False, expected_hrp=None):
            if expected_hrp is None:
   DIR diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py
       t@@ -163,8 +163,6 @@ class Channel(Logger):
                self._is_funding_txo_spent = None  # "don't know"
                self._state = None
                self.set_state('DISCONNECTED')
       -        self.lnwatcher = None
       -
                self.local_commitment = None
                self.remote_commitment = None
                self.sweep_info = None
       t@@ -453,13 +451,10 @@ class Channel(Logger):
                return secret, point
        
            def process_new_revocation_secret(self, per_commitment_secret: bytes):
       -        if not self.lnwatcher:
       -            return
                outpoint = self.funding_outpoint.to_str()
                ctx = self.remote_commitment_to_be_revoked  # FIXME can't we just reconstruct it?
                sweeptxs = create_sweeptxs_for_their_revoked_ctx(self, ctx, per_commitment_secret, self.sweep_address)
       -        for tx in sweeptxs:
       -            self.lnwatcher.add_sweep_tx(outpoint, tx.prevout(0), str(tx))
       +        return sweeptxs
        
            def receive_revocation(self, revocation: RevokeAndAck):
                self.logger.info("receive_revocation")
       t@@ -477,9 +472,10 @@ class Channel(Logger):
        
                # be robust to exceptions raised in lnwatcher
                try:
       -            self.process_new_revocation_secret(revocation.per_commitment_secret)
       +            sweeptxs = self.process_new_revocation_secret(revocation.per_commitment_secret)
                except Exception as e:
                    self.logger.info("Could not process revocation secret: {}".format(repr(e)))
       +            sweeptxs = []
        
                ##### start applying fee/htlc changes
        
       t@@ -505,6 +501,8 @@ class Channel(Logger):
        
                self.set_remote_commitment()
                self.remote_commitment_to_be_revoked = prev_remote_commitment
       +        # return sweep transactions for watchtower
       +        return sweeptxs
        
            def balance(self, whose, *, ctx_owner=HTLCOwner.LOCAL, ctn=None):
                """
   DIR diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
       t@@ -42,7 +42,6 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc,
                             MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED, RemoteMisbehaving, DEFAULT_TO_SELF_DELAY)
        from .lntransport import LNTransport, LNTransportBase
        from .lnmsg import encode_msg, decode_msg
       -from .lnverifier import verify_sig_for_channel_update
        from .interface import GracefulDisconnect
        
        if TYPE_CHECKING:
       t@@ -242,22 +241,20 @@ class Peer(Logger):
                    # channel announcements
                    for chan_anns_chunk in chunks(chan_anns, 300):
                        self.verify_channel_announcements(chan_anns_chunk)
       -                self.channel_db.on_channel_announcement(chan_anns_chunk)
       +                self.channel_db.add_channel_announcement(chan_anns_chunk)
                    # node announcements
                    for node_anns_chunk in chunks(node_anns, 100):
                        self.verify_node_announcements(node_anns_chunk)
       -                self.channel_db.on_node_announcement(node_anns_chunk)
       +                self.channel_db.add_node_announcement(node_anns_chunk)
                    # channel updates
                    for chan_upds_chunk in chunks(chan_upds, 1000):
       -                orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates(chan_upds_chunk,
       -                                                                                                        max_age=self.network.lngossip.max_age)
       +                orphaned, expired, deprecated, good, to_delete = self.channel_db.add_channel_updates(
       +                    chan_upds_chunk, max_age=self.network.lngossip.max_age)
                        if orphaned:
                            self.logger.info(f'adding {len(orphaned)} unknown channel ids')
       -                    self.network.lngossip.add_new_ids(orphaned)
       +                    await self.network.lngossip.add_new_ids(orphaned)
                        if good:
                            self.logger.debug(f'on_channel_update: {len(good)}/{len(chan_upds_chunk)}')
       -                    self.verify_channel_updates(good)
       -                    self.channel_db.update_policies(good, to_delete)
                    # refresh gui
                    if chan_anns or node_anns or chan_upds:
                        self.network.lngossip.refresh_gui()
       t@@ -279,14 +276,6 @@ class Peer(Logger):
                    if not ecc.verify_signature(pubkey, signature, h):
                        raise Exception('signature failed')
        
       -    def verify_channel_updates(self, chan_upds):
       -        for payload in chan_upds:
       -            short_channel_id = payload['short_channel_id']
       -            if constants.net.rev_genesis_bytes() != payload['chain_hash']:
       -                raise Exception('wrong chain hash')
       -            if not verify_sig_for_channel_update(payload, payload['start_node']):
       -                raise BaseException('verify error')
       -
            async def query_gossip(self):
                try:
                    await asyncio.wait_for(self.initialized.wait(), 10)
       t@@ -298,7 +287,7 @@ class Peer(Logger):
                    except asyncio.TimeoutError as e:
                        raise GracefulDisconnect("query_channel_range timed out") from e
                    self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete))
       -            self.lnworker.add_new_ids(ids)
       +            await self.lnworker.add_new_ids(ids)
                    while True:
                        todo = self.lnworker.get_ids_to_query()
                        if not todo:
       t@@ -658,7 +647,7 @@ class Peer(Logger):
                )
                chan.open_with_first_pcp(payload['first_per_commitment_point'], remote_sig)
                self.lnworker.save_channel(chan)
       -        self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
       +        await self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
                self.lnworker.on_channels_updated()
                while True:
                    try:
       t@@ -862,8 +851,6 @@ class Peer(Logger):
                    bitcoin_key_2=bitcoin_keys[1]
                )
        
       -        print("SENT CHANNEL ANNOUNCEMENT")
       -
            def mark_open(self, chan: Channel):
                assert chan.short_channel_id is not None
                if chan.get_state() == "OPEN":
       t@@ -872,6 +859,10 @@ class Peer(Logger):
                assert chan.config[LOCAL].funding_locked_received
                chan.set_state("OPEN")
                self.network.trigger_callback('channel', chan)
       +        asyncio.ensure_future(self.add_own_channel(chan))
       +        self.logger.info("CHANNEL OPENING COMPLETED")
       +
       +    async def add_own_channel(self, chan):
                # add channel to database
                bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
                sorted_node_ids = list(sorted(self.node_ids))
       t@@ -887,7 +878,7 @@ class Peer(Logger):
                #   that the remote sends, even if the channel was not announced
                #   (from BOLT-07: "MAY create a channel_update to communicate the channel
                #    parameters to the final node, even though the channel has not yet been announced")
       -        self.channel_db.on_channel_announcement(
       +        self.channel_db.add_channel_announcement(
                    {
                        "short_channel_id": chan.short_channel_id,
                        "node_id_1": node_ids[0],
       t@@ -922,8 +913,6 @@ class Peer(Logger):
                if pending_channel_update:
                    self.channel_db.add_channel_update(pending_channel_update)
        
       -        self.logger.info("CHANNEL OPENING COMPLETED")
       -
            def send_announcement_signatures(self, chan: Channel):
        
                bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey,
       t@@ -962,36 +951,34 @@ class Peer(Logger):
            def on_update_fail_htlc(self, payload):
                channel_id = payload["channel_id"]
                htlc_id = int.from_bytes(payload["id"], "big")
       -        key = (channel_id, htlc_id)
       -        try:
       -            route = self.attempted_route[key]
       -        except KeyError:
       -            # the remote might try to fail an htlc after we restarted...
       -            # attempted_route is not persisted, so we will get here then
       -            self.logger.info("UPDATE_FAIL_HTLC. cannot decode! attempted route is MISSING. {}".format(key))
       -        else:
       -            try:
       -                self._handle_error_code_from_failed_htlc(payload["reason"], route, channel_id, htlc_id)
       -            except Exception:
       -                # exceptions are suppressed as failing to handle an error code
       -                # should not block us from removing the htlc
       -                traceback.print_exc(file=sys.stderr)
       -        # process update_fail_htlc on channel
                chan = self.channels[channel_id]
                chan.receive_fail_htlc(htlc_id)
                local_ctn = chan.get_current_ctn(LOCAL)
       -        asyncio.ensure_future(self._on_update_fail_htlc(chan, htlc_id, local_ctn))
       +        asyncio.ensure_future(self._handle_error_code_from_failed_htlc(payload, channel_id, htlc_id))
       +        asyncio.ensure_future(self._on_update_fail_htlc(channel_id, htlc_id, local_ctn))
        
            @log_exceptions
       -    async def _on_update_fail_htlc(self, chan, htlc_id, local_ctn):
       +    async def _on_update_fail_htlc(self, channel_id, htlc_id, local_ctn):
       +        chan = self.channels[channel_id]
                await self.await_local(chan, local_ctn)
                self.lnworker.pending_payments[(chan.short_channel_id, htlc_id)].set_result(False)
        
       -    def _handle_error_code_from_failed_htlc(self, error_reason, route: List['RouteEdge'], channel_id, htlc_id):
       +    @log_exceptions
       +    async def _handle_error_code_from_failed_htlc(self, payload, channel_id, htlc_id):
                chan = self.channels[channel_id]
       -        failure_msg, sender_idx = decode_onion_error(error_reason,
       -                                                     [x.node_id for x in route],
       -                                                     chan.onion_keys[htlc_id])
       +        key = (channel_id, htlc_id)
       +        try:
       +            route = self.attempted_route[key]
       +        except KeyError:
       +            # the remote might try to fail an htlc after we restarted...
       +            # attempted_route is not persisted, so we will get here then
       +            self.logger.info("UPDATE_FAIL_HTLC. cannot decode! attempted route is MISSING. {}".format(key))
       +            return
       +        error_reason = payload["reason"]
       +        failure_msg, sender_idx = decode_onion_error(
       +            error_reason,
       +            [x.node_id for x in route],
       +            chan.onion_keys[htlc_id])
                code, data = failure_msg.code, failure_msg.data
                self.logger.info(f"UPDATE_FAIL_HTLC {repr(code)} {data}")
                self.logger.info(f"error reported by {bh2u(route[sender_idx].node_id)}")
       t@@ -1009,11 +996,9 @@ class Peer(Logger):
                    channel_update = (258).to_bytes(length=2, byteorder="big") + data[offset:]
                    message_type, payload = decode_msg(channel_update)
                    payload['raw'] = channel_update
       -            orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates([payload])
       +            orphaned, expired, deprecated, good, to_delete = self.channel_db.add_channel_updates([payload])
                    blacklist = False
                    if good:
       -                self.verify_channel_updates(good)
       -                self.channel_db.update_policies(good, to_delete)
                        self.logger.info("applied channel update on our db")
                    elif orphaned:
                        # maybe it is a private channel (and data in invoice was outdated)
       t@@ -1276,11 +1261,17 @@ class Peer(Logger):
                self.logger.info("on_revoke_and_ack")
                channel_id = payload["channel_id"]
                chan = self.channels[channel_id]
       -        chan.receive_revocation(RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"]))
       +        sweeptxs = chan.receive_revocation(RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"]))
                self._remote_changed_events[chan.channel_id].set()
                self._remote_changed_events[chan.channel_id].clear()
                self.lnworker.save_channel(chan)
                self.maybe_send_commitment(chan)
       +        asyncio.ensure_future(self._on_revoke_and_ack(chan, sweeptxs))
       +
       +    async def _on_revoke_and_ack(self, chan, sweeptxs):
       +        outpoint = chan.funding_outpoint.to_str()
       +        for tx in sweeptxs:
       +            await self.lnwatcher.add_sweep_tx(outpoint, tx.prevout(0), str(tx))
        
            def on_update_fee(self, payload):
                channel_id = payload["channel_id"]
   DIR diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
       t@@ -37,7 +37,7 @@ import binascii
        import base64
        
        from . import constants
       -from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits, print_msg, chunks
       +from .util import bh2u, profiler, get_headers_dir, is_ip_address, list_enabled_bits, print_msg, chunks
        from .logging import Logger
        from .storage import JsonDB
        from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
       t@@ -169,7 +169,6 @@ class LNPathFinder(Logger):
                To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
                i.e. an element reads as, "to get to node_id, travel through short_channel_id"
                """
       -        self.channel_db.load_data()
                assert type(nodeA) is bytes
                assert type(nodeB) is bytes
                assert type(invoice_amount_msat) is int
       t@@ -195,11 +194,12 @@ class LNPathFinder(Logger):
                        else:  # payment incoming, on our channel. (funny business, cycle weirdness)
                            assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode))
                            pass  # TODO?
       -            edge_cost, fee_for_edge_msat = self._edge_cost(edge_channel_id,
       -                                                           start_node=edge_startnode,
       -                                                           end_node=edge_endnode,
       -                                                           payment_amt_msat=amount_msat,
       -                                                           ignore_costs=(edge_startnode == nodeA))
       +            edge_cost, fee_for_edge_msat = self._edge_cost(
       +                edge_channel_id,
       +                start_node=edge_startnode,
       +                end_node=edge_endnode,
       +                payment_amt_msat=amount_msat,
       +                ignore_costs=(edge_startnode == nodeA))
                    alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
                    if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
                        distance_from_start[edge_startnode] = alt_dist_to_neighbour
       t@@ -219,9 +219,10 @@ class LNPathFinder(Logger):
                        continue
                    for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode):
                        assert type(edge_channel_id) is bytes
       -                if edge_channel_id in self.blacklist: continue
       +                if edge_channel_id in self.blacklist:
       +                    continue
                        channel_info = self.channel_db.get_channel_info(edge_channel_id)
       -                edge_startnode = bfh(channel_info.node2_id) if bfh(channel_info.node1_id) == edge_endnode else bfh(channel_info.node1_id)
       +                edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
                        inspect_edge()
                else:
                    return None  # no path found
   DIR diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py
       t@@ -70,11 +70,11 @@ class SweepStore(SqlDB):
            @sql
            def get_tx_by_index(self, funding_outpoint, index):
                r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.index==index).one_or_none()
       -        return r.prevout, bh2u(r.tx)
       +        return str(r.prevout), bh2u(r.tx)
        
            @sql
            def list_sweep_tx(self):
       -        return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all())
       +        return set(str(r.funding_outpoint) for r in self.DBSession.query(SweepTx).all())
        
            @sql
            def add_sweep_tx(self, funding_outpoint, prevout, tx):
       t@@ -84,7 +84,7 @@ class SweepStore(SqlDB):
        
            @sql
            def get_num_tx(self, funding_outpoint):
       -        return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()
       +        return int(self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count())
        
            @sql
            def remove_sweep_tx(self, funding_outpoint):
       t@@ -111,11 +111,11 @@ class SweepStore(SqlDB):
            @sql
            def get_address(self, outpoint):
                r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()
       -        return r.address if r else None
       +        return str(r.address) if r else None
        
            @sql
            def list_channel_info(self):
       -        return [(r.address, r.outpoint) for r in self.DBSession.query(ChannelInfo).all()]
       +        return [(str(r.address), str(r.outpoint)) for r in self.DBSession.query(ChannelInfo).all()]
        
        
        class LNWatcher(AddressSynchronizer):
       t@@ -150,14 +150,21 @@ class LNWatcher(AddressSynchronizer):
                self.watchtower_queue = asyncio.Queue()
        
            def get_num_tx(self, outpoint):
       -        return self.sweepstore.get_num_tx(outpoint)
       +        async def f():
       +            return await self.sweepstore.get_num_tx(outpoint)
       +        return self.network.run_from_another_thread(f())
       +
       +    def list_sweep_tx(self):
       +        async def f():
       +            return await self.sweepstore.list_sweep_tx()
       +        return self.network.run_from_another_thread(f())
        
            @ignore_exceptions
            @log_exceptions
            async def watchtower_task(self):
                self.logger.info('watchtower task started')
                # initial check
       -        for address, outpoint in self.sweepstore.list_channel_info():
       +        for address, outpoint in await self.sweepstore.list_channel_info():
                    await self.watchtower_queue.put(outpoint)
                while True:
                    outpoint = await self.watchtower_queue.get()
       t@@ -165,30 +172,30 @@ class LNWatcher(AddressSynchronizer):
                        continue
                    # synchronize with remote
                    try:
       -                local_n = self.sweepstore.get_num_tx(outpoint)
       +                local_n = await self.sweepstore.get_num_tx(outpoint)
                        n = self.watchtower.get_num_tx(outpoint)
                        if n == 0:
       -                    address = self.sweepstore.get_address(outpoint)
       +                    address = await self.sweepstore.get_address(outpoint)
                            self.watchtower.add_channel(outpoint, address)
                        self.logger.info("sending %d transactions to watchtower"%(local_n - n))
                        for index in range(n, local_n):
       -                    prevout, tx = self.sweepstore.get_tx_by_index(outpoint, index)
       +                    prevout, tx = await self.sweepstore.get_tx_by_index(outpoint, index)
                            self.watchtower.add_sweep_tx(outpoint, prevout, tx)
                    except ConnectionRefusedError:
                        self.logger.info('could not reach watchtower, will retry in 5s')
                        await asyncio.sleep(5)
                        await self.watchtower_queue.put(outpoint)
        
       -    def add_channel(self, outpoint, address):
       +    async def add_channel(self, outpoint, address):
                self.add_address(address)
                with self.lock:
       -            if not self.sweepstore.has_channel(outpoint):
       -                self.sweepstore.add_channel(outpoint, address)
       +            if not await self.sweepstore.has_channel(outpoint):
       +                await self.sweepstore.add_channel(outpoint, address)
        
       -    def unwatch_channel(self, address, funding_outpoint):
       +    async def unwatch_channel(self, address, funding_outpoint):
                self.logger.info(f'unwatching {funding_outpoint}')
       -        self.sweepstore.remove_sweep_tx(funding_outpoint)
       -        self.sweepstore.remove_channel(funding_outpoint)
       +        await self.sweepstore.remove_sweep_tx(funding_outpoint)
       +        await self.sweepstore.remove_channel(funding_outpoint)
                if funding_outpoint in self.tx_progress:
                    self.tx_progress[funding_outpoint].all_done.set()
        
       t@@ -202,7 +209,7 @@ class LNWatcher(AddressSynchronizer):
                    return
                if not self.synchronizer.is_up_to_date():
                    return
       -        for address, outpoint in self.sweepstore.list_channel_info():
       +        for address, outpoint in await self.sweepstore.list_channel_info():
                    await self.check_onchain_situation(address, outpoint)
        
            async def check_onchain_situation(self, address, funding_outpoint):
       t@@ -223,7 +230,7 @@ class LNWatcher(AddressSynchronizer):
                                                  closing_height, closing_tx)  # FIXME sooo many args..
                    await self.do_breach_remedy(funding_outpoint, spenders)
                if not keep_watching:
       -            self.unwatch_channel(address, funding_outpoint)
       +            await self.unwatch_channel(address, funding_outpoint)
                else:
                    #self.logger.info(f'we will keep_watching {funding_outpoint}')
                    pass
       t@@ -260,7 +267,7 @@ class LNWatcher(AddressSynchronizer):
                for prevout, spender in spenders.items():
                    if spender is not None:
                        continue
       -            sweep_txns = self.sweepstore.get_sweep_tx(funding_outpoint, prevout)
       +            sweep_txns = await self.sweepstore.get_sweep_tx(funding_outpoint, prevout)
                    for tx in sweep_txns:
                        if not await self.broadcast_or_log(funding_outpoint, tx):
                            self.logger.info(f'{tx.name} could not publish tx: {str(tx)}, prevout: {prevout}')
       t@@ -279,8 +286,8 @@ class LNWatcher(AddressSynchronizer):
                        await self.tx_progress[funding_outpoint].tx_queue.put(tx)
                    return txid
        
       -    def add_sweep_tx(self, funding_outpoint: str, prevout: str, tx: str):
       -        self.sweepstore.add_sweep_tx(funding_outpoint, prevout, tx)
       +    async def add_sweep_tx(self, funding_outpoint: str, prevout: str, tx: str):
       +        await self.sweepstore.add_sweep_tx(funding_outpoint, prevout, tx)
                if self.watchtower:
                    self.watchtower_queue.put_nowait(funding_outpoint)
        
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -108,12 +108,14 @@ class LNWorker(Logger):
        
            @log_exceptions
            async def main_loop(self):
       +        # fixme: only lngossip should do that
       +        await self.channel_db.load_data()
                while True:
                    await asyncio.sleep(1)
                    now = time.time()
                    if len(self.peers) >= NUM_PEERS_TARGET:
                        continue
       -            peers = self._get_next_peers_to_try()
       +            peers = await self._get_next_peers_to_try()
                    for peer in peers:
                        last_tried = self._last_tried_peer.get(peer, 0)
                        if last_tried + PEER_RETRY_INTERVAL < now:
       t@@ -130,7 +132,8 @@ class LNWorker(Logger):
                peer = Peer(self, node_id, transport)
                await self.network.main_taskgroup.spawn(peer.main_loop())
                self.peers[node_id] = peer
       -        self.network.lngossip.refresh_gui()
       +        #if self.network.lngossip:
       +        #    self.network.lngossip.refresh_gui()
                return peer
        
            def start_network(self, network: 'Network'):
       t@@ -148,7 +151,7 @@ class LNWorker(Logger):
                        self._add_peer(host, int(port), bfh(pubkey)),
                        self.network.asyncio_loop)
        
       -    def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
       +    async def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
                now = time.time()
                recent_peers = self.channel_db.get_recent_peers()
                # maintenance for last tried times
       t@@ -158,19 +161,22 @@ class LNWorker(Logger):
                        del self._last_tried_peer[peer]
                # first try from recent peers
                for peer in recent_peers:
       -            if peer.pubkey in self.peers: continue
       -            if peer in self._last_tried_peer: continue
       +            if peer.pubkey in self.peers:
       +                continue
       +            if peer in self._last_tried_peer:
       +                continue
                    return [peer]
                # try random peer from graph
                unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys())
                if unconnected_nodes:
       -            for node in unconnected_nodes:
       -                addrs = self.channel_db.get_node_addresses(node)
       +            for node_id in unconnected_nodes:
       +                addrs = self.channel_db.get_node_addresses(node_id)
                        if not addrs:
                            continue
       -                host, port = self.choose_preferred_address(addrs)
       -                peer = LNPeerAddr(host, port, bytes.fromhex(node.node_id))
       -                if peer in self._last_tried_peer: continue
       +                host, port, timestamp = self.choose_preferred_address(addrs)
       +                peer = LNPeerAddr(host, port, node_id)
       +                if peer in self._last_tried_peer:
       +                    continue
                        #self.logger.info('taking random ln peer from our channel db')
                        return [peer]
        
       t@@ -223,15 +229,13 @@ class LNWorker(Logger):
            def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
                assert len(addr_list) >= 1
                # choose first one that is an IP
       -        for addr_in_db in addr_list:
       -            host = addr_in_db.host
       -            port = addr_in_db.port
       +        for host, port, timestamp in addr_list:
                    if is_ip_address(host):
       -                return host, port
       +                return host, port, timestamp
                # otherwise choose one at random
                # TODO maybe filter out onion if not on tor?
                choice = random.choice(addr_list)
       -        return choice.host, choice.port
       +        return choice
        
        
        class LNGossip(LNWorker):
       t@@ -260,26 +264,19 @@ class LNGossip(LNWorker):
                self.network.trigger_callback('ln_status', num_peers, num_nodes, known, unknown)
        
            async def maintain_db(self):
       -        n = self.channel_db.get_orphaned_channels()
       -        if n:
       -            self.logger.info(f'Deleting {n} orphaned channels')
       -            self.channel_db.prune_orphaned_channels()
       -            self.refresh_gui()
       +        self.channel_db.prune_orphaned_channels()
                while True:
       -            n = self.channel_db.get_old_policies(self.max_age)
       -            if n:
       -                self.logger.info(f'Deleting {n} old channels')
       -                self.channel_db.prune_old_policies(self.max_age)
       -                self.refresh_gui()
       +            self.channel_db.prune_old_policies(self.max_age)
       +            self.refresh_gui()
                    await asyncio.sleep(5)
        
       -    def add_new_ids(self, ids):
       -        known = self.channel_db.known_ids()
       +    async def add_new_ids(self, ids):
       +        known = self.channel_db.get_channel_ids()
                new = set(ids) - set(known)
                self.unknown_ids.update(new)
        
            def get_ids_to_query(self):
       -        N = 500
       +        N = 100
                l = list(self.unknown_ids)
                self.unknown_ids = set(l[N:])
                return l[0:N]
       t@@ -324,9 +321,10 @@ class LNWallet(LNWorker):
                self.network.register_callback(self.on_network_update, ['wallet_updated', 'network_updated', 'verified', 'fee'])  # thread safe
                self.network.register_callback(self.on_channel_open, ['channel_open'])
                self.network.register_callback(self.on_channel_closed, ['channel_closed'])
       +
                for chan_id, chan in self.channels.items():
       -            self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
       -            chan.lnwatcher = network.lnwatcher
       +            self.network.lnwatcher.add_address(chan.get_funding_address())
       +
                super().start_network(network)
                for coro in [
                        self.maybe_listen(),
       t@@ -494,7 +492,7 @@ class LNWallet(LNWorker):
                chan = self.channel_by_txo(funding_outpoint)
                if not chan:
                    return
       -        self.logger.debug(f'on_channel_open {funding_outpoint}')
       +        #self.logger.debug(f'on_channel_open {funding_outpoint}')
                self.channel_timestamps[bh2u(chan.channel_id)] = funding_txid, funding_height.height, funding_height.timestamp, None, None, None
                self.storage.put('lightning_channel_timestamps', self.channel_timestamps)
                chan.set_funding_txo_spentness(False)
       t@@ -606,7 +604,8 @@ class LNWallet(LNWorker):
                            self.logger.info('REBROADCASTING CLOSING TX')
                            await self.force_close_channel(chan.channel_id)
        
       -    async def _open_channel_coroutine(self, peer, local_amount_sat, push_sat, password):
       +    async def _open_channel_coroutine(self, connect_str, local_amount_sat, push_sat, password):
       +        peer = await self.add_peer(connect_str)
                # peer might just have been connected to
                await asyncio.wait_for(peer.initialized.wait(), 5)
                chan = await peer.channel_establishment_flow(
       t@@ -615,24 +614,22 @@ class LNWallet(LNWorker):
                    push_msat=push_sat * 1000,
                    temp_channel_id=os.urandom(32))
                self.save_channel(chan)
       -        self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
       +        self.network.lnwatcher.add_address(chan.get_funding_address())
       +        await self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
                self.on_channels_updated()
                return chan
        
            def on_channels_updated(self):
                self.network.trigger_callback('channels')
        
       -    def add_peer(self, connect_str, timeout=20):
       +    async def add_peer(self, connect_str, timeout=20):
                node_id, rest = extract_nodeid(connect_str)
                peer = self.peers.get(node_id)
                if not peer:
                    if rest is not None:
                        host, port = split_host_port(rest)
                    else:
       -                node_info = self.network.channel_db.nodes_get(node_id)
       -                if not node_info:
       -                    raise ConnStringFormatError(_('Unknown node:') + ' ' + bh2u(node_id))
       -                addrs = self.channel_db.get_node_addresses(node_info)
       +                addrs = self.channel_db.get_node_addresses(node_id)
                        if len(addrs) == 0:
                            raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + bh2u(node_id))
                        host, port = self.choose_preferred_address(addrs)
       t@@ -640,18 +637,12 @@ class LNWallet(LNWorker):
                        socket.getaddrinfo(host, int(port))
                    except socket.gaierror:
                        raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)'))
       -            peer_future = asyncio.run_coroutine_threadsafe(
       -                self._add_peer(host, port, node_id),
       -                self.network.asyncio_loop)
       -            try:
       -                peer = peer_future.result(timeout)
       -            except concurrent.futures.TimeoutError:
       -                raise Exception(_("add_peer timed out"))
       +            # add peer
       +            peer = await self._add_peer(host, port, node_id)
                return peer
        
            def open_channel(self, connect_str, local_amt_sat, push_amt_sat, password=None, timeout=20):
       -        peer = self.add_peer(connect_str, timeout)
       -        coro = self._open_channel_coroutine(peer, local_amt_sat, push_amt_sat, password)
       +        coro = self._open_channel_coroutine(connect_str, local_amt_sat, push_amt_sat, password)
                fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
                try:
                    chan = fut.result(timeout=timeout)
       t@@ -664,6 +655,9 @@ class LNWallet(LNWorker):
                Can be called from other threads
                Raises timeout exception if htlc is not fulfilled
                """
       +        addr = self._check_invoice(invoice, amount_sat)
       +        self.save_invoice(addr.paymenthash, invoice, SENT, is_paid=False)
       +        self.wallet.set_label(bh2u(addr.paymenthash), addr.get_description())
                fut = asyncio.run_coroutine_threadsafe(
                    self._pay(invoice, attempts, amount_sat),
                    self.network.asyncio_loop)
       t@@ -680,8 +674,6 @@ class LNWallet(LNWorker):
        
            async def _pay(self, invoice, attempts=1, amount_sat=None):
                addr = self._check_invoice(invoice, amount_sat)
       -        self.save_invoice(addr.paymenthash, invoice, SENT, is_paid=False)
       -        self.wallet.set_label(bh2u(addr.paymenthash), addr.get_description())
                for i in range(attempts):
                    route = await self._create_route_from_invoice(decoded_invoice=addr)
                    if not self.get_channel_by_short_id(route[0].short_channel_id):
       t@@ -691,7 +683,7 @@ class LNWallet(LNWorker):
                        return True
                return False
        
       -    async def _pay_to_route(self, route, addr, pay_req):
       +    async def _pay_to_route(self, route, addr, invoice):
                short_channel_id = route[0].short_channel_id
                chan = self.get_channel_by_short_id(short_channel_id)
                if not chan:
       t@@ -713,6 +705,9 @@ class LNWallet(LNWorker):
                    raise InvoiceError("{}\n{}".format(
                        _("Invoice wants us to risk locking funds for unreasonably long."),
                        f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
       +        #now = int(time.time())
       +        #if addr.date + addr.get_expiry() > now:
       +        #    raise InvoiceError(_('Invoice expired'))
                return addr
        
            async def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]:
       t@@ -730,11 +725,14 @@ class LNWallet(LNWorker):
                with self.lock:
                    channels = list(self.channels.values())
                for private_route in r_tags:
       -            if len(private_route) == 0: continue
       -            if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH: continue
       +            if len(private_route) == 0:
       +                continue
       +            if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
       +                continue
                    border_node_pubkey = private_route[0][0]
                    path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels)
       -            if not path: continue
       +            if not path:
       +                continue
                    route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
                    # we need to shift the node pubkey by one towards the destination:
                    private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey]
       t@@ -770,10 +768,18 @@ class LNWallet(LNWorker):
                return route
        
            def add_invoice(self, amount_sat, message):
       +        coro = self._add_invoice_coro(amount_sat, message)
       +        fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
       +        try:
       +            return fut.result(timeout=5)
       +        except concurrent.futures.TimeoutError:
       +            raise Exception(_("add_invoice timed out"))
       +
       +    async def _add_invoice_coro(self, amount_sat, message):
                payment_preimage = os.urandom(32)
                payment_hash = sha256(payment_preimage)
                amount_btc = amount_sat/Decimal(COIN) if amount_sat else None
       -        routing_hints = self._calc_routing_hints_for_invoice(amount_sat)
       +        routing_hints = await self._calc_routing_hints_for_invoice(amount_sat)
                if not routing_hints:
                    self.logger.info("Warning. No routing hints added to invoice. "
                                     "Other clients will likely not be able to send to us.")
       t@@ -847,19 +853,20 @@ class LNWallet(LNWorker):
                    })
                return out
        
       -    def _calc_routing_hints_for_invoice(self, amount_sat):
       +    async def _calc_routing_hints_for_invoice(self, amount_sat):
                """calculate routing hints (BOLT-11 'r' field)"""
       -        self.channel_db.load_data()
                routing_hints = []
                with self.lock:
                    channels = list(self.channels.values())
                # note: currently we add *all* our channels; but this might be a privacy leak?
                for chan in channels:
                    # check channel is open
       -            if chan.get_state() != "OPEN": continue
       +            if chan.get_state() != "OPEN":
       +                continue
                    # check channel has sufficient balance
                    # FIXME because of on-chain fees of ctx, this check is insufficient
       -            if amount_sat and chan.balance(REMOTE) // 1000 < amount_sat: continue
       +            if amount_sat and chan.balance(REMOTE) // 1000 < amount_sat:
       +                continue
                    chan_id = chan.short_channel_id
                    assert type(chan_id) is bytes, chan_id
                    channel_info = self.channel_db.get_channel_info(chan_id)
       t@@ -949,14 +956,10 @@ class LNWallet(LNWorker):
                        await self._add_peer(peer.host, peer.port, peer.pubkey)
                        return
                # try random address for node_id
       -        node_info = self.channel_db.nodes_get(chan.node_id)
       -        if not node_info:
       -            return
       -        addresses = self.channel_db.get_node_addresses(node_info)
       +        addresses = self.channel_db.get_node_addresses(chan.node_id)
                if not addresses:
                    return
       -        adr_obj = random.choice(addresses)
       -        host, port = adr_obj.host, adr_obj.port
       +        host, port, t = random.choice(list(addresses))
                peer = LNPeerAddr(host, port, chan.node_id)
                last_tried = self._last_tried_peer.get(peer, 0)
                if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now:
   DIR diff --git a/electrum/sql_db.py b/electrum/sql_db.py
       t@@ -2,6 +2,7 @@ import os
        import concurrent
        import queue
        import threading
       +import asyncio
        
        from sqlalchemy import create_engine
        from sqlalchemy.pool import StaticPool
       t@@ -18,28 +19,32 @@ def sql(func):
            """wrapper for sql methods"""
            def wrapper(self, *args, **kwargs):
                assert threading.currentThread() != self.sql_thread
       -        f = concurrent.futures.Future()
       +        f = asyncio.Future()
                self.db_requests.put((f, func, args, kwargs))
       -        return f.result(timeout=10)
       +        return f
            return wrapper
        
        class SqlDB(Logger):
            
       -    def __init__(self, network, path, base):
       +    def __init__(self, network, path, base, commit_interval=None):
                Logger.__init__(self)
                self.base = base
                self.network = network
                self.path = path
       +        self.commit_interval = commit_interval
                self.db_requests = queue.Queue()
                self.sql_thread = threading.Thread(target=self.run_sql)
                self.sql_thread.start()
        
            def run_sql(self):
       +        #return
       +        self.logger.info("SQL thread started")
                engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
                DBSession = sessionmaker(bind=engine, autoflush=False)
       -        self.DBSession = DBSession()
                if not os.path.exists(self.path):
                    self.base.metadata.create_all(engine)
       +        self.DBSession = DBSession()
       +        i = 0
                while self.network.asyncio_loop.is_running():
                    try:
                        future, func, args, kwargs = self.db_requests.get(timeout=0.1)
       t@@ -50,7 +55,14 @@ class SqlDB(Logger):
                    except BaseException as e:
                        future.set_exception(e)
                        continue
       -            future.set_result(result)
       +            if not future.cancelled():
       +                future.set_result(result)
       +            # note: in sweepstore session.commit() is called inside
       +            # the sql-decorated methods, so commiting to disk is awaited
       +            if self.commit_interval:
       +                i = (i + 1) % self.commit_interval
       +                if i == 0:
       +                    self.DBSession.commit()
                # write
                self.DBSession.commit()
                self.logger.info("SQL thread terminated")
   DIR diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
       t@@ -16,7 +16,8 @@ from electrum.lnpeer import Peer
        from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
        from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
        from electrum.lnutil import PaymentFailure, LnLocalFeatures
       -from electrum.lnrouter import ChannelDB, LNPathFinder
       +from electrum.lnrouter import LNPathFinder
       +from electrum.channel_db import ChannelDB
        from electrum.lnworker import LNWallet
        from electrum.lnmsg import encode_msg, decode_msg
        from electrum.logging import console_stderr_handler
   DIR diff --git a/electrum/tests/test_lnrouter.py b/electrum/tests/test_lnrouter.py
       t@@ -59,33 +59,33 @@ class Test_LNRouter(TestCaseForTestnet):
                cdb = fake_network.channel_db
                path_finder = lnrouter.LNPathFinder(cdb)
                self.assertEqual(cdb.num_channels, 0)
       -        cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02cccccccccccccccccccccccccccccccc',
       +        cdb.add_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02cccccccccccccccccccccccccccccccc',
                                             'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02cccccccccccccccccccccccccccccccc',
                                             'short_channel_id': bfh('0000000000000001'),
                                             'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
                                             'len': b'\x00\x00', 'features': b''}, trusted=True)
                self.assertEqual(cdb.num_channels, 1)
       -        cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
       +        cdb.add_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
                                             'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
                                             'short_channel_id': bfh('0000000000000002'),
                                             'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
                                             'len': b'\x00\x00', 'features': b''}, trusted=True)
       -        cdb.on_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb',
       +        cdb.add_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb',
                                             'bitcoin_key_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'bitcoin_key_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb',
                                             'short_channel_id': bfh('0000000000000003'),
                                             'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
                                             'len': b'\x00\x00', 'features': b''}, trusted=True)
       -        cdb.on_channel_announcement({'node_id_1': b'\x02cccccccccccccccccccccccccccccccc', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd',
       +        cdb.add_channel_announcement({'node_id_1': b'\x02cccccccccccccccccccccccccccccccc', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd',
                                             'bitcoin_key_1': b'\x02cccccccccccccccccccccccccccccccc', 'bitcoin_key_2': b'\x02dddddddddddddddddddddddddddddddd',
                                             'short_channel_id': bfh('0000000000000004'),
                                             'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
                                             'len': b'\x00\x00', 'features': b''}, trusted=True)
       -        cdb.on_channel_announcement({'node_id_1': b'\x02dddddddddddddddddddddddddddddddd', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
       +        cdb.add_channel_announcement({'node_id_1': b'\x02dddddddddddddddddddddddddddddddd', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
                                             'bitcoin_key_1': b'\x02dddddddddddddddddddddddddddddddd', 'bitcoin_key_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
                                             'short_channel_id': bfh('0000000000000005'),
                                             'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
                                             'len': b'\x00\x00', 'features': b''}, trusted=True)
       -        cdb.on_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd',
       +        cdb.add_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd',
                                             'bitcoin_key_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'bitcoin_key_2': b'\x02dddddddddddddddddddddddddddddddd',
                                             'short_channel_id': bfh('0000000000000006'),
                                             'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),