Workflow Copilot: LLM-based YAML autocorrection (#4504)

This commit is contained in:
Stanislav Novosad
2026-01-20 17:11:35 -07:00
committed by GitHub
parent 0777d27fee
commit 5d7814a925
2 changed files with 298 additions and 95 deletions

View File

@@ -0,0 +1,146 @@
import asyncio
from typing import Any, Awaitable, Callable, Protocol
from fastapi import Request
from pydantic import BaseModel
from sse_starlette import EventSourceResponse, JSONServerSentEvent, ServerSentEvent
DEFAULT_KEEPALIVE_INTERVAL_SECONDS = 10
class EventSourceStream(Protocol):
"""Protocol for Server-Sent Events (SSE) streams."""
async def send(self, data: Any) -> bool:
"""
Send data as an SSE event.
Returns:
True if the event was queued successfully, False if disconnected or closed.
"""
...
async def is_disconnected(self) -> bool:
"""Check if the client has disconnected."""
...
async def close(self) -> None:
"""Signal that the stream is complete."""
...
class FastAPIEventSourceStream:
"""
FastAPI implementation of EventSourceStream.
This class provides a cleaner interface for sending SSE updates from async functions
instead of using yield-based generators directly.
Usage:
@app.post("/stream")
async def my_endpoint(request: Request) -> EventSourceResponse:
async def handler(stream: EventSourceStream) -> None:
await stream.send(MyUpdateModel(status="Processing..."))
result = await do_work()
await stream.send({"status": "Done", "result": result})
return FastAPIEventSourceStream.create(request, handler)
"""
def __init__(self, request: Request) -> None:
self._request = request
self._queue: asyncio.Queue[Any] = asyncio.Queue()
self._closed = False
async def send(self, data: Any) -> bool:
"""
Send data as an SSE event. Accepts Pydantic models or dicts.
Returns:
True if the event was queued successfully, False if disconnected or closed.
"""
if self._closed or await self.is_disconnected():
return False
await self._queue.put(data)
return True
async def is_disconnected(self) -> bool:
"""Check if the client has disconnected."""
return await self._request.is_disconnected()
async def close(self) -> None:
"""Signal that the stream is complete."""
self._closed = True
await self._queue.put(None)
def _serialize(self, data: Any) -> Any:
"""Serialize data to JSON-compatible format."""
if isinstance(data, BaseModel):
return data.model_dump(mode="json")
return data
async def _generate(self) -> Any:
"""Internal generator that yields SSE events from the queue."""
while True:
try:
data = await self._queue.get()
if data is None:
break
if await self.is_disconnected():
break
yield JSONServerSentEvent(data=self._serialize(data))
except Exception:
break
@classmethod
def create(
cls,
request: Request,
handler: Callable[[EventSourceStream], Awaitable[None]],
ping_interval: int = DEFAULT_KEEPALIVE_INTERVAL_SECONDS,
) -> EventSourceResponse:
"""
Create an EventSourceResponse that runs the handler with an EventSourceStream.
Args:
request: The FastAPI request object
handler: An async function that receives the stream and sends events
ping_interval: Interval in seconds for keep-alive pings (default: 10)
Returns:
An EventSourceResponse that can be returned from a FastAPI endpoint
"""
stream = cls(request)
async def event_generator() -> Any:
task = asyncio.create_task(cls._run_handler(stream, handler))
try:
async for event in stream._generate():
yield event
finally:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
def ping_message_factory() -> ServerSentEvent:
return ServerSentEvent(comment="keep-alive")
return EventSourceResponse(
event_generator(),
ping=ping_interval,
ping_message_factory=ping_message_factory,
)
@staticmethod
async def _run_handler(
stream: EventSourceStream,
handler: Callable[[EventSourceStream], Awaitable[None]],
) -> None:
"""Run the handler and ensure the stream is closed when done."""
try:
await handler(stream)
finally:
await stream.close()

View File

@@ -2,18 +2,21 @@ import time
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import AsyncGenerator
from typing import Any
import structlog
import yaml
from fastapi import Depends, HTTPException, Request, status
from sse_starlette import EventSourceResponse, JSONServerSentEvent, ServerSentEvent
from pydantic import ValidationError
from sse_starlette import EventSourceResponse
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.llm.api_handler import LLMAPIHandler
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.experimentation.llm_prompt_config import get_llm_handler_for_prompt_type
from skyvern.forge.sdk.routes.event_source_stream import EventSourceStream, FastAPIEventSourceStream
from skyvern.forge.sdk.routes.routers import base_router
from skyvern.forge.sdk.routes.run_blocks import DEFAULT_LOGIN_PROMPT
from skyvern.forge.sdk.schemas.organizations import Organization
@@ -29,6 +32,7 @@ from skyvern.forge.sdk.schemas.workflow_copilot import (
WorkflowCopilotStreamResponseUpdate,
)
from skyvern.forge.sdk.services import org_auth_service
from skyvern.forge.sdk.workflow.exceptions import BaseWorkflowHTTPException
from skyvern.forge.sdk.workflow.models.parameter import ParameterType
from skyvern.forge.sdk.workflow.models.workflow import WorkflowDefinition
from skyvern.forge.sdk.workflow.workflow_definition_converter import convert_workflow_definition
@@ -39,7 +43,6 @@ from skyvern.schemas.workflows import (
WORKFLOW_KNOWLEDGE_BASE_PATH = Path("skyvern/forge/prompts/skyvern/workflow_knowledge_base.txt")
CHAT_HISTORY_CONTEXT_MESSAGES = 10
SSE_KEEPALIVE_INTERVAL_SECONDS = 10
LOG = structlog.get_logger()
@@ -88,19 +91,41 @@ async def _get_debug_run_info(organization_id: str, workflow_run_id: str | None)
)
def _format_chat_history(chat_history: list[WorkflowCopilotChatHistoryMessage]) -> str:
chat_history_text = ""
if chat_history:
history_lines = [f"{msg.sender}: {msg.content}" for msg in chat_history]
chat_history_text = "\n".join(history_lines)
return chat_history_text
def _parse_llm_response(llm_response: dict[str, Any] | Any) -> Any:
if isinstance(llm_response, dict) and "output" in llm_response:
action_data = llm_response["output"]
else:
action_data = llm_response
if not isinstance(action_data, dict):
LOG.error(
"LLM response is not valid JSON",
response_type=type(action_data).__name__,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Invalid response from LLM",
)
return action_data
async def copilot_call_llm(
stream: EventSourceStream,
organization_id: str,
chat_request: WorkflowCopilotChatRequest,
chat_history: list[WorkflowCopilotChatHistoryMessage],
global_llm_context: str | None,
debug_run_info_text: str,
) -> tuple[str, WorkflowDefinition | None, str | None]:
current_datetime = datetime.now(timezone.utc).isoformat()
chat_history_text = ""
if chat_history:
history_lines = [f"{msg.sender}: {msg.content}" for msg in chat_history]
chat_history_text = "\n".join(history_lines)
chat_history_text = _format_chat_history(chat_history)
workflow_knowledge_base = WORKFLOW_KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8")
@@ -111,7 +136,7 @@ async def copilot_call_llm(
user_message=chat_request.message,
chat_history=chat_history_text,
global_llm_context=global_llm_context or "",
current_datetime=current_datetime,
current_datetime=datetime.now(timezone.utc).isoformat(),
debug_run_info=debug_run_info_text,
)
@@ -150,21 +175,7 @@ async def copilot_call_llm(
llm_response=llm_response,
)
if isinstance(llm_response, dict) and "output" in llm_response:
action_data = llm_response["output"]
else:
action_data = llm_response
if not isinstance(action_data, dict):
LOG.error(
"LLM response is not valid JSON",
organization_id=organization_id,
response_type=type(action_data).__name__,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Invalid response from LLM",
)
action_data = _parse_llm_response(llm_response)
action_type = action_data.get("type")
user_response_value = action_data.get("user_response")
@@ -183,7 +194,29 @@ async def copilot_call_llm(
global_llm_context = str(global_llm_context)
if action_type == "REPLACE_WORKFLOW":
updated_workflow = await _process_workflow_yaml(chat_request.workflow_id, action_data.get("workflow_yaml", ""))
workflow_yaml = action_data.get("workflow_yaml", "")
try:
updated_workflow = await _process_workflow_yaml(chat_request.workflow_id, workflow_yaml)
except (yaml.YAMLError, ValidationError, BaseWorkflowHTTPException) as e:
await stream.send(
WorkflowCopilotProcessingUpdate(
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
status="Validating workflow definition...",
timestamp=datetime.now(timezone.utc),
)
)
corrected_workflow_yaml = await _auto_correct_workflow_yaml(
llm_api_handler=llm_api_handler,
organization_id=organization_id,
user_response=user_response,
workflow_yaml=workflow_yaml,
chat_history=chat_history,
global_llm_context=global_llm_context,
debug_run_info_text=debug_run_info_text,
error=e,
)
updated_workflow = await _process_workflow_yaml(chat_request.workflow_id, corrected_workflow_yaml)
return user_response, updated_workflow, global_llm_context
elif action_type == "REPLY":
return user_response, None, global_llm_context
@@ -198,45 +231,80 @@ async def copilot_call_llm(
return "I received your request but I'm not sure how to help. Could you rephrase?", None, None
async def _auto_correct_workflow_yaml(
llm_api_handler: LLMAPIHandler,
organization_id: str,
user_response: str,
workflow_yaml: str,
chat_history: list[WorkflowCopilotChatHistoryMessage],
global_llm_context: str | None,
debug_run_info_text: str,
error: Exception,
) -> str:
failure_reason = f"{error.__class__.__name__}: {error}"
new_chat_history = chat_history[:]
new_chat_history.append(
WorkflowCopilotChatHistoryMessage(
sender=WorkflowCopilotChatSender.AI,
content=user_response,
created_at=datetime.now(timezone.utc),
)
)
workflow_knowledge_base = WORKFLOW_KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8")
llm_prompt = prompt_engine.load_prompt(
template="workflow-copilot",
workflow_knowledge_base=workflow_knowledge_base,
workflow_yaml=workflow_yaml,
user_message=f"Workflow YAML parsing failed, please fix it: {failure_reason}",
chat_history=_format_chat_history(new_chat_history),
global_llm_context=global_llm_context or "",
current_datetime=datetime.now(timezone.utc).isoformat(),
debug_run_info=debug_run_info_text,
)
llm_start_time = time.monotonic()
llm_response = await llm_api_handler(
prompt=llm_prompt,
prompt_name="workflow-copilot",
organization_id=organization_id,
)
LOG.info(
"Auto-correction LLM response",
duration_seconds=time.monotonic() - llm_start_time,
llm_response_len=len(llm_response),
llm_response=llm_response,
)
action_data = _parse_llm_response(llm_response)
return action_data.get("workflow_yaml", workflow_yaml)
async def _process_workflow_yaml(workflow_id: str, workflow_yaml: str) -> WorkflowDefinition:
try:
parsed_yaml = yaml.safe_load(workflow_yaml)
except yaml.YAMLError as e:
LOG.error("Invalid YAML from LLM", yaml=workflow_yaml, exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"LLM generated invalid YAML: {str(e)}",
)
parsed_yaml = yaml.safe_load(workflow_yaml)
try:
# Fixing trivial common LLM mistakes
workflow_definition = parsed_yaml.get("workflow_definition", None)
if workflow_definition:
blocks = workflow_definition.get("blocks", [])
for block in blocks:
block["title"] = block.get("title", "")
# Fixing trivial common LLM mistakes
workflow_definition = parsed_yaml.get("workflow_definition", None)
if workflow_definition:
blocks = workflow_definition.get("blocks", [])
for block in blocks:
block["title"] = block.get("title", "")
workflow_yaml_request = WorkflowCreateYAMLRequest.model_validate(parsed_yaml)
workflow_yaml_request = WorkflowCreateYAMLRequest.model_validate(parsed_yaml)
# Post-processing
for block in workflow_yaml_request.workflow_definition.blocks:
if isinstance(block, LoginBlockYAML) and not block.navigation_goal:
block.navigation_goal = DEFAULT_LOGIN_PROMPT
# Post-processing
for block in workflow_yaml_request.workflow_definition.blocks:
if isinstance(block, LoginBlockYAML) and not block.navigation_goal:
block.navigation_goal = DEFAULT_LOGIN_PROMPT
workflow_yaml_request.workflow_definition.parameters = [
p for p in workflow_yaml_request.workflow_definition.parameters if p.parameter_type != ParameterType.OUTPUT
]
workflow_yaml_request.workflow_definition.parameters = [
p for p in workflow_yaml_request.workflow_definition.parameters if p.parameter_type != ParameterType.OUTPUT
]
updated_workflow = convert_workflow_definition(
workflow_definition_yaml=workflow_yaml_request.workflow_definition,
workflow_id=workflow_id,
)
except Exception as e:
LOG.error("YAML from LLM does not conform to Skyvern workflow schema", yaml=workflow_yaml, exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"LLM generated YAML that doesn't match workflow schema: {str(e)}",
)
updated_workflow = convert_workflow_definition(
workflow_definition_yaml=workflow_yaml_request.workflow_definition,
workflow_id=workflow_id,
)
return updated_workflow
@@ -246,7 +314,7 @@ async def workflow_copilot_chat_post(
chat_request: WorkflowCopilotChatRequest,
organization: Organization = Depends(org_auth_service.get_current_org),
) -> EventSourceResponse:
async def event_stream() -> AsyncGenerator[JSONServerSentEvent, None]:
async def stream_handler(stream: EventSourceStream) -> None:
LOG.info(
"Workflow copilot chat request",
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
@@ -257,12 +325,12 @@ async def workflow_copilot_chat_post(
)
try:
yield JSONServerSentEvent(
data=WorkflowCopilotProcessingUpdate(
await stream.send(
WorkflowCopilotProcessingUpdate(
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
status="Processing...",
timestamp=datetime.now(timezone.utc),
).model_dump(mode="json"),
)
)
if chat_request.workflow_copilot_chat_id:
@@ -302,15 +370,15 @@ async def workflow_copilot_chat_post(
if debug_run_info.html:
debug_run_info_text += f"\n\nVisible Elements Tree (HTML):\n{debug_run_info.html}"
yield JSONServerSentEvent(
data=WorkflowCopilotProcessingUpdate(
await stream.send(
WorkflowCopilotProcessingUpdate(
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
status="Thinking...",
timestamp=datetime.now(timezone.utc),
).model_dump(mode="json"),
)
)
if await request.is_disconnected():
if await stream.is_disconnected():
LOG.info(
"Workflow copilot chat request is disconnected before LLM call",
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
@@ -318,6 +386,7 @@ async def workflow_copilot_chat_post(
return
user_response, updated_workflow, updated_global_llm_context = await copilot_call_llm(
stream,
organization.organization_id,
chat_request,
convert_to_history_messages(chat_messages[-CHAT_HISTORY_CONTEXT_MESSAGES:]),
@@ -325,7 +394,7 @@ async def workflow_copilot_chat_post(
debug_run_info_text,
)
if await request.is_disconnected():
if await stream.is_disconnected():
LOG.info(
"Workflow copilot chat request is disconnected after LLM call",
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
@@ -347,62 +416,50 @@ async def workflow_copilot_chat_post(
global_llm_context=updated_global_llm_context,
)
yield JSONServerSentEvent(
data=WorkflowCopilotStreamResponseUpdate(
await stream.send(
WorkflowCopilotStreamResponseUpdate(
type=WorkflowCopilotStreamMessageType.RESPONSE,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
message=user_response,
updated_workflow=updated_workflow.model_dump(mode="json") if updated_workflow else None,
response_time=assistant_message.created_at,
).model_dump(mode="json"),
)
)
except HTTPException as exc:
if await request.is_disconnected():
return
yield JSONServerSentEvent(
data=WorkflowCopilotStreamErrorUpdate(
await stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
error=exc.detail,
).model_dump(mode="json"),
)
)
except LLMProviderError as exc:
if await request.is_disconnected():
return
LOG.error(
"LLM provider error",
organization_id=organization.organization_id,
error=str(exc),
exc_info=True,
)
yield JSONServerSentEvent(
data=WorkflowCopilotStreamErrorUpdate(
await stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
error="Failed to process your request. Please try again.",
).model_dump(mode="json"),
)
)
except Exception as exc:
if await request.is_disconnected():
return
LOG.error(
"Unexpected error in workflow copilot",
organization_id=organization.organization_id,
error=str(exc),
exc_info=True,
)
yield JSONServerSentEvent(
data=WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR, error="An error occurred. Please try again."
).model_dump(mode="json"),
await stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
error="An error occurred. Please try again.",
)
)
def ping_message_factory() -> ServerSentEvent:
return ServerSentEvent(comment="keep-alive")
return EventSourceResponse(
event_stream(),
ping=SSE_KEEPALIVE_INTERVAL_SECONDS,
ping_message_factory=ping_message_factory,
)
return FastAPIEventSourceStream.create(request, stream_handler)
@base_router.get("/workflow/copilot/chat-history", include_in_schema=False)