add functionality to cache task_run (#1755)
This commit is contained in:
@@ -348,7 +348,7 @@ class ForgeAgent:
|
||||
step,
|
||||
browser_state,
|
||||
detailed_output,
|
||||
) = await self._initialize_execution_state(task, step, workflow_run, browser_session_id)
|
||||
) = await self.initialize_execution_state(task, step, workflow_run, browser_session_id)
|
||||
|
||||
if (
|
||||
not task.navigation_goal
|
||||
@@ -759,7 +759,7 @@ class ForgeAgent:
|
||||
(
|
||||
scraped_page,
|
||||
extract_action_prompt,
|
||||
) = await self._build_and_record_step_prompt(
|
||||
) = await self.build_and_record_step_prompt(
|
||||
task,
|
||||
step,
|
||||
browser_state,
|
||||
@@ -1245,7 +1245,7 @@ class ForgeAgent:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _initialize_execution_state(
|
||||
async def initialize_execution_state(
|
||||
self,
|
||||
task: Task,
|
||||
step: Step,
|
||||
@@ -1322,7 +1322,7 @@ class ForgeAgent:
|
||||
scrape_exclude=app.scrape_exclude,
|
||||
)
|
||||
|
||||
async def _build_and_record_step_prompt(
|
||||
async def build_and_record_step_prompt(
|
||||
self,
|
||||
task: Task,
|
||||
step: Step,
|
||||
|
||||
@@ -12,7 +12,7 @@ Reply in JSON format with the following keys:
|
||||
"user_detail_query": str, // Think of this value as a Jeopardy question. Ask the user for the details you need for executing this action. Ask the question even if the details are disclosed in user instruction or user details. If you are clicking on something specific, ask about what to click on. If you're downloading a file and you have multiple options, ask the user which one to download. Otherwise, use null. Examples are: "What is the previous insurance provider of the user?", "Which invoice should I download?", "Does the user have any pets?". If the action doesn't require any user details, use null.
|
||||
"user_detail_answer": str, // The answer to the `user_detail_query`. The source of this answer can be user instruction or user details.
|
||||
"confidence_float": float, // The confidence of the action. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence
|
||||
"action_type": str, // It's a string enum: "CLICK". "CLICK" is an element you'd like to click.
|
||||
"action_type": str, // It's a string enum: "CLICK". "CLICK" type means there's an element you'd like to click.
|
||||
"id": str, // The id of the element to take action on. The id has to be one from the elements list.
|
||||
"download": bool, // If true, the browser will trigger a download by clicking the element. If false, the browser will click the element without triggering a download.
|
||||
}]
|
||||
@@ -25,7 +25,7 @@ HTML elements from `{{ current_url }}`:
|
||||
{{ elements }}
|
||||
```
|
||||
|
||||
User instruction:
|
||||
User instruction (user's intention or self questioning to help figure out what to click):
|
||||
```
|
||||
{{ navigation_goal }}
|
||||
```
|
||||
@@ -33,7 +33,12 @@ User instruction:
|
||||
User details:
|
||||
```
|
||||
{{ navigation_payload_str }}
|
||||
```{% if user_context %}
|
||||
|
||||
Context of the big goal user wants to achieve:
|
||||
```
|
||||
{{ user_context }}
|
||||
```{% endif %}
|
||||
|
||||
Current datetime, ISO format:
|
||||
```
|
||||
|
||||
@@ -2672,3 +2672,30 @@ class AgentDB:
|
||||
await session.commit()
|
||||
await session.refresh(task_run)
|
||||
return TaskRun.model_validate(task_run)
|
||||
|
||||
async def cache_task_run(self, run_id: str, organization_id: str | None = None) -> TaskRun:
|
||||
async with self.Session() as session:
|
||||
task_run = await session.scalars(
|
||||
select(TaskRunModel).filter_by(organization_id=organization_id).filter_by(run_id=run_id)
|
||||
).first()
|
||||
if task_run:
|
||||
task_run.cached = True
|
||||
await session.commit()
|
||||
await session.refresh(task_run)
|
||||
return TaskRun.model_validate(task_run)
|
||||
raise NotFoundError(f"TaskRun {run_id} not found")
|
||||
|
||||
async def get_cached_task_run(
|
||||
self, task_run_type: TaskRunType, url_hash: str | None = None, organization_id: str | None = None
|
||||
) -> TaskRun | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskRunModel)
|
||||
if task_run_type:
|
||||
query = query.filter_by(task_run_type=task_run_type)
|
||||
if url_hash:
|
||||
query = query.filter_by(url_hash=url_hash)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc())
|
||||
task_run = await session.scalars(query).first()
|
||||
return TaskRun.model_validate(task_run) if task_run else None
|
||||
|
||||
@@ -614,7 +614,10 @@ class PersistentBrowserSessionModel(Base):
|
||||
|
||||
class TaskRunModel(Base):
|
||||
__tablename__ = "task_runs"
|
||||
__table_args__ = (Index("task_run_org_url_index", "organization_id", "url_hash", "cached"),)
|
||||
__table_args__ = (
|
||||
Index("task_run_org_url_index", "organization_id", "url_hash", "cached"),
|
||||
Index("task_run_org_run_id_index", "organization_id", "run_id"),
|
||||
)
|
||||
|
||||
task_run_id = Column(String, primary_key=True, default=generate_task_run_id)
|
||||
organization_id = Column(String, nullable=False)
|
||||
|
||||
Reference in New Issue
Block a user