Workflow Copilot: LLM-based YAML autocorrection (#4504)
This commit is contained in:
committed by
GitHub
parent
0777d27fee
commit
5d7814a925
146
skyvern/forge/sdk/routes/event_source_stream.py
Normal file
146
skyvern/forge/sdk/routes/event_source_stream.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import Any, Awaitable, Callable, Protocol
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sse_starlette import EventSourceResponse, JSONServerSentEvent, ServerSentEvent
|
||||||
|
|
||||||
|
DEFAULT_KEEPALIVE_INTERVAL_SECONDS = 10
|
||||||
|
|
||||||
|
|
||||||
|
class EventSourceStream(Protocol):
|
||||||
|
"""Protocol for Server-Sent Events (SSE) streams."""
|
||||||
|
|
||||||
|
async def send(self, data: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Send data as an SSE event.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the event was queued successfully, False if disconnected or closed.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def is_disconnected(self) -> bool:
|
||||||
|
"""Check if the client has disconnected."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Signal that the stream is complete."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class FastAPIEventSourceStream:
|
||||||
|
"""
|
||||||
|
FastAPI implementation of EventSourceStream.
|
||||||
|
|
||||||
|
This class provides a cleaner interface for sending SSE updates from async functions
|
||||||
|
instead of using yield-based generators directly.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@app.post("/stream")
|
||||||
|
async def my_endpoint(request: Request) -> EventSourceResponse:
|
||||||
|
async def handler(stream: EventSourceStream) -> None:
|
||||||
|
await stream.send(MyUpdateModel(status="Processing..."))
|
||||||
|
result = await do_work()
|
||||||
|
await stream.send({"status": "Done", "result": result})
|
||||||
|
|
||||||
|
return FastAPIEventSourceStream.create(request, handler)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request: Request) -> None:
|
||||||
|
self._request = request
|
||||||
|
self._queue: asyncio.Queue[Any] = asyncio.Queue()
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
|
async def send(self, data: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Send data as an SSE event. Accepts Pydantic models or dicts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the event was queued successfully, False if disconnected or closed.
|
||||||
|
"""
|
||||||
|
if self._closed or await self.is_disconnected():
|
||||||
|
return False
|
||||||
|
await self._queue.put(data)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def is_disconnected(self) -> bool:
|
||||||
|
"""Check if the client has disconnected."""
|
||||||
|
return await self._request.is_disconnected()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Signal that the stream is complete."""
|
||||||
|
self._closed = True
|
||||||
|
await self._queue.put(None)
|
||||||
|
|
||||||
|
def _serialize(self, data: Any) -> Any:
|
||||||
|
"""Serialize data to JSON-compatible format."""
|
||||||
|
if isinstance(data, BaseModel):
|
||||||
|
return data.model_dump(mode="json")
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def _generate(self) -> Any:
|
||||||
|
"""Internal generator that yields SSE events from the queue."""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = await self._queue.get()
|
||||||
|
if data is None:
|
||||||
|
break
|
||||||
|
if await self.is_disconnected():
|
||||||
|
break
|
||||||
|
yield JSONServerSentEvent(data=self._serialize(data))
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
request: Request,
|
||||||
|
handler: Callable[[EventSourceStream], Awaitable[None]],
|
||||||
|
ping_interval: int = DEFAULT_KEEPALIVE_INTERVAL_SECONDS,
|
||||||
|
) -> EventSourceResponse:
|
||||||
|
"""
|
||||||
|
Create an EventSourceResponse that runs the handler with an EventSourceStream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The FastAPI request object
|
||||||
|
handler: An async function that receives the stream and sends events
|
||||||
|
ping_interval: Interval in seconds for keep-alive pings (default: 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An EventSourceResponse that can be returned from a FastAPI endpoint
|
||||||
|
"""
|
||||||
|
stream = cls(request)
|
||||||
|
|
||||||
|
async def event_generator() -> Any:
|
||||||
|
task = asyncio.create_task(cls._run_handler(stream, handler))
|
||||||
|
try:
|
||||||
|
async for event in stream._generate():
|
||||||
|
yield event
|
||||||
|
finally:
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def ping_message_factory() -> ServerSentEvent:
|
||||||
|
return ServerSentEvent(comment="keep-alive")
|
||||||
|
|
||||||
|
return EventSourceResponse(
|
||||||
|
event_generator(),
|
||||||
|
ping=ping_interval,
|
||||||
|
ping_message_factory=ping_message_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _run_handler(
|
||||||
|
stream: EventSourceStream,
|
||||||
|
handler: Callable[[EventSourceStream], Awaitable[None]],
|
||||||
|
) -> None:
|
||||||
|
"""Run the handler and ensure the stream is closed when done."""
|
||||||
|
try:
|
||||||
|
await handler(stream)
|
||||||
|
finally:
|
||||||
|
await stream.close()
|
||||||
@@ -2,18 +2,21 @@ import time
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import AsyncGenerator
|
from typing import Any
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
import yaml
|
import yaml
|
||||||
from fastapi import Depends, HTTPException, Request, status
|
from fastapi import Depends, HTTPException, Request, status
|
||||||
from sse_starlette import EventSourceResponse, JSONServerSentEvent, ServerSentEvent
|
from pydantic import ValidationError
|
||||||
|
from sse_starlette import EventSourceResponse
|
||||||
|
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
from skyvern.forge.prompts import prompt_engine
|
from skyvern.forge.prompts import prompt_engine
|
||||||
|
from skyvern.forge.sdk.api.llm.api_handler import LLMAPIHandler
|
||||||
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
||||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||||
from skyvern.forge.sdk.experimentation.llm_prompt_config import get_llm_handler_for_prompt_type
|
from skyvern.forge.sdk.experimentation.llm_prompt_config import get_llm_handler_for_prompt_type
|
||||||
|
from skyvern.forge.sdk.routes.event_source_stream import EventSourceStream, FastAPIEventSourceStream
|
||||||
from skyvern.forge.sdk.routes.routers import base_router
|
from skyvern.forge.sdk.routes.routers import base_router
|
||||||
from skyvern.forge.sdk.routes.run_blocks import DEFAULT_LOGIN_PROMPT
|
from skyvern.forge.sdk.routes.run_blocks import DEFAULT_LOGIN_PROMPT
|
||||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||||
@@ -29,6 +32,7 @@ from skyvern.forge.sdk.schemas.workflow_copilot import (
|
|||||||
WorkflowCopilotStreamResponseUpdate,
|
WorkflowCopilotStreamResponseUpdate,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.services import org_auth_service
|
from skyvern.forge.sdk.services import org_auth_service
|
||||||
|
from skyvern.forge.sdk.workflow.exceptions import BaseWorkflowHTTPException
|
||||||
from skyvern.forge.sdk.workflow.models.parameter import ParameterType
|
from skyvern.forge.sdk.workflow.models.parameter import ParameterType
|
||||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowDefinition
|
from skyvern.forge.sdk.workflow.models.workflow import WorkflowDefinition
|
||||||
from skyvern.forge.sdk.workflow.workflow_definition_converter import convert_workflow_definition
|
from skyvern.forge.sdk.workflow.workflow_definition_converter import convert_workflow_definition
|
||||||
@@ -39,7 +43,6 @@ from skyvern.schemas.workflows import (
|
|||||||
|
|
||||||
WORKFLOW_KNOWLEDGE_BASE_PATH = Path("skyvern/forge/prompts/skyvern/workflow_knowledge_base.txt")
|
WORKFLOW_KNOWLEDGE_BASE_PATH = Path("skyvern/forge/prompts/skyvern/workflow_knowledge_base.txt")
|
||||||
CHAT_HISTORY_CONTEXT_MESSAGES = 10
|
CHAT_HISTORY_CONTEXT_MESSAGES = 10
|
||||||
SSE_KEEPALIVE_INTERVAL_SECONDS = 10
|
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
@@ -88,19 +91,41 @@ async def _get_debug_run_info(organization_id: str, workflow_run_id: str | None)
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_chat_history(chat_history: list[WorkflowCopilotChatHistoryMessage]) -> str:
|
||||||
|
chat_history_text = ""
|
||||||
|
if chat_history:
|
||||||
|
history_lines = [f"{msg.sender}: {msg.content}" for msg in chat_history]
|
||||||
|
chat_history_text = "\n".join(history_lines)
|
||||||
|
return chat_history_text
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_llm_response(llm_response: dict[str, Any] | Any) -> Any:
|
||||||
|
if isinstance(llm_response, dict) and "output" in llm_response:
|
||||||
|
action_data = llm_response["output"]
|
||||||
|
else:
|
||||||
|
action_data = llm_response
|
||||||
|
|
||||||
|
if not isinstance(action_data, dict):
|
||||||
|
LOG.error(
|
||||||
|
"LLM response is not valid JSON",
|
||||||
|
response_type=type(action_data).__name__,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Invalid response from LLM",
|
||||||
|
)
|
||||||
|
return action_data
|
||||||
|
|
||||||
|
|
||||||
async def copilot_call_llm(
|
async def copilot_call_llm(
|
||||||
|
stream: EventSourceStream,
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
chat_request: WorkflowCopilotChatRequest,
|
chat_request: WorkflowCopilotChatRequest,
|
||||||
chat_history: list[WorkflowCopilotChatHistoryMessage],
|
chat_history: list[WorkflowCopilotChatHistoryMessage],
|
||||||
global_llm_context: str | None,
|
global_llm_context: str | None,
|
||||||
debug_run_info_text: str,
|
debug_run_info_text: str,
|
||||||
) -> tuple[str, WorkflowDefinition | None, str | None]:
|
) -> tuple[str, WorkflowDefinition | None, str | None]:
|
||||||
current_datetime = datetime.now(timezone.utc).isoformat()
|
chat_history_text = _format_chat_history(chat_history)
|
||||||
|
|
||||||
chat_history_text = ""
|
|
||||||
if chat_history:
|
|
||||||
history_lines = [f"{msg.sender}: {msg.content}" for msg in chat_history]
|
|
||||||
chat_history_text = "\n".join(history_lines)
|
|
||||||
|
|
||||||
workflow_knowledge_base = WORKFLOW_KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8")
|
workflow_knowledge_base = WORKFLOW_KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8")
|
||||||
|
|
||||||
@@ -111,7 +136,7 @@ async def copilot_call_llm(
|
|||||||
user_message=chat_request.message,
|
user_message=chat_request.message,
|
||||||
chat_history=chat_history_text,
|
chat_history=chat_history_text,
|
||||||
global_llm_context=global_llm_context or "",
|
global_llm_context=global_llm_context or "",
|
||||||
current_datetime=current_datetime,
|
current_datetime=datetime.now(timezone.utc).isoformat(),
|
||||||
debug_run_info=debug_run_info_text,
|
debug_run_info=debug_run_info_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -150,21 +175,7 @@ async def copilot_call_llm(
|
|||||||
llm_response=llm_response,
|
llm_response=llm_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(llm_response, dict) and "output" in llm_response:
|
action_data = _parse_llm_response(llm_response)
|
||||||
action_data = llm_response["output"]
|
|
||||||
else:
|
|
||||||
action_data = llm_response
|
|
||||||
|
|
||||||
if not isinstance(action_data, dict):
|
|
||||||
LOG.error(
|
|
||||||
"LLM response is not valid JSON",
|
|
||||||
organization_id=organization_id,
|
|
||||||
response_type=type(action_data).__name__,
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="Invalid response from LLM",
|
|
||||||
)
|
|
||||||
|
|
||||||
action_type = action_data.get("type")
|
action_type = action_data.get("type")
|
||||||
user_response_value = action_data.get("user_response")
|
user_response_value = action_data.get("user_response")
|
||||||
@@ -183,7 +194,29 @@ async def copilot_call_llm(
|
|||||||
global_llm_context = str(global_llm_context)
|
global_llm_context = str(global_llm_context)
|
||||||
|
|
||||||
if action_type == "REPLACE_WORKFLOW":
|
if action_type == "REPLACE_WORKFLOW":
|
||||||
updated_workflow = await _process_workflow_yaml(chat_request.workflow_id, action_data.get("workflow_yaml", ""))
|
workflow_yaml = action_data.get("workflow_yaml", "")
|
||||||
|
try:
|
||||||
|
updated_workflow = await _process_workflow_yaml(chat_request.workflow_id, workflow_yaml)
|
||||||
|
except (yaml.YAMLError, ValidationError, BaseWorkflowHTTPException) as e:
|
||||||
|
await stream.send(
|
||||||
|
WorkflowCopilotProcessingUpdate(
|
||||||
|
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
|
||||||
|
status="Validating workflow definition...",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
corrected_workflow_yaml = await _auto_correct_workflow_yaml(
|
||||||
|
llm_api_handler=llm_api_handler,
|
||||||
|
organization_id=organization_id,
|
||||||
|
user_response=user_response,
|
||||||
|
workflow_yaml=workflow_yaml,
|
||||||
|
chat_history=chat_history,
|
||||||
|
global_llm_context=global_llm_context,
|
||||||
|
debug_run_info_text=debug_run_info_text,
|
||||||
|
error=e,
|
||||||
|
)
|
||||||
|
updated_workflow = await _process_workflow_yaml(chat_request.workflow_id, corrected_workflow_yaml)
|
||||||
|
|
||||||
return user_response, updated_workflow, global_llm_context
|
return user_response, updated_workflow, global_llm_context
|
||||||
elif action_type == "REPLY":
|
elif action_type == "REPLY":
|
||||||
return user_response, None, global_llm_context
|
return user_response, None, global_llm_context
|
||||||
@@ -198,45 +231,80 @@ async def copilot_call_llm(
|
|||||||
return "I received your request but I'm not sure how to help. Could you rephrase?", None, None
|
return "I received your request but I'm not sure how to help. Could you rephrase?", None, None
|
||||||
|
|
||||||
|
|
||||||
|
async def _auto_correct_workflow_yaml(
|
||||||
|
llm_api_handler: LLMAPIHandler,
|
||||||
|
organization_id: str,
|
||||||
|
user_response: str,
|
||||||
|
workflow_yaml: str,
|
||||||
|
chat_history: list[WorkflowCopilotChatHistoryMessage],
|
||||||
|
global_llm_context: str | None,
|
||||||
|
debug_run_info_text: str,
|
||||||
|
error: Exception,
|
||||||
|
) -> str:
|
||||||
|
failure_reason = f"{error.__class__.__name__}: {error}"
|
||||||
|
|
||||||
|
new_chat_history = chat_history[:]
|
||||||
|
new_chat_history.append(
|
||||||
|
WorkflowCopilotChatHistoryMessage(
|
||||||
|
sender=WorkflowCopilotChatSender.AI,
|
||||||
|
content=user_response,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_knowledge_base = WORKFLOW_KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8")
|
||||||
|
llm_prompt = prompt_engine.load_prompt(
|
||||||
|
template="workflow-copilot",
|
||||||
|
workflow_knowledge_base=workflow_knowledge_base,
|
||||||
|
workflow_yaml=workflow_yaml,
|
||||||
|
user_message=f"Workflow YAML parsing failed, please fix it: {failure_reason}",
|
||||||
|
chat_history=_format_chat_history(new_chat_history),
|
||||||
|
global_llm_context=global_llm_context or "",
|
||||||
|
current_datetime=datetime.now(timezone.utc).isoformat(),
|
||||||
|
debug_run_info=debug_run_info_text,
|
||||||
|
)
|
||||||
|
llm_start_time = time.monotonic()
|
||||||
|
llm_response = await llm_api_handler(
|
||||||
|
prompt=llm_prompt,
|
||||||
|
prompt_name="workflow-copilot",
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
"Auto-correction LLM response",
|
||||||
|
duration_seconds=time.monotonic() - llm_start_time,
|
||||||
|
llm_response_len=len(llm_response),
|
||||||
|
llm_response=llm_response,
|
||||||
|
)
|
||||||
|
action_data = _parse_llm_response(llm_response)
|
||||||
|
|
||||||
|
return action_data.get("workflow_yaml", workflow_yaml)
|
||||||
|
|
||||||
|
|
||||||
async def _process_workflow_yaml(workflow_id: str, workflow_yaml: str) -> WorkflowDefinition:
|
async def _process_workflow_yaml(workflow_id: str, workflow_yaml: str) -> WorkflowDefinition:
|
||||||
try:
|
parsed_yaml = yaml.safe_load(workflow_yaml)
|
||||||
parsed_yaml = yaml.safe_load(workflow_yaml)
|
|
||||||
except yaml.YAMLError as e:
|
|
||||||
LOG.error("Invalid YAML from LLM", yaml=workflow_yaml, exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"LLM generated invalid YAML: {str(e)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
# Fixing trivial common LLM mistakes
|
||||||
# Fixing trivial common LLM mistakes
|
workflow_definition = parsed_yaml.get("workflow_definition", None)
|
||||||
workflow_definition = parsed_yaml.get("workflow_definition", None)
|
if workflow_definition:
|
||||||
if workflow_definition:
|
blocks = workflow_definition.get("blocks", [])
|
||||||
blocks = workflow_definition.get("blocks", [])
|
for block in blocks:
|
||||||
for block in blocks:
|
block["title"] = block.get("title", "")
|
||||||
block["title"] = block.get("title", "")
|
|
||||||
|
|
||||||
workflow_yaml_request = WorkflowCreateYAMLRequest.model_validate(parsed_yaml)
|
workflow_yaml_request = WorkflowCreateYAMLRequest.model_validate(parsed_yaml)
|
||||||
|
|
||||||
# Post-processing
|
# Post-processing
|
||||||
for block in workflow_yaml_request.workflow_definition.blocks:
|
for block in workflow_yaml_request.workflow_definition.blocks:
|
||||||
if isinstance(block, LoginBlockYAML) and not block.navigation_goal:
|
if isinstance(block, LoginBlockYAML) and not block.navigation_goal:
|
||||||
block.navigation_goal = DEFAULT_LOGIN_PROMPT
|
block.navigation_goal = DEFAULT_LOGIN_PROMPT
|
||||||
|
|
||||||
workflow_yaml_request.workflow_definition.parameters = [
|
workflow_yaml_request.workflow_definition.parameters = [
|
||||||
p for p in workflow_yaml_request.workflow_definition.parameters if p.parameter_type != ParameterType.OUTPUT
|
p for p in workflow_yaml_request.workflow_definition.parameters if p.parameter_type != ParameterType.OUTPUT
|
||||||
]
|
]
|
||||||
|
|
||||||
updated_workflow = convert_workflow_definition(
|
updated_workflow = convert_workflow_definition(
|
||||||
workflow_definition_yaml=workflow_yaml_request.workflow_definition,
|
workflow_definition_yaml=workflow_yaml_request.workflow_definition,
|
||||||
workflow_id=workflow_id,
|
workflow_id=workflow_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
LOG.error("YAML from LLM does not conform to Skyvern workflow schema", yaml=workflow_yaml, exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"LLM generated YAML that doesn't match workflow schema: {str(e)}",
|
|
||||||
)
|
|
||||||
return updated_workflow
|
return updated_workflow
|
||||||
|
|
||||||
|
|
||||||
@@ -246,7 +314,7 @@ async def workflow_copilot_chat_post(
|
|||||||
chat_request: WorkflowCopilotChatRequest,
|
chat_request: WorkflowCopilotChatRequest,
|
||||||
organization: Organization = Depends(org_auth_service.get_current_org),
|
organization: Organization = Depends(org_auth_service.get_current_org),
|
||||||
) -> EventSourceResponse:
|
) -> EventSourceResponse:
|
||||||
async def event_stream() -> AsyncGenerator[JSONServerSentEvent, None]:
|
async def stream_handler(stream: EventSourceStream) -> None:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Workflow copilot chat request",
|
"Workflow copilot chat request",
|
||||||
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
||||||
@@ -257,12 +325,12 @@ async def workflow_copilot_chat_post(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield JSONServerSentEvent(
|
await stream.send(
|
||||||
data=WorkflowCopilotProcessingUpdate(
|
WorkflowCopilotProcessingUpdate(
|
||||||
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
|
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
|
||||||
status="Processing...",
|
status="Processing...",
|
||||||
timestamp=datetime.now(timezone.utc),
|
timestamp=datetime.now(timezone.utc),
|
||||||
).model_dump(mode="json"),
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if chat_request.workflow_copilot_chat_id:
|
if chat_request.workflow_copilot_chat_id:
|
||||||
@@ -302,15 +370,15 @@ async def workflow_copilot_chat_post(
|
|||||||
if debug_run_info.html:
|
if debug_run_info.html:
|
||||||
debug_run_info_text += f"\n\nVisible Elements Tree (HTML):\n{debug_run_info.html}"
|
debug_run_info_text += f"\n\nVisible Elements Tree (HTML):\n{debug_run_info.html}"
|
||||||
|
|
||||||
yield JSONServerSentEvent(
|
await stream.send(
|
||||||
data=WorkflowCopilotProcessingUpdate(
|
WorkflowCopilotProcessingUpdate(
|
||||||
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
|
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
|
||||||
status="Thinking...",
|
status="Thinking...",
|
||||||
timestamp=datetime.now(timezone.utc),
|
timestamp=datetime.now(timezone.utc),
|
||||||
).model_dump(mode="json"),
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if await request.is_disconnected():
|
if await stream.is_disconnected():
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Workflow copilot chat request is disconnected before LLM call",
|
"Workflow copilot chat request is disconnected before LLM call",
|
||||||
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
||||||
@@ -318,6 +386,7 @@ async def workflow_copilot_chat_post(
|
|||||||
return
|
return
|
||||||
|
|
||||||
user_response, updated_workflow, updated_global_llm_context = await copilot_call_llm(
|
user_response, updated_workflow, updated_global_llm_context = await copilot_call_llm(
|
||||||
|
stream,
|
||||||
organization.organization_id,
|
organization.organization_id,
|
||||||
chat_request,
|
chat_request,
|
||||||
convert_to_history_messages(chat_messages[-CHAT_HISTORY_CONTEXT_MESSAGES:]),
|
convert_to_history_messages(chat_messages[-CHAT_HISTORY_CONTEXT_MESSAGES:]),
|
||||||
@@ -325,7 +394,7 @@ async def workflow_copilot_chat_post(
|
|||||||
debug_run_info_text,
|
debug_run_info_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
if await request.is_disconnected():
|
if await stream.is_disconnected():
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Workflow copilot chat request is disconnected after LLM call",
|
"Workflow copilot chat request is disconnected after LLM call",
|
||||||
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
||||||
@@ -347,62 +416,50 @@ async def workflow_copilot_chat_post(
|
|||||||
global_llm_context=updated_global_llm_context,
|
global_llm_context=updated_global_llm_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield JSONServerSentEvent(
|
await stream.send(
|
||||||
data=WorkflowCopilotStreamResponseUpdate(
|
WorkflowCopilotStreamResponseUpdate(
|
||||||
type=WorkflowCopilotStreamMessageType.RESPONSE,
|
type=WorkflowCopilotStreamMessageType.RESPONSE,
|
||||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||||
message=user_response,
|
message=user_response,
|
||||||
updated_workflow=updated_workflow.model_dump(mode="json") if updated_workflow else None,
|
updated_workflow=updated_workflow.model_dump(mode="json") if updated_workflow else None,
|
||||||
response_time=assistant_message.created_at,
|
response_time=assistant_message.created_at,
|
||||||
).model_dump(mode="json"),
|
)
|
||||||
)
|
)
|
||||||
except HTTPException as exc:
|
except HTTPException as exc:
|
||||||
if await request.is_disconnected():
|
await stream.send(
|
||||||
return
|
WorkflowCopilotStreamErrorUpdate(
|
||||||
yield JSONServerSentEvent(
|
|
||||||
data=WorkflowCopilotStreamErrorUpdate(
|
|
||||||
type=WorkflowCopilotStreamMessageType.ERROR,
|
type=WorkflowCopilotStreamMessageType.ERROR,
|
||||||
error=exc.detail,
|
error=exc.detail,
|
||||||
).model_dump(mode="json"),
|
)
|
||||||
)
|
)
|
||||||
except LLMProviderError as exc:
|
except LLMProviderError as exc:
|
||||||
if await request.is_disconnected():
|
|
||||||
return
|
|
||||||
LOG.error(
|
LOG.error(
|
||||||
"LLM provider error",
|
"LLM provider error",
|
||||||
organization_id=organization.organization_id,
|
organization_id=organization.organization_id,
|
||||||
error=str(exc),
|
error=str(exc),
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
yield JSONServerSentEvent(
|
await stream.send(
|
||||||
data=WorkflowCopilotStreamErrorUpdate(
|
WorkflowCopilotStreamErrorUpdate(
|
||||||
type=WorkflowCopilotStreamMessageType.ERROR,
|
type=WorkflowCopilotStreamMessageType.ERROR,
|
||||||
error="Failed to process your request. Please try again.",
|
error="Failed to process your request. Please try again.",
|
||||||
).model_dump(mode="json"),
|
)
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if await request.is_disconnected():
|
|
||||||
return
|
|
||||||
LOG.error(
|
LOG.error(
|
||||||
"Unexpected error in workflow copilot",
|
"Unexpected error in workflow copilot",
|
||||||
organization_id=organization.organization_id,
|
organization_id=organization.organization_id,
|
||||||
error=str(exc),
|
error=str(exc),
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
yield JSONServerSentEvent(
|
await stream.send(
|
||||||
data=WorkflowCopilotStreamErrorUpdate(
|
WorkflowCopilotStreamErrorUpdate(
|
||||||
type=WorkflowCopilotStreamMessageType.ERROR, error="An error occurred. Please try again."
|
type=WorkflowCopilotStreamMessageType.ERROR,
|
||||||
).model_dump(mode="json"),
|
error="An error occurred. Please try again.",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def ping_message_factory() -> ServerSentEvent:
|
return FastAPIEventSourceStream.create(request, stream_handler)
|
||||||
return ServerSentEvent(comment="keep-alive")
|
|
||||||
|
|
||||||
return EventSourceResponse(
|
|
||||||
event_stream(),
|
|
||||||
ping=SSE_KEEPALIVE_INTERVAL_SECONDS,
|
|
||||||
ping_message_factory=ping_message_factory,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@base_router.get("/workflow/copilot/chat-history", include_in_schema=False)
|
@base_router.get("/workflow/copilot/chat-history", include_in_schema=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user