Files
openpilot/selfdrive/models_manager/helpers.py
Comma Device 4b34ea50cd external radar
2026-03-30 22:16:15 +08:00

188 lines
7.4 KiB
Python
Executable File

"""
Model Manager helpers: file verification, bundle version compatibility, active bundle.
Ported from sunnypilot.
"""
import hashlib
import json
import os
from openpilot.common.params import Params
from openpilot.common.swaglog import cloudlog
from cereal import custom
# Model selector versioning (see sunnypilot model selector README)
CURRENT_SELECTOR_VERSION = 15
REQUIRED_MIN_SELECTOR_VERSION = 14
async def verify_file(file_path: str, expected_hash: str) -> bool:
"""Verifies file hash against expected hash (SHA256)."""
if not os.path.exists(file_path):
return False
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as file:
for chunk in iter(lambda: file.read(4096), b""):
sha256_hash.update(chunk)
return sha256_hash.hexdigest().lower() == expected_hash.lower()
def is_bundle_version_compatible(bundle: dict) -> bool:
"""
Checks whether the model bundle is compatible with the current selector version.
Bundle must have minimumSelectorVersion in [REQUIRED_MIN_SELECTOR_VERSION, CURRENT_SELECTOR_VERSION].
"""
return bool(
REQUIRED_MIN_SELECTOR_VERSION <= bundle.get("minimumSelectorVersion", 0) <= CURRENT_SELECTOR_VERSION
)
def _safe_str(v, default=""):
if v is None:
return default
if isinstance(v, bytes):
return v.decode("utf-8", errors="replace").strip() or default
return str(v).strip() if v else default
def _artifact_to_dict(artifact) -> dict:
"""Convert capnp Artifact to JSON-serializable dict (for ActiveBundle storage)."""
if artifact is None:
return {"file_name": "", "download_uri": {"url": "", "sha256": ""}}
try:
uri = getattr(artifact, "downloadUri", None)
url = getattr(uri, "uri", "") or "" if uri else ""
sha256 = getattr(uri, "sha256", "") or "" if uri else ""
return {
"file_name": _safe_str(getattr(artifact, "fileName", "")),
"download_uri": {"url": _safe_str(url), "sha256": _safe_str(sha256)},
}
except Exception:
return {"file_name": "", "download_uri": {"url": "", "sha256": ""}}
def _model_to_dict(model) -> dict:
"""Convert capnp Model to JSON-serializable dict."""
try:
t = getattr(model, "type", None)
type_val = int(getattr(t, "raw", t)) if t is not None and hasattr(t, "raw") else (int(t) if t is not None else 0)
except (TypeError, ValueError):
type_val = 0
art = _artifact_to_dict(getattr(model, "artifact", None))
meta = _artifact_to_dict(getattr(model, "metadata", None))
return {"type": type_val, "artifact": art, "metadata": meta}
def bundle_to_dict(bundle) -> dict:
"""
Convert capnp ModelBundle to a dict that can be json.dumps and later
loaded with get_active_bundle. Use this when saving ModelManager_ActiveBundle.
"""
try:
models = []
model_list = getattr(bundle, "models", None)
if model_list is not None:
for i in range(len(model_list)):
try:
models.append(_model_to_dict(model_list[i]))
except Exception:
models.append({"type": 0, "artifact": {"file_name": "", "download_uri": {"url": "", "sha256": ""}}, "metadata": {"file_name": "", "download_uri": {"url": "", "sha256": ""}}})
overrides = []
ov_list = getattr(bundle, "overrides", None)
if ov_list is not None:
for i in range(len(ov_list)):
try:
o = ov_list[i]
overrides.append({"key": _safe_str(getattr(o, "key", "")), "value": _safe_str(getattr(o, "value", ""))})
except Exception:
overrides.append({"key": "", "value": ""})
display_name = _safe_str(getattr(bundle, "displayName", "") or getattr(bundle, "display_name", ""))
try:
runner_val = getattr(bundle, "runner", None)
runner_int = int(getattr(runner_val, "raw", runner_val)) if runner_val is not None and hasattr(runner_val, "raw") else (int(runner_val) if runner_val is not None else 0)
except (TypeError, ValueError):
runner_int = 0
try:
status_val = getattr(bundle, "status", 0)
status_int = int(getattr(status_val, "raw", status_val)) if hasattr(status_val, "raw") else int(status_val)
except (TypeError, ValueError):
status_int = 0
try:
gen_val = getattr(bundle, "generation", 0)
gen_int = int(getattr(gen_val, "raw", gen_val)) if hasattr(gen_val, "raw") else int(gen_val)
except (TypeError, ValueError):
gen_int = 0
try:
msv_val = getattr(bundle, "minimumSelectorVersion", 0)
msv_int = int(getattr(msv_val, "raw", msv_val)) if hasattr(msv_val, "raw") else int(msv_val)
except (TypeError, ValueError):
msv_int = 0
return {
"index": int(getattr(bundle, "index", 0)),
"short_name": _safe_str(getattr(bundle, "internalName", "")),
"display_name": display_name,
"displayName": display_name,
"models": models,
"status": status_int,
"generation": gen_int,
"environment": _safe_str(getattr(bundle, "environment", "")),
"runner": runner_int,
"is_20hz": bool(getattr(bundle, "is20hz", False)),
"ref": _safe_str(getattr(bundle, "ref", "")),
"minimumSelectorVersion": msv_int,
"overrides": overrides,
}
except Exception as e:
cloudlog.exception("ModelManager: bundle_to_dict failed: %s", e)
return {}
def get_active_bundle(params: Params = None) -> custom.ModelManagerSP.ModelBundle:
"""Gets the active model bundle from params cache."""
if params is None:
params = Params()
try:
raw = params.get("ModelManager_ActiveBundle")
if not raw:
return None
active_bundle = json.loads(raw.decode("utf-8")) if isinstance(raw, bytes) else json.loads(raw)
if not active_bundle or not is_bundle_version_compatible(active_bundle):
return None
# Build ModelBundle from dict: capnp expects camelCase field names from schema
models = []
for m in active_bundle.get("models", []):
model = custom.ModelManagerSP.Model()
t = m.get("type", 0)
model.type = custom.ModelManagerSP.Model.Type(t) if isinstance(t, int) else t
art = m.get("artifact", {})
model.artifact.fileName = art.get("file_name", "") or art.get("fileName", "")
model.artifact.downloadUri.uri = (art.get("download_uri") or {}).get("url", "")
model.artifact.downloadUri.sha256 = (art.get("download_uri") or {}).get("sha256", "")
meta = m.get("metadata", {})
model.metadata.fileName = meta.get("file_name", "") or meta.get("fileName", "")
model.metadata.downloadUri.uri = (meta.get("download_uri") or {}).get("url", "")
model.metadata.downloadUri.sha256 = (meta.get("download_uri") or {}).get("sha256", "")
models.append(model)
b = custom.ModelManagerSP.ModelBundle()
b.index = int(active_bundle.get("index", 0))
b.internalName = active_bundle.get("short_name", "") or active_bundle.get("internalName", "")
b.displayName = active_bundle.get("display_name", "") or active_bundle.get("displayName", "")
b.models = models
b.status = active_bundle.get("status", 0)
b.generation = int(active_bundle.get("generation", 0))
b.environment = active_bundle.get("environment", "")
r = active_bundle.get("runner", 2)
b.runner = custom.ModelManagerSP.Runner(r) if isinstance(r, int) else custom.ModelManagerSP.Runner.stock
b.is20hz = active_bundle.get("is_20hz", active_bundle.get("is20hz", False))
b.ref = active_bundle.get("ref", "")
b.minimumSelectorVersion = int(active_bundle.get("minimumSelectorVersion", 0))
b.overrides = []
return b
except Exception:
pass
return None