958 lines
36 KiB
Python
958 lines
36 KiB
Python
import json
|
|
import socket as socket_module
|
|
import struct
|
|
import threading
|
|
|
|
import pytest
|
|
import serial
|
|
|
|
from bmspy.classes import BMSScalarField, BMSInfoField, UPS
|
|
from bmspy.jbd_bms import JBDBMS
|
|
from bmspy.server import DeviceState, parse_device, read_request, send_response
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# parse_device
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestParseDevice:
|
|
def test_plain_path(self):
|
|
assert parse_device("/dev/ttyUSB0") == ("ttyUSB0", "/dev/ttyUSB0")
|
|
|
|
def test_named_path(self):
|
|
assert parse_device("myups:/dev/ttyUSB1") == ("myups", "/dev/ttyUSB1")
|
|
|
|
def test_nested_path(self):
|
|
assert parse_device("/dev/serial/by-id/usb-FTDI") == (
|
|
"usb-FTDI",
|
|
"/dev/serial/by-id/usb-FTDI",
|
|
)
|
|
|
|
def test_name_without_slash(self):
|
|
# No "/" prefix, no ":" → treated as a plain path; last segment is name
|
|
assert parse_device("ttyUSB0") == ("ttyUSB0", "ttyUSB0")
|
|
|
|
def test_name_colon_path_no_leading_slash(self):
|
|
assert parse_device("office:/dev/ttyUSB2") == ("office", "/dev/ttyUSB2")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DeviceState
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestDeviceState:
|
|
def test_defaults(self):
|
|
ser = serial.Serial()
|
|
ds = DeviceState(ser=ser)
|
|
assert ds.data is None
|
|
assert ds.timestamp == 0.0
|
|
assert ds.ser is ser
|
|
|
|
def test_fields_are_mutable(self):
|
|
ser = serial.Serial()
|
|
ds = DeviceState(ser=ser)
|
|
ds.timestamp = 123.4
|
|
ds.data = UPS.from_dict({})
|
|
assert ds.timestamp == 123.4
|
|
assert isinstance(ds.data, UPS)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# read_request / send_response round-trip
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _send_framed(sock: socket_module.socket, data: dict) -> None:
|
|
payload = json.dumps(data).encode()
|
|
sock.sendall(struct.pack("!I", len(payload)) + payload)
|
|
|
|
|
|
def _recv_framed(sock: socket_module.socket) -> dict:
|
|
length = struct.unpack("!I", sock.recv(4))[0]
|
|
return json.loads(sock.recv(length))
|
|
|
|
|
|
class TestReadRequest:
|
|
def test_round_trip(self):
|
|
srv, cli = socket_module.socketpair()
|
|
try:
|
|
_send_framed(cli, {"command": "GET", "client": "test"})
|
|
result = read_request(srv)
|
|
assert result == {"command": "GET", "client": "test"}
|
|
finally:
|
|
srv.close()
|
|
cli.close()
|
|
|
|
def test_with_ups_filter(self):
|
|
srv, cli = socket_module.socketpair()
|
|
try:
|
|
_send_framed(cli, {"command": "GET", "client": "test", "ups": "myups"})
|
|
result = read_request(srv)
|
|
assert result["ups"] == "myups"
|
|
finally:
|
|
srv.close()
|
|
cli.close()
|
|
|
|
|
|
class TestSendResponse:
|
|
def test_plain_dict(self):
|
|
srv, cli = socket_module.socketpair()
|
|
try:
|
|
send_response(srv, {"status": "REGISTERED", "client": "test"}, "test")
|
|
result = _recv_framed(cli)
|
|
assert result == {"status": "REGISTERED", "client": "test"}
|
|
finally:
|
|
srv.close()
|
|
cli.close()
|
|
|
|
def test_ups_object_serialized_via_items(self):
|
|
"""UPS objects must be serialized using items(), not dataclass_asdict."""
|
|
bms = JBDBMS()
|
|
bms.bms_voltage_total_volts = BMSScalarField(
|
|
help="Total Voltage", raw_value=52.0, value="52.00", units="V"
|
|
)
|
|
bms.bms_manufacture_date = BMSInfoField(
|
|
help="Date of Manufacture", info="2023-01-15"
|
|
)
|
|
|
|
srv, cli = socket_module.socketpair()
|
|
try:
|
|
send_response(srv, {"myups": bms}, "test")
|
|
result = _recv_framed(cli)
|
|
finally:
|
|
srv.close()
|
|
cli.close()
|
|
|
|
assert "myups" in result
|
|
assert "bms_voltage_total_volts" in result["myups"]
|
|
assert result["myups"]["bms_voltage_total_volts"]["raw_value"] == 52.0
|
|
# None fields must not appear
|
|
assert "bms_current_amps" not in result["myups"]
|
|
# client field must not appear
|
|
assert "client" not in result["myups"]
|
|
|
|
def test_empty_ups_serializes_to_empty_dict(self):
|
|
srv, cli = socket_module.socketpair()
|
|
try:
|
|
send_response(srv, {"myups": JBDBMS()}, "test")
|
|
result = _recv_framed(cli)
|
|
finally:
|
|
srv.close()
|
|
cli.close()
|
|
assert result["myups"] == {}
|
|
|
|
def test_plain_dict_passthrough(self):
|
|
srv, cli = socket_module.socketpair()
|
|
try:
|
|
send_response(srv, {"key": "value", "number": 42}, "test")
|
|
result = _recv_framed(cli)
|
|
finally:
|
|
srv.close()
|
|
cli.close()
|
|
assert result == {"key": "value", "number": 42}
|
|
|
|
def test_closed_socket_raises_os_error(self):
|
|
srv, cli = socket_module.socketpair()
|
|
srv.close()
|
|
cli.close()
|
|
with pytest.raises(OSError):
|
|
send_response(srv, {"status": "OK"}, "test")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# signalHandler
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSignalHandler:
|
|
def test_raises_system_exit(self):
|
|
from bmspy.server import signalHandler
|
|
with pytest.raises(SystemExit):
|
|
signalHandler()
|
|
|
|
def test_exit_message_contains_terminating(self):
|
|
from bmspy.server import signalHandler
|
|
with pytest.raises(SystemExit, match="terminating"):
|
|
signalHandler()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# socket_cleanup
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSocketCleanup:
|
|
def test_removes_socket_file(self, tmp_path):
|
|
from bmspy.server import socket_cleanup
|
|
sock_file = tmp_path / "test.sock"
|
|
sock_file.touch()
|
|
assert sock_file.exists()
|
|
socket_cleanup(str(sock_file))
|
|
assert not sock_file.exists()
|
|
|
|
def test_raises_when_file_missing(self, tmp_path):
|
|
from bmspy.server import socket_cleanup
|
|
with pytest.raises(FileNotFoundError):
|
|
socket_cleanup(str(tmp_path / "nonexistent.sock"))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# read_request — error paths
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestReadRequestErrors:
|
|
def test_invalid_json_raises_exception(self):
|
|
srv, cli = socket_module.socketpair()
|
|
try:
|
|
invalid_payload = b"not valid json !!!"
|
|
cli.sendall(struct.pack("!I", len(invalid_payload)) + invalid_payload)
|
|
with pytest.raises(Exception, match="unable to read incoming request"):
|
|
read_request(srv)
|
|
finally:
|
|
srv.close()
|
|
cli.close()
|
|
|
|
def test_truncated_length_bytes_raises(self):
|
|
srv, cli = socket_module.socketpair()
|
|
try:
|
|
cli.sendall(b"\x00\x00") # only 2 of 4 length bytes, then close
|
|
cli.close()
|
|
with pytest.raises(Exception):
|
|
read_request(srv)
|
|
finally:
|
|
srv.close()
|
|
|
|
def test_recv_raises_os_error(self):
|
|
"""When recv raises on first read, read_request should raise OSError."""
|
|
from unittest.mock import MagicMock
|
|
mock_conn = MagicMock()
|
|
mock_conn.recv.side_effect = OSError("connection reset")
|
|
with pytest.raises(OSError, match="unable to read request length"):
|
|
read_request(mock_conn)
|
|
|
|
def test_recv_body_raises_os_error(self):
|
|
"""When recv raises on second read (body), should raise OSError."""
|
|
from unittest.mock import MagicMock
|
|
import struct
|
|
mock_conn = MagicMock()
|
|
length_bytes = struct.pack("!I", 10)
|
|
# First recv returns valid length bytes, second recv raises
|
|
mock_conn.recv.side_effect = [length_bytes, OSError("body read error")]
|
|
with pytest.raises(OSError, match="unable to read socket"):
|
|
read_request(mock_conn)
|
|
|
|
def test_debug_5_logs_length(self, capsys):
|
|
srv, cli = socket_module.socketpair()
|
|
try:
|
|
_send_framed(cli, {"command": "GET", "client": "test"})
|
|
read_request(srv, debug=5)
|
|
finally:
|
|
srv.close()
|
|
cli.close()
|
|
# debug > 4 logs incoming length
|
|
captured = capsys.readouterr()
|
|
assert "incoming length" in captured.out
|
|
|
|
def test_debug_4_logs_request_bytes(self, capsys):
|
|
srv, cli = socket_module.socketpair()
|
|
try:
|
|
_send_framed(cli, {"command": "GET", "client": "test"})
|
|
read_request(srv, debug=4)
|
|
finally:
|
|
srv.close()
|
|
cli.close()
|
|
captured = capsys.readouterr()
|
|
assert "incoming request" in captured.out
|
|
|
|
def test_debug_3_logs_received(self, capsys):
|
|
srv, cli = socket_module.socketpair()
|
|
try:
|
|
_send_framed(cli, {"command": "GET", "client": "test"})
|
|
read_request(srv, debug=3)
|
|
finally:
|
|
srv.close()
|
|
cli.close()
|
|
captured = capsys.readouterr()
|
|
assert "received" in captured.out
|
|
|
|
|
|
class TestServerMain:
|
|
"""Test the server main() function by running it in a thread and sending real socket commands."""
|
|
|
|
def _make_server_thread(self, sock_path: str, ready_event: threading.Event,
|
|
stop_event: threading.Event, **kwargs):
|
|
"""Run server main() in a thread with mocked serial and collect_data."""
|
|
import socket as _socket
|
|
from unittest.mock import MagicMock, patch
|
|
from bmspy.server import main as server_main
|
|
from bmspy.classes import BMSScalarField
|
|
from bmspy.jbd_bms import JBDBMS
|
|
|
|
# Build a fake JBDBMS result
|
|
fake_bms = JBDBMS()
|
|
fake_bms.bms_voltage_total_volts = BMSScalarField(
|
|
help="Voltage", raw_value=52.0, value="52.00", units="V"
|
|
)
|
|
|
|
def _do_main():
|
|
import sys
|
|
import time as _t
|
|
argv = ["bmspy-server", "--socket", sock_path, "--device", "/dev/ttyUSB0"]
|
|
if "debug" in kwargs:
|
|
argv += ["-v"] * kwargs["debug"]
|
|
|
|
original_listen = socket_module.socket.listen
|
|
|
|
def _patched_listen(self, backlog=1):
|
|
result = original_listen(self, backlog)
|
|
ready_event.set()
|
|
return result
|
|
|
|
with patch("sys.argv", argv), \
|
|
patch("bmspy.server.signal.signal"), \
|
|
patch("bmspy.server.initialise_serial", return_value=MagicMock()), \
|
|
patch("bmspy.server.collect_data", return_value=fake_bms), \
|
|
patch("bmspy.server.time.sleep"), \
|
|
patch("bmspy.server.os.path.isdir", return_value=True), \
|
|
patch("bmspy.server.os.path.exists", return_value=False), \
|
|
patch.object(socket_module.socket, "listen", _patched_listen):
|
|
try:
|
|
server_main()
|
|
except (SystemExit, KeyboardInterrupt, OSError):
|
|
pass
|
|
|
|
t = threading.Thread(target=_do_main, daemon=True)
|
|
return t
|
|
|
|
def _send_command(self, sock_path: str, cmd: dict) -> dict:
|
|
"""Connect to server socket and send a command."""
|
|
import socket as _socket
|
|
import struct
|
|
import json
|
|
|
|
sock = _socket.socket(_socket.AF_UNIX, _socket.SOCK_STREAM)
|
|
# Wait for server to be ready
|
|
for _ in range(20):
|
|
try:
|
|
sock.connect(sock_path)
|
|
break
|
|
except (OSError, ConnectionRefusedError):
|
|
import time
|
|
time.sleep(0.1)
|
|
|
|
payload = json.dumps(cmd).encode()
|
|
sock.sendall(struct.pack("!I", len(payload)) + payload)
|
|
|
|
# Read response
|
|
raw_len = sock.recv(4)
|
|
if not raw_len:
|
|
return {}
|
|
length = struct.unpack("!I", raw_len)[0]
|
|
resp_data = sock.recv(length)
|
|
sock.close()
|
|
return json.loads(resp_data)
|
|
|
|
def test_register_command(self, tmp_path):
|
|
"""Test REGISTER command via server main()."""
|
|
sock_path = str(tmp_path / "server.sock")
|
|
ready = threading.Event()
|
|
stop = threading.Event()
|
|
|
|
t = self._make_server_thread(sock_path, ready, stop)
|
|
t.start()
|
|
ready.wait(timeout=2)
|
|
import time as _time
|
|
_time.sleep(0.2) # Let server get to sock.accept()
|
|
|
|
try:
|
|
response = self._send_command(sock_path, {"command": "REGISTER", "client": "test"})
|
|
assert response.get("status") == "REGISTERED"
|
|
finally:
|
|
import os
|
|
# Trigger server shutdown by connecting and sending KeyboardInterrupt-triggering command
|
|
try:
|
|
import socket as _s, struct, json
|
|
s = _s.socket(_s.AF_UNIX, _s.SOCK_STREAM)
|
|
s.connect(sock_path)
|
|
# Send DEREGISTER
|
|
payload = json.dumps({"command": "DEREGISTER", "client": "test"}).encode()
|
|
s.sendall(struct.pack("!I", len(payload)) + payload)
|
|
s.recv(4)
|
|
s.close()
|
|
except Exception:
|
|
pass
|
|
t.join(timeout=0.5)
|
|
|
|
def test_get_command(self, tmp_path):
|
|
"""Test GET command via server main()."""
|
|
sock_path = str(tmp_path / "server_get.sock")
|
|
ready = threading.Event()
|
|
stop = threading.Event()
|
|
|
|
t = self._make_server_thread(sock_path, ready, stop)
|
|
t.start()
|
|
ready.wait(timeout=2)
|
|
import time as _time
|
|
_time.sleep(0.2)
|
|
|
|
try:
|
|
response = self._send_command(sock_path, {"command": "GET", "client": "test"})
|
|
assert "ttyUSB0" in response or len(response) >= 0
|
|
finally:
|
|
t.join(timeout=0.5)
|
|
|
|
def test_deregister_command(self, tmp_path):
|
|
"""Test DEREGISTER command via server main()."""
|
|
sock_path = str(tmp_path / "server_dereg.sock")
|
|
ready = threading.Event()
|
|
stop = threading.Event()
|
|
|
|
t = self._make_server_thread(sock_path, ready, stop)
|
|
t.start()
|
|
ready.wait(timeout=2)
|
|
import time as _t
|
|
_t.sleep(0.2)
|
|
|
|
try:
|
|
# First register
|
|
r1 = self._send_command(sock_path, {"command": "REGISTER", "client": "test"})
|
|
assert r1.get("status") == "REGISTERED"
|
|
# Then deregister
|
|
r2 = self._send_command(sock_path, {"command": "DEREGISTER", "client": "test"})
|
|
assert r2.get("status") == "DEREGISTERED"
|
|
finally:
|
|
t.join(timeout=0.5)
|
|
|
|
def test_get_with_ups_filter(self, tmp_path):
|
|
"""Test GET command with ups filter."""
|
|
sock_path = str(tmp_path / "server_getups.sock")
|
|
ready = threading.Event()
|
|
stop = threading.Event()
|
|
|
|
t = self._make_server_thread(sock_path, ready, stop)
|
|
t.start()
|
|
ready.wait(timeout=2)
|
|
import time as _t
|
|
_t.sleep(0.2)
|
|
|
|
try:
|
|
response = self._send_command(sock_path, {"command": "GET", "client": "test", "ups": "ttyUSB0"})
|
|
# Should get ttyUSB0 data or empty (ups filter)
|
|
assert isinstance(response, dict)
|
|
finally:
|
|
t.join(timeout=0.5)
|
|
|
|
def test_debug_mode_verbose(self, tmp_path, capsys):
|
|
"""Test server main() with debug=1 logs messages."""
|
|
sock_path = str(tmp_path / "server_dbg.sock")
|
|
ready = threading.Event()
|
|
stop = threading.Event()
|
|
|
|
t = self._make_server_thread(sock_path, ready, stop, debug=1)
|
|
t.start()
|
|
ready.wait(timeout=2)
|
|
import time as _t
|
|
_t.sleep(0.2)
|
|
|
|
try:
|
|
self._send_command(sock_path, {"command": "GET", "client": "test"})
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
t.join(timeout=0.5)
|
|
|
|
def test_duplicate_device_name_skipped(self, tmp_path, capsys):
|
|
"""Test that duplicate UPS names are skipped."""
|
|
sock_path = str(tmp_path / "server_dup.sock")
|
|
ready = threading.Event()
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
from bmspy.server import main as server_main
|
|
from bmspy.classes import BMSScalarField
|
|
from bmspy.jbd_bms import JBDBMS
|
|
|
|
fake_bms = JBDBMS()
|
|
fake_bms.bms_voltage_total_volts = BMSScalarField(
|
|
help="Voltage", raw_value=52.0, value="52.00", units="V"
|
|
)
|
|
|
|
def _do_main_dup():
|
|
argv = ["bmspy-server", "--socket", sock_path,
|
|
"--device", "myups:/dev/ttyUSB0",
|
|
"--device", "myups:/dev/ttyUSB1"] # duplicate name
|
|
|
|
original_listen = socket_module.socket.listen
|
|
|
|
def _patched_listen(self, backlog=1):
|
|
result = original_listen(self, backlog)
|
|
ready.set()
|
|
return result
|
|
|
|
with patch("sys.argv", argv), \
|
|
patch("bmspy.server.signal.signal"), \
|
|
patch("bmspy.server.initialise_serial", return_value=MagicMock()), \
|
|
patch("bmspy.server.collect_data", return_value=fake_bms), \
|
|
patch("bmspy.server.time.sleep"), \
|
|
patch("bmspy.server.os.path.isdir", return_value=True), \
|
|
patch("bmspy.server.os.path.exists", return_value=False), \
|
|
patch.object(socket_module.socket, "listen", _patched_listen):
|
|
try:
|
|
server_main()
|
|
except (SystemExit, KeyboardInterrupt, OSError):
|
|
pass
|
|
|
|
t = threading.Thread(target=_do_main_dup, daemon=True)
|
|
t.start()
|
|
ready.wait(timeout=2)
|
|
import time as _t
|
|
_t.sleep(0.2)
|
|
# Server should have started with one device
|
|
response = self._send_command(sock_path, {"command": "GET", "client": "test"})
|
|
assert "myups" in response
|
|
t.join(timeout=0.5)
|
|
|
|
def test_socket_dir_created_if_missing(self, tmp_path):
|
|
"""Test that socket dir is created when it doesn't exist."""
|
|
sock_path = str(tmp_path / "server_mkdir.sock")
|
|
ready = threading.Event()
|
|
|
|
from unittest.mock import MagicMock, patch, call
|
|
from bmspy.server import main as server_main
|
|
from bmspy.jbd_bms import JBDBMS
|
|
|
|
fake_bms = JBDBMS()
|
|
makedirs_called = []
|
|
|
|
def _do_main_mkdir():
|
|
argv = ["bmspy-server", "--socket", sock_path, "--device", "/dev/ttyUSB0"]
|
|
|
|
original_listen = socket_module.socket.listen
|
|
|
|
def _patched_listen(self, backlog=1):
|
|
result = original_listen(self, backlog)
|
|
ready.set()
|
|
return result
|
|
|
|
def _patched_makedirs(path, exist_ok=False):
|
|
makedirs_called.append(path)
|
|
# Don't call actual makedirs to avoid recursion - socket dir is already tmp_path
|
|
|
|
with patch("sys.argv", argv), \
|
|
patch("bmspy.server.signal.signal"), \
|
|
patch("bmspy.server.initialise_serial", return_value=MagicMock()), \
|
|
patch("bmspy.server.collect_data", return_value=fake_bms), \
|
|
patch("bmspy.server.time.sleep"), \
|
|
patch("bmspy.server.os.path.exists", return_value=False), \
|
|
patch("bmspy.server.os.path.isdir", return_value=False), \
|
|
patch("bmspy.server.os.makedirs", side_effect=_patched_makedirs), \
|
|
patch.object(socket_module.socket, "listen", _patched_listen):
|
|
try:
|
|
server_main()
|
|
except (SystemExit, KeyboardInterrupt, OSError):
|
|
pass
|
|
|
|
t = threading.Thread(target=_do_main_mkdir, daemon=True)
|
|
t.start()
|
|
ready.wait(timeout=5)
|
|
import time as _t
|
|
_t.sleep(0.2)
|
|
try:
|
|
self._send_command(sock_path, {"command": "REGISTER", "client": "test"})
|
|
except Exception:
|
|
pass
|
|
t.join(timeout=0.5)
|
|
assert len(makedirs_called) > 0
|
|
|
|
def test_debug_3_logs_startup(self, tmp_path, capsys):
|
|
"""Test debug=3 logs 'starting up' message."""
|
|
sock_path = str(tmp_path / "server_dbg3.sock")
|
|
ready = threading.Event()
|
|
stop = threading.Event()
|
|
|
|
t = self._make_server_thread(sock_path, ready, stop, debug=3)
|
|
t.start()
|
|
ready.wait(timeout=2)
|
|
import time as _t
|
|
_t.sleep(0.2)
|
|
|
|
try:
|
|
self._send_command(sock_path, {"command": "REGISTER", "client": "test"})
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
t.join(timeout=0.5)
|
|
captured = capsys.readouterr()
|
|
# debug>2 triggers "starting up" and "waiting for connection"
|
|
assert "starting up" in captured.out.lower() or "waiting" in captured.out.lower()
|
|
|
|
def test_socket_already_exists_raises(self, tmp_path):
|
|
"""Test that server raises OSError if socket already exists."""
|
|
sock_path = str(tmp_path / "existing.sock")
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
from bmspy.server import main as server_main
|
|
|
|
with patch("sys.argv", ["bmspy-server", "--socket", sock_path, "--device", "/dev/ttyUSB0"]), \
|
|
patch("bmspy.server.signal.signal"), \
|
|
patch("bmspy.server.initialise_serial", return_value=MagicMock()), \
|
|
patch("bmspy.server.os.path.isdir", return_value=True), \
|
|
patch("bmspy.server.os.path.exists", return_value=True):
|
|
with pytest.raises(OSError, match="already exists"):
|
|
server_main()
|
|
|
|
def test_deregister_nonexistent_client_no_error(self, tmp_path):
|
|
"""Test DEREGISTER for a client that was never registered (KeyError suppressed)."""
|
|
sock_path = str(tmp_path / "server_noerr.sock")
|
|
ready = threading.Event()
|
|
stop = threading.Event()
|
|
|
|
t = self._make_server_thread(sock_path, ready, stop)
|
|
t.start()
|
|
ready.wait(timeout=2)
|
|
import time as _t
|
|
_t.sleep(0.2)
|
|
|
|
try:
|
|
# Deregister without first registering
|
|
response = self._send_command(sock_path, {"command": "DEREGISTER", "client": "ghost"})
|
|
assert response.get("status") == "DEREGISTERED"
|
|
finally:
|
|
t.join(timeout=0.5)
|
|
|
|
def test_keyboard_interrupt_closes_connection(self, tmp_path):
|
|
"""Test KeyboardInterrupt handler closes connection when connection is active."""
|
|
sock_path = str(tmp_path / "server_kbi.sock")
|
|
ready = threading.Event()
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
from bmspy.server import main as server_main, read_request as original_read_request
|
|
from bmspy.jbd_bms import JBDBMS
|
|
|
|
fake_bms = JBDBMS()
|
|
call_count = [0]
|
|
|
|
def _do_main_kbi():
|
|
argv = ["bmspy-server", "--socket", sock_path, "--device", "/dev/ttyUSB0"]
|
|
|
|
original_listen = socket_module.socket.listen
|
|
|
|
def _patched_listen(self, backlog=1):
|
|
result = original_listen(self, backlog)
|
|
ready.set()
|
|
return result
|
|
|
|
def _patched_read_request(conn, debug=0):
|
|
call_count[0] += 1
|
|
if call_count[0] >= 2:
|
|
raise KeyboardInterrupt("test interrupt")
|
|
return original_read_request(conn, debug)
|
|
|
|
with patch("sys.argv", argv), \
|
|
patch("bmspy.server.signal.signal"), \
|
|
patch("bmspy.server.initialise_serial", return_value=MagicMock()), \
|
|
patch("bmspy.server.collect_data", return_value=fake_bms), \
|
|
patch("bmspy.server.time.sleep"), \
|
|
patch("bmspy.server.os.path.isdir", return_value=True), \
|
|
patch("bmspy.server.os.path.exists", return_value=False), \
|
|
patch("bmspy.server.read_request", side_effect=_patched_read_request), \
|
|
patch.object(socket_module.socket, "listen", _patched_listen):
|
|
try:
|
|
server_main()
|
|
except SystemExit:
|
|
pass
|
|
|
|
t = threading.Thread(target=_do_main_kbi, daemon=True)
|
|
t.start()
|
|
ready.wait(timeout=2)
|
|
import time as _t
|
|
_t.sleep(0.2)
|
|
# Send first request to register call_count[0] = 1
|
|
try:
|
|
self._send_command(sock_path, {"command": "REGISTER", "client": "test"})
|
|
except Exception:
|
|
pass
|
|
_t.sleep(0.1)
|
|
# Send second request to trigger KeyboardInterrupt with active connection
|
|
try:
|
|
self._send_command(sock_path, {"command": "REGISTER", "client": "test2"})
|
|
except Exception:
|
|
pass
|
|
t.join(timeout=2)
|
|
|
|
def test_socket_read_error_logs_and_continues(self, tmp_path, capsys):
|
|
"""Test that read_request errors are caught and logged."""
|
|
sock_path = str(tmp_path / "server_err.sock")
|
|
ready = threading.Event()
|
|
stop = threading.Event()
|
|
|
|
t = self._make_server_thread(sock_path, ready, stop)
|
|
t.start()
|
|
ready.wait(timeout=2)
|
|
import time as _t
|
|
_t.sleep(0.2)
|
|
|
|
try:
|
|
# Connect but send garbage that will cause read_request to fail
|
|
import socket as _s
|
|
sock = _s.socket(_s.AF_UNIX, _s.SOCK_STREAM)
|
|
for _ in range(20):
|
|
try:
|
|
sock.connect(sock_path)
|
|
break
|
|
except (OSError, ConnectionRefusedError):
|
|
_t.sleep(0.05)
|
|
# Send only 2 bytes (incomplete length header)
|
|
sock.sendall(b"\x00\x00")
|
|
sock.close()
|
|
_t.sleep(0.1) # Give server time to process
|
|
finally:
|
|
t.join(timeout=0.5)
|
|
|
|
def test_root_user_socket_dir_created(self, tmp_path, capsys):
|
|
"""Test root user path when socket dir doesn't exist (chown/chmod triggered)."""
|
|
sock_path = str(tmp_path / "server_root_mkdir.sock")
|
|
ready = threading.Event()
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
from bmspy.server import main as server_main
|
|
from bmspy.jbd_bms import JBDBMS
|
|
|
|
fake_bms = JBDBMS()
|
|
|
|
def _do_main():
|
|
argv = ["bmspy-server", "--socket", sock_path, "--device", "/dev/ttyUSB0", "-v", "-v"]
|
|
|
|
original_listen = socket_module.socket.listen
|
|
|
|
def _patched_listen(self, backlog=1):
|
|
result = original_listen(self, backlog)
|
|
ready.set()
|
|
return result
|
|
|
|
mock_pwd_module = MagicMock()
|
|
mock_pwd_module.getpwnam.return_value = [None, None, 65534]
|
|
mock_pwd_module.getpwuid.return_value = ["nobody"]
|
|
mock_grp_module = MagicMock()
|
|
mock_grp_module.getgrnam.return_value = [None, None, 65534]
|
|
mock_grp_module.getgrgid.return_value = ["dialout"]
|
|
|
|
import sys as _sys
|
|
|
|
with patch("sys.argv", argv), \
|
|
patch("bmspy.server.signal.signal"), \
|
|
patch("bmspy.server.initialise_serial", return_value=MagicMock()), \
|
|
patch("bmspy.server.collect_data", return_value=fake_bms), \
|
|
patch("bmspy.server.time.sleep"), \
|
|
patch("bmspy.server.os.path.isdir", return_value=False), \
|
|
patch("bmspy.server.os.makedirs"), \
|
|
patch("bmspy.server.os.chown"), \
|
|
patch("bmspy.server.os.chmod"), \
|
|
patch("bmspy.server.os.path.exists", return_value=False), \
|
|
patch("bmspy.server.os.getuid", return_value=0), \
|
|
patch("bmspy.server.os.getgid", return_value=0), \
|
|
patch("bmspy.server.os.setuid"), \
|
|
patch("bmspy.server.os.setgid"), \
|
|
patch("bmspy.server.os.umask", return_value=0o022), \
|
|
patch.dict(_sys.modules, {"pwd": mock_pwd_module, "grp": mock_grp_module}), \
|
|
patch.object(socket_module.socket, "listen", _patched_listen):
|
|
try:
|
|
server_main()
|
|
except (SystemExit, KeyboardInterrupt, OSError):
|
|
pass
|
|
|
|
t = threading.Thread(target=_do_main, daemon=True)
|
|
t.start()
|
|
ready.wait(timeout=2)
|
|
import time as _t
|
|
_t.sleep(0.2)
|
|
try:
|
|
self._send_command(sock_path, {"command": "REGISTER", "client": "test"})
|
|
except Exception:
|
|
pass
|
|
t.join(timeout=0.5)
|
|
|
|
def test_root_user_setgid_error(self, tmp_path, capsys):
|
|
"""Test root user path when setgid raises OSError."""
|
|
sock_path = str(tmp_path / "server_setgid_err.sock")
|
|
ready = threading.Event()
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
from bmspy.server import main as server_main
|
|
from bmspy.jbd_bms import JBDBMS
|
|
|
|
fake_bms = JGDBMS() if False else JBDBMS()
|
|
|
|
def _do_main():
|
|
argv = ["bmspy-server", "--socket", sock_path, "--device", "/dev/ttyUSB0"]
|
|
|
|
original_listen = socket_module.socket.listen
|
|
|
|
def _patched_listen(self, backlog=1):
|
|
result = original_listen(self, backlog)
|
|
ready.set()
|
|
return result
|
|
|
|
mock_pwd_module = MagicMock()
|
|
mock_pwd_module.getpwnam.return_value = [None, None, 65534]
|
|
mock_pwd_module.getpwuid.return_value = ["nobody"]
|
|
mock_grp_module = MagicMock()
|
|
mock_grp_module.getgrnam.return_value = [None, None, 65534]
|
|
mock_grp_module.getgrgid.return_value = ["dialout"]
|
|
|
|
import sys as _sys
|
|
|
|
with patch("sys.argv", argv), \
|
|
patch("bmspy.server.signal.signal"), \
|
|
patch("bmspy.server.initialise_serial", return_value=MagicMock()), \
|
|
patch("bmspy.server.collect_data", return_value=JBDBMS()), \
|
|
patch("bmspy.server.time.sleep"), \
|
|
patch("bmspy.server.os.path.isdir", return_value=True), \
|
|
patch("bmspy.server.os.path.exists", return_value=False), \
|
|
patch("bmspy.server.os.getuid", return_value=0), \
|
|
patch("bmspy.server.os.getgid", return_value=0), \
|
|
patch("bmspy.server.os.setgid", side_effect=OSError("cannot set gid")), \
|
|
patch("bmspy.server.os.setuid", side_effect=OSError("cannot set uid")), \
|
|
patch("bmspy.server.os.umask", return_value=0o022), \
|
|
patch.dict(_sys.modules, {"pwd": mock_pwd_module, "grp": mock_grp_module}), \
|
|
patch.object(socket_module.socket, "listen", _patched_listen):
|
|
try:
|
|
server_main()
|
|
except (SystemExit, KeyboardInterrupt, OSError):
|
|
pass
|
|
|
|
t = threading.Thread(target=_do_main, daemon=True)
|
|
t.start()
|
|
ready.wait(timeout=2)
|
|
import time as _t
|
|
_t.sleep(0.2)
|
|
try:
|
|
self._send_command(sock_path, {"command": "REGISTER", "client": "test"})
|
|
except Exception:
|
|
pass
|
|
t.join(timeout=0.5)
|
|
captured = capsys.readouterr()
|
|
# Should log errors about setgid/setuid
|
|
assert "gid" in captured.out.lower() or "uid" in captured.out.lower()
|
|
|
|
def test_root_user_uid_gid_handling(self, tmp_path, capsys):
|
|
"""Test the uid==0 path for privilege dropping."""
|
|
sock_path = str(tmp_path / "server_root.sock")
|
|
ready = threading.Event()
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
from bmspy.server import main as server_main
|
|
from bmspy.jbd_bms import JBDBMS
|
|
|
|
fake_bms = JBDBMS()
|
|
|
|
def _do_main_root():
|
|
argv = ["bmspy-server", "--socket", sock_path, "--device", "/dev/ttyUSB0"]
|
|
|
|
original_listen = socket_module.socket.listen
|
|
|
|
def _patched_listen(self, backlog=1):
|
|
result = original_listen(self, backlog)
|
|
ready.set()
|
|
return result
|
|
|
|
mock_pwd_module = MagicMock()
|
|
mock_pwd_module.getpwnam.return_value = [None, None, 65534] # nobody uid
|
|
mock_pwd_module.getpwuid.return_value = ["nobody"]
|
|
mock_grp_module = MagicMock()
|
|
mock_grp_module.getgrnam.return_value = [None, None, 65534] # dialout gid
|
|
mock_grp_module.getgrgid.return_value = ["dialout"]
|
|
|
|
import sys as _sys
|
|
|
|
with patch("sys.argv", argv), \
|
|
patch("bmspy.server.signal.signal"), \
|
|
patch("bmspy.server.initialise_serial", return_value=MagicMock()), \
|
|
patch("bmspy.server.collect_data", return_value=fake_bms), \
|
|
patch("bmspy.server.time.sleep"), \
|
|
patch("bmspy.server.os.path.isdir", return_value=True), \
|
|
patch("bmspy.server.os.path.exists", return_value=False), \
|
|
patch("bmspy.server.os.getuid", return_value=0), \
|
|
patch("bmspy.server.os.getgid", return_value=0), \
|
|
patch("bmspy.server.os.setuid"), \
|
|
patch("bmspy.server.os.setgid"), \
|
|
patch("bmspy.server.os.umask", return_value=0o022), \
|
|
patch.dict(_sys.modules, {"pwd": mock_pwd_module, "grp": mock_grp_module}), \
|
|
patch.object(socket_module.socket, "listen", _patched_listen):
|
|
try:
|
|
server_main()
|
|
except (SystemExit, KeyboardInterrupt, OSError):
|
|
pass
|
|
|
|
t = threading.Thread(target=_do_main_root, daemon=True)
|
|
t.start()
|
|
ready.wait(timeout=2)
|
|
import time as _t
|
|
_t.sleep(0.2)
|
|
try:
|
|
self._send_command(sock_path, {"command": "REGISTER", "client": "test"})
|
|
except Exception:
|
|
pass
|
|
t.join(timeout=0.5)
|
|
|
|
def test_invalid_command_breaks_loop(self, tmp_path, capsys):
|
|
"""Test that an invalid command logs an error."""
|
|
sock_path = str(tmp_path / "server_invalid.sock")
|
|
ready = threading.Event()
|
|
stop = threading.Event()
|
|
|
|
t = self._make_server_thread(sock_path, ready, stop)
|
|
t.start()
|
|
ready.wait(timeout=2)
|
|
import time as _time
|
|
_time.sleep(0.2)
|
|
|
|
try:
|
|
# Send invalid command - note: server breaks on invalid, so this may fail
|
|
import socket as _s, struct, json
|
|
sock = _s.socket(_s.AF_UNIX, _s.SOCK_STREAM)
|
|
for _ in range(20):
|
|
try:
|
|
sock.connect(sock_path)
|
|
break
|
|
except (OSError, ConnectionRefusedError):
|
|
_time.sleep(0.05)
|
|
payload = json.dumps({"command": "INVALID", "client": "test"}).encode()
|
|
sock.sendall(struct.pack("!I", len(payload)) + payload)
|
|
sock.close()
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
t.join(timeout=0.5)
|
|
|
|
|
|
class TestSendResponseDebug:
|
|
def test_debug_3_logs_sending(self, capsys):
|
|
srv, cli = socket_module.socketpair()
|
|
try:
|
|
send_response(srv, {"status": "OK"}, "test", debug=3)
|
|
_recv_framed(cli)
|
|
finally:
|
|
srv.close()
|
|
cli.close()
|
|
captured = capsys.readouterr()
|
|
assert "sending" in captured.out
|
|
|
|
def test_debug_5_logs_length(self, capsys):
|
|
srv, cli = socket_module.socketpair()
|
|
try:
|
|
send_response(srv, {"status": "OK"}, "test", debug=5)
|
|
_recv_framed(cli)
|
|
finally:
|
|
srv.close()
|
|
cli.close()
|
|
captured = capsys.readouterr()
|
|
assert "length" in captured.out
|
|
|
|
def test_debug_4_logs_response(self, capsys):
|
|
srv, cli = socket_module.socketpair()
|
|
try:
|
|
send_response(srv, {"status": "OK"}, "test", debug=4)
|
|
_recv_framed(cli)
|
|
finally:
|
|
srv.close()
|
|
cli.close()
|
|
captured = capsys.readouterr()
|
|
assert "outgoing response" in captured.out
|