188 lines
7.4 KiB
Python
Executable File
188 lines
7.4 KiB
Python
Executable File
"""
|
|
Model Manager helpers: file verification, bundle version compatibility, active bundle.
|
|
Ported from sunnypilot.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import os
|
|
|
|
from openpilot.common.params import Params
|
|
from openpilot.common.swaglog import cloudlog
|
|
from cereal import custom
|
|
|
|
# Model selector versioning (see sunnypilot model selector README)
|
|
CURRENT_SELECTOR_VERSION = 15
|
|
REQUIRED_MIN_SELECTOR_VERSION = 14
|
|
|
|
|
|
async def verify_file(file_path: str, expected_hash: str) -> bool:
|
|
"""Verifies file hash against expected hash (SHA256)."""
|
|
if not os.path.exists(file_path):
|
|
return False
|
|
|
|
sha256_hash = hashlib.sha256()
|
|
with open(file_path, "rb") as file:
|
|
for chunk in iter(lambda: file.read(4096), b""):
|
|
sha256_hash.update(chunk)
|
|
|
|
return sha256_hash.hexdigest().lower() == expected_hash.lower()
|
|
|
|
|
|
def is_bundle_version_compatible(bundle: dict) -> bool:
|
|
"""
|
|
Checks whether the model bundle is compatible with the current selector version.
|
|
Bundle must have minimumSelectorVersion in [REQUIRED_MIN_SELECTOR_VERSION, CURRENT_SELECTOR_VERSION].
|
|
"""
|
|
return bool(
|
|
REQUIRED_MIN_SELECTOR_VERSION <= bundle.get("minimumSelectorVersion", 0) <= CURRENT_SELECTOR_VERSION
|
|
)
|
|
|
|
|
|
def _safe_str(v, default=""):
|
|
if v is None:
|
|
return default
|
|
if isinstance(v, bytes):
|
|
return v.decode("utf-8", errors="replace").strip() or default
|
|
return str(v).strip() if v else default
|
|
|
|
|
|
def _artifact_to_dict(artifact) -> dict:
|
|
"""Convert capnp Artifact to JSON-serializable dict (for ActiveBundle storage)."""
|
|
if artifact is None:
|
|
return {"file_name": "", "download_uri": {"url": "", "sha256": ""}}
|
|
try:
|
|
uri = getattr(artifact, "downloadUri", None)
|
|
url = getattr(uri, "uri", "") or "" if uri else ""
|
|
sha256 = getattr(uri, "sha256", "") or "" if uri else ""
|
|
return {
|
|
"file_name": _safe_str(getattr(artifact, "fileName", "")),
|
|
"download_uri": {"url": _safe_str(url), "sha256": _safe_str(sha256)},
|
|
}
|
|
except Exception:
|
|
return {"file_name": "", "download_uri": {"url": "", "sha256": ""}}
|
|
|
|
|
|
def _model_to_dict(model) -> dict:
|
|
"""Convert capnp Model to JSON-serializable dict."""
|
|
try:
|
|
t = getattr(model, "type", None)
|
|
type_val = int(getattr(t, "raw", t)) if t is not None and hasattr(t, "raw") else (int(t) if t is not None else 0)
|
|
except (TypeError, ValueError):
|
|
type_val = 0
|
|
art = _artifact_to_dict(getattr(model, "artifact", None))
|
|
meta = _artifact_to_dict(getattr(model, "metadata", None))
|
|
return {"type": type_val, "artifact": art, "metadata": meta}
|
|
|
|
|
|
def bundle_to_dict(bundle) -> dict:
|
|
"""
|
|
Convert capnp ModelBundle to a dict that can be json.dumps and later
|
|
loaded with get_active_bundle. Use this when saving ModelManager_ActiveBundle.
|
|
"""
|
|
try:
|
|
models = []
|
|
model_list = getattr(bundle, "models", None)
|
|
if model_list is not None:
|
|
for i in range(len(model_list)):
|
|
try:
|
|
models.append(_model_to_dict(model_list[i]))
|
|
except Exception:
|
|
models.append({"type": 0, "artifact": {"file_name": "", "download_uri": {"url": "", "sha256": ""}}, "metadata": {"file_name": "", "download_uri": {"url": "", "sha256": ""}}})
|
|
overrides = []
|
|
ov_list = getattr(bundle, "overrides", None)
|
|
if ov_list is not None:
|
|
for i in range(len(ov_list)):
|
|
try:
|
|
o = ov_list[i]
|
|
overrides.append({"key": _safe_str(getattr(o, "key", "")), "value": _safe_str(getattr(o, "value", ""))})
|
|
except Exception:
|
|
overrides.append({"key": "", "value": ""})
|
|
display_name = _safe_str(getattr(bundle, "displayName", "") or getattr(bundle, "display_name", ""))
|
|
try:
|
|
runner_val = getattr(bundle, "runner", None)
|
|
runner_int = int(getattr(runner_val, "raw", runner_val)) if runner_val is not None and hasattr(runner_val, "raw") else (int(runner_val) if runner_val is not None else 0)
|
|
except (TypeError, ValueError):
|
|
runner_int = 0
|
|
try:
|
|
status_val = getattr(bundle, "status", 0)
|
|
status_int = int(getattr(status_val, "raw", status_val)) if hasattr(status_val, "raw") else int(status_val)
|
|
except (TypeError, ValueError):
|
|
status_int = 0
|
|
try:
|
|
gen_val = getattr(bundle, "generation", 0)
|
|
gen_int = int(getattr(gen_val, "raw", gen_val)) if hasattr(gen_val, "raw") else int(gen_val)
|
|
except (TypeError, ValueError):
|
|
gen_int = 0
|
|
try:
|
|
msv_val = getattr(bundle, "minimumSelectorVersion", 0)
|
|
msv_int = int(getattr(msv_val, "raw", msv_val)) if hasattr(msv_val, "raw") else int(msv_val)
|
|
except (TypeError, ValueError):
|
|
msv_int = 0
|
|
return {
|
|
"index": int(getattr(bundle, "index", 0)),
|
|
"short_name": _safe_str(getattr(bundle, "internalName", "")),
|
|
"display_name": display_name,
|
|
"displayName": display_name,
|
|
"models": models,
|
|
"status": status_int,
|
|
"generation": gen_int,
|
|
"environment": _safe_str(getattr(bundle, "environment", "")),
|
|
"runner": runner_int,
|
|
"is_20hz": bool(getattr(bundle, "is20hz", False)),
|
|
"ref": _safe_str(getattr(bundle, "ref", "")),
|
|
"minimumSelectorVersion": msv_int,
|
|
"overrides": overrides,
|
|
}
|
|
except Exception as e:
|
|
cloudlog.exception("ModelManager: bundle_to_dict failed: %s", e)
|
|
return {}
|
|
|
|
|
|
def get_active_bundle(params: Params = None) -> custom.ModelManagerSP.ModelBundle:
|
|
"""Gets the active model bundle from params cache."""
|
|
if params is None:
|
|
params = Params()
|
|
|
|
try:
|
|
raw = params.get("ModelManager_ActiveBundle")
|
|
if not raw:
|
|
return None
|
|
active_bundle = json.loads(raw.decode("utf-8")) if isinstance(raw, bytes) else json.loads(raw)
|
|
if not active_bundle or not is_bundle_version_compatible(active_bundle):
|
|
return None
|
|
# Build ModelBundle from dict: capnp expects camelCase field names from schema
|
|
models = []
|
|
for m in active_bundle.get("models", []):
|
|
model = custom.ModelManagerSP.Model()
|
|
t = m.get("type", 0)
|
|
model.type = custom.ModelManagerSP.Model.Type(t) if isinstance(t, int) else t
|
|
art = m.get("artifact", {})
|
|
model.artifact.fileName = art.get("file_name", "") or art.get("fileName", "")
|
|
model.artifact.downloadUri.uri = (art.get("download_uri") or {}).get("url", "")
|
|
model.artifact.downloadUri.sha256 = (art.get("download_uri") or {}).get("sha256", "")
|
|
meta = m.get("metadata", {})
|
|
model.metadata.fileName = meta.get("file_name", "") or meta.get("fileName", "")
|
|
model.metadata.downloadUri.uri = (meta.get("download_uri") or {}).get("url", "")
|
|
model.metadata.downloadUri.sha256 = (meta.get("download_uri") or {}).get("sha256", "")
|
|
models.append(model)
|
|
b = custom.ModelManagerSP.ModelBundle()
|
|
b.index = int(active_bundle.get("index", 0))
|
|
b.internalName = active_bundle.get("short_name", "") or active_bundle.get("internalName", "")
|
|
b.displayName = active_bundle.get("display_name", "") or active_bundle.get("displayName", "")
|
|
b.models = models
|
|
b.status = active_bundle.get("status", 0)
|
|
b.generation = int(active_bundle.get("generation", 0))
|
|
b.environment = active_bundle.get("environment", "")
|
|
r = active_bundle.get("runner", 2)
|
|
b.runner = custom.ModelManagerSP.Runner(r) if isinstance(r, int) else custom.ModelManagerSP.Runner.stock
|
|
b.is20hz = active_bundle.get("is_20hz", active_bundle.get("is20hz", False))
|
|
b.ref = active_bundle.get("ref", "")
|
|
b.minimumSelectorVersion = int(active_bundle.get("minimumSelectorVersion", 0))
|
|
b.overrides = []
|
|
return b
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|