adopt ruff as the replacement for python black (#332)

This commit is contained in:
Shuchang Zheng
2024-05-16 18:20:11 -07:00
committed by GitHub
parent 7a2be7e355
commit 2466897158
44 changed files with 1081 additions and 321 deletions

View File

@@ -44,7 +44,7 @@ class LLMAPIHandlerFactory:
),
num_retries=llm_config.num_retries,
retry_after=llm_config.retry_delay_seconds,
set_verbose=False if SettingsManager.get_settings().is_cloud_environment() else llm_config.set_verbose,
set_verbose=(False if SettingsManager.get_settings().is_cloud_environment() else llm_config.set_verbose),
enable_pre_call_checks=True,
)
main_model_group = llm_config.main_model_group
@@ -101,7 +101,11 @@ class LLMAPIHandlerFactory:
except openai.OpenAIError as e:
raise LLMProviderError(llm_key) from e
except Exception as e:
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key, model=main_model_group)
LOG.exception(
"LLM request failed unexpectedly",
llm_key=llm_key,
model=main_model_group,
)
raise LLMProviderError(llm_key) from e
if step:

View File

@@ -58,35 +58,58 @@ if not any(
if SettingsManager.get_settings().ENABLE_OPENAI:
LLMConfigRegistry.register_config(
"OPENAI_GPT4_TURBO",
LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], supports_vision=False, add_assistant_prefix=False),
LLMConfig(
"gpt-4-turbo",
["OPENAI_API_KEY"],
supports_vision=False,
add_assistant_prefix=False,
),
)
LLMConfigRegistry.register_config(
"OPENAI_GPT4V", LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], supports_vision=True, add_assistant_prefix=False)
"OPENAI_GPT4V",
LLMConfig(
"gpt-4-turbo",
["OPENAI_API_KEY"],
supports_vision=True,
add_assistant_prefix=False,
),
)
if SettingsManager.get_settings().ENABLE_ANTHROPIC:
LLMConfigRegistry.register_config(
"ANTHROPIC_CLAUDE3",
LLMConfig(
"anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
"anthropic/claude-3-sonnet-20240229",
["ANTHROPIC_API_KEY"],
supports_vision=True,
add_assistant_prefix=True,
),
)
LLMConfigRegistry.register_config(
"ANTHROPIC_CLAUDE3_OPUS",
LLMConfig(
"anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
"anthropic/claude-3-opus-20240229",
["ANTHROPIC_API_KEY"],
supports_vision=True,
add_assistant_prefix=True,
),
)
LLMConfigRegistry.register_config(
"ANTHROPIC_CLAUDE3_SONNET",
LLMConfig(
"anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
"anthropic/claude-3-sonnet-20240229",
["ANTHROPIC_API_KEY"],
supports_vision=True,
add_assistant_prefix=True,
),
)
LLMConfigRegistry.register_config(
"ANTHROPIC_CLAUDE3_HAIKU",
LLMConfig(
"anthropic/claude-3-haiku-20240307", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
"anthropic/claude-3-haiku-20240307",
["ANTHROPIC_API_KEY"],
supports_vision=True,
add_assistant_prefix=True,
),
)
@@ -125,7 +148,12 @@ if SettingsManager.get_settings().ENABLE_AZURE:
"AZURE_OPENAI_GPT4V",
LLMConfig(
f"azure/{SettingsManager.get_settings().AZURE_DEPLOYMENT}",
["AZURE_DEPLOYMENT", "AZURE_API_KEY", "AZURE_API_BASE", "AZURE_API_VERSION"],
[
"AZURE_DEPLOYMENT",
"AZURE_API_KEY",
"AZURE_API_BASE",
"AZURE_API_VERSION",
],
supports_vision=True,
add_assistant_prefix=False,
),

View File

@@ -33,7 +33,10 @@ async def llm_messages_builder(
)
# Anthropic models seems to struggle to always output a valid json object so we need to prefill the response to force it:
if add_assistant_prefix:
return [{"role": "user", "content": messages}, {"role": "assistant", "content": "{"}]
return [
{"role": "user", "content": messages},
{"role": "assistant", "content": "{"},
]
return [{"role": "user", "content": messages}]

View File

@@ -17,7 +17,11 @@ class ArtifactManager:
upload_aiotasks_map: dict[str, list[asyncio.Task[None]]] = defaultdict(list)
async def create_artifact(
self, step: Step, artifact_type: ArtifactType, data: bytes | None = None, path: str | None = None
self,
step: Step,
artifact_type: ArtifactType,
data: bytes | None = None,
path: str | None = None,
) -> str:
# TODO (kerem): Which is better?
# current: (disadvantage: we create the artifact_id UUID here)
@@ -87,7 +91,10 @@ class ArtifactManager:
duration=time.time() - st,
)
except asyncio.TimeoutError:
LOG.error(f"Timeout (30s) while waiting for upload tasks for task_id={task_id}", task_id=task_id)
LOG.error(
f"Timeout (30s) while waiting for upload tasks for task_id={task_id}",
task_id=task_id,
)
del self.upload_aiotasks_map[task_id]
@@ -109,7 +116,10 @@ class ArtifactManager:
duration=time.time() - st,
)
except asyncio.TimeoutError:
LOG.error(f"Timeout (30s) while waiting for upload tasks for task_ids={task_ids}", task_ids=task_ids)
LOG.error(
f"Timeout (30s) while waiting for upload tasks for task_ids={task_ids}",
task_ids=task_ids,
)
for task_id in task_ids:
del self.upload_aiotasks_map[task_id]

View File

@@ -28,7 +28,11 @@ class LocalStorage(BaseStorage):
with open(file_path, "wb") as f:
f.write(data)
except Exception:
LOG.exception("Failed to store artifact locally.", file_path=file_path, artifact=artifact)
LOG.exception(
"Failed to store artifact locally.",
file_path=file_path,
artifact=artifact,
)
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
file_path = None
@@ -37,7 +41,11 @@ class LocalStorage(BaseStorage):
self._create_directories_if_not_exists(file_path)
Path(path).replace(file_path)
except Exception:
LOG.exception("Failed to store artifact locally.", file_path=file_path, artifact=artifact)
LOG.exception(
"Failed to store artifact locally.",
file_path=file_path,
artifact=artifact,
)
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
file_path = None
@@ -46,7 +54,11 @@ class LocalStorage(BaseStorage):
with open(file_path, "rb") as f:
return f.read()
except Exception:
LOG.exception("Failed to retrieve local artifact.", file_path=file_path, artifact=artifact)
LOG.exception(
"Failed to retrieve local artifact.",
file_path=file_path,
artifact=artifact,
)
return None
async def get_share_link(self, artifact: Artifact) -> str:

View File

@@ -184,7 +184,11 @@ class AgentDB:
).first():
return convert_to_task(task_obj, self.debug_enabled)
else:
LOG.info("Task not found", task_id=task_id, organization_id=organization_id)
LOG.info(
"Task not found",
task_id=task_id,
organization_id=organization_id,
)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
@@ -266,7 +270,11 @@ class AgentDB:
).first():
return convert_to_step(step, debug_enabled=self.debug_enabled)
else:
LOG.info("Latest step not found", task_id=task_id, organization_id=organization_id)
LOG.info(
"Latest step not found",
task_id=task_id,
organization_id=organization_id,
)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
@@ -812,7 +820,10 @@ class AgentDB:
)
.where(WorkflowModel.organization_id == organization_id)
.where(WorkflowModel.deleted_at.is_(None))
.group_by(WorkflowModel.organization_id, WorkflowModel.workflow_permanent_id)
.group_by(
WorkflowModel.organization_id,
WorkflowModel.workflow_permanent_id,
)
.subquery()
)
main_query = (
@@ -924,7 +935,10 @@ class AgentDB:
await session.commit()
await session.refresh(workflow_run)
return convert_to_workflow_run(workflow_run)
LOG.error("WorkflowRun not found, nothing to update", workflow_run_id=workflow_run_id)
LOG.error(
"WorkflowRun not found, nothing to update",
workflow_run_id=workflow_run_id,
)
return None
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun | None:
@@ -1066,7 +1080,10 @@ class AgentDB:
raise
async def create_workflow_run_output_parameter(
self, workflow_run_id: str, output_parameter_id: str, value: dict[str, Any] | list | str | None
self,
workflow_run_id: str,
output_parameter_id: str,
value: dict[str, Any] | list | str | None,
) -> WorkflowRunOutputParameter:
try:
async with self.Session() as session:
@@ -1149,7 +1166,9 @@ class AgentDB:
(
workflow_parameter,
convert_to_workflow_run_parameter(
workflow_run_parameter, workflow_parameter, self.debug_enabled
workflow_run_parameter,
workflow_parameter,
self.debug_enabled,
),
)
)

View File

@@ -63,7 +63,11 @@ class TaskModel(Base):
max_steps_per_run = Column(Integer, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False, index=True)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False, index=True
DateTime,
default=datetime.datetime.utcnow,
onupdate=datetime.datetime.utcnow,
nullable=False,
index=True,
)
@@ -80,7 +84,12 @@ class StepModel(Base):
is_last = Column(Boolean, default=False)
retry_index = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
modified_at = Column(
DateTime,
default=datetime.datetime.utcnow,
onupdate=datetime.datetime.utcnow,
nullable=False,
)
input_token_count = Column(Integer, default=0)
output_token_count = Column(Integer, default=0)
step_cost = Column(Numeric, default=0)
@@ -96,7 +105,12 @@ class OrganizationModel(Base):
max_retries_per_step = Column(Integer, nullable=True)
domain = Column(String, nullable=True, index=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
modified_at = Column(
DateTime,
default=datetime.datetime.utcnow,
onupdate=datetime.datetime,
nullable=False,
)
class OrganizationAuthTokenModel(Base):
@@ -115,7 +129,12 @@ class OrganizationAuthTokenModel(Base):
valid = Column(Boolean, nullable=False, default=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
modified_at = Column(
DateTime,
default=datetime.datetime.utcnow,
onupdate=datetime.datetime,
nullable=False,
)
deleted_at = Column(DateTime, nullable=True)
@@ -130,13 +149,23 @@ class ArtifactModel(Base):
artifact_type = Column(String)
uri = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
modified_at = Column(
DateTime,
default=datetime.datetime.utcnow,
onupdate=datetime.datetime.utcnow,
nullable=False,
)
class WorkflowModel(Base):
__tablename__ = "workflows"
__table_args__ = (
UniqueConstraint("organization_id", "workflow_permanent_id", "version", name="uc_org_permanent_id_version"),
UniqueConstraint(
"organization_id",
"workflow_permanent_id",
"version",
name="uc_org_permanent_id_version",
),
Index("permanent_id_version_idx", "workflow_permanent_id", "version"),
)
@@ -149,7 +178,12 @@ class WorkflowModel(Base):
webhook_callback_url = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
modified_at = Column(
DateTime,
default=datetime.datetime.utcnow,
onupdate=datetime.datetime.utcnow,
nullable=False,
)
deleted_at = Column(DateTime, nullable=True)
workflow_permanent_id = Column(String, nullable=False, default=generate_workflow_permanent_id, index=True)
@@ -166,7 +200,12 @@ class WorkflowRunModel(Base):
webhook_callback_url = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
modified_at = Column(
DateTime,
default=datetime.datetime.utcnow,
onupdate=datetime.datetime.utcnow,
nullable=False,
)
class WorkflowParameterModel(Base):
@@ -179,7 +218,12 @@ class WorkflowParameterModel(Base):
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
default_value = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
modified_at = Column(
DateTime,
default=datetime.datetime.utcnow,
onupdate=datetime.datetime.utcnow,
nullable=False,
)
deleted_at = Column(DateTime, nullable=True)
@@ -191,7 +235,12 @@ class OutputParameterModel(Base):
description = Column(String, nullable=True)
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
modified_at = Column(
DateTime,
default=datetime.datetime.utcnow,
onupdate=datetime.datetime.utcnow,
nullable=False,
)
deleted_at = Column(DateTime, nullable=True)
@@ -204,7 +253,12 @@ class AWSSecretParameterModel(Base):
description = Column(String, nullable=True)
aws_key = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
modified_at = Column(
DateTime,
default=datetime.datetime.utcnow,
onupdate=datetime.datetime.utcnow,
nullable=False,
)
deleted_at = Column(DateTime, nullable=True)
@@ -212,7 +266,10 @@ class BitwardenLoginCredentialParameterModel(Base):
__tablename__ = "bitwarden_login_credential_parameters"
bitwarden_login_credential_parameter_id = Column(
String, primary_key=True, index=True, default=generate_bitwarden_login_credential_parameter_id
String,
primary_key=True,
index=True,
default=generate_bitwarden_login_credential_parameter_id,
)
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
key = Column(String, nullable=False)
@@ -222,16 +279,29 @@ class BitwardenLoginCredentialParameterModel(Base):
bitwarden_master_password_aws_secret_key = Column(String, nullable=False)
url_parameter_key = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
modified_at = Column(
DateTime,
default=datetime.datetime.utcnow,
onupdate=datetime.datetime.utcnow,
nullable=False,
)
deleted_at = Column(DateTime, nullable=True)
class WorkflowRunParameterModel(Base):
__tablename__ = "workflow_run_parameters"
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), primary_key=True, index=True)
workflow_run_id = Column(
String,
ForeignKey("workflow_runs.workflow_run_id"),
primary_key=True,
index=True,
)
workflow_parameter_id = Column(
String, ForeignKey("workflow_parameters.workflow_parameter_id"), primary_key=True, index=True
String,
ForeignKey("workflow_parameters.workflow_parameter_id"),
primary_key=True,
index=True,
)
# Can be bool | int | float | str | dict | list depending on the workflow parameter type
value = Column(String, nullable=False)
@@ -241,9 +311,17 @@ class WorkflowRunParameterModel(Base):
class WorkflowRunOutputParameterModel(Base):
__tablename__ = "workflow_run_output_parameters"
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), primary_key=True, index=True)
workflow_run_id = Column(
String,
ForeignKey("workflow_runs.workflow_run_id"),
primary_key=True,
index=True,
)
output_parameter_id = Column(
String, ForeignKey("output_parameters.output_parameter_id"), primary_key=True, index=True
String,
ForeignKey("output_parameters.output_parameter_id"),
primary_key=True,
index=True,
)
value = Column(JSON, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)

View File

@@ -67,7 +67,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
extracted_information=task_obj.extracted_information,
failure_reason=task_obj.failure_reason,
organization_id=task_obj.organization_id,
proxy_location=ProxyLocation(task_obj.proxy_location) if task_obj.proxy_location else None,
proxy_location=(ProxyLocation(task_obj.proxy_location) if task_obj.proxy_location else None),
extracted_information_schema=task_obj.extracted_information_schema,
workflow_run_id=task_obj.workflow_run_id,
order=task_obj.order,
@@ -112,7 +112,9 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization:
)
def convert_to_organization_auth_token(org_auth_token: OrganizationAuthTokenModel) -> OrganizationAuthToken:
def convert_to_organization_auth_token(
org_auth_token: OrganizationAuthTokenModel,
) -> OrganizationAuthToken:
return OrganizationAuthToken(
id=org_auth_token.id,
organization_id=org_auth_token.organization_id,
@@ -126,7 +128,10 @@ def convert_to_organization_auth_token(org_auth_token: OrganizationAuthTokenMode
def convert_to_artifact(artifact_model: ArtifactModel, debug_enabled: bool = False) -> Artifact:
if debug_enabled:
LOG.debug("Converting ArtifactModel to Artifact", artifact_id=artifact_model.artifact_id)
LOG.debug(
"Converting ArtifactModel to Artifact",
artifact_id=artifact_model.artifact_id,
)
return Artifact(
artifact_id=artifact_model.artifact_id,
@@ -142,7 +147,10 @@ def convert_to_artifact(artifact_model: ArtifactModel, debug_enabled: bool = Fal
def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = False) -> Workflow:
if debug_enabled:
LOG.debug("Converting WorkflowModel to Workflow", workflow_id=workflow_model.workflow_id)
LOG.debug(
"Converting WorkflowModel to Workflow",
workflow_id=workflow_model.workflow_id,
)
return Workflow(
workflow_id=workflow_model.workflow_id,
@@ -150,7 +158,7 @@ def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = Fal
title=workflow_model.title,
workflow_permanent_id=workflow_model.workflow_permanent_id,
webhook_callback_url=workflow_model.webhook_callback_url,
proxy_location=ProxyLocation(workflow_model.proxy_location) if workflow_model.proxy_location else None,
proxy_location=(ProxyLocation(workflow_model.proxy_location) if workflow_model.proxy_location else None),
version=workflow_model.version,
description=workflow_model.description,
workflow_definition=WorkflowDefinition.model_validate(workflow_model.workflow_definition),
@@ -162,13 +170,18 @@ def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = Fal
def convert_to_workflow_run(workflow_run_model: WorkflowRunModel, debug_enabled: bool = False) -> WorkflowRun:
if debug_enabled:
LOG.debug("Converting WorkflowRunModel to WorkflowRun", workflow_run_id=workflow_run_model.workflow_run_id)
LOG.debug(
"Converting WorkflowRunModel to WorkflowRun",
workflow_run_id=workflow_run_model.workflow_run_id,
)
return WorkflowRun(
workflow_run_id=workflow_run_model.workflow_run_id,
workflow_id=workflow_run_model.workflow_id,
status=WorkflowRunStatus[workflow_run_model.status],
proxy_location=ProxyLocation(workflow_run_model.proxy_location) if workflow_run_model.proxy_location else None,
proxy_location=(
ProxyLocation(workflow_run_model.proxy_location) if workflow_run_model.proxy_location else None
),
webhook_callback_url=workflow_run_model.webhook_callback_url,
created_at=workflow_run_model.created_at,
modified_at=workflow_run_model.modified_at,
@@ -221,7 +234,8 @@ def convert_to_aws_secret_parameter(
def convert_to_bitwarden_login_credential_parameter(
bitwarden_login_credential_parameter_model: BitwardenLoginCredentialParameterModel, debug_enabled: bool = False
bitwarden_login_credential_parameter_model: BitwardenLoginCredentialParameterModel,
debug_enabled: bool = False,
) -> BitwardenLoginCredentialParameter:
if debug_enabled:
LOG.debug(

View File

@@ -91,7 +91,10 @@ class BackgroundTaskExecutor(AsyncExecutor):
api_key: str | None,
**kwargs: dict,
) -> None:
LOG.info("Executing workflow using background task executor", workflow_run_id=workflow_run_id)
LOG.info(
"Executing workflow using background task executor",
workflow_run_id=workflow_run_id,
)
background_tasks.add_task(
app.WORKFLOW_SERVICE.execute_workflow,
workflow_run_id=workflow_run_id,

View File

@@ -53,7 +53,12 @@ class Step(BaseModel):
output_token_count: int = 0
step_cost: float = 0
def validate_update(self, status: StepStatus | None, output: AgentStepOutput | None, is_last: bool | None) -> None:
def validate_update(
self,
status: StepStatus | None,
output: AgentStepOutput | None,
is_last: bool | None,
) -> None:
old_status = self.status
if status and not old_status.can_update_to(status):

View File

@@ -78,7 +78,12 @@ class PromptEngine:
matches = get_close_matches(target, model_dirs, n=1, cutoff=0.1)
return matches[0]
except Exception:
LOG.error("Failed to get closest match.", target=target, model_dirs=model_dirs, exc_info=True)
LOG.error(
"Failed to get closest match.",
target=target,
model_dirs=model_dirs,
exc_info=True,
)
raise
def load_prompt(self, template: str, **kwargs: Any) -> str:
@@ -97,7 +102,12 @@ class PromptEngine:
jinja_template = self.env.get_template(f"{template}.j2")
return jinja_template.render(**kwargs)
except Exception:
LOG.error("Failed to load prompt.", template=template, kwargs_keys=kwargs.keys(), exc_info=True)
LOG.error(
"Failed to load prompt.",
template=template,
kwargs_keys=kwargs.keys(),
exc_info=True,
)
raise
def load_prompt_from_string(self, template: str, **kwargs: Any) -> str:
@@ -115,5 +125,10 @@ class PromptEngine:
jinja_template = self.env.from_string(template)
return jinja_template.render(**kwargs)
except Exception:
LOG.error("Failed to load prompt from string.", template=template, kwargs_keys=kwargs.keys(), exc_info=True)
LOG.error(
"Failed to load prompt from string.",
template=template,
kwargs_keys=kwargs.keys(),
exc_info=True,
)
raise

View File

@@ -54,7 +54,10 @@ async def webhook(
x_skyvern_timestamp=x_skyvern_timestamp,
payload=payload,
)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Missing webhook signature or timestamp")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing webhook signature or timestamp",
)
generated_signature = generate_skyvern_signature(
payload.decode("utf-8"),
@@ -82,7 +85,12 @@ async def check_server_status() -> Response:
@base_router.post("/tasks", tags=["agent"], response_model=CreateTaskResponse)
@base_router.post("/tasks/", tags=["agent"], response_model=CreateTaskResponse, include_in_schema=False)
@base_router.post(
"/tasks/",
tags=["agent"],
response_model=CreateTaskResponse,
include_in_schema=False,
)
async def create_agent_task(
background_tasks: BackgroundTasks,
task: TaskRequest,
@@ -342,13 +350,21 @@ async def get_agent_tasks(
"""
analytics.capture("skyvern-oss-agent-tasks-get")
tasks = await app.DATABASE.get_tasks(
page, page_size, task_status=task_status, organization_id=current_org.organization_id
page,
page_size,
task_status=task_status,
organization_id=current_org.organization_id,
)
return ORJSONResponse([task.to_task_response().model_dump() for task in tasks])
@base_router.get("/internal/tasks", tags=["agent"], response_model=list[Task])
@base_router.get("/internal/tasks/", tags=["agent"], response_model=list[Task], include_in_schema=False)
@base_router.get(
"/internal/tasks/",
tags=["agent"],
response_model=list[Task],
include_in_schema=False,
)
async def get_agent_tasks_internal(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1),
@@ -367,7 +383,12 @@ async def get_agent_tasks_internal(
@base_router.get("/tasks/{task_id}/steps", tags=["agent"], response_model=list[Step])
@base_router.get("/tasks/{task_id}/steps/", tags=["agent"], response_model=list[Step], include_in_schema=False)
@base_router.get(
"/tasks/{task_id}/steps/",
tags=["agent"],
response_model=list[Step],
include_in_schema=False,
)
async def get_agent_task_steps(
task_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
@@ -382,7 +403,11 @@ async def get_agent_task_steps(
return ORJSONResponse([step.model_dump() for step in steps])
@base_router.get("/tasks/{task_id}/steps/{step_id}/artifacts", tags=["agent"], response_model=list[Artifact])
@base_router.get(
"/tasks/{task_id}/steps/{step_id}/artifacts",
tags=["agent"],
response_model=list[Artifact],
)
@base_router.get(
"/tasks/{task_id}/steps/{step_id}/artifacts/",
tags=["agent"],
@@ -412,7 +437,11 @@ async def get_agent_task_step_artifacts(
for i, artifact in enumerate(artifacts):
artifact.signed_url = signed_urls[i]
else:
LOG.warning("Failed to get signed urls for artifacts", task_id=task_id, step_id=step_id)
LOG.warning(
"Failed to get signed urls for artifacts",
task_id=task_id,
step_id=step_id,
)
return ORJSONResponse([artifact.model_dump() for artifact in artifacts])
@@ -424,7 +453,11 @@ class ActionResultTmp(BaseModel):
@base_router.get("/tasks/{task_id}/actions", response_model=list[ActionResultTmp])
@base_router.get("/tasks/{task_id}/actions/", response_model=list[ActionResultTmp], include_in_schema=False)
@base_router.get(
"/tasks/{task_id}/actions/",
response_model=list[ActionResultTmp],
include_in_schema=False,
)
async def get_task_actions(
task_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
@@ -441,7 +474,11 @@ async def get_task_actions(
@base_router.post("/workflows/{workflow_id}/run", response_model=RunWorkflowResponse)
@base_router.post("/workflows/{workflow_id}/run/", response_model=RunWorkflowResponse, include_in_schema=False)
@base_router.post(
"/workflows/{workflow_id}/run/",
response_model=RunWorkflowResponse,
include_in_schema=False,
)
async def execute_workflow(
background_tasks: BackgroundTasks,
workflow_id: str,
@@ -476,7 +513,10 @@ async def execute_workflow(
)
@base_router.get("/workflows/{workflow_id}/runs/{workflow_run_id}", response_model=WorkflowRunStatusResponse)
@base_router.get(
"/workflows/{workflow_id}/runs/{workflow_run_id}",
response_model=WorkflowRunStatusResponse,
)
@base_router.get(
"/workflows/{workflow_id}/runs/{workflow_run_id}/",
response_model=WorkflowRunStatusResponse,

View File

@@ -82,13 +82,27 @@ class TaskStatus(StrEnum):
completed = "completed"
def is_final(self) -> bool:
return self in {TaskStatus.failed, TaskStatus.terminated, TaskStatus.completed, TaskStatus.timed_out}
return self in {
TaskStatus.failed,
TaskStatus.terminated,
TaskStatus.completed,
TaskStatus.timed_out,
}
def can_update_to(self, new_status: TaskStatus) -> bool:
allowed_transitions: dict[TaskStatus, set[TaskStatus]] = {
TaskStatus.created: {TaskStatus.queued, TaskStatus.running, TaskStatus.timed_out},
TaskStatus.created: {
TaskStatus.queued,
TaskStatus.running,
TaskStatus.timed_out,
},
TaskStatus.queued: {TaskStatus.running, TaskStatus.timed_out},
TaskStatus.running: {TaskStatus.completed, TaskStatus.failed, TaskStatus.terminated, TaskStatus.timed_out},
TaskStatus.running: {
TaskStatus.completed,
TaskStatus.failed,
TaskStatus.terminated,
TaskStatus.timed_out,
},
TaskStatus.failed: set(),
TaskStatus.terminated: set(),
TaskStatus.completed: set(),

View File

@@ -53,7 +53,11 @@ class BitwardenService:
"""
# Step 1: Set up environment variables and log in
try:
env = {"BW_CLIENTID": client_id, "BW_CLIENTSECRET": client_secret, "BW_PASSWORD": master_password}
env = {
"BW_CLIENTID": client_id,
"BW_CLIENTSECRET": client_secret,
"BW_PASSWORD": master_password,
}
login_command = ["bw", "login", "--apikey"]
login_result = BitwardenService.run_command(login_command, env)
@@ -81,7 +85,15 @@ class BitwardenService:
raise BitwardenUnlockError("Session key is empty.")
# Step 3: Retrieve the items
list_command = ["bw", "list", "items", "--url", url, "--session", session_key]
list_command = [
"bw",
"list",
"items",
"--url",
url,
"--session",
session_key,
]
items_result = BitwardenService.run_command(list_command)
if items_result.stderr and "Event post failed" not in items_result.stderr:
@@ -100,7 +112,11 @@ class BitwardenService:
totp_result = BitwardenService.run_command(totp_command)
if totp_result.stderr and "Event post failed" not in totp_result.stderr:
LOG.warning("Bitwarden TOTP Error", error=totp_result.stderr, e=BitwardenTOTPError(totp_result.stderr))
LOG.warning(
"Bitwarden TOTP Error",
error=totp_result.stderr,
e=BitwardenTOTPError(totp_result.stderr),
)
totp_code = totp_result.stdout
credentials: list[dict[str, str]] = [

View File

@@ -39,7 +39,9 @@ async def get_current_org(
)
async def get_current_org_with_api_key(x_api_key: Annotated[str | None, Header()] = None) -> Organization:
async def get_current_org_with_api_key(
x_api_key: Annotated[str | None, Header()] = None,
) -> Organization:
if not x_api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -48,7 +50,9 @@ async def get_current_org_with_api_key(x_api_key: Annotated[str | None, Header()
return await _get_current_org_cached(x_api_key, app.DATABASE)
async def get_current_org_with_authentication(authorization: Annotated[str | None, Header()] = None) -> Organization:
async def get_current_org_with_authentication(
authorization: Annotated[str | None, Header()] = None,
) -> Organization:
if not authorization:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,

View File

@@ -35,5 +35,5 @@ async def create_org_api_token(org_id: str) -> OrganizationAuthToken:
token=api_key,
token_type=OrganizationAuthTokenType.api,
)
LOG.info(f"Created API token for organization", organization_id=org_id)
LOG.info("Created API token for organization", organization_id=org_id)
return org_auth_token

View File

@@ -93,7 +93,7 @@ class WorkflowRunContext:
assume it's an actual parameter value and return it.
"""
if type(secret_id_or_value) is str:
if isinstance(secret_id_or_value, str):
return self.secrets.get(secret_id_or_value)
return None
@@ -149,7 +149,7 @@ class WorkflowRunContext:
url = self.values[parameter.url_parameter_key]
else:
LOG.error(f"URL parameter {parameter.url_parameter_key} not found or has no value")
raise ValueError(f"URL parameter for Bitwarden login credentials not found or has no value")
raise ValueError("URL parameter for Bitwarden login credentials not found or has no value")
try:
secret_credentials = BitwardenService.get_secret_value_from_url(
@@ -224,7 +224,9 @@ class WorkflowRunContext:
await self.set_parameter_values_for_output_parameter_dependent_blocks(parameter, value)
async def set_parameter_values_for_output_parameter_dependent_blocks(
self, output_parameter: OutputParameter, value: dict[str, Any] | list | str | None
self,
output_parameter: OutputParameter,
value: dict[str, Any] | list | str | None,
) -> None:
for key, parameter in self.parameters.items():
if (
@@ -268,7 +270,7 @@ class WorkflowRunContext:
isinstance(x, ContextParameter),
# This makes sure that ContextParameters witha ContextParameter source are processed after all other
# ContextParameters
isinstance(x.source, ContextParameter) if isinstance(x, ContextParameter) else False,
(isinstance(x.source, ContextParameter) if isinstance(x, ContextParameter) else False),
isinstance(x, BitwardenLoginCredentialParameter),
)
)

View File

@@ -81,16 +81,20 @@ class Block(BaseModel, abc.ABC):
value=value,
)
LOG.info(
f"Registered output parameter value",
"Registered output parameter value",
output_parameter_id=self.output_parameter.output_parameter_id,
workflow_run_id=workflow_run_id,
)
def build_block_result(
self, success: bool, output_parameter_value: dict[str, Any] | list | str | None = None
self,
success: bool,
output_parameter_value: dict[str, Any] | list | str | None = None,
) -> BlockResult:
return BlockResult(
success=success, output_parameter=self.output_parameter, output_parameter_value=output_parameter_value
success=success,
output_parameter=self.output_parameter,
output_parameter_value=output_parameter_value,
)
@classmethod
@@ -236,11 +240,14 @@ class TaskBlock(Block):
workflow_run=workflow_run, url=self.url
)
if not browser_state.page:
LOG.error("BrowserState has no page", workflow_run_id=workflow_run.workflow_run_id)
LOG.error(
"BrowserState has no page",
workflow_run_id=workflow_run.workflow_run_id,
)
raise MissingBrowserStatePage(workflow_run_id=workflow_run.workflow_run_id)
LOG.info(
f"Navigating to page",
"Navigating to page",
url=self.url,
workflow_run_id=workflow_run_id,
task_id=task.task_id,
@@ -253,7 +260,12 @@ class TaskBlock(Block):
await browser_state.page.goto(self.url, timeout=settings.BROWSER_LOADING_TIMEOUT_MS)
try:
await app.agent.execute_step(organization=organization, task=task, step=step, workflow_run=workflow_run)
await app.agent.execute_step(
organization=organization,
task=task,
step=step,
workflow_run=workflow_run,
)
except Exception as e:
# Make sure the task is marked as failed in the database before raising the exception
await app.DATABASE.update_task(
@@ -273,7 +285,7 @@ class TaskBlock(Block):
if updated_task.status == TaskStatus.completed or updated_task.status == TaskStatus.terminated:
LOG.info(
f"Task completed",
"Task completed",
task_id=updated_task.task_id,
task_status=updated_task.status,
workflow_run_id=workflow_run_id,
@@ -400,7 +412,7 @@ class ForLoopBlock(Block):
)
if not loop_over_values or len(loop_over_values) == 0:
LOG.info(
f"No loop_over values found",
"No loop_over values found",
block_type=self.block_type,
workflow_run_id=workflow_run_id,
num_loop_over_values=len(loop_over_values),
@@ -519,7 +531,11 @@ class TextPromptBlock(Block):
+ json.dumps(self.json_schema, indent=2)
+ "\n```\n\n"
)
LOG.info("TextPromptBlock: Sending prompt to LLM", prompt=prompt, llm_key=self.llm_key)
LOG.info(
"TextPromptBlock: Sending prompt to LLM",
prompt=prompt,
llm_key=self.llm_key,
)
response = await llm_api_handler(prompt=prompt)
LOG.info("TextPromptBlock: Received response from LLM", response=response)
return response
@@ -692,7 +708,12 @@ class SendEmailBlock(Block):
workflow_run_id: str,
) -> list[PARAMETER_TYPE]:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
parameters = [self.smtp_host, self.smtp_port, self.smtp_username, self.smtp_password]
parameters = [
self.smtp_host,
self.smtp_port,
self.smtp_username,
self.smtp_password,
]
if self.file_attachments:
for file_path in self.file_attachments:
@@ -732,7 +753,12 @@ class SendEmailBlock(Block):
if email_config_problems:
raise InvalidEmailClientConfiguration(email_config_problems)
return smtp_host_value, smtp_port_value, smtp_username_value, smtp_password_value
return (
smtp_host_value,
smtp_port_value,
smtp_username_value,
smtp_password_value,
)
def _get_file_paths(self, workflow_run_context: WorkflowRunContext, workflow_run_id: str) -> list[str]:
file_paths = []
@@ -846,7 +872,12 @@ class SendEmailBlock(Block):
subtype=subtype,
)
with open(path, "rb") as fp:
msg.add_attachment(fp.read(), maintype=maintype, subtype=subtype, filename=attachment_filename)
msg.add_attachment(
fp.read(),
maintype=maintype,
subtype=subtype,
filename=attachment_filename,
)
finally:
if path:
os.unlink(path)
@@ -884,6 +915,12 @@ class SendEmailBlock(Block):
BlockSubclasses = Union[
ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock, DownloadToS3Block, UploadToS3Block, SendEmailBlock
ForLoopBlock,
TaskBlock,
CodeBlock,
TextPromptBlock,
DownloadToS3Block,
UploadToS3Block,
SendEmailBlock,
]
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]

View File

@@ -114,6 +114,10 @@ class OutputParameter(Parameter):
ParameterSubclasses = Union[
WorkflowParameter, ContextParameter, AWSSecretParameter, BitwardenLoginCredentialParameter, OutputParameter
WorkflowParameter,
ContextParameter,
AWSSecretParameter,
BitwardenLoginCredentialParameter,
OutputParameter,
]
PARAMETER_TYPE = Annotated[ParameterSubclasses, Field(discriminator="parameter_type")]

View File

@@ -166,7 +166,10 @@ class WorkflowService:
wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id)
workflow_output_parameters = await self.get_workflow_output_parameters(workflow_id=workflow.workflow_id)
app.WORKFLOW_CONTEXT_MANAGER.initialize_workflow_run_context(
workflow_run_id, wp_wps_tuples, workflow_output_parameters, context_parameters
workflow_run_id,
wp_wps_tuples,
workflow_output_parameters,
context_parameters,
)
# Execute workflow blocks
blocks = workflow.workflow_definition.blocks
@@ -203,7 +206,11 @@ class WorkflowService:
)
else:
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
await self.send_workflow_response(workflow=workflow, workflow_run=workflow_run, api_key=api_key)
await self.send_workflow_response(
workflow=workflow,
workflow_run=workflow_run,
api_key=api_key,
)
return workflow_run
except Exception:
@@ -224,13 +231,21 @@ class WorkflowService:
# Create a mapping of status to (action, log_func, log_message)
status_action_mapping = {
TaskStatus.running: (None, LOG.error, "has running tasks, this should not happen"),
TaskStatus.running: (
None,
LOG.error,
"has running tasks, this should not happen",
),
TaskStatus.terminated: (
self.mark_workflow_run_as_terminated,
LOG.warning,
"has terminated tasks, marking as terminated",
),
TaskStatus.failed: (self.mark_workflow_run_as_failed, LOG.warning, "has failed tasks, marking as failed"),
TaskStatus.failed: (
self.mark_workflow_run_as_failed,
LOG.warning,
"has failed tasks, marking as failed",
),
TaskStatus.completed: (
self.mark_workflow_run_as_completed,
LOG.info,
@@ -333,7 +348,7 @@ class WorkflowService:
title=title,
organization_id=organization_id,
description=description,
workflow_definition=workflow_definition.model_dump() if workflow_definition else None,
workflow_definition=(workflow_definition.model_dump() if workflow_definition else None),
)
async def delete_workflow_by_permanent_id(
@@ -529,7 +544,10 @@ class WorkflowService:
for task in workflow_run_tasks[::-1]:
screenshot_artifact = await app.DATABASE.get_latest_artifact(
task_id=task.task_id,
artifact_types=[ArtifactType.SCREENSHOT_ACTION, ArtifactType.SCREENSHOT_FINAL],
artifact_types=[
ArtifactType.SCREENSHOT_ACTION,
ArtifactType.SCREENSHOT_FINAL,
],
organization_id=organization_id,
)
if screenshot_artifact:
@@ -541,17 +559,19 @@ class WorkflowService:
recording_url = None
recording_artifact = await app.DATABASE.get_artifact_for_workflow_run(
workflow_run_id=workflow_run_id, artifact_type=ArtifactType.RECORDING, organization_id=organization_id
workflow_run_id=workflow_run_id,
artifact_type=ArtifactType.RECORDING,
organization_id=organization_id,
)
if recording_artifact:
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples}
output_parameter_tuples: list[tuple[OutputParameter, WorkflowRunOutputParameter]] = (
await self.get_output_parameter_workflow_run_output_parameter_tuples(
workflow_id=workflow_id, workflow_run_id=workflow_run_id
)
output_parameter_tuples: list[
tuple[OutputParameter, WorkflowRunOutputParameter]
] = await self.get_output_parameter_workflow_run_output_parameter_tuples(
workflow_id=workflow_id, workflow_run_id=workflow_run_id
)
if output_parameter_tuples:
outputs = {output_parameter.key: output.value for output_parameter, output in output_parameter_tuples}
@@ -587,7 +607,9 @@ class WorkflowService:
tasks = await self.get_tasks_by_workflow_run_id(workflow_run.workflow_run_id)
all_workflow_task_ids = [task.task_id for task in tasks]
browser_state = await app.BROWSER_MANAGER.cleanup_for_workflow_run(
workflow_run.workflow_run_id, all_workflow_task_ids, close_browser_on_completion
workflow_run.workflow_run_id,
all_workflow_task_ids,
close_browser_on_completion,
)
if browser_state:
await self.persist_video_data(browser_state, workflow, workflow_run)
@@ -600,7 +622,10 @@ class WorkflowService:
workflow_run_id=workflow_run.workflow_run_id,
organization_id=workflow.organization_id,
)
LOG.info("Built workflow run status response", workflow_run_status_response=workflow_run_status_response)
LOG.info(
"Built workflow run status response",
workflow_run_status_response=workflow_run_status_response,
)
if not workflow_run.webhook_callback_url:
LOG.warning(
@@ -661,7 +686,8 @@ class WorkflowService:
)
except Exception as e:
raise FailedToSendWebhook(
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
) from e
async def persist_video_data(
@@ -681,10 +707,16 @@ class WorkflowService:
)
async def persist_har_data(
self, browser_state: BrowserState, last_step: Step, workflow: Workflow, workflow_run: WorkflowRun
self,
browser_state: BrowserState,
last_step: Step,
workflow: Workflow,
workflow_run: WorkflowRun,
) -> None:
har_data = await app.BROWSER_MANAGER.get_har_data(
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id, browser_state=browser_state
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
browser_state=browser_state,
)
if har_data:
await app.ARTIFACT_MANAGER.create_artifact(
@@ -703,7 +735,11 @@ class WorkflowService:
await app.ARTIFACT_MANAGER.create_artifact(step=last_step, artifact_type=ArtifactType.TRACE, path=trace_path)
async def persist_debug_artifacts(
self, browser_state: BrowserState, last_task: Task, workflow: Workflow, workflow_run: WorkflowRun
self,
browser_state: BrowserState,
last_task: Task,
workflow: Workflow,
workflow_run: WorkflowRun,
) -> None:
last_step = await app.DATABASE.get_latest_step(
task_id=last_task.task_id, organization_id=last_task.organization_id
@@ -720,7 +756,11 @@ class WorkflowService:
request: WorkflowCreateYAMLRequest,
workflow_permanent_id: str | None = None,
) -> Workflow:
LOG.info("Creating workflow from request", organization_id=organization_id, title=request.title)
LOG.info(
"Creating workflow from request",
organization_id=organization_id,
title=request.title,
)
try:
if workflow_permanent_id:
existing_latest_workflow = await self.get_workflow_by_permanent_id(
@@ -769,7 +809,8 @@ class WorkflowService:
# Create output parameters for all blocks
block_output_parameters = await WorkflowService._create_all_output_parameters_for_workflow(
workflow_id=workflow.workflow_id, block_yamls=request.workflow_definition.blocks
workflow_id=workflow.workflow_id,
block_yamls=request.workflow_definition.blocks,
)
for block_output_parameter in block_output_parameters.values():
parameters[block_output_parameter.key] = block_output_parameter
@@ -822,7 +863,8 @@ class WorkflowService:
for context_parameter in context_parameter_yamls:
if context_parameter.source_parameter_key not in parameters:
raise ContextParameterSourceNotDefined(
context_parameter_key=context_parameter.key, source_key=context_parameter.source_parameter_key
context_parameter_key=context_parameter.key,
source_key=context_parameter.source_parameter_key,
)
if context_parameter.key in parameters:
@@ -901,7 +943,9 @@ class WorkflowService:
@staticmethod
async def block_yaml_to_block(
workflow: Workflow, block_yaml: BLOCK_YAML_TYPES, parameters: dict[str, Parameter]
workflow: Workflow,
block_yaml: BLOCK_YAML_TYPES,
parameters: dict[str, Parameter],
) -> BlockTypeVar:
output_parameter = parameters[f"{block_yaml.label}_output"]
if block_yaml.block_type == BlockType.TASK: