Include max steps in saved task (#499)

This commit is contained in:
Kerem Yilmaz
2024-06-21 13:08:00 -07:00
committed by GitHub
parent 49d7e77b3a
commit 58735a9c20
4 changed files with 77 additions and 5 deletions

View File

@@ -1 +1,3 @@
export const PAGE_SIZE = 15;
export const MAX_STEPS_DEFAULT = 10;

View File

@@ -46,6 +46,7 @@ import {
} from "@/components/ui/accordion";
import { OrganizationApiResponse } from "@/api/types";
import { Skeleton } from "@/components/ui/skeleton";
import { MAX_STEPS_DEFAULT } from "../constants";
const createNewTaskFormSchema = z
.object({
@@ -94,7 +95,6 @@ function createTaskRequestObject(formValues: CreateNewTaskFormValues) {
navigation_goal: transform(formValues.navigationGoal),
data_extraction_goal: transform(formValues.dataExtractionGoal),
proxy_location: "RESIDENTIAL",
error_code_mapping: null,
navigation_payload: transform(formValues.navigationPayload),
extracted_information_schema: transform(
formValues.extractedInformationSchema,
@@ -102,8 +102,6 @@ function createTaskRequestObject(formValues: CreateNewTaskFormValues) {
};
}
const MAX_STEPS_DEFAULT = 10;
function CreateNewTaskForm({ initialValues }: Props) {
const queryClient = useQueryClient();
const { toast } = useToast();

View File

@@ -48,6 +48,8 @@ function CreateNewTaskFormPage() {
const dataSchema = data.workflow_definition.blocks[0].data_schema;
const maxSteps = data.workflow_definition.blocks[0].max_steps_per_run;
return (
<SavedTaskForm
initialValues={{
@@ -61,6 +63,7 @@ function CreateNewTaskFormPage() {
data.workflow_definition.blocks[0].data_extraction_goal,
extractedInformationSchema: JSON.stringify(dataSchema, null, 2),
navigationPayload,
maxSteps,
}}
/>
);

View File

@@ -24,7 +24,7 @@ import { apiBaseUrl } from "@/util/env";
import { zodResolver } from "@hookform/resolvers/zod";
import { InfoCircledIcon, ReloadIcon } from "@radix-ui/react-icons";
import { ToastAction } from "@radix-ui/react-toast";
import { useMutation, useQueryClient } from "@tanstack/react-query";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import fetchToCurl from "fetch-to-curl";
import { useForm, useFormState } from "react-hook-form";
import { Link, useParams } from "react-router-dom";
@@ -46,6 +46,9 @@ import {
AccordionItem,
AccordionTrigger,
} from "@/components/ui/accordion";
import { OrganizationApiResponse } from "@/api/types";
import { MAX_STEPS_DEFAULT } from "../constants";
import { Skeleton } from "@/components/ui/skeleton";
const savedTaskFormSchema = z
.object({
@@ -60,6 +63,7 @@ const savedTaskFormSchema = z
dataExtractionGoal: z.string().or(z.null()).optional(),
navigationPayload: z.string().or(z.null()).optional(),
extractedInformationSchema: z.string().or(z.null()).optional(),
maxSteps: z.number().optional(),
})
.superRefine(
(
@@ -154,6 +158,7 @@ function createTaskTemplateRequestObject(values: SavedTaskFormValues) {
navigation_goal: values.navigationGoal,
data_extraction_goal: values.dataExtractionGoal,
data_schema: extractedInformationSchema,
max_steps_per_run: values.maxSteps,
},
],
},
@@ -167,9 +172,30 @@ function SavedTaskForm({ initialValues }: Props) {
const apiCredential = useApiCredential();
const { template } = useParams();
const { data: organizations, isPending: organizationIsPending } = useQuery<
Array<OrganizationApiResponse>
>({
queryKey: ["organizations"],
queryFn: async () => {
const client = await getClient(credentialGetter);
return await client
.get("/organizations")
.then((response) => response.data.organizations);
},
});
const organization = organizations?.[0];
const form = useForm<SavedTaskFormValues>({
resolver: zodResolver(savedTaskFormSchema),
defaultValues: initialValues,
values: {
...initialValues,
maxSteps:
initialValues.maxSteps ??
organization?.max_steps_per_run ??
MAX_STEPS_DEFAULT,
},
});
const { isDirty } = useFormState({ control: form.control });
@@ -178,10 +204,19 @@ function SavedTaskForm({ initialValues }: Props) {
mutationFn: async (formValues: SavedTaskFormValues) => {
const taskRequest = createTaskRequestObject(formValues);
const client = await getClient(credentialGetter);
const includeOverrideHeader =
formValues.maxSteps !== organization?.max_steps_per_run &&
formValues.maxSteps !== MAX_STEPS_DEFAULT;
return client.post<
ReturnType<typeof createTaskRequestObject>,
{ data: { task_id: string } }
>("/tasks", taskRequest);
>("/tasks", taskRequest, {
...(includeOverrideHeader && {
headers: {
"x-max-steps-override": formValues.maxSteps ?? MAX_STEPS_DEFAULT,
},
}),
});
},
onError: (error: AxiosError) => {
if (error.response?.status === 402) {
@@ -516,6 +551,40 @@ function SavedTaskForm({ initialValues }: Props) {
</FormItem>
)}
/>
<FormField
control={form.control}
name="maxSteps"
render={({ field }) => {
return (
<FormItem>
<FormLabel>Max Steps</FormLabel>
<FormDescription>
Max steps for this task. This will override your
organization wide setting.
</FormDescription>
<FormControl>
{organizationIsPending ? (
<Skeleton className="h-8" />
) : (
<Input
{...field}
type="number"
min={1}
max={
organization?.max_steps_per_run ??
MAX_STEPS_DEFAULT
}
value={field.value ?? MAX_STEPS_DEFAULT}
onChange={(event) => {
field.onChange(parseInt(event.target.value));
}}
/>
)}
</FormControl>
</FormItem>
);
}}
/>
</AccordionContent>
</AccordionItem>
</Accordion>