Kaynağa Gözat

make sure that the chain builder is only used from a single thread

Malte Kraus 8 yıl önce
ebeveyn
işleme
056a5828c9
2 değiştirilmiş dosya ile 15 ekleme ve 1 silme
  1. 14 0
      src/chainbuilder.py
  2. 1 1
      tests/test_proto.py

+ 14 - 0
src/chainbuilder.py

@@ -1,6 +1,8 @@
 from .block import GENESIS_BLOCK, GENESIS_BLOCK_HASH
 from .block import GENESIS_BLOCK, GENESIS_BLOCK_HASH
 from .blockchain import Blockchain
 from .blockchain import Blockchain
 
 
+import threading
+
 __all__ = ['ChainBuilder']
 __all__ = ['ChainBuilder']
 
 
 class ChainBuilder:
 class ChainBuilder:
@@ -25,11 +27,20 @@ class ChainBuilder:
         protocol.block_request_handlers.append(self.block_request_received)
         protocol.block_request_handlers.append(self.block_request_received)
         self.protocol = protocol
         self.protocol = protocol
 
 
+        self._thread_id = None
+
+    def _assert_thread_safety(self):
+        if self._thread_id is None:
+            self._thread_id = threading.get_ident()
+        assert self._thread_id == threading.get_ident()
+
     def block_request_received(self, block_hash):
     def block_request_received(self, block_hash):
+        self._assert_thread_safety()
         return self.block_cache.get(block_hash)
         return self.block_cache.get(block_hash)
 
 
     def new_transaction_received(self, transaction):
     def new_transaction_received(self, transaction):
         """ Event handler that is called by the network layer when a transaction is received. """
         """ Event handler that is called by the network layer when a transaction is received. """
+        self._assert_thread_safety()
         hash_val = transaction.get_hash()
         hash_val = transaction.get_hash()
         if self.primary_block_chain.get_transaction_by_hash(hash_val) is None:
         if self.primary_block_chain.get_transaction_by_hash(hash_val) is None:
            self.unconfirmed_transactions[hash_val] = transaction
            self.unconfirmed_transactions[hash_val] = transaction
@@ -38,6 +49,7 @@ class ChainBuilder:
         """
         """
         Does all the housekeeping that needs to be done when a new longest chain is found.
         Does all the housekeeping that needs to be done when a new longest chain is found.
         """
         """
+        self._assert_thread_safety()
         self.primary_block_chain = chain
         self.primary_block_chain = chain
         todelete = set()
         todelete = set()
         for (hash_val, trans) in self.unconfirmed_transactions.items():
         for (hash_val, trans) in self.unconfirmed_transactions.items():
@@ -54,6 +66,7 @@ class ChainBuilder:
         Helper function that tries to complete the unconfirmed chain,
         Helper function that tries to complete the unconfirmed chain,
         possibly asking the network layer for more blocks.
         possibly asking the network layer for more blocks.
         """
         """
+        self._assert_thread_safety()
         unc = self.unconfirmed_block_chain
         unc = self.unconfirmed_block_chain
         while unc[-1].prev_block_hash in self.block_cache:
         while unc[-1].prev_block_hash in self.block_cache:
             unc.append(self.block_cache[unc[-1].prev_block_hash])
             unc.append(self.block_cache[unc[-1].prev_block_hash])
@@ -69,6 +82,7 @@ class ChainBuilder:
 
 
     def new_block_received(self, block):
     def new_block_received(self, block):
         """ Event handler that is called by the network layer when a block is received. """
         """ Event handler that is called by the network layer when a block is received. """
+        self._assert_thread_safety()
         if not block.verify_difficulty() or not block.verify_merkle() or block.hash in self.block_cache:
         if not block.verify_difficulty() or not block.verify_merkle() or block.hash in self.block_cache:
             return
             return
         self.block_cache[block.hash] = block
         self.block_cache[block.hash] = block

+ 1 - 1
tests/test_proto.py

@@ -25,7 +25,7 @@ strans2 = miner2.chainbuilder.primary_block_chain.head.transactions[0]
 strans2 = TransactionInput(strans2.get_hash(), 0)
 strans2 = TransactionInput(strans2.get_hash(), 0)
 trans = Transaction([strans1, strans2], [])
 trans = Transaction([strans1, strans2], [])
 trans.sign([reward_key, reward_key])
 trans.sign([reward_key, reward_key])
-miner2.chainbuilder.new_transaction_received(trans)
+proto2.received('transaction', trans.to_json_compatible(), None)
 sleep(5)
 sleep(5)
 print(len(miner1.chainbuilder.primary_block_chain.blocks))
 print(len(miner1.chainbuilder.primary_block_chain.blocks))
 print(len(miner2.chainbuilder.primary_block_chain.blocks))
 print(len(miner2.chainbuilder.primary_block_chain.blocks))