URI: 
       tLNGossip: sync channel db using query_channel_range - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 1011245c5e29a128edf4629d024c4cf345190282
   DIR parent 95376226e8567adb4b7bd4eea298e623f58c26e4
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Mon, 13 May 2019 14:30:02 +0200
       
       LNGossip: sync channel db using query_channel_range
       
       Diffstat:
         M electrum/lnpeer.py                  |      84 ++++++++++++++++++-------------
         M electrum/lnrouter.py                |      57 ++++++++++++++++++++++++++-----
         M electrum/lnworker.py                |      86 +++++++++++++++++++++++++------
       
       3 files changed, 168 insertions(+), 59 deletions(-)
       ---
   DIR diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
       t@@ -57,9 +57,7 @@ class Peer(Logger):
        
            def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase):
                self.initialized = asyncio.Event()
       -        self.node_anns = []
       -        self.chan_anns = []
       -        self.chan_upds = []
       +        self.querying_lock = asyncio.Lock()
                self.transport = transport
                self.pubkey = pubkey
                self.lnworker = lnworker
       t@@ -70,6 +68,7 @@ class Peer(Logger):
                self.lnwatcher = lnworker.network.lnwatcher
                self.channel_db = lnworker.network.channel_db
                self.ping_time = 0
       +        self.reply_channel_range = asyncio.Queue()
                self.shutdown_received = defaultdict(asyncio.Future)
                self.channel_accepted = defaultdict(asyncio.Queue)
                self.channel_reestablished = defaultdict(asyncio.Future)
       t@@ -89,7 +88,7 @@ class Peer(Logger):
        
            def send_message(self, message_name: str, **kwargs):
                assert type(message_name) is str
       -        self.logger.info(f"Sending {message_name.upper()}")
       +        self.logger.debug(f"Sending {message_name.upper()}")
                self.transport.send_bytes(encode_msg(message_name, **kwargs))
        
            async def initialize(self):
       t@@ -177,13 +176,13 @@ class Peer(Logger):
                self.initialized.set()
        
            def on_node_announcement(self, payload):
       -        self.node_anns.append(payload)
       +        self.channel_db.node_anns.append(payload)
        
            def on_channel_update(self, payload):
       -        self.chan_upds.append(payload)
       +        self.channel_db.chan_upds.append(payload)
        
            def on_channel_announcement(self, payload):
       -        self.chan_anns.append(payload)
       +        self.channel_db.chan_anns.append(payload)
        
            def on_announcement_signatures(self, payload):
                channel_id = payload['channel_id']
       t@@ -207,15 +206,11 @@ class Peer(Logger):
            @handle_disconnect
            async def main_loop(self):
                async with aiorpcx.TaskGroup() as group:
       -            await group.spawn(self._gossip_loop())
                    await group.spawn(self._message_loop())
                    # kill group if the peer times out
                    await group.spawn(asyncio.wait_for(self.initialized.wait(), 10))
        
       -    @log_exceptions
       -    async def _gossip_loop(self):
       -        await self.initialized.wait()
       -        timestamp = self.channel_db.get_last_timestamp()
       +    def request_gossip(self, timestamp=0):
                if timestamp == 0:
                    self.logger.info('requesting whole channel graph')
                else:
       t@@ -225,28 +220,47 @@ class Peer(Logger):
                    chain_hash=constants.net.rev_genesis_bytes(),
                    first_timestamp=timestamp,
                    timestamp_range=b'\xff'*4)
       -        while True:
       -            await asyncio.sleep(5)
       -            if self.node_anns:
       -                self.channel_db.on_node_announcement(self.node_anns)
       -                self.node_anns = []
       -            if self.chan_anns:
       -                self.channel_db.on_channel_announcement(self.chan_anns)
       -                self.chan_anns = []
       -            if self.chan_upds:
       -                self.channel_db.on_channel_update(self.chan_upds)
       -                self.chan_upds = []
       -            # todo: enable when db is fixed
       -            #need_to_get = sorted(self.channel_db.missing_short_chan_ids())
       -            #if need_to_get and not self.receiving_channels:
       -            #    self.logger.info(f'missing {len(need_to_get)} channels')
       -            #    zlibencoded = zlib.compress(bfh(''.join(need_to_get[0:100])))
       -            #    self.send_message(
       -            #        'query_short_channel_ids',
       -            #        chain_hash=constants.net.rev_genesis_bytes(),
       -            #        len=1+len(zlibencoded),
       -            #        encoded_short_ids=b'\x01' + zlibencoded)
       -            #    self.receiving_channels = True
       +
       +    def query_channel_range(self, index, num):
       +        self.logger.info(f'query channel range')
       +        self.send_message(
       +            'query_channel_range',
       +            chain_hash=constants.net.rev_genesis_bytes(),
       +            first_blocknum=index,
       +            number_of_blocks=num)
       +
       +    def encode_short_ids(self, ids):
       +        return chr(1) + zlib.compress(bfh(''.join(ids)))
       +
       +    def decode_short_ids(self, encoded):
       +        if encoded[0] == 0:
       +            decoded = encoded[1:]
       +        elif encoded[0] == 1:
       +            decoded = zlib.decompress(encoded[1:])
       +        else:
       +            raise BaseException('zlib')
       +        ids = [decoded[i:i+8] for i in range(0, len(decoded), 8)]
       +        return ids
       +
       +    def on_reply_channel_range(self, payload):
       +        first = int.from_bytes(payload['first_blocknum'], 'big')
       +        num = int.from_bytes(payload['number_of_blocks'], 'big')
       +        complete = bool(payload['complete'])
       +        encoded = payload['encoded_short_ids']
       +        ids = self.decode_short_ids(encoded)
       +        self.reply_channel_range.put_nowait((first, num, complete, ids))
       +
       +    async def query_short_channel_ids(self, ids, compressed=True):
       +        await self.querying_lock.acquire()
       +        #self.logger.info('querying {} short_channel_ids'.format(len(ids)))
       +        s = b''.join(ids)
       +        encoded = zlib.compress(s) if compressed else s
       +        prefix = b'\x01' if compressed else b'\x00'
       +        self.send_message(
       +            'query_short_channel_ids',
       +            chain_hash=constants.net.rev_genesis_bytes(),
       +            len=1+len(encoded),
       +            encoded_short_ids=prefix+encoded)
        
            async def _message_loop(self):
                try:
       t@@ -260,7 +274,7 @@ class Peer(Logger):
                    self.ping_if_required()
        
            def on_reply_short_channel_ids_end(self, payload):
       -        self.receiving_channels = False
       +        self.querying_lock.release()
        
            def close_and_cleanup(self):
                try:
   DIR diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
       t@@ -223,6 +223,20 @@ class ChannelDB(SqlDB):
                self._channel_updates_for_private_channels = {}  # type: Dict[Tuple[bytes, bytes], dict]
                self.ca_verifier = LNChannelVerifier(network, self)
                self.update_counts()
       +        self.node_anns = []
       +        self.chan_anns = []
       +        self.chan_upds = []
       +
       +    def process_gossip(self):
       +        if self.node_anns:
       +            self.on_node_announcement(self.node_anns)
       +            self.node_anns = []
       +        if self.chan_anns:
       +            self.on_channel_announcement(self.chan_anns)
       +            self.chan_anns = []
       +        if self.chan_upds:
       +            self.on_channel_update(self.chan_upds)
       +            self.chan_upds = []
        
            @sql
            def update_counts(self):
       t@@ -232,7 +246,32 @@ class ChannelDB(SqlDB):
                self.num_channels = self.DBSession.query(ChannelInfo).count()
                self.num_policies = self.DBSession.query(Policy).count()
                self.num_nodes = self.DBSession.query(NodeInfo).count()
       -        self.logger.info(f'update counts {self.num_channels} {self.num_policies}')
       +
       +    @sql
       +    @profiler
       +    def purge_unknown_channels(self, channel_ids):
       +        ids = [x.hex() for x in channel_ids]
       +        missing = self.DBSession \
       +                      .query(ChannelInfo) \
       +                      .filter(not_(ChannelInfo.short_channel_id.in_(ids))) \
       +                      .all()
       +        if missing:
       +            self.logger.info("deleting {} channels".format(len(missing)))
       +            delete_query = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(ids)))
       +            self.DBSession.execute(delete_query)
       +            self.DBSession.commit()
       +
       +    @sql
       +    @profiler
       +    def compare_channels(self, channel_ids):
       +        ids = [x.hex() for x in channel_ids]
       +        # I need to get the unknown, and also the channels that need refresh
       +        known = self.DBSession \
       +                 .query(ChannelInfo) \
       +                 .filter(ChannelInfo.short_channel_id.in_(ids)) \
       +                 .all()
       +        known = [bfh(r.short_channel_id) for r in known]
       +        return known
        
            @sql
            def add_recent_peer(self, peer: LNPeerAddr):
       t@@ -276,12 +315,14 @@ class ChannelDB(SqlDB):
                return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r]
        
            @sql
       -    def missing_short_chan_ids(self) -> Set[int]:
       +    def missing_channel_announcements(self) -> Set[int]:
                expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
       -        chan_ids_from_policy = set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
       -        if chan_ids_from_policy:
       -            return chan_ids_from_policy
       -        return set()
       +        return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
       +
       +    @sql
       +    def missing_channel_updates(self) -> Set[int]:
       +        expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id)))
       +        return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
        
            @sql
            def add_verified_channel_info(self, short_id, capacity):
       t@@ -316,8 +357,8 @@ class ChannelDB(SqlDB):
                for channel_info in new_channels.values():
                    self.DBSession.add(channel_info)
                self.DBSession.commit()
       -        #self.logger.info('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
                self._update_counts()
       +        self.logger.info('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
                self.network.trigger_callback('ln_status')
        
            @sql
       t@@ -370,7 +411,7 @@ class ChannelDB(SqlDB):
                self.DBSession.commit()
                if new_policies:
                    self.logger.info(f'on_channel_update: {len(new_policies)}/{len(msg_payloads)}')
       -            self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}')
       +            #self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}')
                    self._update_counts()
        
            @sql
   DIR diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -133,9 +133,7 @@ class LNWorker(Logger):
                self.channel_db = self.network.channel_db
                self._last_tried_peer = {}  # LNPeerAddr -> unix timestamp
                self._add_peers_from_config()
       -        # wait until we see confirmations
                asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.main_loop()), self.network.asyncio_loop)
       -        self.first_timestamp_requested = None
        
            def _add_peers_from_config(self):
                peer_list = self.config.get('lightning_peers', [])
       t@@ -215,9 +213,24 @@ class LNWorker(Logger):
                self.logger.info('got {} ln peers from dns seed'.format(len(peers)))
                return peers
        
       +    @staticmethod
       +    def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
       +        assert len(addr_list) >= 1
       +        # choose first one that is an IP
       +        for addr_in_db in addr_list:
       +            host = addr_in_db.host
       +            port = addr_in_db.port
       +            if is_ip_address(host):
       +                return host, port
       +        # otherwise choose one at random
       +        # TODO maybe filter out onion if not on tor?
       +        choice = random.choice(addr_list)
       +        return choice.host, choice.port
        
        
        class LNGossip(LNWorker):
       +    # height of first channel announcements
       +    first_block = 497000
        
            def __init__(self, network):
                seed = os.urandom(32)
       t@@ -226,6 +239,61 @@ class LNGossip(LNWorker):
                super().__init__(xprv)
                self.localfeatures |= LnLocalFeatures.GOSSIP_QUERIES_REQ
        
       +    def start_network(self, network: 'Network'):
       +        super().start_network(network)
       +        asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.gossip_task()), self.network.asyncio_loop)
       +
       +    async def gossip_task(self):
       +        req_index = self.first_block
       +        req_num = self.network.get_local_height() - req_index
       +        while len(self.peers) == 0:
       +            await asyncio.sleep(1)
       +            continue
       +        # todo: parallelize over peers
       +        peer = list(self.peers.values())[0]
       +        await peer.initialized.wait()
       +        # send channels_range query. peer will reply with several intervals
       +        peer.query_channel_range(req_index, req_num)
       +        intervals = []
       +        ids = set()
       +        # wait until requested range is covered
       +        while True:
       +            index, num, complete, _ids = await peer.reply_channel_range.get()
       +            ids.update(_ids)
       +            intervals.append((index, index+num))
       +            intervals.sort()
       +            while len(intervals) > 1:
       +                a,b = intervals[0]
       +                c,d = intervals[1]
       +                if b == c:
       +                    intervals = [(a,d)] + intervals[2:]
       +                else:
       +                    break
       +            if len(intervals) == 1:
       +                a, b = intervals[0]
       +                if a <= req_index and b >= req_index + req_num:
       +                    break
       +        self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete))
       +        # TODO: filter results by date of last channel update, purge DB
       +        #if complete:
       +        #    self.channel_db.purge_unknown_channels(ids)
       +        known = self.channel_db.compare_channels(ids)
       +        unknown = list(ids - set(known))
       +        total = len(unknown)
       +        N = 500
       +        while unknown:
       +            self.channel_db.process_gossip()
       +            await peer.query_short_channel_ids(unknown[0:N])
       +            unknown = unknown[N:]
       +            self.logger.info(f'Querying channels: {total - len(unknown)}/{total}. Count: {self.channel_db.num_channels}')
       +
       +        # request gossip fromm current time
       +        now = int(time.time())
       +        peer.request_gossip(now)
       +        while True:
       +            await asyncio.sleep(5)
       +            self.channel_db.process_gossip()
       +
        
        class LNWallet(LNWorker):
        
       t@@ -548,20 +616,6 @@ class LNWallet(LNWorker):
            def on_channels_updated(self):
                self.network.trigger_callback('channels')
        
       -    @staticmethod
       -    def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
       -        assert len(addr_list) >= 1
       -        # choose first one that is an IP
       -        for addr_in_db in addr_list:
       -            host = addr_in_db.host
       -            port = addr_in_db.port
       -            if is_ip_address(host):
       -                return host, port
       -        # otherwise choose one at random
       -        # TODO maybe filter out onion if not on tor?
       -        choice = random.choice(addr_list)
       -        return choice.host, choice.port
       -
            def open_channel(self, connect_contents, local_amt_sat, push_amt_sat, password=None, timeout=20):
                node_id, rest = extract_nodeid(connect_contents)
                peer = self.peers.get(node_id)