Files
openpilot/selfdrive/models_manager/manager.py
Comma Device 4b34ea50cd external radar
2026-03-30 22:16:15 +08:00

222 lines
8.9 KiB
Python
Executable File

"""
Model Manager: downloads model bundles and reports status via modelManagerSP.
Ported from sunnypilot.
"""
import asyncio
import json
import os
import time
import os
import aiohttp
from openpilot.common.params import Params
from openpilot.common.realtime import Ratekeeper
from openpilot.common.swaglog import cloudlog
from openpilot.system.hardware.hw import Paths
from cereal import messaging, custom
from openpilot.selfdrive.models_manager.fetcher import ModelFetcher
from openpilot.selfdrive.models_manager.helpers import verify_file, get_active_bundle, bundle_to_dict
class ModelManagerSP:
"""Manages model downloads and status reporting."""
def __init__(self):
self.params = Params()
self.model_fetcher = ModelFetcher(self.params)
self.pm = messaging.PubMaster(["modelManagerSP"])
self.available_models: list = []
self.selected_bundle = None
self.active_bundle = get_active_bundle(self.params)
self._chunk_size = 128 * 1000 # 128 KB
self._download_start_times: dict = {}
def _calculate_eta(self, filename: str, progress: float) -> int:
if filename not in self._download_start_times or progress <= 0:
return 60
elapsed = time.monotonic() - self._download_start_times[filename]
if elapsed <= 0:
return 60
total_estimated = (elapsed / progress) * 100
return max(1, int(total_estimated - elapsed))
async def _download_file(self, url: str, path: str, model) -> None:
self._download_start_times[model.fileName] = time.monotonic()
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
bytes_downloaded = 0
with open(path, "wb") as f:
async for chunk in response.content.iter_chunked(self._chunk_size):
f.write(chunk)
bytes_downloaded += len(chunk)
if not self.params.get("ModelManager_DownloadIndex"):
raise Exception("Download cancelled")
if total_size > 0:
progress = (bytes_downloaded / total_size) * 100
model.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.downloading
model.downloadProgress.progress = progress
model.downloadProgress.eta = self._calculate_eta(model.fileName, progress)
self._report_status()
self._download_start_times.pop(model.fileName, None)
async def _process_artifact(self, artifact, destination_path: str) -> None:
if not artifact.downloadUri.uri:
return
url = artifact.downloadUri.uri
expected_hash = artifact.downloadUri.sha256
filename = artifact.fileName
full_path = os.path.join(destination_path, filename)
try:
if os.path.exists(full_path) and await verify_file(full_path, expected_hash):
artifact.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.cached
artifact.downloadProgress.progress = 100
artifact.downloadProgress.eta = 0
self._report_status()
return
await self._download_file(url, full_path, artifact)
if not await verify_file(full_path, expected_hash):
raise ValueError(f"Hash validation failed for {filename}")
artifact.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.downloaded
artifact.downloadProgress.eta = 0
self._report_status()
except Exception as e:
cloudlog.error(f"Error downloading {filename}: {e}")
if os.path.exists(full_path):
os.remove(full_path)
artifact.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.failed
artifact.downloadProgress.eta = 0
self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.failed
self._report_status()
self._download_start_times.pop(artifact.fileName, None)
raise
async def _process_model(self, model, destination_path: str) -> None:
await self._process_artifact(model.metadata, destination_path)
await self._process_artifact(model.artifact, destination_path)
def _report_status(self) -> None:
msg = messaging.new_message("modelManagerSP", valid=True)
state = msg.modelManagerSP
if self.selected_bundle:
state.selectedBundle = self.selected_bundle
if self.active_bundle:
state.activeBundle = self.active_bundle
state.availableBundles = self.available_models
self.pm.send("modelManagerSP", msg)
async def _download_bundle(self, model_bundle, destination_path: str) -> None:
self.selected_bundle = model_bundle
self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.downloading
os.makedirs(destination_path, exist_ok=True)
try:
tasks = [self._process_model(m, destination_path) for m in self.selected_bundle.models]
await asyncio.gather(*tasks)
self.active_bundle = self.selected_bundle
self.active_bundle.status = custom.ModelManagerSP.DownloadStatus.downloaded
d = bundle_to_dict(self.active_bundle)
if d:
self.params.put("ModelManager_ActiveBundle", json.dumps(d))
self.params.put_bool("ModelManager_NeedsRecalibration", True)
calib_path = "/data/params/d_tmp/CalibrationParams"
try:
if os.path.exists(calib_path):
os.remove(calib_path)
cloudlog.info("ModelManager: CalibrationParams deleted for recalibration")
except Exception as e:
cloudlog.error(f"ModelManager: failed to delete CalibrationParams: {e}")
# 触发重启(与 settings.cc 中的 DoReboot 配合)
self.params.put_bool("DoReboot", True)
cloudlog.info("ModelManager: triggering reboot after model download")
cloudlog.info("ModelManager: saved active bundle %s", d.get("display_name", d.get("index")))
else:
cloudlog.error("ModelManager: failed to serialize active bundle for params")
self.selected_bundle = None
except Exception:
self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.failed
raise
finally:
self._report_status()
def download(self, model_bundle, destination_path: str) -> None:
asyncio.run(self._download_bundle(model_bundle, destination_path))
def main_thread(self) -> None:
rk = Ratekeeper(1, print_delay_threshold=None)
while True:
try:
self.available_models = self.model_fetcher.get_available_bundles()
self.active_bundle = get_active_bundle(self.params)
raw_index = self.params.get("ModelManager_DownloadIndex")
has_index = raw_index is not None and (
len(raw_index) > 0 if isinstance(raw_index, (bytes, str)) else bool(raw_index)
)
if has_index:
try:
index_to_download = self.params.get_int("ModelManager_DownloadIndex")
except Exception:
index_to_download = None
if index_to_download is not None:
model_to_download = next((m for m in self.available_models if m.index == index_to_download), None)
if model_to_download is not None:
cloudlog.info(f"ModelManager: starting download for bundle index {index_to_download} ({model_to_download.displayName})")
try:
self.download(model_to_download, Paths.model_root())
cloudlog.info("ModelManager: download completed")
except Exception as e:
cloudlog.exception(e)
finally:
self.params.remove("ModelManager_DownloadIndex")
self.selected_bundle = None
else:
cloudlog.warning(
f"ModelManager: no bundle index {index_to_download} in {len(self.available_models)} available"
)
self.params.remove("ModelManager_DownloadIndex")
else:
self.params.remove("ModelManager_DownloadIndex")
if self.params.get("ModelManager_ClearCache"):
self.clear_model_cache()
self.params.remove("ModelManager_ClearCache")
self._report_status()
rk.keep_time()
except Exception as e:
cloudlog.exception(f"Error in model manager main thread: {e}")
rk.keep_time()
def clear_model_cache(self) -> None:
active_files = []
if self.active_bundle is not None:
for model in self.active_bundle.models:
if hasattr(model, "artifact") and model.artifact.fileName:
active_files.append(model.artifact.fileName)
if hasattr(model, "metadata") and model.metadata.fileName:
active_files.append(model.metadata.fileName)
model_dir = Paths.model_root()
try:
for filename in os.listdir(model_dir):
if filename not in active_files:
file_path = os.path.join(model_dir, filename)
if os.path.isfile(file_path):
os.remove(file_path)
cloudlog.info("Model cache cleared, keeping active model files")
except Exception as e:
cloudlog.exception(f"Error clearing model cache: {e}")
def main():
ModelManagerSP().main_thread()
if __name__ == "__main__":
main()