Bläddra i källkod

properly close P2P connections, without race conditions

Malte Kraus 8 år sedan
förälder
incheckning
08cfd1761a
1 ändrade filer med 43 tillägg och 21 borttagningar
  1. 43 21
      src/protocol.py

+ 43 - 21
src/protocol.py

@@ -21,8 +21,8 @@ MAX_PEERS = 10
 HELLO_MSG = b"bl0ckch41n" + hexlify(GENESIS_BLOCK_HASH)[:30]
 """ The hello message two peers use to make sure they are speaking the same protocol. """
 
-# TODO: set this centrally
-socket.setdefaulttimeout(30)
+SOCKET_TIMEOUT = 30
+""" The socket timeout for P2P connections. """
 
 class PeerConnection:
     """
@@ -45,6 +45,7 @@ class PeerConnection:
         self.is_connected = False
         self._sent_uuid = str(uuid4())
         self.outgoing_msgs = Queue()
+        self._close_lock = Lock()
 
         Thread(target=self.run, daemon=True).start()
 
@@ -60,13 +61,21 @@ class PeerConnection:
 
         Does not return until the writer thread does.
         """
-        if self.socket is None:
-            logging.info("connecting to peer %s", repr(self._sock_addr))
-            self.socket = socket.create_connection(self._sock_addr)
-        self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
-        self.socket.sendall(HELLO_MSG)
-        if self.socket.recv(len(HELLO_MSG)) != HELLO_MSG:
-            return
+        try:
+            if self.socket is None:
+                logging.info("connecting to peer %s", repr(self._sock_addr))
+                self.socket = socket.create_connection(self._sock_addr, SOCKET_TIMEOUT)
+            else:
+                self.socket.settimeout(SOCKET_TIMEOUT)
+            self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+            self.socket.sendall(HELLO_MSG)
+            if self.socket.recv(len(HELLO_MSG)) != HELLO_MSG:
+                raise OSError("peer talks a different protocol")
+        except OSError as e:
+            self.proto.received("disconnected", None, self)
+            if self.socket is not None:
+                self.socket.close()
+            raise e
         self.is_connected = True
 
         self.send_msg("myport", self.proto.server.server_address[1])
@@ -93,19 +102,19 @@ class PeerConnection:
     def close(self):
         """ Closes the connection to this peer. """
 
-        # TODO: use locks to avoid the race conditions here
+        with self._close_lock:
+            if not self.is_connected:
+                return
 
-        if not self.is_connected:
-            return
+            logging.info("closing connection to peer %s", self._sock_addr)
 
-        logging.info("closing connection to peer %s", self._sock_addr)
-        while not self.outgoing_msgs.empty():
-            self.outgoing_msgs.get_nowait()
-        self.outgoing_msgs.put(None)
-        self.is_connected = False
-        if self in self.proto.peers:
-            self.proto.peers.remove(self)
-        self.socket.close()
+            while not self.outgoing_msgs.empty():
+                self.outgoing_msgs.get_nowait()
+            self.outgoing_msgs.put(None)
+            self.is_connected = False
+            self.proto.received("disconnected", None, self)
+
+            self.socket.close()
 
     def send_msg(self, msg_type: str, msg_param):
         """
@@ -316,7 +325,10 @@ class Protocol:
     def received_myport(self, _, sender: PeerConnection):
         for peer in self.peers:
             if peer.is_connected and peer is not sender:
-                peer.send_msg("peer", list(sender.peer_addr))
+                if peer.peer_addr == sender.peer_addr:
+                    sender.close()
+                else:
+                    peer.send_msg("peer", list(sender.peer_addr))
 
     def received_getblock(self, block_hash: str, peer: PeerConnection):
         """ We received a request for a new block from a certain peer. """
@@ -336,6 +348,16 @@ class Protocol:
         for handler in self.trans_receive_handlers:
             handler(Transaction.from_json_compatible(transaction))
 
+    def received_disconnected(self, _, peer):
+        """
+        Removes a disconnected peer from our list of connected peers.
+
+        (Not actually a message received from the peer, but a message sent by the reader or writer
+        thread to the main thread.)
+        """
+        if not peer.is_connected:
+            self.peers.remove(peer)
+
     def send_block_request(self, block_hash: bytes):
         """ Sends a request for a block to all our peers. """
         for peer in self.peers: