""" Model Manager fetcher: fetch and cache model index from remote JSON. Ported from sunnypilot. """ import json import time import requests from requests.exceptions import SSLError, RequestException, HTTPError from openpilot.common.params import Params from openpilot.common.swaglog import cloudlog from openpilot.selfdrive.models_manager.helpers import is_bundle_version_compatible, bundle_to_dict from cereal import custom class ModelParser: """Parses model index JSON into cereal ModelManagerSP structs.""" @staticmethod def _parse_download_uri(download_uri_data) -> custom.ModelManagerSP.DownloadUri: download_uri = custom.ModelManagerSP.DownloadUri() download_uri.uri = download_uri_data.get("url") or "" download_uri.sha256 = download_uri_data.get("sha256") or "" return download_uri @staticmethod def _parse_artifact(artifact_data) -> custom.ModelManagerSP.Artifact: artifact = custom.ModelManagerSP.Artifact() artifact.fileName = artifact_data.get("file_name") or "" artifact.downloadUri = ModelParser._parse_download_uri(artifact_data.get("download_uri", {})) return artifact @staticmethod def _model_type_from_str(t: str): if not isinstance(t, str): return custom.ModelManagerSP.Model.Type.supercombo name = (t or "supercombo").lower() if name == "vision": return custom.ModelManagerSP.Model.Type.vision if name == "policy": return custom.ModelManagerSP.Model.Type.policy if name == "navigation": return custom.ModelManagerSP.Model.Type.navigation if name == "offpolicy": return custom.ModelManagerSP.Model.Type.offPolicy return custom.ModelManagerSP.Model.Type.supercombo @staticmethod def _parse_model(model_data) -> custom.ModelManagerSP.Model: model = custom.ModelManagerSP.Model() model.type = ModelParser._model_type_from_str(model_data.get("type")) model.artifact = ModelParser._parse_artifact(model_data.get("artifact", {})) if metadata := model_data.get("metadata"): model.metadata = ModelParser._parse_artifact(metadata) else: model.metadata = ModelParser._parse_artifact({}) return model @staticmethod def _parse_overrides(overrides_data: dict) -> list: overrides = [] for key, value in (overrides_data or {}).items(): override = custom.ModelManagerSP.Override() override.key = key override.value = value overrides.append(override) return overrides @staticmethod def _runner_from_bundle(bundle) -> custom.ModelManagerSP.Runner: r = bundle.get("runner", custom.ModelManagerSP.Runner.stock) if isinstance(r, str): name = r.lower() if name == "tinygrad": return custom.ModelManagerSP.Runner.tinygrad if name == "stock": return custom.ModelManagerSP.Runner.stock return custom.ModelManagerSP.Runner.snpe return r @staticmethod def _parse_bundle(bundle) -> custom.ModelManagerSP.ModelBundle: model_bundle = custom.ModelManagerSP.ModelBundle() model_bundle.index = int(bundle["index"]) model_bundle.internalName = bundle.get("short_name", "") model_bundle.displayName = bundle.get("display_name", "") model_bundle.models = [ModelParser._parse_model(m) for m in bundle.get("models", [])] model_bundle.status = 0 g = bundle.get("generation", 0) model_bundle.generation = int(g) if g != "" else 0 model_bundle.environment = bundle.get("environment", "") model_bundle.runner = ModelParser._runner_from_bundle(bundle) model_bundle.is20hz = bundle.get("is_20hz", False) model_bundle.ref = bundle.get("ref", "") msv = bundle.get("minimum_selector_version", 0) model_bundle.minimumSelectorVersion = int(msv) if msv != "" else 0 model_bundle.overrides = ModelParser._parse_overrides(bundle.get("overrides", {})) return model_bundle @staticmethod def parse_models(json_data: dict) -> list: found_bundles = [ModelParser._parse_bundle(b) for b in json_data.get("bundles", [])] return [b for b in found_bundles if is_bundle_version_compatible(bundle_to_dict(b))] class ModelCache: """Caches model index in Params to avoid frequent network fetches.""" def __init__(self, params: Params, cache_timeout_seconds: int = 3600): self.params = params self.cache_timeout_seconds = cache_timeout_seconds self._LAST_SYNC_KEY = "ModelManager_LastSyncTime" self._CACHE_KEY = "ModelManager_ModelsCache" def _is_expired(self) -> bool: try: raw = self.params.get(self._LAST_SYNC_KEY) if not raw: return True last_sync = int(raw.decode("utf-8") if isinstance(raw, bytes) else raw) except (TypeError, ValueError): return True return (time.monotonic() - last_sync) >= self.cache_timeout_seconds def get(self) -> tuple: """Returns (cached_data dict, is_expired).""" try: raw = self.params.get(self._CACHE_KEY) if not raw: return {}, True data = json.loads(raw.decode("utf-8")) if isinstance(raw, bytes) else json.loads(raw) return data, self._is_expired() except Exception as e: cloudlog.exception(f"Error retrieving cached model data: {e}") return {}, True def set(self, data: dict) -> None: self.params.put(self._CACHE_KEY, json.dumps(data)) self.params.put(self._LAST_SYNC_KEY, str(int(time.monotonic()))) class ModelFetcher: """Fetches and caches model index from remote URL.""" MODEL_URL = "https://raw.githubusercontent.com/sunnypilot/sunnypilot-models/refs/heads/gh-pages/docs/driving_models_v15.json" NETWORK_RETRY_INTERVAL = 300 # 5 minutes between network retry attempts when offline def __init__(self, params: Params): self.params = params self.model_cache = ModelCache(params) self.model_parser = ModelParser() self._last_fetch_attempt = 0 self._fetch_failed = False def _fetch_and_cache_models(self) -> list | None: try: response = requests.get(self.MODEL_URL, timeout=10) if response.status_code == 404: cloudlog.error(f"Models URL returned 404: {self.MODEL_URL}") raise HTTPError(f"404 Not Found: {self.MODEL_URL}", response=response) response.raise_for_status() json_data = response.json() self.model_cache.set(json_data) cloudlog.debug("Successfully updated models cache") return self.model_parser.parse_models(json_data) except (ConnectionError, SSLError, RequestException) as e: cloudlog.warning(f"Request error fetching models: {e}") except Exception as e: cloudlog.exception(f"Error fetching models: {e}") return None def get_available_bundles(self) -> list: cached_data, is_expired = self.model_cache.get() # If cache is valid, use it without attempting network fetch if cached_data and not is_expired: cloudlog.debug("Using valid cached models data") self._fetch_failed = False return self.model_parser.parse_models(cached_data) # If previous fetch failed, wait before retrying to avoid hammering network current_time = time.monotonic() if self._fetch_failed and (current_time - self._last_fetch_attempt) < self.NETWORK_RETRY_INTERVAL: cloudlog.debug(f"Network fetch failed recently, waiting {self.NETWORK_RETRY_INTERVAL}s before retry") if cached_data: cloudlog.warning("Using expired cache as fallback (network offline)") return self.model_parser.parse_models(cached_data) return [] # Attempt to fetch from network self._last_fetch_attempt = current_time fetched_bundles = self._fetch_and_cache_models() if fetched_bundles is not None: self._fetch_failed = False return fetched_bundles # Network fetch failed self._fetch_failed = True if not cached_data: cloudlog.warning("Failed to fetch models and no cache available") return [] cloudlog.warning("Using expired cache as fallback") return self.model_parser.parse_models(cached_data)