Move the code over from private repository (#3)

This commit is contained in:
Kerem Yilmaz
2024-03-01 10:09:30 -08:00
committed by GitHub
parent 32dd6d92a5
commit 9eddb3d812
93 changed files with 16798 additions and 0 deletions

View File

View File

@@ -0,0 +1,97 @@
import uuid
from datetime import datetime
from typing import Awaitable, Callable
import structlog
from fastapi import APIRouter, FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.requests import HTTPConnection, Request
from starlette_context.middleware import RawContextMiddleware
from starlette_context.plugins.base import Plugin
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.routes.agent_protocol import base_router
LOG = structlog.get_logger()
class Agent:
def get_agent_app(self, router: APIRouter = base_router) -> FastAPI:
"""
Start the agent server.
"""
app = FastAPI()
# Add CORS middleware
origins = [
"http://localhost:5000",
"http://127.0.0.1:5000",
"http://localhost:8000",
"http://127.0.0.1:8000",
"http://localhost:8080",
"http://127.0.0.1:8080",
# Add any other origins you want to whitelist
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(router, prefix="/api/v1")
app.add_middleware(AgentMiddleware, agent=self)
app.add_middleware(
RawContextMiddleware,
plugins=(
# TODO (suchintan): We should set these up
ExecutionDatePlugin(),
# RequestIdPlugin(),
# UserAgentPlugin(),
),
)
@app.exception_handler(Exception)
async def unexpected_exception(request: Request, exc: Exception) -> JSONResponse:
LOG.exception("Unexpected error in agent server.", exc_info=exc)
return JSONResponse(status_code=500, content={"error": f"Unexpected error: {exc}"})
@app.middleware("http")
async def request_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
request_id = str(uuid.uuid4())
skyvern_context.set(SkyvernContext(request_id=request_id))
try:
return await call_next(request)
finally:
skyvern_context.reset()
return app
class AgentMiddleware:
"""
Middleware that injects the agent instance into the request scope.
"""
def __init__(self, app: FastAPI, agent: Agent):
self.app = app
self.agent = agent
async def __call__(self, scope, receive, send): # type: ignore
scope["agent"] = self.agent
await self.app(scope, receive, send)
class ExecutionDatePlugin(Plugin):
key = "execution_date"
async def process_request(self, request: Request | HTTPConnection) -> datetime:
return datetime.now()

View File

View File

@@ -0,0 +1,134 @@
from enum import StrEnum
from typing import Any, Callable
from urllib.parse import urlparse
import aioboto3
import structlog
from aiobotocore.client import AioBaseClient
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class AWSClientType(StrEnum):
S3 = "s3"
SECRETS_MANAGER = "secretsmanager"
def execute_with_async_client(client_type: AWSClientType) -> Callable:
def decorator(f: Callable) -> Callable:
async def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
self = args[0]
assert isinstance(self, AsyncAWSClient)
session = aioboto3.Session()
async with session.client(client_type) as client:
return await f(*args, client=client, **kwargs)
return wrapper
return decorator
class AsyncAWSClient:
@execute_with_async_client(client_type=AWSClientType.SECRETS_MANAGER)
async def get_secret(self, secret_name: str, client: AioBaseClient = None) -> str | None:
try:
response = await client.get_secret_value(SecretId=secret_name)
return response["SecretString"]
except Exception as e:
try:
error_code = e.response["Error"]["Code"] # type: ignore
except Exception:
error_code = "failed-to-get-error-code"
LOG.exception("Failed to get secret.", secret_name=secret_name, error_code=error_code, exc_info=True)
return None
@execute_with_async_client(client_type=AWSClientType.S3)
async def upload_file(self, uri: str, data: bytes, client: AioBaseClient = None) -> str | None:
try:
parsed_uri = S3Uri(uri)
await client.put_object(Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key)
LOG.debug("Upload file success", uri=uri)
return uri
except Exception:
LOG.exception("S3 upload failed.", uri=uri)
return None
@execute_with_async_client(client_type=AWSClientType.S3)
async def upload_file_from_path(self, uri: str, file_path: str, client: AioBaseClient = None) -> None:
try:
parsed_uri = S3Uri(uri)
await client.upload_file(file_path, parsed_uri.bucket, parsed_uri.key)
LOG.info("Upload file from path success", uri=uri)
except Exception:
LOG.exception("S3 upload failed.", uri=uri)
@execute_with_async_client(client_type=AWSClientType.S3)
async def download_file(self, uri: str, client: AioBaseClient = None) -> bytes | None:
try:
parsed_uri = S3Uri(uri)
response = await client.get_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
return await response["Body"].read()
except Exception:
LOG.exception("S3 download failed", uri=uri)
return None
@execute_with_async_client(client_type=AWSClientType.S3)
async def create_presigned_url(self, uri: str, client: AioBaseClient = None) -> str | None:
try:
parsed_uri = S3Uri(uri)
url = await client.generate_presigned_url(
"get_object",
Params={"Bucket": parsed_uri.bucket, "Key": parsed_uri.key},
ExpiresIn=SettingsManager.get_settings().PRESIGNED_URL_EXPIRATION,
)
return url
except Exception:
LOG.exception("Failed to create presigned url.", uri=uri)
return None
class S3Uri(object):
# From: https://stackoverflow.com/questions/42641315/s3-urls-get-bucket-name-and-path
"""
>>> s = S3Uri("s3://bucket/hello/world")
>>> s.bucket
'bucket'
>>> s.key
'hello/world'
>>> s.uri
's3://bucket/hello/world'
>>> s = S3Uri("s3://bucket/hello/world?qwe1=3#ddd")
>>> s.bucket
'bucket'
>>> s.key
'hello/world?qwe1=3#ddd'
>>> s.uri
's3://bucket/hello/world?qwe1=3#ddd'
>>> s = S3Uri("s3://bucket/hello/world#foo?bar=2")
>>> s.key
'hello/world#foo?bar=2'
>>> s.uri
's3://bucket/hello/world#foo?bar=2'
"""
def __init__(self, uri: str) -> None:
self._parsed = urlparse(uri, allow_fragments=False)
@property
def bucket(self) -> str:
return self._parsed.netloc
@property
def key(self) -> str:
if self._parsed.query:
return self._parsed.path.lstrip("/") + "?" + self._parsed.query
else:
return self._parsed.path.lstrip("/")
@property
def uri(self) -> str:
return self._parsed.geturl()

View File

@@ -0,0 +1,25 @@
from typing import Callable
from pydantic import BaseModel
openai_model_to_price_lambdas = {
"gpt-4-vision-preview": (0.01, 0.03),
"gpt-4-1106-preview": (0.01, 0.03),
"gpt-3.5-turbo": (0.001, 0.002),
"gpt-3.5-turbo-1106": (0.001, 0.002),
}
class ChatCompletionPrice(BaseModel):
input_token_count: int
output_token_count: int
openai_model_to_price_lambda: Callable[[int, int], float]
def __init__(self, input_token_count: int, output_token_count: int, model_name: str):
input_token_price, output_token_price = openai_model_to_price_lambdas[model_name]
super().__init__(
input_token_count=input_token_count,
output_token_count=output_token_count,
openai_model_to_price_lambda=lambda input_token, output_token: input_token_price * input_token / 1000
+ output_token_price * output_token / 1000,
)

View File

@@ -0,0 +1,47 @@
import os
import tempfile
import zipfile
from urllib.parse import urlparse
import requests
import structlog
LOG = structlog.get_logger()
def download_file(url: str) -> str | None:
# Send an HTTP request to the URL of the file, stream=True to prevent loading the content at once into memory
r = requests.get(url, stream=True)
# Check if the request is successful
if r.status_code == 200:
# Parse the URL
a = urlparse(url)
# Get the file name
temp_dir = tempfile.mkdtemp(prefix="skyvern_downloads_")
file_name = os.path.basename(a.path)
file_path = os.path.join(temp_dir, file_name)
LOG.info(f"Downloading file to {file_path}")
with open(file_path, "wb") as f:
# Write the content of the request into the file
for chunk in r.iter_content(1024):
f.write(chunk)
LOG.info(f"File downloaded successfully to {file_path}")
return file_path
else:
LOG.error(f"Failed to download file, status code: {r.status_code}")
return None
def zip_files(files_path: str, zip_file_path: str) -> str:
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(files_path):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, files_path) # Relative path within the zip
zipf.write(file_path, arcname)
return zip_file_path

View File

@@ -0,0 +1,221 @@
import base64
import json
import random
from datetime import datetime, timedelta
from typing import Any
import commentjson
import openai
import structlog
from openai import AsyncOpenAI
from openai.types.chat.chat_completion import ChatCompletion
from skyvern.exceptions import InvalidOpenAIResponseFormat, NoAvailableOpenAIClients, OpenAIRequestTooBigError
from skyvern.forge import app
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class OpenAIKeyClientWrapper:
client: AsyncOpenAI
key: str
remaining_requests: int | None
def __init__(self, key: str, remaining_requests: int | None) -> None:
self.key = key
self.remaining_requests = remaining_requests
self.updated_at = datetime.utcnow()
self.client = AsyncOpenAI(api_key=self.key)
def update_remaining_requests(self, remaining_requests: int | None) -> None:
self.remaining_requests = remaining_requests
self.updated_at = datetime.utcnow()
def is_available(self) -> bool:
# If remaining_requests is None, then it's the first time we're trying this key
# so we can assume it's available, otherwise we check if it's greater than 0
if self.remaining_requests is None:
return True
if self.remaining_requests > 0:
return True
# If we haven't checked this in over 1 minutes, check it again
# Most of our failures are because of Tokens-per-minute (TPM) limits
if self.updated_at < (datetime.utcnow() - timedelta(minutes=1)):
return True
return False
class OpenAIClientManager:
# TODO Support other models for requests without screenshots, track rate limits for each model and key as well if any
clients: list[OpenAIKeyClientWrapper]
def __init__(self, api_keys: list[str] = SettingsManager.get_settings().OPENAI_API_KEYS) -> None:
self.clients = [OpenAIKeyClientWrapper(key, None) for key in api_keys]
def get_available_client(self) -> OpenAIKeyClientWrapper | None:
available_clients = [client for client in self.clients if client.is_available()]
if not available_clients:
return None
# Randomly select an available client to distribute requests across our accounts
return random.choice(available_clients)
async def content_builder(
self,
step: Step,
screenshots: list[bytes] | None = None,
prompt: str | None = None,
) -> list[dict[str, Any]]:
content: list[dict[str, Any]] = []
if prompt is not None:
content.append(
{
"type": "text",
"text": prompt,
}
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
if screenshots:
for screenshot in screenshots:
encoded_image = base64.b64encode(screenshot).decode("utf-8")
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encoded_image}",
},
}
)
# create artifact for each image
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
return content
async def chat_completion(
self,
step: Step,
model: str = "gpt-4-vision-preview",
max_tokens: int = 4096,
temperature: int = 0,
screenshots: list[bytes] | None = None,
prompt: str | None = None,
) -> dict[str, Any]:
LOG.info(
f"Sending LLM request",
task_id=step.task_id,
step_id=step.step_id,
num_screenshots=len(screenshots) if screenshots else 0,
)
messages = [
{
"role": "user",
"content": await self.content_builder(
step=step,
screenshots=screenshots,
prompt=prompt,
),
}
]
chat_completion_kwargs = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
}
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(chat_completion_kwargs).encode("utf-8"),
)
available_client = self.get_available_client()
if available_client is None:
raise NoAvailableOpenAIClients()
try:
response = await available_client.client.chat.completions.with_raw_response.create(**chat_completion_kwargs)
except openai.RateLimitError as e:
# If we get a RateLimitError, we can assume the key is not available anymore
if e.code == 429:
raise OpenAIRequestTooBigError(e.message)
LOG.warning(
"OpenAI rate limit exceeded, marking key as unavailable.", error_code=e.code, error_message=e.message
)
available_client.update_remaining_requests(remaining_requests=0)
available_client = self.get_available_client()
if available_client is None:
raise NoAvailableOpenAIClients()
return await self.chat_completion(
step=step,
model=model,
max_tokens=max_tokens,
temperature=temperature,
screenshots=screenshots,
prompt=prompt,
)
# TODO: https://platform.openai.com/docs/guides/rate-limits/rate-limits-in-headers
# use other headers, x-ratelimit-limit-requests, x-ratelimit-limit-tokens, x-ratelimit-remaining-tokens
# x-ratelimit-reset-requests, x-ratelimit-reset-tokens to write a more accurate algorithm for managing api keys
# If we get a response, we can assume the key is available and update the remaining requests
ratelimit_remaining_requests = response.headers.get("x-ratelimit-remaining-requests")
if not ratelimit_remaining_requests:
LOG.warning("Invalid x-ratelimit-remaining-requests from OpenAI", response.headers)
available_client.update_remaining_requests(remaining_requests=int(ratelimit_remaining_requests))
chat_completion = response.parse()
if chat_completion.usage is not None:
# TODO (Suchintan): Is this bad design?
step = await app.DATABASE.update_step(
step_id=step.step_id,
task_id=step.task_id,
organization_id=step.organization_id,
chat_completion_price=ChatCompletionPrice(
input_token_count=chat_completion.usage.prompt_tokens,
output_token_count=chat_completion.usage.completion_tokens,
model_name=model,
),
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=chat_completion.model_dump_json(indent=2).encode("utf-8"),
)
parsed_response = self.parse_response(chat_completion)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
return parsed_response
def parse_response(self, response: ChatCompletion) -> dict[str, str]:
try:
content = response.choices[0].message.content
content = content.replace("```json", "")
content = content.replace("```", "")
if not content:
raise Exception("openai response content is empty")
return commentjson.loads(content)
except Exception as e:
raise InvalidOpenAIResponseFormat(str(response)) from e

View File

View File

@@ -0,0 +1,112 @@
import asyncio
import time
from collections import defaultdict
import structlog
from skyvern.forge import app
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.id import generate_artifact_id
from skyvern.forge.sdk.models import Step
LOG = structlog.get_logger(__name__)
class ArtifactManager:
# task_id -> list of aio_tasks for uploading artifacts
upload_aiotasks_map: dict[str, list[asyncio.Task[None]]] = defaultdict(list)
async def create_artifact(
self, step: Step, artifact_type: ArtifactType, data: bytes | None = None, path: str | None = None
) -> str:
# TODO (kerem): Which is better?
# current: (disadvantage: we create the artifact_id UUID here)
# 1. generate artifact_id UUID here
# 2. build uri with artifact_id, step_id, task_id, artifact_type
# 3. create artifact in db using artifact_id, step_id, task_id, artifact_type, uri
# 4. store artifact in storage
# alternative: (disadvantage: two db calls)
# 1. create artifact in db without the URI
# 2. build uri with artifact_id, step_id, task_id, artifact_type
# 3. update artifact in db with the URI
# 4. store artifact in storage
if data is None and path is None:
raise ValueError("Either data or path must be provided to create an artifact.")
if data and path:
raise ValueError("Both data and path cannot be provided to create an artifact.")
artifact_id = generate_artifact_id()
uri = app.STORAGE.build_uri(artifact_id, step, artifact_type)
artifact = await app.DATABASE.create_artifact(
artifact_id,
step.step_id,
step.task_id,
artifact_type,
uri,
organization_id=step.organization_id,
)
if data:
# Fire and forget
aio_task = asyncio.create_task(app.STORAGE.store_artifact(artifact, data))
self.upload_aiotasks_map[step.task_id].append(aio_task)
elif path:
# Fire and forget
aio_task = asyncio.create_task(app.STORAGE.store_artifact_from_path(artifact, path))
self.upload_aiotasks_map[step.task_id].append(aio_task)
return artifact_id
async def update_artifact_data(self, artifact_id: str | None, organization_id: str | None, data: bytes) -> None:
if not artifact_id or not organization_id:
return None
artifact = await app.DATABASE.get_artifact_by_id(artifact_id, organization_id)
if not artifact:
return
# Fire and forget
aio_task = asyncio.create_task(app.STORAGE.store_artifact(artifact, data))
self.upload_aiotasks_map[artifact.task_id].append(aio_task)
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
return await app.STORAGE.retrieve_artifact(artifact)
async def get_share_link(self, artifact: Artifact) -> str | None:
return await app.STORAGE.get_share_link(artifact)
async def wait_for_upload_aiotasks_for_task(self, task_id: str) -> None:
try:
st = time.time()
async with asyncio.timeout(30):
await asyncio.gather(
*[aio_task for aio_task in self.upload_aiotasks_map[task_id] if not aio_task.done()]
)
LOG.info(
f"S3 upload tasks for task_id={task_id} completed in {time.time() - st:.2f}s",
task_id=task_id,
duration=time.time() - st,
)
except asyncio.TimeoutError:
LOG.error(f"Timeout (30s) while waiting for upload tasks for task_id={task_id}", task_id=task_id)
del self.upload_aiotasks_map[task_id]
async def wait_for_upload_aiotasks_for_tasks(self, task_ids: list[str]) -> None:
try:
st = time.time()
async with asyncio.timeout(30):
await asyncio.gather(
*[
aio_task
for task_id in task_ids
for aio_task in self.upload_aiotasks_map[task_id]
if not aio_task.done()
]
)
LOG.info(
f"S3 upload tasks for task_ids={task_ids} completed in {time.time() - st:.2f}s",
task_ids=task_ids,
duration=time.time() - st,
)
except asyncio.TimeoutError:
LOG.error(f"Timeout (30s) while waiting for upload tasks for task_ids={task_ids}", task_ids=task_ids)
for task_id in task_ids:
del self.upload_aiotasks_map[task_id]

View File

@@ -0,0 +1,78 @@
from __future__ import annotations
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel, Field
class ArtifactType(StrEnum):
RECORDING = "recording"
# DEPRECATED. pls use SCREENSHOT_LLM, SCREENSHOT_ACTION or SCREENSHOT_FINAL
SCREENSHOT = "screenshot"
# USE THESE for screenshots
SCREENSHOT_LLM = "screenshot_llm"
SCREENSHOT_ACTION = "screenshot_action"
SCREENSHOT_FINAL = "screenshot_final"
LLM_PROMPT = "llm_prompt"
LLM_REQUEST = "llm_request"
LLM_RESPONSE = "llm_response"
LLM_RESPONSE_PARSED = "llm_response_parsed"
VISIBLE_ELEMENTS_ID_XPATH_MAP = "visible_elements_id_xpath_map"
VISIBLE_ELEMENTS_TREE = "visible_elements_tree"
VISIBLE_ELEMENTS_TREE_TRIMMED = "visible_elements_tree_trimmed"
# DEPRECATED. pls use HTML_SCRAPE or HTML_ACTION
HTML = "html"
# USE THESE for htmls
HTML_SCRAPE = "html_scrape"
HTML_ACTION = "html_action"
# Debugging
TRACE = "trace"
HAR = "har"
class Artifact(BaseModel):
created_at: datetime = Field(
...,
description="The creation datetime of the task.",
examples=["2023-01-01T00:00:00Z"],
json_encoders={datetime: lambda v: v.isoformat()},
)
modified_at: datetime = Field(
...,
description="The modification datetime of the task.",
examples=["2023-01-01T00:00:00Z"],
json_encoders={datetime: lambda v: v.isoformat()},
)
artifact_id: str = Field(
...,
description="The ID of the task artifact.",
examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"],
)
task_id: str = Field(
...,
description="The ID of the task this artifact belongs to.",
examples=["50da533e-3904-4401-8a07-c49adf88b5eb"],
)
step_id: str = Field(
...,
description="The ID of the task step this artifact belongs to.",
examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"],
)
artifact_type: ArtifactType = Field(
...,
description="The type of the artifact.",
examples=["screenshot"],
)
uri: str = Field(
...,
description="The URI of the artifact.",
examples=["/Users/skyvern/hello/world.png"],
)
organization_id: str | None = None

View File

@@ -0,0 +1,45 @@
from abc import ABC, abstractmethod
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.models import Step
# TODO: This should be a part of the ArtifactType model
FILE_EXTENTSION_MAP: dict[ArtifactType, str] = {
ArtifactType.RECORDING: "webm",
ArtifactType.SCREENSHOT_LLM: "png",
ArtifactType.SCREENSHOT_ACTION: "png",
ArtifactType.SCREENSHOT_FINAL: "png",
ArtifactType.LLM_PROMPT: "txt",
ArtifactType.LLM_REQUEST: "json",
ArtifactType.LLM_RESPONSE: "json",
ArtifactType.LLM_RESPONSE_PARSED: "json",
ArtifactType.VISIBLE_ELEMENTS_ID_XPATH_MAP: "json",
ArtifactType.VISIBLE_ELEMENTS_TREE: "json",
ArtifactType.VISIBLE_ELEMENTS_TREE_TRIMMED: "json",
ArtifactType.HTML_SCRAPE: "html",
ArtifactType.HTML_ACTION: "html",
ArtifactType.TRACE: "zip",
ArtifactType.HAR: "har",
}
class BaseStorage(ABC):
@abstractmethod
def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str:
pass
@abstractmethod
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
pass
@abstractmethod
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
pass
@abstractmethod
async def get_share_link(self, artifact: Artifact) -> str | None:
pass
@abstractmethod
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
pass

View File

@@ -0,0 +1,14 @@
from skyvern.forge.sdk.artifact.storage.base import BaseStorage
from skyvern.forge.sdk.artifact.storage.local import LocalStorage
class StorageFactory:
__storage: BaseStorage = LocalStorage()
@staticmethod
def set_storage(storage: BaseStorage) -> None:
StorageFactory.__storage = storage
@staticmethod
def get_storage() -> BaseStorage:
return StorageFactory.__storage

View File

@@ -0,0 +1,66 @@
from datetime import datetime
from pathlib import Path
from urllib.parse import unquote, urlparse
import structlog
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class LocalStorage(BaseStorage):
def __init__(self, artifact_path: str = SettingsManager.get_settings().ARTIFACT_STORAGE_PATH) -> None:
self.artifact_path = artifact_path
def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str:
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"file://{self.artifact_path}/{step.task_id}/{step.order:02d}_{step.retry_index}_{step.step_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
file_path = None
try:
file_path = Path(self._parse_uri_to_path(artifact.uri))
self._create_directories_if_not_exists(file_path)
with open(file_path, "wb") as f:
f.write(data)
except Exception:
LOG.exception("Failed to store artifact locally.", file_path=file_path, artifact=artifact)
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
file_path = None
try:
file_path = Path(self._parse_uri_to_path(artifact.uri))
self._create_directories_if_not_exists(file_path)
Path(path).replace(file_path)
except Exception:
LOG.exception("Failed to store artifact locally.", file_path=file_path, artifact=artifact)
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
file_path = None
try:
file_path = self._parse_uri_to_path(artifact.uri)
with open(file_path, "rb") as f:
return f.read()
except Exception:
LOG.exception("Failed to retrieve local artifact.", file_path=file_path, artifact=artifact)
return None
async def get_share_link(self, artifact: Artifact) -> str:
return artifact.uri
@staticmethod
def _parse_uri_to_path(uri: str) -> str:
parsed_uri = urlparse(uri)
if parsed_uri.scheme != "file":
raise ValueError("Invalid URI scheme: {parsed_uri.scheme} expected: file")
path = parsed_uri.netloc + parsed_uri.path
return unquote(path)
@staticmethod
def _create_directories_if_not_exists(path_including_file_name: Path) -> None:
path = path_including_file_name.parent
path.mkdir(parents=True, exist_ok=True)

View File

View File

@@ -0,0 +1,41 @@
import hashlib
import hmac
from datetime import datetime, timedelta
from typing import Any, Union
from jose import jwt
from skyvern.forge.sdk.settings_manager import SettingsManager
ALGORITHM = "HS256"
def create_access_token(
subject: Union[str, Any],
expires_delta: timedelta | None = None,
) -> str:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(
minutes=SettingsManager.get_settings().ACCESS_TOKEN_EXPIRE_MINUTES,
)
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, SettingsManager.get_settings().SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def generate_skyvern_signature(
payload: str,
api_key: str,
) -> str:
"""
Generate Skyvern signature.
:param payload: the request body
:param api_key: the Skyvern api key
:return: the Skyvern signature
"""
hash_obj = hmac.new(api_key.encode("utf-8"), msg=payload.encode("utf-8"), digestmod=hashlib.sha256)
return hash_obj.hexdigest()

View File

@@ -0,0 +1,73 @@
from contextvars import ContextVar
from dataclasses import dataclass
@dataclass
class SkyvernContext:
request_id: str | None = None
organization_id: str | None = None
task_id: str | None = None
workflow_id: str | None = None
workflow_run_id: str | None = None
max_steps_override: int | None = None
def __repr__(self) -> str:
return f"SkyvernContext(request_id={self.request_id}, organization_id={self.organization_id}, task_id={self.task_id}, workflow_id={self.workflow_id}, workflow_run_id={self.workflow_run_id}, max_steps_override={self.max_steps_override})"
def __str__(self) -> str:
return self.__repr__()
_context: ContextVar[SkyvernContext | None] = ContextVar(
"Global context",
default=None,
)
def current() -> SkyvernContext | None:
"""
Get the current context
Returns:
The current context, or None if there is none
"""
return _context.get()
def ensure_context() -> SkyvernContext:
"""
Get the current context, or raise an error if there is none
Returns:
The current context if there is one
Raises:
RuntimeError: If there is no current context
"""
context = current()
if context is None:
raise RuntimeError("No skyvern context")
return context
def set(context: SkyvernContext) -> None:
"""
Set the current context
Args:
context: The context to set
Returns:
None
"""
_context.set(context)
def reset() -> None:
"""
Reset the current context
Returns:
None
"""
_context.set(None)

View File

View File

@@ -0,0 +1,900 @@
from datetime import datetime
from typing import Any
import structlog
from sqlalchemy import and_, create_engine, delete
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from skyvern.exceptions import WorkflowParameterNotFound
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.models import (
ArtifactModel,
AWSSecretParameterModel,
OrganizationAuthTokenModel,
OrganizationModel,
StepModel,
TaskModel,
WorkflowModel,
WorkflowParameterModel,
WorkflowRunModel,
WorkflowRunParameterModel,
)
from skyvern.forge.sdk.db.utils import (
_custom_json_serializer,
convert_to_artifact,
convert_to_aws_secret_parameter,
convert_to_organization,
convert_to_organization_auth_token,
convert_to_step,
convert_to_task,
convert_to_workflow,
convert_to_workflow_parameter,
convert_to_workflow_run,
convert_to_workflow_run_parameter,
)
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunParameter, WorkflowRunStatus
from skyvern.webeye.actions.models import AgentStepOutput
LOG = structlog.get_logger()
class AgentDB:
def __init__(self, database_string: str, debug_enabled: bool = False) -> None:
super().__init__()
self.debug_enabled = debug_enabled
self.engine = create_engine(database_string, json_serializer=_custom_json_serializer)
self.Session = sessionmaker(bind=self.engine)
async def create_task(
self,
url: str,
navigation_goal: str | None,
data_extraction_goal: str | None,
navigation_payload: dict[str, Any] | list | str | None,
webhook_callback_url: str | None = None,
organization_id: str | None = None,
proxy_location: ProxyLocation | None = None,
extracted_information_schema: dict[str, Any] | list | str | None = None,
workflow_run_id: str | None = None,
order: int | None = None,
retry: int | None = None,
) -> Task:
try:
with self.Session() as session:
new_task = TaskModel(
status="created",
url=url,
webhook_callback_url=webhook_callback_url,
navigation_goal=navigation_goal,
data_extraction_goal=data_extraction_goal,
navigation_payload=navigation_payload,
organization_id=organization_id,
proxy_location=proxy_location,
extracted_information_schema=extracted_information_schema,
workflow_run_id=workflow_run_id,
order=order,
retry=retry,
)
session.add(new_task)
session.commit()
session.refresh(new_task)
return convert_to_task(new_task, self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def create_step(
self,
task_id: str,
order: int,
retry_index: int,
organization_id: str | None = None,
) -> Step:
try:
with self.Session() as session:
new_step = StepModel(
task_id=task_id,
order=order,
retry_index=retry_index,
status="created",
organization_id=organization_id,
)
session.add(new_step)
session.commit()
session.refresh(new_step)
return convert_to_step(new_step, debug_enabled=self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def create_artifact(
self,
artifact_id: str,
step_id: str,
task_id: str,
artifact_type: str,
uri: str,
organization_id: str | None = None,
) -> Artifact:
try:
with self.Session() as session:
new_artifact = ArtifactModel(
artifact_id=artifact_id,
task_id=task_id,
step_id=step_id,
artifact_type=artifact_type,
uri=uri,
organization_id=organization_id,
)
session.add(new_artifact)
session.commit()
session.refresh(new_artifact)
return convert_to_artifact(new_artifact, self.debug_enabled)
except SQLAlchemyError:
LOG.exception("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.exception("UnexpectedError", exc_info=True)
raise
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
"""Get a task by its id"""
try:
with self.Session() as session:
if task_obj := (
session.query(TaskModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.first()
):
return convert_to_task(task_obj, self.debug_enabled)
else:
LOG.info("Task not found", task_id=task_id, organization_id=organization_id)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_step(self, task_id: str, step_id: str, organization_id: str | None = None) -> Step | None:
try:
with self.Session() as session:
if step := (
session.query(StepModel)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.first()
):
return convert_to_step(step, debug_enabled=self.debug_enabled)
else:
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_task_steps(self, task_id: str, organization_id: str | None = None) -> list[Step]:
try:
with self.Session() as session:
if (
steps := session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order)
.order_by(StepModel.retry_index)
.all()
):
return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps]
else:
return []
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> list[StepModel]:
try:
with self.Session() as session:
return (
session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order)
.order_by(StepModel.retry_index)
.all()
)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_latest_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
try:
with self.Session() as session:
if step := (
session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order.desc())
.first()
):
return convert_to_step(step, debug_enabled=self.debug_enabled)
else:
LOG.info("Latest step not found", task_id=task_id, organization_id=organization_id)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def update_step(
self,
task_id: str,
step_id: str,
status: StepStatus | None = None,
output: AgentStepOutput | None = None,
is_last: bool | None = None,
retry_index: int | None = None,
organization_id: str | None = None,
chat_completion_price: ChatCompletionPrice | None = None,
) -> Step:
try:
with self.Session() as session:
if (
step := session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.first()
):
if status is not None:
step.status = status
if output is not None:
step.output = output.model_dump()
if is_last is not None:
step.is_last = is_last
if retry_index is not None:
step.retry_index = retry_index
if chat_completion_price is not None:
if step.input_token_count is None:
step.input_token_count = 0
if step.output_token_count is None:
step.output_token_count = 0
step.input_token_count += chat_completion_price.input_token_count
step.output_token_count += chat_completion_price.output_token_count
step.step_cost = chat_completion_price.openai_model_to_price_lambda(
step.input_token_count, step.output_token_count
)
session.commit()
updated_step = await self.get_step(task_id, step_id, organization_id)
if not updated_step:
raise NotFoundError("Step not found")
return updated_step
else:
raise NotFoundError("Step not found")
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except NotFoundError:
LOG.error("NotFoundError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def update_task(
self,
task_id: str,
status: TaskStatus,
extracted_information: dict[str, Any] | list | str | None = None,
failure_reason: str | None = None,
organization_id: str | None = None,
) -> Task:
try:
with self.Session() as session:
if (
task := session.query(TaskModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.first()
):
task.status = status
if extracted_information is not None:
task.extracted_information = extracted_information
if failure_reason is not None:
task.failure_reason = failure_reason
session.commit()
updated_task = await self.get_task(task_id, organization_id=organization_id)
if not updated_task:
raise NotFoundError("Task not found")
return updated_task
else:
raise NotFoundError("Task not found")
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except NotFoundError:
LOG.error("NotFoundError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_tasks(self, page: int = 1, page_size: int = 10, organization_id: str | None = None) -> list[Task]:
"""
Get all tasks.
:param page: Starts at 1
:param page_size:
:return:
"""
if page < 1:
raise ValueError(f"Page must be greater than 0, got {page}")
try:
with self.Session() as session:
db_page = page - 1 # offset logic is 0 based
tasks = (
session.query(TaskModel)
.filter_by(organization_id=organization_id)
.order_by(TaskModel.created_at.desc())
.limit(page_size)
.offset(db_page * page_size)
.all()
)
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_organization(self, organization_id: str) -> Organization | None:
try:
with self.Session() as session:
if organization := (
session.query(OrganizationModel).filter_by(organization_id=organization_id).first()
):
return convert_to_organization(organization)
else:
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def create_organization(
self,
organization_name: str,
webhook_callback_url: str | None = None,
max_steps_per_run: int | None = None,
) -> Organization:
with self.Session() as session:
org = OrganizationModel(
organization_name=organization_name,
webhook_callback_url=webhook_callback_url,
max_steps_per_run=max_steps_per_run,
)
session.add(org)
session.commit()
session.refresh(org)
return convert_to_organization(org)
async def get_valid_org_auth_token(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
) -> OrganizationAuthToken | None:
try:
with self.Session() as session:
if token := (
session.query(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
.filter_by(valid=True)
.first()
):
return convert_to_organization_auth_token(token)
else:
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def validate_org_auth_token(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
token: str,
) -> OrganizationAuthToken | None:
try:
with self.Session() as session:
if token_obj := (
session.query(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
.filter_by(token=token)
.filter_by(valid=True)
.first()
):
return convert_to_organization_auth_token(token_obj)
else:
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def create_org_auth_token(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
token: str,
) -> OrganizationAuthToken:
with self.Session() as session:
token = OrganizationAuthTokenModel(
organization_id=organization_id,
token_type=token_type,
token=token,
)
session.add(token)
session.commit()
session.refresh(token)
return convert_to_organization_auth_token(token)
async def get_artifacts_for_task_step(
self,
task_id: str,
step_id: str,
organization_id: str | None = None,
) -> list[Artifact]:
try:
with self.Session() as session:
if artifacts := (
session.query(ArtifactModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.all()
):
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
else:
return []
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_artifact_by_id(
self,
artifact_id: str,
organization_id: str,
) -> Artifact | None:
try:
with self.Session() as session:
if artifact := (
session.query(ArtifactModel)
.filter_by(artifact_id=artifact_id)
.filter_by(organization_id=organization_id)
.first()
):
return convert_to_artifact(artifact, self.debug_enabled)
else:
return None
except SQLAlchemyError:
LOG.exception("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.exception("UnexpectedError", exc_info=True)
raise
async def get_artifact(
self,
task_id: str,
step_id: str,
artifact_type: ArtifactType,
organization_id: str | None = None,
) -> Artifact | None:
try:
with self.Session() as session:
artifact = (
session.query(ArtifactModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.filter_by(artifact_type=artifact_type)
.order_by(ArtifactModel.created_at.desc())
.first()
)
if artifact:
return convert_to_artifact(artifact, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_artifact_for_workflow_run(
self,
workflow_run_id: str,
artifact_type: ArtifactType,
organization_id: str | None = None,
) -> Artifact | None:
try:
with self.Session() as session:
artifact = (
session.query(ArtifactModel)
.join(TaskModel, TaskModel.task_id == ArtifactModel.task_id)
.filter(TaskModel.workflow_run_id == workflow_run_id)
.filter(ArtifactModel.artifact_type == artifact_type)
.filter(ArtifactModel.organization_id == organization_id)
.order_by(ArtifactModel.created_at.desc())
.first()
)
if artifact:
return convert_to_artifact(artifact, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_latest_artifact(
self,
task_id: str,
step_id: str | None = None,
artifact_types: list[ArtifactType] | None = None,
organization_id: str | None = None,
) -> Artifact | None:
try:
with self.Session() as session:
artifact_query = session.query(ArtifactModel).filter_by(task_id=task_id)
if step_id:
artifact_query = artifact_query.filter_by(step_id=step_id)
if organization_id:
artifact_query = artifact_query.filter_by(organization_id=organization_id)
if artifact_types:
artifact_query = artifact_query.filter(ArtifactModel.artifact_type.in_(artifact_types))
artifact = artifact_query.order_by(ArtifactModel.created_at.desc()).first()
if artifact:
return convert_to_artifact(artifact, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.exception("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.exception("UnexpectedError", exc_info=True)
raise
async def get_latest_task_by_workflow_id(
self,
organization_id: str,
workflow_id: str,
before: datetime | None = None,
) -> Task | None:
try:
with self.Session() as session:
query = (
session.query(TaskModel)
.filter_by(organization_id=organization_id)
.filter_by(workflow_id=workflow_id)
)
if before:
query = query.filter(TaskModel.created_at < before)
task = query.order_by(TaskModel.created_at.desc()).first()
if task:
return convert_to_task(task, debug_enabled=self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def create_workflow(
self,
organization_id: str,
title: str,
workflow_definition: dict[str, Any],
description: str | None = None,
) -> Workflow:
with self.Session() as session:
workflow = WorkflowModel(
organization_id=organization_id,
title=title,
description=description,
workflow_definition=workflow_definition,
)
session.add(workflow)
session.commit()
session.refresh(workflow)
return convert_to_workflow(workflow, self.debug_enabled)
async def get_workflow(self, workflow_id: str) -> Workflow | None:
try:
with self.Session() as session:
if workflow := session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first():
return convert_to_workflow(workflow, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def update_workflow(
self,
workflow_id: str,
title: str | None = None,
description: str | None = None,
workflow_definition: dict[str, Any] | None = None,
) -> Workflow | None:
with self.Session() as session:
workflow = session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first()
if workflow:
if title:
workflow.title = title
if description:
workflow.description = description
if workflow_definition:
workflow.workflow_definition = workflow_definition
session.commit()
session.refresh(workflow)
return convert_to_workflow(workflow, self.debug_enabled)
LOG.error("Workflow not found, nothing to update", workflow_id=workflow_id)
return None
async def create_workflow_run(
self, workflow_id: str, proxy_location: ProxyLocation | None = None, webhook_callback_url: str | None = None
) -> WorkflowRun:
try:
with self.Session() as session:
workflow_run = WorkflowRunModel(
workflow_id=workflow_id,
proxy_location=proxy_location,
status="created",
webhook_callback_url=webhook_callback_url,
)
session.add(workflow_run)
session.commit()
session.refresh(workflow_run)
return convert_to_workflow_run(workflow_run)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def update_workflow_run(self, workflow_run_id: str, status: WorkflowRunStatus) -> WorkflowRun | None:
with self.Session() as session:
workflow_run = session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first()
if workflow_run:
workflow_run.status = status
session.commit()
session.refresh(workflow_run)
return convert_to_workflow_run(workflow_run)
LOG.error("WorkflowRun not found, nothing to update", workflow_run_id=workflow_run_id)
return None
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun | None:
try:
with self.Session() as session:
if workflow_run := session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first():
return convert_to_workflow_run(workflow_run)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_workflow_runs(self, workflow_id: str) -> list[WorkflowRun]:
try:
with self.Session() as session:
workflow_runs = session.query(WorkflowRunModel).filter_by(workflow_id=workflow_id).all()
return [convert_to_workflow_run(run) for run in workflow_runs]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def create_workflow_parameter(
self,
workflow_id: str,
workflow_parameter_type: WorkflowParameterType,
key: str,
default_value: Any,
description: str | None = None,
) -> WorkflowParameter:
try:
with self.Session() as session:
workflow_parameter = WorkflowParameterModel(
workflow_id=workflow_id,
workflow_parameter_type=workflow_parameter_type,
key=key,
default_value=default_value,
description=description,
)
session.add(workflow_parameter)
session.commit()
session.refresh(workflow_parameter)
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def create_aws_secret_parameter(
self,
workflow_id: str,
key: str,
aws_key: str,
description: str | None = None,
) -> AWSSecretParameter:
with self.Session() as session:
aws_secret_parameter = AWSSecretParameterModel(
workflow_id=workflow_id,
key=key,
aws_key=aws_key,
description=description,
)
session.add(aws_secret_parameter)
session.commit()
session.refresh(aws_secret_parameter)
return convert_to_aws_secret_parameter(aws_secret_parameter)
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
try:
with self.Session() as session:
workflow_parameters = session.query(WorkflowParameterModel).filter_by(workflow_id=workflow_id).all()
return [convert_to_workflow_parameter(parameter) for parameter in workflow_parameters]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_workflow_parameter(self, workflow_parameter_id: str) -> WorkflowParameter | None:
try:
with self.Session() as session:
if workflow_parameter := (
session.query(WorkflowParameterModel).filter_by(workflow_parameter_id=workflow_parameter_id).first()
):
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def create_workflow_run_parameter(
self, workflow_run_id: str, workflow_parameter_id: str, value: Any
) -> WorkflowRunParameter:
try:
with self.Session() as session:
workflow_run_parameter = WorkflowRunParameterModel(
workflow_run_id=workflow_run_id,
workflow_parameter_id=workflow_parameter_id,
value=value,
)
session.add(workflow_run_parameter)
session.commit()
session.refresh(workflow_run_parameter)
workflow_parameter = await self.get_workflow_parameter(workflow_parameter_id)
if not workflow_parameter:
raise WorkflowParameterNotFound(workflow_parameter_id)
return convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_workflow_run_parameters(
self, workflow_run_id: str
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
try:
with self.Session() as session:
workflow_run_parameters = (
session.query(WorkflowRunParameterModel).filter_by(workflow_run_id=workflow_run_id).all()
)
results = []
for workflow_run_parameter in workflow_run_parameters:
workflow_parameter = await self.get_workflow_parameter(workflow_run_parameter.workflow_parameter_id)
if not workflow_parameter:
raise WorkflowParameterNotFound(
workflow_parameter_id=workflow_run_parameter.workflow_parameter_id
)
results.append(
(
workflow_parameter,
convert_to_workflow_run_parameter(
workflow_run_parameter, workflow_parameter, self.debug_enabled
),
)
)
return results
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
try:
with self.Session() as session:
if task := (
session.query(TaskModel)
.filter_by(workflow_run_id=workflow_run_id)
.order_by(TaskModel.created_at.desc())
.first()
):
return convert_to_task(task, debug_enabled=self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
try:
with self.Session() as session:
tasks = (
session.query(TaskModel)
.filter_by(workflow_run_id=workflow_run_id)
.order_by(TaskModel.created_at)
.all()
)
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def delete_task_artifacts(self, organization_id: str, task_id: str) -> None:
with self.Session() as session:
# delete artifacts by filtering organization_id and task_id
stmt = delete(ArtifactModel).where(
and_(
ArtifactModel.organization_id == organization_id,
ArtifactModel.task_id == task_id,
)
)
session.execute(stmt)
session.commit()
async def delete_task_steps(self, organization_id: str, task_id: str) -> None:
with self.Session() as session:
# delete artifacts by filtering organization_id and task_id
stmt = delete(StepModel).where(
and_(
StepModel.organization_id == organization_id,
StepModel.task_id == task_id,
)
)
session.execute(stmt)
session.commit()

View File

@@ -0,0 +1,15 @@
from enum import StrEnum
class OrganizationAuthTokenType(StrEnum):
api = "api"
class ScheduleRuleUnit(StrEnum):
# No support for scheduling every second
minute = "minute"
hour = "hour"
day = "day"
week = "week"
month = "month"
year = "year"

View File

@@ -0,0 +1,2 @@
class NotFoundError(Exception):
pass

136
skyvern/forge/sdk/db/id.py Normal file
View File

@@ -0,0 +1,136 @@
import hashlib
import itertools
import os
import platform
import random
import time
# 6/20/2022 12AM
BASE_EPOCH = 1655683200
VERSION = 0
# Number of bits
TIMESTAMP_BITS = 32
WORKER_ID_BITS = 21
SEQUENCE_BITS = 10
VERSION_BITS = 1
# Bit shits (left)
TIMESTAMP_SHIFT = 32
WORKER_ID_SHIFT = 11
SEQUENCE_SHIFT = 1
VERSION_SHIFT = 0
SEQUENCE_MAX = (2**SEQUENCE_BITS) - 1
_sequence_start = None
SEQUENCE_COUNTER = itertools.count()
_worker_hash = None
# prefix
ORGANIZATION_AUTH_TOKEN_PREFIX = "oat"
ORG_PREFIX = "o"
TASK_PREFIX = "tsk"
USER_PREFIX = "u"
STEP_PREFIX = "stp"
ARTIFACT_PREFIX = "a"
WORKFLOW_PREFIX = "w"
WORKFLOW_RUN_PREFIX = "wr"
WORKFLOW_PARAMETER_PREFIX = "wp"
AWS_SECRET_PARAMETER_PREFIX = "asp"
def generate_workflow_id() -> str:
int_id = generate_id()
return f"{WORKFLOW_PREFIX}_{int_id}"
def generate_workflow_run_id() -> str:
int_id = generate_id()
return f"{WORKFLOW_RUN_PREFIX}_{int_id}"
def generate_aws_secret_parameter_id() -> str:
int_id = generate_id()
return f"{AWS_SECRET_PARAMETER_PREFIX}_{int_id}"
def generate_workflow_parameter_id() -> str:
int_id = generate_id()
return f"{WORKFLOW_PARAMETER_PREFIX}_{int_id}"
def generate_organization_auth_token_id() -> str:
int_id = generate_id()
return f"{ORGANIZATION_AUTH_TOKEN_PREFIX}_{int_id}"
def generate_org_id() -> str:
int_id = generate_id()
return f"{ORG_PREFIX}_{int_id}"
def generate_task_id() -> str:
int_id = generate_id()
return f"{TASK_PREFIX}_{int_id}"
def generate_step_id() -> str:
int_id = generate_id()
return f"{STEP_PREFIX}_{int_id}"
def generate_artifact_id() -> str:
int_id = generate_id()
return f"{ARTIFACT_PREFIX}_{int_id}"
def generate_user_id() -> str:
int_id = generate_id()
return f"{USER_PREFIX}_{int_id}"
def generate_id() -> int:
"""
generate a 64-bit int ID
"""
create_at = current_time() - BASE_EPOCH
sequence = _increment_and_get_sequence()
time_part = _mask_shift(create_at, TIMESTAMP_BITS, TIMESTAMP_SHIFT)
worker_part = _mask_shift(_get_worker_hash(), WORKER_ID_BITS, WORKER_ID_SHIFT)
sequence_part = _mask_shift(sequence, SEQUENCE_BITS, SEQUENCE_SHIFT)
version_part = _mask_shift(VERSION, VERSION_BITS, VERSION_SHIFT)
return time_part | worker_part | sequence_part | version_part
def _increment_and_get_sequence() -> int:
global _sequence_start
if _sequence_start is None:
_sequence_start = random.randint(0, SEQUENCE_MAX)
return (_sequence_start + next(SEQUENCE_COUNTER)) % SEQUENCE_MAX
def current_time() -> int:
return int(time.time())
def current_time_ms() -> int:
return int(time.time() * 1000)
def _mask_shift(value: int, mask_bits: int, shift_bits: int) -> int:
return (value & ((2**mask_bits) - 1)) << shift_bits
def _get_worker_hash() -> int:
global _worker_hash
if _worker_hash is None:
_worker_hash = _generate_worker_hash()
return _worker_hash
def _generate_worker_hash() -> int:
worker_identity = f"{platform.node()}:{os.getpid()}"
return int(hashlib.md5(worker_identity.encode()).hexdigest()[-15:], 16)

View File

@@ -0,0 +1,172 @@
import datetime
from sqlalchemy import JSON, Boolean, Column, DateTime, Enum, ForeignKey, Integer, Numeric, String, UnicodeText
from sqlalchemy.orm import DeclarativeBase
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.id import (
generate_artifact_id,
generate_aws_secret_parameter_id,
generate_org_id,
generate_organization_auth_token_id,
generate_step_id,
generate_task_id,
generate_workflow_id,
generate_workflow_parameter_id,
generate_workflow_run_id,
)
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
class Base(DeclarativeBase):
pass
class TaskModel(Base):
__tablename__ = "tasks"
task_id = Column(String, primary_key=True, index=True, default=generate_task_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"))
status = Column(String)
webhook_callback_url = Column(String)
url = Column(String)
navigation_goal = Column(String)
data_extraction_goal = Column(String)
navigation_payload = Column(JSON)
extracted_information = Column(JSON)
failure_reason = Column(String)
proxy_location = Column(Enum(ProxyLocation))
extracted_information_schema = Column(JSON)
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"))
order = Column(Integer, nullable=True)
retry = Column(Integer, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
class StepModel(Base):
__tablename__ = "steps"
step_id = Column(String, primary_key=True, index=True, default=generate_step_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"))
task_id = Column(String, ForeignKey("tasks.task_id"))
status = Column(String)
output = Column(JSON)
order = Column(Integer)
is_last = Column(Boolean, default=False)
retry_index = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
input_token_count = Column(Integer, default=0)
output_token_count = Column(Integer, default=0)
step_cost = Column(Numeric, default=0)
class OrganizationModel(Base):
__tablename__ = "organizations"
organization_id = Column(String, primary_key=True, index=True, default=generate_org_id)
organization_name = Column(String, nullable=False)
webhook_callback_url = Column(UnicodeText)
max_steps_per_run = Column(Integer)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
class OrganizationAuthTokenModel(Base):
__tablename__ = "organization_auth_tokens"
id = Column(
String,
primary_key=True,
index=True,
default=generate_organization_auth_token_id,
)
organization_id = Column(String, ForeignKey("organizations.organization_id"), index=True, nullable=False)
token_type = Column(Enum(OrganizationAuthTokenType), nullable=False)
token = Column(String, index=True, nullable=False)
valid = Column(Boolean, nullable=False, default=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
deleted_at = Column(DateTime, nullable=True)
class ArtifactModel(Base):
__tablename__ = "artifacts"
artifact_id = Column(String, primary_key=True, index=True, default=generate_artifact_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"))
task_id = Column(String, ForeignKey("tasks.task_id"))
step_id = Column(String, ForeignKey("steps.step_id"))
artifact_type = Column(String)
uri = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
class WorkflowModel(Base):
__tablename__ = "workflows"
workflow_id = Column(String, primary_key=True, index=True, default=generate_workflow_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"))
title = Column(String, nullable=False)
description = Column(String, nullable=True)
workflow_definition = Column(JSON, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
deleted_at = Column(DateTime, nullable=True)
class WorkflowRunModel(Base):
__tablename__ = "workflow_runs"
workflow_run_id = Column(String, primary_key=True, index=True, default=generate_workflow_run_id)
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), nullable=False)
status = Column(String, nullable=False)
proxy_location = Column(Enum(ProxyLocation))
webhook_callback_url = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
class WorkflowParameterModel(Base):
__tablename__ = "workflow_parameters"
workflow_parameter_id = Column(String, primary_key=True, index=True, default=generate_workflow_parameter_id)
workflow_parameter_type = Column(String, nullable=False)
key = Column(String, nullable=False)
description = Column(String, nullable=True)
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
default_value = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
deleted_at = Column(DateTime, nullable=True)
class AWSSecretParameterModel(Base):
__tablename__ = "aws_secret_parameters"
aws_secret_parameter_id = Column(String, primary_key=True, index=True, default=generate_aws_secret_parameter_id)
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
key = Column(String, nullable=False)
description = Column(String, nullable=True)
aws_key = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
deleted_at = Column(DateTime, nullable=True)
class WorkflowRunParameterModel(Base):
__tablename__ = "workflow_run_parameters"
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), primary_key=True, index=True)
workflow_parameter_id = Column(
String, ForeignKey("workflow_parameters.workflow_parameter_id"), primary_key=True, index=True
)
# Can be bool | int | float | str | dict | list depending on the workflow parameter type
value = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)

View File

@@ -0,0 +1,220 @@
import json
import typing
import pydantic.json
import structlog
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.models import (
ArtifactModel,
AWSSecretParameterModel,
OrganizationAuthTokenModel,
OrganizationModel,
StepModel,
TaskModel,
WorkflowModel,
WorkflowParameterModel,
WorkflowRunModel,
WorkflowRunParameterModel,
)
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
from skyvern.forge.sdk.workflow.models.workflow import (
Workflow,
WorkflowDefinition,
WorkflowRun,
WorkflowRunParameter,
WorkflowRunStatus,
)
LOG = structlog.get_logger()
@typing.no_type_check
def _custom_json_serializer(*args, **kwargs) -> str:
"""
Encodes json in the same way that pydantic does.
"""
return json.dumps(*args, default=pydantic.json.pydantic_encoder, **kwargs)
def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
if debug_enabled:
LOG.debug("Converting TaskModel to Task", task_id=task_obj.task_id)
task = Task(
task_id=task_obj.task_id,
status=TaskStatus(task_obj.status),
created_at=task_obj.created_at,
modified_at=task_obj.modified_at,
url=task_obj.url,
webhook_callback_url=task_obj.webhook_callback_url,
navigation_goal=task_obj.navigation_goal,
data_extraction_goal=task_obj.data_extraction_goal,
navigation_payload=task_obj.navigation_payload,
extracted_information=task_obj.extracted_information,
failure_reason=task_obj.failure_reason,
organization_id=task_obj.organization_id,
proxy_location=ProxyLocation(task_obj.proxy_location) if task_obj.proxy_location else None,
extracted_information_schema=task_obj.extracted_information_schema,
workflow_run_id=task_obj.workflow_run_id,
order=task_obj.order,
retry=task_obj.retry,
)
return task
def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
if debug_enabled:
LOG.debug("Converting StepModel to Step", step_id=step_model.step_id)
return Step(
task_id=step_model.task_id,
step_id=step_model.step_id,
created_at=step_model.created_at,
modified_at=step_model.modified_at,
status=StepStatus(step_model.status),
output=step_model.output,
order=step_model.order,
is_last=step_model.is_last,
retry_index=step_model.retry_index,
organization_id=step_model.organization_id,
input_token_count=step_model.input_token_count,
output_token_count=step_model.output_token_count,
step_cost=step_model.step_cost,
)
def convert_to_organization(org_model: OrganizationModel) -> Organization:
return Organization(
organization_id=org_model.organization_id,
organization_name=org_model.organization_name,
webhook_callback_url=org_model.webhook_callback_url,
max_steps_per_run=org_model.max_steps_per_run,
created_at=org_model.created_at,
modified_at=org_model.modified_at,
)
def convert_to_organization_auth_token(org_auth_token: OrganizationAuthTokenModel) -> OrganizationAuthToken:
return OrganizationAuthToken(
id=org_auth_token.id,
organization_id=org_auth_token.organization_id,
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
token=org_auth_token.token,
valid=org_auth_token.valid,
created_at=org_auth_token.created_at,
modified_at=org_auth_token.modified_at,
)
def convert_to_artifact(artifact_model: ArtifactModel, debug_enabled: bool = False) -> Artifact:
if debug_enabled:
LOG.debug("Converting ArtifactModel to Artifact", artifact_id=artifact_model.artifact_id)
return Artifact(
artifact_id=artifact_model.artifact_id,
artifact_type=ArtifactType[artifact_model.artifact_type.upper()],
uri=artifact_model.uri,
task_id=artifact_model.task_id,
step_id=artifact_model.step_id,
created_at=artifact_model.created_at,
modified_at=artifact_model.modified_at,
organization_id=artifact_model.organization_id,
)
def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = False) -> Workflow:
if debug_enabled:
LOG.debug("Converting WorkflowModel to Workflow", workflow_id=workflow_model.workflow_id)
return Workflow(
workflow_id=workflow_model.workflow_id,
organization_id=workflow_model.organization_id,
title=workflow_model.title,
description=workflow_model.description,
workflow_definition=WorkflowDefinition.model_validate(workflow_model.workflow_definition),
created_at=workflow_model.created_at,
modified_at=workflow_model.modified_at,
deleted_at=workflow_model.deleted_at,
)
def convert_to_workflow_run(workflow_run_model: WorkflowRunModel, debug_enabled: bool = False) -> WorkflowRun:
if debug_enabled:
LOG.debug("Converting WorkflowRunModel to WorkflowRun", workflow_run_id=workflow_run_model.workflow_run_id)
return WorkflowRun(
workflow_run_id=workflow_run_model.workflow_run_id,
workflow_id=workflow_run_model.workflow_id,
status=WorkflowRunStatus[workflow_run_model.status],
proxy_location=ProxyLocation(workflow_run_model.proxy_location) if workflow_run_model.proxy_location else None,
webhook_callback_url=workflow_run_model.webhook_callback_url,
created_at=workflow_run_model.created_at,
modified_at=workflow_run_model.modified_at,
)
def convert_to_workflow_parameter(
workflow_parameter_model: WorkflowParameterModel, debug_enabled: bool = False
) -> WorkflowParameter:
if debug_enabled:
LOG.debug(
"Converting WorkflowParameterModel to WorkflowParameter",
workflow_parameter_id=workflow_parameter_model.workflow_parameter_id,
)
workflow_parameter_type = WorkflowParameterType[workflow_parameter_model.workflow_parameter_type.upper()]
return WorkflowParameter(
workflow_parameter_id=workflow_parameter_model.workflow_parameter_id,
workflow_parameter_type=workflow_parameter_type,
workflow_id=workflow_parameter_model.workflow_id,
default_value=workflow_parameter_type.convert_value(workflow_parameter_model.default_value),
key=workflow_parameter_model.key,
description=workflow_parameter_model.description,
created_at=workflow_parameter_model.created_at,
modified_at=workflow_parameter_model.modified_at,
deleted_at=workflow_parameter_model.deleted_at,
)
def convert_to_aws_secret_parameter(
aws_secret_parameter_model: AWSSecretParameterModel, debug_enabled: bool = False
) -> AWSSecretParameter:
if debug_enabled:
LOG.debug(
"Converting AWSSecretParameterModel to AWSSecretParameter",
aws_secret_parameter_id=aws_secret_parameter_model.id,
)
return AWSSecretParameter(
aws_secret_parameter_id=aws_secret_parameter_model.aws_secret_parameter_id,
workflow_id=aws_secret_parameter_model.workflow_id,
key=aws_secret_parameter_model.key,
description=aws_secret_parameter_model.description,
aws_key=aws_secret_parameter_model.aws_key,
created_at=aws_secret_parameter_model.created_at,
modified_at=aws_secret_parameter_model.modified_at,
deleted_at=aws_secret_parameter_model.deleted_at,
)
def convert_to_workflow_run_parameter(
workflow_run_parameter_model: WorkflowRunParameterModel,
workflow_parameter: WorkflowParameter,
debug_enabled: bool = False,
) -> WorkflowRunParameter:
if debug_enabled:
LOG.debug(
"Converting WorkflowRunParameterModel to WorkflowRunParameter",
workflow_run_id=workflow_run_parameter_model.workflow_run_id,
workflow_parameter_id=workflow_run_parameter_model.workflow_parameter_id,
)
return WorkflowRunParameter(
workflow_run_id=workflow_run_parameter_model.workflow_run_id,
workflow_parameter_id=workflow_run_parameter_model.workflow_parameter_id,
value=workflow_parameter.workflow_parameter_type.convert_value(workflow_run_parameter_model.value),
created_at=workflow_run_parameter_model.created_at,
)

View File

View File

@@ -0,0 +1,85 @@
import abc
from fastapi import BackgroundTasks
from skyvern.forge import app
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.models import Organization
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
class AsyncExecutor(abc.ABC):
@abc.abstractmethod
async def execute_task(
self,
background_tasks: BackgroundTasks,
task: Task,
organization: Organization,
max_steps_override: int | None,
api_key: str | None,
) -> None:
pass
@abc.abstractmethod
async def execute_workflow(
self,
background_tasks: BackgroundTasks,
organization: Organization,
workflow_id: str,
workflow_run_id: str,
max_steps_override: int | None,
api_key: str | None,
) -> None:
pass
class BackgroundTaskExecutor(AsyncExecutor):
async def execute_task(
self,
background_tasks: BackgroundTasks,
task: Task,
organization: Organization,
max_steps_override: int | None,
api_key: str | None,
) -> None:
step = await app.DATABASE.create_step(
task.task_id,
order=0,
retry_index=0,
organization_id=organization.organization_id,
)
task = await app.DATABASE.update_task(
task.task_id,
TaskStatus.running,
organization_id=organization.organization_id,
)
context: SkyvernContext = skyvern_context.ensure_context()
context.task_id = task.task_id
context.organization_id = organization.organization_id
context.max_steps_override = max_steps_override
background_tasks.add_task(
app.agent.execute_step,
organization,
task,
step,
api_key,
)
async def execute_workflow(
self,
background_tasks: BackgroundTasks,
organization: Organization,
workflow_id: str,
workflow_run_id: str,
max_steps_override: int | None,
api_key: str | None,
) -> None:
background_tasks.add_task(
app.WORKFLOW_SERVICE.execute_workflow,
workflow_run_id=workflow_run_id,
api_key=api_key,
)

View File

@@ -0,0 +1,13 @@
from skyvern.forge.sdk.executor.async_executor import AsyncExecutor, BackgroundTaskExecutor
class AsyncExecutorFactory:
__instance: AsyncExecutor = BackgroundTaskExecutor()
@staticmethod
def set_executor(executor: AsyncExecutor) -> None:
AsyncExecutorFactory.__instance = executor
@staticmethod
def get_executor() -> AsyncExecutor:
return AsyncExecutorFactory.__instance

View File

@@ -0,0 +1,90 @@
import logging
import structlog
from structlog.typing import EventDict
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.settings_manager import SettingsManager
def add_kv_pairs_to_msg(logger: logging.Logger, method_name: str, event_dict: EventDict) -> EventDict:
"""
A custom processor to add key-value pairs to the 'msg' field.
"""
# Add context to the log
context = skyvern_context.current()
if context:
if context.request_id:
event_dict["request_id"] = context.request_id
if context.organization_id:
event_dict["organization_id"] = context.organization_id
if context.task_id:
event_dict["task_id"] = context.task_id
if context.workflow_id:
event_dict["workflow_id"] = context.workflow_id
if context.workflow_run_id:
event_dict["workflow_run_id"] = context.workflow_run_id
# Add env to the log
event_dict["env"] = SettingsManager.get_settings().ENV
if method_name not in ["info", "warning", "error", "critical", "exception"]:
# Only modify the log for these log levels
return event_dict
# Assuming 'event' or 'msg' is the field to update
msg_field = event_dict.get("msg", "")
# Add key-value pairs
kv_pairs = {k: v for k, v in event_dict.items() if k not in ["msg", "timestamp", "level"]}
if kv_pairs:
additional_info = ", ".join(f"{k}={v}" for k, v in kv_pairs.items())
msg_field += f" | {additional_info}"
event_dict["msg"] = msg_field
return event_dict
def setup_logger() -> None:
"""
Setup the logger with the specified format
"""
# logging.config.dictConfig(logging_config)
renderer = (
structlog.processors.JSONRenderer()
if SettingsManager.get_settings().JSON_LOGGING
else structlog.dev.ConsoleRenderer()
)
additional_processors = (
[
structlog.processors.EventRenamer("msg"),
add_kv_pairs_to_msg,
structlog.processors.CallsiteParameterAdder(
{
structlog.processors.CallsiteParameter.PATHNAME,
structlog.processors.CallsiteParameter.FILENAME,
structlog.processors.CallsiteParameter.MODULE,
structlog.processors.CallsiteParameter.FUNC_NAME,
structlog.processors.CallsiteParameter.LINENO,
}
),
]
if SettingsManager.get_settings().JSON_LOGGING
else []
)
structlog.configure(
processors=[
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso"),
# structlog.processors.dict_tracebacks,
structlog.processors.format_exc_info,
]
+ additional_processors
+ [renderer]
)
uvicorn_error = logging.getLogger("uvicorn.error")
uvicorn_error.disabled = True
uvicorn_access = logging.getLogger("uvicorn.access")
uvicorn_access.disabled = True

137
skyvern/forge/sdk/models.py Normal file
View File

@@ -0,0 +1,137 @@
from __future__ import annotations
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.webeye.actions.actions import ActionType
from skyvern.webeye.actions.models import AgentStepOutput
class StepStatus(StrEnum):
created = "created"
running = "running"
failed = "failed"
completed = "completed"
def can_update_to(self, new_status: StepStatus) -> bool:
allowed_transitions: dict[StepStatus, set[StepStatus]] = {
StepStatus.created: {StepStatus.running},
StepStatus.running: {StepStatus.completed, StepStatus.failed},
StepStatus.failed: set(),
StepStatus.completed: set(),
}
return new_status in allowed_transitions[self]
def requires_output(self) -> bool:
status_requires_output = {StepStatus.completed}
return self in status_requires_output
def cant_have_output(self) -> bool:
status_cant_have_output = {StepStatus.created, StepStatus.running}
return self in status_cant_have_output
def is_terminal(self) -> bool:
status_is_terminal = {StepStatus.failed, StepStatus.completed}
return self in status_is_terminal
class Step(BaseModel):
created_at: datetime
modified_at: datetime
task_id: str
step_id: str
status: StepStatus
output: AgentStepOutput | None = None
order: int
is_last: bool
retry_index: int = 0
organization_id: str | None = None
input_token_count: int = 0
output_token_count: int = 0
step_cost: float = 0
def validate_update(self, status: StepStatus | None, output: AgentStepOutput | None, is_last: bool | None) -> None:
old_status = self.status
if status and not old_status.can_update_to(status):
raise ValueError(f"invalid_status_transition({old_status},{status},{self.step_id})")
if status and status.requires_output() and output is None:
raise ValueError(f"status_requires_output({status},{self.step_id})")
if status and status.cant_have_output() and output is not None:
raise ValueError(f"status_cant_have_output({status},{self.step_id})")
if output is not None and status is None:
raise ValueError(f"cant_set_output_without_updating_status({self.step_id})")
if self.output is not None and output is not None:
raise ValueError(f"cant_override_output({self.step_id})")
if is_last and not self.status.is_terminal():
raise ValueError(f"is_last_but_status_not_terminal({self.status},{self.step_id})")
if is_last is False:
raise ValueError(f"cant_set_is_last_to_false({self.step_id})")
def is_goal_achieved(self) -> bool:
if self.status != StepStatus.completed:
return False
# TODO (kerem): Remove this check once we have backfilled all the steps
if self.output is None or self.output.actions_and_results is None:
return False
# Check if there is a successful complete action
for action, action_results in self.output.actions_and_results:
if action.action_type != ActionType.COMPLETE:
continue
if any(action_result.success for action_result in action_results):
return True
return False
def is_terminated(self) -> bool:
if self.status != StepStatus.completed:
return False
# TODO (kerem): Remove this check once we have backfilled all the steps
if self.output is None or self.output.actions_and_results is None:
return False
# Check if there is a successful terminate action
for action, action_results in self.output.actions_and_results:
if action.action_type != ActionType.TERMINATE:
continue
if any(action_result.success for action_result in action_results):
return True
return False
class Organization(BaseModel):
organization_id: str
organization_name: str
webhook_callback_url: str | None = None
max_steps_per_run: int | None = None
created_at: datetime
modified_at: datetime
class OrganizationAuthToken(BaseModel):
id: str
organization_id: str
token_type: OrganizationAuthTokenType
token: str
valid: bool
created_at: datetime
modified_at: datetime
class TokenPayload(BaseModel):
sub: str
exp: int

View File

@@ -0,0 +1,98 @@
"""
Relative to this file I will have a prompt directory its located ../prompts
In this directory there will be a techniques directory and a directory for each model - gpt-3.5-turbo gpt-4, llama-2-70B, code-llama-7B etc
Each directory will have jinga2 templates for the prompts.
prompts in the model directories can use the techniques in the techniques directory.
Write the code I'd need to load and populate the templates.
I want the following functions:
class PromptEngine:
def __init__(self, model):
pass
def load_prompt(model, prompt_name, prompt_ags) -> str:
pass
"""
import glob
import os
from difflib import get_close_matches
from typing import Any, List
import structlog
from jinja2 import Environment, FileSystemLoader
LOG = structlog.get_logger()
class PromptEngine:
"""
Class to handle loading and populating Jinja2 templates for prompts.
"""
def __init__(self, model: str):
"""
Initialize the PromptEngine with the specified model.
Args:
model (str): The model to use for loading prompts.
"""
self.model = model
try:
# Get the list of all model directories
models_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../prompts"))
model_names = [
os.path.basename(os.path.normpath(d))
for d in glob.glob(os.path.join(models_dir, "*/"))
if os.path.isdir(d) and "techniques" not in d
]
self.model = self.get_closest_match(self.model, model_names)
self.env = Environment(loader=FileSystemLoader(models_dir))
except Exception:
LOG.error("Error initializing PromptEngine.", model=model, exc_info=True)
raise
@staticmethod
def get_closest_match(target: str, model_dirs: List[str]) -> str:
"""
Find the closest match to the target in the list of model directories.
Args:
target (str): The target model.
model_dirs (list): The list of available model directories.
Returns:
str: The closest match to the target.
"""
try:
matches = get_close_matches(target, model_dirs, n=1, cutoff=0.1)
return matches[0]
except Exception:
LOG.error("Failed to get closest match.", target=target, model_dirs=model_dirs, exc_info=True)
raise
def load_prompt(self, template: str, **kwargs: Any) -> str:
"""
Load and populate the specified template.
Args:
template (str): The name of the template to load.
**kwargs: The arguments to populate the template with.
Returns:
str: The populated template.
"""
try:
template = os.path.join(self.model, template)
jinja_template = self.env.get_template(f"{template}.j2")
return jinja_template.render(**kwargs)
except Exception:
LOG.error("Failed to load prompt.", template=template, kwargs_keys=kwargs.keys(), exc_info=True)
raise

View File

View File

@@ -0,0 +1,397 @@
from typing import Annotated, Any
import structlog
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, Query, Request, Response, status
from fastapi.responses import ORJSONResponse
from pydantic import BaseModel
from skyvern.exceptions import StepNotFound
from skyvern.forge import app
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.models import Organization, Step
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
from skyvern.forge.sdk.services import org_auth_service
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.models.workflow import (
RunWorkflowResponse,
WorkflowRequestBody,
WorkflowRunStatusResponse,
)
base_router = APIRouter()
LOG = structlog.get_logger()
@base_router.post("/webhook", tags=["server"])
async def webhook(
request: Request,
x_skyvern_signature: Annotated[str | None, Header()] = None,
x_skyvern_timestamp: Annotated[str | None, Header()] = None,
) -> Response:
payload = await request.body()
if not x_skyvern_signature or not x_skyvern_timestamp:
LOG.error(
"Webhook signature or timestamp missing",
x_skyvern_signature=x_skyvern_signature,
x_skyvern_timestamp=x_skyvern_timestamp,
payload=payload,
)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Missing webhook signature or timestamp")
generated_signature = generate_skyvern_signature(
payload.decode("utf-8"),
SettingsManager.get_settings().SKYVERN_API_KEY,
)
LOG.info(
"Webhook received",
x_skyvern_signature=x_skyvern_signature,
x_skyvern_timestamp=x_skyvern_timestamp,
payload=payload,
generated_signature=generated_signature,
valid_signature=x_skyvern_signature == generated_signature,
)
return Response(content="webhook validation", status_code=200)
@base_router.get("/heartbeat", tags=["server"])
async def check_server_status() -> Response:
"""
Check if the server is running.
"""
return Response(content="Server is running.", status_code=200)
@base_router.post("/tasks", tags=["agent"], response_model=CreateTaskResponse)
async def create_agent_task(
background_tasks: BackgroundTasks,
request: Request,
task: TaskRequest,
current_org: Organization = Depends(org_auth_service.get_current_org),
x_api_key: Annotated[str | None, Header()] = None,
x_max_steps_override: Annotated[int | None, Header()] = None,
) -> CreateTaskResponse:
agent = request["agent"]
created_task = await agent.create_task(task, current_org.organization_id)
if x_max_steps_override:
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
await app.ASYNC_EXECUTOR.execute_task(
background_tasks=background_tasks,
task=created_task,
organization=current_org,
max_steps_override=x_max_steps_override,
api_key=x_api_key,
)
return CreateTaskResponse(task_id=created_task.task_id)
@base_router.post(
"/tasks/{task_id}/steps/{step_id}",
tags=["agent"],
response_model=Step,
summary="Executes a specific step",
)
@base_router.post(
"/tasks/{task_id}/steps/",
tags=["agent"],
response_model=Step,
summary="Executes the next step",
)
async def execute_agent_task_step(
request: Request,
task_id: str,
step_id: str | None = None,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Response:
agent = request["agent"]
task = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
if not task:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No task found with id {task_id}",
)
# An empty step request means that the agent should execute the next step for the task.
if not step_id:
step = await app.DATABASE.get_latest_step(task_id=task_id, organization_id=current_org.organization_id)
if not step:
raise StepNotFound(current_org.organization_id, task_id)
LOG.info(
"Executing latest step since no step_id was provided",
task_id=task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
)
if not step:
LOG.error(
"No steps found for task",
task_id=task_id,
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No steps found for task {task_id}",
)
else:
step = await app.DATABASE.get_step(task_id, step_id, organization_id=current_org.organization_id)
if not step:
raise StepNotFound(current_org.organization_id, task_id, step_id)
LOG.info(
"Executing step",
task_id=task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
)
if not step:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No step found with id {step_id}",
)
step, _, _ = await agent.execute_step(current_org, task, step)
return Response(
content=step.model_dump_json() if step else "",
status_code=200,
media_type="application/json",
)
@base_router.get("/tasks/{task_id}", response_model=TaskResponse)
async def get_task(
request: Request,
task_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> TaskResponse:
request["agent"]
task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
if not task_obj:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task not found {task_id}",
)
# get latest step
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=current_org.organization_id)
if not latest_step:
return task_obj.to_task_response()
screenshot_url = None
# todo (kerem): only access artifacts through the artifact manager instead of db
screenshot_artifact = await app.DATABASE.get_latest_artifact(
task_id=task_obj.task_id,
step_id=latest_step.step_id,
artifact_types=[ArtifactType.SCREENSHOT_ACTION, ArtifactType.SCREENSHOT_FINAL],
organization_id=current_org.organization_id,
)
if screenshot_artifact:
screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(screenshot_artifact)
recording_artifact = await app.DATABASE.get_latest_artifact(
task_id=task_obj.task_id,
artifact_types=[ArtifactType.RECORDING],
organization_id=current_org.organization_id,
)
recording_url = None
if recording_artifact:
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
failure_reason = None
if task_obj.status == TaskStatus.failed and (latest_step.output or task_obj.failure_reason):
failure_reason = ""
if task_obj.failure_reason:
failure_reason += f"Reasoning: {task_obj.failure_reason or ''}"
failure_reason += "\n"
if latest_step.output and latest_step.output.action_results:
failure_reason += "Exceptions: "
failure_reason += str(
[f"[{ar.exception_type}]: {ar.exception_message}" for ar in latest_step.output.action_results]
)
return task_obj.to_task_response(
screenshot_url=screenshot_url,
recording_url=recording_url,
failure_reason=failure_reason,
)
@base_router.get("/internal/tasks/{task_id}", response_model=list[Task])
async def get_task_internal(
request: Request,
task_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Response:
"""
Get all tasks.
:param request:
:param page: Starting page, defaults to 1
:param page_size:
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
get_agent_task endpoint.
"""
task = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
if not task:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task not found {task_id}",
)
return ORJSONResponse(task.model_dump())
@base_router.get("/tasks", tags=["agent"], response_model=list[Task])
async def get_agent_tasks(
request: Request,
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1),
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Response:
"""
Get all tasks.
:param request:
:param page: Starting page, defaults to 1
:param page_size: Page size, defaults to 10
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
get_agent_task endpoint.
"""
request["agent"]
tasks = await app.DATABASE.get_tasks(page, page_size, organization_id=current_org.organization_id)
return ORJSONResponse([task.to_task_response().model_dump() for task in tasks])
@base_router.get("/internal/tasks", tags=["agent"], response_model=list[Task])
async def get_agent_tasks_internal(
request: Request,
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1),
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Response:
"""
Get all tasks.
:param request:
:param page: Starting page, defaults to 1
:param page_size: Page size, defaults to 10
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
get_agent_task endpoint.
"""
request["agent"]
tasks = await app.DATABASE.get_tasks(page, page_size, organization_id=current_org.organization_id)
return ORJSONResponse([task.model_dump() for task in tasks])
@base_router.get("/tasks/{task_id}/steps", tags=["agent"], response_model=list[Step])
async def get_agent_task_steps(
request: Request,
task_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Response:
"""
Get all steps for a task.
:param request:
:param task_id:
:return: List of steps for a task with pagination.
"""
request["agent"]
steps = await app.DATABASE.get_task_steps(task_id, organization_id=current_org.organization_id)
return ORJSONResponse([step.model_dump() for step in steps])
@base_router.get("/tasks/{task_id}/steps/{step_id}/artifacts", tags=["agent"], response_model=list[Artifact])
async def get_agent_task_step_artifacts(
request: Request,
task_id: str,
step_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Response:
"""
Get all artifacts for a list of steps.
:param request:
:param task_id:
:param step_id:
:return: List of artifacts for a list of steps.
"""
request["agent"]
artifacts = await app.DATABASE.get_artifacts_for_task_step(
task_id,
step_id,
organization_id=current_org.organization_id,
)
return ORJSONResponse([artifact.model_dump() for artifact in artifacts])
class ActionResultTmp(BaseModel):
action: dict[str, Any]
data: dict[str, Any] | list | str | None = None
exception_message: str | None = None
success: bool = True
@base_router.get("/tasks/{task_id}/actions", response_model=list[ActionResultTmp])
async def get_task_actions(
request: Request,
task_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> list[ActionResultTmp]:
request["agent"]
steps = await app.DATABASE.get_task_step_models(task_id, organization_id=current_org.organization_id)
results: list[ActionResultTmp] = []
for step_s in steps:
if not step_s.output or "action_results" not in step_s.output:
continue
for action_result in step_s.output["action_results"]:
results.append(ActionResultTmp.model_validate(action_result))
return results
@base_router.post("/workflows/{workflow_id}/run", response_model=RunWorkflowResponse)
async def execute_workflow(
background_tasks: BackgroundTasks,
request: Request,
workflow_id: str,
workflow_request: WorkflowRequestBody,
current_org: Organization = Depends(org_auth_service.get_current_org),
x_api_key: Annotated[str | None, Header()] = None,
x_max_steps_override: Annotated[int | None, Header()] = None,
) -> RunWorkflowResponse:
LOG.info(
f"Running workflow {workflow_id}",
workflow_id=workflow_id,
)
context = skyvern_context.ensure_context()
request_id = context.request_id
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
request_id=request_id,
workflow_request=workflow_request,
workflow_id=workflow_id,
organization_id=current_org.organization_id,
max_steps_override=x_max_steps_override,
)
if x_max_steps_override:
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
await app.ASYNC_EXECUTOR.execute_workflow(
background_tasks=background_tasks,
organization=current_org,
workflow_id=workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
max_steps_override=x_max_steps_override,
api_key=x_api_key,
)
return RunWorkflowResponse(
workflow_id=workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
)
@base_router.get("/workflows/{workflow_id}/runs/{workflow_run_id}", response_model=WorkflowRunStatusResponse)
async def get_workflow_run(
request: Request,
workflow_id: str,
workflow_run_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> WorkflowRunStatusResponse:
request["agent"]
return await app.WORKFLOW_SERVICE.build_workflow_run_status_response(
workflow_id=workflow_id, workflow_run_id=workflow_run_id, organization_id=current_org.organization_id
)

View File

View File

@@ -0,0 +1,181 @@
from __future__ import annotations
from datetime import datetime
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
class ProxyLocation(StrEnum):
US_CA = "US-CA"
US_NY = "US-NY"
US_TX = "US-TX"
US_FL = "US-FL"
US_WA = "US-WA"
RESIDENTIAL = "RESIDENTIAL"
NONE = "NONE"
class TaskRequest(BaseModel):
url: str = Field(
...,
min_length=1,
description="Starting URL for the task.",
examples=["https://www.geico.com"],
)
# TODO: use HttpUrl instead of str
webhook_callback_url: str | None = Field(
default=None,
description="The URL to call when the task is completed.",
examples=["https://my-webhook.com"],
)
navigation_goal: str | None = Field(
default=None,
description="The user's goal for the task.",
examples=["Get a quote for car insurance"],
)
data_extraction_goal: str | None = Field(
default=None,
description="The user's goal for data extraction.",
examples=["Extract the quote price"],
)
navigation_payload: dict[str, Any] | list | str | None = Field(
None,
description="The user's details needed to achieve the task.",
examples=[{"name": "John Doe", "email": "john@doe.com"}],
)
proxy_location: ProxyLocation | None = Field(
None,
description="The location of the proxy to use for the task.",
examples=["US-WA", "US-CA", "US-FL", "US-NY", "US-TX"],
)
extracted_information_schema: dict[str, Any] | list | str | None = Field(
None,
description="The requested schema of the extracted information.",
)
class TaskStatus(StrEnum):
created = "created"
running = "running"
failed = "failed"
terminated = "terminated"
completed = "completed"
def is_final(self) -> bool:
return self in {TaskStatus.failed, TaskStatus.terminated, TaskStatus.completed}
def can_update_to(self, new_status: TaskStatus) -> bool:
allowed_transitions: dict[TaskStatus, set[TaskStatus]] = {
TaskStatus.created: {TaskStatus.running},
TaskStatus.running: {TaskStatus.completed, TaskStatus.failed, TaskStatus.terminated},
TaskStatus.failed: set(),
TaskStatus.completed: set(),
}
return new_status in allowed_transitions[self]
def requires_extracted_info(self) -> bool:
status_requires_extracted_information = {TaskStatus.completed}
return self in status_requires_extracted_information
def cant_have_extracted_info(self) -> bool:
status_cant_have_extracted_information = {
TaskStatus.created,
TaskStatus.running,
TaskStatus.failed,
TaskStatus.terminated,
}
return self in status_cant_have_extracted_information
def requires_failure_reason(self) -> bool:
status_requires_failure_reason = {TaskStatus.failed, TaskStatus.terminated}
return self in status_requires_failure_reason
class Task(TaskRequest):
created_at: datetime = Field(
...,
description="The creation datetime of the task.",
examples=["2023-01-01T00:00:00Z"],
)
modified_at: datetime = Field(
...,
description="The modification datetime of the task.",
examples=["2023-01-01T00:00:00Z"],
)
task_id: str = Field(
...,
description="The ID of the task.",
examples=["50da533e-3904-4401-8a07-c49adf88b5eb"],
)
status: TaskStatus = Field(..., description="The status of the task.", examples=["created"])
extracted_information: dict[str, Any] | list | str | None = Field(
None,
description="The extracted information from the task.",
)
failure_reason: str | None = Field(
None,
description="The reason for the task failure.",
)
organization_id: str | None = None
workflow_run_id: str | None = None
order: int | None = None
retry: int | None = None
def validate_update(
self,
status: TaskStatus,
extracted_information: dict[str, Any] | list | str | None,
failure_reason: str | None = None,
) -> None:
old_status = self.status
if not old_status.can_update_to(status):
raise ValueError(f"invalid_status_transition({old_status},{status},{self.task_id}")
if status.requires_failure_reason() and failure_reason is None:
raise ValueError(f"status_requires_failure_reason({status},{self.task_id}")
if status.requires_extracted_info() and self.data_extraction_goal and extracted_information is None:
raise ValueError(f"status_requires_extracted_information({status},{self.task_id}")
if status.cant_have_extracted_info() and extracted_information is not None:
raise ValueError(f"status_cant_have_extracted_information({self.task_id})")
if self.extracted_information is not None and extracted_information is not None:
raise ValueError(f"cant_override_extracted_information({self.task_id})")
if self.failure_reason is not None and failure_reason is not None:
raise ValueError(f"cant_override_failure_reason({self.task_id})")
def to_task_response(
self, screenshot_url: str | None = None, recording_url: str | None = None, failure_reason: str | None = None
) -> TaskResponse:
return TaskResponse(
request=self,
task_id=self.task_id,
status=self.status,
created_at=self.created_at,
modified_at=self.modified_at,
extracted_information=self.extracted_information,
failure_reason=failure_reason or self.failure_reason,
screenshot_url=screenshot_url,
recording_url=recording_url,
)
class TaskResponse(BaseModel):
request: TaskRequest
task_id: str
status: TaskStatus
created_at: datetime
modified_at: datetime
extracted_information: list | dict[str, Any] | str | None = None
screenshot_url: str | None = None
recording_url: str | None = None
failure_reason: str | None = None
class CreateTaskResponse(BaseModel):
task_id: str

View File

View File

@@ -0,0 +1,76 @@
import time
from typing import Annotated
from asyncache import cached
from cachetools import TTLCache
from fastapi import Header, HTTPException, status
from jose import jwt
from jose.exceptions import JWTError
from pydantic import ValidationError
from skyvern.forge import app
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.db.client import AgentDB
from skyvern.forge.sdk.models import Organization, OrganizationAuthTokenType, TokenPayload
from skyvern.forge.sdk.settings_manager import SettingsManager
AUTHENTICATION_TTL = 60 * 60 # one hour
CACHE_SIZE = 128
ALGORITHM = "HS256"
async def get_current_org(
x_api_key: Annotated[str | None, Header()] = None,
) -> Organization:
if not x_api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
return await _get_current_org_cached(x_api_key, app.DATABASE)
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
"""
Authentication is cached for one hour
"""
try:
payload = jwt.decode(
x_api_key,
SettingsManager.get_settings().SECRET_KEY,
algorithms=[ALGORITHM],
)
api_key_data = TokenPayload(**payload)
except (JWTError, ValidationError):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
if api_key_data.exp < time.time():
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Auth token is expired",
)
organization = await db.get_organization(organization_id=api_key_data.sub)
if not organization:
raise HTTPException(status_code=404, detail="Organization not found")
# check if the token exists in the database
api_key_db_obj = await db.validate_org_auth_token(
organization_id=organization.organization_id,
token_type=OrganizationAuthTokenType.api,
token=x_api_key,
)
if not api_key_db_obj:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
# set organization_id in skyvern context and log context
context = skyvern_context.current()
if context:
context.organization_id = organization.organization_id
return organization

View File

@@ -0,0 +1,14 @@
from skyvern.config import Settings
from skyvern.config import settings as base_settings
class SettingsManager:
__instance: Settings = base_settings
@staticmethod
def get_settings() -> Settings:
return SettingsManager.__instance
@staticmethod
def set_settings(settings: Settings) -> None:
SettingsManager.__instance = settings

View File

View File

@@ -0,0 +1,79 @@
from typing import TYPE_CHECKING, Any
import structlog
from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, Parameter, ParameterType, WorkflowParameter
if TYPE_CHECKING:
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunParameter
LOG = structlog.get_logger()
class ContextManager:
aws_client: AsyncAWSClient
parameters: dict[str, PARAMETER_TYPE]
values: dict[str, Any]
def __init__(self, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]) -> None:
self.aws_client = AsyncAWSClient()
self.parameters = {}
self.values = {}
for parameter, run_parameter in workflow_parameter_tuples:
if parameter.key in self.parameters:
prev_value = self.parameters[parameter.key]
new_value = run_parameter.value
LOG.error(
f"Duplicate parameter key {parameter.key} found while initializing context manager, previous value: {prev_value}, new value: {new_value}. Using new value."
)
self.parameters[parameter.key] = parameter
self.values[parameter.key] = run_parameter.value
async def register_parameter_value(
self,
parameter: PARAMETER_TYPE,
) -> None:
if parameter.parameter_type == ParameterType.WORKFLOW:
LOG.error(f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}")
raise ValueError(
f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}"
)
elif parameter.parameter_type == ParameterType.AWS_SECRET:
secret_value = await self.aws_client.get_secret(parameter.aws_key)
if secret_value is not None:
self.values[parameter.key] = secret_value
else:
# ContextParameter values will be set within the blocks
return None
async def register_block_parameters(
self,
parameters: list[PARAMETER_TYPE],
) -> None:
for parameter in parameters:
if parameter.key in self.parameters:
LOG.debug(f"Parameter {parameter.key} already registered, skipping")
continue
if parameter.parameter_type == ParameterType.WORKFLOW:
LOG.error(
f"Workflow parameter {parameter.key} should have already been set through workflow run parameters"
)
raise ValueError(
f"Workflow parameter {parameter.key} should have already been set through workflow run parameters"
)
self.parameters[parameter.key] = parameter
await self.register_parameter_value(parameter)
def get_parameter(self, key: str) -> Parameter:
return self.parameters[key]
def get_value(self, key: str) -> Any:
return self.values[key]
def set_value(self, key: str, value: Any) -> None:
self.values[key] = value

View File

@@ -0,0 +1,221 @@
import abc
from enum import StrEnum
from typing import Annotated, Any, Literal, Union
import structlog
from pydantic import BaseModel, Field
from skyvern.exceptions import (
ContextParameterValueNotFound,
MissingBrowserStatePage,
TaskNotFound,
UnexpectedTaskStatus,
)
from skyvern.forge import app
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.workflow.context_manager import ContextManager
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, ContextParameter, WorkflowParameter
LOG = structlog.get_logger()
class BlockType(StrEnum):
TASK = "task"
FOR_LOOP = "for_loop"
class Block(BaseModel, abc.ABC):
block_type: BlockType
parent_block_id: str | None = None
next_block_id: str | None = None
@classmethod
def get_subclasses(cls) -> tuple[type["Block"], ...]:
return tuple(cls.__subclasses__())
@abc.abstractmethod
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
pass
@abc.abstractmethod
def get_all_parameters(
self,
) -> list[PARAMETER_TYPE]:
pass
class TaskBlock(Block):
block_type: Literal[BlockType.TASK] = BlockType.TASK
url: str | None = None
navigation_goal: str | None = None
data_extraction_goal: str | None = None
data_schema: dict[str, Any] | None = None
max_retries: int = 0
parameters: list[PARAMETER_TYPE] = []
def get_all_parameters(
self,
) -> list[PARAMETER_TYPE]:
return self.parameters
@staticmethod
async def get_task_order(workflow_run_id: str, current_retry: int) -> tuple[int, int]:
"""
Returns the order and retry for the next task in the workflow run as a tuple.
"""
last_task_for_workflow_run = await app.DATABASE.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
# If there is no previous task, the order will be 0 and the retry will be 0.
if last_task_for_workflow_run is None:
return 0, 0
# If there is a previous task but the current retry is 0, the order will be the order of the last task + 1
# and the retry will be 0.
order = last_task_for_workflow_run.order or 0
if current_retry == 0:
return order + 1, 0
# If there is a previous task and the current retry is not 0, the order will be the order of the last task
# and the retry will be the retry of the last task + 1. (There is a validation that makes sure the retry
# of the last task is equal to current_retry - 1) if it is not, we use last task retry + 1.
retry = last_task_for_workflow_run.retry or 0
if retry + 1 != current_retry:
LOG.error(
f"Last task for workflow run is retry number {last_task_for_workflow_run.retry}, "
f"but current retry is {current_retry}. Could be race condition. Using last task retry + 1",
workflow_run_id=workflow_run_id,
last_task_id=last_task_for_workflow_run.task_id,
last_task_retry=last_task_for_workflow_run.retry,
current_retry=current_retry,
)
return order, retry + 1
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
task = None
current_retry = 0
# initial value for will_retry is True, so that the loop runs at least once
will_retry = True
workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id=workflow_run_id)
workflow = await app.WORKFLOW_SERVICE.get_workflow(workflow_id=workflow_run.workflow_id)
# TODO (kerem) we should always retry on terminated. We should make a distinction between retriable and
# non-retryable terminations
while will_retry:
task_order, task_retry = await self.get_task_order(workflow_run_id, current_retry)
task, step = await app.agent.create_task_and_step_from_block(
task_block=self,
workflow=workflow,
workflow_run=workflow_run,
context_manager=context_manager,
task_order=task_order,
task_retry=task_retry,
)
organization = await app.DATABASE.get_organization(organization_id=workflow.organization_id)
if not organization:
raise Exception(f"Organization is missing organization_id={workflow.organization_id}")
browser_state = await app.BROWSER_MANAGER.get_or_create_for_workflow_run(
workflow_run=workflow_run, url=self.url
)
if not browser_state.page:
LOG.error("BrowserState has no page", workflow_run_id=workflow_run.workflow_run_id)
raise MissingBrowserStatePage(workflow_run_id=workflow_run.workflow_run_id)
LOG.info(
f"Navigating to page",
url=self.url,
workflow_run_id=workflow_run_id,
task_id=task.task_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
step_id=step.step_id,
)
if self.url:
await browser_state.page.goto(self.url)
await app.agent.execute_step(organization=organization, task=task, step=step, workflow_run=workflow_run)
# Check task status
updated_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=workflow.organization_id)
if not updated_task:
raise TaskNotFound(task.task_id)
if not updated_task.status.is_final():
raise UnexpectedTaskStatus(task_id=updated_task.task_id, status=updated_task.status)
if updated_task.status == TaskStatus.completed:
will_retry = False
else:
current_retry += 1
will_retry = current_retry <= self.max_retries
retry_message = f", retrying task {current_retry}/{self.max_retries}" if will_retry else ""
LOG.warning(
f"Task failed with status {updated_task.status}{retry_message}",
task_id=updated_task.task_id,
status=updated_task.status,
workflow_run_id=workflow_run_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
current_retry=current_retry,
max_retries=self.max_retries,
)
class ForLoopBlock(Block):
block_type: Literal[BlockType.FOR_LOOP] = BlockType.FOR_LOOP
# TODO (kerem): Add support for ContextParameter
loop_over: PARAMETER_TYPE
loop_block: "BlockTypeVar"
def get_all_parameters(
self,
) -> list[PARAMETER_TYPE]:
return self.loop_block.get_all_parameters() + [self.loop_over]
def get_loop_block_context_parameters(self, workflow_run_id: str, loop_data: Any) -> list[ContextParameter]:
if not isinstance(loop_data, dict):
# TODO (kerem): Should we add support for other types?
raise ValueError("loop_data should be a dictionary")
loop_block_parameters = self.loop_block.get_all_parameters()
context_parameters = [
parameter for parameter in loop_block_parameters if isinstance(parameter, ContextParameter)
]
for context_parameter in context_parameters:
if context_parameter.key not in loop_data:
raise ContextParameterValueNotFound(
parameter_key=context_parameter.key,
existing_keys=list(loop_data.keys()),
workflow_run_id=workflow_run_id,
)
context_parameter.value = loop_data[context_parameter.key]
return context_parameters
def get_loop_over_parameter_values(self, context_manager: ContextManager) -> list[Any]:
if isinstance(self.loop_over, WorkflowParameter):
parameter_value = context_manager.get_value(self.loop_over.key)
if isinstance(parameter_value, list):
return parameter_value
else:
# TODO (kerem): Should we raise an error here?
return [parameter_value]
else:
# TODO (kerem): Implement this for context parameters
raise NotImplementedError
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
loop_over_values = self.get_loop_over_parameter_values(context_manager)
LOG.info(
f"Number of loop_over values: {len(loop_over_values)}",
block_type=self.block_type,
workflow_run_id=workflow_run_id,
num_loop_over_values=len(loop_over_values),
)
for loop_over_value in loop_over_values:
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
for context_parameter in context_parameters_with_value:
context_manager.set_value(context_parameter.key, context_parameter.value)
await self.loop_block.execute(workflow_run_id=workflow_run_id, context_manager=context_manager)
return None
BlockSubclasses = Union[ForLoopBlock, TaskBlock]
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]

View File

@@ -0,0 +1,84 @@
import abc
import json
from datetime import datetime
from enum import StrEnum
from typing import Annotated, Literal, Union
from pydantic import BaseModel, Field
class ParameterType(StrEnum):
WORKFLOW = "workflow"
CONTEXT = "context"
AWS_SECRET = "aws_secret"
class Parameter(BaseModel, abc.ABC):
# TODO (kerem): Should we also have organization_id here?
parameter_type: ParameterType
key: str
description: str | None = None
@classmethod
def get_subclasses(cls) -> tuple[type["Parameter"], ...]:
return tuple(cls.__subclasses__())
class AWSSecretParameter(Parameter):
parameter_type: Literal[ParameterType.AWS_SECRET] = ParameterType.AWS_SECRET
aws_secret_parameter_id: str
workflow_id: str
aws_key: str
created_at: datetime
modified_at: datetime
deleted_at: datetime | None = None
class WorkflowParameterType(StrEnum):
STRING = "string"
INTEGER = "integer"
FLOAT = "float"
BOOLEAN = "boolean"
JSON = "json"
def convert_value(self, value: str | None) -> str | int | float | bool | dict | list | None:
if value is None:
return None
if self == WorkflowParameterType.STRING:
return value
elif self == WorkflowParameterType.INTEGER:
return int(value)
elif self == WorkflowParameterType.FLOAT:
return float(value)
elif self == WorkflowParameterType.BOOLEAN:
return value.lower() in ["true", "1"]
elif self == WorkflowParameterType.JSON:
return json.loads(value)
class WorkflowParameter(Parameter):
parameter_type: Literal[ParameterType.WORKFLOW] = ParameterType.WORKFLOW
workflow_parameter_id: str
workflow_parameter_type: WorkflowParameterType
workflow_id: str
# the type of default_value will be determined by the workflow_parameter_type
default_value: str | int | float | bool | dict | list | None = None
created_at: datetime
modified_at: datetime
deleted_at: datetime | None = None
class ContextParameter(Parameter):
parameter_type: Literal[ParameterType.CONTEXT] = ParameterType.CONTEXT
source: WorkflowParameter
# value will be populated by the context manager
value: str | int | float | bool | dict | list | None = None
ParameterSubclasses = Union[WorkflowParameter, ContextParameter, AWSSecretParameter]
PARAMETER_TYPE = Annotated[ParameterSubclasses, Field(discriminator="parameter_type")]

View File

@@ -0,0 +1,74 @@
from datetime import datetime
from enum import StrEnum
from typing import Any, List
from pydantic import BaseModel
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
from skyvern.forge.sdk.workflow.models.block import BlockTypeVar
class WorkflowRequestBody(BaseModel):
data: dict[str, Any] | None = None
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
class RunWorkflowResponse(BaseModel):
workflow_id: str
workflow_run_id: str
class WorkflowDefinition(BaseModel):
blocks: List[BlockTypeVar]
class Workflow(BaseModel):
workflow_id: str
organization_id: str
title: str
description: str | None = None
workflow_definition: WorkflowDefinition
created_at: datetime
modified_at: datetime
deleted_at: datetime | None = None
class WorkflowRunStatus(StrEnum):
created = "created"
running = "running"
failed = "failed"
terminated = "terminated"
completed = "completed"
class WorkflowRun(BaseModel):
workflow_run_id: str
workflow_id: str
status: WorkflowRunStatus
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
created_at: datetime
modified_at: datetime
class WorkflowRunParameter(BaseModel):
workflow_run_id: str
workflow_parameter_id: str
value: bool | int | float | str | dict | list
created_at: datetime
class WorkflowRunStatusResponse(BaseModel):
workflow_id: str
workflow_run_id: str
status: WorkflowRunStatus
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
created_at: datetime
modified_at: datetime
parameters: dict[str, Any]
screenshot_urls: list[str] | None = None
recording_url: str | None = None

View File

@@ -0,0 +1,509 @@
import asyncio
import json
import time
from datetime import datetime
import requests
import structlog
from skyvern.exceptions import (
FailedToSendWebhook,
MissingValueForParameter,
WorkflowNotFound,
WorkflowOrganizationMismatch,
WorkflowRunNotFound,
)
from skyvern.forge import app
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.workflow.context_manager import ContextManager
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
from skyvern.forge.sdk.workflow.models.workflow import (
Workflow,
WorkflowDefinition,
WorkflowRequestBody,
WorkflowRun,
WorkflowRunParameter,
WorkflowRunStatus,
WorkflowRunStatusResponse,
)
from skyvern.webeye.browser_factory import BrowserState
LOG = structlog.get_logger()
class WorkflowService:
async def setup_workflow_run(
self,
request_id: str | None,
workflow_request: WorkflowRequestBody,
workflow_id: str,
organization_id: str,
max_steps_override: int | None = None,
) -> WorkflowRun:
"""
Create a workflow run and its parameters. Validate the workflow and the organization. If there are missing
parameters with no default value, mark the workflow run as failed.
:param request_id: The request id for the workflow run.
:param workflow_request: The request body for the workflow run, containing the parameters and the config.
:param workflow_id: The workflow id to run.
:param organization_id: The organization id for the workflow.
:param max_steps_override: The max steps override for the workflow run, if any.
:return: The created workflow run.
"""
LOG.info(f"Setting up workflow run for workflow {workflow_id}", workflow_id=workflow_id)
# Validate the workflow and the organization
workflow = await self.get_workflow(workflow_id=workflow_id)
if workflow is None:
LOG.error(f"Workflow {workflow_id} not found")
raise WorkflowNotFound(workflow_id=workflow_id)
if workflow.organization_id != organization_id:
LOG.error(f"Workflow {workflow_id} does not belong to organization {organization_id}")
raise WorkflowOrganizationMismatch(workflow_id=workflow_id, organization_id=organization_id)
# Create the workflow run and set skyvern context
workflow_run = await self.create_workflow_run(workflow_request=workflow_request, workflow_id=workflow_id)
LOG.info(
f"Created workflow run {workflow_run.workflow_run_id} for workflow {workflow.workflow_id}",
request_id=request_id,
workflow_run_id=workflow_run.workflow_run_id,
workflow_id=workflow.workflow_id,
proxy_location=workflow_request.proxy_location,
)
skyvern_context.set(
SkyvernContext(
organization_id=organization_id,
request_id=request_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
max_steps_override=max_steps_override,
)
)
# Set workflow run status to running, create workflow run parameters
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
# Create all the workflow run parameters, AWSSecretParameter won't have workflow run parameters created.
all_workflow_parameters = await self.get_workflow_parameters(workflow_id=workflow.workflow_id)
workflow_run_parameters = []
for workflow_parameter in all_workflow_parameters:
if workflow_request.data and workflow_parameter.key in workflow_request.data:
request_body_value = workflow_request.data[workflow_parameter.key]
workflow_run_parameter = await self.create_workflow_run_parameter(
workflow_run_id=workflow_run.workflow_run_id,
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
value=request_body_value,
)
elif workflow_parameter.default_value is not None:
workflow_run_parameter = await self.create_workflow_run_parameter(
workflow_run_id=workflow_run.workflow_run_id,
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
value=workflow_parameter.default_value,
)
else:
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
raise MissingValueForParameter(
parameter_key=workflow_parameter.key,
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
)
workflow_run_parameters.append(workflow_run_parameter)
LOG.info(
f"Created workflow run parameters for workflow run {workflow_run.workflow_run_id}",
workflow_run_id=workflow_run.workflow_run_id,
)
return workflow_run
async def execute_workflow(
self,
workflow_run_id: str,
api_key: str,
) -> WorkflowRun:
"""Execute a workflow."""
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id)
await app.BROWSER_MANAGER.get_or_create_for_workflow_run(workflow_run=workflow_run)
# Get all <workflow parameter, workflow run parameter> tuples
wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id)
# todo(kerem): do this in a better way (a shared context manager? (not really shared because we use batch job))
context_manager = ContextManager(wp_wps_tuples)
# Execute workflow blocks
blocks = workflow.workflow_definition.blocks
for block_idx, block in enumerate(blocks):
parameters = block.get_all_parameters()
await context_manager.register_block_parameters(parameters)
LOG.info(
f"Executing root block {block.block_type} at index {block_idx} for workflow run {workflow_run.workflow_run_id}",
block_type=block.block_type,
workflow_run_id=workflow_run.workflow_run_id,
block_idx=block_idx,
)
await block.execute(workflow_run_id=workflow_run.workflow_run_id, context_manager=context_manager)
# Get last task for workflow run
task = await self.get_last_task_for_workflow_run(workflow_run_id=workflow_run.workflow_run_id)
if not task:
LOG.warning(
f"No tasks found for workflow run {workflow_run.workflow_run_id}, not sending webhook",
workflow_run_id=workflow_run.workflow_run_id,
)
return workflow_run
# Update workflow status
if task.status == "completed":
await self.mark_workflow_run_as_completed(workflow_run_id=workflow_run.workflow_run_id)
elif task.status == "failed":
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
elif task.status == "terminated":
await self.mark_workflow_run_as_terminated(workflow_run_id=workflow_run.workflow_run_id)
else:
LOG.warning(
f"Task {task.task_id} has an incomplete status {task.status}, not updating workflow run status",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
task_id=task.task_id,
status=task.status,
workflow_run_status=workflow_run.status,
)
await self.send_workflow_response(
workflow=workflow,
workflow_run=workflow_run,
api_key=api_key,
last_task=task,
)
return workflow_run
async def create_workflow(
self,
organization_id: str,
title: str,
workflow_definition: WorkflowDefinition,
description: str | None = None,
) -> Workflow:
return await app.DATABASE.create_workflow(
organization_id=organization_id,
title=title,
description=description,
workflow_definition=workflow_definition.model_dump() if workflow_definition else None,
)
async def get_workflow(self, workflow_id: str) -> Workflow:
workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id)
if not workflow:
raise WorkflowNotFound(workflow_id)
return workflow
async def update_workflow(
self,
workflow_id: str,
title: str | None = None,
description: str | None = None,
workflow_definition: WorkflowDefinition | None = None,
) -> Workflow | None:
return await app.DATABASE.update_workflow(
workflow_id=workflow_id,
title=title,
description=description,
workflow_definition=workflow_definition.model_dump() if workflow_definition else None,
)
async def create_workflow_run(self, workflow_request: WorkflowRequestBody, workflow_id: str) -> WorkflowRun:
return await app.DATABASE.create_workflow_run(
workflow_id=workflow_id,
proxy_location=workflow_request.proxy_location,
webhook_callback_url=workflow_request.webhook_callback_url,
)
async def mark_workflow_run_as_completed(self, workflow_run_id: str) -> None:
LOG.info(
f"Marking workflow run {workflow_run_id} as completed", workflow_run_id=workflow_run_id, status="completed"
)
await app.DATABASE.update_workflow_run(
workflow_run_id=workflow_run_id,
status=WorkflowRunStatus.completed,
)
async def mark_workflow_run_as_failed(self, workflow_run_id: str) -> None:
LOG.info(f"Marking workflow run {workflow_run_id} as failed", workflow_run_id=workflow_run_id, status="failed")
await app.DATABASE.update_workflow_run(
workflow_run_id=workflow_run_id,
status=WorkflowRunStatus.failed,
)
async def mark_workflow_run_as_running(self, workflow_run_id: str) -> None:
LOG.info(
f"Marking workflow run {workflow_run_id} as running", workflow_run_id=workflow_run_id, status="running"
)
await app.DATABASE.update_workflow_run(
workflow_run_id=workflow_run_id,
status=WorkflowRunStatus.running,
)
async def mark_workflow_run_as_terminated(self, workflow_run_id: str) -> None:
LOG.info(
f"Marking workflow run {workflow_run_id} as terminated",
workflow_run_id=workflow_run_id,
status="terminated",
)
await app.DATABASE.update_workflow_run(
workflow_run_id=workflow_run_id,
status=WorkflowRunStatus.terminated,
)
async def get_workflow_runs(self, workflow_id: str) -> list[WorkflowRun]:
return await app.DATABASE.get_workflow_runs(workflow_id=workflow_id)
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run_id)
if not workflow_run:
raise WorkflowRunNotFound(workflow_run_id)
return workflow_run
async def create_workflow_parameter(
self,
workflow_id: str,
workflow_parameter_type: WorkflowParameterType,
key: str,
default_value: bool | int | float | str | dict | list | None = None,
description: str | None = None,
) -> WorkflowParameter:
return await app.DATABASE.create_workflow_parameter(
workflow_id=workflow_id,
workflow_parameter_type=workflow_parameter_type,
key=key,
description=description,
default_value=default_value,
)
async def create_aws_secret_parameter(
self, workflow_id: str, aws_key: str, key: str, description: str | None = None
) -> AWSSecretParameter:
return await app.DATABASE.create_aws_secret_parameter(
workflow_id=workflow_id, aws_key=aws_key, key=key, description=description
)
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
return await app.DATABASE.get_workflow_parameters(workflow_id=workflow_id)
async def create_workflow_run_parameter(
self,
workflow_run_id: str,
workflow_parameter_id: str,
value: bool | int | float | str | dict | list,
) -> WorkflowRunParameter:
return await app.DATABASE.create_workflow_run_parameter(
workflow_run_id=workflow_run_id,
workflow_parameter_id=workflow_parameter_id,
value=json.dumps(value) if isinstance(value, (dict, list)) else value,
)
async def get_workflow_run_parameter_tuples(
self, workflow_run_id: str
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
return await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
return await app.DATABASE.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
return await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
async def build_workflow_run_status_response(
self, workflow_id: str, workflow_run_id: str, organization_id: str
) -> WorkflowRunStatusResponse:
workflow = await self.get_workflow(workflow_id=workflow_id)
if workflow is None:
LOG.error(f"Workflow {workflow_id} not found")
raise WorkflowNotFound(workflow_id=workflow_id)
if workflow.organization_id != organization_id:
LOG.error(f"Workflow {workflow_id} does not belong to organization {organization_id}")
raise WorkflowOrganizationMismatch(workflow_id=workflow_id, organization_id=organization_id)
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
screenshot_urls = []
# get the last screenshot for the last 3 tasks of the workflow run
for task in workflow_run_tasks[::-1]:
screenshot_artifact = await app.DATABASE.get_latest_artifact(
task_id=task.task_id,
artifact_types=[ArtifactType.SCREENSHOT_ACTION, ArtifactType.SCREENSHOT_FINAL],
organization_id=organization_id,
)
if screenshot_artifact:
screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(screenshot_artifact)
if screenshot_url:
screenshot_urls.append(screenshot_url)
if len(screenshot_urls) >= 3:
break
recording_url = None
recording_artifact = await app.DATABASE.get_artifact_for_workflow_run(
workflow_run_id=workflow_run_id, artifact_type=ArtifactType.RECORDING, organization_id=organization_id
)
if recording_artifact:
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples}
return WorkflowRunStatusResponse(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
status=workflow_run.status,
proxy_location=workflow_run.proxy_location,
webhook_callback_url=workflow_run.webhook_callback_url,
created_at=workflow_run.created_at,
modified_at=workflow_run.modified_at,
parameters=parameters_with_value,
screenshot_urls=screenshot_urls,
recording_url=recording_url,
)
async def send_workflow_response(
self,
workflow: Workflow,
workflow_run: WorkflowRun,
last_task: Task,
api_key: str | None = None,
close_browser_on_completion: bool = True,
) -> None:
browser_state = await app.BROWSER_MANAGER.cleanup_for_workflow_run(
workflow_run.workflow_run_id, close_browser_on_completion
)
if browser_state:
await self.persist_video_data(browser_state, workflow, workflow_run)
await self.persist_har_data(browser_state, last_task, workflow, workflow_run)
# Wait for all tasks to complete before generating the links for the artifacts
all_workflow_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(
workflow_run_id=workflow_run.workflow_run_id
)
all_workflow_task_ids = [task.task_id for task in all_workflow_tasks]
await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_tasks(all_workflow_task_ids)
try:
# Wait for all tasks to complete. Currently we're using asyncio.create_task() only for uploading artifacts to S3.
# We're excluding the current task from the list of tasks to wait for to prevent a deadlock.
st = time.time()
async with asyncio.timeout(30):
await asyncio.gather(
*[aio_task for aio_task in (asyncio.all_tasks() - {asyncio.current_task()}) if not aio_task.done()]
)
LOG.info(
f"Waiting for all S3 uploads to complete took {time.time() - st} seconds",
duration=time.time() - st,
)
except asyncio.TimeoutError:
LOG.warning(
"Timed out waiting for all S3 uploads to complete, not all artifacts may be uploaded. Waited 30 seconds.",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
)
if not workflow_run.webhook_callback_url:
LOG.warning(
"Workflow has no webhook callback url. Not sending workflow response",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
)
return
if not api_key:
LOG.warning(
"Request has no api key. Not sending workflow response",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
)
return
workflow_run_status_response = await self.build_workflow_run_status_response(
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
organization_id=workflow.organization_id,
)
# send task_response to the webhook callback url
# TODO: use async requests (httpx)
timestamp = str(int(datetime.utcnow().timestamp()))
payload = workflow_run_status_response.model_dump_json()
signature = generate_skyvern_signature(
payload=payload,
api_key=api_key,
)
headers = {
"x-skyvern-timestamp": timestamp,
"x-skyvern-signature": signature,
"Content-Type": "application/json",
}
LOG.info(
"Sending webhook run status to webhook callback url",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
webhook_callback_url=workflow_run.webhook_callback_url,
payload=payload,
headers=headers,
)
try:
resp = requests.post(workflow_run.webhook_callback_url, data=payload, headers=headers)
if resp.ok:
LOG.info(
"Webhook sent successfully",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
resp_code=resp.status_code,
resp_text=resp.text,
)
else:
LOG.info(
"Webhook failed",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
resp=resp,
resp_code=resp.status_code,
resp_text=resp.text,
resp_json=resp.json(),
)
except Exception as e:
raise FailedToSendWebhook(
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id
) from e
async def persist_video_data(
self, browser_state: BrowserState, workflow: Workflow, workflow_run: WorkflowRun
) -> None:
# Create recording artifact after closing the browser, so we can get an accurate recording
video_data = await app.BROWSER_MANAGER.get_video_data(
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
browser_state=browser_state,
)
if video_data:
await app.ARTIFACT_MANAGER.update_artifact_data(
artifact_id=browser_state.browser_artifacts.video_artifact_id,
organization_id=workflow.organization_id,
data=video_data,
)
async def persist_har_data(
self, browser_state: BrowserState, last_task: Task, workflow: Workflow, workflow_run: WorkflowRun
) -> None:
har_data = await app.BROWSER_MANAGER.get_har_data(
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id, browser_state=browser_state
)
if har_data:
last_step = await app.DATABASE.get_latest_step(
task_id=last_task.task_id, organization_id=last_task.organization_id
)
if last_step:
await app.ARTIFACT_MANAGER.create_artifact(
step=last_step,
artifact_type=ArtifactType.HAR,
data=har_data,
)