add Request context to async_executor (#709)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
|
||||
import structlog
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi import BackgroundTasks, Request
|
||||
|
||||
from skyvern.exceptions import OrganizationNotFound
|
||||
from skyvern.forge import app
|
||||
@@ -16,6 +16,7 @@ class AsyncExecutor(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def execute_task(
|
||||
self,
|
||||
request: Request | None,
|
||||
background_tasks: BackgroundTasks,
|
||||
task_id: str,
|
||||
organization_id: str,
|
||||
@@ -28,6 +29,7 @@ class AsyncExecutor(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def execute_workflow(
|
||||
self,
|
||||
request: Request | None,
|
||||
background_tasks: BackgroundTasks,
|
||||
organization_id: str,
|
||||
workflow_id: str,
|
||||
@@ -42,6 +44,7 @@ class AsyncExecutor(abc.ABC):
|
||||
class BackgroundTaskExecutor(AsyncExecutor):
|
||||
async def execute_task(
|
||||
self,
|
||||
request: Request | None,
|
||||
background_tasks: BackgroundTasks,
|
||||
task_id: str,
|
||||
organization_id: str,
|
||||
@@ -83,6 +86,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
||||
|
||||
async def execute_workflow(
|
||||
self,
|
||||
request: Request | None,
|
||||
background_tasks: BackgroundTasks,
|
||||
organization_id: str,
|
||||
workflow_id: str,
|
||||
|
||||
@@ -110,6 +110,7 @@ async def check_server_status() -> Response:
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def create_agent_task(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
task: TaskRequest,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
@@ -123,6 +124,7 @@ async def create_agent_task(
|
||||
if x_max_steps_override:
|
||||
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
|
||||
await AsyncExecutorFactory.get_executor().execute_task(
|
||||
request=request,
|
||||
background_tasks=background_tasks,
|
||||
task_id=created_task.task_id,
|
||||
organization_id=current_org.organization_id,
|
||||
@@ -518,6 +520,7 @@ async def get_task_actions(
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def execute_workflow(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
workflow_id: str, # this is the workflow_permanent_id
|
||||
workflow_request: WorkflowRequestBody,
|
||||
@@ -540,6 +543,7 @@ async def execute_workflow(
|
||||
if x_max_steps_override:
|
||||
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
|
||||
await AsyncExecutorFactory.get_executor().execute_workflow(
|
||||
request=request,
|
||||
background_tasks=background_tasks,
|
||||
organization_id=current_org.organization_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
|
||||
Reference in New Issue
Block a user