feat: add more workflow executor (#7231)

This commit is contained in:
darkskygit
2024-06-25 10:54:16 +00:00
parent 532a628989
commit cffaf815e1
23 changed files with 946 additions and 324 deletions

View File

@@ -63,11 +63,13 @@
"dotenv": "^16.4.5",
"dotenv-cli": "^7.4.1",
"express": "^4.19.2",
"fast-xml-parser": "^4.4.0",
"get-stream": "^9.0.1",
"graphql": "^16.8.1",
"graphql-scalars": "^1.23.0",
"graphql-type-json": "^0.3.2",
"graphql-upload": "^16.0.2",
"html-validate": "^8.20.1",
"ioredis": "^5.3.2",
"keyv": "^4.5.4",
"lodash-es": "^4.17.21",

View File

@@ -27,7 +27,7 @@ export async function createApp() {
app.use(
graphqlUploadExpress({
// TODO(@darksky): dynamic limit by quota maybe?
// TODO(@darkskygit): dynamic limit by quota maybe?
maxFileSize: 100 * 1024 * 1024,
maxFiles: 5,
})

View File

@@ -484,8 +484,7 @@ content: {{content}}`,
messages: [
{
role: 'system',
content:
"You are a PPT creator. You need to analyze and expand the input content based on the input, not more than 30 words per page for title and 500 words per page for content and give the keywords to call the images via unsplash to match each paragraph. Output according to the indented formatting template given below, without redundancy, at least 8 pages of PPT, of which the first page is the cover page, consisting of title, description and optional image, the title should not exceed 4 words.\nThe following are PPT templates, you can choose any template to apply, page name, column name, title, keywords, content should be removed by text replacement, do not retain. Keywords need to be generic enough for broad, mass categorization. The output ignores template titles like template1 and template2. The first template is allowed to be used only once and as a cover, please strictly follow the template's hierarchical indentation and my requirements, bolding, headings and other formatting (e.g., #, **) are not allowed, or penalties will be applied:\ntemplate1:\n- {page name}\n  - {title}\n    - keywords\n    - {description}\ntemplate2:\n- {page name}\n  - {section name}\n    - keywords\n    - {content}\n  - {section name}\n    - keywords\n    - {content}\ntemplate3:\n- {page name}\n  - {section name}\n    - keywords\n    - {content}\n  - {section name}\n    - keywords\n    - {content}\n  - {section name}\n    - keywords\n    - {content}\ntemplate4:\n- {page name}\n  - {section name}\n    - keywords\n    - {content}\n  - {section name}\n    - keywords\n    - {content}\n  - {section name}\n    - keywords\n    - {content}\n  - {section name}\n    - keywords\n    - {content}\ntemplate5:\n- {page name}\n  - {section name}\n    - keywords\n    - {content}",
content: `You are a PPT creator. You need to analyze and expand the input content based on the input, not more than 30 words per page for title and 500 words per page for content and give the keywords to call the images via unsplash to match each paragraph. Output according to the indented formatting template given below, without redundancy, at least 8 pages of PPT, of which the first page is the cover page, consisting of title, description and optional image, the title should not exceed 4 words.\nThe following are PPT templates, you can choose any template to apply, page name, column name, title, keywords, content should be removed by text replacement, do not retain, no responses should contain markdown formatting. Keywords need to be generic enough for broad, mass categorization. The output ignores template titles like template1 and template2. The first template is allowed to be used only once and as a cover, please strictly follow the template's ND-JSON field, format and my requirements, or penalties will be applied:\n{"page":1,"type":"name","content":"page name"}\n{"page":1,"type":"title","content":"title"}\n{"page":1,"type":"content","content":"keywords"}\n{"page":1,"type":"content","content":"description"}\n{"page":2,"type":"name","content":"page name"}\n{"page":2,"type":"title","content":"section name"}\n{"page":2,"type":"content","content":"keywords"}\n{"page":2,"type":"content","content":"description"}\n{"page":2,"type":"title","content":"section name"}\n{"page":2,"type":"content","content":"keywords"}\n{"page":2,"type":"content","content":"description"}\n{"page":3,"type":"name","content":"page name"}\n{"page":3,"type":"title","content":"section name"}\n{"page":3,"type":"content","content":"keywords"}\n{"page":3,"type":"content","content":"description"}\n{"page":3,"type":"title","content":"section name"}\n{"page":3,"type":"content","content":"keywords"}\n{"page":3,"type":"content","content":"description"}\n{"page":3,"type":"title","content":"section name"}\n{"page":3,"type":"content","content":"keywords"}\n{"page":3,"type":"content","content":"description"}`,
},
{
role: 'assistant',
@@ -498,35 +497,18 @@ content: {{content}}`,
],
},
{
name: 'workflow:presentation:step3',
action: 'workflow:presentation:step3',
name: 'workflow:presentation:step4',
action: 'workflow:presentation:step4',
model: 'gpt-4o',
messages: [
{
role: 'system',
content:
'You are very strict text indentation judgment model, you need to judge the input and output True if it is text that has no problem with indentation, otherwise output False.',
},
{
role: 'user',
content: '{{content}}',
},
],
},
{
name: 'workflow:presentation:step5',
action: 'workflow:presentation:step5',
model: 'gpt-4o',
messages: [
{
role: 'system',
content:
"You are a text indentation format checking model with very strict formatting requirements, and you need to optimize the input so that it fully conforms to the template's indentation format and output.\nPage names, section names, titles, keywords, and content should be removed via text replacement and not retained. The first template is only allowed to be used once and as a cover, please strictly adhere to the template's hierarchical indentation and my requirement that bold, headings, and other formatting (e.g., #, **) are not allowed or penalties will be applied.",
"You are a ND-JSON text format checking model with very strict formatting requirements, and you need to optimize the input so that it fully conforms to the template's indentation format and output.\nPage names, section names, titles, keywords, and content should be removed via text replacement and not retained. The first template is only allowed to be used once and as a cover, please strictly adhere to the template's hierarchical indentation and my requirement that bold, headings, and other formatting (e.g., #, **, ```) are not allowed or penalties will be applied, no responses should contain markdown formatting.",
},
{
role: 'assistant',
content:
"You are a PPT creator. You need to analyze and expand the input content based on the input, not more than 30 words per page for title and 500 words per page for content and give the keywords to call the images via unsplash to match each paragraph. Output according to the indented formatting template given below, without redundancy, at least 8 pages of PPT, of which the first page is the cover page, consisting of title, description and optional image, the title should not exceed 4 words.\nThe following are PPT templates, you can choose any template to apply, page name, column name, title, keywords, content should be removed by text replacement, do not retain. Keywords need to be generic enough for broad, mass categorization. The output ignores template titles like template1 and template2. The first template is allowed to be used only once and as a cover, please strictly follow the template's hierarchical indentation and my requirements, bolding, headings and other formatting (e.g., #, **) are not allowed, or penalties will be applied:\n//template1:\n- {page name}\n  - {title}\n    - keywords\n    - {description}\n//template2:\n- {page name}\n  - {section name}\n    - keywords\n    - {content}\n  - {section name}\n    - keywords\n    - {content}\n//template3:\n- {page name}\n  - {section name}\n    - keywords\n    - {content}\n  - {section name}\n    - keywords\n    - {content}\n  - {section name}\n    - keywords\n    - {content}\n//template4:\n- {page name}\n  - {section name}\n    - keywords\n    - {content}\n  - {section name}\n    - keywords\n    - {content}\n  - {section name}\n    - keywords\n    - {content}\n  - {section name}\n    - keywords\n    - {content}\n//template5:\n- {page name}\n  - {section name}\n    - keywords\n    - {content}",
content: `You are a PPT creator. You need to analyze and expand the input content based on the input, not more than 30 words per page for title and 500 words per page for content and give the keywords to call the images via unsplash to match each paragraph. Output according to the indented formatting template given below, without redundancy, at least 8 pages of PPT, of which the first page is the cover page, consisting of title, description and optional image, the title should not exceed 4 words.\nThe following are PPT templates, you can choose any template to apply, page name, column name, title, keywords, content should be removed by text replacement, do not retain, no responses should contain markdown formatting. Keywords need to be generic enough for broad, mass categorization. The output ignores template titles like template1 and template2. The first template is allowed to be used only once and as a cover, please strictly follow the template's ND-JSON field, format and my requirements, or penalties will be applied:\n{"page":1,"type":"name","content":"page name"}\n{"page":1,"type":"title","content":"title"}\n{"page":1,"type":"content","content":"keywords"}\n{"page":1,"type":"content","content":"description"}\n{"page":2,"type":"name","content":"page name"}\n{"page":2,"type":"title","content":"section name"}\n{"page":2,"type":"content","content":"keywords"}\n{"page":2,"type":"content","content":"description"}\n{"page":2,"type":"title","content":"section name"}\n{"page":2,"type":"content","content":"keywords"}\n{"page":2,"type":"content","content":"description"}\n{"page":3,"type":"name","content":"page name"}\n{"page":3,"type":"title","content":"section name"}\n{"page":3,"type":"content","content":"keywords"}\n{"page":3,"type":"content","content":"description"}\n{"page":3,"type":"title","content":"section name"}\n{"page":3,"type":"content","content":"keywords"}\n{"page":3,"type":"content","content":"description"}\n{"page":3,"type":"title","content":"section name"}\n{"page":3,"type":"content","content":"keywords"}\n{"page":3,"type":"content","content":"description"}`,
},
{
role: 'user',

View File

@@ -38,10 +38,10 @@ import { CopilotProviderService } from './providers';
import { ChatSession, ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import { CopilotCapability, CopilotTextProvider } from './types';
import { CopilotWorkflowService } from './workflow';
import { CopilotWorkflowService, GraphExecutorState } from './workflow';
export interface ChatEvent {
type: 'attachment' | 'message' | 'error';
type: 'event' | 'attachment' | 'message' | 'error';
id?: string;
data: string | object;
}
@@ -134,6 +134,15 @@ export class CopilotController {
return session;
}
private prepareParams(params: Record<string, string | string[]>) {
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const jsonMode = String(params.jsonMode).toLowerCase() === 'true';
delete params.messageId;
return { messageId, jsonMode, params };
}
private getSignal(req: Request) {
const controller = new AbortController();
req.on('close', () => controller.abort());
@@ -158,9 +167,7 @@ export class CopilotController {
@Param('sessionId') sessionId: string,
@Query() params: Record<string, string | string[]>
): Promise<string> {
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const { messageId, jsonMode } = this.prepareParams(params);
const provider = await this.chooseTextProvider(
user.id,
sessionId,
@@ -170,14 +177,10 @@ export class CopilotController {
const session = await this.appendSessionMessage(sessionId, messageId);
try {
delete params.messageId;
const content = await provider.generateText(
session.finish(params),
session.model,
{
signal: this.getSignal(req),
user: user.id,
}
{ jsonMode, signal: this.getSignal(req), user: user.id }
);
session.push({
@@ -201,9 +204,7 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const { messageId, jsonMode } = this.prepareParams(params);
const provider = await this.chooseTextProvider(
user.id,
sessionId,
@@ -211,10 +212,10 @@ export class CopilotController {
);
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;
return from(
provider.generateTextStream(session.finish(params), session.model, {
jsonMode,
signal: this.getSignal(req),
user: user.id,
})
@@ -255,12 +256,8 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const { messageId, jsonMode } = this.prepareParams(params);
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;
const latestMessage = session.stashMessages.findLast(
m => m.role === 'user'
);
@@ -272,6 +269,7 @@ export class CopilotController {
return from(
this.workflow.runGraph(params, session.model, {
jsonMode,
signal: this.getSignal(req),
user: user.id,
})
@@ -280,7 +278,23 @@ export class CopilotController {
merge(
// actual chat event stream
shared$.pipe(
map(data => ({ type: 'message' as const, id: messageId, data }))
map(data =>
data.status === GraphExecutorState.EmitContent
? {
type: 'message' as const,
id: messageId,
data: data.content,
}
: {
type: 'event' as const,
id: messageId,
data: {
status: data.status,
id: data.node.id,
type: data.node.config.nodeType,
} as any,
}
)
),
// save the generated text to the session
shared$.pipe(
@@ -312,9 +326,7 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const { messageId } = this.prepareParams(params);
const { model, hasAttachment } = await this.checkRequest(
user.id,
sessionId,
@@ -331,7 +343,6 @@ export class CopilotController {
}
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;
const handleRemoteLink = this.storage.handleRemoteLink.bind(
this.storage,

View File

@@ -120,19 +120,37 @@ export class OpenAIProvider
});
}
private extractOptionFromMessages(
messages: PromptMessage[],
options: CopilotChatOptions
) {
const params: Record<string, string | string[]> = {};
for (const message of messages) {
if (message.params) {
Object.assign(params, message.params);
}
}
if (params.jsonMode && options) {
options.jsonMode = String(params.jsonMode).toLowerCase() === 'true';
}
}
protected checkParams({
messages,
embeddings,
model,
options = {},
}: {
messages?: PromptMessage[];
embeddings?: string[];
model: string;
options: CopilotChatOptions;
}) {
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
}
if (Array.isArray(messages) && messages.length > 0) {
this.extractOptionFromMessages(messages, options);
if (
messages.some(
m =>
@@ -158,6 +176,14 @@ export class OpenAIProvider
) {
throw new Error('Invalid message role');
}
// json mode need 'json' keyword in content
// ref: https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format
if (
options.jsonMode &&
!messages.some(m => m.content.toLowerCase().includes('json'))
) {
throw new Error('Prompt not support json mode');
}
} else if (
Array.isArray(embeddings) &&
embeddings.some(e => typeof e !== 'string' || !e || !e.trim())
@@ -173,13 +199,16 @@ export class OpenAIProvider
model: string = 'gpt-3.5-turbo',
options: CopilotChatOptions = {}
): Promise<string> {
this.checkParams({ messages, model });
this.checkParams({ messages, model, options });
const result = await this.instance.chat.completions.create(
{
messages: this.chatToGPTMessage(messages),
model: model,
temperature: options.temperature || 0,
max_tokens: options.maxTokens || 4096,
response_format: {
type: options.jsonMode ? 'json_object' : 'text',
},
user: options.user,
},
{ signal: options.signal }
@@ -196,7 +225,7 @@ export class OpenAIProvider
model: string = 'gpt-3.5-turbo',
options: CopilotChatOptions = {}
): AsyncIterable<string> {
this.checkParams({ messages, model });
this.checkParams({ messages, model, options });
const result = await this.instance.chat.completions.create(
{
stream: true,
@@ -204,6 +233,9 @@ export class OpenAIProvider
model: model,
temperature: options.temperature || 0,
max_tokens: options.maxTokens || 4096,
response_format: {
type: options.jsonMode ? 'json_object' : 'text',
},
user: options.user,
},
{
@@ -231,7 +263,7 @@ export class OpenAIProvider
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages];
this.checkParams({ embeddings: messages, model });
this.checkParams({ embeddings: messages, model, options });
const result = await this.instance.embeddings.create({
model: model,

View File

@@ -137,6 +137,7 @@ const CopilotProviderOptionsSchema = z.object({
});
const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.extend({
jsonMode: z.boolean().optional(),
temperature: z.number().optional(),
maxTokens: z.number().optional(),
}).optional();

View File

@@ -0,0 +1,91 @@
import { Injectable } from '@nestjs/common';
import { ChatPrompt, PromptService } from '../../prompt';
import { CopilotProviderService } from '../../providers';
import { CopilotChatOptions, CopilotImageProvider } from '../../types';
import { WorkflowNodeData, WorkflowNodeType } from '../types';
import { NodeExecuteResult, NodeExecuteState, NodeExecutorType } from './types';
import { AutoRegisteredWorkflowExecutor } from './utils';
@Injectable()
export class CopilotChatImageExecutor extends AutoRegisteredWorkflowExecutor {
constructor(
private readonly promptService: PromptService,
private readonly providerService: CopilotProviderService
) {
super();
}
private async initExecutor(
data: WorkflowNodeData
): Promise<
[
WorkflowNodeData & { nodeType: WorkflowNodeType.Basic },
ChatPrompt,
CopilotImageProvider,
]
> {
if (data.nodeType !== WorkflowNodeType.Basic) {
throw new Error(
`Executor ${this.type} not support ${data.nodeType} node`
);
}
if (!data.promptName) {
throw new Error(
`Prompt name not found when running workflow node ${data.name}`
);
}
const prompt = await this.promptService.get(data.promptName);
if (!prompt) {
throw new Error(
`Prompt ${data.promptName} not found when running workflow node ${data.name}`
);
}
const provider = await this.providerService.getProviderByModel(
prompt.model
);
if (provider && 'generateImages' in provider) {
return [data, prompt, provider];
}
throw new Error(
`Provider not found for model ${prompt.model} when running workflow node ${data.name}`
);
}
override get type() {
return NodeExecutorType.ChatImage;
}
override async *next(
data: WorkflowNodeData,
params: Record<string, string>,
options?: CopilotChatOptions
): AsyncIterable<NodeExecuteResult> {
const [{ paramKey, id }, prompt, provider] = await this.initExecutor(data);
const finalMessage = prompt.finish(params);
if (paramKey) {
// update params with custom key
yield {
type: NodeExecuteState.Params,
params: {
[paramKey]: await provider.generateImages(
finalMessage,
prompt.model,
options
),
},
};
} else {
for await (const content of provider.generateImagesStream(
finalMessage,
prompt.model,
options
)) {
yield { type: NodeExecuteState.Content, nodeId: id, content };
}
}
}
}

View File

@@ -3,13 +3,8 @@ import { Injectable } from '@nestjs/common';
import { ChatPrompt, PromptService } from '../../prompt';
import { CopilotProviderService } from '../../providers';
import { CopilotChatOptions, CopilotTextProvider } from '../../types';
import {
NodeData,
WorkflowNodeType,
WorkflowResult,
WorkflowResultType,
} from '../types';
import { WorkflowExecutorType } from './types';
import { WorkflowNodeData, WorkflowNodeType } from '../types';
import { NodeExecuteResult, NodeExecuteState, NodeExecutorType } from './types';
import { AutoRegisteredWorkflowExecutor } from './utils';
@Injectable()
@@ -22,10 +17,10 @@ export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
}
private async initExecutor(
data: NodeData
data: WorkflowNodeData
): Promise<
[
NodeData & { nodeType: WorkflowNodeType.Basic },
WorkflowNodeData & { nodeType: WorkflowNodeType.Basic },
ChatPrompt,
CopilotTextProvider,
]
@@ -36,6 +31,11 @@ export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
);
}
if (!data.promptName) {
throw new Error(
`Prompt name not found when running workflow node ${data.name}`
);
}
const prompt = await this.promptService.get(data.promptName);
if (!prompt) {
throw new Error(
@@ -55,21 +55,21 @@ export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
}
override get type() {
return WorkflowExecutorType.ChatText;
return NodeExecutorType.ChatText;
}
override async *next(
data: NodeData,
data: WorkflowNodeData,
params: Record<string, string>,
options?: CopilotChatOptions
): AsyncIterable<WorkflowResult> {
): AsyncIterable<NodeExecuteResult> {
const [{ paramKey, id }, prompt, provider] = await this.initExecutor(data);
const finalMessage = prompt.finish(params);
if (paramKey) {
// update params with custom key
yield {
type: WorkflowResultType.Params,
type: NodeExecuteState.Params,
params: {
[paramKey]: await provider.generateText(
finalMessage,
@@ -84,11 +84,7 @@ export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
prompt.model,
options
)) {
yield {
type: WorkflowResultType.Content,
nodeId: id,
content,
};
yield { type: NodeExecuteState.Content, nodeId: id, content };
}
}
}

View File

@@ -0,0 +1,64 @@
import { Injectable } from '@nestjs/common';
import { XMLValidator } from 'fast-xml-parser';
import { HtmlValidate } from 'html-validate/node';
import { WorkflowNodeData, WorkflowNodeType, WorkflowParams } from '../types';
import { NodeExecuteResult, NodeExecuteState, NodeExecutorType } from './types';
import { AutoRegisteredWorkflowExecutor } from './utils';
@Injectable()
export class CopilotCheckHtmlExecutor extends AutoRegisteredWorkflowExecutor {
private readonly html = new HtmlValidate();
private async initExecutor(
data: WorkflowNodeData
): Promise<WorkflowNodeData & { nodeType: WorkflowNodeType.Basic }> {
if (data.nodeType !== WorkflowNodeType.Basic) {
throw new Error(
`Executor ${this.type} not support ${data.nodeType} node`
);
}
return data;
}
override get type() {
return NodeExecutorType.CheckHtml;
}
private async checkHtml(
content?: string | string[],
strict?: boolean
): Promise<boolean> {
try {
if (content && typeof content === 'string') {
const ret = XMLValidator.validate(content);
if (ret === true) {
if (strict) {
const report = await this.html.validateString(content, {
extends: ['html-validate:standard'],
});
return report.valid;
}
return true;
}
}
return false;
} catch (e) {
return false;
}
}
override async *next(
data: WorkflowNodeData,
params: WorkflowParams
): AsyncIterable<NodeExecuteResult> {
const { paramKey, id } = await this.initExecutor(data);
const ret = String(await this.checkHtml(params.content, !!params.strict));
if (paramKey) {
yield { type: NodeExecuteState.Params, params: { [paramKey]: ret } };
} else {
yield { type: NodeExecuteState.Content, nodeId: id, content: ret };
}
}
}

View File

@@ -0,0 +1,53 @@
import { Injectable } from '@nestjs/common';
import { WorkflowNodeData, WorkflowNodeType, WorkflowParams } from '../types';
import { NodeExecuteResult, NodeExecuteState, NodeExecutorType } from './types';
import { AutoRegisteredWorkflowExecutor } from './utils';
@Injectable()
export class CopilotCheckJsonExecutor extends AutoRegisteredWorkflowExecutor {
constructor() {
super();
}
private async initExecutor(
data: WorkflowNodeData
): Promise<WorkflowNodeData & { nodeType: WorkflowNodeType.Basic }> {
if (data.nodeType !== WorkflowNodeType.Basic) {
throw new Error(
`Executor ${this.type} not support ${data.nodeType} node`
);
}
return data;
}
override get type() {
return NodeExecutorType.CheckJson;
}
private checkJson(content?: string | string[]): boolean {
try {
if (content && typeof content === 'string') {
JSON.parse(content);
return true;
}
return false;
} catch (e) {
return false;
}
}
override async *next(
data: WorkflowNodeData,
params: WorkflowParams
): AsyncIterable<NodeExecuteResult> {
const { paramKey, id } = await this.initExecutor(data);
const ret = String(this.checkJson(params.content));
if (paramKey) {
yield { type: NodeExecuteState.Params, params: { [paramKey]: ret } };
} else {
yield { type: NodeExecuteState.Content, nodeId: id, content: ret };
}
}
}

View File

@@ -1,7 +1,21 @@
import { CopilotChatImageExecutor } from './chat-image';
import { CopilotChatTextExecutor } from './chat-text';
import { CopilotCheckHtmlExecutor } from './check-html';
import { CopilotCheckJsonExecutor } from './check-json';
export const CopilotWorkflowExecutors = [CopilotChatTextExecutor];
export const CopilotWorkflowExecutors = [
CopilotChatImageExecutor,
CopilotChatTextExecutor,
CopilotCheckHtmlExecutor,
CopilotCheckJsonExecutor,
];
export { type WorkflowExecutor, WorkflowExecutorType } from './types';
export type { NodeExecuteResult, NodeExecutor } from './types';
export { NodeExecuteState, NodeExecutorType } from './types';
export { getWorkflowExecutor } from './utils';
export { CopilotChatTextExecutor };
export {
CopilotChatImageExecutor,
CopilotChatTextExecutor,
CopilotCheckHtmlExecutor,
CopilotCheckJsonExecutor,
};

View File

@@ -1,15 +1,32 @@
import { CopilotChatOptions } from '../../types';
import { NodeData, WorkflowResult } from '../types';
import type { WorkflowNode } from '../node';
import { WorkflowNodeData, WorkflowParams } from '../types';
export enum WorkflowExecutorType {
export enum NodeExecutorType {
ChatText = 'ChatText',
ChatImage = 'ChatImage',
CheckJson = 'CheckJson',
CheckHtml = 'CheckHtml',
}
export abstract class WorkflowExecutor {
abstract get type(): WorkflowExecutorType;
abstract next(
data: NodeData,
params: Record<string, string | string[]>,
options?: CopilotChatOptions
): AsyncIterable<WorkflowResult>;
export enum NodeExecuteState {
StartRun,
EndRun,
Params,
Content,
}
export type NodeExecuteResult =
| { type: NodeExecuteState.StartRun; nodeId: string }
| { type: NodeExecuteState.EndRun; nextNode?: WorkflowNode }
| { type: NodeExecuteState.Params; params: WorkflowParams }
| { type: NodeExecuteState.Content; nodeId: string; content: string };
export abstract class NodeExecutor {
abstract get type(): NodeExecutorType;
abstract next(
data: WorkflowNodeData,
params: WorkflowParams,
options?: CopilotChatOptions
): AsyncIterable<NodeExecuteResult>;
}

View File

@@ -1,19 +1,17 @@
import { Logger, OnModuleInit } from '@nestjs/common';
import { WorkflowExecutor, type WorkflowExecutorType } from './types';
import { NodeExecutor, type NodeExecutorType } from './types';
const WORKFLOW_EXECUTOR: Map<string, WorkflowExecutor> = new Map();
const WORKFLOW_EXECUTOR: Map<string, NodeExecutor> = new Map();
function registerWorkflowExecutor(e: WorkflowExecutor) {
function registerWorkflowExecutor(e: NodeExecutor) {
const existing = WORKFLOW_EXECUTOR.get(e.type);
if (existing && existing === e) return false;
WORKFLOW_EXECUTOR.set(e.type, e);
return true;
}
export function getWorkflowExecutor(
type: WorkflowExecutorType
): WorkflowExecutor {
export function getWorkflowExecutor(type: NodeExecutorType): NodeExecutor {
const executor = WORKFLOW_EXECUTOR.get(type);
if (!executor) {
throw new Error(`Executor ${type} not defined`);
@@ -23,7 +21,7 @@ export function getWorkflowExecutor(
}
export abstract class AutoRegisteredWorkflowExecutor
extends WorkflowExecutor
extends NodeExecutor
implements OnModuleInit
{
onModuleInit() {

View File

@@ -1,4 +1,4 @@
import { WorkflowExecutorType } from './executor';
import { NodeExecutorType } from './executor';
import type { WorkflowGraphs } from './types';
import { WorkflowNodeState, WorkflowNodeType } from './types';
@@ -10,7 +10,7 @@ export const WorkflowGraphList: WorkflowGraphs = [
id: 'start',
name: 'Start: check language',
nodeType: WorkflowNodeType.Basic,
type: WorkflowExecutorType.ChatText,
type: NodeExecutorType.ChatText,
promptName: 'workflow:presentation:step1',
paramKey: 'language',
edges: ['step2'],
@@ -19,38 +19,44 @@ export const WorkflowGraphList: WorkflowGraphs = [
id: 'step2',
name: 'Step 2: generate presentation',
nodeType: WorkflowNodeType.Basic,
type: WorkflowExecutorType.ChatText,
type: NodeExecutorType.ChatText,
promptName: 'workflow:presentation:step2',
edges: ['step3'],
},
{
id: 'step3',
name: 'Step 3: check format',
nodeType: WorkflowNodeType.Basic,
type: WorkflowExecutorType.ChatText,
promptName: 'workflow:presentation:step3',
paramKey: 'needFormat',
edges: ['step4'],
name: 'Step 3: format presentation if needed',
nodeType: WorkflowNodeType.Decision,
condition: (nodeIds: string[], params: WorkflowNodeState) => {
const lines = params.content?.split('\n') || [];
return nodeIds[
Number(
!lines.some(line => {
try {
if (line.trim()) {
JSON.parse(line);
}
return false;
} catch {
return true;
}
})
)
];
},
edges: ['step4', 'step5'],
},
{
id: 'step4',
name: 'Step 4: format presentation if needed',
nodeType: WorkflowNodeType.Decision,
condition: (nodeIds: string[], params: WorkflowNodeState) =>
nodeIds[Number(String(params.needFormat).toLowerCase() !== 'true')],
edges: ['step5', 'step6'],
name: 'Step 4: format presentation',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:presentation:step4',
edges: ['step5'],
},
{
id: 'step5',
name: 'Step 5: format presentation',
nodeType: WorkflowNodeType.Basic,
type: WorkflowExecutorType.ChatText,
promptName: 'workflow:presentation:step5',
edges: ['step6'],
},
{
id: 'step6',
name: 'Step 6: finish',
name: 'Step 5: finish',
nodeType: WorkflowNodeType.Nope,
edges: [],
},

View File

@@ -1,72 +1,8 @@
import { Injectable, Logger } from '@nestjs/common';
import { CopilotChatOptions } from '../types';
import { WorkflowGraphList } from './graph';
import { WorkflowNode } from './node';
import type { WorkflowGraph, WorkflowGraphInstances } from './types';
import { CopilotWorkflow } from './workflow';
@Injectable()
export class CopilotWorkflowService {
private readonly logger = new Logger(CopilotWorkflowService.name);
constructor() {}
private initWorkflow(graph: WorkflowGraph) {
const workflow = new Map<string, WorkflowNode>();
for (const nodeData of graph.graph) {
const { edges: _, ...data } = nodeData;
const node = new WorkflowNode(graph, data);
workflow.set(node.id, node);
}
// add edges
for (const nodeData of graph.graph) {
const node = workflow.get(nodeData.id);
if (!node) {
this.logger.error(
`Failed to init workflow ${name}: node ${nodeData.id} not found`
);
throw new Error(`Node ${nodeData.id} not found`);
}
for (const edgeId of nodeData.edges) {
const edge = workflow.get(edgeId);
if (!edge) {
this.logger.error(
`Failed to init workflow ${name}: edge ${edgeId} not found in node ${nodeData.id}`
);
throw new Error(`Edge ${edgeId} not found`);
}
node.addEdge(edge);
}
}
return workflow;
}
// TODO(@darkskygit): get workflow from database
private async getWorkflow(
graphName: string
): Promise<WorkflowGraphInstances> {
const graph = WorkflowGraphList.find(g => g.name === graphName);
if (!graph) {
throw new Error(`Graph ${graphName} not found`);
}
return this.initWorkflow(graph);
}
async *runGraph(
params: Record<string, string>,
graphName: string,
options?: CopilotChatOptions
): AsyncIterable<string> {
const workflowGraph = await this.getWorkflow(graphName);
const workflow = new CopilotWorkflow(workflowGraph);
for await (const result of workflow.runGraph(params, options)) {
yield result;
}
}
}
export { CopilotChatTextExecutor, CopilotWorkflowExecutors } from './executor';
export { WorkflowNodeType } from './types';
export { CopilotWorkflowService } from './service';
export {
type WorkflowGraph,
type WorkflowNodeData,
WorkflowNodeType,
} from './types';
export { GraphExecutorState, WorkflowGraphExecutor } from './workflow';

View File

@@ -5,33 +5,33 @@ import { Logger } from '@nestjs/common';
import Piscina from 'piscina';
import { CopilotChatOptions } from '../types';
import { getWorkflowExecutor, WorkflowExecutor } from './executor';
import type { NodeExecuteResult, NodeExecutor } from './executor';
import { getWorkflowExecutor, NodeExecuteState } from './executor';
import type {
NodeData,
WorkflowGraph,
WorkflowNodeData,
WorkflowNodeState,
WorkflowResult,
} from './types';
import { WorkflowNodeType, WorkflowResultType } from './types';
import { WorkflowNodeType } from './types';
export class WorkflowNode {
private readonly logger = new Logger(WorkflowNode.name);
private readonly edges: WorkflowNode[] = [];
private readonly parents: WorkflowNode[] = [];
private readonly executor: WorkflowExecutor | null = null;
private readonly executor: NodeExecutor | null = null;
private readonly condition:
| ((params: WorkflowNodeState) => Promise<any>)
| null = null;
constructor(
graph: WorkflowGraph,
private readonly data: NodeData
private readonly data: WorkflowNodeData
) {
if (data.nodeType === WorkflowNodeType.Basic) {
this.executor = getWorkflowExecutor(data.type);
} else if (data.nodeType === WorkflowNodeType.Decision) {
// prepare decision condition, reused in each run
const iife = `(${data.condition})(nodeIds, params)`;
const iife = `return (${data.condition})(nodeIds, params)`;
// only eval the condition in worker if graph has been modified
if (graph.modified) {
const worker = new Piscina({
@@ -55,11 +55,7 @@ export class WorkflowNode {
const func =
typeof data.condition === 'function'
? data.condition
: new Function(
'nodeIds',
'params',
`(${data.condition})(nodeIds, params)`
);
: new Function('nodeIds', 'params', iife);
this.condition = (params: WorkflowNodeState) =>
func(
this.edges.map(node => node.id),
@@ -77,7 +73,7 @@ export class WorkflowNode {
return this.data.name;
}
get config(): NodeData {
get config(): WorkflowNodeData {
return Object.assign({}, this.data);
}
@@ -106,6 +102,8 @@ export class WorkflowNode {
!this.data.condition
) {
throw new Error(`Decision block must have a condition`);
} else if (this.data.nodeType === WorkflowNodeType.Nope) {
throw new Error(`Nope block cannot have edges`);
}
node.parent = this;
this.edges.push(node);
@@ -133,8 +131,8 @@ export class WorkflowNode {
async *next(
params: WorkflowNodeState,
options?: CopilotChatOptions
): AsyncIterable<WorkflowResult> {
yield { type: WorkflowResultType.StartRun, nodeId: this.id };
): AsyncIterable<NodeExecuteResult> {
yield { type: NodeExecuteState.StartRun, nodeId: this.id };
// choose next node in graph
let nextNode: WorkflowNode | undefined = this.edges[0];
@@ -155,12 +153,12 @@ export class WorkflowNode {
yield* this.executor.next(this.data, params, options);
} else {
yield {
type: WorkflowResultType.Content,
type: NodeExecuteState.Content,
nodeId: this.id,
content: params.content,
};
}
yield { type: WorkflowResultType.EndRun, nextNode };
yield { type: NodeExecuteState.EndRun, nextNode };
}
}

View File

@@ -0,0 +1,68 @@
import { Injectable, Logger } from '@nestjs/common';
import { CopilotChatOptions } from '../types';
import { WorkflowGraphList } from './graph';
import { WorkflowNode } from './node';
import type { WorkflowGraph, WorkflowGraphInstances } from './types';
import { type GraphExecutorStatus, WorkflowGraphExecutor } from './workflow';
@Injectable()
export class CopilotWorkflowService {
private readonly logger = new Logger(CopilotWorkflowService.name);
initWorkflow(graph: WorkflowGraph) {
const workflow = new Map<string, WorkflowNode>();
for (const nodeData of graph.graph) {
const { edges: _, ...data } = nodeData;
const node = new WorkflowNode(graph, data);
workflow.set(node.id, node);
}
// add edges
for (const nodeData of graph.graph) {
const node = workflow.get(nodeData.id);
if (!node) {
this.logger.error(
`Failed to init workflow ${name}: node ${nodeData.id} not found`
);
throw new Error(`Node ${nodeData.id} not found`);
}
for (const edgeId of nodeData.edges) {
const edge = workflow.get(edgeId);
if (!edge) {
this.logger.error(
`Failed to init workflow ${name}: edge ${edgeId} not found in node ${nodeData.id}`
);
throw new Error(`Edge ${edgeId} not found`);
}
node.addEdge(edge);
}
}
return workflow;
}
// TODO(@darkskygit): get workflow from database
private async getWorkflow(
graphName: string
): Promise<WorkflowGraphInstances> {
const graph = WorkflowGraphList.find(g => g.name === graphName);
if (!graph) {
throw new Error(`Graph ${graphName} not found`);
}
return this.initWorkflow(graph);
}
async *runGraph(
params: Record<string, string>,
graphName: string,
options?: CopilotChatOptions
): AsyncIterable<GraphExecutorStatus> {
const workflowGraph = await this.getWorkflow(graphName);
const executor = new WorkflowGraphExecutor(workflowGraph);
for await (const result of executor.runGraph(params, options)) {
yield result;
}
}
}

View File

@@ -1,17 +1,19 @@
import type { WorkflowExecutorType } from './executor';
import type { NodeExecutorType } from './executor';
import type { WorkflowNode } from './node';
// ===================== node =====================
export enum WorkflowNodeType {
Basic = 'basic',
Decision = 'decision',
Nope = 'nope',
}
export type NodeData = { id: string; name: string } & (
export type WorkflowNodeData = { id: string; name: string } & (
| {
nodeType: WorkflowNodeType.Basic;
promptName: string;
type: WorkflowExecutorType;
type: NodeExecutorType;
promptName?: string;
// update the prompt params by output with the custom key
paramKey?: string;
}
@@ -25,35 +27,22 @@ export type NodeData = { id: string; name: string } & (
| { nodeType: WorkflowNodeType.Nope }
);
export type WorkflowNodeState = Record<string, string>;
export type WorkflowGraphInstances = Map<string, WorkflowNode>;
export type WorkflowGraphData = Array<NodeData & { edges: string[] }>;
// ===================== graph =====================
export type WorkflowGraphDefinition = Array<
WorkflowNodeData & { edges: string[] }
>;
export type WorkflowGraph = {
name: string;
// true if the graph has been modified
modified?: boolean;
graph: WorkflowGraphData;
graph: WorkflowGraphDefinition;
};
export type WorkflowGraphs = Array<WorkflowGraph>;
export enum WorkflowResultType {
StartRun,
EndRun,
Params,
Content,
}
// ===================== executor =====================
export type WorkflowResult =
| { type: WorkflowResultType.StartRun; nodeId: string }
| { type: WorkflowResultType.EndRun; nextNode?: WorkflowNode }
| {
type: WorkflowResultType.Params;
params: Record<string, string | string[]>;
}
| {
type: WorkflowResultType.Content;
nodeId: string;
content: string;
};
export type WorkflowGraphInstances = Map<string, WorkflowNode>;
export type WorkflowParams = Record<string, string | string[]>;
export type WorkflowNodeState = Record<string, string>;

View File

@@ -1,16 +1,25 @@
import { Logger } from '@nestjs/common';
import { CopilotChatOptions } from '../types';
import { NodeExecuteState } from './executor';
import { WorkflowNode } from './node';
import {
type WorkflowGraphInstances,
type WorkflowNodeState,
WorkflowNodeType,
WorkflowResultType,
} from './types';
import type { WorkflowGraphInstances, WorkflowNodeState } from './types';
import { WorkflowNodeType } from './types';
export class CopilotWorkflow {
private readonly logger = new Logger(CopilotWorkflow.name);
export enum GraphExecutorState {
EnterNode = 'EnterNode',
EmitContent = 'EmitContent',
ExitNode = 'ExitNode',
}
export type GraphExecutorStatus = { status: GraphExecutorState } & (
| { status: GraphExecutorState.EnterNode; node: WorkflowNode }
| { status: GraphExecutorState.EmitContent; content: string }
| { status: GraphExecutorState.ExitNode; node: WorkflowNode }
);
export class WorkflowGraphExecutor {
private readonly logger = new Logger(WorkflowGraphExecutor.name);
private readonly rootNode: WorkflowNode;
constructor(workflow: WorkflowGraphInstances) {
@@ -24,7 +33,7 @@ export class CopilotWorkflow {
async *runGraph(
params: Record<string, string>,
options?: CopilotChatOptions
): AsyncIterable<string> {
): AsyncIterable<GraphExecutorStatus> {
let currentNode: WorkflowNode | undefined = this.rootNode;
const lastParams: WorkflowNodeState = { ...params };
@@ -33,10 +42,13 @@ export class CopilotWorkflow {
let nextNode: WorkflowNode | undefined;
for await (const ret of currentNode.next(lastParams, options)) {
if (ret.type === WorkflowResultType.EndRun) {
if (ret.type === NodeExecuteState.StartRun) {
yield { status: GraphExecutorState.EnterNode, node: currentNode };
} else if (ret.type === NodeExecuteState.EndRun) {
yield { status: GraphExecutorState.ExitNode, node: currentNode };
nextNode = ret.nextNode;
break;
} else if (ret.type === WorkflowResultType.Params) {
} else if (ret.type === NodeExecuteState.Params) {
Object.assign(lastParams, ret.params);
if (currentNode.config.nodeType === WorkflowNodeType.Basic) {
const { type, promptName } = currentNode.config;
@@ -44,10 +56,13 @@ export class CopilotWorkflow {
`[${currentNode.name}][${type}][${promptName}]: update params - '${JSON.stringify(ret.params)}'`
);
}
} else if (ret.type === WorkflowResultType.Content) {
} else if (ret.type === NodeExecuteState.Content) {
if (!currentNode.hasEdges) {
// pass through content as a stream response if node is end node
yield ret.content;
yield {
status: GraphExecutorState.EmitContent,
content: ret.content,
};
} else {
result += ret.content;
}

View File

@@ -29,6 +29,7 @@ import {
signUp,
} from './utils';
import {
array2sse,
chatWithImages,
chatWithText,
chatWithTextStream,
@@ -37,6 +38,7 @@ import {
createCopilotSession,
getHistories,
MockCopilotTestProvider,
sse2array,
textToEventStream,
} from './utils/copilot';
@@ -227,9 +229,9 @@ test('should be able to chat with api', async t => {
const ret3 = await chatWithImages(app, token, sessionId, messageId);
t.is(
ret3,
array2sse(sse2array(ret3).filter(e => e.event !== 'event')),
textToEventStream(
['https://example.com/test.jpg', 'generate text to text stream'],
['https://example.com/test.jpg', 'hello '],
messageId,
'attachment'
),
@@ -258,7 +260,7 @@ test('should be able to chat with api by workflow', async t => {
);
const ret = await chatWithWorkflow(app, token, sessionId, messageId);
t.is(
ret,
array2sse(sse2array(ret).filter(e => e.event !== 'event')),
textToEventStream(['generate text to text stream'], messageId),
'should be able to chat with workflow'
);

View File

@@ -25,19 +25,24 @@ import {
import {
CopilotChatTextExecutor,
CopilotWorkflowService,
GraphExecutorState,
type WorkflowGraph,
WorkflowGraphExecutor,
type WorkflowNodeData,
WorkflowNodeType,
} from '../src/plugins/copilot/workflow';
import {
CopilotChatImageExecutor,
CopilotCheckHtmlExecutor,
CopilotCheckJsonExecutor,
getWorkflowExecutor,
WorkflowExecutorType,
NodeExecuteState,
NodeExecutorType,
} from '../src/plugins/copilot/workflow/executor';
import { AutoRegisteredWorkflowExecutor } from '../src/plugins/copilot/workflow/executor/utils';
import { WorkflowGraphList } from '../src/plugins/copilot/workflow/graph';
import {
NodeData,
WorkflowResultType,
} from '../src/plugins/copilot/workflow/types';
import { createTestingModule } from './utils';
import { MockCopilotTestProvider } from './utils/copilot';
import { MockCopilotTestProvider, WorkflowTestCases } from './utils/copilot';
const test = ava as TestFn<{
auth: AuthService;
@@ -46,7 +51,12 @@ const test = ava as TestFn<{
provider: CopilotProviderService;
session: ChatSessionService;
workflow: CopilotWorkflowService;
textWorkflowExecutor: CopilotChatTextExecutor;
executors: {
image: CopilotChatImageExecutor;
text: CopilotChatTextExecutor;
html: CopilotCheckHtmlExecutor;
json: CopilotCheckJsonExecutor;
};
}>;
test.beforeEach(async t => {
@@ -74,7 +84,6 @@ test.beforeEach(async t => {
const provider = module.get(CopilotProviderService);
const session = module.get(ChatSessionService);
const workflow = module.get(CopilotWorkflowService);
const textWorkflowExecutor = module.get(CopilotChatTextExecutor);
t.context.module = module;
t.context.auth = auth;
@@ -82,7 +91,12 @@ test.beforeEach(async t => {
t.context.provider = provider;
t.context.session = session;
t.context.workflow = workflow;
t.context.textWorkflowExecutor = textWorkflowExecutor;
t.context.executors = {
image: module.get(CopilotChatImageExecutor),
text: module.get(CopilotChatTextExecutor),
html: module.get(CopilotCheckHtmlExecutor),
json: module.get(CopilotCheckJsonExecutor),
};
});
test.afterEach.always(async t => {
@@ -563,9 +577,9 @@ test('should be able to register test provider', async t => {
// this test used to preview the final result of the workflow
// for the functional test of the API itself, refer to the follow tests
test.skip('should be able to preview workflow', async t => {
const { prompt, workflow, textWorkflowExecutor } = t.context;
const { prompt, workflow, executors } = t.context;
textWorkflowExecutor.register();
executors.text.register();
registerCopilotProvider(OpenAIProvider);
for (const p of prompts) {
@@ -577,8 +591,14 @@ test.skip('should be able to preview workflow', async t => {
{ content: 'apple company' },
'presentation'
)) {
result += ret;
console.log('stream result:', ret);
if (ret.status === GraphExecutorState.EnterNode) {
console.log('enter node:', ret.node.name);
} else if (ret.status === GraphExecutorState.ExitNode) {
console.log('exit node:', ret.node.name);
} else {
result += ret.content;
// console.log('stream result:', ret);
}
}
console.log('final stream result:', result);
t.truthy(result, 'should return result');
@@ -586,14 +606,78 @@ test.skip('should be able to preview workflow', async t => {
unregisterCopilotProvider(OpenAIProvider.type);
});
test('should be able to run workflow', async t => {
const { prompt, workflow, textWorkflowExecutor } = t.context;
const runWorkflow = async function* runWorkflow(
workflowService: CopilotWorkflowService,
graph: WorkflowGraph,
params: Record<string, string>
) {
const instance = workflowService.initWorkflow(graph);
const workflow = new WorkflowGraphExecutor(instance);
for await (const result of workflow.runGraph(params)) {
yield result;
}
};
textWorkflowExecutor.register();
test('should be able to run pre defined workflow', async t => {
const { prompt, workflow, executors } = t.context;
executors.text.register();
executors.html.register();
executors.json.register();
unregisterCopilotProvider(OpenAIProvider.type);
registerCopilotProvider(MockCopilotTestProvider);
const executor = Sinon.spy(textWorkflowExecutor, 'next');
const executor = Sinon.spy(executors.text, 'next');
for (const testCase of WorkflowTestCases) {
const { graph, prompts, callCount, input, params, result } = testCase;
console.log('running workflow test:', graph.name);
for (const p of prompts) {
await prompt.set(p.name, p.model, p.messages);
}
for (const [idx, i] of input.entries()) {
let content: string | undefined = undefined;
const param: any = Object.assign({ content: i }, params[idx]);
for await (const ret of runWorkflow(workflow, graph!, param)) {
if (ret.status === GraphExecutorState.EmitContent) {
if (!content) content = '';
content += ret.content;
}
}
t.is(
content,
result[idx],
`workflow ${graph.name} should generate correct text: ${result[idx]}`
);
t.is(
executor.callCount,
callCount[idx],
`should call executor ${callCount} times`
);
// check run order
for (const [idx, node] of graph!.graph
.filter(g => g.nodeType === WorkflowNodeType.Basic)
.entries()) {
const params = executor.getCall(idx);
t.is(params.args[0].id, node.id, 'graph id should correct');
}
}
}
unregisterCopilotProvider(MockCopilotTestProvider.type);
registerCopilotProvider(OpenAIProvider);
});
test('should be able to run workflow', async t => {
const { prompt, workflow, executors } = t.context;
executors.text.register();
unregisterCopilotProvider(OpenAIProvider.type);
registerCopilotProvider(MockCopilotTestProvider);
const executor = Sinon.spy(executors.text, 'next');
for (const p of prompts) {
await prompt.set(p.name, p.model, p.messages);
@@ -603,19 +687,21 @@ test('should be able to run workflow', async t => {
const graph = WorkflowGraphList.find(g => g.name === graphName);
t.truthy(graph, `graph ${graphName} not defined`);
// todo: use Array.fromAsync
// TODO(@darkskygit): use Array.fromAsync
let result = '';
for await (const ret of workflow.runGraph(
{ content: 'apple company' },
graphName
)) {
result += ret;
if (ret.status === GraphExecutorState.EmitContent) {
result += ret;
}
}
t.assert(result, 'generate text to text stream');
// presentation workflow has condition node, it will always false
// so the latest 2 nodes will not be executed
const callCount = graph!.graph.length - 3;
const callCount = graph!.graph.length - 2;
t.is(
executor.callCount,
callCount,
@@ -627,20 +713,18 @@ test('should be able to run workflow', async t => {
.entries()) {
const params = executor.getCall(idx);
if (idx < callCount) {
t.is(params.args[0].id, node.id, 'graph id should correct');
t.is(params.args[0].id, node.id, 'graph id should correct');
t.is(
params.args[1].content,
'generate text to text stream',
'graph params should correct'
);
t.is(
params.args[1].language,
'generate text to text',
'graph params should correct'
);
}
t.is(
params.args[1].content,
'generate text to text stream',
'graph params should correct'
);
t.is(
params.args[1].language,
'generate text to text',
'graph params should correct'
);
}
unregisterCopilotProvider(MockCopilotTestProvider.type);
@@ -658,29 +742,33 @@ const wrapAsyncIter = async <T>(iter: AsyncIterable<T>) => {
};
test('should be able to run executor', async t => {
const { textWorkflowExecutor } = t.context;
const { executors } = t.context;
textWorkflowExecutor.register();
const executor = getWorkflowExecutor(textWorkflowExecutor.type);
t.is(executor.type, textWorkflowExecutor.type, 'should get executor');
const assertExecutor = async (proto: AutoRegisteredWorkflowExecutor) => {
proto.register();
const executor = getWorkflowExecutor(proto.type);
t.is(executor.type, proto.type, 'should get executor');
await t.throwsAsync(
wrapAsyncIter(
executor.next(
{ id: 'nope', name: 'nope', nodeType: WorkflowNodeType.Nope },
{}
)
),
{ instanceOf: Error },
'should throw error if run non basic node'
);
};
await t.throwsAsync(
wrapAsyncIter(
executor.next(
{ id: 'nope', name: 'nope', nodeType: WorkflowNodeType.Nope },
{}
)
),
{ instanceOf: Error },
'should throw error if run non basic node'
);
await assertExecutor(executors.image);
await assertExecutor(executors.text);
});
test('should be able to run text executor', async t => {
const { textWorkflowExecutor, provider, prompt } = t.context;
const { executors, provider, prompt } = t.context;
textWorkflowExecutor.register();
const executor = getWorkflowExecutor(textWorkflowExecutor.type);
executors.text.register();
const executor = getWorkflowExecutor(executors.text.type);
unregisterCopilotProvider(OpenAIProvider.type);
registerCopilotProvider(MockCopilotTestProvider);
await prompt.set('test', 'test', [
@@ -692,12 +780,12 @@ test('should be able to run text executor', async t => {
const text = Sinon.spy(testProvider, 'generateText');
const textStream = Sinon.spy(testProvider, 'generateTextStream');
const nodeData: NodeData = {
const nodeData: WorkflowNodeData = {
id: 'basic',
name: 'basic',
nodeType: WorkflowNodeType.Basic,
promptName: 'test',
type: WorkflowExecutorType.ChatText,
type: NodeExecutorType.ChatText,
};
// text
@@ -708,7 +796,7 @@ test('should be able to run text executor', async t => {
t.deepEqual(ret, [
{
type: WorkflowResultType.Params,
type: NodeExecuteState.Params,
params: { key: 'generate text to text' },
},
]);
@@ -732,7 +820,7 @@ test('should be able to run text executor', async t => {
Array.from('generate text to text stream').map(t => ({
content: t,
nodeId: 'basic',
type: WorkflowResultType.Content,
type: NodeExecuteState.Content,
}))
);
t.deepEqual(
@@ -746,3 +834,84 @@ test('should be able to run text executor', async t => {
unregisterCopilotProvider(MockCopilotTestProvider.type);
registerCopilotProvider(OpenAIProvider);
});
test('should be able to run image executor', async t => {
const { executors, provider, prompt } = t.context;
executors.image.register();
const executor = getWorkflowExecutor(executors.image.type);
unregisterCopilotProvider(OpenAIProvider.type);
registerCopilotProvider(MockCopilotTestProvider);
await prompt.set('test', 'test', [
{ role: 'user', content: 'tag1, tag2, tag3, {{#tags}}{{.}}, {{/tags}}' },
]);
// mock provider
const testProvider =
(await provider.getProviderByModel<CopilotCapability.TextToImage>('test'))!;
const image = Sinon.spy(testProvider, 'generateImages');
const imageStream = Sinon.spy(testProvider, 'generateImagesStream');
const nodeData: WorkflowNodeData = {
id: 'basic',
name: 'basic',
nodeType: WorkflowNodeType.Basic,
promptName: 'test',
type: NodeExecutorType.ChatText,
};
// image
{
const ret = await wrapAsyncIter(
executor.next(
{ ...nodeData, paramKey: 'key' },
{ tags: ['tag4', 'tag5'] }
)
);
t.deepEqual(ret, [
{
type: NodeExecuteState.Params,
params: {
key: [
'https://example.com/test.jpg',
'tag1, tag2, tag3, tag4, tag5, ',
],
},
},
]);
t.deepEqual(
image.lastCall.args[0][0].content,
'tag1, tag2, tag3, tag4, tag5, ',
'should render the prompt with params array'
);
}
// image stream with attachment
{
const ret = await wrapAsyncIter(
executor.next(nodeData, {
attachments: ['https://affine.pro/example.jpg'],
})
);
t.deepEqual(
ret,
Array.from(['https://example.com/test.jpg', 'tag1, tag2, tag3, ']).map(
t => ({
content: t,
nodeId: 'basic',
type: NodeExecuteState.Content,
})
)
);
t.deepEqual(
imageStream.lastCall.args[0][0].params?.attachments,
['https://affine.pro/example.jpg'],
'should pass attachments to provider'
);
}
Sinon.restore();
unregisterCopilotProvider(MockCopilotTestProvider.type);
registerCopilotProvider(OpenAIProvider);
});

View File

@@ -9,6 +9,8 @@ import {
} from '../../src/plugins/copilot/providers/openai';
import {
CopilotCapability,
CopilotChatOptions,
CopilotEmbeddingOptions,
CopilotImageToImageProvider,
CopilotImageToTextProvider,
CopilotProviderType,
@@ -17,6 +19,12 @@ import {
CopilotTextToTextProvider,
PromptMessage,
} from '../../src/plugins/copilot/types';
import { NodeExecutorType } from '../../src/plugins/copilot/workflow/executor';
import {
WorkflowGraph,
WorkflowNodeType,
WorkflowParams,
} from '../../src/plugins/copilot/workflow/types';
import { gql } from './common';
import { handleGraphQLError } from './utils';
@@ -72,28 +80,18 @@ export class MockCopilotTestProvider
override async generateText(
messages: PromptMessage[],
model: string = 'test',
_options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
} = {}
options: CopilotChatOptions = {}
): Promise<string> {
this.checkParams({ messages, model });
this.checkParams({ messages, model, options });
return 'generate text to text';
}
override async *generateTextStream(
messages: PromptMessage[],
model: string = 'gpt-3.5-turbo',
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
} = {}
options: CopilotChatOptions = {}
): AsyncIterable<string> {
this.checkParams({ messages, model });
this.checkParams({ messages, model, options });
const result = 'generate text to text stream';
for await (const message of result) {
@@ -109,14 +107,10 @@ export class MockCopilotTestProvider
override async generateEmbedding(
messages: string | string[],
model: string,
options: {
dimensions: number;
signal?: AbortSignal;
user?: string;
} = { dimensions: DEFAULT_DIMENSIONS }
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages];
this.checkParams({ embeddings: messages, model });
this.checkParams({ embeddings: messages, model, options });
return [Array.from(randomBytes(options.dimensions)).map(v => v % 128)];
}
@@ -130,7 +124,7 @@ export class MockCopilotTestProvider
user?: string;
} = {}
): Promise<Array<string>> {
const { content: prompt } = messages.pop() || {};
const { content: prompt } = messages[0] || {};
if (!prompt) {
throw new Error('Prompt is required');
}
@@ -253,6 +247,32 @@ export async function chatWithImages(
return chatWithText(app, userToken, sessionId, messageId, '/images');
}
export function sse2array(eventSource: string) {
const blocks = eventSource.replace(/^\n(.*?)\n$/, '$1').split(/\n\n+/);
return blocks.map(block =>
block.split('\n').reduce(
(prev, curr) => {
const [key, ...values] = curr.split(': ');
return Object.assign(prev, { [key]: values.join(': ') });
},
{} as Record<string, string>
)
);
}
export function array2sse(blocks: Record<string, string>[]) {
return blocks
.map(
e =>
'\n' +
Object.entries(e)
.filter(([k]) => !!k)
.map(([k, v]) => `${k}: ${v}`)
.join('\n')
)
.join('\n');
}
export function textToEventStream(
content: string | string[],
id: string,
@@ -331,3 +351,103 @@ export async function getHistories(
return res.body.data.currentUser?.copilot?.histories || [];
}
type Prompt = { name: string; model: string; messages: PromptMessage[] };
type WorkflowTestCase = {
graph: WorkflowGraph;
prompts: Prompt[];
callCount: number[];
input: string[];
params: WorkflowParams[];
result: (string | undefined)[];
};
export const WorkflowTestCases: WorkflowTestCase[] = [
{
prompts: [
{
name: 'test1',
model: 'test',
messages: [{ role: 'user', content: '{{content}}' }],
},
],
graph: {
name: 'test chat text node',
graph: [
{
id: 'start',
name: 'test chat text node',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'test1',
edges: [],
},
],
},
callCount: [1],
input: ['test'],
params: [],
result: ['generate text to text stream'],
},
{
prompts: [],
graph: {
name: 'test check json node',
graph: [
{
id: 'start',
name: 'basic node',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.CheckJson,
edges: [],
},
],
},
callCount: [1, 1],
input: ['{"test": "true"}', '{"test": '],
params: [],
result: ['true', 'false'],
},
{
prompts: [],
graph: {
name: 'test check html node',
graph: [
{
id: 'start',
name: 'basic node',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.CheckHtml,
edges: [],
},
],
},
callCount: [1, 1, 1, 1],
params: [{}, { strict: 'true' }, {}, {}],
input: [
'<html><span /></html>',
'<html><span /></html>',
'<img src="http://123.com/1.jpg" />',
'{"test": "true"}',
],
result: ['true', 'false', 'true', 'false'],
},
{
prompts: [],
graph: {
name: 'test nope node',
graph: [
{
id: 'start',
name: 'nope node',
nodeType: WorkflowNodeType.Nope,
edges: [],
},
],
},
callCount: [1],
input: ['test'],
params: [],
result: ['test'],
},
];