Workflow Copilot: Use streaming in /chat-post (#4437)
This commit is contained in:
committed by
GitHub
parent
1d38c7bfe8
commit
a6f0781491
@@ -2,11 +2,12 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import structlog
|
||||
import yaml
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from sse_starlette import EventSourceResponse, JSONServerSentEvent, ServerSentEvent
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
@@ -29,6 +30,7 @@ from skyvern.schemas.workflows import LoginBlockYAML, WorkflowCreateYAMLRequest
|
||||
|
||||
WORKFLOW_KNOWLEDGE_BASE_PATH = Path("skyvern/forge/prompts/skyvern/workflow_knowledge_base.txt")
|
||||
CHAT_HISTORY_CONTEXT_MESSAGES = 10
|
||||
SSE_KEEPALIVE_INTERVAL_SECONDS = 10
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
@@ -236,7 +238,7 @@ async def _process_workflow_yaml(action_data: dict[str, Any]) -> None | str:
|
||||
async def workflow_copilot_chat_post(
|
||||
chat_request: WorkflowCopilotChatRequest,
|
||||
organization: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> WorkflowCopilotChatResponse:
|
||||
) -> EventSourceResponse:
|
||||
LOG.info(
|
||||
"Workflow copilot chat request",
|
||||
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
||||
@@ -292,53 +294,57 @@ async def workflow_copilot_chat_post(
|
||||
content=chat_request.message,
|
||||
)
|
||||
|
||||
try:
|
||||
user_response, updated_workflow_yaml, updated_global_llm_context = await copilot_call_llm(
|
||||
organization.organization_id,
|
||||
chat_request,
|
||||
convert_to_history_messages(chat_messages[-CHAT_HISTORY_CONTEXT_MESSAGES:]),
|
||||
global_llm_context,
|
||||
debug_run_info_text,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except LLMProviderError as e:
|
||||
LOG.error(
|
||||
"LLM provider error",
|
||||
organization_id=organization.organization_id,
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to process your request. Please try again.",
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.error(
|
||||
"Unexpected error in workflow copilot",
|
||||
organization_id=organization.organization_id,
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"An error occurred: {str(e)}",
|
||||
)
|
||||
async def event_stream() -> AsyncGenerator[JSONServerSentEvent, None]:
|
||||
try:
|
||||
user_response, updated_workflow_yaml, updated_global_llm_context = await copilot_call_llm(
|
||||
organization.organization_id,
|
||||
chat_request,
|
||||
convert_to_history_messages(chat_messages[-CHAT_HISTORY_CONTEXT_MESSAGES:]),
|
||||
global_llm_context,
|
||||
debug_run_info_text,
|
||||
)
|
||||
assistant_message = await app.DATABASE.create_workflow_copilot_chat_message(
|
||||
organization_id=chat.organization_id,
|
||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||
sender=WorkflowCopilotChatSender.AI,
|
||||
content=user_response,
|
||||
global_llm_context=updated_global_llm_context,
|
||||
)
|
||||
|
||||
assistant_message = await app.DATABASE.create_workflow_copilot_chat_message(
|
||||
organization_id=chat.organization_id,
|
||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||
sender=WorkflowCopilotChatSender.AI,
|
||||
content=user_response,
|
||||
global_llm_context=updated_global_llm_context,
|
||||
)
|
||||
response_payload = WorkflowCopilotChatResponse(
|
||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||
message=user_response,
|
||||
updated_workflow_yaml=updated_workflow_yaml,
|
||||
request_time=request_started_at,
|
||||
response_time=assistant_message.created_at,
|
||||
).model_dump(mode="json")
|
||||
yield JSONServerSentEvent(response_payload)
|
||||
except HTTPException as exc:
|
||||
yield JSONServerSentEvent({"error": exc.detail})
|
||||
except LLMProviderError as exc:
|
||||
LOG.error(
|
||||
"LLM provider error",
|
||||
organization_id=organization.organization_id,
|
||||
error=str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
yield JSONServerSentEvent({"error": "Failed to process your request. Please try again."})
|
||||
except Exception as exc:
|
||||
LOG.error(
|
||||
"Unexpected error in workflow copilot",
|
||||
organization_id=organization.organization_id,
|
||||
error=str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
yield JSONServerSentEvent({"error": "An error occurred. Please try again."})
|
||||
|
||||
return WorkflowCopilotChatResponse(
|
||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||
message=user_response,
|
||||
updated_workflow_yaml=updated_workflow_yaml,
|
||||
request_time=request_started_at,
|
||||
response_time=assistant_message.created_at,
|
||||
def ping_message_factory() -> ServerSentEvent:
|
||||
return ServerSentEvent(comment="keep-alive")
|
||||
|
||||
return EventSourceResponse(
|
||||
event_stream(),
|
||||
ping=SSE_KEEPALIVE_INTERVAL_SECONDS,
|
||||
ping_message_factory=ping_message_factory,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user