URI: 
       tln gossip: don't put own channels into db; always pass them to fn calls - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 46d8080c76e79670e8abaaaa0eb2d4d4a74544c1
   DIR parent 7d65fe1ba32200ae7e46841b7e0e4b6397bf7a2b
  HTML Author: SomberNight <somber.night@protonmail.com>
       Date:   Mon, 17 Feb 2020 20:38:41 +0100
       
       ln gossip: don't put own channels into db; always pass them to fn calls
       
       Previously we would put fake chan announcement and fake outgoing chan upd
       for own channels into db (to make path finding work). See Peer.add_own_channel().
       Now, instead of above, we pass a "my_channels" param to the relevant ChannelDB methods.
       
       Diffstat:
         M electrum/channel_db.py              |      88 ++++++++++++++++++++++---------
         M electrum/lnchannel.py               |      76 +++++++++++++++++++++++++++++--
         M electrum/lnpeer.py                  |     111 ++++---------------------------
         M electrum/lnrouter.py                |      32 ++++++++++++++++++-------------
         M electrum/lnworker.py                |      25 ++++++++++++++++++-------
         M electrum/tests/test_lnpeer.py       |       9 ++++++---
       
       6 files changed, 190 insertions(+), 151 deletions(-)
       ---
   DIR diff --git a/electrum/channel_db.py b/electrum/channel_db.py
       t@@ -39,9 +39,11 @@ from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enab
        from .logging import Logger
        from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID
        from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
       +from .lnmsg import decode_msg
        
        if TYPE_CHECKING:
            from .network import Network
       +    from .lnchannel import Channel
        
        
        class UnknownEvenFeatureBits(Exception): pass
       t@@ -63,7 +65,7 @@ class ChannelInfo(NamedTuple):
            capacity_sat: Optional[int]
        
            @staticmethod
       -    def from_msg(payload):
       +    def from_msg(payload: dict) -> 'ChannelInfo':
                features = int.from_bytes(payload['features'], 'big')
                validate_features(features)
                channel_id = payload['short_channel_id']
       t@@ -78,6 +80,11 @@ class ChannelInfo(NamedTuple):
                    capacity_sat = capacity_sat
                )
        
       +    @staticmethod
       +    def from_raw_msg(raw: bytes) -> 'ChannelInfo':
       +        payload_dict = decode_msg(raw)[1]
       +        return ChannelInfo.from_msg(payload_dict)
       +
        
        class Policy(NamedTuple):
            key: bytes
       t@@ -91,7 +98,7 @@ class Policy(NamedTuple):
            timestamp: int
        
            @staticmethod
       -    def from_msg(payload):
       +    def from_msg(payload: dict) -> 'Policy':
                return Policy(
                    key                         = payload['short_channel_id'] + payload['start_node'],
                    cltv_expiry_delta           = int.from_bytes(payload['cltv_expiry_delta'], "big"),
       t@@ -248,11 +255,11 @@ class ChannelDB(SqlDB):
                self.ca_verifier = LNChannelVerifier(network, self)
                # initialized in load_data
                self._channels = {}  # type: Dict[bytes, ChannelInfo]
       -        self._policies = {}
       +        self._policies = {}  # type: Dict[Tuple[bytes, bytes], Policy]  # (node_id, scid) -> Policy
                self._nodes = {}
                # node_id -> (host, port, ts)
                self._addresses = defaultdict(set)  # type: Dict[bytes, Set[Tuple[str, int, int]]]
       -        self._channels_for_node = defaultdict(set)
       +        self._channels_for_node = defaultdict(set)  # type: Dict[bytes, Set[ShortChannelID]]
                self.data_loaded = asyncio.Event()
                self.network = network # only for callback
        
       t@@ -495,17 +502,6 @@ class ChannelDB(SqlDB):
                self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
                self.update_counts()
        
       -    def get_routing_policy_for_channel(self, start_node_id: bytes,
       -                                       short_channel_id: bytes) -> Optional[Policy]:
       -        if not start_node_id or not short_channel_id: return None
       -        channel_info = self.get_channel_info(short_channel_id)
       -        if channel_info is not None:
       -            return self.get_policy_for_node(short_channel_id, start_node_id)
       -        msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
       -        if not msg:
       -            return None
       -        return Policy.from_msg(msg) # won't actually be written to DB
       -
            def get_old_policies(self, delta):
                now = int(time.time())
                return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
       t@@ -587,12 +583,56 @@ class ChannelDB(SqlDB):
                        out.add(short_channel_id)
                self.logger.info(f'semi-orphaned: {len(out)}')
        
       -    def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
       -        return self._policies.get((node_id, short_channel_id))
       -
       -    def get_channel_info(self, channel_id: bytes) -> ChannelInfo:
       -        return self._channels.get(channel_id)
       -
       -    def get_channels_for_node(self, node_id) -> Set[bytes]:
       -        """Returns the set of channels that have node_id as one of the endpoints."""
       -        return self._channels_for_node.get(node_id) or set()
       +    def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *,
       +                            my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']:
       +        channel_info = self.get_channel_info(short_channel_id)
       +        if channel_info is not None:  # publicly announced channel
       +            policy = self._policies.get((node_id, short_channel_id))
       +            if policy:
       +                return policy
       +        else:  # private channel
       +            chan_upd_dict = self._channel_updates_for_private_channels.get((node_id, short_channel_id))
       +            if chan_upd_dict:
       +                return Policy.from_msg(chan_upd_dict)
       +        # check if it's one of our own channels
       +        if not my_channels:
       +            return
       +        chan = my_channels.get(short_channel_id)  # type: Optional[Channel]
       +        if not chan:
       +            return
       +        if node_id == chan.node_id:  # incoming direction (to us)
       +            remote_update_raw = chan.get_remote_update()
       +            if not remote_update_raw:
       +                return
       +            now = int(time.time())
       +            remote_update_decoded = decode_msg(remote_update_raw)[1]
       +            remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big")
       +            remote_update_decoded['start_node'] = node_id
       +            return Policy.from_msg(remote_update_decoded)
       +        elif node_id == chan.get_local_pubkey():  # outgoing direction (from us)
       +            local_update_decoded = decode_msg(chan.get_outgoing_gossip_channel_update())[1]
       +            local_update_decoded['start_node'] = node_id
       +            return Policy.from_msg(local_update_decoded)
       +
       +    def get_channel_info(self, short_channel_id: bytes, *,
       +                         my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[ChannelInfo]:
       +        ret = self._channels.get(short_channel_id)
       +        if ret:
       +            return ret
       +        # check if it's one of our own channels
       +        if not my_channels:
       +            return
       +        chan = my_channels.get(short_channel_id)  # type: Optional[Channel]
       +        ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs())
       +        return ci._replace(capacity_sat=chan.constraints.capacity)
       +
       +    def get_channels_for_node(self, node_id: bytes, *,
       +                              my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Set[bytes]:
       +        """Returns the set of short channel IDs where node_id is one of the channel participants."""
       +        relevant_channels = self._channels_for_node.get(node_id) or set()
       +        relevant_channels = set(relevant_channels)  # copy
       +        # add our own channels  # TODO maybe slow?
       +        for chan in (my_channels.values() or []):
       +            if node_id in (chan.node_id, chan.get_local_pubkey()):
       +                relevant_channels.add(chan.short_channel_id)
       +        return relevant_channels
   DIR diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py
       t@@ -32,13 +32,14 @@ import time
        import threading
        
        from . import ecc
       +from . import constants
        from .util import bfh, bh2u
        from .bitcoin import redeem_script_to_address
        from .crypto import sha256, sha256d
        from .transaction import Transaction, PartialTransaction
        from .logging import Logger
       -
        from .lnonion import decode_onion_error
       +from . import lnutil
        from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, ChannelConstraints,
                            get_per_commitment_secret_from_seed, secret_to_pubkey, derive_privkey, make_closing_tx,
                            sign_and_get_sig_string, RevocationStore, derive_blinded_pubkey, Direction, derive_pubkey,
       t@@ -47,10 +48,10 @@ from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKey
                            funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs,
                            ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script,
                            ShortChannelID, map_htlcs_to_ctx_output_idxs)
       -from .lnutil import FeeUpdate
        from .lnsweep import create_sweeptxs_for_our_ctx, create_sweeptxs_for_their_ctx
        from .lnsweep import create_sweeptx_for_their_revoked_htlc, SweepInfo
        from .lnhtlc import HTLCManager
       +from .lnmsg import encode_msg, decode_msg
        
        if TYPE_CHECKING:
            from .lnworker import LNWallet
       t@@ -136,7 +137,6 @@ class Channel(Logger):
                self.funding_outpoint = state["funding_outpoint"]
                self.node_id = bfh(state["node_id"])
                self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"])
       -        self.short_channel_id_predicted = self.short_channel_id
                self.onion_keys = state['onion_keys']
                self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp']
                self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate)
       t@@ -144,6 +144,7 @@ class Channel(Logger):
                self.peer_state = peer_states.DISCONNECTED
                self.sweep_info = {}  # type: Dict[str, Dict[str, SweepInfo]]
                self._outgoing_channel_update = None  # type: Optional[bytes]
       +        self._chan_ann_without_sigs = None  # type: Optional[bytes]
                self.revocation_store = RevocationStore(state["revocation_store"])
        
            def set_onion_key(self, key, value):
       t@@ -158,12 +159,77 @@ class Channel(Logger):
            def get_data_loss_protect_remote_pcp(self, key):
                return self.data_loss_protect_remote_pcp.get(key)
        
       -    def set_remote_update(self, raw):
       +    def get_local_pubkey(self) -> bytes:
       +        if not self.lnworker:
       +            raise Exception('lnworker not set for channel!')
       +        return self.lnworker.node_keypair.pubkey
       +
       +    def set_remote_update(self, raw: bytes) -> None:
                self.storage['remote_update'] = raw.hex()
        
       -    def get_remote_update(self):
       +    def get_remote_update(self) -> Optional[bytes]:
                return bfh(self.storage.get('remote_update')) if self.storage.get('remote_update') else None
        
       +    def get_outgoing_gossip_channel_update(self) -> bytes:
       +        if self._outgoing_channel_update is not None:
       +            return self._outgoing_channel_update
       +        if not self.lnworker:
       +            raise Exception('lnworker not set for channel!')
       +        sorted_node_ids = list(sorted([self.node_id, self.get_local_pubkey()]))
       +        channel_flags = b'\x00' if sorted_node_ids[0] == self.get_local_pubkey() else b'\x01'
       +        now = int(time.time())
       +        htlc_maximum_msat = min(self.config[REMOTE].max_htlc_value_in_flight_msat, 1000 * self.constraints.capacity)
       +
       +        chan_upd = encode_msg(
       +            "channel_update",
       +            short_channel_id=self.short_channel_id,
       +            channel_flags=channel_flags,
       +            message_flags=b'\x01',
       +            cltv_expiry_delta=lnutil.NBLOCK_OUR_CLTV_EXPIRY_DELTA.to_bytes(2, byteorder="big"),
       +            htlc_minimum_msat=self.config[REMOTE].htlc_minimum_msat.to_bytes(8, byteorder="big"),
       +            htlc_maximum_msat=htlc_maximum_msat.to_bytes(8, byteorder="big"),
       +            fee_base_msat=lnutil.OUR_FEE_BASE_MSAT.to_bytes(4, byteorder="big"),
       +            fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS.to_bytes(4, byteorder="big"),
       +            chain_hash=constants.net.rev_genesis_bytes(),
       +            timestamp=now.to_bytes(4, byteorder="big"),
       +        )
       +        sighash = sha256d(chan_upd[2 + 64:])
       +        sig = ecc.ECPrivkey(self.lnworker.node_keypair.privkey).sign(sighash, ecc.sig_string_from_r_and_s)
       +        message_type, payload = decode_msg(chan_upd)
       +        payload['signature'] = sig
       +        chan_upd = encode_msg(message_type, **payload)
       +
       +        self._outgoing_channel_update = chan_upd
       +        return chan_upd
       +
       +    def construct_channel_announcement_without_sigs(self) -> bytes:
       +        if self._chan_ann_without_sigs is not None:
       +            return self._chan_ann_without_sigs
       +        if not self.lnworker:
       +            raise Exception('lnworker not set for channel!')
       +
       +        bitcoin_keys = [self.config[REMOTE].multisig_key.pubkey,
       +                        self.config[LOCAL].multisig_key.pubkey]
       +        node_ids = [self.node_id, self.get_local_pubkey()]
       +        sorted_node_ids = list(sorted(node_ids))
       +        if sorted_node_ids != node_ids:
       +            node_ids = sorted_node_ids
       +            bitcoin_keys.reverse()
       +
       +        chan_ann = encode_msg("channel_announcement",
       +            len=0,
       +            features=b'',
       +            chain_hash=constants.net.rev_genesis_bytes(),
       +            short_channel_id=self.short_channel_id,
       +            node_id_1=node_ids[0],
       +            node_id_2=node_ids[1],
       +            bitcoin_key_1=bitcoin_keys[0],
       +            bitcoin_key_2=bitcoin_keys[1]
       +        )
       +
       +        self._chan_ann_without_sigs = chan_ann
       +        return chan_ann
       +
            def set_short_channel_id(self, short_id):
                self.short_channel_id = short_id
                self.storage["short_channel_id"] = short_id
   DIR diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
       t@@ -953,112 +953,25 @@ class Peer(Logger):
                assert chan.config[LOCAL].funding_locked_received
                chan.set_state(channel_states.OPEN)
                self.network.trigger_callback('channel', chan)
       -        self.add_own_channel(chan)
       +        # peer may have sent us a channel update for the incoming direction previously
       +        pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id)
       +        if pending_channel_update:
       +            chan.set_remote_update(pending_channel_update['raw'])
                self.logger.info(f"CHANNEL OPENING COMPLETED for {scid}")
                forwarding_enabled = self.network.config.get('lightning_forward_payments', False)
                if forwarding_enabled:
                    # send channel_update of outgoing edge to peer,
                    # so that channel can be used to to receive payments
                    self.logger.info(f"sending channel update for outgoing edge of {scid}")
       -            chan_upd = self.get_outgoing_gossip_channel_update_for_chan(chan)
       +            chan_upd = chan.get_outgoing_gossip_channel_update()
                    self.transport.send_bytes(chan_upd)
        
       -    def add_own_channel(self, chan):
       -        # add channel to database
       -        bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
       -        sorted_node_ids = list(sorted(self.node_ids))
       -        if sorted_node_ids != self.node_ids:
       -            bitcoin_keys.reverse()
       -        # note: we inject a channel announcement, and a channel update (for outgoing direction)
       -        # This is atm needed for
       -        # - finding routes
       -        # - the ChanAnn is needed so that we can anchor to it a future ChanUpd
       -        #   that the remote sends, even if the channel was not announced
       -        #   (from BOLT-07: "MAY create a channel_update to communicate the channel
       -        #    parameters to the final node, even though the channel has not yet been announced")
       -        self.channel_db.add_channel_announcement(
       -            {
       -                "short_channel_id": chan.short_channel_id,
       -                "node_id_1": sorted_node_ids[0],
       -                "node_id_2": sorted_node_ids[1],
       -                'chain_hash': constants.net.rev_genesis_bytes(),
       -                'len': b'\x00\x00',
       -                'features': b'',
       -                'bitcoin_key_1': bitcoin_keys[0],
       -                'bitcoin_key_2': bitcoin_keys[1]
       -            },
       -            trusted=True)
       -        # only inject outgoing direction:
       -        chan_upd_bytes = self.get_outgoing_gossip_channel_update_for_chan(chan)
       -        chan_upd_payload = decode_msg(chan_upd_bytes)[1]
       -        self.channel_db.add_channel_update(chan_upd_payload)
       -        # peer may have sent us a channel update for the incoming direction previously
       -        pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id)
       -        if pending_channel_update:
       -            chan.set_remote_update(pending_channel_update['raw'])
       -        # add remote update with a fresh timestamp
       -        if chan.get_remote_update():
       -            now = int(time.time())
       -            remote_update_decoded = decode_msg(chan.get_remote_update())[1]
       -            remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big")
       -            self.channel_db.add_channel_update(remote_update_decoded)
       -
       -    def get_outgoing_gossip_channel_update_for_chan(self, chan: Channel) -> bytes:
       -        if chan._outgoing_channel_update is not None:
       -            return chan._outgoing_channel_update
       -        sorted_node_ids = list(sorted(self.node_ids))
       -        channel_flags = b'\x00' if sorted_node_ids[0] == privkey_to_pubkey(self.privkey) else b'\x01'
       -        now = int(time.time())
       -        htlc_maximum_msat = min(chan.config[REMOTE].max_htlc_value_in_flight_msat, 1000 * chan.constraints.capacity)
       -
       -        chan_upd = encode_msg(
       -            "channel_update",
       -            short_channel_id=chan.short_channel_id,
       -            channel_flags=channel_flags,
       -            message_flags=b'\x01',
       -            cltv_expiry_delta=lnutil.NBLOCK_OUR_CLTV_EXPIRY_DELTA.to_bytes(2, byteorder="big"),
       -            htlc_minimum_msat=chan.config[REMOTE].htlc_minimum_msat.to_bytes(8, byteorder="big"),
       -            htlc_maximum_msat=htlc_maximum_msat.to_bytes(8, byteorder="big"),
       -            fee_base_msat=lnutil.OUR_FEE_BASE_MSAT.to_bytes(4, byteorder="big"),
       -            fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS.to_bytes(4, byteorder="big"),
       -            chain_hash=constants.net.rev_genesis_bytes(),
       -            timestamp=now.to_bytes(4, byteorder="big"),
       -        )
       -        sighash = sha256d(chan_upd[2 + 64:])
       -        sig = ecc.ECPrivkey(self.privkey).sign(sighash, sig_string_from_r_and_s)
       -        message_type, payload = decode_msg(chan_upd)
       -        payload['signature'] = sig
       -        chan_upd = encode_msg(message_type, **payload)
       -
       -        chan._outgoing_channel_update = chan_upd
       -        return chan_upd
       -
            def send_announcement_signatures(self, chan: Channel):
       -
       -        bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey,
       -                        chan.config[LOCAL].multisig_key.pubkey]
       -
       -        sorted_node_ids = list(sorted(self.node_ids))
       -        if sorted_node_ids != self.node_ids:
       -            node_ids = sorted_node_ids
       -            bitcoin_keys.reverse()
       -        else:
       -            node_ids = self.node_ids
       -
       -        chan_ann = encode_msg("channel_announcement",
       -            len=0,
       -            #features not set (defaults to zeros)
       -            chain_hash=constants.net.rev_genesis_bytes(),
       -            short_channel_id=chan.short_channel_id,
       -            node_id_1=node_ids[0],
       -            node_id_2=node_ids[1],
       -            bitcoin_key_1=bitcoin_keys[0],
       -            bitcoin_key_2=bitcoin_keys[1]
       -        )
       -        to_hash = chan_ann[256+2:]
       -        h = sha256d(to_hash)
       -        bitcoin_signature = ecc.ECPrivkey(chan.config[LOCAL].multisig_key.privkey).sign(h, sig_string_from_r_and_s)
       -        node_signature = ecc.ECPrivkey(self.privkey).sign(h, sig_string_from_r_and_s)
       +        chan_ann = chan.construct_channel_announcement_without_sigs()
       +        preimage = chan_ann[256+2:]
       +        msg_hash = sha256d(preimage)
       +        bitcoin_signature = ecc.ECPrivkey(chan.config[LOCAL].multisig_key.privkey).sign(msg_hash, sig_string_from_r_and_s)
       +        node_signature = ecc.ECPrivkey(self.privkey).sign(msg_hash, sig_string_from_r_and_s)
                self.send_message("announcement_signatures",
                    channel_id=chan.channel_id,
                    short_channel_id=chan.short_channel_id,
       t@@ -1066,7 +979,7 @@ class Peer(Logger):
                    bitcoin_signature=bitcoin_signature
                )
        
       -        return h, node_signature, bitcoin_signature
       +        return msg_hash, node_signature, bitcoin_signature
        
            def on_update_fail_htlc(self, payload):
                channel_id = payload["channel_id"]
       t@@ -1255,7 +1168,7 @@ class Peer(Logger):
                    reason = OnionRoutingFailureMessage(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
                    await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
                    return
       -        outgoing_chan_upd = self.get_outgoing_gossip_channel_update_for_chan(next_chan)[2:]
       +        outgoing_chan_upd = next_chan.get_outgoing_gossip_channel_update()[2:]
                outgoing_chan_upd_len = len(outgoing_chan_upd).to_bytes(2, byteorder="big")
                if next_chan.get_state() != channel_states.OPEN:
                    self.logger.info(f"cannot forward htlc. next_chan not OPEN: {next_chan_scid} in state {next_chan.get_state()}")
   DIR diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
       t@@ -129,18 +129,20 @@ class LNPathFinder(Logger):
                self.blacklist.add(short_channel_id)
        
            def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
       -                   payment_amt_msat: int, ignore_costs=False, is_mine=False) -> Tuple[float, int]:
       +                   payment_amt_msat: int, ignore_costs=False, is_mine=False, *,
       +                   my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Tuple[float, int]:
                """Heuristic cost of going through a channel.
                Returns (heuristic_cost, fee_for_edge_msat).
                """
       -        channel_info = self.channel_db.get_channel_info(short_channel_id)
       +        channel_info = self.channel_db.get_channel_info(short_channel_id, my_channels=my_channels)
                if channel_info is None:
                    return float('inf'), 0
       -        channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node)
       +        channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node, my_channels=my_channels)
                if channel_policy is None:
                    return float('inf'), 0
                # channels that did not publish both policies often return temporary channel failure
       -        if self.channel_db.get_policy_for_node(short_channel_id, end_node) is None and not is_mine:
       +        if self.channel_db.get_policy_for_node(short_channel_id, end_node, my_channels=my_channels) is None \
       +                and not is_mine:
                    return float('inf'), 0
                if channel_policy.is_disabled():
                    return float('inf'), 0
       t@@ -164,8 +166,9 @@ class LNPathFinder(Logger):
        
            @profiler
            def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
       -                              invoice_amount_msat: int,
       -                              my_channels: List['Channel']=None) -> Sequence[Tuple[bytes, bytes]]:
       +                              invoice_amount_msat: int, *,
       +                              my_channels: Dict[ShortChannelID, 'Channel'] = None) \
       +            -> Optional[Sequence[Tuple[bytes, bytes]]]:
                """Return a path from nodeA to nodeB.
        
                Returns a list of (node_id, short_channel_id) representing a path.
       t@@ -175,8 +178,7 @@ class LNPathFinder(Logger):
                assert type(nodeA) is bytes
                assert type(nodeB) is bytes
                assert type(invoice_amount_msat) is int
       -        if my_channels is None: my_channels = []
       -        my_channels = {chan.short_channel_id: chan for chan in my_channels}
       +        if my_channels is None: my_channels = {}
        
                # FIXME paths cannot be longer than 20 edges (onion packet)...
        
       t@@ -204,7 +206,8 @@ class LNPathFinder(Logger):
                        end_node=edge_endnode,
                        payment_amt_msat=amount_msat,
                        ignore_costs=(edge_startnode == nodeA),
       -                is_mine=is_mine)
       +                is_mine=is_mine,
       +                my_channels=my_channels)
                    alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
                    if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
                        distance_from_start[edge_startnode] = alt_dist_to_neighbour
       t@@ -222,11 +225,11 @@ class LNPathFinder(Logger):
                        # so instead of decreasing priorities, we add items again into the queue.
                        # so there are duplicates in the queue, that we discard now:
                        continue
       -            for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode):
       +            for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels):
                        assert isinstance(edge_channel_id, bytes)
                        if edge_channel_id in self.blacklist:
                            continue
       -                channel_info = self.channel_db.get_channel_info(edge_channel_id)
       +                channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels)
                        edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
                        inspect_edge()
                else:
       t@@ -241,14 +244,17 @@ class LNPathFinder(Logger):
                    edge_startnode = edge_endnode
                return path
        
       -    def create_route_from_path(self, path, from_node_id: bytes) -> LNPaymentRoute:
       +    def create_route_from_path(self, path, from_node_id: bytes, *,
       +                               my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute:
                assert isinstance(from_node_id, bytes)
                if path is None:
                    raise Exception('cannot create route from None path')
                route = []
                prev_node_id = from_node_id
                for node_id, short_channel_id in path:
       -            channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
       +            channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
       +                                                                 node_id=prev_node_id,
       +                                                                 my_channels=my_channels)
                    if channel_policy is None:
                        raise NoChannelPolicy(short_channel_id)
                    route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id))
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -942,16 +942,20 @@ class LNWallet(LNWorker):
                random.shuffle(r_tags)
                with self.lock:
                    channels = list(self.channels.values())
       +        scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
       +                               if chan.short_channel_id is not None}
                for private_route in r_tags:
                    if len(private_route) == 0:
                        continue
                    if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
                        continue
                    border_node_pubkey = private_route[0][0]
       -            path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels)
       +            path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat,
       +                                                                  my_channels=scid_to_my_channels)
                    if not path:
                        continue
       -            route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
       +            route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey,
       +                                                                    my_channels=scid_to_my_channels)
                    # we need to shift the node pubkey by one towards the destination:
                    private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey]
                    private_route_rest = [edge[1:] for edge in private_route]
       t@@ -961,7 +965,9 @@ class LNWallet(LNWorker):
                        short_channel_id = ShortChannelID(short_channel_id)
                        # if we have a routing policy for this edge in the db, that takes precedence,
                        # as it is likely from a previous failure
       -                channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
       +                channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
       +                                                                     node_id=prev_node_id,
       +                                                                     my_channels=scid_to_my_channels)
                        if channel_policy:
                            fee_base_msat = channel_policy.fee_base_msat
                            fee_proportional_millionths = channel_policy.fee_proportional_millionths
       t@@ -977,10 +983,12 @@ class LNWallet(LNWorker):
                    break
                # if could not find route using any hint; try without hint now
                if route is None:
       -            path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat, channels)
       +            path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat,
       +                                                                  my_channels=scid_to_my_channels)
                    if not path:
                        raise NoPathFound()
       -            route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
       +            route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey,
       +                                                                    my_channels=scid_to_my_channels)
                    if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()):
                        self.logger.info(f"rejecting insane route {route}")
                        raise NoPathFound()
       t@@ -1099,6 +1107,8 @@ class LNWallet(LNWorker):
                routing_hints = []
                with self.lock:
                    channels = list(self.channels.values())
       +        scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
       +                               if chan.short_channel_id is not None}
                # note: currently we add *all* our channels; but this might be a privacy leak?
                for chan in channels:
                    # check channel is open
       t@@ -1110,7 +1120,7 @@ class LNWallet(LNWorker):
                        continue
                    chan_id = chan.short_channel_id
                    assert isinstance(chan_id, bytes), chan_id
       -            channel_info = self.channel_db.get_channel_info(chan_id)
       +            channel_info = self.channel_db.get_channel_info(chan_id, my_channels=scid_to_my_channels)
                    # note: as a fallback, if we don't have a channel update for the
                    # incoming direction of our private channel, we fill the invoice with garbage.
                    # the sender should still be able to pay us, but will incur an extra round trip
       t@@ -1120,7 +1130,8 @@ class LNWallet(LNWorker):
                    cltv_expiry_delta = 1  # lnd won't even try with zero
                    missing_info = True
                    if channel_info:
       -                policy = self.channel_db.get_policy_for_node(channel_info.short_channel_id, chan.node_id)
       +                policy = self.channel_db.get_policy_for_node(channel_info.short_channel_id, chan.node_id,
       +                                                             my_channels=scid_to_my_channels)
                        if policy:
                            fee_base_msat = policy.fee_base_msat
                            fee_proportional_millionths = policy.fee_proportional_millionths
   DIR diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
       t@@ -18,7 +18,7 @@ from electrum.lnpeer import Peer
        from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
        from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
        from electrum.lnutil import PaymentFailure, LnLocalFeatures
       -from electrum.lnchannel import channel_states, peer_states
       +from electrum.lnchannel import channel_states, peer_states, Channel
        from electrum.lnrouter import LNPathFinder
        from electrum.channel_db import ChannelDB
        from electrum.lnworker import LNWallet, NoPathFound
       t@@ -77,7 +77,7 @@ class MockWallet:
                return False
        
        class MockLNWallet:
       -    def __init__(self, remote_keypair, local_keypair, chan, tx_queue):
       +    def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue):
                self.remote_keypair = remote_keypair
                self.node_keypair = local_keypair
                self.network = MockNetwork(tx_queue)
       t@@ -88,6 +88,8 @@ class MockLNWallet:
                self.localfeatures = LnLocalFeatures(0)
                self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT
                self.pending_payments = defaultdict(asyncio.Future)
       +        chan.lnworker = self
       +        chan.node_id = remote_keypair.pubkey
        
            def get_invoice_status(self, key):
                pass
       t@@ -127,6 +129,7 @@ class MockLNWallet:
            _pay_to_route = LNWallet._pay_to_route
            force_close_channel = LNWallet.force_close_channel
            get_first_timestamp = lambda self: 0
       +    payment_completed = LNWallet.payment_completed
        
        class MockTransport:
            def __init__(self, name):
       t@@ -264,7 +267,7 @@ class TestPeer(ElectrumTestCase):
                pay_req = self.prepare_invoice(w2)
                async def pay():
                    result = await LNWallet._pay(w1, pay_req)
       -            self.assertEqual(result, True)
       +            self.assertTrue(result)
                    gath.cancel()
                gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop())
                async def f():