222 lines
8.9 KiB
Python
Executable File
222 lines
8.9 KiB
Python
Executable File
"""
|
|
Model Manager: downloads model bundles and reports status via modelManagerSP.
|
|
Ported from sunnypilot.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import time
|
|
import os
|
|
|
|
import aiohttp
|
|
from openpilot.common.params import Params
|
|
from openpilot.common.realtime import Ratekeeper
|
|
from openpilot.common.swaglog import cloudlog
|
|
from openpilot.system.hardware.hw import Paths
|
|
|
|
from cereal import messaging, custom
|
|
from openpilot.selfdrive.models_manager.fetcher import ModelFetcher
|
|
from openpilot.selfdrive.models_manager.helpers import verify_file, get_active_bundle, bundle_to_dict
|
|
|
|
|
|
class ModelManagerSP:
|
|
"""Manages model downloads and status reporting."""
|
|
|
|
def __init__(self):
|
|
self.params = Params()
|
|
self.model_fetcher = ModelFetcher(self.params)
|
|
self.pm = messaging.PubMaster(["modelManagerSP"])
|
|
self.available_models: list = []
|
|
self.selected_bundle = None
|
|
self.active_bundle = get_active_bundle(self.params)
|
|
self._chunk_size = 128 * 1000 # 128 KB
|
|
self._download_start_times: dict = {}
|
|
|
|
def _calculate_eta(self, filename: str, progress: float) -> int:
|
|
if filename not in self._download_start_times or progress <= 0:
|
|
return 60
|
|
elapsed = time.monotonic() - self._download_start_times[filename]
|
|
if elapsed <= 0:
|
|
return 60
|
|
total_estimated = (elapsed / progress) * 100
|
|
return max(1, int(total_estimated - elapsed))
|
|
|
|
async def _download_file(self, url: str, path: str, model) -> None:
|
|
self._download_start_times[model.fileName] = time.monotonic()
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(url) as response:
|
|
response.raise_for_status()
|
|
total_size = int(response.headers.get("content-length", 0))
|
|
bytes_downloaded = 0
|
|
with open(path, "wb") as f:
|
|
async for chunk in response.content.iter_chunked(self._chunk_size):
|
|
f.write(chunk)
|
|
bytes_downloaded += len(chunk)
|
|
if not self.params.get("ModelManager_DownloadIndex"):
|
|
raise Exception("Download cancelled")
|
|
if total_size > 0:
|
|
progress = (bytes_downloaded / total_size) * 100
|
|
model.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.downloading
|
|
model.downloadProgress.progress = progress
|
|
model.downloadProgress.eta = self._calculate_eta(model.fileName, progress)
|
|
self._report_status()
|
|
self._download_start_times.pop(model.fileName, None)
|
|
|
|
async def _process_artifact(self, artifact, destination_path: str) -> None:
|
|
if not artifact.downloadUri.uri:
|
|
return
|
|
url = artifact.downloadUri.uri
|
|
expected_hash = artifact.downloadUri.sha256
|
|
filename = artifact.fileName
|
|
full_path = os.path.join(destination_path, filename)
|
|
try:
|
|
if os.path.exists(full_path) and await verify_file(full_path, expected_hash):
|
|
artifact.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.cached
|
|
artifact.downloadProgress.progress = 100
|
|
artifact.downloadProgress.eta = 0
|
|
self._report_status()
|
|
return
|
|
await self._download_file(url, full_path, artifact)
|
|
if not await verify_file(full_path, expected_hash):
|
|
raise ValueError(f"Hash validation failed for {filename}")
|
|
artifact.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.downloaded
|
|
artifact.downloadProgress.eta = 0
|
|
self._report_status()
|
|
except Exception as e:
|
|
cloudlog.error(f"Error downloading {filename}: {e}")
|
|
if os.path.exists(full_path):
|
|
os.remove(full_path)
|
|
artifact.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.failed
|
|
artifact.downloadProgress.eta = 0
|
|
self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.failed
|
|
self._report_status()
|
|
self._download_start_times.pop(artifact.fileName, None)
|
|
raise
|
|
|
|
async def _process_model(self, model, destination_path: str) -> None:
|
|
await self._process_artifact(model.metadata, destination_path)
|
|
await self._process_artifact(model.artifact, destination_path)
|
|
|
|
def _report_status(self) -> None:
|
|
msg = messaging.new_message("modelManagerSP", valid=True)
|
|
state = msg.modelManagerSP
|
|
if self.selected_bundle:
|
|
state.selectedBundle = self.selected_bundle
|
|
if self.active_bundle:
|
|
state.activeBundle = self.active_bundle
|
|
state.availableBundles = self.available_models
|
|
self.pm.send("modelManagerSP", msg)
|
|
|
|
async def _download_bundle(self, model_bundle, destination_path: str) -> None:
|
|
self.selected_bundle = model_bundle
|
|
self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.downloading
|
|
os.makedirs(destination_path, exist_ok=True)
|
|
try:
|
|
tasks = [self._process_model(m, destination_path) for m in self.selected_bundle.models]
|
|
await asyncio.gather(*tasks)
|
|
self.active_bundle = self.selected_bundle
|
|
self.active_bundle.status = custom.ModelManagerSP.DownloadStatus.downloaded
|
|
d = bundle_to_dict(self.active_bundle)
|
|
if d:
|
|
self.params.put("ModelManager_ActiveBundle", json.dumps(d))
|
|
self.params.put_bool("ModelManager_NeedsRecalibration", True)
|
|
|
|
calib_path = "/data/params/d_tmp/CalibrationParams"
|
|
try:
|
|
if os.path.exists(calib_path):
|
|
os.remove(calib_path)
|
|
cloudlog.info("ModelManager: CalibrationParams deleted for recalibration")
|
|
except Exception as e:
|
|
cloudlog.error(f"ModelManager: failed to delete CalibrationParams: {e}")
|
|
|
|
# 触发重启(与 settings.cc 中的 DoReboot 配合)
|
|
self.params.put_bool("DoReboot", True)
|
|
cloudlog.info("ModelManager: triggering reboot after model download")
|
|
|
|
cloudlog.info("ModelManager: saved active bundle %s", d.get("display_name", d.get("index")))
|
|
else:
|
|
cloudlog.error("ModelManager: failed to serialize active bundle for params")
|
|
self.selected_bundle = None
|
|
except Exception:
|
|
self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.failed
|
|
raise
|
|
finally:
|
|
self._report_status()
|
|
|
|
def download(self, model_bundle, destination_path: str) -> None:
|
|
asyncio.run(self._download_bundle(model_bundle, destination_path))
|
|
|
|
def main_thread(self) -> None:
|
|
rk = Ratekeeper(1, print_delay_threshold=None)
|
|
while True:
|
|
try:
|
|
self.available_models = self.model_fetcher.get_available_bundles()
|
|
self.active_bundle = get_active_bundle(self.params)
|
|
|
|
raw_index = self.params.get("ModelManager_DownloadIndex")
|
|
has_index = raw_index is not None and (
|
|
len(raw_index) > 0 if isinstance(raw_index, (bytes, str)) else bool(raw_index)
|
|
)
|
|
if has_index:
|
|
try:
|
|
index_to_download = self.params.get_int("ModelManager_DownloadIndex")
|
|
except Exception:
|
|
index_to_download = None
|
|
if index_to_download is not None:
|
|
model_to_download = next((m for m in self.available_models if m.index == index_to_download), None)
|
|
if model_to_download is not None:
|
|
cloudlog.info(f"ModelManager: starting download for bundle index {index_to_download} ({model_to_download.displayName})")
|
|
try:
|
|
self.download(model_to_download, Paths.model_root())
|
|
cloudlog.info("ModelManager: download completed")
|
|
except Exception as e:
|
|
cloudlog.exception(e)
|
|
finally:
|
|
self.params.remove("ModelManager_DownloadIndex")
|
|
self.selected_bundle = None
|
|
else:
|
|
cloudlog.warning(
|
|
f"ModelManager: no bundle index {index_to_download} in {len(self.available_models)} available"
|
|
)
|
|
self.params.remove("ModelManager_DownloadIndex")
|
|
else:
|
|
self.params.remove("ModelManager_DownloadIndex")
|
|
|
|
if self.params.get("ModelManager_ClearCache"):
|
|
self.clear_model_cache()
|
|
self.params.remove("ModelManager_ClearCache")
|
|
|
|
self._report_status()
|
|
rk.keep_time()
|
|
except Exception as e:
|
|
cloudlog.exception(f"Error in model manager main thread: {e}")
|
|
rk.keep_time()
|
|
|
|
def clear_model_cache(self) -> None:
|
|
active_files = []
|
|
if self.active_bundle is not None:
|
|
for model in self.active_bundle.models:
|
|
if hasattr(model, "artifact") and model.artifact.fileName:
|
|
active_files.append(model.artifact.fileName)
|
|
if hasattr(model, "metadata") and model.metadata.fileName:
|
|
active_files.append(model.metadata.fileName)
|
|
model_dir = Paths.model_root()
|
|
try:
|
|
for filename in os.listdir(model_dir):
|
|
if filename not in active_files:
|
|
file_path = os.path.join(model_dir, filename)
|
|
if os.path.isfile(file_path):
|
|
os.remove(file_path)
|
|
cloudlog.info("Model cache cleared, keeping active model files")
|
|
except Exception as e:
|
|
cloudlog.exception(f"Error clearing model cache: {e}")
|
|
|
|
|
|
def main():
|
|
ModelManagerSP().main_thread()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|