URI: 
       tlnwatcher: save sweepstore in sqlite database - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit b861e2e955c4a790d8e2b4ce262b894a67c3b470
   DIR parent bfdf0a7e8823b8f250df6d7fb8d691dea689b2a5
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Tue,  5 Mar 2019 17:28:24 +0100
       
       lnwatcher: save sweepstore in sqlite database
       
       Diffstat:
         M electrum/gui/qt/watchtower_window.… |       6 ++++--
         M electrum/lnwatcher.py               |     161 +++++++++++++++++++++++--------
       
       2 files changed, 124 insertions(+), 43 deletions(-)
       ---
   DIR diff --git a/electrum/gui/qt/watchtower_window.py b/electrum/gui/qt/watchtower_window.py
       t@@ -52,9 +52,11 @@ class WatcherList(MyTreeView):
            def update(self):
                self.model().clear()
                self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')})
       -        for outpoint, sweep_dict in self.parent.lnwatcher.sweepstore.items():
       +        sweepstore = self.parent.lnwatcher.sweepstore
       +        for outpoint in sweepstore.list_sweep_tx():
       +            n = sweepstore.num_sweep_tx(outpoint)
                    status = self.parent.lnwatcher.get_channel_status(outpoint)
       -            items = [QStandardItem(e) for e in [outpoint, "%d"%len(sweep_dict), status]]
       +            items = [QStandardItem(e) for e in [outpoint, "%d"%n, status]]
                    self.model().insertRow(self.model().rowCount(), items)
        
        
   DIR diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py
       t@@ -2,9 +2,11 @@
        # Distributed under the MIT software license, see the accompanying
        # file LICENCE or http://www.opensource.org/licenses/mit-license.php
        
       -import threading
        from typing import NamedTuple, Iterable, TYPE_CHECKING
        import os
       +import queue
       +import threading
       +import concurrent
        from collections import defaultdict
        import asyncio
        from enum import IntEnum, auto
       t@@ -35,27 +37,125 @@ class TxMinedDepth(IntEnum):
            FREE = auto()
        
        
       +from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
       +from sqlalchemy.pool import StaticPool
       +from sqlalchemy.orm import sessionmaker
       +from sqlalchemy.orm.query import Query
       +from sqlalchemy.ext.declarative import declarative_base
       +from sqlalchemy.sql import not_, or_
       +from sqlalchemy.orm import scoped_session
       +
       +Base = declarative_base()
       +
       +class SweepTx(Base):
       +    __tablename__ = 'sweep_txs'
       +    funding_outpoint = Column(String(34))
       +    prev_txid = Column(String(32))
       +    tx = Column(String())
       +    txid = Column(String(32), primary_key=True) # txid of tx
       +
       +class ChannelInfo(Base):
       +    __tablename__ = 'channel_info'
       +    address = Column(String(32), primary_key=True)
       +    outpoint = Column(String(34))
       +
       +
       +class SweepStore(PrintError):
       +
       +    def __init__(self, path, network):
       +        PrintError.__init__(self)
       +        self.path = path
       +        self.network = network
       +        self.db_requests = queue.Queue()
       +        threading.Thread(target=self.sql_thread).start()
       +
       +    def sql_thread(self):
       +        engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)
       +        DBSession = sessionmaker(bind=engine, autoflush=False)
       +        self.DBSession = DBSession()
       +        if not os.path.exists(self.path):
       +            Base.metadata.create_all(engine)
       +        while self.network.asyncio_loop.is_running():
       +            try:
       +                future, func, args, kwargs = self.db_requests.get(timeout=0.1)
       +            except queue.Empty:
       +                continue
       +            try:
       +                result = func(self, *args, **kwargs)
       +            except BaseException as e:
       +                future.set_exception(e)
       +                continue
       +            future.set_result(result)
       +        # write
       +        self.DBSession.commit()
       +        self.print_error("SQL thread terminated")
       +
       +    def sql(func):
       +        def wrapper(self, *args, **kwargs):
       +            f = concurrent.futures.Future()
       +            self.db_requests.put((f, func, args, kwargs))
       +            return f.result(timeout=10)
       +        return wrapper
       +
       +    @sql
       +    def get_sweep_tx(self, funding_outpoint, prev_txid):
       +        return [Transaction(r.tx) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prev_txid==prev_txid).all()]
       +
       +    @sql
       +    def list_sweep_tx(self):
       +        return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all())
       +
       +    @sql
       +    def add_sweep_tx(self, funding_outpoint, prev_txid, tx):
       +        self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, prev_txid=prev_txid, tx=str(tx), txid=tx.txid()))
       +        self.DBSession.commit()
       +
       +    @sql
       +    def num_sweep_tx(self, funding_outpoint):
       +        return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()
       +
       +    @sql
       +    def remove_sweep_tx(self, funding_outpoint):
       +        v = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint).all()
       +        self.DBSession.delete(v)
       +        self.DBSession.commit()
       +
       +    @sql
       +    def add_channel_info(self, address, outpoint):
       +        self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint))
       +        self.DBSession.commit()
       +
       +    @sql
       +    def remove_channel_info(self, address):
       +        v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()
       +        self.DBSession.delete(v)
       +        self.DBSession.commit()
       +
       +    @sql
       +    def has_channel_info(self, address):
       +        return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none())
       +
       +    @sql
       +    def get_channel_info(self, address):
       +        r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()
       +        return r.outpoint if r else None
       +
       +    @sql
       +    def list_channel_info(self):
       +        return [(r.address, r.outpoint) for r in self.DBSession.query(ChannelInfo).all()]
       +
       +
        class LNWatcher(AddressSynchronizer):
            verbosity_filter = 'W'
        
            def __init__(self, network: 'Network'):
       -        path = os.path.join(network.config.path, "watcher_db")
       +        path = os.path.join(network.config.path, "watchtower_wallet")
                storage = WalletStorage(path)
                AddressSynchronizer.__init__(self, storage)
                self.config = network.config
                self.start_network(network)
                self.lock = threading.RLock()
       -        self.channel_info = storage.get('channel_info', {})  # access with 'lock'
       -        # [funding_outpoint_str][prev_txid] -> set of Transaction
       -        # prev_txid is the txid of a tx that is watched for confirmations
       -        # access with 'lock'
       -        self.sweepstore = defaultdict(lambda: defaultdict(set))
       -        for funding_outpoint, ctxs in storage.get('sweepstore', {}).items():
       -            for txid, set_of_txns in ctxs.items():
       -                for tx in set_of_txns:
       -                    tx2 = Transaction.from_dict(tx)
       -                    self.sweepstore[funding_outpoint][txid].add(tx2)
       -
       +        self.sweepstore = SweepStore(os.path.join(network.config.path, "watchtower_db"), network)
                self.network.register_callback(self.on_network_update,
                                               ['network_updated', 'blockchain_updated', 'verified', 'wallet_updated'])
                self.set_remote_watchtower()
       t@@ -97,34 +197,18 @@ class LNWatcher(AddressSynchronizer):
                        await asyncio.sleep(5)
                        await self.watchtower_queue.put((name, args, kwargs))
        
       -    def write_to_disk(self):
       -        # FIXME: json => every update takes linear instead of constant disk write
       -        with self.lock:
       -            storage = self.storage
       -            storage.put('channel_info', self.channel_info)
       -            # self.sweepstore
       -            sweepstore = {}
       -            for funding_outpoint, ctxs in self.sweepstore.items():
       -                sweepstore[funding_outpoint] = {}
       -                for prev_txid, set_of_txns in ctxs.items():
       -                    sweepstore[funding_outpoint][prev_txid] = [tx.as_dict() for tx in set_of_txns]
       -            storage.put('sweepstore', sweepstore)
       -        storage.write()
        
            @with_watchtower
            def watch_channel(self, address, outpoint):
                self.add_address(address)
                with self.lock:
       -            if address not in self.channel_info:
       -                self.channel_info[address] = outpoint
       -            self.write_to_disk()
       +            if not self.sweepstore.has_channel_info(address):
       +                self.sweepstore.add_channel_info(address, outpoint)
        
            def unwatch_channel(self, address, funding_outpoint):
                self.print_error('unwatching', funding_outpoint)
       -        with self.lock:
       -            self.channel_info.pop(address)
       -            self.sweepstore.pop(funding_outpoint)
       -            self.write_to_disk()
       +        self.sweepstore.remove_sweep_tx(funding_outpoint)
       +        self.sweepstore.remove_channel_info(address)
                if funding_outpoint in self.tx_progress:
                    self.tx_progress[funding_outpoint].all_done.set()
        
       t@@ -138,9 +222,7 @@ class LNWatcher(AddressSynchronizer):
                    return
                if not self.synchronizer.is_up_to_date():
                    return
       -        with self.lock:
       -            channel_info_items = list(self.channel_info.items())
       -        for address, outpoint in channel_info_items:
       +        for address, outpoint in self.sweepstore.list_channel_info():
                    await self.check_onchain_situation(address, outpoint)
        
            async def check_onchain_situation(self, address, funding_outpoint):
       t@@ -192,8 +274,7 @@ class LNWatcher(AddressSynchronizer):
                    if spender is not None:
                        continue
                    prev_txid, prev_n = prevout.split(':')
       -            with self.lock:
       -                sweep_txns = self.sweepstore[funding_outpoint][prev_txid]
       +            sweep_txns = self.sweepstore.get_sweep_tx(funding_outpoint, prev_txid)
                    for tx in sweep_txns:
                        if not await self.broadcast_or_log(funding_outpoint, tx):
                            self.print_error(tx.name, f'could not publish tx: {str(tx)}, prev_txid: {prev_txid}')
       t@@ -215,9 +296,7 @@ class LNWatcher(AddressSynchronizer):
            @with_watchtower
            def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx_dict):
                tx = Transaction.from_dict(tx_dict)
       -        with self.lock:
       -            self.sweepstore[funding_outpoint][prev_txid].add(tx)
       -        self.write_to_disk()
       +        self.sweepstore.add_sweep_tx(funding_outpoint, prev_txid, tx)
        
            def get_tx_mined_depth(self, txid: str):
                if not txid: