URI: 
       tcoinchooser: refactor so that penalty_func has access to change outputs - electrum - Electrum Bitcoin wallet
  HTML git clone https://git.parazyd.org/electrum
   DIR Log
   DIR Files
   DIR Refs
   DIR Submodules
       ---
   DIR commit f409b5da40e607819f94093f4ca6a91f8a2b71f0
   DIR parent 6424163d4bad3de72733849db797d10f11b47479
  HTML Author: SomberNight <somber.night@protonmail.com>
       Date:   Thu, 20 Jun 2019 17:45:56 +0200
       
       coinchooser: refactor so that penalty_func has access to change outputs
       
       Diffstat:
         M electrum/coinchooser.py             |     164 +++++++++++++++++--------------
       
       1 file changed, 88 insertions(+), 76 deletions(-)
       ---
   DIR diff --git a/electrum/coinchooser.py b/electrum/coinchooser.py
       t@@ -24,7 +24,7 @@
        # SOFTWARE.
        from collections import defaultdict
        from math import floor, log10
       -from typing import NamedTuple, List
       +from typing import NamedTuple, List, Callable
        
        from .bitcoin import sha256, COIN, TYPE_ADDRESS, is_address
        from .transaction import Transaction, TxOutput
       t@@ -79,6 +79,12 @@ class Bucket(NamedTuple):
            witness: bool       # whether any coin uses segwit
        
        
       +class ScoredCandidate(NamedTuple):
       +    penalty: float
       +    tx: Transaction
       +    buckets: List[Bucket]
       +
       +
        def strip_unneeded(bkts, sufficient_funds):
            '''Remove buckets that are unnecessary in achieving the spend amount'''
            if sufficient_funds([], bucket_value_sum=0):
       t@@ -121,12 +127,10 @@ class CoinChooserBase(Logger):
        
                return list(map(make_Bucket, buckets.keys(), buckets.values()))
        
       -    def penalty_func(self, tx, *, fee_for_buckets):
       -        def penalty(candidate):
       -            return 0
       -        return penalty
       +    def penalty_func(self, base_tx, *, tx_from_buckets) -> Callable[[List[Bucket]], ScoredCandidate]:
       +        raise NotImplementedError
        
       -    def change_amounts(self, tx, count, fee_estimator, dust_threshold):
       +    def _change_amounts(self, tx, count, fee_estimator):
                # Break change up if bigger than max_change
                output_amounts = [o.value for o in tx.outputs()]
                # Don't split change of less than 0.02 BTC
       t@@ -180,22 +184,60 @@ class CoinChooserBase(Logger):
        
                return amounts
        
       -    def change_outputs(self, tx, change_addrs, fee_estimator, dust_threshold):
       -        amounts = self.change_amounts(tx, len(change_addrs), fee_estimator,
       -                                      dust_threshold)
       +    def _change_outputs(self, tx, change_addrs, fee_estimator, dust_threshold):
       +        amounts = self._change_amounts(tx, len(change_addrs), fee_estimator)
                assert min(amounts) >= 0
                assert len(change_addrs) >= len(amounts)
                # If change is above dust threshold after accounting for the
                # size of the change output, add it to the transaction.
       -        dust = sum(amount for amount in amounts if amount < dust_threshold)
                amounts = [amount for amount in amounts if amount >= dust_threshold]
                change = [TxOutput(TYPE_ADDRESS, addr, amount)
                          for addr, amount in zip(change_addrs, amounts)]
       -        self.logger.info(f'change: {change}')
       -        if dust:
       -            self.logger.info(f'not keeping dust {dust}')
                return change
        
       +    def _construct_tx_from_selected_buckets(self, *, buckets, base_tx, change_addrs,
       +                                            fee_estimator_w, dust_threshold, base_weight):
       +        # make a copy of base_tx so it won't get mutated
       +        tx = Transaction.from_io(base_tx.inputs()[:], base_tx.outputs()[:])
       +
       +        tx.add_inputs([coin for b in buckets for coin in b.coins])
       +        tx_weight = self._get_tx_weight(buckets, base_weight=base_weight)
       +
       +        # change is sent back to sending address unless specified
       +        if not change_addrs:
       +            change_addrs = [tx.inputs()[0]['address']]
       +            # note: this is not necessarily the final "first input address"
       +            # because the inputs had not been sorted at this point
       +            assert is_address(change_addrs[0])
       +
       +        # This takes a count of change outputs and returns a tx fee
       +        output_weight = 4 * Transaction.estimated_output_size(change_addrs[0])
       +        fee = lambda count: fee_estimator_w(tx_weight + count * output_weight)
       +        change = self._change_outputs(tx, change_addrs, fee, dust_threshold)
       +        tx.add_outputs(change)
       +
       +        return tx, change
       +
       +    def _get_tx_weight(self, buckets, *, base_weight) -> int:
       +        """Given a collection of buckets, return the total weight of the
       +        resulting transaction.
       +        base_weight is the weight of the tx that includes the fixed (non-change)
       +        outputs and potentially some fixed inputs. Note that the change outputs
       +        at this point are not yet known so they are NOT accounted for.
       +        """
       +        total_weight = base_weight + sum(bucket.weight for bucket in buckets)
       +        is_segwit_tx = any(bucket.witness for bucket in buckets)
       +        if is_segwit_tx:
       +            total_weight += 2  # marker and flag
       +            # non-segwit inputs were previously assumed to have
       +            # a witness of '' instead of '00' (hex)
       +            # note that mixed legacy/segwit buckets are already ok
       +            num_legacy_inputs = sum((not bucket.witness) * len(bucket.coins)
       +                                    for bucket in buckets)
       +            total_weight += num_legacy_inputs
       +
       +        return total_weight
       +
            def make_tx(self, coins, inputs, outputs, change_addrs, fee_estimator,
                        dust_threshold):
                """Select unspent coins to spend to pay outputs.  If the change is
       t@@ -211,34 +253,20 @@ class CoinChooserBase(Logger):
                self.p = PRNG(''.join(sorted(utxos)))
        
                # Copy the outputs so when adding change we don't modify "outputs"
       -        tx = Transaction.from_io(inputs[:], outputs[:])
       -        input_value = tx.input_value()
       +        base_tx = Transaction.from_io(inputs[:], outputs[:])
       +        input_value = base_tx.input_value()
        
                # Weight of the transaction with no inputs and no change
                # Note: this will use legacy tx serialization as the need for "segwit"
                # would be detected from inputs. The only side effect should be that the
                # marker and flag are excluded, which is compensated in get_tx_weight()
                # FIXME calculation will be off by this (2 wu) in case of RBF batching
       -        base_weight = tx.estimated_weight()
       -        spent_amount = tx.output_value()
       +        base_weight = base_tx.estimated_weight()
       +        spent_amount = base_tx.output_value()
        
                def fee_estimator_w(weight):
                    return fee_estimator(Transaction.virtual_size_from_weight(weight))
        
       -        def get_tx_weight(buckets):
       -            total_weight = base_weight + sum(bucket.weight for bucket in buckets)
       -            is_segwit_tx = any(bucket.witness for bucket in buckets)
       -            if is_segwit_tx:
       -                total_weight += 2  # marker and flag
       -                # non-segwit inputs were previously assumed to have
       -                # a witness of '' instead of '00' (hex)
       -                # note that mixed legacy/segwit buckets are already ok
       -                num_legacy_inputs = sum((not bucket.witness) * len(bucket.coins)
       -                                        for bucket in buckets)
       -                total_weight += num_legacy_inputs
       -
       -            return total_weight
       -
                def sufficient_funds(buckets, *, bucket_value_sum):
                    '''Given a list of buckets, return True if it has enough
                    value to pay for the transaction'''
       t@@ -248,45 +276,30 @@ class CoinChooserBase(Logger):
                        return False
                    # note re performance: so far this was constant time
                    # what follows is linear in len(buckets)
       -            total_weight = get_tx_weight(buckets)
       +            total_weight = self._get_tx_weight(buckets, base_weight=base_weight)
                    return total_input >= spent_amount + fee_estimator_w(total_weight)
        
       -        def fee_for_buckets(buckets) -> int:
       -            """Given a list of buckets, return the total fee paid by the
       -            transaction, in satoshis.
       -            Note that the change output(s) are not yet known here,
       -            so fees for those are excluded and hence this is a lower bound.
       -            """
       -            total_weight = get_tx_weight(buckets)
       -            return fee_estimator_w(total_weight)
       +        def tx_from_buckets(buckets):
       +            return self._construct_tx_from_selected_buckets(buckets=buckets,
       +                                                            base_tx=base_tx,
       +                                                            change_addrs=change_addrs,
       +                                                            fee_estimator_w=fee_estimator_w,
       +                                                            dust_threshold=dust_threshold,
       +                                                            base_weight=base_weight)
        
                # Collect the coins into buckets, choose a subset of the buckets
       -        buckets = self.bucketize_coins(coins)
       -        buckets = self.choose_buckets(buckets, sufficient_funds,
       -                                      self.penalty_func(tx, fee_for_buckets=fee_for_buckets))
       -
       -        tx.add_inputs([coin for b in buckets for coin in b.coins])
       -        tx_weight = get_tx_weight(buckets)
       -
       -        # change is sent back to sending address unless specified
       -        if not change_addrs:
       -            change_addrs = [tx.inputs()[0]['address']]
       -            # note: this is not necessarily the final "first input address"
       -            # because the inputs had not been sorted at this point
       -            assert is_address(change_addrs[0])
       -
       -        # This takes a count of change outputs and returns a tx fee
       -        output_weight = 4 * Transaction.estimated_output_size(change_addrs[0])
       -        fee = lambda count: fee_estimator_w(tx_weight + count * output_weight)
       -        change = self.change_outputs(tx, change_addrs, fee, dust_threshold)
       -        tx.add_outputs(change)
       +        all_buckets = self.bucketize_coins(coins)
       +        scored_candidate = self.choose_buckets(all_buckets, sufficient_funds,
       +                                               self.penalty_func(base_tx, tx_from_buckets=tx_from_buckets))
       +        tx = scored_candidate.tx
        
                self.logger.info(f"using {len(tx.inputs())} inputs")
       -        self.logger.info(f"using buckets: {[bucket.desc for bucket in buckets]}")
       +        self.logger.info(f"using buckets: {[bucket.desc for bucket in scored_candidate.buckets]}")
        
                return tx
        
       -    def choose_buckets(self, buckets, sufficient_funds, penalty_func):
       +    def choose_buckets(self, buckets, sufficient_funds,
       +                       penalty_func: Callable[[List[Bucket]], ScoredCandidate]) -> ScoredCandidate:
                raise NotImplemented('To be subclassed')
        
        
       t@@ -368,12 +381,14 @@ class CoinChooserRandom(CoinChooserBase):
        
            def choose_buckets(self, buckets, sufficient_funds, penalty_func):
                candidates = self.bucket_candidates_prefer_confirmed(buckets, sufficient_funds)
       -        penalties = [penalty_func(cand) for cand in candidates]
       -        winner = candidates[penalties.index(min(penalties))]
       -        self.logger.info(f"Bucket sets: {len(buckets)}")
       -        self.logger.info(f"Winning penalty: {min(penalties)}")
       +        scored_candidates = [penalty_func(cand) for cand in candidates]
       +        winner = min(scored_candidates, key=lambda x: x.penalty)
       +        self.logger.info(f"Total number of buckets: {len(buckets)}")
       +        self.logger.info(f"Num candidates considered: {len(candidates)}. "
       +                         f"Winning penalty: {winner.penalty}")
                return winner
        
       +
        class CoinChooserPrivacy(CoinChooserRandom):
            """Attempts to better preserve user privacy.
            First, if any coin is spent from a user address, all coins are.
       t@@ -388,18 +403,15 @@ class CoinChooserPrivacy(CoinChooserRandom):
            def keys(self, coins):
                return [coin['address'] for coin in coins]
        
       -    def penalty_func(self, tx, *, fee_for_buckets):
       -        min_change = min(o.value for o in tx.outputs()) * 0.75
       -        max_change = max(o.value for o in tx.outputs()) * 1.33
       -        spent_amount = sum(o.value for o in tx.outputs())
       +    def penalty_func(self, base_tx, *, tx_from_buckets):
       +        min_change = min(o.value for o in base_tx.outputs()) * 0.75
       +        max_change = max(o.value for o in base_tx.outputs()) * 1.33
        
       -        def penalty(buckets):
       +        def penalty(buckets) -> ScoredCandidate:
       +            # Penalize using many buckets (~inputs)
                    badness = len(buckets) - 1
       -            total_input = sum(bucket.value for bucket in buckets)
       -            # FIXME fee_for_buckets does not include fees needed to cover the change output(s)
       -            # so fee here is a lower bound
       -            fee = fee_for_buckets(buckets)
       -            change = float(total_input - spent_amount - fee)
       +            tx, change_outputs = tx_from_buckets(buckets)
       +            change = sum(o.value for o in change_outputs)
                    # Penalize change not roughly in output range
                    if change < min_change:
                        badness += (min_change - change) / (min_change + 10000)
       t@@ -407,7 +419,7 @@ class CoinChooserPrivacy(CoinChooserRandom):
                        badness += (change - max_change) / (max_change + 10000)
                        # Penalize large change; 5 BTC excess ~= using 1 more input
                        badness += change / (COIN * 5)
       -            return badness
       +            return ScoredCandidate(badness, tx, buckets)
        
                return penalty