diff --git a/skyvern-frontend/src/api/types.ts b/skyvern-frontend/src/api/types.ts index 4c471056..b54ebc3e 100644 --- a/skyvern-frontend/src/api/types.ts +++ b/skyvern-frontend/src/api/types.ts @@ -392,6 +392,10 @@ export type CreditCardCredential = { card_holder_name: string; }; +export type ModelsResponse = { + models: string[]; +}; + export const RunEngine = { SkyvernV1: "skyvern-1.0", SkyvernV2: "skyvern-2.0", diff --git a/skyvern-frontend/src/components/ModelSelector.tsx b/skyvern-frontend/src/components/ModelSelector.tsx new file mode 100644 index 00000000..73677d7b --- /dev/null +++ b/skyvern-frontend/src/components/ModelSelector.tsx @@ -0,0 +1,66 @@ +import { HelpTooltip } from "@/components/HelpTooltip"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectTrigger, + SelectContent, + SelectValue, + SelectItem, +} from "@/components/ui/select"; +import { getClient } from "@/api/AxiosClient"; +import { useQuery } from "@tanstack/react-query"; +import { useCredentialGetter } from "@/hooks/useCredentialGetter"; +import { ModelsResponse } from "@/api/types"; +import { WorkflowModel } from "@/routes/workflows/types/workflowTypes"; + +type Props = { + className?: string; + value: WorkflowModel | null; + // -- + onChange: (value: WorkflowModel | null) => void; +}; + +function ModelSelector({ value, onChange, className }: Props) { + const credentialGetter = useCredentialGetter(); + + const { data: availableModels } = useQuery({ + queryKey: ["models"], + queryFn: async () => { + const client = await getClient(credentialGetter); + return client.get("/models").then((res) => res.data); + }, + }); + + const models = availableModels?.models ?? []; + + return ( +
+
+ + +
+ +
+ ); +} + +ModelSelector.displayName = "ModelSelector"; + +export { ModelSelector }; diff --git a/skyvern-frontend/src/routes/workflows/editor/FlowRenderer.tsx b/skyvern-frontend/src/routes/workflows/editor/FlowRenderer.tsx index b41f6732..cfebc235 100644 --- a/skyvern-frontend/src/routes/workflows/editor/FlowRenderer.tsx +++ b/skyvern-frontend/src/routes/workflows/editor/FlowRenderer.tsx @@ -233,6 +233,7 @@ function FlowRenderer({ proxy_location: data.settings.proxyLocation, webhook_callback_url: data.settings.webhookCallbackUrl, persist_browser_session: data.settings.persistBrowserSession, + model: data.settings.model, totp_verification_url: workflow.totp_verification_url, workflow_definition: { parameters: data.parameters, diff --git a/skyvern-frontend/src/routes/workflows/editor/WorkflowEditor.tsx b/skyvern-frontend/src/routes/workflows/editor/WorkflowEditor.tsx index 70728ea8..1cdf2157 100644 --- a/skyvern-frontend/src/routes/workflows/editor/WorkflowEditor.tsx +++ b/skyvern-frontend/src/routes/workflows/editor/WorkflowEditor.tsx @@ -58,6 +58,7 @@ function WorkflowEditor() { persistBrowserSession: workflow.persist_browser_session, proxyLocation: workflow.proxy_location, webhookCallbackUrl: workflow.webhook_callback_url, + model: workflow.model, }; const elements = getElements( diff --git a/skyvern-frontend/src/routes/workflows/editor/nodes/ActionNode/ActionNode.tsx b/skyvern-frontend/src/routes/workflows/editor/nodes/ActionNode/ActionNode.tsx index a9699c0f..88721c3e 100644 --- a/skyvern-frontend/src/routes/workflows/editor/nodes/ActionNode/ActionNode.tsx +++ b/skyvern-frontend/src/routes/workflows/editor/nodes/ActionNode/ActionNode.tsx @@ -35,6 +35,7 @@ import { getAvailableOutputParameterKeys } from "../../workflowEditorUtils"; import { ParametersMultiSelect } from "../TaskNode/ParametersMultiSelect"; import { useIsFirstBlockInWorkflow } from "../../hooks/useIsFirstNodeInWorkflow"; import { RunEngineSelector } from "@/components/EngineSelector"; +import { ModelSelector } from "@/components/ModelSelector"; const urlTooltip = "The URL Skyvern is navigating to. Leave this field blank to pick up from where the last block left off."; @@ -59,6 +60,7 @@ function ActionNode({ id, data }: NodeProps) { cacheActions: data.cacheActions, downloadSuffix: data.downloadSuffix, totpVerificationUrl: data.totpVerificationUrl, + model: data.model, totpIdentifier: data.totpIdentifier, engine: data.engine, }); @@ -175,6 +177,13 @@ function ActionNode({ id, data }: NodeProps) {
+ { + handleChange("model", value); + }} + /> ) { const { updateNodeData } = useReactFlow(); @@ -50,6 +51,7 @@ function ExtractionNode({ id, data }: NodeProps) { continueOnFailure: data.continueOnFailure, cacheActions: data.cacheActions, engine: data.engine, + model: data.model, }); const deleteNodeCallback = useDeleteNodeCallback(); const nodes = useNodes(); @@ -153,6 +155,13 @@ function ExtractionNode({ id, data }: NodeProps) {
+ { + handleChange("model", value); + }} + /> ) { totpVerificationUrl: data.totpVerificationUrl, totpIdentifier: data.totpIdentifier, engine: data.engine, + model: data.model, }); const deleteNodeCallback = useDeleteNodeCallback(); @@ -167,6 +169,13 @@ function FileDownloadNode({ id, data }: NodeProps) {
+ { + handleChange("model", value); + }} + /> ) { const { updateNodeData } = useReactFlow(); const { editable } = data; @@ -55,6 +57,7 @@ function LoginNode({ id, data }: NodeProps) { completeCriterion: data.completeCriterion, terminateCriterion: data.terminateCriterion, engine: data.engine, + model: data.model, }); const deleteNodeCallback = useDeleteNodeCallback(); @@ -177,6 +180,13 @@ function LoginNode({ id, data }: NodeProps) {
+ { + handleChange("model", value); + }} + /> ) { const { updateNodeData } = useReactFlow(); @@ -58,6 +59,7 @@ function NavigationNode({ id, data }: NodeProps) { completeCriterion: data.completeCriterion, terminateCriterion: data.terminateCriterion, engine: data.engine, + model: data.model, includeActionHistoryInVerification: data.includeActionHistoryInVerification, }); const deleteNodeCallback = useDeleteNodeCallback(); @@ -198,6 +200,13 @@ function NavigationNode({ id, data }: NodeProps) { />
+ { + handleChange("model", value); + }} + />