diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index f4287299..5f8021ef 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -164,17 +164,19 @@ async def run_task( data_extraction_schema = run_request.data_extraction_schema navigation_goal = run_request.prompt navigation_payload = None - task_generation = await task_v1_service.generate_task( - user_prompt=run_request.prompt, - organization=current_org, - ) - url = url or task_generation.url - navigation_goal = task_generation.navigation_goal or run_request.prompt - if run_request.engine in CUA_ENGINES: - navigation_goal = run_request.prompt - navigation_payload = task_generation.navigation_payload - data_extraction_goal = task_generation.data_extraction_goal - data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema + if not url: + task_generation = await task_v1_service.generate_task( + user_prompt=run_request.prompt, + organization=current_org, + ) + # What if it's a SDK request with browser_session_id? + url = task_generation.url + navigation_goal = task_generation.navigation_goal or run_request.prompt + if run_request.engine in CUA_ENGINES: + navigation_goal = run_request.prompt + navigation_payload = task_generation.navigation_payload + data_extraction_goal = task_generation.data_extraction_goal + data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema task_v1_request = TaskRequest( title=run_request.title, diff --git a/skyvern/services/task_v1_service.py b/skyvern/services/task_v1_service.py index 7e1f0397..15efe093 100644 --- a/skyvern/services/task_v1_service.py +++ b/skyvern/services/task_v1_service.py @@ -47,7 +47,7 @@ async def generate_task(user_prompt: str, organization: Organization) -> TaskGen llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=user_prompt) try: - llm_response = await app.LLM_API_HANDLER( + llm_response = await app.SECONDARY_LLM_API_HANDLER( prompt=llm_prompt, prompt_name="generate-task", organization_id=organization.organization_id ) parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)