tprotocol: Do task cleanup when peer disconnects. - obelisk - Electrum server using libbitcoin as its backend HTML git clone https://git.parazyd.org/obelisk DIR Log DIR Files DIR Refs DIR README DIR LICENSE --- DIR commit f6449ea78a20d6ef3d62d7dc00de34ec05bbfc10 DIR parent 0c8ef25aea1b5e0bab605b950f77d93279861c9b HTML Author: parazyd <parazyd@dyne.org> Date: Mon, 19 Apr 2021 17:59:10 +0200 protocol: Do task cleanup when peer disconnects. Diffstat: M obelisk/protocol.py | 135 ++++++++++++++++++------------- M obelisk/zeromq.py | 10 +++++----- M tests/test_electrum_protocol.py | 12 ++++++++++++ 3 files changed, 94 insertions(+), 63 deletions(-) --- DIR diff --git a/obelisk/protocol.py b/obelisk/protocol.py t@@ -65,13 +65,9 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902 self.endpoints = endpoints self.server_cfg = server_cfg self.loop = asyncio.get_event_loop() - self.chain_tip = 0 - # Consider renaming bx to something else self.bx = Client(log, endpoints, self.loop) self.block_queue = None - # TODO: Clean up on client disconnect - self.tasks = [] - self.sh_subscriptions = {} + self.peers = {} if chain == "mainnet": # pragma: no cover self.genesis = "000000000019d6689c085ae165831e934ff763ae46a2a6c172b3f1b60a8ce26f" t@@ -112,28 +108,32 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902 self.log.debug("ElectrumProtocol.stop()") self.stopped = True if self.bx: - # unsub_pool = [] - # for i in self.sh_subscriptions: # pragma: no cover - # self.log.debug("bx.unsubscribe %s", i) - # unsub_pool.append(self.bx.unsubscribe_scripthash(i)) - # await asyncio.gather(*unsub_pool, return_exceptions=True) + for i in self.peers: + await self._peer_cleanup(i) await self.bx.stop() - # idxs = [] - # for task in self.tasks: - # idxs.append(self.tasks.index(task)) - # task.cancel() - # for i in idxs: - # del self.tasks[i] + async def _peer_cleanup(self, peer): + """Cleanup tasks and data for peer""" + self.log.debug("Cleaning up data for %s", peer) + for i in self.peers[peer]["tasks"]: + i.cancel() + for i in self.peers[peer]["sh"]: + self.peers[peer]["sh"][i]["task"].cancel() + + @staticmethod + def _get_peer(writer): + peer_t = writer._transport.get_extra_info("peername") # pylint: disable=W0212 + return f"{peer_t[0]}:{peer_t[1]}" async def recv(self, reader, writer): """Loop ran upon a connection which acts as a JSON-RPC handler""" recv_buf = bytearray() + self.peers[self._get_peer(writer)] = {"tasks": [], "sh": {}} + while not self.stopped: data = await reader.read(4096) if not data or len(data) == 0: - self.log.debug("Received EOF, disconnect") - # TODO: cancel asyncio tasks for this client here? + await self._peer_cleanup(self._get_peer(writer)) return recv_buf.extend(data) lb = recv_buf.find(b"\n") t@@ -181,12 +181,7 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902 async def handle_query(self, writer, query): # pylint: disable=R0915,R0912,R0911 """Electrum protocol method handler mapper""" - if "method" not in query: - self.log.debug("No 'method' in query: %s", query) - return await self._send_reply(writer, JsonRPCError.invalidrequest(), - None) - if "id" not in query: - self.log.debug("No 'id' in query: %s", query) + if "method" not in query or "id" not in query: return await self._send_reply(writer, JsonRPCError.invalidrequest(), None) t@@ -304,13 +299,11 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902 self.block_queue = asyncio.Queue() await self.bx.subscribe_to_blocks(self.block_queue) while True: - # item = (seq, height, block_data) item = await self.block_queue.get() if len(item) != 3: self.log.debug("error: item from block queue len != 3") continue - self.chain_tip = item[1] header = block_to_header(item[2]) params = [{"height": item[1], "hex": safe_hexlify(header)}] await self._send_notification(writer, t@@ -331,8 +324,8 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902 self.log.debug("Got error: %s", repr(_ec)) return JsonRPCError.internalerror() - self.chain_tip = height - self.tasks.append(asyncio.create_task(self.header_notifier(writer))) + self.peers[self._get_peer(writer)]["tasks"].append( + asyncio.create_task(self.header_notifier(writer))) ret = {"height": height, "hex": safe_hexlify(tip_header)} return {"result": ret} t@@ -428,32 +421,56 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902 return {"result": ret} + async def scripthash_renewer(self, scripthash, queue): + while True: + try: + self.log.debug("scriphash renewer: %s", scripthash) + _ec = await self.bx.subscribe_scripthash(scripthash, queue) + if _ec and _ec != 0: + self.log.error("bx.subscribe_scripthash failed: %s", + repr(_ec)) + await asyncio.sleep(60) + except asyncio.CancelledError: + self.log.debug("%s renewer cancelled", scripthash) + break + async def scripthash_notifier(self, writer, scripthash): # TODO: Mempool + # TODO: This is still flaky and not always notified. Investigate. + self.log.debug("notifier") method = "blockchain.scripthash.subscribe" - while True: - _ec, sh_queue = await self.bx.subscribe_scripthash(scripthash) - if _ec and _ec != 0: - self.log.error("bx.subscribe_scripthash failed: %s", repr(_ec)) - return - - item = await sh_queue.get() - _ec, height, txid = struct.unpack("<HI32s", item) + queue = asyncio.Queue() + renew_task = asyncio.create_task( + self.scripthash_renewer(scripthash, queue)) - if (_ec == ErrorCode.service_stopped.value and height == 0 and - not self.stopped): - # Subscription expired - continue - - self.sh_subscriptions[scripthash]["status"].append( - (hash_to_hex_str(txid), height)) - - params = [ - scripthash, - ElectrumProtocol.__scripthash_status_encode( - self.sh_subscriptions[scripthash]["status"]), - ] - await self._send_notification(writer, method, params) + while True: + try: + item = await queue.get() + _ec, height, txid = struct.unpack("<HI32s", item) + + if (_ec == ErrorCode.service_stopped.value and height == 0 and + not self.stopped): + self.log.debug("subscription expired: %s", scripthash) + # Subscription expired + continue + + self.peers[self._get_peer(writer)]["sh"]["status"].append( + (hash_to_hex_str(txid), height)) + + self.log.debug("shnotifier: Got _ec: %d", _ec) + self.log.debug("shnotifier: Got height: %d", height) + self.log.debug("shnotifier: Got txid: %s", + hash_to_hex_str(txid)) + + params = [ + scripthash, + ElectrumProtocol.__scripthash_status_encode(self.peers[ + self._get_peer(writer)]["sh"]["scripthash"]["status"]), + ] + await self._send_notification(writer, method, params) + except asyncio.CancelledError: + break + renew_task.cancel() async def scripthash_subscribe(self, writer, query): # pylint: disable=W0613 """Method: blockchain.scripthash.subscribe t@@ -470,16 +487,17 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902 if _ec and _ec != 0: return JsonRPCError.internalerror() - if len(history) < 1: - return {"result": None} - # TODO: Check how history4 acts for mempool/unconfirmed status = ElectrumProtocol.__scripthash_status_from_history(history) - self.sh_subscriptions[scripthash] = {"status": status} + self.peers[self._get_peer(writer)]["sh"][scripthash] = { + "status": status + } task = asyncio.create_task(self.scripthash_notifier(writer, scripthash)) - self.sh_subscriptions[scripthash]["task"] = task + self.peers[self._get_peer(writer)]["sh"][scripthash]["task"] = task + if len(history) < 1: + return {"result": None} return {"result": ElectrumProtocol.__scripthash_status_encode(status)} @staticmethod t@@ -517,10 +535,11 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902 if not is_hash256_str(scripthash): return JsonRPCError.invalidparams() - if scripthash in self.sh_subscriptions: - self.sh_subscriptions[scripthash]["task"].cancel() + if scripthash in self.peers[self._get_peer(writer)]["sh"]: + self.peers[self._get_peer( + writer)]["sh"][scripthash]["task"].cancel() # await self.bx.unsubscribe_scripthash(scripthash) - del self.sh_subscriptions[scripthash] + del self.peers[self._get_peer(writer)]["sh"][scripthash] return {"result": True} return {"result": False} DIR diff --git a/obelisk/zeromq.py b/obelisk/zeromq.py t@@ -266,11 +266,11 @@ class Client: socket.connect(self._endpoints["query"]) return socket - async def _subscription_request(self, command, data): + async def _subscription_request(self, command, data, queue): request = await self._request(command, data) - request.queue = asyncio.Queue() + request.queue = queue error_code, _ = await self._wait_for_response(request) - return error_code, request.queue + return error_code async def _simple_request(self, command, data): return await self._wait_for_response(await self._request(command, data)) t@@ -345,11 +345,11 @@ class Client: return error_code, None return error_code, data - async def subscribe_scripthash(self, scripthash): + async def subscribe_scripthash(self, scripthash, queue): """Subscribe to scripthash""" command = b"subscribe.key" decoded_address = unhexlify(scripthash) - return await self._subscription_request(command, decoded_address) + return await self._subscription_request(command, decoded_address, queue) async def unsubscribe_scripthash(self, scripthash): """Unsubscribe scripthash""" DIR diff --git a/tests/test_electrum_protocol.py b/tests/test_electrum_protocol.py t@@ -399,11 +399,21 @@ async def test_send_reply(protocol, writer, method): assert_equal(writer.mock, expect) +class MockTransport: + + def __init__(self): + self.peername = ("foo", 42) + + def get_extra_info(self, param): + return self.peername + + class MockWriter(asyncio.StreamWriter): # pragma: no cover """Mock class for StreamWriter""" def __init__(self): self.mock = None + self._transport = MockTransport() def write(self, data): self.mock = data t@@ -455,6 +465,8 @@ async def main(): protocol = ElectrumProtocol(log, "testnet", libbitcoin, {}) writer = MockWriter() + protocol.peers[protocol._get_peer(writer)] = {"tasks": [], "sh": {}} + for func in orchestration: try: await orchestration[func](protocol, writer, func)