add Request context to async_executor (#709)

This commit is contained in:
Kerem Yilmaz
2024-08-16 08:25:10 +03:00
committed by GitHub
parent 4a8b1473ec
commit fd5fdb9d32
4 changed files with 47 additions and 3 deletions

View File

@@ -1,7 +1,7 @@
import abc
import structlog
from fastapi import BackgroundTasks
from fastapi import BackgroundTasks, Request
from skyvern.exceptions import OrganizationNotFound
from skyvern.forge import app
@@ -16,6 +16,7 @@ class AsyncExecutor(abc.ABC):
@abc.abstractmethod
async def execute_task(
self,
request: Request | None,
background_tasks: BackgroundTasks,
task_id: str,
organization_id: str,
@@ -28,6 +29,7 @@ class AsyncExecutor(abc.ABC):
@abc.abstractmethod
async def execute_workflow(
self,
request: Request | None,
background_tasks: BackgroundTasks,
organization_id: str,
workflow_id: str,
@@ -42,6 +44,7 @@ class AsyncExecutor(abc.ABC):
class BackgroundTaskExecutor(AsyncExecutor):
async def execute_task(
self,
request: Request | None,
background_tasks: BackgroundTasks,
task_id: str,
organization_id: str,
@@ -83,6 +86,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
async def execute_workflow(
self,
request: Request | None,
background_tasks: BackgroundTasks,
organization_id: str,
workflow_id: str,

View File

@@ -110,6 +110,7 @@ async def check_server_status() -> Response:
include_in_schema=False,
)
async def create_agent_task(
request: Request,
background_tasks: BackgroundTasks,
task: TaskRequest,
current_org: Organization = Depends(org_auth_service.get_current_org),
@@ -123,6 +124,7 @@ async def create_agent_task(
if x_max_steps_override:
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
await AsyncExecutorFactory.get_executor().execute_task(
request=request,
background_tasks=background_tasks,
task_id=created_task.task_id,
organization_id=current_org.organization_id,
@@ -518,6 +520,7 @@ async def get_task_actions(
include_in_schema=False,
)
async def execute_workflow(
request: Request,
background_tasks: BackgroundTasks,
workflow_id: str, # this is the workflow_permanent_id
workflow_request: WorkflowRequestBody,
@@ -540,6 +543,7 @@ async def execute_workflow(
if x_max_steps_override:
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
await AsyncExecutorFactory.get_executor().execute_workflow(
request=request,
background_tasks=background_tasks,
organization_id=current_org.organization_id,
workflow_id=workflow_run.workflow_id,