Sync cloud skyvern to oss skyvern (#55)

This commit is contained in:
Kerem Yilmaz
2024-03-12 22:28:16 -07:00
committed by GitHub
parent 647ea2ac0f
commit 15d78d7b08
25 changed files with 554 additions and 163 deletions

View File

@@ -55,6 +55,7 @@ class AgentDB:
async def create_task(
self,
url: str,
title: str | None,
navigation_goal: str | None,
data_extraction_goal: str | None,
navigation_payload: dict[str, Any] | list | str | None,
@@ -65,12 +66,14 @@ class AgentDB:
workflow_run_id: str | None = None,
order: int | None = None,
retry: int | None = None,
error_code_mapping: dict[str, str] | None = None,
) -> Task:
try:
with self.Session() as session:
new_task = TaskModel(
status="created",
url=url,
title=title,
webhook_callback_url=webhook_callback_url,
navigation_goal=navigation_goal,
data_extraction_goal=data_extraction_goal,
@@ -81,6 +84,7 @@ class AgentDB:
workflow_run_id=workflow_run_id,
order=order,
retry=retry,
error_code_mapping=error_code_mapping,
)
session.add(new_task)
session.commit()
@@ -312,11 +316,16 @@ class AgentDB:
async def update_task(
self,
task_id: str,
status: TaskStatus,
status: TaskStatus | None = None,
extracted_information: dict[str, Any] | list | str | None = None,
failure_reason: str | None = None,
errors: list[dict[str, Any]] | None = None,
organization_id: str | None = None,
) -> Task:
if status is None and extracted_information is None and failure_reason is None and errors is None:
raise ValueError(
"At least one of status, extracted_information, or failure_reason must be provided to update the task"
)
try:
with self.Session() as session:
if (
@@ -325,11 +334,14 @@ class AgentDB:
.filter_by(organization_id=organization_id)
.first()
):
task.status = status
if status is not None:
task.status = status
if extracted_information is not None:
task.extracted_information = extracted_information
if failure_reason is not None:
task.failure_reason = failure_reason
if errors is not None:
task.errors = errors
session.commit()
updated_task = await self.get_task(task_id, organization_id=organization_id)
if not updated_task:

View File

@@ -29,6 +29,7 @@ class TaskModel(Base):
organization_id = Column(String, ForeignKey("organizations.organization_id"))
status = Column(String)
webhook_callback_url = Column(String)
title = Column(String)
url = Column(String)
navigation_goal = Column(String)
data_extraction_goal = Column(String)
@@ -40,6 +41,8 @@ class TaskModel(Base):
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"))
order = Column(Integer, nullable=True)
retry = Column(Integer, nullable=True)
error_code_mapping = Column(JSON, nullable=True)
errors = Column(JSON, default=[], nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)

View File

@@ -48,6 +48,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
status=TaskStatus(task_obj.status),
created_at=task_obj.created_at,
modified_at=task_obj.modified_at,
title=task_obj.title,
url=task_obj.url,
webhook_callback_url=task_obj.webhook_callback_url,
navigation_goal=task_obj.navigation_goal,
@@ -61,6 +62,8 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
workflow_run_id=task_obj.workflow_run_id,
order=task_obj.order,
retry=task_obj.retry,
error_code_mapping=task_obj.error_code_mapping,
errors=task_obj.errors,
)
return task