improve validations on parameter run ui (#4000)
Co-authored-by: Jonathan Dobson <jon.m.dobson@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import { AxiosError } from "axios";
|
import { AxiosError } from "axios";
|
||||||
import { PlayIcon, ReloadIcon } from "@radix-ui/react-icons";
|
import { PlayIcon, ReloadIcon } from "@radix-ui/react-icons";
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import { useForm } from "react-hook-form";
|
import { type FieldErrors, useForm } from "react-hook-form";
|
||||||
import { useNavigate, useParams } from "react-router-dom";
|
import { useNavigate, useParams } from "react-router-dom";
|
||||||
import { useMutation, useQueryClient } from "@tanstack/react-query";
|
import { useMutation, useQueryClient } from "@tanstack/react-query";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
@@ -99,6 +99,34 @@ function parseValuesForWorkflowRun(
|
|||||||
) {
|
) {
|
||||||
return [key, value.s3uri];
|
return [key, value.s3uri];
|
||||||
}
|
}
|
||||||
|
// Convert boolean values to strings for backend storage
|
||||||
|
if (
|
||||||
|
parameter?.workflow_parameter_type === "boolean" &&
|
||||||
|
typeof value === "boolean"
|
||||||
|
) {
|
||||||
|
return [key, String(value)];
|
||||||
|
}
|
||||||
|
if (parameter?.workflow_parameter_type === "string") {
|
||||||
|
if (value === null || value === undefined) {
|
||||||
|
return [key, ""];
|
||||||
|
}
|
||||||
|
return [key, String(value)];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
parameter?.workflow_parameter_type === "integer" ||
|
||||||
|
parameter?.workflow_parameter_type === "float"
|
||||||
|
) {
|
||||||
|
if (
|
||||||
|
value === null ||
|
||||||
|
value === undefined ||
|
||||||
|
(typeof value === "number" && Number.isNaN(value))
|
||||||
|
) {
|
||||||
|
return [key, ""];
|
||||||
|
}
|
||||||
|
return [key, String(value)];
|
||||||
|
}
|
||||||
|
|
||||||
return [key, value];
|
return [key, value];
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
@@ -211,6 +239,8 @@ function RunWorkflowForm({
|
|||||||
const { data: workflow } = useWorkflowQuery({ workflowPermanentId });
|
const { data: workflow } = useWorkflowQuery({ workflowPermanentId });
|
||||||
|
|
||||||
const form = useForm<RunWorkflowFormType>({
|
const form = useForm<RunWorkflowFormType>({
|
||||||
|
mode: "onTouched",
|
||||||
|
reValidateMode: "onChange",
|
||||||
defaultValues: {
|
defaultValues: {
|
||||||
...initialValues,
|
...initialValues,
|
||||||
webhookCallbackUrl: initialSettings.webhookCallbackUrl,
|
webhookCallbackUrl: initialSettings.webhookCallbackUrl,
|
||||||
@@ -268,6 +298,7 @@ function RunWorkflowForm({
|
|||||||
unknown
|
unknown
|
||||||
> | null>(null);
|
> | null>(null);
|
||||||
const [cacheKeyValue, setCacheKeyValue] = useState<string>("");
|
const [cacheKeyValue, setCacheKeyValue] = useState<string>("");
|
||||||
|
const [isFormReset, setIsFormReset] = useState(false);
|
||||||
const cacheKey = workflow?.cache_key ?? "default";
|
const cacheKey = workflow?.cache_key ?? "default";
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -297,10 +328,41 @@ function RunWorkflowForm({
|
|||||||
setHasCode(Object.keys(blockScripts ?? {}).length > 0);
|
setHasCode(Object.keys(blockScripts ?? {}).length > 0);
|
||||||
}, [blockScripts]);
|
}, [blockScripts]);
|
||||||
|
|
||||||
|
// Watch form changes and update run parameters without triggering validation
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
onChange(form.getValues());
|
const subscription = form.watch((values) => {
|
||||||
|
onChange(values as RunWorkflowFormType);
|
||||||
|
});
|
||||||
|
return () => subscription.unsubscribe();
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [form]);
|
}, []);
|
||||||
|
|
||||||
|
// Reset form with initial values after all fields are registered
|
||||||
|
useEffect(() => {
|
||||||
|
form.reset({
|
||||||
|
...initialValues,
|
||||||
|
webhookCallbackUrl: initialSettings.webhookCallbackUrl,
|
||||||
|
proxyLocation: initialSettings.proxyLocation,
|
||||||
|
browserSessionId: null,
|
||||||
|
cdpAddress: initialSettings.cdpAddress,
|
||||||
|
maxScreenshotScrolls: initialSettings.maxScreenshotScrolls,
|
||||||
|
extraHttpHeaders: initialSettings.extraHttpHeaders
|
||||||
|
? JSON.stringify(initialSettings.extraHttpHeaders)
|
||||||
|
: null,
|
||||||
|
runWithCode: workflow?.run_with === "code",
|
||||||
|
aiFallback: workflow?.ai_fallback ?? true,
|
||||||
|
});
|
||||||
|
setIsFormReset(true);
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Trigger validation after form is reset and re-rendered
|
||||||
|
useEffect(() => {
|
||||||
|
if (isFormReset) {
|
||||||
|
form.trigger();
|
||||||
|
}
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [isFormReset]);
|
||||||
|
|
||||||
// if we're coming from debugger, block scripts may already be cached; let's ensure we bust it
|
// if we're coming from debugger, block scripts may already be cached; let's ensure we bust it
|
||||||
// on mount
|
// on mount
|
||||||
@@ -360,6 +422,22 @@ function RunWorkflowForm({
|
|||||||
setRunParameters(parsedParameters);
|
setRunParameters(parsedParameters);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleInvalid = (errors: FieldErrors<RunWorkflowFormType>) => {
|
||||||
|
const hasBlockingErrors = workflowParameters.some(
|
||||||
|
(param) =>
|
||||||
|
(param.workflow_parameter_type === "boolean" ||
|
||||||
|
param.workflow_parameter_type === "integer" ||
|
||||||
|
param.workflow_parameter_type === "float" ||
|
||||||
|
param.workflow_parameter_type === "file_url" ||
|
||||||
|
param.workflow_parameter_type === "json") &&
|
||||||
|
errors[param.key],
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!hasBlockingErrors) {
|
||||||
|
onSubmit(form.getValues());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if (!workflowPermanentId || !workflow) {
|
if (!workflowPermanentId || !workflow) {
|
||||||
return <div>Invalid workflow</div>;
|
return <div>Invalid workflow</div>;
|
||||||
}
|
}
|
||||||
@@ -367,8 +445,7 @@ function RunWorkflowForm({
|
|||||||
return (
|
return (
|
||||||
<Form {...form}>
|
<Form {...form}>
|
||||||
<form
|
<form
|
||||||
onChange={form.handleSubmit(onChange)}
|
onSubmit={form.handleSubmit(onSubmit, handleInvalid)}
|
||||||
onSubmit={form.handleSubmit(onSubmit)}
|
|
||||||
className="space-y-8"
|
className="space-y-8"
|
||||||
>
|
>
|
||||||
<div className="space-y-8 rounded-lg bg-slate-elevation3 px-6 py-5">
|
<div className="space-y-8 rounded-lg bg-slate-elevation3 px-6 py-5">
|
||||||
@@ -383,19 +460,74 @@ function RunWorkflowForm({
|
|||||||
name={parameter.key}
|
name={parameter.key}
|
||||||
rules={{
|
rules={{
|
||||||
validate: (value) => {
|
validate: (value) => {
|
||||||
if (
|
if (parameter.workflow_parameter_type === "json") {
|
||||||
parameter.workflow_parameter_type === "json" &&
|
if (value === null || value === undefined) {
|
||||||
typeof value === "string"
|
return "This field is required";
|
||||||
) {
|
|
||||||
try {
|
|
||||||
JSON.parse(value);
|
|
||||||
return true;
|
|
||||||
} catch (e) {
|
|
||||||
return "Invalid JSON";
|
|
||||||
}
|
}
|
||||||
|
if (typeof value === "string") {
|
||||||
|
const trimmed = value.trim();
|
||||||
|
if (trimmed === "") {
|
||||||
|
return "This field is required";
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
JSON.parse(trimmed);
|
||||||
|
return true;
|
||||||
|
} catch (e) {
|
||||||
|
return "Invalid JSON";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
if (value === null) {
|
|
||||||
return "This field is required";
|
// Boolean parameters are required - show error and block submission
|
||||||
|
if (parameter.workflow_parameter_type === "boolean") {
|
||||||
|
if (value === null || value === undefined) {
|
||||||
|
return "This field is required";
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Numeric parameters are required - show error and block submission
|
||||||
|
if (
|
||||||
|
parameter.workflow_parameter_type === "integer" ||
|
||||||
|
parameter.workflow_parameter_type === "float"
|
||||||
|
) {
|
||||||
|
if (
|
||||||
|
value === null ||
|
||||||
|
value === undefined ||
|
||||||
|
Number.isNaN(value)
|
||||||
|
) {
|
||||||
|
return "This field is required";
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parameter.workflow_parameter_type === "file_url") {
|
||||||
|
if (
|
||||||
|
value === null ||
|
||||||
|
value === undefined ||
|
||||||
|
(typeof value === "string" && value.trim() === "") ||
|
||||||
|
(typeof value === "object" &&
|
||||||
|
value !== null &&
|
||||||
|
"s3uri" in value &&
|
||||||
|
!value.s3uri)
|
||||||
|
) {
|
||||||
|
return "This field is required";
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For string parameters, show warning but don't block
|
||||||
|
if (
|
||||||
|
parameter.workflow_parameter_type === "string" &&
|
||||||
|
(value === null || value === "")
|
||||||
|
) {
|
||||||
|
return "Warning: you left this field empty";
|
||||||
|
}
|
||||||
|
|
||||||
|
// For all other non-boolean types, show warning but don't block
|
||||||
|
if (value === null || value === undefined) {
|
||||||
|
return "Warning: you left this field empty";
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
@@ -403,7 +535,7 @@ function RunWorkflowForm({
|
|||||||
return (
|
return (
|
||||||
<FormItem>
|
<FormItem>
|
||||||
<div className="flex gap-16">
|
<div className="flex gap-16">
|
||||||
<FormLabel>
|
<FormLabel className="!text-slate-50">
|
||||||
<div className="w-72">
|
<div className="w-72">
|
||||||
<div className="flex items-center gap-2 text-lg">
|
<div className="flex items-center gap-2 text-lg">
|
||||||
{parameter.key}
|
{parameter.key}
|
||||||
@@ -423,11 +555,27 @@ function RunWorkflowForm({
|
|||||||
<WorkflowParameterInput
|
<WorkflowParameterInput
|
||||||
type={parameter.workflow_parameter_type}
|
type={parameter.workflow_parameter_type}
|
||||||
value={field.value}
|
value={field.value}
|
||||||
onChange={field.onChange}
|
onChange={(value) => {
|
||||||
|
field.onChange(value);
|
||||||
|
form.trigger(parameter.key);
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
{form.formState.errors[parameter.key] && (
|
{form.formState.errors[parameter.key] && (
|
||||||
<div className="text-destructive">
|
<div
|
||||||
|
className={`text-xs ${
|
||||||
|
parameter.workflow_parameter_type ===
|
||||||
|
"boolean" ||
|
||||||
|
parameter.workflow_parameter_type ===
|
||||||
|
"integer" ||
|
||||||
|
parameter.workflow_parameter_type === "float" ||
|
||||||
|
parameter.workflow_parameter_type ===
|
||||||
|
"file_url" ||
|
||||||
|
parameter.workflow_parameter_type === "json"
|
||||||
|
? "text-destructive"
|
||||||
|
: "text-warning"
|
||||||
|
}`}
|
||||||
|
>
|
||||||
{form.formState.errors[parameter.key]?.message}
|
{form.formState.errors[parameter.key]?.message}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -1,9 +1,14 @@
|
|||||||
import { FileInputValue, FileUpload } from "@/components/FileUpload";
|
import { FileInputValue, FileUpload } from "@/components/FileUpload";
|
||||||
import { Checkbox } from "@/components/ui/checkbox";
|
|
||||||
import { Input } from "@/components/ui/input";
|
import { Input } from "@/components/ui/input";
|
||||||
|
import {
|
||||||
|
Select,
|
||||||
|
SelectContent,
|
||||||
|
SelectItem,
|
||||||
|
SelectTrigger,
|
||||||
|
SelectValue,
|
||||||
|
} from "@/components/ui/select";
|
||||||
import { CodeEditor } from "./components/CodeEditor";
|
import { CodeEditor } from "./components/CodeEditor";
|
||||||
import { AutoResizingTextarea } from "@/components/AutoResizingTextarea/AutoResizingTextarea";
|
import { AutoResizingTextarea } from "@/components/AutoResizingTextarea/AutoResizingTextarea";
|
||||||
import { Label } from "@/components/ui/label";
|
|
||||||
import { WorkflowParameterValueType } from "./types/workflowTypes";
|
import { WorkflowParameterValueType } from "./types/workflowTypes";
|
||||||
import { CredentialSelector } from "./components/CredentialSelector";
|
import { CredentialSelector } from "./components/CredentialSelector";
|
||||||
|
|
||||||
@@ -60,16 +65,19 @@ function WorkflowParameterInput({ type, value, onChange }: Props) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (type === "boolean") {
|
if (type === "boolean") {
|
||||||
const checked = typeof value === "boolean" ? value : Boolean(value);
|
|
||||||
return (
|
return (
|
||||||
<div className="flex items-center gap-2">
|
<Select
|
||||||
<Checkbox
|
value={value === null ? "" : String(value)}
|
||||||
checked={checked}
|
onValueChange={(v) => onChange(v === "true")}
|
||||||
onCheckedChange={(checked) => onChange(checked)}
|
>
|
||||||
className="block"
|
<SelectTrigger className="w-48">
|
||||||
/>
|
<SelectValue placeholder="Select value..." />
|
||||||
<Label>{value ? "True" : "False"}</Label>
|
</SelectTrigger>
|
||||||
</div>
|
<SelectContent>
|
||||||
|
<SelectItem value="true">True</SelectItem>
|
||||||
|
<SelectItem value="false">False</SelectItem>
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -116,6 +116,26 @@ function convertToParametersYAML(
|
|||||||
| CredentialParameterYAML
|
| CredentialParameterYAML
|
||||||
| undefined => {
|
| undefined => {
|
||||||
if (parameter.parameterType === WorkflowEditorParameterTypes.Workflow) {
|
if (parameter.parameterType === WorkflowEditorParameterTypes.Workflow) {
|
||||||
|
// Convert boolean default values to strings for backend
|
||||||
|
let defaultValue = parameter.defaultValue;
|
||||||
|
if (
|
||||||
|
parameter.dataType === "boolean" &&
|
||||||
|
typeof parameter.defaultValue === "boolean"
|
||||||
|
) {
|
||||||
|
defaultValue = String(parameter.defaultValue);
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
(parameter.dataType === "integer" ||
|
||||||
|
parameter.dataType === "float") &&
|
||||||
|
(typeof parameter.defaultValue === "number" ||
|
||||||
|
typeof parameter.defaultValue === "string")
|
||||||
|
) {
|
||||||
|
defaultValue =
|
||||||
|
parameter.defaultValue === null
|
||||||
|
? parameter.defaultValue
|
||||||
|
: String(parameter.defaultValue);
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
parameter_type: WorkflowParameterTypes.Workflow,
|
parameter_type: WorkflowParameterTypes.Workflow,
|
||||||
key: parameter.key,
|
key: parameter.key,
|
||||||
@@ -123,7 +143,7 @@ function convertToParametersYAML(
|
|||||||
workflow_parameter_type: parameter.dataType,
|
workflow_parameter_type: parameter.dataType,
|
||||||
...(parameter.defaultValue === null
|
...(parameter.defaultValue === null
|
||||||
? {}
|
? {}
|
||||||
: { default_value: parameter.defaultValue }),
|
: { default_value: defaultValue }),
|
||||||
};
|
};
|
||||||
} else if (
|
} else if (
|
||||||
parameter.parameterType === WorkflowEditorParameterTypes.Context
|
parameter.parameterType === WorkflowEditorParameterTypes.Context
|
||||||
|
|||||||
@@ -538,11 +538,32 @@ function WorkflowParameterEditPanel({
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const defaultValue =
|
let defaultValue = defaultValueState.defaultValue;
|
||||||
|
|
||||||
|
// Handle JSON parsing
|
||||||
|
if (
|
||||||
parameterType === "json" &&
|
parameterType === "json" &&
|
||||||
typeof defaultValueState.defaultValue === "string"
|
typeof defaultValueState.defaultValue === "string"
|
||||||
? JSON.parse(defaultValueState.defaultValue)
|
) {
|
||||||
: defaultValueState.defaultValue;
|
defaultValue = JSON.parse(defaultValueState.defaultValue);
|
||||||
|
}
|
||||||
|
// Convert boolean to string for backend storage
|
||||||
|
else if (
|
||||||
|
parameterType === "boolean" &&
|
||||||
|
typeof defaultValueState.defaultValue === "boolean"
|
||||||
|
) {
|
||||||
|
defaultValue = String(defaultValueState.defaultValue);
|
||||||
|
}
|
||||||
|
// Convert numeric defaults to strings for backend storage
|
||||||
|
else if (
|
||||||
|
(parameterType === "integer" ||
|
||||||
|
parameterType === "float") &&
|
||||||
|
(typeof defaultValueState.defaultValue === "number" ||
|
||||||
|
typeof defaultValueState.defaultValue === "string")
|
||||||
|
) {
|
||||||
|
defaultValue = String(defaultValueState.defaultValue);
|
||||||
|
}
|
||||||
|
|
||||||
onSave({
|
onSave({
|
||||||
key,
|
key,
|
||||||
parameterType: "workflow",
|
parameterType: "workflow",
|
||||||
|
|||||||
@@ -1852,11 +1852,30 @@ function convertParametersToParameterYAML(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
case WorkflowParameterTypes.Workflow: {
|
case WorkflowParameterTypes.Workflow: {
|
||||||
|
// Convert default values to strings for backend when needed
|
||||||
|
let defaultValue = parameter.default_value;
|
||||||
|
if (
|
||||||
|
parameter.workflow_parameter_type === "boolean" &&
|
||||||
|
typeof parameter.default_value === "boolean"
|
||||||
|
) {
|
||||||
|
defaultValue = String(parameter.default_value);
|
||||||
|
} else if (
|
||||||
|
(parameter.workflow_parameter_type === "integer" ||
|
||||||
|
parameter.workflow_parameter_type === "float") &&
|
||||||
|
(typeof parameter.default_value === "number" ||
|
||||||
|
typeof parameter.default_value === "string")
|
||||||
|
) {
|
||||||
|
defaultValue =
|
||||||
|
parameter.default_value === null
|
||||||
|
? parameter.default_value
|
||||||
|
: String(parameter.default_value);
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...base,
|
...base,
|
||||||
parameter_type: WorkflowParameterTypes.Workflow,
|
parameter_type: WorkflowParameterTypes.Workflow,
|
||||||
workflow_parameter_type: parameter.workflow_parameter_type,
|
workflow_parameter_type: parameter.workflow_parameter_type,
|
||||||
default_value: parameter.default_value,
|
default_value: defaultValue,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
case WorkflowParameterTypes.Credential: {
|
case WorkflowParameterTypes.Credential: {
|
||||||
|
|||||||
@@ -12,31 +12,24 @@ export const getInitialValues = (
|
|||||||
? location.state.data
|
? location.state.data
|
||||||
: workflowParameters?.reduce(
|
: workflowParameters?.reduce(
|
||||||
(acc, curr) => {
|
(acc, curr) => {
|
||||||
if (curr.workflow_parameter_type === "json") {
|
const hasDefaultValue =
|
||||||
if (typeof curr.default_value === "string") {
|
curr.default_value !== null && curr.default_value !== undefined;
|
||||||
acc[curr.key] = curr.default_value;
|
if (hasDefaultValue) {
|
||||||
|
// Handle JSON parameters
|
||||||
|
if (curr.workflow_parameter_type === "json") {
|
||||||
|
if (typeof curr.default_value === "string") {
|
||||||
|
acc[curr.key] = curr.default_value;
|
||||||
|
} else {
|
||||||
|
acc[curr.key] = JSON.stringify(curr.default_value, null, 2);
|
||||||
|
}
|
||||||
return acc;
|
return acc;
|
||||||
}
|
}
|
||||||
if (curr.default_value) {
|
if (curr.workflow_parameter_type === "boolean") {
|
||||||
acc[curr.key] = JSON.stringify(curr.default_value, null, 2);
|
// Backend stores as strings, convert to boolean for frontend
|
||||||
|
acc[curr.key] =
|
||||||
|
curr.default_value === "true" || curr.default_value === true;
|
||||||
return acc;
|
return acc;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if (
|
|
||||||
curr.default_value &&
|
|
||||||
curr.workflow_parameter_type === "boolean"
|
|
||||||
) {
|
|
||||||
acc[curr.key] = Boolean(curr.default_value);
|
|
||||||
return acc;
|
|
||||||
}
|
|
||||||
if (
|
|
||||||
curr.default_value === null &&
|
|
||||||
curr.workflow_parameter_type === "string"
|
|
||||||
) {
|
|
||||||
acc[curr.key] = "";
|
|
||||||
return acc;
|
|
||||||
}
|
|
||||||
if (curr.default_value) {
|
|
||||||
acc[curr.key] = curr.default_value;
|
acc[curr.key] = curr.default_value;
|
||||||
return acc;
|
return acc;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -673,6 +673,12 @@ class LLMAPIHandlerFactory:
|
|||||||
# Add Vertex AI cache reference only for the intended cached prompt
|
# Add Vertex AI cache reference only for the intended cached prompt
|
||||||
vertex_cache_attached = False
|
vertex_cache_attached = False
|
||||||
cache_resource_name = getattr(context, "vertex_cache_name", None)
|
cache_resource_name = getattr(context, "vertex_cache_name", None)
|
||||||
|
LOG.info(
|
||||||
|
"Vertex cache attachment check",
|
||||||
|
cache_resource_name=cache_resource_name,
|
||||||
|
prompt_name=prompt_name,
|
||||||
|
use_prompt_caching=getattr(context, "use_prompt_caching", None) if context else None,
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
cache_resource_name
|
cache_resource_name
|
||||||
and prompt_name == EXTRACT_ACTION_PROMPT_NAME
|
and prompt_name == EXTRACT_ACTION_PROMPT_NAME
|
||||||
|
|||||||
@@ -8,16 +8,13 @@ Unlike the Anthropic-style cache_control markers, Vertex AI requires:
|
|||||||
3. Referencing that cache name in subsequent requests
|
3. Referencing that cache name in subsequent requests
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import google.auth
|
import google.auth
|
||||||
import requests
|
import requests
|
||||||
import structlog
|
import structlog
|
||||||
from google.auth.credentials import Credentials
|
|
||||||
from google.auth.transport.requests import Request
|
from google.auth.transport.requests import Request
|
||||||
from google.oauth2 import service_account
|
|
||||||
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
|
|
||||||
@@ -32,7 +29,7 @@ class VertexCacheManager:
|
|||||||
unlike implicit caching which requires exact prompt matches.
|
unlike implicit caching which requires exact prompt matches.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, project_id: str, location: str = "global", credentials_json: str | None = None):
|
def __init__(self, project_id: str, location: str = "global"):
|
||||||
self.project_id = project_id
|
self.project_id = project_id
|
||||||
self.location = location
|
self.location = location
|
||||||
# Use regional endpoint for non-global locations, global endpoint for global
|
# Use regional endpoint for non-global locations, global endpoint for global
|
||||||
@@ -41,39 +38,13 @@ class VertexCacheManager:
|
|||||||
else:
|
else:
|
||||||
self.api_endpoint = f"{location}-aiplatform.googleapis.com"
|
self.api_endpoint = f"{location}-aiplatform.googleapis.com"
|
||||||
self._cache_registry: dict[str, dict[str, Any]] = {} # Maps cache_key -> cache_data
|
self._cache_registry: dict[str, dict[str, Any]] = {} # Maps cache_key -> cache_data
|
||||||
self._scopes = ["https://www.googleapis.com/auth/cloud-platform"]
|
|
||||||
self._default_credentials = None
|
|
||||||
self._service_account_credentials = None
|
|
||||||
self._service_account_info: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
if credentials_json:
|
|
||||||
try:
|
|
||||||
self._service_account_info = json.loads(credentials_json)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
LOG.warning("Failed to parse Vertex credentials JSON, falling back to ADC", error=str(exc))
|
|
||||||
|
|
||||||
def _get_access_token(self) -> str:
|
def _get_access_token(self) -> str:
|
||||||
"""Get Google Cloud access token for API calls."""
|
"""Get Google Cloud access token for API calls."""
|
||||||
try:
|
try:
|
||||||
credentials: Credentials | None = None
|
# Try to use default credentials
|
||||||
if self._service_account_info:
|
credentials, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
|
||||||
if not self._service_account_credentials:
|
credentials.refresh(Request())
|
||||||
self._service_account_credentials = service_account.Credentials.from_service_account_info(
|
|
||||||
self._service_account_info,
|
|
||||||
scopes=self._scopes,
|
|
||||||
)
|
|
||||||
credentials = self._service_account_credentials
|
|
||||||
else:
|
|
||||||
if not self._default_credentials:
|
|
||||||
self._default_credentials, _ = google.auth.default(scopes=self._scopes)
|
|
||||||
credentials = self._default_credentials
|
|
||||||
|
|
||||||
if credentials is None:
|
|
||||||
raise RuntimeError("Unable to initialize Google credentials for Vertex cache manager")
|
|
||||||
|
|
||||||
if not credentials.valid or credentials.expired:
|
|
||||||
credentials.refresh(Request())
|
|
||||||
|
|
||||||
return credentials.token
|
return credentials.token
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.error("Failed to get access token", error=str(e))
|
LOG.error("Failed to get access token", error=str(e))
|
||||||
@@ -226,11 +197,7 @@ def get_cache_manager() -> VertexCacheManager:
|
|||||||
# Default to "global" to match the model configs in cloud/__init__.py
|
# Default to "global" to match the model configs in cloud/__init__.py
|
||||||
# Can be overridden with VERTEX_LOCATION (e.g., "us-central1" for better caching)
|
# Can be overridden with VERTEX_LOCATION (e.g., "us-central1" for better caching)
|
||||||
location = settings.VERTEX_LOCATION or "global"
|
location = settings.VERTEX_LOCATION or "global"
|
||||||
_global_cache_manager = VertexCacheManager(
|
_global_cache_manager = VertexCacheManager(project_id, location)
|
||||||
project_id=project_id,
|
|
||||||
location=location,
|
|
||||||
credentials_json=settings.VERTEX_CREDENTIALS,
|
|
||||||
)
|
|
||||||
LOG.info("Created global cache manager", project_id=project_id, location=location)
|
LOG.info("Created global cache manager", project_id=project_id, location=location)
|
||||||
|
|
||||||
return _global_cache_manager
|
return _global_cache_manager
|
||||||
|
|||||||
Reference in New Issue
Block a user