trsakey.py - electrum - Electrum Bitcoin wallet
HTML git clone https://git.parazyd.org/electrum
DIR Log
DIR Files
DIR Refs
DIR Submodules
---
trsakey.py (16814B)
---
1 #!/usr/bin/env python
2 #
3 # Electrum - lightweight Bitcoin client
4 # Copyright (C) 2015 Thomas Voegtlin
5 #
6 # Permission is hereby granted, free of charge, to any person
7 # obtaining a copy of this software and associated documentation files
8 # (the "Software"), to deal in the Software without restriction,
9 # including without limitation the rights to use, copy, modify, merge,
10 # publish, distribute, sublicense, and/or sell copies of the Software,
11 # and to permit persons to whom the Software is furnished to do so,
12 # subject to the following conditions:
13 #
14 # The above copyright notice and this permission notice shall be
15 # included in all copies or substantial portions of the Software.
16 #
17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18 # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
21 # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
22 # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
23 # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 # SOFTWARE.
25
26 # This module uses functions from TLSLite (public domain)
27 #
28 # TLSLite Authors:
29 # Trevor Perrin
30 # Martin von Loewis - python 3 port
31 # Yngve Pettersen (ported by Paul Sokolovsky) - TLS 1.2
32 #
33
34 """Pure-Python RSA implementation."""
35
36 import os
37 import math
38 import hashlib
39
40
41 def SHA1(x):
42 return hashlib.sha1(x).digest()
43
44
45 # **************************************************************************
46 # PRNG Functions
47 # **************************************************************************
48
49 # Check that os.urandom works
50 import zlib
51 length = len(zlib.compress(os.urandom(1000)))
52 assert(length > 900)
53
54 def getRandomBytes(howMany):
55 b = bytearray(os.urandom(howMany))
56 assert(len(b) == howMany)
57 return b
58
59 prngName = "os.urandom"
60
61
62 # **************************************************************************
63 # Converter Functions
64 # **************************************************************************
65
66 def bytesToNumber(b):
67 total = 0
68 multiplier = 1
69 for count in range(len(b)-1, -1, -1):
70 byte = b[count]
71 total += multiplier * byte
72 multiplier *= 256
73 return total
74
75 def numberToByteArray(n, howManyBytes=None):
76 """Convert an integer into a bytearray, zero-pad to howManyBytes.
77
78 The returned bytearray may be smaller than howManyBytes, but will
79 not be larger. The returned bytearray will contain a big-endian
80 encoding of the input integer (n).
81 """
82 if howManyBytes == None:
83 howManyBytes = numBytes(n)
84 b = bytearray(howManyBytes)
85 for count in range(howManyBytes-1, -1, -1):
86 b[count] = int(n % 256)
87 n >>= 8
88 return b
89
90 def mpiToNumber(mpi): #mpi is an openssl-format bignum string
91 if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number
92 raise AssertionError()
93 b = bytearray(mpi[4:])
94 return bytesToNumber(b)
95
96 def numberToMPI(n):
97 b = numberToByteArray(n)
98 ext = 0
99 #If the high-order bit is going to be set,
100 #add an extra byte of zeros
101 if (numBits(n) & 0x7)==0:
102 ext = 1
103 length = numBytes(n) + ext
104 b = bytearray(4+ext) + b
105 b[0] = (length >> 24) & 0xFF
106 b[1] = (length >> 16) & 0xFF
107 b[2] = (length >> 8) & 0xFF
108 b[3] = length & 0xFF
109 return bytes(b)
110
111
112 # **************************************************************************
113 # Misc. Utility Functions
114 # **************************************************************************
115
116 def numBits(n):
117 if n==0:
118 return 0
119 s = "%x" % n
120 return ((len(s)-1)*4) + \
121 {'0':0, '1':1, '2':2, '3':2,
122 '4':3, '5':3, '6':3, '7':3,
123 '8':4, '9':4, 'a':4, 'b':4,
124 'c':4, 'd':4, 'e':4, 'f':4,
125 }[s[0]]
126
127 def numBytes(n):
128 if n==0:
129 return 0
130 bits = numBits(n)
131 return int(math.ceil(bits / 8.0))
132
133 # **************************************************************************
134 # Big Number Math
135 # **************************************************************************
136
137 def getRandomNumber(low, high):
138 if low >= high:
139 raise AssertionError()
140 howManyBits = numBits(high)
141 howManyBytes = numBytes(high)
142 lastBits = howManyBits % 8
143 while 1:
144 bytes = getRandomBytes(howManyBytes)
145 if lastBits:
146 bytes[0] = bytes[0] % (1 << lastBits)
147 n = bytesToNumber(bytes)
148 if n >= low and n < high:
149 return n
150
151 def gcd(a,b):
152 a, b = max(a,b), min(a,b)
153 while b:
154 a, b = b, a % b
155 return a
156
157 def lcm(a, b):
158 return (a * b) // gcd(a, b)
159
160 #Returns inverse of a mod b, zero if none
161 #Uses Extended Euclidean Algorithm
162 def invMod(a, b):
163 c, d = a, b
164 uc, ud = 1, 0
165 while c != 0:
166 q = d // c
167 c, d = d-(q*c), c
168 uc, ud = ud - (q * uc), uc
169 if d == 1:
170 return ud % b
171 return 0
172
173
174 def powMod(base, power, modulus):
175 if power < 0:
176 result = pow(base, power*-1, modulus)
177 result = invMod(result, modulus)
178 return result
179 else:
180 return pow(base, power, modulus)
181
182 #Pre-calculate a sieve of the ~100 primes < 1000:
183 def makeSieve(n):
184 sieve = list(range(n))
185 for count in range(2, int(math.sqrt(n))+1):
186 if sieve[count] == 0:
187 continue
188 x = sieve[count] * 2
189 while x < len(sieve):
190 sieve[x] = 0
191 x += sieve[count]
192 sieve = [x for x in sieve[2:] if x]
193 return sieve
194
195 sieve = makeSieve(1000)
196
197 def isPrime(n, iterations=5, display=False):
198 #Trial division with sieve
199 for x in sieve:
200 if x >= n: return True
201 if n % x == 0: return False
202 #Passed trial division, proceed to Rabin-Miller
203 #Rabin-Miller implemented per Ferguson & Schneier
204 #Compute s, t for Rabin-Miller
205 if display: print("*", end=' ')
206 s, t = n-1, 0
207 while s % 2 == 0:
208 s, t = s//2, t+1
209 #Repeat Rabin-Miller x times
210 a = 2 #Use 2 as a base for first iteration speedup, per HAC
211 for count in range(iterations):
212 v = powMod(a, s, n)
213 if v==1:
214 continue
215 i = 0
216 while v != n-1:
217 if i == t-1:
218 return False
219 else:
220 v, i = powMod(v, 2, n), i+1
221 a = getRandomNumber(2, n)
222 return True
223
224 def getRandomPrime(bits, display=False):
225 if bits < 10:
226 raise AssertionError()
227 #The 1.5 ensures the 2 MSBs are set
228 #Thus, when used for p,q in RSA, n will have its MSB set
229 #
230 #Since 30 is lcm(2,3,5), we'll set our test numbers to
231 #29 % 30 and keep them there
232 low = ((2 ** (bits-1)) * 3) // 2
233 high = 2 ** bits - 30
234 p = getRandomNumber(low, high)
235 p += 29 - (p % 30)
236 while 1:
237 if display: print(".", end=' ')
238 p += 30
239 if p >= high:
240 p = getRandomNumber(low, high)
241 p += 29 - (p % 30)
242 if isPrime(p, display=display):
243 return p
244
245 #Unused at the moment...
246 def getRandomSafePrime(bits, display=False):
247 if bits < 10:
248 raise AssertionError()
249 #The 1.5 ensures the 2 MSBs are set
250 #Thus, when used for p,q in RSA, n will have its MSB set
251 #
252 #Since 30 is lcm(2,3,5), we'll set our test numbers to
253 #29 % 30 and keep them there
254 low = (2 ** (bits-2)) * 3//2
255 high = (2 ** (bits-1)) - 30
256 q = getRandomNumber(low, high)
257 q += 29 - (q % 30)
258 while 1:
259 if display: print(".", end=' ')
260 q += 30
261 if (q >= high):
262 q = getRandomNumber(low, high)
263 q += 29 - (q % 30)
264 #Ideas from Tom Wu's SRP code
265 #Do trial division on p and q before Rabin-Miller
266 if isPrime(q, 0, display=display):
267 p = (2 * q) + 1
268 if isPrime(p, display=display):
269 if isPrime(q, display=display):
270 return p
271
272
273 class RSAKey(object):
274
275 def __init__(self, n=0, e=0, d=0, p=0, q=0, dP=0, dQ=0, qInv=0):
276 if (n and not e) or (e and not n):
277 raise AssertionError()
278 self.n = n
279 self.e = e
280 self.d = d
281 self.p = p
282 self.q = q
283 self.dP = dP
284 self.dQ = dQ
285 self.qInv = qInv
286 self.blinder = 0
287 self.unblinder = 0
288
289 def __len__(self):
290 """Return the length of this key in bits.
291
292 @rtype: int
293 """
294 return numBits(self.n)
295
296 def hasPrivateKey(self):
297 return self.d != 0
298
299 def hashAndSign(self, bytes):
300 """Hash and sign the passed-in bytes.
301
302 This requires the key to have a private component. It performs
303 a PKCS1-SHA1 signature on the passed-in data.
304
305 @type bytes: str or L{bytearray} of unsigned bytes
306 @param bytes: The value which will be hashed and signed.
307
308 @rtype: L{bytearray} of unsigned bytes.
309 @return: A PKCS1-SHA1 signature on the passed-in data.
310 """
311 hashBytes = SHA1(bytearray(bytes))
312 prefixedHashBytes = self._addPKCS1SHA1Prefix(hashBytes)
313 sigBytes = self.sign(prefixedHashBytes)
314 return sigBytes
315
316 def hashAndVerify(self, sigBytes, bytes):
317 """Hash and verify the passed-in bytes with the signature.
318
319 This verifies a PKCS1-SHA1 signature on the passed-in data.
320
321 @type sigBytes: L{bytearray} of unsigned bytes
322 @param sigBytes: A PKCS1-SHA1 signature.
323
324 @type bytes: str or L{bytearray} of unsigned bytes
325 @param bytes: The value which will be hashed and verified.
326
327 @rtype: bool
328 @return: Whether the signature matches the passed-in data.
329 """
330 hashBytes = SHA1(bytearray(bytes))
331
332 # Try it with/without the embedded NULL
333 prefixedHashBytes1 = self._addPKCS1SHA1Prefix(hashBytes, False)
334 prefixedHashBytes2 = self._addPKCS1SHA1Prefix(hashBytes, True)
335 result1 = self.verify(sigBytes, prefixedHashBytes1)
336 result2 = self.verify(sigBytes, prefixedHashBytes2)
337 return (result1 or result2)
338
339 def sign(self, bytes):
340 """Sign the passed-in bytes.
341
342 This requires the key to have a private component. It performs
343 a PKCS1 signature on the passed-in data.
344
345 @type bytes: L{bytearray} of unsigned bytes
346 @param bytes: The value which will be signed.
347
348 @rtype: L{bytearray} of unsigned bytes.
349 @return: A PKCS1 signature on the passed-in data.
350 """
351 if not self.hasPrivateKey():
352 raise AssertionError()
353 paddedBytes = self._addPKCS1Padding(bytes, 1)
354 m = bytesToNumber(paddedBytes)
355 if m >= self.n:
356 raise ValueError()
357 c = self._rawPrivateKeyOp(m)
358 sigBytes = numberToByteArray(c, numBytes(self.n))
359 return sigBytes
360
361 def verify(self, sigBytes, bytes):
362 """Verify the passed-in bytes with the signature.
363
364 This verifies a PKCS1 signature on the passed-in data.
365
366 @type sigBytes: L{bytearray} of unsigned bytes
367 @param sigBytes: A PKCS1 signature.
368
369 @type bytes: L{bytearray} of unsigned bytes
370 @param bytes: The value which will be verified.
371
372 @rtype: bool
373 @return: Whether the signature matches the passed-in data.
374 """
375 if len(sigBytes) != numBytes(self.n):
376 return False
377 paddedBytes = self._addPKCS1Padding(bytes, 1)
378 c = bytesToNumber(sigBytes)
379 if c >= self.n:
380 return False
381 m = self._rawPublicKeyOp(c)
382 checkBytes = numberToByteArray(m, numBytes(self.n))
383 return checkBytes == paddedBytes
384
385 def encrypt(self, bytes):
386 """Encrypt the passed-in bytes.
387
388 This performs PKCS1 encryption of the passed-in data.
389
390 @type bytes: L{bytearray} of unsigned bytes
391 @param bytes: The value which will be encrypted.
392
393 @rtype: L{bytearray} of unsigned bytes.
394 @return: A PKCS1 encryption of the passed-in data.
395 """
396 paddedBytes = self._addPKCS1Padding(bytes, 2)
397 m = bytesToNumber(paddedBytes)
398 if m >= self.n:
399 raise ValueError()
400 c = self._rawPublicKeyOp(m)
401 encBytes = numberToByteArray(c, numBytes(self.n))
402 return encBytes
403
404 def decrypt(self, encBytes):
405 """Decrypt the passed-in bytes.
406
407 This requires the key to have a private component. It performs
408 PKCS1 decryption of the passed-in data.
409
410 @type encBytes: L{bytearray} of unsigned bytes
411 @param encBytes: The value which will be decrypted.
412
413 @rtype: L{bytearray} of unsigned bytes or None.
414 @return: A PKCS1 decryption of the passed-in data or None if
415 the data is not properly formatted.
416 """
417 if not self.hasPrivateKey():
418 raise AssertionError()
419 if len(encBytes) != numBytes(self.n):
420 return None
421 c = bytesToNumber(encBytes)
422 if c >= self.n:
423 return None
424 m = self._rawPrivateKeyOp(c)
425 decBytes = numberToByteArray(m, numBytes(self.n))
426 #Check first two bytes
427 if decBytes[0] != 0 or decBytes[1] != 2:
428 return None
429 #Scan through for zero separator
430 for x in range(1, len(decBytes)-1):
431 if decBytes[x]== 0:
432 break
433 else:
434 return None
435 return decBytes[x+1:] #Return everything after the separator
436
437
438
439
440 # **************************************************************************
441 # Helper Functions for RSA Keys
442 # **************************************************************************
443
444 def _addPKCS1SHA1Prefix(self, bytes, withNULL=True):
445 # There is a long history of confusion over whether the SHA1
446 # algorithmIdentifier should be encoded with a NULL parameter or
447 # with the parameter omitted. While the original intention was
448 # apparently to omit it, many toolkits went the other way. TLS 1.2
449 # specifies the NULL should be included, and this behavior is also
450 # mandated in recent versions of PKCS #1, and is what tlslite has
451 # always implemented. Anyways, verification code should probably
452 # accept both. However, nothing uses this code yet, so this is
453 # all fairly moot.
454 if not withNULL:
455 prefixBytes = bytearray(\
456 [0x30,0x1f,0x30,0x07,0x06,0x05,0x2b,0x0e,0x03,0x02,0x1a,0x04,0x14])
457 else:
458 prefixBytes = bytearray(\
459 [0x30,0x21,0x30,0x09,0x06,0x05,0x2b,0x0e,0x03,0x02,0x1a,0x05,0x00,0x04,0x14])
460 prefixedBytes = prefixBytes + bytes
461 return prefixedBytes
462
463 def _addPKCS1Padding(self, bytes, blockType):
464 padLength = (numBytes(self.n) - (len(bytes)+3))
465 if blockType == 1: #Signature padding
466 pad = [0xFF] * padLength
467 elif blockType == 2: #Encryption padding
468 pad = bytearray(0)
469 while len(pad) < padLength:
470 padBytes = getRandomBytes(padLength * 2)
471 pad = [b for b in padBytes if b != 0]
472 pad = pad[:padLength]
473 else:
474 raise AssertionError()
475
476 padding = bytearray([0,blockType] + pad + [0])
477 paddedBytes = padding + bytes
478 return paddedBytes
479
480
481
482
483 def _rawPrivateKeyOp(self, m):
484 #Create blinding values, on the first pass:
485 if not self.blinder:
486 self.unblinder = getRandomNumber(2, self.n)
487 self.blinder = powMod(invMod(self.unblinder, self.n), self.e,
488 self.n)
489
490 #Blind the input
491 m = (m * self.blinder) % self.n
492
493 #Perform the RSA operation
494 c = self._rawPrivateKeyOpHelper(m)
495
496 #Unblind the output
497 c = (c * self.unblinder) % self.n
498
499 #Update blinding values
500 self.blinder = (self.blinder * self.blinder) % self.n
501 self.unblinder = (self.unblinder * self.unblinder) % self.n
502
503 #Return the output
504 return c
505
506
507 def _rawPrivateKeyOpHelper(self, m):
508 #Non-CRT version
509 #c = powMod(m, self.d, self.n)
510
511 #CRT version (~3x faster)
512 s1 = powMod(m, self.dP, self.p)
513 s2 = powMod(m, self.dQ, self.q)
514 h = ((s1 - s2) * self.qInv) % self.p
515 c = s2 + self.q * h
516 return c
517
518 def _rawPublicKeyOp(self, c):
519 m = powMod(c, self.e, self.n)
520 return m
521
522 def acceptsPassword(self):
523 return False
524
525 def generate(bits):
526 key = RSAKey()
527 p = getRandomPrime(bits//2, False)
528 q = getRandomPrime(bits//2, False)
529 t = lcm(p-1, q-1)
530 key.n = p * q
531 key.e = 65537
532 key.d = invMod(key.e, t)
533 key.p = p
534 key.q = q
535 key.dP = key.d % (p-1)
536 key.dQ = key.d % (q-1)
537 key.qInv = invMod(q, p)
538 return key
539 generate = staticmethod(generate)