create cruise related artifact in cruise api (#1355)

This commit is contained in:
Shuchang Zheng
2024-12-08 21:17:58 -08:00
committed by GitHub
parent bda119027e
commit 5842bfc1fd
6 changed files with 209 additions and 75 deletions

View File

@@ -20,6 +20,7 @@ from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMConfig, LLMRouter
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought
LOG = structlog.get_logger()
@@ -58,6 +59,8 @@ class LLMAPIHandlerFactory:
async def llm_api_handler_with_router_and_fallback(
prompt: str,
step: Step | None = None,
observer_cruise: ObserverCruise | None = None,
observer_thought: ObserverThought | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> dict[str, Any]:
@@ -76,32 +79,29 @@ class LLMAPIHandlerFactory:
if parameters is None:
parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
for screenshot in screenshots or []:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=prompt.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(
{
"model": llm_key,
"messages": messages,
**parameters,
}
).encode("utf-8"),
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
"model": llm_key,
"messages": messages,
**parameters,
}
).encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
try:
response = await router.acompletion(model=main_model_group, messages=messages, **parameters)
LOG.info("LLM API call successful", llm_key=llm_key, model=llm_config.model_name)
@@ -122,12 +122,14 @@ class LLMAPIHandlerFactory:
)
raise LLMProviderError(llm_key) from e
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=response.model_dump_json(indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=response.model_dump_json(indent=2).encode("utf-8"),
)
llm_cost = litellm.completion_cost(completion_response=response)
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
@@ -140,12 +142,13 @@ class LLMAPIHandlerFactory:
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
)
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
return parsed_response
return llm_api_handler_with_router_and_fallback
@@ -162,6 +165,8 @@ class LLMAPIHandlerFactory:
async def llm_api_handler(
prompt: str,
step: Step | None = None,
observer_cruise: ObserverCruise | None = None,
observer_thought: ObserverThought | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> dict[str, Any]:
@@ -173,37 +178,33 @@ class LLMAPIHandlerFactory:
if llm_config.litellm_params: # type: ignore
active_parameters.update(llm_config.litellm_params) # type: ignore
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
for screenshot in screenshots or []:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=prompt.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
# TODO (kerem): instead of overriding the screenshots, should we just not take them in the first place?
if not llm_config.supports_vision:
screenshots = None
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(
{
"model": llm_config.model_name,
"messages": messages,
# we're not using active_parameters here because it may contain sensitive information
**parameters,
}
).encode("utf-8"),
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
"model": llm_config.model_name,
"messages": messages,
# we're not using active_parameters here because it may contain sensitive information
**parameters,
}
).encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
t_llm_request = time.perf_counter()
try:
# TODO (kerem): add a timeout to this call
@@ -231,12 +232,16 @@ class LLMAPIHandlerFactory:
except Exception as e:
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
raise LLMProviderError(llm_key) from e
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=response.model_dump_json(indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=response.model_dump_json(indent=2).encode("utf-8"),
)
llm_cost = litellm.completion_cost(completion_response=response)
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
@@ -249,12 +254,13 @@ class LLMAPIHandlerFactory:
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
)
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
return parsed_response
return llm_api_handler

View File

@@ -4,6 +4,7 @@ from typing import Any, Awaitable, Literal, Optional, Protocol, TypedDict
from litellm import AllowedFailsPolicy
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverCruise
from skyvern.forge.sdk.settings_manager import SettingsManager
@@ -78,6 +79,8 @@ class LLMAPIHandler(Protocol):
self,
prompt: str,
step: Step | None = None,
observer_cruise: ObserverCruise | None = None,
observer_thought: ObserverCruise | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> Awaitable[dict[str, Any]]: ...