""" Model Manager: downloads model bundles and reports status via modelManagerSP. Ported from sunnypilot. """ import asyncio import json import os import time import os import aiohttp from openpilot.common.params import Params from openpilot.common.realtime import Ratekeeper from openpilot.common.swaglog import cloudlog from openpilot.system.hardware.hw import Paths from cereal import messaging, custom from openpilot.selfdrive.models_manager.fetcher import ModelFetcher from openpilot.selfdrive.models_manager.helpers import verify_file, get_active_bundle, bundle_to_dict class ModelManagerSP: """Manages model downloads and status reporting.""" def __init__(self): self.params = Params() self.model_fetcher = ModelFetcher(self.params) self.pm = messaging.PubMaster(["modelManagerSP"]) self.available_models: list = [] self.selected_bundle = None self.active_bundle = get_active_bundle(self.params) self._chunk_size = 128 * 1000 # 128 KB self._download_start_times: dict = {} def _calculate_eta(self, filename: str, progress: float) -> int: if filename not in self._download_start_times or progress <= 0: return 60 elapsed = time.monotonic() - self._download_start_times[filename] if elapsed <= 0: return 60 total_estimated = (elapsed / progress) * 100 return max(1, int(total_estimated - elapsed)) async def _download_file(self, url: str, path: str, model) -> None: self._download_start_times[model.fileName] = time.monotonic() async with aiohttp.ClientSession() as session: async with session.get(url) as response: response.raise_for_status() total_size = int(response.headers.get("content-length", 0)) bytes_downloaded = 0 with open(path, "wb") as f: async for chunk in response.content.iter_chunked(self._chunk_size): f.write(chunk) bytes_downloaded += len(chunk) if not self.params.get("ModelManager_DownloadIndex"): raise Exception("Download cancelled") if total_size > 0: progress = (bytes_downloaded / total_size) * 100 model.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.downloading model.downloadProgress.progress = progress model.downloadProgress.eta = self._calculate_eta(model.fileName, progress) self._report_status() self._download_start_times.pop(model.fileName, None) async def _process_artifact(self, artifact, destination_path: str) -> None: if not artifact.downloadUri.uri: return url = artifact.downloadUri.uri expected_hash = artifact.downloadUri.sha256 filename = artifact.fileName full_path = os.path.join(destination_path, filename) try: if os.path.exists(full_path) and await verify_file(full_path, expected_hash): artifact.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.cached artifact.downloadProgress.progress = 100 artifact.downloadProgress.eta = 0 self._report_status() return await self._download_file(url, full_path, artifact) if not await verify_file(full_path, expected_hash): raise ValueError(f"Hash validation failed for {filename}") artifact.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.downloaded artifact.downloadProgress.eta = 0 self._report_status() except Exception as e: cloudlog.error(f"Error downloading {filename}: {e}") if os.path.exists(full_path): os.remove(full_path) artifact.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.failed artifact.downloadProgress.eta = 0 self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.failed self._report_status() self._download_start_times.pop(artifact.fileName, None) raise async def _process_model(self, model, destination_path: str) -> None: await self._process_artifact(model.metadata, destination_path) await self._process_artifact(model.artifact, destination_path) def _report_status(self) -> None: msg = messaging.new_message("modelManagerSP", valid=True) state = msg.modelManagerSP if self.selected_bundle: state.selectedBundle = self.selected_bundle if self.active_bundle: state.activeBundle = self.active_bundle state.availableBundles = self.available_models self.pm.send("modelManagerSP", msg) async def _download_bundle(self, model_bundle, destination_path: str) -> None: self.selected_bundle = model_bundle self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.downloading os.makedirs(destination_path, exist_ok=True) try: tasks = [self._process_model(m, destination_path) for m in self.selected_bundle.models] await asyncio.gather(*tasks) self.active_bundle = self.selected_bundle self.active_bundle.status = custom.ModelManagerSP.DownloadStatus.downloaded d = bundle_to_dict(self.active_bundle) if d: self.params.put("ModelManager_ActiveBundle", json.dumps(d)) self.params.put_bool("ModelManager_NeedsRecalibration", True) calib_path = "/data/params/d_tmp/CalibrationParams" try: if os.path.exists(calib_path): os.remove(calib_path) cloudlog.info("ModelManager: CalibrationParams deleted for recalibration") except Exception as e: cloudlog.error(f"ModelManager: failed to delete CalibrationParams: {e}") # 触发重启(与 settings.cc 中的 DoReboot 配合) self.params.put_bool("DoReboot", True) cloudlog.info("ModelManager: triggering reboot after model download") cloudlog.info("ModelManager: saved active bundle %s", d.get("display_name", d.get("index"))) else: cloudlog.error("ModelManager: failed to serialize active bundle for params") self.selected_bundle = None except Exception: self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.failed raise finally: self._report_status() def download(self, model_bundle, destination_path: str) -> None: asyncio.run(self._download_bundle(model_bundle, destination_path)) def main_thread(self) -> None: rk = Ratekeeper(1, print_delay_threshold=None) while True: try: self.available_models = self.model_fetcher.get_available_bundles() self.active_bundle = get_active_bundle(self.params) raw_index = self.params.get("ModelManager_DownloadIndex") has_index = raw_index is not None and ( len(raw_index) > 0 if isinstance(raw_index, (bytes, str)) else bool(raw_index) ) if has_index: try: index_to_download = self.params.get_int("ModelManager_DownloadIndex") except Exception: index_to_download = None if index_to_download is not None: model_to_download = next((m for m in self.available_models if m.index == index_to_download), None) if model_to_download is not None: cloudlog.info(f"ModelManager: starting download for bundle index {index_to_download} ({model_to_download.displayName})") try: self.download(model_to_download, Paths.model_root()) cloudlog.info("ModelManager: download completed") except Exception as e: cloudlog.exception(e) finally: self.params.remove("ModelManager_DownloadIndex") self.selected_bundle = None else: cloudlog.warning( f"ModelManager: no bundle index {index_to_download} in {len(self.available_models)} available" ) self.params.remove("ModelManager_DownloadIndex") else: self.params.remove("ModelManager_DownloadIndex") if self.params.get("ModelManager_ClearCache"): self.clear_model_cache() self.params.remove("ModelManager_ClearCache") self._report_status() rk.keep_time() except Exception as e: cloudlog.exception(f"Error in model manager main thread: {e}") rk.keep_time() def clear_model_cache(self) -> None: active_files = [] if self.active_bundle is not None: for model in self.active_bundle.models: if hasattr(model, "artifact") and model.artifact.fileName: active_files.append(model.artifact.fileName) if hasattr(model, "metadata") and model.metadata.fileName: active_files.append(model.metadata.fileName) model_dir = Paths.model_root() try: for filename in os.listdir(model_dir): if filename not in active_files: file_path = os.path.join(model_dir, filename) if os.path.isfile(file_path): os.remove(file_path) cloudlog.info("Model cache cleared, keeping active model files") except Exception as e: cloudlog.exception(f"Error clearing model cache: {e}") def main(): ModelManagerSP().main_thread() if __name__ == "__main__": main()