URI: 
       tcreate transport and perform handshake before creating Peer - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit b5482e4470dcf2da331bf33b65c8a37ba009aa57
   DIR parent 61638664f769fa28d33711115c849267517ede08
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Fri,  1 Feb 2019 20:21:59 +0100
       
       create transport and perform handshake before creating Peer
       
       Diffstat:
         M electrum/lnbase.py                  |      28 ++++++++++------------------
         M electrum/lntransport.py             |      27 +++++++++++++++++----------
         M electrum/lnworker.py                |      20 ++++++++++----------
         M electrum/tests/test_lnbase.py       |      11 ++++++-----
       
       4 files changed, 43 insertions(+), 43 deletions(-)
       ---
   DIR diff --git a/electrum/lnbase.py b/electrum/lnbase.py
       t@@ -197,15 +197,14 @@ def gen_msg(msg_type: str, **kwargs) -> bytes:
        
        class Peer(PrintError):
        
       -    def __init__(self, lnworker: 'LNWorker', peer_addr: LNPeerAddr, responding=False,
       -                 request_initial_sync=False, transport: LNTransportBase=None):
       -        self.responding = responding
       +    def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase,
       +                 request_initial_sync=False):
                self.initialized = asyncio.Event()
                self.transport = transport
       -        self.peer_addr = peer_addr
       +        self.pubkey = pubkey
                self.lnworker = lnworker
                self.privkey = lnworker.node_keypair.privkey
       -        self.node_ids = [peer_addr.pubkey, privkey_to_pubkey(self.privkey)]
       +        self.node_ids = [self.pubkey, privkey_to_pubkey(self.privkey)]
                self.network = lnworker.network
                self.lnwatcher = lnworker.network.lnwatcher
                self.channel_db = lnworker.network.channel_db
       t@@ -233,19 +232,14 @@ class Peer(PrintError):
                self.transport.send_bytes(gen_msg(message_name, **kwargs))
        
            async def initialize(self):
       -        if not self.transport:
       -            reader, writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port)
       -            transport = LNTransport(self.privkey, self.peer_addr.pubkey, reader, writer)
       -            await transport.handshake()
       -            self.transport = transport
                self.send_message("init", gflen=0, lflen=1, localfeatures=self.localfeatures)
        
            @property
            def channels(self) -> Dict[bytes, Channel]:
       -        return self.lnworker.channels_for_peer(self.peer_addr.pubkey)
       +        return self.lnworker.channels_for_peer(self.pubkey)
        
            def diagnostic_name(self):
       -        return str(self.peer_addr.host) + ':' + str(self.peer_addr.port)
       +        return self.transport.name()
        
            def ping_if_required(self):
                if time.time() - self.ping_time > 120:
       t@@ -352,7 +346,7 @@ class Peer(PrintError):
                        self.print_error("disconnecting gracefully. {}".format(e))
                    finally:
                        self.close_and_cleanup()
       -                self.lnworker.peers.pop(self.peer_addr.pubkey)
       +                self.lnworker.peers.pop(self.pubkey)
                return wrapper_func
        
            @ignore_exceptions  # do not kill main_taskgroup
       t@@ -373,8 +367,6 @@ class Peer(PrintError):
                except (OSError, asyncio.TimeoutError, HandshakeFailed) as e:
                    self.print_error('initialize failed, disconnecting: {}'.format(repr(e)))
                    return
       -        if not self.responding:
       -            self.channel_db.add_recent_peer(self.peer_addr)
                # loop
                async for msg in self.transport.read_messages():
                    self.process_message(msg)
       t@@ -513,7 +505,7 @@ class Peer(PrintError):
                # remote commitment transaction
                channel_id, funding_txid_bytes = channel_id_from_funding_tx(funding_txid, funding_index)
                chan_dict = {
       -                "node_id": self.peer_addr.pubkey,
       +                "node_id": self.pubkey,
                        "channel_id": channel_id,
                        "short_channel_id": None,
                        "funding_outpoint": Outpoint(funding_txid, funding_index),
       t@@ -587,7 +579,7 @@ class Peer(PrintError):
                remote_dust_limit_sat = int.from_bytes(payload['dust_limit_satoshis'], byteorder='big') # TODO validate
                remote_reserve_sat = self.validate_remote_reserve(payload['channel_reserve_satoshis'], remote_dust_limit_sat, funding_sat)
                chan_dict = {
       -                "node_id": self.peer_addr.pubkey,
       +                "node_id": self.pubkey,
                        "channel_id": channel_id,
                        "short_channel_id": None,
                        "funding_outpoint": Outpoint(funding_txid, funding_idx),
       t@@ -794,7 +786,7 @@ class Peer(PrintError):
                remote_bitcoin_sig = announcement_signatures_msg["bitcoin_signature"]
                if not ecc.verify_signature(chan.config[REMOTE].multisig_key.pubkey, remote_bitcoin_sig, h):
                    raise Exception("bitcoin_sig invalid in announcement_signatures")
       -        if not ecc.verify_signature(self.peer_addr.pubkey, remote_node_sig, h):
       +        if not ecc.verify_signature(self.pubkey, remote_node_sig, h):
                    raise Exception("node_sig invalid in announcement_signatures")
        
                node_sigs = [remote_node_sig, local_node_sig]
   DIR diff --git a/electrum/lntransport.py b/electrum/lntransport.py
       t@@ -6,6 +6,7 @@
        # Derived from https://gist.github.com/AdamISZ/046d05c156aaeb56cc897f85eecb3eb8
        
        import hashlib
       +import asyncio
        from asyncio import StreamReader, StreamWriter
        from Cryptodome.Cipher import ChaCha20_Poly1305
        
       t@@ -87,10 +88,6 @@ def create_ephemeral_key() -> (bytes, bytes):
        
        class LNTransportBase:
        
       -    def __init__(self, reader: StreamReader, writer: StreamWriter):
       -        self.reader = reader
       -        self.writer = writer
       -
            def send_bytes(self, msg):
                l = len(msg).to_bytes(2, 'big')
                lc = aead_encrypt(self.sk, self.sn(), b'', l)
       t@@ -153,12 +150,16 @@ class LNTransportBase:
        
        class LNResponderTransport(LNTransportBase):
            def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter):
       -        LNTransportBase.__init__(self, reader, writer)
       +        LNTransportBase.__init__(self)
       +        self.reader = reader
       +        self.writer = writer
                self.privkey = privkey
        
       +    def name(self):
       +        return "responder"
       +
            async def handshake(self, **kwargs):
                hs = HandshakeState(privkey_to_pubkey(self.privkey))
       -
                act1 = b''
                while len(act1) < 50:
                    act1 += await self.reader.read(50 - len(act1))
       t@@ -205,14 +206,20 @@ class LNResponderTransport(LNTransportBase):
                return rs
        
        class LNTransport(LNTransportBase):
       -    def __init__(self, privkey: bytes, remote_pubkey: bytes,
       -                 reader: StreamReader, writer: StreamWriter):
       -        LNTransportBase.__init__(self, reader, writer)
       +
       +    def __init__(self, privkey: bytes, peer_addr):
       +        LNTransportBase.__init__(self)
                assert type(privkey) is bytes and len(privkey) == 32
                self.privkey = privkey
       -        self.remote_pubkey = remote_pubkey
       +        self.remote_pubkey = peer_addr.pubkey
       +        self.host = peer_addr.host
       +        self.port = peer_addr.port
       +
       +    def name(self):
       +        return str(self.host) + ':' + str(self.port)
        
            async def handshake(self):
       +        self.reader, self.writer = await asyncio.open_connection(self.host, self.port)
                hs = HandshakeState(self.remote_pubkey)
                # Get a new ephemeral key
                epriv, epub = create_ephemeral_key()
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -28,7 +28,7 @@ from .crypto import sha256
        from .bip32 import bip32_root
        from .util import bh2u, bfh, PrintError, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions
        from .util import timestamp_to_datetime
       -from .lntransport import LNResponderTransport
       +from .lntransport import LNTransport, LNResponderTransport
        from .lnbase import Peer
        from .lnaddr import lnencode, LnAddr, lndecode
        from .ecc import der_sig_from_sig_string
       t@@ -244,13 +244,16 @@ class LNWorker(PrintError):
                    return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
        
            async def add_peer(self, host, port, node_id):
       -        port = int(port)
       -        peer_addr = LNPeerAddr(host, port, node_id)
                if node_id in self.peers:
                    return self.peers[node_id]
       +        port = int(port)
       +        peer_addr = LNPeerAddr(host, port, node_id)
       +        transport = LNTransport(self.node_keypair.privkey, peer_addr)
       +        await transport.handshake()
       +        self.channel_db.add_recent_peer(peer_addr)
                self._last_tried_peer[peer_addr] = time.time()
                self.print_error("adding peer", peer_addr)
       -        peer = Peer(self, peer_addr, request_initial_sync=self.config.get("request_initial_sync", True))
       +        peer = Peer(self, node_id, transport, request_initial_sync=self.config.get("request_initial_sync", True))
                await self.network.main_taskgroup.spawn(peer.main_loop())
                self.peers[node_id] = peer
                self.network.trigger_callback('ln_status')
       t@@ -797,16 +800,13 @@ class LNWorker(PrintError):
                        # ipv6
                        addr = addr[1:-1]
                    async def cb(reader, writer):
       -                t = LNResponderTransport(self.node_keypair.privkey, reader, writer)
       +                transport = LNResponderTransport(self.node_keypair.privkey, reader, writer)
                        try:
       -                    node_id = await t.handshake()
       +                    node_id = await transport.handshake()
                        except:
                            self.print_error('handshake failure from incoming connection')
                            return
       -                # FIXME extract host and port from transport
       -                peer = Peer(self, LNPeerAddr("bogus", 1337, node_id), responding=True,
       -                            request_initial_sync=self.config.get("request_initial_sync", True),
       -                            transport=t)
       +                peer = Peer(self, node_id, transport, request_initial_sync=self.config.get("request_initial_sync", True))
                        self.peers[node_id] = peer
                        await self.network.main_taskgroup.spawn(peer.main_loop())
                        self.network.trigger_callback('ln_status')
   DIR diff --git a/electrum/tests/test_lnbase.py b/electrum/tests/test_lnbase.py
       t@@ -113,6 +113,9 @@ class MockTransport:
            def __init__(self):
                self.queue = asyncio.Queue()
        
       +    def name(self):
       +        return ""
       +
            async def read_messages(self):
                while True:
                    yield await self.queue.get()
       t@@ -150,7 +153,7 @@ class TestPeer(unittest.TestCase):
            def test_require_data_loss_protect(self):
                mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None)
                mock_transport = NoFeaturesTransport()
       -        p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport)
       +        p1 = Peer(mock_lnworker, b"\x00" * 33, mock_transport, request_initial_sync=False)
                mock_lnworker.peer = p1
                with self.assertRaises(LightningPeerConnectionClosed):
                    run(asyncio.wait_for(p1._main_loop(), 1))
       t@@ -161,10 +164,8 @@ class TestPeer(unittest.TestCase):
                q1, q2 = asyncio.Queue(), asyncio.Queue()
                w1 = MockLNWorker(k1, k2, self.alice_channel, tx_queue=q1)
                w2 = MockLNWorker(k2, k1, self.bob_channel, tx_queue=q2)
       -        p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey),
       -                request_initial_sync=False, transport=t1)
       -        p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
       -                request_initial_sync=False, transport=t2)
       +        p1 = Peer(w1, k1.pubkey, t1, request_initial_sync=False)
       +        p2 = Peer(w2, k2.pubkey, t2, request_initial_sync=False)
                w1.peer = p1
                w2.peer = p2
                # mark_open won't work if state is already OPEN.