Task streaming frontend (#512)

Co-authored-by: Muhammed Salih Altun <muhammedsalihaltun@gmail.com>
This commit is contained in:
Kerem Yilmaz
2024-06-25 12:18:44 -07:00
committed by GitHub
parent fe1c7214f7
commit a0a3aa6f83
8 changed files with 487 additions and 207 deletions

View File

@@ -3,5 +3,9 @@ VITE_API_BASE_URL=http://localhost:8000/api/v1
# server to load artifacts from file URIs
VITE_ARTIFACT_API_BASE_URL=http://localhost:9090
# websocket
# VITE_WSS_BASE_URL=wss://api-staging.skyvern.com/api/v1
VITE_WSS_BASE_URL=ws://localhost:8000/api/v1
# your api key - for x-api-key header
VITE_SKYVERN_API_KEY=

View File

@@ -12,21 +12,23 @@ import {
import { useCredentialGetter } from "@/hooks/useCredentialGetter";
import { cn } from "@/util/utils";
import {
ArrowLeftIcon,
ArrowRightIcon,
ArrowDownIcon,
ArrowUpIcon,
CheckCircledIcon,
CrossCircledIcon,
DotFilledIcon,
} from "@radix-ui/react-icons";
import { useQueryClient } from "@tanstack/react-query";
import { useEffect, useRef } from "react";
import { ReactNode, useEffect, useRef } from "react";
import { useParams } from "react-router-dom";
type Props = {
data: Array<Action | null>;
onNext: () => void;
onPrevious: () => void;
onActiveIndexChange: (index: number) => void;
activeIndex: number;
onActiveIndexChange: (index: number | "stream") => void;
activeIndex: number | "stream";
showStreamOption: boolean;
};
function ScrollableActionList({
@@ -35,22 +37,105 @@ function ScrollableActionList({
onPrevious,
activeIndex,
onActiveIndexChange,
showStreamOption,
}: Props) {
const { taskId } = useParams();
const queryClient = useQueryClient();
const credentialGetter = useCredentialGetter();
const refs = useRef<Array<HTMLDivElement | null>>(
Array.from({ length: data.length }),
Array.from({ length: data.length + 1 }),
);
useEffect(() => {
if (refs.current[activeIndex]) {
if (typeof activeIndex === "number" && refs.current[activeIndex]) {
refs.current[activeIndex]?.scrollIntoView({
behavior: "smooth",
block: "nearest",
});
}
}, [activeIndex]);
if (activeIndex === "stream") {
refs.current[data.length]?.scrollIntoView({
behavior: "smooth",
block: "nearest",
});
}
}, [activeIndex, data.length]);
function getReverseActions() {
const elements: ReactNode[] = [];
for (let i = data.length - 1; i >= 0; i--) {
const action = data[i];
const actionIndex = data.length - i - 1;
if (!action) {
continue;
}
const selected = activeIndex === actionIndex;
elements.push(
<div
key={i}
ref={(element) => {
refs.current[actionIndex] = element;
}}
className={cn(
"flex p-4 rounded-lg shadow-md border hover:border-slate-300 cursor-pointer",
{
"border-slate-300": selected,
},
)}
onClick={() => onActiveIndexChange(actionIndex)}
onMouseEnter={() => {
queryClient.prefetchQuery({
queryKey: ["task", taskId, "steps", action.stepId, "artifacts"],
queryFn: async () => {
const client = await getClient(credentialGetter);
return client
.get(`/tasks/${taskId}/steps/${action.stepId}/artifacts`)
.then((response) => response.data);
},
});
}}
>
<div className="flex-1 p-2 pt-0 space-y-2">
<div className="flex justify-between">
<div className="flex gap-2 items-center">
<span>#{i + 1}</span>
<Badge>{ReadableActionTypes[action.type]}</Badge>
</div>
<div className="flex items-center gap-2">
{typeof action.confidence === "number" && (
<TooltipProvider>
<Tooltip>
<TooltipTrigger>
<Badge variant="secondary">{action.confidence}</Badge>
</TooltipTrigger>
<TooltipContent>Confidence Score</TooltipContent>
</Tooltip>
</TooltipProvider>
)}
{action.success ? (
<CheckCircledIcon className="w-6 h-6 text-success" />
) : (
<CrossCircledIcon className="w-6 h-6 text-destructive" />
)}
</div>
</div>
<div className="text-sm">{action.reasoning}</div>
{action.type === ActionTypes.InputText && (
<>
<Separator className="bg-slate-50 block" />
<div className="text-sm">Input: {action.input}</div>
</>
)}
</div>
</div>,
);
}
return elements;
}
const actionIndex =
typeof activeIndex === "number" ? data.length - activeIndex - 1 : "stream";
return (
<div className="w-1/3 flex flex-col items-center border rounded h-[40rem]">
@@ -61,89 +146,37 @@ function ScrollableActionList({
onPrevious();
}}
>
<ArrowLeftIcon />
<ArrowUpIcon />
</Button>
{activeIndex + 1} of {data.length} total actions
{typeof actionIndex === "number" &&
`#${actionIndex + 1} of ${data.length} total actions`}
{activeIndex === "stream" && "Livestream"}
<Button size="icon" onClick={() => onNext()}>
<ArrowRightIcon />
<ArrowDownIcon />
</Button>
</div>
<div className="overflow-y-scroll w-full px-4 pb-4 space-y-4">
{data.map((action, index) => {
if (!action) {
return null;
}
const selected = activeIndex === index;
return (
<div
key={index}
ref={(element) => {
refs.current[index] = element;
}}
className={cn(
"flex p-4 rounded-lg shadow-md border hover:border-slate-500 cursor-pointer",
{
"border-slate-500": selected,
},
)}
onClick={() => onActiveIndexChange(index)}
onMouseEnter={() => {
queryClient.prefetchQuery({
queryKey: [
"task",
taskId,
"steps",
action.stepId,
"artifacts",
],
queryFn: async () => {
const client = await getClient(credentialGetter);
return client
.get(`/tasks/${taskId}/steps/${action.stepId}/artifacts`)
.then((response) => response.data);
},
staleTime: Infinity,
});
}}
>
<div className="flex-1 p-2 pt-0 space-y-2">
<div className="flex justify-between">
<div className="flex gap-2 items-center">
<span>#{index + 1}</span>
<Badge>{ReadableActionTypes[action.type]}</Badge>
</div>
<div className="flex items-center gap-2">
{typeof action.confidence === "number" && (
<TooltipProvider>
<Tooltip>
<TooltipTrigger>
<Badge variant="secondary">
{action.confidence}
</Badge>
</TooltipTrigger>
<TooltipContent>Confidence Score</TooltipContent>
</Tooltip>
</TooltipProvider>
)}
{action.success ? (
<CheckCircledIcon className="w-6 h-6 text-success" />
) : (
<CrossCircledIcon className="w-6 h-6 text-destructive" />
)}
</div>
</div>
<div className="text-sm">{action.reasoning}</div>
{action.type === ActionTypes.InputText && (
<>
<Separator className="bg-slate-50 block" />
<div className="text-sm">Input: {action.input}</div>
</>
)}
</div>
{showStreamOption && (
<div
key="stream"
ref={(element) => {
refs.current[data.length] = element;
}}
className={cn(
"flex p-4 rounded-lg shadow-md border hover:border-slate-300 cursor-pointer",
{
"border-slate-300": activeIndex === "stream",
},
)}
onClick={() => onActiveIndexChange("stream")}
>
<div className="text-lg flex gap-2 items-center">
<DotFilledIcon className="w-6 h-6 text-red-500" />
Live
</div>
);
})}
</div>
)}
{getReverseActions()}
</div>
</div>
);

View File

@@ -1,19 +1,179 @@
import { useState } from "react";
import { useEffect, useState } from "react";
import { useParams } from "react-router-dom";
import { ActionScreenshot } from "./ActionScreenshot";
import { ScrollableActionList } from "./ScrollableActionList";
import { useActions } from "./useActions";
import { keepPreviousData, useQuery } from "@tanstack/react-query";
import {
ActionApiResponse,
ActionTypes,
Status,
StepApiResponse,
TaskApiResponse,
} from "@/api/types";
import { getClient } from "@/api/AxiosClient";
import { useCredentialGetter } from "@/hooks/useCredentialGetter";
import { Skeleton } from "@/components/ui/skeleton";
import { toast } from "@/components/ui/use-toast";
type StreamMessage = {
task_id: string;
status: string;
screenshot?: string;
};
let socket: WebSocket | null = null;
const wssBaseUrl = import.meta.env.VITE_WSS_BASE_URL;
function getActionInput(action: ActionApiResponse) {
let input = "";
if (action.action_type === ActionTypes.InputText && action.text) {
input = action.text;
} else if (action.action_type === ActionTypes.Click) {
input = "Click";
} else if (action.action_type === ActionTypes.SelectOption && action.option) {
input = action.option.label;
}
return input;
}
function TaskActions() {
const { taskId } = useParams();
const credentialGetter = useCredentialGetter();
const [streamImgSrc, setStreamImgSrc] = useState<string>("");
const [selectedAction, setSelectedAction] = useState<number | "stream">(0);
const { data, isFetching } = useActions(taskId!);
const [selectedActionIndex, setSelectedAction] = useState(0);
const { data: task, isLoading: taskIsLoading } = useQuery<TaskApiResponse>({
queryKey: ["task", taskId],
queryFn: async () => {
const client = await getClient(credentialGetter);
return client.get(`/tasks/${taskId}`).then((response) => response.data);
},
refetchInterval: (query) => {
if (
query.state.data?.status === Status.Running ||
query.state.data?.status === Status.Queued
) {
return 5000;
}
return false;
},
placeholderData: keepPreviousData,
});
const taskIsRunningOrQueued =
task?.status === Status.Running || task?.status === Status.Queued;
const activeAction = data?.[selectedActionIndex];
useEffect(() => {
if (!taskIsRunningOrQueued) {
return;
}
if (isFetching) {
async function run() {
// Create WebSocket connection.
const credential = await credentialGetter!();
if (socket) {
socket.close();
}
socket = new WebSocket(
`${wssBaseUrl}/stream/tasks/${taskId}?token=Bearer ${credential}`,
);
// Listen for messages
socket.addEventListener("message", (event) => {
try {
const message: StreamMessage = JSON.parse(event.data);
if (message.screenshot) {
setStreamImgSrc(message.screenshot);
}
if (
message.status === "completed" ||
message.status === "failed" ||
message.status === "terminated"
) {
socket?.close();
setSelectedAction(0);
if (
message.status === "failed" ||
message.status === "terminated"
) {
toast({
title: "Task Failed",
description: "The task has failed.",
variant: "destructive",
});
} else if (message.status === "completed") {
toast({
title: "Task Completed",
description: "The task has been completed.",
variant: "success",
});
}
}
} catch (e) {
console.error("Failed to parse message", e);
}
});
socket.addEventListener("close", () => {
socket = null;
});
}
run();
return () => {
if (socket) {
socket.close();
socket = null;
}
};
}, [credentialGetter, taskId, taskIsRunningOrQueued]);
useEffect(() => {
if (!taskIsLoading && taskIsRunningOrQueued) {
setSelectedAction("stream");
}
}, [taskIsLoading, taskIsRunningOrQueued]);
const { data: steps, isLoading: stepsIsLoading } = useQuery<
Array<StepApiResponse>
>({
queryKey: ["task", taskId, "steps"],
queryFn: async () => {
const client = await getClient(credentialGetter);
return client
.get(`/tasks/${taskId}/steps`)
.then((response) => response.data);
},
enabled: !!task,
refetchOnWindowFocus: taskIsRunningOrQueued,
refetchInterval: taskIsRunningOrQueued ? 5000 : false,
placeholderData: keepPreviousData,
});
const actions = steps
?.map((step) => {
const actionsAndResults = step.output?.actions_and_results ?? [];
const actions = actionsAndResults.map((actionAndResult, index) => {
const action = actionAndResult[0];
const actionResult = actionAndResult[1];
if (actionResult.length === 0) {
return null;
}
return {
reasoning: action.reasoning,
confidence: action.confidence_float,
input: getActionInput(action),
type: action.action_type,
success: actionResult?.[0]?.success ?? false,
stepId: step.step_id,
index,
};
});
return actions;
})
.flat();
if (taskIsLoading || stepsIsLoading) {
return (
<div className="flex gap-2">
<div className="h-[40rem] w-3/4">
@@ -26,36 +186,102 @@ function TaskActions() {
);
}
if (!data) {
return <div>No actions</div>;
}
const activeAction =
typeof selectedAction === "number" &&
actions?.[actions.length - selectedAction - 1];
if (!activeAction) {
return <div>No active action</div>;
function getStream() {
if (task?.status === Status.Queued) {
return (
<div className="w-full h-full flex flex-col gap-4 items-center justify-center text-lg bg-slate-900">
<span>Your task is queued. Typical queue time is 1-2 minutes.</span>
<span>Stream will start when the task is running.</span>
</div>
);
}
if (task?.status === Status.Running && streamImgSrc.length === 0) {
return (
<div className="w-full h-full flex items-center justify-center text-lg bg-slate-900">
Starting the stream...
</div>
);
}
if (task?.status === Status.Running && streamImgSrc.length > 0) {
return (
<div className="w-full h-full">
<img src={`data:image/png;base64,${streamImgSrc}`} />
</div>
);
}
return null;
}
return (
<div className="flex gap-2">
<div className="w-2/3 border rounded">
<div className="p-4">
<ActionScreenshot
stepId={activeAction.stepId}
index={activeAction.index}
/>
<div className="p-4 w-full h-full">
{selectedAction === "stream" ? getStream() : null}
{typeof selectedAction === "number" && activeAction ? (
<ActionScreenshot
stepId={activeAction.stepId}
index={activeAction.index}
/>
) : null}
</div>
</div>
<ScrollableActionList
activeIndex={selectedActionIndex}
data={data}
activeIndex={selectedAction}
data={actions ?? []}
onActiveIndexChange={setSelectedAction}
onNext={() =>
setSelectedAction((prev) =>
prev === data.length - 1 ? prev : prev + 1,
)
}
onPrevious={() =>
setSelectedAction((prev) => (prev === 0 ? prev : prev - 1))
}
showStreamOption={taskIsRunningOrQueued}
onNext={() => {
if (!actions) {
return;
}
setSelectedAction((prev) => {
if (taskIsRunningOrQueued) {
if (actions.length === 0) {
return "stream";
}
if (prev === actions.length - 1) {
return actions.length - 1;
}
if (prev === "stream") {
return 0;
}
return prev + 1;
}
if (typeof prev === "number") {
return prev === actions.length - 1 ? prev : prev + 1;
}
return 0;
});
}}
onPrevious={() => {
if (!actions) {
return;
}
setSelectedAction((prev) => {
if (taskIsRunningOrQueued) {
if (actions.length === 0) {
return "stream";
}
if (prev === 0) {
return "stream";
}
if (prev === "stream") {
return "stream";
}
return prev - 1;
}
if (typeof prev === "number") {
return prev === 0 ? prev : prev - 1;
}
return 0;
});
}}
/>
</div>
);

View File

@@ -1,7 +1,6 @@
import { getClient } from "@/api/AxiosClient";
import { Status, TaskApiResponse } from "@/api/types";
import { StatusBadge } from "@/components/StatusBadge";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { Skeleton } from "@/components/ui/skeleton";
import { Textarea } from "@/components/ui/textarea";
@@ -16,11 +15,11 @@ function TaskDetails() {
const {
data: task,
isFetching: taskIsFetching,
isLoading: taskIsLoading,
isError: taskIsError,
error: taskError,
} = useQuery<TaskApiResponse>({
queryKey: ["task", taskId, "details"],
queryKey: ["task", taskId],
queryFn: async () => {
const client = await getClient(credentialGetter);
return client.get(`/tasks/${taskId}`).then((response) => response.data);
@@ -30,7 +29,7 @@ function TaskDetails() {
query.state.data?.status === Status.Running ||
query.state.data?.status === Status.Queued
) {
return 30000;
return 10000;
}
return false;
},
@@ -72,14 +71,14 @@ function TaskDetails() {
return (
<div className="flex flex-col gap-8">
<div className="flex items-center gap-4">
<Input value={taskId} className="w-52" readOnly />
{taskIsFetching ? (
<span className="text-lg">{taskId}</span>
{taskIsLoading ? (
<Skeleton className="w-28 h-8" />
) : task ? (
<StatusBadge status={task?.status} />
) : null}
</div>
{taskIsFetching ? (
{taskIsLoading ? (
<div className="flex items-center gap-2">
<Skeleton className="w-32 h-32" />
<Skeleton className="w-full h-32" />

View File

@@ -36,8 +36,6 @@ function TaskRecording() {
return <div>Error loading recording</div>;
}
console.log(task);
return (
<div className="flex mx-auto">
{task.recording_url ? (

View File

@@ -0,0 +1,104 @@
import { Status } from "@/api/types";
import { useCredentialGetter } from "@/hooks/useCredentialGetter";
import { useEffect, useState } from "react";
import { useParams } from "react-router-dom";
type StreamMessage = {
task_id: string;
status: string;
screenshot?: string;
};
let socket: WebSocket | null = null;
type Props = {
status: Status;
};
const wssBaseUrl = import.meta.env.VITE_WSS_BASE_URL;
function TaskStream({ status }: Props) {
const { taskId } = useParams();
const credentialGetter = useCredentialGetter();
const [imgSrc, setImgSrc] = useState<string>("");
useEffect(() => {
if (!taskId || !credentialGetter) {
console.error("TaskStream: Task ID is required");
return;
}
async function run() {
// Create WebSocket connection.
const credential = await credentialGetter!();
if (socket) {
socket.close();
}
socket = new WebSocket(
`${wssBaseUrl}/stream/tasks/${taskId}?token=Bearer ${credential}`,
);
socket.addEventListener("open", (event) => {
console.log("open event", event);
});
// Listen for messages
socket.addEventListener("message", (event) => {
try {
const message: StreamMessage = JSON.parse(event.data);
if (message.screenshot) {
setImgSrc(message.screenshot);
}
if (message.status === "completed") {
socket?.close();
}
} catch (e) {
console.error("Failed to parse message", e);
}
});
socket.addEventListener("close", (event) => {
console.log("close event", event);
socket = null;
});
}
run();
return () => {
if (socket) {
socket.close();
socket = null;
}
};
}, [credentialGetter, taskId]);
if (status === Status.Queued) {
return (
<div className="w-full h-full flex flex-col gap-4 items-center justify-center text-lg bg-slate-900">
<span>Your task is queued. Typical queue time is 1-2 minutes.</span>
<span>Stream will start when the task is running.</span>
</div>
);
}
if (status === Status.Running && imgSrc.length === 0) {
return (
<div className="w-full h-full flex items-center justify-center text-lg bg-slate-900">
Starting the stream...
</div>
);
}
if (status === Status.Running && imgSrc.length > 0) {
return (
<div className="w-full h-full">
<img src={`data:image/png;base64,${imgSrc}`} />
</div>
);
}
return null;
}
export { TaskStream };

View File

@@ -1,84 +0,0 @@
import { getClient } from "@/api/AxiosClient";
import {
Action,
ActionApiResponse,
ActionTypes,
Status,
StepApiResponse,
TaskApiResponse,
} from "@/api/types";
import { useCredentialGetter } from "@/hooks/useCredentialGetter";
import { useQuery } from "@tanstack/react-query";
function getActionInput(action: ActionApiResponse) {
let input = "";
if (action.action_type === ActionTypes.InputText && action.text) {
input = action.text;
} else if (action.action_type === ActionTypes.Click) {
input = "Click";
} else if (action.action_type === ActionTypes.SelectOption && action.option) {
input = action.option.label;
}
return input;
}
function useActions(taskId: string): {
data?: Array<Action | null>;
isFetching: boolean;
} {
const credentialGetter = useCredentialGetter();
const { data: task, isFetching: taskIsFetching } = useQuery<TaskApiResponse>({
queryKey: ["task", taskId],
queryFn: async () => {
const client = await getClient(credentialGetter);
return client.get(`/tasks/${taskId}`).then((response) => response.data);
},
});
const taskIsRunningOrQueued =
task?.status === Status.Running || task?.status === Status.Queued;
const stepsQuery = useQuery<Array<StepApiResponse>>({
queryKey: ["task", taskId, "steps"],
queryFn: async () => {
const client = await getClient(credentialGetter);
return client
.get(`/tasks/${taskId}/steps`)
.then((response) => response.data);
},
enabled: !!task,
staleTime: taskIsRunningOrQueued ? 30 : Infinity,
refetchOnWindowFocus: taskIsRunningOrQueued,
});
const actions = stepsQuery.data
?.map((step) => {
const actionsAndResults = step.output?.actions_and_results ?? [];
const actions = actionsAndResults.map((actionAndResult, index) => {
const action = actionAndResult[0];
const actionResult = actionAndResult[1];
if (actionResult.length === 0) {
return null;
}
return {
reasoning: action.reasoning,
confidence: action.confidence_float,
input: getActionInput(action),
type: action.action_type,
success: actionResult?.[0]?.success ?? false,
stepId: step.step_id,
index,
};
});
return actions;
})
.flat();
return {
data: actions,
isFetching: stepsQuery.isFetching || taskIsFetching,
};
}
export { useActions };

View File

@@ -1,12 +1,12 @@
function basicTimeFormat(time: string): string {
const date = new Date(time);
const dateString = date.toLocaleDateString("en-us", {
weekday: "long",
const dateString = date.toLocaleDateString("en-US", {
weekday: "short",
year: "numeric",
month: "short",
day: "numeric",
});
const timeString = date.toLocaleTimeString("en-us");
const timeString = date.toLocaleTimeString("en-US");
return `${dateString} at ${timeString}`;
}