URI: 
       tlnhtlc: local update raw messages must not be deleted before acked - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit a27b03be6da5a9f4b397830151daf55c682beda6
   DIR parent 4fc9f243f7d1c957770589582261389a6a74b9ee
  HTML Author: SomberNight <somber.night@protonmail.com>
       Date:   Mon, 12 Aug 2019 18:37:13 +0200
       
       lnhtlc: local update raw messages must not be deleted before acked
       
       In recv_rev() previously all unacked_local_updates were deleted
       as it was assumed that all of them have been acked at that point by
       tthe revoke_and_ack itself. However this is not necessarily the case:
       see new test case.
       
       renamed log['unacked_local_updates'] to log['unacked_local_updates2']
       tto avoid breaking existing wallet files
       
       Diffstat:
         M electrum/lnhtlc.py                  |      33 +++++++++++++++++++++----------
         M electrum/lnpeer.py                  |      13 ++++++++-----
         M electrum/lntransport.py             |       2 +-
         M electrum/tests/test_lnhtlc.py       |      31 +++++++++++++++++++++++++++++++
       
       4 files changed, 63 insertions(+), 16 deletions(-)
       ---
   DIR diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py
       t@@ -1,5 +1,5 @@
        from copy import deepcopy
       -from typing import Optional, Sequence, Tuple, List
       +from typing import Optional, Sequence, Tuple, List, Dict
        
        from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate
        from .util import bh2u, bfh
       t@@ -33,9 +33,10 @@ class HTLCManager:
                        log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()}
                        # "side who initiated fee update" -> action -> list of FeeUpdates
                        log[sub]['fee_updates'] = [FeeUpdate.from_dict(fee_upd) for fee_upd in log[sub]['fee_updates']]
       -        if 'unacked_local_updates' not in log:
       -            log['unacked_local_updates'] = []
       -        log['unacked_local_updates'] = [bfh(upd) for upd in log['unacked_local_updates']]
       +        if 'unacked_local_updates2' not in log:
       +            log['unacked_local_updates2'] = {}
       +        log['unacked_local_updates2'] = {int(ctn): [bfh(msg) for msg in messages]
       +                                         for ctn, messages in log['unacked_local_updates2'].items()}
                # maybe bootstrap fee_updates if initial_feerate was provided
                if initial_feerate is not None:
                    assert type(initial_feerate) is int
       t@@ -74,7 +75,8 @@ class HTLCManager:
                    log[sub]['adds'] = d
                    # fee_updates
                    log[sub]['fee_updates'] = [FeeUpdate.to_dict(fee_upd) for fee_upd in log[sub]['fee_updates']]
       -        log['unacked_local_updates'] = [bh2u(upd) for upd in log['unacked_local_updates']]
       +        log['unacked_local_updates2'] = {ctn: [bh2u(msg) for msg in messages]
       +                                         for ctn, messages in log['unacked_local_updates2'].items()}
                return log
        
            ##### Actions on channel:
       t@@ -175,7 +177,7 @@ class HTLCManager:
                    if fee_update.ctns[LOCAL] is None and fee_update.ctns[REMOTE] <= self.ctn_latest(REMOTE):
                        fee_update.ctns[LOCAL] = self.ctn_latest(LOCAL) + 1
                # no need to keep local update raw msgs anymore, they have just been ACKed.
       -        self.log['unacked_local_updates'].clear()
       +        self.log['unacked_local_updates2'].pop(self.log[REMOTE]['ctn'], None)
        
            def discard_unsigned_remote_updates(self):
                """Discard updates sent by the remote, that the remote itself
       t@@ -200,11 +202,22 @@ class HTLCManager:
                    if fee_update.ctns[LOCAL] > self.ctn_latest(LOCAL):
                        del self.log[REMOTE]['fee_updates'][i]
        
       -    def store_local_update_raw_msg(self, raw_update_msg: bytes):
       -        self.log['unacked_local_updates'].append(raw_update_msg)
       +    def store_local_update_raw_msg(self, raw_update_msg: bytes, *, is_commitment_signed: bool) -> None:
       +        """We need to be able to replay unacknowledged updates we sent to the remote
       +        in case of disconnections. Hence, raw update and commitment_signed messages
       +        are stored temporarily (until they are acked)."""
       +        # self.log['unacked_local_updates2'][ctn_idx] is a list of raw messages
       +        # containing some number of updates and then a single commitment_signed
       +        if is_commitment_signed:
       +            ctn_idx = self.ctn_latest(REMOTE)
       +        else:
       +            ctn_idx = self.ctn_latest(REMOTE) + 1
       +        if ctn_idx not in self.log['unacked_local_updates2']:
       +            self.log['unacked_local_updates2'][ctn_idx] = []
       +        self.log['unacked_local_updates2'][ctn_idx].append(raw_update_msg)
        
       -    def get_unacked_local_updates(self) -> Sequence[bytes]:
       -        return self.log['unacked_local_updates']
       +    def get_unacked_local_updates(self) -> Dict[int, Sequence[bytes]]:
       +        return self.log['unacked_local_updates2']
        
            ##### Queries re HTLCs:
        
   DIR diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
       t@@ -96,12 +96,13 @@ class Peer(Logger):
                self.transport.send_bytes(raw_msg)
        
            def _store_raw_msg_if_local_update(self, raw_msg: bytes, *, message_name: str, channel_id: Optional[bytes]):
       -        if not (message_name.startswith("update_") or message_name == "commitment_signed"):
       +        is_commitment_signed = message_name == "commitment_signed"
       +        if not (message_name.startswith("update_") or is_commitment_signed):
                    return
                assert channel_id
                chan = self.lnworker.channels[channel_id]  # type: Channel
       -        chan.hm.store_local_update_raw_msg(raw_msg)
       -        if message_name == "commitment_signed":
       +        chan.hm.store_local_update_raw_msg(raw_msg, is_commitment_signed=is_commitment_signed)
       +        if is_commitment_signed:
                    # saving now, to ensure replaying updates works (in case of channel reestablishment)
                    self.lnworker.save_channel(chan)
        
       t@@ -755,8 +756,9 @@ class Peer(Logger):
                # Multiple valid ctxs at the same ctn is a major headache for pre-signing spending txns,
                # e.g. for watchtowers, hence we must ensure these ctxs coincide.
                # We replay the local updates even if they were not yet committed.
       -        for raw_upd_msg in chan.hm.get_unacked_local_updates():
       -            self.transport.send_bytes(raw_upd_msg)
       +        for ctn, messages in chan.hm.get_unacked_local_updates():
       +            for raw_upd_msg in messages:
       +                self.transport.send_bytes(raw_upd_msg)
        
                should_close_we_are_ahead = False
                should_close_they_are_ahead = False
       t@@ -831,6 +833,7 @@ class Peer(Logger):
                    self.lnworker.force_close_channel(chan_id)
                    return
        
       +        # note: chan.short_channel_id being set implies the funding txn is already at sufficient depth
                if their_next_local_ctn == next_local_ctn == 1 and chan.short_channel_id:
                    self.send_funding_locked(chan)
                # checks done
   DIR diff --git a/electrum/lntransport.py b/electrum/lntransport.py
       t@@ -88,7 +88,7 @@ def create_ephemeral_key() -> (bytes, bytes):
        
        class LNTransportBase:
        
       -    def send_bytes(self, msg):
       +    def send_bytes(self, msg: bytes) -> None:
                l = len(msg).to_bytes(2, 'big')
                lc = aead_encrypt(self.sk, self.sn(), b'', l)
                c = aead_encrypt(self.sk, self.sn(), b'', msg)
   DIR diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py
       t@@ -211,3 +211,34 @@ class TestHTLCManager(unittest.TestCase):
                self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
                self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL))
                self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
       +
       +    def test_unacked_local_updates(self):
       +        A = HTLCManager()
       +        B = HTLCManager()
       +        A.channel_open_finished()
       +        B.channel_open_finished()
       +        self.assertEqual({}, A.get_unacked_local_updates())
       +
       +        ah0 = H('A', 0)
       +        B.recv_htlc(A.send_htlc(ah0))
       +        A.store_local_update_raw_msg(b"upd_msg0", is_commitment_signed=False)
       +        self.assertEqual({1: [b"upd_msg0"]}, A.get_unacked_local_updates())
       +
       +        ah1 = H('A', 1)
       +        B.recv_htlc(A.send_htlc(ah1))
       +        A.store_local_update_raw_msg(b"upd_msg1", is_commitment_signed=False)
       +        self.assertEqual({1: [b"upd_msg0", b"upd_msg1"]}, A.get_unacked_local_updates())
       +
       +        A.send_ctx()
       +        B.recv_ctx()
       +        A.store_local_update_raw_msg(b"ctx1", is_commitment_signed=True)
       +        self.assertEqual({1: [b"upd_msg0", b"upd_msg1", b"ctx1"]}, A.get_unacked_local_updates())
       +
       +        ah2 = H('A', 2)
       +        B.recv_htlc(A.send_htlc(ah2))
       +        A.store_local_update_raw_msg(b"upd_msg2", is_commitment_signed=False)
       +        self.assertEqual({1: [b"upd_msg0", b"upd_msg1", b"ctx1"], 2: [b"upd_msg2"]}, A.get_unacked_local_updates())
       +
       +        B.send_rev()
       +        A.recv_rev()
       +        self.assertEqual({2: [b"upd_msg2"]}, A.get_unacked_local_updates())