Move the code over from private repository (#3)

This commit is contained in:
Kerem Yilmaz
2024-03-01 10:09:30 -08:00
committed by GitHub
parent 32dd6d92a5
commit 9eddb3d812
93 changed files with 16798 additions and 0 deletions

22
.env.example Normal file
View File

@@ -0,0 +1,22 @@
# Environment that the agent will run in.
ENV=local
# Your OpenAI API Keys. Separate multiple keys with commas. Keys will be used in order until the rate limit is reached for all keys
OPENAI_API_KEYS=["abc","def","ghi"]
# can be either "chromium-headless" or "chromium-headful".
BROWSER_TYPE="chromium-headful"
# number of times to retry scraping a page before giving up, currently set to 0
MAX_SCRAPING_RETRIES=0
# path to the directory where videos will be saved
VIDEO_PATH=./videos
# timeout for browser actions in milliseconds
BROWSER_ACTION_TIMEOUT_MS=5000
# maximum number of steps to execute per run unless the agent finishes with a terminal state (last step or error)
MAX_STEPS_PER_RUN = 50
# Control log level
LOG_LEVEL=INFO
# Database connection string
DATABASE_STRING="postgresql+psycopg://skyvern-open-source@localhost/skyvern-open-source"
# Port to run the agent on
PORT=8000

13
.flake8 Normal file
View File

@@ -0,0 +1,13 @@
[flake8]
max-line-length = 88
select = "E303, W293, W292, E305, E231, E302"
exclude =
.tox,
__pycache__,
*.pyc,
.env
venv*/*,
.venv/*,
reports/*,
dist/*,
code,

169
.gitignore vendored Normal file
View File

@@ -0,0 +1,169 @@
## Original ignores
*.env
.vscode
.idea/*
log.txt
log-ingestion.txt
logs
*.log
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
develop-eggs/
dist/
plugins/
plugins_config.yaml
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
site/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.direnv/
.env
.venv
env/
venv*/
ENV/
env.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
llama-*
vicuna-*
# mac
.DS_Store
openai/
# news
CURRENT_BULLETIN.md
*.sqlite
.mypy_cache
.pytest_cache
.vscode
ig_*
# IntelliJ
.idea/
# Skyvern ignores
videos/
artifacts/
traces/
*.pkl
har/
# Streamlit ignores
**/secrets*.toml

87
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,87 @@
default_language_version:
python: python3.11
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-added-large-files
args: ['--maxkb=5000']
exclude: |
(?x)(
inputs.*
)
- id: check-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: debug-statements
- id: detect-private-key
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
language_version: python3.11
- repo: https://github.com/psf/black
rev: 23.11.0
hooks:
- id: black
language_version: python3.11
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: python-check-blanket-noqa
- id: python-check-mock-methods
- id: python-no-log-warn
- id: python-use-type-annotations
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.0
hooks:
- id: mypy
args: [--show-error-codes, --warn-unused-configs, --disallow-untyped-calls, --disallow-untyped-defs, --disallow-incomplete-defs, --check-untyped-defs, --python-version=3.11]
additional_dependencies:
- requests
- types-requests
- types-cachetools
- alembic
exclude: |
(?x)(
^tests.*|
^streamlit_app.*|
^alembic.*
)
- repo: https://github.com/PyCQA/autoflake
rev: v2.2.1
hooks:
- id: autoflake
name: autoflake
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports
language: python
types: [ python ]
# Mono repo has bronken this TODO: fix
# - id: pytest-check
# name: pytest-check
# entry: pytest
# language: system
# pass_filenames: false
# always_run: true
- repo: https://github.com/pre-commit/mirrors-prettier
rev: "v3.1.0" # Use the sha or tag you want to point at
hooks:
- id: prettier
types: [javascript]
- repo: https://github.com/thlorenz/doctoc
rev: v2.2.0
hooks:
- id: doctoc
- repo: local
hooks:
- id: alembic-check
name: Alembic Check
entry: ./run_alembic_check.sh
language: script
stages: [ commit ]

2
.prettierignore Normal file
View File

@@ -0,0 +1,2 @@
# Ignore chrome extensions
skyvern/extensions

5
.streamlit/config.toml Normal file
View File

@@ -0,0 +1,5 @@
[theme]
# The preset Streamlit theme that your custom theme inherits from.
# One of "light" or "dark".
base = "dark"

117
alembic.ini Normal file
View File

@@ -0,0 +1,117 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python-dateutil library that can be
# installed by adding `alembic[tz]` to the pip requirements
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
timezone = UTC
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
; sqlalchemy.url = driver://user:pass@localhost/dbname
sqlalchemy.url = postgresql+psycopg://skyvern-open-source-@localhost/skyvern-open-source
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

30
alembic/README.md Normal file
View File

@@ -0,0 +1,30 @@
<!-- START doctoc generated TOC please keep comment here to allow auto update -->
<!-- DON'T EDIT THIS SECTION, INSTEAD RE-RUN doctoc TO UPDATE -->
**Table of Contents** *generated with [DocToc](https://github.com/thlorenz/doctoc)*
- [Creating a new revision](#creating-a-new-revision)
- [Running migrations](#running-migrations)
- [Downgrading migrations](#downgrading-migrations)
- [Check your current alembic setup](#check-your-current-alembic-setup)
<!-- END doctoc generated TOC please keep comment here to allow auto update -->
# Creating a new revision
```
alembic revision --autogenerate -m "enter description here"
```
**Note:** Please read [What does Autogenerate Detect (and what does it not detect?)](https://alembic.sqlalchemy.org/en/latest/autogenerate.html#what-does-autogenerate-detect-and-what-does-it-not-detect) and always make sure to review the generated revision file before running it.
# Running migrations
```
alembic upgrade head
```
# Downgrading migrations
```
alembic downgrade -1
```
# Check your current alembic setup
```
alembic current
```

81
alembic/env.py Normal file
View File

@@ -0,0 +1,81 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
from skyvern.forge.sdk.db import models
target_metadata = models.Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
from skyvern.forge.sdk.settings_manager import SettingsManager
config.set_main_option("sqlalchemy.url", SettingsManager.get_settings().DATABASE_STRING)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
print("Alembic mode: ", "offline" if context.is_offline_mode() else "online")
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

26
alembic/script.py.mako Normal file
View File

@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,284 @@
"""Create tables
Revision ID: 99423c1dec60
Revises:
Create Date: 2024-03-01 05:37:31.862957+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "99423c1dec60"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"organizations",
sa.Column("organization_id", sa.String(), nullable=False),
sa.Column("organization_name", sa.String(), nullable=False),
sa.Column("webhook_callback_url", sa.UnicodeText(), nullable=True),
sa.Column("max_steps_per_run", sa.Integer(), nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("modified_at", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("organization_id"),
)
op.create_index(op.f("ix_organizations_organization_id"), "organizations", ["organization_id"], unique=False)
op.create_table(
"organization_auth_tokens",
sa.Column("id", sa.String(), nullable=False),
sa.Column("organization_id", sa.String(), nullable=False),
sa.Column("token_type", sa.Enum("api", name="organizationauthtokentype"), nullable=False),
sa.Column("token", sa.String(), nullable=False),
sa.Column("valid", sa.Boolean(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("modified_at", sa.DateTime(), nullable=False),
sa.Column("deleted_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.organization_id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_organization_auth_tokens_id"), "organization_auth_tokens", ["id"], unique=False)
op.create_index(
op.f("ix_organization_auth_tokens_organization_id"),
"organization_auth_tokens",
["organization_id"],
unique=False,
)
op.create_index(op.f("ix_organization_auth_tokens_token"), "organization_auth_tokens", ["token"], unique=False)
op.create_table(
"workflows",
sa.Column("workflow_id", sa.String(), nullable=False),
sa.Column("organization_id", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.Column("workflow_definition", sa.JSON(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("modified_at", sa.DateTime(), nullable=False),
sa.Column("deleted_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.organization_id"],
),
sa.PrimaryKeyConstraint("workflow_id"),
)
op.create_index(op.f("ix_workflows_workflow_id"), "workflows", ["workflow_id"], unique=False)
op.create_table(
"aws_secret_parameters",
sa.Column("aws_secret_parameter_id", sa.String(), nullable=False),
sa.Column("workflow_id", sa.String(), nullable=False),
sa.Column("key", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.Column("aws_key", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("modified_at", sa.DateTime(), nullable=False),
sa.Column("deleted_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["workflow_id"],
["workflows.workflow_id"],
),
sa.PrimaryKeyConstraint("aws_secret_parameter_id"),
)
op.create_index(
op.f("ix_aws_secret_parameters_aws_secret_parameter_id"),
"aws_secret_parameters",
["aws_secret_parameter_id"],
unique=False,
)
op.create_index(
op.f("ix_aws_secret_parameters_workflow_id"), "aws_secret_parameters", ["workflow_id"], unique=False
)
op.create_table(
"workflow_parameters",
sa.Column("workflow_parameter_id", sa.String(), nullable=False),
sa.Column("workflow_parameter_type", sa.String(), nullable=False),
sa.Column("key", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.Column("workflow_id", sa.String(), nullable=False),
sa.Column("default_value", sa.String(), nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("modified_at", sa.DateTime(), nullable=False),
sa.Column("deleted_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["workflow_id"],
["workflows.workflow_id"],
),
sa.PrimaryKeyConstraint("workflow_parameter_id"),
)
op.create_index(op.f("ix_workflow_parameters_workflow_id"), "workflow_parameters", ["workflow_id"], unique=False)
op.create_index(
op.f("ix_workflow_parameters_workflow_parameter_id"),
"workflow_parameters",
["workflow_parameter_id"],
unique=False,
)
op.create_table(
"workflow_runs",
sa.Column("workflow_run_id", sa.String(), nullable=False),
sa.Column("workflow_id", sa.String(), nullable=False),
sa.Column("status", sa.String(), nullable=False),
sa.Column(
"proxy_location",
sa.Enum("US_CA", "US_NY", "US_TX", "US_FL", "US_WA", "RESIDENTIAL", "NONE", name="proxylocation"),
nullable=True,
),
sa.Column("webhook_callback_url", sa.String(), nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("modified_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(
["workflow_id"],
["workflows.workflow_id"],
),
sa.PrimaryKeyConstraint("workflow_run_id"),
)
op.create_index(op.f("ix_workflow_runs_workflow_run_id"), "workflow_runs", ["workflow_run_id"], unique=False)
op.create_table(
"tasks",
sa.Column("task_id", sa.String(), nullable=False),
sa.Column("organization_id", sa.String(), nullable=True),
sa.Column("status", sa.String(), nullable=True),
sa.Column("webhook_callback_url", sa.String(), nullable=True),
sa.Column("url", sa.String(), nullable=True),
sa.Column("navigation_goal", sa.String(), nullable=True),
sa.Column("data_extraction_goal", sa.String(), nullable=True),
sa.Column("navigation_payload", sa.JSON(), nullable=True),
sa.Column("extracted_information", sa.JSON(), nullable=True),
sa.Column("failure_reason", sa.String(), nullable=True),
sa.Column(
"proxy_location",
sa.Enum("US_CA", "US_NY", "US_TX", "US_FL", "US_WA", "RESIDENTIAL", "NONE", name="proxylocation"),
nullable=True,
),
sa.Column("extracted_information_schema", sa.JSON(), nullable=True),
sa.Column("workflow_run_id", sa.String(), nullable=True),
sa.Column("order", sa.Integer(), nullable=True),
sa.Column("retry", sa.Integer(), nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("modified_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.organization_id"],
),
sa.ForeignKeyConstraint(
["workflow_run_id"],
["workflow_runs.workflow_run_id"],
),
sa.PrimaryKeyConstraint("task_id"),
)
op.create_index(op.f("ix_tasks_task_id"), "tasks", ["task_id"], unique=False)
op.create_table(
"workflow_run_parameters",
sa.Column("workflow_run_id", sa.String(), nullable=False),
sa.Column("workflow_parameter_id", sa.String(), nullable=False),
sa.Column("value", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(
["workflow_parameter_id"],
["workflow_parameters.workflow_parameter_id"],
),
sa.ForeignKeyConstraint(
["workflow_run_id"],
["workflow_runs.workflow_run_id"],
),
sa.PrimaryKeyConstraint("workflow_run_id", "workflow_parameter_id"),
)
op.create_index(
op.f("ix_workflow_run_parameters_workflow_parameter_id"),
"workflow_run_parameters",
["workflow_parameter_id"],
unique=False,
)
op.create_index(
op.f("ix_workflow_run_parameters_workflow_run_id"), "workflow_run_parameters", ["workflow_run_id"], unique=False
)
op.create_table(
"steps",
sa.Column("step_id", sa.String(), nullable=False),
sa.Column("organization_id", sa.String(), nullable=True),
sa.Column("task_id", sa.String(), nullable=True),
sa.Column("status", sa.String(), nullable=True),
sa.Column("output", sa.JSON(), nullable=True),
sa.Column("order", sa.Integer(), nullable=True),
sa.Column("is_last", sa.Boolean(), nullable=True),
sa.Column("retry_index", sa.Integer(), nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("modified_at", sa.DateTime(), nullable=False),
sa.Column("input_token_count", sa.Integer(), nullable=True),
sa.Column("output_token_count", sa.Integer(), nullable=True),
sa.Column("step_cost", sa.Numeric(), nullable=True),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.organization_id"],
),
sa.ForeignKeyConstraint(
["task_id"],
["tasks.task_id"],
),
sa.PrimaryKeyConstraint("step_id"),
)
op.create_index(op.f("ix_steps_step_id"), "steps", ["step_id"], unique=False)
op.create_table(
"artifacts",
sa.Column("artifact_id", sa.String(), nullable=False),
sa.Column("organization_id", sa.String(), nullable=True),
sa.Column("task_id", sa.String(), nullable=True),
sa.Column("step_id", sa.String(), nullable=True),
sa.Column("artifact_type", sa.String(), nullable=True),
sa.Column("uri", sa.String(), nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("modified_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.organization_id"],
),
sa.ForeignKeyConstraint(
["step_id"],
["steps.step_id"],
),
sa.ForeignKeyConstraint(
["task_id"],
["tasks.task_id"],
),
sa.PrimaryKeyConstraint("artifact_id"),
)
op.create_index(op.f("ix_artifacts_artifact_id"), "artifacts", ["artifact_id"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_artifacts_artifact_id"), table_name="artifacts")
op.drop_table("artifacts")
op.drop_index(op.f("ix_steps_step_id"), table_name="steps")
op.drop_table("steps")
op.drop_index(op.f("ix_workflow_run_parameters_workflow_run_id"), table_name="workflow_run_parameters")
op.drop_index(op.f("ix_workflow_run_parameters_workflow_parameter_id"), table_name="workflow_run_parameters")
op.drop_table("workflow_run_parameters")
op.drop_index(op.f("ix_tasks_task_id"), table_name="tasks")
op.drop_table("tasks")
op.drop_index(op.f("ix_workflow_runs_workflow_run_id"), table_name="workflow_runs")
op.drop_table("workflow_runs")
op.drop_index(op.f("ix_workflow_parameters_workflow_parameter_id"), table_name="workflow_parameters")
op.drop_index(op.f("ix_workflow_parameters_workflow_id"), table_name="workflow_parameters")
op.drop_table("workflow_parameters")
op.drop_index(op.f("ix_aws_secret_parameters_workflow_id"), table_name="aws_secret_parameters")
op.drop_index(op.f("ix_aws_secret_parameters_aws_secret_parameter_id"), table_name="aws_secret_parameters")
op.drop_table("aws_secret_parameters")
op.drop_index(op.f("ix_workflows_workflow_id"), table_name="workflows")
op.drop_table("workflows")
op.drop_index(op.f("ix_organization_auth_tokens_token"), table_name="organization_auth_tokens")
op.drop_index(op.f("ix_organization_auth_tokens_organization_id"), table_name="organization_auth_tokens")
op.drop_index(op.f("ix_organization_auth_tokens_id"), table_name="organization_auth_tokens")
op.drop_table("organization_auth_tokens")
op.drop_index(op.f("ix_organizations_organization_id"), table_name="organizations")
op.drop_table("organizations")
# ### end Alembic commands ###

10
mypy.ini Normal file
View File

@@ -0,0 +1,10 @@
[mypy]
namespace_packages = True
follow_imports = skip
check_untyped_defs = True
disallow_untyped_defs = True
exclude = ^(venv|venv-dev)
ignore_missing_imports = True
[mypy-numpy.*]
ignore_errors = True

6863
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

92
pyproject.toml Normal file
View File

@@ -0,0 +1,92 @@
[tool.poetry]
name = "skyvern"
version = "0.1.0"
description = ""
authors = ["Skyvern AI <info@skyvern.com>"]
readme = "README.md"
packages = [{ include = "skyvern" }]
[tool.poetry.dependencies]
python = "^3.10"
python-dotenv = "^1.0.0"
openai = "<1.8"
tenacity = "^8.2.2"
sqlalchemy = "^2.0.23"
aiohttp = "^3.8.5"
colorlog = "^6.7.0"
chromadb = "^0.4.10"
python-multipart = "^0.0.6"
toml = "^0.10.2"
jinja2 = "^3.1.2"
uvicorn = {extras = ["standard"], version = "^0.24.0.post1"}
litellm = "^1.0.0"
duckduckgo-search = "^3.8.0"
selenium = "^4.13.0"
bs4 = "^0.0.1"
webdriver-manager = "^4.0.1"
playwright = "^1.39.0"
pre-commit = "^3.5.0"
pillow = "^10.1.0"
starlette-context = "^0.3.6"
sqlalchemy-stubs = "^0.4"
ddtrace = "^2.3.2"
pydantic = "^2.5.2"
pydantic-settings = "^2.1.0"
fastapi = "^0.104.1"
psycopg = {extras = ["binary", "pool"], version = "^3.1.13"}
alembic = "^1.12.1"
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
cachetools = "^5.3.2"
aioboto3 = "^12.0.0"
commentjson = "^0.9.0"
asyncache = "^0.3.1"
orjson = "^3.9.10"
structlog = "^23.2.0"
plotly = "^5.18.0"
[tool.poetry.group.dev.dependencies]
isort = "^5.12.0"
black = "^23.3.0"
pre-commit = "^3.3.3"
mypy = "^1.4.1"
flake8 = "^6.0.0"
types-requests = "^2.31.0.2"
pytest = "^7.4.0"
pytest-asyncio = "^0.21.1"
watchdog = "^3.0.0"
mock = "^5.1.0"
autoflake = "^2.2.0"
pydevd-pycharm = "^233.6745.319"
ipython = "^8.17.2"
streamlit = "^1.28.1"
typer = "^0.9.0"
ipykernel = "^6.26.0"
notebook = "^7.0.6"
freezegun = "^1.2.2"
snoop = "^0.4.3"
rich = {extras = ["jupyter"], version = "^13.7.0"}
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.black]
line-length = 120
target-version = ['py311']
include = '\.pyi?$'
packages = []
extend-exclude = '(/dist|/.venv|/venv|/build)/'
[tool.isort]
profile = "black"
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 120
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
skip_glob = [".tox", "__pycache__", "*.pyc", "venv*/*", "reports", "venv", "env", "node_modules", ".env", ".venv", "dist"]
skip = ["webeye/actions/__init__.py", "forge/sdk/__init__.py"]

9
run Executable file
View File

@@ -0,0 +1,9 @@
#!/bin/bash
kill $(lsof -t -i :8000)
if [ ! -f .env ]; then
cp .env.example .env
echo "Please add your api keys to the .env file."
fi
poetry run python -m skyvern.forge

6
run_alembic_check.sh Executable file
View File

@@ -0,0 +1,6 @@
#!/bin/sh
# first apply migrations
export PATH=$PATH:.venv/bin
alembic upgrade head
# then check if the database is up to date with the models
alembic check

47
scripts/create_api_key.py Normal file
View File

@@ -0,0 +1,47 @@
import asyncio
from datetime import timedelta
import typer
from skyvern.forge.app import DATABASE
from skyvern.forge.sdk.core import security
from skyvern.forge.sdk.models import OrganizationAuthToken, OrganizationAuthTokenType
API_KEY_LIFETIME = timedelta(weeks=5200)
async def create_org_api_token(org_id: str) -> OrganizationAuthToken:
"""Creates an API token for the specified org_id.
Args:
org_id: The org_id for which to create an API token.
Returns:
The API token created for the specified org_id.
"""
# get the organization
organization = await DATABASE.get_organization(org_id)
if not organization:
raise Exception(f"Organization id {org_id} not found")
# [START create_org_api_token]
api_key = security.create_access_token(
org_id,
expires_delta=API_KEY_LIFETIME,
)
# generate OrganizationAutoToken
org_auth_token = await DATABASE.create_org_auth_token(
organization_id=org_id,
token=api_key,
token_type=OrganizationAuthTokenType.api,
)
print(f"Created API token for organization {org_auth_token}")
return org_auth_token
def main(org_id: str) -> None:
asyncio.run(create_org_api_token(org_id))
if __name__ == "__main__":
typer.run(main)

View File

@@ -0,0 +1,21 @@
import asyncio
from typing import Annotated, Optional
import typer
from scripts.create_api_key import create_org_api_token
from skyvern.forge.app import DATABASE
async def create_org(org_name: str, webhook_callback_url: str | None = None) -> None:
organization = await DATABASE.create_organization(org_name, webhook_callback_url)
print(f"Created organization: {organization}")
await create_org_api_token(organization.organization_id)
def main(org_name: str, webhook_callback_url: Annotated[Optional[str], typer.Argument()] = None) -> None:
asyncio.run(create_org(org_name, webhook_callback_url))
if __name__ == "__main__":
typer.run(main)

92
setup.sh Executable file
View File

@@ -0,0 +1,92 @@
#!/bin/bash
# Function to check if a command exists
command_exists() {
command -v "$1" &> /dev/null
}
# Ensure required commands are available
for cmd in poetry pre-commit brew python; do
if ! command_exists "$cmd"; then
echo "Error: $cmd is not installed." >&2
exit 1
fi
done
# Function to remove Poetry environment
remove_poetry_env() {
local env_path
env_path=$(poetry env info --path)
if [ -d "$env_path" ]; then
rm -rf "$env_path"
echo "Removed the poetry environment at $env_path."
else
echo "No poetry environment found."
fi
}
# Function to install dependencies
install_dependencies() {
poetry install
pre-commit install
}
activate_poetry_env() {
source "$(poetry env info --path)/bin/activate"
}
# Function to setup PostgreSQL
setup_postgresql() {
echo "Installing postgresql using brew"
brew install postgresql@14
brew services start postgresql@14
if psql skyvern-open-source -U skyvern-open-source -c '\q'; then
echo "Connection successful. Database and user exist."
else
createuser skyvern-open-source
createdb skyvern-open-source -O skyvern-open-source
echo "Database and user created successfully."
fi
}
# Function to run Alembic upgrade
run_alembic_upgrade() {
echo "Running Alembic upgrade..."
alembic upgrade head
}
# Function to create organization and API token
create_organization() {
echo "Creating organization and API token..."
local org_output api_token
org_output=$(python scripts/create_organization.py Skyvern-Open-Source)
api_token=$(echo "$org_output" | awk '/token=/{gsub(/.*token='\''|'\''.*/, ""); print}')
# Ensure .streamlit directory exists
mkdir -p .streamlit
# Check if secrets.toml exists and back it up
if [ -f ".streamlit/secrets.toml" ]; then
mv .streamlit/secrets.toml .streamlit/secrets.backup.toml
echo "Existing secrets.toml file backed up as secrets.backup.toml"
fi
# Update the secrets-open-source.toml file
echo -e "[skyvern]\nconfigs = [\n {\"env\" = \"local\", \"host\" = \"http://0.0.0.0:8000/api/v1\", \"orgs\" = [{name=\"Skyvern-Open-Source\", cred=\"$api_token\"}]}\n]" > .streamlit/secrets.toml
echo ".streamlit/secrets.toml file updated with organization details."
}
# Main function
main() {
remove_poetry_env
install_dependencies
setup_postgresql
activate_poetry_env
run_alembic_upgrade
create_organization
echo "Setup completed successfully."
}
# Execute main function
main

0
skyvern/__init__.py Normal file
View File

61
skyvern/config.py Normal file
View File

@@ -0,0 +1,61 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from skyvern.constants import SKYVERN_DIR
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=(".env", ".env.staging", ".env.prod"), extra="ignore")
ADDITIONAL_MODULES: list[str] = []
BROWSER_TYPE: str = "chromium-headful"
MAX_SCRAPING_RETRIES: int = 0
VIDEO_PATH: str | None = None
HAR_PATH: str | None = "./har"
BROWSER_ACTION_TIMEOUT_MS: int = 5000
MAX_STEPS_PER_RUN: int = 75
MAX_NUM_SCREENSHOTS: int = 6
# Ratio should be between 0 and 1.
# If the task has been running for more steps than this ratio of the max steps per run, then we'll log a warning.
LONG_RUNNING_TASK_WARNING_RATIO: float = 0.95
MAX_RETRIES_PER_STEP: int = 5
DEBUG_MODE: bool = False
DATABASE_STRING: str = "postgresql+psycopg://skyvern@localhost/skyvern"
PROMPT_ACTION_HISTORY_WINDOW: int = 5
OPENAI_API_KEYS: list[str] = []
ENV: str = "local"
EXECUTE_ALL_STEPS: bool = True
JSON_LOGGING: bool = False
PORT: int = 8000
SECRET_KEY: str = "RX1NvhujcJqBPi8O78-7aSfJEWuT86-fll4CzKc_uek"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 7 # one week
SKYVERN_API_KEY: str = "SKYVERN_API_KEY"
# Artifact storage settings
ARTIFACT_STORAGE_PATH: str = f"{SKYVERN_DIR}/artifacts"
ASYNC_ENABLED: bool = False
def is_cloud_environment(self) -> bool:
"""
:return: True if env is not local, else False
"""
return self.ENV != "local"
def execute_all_steps(self) -> bool:
"""
This provides the functionality to execute steps one by one through the Streamlit UI.
***Value is always True if ENV is not local.***
:return: True if env is not local, else the value of EXECUTE_ALL_STEPS
"""
if self.is_cloud_environment():
return True
else:
return self.EXECUTE_ALL_STEPS
settings = Settings()

6
skyvern/constants.py Normal file
View File

@@ -0,0 +1,6 @@
from pathlib import Path
# This is the attribute name used to tag interactable elements
SKYVERN_ID_ATTR: str = "unique_id"
SKYVERN_DIR = Path(__file__).parent
REPO_ROOT_DIR = SKYVERN_DIR.parent

181
skyvern/exceptions.py Normal file
View File

@@ -0,0 +1,181 @@
class SkyvernException(Exception):
def __init__(self, message: str | None = None):
self.message = message
super().__init__(message)
class NoAvailableOpenAIClients(SkyvernException):
def __init__(self, message: str | None = None):
super().__init__("No available OpenAI API clients found.")
class InvalidOpenAIResponseFormat(SkyvernException):
def __init__(self, message: str | None = None):
super().__init__(f"Invalid response format: {message}")
class OpenAIRequestTooBigError(SkyvernException):
def __init__(self, message: str | None = None):
super().__init__(f"OpenAI request 429 error: {message}")
class FailedToSendWebhook(SkyvernException):
def __init__(self, task_id: str | None = None, workflow_run_id: str | None = None, workflow_id: str | None = None):
workflow_run_str = f"workflow_run_id={workflow_run_id}" if workflow_run_id else ""
workflow_str = f"workflow_id={workflow_id}" if workflow_id else ""
task_str = f"task_id={task_id}" if task_id else ""
super().__init__(f"Failed to send webhook. {workflow_run_str} {workflow_str} {task_str}")
class ProxyLocationNotSupportedError(SkyvernException):
def __init__(self, proxy_location: str | None = None):
super().__init__(f"Unknown proxy location: {proxy_location}")
class TaskNotFound(SkyvernException):
def __init__(self, task_id: str | None = None):
super().__init__(f"Task {task_id} not found")
class ScriptNotFound(SkyvernException):
def __init__(self, script_name: str | None = None):
super().__init__(f"Script {script_name} not found. Has the script been registered?")
class MissingElement(SkyvernException):
def __init__(self, xpath: str | None = None, element_id: int | None = None):
super().__init__(
f"Found no elements. Might be due to previous actions which removed this element."
f" xpath={xpath} element_id={element_id}",
)
class MultipleElementsFound(SkyvernException):
def __init__(self, num: int, xpath: str | None = None, element_id: int | None = None):
super().__init__(
f"Found {num} elements. Expected 1. num_elements={num} xpath={xpath} element_id={element_id}",
)
class MissingFileUrl(SkyvernException):
def __init__(self) -> None:
super().__init__("File url is missing.")
class ImaginaryFileUrl(SkyvernException):
def __init__(self, file_url: str) -> None:
super().__init__(f"File url {file_url} is imaginary.")
class MissingBrowserState(SkyvernException):
def __init__(self, task_id: str) -> None:
super().__init__(f"Browser state for task {task_id} is missing.")
class MissingBrowserStatePage(SkyvernException):
def __init__(self, task_id: str | None = None, workflow_run_id: str | None = None):
task_str = f"task_id={task_id}" if task_id else ""
workflow_run_str = f"workflow_run_id={workflow_run_id}" if workflow_run_id else ""
super().__init__(f"Browser state page is missing. {task_str} {workflow_run_str}")
class MissingWorkflowRunBrowserState(SkyvernException):
def __init__(self, workflow_run_id: str, task_id: str) -> None:
super().__init__(f"Browser state for workflow run {workflow_run_id} and task {task_id} is missing.")
class CaptchaNotSolvedInTime(SkyvernException):
def __init__(self, task_id: str, final_state: str) -> None:
super().__init__(f"Captcha not solved in time for task {task_id}. Final state: {final_state}")
class EnablingCaptchaSolver(SkyvernException):
def __init__(self) -> None:
super().__init__("Enabling captcha solver. Reload the page and try again.")
class ContextParameterValueNotFound(SkyvernException):
def __init__(self, parameter_key: str, existing_keys: list[str], workflow_run_id: str) -> None:
super().__init__(
f"Context parameter value not found during workflow run {workflow_run_id}. "
f"Parameter key: {parameter_key}. Existing keys: {existing_keys}"
)
class UnknownBlockType(SkyvernException):
def __init__(self, block_type: str) -> None:
super().__init__(f"Unknown block type {block_type}")
class WorkflowNotFound(SkyvernException):
def __init__(self, workflow_id: str) -> None:
super().__init__(f"Workflow {workflow_id} not found")
class WorkflowRunNotFound(SkyvernException):
def __init__(self, workflow_run_id: str) -> None:
super().__init__(f"WorkflowRun {workflow_run_id} not found")
class WorkflowOrganizationMismatch(SkyvernException):
def __init__(self, workflow_id: str, organization_id: str) -> None:
super().__init__(f"Workflow {workflow_id} does not belong to organization {organization_id}")
class MissingValueForParameter(SkyvernException):
def __init__(self, parameter_key: str, workflow_id: str, workflow_run_id: str) -> None:
super().__init__(
f"Missing value for parameter {parameter_key} in workflow run {workflow_run_id} of workflow {workflow_id}"
)
class WorkflowParameterNotFound(SkyvernException):
def __init__(self, workflow_parameter_id: str) -> None:
super().__init__(f"Workflow parameter {workflow_parameter_id} not found")
class FailedToNavigateToUrl(SkyvernException):
def __init__(self, url: str, error_message: str) -> None:
super().__init__(f"Failed to navigate to url {url}. Error message: {error_message}")
class UnexpectedTaskStatus(SkyvernException):
def __init__(self, task_id: str, status: str) -> None:
super().__init__(f"Unexpected task status {status} for task {task_id}")
class InvalidWorkflowTaskURLState(SkyvernException):
def __init__(self, workflow_run_id: str) -> None:
super().__init__(f"No Valid URL found in the first task")
class DisabledFeature(SkyvernException):
def __init__(self, feature: str) -> None:
super().__init__(f"Feature {feature} is disabled")
class UnknownBrowserType(SkyvernException):
def __init__(self, browser_type: str) -> None:
super().__init__(f"Unknown browser type {browser_type}")
class UnknownErrorWhileCreatingBrowserContext(SkyvernException):
def __init__(self, browser_type: str, exception: Exception) -> None:
super().__init__(
f"Unknown error while creating browser context for {browser_type}. Exception type: {type(exception)} Exception message: {str(exception)}"
)
class BrowserStateMissingPage(SkyvernException):
def __init__(self) -> None:
super().__init__("BrowserState is missing the main page")
class OrganizationNotFound(SkyvernException):
def __init__(self, organization_id: str) -> None:
super().__init__(f"Organization {organization_id} not found")
class StepNotFound(SkyvernException):
def __init__(self, organization_id: str, task_id: str, step_id: str | None = None) -> None:
super().__init__(f"Step {step_id or 'latest'} not found. organization_id={organization_id} task_id={task_id}")

View File

18
skyvern/forge/__main__.py Normal file
View File

@@ -0,0 +1,18 @@
import structlog
import uvicorn
from dotenv import load_dotenv
import skyvern.forge.sdk.forge_log as forge_log
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.stdlib.get_logger()
if __name__ == "__main__":
forge_log.setup_logger()
port = SettingsManager.get_settings().PORT
LOG.info("Agent server starting.", host="0.0.0.0", port=port)
load_dotenv()
reload = SettingsManager.get_settings().ENV == "local"
uvicorn.run("skyvern.forge.app:app", host="0.0.0.0", port=port, log_level="info", reload=reload)

985
skyvern/forge/agent.py Normal file
View File

@@ -0,0 +1,985 @@
import asyncio
import json
import random
from datetime import datetime
from typing import Any, Tuple
import requests
import structlog
from playwright._impl._errors import TargetClosedError
from skyvern.exceptions import (
BrowserStateMissingPage,
FailedToSendWebhook,
InvalidWorkflowTaskURLState,
MissingBrowserStatePage,
TaskNotFound,
)
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.agent import Agent
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.models import Organization, Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskStatus
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.context_manager import ContextManager
from skyvern.forge.sdk.workflow.models.block import TaskBlock
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun
from skyvern.webeye.actions.actions import Action, ActionType, CompleteAction, parse_actions
from skyvern.webeye.actions.handler import ActionHandler
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
from skyvern.webeye.actions.responses import ActionResult
from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ScrapedPage, scrape_website
LOG = structlog.get_logger()
class ForgeAgent(Agent):
def __init__(self) -> None:
LOG.info(
"Initializing ForgeAgent",
env=SettingsManager.get_settings().ENV,
execute_all_steps=SettingsManager.get_settings().EXECUTE_ALL_STEPS,
browser_type=SettingsManager.get_settings().BROWSER_TYPE,
max_scraping_retries=SettingsManager.get_settings().MAX_SCRAPING_RETRIES,
video_path=SettingsManager.get_settings().VIDEO_PATH,
browser_action_timeout_ms=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS,
max_steps_per_run=SettingsManager.get_settings().MAX_STEPS_PER_RUN,
long_running_task_warning_ratio=SettingsManager.get_settings().LONG_RUNNING_TASK_WARNING_RATIO,
debug_mode=SettingsManager.get_settings().DEBUG_MODE,
)
if SettingsManager.get_settings().ADDITIONAL_MODULES:
for module in SettingsManager.get_settings().ADDITIONAL_MODULES:
LOG.info("Loading additional module", module=module)
__import__(module)
LOG.info("Additional modules loaded", modules=SettingsManager.get_settings().ADDITIONAL_MODULES)
async def validate_step_execution(
self,
task: Task,
step: Step,
) -> None:
"""
Checks if the step can be executed.
:return: A tuple of whether the step can be executed and a list of reasons why it can't be executed.
"""
reasons = []
# can't execute if task status is not running
has_valid_task_status = task.status == TaskStatus.running
if not has_valid_task_status:
reasons.append(f"invalid_task_status:{task.status}")
# can't execute if the step is already running or completed
has_valid_step_status = step.status in [StepStatus.created, StepStatus.failed]
if not has_valid_step_status:
reasons.append(f"invalid_step_status:{step.status}")
# can't execute if the task has another step that is running
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
has_no_running_steps = not any(step.status == StepStatus.running for step in steps)
if not has_no_running_steps:
reasons.append(f"another_step_is_running_for_task:{task.task_id}")
can_execute = has_valid_task_status and has_valid_step_status and has_no_running_steps
if not can_execute:
raise Exception(f"Cannot execute step. Reasons: {reasons}, Step: {step}")
async def create_task_and_step_from_block(
self,
task_block: TaskBlock,
workflow: Workflow,
workflow_run: WorkflowRun,
context_manager: ContextManager,
task_order: int,
task_retry: int,
) -> tuple[Task, Step]:
task_block_parameters = task_block.parameters
navigation_payload = {}
for parameter in task_block_parameters:
navigation_payload[parameter.key] = context_manager.get_value(parameter.key)
task_url = task_block.url
if task_url is None:
browser_state = await app.BROWSER_MANAGER.get_or_create_for_workflow_run(workflow_run=workflow_run)
if not browser_state.page:
LOG.error("BrowserState has no page", workflow_run_id=workflow_run.workflow_run_id)
raise MissingBrowserStatePage(workflow_run_id=workflow_run.workflow_run_id)
if browser_state.page.url == "about:blank":
raise InvalidWorkflowTaskURLState(workflow_run.workflow_run_id)
task_url = browser_state.page.url
task = await app.DATABASE.create_task(
url=task_url,
webhook_callback_url=None,
navigation_goal=task_block.navigation_goal,
data_extraction_goal=task_block.data_extraction_goal,
navigation_payload=navigation_payload,
organization_id=workflow.organization_id,
proxy_location=workflow_run.proxy_location,
extracted_information_schema=task_block.data_schema,
workflow_run_id=workflow_run.workflow_run_id,
order=task_order,
retry=task_retry,
)
LOG.info(
"Created new task for workflow run",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
task_id=task.task_id,
url=task.url,
nav_goal=task.navigation_goal,
data_goal=task.data_extraction_goal,
proxy_location=task.proxy_location,
task_order=task_order,
task_retry=task_retry,
)
# Update task status to running
task = await app.DATABASE.update_task(
task_id=task.task_id, organization_id=task.organization_id, status=TaskStatus.running
)
step = await app.DATABASE.create_step(
task.task_id,
order=0,
retry_index=0,
organization_id=task.organization_id,
)
LOG.info(
"Created new step for workflow run",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
step_id=step.step_id,
task_id=task.task_id,
order=step.order,
retry_index=step.retry_index,
)
return task, step
async def create_task(self, task_request: TaskRequest, organization_id: str | None = None) -> Task:
task = await app.DATABASE.create_task(
url=task_request.url,
webhook_callback_url=task_request.webhook_callback_url,
navigation_goal=task_request.navigation_goal,
data_extraction_goal=task_request.data_extraction_goal,
navigation_payload=task_request.navigation_payload,
organization_id=organization_id,
proxy_location=task_request.proxy_location,
extracted_information_schema=task_request.extracted_information_schema,
)
LOG.info(
"Created new task",
task_id=task.task_id,
url=task.url,
nav_goal=task.navigation_goal,
data_goal=task.data_extraction_goal,
proxy_location=task.proxy_location,
)
return task
async def execute_step(
self,
organization: Organization,
task: Task,
step: Step,
api_key: str | None = None,
workflow_run: WorkflowRun | None = None,
close_browser_on_completion: bool = True,
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
next_step: Step | None = None
detailed_output: DetailedAgentStepOutput | None = None
try:
# Check some conditions before executing the step, throw an exception if the step can't be executed
await self.validate_step_execution(task, step)
step, browser_state, detailed_output = await self._initialize_execution_state(task, step, workflow_run)
step, detailed_output = await self.agent_step(task, step, browser_state, organization=organization)
retry = False
# If the step failed, mark the step as failed and retry
if step.status == StepStatus.failed:
maybe_next_step = await self.handle_failed_step(task, step)
# If there is no next step, it means that the task has failed
if maybe_next_step:
next_step = maybe_next_step
retry = True
else:
await self.send_task_response(
task=task,
last_step=step,
api_key=api_key,
close_browser_on_completion=close_browser_on_completion,
)
return step, detailed_output, None
elif step.status == StepStatus.completed:
# TODO (kerem): keep the task object uptodate at all times so that send_task_response can just use it
is_task_completed, maybe_last_step, maybe_next_step = await self.handle_completed_step(
organization, task, step
)
if is_task_completed is not None and maybe_last_step:
last_step = maybe_last_step
await self.send_task_response(
task=task,
last_step=last_step,
api_key=api_key,
close_browser_on_completion=close_browser_on_completion,
)
return last_step, detailed_output, None
elif maybe_next_step:
next_step = maybe_next_step
retry = False
else:
LOG.error(
"Step completed but task is not completed and next step is not created.",
task_id=task.task_id,
step_id=step.step_id,
is_task_completed=is_task_completed,
maybe_last_step=maybe_last_step,
maybe_next_step=maybe_next_step,
)
else:
LOG.error(
"Unexpected step status after agent_step",
task_id=task.task_id,
step_id=step.step_id,
step_status=step.status,
)
if retry and next_step:
return await self.execute_step(
organization,
task,
next_step,
api_key=api_key,
)
elif SettingsManager.get_settings().execute_all_steps() and next_step:
return await self.execute_step(
organization,
task,
next_step,
api_key=api_key,
)
else:
LOG.info(
"Step executed but continuous execution is disabled.",
task_id=task.task_id,
step_id=step.step_id,
is_cloud_env=SettingsManager.get_settings().is_cloud_environment(),
execute_all_steps=SettingsManager.get_settings().execute_all_steps(),
next_step_id=next_step.step_id if next_step else None,
)
return step, detailed_output, next_step
# TODO (kerem): Let's add other exceptions that we know about here as custom exceptions as well
except FailedToSendWebhook:
LOG.exception(
"Failed to send webhook",
exc_info=True,
task_id=task.task_id,
step_id=step.step_id,
task=task,
step=step,
)
return step, detailed_output, next_step
async def agent_step(
self,
task: Task,
step: Step,
browser_state: BrowserState,
organization: Organization | None = None,
) -> tuple[Step, DetailedAgentStepOutput]:
detailed_agent_step_output = DetailedAgentStepOutput(
scraped_page=None,
extract_action_prompt=None,
llm_response=None,
actions=None,
action_results=None,
actions_and_results=None,
)
try:
LOG.info(
"Starting agent step",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
)
step = await self.update_step(step=step, status=StepStatus.running)
scraped_page, extract_action_prompt = await self._build_and_record_step_prompt(
task,
step,
browser_state,
)
detailed_agent_step_output.scraped_page = scraped_page
detailed_agent_step_output.extract_action_prompt = extract_action_prompt
json_response = None
actions: list[Action]
if task.navigation_goal:
json_response = await app.OPENAI_CLIENT.chat_completion(
step=step,
prompt=extract_action_prompt,
screenshots=scraped_page.screenshots,
)
detailed_agent_step_output.llm_response = json_response
actions = parse_actions(task, json_response["actions"])
else:
actions = [
CompleteAction(
reasoning="Task has no navigation goal.", data_extraction_goal=task.data_extraction_goal
)
]
detailed_agent_step_output.actions = actions
if len(actions) == 0:
LOG.info(
"No actions to execute, marking step as failed",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
)
step = await self.update_step(
step=step, status=StepStatus.failed, output=detailed_agent_step_output.to_agent_step_output()
)
detailed_agent_step_output = DetailedAgentStepOutput(
scraped_page=scraped_page,
extract_action_prompt=extract_action_prompt,
llm_response=json_response,
actions=actions,
action_results=[],
actions_and_results=[],
)
return step, detailed_agent_step_output
# Execute the actions
LOG.info(
"Executing actions",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
actions=actions,
)
action_results: list[ActionResult] = []
detailed_agent_step_output.action_results = action_results
# filter out wait action if there are other actions in the list
# we do this because WAIT action is considered as a failure
# which will block following actions if we don't remove it from the list
# if the list only contains WAIT action, we will execute WAIT action(s)
if len(actions) > 1:
wait_actions_to_skip = [action for action in actions if action.action_type == ActionType.WAIT]
wait_actions_len = len(wait_actions_to_skip)
# if there are wait actions and there are other actions in the list, skip wait actions
if wait_actions_len > 0 and wait_actions_len < len(actions):
actions = [action for action in actions if action.action_type != ActionType.WAIT]
LOG.info("Skipping wait actions", wait_actions_to_skip=wait_actions_to_skip, actions=actions)
# initialize list of tuples and set actions as the first element of each tuple so that in the case
# of an exception, we can still see all the actions
detailed_agent_step_output.actions_and_results = [(action, []) for action in actions]
for action_idx, action in enumerate(actions):
results = await ActionHandler.handle_action(scraped_page, task, step, browser_state, action)
detailed_agent_step_output.actions_and_results[action_idx] = (action, results)
# wait random time between actions to avoid detection
await asyncio.sleep(random.uniform(1.0, 2.0))
await self.record_artifacts_after_action(task, step, browser_state)
for result in results:
result.step_retry_number = step.retry_index
result.step_order = step.order
action_results.extend(results)
# Check the last result for this action. If that succeeded, assume the entire action is successful
if results and results[-1].success:
LOG.info(
"Action succeeded",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
action_idx=action_idx,
action=action,
action_result=results,
)
# if the action triggered javascript calls
# this action should be the last action this round and do not take more actions.
# for now, we're being optimistic and assuming that
# js call doesn't have impact on the following actions
if results[-1].javascript_triggered:
LOG.info("Action triggered javascript, ", action=action)
else:
LOG.warning(
"Action failed, marking step as failed",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
action_idx=action_idx,
action=action,
action_result=results,
actions_and_results=detailed_agent_step_output.actions_and_results,
)
# if the action failed, don't execute the rest of the actions, mark the step as failed, and retry
failed_step = await self.update_step(
step=step, status=StepStatus.failed, output=detailed_agent_step_output.to_agent_step_output()
)
return failed_step, detailed_agent_step_output
LOG.info(
"Actions executed successfully, marking step as completed",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
action_results=action_results,
)
# If no action errors return the agent state and output
completed_step = await self.update_step(
step=step, status=StepStatus.completed, output=detailed_agent_step_output.to_agent_step_output()
)
return completed_step, detailed_agent_step_output
except Exception:
LOG.exception(
"Unexpected exception in agent_step, marking step as failed",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
)
failed_step = await self.update_step(
step=step, status=StepStatus.failed, output=detailed_agent_step_output.to_agent_step_output()
)
return failed_step, detailed_agent_step_output
async def record_artifacts_after_action(self, task: Task, step: Step, browser_state: BrowserState) -> None:
if not browser_state.page:
raise BrowserStateMissingPage()
try:
screenshot = await browser_state.page.screenshot(full_page=True)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_ACTION,
data=screenshot,
)
except Exception:
LOG.error(
"Failed to record screenshot after action",
task_id=task.task_id,
step_id=step.step_id,
exc_info=True,
)
try:
html = await browser_state.page.content()
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.HTML_ACTION,
data=html.encode(),
)
except Exception:
LOG.error(
"Failed to record html after action",
task_id=task.task_id,
step_id=step.step_id,
exc_info=True,
)
try:
video_data = await app.BROWSER_MANAGER.get_video_data(task_id=task.task_id, browser_state=browser_state)
await app.ARTIFACT_MANAGER.update_artifact_data(
artifact_id=browser_state.browser_artifacts.video_artifact_id,
organization_id=task.organization_id,
data=video_data,
)
except Exception:
LOG.error(
"Failed to record video after action",
task_id=task.task_id,
step_id=step.step_id,
exc_info=True,
)
async def _initialize_execution_state(
self, task: Task, step: Step, workflow_run: WorkflowRun | None = None
) -> tuple[Step, BrowserState, DetailedAgentStepOutput]:
if workflow_run:
browser_state = await app.BROWSER_MANAGER.get_or_create_for_workflow_run(
workflow_run=workflow_run, url=task.url
)
else:
browser_state = await app.BROWSER_MANAGER.get_or_create_for_task(task)
# Initialize video artifact for the task here, afterwards it'll only get updated
if browser_state and not browser_state.browser_artifacts.video_artifact_id:
video_data = await app.BROWSER_MANAGER.get_video_data(task_id=task.task_id, browser_state=browser_state)
video_artifact_id = await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.RECORDING,
data=video_data,
)
app.BROWSER_MANAGER.set_video_artifact_for_task(task, video_artifact_id)
detailed_output = DetailedAgentStepOutput(
scraped_page=None,
extract_action_prompt=None,
llm_response=None,
actions=None,
action_results=None,
actions_and_results=None,
)
return step, browser_state, detailed_output
async def _build_and_record_step_prompt(
self,
task: Task,
step: Step,
browser_state: BrowserState,
) -> tuple[ScrapedPage, str]:
# Scrape the web page and get the screenshot and the elements
scraped_page = await scrape_website(
browser_state,
task.url,
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.HTML_SCRAPE,
data=scraped_page.html.encode(),
)
LOG.info(
"Scraped website",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
num_elements=len(scraped_page.elements),
url=task.url,
)
# Get action results from the last app.SETTINGS.PROMPT_ACTION_HISTORY_WINDOW steps
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
window_steps = steps[-1 * SettingsManager.get_settings().PROMPT_ACTION_HISTORY_WINDOW :]
action_results: list[ActionResult] = []
for window_step in window_steps:
if window_step.output and window_step.output.action_results:
action_results.extend(window_step.output.action_results)
action_results_str = json.dumps([action_result.model_dump() for action_result in action_results])
# Generate the extract action prompt
navigation_goal = task.navigation_goal
extract_action_prompt = prompt_engine.load_prompt(
"extract-action",
navigation_goal=navigation_goal,
navigation_payload_str=json.dumps(task.navigation_payload),
url=task.url,
elements=scraped_page.element_tree_trimmed, # scraped_page.element_tree,
data_extraction_goal=task.data_extraction_goal,
action_history=action_results_str,
utc_datetime=datetime.utcnow(),
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.VISIBLE_ELEMENTS_ID_XPATH_MAP,
data=json.dumps(scraped_page.id_to_xpath_dict, indent=2).encode(),
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE,
data=json.dumps(scraped_page.element_tree, indent=2).encode(),
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE_TRIMMED,
data=json.dumps(scraped_page.element_tree_trimmed, indent=2).encode(),
)
return scraped_page, extract_action_prompt
async def get_extracted_information_for_task(self, task: Task) -> dict[str, Any] | list | str | None:
"""
Find the last successful ScrapeAction for the task and return the extracted information.
"""
steps = await app.DATABASE.get_task_steps(
task_id=task.task_id,
organization_id=task.organization_id,
)
for step in reversed(steps):
if step.status != StepStatus.completed:
continue
if not step.output or not step.output.actions_and_results:
continue
for action, action_results in step.output.actions_and_results:
if action.action_type != ActionType.COMPLETE:
continue
for action_result in action_results:
if action_result.success:
LOG.info(
"Extracted information for task",
task_id=task.task_id,
step_id=step.step_id,
extracted_information=action_result.data,
)
return action_result.data
LOG.warning(
"Failed to find extracted information for task",
task_id=task.task_id,
)
return None
async def get_failure_reason_for_task(self, task: Task) -> str | None:
"""
Find the TerminateAction for the task and return the reasoning.
# TODO (kerem): Also return meaningful exceptions when we add them [WYV-311]
"""
steps = await app.DATABASE.get_task_steps(
task_id=task.task_id,
organization_id=task.organization_id,
)
for step in reversed(steps):
if step.status != StepStatus.completed:
continue
if not step.output:
continue
if step.output.actions_and_results:
for action, action_results in step.output.actions_and_results:
if action.action_type == ActionType.TERMINATE:
return action.reasoning
LOG.error(
"Failed to find failure reasoning for task",
task_id=task.task_id,
)
return None
async def send_task_response(
self,
task: Task,
last_step: Step,
api_key: str | None = None,
close_browser_on_completion: bool = True,
) -> None:
"""
send the task response to the webhook callback url
"""
# Take one last screenshot and create an artifact before closing the browser to see the final state
browser_state: BrowserState = await app.BROWSER_MANAGER.get_or_create_for_task(task)
page = await browser_state.get_or_create_page()
try:
screenshot = await page.screenshot(full_page=True)
await app.ARTIFACT_MANAGER.create_artifact(
step=last_step,
artifact_type=ArtifactType.SCREENSHOT_FINAL,
data=screenshot,
)
except TargetClosedError as e:
LOG.warning(
"Failed to take screenshot before sending task response, page is closed",
task_id=task.task_id,
step_id=last_step.step_id,
error=e,
)
if task.workflow_run_id:
LOG.info(
"Task is part of a workflow run, not sending a webhook response",
task_id=task.task_id,
workflow_run_id=task.workflow_run_id,
)
return
await self.cleanup_browser_and_create_artifacts(close_browser_on_completion, last_step, task)
# Wait for all tasks to complete before generating the links for the artifacts
await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_task(task.task_id)
if not task.webhook_callback_url:
LOG.warning(
"Task has no webhook callback url. Not sending task response",
task_id=task.task_id,
)
return
if not api_key:
LOG.warning(
"Request has no api key. Not sending task response",
task_id=task.task_id,
)
return
# get the artifact of the screenshot and get the screenshot_url
screenshot_artifact = await app.DATABASE.get_artifact(
task_id=task.task_id,
step_id=last_step.step_id,
artifact_type=ArtifactType.SCREENSHOT_FINAL,
organization_id=task.organization_id,
)
screenshot_url = None
if screenshot_artifact:
screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(screenshot_artifact)
recording_artifact = await app.DATABASE.get_artifact(
task_id=task.task_id,
step_id=last_step.step_id,
artifact_type=ArtifactType.RECORDING,
organization_id=task.organization_id,
)
recording_url = None
if recording_artifact:
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
# get the latest task from the db to get the latest status, extracted_information, and failure_reason
task_from_db = await app.DATABASE.get_task(task_id=task.task_id, organization_id=task.organization_id)
if not task_from_db:
LOG.error("Failed to get task from db when sending task response")
raise TaskNotFound(task_id=task.task_id)
task = task_from_db
if not task.webhook_callback_url:
LOG.info("Task has no webhook callback url. Not sending task response")
return
task_response = task.to_task_response(screenshot_url=screenshot_url, recording_url=recording_url)
# send task_response to the webhook callback url
# TODO: use async requests (httpx)
timestamp = str(int(datetime.utcnow().timestamp()))
payload = task_response.model_dump_json(exclude={"request": {"navigation_payload"}})
signature = generate_skyvern_signature(
payload=payload,
api_key=api_key,
)
headers = {
"x-skyvern-timestamp": timestamp,
"x-skyvern-signature": signature,
"Content-Type": "application/json",
}
LOG.info(
"Sending task response to webhook callback url",
task_id=task.task_id,
webhook_callback_url=task.webhook_callback_url,
payload=payload,
headers=headers,
)
try:
resp = requests.post(task.webhook_callback_url, data=payload, headers=headers)
if resp.ok:
LOG.info(
"Webhook sent successfully",
task_id=task.task_id,
resp_code=resp.status_code,
resp_text=resp.text,
)
else:
LOG.info(
"Webhook failed",
task_id=task.task_id,
resp=resp,
resp_code=resp.status_code,
resp_json=resp.json(),
resp_text=resp.text,
)
except Exception as e:
raise FailedToSendWebhook(task_id=task.task_id) from e
async def cleanup_browser_and_create_artifacts(
self, close_browser_on_completion: bool, last_step: Step, task: Task
) -> None:
# We need to close the browser even if there is no webhook callback url or api key
browser_state = await app.BROWSER_MANAGER.cleanup_for_task(task.task_id, close_browser_on_completion)
if browser_state:
# Update recording artifact after closing the browser, so we can get an accurate recording
video_data = await app.BROWSER_MANAGER.get_video_data(task_id=task.task_id, browser_state=browser_state)
if video_data:
await app.ARTIFACT_MANAGER.update_artifact_data(
artifact_id=browser_state.browser_artifacts.video_artifact_id,
organization_id=task.organization_id,
data=video_data,
)
har_data = await app.BROWSER_MANAGER.get_har_data(task_id=task.task_id, browser_state=browser_state)
if har_data:
await app.ARTIFACT_MANAGER.create_artifact(
step=last_step,
artifact_type=ArtifactType.HAR,
data=har_data,
)
else:
LOG.warning(
"BrowserState is missing before sending response to webhook_callback_url",
web_hook_url=task.webhook_callback_url,
)
async def update_step(
self,
step: Step,
status: StepStatus | None = None,
output: AgentStepOutput | None = None,
is_last: bool | None = None,
retry_index: int | None = None,
) -> Step:
step.validate_update(status, output, is_last)
updates: dict[str, Any] = {}
if status is not None:
updates["status"] = status
if output is not None:
updates["output"] = output
if is_last is not None:
updates["is_last"] = is_last
if retry_index is not None:
updates["retry_index"] = retry_index
update_comparison = {
key: {"old": getattr(step, key), "new": value}
for key, value in updates.items()
if getattr(step, key) != value
}
LOG.info(
"Updating step in db",
task_id=step.task_id,
step_id=step.step_id,
diff=update_comparison,
)
return await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
**updates,
)
async def update_task(
self,
task: Task,
status: TaskStatus,
extracted_information: dict[str, Any] | list | str | None = None,
failure_reason: str | None = None,
) -> Task:
task.validate_update(status, extracted_information, failure_reason)
updates: dict[str, Any] = {}
if status is not None:
updates["status"] = status
if extracted_information is not None:
updates["extracted_information"] = extracted_information
if failure_reason is not None:
updates["failure_reason"] = failure_reason
update_comparison = {
key: {"old": getattr(task, key), "new": value}
for key, value in updates.items()
if getattr(task, key) != value
}
LOG.info("Updating task in db", task_id=task.task_id, diff=update_comparison)
return await app.DATABASE.update_task(
task.task_id,
organization_id=task.organization_id,
**updates,
)
async def handle_failed_step(self, task: Task, step: Step) -> Step | None:
if step.retry_index >= SettingsManager.get_settings().MAX_RETRIES_PER_STEP:
LOG.warning(
"Step failed after max retries, marking task as failed",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
max_retries=SettingsManager.get_settings().MAX_RETRIES_PER_STEP,
)
await self.update_task(
task,
TaskStatus.failed,
failure_reason=f"Max retries per step ({SettingsManager.get_settings().MAX_RETRIES_PER_STEP}) exceeded",
)
return None
else:
LOG.warning(
"Step failed, retrying",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
)
next_step = await app.DATABASE.create_step(
task_id=task.task_id,
organization_id=task.organization_id,
order=step.order,
retry_index=step.retry_index + 1,
)
return next_step
async def handle_completed_step(
self, organization: Organization, task: Task, step: Step
) -> tuple[bool | None, Step | None, Step | None]:
if step.is_goal_achieved():
LOG.info(
"Step completed and goal achieved, marking task as completed",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
output=step.output,
)
last_step = await self.update_step(step, is_last=True)
extracted_information = await self.get_extracted_information_for_task(task)
await self.update_task(task, status=TaskStatus.completed, extracted_information=extracted_information)
return True, last_step, None
if step.is_terminated():
LOG.info(
"Step completed and terminated by the agent, marking task as terminated",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
output=step.output,
)
last_step = await self.update_step(step, is_last=True)
failure_reason = await self.get_failure_reason_for_task(task)
await self.update_task(task, status=TaskStatus.terminated, failure_reason=failure_reason)
return False, last_step, None
# If the max steps are exceeded, mark the current step as the last step and conclude the task
context = skyvern_context.ensure_context()
override_max_steps_per_run = context.max_steps_override
max_steps_per_run = (
override_max_steps_per_run
or organization.max_steps_per_run
or SettingsManager.get_settings().MAX_STEPS_PER_RUN
)
if step.order + 1 >= max_steps_per_run:
LOG.info(
"Step completed but max steps reached, marking task as failed",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
output=step.output,
max_steps=max_steps_per_run,
)
last_step = await self.update_step(step, is_last=True)
await self.update_task(
task,
status=TaskStatus.failed,
failure_reason=f"Max steps per task ({max_steps_per_run}) exceeded",
)
return False, last_step, None
else:
LOG.info(
"Step completed, creating next step",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
output=step.output,
)
next_step = await app.DATABASE.create_step(
task_id=task.task_id,
order=step.order + 1,
retry_index=0,
organization_id=task.organization_id,
)
if step.order == int(
max_steps_per_run * SettingsManager.get_settings().LONG_RUNNING_TASK_WARNING_RATIO - 1
):
LOG.info(
"Long running task warning",
order=step.order,
max_steps=max_steps_per_run,
warning_ratio=SettingsManager.get_settings().LONG_RUNNING_TASK_WARNING_RATIO,
)
return None, None, next_step

36
skyvern/forge/app.py Normal file
View File

@@ -0,0 +1,36 @@
from ddtrace import tracer
from ddtrace.filters import FilterRequestsOnUrl
from skyvern.forge.agent import ForgeAgent
from skyvern.forge.sdk.api.open_ai import OpenAIClientManager
from skyvern.forge.sdk.artifact.manager import ArtifactManager
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
from skyvern.forge.sdk.db.client import AgentDB
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.forge_log import setup_logger
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.service import WorkflowService
from skyvern.webeye.browser_manager import BrowserManager
tracer.configure(
settings={
"FILTERS": [
FilterRequestsOnUrl(r"http://.*/heartbeat$"),
],
},
)
setup_logger()
SETTINGS_MANAGER = SettingsManager.get_settings()
DATABASE = AgentDB(
SettingsManager.get_settings().DATABASE_STRING, debug_enabled=SettingsManager.get_settings().DEBUG_MODE
)
STORAGE = StorageFactory.get_storage()
ASYNC_EXECUTOR = AsyncExecutorFactory.get_executor()
ARTIFACT_MANAGER = ArtifactManager()
BROWSER_MANAGER = BrowserManager()
OPENAI_CLIENT = OpenAIClientManager()
WORKFLOW_SERVICE = WorkflowService()
agent = ForgeAgent()
app = agent.get_agent_app()

4
skyvern/forge/prompts.py Normal file
View File

@@ -0,0 +1,4 @@
from skyvern.forge.sdk.prompting import PromptEngine
# Initialize the prompt engine
prompt_engine = PromptEngine("skyvern")

View File

@@ -0,0 +1,65 @@
Identify actions to help user progress towards the user goal using the DOM elements given in the list and the screenshot of the website.
Include only the elements that are relevant to the user goal, without altering or imagining new elements.
Use the details from the user details to fill in necessary values. Always complete required fields if the field isn't already filled in.
MAKE SURE YOU OUTPUT VALID JSON. No text before or after JSON, no trailing commas, no comments (//), no unnecessary quotes, etc.
Each element is tagged with an ID.
If you see any information in red in the page screenshot, this means a condition wasn't satisfied. prioritize actions with the red information.
If you see a popup in the page screenshot, prioritize actions on the popup.
{% if "lever" in url %}
DO NOT UPDATE ANY LOCATION FIELDS
{% endif %}
Reply in JSON format with the following keys:
{
"actions": array // An array of actions. Here's the format of each action:
[{
"reasoning": str, // The reasoning behind the action. Be specific, referencing any user information and their fields and element ids in your reasoning. Mention why you chose the action type, and why you chose the element id. Keep the reasoning short and to the point.
"confidence_float": float, // The confidence of the action. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence
"action_type": str, // It's a string enum: "CLICK", "INPUT_TEXT", "UPLOAD_FILE", "SELECT_OPTION", "WAIT", "SOLVE_CAPTCHA", "COMPLETE", "TERMINATE". "CLICK" is an element you'd like to click. "INPUT_TEXT" is an element you'd like to input text into. "UPLOAD_FILE" is an element you'd like to upload a file into. "SELECT_OPTION" is an element you'd like to select an option from. "WAIT" action should be used if there are no actions to take and there is some indication on screen that waiting could yield more actions. "WAIT" should not be used if there are actions to take. "SOLVE_CAPTCHA" should be used if there's a captcha to solve on the screen. "COMPLETE" is used when the user goal has been achieved AND if there's any data extraction goal, you should be able to get data from the page. If there is any other action to take, do not add "COMPLETE" type at all. "TERMINATE" is used to terminate the whole task with a failure when it doesn't seem like the user goal can be achieved. Do not use "TERMINATE" if waiting could lead the user towards the goal. Only return "TERMINATE" if you are on a page where the user goal cannot be achieved. All other actions are ignored when "TERMINATE" is returned.
"id": int, // The id of the element to take action on. The id has to be one from the elements list
"text": str, // Text for INPUT_TEXT action only
"file_url": str, // The url of the file to upload if applicable. This field must be present for UPLOAD_FILE but can also be present for CLICK only if the click is to upload the file. It should be null otherwise.
"option": { // The option to select for SELECT_OPTION action only. null if not SELECT_OPTION action
"label": str, // the label of the option if any. MAKE SURE YOU USE THIS LABEL TO SELECT THE OPTION. DO NOT PUT ANYTHING OTHER THAN A VALID OPTION LABEL HERE
"index": int, // the id corresponding to the optionIndex under the the select element.
"value": str // the value of the option. MAKE SURE YOU USE THIS VALUE TO SELECT THE OPTION. DO NOT PUT ANYTHING OTHER THAN A VALID OPTION VALUE HERE
}
}],
}
{% if action_history %}
Consider the action history from the last step and the screenshot together, if actions from the last step don't yield positive impact, try other actions or other action combinations.
{% endif %}
Clickable elements from `{{ url }}`:
```
{{ elements }}
```
User goal:
```
{{ navigation_goal }}
```
{% if data_extraction_goal %}
User Data Extraction Goal:
```
{{ data_extraction_goal }}
```
{% endif %}
User details:
```
{{ navigation_payload_str }}
```
{% if action_history %}
Action results from previous steps: (note: even if the action history suggests goal is achieved, check the screenshot and the DOM elements to make sure the goal is achieved)
{{ action_history }}
{% endif %}
Current datetime in UTC:
```
{{ utc_datetime }}
```

View File

@@ -0,0 +1,16 @@
You are given a screenshot, user data extraction goal, the JSON schema for the output data format, and the current URL.
Your task is to extract the requested information from the screenshot and {% if extracted_information_schema %}output it in the specified JSON schema format:
{{ extracted_information_schema }} {% else %}output in strictly JSON format {% endif %}
Add as much details as possible to the output JSON object while conforming to the output JSON schema.
Do not ever include anything other than the JSON object in your output, and do not ever include any additional fields in the JSON object.
If you are unable to extract the requested information for a specific field in the json schema, please output a null value for that field.
User Data Extraction Goal: {{ data_extraction_goal }}
Current URL: {{ current_url }}
Text extracted from the webpage: {{ extracted_text }}

View File

View File

@@ -0,0 +1,97 @@
import uuid
from datetime import datetime
from typing import Awaitable, Callable
import structlog
from fastapi import APIRouter, FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.requests import HTTPConnection, Request
from starlette_context.middleware import RawContextMiddleware
from starlette_context.plugins.base import Plugin
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.routes.agent_protocol import base_router
LOG = structlog.get_logger()
class Agent:
def get_agent_app(self, router: APIRouter = base_router) -> FastAPI:
"""
Start the agent server.
"""
app = FastAPI()
# Add CORS middleware
origins = [
"http://localhost:5000",
"http://127.0.0.1:5000",
"http://localhost:8000",
"http://127.0.0.1:8000",
"http://localhost:8080",
"http://127.0.0.1:8080",
# Add any other origins you want to whitelist
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(router, prefix="/api/v1")
app.add_middleware(AgentMiddleware, agent=self)
app.add_middleware(
RawContextMiddleware,
plugins=(
# TODO (suchintan): We should set these up
ExecutionDatePlugin(),
# RequestIdPlugin(),
# UserAgentPlugin(),
),
)
@app.exception_handler(Exception)
async def unexpected_exception(request: Request, exc: Exception) -> JSONResponse:
LOG.exception("Unexpected error in agent server.", exc_info=exc)
return JSONResponse(status_code=500, content={"error": f"Unexpected error: {exc}"})
@app.middleware("http")
async def request_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
request_id = str(uuid.uuid4())
skyvern_context.set(SkyvernContext(request_id=request_id))
try:
return await call_next(request)
finally:
skyvern_context.reset()
return app
class AgentMiddleware:
"""
Middleware that injects the agent instance into the request scope.
"""
def __init__(self, app: FastAPI, agent: Agent):
self.app = app
self.agent = agent
async def __call__(self, scope, receive, send): # type: ignore
scope["agent"] = self.agent
await self.app(scope, receive, send)
class ExecutionDatePlugin(Plugin):
key = "execution_date"
async def process_request(self, request: Request | HTTPConnection) -> datetime:
return datetime.now()

View File

View File

@@ -0,0 +1,134 @@
from enum import StrEnum
from typing import Any, Callable
from urllib.parse import urlparse
import aioboto3
import structlog
from aiobotocore.client import AioBaseClient
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class AWSClientType(StrEnum):
S3 = "s3"
SECRETS_MANAGER = "secretsmanager"
def execute_with_async_client(client_type: AWSClientType) -> Callable:
def decorator(f: Callable) -> Callable:
async def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
self = args[0]
assert isinstance(self, AsyncAWSClient)
session = aioboto3.Session()
async with session.client(client_type) as client:
return await f(*args, client=client, **kwargs)
return wrapper
return decorator
class AsyncAWSClient:
@execute_with_async_client(client_type=AWSClientType.SECRETS_MANAGER)
async def get_secret(self, secret_name: str, client: AioBaseClient = None) -> str | None:
try:
response = await client.get_secret_value(SecretId=secret_name)
return response["SecretString"]
except Exception as e:
try:
error_code = e.response["Error"]["Code"] # type: ignore
except Exception:
error_code = "failed-to-get-error-code"
LOG.exception("Failed to get secret.", secret_name=secret_name, error_code=error_code, exc_info=True)
return None
@execute_with_async_client(client_type=AWSClientType.S3)
async def upload_file(self, uri: str, data: bytes, client: AioBaseClient = None) -> str | None:
try:
parsed_uri = S3Uri(uri)
await client.put_object(Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key)
LOG.debug("Upload file success", uri=uri)
return uri
except Exception:
LOG.exception("S3 upload failed.", uri=uri)
return None
@execute_with_async_client(client_type=AWSClientType.S3)
async def upload_file_from_path(self, uri: str, file_path: str, client: AioBaseClient = None) -> None:
try:
parsed_uri = S3Uri(uri)
await client.upload_file(file_path, parsed_uri.bucket, parsed_uri.key)
LOG.info("Upload file from path success", uri=uri)
except Exception:
LOG.exception("S3 upload failed.", uri=uri)
@execute_with_async_client(client_type=AWSClientType.S3)
async def download_file(self, uri: str, client: AioBaseClient = None) -> bytes | None:
try:
parsed_uri = S3Uri(uri)
response = await client.get_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
return await response["Body"].read()
except Exception:
LOG.exception("S3 download failed", uri=uri)
return None
@execute_with_async_client(client_type=AWSClientType.S3)
async def create_presigned_url(self, uri: str, client: AioBaseClient = None) -> str | None:
try:
parsed_uri = S3Uri(uri)
url = await client.generate_presigned_url(
"get_object",
Params={"Bucket": parsed_uri.bucket, "Key": parsed_uri.key},
ExpiresIn=SettingsManager.get_settings().PRESIGNED_URL_EXPIRATION,
)
return url
except Exception:
LOG.exception("Failed to create presigned url.", uri=uri)
return None
class S3Uri(object):
# From: https://stackoverflow.com/questions/42641315/s3-urls-get-bucket-name-and-path
"""
>>> s = S3Uri("s3://bucket/hello/world")
>>> s.bucket
'bucket'
>>> s.key
'hello/world'
>>> s.uri
's3://bucket/hello/world'
>>> s = S3Uri("s3://bucket/hello/world?qwe1=3#ddd")
>>> s.bucket
'bucket'
>>> s.key
'hello/world?qwe1=3#ddd'
>>> s.uri
's3://bucket/hello/world?qwe1=3#ddd'
>>> s = S3Uri("s3://bucket/hello/world#foo?bar=2")
>>> s.key
'hello/world#foo?bar=2'
>>> s.uri
's3://bucket/hello/world#foo?bar=2'
"""
def __init__(self, uri: str) -> None:
self._parsed = urlparse(uri, allow_fragments=False)
@property
def bucket(self) -> str:
return self._parsed.netloc
@property
def key(self) -> str:
if self._parsed.query:
return self._parsed.path.lstrip("/") + "?" + self._parsed.query
else:
return self._parsed.path.lstrip("/")
@property
def uri(self) -> str:
return self._parsed.geturl()

View File

@@ -0,0 +1,25 @@
from typing import Callable
from pydantic import BaseModel
openai_model_to_price_lambdas = {
"gpt-4-vision-preview": (0.01, 0.03),
"gpt-4-1106-preview": (0.01, 0.03),
"gpt-3.5-turbo": (0.001, 0.002),
"gpt-3.5-turbo-1106": (0.001, 0.002),
}
class ChatCompletionPrice(BaseModel):
input_token_count: int
output_token_count: int
openai_model_to_price_lambda: Callable[[int, int], float]
def __init__(self, input_token_count: int, output_token_count: int, model_name: str):
input_token_price, output_token_price = openai_model_to_price_lambdas[model_name]
super().__init__(
input_token_count=input_token_count,
output_token_count=output_token_count,
openai_model_to_price_lambda=lambda input_token, output_token: input_token_price * input_token / 1000
+ output_token_price * output_token / 1000,
)

View File

@@ -0,0 +1,47 @@
import os
import tempfile
import zipfile
from urllib.parse import urlparse
import requests
import structlog
LOG = structlog.get_logger()
def download_file(url: str) -> str | None:
# Send an HTTP request to the URL of the file, stream=True to prevent loading the content at once into memory
r = requests.get(url, stream=True)
# Check if the request is successful
if r.status_code == 200:
# Parse the URL
a = urlparse(url)
# Get the file name
temp_dir = tempfile.mkdtemp(prefix="skyvern_downloads_")
file_name = os.path.basename(a.path)
file_path = os.path.join(temp_dir, file_name)
LOG.info(f"Downloading file to {file_path}")
with open(file_path, "wb") as f:
# Write the content of the request into the file
for chunk in r.iter_content(1024):
f.write(chunk)
LOG.info(f"File downloaded successfully to {file_path}")
return file_path
else:
LOG.error(f"Failed to download file, status code: {r.status_code}")
return None
def zip_files(files_path: str, zip_file_path: str) -> str:
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(files_path):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, files_path) # Relative path within the zip
zipf.write(file_path, arcname)
return zip_file_path

View File

@@ -0,0 +1,221 @@
import base64
import json
import random
from datetime import datetime, timedelta
from typing import Any
import commentjson
import openai
import structlog
from openai import AsyncOpenAI
from openai.types.chat.chat_completion import ChatCompletion
from skyvern.exceptions import InvalidOpenAIResponseFormat, NoAvailableOpenAIClients, OpenAIRequestTooBigError
from skyvern.forge import app
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class OpenAIKeyClientWrapper:
client: AsyncOpenAI
key: str
remaining_requests: int | None
def __init__(self, key: str, remaining_requests: int | None) -> None:
self.key = key
self.remaining_requests = remaining_requests
self.updated_at = datetime.utcnow()
self.client = AsyncOpenAI(api_key=self.key)
def update_remaining_requests(self, remaining_requests: int | None) -> None:
self.remaining_requests = remaining_requests
self.updated_at = datetime.utcnow()
def is_available(self) -> bool:
# If remaining_requests is None, then it's the first time we're trying this key
# so we can assume it's available, otherwise we check if it's greater than 0
if self.remaining_requests is None:
return True
if self.remaining_requests > 0:
return True
# If we haven't checked this in over 1 minutes, check it again
# Most of our failures are because of Tokens-per-minute (TPM) limits
if self.updated_at < (datetime.utcnow() - timedelta(minutes=1)):
return True
return False
class OpenAIClientManager:
# TODO Support other models for requests without screenshots, track rate limits for each model and key as well if any
clients: list[OpenAIKeyClientWrapper]
def __init__(self, api_keys: list[str] = SettingsManager.get_settings().OPENAI_API_KEYS) -> None:
self.clients = [OpenAIKeyClientWrapper(key, None) for key in api_keys]
def get_available_client(self) -> OpenAIKeyClientWrapper | None:
available_clients = [client for client in self.clients if client.is_available()]
if not available_clients:
return None
# Randomly select an available client to distribute requests across our accounts
return random.choice(available_clients)
async def content_builder(
self,
step: Step,
screenshots: list[bytes] | None = None,
prompt: str | None = None,
) -> list[dict[str, Any]]:
content: list[dict[str, Any]] = []
if prompt is not None:
content.append(
{
"type": "text",
"text": prompt,
}
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
if screenshots:
for screenshot in screenshots:
encoded_image = base64.b64encode(screenshot).decode("utf-8")
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encoded_image}",
},
}
)
# create artifact for each image
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
return content
async def chat_completion(
self,
step: Step,
model: str = "gpt-4-vision-preview",
max_tokens: int = 4096,
temperature: int = 0,
screenshots: list[bytes] | None = None,
prompt: str | None = None,
) -> dict[str, Any]:
LOG.info(
f"Sending LLM request",
task_id=step.task_id,
step_id=step.step_id,
num_screenshots=len(screenshots) if screenshots else 0,
)
messages = [
{
"role": "user",
"content": await self.content_builder(
step=step,
screenshots=screenshots,
prompt=prompt,
),
}
]
chat_completion_kwargs = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
}
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(chat_completion_kwargs).encode("utf-8"),
)
available_client = self.get_available_client()
if available_client is None:
raise NoAvailableOpenAIClients()
try:
response = await available_client.client.chat.completions.with_raw_response.create(**chat_completion_kwargs)
except openai.RateLimitError as e:
# If we get a RateLimitError, we can assume the key is not available anymore
if e.code == 429:
raise OpenAIRequestTooBigError(e.message)
LOG.warning(
"OpenAI rate limit exceeded, marking key as unavailable.", error_code=e.code, error_message=e.message
)
available_client.update_remaining_requests(remaining_requests=0)
available_client = self.get_available_client()
if available_client is None:
raise NoAvailableOpenAIClients()
return await self.chat_completion(
step=step,
model=model,
max_tokens=max_tokens,
temperature=temperature,
screenshots=screenshots,
prompt=prompt,
)
# TODO: https://platform.openai.com/docs/guides/rate-limits/rate-limits-in-headers
# use other headers, x-ratelimit-limit-requests, x-ratelimit-limit-tokens, x-ratelimit-remaining-tokens
# x-ratelimit-reset-requests, x-ratelimit-reset-tokens to write a more accurate algorithm for managing api keys
# If we get a response, we can assume the key is available and update the remaining requests
ratelimit_remaining_requests = response.headers.get("x-ratelimit-remaining-requests")
if not ratelimit_remaining_requests:
LOG.warning("Invalid x-ratelimit-remaining-requests from OpenAI", response.headers)
available_client.update_remaining_requests(remaining_requests=int(ratelimit_remaining_requests))
chat_completion = response.parse()
if chat_completion.usage is not None:
# TODO (Suchintan): Is this bad design?
step = await app.DATABASE.update_step(
step_id=step.step_id,
task_id=step.task_id,
organization_id=step.organization_id,
chat_completion_price=ChatCompletionPrice(
input_token_count=chat_completion.usage.prompt_tokens,
output_token_count=chat_completion.usage.completion_tokens,
model_name=model,
),
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=chat_completion.model_dump_json(indent=2).encode("utf-8"),
)
parsed_response = self.parse_response(chat_completion)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
return parsed_response
def parse_response(self, response: ChatCompletion) -> dict[str, str]:
try:
content = response.choices[0].message.content
content = content.replace("```json", "")
content = content.replace("```", "")
if not content:
raise Exception("openai response content is empty")
return commentjson.loads(content)
except Exception as e:
raise InvalidOpenAIResponseFormat(str(response)) from e

View File

View File

@@ -0,0 +1,112 @@
import asyncio
import time
from collections import defaultdict
import structlog
from skyvern.forge import app
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.id import generate_artifact_id
from skyvern.forge.sdk.models import Step
LOG = structlog.get_logger(__name__)
class ArtifactManager:
# task_id -> list of aio_tasks for uploading artifacts
upload_aiotasks_map: dict[str, list[asyncio.Task[None]]] = defaultdict(list)
async def create_artifact(
self, step: Step, artifact_type: ArtifactType, data: bytes | None = None, path: str | None = None
) -> str:
# TODO (kerem): Which is better?
# current: (disadvantage: we create the artifact_id UUID here)
# 1. generate artifact_id UUID here
# 2. build uri with artifact_id, step_id, task_id, artifact_type
# 3. create artifact in db using artifact_id, step_id, task_id, artifact_type, uri
# 4. store artifact in storage
# alternative: (disadvantage: two db calls)
# 1. create artifact in db without the URI
# 2. build uri with artifact_id, step_id, task_id, artifact_type
# 3. update artifact in db with the URI
# 4. store artifact in storage
if data is None and path is None:
raise ValueError("Either data or path must be provided to create an artifact.")
if data and path:
raise ValueError("Both data and path cannot be provided to create an artifact.")
artifact_id = generate_artifact_id()
uri = app.STORAGE.build_uri(artifact_id, step, artifact_type)
artifact = await app.DATABASE.create_artifact(
artifact_id,
step.step_id,
step.task_id,
artifact_type,
uri,
organization_id=step.organization_id,
)
if data:
# Fire and forget
aio_task = asyncio.create_task(app.STORAGE.store_artifact(artifact, data))
self.upload_aiotasks_map[step.task_id].append(aio_task)
elif path:
# Fire and forget
aio_task = asyncio.create_task(app.STORAGE.store_artifact_from_path(artifact, path))
self.upload_aiotasks_map[step.task_id].append(aio_task)
return artifact_id
async def update_artifact_data(self, artifact_id: str | None, organization_id: str | None, data: bytes) -> None:
if not artifact_id or not organization_id:
return None
artifact = await app.DATABASE.get_artifact_by_id(artifact_id, organization_id)
if not artifact:
return
# Fire and forget
aio_task = asyncio.create_task(app.STORAGE.store_artifact(artifact, data))
self.upload_aiotasks_map[artifact.task_id].append(aio_task)
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
return await app.STORAGE.retrieve_artifact(artifact)
async def get_share_link(self, artifact: Artifact) -> str | None:
return await app.STORAGE.get_share_link(artifact)
async def wait_for_upload_aiotasks_for_task(self, task_id: str) -> None:
try:
st = time.time()
async with asyncio.timeout(30):
await asyncio.gather(
*[aio_task for aio_task in self.upload_aiotasks_map[task_id] if not aio_task.done()]
)
LOG.info(
f"S3 upload tasks for task_id={task_id} completed in {time.time() - st:.2f}s",
task_id=task_id,
duration=time.time() - st,
)
except asyncio.TimeoutError:
LOG.error(f"Timeout (30s) while waiting for upload tasks for task_id={task_id}", task_id=task_id)
del self.upload_aiotasks_map[task_id]
async def wait_for_upload_aiotasks_for_tasks(self, task_ids: list[str]) -> None:
try:
st = time.time()
async with asyncio.timeout(30):
await asyncio.gather(
*[
aio_task
for task_id in task_ids
for aio_task in self.upload_aiotasks_map[task_id]
if not aio_task.done()
]
)
LOG.info(
f"S3 upload tasks for task_ids={task_ids} completed in {time.time() - st:.2f}s",
task_ids=task_ids,
duration=time.time() - st,
)
except asyncio.TimeoutError:
LOG.error(f"Timeout (30s) while waiting for upload tasks for task_ids={task_ids}", task_ids=task_ids)
for task_id in task_ids:
del self.upload_aiotasks_map[task_id]

View File

@@ -0,0 +1,78 @@
from __future__ import annotations
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel, Field
class ArtifactType(StrEnum):
RECORDING = "recording"
# DEPRECATED. pls use SCREENSHOT_LLM, SCREENSHOT_ACTION or SCREENSHOT_FINAL
SCREENSHOT = "screenshot"
# USE THESE for screenshots
SCREENSHOT_LLM = "screenshot_llm"
SCREENSHOT_ACTION = "screenshot_action"
SCREENSHOT_FINAL = "screenshot_final"
LLM_PROMPT = "llm_prompt"
LLM_REQUEST = "llm_request"
LLM_RESPONSE = "llm_response"
LLM_RESPONSE_PARSED = "llm_response_parsed"
VISIBLE_ELEMENTS_ID_XPATH_MAP = "visible_elements_id_xpath_map"
VISIBLE_ELEMENTS_TREE = "visible_elements_tree"
VISIBLE_ELEMENTS_TREE_TRIMMED = "visible_elements_tree_trimmed"
# DEPRECATED. pls use HTML_SCRAPE or HTML_ACTION
HTML = "html"
# USE THESE for htmls
HTML_SCRAPE = "html_scrape"
HTML_ACTION = "html_action"
# Debugging
TRACE = "trace"
HAR = "har"
class Artifact(BaseModel):
created_at: datetime = Field(
...,
description="The creation datetime of the task.",
examples=["2023-01-01T00:00:00Z"],
json_encoders={datetime: lambda v: v.isoformat()},
)
modified_at: datetime = Field(
...,
description="The modification datetime of the task.",
examples=["2023-01-01T00:00:00Z"],
json_encoders={datetime: lambda v: v.isoformat()},
)
artifact_id: str = Field(
...,
description="The ID of the task artifact.",
examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"],
)
task_id: str = Field(
...,
description="The ID of the task this artifact belongs to.",
examples=["50da533e-3904-4401-8a07-c49adf88b5eb"],
)
step_id: str = Field(
...,
description="The ID of the task step this artifact belongs to.",
examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"],
)
artifact_type: ArtifactType = Field(
...,
description="The type of the artifact.",
examples=["screenshot"],
)
uri: str = Field(
...,
description="The URI of the artifact.",
examples=["/Users/skyvern/hello/world.png"],
)
organization_id: str | None = None

View File

@@ -0,0 +1,45 @@
from abc import ABC, abstractmethod
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.models import Step
# TODO: This should be a part of the ArtifactType model
FILE_EXTENTSION_MAP: dict[ArtifactType, str] = {
ArtifactType.RECORDING: "webm",
ArtifactType.SCREENSHOT_LLM: "png",
ArtifactType.SCREENSHOT_ACTION: "png",
ArtifactType.SCREENSHOT_FINAL: "png",
ArtifactType.LLM_PROMPT: "txt",
ArtifactType.LLM_REQUEST: "json",
ArtifactType.LLM_RESPONSE: "json",
ArtifactType.LLM_RESPONSE_PARSED: "json",
ArtifactType.VISIBLE_ELEMENTS_ID_XPATH_MAP: "json",
ArtifactType.VISIBLE_ELEMENTS_TREE: "json",
ArtifactType.VISIBLE_ELEMENTS_TREE_TRIMMED: "json",
ArtifactType.HTML_SCRAPE: "html",
ArtifactType.HTML_ACTION: "html",
ArtifactType.TRACE: "zip",
ArtifactType.HAR: "har",
}
class BaseStorage(ABC):
@abstractmethod
def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str:
pass
@abstractmethod
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
pass
@abstractmethod
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
pass
@abstractmethod
async def get_share_link(self, artifact: Artifact) -> str | None:
pass
@abstractmethod
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
pass

View File

@@ -0,0 +1,14 @@
from skyvern.forge.sdk.artifact.storage.base import BaseStorage
from skyvern.forge.sdk.artifact.storage.local import LocalStorage
class StorageFactory:
__storage: BaseStorage = LocalStorage()
@staticmethod
def set_storage(storage: BaseStorage) -> None:
StorageFactory.__storage = storage
@staticmethod
def get_storage() -> BaseStorage:
return StorageFactory.__storage

View File

@@ -0,0 +1,66 @@
from datetime import datetime
from pathlib import Path
from urllib.parse import unquote, urlparse
import structlog
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class LocalStorage(BaseStorage):
def __init__(self, artifact_path: str = SettingsManager.get_settings().ARTIFACT_STORAGE_PATH) -> None:
self.artifact_path = artifact_path
def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str:
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"file://{self.artifact_path}/{step.task_id}/{step.order:02d}_{step.retry_index}_{step.step_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
file_path = None
try:
file_path = Path(self._parse_uri_to_path(artifact.uri))
self._create_directories_if_not_exists(file_path)
with open(file_path, "wb") as f:
f.write(data)
except Exception:
LOG.exception("Failed to store artifact locally.", file_path=file_path, artifact=artifact)
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
file_path = None
try:
file_path = Path(self._parse_uri_to_path(artifact.uri))
self._create_directories_if_not_exists(file_path)
Path(path).replace(file_path)
except Exception:
LOG.exception("Failed to store artifact locally.", file_path=file_path, artifact=artifact)
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
file_path = None
try:
file_path = self._parse_uri_to_path(artifact.uri)
with open(file_path, "rb") as f:
return f.read()
except Exception:
LOG.exception("Failed to retrieve local artifact.", file_path=file_path, artifact=artifact)
return None
async def get_share_link(self, artifact: Artifact) -> str:
return artifact.uri
@staticmethod
def _parse_uri_to_path(uri: str) -> str:
parsed_uri = urlparse(uri)
if parsed_uri.scheme != "file":
raise ValueError("Invalid URI scheme: {parsed_uri.scheme} expected: file")
path = parsed_uri.netloc + parsed_uri.path
return unquote(path)
@staticmethod
def _create_directories_if_not_exists(path_including_file_name: Path) -> None:
path = path_including_file_name.parent
path.mkdir(parents=True, exist_ok=True)

View File

View File

@@ -0,0 +1,41 @@
import hashlib
import hmac
from datetime import datetime, timedelta
from typing import Any, Union
from jose import jwt
from skyvern.forge.sdk.settings_manager import SettingsManager
ALGORITHM = "HS256"
def create_access_token(
subject: Union[str, Any],
expires_delta: timedelta | None = None,
) -> str:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(
minutes=SettingsManager.get_settings().ACCESS_TOKEN_EXPIRE_MINUTES,
)
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, SettingsManager.get_settings().SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def generate_skyvern_signature(
payload: str,
api_key: str,
) -> str:
"""
Generate Skyvern signature.
:param payload: the request body
:param api_key: the Skyvern api key
:return: the Skyvern signature
"""
hash_obj = hmac.new(api_key.encode("utf-8"), msg=payload.encode("utf-8"), digestmod=hashlib.sha256)
return hash_obj.hexdigest()

View File

@@ -0,0 +1,73 @@
from contextvars import ContextVar
from dataclasses import dataclass
@dataclass
class SkyvernContext:
request_id: str | None = None
organization_id: str | None = None
task_id: str | None = None
workflow_id: str | None = None
workflow_run_id: str | None = None
max_steps_override: int | None = None
def __repr__(self) -> str:
return f"SkyvernContext(request_id={self.request_id}, organization_id={self.organization_id}, task_id={self.task_id}, workflow_id={self.workflow_id}, workflow_run_id={self.workflow_run_id}, max_steps_override={self.max_steps_override})"
def __str__(self) -> str:
return self.__repr__()
_context: ContextVar[SkyvernContext | None] = ContextVar(
"Global context",
default=None,
)
def current() -> SkyvernContext | None:
"""
Get the current context
Returns:
The current context, or None if there is none
"""
return _context.get()
def ensure_context() -> SkyvernContext:
"""
Get the current context, or raise an error if there is none
Returns:
The current context if there is one
Raises:
RuntimeError: If there is no current context
"""
context = current()
if context is None:
raise RuntimeError("No skyvern context")
return context
def set(context: SkyvernContext) -> None:
"""
Set the current context
Args:
context: The context to set
Returns:
None
"""
_context.set(context)
def reset() -> None:
"""
Reset the current context
Returns:
None
"""
_context.set(None)

View File

View File

@@ -0,0 +1,900 @@
from datetime import datetime
from typing import Any
import structlog
from sqlalchemy import and_, create_engine, delete
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from skyvern.exceptions import WorkflowParameterNotFound
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.models import (
ArtifactModel,
AWSSecretParameterModel,
OrganizationAuthTokenModel,
OrganizationModel,
StepModel,
TaskModel,
WorkflowModel,
WorkflowParameterModel,
WorkflowRunModel,
WorkflowRunParameterModel,
)
from skyvern.forge.sdk.db.utils import (
_custom_json_serializer,
convert_to_artifact,
convert_to_aws_secret_parameter,
convert_to_organization,
convert_to_organization_auth_token,
convert_to_step,
convert_to_task,
convert_to_workflow,
convert_to_workflow_parameter,
convert_to_workflow_run,
convert_to_workflow_run_parameter,
)
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunParameter, WorkflowRunStatus
from skyvern.webeye.actions.models import AgentStepOutput
LOG = structlog.get_logger()
class AgentDB:
def __init__(self, database_string: str, debug_enabled: bool = False) -> None:
super().__init__()
self.debug_enabled = debug_enabled
self.engine = create_engine(database_string, json_serializer=_custom_json_serializer)
self.Session = sessionmaker(bind=self.engine)
async def create_task(
self,
url: str,
navigation_goal: str | None,
data_extraction_goal: str | None,
navigation_payload: dict[str, Any] | list | str | None,
webhook_callback_url: str | None = None,
organization_id: str | None = None,
proxy_location: ProxyLocation | None = None,
extracted_information_schema: dict[str, Any] | list | str | None = None,
workflow_run_id: str | None = None,
order: int | None = None,
retry: int | None = None,
) -> Task:
try:
with self.Session() as session:
new_task = TaskModel(
status="created",
url=url,
webhook_callback_url=webhook_callback_url,
navigation_goal=navigation_goal,
data_extraction_goal=data_extraction_goal,
navigation_payload=navigation_payload,
organization_id=organization_id,
proxy_location=proxy_location,
extracted_information_schema=extracted_information_schema,
workflow_run_id=workflow_run_id,
order=order,
retry=retry,
)
session.add(new_task)
session.commit()
session.refresh(new_task)
return convert_to_task(new_task, self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def create_step(
self,
task_id: str,
order: int,
retry_index: int,
organization_id: str | None = None,
) -> Step:
try:
with self.Session() as session:
new_step = StepModel(
task_id=task_id,
order=order,
retry_index=retry_index,
status="created",
organization_id=organization_id,
)
session.add(new_step)
session.commit()
session.refresh(new_step)
return convert_to_step(new_step, debug_enabled=self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def create_artifact(
self,
artifact_id: str,
step_id: str,
task_id: str,
artifact_type: str,
uri: str,
organization_id: str | None = None,
) -> Artifact:
try:
with self.Session() as session:
new_artifact = ArtifactModel(
artifact_id=artifact_id,
task_id=task_id,
step_id=step_id,
artifact_type=artifact_type,
uri=uri,
organization_id=organization_id,
)
session.add(new_artifact)
session.commit()
session.refresh(new_artifact)
return convert_to_artifact(new_artifact, self.debug_enabled)
except SQLAlchemyError:
LOG.exception("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.exception("UnexpectedError", exc_info=True)
raise
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
"""Get a task by its id"""
try:
with self.Session() as session:
if task_obj := (
session.query(TaskModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.first()
):
return convert_to_task(task_obj, self.debug_enabled)
else:
LOG.info("Task not found", task_id=task_id, organization_id=organization_id)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_step(self, task_id: str, step_id: str, organization_id: str | None = None) -> Step | None:
try:
with self.Session() as session:
if step := (
session.query(StepModel)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.first()
):
return convert_to_step(step, debug_enabled=self.debug_enabled)
else:
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_task_steps(self, task_id: str, organization_id: str | None = None) -> list[Step]:
try:
with self.Session() as session:
if (
steps := session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order)
.order_by(StepModel.retry_index)
.all()
):
return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps]
else:
return []
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> list[StepModel]:
try:
with self.Session() as session:
return (
session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order)
.order_by(StepModel.retry_index)
.all()
)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_latest_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
try:
with self.Session() as session:
if step := (
session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order.desc())
.first()
):
return convert_to_step(step, debug_enabled=self.debug_enabled)
else:
LOG.info("Latest step not found", task_id=task_id, organization_id=organization_id)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def update_step(
self,
task_id: str,
step_id: str,
status: StepStatus | None = None,
output: AgentStepOutput | None = None,
is_last: bool | None = None,
retry_index: int | None = None,
organization_id: str | None = None,
chat_completion_price: ChatCompletionPrice | None = None,
) -> Step:
try:
with self.Session() as session:
if (
step := session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.first()
):
if status is not None:
step.status = status
if output is not None:
step.output = output.model_dump()
if is_last is not None:
step.is_last = is_last
if retry_index is not None:
step.retry_index = retry_index
if chat_completion_price is not None:
if step.input_token_count is None:
step.input_token_count = 0
if step.output_token_count is None:
step.output_token_count = 0
step.input_token_count += chat_completion_price.input_token_count
step.output_token_count += chat_completion_price.output_token_count
step.step_cost = chat_completion_price.openai_model_to_price_lambda(
step.input_token_count, step.output_token_count
)
session.commit()
updated_step = await self.get_step(task_id, step_id, organization_id)
if not updated_step:
raise NotFoundError("Step not found")
return updated_step
else:
raise NotFoundError("Step not found")
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except NotFoundError:
LOG.error("NotFoundError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def update_task(
self,
task_id: str,
status: TaskStatus,
extracted_information: dict[str, Any] | list | str | None = None,
failure_reason: str | None = None,
organization_id: str | None = None,
) -> Task:
try:
with self.Session() as session:
if (
task := session.query(TaskModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.first()
):
task.status = status
if extracted_information is not None:
task.extracted_information = extracted_information
if failure_reason is not None:
task.failure_reason = failure_reason
session.commit()
updated_task = await self.get_task(task_id, organization_id=organization_id)
if not updated_task:
raise NotFoundError("Task not found")
return updated_task
else:
raise NotFoundError("Task not found")
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except NotFoundError:
LOG.error("NotFoundError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_tasks(self, page: int = 1, page_size: int = 10, organization_id: str | None = None) -> list[Task]:
"""
Get all tasks.
:param page: Starts at 1
:param page_size:
:return:
"""
if page < 1:
raise ValueError(f"Page must be greater than 0, got {page}")
try:
with self.Session() as session:
db_page = page - 1 # offset logic is 0 based
tasks = (
session.query(TaskModel)
.filter_by(organization_id=organization_id)
.order_by(TaskModel.created_at.desc())
.limit(page_size)
.offset(db_page * page_size)
.all()
)
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_organization(self, organization_id: str) -> Organization | None:
try:
with self.Session() as session:
if organization := (
session.query(OrganizationModel).filter_by(organization_id=organization_id).first()
):
return convert_to_organization(organization)
else:
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def create_organization(
self,
organization_name: str,
webhook_callback_url: str | None = None,
max_steps_per_run: int | None = None,
) -> Organization:
with self.Session() as session:
org = OrganizationModel(
organization_name=organization_name,
webhook_callback_url=webhook_callback_url,
max_steps_per_run=max_steps_per_run,
)
session.add(org)
session.commit()
session.refresh(org)
return convert_to_organization(org)
async def get_valid_org_auth_token(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
) -> OrganizationAuthToken | None:
try:
with self.Session() as session:
if token := (
session.query(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
.filter_by(valid=True)
.first()
):
return convert_to_organization_auth_token(token)
else:
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def validate_org_auth_token(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
token: str,
) -> OrganizationAuthToken | None:
try:
with self.Session() as session:
if token_obj := (
session.query(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
.filter_by(token=token)
.filter_by(valid=True)
.first()
):
return convert_to_organization_auth_token(token_obj)
else:
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def create_org_auth_token(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
token: str,
) -> OrganizationAuthToken:
with self.Session() as session:
token = OrganizationAuthTokenModel(
organization_id=organization_id,
token_type=token_type,
token=token,
)
session.add(token)
session.commit()
session.refresh(token)
return convert_to_organization_auth_token(token)
async def get_artifacts_for_task_step(
self,
task_id: str,
step_id: str,
organization_id: str | None = None,
) -> list[Artifact]:
try:
with self.Session() as session:
if artifacts := (
session.query(ArtifactModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.all()
):
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
else:
return []
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_artifact_by_id(
self,
artifact_id: str,
organization_id: str,
) -> Artifact | None:
try:
with self.Session() as session:
if artifact := (
session.query(ArtifactModel)
.filter_by(artifact_id=artifact_id)
.filter_by(organization_id=organization_id)
.first()
):
return convert_to_artifact(artifact, self.debug_enabled)
else:
return None
except SQLAlchemyError:
LOG.exception("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.exception("UnexpectedError", exc_info=True)
raise
async def get_artifact(
self,
task_id: str,
step_id: str,
artifact_type: ArtifactType,
organization_id: str | None = None,
) -> Artifact | None:
try:
with self.Session() as session:
artifact = (
session.query(ArtifactModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.filter_by(artifact_type=artifact_type)
.order_by(ArtifactModel.created_at.desc())
.first()
)
if artifact:
return convert_to_artifact(artifact, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_artifact_for_workflow_run(
self,
workflow_run_id: str,
artifact_type: ArtifactType,
organization_id: str | None = None,
) -> Artifact | None:
try:
with self.Session() as session:
artifact = (
session.query(ArtifactModel)
.join(TaskModel, TaskModel.task_id == ArtifactModel.task_id)
.filter(TaskModel.workflow_run_id == workflow_run_id)
.filter(ArtifactModel.artifact_type == artifact_type)
.filter(ArtifactModel.organization_id == organization_id)
.order_by(ArtifactModel.created_at.desc())
.first()
)
if artifact:
return convert_to_artifact(artifact, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_latest_artifact(
self,
task_id: str,
step_id: str | None = None,
artifact_types: list[ArtifactType] | None = None,
organization_id: str | None = None,
) -> Artifact | None:
try:
with self.Session() as session:
artifact_query = session.query(ArtifactModel).filter_by(task_id=task_id)
if step_id:
artifact_query = artifact_query.filter_by(step_id=step_id)
if organization_id:
artifact_query = artifact_query.filter_by(organization_id=organization_id)
if artifact_types:
artifact_query = artifact_query.filter(ArtifactModel.artifact_type.in_(artifact_types))
artifact = artifact_query.order_by(ArtifactModel.created_at.desc()).first()
if artifact:
return convert_to_artifact(artifact, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.exception("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.exception("UnexpectedError", exc_info=True)
raise
async def get_latest_task_by_workflow_id(
self,
organization_id: str,
workflow_id: str,
before: datetime | None = None,
) -> Task | None:
try:
with self.Session() as session:
query = (
session.query(TaskModel)
.filter_by(organization_id=organization_id)
.filter_by(workflow_id=workflow_id)
)
if before:
query = query.filter(TaskModel.created_at < before)
task = query.order_by(TaskModel.created_at.desc()).first()
if task:
return convert_to_task(task, debug_enabled=self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def create_workflow(
self,
organization_id: str,
title: str,
workflow_definition: dict[str, Any],
description: str | None = None,
) -> Workflow:
with self.Session() as session:
workflow = WorkflowModel(
organization_id=organization_id,
title=title,
description=description,
workflow_definition=workflow_definition,
)
session.add(workflow)
session.commit()
session.refresh(workflow)
return convert_to_workflow(workflow, self.debug_enabled)
async def get_workflow(self, workflow_id: str) -> Workflow | None:
try:
with self.Session() as session:
if workflow := session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first():
return convert_to_workflow(workflow, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def update_workflow(
self,
workflow_id: str,
title: str | None = None,
description: str | None = None,
workflow_definition: dict[str, Any] | None = None,
) -> Workflow | None:
with self.Session() as session:
workflow = session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first()
if workflow:
if title:
workflow.title = title
if description:
workflow.description = description
if workflow_definition:
workflow.workflow_definition = workflow_definition
session.commit()
session.refresh(workflow)
return convert_to_workflow(workflow, self.debug_enabled)
LOG.error("Workflow not found, nothing to update", workflow_id=workflow_id)
return None
async def create_workflow_run(
self, workflow_id: str, proxy_location: ProxyLocation | None = None, webhook_callback_url: str | None = None
) -> WorkflowRun:
try:
with self.Session() as session:
workflow_run = WorkflowRunModel(
workflow_id=workflow_id,
proxy_location=proxy_location,
status="created",
webhook_callback_url=webhook_callback_url,
)
session.add(workflow_run)
session.commit()
session.refresh(workflow_run)
return convert_to_workflow_run(workflow_run)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def update_workflow_run(self, workflow_run_id: str, status: WorkflowRunStatus) -> WorkflowRun | None:
with self.Session() as session:
workflow_run = session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first()
if workflow_run:
workflow_run.status = status
session.commit()
session.refresh(workflow_run)
return convert_to_workflow_run(workflow_run)
LOG.error("WorkflowRun not found, nothing to update", workflow_run_id=workflow_run_id)
return None
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun | None:
try:
with self.Session() as session:
if workflow_run := session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first():
return convert_to_workflow_run(workflow_run)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_workflow_runs(self, workflow_id: str) -> list[WorkflowRun]:
try:
with self.Session() as session:
workflow_runs = session.query(WorkflowRunModel).filter_by(workflow_id=workflow_id).all()
return [convert_to_workflow_run(run) for run in workflow_runs]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def create_workflow_parameter(
self,
workflow_id: str,
workflow_parameter_type: WorkflowParameterType,
key: str,
default_value: Any,
description: str | None = None,
) -> WorkflowParameter:
try:
with self.Session() as session:
workflow_parameter = WorkflowParameterModel(
workflow_id=workflow_id,
workflow_parameter_type=workflow_parameter_type,
key=key,
default_value=default_value,
description=description,
)
session.add(workflow_parameter)
session.commit()
session.refresh(workflow_parameter)
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def create_aws_secret_parameter(
self,
workflow_id: str,
key: str,
aws_key: str,
description: str | None = None,
) -> AWSSecretParameter:
with self.Session() as session:
aws_secret_parameter = AWSSecretParameterModel(
workflow_id=workflow_id,
key=key,
aws_key=aws_key,
description=description,
)
session.add(aws_secret_parameter)
session.commit()
session.refresh(aws_secret_parameter)
return convert_to_aws_secret_parameter(aws_secret_parameter)
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
try:
with self.Session() as session:
workflow_parameters = session.query(WorkflowParameterModel).filter_by(workflow_id=workflow_id).all()
return [convert_to_workflow_parameter(parameter) for parameter in workflow_parameters]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_workflow_parameter(self, workflow_parameter_id: str) -> WorkflowParameter | None:
try:
with self.Session() as session:
if workflow_parameter := (
session.query(WorkflowParameterModel).filter_by(workflow_parameter_id=workflow_parameter_id).first()
):
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def create_workflow_run_parameter(
self, workflow_run_id: str, workflow_parameter_id: str, value: Any
) -> WorkflowRunParameter:
try:
with self.Session() as session:
workflow_run_parameter = WorkflowRunParameterModel(
workflow_run_id=workflow_run_id,
workflow_parameter_id=workflow_parameter_id,
value=value,
)
session.add(workflow_run_parameter)
session.commit()
session.refresh(workflow_run_parameter)
workflow_parameter = await self.get_workflow_parameter(workflow_parameter_id)
if not workflow_parameter:
raise WorkflowParameterNotFound(workflow_parameter_id)
return convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_workflow_run_parameters(
self, workflow_run_id: str
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
try:
with self.Session() as session:
workflow_run_parameters = (
session.query(WorkflowRunParameterModel).filter_by(workflow_run_id=workflow_run_id).all()
)
results = []
for workflow_run_parameter in workflow_run_parameters:
workflow_parameter = await self.get_workflow_parameter(workflow_run_parameter.workflow_parameter_id)
if not workflow_parameter:
raise WorkflowParameterNotFound(
workflow_parameter_id=workflow_run_parameter.workflow_parameter_id
)
results.append(
(
workflow_parameter,
convert_to_workflow_run_parameter(
workflow_run_parameter, workflow_parameter, self.debug_enabled
),
)
)
return results
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
try:
with self.Session() as session:
if task := (
session.query(TaskModel)
.filter_by(workflow_run_id=workflow_run_id)
.order_by(TaskModel.created_at.desc())
.first()
):
return convert_to_task(task, debug_enabled=self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
try:
with self.Session() as session:
tasks = (
session.query(TaskModel)
.filter_by(workflow_run_id=workflow_run_id)
.order_by(TaskModel.created_at)
.all()
)
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def delete_task_artifacts(self, organization_id: str, task_id: str) -> None:
with self.Session() as session:
# delete artifacts by filtering organization_id and task_id
stmt = delete(ArtifactModel).where(
and_(
ArtifactModel.organization_id == organization_id,
ArtifactModel.task_id == task_id,
)
)
session.execute(stmt)
session.commit()
async def delete_task_steps(self, organization_id: str, task_id: str) -> None:
with self.Session() as session:
# delete artifacts by filtering organization_id and task_id
stmt = delete(StepModel).where(
and_(
StepModel.organization_id == organization_id,
StepModel.task_id == task_id,
)
)
session.execute(stmt)
session.commit()

View File

@@ -0,0 +1,15 @@
from enum import StrEnum
class OrganizationAuthTokenType(StrEnum):
api = "api"
class ScheduleRuleUnit(StrEnum):
# No support for scheduling every second
minute = "minute"
hour = "hour"
day = "day"
week = "week"
month = "month"
year = "year"

View File

@@ -0,0 +1,2 @@
class NotFoundError(Exception):
pass

136
skyvern/forge/sdk/db/id.py Normal file
View File

@@ -0,0 +1,136 @@
import hashlib
import itertools
import os
import platform
import random
import time
# 6/20/2022 12AM
BASE_EPOCH = 1655683200
VERSION = 0
# Number of bits
TIMESTAMP_BITS = 32
WORKER_ID_BITS = 21
SEQUENCE_BITS = 10
VERSION_BITS = 1
# Bit shits (left)
TIMESTAMP_SHIFT = 32
WORKER_ID_SHIFT = 11
SEQUENCE_SHIFT = 1
VERSION_SHIFT = 0
SEQUENCE_MAX = (2**SEQUENCE_BITS) - 1
_sequence_start = None
SEQUENCE_COUNTER = itertools.count()
_worker_hash = None
# prefix
ORGANIZATION_AUTH_TOKEN_PREFIX = "oat"
ORG_PREFIX = "o"
TASK_PREFIX = "tsk"
USER_PREFIX = "u"
STEP_PREFIX = "stp"
ARTIFACT_PREFIX = "a"
WORKFLOW_PREFIX = "w"
WORKFLOW_RUN_PREFIX = "wr"
WORKFLOW_PARAMETER_PREFIX = "wp"
AWS_SECRET_PARAMETER_PREFIX = "asp"
def generate_workflow_id() -> str:
int_id = generate_id()
return f"{WORKFLOW_PREFIX}_{int_id}"
def generate_workflow_run_id() -> str:
int_id = generate_id()
return f"{WORKFLOW_RUN_PREFIX}_{int_id}"
def generate_aws_secret_parameter_id() -> str:
int_id = generate_id()
return f"{AWS_SECRET_PARAMETER_PREFIX}_{int_id}"
def generate_workflow_parameter_id() -> str:
int_id = generate_id()
return f"{WORKFLOW_PARAMETER_PREFIX}_{int_id}"
def generate_organization_auth_token_id() -> str:
int_id = generate_id()
return f"{ORGANIZATION_AUTH_TOKEN_PREFIX}_{int_id}"
def generate_org_id() -> str:
int_id = generate_id()
return f"{ORG_PREFIX}_{int_id}"
def generate_task_id() -> str:
int_id = generate_id()
return f"{TASK_PREFIX}_{int_id}"
def generate_step_id() -> str:
int_id = generate_id()
return f"{STEP_PREFIX}_{int_id}"
def generate_artifact_id() -> str:
int_id = generate_id()
return f"{ARTIFACT_PREFIX}_{int_id}"
def generate_user_id() -> str:
int_id = generate_id()
return f"{USER_PREFIX}_{int_id}"
def generate_id() -> int:
"""
generate a 64-bit int ID
"""
create_at = current_time() - BASE_EPOCH
sequence = _increment_and_get_sequence()
time_part = _mask_shift(create_at, TIMESTAMP_BITS, TIMESTAMP_SHIFT)
worker_part = _mask_shift(_get_worker_hash(), WORKER_ID_BITS, WORKER_ID_SHIFT)
sequence_part = _mask_shift(sequence, SEQUENCE_BITS, SEQUENCE_SHIFT)
version_part = _mask_shift(VERSION, VERSION_BITS, VERSION_SHIFT)
return time_part | worker_part | sequence_part | version_part
def _increment_and_get_sequence() -> int:
global _sequence_start
if _sequence_start is None:
_sequence_start = random.randint(0, SEQUENCE_MAX)
return (_sequence_start + next(SEQUENCE_COUNTER)) % SEQUENCE_MAX
def current_time() -> int:
return int(time.time())
def current_time_ms() -> int:
return int(time.time() * 1000)
def _mask_shift(value: int, mask_bits: int, shift_bits: int) -> int:
return (value & ((2**mask_bits) - 1)) << shift_bits
def _get_worker_hash() -> int:
global _worker_hash
if _worker_hash is None:
_worker_hash = _generate_worker_hash()
return _worker_hash
def _generate_worker_hash() -> int:
worker_identity = f"{platform.node()}:{os.getpid()}"
return int(hashlib.md5(worker_identity.encode()).hexdigest()[-15:], 16)

View File

@@ -0,0 +1,172 @@
import datetime
from sqlalchemy import JSON, Boolean, Column, DateTime, Enum, ForeignKey, Integer, Numeric, String, UnicodeText
from sqlalchemy.orm import DeclarativeBase
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.id import (
generate_artifact_id,
generate_aws_secret_parameter_id,
generate_org_id,
generate_organization_auth_token_id,
generate_step_id,
generate_task_id,
generate_workflow_id,
generate_workflow_parameter_id,
generate_workflow_run_id,
)
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
class Base(DeclarativeBase):
pass
class TaskModel(Base):
__tablename__ = "tasks"
task_id = Column(String, primary_key=True, index=True, default=generate_task_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"))
status = Column(String)
webhook_callback_url = Column(String)
url = Column(String)
navigation_goal = Column(String)
data_extraction_goal = Column(String)
navigation_payload = Column(JSON)
extracted_information = Column(JSON)
failure_reason = Column(String)
proxy_location = Column(Enum(ProxyLocation))
extracted_information_schema = Column(JSON)
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"))
order = Column(Integer, nullable=True)
retry = Column(Integer, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
class StepModel(Base):
__tablename__ = "steps"
step_id = Column(String, primary_key=True, index=True, default=generate_step_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"))
task_id = Column(String, ForeignKey("tasks.task_id"))
status = Column(String)
output = Column(JSON)
order = Column(Integer)
is_last = Column(Boolean, default=False)
retry_index = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
input_token_count = Column(Integer, default=0)
output_token_count = Column(Integer, default=0)
step_cost = Column(Numeric, default=0)
class OrganizationModel(Base):
__tablename__ = "organizations"
organization_id = Column(String, primary_key=True, index=True, default=generate_org_id)
organization_name = Column(String, nullable=False)
webhook_callback_url = Column(UnicodeText)
max_steps_per_run = Column(Integer)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
class OrganizationAuthTokenModel(Base):
__tablename__ = "organization_auth_tokens"
id = Column(
String,
primary_key=True,
index=True,
default=generate_organization_auth_token_id,
)
organization_id = Column(String, ForeignKey("organizations.organization_id"), index=True, nullable=False)
token_type = Column(Enum(OrganizationAuthTokenType), nullable=False)
token = Column(String, index=True, nullable=False)
valid = Column(Boolean, nullable=False, default=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
deleted_at = Column(DateTime, nullable=True)
class ArtifactModel(Base):
__tablename__ = "artifacts"
artifact_id = Column(String, primary_key=True, index=True, default=generate_artifact_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"))
task_id = Column(String, ForeignKey("tasks.task_id"))
step_id = Column(String, ForeignKey("steps.step_id"))
artifact_type = Column(String)
uri = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
class WorkflowModel(Base):
__tablename__ = "workflows"
workflow_id = Column(String, primary_key=True, index=True, default=generate_workflow_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"))
title = Column(String, nullable=False)
description = Column(String, nullable=True)
workflow_definition = Column(JSON, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
deleted_at = Column(DateTime, nullable=True)
class WorkflowRunModel(Base):
__tablename__ = "workflow_runs"
workflow_run_id = Column(String, primary_key=True, index=True, default=generate_workflow_run_id)
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), nullable=False)
status = Column(String, nullable=False)
proxy_location = Column(Enum(ProxyLocation))
webhook_callback_url = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
class WorkflowParameterModel(Base):
__tablename__ = "workflow_parameters"
workflow_parameter_id = Column(String, primary_key=True, index=True, default=generate_workflow_parameter_id)
workflow_parameter_type = Column(String, nullable=False)
key = Column(String, nullable=False)
description = Column(String, nullable=True)
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
default_value = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
deleted_at = Column(DateTime, nullable=True)
class AWSSecretParameterModel(Base):
__tablename__ = "aws_secret_parameters"
aws_secret_parameter_id = Column(String, primary_key=True, index=True, default=generate_aws_secret_parameter_id)
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
key = Column(String, nullable=False)
description = Column(String, nullable=True)
aws_key = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
deleted_at = Column(DateTime, nullable=True)
class WorkflowRunParameterModel(Base):
__tablename__ = "workflow_run_parameters"
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), primary_key=True, index=True)
workflow_parameter_id = Column(
String, ForeignKey("workflow_parameters.workflow_parameter_id"), primary_key=True, index=True
)
# Can be bool | int | float | str | dict | list depending on the workflow parameter type
value = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)

View File

@@ -0,0 +1,220 @@
import json
import typing
import pydantic.json
import structlog
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.models import (
ArtifactModel,
AWSSecretParameterModel,
OrganizationAuthTokenModel,
OrganizationModel,
StepModel,
TaskModel,
WorkflowModel,
WorkflowParameterModel,
WorkflowRunModel,
WorkflowRunParameterModel,
)
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
from skyvern.forge.sdk.workflow.models.workflow import (
Workflow,
WorkflowDefinition,
WorkflowRun,
WorkflowRunParameter,
WorkflowRunStatus,
)
LOG = structlog.get_logger()
@typing.no_type_check
def _custom_json_serializer(*args, **kwargs) -> str:
"""
Encodes json in the same way that pydantic does.
"""
return json.dumps(*args, default=pydantic.json.pydantic_encoder, **kwargs)
def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
if debug_enabled:
LOG.debug("Converting TaskModel to Task", task_id=task_obj.task_id)
task = Task(
task_id=task_obj.task_id,
status=TaskStatus(task_obj.status),
created_at=task_obj.created_at,
modified_at=task_obj.modified_at,
url=task_obj.url,
webhook_callback_url=task_obj.webhook_callback_url,
navigation_goal=task_obj.navigation_goal,
data_extraction_goal=task_obj.data_extraction_goal,
navigation_payload=task_obj.navigation_payload,
extracted_information=task_obj.extracted_information,
failure_reason=task_obj.failure_reason,
organization_id=task_obj.organization_id,
proxy_location=ProxyLocation(task_obj.proxy_location) if task_obj.proxy_location else None,
extracted_information_schema=task_obj.extracted_information_schema,
workflow_run_id=task_obj.workflow_run_id,
order=task_obj.order,
retry=task_obj.retry,
)
return task
def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
if debug_enabled:
LOG.debug("Converting StepModel to Step", step_id=step_model.step_id)
return Step(
task_id=step_model.task_id,
step_id=step_model.step_id,
created_at=step_model.created_at,
modified_at=step_model.modified_at,
status=StepStatus(step_model.status),
output=step_model.output,
order=step_model.order,
is_last=step_model.is_last,
retry_index=step_model.retry_index,
organization_id=step_model.organization_id,
input_token_count=step_model.input_token_count,
output_token_count=step_model.output_token_count,
step_cost=step_model.step_cost,
)
def convert_to_organization(org_model: OrganizationModel) -> Organization:
return Organization(
organization_id=org_model.organization_id,
organization_name=org_model.organization_name,
webhook_callback_url=org_model.webhook_callback_url,
max_steps_per_run=org_model.max_steps_per_run,
created_at=org_model.created_at,
modified_at=org_model.modified_at,
)
def convert_to_organization_auth_token(org_auth_token: OrganizationAuthTokenModel) -> OrganizationAuthToken:
return OrganizationAuthToken(
id=org_auth_token.id,
organization_id=org_auth_token.organization_id,
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
token=org_auth_token.token,
valid=org_auth_token.valid,
created_at=org_auth_token.created_at,
modified_at=org_auth_token.modified_at,
)
def convert_to_artifact(artifact_model: ArtifactModel, debug_enabled: bool = False) -> Artifact:
if debug_enabled:
LOG.debug("Converting ArtifactModel to Artifact", artifact_id=artifact_model.artifact_id)
return Artifact(
artifact_id=artifact_model.artifact_id,
artifact_type=ArtifactType[artifact_model.artifact_type.upper()],
uri=artifact_model.uri,
task_id=artifact_model.task_id,
step_id=artifact_model.step_id,
created_at=artifact_model.created_at,
modified_at=artifact_model.modified_at,
organization_id=artifact_model.organization_id,
)
def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = False) -> Workflow:
if debug_enabled:
LOG.debug("Converting WorkflowModel to Workflow", workflow_id=workflow_model.workflow_id)
return Workflow(
workflow_id=workflow_model.workflow_id,
organization_id=workflow_model.organization_id,
title=workflow_model.title,
description=workflow_model.description,
workflow_definition=WorkflowDefinition.model_validate(workflow_model.workflow_definition),
created_at=workflow_model.created_at,
modified_at=workflow_model.modified_at,
deleted_at=workflow_model.deleted_at,
)
def convert_to_workflow_run(workflow_run_model: WorkflowRunModel, debug_enabled: bool = False) -> WorkflowRun:
if debug_enabled:
LOG.debug("Converting WorkflowRunModel to WorkflowRun", workflow_run_id=workflow_run_model.workflow_run_id)
return WorkflowRun(
workflow_run_id=workflow_run_model.workflow_run_id,
workflow_id=workflow_run_model.workflow_id,
status=WorkflowRunStatus[workflow_run_model.status],
proxy_location=ProxyLocation(workflow_run_model.proxy_location) if workflow_run_model.proxy_location else None,
webhook_callback_url=workflow_run_model.webhook_callback_url,
created_at=workflow_run_model.created_at,
modified_at=workflow_run_model.modified_at,
)
def convert_to_workflow_parameter(
workflow_parameter_model: WorkflowParameterModel, debug_enabled: bool = False
) -> WorkflowParameter:
if debug_enabled:
LOG.debug(
"Converting WorkflowParameterModel to WorkflowParameter",
workflow_parameter_id=workflow_parameter_model.workflow_parameter_id,
)
workflow_parameter_type = WorkflowParameterType[workflow_parameter_model.workflow_parameter_type.upper()]
return WorkflowParameter(
workflow_parameter_id=workflow_parameter_model.workflow_parameter_id,
workflow_parameter_type=workflow_parameter_type,
workflow_id=workflow_parameter_model.workflow_id,
default_value=workflow_parameter_type.convert_value(workflow_parameter_model.default_value),
key=workflow_parameter_model.key,
description=workflow_parameter_model.description,
created_at=workflow_parameter_model.created_at,
modified_at=workflow_parameter_model.modified_at,
deleted_at=workflow_parameter_model.deleted_at,
)
def convert_to_aws_secret_parameter(
aws_secret_parameter_model: AWSSecretParameterModel, debug_enabled: bool = False
) -> AWSSecretParameter:
if debug_enabled:
LOG.debug(
"Converting AWSSecretParameterModel to AWSSecretParameter",
aws_secret_parameter_id=aws_secret_parameter_model.id,
)
return AWSSecretParameter(
aws_secret_parameter_id=aws_secret_parameter_model.aws_secret_parameter_id,
workflow_id=aws_secret_parameter_model.workflow_id,
key=aws_secret_parameter_model.key,
description=aws_secret_parameter_model.description,
aws_key=aws_secret_parameter_model.aws_key,
created_at=aws_secret_parameter_model.created_at,
modified_at=aws_secret_parameter_model.modified_at,
deleted_at=aws_secret_parameter_model.deleted_at,
)
def convert_to_workflow_run_parameter(
workflow_run_parameter_model: WorkflowRunParameterModel,
workflow_parameter: WorkflowParameter,
debug_enabled: bool = False,
) -> WorkflowRunParameter:
if debug_enabled:
LOG.debug(
"Converting WorkflowRunParameterModel to WorkflowRunParameter",
workflow_run_id=workflow_run_parameter_model.workflow_run_id,
workflow_parameter_id=workflow_run_parameter_model.workflow_parameter_id,
)
return WorkflowRunParameter(
workflow_run_id=workflow_run_parameter_model.workflow_run_id,
workflow_parameter_id=workflow_run_parameter_model.workflow_parameter_id,
value=workflow_parameter.workflow_parameter_type.convert_value(workflow_run_parameter_model.value),
created_at=workflow_run_parameter_model.created_at,
)

View File

View File

@@ -0,0 +1,85 @@
import abc
from fastapi import BackgroundTasks
from skyvern.forge import app
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.models import Organization
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
class AsyncExecutor(abc.ABC):
@abc.abstractmethod
async def execute_task(
self,
background_tasks: BackgroundTasks,
task: Task,
organization: Organization,
max_steps_override: int | None,
api_key: str | None,
) -> None:
pass
@abc.abstractmethod
async def execute_workflow(
self,
background_tasks: BackgroundTasks,
organization: Organization,
workflow_id: str,
workflow_run_id: str,
max_steps_override: int | None,
api_key: str | None,
) -> None:
pass
class BackgroundTaskExecutor(AsyncExecutor):
async def execute_task(
self,
background_tasks: BackgroundTasks,
task: Task,
organization: Organization,
max_steps_override: int | None,
api_key: str | None,
) -> None:
step = await app.DATABASE.create_step(
task.task_id,
order=0,
retry_index=0,
organization_id=organization.organization_id,
)
task = await app.DATABASE.update_task(
task.task_id,
TaskStatus.running,
organization_id=organization.organization_id,
)
context: SkyvernContext = skyvern_context.ensure_context()
context.task_id = task.task_id
context.organization_id = organization.organization_id
context.max_steps_override = max_steps_override
background_tasks.add_task(
app.agent.execute_step,
organization,
task,
step,
api_key,
)
async def execute_workflow(
self,
background_tasks: BackgroundTasks,
organization: Organization,
workflow_id: str,
workflow_run_id: str,
max_steps_override: int | None,
api_key: str | None,
) -> None:
background_tasks.add_task(
app.WORKFLOW_SERVICE.execute_workflow,
workflow_run_id=workflow_run_id,
api_key=api_key,
)

View File

@@ -0,0 +1,13 @@
from skyvern.forge.sdk.executor.async_executor import AsyncExecutor, BackgroundTaskExecutor
class AsyncExecutorFactory:
__instance: AsyncExecutor = BackgroundTaskExecutor()
@staticmethod
def set_executor(executor: AsyncExecutor) -> None:
AsyncExecutorFactory.__instance = executor
@staticmethod
def get_executor() -> AsyncExecutor:
return AsyncExecutorFactory.__instance

View File

@@ -0,0 +1,90 @@
import logging
import structlog
from structlog.typing import EventDict
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.settings_manager import SettingsManager
def add_kv_pairs_to_msg(logger: logging.Logger, method_name: str, event_dict: EventDict) -> EventDict:
"""
A custom processor to add key-value pairs to the 'msg' field.
"""
# Add context to the log
context = skyvern_context.current()
if context:
if context.request_id:
event_dict["request_id"] = context.request_id
if context.organization_id:
event_dict["organization_id"] = context.organization_id
if context.task_id:
event_dict["task_id"] = context.task_id
if context.workflow_id:
event_dict["workflow_id"] = context.workflow_id
if context.workflow_run_id:
event_dict["workflow_run_id"] = context.workflow_run_id
# Add env to the log
event_dict["env"] = SettingsManager.get_settings().ENV
if method_name not in ["info", "warning", "error", "critical", "exception"]:
# Only modify the log for these log levels
return event_dict
# Assuming 'event' or 'msg' is the field to update
msg_field = event_dict.get("msg", "")
# Add key-value pairs
kv_pairs = {k: v for k, v in event_dict.items() if k not in ["msg", "timestamp", "level"]}
if kv_pairs:
additional_info = ", ".join(f"{k}={v}" for k, v in kv_pairs.items())
msg_field += f" | {additional_info}"
event_dict["msg"] = msg_field
return event_dict
def setup_logger() -> None:
"""
Setup the logger with the specified format
"""
# logging.config.dictConfig(logging_config)
renderer = (
structlog.processors.JSONRenderer()
if SettingsManager.get_settings().JSON_LOGGING
else structlog.dev.ConsoleRenderer()
)
additional_processors = (
[
structlog.processors.EventRenamer("msg"),
add_kv_pairs_to_msg,
structlog.processors.CallsiteParameterAdder(
{
structlog.processors.CallsiteParameter.PATHNAME,
structlog.processors.CallsiteParameter.FILENAME,
structlog.processors.CallsiteParameter.MODULE,
structlog.processors.CallsiteParameter.FUNC_NAME,
structlog.processors.CallsiteParameter.LINENO,
}
),
]
if SettingsManager.get_settings().JSON_LOGGING
else []
)
structlog.configure(
processors=[
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso"),
# structlog.processors.dict_tracebacks,
structlog.processors.format_exc_info,
]
+ additional_processors
+ [renderer]
)
uvicorn_error = logging.getLogger("uvicorn.error")
uvicorn_error.disabled = True
uvicorn_access = logging.getLogger("uvicorn.access")
uvicorn_access.disabled = True

137
skyvern/forge/sdk/models.py Normal file
View File

@@ -0,0 +1,137 @@
from __future__ import annotations
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.webeye.actions.actions import ActionType
from skyvern.webeye.actions.models import AgentStepOutput
class StepStatus(StrEnum):
created = "created"
running = "running"
failed = "failed"
completed = "completed"
def can_update_to(self, new_status: StepStatus) -> bool:
allowed_transitions: dict[StepStatus, set[StepStatus]] = {
StepStatus.created: {StepStatus.running},
StepStatus.running: {StepStatus.completed, StepStatus.failed},
StepStatus.failed: set(),
StepStatus.completed: set(),
}
return new_status in allowed_transitions[self]
def requires_output(self) -> bool:
status_requires_output = {StepStatus.completed}
return self in status_requires_output
def cant_have_output(self) -> bool:
status_cant_have_output = {StepStatus.created, StepStatus.running}
return self in status_cant_have_output
def is_terminal(self) -> bool:
status_is_terminal = {StepStatus.failed, StepStatus.completed}
return self in status_is_terminal
class Step(BaseModel):
created_at: datetime
modified_at: datetime
task_id: str
step_id: str
status: StepStatus
output: AgentStepOutput | None = None
order: int
is_last: bool
retry_index: int = 0
organization_id: str | None = None
input_token_count: int = 0
output_token_count: int = 0
step_cost: float = 0
def validate_update(self, status: StepStatus | None, output: AgentStepOutput | None, is_last: bool | None) -> None:
old_status = self.status
if status and not old_status.can_update_to(status):
raise ValueError(f"invalid_status_transition({old_status},{status},{self.step_id})")
if status and status.requires_output() and output is None:
raise ValueError(f"status_requires_output({status},{self.step_id})")
if status and status.cant_have_output() and output is not None:
raise ValueError(f"status_cant_have_output({status},{self.step_id})")
if output is not None and status is None:
raise ValueError(f"cant_set_output_without_updating_status({self.step_id})")
if self.output is not None and output is not None:
raise ValueError(f"cant_override_output({self.step_id})")
if is_last and not self.status.is_terminal():
raise ValueError(f"is_last_but_status_not_terminal({self.status},{self.step_id})")
if is_last is False:
raise ValueError(f"cant_set_is_last_to_false({self.step_id})")
def is_goal_achieved(self) -> bool:
if self.status != StepStatus.completed:
return False
# TODO (kerem): Remove this check once we have backfilled all the steps
if self.output is None or self.output.actions_and_results is None:
return False
# Check if there is a successful complete action
for action, action_results in self.output.actions_and_results:
if action.action_type != ActionType.COMPLETE:
continue
if any(action_result.success for action_result in action_results):
return True
return False
def is_terminated(self) -> bool:
if self.status != StepStatus.completed:
return False
# TODO (kerem): Remove this check once we have backfilled all the steps
if self.output is None or self.output.actions_and_results is None:
return False
# Check if there is a successful terminate action
for action, action_results in self.output.actions_and_results:
if action.action_type != ActionType.TERMINATE:
continue
if any(action_result.success for action_result in action_results):
return True
return False
class Organization(BaseModel):
organization_id: str
organization_name: str
webhook_callback_url: str | None = None
max_steps_per_run: int | None = None
created_at: datetime
modified_at: datetime
class OrganizationAuthToken(BaseModel):
id: str
organization_id: str
token_type: OrganizationAuthTokenType
token: str
valid: bool
created_at: datetime
modified_at: datetime
class TokenPayload(BaseModel):
sub: str
exp: int

View File

@@ -0,0 +1,98 @@
"""
Relative to this file I will have a prompt directory its located ../prompts
In this directory there will be a techniques directory and a directory for each model - gpt-3.5-turbo gpt-4, llama-2-70B, code-llama-7B etc
Each directory will have jinga2 templates for the prompts.
prompts in the model directories can use the techniques in the techniques directory.
Write the code I'd need to load and populate the templates.
I want the following functions:
class PromptEngine:
def __init__(self, model):
pass
def load_prompt(model, prompt_name, prompt_ags) -> str:
pass
"""
import glob
import os
from difflib import get_close_matches
from typing import Any, List
import structlog
from jinja2 import Environment, FileSystemLoader
LOG = structlog.get_logger()
class PromptEngine:
"""
Class to handle loading and populating Jinja2 templates for prompts.
"""
def __init__(self, model: str):
"""
Initialize the PromptEngine with the specified model.
Args:
model (str): The model to use for loading prompts.
"""
self.model = model
try:
# Get the list of all model directories
models_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../prompts"))
model_names = [
os.path.basename(os.path.normpath(d))
for d in glob.glob(os.path.join(models_dir, "*/"))
if os.path.isdir(d) and "techniques" not in d
]
self.model = self.get_closest_match(self.model, model_names)
self.env = Environment(loader=FileSystemLoader(models_dir))
except Exception:
LOG.error("Error initializing PromptEngine.", model=model, exc_info=True)
raise
@staticmethod
def get_closest_match(target: str, model_dirs: List[str]) -> str:
"""
Find the closest match to the target in the list of model directories.
Args:
target (str): The target model.
model_dirs (list): The list of available model directories.
Returns:
str: The closest match to the target.
"""
try:
matches = get_close_matches(target, model_dirs, n=1, cutoff=0.1)
return matches[0]
except Exception:
LOG.error("Failed to get closest match.", target=target, model_dirs=model_dirs, exc_info=True)
raise
def load_prompt(self, template: str, **kwargs: Any) -> str:
"""
Load and populate the specified template.
Args:
template (str): The name of the template to load.
**kwargs: The arguments to populate the template with.
Returns:
str: The populated template.
"""
try:
template = os.path.join(self.model, template)
jinja_template = self.env.get_template(f"{template}.j2")
return jinja_template.render(**kwargs)
except Exception:
LOG.error("Failed to load prompt.", template=template, kwargs_keys=kwargs.keys(), exc_info=True)
raise

View File

View File

@@ -0,0 +1,397 @@
from typing import Annotated, Any
import structlog
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, Query, Request, Response, status
from fastapi.responses import ORJSONResponse
from pydantic import BaseModel
from skyvern.exceptions import StepNotFound
from skyvern.forge import app
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.models import Organization, Step
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
from skyvern.forge.sdk.services import org_auth_service
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.models.workflow import (
RunWorkflowResponse,
WorkflowRequestBody,
WorkflowRunStatusResponse,
)
base_router = APIRouter()
LOG = structlog.get_logger()
@base_router.post("/webhook", tags=["server"])
async def webhook(
request: Request,
x_skyvern_signature: Annotated[str | None, Header()] = None,
x_skyvern_timestamp: Annotated[str | None, Header()] = None,
) -> Response:
payload = await request.body()
if not x_skyvern_signature or not x_skyvern_timestamp:
LOG.error(
"Webhook signature or timestamp missing",
x_skyvern_signature=x_skyvern_signature,
x_skyvern_timestamp=x_skyvern_timestamp,
payload=payload,
)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Missing webhook signature or timestamp")
generated_signature = generate_skyvern_signature(
payload.decode("utf-8"),
SettingsManager.get_settings().SKYVERN_API_KEY,
)
LOG.info(
"Webhook received",
x_skyvern_signature=x_skyvern_signature,
x_skyvern_timestamp=x_skyvern_timestamp,
payload=payload,
generated_signature=generated_signature,
valid_signature=x_skyvern_signature == generated_signature,
)
return Response(content="webhook validation", status_code=200)
@base_router.get("/heartbeat", tags=["server"])
async def check_server_status() -> Response:
"""
Check if the server is running.
"""
return Response(content="Server is running.", status_code=200)
@base_router.post("/tasks", tags=["agent"], response_model=CreateTaskResponse)
async def create_agent_task(
background_tasks: BackgroundTasks,
request: Request,
task: TaskRequest,
current_org: Organization = Depends(org_auth_service.get_current_org),
x_api_key: Annotated[str | None, Header()] = None,
x_max_steps_override: Annotated[int | None, Header()] = None,
) -> CreateTaskResponse:
agent = request["agent"]
created_task = await agent.create_task(task, current_org.organization_id)
if x_max_steps_override:
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
await app.ASYNC_EXECUTOR.execute_task(
background_tasks=background_tasks,
task=created_task,
organization=current_org,
max_steps_override=x_max_steps_override,
api_key=x_api_key,
)
return CreateTaskResponse(task_id=created_task.task_id)
@base_router.post(
"/tasks/{task_id}/steps/{step_id}",
tags=["agent"],
response_model=Step,
summary="Executes a specific step",
)
@base_router.post(
"/tasks/{task_id}/steps/",
tags=["agent"],
response_model=Step,
summary="Executes the next step",
)
async def execute_agent_task_step(
request: Request,
task_id: str,
step_id: str | None = None,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Response:
agent = request["agent"]
task = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
if not task:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No task found with id {task_id}",
)
# An empty step request means that the agent should execute the next step for the task.
if not step_id:
step = await app.DATABASE.get_latest_step(task_id=task_id, organization_id=current_org.organization_id)
if not step:
raise StepNotFound(current_org.organization_id, task_id)
LOG.info(
"Executing latest step since no step_id was provided",
task_id=task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
)
if not step:
LOG.error(
"No steps found for task",
task_id=task_id,
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No steps found for task {task_id}",
)
else:
step = await app.DATABASE.get_step(task_id, step_id, organization_id=current_org.organization_id)
if not step:
raise StepNotFound(current_org.organization_id, task_id, step_id)
LOG.info(
"Executing step",
task_id=task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
)
if not step:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No step found with id {step_id}",
)
step, _, _ = await agent.execute_step(current_org, task, step)
return Response(
content=step.model_dump_json() if step else "",
status_code=200,
media_type="application/json",
)
@base_router.get("/tasks/{task_id}", response_model=TaskResponse)
async def get_task(
request: Request,
task_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> TaskResponse:
request["agent"]
task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
if not task_obj:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task not found {task_id}",
)
# get latest step
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=current_org.organization_id)
if not latest_step:
return task_obj.to_task_response()
screenshot_url = None
# todo (kerem): only access artifacts through the artifact manager instead of db
screenshot_artifact = await app.DATABASE.get_latest_artifact(
task_id=task_obj.task_id,
step_id=latest_step.step_id,
artifact_types=[ArtifactType.SCREENSHOT_ACTION, ArtifactType.SCREENSHOT_FINAL],
organization_id=current_org.organization_id,
)
if screenshot_artifact:
screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(screenshot_artifact)
recording_artifact = await app.DATABASE.get_latest_artifact(
task_id=task_obj.task_id,
artifact_types=[ArtifactType.RECORDING],
organization_id=current_org.organization_id,
)
recording_url = None
if recording_artifact:
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
failure_reason = None
if task_obj.status == TaskStatus.failed and (latest_step.output or task_obj.failure_reason):
failure_reason = ""
if task_obj.failure_reason:
failure_reason += f"Reasoning: {task_obj.failure_reason or ''}"
failure_reason += "\n"
if latest_step.output and latest_step.output.action_results:
failure_reason += "Exceptions: "
failure_reason += str(
[f"[{ar.exception_type}]: {ar.exception_message}" for ar in latest_step.output.action_results]
)
return task_obj.to_task_response(
screenshot_url=screenshot_url,
recording_url=recording_url,
failure_reason=failure_reason,
)
@base_router.get("/internal/tasks/{task_id}", response_model=list[Task])
async def get_task_internal(
request: Request,
task_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Response:
"""
Get all tasks.
:param request:
:param page: Starting page, defaults to 1
:param page_size:
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
get_agent_task endpoint.
"""
task = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
if not task:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task not found {task_id}",
)
return ORJSONResponse(task.model_dump())
@base_router.get("/tasks", tags=["agent"], response_model=list[Task])
async def get_agent_tasks(
request: Request,
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1),
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Response:
"""
Get all tasks.
:param request:
:param page: Starting page, defaults to 1
:param page_size: Page size, defaults to 10
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
get_agent_task endpoint.
"""
request["agent"]
tasks = await app.DATABASE.get_tasks(page, page_size, organization_id=current_org.organization_id)
return ORJSONResponse([task.to_task_response().model_dump() for task in tasks])
@base_router.get("/internal/tasks", tags=["agent"], response_model=list[Task])
async def get_agent_tasks_internal(
request: Request,
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1),
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Response:
"""
Get all tasks.
:param request:
:param page: Starting page, defaults to 1
:param page_size: Page size, defaults to 10
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
get_agent_task endpoint.
"""
request["agent"]
tasks = await app.DATABASE.get_tasks(page, page_size, organization_id=current_org.organization_id)
return ORJSONResponse([task.model_dump() for task in tasks])
@base_router.get("/tasks/{task_id}/steps", tags=["agent"], response_model=list[Step])
async def get_agent_task_steps(
request: Request,
task_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Response:
"""
Get all steps for a task.
:param request:
:param task_id:
:return: List of steps for a task with pagination.
"""
request["agent"]
steps = await app.DATABASE.get_task_steps(task_id, organization_id=current_org.organization_id)
return ORJSONResponse([step.model_dump() for step in steps])
@base_router.get("/tasks/{task_id}/steps/{step_id}/artifacts", tags=["agent"], response_model=list[Artifact])
async def get_agent_task_step_artifacts(
request: Request,
task_id: str,
step_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Response:
"""
Get all artifacts for a list of steps.
:param request:
:param task_id:
:param step_id:
:return: List of artifacts for a list of steps.
"""
request["agent"]
artifacts = await app.DATABASE.get_artifacts_for_task_step(
task_id,
step_id,
organization_id=current_org.organization_id,
)
return ORJSONResponse([artifact.model_dump() for artifact in artifacts])
class ActionResultTmp(BaseModel):
action: dict[str, Any]
data: dict[str, Any] | list | str | None = None
exception_message: str | None = None
success: bool = True
@base_router.get("/tasks/{task_id}/actions", response_model=list[ActionResultTmp])
async def get_task_actions(
request: Request,
task_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> list[ActionResultTmp]:
request["agent"]
steps = await app.DATABASE.get_task_step_models(task_id, organization_id=current_org.organization_id)
results: list[ActionResultTmp] = []
for step_s in steps:
if not step_s.output or "action_results" not in step_s.output:
continue
for action_result in step_s.output["action_results"]:
results.append(ActionResultTmp.model_validate(action_result))
return results
@base_router.post("/workflows/{workflow_id}/run", response_model=RunWorkflowResponse)
async def execute_workflow(
background_tasks: BackgroundTasks,
request: Request,
workflow_id: str,
workflow_request: WorkflowRequestBody,
current_org: Organization = Depends(org_auth_service.get_current_org),
x_api_key: Annotated[str | None, Header()] = None,
x_max_steps_override: Annotated[int | None, Header()] = None,
) -> RunWorkflowResponse:
LOG.info(
f"Running workflow {workflow_id}",
workflow_id=workflow_id,
)
context = skyvern_context.ensure_context()
request_id = context.request_id
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
request_id=request_id,
workflow_request=workflow_request,
workflow_id=workflow_id,
organization_id=current_org.organization_id,
max_steps_override=x_max_steps_override,
)
if x_max_steps_override:
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
await app.ASYNC_EXECUTOR.execute_workflow(
background_tasks=background_tasks,
organization=current_org,
workflow_id=workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
max_steps_override=x_max_steps_override,
api_key=x_api_key,
)
return RunWorkflowResponse(
workflow_id=workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
)
@base_router.get("/workflows/{workflow_id}/runs/{workflow_run_id}", response_model=WorkflowRunStatusResponse)
async def get_workflow_run(
request: Request,
workflow_id: str,
workflow_run_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> WorkflowRunStatusResponse:
request["agent"]
return await app.WORKFLOW_SERVICE.build_workflow_run_status_response(
workflow_id=workflow_id, workflow_run_id=workflow_run_id, organization_id=current_org.organization_id
)

View File

View File

@@ -0,0 +1,181 @@
from __future__ import annotations
from datetime import datetime
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
class ProxyLocation(StrEnum):
US_CA = "US-CA"
US_NY = "US-NY"
US_TX = "US-TX"
US_FL = "US-FL"
US_WA = "US-WA"
RESIDENTIAL = "RESIDENTIAL"
NONE = "NONE"
class TaskRequest(BaseModel):
url: str = Field(
...,
min_length=1,
description="Starting URL for the task.",
examples=["https://www.geico.com"],
)
# TODO: use HttpUrl instead of str
webhook_callback_url: str | None = Field(
default=None,
description="The URL to call when the task is completed.",
examples=["https://my-webhook.com"],
)
navigation_goal: str | None = Field(
default=None,
description="The user's goal for the task.",
examples=["Get a quote for car insurance"],
)
data_extraction_goal: str | None = Field(
default=None,
description="The user's goal for data extraction.",
examples=["Extract the quote price"],
)
navigation_payload: dict[str, Any] | list | str | None = Field(
None,
description="The user's details needed to achieve the task.",
examples=[{"name": "John Doe", "email": "john@doe.com"}],
)
proxy_location: ProxyLocation | None = Field(
None,
description="The location of the proxy to use for the task.",
examples=["US-WA", "US-CA", "US-FL", "US-NY", "US-TX"],
)
extracted_information_schema: dict[str, Any] | list | str | None = Field(
None,
description="The requested schema of the extracted information.",
)
class TaskStatus(StrEnum):
created = "created"
running = "running"
failed = "failed"
terminated = "terminated"
completed = "completed"
def is_final(self) -> bool:
return self in {TaskStatus.failed, TaskStatus.terminated, TaskStatus.completed}
def can_update_to(self, new_status: TaskStatus) -> bool:
allowed_transitions: dict[TaskStatus, set[TaskStatus]] = {
TaskStatus.created: {TaskStatus.running},
TaskStatus.running: {TaskStatus.completed, TaskStatus.failed, TaskStatus.terminated},
TaskStatus.failed: set(),
TaskStatus.completed: set(),
}
return new_status in allowed_transitions[self]
def requires_extracted_info(self) -> bool:
status_requires_extracted_information = {TaskStatus.completed}
return self in status_requires_extracted_information
def cant_have_extracted_info(self) -> bool:
status_cant_have_extracted_information = {
TaskStatus.created,
TaskStatus.running,
TaskStatus.failed,
TaskStatus.terminated,
}
return self in status_cant_have_extracted_information
def requires_failure_reason(self) -> bool:
status_requires_failure_reason = {TaskStatus.failed, TaskStatus.terminated}
return self in status_requires_failure_reason
class Task(TaskRequest):
created_at: datetime = Field(
...,
description="The creation datetime of the task.",
examples=["2023-01-01T00:00:00Z"],
)
modified_at: datetime = Field(
...,
description="The modification datetime of the task.",
examples=["2023-01-01T00:00:00Z"],
)
task_id: str = Field(
...,
description="The ID of the task.",
examples=["50da533e-3904-4401-8a07-c49adf88b5eb"],
)
status: TaskStatus = Field(..., description="The status of the task.", examples=["created"])
extracted_information: dict[str, Any] | list | str | None = Field(
None,
description="The extracted information from the task.",
)
failure_reason: str | None = Field(
None,
description="The reason for the task failure.",
)
organization_id: str | None = None
workflow_run_id: str | None = None
order: int | None = None
retry: int | None = None
def validate_update(
self,
status: TaskStatus,
extracted_information: dict[str, Any] | list | str | None,
failure_reason: str | None = None,
) -> None:
old_status = self.status
if not old_status.can_update_to(status):
raise ValueError(f"invalid_status_transition({old_status},{status},{self.task_id}")
if status.requires_failure_reason() and failure_reason is None:
raise ValueError(f"status_requires_failure_reason({status},{self.task_id}")
if status.requires_extracted_info() and self.data_extraction_goal and extracted_information is None:
raise ValueError(f"status_requires_extracted_information({status},{self.task_id}")
if status.cant_have_extracted_info() and extracted_information is not None:
raise ValueError(f"status_cant_have_extracted_information({self.task_id})")
if self.extracted_information is not None and extracted_information is not None:
raise ValueError(f"cant_override_extracted_information({self.task_id})")
if self.failure_reason is not None and failure_reason is not None:
raise ValueError(f"cant_override_failure_reason({self.task_id})")
def to_task_response(
self, screenshot_url: str | None = None, recording_url: str | None = None, failure_reason: str | None = None
) -> TaskResponse:
return TaskResponse(
request=self,
task_id=self.task_id,
status=self.status,
created_at=self.created_at,
modified_at=self.modified_at,
extracted_information=self.extracted_information,
failure_reason=failure_reason or self.failure_reason,
screenshot_url=screenshot_url,
recording_url=recording_url,
)
class TaskResponse(BaseModel):
request: TaskRequest
task_id: str
status: TaskStatus
created_at: datetime
modified_at: datetime
extracted_information: list | dict[str, Any] | str | None = None
screenshot_url: str | None = None
recording_url: str | None = None
failure_reason: str | None = None
class CreateTaskResponse(BaseModel):
task_id: str

View File

View File

@@ -0,0 +1,76 @@
import time
from typing import Annotated
from asyncache import cached
from cachetools import TTLCache
from fastapi import Header, HTTPException, status
from jose import jwt
from jose.exceptions import JWTError
from pydantic import ValidationError
from skyvern.forge import app
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.db.client import AgentDB
from skyvern.forge.sdk.models import Organization, OrganizationAuthTokenType, TokenPayload
from skyvern.forge.sdk.settings_manager import SettingsManager
AUTHENTICATION_TTL = 60 * 60 # one hour
CACHE_SIZE = 128
ALGORITHM = "HS256"
async def get_current_org(
x_api_key: Annotated[str | None, Header()] = None,
) -> Organization:
if not x_api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
return await _get_current_org_cached(x_api_key, app.DATABASE)
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
"""
Authentication is cached for one hour
"""
try:
payload = jwt.decode(
x_api_key,
SettingsManager.get_settings().SECRET_KEY,
algorithms=[ALGORITHM],
)
api_key_data = TokenPayload(**payload)
except (JWTError, ValidationError):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
if api_key_data.exp < time.time():
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Auth token is expired",
)
organization = await db.get_organization(organization_id=api_key_data.sub)
if not organization:
raise HTTPException(status_code=404, detail="Organization not found")
# check if the token exists in the database
api_key_db_obj = await db.validate_org_auth_token(
organization_id=organization.organization_id,
token_type=OrganizationAuthTokenType.api,
token=x_api_key,
)
if not api_key_db_obj:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
# set organization_id in skyvern context and log context
context = skyvern_context.current()
if context:
context.organization_id = organization.organization_id
return organization

View File

@@ -0,0 +1,14 @@
from skyvern.config import Settings
from skyvern.config import settings as base_settings
class SettingsManager:
__instance: Settings = base_settings
@staticmethod
def get_settings() -> Settings:
return SettingsManager.__instance
@staticmethod
def set_settings(settings: Settings) -> None:
SettingsManager.__instance = settings

View File

View File

@@ -0,0 +1,79 @@
from typing import TYPE_CHECKING, Any
import structlog
from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, Parameter, ParameterType, WorkflowParameter
if TYPE_CHECKING:
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunParameter
LOG = structlog.get_logger()
class ContextManager:
aws_client: AsyncAWSClient
parameters: dict[str, PARAMETER_TYPE]
values: dict[str, Any]
def __init__(self, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]) -> None:
self.aws_client = AsyncAWSClient()
self.parameters = {}
self.values = {}
for parameter, run_parameter in workflow_parameter_tuples:
if parameter.key in self.parameters:
prev_value = self.parameters[parameter.key]
new_value = run_parameter.value
LOG.error(
f"Duplicate parameter key {parameter.key} found while initializing context manager, previous value: {prev_value}, new value: {new_value}. Using new value."
)
self.parameters[parameter.key] = parameter
self.values[parameter.key] = run_parameter.value
async def register_parameter_value(
self,
parameter: PARAMETER_TYPE,
) -> None:
if parameter.parameter_type == ParameterType.WORKFLOW:
LOG.error(f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}")
raise ValueError(
f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}"
)
elif parameter.parameter_type == ParameterType.AWS_SECRET:
secret_value = await self.aws_client.get_secret(parameter.aws_key)
if secret_value is not None:
self.values[parameter.key] = secret_value
else:
# ContextParameter values will be set within the blocks
return None
async def register_block_parameters(
self,
parameters: list[PARAMETER_TYPE],
) -> None:
for parameter in parameters:
if parameter.key in self.parameters:
LOG.debug(f"Parameter {parameter.key} already registered, skipping")
continue
if parameter.parameter_type == ParameterType.WORKFLOW:
LOG.error(
f"Workflow parameter {parameter.key} should have already been set through workflow run parameters"
)
raise ValueError(
f"Workflow parameter {parameter.key} should have already been set through workflow run parameters"
)
self.parameters[parameter.key] = parameter
await self.register_parameter_value(parameter)
def get_parameter(self, key: str) -> Parameter:
return self.parameters[key]
def get_value(self, key: str) -> Any:
return self.values[key]
def set_value(self, key: str, value: Any) -> None:
self.values[key] = value

View File

@@ -0,0 +1,221 @@
import abc
from enum import StrEnum
from typing import Annotated, Any, Literal, Union
import structlog
from pydantic import BaseModel, Field
from skyvern.exceptions import (
ContextParameterValueNotFound,
MissingBrowserStatePage,
TaskNotFound,
UnexpectedTaskStatus,
)
from skyvern.forge import app
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.workflow.context_manager import ContextManager
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, ContextParameter, WorkflowParameter
LOG = structlog.get_logger()
class BlockType(StrEnum):
TASK = "task"
FOR_LOOP = "for_loop"
class Block(BaseModel, abc.ABC):
block_type: BlockType
parent_block_id: str | None = None
next_block_id: str | None = None
@classmethod
def get_subclasses(cls) -> tuple[type["Block"], ...]:
return tuple(cls.__subclasses__())
@abc.abstractmethod
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
pass
@abc.abstractmethod
def get_all_parameters(
self,
) -> list[PARAMETER_TYPE]:
pass
class TaskBlock(Block):
block_type: Literal[BlockType.TASK] = BlockType.TASK
url: str | None = None
navigation_goal: str | None = None
data_extraction_goal: str | None = None
data_schema: dict[str, Any] | None = None
max_retries: int = 0
parameters: list[PARAMETER_TYPE] = []
def get_all_parameters(
self,
) -> list[PARAMETER_TYPE]:
return self.parameters
@staticmethod
async def get_task_order(workflow_run_id: str, current_retry: int) -> tuple[int, int]:
"""
Returns the order and retry for the next task in the workflow run as a tuple.
"""
last_task_for_workflow_run = await app.DATABASE.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
# If there is no previous task, the order will be 0 and the retry will be 0.
if last_task_for_workflow_run is None:
return 0, 0
# If there is a previous task but the current retry is 0, the order will be the order of the last task + 1
# and the retry will be 0.
order = last_task_for_workflow_run.order or 0
if current_retry == 0:
return order + 1, 0
# If there is a previous task and the current retry is not 0, the order will be the order of the last task
# and the retry will be the retry of the last task + 1. (There is a validation that makes sure the retry
# of the last task is equal to current_retry - 1) if it is not, we use last task retry + 1.
retry = last_task_for_workflow_run.retry or 0
if retry + 1 != current_retry:
LOG.error(
f"Last task for workflow run is retry number {last_task_for_workflow_run.retry}, "
f"but current retry is {current_retry}. Could be race condition. Using last task retry + 1",
workflow_run_id=workflow_run_id,
last_task_id=last_task_for_workflow_run.task_id,
last_task_retry=last_task_for_workflow_run.retry,
current_retry=current_retry,
)
return order, retry + 1
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
task = None
current_retry = 0
# initial value for will_retry is True, so that the loop runs at least once
will_retry = True
workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id=workflow_run_id)
workflow = await app.WORKFLOW_SERVICE.get_workflow(workflow_id=workflow_run.workflow_id)
# TODO (kerem) we should always retry on terminated. We should make a distinction between retriable and
# non-retryable terminations
while will_retry:
task_order, task_retry = await self.get_task_order(workflow_run_id, current_retry)
task, step = await app.agent.create_task_and_step_from_block(
task_block=self,
workflow=workflow,
workflow_run=workflow_run,
context_manager=context_manager,
task_order=task_order,
task_retry=task_retry,
)
organization = await app.DATABASE.get_organization(organization_id=workflow.organization_id)
if not organization:
raise Exception(f"Organization is missing organization_id={workflow.organization_id}")
browser_state = await app.BROWSER_MANAGER.get_or_create_for_workflow_run(
workflow_run=workflow_run, url=self.url
)
if not browser_state.page:
LOG.error("BrowserState has no page", workflow_run_id=workflow_run.workflow_run_id)
raise MissingBrowserStatePage(workflow_run_id=workflow_run.workflow_run_id)
LOG.info(
f"Navigating to page",
url=self.url,
workflow_run_id=workflow_run_id,
task_id=task.task_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
step_id=step.step_id,
)
if self.url:
await browser_state.page.goto(self.url)
await app.agent.execute_step(organization=organization, task=task, step=step, workflow_run=workflow_run)
# Check task status
updated_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=workflow.organization_id)
if not updated_task:
raise TaskNotFound(task.task_id)
if not updated_task.status.is_final():
raise UnexpectedTaskStatus(task_id=updated_task.task_id, status=updated_task.status)
if updated_task.status == TaskStatus.completed:
will_retry = False
else:
current_retry += 1
will_retry = current_retry <= self.max_retries
retry_message = f", retrying task {current_retry}/{self.max_retries}" if will_retry else ""
LOG.warning(
f"Task failed with status {updated_task.status}{retry_message}",
task_id=updated_task.task_id,
status=updated_task.status,
workflow_run_id=workflow_run_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
current_retry=current_retry,
max_retries=self.max_retries,
)
class ForLoopBlock(Block):
block_type: Literal[BlockType.FOR_LOOP] = BlockType.FOR_LOOP
# TODO (kerem): Add support for ContextParameter
loop_over: PARAMETER_TYPE
loop_block: "BlockTypeVar"
def get_all_parameters(
self,
) -> list[PARAMETER_TYPE]:
return self.loop_block.get_all_parameters() + [self.loop_over]
def get_loop_block_context_parameters(self, workflow_run_id: str, loop_data: Any) -> list[ContextParameter]:
if not isinstance(loop_data, dict):
# TODO (kerem): Should we add support for other types?
raise ValueError("loop_data should be a dictionary")
loop_block_parameters = self.loop_block.get_all_parameters()
context_parameters = [
parameter for parameter in loop_block_parameters if isinstance(parameter, ContextParameter)
]
for context_parameter in context_parameters:
if context_parameter.key not in loop_data:
raise ContextParameterValueNotFound(
parameter_key=context_parameter.key,
existing_keys=list(loop_data.keys()),
workflow_run_id=workflow_run_id,
)
context_parameter.value = loop_data[context_parameter.key]
return context_parameters
def get_loop_over_parameter_values(self, context_manager: ContextManager) -> list[Any]:
if isinstance(self.loop_over, WorkflowParameter):
parameter_value = context_manager.get_value(self.loop_over.key)
if isinstance(parameter_value, list):
return parameter_value
else:
# TODO (kerem): Should we raise an error here?
return [parameter_value]
else:
# TODO (kerem): Implement this for context parameters
raise NotImplementedError
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
loop_over_values = self.get_loop_over_parameter_values(context_manager)
LOG.info(
f"Number of loop_over values: {len(loop_over_values)}",
block_type=self.block_type,
workflow_run_id=workflow_run_id,
num_loop_over_values=len(loop_over_values),
)
for loop_over_value in loop_over_values:
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
for context_parameter in context_parameters_with_value:
context_manager.set_value(context_parameter.key, context_parameter.value)
await self.loop_block.execute(workflow_run_id=workflow_run_id, context_manager=context_manager)
return None
BlockSubclasses = Union[ForLoopBlock, TaskBlock]
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]

View File

@@ -0,0 +1,84 @@
import abc
import json
from datetime import datetime
from enum import StrEnum
from typing import Annotated, Literal, Union
from pydantic import BaseModel, Field
class ParameterType(StrEnum):
WORKFLOW = "workflow"
CONTEXT = "context"
AWS_SECRET = "aws_secret"
class Parameter(BaseModel, abc.ABC):
# TODO (kerem): Should we also have organization_id here?
parameter_type: ParameterType
key: str
description: str | None = None
@classmethod
def get_subclasses(cls) -> tuple[type["Parameter"], ...]:
return tuple(cls.__subclasses__())
class AWSSecretParameter(Parameter):
parameter_type: Literal[ParameterType.AWS_SECRET] = ParameterType.AWS_SECRET
aws_secret_parameter_id: str
workflow_id: str
aws_key: str
created_at: datetime
modified_at: datetime
deleted_at: datetime | None = None
class WorkflowParameterType(StrEnum):
STRING = "string"
INTEGER = "integer"
FLOAT = "float"
BOOLEAN = "boolean"
JSON = "json"
def convert_value(self, value: str | None) -> str | int | float | bool | dict | list | None:
if value is None:
return None
if self == WorkflowParameterType.STRING:
return value
elif self == WorkflowParameterType.INTEGER:
return int(value)
elif self == WorkflowParameterType.FLOAT:
return float(value)
elif self == WorkflowParameterType.BOOLEAN:
return value.lower() in ["true", "1"]
elif self == WorkflowParameterType.JSON:
return json.loads(value)
class WorkflowParameter(Parameter):
parameter_type: Literal[ParameterType.WORKFLOW] = ParameterType.WORKFLOW
workflow_parameter_id: str
workflow_parameter_type: WorkflowParameterType
workflow_id: str
# the type of default_value will be determined by the workflow_parameter_type
default_value: str | int | float | bool | dict | list | None = None
created_at: datetime
modified_at: datetime
deleted_at: datetime | None = None
class ContextParameter(Parameter):
parameter_type: Literal[ParameterType.CONTEXT] = ParameterType.CONTEXT
source: WorkflowParameter
# value will be populated by the context manager
value: str | int | float | bool | dict | list | None = None
ParameterSubclasses = Union[WorkflowParameter, ContextParameter, AWSSecretParameter]
PARAMETER_TYPE = Annotated[ParameterSubclasses, Field(discriminator="parameter_type")]

View File

@@ -0,0 +1,74 @@
from datetime import datetime
from enum import StrEnum
from typing import Any, List
from pydantic import BaseModel
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
from skyvern.forge.sdk.workflow.models.block import BlockTypeVar
class WorkflowRequestBody(BaseModel):
data: dict[str, Any] | None = None
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
class RunWorkflowResponse(BaseModel):
workflow_id: str
workflow_run_id: str
class WorkflowDefinition(BaseModel):
blocks: List[BlockTypeVar]
class Workflow(BaseModel):
workflow_id: str
organization_id: str
title: str
description: str | None = None
workflow_definition: WorkflowDefinition
created_at: datetime
modified_at: datetime
deleted_at: datetime | None = None
class WorkflowRunStatus(StrEnum):
created = "created"
running = "running"
failed = "failed"
terminated = "terminated"
completed = "completed"
class WorkflowRun(BaseModel):
workflow_run_id: str
workflow_id: str
status: WorkflowRunStatus
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
created_at: datetime
modified_at: datetime
class WorkflowRunParameter(BaseModel):
workflow_run_id: str
workflow_parameter_id: str
value: bool | int | float | str | dict | list
created_at: datetime
class WorkflowRunStatusResponse(BaseModel):
workflow_id: str
workflow_run_id: str
status: WorkflowRunStatus
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
created_at: datetime
modified_at: datetime
parameters: dict[str, Any]
screenshot_urls: list[str] | None = None
recording_url: str | None = None

View File

@@ -0,0 +1,509 @@
import asyncio
import json
import time
from datetime import datetime
import requests
import structlog
from skyvern.exceptions import (
FailedToSendWebhook,
MissingValueForParameter,
WorkflowNotFound,
WorkflowOrganizationMismatch,
WorkflowRunNotFound,
)
from skyvern.forge import app
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.workflow.context_manager import ContextManager
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
from skyvern.forge.sdk.workflow.models.workflow import (
Workflow,
WorkflowDefinition,
WorkflowRequestBody,
WorkflowRun,
WorkflowRunParameter,
WorkflowRunStatus,
WorkflowRunStatusResponse,
)
from skyvern.webeye.browser_factory import BrowserState
LOG = structlog.get_logger()
class WorkflowService:
async def setup_workflow_run(
self,
request_id: str | None,
workflow_request: WorkflowRequestBody,
workflow_id: str,
organization_id: str,
max_steps_override: int | None = None,
) -> WorkflowRun:
"""
Create a workflow run and its parameters. Validate the workflow and the organization. If there are missing
parameters with no default value, mark the workflow run as failed.
:param request_id: The request id for the workflow run.
:param workflow_request: The request body for the workflow run, containing the parameters and the config.
:param workflow_id: The workflow id to run.
:param organization_id: The organization id for the workflow.
:param max_steps_override: The max steps override for the workflow run, if any.
:return: The created workflow run.
"""
LOG.info(f"Setting up workflow run for workflow {workflow_id}", workflow_id=workflow_id)
# Validate the workflow and the organization
workflow = await self.get_workflow(workflow_id=workflow_id)
if workflow is None:
LOG.error(f"Workflow {workflow_id} not found")
raise WorkflowNotFound(workflow_id=workflow_id)
if workflow.organization_id != organization_id:
LOG.error(f"Workflow {workflow_id} does not belong to organization {organization_id}")
raise WorkflowOrganizationMismatch(workflow_id=workflow_id, organization_id=organization_id)
# Create the workflow run and set skyvern context
workflow_run = await self.create_workflow_run(workflow_request=workflow_request, workflow_id=workflow_id)
LOG.info(
f"Created workflow run {workflow_run.workflow_run_id} for workflow {workflow.workflow_id}",
request_id=request_id,
workflow_run_id=workflow_run.workflow_run_id,
workflow_id=workflow.workflow_id,
proxy_location=workflow_request.proxy_location,
)
skyvern_context.set(
SkyvernContext(
organization_id=organization_id,
request_id=request_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
max_steps_override=max_steps_override,
)
)
# Set workflow run status to running, create workflow run parameters
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
# Create all the workflow run parameters, AWSSecretParameter won't have workflow run parameters created.
all_workflow_parameters = await self.get_workflow_parameters(workflow_id=workflow.workflow_id)
workflow_run_parameters = []
for workflow_parameter in all_workflow_parameters:
if workflow_request.data and workflow_parameter.key in workflow_request.data:
request_body_value = workflow_request.data[workflow_parameter.key]
workflow_run_parameter = await self.create_workflow_run_parameter(
workflow_run_id=workflow_run.workflow_run_id,
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
value=request_body_value,
)
elif workflow_parameter.default_value is not None:
workflow_run_parameter = await self.create_workflow_run_parameter(
workflow_run_id=workflow_run.workflow_run_id,
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
value=workflow_parameter.default_value,
)
else:
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
raise MissingValueForParameter(
parameter_key=workflow_parameter.key,
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
)
workflow_run_parameters.append(workflow_run_parameter)
LOG.info(
f"Created workflow run parameters for workflow run {workflow_run.workflow_run_id}",
workflow_run_id=workflow_run.workflow_run_id,
)
return workflow_run
async def execute_workflow(
self,
workflow_run_id: str,
api_key: str,
) -> WorkflowRun:
"""Execute a workflow."""
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id)
await app.BROWSER_MANAGER.get_or_create_for_workflow_run(workflow_run=workflow_run)
# Get all <workflow parameter, workflow run parameter> tuples
wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id)
# todo(kerem): do this in a better way (a shared context manager? (not really shared because we use batch job))
context_manager = ContextManager(wp_wps_tuples)
# Execute workflow blocks
blocks = workflow.workflow_definition.blocks
for block_idx, block in enumerate(blocks):
parameters = block.get_all_parameters()
await context_manager.register_block_parameters(parameters)
LOG.info(
f"Executing root block {block.block_type} at index {block_idx} for workflow run {workflow_run.workflow_run_id}",
block_type=block.block_type,
workflow_run_id=workflow_run.workflow_run_id,
block_idx=block_idx,
)
await block.execute(workflow_run_id=workflow_run.workflow_run_id, context_manager=context_manager)
# Get last task for workflow run
task = await self.get_last_task_for_workflow_run(workflow_run_id=workflow_run.workflow_run_id)
if not task:
LOG.warning(
f"No tasks found for workflow run {workflow_run.workflow_run_id}, not sending webhook",
workflow_run_id=workflow_run.workflow_run_id,
)
return workflow_run
# Update workflow status
if task.status == "completed":
await self.mark_workflow_run_as_completed(workflow_run_id=workflow_run.workflow_run_id)
elif task.status == "failed":
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
elif task.status == "terminated":
await self.mark_workflow_run_as_terminated(workflow_run_id=workflow_run.workflow_run_id)
else:
LOG.warning(
f"Task {task.task_id} has an incomplete status {task.status}, not updating workflow run status",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
task_id=task.task_id,
status=task.status,
workflow_run_status=workflow_run.status,
)
await self.send_workflow_response(
workflow=workflow,
workflow_run=workflow_run,
api_key=api_key,
last_task=task,
)
return workflow_run
async def create_workflow(
self,
organization_id: str,
title: str,
workflow_definition: WorkflowDefinition,
description: str | None = None,
) -> Workflow:
return await app.DATABASE.create_workflow(
organization_id=organization_id,
title=title,
description=description,
workflow_definition=workflow_definition.model_dump() if workflow_definition else None,
)
async def get_workflow(self, workflow_id: str) -> Workflow:
workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id)
if not workflow:
raise WorkflowNotFound(workflow_id)
return workflow
async def update_workflow(
self,
workflow_id: str,
title: str | None = None,
description: str | None = None,
workflow_definition: WorkflowDefinition | None = None,
) -> Workflow | None:
return await app.DATABASE.update_workflow(
workflow_id=workflow_id,
title=title,
description=description,
workflow_definition=workflow_definition.model_dump() if workflow_definition else None,
)
async def create_workflow_run(self, workflow_request: WorkflowRequestBody, workflow_id: str) -> WorkflowRun:
return await app.DATABASE.create_workflow_run(
workflow_id=workflow_id,
proxy_location=workflow_request.proxy_location,
webhook_callback_url=workflow_request.webhook_callback_url,
)
async def mark_workflow_run_as_completed(self, workflow_run_id: str) -> None:
LOG.info(
f"Marking workflow run {workflow_run_id} as completed", workflow_run_id=workflow_run_id, status="completed"
)
await app.DATABASE.update_workflow_run(
workflow_run_id=workflow_run_id,
status=WorkflowRunStatus.completed,
)
async def mark_workflow_run_as_failed(self, workflow_run_id: str) -> None:
LOG.info(f"Marking workflow run {workflow_run_id} as failed", workflow_run_id=workflow_run_id, status="failed")
await app.DATABASE.update_workflow_run(
workflow_run_id=workflow_run_id,
status=WorkflowRunStatus.failed,
)
async def mark_workflow_run_as_running(self, workflow_run_id: str) -> None:
LOG.info(
f"Marking workflow run {workflow_run_id} as running", workflow_run_id=workflow_run_id, status="running"
)
await app.DATABASE.update_workflow_run(
workflow_run_id=workflow_run_id,
status=WorkflowRunStatus.running,
)
async def mark_workflow_run_as_terminated(self, workflow_run_id: str) -> None:
LOG.info(
f"Marking workflow run {workflow_run_id} as terminated",
workflow_run_id=workflow_run_id,
status="terminated",
)
await app.DATABASE.update_workflow_run(
workflow_run_id=workflow_run_id,
status=WorkflowRunStatus.terminated,
)
async def get_workflow_runs(self, workflow_id: str) -> list[WorkflowRun]:
return await app.DATABASE.get_workflow_runs(workflow_id=workflow_id)
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run_id)
if not workflow_run:
raise WorkflowRunNotFound(workflow_run_id)
return workflow_run
async def create_workflow_parameter(
self,
workflow_id: str,
workflow_parameter_type: WorkflowParameterType,
key: str,
default_value: bool | int | float | str | dict | list | None = None,
description: str | None = None,
) -> WorkflowParameter:
return await app.DATABASE.create_workflow_parameter(
workflow_id=workflow_id,
workflow_parameter_type=workflow_parameter_type,
key=key,
description=description,
default_value=default_value,
)
async def create_aws_secret_parameter(
self, workflow_id: str, aws_key: str, key: str, description: str | None = None
) -> AWSSecretParameter:
return await app.DATABASE.create_aws_secret_parameter(
workflow_id=workflow_id, aws_key=aws_key, key=key, description=description
)
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
return await app.DATABASE.get_workflow_parameters(workflow_id=workflow_id)
async def create_workflow_run_parameter(
self,
workflow_run_id: str,
workflow_parameter_id: str,
value: bool | int | float | str | dict | list,
) -> WorkflowRunParameter:
return await app.DATABASE.create_workflow_run_parameter(
workflow_run_id=workflow_run_id,
workflow_parameter_id=workflow_parameter_id,
value=json.dumps(value) if isinstance(value, (dict, list)) else value,
)
async def get_workflow_run_parameter_tuples(
self, workflow_run_id: str
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
return await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
return await app.DATABASE.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
return await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
async def build_workflow_run_status_response(
self, workflow_id: str, workflow_run_id: str, organization_id: str
) -> WorkflowRunStatusResponse:
workflow = await self.get_workflow(workflow_id=workflow_id)
if workflow is None:
LOG.error(f"Workflow {workflow_id} not found")
raise WorkflowNotFound(workflow_id=workflow_id)
if workflow.organization_id != organization_id:
LOG.error(f"Workflow {workflow_id} does not belong to organization {organization_id}")
raise WorkflowOrganizationMismatch(workflow_id=workflow_id, organization_id=organization_id)
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
screenshot_urls = []
# get the last screenshot for the last 3 tasks of the workflow run
for task in workflow_run_tasks[::-1]:
screenshot_artifact = await app.DATABASE.get_latest_artifact(
task_id=task.task_id,
artifact_types=[ArtifactType.SCREENSHOT_ACTION, ArtifactType.SCREENSHOT_FINAL],
organization_id=organization_id,
)
if screenshot_artifact:
screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(screenshot_artifact)
if screenshot_url:
screenshot_urls.append(screenshot_url)
if len(screenshot_urls) >= 3:
break
recording_url = None
recording_artifact = await app.DATABASE.get_artifact_for_workflow_run(
workflow_run_id=workflow_run_id, artifact_type=ArtifactType.RECORDING, organization_id=organization_id
)
if recording_artifact:
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples}
return WorkflowRunStatusResponse(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
status=workflow_run.status,
proxy_location=workflow_run.proxy_location,
webhook_callback_url=workflow_run.webhook_callback_url,
created_at=workflow_run.created_at,
modified_at=workflow_run.modified_at,
parameters=parameters_with_value,
screenshot_urls=screenshot_urls,
recording_url=recording_url,
)
async def send_workflow_response(
self,
workflow: Workflow,
workflow_run: WorkflowRun,
last_task: Task,
api_key: str | None = None,
close_browser_on_completion: bool = True,
) -> None:
browser_state = await app.BROWSER_MANAGER.cleanup_for_workflow_run(
workflow_run.workflow_run_id, close_browser_on_completion
)
if browser_state:
await self.persist_video_data(browser_state, workflow, workflow_run)
await self.persist_har_data(browser_state, last_task, workflow, workflow_run)
# Wait for all tasks to complete before generating the links for the artifacts
all_workflow_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(
workflow_run_id=workflow_run.workflow_run_id
)
all_workflow_task_ids = [task.task_id for task in all_workflow_tasks]
await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_tasks(all_workflow_task_ids)
try:
# Wait for all tasks to complete. Currently we're using asyncio.create_task() only for uploading artifacts to S3.
# We're excluding the current task from the list of tasks to wait for to prevent a deadlock.
st = time.time()
async with asyncio.timeout(30):
await asyncio.gather(
*[aio_task for aio_task in (asyncio.all_tasks() - {asyncio.current_task()}) if not aio_task.done()]
)
LOG.info(
f"Waiting for all S3 uploads to complete took {time.time() - st} seconds",
duration=time.time() - st,
)
except asyncio.TimeoutError:
LOG.warning(
"Timed out waiting for all S3 uploads to complete, not all artifacts may be uploaded. Waited 30 seconds.",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
)
if not workflow_run.webhook_callback_url:
LOG.warning(
"Workflow has no webhook callback url. Not sending workflow response",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
)
return
if not api_key:
LOG.warning(
"Request has no api key. Not sending workflow response",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
)
return
workflow_run_status_response = await self.build_workflow_run_status_response(
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
organization_id=workflow.organization_id,
)
# send task_response to the webhook callback url
# TODO: use async requests (httpx)
timestamp = str(int(datetime.utcnow().timestamp()))
payload = workflow_run_status_response.model_dump_json()
signature = generate_skyvern_signature(
payload=payload,
api_key=api_key,
)
headers = {
"x-skyvern-timestamp": timestamp,
"x-skyvern-signature": signature,
"Content-Type": "application/json",
}
LOG.info(
"Sending webhook run status to webhook callback url",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
webhook_callback_url=workflow_run.webhook_callback_url,
payload=payload,
headers=headers,
)
try:
resp = requests.post(workflow_run.webhook_callback_url, data=payload, headers=headers)
if resp.ok:
LOG.info(
"Webhook sent successfully",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
resp_code=resp.status_code,
resp_text=resp.text,
)
else:
LOG.info(
"Webhook failed",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
resp=resp,
resp_code=resp.status_code,
resp_text=resp.text,
resp_json=resp.json(),
)
except Exception as e:
raise FailedToSendWebhook(
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id
) from e
async def persist_video_data(
self, browser_state: BrowserState, workflow: Workflow, workflow_run: WorkflowRun
) -> None:
# Create recording artifact after closing the browser, so we can get an accurate recording
video_data = await app.BROWSER_MANAGER.get_video_data(
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
browser_state=browser_state,
)
if video_data:
await app.ARTIFACT_MANAGER.update_artifact_data(
artifact_id=browser_state.browser_artifacts.video_artifact_id,
organization_id=workflow.organization_id,
data=video_data,
)
async def persist_har_data(
self, browser_state: BrowserState, last_task: Task, workflow: Workflow, workflow_run: WorkflowRun
) -> None:
har_data = await app.BROWSER_MANAGER.get_har_data(
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id, browser_state=browser_state
)
if har_data:
last_step = await app.DATABASE.get_latest_step(
task_id=last_task.task_id, organization_id=last_task.organization_id
)
if last_step:
await app.ARTIFACT_MANAGER.create_artifact(
step=last_step,
artifact_type=ArtifactType.HAR,
data=har_data,
)

View File

View File

View File

@@ -0,0 +1,204 @@
import abc
from enum import StrEnum
from typing import Any, Dict, List
import structlog
from pydantic import BaseModel
from skyvern.forge.sdk.schemas.tasks import Task
LOG = structlog.get_logger()
class ActionType(StrEnum):
CLICK = "click"
INPUT_TEXT = "input_text"
UPLOAD_FILE = "upload_file"
SELECT_OPTION = "select_option"
CHECKBOX = "checkbox"
WAIT = "wait"
NULL_ACTION = "null_action"
SOLVE_CAPTCHA = "solve_captcha"
TERMINATE = "terminate"
COMPLETE = "complete"
# Note: Remember to update ActionTypeUnion with new actions
class Action(BaseModel):
action_type: ActionType
description: str | None = None
reasoning: str | None = None
class WebAction(Action, abc.ABC):
element_id: int
class ClickAction(WebAction):
action_type: ActionType = ActionType.CLICK
file_url: str | None = None
def __repr__(self) -> str:
return f"ClickAction(element_id={self.element_id}, file_url={self.file_url})"
class InputTextAction(WebAction):
action_type: ActionType = ActionType.INPUT_TEXT
text: str
def __repr__(self) -> str:
return f"InputTextAction(element_id={self.element_id}, text={self.text})"
class UploadFileAction(WebAction):
action_type: ActionType = ActionType.UPLOAD_FILE
file_url: str
is_upload_file_tag: bool = True
def __repr__(self) -> str:
return f"UploadFileAction(element_id={self.element_id}, file={self.file_url}, is_upload_file_tag={self.is_upload_file_tag})"
class NullAction(Action):
action_type: ActionType = ActionType.NULL_ACTION
class SolveCaptchaAction(Action):
action_type: ActionType = ActionType.SOLVE_CAPTCHA
class SelectOption(BaseModel):
label: str | None
value: str | None
index: int | None
def __repr__(self) -> str:
return f"SelectOption(label={self.label}, value={self.value}, index={self.index})"
class SelectOptionAction(WebAction):
action_type: ActionType = ActionType.SELECT_OPTION
option: SelectOption
def __repr__(self) -> str:
return f"SelectOptionAction(element_id={self.element_id}, option={self.option})"
###
# This action causes more harm than it does good.
# It frequently mis-behaves, or gets stuck in click loops.
# Treating checkbox actions as click actions seem to perform way more reliably
# Developers who tried this and failed: 2 (Suchintan and Shu 😂)
###
class CheckboxAction(WebAction):
action_type: ActionType = ActionType.CHECKBOX
is_checked: bool
def __repr__(self) -> str:
return f"CheckboxAction(element_id={self.element_id}, is_checked={self.is_checked})"
class WaitAction(Action):
action_type: ActionType = ActionType.WAIT
class TerminateAction(Action):
action_type: ActionType = ActionType.TERMINATE
class CompleteAction(Action):
action_type: ActionType = ActionType.COMPLETE
data_extraction_goal: str | None = None
def parse_actions(task: Task, json_response: List[Dict[str, Any]]) -> List[Action]:
actions = []
for action in json_response:
element_id = action["id"]
reasoning = action["reasoning"] if "reasoning" in action else None
if "action_type" not in action or action["action_type"] is None:
actions.append(NullAction(reasoning=reasoning))
continue
# `.upper()` handles the case where the LLM returns a lowercase action type (e.g. "click" instead of "CLICK")
action_type = ActionType[action["action_type"].upper()]
if action_type == ActionType.TERMINATE:
LOG.warning(
"Agent decided to terminate",
task_id=task.task_id,
llm_response=json_response,
reasoning=reasoning,
actions=actions,
)
actions.append(TerminateAction(reasoning=reasoning))
elif action_type == ActionType.CLICK:
file_url = action["file_url"] if "file_url" in action else None
actions.append(ClickAction(element_id=element_id, reasoning=reasoning, file_url=file_url))
elif action_type == ActionType.INPUT_TEXT:
actions.append(InputTextAction(element_id=element_id, text=action["text"], reasoning=reasoning))
elif action_type == ActionType.UPLOAD_FILE:
# TODO: see if the element is a file input element. if it's not, convert this action into a click action
actions.append(UploadFileAction(element_id=element_id, file_url=action["file_url"], reasoning=reasoning))
elif action_type == ActionType.SELECT_OPTION:
actions.append(
SelectOptionAction(
element_id=element_id,
option=SelectOption(
label=action["option"]["label"],
value=action["option"]["value"],
index=action["option"]["index"],
),
reasoning=reasoning,
)
)
elif action_type == ActionType.CHECKBOX:
actions.append(CheckboxAction(element_id=element_id, is_checked=action["is_checked"], reasoning=reasoning))
elif action_type == ActionType.WAIT:
actions.append(WaitAction(reasoning=reasoning))
elif action_type == ActionType.COMPLETE:
if actions:
LOG.info(
"Navigation goal achieved, creating complete action and discarding all other actions except "
"complete action",
task_id=task.task_id,
nav_goal=task.navigation_goal,
actions=actions,
llm_response=json_response,
)
return [CompleteAction(reasoning=reasoning, data_extraction_goal=task.data_extraction_goal)]
elif action_type == "null":
actions.append(NullAction(reasoning=reasoning))
elif action_type == ActionType.SOLVE_CAPTCHA:
actions.append(SolveCaptchaAction(reasoning=reasoning))
else:
LOG.error(
"Unsupported action type when parsing actions",
task_id=task.task_id,
action_type=action_type,
raw_action=action,
)
return actions
class ScrapeResult(BaseModel):
"""
Scraped response from a webpage, including:
1. JSON representation of what the user is seeing
"""
scraped_data: dict[str, Any] | list[dict[str, Any]]
# https://blog.devgenius.io/deserialize-child-classes-with-pydantic-that-gonna-work-784230e1cf83
ActionTypeUnion = (
ClickAction
| InputTextAction
| UploadFileAction
| SelectOptionAction
| CheckboxAction
| WaitAction
| NullAction
| SolveCaptchaAction
| TerminateAction
| CompleteAction
)

View File

@@ -0,0 +1,445 @@
import asyncio
import re
from typing import Awaitable, Callable, List
import structlog
from playwright.async_api import Locator, Page
from skyvern.exceptions import ImaginaryFileUrl, MissingElement, MissingFileUrl, MultipleElementsFound
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.files import download_file
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.webeye.actions import actions
from skyvern.webeye.actions.actions import Action, ActionType, ClickAction, ScrapeResult, UploadFileAction, WebAction
from skyvern.webeye.actions.responses import ActionFailure, ActionResult, ActionSuccess
from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ScrapedPage
LOG = structlog.get_logger()
class ActionHandler:
_handled_action_types: dict[
ActionType, Callable[[Action, Page, ScrapedPage, Task, Step], Awaitable[list[ActionResult]]]
] = {}
@classmethod
def register_action_type(
cls,
action_type: ActionType,
handler: Callable[[Action, Page, ScrapedPage, Task, Step], Awaitable[list[ActionResult]]],
) -> None:
cls._handled_action_types[action_type] = handler
@staticmethod
async def handle_action(
scraped_page: ScrapedPage,
task: Task,
step: Step,
browser_state: BrowserState,
action: Action,
) -> list[ActionResult]:
LOG.info("Handling action", action=action)
page = await browser_state.get_or_create_page()
try:
if action.action_type in ActionHandler._handled_action_types:
handler = ActionHandler._handled_action_types[action.action_type]
return await handler(action, page, scraped_page, task, step)
else:
LOG.error("Unsupported action type in handler", action=action, type=type(action))
return [ActionFailure(Exception(f"Unsupported action type: {type(action)}"))]
except MissingElement as e:
LOG.info("Known exceptions", action=action, exception_type=type(e), exception_message=str(e))
return [ActionFailure(e)]
except MultipleElementsFound as e:
LOG.exception(
"Cannot handle multiple elements with the same xpath in one action.",
action=action,
exception=e,
)
return [ActionFailure(e)]
except Exception as e:
LOG.exception("Unhandled exception in action handler", action=action, exception=e)
return [ActionFailure(e)]
async def handle_solve_captcha_action(
action: actions.SolveCaptchaAction, page: Page, scraped_page: ScrapedPage, task: Task, step: Step
) -> list[ActionResult]:
LOG.warning(
"Please solve the captcha on the page, you have 30 seconds",
action=action,
)
await asyncio.sleep(30)
return [ActionSuccess()]
async def handle_click_action(
action: actions.ClickAction, page: Page, scraped_page: ScrapedPage, task: Task, step: Step
) -> list[ActionResult]:
xpath = await validate_actions_in_dom(action, page, scraped_page)
await asyncio.sleep(0.3)
return await chain_click(page, action, xpath, timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
async def handle_input_text_action(
action: actions.InputTextAction, page: Page, scraped_page: ScrapedPage, task: Task, step: Step
) -> list[ActionResult]:
xpath = await validate_actions_in_dom(action, page, scraped_page)
locator = page.locator(f"xpath={xpath}")
await locator.clear()
await locator.fill(action.text, timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
# This is a hack that gets dropdowns to select the "best" option based on what's typed
# Fixes situations like tsk_228671423990405776 where the location isn't being autocompleted
await locator.press("Tab", timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
input_value = await locator.input_value(timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
if not input_value:
LOG.info("Failed to input the text, trying to press sequentially with an enter click", action=action)
await locator.clear()
await locator.press_sequentially(action.text, timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
await locator.press("Enter", timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
input_value = await locator.input_value(timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
LOG.info("Input value", input_value=input_value, action=action)
return [ActionSuccess()]
async def handle_upload_file_action(
action: actions.UploadFileAction, page: Page, scraped_page: ScrapedPage, task: Task, step: Step
) -> list[ActionResult]:
if not action.file_url:
LOG.warning("InputFileAction has no file_url", action=action)
return [ActionFailure(MissingFileUrl())]
if action.file_url not in str(task.navigation_payload):
LOG.warning(
"LLM might be imagining the file url, which is not in navigation payload",
action=action,
file_url=action.file_url,
)
return [ActionFailure(ImaginaryFileUrl(action.file_url))]
xpath = await validate_actions_in_dom(action, page, scraped_page)
file_path = download_file(action.file_url)
locator = page.locator(f"xpath={xpath}")
is_file_input = await is_file_input_element(locator)
if is_file_input:
LOG.info("Taking UploadFileAction. Found file input tag", action=action)
if file_path:
await page.locator(f"xpath={xpath}").set_input_files(
file_path, timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS
)
# Sleep for 10 seconds after uploading a file to let the page process it
await asyncio.sleep(10)
return [ActionSuccess()]
else:
return [ActionFailure(Exception(f"Failed to download file from {action.file_url}"))]
else:
LOG.info("Taking UploadFileAction. Found non file input tag", action=action)
# treat it as a click action
action.is_upload_file_tag = False
return await chain_click(page, action, xpath, timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
async def handle_null_action(
action: actions.NullAction, page: Page, scraped_page: ScrapedPage, task: Task, step: Step
) -> list[ActionResult]:
return [ActionSuccess()]
async def handle_select_option_action(
action: actions.SelectOptionAction, page: Page, scraped_page: ScrapedPage, task: Task, step: Step
) -> list[ActionResult]:
xpath = await validate_actions_in_dom(action, page, scraped_page)
try:
# First click by label (if it matches)
await page.click(f"xpath={xpath}", timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
await page.select_option(
xpath,
label=action.option.label,
timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS,
)
await page.click(f"xpath={xpath}", timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
return [ActionSuccess()]
except Exception as e:
if action.option.index is not None:
LOG.warning(
"Failed to click on the option by label, trying by index",
exc_info=e,
action=action,
xpath=xpath,
)
else:
return [ActionFailure(e)]
try:
option_xpath = scraped_page.id_to_xpath_dict[action.option.index]
match = re.search(r"option\[(\d+)]$", option_xpath)
if match:
# This means we were trying to select an option xpath, click the option
option_index = int(match.group(1))
await page.click(f"xpath={xpath}", timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
await page.select_option(
xpath,
index=option_index,
timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS,
)
await page.click(f"xpath={xpath}", timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
return [ActionSuccess()]
else:
# This means the supplied index was for the select element, not a reference to the xpath dict
await page.click(f"xpath={xpath}", timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
await page.select_option(
xpath,
index=action.option.index,
timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS,
)
await page.click(f"xpath={xpath}", timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
return [ActionSuccess()]
except Exception as e:
LOG.warning("Failed to click on the option by index", exception=e, action=action)
return [ActionFailure(e)]
async def handle_checkbox_action(
self: actions.CheckboxAction, page: Page, scraped_page: ScrapedPage, task: Task, step: Step
) -> list[ActionResult]:
"""
******* NOT REGISTERED *******
This action causes more harm than it does good.
It frequently mis-behaves, or gets stuck in click loops.
Treating checkbox actions as click actions seem to perform way more reliably
Developers who tried this and failed: 2 (Suchintan and Shu 😂)
"""
xpath = await validate_actions_in_dom(self, page, scraped_page)
if self.is_checked:
await page.check(xpath, timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
else:
await page.uncheck(xpath, timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
# TODO (suchintan): Why does checking the label work, but not the actual input element?
return [ActionSuccess()]
async def handle_wait_action(
action: actions.WaitAction, page: Page, scraped_page: ScrapedPage, task: Task, step: Step
) -> list[ActionResult]:
await asyncio.sleep(10)
return [ActionFailure(exception=Exception("Wait action is treated as a failure"))]
async def handle_terminate_action(
action: actions.TerminateAction, page: Page, scraped_page: ScrapedPage, task: Task, step: Step
) -> list[ActionResult]:
return [ActionSuccess()]
async def handle_complete_action(
action: actions.CompleteAction, page: Page, scraped_page: ScrapedPage, task: Task, step: Step
) -> list[ActionResult]:
extracted_data = None
if action.data_extraction_goal:
scrape_action_result = await extract_information_for_navigation_goal(
scraped_page=scraped_page,
task=task,
step=step,
)
extracted_data = scrape_action_result.scraped_data
return [ActionSuccess(data=extracted_data)]
ActionHandler.register_action_type(ActionType.SOLVE_CAPTCHA, handle_solve_captcha_action)
ActionHandler.register_action_type(ActionType.CLICK, handle_click_action)
ActionHandler.register_action_type(ActionType.INPUT_TEXT, handle_input_text_action)
ActionHandler.register_action_type(ActionType.UPLOAD_FILE, handle_upload_file_action)
ActionHandler.register_action_type(ActionType.NULL_ACTION, handle_null_action)
ActionHandler.register_action_type(ActionType.SELECT_OPTION, handle_select_option_action)
ActionHandler.register_action_type(ActionType.WAIT, handle_wait_action)
ActionHandler.register_action_type(ActionType.TERMINATE, handle_terminate_action)
ActionHandler.register_action_type(ActionType.COMPLETE, handle_complete_action)
async def validate_actions_in_dom(action: WebAction, page: Page, scraped_page: ScrapedPage) -> str:
xpath = scraped_page.id_to_xpath_dict[action.element_id]
locator = page.locator(xpath)
num_elements = await locator.count()
if num_elements < 1:
LOG.warning("No elements found with action xpath. Validation failed.", action=action, xpath=xpath)
raise MissingElement(xpath=xpath, element_id=action.element_id)
elif num_elements > 1:
LOG.warning(
"Multiple elements found with action xpath. Expected 1. Validation failed.",
action=action,
num_elements=num_elements,
)
raise MultipleElementsFound(num=num_elements, xpath=xpath, element_id=action.element_id)
else:
LOG.info("Validated action xpath in DOM", action=action)
return xpath
async def chain_click(
page: Page,
action: ClickAction | UploadFileAction,
xpath: str,
timeout: int = SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS,
) -> List[ActionResult]:
# Add a defensive page handler here in case a click action opens a file chooser.
# This automatically dismisses the dialog
# File choosers are impossible to close if you don't expect one. Instead of dealing with it, close it!
# TODO (suchintan): This should likely result in an ActionFailure -- we can figure out how to do this later!
LOG.info("Chain click starts", action=action, xpath=xpath)
file: list[str] | str = []
if action.file_url:
file = download_file(action.file_url) or []
fc_func = lambda fc: fc.set_files(files=file)
page.on("filechooser", fc_func)
LOG.info("Registered file chooser listener", action=action, path=file)
"""
Clicks on an element identified by the xpath and its parent if failed.
:param xpath: xpath of the element to click
"""
javascript_triggered = await is_javascript_triggered(page, xpath)
try:
await page.click(f"xpath={xpath}", timeout=timeout)
LOG.info("Chain click: main element click succeeded", action=action, xpath=xpath)
return [ActionSuccess(javascript_triggered=javascript_triggered)]
except Exception as e:
action_results: list[ActionResult] = [ActionFailure(e, javascript_triggered=javascript_triggered)]
if await is_input_element(page.locator(xpath)):
LOG.info("Chain click: it's an input element. going to try sibling click", action=action, xpath=xpath)
sibling_action_result = await click_sibling_of_input(page.locator(xpath), timeout=timeout)
action_results.append(sibling_action_result)
if type(sibling_action_result) == ActionSuccess:
return action_results
parent_xpath = f"{xpath}/.."
try:
parent_javascript_triggered = await is_javascript_triggered(page, parent_xpath)
javascript_triggered = javascript_triggered or parent_javascript_triggered
parent_locator = page.locator(xpath).locator("..")
await parent_locator.click(timeout=timeout)
LOG.info("Chain click: successfully clicked parent element", action=action, parent_xpath=parent_xpath)
action_results.append(
ActionSuccess(
javascript_triggered=javascript_triggered,
interacted_with_parent=True,
)
)
except Exception as pe:
LOG.warning("Failed to click parent element", action=action, parent_xpath=parent_xpath, exc_info=True)
action_results.append(
ActionFailure(pe, javascript_triggered=javascript_triggered, interacted_with_parent=True)
)
# We don't raise exception here because we do log the exception, and return ActionFailure as the last action
return action_results
finally:
LOG.info("Remove file chooser listener", action=action)
# Sleep for 10 seconds after uploading a file to let the page process it
# Removing this breaks file uploads using the filechooser
# KEREM DO NOT REMOVE
if file:
await asyncio.sleep(10)
page.remove_listener("filechooser", fc_func)
async def is_javascript_triggered(page: Page, xpath: str) -> bool:
locator = page.locator(f"xpath={xpath}")
element = locator.first
tag_name = await element.evaluate("e => e.tagName")
if tag_name.lower() == "a":
href = await element.evaluate("e => e.href")
if href.lower().startswith("javascript:"):
LOG.info("Found javascript call in anchor tag, marking step as completed. Dropping remaining actions")
return True
return False
async def is_file_input_element(locator: Locator) -> bool:
element = locator.first
if element:
tag_name = await element.evaluate("el => el.tagName")
type_name = await element.evaluate("el => el.type")
return tag_name.lower() == "input" and type_name == "file"
return False
async def is_input_element(locator: Locator) -> bool:
element = locator.first
if element:
tag_name = await element.evaluate("el => el.tagName")
return tag_name.lower() == "input"
return False
async def click_sibling_of_input(
locator: Locator,
timeout: int,
javascript_triggered: bool = False,
) -> ActionResult:
try:
input_element = locator.first
parent_locator = locator.locator("..")
if input_element:
input_id = await input_element.get_attribute("id")
sibling_label_xpath = f'//label[@for="{input_id}"]'
label_locator = parent_locator.locator(sibling_label_xpath)
await label_locator.click(timeout=timeout)
LOG.info(
"Successfully clicked sibling label of input element",
sibling_label_xpath=sibling_label_xpath,
)
return ActionSuccess(javascript_triggered=javascript_triggered, interacted_with_sibling=True)
# Should never get here
return ActionFailure(
exception=Exception("Failed while trying to click sibling of input element"),
javascript_triggered=javascript_triggered,
interacted_with_sibling=True,
)
except Exception as e:
LOG.warning("Failed to click sibling label of input element", exc_info=e)
return ActionFailure(exception=e, javascript_triggered=javascript_triggered)
async def extract_information_for_navigation_goal(
task: Task,
step: Step,
scraped_page: ScrapedPage,
) -> ScrapeResult:
"""
Scrapes a webpage and returns the scraped response, including:
1. JSON representation of what the user is seeing
2. The scraped page
"""
prompt_template = "extract-information"
extract_information_prompt = prompt_engine.load_prompt(
prompt_template,
navigation_goal=task.navigation_goal,
elements=scraped_page.element_tree,
data_extraction_goal=task.data_extraction_goal,
extracted_information_schema=task.extracted_information_schema,
current_url=scraped_page.url,
extracted_text=scraped_page.extracted_text,
)
json_response = await app.OPENAI_CLIENT.chat_completion(
step=step,
prompt=extract_information_prompt,
screenshots=scraped_page.screenshots,
)
return ScrapeResult(
scraped_data=json_response,
)

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
from typing import Any
from pydantic import BaseModel
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.webeye.actions.actions import Action, ActionTypeUnion
from skyvern.webeye.actions.responses import ActionResult
from skyvern.webeye.scraper.scraper import ScrapedPage
class AgentStepOutput(BaseModel):
"""
Output of the agent step, this is recorded in the database.
"""
# Will be deprecated once we move to the new format below
action_results: list[ActionResult] | None = None
# Nullable for backwards compatibility, once backfill is done, this won't be nullable anymore
actions_and_results: list[tuple[ActionTypeUnion, list[ActionResult]]] | None = None
def __repr__(self) -> str:
return f"AgentStepOutput({self.model_dump()})"
def __str__(self) -> str:
return self.__repr__()
class DetailedAgentStepOutput(BaseModel):
"""
Output of the agent step, this is not recorded in the database, only used for debugging in the Jupyter notebook.
"""
scraped_page: ScrapedPage | None
extract_action_prompt: str | None
llm_response: dict[str, Any] | None
actions: list[Action] | None
action_results: list[ActionResult] | None
actions_and_results: list[tuple[ActionTypeUnion, list[ActionResult]]] | None
class Config:
exclude = ["scraped_page", "extract_action_prompt"]
def __repr__(self) -> str:
if SettingsManager.get_settings().DEBUG_MODE:
return f"DetailedAgentStepOutput({self.model_dump()})"
else:
return f"AgentStepOutput({self.to_agent_step_output().model_dump()})"
def __str__(self) -> str:
return self.__repr__()
def to_agent_step_output(self) -> AgentStepOutput:
return AgentStepOutput(
action_results=self.action_results if self.action_results else [],
actions_and_results=self.actions_and_results if self.actions_and_results else [],
)

View File

@@ -0,0 +1,62 @@
from typing import Any
from pydantic import BaseModel
from skyvern.webeye.string_util import remove_whitespace
class ActionResult(BaseModel):
success: bool
exception_type: str | None = None
exception_message: str | None = None
data: dict[str, Any] | list | str | None = None
step_retry_number: int | None = None
step_order: int | None = None
javascript_triggered: bool = False
# None is used for old data so that we can differentiate between old and new data which only has boolean
interacted_with_sibling: bool | None = None
interacted_with_parent: bool | None = None
def __str__(self) -> str:
return (
f"ActionResult(success={self.success}, exception_type={self.exception_type}, "
f"exception_message={self.exception_message}), data={self.data}"
)
def __repr__(self) -> str:
return self.__str__()
class ActionSuccess(ActionResult):
def __init__(
self,
data: dict[str, Any] | list | str | None = None,
javascript_triggered: bool = False,
interacted_with_sibling: bool = False,
interacted_with_parent: bool = False,
):
super().__init__(
success=True,
data=data,
javascript_triggered=javascript_triggered,
interacted_with_sibling=interacted_with_sibling,
interacted_with_parent=interacted_with_parent,
)
class ActionFailure(ActionResult):
def __init__(
self,
exception: Exception,
javascript_triggered: bool = False,
interacted_with_sibling: bool = False,
interacted_with_parent: bool = False,
):
super().__init__(
success=False,
exception_type=type(exception).__name__,
exception_message=remove_whitespace(str(exception)),
javascript_triggered=javascript_triggered,
interacted_with_sibling=interacted_with_sibling,
interacted_with_parent=interacted_with_parent,
)

View File

@@ -0,0 +1,167 @@
from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any, Awaitable, Protocol
import structlog
from playwright.async_api import BrowserContext, Error, Page, Playwright, async_playwright
from pydantic import BaseModel
from skyvern.exceptions import FailedToNavigateToUrl, UnknownBrowserType, UnknownErrorWhileCreatingBrowserContext
from skyvern.forge.sdk.core.skyvern_context import current
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class BrowserContextCreator(Protocol):
def __call__(
self, playwright: Playwright, **kwargs: dict[str, Any]
) -> Awaitable[tuple[BrowserContext, BrowserArtifacts]]:
...
class BrowserContextFactory:
_creators: dict[str, BrowserContextCreator] = {}
@staticmethod
def get_subdir() -> str:
curr_context = current()
if curr_context and curr_context.task_id:
return curr_context.task_id
elif curr_context and curr_context.request_id:
return curr_context.request_id
return str(uuid.uuid4())
@staticmethod
def build_browser_args() -> dict[str, Any]:
video_dir = f"{SettingsManager.get_settings().VIDEO_PATH}/{datetime.utcnow().strftime('%Y-%m-%d')}"
har_dir = f"{SettingsManager.get_settings().HAR_PATH}/{datetime.utcnow().strftime('%Y-%m-%d')}/{BrowserContextFactory.get_subdir()}.har"
return {
"record_har_path": har_dir,
"record_video_dir": video_dir,
"viewport": {"width": 1920, "height": 1080},
}
@staticmethod
def build_browser_artifacts(
video_path: str | None = None, har_path: str | None = None, video_artifact_id: str | None = None
) -> BrowserArtifacts:
return BrowserArtifacts(video_path=video_path, har_path=har_path, video_artifact_id=video_artifact_id)
@classmethod
def register_type(cls, browser_type: str, creator: BrowserContextCreator) -> None:
cls._creators[browser_type] = creator
@classmethod
async def create_browser_context(
cls, playwright: Playwright, **kwargs: Any
) -> tuple[BrowserContext, BrowserArtifacts]:
browser_type = SettingsManager.get_settings().BROWSER_TYPE
try:
creator = cls._creators.get(browser_type)
if not creator:
raise UnknownBrowserType(browser_type)
return await creator(playwright, **kwargs)
except UnknownBrowserType as e:
raise e
except Exception as e:
raise UnknownErrorWhileCreatingBrowserContext(browser_type, e) from e
class BrowserArtifacts(BaseModel):
video_path: str | None = None
video_artifact_id: str | None = None
har_path: str | None = None
async def _create_headless_chromium(playwright: Playwright, **kwargs: dict) -> tuple[BrowserContext, BrowserArtifacts]:
browser = await playwright.chromium.launch(headless=True)
browser_args = BrowserContextFactory.build_browser_args()
browser_artifacts = BrowserContextFactory.build_browser_artifacts(har_path=browser_args["record_har_path"])
browser_context = await browser.new_context(**browser_args)
return browser_context, browser_artifacts
async def _create_headful_chromium(playwright: Playwright, **kwargs: dict) -> tuple[BrowserContext, BrowserArtifacts]:
browser = await playwright.chromium.launch(headless=False)
browser_args = BrowserContextFactory.build_browser_args()
browser_artifacts = BrowserContextFactory.build_browser_artifacts(har_path=browser_args["record_har_path"])
browser_context = await browser.new_context(**browser_args)
return browser_context, browser_artifacts
BrowserContextFactory.register_type("chromium-headless", _create_headless_chromium)
BrowserContextFactory.register_type("chromium-headful", _create_headful_chromium)
class BrowserState:
instance = None
def __init__(
self,
pw: Playwright | None = None,
browser_context: BrowserContext | None = None,
page: Page | None = None,
browser_artifacts: BrowserArtifacts = BrowserArtifacts(),
):
self.pw = pw
self.browser_context = browser_context
self.page = page
self.browser_artifacts = browser_artifacts
async def _close_all_other_pages(self) -> None:
if not self.browser_context or not self.page:
return
pages = self.browser_context.pages
for page in pages:
if page != self.page:
await page.close()
async def check_and_fix_state(self, url: str | None = None) -> None:
if self.pw is None:
LOG.info("Starting playwright")
self.pw = await async_playwright().start()
LOG.info("playwright is started")
if self.browser_context is None:
LOG.info("creating browser context")
browser_context, browser_artifacts = await BrowserContextFactory.create_browser_context(self.pw, url=url)
self.browser_context = browser_context
self.browser_artifacts = browser_artifacts
LOG.info("browser context is created")
assert self.browser_context is not None
if self.page is None:
LOG.info("Creating a new page")
self.page = await self.browser_context.new_page()
await self._close_all_other_pages()
LOG.info("A new page is created")
if url:
LOG.info(f"Navigating page to {url} and waiting for 5 seconds")
try:
await self.page.goto(url)
except Error as playright_error:
LOG.exception(f"Error while navigating to url: {str(playright_error)}", exc_info=True)
raise FailedToNavigateToUrl(url=url, error_message=str(playright_error))
LOG.info(f"Successfully went to {url}")
if self.browser_artifacts.video_path is None:
self.browser_artifacts.video_path = await self.page.video.path()
async def get_or_create_page(self, url: str | None = None) -> Page:
await self.check_and_fix_state(url)
assert self.page is not None
return self.page
async def close(self, close_browser_on_completion: bool = True) -> None:
LOG.info("Closing browser state")
if self.browser_context and close_browser_on_completion:
LOG.info("Closing browser context and its pages")
await self.browser_context.close()
LOG.info("Main browser context and all its pages are closed")
if self.pw and close_browser_on_completion:
LOG.info("Stopping playwright")
await self.pw.stop()
LOG.info("Playwright is stopped")

View File

@@ -0,0 +1,152 @@
from __future__ import annotations
import structlog
from playwright.async_api import Browser, Playwright, async_playwright
from skyvern.exceptions import MissingBrowserState
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
from skyvern.webeye.browser_factory import BrowserContextFactory, BrowserState
LOG = structlog.get_logger()
class BrowserManager:
instance = None
pages: dict[str, BrowserState] = dict()
def __new__(cls) -> BrowserManager:
if cls.instance is None:
cls.instance = super().__new__(cls)
return cls.instance
@staticmethod
async def _create_browser_state(
proxy_location: ProxyLocation | None = None, url: str | None = None
) -> BrowserState:
pw = await async_playwright().start()
browser_context, browser_artifacts = await BrowserContextFactory.create_browser_context(
pw, proxy_location=proxy_location, url=url
)
return BrowserState(pw=pw, browser_context=browser_context, page=None, browser_artifacts=browser_artifacts)
async def get_or_create_for_task(self, task: Task) -> BrowserState:
if task.task_id in self.pages:
return self.pages[task.task_id]
elif task.workflow_run_id in self.pages:
LOG.info(
"Browser state for task not found. Using browser state for workflow run",
task_id=task.task_id,
workflow_run_id=task.workflow_run_id,
)
self.pages[task.task_id] = self.pages[task.workflow_run_id]
return self.pages[task.task_id]
LOG.info("Creating browser state for task", task_id=task.task_id)
browser_state = await self._create_browser_state(task.proxy_location, task.url)
# The URL here is only used when creating a new page, and not when using an existing page.
# This will make sure browser_state.page is not None.
await browser_state.get_or_create_page(task.url)
self.pages[task.task_id] = browser_state
return browser_state
async def get_or_create_for_workflow_run(self, workflow_run: WorkflowRun, url: str | None = None) -> BrowserState:
if workflow_run.workflow_run_id in self.pages:
return self.pages[workflow_run.workflow_run_id]
LOG.info("Creating browser state for workflow run", workflow_run_id=workflow_run.workflow_run_id)
browser_state = await self._create_browser_state(workflow_run.proxy_location, url=url)
# The URL here is only used when creating a new page, and not when using an existing page.
# This will make sure browser_state.page is not None.
await browser_state.get_or_create_page(url)
self.pages[workflow_run.workflow_run_id] = browser_state
return browser_state
def set_video_artifact_for_task(self, task: Task, artifact_id: str) -> None:
if task.workflow_run_id and task.workflow_run_id in self.pages:
if self.pages[task.workflow_run_id].browser_artifacts.video_artifact_id:
LOG.warning(
"Video artifact is already set for workflow run. Overwriting",
workflow_run_id=task.workflow_run_id,
old_artifact_id=self.pages[task.workflow_run_id].browser_artifacts.video_artifact_id,
new_artifact_id=artifact_id,
)
self.pages[task.workflow_run_id].browser_artifacts.video_artifact_id = artifact_id
return
if task.task_id in self.pages:
if self.pages[task.task_id].browser_artifacts.video_artifact_id:
LOG.warning(
"Video artifact is already set for task. Overwriting",
task_id=task.task_id,
old_artifact_id=self.pages[task.task_id].browser_artifacts.video_artifact_id,
new_artifact_id=artifact_id,
)
self.pages[task.task_id].browser_artifacts.video_artifact_id = artifact_id
return
raise MissingBrowserState(task_id=task.task_id)
async def get_video_data(
self, browser_state: BrowserState, task_id: str = "", workflow_id: str = "", workflow_run_id: str = ""
) -> bytes:
if browser_state:
path = browser_state.browser_artifacts.video_path
if path:
with open(path, "rb") as f:
return f.read()
LOG.warning(
"Video data not found for task", task_id=task_id, workflow_id=workflow_id, workflow_run_id=workflow_run_id
)
return b""
async def get_har_data(
self, browser_state: BrowserState, task_id: str = "", workflow_id: str = "", workflow_run_id: str = ""
) -> bytes:
if browser_state:
path = browser_state.browser_artifacts.har_path
if path:
with open(path, "rb") as f:
return f.read()
LOG.warning(
"HAR data not found for task", task_id=task_id, workflow_id=workflow_id, workflow_run_id=workflow_run_id
)
return b""
@classmethod
async def connect_to_scraping_browser(cls, pw: Playwright) -> Browser:
if not SettingsManager.get_settings().REMOTE_BROWSER_KEY:
raise Exception("REMOTE_BROWSER_KEY is empty. Cannot connect to remote browser.")
browser = await pw.chromium.connect_over_cdp(SettingsManager.get_settings().REMOTE_BROWSER_KEY)
LOG.info("Connected to remote browser", browser_type=SettingsManager.get_settings().BROWSER_TYPE)
return browser
@classmethod
async def close(cls) -> None:
LOG.info("Closing BrowserManager")
for browser_state in cls.pages.values():
await browser_state.close()
cls.pages = dict()
LOG.info("BrowserManger is closed")
async def cleanup_for_task(self, task_id: str, close_browser_on_completion: bool = True) -> BrowserState | None:
LOG.info("Cleaning up for task")
browser_state_to_close = self.pages.pop(task_id, None)
if browser_state_to_close:
await browser_state_to_close.close(close_browser_on_completion=close_browser_on_completion)
LOG.info("Task is cleaned up")
return browser_state_to_close
async def cleanup_for_workflow_run(
self, workflow_run_id: str, close_browser_on_completion: bool = True
) -> BrowserState | None:
LOG.info("Cleaning up for workflow run")
browser_state_to_close = self.pages.pop(workflow_run_id, None)
if browser_state_to_close:
await browser_state_to_close.close(close_browser_on_completion=close_browser_on_completion)
LOG.info("Workflow run is cleaned up")
return browser_state_to_close

View File

View File

@@ -0,0 +1,806 @@
// Commands for manipulating rects.
class Rect {
// Create a rect given the top left and bottom right corners.
static create(x1, y1, x2, y2) {
return {
bottom: y2,
top: y1,
left: x1,
right: x2,
width: x2 - x1,
height: y2 - y1,
};
}
static copy(rect) {
return {
bottom: rect.bottom,
top: rect.top,
left: rect.left,
right: rect.right,
width: rect.width,
height: rect.height,
};
}
// Translate a rect by x horizontally and y vertically.
static translate(rect, x, y) {
if (x == null) x = 0;
if (y == null) y = 0;
return {
bottom: rect.bottom + y,
top: rect.top + y,
left: rect.left + x,
right: rect.right + x,
width: rect.width,
height: rect.height,
};
}
// Determine whether two rects overlap.
static intersects(rect1, rect2) {
return (
rect1.right > rect2.left &&
rect1.left < rect2.right &&
rect1.bottom > rect2.top &&
rect1.top < rect2.bottom
);
}
static equals(rect1, rect2) {
for (const property of [
"top",
"bottom",
"left",
"right",
"width",
"height",
]) {
if (rect1[property] !== rect2[property]) return false;
}
return true;
}
}
class DomUtils {
//
// Bounds the rect by the current viewport dimensions. If the rect is offscreen or has a height or
// width < 3 then null is returned instead of a rect.
//
static cropRectToVisible(rect) {
const boundedRect = Rect.create(
Math.max(rect.left, 0),
Math.max(rect.top, 0),
rect.right,
rect.bottom,
);
if (
boundedRect.top >= window.innerHeight - 4 ||
boundedRect.left >= window.innerWidth - 4
) {
return null;
} else {
return boundedRect;
}
}
static getVisibleClientRect(element, testChildren) {
// Note: this call will be expensive if we modify the DOM in between calls.
let clientRect;
if (testChildren == null) testChildren = false;
const clientRects = (() => {
const result = [];
for (clientRect of element.getClientRects()) {
result.push(Rect.copy(clientRect));
}
return result;
})();
// Inline elements with font-size: 0px; will declare a height of zero, even if a child with
// non-zero font-size contains text.
let isInlineZeroHeight = function () {
const elementComputedStyle = window.getComputedStyle(element, null);
const isInlineZeroFontSize =
0 ===
elementComputedStyle.getPropertyValue("display").indexOf("inline") &&
elementComputedStyle.getPropertyValue("font-size") === "0px";
// Override the function to return this value for the rest of this context.
isInlineZeroHeight = () => isInlineZeroFontSize;
return isInlineZeroFontSize;
};
for (clientRect of clientRects) {
// If the link has zero dimensions, it may be wrapping visible but floated elements. Check for
// this.
let computedStyle;
if ((clientRect.width === 0 || clientRect.height === 0) && testChildren) {
for (const child of Array.from(element.children)) {
computedStyle = window.getComputedStyle(child, null);
// Ignore child elements which are not floated and not absolutely positioned for parent
// elements with zero width/height, as long as the case described at isInlineZeroHeight
// does not apply.
// NOTE(mrmr1993): This ignores floated/absolutely positioned descendants nested within
// inline children.
const position = computedStyle.getPropertyValue("position");
if (
computedStyle.getPropertyValue("float") === "none" &&
!["absolute", "fixed"].includes(position) &&
!(
clientRect.height === 0 &&
isInlineZeroHeight() &&
0 === computedStyle.getPropertyValue("display").indexOf("inline")
)
) {
continue;
}
const childClientRect = this.getVisibleClientRect(child, true);
if (
childClientRect === null ||
childClientRect.width < 3 ||
childClientRect.height < 3
)
continue;
return childClientRect;
}
} else {
clientRect = this.cropRectToVisible(clientRect);
if (
clientRect === null ||
clientRect.width < 3 ||
clientRect.height < 3
)
continue;
// eliminate invisible elements (see test_harnesses/visibility_test.html)
computedStyle = window.getComputedStyle(element, null);
if (computedStyle.getPropertyValue("visibility") !== "visible")
continue;
return clientRect;
}
}
return null;
}
static getViewportTopLeft() {
const box = document.documentElement;
const style = getComputedStyle(box);
const rect = box.getBoundingClientRect();
if (
style.position === "static" &&
!/content|paint|strict/.test(style.contain || "")
) {
// The margin is included in the client rect, so we need to subtract it back out.
const marginTop = parseInt(style.marginTop);
const marginLeft = parseInt(style.marginLeft);
return {
top: -rect.top + marginTop,
left: -rect.left + marginLeft,
};
} else {
const { clientTop, clientLeft } = box;
return {
top: -rect.top - clientTop,
left: -rect.left - clientLeft,
};
}
}
}
// from playwright
function getElementComputedStyle(element, pseudo) {
return element.ownerDocument && element.ownerDocument.defaultView
? element.ownerDocument.defaultView.getComputedStyle(element, pseudo)
: undefined;
}
// from playwright
function isElementStyleVisibilityVisible(element, style) {
style = style ?? getElementComputedStyle(element);
if (!style) return true;
if (
!element.checkVisibility({ checkOpacity: false, checkVisibilityCSS: false })
)
return false;
if (style.visibility !== "visible") return false;
return true;
}
// from playwright
function isElementVisible(element) {
// TODO: This is a hack to not check visibility for option elements
// because they are not visible by default. We check their parent instead for visibility.
if (element.tagName.toLowerCase() === "option")
return element.parentElement && isElementVisible(element.parentElement);
const style = getElementComputedStyle(element);
if (!style) return true;
if (style.display === "contents") {
// display:contents is not rendered itself, but its child nodes are.
for (let child = element.firstChild; child; child = child.nextSibling) {
if (
child.nodeType === 1 /* Node.ELEMENT_NODE */ &&
isElementVisible(child)
)
return true;
// skipping other nodes including text
}
return false;
}
if (!isElementStyleVisibilityVisible(element, style)) return false;
const rect = element.getBoundingClientRect();
return rect.width > 0 && rect.height > 0;
}
function isHiddenOrDisabled(element) {
const style = getElementComputedStyle(element);
return style?.display === "none" || element.hidden || element.disabled;
}
function isScriptOrStyle(element) {
const tagName = element.tagName.toLowerCase();
return tagName === "script" || tagName === "style";
}
function hasWidgetRole(element) {
const role = element.getAttribute("role");
if (!role) {
return false;
}
// https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA/Roles#2._widget_roles
// Not all roles make sense for the time being so we only check for the ones that do
const widgetRoles = [
"button",
"link",
"checkbox",
"menuitem",
"menuitemcheckbox",
"menuitemradio",
"radio",
"tab",
"combobox",
"textbox",
"searchbox",
"slider",
"spinbutton",
"switch",
"gridcell",
];
return widgetRoles.includes(role.toLowerCase().trim());
}
function isInteractableInput(element) {
const tagName = element.tagName.toLowerCase();
const type = element.getAttribute("type");
if (tagName !== "input" || !type) {
// let other checks decide
return false;
}
const clickableTypes = [
"button",
"checkbox",
"date",
"datetime-local",
"email",
"file",
"image",
"month",
"number",
"password",
"radio",
"range",
"reset",
"search",
"submit",
"tel",
"text",
"time",
"url",
"week",
];
return clickableTypes.includes(type.toLowerCase().trim());
}
function isInteractable(element) {
if (!isElementVisible(element)) {
return false;
}
if (isHiddenOrDisabled(element)) {
return false;
}
if (isScriptOrStyle(element)) {
return false;
}
if (hasWidgetRole(element)) {
return true;
}
if (isInteractableInput(element)) {
return true;
}
const tagName = element.tagName.toLowerCase();
if (tagName === "a" && element.href) {
return true;
}
if (
tagName === "button" ||
tagName === "select" ||
tagName === "option" ||
tagName === "textarea"
) {
return true;
}
if (tagName === "label" && element.control && !element.control.disabled) {
return true;
}
if (
element.hasAttribute("onclick") ||
element.isContentEditable ||
element.hasAttribute("jsaction")
) {
return true;
}
if (tagName === "div" || tagName === "img" || tagName === "span") {
const computedStyle = window.getComputedStyle(element);
const hasPointer = computedStyle.cursor === "pointer";
const hasCursor = computedStyle.cursor === "cursor";
return hasPointer || hasCursor;
}
return false;
}
function removeMultipleSpaces(str) {
if (!str) {
return str;
}
return str.replace(/\s+/g, " ");
}
function cleanupText(text) {
return removeMultipleSpaces(
text.replace("SVGs not supported by this browser.", ""),
).trim();
}
function getElementContext(element) {
// dfs to collect the non unique_id context
let fullContext = "";
if (element.childNodes.length === 0) {
return fullContext;
}
let childContextList = new Array();
for (var child of element.childNodes) {
let childContext = "";
if (child.nodeType === Node.TEXT_NODE) {
if (!element.hasAttribute("unique_id")) {
childContext = child.data.trim();
}
} else if (child.nodeType === Node.ELEMENT_NODE) {
if (!child.hasAttribute("unique_id")) {
childContext = getElementContext(child);
}
}
if (childContext.length > 0) {
childContextList.push(childContext);
}
if (childContextList.length > 0) {
fullContext = childContextList.join(";");
}
const charLimit = 1000;
if (fullContext.length > charLimit) {
fullContext = "";
}
}
return fullContext;
}
function getElementContent(element) {
// DFS to get all the text content from all the nodes under the element
let textContent = element.textContent;
let nodeContent = "";
// if element has children, then build a list of text and join with a semicolon
if (element.childNodes.length > 0) {
let childTextContentList = new Array();
let nodeTextContentList = new Array();
for (var child of element.childNodes) {
let childText = "";
if (child.nodeType === Node.TEXT_NODE) {
childText = child.data.trim();
nodeTextContentList.push(childText);
} else if (child.nodeType === Node.ELEMENT_NODE) {
// childText = child.textContent.trim();
childText = getElementContent(child);
} else {
console.log("Unhandled node type: ", child.nodeType);
}
if (childText.length > 0) {
childTextContentList.push(childText);
}
}
textContent = childTextContentList.join(";");
nodeContent = cleanupText(nodeTextContentList.join(";"));
}
let finalTextContent = cleanupText(textContent);
// Currently we don't support too much context. Character limit is 1000 per element.
// we don't think element context has to be that big
const charLimit = 1000;
if (finalTextContent.length > charLimit) {
if (nodeContent.length <= charLimit) {
finalTextContent = nodeContent;
} else {
finalTextContent = "";
}
}
return finalTextContent;
}
function getSelectOptions(element) {
const options = Array.from(element.options);
const selectOptions = [];
for (const option of options) {
selectOptions.push({
optionIndex: option.index,
text: removeMultipleSpaces(option.textContent),
});
}
return selectOptions;
}
function buildTreeFromBody() {
var elements = [];
var resultArray = [];
function buildElementObject(element) {
var element_id = elements.length;
var elementTagNameLower = element.tagName.toLowerCase();
element.setAttribute("unique_id", element_id);
// if element is an "a" tag and has a target="_blank" attribute, remove the target attribute
// We're doing this so that skyvern can do all the navigation in a single page/tab and not open new tab
if (element.tagName.toLowerCase() === "a") {
if (element.getAttribute("target") === "_blank") {
element.removeAttribute("target");
}
}
const attrs = {};
for (const attr of element.attributes) {
var attrValue = attr.value;
if (
attr.name === "required" ||
attr.name === "aria-required" ||
attr.name === "checked" ||
attr.name === "aria-checked" ||
attr.name === "selected" ||
attr.name === "aria-selected" ||
attr.name === "readonly" ||
attr.name === "aria-readonly"
) {
attrValue = true;
}
attrs[attr.name] = attrValue;
}
if (elementTagNameLower === "input" || elementTagNameLower === "textarea") {
attrs["value"] = element.value;
}
let elementObj = {
id: element_id,
tagName: elementTagNameLower,
attributes: attrs,
text: getElementContent(element),
children: [],
rect: DomUtils.getVisibleClientRect(element, true),
};
// get options for select element or for listbox element
let selectOptions = null;
if (elementTagNameLower === "select") {
selectOptions = getSelectOptions(element);
}
if (selectOptions) {
elementObj.options = selectOptions;
}
return elementObj;
}
function getChildElements(element) {
if (element.childElementCount !== 0) {
return Array.from(element.children);
} else {
return [];
}
}
function processElement(element, interactableParentId) {
// Check if the element is interactable
if (isInteractable(element)) {
var elementObj = buildElementObject(element);
elements.push(elementObj);
// If the element is interactable but has no interactable parent,
// then it starts a new tree, so add it to the result array
// and set its id as the interactable parent id for the next elements
// under it
if (interactableParentId === null) {
resultArray.push(elementObj);
}
// If the element is interactable and has an interactable parent,
// then add it to the children of the parent
else {
elements[interactableParentId].children.push(elementObj);
}
// Recursively process the children of the element
getChildElements(element).forEach((child) => {
processElement(child, elementObj.id);
});
return elementObj;
} else {
// For a non-interactable element, process its children
// and check if any of them are interactable
let interactableChildren = [];
getChildElements(element).forEach((child) => {
let children = processElement(child, interactableParentId);
});
}
}
// TODO: Handle iframes
// Clear all the unique_id attributes so that there are no conflicts
removeAllUniqueIdAttributes();
processElement(document.body, null);
for (var element of elements) {
if (
((element.tagName === "input" && element.attributes["type"] === "text") ||
element.tagName === "textarea") &&
(element.attributes["required"] || element.attributes["aria-required"]) &&
element.attributes.value === ""
) {
// TODO (kerem): we may want to pass these elements to the LLM as empty but required fields in the future
console.log(
"input element with required attribute and no value",
element,
);
}
// for most elements, we're going 10 layers up to see if we can find "label" as a parent
// if found, most likely the context under label is relevant to this element
let targetParentElements = new Set(["label", "fieldset"]);
// look up for 10 levels to find the most contextual parent element
let targetContextualParent = null;
let currentEle = document.querySelector(`[unique_id="${element.id}"]`);
let parentEle = currentEle;
for (var i = 0; i < 10; i++) {
parentEle = parentEle.parentElement;
if (parentEle) {
if (targetParentElements.has(parentEle.tagName.toLowerCase())) {
targetContextualParent = parentEle;
}
} else {
break;
}
}
if (targetContextualParent) {
let context = "";
var lowerCaseTagName = targetContextualParent.tagName.toLowerCase();
if (lowerCaseTagName === "label") {
context = getElementContext(targetContextualParent);
} else if (lowerCaseTagName === "fieldset") {
// fieldset is usually within a form or another element that contains the whole context
targetContextualParent = targetContextualParent.parentElement;
if (targetContextualParent) {
context = getElementContext(targetContextualParent);
}
}
if (context.length > 0) {
element.context = context;
}
}
}
return [elements, resultArray];
}
function drawBoundingBoxes(elements) {
// draw a red border around the elements
var groups = groupElementsVisually(elements);
var hintMarkers = createHintMarkersForGroups(groups);
addHintMarkersToPage(hintMarkers);
}
function removeAllUniqueIdAttributes() {
var elementsWithUniqueId = document.querySelectorAll("[unique_id]");
elementsWithUniqueId.forEach(function (element) {
element.removeAttribute("unique_id");
});
}
function captchaSolvedCallback() {
console.log("captcha solved");
if (!window["captchaSolvedCounter"]) {
window["captchaSolvedCounter"] = 0;
}
// For some reason this isn't being called.. TODO figure out why
window["captchaSolvedCounter"] = window["captchaSolvedCounter"] + 1;
}
function getCaptchaSolves() {
if (!window["captchaSolvedCounter"]) {
window["captchaSolvedCounter"] = 0;
}
return window["captchaSolvedCounter"];
}
function groupElementsVisually(elements) {
const groups = [];
// o n^2
// go through each hint and see if it overlaps with any other hints, if it does, add it to the group of the other hint
// *** if we start from the bigger elements (top -> bottom) we can avoid merging groups
for (const element of elements) {
if (!element.rect) {
continue;
}
const group = groups.find((group) => {
for (const groupElement of group.elements) {
if (Rect.intersects(groupElement.rect, element.rect)) {
return true;
}
}
return false;
});
if (group) {
group.elements.push(element);
} else {
groups.push({
elements: [element],
});
}
}
// go through each group and create a rectangle that encompasses all the hints in the group
for (const group of groups) {
group.rect = createRectangleForGroup(group);
}
return groups;
}
function createRectangleForGroup(group) {
const rects = group.elements.map((element) => element.rect);
const top = Math.min(...rects.map((rect) => rect.top));
const left = Math.min(...rects.map((rect) => rect.left));
const bottom = Math.max(...rects.map((rect) => rect.bottom));
const right = Math.max(...rects.map((rect) => rect.right));
return Rect.create(left, top, right, bottom);
}
function generateHintStrings(count) {
const hintCharacters = "sadfjklewcmpgh";
let hintStrings = [""];
let offset = 0;
while (hintStrings.length - offset < count || hintStrings.length === 1) {
const hintString = hintStrings[offset++];
for (const ch of hintCharacters) {
hintStrings.push(ch + hintString);
}
}
hintStrings = hintStrings.slice(offset, offset + count);
// Shuffle the hints so that they're scattered; hints starting with the same character and short
// hints are spread evenly throughout the array.
return hintStrings.sort(); // .map((str) => str.reverse())
}
function createHintMarkersForGroups(groups) {
if (groups.length === 0) {
console.log("No groups found, not adding hint markers to page.");
return [];
}
const hintMarkers = groups.map((group) => createHintMarkerForGroup(group));
// fill in marker text
const hintStrings = generateHintStrings(hintMarkers.length);
for (let i = 0; i < hintMarkers.length; i++) {
const hintMarker = hintMarkers[i];
hintMarker.hintString = hintStrings[i];
hintMarker.element.innerHTML = hintMarker.hintString.toUpperCase();
}
return hintMarkers;
}
function createHintMarkerForGroup(group) {
const marker = {};
// yellow annotation box with string
const el = document.createElement("div");
el.style.left = group.rect.left + "px";
el.style.top = group.rect.top + "px";
// Each group is assigned a different incremental z-index, we use the same z-index for the
// bounding box and the hint marker
el.style.zIndex = this.currentZIndex;
// The bounding box around the group of hints.
const boundingBox = document.createElement("div");
// Calculate the position of the element relative to the document
var scrollTop = window.pageYOffset || document.documentElement.scrollTop;
var scrollLeft = window.pageXOffset || document.documentElement.scrollLeft;
// Set styles for the bounding box
boundingBox.style.position = "absolute";
boundingBox.style.display = "display";
boundingBox.style.left = group.rect.left + scrollLeft + "px";
boundingBox.style.top = group.rect.top + scrollTop + "px";
boundingBox.style.width = group.rect.width + "px";
boundingBox.style.height = group.rect.height + "px";
boundingBox.style.bottom = boundingBox.style.top + boundingBox.style.height;
boundingBox.style.right = boundingBox.style.left + boundingBox.style.width;
boundingBox.style.border = "2px solid blue"; // Change the border color as needed
boundingBox.style.pointerEvents = "none"; // Ensures the box doesn't interfere with other interactions
boundingBox.style.zIndex = this.currentZIndex++;
return Object.assign(marker, {
element: el,
boundingBox: boundingBox,
group: group,
});
}
function addHintMarkersToPage(hintMarkers) {
const parent = document.createElement("div");
parent.id = "boundingBoxContainer";
for (const hintMarker of hintMarkers) {
// parent.appendChild(hintMarker.element);
parent.appendChild(hintMarker.boundingBox);
}
document.documentElement.appendChild(parent);
}
function removeBoundingBoxes() {
var hintMarkerContainer = document.querySelector("#boundingBoxContainer");
if (hintMarkerContainer) {
hintMarkerContainer.remove();
}
}
function scrollToTop(draw_boxes) {
removeBoundingBoxes();
window.scrollTo(0, 0);
if (draw_boxes) {
var elementsAndResultArray = buildTreeFromBody();
drawBoundingBoxes(elementsAndResultArray[0]);
}
return window.scrollY;
}
function scrollToNextPage(draw_boxes) {
// remove bounding boxes, scroll to next page with 200px overlap, then draw bounding boxes again
// return true if there is a next page, false otherwise
removeBoundingBoxes();
window.scrollBy(0, window.innerHeight - 200);
if (draw_boxes) {
var elementsAndResultArray = buildTreeFromBody();
drawBoundingBoxes(elementsAndResultArray[0]);
}
return window.scrollY;
}

View File

@@ -0,0 +1,316 @@
import asyncio
import copy
import structlog
from playwright.async_api import Page
from pydantic import BaseModel
from skyvern.constants import SKYVERN_DIR, SKYVERN_ID_ATTR
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.webeye.browser_factory import BrowserState
LOG = structlog.get_logger()
RESERVED_ATTRIBUTES = {
"accept", # for input file
"alt",
"aria-checked", # for option tag
"aria-current",
"aria-label",
"aria-required",
"aria-role",
"aria-selected", # for option tag
"checked",
"data-ui",
"for",
"href", # For a tags
"maxlength",
"name",
"pattern",
"placeholder",
"readonly",
"required",
"selected", # for option tag
"src", # do we need this?
"text-value",
"title",
"type",
"value",
}
def load_js_script() -> str:
# TODO: Handle file location better. This is a hacky way to find the file location.
path = f"{SKYVERN_DIR}/webeye/scraper/domUtils.js"
try:
# TODO: Implement TS of domUtils.js and use the complied JS file instead of the raw JS file.
# This will allow our code to be type safe.
with open(path, "r") as f:
return f.read()
except FileNotFoundError as e:
LOG.exception("Failed to load the JS script", exc_info=True, path=path)
raise e
JS_FUNCTION_DEFS = load_js_script()
class ScrapedPage(BaseModel):
"""
Scraped response from a webpage, including:
1. List of elements
2. ID to xpath map
3. The element tree of the page (list of dicts). Each element has children and attributes.
4. The screenshot (base64 encoded)
5. The URL of the page
6. The HTML of the page
7. The extracted text from the page
"""
elements: list[dict]
id_to_xpath_dict: dict[int, str]
element_tree: list[dict]
element_tree_trimmed: list[dict]
screenshots: list[bytes]
url: str
html: str
extracted_text: str | None = None
async def scrape_website(
browser_state: BrowserState,
url: str,
num_retry: int = 0,
) -> ScrapedPage:
"""
************************************************************************************************
************ NOTE: MAX_SCRAPING_RETRIES is set to 0 in both staging and production *************
************************************************************************************************
High-level asynchronous function to scrape a web page. It sets up the Playwright environment, handles browser and
page initialization, and calls the safe scraping function. This function is ideal for general use where initial
setup and safety measures are required.
Asynchronous function that safely scrapes a web page. It handles exceptions and retries scraping up to a maximum
number of attempts. This function should be used when reliability and error handling are crucial, such as in
automated scraping tasks.
:param browser_context: BrowserContext instance used for scraping.
:param url: URL of the web page to be scraped.
:param page: Optional Page instance for scraping, a new page is created if None.
:param num_retry: Tracks number of retries if scraping fails, defaults to 0.
:return: Tuple containing Page instance, base64 encoded screenshot, and page elements.
:raises Exception: When scraping fails after maximum retries.
"""
try:
num_retry += 1
return await scrape_web_unsafe(browser_state, url)
except Exception:
# NOTE: MAX_SCRAPING_RETRIES is set to 0 in both staging and production
if num_retry > SettingsManager.get_settings().MAX_SCRAPING_RETRIES:
LOG.error(
"Scraping failed after max retries, aborting.",
max_retries=SettingsManager.get_settings().MAX_SCRAPING_RETRIES,
url=url,
exc_info=True,
)
raise Exception("Scraping failed.")
LOG.info("Scraping failed, will retry", num_retry=num_retry, url=url)
return await scrape_website(
browser_state,
url,
num_retry=num_retry,
)
async def get_all_visible_text(page: Page) -> str:
"""
Get all the visible text on the page.
:param page: Page instance to get the text from.
:return: All the visible text on the page.
"""
js_script = "() => document.body.innerText"
return await page.evaluate(js_script)
async def scrape_web_unsafe(
browser_state: BrowserState,
url: str,
) -> ScrapedPage:
"""
Asynchronous function that performs web scraping without any built-in error handling. This function is intended
for use cases where the caller handles exceptions or in controlled environments. It directly scrapes the provided
URL or continues on the given page.
:param browser_context: BrowserContext instance used for scraping.
:param url: URL of the web page to be scraped. Used only when creating a new page.
:param page: Optional Page instance for scraping, a new page is created if None.
:return: Tuple containing Page instance, base64 encoded screenshot, and page elements.
:note: This function does not handle exceptions. Ensure proper error handling in the calling context.
"""
# We only create a new page if one does not exist. This is to allow keeping the same page since we want to
# continue working on the same page that we're taking actions on.
# *This also means URL is only used when creating a new page, and not when using an existing page.
page = await browser_state.get_or_create_page(url)
# Take screenshots of the page with the bounding boxes. We will remove the bounding boxes later.
# Scroll to the top of the page and take a screenshot.
# Scroll to the next page and take a screenshot until we reach the end of the page.
# We check if the scroll_y_px_old is the same as scroll_y_px to determine if we have reached the end of the page.
# This also solves the issue where we can't scroll due to a popup.(e.g. geico first popup on the homepage after
# clicking start my quote)
LOG.info("Waiting for 5 seconds before scraping the website.")
await asyncio.sleep(5)
screenshots: list[bytes] = []
scroll_y_px_old = -1.0
scroll_y_px = await scroll_to_top(page, drow_boxes=True)
# Checking max number of screenshots to prevent infinite loop
while scroll_y_px_old != scroll_y_px and len(screenshots) < SettingsManager.get_settings().MAX_NUM_SCREENSHOTS:
screenshot = await page.screenshot(full_page=False)
screenshots.append(screenshot)
scroll_y_px_old = scroll_y_px
LOG.info("Scrolling to next page", url=url, num_screenshots=len(screenshots))
scroll_y_px = await scroll_to_next_page(page, drow_boxes=True)
LOG.info("Scrolled to next page", scroll_y_px=scroll_y_px, scroll_y_px_old=scroll_y_px_old)
await remove_bounding_boxes(page)
await scroll_to_top(page, drow_boxes=False)
elements, element_tree = await get_interactable_element_tree(page)
element_tree = cleanup_elements(copy.deepcopy(element_tree))
id_to_xpath_dict = {}
for element in elements:
element_id = element["id"]
# get_interactable_element_tree marks each interactable element with a unique_id attribute
id_to_xpath_dict[element_id] = f"//*[@{SKYVERN_ID_ATTR}='{element_id}']"
text_content = await get_all_visible_text(page)
return ScrapedPage(
elements=elements,
id_to_xpath_dict=id_to_xpath_dict,
element_tree=element_tree,
element_tree_trimmed=trim_element_tree(copy.deepcopy(element_tree)),
screenshots=screenshots,
url=page.url,
html=await page.content(),
extracted_text=text_content,
)
async def get_interactable_element_tree(page: Page) -> tuple[list[dict], list[dict]]:
"""
Get the element tree of the page, including all the elements that are interactable.
:param page: Page instance to get the element tree from.
:return: Tuple containing the element tree and a map of element IDs to elements.
"""
await page.evaluate(JS_FUNCTION_DEFS)
js_script = "() => buildTreeFromBody()"
elements, element_tree = await page.evaluate(js_script)
return elements, element_tree
async def scroll_to_top(page: Page, drow_boxes: bool) -> float:
"""
Scroll to the top of the page and take a screenshot.
:param drow_boxes: If True, draw bounding boxes around the elements.
:param page: Page instance to take the screenshot from.
:return: Screenshot of the page.
"""
await page.evaluate(JS_FUNCTION_DEFS)
js_script = f"() => scrollToTop({str(drow_boxes).lower()})"
scroll_y_px = await page.evaluate(js_script)
return scroll_y_px
async def scroll_to_next_page(page: Page, drow_boxes: bool) -> bool:
"""
Scroll to the next page and take a screenshot.
:param drow_boxes: If True, draw bounding boxes around the elements.
:param page: Page instance to take the screenshot from.
:return: Screenshot of the page.
"""
await page.evaluate(JS_FUNCTION_DEFS)
js_script = f"() => scrollToNextPage({str(drow_boxes).lower()})"
scroll_y_px = await page.evaluate(js_script)
return scroll_y_px
async def remove_bounding_boxes(page: Page) -> None:
"""
Remove the bounding boxes from the page.
:param page: Page instance to remove the bounding boxes from.
"""
js_script = "() => removeBoundingBoxes()"
await page.evaluate(js_script)
def cleanup_elements(elements: list[dict]) -> list[dict]:
"""
Remove rect and attribute.unique_id from the elements.
The reason we're doing it is to
1. reduce unnecessary data so that llm get less distrction
# TODO later: 2. reduce tokens sent to llm to save money
:param elements: List of elements to remove xpaths from.
:return: List of elements without xpaths.
"""
queue = []
for element in elements:
queue.append(element)
while queue:
queue_ele = queue.pop(0)
_remove_rect(queue_ele)
# TODO: we can come back to test removing the unique_id
# from element attributes to make sure this won't increase hallucination
# _remove_unique_id(queue_ele)
if "children" in queue_ele:
queue.extend(queue_ele["children"])
return elements
def trim_element_tree(elements: list[dict]) -> list[dict]:
queue = []
for element in elements:
queue.append(element)
while queue:
queue_ele = queue.pop(0)
if "attributes" in queue_ele:
tag_name = queue_ele["tagName"] if "tagName" in queue_ele else ""
new_attributes = _trimmed_attributes(tag_name, queue_ele["attributes"])
if new_attributes:
queue_ele["attributes"] = new_attributes
else:
del queue_ele["attributes"]
if "children" in queue_ele:
queue.extend(queue_ele["children"])
if not queue_ele["children"]:
del queue_ele["children"]
if "text" in queue_ele:
element_text = str(queue_ele["text"]).strip()
if not element_text:
del queue_ele["text"]
return elements
def _trimmed_attributes(tag_name: str, attributes: dict) -> dict:
new_attributes: dict = {}
for key in attributes:
if key == "id" and tag_name in ["input", "textarea", "select"]:
# We don't want to remove the id attribute any of these elements in case there's a label for it
new_attributes[key] = attributes[key]
if key in RESERVED_ATTRIBUTES:
new_attributes[key] = attributes[key]
return new_attributes
def _remove_rect(element: dict) -> None:
if "rect" in element:
del element["rect"]
def _remove_unique_id(element: dict) -> None:
if "attributes" not in element:
return
if SKYVERN_ID_ATTR in element["attributes"]:
del element["attributes"][SKYVERN_ID_ATTR]

View File

@@ -0,0 +1,5 @@
import re
def remove_whitespace(string: str) -> str:
return re.sub("[ \n\t]+", " ", string)

View File

View File

@@ -0,0 +1,74 @@
import json
from typing import Any
import requests
from skyvern.forge.sdk.schemas.tasks import TaskRequest
class SkyvernClient:
def __init__(self, base_url: str, credentials: str):
self.base_url = base_url
self.credentials = credentials
def create_task(self, task_request_body: TaskRequest) -> str | None:
url = f"{self.base_url}/tasks"
payload = task_request_body.model_dump()
headers = {
"Content-Type": "application/json",
"x-api-key": self.credentials,
}
response = requests.post(url, headers=headers, data=json.dumps(payload))
if "task_id" not in response.json():
return None
return response.json()["task_id"]
def get_task(self, task_id: str) -> dict[str, Any] | None:
"""Get a task by id."""
url = f"{self.base_url}/internal/tasks/{task_id}"
headers = {"x-api-key": self.credentials}
response = requests.get(url, headers=headers)
if response.status_code != 200:
return None
return response.json()
def get_agent_tasks(self, page: int = 1, page_size: int = 15) -> dict[str, Any]:
"""Get all tasks with pagination."""
url = f"{self.base_url}/internal/tasks"
params = {"page": page, "page_size": page_size}
headers = {"x-api-key": self.credentials}
response = requests.get(url, params=params, headers=headers)
return response.json()
def get_agent_task_steps(self, task_id: str, page: int = 1, page_size: int = 15) -> list[dict[str, Any]]:
"""Get all steps for a task with pagination."""
url = f"{self.base_url}/tasks/{task_id}/steps"
params = {"page": page, "page_size": page_size}
headers = {"x-api-key": self.credentials}
response = requests.get(url, params=params, headers=headers)
steps = response.json()
for step in steps:
step["output"]["actions_and_results"] = json.dumps(step["output"]["actions_and_results"])
return steps
def get_agent_task_video_artifact(self, task_id: str) -> dict[str, Any] | None:
"""Get the video artifact from the first step artifact of the task."""
steps = self.get_agent_task_steps(task_id)
if not steps:
return None
first_step_id = steps[0]["step_id"]
artifacts = self.get_agent_artifacts(task_id, first_step_id)
for artifact in artifacts:
if artifact["artifact_type"] == "recording":
return artifact
return None
def get_agent_artifacts(self, task_id: str, step_id: str) -> list[dict[str, Any]]:
"""Get all artifacts for a list of steps."""
url = f"{self.base_url}/tasks/{task_id}/steps/{step_id}/artifacts"
headers = {"x-api-key": self.credentials}
response = requests.get(url, headers=headers)
return response.json()

View File

@@ -0,0 +1,69 @@
import asyncio
import random
import string
import typing
from typing import Any, Callable
from PIL import Image
from skyvern.forge.sdk.api.aws import AsyncAWSClient
async_s3_client = AsyncAWSClient()
def read_artifact(uri: str, is_image: bool = False, is_webm: bool = False) -> Image.Image | str | bytes:
"""Load and display an artifact based on its URI."""
if uri.startswith("s3://"):
downloaded_bytes = asyncio.run(async_s3_client.download_file(uri))
if is_image:
return downloaded_bytes
elif is_webm:
return downloaded_bytes
else:
return downloaded_bytes.decode("utf-8")
elif uri.startswith("file://"):
# Remove file:// prefix
uri = uri[7:]
# Means it's a local file
if is_image:
with open(uri, "rb") as f:
image = Image.open(f)
image.load()
return image
elif is_webm:
with open(uri, "rb") as f:
return f.read()
else:
with open(uri, "r") as f:
return f.read()
else:
raise ValueError(f"Unsupported URI: {uri}")
def read_artifact_safe(uri: str, is_image: bool = False, is_webm: bool = False) -> Image.Image | str | bytes:
"""Load and display an artifact based on its URI."""
try:
return read_artifact(uri, is_image, is_webm)
except Exception as e:
return f"Failed to load artifact: {e}"
def streamlit_content_safe(st_obj: Any, f: Callable, content: bytes, message: str, **kwargs: dict[str, Any]) -> None:
try:
if content:
f(content, **kwargs)
else:
st_obj.write(message)
except Exception:
st_obj.write(message)
@typing.no_type_check
def streamlit_show_recording(st_obj: Any, uri: str) -> None:
# ignoring type because is_webm will return bytes
content = read_artifact_safe(uri, is_webm=True) # type: ignore
if content:
random_key = "".join(random.choices(string.ascii_uppercase + string.digits, k=6))
st_obj.download_button("Download recording", content, f"recording{uri.split('/')[-1]}.webm", key=random_key)
streamlit_content_safe(st_obj, st_obj.video, content, "No recording available.", format="video/webm", start_time=0)

View File

@@ -0,0 +1,30 @@
from typing import Any, Optional
from streamlit_app.visualizer.api import SkyvernClient
class TaskRepository:
def __init__(self, client: SkyvernClient):
self.client = client
def get_task(self, task_id: str) -> dict[str, Any] | None:
return self.client.get_task(task_id)
def get_tasks(self, page: int = 1, page_size: int = 15) -> dict[str, Any]:
"""Get tasks with pagination."""
return self.client.get_agent_tasks(page=page, page_size=page_size)
def get_task_steps(self, task_id: str) -> list[dict[str, Any]]:
"""Get steps for a specific task with pagination."""
return self.client.get_agent_task_steps(task_id)
def get_artifacts(self, task_id: str, step_id: str) -> list[dict[str, Any]]:
"""Get artifacts for a specific task and steps."""
return self.client.get_agent_artifacts(task_id, step_id)
def get_task_recording_uri(self, task: dict[str, Any]) -> Optional[str]:
"""Get the recording URI for a task."""
video_artifact = self.client.get_agent_task_video_artifact(task["task_id"])
if video_artifact is None:
return None
return video_artifact["uri"]

View File

@@ -0,0 +1,185 @@
import json
def get_sample_url() -> str:
return "https://www.geico.com"
def get_sample_navigation_goal() -> str:
return "Navigate through the website until you generate an auto insurance quote. Do not generate a home insurance quote. If this page contains an auto insurance quote, consider the goal achieved"
def get_sample_data_extraction_goal() -> str:
return "Extract all quote information in JSON format including the premium amount, the timeframe for the quote."
def get_sample_navigation_payload() -> str:
navigation_payload = {
"licensed_at_age": 19,
"education_level": "HIGH_SCHOOL",
"phone_number": "8042221111",
"full_name": "Chris P. Bacon",
"past_claim": [],
"has_claims": False,
"spouse_occupation": "Florist",
"auto_current_carrier": "None",
"home_commercial_uses": None,
"spouse_full_name": "Amy Stake",
"auto_commercial_uses": None,
"requires_sr22": False,
"previous_address_move_date": None,
"line_of_work": None,
"spouse_age": "1987-12-12",
"auto_insurance_deadline": None,
"email": "chris.p.bacon@abc.com",
"net_worth_numeric": 1000000,
"spouse_gender": "F",
"marital_status": "married",
"spouse_licensed_at_age": 20,
"license_number": "AAAAAAA090AA",
"spouse_license_number": "AAAAAAA080AA",
"how_much_can_you_lose": 25000,
"vehicles": [
{
"annual_mileage": 10000,
"commute_mileage": 4000,
"existing_coverages": None,
"ideal_coverages": {
"bodily_injury_per_incident_limit": 50000,
"bodily_injury_per_person_limit": 25000,
"collision_deductible": 1000,
"comprehensive_deductible": 1000,
"personal_injury_protection": None,
"property_damage_per_incident_limit": None,
"property_damage_per_person_limit": 25000,
"rental_reimbursement_per_incident_limit": None,
"rental_reimbursement_per_person_limit": None,
"roadside_assistance_limit": None,
"underinsured_motorist_bodily_injury_per_incident_limit": 50000,
"underinsured_motorist_bodily_injury_per_person_limit": 25000,
"underinsured_motorist_property_limit": None,
},
"ownership": "Owned",
"parked": "Garage",
"purpose": "commute",
"vehicle": {
"style": "AWD 3.0 quattro TDI 4dr Sedan",
"model": "A8 L",
"price_estimate": 29084,
"year": 2015,
"make": "Audi",
},
"vehicle_id": None,
"vin": None,
}
],
"additional_drivers": [],
"home": [
{
"home_ownership": "owned",
}
],
"spouse_line_of_work": "Agriculture, Forestry and Fishing",
"occupation": "Customer Service Representative",
"id": None,
"gender": "M",
"credit_check_authorized": False,
"age": "1987-11-11",
"license_state": "Washington",
"cash_on_hand": "$1000014999",
"address": {
"city": "HOUSTON",
"country": "US",
"state": "TX",
"street": "9625 GARFIELD AVE.",
"zip": "77082",
},
"spouse_education_level": "MASTERS",
"spouse_email": "amy.stake@abc.com",
"spouse_added_to_auto_policy": True,
}
return json.dumps(navigation_payload)
def get_sample_extracted_information_schema() -> str:
extracted_information_schema = {
"additionalProperties": False,
"properties": {
"quotes": {
"items": {
"additionalProperties": False,
"properties": {
"coverages": {
"items": {
"additionalProperties": False,
"properties": {
"amount": {
"description": "The coverage amount in USD, which can be a single value or a range (e.g., '$300,000' or '$300,000/$300,000').",
"type": "string",
},
"included": {
"description": "Indicates whether the coverage is included in the policy (true or False).",
"type": "boolean",
},
"type": {
"description": "The limit of the coverage (e.g., 'bodily_injury_limit', 'property_damage_limit', 'underinsured_motorist_bodily_injury_limit').\nTranslate the english name of the coverage to snake case values in the following list:\n * bodily_injury_limit\n * property_damage_limit\n * underinsured_motorist_bodily_injury_limit\n * personal_injury_protection\n * accidental_death\n * work_loss_exclusion\n",
"type": "string",
},
},
"type": "object",
},
"type": "array",
},
"premium_amount": {
"description": "The total premium amount for the whole quote timeframe in USD, formatted as a string (e.g., '$321.57').",
"type": "string",
},
"quote_number": {
"description": "The quote number generated by the carrier that identifies this quote",
"type": "string",
},
"timeframe": {
"description": "The duration of the coverage, typically expressed in months or years.",
"type": "string",
},
"vehicle_coverages": {
"items": {
"additionalProperties": False,
"properties": {
"collision_deductible": {
"description": "The collision deductible amount in USD, which is a single value (e.g., '$500') or null if it is not included",
"type": "string",
},
"comprehensive_deductible": {
"description": "The collision deductible amount in USD, which is a single value (e.g., '$500') or null if it is not included",
"type": "string",
},
"for_vehicle": {
"additionalProperties": False,
"description": "The vehicle that the collision and comprehensive coverage is for",
"properties": {
"make": {"description": "The make of the vehicle", "type": "string"},
"model": {"description": "The model of the vehicle", "type": "string"},
"year": {"description": "The year of the vehicle", "type": "string"},
},
"type": "object",
},
"underinsured_property_damage": {
"description": "The underinsured property damage limit for this vehicle, which is a limit and a deductible (e.g., '$25,000/$250 deductible') or null if it is not included",
"type": "string",
},
},
"type": "object",
},
"type": "array",
},
},
"type": "object",
},
"type": "array",
}
},
"type": "object",
}
return json.dumps(extracted_information_schema)

View File

@@ -0,0 +1,383 @@
import pandas as pd
import streamlit as st
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, TaskRequest
from streamlit_app.visualizer import styles
from streamlit_app.visualizer.api import SkyvernClient
from streamlit_app.visualizer.artifact_loader import (
read_artifact_safe,
streamlit_content_safe,
streamlit_show_recording,
)
from streamlit_app.visualizer.repository import TaskRepository
from streamlit_app.visualizer.sample_data import (
get_sample_data_extraction_goal,
get_sample_extracted_information_schema,
get_sample_navigation_goal,
get_sample_navigation_payload,
get_sample_url,
)
# Streamlit UI Configuration
st.set_page_config(layout="wide")
# Apply styles
st.markdown(styles.page_font_style, unsafe_allow_html=True)
st.markdown(styles.button_style, unsafe_allow_html=True)
# Configuration
def reset_session_state() -> None:
# Delete all the items in Session state when env or org is changed
for key in st.session_state.keys():
del st.session_state[key]
CONFIGS_DICT = st.secrets["skyvern"]["configs"]
if not CONFIGS_DICT:
raise Exception("No configuration found. Copy the values from 1P and restart the app.")
SETTINGS = {}
for config in CONFIGS_DICT:
env = config["env"]
host = config["host"]
orgs = config["orgs"]
org_dict = {org["name"]: org["cred"] for org in orgs}
SETTINGS[env] = {"host": host, "orgs": org_dict}
st.sidebar.markdown("#### **Settings**")
select_env = st.sidebar.selectbox("Environment", list(SETTINGS.keys()), on_change=reset_session_state)
select_org = st.sidebar.selectbox(
"Organization", list(SETTINGS[select_env]["orgs"].keys()), on_change=reset_session_state
)
# Initialize session state
if "client" not in st.session_state:
st.session_state.client = SkyvernClient(
base_url=SETTINGS[select_env]["host"], credentials=SETTINGS[select_env]["orgs"][select_org]
)
if "repository" not in st.session_state:
st.session_state.repository = TaskRepository(st.session_state.client)
if "task_page_number" not in st.session_state:
st.session_state.task_page_number = 1
if "selected_task" not in st.session_state:
st.session_state.selected_task = None
st.session_state.selected_task_recording_uri = None
st.session_state.task_steps = None
if "selected_step" not in st.session_state:
st.session_state.selected_step = None
st.session_state.selected_step_index = None
client = st.session_state.client
repository = st.session_state.repository
task_page_number = st.session_state.task_page_number
selected_task = st.session_state.selected_task
selected_task_recording_uri = st.session_state.selected_task_recording_uri
task_steps = st.session_state.task_steps
selected_step = st.session_state.selected_step
selected_step_index = st.session_state.selected_step_index
# Onclick handlers
def select_task(task: dict) -> None:
st.session_state.selected_task = task
st.session_state.selected_task_recording_uri = repository.get_task_recording_uri(task)
# reset step selection
st.session_state.selected_step = None
# save task's steps in session state
st.session_state.task_steps = repository.get_task_steps(task["task_id"])
if st.session_state.task_steps:
st.session_state.selected_step = st.session_state.task_steps[0]
st.session_state.selected_step_index = 0
def go_to_previous_step() -> None:
new_step_index = max(0, selected_step_index - 1)
select_step(task_steps[new_step_index])
def go_to_next_step() -> None:
new_step_index = min(len(task_steps) - 1, selected_step_index + 1)
select_step(task_steps[new_step_index])
def select_step(step: dict) -> None:
st.session_state.selected_step = step
st.session_state.selected_step_index = task_steps.index(step)
# Streamlit UI Logic
st.markdown("# **:dragon: Skyvern :dragon:**")
st.markdown(f"### **{select_env} - {select_org}**")
execute_tab, visualizer_tab = st.tabs(["Execute", "Visualizer"])
with execute_tab:
create_column, explanation_column = st.columns([1, 2])
with create_column:
with st.form("task_form"):
st.markdown("## Run a task")
# Create all the fields to create a TaskRequest object
st_url = st.text_input("URL*", value=get_sample_url(), key="url")
st_webhook_callback_url = st.text_input("Webhook Callback URL", key="webhook", placeholder="Optional")
st_navigation_goal = st.text_input(
"Navigation Goal",
key="nav_goal",
placeholder="Describe the navigation goal",
value=get_sample_navigation_goal(),
)
st_data_extraction_goal = st.text_input(
"Data Extraction Goal",
key="data_goal",
placeholder="Describe the data extraction goal",
value=get_sample_data_extraction_goal(),
)
st_navigation_payload = st.text_area(
"Navigation Payload JSON",
key="nav_payload",
placeholder='{"name": "John Doe", "email": "abc@123.com"}',
value=get_sample_navigation_payload(),
)
st_extracted_information_schema = st.text_area(
"Extracted Information Schema",
key="extracted_info_schema",
placeholder='{"quote_price": "float"}',
value=get_sample_extracted_information_schema(),
)
# Create a TaskRequest object from the form fields
task_request_body = TaskRequest(
url=st_url,
webhook_callback_url=st_webhook_callback_url,
navigation_goal=st_navigation_goal,
data_extraction_goal=st_data_extraction_goal,
proxy_location=ProxyLocation.NONE,
navigation_payload=st_navigation_payload,
extracted_information_schema=st_extracted_information_schema,
)
# Submit the form
if st.form_submit_button("Execute Task", use_container_width=True):
# Call the API to create a task
task_id = client.create_task(task_request_body)
if not task_id:
st.error("Failed to create task!")
else:
st.success("Task created successfully, task_id: " + task_id)
with explanation_column:
st.markdown("### **Task Request**")
st.markdown("#### **URL**")
st.markdown("The starting URL for the task.")
st.markdown("#### **Webhook Callback URL**")
st.markdown("The URL to call with the results when the task is completed.")
st.markdown("#### **Navigation Goal**")
st.markdown("The user's goal for the task. Nullable if the task is only for data extraction.")
st.markdown("#### **Data Extraction Goal**")
st.markdown("The user's goal for data extraction. Nullable if the task is only for navigation.")
st.markdown("#### **Navigation Payload**")
st.markdown("The user's details needed to achieve the task. AI will use this information as needed.")
st.markdown("#### **Extracted Information Schema**")
st.markdown("The requested schema of the extracted information for data extraction goal.")
with visualizer_tab:
task_id_input = st.text_input("task_id", value="")
def search_task() -> None:
if not task_id_input:
return
task = repository.get_task(task_id_input)
if task:
select_task(task)
else:
st.error(f"Task with id {task_id_input} not found.")
st.button("search task", on_click=search_task)
col_tasks, _, col_steps, _, col_artifacts = st.columns([4, 1, 6, 1, 18])
col_tasks.markdown(f"#### Tasks")
col_steps.markdown(f"#### Steps")
col_artifacts.markdown("#### Artifacts")
tasks_response = repository.get_tasks(task_page_number)
if "error" in tasks_response:
st.write(tasks_response)
# Display tasks in sidebar for selection
tasks = {task["task_id"]: task for task in tasks_response}
task_id_buttons = {
task_id: col_tasks.button(
f"{task_id}",
on_click=select_task,
args=(task,),
use_container_width=True,
type="primary" if selected_task and task_id == selected_task["task_id"] else "secondary",
)
for task_id, task in tasks.items()
}
# Display pagination buttons
task_page_prev, _, show_task_page_number, _, task_page_next = col_tasks.columns([1, 1, 1, 1, 1])
show_task_page_number.button(str(task_page_number), disabled=True)
if task_page_next.button("\>"):
st.session_state.task_page_number += 1
if task_page_prev.button("\<", disabled=task_page_number == 1):
st.session_state.task_page_number = max(1, st.session_state.task_page_number - 1)
(
tab_task,
tab_step,
tab_recording,
tab_screenshot,
tab_post_action_screenshot,
tab_id_to_xpath,
tab_element_tree,
tab_element_tree_trimmed,
tab_llm_prompt,
tab_llm_request,
tab_llm_response_parsed,
tab_llm_response_raw,
tab_html,
) = col_artifacts.tabs(
[
":green[Task]",
":blue[Step]",
":violet[Recording]",
":rainbow[Screenshot]",
":rainbow[Action Screenshots]",
":red[ID -> XPath]",
":orange[Element Tree]",
":blue[Element Tree (Trimmed)]",
":yellow[LLM Prompt]",
":green[LLM Request]",
":blue[LLM Response (Parsed)]",
":violet[LLM Response (Raw)]",
":rainbow[Html (Raw)]",
]
)
tab_task_details, tab_task_steps, tab_task_action_results = tab_task.tabs(["Details", "Steps", "Action Results"])
if selected_task:
tab_task_details.json(selected_task)
if selected_task_recording_uri:
streamlit_show_recording(tab_recording, selected_task_recording_uri)
if task_steps:
col_steps_prev, _, col_steps_next = col_steps.columns([3, 1, 3])
col_steps_prev.button(
"prev", on_click=go_to_previous_step, key="previous_step_button", use_container_width=True
)
col_steps_next.button("next", on_click=go_to_next_step, key="next_step_button", use_container_width=True)
step_id_buttons = {
step["step_id"]: col_steps.button(
f"{step['order']} - {step['retry_index']} - {step['step_id']}",
on_click=select_step,
args=(step,),
use_container_width=True,
type="primary" if selected_step and step["step_id"] == selected_step["step_id"] else "secondary",
)
for step in task_steps
}
df = pd.json_normalize(task_steps)
tab_task_steps.dataframe(df, use_container_width=True, height=1000)
task_action_results = []
for step in task_steps:
output = step.get("output")
step_id = step["step_id"]
if output:
step_action_results = output.get("action_results", [])
for action_result in step_action_results:
task_action_results.append(
{
"step_id": step_id,
"order": step["order"],
"retry_index": step["retry_index"],
**action_result,
}
)
df = pd.json_normalize(task_action_results)
df = df.reindex(sorted(df.columns), axis=1)
tab_task_action_results.dataframe(df, use_container_width=True, height=1000)
if selected_step:
tab_step.json(selected_step)
artifacts_response = repository.get_artifacts(selected_task["task_id"], selected_step["step_id"])
split_artifact_uris = [artifact["uri"].split("/") for artifact in artifacts_response]
file_name_to_uris = {split_uri[-1]: "/".join(split_uri) for split_uri in split_artifact_uris}
for file_name, uri in file_name_to_uris.items():
file_name = file_name.lower()
if file_name.endswith("screenshot_llm.png") or file_name.endswith("screenshot.png"):
streamlit_content_safe(
tab_screenshot,
tab_screenshot.image,
read_artifact_safe(uri, is_image=True),
"No screenshot available.",
use_column_width=True,
)
elif file_name.endswith("screenshot_action.png"):
streamlit_content_safe(
tab_post_action_screenshot,
tab_post_action_screenshot.image,
read_artifact_safe(uri, is_image=True),
"No action screenshot available.",
use_column_width=True,
)
elif file_name.endswith("id_xpath_map.json"):
streamlit_content_safe(
tab_id_to_xpath, tab_id_to_xpath.json, read_artifact_safe(uri), "No ID -> XPath map available."
)
elif file_name.endswith("tree.json"):
streamlit_content_safe(
tab_element_tree,
tab_element_tree.json,
read_artifact_safe(uri),
"No element tree available.",
)
elif file_name.endswith("tree_trimmed.json"):
streamlit_content_safe(
tab_element_tree_trimmed,
tab_element_tree_trimmed.json,
read_artifact_safe(uri),
"No element tree trimmed available.",
)
elif file_name.endswith("llm_prompt.txt"):
content = read_artifact_safe(uri)
# this is a hacky way to call this generic method to get it working with st.text_area
streamlit_content_safe(
tab_llm_prompt,
tab_llm_prompt.text_area,
content,
"No LLM prompt available.",
value=content,
height=1000,
label_visibility="collapsed",
)
# tab_llm_prompt.text_area("collapsed", value=content, label_visibility="collapsed", height=1000)
elif file_name.endswith("llm_request.json"):
streamlit_content_safe(
tab_llm_request, tab_llm_request.json, read_artifact_safe(uri), "No LLM request available."
)
elif file_name.endswith("llm_response_parsed.json"):
streamlit_content_safe(
tab_llm_response_parsed,
tab_llm_response_parsed.json,
read_artifact_safe(uri),
"No parsed LLM response available.",
)
elif file_name.endswith("llm_response.json"):
streamlit_content_safe(
tab_llm_response_raw,
tab_llm_response_raw.json,
read_artifact_safe(uri),
"No raw LLM response available.",
)
elif file_name.endswith("html_scrape.html"):
streamlit_content_safe(tab_html, tab_html.text, read_artifact_safe(uri), "No html available.")
elif file_name.endswith("html_action.html"):
streamlit_content_safe(tab_html, tab_html.text, read_artifact_safe(uri), "No html available.")
else:
st.write(f"Artifact {file_name} not supported.")

View File

@@ -0,0 +1,38 @@
page_font_style = """
<style>
@import url('https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@300&display=swap');
* {
font-family: 'Roboto Mono', monospace;
}
</style>
"""
button_style = """
<style>
/* Apply the custom styles to Streamlit button */
.stButton > button {
text-align: center; /* Center button text */
font-size: 10px; /* Set font size here */
border: none; /* No border */
border-radius: 20px; /* Rounded corners */
background-color: #67748E;
color: ##3C414A;
padding: 10px 10px; /* Some padding */
box-shadow: 0 4px 8px rgba(0,0,0,0.2); /* Box shadow */
}
.stButton > button[kind="primary"] {
border: 3px solid #DCFF94; /* Red border */
}
.stButton > button:disabled {
background-color: #636B7D;
}
.stButton > button:hover {
background-color: #73678F;
color: #B6E359;
}
</style>
"""