URI: 
       tmore reliable peer and channel re-establishing - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit f3e5ba6ac16e8e30879079a8a18f7ad931a6017a
   DIR parent 362a3a5a442d63374ea9d3daf358f63cd18e88f9
  HTML Author: SomberNight <somber.night@protonmail.com>
       Date:   Mon, 30 Jul 2018 13:51:03 +0200
       
       more reliable peer and channel re-establishing
       
       Diffstat:
         M electrum/gui/qt/channels_list.py    |       2 +-
         M electrum/lnbase.py                  |      48 ++++++++++++++++++++++---------
         M electrum/lnhtlc.py                  |      16 +++++++++++++++-
         M electrum/lnrouter.py                |      15 +++++++++++++++
         M electrum/lnworker.py                |      90 ++++++++++++++++++++++---------
       
       5 files changed, 130 insertions(+), 41 deletions(-)
       ---
   DIR diff --git a/electrum/gui/qt/channels_list.py b/electrum/gui/qt/channels_list.py
       t@@ -24,7 +24,7 @@ class ChannelsList(MyTreeWidget):
                    bh2u(chan.node_id),
                    self.parent.format_amount(chan.local_state.amount_msat//1000),
                    self.parent.format_amount(chan.remote_state.amount_msat//1000),
       -            chan.state
       +            chan.get_state()
                ]
        
            def create_menu(self, position):
   DIR diff --git a/electrum/lnbase.py b/electrum/lnbase.py
       t@@ -207,6 +207,10 @@ class HandshakeState(object):
                self.h = sha256(self.h + data)
                return self.h
        
       +
       +class HandshakeFailed(Exception): pass
       +
       +
        def get_nonce_bytes(n):
            """BOLT 8 requires the nonce to be 12 bytes, 4 bytes leading
            zeroes and 8 bytes little endian encoded 64 bit integer.
       t@@ -285,6 +289,7 @@ class Peer(PrintError):
                self.host = host
                self.port = port
                self.pubkey = pubkey
       +        self.peer_addr = LNPeerAddr(host, port, pubkey)
                self.lnworker = lnworker
                self.privkey = lnworker.privkey
                self.network = lnworker.network
       t@@ -340,7 +345,10 @@ class Peer(PrintError):
                            self.read_buffer = self.read_buffer[offset:]
                            msg = aead_decrypt(rk_m, rn_m, b'', c)
                            return msg
       -            s = await self.reader.read(2**10)
       +            try:
       +                s = await self.reader.read(2**10)
       +            except:
       +                s = None
                    if not s:
                        raise LightningPeerConnectionClosed()
                    self.read_buffer += s
       t@@ -354,9 +362,11 @@ class Peer(PrintError):
                # act 1
                self.writer.write(msg)
                rspns = await self.reader.read(2**10)
       -        assert len(rspns) == 50, "Lightning handshake act 1 response has bad length, are you sure this is the right pubkey? " + str(bh2u(self.pubkey))
       +        if len(rspns) != 50:
       +            raise HandshakeFailed("Lightning handshake act 1 response has bad length, are you sure this is the right pubkey? " + str(bh2u(self.pubkey)))
                hver, alice_epub, tag = rspns[0], rspns[1:34], rspns[34:]
       -        assert bytes([hver]) == hs.handshake_version
       +        if bytes([hver]) != hs.handshake_version:
       +            raise HandshakeFailed("unexpected handshake version: {}".format(hver))
                # act 2
                hs.update(alice_epub)
                ss = get_ecdh(epriv, alice_epub)
       t@@ -461,15 +471,21 @@ class Peer(PrintError):
            @aiosafe
            async def main_loop(self):
                await asyncio.wait_for(self.initialize(), 5)
       -        self.channel_db.add_recent_peer(LNPeerAddr(self.host, self.port, self.pubkey))
       +        self.channel_db.add_recent_peer(self.peer_addr)
                # loop
                while True:
                    self.ping_if_required()
                    msg = await self.read_message()
                    self.process_message(msg)
       -        # close socket
       -        self.print_error('closing lnbase')
       -        self.writer.close()
       +
       +    def close_and_cleanup(self):
       +        try:
       +            self.writer.close()
       +        except:
       +            pass
       +        for chan in self.channels.values():
       +            chan.set_state('DISCONNECTED')
       +            self.network.trigger_callback('channel', chan)
        
            @aiosafe
            async def channel_establishment_flow(self, wallet, config, password, funding_sat, push_msat, temp_channel_id):
       t@@ -601,14 +617,18 @@ class Peer(PrintError):
                assert success, success
                m.remote_state = m.remote_state._replace(ctn=0)
                m.local_state = m.local_state._replace(ctn=0, current_commitment_signature=remote_sig)
       -        m.state = 'OPENING'
       +        m.set_state('OPENING')
                return m
        
            @aiosafe
            async def reestablish_channel(self, chan):
                await self.initialized
                chan_id = chan.channel_id
       -        chan.state = 'REESTABLISHING'
       +        if chan.get_state() != 'DISCONNECTED':
       +            self.print_error('reestablish_channel was called but channel {} already in state {}'
       +                             .format(chan_id, chan.get_state()))
       +            return
       +        chan.set_state('REESTABLISHING')
                self.network.trigger_callback('channel', chan)
                self.send_message(gen_msg("channel_reestablish",
                    channel_id=chan_id,
       t@@ -616,7 +636,7 @@ class Peer(PrintError):
                    next_remote_revocation_number=chan.remote_state.ctn
                ))
                await self.channel_reestablished[chan_id]
       -        chan.state = 'OPENING'
       +        chan.set_state('OPENING')
                if chan.local_state.funding_locked_received and chan.short_channel_id:
                    self.mark_open(chan)
                self.network.trigger_callback('channel', chan)
       t@@ -727,10 +747,10 @@ class Peer(PrintError):
                print("SENT CHANNEL ANNOUNCEMENT")
        
            def mark_open(self, chan):
       -        if chan.state == "OPEN":
       +        if chan.get_state() == "OPEN":
                    return
                assert chan.local_state.funding_locked_received
       -        chan.state = "OPEN"
       +        chan.set_state("OPEN")
                self.network.trigger_callback('channel', chan)
                # add channel to database
                node_ids = [self.pubkey, self.lnworker.pubkey]
       t@@ -820,7 +840,7 @@ class Peer(PrintError):
        
            @aiosafe
            async def pay(self, path, chan, amount_msat, payment_hash, pubkey_in_invoice, min_final_cltv_expiry):
       -        assert chan.state == "OPEN"
       +        assert chan.get_state() == "OPEN"
                assert amount_msat > 0, "amount_msat is not greater zero"
                height = self.network.get_local_height()
                route = self.network.path_finder.create_route_from_path(path, self.lnworker.pubkey)
       t@@ -911,7 +931,7 @@ class Peer(PrintError):
                htlc_id = int.from_bytes(htlc["id"], 'big')
                assert htlc_id == chan.remote_state.next_htlc_id, (htlc_id, chan.remote_state.next_htlc_id)
        
       -        assert chan.state == "OPEN"
       +        assert chan.get_state() == "OPEN"
        
                cltv_expiry = int.from_bytes(htlc["cltv_expiry"], 'big')
                # TODO verify sanity of their cltv expiry
   DIR diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py
       t@@ -138,7 +138,21 @@ class HTLCStateMachine(PrintError):
                self.local_commitment = self.pending_local_commitment
                self.remote_commitment = self.pending_remote_commitment
        
       -        self.state = 'DISCONNECTED'
       +        self._is_funding_txo_spent = None  # "don't know"
       +        self.set_state('DISCONNECTED')
       +
       +    def set_state(self, state: str):
       +        self._state = state
       +
       +    def get_state(self):
       +        return self._state
       +
       +    def set_funding_txo_spentness(self, is_spent: bool):
       +        assert isinstance(is_spent, bool)
       +        self._is_funding_txo_spent = is_spent
       +
       +    def should_try_to_reestablish_peer(self) -> bool:
       +        return self._is_funding_txo_spent is False and self._state == 'DISCONNECTED'
        
            def get_funding_address(self):
                script = funding_output_script(self.local_config, self.remote_config)
   DIR diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
       t@@ -269,6 +269,7 @@ class ChannelDB(JsonDB):
                self._channels_for_node = defaultdict(set)  # node -> set(short_channel_id)
                self.nodes = {}  # node_id -> NodeInfo
                self._recent_peers = []
       +        self._last_good_address = {}  # node_id -> LNPeerAddr
        
                self.ca_verifier = LNChanAnnVerifier(network, self)
                self.network.add_jobs([self.ca_verifier])
       t@@ -297,6 +298,11 @@ class ChannelDB(JsonDB):
                for host, port, pubkey in recent_peers:
                    peer = LNPeerAddr(str(host), int(port), bfh(pubkey))
                    self._recent_peers.append(peer)
       +        # last good address
       +        last_good_addr = self.get('last_good_address', {})
       +        for node_id, host_and_port in last_good_addr.items():
       +            host, port = host_and_port
       +            self._last_good_address[bfh(node_id)] = LNPeerAddr(str(host), int(port), bfh(node_id))
        
            def save_data(self):
                with self.lock:
       t@@ -316,6 +322,11 @@ class ChannelDB(JsonDB):
                        recent_peers.append(
                            [str(peer.host), int(peer.port), bh2u(peer.pubkey)])
                    self.put('recent_peers', recent_peers)
       +            # last good address
       +            last_good_addr = {}
       +            for node_id, peer in self._last_good_address.items():
       +                last_good_addr[bh2u(node_id)] = [str(peer.host), int(peer.port)]
       +            self.put('last_good_address', last_good_addr)
                self.write()
        
            def __len__(self):
       t@@ -347,6 +358,10 @@ class ChannelDB(JsonDB):
                        self._recent_peers.remove(peer)
                    self._recent_peers.insert(0, peer)
                    self._recent_peers = self._recent_peers[:self.NUM_MAX_RECENT_PEERS]
       +            self._last_good_address[peer.pubkey] = peer
       +
       +    def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]:
       +        return self._last_good_address.get(node_id, None)
        
            def on_channel_announcement(self, msg_payload, trusted=False):
                short_channel_id = msg_payload['short_channel_id']
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -4,6 +4,7 @@ from decimal import Decimal
        import random
        import time
        from typing import Optional, Sequence
       +import threading
        
        import dns.resolver
        import dns.exception
       t@@ -22,6 +23,7 @@ from .i18n import _
        
        NUM_PEERS_TARGET = 4
        PEER_RETRY_INTERVAL = 600  # seconds
       +PEER_RETRY_INTERVAL_FOR_CHANNELS = 30  # seconds
        
        FALLBACK_NODE_LIST = (
            LNPeerAddr('ecdsa.net', 9735, bfh('038370f0e7a03eded3e1d41dc081084a87f0afa1c5b22090b4f3abb391eb15d8ff')),
       t@@ -34,6 +36,7 @@ class LNWorker(PrintError):
                self.wallet = wallet
                self.network = network
                self.channel_db = self.network.channel_db
       +        self.lock = threading.RLock()
                pk = wallet.storage.get('lightning_privkey')
                if pk is None:
                    pk = bh2u(os.urandom(32))
       t@@ -48,9 +51,6 @@ class LNWorker(PrintError):
                for chan_id, chan in self.channels.items():
                    self.network.lnwatcher.watch_channel(chan, self.on_channel_utxos)
                self._last_tried_peer = {}  # LNPeerAddr -> unix timestamp
       -        # TODO peers that we have channels with should also be added now
       -        # but we don't store their IP/port yet.. also what if it changes?
       -        # need to listen for node_announcements and save the new IP/port
                self._add_peers_from_config()
                # wait until we see confirmations
                self.network.register_callback(self.on_network_update, ['updated', 'verified', 'fee_histogram']) # thread safe
       t@@ -72,15 +72,14 @@ class LNWorker(PrintError):
        
            def channels_for_peer(self, node_id):
                assert type(node_id) is bytes
       -        return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
       +        with self.lock:
       +            return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
        
            def add_peer(self, host, port, node_id):
                port = int(port)
                peer_addr = LNPeerAddr(host, port, node_id)
                if node_id in self.peers:
                    return
       -        if peer_addr in self._last_tried_peer:
       -            return
                self._last_tried_peer[peer_addr] = time.time()
                self.print_error("adding peer", peer_addr)
                peer = Peer(self, host, port, node_id, request_initial_sync=self.config.get("request_initial_sync", True))
       t@@ -90,10 +89,11 @@ class LNWorker(PrintError):
        
            def save_channel(self, openchannel):
                assert type(openchannel) is HTLCStateMachine
       -        self.channels[openchannel.channel_id] = openchannel
                if openchannel.remote_state.next_per_commitment_point == openchannel.remote_state.current_per_commitment_point:
                    raise Exception("Tried to save channel with next_point == current_point, this should not happen")
       -        dumped = [x.serialize() for x in self.channels.values()]
       +        with self.lock:
       +            self.channels[openchannel.channel_id] = openchannel
       +            dumped = [x.serialize() for x in self.channels.values()]
                self.wallet.storage.put("channels", dumped)
                self.wallet.storage.write()
                self.network.trigger_callback('channel', openchannel)
       t@@ -104,7 +104,7 @@ class LNWorker(PrintError):
        
                If the Funding TX has not been mined, return None
                """
       -        assert chan.state in ["OPEN", "OPENING"]
       +        assert chan.get_state() in ["OPEN", "OPENING"]
                peer = self.peers[chan.node_id]
                conf = self.wallet.get_tx_height(chan.funding_outpoint.txid)[1]
                if conf >= chan.constraints.funding_txn_minimum_depth:
       t@@ -121,16 +121,12 @@ class LNWorker(PrintError):
            def on_channel_utxos(self, chan, utxos):
                outpoints = [Outpoint(x["tx_hash"], x["tx_pos"]) for x in utxos]
                if chan.funding_outpoint not in outpoints:
       -            chan.state = "CLOSED"
       +            chan.set_funding_txo_spentness(True)
       +            chan.set_state("CLOSED")
                    # FIXME is this properly GC-ed? (or too soon?)
                    LNChanCloseHandler(self.network, self.wallet, chan)
       -        elif chan.state == 'DISCONNECTED':
       -            if chan.node_id not in self.peers:
       -                self.print_error("received channel_utxos for channel which does not have peer (errored?)")
       -                return
       -            peer = self.peers[chan.node_id]
       -            coro = peer.reestablish_channel(chan)
       -            asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
       +        else:
       +            chan.set_funding_txo_spentness(False)
                self.network.trigger_callback('channel', chan)
        
            def on_network_update(self, event, *args):
       t@@ -139,8 +135,10 @@ class LNWorker(PrintError):
                # since short_channel_id could be changed while saving.
                # Mitigated by posting to loop:
                async def network_jobs():
       -            for chan in self.channels.values():
       -                if chan.state == "OPENING":
       +            with self.lock:
       +                channels = list(self.channels.values())
       +            for chan in channels:
       +                if chan.get_state() == "OPENING":
                            res = self.save_short_chan_id(chan)
                            if not res:
                                self.print_error("network update but funding tx is still not at sufficient depth")
       t@@ -148,7 +146,7 @@ class LNWorker(PrintError):
                            # this results in the channel being marked OPEN
                            peer = self.peers[chan.node_id]
                            peer.funding_locked(chan)
       -                elif chan.state == "OPEN":
       +                elif chan.get_state() == "OPEN":
                            peer = self.peers.get(chan.node_id)
                            if peer is None:
                                self.print_error("peer not found for {}".format(bh2u(chan.node_id)))
       t@@ -177,6 +175,7 @@ class LNWorker(PrintError):
                return asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
        
            def pay(self, invoice, amount_sat=None):
       +        # TODO try some number of paths (e.g. 10) in case of failures
                addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
                payment_hash = addr.paymenthash
                invoice_pubkey = addr.pubkey.serialize()
       t@@ -189,7 +188,9 @@ class LNWorker(PrintError):
                    raise Exception("No path found")
                node_id, short_channel_id = path[0]
                peer = self.peers[node_id]
       -        for chan in self.channels.values():
       +        with self.lock:
       +            channels = list(self.channels.values())
       +        for chan in channels:
                    if chan.short_channel_id == short_channel_id:
                        break
                else:
       t@@ -216,7 +217,8 @@ class LNWorker(PrintError):
                self.wallet.storage.write()
        
            def list_channels(self):
       -        return [str(x) for x in self.channels]
       +        with self.lock:
       +            return [str(x) for x in self.channels]
        
            def close_channel(self, chan_id):
                chan = self.channels[chan_id]
       t@@ -250,7 +252,7 @@ class LNWorker(PrintError):
                # try random peer from graph
                all_nodes = self.channel_db.nodes
                if all_nodes:
       -            self.print_error('trying to get ln peers from channel db')
       +            #self.print_error('trying to get ln peers from channel db')
                    node_ids = list(all_nodes)
                    max_tries = min(200, len(all_nodes))
                    for i in range(max_tries):
       t@@ -259,7 +261,7 @@ class LNWorker(PrintError):
                        if node is None: continue
                        addresses = node.addresses
                        if not addresses: continue
       -                host, port = addresses[0]
       +                host, port = random.choice(addresses)
                        peer = LNPeerAddr(host, port, node_id)
                        if peer.pubkey in self.peers: continue
                        if peer in self._last_tried_peer: continue
       t@@ -309,16 +311,54 @@ class LNWorker(PrintError):
                self.print_error('got {} ln peers from dns seed'.format(len(peers)))
                return peers
        
       +    def reestablish_peers_and_channels(self):
       +        def reestablish_peer_for_given_channel():
       +            # try last good address first
       +            peer = self.channel_db.get_last_good_address(chan.node_id)
       +            if peer:
       +                last_tried = self._last_tried_peer.get(peer, 0)
       +                if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now:
       +                    self.add_peer(peer.host, peer.port, peer.pubkey)
       +                    return
       +            # try random address for node_id
       +            node_info = self.channel_db.nodes.get(chan.node_id, None)
       +            if not node_info: return
       +            addresses = node_info.addresses
       +            if not addresses: return
       +            host, port = random.choice(addresses)
       +            peer = LNPeerAddr(host, port, chan.node_id)
       +            last_tried = self._last_tried_peer.get(peer, 0)
       +            if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now:
       +                self.add_peer(host, port, chan.node_id)
       +
       +        with self.lock:
       +            channels = list(self.channels.values())
       +        now = time.time()
       +        for chan in channels:
       +            if not chan.should_try_to_reestablish_peer():
       +                continue
       +            peer = self.peers.get(chan.node_id, None)
       +            if peer is None:
       +                reestablish_peer_for_given_channel()
       +            else:
       +                coro = peer.reestablish_channel(chan)
       +                asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
       +
            @aiosafe
            async def main_loop(self):
                while True:
                    await asyncio.sleep(1)
       +            now = time.time()
                    for node_id, peer in list(self.peers.items()):
                        if peer.exception:
                            self.print_error("removing peer", peer.host)
       +                    peer.close_and_cleanup()
                            self.peers.pop(node_id)
       +            self.reestablish_peers_and_channels()
                    if len(self.peers) >= NUM_PEERS_TARGET:
                        continue
                    peers = self._get_next_peers_to_try()
                    for peer in peers:
       -                self.add_peer(peer.host, peer.port, peer.pubkey)
       +                last_tried = self._last_tried_peer.get(peer, 0)
       +                if last_tried + PEER_RETRY_INTERVAL < now:
       +                    self.add_peer(peer.host, peer.port, peer.pubkey)