Pārlūkot izejas kodu

avoid connections to ourselves, store connected host names using their ips

Malte Kraus 8 gadi atpakaļ
vecāks
revīzija
9c8bb6fa9a
1 mainītis faili ar 27 papildinājumiem un 5 dzēšanām
  1. 27 5
      src/protocol.py

+ 27 - 5
src/protocol.py

@@ -7,6 +7,7 @@ import logging
 from threading import Thread, Lock
 from queue import Queue, PriorityQueue
 from binascii import unhexlify, hexlify
+from uuid import UUID, uuid4
 from typing import Callable, List
 
 
@@ -67,10 +68,10 @@ class PeerConnection:
 
         self.send_msg("myport", self.proto.server.server_address[1])
         self.send_msg("block", self.proto._primary_block)
+        self._sent_uuid = str(uuid4())
+        self.send_msg("id", self._sent_uuid)
         self.send_peers()
 
-        # TODO: broadcast this new peer to our current peers, under certain circumstances
-
         Thread(target=self.reader_thread, daemon=True).start()
         self.writer_thread()
 
@@ -90,6 +91,8 @@ class PeerConnection:
     def close(self):
         """ Closes the connection to this peer. """
 
+        # TODO: use locks to avoid the race conditions here
+
         if not self.is_connected:
             return
 
@@ -161,9 +164,9 @@ class PeerConnection:
             logging.debug("received %s", obj['msg_type'])
 
             if msg_type == 'myport':
-                self.peer_addr = (self._sock_addr[0],) + (int(msg_param),) + self._sock_addr[2:]
-            else:
-                self.proto.received(msg_type, msg_param, self)
+                addr = self.socket.getpeername()
+                self.peer_addr = (addr[0],) + (int(msg_param),) + addr[2:]
+            self.proto.received(msg_type, msg_param, self)
 
 
 class SocketServer(socketserver.TCPServer):
@@ -281,11 +284,25 @@ class Protocol:
                 except OSError:
                     pass
 
+    def received_id(self, uuid: str, sender: PeerConnection):
+        """
+        A unique connection id was received. We use this to detect and close connections to
+        ourselves.
+
+        TODO: detect duplicate connections to other peers (needs TLS or something similar)
+        """
+        for peer in self.peers:
+            if peer._sent_uuid == uuid:
+                peer.close()
+                sender.close()
+                break
+
     def received_peer(self, peer_addr: list, _):
         """ Information about a peer has been received. """
 
         peer_addr = tuple(peer_addr)
         if len(self.peers) >= MAX_PEERS:
+            # TODO: maintain list of known, not connected peers
             return
 
         for peer in self.peers:
@@ -295,6 +312,11 @@ class Protocol:
         # TODO: if the other peer also just learned of us, we can end up with two connections (one from each direction)
         self.peers.append(PeerConnection(peer_addr, self))
 
+    def received_myport(self, _, sender: PeerConnection):
+        for peer in self.peers:
+            if peer.is_connected and peer is not sender:
+                peer.send_msg("peer", list(sender.peer_addr))
+
     def received_getblock(self, block_hash: str, peer: PeerConnection):
         """ We received a request for a new block from a certain peer. """
         for handler in self.block_request_handlers: