URI: 
       tlnchan: make sign_next_commitment revert state - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 72187a43416831306fd13b068b7341a3d0c05003
   DIR parent 001bb4ca0983b7f420f156781c4918ffc024ecbf
  HTML Author: Janus <ysangkok@gmail.com>
       Date:   Fri, 26 Oct 2018 17:05:03 +0200
       
       lnchan: make sign_next_commitment revert state
       
       Diffstat:
         M electrum/lnchan.py                  |      80 ++++++++++++++++++-------------
         M electrum/tests/test_lnchan.py       |      17 +++++++++++++++++
       
       2 files changed, 63 insertions(+), 34 deletions(-)
       ---
   DIR diff --git a/electrum/lnchan.py b/electrum/lnchan.py
       t@@ -27,6 +27,7 @@ import binascii
        import json
        from enum import Enum, auto
        from typing import Optional, Dict, List, Tuple
       +from copy import deepcopy
        
        from .util import bfh, PrintError, bh2u
        from .bitcoin import TYPE_SCRIPT, TYPE_ADDRESS
       t@@ -79,21 +80,20 @@ class FeeUpdate(defaultdict):
                    return self.rate
                # implicit return None
        
       -class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'locked_in', 'htlc_id'])):
       +class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id'])):
       +    """
       +    This whole class body is so that if you pass a hex-string as payment_hash,
       +    it is decoded to bytes. Bytes can't be saved to disk, so we save hex-strings.
       +    """
            __slots__ = ()
            def __new__(cls, *args, **kwargs):
                if len(args) > 0:
                    args = list(args)
                    if type(args[1]) is str:
                        args[1] = bfh(args[1])
       -            args[3] = {HTLCOwner(int(x)): y for x,y in args[3].items()}
                    return super().__new__(cls, *args)
                if type(kwargs['payment_hash']) is str:
                    kwargs['payment_hash'] = bfh(kwargs['payment_hash'])
       -        if 'locked_in' not in kwargs:
       -            kwargs['locked_in'] = {LOCAL: None, REMOTE: None}
       -        else:
       -            kwargs['locked_in'] = {HTLCOwner(int(x)): y for x,y in kwargs['locked_in'].items()}
                return super().__new__(cls, **kwargs)
        
        def decodeAll(d, local):
       t@@ -162,6 +162,7 @@ class Channel(PrintError):
                    'adds': {}, # Dict[HTLC_ID, UpdateAddHtlc]
                    'settles': [], # List[HTLC_ID]
                    'fails': [], # List[HTLC_ID]
       +            'locked_in': [], # List[HTLC_ID]
                }
                self.log = {LOCAL: template(), REMOTE: template()}
                for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]:
       t@@ -269,7 +270,8 @@ class Channel(PrintError):
                This docstring was adapted from LND.
                """
                self.print_error("sign_next_commitment")
       -        self.lock_in_htlc_changes(LOCAL)
       +
       +        old_logs = dict(self.lock_in_htlc_changes(LOCAL))
        
                pending_remote_commitment = self.pending_remote_commitment
                sig_64 = sign_and_get_sig_string(pending_remote_commitment, self.config[LOCAL], self.config[REMOTE])
       t@@ -290,29 +292,28 @@ class Channel(PrintError):
                        htlc_sig = ecc.sig_string_from_der_sig(sig[:-1])
                        htlcsigs.append((pending_remote_commitment.htlc_output_indices[htlc.payment_hash], htlc_sig))
        
       -        for pending_fee in self.fee_mgr:
       -            if not self.constraints.is_initiator:
       -                pending_fee[FUNDEE_SIGNED] = True
       -            if self.constraints.is_initiator and pending_fee[FUNDEE_ACKED]:
       -                pending_fee[FUNDER_SIGNED] = True
       -
                self.process_new_offchain_ctx(pending_remote_commitment, ours=False)
        
                htlcsigs.sort()
                htlcsigs = [x[1] for x in htlcsigs]
        
       +        # we can't know if this message arrives.
       +        # since we shouldn't actually throw away
       +        # failed htlcs yet (or mark htlc locked in),
       +        # roll back the changes that were made
       +        self.log = old_logs
       +
                return sig_64, htlcsigs
        
            def lock_in_htlc_changes(self, subject):
                for sub in (LOCAL, REMOTE):
       -            for htlc_id in self.log[-sub]['fails']:
       -                adds = self.log[sub]['adds']
       -                htlc = adds.pop(htlc_id)
       -            self.log[-sub]['fails'].clear()
       +            log = self.log[sub]
       +            yield (sub, deepcopy(log))
       +            for htlc_id in log['fails']:
       +                log['adds'].pop(htlc_id)
       +            log['fails'].clear()
        
       -        for htlc in self.log[subject]['adds'].values():
       -            if htlc.locked_in[subject] is None:
       -                htlc.locked_in[subject] = self.config[subject].ctn
       +        self.log[subject]['locked_in'] |= self.log[subject]['adds'].keys()
        
            def receive_new_commitment(self, sig, htlc_sigs):
                """
       t@@ -328,7 +329,9 @@ class Channel(PrintError):
                This docstring is from LND.
                """
                self.print_error("receive_new_commitment")
       -        self.lock_in_htlc_changes(REMOTE)
       +
       +        for _ in self.lock_in_htlc_changes(REMOTE): pass
       +
                assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes
        
                pending_local_commitment = self.pending_local_commitment
       t@@ -443,11 +446,20 @@ class Channel(PrintError):
            def receive_revocation(self, revocation) -> Tuple[int, int]:
                self.print_error("receive_revocation")
        
       +        old_logs = dict(self.lock_in_htlc_changes(LOCAL))
       +
                cur_point = self.config[REMOTE].current_per_commitment_point
                derived_point = ecc.ECPrivkey(revocation.per_commitment_secret).get_public_key_bytes(compressed=True)
                if cur_point != derived_point:
       +            self.log = old_logs
                    raise Exception('revoked secret not for current point')
        
       +        for pending_fee in self.fee_mgr:
       +            if not self.constraints.is_initiator:
       +                pending_fee[FUNDEE_SIGNED] = True
       +            if self.constraints.is_initiator and pending_fee[FUNDEE_ACKED]:
       +                pending_fee[FUNDER_SIGNED] = True
       +
                # FIXME not sure this is correct... but it seems to work
                # if there are update_add_htlc msgs between commitment_signed and rev_ack,
                # this might break
       t@@ -462,11 +474,11 @@ class Channel(PrintError):
                    """
                    old_amount = htlcsum(self.htlcs(subject, False))
        
       -            for htlc_id in self.log[-subject]['settles']:
       +            for htlc_id in self.log[subject]['settles']:
                        adds = self.log[subject]['adds']
                        htlc = adds.pop(htlc_id)
                        self.settled[subject].append(htlc.amount_msat)
       -            self.log[-subject]['settles'].clear()
       +            self.log[subject]['settles'].clear()
        
                    return old_amount - htlcsum(self.htlcs(subject, False))
        
       t@@ -588,13 +600,12 @@ class Channel(PrintError):
                only_pending: require the htlc's settlement to be pending (needs additional signatures/acks)
                """
                update_log = self.log[subject]
       -        other_log = self.log[-subject]
                res = []
                for htlc in update_log['adds'].values():
       -            locked_in = htlc.locked_in[subject]
       -            settled = htlc.htlc_id in other_log['settles']
       -            failed =  htlc.htlc_id in other_log['fails']
       -            if locked_in is None:
       +            locked_in = htlc.htlc_id in update_log['locked_in']
       +            settled = htlc.htlc_id in update_log['settles']
       +            failed =  htlc.htlc_id in update_log['fails']
       +            if not locked_in:
                        continue
                    if only_pending == (settled or failed):
                        continue
       t@@ -608,23 +619,23 @@ class Channel(PrintError):
                self.print_error("settle_htlc")
                htlc = self.log[REMOTE]['adds'][htlc_id]
                assert htlc.payment_hash == sha256(preimage)
       -        self.log[LOCAL]['settles'].append(htlc_id)
       +        self.log[REMOTE]['settles'].append(htlc_id)
                # not saving preimage because it's already saved in LNWorker.invoices
        
            def receive_htlc_settle(self, preimage, htlc_id):
                self.print_error("receive_htlc_settle")
                htlc = self.log[LOCAL]['adds'][htlc_id]
                assert htlc.payment_hash == sha256(preimage)
       -        self.log[REMOTE]['settles'].append(htlc_id)
       +        self.log[LOCAL]['settles'].append(htlc_id)
                # we don't save the preimage because we don't need to forward it anyway
        
            def fail_htlc(self, htlc_id):
                self.print_error("fail_htlc")
       -        self.log[LOCAL]['fails'].append(htlc_id)
       +        self.log[REMOTE]['fails'].append(htlc_id)
        
            def receive_fail_htlc(self, htlc_id):
                self.print_error("receive_fail_htlc")
       -        self.log[REMOTE]['fails'].append(htlc_id)
       +        self.log[LOCAL]['fails'].append(htlc_id)
        
            @property
            def current_height(self):
       t@@ -654,8 +665,9 @@ class Channel(PrintError):
                """
                removed = []
                htlcs = []
       -        for i in self.log[subject]['adds'].values():
       -            locked_in = i.locked_in[LOCAL] is not None or i.locked_in[REMOTE] is not None
       +        log = self.log[subject]
       +        for htlc_id, i in log['adds'].items():
       +            locked_in = htlc_id in log['locked_in']
                    if locked_in:
                        htlcs.append(i._asdict())
                    else:
   DIR diff --git a/electrum/tests/test_lnchan.py b/electrum/tests/test_lnchan.py
       t@@ -396,6 +396,23 @@ class TestChannel(unittest.TestCase):
                    self.alice_channel.add_htlc(new)
                self.assertIn('Not enough local balance', cm.exception.args[0])
        
       +    def test_sign_commitment_is_pure(self):
       +        force_state_transition(self.alice_channel, self.bob_channel)
       +        self.htlc_dict['payment_hash'] = bitcoin.sha256(b'\x02' * 32)
       +        aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict)
       +        before_signing = self.alice_channel.to_save()
       +        self.alice_channel.sign_next_commitment()
       +        after_signing = self.alice_channel.to_save()
       +        try:
       +            self.assertEqual(before_signing, after_signing)
       +        except:
       +            try:
       +                from deepdiff import DeepDiff
       +                from pprint import pformat
       +            except ImportError:
       +                raise
       +            raise Exception(pformat(DeepDiff(before_signing, after_signing)))
       +
        class TestAvailableToSpend(unittest.TestCase):
            def test_DesyncHTLCs(self):
                alice_channel, bob_channel = create_test_channels()