feat: add more workflow executor (#7231)

This commit is contained in:
darkskygit
2024-06-25 10:54:16 +00:00
parent 532a628989
commit cffaf815e1
23 changed files with 946 additions and 324 deletions

View File

@@ -38,10 +38,10 @@ import { CopilotProviderService } from './providers';
import { ChatSession, ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import { CopilotCapability, CopilotTextProvider } from './types';
import { CopilotWorkflowService } from './workflow';
import { CopilotWorkflowService, GraphExecutorState } from './workflow';
export interface ChatEvent {
type: 'attachment' | 'message' | 'error';
type: 'event' | 'attachment' | 'message' | 'error';
id?: string;
data: string | object;
}
@@ -134,6 +134,15 @@ export class CopilotController {
return session;
}
private prepareParams(params: Record<string, string | string[]>) {
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const jsonMode = String(params.jsonMode).toLowerCase() === 'true';
delete params.messageId;
return { messageId, jsonMode, params };
}
private getSignal(req: Request) {
const controller = new AbortController();
req.on('close', () => controller.abort());
@@ -158,9 +167,7 @@ export class CopilotController {
@Param('sessionId') sessionId: string,
@Query() params: Record<string, string | string[]>
): Promise<string> {
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const { messageId, jsonMode } = this.prepareParams(params);
const provider = await this.chooseTextProvider(
user.id,
sessionId,
@@ -170,14 +177,10 @@ export class CopilotController {
const session = await this.appendSessionMessage(sessionId, messageId);
try {
delete params.messageId;
const content = await provider.generateText(
session.finish(params),
session.model,
{
signal: this.getSignal(req),
user: user.id,
}
{ jsonMode, signal: this.getSignal(req), user: user.id }
);
session.push({
@@ -201,9 +204,7 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const { messageId, jsonMode } = this.prepareParams(params);
const provider = await this.chooseTextProvider(
user.id,
sessionId,
@@ -211,10 +212,10 @@ export class CopilotController {
);
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;
return from(
provider.generateTextStream(session.finish(params), session.model, {
jsonMode,
signal: this.getSignal(req),
user: user.id,
})
@@ -255,12 +256,8 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const { messageId, jsonMode } = this.prepareParams(params);
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;
const latestMessage = session.stashMessages.findLast(
m => m.role === 'user'
);
@@ -272,6 +269,7 @@ export class CopilotController {
return from(
this.workflow.runGraph(params, session.model, {
jsonMode,
signal: this.getSignal(req),
user: user.id,
})
@@ -280,7 +278,23 @@ export class CopilotController {
merge(
// actual chat event stream
shared$.pipe(
map(data => ({ type: 'message' as const, id: messageId, data }))
map(data =>
data.status === GraphExecutorState.EmitContent
? {
type: 'message' as const,
id: messageId,
data: data.content,
}
: {
type: 'event' as const,
id: messageId,
data: {
status: data.status,
id: data.node.id,
type: data.node.config.nodeType,
} as any,
}
)
),
// save the generated text to the session
shared$.pipe(
@@ -312,9 +326,7 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const { messageId } = this.prepareParams(params);
const { model, hasAttachment } = await this.checkRequest(
user.id,
sessionId,
@@ -331,7 +343,6 @@ export class CopilotController {
}
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;
const handleRemoteLink = this.storage.handleRemoteLink.bind(
this.storage,

View File

@@ -120,19 +120,37 @@ export class OpenAIProvider
});
}
private extractOptionFromMessages(
messages: PromptMessage[],
options: CopilotChatOptions
) {
const params: Record<string, string | string[]> = {};
for (const message of messages) {
if (message.params) {
Object.assign(params, message.params);
}
}
if (params.jsonMode && options) {
options.jsonMode = String(params.jsonMode).toLowerCase() === 'true';
}
}
protected checkParams({
messages,
embeddings,
model,
options = {},
}: {
messages?: PromptMessage[];
embeddings?: string[];
model: string;
options: CopilotChatOptions;
}) {
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
}
if (Array.isArray(messages) && messages.length > 0) {
this.extractOptionFromMessages(messages, options);
if (
messages.some(
m =>
@@ -158,6 +176,14 @@ export class OpenAIProvider
) {
throw new Error('Invalid message role');
}
// json mode need 'json' keyword in content
// ref: https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format
if (
options.jsonMode &&
!messages.some(m => m.content.toLowerCase().includes('json'))
) {
throw new Error('Prompt not support json mode');
}
} else if (
Array.isArray(embeddings) &&
embeddings.some(e => typeof e !== 'string' || !e || !e.trim())
@@ -173,13 +199,16 @@ export class OpenAIProvider
model: string = 'gpt-3.5-turbo',
options: CopilotChatOptions = {}
): Promise<string> {
this.checkParams({ messages, model });
this.checkParams({ messages, model, options });
const result = await this.instance.chat.completions.create(
{
messages: this.chatToGPTMessage(messages),
model: model,
temperature: options.temperature || 0,
max_tokens: options.maxTokens || 4096,
response_format: {
type: options.jsonMode ? 'json_object' : 'text',
},
user: options.user,
},
{ signal: options.signal }
@@ -196,7 +225,7 @@ export class OpenAIProvider
model: string = 'gpt-3.5-turbo',
options: CopilotChatOptions = {}
): AsyncIterable<string> {
this.checkParams({ messages, model });
this.checkParams({ messages, model, options });
const result = await this.instance.chat.completions.create(
{
stream: true,
@@ -204,6 +233,9 @@ export class OpenAIProvider
model: model,
temperature: options.temperature || 0,
max_tokens: options.maxTokens || 4096,
response_format: {
type: options.jsonMode ? 'json_object' : 'text',
},
user: options.user,
},
{
@@ -231,7 +263,7 @@ export class OpenAIProvider
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages];
this.checkParams({ embeddings: messages, model });
this.checkParams({ embeddings: messages, model, options });
const result = await this.instance.embeddings.create({
model: model,

View File

@@ -137,6 +137,7 @@ const CopilotProviderOptionsSchema = z.object({
});
const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.extend({
jsonMode: z.boolean().optional(),
temperature: z.number().optional(),
maxTokens: z.number().optional(),
}).optional();

View File

@@ -0,0 +1,91 @@
import { Injectable } from '@nestjs/common';
import { ChatPrompt, PromptService } from '../../prompt';
import { CopilotProviderService } from '../../providers';
import { CopilotChatOptions, CopilotImageProvider } from '../../types';
import { WorkflowNodeData, WorkflowNodeType } from '../types';
import { NodeExecuteResult, NodeExecuteState, NodeExecutorType } from './types';
import { AutoRegisteredWorkflowExecutor } from './utils';
@Injectable()
export class CopilotChatImageExecutor extends AutoRegisteredWorkflowExecutor {
constructor(
private readonly promptService: PromptService,
private readonly providerService: CopilotProviderService
) {
super();
}
private async initExecutor(
data: WorkflowNodeData
): Promise<
[
WorkflowNodeData & { nodeType: WorkflowNodeType.Basic },
ChatPrompt,
CopilotImageProvider,
]
> {
if (data.nodeType !== WorkflowNodeType.Basic) {
throw new Error(
`Executor ${this.type} not support ${data.nodeType} node`
);
}
if (!data.promptName) {
throw new Error(
`Prompt name not found when running workflow node ${data.name}`
);
}
const prompt = await this.promptService.get(data.promptName);
if (!prompt) {
throw new Error(
`Prompt ${data.promptName} not found when running workflow node ${data.name}`
);
}
const provider = await this.providerService.getProviderByModel(
prompt.model
);
if (provider && 'generateImages' in provider) {
return [data, prompt, provider];
}
throw new Error(
`Provider not found for model ${prompt.model} when running workflow node ${data.name}`
);
}
override get type() {
return NodeExecutorType.ChatImage;
}
override async *next(
data: WorkflowNodeData,
params: Record<string, string>,
options?: CopilotChatOptions
): AsyncIterable<NodeExecuteResult> {
const [{ paramKey, id }, prompt, provider] = await this.initExecutor(data);
const finalMessage = prompt.finish(params);
if (paramKey) {
// update params with custom key
yield {
type: NodeExecuteState.Params,
params: {
[paramKey]: await provider.generateImages(
finalMessage,
prompt.model,
options
),
},
};
} else {
for await (const content of provider.generateImagesStream(
finalMessage,
prompt.model,
options
)) {
yield { type: NodeExecuteState.Content, nodeId: id, content };
}
}
}
}

View File

@@ -3,13 +3,8 @@ import { Injectable } from '@nestjs/common';
import { ChatPrompt, PromptService } from '../../prompt';
import { CopilotProviderService } from '../../providers';
import { CopilotChatOptions, CopilotTextProvider } from '../../types';
import {
NodeData,
WorkflowNodeType,
WorkflowResult,
WorkflowResultType,
} from '../types';
import { WorkflowExecutorType } from './types';
import { WorkflowNodeData, WorkflowNodeType } from '../types';
import { NodeExecuteResult, NodeExecuteState, NodeExecutorType } from './types';
import { AutoRegisteredWorkflowExecutor } from './utils';
@Injectable()
@@ -22,10 +17,10 @@ export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
}
private async initExecutor(
data: NodeData
data: WorkflowNodeData
): Promise<
[
NodeData & { nodeType: WorkflowNodeType.Basic },
WorkflowNodeData & { nodeType: WorkflowNodeType.Basic },
ChatPrompt,
CopilotTextProvider,
]
@@ -36,6 +31,11 @@ export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
);
}
if (!data.promptName) {
throw new Error(
`Prompt name not found when running workflow node ${data.name}`
);
}
const prompt = await this.promptService.get(data.promptName);
if (!prompt) {
throw new Error(
@@ -55,21 +55,21 @@ export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
}
override get type() {
return WorkflowExecutorType.ChatText;
return NodeExecutorType.ChatText;
}
override async *next(
data: NodeData,
data: WorkflowNodeData,
params: Record<string, string>,
options?: CopilotChatOptions
): AsyncIterable<WorkflowResult> {
): AsyncIterable<NodeExecuteResult> {
const [{ paramKey, id }, prompt, provider] = await this.initExecutor(data);
const finalMessage = prompt.finish(params);
if (paramKey) {
// update params with custom key
yield {
type: WorkflowResultType.Params,
type: NodeExecuteState.Params,
params: {
[paramKey]: await provider.generateText(
finalMessage,
@@ -84,11 +84,7 @@ export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
prompt.model,
options
)) {
yield {
type: WorkflowResultType.Content,
nodeId: id,
content,
};
yield { type: NodeExecuteState.Content, nodeId: id, content };
}
}
}

View File

@@ -0,0 +1,64 @@
import { Injectable } from '@nestjs/common';
import { XMLValidator } from 'fast-xml-parser';
import { HtmlValidate } from 'html-validate/node';
import { WorkflowNodeData, WorkflowNodeType, WorkflowParams } from '../types';
import { NodeExecuteResult, NodeExecuteState, NodeExecutorType } from './types';
import { AutoRegisteredWorkflowExecutor } from './utils';
@Injectable()
export class CopilotCheckHtmlExecutor extends AutoRegisteredWorkflowExecutor {
private readonly html = new HtmlValidate();
private async initExecutor(
data: WorkflowNodeData
): Promise<WorkflowNodeData & { nodeType: WorkflowNodeType.Basic }> {
if (data.nodeType !== WorkflowNodeType.Basic) {
throw new Error(
`Executor ${this.type} not support ${data.nodeType} node`
);
}
return data;
}
override get type() {
return NodeExecutorType.CheckHtml;
}
private async checkHtml(
content?: string | string[],
strict?: boolean
): Promise<boolean> {
try {
if (content && typeof content === 'string') {
const ret = XMLValidator.validate(content);
if (ret === true) {
if (strict) {
const report = await this.html.validateString(content, {
extends: ['html-validate:standard'],
});
return report.valid;
}
return true;
}
}
return false;
} catch (e) {
return false;
}
}
override async *next(
data: WorkflowNodeData,
params: WorkflowParams
): AsyncIterable<NodeExecuteResult> {
const { paramKey, id } = await this.initExecutor(data);
const ret = String(await this.checkHtml(params.content, !!params.strict));
if (paramKey) {
yield { type: NodeExecuteState.Params, params: { [paramKey]: ret } };
} else {
yield { type: NodeExecuteState.Content, nodeId: id, content: ret };
}
}
}

View File

@@ -0,0 +1,53 @@
import { Injectable } from '@nestjs/common';
import { WorkflowNodeData, WorkflowNodeType, WorkflowParams } from '../types';
import { NodeExecuteResult, NodeExecuteState, NodeExecutorType } from './types';
import { AutoRegisteredWorkflowExecutor } from './utils';
@Injectable()
export class CopilotCheckJsonExecutor extends AutoRegisteredWorkflowExecutor {
constructor() {
super();
}
private async initExecutor(
data: WorkflowNodeData
): Promise<WorkflowNodeData & { nodeType: WorkflowNodeType.Basic }> {
if (data.nodeType !== WorkflowNodeType.Basic) {
throw new Error(
`Executor ${this.type} not support ${data.nodeType} node`
);
}
return data;
}
override get type() {
return NodeExecutorType.CheckJson;
}
private checkJson(content?: string | string[]): boolean {
try {
if (content && typeof content === 'string') {
JSON.parse(content);
return true;
}
return false;
} catch (e) {
return false;
}
}
override async *next(
data: WorkflowNodeData,
params: WorkflowParams
): AsyncIterable<NodeExecuteResult> {
const { paramKey, id } = await this.initExecutor(data);
const ret = String(this.checkJson(params.content));
if (paramKey) {
yield { type: NodeExecuteState.Params, params: { [paramKey]: ret } };
} else {
yield { type: NodeExecuteState.Content, nodeId: id, content: ret };
}
}
}

View File

@@ -1,7 +1,21 @@
import { CopilotChatImageExecutor } from './chat-image';
import { CopilotChatTextExecutor } from './chat-text';
import { CopilotCheckHtmlExecutor } from './check-html';
import { CopilotCheckJsonExecutor } from './check-json';
export const CopilotWorkflowExecutors = [CopilotChatTextExecutor];
export const CopilotWorkflowExecutors = [
CopilotChatImageExecutor,
CopilotChatTextExecutor,
CopilotCheckHtmlExecutor,
CopilotCheckJsonExecutor,
];
export { type WorkflowExecutor, WorkflowExecutorType } from './types';
export type { NodeExecuteResult, NodeExecutor } from './types';
export { NodeExecuteState, NodeExecutorType } from './types';
export { getWorkflowExecutor } from './utils';
export { CopilotChatTextExecutor };
export {
CopilotChatImageExecutor,
CopilotChatTextExecutor,
CopilotCheckHtmlExecutor,
CopilotCheckJsonExecutor,
};

View File

@@ -1,15 +1,32 @@
import { CopilotChatOptions } from '../../types';
import { NodeData, WorkflowResult } from '../types';
import type { WorkflowNode } from '../node';
import { WorkflowNodeData, WorkflowParams } from '../types';
export enum WorkflowExecutorType {
export enum NodeExecutorType {
ChatText = 'ChatText',
ChatImage = 'ChatImage',
CheckJson = 'CheckJson',
CheckHtml = 'CheckHtml',
}
export abstract class WorkflowExecutor {
abstract get type(): WorkflowExecutorType;
abstract next(
data: NodeData,
params: Record<string, string | string[]>,
options?: CopilotChatOptions
): AsyncIterable<WorkflowResult>;
export enum NodeExecuteState {
StartRun,
EndRun,
Params,
Content,
}
export type NodeExecuteResult =
| { type: NodeExecuteState.StartRun; nodeId: string }
| { type: NodeExecuteState.EndRun; nextNode?: WorkflowNode }
| { type: NodeExecuteState.Params; params: WorkflowParams }
| { type: NodeExecuteState.Content; nodeId: string; content: string };
export abstract class NodeExecutor {
abstract get type(): NodeExecutorType;
abstract next(
data: WorkflowNodeData,
params: WorkflowParams,
options?: CopilotChatOptions
): AsyncIterable<NodeExecuteResult>;
}

View File

@@ -1,19 +1,17 @@
import { Logger, OnModuleInit } from '@nestjs/common';
import { WorkflowExecutor, type WorkflowExecutorType } from './types';
import { NodeExecutor, type NodeExecutorType } from './types';
const WORKFLOW_EXECUTOR: Map<string, WorkflowExecutor> = new Map();
const WORKFLOW_EXECUTOR: Map<string, NodeExecutor> = new Map();
function registerWorkflowExecutor(e: WorkflowExecutor) {
function registerWorkflowExecutor(e: NodeExecutor) {
const existing = WORKFLOW_EXECUTOR.get(e.type);
if (existing && existing === e) return false;
WORKFLOW_EXECUTOR.set(e.type, e);
return true;
}
export function getWorkflowExecutor(
type: WorkflowExecutorType
): WorkflowExecutor {
export function getWorkflowExecutor(type: NodeExecutorType): NodeExecutor {
const executor = WORKFLOW_EXECUTOR.get(type);
if (!executor) {
throw new Error(`Executor ${type} not defined`);
@@ -23,7 +21,7 @@ export function getWorkflowExecutor(
}
export abstract class AutoRegisteredWorkflowExecutor
extends WorkflowExecutor
extends NodeExecutor
implements OnModuleInit
{
onModuleInit() {

View File

@@ -1,4 +1,4 @@
import { WorkflowExecutorType } from './executor';
import { NodeExecutorType } from './executor';
import type { WorkflowGraphs } from './types';
import { WorkflowNodeState, WorkflowNodeType } from './types';
@@ -10,7 +10,7 @@ export const WorkflowGraphList: WorkflowGraphs = [
id: 'start',
name: 'Start: check language',
nodeType: WorkflowNodeType.Basic,
type: WorkflowExecutorType.ChatText,
type: NodeExecutorType.ChatText,
promptName: 'workflow:presentation:step1',
paramKey: 'language',
edges: ['step2'],
@@ -19,38 +19,44 @@ export const WorkflowGraphList: WorkflowGraphs = [
id: 'step2',
name: 'Step 2: generate presentation',
nodeType: WorkflowNodeType.Basic,
type: WorkflowExecutorType.ChatText,
type: NodeExecutorType.ChatText,
promptName: 'workflow:presentation:step2',
edges: ['step3'],
},
{
id: 'step3',
name: 'Step 3: check format',
nodeType: WorkflowNodeType.Basic,
type: WorkflowExecutorType.ChatText,
promptName: 'workflow:presentation:step3',
paramKey: 'needFormat',
edges: ['step4'],
name: 'Step 3: format presentation if needed',
nodeType: WorkflowNodeType.Decision,
condition: (nodeIds: string[], params: WorkflowNodeState) => {
const lines = params.content?.split('\n') || [];
return nodeIds[
Number(
!lines.some(line => {
try {
if (line.trim()) {
JSON.parse(line);
}
return false;
} catch {
return true;
}
})
)
];
},
edges: ['step4', 'step5'],
},
{
id: 'step4',
name: 'Step 4: format presentation if needed',
nodeType: WorkflowNodeType.Decision,
condition: (nodeIds: string[], params: WorkflowNodeState) =>
nodeIds[Number(String(params.needFormat).toLowerCase() !== 'true')],
edges: ['step5', 'step6'],
name: 'Step 4: format presentation',
nodeType: WorkflowNodeType.Basic,
type: NodeExecutorType.ChatText,
promptName: 'workflow:presentation:step4',
edges: ['step5'],
},
{
id: 'step5',
name: 'Step 5: format presentation',
nodeType: WorkflowNodeType.Basic,
type: WorkflowExecutorType.ChatText,
promptName: 'workflow:presentation:step5',
edges: ['step6'],
},
{
id: 'step6',
name: 'Step 6: finish',
name: 'Step 5: finish',
nodeType: WorkflowNodeType.Nope,
edges: [],
},

View File

@@ -1,72 +1,8 @@
import { Injectable, Logger } from '@nestjs/common';
import { CopilotChatOptions } from '../types';
import { WorkflowGraphList } from './graph';
import { WorkflowNode } from './node';
import type { WorkflowGraph, WorkflowGraphInstances } from './types';
import { CopilotWorkflow } from './workflow';
@Injectable()
export class CopilotWorkflowService {
private readonly logger = new Logger(CopilotWorkflowService.name);
constructor() {}
private initWorkflow(graph: WorkflowGraph) {
const workflow = new Map<string, WorkflowNode>();
for (const nodeData of graph.graph) {
const { edges: _, ...data } = nodeData;
const node = new WorkflowNode(graph, data);
workflow.set(node.id, node);
}
// add edges
for (const nodeData of graph.graph) {
const node = workflow.get(nodeData.id);
if (!node) {
this.logger.error(
`Failed to init workflow ${name}: node ${nodeData.id} not found`
);
throw new Error(`Node ${nodeData.id} not found`);
}
for (const edgeId of nodeData.edges) {
const edge = workflow.get(edgeId);
if (!edge) {
this.logger.error(
`Failed to init workflow ${name}: edge ${edgeId} not found in node ${nodeData.id}`
);
throw new Error(`Edge ${edgeId} not found`);
}
node.addEdge(edge);
}
}
return workflow;
}
// TODO(@darkskygit): get workflow from database
private async getWorkflow(
graphName: string
): Promise<WorkflowGraphInstances> {
const graph = WorkflowGraphList.find(g => g.name === graphName);
if (!graph) {
throw new Error(`Graph ${graphName} not found`);
}
return this.initWorkflow(graph);
}
async *runGraph(
params: Record<string, string>,
graphName: string,
options?: CopilotChatOptions
): AsyncIterable<string> {
const workflowGraph = await this.getWorkflow(graphName);
const workflow = new CopilotWorkflow(workflowGraph);
for await (const result of workflow.runGraph(params, options)) {
yield result;
}
}
}
export { CopilotChatTextExecutor, CopilotWorkflowExecutors } from './executor';
export { WorkflowNodeType } from './types';
export { CopilotWorkflowService } from './service';
export {
type WorkflowGraph,
type WorkflowNodeData,
WorkflowNodeType,
} from './types';
export { GraphExecutorState, WorkflowGraphExecutor } from './workflow';

View File

@@ -5,33 +5,33 @@ import { Logger } from '@nestjs/common';
import Piscina from 'piscina';
import { CopilotChatOptions } from '../types';
import { getWorkflowExecutor, WorkflowExecutor } from './executor';
import type { NodeExecuteResult, NodeExecutor } from './executor';
import { getWorkflowExecutor, NodeExecuteState } from './executor';
import type {
NodeData,
WorkflowGraph,
WorkflowNodeData,
WorkflowNodeState,
WorkflowResult,
} from './types';
import { WorkflowNodeType, WorkflowResultType } from './types';
import { WorkflowNodeType } from './types';
export class WorkflowNode {
private readonly logger = new Logger(WorkflowNode.name);
private readonly edges: WorkflowNode[] = [];
private readonly parents: WorkflowNode[] = [];
private readonly executor: WorkflowExecutor | null = null;
private readonly executor: NodeExecutor | null = null;
private readonly condition:
| ((params: WorkflowNodeState) => Promise<any>)
| null = null;
constructor(
graph: WorkflowGraph,
private readonly data: NodeData
private readonly data: WorkflowNodeData
) {
if (data.nodeType === WorkflowNodeType.Basic) {
this.executor = getWorkflowExecutor(data.type);
} else if (data.nodeType === WorkflowNodeType.Decision) {
// prepare decision condition, reused in each run
const iife = `(${data.condition})(nodeIds, params)`;
const iife = `return (${data.condition})(nodeIds, params)`;
// only eval the condition in worker if graph has been modified
if (graph.modified) {
const worker = new Piscina({
@@ -55,11 +55,7 @@ export class WorkflowNode {
const func =
typeof data.condition === 'function'
? data.condition
: new Function(
'nodeIds',
'params',
`(${data.condition})(nodeIds, params)`
);
: new Function('nodeIds', 'params', iife);
this.condition = (params: WorkflowNodeState) =>
func(
this.edges.map(node => node.id),
@@ -77,7 +73,7 @@ export class WorkflowNode {
return this.data.name;
}
get config(): NodeData {
get config(): WorkflowNodeData {
return Object.assign({}, this.data);
}
@@ -106,6 +102,8 @@ export class WorkflowNode {
!this.data.condition
) {
throw new Error(`Decision block must have a condition`);
} else if (this.data.nodeType === WorkflowNodeType.Nope) {
throw new Error(`Nope block cannot have edges`);
}
node.parent = this;
this.edges.push(node);
@@ -133,8 +131,8 @@ export class WorkflowNode {
async *next(
params: WorkflowNodeState,
options?: CopilotChatOptions
): AsyncIterable<WorkflowResult> {
yield { type: WorkflowResultType.StartRun, nodeId: this.id };
): AsyncIterable<NodeExecuteResult> {
yield { type: NodeExecuteState.StartRun, nodeId: this.id };
// choose next node in graph
let nextNode: WorkflowNode | undefined = this.edges[0];
@@ -155,12 +153,12 @@ export class WorkflowNode {
yield* this.executor.next(this.data, params, options);
} else {
yield {
type: WorkflowResultType.Content,
type: NodeExecuteState.Content,
nodeId: this.id,
content: params.content,
};
}
yield { type: WorkflowResultType.EndRun, nextNode };
yield { type: NodeExecuteState.EndRun, nextNode };
}
}

View File

@@ -0,0 +1,68 @@
import { Injectable, Logger } from '@nestjs/common';
import { CopilotChatOptions } from '../types';
import { WorkflowGraphList } from './graph';
import { WorkflowNode } from './node';
import type { WorkflowGraph, WorkflowGraphInstances } from './types';
import { type GraphExecutorStatus, WorkflowGraphExecutor } from './workflow';
@Injectable()
export class CopilotWorkflowService {
private readonly logger = new Logger(CopilotWorkflowService.name);
initWorkflow(graph: WorkflowGraph) {
const workflow = new Map<string, WorkflowNode>();
for (const nodeData of graph.graph) {
const { edges: _, ...data } = nodeData;
const node = new WorkflowNode(graph, data);
workflow.set(node.id, node);
}
// add edges
for (const nodeData of graph.graph) {
const node = workflow.get(nodeData.id);
if (!node) {
this.logger.error(
`Failed to init workflow ${name}: node ${nodeData.id} not found`
);
throw new Error(`Node ${nodeData.id} not found`);
}
for (const edgeId of nodeData.edges) {
const edge = workflow.get(edgeId);
if (!edge) {
this.logger.error(
`Failed to init workflow ${name}: edge ${edgeId} not found in node ${nodeData.id}`
);
throw new Error(`Edge ${edgeId} not found`);
}
node.addEdge(edge);
}
}
return workflow;
}
// TODO(@darkskygit): get workflow from database
private async getWorkflow(
graphName: string
): Promise<WorkflowGraphInstances> {
const graph = WorkflowGraphList.find(g => g.name === graphName);
if (!graph) {
throw new Error(`Graph ${graphName} not found`);
}
return this.initWorkflow(graph);
}
async *runGraph(
params: Record<string, string>,
graphName: string,
options?: CopilotChatOptions
): AsyncIterable<GraphExecutorStatus> {
const workflowGraph = await this.getWorkflow(graphName);
const executor = new WorkflowGraphExecutor(workflowGraph);
for await (const result of executor.runGraph(params, options)) {
yield result;
}
}
}

View File

@@ -1,17 +1,19 @@
import type { WorkflowExecutorType } from './executor';
import type { NodeExecutorType } from './executor';
import type { WorkflowNode } from './node';
// ===================== node =====================
export enum WorkflowNodeType {
Basic = 'basic',
Decision = 'decision',
Nope = 'nope',
}
export type NodeData = { id: string; name: string } & (
export type WorkflowNodeData = { id: string; name: string } & (
| {
nodeType: WorkflowNodeType.Basic;
promptName: string;
type: WorkflowExecutorType;
type: NodeExecutorType;
promptName?: string;
// update the prompt params by output with the custom key
paramKey?: string;
}
@@ -25,35 +27,22 @@ export type NodeData = { id: string; name: string } & (
| { nodeType: WorkflowNodeType.Nope }
);
export type WorkflowNodeState = Record<string, string>;
export type WorkflowGraphInstances = Map<string, WorkflowNode>;
export type WorkflowGraphData = Array<NodeData & { edges: string[] }>;
// ===================== graph =====================
export type WorkflowGraphDefinition = Array<
WorkflowNodeData & { edges: string[] }
>;
export type WorkflowGraph = {
name: string;
// true if the graph has been modified
modified?: boolean;
graph: WorkflowGraphData;
graph: WorkflowGraphDefinition;
};
export type WorkflowGraphs = Array<WorkflowGraph>;
export enum WorkflowResultType {
StartRun,
EndRun,
Params,
Content,
}
// ===================== executor =====================
export type WorkflowResult =
| { type: WorkflowResultType.StartRun; nodeId: string }
| { type: WorkflowResultType.EndRun; nextNode?: WorkflowNode }
| {
type: WorkflowResultType.Params;
params: Record<string, string | string[]>;
}
| {
type: WorkflowResultType.Content;
nodeId: string;
content: string;
};
export type WorkflowGraphInstances = Map<string, WorkflowNode>;
export type WorkflowParams = Record<string, string | string[]>;
export type WorkflowNodeState = Record<string, string>;

View File

@@ -1,16 +1,25 @@
import { Logger } from '@nestjs/common';
import { CopilotChatOptions } from '../types';
import { NodeExecuteState } from './executor';
import { WorkflowNode } from './node';
import {
type WorkflowGraphInstances,
type WorkflowNodeState,
WorkflowNodeType,
WorkflowResultType,
} from './types';
import type { WorkflowGraphInstances, WorkflowNodeState } from './types';
import { WorkflowNodeType } from './types';
export class CopilotWorkflow {
private readonly logger = new Logger(CopilotWorkflow.name);
export enum GraphExecutorState {
EnterNode = 'EnterNode',
EmitContent = 'EmitContent',
ExitNode = 'ExitNode',
}
export type GraphExecutorStatus = { status: GraphExecutorState } & (
| { status: GraphExecutorState.EnterNode; node: WorkflowNode }
| { status: GraphExecutorState.EmitContent; content: string }
| { status: GraphExecutorState.ExitNode; node: WorkflowNode }
);
export class WorkflowGraphExecutor {
private readonly logger = new Logger(WorkflowGraphExecutor.name);
private readonly rootNode: WorkflowNode;
constructor(workflow: WorkflowGraphInstances) {
@@ -24,7 +33,7 @@ export class CopilotWorkflow {
async *runGraph(
params: Record<string, string>,
options?: CopilotChatOptions
): AsyncIterable<string> {
): AsyncIterable<GraphExecutorStatus> {
let currentNode: WorkflowNode | undefined = this.rootNode;
const lastParams: WorkflowNodeState = { ...params };
@@ -33,10 +42,13 @@ export class CopilotWorkflow {
let nextNode: WorkflowNode | undefined;
for await (const ret of currentNode.next(lastParams, options)) {
if (ret.type === WorkflowResultType.EndRun) {
if (ret.type === NodeExecuteState.StartRun) {
yield { status: GraphExecutorState.EnterNode, node: currentNode };
} else if (ret.type === NodeExecuteState.EndRun) {
yield { status: GraphExecutorState.ExitNode, node: currentNode };
nextNode = ret.nextNode;
break;
} else if (ret.type === WorkflowResultType.Params) {
} else if (ret.type === NodeExecuteState.Params) {
Object.assign(lastParams, ret.params);
if (currentNode.config.nodeType === WorkflowNodeType.Basic) {
const { type, promptName } = currentNode.config;
@@ -44,10 +56,13 @@ export class CopilotWorkflow {
`[${currentNode.name}][${type}][${promptName}]: update params - '${JSON.stringify(ret.params)}'`
);
}
} else if (ret.type === WorkflowResultType.Content) {
} else if (ret.type === NodeExecuteState.Content) {
if (!currentNode.hasEdges) {
// pass through content as a stream response if node is end node
yield ret.content;
yield {
status: GraphExecutorState.EmitContent,
content: ret.content,
};
} else {
result += ret.content;
}