Add typing, as well as a new UPS type that should be inherited by all new models

This commit is contained in:
2026-05-02 18:06:52 +02:00
parent 973dc9bc96
commit ff578f1889
7 changed files with 92 additions and 41 deletions
+5 -3
View File
@@ -8,7 +8,9 @@ import pprint
from bmspy.utilities import debugger from bmspy.utilities import debugger
def parse_args():
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Query JBD BMS and report status", description="Query JBD BMS and report status",
add_help=True, add_help=True,
@@ -115,7 +117,7 @@ def parse_args():
return args return args
def main(): def main() -> None:
try: try:
args = parse_args() args = parse_args()
@@ -163,7 +165,7 @@ def main():
client.handle_registration(args.socket, "bmspy", debug) client.handle_registration(args.socket, "bmspy", debug)
atexit.register(client.handle_registration, args.socket, "bmspy", debug) atexit.register(client.handle_registration, args.socket, "bmspy", debug)
# {ups_name: JBDUPS} # {ups_name: UPS}
data = client.read_data(args.socket, "bmspy", ups=args.ups, debug=debug) data = client.read_data(args.socket, "bmspy", ups=args.ups, debug=debug)
if args.report_json: if args.report_json:
+48 -6
View File
@@ -1,4 +1,7 @@
from dataclasses import dataclass, fields as dataclass_fields from abc import ABC, abstractmethod
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any
@dataclass @dataclass
@@ -10,7 +13,7 @@ class BMSScalarField:
value: str value: str
units: str | None = None units: str | None = None
def get(self, key, default=None): def get(self, key: str, default: Any = None) -> Any:
return getattr(self, key, default) return getattr(self, key, default)
@@ -20,11 +23,11 @@ class BMSMultiField:
help: str help: str
label: str label: str
raw_values: dict raw_values: dict[int, float | bool]
values: dict values: dict[int, str]
units: str | None = None units: str | None = None
def get(self, key, default=None): def get(self, key: str, default: Any = None) -> Any:
return getattr(self, key, default) return getattr(self, key, default)
@@ -35,5 +38,44 @@ class BMSInfoField:
help: str help: str
info: str info: str
def get(self, key, default=None): def get(self, key: str, default: Any = None) -> Any:
return getattr(self, key, default) return getattr(self, key, default)
type BMSField = BMSScalarField | BMSMultiField | BMSInfoField
def _field_from_dict(d: dict[str, Any]) -> BMSField:
"""Reconstruct a BMSField from its JSON-serialized dict."""
if "raw_values" in d:
return BMSMultiField(**d)
elif "info" in d:
return BMSInfoField(**d)
else:
return BMSScalarField(**d)
class UPS(ABC):
"""Abstract base class for all UPS/BMS device data containers."""
@abstractmethod
def items(self) -> Iterator[tuple[str, BMSField]]:
...
def __bool__(self) -> bool:
return next(iter(self.items()), None) is not None
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "UPS":
"""Reconstruct a UPS snapshot from a JSON-decoded field dict."""
return _UPSSnapshot({name: _field_from_dict(v) for name, v in data.items()})
class _UPSSnapshot(UPS):
"""Generic UPS snapshot reconstructed from serialized data by the client."""
def __init__(self, fields: dict[str, BMSField]) -> None:
self._fields = fields
def items(self) -> Iterator[tuple[str, BMSField]]:
yield from self._fields.items()
+12 -9
View File
@@ -5,18 +5,21 @@ import sys
import struct import struct
import json import json
import socket import socket
from typing import Any
from bmspy.utilities import debugger from bmspy.utilities import debugger
from bmspy.classes import UPS
is_registered = False is_registered = False
def handle_registration(socket_path, client_name, debug=0): def handle_registration(socket_path: str, client_name: str, debug: int = 0) -> dict[str, Any]:
global is_registered global is_registered
data = dict() data: dict[str, Any] = dict()
if is_registered: if is_registered:
message = {"command": "DEREGISTER", "client": client_name} message: dict[str, Any] = {"command": "DEREGISTER", "client": client_name}
else: else:
# fork server if it's not already running # fork server if it's not already running
message = {"command": "REGISTER", "client": client_name} message = {"command": "REGISTER", "client": client_name}
@@ -41,8 +44,8 @@ def handle_registration(socket_path, client_name, debug=0):
return data return data
def socket_comms(socket_path, request_data, debug=0): def socket_comms(socket_path: str, request_data: dict[str, Any], debug: int = 0) -> dict[str, Any]:
response_data = dict() response_data: dict[str, Any] = dict()
# Create a UDS socket # Create a UDS socket
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
@@ -106,9 +109,9 @@ def socket_comms(socket_path, request_data, debug=0):
return response_data return response_data
def read_data(socket_path, client_name, ups=None, debug=0): def read_data(socket_path: str, client_name: str, ups: str | None = None, debug: int = 0) -> dict[str, UPS]:
"""Return {ups_name: JBDUPS} for all UPSes, or just the named one.""" """Return {ups_name: UPS} for all devices, or just the named one."""
request = {"command": "GET", "client": client_name} request: dict[str, Any] = {"command": "GET", "client": client_name}
if ups is not None: if ups is not None:
request["ups"] = ups request["ups"] = ups
@@ -117,7 +120,7 @@ def read_data(socket_path, client_name, ups=None, debug=0):
if data is None: if data is None:
raise RuntimeError("No data received from daemon") raise RuntimeError("No data received from daemon")
return data return {name: UPS.from_dict(fields) for name, fields in data.items()}
if __name__ == "__main__": if __name__ == "__main__":
+14 -13
View File
@@ -5,26 +5,27 @@ import sys
import time import time
from influxdb_client_3 import InfluxDBClient3, Point from influxdb_client_3 import InfluxDBClient3, Point
from bmspy import client from bmspy import client
from bmspy.classes import UPS
from bmspy.utilities import debugger from bmspy.utilities import debugger
DAEMON_UPDATE_PERIOD = 30 DAEMON_UPDATE_PERIOD = 30
def influx_shutdown(influxclient): def influx_shutdown(influxclient: InfluxDBClient3 | None) -> None:
if influxclient is not None: if influxclient is not None:
influxclient.close() influxclient.close()
def influxdb_export( def influxdb_export(
bucket, bucket: str,
url=None, url: str | None = None,
org=None, org: str | None = None,
token=None, token: str | None = None,
socket_path=None, socket_path: str | None = None,
ups=None, ups: str | None = None,
daemonize=True, daemonize: bool = True,
debug=0, debug: int = 0,
): ) -> None:
if not url: if not url:
url = os.environ["INFLUXDB_V2_URL"] url = os.environ["INFLUXDB_V2_URL"]
org = os.environ.get("INFLUXDB_V2_ORG") org = os.environ.get("INFLUXDB_V2_ORG")
@@ -46,7 +47,7 @@ def influxdb_export(
atexit.unregister(influx_shutdown) atexit.unregister(influx_shutdown)
def influxdb_write_snapshot(influxclient, bucket, ups_data, debug=0): def influxdb_write_snapshot(influxclient: InfluxDBClient3, bucket: str, ups_data: dict[str, UPS], debug: int = 0) -> None:
if debug > 1: if debug > 1:
debugger("influxdb: creating snapshot") debugger("influxdb: creating snapshot")
points = influxdb_create_snapshot(ups_data, debug) points = influxdb_create_snapshot(ups_data, debug)
@@ -58,8 +59,8 @@ def influxdb_write_snapshot(influxclient, bucket, ups_data, debug=0):
debugger(e) debugger(e)
def influxdb_create_snapshot(ups_data, debug=0): def influxdb_create_snapshot(ups_data: dict[str, UPS], debug: int = 0) -> list[Point]:
"""Build InfluxDB points from {ups_name: JBDUPS}, tagging each point with the UPS name.""" """Build InfluxDB points from {ups_name: UPS}, tagging each point with the UPS name."""
points = [] points = []
now = datetime.datetime.now(datetime.timezone.utc) now = datetime.datetime.now(datetime.timezone.utc)
+7 -6
View File
@@ -3,6 +3,7 @@
# Daemon: listens on a Unix socket and serves JBD BMS data to clients # Daemon: listens on a Unix socket and serves JBD BMS data to clients
# #
import os import os
import socket
import sys import sys
import stat import stat
import time import time
@@ -11,6 +12,7 @@ import signal
import json import json
import struct import struct
from dataclasses import asdict as dataclass_asdict from dataclasses import asdict as dataclass_asdict
from typing import Any, NoReturn
from bmspy.utilities import debugger from bmspy.utilities import debugger
from bmspy.jbd_ups import collect_data, initialise_serial from bmspy.jbd_ups import collect_data, initialise_serial
@@ -33,15 +35,15 @@ from bmspy.jbd_ups import collect_data, initialise_serial
connected_clients = list() connected_clients = list()
def signalHandler(): def signalHandler() -> NoReturn:
raise SystemExit("terminating") raise SystemExit("terminating")
def socket_cleanup(socket_path, debug=0): def socket_cleanup(socket_path: str, debug: int = 0) -> None:
os.unlink(socket_path) os.unlink(socket_path)
def read_request(connection, debug=0): def read_request(connection: socket.socket, debug: int = 0) -> dict[str, Any]:
# get length of expected json string # get length of expected json string
request = bytes() request = bytes()
try: try:
@@ -72,7 +74,7 @@ def read_request(connection, debug=0):
return request_data return request_data
def send_response(connection, response_data, client, debug=0): def send_response(connection: socket.socket, response_data: Any, client: str, debug: int = 0) -> None:
if debug > 2: if debug > 2:
debugger("socket: sending {!r}".format(response_data)) debugger("socket: sending {!r}".format(response_data))
try: try:
@@ -94,7 +96,7 @@ def send_response(connection, response_data, client, debug=0):
raise OSError("unable to encode response: {}".format(e)) raise OSError("unable to encode response: {}".format(e))
def parse_device(device_str): def parse_device(device_str: str) -> tuple[str, str]:
"""Parse 'name:/dev/path' or '/dev/path' into (name, path).""" """Parse 'name:/dev/path' or '/dev/path' into (name, path)."""
if not device_str.startswith("/") and ":" in device_str: if not device_str.startswith("/") and ":" in device_str:
name, path = device_str.split(":", 1) name, path = device_str.split(":", 1)
@@ -105,7 +107,6 @@ def parse_device(device_str):
def main(): def main():
import argparse import argparse
import socket
import pwd import pwd
import grp import grp
+4 -3
View File
@@ -2,6 +2,7 @@ from collections import deque
import argparse import argparse
import atexit, datetime, os, re, sys, time import atexit, datetime, os, re, sys, time
import smtplib, ssl, socket import smtplib, ssl, socket
from typing import Any
from bmspy import client from bmspy import client
DAEMON_UPDATE_PERIOD = 30 DAEMON_UPDATE_PERIOD = 30
@@ -11,7 +12,7 @@ critical_sent = False
warning_sent = False warning_sent = False
alert_sent = False alert_sent = False
def handle_shutdown(action = 'cancel', delay = 0, debug = 0): def handle_shutdown(action: str = 'cancel', delay: int = 0, debug: int = 0) -> None:
global scheduled_shutdown global scheduled_shutdown
if action == 'shutdown': if action == 'shutdown':
@@ -25,7 +26,7 @@ def handle_shutdown(action = 'cancel', delay = 0, debug = 0):
return return
def handle_email(text, level, recipient = "root", mailserver = "localhost", port = 25, mailuser = None, mailpass = None, debug = 0): def handle_email(text: str, level: str | None, recipient: str = "root", mailserver: str = "localhost", port: int = 25, mailuser: str | None = None, mailpass: str | None = None, debug: int = 0) -> None:
isSSL = False isSSL = False
hostname = socket.gethostname() hostname = socket.gethostname()
@@ -53,7 +54,7 @@ def handle_email(text, level, recipient = "root", mailserver = "localhost", port
return return
def main(): def main() -> None:
global alert_sent, warning_sent, critical_sent global alert_sent, warning_sent, critical_sent
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
+2 -1
View File
@@ -4,9 +4,10 @@
# #
import datetime import datetime
import pprint import pprint
from typing import Any
def debugger(data, pretty: bool = False): def debugger(data: Any, pretty: bool = False) -> None:
if pretty: if pretty:
pp = pprint.PrettyPrinter(indent=4) pp = pprint.PrettyPrinter(indent=4)
pp.pprint( pp.pprint(