research.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. from __future__ import annotations
  2. import json
  3. from typing import Any
  4. from urllib.parse import urlparse
  5. from fastapi import HTTPException
  6. from ... import ai_service
  7. from ..context import WorkflowContext
  8. from ..registry import control_ports, field_def, register_node
  9. from .web_search import WebSearchRunner, _integer
  10. def parse_object(value: Any, field_name: str) -> dict[str, Any]:
  11. """兼容 API 直接传对象和编辑器中保存的 JSON 文本。"""
  12. if isinstance(value, dict):
  13. return value
  14. if isinstance(value, str) and value.strip():
  15. try:
  16. parsed = json.loads(value)
  17. except json.JSONDecodeError as exc:
  18. raise HTTPException(status_code=400, detail=f"{field_name} 不是有效 JSON 对象: {exc}") from exc
  19. if isinstance(parsed, dict):
  20. return parsed
  21. raise HTTPException(status_code=400, detail=f"{field_name} 必须是 JSON 对象")
  22. def validate_json_data(data: Any, schema: dict[str, Any]) -> dict[str, Any]:
  23. """使用标准 JSON Schema 校验最终返回数据。"""
  24. try:
  25. from jsonschema import Draft202012Validator
  26. except ImportError as exc:
  27. raise HTTPException(status_code=500, detail="jsonschema is not installed") from exc
  28. try:
  29. Draft202012Validator.check_schema(schema)
  30. except Exception as exc:
  31. raise HTTPException(status_code=400, detail=f"output_schema 无效: {exc}") from exc
  32. validator = Draft202012Validator(schema)
  33. errors = sorted(validator.iter_errors(data), key=lambda item: list(item.absolute_path))
  34. return {
  35. "schema_valid": not errors,
  36. "errors": [
  37. {
  38. "path": ".".join(str(part) for part in error.absolute_path),
  39. "message": error.message,
  40. }
  41. for error in errors
  42. ],
  43. }
  44. def validate_research_result(
  45. data: Any,
  46. schema: dict[str, Any],
  47. constraints: dict[str, Any],
  48. evidence: list[dict[str, Any]],
  49. ) -> dict[str, Any]:
  50. """组合 JSON Schema 与来源数量、必需域名等确定性约束。"""
  51. result = validate_json_data(data, schema)
  52. constraint_errors: list[str] = []
  53. sources = sources_from_evidence(evidence)
  54. try:
  55. min_sources = max(0, int(constraints.get("min_sources", 0)))
  56. except (TypeError, ValueError):
  57. min_sources = 0
  58. if len(sources) < min_sources:
  59. constraint_errors.append(f"来源数量 {len(sources)} 少于要求的 {min_sources}")
  60. required_domains = constraints.get("required_domains") or []
  61. if isinstance(required_domains, list):
  62. source_hosts = {urlparse(item["url"]).netloc.lower() for item in sources}
  63. for domain in required_domains:
  64. normalized = str(domain).lower().strip()
  65. if normalized and not any(host == normalized or host.endswith(f".{normalized}") for host in source_hosts):
  66. constraint_errors.append(f"缺少必需来源域名: {normalized}")
  67. result["constraints_valid"] = not constraint_errors
  68. result["constraint_errors"] = constraint_errors
  69. result["valid"] = result["schema_valid"] and result["constraints_valid"]
  70. return result
  71. class AiWebResearchRunner:
  72. """AI 驱动的多轮视觉网页研究状态机。"""
  73. def __init__(self, context: WorkflowContext, params: dict[str, Any]) -> None:
  74. if not context.provider_id or not context.model_id:
  75. raise HTTPException(status_code=400, detail="AI 搜索研究节点需要配置默认 AI 服务商和模型")
  76. self.context = context
  77. self.params = params
  78. self.objective = str(params.get("objective") or "").strip()
  79. if not self.objective:
  80. raise HTTPException(status_code=400, detail="研究目标不能为空")
  81. self.output_schema = parse_object(params.get("output_schema"), "output_schema")
  82. self.constraints = parse_object(params.get("constraints") or {}, "constraints")
  83. self.max_attempts = _integer(params.get("max_attempts"), 3, 1, 10)
  84. self.search_engine = str(params.get("search_engine") or "bing")
  85. self.browser = str(params.get("browser") or "edge")
  86. self.max_search_pages = _integer(params.get("max_search_pages"), 2, 1, 10)
  87. self.result_count = _integer(params.get("result_count"), 2, 1, 5)
  88. self.detail_max_pages = _integer(params.get("detail_max_pages"), 2, 1, 10)
  89. def run(self) -> dict[str, Any]:
  90. plan = self._create_plan()
  91. pending_queries = self._normalize_queries(plan.get("queries"))
  92. if not pending_queries:
  93. pending_queries = [self.objective]
  94. searched_queries: list[str] = []
  95. evidence: list[dict[str, Any]] = []
  96. attempts: list[dict[str, Any]] = []
  97. latest_assessment: dict[str, Any] = {}
  98. latest_data: Any = {}
  99. latest_validation = validate_research_result(latest_data, self.output_schema, self.constraints, evidence)
  100. for attempt_number in range(1, self.max_attempts + 1):
  101. query = self._next_query(pending_queries, searched_queries, latest_assessment)
  102. searched_queries.append(query)
  103. search_output = WebSearchRunner(
  104. self.context,
  105. {
  106. "query": query,
  107. "search_engine": self.search_engine,
  108. "browser": self.browser,
  109. "max_search_pages": self.max_search_pages,
  110. "result_count": self.result_count,
  111. "detail_max_pages": self.detail_max_pages,
  112. "click_attempts": self.params.get("click_attempts", 2),
  113. "maximize_browser": self.params.get("maximize_browser", True),
  114. "page_load_wait_seconds": self.params.get("page_load_wait_seconds", 8),
  115. "action_wait_seconds": self.params.get("action_wait_seconds", 1),
  116. "wait_jitter_min_seconds": self.params.get("wait_jitter_min_seconds", 0),
  117. "wait_jitter_max_seconds": self.params.get("wait_jitter_max_seconds", 0),
  118. "close_browser": True,
  119. "include_debug_analyses": False,
  120. },
  121. ).run()
  122. attempt_evidence = compact_evidence(search_output)
  123. evidence.extend(attempt_evidence)
  124. latest_assessment = self._assess_progress(plan, searched_queries, evidence)
  125. latest_data = latest_assessment.get("candidate_data")
  126. latest_validation = validate_research_result(
  127. latest_data,
  128. self.output_schema,
  129. self.constraints,
  130. evidence,
  131. )
  132. goal_achieved = bool(latest_assessment.get("goal_achieved")) and latest_validation["valid"]
  133. attempts.append(
  134. {
  135. "attempt": attempt_number,
  136. "query": query,
  137. "search_result_count": search_output.get("result_count", 0),
  138. "researched_count": search_output.get("researched_count", 0),
  139. "sources": sources_from_evidence(attempt_evidence),
  140. "assessment": {
  141. "goal_achieved": bool(latest_assessment.get("goal_achieved")),
  142. "confidence": latest_assessment.get("confidence"),
  143. "reason": latest_assessment.get("reason"),
  144. "missing_information": latest_assessment.get("missing_information") or [],
  145. },
  146. "validation": latest_validation,
  147. }
  148. )
  149. if goal_achieved:
  150. return self._build_output(
  151. plan,
  152. attempts,
  153. evidence,
  154. latest_data,
  155. latest_validation,
  156. latest_assessment,
  157. True,
  158. )
  159. pending_queries.extend(self._normalize_queries(latest_assessment.get("next_queries")))
  160. return self._build_output(
  161. plan,
  162. attempts,
  163. evidence,
  164. latest_data,
  165. latest_validation,
  166. latest_assessment,
  167. False,
  168. )
  169. def _text_json(self, prompt: str) -> dict[str, Any]:
  170. result = ai_service.chat(
  171. int(self.context.provider_id),
  172. int(self.context.model_id),
  173. prompt,
  174. self.context.temperature,
  175. )
  176. try:
  177. parsed = json.loads(ai_service.extract_json_text(result["content"]))
  178. except (json.JSONDecodeError, TypeError, ValueError) as exc:
  179. raise HTTPException(status_code=502, detail=f"AI 研究模型未返回有效 JSON: {exc}") from exc
  180. if not isinstance(parsed, dict):
  181. raise HTTPException(status_code=502, detail="AI 研究模型返回值必须是 JSON 对象")
  182. return parsed
  183. def _create_plan(self) -> dict[str, Any]:
  184. prompt = f"""请为一个使用真实浏览器和视觉截图的网页研究任务制定搜索计划。
  185. 研究目标:
  186. {self.objective}
  187. 最终输出 JSON Schema:
  188. {json.dumps(self.output_schema, ensure_ascii=False, indent=2)}
  189. 约束:
  190. {json.dumps(self.constraints, ensure_ascii=False, indent=2)}
  191. 最多尝试次数:{self.max_attempts}
  192. 请严格只输出 JSON:
  193. {{
  194. "summary": string,
  195. "acceptance_criteria": [string],
  196. "queries": [string],
  197. "source_preferences": [string],
  198. "risks": [string]
  199. }}
  200. queries 应按优先级排列,数量不超过最多尝试次数。"""
  201. return self._text_json(prompt)
  202. def _assess_progress(
  203. self,
  204. plan: dict[str, Any],
  205. searched_queries: list[str],
  206. evidence: list[dict[str, Any]],
  207. ) -> dict[str, Any]:
  208. prompt = f"""请评估网页研究任务是否已经达成,并生成符合指定 JSON Schema 的候选数据。
  209. 研究目标:
  210. {self.objective}
  211. 研究计划:
  212. {json.dumps(plan, ensure_ascii=False)}
  213. 输出 JSON Schema:
  214. {json.dumps(self.output_schema, ensure_ascii=False, indent=2)}
  215. 约束:
  216. {json.dumps(self.constraints, ensure_ascii=False)}
  217. 已搜索查询:
  218. {json.dumps(searched_queries, ensure_ascii=False)}
  219. 已获得证据:
  220. {json.dumps(evidence[-20:], ensure_ascii=False)}
  221. 判断规则:
  222. 1. 只有证据足以覆盖研究目标和计划中的验收标准时,goal_achieved 才能为 true。
  223. 2. candidate_data 必须严格匹配给定 JSON Schema,不要添加 Schema 未允许的包装字段。
  224. 3. 缺少信息时给出下一轮更精确、且与已搜索内容不同的查询词。
  225. 4. 不要把搜索摘要中的推测当作已验证事实。
  226. 严格只输出 JSON:
  227. {{
  228. "goal_achieved": boolean,
  229. "confidence": number,
  230. "reason": string,
  231. "missing_information": [string],
  232. "next_queries": [string],
  233. "candidate_data": object
  234. }}"""
  235. return self._text_json(prompt)
  236. def _next_query(
  237. self,
  238. pending_queries: list[str],
  239. searched_queries: list[str],
  240. assessment: dict[str, Any],
  241. ) -> str:
  242. searched = {item.strip().lower() for item in searched_queries}
  243. while pending_queries:
  244. query = pending_queries.pop(0).strip()
  245. if query and query.lower() not in searched:
  246. return query
  247. missing = assessment.get("missing_information") or []
  248. suffix = " ".join(str(item) for item in missing[:2])
  249. return f"{self.objective} {suffix} 补充资料 第{len(searched_queries) + 1}轮".strip()
  250. @staticmethod
  251. def _normalize_queries(value: Any) -> list[str]:
  252. if not isinstance(value, list):
  253. return []
  254. return [str(item).strip() for item in value if str(item).strip()]
  255. def _build_output(
  256. self,
  257. plan: dict[str, Any],
  258. attempts: list[dict[str, Any]],
  259. evidence: list[dict[str, Any]],
  260. data: Any,
  261. validation: dict[str, Any],
  262. assessment: dict[str, Any],
  263. goal_achieved: bool,
  264. ) -> dict[str, Any]:
  265. return {
  266. "status": "GOAL_ACHIEVED" if goal_achieved else "MAX_ATTEMPTS_REACHED",
  267. "goal_achieved": goal_achieved,
  268. "objective": self.objective,
  269. "attempts_used": len(attempts),
  270. "max_attempts": self.max_attempts,
  271. "data": data,
  272. "validation": validation,
  273. "assessment": {
  274. "confidence": assessment.get("confidence"),
  275. "reason": assessment.get("reason"),
  276. "missing_information": assessment.get("missing_information") or [],
  277. },
  278. "sources": sources_from_evidence(evidence),
  279. "plan": plan,
  280. "attempts": attempts,
  281. "next_port": "success" if goal_achieved else "partial",
  282. }
  283. def compact_evidence(search_output: dict[str, Any]) -> list[dict[str, Any]]:
  284. """只保留评估所需字段,控制多轮提示词长度。"""
  285. evidence: list[dict[str, Any]] = []
  286. for detail in search_output.get("researched_details") or []:
  287. if not isinstance(detail, dict):
  288. continue
  289. result = detail.get("result") if isinstance(detail.get("result"), dict) else {}
  290. cleaned = detail.get("cleaned") if isinstance(detail.get("cleaned"), dict) else {}
  291. evidence.append(
  292. {
  293. "title": cleaned.get("clean_title") or result.get("title"),
  294. "url": detail.get("visited_url") or result.get("url"),
  295. "text": cleaned.get("clean_text") or detail.get("error") or "",
  296. "key_points": cleaned.get("key_points") or [],
  297. "opened_detail_page": bool(detail.get("opened_detail_page")),
  298. }
  299. )
  300. return evidence
  301. def sources_from_evidence(evidence: list[dict[str, Any]]) -> list[dict[str, str]]:
  302. sources: list[dict[str, str]] = []
  303. seen: set[str] = set()
  304. for item in evidence:
  305. url = str(item.get("url") or "").strip()
  306. if not url or url in seen:
  307. continue
  308. seen.add(url)
  309. sources.append({"title": str(item.get("title") or url), "url": url})
  310. return sources
  311. def ai_web_research_node(node: dict[str, Any], inputs: dict[str, Any], context: WorkflowContext) -> dict[str, Any]:
  312. params = {**(node.get("params") or {}), **inputs}
  313. return AiWebResearchRunner(context, params).run()
  314. register_node(
  315. {
  316. "type": "research.ai_web_research",
  317. "category": "research",
  318. "label": "AI 多轮网页研究",
  319. "params": {
  320. "objective": field_def("textarea", "研究目标", required=True),
  321. "output_schema": field_def("textarea", "返回 JSON Schema", required=True),
  322. "constraints": field_def("textarea", "研究约束", "{}"),
  323. "max_attempts": field_def("number", "最多尝试次数", 3, minimum=1, maximum=10),
  324. "search_engine": field_def("select", "搜索引擎", "bing", options=["google", "bing"]),
  325. "browser": field_def("select", "浏览器", "edge", options=["default", "edge"]),
  326. "max_search_pages": field_def("number", "每轮搜索页屏", 2, minimum=1, maximum=10),
  327. "result_count": field_def("number", "每轮研究结果数", 2, minimum=1, maximum=5),
  328. "detail_max_pages": field_def("number", "每个详情页屏", 2, minimum=1, maximum=10),
  329. "click_attempts": field_def("number", "标题点击重试", 2, minimum=1, maximum=5),
  330. "maximize_browser": field_def("boolean", "打开后最大化浏览器", True),
  331. "page_load_wait_seconds": field_def("number", "页面加载等待秒数", 8, minimum=0, maximum=120),
  332. "action_wait_seconds": field_def("number", "操作等待秒数", 1, minimum=0, maximum=30),
  333. "wait_jitter_min_seconds": field_def("number", "等待抖动最小秒数", 0, minimum=0, maximum=30),
  334. "wait_jitter_max_seconds": field_def("number", "等待抖动最大秒数", 0, minimum=0, maximum=30),
  335. },
  336. "inputs": {
  337. "objective": field_def("string", "研究目标"),
  338. "output_schema": field_def("object", "返回 JSON Schema"),
  339. "constraints": field_def("object", "研究约束"),
  340. "max_attempts": field_def("number", "最多尝试次数"),
  341. },
  342. "outputs": {
  343. "status": {"type": "string", "label": "研究状态"},
  344. "goal_achieved": {"type": "boolean", "label": "是否达成目标"},
  345. "data": {"type": "object", "label": "符合 Schema 的数据"},
  346. "validation": {"type": "object", "label": "Schema 校验结果"},
  347. "assessment": {"type": "object", "label": "目标评估"},
  348. "sources": {"type": "array", "label": "来源"},
  349. "attempts": {"type": "array", "label": "尝试记录"},
  350. },
  351. "control_ports": control_ports(["success", "partial", "failure"]),
  352. },
  353. ai_web_research_node,
  354. )