URI: 
       tmove force_close_channel to lnbase, test it, add FORCE_CLOSING state - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 0ea87278fb18a535c5ba35eb542ea7dfd5672f18
   DIR parent 6211e656a8cbdc9b4ab31b127e94bce6523f92d3
  HTML Author: Janus <ysangkok@gmail.com>
       Date:   Fri,  2 Nov 2018 19:16:42 +0100
       
       move force_close_channel to lnbase, test it, add FORCE_CLOSING state
       
       Diffstat:
         M electrum/lnbase.py                  |      21 ++++++++++++++++++++-
         M electrum/lnworker.py                |      56 +++++++++++++++----------------
         M electrum/tests/test_lnbase.py       |      80 +++++++++++++++++++++++++-------
         M electrum/tests/test_lnchan.py       |      12 +++++++++++-
       
       4 files changed, 122 insertions(+), 47 deletions(-)
       ---
   DIR diff --git a/electrum/lnbase.py b/electrum/lnbase.py
       t@@ -19,7 +19,7 @@ import aiorpcx
        from .crypto import sha256, sha256d
        from . import bitcoin
        from . import ecc
       -from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string
       +from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string, der_sig_from_sig_string
        from . import constants
        from .util import PrintError, bh2u, print_error, bfh, log_exceptions, list_enabled_bits, ignore_exceptions
        from .transaction import Transaction, TxOutput
       t@@ -1158,6 +1158,25 @@ class Peer(PrintError):
                self.print_error('Channel closed', txid)
                return txid
        
       +    async def force_close_channel(self, chan_id):
       +        chan = self.channels[chan_id]
       +        # local_commitment always gives back the next expected local_commitment,
       +        # but in this case, we want the current one. So substract one ctn number
       +        old_local_state = chan.config[LOCAL]
       +        chan.config[LOCAL]=chan.config[LOCAL]._replace(ctn=chan.config[LOCAL].ctn - 1)
       +        tx = chan.pending_local_commitment
       +        chan.config[LOCAL] = old_local_state
       +        tx.sign({bh2u(chan.config[LOCAL].multisig_key.pubkey): (chan.config[LOCAL].multisig_key.privkey, True)})
       +        remote_sig = chan.config[LOCAL].current_commitment_signature
       +        remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01"
       +        none_idx = tx._inputs[0]["signatures"].index(None)
       +        tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))
       +        assert tx.is_complete()
       +        # TODO persist FORCE_CLOSING state to disk
       +        chan.set_state('FORCE_CLOSING')
       +        self.lnworker.save_channel(chan)
       +        return await self.network.broadcast_transaction(tx)
       +
            @log_exceptions
            async def on_shutdown(self, payload):
                # length of scripts allowed in BOLT-02
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -11,6 +11,7 @@ from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING
        import threading
        import socket
        import json
       +from decimal import Decimal
        
        import dns.resolver
        import dns.exception
       t@@ -267,18 +268,13 @@ class LNWorker(PrintError):
                return addr, peer, fut
        
            def _pay(self, invoice, amount_sat=None):
       -        addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
       -        payment_hash = addr.paymenthash
       -        amount_sat = (addr.amount * COIN) if addr.amount else amount_sat
       -        if amount_sat is None:
       -            raise InvoiceError(_("Missing amount"))
       -        amount_msat = int(amount_sat * 1000)
       -        if addr.get_min_final_cltv_expiry() > 60 * 144:
       -            raise InvoiceError("{}\n{}".format(
       -                _("Invoice wants us to risk locking funds for unreasonably long."),
       -                f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
       -        route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat)
       -        node_id, short_channel_id = route[0].node_id, route[0].short_channel_id
       +        addr = self._check_invoice(invoice, amount_sat)
       +        route = self._create_route_from_invoice(decoded_invoice=addr)
       +        peer = self.peers[route[0].node_id]
       +        return addr, peer, self._pay_to_route(route, addr)
       +
       +    async def _pay_to_route(self, route, addr):
       +        short_channel_id = route[0].short_channel_id
                with self.lock:
                    channels = list(self.channels.values())
                for chan in channels:
       t@@ -286,11 +282,24 @@ class LNWorker(PrintError):
                        break
                else:
                    raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
       -        peer = self.peers[node_id]
       -        coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry())
       -        return addr, peer, coro
       +        peer = self.peers[route[0].node_id]
       +        return await peer.pay(route, chan, int(addr.amount * COIN * 1000), addr.paymenthash, addr.get_min_final_cltv_expiry())
        
       -    def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]:
       +    @staticmethod
       +    def _check_invoice(invoice, amount_sat=None):
       +        addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
       +        if amount_sat:
       +            addr.amount = Decimal(amount_sat) / COIN
       +        if addr.amount is None:
       +            raise InvoiceError(_("Missing amount"))
       +        if addr.get_min_final_cltv_expiry() > 60 * 144:
       +            raise InvoiceError("{}\n{}".format(
       +                _("Invoice wants us to risk locking funds for unreasonably long."),
       +                f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
       +        return addr
       +
       +    def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]:
       +        amount_msat = int(decoded_invoice.amount * COIN * 1000)
                invoice_pubkey = decoded_invoice.pubkey.serialize()
                # use 'r' field from invoice
                route = None  # type: List[RouteEdge]
       t@@ -441,19 +450,8 @@ class LNWorker(PrintError):
        
            async def force_close_channel(self, chan_id):
                chan = self.channels[chan_id]
       -        # local_commitment always gives back the next expected local_commitment,
       -        # but in this case, we want the current one. So substract one ctn number
       -        old_local_state = chan.config[LOCAL]
       -        chan.config[LOCAL]=chan.config[LOCAL]._replace(ctn=chan.config[LOCAL].ctn - 1)
       -        tx = chan.pending_local_commitment
       -        chan.config[LOCAL] = old_local_state
       -        tx.sign({bh2u(chan.config[LOCAL].multisig_key.pubkey): (chan.config[LOCAL].multisig_key.privkey, True)})
       -        remote_sig = chan.config[LOCAL].current_commitment_signature
       -        remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01"
       -        none_idx = tx._inputs[0]["signatures"].index(None)
       -        tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))
       -        assert tx.is_complete()
       -        return await self.network.broadcast_transaction(tx)
       +        peer = self.peers[chan.node_id]
       +        return await peer.force_close_channel(chan_id)
        
            def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
                now = time.time()
   DIR diff --git a/electrum/tests/test_lnbase.py b/electrum/tests/test_lnbase.py
       t@@ -16,6 +16,7 @@ from electrum.util import bh2u
        from electrum.lnbase import Peer, decode_msg, gen_msg
        from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
        from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
       +from electrum.lnutil import PaymentFailure
        from electrum.lnrouter import ChannelDB, LNPathFinder
        from electrum.lnworker import LNWorker
        
       t@@ -33,7 +34,7 @@ def noop_lock():
            yield
        
        class MockNetwork:
       -    def __init__(self):
       +    def __init__(self, tx_queue):
                self.callbacks = defaultdict(list)
                self.lnwatcher = None
                user_config = {}
       t@@ -43,6 +44,7 @@ class MockNetwork:
                self.channel_db = ChannelDB(self)
                self.interface = None
                self.path_finder = LNPathFinder(self.channel_db)
       +        self.tx_queue = tx_queue
        
            @property
            def callback_lock(self):
       t@@ -55,12 +57,16 @@ class MockNetwork:
            def get_local_height(self):
                return 0
        
       +    async def broadcast_transaction(self, tx):
       +        if self.tx_queue:
       +            await self.tx_queue.put(tx)
       +
        class MockLNWorker:
       -    def __init__(self, remote_keypair, local_keypair, chan):
       +    def __init__(self, remote_keypair, local_keypair, chan, tx_queue):
                self.chan = chan
                self.remote_keypair = remote_keypair
                self.node_keypair = local_keypair
       -        self.network = MockNetwork()
       +        self.network = MockNetwork(tx_queue)
                self.channels = {self.chan.channel_id: self.chan}
                self.invoices = {}
        
       t@@ -76,10 +82,12 @@ class MockLNWorker:
                return self.channels
        
            def save_channel(self, chan):
       -        pass
       +        print("Ignoring channel save")
        
            get_invoice = LNWorker.get_invoice
            _create_route_from_invoice = LNWorker._create_route_from_invoice
       +    _check_invoice = staticmethod(LNWorker._check_invoice)
       +    _pay_to_route = LNWorker._pay_to_route
        
        class MockTransport:
            def __init__(self):
       t@@ -120,18 +128,19 @@ class TestPeer(unittest.TestCase):
                self.alice_channel, self.bob_channel = create_test_channels()
        
            def test_require_data_loss_protect(self):
       -        mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel)
       +        mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None)
                mock_transport = NoFeaturesTransport()
                p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport)
                mock_lnworker.peer = p1
                with self.assertRaises(LightningPeerConnectionClosed):
       -            asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1))
       +            run(asyncio.wait_for(p1._main_loop(), 1))
        
       -    def test_payment(self):
       +    def prepare_peers(self):
                k1, k2 = keypair(), keypair()
                t1, t2 = transport_pair()
       -        w1 = MockLNWorker(k1, k2, self.alice_channel)
       -        w2 = MockLNWorker(k2, k1, self.bob_channel)
       +        q1, q2 = asyncio.Queue(), asyncio.Queue()
       +        w1 = MockLNWorker(k1, k2, self.alice_channel, tx_queue=q1)
       +        w2 = MockLNWorker(k2, k1, self.bob_channel, tx_queue=q2)
                p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey),
                        request_initial_sync=False, transport=t1)
                p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
       t@@ -145,6 +154,11 @@ class TestPeer(unittest.TestCase):
                # this populates the channel graph:
                p1.mark_open(self.alice_channel)
                p2.mark_open(self.bob_channel)
       +        return p1, p2, w1, w2, q1, q2
       +
       +    @staticmethod
       +    def prepare_invoice(w2 # receiver
       +            ):
                amount_btc = 100000/Decimal(COIN)
                payment_preimage = os.urandom(32)
                RHASH = sha256(payment_preimage)
       t@@ -156,13 +170,23 @@ class TestPeer(unittest.TestCase):
                                 ])
                pay_req = lnencode(addr, w2.node_keypair.privkey)
                w2.invoices[bh2u(RHASH)] = (bh2u(payment_preimage), pay_req)
       -        l = asyncio.get_event_loop()
       -        async def pay():
       -            fut = asyncio.Future()
       -            def evt_set(event, _lnworker, msg):
       -                fut.set_result(msg)
       -            w2.network.register_callback(evt_set, ['ln_message'])
       +        return pay_req
       +
       +    @staticmethod
       +    def prepare_ln_message_future(w2 # receiver
       +            ):
       +        fut = asyncio.Future()
       +        def evt_set(event, _lnworker, msg):
       +            fut.set_result(msg)
       +        w2.network.register_callback(evt_set, ['ln_message'])
       +        return fut
       +
       +    def test_payment(self):
       +        p1, p2, w1, w2, _q1, _q2 = self.prepare_peers()
       +        pay_req = self.prepare_invoice(w2)
       +        fut = self.prepare_ln_message_future(w2)
        
       +        async def pay():
                    addr, peer, coro = LNWorker._pay(w1, pay_req)
                    await coro
                    print("HTLC ADDED")
       t@@ -170,4 +194,28 @@ class TestPeer(unittest.TestCase):
                    gath.cancel()
                gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop())
                with self.assertRaises(asyncio.CancelledError):
       -            l.run_until_complete(gath)
       +            run(gath)
       +
       +    def test_channel_usage_after_closing(self):
       +        p1, p2, w1, w2, q1, q2 = self.prepare_peers()
       +        pay_req = self.prepare_invoice(w2)
       +
       +        addr = w1._check_invoice(pay_req)
       +        route = w1._create_route_from_invoice(decoded_invoice=addr)
       +
       +        run(p1.force_close_channel(self.alice_channel.channel_id))
       +        # check if a tx (commitment transaction) was broadcasted:
       +        assert q1.qsize() == 1
       +
       +        with self.assertRaises(PaymentFailure) as e:
       +            w1._create_route_from_invoice(decoded_invoice=addr)
       +        self.assertEqual(str(e.exception), 'No path found')
       +
       +        peer = w1.peers[route[0].node_id]
       +        # AssertionError is ok since we shouldn't use old routes, and the
       +        # route finding should fail when channel is closed
       +        with self.assertRaises(AssertionError):
       +            run(asyncio.gather(w1._pay_to_route(route, addr), p1._main_loop(), p2._main_loop()))
       +
       +def run(coro):
       +    asyncio.get_event_loop().run_until_complete(coro)
   DIR diff --git a/electrum/tests/test_lnchan.py b/electrum/tests/test_lnchan.py
       t@@ -29,6 +29,7 @@ from electrum import lnchan
        from electrum import lnutil
        from electrum import bip32 as bip32_utils
        from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED
       +from electrum.ecc import sig_string_from_der_sig
        
        one_bitcoin_in_msat = bitcoin.COIN * 1000
        
       t@@ -81,7 +82,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
                        per_commitment_secret_seed=seed,
                        funding_locked_received=True,
                        was_announced=False,
       -                current_commitment_signature=None,
       +                # just a random signature
       +                current_commitment_signature=sig_string_from_der_sig(bytes.fromhex('3046022100c66e112e22b91b96b795a6dd5f4b004f3acccd9a2a31bf104840f256855b7aa3022100e711b868b62d87c7edd95a2370e496b9cb6a38aff13c9f64f9ff2f3b2a0052dd')),
                        current_htlc_signatures=None,
                    ),
                    "constraints":lnbase.ChannelConstraints(
       t@@ -185,6 +187,14 @@ class TestChannel(unittest.TestCase):
        
                self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0]
        
       +    def test_concurrent_reversed_payment(self):
       +        self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02')
       +        self.htlc_dict['amount_msat'] += 1000
       +        bob_idx = self.bob_channel.add_htlc(self.htlc_dict)
       +        alice_idx = self.alice_channel.receive_htlc(self.htlc_dict)
       +        self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment())
       +        self.assertEquals(len(self.alice_channel.pending_remote_commitment.outputs()), 3)
       +
            def test_SimpleAddSettleWorkflow(self):
                alice_channel, bob_channel = self.alice_channel, self.bob_channel
                htlc = self.htlc