Workflow Copilot: server update streaming with "cancel" ability (#4456)
This commit is contained in:
committed by
GitHub
parent
6b9ea59e67
commit
9cf1f87514
@@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator
|
||||
|
||||
import structlog
|
||||
import yaml
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from sse_starlette import EventSourceResponse, JSONServerSentEvent, ServerSentEvent
|
||||
|
||||
from skyvern.forge import app
|
||||
@@ -22,8 +22,11 @@ from skyvern.forge.sdk.schemas.workflow_copilot import (
|
||||
WorkflowCopilotChatHistoryResponse,
|
||||
WorkflowCopilotChatMessage,
|
||||
WorkflowCopilotChatRequest,
|
||||
WorkflowCopilotChatResponse,
|
||||
WorkflowCopilotChatSender,
|
||||
WorkflowCopilotProcessingUpdate,
|
||||
WorkflowCopilotStreamErrorUpdate,
|
||||
WorkflowCopilotStreamMessageType,
|
||||
WorkflowCopilotStreamResponseUpdate,
|
||||
)
|
||||
from skyvern.forge.sdk.services import org_auth_service
|
||||
from skyvern.schemas.workflows import LoginBlockYAML, WorkflowCreateYAMLRequest
|
||||
@@ -236,66 +239,81 @@ async def _process_workflow_yaml(action_data: dict[str, Any]) -> None | str:
|
||||
|
||||
@base_router.post("/workflow/copilot/chat-post", include_in_schema=False)
|
||||
async def workflow_copilot_chat_post(
|
||||
request: Request,
|
||||
chat_request: WorkflowCopilotChatRequest,
|
||||
organization: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> EventSourceResponse:
|
||||
LOG.info(
|
||||
"Workflow copilot chat request",
|
||||
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
||||
workflow_run_id=chat_request.workflow_run_id,
|
||||
message=chat_request.message,
|
||||
workflow_yaml_length=len(chat_request.workflow_yaml),
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
|
||||
request_started_at = datetime.now(timezone.utc)
|
||||
|
||||
if chat_request.workflow_copilot_chat_id:
|
||||
chat = await app.DATABASE.get_workflow_copilot_chat_by_id(
|
||||
organization_id=organization.organization_id,
|
||||
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
||||
)
|
||||
if not chat:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
||||
if chat_request.workflow_permanent_id != chat.workflow_permanent_id:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Wrong workflow permanent ID")
|
||||
else:
|
||||
chat = await app.DATABASE.create_workflow_copilot_chat(
|
||||
organization_id=organization.organization_id,
|
||||
workflow_permanent_id=chat_request.workflow_permanent_id,
|
||||
)
|
||||
|
||||
chat_messages = await app.DATABASE.get_workflow_copilot_chat_messages(
|
||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||
)
|
||||
global_llm_context = None
|
||||
for message in reversed(chat_messages):
|
||||
if message.global_llm_context is not None:
|
||||
global_llm_context = message.global_llm_context
|
||||
break
|
||||
|
||||
debug_run_info = await _get_debug_run_info(organization.organization_id, chat_request.workflow_run_id)
|
||||
|
||||
# Format debug run info for prompt
|
||||
debug_run_info_text = ""
|
||||
if debug_run_info:
|
||||
debug_run_info_text = f"Block Label: {debug_run_info.block_label}"
|
||||
debug_run_info_text += f" Block Type: {debug_run_info.block_type}"
|
||||
debug_run_info_text += f" Status: {debug_run_info.block_status}"
|
||||
if debug_run_info.failure_reason:
|
||||
debug_run_info_text += f"\nFailure Reason: {debug_run_info.failure_reason}"
|
||||
if debug_run_info.html:
|
||||
debug_run_info_text += f"\n\nVisible Elements Tree (HTML):\n{debug_run_info.html}"
|
||||
|
||||
await app.DATABASE.create_workflow_copilot_chat_message(
|
||||
organization_id=chat.organization_id,
|
||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||
sender=WorkflowCopilotChatSender.USER,
|
||||
content=chat_request.message,
|
||||
)
|
||||
|
||||
async def event_stream() -> AsyncGenerator[JSONServerSentEvent, None]:
|
||||
LOG.info(
|
||||
"Workflow copilot chat request",
|
||||
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
||||
workflow_run_id=chat_request.workflow_run_id,
|
||||
message=chat_request.message,
|
||||
workflow_yaml_length=len(chat_request.workflow_yaml),
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
|
||||
try:
|
||||
yield JSONServerSentEvent(
|
||||
data=WorkflowCopilotProcessingUpdate(
|
||||
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
|
||||
status="Processing...",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
).model_dump(mode="json"),
|
||||
)
|
||||
|
||||
if chat_request.workflow_copilot_chat_id:
|
||||
chat = await app.DATABASE.get_workflow_copilot_chat_by_id(
|
||||
organization_id=organization.organization_id,
|
||||
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
||||
)
|
||||
if not chat:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
||||
if chat_request.workflow_permanent_id != chat.workflow_permanent_id:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Wrong workflow permanent ID")
|
||||
else:
|
||||
chat = await app.DATABASE.create_workflow_copilot_chat(
|
||||
organization_id=organization.organization_id,
|
||||
workflow_permanent_id=chat_request.workflow_permanent_id,
|
||||
)
|
||||
|
||||
chat_messages = await app.DATABASE.get_workflow_copilot_chat_messages(
|
||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||
)
|
||||
global_llm_context = None
|
||||
for message in reversed(chat_messages):
|
||||
if message.global_llm_context is not None:
|
||||
global_llm_context = message.global_llm_context
|
||||
break
|
||||
|
||||
debug_run_info = await _get_debug_run_info(organization.organization_id, chat_request.workflow_run_id)
|
||||
|
||||
# Format debug run info for prompt
|
||||
debug_run_info_text = ""
|
||||
if debug_run_info:
|
||||
debug_run_info_text = f"Block Label: {debug_run_info.block_label}"
|
||||
debug_run_info_text += f" Block Type: {debug_run_info.block_type}"
|
||||
debug_run_info_text += f" Status: {debug_run_info.block_status}"
|
||||
if debug_run_info.failure_reason:
|
||||
debug_run_info_text += f"\nFailure Reason: {debug_run_info.failure_reason}"
|
||||
if debug_run_info.html:
|
||||
debug_run_info_text += f"\n\nVisible Elements Tree (HTML):\n{debug_run_info.html}"
|
||||
|
||||
yield JSONServerSentEvent(
|
||||
data=WorkflowCopilotProcessingUpdate(
|
||||
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
|
||||
status="Thinking...",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
).model_dump(mode="json"),
|
||||
)
|
||||
|
||||
if await request.is_disconnected():
|
||||
LOG.info(
|
||||
"Workflow copilot chat request is disconnected before LLM call",
|
||||
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
||||
)
|
||||
return
|
||||
|
||||
user_response, updated_workflow_yaml, updated_global_llm_context = await copilot_call_llm(
|
||||
organization.organization_id,
|
||||
chat_request,
|
||||
@@ -303,6 +321,21 @@ async def workflow_copilot_chat_post(
|
||||
global_llm_context,
|
||||
debug_run_info_text,
|
||||
)
|
||||
|
||||
if await request.is_disconnected():
|
||||
LOG.info(
|
||||
"Workflow copilot chat request is disconnected after LLM call",
|
||||
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
||||
)
|
||||
return
|
||||
|
||||
await app.DATABASE.create_workflow_copilot_chat_message(
|
||||
organization_id=chat.organization_id,
|
||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||
sender=WorkflowCopilotChatSender.USER,
|
||||
content=chat_request.message,
|
||||
)
|
||||
|
||||
assistant_message = await app.DATABASE.create_workflow_copilot_chat_message(
|
||||
organization_id=chat.organization_id,
|
||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||
@@ -311,32 +344,53 @@ async def workflow_copilot_chat_post(
|
||||
global_llm_context=updated_global_llm_context,
|
||||
)
|
||||
|
||||
response_payload = WorkflowCopilotChatResponse(
|
||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||
message=user_response,
|
||||
updated_workflow_yaml=updated_workflow_yaml,
|
||||
request_time=request_started_at,
|
||||
response_time=assistant_message.created_at,
|
||||
).model_dump(mode="json")
|
||||
yield JSONServerSentEvent(response_payload)
|
||||
yield JSONServerSentEvent(
|
||||
data=WorkflowCopilotStreamResponseUpdate(
|
||||
type=WorkflowCopilotStreamMessageType.RESPONSE,
|
||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||
message=user_response,
|
||||
updated_workflow_yaml=updated_workflow_yaml,
|
||||
response_time=assistant_message.created_at,
|
||||
).model_dump(mode="json"),
|
||||
)
|
||||
except HTTPException as exc:
|
||||
yield JSONServerSentEvent({"error": exc.detail})
|
||||
if await request.is_disconnected():
|
||||
return
|
||||
yield JSONServerSentEvent(
|
||||
data=WorkflowCopilotStreamErrorUpdate(
|
||||
type=WorkflowCopilotStreamMessageType.ERROR,
|
||||
error=exc.detail,
|
||||
).model_dump(mode="json"),
|
||||
)
|
||||
except LLMProviderError as exc:
|
||||
if await request.is_disconnected():
|
||||
return
|
||||
LOG.error(
|
||||
"LLM provider error",
|
||||
organization_id=organization.organization_id,
|
||||
error=str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
yield JSONServerSentEvent({"error": "Failed to process your request. Please try again."})
|
||||
yield JSONServerSentEvent(
|
||||
data=WorkflowCopilotStreamErrorUpdate(
|
||||
type=WorkflowCopilotStreamMessageType.ERROR,
|
||||
error="Failed to process your request. Please try again.",
|
||||
).model_dump(mode="json"),
|
||||
)
|
||||
except Exception as exc:
|
||||
if await request.is_disconnected():
|
||||
return
|
||||
LOG.error(
|
||||
"Unexpected error in workflow copilot",
|
||||
organization_id=organization.organization_id,
|
||||
error=str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
yield JSONServerSentEvent({"error": "An error occurred. Please try again."})
|
||||
yield JSONServerSentEvent(
|
||||
data=WorkflowCopilotStreamErrorUpdate(
|
||||
type=WorkflowCopilotStreamMessageType.ERROR, error="An error occurred. Please try again."
|
||||
).model_dump(mode="json"),
|
||||
)
|
||||
|
||||
def ping_message_factory() -> ServerSentEvent:
|
||||
return ServerSentEvent(comment="keep-alive")
|
||||
|
||||
@@ -39,14 +39,6 @@ class WorkflowCopilotChatRequest(BaseModel):
|
||||
workflow_yaml: str = Field(..., description="Current workflow YAML including unsaved changes")
|
||||
|
||||
|
||||
class WorkflowCopilotChatResponse(BaseModel):
|
||||
workflow_copilot_chat_id: str = Field(..., description="The chat ID")
|
||||
message: str = Field(..., description="The message sent to the user")
|
||||
updated_workflow_yaml: str | None = Field(None, description="The updated workflow yaml")
|
||||
request_time: datetime = Field(..., description="When the request was received")
|
||||
response_time: datetime = Field(..., description="When the assistant message was created")
|
||||
|
||||
|
||||
class WorkflowCopilotChatHistoryMessage(BaseModel):
|
||||
sender: WorkflowCopilotChatSender = Field(..., description="Message sender")
|
||||
content: str = Field(..., description="Message content")
|
||||
@@ -56,3 +48,32 @@ class WorkflowCopilotChatHistoryMessage(BaseModel):
|
||||
class WorkflowCopilotChatHistoryResponse(BaseModel):
|
||||
workflow_copilot_chat_id: str | None = Field(None, description="Latest chat ID for the workflow")
|
||||
chat_history: list[WorkflowCopilotChatHistoryMessage] = Field(default_factory=list, description="Chat messages")
|
||||
|
||||
|
||||
class WorkflowCopilotStreamMessageType(StrEnum):
|
||||
PROCESSING_UPDATE = "processing_update"
|
||||
RESPONSE = "response"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class WorkflowCopilotProcessingUpdate(BaseModel):
|
||||
type: WorkflowCopilotStreamMessageType = Field(
|
||||
WorkflowCopilotStreamMessageType.PROCESSING_UPDATE, description="Message type"
|
||||
)
|
||||
status: str = Field(..., description="Processing status text")
|
||||
timestamp: datetime = Field(..., description="Server timestamp")
|
||||
|
||||
|
||||
class WorkflowCopilotStreamResponseUpdate(BaseModel):
|
||||
type: WorkflowCopilotStreamMessageType = Field(
|
||||
WorkflowCopilotStreamMessageType.RESPONSE, description="Message type"
|
||||
)
|
||||
workflow_copilot_chat_id: str = Field(..., description="The chat ID")
|
||||
message: str = Field(..., description="The message sent to the user")
|
||||
updated_workflow_yaml: str | None = Field(None, description="The updated workflow yaml")
|
||||
response_time: datetime = Field(..., description="When the assistant message was created")
|
||||
|
||||
|
||||
class WorkflowCopilotStreamErrorUpdate(BaseModel):
|
||||
type: WorkflowCopilotStreamMessageType = Field(WorkflowCopilotStreamMessageType.ERROR, description="Message type")
|
||||
error: str = Field(..., description="Error message")
|
||||
|
||||
Reference in New Issue
Block a user