KiriPacketSocket/__init__.py

415 lines
13 KiB
Python
Raw Normal View History

2024-06-29 13:52:51 -07:00
#!/usr/bin/python3
# Copyright © 2024 Kiri Jolly
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation files
# (the “Software”), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import socket
import threading
import time
import sys
import enum
class PacketSocket:
"""Socket wrapper that encapsulates its own packet protocol and
manages its own threads. Polling-style interface.
Packets are simply binary blobs with a little-endian 32-bit
unsigned integer size prepended to them. Input data will be split
into individual packets based on these.
TCP/IP only.
Main public API:
- start_client(address) - Connect to an open port.
- start_server(address) - Start a server with an open port. This
will start receiving connections, which much be accepted with
get_next_server_connection().
- get_next_server_connection() - Get the next incoming connection
(returns None if nothing is currently connecting, or a new
PacketSocket that can actually receive packets, otherwise).
Polling interface.
- send_packet(b) - Sends a byte array (b) as a packet. This adds
it to a queue which is send in its own thread.
- get_next_packet() - Gets the next complete incoming packet as a
byte array or None if no packet yet (polling interface).
- get_last_error() - Gets the last error as a string.
- is_disconnected_or_error() - True if we've had a problem (use
get_last_error() to see what the problem was).
- stop() - Shutdown server or disconnect client (undoes
start_client/start_server). Also must be used on the
PacketSocket objects returned by get_next_server_connection().
Internal packet format:
- size: 4-byte little-endian unsigned integer indicating the
number of bytes to follow for the next packet.
- data: As many bytes of data that were specified in the 'size'.
Note: 'size' does not include the space needed for the 'size'
value itself.
Example packet (bytes):
[ 0x03, 0x00, 0x00, 0x00, 0x41, 0x42, 0x43 ]
The first four bytes, [ 0x03, 0x00, 0x00, 0x00 ], indicate the
size of the data, in a little endian unsigned integer. In this
case that decodes to three. The next three bytes, [ 0x41, 0x42,
0x43 ] are the packet data. In this case they represent the ASCII
string "ABC".
To send this example packet, one would use send_packet(b'ABC').
On the receiving side, the return value of get_next_packet() would
be b'ABC'.
"""
class PacketBuffer:
"""Receiving buffer for packets. Accumulates bytes until
complete packets are formed.
"""
def __init__(self):
self._receive_buffer = b''
self._packet_buffer = []
def _grab_complete_packets(self):
while len(self._receive_buffer) >= 4:
next_packet_size = int.from_bytes(
self._receive_buffer[0:4],
"little")
if len(self._receive_buffer) >= 4 + next_packet_size:
next_packet = self._receive_buffer[4 : 4 + next_packet_size]
self._receive_buffer = self._receive_buffer[4 + len(next_packet):]
self._packet_buffer.append(next_packet)
else:
break
def _have_complete_packet(self):
self._grab_complete_packets()
return len(self._packet_buffer) > 0
def get_next_packet(self):
if not self._have_complete_packet():
return None
return self._packet_buffer.pop(0)
def add_bytes(self, incoming_bytes):
self._receive_buffer += incoming_bytes
class SocketState(enum.Enum):
DISCONNECTED = 0
CONNECTING = 1
CONNECTED = 2
LISTENING = 3
ERROR = 4
def __init__(self):
self._should_quit = False
self._packet_buffer = self.PacketBuffer()
self._state = self.SocketState.DISCONNECTED
self._outgoing_packet_queue = []
self._state_lock = threading.Lock()
self._worker_thread = None
self._new_connections_to_server = []
self._error_string = None
def __del__(self):
# WE BETTER NOT HAVE ZOMBIE THREADS SITTING AROUND.
assert(not self._worker_thread)
def send_packet(self, packet_bytes):
"""Add a binary blob to the send queue."""
assert(packet_bytes)
with self._state_lock:
assert(packet_bytes)
self._outgoing_packet_queue.append(packet_bytes)
def get_next_packet(self):
"""Get a binary blob from the receive queue."""
with self._state_lock:
ret = self._packet_buffer.get_next_packet()
return ret
def get_next_server_connection(self):
"""For servers: Get the next incoming connection as a PacketSocket instance.
Returns None if there are no incoming connections.
"""
with self._state_lock:
ret = None
if len(self._new_connections_to_server):
ret = self._new_connections_to_server.pop(0)
return ret
def get_last_error(self):
"""Get the last error, as a string. (From the thrown exception.)
Returns None if there has not been an error.
"""
with self._state_lock:
return self._error_string
def is_disconnected_or_error(self):
"""Returns True if this socket has disconnected, or thrown an
error. (One way or another, the connection is gone and
resources may be cleaned up.)
Note: May still have un-processed packets queued.
"""
with self._state_lock:
bad_states = [
self.SocketState.DISCONNECTED,
self.SocketState.ERROR
]
if self._state in bad_states:
return True
return False
def get_state(self):
"""Returns the current SocketState of this object."""
with self._state_lock:
return self._state
def start_server(self, address):
"""For servers: Start a listening server. (And start worker
thread.)
Address is a tuple with a host IP (string) and a port number
(int). Use "0.0.0.0" to open on every interface.
"""
self._set_state(self.SocketState.LISTENING)
assert(not self._worker_thread)
self._worker_thread = threading.Thread(
target=self._server_thread_func,
args=[address],
daemon=True)
self._worker_thread.start()
def start_client(self, address):
"""For clients: Attempt to connect to a listening server. (And
start worker thread.)
Address is a tuple with a host IP (string) and a port number
(int).
"""
self._set_state(self.SocketState.CONNECTING)
assert(not self._worker_thread)
self._worker_thread = threading.Thread(
target=self._client_thread_func,
args=[address],
daemon=True)
self._worker_thread.start()
def stop(self):
"""Disconnect and shutdown the thread.
For servers: Note that this does not disconnect PacketSocket
instances that were established from this server.
"""
assert(self._worker_thread)
self._should_quit = True
self._worker_thread.join()
self._worker_thread = None
self._should_quit = False
self._client_thread_func = None
def is_running(self):
return not (self._worker_thread == None)
def _normal_communication_loop(self, sock, address):
"""Shared communication loop between clients and servers."""
# Packet wranging timeout. Should be low so we can send and
# receive packets fast.
sock.settimeout(0.0001)
while not self._should_quit:
# Get new data.
try:
incoming_bytes = sock.recv(1024)
self._packet_buffer.add_bytes(incoming_bytes)
if not incoming_bytes:
break
except TimeoutError:
pass
# Send all packets from queue.
with self._state_lock:
while len(self._outgoing_packet_queue):
next_outgoing_packet = self._outgoing_packet_queue.pop(0)
sock.send(len(next_outgoing_packet).to_bytes(4, "little"))
sock.send(next_outgoing_packet)
def _client_thread_func(self, address):
"""Client startup thread function. Attempts to establishes connection."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
try:
# Connect to the server.
self._set_state(self.SocketState.CONNECTING)
sock.connect(address)
self._set_state(self.SocketState.CONNECTED)
self._normal_communication_loop(sock, address)
# We are now disconnected.
self._set_state(self.SocketState.DISCONNECTED)
sock.close()
except ConnectionError as ex:
self._set_state(self.SocketState.ERROR, str(ex))
self._connected = False
def _set_state(self, state, error_string=None):
with self._state_lock:
self._state = state
if self._state == self.SocketState.ERROR:
assert(error_string)
self._error_string = error_string
else:
assert(not error_string)
self._error_string = None
def _server_to_client_thread_func(self, connection, address):
"""Server connection startup thread function. Initiated
internally from a server listening for connections.
"""
self._set_state(self.SocketState.CONNECTED)
try:
self._normal_communication_loop(connection, address)
except ConnectionError as ex:
self._set_state(self.SocketState.ERROR, str(ex))
# Only switch to "disconnected" if we were most recently
# connected, otherwise we could mask an error.
if self.get_state() == self.SocketState.CONNECTED:
self._set_state(self.SocketState.DISCONNECTED)
def _server_thread_func(self, address):
"""Server thread function. Attempts to bind to an address and
listen for incoming connections, in a loop.
"""
while not self._should_quit:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
# Timeout for waiting on incoming connections. This
# can be "large". We aren't doing any message
# wrangling besides getting these in. Worst case we
# add latency to shutdown with this.
sock.settimeout(0.01)
try:
# FXIME: This seems to be for UDP ports?
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(address)
sock.listen()
except Exception as ex:
# FIXME: I wonder if we should do this in the main
# thread so we can get the exceptions back up to
# the start_server function and up from there.
self._set_state(self.SocketState.ERROR, str(ex))
break
while not self._should_quit:
try:
connection, address = sock.accept()
new_client = PacketSocket()
new_client._start_client_connection_from_server(connection, address)
with self._state_lock:
self._new_connections_to_server.append(new_client)
except TimeoutError:
pass
def _start_client_connection_from_server(self, connection, address):
"""Entrypoint for PacketSocket instances created by servers
accepting connections.
"""
assert(not self._worker_thread)
self._worker_thread = threading.Thread(
target=self._server_to_client_thread_func,
args=(connection, address),
daemon=True)
self._worker_thread.start()