URI: 
       tget rid of sql_alchemy - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit 238f3c949ca2d23cbe4cc4e5a1ccd15371050663
   DIR parent 0eab1692d6a64eacdc13f2b3568a1453ce1c3761
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Thu, 27 Jun 2019 09:03:34 +0200
       
       get rid of sql_alchemy
       
       Diffstat:
         M contrib/requirements/requirements.… |       1 -
         M electrum/channel_db.py              |     174 +++++++++++++++----------------
         M electrum/lnwatcher.py               |     100 +++++++++++++++++++-------------
         M electrum/sql_db.py                  |      25 +++++++------------------
       
       4 files changed, 153 insertions(+), 147 deletions(-)
       ---
   DIR diff --git a/contrib/requirements/requirements.txt b/contrib/requirements/requirements.txt
       t@@ -11,4 +11,3 @@ aiohttp_socks
        certifi
        bitstring
        pycryptodomex>=3.7
       -sqlalchemy>=1.3.0b3
   DIR diff --git a/electrum/channel_db.py b/electrum/channel_db.py
       t@@ -36,10 +36,6 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
        import binascii
        import base64
        
       -from sqlalchemy import Column, ForeignKey, Integer, String, Boolean
       -from sqlalchemy.orm.query import Query
       -from sqlalchemy.ext.declarative import declarative_base
       -from sqlalchemy.sql import not_, or_
        
        from .sql_db import SqlDB, sql
        from . import constants
       t@@ -66,7 +62,6 @@ def validate_features(features : int):
                if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
                    raise UnknownEvenFeatureBits()
        
       -Base = declarative_base()
        
        FLAG_DISABLE   = 1 << 1
        FLAG_DIRECTION = 1 << 0
       t@@ -193,57 +188,45 @@ class Address(NamedTuple):
            port: int
            last_connected_date: int
        
       -
       -class ChannelInfoBase(Base):
       -    __tablename__ = 'channel_info'
       -    short_channel_id = Column(String(64), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
       -    node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
       -    node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
       -    capacity_sat = Column(Integer)
       -    def to_nametuple(self):
       -        return ChannelInfo(
       -            short_channel_id=self.short_channel_id,
       -            node1_id=self.node1_id,
       -            node2_id=self.node2_id,
       -            capacity_sat=self.capacity_sat
       -        )
       -
       -class PolicyBase(Base):
       -    __tablename__ = 'policy'
       -    key                         = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
       -    cltv_expiry_delta           = Column(Integer, nullable=False)
       -    htlc_minimum_msat           = Column(Integer, nullable=False)
       -    htlc_maximum_msat           = Column(Integer)
       -    fee_base_msat               = Column(Integer, nullable=False)
       -    fee_proportional_millionths = Column(Integer, nullable=False)
       -    channel_flags               = Column(Integer, nullable=False)
       -    timestamp                   = Column(Integer, nullable=False)
       -
       -    def to_nametuple(self):
       -        return Policy(
       -            key=self.key,
       -            cltv_expiry_delta=self.cltv_expiry_delta,
       -            htlc_minimum_msat=self.htlc_minimum_msat,
       -            htlc_maximum_msat=self.htlc_maximum_msat,
       -            fee_base_msat= self.fee_base_msat,
       -            fee_proportional_millionths = self.fee_proportional_millionths,
       -            channel_flags=self.channel_flags,
       -            timestamp=self.timestamp
       -        )
       -
       -class NodeInfoBase(Base):
       -    __tablename__ = 'node_info'
       -    node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
       -    features = Column(Integer, nullable=False)
       -    timestamp = Column(Integer, nullable=False)
       -    alias = Column(String(64), nullable=False)
       -
       -class AddressBase(Base):
       -    __tablename__ = 'address'
       -    node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
       -    host = Column(String(256))
       -    port = Column(Integer)
       -    last_connected_date = Column(Integer(), nullable=True)
       +create_channel_info = """
       +CREATE TABLE IF NOT EXISTS channel_info (
       +short_channel_id VARCHAR(64),
       +node1_id VARCHAR(66),
       +node2_id VARCHAR(66),
       +capacity_sat INTEGER,
       +PRIMARY KEY(short_channel_id)
       +)"""
       +
       +create_policy = """
       +CREATE TABLE IF NOT EXISTS policy (
       +key VARCHAR(66),
       +cltv_expiry_delta INTEGER NOT NULL,
       +htlc_minimum_msat INTEGER NOT NULL,
       +htlc_maximum_msat INTEGER,
       +fee_base_msat INTEGER NOT NULL,
       +fee_proportional_millionths INTEGER NOT NULL,
       +channel_flags INTEGER NOT NULL,
       +timestamp INTEGER NOT NULL,
       +PRIMARY KEY(key)
       +)"""
       +
       +create_address = """
       +CREATE TABLE IF NOT EXISTS address (
       +node_id VARCHAR(66),
       +host STRING(256),
       +port INTEGER NOT NULL,
       +timestamp INTEGER,
       +PRIMARY KEY(node_id, host, port)
       +)"""
       +
       +create_node_info = """
       +CREATE TABLE IF NOT EXISTS node_info (
       +node_id VARCHAR(66),
       +features INTEGER NOT NULL,
       +timestamp INTEGER NOT NULL,
       +alias STRING(64),
       +PRIMARY KEY(node_id)
       +)"""
        
        
        class ChannelDB(SqlDB):
       t@@ -252,7 +235,7 @@ class ChannelDB(SqlDB):
        
            def __init__(self, network: 'Network'):
                path = os.path.join(get_headers_dir(network.config), 'channel_db')
       -        super().__init__(network, path, Base, commit_interval=100)
       +        super().__init__(network, path, commit_interval=100)
                self.num_nodes = 0
                self.num_channels = 0
                self._channel_updates_for_private_channels = {}  # type: Dict[Tuple[bytes, bytes], dict]
       t@@ -276,16 +259,7 @@ class ChannelDB(SqlDB):
                now = int(time.time())
                node_id = peer.pubkey
                self._addresses[node_id].add((peer.host, peer.port, now))
       -        self.save_address(node_id, peer, now)
       -
       -    @sql
       -    def save_address(self, node_id, peer, now):
       -        addr = self.DBSession.query(AddressBase).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
       -        if addr:
       -            addr.last_connected_date = now
       -        else:
       -            addr = AddressBase(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
       -            self.DBSession.add(addr)
       +        self.save_node_address(node_id, peer, now)
        
            def get_200_randomly_sorted_nodes_not_in(self, node_ids):
                unshuffled = set(self._nodes.keys()) - node_ids
       t@@ -394,17 +368,47 @@ class ChannelDB(SqlDB):
                orphaned, expired, deprecated, good, to_delete = self.add_channel_updates([payload], verify=False)
                assert len(good) == 1
        
       +    def create_database(self):
       +        c = self.conn.cursor()
       +        c.execute(create_node_info)
       +        c.execute(create_address)
       +        c.execute(create_policy)
       +        c.execute(create_channel_info)
       +        self.conn.commit()
       +
            @sql
            def save_policy(self, policy):
       -        self.DBSession.execute(PolicyBase.__table__.insert().values(policy))
       +        c = self.conn.cursor()
       +        c.execute("""REPLACE INTO policy (key, cltv_expiry_delta, htlc_minimum_msat, htlc_maximum_msat, fee_base_msat, fee_proportional_millionths, channel_flags, timestamp) VALUES (?,?,?,?,?,?, ?, ?)""", list(policy))
        
            @sql
            def delete_policy(self, short_channel_id, node_id):
       -        self.DBSession.execute(PolicyBase.__table__.delete().values(policy))
       +        c = self.conn.cursor()
       +        c.execute("""DELETE FROM policy WHERE key=?""", (key,))
        
            @sql
            def save_channel(self, channel_info):
       -        self.DBSession.execute(ChannelInfoBase.__table__.insert().values(channel_info))
       +        c = self.conn.cursor()
       +        c.execute("REPLACE INTO channel_info (short_channel_id, node1_id, node2_id, capacity_sat) VALUES (?,?,?,?)", list(channel_info))
       +
       +    @sql
       +    def save_node(self, node_info):
       +        c = self.conn.cursor()
       +        c.execute("REPLACE INTO node_info (node_id, features, timestamp, alias) VALUES (?,?,?,?)", list(node_info))
       +
       +    @sql
       +    def save_node_address(self, node_id, peer, now):
       +        c = self.conn.cursor()
       +        c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (node_id, peer.host, peer.port, now))
       +
       +    @sql
       +    def save_node_addresses(self, node_id, node_addresses):
       +        c = self.conn.cursor()
       +        for addr in node_addresses:
       +            c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.node_id, addr.host, addr.port))
       +            r = c.fetchall()
       +            if r == []:
       +                c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.node_id, addr.host, addr.port, 0))
        
            def verify_channel_update(self, payload):
                short_channel_id = payload['short_channel_id']
       t@@ -418,7 +422,6 @@ class ChannelDB(SqlDB):
                    msg_payloads = [msg_payloads]
                old_addr = None
                new_nodes = {}
       -        new_addresses = {}
                for msg_payload in msg_payloads:
                    try:
                        node_info, node_addresses = NodeInfo.from_msg(msg_payload)
       t@@ -445,17 +448,6 @@ class ChannelDB(SqlDB):
                self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
                self.update_counts()
        
       -    @sql
       -    def save_node_addresses(self, node_if, node_addresses):
       -        for new_addr in node_addresses:
       -            old_addr = self.DBSession.query(AddressBase).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
       -            if not old_addr:
       -                self.DBSession.execute(AddressBase.__table__.insert().values(new_addr))
       -
       -    @sql
       -    def save_node(self, node_info):
       -        self.DBSession.execute(NodeInfoBase.__table__.insert().values(node_info))
       -
            def get_routing_policy_for_channel(self, start_node_id: bytes,
                                               short_channel_id: bytes) -> Optional[bytes]:
                if not start_node_id or not short_channel_id: return None
       t@@ -506,12 +498,18 @@ class ChannelDB(SqlDB):
            @sql
            @profiler
            def load_data(self):
       -        for x in self.DBSession.query(AddressBase).all():
       -            self._addresses[x.node_id].add((str(x.host), int(x.port), int(x.last_connected_date or 0)))
       -        for x in self.DBSession.query(ChannelInfoBase).all():
       -            self._channels[x.short_channel_id] = x.to_nametuple()
       -        for x in self.DBSession.query(PolicyBase).filter_by().all():
       -            p = x.to_nametuple()
       +        c = self.conn.cursor()
       +        c.execute("""SELECT * FROM address""")
       +        for x in c:
       +            node_id, host, port, timestamp = x
       +            self._addresses[node_id].add((str(host), int(port), int(timestamp or 0)))
       +        c.execute("""SELECT * FROM channel_info""")
       +        for x in c:
       +            ci = ChannelInfo(*x)
       +            self._channels[ci.short_channel_id] = ci
       +        c.execute("""SELECT * FROM policy""")
       +        for x in c:
       +            p = Policy(*x)
                    self._policies[(p.start_node, p.short_channel_id)] = p
                for channel_info in self._channels.values():
                    self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
   DIR diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py
       t@@ -13,12 +13,7 @@ from enum import IntEnum, auto
        from typing import NamedTuple, Dict
        import jsonrpclib
        
       -from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
       -from sqlalchemy.orm.query import Query
       -from sqlalchemy.ext.declarative import declarative_base
       -from sqlalchemy.sql import not_, or_
        from .sql_db import SqlDB, sql
       -
        from .util import bh2u, bfh, log_exceptions, ignore_exceptions
        from . import wallet
        from .storage import WalletStorage
       t@@ -42,80 +37,105 @@ class TxMinedDepth(IntEnum):
            FREE = auto()
        
        
       -Base = declarative_base()
       -
       -class SweepTx(Base):
       -    __tablename__ = 'sweep_txs'
       -    funding_outpoint = Column(String(34), primary_key=True)
       -    index = Column(Integer(), primary_key=True)
       -    prevout = Column(String(34))
       -    tx = Column(String())
       -
       -class ChannelInfo(Base):
       -    __tablename__ = 'channel_info'
       -    outpoint = Column(String(34), primary_key=True)
       -    address = Column(String(32))
       +create_sweep_txs="""
       +CREATE TABLE IF NOT EXISTS sweep_txs (
       +funding_outpoint VARCHAR(34) NOT NULL,
       +"index" INTEGER NOT NULL,
       +prevout VARCHAR(34),
       +tx VARCHAR,
       +PRIMARY KEY(funding_outpoint, "index")
       +)"""
        
       +create_channel_info="""
       +CREATE TABLE IF NOT EXISTS channel_info (
       +outpoint VARCHAR(34) NOT NULL,
       +address VARCHAR(32),
       +PRIMARY KEY(outpoint)
       +)"""
        
        
        class SweepStore(SqlDB):
        
            def __init__(self, path, network):
       -        super().__init__(network, path, Base)
       +        super().__init__(network, path)
       +
       +    def create_database(self):
       +        c = self.conn.cursor()
       +        c.execute(create_channel_info)
       +        c.execute(create_sweep_txs)
       +        self.conn.commit()
        
            @sql
            def get_sweep_tx(self, funding_outpoint, prevout):
       -        return [Transaction(bh2u(r.tx)) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prevout==prevout).all()]
       +        c = self.conn.cursor()
       +        c.execute("SELECT tx FROM sweep_txs WHERE funding_outpoint=? AND prevout=?", (funding_outpoint, prevout))
       +        return [Transaction(bh2u(r[0])) for r in c.fetchall()]
        
            @sql
            def get_tx_by_index(self, funding_outpoint, index):
       -        r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.index==index).one_or_none()
       -        return str(r.prevout), bh2u(r.tx)
       +        c = self.conn.cursor()
       +        c.execute("""SELECT prevout, tx FROM sweep_txs WHERE funding_outpoint=? AND "index"=?""", (funding_outpoint, index))
       +        r = c.fetchone()[0]
       +        return str(r[0]), bh2u(r[1])
        
            @sql
            def list_sweep_tx(self):
       -        return set(str(r.funding_outpoint) for r in self.DBSession.query(SweepTx).all())
       +        c = self.conn.cursor()
       +        c.execute("SELECT funding_outpoint FROM sweep_txs")
       +        return set([r[0] for r in c.fetchall()])
        
            @sql
            def add_sweep_tx(self, funding_outpoint, prevout, tx):
       -        n = self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()
       -        self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, index=n, prevout=prevout, tx=bfh(tx)))
       -        self.DBSession.commit()
       +        c = self.conn.cursor()
       +        c.execute("SELECT count(*) FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,))
       +        n = int(c.fetchone()[0])
       +        c.execute("""INSERT INTO sweep_txs (funding_outpoint, "index", prevout, tx) VALUES (?,?,?,?)""", (funding_outpoint, n, prevout, bfh(str(tx))))
       +        self.conn.commit()
        
            @sql
            def get_num_tx(self, funding_outpoint):
       -        return int(self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count())
       +        c = self.conn.cursor()
       +        c.execute("SELECT count(*) FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,))
       +        return int(c.fetchone()[0])
        
            @sql
            def remove_sweep_tx(self, funding_outpoint):
       -        r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint).all()
       -        for x in r:
       -            self.DBSession.delete(x)
       -        self.DBSession.commit()
       +        c = self.conn.cursor()
       +        c.execute("DELETE FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,))
       +        self.conn.commit()
        
            @sql
            def add_channel(self, outpoint, address):
       -        self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint))
       -        self.DBSession.commit()
       +        c = self.conn.cursor()
       +        c.execute("INSERT INTO channel_info (address, outpoint) VALUES (?,?)", (address, outpoint))
       +        self.conn.commit()
        
            @sql
            def remove_channel(self, outpoint):
       -        v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()
       -        self.DBSession.delete(v)
       -        self.DBSession.commit()
       +        c = self.conn.cursor()
       +        c.execute("DELETE FROM channel_info WHERE outpoint=?", (outpoint,))
       +        self.conn.commit()
        
            @sql
            def has_channel(self, outpoint):
       -        return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none())
       +        c = self.conn.cursor()
       +        c.execute("SELECT * FROM channel_info WHERE outpoint=?", (outpoint,))
       +        r = c.fetchone()
       +        return r is not None
        
            @sql
            def get_address(self, outpoint):
       -        r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()
       -        return str(r.address) if r else None
       +        c = self.conn.cursor()
       +        c.execute("SELECT address FROM channel_info WHERE outpoint=?", (outpoint,))
       +        r = c.fetchone()
       +        return r[0] if r else None
        
            @sql
            def list_channel_info(self):
       -        return [(str(r.address), str(r.outpoint)) for r in self.DBSession.query(ChannelInfo).all()]
       +        c = self.conn.cursor()
       +        c.execute("SELECT address, outpoint FROM channel_info")
       +        return [(r[0], r[1]) for r in c.fetchall()]
       +
        
        
        class LNWatcher(AddressSynchronizer):
   DIR diff --git a/electrum/sql_db.py b/electrum/sql_db.py
       t@@ -3,18 +3,11 @@ import concurrent
        import queue
        import threading
        import asyncio
       -
       -from sqlalchemy import create_engine
       -from sqlalchemy.pool import StaticPool
       -from sqlalchemy.orm import sessionmaker
       +import sqlite3
        
        from .logging import Logger
        
        
       -# https://stackoverflow.com/questions/26971050/sqlalchemy-sqlite-too-many-sql-variables
       -SQLITE_LIMIT_VARIABLE_NUMBER = 999
       -
       -
        def sql(func):
            """wrapper for sql methods"""
            def wrapper(self, *args, **kwargs):
       t@@ -26,9 +19,8 @@ def sql(func):
        
        class SqlDB(Logger):
            
       -    def __init__(self, network, path, base, commit_interval=None):
       +    def __init__(self, network, path, commit_interval=None):
                Logger.__init__(self)
       -        self.base = base
                self.network = network
                self.path = path
                self.commit_interval = commit_interval
       t@@ -37,13 +29,10 @@ class SqlDB(Logger):
                self.sql_thread.start()
        
            def run_sql(self):
       -        #return
                self.logger.info("SQL thread started")
       -        engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
       -        DBSession = sessionmaker(bind=engine, autoflush=False)
       -        if not os.path.exists(self.path):
       -            self.base.metadata.create_all(engine)
       -        self.DBSession = DBSession()
       +        self.conn = sqlite3.connect(self.path)
       +        self.logger.info("Creating database")
       +        self.create_database()
                i = 0
                while self.network.asyncio_loop.is_running():
                    try:
       t@@ -62,7 +51,7 @@ class SqlDB(Logger):
                    if self.commit_interval:
                        i = (i + 1) % self.commit_interval
                        if i == 0:
       -                    self.DBSession.commit()
       +                    self.conn.commit()
                # write
       -        self.DBSession.commit()
       +        self.conn.commit()
                self.logger.info("SQL thread terminated")