complete_on_downloads for task block (#403)

This commit is contained in:
Kerem Yilmaz
2024-06-02 23:24:30 -07:00
committed by GitHub
parent 343937e12c
commit f1d5a3a687
9 changed files with 118 additions and 30 deletions

View File

@@ -23,6 +23,7 @@ from skyvern.exceptions import (
from skyvern.forge import app
from skyvern.forge.async_operations import AgentPhase, AsyncOperationPool
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.files import get_number_of_files_in_directory, get_path_for_workflow_download_directory
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature
@@ -195,10 +196,18 @@ class ForgeAgent:
api_key: str | None = None,
workflow_run: WorkflowRun | None = None,
close_browser_on_completion: bool = True,
# If complete_on_download is True and there is a workflow run, the task will be marked as completed
# if a download happens during the step execution.
complete_on_download: bool = False,
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
next_step: Step | None = None
detailed_output: DetailedAgentStepOutput | None = None
num_files_before = 0
try:
if task.workflow_run_id:
num_files_before = get_number_of_files_in_directory(
get_path_for_workflow_download_directory(task.workflow_run_id)
)
# Check some conditions before executing the step, throw an exception if the step can't be executed
await app.AGENT_FUNCTION.validate_step_execution(task, step)
(
@@ -214,6 +223,30 @@ class ForgeAgent:
task = await self.update_task_errors_from_detailed_output(task, detailed_output)
retry = False
if complete_on_download and task.workflow_run_id:
num_files_after = get_number_of_files_in_directory(
get_path_for_workflow_download_directory(task.workflow_run_id)
)
if num_files_after > num_files_before:
LOG.info(
"Task marked as completed due to download",
task_id=task.task_id,
num_files_before=num_files_before,
num_files_after=num_files_after,
)
last_step = await self.update_step(step, is_last=True)
completed_task = await self.update_task(
task,
status=TaskStatus.completed,
)
await self.send_task_response(
task=completed_task,
last_step=last_step,
api_key=api_key,
close_browser_on_completion=close_browser_on_completion,
)
return last_step, detailed_output, None
# If the step failed, mark the step as failed and retry
if step.status == StepStatus.failed:
maybe_next_step = await self.handle_failed_step(organization, task, step)
@@ -273,6 +306,7 @@ class ForgeAgent:
next_step,
api_key=api_key,
close_browser_on_completion=close_browser_on_completion,
complete_on_download=complete_on_download,
)
elif SettingsManager.get_settings().execute_all_steps() and next_step:
return await self.execute_step(
@@ -281,6 +315,7 @@ class ForgeAgent:
next_step,
api_key=api_key,
close_browser_on_completion=close_browser_on_completion,
complete_on_download=complete_on_download,
)
else:
LOG.info(

View File

@@ -68,3 +68,13 @@ def zip_files(files_path: str, zip_file_path: str) -> str:
def get_path_for_workflow_download_directory(workflow_run_id: str) -> Path:
return Path(f"{REPO_ROOT_DIR}/downloads/{workflow_run_id}/")
def get_number_of_files_in_directory(directory: Path, recursive: bool = False) -> int:
count = 0
for root, dirs, files in os.walk(directory):
if not recursive:
count += len(files)
break
count += len(files)
return count

View File

@@ -217,7 +217,7 @@ class WorkflowRunContext:
self, parameter: OutputParameter, value: dict[str, Any] | list | str | None
) -> None:
if parameter.key in self.values:
LOG.error(f"Output parameter {parameter.output_parameter_id} already has a registered value")
LOG.warning(f"Output parameter {parameter.output_parameter_id} already has a registered value")
return
self.values[parameter.key] = value

View File

@@ -71,6 +71,14 @@ class Block(BaseModel, abc.ABC):
workflow_run_id: str,
value: dict[str, Any] | list | str | None = None,
) -> None:
if workflow_run_context.has_value(self.output_parameter.key):
LOG.warning(
"Output parameter value already recorded",
output_parameter_id=self.output_parameter.output_parameter_id,
workflow_run_id=workflow_run_id,
)
return
await workflow_run_context.register_output_parameter_value_post_execution(
parameter=self.output_parameter,
value=value,
@@ -150,6 +158,7 @@ class TaskBlock(Block):
max_retries: int = 0
max_steps_per_run: int | None = None
parameters: list[PARAMETER_TYPE] = []
complete_on_download: bool = False
def get_all_parameters(
self,
@@ -265,6 +274,7 @@ class TaskBlock(Block):
task=task,
step=step,
workflow_run=workflow_run,
complete_on_download=self.complete_on_download,
)
except Exception as e:
# Make sure the task is marked as failed in the database before raising the exception

View File

@@ -87,6 +87,7 @@ class TaskBlockYAML(BlockYAML):
max_retries: int = 0
max_steps_per_run: int | None = None
parameter_keys: list[str] | None = None
complete_on_download: bool = False
class ForLoopBlockYAML(BlockYAML):

View File

@@ -969,6 +969,7 @@ class WorkflowService:
error_code_mapping=block_yaml.error_code_mapping,
max_steps_per_run=block_yaml.max_steps_per_run,
max_retries=block_yaml.max_retries,
complete_on_download=block_yaml.complete_on_download,
)
elif block_yaml.block_type == BlockType.FOR_LOOP:
loop_blocks = [