URI: 
       tbasic_mpp: receive multi-part payments - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit ef5a26544981a426945b1d32b22b170272cb0671
   DIR parent c0bf9b4509cfdeace824a9054d9c90449ef46ad6
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Wed, 27 Jan 2021 19:27:06 +0100
       
       basic_mpp: receive multi-part payments
       
       Diffstat:
         M electrum/lnchannel.py               |       6 +-----
         M electrum/lnonion.py                 |       1 +
         M electrum/lnpeer.py                  |      37 +++++++++++++++++++------------
         M electrum/lnworker.py                |      27 ++++++++++++++++++++++++---
         M electrum/tests/test_lnpeer.py       |       2 ++
       
       5 files changed, 51 insertions(+), 22 deletions(-)
       ---
   DIR diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py
       t@@ -969,10 +969,6 @@ class Channel(AbstractChannel):
                    raise Exception("refusing to revoke as remote sig does not fit")
                with self.db_lock:
                    self.hm.send_rev()
       -        if self.lnworker:
       -            received = self.hm.received_in_ctn(new_ctn)
       -            for htlc in received:
       -                self.lnworker.payment_received(self, htlc.payment_hash)
                last_secret, last_point = self.get_secret_and_point(LOCAL, new_ctn - 1)
                next_secret, next_point = self.get_secret_and_point(LOCAL, new_ctn + 1)
                return RevokeAndAck(last_secret, next_point)
       t@@ -1054,7 +1050,7 @@ class Channel(AbstractChannel):
                    if is_sent:
                        self.lnworker.payment_sent(self, payment_hash)
                    else:
       -                self.lnworker.payment_received(self, payment_hash)
       +                self.lnworker.payment_received(payment_hash)
        
            def balance(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None) -> int:
                assert type(whose) is HTLCOwner
   DIR diff --git a/electrum/lnonion.py b/electrum/lnonion.py
       t@@ -498,6 +498,7 @@ class OnionFailureCode(IntEnum):
            CHANNEL_DISABLED =                        UPDATE | 20
            EXPIRY_TOO_FAR =                          21
            INVALID_ONION_PAYLOAD =                   PERM | 22
       +    MPP_TIMEOUT =                             23
        
        
        # don't use these elsewhere, the names are ambiguous without context
   DIR diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
       t@@ -1389,10 +1389,6 @@ class Peer(Logger):
                        reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
                        return None, reason
                expected_received_msat = info.amount_msat
       -        if expected_received_msat is not None and \
       -                not (expected_received_msat <= htlc.amount_msat <= 2 * expected_received_msat):
       -            reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
       -            return None, reason
                # Check that our blockchain tip is sufficiently recent so that we have an approx idea of the height.
                # We should not release the preimage for an HTLC that its sender could already time out as
                # then they might try to force-close and it becomes a race.
       t@@ -1415,20 +1411,34 @@ class Peer(Logger):
                        data=htlc.cltv_expiry.to_bytes(4, byteorder="big"))
                    return None, reason
                try:
       -            amount_from_onion = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"]
       +            amt_to_forward = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"]
                except:
                    reason = OnionRoutingFailureMessage(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00')
                    return None, reason
                try:
       -            amount_from_onion = processed_onion.hop_data.payload["payment_data"]["total_msat"]
       +            total_msat = processed_onion.hop_data.payload["payment_data"]["total_msat"]
                except:
       -            pass  # fall back to "amt_to_forward"
       -        if amount_from_onion > htlc.amount_msat:
       -            reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT,
       -                                                data=htlc.amount_msat.to_bytes(8, byteorder="big"))
       +            total_msat = amt_to_forward # fall back to "amt_to_forward"
       +
       +        if amt_to_forward != htlc.amount_msat:
       +            reason = OnionRoutingFailureMessage(
       +                code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT,
       +                data=total_msat.to_bytes(8, byteorder="big"))
                    return None, reason
       -        # all good
       -        return preimage, None
       +        if expected_received_msat is None:
       +            return preimage, None
       +        if not (expected_received_msat <= total_msat <= 2 * expected_received_msat):
       +            reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
       +            return None, reason
       +        accepted, expired = self.lnworker.htlc_received(chan.short_channel_id, htlc, expected_received_msat)
       +        if accepted:
       +            return preimage, None
       +        elif expired:
       +            reason = OnionRoutingFailureMessage(code=OnionFailureCode.MPP_TIMEOUT)
       +            return None, reason
       +        else:
       +            # waiting for more htlcs
       +            return None, None
        
            def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes):
                self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
       t@@ -1669,7 +1679,7 @@ class Peer(Logger):
                        for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_info) in unfulfilled.items():
                            if not chan.hm.is_add_htlc_irrevocably_committed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id):
                                continue
       -                    chan.logger.info(f'found unfulfilled htlc: {htlc_id}')
       +                    #chan.logger.info(f'found unfulfilled htlc: {htlc_id}')
                            htlc = chan.hm.get_htlc_by_id(REMOTE, htlc_id)
                            payment_hash = htlc.payment_hash
                            error_reason = None  # type: Optional[OnionRoutingFailureMessage]
       t@@ -1694,7 +1704,6 @@ class Peer(Logger):
                                    error_reason = OnionRoutingFailureMessage(code=OnionFailureCode.INVALID_ONION_VERSION, data=sha256(onion_packet_bytes))
                                if self.network.config.get('test_fail_htlcs_with_temp_node_failure'):
                                    error_reason = OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
       -
                            if not error_reason:
                                if processed_onion.are_we_final:
                                    preimage, error_reason = self.maybe_fulfill_htlc(
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -86,6 +86,7 @@ SAVED_PR_STATUS = [PR_PAID, PR_UNPAID] # status that are persisted
        
        
        NUM_PEERS_TARGET = 4
       +MPP_EXPIRY = 120
        
        
        FALLBACK_NODE_LIST_TESTNET = (
       t@@ -164,7 +165,8 @@ BASE_FEATURES = LnFeatures(0)\
        LNWALLET_FEATURES = BASE_FEATURES\
            | LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ\
            | LnFeatures.OPTION_STATIC_REMOTEKEY_REQ\
       -    | LnFeatures.GOSSIP_QUERIES_REQ
       +    | LnFeatures.GOSSIP_QUERIES_REQ\
       +    | LnFeatures.BASIC_MPP_OPT
        
        LNGOSSIP_FEATURES = BASE_FEATURES\
            | LnFeatures.GOSSIP_QUERIES_OPT\
       t@@ -581,6 +583,7 @@ class LNWallet(LNWorker):
                    self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
        
                self.pending_payments = defaultdict(asyncio.Future)  # type: Dict[bytes, asyncio.Future[BarePaymentAttemptLog]]
       +        self.pending_htlcs = defaultdict(set) # type: Dict[bytes, set]
        
                self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
                # detect inflight payments
       t@@ -1284,6 +1287,24 @@ class LNWallet(LNWorker):
                    self.payments[key] = info.amount_msat, info.direction, info.status
                self.wallet.save_db()
        
       +    def htlc_received(self, short_channel_id, htlc, expected_msat):
       +        status = self.get_payment_status(htlc.payment_hash)
       +        if status == PR_PAID:
       +            return True, None
       +        s = self.pending_htlcs[htlc.payment_hash]
       +        if (short_channel_id, htlc) not in s:
       +            s.add((short_channel_id, htlc))
       +        total = sum([htlc.amount_msat for scid, htlc in s])
       +        first_timestamp = min([htlc.timestamp for scid, htlc in s])
       +        expired = time.time() - first_timestamp > MPP_EXPIRY
       +        if total >= expected_msat and not expired:
       +            # status must be persisted
       +            self.payment_received(htlc.payment_hash)
       +            return True, None
       +        if expired:
       +            return None, True
       +        return None, None
       +
            def get_payment_status(self, payment_hash):
                info = self.get_payment_info(payment_hash)
                return info.status if info else PR_UNPAID
       t@@ -1359,10 +1380,10 @@ class LNWallet(LNWorker):
                    util.trigger_callback('payment_succeeded', self.wallet, key)
                util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
        
       -    def payment_received(self, chan, payment_hash: bytes):
       +    def payment_received(self, payment_hash: bytes):
                self.set_payment_status(payment_hash, PR_PAID)
                util.trigger_callback('request_status', self.wallet, payment_hash.hex(), PR_PAID)
       -        util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
       +        #util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
        
            async def _calc_routing_hints_for_invoice(self, amount_msat: Optional[int]):
                """calculate routing hints (BOLT-11 'r' field)"""
   DIR diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
       t@@ -132,6 +132,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
                # used in tests
                self.enable_htlc_settle = asyncio.Event()
                self.enable_htlc_settle.set()
       +        self.pending_htlcs = defaultdict(set)
        
            def get_invoice_status(self, key):
                pass
       t@@ -167,6 +168,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
            set_invoice_status = LNWallet.set_invoice_status
            set_payment_status = LNWallet.set_payment_status
            get_payment_status = LNWallet.get_payment_status
       +    htlc_received = LNWallet.htlc_received
            await_payment = LNWallet.await_payment
            payment_received = LNWallet.payment_received
            payment_sent = LNWallet.payment_sent