from __future__ import annotations import json import re import sqlite3 from typing import Any import httpx from fastapi import HTTPException from pydantic import ValidationError from . import ai_gemini, ai_openai from .database import get_db from .scanner import now_iso from .schemas import ( AiImportItem, AiModelCreate, AiModelUpdate, AiProviderCreate, AiProviderUpdate, ) def public_provider(row: dict[str, Any]) -> dict[str, Any]: item = dict(row) item["enabled"] = bool(item["enabled"]) item["api_key_set"] = bool(item.get("api_key")) item.pop("api_key", None) return item def public_model(row: dict[str, Any]) -> dict[str, Any]: item = dict(row) item["is_default"] = bool(item["is_default"]) return item def list_providers() -> list[dict[str, Any]]: with get_db() as conn: rows = conn.execute("SELECT * FROM ai_providers ORDER BY name ASC").fetchall() return [public_provider(row) for row in rows] def create_provider(payload: AiProviderCreate) -> dict[str, Any]: now = now_iso() try: with get_db() as conn: cursor = conn.execute( """ INSERT INTO ai_providers (name, provider_type, base_url, api_key, enabled, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?) """, ( payload.name.strip(), payload.provider_type, clean_optional(payload.base_url), clean_optional(payload.api_key), 1 if payload.enabled else 0, now, now, ), ) row = conn.execute("SELECT * FROM ai_providers WHERE id = ?", (cursor.lastrowid,)).fetchone() except sqlite3.IntegrityError as exc: raise HTTPException(status_code=409, detail="AI provider name already exists") from exc return public_provider(row) def update_provider(provider_id: int, payload: AiProviderUpdate) -> dict[str, Any]: now = now_iso() with get_db() as conn: existing = conn.execute("SELECT * FROM ai_providers WHERE id = ?", (provider_id,)).fetchone() if not existing: raise HTTPException(status_code=404, detail="AI provider not found") api_key = existing.get("api_key") if payload.clear_api_key: api_key = None elif payload.api_key is not None and payload.api_key != "": api_key = payload.api_key try: conn.execute( """ UPDATE ai_providers SET name = ?, provider_type = ?, base_url = ?, api_key = ?, enabled = ?, updated_at = ? WHERE id = ? """, ( payload.name.strip(), payload.provider_type, clean_optional(payload.base_url), clean_optional(api_key), 1 if payload.enabled else 0, now, provider_id, ), ) except sqlite3.IntegrityError as exc: raise HTTPException(status_code=409, detail="AI provider name already exists") from exc row = conn.execute("SELECT * FROM ai_providers WHERE id = ?", (provider_id,)).fetchone() return public_provider(row) def delete_provider(provider_id: int) -> dict[str, Any]: with get_db() as conn: cursor = conn.execute("DELETE FROM ai_providers WHERE id = ?", (provider_id,)) if cursor.rowcount == 0: raise HTTPException(status_code=404, detail="AI provider not found") return {"deleted": cursor.rowcount} def list_models(provider_id: int | None = None) -> list[dict[str, Any]]: where = "WHERE m.provider_id = ?" if provider_id else "" params = [provider_id] if provider_id else [] with get_db() as conn: rows = conn.execute( f""" SELECT m.*, p.name AS provider_name, p.provider_type FROM ai_models m JOIN ai_providers p ON p.id = m.provider_id {where} ORDER BY p.name ASC, m.is_default DESC, m.name ASC """, params, ).fetchall() return [public_model(row) for row in rows] def create_model(payload: AiModelCreate) -> dict[str, Any]: now = now_iso() with get_db() as conn: ensure_provider_exists(conn, payload.provider_id) try: cursor = conn.execute( """ INSERT INTO ai_models (provider_id, name, display_name, is_default, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?) """, ( payload.provider_id, payload.name.strip(), clean_optional(payload.display_name), 1 if payload.is_default else 0, now, now, ), ) except sqlite3.IntegrityError as exc: raise HTTPException(status_code=409, detail="AI model already exists for this provider") from exc if payload.is_default: clear_other_default_models(conn, payload.provider_id, cursor.lastrowid) row = get_model_row(conn, cursor.lastrowid) return public_model(row) def update_model(model_id: int, payload: AiModelUpdate) -> dict[str, Any]: now = now_iso() with get_db() as conn: ensure_provider_exists(conn, payload.provider_id) if not conn.execute("SELECT id FROM ai_models WHERE id = ?", (model_id,)).fetchone(): raise HTTPException(status_code=404, detail="AI model not found") try: conn.execute( """ UPDATE ai_models SET provider_id = ?, name = ?, display_name = ?, is_default = ?, updated_at = ? WHERE id = ? """, ( payload.provider_id, payload.name.strip(), clean_optional(payload.display_name), 1 if payload.is_default else 0, now, model_id, ), ) except sqlite3.IntegrityError as exc: raise HTTPException(status_code=409, detail="AI model already exists for this provider") from exc if payload.is_default: clear_other_default_models(conn, payload.provider_id, model_id) row = get_model_row(conn, model_id) return public_model(row) def delete_model(model_id: int) -> dict[str, Any]: with get_db() as conn: cursor = conn.execute("DELETE FROM ai_models WHERE id = ?", (model_id,)) if cursor.rowcount == 0: raise HTTPException(status_code=404, detail="AI model not found") return {"deleted": cursor.rowcount} def chat(provider_id: int, model_id: int, prompt: str, temperature: float) -> dict[str, Any]: provider, model = get_provider_and_model(provider_id, model_id) try: if provider["provider_type"] in {"OPENAI", "OPENAI_COMPATIBLE"}: result = ai_openai.chat(provider, model, prompt, temperature) elif provider["provider_type"] == "GOOGLE_GEMINI": result = ai_gemini.chat(provider, model, prompt, temperature) else: raise HTTPException(status_code=400, detail="Unsupported AI provider type") except httpx.HTTPStatusError as exc: detail = exc.response.text[:1000] if exc.response is not None else str(exc) raise HTTPException(status_code=502, detail=f"AI provider returned an error: {detail}") from exc except httpx.HTTPError as exc: raise HTTPException(status_code=502, detail=f"AI provider request failed: {exc}") from exc return { "provider": public_provider(provider), "model": public_model(model), **result, } def chat_with_images( provider_id: int, model_id: int, prompt: str, images: list[dict[str, str]], temperature: float, ) -> dict[str, Any]: provider, model = get_provider_and_model(provider_id, model_id) try: if provider["provider_type"] in {"OPENAI", "OPENAI_COMPATIBLE"}: result = ai_openai.chat_with_images(provider, model, prompt, images, temperature) elif provider["provider_type"] == "GOOGLE_GEMINI": result = ai_gemini.chat_with_images(provider, model, prompt, images, temperature) else: raise HTTPException(status_code=400, detail="Unsupported AI provider type") except httpx.HTTPStatusError as exc: detail = exc.response.text[:1000] if exc.response is not None else str(exc) raise HTTPException(status_code=502, detail=f"AI provider returned an error: {detail}") from exc except httpx.HTTPError as exc: raise HTTPException(status_code=502, detail=f"AI provider request failed: {exc}") from exc return { "provider": public_provider(provider), "model": public_model(model), **result, } def parse_ai_items(content: str) -> list[dict[str, Any]]: parsed = json.loads(extract_json_text(content)) items = parsed.get("items") if isinstance(parsed, dict) else parsed if not isinstance(items, list): raise ValueError("AI output must be a JSON array or an object containing items") validated = [] for item in items: validated.append(AiImportItem.model_validate(item).model_dump()) return validated def extract_json_text(content: str) -> str: text = content.strip() fenced = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL | re.IGNORECASE) if fenced: text = fenced.group(1).strip() if text.startswith("[") or text.startswith("{"): return text start_candidates = [index for index in [text.find("["), text.find("{")] if index >= 0] if not start_candidates: return text start = min(start_candidates) end = max(text.rfind("]"), text.rfind("}")) return text[start : end + 1] if end > start else text[start:] def get_provider_and_model(provider_id: int, model_id: int) -> tuple[dict[str, Any], dict[str, Any]]: with get_db() as conn: provider = conn.execute("SELECT * FROM ai_providers WHERE id = ?", (provider_id,)).fetchone() model = conn.execute("SELECT * FROM ai_models WHERE id = ?", (model_id,)).fetchone() if not provider: raise HTTPException(status_code=404, detail="AI provider not found") if not provider["enabled"]: raise HTTPException(status_code=400, detail="AI provider is disabled") if not model or model["provider_id"] != provider_id: raise HTTPException(status_code=400, detail="AI model does not belong to this provider") return provider, model def ensure_provider_exists(conn, provider_id: int) -> None: if not conn.execute("SELECT id FROM ai_providers WHERE id = ?", (provider_id,)).fetchone(): raise HTTPException(status_code=404, detail="AI provider not found") def get_model_row(conn, model_id: int) -> dict[str, Any]: row = conn.execute( """ SELECT m.*, p.name AS provider_name, p.provider_type FROM ai_models m JOIN ai_providers p ON p.id = m.provider_id WHERE m.id = ? """, (model_id,), ).fetchone() if not row: raise HTTPException(status_code=404, detail="AI model not found") return row def clear_other_default_models(conn, provider_id: int, model_id: int) -> None: conn.execute( "UPDATE ai_models SET is_default = 0, updated_at = ? WHERE provider_id = ? AND id <> ?", (now_iso(), provider_id, model_id), ) def clean_optional(value: str | None) -> str | None: if value is None: return None stripped = value.strip() return stripped or None