add skyvern sdk (#1741)
This commit is contained in:
144
skyvern/agent/local.py
Normal file
144
skyvern/agent/local.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from skyvern.agent.parameter import TaskV1Request, TaskV2Request
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import security, skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverTaskStatus
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskResponse, TaskStatus
|
||||
from skyvern.forge.sdk.services import observer_service
|
||||
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||
from skyvern.utils import migrate_db
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self) -> None:
|
||||
load_dotenv(".env")
|
||||
migrate_db()
|
||||
|
||||
async def _get_organization(self) -> Organization:
|
||||
organization = await app.DATABASE.get_organization_by_domain("skyvern.local")
|
||||
if not organization:
|
||||
organization = await app.DATABASE.create_organization(
|
||||
organization_name="Skyvern-local",
|
||||
domain="skyvern.local",
|
||||
max_steps_per_run=10,
|
||||
max_retries_per_step=3,
|
||||
)
|
||||
api_key = security.create_access_token(
|
||||
organization.organization_id,
|
||||
expires_delta=API_KEY_LIFETIME,
|
||||
)
|
||||
# generate OrganizationAutoToken
|
||||
await app.DATABASE.create_org_auth_token(
|
||||
organization_id=organization.organization_id,
|
||||
token=api_key,
|
||||
token_type=OrganizationAuthTokenType.api,
|
||||
)
|
||||
return organization
|
||||
|
||||
async def run_task_v1(self, task_request: TaskV1Request) -> TaskResponse:
|
||||
organization = await self._get_organization()
|
||||
|
||||
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
|
||||
organization_id=organization.organization_id,
|
||||
token_type=OrganizationAuthTokenType.api,
|
||||
)
|
||||
|
||||
created_task = await app.agent.create_task(task_request, organization.organization_id)
|
||||
|
||||
skyvern_context.set(
|
||||
SkyvernContext(
|
||||
organization_id=organization.organization_id,
|
||||
task_id=created_task.task_id,
|
||||
max_steps_override=task_request.max_steps,
|
||||
)
|
||||
)
|
||||
|
||||
step = await app.DATABASE.create_step(
|
||||
created_task.task_id,
|
||||
order=0,
|
||||
retry_index=0,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
updated_task = await app.DATABASE.update_task(
|
||||
created_task.task_id,
|
||||
status=TaskStatus.running,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
|
||||
step, _, _ = await app.agent.execute_step(
|
||||
organization=organization,
|
||||
task=updated_task,
|
||||
step=step,
|
||||
api_key=org_auth_token.token if org_auth_token else None,
|
||||
)
|
||||
|
||||
refreshed_task = await app.DATABASE.get_task(created_task.task_id, organization.organization_id)
|
||||
if refreshed_task:
|
||||
updated_task = refreshed_task
|
||||
|
||||
failure_reason: str | None = None
|
||||
if updated_task.status == TaskStatus.failed and (step.output or updated_task.failure_reason):
|
||||
failure_reason = ""
|
||||
if updated_task.failure_reason:
|
||||
failure_reason += updated_task.failure_reason or ""
|
||||
if step.output is not None and step.output.actions_and_results is not None:
|
||||
action_results_string: list[str] = []
|
||||
for action, results in step.output.actions_and_results:
|
||||
if len(results) == 0:
|
||||
continue
|
||||
if results[-1].success:
|
||||
continue
|
||||
action_results_string.append(f"{action.action_type} action failed.")
|
||||
|
||||
if len(action_results_string) > 0:
|
||||
failure_reason += "(Exceptions: " + str(action_results_string) + ")"
|
||||
return await app.agent.build_task_response(
|
||||
task=updated_task, last_step=step, failure_reason=failure_reason, need_browser_log=True
|
||||
)
|
||||
|
||||
async def run_task_v2(self, task_request: TaskV2Request) -> ObserverTask:
|
||||
organization = await self._get_organization()
|
||||
|
||||
observer_task = await observer_service.initialize_observer_task(
|
||||
organization=organization,
|
||||
user_prompt=task_request.user_prompt,
|
||||
user_url=str(task_request.url) if task_request.url else None,
|
||||
totp_identifier=task_request.totp_identifier,
|
||||
totp_verification_url=task_request.totp_verification_url,
|
||||
webhook_callback_url=task_request.webhook_callback_url,
|
||||
proxy_location=task_request.proxy_location,
|
||||
publish_workflow=task_request.publish_workflow,
|
||||
)
|
||||
|
||||
if not observer_task.workflow_run_id:
|
||||
raise Exception("Observer cruise missing workflow run id")
|
||||
|
||||
# mark observer cruise as queued
|
||||
await app.DATABASE.update_observer_cruise(
|
||||
observer_cruise_id=observer_task.observer_cruise_id,
|
||||
status=ObserverTaskStatus.queued,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
await app.DATABASE.update_workflow_run(
|
||||
workflow_run_id=observer_task.workflow_run_id,
|
||||
status=WorkflowRunStatus.queued,
|
||||
)
|
||||
|
||||
await observer_service.run_observer_task(
|
||||
organization=organization,
|
||||
observer_cruise_id=observer_task.observer_cruise_id,
|
||||
max_iterations_override=task_request.max_iterations,
|
||||
)
|
||||
|
||||
refreshed_observer_task = await app.DATABASE.get_observer_cruise(
|
||||
observer_cruise_id=observer_task.observer_cruise_id, organization_id=organization.organization_id
|
||||
)
|
||||
if refreshed_observer_task:
|
||||
return refreshed_observer_task
|
||||
|
||||
return observer_task
|
||||
10
skyvern/agent/parameter.py
Normal file
10
skyvern/agent/parameter.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from skyvern.forge.sdk.schemas.observers import ObserverTaskRequest
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskRequest
|
||||
|
||||
|
||||
class TaskV1Request(TaskRequest):
|
||||
max_steps: int = 10
|
||||
|
||||
|
||||
class TaskV2Request(ObserverTaskRequest):
|
||||
max_iterations: int = 10
|
||||
43
skyvern/agent/remote.py
Normal file
43
skyvern/agent/remote.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import httpx
|
||||
|
||||
from skyvern.agent.parameter import TaskV1Request, TaskV2Request
|
||||
from skyvern.forge.sdk.schemas.observers import ObserverTask
|
||||
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, TaskResponse
|
||||
|
||||
|
||||
class RemoteAgent:
|
||||
def __init__(self, api_key: str, endpoint: str = "https://api.skyvern.com"):
|
||||
self.endpoint = endpoint
|
||||
self.api_key = api_key
|
||||
self.client = httpx.AsyncClient(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": self.api_key,
|
||||
}
|
||||
)
|
||||
|
||||
async def run_task_v1(self, task: TaskV1Request) -> CreateTaskResponse:
|
||||
url = f"{self.endpoint}/api/v1/tasks"
|
||||
payload = task.model_dump_json()
|
||||
headers = {"x_max_steps_override": str(task.max_steps)}
|
||||
response = await self.client.post(url, headers=headers, data=payload)
|
||||
return CreateTaskResponse.model_validate(response.json())
|
||||
|
||||
async def run_task_v2(self, task: TaskV2Request) -> ObserverTask:
|
||||
url = f"{self.endpoint}/api/v2/tasks"
|
||||
payload = task.model_dump_json()
|
||||
headers = {"x_max_iterations_override": str(task.max_iterations)}
|
||||
response = await self.client.post(url, headers=headers, data=payload)
|
||||
return ObserverTask.model_validate(response.json())
|
||||
|
||||
async def get_task_v1(self, task_id: str) -> TaskResponse:
|
||||
"""Get a task by id."""
|
||||
url = f"{self.endpoint}/api/v1/tasks/{task_id}"
|
||||
response = await self.client.get(url)
|
||||
return TaskResponse.model_validate(response.json())
|
||||
|
||||
async def get_task_v2(self, task_id: str) -> ObserverTask:
|
||||
"""Get a task by id."""
|
||||
url = f"{self.endpoint}/api/v2/tasks/{task_id}"
|
||||
response = await self.client.get(url)
|
||||
return ObserverTask.model_validate(response.json())
|
||||
126
skyvern/cli/commands.py
Normal file
126
skyvern/cli/commands.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from skyvern.utils import migrate_db
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def command_exists(command: str) -> bool:
|
||||
return shutil.which(command) is not None
|
||||
|
||||
|
||||
def run_command(command: str, check: bool = True) -> tuple[Optional[str], Optional[int]]:
|
||||
try:
|
||||
result = subprocess.run(command, shell=True, check=check, capture_output=True, text=True)
|
||||
return result.stdout.strip(), result.returncode
|
||||
except subprocess.CalledProcessError as e:
|
||||
return None, e.returncode
|
||||
|
||||
|
||||
def is_postgres_running() -> bool:
|
||||
if command_exists("pg_isready"):
|
||||
result = run_command("pg_isready")
|
||||
return result is not None and "accepting connections" in result
|
||||
return False
|
||||
|
||||
|
||||
def database_exists(dbname: str, user: str) -> bool:
|
||||
check_db_command = f'psql {dbname} -U {user} -c "\\q"'
|
||||
return run_command(check_db_command, check=False) is not None
|
||||
|
||||
|
||||
def create_database_and_user() -> None:
|
||||
print("Creating database user and database...")
|
||||
run_command("createuser skyvern")
|
||||
run_command("createdb skyvern -O skyvern")
|
||||
print("Database and user created successfully.")
|
||||
|
||||
|
||||
def is_docker_running() -> bool:
|
||||
if not command_exists("docker"):
|
||||
return False
|
||||
_, code = run_command("docker info", check=False)
|
||||
return code == 0
|
||||
|
||||
|
||||
def is_postgres_running_in_docker() -> bool:
|
||||
_, code = run_command("docker ps | grep -q postgresql-container", check=False)
|
||||
return code == 0
|
||||
|
||||
|
||||
def is_postgres_container_exists() -> bool:
|
||||
_, code = run_command("docker ps -a | grep -q postgresql-container", check=False)
|
||||
return code == 0
|
||||
|
||||
|
||||
def setup_postgresql() -> None:
|
||||
print("Setting up PostgreSQL...")
|
||||
|
||||
if command_exists("psql") and is_postgres_running():
|
||||
print("PostgreSQL is already running locally.")
|
||||
if database_exists("skyvern", "skyvern"):
|
||||
print("Database and user exist.")
|
||||
else:
|
||||
create_database_and_user()
|
||||
return
|
||||
|
||||
if not is_docker_running():
|
||||
print("Docker is not running or not installed. Please install or start Docker and try again.")
|
||||
exit(1)
|
||||
|
||||
if is_postgres_running_in_docker():
|
||||
print("PostgreSQL is already running in a Docker container.")
|
||||
else:
|
||||
print("Attempting to install PostgreSQL via Docker...")
|
||||
if not is_postgres_container_exists():
|
||||
run_command(
|
||||
"docker run --name postgresql-container -e POSTGRES_HOST_AUTH_METHOD=trust -d -p 5432:5432 postgres:14"
|
||||
)
|
||||
else:
|
||||
run_command("docker start postgresql-container")
|
||||
print("PostgreSQL has been installed and started using Docker.")
|
||||
|
||||
print("Waiting for PostgreSQL to start...")
|
||||
time.sleep(20)
|
||||
|
||||
_, code = run_command('docker exec postgresql-container psql -U postgres -c "\\du" | grep -q skyvern', check=False)
|
||||
if code == 0:
|
||||
print("Database user exists.")
|
||||
else:
|
||||
print("Creating database user...")
|
||||
run_command("docker exec postgresql-container createuser -U postgres skyvern")
|
||||
|
||||
_, code = run_command(
|
||||
"docker exec postgresql-container psql -U postgres -lqt | cut -d \\| -f 1 | grep -qw skyvern", check=False
|
||||
)
|
||||
if code == 0:
|
||||
print("Database exists.")
|
||||
else:
|
||||
print("Creating database...")
|
||||
run_command("docker exec postgresql-container createdb -U postgres skyvern -O skyvern")
|
||||
print("Database and user created successfully.")
|
||||
|
||||
|
||||
@app.command(name="init")
|
||||
def init(
|
||||
openai_api_key: str = typer.Option(..., help="The OpenAI API key"),
|
||||
log_level: str = typer.Option("CRITICAL", help="The log level"),
|
||||
) -> None:
|
||||
setup_postgresql()
|
||||
# Generate .env file
|
||||
with open(".env", "w") as env_file:
|
||||
env_file.write("ENABLE_OPENAI=true\n")
|
||||
env_file.write(f"OPENAI_API_KEY={openai_api_key}\n")
|
||||
env_file.write(f"LOG_LEVEL={log_level}\n")
|
||||
env_file.write("ARTIFACT_STORAGE_PATH=./artifacts\n")
|
||||
print(".env file created with the parameters provided.")
|
||||
|
||||
|
||||
@app.command(name="migrate")
|
||||
def migrate() -> None:
|
||||
migrate_db()
|
||||
@@ -11,7 +11,7 @@ class Settings(BaseSettings):
|
||||
|
||||
BROWSER_TYPE: str = "chromium-headful"
|
||||
MAX_SCRAPING_RETRIES: int = 0
|
||||
VIDEO_PATH: str | None = None
|
||||
VIDEO_PATH: str | None = "./video"
|
||||
HAR_PATH: str | None = "./har"
|
||||
LOG_PATH: str = "./log"
|
||||
TEMP_PATH: str = "./temp"
|
||||
|
||||
10
skyvern/utils/__init__.py
Normal file
10
skyvern/utils/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from skyvern.constants import REPO_ROOT_DIR
|
||||
|
||||
|
||||
def migrate_db() -> None:
|
||||
alembic_cfg = Config()
|
||||
path = f"{REPO_ROOT_DIR}/alembic"
|
||||
alembic_cfg.set_main_option("script_location", path)
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
Reference in New Issue
Block a user