shu/workflowrun timeline get observer cruise id by workflow run id (#1430)

This commit is contained in:
Shuchang Zheng
2024-12-23 11:48:27 -08:00
committed by GitHub
parent aad741d8de
commit acbdcb14e3
5 changed files with 117 additions and 4 deletions

View File

@@ -0,0 +1,43 @@
"""add more columns for different blocks
Revision ID: 835522a23b19
Revises: cf3cd8d666b0
Create Date: 2024-12-23 19:41:48.849308+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "835522a23b19"
down_revision: Union[str, None] = "cf3cd8d666b0"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index("oc_org_wfr_index", "observer_cruises", ["organization_id", "workflow_run_id"], unique=False)
op.add_column("workflow_run_blocks", sa.Column("recipients", sa.JSON(), nullable=True))
op.add_column("workflow_run_blocks", sa.Column("attachments", sa.JSON(), nullable=True))
op.add_column("workflow_run_blocks", sa.Column("subject", sa.String(), nullable=True))
op.add_column("workflow_run_blocks", sa.Column("body", sa.String(), nullable=True))
op.add_column("workflow_run_blocks", sa.Column("prompt", sa.String(), nullable=True))
op.add_column("workflow_run_blocks", sa.Column("wait_sec", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("workflow_run_blocks", "wait_sec")
op.drop_column("workflow_run_blocks", "prompt")
op.drop_column("workflow_run_blocks", "body")
op.drop_column("workflow_run_blocks", "subject")
op.drop_column("workflow_run_blocks", "attachments")
op.drop_column("workflow_run_blocks", "recipients")
op.drop_index("oc_org_wfr_index", table_name="observer_cruises")
# ### end Alembic commands ###

View File

@@ -1891,6 +1891,22 @@ class AgentDB:
return ObserverCruise.model_validate(observer_cruise)
return None
async def get_observer_cruise_by_workflow_run_id(
self,
workflow_run_id: str,
organization_id: str | None = None,
) -> ObserverCruise | None:
async with self.Session() as session:
if observer_cruise := (
await session.scalars(
select(ObserverCruiseModel)
.filter_by(organization_id=organization_id)
.filter_by(workflow_run_id=workflow_run_id)
)
).first():
return ObserverCruise.model_validate(observer_cruise)
return None
async def get_observer_thought(
self, observer_thought_id: str, organization_id: str | None = None
) -> ObserverThought | None:
@@ -2087,6 +2103,12 @@ class AgentDB:
loop_values: list | None = None,
current_value: str | None = None,
current_index: int | None = None,
recipients: list[str] | None = None,
attachments: list[str] | None = None,
subject: str | None = None,
body: str | None = None,
prompt: str | None = None,
wait_sec: int | None = None,
) -> WorkflowRunBlock:
async with self.Session() as session:
workflow_run_block = (
@@ -2111,6 +2133,18 @@ class AgentDB:
workflow_run_block.current_value = current_value
if current_index:
workflow_run_block.current_index = current_index
if recipients:
workflow_run_block.recipients = recipients
if attachments:
workflow_run_block.attachments = attachments
if subject:
workflow_run_block.subject = subject
if body:
workflow_run_block.body = body
if prompt:
workflow_run_block.prompt = prompt
if wait_sec:
workflow_run_block.wait_sec = wait_sec
await session.commit()
await session.refresh(workflow_run_block)
else:

View File

@@ -504,16 +504,31 @@ class WorkflowRunBlockModel(Base):
output = Column(JSON, nullable=True)
continue_on_failure = Column(Boolean, nullable=False, default=False)
failure_reason = Column(String, nullable=True)
# for loop block
loop_values = Column(JSON, nullable=True)
current_value = Column(String, nullable=True)
current_index = Column(Integer, nullable=True)
# email block
recipients = Column(JSON, nullable=True)
attachments = Column(JSON, nullable=True)
subject = Column(String, nullable=True)
body = Column(String, nullable=True)
# prompt block
prompt = Column(String, nullable=True)
# wait block
wait_sec = Column(Integer, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
class ObserverCruiseModel(Base):
__tablename__ = "observer_cruises"
__table_args__ = (Index("oc_org_wfr_index", "organization_id", "workflow_run_id"),)
observer_cruise_id = Column(String, primary_key=True, default=generate_observer_cruise_id)
status = Column(String, nullable=False, default="created")

View File

@@ -734,21 +734,24 @@ async def get_workflow_run(
"/workflows/{workflow_id}/runs/{workflow_run_id}/timeline/",
)
async def get_workflow_run_timeline(
workflow_id: str,
workflow_run_id: str,
observer_cruise_id: str | None = None,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1),
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> list[WorkflowRunTimeline]:
# get observer cruise by workflow run id
observer_cruise_obj = await app.DATABASE.get_observer_cruise_by_workflow_run_id(
workflow_run_id=workflow_run_id,
organization_id=current_org.organization_id,
)
# get all the workflow run blocks
workflow_run_block_timeline = await app.WORKFLOW_SERVICE.get_workflow_run_timeline(
workflow_run_id=workflow_run_id,
organization_id=current_org.organization_id,
)
if observer_cruise_id:
if observer_cruise_obj and observer_cruise_obj.observer_cruise_id:
observer_thought_timeline = await observer_service.get_observer_thought_timelines(
observer_cruise_id=observer_cruise_id,
observer_cruise_id=observer_cruise_obj.observer_cruise_id,
organization_id=current_org.organization_id,
)
workflow_run_block_timeline.extend(observer_thought_timeline)

View File

@@ -1067,6 +1067,11 @@ class TextPromptBlock(Block):
) -> BlockResult:
# get workflow run context
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
await app.DATABASE.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
prompt=self.prompt,
)
try:
self.format_potential_template_parameters(workflow_run_context)
except Exception as e:
@@ -1535,6 +1540,14 @@ class SendEmailBlock(Block):
self, workflow_run_id: str, workflow_run_block_id: str, organization_id: str | None = None, **kwargs: dict
) -> BlockResult:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
await app.DATABASE.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
recipients=self.recipients,
attachments=self.file_attachments,
subject=self.subject,
body=self.body,
)
try:
self.format_potential_template_parameters(workflow_run_context)
except Exception as e:
@@ -1692,6 +1705,11 @@ class WaitBlock(Block):
self, workflow_run_id: str, workflow_run_block_id: str, organization_id: str | None = None, **kwargs: dict
) -> BlockResult:
# TODO: we need to support to interrupt the sleep when the workflow run failed/cancelled/terminated
await app.DATABASE.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
wait_sec=self.wait_sec,
)
LOG.info(
"Going to pause the workflow for a while",
second=self.wait_sec,