URI: 
       ttest_lnbase: add test that pays to another local electrum - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 7e76e821522eb0fcb6aef0fcbd9897d77606b382
   DIR parent ce2b572fa5a97a4686e92d9eaca7c3a95ece02ec
  HTML Author: Janus <ysangkok@gmail.com>
       Date:   Thu, 25 Oct 2018 21:59:16 +0200
       
       ttest_lnbase: add test that pays to another local electrum
       
       Diffstat:
         M electrum/lnbase.py                  |       5 +++++
         M electrum/lnworker.py                |      12 ++++++++++--
         M electrum/tests/test_lnbase.py       |     155 ++++++++++++++++++++++++++-----
       
       3 files changed, 146 insertions(+), 26 deletions(-)
       ---
   DIR diff --git a/electrum/lnbase.py b/electrum/lnbase.py
       t@@ -350,6 +350,11 @@ class Peer(PrintError):
            @log_exceptions
            @handle_disconnect
            async def main_loop(self):
       +        """
       +        This is used in LNWorker and is necessary so that we don't kill the main
       +        task group. It is not merged with _main_loop, so that we can test if the
       +        correct exceptions are getting thrown using _main_loop.
       +        """
                await self._main_loop()
        
            async def _main_loop(self):
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -32,7 +32,6 @@ from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
                             generate_keypair, LnKeyFamily, LOCAL, REMOTE,
                             UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE,
                             NUM_MAX_EDGES_IN_PAYMENT_PATH)
       -from .lnaddr import lndecode
        from .i18n import _
        from .lnrouter import RouteEdge, is_route_sane_to_use
        
       t@@ -258,6 +257,15 @@ class LNWorker(PrintError):
                return bh2u(chan.node_id)
        
            def pay(self, invoice, amount_sat=None):
       +        """
       +        This is not merged with _pay so that we can run the test with
       +        one thread only.
       +        """
       +        addr, peer, coro = self._pay(invoice, amount_sat)
       +        fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
       +        return addr, peer, fut
       +
       +    def _pay(self, invoice, amount_sat=None):
                addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
                payment_hash = addr.paymenthash
                amount_sat = (addr.amount * COIN) if addr.amount else amount_sat
       t@@ -279,7 +287,7 @@ class LNWorker(PrintError):
                    raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
                peer = self.peers[node_id]
                coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry())
       -        return addr, peer, asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
       +        return addr, peer, coro
        
            def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]:
                invoice_pubkey = decoded_invoice.pubkey.serialize()
   DIR diff --git a/electrum/tests/test_lnbase.py b/electrum/tests/test_lnbase.py
       t@@ -1,16 +1,40 @@
       -from electrum.lnbase import Peer, decode_msg, gen_msg
       -from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
       -from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
       -from electrum.ecc import ECPrivkey
       -from electrum.lnrouter import ChannelDB
        import unittest
        import asyncio
       -from electrum import simple_config
        import tempfile
       +from decimal import Decimal
       +import os
       +from contextlib import contextmanager
       +from collections import defaultdict
       +
       +from electrum.network import Network
       +from electrum.ecc import ECPrivkey
       +from electrum import simple_config, lnutil
       +from electrum.lnaddr import lnencode, LnAddr, lndecode
       +from electrum.bitcoin import COIN, sha256
       +from electrum.util import bh2u
       +
       +from electrum.lnbase import Peer, decode_msg, gen_msg
       +from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
       +from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
       +from electrum.lnrouter import ChannelDB, LNPathFinder
       +from electrum.lnworker import LNWorker
       +
        from .test_lnchan import create_test_channels
        
       +def keypair():
       +    priv = ECPrivkey.generate_random_key().get_secret_bytes()
       +    k1 = Keypair(
       +            pubkey=privkey_to_pubkey(priv),
       +            privkey=priv)
       +    return k1
       +
       +@contextmanager
       +def noop_lock():
       +    yield
       +
        class MockNetwork:
            def __init__(self):
       +        self.callbacks = defaultdict(list)
                self.lnwatcher = None
                user_config = {}
                user_dir = tempfile.mkdtemp(prefix="electrum-lnbase-test-")
       t@@ -18,49 +42,132 @@ class MockNetwork:
                self.asyncio_loop = asyncio.get_event_loop()
                self.channel_db = ChannelDB(self)
                self.interface = None
       -    def register_callback(self, cb, trigger_names):
       -        print("callback registered", repr(trigger_names))
       -    def trigger_callback(self, trigger_name, obj):
       -        print("callback triggered", repr(trigger_name))
       +        self.path_finder = LNPathFinder(self.channel_db)
       +
       +    @property
       +    def callback_lock(self):
       +        return noop_lock()
       +
       +    register_callback = Network.register_callback
       +    unregister_callback = Network.unregister_callback
       +    trigger_callback = Network.trigger_callback
       +
       +    def get_local_height(self):
       +        return 0
        
        class MockLNWorker:
       -    def __init__(self, remote_peer_pubkey, chan):
       +    def __init__(self, remote_keypair, local_keypair, chan):
                self.chan = chan
       -        self.remote_peer_pubkey = remote_peer_pubkey
       -        priv = ECPrivkey.generate_random_key().get_secret_bytes()
       -        self.node_keypair = Keypair(
       -                pubkey=privkey_to_pubkey(priv),
       -                privkey=priv)
       +        self.remote_keypair = remote_keypair
       +        self.node_keypair = local_keypair
                self.network = MockNetwork()
       +        self.channels = {self.chan.channel_id: self.chan}
       +        self.invoices = {}
       +
       +    @property
       +    def lock(self):
       +        return noop_lock()
       +
            @property
            def peers(self):
       -        return {self.remote_peer_pubkey: self.peer}
       +        return {self.remote_keypair.pubkey: self.peer}
       +
            def channels_for_peer(self, pubkey):
       -        return {self.chan.channel_id: self.chan}
       +        return self.channels
       +
       +    def save_channel(self, chan):
       +        pass
       +
       +    get_invoice = LNWorker.get_invoice
       +    _create_route_from_invoice = LNWorker._create_route_from_invoice
        
        class MockTransport:
            def __init__(self):
                self.queue = asyncio.Queue()
       +
            async def read_messages(self):
                while True:
                    yield await self.queue.get()
        
       -class BadFeaturesTransport(MockTransport):
       +class NoFeaturesTransport(MockTransport):
       +    """
       +    This answers the init message with a init that doesn't signal any features.
       +    Used for testing that we require DATA_LOSS_PROTECT.
       +    """
            def send_bytes(self, data):
                decoded = decode_msg(data)
                print(decoded)
                if decoded[0] == 'init':
                    self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00"))
        
       +class PutIntoOthersQueueTransport(MockTransport):
       +    def __init__(self):
       +        super().__init__()
       +        self.other_mock_transport = None
       +
       +    def send_bytes(self, data):
       +        self.other_mock_transport.queue.put_nowait(data)
       +
       +def transport_pair():
       +    t1 = PutIntoOthersQueueTransport()
       +    t2 = PutIntoOthersQueueTransport()
       +    t1.other_mock_transport = t2
       +    t2.other_mock_transport = t1
       +    return t1, t2
       +
        class TestPeer(unittest.TestCase):
            def setUp(self):
                self.alice_channel, self.bob_channel = create_test_channels()
       -    def test_bad_feature_flags(self):
       -        # we should require DATA_LOSS_PROTECT
       -        mock_lnworker = MockLNWorker(b"\x00" * 32, self.alice_channel)
       -        mock_transport = BadFeaturesTransport()
       -        p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 32), request_initial_sync=False, transport=mock_transport)
       +
       +    def test_require_data_loss_protect(self):
       +        mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel)
       +        mock_transport = NoFeaturesTransport()
       +        p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport)
                mock_lnworker.peer = p1
                with self.assertRaises(LightningPeerConnectionClosed):
                    asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1))
        
       +    def test_payment(self):
       +        k1, k2 = keypair(), keypair()
       +        t1, t2 = transport_pair()
       +        w1 = MockLNWorker(k1, k2, self.alice_channel)
       +        w2 = MockLNWorker(k2, k1, self.bob_channel)
       +        p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey),
       +                request_initial_sync=False, transport=t1)
       +        p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
       +                request_initial_sync=False, transport=t2)
       +        w1.peer = p1
       +        w2.peer = p2
       +        # mark_open won't work if state is already OPEN.
       +        # so set it to OPENING
       +        self.alice_channel.set_state("OPENING")
       +        self.bob_channel.set_state("OPENING")
       +        # this populates the channel graph:
       +        p1.mark_open(self.alice_channel)
       +        p2.mark_open(self.bob_channel)
       +        amount_btc = 100000/Decimal(COIN)
       +        payment_preimage = os.urandom(32)
       +        RHASH = sha256(payment_preimage)
       +        addr = LnAddr(
       +                    RHASH,
       +                    amount_btc,
       +                    tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE),
       +                          ('d', 'coffee')
       +                         ])
       +        pay_req = lnencode(addr, w2.node_keypair.privkey)
       +        w2.invoices[bh2u(RHASH)] = (bh2u(payment_preimage), pay_req)
       +        l = asyncio.get_event_loop()
       +        async def pay():
       +            fut = asyncio.Future()
       +            def evt_set(event, _lnworker, msg):
       +                fut.set_result(msg)
       +            w2.network.register_callback(evt_set, ['ln_message'])
       +
       +            addr, peer, coro = LNWorker._pay(w1, pay_req)
       +            await coro
       +            print("HTLC ADDED")
       +            self.assertEqual(await fut, 'Payment received')
       +            gath.cancel()
       +        gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop())
       +        with self.assertRaises(asyncio.CancelledError):
       +            l.run_until_complete(gath)