URI: 
       tprotect against getting robbed through routing fees - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 2fafd01945569cb0ef1cad9320e426874fb40f7d
   DIR parent c577df84898c470d45287b39ef484d16cfed4cc4
  HTML Author: SomberNight <somber.night@protonmail.com>
       Date:   Fri, 19 Oct 2018 21:47:51 +0200
       
       protect against getting robbed through routing fees
       
       Diffstat:
         M electrum/lnonion.py                 |      10 ++++++----
         M electrum/lnrouter.py                |     105 ++++++++++++++++++++++---------
         M electrum/lnutil.py                  |       9 ++++++---
         M electrum/lnworker.py                |      18 ++++++++++++++++--
       
       4 files changed, 102 insertions(+), 40 deletions(-)
       ---
   DIR diff --git a/electrum/lnonion.py b/electrum/lnonion.py
       t@@ -33,11 +33,10 @@ from cryptography.hazmat.backends import default_backend
        from . import ecc
        from .crypto import sha256, hmac_oneshot
        from .util import bh2u, profiler, xor_bytes, bfh
       -from .lnutil import get_ecdh
       +from .lnutil import get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH
        from .lnrouter import RouteEdge
        
        
       -NUM_MAX_HOPS_IN_PATH = 20
        HOPS_DATA_SIZE = 1300      # also sometimes called routingInfoSize in bolt-04
        PER_HOP_FULL_SIZE = 65     # HOPS_DATA_SIZE / 20
        NUM_STREAM_BYTES = HOPS_DATA_SIZE + PER_HOP_FULL_SIZE
       t@@ -192,6 +191,9 @@ def calc_hops_data_for_payment(route: List[RouteEdge], amount_msat: int, final_c
            """Returns the hops_data to be used for constructing an onion packet,
            and the amount_msat and cltv to be used on our immediate channel.
            """
       +    if len(route) > NUM_MAX_HOPS_IN_PAYMENT_PATH:
       +        raise PaymentFailure(f"too long route ({len(route)} hops)")
       +
            amt = amount_msat
            cltv = final_cltv
            hops_data = [OnionHopsDataSingle(OnionPerHop(b"\x00" * 8,
       t@@ -209,7 +211,7 @@ def calc_hops_data_for_payment(route: List[RouteEdge], amount_msat: int, final_c
        
        def generate_filler(key_type: bytes, num_hops: int, hop_size: int,
                            shared_secrets: Sequence[bytes]) -> bytes:
       -    filler_size = (NUM_MAX_HOPS_IN_PATH + 1) * hop_size
       +    filler_size = (NUM_MAX_HOPS_IN_PAYMENT_PATH + 1) * hop_size
            filler = bytearray(filler_size)
        
            for i in range(0, num_hops-1):  # -1, as last hop does not obfuscate
       t@@ -219,7 +221,7 @@ def generate_filler(key_type: bytes, num_hops: int, hop_size: int,
                stream_bytes = generate_cipher_stream(stream_key, filler_size)
                filler = xor_bytes(filler, stream_bytes)
        
       -    return filler[(NUM_MAX_HOPS_IN_PATH-num_hops+2)*hop_size:]
       +    return filler[(NUM_MAX_HOPS_IN_PAYMENT_PATH-num_hops+2)*hop_size:]
        
        
        def generate_cipher_stream(stream_key: bytes, num_bytes: int) -> bytes:
   DIR diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
       t@@ -39,7 +39,7 @@ from .storage import JsonDB
        from .lnchannelverifier import LNChannelVerifier, verify_sig_for_channel_update
        from .crypto import Hash
        from . import ecc
       -from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr
       +from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_HOPS_IN_PAYMENT_PATH
        
        
        class UnknownEvenFeatureBits(Exception): pass
       t@@ -502,10 +502,61 @@ class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
                                                 ('cltv_expiry_delta', int)])):
            """if you travel through short_channel_id, you will reach node_id"""
        
       -    def fee_for_edge(self, amount_msat):
       +    def fee_for_edge(self, amount_msat: int) -> int:
                return self.fee_base_msat \
                       + (amount_msat * self.fee_proportional_millionths // 1_000_000)
        
       +    @classmethod
       +    def from_channel_policy(cls, channel_policy: ChannelInfoDirectedPolicy,
       +                            short_channel_id: bytes, end_node: bytes) -> 'RouteEdge':
       +        return RouteEdge(end_node,
       +                         short_channel_id,
       +                         channel_policy.fee_base_msat,
       +                         channel_policy.fee_proportional_millionths,
       +                         channel_policy.cltv_expiry_delta)
       +
       +    def is_sane_to_use(self, amount_msat: int) -> bool:
       +        # TODO revise ad-hoc heuristics
       +        # cltv cannot be more than 2 weeks
       +        if self.cltv_expiry_delta > 14 * 144: return False
       +        total_fee = self.fee_for_edge(amount_msat)
       +        # fees below 50 sat are fine
       +        if total_fee > 50_000:
       +            # fee cannot be higher than amt
       +            if total_fee > amount_msat: return False
       +            # fee cannot be higher than 5000 sat
       +            if total_fee > 5_000_000: return False
       +            # unless amt is tiny, fee cannot be more than 10%
       +            if amount_msat > 1_000_000 and total_fee > amount_msat/10: return False
       +        return True
       +
       +
       +def is_route_sane_to_use(route: List[RouteEdge], invoice_amount_msat: int, min_final_cltv_expiry: int) -> bool:
       +    """Run some sanity checks on the whole route, before attempting to use it.
       +    called when we are paying; so e.g. lower cltv is better
       +    """
       +    if len(route) > NUM_MAX_HOPS_IN_PAYMENT_PATH:
       +        return False
       +    amt = invoice_amount_msat
       +    cltv = min_final_cltv_expiry
       +    for route_edge in reversed(route[1:]):
       +        if not route_edge.is_sane_to_use(amt): return False
       +        amt += route_edge.fee_for_edge(amt)
       +        cltv += route_edge.cltv_expiry_delta
       +    total_fee = amt - invoice_amount_msat
       +    # TODO revise ad-hoc heuristics
       +    # cltv cannot be more than 2 months
       +    if cltv > 60 * 144: return False
       +    # fees below 50 sat are fine
       +    if total_fee > 50_000:
       +        # fee cannot be higher than amt
       +        if total_fee > invoice_amount_msat: return False
       +        # fee cannot be higher than 5000 sat
       +        if total_fee > 5_000_000: return False
       +        # unless amt is tiny, fee cannot be more than 10%
       +        if invoice_amount_msat > 1_000_000 and total_fee > invoice_amount_msat/10: return False
       +    return True
       +
        
        class LNPathFinder(PrintError):
        
       t@@ -513,11 +564,9 @@ class LNPathFinder(PrintError):
                self.channel_db = channel_db
                self.blacklist = set()
        
       -    def _edge_cost(self, short_channel_id: bytes, start_node: bytes, payment_amt_msat: int,
       -                   ignore_cltv=False) -> float:
       -        """Heuristic cost of going through a channel.
       -        direction: 0 or 1. --- 0 means node_id_1 -> node_id_2
       -        """
       +    def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
       +                   payment_amt_msat: int, ignore_cltv=False) -> float:
       +        """Heuristic cost of going through a channel."""
                channel_info = self.channel_db.get_channel_info(short_channel_id)  # type: ChannelInfo
                if channel_info is None:
                    return float('inf')
       t@@ -525,41 +574,39 @@ class LNPathFinder(PrintError):
                channel_policy = channel_info.get_policy_for_node(start_node)
                if channel_policy is None: return float('inf')
                if channel_policy.disabled: return float('inf')
       -        cltv_expiry_delta           = channel_policy.cltv_expiry_delta
       -        htlc_minimum_msat           = channel_policy.htlc_minimum_msat
       -        fee_base_msat               = channel_policy.fee_base_msat
       -        fee_proportional_millionths = channel_policy.fee_proportional_millionths
       -        if payment_amt_msat is not None:
       -            if payment_amt_msat < htlc_minimum_msat:
       -                return float('inf')  # payment amount too little
       -            if channel_info.capacity_sat is not None and \
       -                    payment_amt_msat // 1000 > channel_info.capacity_sat:
       -                return float('inf')  # payment amount too large
       -            if channel_policy.htlc_maximum_msat is not None and \
       -                    payment_amt_msat > channel_policy.htlc_maximum_msat:
       -                return float('inf')  # payment amount too large
       -        amt = payment_amt_msat or 50000 * 1000  # guess for typical payment amount
       -        fee_msat = fee_base_msat + amt * fee_proportional_millionths / 1_000_000
       +        route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node)
       +        if payment_amt_msat < channel_policy.htlc_minimum_msat:
       +            return float('inf')  # payment amount too little
       +        if channel_info.capacity_sat is not None and \
       +                payment_amt_msat // 1000 > channel_info.capacity_sat:
       +            return float('inf')  # payment amount too large
       +        if channel_policy.htlc_maximum_msat is not None and \
       +                payment_amt_msat > channel_policy.htlc_maximum_msat:
       +            return float('inf')  # payment amount too large
       +        if not route_edge.is_sane_to_use(payment_amt_msat):
       +            return float('inf')  # thanks but no thanks
       +        fee_msat = route_edge.fee_for_edge(payment_amt_msat)
                # TODO revise
                # paying 10 more satoshis ~ waiting one more block
                fee_cost = fee_msat / 1000 / 10
       -        cltv_cost = cltv_expiry_delta if not ignore_cltv else 0
       +        cltv_cost = route_edge.cltv_expiry_delta if not ignore_cltv else 0
                return cltv_cost + fee_cost + 1
        
            @profiler
            def find_path_for_payment(self, from_node_id: bytes, to_node_id: bytes,
       -                              amount_msat: int=None, my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]:
       +                              amount_msat: int, my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]:
                """Return a path between from_node_id and to_node_id.
        
                Returns a list of (node_id, short_channel_id) representing a path.
                To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
                i.e. an element reads as, "to get to node_id, travel through short_channel_id"
                """
       -        if amount_msat is not None: assert type(amount_msat) is int
       +        assert type(amount_msat) is int
                if my_channels is None: my_channels = []
                unable_channels = set(map(lambda x: x.short_channel_id, filter(lambda x: not x.can_pay(amount_msat), my_channels)))
        
                # TODO find multiple paths??
       +        # FIXME paths cannot be longer than 20 (onion packet)...
        
                # run Dijkstra
                distance_from_start = defaultdict(lambda: float('inf'))
       t@@ -584,7 +631,7 @@ class LNPathFinder(PrintError):
                        node1, node2 = channel_info.node_id_1, channel_info.node_id_2
                        neighbour = node2 if node1 == cur_node else node1
                        ignore_cltv_delta_in_edge_cost = cur_node == from_node_id
       -                edge_cost = self._edge_cost(edge_channel_id, cur_node, amount_msat,
       +                edge_cost = self._edge_cost(edge_channel_id, cur_node, neighbour, amount_msat,
                                                    ignore_cltv=ignore_cltv_delta_in_edge_cost)
                        alt_dist_to_neighbour = distance_from_start[cur_node] + edge_cost
                        if alt_dist_to_neighbour < distance_from_start[neighbour]:
       t@@ -614,10 +661,6 @@ class LNPathFinder(PrintError):
                    channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
                    if channel_policy is None:
                        raise Exception(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}')
       -            route.append(RouteEdge(node_id,
       -                                   short_channel_id,
       -                                   channel_policy.fee_base_msat,
       -                                   channel_policy.fee_proportional_millionths,
       -                                   channel_policy.cltv_expiry_delta))
       +            route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id))
                    prev_node_id = node_id
                return route
   DIR diff --git a/electrum/lnutil.py b/electrum/lnutil.py
       t@@ -1,7 +1,7 @@
        from enum import IntFlag, IntEnum
        import json
        from collections import namedtuple
       -from typing import NamedTuple, List, Tuple, Mapping
       +from typing import NamedTuple, List, Tuple, Mapping, Optional
        import re
        
        from .util import bfh, bh2u, inv_dict
       t@@ -16,6 +16,7 @@ from .i18n import _
        from .lnaddr import lndecode
        from .keystore import BIP32_KeyStore
        
       +
        HTLC_TIMEOUT_WEIGHT = 663
        HTLC_SUCCESS_WEIGHT = 703
        
       t@@ -597,8 +598,6 @@ def generate_keypair(ln_keystore: BIP32_KeyStore, key_family: LnKeyFamily, index
            return Keypair(*ln_keystore.get_keypair([key_family, 0, index], None))
        
        
       -from typing import Optional
       -
        class EncumberedTransaction(NamedTuple("EncumberedTransaction", [('tx', Transaction),
                                                                         ('csv_delay', Optional[int])])):
            def to_json(self) -> dict:
       t@@ -612,3 +611,7 @@ class EncumberedTransaction(NamedTuple("EncumberedTransaction", [('tx', Transact
                d2 = dict(d)
                d2['tx'] = Transaction(d['tx'])
                return EncumberedTransaction(**d2)
       +
       +
       +NUM_MAX_HOPS_IN_PAYMENT_PATH = 20
       +
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -25,10 +25,11 @@ from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
                             get_compressed_pubkey_from_bech32, extract_nodeid,
                             PaymentFailure, split_host_port, ConnStringFormatError,
                             generate_keypair, LnKeyFamily, LOCAL, REMOTE,
       -                     UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE)
       +                     UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE,
       +                     NUM_MAX_HOPS_IN_PAYMENT_PATH)
        from .lnaddr import lndecode
        from .i18n import _
       -from .lnrouter import RouteEdge
       +from .lnrouter import RouteEdge, is_route_sane_to_use
        
        NUM_PEERS_TARGET = 4
        PEER_RETRY_INTERVAL = 600  # seconds
       t@@ -253,6 +254,10 @@ class LNWorker(PrintError):
                if amount_sat is None:
                    raise InvoiceError(_("Missing amount"))
                amount_msat = int(amount_sat * 1000)
       +        if addr.get_min_final_cltv_expiry() > 60 * 144:
       +            raise InvoiceError("{}\n{}".format(
       +                _("Invoice wants us to risk locking funds for unreasonably long."),
       +                f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
                route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat)
                node_id, short_channel_id = route[0].node_id, route[0].short_channel_id
                peer = self.peers[node_id]
       t@@ -281,6 +286,7 @@ class LNWorker(PrintError):
                    channels = list(self.channels.values())
                for private_route in r_tags:
                    if len(private_route) == 0: continue
       +            if len(private_route) > NUM_MAX_HOPS_IN_PAYMENT_PATH: continue
                    border_node_pubkey = private_route[0][0]
                    path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels)
                    if not path: continue
       t@@ -301,6 +307,11 @@ class LNWorker(PrintError):
                        route.append(RouteEdge(node_pubkey, short_channel_id, fee_base_msat, fee_proportional_millionths,
                                               cltv_expiry_delta))
                        prev_node_id = node_pubkey
       +            # test sanity
       +            if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()):
       +                self.print_error(f"rejecting insane route {route}")
       +                route = None
       +                continue
                    break
                # if could not find route using any hint; try without hint now
                if route is None:
       t@@ -308,6 +319,9 @@ class LNWorker(PrintError):
                    if not path:
                        raise PaymentFailure(_("No path found"))
                    route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
       +            if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()):
       +                self.print_error(f"rejecting insane route {route}")
       +                raise PaymentFailure(_("No path found"))
                return route
        
            def add_invoice(self, amount_sat, message):