tbasic_mpp: receive multi-part payments - electrum - Electrum Bitcoin wallet HTML git clone https://git.parazyd.org/electrum DIR Log DIR Files DIR Refs DIR Submodules --- DIR commit ef5a26544981a426945b1d32b22b170272cb0671 DIR parent c0bf9b4509cfdeace824a9054d9c90449ef46ad6 HTML Author: ThomasV <thomasv@electrum.org> Date: Wed, 27 Jan 2021 19:27:06 +0100 basic_mpp: receive multi-part payments Diffstat: M electrum/lnchannel.py | 6 +----- M electrum/lnonion.py | 1 + M electrum/lnpeer.py | 37 +++++++++++++++++++------------ M electrum/lnworker.py | 27 ++++++++++++++++++++++++--- M electrum/tests/test_lnpeer.py | 2 ++ 5 files changed, 51 insertions(+), 22 deletions(-) --- DIR diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py t@@ -969,10 +969,6 @@ class Channel(AbstractChannel): raise Exception("refusing to revoke as remote sig does not fit") with self.db_lock: self.hm.send_rev() - if self.lnworker: - received = self.hm.received_in_ctn(new_ctn) - for htlc in received: - self.lnworker.payment_received(self, htlc.payment_hash) last_secret, last_point = self.get_secret_and_point(LOCAL, new_ctn - 1) next_secret, next_point = self.get_secret_and_point(LOCAL, new_ctn + 1) return RevokeAndAck(last_secret, next_point) t@@ -1054,7 +1050,7 @@ class Channel(AbstractChannel): if is_sent: self.lnworker.payment_sent(self, payment_hash) else: - self.lnworker.payment_received(self, payment_hash) + self.lnworker.payment_received(payment_hash) def balance(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None) -> int: assert type(whose) is HTLCOwner DIR diff --git a/electrum/lnonion.py b/electrum/lnonion.py t@@ -498,6 +498,7 @@ class OnionFailureCode(IntEnum): CHANNEL_DISABLED = UPDATE | 20 EXPIRY_TOO_FAR = 21 INVALID_ONION_PAYLOAD = PERM | 22 + MPP_TIMEOUT = 23 # don't use these elsewhere, the names are ambiguous without context DIR diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py t@@ -1389,10 +1389,6 @@ class Peer(Logger): reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') return None, reason expected_received_msat = info.amount_msat - if expected_received_msat is not None and \ - not (expected_received_msat <= htlc.amount_msat <= 2 * expected_received_msat): - reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') - return None, reason # Check that our blockchain tip is sufficiently recent so that we have an approx idea of the height. # We should not release the preimage for an HTLC that its sender could already time out as # then they might try to force-close and it becomes a race. t@@ -1415,20 +1411,34 @@ class Peer(Logger): data=htlc.cltv_expiry.to_bytes(4, byteorder="big")) return None, reason try: - amount_from_onion = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"] + amt_to_forward = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"] except: reason = OnionRoutingFailureMessage(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') return None, reason try: - amount_from_onion = processed_onion.hop_data.payload["payment_data"]["total_msat"] + total_msat = processed_onion.hop_data.payload["payment_data"]["total_msat"] except: - pass # fall back to "amt_to_forward" - if amount_from_onion > htlc.amount_msat: - reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT, - data=htlc.amount_msat.to_bytes(8, byteorder="big")) + total_msat = amt_to_forward # fall back to "amt_to_forward" + + if amt_to_forward != htlc.amount_msat: + reason = OnionRoutingFailureMessage( + code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT, + data=total_msat.to_bytes(8, byteorder="big")) return None, reason - # all good - return preimage, None + if expected_received_msat is None: + return preimage, None + if not (expected_received_msat <= total_msat <= 2 * expected_received_msat): + reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') + return None, reason + accepted, expired = self.lnworker.htlc_received(chan.short_channel_id, htlc, expected_received_msat) + if accepted: + return preimage, None + elif expired: + reason = OnionRoutingFailureMessage(code=OnionFailureCode.MPP_TIMEOUT) + return None, reason + else: + # waiting for more htlcs + return None, None def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes): self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") t@@ -1669,7 +1679,7 @@ class Peer(Logger): for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_info) in unfulfilled.items(): if not chan.hm.is_add_htlc_irrevocably_committed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): continue - chan.logger.info(f'found unfulfilled htlc: {htlc_id}') + #chan.logger.info(f'found unfulfilled htlc: {htlc_id}') htlc = chan.hm.get_htlc_by_id(REMOTE, htlc_id) payment_hash = htlc.payment_hash error_reason = None # type: Optional[OnionRoutingFailureMessage] t@@ -1694,7 +1704,6 @@ class Peer(Logger): error_reason = OnionRoutingFailureMessage(code=OnionFailureCode.INVALID_ONION_VERSION, data=sha256(onion_packet_bytes)) if self.network.config.get('test_fail_htlcs_with_temp_node_failure'): error_reason = OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') - if not error_reason: if processed_onion.are_we_final: preimage, error_reason = self.maybe_fulfill_htlc( DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py t@@ -86,6 +86,7 @@ SAVED_PR_STATUS = [PR_PAID, PR_UNPAID] # status that are persisted NUM_PEERS_TARGET = 4 +MPP_EXPIRY = 120 FALLBACK_NODE_LIST_TESTNET = ( t@@ -164,7 +165,8 @@ BASE_FEATURES = LnFeatures(0)\ LNWALLET_FEATURES = BASE_FEATURES\ | LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ\ | LnFeatures.OPTION_STATIC_REMOTEKEY_REQ\ - | LnFeatures.GOSSIP_QUERIES_REQ + | LnFeatures.GOSSIP_QUERIES_REQ\ + | LnFeatures.BASIC_MPP_OPT LNGOSSIP_FEATURES = BASE_FEATURES\ | LnFeatures.GOSSIP_QUERIES_OPT\ t@@ -581,6 +583,7 @@ class LNWallet(LNWorker): self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self) self.pending_payments = defaultdict(asyncio.Future) # type: Dict[bytes, asyncio.Future[BarePaymentAttemptLog]] + self.pending_htlcs = defaultdict(set) # type: Dict[bytes, set] self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self) # detect inflight payments t@@ -1284,6 +1287,24 @@ class LNWallet(LNWorker): self.payments[key] = info.amount_msat, info.direction, info.status self.wallet.save_db() + def htlc_received(self, short_channel_id, htlc, expected_msat): + status = self.get_payment_status(htlc.payment_hash) + if status == PR_PAID: + return True, None + s = self.pending_htlcs[htlc.payment_hash] + if (short_channel_id, htlc) not in s: + s.add((short_channel_id, htlc)) + total = sum([htlc.amount_msat for scid, htlc in s]) + first_timestamp = min([htlc.timestamp for scid, htlc in s]) + expired = time.time() - first_timestamp > MPP_EXPIRY + if total >= expected_msat and not expired: + # status must be persisted + self.payment_received(htlc.payment_hash) + return True, None + if expired: + return None, True + return None, None + def get_payment_status(self, payment_hash): info = self.get_payment_info(payment_hash) return info.status if info else PR_UNPAID t@@ -1359,10 +1380,10 @@ class LNWallet(LNWorker): util.trigger_callback('payment_succeeded', self.wallet, key) util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id) - def payment_received(self, chan, payment_hash: bytes): + def payment_received(self, payment_hash: bytes): self.set_payment_status(payment_hash, PR_PAID) util.trigger_callback('request_status', self.wallet, payment_hash.hex(), PR_PAID) - util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id) + #util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id) async def _calc_routing_hints_for_invoice(self, amount_msat: Optional[int]): """calculate routing hints (BOLT-11 'r' field)""" DIR diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py t@@ -132,6 +132,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): # used in tests self.enable_htlc_settle = asyncio.Event() self.enable_htlc_settle.set() + self.pending_htlcs = defaultdict(set) def get_invoice_status(self, key): pass t@@ -167,6 +168,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): set_invoice_status = LNWallet.set_invoice_status set_payment_status = LNWallet.set_payment_status get_payment_status = LNWallet.get_payment_status + htlc_received = LNWallet.htlc_received await_payment = LNWallet.await_payment payment_received = LNWallet.payment_received payment_sent = LNWallet.payment_sent