URI: 
       tlnchan: use NamedTuple for logs instead of dict with static keys (adds, locked_in, settles, fails) - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 39fa13b93861d324406ac58cf54047dea961ce5d
   DIR parent 72187a43416831306fd13b068b7341a3d0c05003
  HTML Author: Janus <ysangkok@gmail.com>
       Date:   Fri, 26 Oct 2018 18:46:33 +0200
       
       lnchan: use NamedTuple for logs instead of dict with static keys (adds, locked_in, settles, fails)
       
       Diffstat:
         M electrum/lnchan.py                  |     101 ++++++++++++++++++++-----------
         M electrum/tests/test_lnchan.py       |       6 +++++-
       
       2 files changed, 69 insertions(+), 38 deletions(-)
       ---
   DIR diff --git a/electrum/lnchan.py b/electrum/lnchan.py
       t@@ -26,7 +26,7 @@ from collections import namedtuple, defaultdict
        import binascii
        import json
        from enum import Enum, auto
       -from typing import Optional, Dict, List, Tuple
       +from typing import Optional, Dict, List, Tuple, NamedTuple, Set
        from copy import deepcopy
        
        from .util import bfh, PrintError, bh2u
       t@@ -121,6 +121,20 @@ def str_bytes_dict_from_save(x):
        def str_bytes_dict_to_save(x):
            return {str(k): bh2u(v) for k, v in x.items()}
        
       +class HtlcChanges(NamedTuple):
       +    # ints are htlc ids
       +    adds: Dict[int, UpdateAddHtlc]
       +    settles: Set[int]
       +    fails: Set[int]
       +    locked_in: Set[int]
       +
       +    @staticmethod
       +    def new():
       +        """
       +        Since we can't use default arguments for these types (they would be shared among instances)
       +        """
       +        return HtlcChanges({}, set(), set(), set())
       +
        class Channel(PrintError):
            def diagnostic_name(self):
                if self.name:
       t@@ -158,18 +172,12 @@ class Channel(PrintError):
                # any past commitment transaction and use that instead; until then...
                self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"])
        
       -        template = lambda: {
       -            'adds': {}, # Dict[HTLC_ID, UpdateAddHtlc]
       -            'settles': [], # List[HTLC_ID]
       -            'fails': [], # List[HTLC_ID]
       -            'locked_in': [], # List[HTLC_ID]
       -        }
       -        self.log = {LOCAL: template(), REMOTE: template()}
       +        self.log = {LOCAL: HtlcChanges.new(), REMOTE: HtlcChanges.new()}
                for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]:
                    if strname not in state: continue
                    for y in state[strname]:
                        htlc = UpdateAddHtlc(**y)
       -                self.log[subject]['adds'][htlc.htlc_id] = htlc
       +                self.log[subject].adds[htlc.htlc_id] = htlc
        
                self.name = name
        
       t@@ -185,6 +193,9 @@ class Channel(PrintError):
        
                self.settled = {LOCAL: state.get('settled_local', []), REMOTE: state.get('settled_remote', [])}
        
       +        for sub in (LOCAL, REMOTE):
       +            self.log[sub].locked_in.update(self.log[sub].adds.keys())
       +
            def set_state(self, state: str):
                self._state = state
        
       t@@ -232,7 +243,7 @@ class Channel(PrintError):
                assert type(htlc) is dict
                self._check_can_pay(htlc['amount_msat'])
                htlc = UpdateAddHtlc(**htlc, htlc_id=self.config[LOCAL].next_htlc_id)
       -        self.log[LOCAL]['adds'][htlc.htlc_id] = htlc
       +        self.log[LOCAL].adds[htlc.htlc_id] = htlc
                self.print_error("add_htlc")
                self.config[LOCAL]=self.config[LOCAL]._replace(next_htlc_id=htlc.htlc_id + 1)
                return htlc.htlc_id
       t@@ -251,7 +262,7 @@ class Channel(PrintError):
                    raise RemoteMisbehaving('Remote dipped below channel reserve.' +\
                            f' Available at remote: {self.available_to_spend(REMOTE)},' +\
                            f' HTLC amount: {htlc.amount_msat}')
       -        adds = self.log[REMOTE]['adds']
       +        adds = self.log[REMOTE].adds
                adds[htlc.htlc_id] = htlc
                self.print_error("receive_htlc")
                self.config[REMOTE]=self.config[REMOTE]._replace(next_htlc_id=htlc.htlc_id + 1)
       t@@ -309,11 +320,11 @@ class Channel(PrintError):
                for sub in (LOCAL, REMOTE):
                    log = self.log[sub]
                    yield (sub, deepcopy(log))
       -            for htlc_id in log['fails']:
       -                log['adds'].pop(htlc_id)
       -            log['fails'].clear()
       +            for htlc_id in log.fails:
       +                log.adds.pop(htlc_id)
       +            log.fails.clear()
        
       -        self.log[subject]['locked_in'] |= self.log[subject]['adds'].keys()
       +        self.log[subject].locked_in.update(self.log[subject].adds.keys())
        
            def receive_new_commitment(self, sig, htlc_sigs):
                """
       t@@ -474,11 +485,11 @@ class Channel(PrintError):
                    """
                    old_amount = htlcsum(self.htlcs(subject, False))
        
       -            for htlc_id in self.log[subject]['settles']:
       -                adds = self.log[subject]['adds']
       +            for htlc_id in self.log[subject].settles:
       +                adds = self.log[subject].adds
                        htlc = adds.pop(htlc_id)
                        self.settled[subject].append(htlc.amount_msat)
       -            self.log[subject]['settles'].clear()
       +            self.log[subject].settles.clear()
        
                    return old_amount - htlcsum(self.htlcs(subject, False))
        
       t@@ -533,7 +544,7 @@ class Channel(PrintError):
                pending outgoing HTLCs, is used in the UI.
                """
                return self.balance(subject)\
       -                - htlcsum(self.log[subject]['adds'].values())
       +                - htlcsum(self.log[subject].adds.values())
        
            def available_to_spend(self, subject):
                """
       t@@ -541,7 +552,7 @@ class Channel(PrintError):
                not be used in the UI cause it fluctuates (commit fee)
                """
                return self.balance_minus_outgoing_htlcs(subject)\
       -                - htlcsum(self.log[subject]['adds'].values())\
       +                - htlcsum(self.log[subject].adds.values())\
                        - self.config[-subject].reserve_sat * 1000\
                        - calc_onchain_fees(
                              # TODO should we include a potential new htlc, when we are called from receive_htlc?
       t@@ -601,10 +612,10 @@ class Channel(PrintError):
                """
                update_log = self.log[subject]
                res = []
       -        for htlc in update_log['adds'].values():
       -            locked_in = htlc.htlc_id in update_log['locked_in']
       -            settled = htlc.htlc_id in update_log['settles']
       -            failed =  htlc.htlc_id in update_log['fails']
       +        for htlc in update_log.adds.values():
       +            locked_in = htlc.htlc_id in update_log.locked_in
       +            settled = htlc.htlc_id in update_log.settles
       +            failed =  htlc.htlc_id in update_log.fails
                    if not locked_in:
                        continue
                    if only_pending == (settled or failed):
       t@@ -617,25 +628,33 @@ class Channel(PrintError):
                SettleHTLC attempts to settle an existing outstanding received HTLC.
                """
                self.print_error("settle_htlc")
       -        htlc = self.log[REMOTE]['adds'][htlc_id]
       +        log = self.log[REMOTE]
       +        htlc = log.adds[htlc_id]
                assert htlc.payment_hash == sha256(preimage)
       -        self.log[REMOTE]['settles'].append(htlc_id)
       +        assert htlc_id not in log.settles
       +        log.settles.add(htlc_id)
                # not saving preimage because it's already saved in LNWorker.invoices
        
            def receive_htlc_settle(self, preimage, htlc_id):
                self.print_error("receive_htlc_settle")
       -        htlc = self.log[LOCAL]['adds'][htlc_id]
       +        log = self.log[LOCAL]
       +        htlc = log.adds[htlc_id]
                assert htlc.payment_hash == sha256(preimage)
       -        self.log[LOCAL]['settles'].append(htlc_id)
       +        assert htlc_id not in log.settles
       +        log.settles.add(htlc_id)
                # we don't save the preimage because we don't need to forward it anyway
        
            def fail_htlc(self, htlc_id):
                self.print_error("fail_htlc")
       -        self.log[REMOTE]['fails'].append(htlc_id)
       +        log = self.log[REMOTE]
       +        assert htlc_id not in log.fails
       +        log.fails.add(htlc_id)
        
            def receive_fail_htlc(self, htlc_id):
                self.print_error("receive_fail_htlc")
       -        self.log[LOCAL]['fails'].append(htlc_id)
       +        log = self.log[LOCAL]
       +        assert htlc_id not in log.fails
       +        log.fails.add(htlc_id)
        
            @property
            def current_height(self):
       t@@ -666,8 +685,8 @@ class Channel(PrintError):
                removed = []
                htlcs = []
                log = self.log[subject]
       -        for htlc_id, i in log['adds'].items():
       -            locked_in = htlc_id in log['locked_in']
       +        for i in log.adds.values():
       +            locked_in = i.htlc_id in log.locked_in
                    if locked_in:
                        htlcs.append(i._asdict())
                    else:
       t@@ -710,18 +729,26 @@ class Channel(PrintError):
        
            def serialize(self):
                namedtuples_to_dict = lambda v: {i: j._asdict() if isinstance(j, tuple) else j for i, j in v._asdict().items()}
       -        serialized_channel = {k: namedtuples_to_dict(v) if isinstance(v, tuple) else v for k, v in self.to_save().items()}
       +        serialized_channel = {}
       +        to_save_ref = self.to_save()
       +        for k, v in to_save_ref.items():
       +            if isinstance(v, tuple):
       +                serialized_channel[k] = namedtuples_to_dict(v)
       +            else:
       +                serialized_channel[k] = v
                dumped = ChannelJsonEncoder().encode(serialized_channel)
                roundtripped = json.loads(dumped)
                reconstructed = Channel(roundtripped)
       -        if reconstructed.to_save() != self.to_save():
       -            from pprint import pformat
       +        to_save_new = reconstructed.to_save()
       +        if to_save_new != to_save_ref:
       +            from pprint import PrettyPrinter
       +            pp = PrettyPrinter(indent=168)
                    try:
                        from deepdiff import DeepDiff
                    except ImportError:
       -                raise Exception("Channels did not roundtrip serialization without changes:\n" + pformat(reconstructed.to_save()) + "\n" + pformat(self.to_save()))
       +                raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(to_save_ref) + "\n" + pp.pformat(to_save_new))
                    else:
       -                raise Exception("Channels did not roundtrip serialization without changes:\n" + pformat(DeepDiff(reconstructed.to_save(), self.to_save())))
       +                raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(DeepDiff(to_save_ref, to_save_new)))
                return roundtripped
        
            def __str__(self):
   DIR diff --git a/electrum/tests/test_lnchan.py b/electrum/tests/test_lnchan.py
       t@@ -183,7 +183,7 @@ class TestChannel(unittest.TestCase):
        
                self.bob_pending_remote_balance = after
        
       -        self.htlc = self.bob_channel.log[lnutil.REMOTE]['adds'][0]
       +        self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0]
        
            def test_SimpleAddSettleWorkflow(self):
                alice_channel, bob_channel = self.alice_channel, self.bob_channel
       t@@ -217,6 +217,10 @@ class TestChannel(unittest.TestCase):
                # forward since she's sending an outgoing HTLC.
                alice_channel.receive_revocation(bobRevocation)
        
       +        # test serializing with locked_in htlc
       +        self.assertEqual(len(alice_channel.to_save()['local_log']), 1)
       +        alice_channel.serialize()
       +
                # Alice then processes bob's signature, and since she just received
                # the revocation, she expect this signature to cover everything up to
                # the point where she sent her signature, including the HTLC.