URI: 
       trsakey.py - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
       trsakey.py (16814B)
       ---
            1 #!/usr/bin/env python
            2 #
            3 # Electrum - lightweight Bitcoin client
            4 # Copyright (C) 2015 Thomas Voegtlin
            5 #
            6 # Permission is hereby granted, free of charge, to any person
            7 # obtaining a copy of this software and associated documentation files
            8 # (the "Software"), to deal in the Software without restriction,
            9 # including without limitation the rights to use, copy, modify, merge,
           10 # publish, distribute, sublicense, and/or sell copies of the Software,
           11 # and to permit persons to whom the Software is furnished to do so,
           12 # subject to the following conditions:
           13 #
           14 # The above copyright notice and this permission notice shall be
           15 # included in all copies or substantial portions of the Software.
           16 #
           17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
           18 # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
           19 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
           20 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
           21 # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
           22 # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
           23 # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
           24 # SOFTWARE.
           25 
           26 # This module uses functions from TLSLite (public domain)
           27 #
           28 # TLSLite Authors:
           29 #   Trevor Perrin
           30 #   Martin von Loewis - python 3 port
           31 #   Yngve Pettersen (ported by Paul Sokolovsky) - TLS 1.2
           32 #
           33 
           34 """Pure-Python RSA implementation."""
           35 
           36 import os
           37 import math
           38 import hashlib
           39 
           40 
           41 def SHA1(x):
           42     return hashlib.sha1(x).digest()
           43 
           44 
           45 # **************************************************************************
           46 # PRNG Functions
           47 # **************************************************************************
           48 
           49 # Check that os.urandom works
           50 import zlib
           51 length = len(zlib.compress(os.urandom(1000)))
           52 assert(length > 900)
           53 
           54 def getRandomBytes(howMany):
           55     b = bytearray(os.urandom(howMany))
           56     assert(len(b) == howMany)
           57     return b
           58 
           59 prngName = "os.urandom"
           60 
           61 
           62 # **************************************************************************
           63 # Converter Functions
           64 # **************************************************************************
           65 
           66 def bytesToNumber(b):
           67     total = 0
           68     multiplier = 1
           69     for count in range(len(b)-1, -1, -1):
           70         byte = b[count]
           71         total += multiplier * byte
           72         multiplier *= 256
           73     return total
           74 
           75 def numberToByteArray(n, howManyBytes=None):
           76     """Convert an integer into a bytearray, zero-pad to howManyBytes.
           77 
           78     The returned bytearray may be smaller than howManyBytes, but will
           79     not be larger.  The returned bytearray will contain a big-endian
           80     encoding of the input integer (n).
           81     """
           82     if howManyBytes == None:
           83         howManyBytes = numBytes(n)
           84     b = bytearray(howManyBytes)
           85     for count in range(howManyBytes-1, -1, -1):
           86         b[count] = int(n % 256)
           87         n >>= 8
           88     return b
           89 
           90 def mpiToNumber(mpi): #mpi is an openssl-format bignum string
           91     if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number
           92         raise AssertionError()
           93     b = bytearray(mpi[4:])
           94     return bytesToNumber(b)
           95 
           96 def numberToMPI(n):
           97     b = numberToByteArray(n)
           98     ext = 0
           99     #If the high-order bit is going to be set,
          100     #add an extra byte of zeros
          101     if (numBits(n) & 0x7)==0:
          102         ext = 1
          103     length = numBytes(n) + ext
          104     b = bytearray(4+ext) + b
          105     b[0] = (length >> 24) & 0xFF
          106     b[1] = (length >> 16) & 0xFF
          107     b[2] = (length >> 8) & 0xFF
          108     b[3] = length & 0xFF
          109     return bytes(b)
          110 
          111 
          112 # **************************************************************************
          113 # Misc. Utility Functions
          114 # **************************************************************************
          115 
          116 def numBits(n):
          117     if n==0:
          118         return 0
          119     s = "%x" % n
          120     return ((len(s)-1)*4) + \
          121     {'0':0, '1':1, '2':2, '3':2,
          122      '4':3, '5':3, '6':3, '7':3,
          123      '8':4, '9':4, 'a':4, 'b':4,
          124      'c':4, 'd':4, 'e':4, 'f':4,
          125      }[s[0]]
          126 
          127 def numBytes(n):
          128     if n==0:
          129         return 0
          130     bits = numBits(n)
          131     return int(math.ceil(bits / 8.0))
          132 
          133 # **************************************************************************
          134 # Big Number Math
          135 # **************************************************************************
          136 
          137 def getRandomNumber(low, high):
          138     if low >= high:
          139         raise AssertionError()
          140     howManyBits = numBits(high)
          141     howManyBytes = numBytes(high)
          142     lastBits = howManyBits % 8
          143     while 1:
          144         bytes = getRandomBytes(howManyBytes)
          145         if lastBits:
          146             bytes[0] = bytes[0] % (1 << lastBits)
          147         n = bytesToNumber(bytes)
          148         if n >= low and n < high:
          149             return n
          150 
          151 def gcd(a,b):
          152     a, b = max(a,b), min(a,b)
          153     while b:
          154         a, b = b, a % b
          155     return a
          156 
          157 def lcm(a, b):
          158     return (a * b) // gcd(a, b)
          159 
          160 #Returns inverse of a mod b, zero if none
          161 #Uses Extended Euclidean Algorithm
          162 def invMod(a, b):
          163     c, d = a, b
          164     uc, ud = 1, 0
          165     while c != 0:
          166         q = d // c
          167         c, d = d-(q*c), c
          168         uc, ud = ud - (q * uc), uc
          169     if d == 1:
          170         return ud % b
          171     return 0
          172 
          173 
          174 def powMod(base, power, modulus):
          175     if power < 0:
          176         result = pow(base, power*-1, modulus)
          177         result = invMod(result, modulus)
          178         return result
          179     else:
          180         return pow(base, power, modulus)
          181 
          182 #Pre-calculate a sieve of the ~100 primes < 1000:
          183 def makeSieve(n):
          184     sieve = list(range(n))
          185     for count in range(2, int(math.sqrt(n))+1):
          186         if sieve[count] == 0:
          187             continue
          188         x = sieve[count] * 2
          189         while x < len(sieve):
          190             sieve[x] = 0
          191             x += sieve[count]
          192     sieve = [x for x in sieve[2:] if x]
          193     return sieve
          194 
          195 sieve = makeSieve(1000)
          196 
          197 def isPrime(n, iterations=5, display=False):
          198     #Trial division with sieve
          199     for x in sieve:
          200         if x >= n: return True
          201         if n % x == 0: return False
          202     #Passed trial division, proceed to Rabin-Miller
          203     #Rabin-Miller implemented per Ferguson & Schneier
          204     #Compute s, t for Rabin-Miller
          205     if display: print("*", end=' ')
          206     s, t = n-1, 0
          207     while s % 2 == 0:
          208         s, t = s//2, t+1
          209     #Repeat Rabin-Miller x times
          210     a = 2 #Use 2 as a base for first iteration speedup, per HAC
          211     for count in range(iterations):
          212         v = powMod(a, s, n)
          213         if v==1:
          214             continue
          215         i = 0
          216         while v != n-1:
          217             if i == t-1:
          218                 return False
          219             else:
          220                 v, i = powMod(v, 2, n), i+1
          221         a = getRandomNumber(2, n)
          222     return True
          223 
          224 def getRandomPrime(bits, display=False):
          225     if bits < 10:
          226         raise AssertionError()
          227     #The 1.5 ensures the 2 MSBs are set
          228     #Thus, when used for p,q in RSA, n will have its MSB set
          229     #
          230     #Since 30 is lcm(2,3,5), we'll set our test numbers to
          231     #29 % 30 and keep them there
          232     low = ((2 ** (bits-1)) * 3) // 2
          233     high = 2 ** bits - 30
          234     p = getRandomNumber(low, high)
          235     p += 29 - (p % 30)
          236     while 1:
          237         if display: print(".", end=' ')
          238         p += 30
          239         if p >= high:
          240             p = getRandomNumber(low, high)
          241             p += 29 - (p % 30)
          242         if isPrime(p, display=display):
          243             return p
          244 
          245 #Unused at the moment...
          246 def getRandomSafePrime(bits, display=False):
          247     if bits < 10:
          248         raise AssertionError()
          249     #The 1.5 ensures the 2 MSBs are set
          250     #Thus, when used for p,q in RSA, n will have its MSB set
          251     #
          252     #Since 30 is lcm(2,3,5), we'll set our test numbers to
          253     #29 % 30 and keep them there
          254     low = (2 ** (bits-2)) * 3//2
          255     high = (2 ** (bits-1)) - 30
          256     q = getRandomNumber(low, high)
          257     q += 29 - (q % 30)
          258     while 1:
          259         if display: print(".", end=' ')
          260         q += 30
          261         if (q >= high):
          262             q = getRandomNumber(low, high)
          263             q += 29 - (q % 30)
          264         #Ideas from Tom Wu's SRP code
          265         #Do trial division on p and q before Rabin-Miller
          266         if isPrime(q, 0, display=display):
          267             p = (2 * q) + 1
          268             if isPrime(p, display=display):
          269                 if isPrime(q, display=display):
          270                     return p
          271 
          272 
          273 class RSAKey(object):
          274 
          275     def __init__(self, n=0, e=0, d=0, p=0, q=0, dP=0, dQ=0, qInv=0):
          276         if (n and not e) or (e and not n):
          277             raise AssertionError()
          278         self.n = n
          279         self.e = e
          280         self.d = d
          281         self.p = p
          282         self.q = q
          283         self.dP = dP
          284         self.dQ = dQ
          285         self.qInv = qInv
          286         self.blinder = 0
          287         self.unblinder = 0
          288 
          289     def __len__(self):
          290         """Return the length of this key in bits.
          291 
          292         @rtype: int
          293         """
          294         return numBits(self.n)
          295 
          296     def hasPrivateKey(self):
          297         return self.d != 0
          298 
          299     def hashAndSign(self, bytes):
          300         """Hash and sign the passed-in bytes.
          301 
          302         This requires the key to have a private component.  It performs
          303         a PKCS1-SHA1 signature on the passed-in data.
          304 
          305         @type bytes: str or L{bytearray} of unsigned bytes
          306         @param bytes: The value which will be hashed and signed.
          307 
          308         @rtype: L{bytearray} of unsigned bytes.
          309         @return: A PKCS1-SHA1 signature on the passed-in data.
          310         """
          311         hashBytes = SHA1(bytearray(bytes))
          312         prefixedHashBytes = self._addPKCS1SHA1Prefix(hashBytes)
          313         sigBytes = self.sign(prefixedHashBytes)
          314         return sigBytes
          315 
          316     def hashAndVerify(self, sigBytes, bytes):
          317         """Hash and verify the passed-in bytes with the signature.
          318 
          319         This verifies a PKCS1-SHA1 signature on the passed-in data.
          320 
          321         @type sigBytes: L{bytearray} of unsigned bytes
          322         @param sigBytes: A PKCS1-SHA1 signature.
          323 
          324         @type bytes: str or L{bytearray} of unsigned bytes
          325         @param bytes: The value which will be hashed and verified.
          326 
          327         @rtype: bool
          328         @return: Whether the signature matches the passed-in data.
          329         """
          330         hashBytes = SHA1(bytearray(bytes))
          331 
          332         # Try it with/without the embedded NULL
          333         prefixedHashBytes1 = self._addPKCS1SHA1Prefix(hashBytes, False)
          334         prefixedHashBytes2 = self._addPKCS1SHA1Prefix(hashBytes, True)
          335         result1 = self.verify(sigBytes, prefixedHashBytes1)
          336         result2 = self.verify(sigBytes, prefixedHashBytes2)
          337         return (result1 or result2)
          338 
          339     def sign(self, bytes):
          340         """Sign the passed-in bytes.
          341 
          342         This requires the key to have a private component.  It performs
          343         a PKCS1 signature on the passed-in data.
          344 
          345         @type bytes: L{bytearray} of unsigned bytes
          346         @param bytes: The value which will be signed.
          347 
          348         @rtype: L{bytearray} of unsigned bytes.
          349         @return: A PKCS1 signature on the passed-in data.
          350         """
          351         if not self.hasPrivateKey():
          352             raise AssertionError()
          353         paddedBytes = self._addPKCS1Padding(bytes, 1)
          354         m = bytesToNumber(paddedBytes)
          355         if m >= self.n:
          356             raise ValueError()
          357         c = self._rawPrivateKeyOp(m)
          358         sigBytes = numberToByteArray(c, numBytes(self.n))
          359         return sigBytes
          360 
          361     def verify(self, sigBytes, bytes):
          362         """Verify the passed-in bytes with the signature.
          363 
          364         This verifies a PKCS1 signature on the passed-in data.
          365 
          366         @type sigBytes: L{bytearray} of unsigned bytes
          367         @param sigBytes: A PKCS1 signature.
          368 
          369         @type bytes: L{bytearray} of unsigned bytes
          370         @param bytes: The value which will be verified.
          371 
          372         @rtype: bool
          373         @return: Whether the signature matches the passed-in data.
          374         """
          375         if len(sigBytes) != numBytes(self.n):
          376             return False
          377         paddedBytes = self._addPKCS1Padding(bytes, 1)
          378         c = bytesToNumber(sigBytes)
          379         if c >= self.n:
          380             return False
          381         m = self._rawPublicKeyOp(c)
          382         checkBytes = numberToByteArray(m, numBytes(self.n))
          383         return checkBytes == paddedBytes
          384 
          385     def encrypt(self, bytes):
          386         """Encrypt the passed-in bytes.
          387 
          388         This performs PKCS1 encryption of the passed-in data.
          389 
          390         @type bytes: L{bytearray} of unsigned bytes
          391         @param bytes: The value which will be encrypted.
          392 
          393         @rtype: L{bytearray} of unsigned bytes.
          394         @return: A PKCS1 encryption of the passed-in data.
          395         """
          396         paddedBytes = self._addPKCS1Padding(bytes, 2)
          397         m = bytesToNumber(paddedBytes)
          398         if m >= self.n:
          399             raise ValueError()
          400         c = self._rawPublicKeyOp(m)
          401         encBytes = numberToByteArray(c, numBytes(self.n))
          402         return encBytes
          403 
          404     def decrypt(self, encBytes):
          405         """Decrypt the passed-in bytes.
          406 
          407         This requires the key to have a private component.  It performs
          408         PKCS1 decryption of the passed-in data.
          409 
          410         @type encBytes: L{bytearray} of unsigned bytes
          411         @param encBytes: The value which will be decrypted.
          412 
          413         @rtype: L{bytearray} of unsigned bytes or None.
          414         @return: A PKCS1 decryption of the passed-in data or None if
          415         the data is not properly formatted.
          416         """
          417         if not self.hasPrivateKey():
          418             raise AssertionError()
          419         if len(encBytes) != numBytes(self.n):
          420             return None
          421         c = bytesToNumber(encBytes)
          422         if c >= self.n:
          423             return None
          424         m = self._rawPrivateKeyOp(c)
          425         decBytes = numberToByteArray(m, numBytes(self.n))
          426         #Check first two bytes
          427         if decBytes[0] != 0 or decBytes[1] != 2:
          428             return None
          429         #Scan through for zero separator
          430         for x in range(1, len(decBytes)-1):
          431             if decBytes[x]== 0:
          432                 break
          433         else:
          434             return None
          435         return decBytes[x+1:] #Return everything after the separator
          436 
          437 
          438 
          439 
          440     # **************************************************************************
          441     # Helper Functions for RSA Keys
          442     # **************************************************************************
          443 
          444     def _addPKCS1SHA1Prefix(self, bytes, withNULL=True):
          445         # There is a long history of confusion over whether the SHA1
          446         # algorithmIdentifier should be encoded with a NULL parameter or
          447         # with the parameter omitted.  While the original intention was
          448         # apparently to omit it, many toolkits went the other way.  TLS 1.2
          449         # specifies the NULL should be included, and this behavior is also
          450         # mandated in recent versions of PKCS #1, and is what tlslite has
          451         # always implemented.  Anyways, verification code should probably
          452         # accept both.  However, nothing uses this code yet, so this is
          453         # all fairly moot.
          454         if not withNULL:
          455             prefixBytes = bytearray(\
          456             [0x30,0x1f,0x30,0x07,0x06,0x05,0x2b,0x0e,0x03,0x02,0x1a,0x04,0x14])
          457         else:
          458             prefixBytes = bytearray(\
          459             [0x30,0x21,0x30,0x09,0x06,0x05,0x2b,0x0e,0x03,0x02,0x1a,0x05,0x00,0x04,0x14])
          460         prefixedBytes = prefixBytes + bytes
          461         return prefixedBytes
          462 
          463     def _addPKCS1Padding(self, bytes, blockType):
          464         padLength = (numBytes(self.n) - (len(bytes)+3))
          465         if blockType == 1: #Signature padding
          466             pad = [0xFF] * padLength
          467         elif blockType == 2: #Encryption padding
          468             pad = bytearray(0)
          469             while len(pad) < padLength:
          470                 padBytes = getRandomBytes(padLength * 2)
          471                 pad = [b for b in padBytes if b != 0]
          472                 pad = pad[:padLength]
          473         else:
          474             raise AssertionError()
          475 
          476         padding = bytearray([0,blockType] + pad + [0])
          477         paddedBytes = padding + bytes
          478         return paddedBytes
          479 
          480 
          481 
          482 
          483     def _rawPrivateKeyOp(self, m):
          484         #Create blinding values, on the first pass:
          485         if not self.blinder:
          486             self.unblinder = getRandomNumber(2, self.n)
          487             self.blinder = powMod(invMod(self.unblinder, self.n), self.e,
          488                                   self.n)
          489 
          490         #Blind the input
          491         m = (m * self.blinder) % self.n
          492 
          493         #Perform the RSA operation
          494         c = self._rawPrivateKeyOpHelper(m)
          495 
          496         #Unblind the output
          497         c = (c * self.unblinder) % self.n
          498 
          499         #Update blinding values
          500         self.blinder = (self.blinder * self.blinder) % self.n
          501         self.unblinder = (self.unblinder * self.unblinder) % self.n
          502 
          503         #Return the output
          504         return c
          505 
          506 
          507     def _rawPrivateKeyOpHelper(self, m):
          508         #Non-CRT version
          509         #c = powMod(m, self.d, self.n)
          510 
          511         #CRT version  (~3x faster)
          512         s1 = powMod(m, self.dP, self.p)
          513         s2 = powMod(m, self.dQ, self.q)
          514         h = ((s1 - s2) * self.qInv) % self.p
          515         c = s2 + self.q * h
          516         return c
          517 
          518     def _rawPublicKeyOp(self, c):
          519         m = powMod(c, self.e, self.n)
          520         return m
          521 
          522     def acceptsPassword(self):
          523         return False
          524 
          525     def generate(bits):
          526         key = RSAKey()
          527         p = getRandomPrime(bits//2, False)
          528         q = getRandomPrime(bits//2, False)
          529         t = lcm(p-1, q-1)
          530         key.n = p * q
          531         key.e = 65537
          532         key.d = invMod(key.e, t)
          533         key.p = p
          534         key.q = q
          535         key.dP = key.d % (p-1)
          536         key.dQ = key.d % (q-1)
          537         key.qInv = invMod(q, p)
          538         return key
          539     generate = staticmethod(generate)