diff --git a/bmspy/__init__.py b/bmspy/__init__.py index 1e17e0b..aca74e2 100644 --- a/bmspy/__init__.py +++ b/bmspy/__init__.py @@ -8,7 +8,9 @@ import pprint from bmspy.utilities import debugger -def parse_args(): + + +def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Query JBD BMS and report status", add_help=True, @@ -115,7 +117,7 @@ def parse_args(): return args -def main(): +def main() -> None: try: args = parse_args() @@ -163,7 +165,7 @@ def main(): client.handle_registration(args.socket, "bmspy", debug) atexit.register(client.handle_registration, args.socket, "bmspy", debug) - # {ups_name: JBDUPS} + # {ups_name: UPS} data = client.read_data(args.socket, "bmspy", ups=args.ups, debug=debug) if args.report_json: diff --git a/bmspy/classes.py b/bmspy/classes.py index 85f8a73..1427c2e 100644 --- a/bmspy/classes.py +++ b/bmspy/classes.py @@ -1,4 +1,7 @@ -from dataclasses import dataclass, fields as dataclass_fields +from abc import ABC, abstractmethod +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any @dataclass @@ -10,7 +13,7 @@ class BMSScalarField: value: str units: str | None = None - def get(self, key, default=None): + def get(self, key: str, default: Any = None) -> Any: return getattr(self, key, default) @@ -20,11 +23,11 @@ class BMSMultiField: help: str label: str - raw_values: dict - values: dict + raw_values: dict[int, float | bool] + values: dict[int, str] units: str | None = None - def get(self, key, default=None): + def get(self, key: str, default: Any = None) -> Any: return getattr(self, key, default) @@ -35,5 +38,44 @@ class BMSInfoField: help: str info: str - def get(self, key, default=None): + def get(self, key: str, default: Any = None) -> Any: return getattr(self, key, default) + + +type BMSField = BMSScalarField | BMSMultiField | BMSInfoField + + +def _field_from_dict(d: dict[str, Any]) -> BMSField: + """Reconstruct a BMSField from its JSON-serialized dict.""" + if "raw_values" in d: + return BMSMultiField(**d) + elif "info" in d: + return BMSInfoField(**d) + else: + return BMSScalarField(**d) + + +class UPS(ABC): + """Abstract base class for all UPS/BMS device data containers.""" + + @abstractmethod + def items(self) -> Iterator[tuple[str, BMSField]]: + ... + + def __bool__(self) -> bool: + return next(iter(self.items()), None) is not None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "UPS": + """Reconstruct a UPS snapshot from a JSON-decoded field dict.""" + return _UPSSnapshot({name: _field_from_dict(v) for name, v in data.items()}) + + +class _UPSSnapshot(UPS): + """Generic UPS snapshot reconstructed from serialized data by the client.""" + + def __init__(self, fields: dict[str, BMSField]) -> None: + self._fields = fields + + def items(self) -> Iterator[tuple[str, BMSField]]: + yield from self._fields.items() diff --git a/bmspy/client.py b/bmspy/client.py index eeebe7d..cf0f6b9 100644 --- a/bmspy/client.py +++ b/bmspy/client.py @@ -5,18 +5,21 @@ import sys import struct import json import socket +from typing import Any + from bmspy.utilities import debugger +from bmspy.classes import UPS is_registered = False -def handle_registration(socket_path, client_name, debug=0): +def handle_registration(socket_path: str, client_name: str, debug: int = 0) -> dict[str, Any]: global is_registered - data = dict() + data: dict[str, Any] = dict() if is_registered: - message = {"command": "DEREGISTER", "client": client_name} + message: dict[str, Any] = {"command": "DEREGISTER", "client": client_name} else: # fork server if it's not already running message = {"command": "REGISTER", "client": client_name} @@ -41,8 +44,8 @@ def handle_registration(socket_path, client_name, debug=0): return data -def socket_comms(socket_path, request_data, debug=0): - response_data = dict() +def socket_comms(socket_path: str, request_data: dict[str, Any], debug: int = 0) -> dict[str, Any]: + response_data: dict[str, Any] = dict() # Create a UDS socket sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) @@ -106,9 +109,9 @@ def socket_comms(socket_path, request_data, debug=0): return response_data -def read_data(socket_path, client_name, ups=None, debug=0): - """Return {ups_name: JBDUPS} for all UPSes, or just the named one.""" - request = {"command": "GET", "client": client_name} +def read_data(socket_path: str, client_name: str, ups: str | None = None, debug: int = 0) -> dict[str, UPS]: + """Return {ups_name: UPS} for all devices, or just the named one.""" + request: dict[str, Any] = {"command": "GET", "client": client_name} if ups is not None: request["ups"] = ups @@ -117,7 +120,7 @@ def read_data(socket_path, client_name, ups=None, debug=0): if data is None: raise RuntimeError("No data received from daemon") - return data + return {name: UPS.from_dict(fields) for name, fields in data.items()} if __name__ == "__main__": diff --git a/bmspy/influxdb.py b/bmspy/influxdb.py index 04e58e3..5280a43 100644 --- a/bmspy/influxdb.py +++ b/bmspy/influxdb.py @@ -5,26 +5,27 @@ import sys import time from influxdb_client_3 import InfluxDBClient3, Point from bmspy import client +from bmspy.classes import UPS from bmspy.utilities import debugger DAEMON_UPDATE_PERIOD = 30 -def influx_shutdown(influxclient): +def influx_shutdown(influxclient: InfluxDBClient3 | None) -> None: if influxclient is not None: influxclient.close() def influxdb_export( - bucket, - url=None, - org=None, - token=None, - socket_path=None, - ups=None, - daemonize=True, - debug=0, -): + bucket: str, + url: str | None = None, + org: str | None = None, + token: str | None = None, + socket_path: str | None = None, + ups: str | None = None, + daemonize: bool = True, + debug: int = 0, +) -> None: if not url: url = os.environ["INFLUXDB_V2_URL"] org = os.environ.get("INFLUXDB_V2_ORG") @@ -46,7 +47,7 @@ def influxdb_export( atexit.unregister(influx_shutdown) -def influxdb_write_snapshot(influxclient, bucket, ups_data, debug=0): +def influxdb_write_snapshot(influxclient: InfluxDBClient3, bucket: str, ups_data: dict[str, UPS], debug: int = 0) -> None: if debug > 1: debugger("influxdb: creating snapshot") points = influxdb_create_snapshot(ups_data, debug) @@ -58,8 +59,8 @@ def influxdb_write_snapshot(influxclient, bucket, ups_data, debug=0): debugger(e) -def influxdb_create_snapshot(ups_data, debug=0): - """Build InfluxDB points from {ups_name: JBDUPS}, tagging each point with the UPS name.""" +def influxdb_create_snapshot(ups_data: dict[str, UPS], debug: int = 0) -> list[Point]: + """Build InfluxDB points from {ups_name: UPS}, tagging each point with the UPS name.""" points = [] now = datetime.datetime.now(datetime.timezone.utc) diff --git a/bmspy/server.py b/bmspy/server.py index 86a8ee7..7c0b9f0 100755 --- a/bmspy/server.py +++ b/bmspy/server.py @@ -3,6 +3,7 @@ # Daemon: listens on a Unix socket and serves JBD BMS data to clients # import os +import socket import sys import stat import time @@ -11,6 +12,7 @@ import signal import json import struct from dataclasses import asdict as dataclass_asdict +from typing import Any, NoReturn from bmspy.utilities import debugger from bmspy.jbd_ups import collect_data, initialise_serial @@ -33,15 +35,15 @@ from bmspy.jbd_ups import collect_data, initialise_serial connected_clients = list() -def signalHandler(): +def signalHandler() -> NoReturn: raise SystemExit("terminating") -def socket_cleanup(socket_path, debug=0): +def socket_cleanup(socket_path: str, debug: int = 0) -> None: os.unlink(socket_path) -def read_request(connection, debug=0): +def read_request(connection: socket.socket, debug: int = 0) -> dict[str, Any]: # get length of expected json string request = bytes() try: @@ -72,7 +74,7 @@ def read_request(connection, debug=0): return request_data -def send_response(connection, response_data, client, debug=0): +def send_response(connection: socket.socket, response_data: Any, client: str, debug: int = 0) -> None: if debug > 2: debugger("socket: sending {!r}".format(response_data)) try: @@ -94,7 +96,7 @@ def send_response(connection, response_data, client, debug=0): raise OSError("unable to encode response: {}".format(e)) -def parse_device(device_str): +def parse_device(device_str: str) -> tuple[str, str]: """Parse 'name:/dev/path' or '/dev/path' into (name, path).""" if not device_str.startswith("/") and ":" in device_str: name, path = device_str.split(":", 1) @@ -105,7 +107,6 @@ def parse_device(device_str): def main(): import argparse - import socket import pwd import grp diff --git a/bmspy/ups.py b/bmspy/ups.py index 83936e5..1b71d99 100644 --- a/bmspy/ups.py +++ b/bmspy/ups.py @@ -2,6 +2,7 @@ from collections import deque import argparse import atexit, datetime, os, re, sys, time import smtplib, ssl, socket +from typing import Any from bmspy import client DAEMON_UPDATE_PERIOD = 30 @@ -11,7 +12,7 @@ critical_sent = False warning_sent = False alert_sent = False -def handle_shutdown(action = 'cancel', delay = 0, debug = 0): +def handle_shutdown(action: str = 'cancel', delay: int = 0, debug: int = 0) -> None: global scheduled_shutdown if action == 'shutdown': @@ -25,7 +26,7 @@ def handle_shutdown(action = 'cancel', delay = 0, debug = 0): return -def handle_email(text, level, recipient = "root", mailserver = "localhost", port = 25, mailuser = None, mailpass = None, debug = 0): +def handle_email(text: str, level: str | None, recipient: str = "root", mailserver: str = "localhost", port: int = 25, mailuser: str | None = None, mailpass: str | None = None, debug: int = 0) -> None: isSSL = False hostname = socket.gethostname() @@ -53,7 +54,7 @@ def handle_email(text, level, recipient = "root", mailserver = "localhost", port return -def main(): +def main() -> None: global alert_sent, warning_sent, critical_sent parser = argparse.ArgumentParser( diff --git a/bmspy/utilities.py b/bmspy/utilities.py index fde696b..4146931 100755 --- a/bmspy/utilities.py +++ b/bmspy/utilities.py @@ -4,9 +4,10 @@ # import datetime import pprint +from typing import Any -def debugger(data, pretty: bool = False): +def debugger(data: Any, pretty: bool = False) -> None: if pretty: pp = pprint.PrettyPrinter(indent=4) pp.pprint(