research.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. from __future__ import annotations
  2. import json
  3. from typing import Any
  4. from urllib.parse import urlparse
  5. from fastapi import HTTPException
  6. from ... import ai_service
  7. from ..context import WorkflowContext
  8. from ..registry import control_ports, field_def, register_node
  9. from .web_search import WebSearchRunner, _integer
  10. def parse_object(value: Any, field_name: str) -> dict[str, Any]:
  11. """兼容 API 直接传对象和编辑器中保存的 JSON 文本。"""
  12. if isinstance(value, dict):
  13. return value
  14. if isinstance(value, str) and value.strip():
  15. try:
  16. parsed = json.loads(value)
  17. except json.JSONDecodeError as exc:
  18. raise HTTPException(status_code=400, detail=f"{field_name} 不是有效 JSON 对象: {exc}") from exc
  19. if isinstance(parsed, dict):
  20. return parsed
  21. raise HTTPException(status_code=400, detail=f"{field_name} 必须是 JSON 对象")
  22. def validate_json_data(data: Any, schema: dict[str, Any]) -> dict[str, Any]:
  23. """使用标准 JSON Schema 校验最终返回数据。"""
  24. try:
  25. from jsonschema import Draft202012Validator
  26. except ImportError as exc:
  27. raise HTTPException(status_code=500, detail="jsonschema is not installed") from exc
  28. try:
  29. Draft202012Validator.check_schema(schema)
  30. except Exception as exc:
  31. raise HTTPException(status_code=400, detail=f"output_schema 无效: {exc}") from exc
  32. validator = Draft202012Validator(schema)
  33. errors = sorted(validator.iter_errors(data), key=lambda item: list(item.absolute_path))
  34. return {
  35. "schema_valid": not errors,
  36. "errors": [
  37. {
  38. "path": ".".join(str(part) for part in error.absolute_path),
  39. "message": error.message,
  40. }
  41. for error in errors
  42. ],
  43. }
  44. def validate_research_result(
  45. data: Any,
  46. schema: dict[str, Any],
  47. constraints: dict[str, Any],
  48. evidence: list[dict[str, Any]],
  49. ) -> dict[str, Any]:
  50. """组合 JSON Schema 与来源数量、必需域名等确定性约束。"""
  51. result = validate_json_data(data, schema)
  52. constraint_errors: list[str] = []
  53. sources = sources_from_evidence(evidence)
  54. try:
  55. min_sources = max(0, int(constraints.get("min_sources", 0)))
  56. except (TypeError, ValueError):
  57. min_sources = 0
  58. if len(sources) < min_sources:
  59. constraint_errors.append(f"来源数量 {len(sources)} 少于要求的 {min_sources}")
  60. required_domains = constraints.get("required_domains") or []
  61. if isinstance(required_domains, list):
  62. source_hosts = {urlparse(item["url"]).netloc.lower() for item in sources}
  63. for domain in required_domains:
  64. normalized = str(domain).lower().strip()
  65. if normalized and not any(host == normalized or host.endswith(f".{normalized}") for host in source_hosts):
  66. constraint_errors.append(f"缺少必需来源域名: {normalized}")
  67. result["constraints_valid"] = not constraint_errors
  68. result["constraint_errors"] = constraint_errors
  69. result["valid"] = result["schema_valid"] and result["constraints_valid"]
  70. return result
  71. class AiWebResearchRunner:
  72. """AI 驱动的多轮视觉网页研究状态机。"""
  73. def __init__(self, context: WorkflowContext, params: dict[str, Any]) -> None:
  74. if not context.provider_id or not context.model_id:
  75. raise HTTPException(status_code=400, detail="AI 搜索研究节点需要配置默认 AI 服务商和模型")
  76. self.context = context
  77. self.params = params
  78. self.objective = str(params.get("objective") or "").strip()
  79. if not self.objective:
  80. raise HTTPException(status_code=400, detail="研究目标不能为空")
  81. self.output_schema = parse_object(params.get("output_schema"), "output_schema")
  82. self.constraints = parse_object(params.get("constraints") or {}, "constraints")
  83. self.max_attempts = _integer(params.get("max_attempts"), 3, 1, 10)
  84. self.search_engine = str(params.get("search_engine") or "bing")
  85. self.browser = str(params.get("browser") or "edge")
  86. self.max_search_pages = _integer(params.get("max_search_pages"), 2, 1, 10)
  87. self.result_count = _integer(params.get("result_count"), 2, 1, 5)
  88. self.detail_max_pages = _integer(params.get("detail_max_pages"), 2, 1, 10)
  89. def run(self) -> dict[str, Any]:
  90. plan = self._create_plan()
  91. pending_queries = self._normalize_queries(plan.get("queries"))
  92. if not pending_queries:
  93. pending_queries = [self.objective]
  94. searched_queries: list[str] = []
  95. evidence: list[dict[str, Any]] = []
  96. attempts: list[dict[str, Any]] = []
  97. latest_assessment: dict[str, Any] = {}
  98. latest_data: Any = {}
  99. latest_validation = validate_research_result(latest_data, self.output_schema, self.constraints, evidence)
  100. for attempt_number in range(1, self.max_attempts + 1):
  101. query = self._next_query(pending_queries, searched_queries, latest_assessment)
  102. searched_queries.append(query)
  103. search_output = WebSearchRunner(
  104. self.context,
  105. {
  106. "query": query,
  107. "search_engine": self.search_engine,
  108. "browser": self.browser,
  109. "max_search_pages": self.max_search_pages,
  110. "result_count": self.result_count,
  111. "detail_max_pages": self.detail_max_pages,
  112. "click_attempts": self.params.get("click_attempts", 2),
  113. "page_load_wait_seconds": self.params.get("page_load_wait_seconds", 8),
  114. "action_wait_seconds": self.params.get("action_wait_seconds", 1),
  115. "close_browser": True,
  116. "include_debug_analyses": False,
  117. },
  118. ).run()
  119. attempt_evidence = compact_evidence(search_output)
  120. evidence.extend(attempt_evidence)
  121. latest_assessment = self._assess_progress(plan, searched_queries, evidence)
  122. latest_data = latest_assessment.get("candidate_data")
  123. latest_validation = validate_research_result(
  124. latest_data,
  125. self.output_schema,
  126. self.constraints,
  127. evidence,
  128. )
  129. goal_achieved = bool(latest_assessment.get("goal_achieved")) and latest_validation["valid"]
  130. attempts.append(
  131. {
  132. "attempt": attempt_number,
  133. "query": query,
  134. "search_result_count": search_output.get("result_count", 0),
  135. "researched_count": search_output.get("researched_count", 0),
  136. "sources": sources_from_evidence(attempt_evidence),
  137. "assessment": {
  138. "goal_achieved": bool(latest_assessment.get("goal_achieved")),
  139. "confidence": latest_assessment.get("confidence"),
  140. "reason": latest_assessment.get("reason"),
  141. "missing_information": latest_assessment.get("missing_information") or [],
  142. },
  143. "validation": latest_validation,
  144. }
  145. )
  146. if goal_achieved:
  147. return self._build_output(
  148. plan,
  149. attempts,
  150. evidence,
  151. latest_data,
  152. latest_validation,
  153. latest_assessment,
  154. True,
  155. )
  156. pending_queries.extend(self._normalize_queries(latest_assessment.get("next_queries")))
  157. return self._build_output(
  158. plan,
  159. attempts,
  160. evidence,
  161. latest_data,
  162. latest_validation,
  163. latest_assessment,
  164. False,
  165. )
  166. def _text_json(self, prompt: str) -> dict[str, Any]:
  167. result = ai_service.chat(
  168. int(self.context.provider_id),
  169. int(self.context.model_id),
  170. prompt,
  171. self.context.temperature,
  172. )
  173. try:
  174. parsed = json.loads(ai_service.extract_json_text(result["content"]))
  175. except (json.JSONDecodeError, TypeError, ValueError) as exc:
  176. raise HTTPException(status_code=502, detail=f"AI 研究模型未返回有效 JSON: {exc}") from exc
  177. if not isinstance(parsed, dict):
  178. raise HTTPException(status_code=502, detail="AI 研究模型返回值必须是 JSON 对象")
  179. return parsed
  180. def _create_plan(self) -> dict[str, Any]:
  181. prompt = f"""请为一个使用真实浏览器和视觉截图的网页研究任务制定搜索计划。
  182. 研究目标:
  183. {self.objective}
  184. 最终输出 JSON Schema:
  185. {json.dumps(self.output_schema, ensure_ascii=False, indent=2)}
  186. 约束:
  187. {json.dumps(self.constraints, ensure_ascii=False, indent=2)}
  188. 最多尝试次数:{self.max_attempts}
  189. 请严格只输出 JSON:
  190. {{
  191. "summary": string,
  192. "acceptance_criteria": [string],
  193. "queries": [string],
  194. "source_preferences": [string],
  195. "risks": [string]
  196. }}
  197. queries 应按优先级排列,数量不超过最多尝试次数。"""
  198. return self._text_json(prompt)
  199. def _assess_progress(
  200. self,
  201. plan: dict[str, Any],
  202. searched_queries: list[str],
  203. evidence: list[dict[str, Any]],
  204. ) -> dict[str, Any]:
  205. prompt = f"""请评估网页研究任务是否已经达成,并生成符合指定 JSON Schema 的候选数据。
  206. 研究目标:
  207. {self.objective}
  208. 研究计划:
  209. {json.dumps(plan, ensure_ascii=False)}
  210. 输出 JSON Schema:
  211. {json.dumps(self.output_schema, ensure_ascii=False, indent=2)}
  212. 约束:
  213. {json.dumps(self.constraints, ensure_ascii=False)}
  214. 已搜索查询:
  215. {json.dumps(searched_queries, ensure_ascii=False)}
  216. 已获得证据:
  217. {json.dumps(evidence[-20:], ensure_ascii=False)}
  218. 判断规则:
  219. 1. 只有证据足以覆盖研究目标和计划中的验收标准时,goal_achieved 才能为 true。
  220. 2. candidate_data 必须严格匹配给定 JSON Schema,不要添加 Schema 未允许的包装字段。
  221. 3. 缺少信息时给出下一轮更精确、且与已搜索内容不同的查询词。
  222. 4. 不要把搜索摘要中的推测当作已验证事实。
  223. 严格只输出 JSON:
  224. {{
  225. "goal_achieved": boolean,
  226. "confidence": number,
  227. "reason": string,
  228. "missing_information": [string],
  229. "next_queries": [string],
  230. "candidate_data": object
  231. }}"""
  232. return self._text_json(prompt)
  233. def _next_query(
  234. self,
  235. pending_queries: list[str],
  236. searched_queries: list[str],
  237. assessment: dict[str, Any],
  238. ) -> str:
  239. searched = {item.strip().lower() for item in searched_queries}
  240. while pending_queries:
  241. query = pending_queries.pop(0).strip()
  242. if query and query.lower() not in searched:
  243. return query
  244. missing = assessment.get("missing_information") or []
  245. suffix = " ".join(str(item) for item in missing[:2])
  246. return f"{self.objective} {suffix} 补充资料 第{len(searched_queries) + 1}轮".strip()
  247. @staticmethod
  248. def _normalize_queries(value: Any) -> list[str]:
  249. if not isinstance(value, list):
  250. return []
  251. return [str(item).strip() for item in value if str(item).strip()]
  252. def _build_output(
  253. self,
  254. plan: dict[str, Any],
  255. attempts: list[dict[str, Any]],
  256. evidence: list[dict[str, Any]],
  257. data: Any,
  258. validation: dict[str, Any],
  259. assessment: dict[str, Any],
  260. goal_achieved: bool,
  261. ) -> dict[str, Any]:
  262. return {
  263. "status": "GOAL_ACHIEVED" if goal_achieved else "MAX_ATTEMPTS_REACHED",
  264. "goal_achieved": goal_achieved,
  265. "objective": self.objective,
  266. "attempts_used": len(attempts),
  267. "max_attempts": self.max_attempts,
  268. "data": data,
  269. "validation": validation,
  270. "assessment": {
  271. "confidence": assessment.get("confidence"),
  272. "reason": assessment.get("reason"),
  273. "missing_information": assessment.get("missing_information") or [],
  274. },
  275. "sources": sources_from_evidence(evidence),
  276. "plan": plan,
  277. "attempts": attempts,
  278. "next_port": "success" if goal_achieved else "partial",
  279. }
  280. def compact_evidence(search_output: dict[str, Any]) -> list[dict[str, Any]]:
  281. """只保留评估所需字段,控制多轮提示词长度。"""
  282. evidence: list[dict[str, Any]] = []
  283. for detail in search_output.get("researched_details") or []:
  284. if not isinstance(detail, dict):
  285. continue
  286. result = detail.get("result") if isinstance(detail.get("result"), dict) else {}
  287. cleaned = detail.get("cleaned") if isinstance(detail.get("cleaned"), dict) else {}
  288. evidence.append(
  289. {
  290. "title": cleaned.get("clean_title") or result.get("title"),
  291. "url": detail.get("visited_url") or result.get("url"),
  292. "text": cleaned.get("clean_text") or detail.get("error") or "",
  293. "key_points": cleaned.get("key_points") or [],
  294. "opened_detail_page": bool(detail.get("opened_detail_page")),
  295. }
  296. )
  297. return evidence
  298. def sources_from_evidence(evidence: list[dict[str, Any]]) -> list[dict[str, str]]:
  299. sources: list[dict[str, str]] = []
  300. seen: set[str] = set()
  301. for item in evidence:
  302. url = str(item.get("url") or "").strip()
  303. if not url or url in seen:
  304. continue
  305. seen.add(url)
  306. sources.append({"title": str(item.get("title") or url), "url": url})
  307. return sources
  308. def ai_web_research_node(node: dict[str, Any], inputs: dict[str, Any], context: WorkflowContext) -> dict[str, Any]:
  309. params = {**(node.get("params") or {}), **inputs}
  310. return AiWebResearchRunner(context, params).run()
  311. register_node(
  312. {
  313. "type": "research.ai_web_research",
  314. "category": "research",
  315. "label": "AI 多轮网页研究",
  316. "params": {
  317. "objective": field_def("textarea", "研究目标", required=True),
  318. "output_schema": field_def("textarea", "返回 JSON Schema", required=True),
  319. "constraints": field_def("textarea", "研究约束", "{}"),
  320. "max_attempts": field_def("number", "最多尝试次数", 3, minimum=1, maximum=10),
  321. "search_engine": field_def("select", "搜索引擎", "bing", options=["google", "bing"]),
  322. "browser": field_def("select", "浏览器", "edge", options=["default", "edge"]),
  323. "max_search_pages": field_def("number", "每轮搜索页屏", 2, minimum=1, maximum=10),
  324. "result_count": field_def("number", "每轮研究结果数", 2, minimum=1, maximum=5),
  325. "detail_max_pages": field_def("number", "每个详情页屏", 2, minimum=1, maximum=10),
  326. },
  327. "inputs": {
  328. "objective": field_def("string", "研究目标"),
  329. "output_schema": field_def("object", "返回 JSON Schema"),
  330. "constraints": field_def("object", "研究约束"),
  331. "max_attempts": field_def("number", "最多尝试次数"),
  332. },
  333. "outputs": {
  334. "status": {"type": "string", "label": "研究状态"},
  335. "goal_achieved": {"type": "boolean", "label": "是否达成目标"},
  336. "data": {"type": "object", "label": "符合 Schema 的数据"},
  337. "validation": {"type": "object", "label": "Schema 校验结果"},
  338. "assessment": {"type": "object", "label": "目标评估"},
  339. "sources": {"type": "array", "label": "来源"},
  340. "attempts": {"type": "array", "label": "尝试记录"},
  341. },
  342. "control_ports": control_ports(["success", "partial", "failure"]),
  343. },
  344. ai_web_research_node,
  345. )