URI: 
       tprotocol: Do task cleanup when peer disconnects. - obelisk - Electrum server using libbitcoin as its backend
  HTML git clone https://git.parazyd.org/obelisk
   DIR Log
   DIR Files
   DIR Refs
   DIR README
   DIR LICENSE
       ---
   DIR commit f6449ea78a20d6ef3d62d7dc00de34ec05bbfc10
   DIR parent 0c8ef25aea1b5e0bab605b950f77d93279861c9b
  HTML Author: parazyd <parazyd@dyne.org>
       Date:   Mon, 19 Apr 2021 17:59:10 +0200
       
       protocol: Do task cleanup when peer disconnects.
       
       Diffstat:
         M obelisk/protocol.py                 |     135 ++++++++++++++++++-------------
         M obelisk/zeromq.py                   |      10 +++++-----
         M tests/test_electrum_protocol.py     |      12 ++++++++++++
       
       3 files changed, 94 insertions(+), 63 deletions(-)
       ---
   DIR diff --git a/obelisk/protocol.py b/obelisk/protocol.py
       t@@ -65,13 +65,9 @@ class ElectrumProtocol(asyncio.Protocol):  # pylint: disable=R0904,R0902
                self.endpoints = endpoints
                self.server_cfg = server_cfg
                self.loop = asyncio.get_event_loop()
       -        self.chain_tip = 0
       -        # Consider renaming bx to something else
                self.bx = Client(log, endpoints, self.loop)
                self.block_queue = None
       -        # TODO: Clean up on client disconnect
       -        self.tasks = []
       -        self.sh_subscriptions = {}
       +        self.peers = {}
        
                if chain == "mainnet":  # pragma: no cover
                    self.genesis = "000000000019d6689c085ae165831e934ff763ae46a2a6c172b3f1b60a8ce26f"
       t@@ -112,28 +108,32 @@ class ElectrumProtocol(asyncio.Protocol):  # pylint: disable=R0904,R0902
                self.log.debug("ElectrumProtocol.stop()")
                self.stopped = True
                if self.bx:
       -            # unsub_pool = []
       -            # for i in self.sh_subscriptions:  # pragma: no cover
       -            # self.log.debug("bx.unsubscribe %s", i)
       -            # unsub_pool.append(self.bx.unsubscribe_scripthash(i))
       -            # await asyncio.gather(*unsub_pool, return_exceptions=True)
       +            for i in self.peers:
       +                await self._peer_cleanup(i)
                    await self.bx.stop()
        
       -        # idxs = []
       -        # for task in self.tasks:
       -        # idxs.append(self.tasks.index(task))
       -        # task.cancel()
       -        # for i in idxs:
       -        # del self.tasks[i]
       +    async def _peer_cleanup(self, peer):
       +        """Cleanup tasks and data for peer"""
       +        self.log.debug("Cleaning up data for %s", peer)
       +        for i in self.peers[peer]["tasks"]:
       +            i.cancel()
       +        for i in self.peers[peer]["sh"]:
       +            self.peers[peer]["sh"][i]["task"].cancel()
       +
       +    @staticmethod
       +    def _get_peer(writer):
       +        peer_t = writer._transport.get_extra_info("peername")  # pylint: disable=W0212
       +        return f"{peer_t[0]}:{peer_t[1]}"
        
            async def recv(self, reader, writer):
                """Loop ran upon a connection which acts as a JSON-RPC handler"""
                recv_buf = bytearray()
       +        self.peers[self._get_peer(writer)] = {"tasks": [], "sh": {}}
       +
                while not self.stopped:
                    data = await reader.read(4096)
                    if not data or len(data) == 0:
       -                self.log.debug("Received EOF, disconnect")
       -                # TODO: cancel asyncio tasks for this client here?
       +                await self._peer_cleanup(self._get_peer(writer))
                        return
                    recv_buf.extend(data)
                    lb = recv_buf.find(b"\n")
       t@@ -181,12 +181,7 @@ class ElectrumProtocol(asyncio.Protocol):  # pylint: disable=R0904,R0902
        
            async def handle_query(self, writer, query):  # pylint: disable=R0915,R0912,R0911
                """Electrum protocol method handler mapper"""
       -        if "method" not in query:
       -            self.log.debug("No 'method' in query: %s", query)
       -            return await self._send_reply(writer, JsonRPCError.invalidrequest(),
       -                                          None)
       -        if "id" not in query:
       -            self.log.debug("No 'id' in query: %s", query)
       +        if "method" not in query or "id" not in query:
                    return await self._send_reply(writer, JsonRPCError.invalidrequest(),
                                                  None)
        
       t@@ -304,13 +299,11 @@ class ElectrumProtocol(asyncio.Protocol):  # pylint: disable=R0904,R0902
                self.block_queue = asyncio.Queue()
                await self.bx.subscribe_to_blocks(self.block_queue)
                while True:
       -            # item = (seq, height, block_data)
                    item = await self.block_queue.get()
                    if len(item) != 3:
                        self.log.debug("error: item from block queue len != 3")
                        continue
        
       -            self.chain_tip = item[1]
                    header = block_to_header(item[2])
                    params = [{"height": item[1], "hex": safe_hexlify(header)}]
                    await self._send_notification(writer,
       t@@ -331,8 +324,8 @@ class ElectrumProtocol(asyncio.Protocol):  # pylint: disable=R0904,R0902
                    self.log.debug("Got error: %s", repr(_ec))
                    return JsonRPCError.internalerror()
        
       -        self.chain_tip = height
       -        self.tasks.append(asyncio.create_task(self.header_notifier(writer)))
       +        self.peers[self._get_peer(writer)]["tasks"].append(
       +            asyncio.create_task(self.header_notifier(writer)))
                ret = {"height": height, "hex": safe_hexlify(tip_header)}
                return {"result": ret}
        
       t@@ -428,32 +421,56 @@ class ElectrumProtocol(asyncio.Protocol):  # pylint: disable=R0904,R0902
        
                return {"result": ret}
        
       +    async def scripthash_renewer(self, scripthash, queue):
       +        while True:
       +            try:
       +                self.log.debug("scriphash renewer: %s", scripthash)
       +                _ec = await self.bx.subscribe_scripthash(scripthash, queue)
       +                if _ec and _ec != 0:
       +                    self.log.error("bx.subscribe_scripthash failed: %s",
       +                                   repr(_ec))
       +                await asyncio.sleep(60)
       +            except asyncio.CancelledError:
       +                self.log.debug("%s renewer cancelled", scripthash)
       +                break
       +
            async def scripthash_notifier(self, writer, scripthash):
                # TODO: Mempool
       +        # TODO: This is still flaky and not always notified. Investigate.
       +        self.log.debug("notifier")
                method = "blockchain.scripthash.subscribe"
       -        while True:
       -            _ec, sh_queue = await self.bx.subscribe_scripthash(scripthash)
       -            if _ec and _ec != 0:
       -                self.log.error("bx.subscribe_scripthash failed: %s", repr(_ec))
       -                return
       -
       -            item = await sh_queue.get()
       -            _ec, height, txid = struct.unpack("<HI32s", item)
       +        queue = asyncio.Queue()
       +        renew_task = asyncio.create_task(
       +            self.scripthash_renewer(scripthash, queue))
        
       -            if (_ec == ErrorCode.service_stopped.value and height == 0 and
       -                    not self.stopped):
       -                # Subscription expired
       -                continue
       -
       -            self.sh_subscriptions[scripthash]["status"].append(
       -                (hash_to_hex_str(txid), height))
       -
       -            params = [
       -                scripthash,
       -                ElectrumProtocol.__scripthash_status_encode(
       -                    self.sh_subscriptions[scripthash]["status"]),
       -            ]
       -            await self._send_notification(writer, method, params)
       +        while True:
       +            try:
       +                item = await queue.get()
       +                _ec, height, txid = struct.unpack("<HI32s", item)
       +
       +                if (_ec == ErrorCode.service_stopped.value and height == 0 and
       +                        not self.stopped):
       +                    self.log.debug("subscription expired: %s", scripthash)
       +                    # Subscription expired
       +                    continue
       +
       +                self.peers[self._get_peer(writer)]["sh"]["status"].append(
       +                    (hash_to_hex_str(txid), height))
       +
       +                self.log.debug("shnotifier: Got _ec: %d", _ec)
       +                self.log.debug("shnotifier: Got height: %d", height)
       +                self.log.debug("shnotifier: Got txid: %s",
       +                               hash_to_hex_str(txid))
       +
       +                params = [
       +                    scripthash,
       +                    ElectrumProtocol.__scripthash_status_encode(self.peers[
       +                        self._get_peer(writer)]["sh"]["scripthash"]["status"]),
       +                ]
       +                await self._send_notification(writer, method, params)
       +            except asyncio.CancelledError:
       +                break
       +        renew_task.cancel()
        
            async def scripthash_subscribe(self, writer, query):  # pylint: disable=W0613
                """Method: blockchain.scripthash.subscribe
       t@@ -470,16 +487,17 @@ class ElectrumProtocol(asyncio.Protocol):  # pylint: disable=R0904,R0902
                if _ec and _ec != 0:
                    return JsonRPCError.internalerror()
        
       -        if len(history) < 1:
       -            return {"result": None}
       -
                # TODO: Check how history4 acts for mempool/unconfirmed
                status = ElectrumProtocol.__scripthash_status_from_history(history)
       -        self.sh_subscriptions[scripthash] = {"status": status}
       +        self.peers[self._get_peer(writer)]["sh"][scripthash] = {
       +            "status": status
       +        }
        
                task = asyncio.create_task(self.scripthash_notifier(writer, scripthash))
       -        self.sh_subscriptions[scripthash]["task"] = task
       +        self.peers[self._get_peer(writer)]["sh"][scripthash]["task"] = task
        
       +        if len(history) < 1:
       +            return {"result": None}
                return {"result": ElectrumProtocol.__scripthash_status_encode(status)}
        
            @staticmethod
       t@@ -517,10 +535,11 @@ class ElectrumProtocol(asyncio.Protocol):  # pylint: disable=R0904,R0902
                if not is_hash256_str(scripthash):
                    return JsonRPCError.invalidparams()
        
       -        if scripthash in self.sh_subscriptions:
       -            self.sh_subscriptions[scripthash]["task"].cancel()
       +        if scripthash in self.peers[self._get_peer(writer)]["sh"]:
       +            self.peers[self._get_peer(
       +                writer)]["sh"][scripthash]["task"].cancel()
                    # await self.bx.unsubscribe_scripthash(scripthash)
       -            del self.sh_subscriptions[scripthash]
       +            del self.peers[self._get_peer(writer)]["sh"][scripthash]
                    return {"result": True}
        
                return {"result": False}
   DIR diff --git a/obelisk/zeromq.py b/obelisk/zeromq.py
       t@@ -266,11 +266,11 @@ class Client:
                socket.connect(self._endpoints["query"])
                return socket
        
       -    async def _subscription_request(self, command, data):
       +    async def _subscription_request(self, command, data, queue):
                request = await self._request(command, data)
       -        request.queue = asyncio.Queue()
       +        request.queue = queue
                error_code, _ = await self._wait_for_response(request)
       -        return error_code, request.queue
       +        return error_code
        
            async def _simple_request(self, command, data):
                return await self._wait_for_response(await self._request(command, data))
       t@@ -345,11 +345,11 @@ class Client:
                    return error_code, None
                return error_code, data
        
       -    async def subscribe_scripthash(self, scripthash):
       +    async def subscribe_scripthash(self, scripthash, queue):
                """Subscribe to scripthash"""
                command = b"subscribe.key"
                decoded_address = unhexlify(scripthash)
       -        return await self._subscription_request(command, decoded_address)
       +        return await self._subscription_request(command, decoded_address, queue)
        
            async def unsubscribe_scripthash(self, scripthash):
                """Unsubscribe scripthash"""
   DIR diff --git a/tests/test_electrum_protocol.py b/tests/test_electrum_protocol.py
       t@@ -399,11 +399,21 @@ async def test_send_reply(protocol, writer, method):
            assert_equal(writer.mock, expect)
        
        
       +class MockTransport:
       +
       +    def __init__(self):
       +        self.peername = ("foo", 42)
       +
       +    def get_extra_info(self, param):
       +        return self.peername
       +
       +
        class MockWriter(asyncio.StreamWriter):  # pragma: no cover
            """Mock class for StreamWriter"""
        
            def __init__(self):
                self.mock = None
       +        self._transport = MockTransport()
        
            def write(self, data):
                self.mock = data
       t@@ -455,6 +465,8 @@ async def main():
            protocol = ElectrumProtocol(log, "testnet", libbitcoin, {})
            writer = MockWriter()
        
       +    protocol.peers[protocol._get_peer(writer)] = {"tasks": [], "sh": {}}
       +
            for func in orchestration:
                try:
                    await orchestration[func](protocol, writer, func)