mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 20:38:52 +00:00
feat: add more workflow executor (#7231)
This commit is contained in:
@@ -38,10 +38,10 @@ import { CopilotProviderService } from './providers';
|
||||
import { ChatSession, ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
import { CopilotCapability, CopilotTextProvider } from './types';
|
||||
import { CopilotWorkflowService } from './workflow';
|
||||
import { CopilotWorkflowService, GraphExecutorState } from './workflow';
|
||||
|
||||
export interface ChatEvent {
|
||||
type: 'attachment' | 'message' | 'error';
|
||||
type: 'event' | 'attachment' | 'message' | 'error';
|
||||
id?: string;
|
||||
data: string | object;
|
||||
}
|
||||
@@ -134,6 +134,15 @@ export class CopilotController {
|
||||
return session;
|
||||
}
|
||||
|
||||
private prepareParams(params: Record<string, string | string[]>) {
|
||||
const messageId = Array.isArray(params.messageId)
|
||||
? params.messageId[0]
|
||||
: params.messageId;
|
||||
const jsonMode = String(params.jsonMode).toLowerCase() === 'true';
|
||||
delete params.messageId;
|
||||
return { messageId, jsonMode, params };
|
||||
}
|
||||
|
||||
private getSignal(req: Request) {
|
||||
const controller = new AbortController();
|
||||
req.on('close', () => controller.abort());
|
||||
@@ -158,9 +167,7 @@ export class CopilotController {
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query() params: Record<string, string | string[]>
|
||||
): Promise<string> {
|
||||
const messageId = Array.isArray(params.messageId)
|
||||
? params.messageId[0]
|
||||
: params.messageId;
|
||||
const { messageId, jsonMode } = this.prepareParams(params);
|
||||
const provider = await this.chooseTextProvider(
|
||||
user.id,
|
||||
sessionId,
|
||||
@@ -170,14 +177,10 @@ export class CopilotController {
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
|
||||
try {
|
||||
delete params.messageId;
|
||||
const content = await provider.generateText(
|
||||
session.finish(params),
|
||||
session.model,
|
||||
{
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
}
|
||||
{ jsonMode, signal: this.getSignal(req), user: user.id }
|
||||
);
|
||||
|
||||
session.push({
|
||||
@@ -201,9 +204,7 @@ export class CopilotController {
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
const messageId = Array.isArray(params.messageId)
|
||||
? params.messageId[0]
|
||||
: params.messageId;
|
||||
const { messageId, jsonMode } = this.prepareParams(params);
|
||||
const provider = await this.chooseTextProvider(
|
||||
user.id,
|
||||
sessionId,
|
||||
@@ -211,10 +212,10 @@ export class CopilotController {
|
||||
);
|
||||
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
delete params.messageId;
|
||||
|
||||
return from(
|
||||
provider.generateTextStream(session.finish(params), session.model, {
|
||||
jsonMode,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
@@ -255,12 +256,8 @@ export class CopilotController {
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
const messageId = Array.isArray(params.messageId)
|
||||
? params.messageId[0]
|
||||
: params.messageId;
|
||||
|
||||
const { messageId, jsonMode } = this.prepareParams(params);
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
delete params.messageId;
|
||||
const latestMessage = session.stashMessages.findLast(
|
||||
m => m.role === 'user'
|
||||
);
|
||||
@@ -272,6 +269,7 @@ export class CopilotController {
|
||||
|
||||
return from(
|
||||
this.workflow.runGraph(params, session.model, {
|
||||
jsonMode,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
@@ -280,7 +278,23 @@ export class CopilotController {
|
||||
merge(
|
||||
// actual chat event stream
|
||||
shared$.pipe(
|
||||
map(data => ({ type: 'message' as const, id: messageId, data }))
|
||||
map(data =>
|
||||
data.status === GraphExecutorState.EmitContent
|
||||
? {
|
||||
type: 'message' as const,
|
||||
id: messageId,
|
||||
data: data.content,
|
||||
}
|
||||
: {
|
||||
type: 'event' as const,
|
||||
id: messageId,
|
||||
data: {
|
||||
status: data.status,
|
||||
id: data.node.id,
|
||||
type: data.node.config.nodeType,
|
||||
} as any,
|
||||
}
|
||||
)
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
@@ -312,9 +326,7 @@ export class CopilotController {
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
const messageId = Array.isArray(params.messageId)
|
||||
? params.messageId[0]
|
||||
: params.messageId;
|
||||
const { messageId } = this.prepareParams(params);
|
||||
const { model, hasAttachment } = await this.checkRequest(
|
||||
user.id,
|
||||
sessionId,
|
||||
@@ -331,7 +343,6 @@ export class CopilotController {
|
||||
}
|
||||
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
delete params.messageId;
|
||||
|
||||
const handleRemoteLink = this.storage.handleRemoteLink.bind(
|
||||
this.storage,
|
||||
|
||||
@@ -120,19 +120,37 @@ export class OpenAIProvider
|
||||
});
|
||||
}
|
||||
|
||||
private extractOptionFromMessages(
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions
|
||||
) {
|
||||
const params: Record<string, string | string[]> = {};
|
||||
for (const message of messages) {
|
||||
if (message.params) {
|
||||
Object.assign(params, message.params);
|
||||
}
|
||||
}
|
||||
if (params.jsonMode && options) {
|
||||
options.jsonMode = String(params.jsonMode).toLowerCase() === 'true';
|
||||
}
|
||||
}
|
||||
|
||||
protected checkParams({
|
||||
messages,
|
||||
embeddings,
|
||||
model,
|
||||
options = {},
|
||||
}: {
|
||||
messages?: PromptMessage[];
|
||||
embeddings?: string[];
|
||||
model: string;
|
||||
options: CopilotChatOptions;
|
||||
}) {
|
||||
if (!this.availableModels.includes(model)) {
|
||||
throw new Error(`Invalid model: ${model}`);
|
||||
}
|
||||
if (Array.isArray(messages) && messages.length > 0) {
|
||||
this.extractOptionFromMessages(messages, options);
|
||||
if (
|
||||
messages.some(
|
||||
m =>
|
||||
@@ -158,6 +176,14 @@ export class OpenAIProvider
|
||||
) {
|
||||
throw new Error('Invalid message role');
|
||||
}
|
||||
// json mode need 'json' keyword in content
|
||||
// ref: https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format
|
||||
if (
|
||||
options.jsonMode &&
|
||||
!messages.some(m => m.content.toLowerCase().includes('json'))
|
||||
) {
|
||||
throw new Error('Prompt not support json mode');
|
||||
}
|
||||
} else if (
|
||||
Array.isArray(embeddings) &&
|
||||
embeddings.some(e => typeof e !== 'string' || !e || !e.trim())
|
||||
@@ -173,13 +199,16 @@ export class OpenAIProvider
|
||||
model: string = 'gpt-3.5-turbo',
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
this.checkParams({ messages, model });
|
||||
this.checkParams({ messages, model, options });
|
||||
const result = await this.instance.chat.completions.create(
|
||||
{
|
||||
messages: this.chatToGPTMessage(messages),
|
||||
model: model,
|
||||
temperature: options.temperature || 0,
|
||||
max_tokens: options.maxTokens || 4096,
|
||||
response_format: {
|
||||
type: options.jsonMode ? 'json_object' : 'text',
|
||||
},
|
||||
user: options.user,
|
||||
},
|
||||
{ signal: options.signal }
|
||||
@@ -196,7 +225,7 @@ export class OpenAIProvider
|
||||
model: string = 'gpt-3.5-turbo',
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
this.checkParams({ messages, model });
|
||||
this.checkParams({ messages, model, options });
|
||||
const result = await this.instance.chat.completions.create(
|
||||
{
|
||||
stream: true,
|
||||
@@ -204,6 +233,9 @@ export class OpenAIProvider
|
||||
model: model,
|
||||
temperature: options.temperature || 0,
|
||||
max_tokens: options.maxTokens || 4096,
|
||||
response_format: {
|
||||
type: options.jsonMode ? 'json_object' : 'text',
|
||||
},
|
||||
user: options.user,
|
||||
},
|
||||
{
|
||||
@@ -231,7 +263,7 @@ export class OpenAIProvider
|
||||
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
|
||||
): Promise<number[][]> {
|
||||
messages = Array.isArray(messages) ? messages : [messages];
|
||||
this.checkParams({ embeddings: messages, model });
|
||||
this.checkParams({ embeddings: messages, model, options });
|
||||
|
||||
const result = await this.instance.embeddings.create({
|
||||
model: model,
|
||||
|
||||
@@ -137,6 +137,7 @@ const CopilotProviderOptionsSchema = z.object({
|
||||
});
|
||||
|
||||
const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.extend({
|
||||
jsonMode: z.boolean().optional(),
|
||||
temperature: z.number().optional(),
|
||||
maxTokens: z.number().optional(),
|
||||
}).optional();
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { ChatPrompt, PromptService } from '../../prompt';
|
||||
import { CopilotProviderService } from '../../providers';
|
||||
import { CopilotChatOptions, CopilotImageProvider } from '../../types';
|
||||
import { WorkflowNodeData, WorkflowNodeType } from '../types';
|
||||
import { NodeExecuteResult, NodeExecuteState, NodeExecutorType } from './types';
|
||||
import { AutoRegisteredWorkflowExecutor } from './utils';
|
||||
|
||||
@Injectable()
|
||||
export class CopilotChatImageExecutor extends AutoRegisteredWorkflowExecutor {
|
||||
constructor(
|
||||
private readonly promptService: PromptService,
|
||||
private readonly providerService: CopilotProviderService
|
||||
) {
|
||||
super();
|
||||
}
|
||||
|
||||
private async initExecutor(
|
||||
data: WorkflowNodeData
|
||||
): Promise<
|
||||
[
|
||||
WorkflowNodeData & { nodeType: WorkflowNodeType.Basic },
|
||||
ChatPrompt,
|
||||
CopilotImageProvider,
|
||||
]
|
||||
> {
|
||||
if (data.nodeType !== WorkflowNodeType.Basic) {
|
||||
throw new Error(
|
||||
`Executor ${this.type} not support ${data.nodeType} node`
|
||||
);
|
||||
}
|
||||
|
||||
if (!data.promptName) {
|
||||
throw new Error(
|
||||
`Prompt name not found when running workflow node ${data.name}`
|
||||
);
|
||||
}
|
||||
const prompt = await this.promptService.get(data.promptName);
|
||||
if (!prompt) {
|
||||
throw new Error(
|
||||
`Prompt ${data.promptName} not found when running workflow node ${data.name}`
|
||||
);
|
||||
}
|
||||
const provider = await this.providerService.getProviderByModel(
|
||||
prompt.model
|
||||
);
|
||||
if (provider && 'generateImages' in provider) {
|
||||
return [data, prompt, provider];
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
`Provider not found for model ${prompt.model} when running workflow node ${data.name}`
|
||||
);
|
||||
}
|
||||
|
||||
override get type() {
|
||||
return NodeExecutorType.ChatImage;
|
||||
}
|
||||
|
||||
override async *next(
|
||||
data: WorkflowNodeData,
|
||||
params: Record<string, string>,
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<NodeExecuteResult> {
|
||||
const [{ paramKey, id }, prompt, provider] = await this.initExecutor(data);
|
||||
|
||||
const finalMessage = prompt.finish(params);
|
||||
if (paramKey) {
|
||||
// update params with custom key
|
||||
yield {
|
||||
type: NodeExecuteState.Params,
|
||||
params: {
|
||||
[paramKey]: await provider.generateImages(
|
||||
finalMessage,
|
||||
prompt.model,
|
||||
options
|
||||
),
|
||||
},
|
||||
};
|
||||
} else {
|
||||
for await (const content of provider.generateImagesStream(
|
||||
finalMessage,
|
||||
prompt.model,
|
||||
options
|
||||
)) {
|
||||
yield { type: NodeExecuteState.Content, nodeId: id, content };
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,13 +3,8 @@ import { Injectable } from '@nestjs/common';
|
||||
import { ChatPrompt, PromptService } from '../../prompt';
|
||||
import { CopilotProviderService } from '../../providers';
|
||||
import { CopilotChatOptions, CopilotTextProvider } from '../../types';
|
||||
import {
|
||||
NodeData,
|
||||
WorkflowNodeType,
|
||||
WorkflowResult,
|
||||
WorkflowResultType,
|
||||
} from '../types';
|
||||
import { WorkflowExecutorType } from './types';
|
||||
import { WorkflowNodeData, WorkflowNodeType } from '../types';
|
||||
import { NodeExecuteResult, NodeExecuteState, NodeExecutorType } from './types';
|
||||
import { AutoRegisteredWorkflowExecutor } from './utils';
|
||||
|
||||
@Injectable()
|
||||
@@ -22,10 +17,10 @@ export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
|
||||
}
|
||||
|
||||
private async initExecutor(
|
||||
data: NodeData
|
||||
data: WorkflowNodeData
|
||||
): Promise<
|
||||
[
|
||||
NodeData & { nodeType: WorkflowNodeType.Basic },
|
||||
WorkflowNodeData & { nodeType: WorkflowNodeType.Basic },
|
||||
ChatPrompt,
|
||||
CopilotTextProvider,
|
||||
]
|
||||
@@ -36,6 +31,11 @@ export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
|
||||
);
|
||||
}
|
||||
|
||||
if (!data.promptName) {
|
||||
throw new Error(
|
||||
`Prompt name not found when running workflow node ${data.name}`
|
||||
);
|
||||
}
|
||||
const prompt = await this.promptService.get(data.promptName);
|
||||
if (!prompt) {
|
||||
throw new Error(
|
||||
@@ -55,21 +55,21 @@ export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
|
||||
}
|
||||
|
||||
override get type() {
|
||||
return WorkflowExecutorType.ChatText;
|
||||
return NodeExecutorType.ChatText;
|
||||
}
|
||||
|
||||
override async *next(
|
||||
data: NodeData,
|
||||
data: WorkflowNodeData,
|
||||
params: Record<string, string>,
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<WorkflowResult> {
|
||||
): AsyncIterable<NodeExecuteResult> {
|
||||
const [{ paramKey, id }, prompt, provider] = await this.initExecutor(data);
|
||||
|
||||
const finalMessage = prompt.finish(params);
|
||||
if (paramKey) {
|
||||
// update params with custom key
|
||||
yield {
|
||||
type: WorkflowResultType.Params,
|
||||
type: NodeExecuteState.Params,
|
||||
params: {
|
||||
[paramKey]: await provider.generateText(
|
||||
finalMessage,
|
||||
@@ -84,11 +84,7 @@ export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
|
||||
prompt.model,
|
||||
options
|
||||
)) {
|
||||
yield {
|
||||
type: WorkflowResultType.Content,
|
||||
nodeId: id,
|
||||
content,
|
||||
};
|
||||
yield { type: NodeExecuteState.Content, nodeId: id, content };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { XMLValidator } from 'fast-xml-parser';
|
||||
import { HtmlValidate } from 'html-validate/node';
|
||||
|
||||
import { WorkflowNodeData, WorkflowNodeType, WorkflowParams } from '../types';
|
||||
import { NodeExecuteResult, NodeExecuteState, NodeExecutorType } from './types';
|
||||
import { AutoRegisteredWorkflowExecutor } from './utils';
|
||||
|
||||
@Injectable()
|
||||
export class CopilotCheckHtmlExecutor extends AutoRegisteredWorkflowExecutor {
|
||||
private readonly html = new HtmlValidate();
|
||||
|
||||
private async initExecutor(
|
||||
data: WorkflowNodeData
|
||||
): Promise<WorkflowNodeData & { nodeType: WorkflowNodeType.Basic }> {
|
||||
if (data.nodeType !== WorkflowNodeType.Basic) {
|
||||
throw new Error(
|
||||
`Executor ${this.type} not support ${data.nodeType} node`
|
||||
);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
override get type() {
|
||||
return NodeExecutorType.CheckHtml;
|
||||
}
|
||||
|
||||
private async checkHtml(
|
||||
content?: string | string[],
|
||||
strict?: boolean
|
||||
): Promise<boolean> {
|
||||
try {
|
||||
if (content && typeof content === 'string') {
|
||||
const ret = XMLValidator.validate(content);
|
||||
if (ret === true) {
|
||||
if (strict) {
|
||||
const report = await this.html.validateString(content, {
|
||||
extends: ['html-validate:standard'],
|
||||
});
|
||||
return report.valid;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
} catch (e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
override async *next(
|
||||
data: WorkflowNodeData,
|
||||
params: WorkflowParams
|
||||
): AsyncIterable<NodeExecuteResult> {
|
||||
const { paramKey, id } = await this.initExecutor(data);
|
||||
|
||||
const ret = String(await this.checkHtml(params.content, !!params.strict));
|
||||
if (paramKey) {
|
||||
yield { type: NodeExecuteState.Params, params: { [paramKey]: ret } };
|
||||
} else {
|
||||
yield { type: NodeExecuteState.Content, nodeId: id, content: ret };
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { WorkflowNodeData, WorkflowNodeType, WorkflowParams } from '../types';
|
||||
import { NodeExecuteResult, NodeExecuteState, NodeExecutorType } from './types';
|
||||
import { AutoRegisteredWorkflowExecutor } from './utils';
|
||||
|
||||
@Injectable()
|
||||
export class CopilotCheckJsonExecutor extends AutoRegisteredWorkflowExecutor {
|
||||
constructor() {
|
||||
super();
|
||||
}
|
||||
|
||||
private async initExecutor(
|
||||
data: WorkflowNodeData
|
||||
): Promise<WorkflowNodeData & { nodeType: WorkflowNodeType.Basic }> {
|
||||
if (data.nodeType !== WorkflowNodeType.Basic) {
|
||||
throw new Error(
|
||||
`Executor ${this.type} not support ${data.nodeType} node`
|
||||
);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
override get type() {
|
||||
return NodeExecutorType.CheckJson;
|
||||
}
|
||||
|
||||
private checkJson(content?: string | string[]): boolean {
|
||||
try {
|
||||
if (content && typeof content === 'string') {
|
||||
JSON.parse(content);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} catch (e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
override async *next(
|
||||
data: WorkflowNodeData,
|
||||
params: WorkflowParams
|
||||
): AsyncIterable<NodeExecuteResult> {
|
||||
const { paramKey, id } = await this.initExecutor(data);
|
||||
|
||||
const ret = String(this.checkJson(params.content));
|
||||
if (paramKey) {
|
||||
yield { type: NodeExecuteState.Params, params: { [paramKey]: ret } };
|
||||
} else {
|
||||
yield { type: NodeExecuteState.Content, nodeId: id, content: ret };
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,21 @@
|
||||
import { CopilotChatImageExecutor } from './chat-image';
|
||||
import { CopilotChatTextExecutor } from './chat-text';
|
||||
import { CopilotCheckHtmlExecutor } from './check-html';
|
||||
import { CopilotCheckJsonExecutor } from './check-json';
|
||||
|
||||
export const CopilotWorkflowExecutors = [CopilotChatTextExecutor];
|
||||
export const CopilotWorkflowExecutors = [
|
||||
CopilotChatImageExecutor,
|
||||
CopilotChatTextExecutor,
|
||||
CopilotCheckHtmlExecutor,
|
||||
CopilotCheckJsonExecutor,
|
||||
];
|
||||
|
||||
export { type WorkflowExecutor, WorkflowExecutorType } from './types';
|
||||
export type { NodeExecuteResult, NodeExecutor } from './types';
|
||||
export { NodeExecuteState, NodeExecutorType } from './types';
|
||||
export { getWorkflowExecutor } from './utils';
|
||||
export { CopilotChatTextExecutor };
|
||||
export {
|
||||
CopilotChatImageExecutor,
|
||||
CopilotChatTextExecutor,
|
||||
CopilotCheckHtmlExecutor,
|
||||
CopilotCheckJsonExecutor,
|
||||
};
|
||||
|
||||
@@ -1,15 +1,32 @@
|
||||
import { CopilotChatOptions } from '../../types';
|
||||
import { NodeData, WorkflowResult } from '../types';
|
||||
import type { WorkflowNode } from '../node';
|
||||
import { WorkflowNodeData, WorkflowParams } from '../types';
|
||||
|
||||
export enum WorkflowExecutorType {
|
||||
export enum NodeExecutorType {
|
||||
ChatText = 'ChatText',
|
||||
ChatImage = 'ChatImage',
|
||||
CheckJson = 'CheckJson',
|
||||
CheckHtml = 'CheckHtml',
|
||||
}
|
||||
|
||||
export abstract class WorkflowExecutor {
|
||||
abstract get type(): WorkflowExecutorType;
|
||||
abstract next(
|
||||
data: NodeData,
|
||||
params: Record<string, string | string[]>,
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<WorkflowResult>;
|
||||
export enum NodeExecuteState {
|
||||
StartRun,
|
||||
EndRun,
|
||||
Params,
|
||||
Content,
|
||||
}
|
||||
|
||||
export type NodeExecuteResult =
|
||||
| { type: NodeExecuteState.StartRun; nodeId: string }
|
||||
| { type: NodeExecuteState.EndRun; nextNode?: WorkflowNode }
|
||||
| { type: NodeExecuteState.Params; params: WorkflowParams }
|
||||
| { type: NodeExecuteState.Content; nodeId: string; content: string };
|
||||
|
||||
export abstract class NodeExecutor {
|
||||
abstract get type(): NodeExecutorType;
|
||||
abstract next(
|
||||
data: WorkflowNodeData,
|
||||
params: WorkflowParams,
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<NodeExecuteResult>;
|
||||
}
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
import { Logger, OnModuleInit } from '@nestjs/common';
|
||||
|
||||
import { WorkflowExecutor, type WorkflowExecutorType } from './types';
|
||||
import { NodeExecutor, type NodeExecutorType } from './types';
|
||||
|
||||
const WORKFLOW_EXECUTOR: Map<string, WorkflowExecutor> = new Map();
|
||||
const WORKFLOW_EXECUTOR: Map<string, NodeExecutor> = new Map();
|
||||
|
||||
function registerWorkflowExecutor(e: WorkflowExecutor) {
|
||||
function registerWorkflowExecutor(e: NodeExecutor) {
|
||||
const existing = WORKFLOW_EXECUTOR.get(e.type);
|
||||
if (existing && existing === e) return false;
|
||||
WORKFLOW_EXECUTOR.set(e.type, e);
|
||||
return true;
|
||||
}
|
||||
|
||||
export function getWorkflowExecutor(
|
||||
type: WorkflowExecutorType
|
||||
): WorkflowExecutor {
|
||||
export function getWorkflowExecutor(type: NodeExecutorType): NodeExecutor {
|
||||
const executor = WORKFLOW_EXECUTOR.get(type);
|
||||
if (!executor) {
|
||||
throw new Error(`Executor ${type} not defined`);
|
||||
@@ -23,7 +21,7 @@ export function getWorkflowExecutor(
|
||||
}
|
||||
|
||||
export abstract class AutoRegisteredWorkflowExecutor
|
||||
extends WorkflowExecutor
|
||||
extends NodeExecutor
|
||||
implements OnModuleInit
|
||||
{
|
||||
onModuleInit() {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { WorkflowExecutorType } from './executor';
|
||||
import { NodeExecutorType } from './executor';
|
||||
import type { WorkflowGraphs } from './types';
|
||||
import { WorkflowNodeState, WorkflowNodeType } from './types';
|
||||
|
||||
@@ -10,7 +10,7 @@ export const WorkflowGraphList: WorkflowGraphs = [
|
||||
id: 'start',
|
||||
name: 'Start: check language',
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
type: WorkflowExecutorType.ChatText,
|
||||
type: NodeExecutorType.ChatText,
|
||||
promptName: 'workflow:presentation:step1',
|
||||
paramKey: 'language',
|
||||
edges: ['step2'],
|
||||
@@ -19,38 +19,44 @@ export const WorkflowGraphList: WorkflowGraphs = [
|
||||
id: 'step2',
|
||||
name: 'Step 2: generate presentation',
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
type: WorkflowExecutorType.ChatText,
|
||||
type: NodeExecutorType.ChatText,
|
||||
promptName: 'workflow:presentation:step2',
|
||||
edges: ['step3'],
|
||||
},
|
||||
{
|
||||
id: 'step3',
|
||||
name: 'Step 3: check format',
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
type: WorkflowExecutorType.ChatText,
|
||||
promptName: 'workflow:presentation:step3',
|
||||
paramKey: 'needFormat',
|
||||
edges: ['step4'],
|
||||
name: 'Step 3: format presentation if needed',
|
||||
nodeType: WorkflowNodeType.Decision,
|
||||
condition: (nodeIds: string[], params: WorkflowNodeState) => {
|
||||
const lines = params.content?.split('\n') || [];
|
||||
return nodeIds[
|
||||
Number(
|
||||
!lines.some(line => {
|
||||
try {
|
||||
if (line.trim()) {
|
||||
JSON.parse(line);
|
||||
}
|
||||
return false;
|
||||
} catch {
|
||||
return true;
|
||||
}
|
||||
})
|
||||
)
|
||||
];
|
||||
},
|
||||
edges: ['step4', 'step5'],
|
||||
},
|
||||
{
|
||||
id: 'step4',
|
||||
name: 'Step 4: format presentation if needed',
|
||||
nodeType: WorkflowNodeType.Decision,
|
||||
condition: (nodeIds: string[], params: WorkflowNodeState) =>
|
||||
nodeIds[Number(String(params.needFormat).toLowerCase() !== 'true')],
|
||||
edges: ['step5', 'step6'],
|
||||
name: 'Step 4: format presentation',
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
type: NodeExecutorType.ChatText,
|
||||
promptName: 'workflow:presentation:step4',
|
||||
edges: ['step5'],
|
||||
},
|
||||
{
|
||||
id: 'step5',
|
||||
name: 'Step 5: format presentation',
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
type: WorkflowExecutorType.ChatText,
|
||||
promptName: 'workflow:presentation:step5',
|
||||
edges: ['step6'],
|
||||
},
|
||||
{
|
||||
id: 'step6',
|
||||
name: 'Step 6: finish',
|
||||
name: 'Step 5: finish',
|
||||
nodeType: WorkflowNodeType.Nope,
|
||||
edges: [],
|
||||
},
|
||||
|
||||
@@ -1,72 +1,8 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import { CopilotChatOptions } from '../types';
|
||||
import { WorkflowGraphList } from './graph';
|
||||
import { WorkflowNode } from './node';
|
||||
import type { WorkflowGraph, WorkflowGraphInstances } from './types';
|
||||
import { CopilotWorkflow } from './workflow';
|
||||
|
||||
@Injectable()
|
||||
export class CopilotWorkflowService {
|
||||
private readonly logger = new Logger(CopilotWorkflowService.name);
|
||||
constructor() {}
|
||||
|
||||
private initWorkflow(graph: WorkflowGraph) {
|
||||
const workflow = new Map<string, WorkflowNode>();
|
||||
for (const nodeData of graph.graph) {
|
||||
const { edges: _, ...data } = nodeData;
|
||||
const node = new WorkflowNode(graph, data);
|
||||
workflow.set(node.id, node);
|
||||
}
|
||||
|
||||
// add edges
|
||||
for (const nodeData of graph.graph) {
|
||||
const node = workflow.get(nodeData.id);
|
||||
if (!node) {
|
||||
this.logger.error(
|
||||
`Failed to init workflow ${name}: node ${nodeData.id} not found`
|
||||
);
|
||||
throw new Error(`Node ${nodeData.id} not found`);
|
||||
}
|
||||
for (const edgeId of nodeData.edges) {
|
||||
const edge = workflow.get(edgeId);
|
||||
if (!edge) {
|
||||
this.logger.error(
|
||||
`Failed to init workflow ${name}: edge ${edgeId} not found in node ${nodeData.id}`
|
||||
);
|
||||
throw new Error(`Edge ${edgeId} not found`);
|
||||
}
|
||||
node.addEdge(edge);
|
||||
}
|
||||
}
|
||||
return workflow;
|
||||
}
|
||||
|
||||
// TODO(@darkskygit): get workflow from database
|
||||
private async getWorkflow(
|
||||
graphName: string
|
||||
): Promise<WorkflowGraphInstances> {
|
||||
const graph = WorkflowGraphList.find(g => g.name === graphName);
|
||||
if (!graph) {
|
||||
throw new Error(`Graph ${graphName} not found`);
|
||||
}
|
||||
|
||||
return this.initWorkflow(graph);
|
||||
}
|
||||
|
||||
async *runGraph(
|
||||
params: Record<string, string>,
|
||||
graphName: string,
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<string> {
|
||||
const workflowGraph = await this.getWorkflow(graphName);
|
||||
const workflow = new CopilotWorkflow(workflowGraph);
|
||||
|
||||
for await (const result of workflow.runGraph(params, options)) {
|
||||
yield result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export { CopilotChatTextExecutor, CopilotWorkflowExecutors } from './executor';
|
||||
export { WorkflowNodeType } from './types';
|
||||
export { CopilotWorkflowService } from './service';
|
||||
export {
|
||||
type WorkflowGraph,
|
||||
type WorkflowNodeData,
|
||||
WorkflowNodeType,
|
||||
} from './types';
|
||||
export { GraphExecutorState, WorkflowGraphExecutor } from './workflow';
|
||||
|
||||
@@ -5,33 +5,33 @@ import { Logger } from '@nestjs/common';
|
||||
import Piscina from 'piscina';
|
||||
|
||||
import { CopilotChatOptions } from '../types';
|
||||
import { getWorkflowExecutor, WorkflowExecutor } from './executor';
|
||||
import type { NodeExecuteResult, NodeExecutor } from './executor';
|
||||
import { getWorkflowExecutor, NodeExecuteState } from './executor';
|
||||
import type {
|
||||
NodeData,
|
||||
WorkflowGraph,
|
||||
WorkflowNodeData,
|
||||
WorkflowNodeState,
|
||||
WorkflowResult,
|
||||
} from './types';
|
||||
import { WorkflowNodeType, WorkflowResultType } from './types';
|
||||
import { WorkflowNodeType } from './types';
|
||||
|
||||
export class WorkflowNode {
|
||||
private readonly logger = new Logger(WorkflowNode.name);
|
||||
private readonly edges: WorkflowNode[] = [];
|
||||
private readonly parents: WorkflowNode[] = [];
|
||||
private readonly executor: WorkflowExecutor | null = null;
|
||||
private readonly executor: NodeExecutor | null = null;
|
||||
private readonly condition:
|
||||
| ((params: WorkflowNodeState) => Promise<any>)
|
||||
| null = null;
|
||||
|
||||
constructor(
|
||||
graph: WorkflowGraph,
|
||||
private readonly data: NodeData
|
||||
private readonly data: WorkflowNodeData
|
||||
) {
|
||||
if (data.nodeType === WorkflowNodeType.Basic) {
|
||||
this.executor = getWorkflowExecutor(data.type);
|
||||
} else if (data.nodeType === WorkflowNodeType.Decision) {
|
||||
// prepare decision condition, reused in each run
|
||||
const iife = `(${data.condition})(nodeIds, params)`;
|
||||
const iife = `return (${data.condition})(nodeIds, params)`;
|
||||
// only eval the condition in worker if graph has been modified
|
||||
if (graph.modified) {
|
||||
const worker = new Piscina({
|
||||
@@ -55,11 +55,7 @@ export class WorkflowNode {
|
||||
const func =
|
||||
typeof data.condition === 'function'
|
||||
? data.condition
|
||||
: new Function(
|
||||
'nodeIds',
|
||||
'params',
|
||||
`(${data.condition})(nodeIds, params)`
|
||||
);
|
||||
: new Function('nodeIds', 'params', iife);
|
||||
this.condition = (params: WorkflowNodeState) =>
|
||||
func(
|
||||
this.edges.map(node => node.id),
|
||||
@@ -77,7 +73,7 @@ export class WorkflowNode {
|
||||
return this.data.name;
|
||||
}
|
||||
|
||||
get config(): NodeData {
|
||||
get config(): WorkflowNodeData {
|
||||
return Object.assign({}, this.data);
|
||||
}
|
||||
|
||||
@@ -106,6 +102,8 @@ export class WorkflowNode {
|
||||
!this.data.condition
|
||||
) {
|
||||
throw new Error(`Decision block must have a condition`);
|
||||
} else if (this.data.nodeType === WorkflowNodeType.Nope) {
|
||||
throw new Error(`Nope block cannot have edges`);
|
||||
}
|
||||
node.parent = this;
|
||||
this.edges.push(node);
|
||||
@@ -133,8 +131,8 @@ export class WorkflowNode {
|
||||
async *next(
|
||||
params: WorkflowNodeState,
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<WorkflowResult> {
|
||||
yield { type: WorkflowResultType.StartRun, nodeId: this.id };
|
||||
): AsyncIterable<NodeExecuteResult> {
|
||||
yield { type: NodeExecuteState.StartRun, nodeId: this.id };
|
||||
|
||||
// choose next node in graph
|
||||
let nextNode: WorkflowNode | undefined = this.edges[0];
|
||||
@@ -155,12 +153,12 @@ export class WorkflowNode {
|
||||
yield* this.executor.next(this.data, params, options);
|
||||
} else {
|
||||
yield {
|
||||
type: WorkflowResultType.Content,
|
||||
type: NodeExecuteState.Content,
|
||||
nodeId: this.id,
|
||||
content: params.content,
|
||||
};
|
||||
}
|
||||
|
||||
yield { type: WorkflowResultType.EndRun, nextNode };
|
||||
yield { type: NodeExecuteState.EndRun, nextNode };
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import { CopilotChatOptions } from '../types';
|
||||
import { WorkflowGraphList } from './graph';
|
||||
import { WorkflowNode } from './node';
|
||||
import type { WorkflowGraph, WorkflowGraphInstances } from './types';
|
||||
import { type GraphExecutorStatus, WorkflowGraphExecutor } from './workflow';
|
||||
|
||||
@Injectable()
|
||||
export class CopilotWorkflowService {
|
||||
private readonly logger = new Logger(CopilotWorkflowService.name);
|
||||
|
||||
initWorkflow(graph: WorkflowGraph) {
|
||||
const workflow = new Map<string, WorkflowNode>();
|
||||
for (const nodeData of graph.graph) {
|
||||
const { edges: _, ...data } = nodeData;
|
||||
const node = new WorkflowNode(graph, data);
|
||||
workflow.set(node.id, node);
|
||||
}
|
||||
|
||||
// add edges
|
||||
for (const nodeData of graph.graph) {
|
||||
const node = workflow.get(nodeData.id);
|
||||
if (!node) {
|
||||
this.logger.error(
|
||||
`Failed to init workflow ${name}: node ${nodeData.id} not found`
|
||||
);
|
||||
throw new Error(`Node ${nodeData.id} not found`);
|
||||
}
|
||||
for (const edgeId of nodeData.edges) {
|
||||
const edge = workflow.get(edgeId);
|
||||
if (!edge) {
|
||||
this.logger.error(
|
||||
`Failed to init workflow ${name}: edge ${edgeId} not found in node ${nodeData.id}`
|
||||
);
|
||||
throw new Error(`Edge ${edgeId} not found`);
|
||||
}
|
||||
node.addEdge(edge);
|
||||
}
|
||||
}
|
||||
return workflow;
|
||||
}
|
||||
|
||||
// TODO(@darkskygit): get workflow from database
|
||||
private async getWorkflow(
|
||||
graphName: string
|
||||
): Promise<WorkflowGraphInstances> {
|
||||
const graph = WorkflowGraphList.find(g => g.name === graphName);
|
||||
if (!graph) {
|
||||
throw new Error(`Graph ${graphName} not found`);
|
||||
}
|
||||
|
||||
return this.initWorkflow(graph);
|
||||
}
|
||||
|
||||
async *runGraph(
|
||||
params: Record<string, string>,
|
||||
graphName: string,
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<GraphExecutorStatus> {
|
||||
const workflowGraph = await this.getWorkflow(graphName);
|
||||
const executor = new WorkflowGraphExecutor(workflowGraph);
|
||||
|
||||
for await (const result of executor.runGraph(params, options)) {
|
||||
yield result;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,19 @@
|
||||
import type { WorkflowExecutorType } from './executor';
|
||||
import type { NodeExecutorType } from './executor';
|
||||
import type { WorkflowNode } from './node';
|
||||
|
||||
// ===================== node =====================
|
||||
|
||||
export enum WorkflowNodeType {
|
||||
Basic = 'basic',
|
||||
Decision = 'decision',
|
||||
Nope = 'nope',
|
||||
}
|
||||
|
||||
export type NodeData = { id: string; name: string } & (
|
||||
export type WorkflowNodeData = { id: string; name: string } & (
|
||||
| {
|
||||
nodeType: WorkflowNodeType.Basic;
|
||||
promptName: string;
|
||||
type: WorkflowExecutorType;
|
||||
type: NodeExecutorType;
|
||||
promptName?: string;
|
||||
// update the prompt params by output with the custom key
|
||||
paramKey?: string;
|
||||
}
|
||||
@@ -25,35 +27,22 @@ export type NodeData = { id: string; name: string } & (
|
||||
| { nodeType: WorkflowNodeType.Nope }
|
||||
);
|
||||
|
||||
export type WorkflowNodeState = Record<string, string>;
|
||||
export type WorkflowGraphInstances = Map<string, WorkflowNode>;
|
||||
|
||||
export type WorkflowGraphData = Array<NodeData & { edges: string[] }>;
|
||||
// ===================== graph =====================
|
||||
|
||||
export type WorkflowGraphDefinition = Array<
|
||||
WorkflowNodeData & { edges: string[] }
|
||||
>;
|
||||
export type WorkflowGraph = {
|
||||
name: string;
|
||||
// true if the graph has been modified
|
||||
modified?: boolean;
|
||||
graph: WorkflowGraphData;
|
||||
graph: WorkflowGraphDefinition;
|
||||
};
|
||||
export type WorkflowGraphs = Array<WorkflowGraph>;
|
||||
|
||||
export enum WorkflowResultType {
|
||||
StartRun,
|
||||
EndRun,
|
||||
Params,
|
||||
Content,
|
||||
}
|
||||
// ===================== executor =====================
|
||||
|
||||
export type WorkflowResult =
|
||||
| { type: WorkflowResultType.StartRun; nodeId: string }
|
||||
| { type: WorkflowResultType.EndRun; nextNode?: WorkflowNode }
|
||||
| {
|
||||
type: WorkflowResultType.Params;
|
||||
params: Record<string, string | string[]>;
|
||||
}
|
||||
| {
|
||||
type: WorkflowResultType.Content;
|
||||
nodeId: string;
|
||||
content: string;
|
||||
};
|
||||
|
||||
export type WorkflowGraphInstances = Map<string, WorkflowNode>;
|
||||
export type WorkflowParams = Record<string, string | string[]>;
|
||||
export type WorkflowNodeState = Record<string, string>;
|
||||
|
||||
@@ -1,16 +1,25 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
import { CopilotChatOptions } from '../types';
|
||||
import { NodeExecuteState } from './executor';
|
||||
import { WorkflowNode } from './node';
|
||||
import {
|
||||
type WorkflowGraphInstances,
|
||||
type WorkflowNodeState,
|
||||
WorkflowNodeType,
|
||||
WorkflowResultType,
|
||||
} from './types';
|
||||
import type { WorkflowGraphInstances, WorkflowNodeState } from './types';
|
||||
import { WorkflowNodeType } from './types';
|
||||
|
||||
export class CopilotWorkflow {
|
||||
private readonly logger = new Logger(CopilotWorkflow.name);
|
||||
export enum GraphExecutorState {
|
||||
EnterNode = 'EnterNode',
|
||||
EmitContent = 'EmitContent',
|
||||
ExitNode = 'ExitNode',
|
||||
}
|
||||
|
||||
export type GraphExecutorStatus = { status: GraphExecutorState } & (
|
||||
| { status: GraphExecutorState.EnterNode; node: WorkflowNode }
|
||||
| { status: GraphExecutorState.EmitContent; content: string }
|
||||
| { status: GraphExecutorState.ExitNode; node: WorkflowNode }
|
||||
);
|
||||
|
||||
export class WorkflowGraphExecutor {
|
||||
private readonly logger = new Logger(WorkflowGraphExecutor.name);
|
||||
private readonly rootNode: WorkflowNode;
|
||||
|
||||
constructor(workflow: WorkflowGraphInstances) {
|
||||
@@ -24,7 +33,7 @@ export class CopilotWorkflow {
|
||||
async *runGraph(
|
||||
params: Record<string, string>,
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<string> {
|
||||
): AsyncIterable<GraphExecutorStatus> {
|
||||
let currentNode: WorkflowNode | undefined = this.rootNode;
|
||||
const lastParams: WorkflowNodeState = { ...params };
|
||||
|
||||
@@ -33,10 +42,13 @@ export class CopilotWorkflow {
|
||||
let nextNode: WorkflowNode | undefined;
|
||||
|
||||
for await (const ret of currentNode.next(lastParams, options)) {
|
||||
if (ret.type === WorkflowResultType.EndRun) {
|
||||
if (ret.type === NodeExecuteState.StartRun) {
|
||||
yield { status: GraphExecutorState.EnterNode, node: currentNode };
|
||||
} else if (ret.type === NodeExecuteState.EndRun) {
|
||||
yield { status: GraphExecutorState.ExitNode, node: currentNode };
|
||||
nextNode = ret.nextNode;
|
||||
break;
|
||||
} else if (ret.type === WorkflowResultType.Params) {
|
||||
} else if (ret.type === NodeExecuteState.Params) {
|
||||
Object.assign(lastParams, ret.params);
|
||||
if (currentNode.config.nodeType === WorkflowNodeType.Basic) {
|
||||
const { type, promptName } = currentNode.config;
|
||||
@@ -44,10 +56,13 @@ export class CopilotWorkflow {
|
||||
`[${currentNode.name}][${type}][${promptName}]: update params - '${JSON.stringify(ret.params)}'`
|
||||
);
|
||||
}
|
||||
} else if (ret.type === WorkflowResultType.Content) {
|
||||
} else if (ret.type === NodeExecuteState.Content) {
|
||||
if (!currentNode.hasEdges) {
|
||||
// pass through content as a stream response if node is end node
|
||||
yield ret.content;
|
||||
yield {
|
||||
status: GraphExecutorState.EmitContent,
|
||||
content: ret.content,
|
||||
};
|
||||
} else {
|
||||
result += ret.content;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user