208 lines
7.8 KiB
Python
Executable File
208 lines
7.8 KiB
Python
Executable File
"""
|
|
Model Manager fetcher: fetch and cache model index from remote JSON.
|
|
Ported from sunnypilot.
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
|
|
import requests
|
|
from requests.exceptions import SSLError, RequestException, HTTPError
|
|
|
|
from openpilot.common.params import Params
|
|
from openpilot.common.swaglog import cloudlog
|
|
from openpilot.selfdrive.models_manager.helpers import is_bundle_version_compatible, bundle_to_dict
|
|
|
|
from cereal import custom
|
|
|
|
|
|
class ModelParser:
|
|
"""Parses model index JSON into cereal ModelManagerSP structs."""
|
|
|
|
@staticmethod
|
|
def _parse_download_uri(download_uri_data) -> custom.ModelManagerSP.DownloadUri:
|
|
download_uri = custom.ModelManagerSP.DownloadUri()
|
|
download_uri.uri = download_uri_data.get("url") or ""
|
|
download_uri.sha256 = download_uri_data.get("sha256") or ""
|
|
return download_uri
|
|
|
|
@staticmethod
|
|
def _parse_artifact(artifact_data) -> custom.ModelManagerSP.Artifact:
|
|
artifact = custom.ModelManagerSP.Artifact()
|
|
artifact.fileName = artifact_data.get("file_name") or ""
|
|
artifact.downloadUri = ModelParser._parse_download_uri(artifact_data.get("download_uri", {}))
|
|
return artifact
|
|
|
|
@staticmethod
|
|
def _model_type_from_str(t: str):
|
|
if not isinstance(t, str):
|
|
return custom.ModelManagerSP.Model.Type.supercombo
|
|
name = (t or "supercombo").lower()
|
|
if name == "vision":
|
|
return custom.ModelManagerSP.Model.Type.vision
|
|
if name == "policy":
|
|
return custom.ModelManagerSP.Model.Type.policy
|
|
if name == "navigation":
|
|
return custom.ModelManagerSP.Model.Type.navigation
|
|
if name == "offpolicy":
|
|
return custom.ModelManagerSP.Model.Type.offPolicy
|
|
return custom.ModelManagerSP.Model.Type.supercombo
|
|
|
|
@staticmethod
|
|
def _parse_model(model_data) -> custom.ModelManagerSP.Model:
|
|
model = custom.ModelManagerSP.Model()
|
|
model.type = ModelParser._model_type_from_str(model_data.get("type"))
|
|
model.artifact = ModelParser._parse_artifact(model_data.get("artifact", {}))
|
|
if metadata := model_data.get("metadata"):
|
|
model.metadata = ModelParser._parse_artifact(metadata)
|
|
else:
|
|
model.metadata = ModelParser._parse_artifact({})
|
|
return model
|
|
|
|
@staticmethod
|
|
def _parse_overrides(overrides_data: dict) -> list:
|
|
overrides = []
|
|
for key, value in (overrides_data or {}).items():
|
|
override = custom.ModelManagerSP.Override()
|
|
override.key = key
|
|
override.value = value
|
|
overrides.append(override)
|
|
return overrides
|
|
|
|
@staticmethod
|
|
def _runner_from_bundle(bundle) -> custom.ModelManagerSP.Runner:
|
|
r = bundle.get("runner", custom.ModelManagerSP.Runner.stock)
|
|
if isinstance(r, str):
|
|
name = r.lower()
|
|
if name == "tinygrad":
|
|
return custom.ModelManagerSP.Runner.tinygrad
|
|
if name == "stock":
|
|
return custom.ModelManagerSP.Runner.stock
|
|
return custom.ModelManagerSP.Runner.snpe
|
|
return r
|
|
|
|
@staticmethod
|
|
def _parse_bundle(bundle) -> custom.ModelManagerSP.ModelBundle:
|
|
model_bundle = custom.ModelManagerSP.ModelBundle()
|
|
model_bundle.index = int(bundle["index"])
|
|
model_bundle.internalName = bundle.get("short_name", "")
|
|
model_bundle.displayName = bundle.get("display_name", "")
|
|
model_bundle.models = [ModelParser._parse_model(m) for m in bundle.get("models", [])]
|
|
model_bundle.status = 0
|
|
g = bundle.get("generation", 0)
|
|
model_bundle.generation = int(g) if g != "" else 0
|
|
model_bundle.environment = bundle.get("environment", "")
|
|
model_bundle.runner = ModelParser._runner_from_bundle(bundle)
|
|
model_bundle.is20hz = bundle.get("is_20hz", False)
|
|
model_bundle.ref = bundle.get("ref", "")
|
|
msv = bundle.get("minimum_selector_version", 0)
|
|
model_bundle.minimumSelectorVersion = int(msv) if msv != "" else 0
|
|
model_bundle.overrides = ModelParser._parse_overrides(bundle.get("overrides", {}))
|
|
return model_bundle
|
|
|
|
@staticmethod
|
|
def parse_models(json_data: dict) -> list:
|
|
found_bundles = [ModelParser._parse_bundle(b) for b in json_data.get("bundles", [])]
|
|
return [b for b in found_bundles if is_bundle_version_compatible(bundle_to_dict(b))]
|
|
|
|
|
|
class ModelCache:
|
|
"""Caches model index in Params to avoid frequent network fetches."""
|
|
|
|
def __init__(self, params: Params, cache_timeout_seconds: int = 3600):
|
|
self.params = params
|
|
self.cache_timeout_seconds = cache_timeout_seconds
|
|
self._LAST_SYNC_KEY = "ModelManager_LastSyncTime"
|
|
self._CACHE_KEY = "ModelManager_ModelsCache"
|
|
|
|
def _is_expired(self) -> bool:
|
|
try:
|
|
raw = self.params.get(self._LAST_SYNC_KEY)
|
|
if not raw:
|
|
return True
|
|
last_sync = int(raw.decode("utf-8") if isinstance(raw, bytes) else raw)
|
|
except (TypeError, ValueError):
|
|
return True
|
|
return (time.monotonic() - last_sync) >= self.cache_timeout_seconds
|
|
|
|
def get(self) -> tuple:
|
|
"""Returns (cached_data dict, is_expired)."""
|
|
try:
|
|
raw = self.params.get(self._CACHE_KEY)
|
|
if not raw:
|
|
return {}, True
|
|
data = json.loads(raw.decode("utf-8")) if isinstance(raw, bytes) else json.loads(raw)
|
|
return data, self._is_expired()
|
|
except Exception as e:
|
|
cloudlog.exception(f"Error retrieving cached model data: {e}")
|
|
return {}, True
|
|
|
|
def set(self, data: dict) -> None:
|
|
self.params.put(self._CACHE_KEY, json.dumps(data))
|
|
self.params.put(self._LAST_SYNC_KEY, str(int(time.monotonic())))
|
|
|
|
|
|
class ModelFetcher:
|
|
"""Fetches and caches model index from remote URL."""
|
|
|
|
MODEL_URL = "https://raw.githubusercontent.com/sunnypilot/sunnypilot-models/refs/heads/gh-pages/docs/driving_models_v15.json"
|
|
NETWORK_RETRY_INTERVAL = 300 # 5 minutes between network retry attempts when offline
|
|
|
|
def __init__(self, params: Params):
|
|
self.params = params
|
|
self.model_cache = ModelCache(params)
|
|
self.model_parser = ModelParser()
|
|
self._last_fetch_attempt = 0
|
|
self._fetch_failed = False
|
|
|
|
def _fetch_and_cache_models(self) -> list | None:
|
|
try:
|
|
response = requests.get(self.MODEL_URL, timeout=10)
|
|
if response.status_code == 404:
|
|
cloudlog.error(f"Models URL returned 404: {self.MODEL_URL}")
|
|
raise HTTPError(f"404 Not Found: {self.MODEL_URL}", response=response)
|
|
response.raise_for_status()
|
|
json_data = response.json()
|
|
self.model_cache.set(json_data)
|
|
cloudlog.debug("Successfully updated models cache")
|
|
return self.model_parser.parse_models(json_data)
|
|
except (ConnectionError, SSLError, RequestException) as e:
|
|
cloudlog.warning(f"Request error fetching models: {e}")
|
|
except Exception as e:
|
|
cloudlog.exception(f"Error fetching models: {e}")
|
|
return None
|
|
|
|
def get_available_bundles(self) -> list:
|
|
cached_data, is_expired = self.model_cache.get()
|
|
|
|
# If cache is valid, use it without attempting network fetch
|
|
if cached_data and not is_expired:
|
|
cloudlog.debug("Using valid cached models data")
|
|
self._fetch_failed = False
|
|
return self.model_parser.parse_models(cached_data)
|
|
|
|
# If previous fetch failed, wait before retrying to avoid hammering network
|
|
current_time = time.monotonic()
|
|
if self._fetch_failed and (current_time - self._last_fetch_attempt) < self.NETWORK_RETRY_INTERVAL:
|
|
cloudlog.debug(f"Network fetch failed recently, waiting {self.NETWORK_RETRY_INTERVAL}s before retry")
|
|
if cached_data:
|
|
cloudlog.warning("Using expired cache as fallback (network offline)")
|
|
return self.model_parser.parse_models(cached_data)
|
|
return []
|
|
|
|
# Attempt to fetch from network
|
|
self._last_fetch_attempt = current_time
|
|
fetched_bundles = self._fetch_and_cache_models()
|
|
if fetched_bundles is not None:
|
|
self._fetch_failed = False
|
|
return fetched_bundles
|
|
|
|
# Network fetch failed
|
|
self._fetch_failed = True
|
|
if not cached_data:
|
|
cloudlog.warning("Failed to fetch models and no cache available")
|
|
return []
|
|
|
|
cloudlog.warning("Using expired cache as fallback")
|
|
return self.model_parser.parse_models(cached_data)
|