URI: 
       tMerge pull request #4453 from SomberNight/network_locks - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 112b0e0544dc25fce718e66ef678cd7b360d405b
   DIR parent a7589a97ad04a5dbb9889d06e958a7ea55d8ac8e
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Fri, 22 Jun 2018 13:05:40 +0200
       
       Merge pull request #4453 from SomberNight/network_locks
       
       locks in network.py
       Diffstat:
         M lib/network.py                      |     203 ++++++++++++++++++++-----------
       
       1 file changed, 135 insertions(+), 68 deletions(-)
       ---
   DIR diff --git a/lib/network.py b/lib/network.py
       t@@ -171,7 +171,7 @@ class Network(util.DaemonThread):
                util.DaemonThread.__init__(self)
                self.config = SimpleConfig(config) if isinstance(config, dict) else config
                self.num_server = 10 if not self.config.get('oneserver') else 0
       -        self.blockchains = blockchain.read_blockchains(self.config)
       +        self.blockchains = blockchain.read_blockchains(self.config)  # note: needs self.blockchains_lock
                self.print_error("blockchains", self.blockchains.keys())
                self.blockchain_index = config.get('blockchain_index', 0)
                if self.blockchain_index not in self.blockchains.keys():
       t@@ -187,27 +187,35 @@ class Network(util.DaemonThread):
                        self.default_server = None
                if not self.default_server:
                    self.default_server = pick_random_server()
       -        self.lock = threading.Lock()
       +
       +        # locks: if you need to take multiple ones, acquire them in the order they are defined here!
       +        self.interface_lock = threading.RLock()            # <- re-entrant
       +        self.callback_lock = threading.Lock()
       +        self.pending_sends_lock = threading.Lock()
       +        self.recent_servers_lock = threading.RLock()       # <- re-entrant
       +        self.subscribed_addresses_lock = threading.Lock()
       +        self.blockchains_lock = threading.Lock()
       +
                self.pending_sends = []
                self.message_id = 0
                self.debug = False
                self.irc_servers = {} # returned by interface (list from irc)
       -        self.recent_servers = self.read_recent_servers()
       +        self.recent_servers = self.read_recent_servers()  # note: needs self.recent_servers_lock
        
                self.banner = ''
                self.donation_address = ''
                self.relay_fee = None
                # callbacks passed with subscriptions
       -        self.subscriptions = defaultdict(list)
       -        self.sub_cache = {}
       +        self.subscriptions = defaultdict(list)  # note: needs self.callback_lock
       +        self.sub_cache = {}                     # note: needs self.interface_lock
                # callbacks set by the GUI
       -        self.callbacks = defaultdict(list)
       +        self.callbacks = defaultdict(list)      # note: needs self.callback_lock
        
                dir_path = os.path.join( self.config.path, 'certs')
                util.make_dir(dir_path)
        
                # subscriptions and requests
       -        self.subscribed_addresses = set()
       +        self.subscribed_addresses = set()  # note: needs self.subscribed_addresses_lock
                self.h2addr = {}
                # Requests from client we've not seen a response to
                self.unanswered_requests = {}
       t@@ -217,8 +225,8 @@ class Network(util.DaemonThread):
                # kick off the network.  interface is the main server we are currently
                # communicating with.  interfaces is the set of servers we are connecting
                # to or have an ongoing connection with
       -        self.interface = None
       -        self.interfaces = {}
       +        self.interface = None              # note: needs self.interface_lock
       +        self.interfaces = {}               # note: needs self.interface_lock
                self.auto_connect = self.config.get('auto_connect', True)
                self.connecting = set()
                self.requested_chunks = set()
       t@@ -226,19 +234,31 @@ class Network(util.DaemonThread):
                self.start_network(deserialize_server(self.default_server)[2],
                                   deserialize_proxy(self.config.get('proxy')))
        
       +    def with_interface_lock(func):
       +        def func_wrapper(self, *args, **kwargs):
       +            with self.interface_lock:
       +                return func(self, *args, **kwargs)
       +        return func_wrapper
       +
       +    def with_recent_servers_lock(func):
       +        def func_wrapper(self, *args, **kwargs):
       +            with self.recent_servers_lock:
       +                return func(self, *args, **kwargs)
       +        return func_wrapper
       +
            def register_callback(self, callback, events):
       -        with self.lock:
       +        with self.callback_lock:
                    for event in events:
                        self.callbacks[event].append(callback)
        
            def unregister_callback(self, callback):
       -        with self.lock:
       +        with self.callback_lock:
                    for callbacks in self.callbacks.values():
                        if callback in callbacks:
                            callbacks.remove(callback)
        
            def trigger_callback(self, event, *args):
       -        with self.lock:
       +        with self.callback_lock:
                    callbacks = self.callbacks[event][:]
                [callback(event, *args) for callback in callbacks]
        
       t@@ -253,6 +273,7 @@ class Network(util.DaemonThread):
                except:
                    return []
        
       +    @with_recent_servers_lock
            def save_recent_servers(self):
                if not self.config.path:
                    return
       t@@ -264,6 +285,7 @@ class Network(util.DaemonThread):
                except:
                    pass
        
       +    @with_interface_lock
            def get_server_height(self):
                return self.interface.tip if self.interface else 0
        
       t@@ -291,11 +313,15 @@ class Network(util.DaemonThread):
            def is_up_to_date(self):
                return self.unanswered_requests == {}
        
       +    @with_interface_lock
            def queue_request(self, method, params, interface=None):
                # If you want to queue a request on any interface it must go
                # through this function so message ids are properly tracked
                if interface is None:
                    interface = self.interface
       +        if interface is None:
       +            self.print_error('warning: dropping request', method, params)
       +            return
                message_id = self.message_id
                self.message_id += 1
                if self.debug:
       t@@ -303,7 +329,9 @@ class Network(util.DaemonThread):
                interface.queue_request(method, params, message_id)
                return message_id
        
       +    @with_interface_lock
            def send_subscriptions(self):
       +        assert self.interface
                self.print_error('sending subscriptions to', self.interface.server, len(self.unanswered_requests), len(self.subscribed_addresses))
                self.sub_cache.clear()
                # Resend unanswered requests
       t@@ -317,8 +345,9 @@ class Network(util.DaemonThread):
                self.queue_request('server.peers.subscribe', [])
                self.request_fee_estimates()
                self.queue_request('blockchain.relayfee', [])
       -        for h in list(self.subscribed_addresses):
       -            self.queue_request('blockchain.scripthash.subscribe', [h])
       +        with self.subscribed_addresses_lock:
       +            for h in self.subscribed_addresses:
       +                self.queue_request('blockchain.scripthash.subscribe', [h])
        
            def request_fee_estimates(self):
                from .simple_config import FEE_ETA_TARGETS
       t@@ -358,10 +387,12 @@ class Network(util.DaemonThread):
                if self.is_connected():
                    return self.donation_address
        
       +    @with_interface_lock
            def get_interfaces(self):
                '''The interfaces that are in connected state'''
                return list(self.interfaces.keys())
        
       +    @with_recent_servers_lock
            def get_servers(self):
                out = constants.net.DEFAULT_SERVERS
                if self.irc_servers:
       t@@ -376,6 +407,7 @@ class Network(util.DaemonThread):
                            out[host] = { protocol:port }
                return out
        
       +    @with_interface_lock
            def start_interface(self, server):
                if (not server in self.interfaces and not server in self.connecting):
                    if server == self.default_server:
       t@@ -385,7 +417,8 @@ class Network(util.DaemonThread):
                    c = Connection(server, self.socket_queue, self.config.path)
        
            def start_random_interface(self):
       -        exclude_set = self.disconnected_servers.union(set(self.interfaces))
       +        with self.interface_lock:
       +            exclude_set = self.disconnected_servers.union(set(self.interfaces))
                server = pick_random_server(self.get_servers(), self.protocol, exclude_set)
                if server:
                    self.start_interface(server)
       t@@ -433,15 +466,17 @@ class Network(util.DaemonThread):
                    else:
                        socket.getaddrinfo = socket._getaddrinfo
        
       +    @with_interface_lock
            def start_network(self, protocol, proxy):
                assert not self.interface and not self.interfaces
                assert not self.connecting and self.socket_queue.empty()
                self.print_error('starting network')
       -        self.disconnected_servers = set([])
       +        self.disconnected_servers = set([])  # note: needs self.interface_lock
                self.protocol = protocol
                self.set_proxy(proxy)
                self.start_interfaces()
        
       +    @with_interface_lock
            def stop_network(self):
                self.print_error("stopping network")
                for interface in list(self.interfaces.values()):
       t@@ -491,6 +526,7 @@ class Network(util.DaemonThread):
                if servers:
                    self.switch_to_interface(random.choice(servers))
        
       +    @with_interface_lock
            def switch_lagging_interface(self):
                '''If auto_connect and lagging, switch interface'''
                if self.server_is_lagging() and self.auto_connect:
       t@@ -501,6 +537,7 @@ class Network(util.DaemonThread):
                        choice = random.choice(filtered)
                        self.switch_to_interface(choice)
        
       +    @with_interface_lock
            def switch_to_interface(self, server):
                '''Switch to server as our interface.  If no connection exists nor
                being opened, start a thread to connect.  The actual switch will
       t@@ -522,6 +559,7 @@ class Network(util.DaemonThread):
                    self.set_status('connected')
                    self.notify('updated')
        
       +    @with_interface_lock
            def close_interface(self, interface):
                if interface:
                    if interface.server in self.interfaces:
       t@@ -530,6 +568,7 @@ class Network(util.DaemonThread):
                        self.interface = None
                    interface.close()
        
       +    @with_recent_servers_lock
            def add_recent_server(self, server):
                # list is ordered
                if server in self.recent_servers:
       t@@ -587,7 +626,8 @@ class Network(util.DaemonThread):
                for callback in callbacks:
                    callback(response)
        
       -    def get_index(self, method, params):
       +    @classmethod
       +    def get_index(cls, method, params):
                """ hashable index for subscriptions and cache"""
                return str(method) + (':' + str(params[0]) if params else '')
        
       t@@ -602,12 +642,15 @@ class Network(util.DaemonThread):
                        # and are placed in the unanswered_requests dictionary
                        client_req = self.unanswered_requests.pop(message_id, None)
                        if client_req:
       -                    assert interface == self.interface
       +                    if interface != self.interface:
       +                        # we probably changed the current interface
       +                        # in the meantime; drop this.
       +                        return
                            callbacks = [client_req[2]]
                        else:
                            # fixme: will only work for subscriptions
                            k = self.get_index(method, params)
       -                    callbacks = self.subscriptions.get(k, [])
       +                    callbacks = list(self.subscriptions.get(k, []))
        
                        # Copy the request method and params to the response
                        response['method'] = method
       t@@ -615,7 +658,8 @@ class Network(util.DaemonThread):
                        # Only once we've received a response to an addr subscription
                        # add it to the list; avoids double-sends on reconnection
                        if method == 'blockchain.scripthash.subscribe':
       -                    self.subscribed_addresses.add(params[0])
       +                    with self.subscribed_addresses_lock:
       +                        self.subscribed_addresses.add(params[0])
                    else:
                        if not response:  # Closed remotely / misbehaving
                            self.connection_down(interface.server)
       t@@ -630,27 +674,29 @@ class Network(util.DaemonThread):
                        elif method == 'blockchain.scripthash.subscribe':
                            response['params'] = [params[0]]  # addr
                            response['result'] = params[1]
       -                callbacks = self.subscriptions.get(k, [])
       +                callbacks = list(self.subscriptions.get(k, []))
        
                    # update cache if it's a subscription
                    if method.endswith('.subscribe'):
       -                self.sub_cache[k] = response
       +                with self.interface_lock:
       +                    self.sub_cache[k] = response
                    # Response is now in canonical form
                    self.process_response(interface, response, callbacks)
        
            def send(self, messages, callback):
                '''Messages is a list of (method, params) tuples'''
                messages = list(messages)
       -        with self.lock:
       +        with self.pending_sends_lock:
                    self.pending_sends.append((messages, callback))
        
       +    @with_interface_lock
            def process_pending_sends(self):
                # Requests needs connectivity.  If we don't have an interface,
                # we cannot process them.
                if not self.interface:
                    return
        
       -        with self.lock:
       +        with self.pending_sends_lock:
                    sends = self.pending_sends
                    self.pending_sends = []
        
       t@@ -660,10 +706,11 @@ class Network(util.DaemonThread):
                        if method.endswith('.subscribe'):
                            k = self.get_index(method, params)
                            # add callback to list
       -                    l = self.subscriptions.get(k, [])
       +                    l = list(self.subscriptions.get(k, []))
                            if callback not in l:
                                l.append(callback)
       -                    self.subscriptions[k] = l
       +                    with self.callback_lock:
       +                        self.subscriptions[k] = l
                            # check cached response for subscriptions
                            r = self.sub_cache.get(k)
        
       t@@ -679,11 +726,12 @@ class Network(util.DaemonThread):
                # Note: we can't unsubscribe from the server, so if we receive
                # subsequent notifications process_response() will emit a harmless
                # "received unexpected notification" warning
       -        with self.lock:
       +        with self.callback_lock:
                    for v in self.subscriptions.values():
                        if callback in v:
                            v.remove(callback)
        
       +    @with_interface_lock
            def connection_down(self, server):
                '''A connection to server either went down, or was never made.
                We distinguish by whether it is in self.interfaces.'''
       t@@ -693,9 +741,10 @@ class Network(util.DaemonThread):
                if server in self.interfaces:
                    self.close_interface(self.interfaces[server])
                    self.notify('interfaces')
       -        for b in self.blockchains.values():
       -            if b.catch_up == server:
       -                b.catch_up = None
       +        with self.blockchains_lock:
       +            for b in self.blockchains.values():
       +                if b.catch_up == server:
       +                    b.catch_up = None
        
            def new_interface(self, server, socket):
                # todo: get tip first, then decide which checkpoint to use.
       t@@ -706,7 +755,8 @@ class Network(util.DaemonThread):
                interface.tip = 0
                interface.mode = 'default'
                interface.request = None
       -        self.interfaces[server] = interface
       +        with self.interface_lock:
       +            self.interfaces[server] = interface
                # server.version should be the first message
                params = [ELECTRUM_VERSION, PROTOCOL_VERSION]
                self.queue_request('server.version', params, interface)
       t@@ -729,7 +779,9 @@ class Network(util.DaemonThread):
        
                # Send pings and shut down stale interfaces
                # must use copy of values
       -        for interface in list(self.interfaces.values()):
       +        with self.interface_lock:
       +            interfaces = list(self.interfaces.values())
       +        for interface in interfaces:
                    if interface.has_timed_out():
                        self.connection_down(interface.server)
                    elif interface.ping_required():
       t@@ -737,28 +789,30 @@ class Network(util.DaemonThread):
        
                now = time.time()
                # nodes
       -        if len(self.interfaces) + len(self.connecting) < self.num_server:
       -            self.start_random_interface()
       -            if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
       -                self.print_error('network: retrying connections')
       -                self.disconnected_servers = set([])
       -                self.nodes_retry_time = now
       +        with self.interface_lock:
       +            if len(self.interfaces) + len(self.connecting) < self.num_server:
       +                self.start_random_interface()
       +                if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
       +                    self.print_error('network: retrying connections')
       +                    self.disconnected_servers = set([])
       +                    self.nodes_retry_time = now
        
                # main interface
       -        if not self.is_connected():
       -            if self.auto_connect:
       -                if not self.is_connecting():
       -                    self.switch_to_random_interface()
       -            else:
       -                if self.default_server in self.disconnected_servers:
       -                    if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
       -                        self.disconnected_servers.remove(self.default_server)
       -                        self.server_retry_time = now
       +        with self.interface_lock:
       +            if not self.is_connected():
       +                if self.auto_connect:
       +                    if not self.is_connecting():
       +                        self.switch_to_random_interface()
                        else:
       -                    self.switch_to_interface(self.default_server)
       -        else:
       -            if self.config.is_fee_estimates_update_required():
       -                self.request_fee_estimates()
       +                    if self.default_server in self.disconnected_servers:
       +                        if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
       +                            self.disconnected_servers.remove(self.default_server)
       +                            self.server_retry_time = now
       +                    else:
       +                        self.switch_to_interface(self.default_server)
       +            else:
       +                if self.config.is_fee_estimates_update_required():
       +                    self.request_fee_estimates()
        
            def request_chunk(self, interface, index):
                if index in self.requested_chunks:
       t@@ -876,7 +930,8 @@ class Network(util.DaemonThread):
                            if bh > interface.good:
                                if not interface.blockchain.check_header(interface.bad_header):
                                    b = interface.blockchain.fork(interface.bad_header)
       -                            self.blockchains[interface.bad] = b
       +                            with self.blockchains_lock:
       +                                self.blockchains[interface.bad] = b
                                    interface.blockchain = b
                                    interface.print_error("new chain", b.checkpoint)
                                    interface.mode = 'catch_up'
       t@@ -928,7 +983,9 @@ class Network(util.DaemonThread):
                self.notify('interfaces')
        
            def maintain_requests(self):
       -        for interface in list(self.interfaces.values()):
       +        with self.interface_lock:
       +            interfaces = list(self.interfaces.values())
       +        for interface in interfaces:
                    if interface.request and time.time() - interface.request_time > 20:
                        interface.print_error("blockchain request timed out")
                        self.connection_down(interface.server)
       t@@ -940,14 +997,14 @@ class Network(util.DaemonThread):
                if not self.interfaces:
                    time.sleep(0.1)
                    return
       -        rin = [i for i in self.interfaces.values()]
       -        win = [i for i in self.interfaces.values() if i.num_requests()]
       +        with self.interface_lock:
       +            interfaces = list(self.interfaces.values())
       +        rin = [i for i in interfaces]
       +        win = [i for i in interfaces if i.num_requests()]
                try:
                    rout, wout, xout = select.select(rin, win, [], 0.1)
                except socket.error as e:
       -            # TODO: py3, get code from e
       -            code = None
       -            if code == errno.EINTR:
       +            if e.errno == errno.EINTR:
                        return
                    raise
                assert not xout
       t@@ -1004,7 +1061,8 @@ class Network(util.DaemonThread):
                    self.notify('updated')
                    self.notify('interfaces')
                    return
       -        tip = max([x.height() for x in self.blockchains.values()])
       +        with self.blockchains_lock:
       +            tip = max([x.height() for x in self.blockchains.values()])
                if tip >=0:
                    interface.mode = 'backward'
                    interface.bad = height
       t@@ -1016,19 +1074,24 @@ class Network(util.DaemonThread):
                        chain.catch_up = interface
                        interface.mode = 'catch_up'
                        interface.blockchain = chain
       -                self.print_error("switching to catchup mode", tip,  self.blockchains)
       +                with self.blockchains_lock:
       +                    self.print_error("switching to catchup mode", tip,  self.blockchains)
                        self.request_header(interface, 0)
                    else:
                        self.print_error("chain already catching up with", chain.catch_up.server)
        
       +    @with_interface_lock
            def blockchain(self):
                if self.interface and self.interface.blockchain is not None:
                    self.blockchain_index = self.interface.blockchain.checkpoint
                return self.blockchains[self.blockchain_index]
        
       +    @with_interface_lock
            def get_blockchains(self):
                out = {}
       -        for k, b in self.blockchains.items():
       +        with self.blockchains_lock:
       +            blockchain_items = list(self.blockchains.items())
       +        for k, b in blockchain_items:
                    r = list(filter(lambda i: i.blockchain==b, list(self.interfaces.values())))
                    if r:
                        out[k] = r
       t@@ -1039,18 +1102,21 @@ class Network(util.DaemonThread):
                if blockchain:
                    self.blockchain_index = index
                    self.config.set_key('blockchain_index', index)
       -            for i in self.interfaces.values():
       +            with self.interface_lock:
       +                interfaces = list(self.interfaces.values())
       +            for i in interfaces:
                        if i.blockchain == blockchain:
                            self.switch_to_interface(i.server)
                            break
                else:
                    raise Exception('blockchain not found', index)
        
       -        if self.interface:
       -            server = self.interface.server
       -            host, port, protocol, proxy, auto_connect = self.get_parameters()
       -            host, port, protocol = server.split(':')
       -            self.set_parameters(host, port, protocol, proxy, auto_connect)
       +        with self.interface_lock:
       +            if self.interface:
       +                server = self.interface.server
       +                host, port, protocol, proxy, auto_connect = self.get_parameters()
       +                host, port, protocol = server.split(':')
       +                self.set_parameters(host, port, protocol, proxy, auto_connect)
        
            def get_local_height(self):
                return self.blockchain().height()
       t@@ -1189,5 +1255,6 @@ class Network(util.DaemonThread):
                with open(path, 'w', encoding='utf-8') as f:
                    f.write(json.dumps(cp, indent=4))
        
       -    def max_checkpoint(self):
       +    @classmethod
       +    def max_checkpoint(cls):
                return max(0, len(constants.net.CHECKPOINTS) * 2016 - 1)