URI: 
       tlnaddr.py - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
       tlnaddr.py (19074B)
       ---
            1 #! /usr/bin/env python3
            2 # This was forked from https://github.com/rustyrussell/lightning-payencode/tree/acc16ec13a3fa1dc16c07af6ec67c261bd8aff23
            3 
            4 import re
            5 import time
            6 from hashlib import sha256
            7 from binascii import hexlify
            8 from decimal import Decimal
            9 from typing import Optional, TYPE_CHECKING
           10 
           11 import random
           12 import bitstring
           13 
           14 from .bitcoin import hash160_to_b58_address, b58_address_to_hash160, TOTAL_COIN_SUPPLY_LIMIT_IN_BTC
           15 from .segwit_addr import bech32_encode, bech32_decode, CHARSET
           16 from . import constants
           17 from . import ecc
           18 from .bitcoin import COIN
           19 
           20 if TYPE_CHECKING:
           21     from .lnutil import LnFeatures
           22 
           23 
           24 # BOLT #11:
           25 #
           26 # A writer MUST encode `amount` as a positive decimal integer with no
           27 # leading zeroes, SHOULD use the shortest representation possible.
           28 def shorten_amount(amount):
           29     """ Given an amount in bitcoin, shorten it
           30     """
           31     # Convert to pico initially
           32     amount = int(amount * 10**12)
           33     units = ['p', 'n', 'u', 'm', '']
           34     for unit in units:
           35         if amount % 1000 == 0:
           36             amount //= 1000
           37         else:
           38             break
           39     return str(amount) + unit
           40 
           41 def unshorten_amount(amount) -> Decimal:
           42     """ Given a shortened amount, convert it into a decimal
           43     """
           44     # BOLT #11:
           45     # The following `multiplier` letters are defined:
           46     #
           47     #* `m` (milli): multiply by 0.001
           48     #* `u` (micro): multiply by 0.000001
           49     #* `n` (nano): multiply by 0.000000001
           50     #* `p` (pico): multiply by 0.000000000001
           51     units = {
           52         'p': 10**12,
           53         'n': 10**9,
           54         'u': 10**6,
           55         'm': 10**3,
           56     }
           57     unit = str(amount)[-1]
           58     # BOLT #11:
           59     # A reader SHOULD fail if `amount` contains a non-digit, or is followed by
           60     # anything except a `multiplier` in the table above.
           61     if not re.fullmatch("\\d+[pnum]?", str(amount)):
           62         raise ValueError("Invalid amount '{}'".format(amount))
           63 
           64     if unit in units.keys():
           65         return Decimal(amount[:-1]) / units[unit]
           66     else:
           67         return Decimal(amount)
           68 
           69 _INT_TO_BINSTR = {a: '0' * (5-len(bin(a)[2:])) + bin(a)[2:] for a in range(32)}
           70 
           71 # Bech32 spits out array of 5-bit values.  Shim here.
           72 def u5_to_bitarray(arr):
           73     b = ''.join(_INT_TO_BINSTR[a] for a in arr)
           74     return bitstring.BitArray(bin=b)
           75 
           76 def bitarray_to_u5(barr):
           77     assert barr.len % 5 == 0
           78     ret = []
           79     s = bitstring.ConstBitStream(barr)
           80     while s.pos != s.len:
           81         ret.append(s.read(5).uint)
           82     return ret
           83 
           84 def encode_fallback(fallback, currency):
           85     """ Encode all supported fallback addresses.
           86     """
           87     if currency in [constants.BitcoinMainnet.SEGWIT_HRP, constants.BitcoinTestnet.SEGWIT_HRP]:
           88         fbhrp, witness = bech32_decode(fallback, ignore_long_length=True)
           89         if fbhrp:
           90             if fbhrp != currency:
           91                 raise ValueError("Not a bech32 address for this currency")
           92             wver = witness[0]
           93             if wver > 16:
           94                 raise ValueError("Invalid witness version {}".format(witness[0]))
           95             wprog = u5_to_bitarray(witness[1:])
           96         else:
           97             addrtype, addr = b58_address_to_hash160(fallback)
           98             if is_p2pkh(currency, addrtype):
           99                 wver = 17
          100             elif is_p2sh(currency, addrtype):
          101                 wver = 18
          102             else:
          103                 raise ValueError("Unknown address type for {}".format(currency))
          104             wprog = addr
          105         return tagged('f', bitstring.pack("uint:5", wver) + wprog)
          106     else:
          107         raise NotImplementedError("Support for currency {} not implemented".format(currency))
          108 
          109 def parse_fallback(fallback, currency):
          110     if currency in [constants.BitcoinMainnet.SEGWIT_HRP, constants.BitcoinTestnet.SEGWIT_HRP]:
          111         wver = fallback[0:5].uint
          112         if wver == 17:
          113             addr=hash160_to_b58_address(fallback[5:].tobytes(), base58_prefix_map[currency][0])
          114         elif wver == 18:
          115             addr=hash160_to_b58_address(fallback[5:].tobytes(), base58_prefix_map[currency][1])
          116         elif wver <= 16:
          117             addr=bech32_encode(currency, bitarray_to_u5(fallback))
          118         else:
          119             return None
          120     else:
          121         addr=fallback.tobytes()
          122     return addr
          123 
          124 
          125 # Map of classical and witness address prefixes
          126 base58_prefix_map = {
          127     constants.BitcoinMainnet.SEGWIT_HRP : (constants.BitcoinMainnet.ADDRTYPE_P2PKH, constants.BitcoinMainnet.ADDRTYPE_P2SH),
          128     constants.BitcoinTestnet.SEGWIT_HRP : (constants.BitcoinTestnet.ADDRTYPE_P2PKH, constants.BitcoinTestnet.ADDRTYPE_P2SH)
          129 }
          130 
          131 def is_p2pkh(currency, prefix):
          132     return prefix == base58_prefix_map[currency][0]
          133 
          134 def is_p2sh(currency, prefix):
          135     return prefix == base58_prefix_map[currency][1]
          136 
          137 # Tagged field containing BitArray
          138 def tagged(char, l):
          139     # Tagged fields need to be zero-padded to 5 bits.
          140     while l.len % 5 != 0:
          141         l.append('0b0')
          142     return bitstring.pack("uint:5, uint:5, uint:5",
          143                           CHARSET.find(char),
          144                           (l.len / 5) / 32, (l.len / 5) % 32) + l
          145 
          146 # Tagged field containing bytes
          147 def tagged_bytes(char, l):
          148     return tagged(char, bitstring.BitArray(l))
          149 
          150 def trim_to_min_length(bits):
          151     """Ensures 'bits' have min number of leading zeroes.
          152     Assumes 'bits' is big-endian, and that it needs to be encoded in 5 bit blocks.
          153     """
          154     bits = bits[:]  # copy
          155     # make sure we can be split into 5 bit blocks
          156     while bits.len % 5 != 0:
          157         bits.prepend('0b0')
          158     # Get minimal length by trimming leading 5 bits at a time.
          159     while bits.startswith('0b00000'):
          160         if len(bits) == 5:
          161             break  # v == 0
          162         bits = bits[5:]
          163     return bits
          164 
          165 # Discard trailing bits, convert to bytes.
          166 def trim_to_bytes(barr):
          167     # Adds a byte if necessary.
          168     b = barr.tobytes()
          169     if barr.len % 8 != 0:
          170         return b[:-1]
          171     return b
          172 
          173 # Try to pull out tagged data: returns tag, tagged data and remainder.
          174 def pull_tagged(stream):
          175     tag = stream.read(5).uint
          176     length = stream.read(5).uint * 32 + stream.read(5).uint
          177     return (CHARSET[tag], stream.read(length * 5), stream)
          178 
          179 def lnencode(addr: 'LnAddr', privkey) -> str:
          180     if addr.amount:
          181         amount = addr.currency + shorten_amount(addr.amount)
          182     else:
          183         amount = addr.currency if addr.currency else ''
          184 
          185     hrp = 'ln' + amount
          186 
          187     # Start with the timestamp
          188     data = bitstring.pack('uint:35', addr.date)
          189 
          190     tags_set = set()
          191 
          192     # Payment hash
          193     data += tagged_bytes('p', addr.paymenthash)
          194     tags_set.add('p')
          195 
          196     if addr.payment_secret is not None:
          197         data += tagged_bytes('s', addr.payment_secret)
          198         tags_set.add('s')
          199 
          200     for k, v in addr.tags:
          201 
          202         # BOLT #11:
          203         #
          204         # A writer MUST NOT include more than one `d`, `h`, `n` or `x` fields,
          205         if k in ('d', 'h', 'n', 'x', 'p', 's'):
          206             if k in tags_set:
          207                 raise ValueError("Duplicate '{}' tag".format(k))
          208 
          209         if k == 'r':
          210             route = bitstring.BitArray()
          211             for step in v:
          212                 pubkey, channel, feebase, feerate, cltv = step
          213                 route.append(bitstring.BitArray(pubkey) + bitstring.BitArray(channel) + bitstring.pack('intbe:32', feebase) + bitstring.pack('intbe:32', feerate) + bitstring.pack('intbe:16', cltv))
          214             data += tagged('r', route)
          215         elif k == 't':
          216             pubkey, feebase, feerate, cltv = v
          217             route = bitstring.BitArray(pubkey) + bitstring.pack('intbe:32', feebase) + bitstring.pack('intbe:32', feerate) + bitstring.pack('intbe:16', cltv)
          218             data += tagged('t', route)
          219         elif k == 'f':
          220             data += encode_fallback(v, addr.currency)
          221         elif k == 'd':
          222             # truncate to max length: 1024*5 bits = 639 bytes
          223             data += tagged_bytes('d', v.encode()[0:639])
          224         elif k == 'x':
          225             expirybits = bitstring.pack('intbe:64', v)
          226             expirybits = trim_to_min_length(expirybits)
          227             data += tagged('x', expirybits)
          228         elif k == 'h':
          229             data += tagged_bytes('h', sha256(v.encode('utf-8')).digest())
          230         elif k == 'n':
          231             data += tagged_bytes('n', v)
          232         elif k == 'c':
          233             finalcltvbits = bitstring.pack('intbe:64', v)
          234             finalcltvbits = trim_to_min_length(finalcltvbits)
          235             data += tagged('c', finalcltvbits)
          236         elif k == '9':
          237             if v == 0:
          238                 continue
          239             feature_bits = bitstring.BitArray(uint=v, length=v.bit_length())
          240             feature_bits = trim_to_min_length(feature_bits)
          241             data += tagged('9', feature_bits)
          242         else:
          243             # FIXME: Support unknown tags?
          244             raise ValueError("Unknown tag {}".format(k))
          245 
          246         tags_set.add(k)
          247 
          248     # BOLT #11:
          249     #
          250     # A writer MUST include either a `d` or `h` field, and MUST NOT include
          251     # both.
          252     if 'd' in tags_set and 'h' in tags_set:
          253         raise ValueError("Cannot include both 'd' and 'h'")
          254     if not 'd' in tags_set and not 'h' in tags_set:
          255         raise ValueError("Must include either 'd' or 'h'")
          256 
          257     # We actually sign the hrp, then data (padded to 8 bits with zeroes).
          258     msg = hrp.encode("ascii") + data.tobytes()
          259     privkey = ecc.ECPrivkey(privkey)
          260     sig = privkey.sign_message(msg, is_compressed=False, algo=lambda x:sha256(x).digest())
          261     recovery_flag = bytes([sig[0] - 27])
          262     sig = bytes(sig[1:]) + recovery_flag
          263     data += sig
          264 
          265     return bech32_encode(hrp, bitarray_to_u5(data))
          266 
          267 class LnAddr(object):
          268     def __init__(self, *, paymenthash: bytes = None, amount=None, currency=None, tags=None, date=None,
          269                  payment_secret: bytes = None):
          270         self.date = int(time.time()) if not date else int(date)
          271         self.tags = [] if not tags else tags
          272         self.unknown_tags = []
          273         self.paymenthash = paymenthash
          274         self.payment_secret = payment_secret
          275         self.signature = None
          276         self.pubkey = None
          277         self.currency = constants.net.SEGWIT_HRP if currency is None else currency
          278         self._amount = amount  # type: Optional[Decimal]  # in bitcoins
          279         self._min_final_cltv_expiry = 18
          280 
          281     @property
          282     def amount(self) -> Optional[Decimal]:
          283         return self._amount
          284 
          285     @amount.setter
          286     def amount(self, value):
          287         if not (isinstance(value, Decimal) or value is None):
          288             raise ValueError(f"amount must be Decimal or None, not {value!r}")
          289         if value is None:
          290             self._amount = None
          291             return
          292         assert isinstance(value, Decimal)
          293         if value.is_nan() or not (0 <= value <= TOTAL_COIN_SUPPLY_LIMIT_IN_BTC):
          294             raise ValueError(f"amount is out-of-bounds: {value!r} BTC")
          295         if value * 10**12 % 10:
          296             # max resolution is millisatoshi
          297             raise ValueError(f"Cannot encode {value!r}: too many decimal places")
          298         self._amount = value
          299 
          300     def get_amount_sat(self) -> Optional[Decimal]:
          301         # note that this has msat resolution potentially
          302         if self.amount is None:
          303             return None
          304         return self.amount * COIN
          305 
          306     def get_routing_info(self, tag):
          307         # note: tag will be 't' for trampoline
          308         r_tags = list(filter(lambda x: x[0] == tag, self.tags))
          309         # strip the tag type, it's implicitly 'r' now
          310         r_tags = list(map(lambda x: x[1], r_tags))
          311         # if there are multiple hints, we will use the first one that works,
          312         # from a random permutation
          313         random.shuffle(r_tags)
          314         return r_tags
          315 
          316     def get_amount_msat(self) -> Optional[int]:
          317         if self.amount is None:
          318             return None
          319         return int(self.amount * COIN * 1000)
          320 
          321     def get_features(self) -> 'LnFeatures':
          322         from .lnutil import LnFeatures
          323         return LnFeatures(self.get_tag('9') or 0)
          324 
          325     def __str__(self):
          326         return "LnAddr[{}, amount={}{} tags=[{}]]".format(
          327             hexlify(self.pubkey.serialize()).decode('utf-8') if self.pubkey else None,
          328             self.amount, self.currency,
          329             ", ".join([k + '=' + str(v) for k, v in self.tags])
          330         )
          331 
          332     def get_min_final_cltv_expiry(self) -> int:
          333         return self._min_final_cltv_expiry
          334 
          335     def get_tag(self, tag):
          336         for k, v in self.tags:
          337             if k == tag:
          338                 return v
          339         return None
          340 
          341     def get_description(self) -> str:
          342         return self.get_tag('d') or ''
          343 
          344     def get_expiry(self) -> int:
          345         exp = self.get_tag('x')
          346         if exp is None:
          347             exp = 3600
          348         return int(exp)
          349 
          350     def is_expired(self) -> bool:
          351         now = time.time()
          352         # BOLT-11 does not specify what expiration of '0' means.
          353         # we treat it as 0 seconds here (instead of never)
          354         return now > self.get_expiry() + self.date
          355 
          356 
          357 class LnDecodeException(Exception): pass
          358 
          359 class SerializableKey:
          360     def __init__(self, pubkey):
          361         self.pubkey = pubkey
          362     def serialize(self):
          363         return self.pubkey.get_public_key_bytes(True)
          364 
          365 def lndecode(invoice: str, *, verbose=False, expected_hrp=None) -> LnAddr:
          366     if expected_hrp is None:
          367         expected_hrp = constants.net.SEGWIT_HRP
          368     hrp, data = bech32_decode(invoice, ignore_long_length=True)
          369     if not hrp:
          370         raise ValueError("Bad bech32 checksum")
          371 
          372     # BOLT #11:
          373     #
          374     # A reader MUST fail if it does not understand the `prefix`.
          375     if not hrp.startswith('ln'):
          376         raise ValueError("Does not start with ln")
          377 
          378     if not hrp[2:].startswith(expected_hrp):
          379         raise ValueError("Wrong Lightning invoice HRP " + hrp[2:] + ", should be " + expected_hrp)
          380 
          381     data = u5_to_bitarray(data)
          382 
          383     # Final signature 65 bytes, split it off.
          384     if len(data) < 65*8:
          385         raise ValueError("Too short to contain signature")
          386     sigdecoded = data[-65*8:].tobytes()
          387     data = bitstring.ConstBitStream(data[:-65*8])
          388 
          389     addr = LnAddr()
          390     addr.pubkey = None
          391 
          392     m = re.search("[^\\d]+", hrp[2:])
          393     if m:
          394         addr.currency = m.group(0)
          395         amountstr = hrp[2+m.end():]
          396         # BOLT #11:
          397         #
          398         # A reader SHOULD indicate if amount is unspecified, otherwise it MUST
          399         # multiply `amount` by the `multiplier` value (if any) to derive the
          400         # amount required for payment.
          401         if amountstr != '':
          402             addr.amount = unshorten_amount(amountstr)
          403 
          404     addr.date = data.read(35).uint
          405 
          406     while data.pos != data.len:
          407         tag, tagdata, data = pull_tagged(data)
          408 
          409         # BOLT #11:
          410         #
          411         # A reader MUST skip over unknown fields, an `f` field with unknown
          412         # `version`, or a `p`, `h`, or `n` field which does not have
          413         # `data_length` 52, 52, or 53 respectively.
          414         data_length = len(tagdata) / 5
          415 
          416         if tag == 'r':
          417             # BOLT #11:
          418             #
          419             # * `r` (3): `data_length` variable.  One or more entries
          420             # containing extra routing information for a private route;
          421             # there may be more than one `r` field, too.
          422             #    * `pubkey` (264 bits)
          423             #    * `short_channel_id` (64 bits)
          424             #    * `feebase` (32 bits, big-endian)
          425             #    * `feerate` (32 bits, big-endian)
          426             #    * `cltv_expiry_delta` (16 bits, big-endian)
          427             route=[]
          428             s = bitstring.ConstBitStream(tagdata)
          429             while s.pos + 264 + 64 + 32 + 32 + 16 < s.len:
          430                 route.append((s.read(264).tobytes(),
          431                               s.read(64).tobytes(),
          432                               s.read(32).uintbe,
          433                               s.read(32).uintbe,
          434                               s.read(16).uintbe))
          435             addr.tags.append(('r',route))
          436         elif tag == 't':
          437             s = bitstring.ConstBitStream(tagdata)
          438             e = (s.read(264).tobytes(),
          439                  s.read(32).uintbe,
          440                  s.read(32).uintbe,
          441                  s.read(16).uintbe)
          442             addr.tags.append(('t', e))
          443         elif tag == 'f':
          444             fallback = parse_fallback(tagdata, addr.currency)
          445             if fallback:
          446                 addr.tags.append(('f', fallback))
          447             else:
          448                 # Incorrect version.
          449                 addr.unknown_tags.append((tag, tagdata))
          450                 continue
          451 
          452         elif tag == 'd':
          453             addr.tags.append(('d', trim_to_bytes(tagdata).decode('utf-8')))
          454 
          455         elif tag == 'h':
          456             if data_length != 52:
          457                 addr.unknown_tags.append((tag, tagdata))
          458                 continue
          459             addr.tags.append(('h', trim_to_bytes(tagdata)))
          460 
          461         elif tag == 'x':
          462             addr.tags.append(('x', tagdata.uint))
          463 
          464         elif tag == 'p':
          465             if data_length != 52:
          466                 addr.unknown_tags.append((tag, tagdata))
          467                 continue
          468             addr.paymenthash = trim_to_bytes(tagdata)
          469 
          470         elif tag == 's':
          471             if data_length != 52:
          472                 addr.unknown_tags.append((tag, tagdata))
          473                 continue
          474             addr.payment_secret = trim_to_bytes(tagdata)
          475 
          476         elif tag == 'n':
          477             if data_length != 53:
          478                 addr.unknown_tags.append((tag, tagdata))
          479                 continue
          480             pubkeybytes = trim_to_bytes(tagdata)
          481             addr.pubkey = pubkeybytes
          482 
          483         elif tag == 'c':
          484             addr._min_final_cltv_expiry = tagdata.uint
          485 
          486         elif tag == '9':
          487             features = tagdata.uint
          488             addr.tags.append(('9', features))
          489             from .lnutil import validate_features
          490             validate_features(features)
          491 
          492         else:
          493             addr.unknown_tags.append((tag, tagdata))
          494 
          495     if verbose:
          496         print('hex of signature data (32 byte r, 32 byte s): {}'
          497               .format(hexlify(sigdecoded[0:64])))
          498         print('recovery flag: {}'.format(sigdecoded[64]))
          499         print('hex of data for signing: {}'
          500               .format(hexlify(hrp.encode("ascii") + data.tobytes())))
          501         print('SHA256 of above: {}'.format(sha256(hrp.encode("ascii") + data.tobytes()).hexdigest()))
          502 
          503     # BOLT #11:
          504     #
          505     # A reader MUST check that the `signature` is valid (see the `n` tagged
          506     # field specified below).
          507     addr.signature = sigdecoded[:65]
          508     hrp_hash = sha256(hrp.encode("ascii") + data.tobytes()).digest()
          509     if addr.pubkey: # Specified by `n`
          510         # BOLT #11:
          511         #
          512         # A reader MUST use the `n` field to validate the signature instead of
          513         # performing signature recovery if a valid `n` field is provided.
          514         ecc.ECPubkey(addr.pubkey).verify_message_hash(sigdecoded[:64], hrp_hash)
          515         pubkey_copy = addr.pubkey
          516         class WrappedBytesKey:
          517             serialize = lambda: pubkey_copy
          518         addr.pubkey = WrappedBytesKey
          519     else: # Recover pubkey from signature.
          520         addr.pubkey = SerializableKey(ecc.ECPubkey.from_sig_string(sigdecoded[:64], sigdecoded[64], hrp_hash))
          521 
          522     return addr
          523 
          524 
          525 
          526 
          527 if __name__ == '__main__':
          528     # run using
          529     # python3 -m electrum.lnaddr <invoice> <expected hrp>
          530     # python3 -m electrum.lnaddr lntb1n1pdlcakepp5e7rn0knl0gm46qqp9eqdsza2c942d8pjqnwa5903n39zu28sgk3sdq423jhxapqv3hkuct5d9hkucqp2rzjqwyx8nu2hygyvgc02cwdtvuxe0lcxz06qt3lpsldzcdr46my5epmj9vk9sqqqlcqqqqqqqlgqqqqqqgqjqdhnmkgahfaynuhe9md8k49xhxuatnv6jckfmsjq8maxta2l0trh5sdrqlyjlwutdnpd5gwmdnyytsl9q0dj6g08jacvthtpeg383k0sq542rz2 tb1n
          531     import sys
          532     print(lndecode(sys.argv[1], expected_hrp=sys.argv[2]))