Workflow Copilot: backend side of the first version (#4401)

This commit is contained in:
Stanislav Novosad
2026-01-06 14:58:44 -07:00
committed by GitHub
parent 1e314ce149
commit e3dd75d7c1
10 changed files with 1440 additions and 0 deletions

View File

@@ -47,6 +47,8 @@ from skyvern.forge.sdk.db.models import (
TaskV2Model,
ThoughtModel,
TOTPCodeModel,
WorkflowCopilotChatMessageModel,
WorkflowCopilotChatModel,
WorkflowModel,
WorkflowParameterModel,
WorkflowRunBlockModel,
@@ -72,6 +74,7 @@ from skyvern.forge.sdk.db.utils import (
convert_to_task,
convert_to_task_v2,
convert_to_workflow,
convert_to_workflow_copilot_chat_message,
convert_to_workflow_parameter,
convert_to_workflow_run,
convert_to_workflow_run_block,
@@ -100,6 +103,11 @@ from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Status, Thought, ThoughtType
from skyvern.forge.sdk.schemas.tasks import OrderBy, SortDirection, Task, TaskStatus
from skyvern.forge.sdk.schemas.totp_codes import OTPType, TOTPCode
from skyvern.forge.sdk.schemas.workflow_copilot import (
WorkflowCopilotChat,
WorkflowCopilotChatMessage,
WorkflowCopilotChatSender,
)
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
from skyvern.forge.sdk.workflow.models.parameter import (
AWSSecretParameter,
@@ -3640,6 +3648,91 @@ class AgentDB(BaseAlchemyDB):
await session.refresh(new_ai_suggestion)
return AISuggestion.model_validate(new_ai_suggestion)
async def create_workflow_copilot_chat(
self,
organization_id: str,
workflow_permanent_id: str,
) -> WorkflowCopilotChat:
async with self.Session() as session:
new_chat = WorkflowCopilotChatModel(
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
)
session.add(new_chat)
await session.commit()
await session.refresh(new_chat)
return WorkflowCopilotChat.model_validate(new_chat)
async def create_workflow_copilot_chat_message(
self,
organization_id: str,
workflow_copilot_chat_id: str,
sender: WorkflowCopilotChatSender,
content: str,
global_llm_context: str | None = None,
) -> WorkflowCopilotChatMessage:
async with self.Session() as session:
new_message = WorkflowCopilotChatMessageModel(
workflow_copilot_chat_id=workflow_copilot_chat_id,
organization_id=organization_id,
sender=sender,
content=content,
global_llm_context=global_llm_context,
)
session.add(new_message)
await session.commit()
await session.refresh(new_message)
return convert_to_workflow_copilot_chat_message(new_message, self.debug_enabled)
async def get_workflow_copilot_chat_messages(
self,
workflow_copilot_chat_id: str,
) -> list[WorkflowCopilotChatMessage]:
async with self.Session() as session:
query = (
select(WorkflowCopilotChatMessageModel)
.filter(WorkflowCopilotChatMessageModel.workflow_copilot_chat_id == workflow_copilot_chat_id)
.order_by(WorkflowCopilotChatMessageModel.workflow_copilot_chat_message_id.asc())
)
messages = (await session.scalars(query)).all()
return [convert_to_workflow_copilot_chat_message(message, self.debug_enabled) for message in messages]
async def get_workflow_copilot_chat_by_id(
self,
organization_id: str,
workflow_copilot_chat_id: str,
) -> WorkflowCopilotChat | None:
async with self.Session() as session:
query = (
select(WorkflowCopilotChatModel)
.filter(WorkflowCopilotChatModel.organization_id == organization_id)
.filter(WorkflowCopilotChatModel.workflow_copilot_chat_id == workflow_copilot_chat_id)
.order_by(WorkflowCopilotChatModel.created_at.desc())
.limit(1)
)
chat = (await session.scalars(query)).first()
if not chat:
return None
return WorkflowCopilotChat.model_validate(chat)
async def get_latest_workflow_copilot_chat(
self,
organization_id: str,
workflow_permanent_id: str,
) -> WorkflowCopilotChat | None:
async with self.Session() as session:
query = (
select(WorkflowCopilotChatModel)
.filter(WorkflowCopilotChatModel.organization_id == organization_id)
.filter(WorkflowCopilotChatModel.workflow_permanent_id == workflow_permanent_id)
.order_by(WorkflowCopilotChatModel.created_at.desc())
.limit(1)
)
chat = (await session.scalars(query)).first()
if not chat:
return None
return WorkflowCopilotChat.model_validate(chat)
async def get_task_generation_by_prompt_hash(
self,
user_prompt_hash: str,

View File

@@ -69,6 +69,8 @@ WORKFLOW_RUN_PREFIX = "wr"
WORKFLOW_SCRIPT_PREFIX = "ws"
WORKFLOW_TEMPLATE_PREFIX = "wt"
ORGANIZATION_BILLING_PREFIX = "ob"
WORKFLOW_COPILOT_CHAT_PREFIX = "wcc"
WORKFLOW_COPILOT_CHAT_MESSAGE_PREFIX = "wccm"
def generate_workflow_id() -> str:
@@ -266,6 +268,16 @@ def generate_billing_id() -> str:
return f"{ORGANIZATION_BILLING_PREFIX}_{int_id}"
def generate_workflow_copilot_chat_id() -> str:
int_id = generate_id()
return f"{WORKFLOW_COPILOT_CHAT_PREFIX}_{int_id}"
def generate_workflow_copilot_chat_message_id() -> str:
int_id = generate_id()
return f"{WORKFLOW_COPILOT_CHAT_MESSAGE_PREFIX}_{int_id}"
############# Helper functions below ##############
def generate_id() -> int:
"""

View File

@@ -51,6 +51,8 @@ from skyvern.forge.sdk.db.id import (
generate_task_v2_id,
generate_thought_id,
generate_totp_code_id,
generate_workflow_copilot_chat_id,
generate_workflow_copilot_chat_message_id,
generate_workflow_id,
generate_workflow_parameter_id,
generate_workflow_permanent_id,
@@ -1081,3 +1083,40 @@ class ScriptBlockModel(Base):
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
deleted_at = Column(DateTime, nullable=True)
class WorkflowCopilotChatModel(Base):
__tablename__ = "workflow_copilot_chats"
workflow_copilot_chat_id = Column(String, primary_key=True, default=generate_workflow_copilot_chat_id)
organization_id = Column(String, nullable=False)
workflow_permanent_id = Column(String, nullable=False, index=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(
DateTime,
default=datetime.datetime.utcnow,
onupdate=datetime.datetime.utcnow,
nullable=False,
)
class WorkflowCopilotChatMessageModel(Base):
__tablename__ = "workflow_copilot_chat_messages"
workflow_copilot_chat_message_id = Column(
String, primary_key=True, default=generate_workflow_copilot_chat_message_id
)
workflow_copilot_chat_id = Column(String, nullable=False, index=True)
organization_id = Column(String, nullable=False)
sender = Column(String, nullable=False)
content = Column(UnicodeText, nullable=False)
global_llm_context = Column(UnicodeText, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(
DateTime,
default=datetime.datetime.utcnow,
onupdate=datetime.datetime.utcnow,
nullable=False,
)

View File

@@ -21,6 +21,7 @@ from skyvern.forge.sdk.db.models import (
StepModel,
TaskModel,
TaskV2Model,
WorkflowCopilotChatMessageModel,
WorkflowModel,
WorkflowParameterModel,
WorkflowRunBlockModel,
@@ -39,6 +40,7 @@ from skyvern.forge.sdk.schemas.organizations import (
)
from skyvern.forge.sdk.schemas.task_v2 import TaskV2
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.schemas.workflow_copilot import WorkflowCopilotChatMessage as WorkflowCopilotChatMessageSchema
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
from skyvern.forge.sdk.workflow.models.parameter import (
AWSSecretParameter,
@@ -217,6 +219,17 @@ def convert_to_task_v2(task_v2_model: TaskV2Model, debug_enabled: bool = False)
return TaskV2.model_validate(task_v2_data)
def convert_to_workflow_copilot_chat_message(
message_model: WorkflowCopilotChatMessageModel, debug_enabled: bool = False
) -> WorkflowCopilotChatMessageSchema:
if debug_enabled:
LOG.debug(
"Converting WorkflowCopilotChatMessage to WorkflowCopilotChatMessageSchema",
workflow_copilot_chat_message_id=message_model.workflow_copilot_chat_message_id,
)
return WorkflowCopilotChatMessageSchema.model_validate(message_model)
def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
if debug_enabled:
LOG.debug("Converting StepModel to Step", step_id=step_model.step_id)

View File

@@ -9,6 +9,7 @@ from skyvern.forge.sdk.routes import run_blocks # noqa: F401
from skyvern.forge.sdk.routes import scripts # noqa: F401
from skyvern.forge.sdk.routes import sdk # noqa: F401
from skyvern.forge.sdk.routes import webhooks # noqa: F401
from skyvern.forge.sdk.routes import workflow_copilot # noqa: F401
from skyvern.forge.sdk.routes.streaming import messages # noqa: F401
from skyvern.forge.sdk.routes.streaming import screenshot # noqa: F401
from skyvern.forge.sdk.routes.streaming import vnc # noqa: F401

View File

@@ -0,0 +1,348 @@
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import structlog
import yaml
from fastapi import Depends, HTTPException, status
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.routes.routers import base_router
from skyvern.forge.sdk.routes.run_blocks import DEFAULT_LOGIN_PROMPT
from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.workflow_copilot import (
WorkflowCopilotChatHistoryMessage,
WorkflowCopilotChatHistoryResponse,
WorkflowCopilotChatMessage,
WorkflowCopilotChatRequest,
WorkflowCopilotChatResponse,
WorkflowCopilotChatSender,
)
from skyvern.forge.sdk.services import org_auth_service
from skyvern.schemas.workflows import LoginBlockYAML, WorkflowCreateYAMLRequest
WORKFLOW_KNOWLEDGE_BASE_PATH = Path("skyvern/forge/prompts/skyvern/workflow_knowledge_base.txt")
CHAT_HISTORY_CONTEXT_MESSAGES = 10
LOG = structlog.get_logger()
@dataclass(frozen=True)
class RunInfo:
block_label: str | None
block_type: str
block_status: str | None
failure_reason: str | None
html: str | None
async def _get_debug_artifact(organization_id: str, workflow_run_id: str) -> Artifact | None:
artifacts = await app.DATABASE.get_artifacts_for_run(
run_id=workflow_run_id, organization_id=organization_id, artifact_types=[ArtifactType.VISIBLE_ELEMENTS_TREE]
)
return artifacts[0] if isinstance(artifacts, list) else None
async def _get_debug_run_info(organization_id: str, workflow_run_id: str | None) -> RunInfo | None:
if not workflow_run_id:
return None
blocks = await app.DATABASE.get_workflow_run_blocks(
workflow_run_id=workflow_run_id, organization_id=organization_id
)
if not blocks:
return None
block = blocks[0]
artifact = await _get_debug_artifact(organization_id, workflow_run_id)
if artifact:
artifact_bytes = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
html = artifact_bytes.decode("utf-8") if artifact_bytes else None
else:
html = None
return RunInfo(
block_label=block.label,
block_type=block.block_type.name,
block_status=block.status,
failure_reason=block.failure_reason,
html=html,
)
async def copilot_call_llm(
organization_id: str,
chat_request: WorkflowCopilotChatRequest,
chat_history: list[WorkflowCopilotChatHistoryMessage],
global_llm_context: str | None,
debug_run_info_text: str,
) -> tuple[str, str | None, str | None]:
current_datetime = datetime.now(timezone.utc).isoformat()
chat_history_text = ""
if chat_history:
history_lines = [f"{msg.sender}: {msg.content}" for msg in chat_history]
chat_history_text = "\n".join(history_lines)
workflow_knowledge_base = WORKFLOW_KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8")
llm_prompt = prompt_engine.load_prompt(
template="workflow-copilot",
workflow_knowledge_base=workflow_knowledge_base,
workflow_yaml=chat_request.workflow_yaml or "",
user_message=chat_request.message,
chat_history=chat_history_text,
global_llm_context=global_llm_context or "",
current_datetime=current_datetime,
debug_run_info=debug_run_info_text,
)
LOG.info(
"Calling LLM for workflow copilot",
prompt_length=len(llm_prompt),
)
llm_response = await app.LLM_API_HANDLER(
prompt=llm_prompt,
prompt_name="workflow-copilot",
organization_id=organization_id,
)
if isinstance(llm_response, dict) and "output" in llm_response:
action_data = llm_response["output"]
else:
action_data = llm_response
if not isinstance(action_data, dict):
LOG.error(
"LLM response is not valid JSON",
organization_id=organization_id,
response_type=type(action_data).__name__,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Invalid response from LLM",
)
action_type = action_data.get("type")
user_response_value = action_data.get("user_response")
if user_response_value is None:
user_response = "I received your request but I'm not sure how to help. Could you rephrase?"
else:
user_response = str(user_response_value)
LOG.info(
"LLM response received",
organization_id=organization_id,
action_type=action_type,
)
global_llm_context = action_data.get("global_llm_context")
if global_llm_context is not None:
global_llm_context = str(global_llm_context)
if action_type == "REPLACE_WORKFLOW":
workflow_yaml = await _process_workflow_yaml(action_data)
return user_response, workflow_yaml, global_llm_context
elif action_type == "REPLY":
return user_response, None, global_llm_context
elif action_type == "ASK_QUESTION":
return user_response, None, global_llm_context
else:
LOG.error(
"Unknown action type from LLM",
organization_id=organization_id,
action_type=action_type,
)
return "I received your request but I'm not sure how to help. Could you rephrase?", None, None
async def _process_workflow_yaml(action_data: dict[str, Any]) -> None | str:
workflow_yaml = action_data.get("workflow_yaml", "")
try:
parsed_yaml = yaml.safe_load(workflow_yaml)
except yaml.YAMLError as e:
LOG.error(
"Invalid YAML from LLM",
error=str(e),
yaml=f"\n{str(e)}\n{workflow_yaml}",
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"LLM generated invalid YAML: {str(e)}",
)
try:
# Fixing trivial common LLM mistakes
workflow_definition = parsed_yaml.get("workflow_definition", None)
if workflow_definition:
blocks = workflow_definition.get("blocks", [])
for block in blocks:
block["title"] = block.get("title", "")
workflow = WorkflowCreateYAMLRequest.model_validate(parsed_yaml)
# Post-processing
for block in workflow.workflow_definition.blocks:
if isinstance(block, LoginBlockYAML) and not block.navigation_goal:
block.navigation_goal = DEFAULT_LOGIN_PROMPT
workflow_yaml = yaml.safe_dump(workflow.model_dump(mode="json"), sort_keys=False)
except Exception as e:
LOG.error(
"YAML from LLM does not conform to Skyvern workflow schema",
error=str(e),
yaml=f"\n{str(e)}\n{workflow_yaml}",
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"LLM generated YAML that doesn't match workflow schema: {str(e)}",
)
return workflow_yaml
@base_router.post("/workflow/copilot/chat-post", include_in_schema=False)
async def workflow_copilot_chat_post(
chat_request: WorkflowCopilotChatRequest,
organization: Organization = Depends(org_auth_service.get_current_org),
) -> WorkflowCopilotChatResponse:
LOG.info(
"Workflow copilot chat request",
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
workflow_run_id=chat_request.workflow_run_id,
message=chat_request.message,
workflow_yaml_length=len(chat_request.workflow_yaml),
organization_id=organization.organization_id,
)
request_started_at = datetime.now(timezone.utc)
if chat_request.workflow_copilot_chat_id:
chat = await app.DATABASE.get_workflow_copilot_chat_by_id(
organization_id=organization.organization_id,
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
)
if not chat:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
if chat_request.workflow_permanent_id != chat.workflow_permanent_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Wrong workflow permanent ID")
else:
chat = await app.DATABASE.create_workflow_copilot_chat(
organization_id=organization.organization_id,
workflow_permanent_id=chat_request.workflow_permanent_id,
)
chat_messages = await app.DATABASE.get_workflow_copilot_chat_messages(
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
)
global_llm_context = None
for message in reversed(chat_messages):
if message.global_llm_context is not None:
global_llm_context = message.global_llm_context
break
debug_run_info = await _get_debug_run_info(organization.organization_id, chat_request.workflow_run_id)
# Format debug run info for prompt
debug_run_info_text = ""
if debug_run_info:
debug_run_info_text = f"Block Label: {debug_run_info.block_label}"
debug_run_info_text += f" Block Type: {debug_run_info.block_type}"
debug_run_info_text += f" Status: {debug_run_info.block_status}"
if debug_run_info.failure_reason:
debug_run_info_text += f"\nFailure Reason: {debug_run_info.failure_reason}"
if debug_run_info.html:
debug_run_info_text += f"\n\nVisible Elements Tree (HTML):\n{debug_run_info.html}"
await app.DATABASE.create_workflow_copilot_chat_message(
organization_id=chat.organization_id,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
sender=WorkflowCopilotChatSender.USER,
content=chat_request.message,
)
try:
user_response, updated_workflow_yaml, updated_global_llm_context = await copilot_call_llm(
organization.organization_id,
chat_request,
convert_to_history_messages(chat_messages[-CHAT_HISTORY_CONTEXT_MESSAGES:]),
global_llm_context,
debug_run_info_text,
)
except HTTPException:
raise
except LLMProviderError as e:
LOG.error(
"LLM provider error",
organization_id=organization.organization_id,
error=str(e),
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process your request. Please try again.",
)
except Exception as e:
LOG.error(
"Unexpected error in workflow copilot",
organization_id=organization.organization_id,
error=str(e),
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An error occurred: {str(e)}",
)
assistant_message = await app.DATABASE.create_workflow_copilot_chat_message(
organization_id=chat.organization_id,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
sender=WorkflowCopilotChatSender.AI,
content=user_response,
global_llm_context=updated_global_llm_context,
)
return WorkflowCopilotChatResponse(
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
message=user_response,
updated_workflow_yaml=updated_workflow_yaml,
request_time=request_started_at,
response_time=assistant_message.created_at,
)
@base_router.get("/workflow/copilot/chat-history", include_in_schema=False)
async def workflow_copilot_chat_history(
workflow_permanent_id: str,
organization: Organization = Depends(org_auth_service.get_current_org),
) -> WorkflowCopilotChatHistoryResponse:
latest_chat = await app.DATABASE.get_latest_workflow_copilot_chat(
organization_id=organization.organization_id,
workflow_permanent_id=workflow_permanent_id,
)
if not latest_chat:
return WorkflowCopilotChatHistoryResponse(workflow_copilot_chat_id=None, chat_history=[])
chat_messages = await app.DATABASE.get_workflow_copilot_chat_messages(
workflow_copilot_chat_id=latest_chat.workflow_copilot_chat_id,
)
return WorkflowCopilotChatHistoryResponse(
workflow_copilot_chat_id=latest_chat.workflow_copilot_chat_id,
chat_history=convert_to_history_messages(chat_messages),
)
def convert_to_history_messages(
messages: list[WorkflowCopilotChatMessage],
) -> list[WorkflowCopilotChatHistoryMessage]:
return [
WorkflowCopilotChatHistoryMessage(
sender=message.sender,
content=message.content,
created_at=message.created_at,
)
for message in messages
]

View File

@@ -0,0 +1,58 @@
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel, ConfigDict, Field
class WorkflowCopilotChat(BaseModel):
model_config = ConfigDict(from_attributes=True)
workflow_copilot_chat_id: str = Field(..., description="ID for the workflow copilot chat")
organization_id: str = Field(..., description="Organization ID for the chat")
workflow_permanent_id: str = Field(..., description="Workflow permanent ID for the chat")
created_at: datetime = Field(..., description="When the chat was created")
modified_at: datetime = Field(..., description="When the chat was last modified")
class WorkflowCopilotChatSender(StrEnum):
USER = "user"
AI = "ai"
class WorkflowCopilotChatMessage(BaseModel):
model_config = ConfigDict(from_attributes=True)
workflow_copilot_chat_message_id: str = Field(..., description="ID for the workflow copilot chat message")
workflow_copilot_chat_id: str = Field(..., description="ID of the parent workflow copilot chat")
sender: WorkflowCopilotChatSender = Field(..., description="Message sender")
content: str = Field(..., description="Message content")
global_llm_context: str | None = Field(None, description="Optional global LLM context for the message")
created_at: datetime = Field(..., description="When the message was created")
modified_at: datetime = Field(..., description="When the message was last modified")
class WorkflowCopilotChatRequest(BaseModel):
workflow_permanent_id: str = Field(..., description="Workflow permanent ID for the chat")
workflow_copilot_chat_id: str | None = Field(None, description="The chat ID to send the message to")
workflow_run_id: str | None = Field(None, description="The workflow run ID to use for the context")
message: str = Field(..., description="The message that user sends")
workflow_yaml: str = Field(..., description="Current workflow YAML including unsaved changes")
class WorkflowCopilotChatResponse(BaseModel):
workflow_copilot_chat_id: str = Field(..., description="The chat ID")
message: str = Field(..., description="The message sent to the user")
updated_workflow_yaml: str | None = Field(None, description="The updated workflow yaml")
request_time: datetime = Field(..., description="When the request was received")
response_time: datetime = Field(..., description="When the assistant message was created")
class WorkflowCopilotChatHistoryMessage(BaseModel):
sender: WorkflowCopilotChatSender = Field(..., description="Message sender")
content: str = Field(..., description="Message content")
created_at: datetime = Field(..., description="When the message was created")
class WorkflowCopilotChatHistoryResponse(BaseModel):
workflow_copilot_chat_id: str | None = Field(None, description="Latest chat ID for the workflow")
chat_history: list[WorkflowCopilotChatHistoryMessage] = Field(default_factory=list, description="Chat messages")