From 2f10e3c4306281ec6775af26bad7de161699dcb3 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Tue, 29 Apr 2025 03:55:52 +0800 Subject: [PATCH] Add organization_name and workflow_permanent_id to skyvern context, pass workflow_permanet_id when deciding which identifier to use with which llm (#2245) --- skyvern/agent/agent.py | 1 + skyvern/forge/agent.py | 2 +- skyvern/forge/sdk/core/skyvern_context.py | 2 ++ skyvern/forge/sdk/executor/async_executor.py | 9 +++------ skyvern/forge/sdk/forge_log.py | 4 ++++ skyvern/forge/sdk/routes/agent_protocol.py | 4 ++-- skyvern/forge/sdk/services/org_auth_service.py | 1 + skyvern/forge/sdk/workflow/models/block.py | 3 +++ skyvern/forge/sdk/workflow/service.py | 9 +++++---- skyvern/services/task_v2_service.py | 3 +-- skyvern/services/workflow_service.py | 11 ++++++----- 11 files changed, 29 insertions(+), 20 deletions(-) diff --git a/skyvern/agent/agent.py b/skyvern/agent/agent.py index 63c897bc..aa7b0528 100644 --- a/skyvern/agent/agent.py +++ b/skyvern/agent/agent.py @@ -119,6 +119,7 @@ class SkyvernAgent: skyvern_context.set( SkyvernContext( organization_id=organization.organization_id, + organization_name=organization.organization_name, task_id=task.task_id, max_steps_override=max_steps, ) diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 4fea82d5..36cf8e86 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -2657,7 +2657,7 @@ class ForgeAgent: ) data_extraction_summary_resp = await app.SECONDARY_LLM_API_HANDLER( - prompt=prompt, step=step, screenshots=scraped_page.screenshots, prompt_name="data-extraction-summary" + prompt=prompt, step=step, prompt_name="data-extraction-summary" ) return ExtractAction( reasoning=data_extraction_summary_resp.get("summary", "Extracting information from the page"), diff --git a/skyvern/forge/sdk/core/skyvern_context.py b/skyvern/forge/sdk/core/skyvern_context.py index d86a91d2..819cb3ab 100644 --- a/skyvern/forge/sdk/core/skyvern_context.py +++ b/skyvern/forge/sdk/core/skyvern_context.py @@ -9,8 +9,10 @@ from playwright.async_api import Frame class SkyvernContext: request_id: str | None = None organization_id: str | None = None + organization_name: str | None = None task_id: str | None = None workflow_id: str | None = None + workflow_permanent_id: str | None = None workflow_run_id: str | None = None task_v2_id: str | None = None max_steps_override: int | None = None diff --git a/skyvern/forge/sdk/executor/async_executor.py b/skyvern/forge/sdk/executor/async_executor.py index 135cc57b..2d1b0f30 100644 --- a/skyvern/forge/sdk/executor/async_executor.py +++ b/skyvern/forge/sdk/executor/async_executor.py @@ -7,6 +7,7 @@ from skyvern.exceptions import OrganizationNotFound from skyvern.forge import app from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.skyvern_context import SkyvernContext +from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.task_v2 import TaskV2Status from skyvern.forge.sdk.schemas.tasks import TaskStatus from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus @@ -36,7 +37,7 @@ class AsyncExecutor(abc.ABC): self, request: Request | None, background_tasks: BackgroundTasks, - organization_id: str, + organization: Organization, workflow_id: str, workflow_run_id: str, max_steps_override: int | None, @@ -120,7 +121,7 @@ class BackgroundTaskExecutor(AsyncExecutor): self, request: Request | None, background_tasks: BackgroundTasks | None, - organization_id: str, + organization: Organization, workflow_id: str, workflow_run_id: str, max_steps_override: int | None, @@ -133,10 +134,6 @@ class BackgroundTaskExecutor(AsyncExecutor): workflow_run_id=workflow_run_id, ) - organization = await app.DATABASE.get_organization(organization_id) - if organization is None: - raise OrganizationNotFound(organization_id) - if background_tasks: background_tasks.add_task( app.WORKFLOW_SERVICE.execute_workflow, diff --git a/skyvern/forge/sdk/forge_log.py b/skyvern/forge/sdk/forge_log.py index 02efcbf0..58df81f8 100644 --- a/skyvern/forge/sdk/forge_log.py +++ b/skyvern/forge/sdk/forge_log.py @@ -27,12 +27,16 @@ def add_kv_pairs_to_msg(logger: logging.Logger, method_name: str, event_dict: Ev event_dict["request_id"] = context.request_id if context.organization_id: event_dict["organization_id"] = context.organization_id + if context.organization_name: + event_dict["organization_name"] = context.organization_name if context.task_id: event_dict["task_id"] = context.task_id if context.workflow_id: event_dict["workflow_id"] = context.workflow_id if context.workflow_run_id: event_dict["workflow_run_id"] = context.workflow_run_id + if context.workflow_permanent_id: + event_dict["workflow_permanent_id"] = context.workflow_permanent_id if context.task_v2_id: event_dict["task_v2_id"] = context.task_v2_id if context.browser_session_id: diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 1bb27887..f38f773c 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -667,7 +667,7 @@ async def run_workflow_legacy( workflow_run = await workflow_service.run_workflow( workflow_id=workflow_id, - organization_id=current_org.organization_id, + organization=current_org, workflow_request=workflow_request, template=template, version=version, @@ -1634,7 +1634,7 @@ async def run_workflow( ) workflow_run = await workflow_service.run_workflow( workflow_id=workflow_id, - organization_id=current_org.organization_id, + organization=current_org, workflow_request=legacy_workflow_request, template=template, version=None, diff --git a/skyvern/forge/sdk/services/org_auth_service.py b/skyvern/forge/sdk/services/org_auth_service.py index edd57d4e..f9af5459 100644 --- a/skyvern/forge/sdk/services/org_auth_service.py +++ b/skyvern/forge/sdk/services/org_auth_service.py @@ -128,4 +128,5 @@ async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization: context = skyvern_context.current() if context: context.organization_id = organization.organization_id + context.organization_name = organization.organization_name return organization diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index edb687a8..63d25815 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -2497,6 +2497,9 @@ class TaskV2Block(Block): skyvern_context.set( skyvern_context.SkyvernContext( organization_id=organization_id, + organization_name=organization.organization_name, + workflow_id=workflow_run.workflow_id, + workflow_permanent_id=workflow_run.workflow_permanent_id, workflow_run_id=workflow_run_id, browser_session_id=browser_session_id, ) diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index e9694dbe..c0cb1d3f 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -102,7 +102,7 @@ class WorkflowService: request_id: str | None, workflow_request: WorkflowRequestBody, workflow_permanent_id: str, - organization_id: str, + organization: Organization, is_template_workflow: bool = False, version: int | None = None, max_steps_override: int | None = None, @@ -121,7 +121,7 @@ class WorkflowService: # Validate the workflow and the organization workflow = await self.get_workflow_by_permanent_id( workflow_permanent_id=workflow_permanent_id, - organization_id=None if is_template_workflow else organization_id, + organization_id=None if is_template_workflow else organization.organization_id, version=version, ) if workflow is None: @@ -137,7 +137,7 @@ class WorkflowService: workflow_request=workflow_request, workflow_permanent_id=workflow_permanent_id, workflow_id=workflow_id, - organization_id=organization_id, + organization_id=organization.organization_id, parent_workflow_run_id=parent_workflow_run_id, ) LOG.info( @@ -151,7 +151,8 @@ class WorkflowService: ) skyvern_context.set( SkyvernContext( - organization_id=organization_id, + organization_id=organization.organization_id, + organization_name=organization.organization_name, request_id=request_id, workflow_id=workflow_id, workflow_run_id=workflow_run.workflow_run_id, diff --git a/skyvern/services/task_v2_service.py b/skyvern/services/task_v2_service.py index af585ff0..1d29aaca 100644 --- a/skyvern/services/task_v2_service.py +++ b/skyvern/services/task_v2_service.py @@ -211,7 +211,7 @@ async def initialize_task_v2( request_id=None, workflow_request=WorkflowRequestBody(), workflow_permanent_id=new_workflow.workflow_permanent_id, - organization_id=organization.organization_id, + organization=organization, version=None, max_steps_override=max_steps_override, parent_workflow_run_id=parent_workflow_run_id, @@ -1491,7 +1491,6 @@ async def _summarize_task_v2( ) task_v2_summary_resp = await app.LLM_API_HANDLER( prompt=task_v2_summary_prompt, - screenshots=screenshots, thought=thought, prompt_name="task_v2_summary", ) diff --git a/skyvern/services/workflow_service.py b/skyvern/services/workflow_service.py index f12d7aeb..2f4ca9a0 100644 --- a/skyvern/services/workflow_service.py +++ b/skyvern/services/workflow_service.py @@ -3,6 +3,7 @@ from fastapi import BackgroundTasks, Request from skyvern.forge import app from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory +from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.workflow.exceptions import InvalidTemplateWorkflowPermanentId from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody, WorkflowRun from skyvern.schemas.runs import RunType @@ -12,7 +13,7 @@ LOG = structlog.get_logger(__name__) async def run_workflow( workflow_id: str, - organization_id: str, + organization: Organization, workflow_request: WorkflowRequestBody, # this is the deprecated workflow request body template: bool = False, version: int | None = None, @@ -30,19 +31,19 @@ async def run_workflow( request_id=request_id, workflow_request=workflow_request, workflow_permanent_id=workflow_id, - organization_id=organization_id, + organization=organization, version=version, max_steps_override=max_steps, is_template_workflow=template, ) workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id( workflow_permanent_id=workflow_id, - organization_id=None if template else organization_id, + organization_id=None if template else organization.organization_id, version=version, ) await app.DATABASE.create_task_run( task_run_type=RunType.workflow_run, - organization_id=organization_id, + organization_id=organization.organization_id, run_id=workflow_run.workflow_run_id, title=workflow.title, ) @@ -51,7 +52,7 @@ async def run_workflow( await AsyncExecutorFactory.get_executor().execute_workflow( request=request, background_tasks=background_tasks, - organization_id=organization_id, + organization=organization, workflow_id=workflow_run.workflow_id, workflow_run_id=workflow_run.workflow_run_id, max_steps_override=max_steps,