diff --git a/skyvern/core/script_generations/skyvern_page.py b/skyvern/core/script_generations/skyvern_page.py index 5b3d018c..b79956c0 100644 --- a/skyvern/core/script_generations/skyvern_page.py +++ b/skyvern/core/script_generations/skyvern_page.py @@ -330,6 +330,7 @@ class SkyvernPage: json_response = await app.SINGLE_CLICK_AGENT_LLM_API_HANDLER( prompt=single_click_prompt, prompt_name="single-click-action", + organization_id=context.organization_id, ) actions = json_response.get("actions", []) if actions: @@ -404,6 +405,7 @@ class SkyvernPage: json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER( prompt=script_generation_input_text_prompt, prompt_name="script-generation-input-text-generatiion", + organization_id=context.organization_id if context else None, ) value = json_response.get("answer", value) except Exception: diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index a87a7572..064baea4 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -100,6 +100,7 @@ class LLMAPIHandlerFactory: ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, + organization_id: str | None = None, ) -> dict[str, Any]: """ Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision. @@ -204,45 +205,43 @@ class LLMAPIHandlerFactory: cached_tokens = 0 completion_token_detail = None cached_token_detail = None - llm_cost = 0 - if step or thought: - try: - # FIXME: volcengine doesn't support litellm cost calculation. - llm_cost = litellm.completion_cost(completion_response=response) - except Exception as e: - LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True) - llm_cost = 0 - prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0) - completion_tokens = response.get("usage", {}).get("completion_tokens", 0) - reasoning_tokens = 0 - completion_token_detail = response.get("usage", {}).get("completion_tokens_details") - if completion_token_detail: - reasoning_tokens = completion_token_detail.reasoning_tokens or 0 - cached_tokens = 0 - cached_token_detail = response.get("usage", {}).get("prompt_tokens_details") - if cached_token_detail: - cached_tokens = cached_token_detail.cached_tokens or 0 - if step: - await app.DATABASE.update_step( - task_id=step.task_id, - step_id=step.step_id, - organization_id=step.organization_id, - incremental_cost=llm_cost, - incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None, - incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, - incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, - incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None, - ) - if thought: - await app.DATABASE.update_thought( - thought_id=thought.observer_thought_id, - organization_id=thought.organization_id, - input_token_count=prompt_tokens if prompt_tokens > 0 else None, - output_token_count=completion_tokens if completion_tokens > 0 else None, - thought_cost=llm_cost, - reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None, - cached_token_count=cached_tokens if cached_tokens > 0 else None, - ) + try: + # FIXME: volcengine doesn't support litellm cost calculation. + llm_cost = litellm.completion_cost(completion_response=response) + except Exception as e: + LOG.info("Failed to calculate LLM cost", error=str(e), exc_info=True) + llm_cost = 0 + prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0) + completion_tokens = response.get("usage", {}).get("completion_tokens", 0) + reasoning_tokens = 0 + completion_token_detail = response.get("usage", {}).get("completion_tokens_details") + if completion_token_detail: + reasoning_tokens = completion_token_detail.reasoning_tokens or 0 + cached_tokens = 0 + cached_token_detail = response.get("usage", {}).get("prompt_tokens_details") + if cached_token_detail: + cached_tokens = cached_token_detail.cached_tokens or 0 + if step: + await app.DATABASE.update_step( + task_id=step.task_id, + step_id=step.step_id, + organization_id=step.organization_id, + incremental_cost=llm_cost, + incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None, + incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, + incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, + incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None, + ) + if thought: + await app.DATABASE.update_thought( + thought_id=thought.observer_thought_id, + organization_id=thought.organization_id, + input_token_count=prompt_tokens if prompt_tokens > 0 else None, + output_token_count=completion_tokens if completion_tokens > 0 else None, + thought_cost=llm_cost, + reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None, + cached_token_count=cached_tokens if cached_tokens > 0 else None, + ) parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) await app.ARTIFACT_MANAGER.create_llm_artifact( data=json.dumps(parsed_response, indent=2).encode("utf-8"), @@ -267,6 +266,9 @@ class LLMAPIHandlerFactory: ) # Track LLM API handler duration, token counts, and cost + organization_id = organization_id or ( + step.organization_id if step else (thought.organization_id if thought else None) + ) duration_seconds = time.time() - start_time LOG.info( "LLM API handler duration metrics", @@ -276,7 +278,7 @@ class LLMAPIHandlerFactory: duration_seconds=duration_seconds, step_id=step.step_id if step else None, thought_id=thought.observer_thought_id if thought else None, - organization_id=step.organization_id if step else (thought.organization_id if thought else None), + organization_id=organization_id, input_tokens=prompt_tokens if prompt_tokens > 0 else None, output_tokens=completion_tokens if completion_tokens > 0 else None, reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, @@ -310,6 +312,7 @@ class LLMAPIHandlerFactory: ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, + organization_id: str | None = None, ) -> dict[str, Any]: start_time = time.time() active_parameters = base_parameters or {} @@ -421,45 +424,44 @@ class LLMAPIHandlerFactory: cached_tokens = 0 completion_token_detail = None cached_token_detail = None - llm_cost = 0 - if step or thought: - try: - # FIXME: volcengine doesn't support litellm cost calculation. - llm_cost = litellm.completion_cost(completion_response=response) - except Exception as e: - LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True) - llm_cost = 0 - prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0) - completion_tokens = response.get("usage", {}).get("completion_tokens", 0) - reasoning_tokens = 0 - completion_token_detail = response.get("usage", {}).get("completion_tokens_details") - if completion_token_detail: - reasoning_tokens = completion_token_detail.reasoning_tokens or 0 - cached_tokens = 0 - cached_token_detail = response.get("usage", {}).get("prompt_tokens_details") - if cached_token_detail: - cached_tokens = cached_token_detail.cached_tokens or 0 - if step: - await app.DATABASE.update_step( - task_id=step.task_id, - step_id=step.step_id, - organization_id=step.organization_id, - incremental_cost=llm_cost, - incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None, - incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, - incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, - incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None, - ) - if thought: - await app.DATABASE.update_thought( - thought_id=thought.observer_thought_id, - organization_id=thought.organization_id, - input_token_count=prompt_tokens if prompt_tokens > 0 else None, - output_token_count=completion_tokens if completion_tokens > 0 else None, - reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None, - cached_token_count=cached_tokens if cached_tokens > 0 else None, - thought_cost=llm_cost, - ) + try: + # FIXME: volcengine doesn't support litellm cost calculation. + llm_cost = litellm.completion_cost(completion_response=response) + except Exception as e: + LOG.info("Failed to calculate LLM cost", error=str(e), exc_info=True) + llm_cost = 0 + prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0) + completion_tokens = response.get("usage", {}).get("completion_tokens", 0) + reasoning_tokens = 0 + completion_token_detail = response.get("usage", {}).get("completion_tokens_details") + if completion_token_detail: + reasoning_tokens = completion_token_detail.reasoning_tokens or 0 + cached_tokens = 0 + cached_token_detail = response.get("usage", {}).get("prompt_tokens_details") + if cached_token_detail: + cached_tokens = cached_token_detail.cached_tokens or 0 + + if step: + await app.DATABASE.update_step( + task_id=step.task_id, + step_id=step.step_id, + organization_id=step.organization_id, + incremental_cost=llm_cost, + incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None, + incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, + incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, + incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None, + ) + if thought: + await app.DATABASE.update_thought( + thought_id=thought.observer_thought_id, + organization_id=thought.organization_id, + input_token_count=prompt_tokens if prompt_tokens > 0 else None, + output_token_count=completion_tokens if completion_tokens > 0 else None, + reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None, + cached_token_count=cached_tokens if cached_tokens > 0 else None, + thought_cost=llm_cost, + ) parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) await app.ARTIFACT_MANAGER.create_llm_artifact( data=json.dumps(parsed_response, indent=2).encode("utf-8"), @@ -484,6 +486,9 @@ class LLMAPIHandlerFactory: ) # Track LLM API handler duration, token counts, and cost + organization_id = organization_id or ( + step.organization_id if step else (thought.organization_id if thought else None) + ) duration_seconds = time.time() - start_time LOG.info( "LLM API handler duration metrics", @@ -493,7 +498,7 @@ class LLMAPIHandlerFactory: duration_seconds=duration_seconds, step_id=step.step_id if step else None, thought_id=thought.observer_thought_id if thought else None, - organization_id=step.organization_id if step else (thought.organization_id if thought else None), + organization_id=organization_id, input_tokens=prompt_tokens if prompt_tokens > 0 else None, output_tokens=completion_tokens if completion_tokens > 0 else None, reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, @@ -572,6 +577,7 @@ class LLMCaller: use_message_history: bool = False, raw_response: bool = False, window_dimension: Resolution | None = None, + organization_id: str | None = None, **extra_parameters: Any, ) -> dict[str, Any]: start_time = time.perf_counter() @@ -702,30 +708,32 @@ class LLMCaller: ai_suggestion=ai_suggestion, ) - call_stats = None - if step or thought: - call_stats = await self.get_call_stats(response) - if step: - await app.DATABASE.update_step( - task_id=step.task_id, - step_id=step.step_id, - organization_id=step.organization_id, - incremental_cost=call_stats.llm_cost, - incremental_input_tokens=call_stats.input_tokens, - incremental_output_tokens=call_stats.output_tokens, - incremental_reasoning_tokens=call_stats.reasoning_tokens, - incremental_cached_tokens=call_stats.cached_tokens, - ) - if thought: - await app.DATABASE.update_thought( - thought_id=thought.observer_thought_id, - organization_id=thought.organization_id, - input_token_count=call_stats.input_tokens, - output_token_count=call_stats.output_tokens, - reasoning_token_count=call_stats.reasoning_tokens, - cached_token_count=call_stats.cached_tokens, - thought_cost=call_stats.llm_cost, - ) + call_stats = await self.get_call_stats(response) + if step: + await app.DATABASE.update_step( + task_id=step.task_id, + step_id=step.step_id, + organization_id=step.organization_id, + incremental_cost=call_stats.llm_cost, + incremental_input_tokens=call_stats.input_tokens, + incremental_output_tokens=call_stats.output_tokens, + incremental_reasoning_tokens=call_stats.reasoning_tokens, + incremental_cached_tokens=call_stats.cached_tokens, + ) + if thought: + await app.DATABASE.update_thought( + thought_id=thought.observer_thought_id, + organization_id=thought.organization_id, + input_token_count=call_stats.input_tokens, + output_token_count=call_stats.output_tokens, + reasoning_token_count=call_stats.reasoning_tokens, + cached_token_count=call_stats.cached_tokens, + thought_cost=call_stats.llm_cost, + ) + + organization_id = organization_id or ( + step.organization_id if step else (thought.organization_id if thought else None) + ) # Track LLM API handler duration, token counts, and cost duration_seconds = time.perf_counter() - start_time LOG.info( @@ -736,7 +744,7 @@ class LLMCaller: duration_seconds=duration_seconds, step_id=step.step_id if step else None, thought_id=thought.observer_thought_id if thought else None, - organization_id=step.organization_id if step else (thought.organization_id if thought else None), + organization_id=organization_id, input_tokens=call_stats.input_tokens if call_stats and call_stats.input_tokens else None, output_tokens=call_stats.output_tokens if call_stats and call_stats.output_tokens else None, reasoning_tokens=call_stats.reasoning_tokens if call_stats and call_stats.reasoning_tokens else None, @@ -920,7 +928,7 @@ class LLMCaller: try: llm_cost = litellm.completion_cost(completion_response=response) except Exception as e: - LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True) + LOG.info("Failed to calculate LLM cost", error=str(e), exc_info=True) llm_cost = 0 input_tokens = response.get("usage", {}).get("prompt_tokens", 0) output_tokens = response.get("usage", {}).get("completion_tokens", 0) diff --git a/skyvern/forge/sdk/api/llm/models.py b/skyvern/forge/sdk/api/llm/models.py index 1e4ff6ad..b059a4c7 100644 --- a/skyvern/forge/sdk/api/llm/models.py +++ b/skyvern/forge/sdk/api/llm/models.py @@ -95,6 +95,7 @@ class LLMAPIHandler(Protocol): ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, + organization_id: str | None = None, ) -> Awaitable[dict[str, Any]]: ... @@ -107,5 +108,6 @@ async def dummy_llm_api_handler( ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, + organization_id: str | None = None, ) -> dict[str, Any]: raise NotImplementedError("Your LLM provider is not configured. Please configure it in the .env file.") diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index d0b8dc87..587eb3d5 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -1863,7 +1863,10 @@ async def suggest( ) llm_response = await app.LLM_API_HANDLER( - prompt=llm_prompt, ai_suggestion=new_ai_suggestion, prompt_name="suggest-data-schema" + prompt=llm_prompt, + ai_suggestion=new_ai_suggestion, + prompt_name="suggest-data-schema", + organization_id=current_org.organization_id, ) parsed_ai_suggestion = AISuggestionBase.model_validate(llm_response) diff --git a/skyvern/forge/sdk/routes/credentials.py b/skyvern/forge/sdk/routes/credentials.py index 35922334..16167599 100644 --- a/skyvern/forge/sdk/routes/credentials.py +++ b/skyvern/forge/sdk/routes/credentials.py @@ -32,9 +32,11 @@ from skyvern.forge.sdk.services.bitwarden import BitwardenService LOG = structlog.get_logger() -async def parse_totp_code(content: str) -> str | None: +async def parse_totp_code(content: str, organization_id: str) -> str | None: prompt = prompt_engine.load_prompt("parse-verification-code", content=content) - code_resp = await app.SECONDARY_LLM_API_HANDLER(prompt=prompt, prompt_name="parse-verification-code") + code_resp = await app.SECONDARY_LLM_API_HANDLER( + prompt=prompt, prompt_name="parse-verification-code", organization_id=organization_id + ) LOG.info("TOTP Code Parser Response", code_resp=code_resp) return code_resp.get("code", None) @@ -58,7 +60,8 @@ async def parse_totp_code(content: str) -> str | None: include_in_schema=False, ) async def send_totp_code( - data: TOTPCodeCreate, curr_org: Organization = Depends(org_auth_service.get_current_org) + data: TOTPCodeCreate, + curr_org: Organization = Depends(org_auth_service.get_current_org), ) -> TOTPCode: LOG.info( "Saving TOTP code", @@ -72,7 +75,7 @@ async def send_totp_code( code: str | None = content # We assume the user is sending the code directly when the length of code is less than or equal to 10 if len(content) > 10: - code = await parse_totp_code(content) + code = await parse_totp_code(content, curr_org.organization_id) if not code: LOG.error( "Failed to parse totp code", diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 46379d83..03791a14 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -696,6 +696,7 @@ class WorkflowService: metadata_response = await app.LLM_API_HANDLER( prompt=metadata_prompt, prompt_name="conversational_ui_goal", + organization_id=organization.organization_id, ) block_label: str = metadata_response.get("block_label", DEFAULT_FIRST_BLOCK_LABEL) diff --git a/skyvern/services/script_service.py b/skyvern/services/script_service.py index fc5b5cdc..3cf24409 100644 --- a/skyvern/services/script_service.py +++ b/skyvern/services/script_service.py @@ -1360,6 +1360,7 @@ async def generate_text( json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER( prompt=script_generation_input_text_prompt, prompt_name="script-generation-input-text-generatiion", + organization_id=context.organization_id, ) new_text = json_response.get("answer", new_text) except Exception: diff --git a/skyvern/services/task_v1_service.py b/skyvern/services/task_v1_service.py index b975211d..b38a8af0 100644 --- a/skyvern/services/task_v1_service.py +++ b/skyvern/services/task_v1_service.py @@ -47,7 +47,9 @@ async def generate_task(user_prompt: str, organization: Organization) -> TaskGen llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=user_prompt) try: - llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, prompt_name="generate-task") + llm_response = await app.LLM_API_HANDLER( + prompt=llm_prompt, prompt_name="generate-task", organization_id=organization.organization_id + ) parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response) # generate a TaskGenerationModel diff --git a/skyvern/services/task_v2_service.py b/skyvern/services/task_v2_service.py index d371fb62..8de87f3c 100644 --- a/skyvern/services/task_v2_service.py +++ b/skyvern/services/task_v2_service.py @@ -1284,6 +1284,7 @@ async def _generate_extraction_task( generate_extraction_task_prompt, task_v2=task_v2, prompt_name="task_v2_generate_extraction_task", + organization_id=task_v2.organization_id, ) LOG.info("Data extraction response", data_extraction_response=generate_extraction_task_response) diff --git a/skyvern/webeye/actions/handler.py b/skyvern/webeye/actions/handler.py index 194c4857..bcd76b95 100644 --- a/skyvern/webeye/actions/handler.py +++ b/skyvern/webeye/actions/handler.py @@ -3800,7 +3800,7 @@ async def _get_input_or_select_context( starter=element_handle, frame=skyvern_element.get_frame_id(), ) - clean_up_func = app.AGENT_FUNCTION.cleanup_element_tree_factory() + clean_up_func = app.AGENT_FUNCTION.cleanup_element_tree_factory(step=step) element_tree = await clean_up_func(skyvern_element.get_frame(), "", copy.deepcopy(element_tree)) element_tree_trimmed = trim_element_tree(copy.deepcopy(element_tree)) element_tree_builder = ScrapedPage( diff --git a/skyvern/webeye/actions/parse_actions.py b/skyvern/webeye/actions/parse_actions.py index b6c53764..23ec06ca 100644 --- a/skyvern/webeye/actions/parse_actions.py +++ b/skyvern/webeye/actions/parse_actions.py @@ -741,6 +741,7 @@ async def generate_cua_fallback_actions( action_response = await app.LLM_API_HANDLER( prompt=fallback_action_prompt, prompt_name="cua-fallback-action", + step=step, ) LOG.info("Fallback action response", action_response=action_response) skyvern_action_type = action_response.get("action")