URI: 
       tMerge pull request #6315 from SomberNight/202007_interface_check_server_response - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 8d7370d897314d8542906aecc6a45cc949651f77
   DIR parent 3393ff757e05af9df8fcfc6d56caa8b72c474fda
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Thu,  2 Jul 2020 18:00:21 +0200
       
       Merge pull request #6315 from SomberNight/202007_interface_check_server_response
       
       interface: check server response for some methods
       Diffstat:
         M electrum/interface.py               |     164 ++++++++++++++++++++++++++++++-
         M electrum/network.py                 |      48 +++++++++----------------------
         M electrum/synchronizer.py            |       7 ++-----
         M electrum/util.py                    |      24 ++++++++++++++++++++++++
       
       4 files changed, 200 insertions(+), 43 deletions(-)
       ---
   DIR diff --git a/electrum/interface.py b/electrum/interface.py
       t@@ -29,7 +29,7 @@ import sys
        import traceback
        import asyncio
        import socket
       -from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple
       +from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any
        from collections import defaultdict
        from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address
        import itertools
       t@@ -44,16 +44,19 @@ from aiorpcx.jsonrpc import JSONRPC, CodeMessageError
        from aiorpcx.rawsocket import RSClient
        import certifi
        
       -from .util import ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy
       +from .util import (ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy,
       +                   is_integer, is_non_negative_integer, is_hash256_str, is_hex_str,
       +                   is_real_number)
        from . import util
        from . import x509
        from . import pem
        from . import version
        from . import blockchain
       -from .blockchain import Blockchain
       +from .blockchain import Blockchain, HEADER_SIZE
        from . import constants
        from .i18n import _
        from .logging import Logger
       +from .transaction import Transaction
        
        if TYPE_CHECKING:
            from .network import Network
       t@@ -82,6 +85,45 @@ class NetworkTimeout:
                RELAXED = 20
                MOST_RELAXED = 60
        
       +
       +def assert_non_negative_integer(val: Any) -> None:
       +    if not is_non_negative_integer(val):
       +        raise RequestCorrupted(f'{val!r} should be a non-negative integer')
       +
       +
       +def assert_integer(val: Any) -> None:
       +    if not is_integer(val):
       +        raise RequestCorrupted(f'{val!r} should be an integer')
       +
       +
       +def assert_real_number(val: Any, *, as_str: bool = False) -> None:
       +    if not is_real_number(val, as_str=as_str):
       +        raise RequestCorrupted(f'{val!r} should be a number')
       +
       +
       +def assert_hash256_str(val: Any) -> None:
       +    if not is_hash256_str(val):
       +        raise RequestCorrupted(f'{val!r} should be a hash256 str')
       +
       +
       +def assert_hex_str(val: Any) -> None:
       +    if not is_hex_str(val):
       +        raise RequestCorrupted(f'{val!r} should be a hex str')
       +
       +
       +def assert_dict_contains_field(d: Any, *, field_name: str) -> Any:
       +    if not isinstance(d, dict):
       +        raise RequestCorrupted(f'{d!r} should be a dict')
       +    if field_name not in d:
       +        raise RequestCorrupted(f'required field {field_name!r} missing from dict')
       +    return d[field_name]
       +
       +
       +def assert_list_or_tuple(val: Any) -> None:
       +    if not isinstance(val, (list, tuple)):
       +        raise RequestCorrupted(f'{val!r} should be a list or tuple')
       +
       +
        class NotificationSession(RPCSession):
        
            def __init__(self, *args, **kwargs):
       t@@ -187,7 +229,7 @@ class RequestTimedOut(GracefulDisconnect):
                return _("Network request timed out.")
        
        
       -class RequestCorrupted(GracefulDisconnect): pass
       +class RequestCorrupted(Exception): pass
        
        class ErrorParsingSSLCert(Exception): pass
        class ErrorGettingSSLCertFromServer(Exception): pass
       t@@ -529,6 +571,8 @@ class Interface(Logger):
                return blockchain.deserialize_header(bytes.fromhex(res), height)
        
            async def request_chunk(self, height: int, tip=None, *, can_return_early=False):
       +        if not is_non_negative_integer(height):
       +            raise Exception(f"{repr(height)} is not a block height")
                index = height // 2016
                if can_return_early and index in self._requested_chunks:
                    return
       t@@ -542,6 +586,16 @@ class Interface(Logger):
                    res = await self.session.send_request('blockchain.block.headers', [index * 2016, size])
                finally:
                    self._requested_chunks.discard(index)
       +        assert_dict_contains_field(res, field_name='count')
       +        assert_dict_contains_field(res, field_name='hex')
       +        assert_dict_contains_field(res, field_name='max')
       +        assert_non_negative_integer(res['count'])
       +        assert_non_negative_integer(res['max'])
       +        assert_hex_str(res['hex'])
       +        if len(res['hex']) != HEADER_SIZE * 2 * res['count']:
       +            raise RequestCorrupted('inconsistent chunk hex and count')
       +        if res['count'] != size:
       +            raise RequestCorrupted(f"expected {size} headers but only got {res['count']}")
                conn = self.blockchain.connect_chunk(index, res['hex'])
                if not conn:
                    return conn, 0
       t@@ -819,6 +873,108 @@ class Interface(Logger):
                    self._ipaddr_bucket = do_bucket()
                return self._ipaddr_bucket
        
       +    async def get_merkle_for_transaction(self, tx_hash: str, tx_height: int) -> dict:
       +        if not is_hash256_str(tx_hash):
       +            raise Exception(f"{repr(tx_hash)} is not a txid")
       +        if not is_non_negative_integer(tx_height):
       +            raise Exception(f"{repr(tx_height)} is not a block height")
       +        # do request
       +        res = await self.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height])
       +        # check response
       +        block_height = assert_dict_contains_field(res, field_name='block_height')
       +        merkle = assert_dict_contains_field(res, field_name='merkle')
       +        pos = assert_dict_contains_field(res, field_name='pos')
       +        # note: tx_height was just a hint to the server, don't enforce the response to match it
       +        assert_non_negative_integer(block_height)
       +        assert_non_negative_integer(pos)
       +        assert_list_or_tuple(merkle)
       +        for item in merkle:
       +            assert_hash256_str(item)
       +        return res
       +
       +    async def get_transaction(self, tx_hash: str, *, timeout=None) -> str:
       +        if not is_hash256_str(tx_hash):
       +            raise Exception(f"{repr(tx_hash)} is not a txid")
       +        raw = await self.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout)
       +        # validate response
       +        tx = Transaction(raw)
       +        try:
       +            tx.deserialize()  # see if raises
       +        except Exception as e:
       +            raise RequestCorrupted(f"cannot deserialize received transaction (txid {tx_hash})") from e
       +        if tx.txid() != tx_hash:
       +            raise RequestCorrupted(f"received tx does not match expected txid {tx_hash} (got {tx.txid()})")
       +        return raw
       +
       +    async def get_history_for_scripthash(self, sh: str) -> List[dict]:
       +        if not is_hash256_str(sh):
       +            raise Exception(f"{repr(sh)} is not a scripthash")
       +        # do request
       +        res = await self.session.send_request('blockchain.scripthash.get_history', [sh])
       +        # check response
       +        assert_list_or_tuple(res)
       +        for tx_item in res:
       +            assert_dict_contains_field(tx_item, field_name='height')
       +            assert_dict_contains_field(tx_item, field_name='tx_hash')
       +            assert_integer(tx_item['height'])
       +            assert_hash256_str(tx_item['tx_hash'])
       +            if tx_item['height'] in (-1, 0):
       +                assert_dict_contains_field(tx_item, field_name='fee')
       +                assert_non_negative_integer(tx_item['fee'])
       +        return res
       +
       +    async def listunspent_for_scripthash(self, sh: str) -> List[dict]:
       +        if not is_hash256_str(sh):
       +            raise Exception(f"{repr(sh)} is not a scripthash")
       +        # do request
       +        res = await self.session.send_request('blockchain.scripthash.listunspent', [sh])
       +        # check response
       +        assert_list_or_tuple(res)
       +        for utxo_item in res:
       +            assert_dict_contains_field(utxo_item, field_name='tx_pos')
       +            assert_dict_contains_field(utxo_item, field_name='value')
       +            assert_dict_contains_field(utxo_item, field_name='tx_hash')
       +            assert_dict_contains_field(utxo_item, field_name='height')
       +            assert_non_negative_integer(utxo_item['tx_pos'])
       +            assert_non_negative_integer(utxo_item['value'])
       +            assert_non_negative_integer(utxo_item['height'])
       +            assert_hash256_str(utxo_item['tx_hash'])
       +        return res
       +
       +    async def get_balance_for_scripthash(self, sh: str) -> dict:
       +        if not is_hash256_str(sh):
       +            raise Exception(f"{repr(sh)} is not a scripthash")
       +        # do request
       +        res = await self.session.send_request('blockchain.scripthash.get_balance', [sh])
       +        # check response
       +        assert_dict_contains_field(res, field_name='confirmed')
       +        assert_dict_contains_field(res, field_name='unconfirmed')
       +        assert_non_negative_integer(res['confirmed'])
       +        assert_non_negative_integer(res['unconfirmed'])
       +        return res
       +
       +    async def get_txid_from_txpos(self, tx_height: int, tx_pos: int, merkle: bool):
       +        if not is_non_negative_integer(tx_height):
       +            raise Exception(f"{repr(tx_height)} is not a block height")
       +        if not is_non_negative_integer(tx_pos):
       +            raise Exception(f"{repr(tx_pos)} should be non-negative integer")
       +        # do request
       +        res = await self.session.send_request(
       +            'blockchain.transaction.id_from_pos',
       +            [tx_height, tx_pos, merkle],
       +        )
       +        # check response
       +        if merkle:
       +            assert_dict_contains_field(res, field_name='tx_hash')
       +            assert_dict_contains_field(res, field_name='merkle')
       +            assert_hash256_str(res['tx_hash'])
       +            assert_list_or_tuple(res['merkle'])
       +            for node_hash in res['merkle']:
       +                assert_hash256_str(node_hash)
       +        else:
       +            assert_hash256_str(res)
       +        return res
       +
        
        def _assert_header_does_not_check_against_any_chain(header: dict) -> None:
            chain_bad = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
   DIR diff --git a/electrum/network.py b/electrum/network.py
       t@@ -816,7 +816,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
                            if success_fut.exception():
                                try:
                                    raise success_fut.exception()
       -                        except (RequestTimedOut, RequestCorrupted):
       +                        except RequestTimedOut:
       +                            await iface.close()
       +                            await iface.got_disconnected
       +                            continue  # try again
       +                        except RequestCorrupted as e:
       +                            # TODO ban server?
       +                            iface.logger.exception(f"RequestCorrupted: {e}")
                                    await iface.close()
                                    await iface.got_disconnected
                                    continue  # try again
       t@@ -836,11 +842,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
            @best_effort_reliable
            @catch_server_exceptions
            async def get_merkle_for_transaction(self, tx_hash: str, tx_height: int) -> dict:
       -        if not is_hash256_str(tx_hash):
       -            raise Exception(f"{repr(tx_hash)} is not a txid")
       -        if not is_non_negative_integer(tx_height):
       -            raise Exception(f"{repr(tx_height)} is not a block height")
       -        return await self.interface.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height])
       +        return await self.interface.get_merkle_for_transaction(tx_hash=tx_hash, tx_height=tx_height)
        
            @best_effort_reliable
            async def broadcast_transaction(self, tx: 'Transaction', *, timeout=None) -> None:
       t@@ -1012,54 +1014,32 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
            @best_effort_reliable
            @catch_server_exceptions
            async def request_chunk(self, height: int, tip=None, *, can_return_early=False):
       -        if not is_non_negative_integer(height):
       -            raise Exception(f"{repr(height)} is not a block height")
                return await self.interface.request_chunk(height, tip=tip, can_return_early=can_return_early)
        
            @best_effort_reliable
            @catch_server_exceptions
            async def get_transaction(self, tx_hash: str, *, timeout=None) -> str:
       -        if not is_hash256_str(tx_hash):
       -            raise Exception(f"{repr(tx_hash)} is not a txid")
       -        iface = self.interface
       -        raw = await iface.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout)
       -        # validate response
       -        tx = Transaction(raw)
       -        try:
       -            tx.deserialize()  # see if raises
       -        except Exception as e:
       -            self.logger.warning(f"cannot deserialize received transaction (txid {tx_hash}). from {str(iface)}")
       -            raise RequestCorrupted() from e  # TODO ban server?
       -        if tx.txid() != tx_hash:
       -            self.logger.warning(f"received tx does not match expected txid {tx_hash} (got {tx.txid()}). from {str(iface)}")
       -            raise RequestCorrupted()  # TODO ban server?
       -        return raw
       +        return await self.interface.get_transaction(tx_hash=tx_hash, timeout=timeout)
        
            @best_effort_reliable
            @catch_server_exceptions
            async def get_history_for_scripthash(self, sh: str) -> List[dict]:
       -        if not is_hash256_str(sh):
       -            raise Exception(f"{repr(sh)} is not a scripthash")
       -        return await self.interface.session.send_request('blockchain.scripthash.get_history', [sh])
       +        return await self.interface.get_history_for_scripthash(sh)
        
            @best_effort_reliable
            @catch_server_exceptions
            async def listunspent_for_scripthash(self, sh: str) -> List[dict]:
       -        if not is_hash256_str(sh):
       -            raise Exception(f"{repr(sh)} is not a scripthash")
       -        return await self.interface.session.send_request('blockchain.scripthash.listunspent', [sh])
       +        return await self.interface.listunspent_for_scripthash(sh)
        
            @best_effort_reliable
            @catch_server_exceptions
            async def get_balance_for_scripthash(self, sh: str) -> dict:
       -        if not is_hash256_str(sh):
       -            raise Exception(f"{repr(sh)} is not a scripthash")
       -        return await self.interface.session.send_request('blockchain.scripthash.get_balance', [sh])
       +        return await self.interface.get_balance_for_scripthash(sh)
        
            @best_effort_reliable
       +    @catch_server_exceptions
            async def get_txid_from_txpos(self, tx_height, tx_pos, merkle):
       -        command = 'blockchain.transaction.id_from_pos'
       -        return await self.interface.session.send_request(command, [tx_height, tx_pos, merkle])
       +        return await self.interface.get_txid_from_txpos(tx_height, tx_pos, merkle)
        
            def blockchain(self) -> Blockchain:
                interface = self.interface
   DIR diff --git a/electrum/synchronizer.py b/electrum/synchronizer.py
       t@@ -168,15 +168,12 @@ class Synchronizer(SynchronizerBase):
                self.requested_histories.add((addr, status))
                h = address_to_scripthash(addr)
                self._requests_sent += 1
       -        result = await self.network.get_history_for_scripthash(h)
       +        result = await self.interface.get_history_for_scripthash(h)
                self._requests_answered += 1
                self.logger.info(f"receiving history {addr} {len(result)}")
                hashes = set(map(lambda item: item['tx_hash'], result))
                hist = list(map(lambda item: (item['tx_hash'], item['height']), result))
                # tx_fees
       -        for item in result:
       -            if item['height'] in (-1, 0) and 'fee' not in item:
       -                raise Exception("server response to get_history contains unconfirmed tx without fee")
                tx_fees = [(item['tx_hash'], item.get('fee')) for item in result]
                tx_fees = dict(filter(lambda x:x[1] is not None, tx_fees))
                # Check that txids are unique
       t@@ -214,7 +211,7 @@ class Synchronizer(SynchronizerBase):
            async def _get_transaction(self, tx_hash, *, allow_server_not_finding_tx=False):
                self._requests_sent += 1
                try:
       -            raw_tx = await self.network.get_transaction(tx_hash)
       +            raw_tx = await self.interface.get_transaction(tx_hash)
                except UntrustedServerReturnedError as e:
                    # most likely, "No such mempool or blockchain transaction"
                    if allow_server_not_finding_tx:
   DIR diff --git a/electrum/util.py b/electrum/util.py
       t@@ -582,6 +582,30 @@ def is_non_negative_integer(val) -> bool:
            return False
        
        
       +def is_integer(val) -> bool:
       +    try:
       +        int(val)
       +    except:
       +        return False
       +    else:
       +        return True
       +
       +
       +def is_real_number(val, *, as_str: bool = False) -> bool:
       +    if as_str:  # only accept str
       +        if not isinstance(val, str):
       +            return False
       +    else:  # only accept int/float/etc.
       +        if isinstance(val, str):
       +            return False
       +    try:
       +        Decimal(val)
       +    except:
       +        return False
       +    else:
       +        return True
       +
       +
        def chunks(items, size: int):
            """Break up items, an iterable, into chunks of length size."""
            if size < 1: