Implement actions api changes (#1007)

This commit is contained in:
Shuchang Zheng
2024-10-18 12:50:02 -07:00
committed by GitHub
parent 8271813077
commit ec9b77c699
2 changed files with 26 additions and 11 deletions

View File

@@ -274,6 +274,26 @@ class AgentDB:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_task_actions(self, task_id: str, organization_id: str | None = None) -> list[Action]:
try:
async with self.Session() as session:
query = (
select(ActionModel)
.filter(ActionModel.organization_id == organization_id)
.filter(ActionModel.task_id == task_id)
.order_by(ActionModel.step_order, ActionModel.action_order, ActionModel.created_at)
)
actions = (await session.scalars(query)).all()
return [Action.model_validate(action) for action in actions]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_latest_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
try:
async with self.Session() as session:

View File

@@ -52,6 +52,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRunStatusResponse,
)
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
from skyvern.webeye.actions.actions import Action
base_router = APIRouter()
@@ -508,25 +509,19 @@ class ActionResultTmp(BaseModel):
success: bool = True
@base_router.get("/tasks/{task_id}/actions", response_model=list[ActionResultTmp])
@base_router.get("/tasks/{task_id}/actions", response_model=list[Action])
@base_router.get(
"/tasks/{task_id}/actions/",
response_model=list[ActionResultTmp],
response_model=list[Action],
include_in_schema=False,
)
async def get_task_actions(
task_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> list[ActionResultTmp]:
) -> list[Action]:
analytics.capture("skyvern-oss-agent-task-actions-get")
steps = await app.DATABASE.get_task_step_models(task_id, organization_id=current_org.organization_id)
results: list[ActionResultTmp] = []
for step_s in steps:
if not step_s.output or "action_results" not in step_s.output:
continue
for action_result in step_s.output["action_results"]:
results.append(ActionResultTmp.model_validate(action_result))
return results
actions = await app.DATABASE.get_task_actions(task_id, organization_id=current_org.organization_id)
return actions
@base_router.post("/workflows/{workflow_id}/run", response_model=RunWorkflowResponse)