store confidence_float in db (#446)

This commit is contained in:
Kerem Yilmaz
2024-06-07 10:57:53 -07:00
committed by GitHub
parent 12cfef09d4
commit 3801bcbf19
2 changed files with 20 additions and 6 deletions

View File

@@ -854,7 +854,7 @@ class ForgeAgent:
return json.dumps( return json.dumps(
[ [
{ {
"action": action.model_dump(exclude_none=True, exclude={"text"}), "action": action.model_dump(exclude_none=True, exclude={"text", "confidence_float"}),
"results": [ "results": [
result.model_dump( result.model_dump(
exclude_none=True, exclude_none=True,

View File

@@ -45,6 +45,7 @@ class SelectOption(BaseModel):
class Action(BaseModel): class Action(BaseModel):
action_type: ActionType action_type: ActionType
confidence_float: float | None = None
description: str | None = None description: str | None = None
reasoning: str | None = None reasoning: str | None = None
element_id: Annotated[str, Field(coerce_numbers_to_str=True)] | None = None element_id: Annotated[str, Field(coerce_numbers_to_str=True)] | None = None
@@ -158,9 +159,10 @@ def parse_action(action: Dict[str, Any], data_extraction_goal: str | None = None
element_id = None element_id = None
reasoning = action["reasoning"] if "reasoning" in action else None reasoning = action["reasoning"] if "reasoning" in action else None
confidence_float = action["confidence_float"] if "confidence_float" in action else None
if "action_type" not in action or action["action_type"] is None: if "action_type" not in action or action["action_type"] is None:
return NullAction(reasoning=reasoning) return NullAction(reasoning=reasoning, confidence_float=confidence_float)
# `.upper()` handles the case where the LLM returns a lowercase action type (e.g. "click" instead of "CLICK") # `.upper()` handles the case where the LLM returns a lowercase action type (e.g. "click" instead of "CLICK")
action_type = ActionType[action["action_type"].upper()] action_type = ActionType[action["action_type"].upper()]
@@ -168,6 +170,7 @@ def parse_action(action: Dict[str, Any], data_extraction_goal: str | None = None
if action_type == ActionType.TERMINATE: if action_type == ActionType.TERMINATE:
return TerminateAction( return TerminateAction(
reasoning=reasoning, reasoning=reasoning,
confidence_float=confidence_float,
errors=action["errors"] if "errors" in action else [], errors=action["errors"] if "errors" in action else [],
) )
@@ -176,17 +179,24 @@ def parse_action(action: Dict[str, Any], data_extraction_goal: str | None = None
return ClickAction( return ClickAction(
element_id=element_id, element_id=element_id,
reasoning=reasoning, reasoning=reasoning,
confidence_float=confidence_float,
file_url=file_url, file_url=file_url,
download=action.get("download", False), download=action.get("download", False),
) )
if action_type == ActionType.INPUT_TEXT: if action_type == ActionType.INPUT_TEXT:
return InputTextAction(element_id=element_id, text=action["text"], reasoning=reasoning) return InputTextAction(
element_id=element_id,
text=action["text"],
reasoning=reasoning,
confidence_float=confidence_float,
)
if action_type == ActionType.UPLOAD_FILE: if action_type == ActionType.UPLOAD_FILE:
# TODO: see if the element is a file input element. if it's not, convert this action into a click action # TODO: see if the element is a file input element. if it's not, convert this action into a click action
return UploadFileAction( return UploadFileAction(
element_id=element_id, element_id=element_id,
confidence_float=confidence_float,
file_url=action["file_url"], file_url=action["file_url"],
reasoning=reasoning, reasoning=reasoning,
) )
@@ -197,6 +207,7 @@ def parse_action(action: Dict[str, Any], data_extraction_goal: str | None = None
element_id=element_id, element_id=element_id,
file_name=action["file_name"], file_name=action["file_name"],
reasoning=reasoning, reasoning=reasoning,
confidence_float=confidence_float,
) )
if action_type == ActionType.SELECT_OPTION: if action_type == ActionType.SELECT_OPTION:
@@ -208,6 +219,7 @@ def parse_action(action: Dict[str, Any], data_extraction_goal: str | None = None
index=action["option"]["index"], index=action["option"]["index"],
), ),
reasoning=reasoning, reasoning=reasoning,
confidence_float=confidence_float,
) )
if action_type == ActionType.CHECKBOX: if action_type == ActionType.CHECKBOX:
@@ -215,23 +227,25 @@ def parse_action(action: Dict[str, Any], data_extraction_goal: str | None = None
element_id=element_id, element_id=element_id,
is_checked=action["is_checked"], is_checked=action["is_checked"],
reasoning=reasoning, reasoning=reasoning,
confidence_float=confidence_float,
) )
if action_type == ActionType.WAIT: if action_type == ActionType.WAIT:
return WaitAction(reasoning=reasoning) return WaitAction(reasoning=reasoning, confidence_float=confidence_float)
if action_type == ActionType.COMPLETE: if action_type == ActionType.COMPLETE:
return CompleteAction( return CompleteAction(
reasoning=reasoning, reasoning=reasoning,
confidence_float=confidence_float,
data_extraction_goal=data_extraction_goal, data_extraction_goal=data_extraction_goal,
errors=action["errors"] if "errors" in action else [], errors=action["errors"] if "errors" in action else [],
) )
if action_type == "null": if action_type == "null":
return NullAction(reasoning=reasoning) return NullAction(reasoning=reasoning, confidence_float=confidence_float)
if action_type == ActionType.SOLVE_CAPTCHA: if action_type == ActionType.SOLVE_CAPTCHA:
return SolveCaptchaAction(reasoning=reasoning) return SolveCaptchaAction(reasoning=reasoning, confidence_float=confidence_float)
raise UnsupportedActionType(action_type=action_type) raise UnsupportedActionType(action_type=action_type)