automation_service.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900
  1. from __future__ import annotations
  2. import base64
  3. import json
  4. import mimetypes
  5. import time
  6. from pathlib import Path
  7. from typing import Any
  8. import psutil
  9. from fastapi import HTTPException
  10. from . import ai_service, settings_service, windows_automation
  11. from .database import DATA_DIR, get_db
  12. from .scanner import now_iso
  13. from .schemas import (
  14. AutomationKeyboardActionRequest,
  15. AutomationMouseActionRequest,
  16. AutomationElementLocateRequest,
  17. AutomationScreenshotCaptureRequest,
  18. AutomationStartProgramRequest,
  19. AutomationTextInputRequest,
  20. AutomationVisionAnalyzeRequest,
  21. AutomationWorkflowRunRequest,
  22. AutomationWorkflowSaveRequest,
  23. )
  24. AUTOMATION_DIR = DATA_DIR / "automation"
  25. SCREEN_DIR = AUTOMATION_DIR / "screens"
  26. ERROR_DIR = AUTOMATION_DIR / "errors"
  27. RUNTIME_DIR = AUTOMATION_DIR / "runtime"
  28. OPENED_PROCESS_IDS: set[int] = set()
  29. SCREEN_ANALYZE_PROMPT = """请作为 AI 视觉自动化助手分析这张 Windows 屏幕截图,并严格只输出 JSON 对象。
  30. 输出字段:
  31. - interface_name:界面名称,简洁中文。
  32. - description:界面描述,说明当前主要窗口或桌面内容。
  33. - is_windows_desktop:boolean,截图是否处于 Windows 桌面。
  34. - is_browser_webpage:boolean,截图是否为浏览器中的网页。
  35. - elements:可操作元素数组。
  36. 元素字段:
  37. - name:元素名称。
  38. - approximate_location:元素在界面中的大致位置文字描述,例如“窗口右上角”“左侧导航栏中部”“底部任务栏靠左”。不要输出具体坐标或百分比。
  39. 判断规则:
  40. 1. 如果截图位于 Windows 桌面,请识别桌面图标、开始菜单入口、任务栏应用、托盘区域等可操作元素。
  41. 2. 如果不是 Windows 桌面,也就是存在打开的前台窗口或全屏界面,只识别该前台窗口内的可操作元素,不要识别被遮挡的桌面元素。
  42. 3. 不要输出 Markdown,不要解释,只输出 JSON。
  43. """
  44. ELEMENT_LOCATE_PROMPT = """请作为 AI 视觉定位助手,在这张 Windows 屏幕截图中查找一个具体的可操作元素。
  45. 目标元素名称:
  46. {name}
  47. 目标元素大致位置描述:
  48. {approximate_location}
  49. 所在界面描述:
  50. {screen_description}
  51. 请严格只输出 JSON 对象,字段为:
  52. - has_element:boolean,图片中是否能找到该目标元素。
  53. - x_percent:元素中心点 X 相对整张截图宽度的百分比,范围 0-100,可以保留 2 位小数。找不到时为 null。
  54. - y_percent:元素中心点 Y 相对整张截图高度的百分比,范围 0-100,可以保留 2 位小数。找不到时为 null。
  55. - reason:简短中文原因。
  56. 只定位这个目标元素,不要列出其他元素。不要输出 Markdown,不要解释,只输出 JSON。
  57. """
  58. SCREEN_COMPARE_PROMPT = """请作为 AI 视觉自动化校验器判断两张截图是否处于同一个目标界面。
  59. 图片1是当前实际屏幕截图。图片2是数据库中保存的目标界面截图。
  60. 目标界面描述如下:
  61. {description}
  62. 请严格只输出 JSON 对象,字段为:
  63. - is_match:boolean,图片1是否仍然处于目标界面。
  64. - similarity:0 到 1 的数值,表示相似度。
  65. - reason:简短中文原因。
  66. 判断时可以允许小的光标位置、时间、列表内容滚动或轻微刷新差异,但如果前台窗口、网页、弹窗、主要页面或应用已经不同,应返回 false。
  67. """
  68. def ensure_dirs() -> None:
  69. """确保自动化截图、错误截图和运行时目录存在。"""
  70. for path in [screen_dir(), error_dir(), runtime_dir()]:
  71. path.mkdir(parents=True, exist_ok=True)
  72. def screen_dir() -> Path:
  73. """根据系统设置获取已识别界面截图目录。"""
  74. return settings_service.resolve_data_path("automation_screen_path", "automation/screens")
  75. def error_dir() -> Path:
  76. """根据系统设置获取错误截图目录。"""
  77. return settings_service.resolve_data_path("automation_error_path", "automation/errors")
  78. def runtime_dir() -> Path:
  79. """根据系统设置获取临时截图目录。"""
  80. return settings_service.resolve_data_path("automation_runtime_path", "automation/runtime")
  81. def image_to_base64(path: str | Path) -> dict[str, str]:
  82. """读取图片文件并转为 AI 服务可接收的 base64 结构。"""
  83. file_path = stored_path(path)
  84. mime_type = mimetypes.guess_type(file_path.name)[0] or "image/png"
  85. return {
  86. "base64": base64.b64encode(file_path.read_bytes()).decode("ascii"),
  87. "mime_type": mime_type,
  88. }
  89. def json_from_ai(content: str) -> dict[str, Any]:
  90. """从 AI 输出中提取 JSON 对象,兼容模型误加代码块的情况。"""
  91. parsed = json.loads(ai_service.extract_json_text(content))
  92. if not isinstance(parsed, dict):
  93. raise ValueError("AI output must be a JSON object")
  94. return parsed
  95. def take_screenshot_file(folder: Path, prefix: str) -> dict[str, Any]:
  96. """截取当前屏幕并保存为 PNG 文件,同时返回 base64 和分辨率信息。"""
  97. ensure_dirs()
  98. filename = f"{prefix}_{int(time.time() * 1000)}.png"
  99. path = folder / filename
  100. result = windows_automation.take_screenshot(str(path), include_base64=True)
  101. result["path"] = str(path)
  102. result["db_path"] = data_relative_path(path)
  103. return result
  104. def data_relative_path(path: str | Path) -> str:
  105. """把 data 目录下的文件路径转换为数据库保存用的相对路径。"""
  106. file_path = Path(path).resolve()
  107. try:
  108. return file_path.relative_to(DATA_DIR.resolve()).as_posix()
  109. except ValueError:
  110. return str(file_path)
  111. def stored_path(path: str | Path) -> Path:
  112. """把数据库中的相对路径还原成真实文件路径,同时兼容旧的绝对路径。"""
  113. file_path = Path(path)
  114. if file_path.is_absolute():
  115. return file_path
  116. return (DATA_DIR / file_path).resolve()
  117. def resolve_ai_params(
  118. provider_id: int | None,
  119. model_id: int | None,
  120. temperature: float | None,
  121. ) -> tuple[int, int, float]:
  122. """合并请求参数和系统默认 AI 参数。"""
  123. defaults = settings_service.default_ai_params()
  124. resolved_provider = provider_id or defaults.get("provider_id")
  125. resolved_model = model_id or defaults.get("model_id")
  126. resolved_temperature = temperature if temperature is not None else defaults.get("temperature", 0.1)
  127. if not resolved_provider or not resolved_model:
  128. raise HTTPException(status_code=400, detail="AI provider and model are required. Configure system defaults or pass them explicitly.")
  129. return int(resolved_provider), int(resolved_model), float(resolved_temperature)
  130. def capture_screenshot(payload: AutomationScreenshotCaptureRequest) -> dict[str, Any]:
  131. """截取当前屏幕并返回给前端显示,不进行 AI 分析。"""
  132. if payload.save:
  133. screenshot = take_screenshot_file(runtime_dir(), "manual_screenshot")
  134. else:
  135. screenshot = windows_automation.take_screenshot(None, include_base64=True)
  136. screenshot["path"] = None
  137. screenshot["db_path"] = None
  138. return {
  139. "width": screenshot["width"],
  140. "height": screenshot["height"],
  141. "image_base64": screenshot["image_base64"],
  142. "mime_type": screenshot["mime_type"],
  143. "path": screenshot.get("db_path"),
  144. }
  145. def analyze_screen(payload: AutomationVisionAnalyzeRequest) -> dict[str, Any]:
  146. """截图当前屏幕,调用 AI 识别界面和可操作元素,并保存识别结果。"""
  147. provider_id, model_id, temperature = resolve_ai_params(payload.provider_id, payload.model_id, payload.temperature)
  148. screenshot = take_screenshot_file(screen_dir(), "screen")
  149. image = image_to_base64(screenshot["path"])
  150. ai_result = ai_service.chat_with_images(
  151. provider_id,
  152. model_id,
  153. SCREEN_ANALYZE_PROMPT,
  154. [image],
  155. temperature,
  156. )
  157. try:
  158. parsed = json_from_ai(ai_result["content"])
  159. except (json.JSONDecodeError, ValueError) as exc:
  160. raise HTTPException(status_code=502, detail=f"AI vision output is not valid JSON: {exc}") from exc
  161. width = int(screenshot["width"])
  162. height = int(screenshot["height"])
  163. elements = normalize_elements(parsed.get("elements"), width, height)
  164. now = now_iso()
  165. with get_db() as conn:
  166. cursor = conn.execute(
  167. """
  168. INSERT INTO automation_screens (
  169. interface_name, description, image_path, width, height,
  170. is_windows_desktop, is_browser_webpage, raw_ai_json, created_at, updated_at
  171. )
  172. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  173. """,
  174. (
  175. str(parsed.get("interface_name") or "未命名界面")[:160],
  176. parsed.get("description"),
  177. screenshot["db_path"],
  178. width,
  179. height,
  180. 1 if bool(parsed.get("is_windows_desktop")) else 0,
  181. 1 if bool(parsed.get("is_browser_webpage")) else 0,
  182. json.dumps(parsed, ensure_ascii=False),
  183. now,
  184. now,
  185. ),
  186. )
  187. screen_id = cursor.lastrowid
  188. for index, element in enumerate(elements, start=1):
  189. conn.execute(
  190. """
  191. INSERT INTO automation_screen_elements (
  192. screen_id, element_index, name, x_percent, y_percent, x, y,
  193. approximate_location, is_located, raw_json, created_at
  194. )
  195. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  196. """,
  197. (
  198. screen_id,
  199. index,
  200. element["name"],
  201. element["x_percent"],
  202. element["y_percent"],
  203. element["x"],
  204. element["y"],
  205. element["approximate_location"],
  206. 1 if element["is_located"] else 0,
  207. json.dumps(element.get("raw") or element, ensure_ascii=False),
  208. now,
  209. ),
  210. )
  211. detail = get_screen(screen_id)
  212. detail["image_base64"] = screenshot["image_base64"]
  213. detail["mime_type"] = screenshot["mime_type"]
  214. detail["ai_raw_content"] = ai_result["content"]
  215. return detail
  216. def normalize_elements(raw_elements: Any, width: int, height: int) -> list[dict[str, Any]]:
  217. """规范化 AI 返回的可操作元素清单;初始分析阶段不要求坐标。"""
  218. if not isinstance(raw_elements, list):
  219. return []
  220. result = []
  221. for item in raw_elements:
  222. if not isinstance(item, dict):
  223. continue
  224. name = str(item.get("name") or f"元素 {len(result) + 1}")[:160]
  225. approximate_location = str(item.get("approximate_location") or item.get("location") or "未定位")[:300]
  226. x_percent = normalize_percent(item.get("x_percent")) if item.get("x_percent") is not None else 0.0
  227. y_percent = normalize_percent(item.get("y_percent")) if item.get("y_percent") is not None else 0.0
  228. is_located = item.get("x_percent") is not None and item.get("y_percent") is not None
  229. x = round(width * x_percent / 100)
  230. y = round(height * y_percent / 100)
  231. result.append(
  232. {
  233. "name": name,
  234. "x_percent": x_percent,
  235. "y_percent": y_percent,
  236. "x": max(0, min(width - 1, x)),
  237. "y": max(0, min(height - 1, y)),
  238. "approximate_location": approximate_location,
  239. "is_located": is_located,
  240. "raw": item,
  241. }
  242. )
  243. return result
  244. def locate_element(screen_id: int, element_id: int, payload: AutomationElementLocateRequest) -> dict[str, Any]:
  245. """针对单个可操作元素调用 AI 精确定位,并更新该元素的像素坐标。"""
  246. provider_id, model_id, temperature = resolve_ai_params(payload.provider_id, payload.model_id, payload.temperature)
  247. screen = get_screen(screen_id)
  248. element = next((item for item in screen.get("elements", []) if item["id"] == element_id), None)
  249. if not element:
  250. raise HTTPException(status_code=404, detail="Automation screen element not found")
  251. prompt = (
  252. ELEMENT_LOCATE_PROMPT
  253. .replace("{name}", element.get("name") or "")
  254. .replace("{approximate_location}", element.get("approximate_location") or "")
  255. .replace("{screen_description}", screen.get("description") or screen.get("interface_name") or "")
  256. )
  257. ai_result = ai_service.chat_with_images(
  258. provider_id,
  259. model_id,
  260. prompt,
  261. [image_to_base64(screen["image_path"])],
  262. temperature,
  263. )
  264. try:
  265. parsed = json_from_ai(ai_result["content"])
  266. except (json.JSONDecodeError, ValueError) as exc:
  267. raise HTTPException(status_code=502, detail=f"AI locate output is not valid JSON: {exc}") from exc
  268. if not bool(parsed.get("has_element")) or parsed.get("x_percent") is None or parsed.get("y_percent") is None:
  269. return {"located": False, "element": element, "ai_result": parsed, "ai_raw_content": ai_result["content"]}
  270. x_percent = normalize_percent(parsed.get("x_percent"))
  271. y_percent = normalize_percent(parsed.get("y_percent"))
  272. x = max(0, min(int(screen["width"]) - 1, round(int(screen["width"]) * x_percent / 100)))
  273. y = max(0, min(int(screen["height"]) - 1, round(int(screen["height"]) * y_percent / 100)))
  274. raw = {**parsed, "previous": element.get("raw_json")}
  275. with get_db() as conn:
  276. conn.execute(
  277. """
  278. UPDATE automation_screen_elements
  279. SET x_percent = ?, y_percent = ?, x = ?, y = ?, is_located = 1, raw_json = ?
  280. WHERE id = ? AND screen_id = ?
  281. """,
  282. (x_percent, y_percent, x, y, json.dumps(raw, ensure_ascii=False), element_id, screen_id),
  283. )
  284. updated = get_screen(screen_id, include_image=True)
  285. updated_element = next(item for item in updated["elements"] if item["id"] == element_id)
  286. return {
  287. "located": True,
  288. "element": updated_element,
  289. "screen": updated,
  290. "ai_result": parsed,
  291. "ai_raw_content": ai_result["content"],
  292. }
  293. def normalize_percent(value: Any) -> float:
  294. """规范化百分比数值,兼容模型偶尔输出 0-1 小数的情况。"""
  295. try:
  296. number = float(value)
  297. except (TypeError, ValueError):
  298. return 0.0
  299. if 0 <= number <= 1:
  300. number *= 100
  301. return max(0.0, min(100.0, round(number, 2)))
  302. def list_screens(page: int, page_size: int) -> dict[str, Any]:
  303. """分页查询已识别界面列表。"""
  304. offset = (page - 1) * page_size
  305. with get_db() as conn:
  306. total = conn.execute("SELECT COUNT(*) AS total FROM automation_screens").fetchone()["total"]
  307. rows = conn.execute(
  308. """
  309. SELECT s.*, COUNT(e.id) AS element_count
  310. FROM automation_screens s
  311. LEFT JOIN automation_screen_elements e ON e.screen_id = s.id
  312. GROUP BY s.id
  313. ORDER BY s.created_at DESC
  314. LIMIT ? OFFSET ?
  315. """,
  316. (page_size, offset),
  317. ).fetchall()
  318. return {"items": [public_screen(row) for row in rows], "total": total, "page": page, "page_size": page_size}
  319. def get_screen(screen_id: int, include_image: bool = False) -> dict[str, Any]:
  320. """读取单个已识别界面的详情和可操作元素。"""
  321. with get_db() as conn:
  322. screen = conn.execute("SELECT * FROM automation_screens WHERE id = ?", (screen_id,)).fetchone()
  323. if not screen:
  324. raise HTTPException(status_code=404, detail="Automation screen not found")
  325. elements = conn.execute(
  326. "SELECT * FROM automation_screen_elements WHERE screen_id = ? ORDER BY element_index ASC",
  327. (screen_id,),
  328. ).fetchall()
  329. item = public_screen(screen)
  330. item["elements"] = [public_element(row) for row in elements]
  331. if include_image and stored_path(item["image_path"]).exists():
  332. image = image_to_base64(item["image_path"])
  333. item["image_base64"] = image["base64"]
  334. item["mime_type"] = image["mime_type"]
  335. return item
  336. def delete_screen(screen_id: int) -> dict[str, Any]:
  337. """删除已识别界面记录,图片文件保留用于审计。"""
  338. with get_db() as conn:
  339. cursor = conn.execute("DELETE FROM automation_screens WHERE id = ?", (screen_id,))
  340. if cursor.rowcount == 0:
  341. raise HTTPException(status_code=404, detail="Automation screen not found")
  342. return {"deleted": cursor.rowcount}
  343. def public_screen(row: dict[str, Any]) -> dict[str, Any]:
  344. """把数据库中的界面行转换为接口返回格式。"""
  345. item = dict(row)
  346. item["is_windows_desktop"] = bool(item.get("is_windows_desktop"))
  347. item["is_browser_webpage"] = bool(item.get("is_browser_webpage"))
  348. return item
  349. def public_element(row: dict[str, Any]) -> dict[str, Any]:
  350. """把数据库中的元素行转换为接口返回格式。"""
  351. item = dict(row)
  352. item["is_located"] = bool(item.get("is_located"))
  353. return item
  354. def process_snapshot() -> dict[int, dict[str, Any]]:
  355. """获取当前进程快照,只用于自动化动作前后对比,不写入进程扫描表。"""
  356. snapshot: dict[int, dict[str, Any]] = {}
  357. for proc in psutil.process_iter(["pid", "name", "exe"]):
  358. try:
  359. snapshot[int(proc.info["pid"])] = {
  360. "pid": int(proc.info["pid"]),
  361. "name": proc.info.get("name"),
  362. "exe": proc.info.get("exe"),
  363. }
  364. except (psutil.Error, OSError, TypeError, ValueError):
  365. continue
  366. return snapshot
  367. def diff_new_processes(before: dict[int, dict[str, Any]], after: dict[int, dict[str, Any]]) -> list[dict[str, Any]]:
  368. """比较动作前后的进程快照,找出本次自动化动作新增的进程。"""
  369. new_items = [after[pid] for pid in sorted(set(after) - set(before))]
  370. OPENED_PROCESS_IDS.update(item["pid"] for item in new_items)
  371. return new_items
  372. def validate_screen_before_action(
  373. screen_id: int | None,
  374. provider_id: int | None,
  375. model_id: int | None,
  376. temperature: float,
  377. action_type: str,
  378. workflow_id: int | None = None,
  379. node_id: int | None = None,
  380. ) -> dict[str, Any] | None:
  381. """如果动作绑定了界面 ID,则先用 AI 判断当前屏幕是否仍处于目标界面。"""
  382. if screen_id is None:
  383. return None
  384. provider_id, model_id, temperature = resolve_ai_params(provider_id, model_id, temperature)
  385. target = get_screen(screen_id)
  386. current = take_screenshot_file(error_dir(), "compare_current")
  387. prompt = SCREEN_COMPARE_PROMPT.replace("{description}", target.get("description") or target.get("interface_name") or "")
  388. ai_result = ai_service.chat_with_images(
  389. provider_id,
  390. model_id,
  391. prompt,
  392. [image_to_base64(current["path"]), image_to_base64(target["image_path"])],
  393. temperature,
  394. )
  395. try:
  396. parsed = json_from_ai(ai_result["content"])
  397. except (json.JSONDecodeError, ValueError) as exc:
  398. raise HTTPException(status_code=502, detail=f"AI compare output is not valid JSON: {exc}") from exc
  399. is_match = bool(parsed.get("is_match"))
  400. similarity = safe_float(parsed.get("similarity"))
  401. if not is_match:
  402. error = record_error(
  403. action_type=action_type,
  404. message=str(parsed.get("reason") or "界面对比失败,当前屏幕不是目标界面"),
  405. screen_id=screen_id,
  406. workflow_id=workflow_id,
  407. node_id=node_id,
  408. similarity=similarity,
  409. expected_image_path=target["image_path"],
  410. actual_image_path=current["db_path"],
  411. compare_result=parsed,
  412. )
  413. raise HTTPException(status_code=409, detail={"message": error["message"], "error": error})
  414. return parsed
  415. def safe_float(value: Any) -> float | None:
  416. """安全转换浮点数。"""
  417. try:
  418. return float(value)
  419. except (TypeError, ValueError):
  420. return None
  421. def record_error(
  422. action_type: str,
  423. message: str,
  424. screen_id: int | None = None,
  425. workflow_id: int | None = None,
  426. node_id: int | None = None,
  427. similarity: float | None = None,
  428. expected_image_path: str | None = None,
  429. actual_image_path: str | None = None,
  430. compare_result: dict[str, Any] | None = None,
  431. ) -> dict[str, Any]:
  432. """保存自动化错误记录,便于在错误记录菜单中回看。"""
  433. now = now_iso()
  434. with get_db() as conn:
  435. cursor = conn.execute(
  436. """
  437. INSERT INTO automation_errors (
  438. workflow_id, node_id, screen_id, action_type, message, similarity,
  439. expected_image_path, actual_image_path, compare_result_json, created_at
  440. )
  441. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  442. """,
  443. (
  444. workflow_id,
  445. node_id,
  446. screen_id,
  447. action_type,
  448. message,
  449. similarity,
  450. expected_image_path,
  451. actual_image_path,
  452. json.dumps(compare_result or {}, ensure_ascii=False),
  453. now,
  454. ),
  455. )
  456. row = conn.execute("SELECT * FROM automation_errors WHERE id = ?", (cursor.lastrowid,)).fetchone()
  457. return public_error(row)
  458. def execute_mouse_action(payload: AutomationMouseActionRequest) -> dict[str, Any]:
  459. """执行鼠标点击类动作,并记录动作前后新增进程。"""
  460. before = process_snapshot()
  461. compare = validate_screen_before_action(
  462. payload.screen_id,
  463. payload.provider_id,
  464. payload.model_id,
  465. payload.temperature,
  466. f"mouse_{payload.mouse_action}",
  467. payload.workflow_id,
  468. payload.node_id,
  469. )
  470. action_map = {"click": "click", "double_click": "double_click", "right_click": "right_click"}
  471. result = windows_automation.mouse_action(action_map[payload.mouse_action], x=payload.x, y=payload.y)
  472. time.sleep(0.5)
  473. new_processes = diff_new_processes(before, process_snapshot())
  474. return {"result": result, "compare": compare, "new_processes": new_processes}
  475. def execute_keyboard_action(payload: AutomationKeyboardActionRequest) -> dict[str, Any]:
  476. """执行键盘组合键动作,并记录动作前后新增进程。"""
  477. before = process_snapshot()
  478. compare = validate_screen_before_action(
  479. payload.screen_id,
  480. payload.provider_id,
  481. payload.model_id,
  482. payload.temperature,
  483. "keyboard",
  484. payload.workflow_id,
  485. payload.node_id,
  486. )
  487. result = windows_automation.keyboard_action("hotkey" if len(payload.keys) > 1 else "press", key=payload.keys[0], keys=payload.keys)
  488. time.sleep(0.5)
  489. new_processes = diff_new_processes(before, process_snapshot())
  490. return {"result": result, "compare": compare, "new_processes": new_processes}
  491. def execute_text_input(payload: AutomationTextInputRequest) -> dict[str, Any]:
  492. """通过剪贴板粘贴文本,避免直接模拟按键时中文输入不稳定。"""
  493. before = process_snapshot()
  494. compare = validate_screen_before_action(
  495. payload.screen_id,
  496. payload.provider_id,
  497. payload.model_id,
  498. payload.temperature,
  499. "text_input",
  500. payload.workflow_id,
  501. payload.node_id,
  502. )
  503. try:
  504. import pyperclip
  505. except ImportError as exc:
  506. raise HTTPException(status_code=500, detail="pyperclip is not installed") from exc
  507. pyperclip.copy(payload.text)
  508. result = windows_automation.keyboard_action("hotkey", keys=["ctrl", "v"])
  509. time.sleep(0.5)
  510. new_processes = diff_new_processes(before, process_snapshot())
  511. return {"result": result, "compare": compare, "new_processes": new_processes}
  512. def execute_start_program(payload: AutomationStartProgramRequest) -> dict[str, Any]:
  513. """启动程序,并把动作后新增的进程记录为本次自动化打开的程序。"""
  514. before = process_snapshot()
  515. compare = validate_screen_before_action(
  516. payload.screen_id,
  517. payload.provider_id,
  518. payload.model_id,
  519. payload.temperature,
  520. "start_program",
  521. payload.workflow_id,
  522. payload.node_id,
  523. )
  524. result = windows_automation.start_program(payload.command, payload.cwd, payload.shell)
  525. time.sleep(1)
  526. new_processes = diff_new_processes(before, process_snapshot())
  527. if result.get("pid"):
  528. OPENED_PROCESS_IDS.add(int(result["pid"]))
  529. return {"result": result, "compare": compare, "new_processes": new_processes}
  530. def close_opened_programs(pids: list[int] | None = None) -> dict[str, Any]:
  531. """关闭本次自动化过程中记录的新进程。"""
  532. targets = sorted(set(pids or list(OPENED_PROCESS_IDS)))
  533. closed = []
  534. for pid in targets:
  535. try:
  536. closed.append(windows_automation.stop_program(pid=pid))
  537. OPENED_PROCESS_IDS.discard(pid)
  538. except HTTPException as exc:
  539. closed.append({"pid": pid, "error": exc.detail})
  540. return {"action": "close_opened_programs", "items": closed}
  541. def save_workflow(payload: AutomationWorkflowSaveRequest) -> dict[str, Any]:
  542. """保存前端记录或手动编辑的自动化工作流和节点。"""
  543. now = now_iso()
  544. raw_json = payload.model_dump()
  545. with get_db() as conn:
  546. cursor = conn.execute(
  547. """
  548. INSERT INTO automation_workflows (name, description, raw_json, created_at, updated_at)
  549. VALUES (?, ?, ?, ?, ?)
  550. """,
  551. (payload.name.strip(), payload.description, json.dumps(raw_json, ensure_ascii=False), now, now),
  552. )
  553. workflow_id = cursor.lastrowid
  554. insert_workflow_nodes(conn, workflow_id, payload.nodes, now)
  555. return get_workflow(workflow_id)
  556. def update_workflow(workflow_id: int, payload: AutomationWorkflowSaveRequest) -> dict[str, Any]:
  557. """更新工作流基础信息和节点列表。"""
  558. now = now_iso()
  559. raw_json = payload.model_dump()
  560. with get_db() as conn:
  561. existing = conn.execute("SELECT id FROM automation_workflows WHERE id = ?", (workflow_id,)).fetchone()
  562. if not existing:
  563. raise HTTPException(status_code=404, detail="Automation workflow not found")
  564. conn.execute(
  565. """
  566. UPDATE automation_workflows
  567. SET name = ?, description = ?, raw_json = ?, updated_at = ?
  568. WHERE id = ?
  569. """,
  570. (payload.name.strip(), payload.description, json.dumps(raw_json, ensure_ascii=False), now, workflow_id),
  571. )
  572. conn.execute("DELETE FROM automation_workflow_nodes WHERE workflow_id = ?", (workflow_id,))
  573. insert_workflow_nodes(conn, workflow_id, payload.nodes, now)
  574. return get_workflow(workflow_id)
  575. def insert_workflow_nodes(conn, workflow_id: int, nodes: list[Any], now: str) -> None:
  576. """批量写入工作流节点。"""
  577. for index, node in enumerate(nodes, start=1):
  578. conn.execute(
  579. """
  580. INSERT INTO automation_workflow_nodes (
  581. workflow_id, node_index, node_key, node_type, screen_id, title,
  582. position_x, position_y, next_node_keys, config_json, created_at, updated_at
  583. )
  584. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  585. """,
  586. (
  587. workflow_id,
  588. index,
  589. node.node_key or f"node_{index}",
  590. node.node_type,
  591. node.screen_id,
  592. node.title,
  593. node.position_x,
  594. node.position_y,
  595. json.dumps(node.next_node_keys, ensure_ascii=False),
  596. json.dumps(node.config, ensure_ascii=False),
  597. now,
  598. now,
  599. ),
  600. )
  601. def list_workflows(page: int, page_size: int) -> dict[str, Any]:
  602. """分页查询自动化工作流列表。"""
  603. offset = (page - 1) * page_size
  604. with get_db() as conn:
  605. total = conn.execute("SELECT COUNT(*) AS total FROM automation_workflows").fetchone()["total"]
  606. rows = conn.execute(
  607. """
  608. SELECT w.*, COUNT(n.id) AS node_count
  609. FROM automation_workflows w
  610. LEFT JOIN automation_workflow_nodes n ON n.workflow_id = w.id
  611. GROUP BY w.id
  612. ORDER BY w.updated_at DESC
  613. LIMIT ? OFFSET ?
  614. """,
  615. (page_size, offset),
  616. ).fetchall()
  617. return {"items": rows, "total": total, "page": page, "page_size": page_size}
  618. def get_workflow(workflow_id: int) -> dict[str, Any]:
  619. """读取工作流详情和节点列表。"""
  620. with get_db() as conn:
  621. workflow = conn.execute("SELECT * FROM automation_workflows WHERE id = ?", (workflow_id,)).fetchone()
  622. if not workflow:
  623. raise HTTPException(status_code=404, detail="Automation workflow not found")
  624. nodes = conn.execute(
  625. "SELECT * FROM automation_workflow_nodes WHERE workflow_id = ? ORDER BY node_index ASC",
  626. (workflow_id,),
  627. ).fetchall()
  628. item = dict(workflow)
  629. item["nodes"] = [public_node(row) for row in nodes]
  630. return item
  631. def delete_workflow(workflow_id: int) -> dict[str, Any]:
  632. """删除工作流及其节点。"""
  633. with get_db() as conn:
  634. cursor = conn.execute("DELETE FROM automation_workflows WHERE id = ?", (workflow_id,))
  635. if cursor.rowcount == 0:
  636. raise HTTPException(status_code=404, detail="Automation workflow not found")
  637. return {"deleted": cursor.rowcount}
  638. def run_workflow(workflow_id: int, payload: AutomationWorkflowRunRequest) -> dict[str, Any]:
  639. """按数据库中保存的工作流节点和连线顺序在后端执行整个工作流。"""
  640. workflow = get_workflow(workflow_id)
  641. defaults = settings_service.default_ai_params()
  642. provider_id = payload.provider_id or defaults.get("provider_id")
  643. model_id = payload.model_id or defaults.get("model_id")
  644. temperature = payload.temperature if payload.temperature is not None else defaults.get("temperature", 0.1)
  645. nodes = ordered_workflow_nodes(workflow.get("nodes") or [])
  646. results: list[dict[str, Any]] = []
  647. opened_pids: list[int] = []
  648. for node in nodes:
  649. try:
  650. result = execute_workflow_node(workflow_id, node, provider_id, model_id, temperature, opened_pids)
  651. opened_pids.extend(
  652. item["pid"]
  653. for item in result.get("new_processes", [])
  654. if item.get("pid") and item["pid"] not in opened_pids
  655. )
  656. results.append({"node": node, "status": "SUCCESS", "result": result})
  657. except HTTPException as exc:
  658. if not (isinstance(exc.detail, dict) and exc.detail.get("error")):
  659. record_error(
  660. action_type=node.get("node_type") or "workflow",
  661. message=str(exc.detail),
  662. screen_id=node.get("screen_id"),
  663. workflow_id=workflow_id,
  664. node_id=node.get("id"),
  665. )
  666. failure = {"node": node, "status": "FAILED", "detail": exc.detail}
  667. results.append(failure)
  668. return {"workflow_id": workflow_id, "status": "FAILED", "failed": failure, "results": results}
  669. return {"workflow_id": workflow_id, "status": "SUCCESS", "results": results}
  670. def ordered_workflow_nodes(nodes: list[dict[str, Any]]) -> list[dict[str, Any]]:
  671. """根据节点连线得到执行顺序;没有连线时沿用节点序号。"""
  672. if not nodes:
  673. return []
  674. by_key = {node.get("node_key") or f"node_{node.get('node_index')}": node for node in nodes}
  675. targeted = {key for node in nodes for key in node.get("next_node_keys", [])}
  676. start_keys = [key for key in by_key if key not in targeted] or [next(iter(by_key))]
  677. ordered: list[dict[str, Any]] = []
  678. visited: set[str] = set()
  679. def visit(key: str) -> None:
  680. if key in visited or key not in by_key:
  681. return
  682. visited.add(key)
  683. node = by_key[key]
  684. ordered.append(node)
  685. for next_key in node.get("next_node_keys", []):
  686. visit(next_key)
  687. for key in start_keys:
  688. visit(key)
  689. for key in by_key:
  690. visit(key)
  691. return ordered
  692. def execute_workflow_node(
  693. workflow_id: int,
  694. node: dict[str, Any],
  695. provider_id: int | None,
  696. model_id: int | None,
  697. temperature: float,
  698. opened_pids: list[int],
  699. ) -> dict[str, Any]:
  700. """把工作流节点配置转换成已有高层动作并执行。"""
  701. node_type = node.get("node_type")
  702. config = node.get("config") or {}
  703. base = {
  704. "screen_id": node.get("screen_id"),
  705. "provider_id": provider_id,
  706. "model_id": model_id,
  707. "temperature": temperature,
  708. "workflow_id": workflow_id,
  709. "node_id": node.get("id"),
  710. }
  711. if node_type == "mouse":
  712. return execute_mouse_action(
  713. AutomationMouseActionRequest(
  714. **base,
  715. x=int(config.get("x", 0)),
  716. y=int(config.get("y", 0)),
  717. mouse_action=config.get("mouse_action") or "click",
  718. )
  719. )
  720. if node_type == "keyboard":
  721. return execute_keyboard_action(AutomationKeyboardActionRequest(**base, keys=config.get("keys") or []))
  722. if node_type == "text_input":
  723. return execute_text_input(AutomationTextInputRequest(**base, text=str(config.get("text") or "")))
  724. if node_type == "start_program":
  725. return execute_start_program(
  726. AutomationStartProgramRequest(
  727. **base,
  728. command=str(config.get("command") or ""),
  729. cwd=config.get("cwd"),
  730. shell=bool(config.get("shell", True)),
  731. )
  732. )
  733. if node_type == "close_programs":
  734. return close_opened_programs(config.get("pids") or opened_pids)
  735. raise HTTPException(status_code=400, detail=f"Unsupported workflow node type: {node_type}")
  736. def public_node(row: dict[str, Any]) -> dict[str, Any]:
  737. """把工作流节点行转换为接口返回格式。"""
  738. item = dict(row)
  739. try:
  740. item["config"] = json.loads(item.pop("config_json") or "{}")
  741. except json.JSONDecodeError:
  742. item["config"] = {}
  743. try:
  744. item["next_node_keys"] = json.loads(item.get("next_node_keys") or "[]")
  745. except json.JSONDecodeError:
  746. item["next_node_keys"] = []
  747. return item
  748. def list_errors(page: int, page_size: int) -> dict[str, Any]:
  749. """分页查询自动化错误记录。"""
  750. offset = (page - 1) * page_size
  751. with get_db() as conn:
  752. total = conn.execute("SELECT COUNT(*) AS total FROM automation_errors").fetchone()["total"]
  753. rows = conn.execute(
  754. """
  755. SELECT e.*, s.interface_name
  756. FROM automation_errors e
  757. LEFT JOIN automation_screens s ON s.id = e.screen_id
  758. ORDER BY e.created_at DESC
  759. LIMIT ? OFFSET ?
  760. """,
  761. (page_size, offset),
  762. ).fetchall()
  763. return {"items": [public_error(row) for row in rows], "total": total, "page": page, "page_size": page_size}
  764. def get_error(error_id: int, include_images: bool = False) -> dict[str, Any]:
  765. """读取单条自动化错误详情,可附带目标截图和实际截图。"""
  766. with get_db() as conn:
  767. row = conn.execute("SELECT * FROM automation_errors WHERE id = ?", (error_id,)).fetchone()
  768. if not row:
  769. raise HTTPException(status_code=404, detail="Automation error not found")
  770. item = public_error(row)
  771. if include_images:
  772. for key in ["expected_image_path", "actual_image_path"]:
  773. path = item.get(key)
  774. if path and stored_path(path).exists():
  775. image = image_to_base64(path)
  776. item[key.replace("_path", "_base64")] = image["base64"]
  777. item[key.replace("_path", "_mime_type")] = image["mime_type"]
  778. return item
  779. def public_error(row: dict[str, Any]) -> dict[str, Any]:
  780. """把错误记录行转换为接口返回格式。"""
  781. item = dict(row)
  782. try:
  783. item["compare_result"] = json.loads(item.pop("compare_result_json") or "{}")
  784. except json.JSONDecodeError:
  785. item["compare_result"] = {}
  786. return item