URI: 
       tlnworker: fix threading issues for .channels attribute - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit b9b53e7f76eb054773a9a495593396b6b18af85a
   DIR parent f5eb91900ab0a1c3bb776d14f417e566932e596e
  HTML Author: SomberNight <somber.night@protonmail.com>
       Date:   Thu, 30 Apr 2020 21:08:26 +0200
       
       lnworker: fix threading issues for .channels attribute
       
       external code (commands/gui) did not always take lock when iterating lnworker.channels.
       instead of exposing lock, let's take a copy internally (as with .peers)
       
       Diffstat:
         M electrum/lnpeer.py                  |       4 +++-
         M electrum/lnworker.py                |      73 ++++++++++++++-----------------
         M electrum/tests/test_lnpeer.py       |      10 +++++++---
       
       3 files changed, 43 insertions(+), 44 deletions(-)
       ---
   DIR diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
       t@@ -107,7 +107,7 @@ class Peer(Logger):
                if not (message_name.startswith("update_") or is_commitment_signed):
                    return
                assert channel_id
       -        chan = self.lnworker.channels[channel_id]  # type: Channel
       +        chan = self.channels[channel_id]
                chan.hm.store_local_update_raw_msg(raw_msg, is_commitment_signed=is_commitment_signed)
                if is_commitment_signed:
                    # saving now, to ensure replaying updates works (in case of channel reestablishment)
       t@@ -139,6 +139,8 @@ class Peer(Logger):
        
            @property
            def channels(self) -> Dict[bytes, Channel]:
       +        # FIXME this iterates over all channels in lnworker,
       +        #       so if we just want to lookup a channel by channel_id, it's wasteful
                return self.lnworker.channels_for_peer(self.pubkey)
        
            def diagnostic_name(self):
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -491,22 +491,26 @@ class LNWallet(LNWorker):
                self.enable_htlc_settle.set()
        
                # note: accessing channels (besides simple lookup) needs self.lock!
       -        self.channels = {}
       +        self._channels = {}  # type: Dict[bytes, Channel]
                channels = self.db.get_dict("channels")
                for channel_id, c in channels.items():
       -            self.channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
       +            self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
        
                self.pending_payments = defaultdict(asyncio.Future)  # type: Dict[bytes, asyncio.Future[BarePaymentAttemptLog]]
        
       +    @property
       +    def channels(self) -> Mapping[bytes, Channel]:
       +        """Returns a read-only copy of channels."""
       +        with self.lock:
       +            return self._channels.copy()
       +
            @ignore_exceptions
            @log_exceptions
            async def sync_with_local_watchtower(self):
                watchtower = self.network.local_watchtower
                if watchtower:
                    while True:
       -                with self.lock:
       -                    channels = list(self.channels.values())
       -                for chan in channels:
       +                for chan in self.channels.values():
                            await self.sync_channel_with_watchtower(chan, watchtower.sweepstore)
                        await asyncio.sleep(5)
        
       t@@ -524,12 +528,10 @@ class LNWallet(LNWorker):
                    watchtower_url = self.config.get('watchtower_url')
                    if not watchtower_url:
                        continue
       -            with self.lock:
       -                channels = list(self.channels.values())
                    try:
                        async with make_aiohttp_session(proxy=self.network.proxy) as session:
                            watchtower = myAiohttpClient(session, watchtower_url)
       -                    for chan in channels:
       +                    for chan in self.channels.values():
                                await self.sync_channel_with_watchtower(chan, watchtower)
                    except aiohttp.client_exceptions.ClientConnectorError:
                        self.logger.info(f'could not contact remote watchtower {watchtower_url}')
       t@@ -574,9 +576,7 @@ class LNWallet(LNWorker):
                # return one item per payment_hash
                # note: with AMP we will have several channels per payment
                out = defaultdict(list)
       -        with self.lock:
       -            channels = list(self.channels.values())
       -        for chan in channels:
       +        for chan in self.channels.values():
                    d = chan.get_settled_payments()
                    for k, v in d.items():
                        out[k].append(v)
       t@@ -628,9 +628,7 @@ class LNWallet(LNWorker):
            def get_onchain_history(self):
                out = {}
                # add funding events
       -        with self.lock:
       -            channels = list(self.channels.values())
       -        for chan in channels:
       +        for chan in self.channels.values():
                    item = chan.get_funding_height()
                    if item is None:
                        continue
       t@@ -693,8 +691,7 @@ class LNWallet(LNWorker):
        
            def channels_for_peer(self, node_id):
                assert type(node_id) is bytes
       -        with self.lock:
       -            return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
       +        return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
        
            def channel_state_changed(self, chan):
                self.save_channel(chan)
       t@@ -708,9 +705,7 @@ class LNWallet(LNWorker):
                util.trigger_callback('channel', chan)
        
            def channel_by_txo(self, txo):
       -        with self.lock:
       -            channels = list(self.channels.values())
       -        for chan in channels:
       +        for chan in self.channels.values():
                    if chan.funding_outpoint.to_str() == txo:
                        return chan
        
       t@@ -762,7 +757,7 @@ class LNWallet(LNWorker):
        
            def add_channel(self, chan):
                with self.lock:
       -            self.channels[chan.channel_id] = chan
       +            self._channels[chan.channel_id] = chan
                self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
        
            def add_new_channel(self, chan):
       t@@ -805,10 +800,9 @@ class LNWallet(LNWorker):
                success = fut.result()
        
            def get_channel_by_short_id(self, short_channel_id: ShortChannelID) -> Channel:
       -        with self.lock:
       -            for chan in self.channels.values():
       -                if chan.short_channel_id == short_channel_id:
       -                    return chan
       +        for chan in self.channels.values():
       +            if chan.short_channel_id == short_channel_id:
       +                return chan
        
            async def _pay(self, invoice, amount_sat=None, attempts=1) -> bool:
                lnaddr = self._check_invoice(invoice, amount_sat)
       t@@ -981,8 +975,7 @@ class LNWallet(LNWorker):
                # if there are multiple hints, we will use the first one that works,
                # from a random permutation
                random.shuffle(r_tags)
       -        with self.lock:
       -            channels = list(self.channels.values())
       +        channels = list(self.channels.values())
                scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
                                       if chan.short_channel_id is not None}
                for private_route in r_tags:
       t@@ -1196,8 +1189,7 @@ class LNWallet(LNWorker):
            async def _calc_routing_hints_for_invoice(self, amount_sat: Optional[int]):
                """calculate routing hints (BOLT-11 'r' field)"""
                routing_hints = []
       -        with self.lock:
       -            channels = list(self.channels.values())
       +        channels = list(self.channels.values())
                scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
                                       if chan.short_channel_id is not None}
                ignore_min_htlc_value = False
       t@@ -1251,24 +1243,27 @@ class LNWallet(LNWorker):
        
            def get_balance(self):
                with self.lock:
       -            return Decimal(sum(chan.balance(LOCAL) if not chan.is_closed() else 0 for chan in self.channels.values()))/1000
       +            return Decimal(sum(chan.balance(LOCAL) if not chan.is_closed() else 0
       +                               for chan in self.channels.values())) / 1000
        
            def num_sats_can_send(self) -> Union[Decimal, int]:
                with self.lock:
       -            return Decimal(max(chan.available_to_spend(LOCAL) if chan.is_open() else 0 for chan in self.channels.values()))/1000 if self.channels else 0
       +            return Decimal(max(chan.available_to_spend(LOCAL) if chan.is_open() else 0
       +                               for chan in self.channels.values()))/1000 if self.channels else 0
        
            def num_sats_can_receive(self) -> Union[Decimal, int]:
                with self.lock:
       -            return Decimal(max(chan.available_to_spend(REMOTE) if chan.is_open() else 0 for chan in self.channels.values()))/1000 if self.channels else 0
       +            return Decimal(max(chan.available_to_spend(REMOTE) if chan.is_open() else 0
       +                               for chan in self.channels.values()))/1000 if self.channels else 0
        
            async def close_channel(self, chan_id):
       -        chan = self.channels[chan_id]
       +        chan = self._channels[chan_id]
                peer = self._peers[chan.node_id]
                return await peer.close_channel(chan_id)
        
            async def force_close_channel(self, chan_id):
                # returns txid or raises
       -        chan = self.channels[chan_id]
       +        chan = self._channels[chan_id]
                tx = chan.force_close_tx()
                await self.network.broadcast_transaction(tx)
                chan.set_state(ChannelState.FORCE_CLOSING)
       t@@ -1276,16 +1271,16 @@ class LNWallet(LNWorker):
        
            async def try_force_closing(self, chan_id):
                # fails silently but sets the state, so that we will retry later
       -        chan = self.channels[chan_id]
       +        chan = self._channels[chan_id]
                tx = chan.force_close_tx()
                chan.set_state(ChannelState.FORCE_CLOSING)
                await self.network.try_broadcasting(tx, 'force-close')
        
            def remove_channel(self, chan_id):
       -        chan = self.channels[chan_id]
       +        chan = self._channels[chan_id]
                assert chan.get_state() == ChannelState.REDEEMED
                with self.lock:
       -            self.channels.pop(chan_id)
       +            self._channels.pop(chan_id)
                    self.db.get('channels').pop(chan_id.hex())
        
                util.trigger_callback('channels_updated', self.wallet)
       t@@ -1316,9 +1311,7 @@ class LNWallet(LNWorker):
            async def reestablish_peers_and_channels(self):
                while True:
                    await asyncio.sleep(1)
       -            with self.lock:
       -                channels = list(self.channels.values())
       -            for chan in channels:
       +            for chan in self.channels.values():
                        if chan.is_closed():
                            continue
                        # reestablish
       t@@ -1340,7 +1333,7 @@ class LNWallet(LNWorker):
                return max(253, feerate_per_kvbyte // 4)
        
            def create_channel_backup(self, channel_id):
       -        chan = self.channels[channel_id]
       +        chan = self._channels[channel_id]
                peer_addresses = list(chan.get_peer_addresses())
                peer_addr = peer_addresses[0]
                return ChannelBackupStorage(
   DIR diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
       t@@ -102,7 +102,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
                self.remote_keypair = remote_keypair
                self.node_keypair = local_keypair
                self.network = MockNetwork(tx_queue)
       -        self.channels = {chan.channel_id: chan}
       +        self._channels = {chan.channel_id: chan}
                self.payments = {}
                self.logs = defaultdict(list)
                self.wallet = MockWallet()
       t@@ -123,6 +123,10 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
                return noop_lock()
        
            @property
       +    def channels(self):
       +        return self._channels
       +
       +    @property
            def peers(self):
                return self._peers
        
       t@@ -131,11 +135,11 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
                return {self.remote_keypair.pubkey: self.peer}
        
            def channels_for_peer(self, pubkey):
       -        return self.channels
       +        return self._channels
        
            def get_channel_by_short_id(self, short_channel_id):
                with self.lock:
       -            for chan in self.channels.values():
       +            for chan in self._channels.values():
                        if chan.short_channel_id == short_channel_id:
                            return chan