tlnmsg.py - electrum - Electrum Bitcoin wallet
HTML git clone https://git.parazyd.org/electrum
DIR Log
DIR Files
DIR Refs
DIR Submodules
---
tlnmsg.py (23442B)
---
1 import os
2 import csv
3 import io
4 from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional
5 from collections import OrderedDict
6
7 from .lnutil import OnionFailureCodeMetaFlag
8
9
10 class MalformedMsg(Exception): pass
11 class UnknownMsgFieldType(MalformedMsg): pass
12 class UnexpectedEndOfStream(MalformedMsg): pass
13 class FieldEncodingNotMinimal(MalformedMsg): pass
14 class UnknownMandatoryTLVRecordType(MalformedMsg): pass
15 class MsgTrailingGarbage(MalformedMsg): pass
16 class MsgInvalidFieldOrder(MalformedMsg): pass
17 class UnexpectedFieldSizeForEncoder(MalformedMsg): pass
18
19
20 def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int:
21 cur_pos = fd.tell()
22 end_pos = fd.seek(0, io.SEEK_END)
23 fd.seek(cur_pos)
24 return end_pos - cur_pos
25
26
27 def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None:
28 # note: it's faster to read n bytes and then check if we read n, than
29 # to assert we can read at least n and then read n bytes.
30 nremaining = _num_remaining_bytes_to_read(fd)
31 if nremaining < n:
32 raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left")
33
34
35 def write_bigsize_int(i: int) -> bytes:
36 assert i >= 0, i
37 if i < 0xfd:
38 return int.to_bytes(i, length=1, byteorder="big", signed=False)
39 elif i < 0x1_0000:
40 return b"\xfd" + int.to_bytes(i, length=2, byteorder="big", signed=False)
41 elif i < 0x1_0000_0000:
42 return b"\xfe" + int.to_bytes(i, length=4, byteorder="big", signed=False)
43 else:
44 return b"\xff" + int.to_bytes(i, length=8, byteorder="big", signed=False)
45
46
47 def read_bigsize_int(fd: io.BytesIO) -> Optional[int]:
48 try:
49 first = fd.read(1)[0]
50 except IndexError:
51 return None # end of file
52 if first < 0xfd:
53 return first
54 elif first == 0xfd:
55 buf = fd.read(2)
56 if len(buf) != 2:
57 raise UnexpectedEndOfStream()
58 val = int.from_bytes(buf, byteorder="big", signed=False)
59 if not (0xfd <= val < 0x1_0000):
60 raise FieldEncodingNotMinimal()
61 return val
62 elif first == 0xfe:
63 buf = fd.read(4)
64 if len(buf) != 4:
65 raise UnexpectedEndOfStream()
66 val = int.from_bytes(buf, byteorder="big", signed=False)
67 if not (0x1_0000 <= val < 0x1_0000_0000):
68 raise FieldEncodingNotMinimal()
69 return val
70 elif first == 0xff:
71 buf = fd.read(8)
72 if len(buf) != 8:
73 raise UnexpectedEndOfStream()
74 val = int.from_bytes(buf, byteorder="big", signed=False)
75 if not (0x1_0000_0000 <= val):
76 raise FieldEncodingNotMinimal()
77 return val
78 raise Exception()
79
80
81 # TODO: maybe if field_type is not "byte", we could return a list of type_len sized chunks?
82 # if field_type is a numeric, we could return a list of ints?
83 def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> Union[bytes, int]:
84 if not fd: raise Exception()
85 if isinstance(count, int):
86 assert count >= 0, f"{count!r} must be non-neg int"
87 elif count == "...":
88 pass
89 else:
90 raise Exception(f"unexpected field count: {count!r}")
91 if count == 0:
92 return b""
93 type_len = None
94 if field_type == 'byte':
95 type_len = 1
96 elif field_type in ('u8', 'u16', 'u32', 'u64'):
97 if field_type == 'u8':
98 type_len = 1
99 elif field_type == 'u16':
100 type_len = 2
101 elif field_type == 'u32':
102 type_len = 4
103 else:
104 assert field_type == 'u64'
105 type_len = 8
106 assert count == 1, count
107 buf = fd.read(type_len)
108 if len(buf) != type_len:
109 raise UnexpectedEndOfStream()
110 return int.from_bytes(buf, byteorder="big", signed=False)
111 elif field_type in ('tu16', 'tu32', 'tu64'):
112 if field_type == 'tu16':
113 type_len = 2
114 elif field_type == 'tu32':
115 type_len = 4
116 else:
117 assert field_type == 'tu64'
118 type_len = 8
119 assert count == 1, count
120 raw = fd.read(type_len)
121 if len(raw) > 0 and raw[0] == 0x00:
122 raise FieldEncodingNotMinimal()
123 return int.from_bytes(raw, byteorder="big", signed=False)
124 elif field_type == 'varint':
125 assert count == 1, count
126 val = read_bigsize_int(fd)
127 if val is None:
128 raise UnexpectedEndOfStream()
129 return val
130 elif field_type == 'chain_hash':
131 type_len = 32
132 elif field_type == 'channel_id':
133 type_len = 32
134 elif field_type == 'sha256':
135 type_len = 32
136 elif field_type == 'signature':
137 type_len = 64
138 elif field_type == 'point':
139 type_len = 33
140 elif field_type == 'short_channel_id':
141 type_len = 8
142
143 if count == "...":
144 total_len = -1 # read all
145 else:
146 if type_len is None:
147 raise UnknownMsgFieldType(f"unknown field type: {field_type!r}")
148 total_len = count * type_len
149
150 buf = fd.read(total_len)
151 if total_len >= 0 and len(buf) != total_len:
152 raise UnexpectedEndOfStream()
153 return buf
154
155
156 # TODO: maybe for "value" we could accept a list with len "count" of appropriate items
157 def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str],
158 value: Union[bytes, int]) -> None:
159 if not fd: raise Exception()
160 if isinstance(count, int):
161 assert count >= 0, f"{count!r} must be non-neg int"
162 elif count == "...":
163 pass
164 else:
165 raise Exception(f"unexpected field count: {count!r}")
166 if count == 0:
167 return
168 type_len = None
169 if field_type == 'byte':
170 type_len = 1
171 elif field_type == 'u8':
172 type_len = 1
173 elif field_type == 'u16':
174 type_len = 2
175 elif field_type == 'u32':
176 type_len = 4
177 elif field_type == 'u64':
178 type_len = 8
179 elif field_type in ('tu16', 'tu32', 'tu64'):
180 if field_type == 'tu16':
181 type_len = 2
182 elif field_type == 'tu32':
183 type_len = 4
184 else:
185 assert field_type == 'tu64'
186 type_len = 8
187 assert count == 1, count
188 if isinstance(value, int):
189 value = int.to_bytes(value, length=type_len, byteorder="big", signed=False)
190 if not isinstance(value, (bytes, bytearray)):
191 raise Exception(f"can only write bytes into fd. got: {value!r}")
192 while len(value) > 0 and value[0] == 0x00:
193 value = value[1:]
194 nbytes_written = fd.write(value)
195 if nbytes_written != len(value):
196 raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
197 return
198 elif field_type == 'varint':
199 assert count == 1, count
200 if isinstance(value, int):
201 value = write_bigsize_int(value)
202 if not isinstance(value, (bytes, bytearray)):
203 raise Exception(f"can only write bytes into fd. got: {value!r}")
204 nbytes_written = fd.write(value)
205 if nbytes_written != len(value):
206 raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
207 return
208 elif field_type == 'chain_hash':
209 type_len = 32
210 elif field_type == 'channel_id':
211 type_len = 32
212 elif field_type == 'sha256':
213 type_len = 32
214 elif field_type == 'signature':
215 type_len = 64
216 elif field_type == 'point':
217 type_len = 33
218 elif field_type == 'short_channel_id':
219 type_len = 8
220 total_len = -1
221 if count != "...":
222 if type_len is None:
223 raise UnknownMsgFieldType(f"unknown field type: {field_type!r}")
224 total_len = count * type_len
225 if isinstance(value, int) and (count == 1 or field_type == 'byte'):
226 value = int.to_bytes(value, length=total_len, byteorder="big", signed=False)
227 if not isinstance(value, (bytes, bytearray)):
228 raise Exception(f"can only write bytes into fd. got: {value!r}")
229 if count != "..." and total_len != len(value):
230 raise UnexpectedFieldSizeForEncoder(f"expected: {total_len}, got {len(value)}")
231 nbytes_written = fd.write(value)
232 if nbytes_written != len(value):
233 raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
234
235
236 def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes]:
237 if not fd: raise Exception()
238 tlv_type = _read_field(fd=fd, field_type="varint", count=1)
239 tlv_len = _read_field(fd=fd, field_type="varint", count=1)
240 tlv_val = _read_field(fd=fd, field_type="byte", count=tlv_len)
241 return tlv_type, tlv_val
242
243
244 def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None:
245 if not fd: raise Exception()
246 tlv_len = len(tlv_val)
247 _write_field(fd=fd, field_type="varint", count=1, value=tlv_type)
248 _write_field(fd=fd, field_type="varint", count=1, value=tlv_len)
249 _write_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val)
250
251
252 def _resolve_field_count(field_count_str: str, *, vars_dict: dict, allow_any=False) -> Union[int, str]:
253 """Returns an evaluated field count, typically an int.
254 If allow_any is True, the return value can be a str with value=="...".
255 """
256 if field_count_str == "":
257 field_count = 1
258 elif field_count_str == "...":
259 if not allow_any:
260 raise Exception("field count is '...' but allow_any is False")
261 return field_count_str
262 else:
263 try:
264 field_count = int(field_count_str)
265 except ValueError:
266 field_count = vars_dict[field_count_str]
267 if isinstance(field_count, (bytes, bytearray)):
268 field_count = int.from_bytes(field_count, byteorder="big")
269 assert isinstance(field_count, int)
270 return field_count
271
272
273 def _parse_msgtype_intvalue_for_onion_wire(value: str) -> int:
274 msg_type_int = 0
275 for component in value.split("|"):
276 try:
277 msg_type_int |= int(component)
278 except ValueError:
279 msg_type_int |= OnionFailureCodeMetaFlag[component]
280 return msg_type_int
281
282
283 class LNSerializer:
284
285 def __init__(self, *, for_onion_wire: bool = False):
286 # TODO msg_type could be 'int' everywhere...
287 self.msg_scheme_from_type = {} # type: Dict[bytes, List[Sequence[str]]]
288 self.msg_type_from_name = {} # type: Dict[str, bytes]
289
290 self.in_tlv_stream_get_tlv_record_scheme_from_type = {} # type: Dict[str, Dict[int, List[Sequence[str]]]]
291 self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]]
292 self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]]
293
294 if for_onion_wire:
295 path = os.path.join(os.path.dirname(__file__), "lnwire", "onion_wire.csv")
296 else:
297 path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv")
298 with open(path, newline='') as f:
299 csvreader = csv.reader(f)
300 for row in csvreader:
301 #print(f">>> {row!r}")
302 if row[0] == "msgtype":
303 # msgtype,<msgname>,<value>[,<option>]
304 msg_type_name = row[1]
305 if for_onion_wire:
306 msg_type_int = _parse_msgtype_intvalue_for_onion_wire(str(row[2]))
307 else:
308 msg_type_int = int(row[2])
309 msg_type_bytes = msg_type_int.to_bytes(2, 'big')
310 assert msg_type_bytes not in self.msg_scheme_from_type, f"type collision? for {msg_type_name}"
311 assert msg_type_name not in self.msg_type_from_name, f"type collision? for {msg_type_name}"
312 row[2] = msg_type_int
313 self.msg_scheme_from_type[msg_type_bytes] = [tuple(row)]
314 self.msg_type_from_name[msg_type_name] = msg_type_bytes
315 elif row[0] == "msgdata":
316 # msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
317 assert msg_type_name == row[1]
318 self.msg_scheme_from_type[msg_type_bytes].append(tuple(row))
319 elif row[0] == "tlvtype":
320 # tlvtype,<tlvstreamname>,<tlvname>,<value>[,<option>]
321 tlv_stream_name = row[1]
322 tlv_record_name = row[2]
323 tlv_record_type = int(row[3])
324 row[3] = tlv_record_type
325 if tlv_stream_name not in self.in_tlv_stream_get_tlv_record_scheme_from_type:
326 self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name] = OrderedDict()
327 self.in_tlv_stream_get_record_type_from_name[tlv_stream_name] = {}
328 self.in_tlv_stream_get_record_name_from_type[tlv_stream_name] = {}
329 assert tlv_record_type not in self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
330 assert tlv_record_name not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
331 assert tlv_record_type not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
332 self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type] = [tuple(row)]
333 self.in_tlv_stream_get_record_type_from_name[tlv_stream_name][tlv_record_name] = tlv_record_type
334 self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type] = tlv_record_name
335 if max(self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name].keys()) > tlv_record_type:
336 raise Exception(f"tlv record types must be listed in monotonically increasing order for stream. "
337 f"stream={tlv_stream_name}")
338 elif row[0] == "tlvdata":
339 # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
340 assert tlv_stream_name == row[1]
341 assert tlv_record_name == row[2]
342 self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type].append(tuple(row))
343 else:
344 pass # TODO
345
346 def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) -> None:
347 scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
348 for tlv_record_type, scheme in scheme_map.items(): # note: tlv_record_type is monotonically increasing
349 tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
350 if tlv_record_name not in kwargs:
351 continue
352 with io.BytesIO() as tlv_record_fd:
353 for row in scheme:
354 if row[0] == "tlvtype":
355 pass
356 elif row[0] == "tlvdata":
357 # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
358 assert tlv_stream_name == row[1]
359 assert tlv_record_name == row[2]
360 field_name = row[3]
361 field_type = row[4]
362 field_count_str = row[5]
363 field_count = _resolve_field_count(field_count_str,
364 vars_dict=kwargs[tlv_record_name],
365 allow_any=True)
366 field_value = kwargs[tlv_record_name][field_name]
367 _write_field(fd=tlv_record_fd,
368 field_type=field_type,
369 count=field_count,
370 value=field_value)
371 else:
372 raise Exception(f"unexpected row in scheme: {row!r}")
373 _write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())
374
375 def read_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str) -> Dict[str, Dict[str, Any]]:
376 parsed = {} # type: Dict[str, Dict[str, Any]]
377 scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
378 last_seen_tlv_record_type = -1 # type: int
379 while _num_remaining_bytes_to_read(fd) > 0:
380 tlv_record_type, tlv_record_val = _read_tlv_record(fd=fd)
381 if not (tlv_record_type > last_seen_tlv_record_type):
382 raise MsgInvalidFieldOrder(f"TLV records must be monotonically increasing by type. "
383 f"cur: {tlv_record_type}. prev: {last_seen_tlv_record_type}")
384 last_seen_tlv_record_type = tlv_record_type
385 try:
386 scheme = scheme_map[tlv_record_type]
387 except KeyError:
388 if tlv_record_type % 2 == 0:
389 # unknown "even" type: hard fail
390 raise UnknownMandatoryTLVRecordType(f"{tlv_stream_name}/{tlv_record_type}") from None
391 else:
392 # unknown "odd" type: skip it
393 continue
394 tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
395 parsed[tlv_record_name] = {}
396 with io.BytesIO(tlv_record_val) as tlv_record_fd:
397 for row in scheme:
398 #print(f"row: {row!r}")
399 if row[0] == "tlvtype":
400 pass
401 elif row[0] == "tlvdata":
402 # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
403 assert tlv_stream_name == row[1]
404 assert tlv_record_name == row[2]
405 field_name = row[3]
406 field_type = row[4]
407 field_count_str = row[5]
408 field_count = _resolve_field_count(field_count_str,
409 vars_dict=parsed[tlv_record_name],
410 allow_any=True)
411 #print(f">> count={field_count}. parsed={parsed}")
412 parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd,
413 field_type=field_type,
414 count=field_count)
415 else:
416 raise Exception(f"unexpected row in scheme: {row!r}")
417 if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
418 raise MsgTrailingGarbage(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage")
419 return parsed
420
421 def encode_msg(self, msg_type: str, **kwargs) -> bytes:
422 """
423 Encode kwargs into a Lightning message (bytes)
424 of the type given in the msg_type string
425 """
426 #print(f">>> encode_msg. msg_type={msg_type}, payload={kwargs!r}")
427 msg_type_bytes = self.msg_type_from_name[msg_type]
428 scheme = self.msg_scheme_from_type[msg_type_bytes]
429 with io.BytesIO() as fd:
430 fd.write(msg_type_bytes)
431 for row in scheme:
432 if row[0] == "msgtype":
433 pass
434 elif row[0] == "msgdata":
435 # msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
436 field_name = row[2]
437 field_type = row[3]
438 field_count_str = row[4]
439 #print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}")
440 field_count = _resolve_field_count(field_count_str, vars_dict=kwargs)
441 if field_name == "tlvs":
442 tlv_stream_name = field_type
443 if tlv_stream_name in kwargs:
444 self.write_tlv_stream(fd=fd, tlv_stream_name=tlv_stream_name, **(kwargs[tlv_stream_name]))
445 continue
446 try:
447 field_value = kwargs[field_name]
448 except KeyError:
449 if len(row) > 5:
450 break # optional feature field not present
451 else:
452 field_value = 0 # default mandatory fields to zero
453 #print(f">>> encode_msg. writing field: {field_name}. value={field_value!r}. field_type={field_type!r}. count={field_count!r}")
454 _write_field(fd=fd,
455 field_type=field_type,
456 count=field_count,
457 value=field_value)
458 #print(f">>> encode_msg. so far: {fd.getvalue().hex()}")
459 else:
460 raise Exception(f"unexpected row in scheme: {row!r}")
461 return fd.getvalue()
462
463 def decode_msg(self, data: bytes) -> Tuple[str, dict]:
464 """
465 Decode Lightning message by reading the first
466 two bytes to determine message type.
467
468 Returns message type string and parsed message contents dict
469 """
470 #print(f"decode_msg >>> {data.hex()}")
471 assert len(data) >= 2
472 msg_type_bytes = data[:2]
473 msg_type_int = int.from_bytes(msg_type_bytes, byteorder="big", signed=False)
474 scheme = self.msg_scheme_from_type[msg_type_bytes]
475 assert scheme[0][2] == msg_type_int
476 msg_type_name = scheme[0][1]
477 parsed = {}
478 with io.BytesIO(data[2:]) as fd:
479 for row in scheme:
480 #print(f"row: {row!r}")
481 if row[0] == "msgtype":
482 pass
483 elif row[0] == "msgdata":
484 field_name = row[2]
485 field_type = row[3]
486 field_count_str = row[4]
487 field_count = _resolve_field_count(field_count_str, vars_dict=parsed)
488 if field_name == "tlvs":
489 tlv_stream_name = field_type
490 d = self.read_tlv_stream(fd=fd, tlv_stream_name=tlv_stream_name)
491 parsed[tlv_stream_name] = d
492 continue
493 #print(f">> count={field_count}. parsed={parsed}")
494 try:
495 parsed[field_name] = _read_field(fd=fd,
496 field_type=field_type,
497 count=field_count)
498 except UnexpectedEndOfStream as e:
499 if len(row) > 5:
500 break # optional feature field not present
501 else:
502 raise
503 else:
504 raise Exception(f"unexpected row in scheme: {row!r}")
505 return msg_type_name, parsed
506
507
508 _inst = LNSerializer()
509 encode_msg = _inst.encode_msg
510 decode_msg = _inst.decode_msg
511
512
513 OnionWireSerializer = LNSerializer(for_onion_wire=True)