URI: 
       tMerge pull request #4767 from SomberNight/auto_jump_forks - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 684e69763a5d1091887793e560609efd55d58d0d
   DIR parent cd5152a02d0cc3e544d2c0143baf4d3b7405c226
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Fri, 12 Oct 2018 10:50:47 +0200
       
       Merge pull request #4767 from SomberNight/auto_jump_forks
       
       network: auto-switch servers to preferred fork (or longest chain)
       Diffstat:
         M electrum/blockchain.py              |      24 +++++++++++++++++-------
         M electrum/gui/kivy/main_window.py    |       4 ++--
         M electrum/gui/qt/network_dialog.py   |      11 ++++-------
         M electrum/interface.py               |       1 +
         M electrum/network.py                 |     120 ++++++++++++++++++++-----------
         M electrum/verifier.py                |       2 +-
       
       6 files changed, 103 insertions(+), 59 deletions(-)
       ---
   DIR diff --git a/electrum/blockchain.py b/electrum/blockchain.py
       t@@ -22,7 +22,7 @@
        # SOFTWARE.
        import os
        import threading
       -from typing import Optional
       +from typing import Optional, Dict
        
        from . import util
        from .bitcoin import Hash, hash_encode, int_to_hex, rev_hex
       t@@ -73,7 +73,7 @@ def hash_header(header: dict) -> str:
            return hash_encode(Hash(bfh(serialize_header(header))))
        
        
       -blockchains = {}
       +blockchains = {}  # type: Dict[int, Blockchain]
        blockchains_lock = threading.Lock()
        
        
       t@@ -100,7 +100,7 @@ class Blockchain(util.PrintError):
            Manages blockchain headers and their verification
            """
        
       -    def __init__(self, config, forkpoint: int, parent_id: int):
       +    def __init__(self, config, forkpoint: int, parent_id: Optional[int]):
                self.config = config
                self.forkpoint = forkpoint
                self.checkpoints = constants.net.CHECKPOINTS
       t@@ -124,22 +124,32 @@ class Blockchain(util.PrintError):
                children = list(filter(lambda y: y.parent_id==self.forkpoint, chains))
                return max([x.forkpoint for x in children]) if children else None
        
       -    def get_forkpoint(self) -> int:
       +    def get_max_forkpoint(self) -> int:
       +        """Returns the max height where there is a fork
       +        related to this chain.
       +        """
                mc = self.get_max_child()
                return mc if mc is not None else self.forkpoint
        
            def get_branch_size(self) -> int:
       -        return self.height() - self.get_forkpoint() + 1
       +        return self.height() - self.get_max_forkpoint() + 1
        
            def get_name(self) -> str:
       -        return self.get_hash(self.get_forkpoint()).lstrip('00')[0:10]
       +        return self.get_hash(self.get_max_forkpoint()).lstrip('00')[0:10]
        
            def check_header(self, header: dict) -> bool:
                header_hash = hash_header(header)
                height = header.get('block_height')
       +        return self.check_hash(height, header_hash)
       +
       +    def check_hash(self, height: int, header_hash: str) -> bool:
       +        """Returns whether the hash of the block at given height
       +        is the given hash.
       +        """
       +        assert isinstance(header_hash, str) and len(header_hash) == 64, header_hash  # hex
                try:
                    return header_hash == self.get_hash(height)
       -        except MissingHeader:
       +        except Exception:
                    return False
        
            def fork(parent, header: dict) -> 'Blockchain':
   DIR diff --git a/electrum/gui/kivy/main_window.py b/electrum/gui/kivy/main_window.py
       t@@ -120,7 +120,7 @@ class ElectrumWindow(App):
                    with blockchain.blockchains_lock: blockchain_items = list(blockchain.blockchains.items())
                    for index, b in blockchain_items:
                        if name == b.get_name():
       -                    self.network.run_from_another_thread(self.network.follow_chain(index))
       +                    self.network.run_from_another_thread(self.network.follow_chain_given_id(index))
                names = [blockchain.blockchains[b].get_name() for b in chains]
                if len(names) > 1:
                    cur_chain = self.network.blockchain().get_name()
       t@@ -664,7 +664,7 @@ class ElectrumWindow(App):
                self.num_nodes = len(self.network.get_interfaces())
                self.num_chains = len(self.network.get_blockchains())
                chain = self.network.blockchain()
       -        self.blockchain_forkpoint = chain.get_forkpoint()
       +        self.blockchain_forkpoint = chain.get_max_forkpoint()
                self.blockchain_name = chain.get_name()
                interface = self.network.interface
                if interface:
   DIR diff --git a/electrum/gui/qt/network_dialog.py b/electrum/gui/qt/network_dialog.py
       t@@ -107,7 +107,7 @@ class NodesListWidget(QTreeWidget):
                    b = blockchain.blockchains[k]
                    name = b.get_name()
                    if n_chains >1:
       -                x = QTreeWidgetItem([name + '@%d'%b.get_forkpoint(), '%d'%b.height()])
       +                x = QTreeWidgetItem([name + '@%d'%b.get_max_forkpoint(), '%d'%b.height()])
                        x.setData(0, Qt.UserRole, 1)
                        x.setData(1, Qt.UserRole, b.forkpoint)
                    else:
       t@@ -364,7 +364,7 @@ class NetworkChoiceLayout(object):
                chains = self.network.get_blockchains()
                if len(chains) > 1:
                    chain = self.network.blockchain()
       -            forkpoint = chain.get_forkpoint()
       +            forkpoint = chain.get_max_forkpoint()
                    name = chain.get_name()
                    msg = _('Chain split detected at block {0}').format(forkpoint) + '\n'
                    msg += (_('You are following branch') if auto_connect else _('Your server is on branch'))+ ' ' + name
       t@@ -411,14 +411,11 @@ class NetworkChoiceLayout(object):
                self.set_server()
        
            def follow_branch(self, index):
       -        self.network.run_from_another_thread(self.network.follow_chain(index))
       +        self.network.run_from_another_thread(self.network.follow_chain_given_id(index))
                self.update()
        
            def follow_server(self, server):
       -        net_params = self.network.get_parameters()
       -        host, port, protocol = deserialize_server(server)
       -        net_params = net_params._replace(host=host, port=port, protocol=protocol)
       -        self.network.run_from_another_thread(self.network.set_parameters(net_params))
       +        self.network.run_from_another_thread(self.network.follow_chain_given_server(server))
                self.update()
        
            def server_changed(self, x):
   DIR diff --git a/electrum/interface.py b/electrum/interface.py
       t@@ -384,6 +384,7 @@ class Interface(PrintError):
                    self.mark_ready()
                    await self._process_header_at_tip()
                    self.network.trigger_callback('network_updated')
       +            await self.network.switch_unwanted_fork_interface()
                    await self.network.switch_lagging_interface()
        
            async def _process_header_at_tip(self):
   DIR diff --git a/electrum/network.py b/electrum/network.py
       t@@ -32,7 +32,7 @@ import json
        import sys
        import ipaddress
        import asyncio
       -from typing import NamedTuple, Optional, Sequence, List
       +from typing import NamedTuple, Optional, Sequence, List, Dict
        import traceback
        
        import dns
       t@@ -172,10 +172,9 @@ class Network(PrintError):
                self.config = SimpleConfig(config) if isinstance(config, dict) else config
                self.num_server = 10 if not self.config.get('oneserver') else 0
                blockchain.blockchains = blockchain.read_blockchains(self.config)
       -        self.print_error("blockchains", list(blockchain.blockchains.keys()))
       -        self.blockchain_index = config.get('blockchain_index', 0)
       -        if self.blockchain_index not in blockchain.blockchains.keys():
       -            self.blockchain_index = 0
       +        self.print_error("blockchains", list(blockchain.blockchains))
       +        self._blockchain_preferred_block = self.config.get('blockchain_preferred_block', None)  # type: Optional[Dict]
       +        self._blockchain_index = 0
                # Server for addresses and transactions
                self.default_server = self.config.get('server', None)
                # Sanitize default server
       t@@ -213,11 +212,10 @@ class Network(PrintError):
                # retry times
                self.server_retry_time = time.time()
                self.nodes_retry_time = time.time()
       -        # kick off the network.  interface is the main server we are currently
       -        # communicating with.  interfaces is the set of servers we are connecting
       -        # to or have an ongoing connection with
       +        # the main server we are currently communicating with
                self.interface = None  # type: Interface
       -        self.interfaces = {}
       +        # set of servers we have an ongoing connection with
       +        self.interfaces = {}  # type: Dict[str, Interface]
                self.auto_connect = self.config.get('auto_connect', True)
                self.connecting = set()
                self.server_queue = None
       t@@ -227,8 +225,8 @@ class Network(PrintError):
                #self.asyncio_loop.set_debug(1)
                self._run_forever = asyncio.Future()
                self._thread = threading.Thread(target=self.asyncio_loop.run_until_complete,
       -                                                args=(self._run_forever,),
       -                                                name='Network')
       +                                        args=(self._run_forever,),
       +                                        name='Network')
                self._thread.start()
        
            def run_from_another_thread(self, coro):
       t@@ -523,20 +521,40 @@ class Network(PrintError):
        
            async def switch_lagging_interface(self):
                '''If auto_connect and lagging, switch interface'''
       -        if await self._server_is_lagging() and self.auto_connect:
       +        if self.auto_connect and await self._server_is_lagging():
                    # switch to one that has the correct header (not height)
       -            header = self.blockchain().read_header(self.get_local_height())
       -            def filt(x):
       -                a = x[1].tip_header
       -                b = header
       -                assert type(a) is type(b)
       -                return a == b
       -
       -            with self.interfaces_lock: interfaces_items = list(self.interfaces.items())
       -            filtered = list(map(lambda x: x[0], filter(filt, interfaces_items)))
       +            best_header = self.blockchain().read_header(self.get_local_height())
       +            with self.interfaces_lock: interfaces = list(self.interfaces.values())
       +            filtered = list(filter(lambda iface: iface.tip_header == best_header, interfaces))
                    if filtered:
       -                choice = random.choice(filtered)
       -                await self.switch_to_interface(choice)
       +                chosen_iface = random.choice(filtered)
       +                await self.switch_to_interface(chosen_iface.server)
       +
       +    async def switch_unwanted_fork_interface(self):
       +        """If auto_connect and main interface is not on preferred fork,
       +        try to switch to preferred fork.
       +        """
       +        if not self.auto_connect:
       +            return
       +        with self.interfaces_lock: interfaces = list(self.interfaces.values())
       +        # try to switch to preferred fork
       +        if self._blockchain_preferred_block:
       +            pref_height = self._blockchain_preferred_block['height']
       +            pref_hash   = self._blockchain_preferred_block['hash']
       +            filtered = list(filter(lambda iface: iface.blockchain.check_hash(pref_height, pref_hash),
       +                                   interfaces))
       +            if filtered:
       +                chosen_iface = random.choice(filtered)
       +                await self.switch_to_interface(chosen_iface.server)
       +                return
       +        # try to switch to longest chain
       +        if self.blockchain().parent_id is None:
       +            return  # already on longest chain
       +        filtered = list(filter(lambda iface: iface.blockchain.parent_id is None,
       +                               interfaces))
       +        if filtered:
       +            chosen_iface = random.choice(filtered)
       +            await self.switch_to_interface(chosen_iface.server)
        
            async def switch_to_interface(self, server: str):
                """Switch to server as our main interface. If no connection exists,
       t@@ -704,8 +722,8 @@ class Network(PrintError):
            def blockchain(self) -> Blockchain:
                interface = self.interface
                if interface and interface.blockchain is not None:
       -            self.blockchain_index = interface.blockchain.forkpoint
       -        return blockchain.blockchains[self.blockchain_index]
       +            self._blockchain_index = interface.blockchain.forkpoint
       +        return blockchain.blockchains[self._blockchain_index]
        
            def get_blockchains(self):
                out = {}  # blockchain_id -> list(interfaces)
       t@@ -724,24 +742,42 @@ class Network(PrintError):
                    await self.connection_down(interface.server)
                return ifaces
        
       -    async def follow_chain(self, chain_id):
       -        bc = blockchain.blockchains.get(chain_id)
       -        if bc:
       -            self.blockchain_index = chain_id
       -            self.config.set_key('blockchain_index', chain_id)
       -            with self.interfaces_lock: interfaces_values = list(self.interfaces.values())
       -            for iface in interfaces_values:
       -                if iface.blockchain == bc:
       -                    await self.switch_to_interface(iface.server)
       -                    break
       -        else:
       -            raise Exception('blockchain not found', chain_id)
       +    def _set_preferred_chain(self, chain: Blockchain):
       +        height = chain.get_max_forkpoint()
       +        header_hash = chain.get_hash(height)
       +        self._blockchain_preferred_block = {
       +            'height': height,
       +            'hash': header_hash,
       +        }
       +        self.config.set_key('blockchain_preferred_block', self._blockchain_preferred_block)
        
       -        if self.interface:
       -            net_params = self.get_parameters()
       -            host, port, protocol = deserialize_server(self.interface.server)
       -            net_params = net_params._replace(host=host, port=port, protocol=protocol)
       -            await self.set_parameters(net_params)
       +    async def follow_chain_given_id(self, chain_id: int) -> None:
       +        bc = blockchain.blockchains.get(chain_id)
       +        if not bc:
       +            raise Exception('blockchain {} not found'.format(chain_id))
       +        self._set_preferred_chain(bc)
       +        # select server on this chain
       +        with self.interfaces_lock: interfaces = list(self.interfaces.values())
       +        interfaces_on_selected_chain = list(filter(lambda iface: iface.blockchain == bc, interfaces))
       +        if len(interfaces_on_selected_chain) == 0: return
       +        chosen_iface = random.choice(interfaces_on_selected_chain)
       +        # switch to server (and save to config)
       +        net_params = self.get_parameters()
       +        host, port, protocol = deserialize_server(chosen_iface.server)
       +        net_params = net_params._replace(host=host, port=port, protocol=protocol)
       +        await self.set_parameters(net_params)
       +
       +    async def follow_chain_given_server(self, server_str: str) -> None:
       +        # note that server_str should correspond to a connected interface
       +        iface = self.interfaces.get(server_str)
       +        if iface is None:
       +            return
       +        self._set_preferred_chain(iface.blockchain)
       +        # switch to server (and save to config)
       +        net_params = self.get_parameters()
       +        host, port, protocol = deserialize_server(server_str)
       +        net_params = net_params._replace(host=host, port=port, protocol=protocol)
       +        await self.set_parameters(net_params)
        
            def get_local_height(self):
                return self.blockchain().height()
   DIR diff --git a/electrum/verifier.py b/electrum/verifier.py
       t@@ -156,7 +156,7 @@ class SPV(NetworkJobOnDefaultServer):
        
            async def _maybe_undo_verifications(self):
                def undo_verifications():
       -            height = self.blockchain.get_forkpoint()
       +            height = self.blockchain.get_max_forkpoint()
                    self.print_error("undoing verifications back to height {}".format(height))
                    tx_hashes = self.wallet.undo_verifications(self.blockchain, height)
                    for tx_hash in tx_hashes: