| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402 |
- from __future__ import annotations
- import json
- from typing import Any
- from urllib.parse import urlparse
- from fastapi import HTTPException
- from ... import ai_service
- from ..context import WorkflowContext
- from ..registry import control_ports, field_def, register_node
- from .web_search import WebSearchRunner, _integer
- def parse_object(value: Any, field_name: str) -> dict[str, Any]:
- """兼容 API 直接传对象和编辑器中保存的 JSON 文本。"""
- if isinstance(value, dict):
- return value
- if isinstance(value, str) and value.strip():
- try:
- parsed = json.loads(value)
- except json.JSONDecodeError as exc:
- raise HTTPException(status_code=400, detail=f"{field_name} 不是有效 JSON 对象: {exc}") from exc
- if isinstance(parsed, dict):
- return parsed
- raise HTTPException(status_code=400, detail=f"{field_name} 必须是 JSON 对象")
- def validate_json_data(data: Any, schema: dict[str, Any]) -> dict[str, Any]:
- """使用标准 JSON Schema 校验最终返回数据。"""
- try:
- from jsonschema import Draft202012Validator
- except ImportError as exc:
- raise HTTPException(status_code=500, detail="jsonschema is not installed") from exc
- try:
- Draft202012Validator.check_schema(schema)
- except Exception as exc:
- raise HTTPException(status_code=400, detail=f"output_schema 无效: {exc}") from exc
- validator = Draft202012Validator(schema)
- errors = sorted(validator.iter_errors(data), key=lambda item: list(item.absolute_path))
- return {
- "schema_valid": not errors,
- "errors": [
- {
- "path": ".".join(str(part) for part in error.absolute_path),
- "message": error.message,
- }
- for error in errors
- ],
- }
- def validate_research_result(
- data: Any,
- schema: dict[str, Any],
- constraints: dict[str, Any],
- evidence: list[dict[str, Any]],
- ) -> dict[str, Any]:
- """组合 JSON Schema 与来源数量、必需域名等确定性约束。"""
- result = validate_json_data(data, schema)
- constraint_errors: list[str] = []
- sources = sources_from_evidence(evidence)
- try:
- min_sources = max(0, int(constraints.get("min_sources", 0)))
- except (TypeError, ValueError):
- min_sources = 0
- if len(sources) < min_sources:
- constraint_errors.append(f"来源数量 {len(sources)} 少于要求的 {min_sources}")
- required_domains = constraints.get("required_domains") or []
- if isinstance(required_domains, list):
- source_hosts = {urlparse(item["url"]).netloc.lower() for item in sources}
- for domain in required_domains:
- normalized = str(domain).lower().strip()
- if normalized and not any(host == normalized or host.endswith(f".{normalized}") for host in source_hosts):
- constraint_errors.append(f"缺少必需来源域名: {normalized}")
- result["constraints_valid"] = not constraint_errors
- result["constraint_errors"] = constraint_errors
- result["valid"] = result["schema_valid"] and result["constraints_valid"]
- return result
- class AiWebResearchRunner:
- """AI 驱动的多轮视觉网页研究状态机。"""
- def __init__(self, context: WorkflowContext, params: dict[str, Any]) -> None:
- if not context.provider_id or not context.model_id:
- raise HTTPException(status_code=400, detail="AI 搜索研究节点需要配置默认 AI 服务商和模型")
- self.context = context
- self.params = params
- self.objective = str(params.get("objective") or "").strip()
- if not self.objective:
- raise HTTPException(status_code=400, detail="研究目标不能为空")
- self.output_schema = parse_object(params.get("output_schema"), "output_schema")
- self.constraints = parse_object(params.get("constraints") or {}, "constraints")
- self.max_attempts = _integer(params.get("max_attempts"), 3, 1, 10)
- self.search_engine = str(params.get("search_engine") or "bing")
- self.browser = str(params.get("browser") or "edge")
- self.max_search_pages = _integer(params.get("max_search_pages"), 2, 1, 10)
- self.result_count = _integer(params.get("result_count"), 2, 1, 5)
- self.detail_max_pages = _integer(params.get("detail_max_pages"), 2, 1, 10)
- def run(self) -> dict[str, Any]:
- plan = self._create_plan()
- pending_queries = self._normalize_queries(plan.get("queries"))
- if not pending_queries:
- pending_queries = [self.objective]
- searched_queries: list[str] = []
- evidence: list[dict[str, Any]] = []
- attempts: list[dict[str, Any]] = []
- latest_assessment: dict[str, Any] = {}
- latest_data: Any = {}
- latest_validation = validate_research_result(latest_data, self.output_schema, self.constraints, evidence)
- for attempt_number in range(1, self.max_attempts + 1):
- query = self._next_query(pending_queries, searched_queries, latest_assessment)
- searched_queries.append(query)
- search_output = WebSearchRunner(
- self.context,
- {
- "query": query,
- "search_engine": self.search_engine,
- "browser": self.browser,
- "max_search_pages": self.max_search_pages,
- "result_count": self.result_count,
- "detail_max_pages": self.detail_max_pages,
- "click_attempts": self.params.get("click_attempts", 2),
- "maximize_browser": self.params.get("maximize_browser", True),
- "page_load_wait_seconds": self.params.get("page_load_wait_seconds", 8),
- "action_wait_seconds": self.params.get("action_wait_seconds", 1),
- "wait_jitter_min_seconds": self.params.get("wait_jitter_min_seconds", 0),
- "wait_jitter_max_seconds": self.params.get("wait_jitter_max_seconds", 0),
- "close_browser": True,
- "include_debug_analyses": False,
- },
- ).run()
- attempt_evidence = compact_evidence(search_output)
- evidence.extend(attempt_evidence)
- latest_assessment = self._assess_progress(plan, searched_queries, evidence)
- latest_data = latest_assessment.get("candidate_data")
- latest_validation = validate_research_result(
- latest_data,
- self.output_schema,
- self.constraints,
- evidence,
- )
- goal_achieved = bool(latest_assessment.get("goal_achieved")) and latest_validation["valid"]
- attempts.append(
- {
- "attempt": attempt_number,
- "query": query,
- "search_result_count": search_output.get("result_count", 0),
- "researched_count": search_output.get("researched_count", 0),
- "sources": sources_from_evidence(attempt_evidence),
- "assessment": {
- "goal_achieved": bool(latest_assessment.get("goal_achieved")),
- "confidence": latest_assessment.get("confidence"),
- "reason": latest_assessment.get("reason"),
- "missing_information": latest_assessment.get("missing_information") or [],
- },
- "validation": latest_validation,
- }
- )
- if goal_achieved:
- return self._build_output(
- plan,
- attempts,
- evidence,
- latest_data,
- latest_validation,
- latest_assessment,
- True,
- )
- pending_queries.extend(self._normalize_queries(latest_assessment.get("next_queries")))
- return self._build_output(
- plan,
- attempts,
- evidence,
- latest_data,
- latest_validation,
- latest_assessment,
- False,
- )
- def _text_json(self, prompt: str) -> dict[str, Any]:
- result = ai_service.chat(
- int(self.context.provider_id),
- int(self.context.model_id),
- prompt,
- self.context.temperature,
- )
- try:
- parsed = json.loads(ai_service.extract_json_text(result["content"]))
- except (json.JSONDecodeError, TypeError, ValueError) as exc:
- raise HTTPException(status_code=502, detail=f"AI 研究模型未返回有效 JSON: {exc}") from exc
- if not isinstance(parsed, dict):
- raise HTTPException(status_code=502, detail="AI 研究模型返回值必须是 JSON 对象")
- return parsed
- def _create_plan(self) -> dict[str, Any]:
- prompt = f"""请为一个使用真实浏览器和视觉截图的网页研究任务制定搜索计划。
- 研究目标:
- {self.objective}
- 最终输出 JSON Schema:
- {json.dumps(self.output_schema, ensure_ascii=False, indent=2)}
- 约束:
- {json.dumps(self.constraints, ensure_ascii=False, indent=2)}
- 最多尝试次数:{self.max_attempts}
- 请严格只输出 JSON:
- {{
- "summary": string,
- "acceptance_criteria": [string],
- "queries": [string],
- "source_preferences": [string],
- "risks": [string]
- }}
- queries 应按优先级排列,数量不超过最多尝试次数。"""
- return self._text_json(prompt)
- def _assess_progress(
- self,
- plan: dict[str, Any],
- searched_queries: list[str],
- evidence: list[dict[str, Any]],
- ) -> dict[str, Any]:
- prompt = f"""请评估网页研究任务是否已经达成,并生成符合指定 JSON Schema 的候选数据。
- 研究目标:
- {self.objective}
- 研究计划:
- {json.dumps(plan, ensure_ascii=False)}
- 输出 JSON Schema:
- {json.dumps(self.output_schema, ensure_ascii=False, indent=2)}
- 约束:
- {json.dumps(self.constraints, ensure_ascii=False)}
- 已搜索查询:
- {json.dumps(searched_queries, ensure_ascii=False)}
- 已获得证据:
- {json.dumps(evidence[-20:], ensure_ascii=False)}
- 判断规则:
- 1. 只有证据足以覆盖研究目标和计划中的验收标准时,goal_achieved 才能为 true。
- 2. candidate_data 必须严格匹配给定 JSON Schema,不要添加 Schema 未允许的包装字段。
- 3. 缺少信息时给出下一轮更精确、且与已搜索内容不同的查询词。
- 4. 不要把搜索摘要中的推测当作已验证事实。
- 严格只输出 JSON:
- {{
- "goal_achieved": boolean,
- "confidence": number,
- "reason": string,
- "missing_information": [string],
- "next_queries": [string],
- "candidate_data": object
- }}"""
- return self._text_json(prompt)
- def _next_query(
- self,
- pending_queries: list[str],
- searched_queries: list[str],
- assessment: dict[str, Any],
- ) -> str:
- searched = {item.strip().lower() for item in searched_queries}
- while pending_queries:
- query = pending_queries.pop(0).strip()
- if query and query.lower() not in searched:
- return query
- missing = assessment.get("missing_information") or []
- suffix = " ".join(str(item) for item in missing[:2])
- return f"{self.objective} {suffix} 补充资料 第{len(searched_queries) + 1}轮".strip()
- @staticmethod
- def _normalize_queries(value: Any) -> list[str]:
- if not isinstance(value, list):
- return []
- return [str(item).strip() for item in value if str(item).strip()]
- def _build_output(
- self,
- plan: dict[str, Any],
- attempts: list[dict[str, Any]],
- evidence: list[dict[str, Any]],
- data: Any,
- validation: dict[str, Any],
- assessment: dict[str, Any],
- goal_achieved: bool,
- ) -> dict[str, Any]:
- return {
- "status": "GOAL_ACHIEVED" if goal_achieved else "MAX_ATTEMPTS_REACHED",
- "goal_achieved": goal_achieved,
- "objective": self.objective,
- "attempts_used": len(attempts),
- "max_attempts": self.max_attempts,
- "data": data,
- "validation": validation,
- "assessment": {
- "confidence": assessment.get("confidence"),
- "reason": assessment.get("reason"),
- "missing_information": assessment.get("missing_information") or [],
- },
- "sources": sources_from_evidence(evidence),
- "plan": plan,
- "attempts": attempts,
- "next_port": "success" if goal_achieved else "partial",
- }
- def compact_evidence(search_output: dict[str, Any]) -> list[dict[str, Any]]:
- """只保留评估所需字段,控制多轮提示词长度。"""
- evidence: list[dict[str, Any]] = []
- for detail in search_output.get("researched_details") or []:
- if not isinstance(detail, dict):
- continue
- result = detail.get("result") if isinstance(detail.get("result"), dict) else {}
- cleaned = detail.get("cleaned") if isinstance(detail.get("cleaned"), dict) else {}
- evidence.append(
- {
- "title": cleaned.get("clean_title") or result.get("title"),
- "url": detail.get("visited_url") or result.get("url"),
- "text": cleaned.get("clean_text") or detail.get("error") or "",
- "key_points": cleaned.get("key_points") or [],
- "opened_detail_page": bool(detail.get("opened_detail_page")),
- }
- )
- return evidence
- def sources_from_evidence(evidence: list[dict[str, Any]]) -> list[dict[str, str]]:
- sources: list[dict[str, str]] = []
- seen: set[str] = set()
- for item in evidence:
- url = str(item.get("url") or "").strip()
- if not url or url in seen:
- continue
- seen.add(url)
- sources.append({"title": str(item.get("title") or url), "url": url})
- return sources
- def ai_web_research_node(node: dict[str, Any], inputs: dict[str, Any], context: WorkflowContext) -> dict[str, Any]:
- params = {**(node.get("params") or {}), **inputs}
- return AiWebResearchRunner(context, params).run()
- register_node(
- {
- "type": "research.ai_web_research",
- "category": "research",
- "label": "AI 多轮网页研究",
- "params": {
- "objective": field_def("textarea", "研究目标", required=True),
- "output_schema": field_def("textarea", "返回 JSON Schema", required=True),
- "constraints": field_def("textarea", "研究约束", "{}"),
- "max_attempts": field_def("number", "最多尝试次数", 3, minimum=1, maximum=10),
- "search_engine": field_def("select", "搜索引擎", "bing", options=["google", "bing"]),
- "browser": field_def("select", "浏览器", "edge", options=["default", "edge"]),
- "max_search_pages": field_def("number", "每轮搜索页屏", 2, minimum=1, maximum=10),
- "result_count": field_def("number", "每轮研究结果数", 2, minimum=1, maximum=5),
- "detail_max_pages": field_def("number", "每个详情页屏", 2, minimum=1, maximum=10),
- "click_attempts": field_def("number", "标题点击重试", 2, minimum=1, maximum=5),
- "maximize_browser": field_def("boolean", "打开后最大化浏览器", True),
- "page_load_wait_seconds": field_def("number", "页面加载等待秒数", 8, minimum=0, maximum=120),
- "action_wait_seconds": field_def("number", "操作等待秒数", 1, minimum=0, maximum=30),
- "wait_jitter_min_seconds": field_def("number", "等待抖动最小秒数", 0, minimum=0, maximum=30),
- "wait_jitter_max_seconds": field_def("number", "等待抖动最大秒数", 0, minimum=0, maximum=30),
- },
- "inputs": {
- "objective": field_def("string", "研究目标"),
- "output_schema": field_def("object", "返回 JSON Schema"),
- "constraints": field_def("object", "研究约束"),
- "max_attempts": field_def("number", "最多尝试次数"),
- },
- "outputs": {
- "status": {"type": "string", "label": "研究状态"},
- "goal_achieved": {"type": "boolean", "label": "是否达成目标"},
- "data": {"type": "object", "label": "符合 Schema 的数据"},
- "validation": {"type": "object", "label": "Schema 校验结果"},
- "assessment": {"type": "object", "label": "目标评估"},
- "sources": {"type": "array", "label": "来源"},
- "attempts": {"type": "array", "label": "尝试记录"},
- },
- "control_ports": control_ports(["success", "partial", "failure"]),
- },
- ai_web_research_node,
- )
|