Selaa lähdekoodia

only keep a logarithmic amount of blockchain checkpoints

Malte Kraus 8 vuotta sitten
vanhempi
sitoutus
0bbb90e8ba
1 muutettua tiedostoa jossa 35 lisäystä ja 15 poistoa
  1. 35 15
      src/chainbuilder.py

+ 35 - 15
src/chainbuilder.py

@@ -5,6 +5,7 @@ candidate for an even longer chain that it attempts to download and verify.
 
 import threading
 import logging
+import math
 from typing import List, Dict, Callable, Optional
 from datetime import datetime, timedelta
 
@@ -87,7 +88,6 @@ class ChainBuilder:
     def __init__(self, protocol: 'Protocol'):
         self.primary_block_chain = Blockchain()
         self._block_requests = {}
-        # TODO: delete some old checkpoints
         self._blockchain_checkpoints = { GENESIS_BLOCK_HASH: self.primary_block_chain }
 
         self.block_cache = { GENESIS_BLOCK_HASH: GENESIS_BLOCK }
@@ -138,12 +138,42 @@ class ChainBuilder:
         for handler in self.chain_change_handlers:
             handler()
 
-        self._blockchain_checkpoints[chain.head.hash] = chain
         self._retry_expired_requests()
         self._clean_block_requests()
 
         self.protocol.broadcast_primary_block(chain.head)
 
+    def _build_blockchain(self, checkpoint: 'Blockchain', blocks: 'List[Block]'):
+        def checkpoint_hashes(chain):
+            chain_len = len(chain.blocks)
+            idx = 0
+            yield GENESIS_BLOCK_HASH
+            while chain_len > 1:
+                cp = 2 ** (math.floor(math.log(chain_len, 2) - 1))
+                idx += cp
+                yield chain.blocks[idx].hash
+                chain_len = chain_len - cp
+
+        chain = checkpoint
+        checkpoints = self._blockchain_checkpoints.copy()
+        for b in blocks:
+            next_chain = chain.try_append(b)
+            if next_chain is None:
+                logging.warning("invalid block")
+                break
+            chain = next_chain
+            checkpoints[chain.head.hash] = chain
+
+        if chain.head.height <= self.primary_block_chain.head.height:
+            logging.warning("discarding shorter chain")
+            return
+
+        for hash_val in checkpoints.keys() - set(checkpoint_hashes(next_chain)):
+            del checkpoints[hash_val]
+        self._blockchain_checkpoints = checkpoints
+        self._new_primary_block_chain(chain)
+
+
     def _retry_expired_requests(self):
         """ Sends new block requests to our peers for unanswered pending requests. """
         for request in self._block_requests.values():
@@ -204,20 +234,10 @@ class ChainBuilder:
             self._block_requests[block.prev_block_hash] = request
 
         if block.prev_block_hash in self._blockchain_checkpoints:
-            winner = self.primary_block_chain
-            for partial_chain in request.partial_chains:
-                chain = self._blockchain_checkpoints[block.prev_block_hash]
-                for b in partial_chain[::-1]:
-                    next_chain = chain.try_append(b)
-                    if next_chain is None:
-                        logging.warning("invalid block")
-                        break
-                    chain = next_chain
-                if chain.head.height > winner.head.height:
-                    winner = chain
             del self._block_requests[block.prev_block_hash]
-            if winner is not self.primary_block_chain:
-                self._new_primary_block_chain(winner)
+            checkpoint = self._blockchain_checkpoints[block.prev_block_hash]
+            for partial_chain in request.partial_chains:
+                self._build_blockchain(checkpoint, partial_chain[::-1])
         request.checked_retry(self.protocol)
 
 from .protocol import Protocol