tchannel_db.py - electrum - Electrum Bitcoin wallet
HTML git clone https://git.parazyd.org/electrum
DIR Log
DIR Files
DIR Refs
DIR Submodules
---
tchannel_db.py (37455B)
---
1 # -*- coding: utf-8 -*-
2 #
3 # Electrum - lightweight Bitcoin client
4 # Copyright (C) 2018 The Electrum developers
5 #
6 # Permission is hereby granted, free of charge, to any person
7 # obtaining a copy of this software and associated documentation files
8 # (the "Software"), to deal in the Software without restriction,
9 # including without limitation the rights to use, copy, modify, merge,
10 # publish, distribute, sublicense, and/or sell copies of the Software,
11 # and to permit persons to whom the Software is furnished to do so,
12 # subject to the following conditions:
13 #
14 # The above copyright notice and this permission notice shall be
15 # included in all copies or substantial portions of the Software.
16 #
17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18 # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
21 # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
22 # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
23 # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 # SOFTWARE.
25
26 import time
27 import random
28 import os
29 from collections import defaultdict
30 from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
31 import binascii
32 import base64
33 import asyncio
34 import threading
35 from enum import IntEnum
36
37 from aiorpcx import NetAddress
38
39 from .sql_db import SqlDB, sql
40 from . import constants, util
41 from .util import bh2u, profiler, get_headers_dir, is_ip_address, json_normalize
42 from .logging import Logger
43 from .lnutil import (LNPeerAddr, format_short_channel_id, ShortChannelID,
44 validate_features, IncompatibleOrInsaneFeatures)
45 from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
46 from .lnmsg import decode_msg
47
48 if TYPE_CHECKING:
49 from .network import Network
50 from .lnchannel import Channel
51 from .lnrouter import RouteEdge
52
53
54 FLAG_DISABLE = 1 << 1
55 FLAG_DIRECTION = 1 << 0
56
57
58 class ChannelInfo(NamedTuple):
59 short_channel_id: ShortChannelID
60 node1_id: bytes
61 node2_id: bytes
62 capacity_sat: Optional[int]
63
64 @staticmethod
65 def from_msg(payload: dict) -> 'ChannelInfo':
66 features = int.from_bytes(payload['features'], 'big')
67 validate_features(features)
68 channel_id = payload['short_channel_id']
69 node_id_1 = payload['node_id_1']
70 node_id_2 = payload['node_id_2']
71 assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
72 capacity_sat = None
73 return ChannelInfo(
74 short_channel_id = ShortChannelID.normalize(channel_id),
75 node1_id = node_id_1,
76 node2_id = node_id_2,
77 capacity_sat = capacity_sat
78 )
79
80 @staticmethod
81 def from_raw_msg(raw: bytes) -> 'ChannelInfo':
82 payload_dict = decode_msg(raw)[1]
83 return ChannelInfo.from_msg(payload_dict)
84
85 @staticmethod
86 def from_route_edge(route_edge: 'RouteEdge') -> 'ChannelInfo':
87 node1_id, node2_id = sorted([route_edge.start_node, route_edge.end_node])
88 return ChannelInfo(
89 short_channel_id=route_edge.short_channel_id,
90 node1_id=node1_id,
91 node2_id=node2_id,
92 capacity_sat=None,
93 )
94
95
96 class Policy(NamedTuple):
97 key: bytes
98 cltv_expiry_delta: int
99 htlc_minimum_msat: int
100 htlc_maximum_msat: Optional[int]
101 fee_base_msat: int
102 fee_proportional_millionths: int
103 channel_flags: int
104 message_flags: int
105 timestamp: int
106
107 @staticmethod
108 def from_msg(payload: dict) -> 'Policy':
109 return Policy(
110 key = payload['short_channel_id'] + payload['start_node'],
111 cltv_expiry_delta = payload['cltv_expiry_delta'],
112 htlc_minimum_msat = payload['htlc_minimum_msat'],
113 htlc_maximum_msat = payload.get('htlc_maximum_msat', None),
114 fee_base_msat = payload['fee_base_msat'],
115 fee_proportional_millionths = payload['fee_proportional_millionths'],
116 message_flags = int.from_bytes(payload['message_flags'], "big"),
117 channel_flags = int.from_bytes(payload['channel_flags'], "big"),
118 timestamp = payload['timestamp'],
119 )
120
121 @staticmethod
122 def from_raw_msg(key:bytes, raw: bytes) -> 'Policy':
123 payload = decode_msg(raw)[1]
124 payload['start_node'] = key[8:]
125 return Policy.from_msg(payload)
126
127 @staticmethod
128 def from_route_edge(route_edge: 'RouteEdge') -> 'Policy':
129 return Policy(
130 key=route_edge.short_channel_id + route_edge.start_node,
131 cltv_expiry_delta=route_edge.cltv_expiry_delta,
132 htlc_minimum_msat=0,
133 htlc_maximum_msat=None,
134 fee_base_msat=route_edge.fee_base_msat,
135 fee_proportional_millionths=route_edge.fee_proportional_millionths,
136 channel_flags=0,
137 message_flags=0,
138 timestamp=0,
139 )
140
141 def is_disabled(self):
142 return self.channel_flags & FLAG_DISABLE
143
144 @property
145 def short_channel_id(self) -> ShortChannelID:
146 return ShortChannelID.normalize(self.key[0:8])
147
148 @property
149 def start_node(self) -> bytes:
150 return self.key[8:]
151
152
153 class NodeInfo(NamedTuple):
154 node_id: bytes
155 features: int
156 timestamp: int
157 alias: str
158
159 @staticmethod
160 def from_msg(payload) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]:
161 node_id = payload['node_id']
162 features = int.from_bytes(payload['features'], "big")
163 validate_features(features)
164 addresses = NodeInfo.parse_addresses_field(payload['addresses'])
165 peer_addrs = []
166 for host, port in addresses:
167 try:
168 peer_addrs.append(LNPeerAddr(host=host, port=port, pubkey=node_id))
169 except ValueError:
170 pass
171 alias = payload['alias'].rstrip(b'\x00')
172 try:
173 alias = alias.decode('utf8')
174 except:
175 alias = ''
176 timestamp = payload['timestamp']
177 node_info = NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias)
178 return node_info, peer_addrs
179
180 @staticmethod
181 def from_raw_msg(raw: bytes) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]:
182 payload_dict = decode_msg(raw)[1]
183 return NodeInfo.from_msg(payload_dict)
184
185 @staticmethod
186 def parse_addresses_field(addresses_field):
187 buf = addresses_field
188 def read(n):
189 nonlocal buf
190 data, buf = buf[0:n], buf[n:]
191 return data
192 addresses = []
193 while buf:
194 atype = ord(read(1))
195 if atype == 0:
196 pass
197 elif atype == 1: # IPv4
198 ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4)))
199 port = int.from_bytes(read(2), 'big')
200 if is_ip_address(ipv4_addr) and port != 0:
201 addresses.append((ipv4_addr, port))
202 elif atype == 2: # IPv6
203 ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)])
204 ipv6_addr = ipv6_addr.decode('ascii')
205 port = int.from_bytes(read(2), 'big')
206 if is_ip_address(ipv6_addr) and port != 0:
207 addresses.append((ipv6_addr, port))
208 elif atype == 3: # onion v2
209 host = base64.b32encode(read(10)) + b'.onion'
210 host = host.decode('ascii').lower()
211 port = int.from_bytes(read(2), 'big')
212 addresses.append((host, port))
213 elif atype == 4: # onion v3
214 host = base64.b32encode(read(35)) + b'.onion'
215 host = host.decode('ascii').lower()
216 port = int.from_bytes(read(2), 'big')
217 addresses.append((host, port))
218 else:
219 # unknown address type
220 # we don't know how long it is -> have to escape
221 # if there are other addresses we could have parsed later, they are lost.
222 break
223 return addresses
224
225
226 class UpdateStatus(IntEnum):
227 ORPHANED = 0
228 EXPIRED = 1
229 DEPRECATED = 2
230 UNCHANGED = 3
231 GOOD = 4
232
233 class CategorizedChannelUpdates(NamedTuple):
234 orphaned: List # no channel announcement for channel update
235 expired: List # update older than two weeks
236 deprecated: List # update older than database entry
237 unchanged: List # unchanged policies
238 good: List # good updates
239
240
241 def get_mychannel_info(short_channel_id: ShortChannelID,
242 my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[ChannelInfo]:
243 chan = my_channels.get(short_channel_id)
244 if not chan:
245 return
246 ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs())
247 return ci._replace(capacity_sat=chan.constraints.capacity)
248
249 def get_mychannel_policy(short_channel_id: bytes, node_id: bytes,
250 my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[Policy]:
251 chan = my_channels.get(short_channel_id) # type: Optional[Channel]
252 if not chan:
253 return
254 if node_id == chan.node_id: # incoming direction (to us)
255 remote_update_raw = chan.get_remote_update()
256 if not remote_update_raw:
257 return
258 now = int(time.time())
259 remote_update_decoded = decode_msg(remote_update_raw)[1]
260 remote_update_decoded['timestamp'] = now
261 remote_update_decoded['start_node'] = node_id
262 return Policy.from_msg(remote_update_decoded)
263 elif node_id == chan.get_local_pubkey(): # outgoing direction (from us)
264 local_update_decoded = decode_msg(chan.get_outgoing_gossip_channel_update())[1]
265 local_update_decoded['start_node'] = node_id
266 return Policy.from_msg(local_update_decoded)
267
268
269 create_channel_info = """
270 CREATE TABLE IF NOT EXISTS channel_info (
271 short_channel_id BLOB(8),
272 msg BLOB,
273 PRIMARY KEY(short_channel_id)
274 )"""
275
276 create_policy = """
277 CREATE TABLE IF NOT EXISTS policy (
278 key BLOB(41),
279 msg BLOB,
280 PRIMARY KEY(key)
281 )"""
282
283 create_address = """
284 CREATE TABLE IF NOT EXISTS address (
285 node_id BLOB(33),
286 host STRING(256),
287 port INTEGER NOT NULL,
288 timestamp INTEGER,
289 PRIMARY KEY(node_id, host, port)
290 )"""
291
292 create_node_info = """
293 CREATE TABLE IF NOT EXISTS node_info (
294 node_id BLOB(33),
295 msg BLOB,
296 PRIMARY KEY(node_id)
297 )"""
298
299
300 class ChannelDB(SqlDB):
301
302 NUM_MAX_RECENT_PEERS = 20
303
304 def __init__(self, network: 'Network'):
305 path = os.path.join(get_headers_dir(network.config), 'gossip_db')
306 super().__init__(network.asyncio_loop, path, commit_interval=100)
307 self.lock = threading.RLock()
308 self.num_nodes = 0
309 self.num_channels = 0
310 self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
311 self.ca_verifier = LNChannelVerifier(network, self)
312
313 # initialized in load_data
314 # note: modify/iterate needs self.lock
315 self._channels = {} # type: Dict[ShortChannelID, ChannelInfo]
316 self._policies = {} # type: Dict[Tuple[bytes, ShortChannelID], Policy] # (node_id, scid) -> Policy
317 self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo
318 # node_id -> NetAddress -> timestamp
319 self._addresses = defaultdict(dict) # type: Dict[bytes, Dict[NetAddress, int]]
320 self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
321 self._recent_peers = [] # type: List[bytes] # list of node_ids
322 self._chans_with_0_policies = set() # type: Set[ShortChannelID]
323 self._chans_with_1_policies = set() # type: Set[ShortChannelID]
324 self._chans_with_2_policies = set() # type: Set[ShortChannelID]
325
326 self.data_loaded = asyncio.Event()
327 self.network = network # only for callback
328
329 def update_counts(self):
330 self.num_nodes = len(self._nodes)
331 self.num_channels = len(self._channels)
332 self.num_policies = len(self._policies)
333 util.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies)
334 util.trigger_callback('ln_gossip_sync_progress')
335
336 def get_channel_ids(self):
337 with self.lock:
338 return set(self._channels.keys())
339
340 def add_recent_peer(self, peer: LNPeerAddr):
341 now = int(time.time())
342 node_id = peer.pubkey
343 with self.lock:
344 self._addresses[node_id][peer.net_addr()] = now
345 # list is ordered
346 if node_id in self._recent_peers:
347 self._recent_peers.remove(node_id)
348 self._recent_peers.insert(0, node_id)
349 self._recent_peers = self._recent_peers[:self.NUM_MAX_RECENT_PEERS]
350 self._db_save_node_address(peer, now)
351
352 def get_200_randomly_sorted_nodes_not_in(self, node_ids):
353 with self.lock:
354 unshuffled = set(self._nodes.keys()) - node_ids
355 return random.sample(unshuffled, min(200, len(unshuffled)))
356
357 def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]:
358 """Returns latest address we successfully connected to, for given node."""
359 addr_to_ts = self._addresses.get(node_id)
360 if not addr_to_ts:
361 return None
362 addr = sorted(list(addr_to_ts), key=lambda a: addr_to_ts[a], reverse=True)[0]
363 try:
364 return LNPeerAddr(str(addr.host), addr.port, node_id)
365 except ValueError:
366 return None
367
368 def get_recent_peers(self):
369 if not self.data_loaded.is_set():
370 raise Exception("channelDB data not loaded yet!")
371 with self.lock:
372 ret = [self.get_last_good_address(node_id)
373 for node_id in self._recent_peers]
374 return ret
375
376 # note: currently channel announcements are trusted by default (trusted=True);
377 # they are not SPV-verified. Verifying them would make the gossip sync
378 # even slower; especially as servers will start throttling us.
379 # It would probably put significant strain on servers if all clients
380 # verified the complete gossip.
381 def add_channel_announcement(self, msg_payloads, *, trusted=True):
382 # note: signatures have already been verified.
383 if type(msg_payloads) is dict:
384 msg_payloads = [msg_payloads]
385 added = 0
386 for msg in msg_payloads:
387 short_channel_id = ShortChannelID(msg['short_channel_id'])
388 if short_channel_id in self._channels:
389 continue
390 if constants.net.rev_genesis_bytes() != msg['chain_hash']:
391 self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash'])))
392 continue
393 try:
394 channel_info = ChannelInfo.from_msg(msg)
395 except IncompatibleOrInsaneFeatures as e:
396 self.logger.info(f"unknown or insane feature bits: {e!r}")
397 continue
398 if trusted:
399 added += 1
400 self.add_verified_channel_info(msg)
401 else:
402 added += self.ca_verifier.add_new_channel_info(short_channel_id, msg)
403
404 self.update_counts()
405 self.logger.debug('add_channel_announcement: %d/%d'%(added, len(msg_payloads)))
406
407 def add_verified_channel_info(self, msg: dict, *, capacity_sat: int = None) -> None:
408 try:
409 channel_info = ChannelInfo.from_msg(msg)
410 except IncompatibleOrInsaneFeatures:
411 return
412 channel_info = channel_info._replace(capacity_sat=capacity_sat)
413 with self.lock:
414 self._channels[channel_info.short_channel_id] = channel_info
415 self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
416 self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
417 self._update_num_policies_for_chan(channel_info.short_channel_id)
418 if 'raw' in msg:
419 self._db_save_channel(channel_info.short_channel_id, msg['raw'])
420
421 def policy_changed(self, old_policy: Policy, new_policy: Policy, verbose: bool) -> bool:
422 changed = False
423 if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta:
424 changed |= True
425 if verbose:
426 self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}')
427 if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat:
428 changed |= True
429 if verbose:
430 self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}')
431 if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat:
432 changed |= True
433 if verbose:
434 self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}')
435 if old_policy.fee_base_msat != new_policy.fee_base_msat:
436 changed |= True
437 if verbose:
438 self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}')
439 if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths:
440 changed |= True
441 if verbose:
442 self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}')
443 if old_policy.channel_flags != new_policy.channel_flags:
444 changed |= True
445 if verbose:
446 self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
447 if old_policy.message_flags != new_policy.message_flags:
448 changed |= True
449 if verbose:
450 self.logger.info(f'message_flags: {old_policy.message_flags} -> {new_policy.message_flags}')
451 if not changed and verbose:
452 self.logger.info(f'policy unchanged: {old_policy.timestamp} -> {new_policy.timestamp}')
453 return changed
454
455 def add_channel_update(self, payload, max_age=None, verify=False, verbose=True):
456 now = int(time.time())
457 short_channel_id = ShortChannelID(payload['short_channel_id'])
458 timestamp = payload['timestamp']
459 if max_age and now - timestamp > max_age:
460 return UpdateStatus.EXPIRED
461 if timestamp - now > 60:
462 return UpdateStatus.DEPRECATED
463 channel_info = self._channels.get(short_channel_id)
464 if not channel_info:
465 return UpdateStatus.ORPHANED
466 flags = int.from_bytes(payload['channel_flags'], 'big')
467 direction = flags & FLAG_DIRECTION
468 start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
469 payload['start_node'] = start_node
470 # compare updates to existing database entries
471 timestamp = payload['timestamp']
472 start_node = payload['start_node']
473 short_channel_id = ShortChannelID(payload['short_channel_id'])
474 key = (start_node, short_channel_id)
475 old_policy = self._policies.get(key)
476 if old_policy and timestamp <= old_policy.timestamp + 60:
477 return UpdateStatus.DEPRECATED
478 if verify:
479 self.verify_channel_update(payload)
480 policy = Policy.from_msg(payload)
481 with self.lock:
482 self._policies[key] = policy
483 self._update_num_policies_for_chan(short_channel_id)
484 if 'raw' in payload:
485 self._db_save_policy(policy.key, payload['raw'])
486 if old_policy and not self.policy_changed(old_policy, policy, verbose):
487 return UpdateStatus.UNCHANGED
488 else:
489 return UpdateStatus.GOOD
490
491 def add_channel_updates(self, payloads, max_age=None) -> CategorizedChannelUpdates:
492 orphaned = []
493 expired = []
494 deprecated = []
495 unchanged = []
496 good = []
497 for payload in payloads:
498 r = self.add_channel_update(payload, max_age=max_age, verbose=False)
499 if r == UpdateStatus.ORPHANED:
500 orphaned.append(payload)
501 elif r == UpdateStatus.EXPIRED:
502 expired.append(payload)
503 elif r == UpdateStatus.DEPRECATED:
504 deprecated.append(payload)
505 elif r == UpdateStatus.UNCHANGED:
506 unchanged.append(payload)
507 elif r == UpdateStatus.GOOD:
508 good.append(payload)
509 self.update_counts()
510 return CategorizedChannelUpdates(
511 orphaned=orphaned,
512 expired=expired,
513 deprecated=deprecated,
514 unchanged=unchanged,
515 good=good)
516
517
518 def create_database(self):
519 c = self.conn.cursor()
520 c.execute(create_node_info)
521 c.execute(create_address)
522 c.execute(create_policy)
523 c.execute(create_channel_info)
524 self.conn.commit()
525
526 @sql
527 def _db_save_policy(self, key: bytes, msg: bytes):
528 # 'msg' is a 'channel_update' message
529 c = self.conn.cursor()
530 c.execute("""REPLACE INTO policy (key, msg) VALUES (?,?)""", [key, msg])
531
532 @sql
533 def _db_delete_policy(self, node_id: bytes, short_channel_id: ShortChannelID):
534 key = short_channel_id + node_id
535 c = self.conn.cursor()
536 c.execute("""DELETE FROM policy WHERE key=?""", (key,))
537
538 @sql
539 def _db_save_channel(self, short_channel_id: ShortChannelID, msg: bytes):
540 # 'msg' is a 'channel_announcement' message
541 c = self.conn.cursor()
542 c.execute("REPLACE INTO channel_info (short_channel_id, msg) VALUES (?,?)", [short_channel_id, msg])
543
544 @sql
545 def _db_delete_channel(self, short_channel_id: ShortChannelID):
546 c = self.conn.cursor()
547 c.execute("""DELETE FROM channel_info WHERE short_channel_id=?""", (short_channel_id,))
548
549 @sql
550 def _db_save_node_info(self, node_id: bytes, msg: bytes):
551 # 'msg' is a 'node_announcement' message
552 c = self.conn.cursor()
553 c.execute("REPLACE INTO node_info (node_id, msg) VALUES (?,?)", [node_id, msg])
554
555 @sql
556 def _db_save_node_address(self, peer: LNPeerAddr, timestamp: int):
557 c = self.conn.cursor()
558 c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)",
559 (peer.pubkey, peer.host, peer.port, timestamp))
560
561 @sql
562 def _db_save_node_addresses(self, node_addresses: Sequence[LNPeerAddr]):
563 c = self.conn.cursor()
564 for addr in node_addresses:
565 c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.pubkey, addr.host, addr.port))
566 r = c.fetchall()
567 if r == []:
568 c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.pubkey, addr.host, addr.port, 0))
569
570 def verify_channel_update(self, payload):
571 short_channel_id = payload['short_channel_id']
572 short_channel_id = ShortChannelID(short_channel_id)
573 if constants.net.rev_genesis_bytes() != payload['chain_hash']:
574 raise Exception('wrong chain hash')
575 if not verify_sig_for_channel_update(payload, payload['start_node']):
576 raise Exception(f'failed verifying channel update for {short_channel_id}')
577
578 def add_node_announcement(self, msg_payloads):
579 # note: signatures have already been verified.
580 if type(msg_payloads) is dict:
581 msg_payloads = [msg_payloads]
582 new_nodes = {}
583 for msg_payload in msg_payloads:
584 try:
585 node_info, node_addresses = NodeInfo.from_msg(msg_payload)
586 except IncompatibleOrInsaneFeatures:
587 continue
588 node_id = node_info.node_id
589 # Ignore node if it has no associated channel (DoS protection)
590 if node_id not in self._channels_for_node:
591 #self.logger.info('ignoring orphan node_announcement')
592 continue
593 node = self._nodes.get(node_id)
594 if node and node.timestamp >= node_info.timestamp:
595 continue
596 node = new_nodes.get(node_id)
597 if node and node.timestamp >= node_info.timestamp:
598 continue
599 # save
600 with self.lock:
601 self._nodes[node_id] = node_info
602 if 'raw' in msg_payload:
603 self._db_save_node_info(node_id, msg_payload['raw'])
604 with self.lock:
605 for addr in node_addresses:
606 net_addr = NetAddress(addr.host, addr.port)
607 self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0
608 self._db_save_node_addresses(node_addresses)
609
610 self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
611 self.update_counts()
612
613 def get_old_policies(self, delta) -> Sequence[Tuple[bytes, ShortChannelID]]:
614 with self.lock:
615 _policies = self._policies.copy()
616 now = int(time.time())
617 return list(k for k, v in _policies.items() if v.timestamp <= now - delta)
618
619 def prune_old_policies(self, delta):
620 old_policies = self.get_old_policies(delta)
621 if old_policies:
622 for key in old_policies:
623 node_id, scid = key
624 with self.lock:
625 self._policies.pop(key)
626 self._db_delete_policy(*key)
627 self._update_num_policies_for_chan(scid)
628 self.update_counts()
629 self.logger.info(f'Deleting {len(old_policies)} old policies')
630
631 def prune_orphaned_channels(self):
632 with self.lock:
633 orphaned_chans = self._chans_with_0_policies.copy()
634 if orphaned_chans:
635 for short_channel_id in orphaned_chans:
636 self.remove_channel(short_channel_id)
637 self.update_counts()
638 self.logger.info(f'Deleting {len(orphaned_chans)} orphaned channels')
639
640 def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes) -> bool:
641 """Returns True iff the channel update was successfully added and it was different than
642 what we had before (if any).
643 """
644 if not verify_sig_for_channel_update(msg_payload, start_node_id):
645 return False # ignore
646 short_channel_id = ShortChannelID(msg_payload['short_channel_id'])
647 msg_payload['start_node'] = start_node_id
648 key = (start_node_id, short_channel_id)
649 prev_chanupd = self._channel_updates_for_private_channels.get(key)
650 if prev_chanupd == msg_payload:
651 return False
652 self._channel_updates_for_private_channels[key] = msg_payload
653 return True
654
655 def remove_channel(self, short_channel_id: ShortChannelID):
656 # FIXME what about rm-ing policies?
657 with self.lock:
658 channel_info = self._channels.pop(short_channel_id, None)
659 if channel_info:
660 self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
661 self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id)
662 self._update_num_policies_for_chan(short_channel_id)
663 # delete from database
664 self._db_delete_channel(short_channel_id)
665
666 def get_node_addresses(self, node_id: bytes) -> Sequence[Tuple[str, int, int]]:
667 """Returns list of (host, port, timestamp)."""
668 addr_to_ts = self._addresses.get(node_id)
669 if not addr_to_ts:
670 return []
671 return [(str(net_addr.host), net_addr.port, ts)
672 for net_addr, ts in addr_to_ts.items()]
673
674 @sql
675 @profiler
676 def load_data(self):
677 if self.data_loaded.is_set():
678 return
679 # Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow.
680 c = self.conn.cursor()
681 c.execute("""SELECT * FROM address""")
682 for x in c:
683 node_id, host, port, timestamp = x
684 try:
685 net_addr = NetAddress(host, port)
686 except Exception:
687 continue
688 self._addresses[node_id][net_addr] = int(timestamp or 0)
689 def newest_ts_for_node_id(node_id):
690 newest_ts = 0
691 for addr, ts in self._addresses[node_id].items():
692 newest_ts = max(newest_ts, ts)
693 return newest_ts
694 sorted_node_ids = sorted(self._addresses.keys(), key=newest_ts_for_node_id, reverse=True)
695 self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS]
696 c.execute("""SELECT * FROM channel_info""")
697 for short_channel_id, msg in c:
698 try:
699 ci = ChannelInfo.from_raw_msg(msg)
700 except IncompatibleOrInsaneFeatures:
701 continue
702 self._channels[ShortChannelID.normalize(short_channel_id)] = ci
703 c.execute("""SELECT * FROM node_info""")
704 for node_id, msg in c:
705 try:
706 node_info, node_addresses = NodeInfo.from_raw_msg(msg)
707 except IncompatibleOrInsaneFeatures:
708 continue
709 # don't load node_addresses because they dont have timestamps
710 self._nodes[node_id] = node_info
711 c.execute("""SELECT * FROM policy""")
712 for key, msg in c:
713 p = Policy.from_raw_msg(key, msg)
714 self._policies[(p.start_node, p.short_channel_id)] = p
715 for channel_info in self._channels.values():
716 self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
717 self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
718 self._update_num_policies_for_chan(channel_info.short_channel_id)
719 self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}')
720 self.update_counts()
721 (nchans_with_0p, nchans_with_1p, nchans_with_2p) = self.get_num_channels_partitioned_by_policy_count()
722 self.logger.info(f'num_channels_partitioned_by_policy_count. '
723 f'0p: {nchans_with_0p}, 1p: {nchans_with_1p}, 2p: {nchans_with_2p}')
724 self.data_loaded.set()
725 util.trigger_callback('gossip_db_loaded')
726
727 def _update_num_policies_for_chan(self, short_channel_id: ShortChannelID) -> None:
728 channel_info = self.get_channel_info(short_channel_id)
729 if channel_info is None:
730 with self.lock:
731 self._chans_with_0_policies.discard(short_channel_id)
732 self._chans_with_1_policies.discard(short_channel_id)
733 self._chans_with_2_policies.discard(short_channel_id)
734 return
735 p1 = self.get_policy_for_node(short_channel_id, channel_info.node1_id)
736 p2 = self.get_policy_for_node(short_channel_id, channel_info.node2_id)
737 with self.lock:
738 self._chans_with_0_policies.discard(short_channel_id)
739 self._chans_with_1_policies.discard(short_channel_id)
740 self._chans_with_2_policies.discard(short_channel_id)
741 if p1 is not None and p2 is not None:
742 self._chans_with_2_policies.add(short_channel_id)
743 elif p1 is None and p2 is None:
744 self._chans_with_0_policies.add(short_channel_id)
745 else:
746 self._chans_with_1_policies.add(short_channel_id)
747
748 def get_num_channels_partitioned_by_policy_count(self) -> Tuple[int, int, int]:
749 nchans_with_0p = len(self._chans_with_0_policies)
750 nchans_with_1p = len(self._chans_with_1_policies)
751 nchans_with_2p = len(self._chans_with_2_policies)
752 return nchans_with_0p, nchans_with_1p, nchans_with_2p
753
754 def get_policy_for_node(
755 self,
756 short_channel_id: bytes,
757 node_id: bytes,
758 *,
759 my_channels: Dict[ShortChannelID, 'Channel'] = None,
760 private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
761 ) -> Optional['Policy']:
762 channel_info = self.get_channel_info(short_channel_id)
763 if channel_info is not None: # publicly announced channel
764 policy = self._policies.get((node_id, short_channel_id))
765 if policy:
766 return policy
767 else: # private channel
768 chan_upd_dict = self._channel_updates_for_private_channels.get((node_id, short_channel_id))
769 if chan_upd_dict:
770 return Policy.from_msg(chan_upd_dict)
771 # check if it's one of our own channels
772 if my_channels:
773 policy = get_mychannel_policy(short_channel_id, node_id, my_channels)
774 if policy:
775 return policy
776 if private_route_edges:
777 route_edge = private_route_edges.get(short_channel_id, None)
778 if route_edge:
779 return Policy.from_route_edge(route_edge)
780
781 def get_channel_info(
782 self,
783 short_channel_id: ShortChannelID,
784 *,
785 my_channels: Dict[ShortChannelID, 'Channel'] = None,
786 private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
787 ) -> Optional[ChannelInfo]:
788 ret = self._channels.get(short_channel_id)
789 if ret:
790 return ret
791 # check if it's one of our own channels
792 if my_channels:
793 channel_info = get_mychannel_info(short_channel_id, my_channels)
794 if channel_info:
795 return channel_info
796 if private_route_edges:
797 route_edge = private_route_edges.get(short_channel_id)
798 if route_edge:
799 return ChannelInfo.from_route_edge(route_edge)
800
801 def get_channels_for_node(
802 self,
803 node_id: bytes,
804 *,
805 my_channels: Dict[ShortChannelID, 'Channel'] = None,
806 private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
807 ) -> Set[bytes]:
808 """Returns the set of short channel IDs where node_id is one of the channel participants."""
809 if not self.data_loaded.is_set():
810 raise Exception("channelDB data not loaded yet!")
811 relevant_channels = self._channels_for_node.get(node_id) or set()
812 relevant_channels = set(relevant_channels) # copy
813 # add our own channels # TODO maybe slow?
814 if my_channels:
815 for chan in my_channels.values():
816 if node_id in (chan.node_id, chan.get_local_pubkey()):
817 relevant_channels.add(chan.short_channel_id)
818 # add private channels # TODO maybe slow?
819 if private_route_edges:
820 for route_edge in private_route_edges.values():
821 if node_id in (route_edge.start_node, route_edge.end_node):
822 relevant_channels.add(route_edge.short_channel_id)
823 return relevant_channels
824
825 def get_endnodes_for_chan(self, short_channel_id: ShortChannelID, *,
826 my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[Tuple[bytes, bytes]]:
827 channel_info = self.get_channel_info(short_channel_id)
828 if channel_info is not None: # publicly announced channel
829 return channel_info.node1_id, channel_info.node2_id
830 # check if it's one of our own channels
831 if not my_channels:
832 return
833 chan = my_channels.get(short_channel_id) # type: Optional[Channel]
834 if not chan:
835 return
836 return chan.get_local_pubkey(), chan.node_id
837
838 def get_node_info_for_node_id(self, node_id: bytes) -> Optional['NodeInfo']:
839 return self._nodes.get(node_id)
840
841 def get_node_infos(self) -> Dict[bytes, NodeInfo]:
842 with self.lock:
843 return self._nodes.copy()
844
845 def get_node_policies(self) -> Dict[Tuple[bytes, ShortChannelID], Policy]:
846 with self.lock:
847 return self._policies.copy()
848
849 def to_dict(self) -> dict:
850 """ Generates a graph representation in terms of a dictionary.
851
852 The dictionary contains only native python types and can be encoded
853 to json.
854 """
855 with self.lock:
856 graph = {'nodes': [], 'channels': []}
857
858 # gather nodes
859 for pk, nodeinfo in self._nodes.items():
860 # use _asdict() to convert NamedTuples to json encodable dicts
861 graph['nodes'].append(
862 nodeinfo._asdict(),
863 )
864 graph['nodes'][-1]['addresses'] = [
865 {'host': str(addr.host), 'port': addr.port, 'timestamp': ts}
866 for addr, ts in self._addresses[pk].items()
867 ]
868
869 # gather channels
870 for cid, channelinfo in self._channels.items():
871 graph['channels'].append(
872 channelinfo._asdict(),
873 )
874 policy1 = self._policies.get(
875 (channelinfo.node1_id, channelinfo.short_channel_id))
876 policy2 = self._policies.get(
877 (channelinfo.node2_id, channelinfo.short_channel_id))
878 graph['channels'][-1]['policy1'] = policy1._asdict() if policy1 else None
879 graph['channels'][-1]['policy2'] = policy2._asdict() if policy2 else None
880
881 # need to use json_normalize otherwise json encoding in rpc server fails
882 graph = json_normalize(graph)
883 return graph