allow user-based generic prompt improval [sic] (#3965)
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
Original prompt:
|
||||
|
||||
```
|
||||
{{ prompt }}
|
||||
```
|
||||
{% if context %}
|
||||
|
||||
Additional context about the user's needs:
|
||||
```
|
||||
{{ context }}
|
||||
```
|
||||
{% endif %}
|
||||
|
||||
Can you improve the original prompt for an AI browser agent?
|
||||
|
||||
Respond ONLY with valid JSON in this format with no additional text before or after it:
|
||||
```json
|
||||
{
|
||||
"improved_prompt": str, // The improved version of the prompt
|
||||
}
|
||||
|
||||
Ensure that the "improved_prompt" contains liberal whitespace tokens for formatting, clarity, and legibility.
|
||||
@@ -3,6 +3,7 @@ from skyvern.forge.sdk.routes import browser_profiles # noqa: F401
|
||||
from skyvern.forge.sdk.routes import browser_sessions # noqa: F401
|
||||
from skyvern.forge.sdk.routes import credentials # noqa: F401
|
||||
from skyvern.forge.sdk.routes import debug_sessions # noqa: F401
|
||||
from skyvern.forge.sdk.routes import prompts # noqa: F401
|
||||
from skyvern.forge.sdk.routes import pylon # noqa: F401
|
||||
from skyvern.forge.sdk.routes import run_blocks # noqa: F401
|
||||
from skyvern.forge.sdk.routes import scripts # noqa: F401
|
||||
|
||||
103
skyvern/forge/sdk/routes/prompts.py
Normal file
103
skyvern/forge/sdk/routes/prompts.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Endpoints for prompt management.
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import Depends, HTTPException, Query, status
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
||||
from skyvern.forge.sdk.routes.routers import base_router
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
from skyvern.forge.sdk.schemas.prompts import ImprovePromptRequest, ImprovePromptResponse
|
||||
from skyvern.forge.sdk.services import org_auth_service
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class Constants:
|
||||
ImprovePromptUseCaseToTemplateMap = {
|
||||
"new_workflow": "improve-prompt-for-ai-browser-agent",
|
||||
"task_v2_prompt": "improve-prompt-for-ai-browser-agent",
|
||||
}
|
||||
|
||||
|
||||
@base_router.post(
|
||||
"/prompts/improve",
|
||||
tags=["Prompts"],
|
||||
description="Improve a prompt based on a specific use-case",
|
||||
summary="Improve prompt",
|
||||
)
|
||||
async def improve_prompt(
|
||||
request: ImprovePromptRequest,
|
||||
use_case: str = Query(..., alias="use-case", description="The use-case for prompt improvement"),
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> ImprovePromptResponse:
|
||||
"""
|
||||
Improve a prompt based on a specific use-case.
|
||||
"""
|
||||
if use_case not in Constants.ImprovePromptUseCaseToTemplateMap:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"'{use_case}' use-case is unsupported.",
|
||||
)
|
||||
|
||||
template_name = Constants.ImprovePromptUseCaseToTemplateMap[use_case]
|
||||
|
||||
llm_prompt = prompt_engine.load_prompt(
|
||||
context=request.context,
|
||||
prompt=request.prompt,
|
||||
template=template_name,
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
"Improving prompt",
|
||||
use_case=use_case,
|
||||
organization_id=current_org.organization_id,
|
||||
prompt=request.prompt,
|
||||
llm_prompt=llm_prompt,
|
||||
)
|
||||
|
||||
try:
|
||||
llm_response = await app.LLM_API_HANDLER(
|
||||
prompt=llm_prompt,
|
||||
prompt_name=template_name,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
|
||||
if isinstance(llm_response, dict) and "output" in llm_response:
|
||||
output = llm_response["output"]
|
||||
else:
|
||||
output = llm_response
|
||||
|
||||
if not isinstance(output, dict):
|
||||
error = "LLM response is not valid JSON."
|
||||
output = ""
|
||||
elif "improved_prompt" not in output:
|
||||
error = "LLM response missing 'improved_prompt' field."
|
||||
output = ""
|
||||
else:
|
||||
error = None
|
||||
output = output["improved_prompt"]
|
||||
|
||||
response = ImprovePromptResponse(
|
||||
error=error,
|
||||
improved=output,
|
||||
original=request.prompt,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except LLMProviderError:
|
||||
LOG.error("Failed to improve prompt", use_case=use_case, exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to improve prompt. Please try again later.",
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.error("Unexpected error improving prompt", use_case=use_case, error=str(e), exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Failed to improve prompt: {str(e)}",
|
||||
)
|
||||
@@ -19,3 +19,14 @@ class CreateWorkflowFromPromptRequestV2(BaseModel):
|
||||
CreateFromPromptRequest = t.Annotated[
|
||||
t.Union[CreateWorkflowFromPromptRequestV1, CreateWorkflowFromPromptRequestV2], Field(discriminator="task_version")
|
||||
]
|
||||
|
||||
|
||||
class ImprovePromptRequest(BaseModel):
|
||||
context: t.Optional[str] = Field(None, description="Additional context about the user's needs")
|
||||
prompt: str = Field(..., min_length=1, description="The original prompt to improve")
|
||||
|
||||
|
||||
class ImprovePromptResponse(BaseModel):
|
||||
error: t.Optional[str] = Field(None, description="Error message if prompt improvement failed")
|
||||
improved: str = Field(..., description="The improved version of the prompt")
|
||||
original: str = Field(..., description="The original prompt provided for improvement")
|
||||
|
||||
Reference in New Issue
Block a user