feat: workflow executor (#7159)

fix AFF-1221 AFF-1232
This commit is contained in:
darkskygit
2024-06-25 08:40:47 +00:00
parent 45b3b833d4
commit fe89ecb1d3
16 changed files with 573 additions and 201 deletions

View File

@@ -80,12 +80,14 @@
"on-headers": "^1.0.2",
"openai": "^4.33.0",
"parse-duration": "^1.1.0",
"piscina": "^4.5.1",
"pretty-time": "^1.1.0",
"prisma": "^5.12.1",
"prom-client": "^15.1.1",
"reflect-metadata": "^0.2.2",
"rxjs": "^7.8.1",
"semver": "^7.6.0",
"ses": "^1.4.1",
"socket.io": "^4.7.5",
"stripe": "^15.0.0",
"ts-node": "^10.9.2",

View File

@@ -514,8 +514,8 @@ content: {{content}}`,
],
},
{
name: 'workflow:presentation:step4',
action: 'workflow:presentation:step4',
name: 'workflow:presentation:step5',
action: 'workflow:presentation:step5',
model: 'gpt-4o',
messages: [
{

View File

@@ -22,7 +22,7 @@ import {
} from './resolver';
import { ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import { CopilotWorkflowService } from './workflow';
import { CopilotWorkflowExecutors, CopilotWorkflowService } from './workflow';
registerCopilotProvider(FalProvider);
registerCopilotProvider(OpenAIProvider);
@@ -41,6 +41,7 @@ registerCopilotProvider(OpenAIProvider);
CopilotStorage,
PromptsManagementResolver,
CopilotWorkflowService,
...CopilotWorkflowExecutors,
],
controllers: [CopilotController],
contributesTo: ServerFeature.Copilot,

View File

@@ -0,0 +1,95 @@
import { Injectable } from '@nestjs/common';
import { ChatPrompt, PromptService } from '../../prompt';
import { CopilotProviderService } from '../../providers';
import { CopilotChatOptions, CopilotTextProvider } from '../../types';
import {
NodeData,
WorkflowNodeType,
WorkflowResult,
WorkflowResultType,
} from '../types';
import { WorkflowExecutorType } from './types';
import { AutoRegisteredWorkflowExecutor } from './utils';
@Injectable()
export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
constructor(
private readonly promptService: PromptService,
private readonly providerService: CopilotProviderService
) {
super();
}
private async initExecutor(
data: NodeData
): Promise<
[
NodeData & { nodeType: WorkflowNodeType.Basic },
ChatPrompt,
CopilotTextProvider,
]
> {
if (data.nodeType !== WorkflowNodeType.Basic) {
throw new Error(
`Executor ${this.type} not support ${data.nodeType} node`
);
}
const prompt = await this.promptService.get(data.promptName);
if (!prompt) {
throw new Error(
`Prompt ${data.promptName} not found when running workflow node ${data.name}`
);
}
const provider = await this.providerService.getProviderByModel(
prompt.model
);
if (provider && 'generateText' in provider) {
return [data, prompt, provider];
}
throw new Error(
`Provider not found for model ${prompt.model} when running workflow node ${data.name}`
);
}
override get type() {
return WorkflowExecutorType.ChatText;
}
override async *next(
data: NodeData,
params: Record<string, string>,
options?: CopilotChatOptions
): AsyncIterable<WorkflowResult> {
const [{ paramKey, id }, prompt, provider] = await this.initExecutor(data);
const finalMessage = prompt.finish(params);
if (paramKey) {
// update params with custom key
yield {
type: WorkflowResultType.Params,
params: {
[paramKey]: await provider.generateText(
finalMessage,
prompt.model,
options
),
},
};
} else {
for await (const content of provider.generateTextStream(
finalMessage,
prompt.model,
options
)) {
yield {
type: WorkflowResultType.Content,
nodeId: id,
content,
};
}
}
}
}

View File

@@ -0,0 +1,7 @@
import { CopilotChatTextExecutor } from './chat-text';
export const CopilotWorkflowExecutors = [CopilotChatTextExecutor];
export { type WorkflowExecutor, WorkflowExecutorType } from './types';
export { getWorkflowExecutor } from './utils';
export { CopilotChatTextExecutor };

View File

@@ -0,0 +1,15 @@
import { CopilotChatOptions } from '../../types';
import { NodeData, WorkflowResult } from '../types';
export enum WorkflowExecutorType {
ChatText = 'ChatText',
}
export abstract class WorkflowExecutor {
abstract get type(): WorkflowExecutorType;
abstract next(
data: NodeData,
params: Record<string, string | string[]>,
options?: CopilotChatOptions
): AsyncIterable<WorkflowResult>;
}

View File

@@ -0,0 +1,40 @@
import { Logger, OnModuleInit } from '@nestjs/common';
import { WorkflowExecutor, type WorkflowExecutorType } from './types';
const WORKFLOW_EXECUTOR: Map<string, WorkflowExecutor> = new Map();
function registerWorkflowExecutor(e: WorkflowExecutor) {
const existing = WORKFLOW_EXECUTOR.get(e.type);
if (existing && existing === e) return false;
WORKFLOW_EXECUTOR.set(e.type, e);
return true;
}
export function getWorkflowExecutor(
type: WorkflowExecutorType
): WorkflowExecutor {
const executor = WORKFLOW_EXECUTOR.get(type);
if (!executor) {
throw new Error(`Executor ${type} not defined`);
}
return executor;
}
export abstract class AutoRegisteredWorkflowExecutor
extends WorkflowExecutor
implements OnModuleInit
{
onModuleInit() {
this.register();
}
register() {
if (registerWorkflowExecutor(this)) {
new Logger(`CopilotWorkflowExecutor:${this.type}`).log(
'Workflow executor registered.'
);
}
}
}

View File

@@ -1,6 +1,8 @@
import { type WorkflowGraphList, WorkflowNodeType } from './types';
import { WorkflowExecutorType } from './executor';
import type { WorkflowGraphs } from './types';
import { WorkflowNodeState, WorkflowNodeType } from './types';
export const WorkflowGraphs: WorkflowGraphList = [
export const WorkflowGraphList: WorkflowGraphs = [
{
name: 'presentation',
graph: [
@@ -8,7 +10,7 @@ export const WorkflowGraphs: WorkflowGraphList = [
id: 'start',
name: 'Start: check language',
nodeType: WorkflowNodeType.Basic,
type: 'text',
type: WorkflowExecutorType.ChatText,
promptName: 'workflow:presentation:step1',
paramKey: 'language',
edges: ['step2'],
@@ -17,49 +19,41 @@ export const WorkflowGraphs: WorkflowGraphList = [
id: 'step2',
name: 'Step 2: generate presentation',
nodeType: WorkflowNodeType.Basic,
type: 'text',
type: WorkflowExecutorType.ChatText,
promptName: 'workflow:presentation:step2',
edges: [],
// edges: ['step3'],
edges: ['step3'],
},
{
id: 'step3',
name: 'Step 3: check format',
nodeType: WorkflowNodeType.Basic,
type: WorkflowExecutorType.ChatText,
promptName: 'workflow:presentation:step3',
paramKey: 'needFormat',
edges: ['step4'],
},
{
id: 'step4',
name: 'Step 4: format presentation if needed',
nodeType: WorkflowNodeType.Decision,
condition: (nodeIds: string[], params: WorkflowNodeState) =>
nodeIds[Number(String(params.needFormat).toLowerCase() !== 'true')],
edges: ['step5', 'step6'],
},
{
id: 'step5',
name: 'Step 5: format presentation',
nodeType: WorkflowNodeType.Basic,
type: WorkflowExecutorType.ChatText,
promptName: 'workflow:presentation:step5',
edges: ['step6'],
},
{
id: 'step6',
name: 'Step 6: finish',
nodeType: WorkflowNodeType.Nope,
edges: [],
},
// {
// id: 'step3',
// name: 'Step 3: check format',
// nodeType: WorkflowNodeType.Basic,
// type: 'text',
// promptName: 'workflow:presentation:step3',
// paramKey: 'needFormat',
// edges: ['step4'],
// },
// {
// id: 'step4',
// name: 'Step 4: format presentation if needed',
// nodeType: WorkflowNodeType.Decision,
// condition: ((
// nodeIds: string[],
// params: WorkflowNodeState
// ) =>
// nodeIds[
// Number(String(params.needFormat).toLowerCase() === 'true')
// ]).toString(),
// edges: ['step5', 'step6'],
// },
// {
// id: 'step5',
// name: 'Step 5: format presentation',
// nodeType: WorkflowNodeType.Basic,
// type: 'text',
// promptName: 'workflow:presentation:step5',
// edges: ['step6'],
// },
// {
// id: 'step6',
// name: 'Step 6: finish',
// nodeType: WorkflowNodeType.Basic,
// type: 'text',
// promptName: 'workflow:presentation:step6',
// edges: [],
// },
],
},
];

View File

@@ -1,31 +1,26 @@
import { Injectable, Logger } from '@nestjs/common';
import { PromptService } from '../prompt';
import { CopilotProviderService } from '../providers';
import { CopilotChatOptions } from '../types';
import { WorkflowGraphs } from './graph';
import { WorkflowGraphList } from './graph';
import { WorkflowNode } from './node';
import { WorkflowGraph, WorkflowGraphList } from './types';
import type { WorkflowGraph, WorkflowGraphInstances } from './types';
import { CopilotWorkflow } from './workflow';
@Injectable()
export class CopilotWorkflowService {
private readonly logger = new Logger(CopilotWorkflowService.name);
constructor(
private readonly prompt: PromptService,
private readonly provider: CopilotProviderService
) {}
constructor() {}
private initWorkflow({ name, graph }: WorkflowGraphList[number]) {
const workflow = new Map();
for (const nodeData of graph) {
private initWorkflow(graph: WorkflowGraph) {
const workflow = new Map<string, WorkflowNode>();
for (const nodeData of graph.graph) {
const { edges: _, ...data } = nodeData;
const node = new WorkflowNode(data);
const node = new WorkflowNode(graph, data);
workflow.set(node.id, node);
}
// add edges
for (const nodeData of graph) {
for (const nodeData of graph.graph) {
const node = workflow.get(nodeData.id);
if (!node) {
this.logger.error(
@@ -47,9 +42,11 @@ export class CopilotWorkflowService {
return workflow;
}
// TODO(@darksky): get workflow from database
private async getWorkflow(graphName: string): Promise<WorkflowGraph> {
const graph = WorkflowGraphs.find(g => g.name === graphName);
// TODO(@darkskygit): get workflow from database
private async getWorkflow(
graphName: string
): Promise<WorkflowGraphInstances> {
const graph = WorkflowGraphList.find(g => g.name === graphName);
if (!graph) {
throw new Error(`Graph ${graphName} not found`);
}
@@ -63,14 +60,13 @@ export class CopilotWorkflowService {
options?: CopilotChatOptions
): AsyncIterable<string> {
const workflowGraph = await this.getWorkflow(graphName);
const workflow = new CopilotWorkflow(
this.prompt,
this.provider,
workflowGraph
);
const workflow = new CopilotWorkflow(workflowGraph);
for await (const result of workflow.runGraph(params, options)) {
yield result;
}
}
}
export { CopilotChatTextExecutor, CopilotWorkflowExecutors } from './executor';
export { WorkflowNodeType } from './types';

View File

@@ -1,21 +1,73 @@
import { ChatPrompt, PromptService } from '../prompt';
import { CopilotProviderService } from '../providers';
import { CopilotAllProvider, CopilotChatOptions } from '../types';
import {
import path, { dirname } from 'node:path';
import { fileURLToPath } from 'node:url';
import { Logger } from '@nestjs/common';
import Piscina from 'piscina';
import { CopilotChatOptions } from '../types';
import { getWorkflowExecutor, WorkflowExecutor } from './executor';
import type {
NodeData,
WorkflowGraph,
WorkflowNodeState,
WorkflowNodeType,
WorkflowResult,
WorkflowResultType,
} from './types';
import { WorkflowNodeType, WorkflowResultType } from './types';
export class WorkflowNode {
private readonly logger = new Logger(WorkflowNode.name);
private readonly edges: WorkflowNode[] = [];
private readonly parents: WorkflowNode[] = [];
private prompt: ChatPrompt | null = null;
private provider: CopilotAllProvider | null = null;
private readonly executor: WorkflowExecutor | null = null;
private readonly condition:
| ((params: WorkflowNodeState) => Promise<any>)
| null = null;
constructor(private readonly data: NodeData) {}
constructor(
graph: WorkflowGraph,
private readonly data: NodeData
) {
if (data.nodeType === WorkflowNodeType.Basic) {
this.executor = getWorkflowExecutor(data.type);
} else if (data.nodeType === WorkflowNodeType.Decision) {
// prepare decision condition, reused in each run
const iife = `(${data.condition})(nodeIds, params)`;
// only eval the condition in worker if graph has been modified
if (graph.modified) {
const worker = new Piscina({
filename: path.resolve(
dirname(fileURLToPath(import.meta.url)),
'worker.mjs'
),
minThreads: 2,
// empty envs from parent process
env: {},
argv: [],
execArgv: [],
});
this.condition = (params: WorkflowNodeState) =>
worker.run({
iife,
nodeIds: this.edges.map(node => node.id),
params,
});
} else {
const func =
typeof data.condition === 'function'
? data.condition
: new Function(
'nodeIds',
'params',
`(${data.condition})(nodeIds, params)`
);
this.condition = (params: WorkflowNodeState) =>
func(
this.edges.map(node => node.id),
params
);
}
}
}
get id(): string {
return this.data.id;
@@ -33,6 +85,11 @@ export class WorkflowNode {
return this.parents;
}
// if is the end of the workflow, pass through the content to stream response
get hasEdges(): boolean {
return !!this.edges.length;
}
private set parent(node: WorkflowNode) {
if (!this.parents.includes(node)) {
this.parents.push(node);
@@ -44,7 +101,10 @@ export class WorkflowNode {
if (this.edges.length > 0) {
throw new Error(`Basic block can only have one edge`);
}
} else if (!this.data.condition) {
} else if (
this.data.nodeType === WorkflowNodeType.Decision &&
!this.data.condition
) {
throw new Error(`Decision block must have a condition`);
}
node.parent = this;
@@ -52,84 +112,34 @@ export class WorkflowNode {
return this.edges.length;
}
async initNode(prompt: PromptService, provider: CopilotProviderService) {
if (this.prompt && this.provider) return;
if (this.data.nodeType === WorkflowNodeType.Basic) {
this.prompt = await prompt.get(this.data.promptName);
if (!this.prompt) {
throw new Error(
`Prompt ${this.data.promptName} not found when running workflow node ${this.name}`
);
}
this.provider = await provider.getProviderByModel(this.prompt.model);
if (!this.provider) {
throw new Error(
`Provider not found for model ${this.prompt.model} when running workflow node ${this.name}`
);
}
}
}
private async evaluateCondition(
_condition?: string
params: WorkflowNodeState
): Promise<string | undefined> {
// TODO(@darksky): evaluate condition to impl decision block
return this.edges[0]?.id;
}
private getStreamProvider() {
if (this.data.nodeType === WorkflowNodeType.Basic && this.provider) {
if (
this.data.type === 'text' &&
'generateText' in this.provider &&
!this.data.paramKey
) {
return this.provider.generateTextStream.bind(this.provider);
} else if (
this.data.type === 'image' &&
'generateImages' in this.provider &&
!this.data.paramKey
) {
return this.provider.generateImagesStream.bind(this.provider);
}
// early return if no edges
if (this.edges.length === 0) return undefined;
try {
const result = await this.condition?.(params);
if (typeof result === 'string') return result;
// choose default edge if condition falsy
return this.edges[0].id;
} catch (e) {
this.logger.error(
`Failed to evaluate condition for node ${this.name}: ${e}`
);
throw e;
}
throw new Error(`Stream Provider not found for node ${this.name}`);
}
private getProvider() {
if (this.data.nodeType === WorkflowNodeType.Basic && this.provider) {
if (
this.data.type === 'text' &&
'generateText' in this.provider &&
this.data.paramKey
) {
return this.provider.generateText.bind(this.provider);
} else if (
this.data.type === 'image' &&
'generateImages' in this.provider &&
this.data.paramKey
) {
return this.provider.generateImages.bind(this.provider);
}
}
throw new Error(`Provider not found for node ${this.name}`);
}
async *next(
params: WorkflowNodeState,
options?: CopilotChatOptions
): AsyncIterable<WorkflowResult> {
if (!this.prompt || !this.provider) {
throw new Error(`Node ${this.name} not initialized`);
}
yield { type: WorkflowResultType.StartRun, nodeId: this.id };
// choose next node in graph
let nextNode: WorkflowNode | undefined = this.edges[0];
if (this.data.nodeType === WorkflowNodeType.Decision) {
const nextNodeId = await this.evaluateCondition(this.data.condition);
const nextNodeId = await this.evaluateCondition(params);
// return empty to choose default edge
if (nextNodeId) {
nextNode = this.edges.find(node => node.id === nextNodeId);
@@ -137,37 +147,18 @@ export class WorkflowNode {
throw new Error(`No edge found for condition ${this.data.condition}`);
}
}
} else {
const finalMessage = this.prompt.finish(params);
if (this.data.paramKey) {
const provider = this.getProvider();
// update params with custom key
yield {
type: WorkflowResultType.Params,
params: {
[this.data.paramKey]: await provider(
finalMessage,
this.prompt.model,
options
),
},
};
} else {
const provider = this.getStreamProvider();
for await (const content of provider(
finalMessage,
this.prompt.model,
options
)) {
yield {
type: WorkflowResultType.Content,
nodeId: this.id,
content,
// pass through content as a stream response if no next node
passthrough: !nextNode,
};
}
} else if (this.data.nodeType === WorkflowNodeType.Basic) {
if (!this.executor) {
throw new Error(`Node ${this.name} not initialized`);
}
yield* this.executor.next(this.data, params, options);
} else {
yield {
type: WorkflowResultType.Content,
nodeId: this.id,
content: params.content,
};
}
yield { type: WorkflowResultType.EndRun, nextNode };

View File

@@ -1,28 +1,40 @@
import type { WorkflowExecutorType } from './executor';
import type { WorkflowNode } from './node';
export enum WorkflowNodeType {
Basic,
Decision,
Basic = 'basic',
Decision = 'decision',
Nope = 'nope',
}
export type NodeData = { id: string; name: string } & (
| {
nodeType: WorkflowNodeType.Basic;
promptName: string;
type: 'text' | 'image';
type: WorkflowExecutorType;
// update the prompt params by output with the custom key
paramKey?: string;
}
| { nodeType: WorkflowNodeType.Decision; condition: string }
| {
nodeType: WorkflowNodeType.Decision;
condition:
| ((nodeIds: string[], params: WorkflowNodeState) => string)
| string;
}
// do nothing node
| { nodeType: WorkflowNodeType.Nope }
);
export type WorkflowNodeState = Record<string, string>;
export type WorkflowGraphData = Array<NodeData & { edges: string[] }>;
export type WorkflowGraphList = Array<{
export type WorkflowGraph = {
name: string;
// true if the graph has been modified
modified?: boolean;
graph: WorkflowGraphData;
}>;
};
export type WorkflowGraphs = Array<WorkflowGraph>;
export enum WorkflowResultType {
StartRun,
@@ -33,7 +45,7 @@ export enum WorkflowResultType {
export type WorkflowResult =
| { type: WorkflowResultType.StartRun; nodeId: string }
| { type: WorkflowResultType.EndRun; nextNode: WorkflowNode }
| { type: WorkflowResultType.EndRun; nextNode?: WorkflowNode }
| {
type: WorkflowResultType.Params;
params: Record<string, string | string[]>;
@@ -42,8 +54,6 @@ export type WorkflowResult =
type: WorkflowResultType.Content;
nodeId: string;
content: string;
// if is the end of the workflow, pass through the content to stream response
passthrough?: boolean;
};
export type WorkflowGraph = Map<string, WorkflowNode>;
export type WorkflowGraphInstances = Map<string, WorkflowNode>;

View File

@@ -0,0 +1,11 @@
import 'ses';
lockdown();
const sandbox = new Compartment();
export default ({ iife, nodeIds, params }) => {
sandbox.globalThis.nodeIds = harden(nodeIds);
sandbox.globalThis.params = harden(params);
return sandbox.evaluate(iife);
};

View File

@@ -1,12 +1,10 @@
import { Logger } from '@nestjs/common';
import { PromptService } from '../prompt';
import { CopilotProviderService } from '../providers';
import { CopilotChatOptions } from '../types';
import { WorkflowNode } from './node';
import {
WorkflowGraph,
WorkflowNodeState,
type WorkflowGraphInstances,
type WorkflowNodeState,
WorkflowNodeType,
WorkflowResultType,
} from './types';
@@ -15,11 +13,7 @@ export class CopilotWorkflow {
private readonly logger = new Logger(CopilotWorkflow.name);
private readonly rootNode: WorkflowNode;
constructor(
private readonly prompt: PromptService,
private readonly provider: CopilotProviderService,
workflow: WorkflowGraph
) {
constructor(workflow: WorkflowGraphInstances) {
const startNode = workflow.get('start');
if (!startNode) {
throw new Error(`No start node found in graph`);
@@ -38,8 +32,6 @@ export class CopilotWorkflow {
let result = '';
let nextNode: WorkflowNode | undefined;
await currentNode.initNode(this.prompt, this.provider);
for await (const ret of currentNode.next(lastParams, options)) {
if (ret.type === WorkflowResultType.EndRun) {
nextNode = ret.nextNode;
@@ -53,8 +45,8 @@ export class CopilotWorkflow {
);
}
} else if (ret.type === WorkflowResultType.Content) {
if (ret.passthrough) {
// pass through content as a stream response
if (!currentNode.hasEdges) {
// pass through content as a stream response if node is end node
yield ret.content;
} else {
result += ret.content;
@@ -70,7 +62,9 @@ export class CopilotWorkflow {
}
currentNode = nextNode;
if (result) lastParams.content = result;
if (result && lastParams.content !== result) {
lastParams.content = result;
}
}
}
}

View File

@@ -259,7 +259,7 @@ test('should be able to chat with api by workflow', async t => {
const ret = await chatWithWorkflow(app, token, sessionId, messageId);
t.is(
ret,
textToEventStream('generate text to text stream', messageId),
textToEventStream(['generate text to text stream'], messageId),
'should be able to chat with workflow'
);
});

View File

@@ -3,6 +3,7 @@
import { TestingModule } from '@nestjs/testing';
import type { TestFn } from 'ava';
import ava from 'ava';
import Sinon from 'sinon';
import { AuthService } from '../src/core/auth';
import { QuotaModule } from '../src/core/quota';
@@ -21,7 +22,20 @@ import {
CopilotCapability,
CopilotProviderType,
} from '../src/plugins/copilot/types';
import { CopilotWorkflowService } from '../src/plugins/copilot/workflow';
import {
CopilotChatTextExecutor,
CopilotWorkflowService,
WorkflowNodeType,
} from '../src/plugins/copilot/workflow';
import {
getWorkflowExecutor,
WorkflowExecutorType,
} from '../src/plugins/copilot/workflow/executor';
import { WorkflowGraphList } from '../src/plugins/copilot/workflow/graph';
import {
NodeData,
WorkflowResultType,
} from '../src/plugins/copilot/workflow/types';
import { createTestingModule } from './utils';
import { MockCopilotTestProvider } from './utils/copilot';
@@ -32,6 +46,7 @@ const test = ava as TestFn<{
provider: CopilotProviderService;
session: ChatSessionService;
workflow: CopilotWorkflowService;
textWorkflowExecutor: CopilotChatTextExecutor;
}>;
test.beforeEach(async t => {
@@ -59,6 +74,7 @@ test.beforeEach(async t => {
const provider = module.get(CopilotProviderService);
const session = module.get(ChatSessionService);
const workflow = module.get(CopilotWorkflowService);
const textWorkflowExecutor = module.get(CopilotChatTextExecutor);
t.context.module = module;
t.context.auth = auth;
@@ -66,6 +82,7 @@ test.beforeEach(async t => {
t.context.provider = provider;
t.context.session = session;
t.context.workflow = workflow;
t.context.textWorkflowExecutor = textWorkflowExecutor;
});
test.afterEach.always(async t => {
@@ -541,10 +558,14 @@ test('should be able to register test provider', async t => {
await assertProvider(CopilotCapability.ImageToText);
});
// ==================== workflow ====================
// this test used to preview the final result of the workflow
// for the functional test of the API itself, refer to the follow tests
test.skip('should be able to preview workflow', async t => {
const { prompt, workflow } = t.context;
const { prompt, workflow, textWorkflowExecutor } = t.context;
textWorkflowExecutor.register();
registerCopilotProvider(OpenAIProvider);
for (const p of prompts) {
@@ -554,13 +575,174 @@ test.skip('should be able to preview workflow', async t => {
let result = '';
for await (const ret of workflow.runGraph(
{ content: 'apple company' },
'workflow:presentation'
'presentation'
)) {
result += ret;
console.log('stream result:', ret);
}
console.log('final stream result:', result);
t.truthy(result, 'should return result');
unregisterCopilotProvider(OpenAIProvider.type);
t.pass();
});
test('should be able to run workflow', async t => {
const { prompt, workflow, textWorkflowExecutor } = t.context;
textWorkflowExecutor.register();
unregisterCopilotProvider(OpenAIProvider.type);
registerCopilotProvider(MockCopilotTestProvider);
const executor = Sinon.spy(textWorkflowExecutor, 'next');
for (const p of prompts) {
await prompt.set(p.name, p.model, p.messages);
}
const graphName = 'presentation';
const graph = WorkflowGraphList.find(g => g.name === graphName);
t.truthy(graph, `graph ${graphName} not defined`);
// todo: use Array.fromAsync
let result = '';
for await (const ret of workflow.runGraph(
{ content: 'apple company' },
graphName
)) {
result += ret;
}
t.assert(result, 'generate text to text stream');
// presentation workflow has condition node, it will always false
// so the latest 2 nodes will not be executed
const callCount = graph!.graph.length - 3;
t.is(
executor.callCount,
callCount,
`should call executor ${callCount} times`
);
for (const [idx, node] of graph!.graph
.filter(g => g.nodeType === WorkflowNodeType.Basic)
.entries()) {
const params = executor.getCall(idx);
if (idx < callCount) {
t.is(params.args[0].id, node.id, 'graph id should correct');
t.is(
params.args[1].content,
'generate text to text stream',
'graph params should correct'
);
t.is(
params.args[1].language,
'generate text to text',
'graph params should correct'
);
}
}
unregisterCopilotProvider(MockCopilotTestProvider.type);
registerCopilotProvider(OpenAIProvider);
});
// ==================== workflow executor ====================
const wrapAsyncIter = async <T>(iter: AsyncIterable<T>) => {
const result: T[] = [];
for await (const r of iter) {
result.push(r);
}
return result;
};
test('should be able to run executor', async t => {
const { textWorkflowExecutor } = t.context;
textWorkflowExecutor.register();
const executor = getWorkflowExecutor(textWorkflowExecutor.type);
t.is(executor.type, textWorkflowExecutor.type, 'should get executor');
await t.throwsAsync(
wrapAsyncIter(
executor.next(
{ id: 'nope', name: 'nope', nodeType: WorkflowNodeType.Nope },
{}
)
),
{ instanceOf: Error },
'should throw error if run non basic node'
);
});
test('should be able to run text executor', async t => {
const { textWorkflowExecutor, provider, prompt } = t.context;
textWorkflowExecutor.register();
const executor = getWorkflowExecutor(textWorkflowExecutor.type);
unregisterCopilotProvider(OpenAIProvider.type);
registerCopilotProvider(MockCopilotTestProvider);
await prompt.set('test', 'test', [
{ role: 'system', content: 'hello {{word}}' },
]);
// mock provider
const testProvider =
(await provider.getProviderByModel<CopilotCapability.TextToText>('test'))!;
const text = Sinon.spy(testProvider, 'generateText');
const textStream = Sinon.spy(testProvider, 'generateTextStream');
const nodeData: NodeData = {
id: 'basic',
name: 'basic',
nodeType: WorkflowNodeType.Basic,
promptName: 'test',
type: WorkflowExecutorType.ChatText,
};
// text
{
const ret = await wrapAsyncIter(
executor.next({ ...nodeData, paramKey: 'key' }, { word: 'world' })
);
t.deepEqual(ret, [
{
type: WorkflowResultType.Params,
params: { key: 'generate text to text' },
},
]);
t.deepEqual(
text.lastCall.args[0][0].content,
'hello world',
'should render the prompt with params'
);
}
// text stream with attachment
{
const ret = await wrapAsyncIter(
executor.next(nodeData, {
attachments: ['https://affine.pro/example.jpg'],
})
);
t.deepEqual(
ret,
Array.from('generate text to text stream').map(t => ({
content: t,
nodeId: 'basic',
type: WorkflowResultType.Content,
}))
);
t.deepEqual(
textStream.lastCall.args[0][0].params?.attachments,
['https://affine.pro/example.jpg'],
'should pass attachments to provider'
);
}
Sinon.restore();
unregisterCopilotProvider(MockCopilotTestProvider.type);
registerCopilotProvider(OpenAIProvider);
});