tlnrouter.py - electrum - Electrum Bitcoin wallet
HTML git clone https://git.parazyd.org/electrum
DIR Log
DIR Files
DIR Refs
DIR Submodules
---
tlnrouter.py (18008B)
---
1 # -*- coding: utf-8 -*-
2 #
3 # Electrum - lightweight Bitcoin client
4 # Copyright (C) 2018 The Electrum developers
5 #
6 # Permission is hereby granted, free of charge, to any person
7 # obtaining a copy of this software and associated documentation files
8 # (the "Software"), to deal in the Software without restriction,
9 # including without limitation the rights to use, copy, modify, merge,
10 # publish, distribute, sublicense, and/or sell copies of the Software,
11 # and to permit persons to whom the Software is furnished to do so,
12 # subject to the following conditions:
13 #
14 # The above copyright notice and this permission notice shall be
15 # included in all copies or substantial portions of the Software.
16 #
17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18 # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
21 # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
22 # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
23 # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 # SOFTWARE.
25
26 import queue
27 from collections import defaultdict
28 from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
29 import time
30 import attr
31
32 from .util import bh2u, profiler
33 from .logging import Logger
34 from .lnutil import (NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, LnFeatures,
35 NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE)
36 from .channel_db import ChannelDB, Policy, NodeInfo
37
38 if TYPE_CHECKING:
39 from .lnchannel import Channel
40
41
42 class NoChannelPolicy(Exception):
43 def __init__(self, short_channel_id: bytes):
44 short_channel_id = ShortChannelID.normalize(short_channel_id)
45 super().__init__(f'cannot find channel policy for short_channel_id: {short_channel_id}')
46
47
48 class LNPathInconsistent(Exception): pass
49
50
51 def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_proportional_millionths: int) -> int:
52 return fee_base_msat \
53 + (forwarded_amount_msat * fee_proportional_millionths // 1_000_000)
54
55
56 @attr.s(slots=True)
57 class PathEdge:
58 start_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
59 end_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
60 short_channel_id = attr.ib(type=ShortChannelID, kw_only=True, repr=lambda val: str(val))
61
62 @property
63 def node_id(self) -> bytes:
64 # legacy compat # TODO rm
65 return self.end_node
66
67 @attr.s
68 class RouteEdge(PathEdge):
69 fee_base_msat = attr.ib(type=int, kw_only=True)
70 fee_proportional_millionths = attr.ib(type=int, kw_only=True)
71 cltv_expiry_delta = attr.ib(type=int, kw_only=True)
72 node_features = attr.ib(type=int, kw_only=True, repr=lambda val: str(int(val))) # note: for end node!
73
74 def fee_for_edge(self, amount_msat: int) -> int:
75 return fee_for_edge_msat(forwarded_amount_msat=amount_msat,
76 fee_base_msat=self.fee_base_msat,
77 fee_proportional_millionths=self.fee_proportional_millionths)
78
79 @classmethod
80 def from_channel_policy(
81 cls,
82 *,
83 channel_policy: 'Policy',
84 short_channel_id: bytes,
85 start_node: bytes,
86 end_node: bytes,
87 node_info: Optional[NodeInfo], # for end_node
88 ) -> 'RouteEdge':
89 assert isinstance(short_channel_id, bytes)
90 assert type(start_node) is bytes
91 assert type(end_node) is bytes
92 return RouteEdge(
93 start_node=start_node,
94 end_node=end_node,
95 short_channel_id=ShortChannelID.normalize(short_channel_id),
96 fee_base_msat=channel_policy.fee_base_msat,
97 fee_proportional_millionths=channel_policy.fee_proportional_millionths,
98 cltv_expiry_delta=channel_policy.cltv_expiry_delta,
99 node_features=node_info.features if node_info else 0)
100
101 def is_sane_to_use(self, amount_msat: int) -> bool:
102 # TODO revise ad-hoc heuristics
103 # cltv cannot be more than 2 weeks
104 if self.cltv_expiry_delta > 14 * 144:
105 return False
106 total_fee = self.fee_for_edge(amount_msat)
107 if not is_fee_sane(total_fee, payment_amount_msat=amount_msat):
108 return False
109 return True
110
111 def has_feature_varonion(self) -> bool:
112 features = LnFeatures(self.node_features)
113 return features.supports(LnFeatures.VAR_ONION_OPT)
114
115 def is_trampoline(self) -> bool:
116 return False
117
118 @attr.s
119 class TrampolineEdge(RouteEdge):
120 invoice_routing_info = attr.ib(type=bytes, default=None)
121 invoice_features = attr.ib(type=int, default=None)
122 # this is re-defined from parent just to specify a default value:
123 short_channel_id = attr.ib(default=ShortChannelID(8), repr=lambda val: str(val))
124
125 def is_trampoline(self):
126 return True
127
128
129 LNPaymentPath = Sequence[PathEdge]
130 LNPaymentRoute = Sequence[RouteEdge]
131
132
133 def is_route_sane_to_use(route: LNPaymentRoute, invoice_amount_msat: int, min_final_cltv_expiry: int) -> bool:
134 """Run some sanity checks on the whole route, before attempting to use it.
135 called when we are paying; so e.g. lower cltv is better
136 """
137 if len(route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
138 return False
139 amt = invoice_amount_msat
140 cltv = min_final_cltv_expiry
141 for route_edge in reversed(route[1:]):
142 if not route_edge.is_sane_to_use(amt): return False
143 amt += route_edge.fee_for_edge(amt)
144 cltv += route_edge.cltv_expiry_delta
145 total_fee = amt - invoice_amount_msat
146 # TODO revise ad-hoc heuristics
147 if cltv > NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE:
148 return False
149 if not is_fee_sane(total_fee, payment_amount_msat=invoice_amount_msat):
150 return False
151 return True
152
153
154 def is_fee_sane(fee_msat: int, *, payment_amount_msat: int) -> bool:
155 # fees <= 5 sat are fine
156 if fee_msat <= 5_000:
157 return True
158 # fees <= 1 % of payment are fine
159 if 100 * fee_msat <= payment_amount_msat:
160 return True
161 return False
162
163
164
165 class LNPathFinder(Logger):
166
167 def __init__(self, channel_db: ChannelDB):
168 Logger.__init__(self)
169 self.channel_db = channel_db
170
171 def _edge_cost(
172 self,
173 *,
174 short_channel_id: bytes,
175 start_node: bytes,
176 end_node: bytes,
177 payment_amt_msat: int,
178 ignore_costs=False,
179 is_mine=False,
180 my_channels: Dict[ShortChannelID, 'Channel'] = None,
181 private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
182 ) -> Tuple[float, int]:
183 """Heuristic cost (distance metric) of going through a channel.
184 Returns (heuristic_cost, fee_for_edge_msat).
185 """
186 if private_route_edges is None:
187 private_route_edges = {}
188 channel_info = self.channel_db.get_channel_info(
189 short_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
190 if channel_info is None:
191 return float('inf'), 0
192 channel_policy = self.channel_db.get_policy_for_node(
193 short_channel_id, start_node, my_channels=my_channels, private_route_edges=private_route_edges)
194 if channel_policy is None:
195 return float('inf'), 0
196 # channels that did not publish both policies often return temporary channel failure
197 channel_policy_backwards = self.channel_db.get_policy_for_node(
198 short_channel_id, end_node, my_channels=my_channels, private_route_edges=private_route_edges)
199 if (channel_policy_backwards is None
200 and not is_mine
201 and short_channel_id not in private_route_edges):
202 return float('inf'), 0
203 if channel_policy.is_disabled():
204 return float('inf'), 0
205 if payment_amt_msat < channel_policy.htlc_minimum_msat:
206 return float('inf'), 0 # payment amount too little
207 if channel_info.capacity_sat is not None and \
208 payment_amt_msat // 1000 > channel_info.capacity_sat:
209 return float('inf'), 0 # payment amount too large
210 if channel_policy.htlc_maximum_msat is not None and \
211 payment_amt_msat > channel_policy.htlc_maximum_msat:
212 return float('inf'), 0 # payment amount too large
213 route_edge = private_route_edges.get(short_channel_id, None)
214 if route_edge is None:
215 node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
216 route_edge = RouteEdge.from_channel_policy(
217 channel_policy=channel_policy,
218 short_channel_id=short_channel_id,
219 start_node=start_node,
220 end_node=end_node,
221 node_info=node_info)
222 if not route_edge.is_sane_to_use(payment_amt_msat):
223 return float('inf'), 0 # thanks but no thanks
224
225 # Distance metric notes: # TODO constants are ad-hoc
226 # ( somewhat based on https://github.com/lightningnetwork/lnd/pull/1358 )
227 # - Edges have a base cost. (more edges -> less likely none will fail)
228 # - The larger the payment amount, and the longer the CLTV,
229 # the more irritating it is if the HTLC gets stuck.
230 # - Paying lower fees is better. :)
231 base_cost = 500 # one more edge ~ paying 500 msat more fees
232 if ignore_costs:
233 return base_cost, 0
234 fee_msat = route_edge.fee_for_edge(payment_amt_msat)
235 cltv_cost = route_edge.cltv_expiry_delta * payment_amt_msat * 15 / 1_000_000_000
236 overall_cost = base_cost + fee_msat + cltv_cost
237 return overall_cost, fee_msat
238
239 def get_distances(
240 self,
241 *,
242 nodeA: bytes,
243 nodeB: bytes,
244 invoice_amount_msat: int,
245 my_channels: Dict[ShortChannelID, 'Channel'] = None,
246 blacklist: Set[ShortChannelID] = None,
247 private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
248 ) -> Dict[bytes, PathEdge]:
249 # note: we don't lock self.channel_db, so while the path finding runs,
250 # the underlying graph could potentially change... (not good but maybe ~OK?)
251
252 # run Dijkstra
253 # The search is run in the REVERSE direction, from nodeB to nodeA,
254 # to properly calculate compound routing fees.
255 distance_from_start = defaultdict(lambda: float('inf'))
256 distance_from_start[nodeB] = 0
257 prev_node = {} # type: Dict[bytes, PathEdge]
258 nodes_to_explore = queue.PriorityQueue()
259 nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters!
260
261 # main loop of search
262 while nodes_to_explore.qsize() > 0:
263 dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get()
264 if edge_endnode == nodeA:
265 break
266 if dist_to_edge_endnode != distance_from_start[edge_endnode]:
267 # queue.PriorityQueue does not implement decrease_priority,
268 # so instead of decreasing priorities, we add items again into the queue.
269 # so there are duplicates in the queue, that we discard now:
270 continue
271 for edge_channel_id in self.channel_db.get_channels_for_node(
272 edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges):
273 assert isinstance(edge_channel_id, bytes)
274 if blacklist and edge_channel_id in blacklist:
275 continue
276 channel_info = self.channel_db.get_channel_info(
277 edge_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
278 if channel_info is None:
279 continue
280 edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
281 is_mine = edge_channel_id in my_channels
282 if is_mine:
283 if edge_startnode == nodeA: # payment outgoing, on our channel
284 if not my_channels[edge_channel_id].can_pay(amount_msat, check_frozen=True):
285 continue
286 else: # payment incoming, on our channel. (funny business, cycle weirdness)
287 assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode))
288 if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True):
289 continue
290 edge_cost, fee_for_edge_msat = self._edge_cost(
291 short_channel_id=edge_channel_id,
292 start_node=edge_startnode,
293 end_node=edge_endnode,
294 payment_amt_msat=amount_msat,
295 ignore_costs=(edge_startnode == nodeA),
296 is_mine=is_mine,
297 my_channels=my_channels,
298 private_route_edges=private_route_edges)
299 alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
300 if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
301 distance_from_start[edge_startnode] = alt_dist_to_neighbour
302 prev_node[edge_startnode] = PathEdge(
303 start_node=edge_startnode,
304 end_node=edge_endnode,
305 short_channel_id=ShortChannelID(edge_channel_id))
306 amount_to_forward_msat = amount_msat + fee_for_edge_msat
307 nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode))
308
309 return prev_node
310
311 @profiler
312 def find_path_for_payment(
313 self,
314 *,
315 nodeA: bytes,
316 nodeB: bytes,
317 invoice_amount_msat: int,
318 my_channels: Dict[ShortChannelID, 'Channel'] = None,
319 blacklist: Set[ShortChannelID] = None,
320 private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
321 ) -> Optional[LNPaymentPath]:
322 """Return a path from nodeA to nodeB."""
323 assert type(nodeA) is bytes
324 assert type(nodeB) is bytes
325 assert type(invoice_amount_msat) is int
326 if my_channels is None:
327 my_channels = {}
328
329 prev_node = self.get_distances(
330 nodeA=nodeA,
331 nodeB=nodeB,
332 invoice_amount_msat=invoice_amount_msat,
333 my_channels=my_channels,
334 blacklist=blacklist,
335 private_route_edges=private_route_edges)
336
337 if nodeA not in prev_node:
338 return None # no path found
339
340 # backtrack from search_end (nodeA) to search_start (nodeB)
341 # FIXME paths cannot be longer than 20 edges (onion packet)...
342 edge_startnode = nodeA
343 path = []
344 while edge_startnode != nodeB:
345 edge = prev_node[edge_startnode]
346 path += [edge]
347 edge_startnode = edge.node_id
348 return path
349
350 def create_route_from_path(
351 self,
352 path: Optional[LNPaymentPath],
353 *,
354 my_channels: Dict[ShortChannelID, 'Channel'] = None,
355 private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
356 ) -> LNPaymentRoute:
357 if path is None:
358 raise Exception('cannot create route from None path')
359 if private_route_edges is None:
360 private_route_edges = {}
361 route = []
362 prev_end_node = path[0].start_node
363 for path_edge in path:
364 short_channel_id = path_edge.short_channel_id
365 _endnodes = self.channel_db.get_endnodes_for_chan(short_channel_id, my_channels=my_channels)
366 if _endnodes and sorted(_endnodes) != sorted([path_edge.start_node, path_edge.end_node]):
367 raise LNPathInconsistent("endpoints of edge inconsistent with short_channel_id")
368 if path_edge.start_node != prev_end_node:
369 raise LNPathInconsistent("edges do not chain together")
370 route_edge = private_route_edges.get(short_channel_id, None)
371 if route_edge is None:
372 channel_policy = self.channel_db.get_policy_for_node(
373 short_channel_id=short_channel_id,
374 node_id=path_edge.start_node,
375 my_channels=my_channels)
376 if channel_policy is None:
377 raise NoChannelPolicy(short_channel_id)
378 node_info = self.channel_db.get_node_info_for_node_id(node_id=path_edge.end_node)
379 route_edge = RouteEdge.from_channel_policy(
380 channel_policy=channel_policy,
381 short_channel_id=short_channel_id,
382 start_node=path_edge.start_node,
383 end_node=path_edge.end_node,
384 node_info=node_info)
385 route.append(route_edge)
386 prev_end_node = path_edge.end_node
387 return route
388
389 def find_route(
390 self,
391 *,
392 nodeA: bytes,
393 nodeB: bytes,
394 invoice_amount_msat: int,
395 path = None,
396 my_channels: Dict[ShortChannelID, 'Channel'] = None,
397 blacklist: Set[ShortChannelID] = None,
398 private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
399 ) -> Optional[LNPaymentRoute]:
400 route = None
401 if not path:
402 path = self.find_path_for_payment(
403 nodeA=nodeA,
404 nodeB=nodeB,
405 invoice_amount_msat=invoice_amount_msat,
406 my_channels=my_channels,
407 blacklist=blacklist,
408 private_route_edges=private_route_edges)
409 if path:
410 route = self.create_route_from_path(
411 path, my_channels=my_channels, private_route_edges=private_route_edges)
412 return route