ttest_lnpeer.py - electrum - Electrum Bitcoin wallet
HTML git clone https://git.parazyd.org/electrum
DIR Log
DIR Files
DIR Refs
DIR Submodules
---
ttest_lnpeer.py (46112B)
---
1 import asyncio
2 import tempfile
3 from decimal import Decimal
4 import os
5 from contextlib import contextmanager
6 from collections import defaultdict
7 import logging
8 import concurrent
9 from concurrent import futures
10 import unittest
11 from typing import Iterable, NamedTuple, Tuple, List
12
13 from aiorpcx import TaskGroup, timeout_after, TaskTimeout
14
15 from electrum import bitcoin
16 from electrum import constants
17 from electrum.network import Network
18 from electrum.ecc import ECPrivkey
19 from electrum import simple_config, lnutil
20 from electrum.lnaddr import lnencode, LnAddr, lndecode
21 from electrum.bitcoin import COIN, sha256
22 from electrum.util import bh2u, create_and_start_event_loop, NetworkRetryManager, bfh
23 from electrum.lnpeer import Peer, UpfrontShutdownScriptViolation
24 from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
25 from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
26 from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner
27 from electrum.lnchannel import ChannelState, PeerState, Channel
28 from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent
29 from electrum.channel_db import ChannelDB
30 from electrum.lnworker import LNWallet, NoPathFound
31 from electrum.lnmsg import encode_msg, decode_msg
32 from electrum.logging import console_stderr_handler, Logger
33 from electrum.lnworker import PaymentInfo, RECEIVED
34 from electrum.lnonion import OnionFailureCode
35 from electrum.lnutil import ChannelBlackList, derive_payment_secret_from_payment_preimage
36 from electrum.lnutil import LOCAL, REMOTE
37 from electrum.invoices import PR_PAID, PR_UNPAID
38
39 from .test_lnchannel import create_test_channels
40 from .test_bitcoin import needs_test_with_all_chacha20_implementations
41 from . import ElectrumTestCase
42
43 def keypair():
44 priv = ECPrivkey.generate_random_key().get_secret_bytes()
45 k1 = Keypair(
46 pubkey=privkey_to_pubkey(priv),
47 privkey=priv)
48 return k1
49
50 @contextmanager
51 def noop_lock():
52 yield
53
54 class MockNetwork:
55 def __init__(self, tx_queue):
56 self.callbacks = defaultdict(list)
57 self.lnwatcher = None
58 self.interface = None
59 user_config = {}
60 user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-")
61 self.config = simple_config.SimpleConfig(user_config, read_user_dir_function=lambda: user_dir)
62 self.asyncio_loop = asyncio.get_event_loop()
63 self.channel_db = ChannelDB(self)
64 self.channel_db.data_loaded.set()
65 self.path_finder = LNPathFinder(self.channel_db)
66 self.tx_queue = tx_queue
67 self._blockchain = MockBlockchain()
68 self.channel_blacklist = ChannelBlackList()
69
70 @property
71 def callback_lock(self):
72 return noop_lock()
73
74 def get_local_height(self):
75 return 0
76
77 def blockchain(self):
78 return self._blockchain
79
80 async def broadcast_transaction(self, tx):
81 if self.tx_queue:
82 await self.tx_queue.put(tx)
83
84 async def try_broadcasting(self, tx, name):
85 await self.broadcast_transaction(tx)
86
87
88 class MockBlockchain:
89
90 def height(self):
91 return 0
92
93 def is_tip_stale(self):
94 return False
95
96
97 class MockWallet:
98
99 def set_label(self, x, y):
100 pass
101
102 def save_db(self):
103 pass
104
105 def add_transaction(self, tx):
106 pass
107
108 def is_lightning_backup(self):
109 return False
110
111 def is_mine(self, addr):
112 return True
113
114
115 class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
116 MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
117 TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0
118
119 def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name):
120 self.name = name
121 Logger.__init__(self)
122 NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
123 self.node_keypair = local_keypair
124 self.network = MockNetwork(tx_queue)
125 self.taskgroup = TaskGroup()
126 self.lnwatcher = None
127 self.listen_server = None
128 self._channels = {chan.channel_id: chan for chan in chans}
129 self.payments = {}
130 self.logs = defaultdict(list)
131 self.wallet = MockWallet()
132 self.features = LnFeatures(0)
133 self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
134 self.features |= LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT
135 self.features |= LnFeatures.VAR_ONION_OPT
136 self.features |= LnFeatures.PAYMENT_SECRET_OPT
137 self.features |= LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT
138 self.pending_payments = defaultdict(asyncio.Future)
139 for chan in chans:
140 chan.lnworker = self
141 self._peers = {} # bytes -> Peer
142 # used in tests
143 self.enable_htlc_settle = asyncio.Event()
144 self.enable_htlc_settle.set()
145 self.enable_htlc_forwarding = asyncio.Event()
146 self.enable_htlc_forwarding.set()
147 self.received_mpp_htlcs = dict()
148 self.sent_htlcs = defaultdict(asyncio.Queue)
149 self.sent_htlcs_routes = dict()
150 self.sent_buckets = defaultdict(set)
151 self.trampoline_forwarding_failures = {}
152 self.inflight_payments = set()
153 self.preimages = {}
154 self.stopping_soon = False
155
156 def get_invoice_status(self, key):
157 pass
158
159 @property
160 def lock(self):
161 return noop_lock()
162
163 @property
164 def channel_db(self):
165 return self.network.channel_db if self.network else None
166
167 @property
168 def channels(self):
169 return self._channels
170
171 @property
172 def peers(self):
173 return self._peers
174
175 def get_channel_by_short_id(self, short_channel_id):
176 with self.lock:
177 for chan in self._channels.values():
178 if chan.short_channel_id == short_channel_id:
179 return chan
180
181 def channel_state_changed(self, chan):
182 pass
183
184 def save_channel(self, chan):
185 print("Ignoring channel save")
186
187 def diagnostic_name(self):
188 return self.name
189
190 async def stop(self):
191 await LNWallet.stop(self)
192 if self.channel_db:
193 self.channel_db.stop()
194 await self.channel_db.stopped_event.wait()
195
196 get_payments = LNWallet.get_payments
197 get_payment_info = LNWallet.get_payment_info
198 save_payment_info = LNWallet.save_payment_info
199 set_invoice_status = LNWallet.set_invoice_status
200 set_request_status = LNWallet.set_request_status
201 set_payment_status = LNWallet.set_payment_status
202 get_payment_status = LNWallet.get_payment_status
203 check_received_mpp_htlc = LNWallet.check_received_mpp_htlc
204 htlc_fulfilled = LNWallet.htlc_fulfilled
205 htlc_failed = LNWallet.htlc_failed
206 save_preimage = LNWallet.save_preimage
207 get_preimage = LNWallet.get_preimage
208 create_route_for_payment = LNWallet.create_route_for_payment
209 create_routes_for_payment = LNWallet.create_routes_for_payment
210 create_routes_from_invoice = LNWallet.create_routes_from_invoice
211 _check_invoice = staticmethod(LNWallet._check_invoice)
212 pay_to_route = LNWallet.pay_to_route
213 pay_to_node = LNWallet.pay_to_node
214 pay_invoice = LNWallet.pay_invoice
215 force_close_channel = LNWallet.force_close_channel
216 try_force_closing = LNWallet.try_force_closing
217 get_first_timestamp = lambda self: 0
218 on_peer_successfully_established = LNWallet.on_peer_successfully_established
219 get_channel_by_id = LNWallet.get_channel_by_id
220 channels_for_peer = LNWallet.channels_for_peer
221 _calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice
222 handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc
223 is_trampoline_peer = LNWallet.is_trampoline_peer
224 wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed
225 on_proxy_changed = LNWallet.on_proxy_changed
226
227
228 class MockTransport:
229 def __init__(self, name):
230 self.queue = asyncio.Queue()
231 self._name = name
232
233 def name(self):
234 return self._name
235
236 async def read_messages(self):
237 while True:
238 yield await self.queue.get()
239
240 class NoFeaturesTransport(MockTransport):
241 """
242 This answers the init message with a init that doesn't signal any features.
243 Used for testing that we require DATA_LOSS_PROTECT.
244 """
245 def send_bytes(self, data):
246 decoded = decode_msg(data)
247 print(decoded)
248 if decoded[0] == 'init':
249 self.queue.put_nowait(encode_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00"))
250
251 class PutIntoOthersQueueTransport(MockTransport):
252 def __init__(self, keypair, name):
253 super().__init__(name)
254 self.other_mock_transport = None
255 self.privkey = keypair.privkey
256
257 def send_bytes(self, data):
258 self.other_mock_transport.queue.put_nowait(data)
259
260 def transport_pair(k1, k2, name1, name2):
261 t1 = PutIntoOthersQueueTransport(k1, name2)
262 t2 = PutIntoOthersQueueTransport(k2, name1)
263 t1.other_mock_transport = t2
264 t2.other_mock_transport = t1
265 return t1, t2
266
267
268 class SquareGraph(NamedTuple):
269 # A
270 # high fee / \ low fee
271 # B C
272 # high fee \ / low fee
273 # D
274 w_a: MockLNWallet
275 w_b: MockLNWallet
276 w_c: MockLNWallet
277 w_d: MockLNWallet
278 peer_ab: Peer
279 peer_ac: Peer
280 peer_ba: Peer
281 peer_bd: Peer
282 peer_ca: Peer
283 peer_cd: Peer
284 peer_db: Peer
285 peer_dc: Peer
286 chan_ab: Channel
287 chan_ac: Channel
288 chan_ba: Channel
289 chan_bd: Channel
290 chan_ca: Channel
291 chan_cd: Channel
292 chan_db: Channel
293 chan_dc: Channel
294
295 def all_peers(self) -> Iterable[Peer]:
296 return self.peer_ab, self.peer_ac, self.peer_ba, self.peer_bd, self.peer_ca, self.peer_cd, self.peer_db, self.peer_dc
297
298 def all_lnworkers(self) -> Iterable[MockLNWallet]:
299 return self.w_a, self.w_b, self.w_c, self.w_d
300
301
302 class PaymentDone(Exception): pass
303 class TestSuccess(Exception): pass
304
305
306 class TestPeer(ElectrumTestCase):
307
308 @classmethod
309 def setUpClass(cls):
310 super().setUpClass()
311 console_stderr_handler.setLevel(logging.DEBUG)
312
313 def setUp(self):
314 super().setUp()
315 self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop()
316 self._lnworkers_created = [] # type: List[MockLNWallet]
317
318 def tearDown(self):
319 async def cleanup_lnworkers():
320 async with TaskGroup() as group:
321 for lnworker in self._lnworkers_created:
322 await group.spawn(lnworker.stop())
323 self._lnworkers_created.clear()
324 run(cleanup_lnworkers())
325
326 self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)
327 self._loop_thread.join(timeout=1)
328 super().tearDown()
329
330 def prepare_peers(self, alice_channel, bob_channel):
331 k1, k2 = keypair(), keypair()
332 alice_channel.node_id = k2.pubkey
333 bob_channel.node_id = k1.pubkey
334 t1, t2 = transport_pair(k1, k2, alice_channel.name, bob_channel.name)
335 q1, q2 = asyncio.Queue(), asyncio.Queue()
336 w1 = MockLNWallet(local_keypair=k1, chans=[alice_channel], tx_queue=q1, name=bob_channel.name)
337 w2 = MockLNWallet(local_keypair=k2, chans=[bob_channel], tx_queue=q2, name=alice_channel.name)
338 self._lnworkers_created.extend([w1, w2])
339 p1 = Peer(w1, k2.pubkey, t1)
340 p2 = Peer(w2, k1.pubkey, t2)
341 w1._peers[p1.pubkey] = p1
342 w2._peers[p2.pubkey] = p2
343 # mark_open won't work if state is already OPEN.
344 # so set it to FUNDED
345 alice_channel._state = ChannelState.FUNDED
346 bob_channel._state = ChannelState.FUNDED
347 # this populates the channel graph:
348 p1.mark_open(alice_channel)
349 p2.mark_open(bob_channel)
350 return p1, p2, w1, w2, q1, q2
351
352 def prepare_chans_and_peers_in_square(self) -> SquareGraph:
353 key_a, key_b, key_c, key_d = [keypair() for i in range(4)]
354 chan_ab, chan_ba = create_test_channels(alice_name="alice", bob_name="bob", alice_pubkey=key_a.pubkey, bob_pubkey=key_b.pubkey)
355 chan_ac, chan_ca = create_test_channels(alice_name="alice", bob_name="carol", alice_pubkey=key_a.pubkey, bob_pubkey=key_c.pubkey)
356 chan_bd, chan_db = create_test_channels(alice_name="bob", bob_name="dave", alice_pubkey=key_b.pubkey, bob_pubkey=key_d.pubkey)
357 chan_cd, chan_dc = create_test_channels(alice_name="carol", bob_name="dave", alice_pubkey=key_c.pubkey, bob_pubkey=key_d.pubkey)
358 trans_ab, trans_ba = transport_pair(key_a, key_b, chan_ab.name, chan_ba.name)
359 trans_ac, trans_ca = transport_pair(key_a, key_c, chan_ac.name, chan_ca.name)
360 trans_bd, trans_db = transport_pair(key_b, key_d, chan_bd.name, chan_db.name)
361 trans_cd, trans_dc = transport_pair(key_c, key_d, chan_cd.name, chan_dc.name)
362 txq_a, txq_b, txq_c, txq_d = [asyncio.Queue() for i in range(4)]
363 w_a = MockLNWallet(local_keypair=key_a, chans=[chan_ab, chan_ac], tx_queue=txq_a, name="alice")
364 w_b = MockLNWallet(local_keypair=key_b, chans=[chan_ba, chan_bd], tx_queue=txq_b, name="bob")
365 w_c = MockLNWallet(local_keypair=key_c, chans=[chan_ca, chan_cd], tx_queue=txq_c, name="carol")
366 w_d = MockLNWallet(local_keypair=key_d, chans=[chan_db, chan_dc], tx_queue=txq_d, name="dave")
367 self._lnworkers_created.extend([w_a, w_b, w_c, w_d])
368 peer_ab = Peer(w_a, key_b.pubkey, trans_ab)
369 peer_ac = Peer(w_a, key_c.pubkey, trans_ac)
370 peer_ba = Peer(w_b, key_a.pubkey, trans_ba)
371 peer_bd = Peer(w_b, key_d.pubkey, trans_bd)
372 peer_ca = Peer(w_c, key_a.pubkey, trans_ca)
373 peer_cd = Peer(w_c, key_d.pubkey, trans_cd)
374 peer_db = Peer(w_d, key_b.pubkey, trans_db)
375 peer_dc = Peer(w_d, key_c.pubkey, trans_dc)
376 w_a._peers[peer_ab.pubkey] = peer_ab
377 w_a._peers[peer_ac.pubkey] = peer_ac
378 w_b._peers[peer_ba.pubkey] = peer_ba
379 w_b._peers[peer_bd.pubkey] = peer_bd
380 w_c._peers[peer_ca.pubkey] = peer_ca
381 w_c._peers[peer_cd.pubkey] = peer_cd
382 w_d._peers[peer_db.pubkey] = peer_db
383 w_d._peers[peer_dc.pubkey] = peer_dc
384
385 w_b.network.config.set_key('lightning_forward_payments', True)
386 w_c.network.config.set_key('lightning_forward_payments', True)
387
388 # forwarding fees, etc
389 chan_ab.forwarding_fee_proportional_millionths *= 500
390 chan_ab.forwarding_fee_base_msat *= 500
391 chan_ba.forwarding_fee_proportional_millionths *= 500
392 chan_ba.forwarding_fee_base_msat *= 500
393 chan_bd.forwarding_fee_proportional_millionths *= 500
394 chan_bd.forwarding_fee_base_msat *= 500
395 chan_db.forwarding_fee_proportional_millionths *= 500
396 chan_db.forwarding_fee_base_msat *= 500
397
398 # mark_open won't work if state is already OPEN.
399 # so set it to FUNDED
400 for chan in [chan_ab, chan_ac, chan_ba, chan_bd, chan_ca, chan_cd, chan_db, chan_dc]:
401 chan._state = ChannelState.FUNDED
402 # this populates the channel graph:
403 peer_ab.mark_open(chan_ab)
404 peer_ac.mark_open(chan_ac)
405 peer_ba.mark_open(chan_ba)
406 peer_bd.mark_open(chan_bd)
407 peer_ca.mark_open(chan_ca)
408 peer_cd.mark_open(chan_cd)
409 peer_db.mark_open(chan_db)
410 peer_dc.mark_open(chan_dc)
411 return SquareGraph(
412 w_a=w_a,
413 w_b=w_b,
414 w_c=w_c,
415 w_d=w_d,
416 peer_ab=peer_ab,
417 peer_ac=peer_ac,
418 peer_ba=peer_ba,
419 peer_bd=peer_bd,
420 peer_ca=peer_ca,
421 peer_cd=peer_cd,
422 peer_db=peer_db,
423 peer_dc=peer_dc,
424 chan_ab=chan_ab,
425 chan_ac=chan_ac,
426 chan_ba=chan_ba,
427 chan_bd=chan_bd,
428 chan_ca=chan_ca,
429 chan_cd=chan_cd,
430 chan_db=chan_db,
431 chan_dc=chan_dc,
432 )
433
434 @staticmethod
435 async def prepare_invoice(
436 w2: MockLNWallet, # receiver
437 *,
438 amount_msat=100_000_000,
439 include_routing_hints=False,
440 ) -> Tuple[LnAddr, str]:
441 amount_btc = amount_msat/Decimal(COIN*1000)
442 payment_preimage = os.urandom(32)
443 RHASH = sha256(payment_preimage)
444 info = PaymentInfo(RHASH, amount_msat, RECEIVED, PR_UNPAID)
445 w2.save_preimage(RHASH, payment_preimage)
446 w2.save_payment_info(info)
447 if include_routing_hints:
448 routing_hints = await w2._calc_routing_hints_for_invoice(amount_msat)
449 else:
450 routing_hints = []
451 trampoline_hints = []
452 for r in routing_hints:
453 node_id, short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = r[1][0]
454 if len(r[1])== 1 and w2.is_trampoline_peer(node_id):
455 trampoline_hints.append(('t', (node_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta)))
456 invoice_features = w2.features.for_invoice()
457 if invoice_features.supports(LnFeatures.PAYMENT_SECRET_OPT):
458 payment_secret = derive_payment_secret_from_payment_preimage(payment_preimage)
459 else:
460 payment_secret = None
461 lnaddr1 = LnAddr(
462 paymenthash=RHASH,
463 amount=amount_btc,
464 tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE),
465 ('d', 'coffee'),
466 ('9', invoice_features),
467 ] + routing_hints + trampoline_hints,
468 payment_secret=payment_secret,
469 )
470 invoice = lnencode(lnaddr1, w2.node_keypair.privkey)
471 lnaddr2 = lndecode(invoice) # unlike lnaddr1, this now has a pubkey set
472 return lnaddr2, invoice
473
474 def test_reestablish(self):
475 alice_channel, bob_channel = create_test_channels()
476 p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
477 for chan in (alice_channel, bob_channel):
478 chan.peer_state = PeerState.DISCONNECTED
479 async def reestablish():
480 await asyncio.gather(
481 p1.reestablish_channel(alice_channel),
482 p2.reestablish_channel(bob_channel))
483 self.assertEqual(alice_channel.peer_state, PeerState.GOOD)
484 self.assertEqual(bob_channel.peer_state, PeerState.GOOD)
485 gath.cancel()
486 gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p1.htlc_switch())
487 async def f():
488 await gath
489 with self.assertRaises(concurrent.futures.CancelledError):
490 run(f())
491
492 @needs_test_with_all_chacha20_implementations
493 def test_reestablish_with_old_state(self):
494 random_seed = os.urandom(32)
495 alice_channel, bob_channel = create_test_channels(random_seed=random_seed)
496 alice_channel_0, bob_channel_0 = create_test_channels(random_seed=random_seed) # these are identical
497 p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
498 lnaddr, pay_req = run(self.prepare_invoice(w2))
499 async def pay():
500 result, log = await w1.pay_invoice(pay_req)
501 self.assertEqual(result, True)
502 gath.cancel()
503 gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
504 async def f():
505 await gath
506 with self.assertRaises(concurrent.futures.CancelledError):
507 run(f())
508
509 p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel_0, bob_channel)
510 for chan in (alice_channel_0, bob_channel):
511 chan.peer_state = PeerState.DISCONNECTED
512 async def reestablish():
513 await asyncio.gather(
514 p1.reestablish_channel(alice_channel_0),
515 p2.reestablish_channel(bob_channel))
516 self.assertEqual(alice_channel_0.peer_state, PeerState.BAD)
517 self.assertEqual(bob_channel._state, ChannelState.FORCE_CLOSING)
518 # wait so that pending messages are processed
519 #await asyncio.sleep(1)
520 gath.cancel()
521 gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
522 async def f():
523 await gath
524 with self.assertRaises(concurrent.futures.CancelledError):
525 run(f())
526
527 @needs_test_with_all_chacha20_implementations
528 def test_payment(self):
529 """Alice pays Bob a single HTLC via direct channel."""
530 alice_channel, bob_channel = create_test_channels()
531 p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
532 async def pay(lnaddr, pay_req):
533 self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash))
534 result, log = await w1.pay_invoice(pay_req)
535 self.assertTrue(result)
536 self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
537 raise PaymentDone()
538 async def f():
539 async with TaskGroup() as group:
540 await group.spawn(p1._message_loop())
541 await group.spawn(p1.htlc_switch())
542 await group.spawn(p2._message_loop())
543 await group.spawn(p2.htlc_switch())
544 await asyncio.sleep(0.01)
545 lnaddr, pay_req = await self.prepare_invoice(w2)
546 invoice_features = lnaddr.get_features()
547 self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
548 await group.spawn(pay(lnaddr, pay_req))
549 with self.assertRaises(PaymentDone):
550 run(f())
551
552 @needs_test_with_all_chacha20_implementations
553 def test_payment_race(self):
554 """Alice and Bob pay each other simultaneously.
555 They both send 'update_add_htlc' and receive each other's update
556 before sending 'commitment_signed'. Neither party should fulfill
557 the respective HTLCs until those are irrevocably committed to.
558 """
559 alice_channel, bob_channel = create_test_channels()
560 p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
561 async def pay():
562 await asyncio.wait_for(p1.initialized, 1)
563 await asyncio.wait_for(p2.initialized, 1)
564 # prep
565 _maybe_send_commitment1 = p1.maybe_send_commitment
566 _maybe_send_commitment2 = p2.maybe_send_commitment
567 lnaddr2, pay_req2 = await self.prepare_invoice(w2)
568 lnaddr1, pay_req1 = await self.prepare_invoice(w1)
569 # create the htlc queues now (side-effecting defaultdict)
570 q1 = w1.sent_htlcs[lnaddr2.paymenthash]
571 q2 = w2.sent_htlcs[lnaddr1.paymenthash]
572 # alice sends htlc BUT NOT COMMITMENT_SIGNED
573 p1.maybe_send_commitment = lambda x: None
574 route1 = w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2)[0][0]
575 amount_msat = lnaddr2.get_amount_msat()
576 await w1.pay_to_route(
577 route=route1,
578 amount_msat=amount_msat,
579 total_msat=amount_msat,
580 amount_receiver_msat=amount_msat,
581 payment_hash=lnaddr2.paymenthash,
582 min_cltv_expiry=lnaddr2.get_min_final_cltv_expiry(),
583 payment_secret=lnaddr2.payment_secret,
584 )
585 p1.maybe_send_commitment = _maybe_send_commitment1
586 # bob sends htlc BUT NOT COMMITMENT_SIGNED
587 p2.maybe_send_commitment = lambda x: None
588 route2 = w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1)[0][0]
589 amount_msat = lnaddr1.get_amount_msat()
590 await w2.pay_to_route(
591 route=route2,
592 amount_msat=amount_msat,
593 total_msat=amount_msat,
594 amount_receiver_msat=amount_msat,
595 payment_hash=lnaddr1.paymenthash,
596 min_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(),
597 payment_secret=lnaddr1.payment_secret,
598 )
599 p2.maybe_send_commitment = _maybe_send_commitment2
600 # sleep a bit so that they both receive msgs sent so far
601 await asyncio.sleep(0.2)
602 # now they both send COMMITMENT_SIGNED
603 p1.maybe_send_commitment(alice_channel)
604 p2.maybe_send_commitment(bob_channel)
605
606 htlc_log1 = await q1.get()
607 assert htlc_log1.success
608 htlc_log2 = await q2.get()
609 assert htlc_log2.success
610 raise PaymentDone()
611
612 async def f():
613 async with TaskGroup() as group:
614 await group.spawn(p1._message_loop())
615 await group.spawn(p1.htlc_switch())
616 await group.spawn(p2._message_loop())
617 await group.spawn(p2.htlc_switch())
618 await asyncio.sleep(0.01)
619 await group.spawn(pay())
620 with self.assertRaises(PaymentDone):
621 run(f())
622
623 #@unittest.skip("too expensive")
624 #@needs_test_with_all_chacha20_implementations
625 def test_payments_stresstest(self):
626 alice_channel, bob_channel = create_test_channels()
627 p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
628 alice_init_balance_msat = alice_channel.balance(HTLCOwner.LOCAL)
629 bob_init_balance_msat = bob_channel.balance(HTLCOwner.LOCAL)
630 num_payments = 50
631 payment_value_msat = 10_000_000 # make it large enough so that there are actually HTLCs on the ctx
632 max_htlcs_in_flight = asyncio.Semaphore(5)
633 async def single_payment(pay_req):
634 async with max_htlcs_in_flight:
635 await w1.pay_invoice(pay_req)
636 async def many_payments():
637 async with TaskGroup() as group:
638 pay_reqs_tasks = [await group.spawn(self.prepare_invoice(w2, amount_msat=payment_value_msat))
639 for i in range(num_payments)]
640 async with TaskGroup() as group:
641 for pay_req_task in pay_reqs_tasks:
642 lnaddr, pay_req = pay_req_task.result()
643 await group.spawn(single_payment(pay_req))
644 gath.cancel()
645 gath = asyncio.gather(many_payments(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
646 async def f():
647 await gath
648 with self.assertRaises(concurrent.futures.CancelledError):
649 run(f())
650 self.assertEqual(alice_init_balance_msat - num_payments * payment_value_msat, alice_channel.balance(HTLCOwner.LOCAL))
651 self.assertEqual(alice_init_balance_msat - num_payments * payment_value_msat, bob_channel.balance(HTLCOwner.REMOTE))
652 self.assertEqual(bob_init_balance_msat + num_payments * payment_value_msat, bob_channel.balance(HTLCOwner.LOCAL))
653 self.assertEqual(bob_init_balance_msat + num_payments * payment_value_msat, alice_channel.balance(HTLCOwner.REMOTE))
654
655 @needs_test_with_all_chacha20_implementations
656 def test_payment_multihop(self):
657 graph = self.prepare_chans_and_peers_in_square()
658 peers = graph.all_peers()
659 async def pay(lnaddr, pay_req):
660 self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
661 result, log = await graph.w_a.pay_invoice(pay_req)
662 self.assertTrue(result)
663 self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
664 raise PaymentDone()
665 async def f():
666 async with TaskGroup() as group:
667 for peer in peers:
668 await group.spawn(peer._message_loop())
669 await group.spawn(peer.htlc_switch())
670 await asyncio.sleep(0.2)
671 lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True)
672 await group.spawn(pay(lnaddr, pay_req))
673 with self.assertRaises(PaymentDone):
674 run(f())
675
676 @needs_test_with_all_chacha20_implementations
677 def test_payment_multihop_with_preselected_path(self):
678 graph = self.prepare_chans_and_peers_in_square()
679 peers = graph.all_peers()
680 async def pay(pay_req):
681 with self.subTest(msg="bad path: edges do not chain together"):
682 path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey,
683 end_node=graph.w_c.node_keypair.pubkey,
684 short_channel_id=graph.chan_ab.short_channel_id),
685 PathEdge(start_node=graph.w_b.node_keypair.pubkey,
686 end_node=graph.w_d.node_keypair.pubkey,
687 short_channel_id=graph.chan_bd.short_channel_id)]
688 with self.assertRaises(LNPathInconsistent):
689 await graph.w_a.pay_invoice(pay_req, full_path=path)
690 with self.subTest(msg="bad path: last node id differs from invoice pubkey"):
691 path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey,
692 end_node=graph.w_b.node_keypair.pubkey,
693 short_channel_id=graph.chan_ab.short_channel_id)]
694 with self.assertRaises(LNPathInconsistent):
695 await graph.w_a.pay_invoice(pay_req, full_path=path)
696 with self.subTest(msg="good path"):
697 path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey,
698 end_node=graph.w_b.node_keypair.pubkey,
699 short_channel_id=graph.chan_ab.short_channel_id),
700 PathEdge(start_node=graph.w_b.node_keypair.pubkey,
701 end_node=graph.w_d.node_keypair.pubkey,
702 short_channel_id=graph.chan_bd.short_channel_id)]
703 result, log = await graph.w_a.pay_invoice(pay_req, full_path=path)
704 self.assertTrue(result)
705 self.assertEqual(
706 [edge.short_channel_id for edge in path],
707 [edge.short_channel_id for edge in log[0].route])
708 raise PaymentDone()
709 async def f():
710 async with TaskGroup() as group:
711 for peer in peers:
712 await group.spawn(peer._message_loop())
713 await group.spawn(peer.htlc_switch())
714 await asyncio.sleep(0.2)
715 lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True)
716 await group.spawn(pay(pay_req))
717 with self.assertRaises(PaymentDone):
718 run(f())
719
720 @needs_test_with_all_chacha20_implementations
721 def test_payment_multihop_temp_node_failure(self):
722 graph = self.prepare_chans_and_peers_in_square()
723 graph.w_b.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True)
724 graph.w_c.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True)
725 peers = graph.all_peers()
726 async def pay(lnaddr, pay_req):
727 self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
728 result, log = await graph.w_a.pay_invoice(pay_req)
729 self.assertFalse(result)
730 self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
731 self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code)
732 raise PaymentDone()
733 async def f():
734 async with TaskGroup() as group:
735 for peer in peers:
736 await group.spawn(peer._message_loop())
737 await group.spawn(peer.htlc_switch())
738 await asyncio.sleep(0.2)
739 lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True)
740 await group.spawn(pay(lnaddr, pay_req))
741 with self.assertRaises(PaymentDone):
742 run(f())
743
744 @needs_test_with_all_chacha20_implementations
745 def test_payment_multihop_route_around_failure(self):
746 # Alice will pay Dave. Alice first tries A->C->D route, due to lower fees, but Carol
747 # will fail the htlc and get blacklisted. Alice will then try A->B->D and succeed.
748 graph = self.prepare_chans_and_peers_in_square()
749 graph.w_c.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True)
750 peers = graph.all_peers()
751 async def pay(lnaddr, pay_req):
752 self.assertEqual(500000000000, graph.chan_ab.balance(LOCAL))
753 self.assertEqual(500000000000, graph.chan_db.balance(LOCAL))
754 self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
755 result, log = await graph.w_a.pay_invoice(pay_req, attempts=2)
756 self.assertEqual(2, len(log))
757 self.assertTrue(result)
758 self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
759 self.assertEqual([graph.chan_ac.short_channel_id, graph.chan_cd.short_channel_id],
760 [edge.short_channel_id for edge in log[0].route])
761 self.assertEqual([graph.chan_ab.short_channel_id, graph.chan_bd.short_channel_id],
762 [edge.short_channel_id for edge in log[1].route])
763 self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code)
764 self.assertEqual(499899450000, graph.chan_ab.balance(LOCAL))
765 await asyncio.sleep(0.2) # wait for COMMITMENT_SIGNED / REVACK msgs to update balance
766 self.assertEqual(500100000000, graph.chan_db.balance(LOCAL))
767 raise PaymentDone()
768 async def f():
769 async with TaskGroup() as group:
770 for peer in peers:
771 await group.spawn(peer._message_loop())
772 await group.spawn(peer.htlc_switch())
773 await asyncio.sleep(0.2)
774 lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True)
775 invoice_features = lnaddr.get_features()
776 self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
777 await group.spawn(pay(lnaddr, pay_req))
778 with self.assertRaises(PaymentDone):
779 run(f())
780
781 def _run_mpp(self, graph, kwargs1, kwargs2):
782 self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL))
783 self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL))
784 amount_to_pay = 600_000_000_000
785 peers = graph.all_peers()
786 async def pay(attempts=1,
787 alice_uses_trampoline=False,
788 bob_forwarding=True,
789 mpp_invoice=True):
790 if mpp_invoice:
791 graph.w_d.features |= LnFeatures.BASIC_MPP_OPT
792 if not bob_forwarding:
793 graph.w_b.enable_htlc_forwarding.clear()
794 if alice_uses_trampoline:
795 if graph.w_a.network.channel_db:
796 graph.w_a.network.channel_db.stop()
797 await graph.w_a.network.channel_db.stopped_event.wait()
798 graph.w_a.network.channel_db = None
799 else:
800 assert graph.w_a.network.channel_db is not None
801 lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay)
802 self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
803 result, log = await graph.w_a.pay_invoice(pay_req, attempts=attempts)
804 if not bob_forwarding:
805 # reset to previous state, sleep 2s so that the second htlc can time out
806 graph.w_b.enable_htlc_forwarding.set()
807 await asyncio.sleep(2)
808 if result:
809 self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
810 raise PaymentDone()
811 else:
812 raise NoPathFound()
813
814 async def f(kwargs):
815 async with TaskGroup() as group:
816 for peer in peers:
817 await group.spawn(peer._message_loop())
818 await group.spawn(peer.htlc_switch())
819 await asyncio.sleep(0.2)
820 await group.spawn(pay(**kwargs))
821
822 with self.assertRaises(NoPathFound):
823 run(f(kwargs1))
824 with self.assertRaises(PaymentDone):
825 run(f(kwargs2))
826
827 @needs_test_with_all_chacha20_implementations
828 def test_multipart_payment_with_timeout(self):
829 graph = self.prepare_chans_and_peers_in_square()
830 self._run_mpp(graph, {'bob_forwarding':False}, {'bob_forwarding':True})
831
832 @needs_test_with_all_chacha20_implementations
833 def test_multipart_payment(self):
834 graph = self.prepare_chans_and_peers_in_square()
835 self._run_mpp(graph, {'mpp_invoice':False}, {'mpp_invoice':True})
836
837 @needs_test_with_all_chacha20_implementations
838 def test_multipart_payment_with_trampoline(self):
839 # single attempt will fail with insufficient trampoline fee
840 graph = self.prepare_chans_and_peers_in_square()
841 self._run_mpp(graph, {'alice_uses_trampoline':True, 'attempts':1}, {'alice_uses_trampoline':True, 'attempts':3})
842
843 @needs_test_with_all_chacha20_implementations
844 def test_fail_pending_htlcs_on_shutdown(self):
845 """Alice tries to pay Dave via MPP. Dave receives some HTLCs but not all.
846 Dave shuts down (stops wallet).
847 We test if Dave fails the pending HTLCs during shutdown.
848 """
849 graph = self.prepare_chans_and_peers_in_square()
850 self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL))
851 self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL))
852 amount_to_pay = 600_000_000_000
853 peers = graph.all_peers()
854 graph.w_d.MPP_EXPIRY = 120
855 graph.w_d.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3
856 async def pay():
857 graph.w_d.features |= LnFeatures.BASIC_MPP_OPT
858 graph.w_b.enable_htlc_forwarding.clear() # Bob will hold forwarded HTLCs
859 assert graph.w_a.network.channel_db is not None
860 lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay)
861 try:
862 async with timeout_after(0.5):
863 result, log = await graph.w_a.pay_invoice(pay_req, attempts=1)
864 except TaskTimeout:
865 # by now Dave hopefully received some HTLCs:
866 self.assertTrue(len(graph.chan_dc.hm.htlcs(LOCAL)) > 0)
867 self.assertTrue(len(graph.chan_dc.hm.htlcs(REMOTE)) > 0)
868 else:
869 self.fail(f"pay_invoice finished but was not supposed to. result={result}")
870 await graph.w_d.stop()
871 # Dave is supposed to have failed the pending incomplete MPP HTLCs
872 self.assertEqual(0, len(graph.chan_dc.hm.htlcs(LOCAL)))
873 self.assertEqual(0, len(graph.chan_dc.hm.htlcs(REMOTE)))
874 raise TestSuccess()
875
876 async def f():
877 async with TaskGroup() as group:
878 for peer in peers:
879 await group.spawn(peer._message_loop())
880 await group.spawn(peer.htlc_switch())
881 await asyncio.sleep(0.2)
882 await group.spawn(pay())
883
884 with self.assertRaises(TestSuccess):
885 run(f())
886
887 @needs_test_with_all_chacha20_implementations
888 def test_close(self):
889 alice_channel, bob_channel = create_test_channels()
890 p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
891 w1.network.config.set_key('dynamic_fees', False)
892 w2.network.config.set_key('dynamic_fees', False)
893 w1.network.config.set_key('fee_per_kb', 5000)
894 w2.network.config.set_key('fee_per_kb', 1000)
895 w2.enable_htlc_settle.clear()
896 lnaddr, pay_req = run(self.prepare_invoice(w2))
897 async def pay():
898 await asyncio.wait_for(p1.initialized, 1)
899 await asyncio.wait_for(p2.initialized, 1)
900 # alice sends htlc
901 route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0][0:2]
902 htlc = p1.pay(route=route,
903 chan=alice_channel,
904 amount_msat=lnaddr.get_amount_msat(),
905 total_msat=lnaddr.get_amount_msat(),
906 payment_hash=lnaddr.paymenthash,
907 min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(),
908 payment_secret=lnaddr.payment_secret)
909 # alice closes
910 await p1.close_channel(alice_channel.channel_id)
911 gath.cancel()
912 async def set_settle():
913 await asyncio.sleep(0.1)
914 w2.enable_htlc_settle.set()
915 gath = asyncio.gather(pay(), set_settle(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
916 async def f():
917 await gath
918 with self.assertRaises(concurrent.futures.CancelledError):
919 run(f())
920
921 @needs_test_with_all_chacha20_implementations
922 def test_close_upfront_shutdown_script(self):
923 alice_channel, bob_channel = create_test_channels()
924
925 # create upfront shutdown script for bob, alice doesn't use upfront
926 # shutdown script
927 bob_uss_pub = lnutil.privkey_to_pubkey(os.urandom(32))
928 bob_uss_addr = bitcoin.pubkey_to_address('p2wpkh', bh2u(bob_uss_pub))
929 bob_uss = bfh(bitcoin.address_to_script(bob_uss_addr))
930
931 # bob commits to close to bob_uss
932 alice_channel.config[HTLCOwner.REMOTE].upfront_shutdown_script = bob_uss
933 # but bob closes to some receiving address, which we achieve by not
934 # setting the upfront shutdown script in the channel config
935 bob_channel.config[HTLCOwner.LOCAL].upfront_shutdown_script = b''
936
937 p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel)
938 w1.network.config.set_key('dynamic_fees', False)
939 w2.network.config.set_key('dynamic_fees', False)
940 w1.network.config.set_key('fee_per_kb', 5000)
941 w2.network.config.set_key('fee_per_kb', 1000)
942
943 async def test():
944 async def close():
945 await asyncio.wait_for(p1.initialized, 1)
946 await asyncio.wait_for(p2.initialized, 1)
947 # bob closes channel with different shutdown script
948 await p1.close_channel(alice_channel.channel_id)
949 gath.cancel()
950
951 async def main_loop(peer):
952 async with peer.taskgroup as group:
953 await group.spawn(peer._message_loop())
954 await group.spawn(peer.htlc_switch())
955
956 coros = [close(), main_loop(p1), main_loop(p2)]
957 gath = asyncio.gather(*coros)
958 await gath
959
960 with self.assertRaises(UpfrontShutdownScriptViolation):
961 run(test())
962
963 # bob sends the same upfront_shutdown_script has he announced
964 alice_channel.config[HTLCOwner.REMOTE].upfront_shutdown_script = bob_uss
965 bob_channel.config[HTLCOwner.LOCAL].upfront_shutdown_script = bob_uss
966
967 p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel)
968 w1.network.config.set_key('dynamic_fees', False)
969 w2.network.config.set_key('dynamic_fees', False)
970 w1.network.config.set_key('fee_per_kb', 5000)
971 w2.network.config.set_key('fee_per_kb', 1000)
972
973 async def test():
974 async def close():
975 await asyncio.wait_for(p1.initialized, 1)
976 await asyncio.wait_for(p2.initialized, 1)
977 await p1.close_channel(alice_channel.channel_id)
978 gath.cancel()
979
980 async def main_loop(peer):
981 async with peer.taskgroup as group:
982 await group.spawn(peer._message_loop())
983 await group.spawn(peer.htlc_switch())
984
985 coros = [close(), main_loop(p1), main_loop(p2)]
986 gath = asyncio.gather(*coros)
987 await gath
988 with self.assertRaises(concurrent.futures.CancelledError):
989 run(test())
990
991 def test_channel_usage_after_closing(self):
992 alice_channel, bob_channel = create_test_channels()
993 p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel)
994 lnaddr, pay_req = run(self.prepare_invoice(w2))
995
996 lnaddr = w1._check_invoice(pay_req)
997 route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0][0:2]
998 assert amount_msat == lnaddr.get_amount_msat()
999
1000 run(w1.force_close_channel(alice_channel.channel_id))
1001 # check if a tx (commitment transaction) was broadcasted:
1002 assert q1.qsize() == 1
1003
1004 with self.assertRaises(NoPathFound) as e:
1005 w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)
1006
1007 peer = w1.peers[route[0].node_id]
1008 # AssertionError is ok since we shouldn't use old routes, and the
1009 # route finding should fail when channel is closed
1010 async def f():
1011 min_cltv_expiry = lnaddr.get_min_final_cltv_expiry()
1012 payment_hash = lnaddr.paymenthash
1013 payment_secret = lnaddr.payment_secret
1014 pay = w1.pay_to_route(
1015 route=route,
1016 amount_msat=amount_msat,
1017 total_msat=amount_msat,
1018 amount_receiver_msat=amount_msat,
1019 payment_hash=payment_hash,
1020 payment_secret=payment_secret,
1021 min_cltv_expiry=min_cltv_expiry)
1022 await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
1023 with self.assertRaises(PaymentFailure):
1024 run(f())
1025
1026
1027 def run(coro):
1028 return asyncio.run_coroutine_threadsafe(coro, loop=asyncio.get_event_loop()).result()