Move the code over from private repository (#3)
This commit is contained in:
0
skyvern/forge/__init__.py
Normal file
0
skyvern/forge/__init__.py
Normal file
18
skyvern/forge/__main__.py
Normal file
18
skyvern/forge/__main__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import structlog
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import skyvern.forge.sdk.forge_log as forge_log
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.stdlib.get_logger()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
forge_log.setup_logger()
|
||||
port = SettingsManager.get_settings().PORT
|
||||
LOG.info("Agent server starting.", host="0.0.0.0", port=port)
|
||||
load_dotenv()
|
||||
|
||||
reload = SettingsManager.get_settings().ENV == "local"
|
||||
uvicorn.run("skyvern.forge.app:app", host="0.0.0.0", port=port, log_level="info", reload=reload)
|
||||
985
skyvern/forge/agent.py
Normal file
985
skyvern/forge/agent.py
Normal file
@@ -0,0 +1,985 @@
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import Any, Tuple
|
||||
|
||||
import requests
|
||||
import structlog
|
||||
from playwright._impl._errors import TargetClosedError
|
||||
|
||||
from skyvern.exceptions import (
|
||||
BrowserStateMissingPage,
|
||||
FailedToSendWebhook,
|
||||
InvalidWorkflowTaskURLState,
|
||||
MissingBrowserStatePage,
|
||||
TaskNotFound,
|
||||
)
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.agent import Agent
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.models import Organization, Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskStatus
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.forge.sdk.workflow.context_manager import ContextManager
|
||||
from skyvern.forge.sdk.workflow.models.block import TaskBlock
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun
|
||||
from skyvern.webeye.actions.actions import Action, ActionType, CompleteAction, parse_actions
|
||||
from skyvern.webeye.actions.handler import ActionHandler
|
||||
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
|
||||
from skyvern.webeye.actions.responses import ActionResult
|
||||
from skyvern.webeye.browser_factory import BrowserState
|
||||
from skyvern.webeye.scraper.scraper import ScrapedPage, scrape_website
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class ForgeAgent(Agent):
|
||||
def __init__(self) -> None:
|
||||
LOG.info(
|
||||
"Initializing ForgeAgent",
|
||||
env=SettingsManager.get_settings().ENV,
|
||||
execute_all_steps=SettingsManager.get_settings().EXECUTE_ALL_STEPS,
|
||||
browser_type=SettingsManager.get_settings().BROWSER_TYPE,
|
||||
max_scraping_retries=SettingsManager.get_settings().MAX_SCRAPING_RETRIES,
|
||||
video_path=SettingsManager.get_settings().VIDEO_PATH,
|
||||
browser_action_timeout_ms=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS,
|
||||
max_steps_per_run=SettingsManager.get_settings().MAX_STEPS_PER_RUN,
|
||||
long_running_task_warning_ratio=SettingsManager.get_settings().LONG_RUNNING_TASK_WARNING_RATIO,
|
||||
debug_mode=SettingsManager.get_settings().DEBUG_MODE,
|
||||
)
|
||||
if SettingsManager.get_settings().ADDITIONAL_MODULES:
|
||||
for module in SettingsManager.get_settings().ADDITIONAL_MODULES:
|
||||
LOG.info("Loading additional module", module=module)
|
||||
__import__(module)
|
||||
LOG.info("Additional modules loaded", modules=SettingsManager.get_settings().ADDITIONAL_MODULES)
|
||||
|
||||
async def validate_step_execution(
|
||||
self,
|
||||
task: Task,
|
||||
step: Step,
|
||||
) -> None:
|
||||
"""
|
||||
Checks if the step can be executed.
|
||||
:return: A tuple of whether the step can be executed and a list of reasons why it can't be executed.
|
||||
"""
|
||||
reasons = []
|
||||
# can't execute if task status is not running
|
||||
has_valid_task_status = task.status == TaskStatus.running
|
||||
if not has_valid_task_status:
|
||||
reasons.append(f"invalid_task_status:{task.status}")
|
||||
# can't execute if the step is already running or completed
|
||||
has_valid_step_status = step.status in [StepStatus.created, StepStatus.failed]
|
||||
if not has_valid_step_status:
|
||||
reasons.append(f"invalid_step_status:{step.status}")
|
||||
# can't execute if the task has another step that is running
|
||||
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
|
||||
has_no_running_steps = not any(step.status == StepStatus.running for step in steps)
|
||||
if not has_no_running_steps:
|
||||
reasons.append(f"another_step_is_running_for_task:{task.task_id}")
|
||||
|
||||
can_execute = has_valid_task_status and has_valid_step_status and has_no_running_steps
|
||||
if not can_execute:
|
||||
raise Exception(f"Cannot execute step. Reasons: {reasons}, Step: {step}")
|
||||
|
||||
async def create_task_and_step_from_block(
|
||||
self,
|
||||
task_block: TaskBlock,
|
||||
workflow: Workflow,
|
||||
workflow_run: WorkflowRun,
|
||||
context_manager: ContextManager,
|
||||
task_order: int,
|
||||
task_retry: int,
|
||||
) -> tuple[Task, Step]:
|
||||
task_block_parameters = task_block.parameters
|
||||
navigation_payload = {}
|
||||
for parameter in task_block_parameters:
|
||||
navigation_payload[parameter.key] = context_manager.get_value(parameter.key)
|
||||
|
||||
task_url = task_block.url
|
||||
if task_url is None:
|
||||
browser_state = await app.BROWSER_MANAGER.get_or_create_for_workflow_run(workflow_run=workflow_run)
|
||||
if not browser_state.page:
|
||||
LOG.error("BrowserState has no page", workflow_run_id=workflow_run.workflow_run_id)
|
||||
raise MissingBrowserStatePage(workflow_run_id=workflow_run.workflow_run_id)
|
||||
|
||||
if browser_state.page.url == "about:blank":
|
||||
raise InvalidWorkflowTaskURLState(workflow_run.workflow_run_id)
|
||||
|
||||
task_url = browser_state.page.url
|
||||
|
||||
task = await app.DATABASE.create_task(
|
||||
url=task_url,
|
||||
webhook_callback_url=None,
|
||||
navigation_goal=task_block.navigation_goal,
|
||||
data_extraction_goal=task_block.data_extraction_goal,
|
||||
navigation_payload=navigation_payload,
|
||||
organization_id=workflow.organization_id,
|
||||
proxy_location=workflow_run.proxy_location,
|
||||
extracted_information_schema=task_block.data_schema,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
order=task_order,
|
||||
retry=task_retry,
|
||||
)
|
||||
LOG.info(
|
||||
"Created new task for workflow run",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
task_id=task.task_id,
|
||||
url=task.url,
|
||||
nav_goal=task.navigation_goal,
|
||||
data_goal=task.data_extraction_goal,
|
||||
proxy_location=task.proxy_location,
|
||||
task_order=task_order,
|
||||
task_retry=task_retry,
|
||||
)
|
||||
# Update task status to running
|
||||
task = await app.DATABASE.update_task(
|
||||
task_id=task.task_id, organization_id=task.organization_id, status=TaskStatus.running
|
||||
)
|
||||
step = await app.DATABASE.create_step(
|
||||
task.task_id,
|
||||
order=0,
|
||||
retry_index=0,
|
||||
organization_id=task.organization_id,
|
||||
)
|
||||
LOG.info(
|
||||
"Created new step for workflow run",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
step_id=step.step_id,
|
||||
task_id=task.task_id,
|
||||
order=step.order,
|
||||
retry_index=step.retry_index,
|
||||
)
|
||||
return task, step
|
||||
|
||||
async def create_task(self, task_request: TaskRequest, organization_id: str | None = None) -> Task:
|
||||
task = await app.DATABASE.create_task(
|
||||
url=task_request.url,
|
||||
webhook_callback_url=task_request.webhook_callback_url,
|
||||
navigation_goal=task_request.navigation_goal,
|
||||
data_extraction_goal=task_request.data_extraction_goal,
|
||||
navigation_payload=task_request.navigation_payload,
|
||||
organization_id=organization_id,
|
||||
proxy_location=task_request.proxy_location,
|
||||
extracted_information_schema=task_request.extracted_information_schema,
|
||||
)
|
||||
LOG.info(
|
||||
"Created new task",
|
||||
task_id=task.task_id,
|
||||
url=task.url,
|
||||
nav_goal=task.navigation_goal,
|
||||
data_goal=task.data_extraction_goal,
|
||||
proxy_location=task.proxy_location,
|
||||
)
|
||||
return task
|
||||
|
||||
async def execute_step(
|
||||
self,
|
||||
organization: Organization,
|
||||
task: Task,
|
||||
step: Step,
|
||||
api_key: str | None = None,
|
||||
workflow_run: WorkflowRun | None = None,
|
||||
close_browser_on_completion: bool = True,
|
||||
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
|
||||
next_step: Step | None = None
|
||||
detailed_output: DetailedAgentStepOutput | None = None
|
||||
try:
|
||||
# Check some conditions before executing the step, throw an exception if the step can't be executed
|
||||
await self.validate_step_execution(task, step)
|
||||
step, browser_state, detailed_output = await self._initialize_execution_state(task, step, workflow_run)
|
||||
step, detailed_output = await self.agent_step(task, step, browser_state, organization=organization)
|
||||
retry = False
|
||||
|
||||
# If the step failed, mark the step as failed and retry
|
||||
if step.status == StepStatus.failed:
|
||||
maybe_next_step = await self.handle_failed_step(task, step)
|
||||
# If there is no next step, it means that the task has failed
|
||||
if maybe_next_step:
|
||||
next_step = maybe_next_step
|
||||
retry = True
|
||||
else:
|
||||
await self.send_task_response(
|
||||
task=task,
|
||||
last_step=step,
|
||||
api_key=api_key,
|
||||
close_browser_on_completion=close_browser_on_completion,
|
||||
)
|
||||
return step, detailed_output, None
|
||||
elif step.status == StepStatus.completed:
|
||||
# TODO (kerem): keep the task object uptodate at all times so that send_task_response can just use it
|
||||
is_task_completed, maybe_last_step, maybe_next_step = await self.handle_completed_step(
|
||||
organization, task, step
|
||||
)
|
||||
if is_task_completed is not None and maybe_last_step:
|
||||
last_step = maybe_last_step
|
||||
await self.send_task_response(
|
||||
task=task,
|
||||
last_step=last_step,
|
||||
api_key=api_key,
|
||||
close_browser_on_completion=close_browser_on_completion,
|
||||
)
|
||||
return last_step, detailed_output, None
|
||||
elif maybe_next_step:
|
||||
next_step = maybe_next_step
|
||||
retry = False
|
||||
else:
|
||||
LOG.error(
|
||||
"Step completed but task is not completed and next step is not created.",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
is_task_completed=is_task_completed,
|
||||
maybe_last_step=maybe_last_step,
|
||||
maybe_next_step=maybe_next_step,
|
||||
)
|
||||
else:
|
||||
LOG.error(
|
||||
"Unexpected step status after agent_step",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_status=step.status,
|
||||
)
|
||||
|
||||
if retry and next_step:
|
||||
return await self.execute_step(
|
||||
organization,
|
||||
task,
|
||||
next_step,
|
||||
api_key=api_key,
|
||||
)
|
||||
elif SettingsManager.get_settings().execute_all_steps() and next_step:
|
||||
return await self.execute_step(
|
||||
organization,
|
||||
task,
|
||||
next_step,
|
||||
api_key=api_key,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
"Step executed but continuous execution is disabled.",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
is_cloud_env=SettingsManager.get_settings().is_cloud_environment(),
|
||||
execute_all_steps=SettingsManager.get_settings().execute_all_steps(),
|
||||
next_step_id=next_step.step_id if next_step else None,
|
||||
)
|
||||
|
||||
return step, detailed_output, next_step
|
||||
# TODO (kerem): Let's add other exceptions that we know about here as custom exceptions as well
|
||||
except FailedToSendWebhook:
|
||||
LOG.exception(
|
||||
"Failed to send webhook",
|
||||
exc_info=True,
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
task=task,
|
||||
step=step,
|
||||
)
|
||||
return step, detailed_output, next_step
|
||||
|
||||
async def agent_step(
|
||||
self,
|
||||
task: Task,
|
||||
step: Step,
|
||||
browser_state: BrowserState,
|
||||
organization: Organization | None = None,
|
||||
) -> tuple[Step, DetailedAgentStepOutput]:
|
||||
detailed_agent_step_output = DetailedAgentStepOutput(
|
||||
scraped_page=None,
|
||||
extract_action_prompt=None,
|
||||
llm_response=None,
|
||||
actions=None,
|
||||
action_results=None,
|
||||
actions_and_results=None,
|
||||
)
|
||||
try:
|
||||
LOG.info(
|
||||
"Starting agent step",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
)
|
||||
step = await self.update_step(step=step, status=StepStatus.running)
|
||||
scraped_page, extract_action_prompt = await self._build_and_record_step_prompt(
|
||||
task,
|
||||
step,
|
||||
browser_state,
|
||||
)
|
||||
detailed_agent_step_output.scraped_page = scraped_page
|
||||
detailed_agent_step_output.extract_action_prompt = extract_action_prompt
|
||||
json_response = None
|
||||
actions: list[Action]
|
||||
if task.navigation_goal:
|
||||
json_response = await app.OPENAI_CLIENT.chat_completion(
|
||||
step=step,
|
||||
prompt=extract_action_prompt,
|
||||
screenshots=scraped_page.screenshots,
|
||||
)
|
||||
detailed_agent_step_output.llm_response = json_response
|
||||
|
||||
actions = parse_actions(task, json_response["actions"])
|
||||
else:
|
||||
actions = [
|
||||
CompleteAction(
|
||||
reasoning="Task has no navigation goal.", data_extraction_goal=task.data_extraction_goal
|
||||
)
|
||||
]
|
||||
detailed_agent_step_output.actions = actions
|
||||
if len(actions) == 0:
|
||||
LOG.info(
|
||||
"No actions to execute, marking step as failed",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
)
|
||||
step = await self.update_step(
|
||||
step=step, status=StepStatus.failed, output=detailed_agent_step_output.to_agent_step_output()
|
||||
)
|
||||
detailed_agent_step_output = DetailedAgentStepOutput(
|
||||
scraped_page=scraped_page,
|
||||
extract_action_prompt=extract_action_prompt,
|
||||
llm_response=json_response,
|
||||
actions=actions,
|
||||
action_results=[],
|
||||
actions_and_results=[],
|
||||
)
|
||||
return step, detailed_agent_step_output
|
||||
|
||||
# Execute the actions
|
||||
LOG.info(
|
||||
"Executing actions",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
actions=actions,
|
||||
)
|
||||
action_results: list[ActionResult] = []
|
||||
detailed_agent_step_output.action_results = action_results
|
||||
# filter out wait action if there are other actions in the list
|
||||
# we do this because WAIT action is considered as a failure
|
||||
# which will block following actions if we don't remove it from the list
|
||||
# if the list only contains WAIT action, we will execute WAIT action(s)
|
||||
if len(actions) > 1:
|
||||
wait_actions_to_skip = [action for action in actions if action.action_type == ActionType.WAIT]
|
||||
wait_actions_len = len(wait_actions_to_skip)
|
||||
# if there are wait actions and there are other actions in the list, skip wait actions
|
||||
if wait_actions_len > 0 and wait_actions_len < len(actions):
|
||||
actions = [action for action in actions if action.action_type != ActionType.WAIT]
|
||||
LOG.info("Skipping wait actions", wait_actions_to_skip=wait_actions_to_skip, actions=actions)
|
||||
|
||||
# initialize list of tuples and set actions as the first element of each tuple so that in the case
|
||||
# of an exception, we can still see all the actions
|
||||
detailed_agent_step_output.actions_and_results = [(action, []) for action in actions]
|
||||
|
||||
for action_idx, action in enumerate(actions):
|
||||
results = await ActionHandler.handle_action(scraped_page, task, step, browser_state, action)
|
||||
detailed_agent_step_output.actions_and_results[action_idx] = (action, results)
|
||||
# wait random time between actions to avoid detection
|
||||
await asyncio.sleep(random.uniform(1.0, 2.0))
|
||||
await self.record_artifacts_after_action(task, step, browser_state)
|
||||
for result in results:
|
||||
result.step_retry_number = step.retry_index
|
||||
result.step_order = step.order
|
||||
action_results.extend(results)
|
||||
# Check the last result for this action. If that succeeded, assume the entire action is successful
|
||||
if results and results[-1].success:
|
||||
LOG.info(
|
||||
"Action succeeded",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
action_idx=action_idx,
|
||||
action=action,
|
||||
action_result=results,
|
||||
)
|
||||
# if the action triggered javascript calls
|
||||
# this action should be the last action this round and do not take more actions.
|
||||
# for now, we're being optimistic and assuming that
|
||||
# js call doesn't have impact on the following actions
|
||||
if results[-1].javascript_triggered:
|
||||
LOG.info("Action triggered javascript, ", action=action)
|
||||
else:
|
||||
LOG.warning(
|
||||
"Action failed, marking step as failed",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
action_idx=action_idx,
|
||||
action=action,
|
||||
action_result=results,
|
||||
actions_and_results=detailed_agent_step_output.actions_and_results,
|
||||
)
|
||||
# if the action failed, don't execute the rest of the actions, mark the step as failed, and retry
|
||||
failed_step = await self.update_step(
|
||||
step=step, status=StepStatus.failed, output=detailed_agent_step_output.to_agent_step_output()
|
||||
)
|
||||
return failed_step, detailed_agent_step_output
|
||||
|
||||
LOG.info(
|
||||
"Actions executed successfully, marking step as completed",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
action_results=action_results,
|
||||
)
|
||||
# If no action errors return the agent state and output
|
||||
completed_step = await self.update_step(
|
||||
step=step, status=StepStatus.completed, output=detailed_agent_step_output.to_agent_step_output()
|
||||
)
|
||||
return completed_step, detailed_agent_step_output
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Unexpected exception in agent_step, marking step as failed",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
)
|
||||
failed_step = await self.update_step(
|
||||
step=step, status=StepStatus.failed, output=detailed_agent_step_output.to_agent_step_output()
|
||||
)
|
||||
return failed_step, detailed_agent_step_output
|
||||
|
||||
async def record_artifacts_after_action(self, task: Task, step: Step, browser_state: BrowserState) -> None:
|
||||
if not browser_state.page:
|
||||
raise BrowserStateMissingPage()
|
||||
try:
|
||||
screenshot = await browser_state.page.screenshot(full_page=True)
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.SCREENSHOT_ACTION,
|
||||
data=screenshot,
|
||||
)
|
||||
except Exception:
|
||||
LOG.error(
|
||||
"Failed to record screenshot after action",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
html = await browser_state.page.content()
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.HTML_ACTION,
|
||||
data=html.encode(),
|
||||
)
|
||||
except Exception:
|
||||
LOG.error(
|
||||
"Failed to record html after action",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
video_data = await app.BROWSER_MANAGER.get_video_data(task_id=task.task_id, browser_state=browser_state)
|
||||
await app.ARTIFACT_MANAGER.update_artifact_data(
|
||||
artifact_id=browser_state.browser_artifacts.video_artifact_id,
|
||||
organization_id=task.organization_id,
|
||||
data=video_data,
|
||||
)
|
||||
except Exception:
|
||||
LOG.error(
|
||||
"Failed to record video after action",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _initialize_execution_state(
|
||||
self, task: Task, step: Step, workflow_run: WorkflowRun | None = None
|
||||
) -> tuple[Step, BrowserState, DetailedAgentStepOutput]:
|
||||
if workflow_run:
|
||||
browser_state = await app.BROWSER_MANAGER.get_or_create_for_workflow_run(
|
||||
workflow_run=workflow_run, url=task.url
|
||||
)
|
||||
else:
|
||||
browser_state = await app.BROWSER_MANAGER.get_or_create_for_task(task)
|
||||
# Initialize video artifact for the task here, afterwards it'll only get updated
|
||||
if browser_state and not browser_state.browser_artifacts.video_artifact_id:
|
||||
video_data = await app.BROWSER_MANAGER.get_video_data(task_id=task.task_id, browser_state=browser_state)
|
||||
video_artifact_id = await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.RECORDING,
|
||||
data=video_data,
|
||||
)
|
||||
app.BROWSER_MANAGER.set_video_artifact_for_task(task, video_artifact_id)
|
||||
|
||||
detailed_output = DetailedAgentStepOutput(
|
||||
scraped_page=None,
|
||||
extract_action_prompt=None,
|
||||
llm_response=None,
|
||||
actions=None,
|
||||
action_results=None,
|
||||
actions_and_results=None,
|
||||
)
|
||||
return step, browser_state, detailed_output
|
||||
|
||||
async def _build_and_record_step_prompt(
|
||||
self,
|
||||
task: Task,
|
||||
step: Step,
|
||||
browser_state: BrowserState,
|
||||
) -> tuple[ScrapedPage, str]:
|
||||
# Scrape the web page and get the screenshot and the elements
|
||||
scraped_page = await scrape_website(
|
||||
browser_state,
|
||||
task.url,
|
||||
)
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.HTML_SCRAPE,
|
||||
data=scraped_page.html.encode(),
|
||||
)
|
||||
LOG.info(
|
||||
"Scraped website",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
num_elements=len(scraped_page.elements),
|
||||
url=task.url,
|
||||
)
|
||||
# Get action results from the last app.SETTINGS.PROMPT_ACTION_HISTORY_WINDOW steps
|
||||
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
|
||||
window_steps = steps[-1 * SettingsManager.get_settings().PROMPT_ACTION_HISTORY_WINDOW :]
|
||||
action_results: list[ActionResult] = []
|
||||
for window_step in window_steps:
|
||||
if window_step.output and window_step.output.action_results:
|
||||
action_results.extend(window_step.output.action_results)
|
||||
action_results_str = json.dumps([action_result.model_dump() for action_result in action_results])
|
||||
# Generate the extract action prompt
|
||||
navigation_goal = task.navigation_goal
|
||||
extract_action_prompt = prompt_engine.load_prompt(
|
||||
"extract-action",
|
||||
navigation_goal=navigation_goal,
|
||||
navigation_payload_str=json.dumps(task.navigation_payload),
|
||||
url=task.url,
|
||||
elements=scraped_page.element_tree_trimmed, # scraped_page.element_tree,
|
||||
data_extraction_goal=task.data_extraction_goal,
|
||||
action_history=action_results_str,
|
||||
utc_datetime=datetime.utcnow(),
|
||||
)
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.VISIBLE_ELEMENTS_ID_XPATH_MAP,
|
||||
data=json.dumps(scraped_page.id_to_xpath_dict, indent=2).encode(),
|
||||
)
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE,
|
||||
data=json.dumps(scraped_page.element_tree, indent=2).encode(),
|
||||
)
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE_TRIMMED,
|
||||
data=json.dumps(scraped_page.element_tree_trimmed, indent=2).encode(),
|
||||
)
|
||||
|
||||
return scraped_page, extract_action_prompt
|
||||
|
||||
async def get_extracted_information_for_task(self, task: Task) -> dict[str, Any] | list | str | None:
|
||||
"""
|
||||
Find the last successful ScrapeAction for the task and return the extracted information.
|
||||
"""
|
||||
steps = await app.DATABASE.get_task_steps(
|
||||
task_id=task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
)
|
||||
for step in reversed(steps):
|
||||
if step.status != StepStatus.completed:
|
||||
continue
|
||||
if not step.output or not step.output.actions_and_results:
|
||||
continue
|
||||
for action, action_results in step.output.actions_and_results:
|
||||
if action.action_type != ActionType.COMPLETE:
|
||||
continue
|
||||
|
||||
for action_result in action_results:
|
||||
if action_result.success:
|
||||
LOG.info(
|
||||
"Extracted information for task",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
extracted_information=action_result.data,
|
||||
)
|
||||
return action_result.data
|
||||
|
||||
LOG.warning(
|
||||
"Failed to find extracted information for task",
|
||||
task_id=task.task_id,
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_failure_reason_for_task(self, task: Task) -> str | None:
|
||||
"""
|
||||
Find the TerminateAction for the task and return the reasoning.
|
||||
# TODO (kerem): Also return meaningful exceptions when we add them [WYV-311]
|
||||
"""
|
||||
steps = await app.DATABASE.get_task_steps(
|
||||
task_id=task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
)
|
||||
for step in reversed(steps):
|
||||
if step.status != StepStatus.completed:
|
||||
continue
|
||||
if not step.output:
|
||||
continue
|
||||
|
||||
if step.output.actions_and_results:
|
||||
for action, action_results in step.output.actions_and_results:
|
||||
if action.action_type == ActionType.TERMINATE:
|
||||
return action.reasoning
|
||||
|
||||
LOG.error(
|
||||
"Failed to find failure reasoning for task",
|
||||
task_id=task.task_id,
|
||||
)
|
||||
return None
|
||||
|
||||
async def send_task_response(
|
||||
self,
|
||||
task: Task,
|
||||
last_step: Step,
|
||||
api_key: str | None = None,
|
||||
close_browser_on_completion: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
send the task response to the webhook callback url
|
||||
"""
|
||||
# Take one last screenshot and create an artifact before closing the browser to see the final state
|
||||
browser_state: BrowserState = await app.BROWSER_MANAGER.get_or_create_for_task(task)
|
||||
page = await browser_state.get_or_create_page()
|
||||
try:
|
||||
screenshot = await page.screenshot(full_page=True)
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=last_step,
|
||||
artifact_type=ArtifactType.SCREENSHOT_FINAL,
|
||||
data=screenshot,
|
||||
)
|
||||
except TargetClosedError as e:
|
||||
LOG.warning(
|
||||
"Failed to take screenshot before sending task response, page is closed",
|
||||
task_id=task.task_id,
|
||||
step_id=last_step.step_id,
|
||||
error=e,
|
||||
)
|
||||
|
||||
if task.workflow_run_id:
|
||||
LOG.info(
|
||||
"Task is part of a workflow run, not sending a webhook response",
|
||||
task_id=task.task_id,
|
||||
workflow_run_id=task.workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
await self.cleanup_browser_and_create_artifacts(close_browser_on_completion, last_step, task)
|
||||
|
||||
# Wait for all tasks to complete before generating the links for the artifacts
|
||||
await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_task(task.task_id)
|
||||
|
||||
if not task.webhook_callback_url:
|
||||
LOG.warning(
|
||||
"Task has no webhook callback url. Not sending task response",
|
||||
task_id=task.task_id,
|
||||
)
|
||||
return
|
||||
|
||||
if not api_key:
|
||||
LOG.warning(
|
||||
"Request has no api key. Not sending task response",
|
||||
task_id=task.task_id,
|
||||
)
|
||||
return
|
||||
|
||||
# get the artifact of the screenshot and get the screenshot_url
|
||||
screenshot_artifact = await app.DATABASE.get_artifact(
|
||||
task_id=task.task_id,
|
||||
step_id=last_step.step_id,
|
||||
artifact_type=ArtifactType.SCREENSHOT_FINAL,
|
||||
organization_id=task.organization_id,
|
||||
)
|
||||
screenshot_url = None
|
||||
if screenshot_artifact:
|
||||
screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(screenshot_artifact)
|
||||
|
||||
recording_artifact = await app.DATABASE.get_artifact(
|
||||
task_id=task.task_id,
|
||||
step_id=last_step.step_id,
|
||||
artifact_type=ArtifactType.RECORDING,
|
||||
organization_id=task.organization_id,
|
||||
)
|
||||
recording_url = None
|
||||
if recording_artifact:
|
||||
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
|
||||
|
||||
# get the latest task from the db to get the latest status, extracted_information, and failure_reason
|
||||
task_from_db = await app.DATABASE.get_task(task_id=task.task_id, organization_id=task.organization_id)
|
||||
if not task_from_db:
|
||||
LOG.error("Failed to get task from db when sending task response")
|
||||
raise TaskNotFound(task_id=task.task_id)
|
||||
task = task_from_db
|
||||
if not task.webhook_callback_url:
|
||||
LOG.info("Task has no webhook callback url. Not sending task response")
|
||||
return
|
||||
|
||||
task_response = task.to_task_response(screenshot_url=screenshot_url, recording_url=recording_url)
|
||||
|
||||
# send task_response to the webhook callback url
|
||||
# TODO: use async requests (httpx)
|
||||
timestamp = str(int(datetime.utcnow().timestamp()))
|
||||
payload = task_response.model_dump_json(exclude={"request": {"navigation_payload"}})
|
||||
signature = generate_skyvern_signature(
|
||||
payload=payload,
|
||||
api_key=api_key,
|
||||
)
|
||||
headers = {
|
||||
"x-skyvern-timestamp": timestamp,
|
||||
"x-skyvern-signature": signature,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
LOG.info(
|
||||
"Sending task response to webhook callback url",
|
||||
task_id=task.task_id,
|
||||
webhook_callback_url=task.webhook_callback_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
)
|
||||
try:
|
||||
resp = requests.post(task.webhook_callback_url, data=payload, headers=headers)
|
||||
if resp.ok:
|
||||
LOG.info(
|
||||
"Webhook sent successfully",
|
||||
task_id=task.task_id,
|
||||
resp_code=resp.status_code,
|
||||
resp_text=resp.text,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
"Webhook failed",
|
||||
task_id=task.task_id,
|
||||
resp=resp,
|
||||
resp_code=resp.status_code,
|
||||
resp_json=resp.json(),
|
||||
resp_text=resp.text,
|
||||
)
|
||||
except Exception as e:
|
||||
raise FailedToSendWebhook(task_id=task.task_id) from e
|
||||
|
||||
async def cleanup_browser_and_create_artifacts(
|
||||
self, close_browser_on_completion: bool, last_step: Step, task: Task
|
||||
) -> None:
|
||||
# We need to close the browser even if there is no webhook callback url or api key
|
||||
browser_state = await app.BROWSER_MANAGER.cleanup_for_task(task.task_id, close_browser_on_completion)
|
||||
if browser_state:
|
||||
# Update recording artifact after closing the browser, so we can get an accurate recording
|
||||
video_data = await app.BROWSER_MANAGER.get_video_data(task_id=task.task_id, browser_state=browser_state)
|
||||
if video_data:
|
||||
await app.ARTIFACT_MANAGER.update_artifact_data(
|
||||
artifact_id=browser_state.browser_artifacts.video_artifact_id,
|
||||
organization_id=task.organization_id,
|
||||
data=video_data,
|
||||
)
|
||||
|
||||
har_data = await app.BROWSER_MANAGER.get_har_data(task_id=task.task_id, browser_state=browser_state)
|
||||
if har_data:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=last_step,
|
||||
artifact_type=ArtifactType.HAR,
|
||||
data=har_data,
|
||||
)
|
||||
else:
|
||||
LOG.warning(
|
||||
"BrowserState is missing before sending response to webhook_callback_url",
|
||||
web_hook_url=task.webhook_callback_url,
|
||||
)
|
||||
|
||||
async def update_step(
|
||||
self,
|
||||
step: Step,
|
||||
status: StepStatus | None = None,
|
||||
output: AgentStepOutput | None = None,
|
||||
is_last: bool | None = None,
|
||||
retry_index: int | None = None,
|
||||
) -> Step:
|
||||
step.validate_update(status, output, is_last)
|
||||
updates: dict[str, Any] = {}
|
||||
if status is not None:
|
||||
updates["status"] = status
|
||||
if output is not None:
|
||||
updates["output"] = output
|
||||
if is_last is not None:
|
||||
updates["is_last"] = is_last
|
||||
if retry_index is not None:
|
||||
updates["retry_index"] = retry_index
|
||||
update_comparison = {
|
||||
key: {"old": getattr(step, key), "new": value}
|
||||
for key, value in updates.items()
|
||||
if getattr(step, key) != value
|
||||
}
|
||||
LOG.info(
|
||||
"Updating step in db",
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
diff=update_comparison,
|
||||
)
|
||||
return await app.DATABASE.update_step(
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
organization_id=step.organization_id,
|
||||
**updates,
|
||||
)
|
||||
|
||||
async def update_task(
|
||||
self,
|
||||
task: Task,
|
||||
status: TaskStatus,
|
||||
extracted_information: dict[str, Any] | list | str | None = None,
|
||||
failure_reason: str | None = None,
|
||||
) -> Task:
|
||||
task.validate_update(status, extracted_information, failure_reason)
|
||||
updates: dict[str, Any] = {}
|
||||
if status is not None:
|
||||
updates["status"] = status
|
||||
if extracted_information is not None:
|
||||
updates["extracted_information"] = extracted_information
|
||||
if failure_reason is not None:
|
||||
updates["failure_reason"] = failure_reason
|
||||
update_comparison = {
|
||||
key: {"old": getattr(task, key), "new": value}
|
||||
for key, value in updates.items()
|
||||
if getattr(task, key) != value
|
||||
}
|
||||
LOG.info("Updating task in db", task_id=task.task_id, diff=update_comparison)
|
||||
return await app.DATABASE.update_task(
|
||||
task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
**updates,
|
||||
)
|
||||
|
||||
async def handle_failed_step(self, task: Task, step: Step) -> Step | None:
|
||||
if step.retry_index >= SettingsManager.get_settings().MAX_RETRIES_PER_STEP:
|
||||
LOG.warning(
|
||||
"Step failed after max retries, marking task as failed",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
max_retries=SettingsManager.get_settings().MAX_RETRIES_PER_STEP,
|
||||
)
|
||||
await self.update_task(
|
||||
task,
|
||||
TaskStatus.failed,
|
||||
failure_reason=f"Max retries per step ({SettingsManager.get_settings().MAX_RETRIES_PER_STEP}) exceeded",
|
||||
)
|
||||
return None
|
||||
else:
|
||||
LOG.warning(
|
||||
"Step failed, retrying",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
)
|
||||
next_step = await app.DATABASE.create_step(
|
||||
task_id=task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
order=step.order,
|
||||
retry_index=step.retry_index + 1,
|
||||
)
|
||||
return next_step
|
||||
|
||||
async def handle_completed_step(
|
||||
self, organization: Organization, task: Task, step: Step
|
||||
) -> tuple[bool | None, Step | None, Step | None]:
|
||||
if step.is_goal_achieved():
|
||||
LOG.info(
|
||||
"Step completed and goal achieved, marking task as completed",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
output=step.output,
|
||||
)
|
||||
last_step = await self.update_step(step, is_last=True)
|
||||
extracted_information = await self.get_extracted_information_for_task(task)
|
||||
await self.update_task(task, status=TaskStatus.completed, extracted_information=extracted_information)
|
||||
return True, last_step, None
|
||||
if step.is_terminated():
|
||||
LOG.info(
|
||||
"Step completed and terminated by the agent, marking task as terminated",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
output=step.output,
|
||||
)
|
||||
last_step = await self.update_step(step, is_last=True)
|
||||
failure_reason = await self.get_failure_reason_for_task(task)
|
||||
await self.update_task(task, status=TaskStatus.terminated, failure_reason=failure_reason)
|
||||
return False, last_step, None
|
||||
# If the max steps are exceeded, mark the current step as the last step and conclude the task
|
||||
context = skyvern_context.ensure_context()
|
||||
override_max_steps_per_run = context.max_steps_override
|
||||
max_steps_per_run = (
|
||||
override_max_steps_per_run
|
||||
or organization.max_steps_per_run
|
||||
or SettingsManager.get_settings().MAX_STEPS_PER_RUN
|
||||
)
|
||||
if step.order + 1 >= max_steps_per_run:
|
||||
LOG.info(
|
||||
"Step completed but max steps reached, marking task as failed",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
output=step.output,
|
||||
max_steps=max_steps_per_run,
|
||||
)
|
||||
last_step = await self.update_step(step, is_last=True)
|
||||
await self.update_task(
|
||||
task,
|
||||
status=TaskStatus.failed,
|
||||
failure_reason=f"Max steps per task ({max_steps_per_run}) exceeded",
|
||||
)
|
||||
return False, last_step, None
|
||||
else:
|
||||
LOG.info(
|
||||
"Step completed, creating next step",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
output=step.output,
|
||||
)
|
||||
next_step = await app.DATABASE.create_step(
|
||||
task_id=task.task_id,
|
||||
order=step.order + 1,
|
||||
retry_index=0,
|
||||
organization_id=task.organization_id,
|
||||
)
|
||||
|
||||
if step.order == int(
|
||||
max_steps_per_run * SettingsManager.get_settings().LONG_RUNNING_TASK_WARNING_RATIO - 1
|
||||
):
|
||||
LOG.info(
|
||||
"Long running task warning",
|
||||
order=step.order,
|
||||
max_steps=max_steps_per_run,
|
||||
warning_ratio=SettingsManager.get_settings().LONG_RUNNING_TASK_WARNING_RATIO,
|
||||
)
|
||||
return None, None, next_step
|
||||
36
skyvern/forge/app.py
Normal file
36
skyvern/forge/app.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from ddtrace import tracer
|
||||
from ddtrace.filters import FilterRequestsOnUrl
|
||||
|
||||
from skyvern.forge.agent import ForgeAgent
|
||||
from skyvern.forge.sdk.api.open_ai import OpenAIClientManager
|
||||
from skyvern.forge.sdk.artifact.manager import ArtifactManager
|
||||
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
|
||||
from skyvern.forge.sdk.db.client import AgentDB
|
||||
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
|
||||
from skyvern.forge.sdk.forge_log import setup_logger
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.forge.sdk.workflow.service import WorkflowService
|
||||
from skyvern.webeye.browser_manager import BrowserManager
|
||||
|
||||
tracer.configure(
|
||||
settings={
|
||||
"FILTERS": [
|
||||
FilterRequestsOnUrl(r"http://.*/heartbeat$"),
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
setup_logger()
|
||||
SETTINGS_MANAGER = SettingsManager.get_settings()
|
||||
DATABASE = AgentDB(
|
||||
SettingsManager.get_settings().DATABASE_STRING, debug_enabled=SettingsManager.get_settings().DEBUG_MODE
|
||||
)
|
||||
STORAGE = StorageFactory.get_storage()
|
||||
ASYNC_EXECUTOR = AsyncExecutorFactory.get_executor()
|
||||
ARTIFACT_MANAGER = ArtifactManager()
|
||||
BROWSER_MANAGER = BrowserManager()
|
||||
OPENAI_CLIENT = OpenAIClientManager()
|
||||
WORKFLOW_SERVICE = WorkflowService()
|
||||
agent = ForgeAgent()
|
||||
|
||||
app = agent.get_agent_app()
|
||||
4
skyvern/forge/prompts.py
Normal file
4
skyvern/forge/prompts.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from skyvern.forge.sdk.prompting import PromptEngine
|
||||
|
||||
# Initialize the prompt engine
|
||||
prompt_engine = PromptEngine("skyvern")
|
||||
65
skyvern/forge/prompts/skyvern/extract-action.j2
Normal file
65
skyvern/forge/prompts/skyvern/extract-action.j2
Normal file
@@ -0,0 +1,65 @@
|
||||
Identify actions to help user progress towards the user goal using the DOM elements given in the list and the screenshot of the website.
|
||||
Include only the elements that are relevant to the user goal, without altering or imagining new elements.
|
||||
Use the details from the user details to fill in necessary values. Always complete required fields if the field isn't already filled in.
|
||||
MAKE SURE YOU OUTPUT VALID JSON. No text before or after JSON, no trailing commas, no comments (//), no unnecessary quotes, etc.
|
||||
Each element is tagged with an ID.
|
||||
If you see any information in red in the page screenshot, this means a condition wasn't satisfied. prioritize actions with the red information.
|
||||
If you see a popup in the page screenshot, prioritize actions on the popup.
|
||||
|
||||
{% if "lever" in url %}
|
||||
DO NOT UPDATE ANY LOCATION FIELDS
|
||||
{% endif %}
|
||||
|
||||
Reply in JSON format with the following keys:
|
||||
{
|
||||
"actions": array // An array of actions. Here's the format of each action:
|
||||
[{
|
||||
"reasoning": str, // The reasoning behind the action. Be specific, referencing any user information and their fields and element ids in your reasoning. Mention why you chose the action type, and why you chose the element id. Keep the reasoning short and to the point.
|
||||
"confidence_float": float, // The confidence of the action. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence
|
||||
"action_type": str, // It's a string enum: "CLICK", "INPUT_TEXT", "UPLOAD_FILE", "SELECT_OPTION", "WAIT", "SOLVE_CAPTCHA", "COMPLETE", "TERMINATE". "CLICK" is an element you'd like to click. "INPUT_TEXT" is an element you'd like to input text into. "UPLOAD_FILE" is an element you'd like to upload a file into. "SELECT_OPTION" is an element you'd like to select an option from. "WAIT" action should be used if there are no actions to take and there is some indication on screen that waiting could yield more actions. "WAIT" should not be used if there are actions to take. "SOLVE_CAPTCHA" should be used if there's a captcha to solve on the screen. "COMPLETE" is used when the user goal has been achieved AND if there's any data extraction goal, you should be able to get data from the page. If there is any other action to take, do not add "COMPLETE" type at all. "TERMINATE" is used to terminate the whole task with a failure when it doesn't seem like the user goal can be achieved. Do not use "TERMINATE" if waiting could lead the user towards the goal. Only return "TERMINATE" if you are on a page where the user goal cannot be achieved. All other actions are ignored when "TERMINATE" is returned.
|
||||
"id": int, // The id of the element to take action on. The id has to be one from the elements list
|
||||
"text": str, // Text for INPUT_TEXT action only
|
||||
"file_url": str, // The url of the file to upload if applicable. This field must be present for UPLOAD_FILE but can also be present for CLICK only if the click is to upload the file. It should be null otherwise.
|
||||
"option": { // The option to select for SELECT_OPTION action only. null if not SELECT_OPTION action
|
||||
"label": str, // the label of the option if any. MAKE SURE YOU USE THIS LABEL TO SELECT THE OPTION. DO NOT PUT ANYTHING OTHER THAN A VALID OPTION LABEL HERE
|
||||
"index": int, // the id corresponding to the optionIndex under the the select element.
|
||||
"value": str // the value of the option. MAKE SURE YOU USE THIS VALUE TO SELECT THE OPTION. DO NOT PUT ANYTHING OTHER THAN A VALID OPTION VALUE HERE
|
||||
}
|
||||
}],
|
||||
}
|
||||
|
||||
{% if action_history %}
|
||||
Consider the action history from the last step and the screenshot together, if actions from the last step don't yield positive impact, try other actions or other action combinations.
|
||||
{% endif %}
|
||||
|
||||
Clickable elements from `{{ url }}`:
|
||||
```
|
||||
{{ elements }}
|
||||
```
|
||||
|
||||
User goal:
|
||||
```
|
||||
{{ navigation_goal }}
|
||||
```
|
||||
{% if data_extraction_goal %}
|
||||
|
||||
User Data Extraction Goal:
|
||||
```
|
||||
{{ data_extraction_goal }}
|
||||
```
|
||||
{% endif %}
|
||||
|
||||
User details:
|
||||
```
|
||||
{{ navigation_payload_str }}
|
||||
```
|
||||
{% if action_history %}
|
||||
|
||||
Action results from previous steps: (note: even if the action history suggests goal is achieved, check the screenshot and the DOM elements to make sure the goal is achieved)
|
||||
{{ action_history }}
|
||||
{% endif %}
|
||||
|
||||
Current datetime in UTC:
|
||||
```
|
||||
{{ utc_datetime }}
|
||||
```
|
||||
16
skyvern/forge/prompts/skyvern/extract-information.j2
Normal file
16
skyvern/forge/prompts/skyvern/extract-information.j2
Normal file
@@ -0,0 +1,16 @@
|
||||
You are given a screenshot, user data extraction goal, the JSON schema for the output data format, and the current URL.
|
||||
|
||||
Your task is to extract the requested information from the screenshot and {% if extracted_information_schema %}output it in the specified JSON schema format:
|
||||
{{ extracted_information_schema }} {% else %}output in strictly JSON format {% endif %}
|
||||
|
||||
Add as much details as possible to the output JSON object while conforming to the output JSON schema.
|
||||
|
||||
Do not ever include anything other than the JSON object in your output, and do not ever include any additional fields in the JSON object.
|
||||
|
||||
If you are unable to extract the requested information for a specific field in the json schema, please output a null value for that field.
|
||||
|
||||
User Data Extraction Goal: {{ data_extraction_goal }}
|
||||
|
||||
Current URL: {{ current_url }}
|
||||
|
||||
Text extracted from the webpage: {{ extracted_text }}
|
||||
0
skyvern/forge/sdk/__init__.py
Normal file
0
skyvern/forge/sdk/__init__.py
Normal file
97
skyvern/forge/sdk/agent.py
Normal file
97
skyvern/forge/sdk/agent.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, FastAPI, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette_context.middleware import RawContextMiddleware
|
||||
from starlette_context.plugins.base import Plugin
|
||||
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.routes.agent_protocol import base_router
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class Agent:
|
||||
def get_agent_app(self, router: APIRouter = base_router) -> FastAPI:
|
||||
"""
|
||||
Start the agent server.
|
||||
"""
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Add CORS middleware
|
||||
origins = [
|
||||
"http://localhost:5000",
|
||||
"http://127.0.0.1:5000",
|
||||
"http://localhost:8000",
|
||||
"http://127.0.0.1:8000",
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:8080",
|
||||
# Add any other origins you want to whitelist
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
app.add_middleware(AgentMiddleware, agent=self)
|
||||
|
||||
app.add_middleware(
|
||||
RawContextMiddleware,
|
||||
plugins=(
|
||||
# TODO (suchintan): We should set these up
|
||||
ExecutionDatePlugin(),
|
||||
# RequestIdPlugin(),
|
||||
# UserAgentPlugin(),
|
||||
),
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unexpected_exception(request: Request, exc: Exception) -> JSONResponse:
|
||||
LOG.exception("Unexpected error in agent server.", exc_info=exc)
|
||||
return JSONResponse(status_code=500, content={"error": f"Unexpected error: {exc}"})
|
||||
|
||||
@app.middleware("http")
|
||||
async def request_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
||||
request_id = str(uuid.uuid4())
|
||||
skyvern_context.set(SkyvernContext(request_id=request_id))
|
||||
|
||||
try:
|
||||
return await call_next(request)
|
||||
finally:
|
||||
skyvern_context.reset()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class AgentMiddleware:
|
||||
"""
|
||||
Middleware that injects the agent instance into the request scope.
|
||||
"""
|
||||
|
||||
def __init__(self, app: FastAPI, agent: Agent):
|
||||
self.app = app
|
||||
self.agent = agent
|
||||
|
||||
async def __call__(self, scope, receive, send): # type: ignore
|
||||
scope["agent"] = self.agent
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
class ExecutionDatePlugin(Plugin):
|
||||
key = "execution_date"
|
||||
|
||||
async def process_request(self, request: Request | HTTPConnection) -> datetime:
|
||||
return datetime.now()
|
||||
0
skyvern/forge/sdk/api/__init__.py
Normal file
0
skyvern/forge/sdk/api/__init__.py
Normal file
134
skyvern/forge/sdk/api/aws.py
Normal file
134
skyvern/forge/sdk/api/aws.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any, Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aioboto3
|
||||
import structlog
|
||||
from aiobotocore.client import AioBaseClient
|
||||
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class AWSClientType(StrEnum):
|
||||
S3 = "s3"
|
||||
SECRETS_MANAGER = "secretsmanager"
|
||||
|
||||
|
||||
def execute_with_async_client(client_type: AWSClientType) -> Callable:
|
||||
def decorator(f: Callable) -> Callable:
|
||||
async def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
||||
self = args[0]
|
||||
assert isinstance(self, AsyncAWSClient)
|
||||
session = aioboto3.Session()
|
||||
async with session.client(client_type) as client:
|
||||
return await f(*args, client=client, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class AsyncAWSClient:
|
||||
@execute_with_async_client(client_type=AWSClientType.SECRETS_MANAGER)
|
||||
async def get_secret(self, secret_name: str, client: AioBaseClient = None) -> str | None:
|
||||
try:
|
||||
response = await client.get_secret_value(SecretId=secret_name)
|
||||
return response["SecretString"]
|
||||
except Exception as e:
|
||||
try:
|
||||
error_code = e.response["Error"]["Code"] # type: ignore
|
||||
except Exception:
|
||||
error_code = "failed-to-get-error-code"
|
||||
LOG.exception("Failed to get secret.", secret_name=secret_name, error_code=error_code, exc_info=True)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def upload_file(self, uri: str, data: bytes, client: AioBaseClient = None) -> str | None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
await client.put_object(Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
LOG.debug("Upload file success", uri=uri)
|
||||
return uri
|
||||
except Exception:
|
||||
LOG.exception("S3 upload failed.", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def upload_file_from_path(self, uri: str, file_path: str, client: AioBaseClient = None) -> None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
await client.upload_file(file_path, parsed_uri.bucket, parsed_uri.key)
|
||||
LOG.info("Upload file from path success", uri=uri)
|
||||
except Exception:
|
||||
LOG.exception("S3 upload failed.", uri=uri)
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def download_file(self, uri: str, client: AioBaseClient = None) -> bytes | None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
response = await client.get_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return await response["Body"].read()
|
||||
except Exception:
|
||||
LOG.exception("S3 download failed", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def create_presigned_url(self, uri: str, client: AioBaseClient = None) -> str | None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
url = await client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": parsed_uri.bucket, "Key": parsed_uri.key},
|
||||
ExpiresIn=SettingsManager.get_settings().PRESIGNED_URL_EXPIRATION,
|
||||
)
|
||||
return url
|
||||
except Exception:
|
||||
LOG.exception("Failed to create presigned url.", uri=uri)
|
||||
return None
|
||||
|
||||
|
||||
class S3Uri(object):
|
||||
# From: https://stackoverflow.com/questions/42641315/s3-urls-get-bucket-name-and-path
|
||||
"""
|
||||
>>> s = S3Uri("s3://bucket/hello/world")
|
||||
>>> s.bucket
|
||||
'bucket'
|
||||
>>> s.key
|
||||
'hello/world'
|
||||
>>> s.uri
|
||||
's3://bucket/hello/world'
|
||||
|
||||
>>> s = S3Uri("s3://bucket/hello/world?qwe1=3#ddd")
|
||||
>>> s.bucket
|
||||
'bucket'
|
||||
>>> s.key
|
||||
'hello/world?qwe1=3#ddd'
|
||||
>>> s.uri
|
||||
's3://bucket/hello/world?qwe1=3#ddd'
|
||||
|
||||
>>> s = S3Uri("s3://bucket/hello/world#foo?bar=2")
|
||||
>>> s.key
|
||||
'hello/world#foo?bar=2'
|
||||
>>> s.uri
|
||||
's3://bucket/hello/world#foo?bar=2'
|
||||
"""
|
||||
|
||||
def __init__(self, uri: str) -> None:
|
||||
self._parsed = urlparse(uri, allow_fragments=False)
|
||||
|
||||
@property
|
||||
def bucket(self) -> str:
|
||||
return self._parsed.netloc
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
if self._parsed.query:
|
||||
return self._parsed.path.lstrip("/") + "?" + self._parsed.query
|
||||
else:
|
||||
return self._parsed.path.lstrip("/")
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._parsed.geturl()
|
||||
25
skyvern/forge/sdk/api/chat_completion_price.py
Normal file
25
skyvern/forge/sdk/api/chat_completion_price.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
openai_model_to_price_lambdas = {
|
||||
"gpt-4-vision-preview": (0.01, 0.03),
|
||||
"gpt-4-1106-preview": (0.01, 0.03),
|
||||
"gpt-3.5-turbo": (0.001, 0.002),
|
||||
"gpt-3.5-turbo-1106": (0.001, 0.002),
|
||||
}
|
||||
|
||||
|
||||
class ChatCompletionPrice(BaseModel):
|
||||
input_token_count: int
|
||||
output_token_count: int
|
||||
openai_model_to_price_lambda: Callable[[int, int], float]
|
||||
|
||||
def __init__(self, input_token_count: int, output_token_count: int, model_name: str):
|
||||
input_token_price, output_token_price = openai_model_to_price_lambdas[model_name]
|
||||
super().__init__(
|
||||
input_token_count=input_token_count,
|
||||
output_token_count=output_token_count,
|
||||
openai_model_to_price_lambda=lambda input_token, output_token: input_token_price * input_token / 1000
|
||||
+ output_token_price * output_token / 1000,
|
||||
)
|
||||
47
skyvern/forge/sdk/api/files.py
Normal file
47
skyvern/forge/sdk/api/files.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
import structlog
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
def download_file(url: str) -> str | None:
|
||||
# Send an HTTP request to the URL of the file, stream=True to prevent loading the content at once into memory
|
||||
r = requests.get(url, stream=True)
|
||||
|
||||
# Check if the request is successful
|
||||
if r.status_code == 200:
|
||||
# Parse the URL
|
||||
a = urlparse(url)
|
||||
|
||||
# Get the file name
|
||||
temp_dir = tempfile.mkdtemp(prefix="skyvern_downloads_")
|
||||
|
||||
file_name = os.path.basename(a.path)
|
||||
file_path = os.path.join(temp_dir, file_name)
|
||||
|
||||
LOG.info(f"Downloading file to {file_path}")
|
||||
with open(file_path, "wb") as f:
|
||||
# Write the content of the request into the file
|
||||
for chunk in r.iter_content(1024):
|
||||
f.write(chunk)
|
||||
LOG.info(f"File downloaded successfully to {file_path}")
|
||||
return file_path
|
||||
else:
|
||||
LOG.error(f"Failed to download file, status code: {r.status_code}")
|
||||
return None
|
||||
|
||||
|
||||
def zip_files(files_path: str, zip_file_path: str) -> str:
|
||||
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
for root, dirs, files in os.walk(files_path):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
arcname = os.path.relpath(file_path, files_path) # Relative path within the zip
|
||||
zipf.write(file_path, arcname)
|
||||
|
||||
return zip_file_path
|
||||
221
skyvern/forge/sdk/api/open_ai.py
Normal file
221
skyvern/forge/sdk/api/open_ai.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import commentjson
|
||||
import openai
|
||||
import structlog
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from skyvern.exceptions import InvalidOpenAIResponseFormat, NoAvailableOpenAIClients, OpenAIRequestTooBigError
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class OpenAIKeyClientWrapper:
|
||||
client: AsyncOpenAI
|
||||
key: str
|
||||
remaining_requests: int | None
|
||||
|
||||
def __init__(self, key: str, remaining_requests: int | None) -> None:
|
||||
self.key = key
|
||||
self.remaining_requests = remaining_requests
|
||||
self.updated_at = datetime.utcnow()
|
||||
self.client = AsyncOpenAI(api_key=self.key)
|
||||
|
||||
def update_remaining_requests(self, remaining_requests: int | None) -> None:
|
||||
self.remaining_requests = remaining_requests
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
# If remaining_requests is None, then it's the first time we're trying this key
|
||||
# so we can assume it's available, otherwise we check if it's greater than 0
|
||||
if self.remaining_requests is None:
|
||||
return True
|
||||
|
||||
if self.remaining_requests > 0:
|
||||
return True
|
||||
|
||||
# If we haven't checked this in over 1 minutes, check it again
|
||||
# Most of our failures are because of Tokens-per-minute (TPM) limits
|
||||
if self.updated_at < (datetime.utcnow() - timedelta(minutes=1)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class OpenAIClientManager:
|
||||
# TODO Support other models for requests without screenshots, track rate limits for each model and key as well if any
|
||||
clients: list[OpenAIKeyClientWrapper]
|
||||
|
||||
def __init__(self, api_keys: list[str] = SettingsManager.get_settings().OPENAI_API_KEYS) -> None:
|
||||
self.clients = [OpenAIKeyClientWrapper(key, None) for key in api_keys]
|
||||
|
||||
def get_available_client(self) -> OpenAIKeyClientWrapper | None:
|
||||
available_clients = [client for client in self.clients if client.is_available()]
|
||||
|
||||
if not available_clients:
|
||||
return None
|
||||
|
||||
# Randomly select an available client to distribute requests across our accounts
|
||||
return random.choice(available_clients)
|
||||
|
||||
async def content_builder(
|
||||
self,
|
||||
step: Step,
|
||||
screenshots: list[bytes] | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
content: list[dict[str, Any]] = []
|
||||
|
||||
if prompt is not None:
|
||||
content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
}
|
||||
)
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
data=prompt.encode("utf-8"),
|
||||
)
|
||||
if screenshots:
|
||||
for screenshot in screenshots:
|
||||
encoded_image = base64.b64encode(screenshot).decode("utf-8")
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{encoded_image}",
|
||||
},
|
||||
}
|
||||
)
|
||||
# create artifact for each image
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
data=screenshot,
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
step: Step,
|
||||
model: str = "gpt-4-vision-preview",
|
||||
max_tokens: int = 4096,
|
||||
temperature: int = 0,
|
||||
screenshots: list[bytes] | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
LOG.info(
|
||||
f"Sending LLM request",
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
num_screenshots=len(screenshots) if screenshots else 0,
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": await self.content_builder(
|
||||
step=step,
|
||||
screenshots=screenshots,
|
||||
prompt=prompt,
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
chat_completion_kwargs = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
data=json.dumps(chat_completion_kwargs).encode("utf-8"),
|
||||
)
|
||||
available_client = self.get_available_client()
|
||||
if available_client is None:
|
||||
raise NoAvailableOpenAIClients()
|
||||
try:
|
||||
response = await available_client.client.chat.completions.with_raw_response.create(**chat_completion_kwargs)
|
||||
except openai.RateLimitError as e:
|
||||
# If we get a RateLimitError, we can assume the key is not available anymore
|
||||
if e.code == 429:
|
||||
raise OpenAIRequestTooBigError(e.message)
|
||||
LOG.warning(
|
||||
"OpenAI rate limit exceeded, marking key as unavailable.", error_code=e.code, error_message=e.message
|
||||
)
|
||||
available_client.update_remaining_requests(remaining_requests=0)
|
||||
available_client = self.get_available_client()
|
||||
if available_client is None:
|
||||
raise NoAvailableOpenAIClients()
|
||||
return await self.chat_completion(
|
||||
step=step,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
screenshots=screenshots,
|
||||
prompt=prompt,
|
||||
)
|
||||
# TODO: https://platform.openai.com/docs/guides/rate-limits/rate-limits-in-headers
|
||||
# use other headers, x-ratelimit-limit-requests, x-ratelimit-limit-tokens, x-ratelimit-remaining-tokens
|
||||
# x-ratelimit-reset-requests, x-ratelimit-reset-tokens to write a more accurate algorithm for managing api keys
|
||||
|
||||
# If we get a response, we can assume the key is available and update the remaining requests
|
||||
ratelimit_remaining_requests = response.headers.get("x-ratelimit-remaining-requests")
|
||||
|
||||
if not ratelimit_remaining_requests:
|
||||
LOG.warning("Invalid x-ratelimit-remaining-requests from OpenAI", response.headers)
|
||||
|
||||
available_client.update_remaining_requests(remaining_requests=int(ratelimit_remaining_requests))
|
||||
chat_completion = response.parse()
|
||||
|
||||
if chat_completion.usage is not None:
|
||||
# TODO (Suchintan): Is this bad design?
|
||||
step = await app.DATABASE.update_step(
|
||||
step_id=step.step_id,
|
||||
task_id=step.task_id,
|
||||
organization_id=step.organization_id,
|
||||
chat_completion_price=ChatCompletionPrice(
|
||||
input_token_count=chat_completion.usage.prompt_tokens,
|
||||
output_token_count=chat_completion.usage.completion_tokens,
|
||||
model_name=model,
|
||||
),
|
||||
)
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
data=chat_completion.model_dump_json(indent=2).encode("utf-8"),
|
||||
)
|
||||
parsed_response = self.parse_response(chat_completion)
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
)
|
||||
return parsed_response
|
||||
|
||||
def parse_response(self, response: ChatCompletion) -> dict[str, str]:
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
content = content.replace("```json", "")
|
||||
content = content.replace("```", "")
|
||||
if not content:
|
||||
raise Exception("openai response content is empty")
|
||||
return commentjson.loads(content)
|
||||
except Exception as e:
|
||||
raise InvalidOpenAIResponseFormat(str(response)) from e
|
||||
0
skyvern/forge/sdk/artifact/__init__.py
Normal file
0
skyvern/forge/sdk/artifact/__init__.py
Normal file
112
skyvern/forge/sdk/artifact/manager.py
Normal file
112
skyvern/forge/sdk/artifact/manager.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.db.id import generate_artifact_id
|
||||
from skyvern.forge.sdk.models import Step
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class ArtifactManager:
|
||||
# task_id -> list of aio_tasks for uploading artifacts
|
||||
upload_aiotasks_map: dict[str, list[asyncio.Task[None]]] = defaultdict(list)
|
||||
|
||||
async def create_artifact(
|
||||
self, step: Step, artifact_type: ArtifactType, data: bytes | None = None, path: str | None = None
|
||||
) -> str:
|
||||
# TODO (kerem): Which is better?
|
||||
# current: (disadvantage: we create the artifact_id UUID here)
|
||||
# 1. generate artifact_id UUID here
|
||||
# 2. build uri with artifact_id, step_id, task_id, artifact_type
|
||||
# 3. create artifact in db using artifact_id, step_id, task_id, artifact_type, uri
|
||||
# 4. store artifact in storage
|
||||
# alternative: (disadvantage: two db calls)
|
||||
# 1. create artifact in db without the URI
|
||||
# 2. build uri with artifact_id, step_id, task_id, artifact_type
|
||||
# 3. update artifact in db with the URI
|
||||
# 4. store artifact in storage
|
||||
if data is None and path is None:
|
||||
raise ValueError("Either data or path must be provided to create an artifact.")
|
||||
if data and path:
|
||||
raise ValueError("Both data and path cannot be provided to create an artifact.")
|
||||
artifact_id = generate_artifact_id()
|
||||
uri = app.STORAGE.build_uri(artifact_id, step, artifact_type)
|
||||
artifact = await app.DATABASE.create_artifact(
|
||||
artifact_id,
|
||||
step.step_id,
|
||||
step.task_id,
|
||||
artifact_type,
|
||||
uri,
|
||||
organization_id=step.organization_id,
|
||||
)
|
||||
if data:
|
||||
# Fire and forget
|
||||
aio_task = asyncio.create_task(app.STORAGE.store_artifact(artifact, data))
|
||||
self.upload_aiotasks_map[step.task_id].append(aio_task)
|
||||
elif path:
|
||||
# Fire and forget
|
||||
aio_task = asyncio.create_task(app.STORAGE.store_artifact_from_path(artifact, path))
|
||||
self.upload_aiotasks_map[step.task_id].append(aio_task)
|
||||
|
||||
return artifact_id
|
||||
|
||||
async def update_artifact_data(self, artifact_id: str | None, organization_id: str | None, data: bytes) -> None:
|
||||
if not artifact_id or not organization_id:
|
||||
return None
|
||||
artifact = await app.DATABASE.get_artifact_by_id(artifact_id, organization_id)
|
||||
if not artifact:
|
||||
return
|
||||
# Fire and forget
|
||||
aio_task = asyncio.create_task(app.STORAGE.store_artifact(artifact, data))
|
||||
self.upload_aiotasks_map[artifact.task_id].append(aio_task)
|
||||
|
||||
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
|
||||
return await app.STORAGE.retrieve_artifact(artifact)
|
||||
|
||||
async def get_share_link(self, artifact: Artifact) -> str | None:
|
||||
return await app.STORAGE.get_share_link(artifact)
|
||||
|
||||
async def wait_for_upload_aiotasks_for_task(self, task_id: str) -> None:
|
||||
try:
|
||||
st = time.time()
|
||||
async with asyncio.timeout(30):
|
||||
await asyncio.gather(
|
||||
*[aio_task for aio_task in self.upload_aiotasks_map[task_id] if not aio_task.done()]
|
||||
)
|
||||
LOG.info(
|
||||
f"S3 upload tasks for task_id={task_id} completed in {time.time() - st:.2f}s",
|
||||
task_id=task_id,
|
||||
duration=time.time() - st,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
LOG.error(f"Timeout (30s) while waiting for upload tasks for task_id={task_id}", task_id=task_id)
|
||||
|
||||
del self.upload_aiotasks_map[task_id]
|
||||
|
||||
async def wait_for_upload_aiotasks_for_tasks(self, task_ids: list[str]) -> None:
|
||||
try:
|
||||
st = time.time()
|
||||
async with asyncio.timeout(30):
|
||||
await asyncio.gather(
|
||||
*[
|
||||
aio_task
|
||||
for task_id in task_ids
|
||||
for aio_task in self.upload_aiotasks_map[task_id]
|
||||
if not aio_task.done()
|
||||
]
|
||||
)
|
||||
LOG.info(
|
||||
f"S3 upload tasks for task_ids={task_ids} completed in {time.time() - st:.2f}s",
|
||||
task_ids=task_ids,
|
||||
duration=time.time() - st,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
LOG.error(f"Timeout (30s) while waiting for upload tasks for task_ids={task_ids}", task_ids=task_ids)
|
||||
|
||||
for task_id in task_ids:
|
||||
del self.upload_aiotasks_map[task_id]
|
||||
78
skyvern/forge/sdk/artifact/models.py
Normal file
78
skyvern/forge/sdk/artifact/models.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ArtifactType(StrEnum):
|
||||
RECORDING = "recording"
|
||||
|
||||
# DEPRECATED. pls use SCREENSHOT_LLM, SCREENSHOT_ACTION or SCREENSHOT_FINAL
|
||||
SCREENSHOT = "screenshot"
|
||||
|
||||
# USE THESE for screenshots
|
||||
SCREENSHOT_LLM = "screenshot_llm"
|
||||
SCREENSHOT_ACTION = "screenshot_action"
|
||||
SCREENSHOT_FINAL = "screenshot_final"
|
||||
|
||||
LLM_PROMPT = "llm_prompt"
|
||||
LLM_REQUEST = "llm_request"
|
||||
LLM_RESPONSE = "llm_response"
|
||||
LLM_RESPONSE_PARSED = "llm_response_parsed"
|
||||
VISIBLE_ELEMENTS_ID_XPATH_MAP = "visible_elements_id_xpath_map"
|
||||
VISIBLE_ELEMENTS_TREE = "visible_elements_tree"
|
||||
VISIBLE_ELEMENTS_TREE_TRIMMED = "visible_elements_tree_trimmed"
|
||||
|
||||
# DEPRECATED. pls use HTML_SCRAPE or HTML_ACTION
|
||||
HTML = "html"
|
||||
|
||||
# USE THESE for htmls
|
||||
HTML_SCRAPE = "html_scrape"
|
||||
HTML_ACTION = "html_action"
|
||||
|
||||
# Debugging
|
||||
TRACE = "trace"
|
||||
HAR = "har"
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
created_at: datetime = Field(
|
||||
...,
|
||||
description="The creation datetime of the task.",
|
||||
examples=["2023-01-01T00:00:00Z"],
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
modified_at: datetime = Field(
|
||||
...,
|
||||
description="The modification datetime of the task.",
|
||||
examples=["2023-01-01T00:00:00Z"],
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
artifact_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task artifact.",
|
||||
examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"],
|
||||
)
|
||||
task_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task this artifact belongs to.",
|
||||
examples=["50da533e-3904-4401-8a07-c49adf88b5eb"],
|
||||
)
|
||||
step_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task step this artifact belongs to.",
|
||||
examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"],
|
||||
)
|
||||
artifact_type: ArtifactType = Field(
|
||||
...,
|
||||
description="The type of the artifact.",
|
||||
examples=["screenshot"],
|
||||
)
|
||||
uri: str = Field(
|
||||
...,
|
||||
description="The URI of the artifact.",
|
||||
examples=["/Users/skyvern/hello/world.png"],
|
||||
)
|
||||
organization_id: str | None = None
|
||||
0
skyvern/forge/sdk/artifact/storage/__init__.py
Normal file
0
skyvern/forge/sdk/artifact/storage/__init__.py
Normal file
45
skyvern/forge/sdk/artifact/storage/base.py
Normal file
45
skyvern/forge/sdk/artifact/storage/base.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.models import Step
|
||||
|
||||
# TODO: This should be a part of the ArtifactType model
|
||||
FILE_EXTENTSION_MAP: dict[ArtifactType, str] = {
|
||||
ArtifactType.RECORDING: "webm",
|
||||
ArtifactType.SCREENSHOT_LLM: "png",
|
||||
ArtifactType.SCREENSHOT_ACTION: "png",
|
||||
ArtifactType.SCREENSHOT_FINAL: "png",
|
||||
ArtifactType.LLM_PROMPT: "txt",
|
||||
ArtifactType.LLM_REQUEST: "json",
|
||||
ArtifactType.LLM_RESPONSE: "json",
|
||||
ArtifactType.LLM_RESPONSE_PARSED: "json",
|
||||
ArtifactType.VISIBLE_ELEMENTS_ID_XPATH_MAP: "json",
|
||||
ArtifactType.VISIBLE_ELEMENTS_TREE: "json",
|
||||
ArtifactType.VISIBLE_ELEMENTS_TREE_TRIMMED: "json",
|
||||
ArtifactType.HTML_SCRAPE: "html",
|
||||
ArtifactType.HTML_ACTION: "html",
|
||||
ArtifactType.TRACE: "zip",
|
||||
ArtifactType.HAR: "har",
|
||||
}
|
||||
|
||||
|
||||
class BaseStorage(ABC):
|
||||
@abstractmethod
|
||||
def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_share_link(self, artifact: Artifact) -> str | None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
|
||||
pass
|
||||
14
skyvern/forge/sdk/artifact/storage/factory.py
Normal file
14
skyvern/forge/sdk/artifact/storage/factory.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from skyvern.forge.sdk.artifact.storage.base import BaseStorage
|
||||
from skyvern.forge.sdk.artifact.storage.local import LocalStorage
|
||||
|
||||
|
||||
class StorageFactory:
|
||||
__storage: BaseStorage = LocalStorage()
|
||||
|
||||
@staticmethod
|
||||
def set_storage(storage: BaseStorage) -> None:
|
||||
StorageFactory.__storage = storage
|
||||
|
||||
@staticmethod
|
||||
def get_storage() -> BaseStorage:
|
||||
return StorageFactory.__storage
|
||||
66
skyvern/forge/sdk/artifact/storage/local.py
Normal file
66
skyvern/forge/sdk/artifact/storage/local.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class LocalStorage(BaseStorage):
|
||||
def __init__(self, artifact_path: str = SettingsManager.get_settings().ARTIFACT_STORAGE_PATH) -> None:
|
||||
self.artifact_path = artifact_path
|
||||
|
||||
def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str:
|
||||
file_ext = FILE_EXTENTSION_MAP[artifact_type]
|
||||
return f"file://{self.artifact_path}/{step.task_id}/{step.order:02d}_{step.retry_index}_{step.step_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"
|
||||
|
||||
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
|
||||
file_path = None
|
||||
try:
|
||||
file_path = Path(self._parse_uri_to_path(artifact.uri))
|
||||
self._create_directories_if_not_exists(file_path)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(data)
|
||||
except Exception:
|
||||
LOG.exception("Failed to store artifact locally.", file_path=file_path, artifact=artifact)
|
||||
|
||||
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
|
||||
file_path = None
|
||||
try:
|
||||
file_path = Path(self._parse_uri_to_path(artifact.uri))
|
||||
self._create_directories_if_not_exists(file_path)
|
||||
Path(path).replace(file_path)
|
||||
except Exception:
|
||||
LOG.exception("Failed to store artifact locally.", file_path=file_path, artifact=artifact)
|
||||
|
||||
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
|
||||
file_path = None
|
||||
try:
|
||||
file_path = self._parse_uri_to_path(artifact.uri)
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
except Exception:
|
||||
LOG.exception("Failed to retrieve local artifact.", file_path=file_path, artifact=artifact)
|
||||
return None
|
||||
|
||||
async def get_share_link(self, artifact: Artifact) -> str:
|
||||
return artifact.uri
|
||||
|
||||
@staticmethod
|
||||
def _parse_uri_to_path(uri: str) -> str:
|
||||
parsed_uri = urlparse(uri)
|
||||
if parsed_uri.scheme != "file":
|
||||
raise ValueError("Invalid URI scheme: {parsed_uri.scheme} expected: file")
|
||||
path = parsed_uri.netloc + parsed_uri.path
|
||||
return unquote(path)
|
||||
|
||||
@staticmethod
|
||||
def _create_directories_if_not_exists(path_including_file_name: Path) -> None:
|
||||
path = path_including_file_name.parent
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
0
skyvern/forge/sdk/core/__init__.py
Normal file
0
skyvern/forge/sdk/core/__init__.py
Normal file
41
skyvern/forge/sdk/core/security.py
Normal file
41
skyvern/forge/sdk/core/security.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union
|
||||
|
||||
from jose import jwt
|
||||
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: timedelta | None = None,
|
||||
) -> str:
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(
|
||||
minutes=SettingsManager.get_settings().ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
)
|
||||
to_encode = {"exp": expire, "sub": str(subject)}
|
||||
encoded_jwt = jwt.encode(to_encode, SettingsManager.get_settings().SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def generate_skyvern_signature(
|
||||
payload: str,
|
||||
api_key: str,
|
||||
) -> str:
|
||||
"""
|
||||
Generate Skyvern signature.
|
||||
|
||||
:param payload: the request body
|
||||
:param api_key: the Skyvern api key
|
||||
|
||||
:return: the Skyvern signature
|
||||
"""
|
||||
hash_obj = hmac.new(api_key.encode("utf-8"), msg=payload.encode("utf-8"), digestmod=hashlib.sha256)
|
||||
return hash_obj.hexdigest()
|
||||
73
skyvern/forge/sdk/core/skyvern_context.py
Normal file
73
skyvern/forge/sdk/core/skyvern_context.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkyvernContext:
|
||||
request_id: str | None = None
|
||||
organization_id: str | None = None
|
||||
task_id: str | None = None
|
||||
workflow_id: str | None = None
|
||||
workflow_run_id: str | None = None
|
||||
max_steps_override: int | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SkyvernContext(request_id={self.request_id}, organization_id={self.organization_id}, task_id={self.task_id}, workflow_id={self.workflow_id}, workflow_run_id={self.workflow_run_id}, max_steps_override={self.max_steps_override})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
_context: ContextVar[SkyvernContext | None] = ContextVar(
|
||||
"Global context",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
def current() -> SkyvernContext | None:
|
||||
"""
|
||||
Get the current context
|
||||
|
||||
Returns:
|
||||
The current context, or None if there is none
|
||||
"""
|
||||
return _context.get()
|
||||
|
||||
|
||||
def ensure_context() -> SkyvernContext:
|
||||
"""
|
||||
Get the current context, or raise an error if there is none
|
||||
|
||||
Returns:
|
||||
The current context if there is one
|
||||
|
||||
Raises:
|
||||
RuntimeError: If there is no current context
|
||||
"""
|
||||
context = current()
|
||||
if context is None:
|
||||
raise RuntimeError("No skyvern context")
|
||||
return context
|
||||
|
||||
|
||||
def set(context: SkyvernContext) -> None:
|
||||
"""
|
||||
Set the current context
|
||||
|
||||
Args:
|
||||
context: The context to set
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
_context.set(context)
|
||||
|
||||
|
||||
def reset() -> None:
|
||||
"""
|
||||
Reset the current context
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
_context.set(None)
|
||||
0
skyvern/forge/sdk/db/__init__.py
Normal file
0
skyvern/forge/sdk/db/__init__.py
Normal file
900
skyvern/forge/sdk/db/client.py
Normal file
900
skyvern/forge/sdk/db/client.py
Normal file
@@ -0,0 +1,900 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import and_, create_engine, delete
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from skyvern.exceptions import WorkflowParameterNotFound
|
||||
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
ArtifactModel,
|
||||
AWSSecretParameterModel,
|
||||
OrganizationAuthTokenModel,
|
||||
OrganizationModel,
|
||||
StepModel,
|
||||
TaskModel,
|
||||
WorkflowModel,
|
||||
WorkflowParameterModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowRunParameterModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import (
|
||||
_custom_json_serializer,
|
||||
convert_to_artifact,
|
||||
convert_to_aws_secret_parameter,
|
||||
convert_to_organization,
|
||||
convert_to_organization_auth_token,
|
||||
convert_to_step,
|
||||
convert_to_task,
|
||||
convert_to_workflow,
|
||||
convert_to_workflow_parameter,
|
||||
convert_to_workflow_run,
|
||||
convert_to_workflow_run_parameter,
|
||||
)
|
||||
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunParameter, WorkflowRunStatus
|
||||
from skyvern.webeye.actions.models import AgentStepOutput
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class AgentDB:
|
||||
def __init__(self, database_string: str, debug_enabled: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.debug_enabled = debug_enabled
|
||||
self.engine = create_engine(database_string, json_serializer=_custom_json_serializer)
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
url: str,
|
||||
navigation_goal: str | None,
|
||||
data_extraction_goal: str | None,
|
||||
navigation_payload: dict[str, Any] | list | str | None,
|
||||
webhook_callback_url: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
proxy_location: ProxyLocation | None = None,
|
||||
extracted_information_schema: dict[str, Any] | list | str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
order: int | None = None,
|
||||
retry: int | None = None,
|
||||
) -> Task:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
new_task = TaskModel(
|
||||
status="created",
|
||||
url=url,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
navigation_goal=navigation_goal,
|
||||
data_extraction_goal=data_extraction_goal,
|
||||
navigation_payload=navigation_payload,
|
||||
organization_id=organization_id,
|
||||
proxy_location=proxy_location,
|
||||
extracted_information_schema=extracted_information_schema,
|
||||
workflow_run_id=workflow_run_id,
|
||||
order=order,
|
||||
retry=retry,
|
||||
)
|
||||
session.add(new_task)
|
||||
session.commit()
|
||||
session.refresh(new_task)
|
||||
return convert_to_task(new_task, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_step(
|
||||
self,
|
||||
task_id: str,
|
||||
order: int,
|
||||
retry_index: int,
|
||||
organization_id: str | None = None,
|
||||
) -> Step:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
new_step = StepModel(
|
||||
task_id=task_id,
|
||||
order=order,
|
||||
retry_index=retry_index,
|
||||
status="created",
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(new_step)
|
||||
session.commit()
|
||||
session.refresh(new_step)
|
||||
return convert_to_step(new_step, debug_enabled=self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_artifact(
|
||||
self,
|
||||
artifact_id: str,
|
||||
step_id: str,
|
||||
task_id: str,
|
||||
artifact_type: str,
|
||||
uri: str,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
new_artifact = ArtifactModel(
|
||||
artifact_id=artifact_id,
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
artifact_type=artifact_type,
|
||||
uri=uri,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(new_artifact)
|
||||
session.commit()
|
||||
session.refresh(new_artifact)
|
||||
return convert_to_artifact(new_artifact, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.exception("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.exception("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
|
||||
"""Get a task by its id"""
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if task_obj := (
|
||||
session.query(TaskModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_task(task_obj, self.debug_enabled)
|
||||
else:
|
||||
LOG.info("Task not found", task_id=task_id, organization_id=organization_id)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_step(self, task_id: str, step_id: str, organization_id: str | None = None) -> Step | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if step := (
|
||||
session.query(StepModel)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_step(step, debug_enabled=self.debug_enabled)
|
||||
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_task_steps(self, task_id: str, organization_id: str | None = None) -> list[Step]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
steps := session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order)
|
||||
.order_by(StepModel.retry_index)
|
||||
.all()
|
||||
):
|
||||
return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps]
|
||||
else:
|
||||
return []
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> list[StepModel]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
return (
|
||||
session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order)
|
||||
.order_by(StepModel.retry_index)
|
||||
.all()
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_latest_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if step := (
|
||||
session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order.desc())
|
||||
.first()
|
||||
):
|
||||
return convert_to_step(step, debug_enabled=self.debug_enabled)
|
||||
else:
|
||||
LOG.info("Latest step not found", task_id=task_id, organization_id=organization_id)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
status: StepStatus | None = None,
|
||||
output: AgentStepOutput | None = None,
|
||||
is_last: bool | None = None,
|
||||
retry_index: int | None = None,
|
||||
organization_id: str | None = None,
|
||||
chat_completion_price: ChatCompletionPrice | None = None,
|
||||
) -> Step:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
step := session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
if status is not None:
|
||||
step.status = status
|
||||
if output is not None:
|
||||
step.output = output.model_dump()
|
||||
if is_last is not None:
|
||||
step.is_last = is_last
|
||||
if retry_index is not None:
|
||||
step.retry_index = retry_index
|
||||
if chat_completion_price is not None:
|
||||
if step.input_token_count is None:
|
||||
step.input_token_count = 0
|
||||
|
||||
if step.output_token_count is None:
|
||||
step.output_token_count = 0
|
||||
|
||||
step.input_token_count += chat_completion_price.input_token_count
|
||||
step.output_token_count += chat_completion_price.output_token_count
|
||||
step.step_cost = chat_completion_price.openai_model_to_price_lambda(
|
||||
step.input_token_count, step.output_token_count
|
||||
)
|
||||
|
||||
session.commit()
|
||||
updated_step = await self.get_step(task_id, step_id, organization_id)
|
||||
if not updated_step:
|
||||
raise NotFoundError("Step not found")
|
||||
return updated_step
|
||||
else:
|
||||
raise NotFoundError("Step not found")
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except NotFoundError:
|
||||
LOG.error("NotFoundError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: TaskStatus,
|
||||
extracted_information: dict[str, Any] | list | str | None = None,
|
||||
failure_reason: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> Task:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
task := session.query(TaskModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
task.status = status
|
||||
if extracted_information is not None:
|
||||
task.extracted_information = extracted_information
|
||||
if failure_reason is not None:
|
||||
task.failure_reason = failure_reason
|
||||
session.commit()
|
||||
updated_task = await self.get_task(task_id, organization_id=organization_id)
|
||||
if not updated_task:
|
||||
raise NotFoundError("Task not found")
|
||||
return updated_task
|
||||
else:
|
||||
raise NotFoundError("Task not found")
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except NotFoundError:
|
||||
LOG.error("NotFoundError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_tasks(self, page: int = 1, page_size: int = 10, organization_id: str | None = None) -> list[Task]:
|
||||
"""
|
||||
Get all tasks.
|
||||
:param page: Starts at 1
|
||||
:param page_size:
|
||||
:return:
|
||||
"""
|
||||
if page < 1:
|
||||
raise ValueError(f"Page must be greater than 0, got {page}")
|
||||
|
||||
try:
|
||||
with self.Session() as session:
|
||||
db_page = page - 1 # offset logic is 0 based
|
||||
tasks = (
|
||||
session.query(TaskModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(TaskModel.created_at.desc())
|
||||
.limit(page_size)
|
||||
.offset(db_page * page_size)
|
||||
.all()
|
||||
)
|
||||
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_organization(self, organization_id: str) -> Organization | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if organization := (
|
||||
session.query(OrganizationModel).filter_by(organization_id=organization_id).first()
|
||||
):
|
||||
return convert_to_organization(organization)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_organization(
|
||||
self,
|
||||
organization_name: str,
|
||||
webhook_callback_url: str | None = None,
|
||||
max_steps_per_run: int | None = None,
|
||||
) -> Organization:
|
||||
with self.Session() as session:
|
||||
org = OrganizationModel(
|
||||
organization_name=organization_name,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
max_steps_per_run=max_steps_per_run,
|
||||
)
|
||||
session.add(org)
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
|
||||
return convert_to_organization(org)
|
||||
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
) -> OrganizationAuthToken | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if token := (
|
||||
session.query(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(valid=True)
|
||||
.first()
|
||||
):
|
||||
return convert_to_organization_auth_token(token)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def validate_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str,
|
||||
) -> OrganizationAuthToken | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if token_obj := (
|
||||
session.query(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(token=token)
|
||||
.filter_by(valid=True)
|
||||
.first()
|
||||
):
|
||||
return convert_to_organization_auth_token(token_obj)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str,
|
||||
) -> OrganizationAuthToken:
|
||||
with self.Session() as session:
|
||||
token = OrganizationAuthTokenModel(
|
||||
organization_id=organization_id,
|
||||
token_type=token_type,
|
||||
token=token,
|
||||
)
|
||||
session.add(token)
|
||||
session.commit()
|
||||
session.refresh(token)
|
||||
|
||||
return convert_to_organization_auth_token(token)
|
||||
|
||||
async def get_artifacts_for_task_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[Artifact]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if artifacts := (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.all()
|
||||
):
|
||||
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
|
||||
else:
|
||||
return []
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_artifact_by_id(
|
||||
self,
|
||||
artifact_id: str,
|
||||
organization_id: str,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if artifact := (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(artifact_id=artifact_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.exception("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.exception("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
artifact_type: ArtifactType,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
artifact = (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(artifact_type=artifact_type)
|
||||
.order_by(ArtifactModel.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_artifact_for_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
artifact_type: ArtifactType,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
artifact = (
|
||||
session.query(ArtifactModel)
|
||||
.join(TaskModel, TaskModel.task_id == ArtifactModel.task_id)
|
||||
.filter(TaskModel.workflow_run_id == workflow_run_id)
|
||||
.filter(ArtifactModel.artifact_type == artifact_type)
|
||||
.filter(ArtifactModel.organization_id == organization_id)
|
||||
.order_by(ArtifactModel.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_latest_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str | None = None,
|
||||
artifact_types: list[ArtifactType] | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
artifact_query = session.query(ArtifactModel).filter_by(task_id=task_id)
|
||||
if step_id:
|
||||
artifact_query = artifact_query.filter_by(step_id=step_id)
|
||||
if organization_id:
|
||||
artifact_query = artifact_query.filter_by(organization_id=organization_id)
|
||||
if artifact_types:
|
||||
artifact_query = artifact_query.filter(ArtifactModel.artifact_type.in_(artifact_types))
|
||||
|
||||
artifact = artifact_query.order_by(ArtifactModel.created_at.desc()).first()
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.exception("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.exception("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_latest_task_by_workflow_id(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_id: str,
|
||||
before: datetime | None = None,
|
||||
) -> Task | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
query = (
|
||||
session.query(TaskModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(workflow_id=workflow_id)
|
||||
)
|
||||
if before:
|
||||
query = query.filter(TaskModel.created_at < before)
|
||||
task = query.order_by(TaskModel.created_at.desc()).first()
|
||||
if task:
|
||||
return convert_to_task(task, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_workflow(
|
||||
self,
|
||||
organization_id: str,
|
||||
title: str,
|
||||
workflow_definition: dict[str, Any],
|
||||
description: str | None = None,
|
||||
) -> Workflow:
|
||||
with self.Session() as session:
|
||||
workflow = WorkflowModel(
|
||||
organization_id=organization_id,
|
||||
title=title,
|
||||
description=description,
|
||||
workflow_definition=workflow_definition,
|
||||
)
|
||||
session.add(workflow)
|
||||
session.commit()
|
||||
session.refresh(workflow)
|
||||
return convert_to_workflow(workflow, self.debug_enabled)
|
||||
|
||||
async def get_workflow(self, workflow_id: str) -> Workflow | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if workflow := session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first():
|
||||
return convert_to_workflow(workflow, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_workflow(
|
||||
self,
|
||||
workflow_id: str,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
workflow_definition: dict[str, Any] | None = None,
|
||||
) -> Workflow | None:
|
||||
with self.Session() as session:
|
||||
workflow = session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first()
|
||||
if workflow:
|
||||
if title:
|
||||
workflow.title = title
|
||||
if description:
|
||||
workflow.description = description
|
||||
if workflow_definition:
|
||||
workflow.workflow_definition = workflow_definition
|
||||
session.commit()
|
||||
session.refresh(workflow)
|
||||
return convert_to_workflow(workflow, self.debug_enabled)
|
||||
LOG.error("Workflow not found, nothing to update", workflow_id=workflow_id)
|
||||
return None
|
||||
|
||||
async def create_workflow_run(
|
||||
self, workflow_id: str, proxy_location: ProxyLocation | None = None, webhook_callback_url: str | None = None
|
||||
) -> WorkflowRun:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_run = WorkflowRunModel(
|
||||
workflow_id=workflow_id,
|
||||
proxy_location=proxy_location,
|
||||
status="created",
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
)
|
||||
session.add(workflow_run)
|
||||
session.commit()
|
||||
session.refresh(workflow_run)
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_workflow_run(self, workflow_run_id: str, status: WorkflowRunStatus) -> WorkflowRun | None:
|
||||
with self.Session() as session:
|
||||
workflow_run = session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first()
|
||||
if workflow_run:
|
||||
workflow_run.status = status
|
||||
session.commit()
|
||||
session.refresh(workflow_run)
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
LOG.error("WorkflowRun not found, nothing to update", workflow_run_id=workflow_run_id)
|
||||
return None
|
||||
|
||||
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if workflow_run := session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first():
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_workflow_runs(self, workflow_id: str) -> list[WorkflowRun]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_runs = session.query(WorkflowRunModel).filter_by(workflow_id=workflow_id).all()
|
||||
return [convert_to_workflow_run(run) for run in workflow_runs]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_workflow_parameter(
|
||||
self,
|
||||
workflow_id: str,
|
||||
workflow_parameter_type: WorkflowParameterType,
|
||||
key: str,
|
||||
default_value: Any,
|
||||
description: str | None = None,
|
||||
) -> WorkflowParameter:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_parameter = WorkflowParameterModel(
|
||||
workflow_id=workflow_id,
|
||||
workflow_parameter_type=workflow_parameter_type,
|
||||
key=key,
|
||||
default_value=default_value,
|
||||
description=description,
|
||||
)
|
||||
session.add(workflow_parameter)
|
||||
session.commit()
|
||||
session.refresh(workflow_parameter)
|
||||
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_aws_secret_parameter(
|
||||
self,
|
||||
workflow_id: str,
|
||||
key: str,
|
||||
aws_key: str,
|
||||
description: str | None = None,
|
||||
) -> AWSSecretParameter:
|
||||
with self.Session() as session:
|
||||
aws_secret_parameter = AWSSecretParameterModel(
|
||||
workflow_id=workflow_id,
|
||||
key=key,
|
||||
aws_key=aws_key,
|
||||
description=description,
|
||||
)
|
||||
session.add(aws_secret_parameter)
|
||||
session.commit()
|
||||
session.refresh(aws_secret_parameter)
|
||||
return convert_to_aws_secret_parameter(aws_secret_parameter)
|
||||
|
||||
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_parameters = session.query(WorkflowParameterModel).filter_by(workflow_id=workflow_id).all()
|
||||
return [convert_to_workflow_parameter(parameter) for parameter in workflow_parameters]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_workflow_parameter(self, workflow_parameter_id: str) -> WorkflowParameter | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if workflow_parameter := (
|
||||
session.query(WorkflowParameterModel).filter_by(workflow_parameter_id=workflow_parameter_id).first()
|
||||
):
|
||||
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_workflow_run_parameter(
|
||||
self, workflow_run_id: str, workflow_parameter_id: str, value: Any
|
||||
) -> WorkflowRunParameter:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_run_parameter = WorkflowRunParameterModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter_id,
|
||||
value=value,
|
||||
)
|
||||
session.add(workflow_run_parameter)
|
||||
session.commit()
|
||||
session.refresh(workflow_run_parameter)
|
||||
workflow_parameter = await self.get_workflow_parameter(workflow_parameter_id)
|
||||
if not workflow_parameter:
|
||||
raise WorkflowParameterNotFound(workflow_parameter_id)
|
||||
return convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_workflow_run_parameters(
|
||||
self, workflow_run_id: str
|
||||
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_run_parameters = (
|
||||
session.query(WorkflowRunParameterModel).filter_by(workflow_run_id=workflow_run_id).all()
|
||||
)
|
||||
results = []
|
||||
for workflow_run_parameter in workflow_run_parameters:
|
||||
workflow_parameter = await self.get_workflow_parameter(workflow_run_parameter.workflow_parameter_id)
|
||||
if not workflow_parameter:
|
||||
raise WorkflowParameterNotFound(
|
||||
workflow_parameter_id=workflow_run_parameter.workflow_parameter_id
|
||||
)
|
||||
results.append(
|
||||
(
|
||||
workflow_parameter,
|
||||
convert_to_workflow_run_parameter(
|
||||
workflow_run_parameter, workflow_parameter, self.debug_enabled
|
||||
),
|
||||
)
|
||||
)
|
||||
return results
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if task := (
|
||||
session.query(TaskModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.order_by(TaskModel.created_at.desc())
|
||||
.first()
|
||||
):
|
||||
return convert_to_task(task, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
tasks = (
|
||||
session.query(TaskModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.order_by(TaskModel.created_at)
|
||||
.all()
|
||||
)
|
||||
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def delete_task_artifacts(self, organization_id: str, task_id: str) -> None:
|
||||
with self.Session() as session:
|
||||
# delete artifacts by filtering organization_id and task_id
|
||||
stmt = delete(ArtifactModel).where(
|
||||
and_(
|
||||
ArtifactModel.organization_id == organization_id,
|
||||
ArtifactModel.task_id == task_id,
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
async def delete_task_steps(self, organization_id: str, task_id: str) -> None:
|
||||
with self.Session() as session:
|
||||
# delete artifacts by filtering organization_id and task_id
|
||||
stmt = delete(StepModel).where(
|
||||
and_(
|
||||
StepModel.organization_id == organization_id,
|
||||
StepModel.task_id == task_id,
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
15
skyvern/forge/sdk/db/enums.py
Normal file
15
skyvern/forge/sdk/db/enums.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class OrganizationAuthTokenType(StrEnum):
|
||||
api = "api"
|
||||
|
||||
|
||||
class ScheduleRuleUnit(StrEnum):
|
||||
# No support for scheduling every second
|
||||
minute = "minute"
|
||||
hour = "hour"
|
||||
day = "day"
|
||||
week = "week"
|
||||
month = "month"
|
||||
year = "year"
|
||||
2
skyvern/forge/sdk/db/exceptions.py
Normal file
2
skyvern/forge/sdk/db/exceptions.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class NotFoundError(Exception):
|
||||
pass
|
||||
136
skyvern/forge/sdk/db/id.py
Normal file
136
skyvern/forge/sdk/db/id.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import hashlib
|
||||
import itertools
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import time
|
||||
|
||||
# 6/20/2022 12AM
|
||||
BASE_EPOCH = 1655683200
|
||||
VERSION = 0
|
||||
|
||||
# Number of bits
|
||||
TIMESTAMP_BITS = 32
|
||||
WORKER_ID_BITS = 21
|
||||
SEQUENCE_BITS = 10
|
||||
VERSION_BITS = 1
|
||||
|
||||
# Bit shits (left)
|
||||
TIMESTAMP_SHIFT = 32
|
||||
WORKER_ID_SHIFT = 11
|
||||
SEQUENCE_SHIFT = 1
|
||||
VERSION_SHIFT = 0
|
||||
|
||||
SEQUENCE_MAX = (2**SEQUENCE_BITS) - 1
|
||||
_sequence_start = None
|
||||
SEQUENCE_COUNTER = itertools.count()
|
||||
_worker_hash = None
|
||||
|
||||
# prefix
|
||||
ORGANIZATION_AUTH_TOKEN_PREFIX = "oat"
|
||||
ORG_PREFIX = "o"
|
||||
TASK_PREFIX = "tsk"
|
||||
USER_PREFIX = "u"
|
||||
STEP_PREFIX = "stp"
|
||||
ARTIFACT_PREFIX = "a"
|
||||
WORKFLOW_PREFIX = "w"
|
||||
WORKFLOW_RUN_PREFIX = "wr"
|
||||
WORKFLOW_PARAMETER_PREFIX = "wp"
|
||||
AWS_SECRET_PARAMETER_PREFIX = "asp"
|
||||
|
||||
|
||||
def generate_workflow_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{WORKFLOW_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_workflow_run_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{WORKFLOW_RUN_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_aws_secret_parameter_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{AWS_SECRET_PARAMETER_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_workflow_parameter_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{WORKFLOW_PARAMETER_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_organization_auth_token_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{ORGANIZATION_AUTH_TOKEN_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_org_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{ORG_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_task_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{TASK_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_step_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{STEP_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_artifact_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{ARTIFACT_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_user_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{USER_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_id() -> int:
|
||||
"""
|
||||
generate a 64-bit int ID
|
||||
"""
|
||||
create_at = current_time() - BASE_EPOCH
|
||||
sequence = _increment_and_get_sequence()
|
||||
|
||||
time_part = _mask_shift(create_at, TIMESTAMP_BITS, TIMESTAMP_SHIFT)
|
||||
worker_part = _mask_shift(_get_worker_hash(), WORKER_ID_BITS, WORKER_ID_SHIFT)
|
||||
sequence_part = _mask_shift(sequence, SEQUENCE_BITS, SEQUENCE_SHIFT)
|
||||
version_part = _mask_shift(VERSION, VERSION_BITS, VERSION_SHIFT)
|
||||
|
||||
return time_part | worker_part | sequence_part | version_part
|
||||
|
||||
|
||||
def _increment_and_get_sequence() -> int:
|
||||
global _sequence_start
|
||||
if _sequence_start is None:
|
||||
_sequence_start = random.randint(0, SEQUENCE_MAX)
|
||||
|
||||
return (_sequence_start + next(SEQUENCE_COUNTER)) % SEQUENCE_MAX
|
||||
|
||||
|
||||
def current_time() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def current_time_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def _mask_shift(value: int, mask_bits: int, shift_bits: int) -> int:
|
||||
return (value & ((2**mask_bits) - 1)) << shift_bits
|
||||
|
||||
|
||||
def _get_worker_hash() -> int:
|
||||
global _worker_hash
|
||||
if _worker_hash is None:
|
||||
_worker_hash = _generate_worker_hash()
|
||||
return _worker_hash
|
||||
|
||||
|
||||
def _generate_worker_hash() -> int:
|
||||
worker_identity = f"{platform.node()}:{os.getpid()}"
|
||||
return int(hashlib.md5(worker_identity.encode()).hexdigest()[-15:], 16)
|
||||
172
skyvern/forge/sdk/db/models.py
Normal file
172
skyvern/forge/sdk/db/models.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import datetime
|
||||
|
||||
from sqlalchemy import JSON, Boolean, Column, DateTime, Enum, ForeignKey, Integer, Numeric, String, UnicodeText
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.db.id import (
|
||||
generate_artifact_id,
|
||||
generate_aws_secret_parameter_id,
|
||||
generate_org_id,
|
||||
generate_organization_auth_token_id,
|
||||
generate_step_id,
|
||||
generate_task_id,
|
||||
generate_workflow_id,
|
||||
generate_workflow_parameter_id,
|
||||
generate_workflow_run_id,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class TaskModel(Base):
|
||||
__tablename__ = "tasks"
|
||||
|
||||
task_id = Column(String, primary_key=True, index=True, default=generate_task_id)
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"))
|
||||
status = Column(String)
|
||||
webhook_callback_url = Column(String)
|
||||
url = Column(String)
|
||||
navigation_goal = Column(String)
|
||||
data_extraction_goal = Column(String)
|
||||
navigation_payload = Column(JSON)
|
||||
extracted_information = Column(JSON)
|
||||
failure_reason = Column(String)
|
||||
proxy_location = Column(Enum(ProxyLocation))
|
||||
extracted_information_schema = Column(JSON)
|
||||
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"))
|
||||
order = Column(Integer, nullable=True)
|
||||
retry = Column(Integer, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class StepModel(Base):
|
||||
__tablename__ = "steps"
|
||||
|
||||
step_id = Column(String, primary_key=True, index=True, default=generate_step_id)
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"))
|
||||
task_id = Column(String, ForeignKey("tasks.task_id"))
|
||||
status = Column(String)
|
||||
output = Column(JSON)
|
||||
order = Column(Integer)
|
||||
is_last = Column(Boolean, default=False)
|
||||
retry_index = Column(Integer, default=0)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
input_token_count = Column(Integer, default=0)
|
||||
output_token_count = Column(Integer, default=0)
|
||||
step_cost = Column(Numeric, default=0)
|
||||
|
||||
|
||||
class OrganizationModel(Base):
|
||||
__tablename__ = "organizations"
|
||||
|
||||
organization_id = Column(String, primary_key=True, index=True, default=generate_org_id)
|
||||
organization_name = Column(String, nullable=False)
|
||||
webhook_callback_url = Column(UnicodeText)
|
||||
max_steps_per_run = Column(Integer)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
|
||||
|
||||
|
||||
class OrganizationAuthTokenModel(Base):
|
||||
__tablename__ = "organization_auth_tokens"
|
||||
|
||||
id = Column(
|
||||
String,
|
||||
primary_key=True,
|
||||
index=True,
|
||||
default=generate_organization_auth_token_id,
|
||||
)
|
||||
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"), index=True, nullable=False)
|
||||
token_type = Column(Enum(OrganizationAuthTokenType), nullable=False)
|
||||
token = Column(String, index=True, nullable=False)
|
||||
valid = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class ArtifactModel(Base):
|
||||
__tablename__ = "artifacts"
|
||||
|
||||
artifact_id = Column(String, primary_key=True, index=True, default=generate_artifact_id)
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"))
|
||||
task_id = Column(String, ForeignKey("tasks.task_id"))
|
||||
step_id = Column(String, ForeignKey("steps.step_id"))
|
||||
artifact_type = Column(String)
|
||||
uri = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class WorkflowModel(Base):
|
||||
__tablename__ = "workflows"
|
||||
|
||||
workflow_id = Column(String, primary_key=True, index=True, default=generate_workflow_id)
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"))
|
||||
title = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
workflow_definition = Column(JSON, nullable=False)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class WorkflowRunModel(Base):
|
||||
__tablename__ = "workflow_runs"
|
||||
|
||||
workflow_run_id = Column(String, primary_key=True, index=True, default=generate_workflow_run_id)
|
||||
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), nullable=False)
|
||||
status = Column(String, nullable=False)
|
||||
proxy_location = Column(Enum(ProxyLocation))
|
||||
webhook_callback_url = Column(String)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class WorkflowParameterModel(Base):
|
||||
__tablename__ = "workflow_parameters"
|
||||
|
||||
workflow_parameter_id = Column(String, primary_key=True, index=True, default=generate_workflow_parameter_id)
|
||||
workflow_parameter_type = Column(String, nullable=False)
|
||||
key = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
|
||||
default_value = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class AWSSecretParameterModel(Base):
|
||||
__tablename__ = "aws_secret_parameters"
|
||||
|
||||
aws_secret_parameter_id = Column(String, primary_key=True, index=True, default=generate_aws_secret_parameter_id)
|
||||
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
|
||||
key = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
aws_key = Column(String, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class WorkflowRunParameterModel(Base):
|
||||
__tablename__ = "workflow_run_parameters"
|
||||
|
||||
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), primary_key=True, index=True)
|
||||
workflow_parameter_id = Column(
|
||||
String, ForeignKey("workflow_parameters.workflow_parameter_id"), primary_key=True, index=True
|
||||
)
|
||||
# Can be bool | int | float | str | dict | list depending on the workflow parameter type
|
||||
value = Column(String, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
220
skyvern/forge/sdk/db/utils.py
Normal file
220
skyvern/forge/sdk/db/utils.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import json
|
||||
import typing
|
||||
|
||||
import pydantic.json
|
||||
import structlog
|
||||
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
ArtifactModel,
|
||||
AWSSecretParameterModel,
|
||||
OrganizationAuthTokenModel,
|
||||
OrganizationModel,
|
||||
StepModel,
|
||||
TaskModel,
|
||||
WorkflowModel,
|
||||
WorkflowParameterModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowRunParameterModel,
|
||||
)
|
||||
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
|
||||
from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
Workflow,
|
||||
WorkflowDefinition,
|
||||
WorkflowRun,
|
||||
WorkflowRunParameter,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _custom_json_serializer(*args, **kwargs) -> str:
|
||||
"""
|
||||
Encodes json in the same way that pydantic does.
|
||||
"""
|
||||
return json.dumps(*args, default=pydantic.json.pydantic_encoder, **kwargs)
|
||||
|
||||
|
||||
def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
|
||||
if debug_enabled:
|
||||
LOG.debug("Converting TaskModel to Task", task_id=task_obj.task_id)
|
||||
task = Task(
|
||||
task_id=task_obj.task_id,
|
||||
status=TaskStatus(task_obj.status),
|
||||
created_at=task_obj.created_at,
|
||||
modified_at=task_obj.modified_at,
|
||||
url=task_obj.url,
|
||||
webhook_callback_url=task_obj.webhook_callback_url,
|
||||
navigation_goal=task_obj.navigation_goal,
|
||||
data_extraction_goal=task_obj.data_extraction_goal,
|
||||
navigation_payload=task_obj.navigation_payload,
|
||||
extracted_information=task_obj.extracted_information,
|
||||
failure_reason=task_obj.failure_reason,
|
||||
organization_id=task_obj.organization_id,
|
||||
proxy_location=ProxyLocation(task_obj.proxy_location) if task_obj.proxy_location else None,
|
||||
extracted_information_schema=task_obj.extracted_information_schema,
|
||||
workflow_run_id=task_obj.workflow_run_id,
|
||||
order=task_obj.order,
|
||||
retry=task_obj.retry,
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
|
||||
if debug_enabled:
|
||||
LOG.debug("Converting StepModel to Step", step_id=step_model.step_id)
|
||||
return Step(
|
||||
task_id=step_model.task_id,
|
||||
step_id=step_model.step_id,
|
||||
created_at=step_model.created_at,
|
||||
modified_at=step_model.modified_at,
|
||||
status=StepStatus(step_model.status),
|
||||
output=step_model.output,
|
||||
order=step_model.order,
|
||||
is_last=step_model.is_last,
|
||||
retry_index=step_model.retry_index,
|
||||
organization_id=step_model.organization_id,
|
||||
input_token_count=step_model.input_token_count,
|
||||
output_token_count=step_model.output_token_count,
|
||||
step_cost=step_model.step_cost,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_organization(org_model: OrganizationModel) -> Organization:
|
||||
return Organization(
|
||||
organization_id=org_model.organization_id,
|
||||
organization_name=org_model.organization_name,
|
||||
webhook_callback_url=org_model.webhook_callback_url,
|
||||
max_steps_per_run=org_model.max_steps_per_run,
|
||||
created_at=org_model.created_at,
|
||||
modified_at=org_model.modified_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_organization_auth_token(org_auth_token: OrganizationAuthTokenModel) -> OrganizationAuthToken:
|
||||
return OrganizationAuthToken(
|
||||
id=org_auth_token.id,
|
||||
organization_id=org_auth_token.organization_id,
|
||||
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
|
||||
token=org_auth_token.token,
|
||||
valid=org_auth_token.valid,
|
||||
created_at=org_auth_token.created_at,
|
||||
modified_at=org_auth_token.modified_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_artifact(artifact_model: ArtifactModel, debug_enabled: bool = False) -> Artifact:
|
||||
if debug_enabled:
|
||||
LOG.debug("Converting ArtifactModel to Artifact", artifact_id=artifact_model.artifact_id)
|
||||
|
||||
return Artifact(
|
||||
artifact_id=artifact_model.artifact_id,
|
||||
artifact_type=ArtifactType[artifact_model.artifact_type.upper()],
|
||||
uri=artifact_model.uri,
|
||||
task_id=artifact_model.task_id,
|
||||
step_id=artifact_model.step_id,
|
||||
created_at=artifact_model.created_at,
|
||||
modified_at=artifact_model.modified_at,
|
||||
organization_id=artifact_model.organization_id,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = False) -> Workflow:
|
||||
if debug_enabled:
|
||||
LOG.debug("Converting WorkflowModel to Workflow", workflow_id=workflow_model.workflow_id)
|
||||
|
||||
return Workflow(
|
||||
workflow_id=workflow_model.workflow_id,
|
||||
organization_id=workflow_model.organization_id,
|
||||
title=workflow_model.title,
|
||||
description=workflow_model.description,
|
||||
workflow_definition=WorkflowDefinition.model_validate(workflow_model.workflow_definition),
|
||||
created_at=workflow_model.created_at,
|
||||
modified_at=workflow_model.modified_at,
|
||||
deleted_at=workflow_model.deleted_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_workflow_run(workflow_run_model: WorkflowRunModel, debug_enabled: bool = False) -> WorkflowRun:
|
||||
if debug_enabled:
|
||||
LOG.debug("Converting WorkflowRunModel to WorkflowRun", workflow_run_id=workflow_run_model.workflow_run_id)
|
||||
|
||||
return WorkflowRun(
|
||||
workflow_run_id=workflow_run_model.workflow_run_id,
|
||||
workflow_id=workflow_run_model.workflow_id,
|
||||
status=WorkflowRunStatus[workflow_run_model.status],
|
||||
proxy_location=ProxyLocation(workflow_run_model.proxy_location) if workflow_run_model.proxy_location else None,
|
||||
webhook_callback_url=workflow_run_model.webhook_callback_url,
|
||||
created_at=workflow_run_model.created_at,
|
||||
modified_at=workflow_run_model.modified_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_workflow_parameter(
|
||||
workflow_parameter_model: WorkflowParameterModel, debug_enabled: bool = False
|
||||
) -> WorkflowParameter:
|
||||
if debug_enabled:
|
||||
LOG.debug(
|
||||
"Converting WorkflowParameterModel to WorkflowParameter",
|
||||
workflow_parameter_id=workflow_parameter_model.workflow_parameter_id,
|
||||
)
|
||||
|
||||
workflow_parameter_type = WorkflowParameterType[workflow_parameter_model.workflow_parameter_type.upper()]
|
||||
|
||||
return WorkflowParameter(
|
||||
workflow_parameter_id=workflow_parameter_model.workflow_parameter_id,
|
||||
workflow_parameter_type=workflow_parameter_type,
|
||||
workflow_id=workflow_parameter_model.workflow_id,
|
||||
default_value=workflow_parameter_type.convert_value(workflow_parameter_model.default_value),
|
||||
key=workflow_parameter_model.key,
|
||||
description=workflow_parameter_model.description,
|
||||
created_at=workflow_parameter_model.created_at,
|
||||
modified_at=workflow_parameter_model.modified_at,
|
||||
deleted_at=workflow_parameter_model.deleted_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_aws_secret_parameter(
|
||||
aws_secret_parameter_model: AWSSecretParameterModel, debug_enabled: bool = False
|
||||
) -> AWSSecretParameter:
|
||||
if debug_enabled:
|
||||
LOG.debug(
|
||||
"Converting AWSSecretParameterModel to AWSSecretParameter",
|
||||
aws_secret_parameter_id=aws_secret_parameter_model.id,
|
||||
)
|
||||
|
||||
return AWSSecretParameter(
|
||||
aws_secret_parameter_id=aws_secret_parameter_model.aws_secret_parameter_id,
|
||||
workflow_id=aws_secret_parameter_model.workflow_id,
|
||||
key=aws_secret_parameter_model.key,
|
||||
description=aws_secret_parameter_model.description,
|
||||
aws_key=aws_secret_parameter_model.aws_key,
|
||||
created_at=aws_secret_parameter_model.created_at,
|
||||
modified_at=aws_secret_parameter_model.modified_at,
|
||||
deleted_at=aws_secret_parameter_model.deleted_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_workflow_run_parameter(
|
||||
workflow_run_parameter_model: WorkflowRunParameterModel,
|
||||
workflow_parameter: WorkflowParameter,
|
||||
debug_enabled: bool = False,
|
||||
) -> WorkflowRunParameter:
|
||||
if debug_enabled:
|
||||
LOG.debug(
|
||||
"Converting WorkflowRunParameterModel to WorkflowRunParameter",
|
||||
workflow_run_id=workflow_run_parameter_model.workflow_run_id,
|
||||
workflow_parameter_id=workflow_run_parameter_model.workflow_parameter_id,
|
||||
)
|
||||
|
||||
return WorkflowRunParameter(
|
||||
workflow_run_id=workflow_run_parameter_model.workflow_run_id,
|
||||
workflow_parameter_id=workflow_run_parameter_model.workflow_parameter_id,
|
||||
value=workflow_parameter.workflow_parameter_type.convert_value(workflow_run_parameter_model.value),
|
||||
created_at=workflow_run_parameter_model.created_at,
|
||||
)
|
||||
0
skyvern/forge/sdk/executor/__init__.py
Normal file
0
skyvern/forge/sdk/executor/__init__.py
Normal file
85
skyvern/forge/sdk/executor/async_executor.py
Normal file
85
skyvern/forge/sdk/executor/async_executor.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import abc
|
||||
|
||||
from fastapi import BackgroundTasks
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.models import Organization
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
|
||||
|
||||
class AsyncExecutor(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def execute_task(
|
||||
self,
|
||||
background_tasks: BackgroundTasks,
|
||||
task: Task,
|
||||
organization: Organization,
|
||||
max_steps_override: int | None,
|
||||
api_key: str | None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def execute_workflow(
|
||||
self,
|
||||
background_tasks: BackgroundTasks,
|
||||
organization: Organization,
|
||||
workflow_id: str,
|
||||
workflow_run_id: str,
|
||||
max_steps_override: int | None,
|
||||
api_key: str | None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class BackgroundTaskExecutor(AsyncExecutor):
|
||||
async def execute_task(
|
||||
self,
|
||||
background_tasks: BackgroundTasks,
|
||||
task: Task,
|
||||
organization: Organization,
|
||||
max_steps_override: int | None,
|
||||
api_key: str | None,
|
||||
) -> None:
|
||||
step = await app.DATABASE.create_step(
|
||||
task.task_id,
|
||||
order=0,
|
||||
retry_index=0,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
|
||||
task = await app.DATABASE.update_task(
|
||||
task.task_id,
|
||||
TaskStatus.running,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
|
||||
context: SkyvernContext = skyvern_context.ensure_context()
|
||||
context.task_id = task.task_id
|
||||
context.organization_id = organization.organization_id
|
||||
context.max_steps_override = max_steps_override
|
||||
|
||||
background_tasks.add_task(
|
||||
app.agent.execute_step,
|
||||
organization,
|
||||
task,
|
||||
step,
|
||||
api_key,
|
||||
)
|
||||
|
||||
async def execute_workflow(
|
||||
self,
|
||||
background_tasks: BackgroundTasks,
|
||||
organization: Organization,
|
||||
workflow_id: str,
|
||||
workflow_run_id: str,
|
||||
max_steps_override: int | None,
|
||||
api_key: str | None,
|
||||
) -> None:
|
||||
background_tasks.add_task(
|
||||
app.WORKFLOW_SERVICE.execute_workflow,
|
||||
workflow_run_id=workflow_run_id,
|
||||
api_key=api_key,
|
||||
)
|
||||
13
skyvern/forge/sdk/executor/factory.py
Normal file
13
skyvern/forge/sdk/executor/factory.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from skyvern.forge.sdk.executor.async_executor import AsyncExecutor, BackgroundTaskExecutor
|
||||
|
||||
|
||||
class AsyncExecutorFactory:
|
||||
__instance: AsyncExecutor = BackgroundTaskExecutor()
|
||||
|
||||
@staticmethod
|
||||
def set_executor(executor: AsyncExecutor) -> None:
|
||||
AsyncExecutorFactory.__instance = executor
|
||||
|
||||
@staticmethod
|
||||
def get_executor() -> AsyncExecutor:
|
||||
return AsyncExecutorFactory.__instance
|
||||
90
skyvern/forge/sdk/forge_log.py
Normal file
90
skyvern/forge/sdk/forge_log.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import logging
|
||||
|
||||
import structlog
|
||||
from structlog.typing import EventDict
|
||||
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
|
||||
def add_kv_pairs_to_msg(logger: logging.Logger, method_name: str, event_dict: EventDict) -> EventDict:
|
||||
"""
|
||||
A custom processor to add key-value pairs to the 'msg' field.
|
||||
"""
|
||||
# Add context to the log
|
||||
context = skyvern_context.current()
|
||||
if context:
|
||||
if context.request_id:
|
||||
event_dict["request_id"] = context.request_id
|
||||
if context.organization_id:
|
||||
event_dict["organization_id"] = context.organization_id
|
||||
if context.task_id:
|
||||
event_dict["task_id"] = context.task_id
|
||||
if context.workflow_id:
|
||||
event_dict["workflow_id"] = context.workflow_id
|
||||
if context.workflow_run_id:
|
||||
event_dict["workflow_run_id"] = context.workflow_run_id
|
||||
|
||||
# Add env to the log
|
||||
event_dict["env"] = SettingsManager.get_settings().ENV
|
||||
|
||||
if method_name not in ["info", "warning", "error", "critical", "exception"]:
|
||||
# Only modify the log for these log levels
|
||||
return event_dict
|
||||
|
||||
# Assuming 'event' or 'msg' is the field to update
|
||||
msg_field = event_dict.get("msg", "")
|
||||
|
||||
# Add key-value pairs
|
||||
kv_pairs = {k: v for k, v in event_dict.items() if k not in ["msg", "timestamp", "level"]}
|
||||
if kv_pairs:
|
||||
additional_info = ", ".join(f"{k}={v}" for k, v in kv_pairs.items())
|
||||
msg_field += f" | {additional_info}"
|
||||
|
||||
event_dict["msg"] = msg_field
|
||||
|
||||
return event_dict
|
||||
|
||||
|
||||
def setup_logger() -> None:
|
||||
"""
|
||||
Setup the logger with the specified format
|
||||
"""
|
||||
# logging.config.dictConfig(logging_config)
|
||||
renderer = (
|
||||
structlog.processors.JSONRenderer()
|
||||
if SettingsManager.get_settings().JSON_LOGGING
|
||||
else structlog.dev.ConsoleRenderer()
|
||||
)
|
||||
additional_processors = (
|
||||
[
|
||||
structlog.processors.EventRenamer("msg"),
|
||||
add_kv_pairs_to_msg,
|
||||
structlog.processors.CallsiteParameterAdder(
|
||||
{
|
||||
structlog.processors.CallsiteParameter.PATHNAME,
|
||||
structlog.processors.CallsiteParameter.FILENAME,
|
||||
structlog.processors.CallsiteParameter.MODULE,
|
||||
structlog.processors.CallsiteParameter.FUNC_NAME,
|
||||
structlog.processors.CallsiteParameter.LINENO,
|
||||
}
|
||||
),
|
||||
]
|
||||
if SettingsManager.get_settings().JSON_LOGGING
|
||||
else []
|
||||
)
|
||||
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
# structlog.processors.dict_tracebacks,
|
||||
structlog.processors.format_exc_info,
|
||||
]
|
||||
+ additional_processors
|
||||
+ [renderer]
|
||||
)
|
||||
uvicorn_error = logging.getLogger("uvicorn.error")
|
||||
uvicorn_error.disabled = True
|
||||
uvicorn_access = logging.getLogger("uvicorn.access")
|
||||
uvicorn_access.disabled = True
|
||||
137
skyvern/forge/sdk/models.py
Normal file
137
skyvern/forge/sdk/models.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.webeye.actions.actions import ActionType
|
||||
from skyvern.webeye.actions.models import AgentStepOutput
|
||||
|
||||
|
||||
class StepStatus(StrEnum):
|
||||
created = "created"
|
||||
running = "running"
|
||||
failed = "failed"
|
||||
completed = "completed"
|
||||
|
||||
def can_update_to(self, new_status: StepStatus) -> bool:
|
||||
allowed_transitions: dict[StepStatus, set[StepStatus]] = {
|
||||
StepStatus.created: {StepStatus.running},
|
||||
StepStatus.running: {StepStatus.completed, StepStatus.failed},
|
||||
StepStatus.failed: set(),
|
||||
StepStatus.completed: set(),
|
||||
}
|
||||
return new_status in allowed_transitions[self]
|
||||
|
||||
def requires_output(self) -> bool:
|
||||
status_requires_output = {StepStatus.completed}
|
||||
return self in status_requires_output
|
||||
|
||||
def cant_have_output(self) -> bool:
|
||||
status_cant_have_output = {StepStatus.created, StepStatus.running}
|
||||
return self in status_cant_have_output
|
||||
|
||||
def is_terminal(self) -> bool:
|
||||
status_is_terminal = {StepStatus.failed, StepStatus.completed}
|
||||
return self in status_is_terminal
|
||||
|
||||
|
||||
class Step(BaseModel):
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
task_id: str
|
||||
step_id: str
|
||||
status: StepStatus
|
||||
output: AgentStepOutput | None = None
|
||||
order: int
|
||||
is_last: bool
|
||||
retry_index: int = 0
|
||||
organization_id: str | None = None
|
||||
input_token_count: int = 0
|
||||
output_token_count: int = 0
|
||||
step_cost: float = 0
|
||||
|
||||
def validate_update(self, status: StepStatus | None, output: AgentStepOutput | None, is_last: bool | None) -> None:
|
||||
old_status = self.status
|
||||
|
||||
if status and not old_status.can_update_to(status):
|
||||
raise ValueError(f"invalid_status_transition({old_status},{status},{self.step_id})")
|
||||
|
||||
if status and status.requires_output() and output is None:
|
||||
raise ValueError(f"status_requires_output({status},{self.step_id})")
|
||||
|
||||
if status and status.cant_have_output() and output is not None:
|
||||
raise ValueError(f"status_cant_have_output({status},{self.step_id})")
|
||||
|
||||
if output is not None and status is None:
|
||||
raise ValueError(f"cant_set_output_without_updating_status({self.step_id})")
|
||||
|
||||
if self.output is not None and output is not None:
|
||||
raise ValueError(f"cant_override_output({self.step_id})")
|
||||
|
||||
if is_last and not self.status.is_terminal():
|
||||
raise ValueError(f"is_last_but_status_not_terminal({self.status},{self.step_id})")
|
||||
|
||||
if is_last is False:
|
||||
raise ValueError(f"cant_set_is_last_to_false({self.step_id})")
|
||||
|
||||
def is_goal_achieved(self) -> bool:
|
||||
if self.status != StepStatus.completed:
|
||||
return False
|
||||
# TODO (kerem): Remove this check once we have backfilled all the steps
|
||||
if self.output is None or self.output.actions_and_results is None:
|
||||
return False
|
||||
|
||||
# Check if there is a successful complete action
|
||||
for action, action_results in self.output.actions_and_results:
|
||||
if action.action_type != ActionType.COMPLETE:
|
||||
continue
|
||||
|
||||
if any(action_result.success for action_result in action_results):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
if self.status != StepStatus.completed:
|
||||
return False
|
||||
# TODO (kerem): Remove this check once we have backfilled all the steps
|
||||
if self.output is None or self.output.actions_and_results is None:
|
||||
return False
|
||||
|
||||
# Check if there is a successful terminate action
|
||||
for action, action_results in self.output.actions_and_results:
|
||||
if action.action_type != ActionType.TERMINATE:
|
||||
continue
|
||||
|
||||
if any(action_result.success for action_result in action_results):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class Organization(BaseModel):
|
||||
organization_id: str
|
||||
organization_name: str
|
||||
webhook_callback_url: str | None = None
|
||||
max_steps_per_run: int | None = None
|
||||
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
|
||||
|
||||
class OrganizationAuthToken(BaseModel):
|
||||
id: str
|
||||
organization_id: str
|
||||
token_type: OrganizationAuthTokenType
|
||||
token: str
|
||||
valid: bool
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str
|
||||
exp: int
|
||||
98
skyvern/forge/sdk/prompting.py
Normal file
98
skyvern/forge/sdk/prompting.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Relative to this file I will have a prompt directory its located ../prompts
|
||||
In this directory there will be a techniques directory and a directory for each model - gpt-3.5-turbo gpt-4, llama-2-70B, code-llama-7B etc
|
||||
|
||||
Each directory will have jinga2 templates for the prompts.
|
||||
prompts in the model directories can use the techniques in the techniques directory.
|
||||
|
||||
Write the code I'd need to load and populate the templates.
|
||||
|
||||
I want the following functions:
|
||||
|
||||
class PromptEngine:
|
||||
|
||||
def __init__(self, model):
|
||||
pass
|
||||
|
||||
def load_prompt(model, prompt_name, prompt_ags) -> str:
|
||||
pass
|
||||
"""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from difflib import get_close_matches
|
||||
from typing import Any, List
|
||||
|
||||
import structlog
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class PromptEngine:
|
||||
"""
|
||||
Class to handle loading and populating Jinja2 templates for prompts.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str):
|
||||
"""
|
||||
Initialize the PromptEngine with the specified model.
|
||||
|
||||
Args:
|
||||
model (str): The model to use for loading prompts.
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
try:
|
||||
# Get the list of all model directories
|
||||
models_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../prompts"))
|
||||
model_names = [
|
||||
os.path.basename(os.path.normpath(d))
|
||||
for d in glob.glob(os.path.join(models_dir, "*/"))
|
||||
if os.path.isdir(d) and "techniques" not in d
|
||||
]
|
||||
|
||||
self.model = self.get_closest_match(self.model, model_names)
|
||||
|
||||
self.env = Environment(loader=FileSystemLoader(models_dir))
|
||||
except Exception:
|
||||
LOG.error("Error initializing PromptEngine.", model=model, exc_info=True)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_closest_match(target: str, model_dirs: List[str]) -> str:
|
||||
"""
|
||||
Find the closest match to the target in the list of model directories.
|
||||
|
||||
Args:
|
||||
target (str): The target model.
|
||||
model_dirs (list): The list of available model directories.
|
||||
|
||||
Returns:
|
||||
str: The closest match to the target.
|
||||
"""
|
||||
try:
|
||||
matches = get_close_matches(target, model_dirs, n=1, cutoff=0.1)
|
||||
return matches[0]
|
||||
except Exception:
|
||||
LOG.error("Failed to get closest match.", target=target, model_dirs=model_dirs, exc_info=True)
|
||||
raise
|
||||
|
||||
def load_prompt(self, template: str, **kwargs: Any) -> str:
|
||||
"""
|
||||
Load and populate the specified template.
|
||||
|
||||
Args:
|
||||
template (str): The name of the template to load.
|
||||
**kwargs: The arguments to populate the template with.
|
||||
|
||||
Returns:
|
||||
str: The populated template.
|
||||
"""
|
||||
try:
|
||||
template = os.path.join(self.model, template)
|
||||
jinja_template = self.env.get_template(f"{template}.j2")
|
||||
return jinja_template.render(**kwargs)
|
||||
except Exception:
|
||||
LOG.error("Failed to load prompt.", template=template, kwargs_keys=kwargs.keys(), exc_info=True)
|
||||
raise
|
||||
0
skyvern/forge/sdk/routes/__init__.py
Normal file
0
skyvern/forge/sdk/routes/__init__.py
Normal file
397
skyvern/forge/sdk/routes/agent_protocol.py
Normal file
397
skyvern/forge/sdk/routes/agent_protocol.py
Normal file
@@ -0,0 +1,397 @@
|
||||
from typing import Annotated, Any
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, Query, Request, Response, status
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.exceptions import StepNotFound
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.models import Organization, Step
|
||||
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
|
||||
from skyvern.forge.sdk.services import org_auth_service
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
RunWorkflowResponse,
|
||||
WorkflowRequestBody,
|
||||
WorkflowRunStatusResponse,
|
||||
)
|
||||
|
||||
base_router = APIRouter()
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
@base_router.post("/webhook", tags=["server"])
|
||||
async def webhook(
|
||||
request: Request,
|
||||
x_skyvern_signature: Annotated[str | None, Header()] = None,
|
||||
x_skyvern_timestamp: Annotated[str | None, Header()] = None,
|
||||
) -> Response:
|
||||
payload = await request.body()
|
||||
|
||||
if not x_skyvern_signature or not x_skyvern_timestamp:
|
||||
LOG.error(
|
||||
"Webhook signature or timestamp missing",
|
||||
x_skyvern_signature=x_skyvern_signature,
|
||||
x_skyvern_timestamp=x_skyvern_timestamp,
|
||||
payload=payload,
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Missing webhook signature or timestamp")
|
||||
|
||||
generated_signature = generate_skyvern_signature(
|
||||
payload.decode("utf-8"),
|
||||
SettingsManager.get_settings().SKYVERN_API_KEY,
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
"Webhook received",
|
||||
x_skyvern_signature=x_skyvern_signature,
|
||||
x_skyvern_timestamp=x_skyvern_timestamp,
|
||||
payload=payload,
|
||||
generated_signature=generated_signature,
|
||||
valid_signature=x_skyvern_signature == generated_signature,
|
||||
)
|
||||
return Response(content="webhook validation", status_code=200)
|
||||
|
||||
|
||||
@base_router.get("/heartbeat", tags=["server"])
|
||||
async def check_server_status() -> Response:
|
||||
"""
|
||||
Check if the server is running.
|
||||
"""
|
||||
return Response(content="Server is running.", status_code=200)
|
||||
|
||||
|
||||
@base_router.post("/tasks", tags=["agent"], response_model=CreateTaskResponse)
|
||||
async def create_agent_task(
|
||||
background_tasks: BackgroundTasks,
|
||||
request: Request,
|
||||
task: TaskRequest,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
x_max_steps_override: Annotated[int | None, Header()] = None,
|
||||
) -> CreateTaskResponse:
|
||||
agent = request["agent"]
|
||||
|
||||
created_task = await agent.create_task(task, current_org.organization_id)
|
||||
if x_max_steps_override:
|
||||
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
|
||||
await app.ASYNC_EXECUTOR.execute_task(
|
||||
background_tasks=background_tasks,
|
||||
task=created_task,
|
||||
organization=current_org,
|
||||
max_steps_override=x_max_steps_override,
|
||||
api_key=x_api_key,
|
||||
)
|
||||
return CreateTaskResponse(task_id=created_task.task_id)
|
||||
|
||||
|
||||
@base_router.post(
|
||||
"/tasks/{task_id}/steps/{step_id}",
|
||||
tags=["agent"],
|
||||
response_model=Step,
|
||||
summary="Executes a specific step",
|
||||
)
|
||||
@base_router.post(
|
||||
"/tasks/{task_id}/steps/",
|
||||
tags=["agent"],
|
||||
response_model=Step,
|
||||
summary="Executes the next step",
|
||||
)
|
||||
async def execute_agent_task_step(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
step_id: str | None = None,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> Response:
|
||||
agent = request["agent"]
|
||||
task = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No task found with id {task_id}",
|
||||
)
|
||||
# An empty step request means that the agent should execute the next step for the task.
|
||||
if not step_id:
|
||||
step = await app.DATABASE.get_latest_step(task_id=task_id, organization_id=current_org.organization_id)
|
||||
if not step:
|
||||
raise StepNotFound(current_org.organization_id, task_id)
|
||||
LOG.info(
|
||||
"Executing latest step since no step_id was provided",
|
||||
task_id=task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
)
|
||||
if not step:
|
||||
LOG.error(
|
||||
"No steps found for task",
|
||||
task_id=task_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No steps found for task {task_id}",
|
||||
)
|
||||
else:
|
||||
step = await app.DATABASE.get_step(task_id, step_id, organization_id=current_org.organization_id)
|
||||
if not step:
|
||||
raise StepNotFound(current_org.organization_id, task_id, step_id)
|
||||
LOG.info(
|
||||
"Executing step",
|
||||
task_id=task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
)
|
||||
if not step:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No step found with id {step_id}",
|
||||
)
|
||||
step, _, _ = await agent.execute_step(current_org, task, step)
|
||||
return Response(
|
||||
content=step.model_dump_json() if step else "",
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@base_router.get("/tasks/{task_id}", response_model=TaskResponse)
|
||||
async def get_task(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> TaskResponse:
|
||||
request["agent"]
|
||||
task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
||||
if not task_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Task not found {task_id}",
|
||||
)
|
||||
|
||||
# get latest step
|
||||
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=current_org.organization_id)
|
||||
if not latest_step:
|
||||
return task_obj.to_task_response()
|
||||
|
||||
screenshot_url = None
|
||||
# todo (kerem): only access artifacts through the artifact manager instead of db
|
||||
screenshot_artifact = await app.DATABASE.get_latest_artifact(
|
||||
task_id=task_obj.task_id,
|
||||
step_id=latest_step.step_id,
|
||||
artifact_types=[ArtifactType.SCREENSHOT_ACTION, ArtifactType.SCREENSHOT_FINAL],
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
if screenshot_artifact:
|
||||
screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(screenshot_artifact)
|
||||
|
||||
recording_artifact = await app.DATABASE.get_latest_artifact(
|
||||
task_id=task_obj.task_id,
|
||||
artifact_types=[ArtifactType.RECORDING],
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
recording_url = None
|
||||
if recording_artifact:
|
||||
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
|
||||
|
||||
failure_reason = None
|
||||
if task_obj.status == TaskStatus.failed and (latest_step.output or task_obj.failure_reason):
|
||||
failure_reason = ""
|
||||
if task_obj.failure_reason:
|
||||
failure_reason += f"Reasoning: {task_obj.failure_reason or ''}"
|
||||
failure_reason += "\n"
|
||||
if latest_step.output and latest_step.output.action_results:
|
||||
failure_reason += "Exceptions: "
|
||||
failure_reason += str(
|
||||
[f"[{ar.exception_type}]: {ar.exception_message}" for ar in latest_step.output.action_results]
|
||||
)
|
||||
|
||||
return task_obj.to_task_response(
|
||||
screenshot_url=screenshot_url,
|
||||
recording_url=recording_url,
|
||||
failure_reason=failure_reason,
|
||||
)
|
||||
|
||||
|
||||
@base_router.get("/internal/tasks/{task_id}", response_model=list[Task])
|
||||
async def get_task_internal(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> Response:
|
||||
"""
|
||||
Get all tasks.
|
||||
:param request:
|
||||
:param page: Starting page, defaults to 1
|
||||
:param page_size:
|
||||
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
|
||||
get_agent_task endpoint.
|
||||
"""
|
||||
task = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Task not found {task_id}",
|
||||
)
|
||||
return ORJSONResponse(task.model_dump())
|
||||
|
||||
|
||||
@base_router.get("/tasks", tags=["agent"], response_model=list[Task])
|
||||
async def get_agent_tasks(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1),
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> Response:
|
||||
"""
|
||||
Get all tasks.
|
||||
:param request:
|
||||
:param page: Starting page, defaults to 1
|
||||
:param page_size: Page size, defaults to 10
|
||||
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
|
||||
get_agent_task endpoint.
|
||||
"""
|
||||
request["agent"]
|
||||
tasks = await app.DATABASE.get_tasks(page, page_size, organization_id=current_org.organization_id)
|
||||
return ORJSONResponse([task.to_task_response().model_dump() for task in tasks])
|
||||
|
||||
|
||||
@base_router.get("/internal/tasks", tags=["agent"], response_model=list[Task])
|
||||
async def get_agent_tasks_internal(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1),
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> Response:
|
||||
"""
|
||||
Get all tasks.
|
||||
:param request:
|
||||
:param page: Starting page, defaults to 1
|
||||
:param page_size: Page size, defaults to 10
|
||||
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
|
||||
get_agent_task endpoint.
|
||||
"""
|
||||
request["agent"]
|
||||
tasks = await app.DATABASE.get_tasks(page, page_size, organization_id=current_org.organization_id)
|
||||
return ORJSONResponse([task.model_dump() for task in tasks])
|
||||
|
||||
|
||||
@base_router.get("/tasks/{task_id}/steps", tags=["agent"], response_model=list[Step])
|
||||
async def get_agent_task_steps(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> Response:
|
||||
"""
|
||||
Get all steps for a task.
|
||||
:param request:
|
||||
:param task_id:
|
||||
:return: List of steps for a task with pagination.
|
||||
"""
|
||||
request["agent"]
|
||||
steps = await app.DATABASE.get_task_steps(task_id, organization_id=current_org.organization_id)
|
||||
return ORJSONResponse([step.model_dump() for step in steps])
|
||||
|
||||
|
||||
@base_router.get("/tasks/{task_id}/steps/{step_id}/artifacts", tags=["agent"], response_model=list[Artifact])
|
||||
async def get_agent_task_step_artifacts(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> Response:
|
||||
"""
|
||||
Get all artifacts for a list of steps.
|
||||
:param request:
|
||||
:param task_id:
|
||||
:param step_id:
|
||||
:return: List of artifacts for a list of steps.
|
||||
"""
|
||||
request["agent"]
|
||||
artifacts = await app.DATABASE.get_artifacts_for_task_step(
|
||||
task_id,
|
||||
step_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
return ORJSONResponse([artifact.model_dump() for artifact in artifacts])
|
||||
|
||||
|
||||
class ActionResultTmp(BaseModel):
|
||||
action: dict[str, Any]
|
||||
data: dict[str, Any] | list | str | None = None
|
||||
exception_message: str | None = None
|
||||
success: bool = True
|
||||
|
||||
|
||||
@base_router.get("/tasks/{task_id}/actions", response_model=list[ActionResultTmp])
|
||||
async def get_task_actions(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> list[ActionResultTmp]:
|
||||
request["agent"]
|
||||
steps = await app.DATABASE.get_task_step_models(task_id, organization_id=current_org.organization_id)
|
||||
results: list[ActionResultTmp] = []
|
||||
for step_s in steps:
|
||||
if not step_s.output or "action_results" not in step_s.output:
|
||||
continue
|
||||
for action_result in step_s.output["action_results"]:
|
||||
results.append(ActionResultTmp.model_validate(action_result))
|
||||
return results
|
||||
|
||||
|
||||
@base_router.post("/workflows/{workflow_id}/run", response_model=RunWorkflowResponse)
|
||||
async def execute_workflow(
|
||||
background_tasks: BackgroundTasks,
|
||||
request: Request,
|
||||
workflow_id: str,
|
||||
workflow_request: WorkflowRequestBody,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
x_max_steps_override: Annotated[int | None, Header()] = None,
|
||||
) -> RunWorkflowResponse:
|
||||
LOG.info(
|
||||
f"Running workflow {workflow_id}",
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
context = skyvern_context.ensure_context()
|
||||
request_id = context.request_id
|
||||
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
|
||||
request_id=request_id,
|
||||
workflow_request=workflow_request,
|
||||
workflow_id=workflow_id,
|
||||
organization_id=current_org.organization_id,
|
||||
max_steps_override=x_max_steps_override,
|
||||
)
|
||||
if x_max_steps_override:
|
||||
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
|
||||
await app.ASYNC_EXECUTOR.execute_workflow(
|
||||
background_tasks=background_tasks,
|
||||
organization=current_org,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
max_steps_override=x_max_steps_override,
|
||||
api_key=x_api_key,
|
||||
)
|
||||
return RunWorkflowResponse(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
|
||||
|
||||
@base_router.get("/workflows/{workflow_id}/runs/{workflow_run_id}", response_model=WorkflowRunStatusResponse)
|
||||
async def get_workflow_run(
|
||||
request: Request,
|
||||
workflow_id: str,
|
||||
workflow_run_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> WorkflowRunStatusResponse:
|
||||
request["agent"]
|
||||
return await app.WORKFLOW_SERVICE.build_workflow_run_status_response(
|
||||
workflow_id=workflow_id, workflow_run_id=workflow_run_id, organization_id=current_org.organization_id
|
||||
)
|
||||
0
skyvern/forge/sdk/schemas/__init__.py
Normal file
0
skyvern/forge/sdk/schemas/__init__.py
Normal file
181
skyvern/forge/sdk/schemas/tasks.py
Normal file
181
skyvern/forge/sdk/schemas/tasks.py
Normal file
@@ -0,0 +1,181 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ProxyLocation(StrEnum):
|
||||
US_CA = "US-CA"
|
||||
US_NY = "US-NY"
|
||||
US_TX = "US-TX"
|
||||
US_FL = "US-FL"
|
||||
US_WA = "US-WA"
|
||||
RESIDENTIAL = "RESIDENTIAL"
|
||||
NONE = "NONE"
|
||||
|
||||
|
||||
class TaskRequest(BaseModel):
|
||||
url: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Starting URL for the task.",
|
||||
examples=["https://www.geico.com"],
|
||||
)
|
||||
# TODO: use HttpUrl instead of str
|
||||
webhook_callback_url: str | None = Field(
|
||||
default=None,
|
||||
description="The URL to call when the task is completed.",
|
||||
examples=["https://my-webhook.com"],
|
||||
)
|
||||
navigation_goal: str | None = Field(
|
||||
default=None,
|
||||
description="The user's goal for the task.",
|
||||
examples=["Get a quote for car insurance"],
|
||||
)
|
||||
data_extraction_goal: str | None = Field(
|
||||
default=None,
|
||||
description="The user's goal for data extraction.",
|
||||
examples=["Extract the quote price"],
|
||||
)
|
||||
navigation_payload: dict[str, Any] | list | str | None = Field(
|
||||
None,
|
||||
description="The user's details needed to achieve the task.",
|
||||
examples=[{"name": "John Doe", "email": "john@doe.com"}],
|
||||
)
|
||||
proxy_location: ProxyLocation | None = Field(
|
||||
None,
|
||||
description="The location of the proxy to use for the task.",
|
||||
examples=["US-WA", "US-CA", "US-FL", "US-NY", "US-TX"],
|
||||
)
|
||||
extracted_information_schema: dict[str, Any] | list | str | None = Field(
|
||||
None,
|
||||
description="The requested schema of the extracted information.",
|
||||
)
|
||||
|
||||
|
||||
class TaskStatus(StrEnum):
|
||||
created = "created"
|
||||
running = "running"
|
||||
failed = "failed"
|
||||
terminated = "terminated"
|
||||
completed = "completed"
|
||||
|
||||
def is_final(self) -> bool:
|
||||
return self in {TaskStatus.failed, TaskStatus.terminated, TaskStatus.completed}
|
||||
|
||||
def can_update_to(self, new_status: TaskStatus) -> bool:
|
||||
allowed_transitions: dict[TaskStatus, set[TaskStatus]] = {
|
||||
TaskStatus.created: {TaskStatus.running},
|
||||
TaskStatus.running: {TaskStatus.completed, TaskStatus.failed, TaskStatus.terminated},
|
||||
TaskStatus.failed: set(),
|
||||
TaskStatus.completed: set(),
|
||||
}
|
||||
return new_status in allowed_transitions[self]
|
||||
|
||||
def requires_extracted_info(self) -> bool:
|
||||
status_requires_extracted_information = {TaskStatus.completed}
|
||||
return self in status_requires_extracted_information
|
||||
|
||||
def cant_have_extracted_info(self) -> bool:
|
||||
status_cant_have_extracted_information = {
|
||||
TaskStatus.created,
|
||||
TaskStatus.running,
|
||||
TaskStatus.failed,
|
||||
TaskStatus.terminated,
|
||||
}
|
||||
return self in status_cant_have_extracted_information
|
||||
|
||||
def requires_failure_reason(self) -> bool:
|
||||
status_requires_failure_reason = {TaskStatus.failed, TaskStatus.terminated}
|
||||
return self in status_requires_failure_reason
|
||||
|
||||
|
||||
class Task(TaskRequest):
|
||||
created_at: datetime = Field(
|
||||
...,
|
||||
description="The creation datetime of the task.",
|
||||
examples=["2023-01-01T00:00:00Z"],
|
||||
)
|
||||
modified_at: datetime = Field(
|
||||
...,
|
||||
description="The modification datetime of the task.",
|
||||
examples=["2023-01-01T00:00:00Z"],
|
||||
)
|
||||
task_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task.",
|
||||
examples=["50da533e-3904-4401-8a07-c49adf88b5eb"],
|
||||
)
|
||||
status: TaskStatus = Field(..., description="The status of the task.", examples=["created"])
|
||||
extracted_information: dict[str, Any] | list | str | None = Field(
|
||||
None,
|
||||
description="The extracted information from the task.",
|
||||
)
|
||||
failure_reason: str | None = Field(
|
||||
None,
|
||||
description="The reason for the task failure.",
|
||||
)
|
||||
organization_id: str | None = None
|
||||
workflow_run_id: str | None = None
|
||||
order: int | None = None
|
||||
retry: int | None = None
|
||||
|
||||
def validate_update(
|
||||
self,
|
||||
status: TaskStatus,
|
||||
extracted_information: dict[str, Any] | list | str | None,
|
||||
failure_reason: str | None = None,
|
||||
) -> None:
|
||||
old_status = self.status
|
||||
|
||||
if not old_status.can_update_to(status):
|
||||
raise ValueError(f"invalid_status_transition({old_status},{status},{self.task_id}")
|
||||
|
||||
if status.requires_failure_reason() and failure_reason is None:
|
||||
raise ValueError(f"status_requires_failure_reason({status},{self.task_id}")
|
||||
|
||||
if status.requires_extracted_info() and self.data_extraction_goal and extracted_information is None:
|
||||
raise ValueError(f"status_requires_extracted_information({status},{self.task_id}")
|
||||
|
||||
if status.cant_have_extracted_info() and extracted_information is not None:
|
||||
raise ValueError(f"status_cant_have_extracted_information({self.task_id})")
|
||||
|
||||
if self.extracted_information is not None and extracted_information is not None:
|
||||
raise ValueError(f"cant_override_extracted_information({self.task_id})")
|
||||
|
||||
if self.failure_reason is not None and failure_reason is not None:
|
||||
raise ValueError(f"cant_override_failure_reason({self.task_id})")
|
||||
|
||||
def to_task_response(
|
||||
self, screenshot_url: str | None = None, recording_url: str | None = None, failure_reason: str | None = None
|
||||
) -> TaskResponse:
|
||||
return TaskResponse(
|
||||
request=self,
|
||||
task_id=self.task_id,
|
||||
status=self.status,
|
||||
created_at=self.created_at,
|
||||
modified_at=self.modified_at,
|
||||
extracted_information=self.extracted_information,
|
||||
failure_reason=failure_reason or self.failure_reason,
|
||||
screenshot_url=screenshot_url,
|
||||
recording_url=recording_url,
|
||||
)
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
request: TaskRequest
|
||||
task_id: str
|
||||
status: TaskStatus
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
extracted_information: list | dict[str, Any] | str | None = None
|
||||
screenshot_url: str | None = None
|
||||
recording_url: str | None = None
|
||||
failure_reason: str | None = None
|
||||
|
||||
|
||||
class CreateTaskResponse(BaseModel):
|
||||
task_id: str
|
||||
0
skyvern/forge/sdk/services/__init__.py
Normal file
0
skyvern/forge/sdk/services/__init__.py
Normal file
76
skyvern/forge/sdk/services/org_auth_service.py
Normal file
76
skyvern/forge/sdk/services/org_auth_service.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import time
|
||||
from typing import Annotated
|
||||
|
||||
from asyncache import cached
|
||||
from cachetools import TTLCache
|
||||
from fastapi import Header, HTTPException, status
|
||||
from jose import jwt
|
||||
from jose.exceptions import JWTError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.db.client import AgentDB
|
||||
from skyvern.forge.sdk.models import Organization, OrganizationAuthTokenType, TokenPayload
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
AUTHENTICATION_TTL = 60 * 60 # one hour
|
||||
CACHE_SIZE = 128
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
async def get_current_org(
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
) -> Organization:
|
||||
if not x_api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
return await _get_current_org_cached(x_api_key, app.DATABASE)
|
||||
|
||||
|
||||
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
|
||||
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
||||
"""
|
||||
Authentication is cached for one hour
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
x_api_key,
|
||||
SettingsManager.get_settings().SECRET_KEY,
|
||||
algorithms=[ALGORITHM],
|
||||
)
|
||||
api_key_data = TokenPayload(**payload)
|
||||
except (JWTError, ValidationError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Could not validate credentials",
|
||||
)
|
||||
if api_key_data.exp < time.time():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Auth token is expired",
|
||||
)
|
||||
|
||||
organization = await db.get_organization(organization_id=api_key_data.sub)
|
||||
if not organization:
|
||||
raise HTTPException(status_code=404, detail="Organization not found")
|
||||
|
||||
# check if the token exists in the database
|
||||
api_key_db_obj = await db.validate_org_auth_token(
|
||||
organization_id=organization.organization_id,
|
||||
token_type=OrganizationAuthTokenType.api,
|
||||
token=x_api_key,
|
||||
)
|
||||
if not api_key_db_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
|
||||
# set organization_id in skyvern context and log context
|
||||
context = skyvern_context.current()
|
||||
if context:
|
||||
context.organization_id = organization.organization_id
|
||||
return organization
|
||||
14
skyvern/forge/sdk/settings_manager.py
Normal file
14
skyvern/forge/sdk/settings_manager.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from skyvern.config import Settings
|
||||
from skyvern.config import settings as base_settings
|
||||
|
||||
|
||||
class SettingsManager:
|
||||
__instance: Settings = base_settings
|
||||
|
||||
@staticmethod
|
||||
def get_settings() -> Settings:
|
||||
return SettingsManager.__instance
|
||||
|
||||
@staticmethod
|
||||
def set_settings(settings: Settings) -> None:
|
||||
SettingsManager.__instance = settings
|
||||
0
skyvern/forge/sdk/workflow/__init__.py
Normal file
0
skyvern/forge/sdk/workflow/__init__.py
Normal file
79
skyvern/forge/sdk/workflow/context_manager.py
Normal file
79
skyvern/forge/sdk/workflow/context_manager.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
||||
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, Parameter, ParameterType, WorkflowParameter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunParameter
|
||||
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class ContextManager:
|
||||
aws_client: AsyncAWSClient
|
||||
parameters: dict[str, PARAMETER_TYPE]
|
||||
values: dict[str, Any]
|
||||
|
||||
def __init__(self, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]) -> None:
|
||||
self.aws_client = AsyncAWSClient()
|
||||
self.parameters = {}
|
||||
self.values = {}
|
||||
for parameter, run_parameter in workflow_parameter_tuples:
|
||||
if parameter.key in self.parameters:
|
||||
prev_value = self.parameters[parameter.key]
|
||||
new_value = run_parameter.value
|
||||
LOG.error(
|
||||
f"Duplicate parameter key {parameter.key} found while initializing context manager, previous value: {prev_value}, new value: {new_value}. Using new value."
|
||||
)
|
||||
|
||||
self.parameters[parameter.key] = parameter
|
||||
self.values[parameter.key] = run_parameter.value
|
||||
|
||||
async def register_parameter_value(
|
||||
self,
|
||||
parameter: PARAMETER_TYPE,
|
||||
) -> None:
|
||||
if parameter.parameter_type == ParameterType.WORKFLOW:
|
||||
LOG.error(f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}")
|
||||
raise ValueError(
|
||||
f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}"
|
||||
)
|
||||
elif parameter.parameter_type == ParameterType.AWS_SECRET:
|
||||
secret_value = await self.aws_client.get_secret(parameter.aws_key)
|
||||
if secret_value is not None:
|
||||
self.values[parameter.key] = secret_value
|
||||
else:
|
||||
# ContextParameter values will be set within the blocks
|
||||
return None
|
||||
|
||||
async def register_block_parameters(
|
||||
self,
|
||||
parameters: list[PARAMETER_TYPE],
|
||||
) -> None:
|
||||
for parameter in parameters:
|
||||
if parameter.key in self.parameters:
|
||||
LOG.debug(f"Parameter {parameter.key} already registered, skipping")
|
||||
continue
|
||||
|
||||
if parameter.parameter_type == ParameterType.WORKFLOW:
|
||||
LOG.error(
|
||||
f"Workflow parameter {parameter.key} should have already been set through workflow run parameters"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Workflow parameter {parameter.key} should have already been set through workflow run parameters"
|
||||
)
|
||||
|
||||
self.parameters[parameter.key] = parameter
|
||||
await self.register_parameter_value(parameter)
|
||||
|
||||
def get_parameter(self, key: str) -> Parameter:
|
||||
return self.parameters[key]
|
||||
|
||||
def get_value(self, key: str) -> Any:
|
||||
return self.values[key]
|
||||
|
||||
def set_value(self, key: str, value: Any) -> None:
|
||||
self.values[key] = value
|
||||
0
skyvern/forge/sdk/workflow/models/__init__.py
Normal file
0
skyvern/forge/sdk/workflow/models/__init__.py
Normal file
221
skyvern/forge/sdk/workflow/models/block.py
Normal file
221
skyvern/forge/sdk/workflow/models/block.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import abc
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any, Literal, Union
|
||||
|
||||
import structlog
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from skyvern.exceptions import (
|
||||
ContextParameterValueNotFound,
|
||||
MissingBrowserStatePage,
|
||||
TaskNotFound,
|
||||
UnexpectedTaskStatus,
|
||||
)
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.workflow.context_manager import ContextManager
|
||||
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, ContextParameter, WorkflowParameter
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class BlockType(StrEnum):
|
||||
TASK = "task"
|
||||
FOR_LOOP = "for_loop"
|
||||
|
||||
|
||||
class Block(BaseModel, abc.ABC):
|
||||
block_type: BlockType
|
||||
parent_block_id: str | None = None
|
||||
next_block_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def get_subclasses(cls) -> tuple[type["Block"], ...]:
|
||||
return tuple(cls.__subclasses__())
|
||||
|
||||
@abc.abstractmethod
|
||||
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_all_parameters(
|
||||
self,
|
||||
) -> list[PARAMETER_TYPE]:
|
||||
pass
|
||||
|
||||
|
||||
class TaskBlock(Block):
|
||||
block_type: Literal[BlockType.TASK] = BlockType.TASK
|
||||
|
||||
url: str | None = None
|
||||
navigation_goal: str | None = None
|
||||
data_extraction_goal: str | None = None
|
||||
data_schema: dict[str, Any] | None = None
|
||||
max_retries: int = 0
|
||||
parameters: list[PARAMETER_TYPE] = []
|
||||
|
||||
def get_all_parameters(
|
||||
self,
|
||||
) -> list[PARAMETER_TYPE]:
|
||||
return self.parameters
|
||||
|
||||
@staticmethod
|
||||
async def get_task_order(workflow_run_id: str, current_retry: int) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the order and retry for the next task in the workflow run as a tuple.
|
||||
"""
|
||||
last_task_for_workflow_run = await app.DATABASE.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
|
||||
# If there is no previous task, the order will be 0 and the retry will be 0.
|
||||
if last_task_for_workflow_run is None:
|
||||
return 0, 0
|
||||
# If there is a previous task but the current retry is 0, the order will be the order of the last task + 1
|
||||
# and the retry will be 0.
|
||||
order = last_task_for_workflow_run.order or 0
|
||||
if current_retry == 0:
|
||||
return order + 1, 0
|
||||
# If there is a previous task and the current retry is not 0, the order will be the order of the last task
|
||||
# and the retry will be the retry of the last task + 1. (There is a validation that makes sure the retry
|
||||
# of the last task is equal to current_retry - 1) if it is not, we use last task retry + 1.
|
||||
retry = last_task_for_workflow_run.retry or 0
|
||||
if retry + 1 != current_retry:
|
||||
LOG.error(
|
||||
f"Last task for workflow run is retry number {last_task_for_workflow_run.retry}, "
|
||||
f"but current retry is {current_retry}. Could be race condition. Using last task retry + 1",
|
||||
workflow_run_id=workflow_run_id,
|
||||
last_task_id=last_task_for_workflow_run.task_id,
|
||||
last_task_retry=last_task_for_workflow_run.retry,
|
||||
current_retry=current_retry,
|
||||
)
|
||||
|
||||
return order, retry + 1
|
||||
|
||||
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
|
||||
task = None
|
||||
current_retry = 0
|
||||
# initial value for will_retry is True, so that the loop runs at least once
|
||||
will_retry = True
|
||||
workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id=workflow_run_id)
|
||||
workflow = await app.WORKFLOW_SERVICE.get_workflow(workflow_id=workflow_run.workflow_id)
|
||||
# TODO (kerem) we should always retry on terminated. We should make a distinction between retriable and
|
||||
# non-retryable terminations
|
||||
while will_retry:
|
||||
task_order, task_retry = await self.get_task_order(workflow_run_id, current_retry)
|
||||
task, step = await app.agent.create_task_and_step_from_block(
|
||||
task_block=self,
|
||||
workflow=workflow,
|
||||
workflow_run=workflow_run,
|
||||
context_manager=context_manager,
|
||||
task_order=task_order,
|
||||
task_retry=task_retry,
|
||||
)
|
||||
organization = await app.DATABASE.get_organization(organization_id=workflow.organization_id)
|
||||
if not organization:
|
||||
raise Exception(f"Organization is missing organization_id={workflow.organization_id}")
|
||||
browser_state = await app.BROWSER_MANAGER.get_or_create_for_workflow_run(
|
||||
workflow_run=workflow_run, url=self.url
|
||||
)
|
||||
if not browser_state.page:
|
||||
LOG.error("BrowserState has no page", workflow_run_id=workflow_run.workflow_run_id)
|
||||
raise MissingBrowserStatePage(workflow_run_id=workflow_run.workflow_run_id)
|
||||
|
||||
LOG.info(
|
||||
f"Navigating to page",
|
||||
url=self.url,
|
||||
workflow_run_id=workflow_run_id,
|
||||
task_id=task.task_id,
|
||||
workflow_id=workflow.workflow_id,
|
||||
organization_id=workflow.organization_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
|
||||
if self.url:
|
||||
await browser_state.page.goto(self.url)
|
||||
|
||||
await app.agent.execute_step(organization=organization, task=task, step=step, workflow_run=workflow_run)
|
||||
# Check task status
|
||||
updated_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=workflow.organization_id)
|
||||
if not updated_task:
|
||||
raise TaskNotFound(task.task_id)
|
||||
if not updated_task.status.is_final():
|
||||
raise UnexpectedTaskStatus(task_id=updated_task.task_id, status=updated_task.status)
|
||||
if updated_task.status == TaskStatus.completed:
|
||||
will_retry = False
|
||||
else:
|
||||
current_retry += 1
|
||||
will_retry = current_retry <= self.max_retries
|
||||
retry_message = f", retrying task {current_retry}/{self.max_retries}" if will_retry else ""
|
||||
LOG.warning(
|
||||
f"Task failed with status {updated_task.status}{retry_message}",
|
||||
task_id=updated_task.task_id,
|
||||
status=updated_task.status,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow.workflow_id,
|
||||
organization_id=workflow.organization_id,
|
||||
current_retry=current_retry,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
|
||||
|
||||
class ForLoopBlock(Block):
|
||||
block_type: Literal[BlockType.FOR_LOOP] = BlockType.FOR_LOOP
|
||||
|
||||
# TODO (kerem): Add support for ContextParameter
|
||||
loop_over: PARAMETER_TYPE
|
||||
loop_block: "BlockTypeVar"
|
||||
|
||||
def get_all_parameters(
|
||||
self,
|
||||
) -> list[PARAMETER_TYPE]:
|
||||
return self.loop_block.get_all_parameters() + [self.loop_over]
|
||||
|
||||
def get_loop_block_context_parameters(self, workflow_run_id: str, loop_data: Any) -> list[ContextParameter]:
|
||||
if not isinstance(loop_data, dict):
|
||||
# TODO (kerem): Should we add support for other types?
|
||||
raise ValueError("loop_data should be a dictionary")
|
||||
|
||||
loop_block_parameters = self.loop_block.get_all_parameters()
|
||||
context_parameters = [
|
||||
parameter for parameter in loop_block_parameters if isinstance(parameter, ContextParameter)
|
||||
]
|
||||
for context_parameter in context_parameters:
|
||||
if context_parameter.key not in loop_data:
|
||||
raise ContextParameterValueNotFound(
|
||||
parameter_key=context_parameter.key,
|
||||
existing_keys=list(loop_data.keys()),
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
context_parameter.value = loop_data[context_parameter.key]
|
||||
|
||||
return context_parameters
|
||||
|
||||
def get_loop_over_parameter_values(self, context_manager: ContextManager) -> list[Any]:
|
||||
if isinstance(self.loop_over, WorkflowParameter):
|
||||
parameter_value = context_manager.get_value(self.loop_over.key)
|
||||
if isinstance(parameter_value, list):
|
||||
return parameter_value
|
||||
else:
|
||||
# TODO (kerem): Should we raise an error here?
|
||||
return [parameter_value]
|
||||
else:
|
||||
# TODO (kerem): Implement this for context parameters
|
||||
raise NotImplementedError
|
||||
|
||||
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
|
||||
loop_over_values = self.get_loop_over_parameter_values(context_manager)
|
||||
LOG.info(
|
||||
f"Number of loop_over values: {len(loop_over_values)}",
|
||||
block_type=self.block_type,
|
||||
workflow_run_id=workflow_run_id,
|
||||
num_loop_over_values=len(loop_over_values),
|
||||
)
|
||||
for loop_over_value in loop_over_values:
|
||||
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
|
||||
for context_parameter in context_parameters_with_value:
|
||||
context_manager.set_value(context_parameter.key, context_parameter.value)
|
||||
await self.loop_block.execute(workflow_run_id=workflow_run_id, context_manager=context_manager)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
BlockSubclasses = Union[ForLoopBlock, TaskBlock]
|
||||
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]
|
||||
84
skyvern/forge/sdk/workflow/models/parameter.py
Normal file
84
skyvern/forge/sdk/workflow/models/parameter.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import abc
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ParameterType(StrEnum):
|
||||
WORKFLOW = "workflow"
|
||||
CONTEXT = "context"
|
||||
AWS_SECRET = "aws_secret"
|
||||
|
||||
|
||||
class Parameter(BaseModel, abc.ABC):
|
||||
# TODO (kerem): Should we also have organization_id here?
|
||||
parameter_type: ParameterType
|
||||
key: str
|
||||
description: str | None = None
|
||||
|
||||
@classmethod
|
||||
def get_subclasses(cls) -> tuple[type["Parameter"], ...]:
|
||||
return tuple(cls.__subclasses__())
|
||||
|
||||
|
||||
class AWSSecretParameter(Parameter):
|
||||
parameter_type: Literal[ParameterType.AWS_SECRET] = ParameterType.AWS_SECRET
|
||||
|
||||
aws_secret_parameter_id: str
|
||||
workflow_id: str
|
||||
aws_key: str
|
||||
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
deleted_at: datetime | None = None
|
||||
|
||||
|
||||
class WorkflowParameterType(StrEnum):
|
||||
STRING = "string"
|
||||
INTEGER = "integer"
|
||||
FLOAT = "float"
|
||||
BOOLEAN = "boolean"
|
||||
JSON = "json"
|
||||
|
||||
def convert_value(self, value: str | None) -> str | int | float | bool | dict | list | None:
|
||||
if value is None:
|
||||
return None
|
||||
if self == WorkflowParameterType.STRING:
|
||||
return value
|
||||
elif self == WorkflowParameterType.INTEGER:
|
||||
return int(value)
|
||||
elif self == WorkflowParameterType.FLOAT:
|
||||
return float(value)
|
||||
elif self == WorkflowParameterType.BOOLEAN:
|
||||
return value.lower() in ["true", "1"]
|
||||
elif self == WorkflowParameterType.JSON:
|
||||
return json.loads(value)
|
||||
|
||||
|
||||
class WorkflowParameter(Parameter):
|
||||
parameter_type: Literal[ParameterType.WORKFLOW] = ParameterType.WORKFLOW
|
||||
|
||||
workflow_parameter_id: str
|
||||
workflow_parameter_type: WorkflowParameterType
|
||||
workflow_id: str
|
||||
# the type of default_value will be determined by the workflow_parameter_type
|
||||
default_value: str | int | float | bool | dict | list | None = None
|
||||
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
deleted_at: datetime | None = None
|
||||
|
||||
|
||||
class ContextParameter(Parameter):
|
||||
parameter_type: Literal[ParameterType.CONTEXT] = ParameterType.CONTEXT
|
||||
|
||||
source: WorkflowParameter
|
||||
# value will be populated by the context manager
|
||||
value: str | int | float | bool | dict | list | None = None
|
||||
|
||||
|
||||
ParameterSubclasses = Union[WorkflowParameter, ContextParameter, AWSSecretParameter]
|
||||
PARAMETER_TYPE = Annotated[ParameterSubclasses, Field(discriminator="parameter_type")]
|
||||
74
skyvern/forge/sdk/workflow/models/workflow.py
Normal file
74
skyvern/forge/sdk/workflow/models/workflow.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
|
||||
from skyvern.forge.sdk.workflow.models.block import BlockTypeVar
|
||||
|
||||
|
||||
class WorkflowRequestBody(BaseModel):
|
||||
data: dict[str, Any] | None = None
|
||||
proxy_location: ProxyLocation | None = None
|
||||
webhook_callback_url: str | None = None
|
||||
|
||||
|
||||
class RunWorkflowResponse(BaseModel):
|
||||
workflow_id: str
|
||||
workflow_run_id: str
|
||||
|
||||
|
||||
class WorkflowDefinition(BaseModel):
|
||||
blocks: List[BlockTypeVar]
|
||||
|
||||
|
||||
class Workflow(BaseModel):
|
||||
workflow_id: str
|
||||
organization_id: str
|
||||
title: str
|
||||
description: str | None = None
|
||||
workflow_definition: WorkflowDefinition
|
||||
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
deleted_at: datetime | None = None
|
||||
|
||||
|
||||
class WorkflowRunStatus(StrEnum):
|
||||
created = "created"
|
||||
running = "running"
|
||||
failed = "failed"
|
||||
terminated = "terminated"
|
||||
completed = "completed"
|
||||
|
||||
|
||||
class WorkflowRun(BaseModel):
|
||||
workflow_run_id: str
|
||||
workflow_id: str
|
||||
status: WorkflowRunStatus
|
||||
proxy_location: ProxyLocation | None = None
|
||||
webhook_callback_url: str | None = None
|
||||
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
|
||||
|
||||
class WorkflowRunParameter(BaseModel):
|
||||
workflow_run_id: str
|
||||
workflow_parameter_id: str
|
||||
value: bool | int | float | str | dict | list
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class WorkflowRunStatusResponse(BaseModel):
|
||||
workflow_id: str
|
||||
workflow_run_id: str
|
||||
status: WorkflowRunStatus
|
||||
proxy_location: ProxyLocation | None = None
|
||||
webhook_callback_url: str | None = None
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
parameters: dict[str, Any]
|
||||
screenshot_urls: list[str] | None = None
|
||||
recording_url: str | None = None
|
||||
509
skyvern/forge/sdk/workflow/service.py
Normal file
509
skyvern/forge/sdk/workflow/service.py
Normal file
@@ -0,0 +1,509 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
import structlog
|
||||
|
||||
from skyvern.exceptions import (
|
||||
FailedToSendWebhook,
|
||||
MissingValueForParameter,
|
||||
WorkflowNotFound,
|
||||
WorkflowOrganizationMismatch,
|
||||
WorkflowRunNotFound,
|
||||
)
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.forge.sdk.workflow.context_manager import ContextManager
|
||||
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
|
||||
from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
Workflow,
|
||||
WorkflowDefinition,
|
||||
WorkflowRequestBody,
|
||||
WorkflowRun,
|
||||
WorkflowRunParameter,
|
||||
WorkflowRunStatus,
|
||||
WorkflowRunStatusResponse,
|
||||
)
|
||||
from skyvern.webeye.browser_factory import BrowserState
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class WorkflowService:
|
||||
async def setup_workflow_run(
|
||||
self,
|
||||
request_id: str | None,
|
||||
workflow_request: WorkflowRequestBody,
|
||||
workflow_id: str,
|
||||
organization_id: str,
|
||||
max_steps_override: int | None = None,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Create a workflow run and its parameters. Validate the workflow and the organization. If there are missing
|
||||
parameters with no default value, mark the workflow run as failed.
|
||||
:param request_id: The request id for the workflow run.
|
||||
:param workflow_request: The request body for the workflow run, containing the parameters and the config.
|
||||
:param workflow_id: The workflow id to run.
|
||||
:param organization_id: The organization id for the workflow.
|
||||
:param max_steps_override: The max steps override for the workflow run, if any.
|
||||
:return: The created workflow run.
|
||||
"""
|
||||
LOG.info(f"Setting up workflow run for workflow {workflow_id}", workflow_id=workflow_id)
|
||||
# Validate the workflow and the organization
|
||||
workflow = await self.get_workflow(workflow_id=workflow_id)
|
||||
if workflow is None:
|
||||
LOG.error(f"Workflow {workflow_id} not found")
|
||||
raise WorkflowNotFound(workflow_id=workflow_id)
|
||||
if workflow.organization_id != organization_id:
|
||||
LOG.error(f"Workflow {workflow_id} does not belong to organization {organization_id}")
|
||||
raise WorkflowOrganizationMismatch(workflow_id=workflow_id, organization_id=organization_id)
|
||||
# Create the workflow run and set skyvern context
|
||||
workflow_run = await self.create_workflow_run(workflow_request=workflow_request, workflow_id=workflow_id)
|
||||
LOG.info(
|
||||
f"Created workflow run {workflow_run.workflow_run_id} for workflow {workflow.workflow_id}",
|
||||
request_id=request_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
workflow_id=workflow.workflow_id,
|
||||
proxy_location=workflow_request.proxy_location,
|
||||
)
|
||||
skyvern_context.set(
|
||||
SkyvernContext(
|
||||
organization_id=organization_id,
|
||||
request_id=request_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
max_steps_override=max_steps_override,
|
||||
)
|
||||
)
|
||||
|
||||
# Set workflow run status to running, create workflow run parameters
|
||||
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
|
||||
|
||||
# Create all the workflow run parameters, AWSSecretParameter won't have workflow run parameters created.
|
||||
all_workflow_parameters = await self.get_workflow_parameters(workflow_id=workflow.workflow_id)
|
||||
workflow_run_parameters = []
|
||||
for workflow_parameter in all_workflow_parameters:
|
||||
if workflow_request.data and workflow_parameter.key in workflow_request.data:
|
||||
request_body_value = workflow_request.data[workflow_parameter.key]
|
||||
workflow_run_parameter = await self.create_workflow_run_parameter(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
|
||||
value=request_body_value,
|
||||
)
|
||||
elif workflow_parameter.default_value is not None:
|
||||
workflow_run_parameter = await self.create_workflow_run_parameter(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
|
||||
value=workflow_parameter.default_value,
|
||||
)
|
||||
else:
|
||||
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
|
||||
raise MissingValueForParameter(
|
||||
parameter_key=workflow_parameter.key,
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
|
||||
workflow_run_parameters.append(workflow_run_parameter)
|
||||
|
||||
LOG.info(
|
||||
f"Created workflow run parameters for workflow run {workflow_run.workflow_run_id}",
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
|
||||
return workflow_run
|
||||
|
||||
async def execute_workflow(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
api_key: str,
|
||||
) -> WorkflowRun:
|
||||
"""Execute a workflow."""
|
||||
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
|
||||
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id)
|
||||
|
||||
await app.BROWSER_MANAGER.get_or_create_for_workflow_run(workflow_run=workflow_run)
|
||||
|
||||
# Get all <workflow parameter, workflow run parameter> tuples
|
||||
wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id)
|
||||
# todo(kerem): do this in a better way (a shared context manager? (not really shared because we use batch job))
|
||||
context_manager = ContextManager(wp_wps_tuples)
|
||||
# Execute workflow blocks
|
||||
blocks = workflow.workflow_definition.blocks
|
||||
for block_idx, block in enumerate(blocks):
|
||||
parameters = block.get_all_parameters()
|
||||
await context_manager.register_block_parameters(parameters)
|
||||
LOG.info(
|
||||
f"Executing root block {block.block_type} at index {block_idx} for workflow run {workflow_run.workflow_run_id}",
|
||||
block_type=block.block_type,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
await block.execute(workflow_run_id=workflow_run.workflow_run_id, context_manager=context_manager)
|
||||
|
||||
# Get last task for workflow run
|
||||
task = await self.get_last_task_for_workflow_run(workflow_run_id=workflow_run.workflow_run_id)
|
||||
if not task:
|
||||
LOG.warning(
|
||||
f"No tasks found for workflow run {workflow_run.workflow_run_id}, not sending webhook",
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
return workflow_run
|
||||
|
||||
# Update workflow status
|
||||
if task.status == "completed":
|
||||
await self.mark_workflow_run_as_completed(workflow_run_id=workflow_run.workflow_run_id)
|
||||
elif task.status == "failed":
|
||||
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
|
||||
elif task.status == "terminated":
|
||||
await self.mark_workflow_run_as_terminated(workflow_run_id=workflow_run.workflow_run_id)
|
||||
else:
|
||||
LOG.warning(
|
||||
f"Task {task.task_id} has an incomplete status {task.status}, not updating workflow run status",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
task_id=task.task_id,
|
||||
status=task.status,
|
||||
workflow_run_status=workflow_run.status,
|
||||
)
|
||||
|
||||
await self.send_workflow_response(
|
||||
workflow=workflow,
|
||||
workflow_run=workflow_run,
|
||||
api_key=api_key,
|
||||
last_task=task,
|
||||
)
|
||||
return workflow_run
|
||||
|
||||
async def create_workflow(
|
||||
self,
|
||||
organization_id: str,
|
||||
title: str,
|
||||
workflow_definition: WorkflowDefinition,
|
||||
description: str | None = None,
|
||||
) -> Workflow:
|
||||
return await app.DATABASE.create_workflow(
|
||||
organization_id=organization_id,
|
||||
title=title,
|
||||
description=description,
|
||||
workflow_definition=workflow_definition.model_dump() if workflow_definition else None,
|
||||
)
|
||||
|
||||
async def get_workflow(self, workflow_id: str) -> Workflow:
|
||||
workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise WorkflowNotFound(workflow_id)
|
||||
return workflow
|
||||
|
||||
async def update_workflow(
|
||||
self,
|
||||
workflow_id: str,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
workflow_definition: WorkflowDefinition | None = None,
|
||||
) -> Workflow | None:
|
||||
return await app.DATABASE.update_workflow(
|
||||
workflow_id=workflow_id,
|
||||
title=title,
|
||||
description=description,
|
||||
workflow_definition=workflow_definition.model_dump() if workflow_definition else None,
|
||||
)
|
||||
|
||||
async def create_workflow_run(self, workflow_request: WorkflowRequestBody, workflow_id: str) -> WorkflowRun:
|
||||
return await app.DATABASE.create_workflow_run(
|
||||
workflow_id=workflow_id,
|
||||
proxy_location=workflow_request.proxy_location,
|
||||
webhook_callback_url=workflow_request.webhook_callback_url,
|
||||
)
|
||||
|
||||
async def mark_workflow_run_as_completed(self, workflow_run_id: str) -> None:
|
||||
LOG.info(
|
||||
f"Marking workflow run {workflow_run_id} as completed", workflow_run_id=workflow_run_id, status="completed"
|
||||
)
|
||||
await app.DATABASE.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
status=WorkflowRunStatus.completed,
|
||||
)
|
||||
|
||||
async def mark_workflow_run_as_failed(self, workflow_run_id: str) -> None:
|
||||
LOG.info(f"Marking workflow run {workflow_run_id} as failed", workflow_run_id=workflow_run_id, status="failed")
|
||||
await app.DATABASE.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
status=WorkflowRunStatus.failed,
|
||||
)
|
||||
|
||||
async def mark_workflow_run_as_running(self, workflow_run_id: str) -> None:
|
||||
LOG.info(
|
||||
f"Marking workflow run {workflow_run_id} as running", workflow_run_id=workflow_run_id, status="running"
|
||||
)
|
||||
await app.DATABASE.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
status=WorkflowRunStatus.running,
|
||||
)
|
||||
|
||||
async def mark_workflow_run_as_terminated(self, workflow_run_id: str) -> None:
|
||||
LOG.info(
|
||||
f"Marking workflow run {workflow_run_id} as terminated",
|
||||
workflow_run_id=workflow_run_id,
|
||||
status="terminated",
|
||||
)
|
||||
await app.DATABASE.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
status=WorkflowRunStatus.terminated,
|
||||
)
|
||||
|
||||
async def get_workflow_runs(self, workflow_id: str) -> list[WorkflowRun]:
|
||||
return await app.DATABASE.get_workflow_runs(workflow_id=workflow_id)
|
||||
|
||||
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run_id)
|
||||
if not workflow_run:
|
||||
raise WorkflowRunNotFound(workflow_run_id)
|
||||
return workflow_run
|
||||
|
||||
async def create_workflow_parameter(
|
||||
self,
|
||||
workflow_id: str,
|
||||
workflow_parameter_type: WorkflowParameterType,
|
||||
key: str,
|
||||
default_value: bool | int | float | str | dict | list | None = None,
|
||||
description: str | None = None,
|
||||
) -> WorkflowParameter:
|
||||
return await app.DATABASE.create_workflow_parameter(
|
||||
workflow_id=workflow_id,
|
||||
workflow_parameter_type=workflow_parameter_type,
|
||||
key=key,
|
||||
description=description,
|
||||
default_value=default_value,
|
||||
)
|
||||
|
||||
async def create_aws_secret_parameter(
|
||||
self, workflow_id: str, aws_key: str, key: str, description: str | None = None
|
||||
) -> AWSSecretParameter:
|
||||
return await app.DATABASE.create_aws_secret_parameter(
|
||||
workflow_id=workflow_id, aws_key=aws_key, key=key, description=description
|
||||
)
|
||||
|
||||
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
|
||||
return await app.DATABASE.get_workflow_parameters(workflow_id=workflow_id)
|
||||
|
||||
async def create_workflow_run_parameter(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
workflow_parameter_id: str,
|
||||
value: bool | int | float | str | dict | list,
|
||||
) -> WorkflowRunParameter:
|
||||
return await app.DATABASE.create_workflow_run_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter_id,
|
||||
value=json.dumps(value) if isinstance(value, (dict, list)) else value,
|
||||
)
|
||||
|
||||
async def get_workflow_run_parameter_tuples(
|
||||
self, workflow_run_id: str
|
||||
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
|
||||
return await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
|
||||
|
||||
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
|
||||
return await app.DATABASE.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
|
||||
|
||||
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
|
||||
return await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||
|
||||
async def build_workflow_run_status_response(
|
||||
self, workflow_id: str, workflow_run_id: str, organization_id: str
|
||||
) -> WorkflowRunStatusResponse:
|
||||
workflow = await self.get_workflow(workflow_id=workflow_id)
|
||||
if workflow is None:
|
||||
LOG.error(f"Workflow {workflow_id} not found")
|
||||
raise WorkflowNotFound(workflow_id=workflow_id)
|
||||
if workflow.organization_id != organization_id:
|
||||
LOG.error(f"Workflow {workflow_id} does not belong to organization {organization_id}")
|
||||
raise WorkflowOrganizationMismatch(workflow_id=workflow_id, organization_id=organization_id)
|
||||
|
||||
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
|
||||
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||
screenshot_urls = []
|
||||
# get the last screenshot for the last 3 tasks of the workflow run
|
||||
for task in workflow_run_tasks[::-1]:
|
||||
screenshot_artifact = await app.DATABASE.get_latest_artifact(
|
||||
task_id=task.task_id,
|
||||
artifact_types=[ArtifactType.SCREENSHOT_ACTION, ArtifactType.SCREENSHOT_FINAL],
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if screenshot_artifact:
|
||||
screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(screenshot_artifact)
|
||||
if screenshot_url:
|
||||
screenshot_urls.append(screenshot_url)
|
||||
if len(screenshot_urls) >= 3:
|
||||
break
|
||||
|
||||
recording_url = None
|
||||
recording_artifact = await app.DATABASE.get_artifact_for_workflow_run(
|
||||
workflow_run_id=workflow_run_id, artifact_type=ArtifactType.RECORDING, organization_id=organization_id
|
||||
)
|
||||
if recording_artifact:
|
||||
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
|
||||
|
||||
workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
|
||||
parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples}
|
||||
return WorkflowRunStatusResponse(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
status=workflow_run.status,
|
||||
proxy_location=workflow_run.proxy_location,
|
||||
webhook_callback_url=workflow_run.webhook_callback_url,
|
||||
created_at=workflow_run.created_at,
|
||||
modified_at=workflow_run.modified_at,
|
||||
parameters=parameters_with_value,
|
||||
screenshot_urls=screenshot_urls,
|
||||
recording_url=recording_url,
|
||||
)
|
||||
|
||||
async def send_workflow_response(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
workflow_run: WorkflowRun,
|
||||
last_task: Task,
|
||||
api_key: str | None = None,
|
||||
close_browser_on_completion: bool = True,
|
||||
) -> None:
|
||||
browser_state = await app.BROWSER_MANAGER.cleanup_for_workflow_run(
|
||||
workflow_run.workflow_run_id, close_browser_on_completion
|
||||
)
|
||||
if browser_state:
|
||||
await self.persist_video_data(browser_state, workflow, workflow_run)
|
||||
await self.persist_har_data(browser_state, last_task, workflow, workflow_run)
|
||||
|
||||
# Wait for all tasks to complete before generating the links for the artifacts
|
||||
all_workflow_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(
|
||||
workflow_run_id=workflow_run.workflow_run_id
|
||||
)
|
||||
all_workflow_task_ids = [task.task_id for task in all_workflow_tasks]
|
||||
await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_tasks(all_workflow_task_ids)
|
||||
|
||||
try:
|
||||
# Wait for all tasks to complete. Currently we're using asyncio.create_task() only for uploading artifacts to S3.
|
||||
# We're excluding the current task from the list of tasks to wait for to prevent a deadlock.
|
||||
st = time.time()
|
||||
async with asyncio.timeout(30):
|
||||
await asyncio.gather(
|
||||
*[aio_task for aio_task in (asyncio.all_tasks() - {asyncio.current_task()}) if not aio_task.done()]
|
||||
)
|
||||
LOG.info(
|
||||
f"Waiting for all S3 uploads to complete took {time.time() - st} seconds",
|
||||
duration=time.time() - st,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
LOG.warning(
|
||||
"Timed out waiting for all S3 uploads to complete, not all artifacts may be uploaded. Waited 30 seconds.",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
|
||||
if not workflow_run.webhook_callback_url:
|
||||
LOG.warning(
|
||||
"Workflow has no webhook callback url. Not sending workflow response",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
if not api_key:
|
||||
LOG.warning(
|
||||
"Request has no api key. Not sending workflow response",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
workflow_run_status_response = await self.build_workflow_run_status_response(
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
# send task_response to the webhook callback url
|
||||
# TODO: use async requests (httpx)
|
||||
timestamp = str(int(datetime.utcnow().timestamp()))
|
||||
payload = workflow_run_status_response.model_dump_json()
|
||||
signature = generate_skyvern_signature(
|
||||
payload=payload,
|
||||
api_key=api_key,
|
||||
)
|
||||
headers = {
|
||||
"x-skyvern-timestamp": timestamp,
|
||||
"x-skyvern-signature": signature,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
LOG.info(
|
||||
"Sending webhook run status to webhook callback url",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
webhook_callback_url=workflow_run.webhook_callback_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
)
|
||||
try:
|
||||
resp = requests.post(workflow_run.webhook_callback_url, data=payload, headers=headers)
|
||||
if resp.ok:
|
||||
LOG.info(
|
||||
"Webhook sent successfully",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
resp_code=resp.status_code,
|
||||
resp_text=resp.text,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
"Webhook failed",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
resp=resp,
|
||||
resp_code=resp.status_code,
|
||||
resp_text=resp.text,
|
||||
resp_json=resp.json(),
|
||||
)
|
||||
except Exception as e:
|
||||
raise FailedToSendWebhook(
|
||||
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id
|
||||
) from e
|
||||
|
||||
async def persist_video_data(
|
||||
self, browser_state: BrowserState, workflow: Workflow, workflow_run: WorkflowRun
|
||||
) -> None:
|
||||
# Create recording artifact after closing the browser, so we can get an accurate recording
|
||||
video_data = await app.BROWSER_MANAGER.get_video_data(
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
browser_state=browser_state,
|
||||
)
|
||||
if video_data:
|
||||
await app.ARTIFACT_MANAGER.update_artifact_data(
|
||||
artifact_id=browser_state.browser_artifacts.video_artifact_id,
|
||||
organization_id=workflow.organization_id,
|
||||
data=video_data,
|
||||
)
|
||||
|
||||
async def persist_har_data(
|
||||
self, browser_state: BrowserState, last_task: Task, workflow: Workflow, workflow_run: WorkflowRun
|
||||
) -> None:
|
||||
har_data = await app.BROWSER_MANAGER.get_har_data(
|
||||
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id, browser_state=browser_state
|
||||
)
|
||||
if har_data:
|
||||
last_step = await app.DATABASE.get_latest_step(
|
||||
task_id=last_task.task_id, organization_id=last_task.organization_id
|
||||
)
|
||||
|
||||
if last_step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=last_step,
|
||||
artifact_type=ArtifactType.HAR,
|
||||
data=har_data,
|
||||
)
|
||||
Reference in New Issue
Block a user