vision.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. from __future__ import annotations
  2. import json
  3. import random
  4. import time
  5. import uuid
  6. from pathlib import Path
  7. from typing import Any
  8. from fastapi import HTTPException
  9. from ... import ai_service, settings_service, windows_automation
  10. from ..context import WorkflowContext
  11. from ..registry import control_ports, field_def, register_node
  12. LOCATE_TARGET_PROMPT = """请作为 AI 视觉自动化定位助手,在这张真实屏幕截图中寻找用户指定的可点击目标。
  13. 目标描述:
  14. {target_description}
  15. 当前页面/操作上下文:
  16. {screen_context}
  17. 选择要求:
  18. 1. 如果有多个候选目标,{selection_rule}
  19. 2. 返回目标可点击区域的中心点,不要返回窗口、浏览器地址栏或整块页面的中心。
  20. 3. 坐标必须是相对整张截图宽高的百分比,范围 0-100。
  21. 4. 如果目标不可见、被遮挡、需要滚动、页面未加载完成或你不确定,请返回 found=false。
  22. 严格只输出 JSON 对象,不要输出 Markdown:
  23. {{
  24. "found": boolean,
  25. "x_percent": number|null,
  26. "y_percent": number|null,
  27. "confidence": number,
  28. "target_label": string,
  29. "reason": string
  30. }}"""
  31. VERIFY_PAGE_PROMPT = """请作为 AI 视觉自动化校验器,判断当前屏幕是否符合预期状态。
  32. 预期状态:
  33. {expected_state}
  34. 当前页面/操作上下文:
  35. {screen_context}
  36. 严格只输出 JSON 对象,不要输出 Markdown:
  37. {{
  38. "matched": boolean,
  39. "page_state": string,
  40. "confidence": number,
  41. "reason": string
  42. }}"""
  43. def _number(value: Any, default: float = 0) -> float:
  44. try:
  45. return float(value)
  46. except (TypeError, ValueError):
  47. return default
  48. def _boolean(value: Any, default: bool = False) -> bool:
  49. if value in (None, ""):
  50. return default
  51. if isinstance(value, str):
  52. return value.strip().lower() in {"1", "true", "yes", "y", "on"}
  53. return bool(value)
  54. def _percent(value: Any) -> float | None:
  55. try:
  56. number = float(value)
  57. except (TypeError, ValueError):
  58. return None
  59. if 0 <= number <= 1:
  60. number *= 100
  61. return max(0.0, min(100.0, number))
  62. def _runtime_screenshot_path() -> Path:
  63. """生成 workflow 运行期截图路径,便于失败排查和任务结果追踪。"""
  64. folder = settings_service.resolve_data_path("automation_runtime_path", "automation/runtime")
  65. folder.mkdir(parents=True, exist_ok=True)
  66. return folder / f"vision_locate_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}.png"
  67. def _capture_screen(save_screenshot: bool) -> dict[str, Any]:
  68. save_path = _runtime_screenshot_path() if save_screenshot else None
  69. screenshot = windows_automation.take_screenshot(str(save_path) if save_path else None, include_base64=True)
  70. screenshot["mime_type"] = "image/png"
  71. return screenshot
  72. def _parse_ai_json(content: str) -> dict[str, Any]:
  73. parsed = json.loads(ai_service.extract_json_text(content))
  74. if not isinstance(parsed, dict):
  75. raise ValueError("AI locate output must be a JSON object")
  76. return parsed
  77. def _vision_json(context: WorkflowContext, prompt: str, screenshot: dict[str, Any], temperature: float) -> tuple[dict[str, Any], dict[str, Any]]:
  78. ai_result = ai_service.chat_with_images(
  79. int(context.provider_id),
  80. int(context.model_id),
  81. prompt,
  82. [{"base64": screenshot["image_base64"], "mime_type": screenshot["mime_type"]}],
  83. temperature,
  84. )
  85. return _parse_ai_json(ai_result["content"]), ai_result
  86. def _locate_target(
  87. context: WorkflowContext,
  88. target_description: str,
  89. screen_context: str,
  90. randomize: bool,
  91. save_screenshot: bool,
  92. temperature: float,
  93. ) -> dict[str, Any]:
  94. screenshot = _capture_screen(save_screenshot)
  95. if screenshot.get("path"):
  96. context.runtime["current_screenshot_path"] = screenshot["path"]
  97. if randomize:
  98. selection_rule = f"请结合随机种子 {random.randint(1, 1_000_000)},从可见候选中随机挑选一个"
  99. else:
  100. selection_rule = "请选择最符合目标描述、最容易点击的一个"
  101. prompt = LOCATE_TARGET_PROMPT.format(
  102. target_description=target_description,
  103. screen_context=screen_context,
  104. selection_rule=selection_rule,
  105. )
  106. try:
  107. parsed, ai_result = _vision_json(context, prompt, screenshot, temperature)
  108. except (json.JSONDecodeError, ValueError) as exc:
  109. raise HTTPException(status_code=502, detail=f"AI locate output is not valid JSON: {exc}") from exc
  110. found = bool(parsed.get("found"))
  111. x_percent = _percent(parsed.get("x_percent"))
  112. y_percent = _percent(parsed.get("y_percent"))
  113. base = {
  114. "screenshot_path": screenshot.get("path"),
  115. "width": screenshot.get("width"),
  116. "height": screenshot.get("height"),
  117. "ai_result": parsed,
  118. "ai_raw_content": ai_result["content"],
  119. }
  120. if not found or x_percent is None or y_percent is None:
  121. return {"located": False, "found": False, "next_port": "not_found", **base}
  122. width = int(screenshot["width"])
  123. height = int(screenshot["height"])
  124. x = max(0, min(width - 1, round(width * x_percent / 100)))
  125. y = max(0, min(height - 1, round(height * y_percent / 100)))
  126. return {
  127. "located": True,
  128. "found": True,
  129. "x_percent": x_percent,
  130. "y_percent": y_percent,
  131. "x": x,
  132. "y": y,
  133. "confidence": parsed.get("confidence"),
  134. "target_label": parsed.get("target_label"),
  135. "reason": parsed.get("reason"),
  136. **base,
  137. }
  138. def locate_element_node(node: dict[str, Any], inputs: dict[str, Any], context: WorkflowContext) -> dict[str, Any]:
  139. params = node.get("params", {})
  140. if not context.provider_id or not context.model_id:
  141. raise HTTPException(status_code=400, detail="AI 视觉定位节点需要配置默认 AI 服务商和模型")
  142. target_description = str(inputs.get("target_description", params.get("target_description")) or "").strip()
  143. if not target_description:
  144. raise HTTPException(status_code=400, detail="target_description is required")
  145. screen_context = str(inputs.get("screen_context", params.get("screen_context")) or "当前屏幕").strip()
  146. randomize = _boolean(inputs.get("randomize", params.get("randomize")), False)
  147. save_screenshot = _boolean(inputs.get("save_screenshot", params.get("save_screenshot")), True)
  148. fail_if_not_found = _boolean(inputs.get("fail_if_not_found", params.get("fail_if_not_found")), True)
  149. temperature = _number(inputs.get("temperature", params.get("temperature")), context.temperature)
  150. result = _locate_target(
  151. context,
  152. target_description=target_description,
  153. screen_context=screen_context,
  154. randomize=randomize,
  155. save_screenshot=save_screenshot,
  156. temperature=temperature,
  157. )
  158. if not result.get("located"):
  159. if fail_if_not_found:
  160. ai_result = result.get("ai_result") if isinstance(result.get("ai_result"), dict) else {}
  161. raise HTTPException(status_code=404, detail=ai_result.get("reason") or "AI 未定位到目标元素")
  162. return result
  163. return result
  164. def verify_page_node(node: dict[str, Any], inputs: dict[str, Any], context: WorkflowContext) -> dict[str, Any]:
  165. params = node.get("params", {})
  166. if not context.provider_id or not context.model_id:
  167. raise HTTPException(status_code=400, detail="AI 页面校验节点需要配置默认 AI 服务商和模型")
  168. expected_state = str(inputs.get("expected_state", params.get("expected_state")) or "").strip()
  169. if not expected_state:
  170. raise HTTPException(status_code=400, detail="expected_state is required")
  171. screen_context = str(inputs.get("screen_context", params.get("screen_context")) or "当前屏幕").strip()
  172. save_screenshot = _boolean(inputs.get("save_screenshot", params.get("save_screenshot")), True)
  173. temperature = _number(inputs.get("temperature", params.get("temperature")), context.temperature)
  174. screenshot = _capture_screen(save_screenshot)
  175. if screenshot.get("path"):
  176. context.runtime["current_screenshot_path"] = screenshot["path"]
  177. prompt = VERIFY_PAGE_PROMPT.format(expected_state=expected_state, screen_context=screen_context)
  178. try:
  179. parsed, ai_result = _vision_json(context, prompt, screenshot, temperature)
  180. except (json.JSONDecodeError, ValueError) as exc:
  181. raise HTTPException(status_code=502, detail=f"AI verify output is not valid JSON: {exc}") from exc
  182. matched = bool(parsed.get("matched"))
  183. return {
  184. "matched": matched,
  185. "next_port": "matched" if matched else "not_matched",
  186. "page_state": parsed.get("page_state"),
  187. "confidence": parsed.get("confidence"),
  188. "reason": parsed.get("reason"),
  189. "screenshot_path": screenshot.get("path"),
  190. "width": screenshot.get("width"),
  191. "height": screenshot.get("height"),
  192. "ai_result": parsed,
  193. "ai_raw_content": ai_result["content"],
  194. }
  195. def click_target_node(node: dict[str, Any], inputs: dict[str, Any], context: WorkflowContext) -> dict[str, Any]:
  196. params = node.get("params", {})
  197. target_description = str(inputs.get("target_description", params.get("target_description")) or "").strip()
  198. if not target_description:
  199. raise HTTPException(status_code=400, detail="target_description is required")
  200. screen_context = str(inputs.get("screen_context", params.get("screen_context")) or "当前屏幕").strip()
  201. randomize = _boolean(inputs.get("randomize", params.get("randomize")), False)
  202. save_screenshot = _boolean(inputs.get("save_screenshot", params.get("save_screenshot")), True)
  203. fail_if_not_found = _boolean(inputs.get("fail_if_not_found", params.get("fail_if_not_found")), True)
  204. temperature = _number(inputs.get("temperature", params.get("temperature")), context.temperature)
  205. button = str(inputs.get("button", params.get("button")) or "left")
  206. clicks = int(max(1, min(_number(inputs.get("clicks", params.get("clicks")), 1), 20)))
  207. result = _locate_target(context, target_description, screen_context, randomize, save_screenshot, temperature)
  208. if not result.get("located"):
  209. if fail_if_not_found:
  210. ai_result = result.get("ai_result") if isinstance(result.get("ai_result"), dict) else {}
  211. raise HTTPException(status_code=404, detail=ai_result.get("reason") or "AI 未定位到可点击目标")
  212. return result
  213. clicked = windows_automation.mouse_action("click", x=int(result["x"]), y=int(result["y"]), button=button, clicks=clicks)
  214. return {**result, "clicked": True, "click": clicked, "button": button, "clicks": clicks}
  215. def close_popups_node(node: dict[str, Any], inputs: dict[str, Any], context: WorkflowContext) -> dict[str, Any]:
  216. params = node.get("params", {})
  217. target_description = str(
  218. inputs.get("target_description", params.get("target_description"))
  219. or "当前页面可见的弹窗关闭按钮、跳过按钮、稍后再说按钮、我知道了按钮或拒绝按钮"
  220. )
  221. screen_context = str(inputs.get("screen_context", params.get("screen_context")) or "当前浏览器页面").strip()
  222. attempts = int(max(1, min(_number(inputs.get("attempts", params.get("attempts")), 2), 5)))
  223. temperature = _number(inputs.get("temperature", params.get("temperature")), context.temperature)
  224. closed: list[dict[str, Any]] = []
  225. for _ in range(attempts):
  226. result = _locate_target(context, target_description, screen_context, False, True, temperature)
  227. if not result.get("located"):
  228. return {"closed_count": len(closed), "items": closed, "next_port": "success"}
  229. clicked = windows_automation.mouse_action("click", x=int(result["x"]), y=int(result["y"]))
  230. closed.append({**result, "click": clicked})
  231. time.sleep(0.8)
  232. return {"closed_count": len(closed), "items": closed, "next_port": "success"}
  233. register_node(
  234. {
  235. "type": "vision.locate_element",
  236. "category": "vision",
  237. "label": "AI 视觉定位元素",
  238. "params": {
  239. "target_description": field_def("text", "目标描述", required=True),
  240. "screen_context": field_def("text", "页面上下文"),
  241. "randomize": field_def("boolean", "多候选随机选择", False),
  242. "save_screenshot": field_def("boolean", "保存截图", True),
  243. "fail_if_not_found": field_def("boolean", "找不到时报错", True),
  244. "temperature": field_def("number", "定位温度", 0.1, minimum=0, maximum=2),
  245. },
  246. "inputs": {
  247. "target_description": field_def("string", "目标描述"),
  248. "screen_context": field_def("string", "页面上下文"),
  249. "randomize": field_def("boolean", "多候选随机选择"),
  250. "save_screenshot": field_def("boolean", "保存截图"),
  251. "fail_if_not_found": field_def("boolean", "找不到时报错"),
  252. "temperature": field_def("number", "定位温度"),
  253. },
  254. "outputs": {
  255. "located": {"type": "boolean", "label": "是否定位成功"},
  256. "x_percent": {"type": "number", "label": "X 百分比"},
  257. "y_percent": {"type": "number", "label": "Y 百分比"},
  258. "x": {"type": "number", "label": "X 坐标"},
  259. "y": {"type": "number", "label": "Y 坐标"},
  260. "confidence": {"type": "number", "label": "置信度"},
  261. "target_label": {"type": "string", "label": "目标标签"},
  262. "screenshot_path": {"type": "string", "label": "截图路径"},
  263. "ai_result": {"type": "object", "label": "AI 结果"},
  264. },
  265. "control_ports": control_ports(["success", "not_found", "failure"]),
  266. },
  267. locate_element_node,
  268. )
  269. register_node(
  270. {
  271. "type": "vision.verify_page",
  272. "category": "vision",
  273. "label": "AI 校验页面状态",
  274. "description": "截取当前屏幕,让多模态 AI 判断页面是否符合预期,并按 matched/not_matched 分支继续。",
  275. "params": {
  276. "expected_state": field_def("text", "预期状态", required=True),
  277. "screen_context": field_def("text", "页面上下文"),
  278. "save_screenshot": field_def("boolean", "保存截图", True),
  279. "temperature": field_def("number", "校验温度", 0.1, minimum=0, maximum=2),
  280. },
  281. "inputs": {
  282. "expected_state": field_def("string", "预期状态"),
  283. "screen_context": field_def("string", "页面上下文"),
  284. "save_screenshot": field_def("boolean", "保存截图"),
  285. "temperature": field_def("number", "校验温度"),
  286. },
  287. "outputs": {
  288. "matched": {"type": "boolean", "label": "是否匹配"},
  289. "page_state": {"type": "string", "label": "页面状态"},
  290. "confidence": {"type": "number", "label": "置信度"},
  291. "reason": {"type": "string", "label": "原因"},
  292. "screenshot_path": {"type": "string", "label": "截图路径"},
  293. "ai_result": {"type": "object", "label": "AI 结果"},
  294. },
  295. "control_ports": control_ports(["matched", "not_matched", "failure"]),
  296. },
  297. verify_page_node,
  298. )
  299. register_node(
  300. {
  301. "type": "vision.click_target",
  302. "category": "vision",
  303. "label": "AI 定位并点击",
  304. "description": "截屏定位目标元素,换算坐标后立即点击,适合封装常见视觉点击步骤。",
  305. "params": {
  306. "target_description": field_def("text", "目标描述", required=True),
  307. "screen_context": field_def("text", "页面上下文"),
  308. "randomize": field_def("boolean", "多候选随机选择", False),
  309. "button": field_def("select", "按键", "left", options=["left", "middle", "right"]),
  310. "clicks": field_def("number", "点击次数", 1, minimum=1, maximum=20),
  311. "save_screenshot": field_def("boolean", "保存截图", True),
  312. "fail_if_not_found": field_def("boolean", "找不到时报错", True),
  313. "temperature": field_def("number", "定位温度", 0.1, minimum=0, maximum=2),
  314. },
  315. "inputs": {
  316. "target_description": field_def("string", "目标描述"),
  317. "screen_context": field_def("string", "页面上下文"),
  318. "randomize": field_def("boolean", "多候选随机选择"),
  319. "button": field_def("string", "按键"),
  320. "clicks": field_def("number", "点击次数"),
  321. "save_screenshot": field_def("boolean", "保存截图"),
  322. "fail_if_not_found": field_def("boolean", "找不到时报错"),
  323. "temperature": field_def("number", "定位温度"),
  324. },
  325. "outputs": {
  326. "located": {"type": "boolean", "label": "是否定位成功"},
  327. "clicked": {"type": "boolean", "label": "是否已点击"},
  328. "x": {"type": "number", "label": "X 坐标"},
  329. "y": {"type": "number", "label": "Y 坐标"},
  330. "confidence": {"type": "number", "label": "置信度"},
  331. "target_label": {"type": "string", "label": "目标标签"},
  332. "click": {"type": "object", "label": "点击结果"},
  333. "ai_result": {"type": "object", "label": "AI 结果"},
  334. },
  335. "control_ports": control_ports(["success", "not_found", "failure"]),
  336. },
  337. click_target_node,
  338. )
  339. register_node(
  340. {
  341. "type": "vision.close_popups",
  342. "category": "vision",
  343. "label": "AI 关闭弹窗",
  344. "description": "尝试识别并点击当前页面上的关闭、跳过、稍后再说等弹窗按钮。",
  345. "params": {
  346. "target_description": field_def("text", "关闭目标", "当前页面可见的弹窗关闭按钮、跳过按钮、稍后再说按钮、我知道了按钮或拒绝按钮"),
  347. "screen_context": field_def("text", "页面上下文", "当前浏览器页面"),
  348. "attempts": field_def("number", "最多尝试", 2, minimum=1, maximum=5),
  349. "temperature": field_def("number", "定位温度", 0.1, minimum=0, maximum=2),
  350. },
  351. "inputs": {
  352. "target_description": field_def("string", "关闭目标"),
  353. "screen_context": field_def("string", "页面上下文"),
  354. "attempts": field_def("number", "最多尝试"),
  355. "temperature": field_def("number", "定位温度"),
  356. },
  357. "outputs": {
  358. "closed_count": {"type": "number", "label": "关闭数量"},
  359. "items": {"type": "array", "label": "关闭记录"},
  360. },
  361. "control_ports": control_ports(),
  362. },
  363. close_popups_node,
  364. )