URI: 
       tcreate parent class for sql databases - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit d8e9a9a49e38fb9353fb80b3d72debc0ccd711ce
   DIR parent b861e2e955c4a790d8e2b4ce262b894a67c3b470
  HTML Author: ThomasV <thomasv@electrum.org>
       Date:   Wed,  6 Mar 2019 09:56:22 +0100
       
       create parent class for sql databases
       
       Diffstat:
         M electrum/lnrouter.py                |      51 ++++++++-----------------------
         M electrum/lnwatcher.py               |      52 ++++++-------------------------
         A electrum/sql_db.py                  |      51 +++++++++++++++++++++++++++++++
       
       3 files changed, 72 insertions(+), 82 deletions(-)
       ---
   DIR diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
       t@@ -35,13 +35,11 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
        import binascii
        import base64
        
       -from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
       -from sqlalchemy.pool import StaticPool
       -from sqlalchemy.orm import sessionmaker
       +from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
        from sqlalchemy.orm.query import Query
        from sqlalchemy.ext.declarative import declarative_base
        from sqlalchemy.sql import not_, or_
       -from sqlalchemy.orm import scoped_session
       +from .sql_db import SqlDB, sql
        
        from . import constants
        from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
       t@@ -212,50 +210,25 @@ class Address(Base):
            last_connected_date = Column(DateTime(), nullable=False)
        
        
       -class ChannelDB(PrintError):
       +
       +
       +class ChannelDB(SqlDB):
        
            NUM_MAX_RECENT_PEERS = 20
        
            def __init__(self, network: 'Network'):
       -        self.network = network
       +        path = os.path.join(get_headers_dir(network.config), 'channel_db')
       +        super().__init__(network, path, Base)
       +        print(Base)
                self.num_nodes = 0
                self.num_channels = 0
       -        self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3')
                self._channel_updates_for_private_channels = {}  # type: Dict[Tuple[bytes, bytes], dict]
                self.ca_verifier = LNChannelVerifier(network, self)
       -        self.db_requests = queue.Queue()
       -        threading.Thread(target=self.sql_thread).start()
       -
       -    def sql_thread(self):
       -        self.sql_thread = threading.currentThread()
       -        engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
       -        DBSession = sessionmaker(bind=engine, autoflush=False)
       -        self.DBSession = DBSession()
       -        if not os.path.exists(self.path):
       -            Base.metadata.create_all(engine)
       +        self.update_counts()
       +
       +    @sql
       +    def update_counts(self):
                self._update_counts()
       -        while self.network.asyncio_loop.is_running():
       -            try:
       -                future, func, args, kwargs = self.db_requests.get(timeout=0.1)
       -            except queue.Empty:
       -                continue
       -            try:
       -                result = func(self, *args, **kwargs)
       -            except BaseException as e:
       -                future.set_exception(e)
       -                continue
       -            future.set_result(result)
       -        # write
       -        self.DBSession.commit()
       -        self.print_error("SQL thread terminated")
       -
       -    def sql(func):
       -        def wrapper(self, *args, **kwargs):
       -            assert threading.currentThread() != self.sql_thread
       -            f = concurrent.futures.Future()
       -            self.db_requests.put((f, func, args, kwargs))
       -            return f.result(timeout=10)
       -        return wrapper
        
            def _update_counts(self):
                self.num_channels = self.DBSession.query(ChannelInfo).count()
   DIR diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py
       t@@ -11,9 +11,14 @@ from collections import defaultdict
        import asyncio
        from enum import IntEnum, auto
        from typing import NamedTuple, Dict
       -
        import jsonrpclib
        
       +from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
       +from sqlalchemy.orm.query import Query
       +from sqlalchemy.ext.declarative import declarative_base
       +from sqlalchemy.sql import not_, or_
       +from .sql_db import SqlDB, sql
       +
        from .util import PrintError, bh2u, bfh, log_exceptions, ignore_exceptions
        from . import wallet
        from .storage import WalletStorage
       t@@ -37,14 +42,6 @@ class TxMinedDepth(IntEnum):
            FREE = auto()
        
        
       -from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
       -from sqlalchemy.pool import StaticPool
       -from sqlalchemy.orm import sessionmaker
       -from sqlalchemy.orm.query import Query
       -from sqlalchemy.ext.declarative import declarative_base
       -from sqlalchemy.sql import not_, or_
       -from sqlalchemy.orm import scoped_session
       -
        Base = declarative_base()
        
        class SweepTx(Base):
       t@@ -60,42 +57,11 @@ class ChannelInfo(Base):
            outpoint = Column(String(34))
        
        
       -class SweepStore(PrintError):
        
       -    def __init__(self, path, network):
       -        PrintError.__init__(self)
       -        self.path = path
       -        self.network = network
       -        self.db_requests = queue.Queue()
       -        threading.Thread(target=self.sql_thread).start()
       -
       -    def sql_thread(self):
       -        engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)
       -        DBSession = sessionmaker(bind=engine, autoflush=False)
       -        self.DBSession = DBSession()
       -        if not os.path.exists(self.path):
       -            Base.metadata.create_all(engine)
       -        while self.network.asyncio_loop.is_running():
       -            try:
       -                future, func, args, kwargs = self.db_requests.get(timeout=0.1)
       -            except queue.Empty:
       -                continue
       -            try:
       -                result = func(self, *args, **kwargs)
       -            except BaseException as e:
       -                future.set_exception(e)
       -                continue
       -            future.set_result(result)
       -        # write
       -        self.DBSession.commit()
       -        self.print_error("SQL thread terminated")
       +class SweepStore(SqlDB):
        
       -    def sql(func):
       -        def wrapper(self, *args, **kwargs):
       -            f = concurrent.futures.Future()
       -            self.db_requests.put((f, func, args, kwargs))
       -            return f.result(timeout=10)
       -        return wrapper
       +    def __init__(self, path, network):
       +        super().__init__(network, path, Base)
        
            @sql
            def get_sweep_tx(self, funding_outpoint, prev_txid):
   DIR diff --git a/electrum/sql_db.py b/electrum/sql_db.py
       t@@ -0,0 +1,51 @@
       +import os
       +import concurrent
       +import queue
       +import threading
       +
       +from sqlalchemy import create_engine
       +from sqlalchemy.pool import StaticPool
       +from sqlalchemy.orm import sessionmaker
       +
       +from .util import PrintError
       +
       +
       +def sql(func):
       +    """wrapper for sql methods"""
       +    def wrapper(self, *args, **kwargs):
       +        assert threading.currentThread() != self.sql_thread
       +        f = concurrent.futures.Future()
       +        self.db_requests.put((f, func, args, kwargs))
       +        return f.result(timeout=10)
       +    return wrapper
       +
       +class SqlDB(PrintError):
       +    
       +    def __init__(self, network, path, base):
       +        self.base = base
       +        self.network = network
       +        self.path = path
       +        self.db_requests = queue.Queue()
       +        self.sql_thread = threading.Thread(target=self.run_sql)
       +        self.sql_thread.start()
       +
       +    def run_sql(self):
       +        engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
       +        DBSession = sessionmaker(bind=engine, autoflush=False)
       +        self.DBSession = DBSession()
       +        if not os.path.exists(self.path):
       +            self.base.metadata.create_all(engine)
       +        while self.network.asyncio_loop.is_running():
       +            try:
       +                future, func, args, kwargs = self.db_requests.get(timeout=0.1)
       +            except queue.Empty:
       +                continue
       +            try:
       +                result = func(self, *args, **kwargs)
       +            except BaseException as e:
       +                future.set_exception(e)
       +                continue
       +            future.set_result(result)
       +        # write
       +        self.DBSession.commit()
       +        self.print_error("SQL thread terminated")