URI: 
       tlntransport.py - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
       tlntransport.py (9838B)
       ---
            1 # Copyright (C) 2018 Adam Gibson (waxwing)
            2 # Copyright (C) 2018 The Electrum developers
            3 # Distributed under the MIT software license, see the accompanying
            4 # file LICENCE or http://www.opensource.org/licenses/mit-license.php
            5 
            6 # Derived from https://gist.github.com/AdamISZ/046d05c156aaeb56cc897f85eecb3eb8
            7 
            8 import hashlib
            9 import asyncio
           10 from asyncio import StreamReader, StreamWriter
           11 from typing import Optional
           12 
           13 from .crypto import sha256, hmac_oneshot, chacha20_poly1305_encrypt, chacha20_poly1305_decrypt
           14 from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed,
           15                      HandshakeFailed, LNPeerAddr)
           16 from . import ecc
           17 from .util import bh2u, MySocksProxy
           18 
           19 
           20 class HandshakeState(object):
           21     prologue = b"lightning"
           22     protocol_name = b"Noise_XK_secp256k1_ChaChaPoly_SHA256"
           23     handshake_version = b"\x00"
           24 
           25     def __init__(self, responder_pub):
           26         self.responder_pub = responder_pub
           27         self.h = sha256(self.protocol_name)
           28         self.ck = self.h
           29         self.update(self.prologue)
           30         self.update(self.responder_pub)
           31 
           32     def update(self, data):
           33         self.h = sha256(self.h + data)
           34         return self.h
           35 
           36 def get_nonce_bytes(n):
           37     """BOLT 8 requires the nonce to be 12 bytes, 4 bytes leading
           38     zeroes and 8 bytes little endian encoded 64 bit integer.
           39     """
           40     return b"\x00"*4 + n.to_bytes(8, 'little')
           41 
           42 def aead_encrypt(key: bytes, nonce: int, associated_data: bytes, data: bytes) -> bytes:
           43     nonce_bytes = get_nonce_bytes(nonce)
           44     return chacha20_poly1305_encrypt(key=key,
           45                                      nonce=nonce_bytes,
           46                                      associated_data=associated_data,
           47                                      data=data)
           48 
           49 def aead_decrypt(key: bytes, nonce: int, associated_data: bytes, data: bytes) -> bytes:
           50     nonce_bytes = get_nonce_bytes(nonce)
           51     return chacha20_poly1305_decrypt(key=key,
           52                                      nonce=nonce_bytes,
           53                                      associated_data=associated_data,
           54                                      data=data)
           55 
           56 def get_bolt8_hkdf(salt, ikm):
           57     """RFC5869 HKDF instantiated in the specific form
           58     used in Lightning BOLT 8:
           59     Extract and expand to 64 bytes using HMAC-SHA256,
           60     with info field set to a zero length string as per BOLT8
           61     Return as two 32 byte fields.
           62     """
           63     #Extract
           64     prk = hmac_oneshot(salt, msg=ikm, digest=hashlib.sha256)
           65     assert len(prk) == 32
           66     #Expand
           67     info = b""
           68     T0 = b""
           69     T1 = hmac_oneshot(prk, T0 + info + b"\x01", digest=hashlib.sha256)
           70     T2 = hmac_oneshot(prk, T1 + info + b"\x02", digest=hashlib.sha256)
           71     assert len(T1 + T2) == 64
           72     return T1, T2
           73 
           74 def act1_initiator_message(hs, epriv, epub):
           75     ss = get_ecdh(epriv, hs.responder_pub)
           76     ck2, temp_k1 = get_bolt8_hkdf(hs.ck, ss)
           77     hs.ck = ck2
           78     c = aead_encrypt(temp_k1, 0, hs.update(epub), b"")
           79     #for next step if we do it
           80     hs.update(c)
           81     msg = hs.handshake_version + epub + c
           82     assert len(msg) == 50
           83     return msg, temp_k1
           84 
           85 
           86 def create_ephemeral_key() -> (bytes, bytes):
           87     privkey = ecc.ECPrivkey.generate_random_key()
           88     return privkey.get_secret_bytes(), privkey.get_public_key_bytes()
           89 
           90 
           91 class LNTransportBase:
           92     reader: StreamReader
           93     writer: StreamWriter
           94     privkey: bytes
           95 
           96     def name(self) -> str:
           97         raise NotImplementedError()
           98 
           99     def send_bytes(self, msg: bytes) -> None:
          100         l = len(msg).to_bytes(2, 'big')
          101         lc = aead_encrypt(self.sk, self.sn(), b'', l)
          102         c = aead_encrypt(self.sk, self.sn(), b'', msg)
          103         assert len(lc) == 18
          104         assert len(c) == len(msg) + 16
          105         self.writer.write(lc+c)
          106 
          107     async def read_messages(self):
          108         read_buffer = b''
          109         while True:
          110             rn_l, rk_l = self.rn()
          111             rn_m, rk_m = self.rn()
          112             while True:
          113                 if len(read_buffer) >= 18:
          114                     lc = read_buffer[:18]
          115                     l = aead_decrypt(rk_l, rn_l, b'', lc)
          116                     length = int.from_bytes(l, 'big')
          117                     offset = 18 + length + 16
          118                     if len(read_buffer) >= offset:
          119                         c = read_buffer[18:offset]
          120                         read_buffer = read_buffer[offset:]
          121                         msg = aead_decrypt(rk_m, rn_m, b'', c)
          122                         yield msg
          123                         break
          124                 try:
          125                     s = await self.reader.read(2**10)
          126                 except asyncio.CancelledError:
          127                     raise
          128                 except Exception:
          129                     s = None
          130                 if not s:
          131                     raise LightningPeerConnectionClosed()
          132                 read_buffer += s
          133 
          134     def rn(self):
          135         o = self._rn, self.rk
          136         self._rn += 1
          137         if self._rn == 1000:
          138             self.r_ck, self.rk = get_bolt8_hkdf(self.r_ck, self.rk)
          139             self._rn = 0
          140         return o
          141 
          142     def sn(self):
          143         o = self._sn
          144         self._sn += 1
          145         if self._sn == 1000:
          146             self.s_ck, self.sk = get_bolt8_hkdf(self.s_ck, self.sk)
          147             self._sn = 0
          148         return o
          149 
          150     def init_counters(self, ck):
          151         # init counters
          152         self._sn = 0
          153         self._rn = 0
          154         self.r_ck = ck
          155         self.s_ck = ck
          156 
          157     def close(self):
          158         self.writer.close()
          159 
          160 
          161 class LNResponderTransport(LNTransportBase):
          162     """Transport initiated by remote party."""
          163 
          164     def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter):
          165         LNTransportBase.__init__(self)
          166         self.reader = reader
          167         self.writer = writer
          168         self.privkey = privkey
          169 
          170     def name(self):
          171         return "responder"
          172 
          173     async def handshake(self, **kwargs):
          174         hs = HandshakeState(privkey_to_pubkey(self.privkey))
          175         act1 = b''
          176         while len(act1) < 50:
          177             buf = await self.reader.read(50 - len(act1))
          178             if not buf:
          179                 raise HandshakeFailed('responder disconnected')
          180             act1 += buf
          181         if len(act1) != 50:
          182             raise HandshakeFailed('responder: short act 1 read, length is ' + str(len(act1)))
          183         if bytes([act1[0]]) != HandshakeState.handshake_version:
          184             raise HandshakeFailed('responder: bad handshake version in act 1')
          185         c = act1[-16:]
          186         re = act1[1:34]
          187         h = hs.update(re)
          188         ss = get_ecdh(self.privkey, re)
          189         ck, temp_k1 = get_bolt8_hkdf(sha256(HandshakeState.protocol_name), ss)
          190         _p = aead_decrypt(temp_k1, 0, h, c)
          191         hs.update(c)
          192 
          193         # act 2
          194         if 'epriv' not in kwargs:
          195             epriv, epub = create_ephemeral_key()
          196         else:
          197             epriv = kwargs['epriv']
          198             epub = ecc.ECPrivkey(epriv).get_public_key_bytes()
          199         hs.ck = ck
          200         hs.responder_pub = re
          201 
          202         msg, temp_k2 = act1_initiator_message(hs, epriv, epub)
          203         self.writer.write(msg)
          204 
          205         # act 3
          206         act3 = b''
          207         while len(act3) < 66:
          208             buf = await self.reader.read(66 - len(act3))
          209             if not buf:
          210                 raise HandshakeFailed('responder disconnected')
          211             act3 += buf
          212         if len(act3) != 66:
          213             raise HandshakeFailed('responder: short act 3 read, length is ' + str(len(act3)))
          214         if bytes([act3[0]]) != HandshakeState.handshake_version:
          215             raise HandshakeFailed('responder: bad handshake version in act 3')
          216         c = act3[1:50]
          217         t = act3[-16:]
          218         rs = aead_decrypt(temp_k2, 1, hs.h, c)
          219         ss = get_ecdh(epriv, rs)
          220         ck, temp_k3 = get_bolt8_hkdf(hs.ck, ss)
          221         _p = aead_decrypt(temp_k3, 0, hs.update(c), t)
          222         self.rk, self.sk = get_bolt8_hkdf(ck, b'')
          223         self.init_counters(ck)
          224         return rs
          225 
          226 
          227 class LNTransport(LNTransportBase):
          228     """Transport initiated by local party."""
          229 
          230     def __init__(self, privkey: bytes, peer_addr: LNPeerAddr, *,
          231                  proxy: Optional[dict]):
          232         LNTransportBase.__init__(self)
          233         assert type(privkey) is bytes and len(privkey) == 32
          234         self.privkey = privkey
          235         self.peer_addr = peer_addr
          236         self.proxy = MySocksProxy.from_proxy_dict(proxy)
          237 
          238     def name(self):
          239         return self.peer_addr.net_addr_str()
          240 
          241     async def handshake(self):
          242         if not self.proxy:
          243             self.reader, self.writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port)
          244         else:
          245             self.reader, self.writer = await self.proxy.open_connection(self.peer_addr.host, self.peer_addr.port)
          246         hs = HandshakeState(self.peer_addr.pubkey)
          247         # Get a new ephemeral key
          248         epriv, epub = create_ephemeral_key()
          249 
          250         msg, _temp_k1 = act1_initiator_message(hs, epriv, epub)
          251         # act 1
          252         self.writer.write(msg)
          253         rspns = await self.reader.read(2**10)
          254         if len(rspns) != 50:
          255             raise HandshakeFailed(f"Lightning handshake act 1 response has bad length, "
          256                                   f"are you sure this is the right pubkey? {self.peer_addr}")
          257         hver, alice_epub, tag = rspns[0], rspns[1:34], rspns[34:]
          258         if bytes([hver]) != hs.handshake_version:
          259             raise HandshakeFailed("unexpected handshake version: {}".format(hver))
          260         # act 2
          261         hs.update(alice_epub)
          262         ss = get_ecdh(epriv, alice_epub)
          263         ck, temp_k2 = get_bolt8_hkdf(hs.ck, ss)
          264         hs.ck = ck
          265         p = aead_decrypt(temp_k2, 0, hs.h, tag)
          266         hs.update(tag)
          267         # act 3
          268         my_pubkey = privkey_to_pubkey(self.privkey)
          269         c = aead_encrypt(temp_k2, 1, hs.h, my_pubkey)
          270         hs.update(c)
          271         ss = get_ecdh(self.privkey[:32], alice_epub)
          272         ck, temp_k3 = get_bolt8_hkdf(hs.ck, ss)
          273         hs.ck = ck
          274         t = aead_encrypt(temp_k3, 0, hs.h, b'')
          275         msg = hs.handshake_version + c + t
          276         self.writer.write(msg)
          277         self.sk, self.rk = get_bolt8_hkdf(hs.ck, b'')
          278         self.init_counters(ck)