Workflow Copilot: server update streaming with "cancel" ability (#4456)

This commit is contained in:
Stanislav Novosad
2026-01-14 18:34:09 -07:00
committed by GitHub
parent 6b9ea59e67
commit 9cf1f87514
7 changed files with 462 additions and 170 deletions

View File

@@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator
import structlog
import yaml
from fastapi import Depends, HTTPException, status
from fastapi import Depends, HTTPException, Request, status
from sse_starlette import EventSourceResponse, JSONServerSentEvent, ServerSentEvent
from skyvern.forge import app
@@ -22,8 +22,11 @@ from skyvern.forge.sdk.schemas.workflow_copilot import (
WorkflowCopilotChatHistoryResponse,
WorkflowCopilotChatMessage,
WorkflowCopilotChatRequest,
WorkflowCopilotChatResponse,
WorkflowCopilotChatSender,
WorkflowCopilotProcessingUpdate,
WorkflowCopilotStreamErrorUpdate,
WorkflowCopilotStreamMessageType,
WorkflowCopilotStreamResponseUpdate,
)
from skyvern.forge.sdk.services import org_auth_service
from skyvern.schemas.workflows import LoginBlockYAML, WorkflowCreateYAMLRequest
@@ -236,66 +239,81 @@ async def _process_workflow_yaml(action_data: dict[str, Any]) -> None | str:
@base_router.post("/workflow/copilot/chat-post", include_in_schema=False)
async def workflow_copilot_chat_post(
request: Request,
chat_request: WorkflowCopilotChatRequest,
organization: Organization = Depends(org_auth_service.get_current_org),
) -> EventSourceResponse:
LOG.info(
"Workflow copilot chat request",
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
workflow_run_id=chat_request.workflow_run_id,
message=chat_request.message,
workflow_yaml_length=len(chat_request.workflow_yaml),
organization_id=organization.organization_id,
)
request_started_at = datetime.now(timezone.utc)
if chat_request.workflow_copilot_chat_id:
chat = await app.DATABASE.get_workflow_copilot_chat_by_id(
organization_id=organization.organization_id,
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
)
if not chat:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
if chat_request.workflow_permanent_id != chat.workflow_permanent_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Wrong workflow permanent ID")
else:
chat = await app.DATABASE.create_workflow_copilot_chat(
organization_id=organization.organization_id,
workflow_permanent_id=chat_request.workflow_permanent_id,
)
chat_messages = await app.DATABASE.get_workflow_copilot_chat_messages(
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
)
global_llm_context = None
for message in reversed(chat_messages):
if message.global_llm_context is not None:
global_llm_context = message.global_llm_context
break
debug_run_info = await _get_debug_run_info(organization.organization_id, chat_request.workflow_run_id)
# Format debug run info for prompt
debug_run_info_text = ""
if debug_run_info:
debug_run_info_text = f"Block Label: {debug_run_info.block_label}"
debug_run_info_text += f" Block Type: {debug_run_info.block_type}"
debug_run_info_text += f" Status: {debug_run_info.block_status}"
if debug_run_info.failure_reason:
debug_run_info_text += f"\nFailure Reason: {debug_run_info.failure_reason}"
if debug_run_info.html:
debug_run_info_text += f"\n\nVisible Elements Tree (HTML):\n{debug_run_info.html}"
await app.DATABASE.create_workflow_copilot_chat_message(
organization_id=chat.organization_id,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
sender=WorkflowCopilotChatSender.USER,
content=chat_request.message,
)
async def event_stream() -> AsyncGenerator[JSONServerSentEvent, None]:
LOG.info(
"Workflow copilot chat request",
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
workflow_run_id=chat_request.workflow_run_id,
message=chat_request.message,
workflow_yaml_length=len(chat_request.workflow_yaml),
organization_id=organization.organization_id,
)
try:
yield JSONServerSentEvent(
data=WorkflowCopilotProcessingUpdate(
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
status="Processing...",
timestamp=datetime.now(timezone.utc),
).model_dump(mode="json"),
)
if chat_request.workflow_copilot_chat_id:
chat = await app.DATABASE.get_workflow_copilot_chat_by_id(
organization_id=organization.organization_id,
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
)
if not chat:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
if chat_request.workflow_permanent_id != chat.workflow_permanent_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Wrong workflow permanent ID")
else:
chat = await app.DATABASE.create_workflow_copilot_chat(
organization_id=organization.organization_id,
workflow_permanent_id=chat_request.workflow_permanent_id,
)
chat_messages = await app.DATABASE.get_workflow_copilot_chat_messages(
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
)
global_llm_context = None
for message in reversed(chat_messages):
if message.global_llm_context is not None:
global_llm_context = message.global_llm_context
break
debug_run_info = await _get_debug_run_info(organization.organization_id, chat_request.workflow_run_id)
# Format debug run info for prompt
debug_run_info_text = ""
if debug_run_info:
debug_run_info_text = f"Block Label: {debug_run_info.block_label}"
debug_run_info_text += f" Block Type: {debug_run_info.block_type}"
debug_run_info_text += f" Status: {debug_run_info.block_status}"
if debug_run_info.failure_reason:
debug_run_info_text += f"\nFailure Reason: {debug_run_info.failure_reason}"
if debug_run_info.html:
debug_run_info_text += f"\n\nVisible Elements Tree (HTML):\n{debug_run_info.html}"
yield JSONServerSentEvent(
data=WorkflowCopilotProcessingUpdate(
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
status="Thinking...",
timestamp=datetime.now(timezone.utc),
).model_dump(mode="json"),
)
if await request.is_disconnected():
LOG.info(
"Workflow copilot chat request is disconnected before LLM call",
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
)
return
user_response, updated_workflow_yaml, updated_global_llm_context = await copilot_call_llm(
organization.organization_id,
chat_request,
@@ -303,6 +321,21 @@ async def workflow_copilot_chat_post(
global_llm_context,
debug_run_info_text,
)
if await request.is_disconnected():
LOG.info(
"Workflow copilot chat request is disconnected after LLM call",
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
)
return
await app.DATABASE.create_workflow_copilot_chat_message(
organization_id=chat.organization_id,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
sender=WorkflowCopilotChatSender.USER,
content=chat_request.message,
)
assistant_message = await app.DATABASE.create_workflow_copilot_chat_message(
organization_id=chat.organization_id,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
@@ -311,32 +344,53 @@ async def workflow_copilot_chat_post(
global_llm_context=updated_global_llm_context,
)
response_payload = WorkflowCopilotChatResponse(
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
message=user_response,
updated_workflow_yaml=updated_workflow_yaml,
request_time=request_started_at,
response_time=assistant_message.created_at,
).model_dump(mode="json")
yield JSONServerSentEvent(response_payload)
yield JSONServerSentEvent(
data=WorkflowCopilotStreamResponseUpdate(
type=WorkflowCopilotStreamMessageType.RESPONSE,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
message=user_response,
updated_workflow_yaml=updated_workflow_yaml,
response_time=assistant_message.created_at,
).model_dump(mode="json"),
)
except HTTPException as exc:
yield JSONServerSentEvent({"error": exc.detail})
if await request.is_disconnected():
return
yield JSONServerSentEvent(
data=WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
error=exc.detail,
).model_dump(mode="json"),
)
except LLMProviderError as exc:
if await request.is_disconnected():
return
LOG.error(
"LLM provider error",
organization_id=organization.organization_id,
error=str(exc),
exc_info=True,
)
yield JSONServerSentEvent({"error": "Failed to process your request. Please try again."})
yield JSONServerSentEvent(
data=WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
error="Failed to process your request. Please try again.",
).model_dump(mode="json"),
)
except Exception as exc:
if await request.is_disconnected():
return
LOG.error(
"Unexpected error in workflow copilot",
organization_id=organization.organization_id,
error=str(exc),
exc_info=True,
)
yield JSONServerSentEvent({"error": "An error occurred. Please try again."})
yield JSONServerSentEvent(
data=WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR, error="An error occurred. Please try again."
).model_dump(mode="json"),
)
def ping_message_factory() -> ServerSentEvent:
return ServerSentEvent(comment="keep-alive")

View File

@@ -39,14 +39,6 @@ class WorkflowCopilotChatRequest(BaseModel):
workflow_yaml: str = Field(..., description="Current workflow YAML including unsaved changes")
class WorkflowCopilotChatResponse(BaseModel):
workflow_copilot_chat_id: str = Field(..., description="The chat ID")
message: str = Field(..., description="The message sent to the user")
updated_workflow_yaml: str | None = Field(None, description="The updated workflow yaml")
request_time: datetime = Field(..., description="When the request was received")
response_time: datetime = Field(..., description="When the assistant message was created")
class WorkflowCopilotChatHistoryMessage(BaseModel):
sender: WorkflowCopilotChatSender = Field(..., description="Message sender")
content: str = Field(..., description="Message content")
@@ -56,3 +48,32 @@ class WorkflowCopilotChatHistoryMessage(BaseModel):
class WorkflowCopilotChatHistoryResponse(BaseModel):
workflow_copilot_chat_id: str | None = Field(None, description="Latest chat ID for the workflow")
chat_history: list[WorkflowCopilotChatHistoryMessage] = Field(default_factory=list, description="Chat messages")
class WorkflowCopilotStreamMessageType(StrEnum):
PROCESSING_UPDATE = "processing_update"
RESPONSE = "response"
ERROR = "error"
class WorkflowCopilotProcessingUpdate(BaseModel):
type: WorkflowCopilotStreamMessageType = Field(
WorkflowCopilotStreamMessageType.PROCESSING_UPDATE, description="Message type"
)
status: str = Field(..., description="Processing status text")
timestamp: datetime = Field(..., description="Server timestamp")
class WorkflowCopilotStreamResponseUpdate(BaseModel):
type: WorkflowCopilotStreamMessageType = Field(
WorkflowCopilotStreamMessageType.RESPONSE, description="Message type"
)
workflow_copilot_chat_id: str = Field(..., description="The chat ID")
message: str = Field(..., description="The message sent to the user")
updated_workflow_yaml: str | None = Field(None, description="The updated workflow yaml")
response_time: datetime = Field(..., description="When the assistant message was created")
class WorkflowCopilotStreamErrorUpdate(BaseModel):
type: WorkflowCopilotStreamMessageType = Field(WorkflowCopilotStreamMessageType.ERROR, description="Message type")
error: str = Field(..., description="Error message")