Add typing, as well as a new UPS type that should be inherited by all new models

This commit is contained in:
2026-05-02 18:06:52 +02:00
parent 973dc9bc96
commit ff578f1889
7 changed files with 92 additions and 41 deletions
+5 -3
View File
@@ -8,7 +8,9 @@ import pprint
from bmspy.utilities import debugger
def parse_args():
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Query JBD BMS and report status",
add_help=True,
@@ -115,7 +117,7 @@ def parse_args():
return args
def main():
def main() -> None:
try:
args = parse_args()
@@ -163,7 +165,7 @@ def main():
client.handle_registration(args.socket, "bmspy", debug)
atexit.register(client.handle_registration, args.socket, "bmspy", debug)
# {ups_name: JBDUPS}
# {ups_name: UPS}
data = client.read_data(args.socket, "bmspy", ups=args.ups, debug=debug)
if args.report_json:
+48 -6
View File
@@ -1,4 +1,7 @@
from dataclasses import dataclass, fields as dataclass_fields
from abc import ABC, abstractmethod
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any
@dataclass
@@ -10,7 +13,7 @@ class BMSScalarField:
value: str
units: str | None = None
def get(self, key, default=None):
def get(self, key: str, default: Any = None) -> Any:
return getattr(self, key, default)
@@ -20,11 +23,11 @@ class BMSMultiField:
help: str
label: str
raw_values: dict
values: dict
raw_values: dict[int, float | bool]
values: dict[int, str]
units: str | None = None
def get(self, key, default=None):
def get(self, key: str, default: Any = None) -> Any:
return getattr(self, key, default)
@@ -35,5 +38,44 @@ class BMSInfoField:
help: str
info: str
def get(self, key, default=None):
def get(self, key: str, default: Any = None) -> Any:
return getattr(self, key, default)
type BMSField = BMSScalarField | BMSMultiField | BMSInfoField
def _field_from_dict(d: dict[str, Any]) -> BMSField:
"""Reconstruct a BMSField from its JSON-serialized dict."""
if "raw_values" in d:
return BMSMultiField(**d)
elif "info" in d:
return BMSInfoField(**d)
else:
return BMSScalarField(**d)
class UPS(ABC):
"""Abstract base class for all UPS/BMS device data containers."""
@abstractmethod
def items(self) -> Iterator[tuple[str, BMSField]]:
...
def __bool__(self) -> bool:
return next(iter(self.items()), None) is not None
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "UPS":
"""Reconstruct a UPS snapshot from a JSON-decoded field dict."""
return _UPSSnapshot({name: _field_from_dict(v) for name, v in data.items()})
class _UPSSnapshot(UPS):
"""Generic UPS snapshot reconstructed from serialized data by the client."""
def __init__(self, fields: dict[str, BMSField]) -> None:
self._fields = fields
def items(self) -> Iterator[tuple[str, BMSField]]:
yield from self._fields.items()
+12 -9
View File
@@ -5,18 +5,21 @@ import sys
import struct
import json
import socket
from typing import Any
from bmspy.utilities import debugger
from bmspy.classes import UPS
is_registered = False
def handle_registration(socket_path, client_name, debug=0):
def handle_registration(socket_path: str, client_name: str, debug: int = 0) -> dict[str, Any]:
global is_registered
data = dict()
data: dict[str, Any] = dict()
if is_registered:
message = {"command": "DEREGISTER", "client": client_name}
message: dict[str, Any] = {"command": "DEREGISTER", "client": client_name}
else:
# fork server if it's not already running
message = {"command": "REGISTER", "client": client_name}
@@ -41,8 +44,8 @@ def handle_registration(socket_path, client_name, debug=0):
return data
def socket_comms(socket_path, request_data, debug=0):
response_data = dict()
def socket_comms(socket_path: str, request_data: dict[str, Any], debug: int = 0) -> dict[str, Any]:
response_data: dict[str, Any] = dict()
# Create a UDS socket
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
@@ -106,9 +109,9 @@ def socket_comms(socket_path, request_data, debug=0):
return response_data
def read_data(socket_path, client_name, ups=None, debug=0):
"""Return {ups_name: JBDUPS} for all UPSes, or just the named one."""
request = {"command": "GET", "client": client_name}
def read_data(socket_path: str, client_name: str, ups: str | None = None, debug: int = 0) -> dict[str, UPS]:
"""Return {ups_name: UPS} for all devices, or just the named one."""
request: dict[str, Any] = {"command": "GET", "client": client_name}
if ups is not None:
request["ups"] = ups
@@ -117,7 +120,7 @@ def read_data(socket_path, client_name, ups=None, debug=0):
if data is None:
raise RuntimeError("No data received from daemon")
return data
return {name: UPS.from_dict(fields) for name, fields in data.items()}
if __name__ == "__main__":
+14 -13
View File
@@ -5,26 +5,27 @@ import sys
import time
from influxdb_client_3 import InfluxDBClient3, Point
from bmspy import client
from bmspy.classes import UPS
from bmspy.utilities import debugger
DAEMON_UPDATE_PERIOD = 30
def influx_shutdown(influxclient):
def influx_shutdown(influxclient: InfluxDBClient3 | None) -> None:
if influxclient is not None:
influxclient.close()
def influxdb_export(
bucket,
url=None,
org=None,
token=None,
socket_path=None,
ups=None,
daemonize=True,
debug=0,
):
bucket: str,
url: str | None = None,
org: str | None = None,
token: str | None = None,
socket_path: str | None = None,
ups: str | None = None,
daemonize: bool = True,
debug: int = 0,
) -> None:
if not url:
url = os.environ["INFLUXDB_V2_URL"]
org = os.environ.get("INFLUXDB_V2_ORG")
@@ -46,7 +47,7 @@ def influxdb_export(
atexit.unregister(influx_shutdown)
def influxdb_write_snapshot(influxclient, bucket, ups_data, debug=0):
def influxdb_write_snapshot(influxclient: InfluxDBClient3, bucket: str, ups_data: dict[str, UPS], debug: int = 0) -> None:
if debug > 1:
debugger("influxdb: creating snapshot")
points = influxdb_create_snapshot(ups_data, debug)
@@ -58,8 +59,8 @@ def influxdb_write_snapshot(influxclient, bucket, ups_data, debug=0):
debugger(e)
def influxdb_create_snapshot(ups_data, debug=0):
"""Build InfluxDB points from {ups_name: JBDUPS}, tagging each point with the UPS name."""
def influxdb_create_snapshot(ups_data: dict[str, UPS], debug: int = 0) -> list[Point]:
"""Build InfluxDB points from {ups_name: UPS}, tagging each point with the UPS name."""
points = []
now = datetime.datetime.now(datetime.timezone.utc)
+7 -6
View File
@@ -3,6 +3,7 @@
# Daemon: listens on a Unix socket and serves JBD BMS data to clients
#
import os
import socket
import sys
import stat
import time
@@ -11,6 +12,7 @@ import signal
import json
import struct
from dataclasses import asdict as dataclass_asdict
from typing import Any, NoReturn
from bmspy.utilities import debugger
from bmspy.jbd_ups import collect_data, initialise_serial
@@ -33,15 +35,15 @@ from bmspy.jbd_ups import collect_data, initialise_serial
connected_clients = list()
def signalHandler():
def signalHandler() -> NoReturn:
raise SystemExit("terminating")
def socket_cleanup(socket_path, debug=0):
def socket_cleanup(socket_path: str, debug: int = 0) -> None:
os.unlink(socket_path)
def read_request(connection, debug=0):
def read_request(connection: socket.socket, debug: int = 0) -> dict[str, Any]:
# get length of expected json string
request = bytes()
try:
@@ -72,7 +74,7 @@ def read_request(connection, debug=0):
return request_data
def send_response(connection, response_data, client, debug=0):
def send_response(connection: socket.socket, response_data: Any, client: str, debug: int = 0) -> None:
if debug > 2:
debugger("socket: sending {!r}".format(response_data))
try:
@@ -94,7 +96,7 @@ def send_response(connection, response_data, client, debug=0):
raise OSError("unable to encode response: {}".format(e))
def parse_device(device_str):
def parse_device(device_str: str) -> tuple[str, str]:
"""Parse 'name:/dev/path' or '/dev/path' into (name, path)."""
if not device_str.startswith("/") and ":" in device_str:
name, path = device_str.split(":", 1)
@@ -105,7 +107,6 @@ def parse_device(device_str):
def main():
import argparse
import socket
import pwd
import grp
+4 -3
View File
@@ -2,6 +2,7 @@ from collections import deque
import argparse
import atexit, datetime, os, re, sys, time
import smtplib, ssl, socket
from typing import Any
from bmspy import client
DAEMON_UPDATE_PERIOD = 30
@@ -11,7 +12,7 @@ critical_sent = False
warning_sent = False
alert_sent = False
def handle_shutdown(action = 'cancel', delay = 0, debug = 0):
def handle_shutdown(action: str = 'cancel', delay: int = 0, debug: int = 0) -> None:
global scheduled_shutdown
if action == 'shutdown':
@@ -25,7 +26,7 @@ def handle_shutdown(action = 'cancel', delay = 0, debug = 0):
return
def handle_email(text, level, recipient = "root", mailserver = "localhost", port = 25, mailuser = None, mailpass = None, debug = 0):
def handle_email(text: str, level: str | None, recipient: str = "root", mailserver: str = "localhost", port: int = 25, mailuser: str | None = None, mailpass: str | None = None, debug: int = 0) -> None:
isSSL = False
hostname = socket.gethostname()
@@ -53,7 +54,7 @@ def handle_email(text, level, recipient = "root", mailserver = "localhost", port
return
def main():
def main() -> None:
global alert_sent, warning_sent, critical_sent
parser = argparse.ArgumentParser(
+2 -1
View File
@@ -4,9 +4,10 @@
#
import datetime
import pprint
from typing import Any
def debugger(data, pretty: bool = False):
def debugger(data: Any, pretty: bool = False) -> None:
if pretty:
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(