Add Run id as a parameter to the artifacts table (#2799)

Co-authored-by: Suchintan Singh <suchintansingh@gmail.com>
This commit is contained in:
Shuchang Zheng
2025-06-27 00:27:48 +09:00
committed by GitHub
parent a549e19f61
commit 43cab04454
12 changed files with 69 additions and 15 deletions

View File

@@ -33,6 +33,7 @@ class ArtifactManager:
workflow_run_block_id: str | None = None,
thought_id: str | None = None,
task_v2_id: str | None = None,
run_id: str | None = None,
ai_suggestion_id: str | None = None,
data: bytes | None = None,
path: str | None = None,
@@ -49,6 +50,8 @@ class ArtifactManager:
task_v2_id = context.task_v2_id
if not task_id and context:
task_id = context.task_id
if not run_id and context:
run_id = context.run_id
artifact = await app.DATABASE.create_artifact(
artifact_id,
@@ -60,6 +63,7 @@ class ArtifactManager:
workflow_run_block_id=workflow_run_block_id,
thought_id=thought_id,
task_v2_id=task_v2_id,
run_id=run_id,
organization_id=organization_id,
ai_suggestion_id=ai_suggestion_id,
)

View File

@@ -18,6 +18,7 @@ class SkyvernContext:
max_steps_override: int | None = None
browser_session_id: str | None = None
tz_info: ZoneInfo | None = None
run_id: str | None = None
totp_codes: dict[str, str | None] = field(default_factory=dict)
log: list[dict] = field(default_factory=list)
hashed_href_map: dict[str, str] = field(default_factory=dict)
@@ -26,7 +27,7 @@ class SkyvernContext:
max_screenshot_scrolling_times: int | None = None
def __repr__(self) -> str:
return f"SkyvernContext(request_id={self.request_id}, organization_id={self.organization_id}, task_id={self.task_id}, workflow_id={self.workflow_id}, workflow_run_id={self.workflow_run_id}, task_v2_id={self.task_v2_id}, max_steps_override={self.max_steps_override})"
return f"SkyvernContext(request_id={self.request_id}, organization_id={self.organization_id}, task_id={self.task_id}, workflow_id={self.workflow_id}, workflow_run_id={self.workflow_run_id}, task_v2_id={self.task_v2_id}, max_steps_override={self.max_steps_override}, run_id={self.run_id})"
def __str__(self) -> str:
return self.__repr__()

View File

@@ -224,14 +224,15 @@ class AgentDB:
artifact_id: str,
artifact_type: str,
uri: str,
organization_id: str,
step_id: str | None = None,
task_id: str | None = None,
workflow_run_id: str | None = None,
workflow_run_block_id: str | None = None,
task_v2_id: str | None = None,
run_id: str | None = None,
thought_id: str | None = None,
ai_suggestion_id: str | None = None,
organization_id: str | None = None,
) -> Artifact:
try:
async with self.Session() as session:
@@ -245,6 +246,7 @@ class AgentDB:
workflow_run_block_id=workflow_run_block_id,
observer_cruise_id=task_v2_id,
observer_thought_id=thought_id,
run_id=run_id,
ai_suggestion_id=ai_suggestion_id,
organization_id=organization_id,
)
@@ -1024,18 +1026,7 @@ class AgentDB:
async with self.Session() as session:
query = select(ArtifactModel).filter_by(organization_id=organization_id)
if run.task_run_type in [
RunType.task_v1,
RunType.openai_cua,
RunType.anthropic_cua,
]:
query = query.filter_by(task_id=run.run_id)
elif run.task_run_type == RunType.task_v2:
query = query.filter_by(observer_cruise_id=run.run_id)
elif run.task_run_type == RunType.workflow_run:
query = query.filter_by(workflow_run_id=run.run_id)
else:
return []
query = query.filter_by(run_id=run.run_id)
if artifact_types:
query = query.filter(ArtifactModel.artifact_type.in_(artifact_types))

View File

@@ -193,6 +193,7 @@ class ArtifactModel(Base):
step_id = Column(String, index=True)
artifact_type = Column(String)
uri = Column(String)
run_id = Column(String, nullable=True, index=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(
DateTime,

View File

@@ -105,6 +105,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
context: SkyvernContext = skyvern_context.ensure_context()
context.task_id = task.task_id
context.run_id = context.run_id or task.task_id
context.organization_id = organization_id
context.max_steps_override = max_steps_override
context.max_screenshot_scrolling_times = task.max_screenshot_scrolling_times

View File

@@ -31,6 +31,8 @@ def add_kv_pairs_to_msg(logger: logging.Logger, method_name: str, event_dict: Ev
event_dict["organization_name"] = context.organization_name
if context.task_id:
event_dict["task_id"] = context.task_id
if context.run_id:
event_dict["run_id"] = context.run_id
if context.workflow_id:
event_dict["workflow_id"] = context.workflow_id
if context.workflow_run_id:

View File

@@ -2544,6 +2544,8 @@ class TaskV2Block(Block):
browser_session_id=browser_session_id,
)
finally:
context: skyvern_context.SkyvernContext | None = skyvern_context.current()
current_run_id = context.run_id if context and context.run_id else workflow_run_id
skyvern_context.set(
skyvern_context.SkyvernContext(
organization_id=organization_id,
@@ -2551,6 +2553,7 @@ class TaskV2Block(Block):
workflow_id=workflow_run.workflow_id,
workflow_permanent_id=workflow_run.workflow_permanent_id,
workflow_run_id=workflow_run_id,
run_id=current_run_id,
browser_session_id=browser_session_id,
max_screenshot_scrolling_times=workflow_run.max_screenshot_scrolling_times,
)

View File

@@ -177,6 +177,8 @@ class WorkflowService:
webhook_callback_url=workflow_request.webhook_callback_url,
max_screenshot_scrolling_times=workflow_request.max_screenshot_scrolling_times,
)
context: skyvern_context.SkyvernContext | None = skyvern_context.current()
current_run_id = context.run_id if context and context.run_id else workflow_run.workflow_run_id
skyvern_context.set(
SkyvernContext(
organization_id=organization.organization_id,
@@ -184,6 +186,8 @@ class WorkflowService:
request_id=request_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
run_id=current_run_id,
workflow_permanent_id=workflow_run.workflow_permanent_id,
max_steps_override=max_steps_override,
max_screenshot_scrolling_times=workflow_request.max_screenshot_scrolling_times,
)