tlntransport.py - electrum - Electrum Bitcoin wallet
HTML git clone https://git.parazyd.org/electrum
DIR Log
DIR Files
DIR Refs
DIR Submodules
---
tlntransport.py (9838B)
---
1 # Copyright (C) 2018 Adam Gibson (waxwing)
2 # Copyright (C) 2018 The Electrum developers
3 # Distributed under the MIT software license, see the accompanying
4 # file LICENCE or http://www.opensource.org/licenses/mit-license.php
5
6 # Derived from https://gist.github.com/AdamISZ/046d05c156aaeb56cc897f85eecb3eb8
7
8 import hashlib
9 import asyncio
10 from asyncio import StreamReader, StreamWriter
11 from typing import Optional
12
13 from .crypto import sha256, hmac_oneshot, chacha20_poly1305_encrypt, chacha20_poly1305_decrypt
14 from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed,
15 HandshakeFailed, LNPeerAddr)
16 from . import ecc
17 from .util import bh2u, MySocksProxy
18
19
20 class HandshakeState(object):
21 prologue = b"lightning"
22 protocol_name = b"Noise_XK_secp256k1_ChaChaPoly_SHA256"
23 handshake_version = b"\x00"
24
25 def __init__(self, responder_pub):
26 self.responder_pub = responder_pub
27 self.h = sha256(self.protocol_name)
28 self.ck = self.h
29 self.update(self.prologue)
30 self.update(self.responder_pub)
31
32 def update(self, data):
33 self.h = sha256(self.h + data)
34 return self.h
35
36 def get_nonce_bytes(n):
37 """BOLT 8 requires the nonce to be 12 bytes, 4 bytes leading
38 zeroes and 8 bytes little endian encoded 64 bit integer.
39 """
40 return b"\x00"*4 + n.to_bytes(8, 'little')
41
42 def aead_encrypt(key: bytes, nonce: int, associated_data: bytes, data: bytes) -> bytes:
43 nonce_bytes = get_nonce_bytes(nonce)
44 return chacha20_poly1305_encrypt(key=key,
45 nonce=nonce_bytes,
46 associated_data=associated_data,
47 data=data)
48
49 def aead_decrypt(key: bytes, nonce: int, associated_data: bytes, data: bytes) -> bytes:
50 nonce_bytes = get_nonce_bytes(nonce)
51 return chacha20_poly1305_decrypt(key=key,
52 nonce=nonce_bytes,
53 associated_data=associated_data,
54 data=data)
55
56 def get_bolt8_hkdf(salt, ikm):
57 """RFC5869 HKDF instantiated in the specific form
58 used in Lightning BOLT 8:
59 Extract and expand to 64 bytes using HMAC-SHA256,
60 with info field set to a zero length string as per BOLT8
61 Return as two 32 byte fields.
62 """
63 #Extract
64 prk = hmac_oneshot(salt, msg=ikm, digest=hashlib.sha256)
65 assert len(prk) == 32
66 #Expand
67 info = b""
68 T0 = b""
69 T1 = hmac_oneshot(prk, T0 + info + b"\x01", digest=hashlib.sha256)
70 T2 = hmac_oneshot(prk, T1 + info + b"\x02", digest=hashlib.sha256)
71 assert len(T1 + T2) == 64
72 return T1, T2
73
74 def act1_initiator_message(hs, epriv, epub):
75 ss = get_ecdh(epriv, hs.responder_pub)
76 ck2, temp_k1 = get_bolt8_hkdf(hs.ck, ss)
77 hs.ck = ck2
78 c = aead_encrypt(temp_k1, 0, hs.update(epub), b"")
79 #for next step if we do it
80 hs.update(c)
81 msg = hs.handshake_version + epub + c
82 assert len(msg) == 50
83 return msg, temp_k1
84
85
86 def create_ephemeral_key() -> (bytes, bytes):
87 privkey = ecc.ECPrivkey.generate_random_key()
88 return privkey.get_secret_bytes(), privkey.get_public_key_bytes()
89
90
91 class LNTransportBase:
92 reader: StreamReader
93 writer: StreamWriter
94 privkey: bytes
95
96 def name(self) -> str:
97 raise NotImplementedError()
98
99 def send_bytes(self, msg: bytes) -> None:
100 l = len(msg).to_bytes(2, 'big')
101 lc = aead_encrypt(self.sk, self.sn(), b'', l)
102 c = aead_encrypt(self.sk, self.sn(), b'', msg)
103 assert len(lc) == 18
104 assert len(c) == len(msg) + 16
105 self.writer.write(lc+c)
106
107 async def read_messages(self):
108 read_buffer = b''
109 while True:
110 rn_l, rk_l = self.rn()
111 rn_m, rk_m = self.rn()
112 while True:
113 if len(read_buffer) >= 18:
114 lc = read_buffer[:18]
115 l = aead_decrypt(rk_l, rn_l, b'', lc)
116 length = int.from_bytes(l, 'big')
117 offset = 18 + length + 16
118 if len(read_buffer) >= offset:
119 c = read_buffer[18:offset]
120 read_buffer = read_buffer[offset:]
121 msg = aead_decrypt(rk_m, rn_m, b'', c)
122 yield msg
123 break
124 try:
125 s = await self.reader.read(2**10)
126 except asyncio.CancelledError:
127 raise
128 except Exception:
129 s = None
130 if not s:
131 raise LightningPeerConnectionClosed()
132 read_buffer += s
133
134 def rn(self):
135 o = self._rn, self.rk
136 self._rn += 1
137 if self._rn == 1000:
138 self.r_ck, self.rk = get_bolt8_hkdf(self.r_ck, self.rk)
139 self._rn = 0
140 return o
141
142 def sn(self):
143 o = self._sn
144 self._sn += 1
145 if self._sn == 1000:
146 self.s_ck, self.sk = get_bolt8_hkdf(self.s_ck, self.sk)
147 self._sn = 0
148 return o
149
150 def init_counters(self, ck):
151 # init counters
152 self._sn = 0
153 self._rn = 0
154 self.r_ck = ck
155 self.s_ck = ck
156
157 def close(self):
158 self.writer.close()
159
160
161 class LNResponderTransport(LNTransportBase):
162 """Transport initiated by remote party."""
163
164 def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter):
165 LNTransportBase.__init__(self)
166 self.reader = reader
167 self.writer = writer
168 self.privkey = privkey
169
170 def name(self):
171 return "responder"
172
173 async def handshake(self, **kwargs):
174 hs = HandshakeState(privkey_to_pubkey(self.privkey))
175 act1 = b''
176 while len(act1) < 50:
177 buf = await self.reader.read(50 - len(act1))
178 if not buf:
179 raise HandshakeFailed('responder disconnected')
180 act1 += buf
181 if len(act1) != 50:
182 raise HandshakeFailed('responder: short act 1 read, length is ' + str(len(act1)))
183 if bytes([act1[0]]) != HandshakeState.handshake_version:
184 raise HandshakeFailed('responder: bad handshake version in act 1')
185 c = act1[-16:]
186 re = act1[1:34]
187 h = hs.update(re)
188 ss = get_ecdh(self.privkey, re)
189 ck, temp_k1 = get_bolt8_hkdf(sha256(HandshakeState.protocol_name), ss)
190 _p = aead_decrypt(temp_k1, 0, h, c)
191 hs.update(c)
192
193 # act 2
194 if 'epriv' not in kwargs:
195 epriv, epub = create_ephemeral_key()
196 else:
197 epriv = kwargs['epriv']
198 epub = ecc.ECPrivkey(epriv).get_public_key_bytes()
199 hs.ck = ck
200 hs.responder_pub = re
201
202 msg, temp_k2 = act1_initiator_message(hs, epriv, epub)
203 self.writer.write(msg)
204
205 # act 3
206 act3 = b''
207 while len(act3) < 66:
208 buf = await self.reader.read(66 - len(act3))
209 if not buf:
210 raise HandshakeFailed('responder disconnected')
211 act3 += buf
212 if len(act3) != 66:
213 raise HandshakeFailed('responder: short act 3 read, length is ' + str(len(act3)))
214 if bytes([act3[0]]) != HandshakeState.handshake_version:
215 raise HandshakeFailed('responder: bad handshake version in act 3')
216 c = act3[1:50]
217 t = act3[-16:]
218 rs = aead_decrypt(temp_k2, 1, hs.h, c)
219 ss = get_ecdh(epriv, rs)
220 ck, temp_k3 = get_bolt8_hkdf(hs.ck, ss)
221 _p = aead_decrypt(temp_k3, 0, hs.update(c), t)
222 self.rk, self.sk = get_bolt8_hkdf(ck, b'')
223 self.init_counters(ck)
224 return rs
225
226
227 class LNTransport(LNTransportBase):
228 """Transport initiated by local party."""
229
230 def __init__(self, privkey: bytes, peer_addr: LNPeerAddr, *,
231 proxy: Optional[dict]):
232 LNTransportBase.__init__(self)
233 assert type(privkey) is bytes and len(privkey) == 32
234 self.privkey = privkey
235 self.peer_addr = peer_addr
236 self.proxy = MySocksProxy.from_proxy_dict(proxy)
237
238 def name(self):
239 return self.peer_addr.net_addr_str()
240
241 async def handshake(self):
242 if not self.proxy:
243 self.reader, self.writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port)
244 else:
245 self.reader, self.writer = await self.proxy.open_connection(self.peer_addr.host, self.peer_addr.port)
246 hs = HandshakeState(self.peer_addr.pubkey)
247 # Get a new ephemeral key
248 epriv, epub = create_ephemeral_key()
249
250 msg, _temp_k1 = act1_initiator_message(hs, epriv, epub)
251 # act 1
252 self.writer.write(msg)
253 rspns = await self.reader.read(2**10)
254 if len(rspns) != 50:
255 raise HandshakeFailed(f"Lightning handshake act 1 response has bad length, "
256 f"are you sure this is the right pubkey? {self.peer_addr}")
257 hver, alice_epub, tag = rspns[0], rspns[1:34], rspns[34:]
258 if bytes([hver]) != hs.handshake_version:
259 raise HandshakeFailed("unexpected handshake version: {}".format(hver))
260 # act 2
261 hs.update(alice_epub)
262 ss = get_ecdh(epriv, alice_epub)
263 ck, temp_k2 = get_bolt8_hkdf(hs.ck, ss)
264 hs.ck = ck
265 p = aead_decrypt(temp_k2, 0, hs.h, tag)
266 hs.update(tag)
267 # act 3
268 my_pubkey = privkey_to_pubkey(self.privkey)
269 c = aead_encrypt(temp_k2, 1, hs.h, my_pubkey)
270 hs.update(c)
271 ss = get_ecdh(self.privkey[:32], alice_epub)
272 ck, temp_k3 = get_bolt8_hkdf(hs.ck, ss)
273 hs.ck = ck
274 t = aead_encrypt(temp_k3, 0, hs.h, b'')
275 msg = hs.handshake_version + c + t
276 self.writer.write(msg)
277 self.sk, self.rk = get_bolt8_hkdf(hs.ck, b'')
278 self.init_counters(ck)