diff --git a/skyvern/forge/sdk/routes/event_source_stream.py b/skyvern/forge/sdk/routes/event_source_stream.py new file mode 100644 index 00000000..f8bff123 --- /dev/null +++ b/skyvern/forge/sdk/routes/event_source_stream.py @@ -0,0 +1,146 @@ +import asyncio +from typing import Any, Awaitable, Callable, Protocol + +from fastapi import Request +from pydantic import BaseModel +from sse_starlette import EventSourceResponse, JSONServerSentEvent, ServerSentEvent + +DEFAULT_KEEPALIVE_INTERVAL_SECONDS = 10 + + +class EventSourceStream(Protocol): + """Protocol for Server-Sent Events (SSE) streams.""" + + async def send(self, data: Any) -> bool: + """ + Send data as an SSE event. + + Returns: + True if the event was queued successfully, False if disconnected or closed. + """ + ... + + async def is_disconnected(self) -> bool: + """Check if the client has disconnected.""" + ... + + async def close(self) -> None: + """Signal that the stream is complete.""" + ... + + +class FastAPIEventSourceStream: + """ + FastAPI implementation of EventSourceStream. + + This class provides a cleaner interface for sending SSE updates from async functions + instead of using yield-based generators directly. + + Usage: + @app.post("/stream") + async def my_endpoint(request: Request) -> EventSourceResponse: + async def handler(stream: EventSourceStream) -> None: + await stream.send(MyUpdateModel(status="Processing...")) + result = await do_work() + await stream.send({"status": "Done", "result": result}) + + return FastAPIEventSourceStream.create(request, handler) + """ + + def __init__(self, request: Request) -> None: + self._request = request + self._queue: asyncio.Queue[Any] = asyncio.Queue() + self._closed = False + + async def send(self, data: Any) -> bool: + """ + Send data as an SSE event. Accepts Pydantic models or dicts. + + Returns: + True if the event was queued successfully, False if disconnected or closed. + """ + if self._closed or await self.is_disconnected(): + return False + await self._queue.put(data) + return True + + async def is_disconnected(self) -> bool: + """Check if the client has disconnected.""" + return await self._request.is_disconnected() + + async def close(self) -> None: + """Signal that the stream is complete.""" + self._closed = True + await self._queue.put(None) + + def _serialize(self, data: Any) -> Any: + """Serialize data to JSON-compatible format.""" + if isinstance(data, BaseModel): + return data.model_dump(mode="json") + return data + + async def _generate(self) -> Any: + """Internal generator that yields SSE events from the queue.""" + while True: + try: + data = await self._queue.get() + if data is None: + break + if await self.is_disconnected(): + break + yield JSONServerSentEvent(data=self._serialize(data)) + except Exception: + break + + @classmethod + def create( + cls, + request: Request, + handler: Callable[[EventSourceStream], Awaitable[None]], + ping_interval: int = DEFAULT_KEEPALIVE_INTERVAL_SECONDS, + ) -> EventSourceResponse: + """ + Create an EventSourceResponse that runs the handler with an EventSourceStream. + + Args: + request: The FastAPI request object + handler: An async function that receives the stream and sends events + ping_interval: Interval in seconds for keep-alive pings (default: 10) + + Returns: + An EventSourceResponse that can be returned from a FastAPI endpoint + """ + stream = cls(request) + + async def event_generator() -> Any: + task = asyncio.create_task(cls._run_handler(stream, handler)) + try: + async for event in stream._generate(): + yield event + finally: + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + def ping_message_factory() -> ServerSentEvent: + return ServerSentEvent(comment="keep-alive") + + return EventSourceResponse( + event_generator(), + ping=ping_interval, + ping_message_factory=ping_message_factory, + ) + + @staticmethod + async def _run_handler( + stream: EventSourceStream, + handler: Callable[[EventSourceStream], Awaitable[None]], + ) -> None: + """Run the handler and ensure the stream is closed when done.""" + try: + await handler(stream) + finally: + await stream.close() diff --git a/skyvern/forge/sdk/routes/workflow_copilot.py b/skyvern/forge/sdk/routes/workflow_copilot.py index 03889a84..69a4978b 100644 --- a/skyvern/forge/sdk/routes/workflow_copilot.py +++ b/skyvern/forge/sdk/routes/workflow_copilot.py @@ -2,18 +2,21 @@ import time from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path -from typing import AsyncGenerator +from typing import Any import structlog import yaml from fastapi import Depends, HTTPException, Request, status -from sse_starlette import EventSourceResponse, JSONServerSentEvent, ServerSentEvent +from pydantic import ValidationError +from sse_starlette import EventSourceResponse from skyvern.forge import app from skyvern.forge.prompts import prompt_engine +from skyvern.forge.sdk.api.llm.api_handler import LLMAPIHandler from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.experimentation.llm_prompt_config import get_llm_handler_for_prompt_type +from skyvern.forge.sdk.routes.event_source_stream import EventSourceStream, FastAPIEventSourceStream from skyvern.forge.sdk.routes.routers import base_router from skyvern.forge.sdk.routes.run_blocks import DEFAULT_LOGIN_PROMPT from skyvern.forge.sdk.schemas.organizations import Organization @@ -29,6 +32,7 @@ from skyvern.forge.sdk.schemas.workflow_copilot import ( WorkflowCopilotStreamResponseUpdate, ) from skyvern.forge.sdk.services import org_auth_service +from skyvern.forge.sdk.workflow.exceptions import BaseWorkflowHTTPException from skyvern.forge.sdk.workflow.models.parameter import ParameterType from skyvern.forge.sdk.workflow.models.workflow import WorkflowDefinition from skyvern.forge.sdk.workflow.workflow_definition_converter import convert_workflow_definition @@ -39,7 +43,6 @@ from skyvern.schemas.workflows import ( WORKFLOW_KNOWLEDGE_BASE_PATH = Path("skyvern/forge/prompts/skyvern/workflow_knowledge_base.txt") CHAT_HISTORY_CONTEXT_MESSAGES = 10 -SSE_KEEPALIVE_INTERVAL_SECONDS = 10 LOG = structlog.get_logger() @@ -88,19 +91,41 @@ async def _get_debug_run_info(organization_id: str, workflow_run_id: str | None) ) +def _format_chat_history(chat_history: list[WorkflowCopilotChatHistoryMessage]) -> str: + chat_history_text = "" + if chat_history: + history_lines = [f"{msg.sender}: {msg.content}" for msg in chat_history] + chat_history_text = "\n".join(history_lines) + return chat_history_text + + +def _parse_llm_response(llm_response: dict[str, Any] | Any) -> Any: + if isinstance(llm_response, dict) and "output" in llm_response: + action_data = llm_response["output"] + else: + action_data = llm_response + + if not isinstance(action_data, dict): + LOG.error( + "LLM response is not valid JSON", + response_type=type(action_data).__name__, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Invalid response from LLM", + ) + return action_data + + async def copilot_call_llm( + stream: EventSourceStream, organization_id: str, chat_request: WorkflowCopilotChatRequest, chat_history: list[WorkflowCopilotChatHistoryMessage], global_llm_context: str | None, debug_run_info_text: str, ) -> tuple[str, WorkflowDefinition | None, str | None]: - current_datetime = datetime.now(timezone.utc).isoformat() - - chat_history_text = "" - if chat_history: - history_lines = [f"{msg.sender}: {msg.content}" for msg in chat_history] - chat_history_text = "\n".join(history_lines) + chat_history_text = _format_chat_history(chat_history) workflow_knowledge_base = WORKFLOW_KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8") @@ -111,7 +136,7 @@ async def copilot_call_llm( user_message=chat_request.message, chat_history=chat_history_text, global_llm_context=global_llm_context or "", - current_datetime=current_datetime, + current_datetime=datetime.now(timezone.utc).isoformat(), debug_run_info=debug_run_info_text, ) @@ -150,21 +175,7 @@ async def copilot_call_llm( llm_response=llm_response, ) - if isinstance(llm_response, dict) and "output" in llm_response: - action_data = llm_response["output"] - else: - action_data = llm_response - - if not isinstance(action_data, dict): - LOG.error( - "LLM response is not valid JSON", - organization_id=organization_id, - response_type=type(action_data).__name__, - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Invalid response from LLM", - ) + action_data = _parse_llm_response(llm_response) action_type = action_data.get("type") user_response_value = action_data.get("user_response") @@ -183,7 +194,29 @@ async def copilot_call_llm( global_llm_context = str(global_llm_context) if action_type == "REPLACE_WORKFLOW": - updated_workflow = await _process_workflow_yaml(chat_request.workflow_id, action_data.get("workflow_yaml", "")) + workflow_yaml = action_data.get("workflow_yaml", "") + try: + updated_workflow = await _process_workflow_yaml(chat_request.workflow_id, workflow_yaml) + except (yaml.YAMLError, ValidationError, BaseWorkflowHTTPException) as e: + await stream.send( + WorkflowCopilotProcessingUpdate( + type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE, + status="Validating workflow definition...", + timestamp=datetime.now(timezone.utc), + ) + ) + corrected_workflow_yaml = await _auto_correct_workflow_yaml( + llm_api_handler=llm_api_handler, + organization_id=organization_id, + user_response=user_response, + workflow_yaml=workflow_yaml, + chat_history=chat_history, + global_llm_context=global_llm_context, + debug_run_info_text=debug_run_info_text, + error=e, + ) + updated_workflow = await _process_workflow_yaml(chat_request.workflow_id, corrected_workflow_yaml) + return user_response, updated_workflow, global_llm_context elif action_type == "REPLY": return user_response, None, global_llm_context @@ -198,45 +231,80 @@ async def copilot_call_llm( return "I received your request but I'm not sure how to help. Could you rephrase?", None, None +async def _auto_correct_workflow_yaml( + llm_api_handler: LLMAPIHandler, + organization_id: str, + user_response: str, + workflow_yaml: str, + chat_history: list[WorkflowCopilotChatHistoryMessage], + global_llm_context: str | None, + debug_run_info_text: str, + error: Exception, +) -> str: + failure_reason = f"{error.__class__.__name__}: {error}" + + new_chat_history = chat_history[:] + new_chat_history.append( + WorkflowCopilotChatHistoryMessage( + sender=WorkflowCopilotChatSender.AI, + content=user_response, + created_at=datetime.now(timezone.utc), + ) + ) + + workflow_knowledge_base = WORKFLOW_KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8") + llm_prompt = prompt_engine.load_prompt( + template="workflow-copilot", + workflow_knowledge_base=workflow_knowledge_base, + workflow_yaml=workflow_yaml, + user_message=f"Workflow YAML parsing failed, please fix it: {failure_reason}", + chat_history=_format_chat_history(new_chat_history), + global_llm_context=global_llm_context or "", + current_datetime=datetime.now(timezone.utc).isoformat(), + debug_run_info=debug_run_info_text, + ) + llm_start_time = time.monotonic() + llm_response = await llm_api_handler( + prompt=llm_prompt, + prompt_name="workflow-copilot", + organization_id=organization_id, + ) + LOG.info( + "Auto-correction LLM response", + duration_seconds=time.monotonic() - llm_start_time, + llm_response_len=len(llm_response), + llm_response=llm_response, + ) + action_data = _parse_llm_response(llm_response) + + return action_data.get("workflow_yaml", workflow_yaml) + + async def _process_workflow_yaml(workflow_id: str, workflow_yaml: str) -> WorkflowDefinition: - try: - parsed_yaml = yaml.safe_load(workflow_yaml) - except yaml.YAMLError as e: - LOG.error("Invalid YAML from LLM", yaml=workflow_yaml, exc_info=True) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"LLM generated invalid YAML: {str(e)}", - ) + parsed_yaml = yaml.safe_load(workflow_yaml) - try: - # Fixing trivial common LLM mistakes - workflow_definition = parsed_yaml.get("workflow_definition", None) - if workflow_definition: - blocks = workflow_definition.get("blocks", []) - for block in blocks: - block["title"] = block.get("title", "") + # Fixing trivial common LLM mistakes + workflow_definition = parsed_yaml.get("workflow_definition", None) + if workflow_definition: + blocks = workflow_definition.get("blocks", []) + for block in blocks: + block["title"] = block.get("title", "") - workflow_yaml_request = WorkflowCreateYAMLRequest.model_validate(parsed_yaml) + workflow_yaml_request = WorkflowCreateYAMLRequest.model_validate(parsed_yaml) - # Post-processing - for block in workflow_yaml_request.workflow_definition.blocks: - if isinstance(block, LoginBlockYAML) and not block.navigation_goal: - block.navigation_goal = DEFAULT_LOGIN_PROMPT + # Post-processing + for block in workflow_yaml_request.workflow_definition.blocks: + if isinstance(block, LoginBlockYAML) and not block.navigation_goal: + block.navigation_goal = DEFAULT_LOGIN_PROMPT - workflow_yaml_request.workflow_definition.parameters = [ - p for p in workflow_yaml_request.workflow_definition.parameters if p.parameter_type != ParameterType.OUTPUT - ] + workflow_yaml_request.workflow_definition.parameters = [ + p for p in workflow_yaml_request.workflow_definition.parameters if p.parameter_type != ParameterType.OUTPUT + ] - updated_workflow = convert_workflow_definition( - workflow_definition_yaml=workflow_yaml_request.workflow_definition, - workflow_id=workflow_id, - ) - except Exception as e: - LOG.error("YAML from LLM does not conform to Skyvern workflow schema", yaml=workflow_yaml, exc_info=True) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"LLM generated YAML that doesn't match workflow schema: {str(e)}", - ) + updated_workflow = convert_workflow_definition( + workflow_definition_yaml=workflow_yaml_request.workflow_definition, + workflow_id=workflow_id, + ) return updated_workflow @@ -246,7 +314,7 @@ async def workflow_copilot_chat_post( chat_request: WorkflowCopilotChatRequest, organization: Organization = Depends(org_auth_service.get_current_org), ) -> EventSourceResponse: - async def event_stream() -> AsyncGenerator[JSONServerSentEvent, None]: + async def stream_handler(stream: EventSourceStream) -> None: LOG.info( "Workflow copilot chat request", workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id, @@ -257,12 +325,12 @@ async def workflow_copilot_chat_post( ) try: - yield JSONServerSentEvent( - data=WorkflowCopilotProcessingUpdate( + await stream.send( + WorkflowCopilotProcessingUpdate( type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE, status="Processing...", timestamp=datetime.now(timezone.utc), - ).model_dump(mode="json"), + ) ) if chat_request.workflow_copilot_chat_id: @@ -302,15 +370,15 @@ async def workflow_copilot_chat_post( if debug_run_info.html: debug_run_info_text += f"\n\nVisible Elements Tree (HTML):\n{debug_run_info.html}" - yield JSONServerSentEvent( - data=WorkflowCopilotProcessingUpdate( + await stream.send( + WorkflowCopilotProcessingUpdate( type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE, status="Thinking...", timestamp=datetime.now(timezone.utc), - ).model_dump(mode="json"), + ) ) - if await request.is_disconnected(): + if await stream.is_disconnected(): LOG.info( "Workflow copilot chat request is disconnected before LLM call", workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id, @@ -318,6 +386,7 @@ async def workflow_copilot_chat_post( return user_response, updated_workflow, updated_global_llm_context = await copilot_call_llm( + stream, organization.organization_id, chat_request, convert_to_history_messages(chat_messages[-CHAT_HISTORY_CONTEXT_MESSAGES:]), @@ -325,7 +394,7 @@ async def workflow_copilot_chat_post( debug_run_info_text, ) - if await request.is_disconnected(): + if await stream.is_disconnected(): LOG.info( "Workflow copilot chat request is disconnected after LLM call", workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id, @@ -347,62 +416,50 @@ async def workflow_copilot_chat_post( global_llm_context=updated_global_llm_context, ) - yield JSONServerSentEvent( - data=WorkflowCopilotStreamResponseUpdate( + await stream.send( + WorkflowCopilotStreamResponseUpdate( type=WorkflowCopilotStreamMessageType.RESPONSE, workflow_copilot_chat_id=chat.workflow_copilot_chat_id, message=user_response, updated_workflow=updated_workflow.model_dump(mode="json") if updated_workflow else None, response_time=assistant_message.created_at, - ).model_dump(mode="json"), + ) ) except HTTPException as exc: - if await request.is_disconnected(): - return - yield JSONServerSentEvent( - data=WorkflowCopilotStreamErrorUpdate( + await stream.send( + WorkflowCopilotStreamErrorUpdate( type=WorkflowCopilotStreamMessageType.ERROR, error=exc.detail, - ).model_dump(mode="json"), + ) ) except LLMProviderError as exc: - if await request.is_disconnected(): - return LOG.error( "LLM provider error", organization_id=organization.organization_id, error=str(exc), exc_info=True, ) - yield JSONServerSentEvent( - data=WorkflowCopilotStreamErrorUpdate( + await stream.send( + WorkflowCopilotStreamErrorUpdate( type=WorkflowCopilotStreamMessageType.ERROR, error="Failed to process your request. Please try again.", - ).model_dump(mode="json"), + ) ) except Exception as exc: - if await request.is_disconnected(): - return LOG.error( "Unexpected error in workflow copilot", organization_id=organization.organization_id, error=str(exc), exc_info=True, ) - yield JSONServerSentEvent( - data=WorkflowCopilotStreamErrorUpdate( - type=WorkflowCopilotStreamMessageType.ERROR, error="An error occurred. Please try again." - ).model_dump(mode="json"), + await stream.send( + WorkflowCopilotStreamErrorUpdate( + type=WorkflowCopilotStreamMessageType.ERROR, + error="An error occurred. Please try again.", + ) ) - def ping_message_factory() -> ServerSentEvent: - return ServerSentEvent(comment="keep-alive") - - return EventSourceResponse( - event_stream(), - ping=SSE_KEEPALIVE_INTERVAL_SECONDS, - ping_message_factory=ping_message_factory, - ) + return FastAPIEventSourceStream.create(request, stream_handler) @base_router.get("/workflow/copilot/chat-history", include_in_schema=False)