URI: 
       tlnbase: mark initialized later, add tests, etc - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 85789d8a09523cd6c5635ec71ec03d099caf0c48
   DIR parent a42c1067abcdd1a610f3d26da804703055a34b1a
  HTML Author: Janus <ysangkok@gmail.com>
       Date:   Thu, 25 Oct 2018 18:28:18 +0200
       
       lnbase: mark initialized later, add tests, etc
       
       - consistent node_id sorting
       - require OPTION_DATA_LOSS_PROTECT and test it
       
       Diffstat:
         M electrum/lnbase.py                  |      50 +++++++++++++++++++-------------
         A electrum/tests/test_lnbase.py       |      66 +++++++++++++++++++++++++++++++
       
       2 files changed, 96 insertions(+), 20 deletions(-)
       ---
   DIR diff --git a/electrum/lnbase.py b/electrum/lnbase.py
       t@@ -201,6 +201,7 @@ class Peer(PrintError):
                self.peer_addr = peer_addr
                self.lnworker = lnworker
                self.privkey = lnworker.node_keypair.privkey
       +        self.node_ids = [peer_addr.pubkey, privkey_to_pubkey(self.privkey)]
                self.network = lnworker.network
                self.lnwatcher = lnworker.network.lnwatcher
                self.channel_db = lnworker.network.channel_db
       t@@ -218,7 +219,7 @@ class Peer(PrintError):
                self.localfeatures = LnLocalFeatures(0)
                if request_initial_sync:
                    self.localfeatures |= LnLocalFeatures.INITIAL_ROUTING_SYNC
       -        self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT
       +        self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ
                self.attempted_route = {}
                self.orphan_channel_updates = OrderedDict()
        
       t@@ -234,7 +235,6 @@ class Peer(PrintError):
                    await transport.handshake()
                    self.transport = transport
                self.send_message("init", gflen=0, lflen=1, localfeatures=self.localfeatures)
       -        self.initialized.set_result(True)
        
            @property
            def channels(self) -> Dict[bytes, Channel]:
       t@@ -310,6 +310,7 @@ class Peer(PrintError):
                            raise LightningPeerConnectionClosed("remote does not have even flag {}"
                                                                .format(str(LnLocalFeatures(1 << flag))))
                        self.localfeatures ^= 1 << flag  # disable flag
       +        self.initialized.set_result(True)
        
            def on_channel_update(self, payload):
                try:
       t@@ -349,6 +350,13 @@ class Peer(PrintError):
            @log_exceptions
            @handle_disconnect
            async def main_loop(self):
       +        """
       +        This is used from the GUI. It is not merged with the other function,
       +        so that we can test if the correct exceptions are getting thrown.
       +        """
       +        await self._main_loop()
       +
       +    async def _main_loop(self):
                try:
                    await asyncio.wait_for(self.initialize(), 10)
                except (OSError, asyncio.TimeoutError, HandshakeFailed) as e:
       t@@ -757,16 +765,17 @@ class Peer(PrintError):
                if not ecc.verify_signature(self.peer_addr.pubkey, remote_node_sig, h):
                    raise Exception("node_sig invalid in announcement_signatures")
        
       -        node_sigs = [local_node_sig, remote_node_sig]
       -        bitcoin_sigs = [local_bitcoin_sig, remote_bitcoin_sig]
       -        node_ids = [privkey_to_pubkey(self.privkey), self.peer_addr.pubkey]
       -        bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
       +        node_sigs = [remote_node_sig, local_node_sig]
       +        bitcoin_sigs = [remote_bitcoin_sig, local_bitcoin_sig]
       +        bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey, chan.config[LOCAL].multisig_key.pubkey]
        
       -        if node_ids[0] > node_ids[1]:
       +        if self.node_ids[0] > self.node_ids[1]:
                    node_sigs.reverse()
                    bitcoin_sigs.reverse()
       -            node_ids.reverse()
       +            node_ids = list(reversed(self.node_ids))
                    bitcoin_keys.reverse()
       +        else:
       +            node_ids = self.node_ids
        
                self.send_message("channel_announcement",
                    node_signatures_1=node_sigs[0],
       t@@ -793,14 +802,13 @@ class Peer(PrintError):
                chan.set_state("OPEN")
                self.network.trigger_callback('channel', chan)
                # add channel to database
       -        pubkey_ours = self.lnworker.node_keypair.pubkey
       -        pubkey_theirs = self.peer_addr.pubkey
       -        node_ids = [pubkey_theirs, pubkey_ours]
                bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
       -        sorted_node_ids = list(sorted(node_ids))
       -        if sorted_node_ids != node_ids:
       +        sorted_node_ids = list(sorted(self.node_ids))
       +        if sorted_node_ids != self.node_ids:
                    node_ids = sorted_node_ids
                    bitcoin_keys.reverse()
       +        else:
       +            node_ids = self.node_ids
                # note: we inject a channel announcement, and a channel update (for outgoing direction)
                # This is atm needed for
                # - finding routes
       t@@ -813,7 +821,10 @@ class Peer(PrintError):
                                                         'bitcoin_key_1': bitcoin_keys[0], 'bitcoin_key_2': bitcoin_keys[1]},
                                                        trusted=True)
                # only inject outgoing direction:
       -        channel_flags = b'\x00' if node_ids[0] == pubkey_ours else b'\x01'
       +        if node_ids[0] == privkey_to_pubkey(self.privkey):
       +            channel_flags = b'\x00'
       +        else:
       +            channel_flags = b'\x01'
                now = int(time.time()).to_bytes(4, byteorder="big")
                self.channel_db.on_channel_update({"short_channel_id": chan.short_channel_id, 'channel_flags': channel_flags, 'cltv_expiry_delta': b'\x90',
                                                   'htlc_minimum_msat': b'\x03\xe8', 'fee_base_msat': b'\x03\xe8', 'fee_proportional_millionths': b'\x01',
       t@@ -832,16 +843,15 @@ class Peer(PrintError):
        
            def send_announcement_signatures(self, chan):
        
       -        bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey,
       -                        chan.config[REMOTE].multisig_key.pubkey]
       -
       -        node_ids = [privkey_to_pubkey(self.privkey),
       -                    self.peer_addr.pubkey]
       +        bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey,
       +                        chan.config[LOCAL].multisig_key.pubkey]
        
       -        sorted_node_ids = list(sorted(node_ids))
       +        sorted_node_ids = list(sorted(self.node_ids))
                if sorted_node_ids != node_ids:
                    node_ids = sorted_node_ids
                    bitcoin_keys.reverse()
       +        else:
       +            node_ids = self.node_ids
        
                chan_ann = gen_msg("channel_announcement",
                    len=0,
   DIR diff --git a/electrum/tests/test_lnbase.py b/electrum/tests/test_lnbase.py
       t@@ -0,0 +1,66 @@
       +from electrum.lnbase import Peer, decode_msg, gen_msg
       +from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
       +from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
       +from electrum.ecc import ECPrivkey
       +from electrum.lnrouter import ChannelDB
       +import unittest
       +import asyncio
       +from electrum import simple_config
       +import tempfile
       +from .test_lnchan import create_test_channels
       +
       +class MockNetwork:
       +    def __init__(self):
       +        self.lnwatcher = None
       +        user_config = {}
       +        user_dir = tempfile.mkdtemp(prefix="electrum-lnbase-test-")
       +        self.config = simple_config.SimpleConfig(user_config, read_user_dir_function=lambda: user_dir)
       +        self.asyncio_loop = asyncio.get_event_loop()
       +        self.channel_db = ChannelDB(self)
       +        self.interface = None
       +    def register_callback(self, cb, trigger_names):
       +        print("callback registered", repr(trigger_names))
       +    def trigger_callback(self, trigger_name, obj):
       +        print("callback triggered", repr(trigger_name))
       +
       +class MockLNWorker:
       +    def __init__(self, remote_peer_pubkey, chan):
       +        self.chan = chan
       +        self.remote_peer_pubkey = remote_peer_pubkey
       +        priv = ECPrivkey.generate_random_key().get_secret_bytes()
       +        self.node_keypair = Keypair(
       +                pubkey=privkey_to_pubkey(priv),
       +                privkey=priv)
       +        self.network = MockNetwork()
       +    @property
       +    def peers(self):
       +        return {self.remote_peer_pubkey: self.peer}
       +    def channels_for_peer(self, pubkey):
       +        return {self.chan.channel_id: self.chan}
       +
       +class MockTransport:
       +    def __init__(self):
       +        self.queue = asyncio.Queue()
       +    async def read_messages(self):
       +        while True:
       +            yield await self.queue.get()
       +
       +class BadFeaturesTransport(MockTransport):
       +    def send_bytes(self, data):
       +        decoded = decode_msg(data)
       +        print(decoded)
       +        if decoded[0] == 'init':
       +            self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00"))
       +
       +class TestPeer(unittest.TestCase):
       +    def setUp(self):
       +        self.alice_channel, self.bob_channel = create_test_channels()
       +    def test_bad_feature_flags(self):
       +        # we should require DATA_LOSS_PROTECT
       +        mock_lnworker = MockLNWorker(b"\x00" * 32, self.alice_channel)
       +        mock_transport = BadFeaturesTransport()
       +        p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 32), request_initial_sync=False, transport=mock_transport)
       +        mock_lnworker.peer = p1
       +        with self.assertRaises(LightningPeerConnectionClosed):
       +            asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1))
       +