URI: 
       tbasic watchtower synchronization - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 3abe30e9d8a383fe88bd9edd23b3be749e9c1c03
   DIR parent c155293166a315a3479bfc9af66b05cb8ef22696
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Tue, 12 Mar 2019 18:33:36 +0100
       
       basic watchtower synchronization
       
       Diffstat:
         M electrum/daemon.py                  |       3 ++-
         M electrum/gui/qt/watchtower_window.… |       2 +-
         M electrum/lnchannel.py               |       2 +-
         M electrum/lnpeer.py                  |       2 +-
         M electrum/lnwatcher.py               |      82 +++++++++++++++++--------------
         M electrum/lnworker.py                |       4 ++--
       
       6 files changed, 53 insertions(+), 42 deletions(-)
       ---
   DIR diff --git a/electrum/daemon.py b/electrum/daemon.py
       t@@ -135,7 +135,8 @@ class WatchTower(DaemonThread):
                port = self.config.get('watchtower_port', 12345)
                server = SimpleJSONRPCServer((host, port), logRequests=True)
                server.register_function(self.lnwatcher.add_sweep_tx, 'add_sweep_tx')
       -        server.register_function(self.lnwatcher.watch_channel, 'watch_channel')
       +        server.register_function(self.lnwatcher.add_channel, 'add_channel')
       +        server.register_function(self.lnwatcher.get_num_tx, 'get_num_tx')
                server.timeout = 0.1
                while self.is_running():
                    server.handle_request()
   DIR diff --git a/electrum/gui/qt/watchtower_window.py b/electrum/gui/qt/watchtower_window.py
       t@@ -54,7 +54,7 @@ class WatcherList(MyTreeView):
                self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')})
                sweepstore = self.parent.lnwatcher.sweepstore
                for outpoint in sweepstore.list_sweep_tx():
       -            n = sweepstore.num_sweep_tx(outpoint)
       +            n = sweepstore.get_num_tx(outpoint)
                    status = self.parent.lnwatcher.get_channel_status(outpoint)
                    items = [QStandardItem(e) for e in [outpoint, "%d"%n, status]]
                    self.model().insertRow(self.model().rowCount(), items)
   DIR diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py
       t@@ -464,7 +464,7 @@ class Channel(PrintError):
                sweeptxs = create_sweeptxs_for_their_just_revoked_ctx(self, ctx, per_commitment_secret, self.sweep_address)
                for prev_txid, tx in sweeptxs.items():
                    if tx is not None:
       -                self.lnwatcher.add_sweep_tx(outpoint, prev_txid, tx.as_dict())
       +                self.lnwatcher.add_sweep_tx(outpoint, prev_txid, str(tx))
        
            def receive_revocation(self, revocation: RevokeAndAck):
                self.print_error("receive_revocation")
   DIR diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
       t@@ -490,7 +490,7 @@ class Peer(PrintError):
                )
                chan.open_with_first_pcp(payload['first_per_commitment_point'], remote_sig)
                self.lnworker.save_channel(chan)
       -        self.lnwatcher.watch_channel(chan.get_funding_address(), chan.funding_outpoint.to_str())
       +        self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
                self.lnworker.on_channels_updated()
                while True:
                    try:
   DIR diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py
       t@@ -46,15 +46,15 @@ Base = declarative_base()
        
        class SweepTx(Base):
            __tablename__ = 'sweep_txs'
       -    funding_outpoint = Column(String(34))
       +    funding_outpoint = Column(String(34), primary_key=True)
       +    index = Column(Integer(), primary_key=True)
            prev_txid = Column(String(32))
            tx = Column(String())
       -    txid = Column(String(32), primary_key=True) # txid of tx
        
        class ChannelInfo(Base):
            __tablename__ = 'channel_info'
       -    address = Column(String(32), primary_key=True)
       -    outpoint = Column(String(34))
       +    outpoint = Column(String(34), primary_key=True)
       +    address = Column(String(32))
        
        
        
       t@@ -68,16 +68,22 @@ class SweepStore(SqlDB):
                return [Transaction(bh2u(r.tx)) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prev_txid==prev_txid).all()]
        
            @sql
       +    def get_tx_by_index(self, funding_outpoint, index):
       +        r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.index==index).one_or_none()
       +        return r.prev_txid, bh2u(r.tx)
       +
       +    @sql
            def list_sweep_tx(self):
                return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all())
        
            @sql
            def add_sweep_tx(self, funding_outpoint, prev_txid, tx):
       -        self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, prev_txid=prev_txid, tx=bfh(str(tx)), txid=tx.txid()))
       +        n = self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()
       +        self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, index=n, prev_txid=prev_txid, tx=bfh(tx)))
                self.DBSession.commit()
        
            @sql
       -    def num_sweep_tx(self, funding_outpoint):
       +    def get_num_tx(self, funding_outpoint):
                return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()
        
            @sql
       t@@ -87,24 +93,24 @@ class SweepStore(SqlDB):
                self.DBSession.commit()
        
            @sql
       -    def add_channel_info(self, address, outpoint):
       +    def add_channel(self, outpoint, address):
                self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint))
                self.DBSession.commit()
        
            @sql
       -    def remove_channel_info(self, address):
       -        v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()
       +    def remove_channel(self, outpoint):
       +        v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()
                self.DBSession.delete(v)
                self.DBSession.commit()
        
            @sql
       -    def has_channel_info(self, address):
       -        return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none())
       +    def has_channel(self, outpoint):
       +        return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none())
        
            @sql
       -    def get_channel_info(self, address):
       -        r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()
       -        return r.outpoint if r else None
       +    def get_address(self, outpoint):
       +        r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()
       +        return r.address if r else None
        
            @sql
            def list_channel_info(self):
       t@@ -139,42 +145,46 @@ class LNWatcher(AddressSynchronizer):
                self.watchtower = jsonrpclib.Server(watchtower_url) if watchtower_url else None
                self.watchtower_queue = asyncio.Queue()
        
       -    def with_watchtower(func):
       -        def wrapper(self, *args, **kwargs):
       -            if self.watchtower:
       -                self.watchtower_queue.put_nowait((func.__name__, args, kwargs))
       -            return func(self, *args, **kwargs)
       -        return wrapper
       +    def get_num_tx(self, outpoint):
       +        return self.sweepstore.get_num_tx(outpoint)
        
            @ignore_exceptions
            @log_exceptions
            async def watchtower_task(self):
                self.print_error('watchtower task started')
       +        # initial check
       +        for address, outpoint in self.sweepstore.list_channel_info():
       +            await self.watchtower_queue.put(outpoint)
                while True:
       -            name, args, kwargs = await self.watchtower_queue.get()
       +            outpoint = await self.watchtower_queue.get()
                    if self.watchtower is None:
                        continue
       -            func = getattr(self.watchtower, name)
       +            # synchronize with remote
                    try:
       -                r = func(*args, **kwargs)
       -                self.print_error("watchtower answer", r)
       -            except:
       -                self.print_error('could not reach watchtower, will retry in 5s', name, args)
       +                local_n = self.sweepstore.get_num_tx(outpoint)
       +                n = self.watchtower.get_num_tx(outpoint)
       +                if n == 0:
       +                    address = self.sweepstore.get_address(outpoint)
       +                    self.watchtower.add_channel(outpoint, address)
       +                self.print_error("sending %d transactions to watchtower"%(local_n - n))
       +                for index in range(n, local_n):
       +                    prev_txid, tx = self.sweepstore.get_tx_by_index(outpoint, index)
       +                    self.watchtower.add_sweep_tx(outpoint, prev_txid, tx)
       +            except ConnectionRefusedError:
       +                self.print_error('could not reach watchtower, will retry in 5s')
                        await asyncio.sleep(5)
       -                await self.watchtower_queue.put((name, args, kwargs))
       -
       +                await self.watchtower_queue.put(outpoint)
        
       -    @with_watchtower
       -    def watch_channel(self, address, outpoint):
       +    def add_channel(self, outpoint, address):
                self.add_address(address)
                with self.lock:
       -            if not self.sweepstore.has_channel_info(address):
       -                self.sweepstore.add_channel_info(address, outpoint)
       +            if not self.sweepstore.has_channel(outpoint):
       +                self.sweepstore.add_channel(outpoint, address)
        
            def unwatch_channel(self, address, funding_outpoint):
                self.print_error('unwatching', funding_outpoint)
                self.sweepstore.remove_sweep_tx(funding_outpoint)
       -        self.sweepstore.remove_channel_info(address)
       +        self.sweepstore.remove_channel_info(funding_outpoint)
                if funding_outpoint in self.tx_progress:
                    self.tx_progress[funding_outpoint].all_done.set()
        
       t@@ -259,10 +269,10 @@ class LNWatcher(AddressSynchronizer):
                        await self.tx_progress[funding_outpoint].tx_queue.put(tx)
                    return txid
        
       -    @with_watchtower
       -    def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx_dict):
       -        tx = Transaction.from_dict(tx_dict)
       +    def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx: str):
                self.sweepstore.add_sweep_tx(funding_outpoint, prev_txid, tx)
       +        if self.watchtower:
       +            self.watchtower_queue.put_nowait(funding_outpoint)
        
            def get_tx_mined_depth(self, txid: str):
                if not txid:
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -92,7 +92,7 @@ class LNWorker(PrintError):
                self.config = network.config
                self.channel_db = self.network.channel_db
                for chan_id, chan in self.channels.items():
       -            self.network.lnwatcher.watch_channel(chan.get_funding_address(), chan.funding_outpoint.to_str())
       +            self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
                    chan.lnwatcher = network.lnwatcher
                self._last_tried_peer = {}  # LNPeerAddr -> unix timestamp
                self._add_peers_from_config()
       t@@ -425,7 +425,7 @@ class LNWorker(PrintError):
                    push_msat=push_sat * 1000,
                    temp_channel_id=os.urandom(32))
                self.save_channel(chan)
       -        self.network.lnwatcher.watch_channel(chan.get_funding_address(), chan.funding_outpoint.to_str())
       +        self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
                self.on_channels_updated()
                return chan