URI: 
       timprove filter_channel_updates blacklist channels that do not really get updated - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit eb4e6bb0de493877c0c7553928b4408787d79c92
   DIR parent f4b3d7627d99777d9c33758fe5874d1360cac1ab
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Thu, 16 May 2019 19:00:44 +0200
       
       improve filter_channel_updates
       blacklist channels that do not really get updated
       
       Diffstat:
         M electrum/lnpeer.py                  |      38 +++++++++++++++++++------------
         M electrum/lnrouter.py                |     142 ++++++++++++++++++-------------
       
       2 files changed, 106 insertions(+), 74 deletions(-)
       ---
   DIR diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
       t@@ -241,12 +241,14 @@ class Peer(Logger):
                    self.verify_node_announcements(node_anns)
                    self.channel_db.on_node_announcement(node_anns)
                    # channel updates
       -            good, bad = self.channel_db.filter_channel_updates(chan_upds)
       -            if bad:
       -                self.logger.info(f'adding {len(bad)} unknown channel ids')
       -                self.network.lngossip.add_new_ids(bad)
       -            self.verify_channel_updates(good)
       -            self.channel_db.on_channel_update(good)
       +            orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates(chan_upds, max_age=self.network.lngossip.max_age)
       +            if orphaned:
       +                self.logger.info(f'adding {len(orphaned)} unknown channel ids')
       +                self.network.lngossip.add_new_ids(orphaned)
       +            if good:
       +                self.logger.debug(f'on_channel_update: {len(good)}/{len(chan_upds)}')
       +                self.verify_channel_updates(good)
       +                self.channel_db.update_policies(good, to_delete)
                    # refresh gui
                    if chan_anns or node_anns or chan_upds:
                        self.network.lngossip.refresh_gui()
       t@@ -273,7 +275,7 @@ class Peer(Logger):
                    short_channel_id = payload['short_channel_id']
                    if constants.net.rev_genesis_bytes() != payload['chain_hash']:
                        raise Exception('wrong chain hash')
       -            if not verify_sig_for_channel_update(payload, payload['node_id']):
       +            if not verify_sig_for_channel_update(payload, payload['start_node']):
                        raise BaseException('verify error')
        
            @log_exceptions
       t@@ -990,21 +992,29 @@ class Peer(Logger):
                    OnionFailureCode.EXPIRY_TOO_SOON: 2,
                    OnionFailureCode.CHANNEL_DISABLED: 4,
                }
       -        offset = failure_codes.get(code)
       -        if offset:
       +        if code in failure_codes:
       +            offset = failure_codes[code]
                    channel_update = (258).to_bytes(length=2, byteorder="big") + data[offset:]
                    message_type, payload = decode_msg(channel_update)
                    payload['raw'] = channel_update
       -            try:
       -                self.logger.info(f"trying to apply channel update on our db {payload}")
       -                self.channel_db.add_channel_update(payload)
       -                self.logger.info("successfully applied channel update on our db")
       -            except NotFoundChanAnnouncementForUpdate:
       +            orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates([payload])
       +            if good:
       +                self.verify_channel_updates(good)
       +                self.channel_db.update_policies(good, to_delete)
       +                self.logger.info("applied channel update on our db")
       +            elif orphaned:
                        # maybe it is a private channel (and data in invoice was outdated)
                        self.logger.info("maybe channel update is for private channel?")
                        start_node_id = route[sender_idx].node_id
                        self.channel_db.add_channel_update_for_private_channel(payload, start_node_id)
       +            elif expired:
       +                blacklist = True
       +            elif deprecated:
       +                self.logger.info(f'channel update is not more recent. blacklisting channel')
       +                blacklist = True
                else:
       +            blacklist = True
       +        if blacklist:
                    # blacklist channel after reporter node
                    # TODO this should depend on the error (even more granularity)
                    # also, we need finer blacklisting (directed edges; nodes)
   DIR diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
       t@@ -114,22 +114,16 @@ class Policy(Base):
            timestamp                   = Column(Integer, nullable=False)
        
            @staticmethod
       -    def from_msg(payload, start_node, short_channel_id):
       -        cltv_expiry_delta           = payload['cltv_expiry_delta']
       -        htlc_minimum_msat           = payload['htlc_minimum_msat']
       -        fee_base_msat               = payload['fee_base_msat']
       -        fee_proportional_millionths = payload['fee_proportional_millionths']
       -        channel_flags               = payload['channel_flags']
       -        timestamp                   = payload['timestamp']
       -        htlc_maximum_msat           = payload.get('htlc_maximum_msat')  # optional
       -
       -        cltv_expiry_delta           = int.from_bytes(cltv_expiry_delta, "big")
       -        htlc_minimum_msat           = int.from_bytes(htlc_minimum_msat, "big")
       -        htlc_maximum_msat           = int.from_bytes(htlc_maximum_msat, "big") if htlc_maximum_msat else None
       -        fee_base_msat               = int.from_bytes(fee_base_msat, "big")
       -        fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
       -        channel_flags               = int.from_bytes(channel_flags, "big")
       -        timestamp                   = int.from_bytes(timestamp, "big")
       +    def from_msg(payload):
       +        cltv_expiry_delta           = int.from_bytes(payload['cltv_expiry_delta'], "big")
       +        htlc_minimum_msat           = int.from_bytes(payload['htlc_minimum_msat'], "big")
       +        htlc_maximum_msat           = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None
       +        fee_base_msat               = int.from_bytes(payload['fee_base_msat'], "big")
       +        fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big")
       +        channel_flags               = int.from_bytes(payload['channel_flags'], "big")
       +        timestamp                   = int.from_bytes(payload['timestamp'], "big")
       +        start_node                  = payload['start_node'].hex()
       +        short_channel_id            = payload['short_channel_id'].hex()
        
                return Policy(start_node=start_node,
                        short_channel_id=short_channel_id,
       t@@ -341,71 +335,98 @@ class ChannelDB(SqlDB):
                r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one()
                return r.max_timestamp or 0
        
       +    def print_change(self, old_policy, new_policy):
       +        # print what changed between policies
       +        if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta:
       +            self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}')
       +        if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat:
       +            self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}')
       +        if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat:
       +            self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}')
       +        if old_policy.fee_base_msat != new_policy.fee_base_msat:
       +            self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}')
       +        if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths:
       +            self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}')
       +        if old_policy.channel_flags != new_policy.channel_flags:
       +            self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
       +
            @sql
       -    def get_info_for_updates(self, msg_payloads):
       -        short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
       +    def get_info_for_updates(self, payloads):
       +        short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads]
                channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
                channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
                return channel_infos
        
       +    @sql
       +    def get_policies_for_updates(self, payloads):
       +        out = {}
       +        for payload in payloads:
       +            short_channel_id = payload['short_channel_id'].hex()
       +            start_node = payload['start_node'].hex()
       +            policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none()
       +            if policy:
       +                out[short_channel_id+start_node] = policy
       +        return out
       +
            @profiler
       -    def filter_channel_updates(self, payloads):
       -        # add 'node_id' to payload
       -        channel_infos = self.get_info_for_updates(payloads)
       +    def filter_channel_updates(self, payloads, max_age=None):
       +        orphaned = []      # no channel announcement for channel update
       +        expired = []       # update older than two weeks
       +        deprecated = []    # update older than database entry
       +        good = []          # good updates
       +        to_delete = []     # database entries to delete
       +        # filter orphaned and expired first
                known = []
       -        unknown = []
       +        now = int(time.time())
       +        channel_infos = self.get_info_for_updates(payloads)
                for payload in payloads:
                    short_channel_id = payload['short_channel_id']
       +            timestamp = int.from_bytes(payload['timestamp'], "big")
       +            if max_age and now - timestamp > max_age:
       +                expired.append(short_channel_id)
       +                continue
                    channel_info = channel_infos.get(short_channel_id)
                    if not channel_info:
       -                unknown.append(short_channel_id)
       +                orphaned.append(short_channel_id)
                        continue
                    flags = int.from_bytes(payload['channel_flags'], 'big')
                    direction = flags & FLAG_DIRECTION
       -            node_id = bfh(channel_info.node1_id if direction == 0 else channel_info.node2_id)
       -            payload['node_id'] = node_id
       +            start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
       +            payload['start_node'] = bfh(start_node)
                    known.append(payload)
       -        return known, unknown
       +        # compare updates to existing database entries
       +        old_policies = self.get_policies_for_updates(known)
       +        for payload in known:
       +            timestamp = int.from_bytes(payload['timestamp'], "big")
       +            start_node = payload['start_node'].hex()
       +            short_channel_id = payload['short_channel_id'].hex()
       +            old_policy = old_policies.get(short_channel_id+start_node)
       +            if old_policy:
       +                if timestamp <= old_policy.timestamp:
       +                    deprecated.append(short_channel_id)
       +                else:
       +                    good.append(payload)
       +                    to_delete.append(old_policy)
       +            else:
       +                good.append(payload)
       +        return orphaned, expired, deprecated, good, to_delete
        
            def add_channel_update(self, payload):
       -        # called in tests/test_lnrouter
       -        good, bad = self.filter_channel_updates([payload])
       -        assert len(bad) == 0
       -        self.on_channel_update(good)
       +        orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload])
       +        assert len(good) == 1
       +        self.update_policies(good, to_delete)
        
            @sql
            @profiler
       -    def on_channel_update(self, msg_payloads):
       -        now = int(time.time())
       -        if type(msg_payloads) is dict:
       -            msg_payloads = [msg_payloads]
       -        new_policies = {}
       -        for msg_payload in msg_payloads:
       -            short_channel_id = msg_payload['short_channel_id'].hex()
       -            node_id = msg_payload['node_id'].hex()
       -            new_policy = Policy.from_msg(msg_payload, node_id, short_channel_id)
       -            # must not be older than two weeks
       -            if new_policy.timestamp < now - 14*24*3600:
       -                continue
       -            old_policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=node_id).one_or_none()
       -            if old_policy:
       -                if old_policy.timestamp >= new_policy.timestamp:
       -                    continue
       -                self.DBSession.delete(old_policy)
       -            p = new_policies.get((short_channel_id, node_id))
       -            if p and p.timestamp >= new_policy.timestamp:
       -                continue
       -            new_policies[(short_channel_id, node_id)] = new_policy
       -        # commit pending removals
       +    def update_policies(self, to_add, to_delete):
       +        for policy in to_delete:
       +            self.DBSession.delete(policy)
                self.DBSession.commit()
       -        # add and commit new policies
       -        for new_policy in new_policies.values():
       -            self.DBSession.add(new_policy)
       +        for payload in to_add:
       +            policy = Policy.from_msg(payload)
       +            self.DBSession.add(policy)
                self.DBSession.commit()
       -        if new_policies:
       -            self.logger.debug(f'on_channel_update: {len(new_policies)}/{len(msg_payloads)}')
       -            #self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}')
       -            self._update_counts()
       +        self._update_counts()
        
            @sql
            @profiler
       t@@ -454,7 +475,7 @@ class ChannelDB(SqlDB):
                msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
                if not msg:
                    return None
       -        return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB
       +        return Policy.from_msg(msg) # won't actually be written to DB
        
            @sql
            @profiler
       t@@ -496,6 +517,7 @@ class ChannelDB(SqlDB):
                if not verify_sig_for_channel_update(msg_payload, start_node_id):
                    return  # ignore
                short_channel_id = msg_payload['short_channel_id']
       +        msg_payload['start_node'] = start_node_id
                self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
        
            @sql