settings_service.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. from __future__ import annotations
  2. from pathlib import Path
  3. from typing import Any
  4. from fastapi import HTTPException
  5. from .database import DATA_DIR, get_db
  6. from .scanner import now_iso
  7. from .schemas import SystemSettingsUpdate
  8. SETTING_KEYS = {
  9. "default_ai_provider_id",
  10. "default_ai_model_id",
  11. "default_ai_temperature",
  12. "automation_file_root",
  13. "automation_screen_path",
  14. "automation_error_path",
  15. "automation_runtime_path",
  16. "automation_auto_screenshot_enabled",
  17. "automation_auto_screenshot_interval",
  18. }
  19. def list_settings() -> dict[str, Any]:
  20. with get_db() as conn:
  21. rows = conn.execute("SELECT * FROM app_settings ORDER BY key ASC").fetchall()
  22. values = {row["key"]: row["value"] for row in rows}
  23. return {"settings": normalize_settings(values), "items": rows}
  24. def update_settings(payload: SystemSettingsUpdate) -> dict[str, Any]:
  25. values = payload.model_dump(exclude_unset=True)
  26. now = now_iso()
  27. with get_db() as conn:
  28. for key, value in values.items():
  29. if key not in SETTING_KEYS:
  30. continue
  31. if key.endswith("_path") or key == "automation_file_root":
  32. value = normalize_relative_path(value)
  33. conn.execute(
  34. """
  35. INSERT INTO app_settings (key, value, description, updated_at)
  36. VALUES (?, ?, COALESCE((SELECT description FROM app_settings WHERE key = ?), ''), ?)
  37. ON CONFLICT(key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at
  38. """,
  39. (key, serialize_value(value), key, now),
  40. )
  41. return list_settings()
  42. def normalize_settings(values: dict[str, str | None]) -> dict[str, Any]:
  43. return {
  44. "default_ai_provider_id": optional_int(values.get("default_ai_provider_id")),
  45. "default_ai_model_id": optional_int(values.get("default_ai_model_id")),
  46. "default_ai_temperature": optional_float(values.get("default_ai_temperature"), 0.1),
  47. "automation_file_root": values.get("automation_file_root") or "automation",
  48. "automation_screen_path": values.get("automation_screen_path") or "automation/screens",
  49. "automation_error_path": values.get("automation_error_path") or "automation/errors",
  50. "automation_runtime_path": values.get("automation_runtime_path") or "automation/runtime",
  51. "automation_auto_screenshot_enabled": parse_bool(values.get("automation_auto_screenshot_enabled")),
  52. "automation_auto_screenshot_interval": optional_int(values.get("automation_auto_screenshot_interval"), 30),
  53. }
  54. def get_settings_dict() -> dict[str, Any]:
  55. return list_settings()["settings"]
  56. def default_ai_params() -> dict[str, Any]:
  57. settings = get_settings_dict()
  58. return {
  59. "provider_id": settings.get("default_ai_provider_id"),
  60. "model_id": settings.get("default_ai_model_id"),
  61. "temperature": settings.get("default_ai_temperature", 0.1),
  62. }
  63. def resolve_data_path(setting_key: str, fallback: str) -> Path:
  64. settings = get_settings_dict()
  65. relative = settings.get(setting_key) or fallback
  66. normalized = normalize_relative_path(relative)
  67. path = (DATA_DIR / normalized).resolve()
  68. data_root = DATA_DIR.resolve()
  69. if data_root != path and data_root not in path.parents:
  70. raise HTTPException(status_code=400, detail="Configured path escapes data directory")
  71. path.mkdir(parents=True, exist_ok=True)
  72. return path
  73. def normalize_relative_path(value: str | None) -> str:
  74. raw = (value or "").strip().replace("\\", "/")
  75. if not raw:
  76. return ""
  77. path = Path(raw)
  78. if path.is_absolute() or ".." in path.parts:
  79. raise HTTPException(status_code=400, detail="Path must be relative and must not contain ..")
  80. return "/".join(part for part in path.parts if part not in {"", "."})
  81. def optional_int(value: Any, default: int | None = None) -> int | None:
  82. if value in (None, ""):
  83. return default
  84. try:
  85. return int(value)
  86. except (TypeError, ValueError):
  87. return default
  88. def optional_float(value: Any, default: float | None = None) -> float | None:
  89. if value in (None, ""):
  90. return default
  91. try:
  92. return float(value)
  93. except (TypeError, ValueError):
  94. return default
  95. def parse_bool(value: Any) -> bool:
  96. return str(value).lower() in {"1", "true", "yes", "on"}
  97. def serialize_value(value: Any) -> str:
  98. if value is None:
  99. return ""
  100. if isinstance(value, bool):
  101. return "1" if value else "0"
  102. return str(value)