URI: 
       tlnworker/lnpeer: add some type hints, force some kwargs - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 691ebaf4f816b45ca10eabceae3068b7465e6bc5
   DIR parent d800f88bfcdc83a2cb9d359791c0c7e8cf6fcaff
  HTML Author: SomberNight <somber.night@protonmail.com>
       Date:   Wed, 24 Feb 2021 20:03:12 +0100
       
       lnworker/lnpeer: add some type hints, force some kwargs
       
       Diffstat:
         M electrum/lnonion.py                 |       9 ++++++---
         M electrum/lnpeer.py                  |      59 +++++++++++++++++++++----------
         M electrum/lnrater.py                 |       5 ++++-
         M electrum/lnworker.py                |     150 +++++++++++++++++++++----------
         M electrum/tests/test_lnpeer.py       |       8 +++++++-
       
       5 files changed, 160 insertions(+), 71 deletions(-)
       ---
   DIR diff --git a/electrum/lnonion.py b/electrum/lnonion.py
       t@@ -437,9 +437,12 @@ class OnionRoutingFailure(Exception):
                    return str(self.code.name)
                return f"Unknown error ({self.code!r})"
        
       -def construct_onion_error(reason: OnionRoutingFailure,
       -                          onion_packet: OnionPacket,
       -                          our_onion_private_key: bytes) -> bytes:
       +
       +def construct_onion_error(
       +        reason: OnionRoutingFailure,
       +        onion_packet: OnionPacket,
       +        our_onion_private_key: bytes,
       +) -> bytes:
            # create payload
            failure_msg = reason.to_bytes()
            failure_len = len(failure_msg)
   DIR diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
       t@@ -1373,9 +1373,12 @@ class Peer(Logger):
                chan.receive_htlc(htlc, onion_packet)
                util.trigger_callback('htlc_added', chan, htlc, RECEIVED)
        
       -    def maybe_forward_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *,
       -                           onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket
       -                           ) -> Tuple[Optional[bytes], Optional[int], Optional[OnionRoutingFailure]]:
       +    def maybe_forward_htlc(
       +            self,
       +            *,
       +            htlc: UpdateAddHtlc,
       +            processed_onion: ProcessedOnionPacket,
       +    ) -> Tuple[bytes, int]:
                # Forward HTLC
                # FIXME: there are critical safety checks MISSING here
                forwarding_enabled = self.network.config.get('lightning_forward_payments', False)
       t@@ -1662,7 +1665,7 @@ class Peer(Logger):
                self.shutdown_received[chan_id] = asyncio.Future()
                await self.send_shutdown(chan)
                payload = await self.shutdown_received[chan_id]
       -        txid = await self._shutdown(chan, payload, True)
       +        txid = await self._shutdown(chan, payload, is_local=True)
                self.logger.info(f'({chan.get_id_for_log()}) Channel closed {txid}')
                return txid
        
       t@@ -1686,10 +1689,10 @@ class Peer(Logger):
                else:
                    chan = self.channels[chan_id]
                    await self.send_shutdown(chan)
       -            txid = await self._shutdown(chan, payload, False)
       +            txid = await self._shutdown(chan, payload, is_local=False)
                    self.logger.info(f'({chan.get_id_for_log()}) Channel closed by remote peer {txid}')
        
       -    def can_send_shutdown(self, chan):
       +    def can_send_shutdown(self, chan: Channel):
                if chan.get_state() >= ChannelState.OPENING:
                    return True
                if chan.constraints.is_initiator and chan.channel_id in self.funding_created_sent:
       t@@ -1718,7 +1721,7 @@ class Peer(Logger):
                chan.set_can_send_ctx_updates(True)
        
            @log_exceptions
       -    async def _shutdown(self, chan: Channel, payload, is_local):
       +    async def _shutdown(self, chan: Channel, payload, *, is_local: bool):
                # wait until no HTLCs remain in either commitment transaction
                while len(chan.hm.htlcs(LOCAL)) + len(chan.hm.htlcs(REMOTE)) > 0:
                    self.logger.info(f'(chan: {chan.short_channel_id}) waiting for htlcs to settle...')
       t@@ -1826,7 +1829,12 @@ class Peer(Logger):
                                error_reason = e
                            else:
                                try:
       -                            preimage, fw_info, error_bytes = self.process_unfulfilled_htlc(chan, htlc_id, htlc, forwarding_info, onion_packet_bytes, onion_packet)
       +                            preimage, fw_info, error_bytes = self.process_unfulfilled_htlc(
       +                                chan=chan,
       +                                htlc=htlc,
       +                                forwarding_info=forwarding_info,
       +                                onion_packet_bytes=onion_packet_bytes,
       +                                onion_packet=onion_packet)
                                except OnionRoutingFailure as e:
                                    error_bytes = construct_onion_error(e, onion_packet, our_onion_private_key=self.privkey)
                            if fw_info:
       t@@ -1850,13 +1858,24 @@ class Peer(Logger):
                        for htlc_id in done:
                            unfulfilled.pop(htlc_id)
        
       -    def process_unfulfilled_htlc(self, chan, htlc_id, htlc, forwarding_info, onion_packet_bytes, onion_packet):
       +    def process_unfulfilled_htlc(
       +            self,
       +            *,
       +            chan: Channel,
       +            htlc: UpdateAddHtlc,
       +            forwarding_info: Tuple[str, int],
       +            onion_packet_bytes: bytes,
       +            onion_packet: OnionPacket,
       +    ) -> Tuple[Optional[bytes], Union[bool, None, Tuple[str, int]], Optional[bytes]]:
                """
                returns either preimage or fw_info or error_bytes or (None, None, None)
                raise an OnionRoutingFailure if we need to fail the htlc
                """
                payment_hash = htlc.payment_hash
       -        processed_onion = self.process_onion_packet(onion_packet, payment_hash, onion_packet_bytes)
       +        processed_onion = self.process_onion_packet(
       +            onion_packet,
       +            payment_hash=payment_hash,
       +            onion_packet_bytes=onion_packet_bytes)
                if processed_onion.are_we_final:
                    preimage = self.maybe_fulfill_htlc(
                        chan=chan,
       t@@ -1867,8 +1886,8 @@ class Peer(Logger):
                        if not forwarding_info:
                            trampoline_onion = self.process_onion_packet(
                                processed_onion.trampoline_onion_packet,
       -                        htlc.payment_hash,
       -                        onion_packet_bytes,
       +                        payment_hash=htlc.payment_hash,
       +                        onion_packet_bytes=onion_packet_bytes,
                                is_trampoline=True)
                            if trampoline_onion.are_we_final:
                                preimage = self.maybe_fulfill_htlc(
       t@@ -1892,13 +1911,10 @@ class Peer(Logger):
        
                elif not forwarding_info:
                    next_chan_id, next_htlc_id = self.maybe_forward_htlc(
       -                chan=chan,
                        htlc=htlc,
       -                onion_packet=onion_packet,
                        processed_onion=processed_onion)
       -            if next_chan_id:
       -                fw_info = (next_chan_id.hex(), next_htlc_id)
       -                return None, fw_info, None
       +            fw_info = (next_chan_id.hex(), next_htlc_id)
       +            return None, fw_info, None
                else:
                    preimage = self.lnworker.get_preimage(payment_hash)
                    next_chan_id_hex, htlc_id = forwarding_info
       t@@ -1913,7 +1929,14 @@ class Peer(Logger):
                    return preimage, None, None
                return None, None, None
        
       -    def process_onion_packet(self, onion_packet, payment_hash, onion_packet_bytes, is_trampoline=False):
       +    def process_onion_packet(
       +            self,
       +            onion_packet: OnionPacket,
       +            *,
       +            payment_hash: bytes,
       +            onion_packet_bytes: bytes,
       +            is_trampoline: bool = False,
       +    ) -> ProcessedOnionPacket:
                failure_data = sha256(onion_packet_bytes)
                try:
                    processed_onion = process_onion_packet(
   DIR diff --git a/electrum/lnrater.py b/electrum/lnrater.py
       t@@ -268,7 +268,10 @@ class LNRater(Logger):
        
                return pk, self._node_stats[pk]
        
       -    def suggest_peer(self):
       +    def suggest_peer(self) -> Optional[bytes]:
       +        """Suggests a LN node to open a channel with.
       +        Returns a node ID (pubkey).
       +        """
                self.maybe_analyze_graph()
                if self._node_ratings:
                    return self.suggest_node_channel_open()[0]
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -7,7 +7,8 @@ import os
        from decimal import Decimal
        import random
        import time
       -from typing import Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping, Any
       +from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING,
       +                    NamedTuple, Union, Mapping, Any, Iterable)
        import threading
        import socket
        import aiohttp
       t@@ -266,10 +267,10 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
                with self.lock:
                    return self._peers.copy()
        
       -    def channels_for_peer(self, node_id):
       +    def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
                return {}
        
       -    def get_node_alias(self, node_id):
       +    def get_node_alias(self, node_id: bytes) -> str:
                if self.channel_db:
                    node_info = self.channel_db.get_node_info_for_node_id(node_id)
                    node_alias = (node_info.alias if node_info else '') or node_id.hex()
       t@@ -380,7 +381,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
                        self._add_peer(host, int(port), bfh(pubkey)),
                        self.network.asyncio_loop)
        
       -    def is_good_peer(self, peer):
       +    def is_good_peer(self, peer: LNPeerAddr) -> bool:
                # the purpose of this method is to filter peers that advertise the desired feature bits
                # it is disabled for now, because feature bits published in node announcements seem to be unreliable
                return True
       t@@ -566,7 +567,7 @@ class LNGossip(LNWorker):
                        self.channel_db.prune_orphaned_channels()
                    await asyncio.sleep(120)
        
       -    async def add_new_ids(self, ids):
       +    async def add_new_ids(self, ids: Iterable[bytes]):
                known = self.channel_db.get_channel_ids()
                new = set(ids) - set(known)
                self.unknown_ids.update(new)
       t@@ -574,7 +575,7 @@ class LNGossip(LNWorker):
                util.trigger_callback('gossip_peers', self.num_peers())
                util.trigger_callback('ln_gossip_sync_progress')
        
       -    def get_ids_to_query(self):
       +    def get_ids_to_query(self) -> Sequence[bytes]:
                N = 500
                l = list(self.unknown_ids)
                self.unknown_ids = set(l[N:])
       t@@ -910,7 +911,7 @@ class LNWallet(LNWorker):
                    if chan.funding_outpoint.to_str() == txo:
                        return chan
        
       -    async def on_channel_update(self, chan):
       +    async def on_channel_update(self, chan: Channel):
        
                if chan.get_state() == ChannelState.OPEN and chan.should_be_closed_due_to_expiring_htlcs(self.network.get_local_height()):
                    self.logger.info(f"force-closing due to expiring htlcs")
       t@@ -938,10 +939,14 @@ class LNWallet(LNWorker):
        
            @log_exceptions
            async def _open_channel_coroutine(
       -            self, *, connect_str: str,
       +            self,
       +            *,
       +            connect_str: str,
                    funding_tx: PartialTransaction,
       -            funding_sat: int, push_sat: int,
       -            password: Optional[str]) -> Tuple[Channel, PartialTransaction]:
       +            funding_sat: int,
       +            push_sat: int,
       +            password: Optional[str],
       +    ) -> Tuple[Channel, PartialTransaction]:
                peer = await self.add_peer(connect_str)
                coro = peer.channel_establishment_flow(
                    funding_tx=funding_tx,
       t@@ -1006,7 +1011,7 @@ class LNWallet(LNWorker):
                    if chan.short_channel_id == short_channel_id:
                        return chan
        
       -    def create_routes_from_invoice(self, amount_msat, decoded_invoice, *, full_path=None):
       +    def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None):
                return self.create_routes_for_payment(
                    amount_msat=amount_msat,
                    invoice_pubkey=decoded_invoice.pubkey.serialize(),
       t@@ -1051,9 +1056,16 @@ class LNWallet(LNWorker):
                util.trigger_callback('invoice_status', self.wallet, key)
                try:
                    await self.pay_to_node(
       -                invoice_pubkey, payment_hash, payment_secret, amount_to_pay,
       -                min_cltv_expiry, r_tags, t_tags, invoice_features,
       -                attempts=attempts, full_path=full_path)
       +                node_pubkey=invoice_pubkey,
       +                payment_hash=payment_hash,
       +                payment_secret=payment_secret,
       +                amount_to_pay=amount_to_pay,
       +                min_cltv_expiry=min_cltv_expiry,
       +                r_tags=r_tags,
       +                t_tags=t_tags,
       +                invoice_features=invoice_features,
       +                attempts=attempts,
       +                full_path=full_path)
                    success = True
                except PaymentFailure as e:
                    self.logger.exception('')
       t@@ -1068,12 +1080,23 @@ class LNWallet(LNWorker):
                log = self.logs[key]
                return success, log
        
       -
            async def pay_to_node(
       -            self, node_pubkey, payment_hash, payment_secret, amount_to_pay,
       -            min_cltv_expiry, r_tags, t_tags, invoice_features, *,
       -            attempts: int = 1, full_path: LNPaymentPath=None,
       -            trampoline_onion=None, trampoline_fee=None, trampoline_cltv_delta=None):
       +            self,
       +            *,
       +            node_pubkey: bytes,
       +            payment_hash: bytes,
       +            payment_secret: Optional[bytes],
       +            amount_to_pay: int,  # in msat
       +            min_cltv_expiry: int,
       +            r_tags,
       +            t_tags,
       +            invoice_features: int,
       +            attempts: int = 1,
       +            full_path: LNPaymentPath = None,
       +            trampoline_onion=None,
       +            trampoline_fee=None,
       +            trampoline_cltv_delta=None,
       +    ) -> None:
        
                if trampoline_onion:
                    # todo: compare to the fee of the actual route we found
       t@@ -1095,7 +1118,14 @@ class LNWallet(LNWorker):
                            min_cltv_expiry, r_tags, t_tags, invoice_features, full_path=full_path))
                        # 2. send htlcs
                        for route, amount_msat in routes:
       -                    await self.pay_to_route(route, amount_msat, amount_to_pay, payment_hash, payment_secret, min_cltv_expiry, trampoline_onion)
       +                    await self.pay_to_route(
       +                        route,
       +                        amount_msat=amount_msat,
       +                        total_msat=amount_to_pay,
       +                        payment_hash=payment_hash,
       +                        payment_secret=payment_secret,
       +                        min_cltv_expiry=min_cltv_expiry,
       +                        trampoline_onion=trampoline_onion)
                            amount_inflight += amount_msat
                        util.trigger_callback('invoice_status', self.wallet, payment_hash.hex())
                    # 3. await a queue
       t@@ -1111,9 +1141,17 @@ class LNWallet(LNWorker):
                    # if we get a channel update, we might retry the same route and amount
                    self.handle_error_code_from_failed_htlc(htlc_log)
        
       -    async def pay_to_route(self, route: LNPaymentRoute, amount_msat: int,
       -                           total_msat: int, payment_hash: bytes, payment_secret: bytes,
       -                           min_cltv_expiry: int, trampoline_onion: bytes=None):
       +    async def pay_to_route(
       +            self,
       +            route: LNPaymentRoute,
       +            *,
       +            amount_msat: int,
       +            total_msat: int,
       +            payment_hash: bytes,
       +            payment_secret: Optional[bytes],
       +            min_cltv_expiry: int,
       +            trampoline_onion: bytes = None,
       +    ) -> None:
                # send a single htlc
                short_channel_id = route[0].short_channel_id
                chan = self.get_channel_by_short_id(short_channel_id)
       t@@ -1267,7 +1305,7 @@ class LNWallet(LNWorker):
                        result.append(bitstring.BitArray(pubkey) + bitstring.BitArray(channel) + bitstring.pack('intbe:32', feebase) + bitstring.pack('intbe:32', feerate) + bitstring.pack('intbe:16', cltv))
                return result.tobytes()
        
       -    def is_trampoline_peer(self, node_id):
       +    def is_trampoline_peer(self, node_id: bytes) -> bool:
                # until trampoline is advertised in lnfeatures, check against hardcoded list
                if is_hardcoded_trampoline(node_id):
                    return True
       t@@ -1276,8 +1314,11 @@ class LNWallet(LNWorker):
                    return True
                return False
        
       -    def suggest_peer(self):
       -        return self.lnrater.suggest_peer() if self.channel_db else random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
       +    def suggest_peer(self) -> Optional[bytes]:
       +        if self.channel_db:
       +            return self.lnrater.suggest_peer()
       +        else:
       +            return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
        
            def create_trampoline_route(
                    self, amount_msat:int,
       t@@ -1400,8 +1441,10 @@ class LNWallet(LNWorker):
                    invoice_pubkey,
                    min_cltv_expiry,
                    r_tags, t_tags,
       -            invoice_features,
       -            *, full_path: LNPaymentPath = None) -> Sequence[Tuple[LNPaymentRoute, int]]:
       +            invoice_features: int,
       +            *,
       +            full_path: LNPaymentPath = None,
       +    ) -> Sequence[Tuple[LNPaymentRoute, int]]:
                """Creates multiple routes for splitting a payment over the available
                private channels.
        
       t@@ -1411,13 +1454,14 @@ class LNWallet(LNWorker):
                # try to send over a single channel
                try:
                    routes = [self.create_route_for_payment(
       -                amount_msat,
       -                invoice_pubkey,
       -                min_cltv_expiry,
       -                r_tags, t_tags,
       -                invoice_features,
       -                None,
       -                full_path=full_path
       +                amount_msat=amount_msat,
       +                invoice_pubkey=invoice_pubkey,
       +                min_cltv_expiry=min_cltv_expiry,
       +                r_tags=r_tags,
       +                t_tags=t_tags,
       +                invoice_features=invoice_features,
       +                outgoing_channel=None,
       +                full_path=full_path,
                    )]
                except NoPathFound:
                    if not invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
       t@@ -1439,12 +1483,13 @@ class LNWallet(LNWorker):
                                    # its capacity. This could be dealt with by temporarily
                                    # iteratively blacklisting channels for this mpp attempt.
                                    route, amt = self.create_route_for_payment(
       -                                part_amount_msat,
       -                                invoice_pubkey,
       -                                min_cltv_expiry,
       -                                r_tags, t_tags,
       -                                invoice_features,
       -                                channel,
       +                                amount_msat=part_amount_msat,
       +                                invoice_pubkey=invoice_pubkey,
       +                                min_cltv_expiry=min_cltv_expiry,
       +                                r_tags=r_tags,
       +                                t_tags=t_tags,
       +                                invoice_features=invoice_features,
       +                                outgoing_channel=channel,
                                        full_path=None)
                                    routes.append((route, amt))
                            self.logger.info(f"found acceptable split configuration: {list(s[0].values())} rating: {s[1]}")
       t@@ -1457,13 +1502,16 @@ class LNWallet(LNWorker):
        
            def create_route_for_payment(
                    self,
       +            *,
                    amount_msat: int,
       -            invoice_pubkey,
       -            min_cltv_expiry,
       -            r_tags, t_tags,
       -            invoice_features,
       +            invoice_pubkey: bytes,
       +            min_cltv_expiry: int,
       +            r_tags,
       +            t_tags,
       +            invoice_features: int,
                    outgoing_channel: Channel = None,
       -            *, full_path: Optional[LNPaymentPath]) -> Tuple[LNPaymentRoute, int]:
       +            full_path: Optional[LNPaymentPath],
       +    ) -> Tuple[LNPaymentRoute, int]:
        
                channels = [outgoing_channel] if outgoing_channel else list(self.channels.values())
                if not self.channel_db:
       t@@ -1554,7 +1602,13 @@ class LNWallet(LNWorker):
                    raise Exception(_("add invoice timed out"))
        
            @log_exceptions
       -    async def create_invoice(self, *, amount_msat: Optional[int], message, expiry: int):
       +    async def create_invoice(
       +            self,
       +            *,
       +            amount_msat: Optional[int],
       +            message,
       +            expiry: int,
       +    ) -> Tuple[LnAddr, str]:
                timestamp = int(time.time())
                routing_hints = await self._calc_routing_hints_for_invoice(amount_msat)
                if not routing_hints:
       t@@ -1628,7 +1682,7 @@ class LNWallet(LNWorker):
                    self.payments[key] = info.amount_msat, info.direction, info.status
                self.wallet.save_db()
        
       -    def htlc_received(self, short_channel_id, htlc, expected_msat):
       +    def htlc_received(self, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int):
                status = self.get_payment_status(htlc.payment_hash)
                if status == PR_PAID:
                    return True, None
   DIR diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
       t@@ -775,7 +775,13 @@ class TestPeer(ElectrumTestCase):
                    min_cltv_expiry = lnaddr.get_min_final_cltv_expiry()
                    payment_hash = lnaddr.paymenthash
                    payment_secret = lnaddr.payment_secret
       -            pay = w1.pay_to_route(route, amount_msat, amount_msat, payment_hash, payment_secret, min_cltv_expiry)
       +            pay = w1.pay_to_route(
       +                route,
       +                amount_msat=amount_msat,
       +                total_msat=amount_msat,
       +                payment_hash=payment_hash,
       +                payment_secret=payment_secret,
       +                min_cltv_expiry=min_cltv_expiry)
                    await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
                with self.assertRaises(PaymentFailure):
                    run(f())