diff --git a/.gitignore b/.gitignore index 6c1eab14..d54a2972 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ venv # ides .vscode .idea +.cursor *~ .*.swp diff --git a/server/pyproject.toml b/server/pyproject.toml index a4a09d89..aa9b12d9 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -36,7 +36,7 @@ optional-dependencies.test = [ "pytest", "testcontainers>=4" ] urls.Repository = "https://github.com/ChannelFinder/recsync" [tool.setuptools] -packages = [ "recceiver", "twisted.plugins" ] +packages = [ "recceiver", "recceiver.protocol", "twisted.plugins" ] include-package-data = true package-data.twisted = [ "plugins/recceiver_plugin.py" ] diff --git a/server/recceiver/announce.py b/server/recceiver/announce.py index f6442035..266e5cc8 100644 --- a/server/recceiver/announce.py +++ b/server/recceiver/announce.py @@ -1,16 +1,14 @@ # -*- coding: utf-8 -*- import logging -import struct -import sys from twisted.internet import protocol from twisted.internet.error import MessageLengthError -_log = logging.getLogger(__name__) +from recceiver.protocol.announce import ANNOUNCE_PORT, BROADCAST_ADDRESS, Announce +_log = logging.getLogger(__name__) -_Ann = struct.Struct(">HH4sHHI") __all__ = ["Announcer"] @@ -20,19 +18,18 @@ def __init__( self, tcpport, key=0, - tcpaddr="\xff\xff\xff\xff", - udpaddrs=[("", 5049)], + host=BROADCAST_ADDRESS, + udpaddrs=None, period=15.0, ): from twisted.internet import reactor self.reactor = reactor - if sys.version_info[0] < 3: - self.msg = _Ann.pack(0x5243, 0, tcpaddr, tcpport, 0, key) - else: - self.msg = _Ann.pack(0x5243, 0, tcpaddr.encode("latin-1"), tcpport, 0, key) + if udpaddrs is None: + udpaddrs = [("", ANNOUNCE_PORT)] + self.msg = Announce(tcp_port=tcpport, key=key, host=host).encode() self.delay = period self.udps = udpaddrs self.udpErr = set() @@ -43,7 +40,7 @@ def __init__( def startProtocol(self): _log.info("Setup Announcer") self.D = self.reactor.callLater(0, self.sendOne) - # we won't process any receieved traffic, so no reason to wake + # we won't process any received traffic, so no reason to wake # up for it... self.transport.pauseProducing() diff --git a/server/recceiver/protocol/__init__.py b/server/recceiver/protocol/__init__.py new file mode 100644 index 00000000..9b3a2c1d --- /dev/null +++ b/server/recceiver/protocol/__init__.py @@ -0,0 +1 @@ +"""Framework-neutral recsync protocol helpers.""" diff --git a/server/recceiver/protocol/announce.py b/server/recceiver/protocol/announce.py new file mode 100644 index 00000000..afdb6c54 --- /dev/null +++ b/server/recceiver/protocol/announce.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +"""Framework-neutral RecSync UDP announce packet.""" + +import socket +import struct +from dataclasses import dataclass +from typing import ClassVar + +from recceiver.protocol.messages import PROTO_ID, ProtocolError + +ANNOUNCE_PORT = 5049 +BROADCAST_ADDRESS = "255.255.255.255" + + +@dataclass(frozen=True) +class Announce: + """UDP packet advertising the RecCeiver's TCP endpoint. + + Wire layout: PROTO_ID(2) + 0(2) + SERV_ADDR(4) + PORT(2) + pad(2) + SERV_KEY(4). + """ + + payload: ClassVar = struct.Struct("!HH4sHxxI") + + tcp_port: int + key: int = 0 + host: str = BROADCAST_ADDRESS + + def encode(self): + try: + addr = socket.inet_aton(self.host) + except OSError: + raise ProtocolError(f"invalid announce address: {self.host}") + reserved = 0 + return self.payload.pack(PROTO_ID, reserved, addr, self.tcp_port, self.key) + + @classmethod + def decode(cls, wire_bytes): + if len(wire_bytes) < cls.payload.size: + raise ProtocolError(f"announce packet must be at least {cls.payload.size} bytes") + proto_id, version, addr, tcp_port, key = cls.payload.unpack(wire_bytes[: cls.payload.size]) + if proto_id != PROTO_ID: + raise ProtocolError(f"bad protocol id {proto_id:#06x}") + if version != 0: + raise ProtocolError(f"unsupported announce version {version}") + return cls(tcp_port=tcp_port, key=key, host=socket.inet_ntoa(addr)) + + +assert Announce.payload.size == 16 diff --git a/server/recceiver/protocol/messages.py b/server/recceiver/protocol/messages.py new file mode 100644 index 00000000..979f6097 --- /dev/null +++ b/server/recceiver/protocol/messages.py @@ -0,0 +1,287 @@ +# -*- coding: utf-8 -*- +"""Framework-neutral RecSync TCP protocol messages.""" + +import struct +from dataclasses import astuple, dataclass +from enum import IntEnum +from typing import ClassVar + +PROTO_ID = 0x5243 # b'RC' + + +class RecordKind(IntEnum): + RECORD = 0 + ALIAS = 1 + + +class ProtocolError(ValueError): + """Raised when a RecSync message violates the wire protocol.""" + + +def _require_length(data, expected_length, message_name, exact=True): + if exact and len(data) == expected_length: + return + if not exact and len(data) >= expected_length: + return + + relation = "must be" if exact else "must be at least" + raise ProtocolError(f"{message_name} {relation} {expected_length} bytes, got {len(data)}") + + +@dataclass(frozen=True) +class Header: + """TCP frame header. + + Wire layout: PROTO_ID(2) + MSG_ID(2) + BLEN(4). + """ + + payload: ClassVar = struct.Struct("!HHI") + + msg_id: int + body_length: int + + def encode(self): + return self.payload.pack(PROTO_ID, int(self.msg_id), self.body_length) + + @classmethod + def decode(cls, wire_bytes): + _require_length(wire_bytes, cls.payload.size, cls.__name__) + proto_id, msg_id, body_length = cls.payload.unpack(wire_bytes) + if proto_id != PROTO_ID: + raise ProtocolError(f"bad protocol id {proto_id:#06x}") + return cls(msg_id, body_length) + + +assert Header.payload.size == 8 + + +class TCPMessage: + """Base class for TCP payloads carried by a RecSync frame.""" + + payload = struct.Struct("!") + msg_id = None + + def encode(self): + return self.payload.pack(*astuple(self)) + + def frame(self): + body = self.encode() + return Header(self.msg_id, len(body)).encode() + body + + @classmethod + def decode(cls, body): + _require_length(body, cls.payload.size, cls.__name__) + return cls(*cls.payload.unpack(body)) + + +assert TCPMessage.payload.size == 0 + + +class ClientMessage(TCPMessage): + """Message sent from RecCaster to RecCeiver (msg_id 0x0XXX).""" + + +class ServerMessage(TCPMessage): + """Message sent from RecCeiver to RecCaster (msg_id 0x8XXX).""" + + +@dataclass(frozen=True) +class ClientGreeting(ClientMessage): + """Client greeting sent on connection establishment. + + Wire layout: VERSION(1) + CLIENT_TYPE(1) + pad(2) + SERV_KEY(4). + """ + + payload: ClassVar = struct.Struct("!BBxxI") + msg_id: ClassVar = 0x0001 + + version: int + client_type: int + server_key: int + + +assert ClientGreeting.payload.size == 8 + + +@dataclass(frozen=True) +class Pong(ClientMessage): + """Ping response echoing the server's nonce. + + Wire layout: NONCE(4). + """ + + payload: ClassVar = struct.Struct("!I") + msg_id: ClassVar = 0x0002 + + nonce: int + + +assert Pong.payload.size == 4 + + +@dataclass(frozen=True) +class AddRecord(ClientMessage): + """Registers a new record or alias. + + Wire layout: RECID(4) + KIND(1) + RTLEN(1) + RNLEN(2) [+ RTYPE(RTLEN) + RNAME(RNLEN)]. + """ + + payload: ClassVar = struct.Struct("!IBBH") + msg_id: ClassVar = 0x0003 + + record_id: int + kind: RecordKind + record_type: str + record_name: str + + @property + def is_alias(self): + return self.kind == RecordKind.ALIAS + + def encode(self): + kind = _record_kind(self.kind) + record_type = self.record_type.encode() + record_name = self.record_name.encode() + _validate_record_lengths(kind, len(record_type), len(record_name)) + return ( + self.payload.pack(self.record_id, int(kind), len(record_type), len(record_name)) + record_type + record_name + ) + + @classmethod + def decode(cls, body): + _require_length(body, cls.payload.size, cls.__name__, exact=False) + record_id, kind, record_type_length, record_name_length = cls.payload.unpack(body[: cls.payload.size]) + kind = _record_kind(kind) + text = body[cls.payload.size :] + _validate_record_lengths(kind, record_type_length, record_name_length) + _require_length(text, record_type_length + record_name_length, "add record text") + record_type = text[:record_type_length].decode() + record_name = text[record_type_length:].decode() + return cls(record_id, kind, record_type, record_name) + + +assert AddRecord.payload.size == 8 + + +@dataclass(frozen=True) +class DelRecord(ClientMessage): + """Removes a previously registered record. + + Wire layout: RECID(4). + """ + + payload: ClassVar = struct.Struct("!I") + msg_id: ClassVar = 0x0004 + + record_id: int + + +assert DelRecord.payload.size == 4 + + +@dataclass(frozen=True) +class UploadDone(ClientMessage): + """Signals the end of the initial record upload. + + Carries a dummy zero payload that is accepted but ignored. + """ + + payload: ClassVar = struct.Struct("!I") + msg_id: ClassVar = 0x0005 + + def encode(self): + return self.payload.pack(0) + + @classmethod + def decode(cls, body): + _require_length(body, cls.payload.size, cls.__name__) + return cls() + + +assert UploadDone.payload.size == 4 + + +@dataclass(frozen=True) +class AddInfo(ClientMessage): + """Key-value info update for a record or the IOC. + + Wire layout: RECID(4) + KEYLEN(1) + pad(1) + VALEN(2) [+ KEY(KEYLEN) + VALUE(VALEN)]. + RECID=0 targets the IOC itself rather than a specific record. + """ + + payload: ClassVar = struct.Struct("!IBxH") + msg_id: ClassVar = 0x0006 + + record_id: int + key: str + value: str + + def encode(self): + key = self.key.encode() + value = self.value.encode() + if len(key) == 0: + raise ProtocolError("add info key must not be empty") + return self.payload.pack(self.record_id, len(key), len(value)) + key + value + + @classmethod + def decode(cls, body): + _require_length(body, cls.payload.size, cls.__name__, exact=False) + record_id, key_length, value_length = cls.payload.unpack(body[: cls.payload.size]) + text = body[cls.payload.size :] + if key_length == 0: + raise ProtocolError("add info key must not be empty") + _require_length(text, key_length + value_length, "add info text") + key = text[:key_length].decode() + value = text[key_length:].decode() + return cls(record_id, key, value) + + +assert AddInfo.payload.size == 8 + + +@dataclass(frozen=True) +class ServerGreeting(ServerMessage): + """Server greeting sent on connection acceptance. + + Wire layout: VERSION(1). + """ + + payload: ClassVar = struct.Struct("!B") + msg_id: ClassVar = 0x8001 + + version: int = 0 + + +assert ServerGreeting.payload.size == 1 + + +@dataclass(frozen=True) +class Ping(ServerMessage): + """Keepalive ping carrying a random nonce. + + Wire layout: NONCE(4). + """ + + payload: ClassVar = struct.Struct("!I") + msg_id: ClassVar = 0x8002 + + nonce: int + + +assert Ping.payload.size == 4 + + +def _record_kind(value): + try: + return RecordKind(value) + except ValueError: + raise ProtocolError(f"unknown record kind {value}") + + +def _validate_record_lengths(kind, record_type_length, record_name_length): + if record_name_length == 0: + raise ProtocolError("add record name must not be empty") + if kind == RecordKind.RECORD and record_type_length == 0: + raise ProtocolError("record type must not be empty for records") + if kind == RecordKind.ALIAS and record_type_length != 0: + raise ProtocolError("record type must be empty for aliases") diff --git a/server/recceiver/recast.py b/server/recceiver/recast.py index 1dbeac2e..0f55ae18 100644 --- a/server/recceiver/recast.py +++ b/server/recceiver/recast.py @@ -3,7 +3,6 @@ import collections import logging import random -import struct import sys import time @@ -12,29 +11,10 @@ from zope.interface import implementer from .interfaces import ITransaction +from .protocol import messages _log = logging.getLogger(__name__) -_M = 0x5243 - -_Head = struct.Struct(">HHI") -assert _Head.size == 8 - -_ping = struct.Struct(">I") -assert _ping.size == 4 - -_s_greet = struct.Struct(">B") -assert _s_greet.size == 1 - -_c_greet = struct.Struct(">BBxxI") -assert _c_greet.size == 8 - -_c_info = struct.Struct(">IBxH") -assert _c_info.size == 8 - -_c_rec = struct.Struct(">IBBH") -assert _c_rec.size == 8 - class CastReceiver(stateful.StatefulProtocol): timeout = 3.0 @@ -50,17 +30,12 @@ def __init__(self, active=True): self.rxfn = collections.defaultdict(self.dfact) - self.rxfn[1] = (self.recvClientGreeting, _c_greet.size) - self.rxfn[2] = (self.recvPong, _ping.size) - self.rxfn[3] = (self.recvAddRec, _c_rec.size) - self.rxfn[4] = (self.recvDelRec, _ping.size) - self.rxfn[5] = (self.recvDone, -1) - self.rxfn[6] = (self.recvInfo, _c_info.size) - - def writeMsg(self, msgid, body): - head = _Head.pack(_M, msgid, len(body)) - msg = b"".join((head, body)) - self.transport.write(msg) + self.rxfn[messages.ClientGreeting.msg_id] = (self.recvClientGreeting, messages.ClientGreeting.payload.size) + self.rxfn[messages.Pong.msg_id] = (self.recvPong, messages.Pong.payload.size) + self.rxfn[messages.AddRecord.msg_id] = (self.recvAddRec, messages.AddRecord.payload.size) + self.rxfn[messages.DelRecord.msg_id] = (self.recvDelRec, messages.DelRecord.payload.size) + self.rxfn[messages.UploadDone.msg_id] = (self.recvDone, messages.UploadDone.payload.size) + self.rxfn[messages.AddInfo.msg_id] = (self.recvInfo, messages.AddInfo.payload.size) def dataReceived(self, data): self.uploadSize += len(data) @@ -71,7 +46,7 @@ def connectionMade(self): # Full speed ahead self.phase = 1 # 1: send ping, 2: receive pong self.T = self.reactor.callLater(self.timeout, self.writePing) - self.writeMsg(0x8001, _s_greet.pack(self.version)) + self.transport.write(messages.ServerGreeting(self.version).frame()) self.uploadStart = time.time() else: # apply brakes @@ -99,43 +74,57 @@ def writePing(self): self.restartPingTimer() self.phase = 2 self.nonce = random.randint(0, 0xFFFFFFFF) - self.writeMsg(0x8002, _ping.pack(self.nonce)) + self.transport.write(messages.Ping(self.nonce).frame()) _log.debug("ping nonce: " + str(self.nonce)) def getInitialState(self): - return (self.recvHeader, 8) + return (self.recvHeader, messages.Header.payload.size) def recvHeader(self, data): self.restartPingTimer() - magic, msgid, blen = _Head.unpack(data) - if magic != _M: - _log.error("Protocol error! Bad magic {magic}".format(magic=magic)) + try: + header = messages.Header.decode(data) + except messages.ProtocolError as exc: + _log.error(f"Protocol error! {exc}") self.transport.loseConnection() return - self.msgid = msgid + if header.body_length == 0: + _log.debug(f"Ignoring empty message {header.msg_id:#06x}") + return self.getInitialState() + self.msgid = header.msg_id fn, minlen = self.rxfn[self.msgid] - if minlen >= 0 and blen < minlen: - return (self.ignoreBody, blen) + if minlen >= 0 and header.body_length < minlen: + return (self.ignoreBody, header.body_length) else: - return (fn, blen) + return (fn, header.body_length) # 0x0001 def recvClientGreeting(self, body): - cver, ctype, skey = _c_greet.unpack(body[: _c_greet.size]) - if ctype != 0: - _log.error("I don't understand you! {s}".format(s=ctype)) + try: + greeting = messages.ClientGreeting.decode(body) + except messages.ProtocolError as exc: + _log.error(f"Protocol error! {exc}") + self.transport.loseConnection() + return + if greeting.client_type != 0: + _log.error(f"unsupported client type {greeting.client_type}") self.transport.loseConnection() return - self.version = min(self.version, cver) - self.clientKey = skey + self.version = min(self.version, greeting.version) + self.clientKey = greeting.server_key self.sess = self.factory.addClient(self, self.transport.getPeer()) return self.getInitialState() # 0x0002 def recvPong(self, body): - (nonce,) = _ping.unpack(body[: _ping.size]) - if nonce != self.nonce: - _log.error("pong nonce does not match! {pong_nonce}!={nonce}".format(pong_nonce=nonce, nonce=self.nonce)) + try: + pong = messages.Pong.decode(body) + except messages.ProtocolError as exc: + _log.error(f"Protocol error! {exc}") + self.transport.loseConnection() + return + if pong.nonce != self.nonce: + _log.error(f"pong nonce does not match! {pong.nonce}!={self.nonce}") self.transport.loseConnection() else: _log.debug("pong nonce match") @@ -144,47 +133,48 @@ def recvPong(self, body): # 0x0006 def recvInfo(self, body): - record_id, klen, vlen = _c_info.unpack(body[: _c_info.size]) - text = body[_c_info.size :] - text = text.decode() - if klen == 0 or klen + vlen < len(text): + try: + info = messages.AddInfo.decode(body) + except messages.ProtocolError: _log.error("Ignoring info update") return self.getInitialState() - key = text[:klen] - val = text[klen : klen + vlen] - if record_id: - self.sess.recInfo(record_id, key, val) + if info.record_id: + self.sess.recInfo(info.record_id, info.key, info.value) else: - self.sess.iocInfo(key, val) + self.sess.iocInfo(info.key, info.value) return self.getInitialState() # 0x0003 def recvAddRec(self, body): - record_id, record_type, rtlen, rnlen = _c_rec.unpack(body[: _c_rec.size]) - text = body[_c_rec.size :] - text = text.decode() - if rnlen == 0 or rtlen + rnlen < len(text): + try: + record = messages.AddRecord.decode(body) + except messages.ProtocolError: _log.error("Ignoring record update") - - elif rtlen > 0 and record_type == 0: # new record - rectype = text[:rtlen] - recname = text[rtlen : rtlen + rnlen] - self.sess.addRecord(record_id, rectype, recname) - - elif record_type == 1: # record alias - recname = text[rtlen : rtlen + rnlen] - self.sess.addAlias(record_id, recname) + return self.getInitialState() + if record.is_alias: + self.sess.addAlias(record.record_id, record.record_name) + else: + self.sess.addRecord(record.record_id, record.record_type, record.record_name) return self.getInitialState() # 0x0004 def recvDelRec(self, body): - (record_id,) = _ping.unpack(body[: _ping.size]) - self.sess.delRecord(record_id) + try: + record = messages.DelRecord.decode(body) + except messages.ProtocolError: + _log.error("Ignoring delete record update") + return self.getInitialState() + self.sess.delRecord(record.record_id) return self.getInitialState() # 0x0005 def recvDone(self, body): + try: + messages.UploadDone.decode(body) + except messages.ProtocolError: + _log.error("Ignoring done update") + return self.getInitialState() self.factory.isDone(self, self.active) self.sess.done() if self.phase == 1: diff --git a/server/tests/test_protocol.py b/server/tests/test_protocol.py new file mode 100644 index 00000000..d15699a3 --- /dev/null +++ b/server/tests/test_protocol.py @@ -0,0 +1,131 @@ +import struct + +import pytest + +from recceiver.protocol import announce, messages + + +def test_encode_server_greeting(): + assert messages.ServerGreeting(0).frame() == struct.pack( + ">HHIB", messages.PROTO_ID, messages.ServerGreeting.msg_id, 1, 0 + ) + + +def test_encode_ping(): + assert messages.Ping(0x12345678).frame() == struct.pack( + ">HHII", messages.PROTO_ID, messages.Ping.msg_id, 4, 0x12345678 + ) + + +def test_decode_header_rejects_bad_protocol_id(): + with pytest.raises(messages.ProtocolError): + messages.Header.decode(struct.pack(">HHI", 0, messages.AddRecord.msg_id, 8)) + + +def test_decode_header_rejects_partial_header(): + with pytest.raises(messages.ProtocolError): + messages.Header.decode(b"short") + + +def test_decode_client_greeting(): + greeting = messages.ClientGreeting.decode(struct.pack(">BBxxI", 3, 0, 0xCAFEF00D)) + + assert greeting == messages.ClientGreeting(version=3, client_type=0, server_key=0xCAFEF00D) + + +def test_decode_pong_requires_exact_length(): + with pytest.raises(messages.ProtocolError): + messages.Pong.decode(struct.pack(">II", 1, 2)) + + +def test_decode_add_record(): + body = struct.pack(">IBBH", 11, 0, 2, 8) + b"aiIOC1:PV1" + + record = messages.AddRecord.decode(body) + + assert record == messages.AddRecord( + record_id=11, + kind=messages.RecordKind.RECORD, + record_type="ai", + record_name="IOC1:PV1", + ) + assert not record.is_alias + + +def test_decode_add_alias(): + body = struct.pack(">IBBH", 11, 1, 0, 14) + b"IOC1:PV1:ALIAS" + + record = messages.AddRecord.decode(body) + + assert record == messages.AddRecord( + record_id=11, + kind=messages.RecordKind.ALIAS, + record_type="", + record_name="IOC1:PV1:ALIAS", + ) + assert record.is_alias + + +@pytest.mark.parametrize( + "body", + [ + struct.pack(">IBBH", 11, 0, 0, 8) + b"IOC1:PV1", + struct.pack(">IBBH", 11, 1, 2, 8) + b"aiIOC1:PV1", + struct.pack(">IBBH", 11, 2, 2, 8) + b"aiIOC1:PV1", + struct.pack(">IBBH", 11, 0, 2, 8) + b"ai", + struct.pack(">IBBH", 11, 0, 2, 0) + b"ai", + ], +) +def test_decode_add_record_rejects_malformed_body(body): + with pytest.raises(messages.ProtocolError): + messages.AddRecord.decode(body) + + +def test_decode_add_info_for_ioc_property(): + body = struct.pack(">IBxH", 0, 7, 5) + b"iocNameIOC-1" + + info = messages.AddInfo.decode(body) + + assert info == messages.AddInfo(record_id=0, key="iocName", value="IOC-1") + + +@pytest.mark.parametrize( + "body", + [ + struct.pack(">IBxH", 0, 0, 5) + b"IOC-1", + struct.pack(">IBxH", 0, 7, 5) + b"iocName", + ], +) +def test_decode_add_info_rejects_malformed_body(body): + with pytest.raises(messages.ProtocolError): + messages.AddInfo.decode(body) + + +def test_decode_delete_record(): + assert messages.DelRecord.decode(struct.pack(">I", 11)) == messages.DelRecord(11) + + +def test_decode_upload_done_uses_client_dummy_payload(): + assert messages.UploadDone.decode(struct.pack(">I", 0)) == messages.UploadDone() + + +def test_encode_announce_matches_c_layout(): + assert announce.Announce(tcp_port=1234, key=0xCAFEF00D, host="255.255.255.255").encode() == struct.pack( + ">HH4sHHI", + messages.PROTO_ID, + 0, + b"\xff\xff\xff\xff", + 1234, + 0, + 0xCAFEF00D, + ) + + +def test_decode_announce(): + packet = struct.pack(">HH4sHHI", messages.PROTO_ID, 0, b"\x7f\x00\x00\x01", 1234, 0, 0xCAFEF00D) + + assert announce.Announce.decode(packet) == announce.Announce( + tcp_port=1234, + key=0xCAFEF00D, + host="127.0.0.1", + )