URI: 
       tkeepkey.py - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
       tkeepkey.py (20387B)
       ---
            1 from binascii import hexlify, unhexlify
            2 import traceback
            3 import sys
            4 from typing import NamedTuple, Any, Optional, Dict, Union, List, Tuple, TYPE_CHECKING
            5 
            6 from electrum.util import bfh, bh2u, UserCancelled, UserFacingException
            7 from electrum.bip32 import BIP32Node
            8 from electrum import constants
            9 from electrum.i18n import _
           10 from electrum.transaction import Transaction, PartialTransaction, PartialTxInput, PartialTxOutput
           11 from electrum.keystore import Hardware_KeyStore
           12 from electrum.plugin import Device, runs_in_hwd_thread
           13 from electrum.base_wizard import ScriptTypeNotSupported
           14 
           15 from ..hw_wallet import HW_PluginBase
           16 from ..hw_wallet.plugin import (is_any_tx_output_on_change_branch, trezor_validate_op_return_output_and_get_data,
           17                                 get_xpubs_and_der_suffixes_from_txinout)
           18 
           19 if TYPE_CHECKING:
           20     import usb1
           21     from .client import KeepKeyClient
           22 
           23 
           24 # TREZOR initialization methods
           25 TIM_NEW, TIM_RECOVER, TIM_MNEMONIC, TIM_PRIVKEY = range(0, 4)
           26 
           27 
           28 class KeepKey_KeyStore(Hardware_KeyStore):
           29     hw_type = 'keepkey'
           30     device = 'KeepKey'
           31 
           32     plugin: 'KeepKeyPlugin'
           33 
           34     def get_client(self, force_pair=True):
           35         return self.plugin.get_client(self, force_pair)
           36 
           37     def decrypt_message(self, sequence, message, password):
           38         raise UserFacingException(_('Encryption and decryption are not implemented by {}').format(self.device))
           39 
           40     @runs_in_hwd_thread
           41     def sign_message(self, sequence, message, password):
           42         client = self.get_client()
           43         address_path = self.get_derivation_prefix() + "/%d/%d"%sequence
           44         address_n = client.expand_path(address_path)
           45         msg_sig = client.sign_message(self.plugin.get_coin_name(), address_n, message)
           46         return msg_sig.signature
           47 
           48     @runs_in_hwd_thread
           49     def sign_transaction(self, tx, password):
           50         if tx.is_complete():
           51             return
           52         # previous transactions used as inputs
           53         prev_tx = {}
           54         for txin in tx.inputs():
           55             tx_hash = txin.prevout.txid.hex()
           56             if txin.utxo is None and not txin.is_segwit():
           57                 raise UserFacingException(_('Missing previous tx for legacy input.'))
           58             prev_tx[tx_hash] = txin.utxo
           59 
           60         self.plugin.sign_transaction(self, tx, prev_tx)
           61 
           62 
           63 class KeepKeyPlugin(HW_PluginBase):
           64     # Derived classes provide:
           65     #
           66     #  class-static variables: client_class, firmware_URL, handler_class,
           67     #     libraries_available, libraries_URL, minimum_firmware,
           68     #     wallet_class, ckd_public, types, HidTransport
           69 
           70     firmware_URL = 'https://www.keepkey.com'
           71     libraries_URL = 'https://github.com/keepkey/python-keepkey'
           72     minimum_firmware = (1, 0, 0)
           73     keystore_class = KeepKey_KeyStore
           74     SUPPORTED_XTYPES = ('standard', 'p2wpkh-p2sh', 'p2wpkh', 'p2wsh-p2sh', 'p2wsh')
           75 
           76     MAX_LABEL_LEN = 32
           77 
           78     def __init__(self, parent, config, name):
           79         HW_PluginBase.__init__(self, parent, config, name)
           80 
           81         try:
           82             from . import client
           83             import keepkeylib
           84             import keepkeylib.ckd_public
           85             import keepkeylib.transport_hid
           86             import keepkeylib.transport_webusb
           87             self.client_class = client.KeepKeyClient
           88             self.ckd_public = keepkeylib.ckd_public
           89             self.types = keepkeylib.client.types
           90             self.DEVICE_IDS = (keepkeylib.transport_hid.DEVICE_IDS +
           91                                keepkeylib.transport_webusb.DEVICE_IDS)
           92             # only "register" hid device id:
           93             self.device_manager().register_devices(keepkeylib.transport_hid.DEVICE_IDS, plugin=self)
           94             # for webusb transport, use custom enumerate function:
           95             self.device_manager().register_enumerate_func(self.enumerate)
           96             self.libraries_available = True
           97         except ImportError:
           98             self.libraries_available = False
           99 
          100     @runs_in_hwd_thread
          101     def enumerate(self):
          102         from keepkeylib.transport_webusb import WebUsbTransport
          103         results = []
          104         for dev in WebUsbTransport.enumerate():
          105             path = self._dev_to_str(dev)
          106             results.append(Device(path=path,
          107                                   interface_number=-1,
          108                                   id_=path,
          109                                   product_key=(dev.getVendorID(), dev.getProductID()),
          110                                   usage_page=0,
          111                                   transport_ui_string=f"webusb:{path}"))
          112         return results
          113 
          114     @staticmethod
          115     def _dev_to_str(dev: "usb1.USBDevice") -> str:
          116         return ":".join(str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList())
          117 
          118     @runs_in_hwd_thread
          119     def hid_transport(self, pair):
          120         from keepkeylib.transport_hid import HidTransport
          121         return HidTransport(pair)
          122 
          123     @runs_in_hwd_thread
          124     def webusb_transport(self, device):
          125         from keepkeylib.transport_webusb import WebUsbTransport
          126         for dev in WebUsbTransport.enumerate():
          127             if device.path == self._dev_to_str(dev):
          128                 return WebUsbTransport(dev)
          129 
          130     @runs_in_hwd_thread
          131     def _try_hid(self, device):
          132         self.logger.info("Trying to connect over USB...")
          133         if device.interface_number == 1:
          134             pair = [None, device.path]
          135         else:
          136             pair = [device.path, None]
          137 
          138         try:
          139             return self.hid_transport(pair)
          140         except BaseException as e:
          141             # see fdb810ba622dc7dbe1259cbafb5b28e19d2ab114
          142             # raise
          143             self.logger.info(f"cannot connect at {device.path} {e}")
          144             return None
          145 
          146     @runs_in_hwd_thread
          147     def _try_webusb(self, device):
          148         self.logger.info("Trying to connect over WebUSB...")
          149         try:
          150             return self.webusb_transport(device)
          151         except BaseException as e:
          152             self.logger.info(f"cannot connect at {device.path} {e}")
          153             return None
          154 
          155     @runs_in_hwd_thread
          156     def create_client(self, device, handler):
          157         if device.product_key[1] == 2:
          158             transport = self._try_webusb(device)
          159         else:
          160             transport = self._try_hid(device)
          161 
          162         if not transport:
          163             self.logger.info("cannot connect to device")
          164             return
          165 
          166         self.logger.info(f"connected to device at {device.path}")
          167 
          168         client = self.client_class(transport, handler, self)
          169 
          170         # Try a ping for device sanity
          171         try:
          172             client.ping('t')
          173         except BaseException as e:
          174             self.logger.info(f"ping failed {e}")
          175             return None
          176 
          177         if not client.atleast_version(*self.minimum_firmware):
          178             msg = (_('Outdated {} firmware for device labelled {}. Please '
          179                      'download the updated firmware from {}')
          180                    .format(self.device, client.label(), self.firmware_URL))
          181             self.logger.info(msg)
          182             if handler:
          183                 handler.show_error(msg)
          184             else:
          185                 raise UserFacingException(msg)
          186             return None
          187 
          188         return client
          189 
          190     @runs_in_hwd_thread
          191     def get_client(self, keystore, force_pair=True, *,
          192                    devices=None, allow_user_interaction=True) -> Optional['KeepKeyClient']:
          193         client = super().get_client(keystore, force_pair,
          194                                     devices=devices,
          195                                     allow_user_interaction=allow_user_interaction)
          196         # returns the client for a given keystore. can use xpub
          197         if client:
          198             client.used()
          199         return client
          200 
          201     def get_coin_name(self):
          202         return "Testnet" if constants.net.TESTNET else "Bitcoin"
          203 
          204     def initialize_device(self, device_id, wizard, handler):
          205         # Initialization method
          206         msg = _("Choose how you want to initialize your {}.\n\n"
          207                 "The first two methods are secure as no secret information "
          208                 "is entered into your computer.\n\n"
          209                 "For the last two methods you input secrets on your keyboard "
          210                 "and upload them to your {}, and so you should "
          211                 "only do those on a computer you know to be trustworthy "
          212                 "and free of malware."
          213         ).format(self.device, self.device)
          214         choices = [
          215             # Must be short as QT doesn't word-wrap radio button text
          216             (TIM_NEW, _("Let the device generate a completely new seed randomly")),
          217             (TIM_RECOVER, _("Recover from a seed you have previously written down")),
          218             (TIM_MNEMONIC, _("Upload a BIP39 mnemonic to generate the seed")),
          219             (TIM_PRIVKEY, _("Upload a master private key"))
          220         ]
          221         def f(method):
          222             import threading
          223             settings = self.request_trezor_init_settings(wizard, method, self.device)
          224             t = threading.Thread(target=self._initialize_device_safe, args=(settings, method, device_id, wizard, handler))
          225             t.setDaemon(True)
          226             t.start()
          227             exit_code = wizard.loop.exec_()
          228             if exit_code != 0:
          229                 # this method (initialize_device) was called with the expectation
          230                 # of leaving the device in an initialized state when finishing.
          231                 # signal that this is not the case:
          232                 raise UserCancelled()
          233         wizard.choice_dialog(title=_('Initialize Device'), message=msg, choices=choices, run_next=f)
          234 
          235     def _initialize_device_safe(self, settings, method, device_id, wizard, handler):
          236         exit_code = 0
          237         try:
          238             self._initialize_device(settings, method, device_id, wizard, handler)
          239         except UserCancelled:
          240             exit_code = 1
          241         except BaseException as e:
          242             self.logger.exception('')
          243             handler.show_error(repr(e))
          244             exit_code = 1
          245         finally:
          246             wizard.loop.exit(exit_code)
          247 
          248     @runs_in_hwd_thread
          249     def _initialize_device(self, settings, method, device_id, wizard, handler):
          250         item, label, pin_protection, passphrase_protection = settings
          251 
          252         language = 'english'
          253         devmgr = self.device_manager()
          254         client = devmgr.client_by_id(device_id)
          255         if not client:
          256             raise Exception(_("The device was disconnected."))
          257 
          258         if method == TIM_NEW:
          259             strength = 64 * (item + 2)  # 128, 192 or 256
          260             client.reset_device(True, strength, passphrase_protection,
          261                                 pin_protection, label, language)
          262         elif method == TIM_RECOVER:
          263             word_count = 6 * (item + 2)  # 12, 18 or 24
          264             client.step = 0
          265             client.recovery_device(word_count, passphrase_protection,
          266                                        pin_protection, label, language)
          267         elif method == TIM_MNEMONIC:
          268             pin = pin_protection  # It's the pin, not a boolean
          269             client.load_device_by_mnemonic(str(item), pin,
          270                                            passphrase_protection,
          271                                            label, language)
          272         else:
          273             pin = pin_protection  # It's the pin, not a boolean
          274             client.load_device_by_xprv(item, pin, passphrase_protection,
          275                                        label, language)
          276 
          277     def _make_node_path(self, xpub, address_n):
          278         bip32node = BIP32Node.from_xkey(xpub)
          279         node = self.types.HDNodeType(
          280             depth=bip32node.depth,
          281             fingerprint=int.from_bytes(bip32node.fingerprint, 'big'),
          282             child_num=int.from_bytes(bip32node.child_number, 'big'),
          283             chain_code=bip32node.chaincode,
          284             public_key=bip32node.eckey.get_public_key_bytes(compressed=True),
          285         )
          286         return self.types.HDNodePathType(node=node, address_n=address_n)
          287 
          288     def setup_device(self, device_info, wizard, purpose):
          289         device_id = device_info.device.id_
          290         client = self.scan_and_create_client_for_device(device_id=device_id, wizard=wizard)
          291         if not device_info.initialized:
          292             self.initialize_device(device_id, wizard, client.handler)
          293         wizard.run_task_without_blocking_gui(
          294             task=lambda: client.get_xpub("m", 'standard'))
          295         client.used()
          296         return client
          297 
          298     def get_xpub(self, device_id, derivation, xtype, wizard):
          299         if xtype not in self.SUPPORTED_XTYPES:
          300             raise ScriptTypeNotSupported(_('This type of script is not supported with {}.').format(self.device))
          301         client = self.scan_and_create_client_for_device(device_id=device_id, wizard=wizard)
          302         xpub = client.get_xpub(derivation, xtype)
          303         client.used()
          304         return xpub
          305 
          306     def get_keepkey_input_script_type(self, electrum_txin_type: str):
          307         if electrum_txin_type in ('p2wpkh', 'p2wsh'):
          308             return self.types.SPENDWITNESS
          309         if electrum_txin_type in ('p2wpkh-p2sh', 'p2wsh-p2sh'):
          310             return self.types.SPENDP2SHWITNESS
          311         if electrum_txin_type in ('p2pkh', ):
          312             return self.types.SPENDADDRESS
          313         if electrum_txin_type in ('p2sh', ):
          314             return self.types.SPENDMULTISIG
          315         raise ValueError('unexpected txin type: {}'.format(electrum_txin_type))
          316 
          317     def get_keepkey_output_script_type(self, electrum_txin_type: str):
          318         if electrum_txin_type in ('p2wpkh', 'p2wsh'):
          319             return self.types.PAYTOWITNESS
          320         if electrum_txin_type in ('p2wpkh-p2sh', 'p2wsh-p2sh'):
          321             return self.types.PAYTOP2SHWITNESS
          322         if electrum_txin_type in ('p2pkh', ):
          323             return self.types.PAYTOADDRESS
          324         if electrum_txin_type in ('p2sh', ):
          325             return self.types.PAYTOMULTISIG
          326         raise ValueError('unexpected txin type: {}'.format(electrum_txin_type))
          327 
          328     @runs_in_hwd_thread
          329     def sign_transaction(self, keystore, tx: PartialTransaction, prev_tx):
          330         self.prev_tx = prev_tx
          331         client = self.get_client(keystore)
          332         inputs = self.tx_inputs(tx, for_sig=True, keystore=keystore)
          333         outputs = self.tx_outputs(tx, keystore=keystore)
          334         signatures = client.sign_tx(self.get_coin_name(), inputs, outputs,
          335                                     lock_time=tx.locktime, version=tx.version)[0]
          336         signatures = [(bh2u(x) + '01') for x in signatures]
          337         tx.update_signatures(signatures)
          338 
          339     @runs_in_hwd_thread
          340     def show_address(self, wallet, address, keystore=None):
          341         if keystore is None:
          342             keystore = wallet.get_keystore()
          343         if not self.show_address_helper(wallet, address, keystore):
          344             return
          345         client = self.get_client(keystore)
          346         if not client.atleast_version(1, 3):
          347             keystore.handler.show_error(_("Your device firmware is too old"))
          348             return
          349         deriv_suffix = wallet.get_address_index(address)
          350         derivation = keystore.get_derivation_prefix()
          351         address_path = "%s/%d/%d"%(derivation, *deriv_suffix)
          352         address_n = client.expand_path(address_path)
          353         script_type = self.get_keepkey_input_script_type(wallet.txin_type)
          354 
          355         # prepare multisig, if available:
          356         xpubs = wallet.get_master_public_keys()
          357         if len(xpubs) > 1:
          358             pubkeys = wallet.get_public_keys(address)
          359             # sort xpubs using the order of pubkeys
          360             sorted_pairs = sorted(zip(pubkeys, xpubs))
          361             multisig = self._make_multisig(
          362                 wallet.m,
          363                 [(xpub, deriv_suffix) for pubkey, xpub in sorted_pairs])
          364         else:
          365             multisig = None
          366 
          367         client.get_address(self.get_coin_name(), address_n, True, multisig=multisig, script_type=script_type)
          368 
          369     def tx_inputs(self, tx: Transaction, *, for_sig=False, keystore: 'KeepKey_KeyStore' = None):
          370         inputs = []
          371         for txin in tx.inputs():
          372             txinputtype = self.types.TxInputType()
          373             if txin.is_coinbase_input():
          374                 prev_hash = b"\x00"*32
          375                 prev_index = 0xffffffff  # signed int -1
          376             else:
          377                 if for_sig:
          378                     assert isinstance(tx, PartialTransaction)
          379                     assert isinstance(txin, PartialTxInput)
          380                     assert keystore
          381                     if len(txin.pubkeys) > 1:
          382                         xpubs_and_deriv_suffixes = get_xpubs_and_der_suffixes_from_txinout(tx, txin)
          383                         multisig = self._make_multisig(txin.num_sig, xpubs_and_deriv_suffixes)
          384                     else:
          385                         multisig = None
          386                     script_type = self.get_keepkey_input_script_type(txin.script_type)
          387                     txinputtype = self.types.TxInputType(
          388                         script_type=script_type,
          389                         multisig=multisig)
          390                     my_pubkey, full_path = keystore.find_my_pubkey_in_txinout(txin)
          391                     if full_path:
          392                         txinputtype.address_n.extend(full_path)
          393 
          394                 prev_hash = txin.prevout.txid
          395                 prev_index = txin.prevout.out_idx
          396 
          397             if txin.value_sats() is not None:
          398                 txinputtype.amount = txin.value_sats()
          399             txinputtype.prev_hash = prev_hash
          400             txinputtype.prev_index = prev_index
          401 
          402             if txin.script_sig is not None:
          403                 txinputtype.script_sig = txin.script_sig
          404 
          405             txinputtype.sequence = txin.nsequence
          406 
          407             inputs.append(txinputtype)
          408 
          409         return inputs
          410 
          411     def _make_multisig(self, m, xpubs):
          412         if len(xpubs) == 1:
          413             return None
          414         pubkeys = [self._make_node_path(xpub, deriv) for xpub, deriv in xpubs]
          415         return self.types.MultisigRedeemScriptType(
          416             pubkeys=pubkeys,
          417             signatures=[b''] * len(pubkeys),
          418             m=m)
          419 
          420     def tx_outputs(self, tx: PartialTransaction, *, keystore: 'KeepKey_KeyStore'):
          421 
          422         def create_output_by_derivation():
          423             script_type = self.get_keepkey_output_script_type(txout.script_type)
          424             if len(txout.pubkeys) > 1:
          425                 xpubs_and_deriv_suffixes = get_xpubs_and_der_suffixes_from_txinout(tx, txout)
          426                 multisig = self._make_multisig(txout.num_sig, xpubs_and_deriv_suffixes)
          427             else:
          428                 multisig = None
          429             my_pubkey, full_path = keystore.find_my_pubkey_in_txinout(txout)
          430             assert full_path
          431             txoutputtype = self.types.TxOutputType(
          432                 multisig=multisig,
          433                 amount=txout.value,
          434                 address_n=full_path,
          435                 script_type=script_type)
          436             return txoutputtype
          437 
          438         def create_output_by_address():
          439             txoutputtype = self.types.TxOutputType()
          440             txoutputtype.amount = txout.value
          441             if address:
          442                 txoutputtype.script_type = self.types.PAYTOADDRESS
          443                 txoutputtype.address = address
          444             else:
          445                 txoutputtype.script_type = self.types.PAYTOOPRETURN
          446                 txoutputtype.op_return_data = trezor_validate_op_return_output_and_get_data(txout)
          447             return txoutputtype
          448 
          449         outputs = []
          450         has_change = False
          451         any_output_on_change_branch = is_any_tx_output_on_change_branch(tx)
          452 
          453         for txout in tx.outputs():
          454             address = txout.address
          455             use_create_by_derivation = False
          456 
          457             if txout.is_mine and not has_change:
          458                 # prioritise hiding outputs on the 'change' branch from user
          459                 # because no more than one change address allowed
          460                 if txout.is_change == any_output_on_change_branch:
          461                     use_create_by_derivation = True
          462                     has_change = True
          463 
          464             if use_create_by_derivation:
          465                 txoutputtype = create_output_by_derivation()
          466             else:
          467                 txoutputtype = create_output_by_address()
          468             outputs.append(txoutputtype)
          469 
          470         return outputs
          471 
          472     def electrum_tx_to_txtype(self, tx: Optional[Transaction]):
          473         t = self.types.TransactionType()
          474         if tx is None:
          475             # probably for segwit input and we don't need this prev txn
          476             return t
          477         tx.deserialize()
          478         t.version = tx.version
          479         t.lock_time = tx.locktime
          480         inputs = self.tx_inputs(tx)
          481         t.inputs.extend(inputs)
          482         for out in tx.outputs():
          483             o = t.bin_outputs.add()
          484             o.amount = out.value
          485             o.script_pubkey = out.scriptpubkey
          486         return t
          487 
          488     # This function is called from the TREZOR libraries (via tx_api)
          489     def get_tx(self, tx_hash):
          490         tx = self.prev_tx[tx_hash]
          491         return self.electrum_tx_to_txtype(tx)