Move the code over from private repository (#3)
This commit is contained in:
0
skyvern/forge/sdk/__init__.py
Normal file
0
skyvern/forge/sdk/__init__.py
Normal file
97
skyvern/forge/sdk/agent.py
Normal file
97
skyvern/forge/sdk/agent.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, FastAPI, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette_context.middleware import RawContextMiddleware
|
||||
from starlette_context.plugins.base import Plugin
|
||||
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.routes.agent_protocol import base_router
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class Agent:
|
||||
def get_agent_app(self, router: APIRouter = base_router) -> FastAPI:
|
||||
"""
|
||||
Start the agent server.
|
||||
"""
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Add CORS middleware
|
||||
origins = [
|
||||
"http://localhost:5000",
|
||||
"http://127.0.0.1:5000",
|
||||
"http://localhost:8000",
|
||||
"http://127.0.0.1:8000",
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:8080",
|
||||
# Add any other origins you want to whitelist
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
app.add_middleware(AgentMiddleware, agent=self)
|
||||
|
||||
app.add_middleware(
|
||||
RawContextMiddleware,
|
||||
plugins=(
|
||||
# TODO (suchintan): We should set these up
|
||||
ExecutionDatePlugin(),
|
||||
# RequestIdPlugin(),
|
||||
# UserAgentPlugin(),
|
||||
),
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unexpected_exception(request: Request, exc: Exception) -> JSONResponse:
|
||||
LOG.exception("Unexpected error in agent server.", exc_info=exc)
|
||||
return JSONResponse(status_code=500, content={"error": f"Unexpected error: {exc}"})
|
||||
|
||||
@app.middleware("http")
|
||||
async def request_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
||||
request_id = str(uuid.uuid4())
|
||||
skyvern_context.set(SkyvernContext(request_id=request_id))
|
||||
|
||||
try:
|
||||
return await call_next(request)
|
||||
finally:
|
||||
skyvern_context.reset()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class AgentMiddleware:
|
||||
"""
|
||||
Middleware that injects the agent instance into the request scope.
|
||||
"""
|
||||
|
||||
def __init__(self, app: FastAPI, agent: Agent):
|
||||
self.app = app
|
||||
self.agent = agent
|
||||
|
||||
async def __call__(self, scope, receive, send): # type: ignore
|
||||
scope["agent"] = self.agent
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
class ExecutionDatePlugin(Plugin):
|
||||
key = "execution_date"
|
||||
|
||||
async def process_request(self, request: Request | HTTPConnection) -> datetime:
|
||||
return datetime.now()
|
||||
0
skyvern/forge/sdk/api/__init__.py
Normal file
0
skyvern/forge/sdk/api/__init__.py
Normal file
134
skyvern/forge/sdk/api/aws.py
Normal file
134
skyvern/forge/sdk/api/aws.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any, Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aioboto3
|
||||
import structlog
|
||||
from aiobotocore.client import AioBaseClient
|
||||
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class AWSClientType(StrEnum):
|
||||
S3 = "s3"
|
||||
SECRETS_MANAGER = "secretsmanager"
|
||||
|
||||
|
||||
def execute_with_async_client(client_type: AWSClientType) -> Callable:
|
||||
def decorator(f: Callable) -> Callable:
|
||||
async def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
||||
self = args[0]
|
||||
assert isinstance(self, AsyncAWSClient)
|
||||
session = aioboto3.Session()
|
||||
async with session.client(client_type) as client:
|
||||
return await f(*args, client=client, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class AsyncAWSClient:
|
||||
@execute_with_async_client(client_type=AWSClientType.SECRETS_MANAGER)
|
||||
async def get_secret(self, secret_name: str, client: AioBaseClient = None) -> str | None:
|
||||
try:
|
||||
response = await client.get_secret_value(SecretId=secret_name)
|
||||
return response["SecretString"]
|
||||
except Exception as e:
|
||||
try:
|
||||
error_code = e.response["Error"]["Code"] # type: ignore
|
||||
except Exception:
|
||||
error_code = "failed-to-get-error-code"
|
||||
LOG.exception("Failed to get secret.", secret_name=secret_name, error_code=error_code, exc_info=True)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def upload_file(self, uri: str, data: bytes, client: AioBaseClient = None) -> str | None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
await client.put_object(Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
LOG.debug("Upload file success", uri=uri)
|
||||
return uri
|
||||
except Exception:
|
||||
LOG.exception("S3 upload failed.", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def upload_file_from_path(self, uri: str, file_path: str, client: AioBaseClient = None) -> None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
await client.upload_file(file_path, parsed_uri.bucket, parsed_uri.key)
|
||||
LOG.info("Upload file from path success", uri=uri)
|
||||
except Exception:
|
||||
LOG.exception("S3 upload failed.", uri=uri)
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def download_file(self, uri: str, client: AioBaseClient = None) -> bytes | None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
response = await client.get_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return await response["Body"].read()
|
||||
except Exception:
|
||||
LOG.exception("S3 download failed", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def create_presigned_url(self, uri: str, client: AioBaseClient = None) -> str | None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
url = await client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": parsed_uri.bucket, "Key": parsed_uri.key},
|
||||
ExpiresIn=SettingsManager.get_settings().PRESIGNED_URL_EXPIRATION,
|
||||
)
|
||||
return url
|
||||
except Exception:
|
||||
LOG.exception("Failed to create presigned url.", uri=uri)
|
||||
return None
|
||||
|
||||
|
||||
class S3Uri(object):
|
||||
# From: https://stackoverflow.com/questions/42641315/s3-urls-get-bucket-name-and-path
|
||||
"""
|
||||
>>> s = S3Uri("s3://bucket/hello/world")
|
||||
>>> s.bucket
|
||||
'bucket'
|
||||
>>> s.key
|
||||
'hello/world'
|
||||
>>> s.uri
|
||||
's3://bucket/hello/world'
|
||||
|
||||
>>> s = S3Uri("s3://bucket/hello/world?qwe1=3#ddd")
|
||||
>>> s.bucket
|
||||
'bucket'
|
||||
>>> s.key
|
||||
'hello/world?qwe1=3#ddd'
|
||||
>>> s.uri
|
||||
's3://bucket/hello/world?qwe1=3#ddd'
|
||||
|
||||
>>> s = S3Uri("s3://bucket/hello/world#foo?bar=2")
|
||||
>>> s.key
|
||||
'hello/world#foo?bar=2'
|
||||
>>> s.uri
|
||||
's3://bucket/hello/world#foo?bar=2'
|
||||
"""
|
||||
|
||||
def __init__(self, uri: str) -> None:
|
||||
self._parsed = urlparse(uri, allow_fragments=False)
|
||||
|
||||
@property
|
||||
def bucket(self) -> str:
|
||||
return self._parsed.netloc
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
if self._parsed.query:
|
||||
return self._parsed.path.lstrip("/") + "?" + self._parsed.query
|
||||
else:
|
||||
return self._parsed.path.lstrip("/")
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._parsed.geturl()
|
||||
25
skyvern/forge/sdk/api/chat_completion_price.py
Normal file
25
skyvern/forge/sdk/api/chat_completion_price.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
openai_model_to_price_lambdas = {
|
||||
"gpt-4-vision-preview": (0.01, 0.03),
|
||||
"gpt-4-1106-preview": (0.01, 0.03),
|
||||
"gpt-3.5-turbo": (0.001, 0.002),
|
||||
"gpt-3.5-turbo-1106": (0.001, 0.002),
|
||||
}
|
||||
|
||||
|
||||
class ChatCompletionPrice(BaseModel):
|
||||
input_token_count: int
|
||||
output_token_count: int
|
||||
openai_model_to_price_lambda: Callable[[int, int], float]
|
||||
|
||||
def __init__(self, input_token_count: int, output_token_count: int, model_name: str):
|
||||
input_token_price, output_token_price = openai_model_to_price_lambdas[model_name]
|
||||
super().__init__(
|
||||
input_token_count=input_token_count,
|
||||
output_token_count=output_token_count,
|
||||
openai_model_to_price_lambda=lambda input_token, output_token: input_token_price * input_token / 1000
|
||||
+ output_token_price * output_token / 1000,
|
||||
)
|
||||
47
skyvern/forge/sdk/api/files.py
Normal file
47
skyvern/forge/sdk/api/files.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
import structlog
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
def download_file(url: str) -> str | None:
|
||||
# Send an HTTP request to the URL of the file, stream=True to prevent loading the content at once into memory
|
||||
r = requests.get(url, stream=True)
|
||||
|
||||
# Check if the request is successful
|
||||
if r.status_code == 200:
|
||||
# Parse the URL
|
||||
a = urlparse(url)
|
||||
|
||||
# Get the file name
|
||||
temp_dir = tempfile.mkdtemp(prefix="skyvern_downloads_")
|
||||
|
||||
file_name = os.path.basename(a.path)
|
||||
file_path = os.path.join(temp_dir, file_name)
|
||||
|
||||
LOG.info(f"Downloading file to {file_path}")
|
||||
with open(file_path, "wb") as f:
|
||||
# Write the content of the request into the file
|
||||
for chunk in r.iter_content(1024):
|
||||
f.write(chunk)
|
||||
LOG.info(f"File downloaded successfully to {file_path}")
|
||||
return file_path
|
||||
else:
|
||||
LOG.error(f"Failed to download file, status code: {r.status_code}")
|
||||
return None
|
||||
|
||||
|
||||
def zip_files(files_path: str, zip_file_path: str) -> str:
|
||||
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
for root, dirs, files in os.walk(files_path):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
arcname = os.path.relpath(file_path, files_path) # Relative path within the zip
|
||||
zipf.write(file_path, arcname)
|
||||
|
||||
return zip_file_path
|
||||
221
skyvern/forge/sdk/api/open_ai.py
Normal file
221
skyvern/forge/sdk/api/open_ai.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import commentjson
|
||||
import openai
|
||||
import structlog
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from skyvern.exceptions import InvalidOpenAIResponseFormat, NoAvailableOpenAIClients, OpenAIRequestTooBigError
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class OpenAIKeyClientWrapper:
|
||||
client: AsyncOpenAI
|
||||
key: str
|
||||
remaining_requests: int | None
|
||||
|
||||
def __init__(self, key: str, remaining_requests: int | None) -> None:
|
||||
self.key = key
|
||||
self.remaining_requests = remaining_requests
|
||||
self.updated_at = datetime.utcnow()
|
||||
self.client = AsyncOpenAI(api_key=self.key)
|
||||
|
||||
def update_remaining_requests(self, remaining_requests: int | None) -> None:
|
||||
self.remaining_requests = remaining_requests
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
# If remaining_requests is None, then it's the first time we're trying this key
|
||||
# so we can assume it's available, otherwise we check if it's greater than 0
|
||||
if self.remaining_requests is None:
|
||||
return True
|
||||
|
||||
if self.remaining_requests > 0:
|
||||
return True
|
||||
|
||||
# If we haven't checked this in over 1 minutes, check it again
|
||||
# Most of our failures are because of Tokens-per-minute (TPM) limits
|
||||
if self.updated_at < (datetime.utcnow() - timedelta(minutes=1)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class OpenAIClientManager:
|
||||
# TODO Support other models for requests without screenshots, track rate limits for each model and key as well if any
|
||||
clients: list[OpenAIKeyClientWrapper]
|
||||
|
||||
def __init__(self, api_keys: list[str] = SettingsManager.get_settings().OPENAI_API_KEYS) -> None:
|
||||
self.clients = [OpenAIKeyClientWrapper(key, None) for key in api_keys]
|
||||
|
||||
def get_available_client(self) -> OpenAIKeyClientWrapper | None:
|
||||
available_clients = [client for client in self.clients if client.is_available()]
|
||||
|
||||
if not available_clients:
|
||||
return None
|
||||
|
||||
# Randomly select an available client to distribute requests across our accounts
|
||||
return random.choice(available_clients)
|
||||
|
||||
async def content_builder(
|
||||
self,
|
||||
step: Step,
|
||||
screenshots: list[bytes] | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
content: list[dict[str, Any]] = []
|
||||
|
||||
if prompt is not None:
|
||||
content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
}
|
||||
)
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
data=prompt.encode("utf-8"),
|
||||
)
|
||||
if screenshots:
|
||||
for screenshot in screenshots:
|
||||
encoded_image = base64.b64encode(screenshot).decode("utf-8")
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{encoded_image}",
|
||||
},
|
||||
}
|
||||
)
|
||||
# create artifact for each image
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
data=screenshot,
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
step: Step,
|
||||
model: str = "gpt-4-vision-preview",
|
||||
max_tokens: int = 4096,
|
||||
temperature: int = 0,
|
||||
screenshots: list[bytes] | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
LOG.info(
|
||||
f"Sending LLM request",
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
num_screenshots=len(screenshots) if screenshots else 0,
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": await self.content_builder(
|
||||
step=step,
|
||||
screenshots=screenshots,
|
||||
prompt=prompt,
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
chat_completion_kwargs = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
data=json.dumps(chat_completion_kwargs).encode("utf-8"),
|
||||
)
|
||||
available_client = self.get_available_client()
|
||||
if available_client is None:
|
||||
raise NoAvailableOpenAIClients()
|
||||
try:
|
||||
response = await available_client.client.chat.completions.with_raw_response.create(**chat_completion_kwargs)
|
||||
except openai.RateLimitError as e:
|
||||
# If we get a RateLimitError, we can assume the key is not available anymore
|
||||
if e.code == 429:
|
||||
raise OpenAIRequestTooBigError(e.message)
|
||||
LOG.warning(
|
||||
"OpenAI rate limit exceeded, marking key as unavailable.", error_code=e.code, error_message=e.message
|
||||
)
|
||||
available_client.update_remaining_requests(remaining_requests=0)
|
||||
available_client = self.get_available_client()
|
||||
if available_client is None:
|
||||
raise NoAvailableOpenAIClients()
|
||||
return await self.chat_completion(
|
||||
step=step,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
screenshots=screenshots,
|
||||
prompt=prompt,
|
||||
)
|
||||
# TODO: https://platform.openai.com/docs/guides/rate-limits/rate-limits-in-headers
|
||||
# use other headers, x-ratelimit-limit-requests, x-ratelimit-limit-tokens, x-ratelimit-remaining-tokens
|
||||
# x-ratelimit-reset-requests, x-ratelimit-reset-tokens to write a more accurate algorithm for managing api keys
|
||||
|
||||
# If we get a response, we can assume the key is available and update the remaining requests
|
||||
ratelimit_remaining_requests = response.headers.get("x-ratelimit-remaining-requests")
|
||||
|
||||
if not ratelimit_remaining_requests:
|
||||
LOG.warning("Invalid x-ratelimit-remaining-requests from OpenAI", response.headers)
|
||||
|
||||
available_client.update_remaining_requests(remaining_requests=int(ratelimit_remaining_requests))
|
||||
chat_completion = response.parse()
|
||||
|
||||
if chat_completion.usage is not None:
|
||||
# TODO (Suchintan): Is this bad design?
|
||||
step = await app.DATABASE.update_step(
|
||||
step_id=step.step_id,
|
||||
task_id=step.task_id,
|
||||
organization_id=step.organization_id,
|
||||
chat_completion_price=ChatCompletionPrice(
|
||||
input_token_count=chat_completion.usage.prompt_tokens,
|
||||
output_token_count=chat_completion.usage.completion_tokens,
|
||||
model_name=model,
|
||||
),
|
||||
)
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
data=chat_completion.model_dump_json(indent=2).encode("utf-8"),
|
||||
)
|
||||
parsed_response = self.parse_response(chat_completion)
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
)
|
||||
return parsed_response
|
||||
|
||||
def parse_response(self, response: ChatCompletion) -> dict[str, str]:
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
content = content.replace("```json", "")
|
||||
content = content.replace("```", "")
|
||||
if not content:
|
||||
raise Exception("openai response content is empty")
|
||||
return commentjson.loads(content)
|
||||
except Exception as e:
|
||||
raise InvalidOpenAIResponseFormat(str(response)) from e
|
||||
0
skyvern/forge/sdk/artifact/__init__.py
Normal file
0
skyvern/forge/sdk/artifact/__init__.py
Normal file
112
skyvern/forge/sdk/artifact/manager.py
Normal file
112
skyvern/forge/sdk/artifact/manager.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.db.id import generate_artifact_id
|
||||
from skyvern.forge.sdk.models import Step
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class ArtifactManager:
|
||||
# task_id -> list of aio_tasks for uploading artifacts
|
||||
upload_aiotasks_map: dict[str, list[asyncio.Task[None]]] = defaultdict(list)
|
||||
|
||||
async def create_artifact(
|
||||
self, step: Step, artifact_type: ArtifactType, data: bytes | None = None, path: str | None = None
|
||||
) -> str:
|
||||
# TODO (kerem): Which is better?
|
||||
# current: (disadvantage: we create the artifact_id UUID here)
|
||||
# 1. generate artifact_id UUID here
|
||||
# 2. build uri with artifact_id, step_id, task_id, artifact_type
|
||||
# 3. create artifact in db using artifact_id, step_id, task_id, artifact_type, uri
|
||||
# 4. store artifact in storage
|
||||
# alternative: (disadvantage: two db calls)
|
||||
# 1. create artifact in db without the URI
|
||||
# 2. build uri with artifact_id, step_id, task_id, artifact_type
|
||||
# 3. update artifact in db with the URI
|
||||
# 4. store artifact in storage
|
||||
if data is None and path is None:
|
||||
raise ValueError("Either data or path must be provided to create an artifact.")
|
||||
if data and path:
|
||||
raise ValueError("Both data and path cannot be provided to create an artifact.")
|
||||
artifact_id = generate_artifact_id()
|
||||
uri = app.STORAGE.build_uri(artifact_id, step, artifact_type)
|
||||
artifact = await app.DATABASE.create_artifact(
|
||||
artifact_id,
|
||||
step.step_id,
|
||||
step.task_id,
|
||||
artifact_type,
|
||||
uri,
|
||||
organization_id=step.organization_id,
|
||||
)
|
||||
if data:
|
||||
# Fire and forget
|
||||
aio_task = asyncio.create_task(app.STORAGE.store_artifact(artifact, data))
|
||||
self.upload_aiotasks_map[step.task_id].append(aio_task)
|
||||
elif path:
|
||||
# Fire and forget
|
||||
aio_task = asyncio.create_task(app.STORAGE.store_artifact_from_path(artifact, path))
|
||||
self.upload_aiotasks_map[step.task_id].append(aio_task)
|
||||
|
||||
return artifact_id
|
||||
|
||||
async def update_artifact_data(self, artifact_id: str | None, organization_id: str | None, data: bytes) -> None:
|
||||
if not artifact_id or not organization_id:
|
||||
return None
|
||||
artifact = await app.DATABASE.get_artifact_by_id(artifact_id, organization_id)
|
||||
if not artifact:
|
||||
return
|
||||
# Fire and forget
|
||||
aio_task = asyncio.create_task(app.STORAGE.store_artifact(artifact, data))
|
||||
self.upload_aiotasks_map[artifact.task_id].append(aio_task)
|
||||
|
||||
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
|
||||
return await app.STORAGE.retrieve_artifact(artifact)
|
||||
|
||||
async def get_share_link(self, artifact: Artifact) -> str | None:
|
||||
return await app.STORAGE.get_share_link(artifact)
|
||||
|
||||
async def wait_for_upload_aiotasks_for_task(self, task_id: str) -> None:
|
||||
try:
|
||||
st = time.time()
|
||||
async with asyncio.timeout(30):
|
||||
await asyncio.gather(
|
||||
*[aio_task for aio_task in self.upload_aiotasks_map[task_id] if not aio_task.done()]
|
||||
)
|
||||
LOG.info(
|
||||
f"S3 upload tasks for task_id={task_id} completed in {time.time() - st:.2f}s",
|
||||
task_id=task_id,
|
||||
duration=time.time() - st,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
LOG.error(f"Timeout (30s) while waiting for upload tasks for task_id={task_id}", task_id=task_id)
|
||||
|
||||
del self.upload_aiotasks_map[task_id]
|
||||
|
||||
async def wait_for_upload_aiotasks_for_tasks(self, task_ids: list[str]) -> None:
|
||||
try:
|
||||
st = time.time()
|
||||
async with asyncio.timeout(30):
|
||||
await asyncio.gather(
|
||||
*[
|
||||
aio_task
|
||||
for task_id in task_ids
|
||||
for aio_task in self.upload_aiotasks_map[task_id]
|
||||
if not aio_task.done()
|
||||
]
|
||||
)
|
||||
LOG.info(
|
||||
f"S3 upload tasks for task_ids={task_ids} completed in {time.time() - st:.2f}s",
|
||||
task_ids=task_ids,
|
||||
duration=time.time() - st,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
LOG.error(f"Timeout (30s) while waiting for upload tasks for task_ids={task_ids}", task_ids=task_ids)
|
||||
|
||||
for task_id in task_ids:
|
||||
del self.upload_aiotasks_map[task_id]
|
||||
78
skyvern/forge/sdk/artifact/models.py
Normal file
78
skyvern/forge/sdk/artifact/models.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ArtifactType(StrEnum):
|
||||
RECORDING = "recording"
|
||||
|
||||
# DEPRECATED. pls use SCREENSHOT_LLM, SCREENSHOT_ACTION or SCREENSHOT_FINAL
|
||||
SCREENSHOT = "screenshot"
|
||||
|
||||
# USE THESE for screenshots
|
||||
SCREENSHOT_LLM = "screenshot_llm"
|
||||
SCREENSHOT_ACTION = "screenshot_action"
|
||||
SCREENSHOT_FINAL = "screenshot_final"
|
||||
|
||||
LLM_PROMPT = "llm_prompt"
|
||||
LLM_REQUEST = "llm_request"
|
||||
LLM_RESPONSE = "llm_response"
|
||||
LLM_RESPONSE_PARSED = "llm_response_parsed"
|
||||
VISIBLE_ELEMENTS_ID_XPATH_MAP = "visible_elements_id_xpath_map"
|
||||
VISIBLE_ELEMENTS_TREE = "visible_elements_tree"
|
||||
VISIBLE_ELEMENTS_TREE_TRIMMED = "visible_elements_tree_trimmed"
|
||||
|
||||
# DEPRECATED. pls use HTML_SCRAPE or HTML_ACTION
|
||||
HTML = "html"
|
||||
|
||||
# USE THESE for htmls
|
||||
HTML_SCRAPE = "html_scrape"
|
||||
HTML_ACTION = "html_action"
|
||||
|
||||
# Debugging
|
||||
TRACE = "trace"
|
||||
HAR = "har"
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
created_at: datetime = Field(
|
||||
...,
|
||||
description="The creation datetime of the task.",
|
||||
examples=["2023-01-01T00:00:00Z"],
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
modified_at: datetime = Field(
|
||||
...,
|
||||
description="The modification datetime of the task.",
|
||||
examples=["2023-01-01T00:00:00Z"],
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
artifact_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task artifact.",
|
||||
examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"],
|
||||
)
|
||||
task_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task this artifact belongs to.",
|
||||
examples=["50da533e-3904-4401-8a07-c49adf88b5eb"],
|
||||
)
|
||||
step_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task step this artifact belongs to.",
|
||||
examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"],
|
||||
)
|
||||
artifact_type: ArtifactType = Field(
|
||||
...,
|
||||
description="The type of the artifact.",
|
||||
examples=["screenshot"],
|
||||
)
|
||||
uri: str = Field(
|
||||
...,
|
||||
description="The URI of the artifact.",
|
||||
examples=["/Users/skyvern/hello/world.png"],
|
||||
)
|
||||
organization_id: str | None = None
|
||||
0
skyvern/forge/sdk/artifact/storage/__init__.py
Normal file
0
skyvern/forge/sdk/artifact/storage/__init__.py
Normal file
45
skyvern/forge/sdk/artifact/storage/base.py
Normal file
45
skyvern/forge/sdk/artifact/storage/base.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.models import Step
|
||||
|
||||
# TODO: This should be a part of the ArtifactType model
|
||||
FILE_EXTENTSION_MAP: dict[ArtifactType, str] = {
|
||||
ArtifactType.RECORDING: "webm",
|
||||
ArtifactType.SCREENSHOT_LLM: "png",
|
||||
ArtifactType.SCREENSHOT_ACTION: "png",
|
||||
ArtifactType.SCREENSHOT_FINAL: "png",
|
||||
ArtifactType.LLM_PROMPT: "txt",
|
||||
ArtifactType.LLM_REQUEST: "json",
|
||||
ArtifactType.LLM_RESPONSE: "json",
|
||||
ArtifactType.LLM_RESPONSE_PARSED: "json",
|
||||
ArtifactType.VISIBLE_ELEMENTS_ID_XPATH_MAP: "json",
|
||||
ArtifactType.VISIBLE_ELEMENTS_TREE: "json",
|
||||
ArtifactType.VISIBLE_ELEMENTS_TREE_TRIMMED: "json",
|
||||
ArtifactType.HTML_SCRAPE: "html",
|
||||
ArtifactType.HTML_ACTION: "html",
|
||||
ArtifactType.TRACE: "zip",
|
||||
ArtifactType.HAR: "har",
|
||||
}
|
||||
|
||||
|
||||
class BaseStorage(ABC):
|
||||
@abstractmethod
|
||||
def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_share_link(self, artifact: Artifact) -> str | None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
|
||||
pass
|
||||
14
skyvern/forge/sdk/artifact/storage/factory.py
Normal file
14
skyvern/forge/sdk/artifact/storage/factory.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from skyvern.forge.sdk.artifact.storage.base import BaseStorage
|
||||
from skyvern.forge.sdk.artifact.storage.local import LocalStorage
|
||||
|
||||
|
||||
class StorageFactory:
|
||||
__storage: BaseStorage = LocalStorage()
|
||||
|
||||
@staticmethod
|
||||
def set_storage(storage: BaseStorage) -> None:
|
||||
StorageFactory.__storage = storage
|
||||
|
||||
@staticmethod
|
||||
def get_storage() -> BaseStorage:
|
||||
return StorageFactory.__storage
|
||||
66
skyvern/forge/sdk/artifact/storage/local.py
Normal file
66
skyvern/forge/sdk/artifact/storage/local.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class LocalStorage(BaseStorage):
|
||||
def __init__(self, artifact_path: str = SettingsManager.get_settings().ARTIFACT_STORAGE_PATH) -> None:
|
||||
self.artifact_path = artifact_path
|
||||
|
||||
def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str:
|
||||
file_ext = FILE_EXTENTSION_MAP[artifact_type]
|
||||
return f"file://{self.artifact_path}/{step.task_id}/{step.order:02d}_{step.retry_index}_{step.step_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"
|
||||
|
||||
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
|
||||
file_path = None
|
||||
try:
|
||||
file_path = Path(self._parse_uri_to_path(artifact.uri))
|
||||
self._create_directories_if_not_exists(file_path)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(data)
|
||||
except Exception:
|
||||
LOG.exception("Failed to store artifact locally.", file_path=file_path, artifact=artifact)
|
||||
|
||||
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
|
||||
file_path = None
|
||||
try:
|
||||
file_path = Path(self._parse_uri_to_path(artifact.uri))
|
||||
self._create_directories_if_not_exists(file_path)
|
||||
Path(path).replace(file_path)
|
||||
except Exception:
|
||||
LOG.exception("Failed to store artifact locally.", file_path=file_path, artifact=artifact)
|
||||
|
||||
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
|
||||
file_path = None
|
||||
try:
|
||||
file_path = self._parse_uri_to_path(artifact.uri)
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
except Exception:
|
||||
LOG.exception("Failed to retrieve local artifact.", file_path=file_path, artifact=artifact)
|
||||
return None
|
||||
|
||||
async def get_share_link(self, artifact: Artifact) -> str:
|
||||
return artifact.uri
|
||||
|
||||
@staticmethod
|
||||
def _parse_uri_to_path(uri: str) -> str:
|
||||
parsed_uri = urlparse(uri)
|
||||
if parsed_uri.scheme != "file":
|
||||
raise ValueError("Invalid URI scheme: {parsed_uri.scheme} expected: file")
|
||||
path = parsed_uri.netloc + parsed_uri.path
|
||||
return unquote(path)
|
||||
|
||||
@staticmethod
|
||||
def _create_directories_if_not_exists(path_including_file_name: Path) -> None:
|
||||
path = path_including_file_name.parent
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
0
skyvern/forge/sdk/core/__init__.py
Normal file
0
skyvern/forge/sdk/core/__init__.py
Normal file
41
skyvern/forge/sdk/core/security.py
Normal file
41
skyvern/forge/sdk/core/security.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union
|
||||
|
||||
from jose import jwt
|
||||
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: timedelta | None = None,
|
||||
) -> str:
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(
|
||||
minutes=SettingsManager.get_settings().ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
)
|
||||
to_encode = {"exp": expire, "sub": str(subject)}
|
||||
encoded_jwt = jwt.encode(to_encode, SettingsManager.get_settings().SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def generate_skyvern_signature(
|
||||
payload: str,
|
||||
api_key: str,
|
||||
) -> str:
|
||||
"""
|
||||
Generate Skyvern signature.
|
||||
|
||||
:param payload: the request body
|
||||
:param api_key: the Skyvern api key
|
||||
|
||||
:return: the Skyvern signature
|
||||
"""
|
||||
hash_obj = hmac.new(api_key.encode("utf-8"), msg=payload.encode("utf-8"), digestmod=hashlib.sha256)
|
||||
return hash_obj.hexdigest()
|
||||
73
skyvern/forge/sdk/core/skyvern_context.py
Normal file
73
skyvern/forge/sdk/core/skyvern_context.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkyvernContext:
|
||||
request_id: str | None = None
|
||||
organization_id: str | None = None
|
||||
task_id: str | None = None
|
||||
workflow_id: str | None = None
|
||||
workflow_run_id: str | None = None
|
||||
max_steps_override: int | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SkyvernContext(request_id={self.request_id}, organization_id={self.organization_id}, task_id={self.task_id}, workflow_id={self.workflow_id}, workflow_run_id={self.workflow_run_id}, max_steps_override={self.max_steps_override})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
_context: ContextVar[SkyvernContext | None] = ContextVar(
|
||||
"Global context",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
def current() -> SkyvernContext | None:
|
||||
"""
|
||||
Get the current context
|
||||
|
||||
Returns:
|
||||
The current context, or None if there is none
|
||||
"""
|
||||
return _context.get()
|
||||
|
||||
|
||||
def ensure_context() -> SkyvernContext:
|
||||
"""
|
||||
Get the current context, or raise an error if there is none
|
||||
|
||||
Returns:
|
||||
The current context if there is one
|
||||
|
||||
Raises:
|
||||
RuntimeError: If there is no current context
|
||||
"""
|
||||
context = current()
|
||||
if context is None:
|
||||
raise RuntimeError("No skyvern context")
|
||||
return context
|
||||
|
||||
|
||||
def set(context: SkyvernContext) -> None:
|
||||
"""
|
||||
Set the current context
|
||||
|
||||
Args:
|
||||
context: The context to set
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
_context.set(context)
|
||||
|
||||
|
||||
def reset() -> None:
|
||||
"""
|
||||
Reset the current context
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
_context.set(None)
|
||||
0
skyvern/forge/sdk/db/__init__.py
Normal file
0
skyvern/forge/sdk/db/__init__.py
Normal file
900
skyvern/forge/sdk/db/client.py
Normal file
900
skyvern/forge/sdk/db/client.py
Normal file
@@ -0,0 +1,900 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import and_, create_engine, delete
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from skyvern.exceptions import WorkflowParameterNotFound
|
||||
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
ArtifactModel,
|
||||
AWSSecretParameterModel,
|
||||
OrganizationAuthTokenModel,
|
||||
OrganizationModel,
|
||||
StepModel,
|
||||
TaskModel,
|
||||
WorkflowModel,
|
||||
WorkflowParameterModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowRunParameterModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import (
|
||||
_custom_json_serializer,
|
||||
convert_to_artifact,
|
||||
convert_to_aws_secret_parameter,
|
||||
convert_to_organization,
|
||||
convert_to_organization_auth_token,
|
||||
convert_to_step,
|
||||
convert_to_task,
|
||||
convert_to_workflow,
|
||||
convert_to_workflow_parameter,
|
||||
convert_to_workflow_run,
|
||||
convert_to_workflow_run_parameter,
|
||||
)
|
||||
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunParameter, WorkflowRunStatus
|
||||
from skyvern.webeye.actions.models import AgentStepOutput
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class AgentDB:
|
||||
def __init__(self, database_string: str, debug_enabled: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.debug_enabled = debug_enabled
|
||||
self.engine = create_engine(database_string, json_serializer=_custom_json_serializer)
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
url: str,
|
||||
navigation_goal: str | None,
|
||||
data_extraction_goal: str | None,
|
||||
navigation_payload: dict[str, Any] | list | str | None,
|
||||
webhook_callback_url: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
proxy_location: ProxyLocation | None = None,
|
||||
extracted_information_schema: dict[str, Any] | list | str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
order: int | None = None,
|
||||
retry: int | None = None,
|
||||
) -> Task:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
new_task = TaskModel(
|
||||
status="created",
|
||||
url=url,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
navigation_goal=navigation_goal,
|
||||
data_extraction_goal=data_extraction_goal,
|
||||
navigation_payload=navigation_payload,
|
||||
organization_id=organization_id,
|
||||
proxy_location=proxy_location,
|
||||
extracted_information_schema=extracted_information_schema,
|
||||
workflow_run_id=workflow_run_id,
|
||||
order=order,
|
||||
retry=retry,
|
||||
)
|
||||
session.add(new_task)
|
||||
session.commit()
|
||||
session.refresh(new_task)
|
||||
return convert_to_task(new_task, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_step(
|
||||
self,
|
||||
task_id: str,
|
||||
order: int,
|
||||
retry_index: int,
|
||||
organization_id: str | None = None,
|
||||
) -> Step:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
new_step = StepModel(
|
||||
task_id=task_id,
|
||||
order=order,
|
||||
retry_index=retry_index,
|
||||
status="created",
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(new_step)
|
||||
session.commit()
|
||||
session.refresh(new_step)
|
||||
return convert_to_step(new_step, debug_enabled=self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_artifact(
|
||||
self,
|
||||
artifact_id: str,
|
||||
step_id: str,
|
||||
task_id: str,
|
||||
artifact_type: str,
|
||||
uri: str,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
new_artifact = ArtifactModel(
|
||||
artifact_id=artifact_id,
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
artifact_type=artifact_type,
|
||||
uri=uri,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(new_artifact)
|
||||
session.commit()
|
||||
session.refresh(new_artifact)
|
||||
return convert_to_artifact(new_artifact, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.exception("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.exception("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
|
||||
"""Get a task by its id"""
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if task_obj := (
|
||||
session.query(TaskModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_task(task_obj, self.debug_enabled)
|
||||
else:
|
||||
LOG.info("Task not found", task_id=task_id, organization_id=organization_id)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_step(self, task_id: str, step_id: str, organization_id: str | None = None) -> Step | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if step := (
|
||||
session.query(StepModel)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_step(step, debug_enabled=self.debug_enabled)
|
||||
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_task_steps(self, task_id: str, organization_id: str | None = None) -> list[Step]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
steps := session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order)
|
||||
.order_by(StepModel.retry_index)
|
||||
.all()
|
||||
):
|
||||
return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps]
|
||||
else:
|
||||
return []
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> list[StepModel]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
return (
|
||||
session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order)
|
||||
.order_by(StepModel.retry_index)
|
||||
.all()
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_latest_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if step := (
|
||||
session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order.desc())
|
||||
.first()
|
||||
):
|
||||
return convert_to_step(step, debug_enabled=self.debug_enabled)
|
||||
else:
|
||||
LOG.info("Latest step not found", task_id=task_id, organization_id=organization_id)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
status: StepStatus | None = None,
|
||||
output: AgentStepOutput | None = None,
|
||||
is_last: bool | None = None,
|
||||
retry_index: int | None = None,
|
||||
organization_id: str | None = None,
|
||||
chat_completion_price: ChatCompletionPrice | None = None,
|
||||
) -> Step:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
step := session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
if status is not None:
|
||||
step.status = status
|
||||
if output is not None:
|
||||
step.output = output.model_dump()
|
||||
if is_last is not None:
|
||||
step.is_last = is_last
|
||||
if retry_index is not None:
|
||||
step.retry_index = retry_index
|
||||
if chat_completion_price is not None:
|
||||
if step.input_token_count is None:
|
||||
step.input_token_count = 0
|
||||
|
||||
if step.output_token_count is None:
|
||||
step.output_token_count = 0
|
||||
|
||||
step.input_token_count += chat_completion_price.input_token_count
|
||||
step.output_token_count += chat_completion_price.output_token_count
|
||||
step.step_cost = chat_completion_price.openai_model_to_price_lambda(
|
||||
step.input_token_count, step.output_token_count
|
||||
)
|
||||
|
||||
session.commit()
|
||||
updated_step = await self.get_step(task_id, step_id, organization_id)
|
||||
if not updated_step:
|
||||
raise NotFoundError("Step not found")
|
||||
return updated_step
|
||||
else:
|
||||
raise NotFoundError("Step not found")
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except NotFoundError:
|
||||
LOG.error("NotFoundError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: TaskStatus,
|
||||
extracted_information: dict[str, Any] | list | str | None = None,
|
||||
failure_reason: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> Task:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
task := session.query(TaskModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
task.status = status
|
||||
if extracted_information is not None:
|
||||
task.extracted_information = extracted_information
|
||||
if failure_reason is not None:
|
||||
task.failure_reason = failure_reason
|
||||
session.commit()
|
||||
updated_task = await self.get_task(task_id, organization_id=organization_id)
|
||||
if not updated_task:
|
||||
raise NotFoundError("Task not found")
|
||||
return updated_task
|
||||
else:
|
||||
raise NotFoundError("Task not found")
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except NotFoundError:
|
||||
LOG.error("NotFoundError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_tasks(self, page: int = 1, page_size: int = 10, organization_id: str | None = None) -> list[Task]:
|
||||
"""
|
||||
Get all tasks.
|
||||
:param page: Starts at 1
|
||||
:param page_size:
|
||||
:return:
|
||||
"""
|
||||
if page < 1:
|
||||
raise ValueError(f"Page must be greater than 0, got {page}")
|
||||
|
||||
try:
|
||||
with self.Session() as session:
|
||||
db_page = page - 1 # offset logic is 0 based
|
||||
tasks = (
|
||||
session.query(TaskModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(TaskModel.created_at.desc())
|
||||
.limit(page_size)
|
||||
.offset(db_page * page_size)
|
||||
.all()
|
||||
)
|
||||
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_organization(self, organization_id: str) -> Organization | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if organization := (
|
||||
session.query(OrganizationModel).filter_by(organization_id=organization_id).first()
|
||||
):
|
||||
return convert_to_organization(organization)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_organization(
|
||||
self,
|
||||
organization_name: str,
|
||||
webhook_callback_url: str | None = None,
|
||||
max_steps_per_run: int | None = None,
|
||||
) -> Organization:
|
||||
with self.Session() as session:
|
||||
org = OrganizationModel(
|
||||
organization_name=organization_name,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
max_steps_per_run=max_steps_per_run,
|
||||
)
|
||||
session.add(org)
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
|
||||
return convert_to_organization(org)
|
||||
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
) -> OrganizationAuthToken | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if token := (
|
||||
session.query(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(valid=True)
|
||||
.first()
|
||||
):
|
||||
return convert_to_organization_auth_token(token)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def validate_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str,
|
||||
) -> OrganizationAuthToken | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if token_obj := (
|
||||
session.query(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(token=token)
|
||||
.filter_by(valid=True)
|
||||
.first()
|
||||
):
|
||||
return convert_to_organization_auth_token(token_obj)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str,
|
||||
) -> OrganizationAuthToken:
|
||||
with self.Session() as session:
|
||||
token = OrganizationAuthTokenModel(
|
||||
organization_id=organization_id,
|
||||
token_type=token_type,
|
||||
token=token,
|
||||
)
|
||||
session.add(token)
|
||||
session.commit()
|
||||
session.refresh(token)
|
||||
|
||||
return convert_to_organization_auth_token(token)
|
||||
|
||||
async def get_artifacts_for_task_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[Artifact]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if artifacts := (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.all()
|
||||
):
|
||||
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
|
||||
else:
|
||||
return []
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_artifact_by_id(
|
||||
self,
|
||||
artifact_id: str,
|
||||
organization_id: str,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if artifact := (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(artifact_id=artifact_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.exception("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.exception("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
artifact_type: ArtifactType,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
artifact = (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(artifact_type=artifact_type)
|
||||
.order_by(ArtifactModel.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_artifact_for_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
artifact_type: ArtifactType,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
artifact = (
|
||||
session.query(ArtifactModel)
|
||||
.join(TaskModel, TaskModel.task_id == ArtifactModel.task_id)
|
||||
.filter(TaskModel.workflow_run_id == workflow_run_id)
|
||||
.filter(ArtifactModel.artifact_type == artifact_type)
|
||||
.filter(ArtifactModel.organization_id == organization_id)
|
||||
.order_by(ArtifactModel.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_latest_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str | None = None,
|
||||
artifact_types: list[ArtifactType] | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
artifact_query = session.query(ArtifactModel).filter_by(task_id=task_id)
|
||||
if step_id:
|
||||
artifact_query = artifact_query.filter_by(step_id=step_id)
|
||||
if organization_id:
|
||||
artifact_query = artifact_query.filter_by(organization_id=organization_id)
|
||||
if artifact_types:
|
||||
artifact_query = artifact_query.filter(ArtifactModel.artifact_type.in_(artifact_types))
|
||||
|
||||
artifact = artifact_query.order_by(ArtifactModel.created_at.desc()).first()
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.exception("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.exception("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_latest_task_by_workflow_id(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_id: str,
|
||||
before: datetime | None = None,
|
||||
) -> Task | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
query = (
|
||||
session.query(TaskModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(workflow_id=workflow_id)
|
||||
)
|
||||
if before:
|
||||
query = query.filter(TaskModel.created_at < before)
|
||||
task = query.order_by(TaskModel.created_at.desc()).first()
|
||||
if task:
|
||||
return convert_to_task(task, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_workflow(
|
||||
self,
|
||||
organization_id: str,
|
||||
title: str,
|
||||
workflow_definition: dict[str, Any],
|
||||
description: str | None = None,
|
||||
) -> Workflow:
|
||||
with self.Session() as session:
|
||||
workflow = WorkflowModel(
|
||||
organization_id=organization_id,
|
||||
title=title,
|
||||
description=description,
|
||||
workflow_definition=workflow_definition,
|
||||
)
|
||||
session.add(workflow)
|
||||
session.commit()
|
||||
session.refresh(workflow)
|
||||
return convert_to_workflow(workflow, self.debug_enabled)
|
||||
|
||||
async def get_workflow(self, workflow_id: str) -> Workflow | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if workflow := session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first():
|
||||
return convert_to_workflow(workflow, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_workflow(
|
||||
self,
|
||||
workflow_id: str,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
workflow_definition: dict[str, Any] | None = None,
|
||||
) -> Workflow | None:
|
||||
with self.Session() as session:
|
||||
workflow = session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first()
|
||||
if workflow:
|
||||
if title:
|
||||
workflow.title = title
|
||||
if description:
|
||||
workflow.description = description
|
||||
if workflow_definition:
|
||||
workflow.workflow_definition = workflow_definition
|
||||
session.commit()
|
||||
session.refresh(workflow)
|
||||
return convert_to_workflow(workflow, self.debug_enabled)
|
||||
LOG.error("Workflow not found, nothing to update", workflow_id=workflow_id)
|
||||
return None
|
||||
|
||||
async def create_workflow_run(
|
||||
self, workflow_id: str, proxy_location: ProxyLocation | None = None, webhook_callback_url: str | None = None
|
||||
) -> WorkflowRun:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_run = WorkflowRunModel(
|
||||
workflow_id=workflow_id,
|
||||
proxy_location=proxy_location,
|
||||
status="created",
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
)
|
||||
session.add(workflow_run)
|
||||
session.commit()
|
||||
session.refresh(workflow_run)
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_workflow_run(self, workflow_run_id: str, status: WorkflowRunStatus) -> WorkflowRun | None:
|
||||
with self.Session() as session:
|
||||
workflow_run = session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first()
|
||||
if workflow_run:
|
||||
workflow_run.status = status
|
||||
session.commit()
|
||||
session.refresh(workflow_run)
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
LOG.error("WorkflowRun not found, nothing to update", workflow_run_id=workflow_run_id)
|
||||
return None
|
||||
|
||||
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if workflow_run := session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first():
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_workflow_runs(self, workflow_id: str) -> list[WorkflowRun]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_runs = session.query(WorkflowRunModel).filter_by(workflow_id=workflow_id).all()
|
||||
return [convert_to_workflow_run(run) for run in workflow_runs]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_workflow_parameter(
|
||||
self,
|
||||
workflow_id: str,
|
||||
workflow_parameter_type: WorkflowParameterType,
|
||||
key: str,
|
||||
default_value: Any,
|
||||
description: str | None = None,
|
||||
) -> WorkflowParameter:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_parameter = WorkflowParameterModel(
|
||||
workflow_id=workflow_id,
|
||||
workflow_parameter_type=workflow_parameter_type,
|
||||
key=key,
|
||||
default_value=default_value,
|
||||
description=description,
|
||||
)
|
||||
session.add(workflow_parameter)
|
||||
session.commit()
|
||||
session.refresh(workflow_parameter)
|
||||
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_aws_secret_parameter(
|
||||
self,
|
||||
workflow_id: str,
|
||||
key: str,
|
||||
aws_key: str,
|
||||
description: str | None = None,
|
||||
) -> AWSSecretParameter:
|
||||
with self.Session() as session:
|
||||
aws_secret_parameter = AWSSecretParameterModel(
|
||||
workflow_id=workflow_id,
|
||||
key=key,
|
||||
aws_key=aws_key,
|
||||
description=description,
|
||||
)
|
||||
session.add(aws_secret_parameter)
|
||||
session.commit()
|
||||
session.refresh(aws_secret_parameter)
|
||||
return convert_to_aws_secret_parameter(aws_secret_parameter)
|
||||
|
||||
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_parameters = session.query(WorkflowParameterModel).filter_by(workflow_id=workflow_id).all()
|
||||
return [convert_to_workflow_parameter(parameter) for parameter in workflow_parameters]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_workflow_parameter(self, workflow_parameter_id: str) -> WorkflowParameter | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if workflow_parameter := (
|
||||
session.query(WorkflowParameterModel).filter_by(workflow_parameter_id=workflow_parameter_id).first()
|
||||
):
|
||||
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_workflow_run_parameter(
|
||||
self, workflow_run_id: str, workflow_parameter_id: str, value: Any
|
||||
) -> WorkflowRunParameter:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_run_parameter = WorkflowRunParameterModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter_id,
|
||||
value=value,
|
||||
)
|
||||
session.add(workflow_run_parameter)
|
||||
session.commit()
|
||||
session.refresh(workflow_run_parameter)
|
||||
workflow_parameter = await self.get_workflow_parameter(workflow_parameter_id)
|
||||
if not workflow_parameter:
|
||||
raise WorkflowParameterNotFound(workflow_parameter_id)
|
||||
return convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_workflow_run_parameters(
|
||||
self, workflow_run_id: str
|
||||
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_run_parameters = (
|
||||
session.query(WorkflowRunParameterModel).filter_by(workflow_run_id=workflow_run_id).all()
|
||||
)
|
||||
results = []
|
||||
for workflow_run_parameter in workflow_run_parameters:
|
||||
workflow_parameter = await self.get_workflow_parameter(workflow_run_parameter.workflow_parameter_id)
|
||||
if not workflow_parameter:
|
||||
raise WorkflowParameterNotFound(
|
||||
workflow_parameter_id=workflow_run_parameter.workflow_parameter_id
|
||||
)
|
||||
results.append(
|
||||
(
|
||||
workflow_parameter,
|
||||
convert_to_workflow_run_parameter(
|
||||
workflow_run_parameter, workflow_parameter, self.debug_enabled
|
||||
),
|
||||
)
|
||||
)
|
||||
return results
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if task := (
|
||||
session.query(TaskModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.order_by(TaskModel.created_at.desc())
|
||||
.first()
|
||||
):
|
||||
return convert_to_task(task, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
tasks = (
|
||||
session.query(TaskModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.order_by(TaskModel.created_at)
|
||||
.all()
|
||||
)
|
||||
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def delete_task_artifacts(self, organization_id: str, task_id: str) -> None:
|
||||
with self.Session() as session:
|
||||
# delete artifacts by filtering organization_id and task_id
|
||||
stmt = delete(ArtifactModel).where(
|
||||
and_(
|
||||
ArtifactModel.organization_id == organization_id,
|
||||
ArtifactModel.task_id == task_id,
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
async def delete_task_steps(self, organization_id: str, task_id: str) -> None:
|
||||
with self.Session() as session:
|
||||
# delete artifacts by filtering organization_id and task_id
|
||||
stmt = delete(StepModel).where(
|
||||
and_(
|
||||
StepModel.organization_id == organization_id,
|
||||
StepModel.task_id == task_id,
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
15
skyvern/forge/sdk/db/enums.py
Normal file
15
skyvern/forge/sdk/db/enums.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class OrganizationAuthTokenType(StrEnum):
|
||||
api = "api"
|
||||
|
||||
|
||||
class ScheduleRuleUnit(StrEnum):
|
||||
# No support for scheduling every second
|
||||
minute = "minute"
|
||||
hour = "hour"
|
||||
day = "day"
|
||||
week = "week"
|
||||
month = "month"
|
||||
year = "year"
|
||||
2
skyvern/forge/sdk/db/exceptions.py
Normal file
2
skyvern/forge/sdk/db/exceptions.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class NotFoundError(Exception):
|
||||
pass
|
||||
136
skyvern/forge/sdk/db/id.py
Normal file
136
skyvern/forge/sdk/db/id.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import hashlib
|
||||
import itertools
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import time
|
||||
|
||||
# 6/20/2022 12AM
|
||||
BASE_EPOCH = 1655683200
|
||||
VERSION = 0
|
||||
|
||||
# Number of bits
|
||||
TIMESTAMP_BITS = 32
|
||||
WORKER_ID_BITS = 21
|
||||
SEQUENCE_BITS = 10
|
||||
VERSION_BITS = 1
|
||||
|
||||
# Bit shits (left)
|
||||
TIMESTAMP_SHIFT = 32
|
||||
WORKER_ID_SHIFT = 11
|
||||
SEQUENCE_SHIFT = 1
|
||||
VERSION_SHIFT = 0
|
||||
|
||||
SEQUENCE_MAX = (2**SEQUENCE_BITS) - 1
|
||||
_sequence_start = None
|
||||
SEQUENCE_COUNTER = itertools.count()
|
||||
_worker_hash = None
|
||||
|
||||
# prefix
|
||||
ORGANIZATION_AUTH_TOKEN_PREFIX = "oat"
|
||||
ORG_PREFIX = "o"
|
||||
TASK_PREFIX = "tsk"
|
||||
USER_PREFIX = "u"
|
||||
STEP_PREFIX = "stp"
|
||||
ARTIFACT_PREFIX = "a"
|
||||
WORKFLOW_PREFIX = "w"
|
||||
WORKFLOW_RUN_PREFIX = "wr"
|
||||
WORKFLOW_PARAMETER_PREFIX = "wp"
|
||||
AWS_SECRET_PARAMETER_PREFIX = "asp"
|
||||
|
||||
|
||||
def generate_workflow_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{WORKFLOW_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_workflow_run_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{WORKFLOW_RUN_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_aws_secret_parameter_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{AWS_SECRET_PARAMETER_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_workflow_parameter_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{WORKFLOW_PARAMETER_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_organization_auth_token_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{ORGANIZATION_AUTH_TOKEN_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_org_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{ORG_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_task_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{TASK_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_step_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{STEP_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_artifact_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{ARTIFACT_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_user_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{USER_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_id() -> int:
|
||||
"""
|
||||
generate a 64-bit int ID
|
||||
"""
|
||||
create_at = current_time() - BASE_EPOCH
|
||||
sequence = _increment_and_get_sequence()
|
||||
|
||||
time_part = _mask_shift(create_at, TIMESTAMP_BITS, TIMESTAMP_SHIFT)
|
||||
worker_part = _mask_shift(_get_worker_hash(), WORKER_ID_BITS, WORKER_ID_SHIFT)
|
||||
sequence_part = _mask_shift(sequence, SEQUENCE_BITS, SEQUENCE_SHIFT)
|
||||
version_part = _mask_shift(VERSION, VERSION_BITS, VERSION_SHIFT)
|
||||
|
||||
return time_part | worker_part | sequence_part | version_part
|
||||
|
||||
|
||||
def _increment_and_get_sequence() -> int:
|
||||
global _sequence_start
|
||||
if _sequence_start is None:
|
||||
_sequence_start = random.randint(0, SEQUENCE_MAX)
|
||||
|
||||
return (_sequence_start + next(SEQUENCE_COUNTER)) % SEQUENCE_MAX
|
||||
|
||||
|
||||
def current_time() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def current_time_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def _mask_shift(value: int, mask_bits: int, shift_bits: int) -> int:
|
||||
return (value & ((2**mask_bits) - 1)) << shift_bits
|
||||
|
||||
|
||||
def _get_worker_hash() -> int:
|
||||
global _worker_hash
|
||||
if _worker_hash is None:
|
||||
_worker_hash = _generate_worker_hash()
|
||||
return _worker_hash
|
||||
|
||||
|
||||
def _generate_worker_hash() -> int:
|
||||
worker_identity = f"{platform.node()}:{os.getpid()}"
|
||||
return int(hashlib.md5(worker_identity.encode()).hexdigest()[-15:], 16)
|
||||
172
skyvern/forge/sdk/db/models.py
Normal file
172
skyvern/forge/sdk/db/models.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import datetime
|
||||
|
||||
from sqlalchemy import JSON, Boolean, Column, DateTime, Enum, ForeignKey, Integer, Numeric, String, UnicodeText
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.db.id import (
|
||||
generate_artifact_id,
|
||||
generate_aws_secret_parameter_id,
|
||||
generate_org_id,
|
||||
generate_organization_auth_token_id,
|
||||
generate_step_id,
|
||||
generate_task_id,
|
||||
generate_workflow_id,
|
||||
generate_workflow_parameter_id,
|
||||
generate_workflow_run_id,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class TaskModel(Base):
|
||||
__tablename__ = "tasks"
|
||||
|
||||
task_id = Column(String, primary_key=True, index=True, default=generate_task_id)
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"))
|
||||
status = Column(String)
|
||||
webhook_callback_url = Column(String)
|
||||
url = Column(String)
|
||||
navigation_goal = Column(String)
|
||||
data_extraction_goal = Column(String)
|
||||
navigation_payload = Column(JSON)
|
||||
extracted_information = Column(JSON)
|
||||
failure_reason = Column(String)
|
||||
proxy_location = Column(Enum(ProxyLocation))
|
||||
extracted_information_schema = Column(JSON)
|
||||
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"))
|
||||
order = Column(Integer, nullable=True)
|
||||
retry = Column(Integer, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class StepModel(Base):
|
||||
__tablename__ = "steps"
|
||||
|
||||
step_id = Column(String, primary_key=True, index=True, default=generate_step_id)
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"))
|
||||
task_id = Column(String, ForeignKey("tasks.task_id"))
|
||||
status = Column(String)
|
||||
output = Column(JSON)
|
||||
order = Column(Integer)
|
||||
is_last = Column(Boolean, default=False)
|
||||
retry_index = Column(Integer, default=0)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
input_token_count = Column(Integer, default=0)
|
||||
output_token_count = Column(Integer, default=0)
|
||||
step_cost = Column(Numeric, default=0)
|
||||
|
||||
|
||||
class OrganizationModel(Base):
|
||||
__tablename__ = "organizations"
|
||||
|
||||
organization_id = Column(String, primary_key=True, index=True, default=generate_org_id)
|
||||
organization_name = Column(String, nullable=False)
|
||||
webhook_callback_url = Column(UnicodeText)
|
||||
max_steps_per_run = Column(Integer)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
|
||||
|
||||
|
||||
class OrganizationAuthTokenModel(Base):
|
||||
__tablename__ = "organization_auth_tokens"
|
||||
|
||||
id = Column(
|
||||
String,
|
||||
primary_key=True,
|
||||
index=True,
|
||||
default=generate_organization_auth_token_id,
|
||||
)
|
||||
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"), index=True, nullable=False)
|
||||
token_type = Column(Enum(OrganizationAuthTokenType), nullable=False)
|
||||
token = Column(String, index=True, nullable=False)
|
||||
valid = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class ArtifactModel(Base):
|
||||
__tablename__ = "artifacts"
|
||||
|
||||
artifact_id = Column(String, primary_key=True, index=True, default=generate_artifact_id)
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"))
|
||||
task_id = Column(String, ForeignKey("tasks.task_id"))
|
||||
step_id = Column(String, ForeignKey("steps.step_id"))
|
||||
artifact_type = Column(String)
|
||||
uri = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class WorkflowModel(Base):
|
||||
__tablename__ = "workflows"
|
||||
|
||||
workflow_id = Column(String, primary_key=True, index=True, default=generate_workflow_id)
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"))
|
||||
title = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
workflow_definition = Column(JSON, nullable=False)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class WorkflowRunModel(Base):
|
||||
__tablename__ = "workflow_runs"
|
||||
|
||||
workflow_run_id = Column(String, primary_key=True, index=True, default=generate_workflow_run_id)
|
||||
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), nullable=False)
|
||||
status = Column(String, nullable=False)
|
||||
proxy_location = Column(Enum(ProxyLocation))
|
||||
webhook_callback_url = Column(String)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class WorkflowParameterModel(Base):
|
||||
__tablename__ = "workflow_parameters"
|
||||
|
||||
workflow_parameter_id = Column(String, primary_key=True, index=True, default=generate_workflow_parameter_id)
|
||||
workflow_parameter_type = Column(String, nullable=False)
|
||||
key = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
|
||||
default_value = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class AWSSecretParameterModel(Base):
|
||||
__tablename__ = "aws_secret_parameters"
|
||||
|
||||
aws_secret_parameter_id = Column(String, primary_key=True, index=True, default=generate_aws_secret_parameter_id)
|
||||
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
|
||||
key = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
aws_key = Column(String, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class WorkflowRunParameterModel(Base):
|
||||
__tablename__ = "workflow_run_parameters"
|
||||
|
||||
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), primary_key=True, index=True)
|
||||
workflow_parameter_id = Column(
|
||||
String, ForeignKey("workflow_parameters.workflow_parameter_id"), primary_key=True, index=True
|
||||
)
|
||||
# Can be bool | int | float | str | dict | list depending on the workflow parameter type
|
||||
value = Column(String, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
220
skyvern/forge/sdk/db/utils.py
Normal file
220
skyvern/forge/sdk/db/utils.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import json
|
||||
import typing
|
||||
|
||||
import pydantic.json
|
||||
import structlog
|
||||
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
ArtifactModel,
|
||||
AWSSecretParameterModel,
|
||||
OrganizationAuthTokenModel,
|
||||
OrganizationModel,
|
||||
StepModel,
|
||||
TaskModel,
|
||||
WorkflowModel,
|
||||
WorkflowParameterModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowRunParameterModel,
|
||||
)
|
||||
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
|
||||
from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
Workflow,
|
||||
WorkflowDefinition,
|
||||
WorkflowRun,
|
||||
WorkflowRunParameter,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _custom_json_serializer(*args, **kwargs) -> str:
|
||||
"""
|
||||
Encodes json in the same way that pydantic does.
|
||||
"""
|
||||
return json.dumps(*args, default=pydantic.json.pydantic_encoder, **kwargs)
|
||||
|
||||
|
||||
def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
|
||||
if debug_enabled:
|
||||
LOG.debug("Converting TaskModel to Task", task_id=task_obj.task_id)
|
||||
task = Task(
|
||||
task_id=task_obj.task_id,
|
||||
status=TaskStatus(task_obj.status),
|
||||
created_at=task_obj.created_at,
|
||||
modified_at=task_obj.modified_at,
|
||||
url=task_obj.url,
|
||||
webhook_callback_url=task_obj.webhook_callback_url,
|
||||
navigation_goal=task_obj.navigation_goal,
|
||||
data_extraction_goal=task_obj.data_extraction_goal,
|
||||
navigation_payload=task_obj.navigation_payload,
|
||||
extracted_information=task_obj.extracted_information,
|
||||
failure_reason=task_obj.failure_reason,
|
||||
organization_id=task_obj.organization_id,
|
||||
proxy_location=ProxyLocation(task_obj.proxy_location) if task_obj.proxy_location else None,
|
||||
extracted_information_schema=task_obj.extracted_information_schema,
|
||||
workflow_run_id=task_obj.workflow_run_id,
|
||||
order=task_obj.order,
|
||||
retry=task_obj.retry,
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
|
||||
if debug_enabled:
|
||||
LOG.debug("Converting StepModel to Step", step_id=step_model.step_id)
|
||||
return Step(
|
||||
task_id=step_model.task_id,
|
||||
step_id=step_model.step_id,
|
||||
created_at=step_model.created_at,
|
||||
modified_at=step_model.modified_at,
|
||||
status=StepStatus(step_model.status),
|
||||
output=step_model.output,
|
||||
order=step_model.order,
|
||||
is_last=step_model.is_last,
|
||||
retry_index=step_model.retry_index,
|
||||
organization_id=step_model.organization_id,
|
||||
input_token_count=step_model.input_token_count,
|
||||
output_token_count=step_model.output_token_count,
|
||||
step_cost=step_model.step_cost,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_organization(org_model: OrganizationModel) -> Organization:
|
||||
return Organization(
|
||||
organization_id=org_model.organization_id,
|
||||
organization_name=org_model.organization_name,
|
||||
webhook_callback_url=org_model.webhook_callback_url,
|
||||
max_steps_per_run=org_model.max_steps_per_run,
|
||||
created_at=org_model.created_at,
|
||||
modified_at=org_model.modified_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_organization_auth_token(org_auth_token: OrganizationAuthTokenModel) -> OrganizationAuthToken:
|
||||
return OrganizationAuthToken(
|
||||
id=org_auth_token.id,
|
||||
organization_id=org_auth_token.organization_id,
|
||||
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
|
||||
token=org_auth_token.token,
|
||||
valid=org_auth_token.valid,
|
||||
created_at=org_auth_token.created_at,
|
||||
modified_at=org_auth_token.modified_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_artifact(artifact_model: ArtifactModel, debug_enabled: bool = False) -> Artifact:
|
||||
if debug_enabled:
|
||||
LOG.debug("Converting ArtifactModel to Artifact", artifact_id=artifact_model.artifact_id)
|
||||
|
||||
return Artifact(
|
||||
artifact_id=artifact_model.artifact_id,
|
||||
artifact_type=ArtifactType[artifact_model.artifact_type.upper()],
|
||||
uri=artifact_model.uri,
|
||||
task_id=artifact_model.task_id,
|
||||
step_id=artifact_model.step_id,
|
||||
created_at=artifact_model.created_at,
|
||||
modified_at=artifact_model.modified_at,
|
||||
organization_id=artifact_model.organization_id,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = False) -> Workflow:
|
||||
if debug_enabled:
|
||||
LOG.debug("Converting WorkflowModel to Workflow", workflow_id=workflow_model.workflow_id)
|
||||
|
||||
return Workflow(
|
||||
workflow_id=workflow_model.workflow_id,
|
||||
organization_id=workflow_model.organization_id,
|
||||
title=workflow_model.title,
|
||||
description=workflow_model.description,
|
||||
workflow_definition=WorkflowDefinition.model_validate(workflow_model.workflow_definition),
|
||||
created_at=workflow_model.created_at,
|
||||
modified_at=workflow_model.modified_at,
|
||||
deleted_at=workflow_model.deleted_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_workflow_run(workflow_run_model: WorkflowRunModel, debug_enabled: bool = False) -> WorkflowRun:
|
||||
if debug_enabled:
|
||||
LOG.debug("Converting WorkflowRunModel to WorkflowRun", workflow_run_id=workflow_run_model.workflow_run_id)
|
||||
|
||||
return WorkflowRun(
|
||||
workflow_run_id=workflow_run_model.workflow_run_id,
|
||||
workflow_id=workflow_run_model.workflow_id,
|
||||
status=WorkflowRunStatus[workflow_run_model.status],
|
||||
proxy_location=ProxyLocation(workflow_run_model.proxy_location) if workflow_run_model.proxy_location else None,
|
||||
webhook_callback_url=workflow_run_model.webhook_callback_url,
|
||||
created_at=workflow_run_model.created_at,
|
||||
modified_at=workflow_run_model.modified_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_workflow_parameter(
|
||||
workflow_parameter_model: WorkflowParameterModel, debug_enabled: bool = False
|
||||
) -> WorkflowParameter:
|
||||
if debug_enabled:
|
||||
LOG.debug(
|
||||
"Converting WorkflowParameterModel to WorkflowParameter",
|
||||
workflow_parameter_id=workflow_parameter_model.workflow_parameter_id,
|
||||
)
|
||||
|
||||
workflow_parameter_type = WorkflowParameterType[workflow_parameter_model.workflow_parameter_type.upper()]
|
||||
|
||||
return WorkflowParameter(
|
||||
workflow_parameter_id=workflow_parameter_model.workflow_parameter_id,
|
||||
workflow_parameter_type=workflow_parameter_type,
|
||||
workflow_id=workflow_parameter_model.workflow_id,
|
||||
default_value=workflow_parameter_type.convert_value(workflow_parameter_model.default_value),
|
||||
key=workflow_parameter_model.key,
|
||||
description=workflow_parameter_model.description,
|
||||
created_at=workflow_parameter_model.created_at,
|
||||
modified_at=workflow_parameter_model.modified_at,
|
||||
deleted_at=workflow_parameter_model.deleted_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_aws_secret_parameter(
|
||||
aws_secret_parameter_model: AWSSecretParameterModel, debug_enabled: bool = False
|
||||
) -> AWSSecretParameter:
|
||||
if debug_enabled:
|
||||
LOG.debug(
|
||||
"Converting AWSSecretParameterModel to AWSSecretParameter",
|
||||
aws_secret_parameter_id=aws_secret_parameter_model.id,
|
||||
)
|
||||
|
||||
return AWSSecretParameter(
|
||||
aws_secret_parameter_id=aws_secret_parameter_model.aws_secret_parameter_id,
|
||||
workflow_id=aws_secret_parameter_model.workflow_id,
|
||||
key=aws_secret_parameter_model.key,
|
||||
description=aws_secret_parameter_model.description,
|
||||
aws_key=aws_secret_parameter_model.aws_key,
|
||||
created_at=aws_secret_parameter_model.created_at,
|
||||
modified_at=aws_secret_parameter_model.modified_at,
|
||||
deleted_at=aws_secret_parameter_model.deleted_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_workflow_run_parameter(
|
||||
workflow_run_parameter_model: WorkflowRunParameterModel,
|
||||
workflow_parameter: WorkflowParameter,
|
||||
debug_enabled: bool = False,
|
||||
) -> WorkflowRunParameter:
|
||||
if debug_enabled:
|
||||
LOG.debug(
|
||||
"Converting WorkflowRunParameterModel to WorkflowRunParameter",
|
||||
workflow_run_id=workflow_run_parameter_model.workflow_run_id,
|
||||
workflow_parameter_id=workflow_run_parameter_model.workflow_parameter_id,
|
||||
)
|
||||
|
||||
return WorkflowRunParameter(
|
||||
workflow_run_id=workflow_run_parameter_model.workflow_run_id,
|
||||
workflow_parameter_id=workflow_run_parameter_model.workflow_parameter_id,
|
||||
value=workflow_parameter.workflow_parameter_type.convert_value(workflow_run_parameter_model.value),
|
||||
created_at=workflow_run_parameter_model.created_at,
|
||||
)
|
||||
0
skyvern/forge/sdk/executor/__init__.py
Normal file
0
skyvern/forge/sdk/executor/__init__.py
Normal file
85
skyvern/forge/sdk/executor/async_executor.py
Normal file
85
skyvern/forge/sdk/executor/async_executor.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import abc
|
||||
|
||||
from fastapi import BackgroundTasks
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.models import Organization
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
|
||||
|
||||
class AsyncExecutor(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def execute_task(
|
||||
self,
|
||||
background_tasks: BackgroundTasks,
|
||||
task: Task,
|
||||
organization: Organization,
|
||||
max_steps_override: int | None,
|
||||
api_key: str | None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def execute_workflow(
|
||||
self,
|
||||
background_tasks: BackgroundTasks,
|
||||
organization: Organization,
|
||||
workflow_id: str,
|
||||
workflow_run_id: str,
|
||||
max_steps_override: int | None,
|
||||
api_key: str | None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class BackgroundTaskExecutor(AsyncExecutor):
|
||||
async def execute_task(
|
||||
self,
|
||||
background_tasks: BackgroundTasks,
|
||||
task: Task,
|
||||
organization: Organization,
|
||||
max_steps_override: int | None,
|
||||
api_key: str | None,
|
||||
) -> None:
|
||||
step = await app.DATABASE.create_step(
|
||||
task.task_id,
|
||||
order=0,
|
||||
retry_index=0,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
|
||||
task = await app.DATABASE.update_task(
|
||||
task.task_id,
|
||||
TaskStatus.running,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
|
||||
context: SkyvernContext = skyvern_context.ensure_context()
|
||||
context.task_id = task.task_id
|
||||
context.organization_id = organization.organization_id
|
||||
context.max_steps_override = max_steps_override
|
||||
|
||||
background_tasks.add_task(
|
||||
app.agent.execute_step,
|
||||
organization,
|
||||
task,
|
||||
step,
|
||||
api_key,
|
||||
)
|
||||
|
||||
async def execute_workflow(
|
||||
self,
|
||||
background_tasks: BackgroundTasks,
|
||||
organization: Organization,
|
||||
workflow_id: str,
|
||||
workflow_run_id: str,
|
||||
max_steps_override: int | None,
|
||||
api_key: str | None,
|
||||
) -> None:
|
||||
background_tasks.add_task(
|
||||
app.WORKFLOW_SERVICE.execute_workflow,
|
||||
workflow_run_id=workflow_run_id,
|
||||
api_key=api_key,
|
||||
)
|
||||
13
skyvern/forge/sdk/executor/factory.py
Normal file
13
skyvern/forge/sdk/executor/factory.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from skyvern.forge.sdk.executor.async_executor import AsyncExecutor, BackgroundTaskExecutor
|
||||
|
||||
|
||||
class AsyncExecutorFactory:
|
||||
__instance: AsyncExecutor = BackgroundTaskExecutor()
|
||||
|
||||
@staticmethod
|
||||
def set_executor(executor: AsyncExecutor) -> None:
|
||||
AsyncExecutorFactory.__instance = executor
|
||||
|
||||
@staticmethod
|
||||
def get_executor() -> AsyncExecutor:
|
||||
return AsyncExecutorFactory.__instance
|
||||
90
skyvern/forge/sdk/forge_log.py
Normal file
90
skyvern/forge/sdk/forge_log.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import logging
|
||||
|
||||
import structlog
|
||||
from structlog.typing import EventDict
|
||||
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
|
||||
def add_kv_pairs_to_msg(logger: logging.Logger, method_name: str, event_dict: EventDict) -> EventDict:
|
||||
"""
|
||||
A custom processor to add key-value pairs to the 'msg' field.
|
||||
"""
|
||||
# Add context to the log
|
||||
context = skyvern_context.current()
|
||||
if context:
|
||||
if context.request_id:
|
||||
event_dict["request_id"] = context.request_id
|
||||
if context.organization_id:
|
||||
event_dict["organization_id"] = context.organization_id
|
||||
if context.task_id:
|
||||
event_dict["task_id"] = context.task_id
|
||||
if context.workflow_id:
|
||||
event_dict["workflow_id"] = context.workflow_id
|
||||
if context.workflow_run_id:
|
||||
event_dict["workflow_run_id"] = context.workflow_run_id
|
||||
|
||||
# Add env to the log
|
||||
event_dict["env"] = SettingsManager.get_settings().ENV
|
||||
|
||||
if method_name not in ["info", "warning", "error", "critical", "exception"]:
|
||||
# Only modify the log for these log levels
|
||||
return event_dict
|
||||
|
||||
# Assuming 'event' or 'msg' is the field to update
|
||||
msg_field = event_dict.get("msg", "")
|
||||
|
||||
# Add key-value pairs
|
||||
kv_pairs = {k: v for k, v in event_dict.items() if k not in ["msg", "timestamp", "level"]}
|
||||
if kv_pairs:
|
||||
additional_info = ", ".join(f"{k}={v}" for k, v in kv_pairs.items())
|
||||
msg_field += f" | {additional_info}"
|
||||
|
||||
event_dict["msg"] = msg_field
|
||||
|
||||
return event_dict
|
||||
|
||||
|
||||
def setup_logger() -> None:
|
||||
"""
|
||||
Setup the logger with the specified format
|
||||
"""
|
||||
# logging.config.dictConfig(logging_config)
|
||||
renderer = (
|
||||
structlog.processors.JSONRenderer()
|
||||
if SettingsManager.get_settings().JSON_LOGGING
|
||||
else structlog.dev.ConsoleRenderer()
|
||||
)
|
||||
additional_processors = (
|
||||
[
|
||||
structlog.processors.EventRenamer("msg"),
|
||||
add_kv_pairs_to_msg,
|
||||
structlog.processors.CallsiteParameterAdder(
|
||||
{
|
||||
structlog.processors.CallsiteParameter.PATHNAME,
|
||||
structlog.processors.CallsiteParameter.FILENAME,
|
||||
structlog.processors.CallsiteParameter.MODULE,
|
||||
structlog.processors.CallsiteParameter.FUNC_NAME,
|
||||
structlog.processors.CallsiteParameter.LINENO,
|
||||
}
|
||||
),
|
||||
]
|
||||
if SettingsManager.get_settings().JSON_LOGGING
|
||||
else []
|
||||
)
|
||||
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
# structlog.processors.dict_tracebacks,
|
||||
structlog.processors.format_exc_info,
|
||||
]
|
||||
+ additional_processors
|
||||
+ [renderer]
|
||||
)
|
||||
uvicorn_error = logging.getLogger("uvicorn.error")
|
||||
uvicorn_error.disabled = True
|
||||
uvicorn_access = logging.getLogger("uvicorn.access")
|
||||
uvicorn_access.disabled = True
|
||||
137
skyvern/forge/sdk/models.py
Normal file
137
skyvern/forge/sdk/models.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.webeye.actions.actions import ActionType
|
||||
from skyvern.webeye.actions.models import AgentStepOutput
|
||||
|
||||
|
||||
class StepStatus(StrEnum):
|
||||
created = "created"
|
||||
running = "running"
|
||||
failed = "failed"
|
||||
completed = "completed"
|
||||
|
||||
def can_update_to(self, new_status: StepStatus) -> bool:
|
||||
allowed_transitions: dict[StepStatus, set[StepStatus]] = {
|
||||
StepStatus.created: {StepStatus.running},
|
||||
StepStatus.running: {StepStatus.completed, StepStatus.failed},
|
||||
StepStatus.failed: set(),
|
||||
StepStatus.completed: set(),
|
||||
}
|
||||
return new_status in allowed_transitions[self]
|
||||
|
||||
def requires_output(self) -> bool:
|
||||
status_requires_output = {StepStatus.completed}
|
||||
return self in status_requires_output
|
||||
|
||||
def cant_have_output(self) -> bool:
|
||||
status_cant_have_output = {StepStatus.created, StepStatus.running}
|
||||
return self in status_cant_have_output
|
||||
|
||||
def is_terminal(self) -> bool:
|
||||
status_is_terminal = {StepStatus.failed, StepStatus.completed}
|
||||
return self in status_is_terminal
|
||||
|
||||
|
||||
class Step(BaseModel):
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
task_id: str
|
||||
step_id: str
|
||||
status: StepStatus
|
||||
output: AgentStepOutput | None = None
|
||||
order: int
|
||||
is_last: bool
|
||||
retry_index: int = 0
|
||||
organization_id: str | None = None
|
||||
input_token_count: int = 0
|
||||
output_token_count: int = 0
|
||||
step_cost: float = 0
|
||||
|
||||
def validate_update(self, status: StepStatus | None, output: AgentStepOutput | None, is_last: bool | None) -> None:
|
||||
old_status = self.status
|
||||
|
||||
if status and not old_status.can_update_to(status):
|
||||
raise ValueError(f"invalid_status_transition({old_status},{status},{self.step_id})")
|
||||
|
||||
if status and status.requires_output() and output is None:
|
||||
raise ValueError(f"status_requires_output({status},{self.step_id})")
|
||||
|
||||
if status and status.cant_have_output() and output is not None:
|
||||
raise ValueError(f"status_cant_have_output({status},{self.step_id})")
|
||||
|
||||
if output is not None and status is None:
|
||||
raise ValueError(f"cant_set_output_without_updating_status({self.step_id})")
|
||||
|
||||
if self.output is not None and output is not None:
|
||||
raise ValueError(f"cant_override_output({self.step_id})")
|
||||
|
||||
if is_last and not self.status.is_terminal():
|
||||
raise ValueError(f"is_last_but_status_not_terminal({self.status},{self.step_id})")
|
||||
|
||||
if is_last is False:
|
||||
raise ValueError(f"cant_set_is_last_to_false({self.step_id})")
|
||||
|
||||
def is_goal_achieved(self) -> bool:
|
||||
if self.status != StepStatus.completed:
|
||||
return False
|
||||
# TODO (kerem): Remove this check once we have backfilled all the steps
|
||||
if self.output is None or self.output.actions_and_results is None:
|
||||
return False
|
||||
|
||||
# Check if there is a successful complete action
|
||||
for action, action_results in self.output.actions_and_results:
|
||||
if action.action_type != ActionType.COMPLETE:
|
||||
continue
|
||||
|
||||
if any(action_result.success for action_result in action_results):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
if self.status != StepStatus.completed:
|
||||
return False
|
||||
# TODO (kerem): Remove this check once we have backfilled all the steps
|
||||
if self.output is None or self.output.actions_and_results is None:
|
||||
return False
|
||||
|
||||
# Check if there is a successful terminate action
|
||||
for action, action_results in self.output.actions_and_results:
|
||||
if action.action_type != ActionType.TERMINATE:
|
||||
continue
|
||||
|
||||
if any(action_result.success for action_result in action_results):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class Organization(BaseModel):
|
||||
organization_id: str
|
||||
organization_name: str
|
||||
webhook_callback_url: str | None = None
|
||||
max_steps_per_run: int | None = None
|
||||
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
|
||||
|
||||
class OrganizationAuthToken(BaseModel):
|
||||
id: str
|
||||
organization_id: str
|
||||
token_type: OrganizationAuthTokenType
|
||||
token: str
|
||||
valid: bool
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str
|
||||
exp: int
|
||||
98
skyvern/forge/sdk/prompting.py
Normal file
98
skyvern/forge/sdk/prompting.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Relative to this file I will have a prompt directory its located ../prompts
|
||||
In this directory there will be a techniques directory and a directory for each model - gpt-3.5-turbo gpt-4, llama-2-70B, code-llama-7B etc
|
||||
|
||||
Each directory will have jinga2 templates for the prompts.
|
||||
prompts in the model directories can use the techniques in the techniques directory.
|
||||
|
||||
Write the code I'd need to load and populate the templates.
|
||||
|
||||
I want the following functions:
|
||||
|
||||
class PromptEngine:
|
||||
|
||||
def __init__(self, model):
|
||||
pass
|
||||
|
||||
def load_prompt(model, prompt_name, prompt_ags) -> str:
|
||||
pass
|
||||
"""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from difflib import get_close_matches
|
||||
from typing import Any, List
|
||||
|
||||
import structlog
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class PromptEngine:
|
||||
"""
|
||||
Class to handle loading and populating Jinja2 templates for prompts.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str):
|
||||
"""
|
||||
Initialize the PromptEngine with the specified model.
|
||||
|
||||
Args:
|
||||
model (str): The model to use for loading prompts.
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
try:
|
||||
# Get the list of all model directories
|
||||
models_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../prompts"))
|
||||
model_names = [
|
||||
os.path.basename(os.path.normpath(d))
|
||||
for d in glob.glob(os.path.join(models_dir, "*/"))
|
||||
if os.path.isdir(d) and "techniques" not in d
|
||||
]
|
||||
|
||||
self.model = self.get_closest_match(self.model, model_names)
|
||||
|
||||
self.env = Environment(loader=FileSystemLoader(models_dir))
|
||||
except Exception:
|
||||
LOG.error("Error initializing PromptEngine.", model=model, exc_info=True)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_closest_match(target: str, model_dirs: List[str]) -> str:
|
||||
"""
|
||||
Find the closest match to the target in the list of model directories.
|
||||
|
||||
Args:
|
||||
target (str): The target model.
|
||||
model_dirs (list): The list of available model directories.
|
||||
|
||||
Returns:
|
||||
str: The closest match to the target.
|
||||
"""
|
||||
try:
|
||||
matches = get_close_matches(target, model_dirs, n=1, cutoff=0.1)
|
||||
return matches[0]
|
||||
except Exception:
|
||||
LOG.error("Failed to get closest match.", target=target, model_dirs=model_dirs, exc_info=True)
|
||||
raise
|
||||
|
||||
def load_prompt(self, template: str, **kwargs: Any) -> str:
|
||||
"""
|
||||
Load and populate the specified template.
|
||||
|
||||
Args:
|
||||
template (str): The name of the template to load.
|
||||
**kwargs: The arguments to populate the template with.
|
||||
|
||||
Returns:
|
||||
str: The populated template.
|
||||
"""
|
||||
try:
|
||||
template = os.path.join(self.model, template)
|
||||
jinja_template = self.env.get_template(f"{template}.j2")
|
||||
return jinja_template.render(**kwargs)
|
||||
except Exception:
|
||||
LOG.error("Failed to load prompt.", template=template, kwargs_keys=kwargs.keys(), exc_info=True)
|
||||
raise
|
||||
0
skyvern/forge/sdk/routes/__init__.py
Normal file
0
skyvern/forge/sdk/routes/__init__.py
Normal file
397
skyvern/forge/sdk/routes/agent_protocol.py
Normal file
397
skyvern/forge/sdk/routes/agent_protocol.py
Normal file
@@ -0,0 +1,397 @@
|
||||
from typing import Annotated, Any
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, Query, Request, Response, status
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.exceptions import StepNotFound
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.models import Organization, Step
|
||||
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
|
||||
from skyvern.forge.sdk.services import org_auth_service
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
RunWorkflowResponse,
|
||||
WorkflowRequestBody,
|
||||
WorkflowRunStatusResponse,
|
||||
)
|
||||
|
||||
base_router = APIRouter()
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
@base_router.post("/webhook", tags=["server"])
|
||||
async def webhook(
|
||||
request: Request,
|
||||
x_skyvern_signature: Annotated[str | None, Header()] = None,
|
||||
x_skyvern_timestamp: Annotated[str | None, Header()] = None,
|
||||
) -> Response:
|
||||
payload = await request.body()
|
||||
|
||||
if not x_skyvern_signature or not x_skyvern_timestamp:
|
||||
LOG.error(
|
||||
"Webhook signature or timestamp missing",
|
||||
x_skyvern_signature=x_skyvern_signature,
|
||||
x_skyvern_timestamp=x_skyvern_timestamp,
|
||||
payload=payload,
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Missing webhook signature or timestamp")
|
||||
|
||||
generated_signature = generate_skyvern_signature(
|
||||
payload.decode("utf-8"),
|
||||
SettingsManager.get_settings().SKYVERN_API_KEY,
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
"Webhook received",
|
||||
x_skyvern_signature=x_skyvern_signature,
|
||||
x_skyvern_timestamp=x_skyvern_timestamp,
|
||||
payload=payload,
|
||||
generated_signature=generated_signature,
|
||||
valid_signature=x_skyvern_signature == generated_signature,
|
||||
)
|
||||
return Response(content="webhook validation", status_code=200)
|
||||
|
||||
|
||||
@base_router.get("/heartbeat", tags=["server"])
|
||||
async def check_server_status() -> Response:
|
||||
"""
|
||||
Check if the server is running.
|
||||
"""
|
||||
return Response(content="Server is running.", status_code=200)
|
||||
|
||||
|
||||
@base_router.post("/tasks", tags=["agent"], response_model=CreateTaskResponse)
|
||||
async def create_agent_task(
|
||||
background_tasks: BackgroundTasks,
|
||||
request: Request,
|
||||
task: TaskRequest,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
x_max_steps_override: Annotated[int | None, Header()] = None,
|
||||
) -> CreateTaskResponse:
|
||||
agent = request["agent"]
|
||||
|
||||
created_task = await agent.create_task(task, current_org.organization_id)
|
||||
if x_max_steps_override:
|
||||
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
|
||||
await app.ASYNC_EXECUTOR.execute_task(
|
||||
background_tasks=background_tasks,
|
||||
task=created_task,
|
||||
organization=current_org,
|
||||
max_steps_override=x_max_steps_override,
|
||||
api_key=x_api_key,
|
||||
)
|
||||
return CreateTaskResponse(task_id=created_task.task_id)
|
||||
|
||||
|
||||
@base_router.post(
|
||||
"/tasks/{task_id}/steps/{step_id}",
|
||||
tags=["agent"],
|
||||
response_model=Step,
|
||||
summary="Executes a specific step",
|
||||
)
|
||||
@base_router.post(
|
||||
"/tasks/{task_id}/steps/",
|
||||
tags=["agent"],
|
||||
response_model=Step,
|
||||
summary="Executes the next step",
|
||||
)
|
||||
async def execute_agent_task_step(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
step_id: str | None = None,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> Response:
|
||||
agent = request["agent"]
|
||||
task = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No task found with id {task_id}",
|
||||
)
|
||||
# An empty step request means that the agent should execute the next step for the task.
|
||||
if not step_id:
|
||||
step = await app.DATABASE.get_latest_step(task_id=task_id, organization_id=current_org.organization_id)
|
||||
if not step:
|
||||
raise StepNotFound(current_org.organization_id, task_id)
|
||||
LOG.info(
|
||||
"Executing latest step since no step_id was provided",
|
||||
task_id=task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
)
|
||||
if not step:
|
||||
LOG.error(
|
||||
"No steps found for task",
|
||||
task_id=task_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No steps found for task {task_id}",
|
||||
)
|
||||
else:
|
||||
step = await app.DATABASE.get_step(task_id, step_id, organization_id=current_org.organization_id)
|
||||
if not step:
|
||||
raise StepNotFound(current_org.organization_id, task_id, step_id)
|
||||
LOG.info(
|
||||
"Executing step",
|
||||
task_id=task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
)
|
||||
if not step:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No step found with id {step_id}",
|
||||
)
|
||||
step, _, _ = await agent.execute_step(current_org, task, step)
|
||||
return Response(
|
||||
content=step.model_dump_json() if step else "",
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@base_router.get("/tasks/{task_id}", response_model=TaskResponse)
|
||||
async def get_task(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> TaskResponse:
|
||||
request["agent"]
|
||||
task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
||||
if not task_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Task not found {task_id}",
|
||||
)
|
||||
|
||||
# get latest step
|
||||
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=current_org.organization_id)
|
||||
if not latest_step:
|
||||
return task_obj.to_task_response()
|
||||
|
||||
screenshot_url = None
|
||||
# todo (kerem): only access artifacts through the artifact manager instead of db
|
||||
screenshot_artifact = await app.DATABASE.get_latest_artifact(
|
||||
task_id=task_obj.task_id,
|
||||
step_id=latest_step.step_id,
|
||||
artifact_types=[ArtifactType.SCREENSHOT_ACTION, ArtifactType.SCREENSHOT_FINAL],
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
if screenshot_artifact:
|
||||
screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(screenshot_artifact)
|
||||
|
||||
recording_artifact = await app.DATABASE.get_latest_artifact(
|
||||
task_id=task_obj.task_id,
|
||||
artifact_types=[ArtifactType.RECORDING],
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
recording_url = None
|
||||
if recording_artifact:
|
||||
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
|
||||
|
||||
failure_reason = None
|
||||
if task_obj.status == TaskStatus.failed and (latest_step.output or task_obj.failure_reason):
|
||||
failure_reason = ""
|
||||
if task_obj.failure_reason:
|
||||
failure_reason += f"Reasoning: {task_obj.failure_reason or ''}"
|
||||
failure_reason += "\n"
|
||||
if latest_step.output and latest_step.output.action_results:
|
||||
failure_reason += "Exceptions: "
|
||||
failure_reason += str(
|
||||
[f"[{ar.exception_type}]: {ar.exception_message}" for ar in latest_step.output.action_results]
|
||||
)
|
||||
|
||||
return task_obj.to_task_response(
|
||||
screenshot_url=screenshot_url,
|
||||
recording_url=recording_url,
|
||||
failure_reason=failure_reason,
|
||||
)
|
||||
|
||||
|
||||
@base_router.get("/internal/tasks/{task_id}", response_model=list[Task])
|
||||
async def get_task_internal(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> Response:
|
||||
"""
|
||||
Get all tasks.
|
||||
:param request:
|
||||
:param page: Starting page, defaults to 1
|
||||
:param page_size:
|
||||
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
|
||||
get_agent_task endpoint.
|
||||
"""
|
||||
task = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Task not found {task_id}",
|
||||
)
|
||||
return ORJSONResponse(task.model_dump())
|
||||
|
||||
|
||||
@base_router.get("/tasks", tags=["agent"], response_model=list[Task])
|
||||
async def get_agent_tasks(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1),
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> Response:
|
||||
"""
|
||||
Get all tasks.
|
||||
:param request:
|
||||
:param page: Starting page, defaults to 1
|
||||
:param page_size: Page size, defaults to 10
|
||||
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
|
||||
get_agent_task endpoint.
|
||||
"""
|
||||
request["agent"]
|
||||
tasks = await app.DATABASE.get_tasks(page, page_size, organization_id=current_org.organization_id)
|
||||
return ORJSONResponse([task.to_task_response().model_dump() for task in tasks])
|
||||
|
||||
|
||||
@base_router.get("/internal/tasks", tags=["agent"], response_model=list[Task])
|
||||
async def get_agent_tasks_internal(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1),
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> Response:
|
||||
"""
|
||||
Get all tasks.
|
||||
:param request:
|
||||
:param page: Starting page, defaults to 1
|
||||
:param page_size: Page size, defaults to 10
|
||||
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
|
||||
get_agent_task endpoint.
|
||||
"""
|
||||
request["agent"]
|
||||
tasks = await app.DATABASE.get_tasks(page, page_size, organization_id=current_org.organization_id)
|
||||
return ORJSONResponse([task.model_dump() for task in tasks])
|
||||
|
||||
|
||||
@base_router.get("/tasks/{task_id}/steps", tags=["agent"], response_model=list[Step])
|
||||
async def get_agent_task_steps(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> Response:
|
||||
"""
|
||||
Get all steps for a task.
|
||||
:param request:
|
||||
:param task_id:
|
||||
:return: List of steps for a task with pagination.
|
||||
"""
|
||||
request["agent"]
|
||||
steps = await app.DATABASE.get_task_steps(task_id, organization_id=current_org.organization_id)
|
||||
return ORJSONResponse([step.model_dump() for step in steps])
|
||||
|
||||
|
||||
@base_router.get("/tasks/{task_id}/steps/{step_id}/artifacts", tags=["agent"], response_model=list[Artifact])
|
||||
async def get_agent_task_step_artifacts(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> Response:
|
||||
"""
|
||||
Get all artifacts for a list of steps.
|
||||
:param request:
|
||||
:param task_id:
|
||||
:param step_id:
|
||||
:return: List of artifacts for a list of steps.
|
||||
"""
|
||||
request["agent"]
|
||||
artifacts = await app.DATABASE.get_artifacts_for_task_step(
|
||||
task_id,
|
||||
step_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
return ORJSONResponse([artifact.model_dump() for artifact in artifacts])
|
||||
|
||||
|
||||
class ActionResultTmp(BaseModel):
|
||||
action: dict[str, Any]
|
||||
data: dict[str, Any] | list | str | None = None
|
||||
exception_message: str | None = None
|
||||
success: bool = True
|
||||
|
||||
|
||||
@base_router.get("/tasks/{task_id}/actions", response_model=list[ActionResultTmp])
|
||||
async def get_task_actions(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> list[ActionResultTmp]:
|
||||
request["agent"]
|
||||
steps = await app.DATABASE.get_task_step_models(task_id, organization_id=current_org.organization_id)
|
||||
results: list[ActionResultTmp] = []
|
||||
for step_s in steps:
|
||||
if not step_s.output or "action_results" not in step_s.output:
|
||||
continue
|
||||
for action_result in step_s.output["action_results"]:
|
||||
results.append(ActionResultTmp.model_validate(action_result))
|
||||
return results
|
||||
|
||||
|
||||
@base_router.post("/workflows/{workflow_id}/run", response_model=RunWorkflowResponse)
|
||||
async def execute_workflow(
|
||||
background_tasks: BackgroundTasks,
|
||||
request: Request,
|
||||
workflow_id: str,
|
||||
workflow_request: WorkflowRequestBody,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
x_max_steps_override: Annotated[int | None, Header()] = None,
|
||||
) -> RunWorkflowResponse:
|
||||
LOG.info(
|
||||
f"Running workflow {workflow_id}",
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
context = skyvern_context.ensure_context()
|
||||
request_id = context.request_id
|
||||
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
|
||||
request_id=request_id,
|
||||
workflow_request=workflow_request,
|
||||
workflow_id=workflow_id,
|
||||
organization_id=current_org.organization_id,
|
||||
max_steps_override=x_max_steps_override,
|
||||
)
|
||||
if x_max_steps_override:
|
||||
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
|
||||
await app.ASYNC_EXECUTOR.execute_workflow(
|
||||
background_tasks=background_tasks,
|
||||
organization=current_org,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
max_steps_override=x_max_steps_override,
|
||||
api_key=x_api_key,
|
||||
)
|
||||
return RunWorkflowResponse(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
|
||||
|
||||
@base_router.get("/workflows/{workflow_id}/runs/{workflow_run_id}", response_model=WorkflowRunStatusResponse)
|
||||
async def get_workflow_run(
|
||||
request: Request,
|
||||
workflow_id: str,
|
||||
workflow_run_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> WorkflowRunStatusResponse:
|
||||
request["agent"]
|
||||
return await app.WORKFLOW_SERVICE.build_workflow_run_status_response(
|
||||
workflow_id=workflow_id, workflow_run_id=workflow_run_id, organization_id=current_org.organization_id
|
||||
)
|
||||
0
skyvern/forge/sdk/schemas/__init__.py
Normal file
0
skyvern/forge/sdk/schemas/__init__.py
Normal file
181
skyvern/forge/sdk/schemas/tasks.py
Normal file
181
skyvern/forge/sdk/schemas/tasks.py
Normal file
@@ -0,0 +1,181 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ProxyLocation(StrEnum):
|
||||
US_CA = "US-CA"
|
||||
US_NY = "US-NY"
|
||||
US_TX = "US-TX"
|
||||
US_FL = "US-FL"
|
||||
US_WA = "US-WA"
|
||||
RESIDENTIAL = "RESIDENTIAL"
|
||||
NONE = "NONE"
|
||||
|
||||
|
||||
class TaskRequest(BaseModel):
|
||||
url: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Starting URL for the task.",
|
||||
examples=["https://www.geico.com"],
|
||||
)
|
||||
# TODO: use HttpUrl instead of str
|
||||
webhook_callback_url: str | None = Field(
|
||||
default=None,
|
||||
description="The URL to call when the task is completed.",
|
||||
examples=["https://my-webhook.com"],
|
||||
)
|
||||
navigation_goal: str | None = Field(
|
||||
default=None,
|
||||
description="The user's goal for the task.",
|
||||
examples=["Get a quote for car insurance"],
|
||||
)
|
||||
data_extraction_goal: str | None = Field(
|
||||
default=None,
|
||||
description="The user's goal for data extraction.",
|
||||
examples=["Extract the quote price"],
|
||||
)
|
||||
navigation_payload: dict[str, Any] | list | str | None = Field(
|
||||
None,
|
||||
description="The user's details needed to achieve the task.",
|
||||
examples=[{"name": "John Doe", "email": "john@doe.com"}],
|
||||
)
|
||||
proxy_location: ProxyLocation | None = Field(
|
||||
None,
|
||||
description="The location of the proxy to use for the task.",
|
||||
examples=["US-WA", "US-CA", "US-FL", "US-NY", "US-TX"],
|
||||
)
|
||||
extracted_information_schema: dict[str, Any] | list | str | None = Field(
|
||||
None,
|
||||
description="The requested schema of the extracted information.",
|
||||
)
|
||||
|
||||
|
||||
class TaskStatus(StrEnum):
|
||||
created = "created"
|
||||
running = "running"
|
||||
failed = "failed"
|
||||
terminated = "terminated"
|
||||
completed = "completed"
|
||||
|
||||
def is_final(self) -> bool:
|
||||
return self in {TaskStatus.failed, TaskStatus.terminated, TaskStatus.completed}
|
||||
|
||||
def can_update_to(self, new_status: TaskStatus) -> bool:
|
||||
allowed_transitions: dict[TaskStatus, set[TaskStatus]] = {
|
||||
TaskStatus.created: {TaskStatus.running},
|
||||
TaskStatus.running: {TaskStatus.completed, TaskStatus.failed, TaskStatus.terminated},
|
||||
TaskStatus.failed: set(),
|
||||
TaskStatus.completed: set(),
|
||||
}
|
||||
return new_status in allowed_transitions[self]
|
||||
|
||||
def requires_extracted_info(self) -> bool:
|
||||
status_requires_extracted_information = {TaskStatus.completed}
|
||||
return self in status_requires_extracted_information
|
||||
|
||||
def cant_have_extracted_info(self) -> bool:
|
||||
status_cant_have_extracted_information = {
|
||||
TaskStatus.created,
|
||||
TaskStatus.running,
|
||||
TaskStatus.failed,
|
||||
TaskStatus.terminated,
|
||||
}
|
||||
return self in status_cant_have_extracted_information
|
||||
|
||||
def requires_failure_reason(self) -> bool:
|
||||
status_requires_failure_reason = {TaskStatus.failed, TaskStatus.terminated}
|
||||
return self in status_requires_failure_reason
|
||||
|
||||
|
||||
class Task(TaskRequest):
|
||||
created_at: datetime = Field(
|
||||
...,
|
||||
description="The creation datetime of the task.",
|
||||
examples=["2023-01-01T00:00:00Z"],
|
||||
)
|
||||
modified_at: datetime = Field(
|
||||
...,
|
||||
description="The modification datetime of the task.",
|
||||
examples=["2023-01-01T00:00:00Z"],
|
||||
)
|
||||
task_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task.",
|
||||
examples=["50da533e-3904-4401-8a07-c49adf88b5eb"],
|
||||
)
|
||||
status: TaskStatus = Field(..., description="The status of the task.", examples=["created"])
|
||||
extracted_information: dict[str, Any] | list | str | None = Field(
|
||||
None,
|
||||
description="The extracted information from the task.",
|
||||
)
|
||||
failure_reason: str | None = Field(
|
||||
None,
|
||||
description="The reason for the task failure.",
|
||||
)
|
||||
organization_id: str | None = None
|
||||
workflow_run_id: str | None = None
|
||||
order: int | None = None
|
||||
retry: int | None = None
|
||||
|
||||
def validate_update(
|
||||
self,
|
||||
status: TaskStatus,
|
||||
extracted_information: dict[str, Any] | list | str | None,
|
||||
failure_reason: str | None = None,
|
||||
) -> None:
|
||||
old_status = self.status
|
||||
|
||||
if not old_status.can_update_to(status):
|
||||
raise ValueError(f"invalid_status_transition({old_status},{status},{self.task_id}")
|
||||
|
||||
if status.requires_failure_reason() and failure_reason is None:
|
||||
raise ValueError(f"status_requires_failure_reason({status},{self.task_id}")
|
||||
|
||||
if status.requires_extracted_info() and self.data_extraction_goal and extracted_information is None:
|
||||
raise ValueError(f"status_requires_extracted_information({status},{self.task_id}")
|
||||
|
||||
if status.cant_have_extracted_info() and extracted_information is not None:
|
||||
raise ValueError(f"status_cant_have_extracted_information({self.task_id})")
|
||||
|
||||
if self.extracted_information is not None and extracted_information is not None:
|
||||
raise ValueError(f"cant_override_extracted_information({self.task_id})")
|
||||
|
||||
if self.failure_reason is not None and failure_reason is not None:
|
||||
raise ValueError(f"cant_override_failure_reason({self.task_id})")
|
||||
|
||||
def to_task_response(
|
||||
self, screenshot_url: str | None = None, recording_url: str | None = None, failure_reason: str | None = None
|
||||
) -> TaskResponse:
|
||||
return TaskResponse(
|
||||
request=self,
|
||||
task_id=self.task_id,
|
||||
status=self.status,
|
||||
created_at=self.created_at,
|
||||
modified_at=self.modified_at,
|
||||
extracted_information=self.extracted_information,
|
||||
failure_reason=failure_reason or self.failure_reason,
|
||||
screenshot_url=screenshot_url,
|
||||
recording_url=recording_url,
|
||||
)
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
request: TaskRequest
|
||||
task_id: str
|
||||
status: TaskStatus
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
extracted_information: list | dict[str, Any] | str | None = None
|
||||
screenshot_url: str | None = None
|
||||
recording_url: str | None = None
|
||||
failure_reason: str | None = None
|
||||
|
||||
|
||||
class CreateTaskResponse(BaseModel):
|
||||
task_id: str
|
||||
0
skyvern/forge/sdk/services/__init__.py
Normal file
0
skyvern/forge/sdk/services/__init__.py
Normal file
76
skyvern/forge/sdk/services/org_auth_service.py
Normal file
76
skyvern/forge/sdk/services/org_auth_service.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import time
|
||||
from typing import Annotated
|
||||
|
||||
from asyncache import cached
|
||||
from cachetools import TTLCache
|
||||
from fastapi import Header, HTTPException, status
|
||||
from jose import jwt
|
||||
from jose.exceptions import JWTError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.db.client import AgentDB
|
||||
from skyvern.forge.sdk.models import Organization, OrganizationAuthTokenType, TokenPayload
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
AUTHENTICATION_TTL = 60 * 60 # one hour
|
||||
CACHE_SIZE = 128
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
async def get_current_org(
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
) -> Organization:
|
||||
if not x_api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
return await _get_current_org_cached(x_api_key, app.DATABASE)
|
||||
|
||||
|
||||
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
|
||||
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
||||
"""
|
||||
Authentication is cached for one hour
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
x_api_key,
|
||||
SettingsManager.get_settings().SECRET_KEY,
|
||||
algorithms=[ALGORITHM],
|
||||
)
|
||||
api_key_data = TokenPayload(**payload)
|
||||
except (JWTError, ValidationError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Could not validate credentials",
|
||||
)
|
||||
if api_key_data.exp < time.time():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Auth token is expired",
|
||||
)
|
||||
|
||||
organization = await db.get_organization(organization_id=api_key_data.sub)
|
||||
if not organization:
|
||||
raise HTTPException(status_code=404, detail="Organization not found")
|
||||
|
||||
# check if the token exists in the database
|
||||
api_key_db_obj = await db.validate_org_auth_token(
|
||||
organization_id=organization.organization_id,
|
||||
token_type=OrganizationAuthTokenType.api,
|
||||
token=x_api_key,
|
||||
)
|
||||
if not api_key_db_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
|
||||
# set organization_id in skyvern context and log context
|
||||
context = skyvern_context.current()
|
||||
if context:
|
||||
context.organization_id = organization.organization_id
|
||||
return organization
|
||||
14
skyvern/forge/sdk/settings_manager.py
Normal file
14
skyvern/forge/sdk/settings_manager.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from skyvern.config import Settings
|
||||
from skyvern.config import settings as base_settings
|
||||
|
||||
|
||||
class SettingsManager:
|
||||
__instance: Settings = base_settings
|
||||
|
||||
@staticmethod
|
||||
def get_settings() -> Settings:
|
||||
return SettingsManager.__instance
|
||||
|
||||
@staticmethod
|
||||
def set_settings(settings: Settings) -> None:
|
||||
SettingsManager.__instance = settings
|
||||
0
skyvern/forge/sdk/workflow/__init__.py
Normal file
0
skyvern/forge/sdk/workflow/__init__.py
Normal file
79
skyvern/forge/sdk/workflow/context_manager.py
Normal file
79
skyvern/forge/sdk/workflow/context_manager.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
||||
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, Parameter, ParameterType, WorkflowParameter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunParameter
|
||||
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class ContextManager:
|
||||
aws_client: AsyncAWSClient
|
||||
parameters: dict[str, PARAMETER_TYPE]
|
||||
values: dict[str, Any]
|
||||
|
||||
def __init__(self, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]) -> None:
|
||||
self.aws_client = AsyncAWSClient()
|
||||
self.parameters = {}
|
||||
self.values = {}
|
||||
for parameter, run_parameter in workflow_parameter_tuples:
|
||||
if parameter.key in self.parameters:
|
||||
prev_value = self.parameters[parameter.key]
|
||||
new_value = run_parameter.value
|
||||
LOG.error(
|
||||
f"Duplicate parameter key {parameter.key} found while initializing context manager, previous value: {prev_value}, new value: {new_value}. Using new value."
|
||||
)
|
||||
|
||||
self.parameters[parameter.key] = parameter
|
||||
self.values[parameter.key] = run_parameter.value
|
||||
|
||||
async def register_parameter_value(
|
||||
self,
|
||||
parameter: PARAMETER_TYPE,
|
||||
) -> None:
|
||||
if parameter.parameter_type == ParameterType.WORKFLOW:
|
||||
LOG.error(f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}")
|
||||
raise ValueError(
|
||||
f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}"
|
||||
)
|
||||
elif parameter.parameter_type == ParameterType.AWS_SECRET:
|
||||
secret_value = await self.aws_client.get_secret(parameter.aws_key)
|
||||
if secret_value is not None:
|
||||
self.values[parameter.key] = secret_value
|
||||
else:
|
||||
# ContextParameter values will be set within the blocks
|
||||
return None
|
||||
|
||||
async def register_block_parameters(
|
||||
self,
|
||||
parameters: list[PARAMETER_TYPE],
|
||||
) -> None:
|
||||
for parameter in parameters:
|
||||
if parameter.key in self.parameters:
|
||||
LOG.debug(f"Parameter {parameter.key} already registered, skipping")
|
||||
continue
|
||||
|
||||
if parameter.parameter_type == ParameterType.WORKFLOW:
|
||||
LOG.error(
|
||||
f"Workflow parameter {parameter.key} should have already been set through workflow run parameters"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Workflow parameter {parameter.key} should have already been set through workflow run parameters"
|
||||
)
|
||||
|
||||
self.parameters[parameter.key] = parameter
|
||||
await self.register_parameter_value(parameter)
|
||||
|
||||
def get_parameter(self, key: str) -> Parameter:
|
||||
return self.parameters[key]
|
||||
|
||||
def get_value(self, key: str) -> Any:
|
||||
return self.values[key]
|
||||
|
||||
def set_value(self, key: str, value: Any) -> None:
|
||||
self.values[key] = value
|
||||
0
skyvern/forge/sdk/workflow/models/__init__.py
Normal file
0
skyvern/forge/sdk/workflow/models/__init__.py
Normal file
221
skyvern/forge/sdk/workflow/models/block.py
Normal file
221
skyvern/forge/sdk/workflow/models/block.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import abc
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any, Literal, Union
|
||||
|
||||
import structlog
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from skyvern.exceptions import (
|
||||
ContextParameterValueNotFound,
|
||||
MissingBrowserStatePage,
|
||||
TaskNotFound,
|
||||
UnexpectedTaskStatus,
|
||||
)
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.workflow.context_manager import ContextManager
|
||||
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, ContextParameter, WorkflowParameter
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class BlockType(StrEnum):
|
||||
TASK = "task"
|
||||
FOR_LOOP = "for_loop"
|
||||
|
||||
|
||||
class Block(BaseModel, abc.ABC):
|
||||
block_type: BlockType
|
||||
parent_block_id: str | None = None
|
||||
next_block_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def get_subclasses(cls) -> tuple[type["Block"], ...]:
|
||||
return tuple(cls.__subclasses__())
|
||||
|
||||
@abc.abstractmethod
|
||||
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_all_parameters(
|
||||
self,
|
||||
) -> list[PARAMETER_TYPE]:
|
||||
pass
|
||||
|
||||
|
||||
class TaskBlock(Block):
|
||||
block_type: Literal[BlockType.TASK] = BlockType.TASK
|
||||
|
||||
url: str | None = None
|
||||
navigation_goal: str | None = None
|
||||
data_extraction_goal: str | None = None
|
||||
data_schema: dict[str, Any] | None = None
|
||||
max_retries: int = 0
|
||||
parameters: list[PARAMETER_TYPE] = []
|
||||
|
||||
def get_all_parameters(
|
||||
self,
|
||||
) -> list[PARAMETER_TYPE]:
|
||||
return self.parameters
|
||||
|
||||
@staticmethod
|
||||
async def get_task_order(workflow_run_id: str, current_retry: int) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the order and retry for the next task in the workflow run as a tuple.
|
||||
"""
|
||||
last_task_for_workflow_run = await app.DATABASE.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
|
||||
# If there is no previous task, the order will be 0 and the retry will be 0.
|
||||
if last_task_for_workflow_run is None:
|
||||
return 0, 0
|
||||
# If there is a previous task but the current retry is 0, the order will be the order of the last task + 1
|
||||
# and the retry will be 0.
|
||||
order = last_task_for_workflow_run.order or 0
|
||||
if current_retry == 0:
|
||||
return order + 1, 0
|
||||
# If there is a previous task and the current retry is not 0, the order will be the order of the last task
|
||||
# and the retry will be the retry of the last task + 1. (There is a validation that makes sure the retry
|
||||
# of the last task is equal to current_retry - 1) if it is not, we use last task retry + 1.
|
||||
retry = last_task_for_workflow_run.retry or 0
|
||||
if retry + 1 != current_retry:
|
||||
LOG.error(
|
||||
f"Last task for workflow run is retry number {last_task_for_workflow_run.retry}, "
|
||||
f"but current retry is {current_retry}. Could be race condition. Using last task retry + 1",
|
||||
workflow_run_id=workflow_run_id,
|
||||
last_task_id=last_task_for_workflow_run.task_id,
|
||||
last_task_retry=last_task_for_workflow_run.retry,
|
||||
current_retry=current_retry,
|
||||
)
|
||||
|
||||
return order, retry + 1
|
||||
|
||||
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
|
||||
task = None
|
||||
current_retry = 0
|
||||
# initial value for will_retry is True, so that the loop runs at least once
|
||||
will_retry = True
|
||||
workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id=workflow_run_id)
|
||||
workflow = await app.WORKFLOW_SERVICE.get_workflow(workflow_id=workflow_run.workflow_id)
|
||||
# TODO (kerem) we should always retry on terminated. We should make a distinction between retriable and
|
||||
# non-retryable terminations
|
||||
while will_retry:
|
||||
task_order, task_retry = await self.get_task_order(workflow_run_id, current_retry)
|
||||
task, step = await app.agent.create_task_and_step_from_block(
|
||||
task_block=self,
|
||||
workflow=workflow,
|
||||
workflow_run=workflow_run,
|
||||
context_manager=context_manager,
|
||||
task_order=task_order,
|
||||
task_retry=task_retry,
|
||||
)
|
||||
organization = await app.DATABASE.get_organization(organization_id=workflow.organization_id)
|
||||
if not organization:
|
||||
raise Exception(f"Organization is missing organization_id={workflow.organization_id}")
|
||||
browser_state = await app.BROWSER_MANAGER.get_or_create_for_workflow_run(
|
||||
workflow_run=workflow_run, url=self.url
|
||||
)
|
||||
if not browser_state.page:
|
||||
LOG.error("BrowserState has no page", workflow_run_id=workflow_run.workflow_run_id)
|
||||
raise MissingBrowserStatePage(workflow_run_id=workflow_run.workflow_run_id)
|
||||
|
||||
LOG.info(
|
||||
f"Navigating to page",
|
||||
url=self.url,
|
||||
workflow_run_id=workflow_run_id,
|
||||
task_id=task.task_id,
|
||||
workflow_id=workflow.workflow_id,
|
||||
organization_id=workflow.organization_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
|
||||
if self.url:
|
||||
await browser_state.page.goto(self.url)
|
||||
|
||||
await app.agent.execute_step(organization=organization, task=task, step=step, workflow_run=workflow_run)
|
||||
# Check task status
|
||||
updated_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=workflow.organization_id)
|
||||
if not updated_task:
|
||||
raise TaskNotFound(task.task_id)
|
||||
if not updated_task.status.is_final():
|
||||
raise UnexpectedTaskStatus(task_id=updated_task.task_id, status=updated_task.status)
|
||||
if updated_task.status == TaskStatus.completed:
|
||||
will_retry = False
|
||||
else:
|
||||
current_retry += 1
|
||||
will_retry = current_retry <= self.max_retries
|
||||
retry_message = f", retrying task {current_retry}/{self.max_retries}" if will_retry else ""
|
||||
LOG.warning(
|
||||
f"Task failed with status {updated_task.status}{retry_message}",
|
||||
task_id=updated_task.task_id,
|
||||
status=updated_task.status,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow.workflow_id,
|
||||
organization_id=workflow.organization_id,
|
||||
current_retry=current_retry,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
|
||||
|
||||
class ForLoopBlock(Block):
|
||||
block_type: Literal[BlockType.FOR_LOOP] = BlockType.FOR_LOOP
|
||||
|
||||
# TODO (kerem): Add support for ContextParameter
|
||||
loop_over: PARAMETER_TYPE
|
||||
loop_block: "BlockTypeVar"
|
||||
|
||||
def get_all_parameters(
|
||||
self,
|
||||
) -> list[PARAMETER_TYPE]:
|
||||
return self.loop_block.get_all_parameters() + [self.loop_over]
|
||||
|
||||
def get_loop_block_context_parameters(self, workflow_run_id: str, loop_data: Any) -> list[ContextParameter]:
|
||||
if not isinstance(loop_data, dict):
|
||||
# TODO (kerem): Should we add support for other types?
|
||||
raise ValueError("loop_data should be a dictionary")
|
||||
|
||||
loop_block_parameters = self.loop_block.get_all_parameters()
|
||||
context_parameters = [
|
||||
parameter for parameter in loop_block_parameters if isinstance(parameter, ContextParameter)
|
||||
]
|
||||
for context_parameter in context_parameters:
|
||||
if context_parameter.key not in loop_data:
|
||||
raise ContextParameterValueNotFound(
|
||||
parameter_key=context_parameter.key,
|
||||
existing_keys=list(loop_data.keys()),
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
context_parameter.value = loop_data[context_parameter.key]
|
||||
|
||||
return context_parameters
|
||||
|
||||
def get_loop_over_parameter_values(self, context_manager: ContextManager) -> list[Any]:
|
||||
if isinstance(self.loop_over, WorkflowParameter):
|
||||
parameter_value = context_manager.get_value(self.loop_over.key)
|
||||
if isinstance(parameter_value, list):
|
||||
return parameter_value
|
||||
else:
|
||||
# TODO (kerem): Should we raise an error here?
|
||||
return [parameter_value]
|
||||
else:
|
||||
# TODO (kerem): Implement this for context parameters
|
||||
raise NotImplementedError
|
||||
|
||||
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
|
||||
loop_over_values = self.get_loop_over_parameter_values(context_manager)
|
||||
LOG.info(
|
||||
f"Number of loop_over values: {len(loop_over_values)}",
|
||||
block_type=self.block_type,
|
||||
workflow_run_id=workflow_run_id,
|
||||
num_loop_over_values=len(loop_over_values),
|
||||
)
|
||||
for loop_over_value in loop_over_values:
|
||||
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
|
||||
for context_parameter in context_parameters_with_value:
|
||||
context_manager.set_value(context_parameter.key, context_parameter.value)
|
||||
await self.loop_block.execute(workflow_run_id=workflow_run_id, context_manager=context_manager)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
BlockSubclasses = Union[ForLoopBlock, TaskBlock]
|
||||
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]
|
||||
84
skyvern/forge/sdk/workflow/models/parameter.py
Normal file
84
skyvern/forge/sdk/workflow/models/parameter.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import abc
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ParameterType(StrEnum):
|
||||
WORKFLOW = "workflow"
|
||||
CONTEXT = "context"
|
||||
AWS_SECRET = "aws_secret"
|
||||
|
||||
|
||||
class Parameter(BaseModel, abc.ABC):
|
||||
# TODO (kerem): Should we also have organization_id here?
|
||||
parameter_type: ParameterType
|
||||
key: str
|
||||
description: str | None = None
|
||||
|
||||
@classmethod
|
||||
def get_subclasses(cls) -> tuple[type["Parameter"], ...]:
|
||||
return tuple(cls.__subclasses__())
|
||||
|
||||
|
||||
class AWSSecretParameter(Parameter):
|
||||
parameter_type: Literal[ParameterType.AWS_SECRET] = ParameterType.AWS_SECRET
|
||||
|
||||
aws_secret_parameter_id: str
|
||||
workflow_id: str
|
||||
aws_key: str
|
||||
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
deleted_at: datetime | None = None
|
||||
|
||||
|
||||
class WorkflowParameterType(StrEnum):
|
||||
STRING = "string"
|
||||
INTEGER = "integer"
|
||||
FLOAT = "float"
|
||||
BOOLEAN = "boolean"
|
||||
JSON = "json"
|
||||
|
||||
def convert_value(self, value: str | None) -> str | int | float | bool | dict | list | None:
|
||||
if value is None:
|
||||
return None
|
||||
if self == WorkflowParameterType.STRING:
|
||||
return value
|
||||
elif self == WorkflowParameterType.INTEGER:
|
||||
return int(value)
|
||||
elif self == WorkflowParameterType.FLOAT:
|
||||
return float(value)
|
||||
elif self == WorkflowParameterType.BOOLEAN:
|
||||
return value.lower() in ["true", "1"]
|
||||
elif self == WorkflowParameterType.JSON:
|
||||
return json.loads(value)
|
||||
|
||||
|
||||
class WorkflowParameter(Parameter):
|
||||
parameter_type: Literal[ParameterType.WORKFLOW] = ParameterType.WORKFLOW
|
||||
|
||||
workflow_parameter_id: str
|
||||
workflow_parameter_type: WorkflowParameterType
|
||||
workflow_id: str
|
||||
# the type of default_value will be determined by the workflow_parameter_type
|
||||
default_value: str | int | float | bool | dict | list | None = None
|
||||
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
deleted_at: datetime | None = None
|
||||
|
||||
|
||||
class ContextParameter(Parameter):
|
||||
parameter_type: Literal[ParameterType.CONTEXT] = ParameterType.CONTEXT
|
||||
|
||||
source: WorkflowParameter
|
||||
# value will be populated by the context manager
|
||||
value: str | int | float | bool | dict | list | None = None
|
||||
|
||||
|
||||
ParameterSubclasses = Union[WorkflowParameter, ContextParameter, AWSSecretParameter]
|
||||
PARAMETER_TYPE = Annotated[ParameterSubclasses, Field(discriminator="parameter_type")]
|
||||
74
skyvern/forge/sdk/workflow/models/workflow.py
Normal file
74
skyvern/forge/sdk/workflow/models/workflow.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
|
||||
from skyvern.forge.sdk.workflow.models.block import BlockTypeVar
|
||||
|
||||
|
||||
class WorkflowRequestBody(BaseModel):
|
||||
data: dict[str, Any] | None = None
|
||||
proxy_location: ProxyLocation | None = None
|
||||
webhook_callback_url: str | None = None
|
||||
|
||||
|
||||
class RunWorkflowResponse(BaseModel):
|
||||
workflow_id: str
|
||||
workflow_run_id: str
|
||||
|
||||
|
||||
class WorkflowDefinition(BaseModel):
|
||||
blocks: List[BlockTypeVar]
|
||||
|
||||
|
||||
class Workflow(BaseModel):
|
||||
workflow_id: str
|
||||
organization_id: str
|
||||
title: str
|
||||
description: str | None = None
|
||||
workflow_definition: WorkflowDefinition
|
||||
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
deleted_at: datetime | None = None
|
||||
|
||||
|
||||
class WorkflowRunStatus(StrEnum):
|
||||
created = "created"
|
||||
running = "running"
|
||||
failed = "failed"
|
||||
terminated = "terminated"
|
||||
completed = "completed"
|
||||
|
||||
|
||||
class WorkflowRun(BaseModel):
|
||||
workflow_run_id: str
|
||||
workflow_id: str
|
||||
status: WorkflowRunStatus
|
||||
proxy_location: ProxyLocation | None = None
|
||||
webhook_callback_url: str | None = None
|
||||
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
|
||||
|
||||
class WorkflowRunParameter(BaseModel):
|
||||
workflow_run_id: str
|
||||
workflow_parameter_id: str
|
||||
value: bool | int | float | str | dict | list
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class WorkflowRunStatusResponse(BaseModel):
|
||||
workflow_id: str
|
||||
workflow_run_id: str
|
||||
status: WorkflowRunStatus
|
||||
proxy_location: ProxyLocation | None = None
|
||||
webhook_callback_url: str | None = None
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
parameters: dict[str, Any]
|
||||
screenshot_urls: list[str] | None = None
|
||||
recording_url: str | None = None
|
||||
509
skyvern/forge/sdk/workflow/service.py
Normal file
509
skyvern/forge/sdk/workflow/service.py
Normal file
@@ -0,0 +1,509 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
import structlog
|
||||
|
||||
from skyvern.exceptions import (
|
||||
FailedToSendWebhook,
|
||||
MissingValueForParameter,
|
||||
WorkflowNotFound,
|
||||
WorkflowOrganizationMismatch,
|
||||
WorkflowRunNotFound,
|
||||
)
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.forge.sdk.workflow.context_manager import ContextManager
|
||||
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
|
||||
from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
Workflow,
|
||||
WorkflowDefinition,
|
||||
WorkflowRequestBody,
|
||||
WorkflowRun,
|
||||
WorkflowRunParameter,
|
||||
WorkflowRunStatus,
|
||||
WorkflowRunStatusResponse,
|
||||
)
|
||||
from skyvern.webeye.browser_factory import BrowserState
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class WorkflowService:
|
||||
async def setup_workflow_run(
|
||||
self,
|
||||
request_id: str | None,
|
||||
workflow_request: WorkflowRequestBody,
|
||||
workflow_id: str,
|
||||
organization_id: str,
|
||||
max_steps_override: int | None = None,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Create a workflow run and its parameters. Validate the workflow and the organization. If there are missing
|
||||
parameters with no default value, mark the workflow run as failed.
|
||||
:param request_id: The request id for the workflow run.
|
||||
:param workflow_request: The request body for the workflow run, containing the parameters and the config.
|
||||
:param workflow_id: The workflow id to run.
|
||||
:param organization_id: The organization id for the workflow.
|
||||
:param max_steps_override: The max steps override for the workflow run, if any.
|
||||
:return: The created workflow run.
|
||||
"""
|
||||
LOG.info(f"Setting up workflow run for workflow {workflow_id}", workflow_id=workflow_id)
|
||||
# Validate the workflow and the organization
|
||||
workflow = await self.get_workflow(workflow_id=workflow_id)
|
||||
if workflow is None:
|
||||
LOG.error(f"Workflow {workflow_id} not found")
|
||||
raise WorkflowNotFound(workflow_id=workflow_id)
|
||||
if workflow.organization_id != organization_id:
|
||||
LOG.error(f"Workflow {workflow_id} does not belong to organization {organization_id}")
|
||||
raise WorkflowOrganizationMismatch(workflow_id=workflow_id, organization_id=organization_id)
|
||||
# Create the workflow run and set skyvern context
|
||||
workflow_run = await self.create_workflow_run(workflow_request=workflow_request, workflow_id=workflow_id)
|
||||
LOG.info(
|
||||
f"Created workflow run {workflow_run.workflow_run_id} for workflow {workflow.workflow_id}",
|
||||
request_id=request_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
workflow_id=workflow.workflow_id,
|
||||
proxy_location=workflow_request.proxy_location,
|
||||
)
|
||||
skyvern_context.set(
|
||||
SkyvernContext(
|
||||
organization_id=organization_id,
|
||||
request_id=request_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
max_steps_override=max_steps_override,
|
||||
)
|
||||
)
|
||||
|
||||
# Set workflow run status to running, create workflow run parameters
|
||||
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
|
||||
|
||||
# Create all the workflow run parameters, AWSSecretParameter won't have workflow run parameters created.
|
||||
all_workflow_parameters = await self.get_workflow_parameters(workflow_id=workflow.workflow_id)
|
||||
workflow_run_parameters = []
|
||||
for workflow_parameter in all_workflow_parameters:
|
||||
if workflow_request.data and workflow_parameter.key in workflow_request.data:
|
||||
request_body_value = workflow_request.data[workflow_parameter.key]
|
||||
workflow_run_parameter = await self.create_workflow_run_parameter(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
|
||||
value=request_body_value,
|
||||
)
|
||||
elif workflow_parameter.default_value is not None:
|
||||
workflow_run_parameter = await self.create_workflow_run_parameter(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
|
||||
value=workflow_parameter.default_value,
|
||||
)
|
||||
else:
|
||||
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
|
||||
raise MissingValueForParameter(
|
||||
parameter_key=workflow_parameter.key,
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
|
||||
workflow_run_parameters.append(workflow_run_parameter)
|
||||
|
||||
LOG.info(
|
||||
f"Created workflow run parameters for workflow run {workflow_run.workflow_run_id}",
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
|
||||
return workflow_run
|
||||
|
||||
async def execute_workflow(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
api_key: str,
|
||||
) -> WorkflowRun:
|
||||
"""Execute a workflow."""
|
||||
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
|
||||
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id)
|
||||
|
||||
await app.BROWSER_MANAGER.get_or_create_for_workflow_run(workflow_run=workflow_run)
|
||||
|
||||
# Get all <workflow parameter, workflow run parameter> tuples
|
||||
wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id)
|
||||
# todo(kerem): do this in a better way (a shared context manager? (not really shared because we use batch job))
|
||||
context_manager = ContextManager(wp_wps_tuples)
|
||||
# Execute workflow blocks
|
||||
blocks = workflow.workflow_definition.blocks
|
||||
for block_idx, block in enumerate(blocks):
|
||||
parameters = block.get_all_parameters()
|
||||
await context_manager.register_block_parameters(parameters)
|
||||
LOG.info(
|
||||
f"Executing root block {block.block_type} at index {block_idx} for workflow run {workflow_run.workflow_run_id}",
|
||||
block_type=block.block_type,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
await block.execute(workflow_run_id=workflow_run.workflow_run_id, context_manager=context_manager)
|
||||
|
||||
# Get last task for workflow run
|
||||
task = await self.get_last_task_for_workflow_run(workflow_run_id=workflow_run.workflow_run_id)
|
||||
if not task:
|
||||
LOG.warning(
|
||||
f"No tasks found for workflow run {workflow_run.workflow_run_id}, not sending webhook",
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
return workflow_run
|
||||
|
||||
# Update workflow status
|
||||
if task.status == "completed":
|
||||
await self.mark_workflow_run_as_completed(workflow_run_id=workflow_run.workflow_run_id)
|
||||
elif task.status == "failed":
|
||||
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
|
||||
elif task.status == "terminated":
|
||||
await self.mark_workflow_run_as_terminated(workflow_run_id=workflow_run.workflow_run_id)
|
||||
else:
|
||||
LOG.warning(
|
||||
f"Task {task.task_id} has an incomplete status {task.status}, not updating workflow run status",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
task_id=task.task_id,
|
||||
status=task.status,
|
||||
workflow_run_status=workflow_run.status,
|
||||
)
|
||||
|
||||
await self.send_workflow_response(
|
||||
workflow=workflow,
|
||||
workflow_run=workflow_run,
|
||||
api_key=api_key,
|
||||
last_task=task,
|
||||
)
|
||||
return workflow_run
|
||||
|
||||
async def create_workflow(
|
||||
self,
|
||||
organization_id: str,
|
||||
title: str,
|
||||
workflow_definition: WorkflowDefinition,
|
||||
description: str | None = None,
|
||||
) -> Workflow:
|
||||
return await app.DATABASE.create_workflow(
|
||||
organization_id=organization_id,
|
||||
title=title,
|
||||
description=description,
|
||||
workflow_definition=workflow_definition.model_dump() if workflow_definition else None,
|
||||
)
|
||||
|
||||
async def get_workflow(self, workflow_id: str) -> Workflow:
|
||||
workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise WorkflowNotFound(workflow_id)
|
||||
return workflow
|
||||
|
||||
async def update_workflow(
|
||||
self,
|
||||
workflow_id: str,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
workflow_definition: WorkflowDefinition | None = None,
|
||||
) -> Workflow | None:
|
||||
return await app.DATABASE.update_workflow(
|
||||
workflow_id=workflow_id,
|
||||
title=title,
|
||||
description=description,
|
||||
workflow_definition=workflow_definition.model_dump() if workflow_definition else None,
|
||||
)
|
||||
|
||||
async def create_workflow_run(self, workflow_request: WorkflowRequestBody, workflow_id: str) -> WorkflowRun:
|
||||
return await app.DATABASE.create_workflow_run(
|
||||
workflow_id=workflow_id,
|
||||
proxy_location=workflow_request.proxy_location,
|
||||
webhook_callback_url=workflow_request.webhook_callback_url,
|
||||
)
|
||||
|
||||
async def mark_workflow_run_as_completed(self, workflow_run_id: str) -> None:
|
||||
LOG.info(
|
||||
f"Marking workflow run {workflow_run_id} as completed", workflow_run_id=workflow_run_id, status="completed"
|
||||
)
|
||||
await app.DATABASE.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
status=WorkflowRunStatus.completed,
|
||||
)
|
||||
|
||||
async def mark_workflow_run_as_failed(self, workflow_run_id: str) -> None:
|
||||
LOG.info(f"Marking workflow run {workflow_run_id} as failed", workflow_run_id=workflow_run_id, status="failed")
|
||||
await app.DATABASE.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
status=WorkflowRunStatus.failed,
|
||||
)
|
||||
|
||||
async def mark_workflow_run_as_running(self, workflow_run_id: str) -> None:
|
||||
LOG.info(
|
||||
f"Marking workflow run {workflow_run_id} as running", workflow_run_id=workflow_run_id, status="running"
|
||||
)
|
||||
await app.DATABASE.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
status=WorkflowRunStatus.running,
|
||||
)
|
||||
|
||||
async def mark_workflow_run_as_terminated(self, workflow_run_id: str) -> None:
|
||||
LOG.info(
|
||||
f"Marking workflow run {workflow_run_id} as terminated",
|
||||
workflow_run_id=workflow_run_id,
|
||||
status="terminated",
|
||||
)
|
||||
await app.DATABASE.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
status=WorkflowRunStatus.terminated,
|
||||
)
|
||||
|
||||
async def get_workflow_runs(self, workflow_id: str) -> list[WorkflowRun]:
|
||||
return await app.DATABASE.get_workflow_runs(workflow_id=workflow_id)
|
||||
|
||||
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run_id)
|
||||
if not workflow_run:
|
||||
raise WorkflowRunNotFound(workflow_run_id)
|
||||
return workflow_run
|
||||
|
||||
async def create_workflow_parameter(
|
||||
self,
|
||||
workflow_id: str,
|
||||
workflow_parameter_type: WorkflowParameterType,
|
||||
key: str,
|
||||
default_value: bool | int | float | str | dict | list | None = None,
|
||||
description: str | None = None,
|
||||
) -> WorkflowParameter:
|
||||
return await app.DATABASE.create_workflow_parameter(
|
||||
workflow_id=workflow_id,
|
||||
workflow_parameter_type=workflow_parameter_type,
|
||||
key=key,
|
||||
description=description,
|
||||
default_value=default_value,
|
||||
)
|
||||
|
||||
async def create_aws_secret_parameter(
|
||||
self, workflow_id: str, aws_key: str, key: str, description: str | None = None
|
||||
) -> AWSSecretParameter:
|
||||
return await app.DATABASE.create_aws_secret_parameter(
|
||||
workflow_id=workflow_id, aws_key=aws_key, key=key, description=description
|
||||
)
|
||||
|
||||
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
|
||||
return await app.DATABASE.get_workflow_parameters(workflow_id=workflow_id)
|
||||
|
||||
async def create_workflow_run_parameter(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
workflow_parameter_id: str,
|
||||
value: bool | int | float | str | dict | list,
|
||||
) -> WorkflowRunParameter:
|
||||
return await app.DATABASE.create_workflow_run_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter_id,
|
||||
value=json.dumps(value) if isinstance(value, (dict, list)) else value,
|
||||
)
|
||||
|
||||
async def get_workflow_run_parameter_tuples(
|
||||
self, workflow_run_id: str
|
||||
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
|
||||
return await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
|
||||
|
||||
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
|
||||
return await app.DATABASE.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
|
||||
|
||||
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
|
||||
return await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||
|
||||
async def build_workflow_run_status_response(
|
||||
self, workflow_id: str, workflow_run_id: str, organization_id: str
|
||||
) -> WorkflowRunStatusResponse:
|
||||
workflow = await self.get_workflow(workflow_id=workflow_id)
|
||||
if workflow is None:
|
||||
LOG.error(f"Workflow {workflow_id} not found")
|
||||
raise WorkflowNotFound(workflow_id=workflow_id)
|
||||
if workflow.organization_id != organization_id:
|
||||
LOG.error(f"Workflow {workflow_id} does not belong to organization {organization_id}")
|
||||
raise WorkflowOrganizationMismatch(workflow_id=workflow_id, organization_id=organization_id)
|
||||
|
||||
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
|
||||
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||
screenshot_urls = []
|
||||
# get the last screenshot for the last 3 tasks of the workflow run
|
||||
for task in workflow_run_tasks[::-1]:
|
||||
screenshot_artifact = await app.DATABASE.get_latest_artifact(
|
||||
task_id=task.task_id,
|
||||
artifact_types=[ArtifactType.SCREENSHOT_ACTION, ArtifactType.SCREENSHOT_FINAL],
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if screenshot_artifact:
|
||||
screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(screenshot_artifact)
|
||||
if screenshot_url:
|
||||
screenshot_urls.append(screenshot_url)
|
||||
if len(screenshot_urls) >= 3:
|
||||
break
|
||||
|
||||
recording_url = None
|
||||
recording_artifact = await app.DATABASE.get_artifact_for_workflow_run(
|
||||
workflow_run_id=workflow_run_id, artifact_type=ArtifactType.RECORDING, organization_id=organization_id
|
||||
)
|
||||
if recording_artifact:
|
||||
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
|
||||
|
||||
workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
|
||||
parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples}
|
||||
return WorkflowRunStatusResponse(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
status=workflow_run.status,
|
||||
proxy_location=workflow_run.proxy_location,
|
||||
webhook_callback_url=workflow_run.webhook_callback_url,
|
||||
created_at=workflow_run.created_at,
|
||||
modified_at=workflow_run.modified_at,
|
||||
parameters=parameters_with_value,
|
||||
screenshot_urls=screenshot_urls,
|
||||
recording_url=recording_url,
|
||||
)
|
||||
|
||||
async def send_workflow_response(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
workflow_run: WorkflowRun,
|
||||
last_task: Task,
|
||||
api_key: str | None = None,
|
||||
close_browser_on_completion: bool = True,
|
||||
) -> None:
|
||||
browser_state = await app.BROWSER_MANAGER.cleanup_for_workflow_run(
|
||||
workflow_run.workflow_run_id, close_browser_on_completion
|
||||
)
|
||||
if browser_state:
|
||||
await self.persist_video_data(browser_state, workflow, workflow_run)
|
||||
await self.persist_har_data(browser_state, last_task, workflow, workflow_run)
|
||||
|
||||
# Wait for all tasks to complete before generating the links for the artifacts
|
||||
all_workflow_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(
|
||||
workflow_run_id=workflow_run.workflow_run_id
|
||||
)
|
||||
all_workflow_task_ids = [task.task_id for task in all_workflow_tasks]
|
||||
await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_tasks(all_workflow_task_ids)
|
||||
|
||||
try:
|
||||
# Wait for all tasks to complete. Currently we're using asyncio.create_task() only for uploading artifacts to S3.
|
||||
# We're excluding the current task from the list of tasks to wait for to prevent a deadlock.
|
||||
st = time.time()
|
||||
async with asyncio.timeout(30):
|
||||
await asyncio.gather(
|
||||
*[aio_task for aio_task in (asyncio.all_tasks() - {asyncio.current_task()}) if not aio_task.done()]
|
||||
)
|
||||
LOG.info(
|
||||
f"Waiting for all S3 uploads to complete took {time.time() - st} seconds",
|
||||
duration=time.time() - st,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
LOG.warning(
|
||||
"Timed out waiting for all S3 uploads to complete, not all artifacts may be uploaded. Waited 30 seconds.",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
|
||||
if not workflow_run.webhook_callback_url:
|
||||
LOG.warning(
|
||||
"Workflow has no webhook callback url. Not sending workflow response",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
if not api_key:
|
||||
LOG.warning(
|
||||
"Request has no api key. Not sending workflow response",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
workflow_run_status_response = await self.build_workflow_run_status_response(
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
# send task_response to the webhook callback url
|
||||
# TODO: use async requests (httpx)
|
||||
timestamp = str(int(datetime.utcnow().timestamp()))
|
||||
payload = workflow_run_status_response.model_dump_json()
|
||||
signature = generate_skyvern_signature(
|
||||
payload=payload,
|
||||
api_key=api_key,
|
||||
)
|
||||
headers = {
|
||||
"x-skyvern-timestamp": timestamp,
|
||||
"x-skyvern-signature": signature,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
LOG.info(
|
||||
"Sending webhook run status to webhook callback url",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
webhook_callback_url=workflow_run.webhook_callback_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
)
|
||||
try:
|
||||
resp = requests.post(workflow_run.webhook_callback_url, data=payload, headers=headers)
|
||||
if resp.ok:
|
||||
LOG.info(
|
||||
"Webhook sent successfully",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
resp_code=resp.status_code,
|
||||
resp_text=resp.text,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
"Webhook failed",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
resp=resp,
|
||||
resp_code=resp.status_code,
|
||||
resp_text=resp.text,
|
||||
resp_json=resp.json(),
|
||||
)
|
||||
except Exception as e:
|
||||
raise FailedToSendWebhook(
|
||||
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id
|
||||
) from e
|
||||
|
||||
async def persist_video_data(
|
||||
self, browser_state: BrowserState, workflow: Workflow, workflow_run: WorkflowRun
|
||||
) -> None:
|
||||
# Create recording artifact after closing the browser, so we can get an accurate recording
|
||||
video_data = await app.BROWSER_MANAGER.get_video_data(
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
browser_state=browser_state,
|
||||
)
|
||||
if video_data:
|
||||
await app.ARTIFACT_MANAGER.update_artifact_data(
|
||||
artifact_id=browser_state.browser_artifacts.video_artifact_id,
|
||||
organization_id=workflow.organization_id,
|
||||
data=video_data,
|
||||
)
|
||||
|
||||
async def persist_har_data(
|
||||
self, browser_state: BrowserState, last_task: Task, workflow: Workflow, workflow_run: WorkflowRun
|
||||
) -> None:
|
||||
har_data = await app.BROWSER_MANAGER.get_har_data(
|
||||
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id, browser_state=browser_state
|
||||
)
|
||||
if har_data:
|
||||
last_step = await app.DATABASE.get_latest_step(
|
||||
task_id=last_task.task_id, organization_id=last_task.organization_id
|
||||
)
|
||||
|
||||
if last_step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=last_step,
|
||||
artifact_type=ArtifactType.HAR,
|
||||
data=har_data,
|
||||
)
|
||||
Reference in New Issue
Block a user