adopt ruff as the replacement for python black (#332)

This commit is contained in:
Shuchang Zheng
2024-05-16 18:20:11 -07:00
committed by GitHub
parent 7a2be7e355
commit 2466897158
44 changed files with 1081 additions and 321 deletions

View File

@@ -93,7 +93,7 @@ class WorkflowRunContext:
assume it's an actual parameter value and return it.
"""
if type(secret_id_or_value) is str:
if isinstance(secret_id_or_value, str):
return self.secrets.get(secret_id_or_value)
return None
@@ -149,7 +149,7 @@ class WorkflowRunContext:
url = self.values[parameter.url_parameter_key]
else:
LOG.error(f"URL parameter {parameter.url_parameter_key} not found or has no value")
raise ValueError(f"URL parameter for Bitwarden login credentials not found or has no value")
raise ValueError("URL parameter for Bitwarden login credentials not found or has no value")
try:
secret_credentials = BitwardenService.get_secret_value_from_url(
@@ -224,7 +224,9 @@ class WorkflowRunContext:
await self.set_parameter_values_for_output_parameter_dependent_blocks(parameter, value)
async def set_parameter_values_for_output_parameter_dependent_blocks(
self, output_parameter: OutputParameter, value: dict[str, Any] | list | str | None
self,
output_parameter: OutputParameter,
value: dict[str, Any] | list | str | None,
) -> None:
for key, parameter in self.parameters.items():
if (
@@ -268,7 +270,7 @@ class WorkflowRunContext:
isinstance(x, ContextParameter),
# This makes sure that ContextParameters witha ContextParameter source are processed after all other
# ContextParameters
isinstance(x.source, ContextParameter) if isinstance(x, ContextParameter) else False,
(isinstance(x.source, ContextParameter) if isinstance(x, ContextParameter) else False),
isinstance(x, BitwardenLoginCredentialParameter),
)
)

View File

@@ -81,16 +81,20 @@ class Block(BaseModel, abc.ABC):
value=value,
)
LOG.info(
f"Registered output parameter value",
"Registered output parameter value",
output_parameter_id=self.output_parameter.output_parameter_id,
workflow_run_id=workflow_run_id,
)
def build_block_result(
self, success: bool, output_parameter_value: dict[str, Any] | list | str | None = None
self,
success: bool,
output_parameter_value: dict[str, Any] | list | str | None = None,
) -> BlockResult:
return BlockResult(
success=success, output_parameter=self.output_parameter, output_parameter_value=output_parameter_value
success=success,
output_parameter=self.output_parameter,
output_parameter_value=output_parameter_value,
)
@classmethod
@@ -236,11 +240,14 @@ class TaskBlock(Block):
workflow_run=workflow_run, url=self.url
)
if not browser_state.page:
LOG.error("BrowserState has no page", workflow_run_id=workflow_run.workflow_run_id)
LOG.error(
"BrowserState has no page",
workflow_run_id=workflow_run.workflow_run_id,
)
raise MissingBrowserStatePage(workflow_run_id=workflow_run.workflow_run_id)
LOG.info(
f"Navigating to page",
"Navigating to page",
url=self.url,
workflow_run_id=workflow_run_id,
task_id=task.task_id,
@@ -253,7 +260,12 @@ class TaskBlock(Block):
await browser_state.page.goto(self.url, timeout=settings.BROWSER_LOADING_TIMEOUT_MS)
try:
await app.agent.execute_step(organization=organization, task=task, step=step, workflow_run=workflow_run)
await app.agent.execute_step(
organization=organization,
task=task,
step=step,
workflow_run=workflow_run,
)
except Exception as e:
# Make sure the task is marked as failed in the database before raising the exception
await app.DATABASE.update_task(
@@ -273,7 +285,7 @@ class TaskBlock(Block):
if updated_task.status == TaskStatus.completed or updated_task.status == TaskStatus.terminated:
LOG.info(
f"Task completed",
"Task completed",
task_id=updated_task.task_id,
task_status=updated_task.status,
workflow_run_id=workflow_run_id,
@@ -400,7 +412,7 @@ class ForLoopBlock(Block):
)
if not loop_over_values or len(loop_over_values) == 0:
LOG.info(
f"No loop_over values found",
"No loop_over values found",
block_type=self.block_type,
workflow_run_id=workflow_run_id,
num_loop_over_values=len(loop_over_values),
@@ -519,7 +531,11 @@ class TextPromptBlock(Block):
+ json.dumps(self.json_schema, indent=2)
+ "\n```\n\n"
)
LOG.info("TextPromptBlock: Sending prompt to LLM", prompt=prompt, llm_key=self.llm_key)
LOG.info(
"TextPromptBlock: Sending prompt to LLM",
prompt=prompt,
llm_key=self.llm_key,
)
response = await llm_api_handler(prompt=prompt)
LOG.info("TextPromptBlock: Received response from LLM", response=response)
return response
@@ -692,7 +708,12 @@ class SendEmailBlock(Block):
workflow_run_id: str,
) -> list[PARAMETER_TYPE]:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
parameters = [self.smtp_host, self.smtp_port, self.smtp_username, self.smtp_password]
parameters = [
self.smtp_host,
self.smtp_port,
self.smtp_username,
self.smtp_password,
]
if self.file_attachments:
for file_path in self.file_attachments:
@@ -732,7 +753,12 @@ class SendEmailBlock(Block):
if email_config_problems:
raise InvalidEmailClientConfiguration(email_config_problems)
return smtp_host_value, smtp_port_value, smtp_username_value, smtp_password_value
return (
smtp_host_value,
smtp_port_value,
smtp_username_value,
smtp_password_value,
)
def _get_file_paths(self, workflow_run_context: WorkflowRunContext, workflow_run_id: str) -> list[str]:
file_paths = []
@@ -846,7 +872,12 @@ class SendEmailBlock(Block):
subtype=subtype,
)
with open(path, "rb") as fp:
msg.add_attachment(fp.read(), maintype=maintype, subtype=subtype, filename=attachment_filename)
msg.add_attachment(
fp.read(),
maintype=maintype,
subtype=subtype,
filename=attachment_filename,
)
finally:
if path:
os.unlink(path)
@@ -884,6 +915,12 @@ class SendEmailBlock(Block):
BlockSubclasses = Union[
ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock, DownloadToS3Block, UploadToS3Block, SendEmailBlock
ForLoopBlock,
TaskBlock,
CodeBlock,
TextPromptBlock,
DownloadToS3Block,
UploadToS3Block,
SendEmailBlock,
]
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]

View File

@@ -114,6 +114,10 @@ class OutputParameter(Parameter):
ParameterSubclasses = Union[
WorkflowParameter, ContextParameter, AWSSecretParameter, BitwardenLoginCredentialParameter, OutputParameter
WorkflowParameter,
ContextParameter,
AWSSecretParameter,
BitwardenLoginCredentialParameter,
OutputParameter,
]
PARAMETER_TYPE = Annotated[ParameterSubclasses, Field(discriminator="parameter_type")]

View File

@@ -166,7 +166,10 @@ class WorkflowService:
wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id)
workflow_output_parameters = await self.get_workflow_output_parameters(workflow_id=workflow.workflow_id)
app.WORKFLOW_CONTEXT_MANAGER.initialize_workflow_run_context(
workflow_run_id, wp_wps_tuples, workflow_output_parameters, context_parameters
workflow_run_id,
wp_wps_tuples,
workflow_output_parameters,
context_parameters,
)
# Execute workflow blocks
blocks = workflow.workflow_definition.blocks
@@ -203,7 +206,11 @@ class WorkflowService:
)
else:
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
await self.send_workflow_response(workflow=workflow, workflow_run=workflow_run, api_key=api_key)
await self.send_workflow_response(
workflow=workflow,
workflow_run=workflow_run,
api_key=api_key,
)
return workflow_run
except Exception:
@@ -224,13 +231,21 @@ class WorkflowService:
# Create a mapping of status to (action, log_func, log_message)
status_action_mapping = {
TaskStatus.running: (None, LOG.error, "has running tasks, this should not happen"),
TaskStatus.running: (
None,
LOG.error,
"has running tasks, this should not happen",
),
TaskStatus.terminated: (
self.mark_workflow_run_as_terminated,
LOG.warning,
"has terminated tasks, marking as terminated",
),
TaskStatus.failed: (self.mark_workflow_run_as_failed, LOG.warning, "has failed tasks, marking as failed"),
TaskStatus.failed: (
self.mark_workflow_run_as_failed,
LOG.warning,
"has failed tasks, marking as failed",
),
TaskStatus.completed: (
self.mark_workflow_run_as_completed,
LOG.info,
@@ -333,7 +348,7 @@ class WorkflowService:
title=title,
organization_id=organization_id,
description=description,
workflow_definition=workflow_definition.model_dump() if workflow_definition else None,
workflow_definition=(workflow_definition.model_dump() if workflow_definition else None),
)
async def delete_workflow_by_permanent_id(
@@ -529,7 +544,10 @@ class WorkflowService:
for task in workflow_run_tasks[::-1]:
screenshot_artifact = await app.DATABASE.get_latest_artifact(
task_id=task.task_id,
artifact_types=[ArtifactType.SCREENSHOT_ACTION, ArtifactType.SCREENSHOT_FINAL],
artifact_types=[
ArtifactType.SCREENSHOT_ACTION,
ArtifactType.SCREENSHOT_FINAL,
],
organization_id=organization_id,
)
if screenshot_artifact:
@@ -541,17 +559,19 @@ class WorkflowService:
recording_url = None
recording_artifact = await app.DATABASE.get_artifact_for_workflow_run(
workflow_run_id=workflow_run_id, artifact_type=ArtifactType.RECORDING, organization_id=organization_id
workflow_run_id=workflow_run_id,
artifact_type=ArtifactType.RECORDING,
organization_id=organization_id,
)
if recording_artifact:
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples}
output_parameter_tuples: list[tuple[OutputParameter, WorkflowRunOutputParameter]] = (
await self.get_output_parameter_workflow_run_output_parameter_tuples(
workflow_id=workflow_id, workflow_run_id=workflow_run_id
)
output_parameter_tuples: list[
tuple[OutputParameter, WorkflowRunOutputParameter]
] = await self.get_output_parameter_workflow_run_output_parameter_tuples(
workflow_id=workflow_id, workflow_run_id=workflow_run_id
)
if output_parameter_tuples:
outputs = {output_parameter.key: output.value for output_parameter, output in output_parameter_tuples}
@@ -587,7 +607,9 @@ class WorkflowService:
tasks = await self.get_tasks_by_workflow_run_id(workflow_run.workflow_run_id)
all_workflow_task_ids = [task.task_id for task in tasks]
browser_state = await app.BROWSER_MANAGER.cleanup_for_workflow_run(
workflow_run.workflow_run_id, all_workflow_task_ids, close_browser_on_completion
workflow_run.workflow_run_id,
all_workflow_task_ids,
close_browser_on_completion,
)
if browser_state:
await self.persist_video_data(browser_state, workflow, workflow_run)
@@ -600,7 +622,10 @@ class WorkflowService:
workflow_run_id=workflow_run.workflow_run_id,
organization_id=workflow.organization_id,
)
LOG.info("Built workflow run status response", workflow_run_status_response=workflow_run_status_response)
LOG.info(
"Built workflow run status response",
workflow_run_status_response=workflow_run_status_response,
)
if not workflow_run.webhook_callback_url:
LOG.warning(
@@ -661,7 +686,8 @@ class WorkflowService:
)
except Exception as e:
raise FailedToSendWebhook(
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
) from e
async def persist_video_data(
@@ -681,10 +707,16 @@ class WorkflowService:
)
async def persist_har_data(
self, browser_state: BrowserState, last_step: Step, workflow: Workflow, workflow_run: WorkflowRun
self,
browser_state: BrowserState,
last_step: Step,
workflow: Workflow,
workflow_run: WorkflowRun,
) -> None:
har_data = await app.BROWSER_MANAGER.get_har_data(
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id, browser_state=browser_state
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
browser_state=browser_state,
)
if har_data:
await app.ARTIFACT_MANAGER.create_artifact(
@@ -703,7 +735,11 @@ class WorkflowService:
await app.ARTIFACT_MANAGER.create_artifact(step=last_step, artifact_type=ArtifactType.TRACE, path=trace_path)
async def persist_debug_artifacts(
self, browser_state: BrowserState, last_task: Task, workflow: Workflow, workflow_run: WorkflowRun
self,
browser_state: BrowserState,
last_task: Task,
workflow: Workflow,
workflow_run: WorkflowRun,
) -> None:
last_step = await app.DATABASE.get_latest_step(
task_id=last_task.task_id, organization_id=last_task.organization_id
@@ -720,7 +756,11 @@ class WorkflowService:
request: WorkflowCreateYAMLRequest,
workflow_permanent_id: str | None = None,
) -> Workflow:
LOG.info("Creating workflow from request", organization_id=organization_id, title=request.title)
LOG.info(
"Creating workflow from request",
organization_id=organization_id,
title=request.title,
)
try:
if workflow_permanent_id:
existing_latest_workflow = await self.get_workflow_by_permanent_id(
@@ -769,7 +809,8 @@ class WorkflowService:
# Create output parameters for all blocks
block_output_parameters = await WorkflowService._create_all_output_parameters_for_workflow(
workflow_id=workflow.workflow_id, block_yamls=request.workflow_definition.blocks
workflow_id=workflow.workflow_id,
block_yamls=request.workflow_definition.blocks,
)
for block_output_parameter in block_output_parameters.values():
parameters[block_output_parameter.key] = block_output_parameter
@@ -822,7 +863,8 @@ class WorkflowService:
for context_parameter in context_parameter_yamls:
if context_parameter.source_parameter_key not in parameters:
raise ContextParameterSourceNotDefined(
context_parameter_key=context_parameter.key, source_key=context_parameter.source_parameter_key
context_parameter_key=context_parameter.key,
source_key=context_parameter.source_parameter_key,
)
if context_parameter.key in parameters:
@@ -901,7 +943,9 @@ class WorkflowService:
@staticmethod
async def block_yaml_to_block(
workflow: Workflow, block_yaml: BLOCK_YAML_TYPES, parameters: dict[str, Parameter]
workflow: Workflow,
block_yaml: BLOCK_YAML_TYPES,
parameters: dict[str, Parameter],
) -> BlockTypeVar:
output_parameter = parameters[f"{block_yaml.label}_output"]
if block_yaml.block_type == BlockType.TASK: