瀏覽代碼

incrementally build blockchains based on their predecessors

Malte Kraus 8 年之前
父節點
當前提交
7ceaa03073
共有 4 個文件被更改,包括 61 次插入34 次删除
  1. 29 19
      src/blockchain.py
  2. 29 12
      src/chainbuilder.py
  3. 1 1
      tests/test_verifications.py
  4. 2 2
      tests/utils.py

+ 29 - 19
src/blockchain.py

@@ -20,25 +20,35 @@ class Blockchain:
     :vartype unspent_coins: Dict[TransactionInput, TransactionTarget]
     """
 
-    def __init__(self, blocks: 'List[Block]'):
-        self.blocks = blocks
+    def __init__(self):
+        self.blocks = [GENESIS_BLOCK]
         assert self.blocks[0].height == 0
-        self.block_indices = {block.hash: i for (i, block) in enumerate(blocks)}
-        self.unspent_coins = self._compute_unspent_coins()
-
-    def _compute_unspent_coins(self):
-        val = {}
-
-        for b in self.blocks:
-            for t in b.transactions:
-                for inp in t.inputs:
-                    if inp not in val:
-                        logging.warning("Aborting computation of unspent transactions because a transaction spent an unavailable coin.")
-                        return {}
-                    del val[inp]
-                for i, target in enumerate(t.targets):
-                    val[TransactionInput(t.get_hash(), i)] = target
-        return val
+        self.block_indices = {GENESIS_BLOCK_HASH: 0}
+        assert not GENESIS_BLOCK.transactions
+        self.unspent_coins = {}
+
+    def try_append(self, block: 'Block') -> 'Optional[Blockchain]':
+        unspent_coins = self.unspent_coins.copy()
+
+        for t in block.transactions:
+            for inp in t.inputs:
+                if inp not in unspent_coins:
+                    logging.warning("Aborting computation of unspent transactions because a transaction spent an unavailable coin.")
+                    return None
+                del unspent_coins[inp]
+            for i, target in enumerate(t.targets):
+                unspent_coins[TransactionInput(t.get_hash(), i)] = target
+
+        chain = Blockchain()
+        chain.unspent_coins = unspent_coins
+        chain.blocks = self.blocks + [block]
+        chain.block_indices = self.block_indices.copy()
+        chain.block_indices[block.hash] = len(self.blocks)
+
+        if not block.verify(chain):
+            return None
+
+        return chain
 
     def get_transaction_by_hash(self, hash_val: bytes) -> 'Optional[Transaction]':
         """ Returns a transaction from its hash, or None. """
@@ -133,5 +143,5 @@ class Blockchain:
             reward = reward // 2
         return reward
 
-from .block import Block
+from .block import Block, GENESIS_BLOCK, GENESIS_BLOCK_HASH
 from .transaction import TransactionInput, Transaction

+ 29 - 12
src/chainbuilder.py

@@ -14,8 +14,8 @@ from .blockchain import Blockchain
 __all__ = ['ChainBuilder']
 
 class PartialChain:
-    def __init__(self, start_block: Block):
-        self.blocks = [start_block]
+    def __init__(self):
+        self.blocks = []
         self.last_update = datetime.utcnow()
         # TODO: delete partial chains after some time
 
@@ -40,8 +40,10 @@ class ChainBuilder:
     """
 
     def __init__(self, protocol):
-        self.primary_block_chain = Blockchain([GENESIS_BLOCK])
+        self.primary_block_chain = Blockchain()
         self._block_requests = {}
+        # TODO: delete some old checkpoints
+        self._blockchain_checkpoints = { GENESIS_BLOCK_HASH: self.primary_block_chain }
 
         self.block_cache = { GENESIS_BLOCK_HASH: GENESIS_BLOCK }
         self.unconfirmed_transactions = {}
@@ -91,6 +93,19 @@ class ChainBuilder:
         for handler in self.chain_change_handlers:
             handler()
 
+        self._blockchain_checkpoints[chain.head.hash] = chain
+
+        # stop trying to build shorter block chains
+        block_requests = {}
+        for block_hash, requests in self._block_requests.items():
+            new_requests = []
+            for partial_chain in requests:
+                if partial_chain.blocks[0].height > chain.head.height:
+                    new_requests.append(partial_chain)
+            if new_requests:
+                block_requests[block_hash] = new_requests
+        self._block_requests = block_requests
+
         self.protocol.broadcast_primary_block(chain.head)
 
     def new_block_received(self, block: 'Block'):
@@ -103,10 +118,7 @@ class ChainBuilder:
 
         if block.hash not in self._block_requests:
             if block.height > self.primary_block_chain.head.height:
-                self._block_requests.setdefault(block.prev_block_hash, []).append(PartialChain(block))
-                block = self.block_cache.get(block.prev_block_hash)
-                if block is None:
-                    return
+                self._block_requests.setdefault(block.hash, []).append(PartialChain())
             else:
                 return
 
@@ -116,16 +128,21 @@ class ChainBuilder:
             for partial_chain in requests:
                 partial_chain.blocks.append(block)
                 partial_chain.last_update = datetime.utcnow()
-            if block.prev_block_hash not in self.block_cache:
+            if block.prev_block_hash not in self.block_cache or block.prev_block_hash in self._blockchain_checkpoints:
                 break
             block = self.block_cache[block.prev_block_hash]
         self._block_requests.setdefault(block.prev_block_hash, []).extend(requests)
-        if block.hash == GENESIS_BLOCK_HASH:
+
+        if block.prev_block_hash in self._blockchain_checkpoints:
             winner = self.primary_block_chain
             for partial_chain in requests:
-                chain = Blockchain(partial_chain.blocks[::-1])
-                if chain.head.height > winner.head.height and \
-                        chain.verify_all():
+                chain = self._blockchain_checkpoints[block.prev_block_hash]
+                for b in partial_chain.blocks[::-1]:
+                    next_chain = chain.try_append(b)
+                    if next_chain is None:
+                        break
+                    chain = next_chain
+                if chain.head.height > winner.head.height:
                     winner = chain
             if winner is not self.primary_block_chain:
                 self._new_primary_block_chain(winner)

+ 1 - 1
tests/test_verifications.py

@@ -10,7 +10,7 @@ def trans_test(fn):
         src.proof_of_work.verify_proof_of_work = lambda b: True
         src.block.verify_proof_of_work = src.proof_of_work.verify_proof_of_work
 
-        gen_chain = Blockchain([GENESIS_BLOCK])
+        gen_chain = Blockchain()
         assert gen_chain.verify_all()
         key = Signing.generate_private_key()
         reward_trans = Transaction([], [TransactionTarget(key, gen_chain.compute_blockreward(gen_chain.head))])

+ 2 - 2
tests/utils.py

@@ -15,8 +15,8 @@ def extend_blockchain(chain, trans:list=None, verify_res=True):
     ts = datetime.utcfromtimestamp(0)
     new_block = Block.create(chain, trans, ts)
     new_block.hash = new_block.get_hash()
-    new_chain = Blockchain(chain.blocks + [new_block])
-    assert new_chain.verify_all() == verify_res
+    new_chain = chain.try_append(new_block)
+    assert (new_chain is not None) == verify_res
     return new_chain
 
 def trans_as_input(trans, out_idx=0):