Update docs plus init (#2073)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Suchintan
2025-04-03 00:46:57 -04:00
committed by GitHub
parent 816d0e34d1
commit ff57f9977c
11 changed files with 804 additions and 750 deletions

View File

@@ -1,4 +1,5 @@
import asyncio
import os
import subprocess
from typing import Any, cast
@@ -35,21 +36,21 @@ class SkyvernAgent:
self.extra_headers = extra_headers
self.client: SkyvernClient | None = None
if base_url is None and api_key is None:
# TODO: run at the root wherever the code is initiated
if not os.path.exists(".env"):
raise Exception("No .env file found. Please run 'skyvern init' first to set up your environment.")
load_dotenv(".env")
migrate_db()
# TODO: will this change the already imported settings?
# TODO: maybe refresh the settings
self.cdp_url = cdp_url
if browser_path:
# TODO validate browser_path
# Supported Browsers: Google Chrome, Brave Browser, Microsoft Edge, Firefox
if "Chrome" in browser_path or "Brave" in browser_path or "Edge" in browser_path:
self.browser_process = subprocess.Popen(
browser_process = subprocess.Popen(
[browser_path, "--remote-debugging-port=9222"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
if self.browser_process.poll() is not None:
if browser_process.poll() is not None:
raise Exception(f"Failed to open browser. browser_path: {browser_path}")
self.cdp_url = "http://127.0.0.1:9222"
@@ -76,7 +77,7 @@ class SkyvernAgent:
else:
raise ValueError("base_url and api_key must be both provided")
async def _get_organization(self) -> Organization:
async def get_organization(self) -> Organization:
organization = await app.DATABASE.get_organization_by_domain("skyvern.local")
if not organization:
organization = await app.DATABASE.create_organization(
@@ -154,7 +155,7 @@ class SkyvernAgent:
self,
task_request: TaskRequest,
) -> CreateTaskResponse:
organization = await self._get_organization()
organization = await self.get_organization()
created_task = await app.agent.create_task(task_request, organization.organization_id)
@@ -165,7 +166,7 @@ class SkyvernAgent:
self,
task_id: str,
) -> TaskResponse | None:
organization = await self._get_organization()
organization = await self.get_organization()
task = await app.DATABASE.get_task(task_id, organization.organization_id)
if task is None:
@@ -212,7 +213,7 @@ class SkyvernAgent:
await asyncio.sleep(1)
async def observer_task_v_2(self, task_request: TaskV2Request) -> TaskV2:
organization = await self._get_organization()
organization = await self.get_organization()
task_v2 = await task_v2_service.initialize_task_v2(
organization=organization,
@@ -232,7 +233,7 @@ class SkyvernAgent:
return task_v2
async def get_observer_task_v_2(self, task_id: str) -> TaskV2 | None:
organization = await self._get_organization()
organization = await self.get_organization()
return await app.DATABASE.get_task_v2(task_id, organization.organization_id)
async def run_observer_task_v_2(self, task_request: TaskV2Request, timeout_seconds: int = 600) -> TaskV2:
@@ -249,7 +250,7 @@ class SkyvernAgent:
############### officially supported interfaces ###############
async def get_run(self, run_id: str) -> RunResponse | None:
if not self.client:
organization = await self._get_organization()
organization = await self.get_organization()
return await run_service.get_run_response(run_id, organization_id=organization.organization_id)
return await self.client.get_run(run_id)
@@ -276,7 +277,7 @@ class SkyvernAgent:
data_extraction_goal = None
navigation_goal = prompt
navigation_payload = None
organization = await self._get_organization()
organization = await self.get_organization()
task_generation = await task_v1_service.generate_task(
user_prompt=prompt,
organization=organization,
@@ -318,7 +319,7 @@ class SkyvernAgent:
return cast(TaskRunResponse, run_obj)
elif engine == RunEngine.skyvern_v2:
# initialize task v2
organization = await self._get_organization()
organization = await self.get_organization()
task_v2 = await task_v2_service.initialize_task_v2(
organization=organization,