URI: 
       tfix sql conflicts in lnrouter - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit e7888a50bedd8133fec92e960c8a44bedf8c1311
   DIR parent eae8f1a139ae5ebc411c9a33352744ee1897bf1c
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Sun, 17 Mar 2019 11:54:31 +0100
       
       fix sql conflicts in lnrouter
       
       Diffstat:
         M electrum/lnrouter.py                |     138 ++++++++++++++++---------------
       
       1 file changed, 70 insertions(+), 68 deletions(-)
       ---
   DIR diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
       t@@ -23,7 +23,7 @@
        # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
        # SOFTWARE.
        
       -import datetime
       +import time
        import random
        import queue
        import os
       t@@ -35,7 +35,7 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
        import binascii
        import base64
        
       -from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
       +from sqlalchemy import Column, ForeignKey, Integer, String, Boolean
        from sqlalchemy.orm.query import Query
        from sqlalchemy.ext.declarative import declarative_base
        from sqlalchemy.sql import not_, or_
       t@@ -81,14 +81,14 @@ class ChannelInfo(Base):
            trusted = Column(Boolean, nullable=False)
        
            @staticmethod
       -    def from_msg(channel_announcement_payload):
       -        features = int.from_bytes(channel_announcement_payload['features'], 'big')
       +    def from_msg(payload):
       +        features = int.from_bytes(payload['features'], 'big')
                validate_features(features)
       -        channel_id = channel_announcement_payload['short_channel_id'].hex()
       -        node_id_1 = channel_announcement_payload['node_id_1'].hex()
       -        node_id_2 = channel_announcement_payload['node_id_2'].hex()
       +        channel_id = payload['short_channel_id'].hex()
       +        node_id_1 = payload['node_id_1'].hex()
       +        node_id_2 = payload['node_id_2'].hex()
                assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
       -        msg_payload_hex = encode_msg('channel_announcement', **channel_announcement_payload).hex()
       +        msg_payload_hex = encode_msg('channel_announcement', **payload).hex()
                capacity_sat = None
                return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1,
                        node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
       t@@ -109,17 +109,17 @@ class Policy(Base):
            fee_base_msat               = Column(Integer, nullable=False)
            fee_proportional_millionths = Column(Integer, nullable=False)
            channel_flags               = Column(Integer, nullable=False)
       -    timestamp                   = Column(DateTime, nullable=False)
       +    timestamp                   = Column(Integer, nullable=False)
        
            @staticmethod
       -    def from_msg(channel_update_payload, start_node, short_channel_id):
       -        cltv_expiry_delta           = channel_update_payload['cltv_expiry_delta']
       -        htlc_minimum_msat           = channel_update_payload['htlc_minimum_msat']
       -        fee_base_msat               = channel_update_payload['fee_base_msat']
       -        fee_proportional_millionths = channel_update_payload['fee_proportional_millionths']
       -        channel_flags               = channel_update_payload['channel_flags']
       -        timestamp                   = channel_update_payload['timestamp']
       -        htlc_maximum_msat           = channel_update_payload.get('htlc_maximum_msat')  # optional
       +    def from_msg(payload, start_node, short_channel_id):
       +        cltv_expiry_delta           = payload['cltv_expiry_delta']
       +        htlc_minimum_msat           = payload['htlc_minimum_msat']
       +        fee_base_msat               = payload['fee_base_msat']
       +        fee_proportional_millionths = payload['fee_proportional_millionths']
       +        channel_flags               = payload['channel_flags']
       +        timestamp                   = payload['timestamp']
       +        htlc_maximum_msat           = payload.get('htlc_maximum_msat')  # optional
        
                cltv_expiry_delta           = int.from_bytes(cltv_expiry_delta, "big")
                htlc_minimum_msat           = int.from_bytes(htlc_minimum_msat, "big")
       t@@ -127,7 +127,7 @@ class Policy(Base):
                fee_base_msat               = int.from_bytes(fee_base_msat, "big")
                fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
                channel_flags               = int.from_bytes(channel_flags, "big")
       -        timestamp                   = datetime.datetime.fromtimestamp(int.from_bytes(timestamp, "big"))
       +        timestamp                   = int.from_bytes(timestamp, "big")
        
                return Policy(start_node=start_node,
                        short_channel_id=short_channel_id,
       t@@ -150,17 +150,16 @@ class NodeInfo(Base):
            alias = Column(String(64), nullable=False)
        
            @staticmethod
       -    def from_msg(node_announcement_payload, addresses_already_parsed=False):
       -        node_id = node_announcement_payload['node_id'].hex()
       -        features = int.from_bytes(node_announcement_payload['features'], "big")
       +    def from_msg(payload):
       +        node_id = payload['node_id'].hex()
       +        features = int.from_bytes(payload['features'], "big")
                validate_features(features)
       -        if not addresses_already_parsed:
       -            addresses = NodeInfo.parse_addresses_field(node_announcement_payload['addresses'])
       -        else:
       -            addresses = node_announcement_payload['addresses']
       -        alias = node_announcement_payload['alias'].rstrip(b'\x00').hex()
       -        timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big"))
       -        return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [Address(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses]
       +        addresses = NodeInfo.parse_addresses_field(payload['addresses'])
       +        alias = payload['alias'].rstrip(b'\x00').hex()
       +        timestamp = int.from_bytes(payload['timestamp'], "big")
       +        now = int(time.time())
       +        return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [
       +            Address(host=host, port=port, node_id=node_id, last_connected_date=now) for host, port in addresses]
        
            @staticmethod
            def parse_addresses_field(addresses_field):
       t@@ -207,7 +206,7 @@ class Address(Base):
            node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
            host = Column(String(256), primary_key=True)
            port = Column(Integer, primary_key=True)
       -    last_connected_date = Column(DateTime(), nullable=False)
       +    last_connected_date = Column(Integer(), nullable=False)
        
        
        
       t@@ -235,12 +234,14 @@ class ChannelDB(SqlDB):
        
            @sql
            def add_recent_peer(self, peer: LNPeerAddr):
       -        addr = self.DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none()
       -        if addr is None:
       -            addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now())
       +        now = int(time.time())
       +        node_id = peer.pubkey.hex()
       +        addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
       +        if addr:
       +            addr.last_connected_date = now
                else:
       -            addr.last_connected_date = datetime.datetime.now()
       -        self.DBSession.add(addr)
       +            addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
       +            self.DBSession.add(addr)
                self.DBSession.commit()
        
            @sql
       t@@ -317,25 +318,31 @@ class ChannelDB(SqlDB):
                self.DBSession.commit()
        
            @sql
       -    @profiler
       +    #@profiler
            def on_channel_announcement(self, msg_payloads, trusted=False):
                if type(msg_payloads) is dict:
                    msg_payloads = [msg_payloads]
       +        new_channels = {}
                for msg in msg_payloads:
       -            short_channel_id = msg['short_channel_id']
       -            if self.DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count():
       +            short_channel_id = bh2u(msg['short_channel_id'])
       +            if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count():
                        continue
                    if constants.net.rev_genesis_bytes() != msg['chain_hash']:
       -                #self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
       +                self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
                        continue
                    try:
                        channel_info = ChannelInfo.from_msg(msg)
                    except UnknownEvenFeatureBits:
       +                self.print_error("unknown feature bits")
                        continue
                    channel_info.trusted = trusted
       +            new_channels[short_channel_id] = channel_info
       +            if not trusted:
       +                self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
       +        for channel_info in new_channels.values():
                    self.DBSession.add(channel_info)
       -            if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
                self.DBSession.commit()
       +        self.print_error('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
                self._update_counts()
                self.network.trigger_callback('ln_status')
        
       t@@ -379,21 +386,13 @@ class ChannelDB(SqlDB):
                self.DBSession.commit()
        
            @sql
       -    @profiler
       +    #@profiler
            def on_node_announcement(self, msg_payloads):
                if type(msg_payloads) is dict:
                    msg_payloads = [msg_payloads]
       -        addresses = self.DBSession.query(Address).all()
       -        have_addr = {}
       -        for addr in addresses:
       -            have_addr[(addr.node_id, addr.host, addr.port)] = addr
       -
       -        nodes = self.DBSession.query(NodeInfo).all()
       -        timestamps = {}
       -        for node in nodes:
       -            no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")]
       -            timestamps[bfh(node.node_id)] = datetime.datetime.strptime(no_millisecs, "%Y-%m-%d %H:%M:%S")
                old_addr = None
       +        new_nodes = {}
       +        new_addresses = {}
                for msg_payload in msg_payloads:
                    pubkey = msg_payload['node_id']
                    signature = msg_payload['signature']
       t@@ -401,30 +400,33 @@ class ChannelDB(SqlDB):
                    if not ecc.verify_signature(pubkey, signature, h):
                        continue
                    try:
       -                new_node_info, addresses = NodeInfo.from_msg(msg_payload)
       +                node_info, node_addresses = NodeInfo.from_msg(msg_payload)
                    except UnknownEvenFeatureBits:
                        continue
       -            if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp:
       -                continue  # ignore
       -            self.DBSession.add(new_node_info)
       -            for new_addr in addresses:
       -                key = (new_addr.node_id, new_addr.host, new_addr.port)
       -                old_addr = have_addr.get(key)
       -                if old_addr:
       -                    # since old_addr is embedded in have_addr,
       -                    # it will still live when commmit is called
       -                    old_addr.last_connected_date = new_addr.last_connected_date
       -                    del new_addr
       -                else:
       -                    self.DBSession.add(new_addr)
       -                    have_addr[key] = new_addr
       +            node_id = node_info.node_id
       +            node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none()
       +            if node and node.timestamp >= node_info.timestamp:
       +                continue
       +            node = new_nodes.get(node_id)
       +            if node and node.timestamp >= node_info.timestamp:
       +                continue
       +            new_nodes[node_id] = node_info
       +            for addr in node_addresses:
       +                new_addresses[(addr.node_id,addr.host,addr.port)] = addr
       +
       +        self.print_error("on_node_announcements: %d/%d"%(len(new_nodes), len(msg_payloads)))
       +        for node_info in new_nodes.values():
       +            self.DBSession.add(node_info)
       +        for new_addr in new_addresses.values():
       +            old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
       +            if old_addr:
       +                old_addr.last_connected_date = new_addr.last_connected_date
       +            else:
       +                self.DBSession.add(new_addr)
                    # TODO if this message is for a new node, and if we have no associated
                    # channels for this node, we should ignore the message and return here,
                    # to mitigate DOS. but race condition: the channels we have for this
                    # node, might be under verification in self.ca_verifier, what then?
       -        del nodes, addresses
       -        if old_addr:
       -            del old_addr
                self.DBSession.commit()
                self._update_counts()
                self.network.trigger_callback('ln_status')