Files
openpilot/selfdrive/models_manager/fetcher.py
Comma Device 4b34ea50cd external radar
2026-03-30 22:16:15 +08:00

208 lines
7.8 KiB
Python
Executable File

"""
Model Manager fetcher: fetch and cache model index from remote JSON.
Ported from sunnypilot.
"""
import json
import time
import requests
from requests.exceptions import SSLError, RequestException, HTTPError
from openpilot.common.params import Params
from openpilot.common.swaglog import cloudlog
from openpilot.selfdrive.models_manager.helpers import is_bundle_version_compatible, bundle_to_dict
from cereal import custom
class ModelParser:
"""Parses model index JSON into cereal ModelManagerSP structs."""
@staticmethod
def _parse_download_uri(download_uri_data) -> custom.ModelManagerSP.DownloadUri:
download_uri = custom.ModelManagerSP.DownloadUri()
download_uri.uri = download_uri_data.get("url") or ""
download_uri.sha256 = download_uri_data.get("sha256") or ""
return download_uri
@staticmethod
def _parse_artifact(artifact_data) -> custom.ModelManagerSP.Artifact:
artifact = custom.ModelManagerSP.Artifact()
artifact.fileName = artifact_data.get("file_name") or ""
artifact.downloadUri = ModelParser._parse_download_uri(artifact_data.get("download_uri", {}))
return artifact
@staticmethod
def _model_type_from_str(t: str):
if not isinstance(t, str):
return custom.ModelManagerSP.Model.Type.supercombo
name = (t or "supercombo").lower()
if name == "vision":
return custom.ModelManagerSP.Model.Type.vision
if name == "policy":
return custom.ModelManagerSP.Model.Type.policy
if name == "navigation":
return custom.ModelManagerSP.Model.Type.navigation
if name == "offpolicy":
return custom.ModelManagerSP.Model.Type.offPolicy
return custom.ModelManagerSP.Model.Type.supercombo
@staticmethod
def _parse_model(model_data) -> custom.ModelManagerSP.Model:
model = custom.ModelManagerSP.Model()
model.type = ModelParser._model_type_from_str(model_data.get("type"))
model.artifact = ModelParser._parse_artifact(model_data.get("artifact", {}))
if metadata := model_data.get("metadata"):
model.metadata = ModelParser._parse_artifact(metadata)
else:
model.metadata = ModelParser._parse_artifact({})
return model
@staticmethod
def _parse_overrides(overrides_data: dict) -> list:
overrides = []
for key, value in (overrides_data or {}).items():
override = custom.ModelManagerSP.Override()
override.key = key
override.value = value
overrides.append(override)
return overrides
@staticmethod
def _runner_from_bundle(bundle) -> custom.ModelManagerSP.Runner:
r = bundle.get("runner", custom.ModelManagerSP.Runner.stock)
if isinstance(r, str):
name = r.lower()
if name == "tinygrad":
return custom.ModelManagerSP.Runner.tinygrad
if name == "stock":
return custom.ModelManagerSP.Runner.stock
return custom.ModelManagerSP.Runner.snpe
return r
@staticmethod
def _parse_bundle(bundle) -> custom.ModelManagerSP.ModelBundle:
model_bundle = custom.ModelManagerSP.ModelBundle()
model_bundle.index = int(bundle["index"])
model_bundle.internalName = bundle.get("short_name", "")
model_bundle.displayName = bundle.get("display_name", "")
model_bundle.models = [ModelParser._parse_model(m) for m in bundle.get("models", [])]
model_bundle.status = 0
g = bundle.get("generation", 0)
model_bundle.generation = int(g) if g != "" else 0
model_bundle.environment = bundle.get("environment", "")
model_bundle.runner = ModelParser._runner_from_bundle(bundle)
model_bundle.is20hz = bundle.get("is_20hz", False)
model_bundle.ref = bundle.get("ref", "")
msv = bundle.get("minimum_selector_version", 0)
model_bundle.minimumSelectorVersion = int(msv) if msv != "" else 0
model_bundle.overrides = ModelParser._parse_overrides(bundle.get("overrides", {}))
return model_bundle
@staticmethod
def parse_models(json_data: dict) -> list:
found_bundles = [ModelParser._parse_bundle(b) for b in json_data.get("bundles", [])]
return [b for b in found_bundles if is_bundle_version_compatible(bundle_to_dict(b))]
class ModelCache:
"""Caches model index in Params to avoid frequent network fetches."""
def __init__(self, params: Params, cache_timeout_seconds: int = 3600):
self.params = params
self.cache_timeout_seconds = cache_timeout_seconds
self._LAST_SYNC_KEY = "ModelManager_LastSyncTime"
self._CACHE_KEY = "ModelManager_ModelsCache"
def _is_expired(self) -> bool:
try:
raw = self.params.get(self._LAST_SYNC_KEY)
if not raw:
return True
last_sync = int(raw.decode("utf-8") if isinstance(raw, bytes) else raw)
except (TypeError, ValueError):
return True
return (time.monotonic() - last_sync) >= self.cache_timeout_seconds
def get(self) -> tuple:
"""Returns (cached_data dict, is_expired)."""
try:
raw = self.params.get(self._CACHE_KEY)
if not raw:
return {}, True
data = json.loads(raw.decode("utf-8")) if isinstance(raw, bytes) else json.loads(raw)
return data, self._is_expired()
except Exception as e:
cloudlog.exception(f"Error retrieving cached model data: {e}")
return {}, True
def set(self, data: dict) -> None:
self.params.put(self._CACHE_KEY, json.dumps(data))
self.params.put(self._LAST_SYNC_KEY, str(int(time.monotonic())))
class ModelFetcher:
"""Fetches and caches model index from remote URL."""
MODEL_URL = "https://raw.githubusercontent.com/sunnypilot/sunnypilot-models/refs/heads/gh-pages/docs/driving_models_v15.json"
NETWORK_RETRY_INTERVAL = 300 # 5 minutes between network retry attempts when offline
def __init__(self, params: Params):
self.params = params
self.model_cache = ModelCache(params)
self.model_parser = ModelParser()
self._last_fetch_attempt = 0
self._fetch_failed = False
def _fetch_and_cache_models(self) -> list | None:
try:
response = requests.get(self.MODEL_URL, timeout=10)
if response.status_code == 404:
cloudlog.error(f"Models URL returned 404: {self.MODEL_URL}")
raise HTTPError(f"404 Not Found: {self.MODEL_URL}", response=response)
response.raise_for_status()
json_data = response.json()
self.model_cache.set(json_data)
cloudlog.debug("Successfully updated models cache")
return self.model_parser.parse_models(json_data)
except (ConnectionError, SSLError, RequestException) as e:
cloudlog.warning(f"Request error fetching models: {e}")
except Exception as e:
cloudlog.exception(f"Error fetching models: {e}")
return None
def get_available_bundles(self) -> list:
cached_data, is_expired = self.model_cache.get()
# If cache is valid, use it without attempting network fetch
if cached_data and not is_expired:
cloudlog.debug("Using valid cached models data")
self._fetch_failed = False
return self.model_parser.parse_models(cached_data)
# If previous fetch failed, wait before retrying to avoid hammering network
current_time = time.monotonic()
if self._fetch_failed and (current_time - self._last_fetch_attempt) < self.NETWORK_RETRY_INTERVAL:
cloudlog.debug(f"Network fetch failed recently, waiting {self.NETWORK_RETRY_INTERVAL}s before retry")
if cached_data:
cloudlog.warning("Using expired cache as fallback (network offline)")
return self.model_parser.parse_models(cached_data)
return []
# Attempt to fetch from network
self._last_fetch_attempt = current_time
fetched_bundles = self._fetch_and_cache_models()
if fetched_bundles is not None:
self._fetch_failed = False
return fetched_bundles
# Network fetch failed
self._fetch_failed = True
if not cached_data:
cloudlog.warning("Failed to fetch models and no cache available")
return []
cloudlog.warning("Using expired cache as fallback")
return self.model_parser.parse_models(cached_data)