URI: 
       tMPP receive: allow payer to retry after mpp timeout - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 7f61f22857e54954dfe1132ceb91f8006e055a57
   DIR parent 0ce6adffcc271a8e31705e16b7b24a0e93a80ca0
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Sat, 27 Feb 2021 11:48:14 +0100
       
       MPP receive: allow payer to retry after mpp timeout
       
       Diffstat:
         M electrum/lnpeer.py                  |       6 +++---
         M electrum/lnworker.py                |      44 ++++++++++++++++++-------------
         M electrum/tests/test_lnpeer.py       |       4 ++--
       
       3 files changed, 30 insertions(+), 24 deletions(-)
       ---
   DIR diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
       t@@ -1576,10 +1576,10 @@ class Peer(Logger):
                invoice_msat = info.amount_msat
                if not (invoice_msat is None or invoice_msat <= total_msat <= 2 * invoice_msat):
                    raise exc_incorrect_or_unknown_pd
       -        accepted, expired = self.lnworker.htlc_received(chan.short_channel_id, htlc, total_msat)
       -        if accepted:
       +        mpp_status = self.lnworker.add_received_htlc(chan.short_channel_id, htlc, total_msat)
       +        if mpp_status == True:
                    return preimage
       -        elif expired:
       +        elif mpp_status == False:
                    raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'')
                else:
                    return None
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -657,7 +657,7 @@ class LNWallet(LNWorker):
                    self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
        
                self.sent_htlcs = defaultdict(asyncio.Queue)  # type: Dict[bytes, asyncio.Queue[HtlcLog]]
       -        self.received_htlcs = defaultdict(set) # type: Dict[bytes, set]
       +        self.received_htlcs = dict()                  # RHASH -> mpp_status, htlc_set
                self.htlc_routes = dict()
        
                self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
       t@@ -1682,24 +1682,30 @@ class LNWallet(LNWorker):
                    self.payments[key] = info.amount_msat, info.direction, info.status
                self.wallet.save_db()
        
       -    def htlc_received(self, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int):
       -        status = self.get_payment_status(htlc.payment_hash)
       -        if status == PR_PAID:
       -            return True, None
       -        s = self.received_htlcs[htlc.payment_hash]
       -        if (short_channel_id, htlc) not in s:
       -            s.add((short_channel_id, htlc))
       -        total = sum([htlc.amount_msat for scid, htlc in s])
       -        first_timestamp = min([htlc.timestamp for scid, htlc in s])
       -        expired = time.time() - first_timestamp > MPP_EXPIRY
       -        if total == expected_msat and not expired:
       -            # status must be persisted
       -            self.set_payment_status(htlc.payment_hash, PR_PAID)
       -            util.trigger_callback('request_status', self.wallet, htlc.payment_hash.hex(), PR_PAID)
       -            return True, None
       -        if expired:
       -            return None, True
       -        return None, None
       +    def add_received_htlc(self, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]:
       +        """ return MPP status: True (accepted), False (expired) or None """
       +        payment_hash = htlc.payment_hash
       +        mpp_status, htlc_set = self.received_htlcs.get(payment_hash, (None, set()))
       +        key = (short_channel_id, htlc)
       +        if key not in htlc_set:
       +            htlc_set.add(key)
       +        if mpp_status is None:
       +            total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
       +            first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set])
       +            expired = time.time() - first_timestamp > MPP_EXPIRY
       +            if expired:
       +                mpp_status = False
       +            elif total == expected_msat:
       +                mpp_status = True
       +                self.set_payment_status(payment_hash, PR_PAID)
       +                util.trigger_callback('request_status', self.wallet, payment_hash.hex(), PR_PAID)
       +        if mpp_status is not None:
       +            htlc_set.remove(key)
       +        if len(htlc_set) > 0:
       +            self.received_htlcs[payment_hash] = mpp_status, htlc_set
       +        elif payment_hash in self.received_htlcs:
       +            self.received_htlcs.pop(payment_hash)
       +        return mpp_status
        
            def get_payment_status(self, payment_hash):
                info = self.get_payment_info(payment_hash)
   DIR diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
       t@@ -132,7 +132,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
                # used in tests
                self.enable_htlc_settle = asyncio.Event()
                self.enable_htlc_settle.set()
       -        self.received_htlcs = defaultdict(set)
       +        self.received_htlcs = dict()
                self.sent_htlcs = defaultdict(asyncio.Queue)
                self.htlc_routes = defaultdict(list)
        
       t@@ -170,7 +170,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
            set_invoice_status = LNWallet.set_invoice_status
            set_payment_status = LNWallet.set_payment_status
            get_payment_status = LNWallet.get_payment_status
       -    htlc_received = LNWallet.htlc_received
       +    add_received_htlc = LNWallet.add_received_htlc
            htlc_fulfilled = LNWallet.htlc_fulfilled
            htlc_failed = LNWallet.htlc_failed
            save_preimage = LNWallet.save_preimage