feat: migrate fal workflow to server (#7581)

This commit is contained in:
darkskygit
2024-07-26 04:04:38 +00:00
parent cb0d91facd
commit 470262d400
24 changed files with 741 additions and 299 deletions

View File

@@ -288,6 +288,7 @@ export class CopilotController {
if (latestMessage) {
params = Object.assign({}, params, latestMessage.params, {
content: latestMessage.content,
attachments: latestMessage.attachments,
});
}
@@ -302,14 +303,22 @@ export class CopilotController {
merge(
// actual chat event stream
shared$.pipe(
map(data =>
data.status === GraphExecutorState.EmitContent
? {
map(data => {
switch (data.status) {
case GraphExecutorState.EmitContent:
return {
type: 'message' as const,
id: messageId,
data: data.content,
}
: {
};
case GraphExecutorState.EmitAttachment:
return {
type: 'attachment' as const,
id: messageId,
data: data.attachment,
};
default:
return {
type: 'event' as const,
id: messageId,
data: {
@@ -317,8 +326,9 @@ export class CopilotController {
id: data.node.id,
type: data.node.config.nodeType,
} as any,
}
)
};
}
})
),
// save the generated text to the session
shared$.pipe(
@@ -378,6 +388,7 @@ export class CopilotController {
const source$ = from(
provider.generateImagesStream(session.finish(params), session.model, {
...session.config.promptConfig,
seed: this.parseNumber(params.seed),
signal: this.getSignal(req),
user: user.id,

View File

@@ -27,8 +27,6 @@ function extractMustacheParams(template: string) {
return Array.from(new Set(params));
}
const EXCLUDE_MISSING_WARN_PARAMS = ['lora'];
export class ChatPrompt {
private readonly logger = new Logger(ChatPrompt.name);
public readonly encoder: Tokenizer | null;
@@ -104,12 +102,12 @@ export class ChatPrompt {
typeof income !== 'string' ||
(Array.isArray(options) && !options.includes(income))
) {
if (sessionId && !EXCLUDE_MISSING_WARN_PARAMS.includes(key)) {
if (sessionId) {
const prefix = income
? `Invalid param value: ${key}=${income}`
: `Missing param value: ${key}`;
this.logger.warn(
`${prefix} in session ${sessionId}, use default options: ${options[0]}`
`${prefix} in session ${sessionId}, use default options: ${Array.isArray(options) ? options[0] : options}`
);
}
if (Array.isArray(options)) {
@@ -129,11 +127,28 @@ export class ChatPrompt {
*/
finish(params: PromptParams, sessionId?: string): PromptMessage[] {
this.checkParams(params, sessionId);
return this.messages.map(({ content, params: _, ...rest }) => ({
...rest,
params,
content: Mustache.render(content, params),
}));
const { attachments: attach, ...restParams } = params;
const paramsAttach = Array.isArray(attach) ? attach : [];
return this.messages.map(
({ attachments: attach, content, params: _, ...rest }) => {
const result: PromptMessage = {
...rest,
params,
content: Mustache.render(content, restParams),
};
const attachments = [
...(Array.isArray(attach) ? attach : []),
...paramsAttach,
];
if (attachments.length && rest.role === 'user') {
result.attachments = attachments;
}
return result;
}
);
}
}

View File

@@ -59,9 +59,15 @@ const FalStreamOutputSchema = z.object({
});
type FalPrompt = {
model_name?: string;
image_url?: string;
prompt?: string;
lora?: string[];
loras?: { path: string; scale?: number }[];
controlnets?: {
image_url: string;
start_percentage?: number;
end_percentage?: number;
}[];
};
export class FalProvider
@@ -83,10 +89,8 @@ export class FalProvider
'face-to-sticker',
'imageutils/rembg',
'fast-sdxl/image-to-image',
'workflows/darkskygit/animie',
'workflows/darkskygit/clay',
'workflows/darkskygit/pixel-art',
'workflows/darkskygit/sketch',
'workflowutils/teed',
'lora/image-to-image',
// image to text
'llava-next',
];
@@ -112,7 +116,15 @@ export class FalProvider
return this.availableModels.includes(model);
}
private extractPrompt(message?: PromptMessage): FalPrompt {
private extractArray<T>(value: T | T[] | undefined): T[] {
if (Array.isArray(value)) return value;
return value ? [value] : [];
}
private extractPrompt(
message?: PromptMessage,
options: CopilotImageOptions = {}
): FalPrompt {
if (!message) throw new CopilotPromptInvalid('Prompt is empty');
const { content, attachments, params } = message;
// prompt attachments require at least one
@@ -122,17 +134,23 @@ export class FalProvider
if (Array.isArray(attachments) && attachments.length > 1) {
throw new CopilotPromptInvalid('Only one attachment is allowed');
}
const lora = (
params?.lora
? Array.isArray(params.lora)
? params.lora
: [params.lora]
: []
).filter(v => typeof v === 'string' && v.length);
const lora = [
...this.extractArray(params?.lora),
...this.extractArray(options.loras),
].filter(
(v): v is { path: string; scale?: number } =>
!!v && typeof v === 'object' && typeof v.path === 'string'
);
const controlnets = this.extractArray(params?.controlnets).filter(
(v): v is { image_url: string } =>
!!v && typeof v === 'object' && typeof v.image_url === 'string'
);
return {
model_name: options.modelName || undefined,
image_url: attachments?.[0],
prompt: content.trim(),
lora: lora.length ? lora : undefined,
loras: lora.length ? lora : undefined,
controlnets: controlnets.length ? controlnets : undefined,
};
}
@@ -246,7 +264,7 @@ export class FalProvider
options: CopilotImageOptions = {}
) {
// by default, image prompt assumes there is only one message
const prompt = this.extractPrompt(messages.pop());
const prompt = this.extractPrompt(messages.pop(), options);
if (model.startsWith('workflows/')) {
const stream = await falStream(model, { input: prompt });
return this.parseSchema(FalStreamOutputSchema, await stream.done())

View File

@@ -42,6 +42,7 @@ export class OpenAIProvider
readonly availableModels = [
// text to text
'gpt-4o',
'gpt-4o-mini',
'gpt-4-vision-preview',
'gpt-4-turbo-preview',
'gpt-3.5-turbo',

View File

@@ -50,7 +50,7 @@ const PureMessageSchema = z.object({
content: z.string(),
attachments: z.array(z.string()).optional().nullable(),
params: z
.record(z.union([z.string(), z.array(z.string())]))
.record(z.union([z.string(), z.array(z.string()), z.record(z.any())]))
.optional()
.nullable(),
});
@@ -64,12 +64,21 @@ export type PromptMessage = z.infer<typeof PromptMessageSchema>;
export type PromptParams = NonNullable<PromptMessage['params']>;
export const PromptConfigStrictSchema = z.object({
// openai
jsonMode: z.boolean().nullable().optional(),
frequencyPenalty: z.number().nullable().optional(),
presencePenalty: z.number().nullable().optional(),
temperature: z.number().nullable().optional(),
topP: z.number().nullable().optional(),
maxTokens: z.number().nullable().optional(),
// fal
modelName: z.string().nullable().optional(),
loras: z
.array(
z.object({ path: z.string(), scale: z.number().nullable().optional() })
)
.nullable()
.optional(),
});
export const PromptConfigSchema =
@@ -175,9 +184,13 @@ export type CopilotEmbeddingOptions = z.infer<
typeof CopilotEmbeddingOptionsSchema
>;
const CopilotImageOptionsSchema = CopilotProviderOptionsSchema.extend({
seed: z.number().optional(),
}).optional();
const CopilotImageOptionsSchema = CopilotProviderOptionsSchema.merge(
PromptConfigStrictSchema
)
.extend({
seed: z.number().optional(),
})
.optional();
export type CopilotImageOptions = z.infer<typeof CopilotImageOptionsSchema>;

View File

@@ -63,28 +63,31 @@ export class CopilotChatImageExecutor extends AutoRegisteredWorkflowExecutor {
params: Record<string, string>,
options?: CopilotChatOptions
): AsyncIterable<NodeExecuteResult> {
const [{ paramKey, id }, prompt, provider] = await this.initExecutor(data);
const [{ paramKey, paramToucher, id }, prompt, provider] =
await this.initExecutor(data);
const finalMessage = prompt.finish(params);
const config = { ...prompt.config, ...options };
if (paramKey) {
// update params with custom key
const result = {
[paramKey]: await provider.generateImages(
finalMessage,
prompt.model,
config
),
};
yield {
type: NodeExecuteState.Params,
params: {
[paramKey]: await provider.generateImages(
finalMessage,
prompt.model,
options
),
},
params: paramToucher?.(result) ?? result,
};
} else {
for await (const content of provider.generateImagesStream(
for await (const attachment of provider.generateImagesStream(
finalMessage,
prompt.model,
options
config
)) {
yield { type: NodeExecuteState.Content, nodeId: id, content };
yield { type: NodeExecuteState.Attachment, nodeId: id, attachment };
}
}
}

View File

@@ -63,26 +63,29 @@ export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
params: Record<string, string>,
options?: CopilotChatOptions
): AsyncIterable<NodeExecuteResult> {
const [{ paramKey, id }, prompt, provider] = await this.initExecutor(data);
const [{ paramKey, paramToucher, id }, prompt, provider] =
await this.initExecutor(data);
const finalMessage = prompt.finish(params);
const config = { ...prompt.config, ...options };
if (paramKey) {
// update params with custom key
const result = {
[paramKey]: await provider.generateText(
finalMessage,
prompt.model,
config
),
};
yield {
type: NodeExecuteState.Params,
params: {
[paramKey]: await provider.generateText(
finalMessage,
prompt.model,
options
),
},
params: paramToucher?.(result) ?? result,
};
} else {
for await (const content of provider.generateTextStream(
finalMessage,
prompt.model,
options
config
)) {
yield { type: NodeExecuteState.Content, nodeId: id, content };
}

View File

@@ -26,7 +26,7 @@ export class CopilotCheckHtmlExecutor extends AutoRegisteredWorkflowExecutor {
}
private async checkHtml(
content?: string | string[],
content?: string | string[] | Record<string, any>,
strict?: boolean
): Promise<boolean> {
try {

View File

@@ -25,7 +25,9 @@ export class CopilotCheckJsonExecutor extends AutoRegisteredWorkflowExecutor {
return NodeExecutorType.CheckJson;
}
private checkJson(content?: string | string[]): boolean {
private checkJson(
content?: string | string[] | Record<string, any>
): boolean {
try {
if (content && typeof content === 'string') {
JSON.parse(content);

View File

@@ -14,13 +14,15 @@ export enum NodeExecuteState {
EndRun,
Params,
Content,
Attachment,
}
export type NodeExecuteResult =
| { type: NodeExecuteState.StartRun; nodeId: string }
| { type: NodeExecuteState.EndRun; nextNode?: WorkflowNode }
| { type: NodeExecuteState.Params; params: WorkflowParams }
| { type: NodeExecuteState.Content; nodeId: string; content: string };
| { type: NodeExecuteState.Content; nodeId: string; content: string }
| { type: NodeExecuteState.Attachment; nodeId: string; attachment: string };
export abstract class NodeExecutor {
abstract get type(): NodeExecutorType;

View File

@@ -1,87 +0,0 @@
import { NodeExecutorType } from './executor';
import type { WorkflowGraphs, WorkflowNodeState } from './types';
import { WorkflowNodeType } from './types';
export const WorkflowGraphList: WorkflowGraphs = [
{
name: 'presentation',
graph: [
{
id: 'start',
name: 'Start: check language',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:presentation:step1',
paramKey: 'language',
edges: ['step2'],
},
{
id: 'step2',
name: 'Step 2: generate presentation',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:presentation:step2',
edges: ['step3'],
},
{
id: 'step3',
name: 'Step 3: format presentation if needed',
nodeType: WorkflowNodeType.Decision,
condition: (nodeIds: string[], params: WorkflowNodeState) => {
const lines = params.content?.split('\n') || [];
return nodeIds[
Number(
!lines.some(line => {
try {
if (line.trim()) {
JSON.parse(line);
}
return false;
} catch {
return true;
}
})
)
];
},
edges: ['step4', 'step5'],
},
{
id: 'step4',
name: 'Step 4: format presentation',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:presentation:step4',
edges: ['step5'],
},
{
id: 'step5',
name: 'Step 5: finish',
nodeType: WorkflowNodeType.Nope,
edges: [],
},
],
},
{
name: 'brainstorm',
graph: [
{
id: 'start',
name: 'Start: check language',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:brainstorm:step1',
paramKey: 'language',
edges: ['step2'],
},
{
id: 'step2',
name: 'Step 2: generate brainstorm mind map',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:brainstorm:step2',
edges: [],
},
],
},
];

View File

@@ -0,0 +1,25 @@
import { NodeExecutorType } from '../executor';
import { type WorkflowGraph, WorkflowNodeType } from '../types';
export const brainstorm: WorkflowGraph = {
name: 'brainstorm',
graph: [
{
id: 'start',
name: 'Start: check language',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:brainstorm:step1',
paramKey: 'language',
edges: ['step2'],
},
{
id: 'step2',
name: 'Step 2: generate brainstorm mind map',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:brainstorm:step2',
edges: [],
},
],
};

View File

@@ -0,0 +1,183 @@
import { NodeExecutorType } from '../executor';
import type { WorkflowGraph, WorkflowParams } from '../types';
import { WorkflowNodeType } from '../types';
export const sketch: WorkflowGraph = {
name: 'image-sketch',
graph: [
{
id: 'start',
name: 'Start: extract edge',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatImage,
promptName: 'debug:action:fal-teed',
paramKey: 'controlnets',
paramToucher: params => {
if (Array.isArray(params.controlnets)) {
const controlnets = params.controlnets.map(image_url => ({
path: 'diffusers/controlnet-canny-sdxl-1.0',
image_url,
start_percentage: 0.1,
end_percentage: 0.6,
}));
return { controlnets } as WorkflowParams;
} else {
return {};
}
},
edges: ['step2'],
},
{
id: 'step2',
name: 'Step 2: generate tags',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:image-sketch:step2',
paramKey: 'tags',
edges: ['step3'],
},
{
id: 'step3',
name: 'Step3: generate image',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatImage,
promptName: 'workflow:image-sketch:step3',
edges: [],
},
],
};
export const clay: WorkflowGraph = {
name: 'image-clay',
graph: [
{
id: 'start',
name: 'Start: extract edge',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatImage,
promptName: 'debug:action:fal-teed',
paramKey: 'controlnets',
paramToucher: params => {
if (Array.isArray(params.controlnets)) {
const controlnets = params.controlnets.map(image_url => ({
path: 'diffusers/controlnet-canny-sdxl-1.0',
image_url,
start_percentage: 0.1,
end_percentage: 0.6,
}));
return { controlnets } as WorkflowParams;
} else {
return {};
}
},
edges: ['step2'],
},
{
id: 'step2',
name: 'Step 2: generate tags',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:image-clay:step2',
paramKey: 'tags',
edges: ['step3'],
},
{
id: 'step3',
name: 'Step3: generate image',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatImage,
promptName: 'workflow:image-clay:step3',
edges: [],
},
],
};
export const anime: WorkflowGraph = {
name: 'image-anime',
graph: [
{
id: 'start',
name: 'Start: extract edge',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatImage,
promptName: 'debug:action:fal-teed',
paramKey: 'controlnets',
paramToucher: params => {
if (Array.isArray(params.controlnets)) {
const controlnets = params.controlnets.map(image_url => ({
path: 'diffusers/controlnet-canny-sdxl-1.0',
image_url,
start_percentage: 0.1,
end_percentage: 0.6,
}));
return { controlnets } as WorkflowParams;
} else {
return {};
}
},
edges: ['step2'],
},
{
id: 'step2',
name: 'Step 2: generate tags',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:image-anime:step2',
paramKey: 'tags',
edges: ['step3'],
},
{
id: 'step3',
name: 'Step3: generate image',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatImage,
promptName: 'workflow:image-anime:step3',
edges: [],
},
],
};
export const pixel: WorkflowGraph = {
name: 'image-pixel',
graph: [
{
id: 'start',
name: 'Start: extract edge',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatImage,
promptName: 'debug:action:fal-teed',
paramKey: 'controlnets',
paramToucher: params => {
if (Array.isArray(params.controlnets)) {
const controlnets = params.controlnets.map(image_url => ({
path: 'diffusers/controlnet-canny-sdxl-1.0',
image_url,
start_percentage: 0.1,
end_percentage: 0.6,
}));
return { controlnets } as WorkflowParams;
} else {
return {};
}
},
edges: ['step2'],
},
{
id: 'step2',
name: 'Step 2: generate tags',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:image-pixel:step2',
paramKey: 'tags',
edges: ['step3'],
},
{
id: 'step3',
name: 'Step3: generate image',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatImage,
promptName: 'workflow:image-pixel:step3',
edges: [],
},
],
};

View File

@@ -0,0 +1,13 @@
import type { WorkflowGraphs } from '../types';
import { brainstorm } from './brainstorm';
import { anime, clay, pixel, sketch } from './image-filter';
import { presentation } from './presentation';
export const WorkflowGraphList: WorkflowGraphs = [
brainstorm,
presentation,
sketch,
clay,
anime,
pixel,
];

View File

@@ -0,0 +1,63 @@
import { NodeExecutorType } from '../executor';
import type { WorkflowGraph, WorkflowNodeState } from '../types';
import { WorkflowNodeType } from '../types';
export const presentation: WorkflowGraph = {
name: 'presentation',
graph: [
{
id: 'start',
name: 'Start: check language',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:presentation:step1',
paramKey: 'language',
edges: ['step2'],
},
{
id: 'step2',
name: 'Step 2: generate presentation',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:presentation:step2',
edges: ['step3'],
},
{
id: 'step3',
name: 'Step 3: format presentation if needed',
nodeType: WorkflowNodeType.Decision,
condition: (nodeIds: string[], params: WorkflowNodeState) => {
const lines = params.content?.split('\n') || [];
return nodeIds[
Number(
!lines.some(line => {
try {
if (line.trim()) {
JSON.parse(line);
}
return false;
} catch {
return true;
}
})
)
];
},
edges: ['step4', 'step5'],
},
{
id: 'step4',
name: 'Step 4: format presentation',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:presentation:step4',
edges: ['step5'],
},
{
id: 'step5',
name: 'Step 5: finish',
nodeType: WorkflowNodeType.Nope,
edges: [],
},
],
};

View File

@@ -16,6 +16,7 @@ export type WorkflowNodeData = { id: string; name: string } & (
promptName?: string;
// update the prompt params by output with the custom key
paramKey?: string;
paramToucher?: (params: WorkflowParams) => WorkflowParams;
}
| {
nodeType: WorkflowNodeType.Decision;
@@ -44,5 +45,8 @@ export type WorkflowGraphs = Array<WorkflowGraph>;
// ===================== executor =====================
export type WorkflowParams = Record<string, string | string[]>;
export type WorkflowParams = Record<
string,
string | string[] | Record<string, any>
>;
export type WorkflowNodeState = Record<string, string>;

View File

@@ -9,12 +9,14 @@ import { WorkflowNodeType } from './types';
export enum GraphExecutorState {
EnterNode = 'EnterNode',
EmitContent = 'EmitContent',
EmitAttachment = 'EmitAttachment',
ExitNode = 'ExitNode',
}
export type GraphExecutorStatus = { status: GraphExecutorState } & (
| { status: GraphExecutorState.EnterNode; node: WorkflowNode }
| { status: GraphExecutorState.EmitContent; content: string }
| { status: GraphExecutorState.EmitAttachment; attachment: string }
| { status: GraphExecutorState.ExitNode; node: WorkflowNode }
);
@@ -66,6 +68,15 @@ export class WorkflowGraphExecutor {
} else {
result += ret.content;
}
} else if (
ret.type === NodeExecuteState.Attachment &&
!currentNode.hasEdges
) {
// pass through content as a stream response if node is end node
yield {
status: GraphExecutorState.EmitAttachment,
attachment: ret.attachment,
};
}
}