mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 04:18:54 +00:00
@@ -80,12 +80,14 @@
|
||||
"on-headers": "^1.0.2",
|
||||
"openai": "^4.33.0",
|
||||
"parse-duration": "^1.1.0",
|
||||
"piscina": "^4.5.1",
|
||||
"pretty-time": "^1.1.0",
|
||||
"prisma": "^5.12.1",
|
||||
"prom-client": "^15.1.1",
|
||||
"reflect-metadata": "^0.2.2",
|
||||
"rxjs": "^7.8.1",
|
||||
"semver": "^7.6.0",
|
||||
"ses": "^1.4.1",
|
||||
"socket.io": "^4.7.5",
|
||||
"stripe": "^15.0.0",
|
||||
"ts-node": "^10.9.2",
|
||||
|
||||
@@ -514,8 +514,8 @@ content: {{content}}`,
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'workflow:presentation:step4',
|
||||
action: 'workflow:presentation:step4',
|
||||
name: 'workflow:presentation:step5',
|
||||
action: 'workflow:presentation:step5',
|
||||
model: 'gpt-4o',
|
||||
messages: [
|
||||
{
|
||||
|
||||
@@ -22,7 +22,7 @@ import {
|
||||
} from './resolver';
|
||||
import { ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
import { CopilotWorkflowService } from './workflow';
|
||||
import { CopilotWorkflowExecutors, CopilotWorkflowService } from './workflow';
|
||||
|
||||
registerCopilotProvider(FalProvider);
|
||||
registerCopilotProvider(OpenAIProvider);
|
||||
@@ -41,6 +41,7 @@ registerCopilotProvider(OpenAIProvider);
|
||||
CopilotStorage,
|
||||
PromptsManagementResolver,
|
||||
CopilotWorkflowService,
|
||||
...CopilotWorkflowExecutors,
|
||||
],
|
||||
controllers: [CopilotController],
|
||||
contributesTo: ServerFeature.Copilot,
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { ChatPrompt, PromptService } from '../../prompt';
|
||||
import { CopilotProviderService } from '../../providers';
|
||||
import { CopilotChatOptions, CopilotTextProvider } from '../../types';
|
||||
import {
|
||||
NodeData,
|
||||
WorkflowNodeType,
|
||||
WorkflowResult,
|
||||
WorkflowResultType,
|
||||
} from '../types';
|
||||
import { WorkflowExecutorType } from './types';
|
||||
import { AutoRegisteredWorkflowExecutor } from './utils';
|
||||
|
||||
@Injectable()
|
||||
export class CopilotChatTextExecutor extends AutoRegisteredWorkflowExecutor {
|
||||
constructor(
|
||||
private readonly promptService: PromptService,
|
||||
private readonly providerService: CopilotProviderService
|
||||
) {
|
||||
super();
|
||||
}
|
||||
|
||||
private async initExecutor(
|
||||
data: NodeData
|
||||
): Promise<
|
||||
[
|
||||
NodeData & { nodeType: WorkflowNodeType.Basic },
|
||||
ChatPrompt,
|
||||
CopilotTextProvider,
|
||||
]
|
||||
> {
|
||||
if (data.nodeType !== WorkflowNodeType.Basic) {
|
||||
throw new Error(
|
||||
`Executor ${this.type} not support ${data.nodeType} node`
|
||||
);
|
||||
}
|
||||
|
||||
const prompt = await this.promptService.get(data.promptName);
|
||||
if (!prompt) {
|
||||
throw new Error(
|
||||
`Prompt ${data.promptName} not found when running workflow node ${data.name}`
|
||||
);
|
||||
}
|
||||
const provider = await this.providerService.getProviderByModel(
|
||||
prompt.model
|
||||
);
|
||||
if (provider && 'generateText' in provider) {
|
||||
return [data, prompt, provider];
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
`Provider not found for model ${prompt.model} when running workflow node ${data.name}`
|
||||
);
|
||||
}
|
||||
|
||||
override get type() {
|
||||
return WorkflowExecutorType.ChatText;
|
||||
}
|
||||
|
||||
override async *next(
|
||||
data: NodeData,
|
||||
params: Record<string, string>,
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<WorkflowResult> {
|
||||
const [{ paramKey, id }, prompt, provider] = await this.initExecutor(data);
|
||||
|
||||
const finalMessage = prompt.finish(params);
|
||||
if (paramKey) {
|
||||
// update params with custom key
|
||||
yield {
|
||||
type: WorkflowResultType.Params,
|
||||
params: {
|
||||
[paramKey]: await provider.generateText(
|
||||
finalMessage,
|
||||
prompt.model,
|
||||
options
|
||||
),
|
||||
},
|
||||
};
|
||||
} else {
|
||||
for await (const content of provider.generateTextStream(
|
||||
finalMessage,
|
||||
prompt.model,
|
||||
options
|
||||
)) {
|
||||
yield {
|
||||
type: WorkflowResultType.Content,
|
||||
nodeId: id,
|
||||
content,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
import { CopilotChatTextExecutor } from './chat-text';
|
||||
|
||||
export const CopilotWorkflowExecutors = [CopilotChatTextExecutor];
|
||||
|
||||
export { type WorkflowExecutor, WorkflowExecutorType } from './types';
|
||||
export { getWorkflowExecutor } from './utils';
|
||||
export { CopilotChatTextExecutor };
|
||||
@@ -0,0 +1,15 @@
|
||||
import { CopilotChatOptions } from '../../types';
|
||||
import { NodeData, WorkflowResult } from '../types';
|
||||
|
||||
export enum WorkflowExecutorType {
|
||||
ChatText = 'ChatText',
|
||||
}
|
||||
|
||||
export abstract class WorkflowExecutor {
|
||||
abstract get type(): WorkflowExecutorType;
|
||||
abstract next(
|
||||
data: NodeData,
|
||||
params: Record<string, string | string[]>,
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<WorkflowResult>;
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
import { Logger, OnModuleInit } from '@nestjs/common';
|
||||
|
||||
import { WorkflowExecutor, type WorkflowExecutorType } from './types';
|
||||
|
||||
const WORKFLOW_EXECUTOR: Map<string, WorkflowExecutor> = new Map();
|
||||
|
||||
function registerWorkflowExecutor(e: WorkflowExecutor) {
|
||||
const existing = WORKFLOW_EXECUTOR.get(e.type);
|
||||
if (existing && existing === e) return false;
|
||||
WORKFLOW_EXECUTOR.set(e.type, e);
|
||||
return true;
|
||||
}
|
||||
|
||||
export function getWorkflowExecutor(
|
||||
type: WorkflowExecutorType
|
||||
): WorkflowExecutor {
|
||||
const executor = WORKFLOW_EXECUTOR.get(type);
|
||||
if (!executor) {
|
||||
throw new Error(`Executor ${type} not defined`);
|
||||
}
|
||||
|
||||
return executor;
|
||||
}
|
||||
|
||||
export abstract class AutoRegisteredWorkflowExecutor
|
||||
extends WorkflowExecutor
|
||||
implements OnModuleInit
|
||||
{
|
||||
onModuleInit() {
|
||||
this.register();
|
||||
}
|
||||
|
||||
register() {
|
||||
if (registerWorkflowExecutor(this)) {
|
||||
new Logger(`CopilotWorkflowExecutor:${this.type}`).log(
|
||||
'Workflow executor registered.'
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
import { type WorkflowGraphList, WorkflowNodeType } from './types';
|
||||
import { WorkflowExecutorType } from './executor';
|
||||
import type { WorkflowGraphs } from './types';
|
||||
import { WorkflowNodeState, WorkflowNodeType } from './types';
|
||||
|
||||
export const WorkflowGraphs: WorkflowGraphList = [
|
||||
export const WorkflowGraphList: WorkflowGraphs = [
|
||||
{
|
||||
name: 'presentation',
|
||||
graph: [
|
||||
@@ -8,7 +10,7 @@ export const WorkflowGraphs: WorkflowGraphList = [
|
||||
id: 'start',
|
||||
name: 'Start: check language',
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
type: 'text',
|
||||
type: WorkflowExecutorType.ChatText,
|
||||
promptName: 'workflow:presentation:step1',
|
||||
paramKey: 'language',
|
||||
edges: ['step2'],
|
||||
@@ -17,49 +19,41 @@ export const WorkflowGraphs: WorkflowGraphList = [
|
||||
id: 'step2',
|
||||
name: 'Step 2: generate presentation',
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
type: 'text',
|
||||
type: WorkflowExecutorType.ChatText,
|
||||
promptName: 'workflow:presentation:step2',
|
||||
edges: [],
|
||||
// edges: ['step3'],
|
||||
edges: ['step3'],
|
||||
},
|
||||
{
|
||||
id: 'step3',
|
||||
name: 'Step 3: check format',
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
type: WorkflowExecutorType.ChatText,
|
||||
promptName: 'workflow:presentation:step3',
|
||||
paramKey: 'needFormat',
|
||||
edges: ['step4'],
|
||||
},
|
||||
{
|
||||
id: 'step4',
|
||||
name: 'Step 4: format presentation if needed',
|
||||
nodeType: WorkflowNodeType.Decision,
|
||||
condition: (nodeIds: string[], params: WorkflowNodeState) =>
|
||||
nodeIds[Number(String(params.needFormat).toLowerCase() !== 'true')],
|
||||
edges: ['step5', 'step6'],
|
||||
},
|
||||
{
|
||||
id: 'step5',
|
||||
name: 'Step 5: format presentation',
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
type: WorkflowExecutorType.ChatText,
|
||||
promptName: 'workflow:presentation:step5',
|
||||
edges: ['step6'],
|
||||
},
|
||||
{
|
||||
id: 'step6',
|
||||
name: 'Step 6: finish',
|
||||
nodeType: WorkflowNodeType.Nope,
|
||||
edges: [],
|
||||
},
|
||||
// {
|
||||
// id: 'step3',
|
||||
// name: 'Step 3: check format',
|
||||
// nodeType: WorkflowNodeType.Basic,
|
||||
// type: 'text',
|
||||
// promptName: 'workflow:presentation:step3',
|
||||
// paramKey: 'needFormat',
|
||||
// edges: ['step4'],
|
||||
// },
|
||||
// {
|
||||
// id: 'step4',
|
||||
// name: 'Step 4: format presentation if needed',
|
||||
// nodeType: WorkflowNodeType.Decision,
|
||||
// condition: ((
|
||||
// nodeIds: string[],
|
||||
// params: WorkflowNodeState
|
||||
// ) =>
|
||||
// nodeIds[
|
||||
// Number(String(params.needFormat).toLowerCase() === 'true')
|
||||
// ]).toString(),
|
||||
// edges: ['step5', 'step6'],
|
||||
// },
|
||||
// {
|
||||
// id: 'step5',
|
||||
// name: 'Step 5: format presentation',
|
||||
// nodeType: WorkflowNodeType.Basic,
|
||||
// type: 'text',
|
||||
// promptName: 'workflow:presentation:step5',
|
||||
// edges: ['step6'],
|
||||
// },
|
||||
// {
|
||||
// id: 'step6',
|
||||
// name: 'Step 6: finish',
|
||||
// nodeType: WorkflowNodeType.Basic,
|
||||
// type: 'text',
|
||||
// promptName: 'workflow:presentation:step6',
|
||||
// edges: [],
|
||||
// },
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
@@ -1,31 +1,26 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import { PromptService } from '../prompt';
|
||||
import { CopilotProviderService } from '../providers';
|
||||
import { CopilotChatOptions } from '../types';
|
||||
import { WorkflowGraphs } from './graph';
|
||||
import { WorkflowGraphList } from './graph';
|
||||
import { WorkflowNode } from './node';
|
||||
import { WorkflowGraph, WorkflowGraphList } from './types';
|
||||
import type { WorkflowGraph, WorkflowGraphInstances } from './types';
|
||||
import { CopilotWorkflow } from './workflow';
|
||||
|
||||
@Injectable()
|
||||
export class CopilotWorkflowService {
|
||||
private readonly logger = new Logger(CopilotWorkflowService.name);
|
||||
constructor(
|
||||
private readonly prompt: PromptService,
|
||||
private readonly provider: CopilotProviderService
|
||||
) {}
|
||||
constructor() {}
|
||||
|
||||
private initWorkflow({ name, graph }: WorkflowGraphList[number]) {
|
||||
const workflow = new Map();
|
||||
for (const nodeData of graph) {
|
||||
private initWorkflow(graph: WorkflowGraph) {
|
||||
const workflow = new Map<string, WorkflowNode>();
|
||||
for (const nodeData of graph.graph) {
|
||||
const { edges: _, ...data } = nodeData;
|
||||
const node = new WorkflowNode(data);
|
||||
const node = new WorkflowNode(graph, data);
|
||||
workflow.set(node.id, node);
|
||||
}
|
||||
|
||||
// add edges
|
||||
for (const nodeData of graph) {
|
||||
for (const nodeData of graph.graph) {
|
||||
const node = workflow.get(nodeData.id);
|
||||
if (!node) {
|
||||
this.logger.error(
|
||||
@@ -47,9 +42,11 @@ export class CopilotWorkflowService {
|
||||
return workflow;
|
||||
}
|
||||
|
||||
// TODO(@darksky): get workflow from database
|
||||
private async getWorkflow(graphName: string): Promise<WorkflowGraph> {
|
||||
const graph = WorkflowGraphs.find(g => g.name === graphName);
|
||||
// TODO(@darkskygit): get workflow from database
|
||||
private async getWorkflow(
|
||||
graphName: string
|
||||
): Promise<WorkflowGraphInstances> {
|
||||
const graph = WorkflowGraphList.find(g => g.name === graphName);
|
||||
if (!graph) {
|
||||
throw new Error(`Graph ${graphName} not found`);
|
||||
}
|
||||
@@ -63,14 +60,13 @@ export class CopilotWorkflowService {
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<string> {
|
||||
const workflowGraph = await this.getWorkflow(graphName);
|
||||
const workflow = new CopilotWorkflow(
|
||||
this.prompt,
|
||||
this.provider,
|
||||
workflowGraph
|
||||
);
|
||||
const workflow = new CopilotWorkflow(workflowGraph);
|
||||
|
||||
for await (const result of workflow.runGraph(params, options)) {
|
||||
yield result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export { CopilotChatTextExecutor, CopilotWorkflowExecutors } from './executor';
|
||||
export { WorkflowNodeType } from './types';
|
||||
|
||||
@@ -1,21 +1,73 @@
|
||||
import { ChatPrompt, PromptService } from '../prompt';
|
||||
import { CopilotProviderService } from '../providers';
|
||||
import { CopilotAllProvider, CopilotChatOptions } from '../types';
|
||||
import {
|
||||
import path, { dirname } from 'node:path';
|
||||
import { fileURLToPath } from 'node:url';
|
||||
|
||||
import { Logger } from '@nestjs/common';
|
||||
import Piscina from 'piscina';
|
||||
|
||||
import { CopilotChatOptions } from '../types';
|
||||
import { getWorkflowExecutor, WorkflowExecutor } from './executor';
|
||||
import type {
|
||||
NodeData,
|
||||
WorkflowGraph,
|
||||
WorkflowNodeState,
|
||||
WorkflowNodeType,
|
||||
WorkflowResult,
|
||||
WorkflowResultType,
|
||||
} from './types';
|
||||
import { WorkflowNodeType, WorkflowResultType } from './types';
|
||||
|
||||
export class WorkflowNode {
|
||||
private readonly logger = new Logger(WorkflowNode.name);
|
||||
private readonly edges: WorkflowNode[] = [];
|
||||
private readonly parents: WorkflowNode[] = [];
|
||||
private prompt: ChatPrompt | null = null;
|
||||
private provider: CopilotAllProvider | null = null;
|
||||
private readonly executor: WorkflowExecutor | null = null;
|
||||
private readonly condition:
|
||||
| ((params: WorkflowNodeState) => Promise<any>)
|
||||
| null = null;
|
||||
|
||||
constructor(private readonly data: NodeData) {}
|
||||
constructor(
|
||||
graph: WorkflowGraph,
|
||||
private readonly data: NodeData
|
||||
) {
|
||||
if (data.nodeType === WorkflowNodeType.Basic) {
|
||||
this.executor = getWorkflowExecutor(data.type);
|
||||
} else if (data.nodeType === WorkflowNodeType.Decision) {
|
||||
// prepare decision condition, reused in each run
|
||||
const iife = `(${data.condition})(nodeIds, params)`;
|
||||
// only eval the condition in worker if graph has been modified
|
||||
if (graph.modified) {
|
||||
const worker = new Piscina({
|
||||
filename: path.resolve(
|
||||
dirname(fileURLToPath(import.meta.url)),
|
||||
'worker.mjs'
|
||||
),
|
||||
minThreads: 2,
|
||||
// empty envs from parent process
|
||||
env: {},
|
||||
argv: [],
|
||||
execArgv: [],
|
||||
});
|
||||
this.condition = (params: WorkflowNodeState) =>
|
||||
worker.run({
|
||||
iife,
|
||||
nodeIds: this.edges.map(node => node.id),
|
||||
params,
|
||||
});
|
||||
} else {
|
||||
const func =
|
||||
typeof data.condition === 'function'
|
||||
? data.condition
|
||||
: new Function(
|
||||
'nodeIds',
|
||||
'params',
|
||||
`(${data.condition})(nodeIds, params)`
|
||||
);
|
||||
this.condition = (params: WorkflowNodeState) =>
|
||||
func(
|
||||
this.edges.map(node => node.id),
|
||||
params
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get id(): string {
|
||||
return this.data.id;
|
||||
@@ -33,6 +85,11 @@ export class WorkflowNode {
|
||||
return this.parents;
|
||||
}
|
||||
|
||||
// if is the end of the workflow, pass through the content to stream response
|
||||
get hasEdges(): boolean {
|
||||
return !!this.edges.length;
|
||||
}
|
||||
|
||||
private set parent(node: WorkflowNode) {
|
||||
if (!this.parents.includes(node)) {
|
||||
this.parents.push(node);
|
||||
@@ -44,7 +101,10 @@ export class WorkflowNode {
|
||||
if (this.edges.length > 0) {
|
||||
throw new Error(`Basic block can only have one edge`);
|
||||
}
|
||||
} else if (!this.data.condition) {
|
||||
} else if (
|
||||
this.data.nodeType === WorkflowNodeType.Decision &&
|
||||
!this.data.condition
|
||||
) {
|
||||
throw new Error(`Decision block must have a condition`);
|
||||
}
|
||||
node.parent = this;
|
||||
@@ -52,84 +112,34 @@ export class WorkflowNode {
|
||||
return this.edges.length;
|
||||
}
|
||||
|
||||
async initNode(prompt: PromptService, provider: CopilotProviderService) {
|
||||
if (this.prompt && this.provider) return;
|
||||
|
||||
if (this.data.nodeType === WorkflowNodeType.Basic) {
|
||||
this.prompt = await prompt.get(this.data.promptName);
|
||||
if (!this.prompt) {
|
||||
throw new Error(
|
||||
`Prompt ${this.data.promptName} not found when running workflow node ${this.name}`
|
||||
);
|
||||
}
|
||||
this.provider = await provider.getProviderByModel(this.prompt.model);
|
||||
if (!this.provider) {
|
||||
throw new Error(
|
||||
`Provider not found for model ${this.prompt.model} when running workflow node ${this.name}`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async evaluateCondition(
|
||||
_condition?: string
|
||||
params: WorkflowNodeState
|
||||
): Promise<string | undefined> {
|
||||
// TODO(@darksky): evaluate condition to impl decision block
|
||||
return this.edges[0]?.id;
|
||||
}
|
||||
|
||||
private getStreamProvider() {
|
||||
if (this.data.nodeType === WorkflowNodeType.Basic && this.provider) {
|
||||
if (
|
||||
this.data.type === 'text' &&
|
||||
'generateText' in this.provider &&
|
||||
!this.data.paramKey
|
||||
) {
|
||||
return this.provider.generateTextStream.bind(this.provider);
|
||||
} else if (
|
||||
this.data.type === 'image' &&
|
||||
'generateImages' in this.provider &&
|
||||
!this.data.paramKey
|
||||
) {
|
||||
return this.provider.generateImagesStream.bind(this.provider);
|
||||
}
|
||||
// early return if no edges
|
||||
if (this.edges.length === 0) return undefined;
|
||||
try {
|
||||
const result = await this.condition?.(params);
|
||||
if (typeof result === 'string') return result;
|
||||
// choose default edge if condition falsy
|
||||
return this.edges[0].id;
|
||||
} catch (e) {
|
||||
this.logger.error(
|
||||
`Failed to evaluate condition for node ${this.name}: ${e}`
|
||||
);
|
||||
throw e;
|
||||
}
|
||||
throw new Error(`Stream Provider not found for node ${this.name}`);
|
||||
}
|
||||
|
||||
private getProvider() {
|
||||
if (this.data.nodeType === WorkflowNodeType.Basic && this.provider) {
|
||||
if (
|
||||
this.data.type === 'text' &&
|
||||
'generateText' in this.provider &&
|
||||
this.data.paramKey
|
||||
) {
|
||||
return this.provider.generateText.bind(this.provider);
|
||||
} else if (
|
||||
this.data.type === 'image' &&
|
||||
'generateImages' in this.provider &&
|
||||
this.data.paramKey
|
||||
) {
|
||||
return this.provider.generateImages.bind(this.provider);
|
||||
}
|
||||
}
|
||||
throw new Error(`Provider not found for node ${this.name}`);
|
||||
}
|
||||
|
||||
async *next(
|
||||
params: WorkflowNodeState,
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<WorkflowResult> {
|
||||
if (!this.prompt || !this.provider) {
|
||||
throw new Error(`Node ${this.name} not initialized`);
|
||||
}
|
||||
|
||||
yield { type: WorkflowResultType.StartRun, nodeId: this.id };
|
||||
|
||||
// choose next node in graph
|
||||
let nextNode: WorkflowNode | undefined = this.edges[0];
|
||||
if (this.data.nodeType === WorkflowNodeType.Decision) {
|
||||
const nextNodeId = await this.evaluateCondition(this.data.condition);
|
||||
const nextNodeId = await this.evaluateCondition(params);
|
||||
// return empty to choose default edge
|
||||
if (nextNodeId) {
|
||||
nextNode = this.edges.find(node => node.id === nextNodeId);
|
||||
@@ -137,37 +147,18 @@ export class WorkflowNode {
|
||||
throw new Error(`No edge found for condition ${this.data.condition}`);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const finalMessage = this.prompt.finish(params);
|
||||
if (this.data.paramKey) {
|
||||
const provider = this.getProvider();
|
||||
// update params with custom key
|
||||
yield {
|
||||
type: WorkflowResultType.Params,
|
||||
params: {
|
||||
[this.data.paramKey]: await provider(
|
||||
finalMessage,
|
||||
this.prompt.model,
|
||||
options
|
||||
),
|
||||
},
|
||||
};
|
||||
} else {
|
||||
const provider = this.getStreamProvider();
|
||||
for await (const content of provider(
|
||||
finalMessage,
|
||||
this.prompt.model,
|
||||
options
|
||||
)) {
|
||||
yield {
|
||||
type: WorkflowResultType.Content,
|
||||
nodeId: this.id,
|
||||
content,
|
||||
// pass through content as a stream response if no next node
|
||||
passthrough: !nextNode,
|
||||
};
|
||||
}
|
||||
} else if (this.data.nodeType === WorkflowNodeType.Basic) {
|
||||
if (!this.executor) {
|
||||
throw new Error(`Node ${this.name} not initialized`);
|
||||
}
|
||||
|
||||
yield* this.executor.next(this.data, params, options);
|
||||
} else {
|
||||
yield {
|
||||
type: WorkflowResultType.Content,
|
||||
nodeId: this.id,
|
||||
content: params.content,
|
||||
};
|
||||
}
|
||||
|
||||
yield { type: WorkflowResultType.EndRun, nextNode };
|
||||
|
||||
@@ -1,28 +1,40 @@
|
||||
import type { WorkflowExecutorType } from './executor';
|
||||
import type { WorkflowNode } from './node';
|
||||
|
||||
export enum WorkflowNodeType {
|
||||
Basic,
|
||||
Decision,
|
||||
Basic = 'basic',
|
||||
Decision = 'decision',
|
||||
Nope = 'nope',
|
||||
}
|
||||
|
||||
export type NodeData = { id: string; name: string } & (
|
||||
| {
|
||||
nodeType: WorkflowNodeType.Basic;
|
||||
promptName: string;
|
||||
type: 'text' | 'image';
|
||||
type: WorkflowExecutorType;
|
||||
// update the prompt params by output with the custom key
|
||||
paramKey?: string;
|
||||
}
|
||||
| { nodeType: WorkflowNodeType.Decision; condition: string }
|
||||
| {
|
||||
nodeType: WorkflowNodeType.Decision;
|
||||
condition:
|
||||
| ((nodeIds: string[], params: WorkflowNodeState) => string)
|
||||
| string;
|
||||
}
|
||||
// do nothing node
|
||||
| { nodeType: WorkflowNodeType.Nope }
|
||||
);
|
||||
|
||||
export type WorkflowNodeState = Record<string, string>;
|
||||
|
||||
export type WorkflowGraphData = Array<NodeData & { edges: string[] }>;
|
||||
export type WorkflowGraphList = Array<{
|
||||
export type WorkflowGraph = {
|
||||
name: string;
|
||||
// true if the graph has been modified
|
||||
modified?: boolean;
|
||||
graph: WorkflowGraphData;
|
||||
}>;
|
||||
};
|
||||
export type WorkflowGraphs = Array<WorkflowGraph>;
|
||||
|
||||
export enum WorkflowResultType {
|
||||
StartRun,
|
||||
@@ -33,7 +45,7 @@ export enum WorkflowResultType {
|
||||
|
||||
export type WorkflowResult =
|
||||
| { type: WorkflowResultType.StartRun; nodeId: string }
|
||||
| { type: WorkflowResultType.EndRun; nextNode: WorkflowNode }
|
||||
| { type: WorkflowResultType.EndRun; nextNode?: WorkflowNode }
|
||||
| {
|
||||
type: WorkflowResultType.Params;
|
||||
params: Record<string, string | string[]>;
|
||||
@@ -42,8 +54,6 @@ export type WorkflowResult =
|
||||
type: WorkflowResultType.Content;
|
||||
nodeId: string;
|
||||
content: string;
|
||||
// if is the end of the workflow, pass through the content to stream response
|
||||
passthrough?: boolean;
|
||||
};
|
||||
|
||||
export type WorkflowGraph = Map<string, WorkflowNode>;
|
||||
export type WorkflowGraphInstances = Map<string, WorkflowNode>;
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
import 'ses';
|
||||
|
||||
lockdown();
|
||||
|
||||
const sandbox = new Compartment();
|
||||
|
||||
export default ({ iife, nodeIds, params }) => {
|
||||
sandbox.globalThis.nodeIds = harden(nodeIds);
|
||||
sandbox.globalThis.params = harden(params);
|
||||
return sandbox.evaluate(iife);
|
||||
};
|
||||
@@ -1,12 +1,10 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
import { PromptService } from '../prompt';
|
||||
import { CopilotProviderService } from '../providers';
|
||||
import { CopilotChatOptions } from '../types';
|
||||
import { WorkflowNode } from './node';
|
||||
import {
|
||||
WorkflowGraph,
|
||||
WorkflowNodeState,
|
||||
type WorkflowGraphInstances,
|
||||
type WorkflowNodeState,
|
||||
WorkflowNodeType,
|
||||
WorkflowResultType,
|
||||
} from './types';
|
||||
@@ -15,11 +13,7 @@ export class CopilotWorkflow {
|
||||
private readonly logger = new Logger(CopilotWorkflow.name);
|
||||
private readonly rootNode: WorkflowNode;
|
||||
|
||||
constructor(
|
||||
private readonly prompt: PromptService,
|
||||
private readonly provider: CopilotProviderService,
|
||||
workflow: WorkflowGraph
|
||||
) {
|
||||
constructor(workflow: WorkflowGraphInstances) {
|
||||
const startNode = workflow.get('start');
|
||||
if (!startNode) {
|
||||
throw new Error(`No start node found in graph`);
|
||||
@@ -38,8 +32,6 @@ export class CopilotWorkflow {
|
||||
let result = '';
|
||||
let nextNode: WorkflowNode | undefined;
|
||||
|
||||
await currentNode.initNode(this.prompt, this.provider);
|
||||
|
||||
for await (const ret of currentNode.next(lastParams, options)) {
|
||||
if (ret.type === WorkflowResultType.EndRun) {
|
||||
nextNode = ret.nextNode;
|
||||
@@ -53,8 +45,8 @@ export class CopilotWorkflow {
|
||||
);
|
||||
}
|
||||
} else if (ret.type === WorkflowResultType.Content) {
|
||||
if (ret.passthrough) {
|
||||
// pass through content as a stream response
|
||||
if (!currentNode.hasEdges) {
|
||||
// pass through content as a stream response if node is end node
|
||||
yield ret.content;
|
||||
} else {
|
||||
result += ret.content;
|
||||
@@ -70,7 +62,9 @@ export class CopilotWorkflow {
|
||||
}
|
||||
|
||||
currentNode = nextNode;
|
||||
if (result) lastParams.content = result;
|
||||
if (result && lastParams.content !== result) {
|
||||
lastParams.content = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,7 +259,7 @@ test('should be able to chat with api by workflow', async t => {
|
||||
const ret = await chatWithWorkflow(app, token, sessionId, messageId);
|
||||
t.is(
|
||||
ret,
|
||||
textToEventStream('generate text to text stream', messageId),
|
||||
textToEventStream(['generate text to text stream'], messageId),
|
||||
'should be able to chat with workflow'
|
||||
);
|
||||
});
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import { TestingModule } from '@nestjs/testing';
|
||||
import type { TestFn } from 'ava';
|
||||
import ava from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import { AuthService } from '../src/core/auth';
|
||||
import { QuotaModule } from '../src/core/quota';
|
||||
@@ -21,7 +22,20 @@ import {
|
||||
CopilotCapability,
|
||||
CopilotProviderType,
|
||||
} from '../src/plugins/copilot/types';
|
||||
import { CopilotWorkflowService } from '../src/plugins/copilot/workflow';
|
||||
import {
|
||||
CopilotChatTextExecutor,
|
||||
CopilotWorkflowService,
|
||||
WorkflowNodeType,
|
||||
} from '../src/plugins/copilot/workflow';
|
||||
import {
|
||||
getWorkflowExecutor,
|
||||
WorkflowExecutorType,
|
||||
} from '../src/plugins/copilot/workflow/executor';
|
||||
import { WorkflowGraphList } from '../src/plugins/copilot/workflow/graph';
|
||||
import {
|
||||
NodeData,
|
||||
WorkflowResultType,
|
||||
} from '../src/plugins/copilot/workflow/types';
|
||||
import { createTestingModule } from './utils';
|
||||
import { MockCopilotTestProvider } from './utils/copilot';
|
||||
|
||||
@@ -32,6 +46,7 @@ const test = ava as TestFn<{
|
||||
provider: CopilotProviderService;
|
||||
session: ChatSessionService;
|
||||
workflow: CopilotWorkflowService;
|
||||
textWorkflowExecutor: CopilotChatTextExecutor;
|
||||
}>;
|
||||
|
||||
test.beforeEach(async t => {
|
||||
@@ -59,6 +74,7 @@ test.beforeEach(async t => {
|
||||
const provider = module.get(CopilotProviderService);
|
||||
const session = module.get(ChatSessionService);
|
||||
const workflow = module.get(CopilotWorkflowService);
|
||||
const textWorkflowExecutor = module.get(CopilotChatTextExecutor);
|
||||
|
||||
t.context.module = module;
|
||||
t.context.auth = auth;
|
||||
@@ -66,6 +82,7 @@ test.beforeEach(async t => {
|
||||
t.context.provider = provider;
|
||||
t.context.session = session;
|
||||
t.context.workflow = workflow;
|
||||
t.context.textWorkflowExecutor = textWorkflowExecutor;
|
||||
});
|
||||
|
||||
test.afterEach.always(async t => {
|
||||
@@ -541,10 +558,14 @@ test('should be able to register test provider', async t => {
|
||||
await assertProvider(CopilotCapability.ImageToText);
|
||||
});
|
||||
|
||||
// ==================== workflow ====================
|
||||
|
||||
// this test used to preview the final result of the workflow
|
||||
// for the functional test of the API itself, refer to the follow tests
|
||||
test.skip('should be able to preview workflow', async t => {
|
||||
const { prompt, workflow } = t.context;
|
||||
const { prompt, workflow, textWorkflowExecutor } = t.context;
|
||||
|
||||
textWorkflowExecutor.register();
|
||||
registerCopilotProvider(OpenAIProvider);
|
||||
|
||||
for (const p of prompts) {
|
||||
@@ -554,13 +575,174 @@ test.skip('should be able to preview workflow', async t => {
|
||||
let result = '';
|
||||
for await (const ret of workflow.runGraph(
|
||||
{ content: 'apple company' },
|
||||
'workflow:presentation'
|
||||
'presentation'
|
||||
)) {
|
||||
result += ret;
|
||||
console.log('stream result:', ret);
|
||||
}
|
||||
console.log('final stream result:', result);
|
||||
t.truthy(result, 'should return result');
|
||||
|
||||
unregisterCopilotProvider(OpenAIProvider.type);
|
||||
t.pass();
|
||||
});
|
||||
|
||||
test('should be able to run workflow', async t => {
|
||||
const { prompt, workflow, textWorkflowExecutor } = t.context;
|
||||
|
||||
textWorkflowExecutor.register();
|
||||
unregisterCopilotProvider(OpenAIProvider.type);
|
||||
registerCopilotProvider(MockCopilotTestProvider);
|
||||
|
||||
const executor = Sinon.spy(textWorkflowExecutor, 'next');
|
||||
|
||||
for (const p of prompts) {
|
||||
await prompt.set(p.name, p.model, p.messages);
|
||||
}
|
||||
|
||||
const graphName = 'presentation';
|
||||
const graph = WorkflowGraphList.find(g => g.name === graphName);
|
||||
t.truthy(graph, `graph ${graphName} not defined`);
|
||||
|
||||
// todo: use Array.fromAsync
|
||||
let result = '';
|
||||
for await (const ret of workflow.runGraph(
|
||||
{ content: 'apple company' },
|
||||
graphName
|
||||
)) {
|
||||
result += ret;
|
||||
}
|
||||
t.assert(result, 'generate text to text stream');
|
||||
|
||||
// presentation workflow has condition node, it will always false
|
||||
// so the latest 2 nodes will not be executed
|
||||
const callCount = graph!.graph.length - 3;
|
||||
t.is(
|
||||
executor.callCount,
|
||||
callCount,
|
||||
`should call executor ${callCount} times`
|
||||
);
|
||||
|
||||
for (const [idx, node] of graph!.graph
|
||||
.filter(g => g.nodeType === WorkflowNodeType.Basic)
|
||||
.entries()) {
|
||||
const params = executor.getCall(idx);
|
||||
|
||||
if (idx < callCount) {
|
||||
t.is(params.args[0].id, node.id, 'graph id should correct');
|
||||
|
||||
t.is(
|
||||
params.args[1].content,
|
||||
'generate text to text stream',
|
||||
'graph params should correct'
|
||||
);
|
||||
t.is(
|
||||
params.args[1].language,
|
||||
'generate text to text',
|
||||
'graph params should correct'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unregisterCopilotProvider(MockCopilotTestProvider.type);
|
||||
registerCopilotProvider(OpenAIProvider);
|
||||
});
|
||||
|
||||
// ==================== workflow executor ====================
|
||||
|
||||
const wrapAsyncIter = async <T>(iter: AsyncIterable<T>) => {
|
||||
const result: T[] = [];
|
||||
for await (const r of iter) {
|
||||
result.push(r);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
test('should be able to run executor', async t => {
|
||||
const { textWorkflowExecutor } = t.context;
|
||||
|
||||
textWorkflowExecutor.register();
|
||||
const executor = getWorkflowExecutor(textWorkflowExecutor.type);
|
||||
t.is(executor.type, textWorkflowExecutor.type, 'should get executor');
|
||||
|
||||
await t.throwsAsync(
|
||||
wrapAsyncIter(
|
||||
executor.next(
|
||||
{ id: 'nope', name: 'nope', nodeType: WorkflowNodeType.Nope },
|
||||
{}
|
||||
)
|
||||
),
|
||||
{ instanceOf: Error },
|
||||
'should throw error if run non basic node'
|
||||
);
|
||||
});
|
||||
|
||||
test('should be able to run text executor', async t => {
|
||||
const { textWorkflowExecutor, provider, prompt } = t.context;
|
||||
|
||||
textWorkflowExecutor.register();
|
||||
const executor = getWorkflowExecutor(textWorkflowExecutor.type);
|
||||
unregisterCopilotProvider(OpenAIProvider.type);
|
||||
registerCopilotProvider(MockCopilotTestProvider);
|
||||
await prompt.set('test', 'test', [
|
||||
{ role: 'system', content: 'hello {{word}}' },
|
||||
]);
|
||||
// mock provider
|
||||
const testProvider =
|
||||
(await provider.getProviderByModel<CopilotCapability.TextToText>('test'))!;
|
||||
const text = Sinon.spy(testProvider, 'generateText');
|
||||
const textStream = Sinon.spy(testProvider, 'generateTextStream');
|
||||
|
||||
const nodeData: NodeData = {
|
||||
id: 'basic',
|
||||
name: 'basic',
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
promptName: 'test',
|
||||
type: WorkflowExecutorType.ChatText,
|
||||
};
|
||||
|
||||
// text
|
||||
{
|
||||
const ret = await wrapAsyncIter(
|
||||
executor.next({ ...nodeData, paramKey: 'key' }, { word: 'world' })
|
||||
);
|
||||
|
||||
t.deepEqual(ret, [
|
||||
{
|
||||
type: WorkflowResultType.Params,
|
||||
params: { key: 'generate text to text' },
|
||||
},
|
||||
]);
|
||||
t.deepEqual(
|
||||
text.lastCall.args[0][0].content,
|
||||
'hello world',
|
||||
'should render the prompt with params'
|
||||
);
|
||||
}
|
||||
|
||||
// text stream with attachment
|
||||
{
|
||||
const ret = await wrapAsyncIter(
|
||||
executor.next(nodeData, {
|
||||
attachments: ['https://affine.pro/example.jpg'],
|
||||
})
|
||||
);
|
||||
|
||||
t.deepEqual(
|
||||
ret,
|
||||
Array.from('generate text to text stream').map(t => ({
|
||||
content: t,
|
||||
nodeId: 'basic',
|
||||
type: WorkflowResultType.Content,
|
||||
}))
|
||||
);
|
||||
t.deepEqual(
|
||||
textStream.lastCall.args[0][0].params?.attachments,
|
||||
['https://affine.pro/example.jpg'],
|
||||
'should pass attachments to provider'
|
||||
);
|
||||
}
|
||||
|
||||
Sinon.restore();
|
||||
unregisterCopilotProvider(MockCopilotTestProvider.type);
|
||||
registerCopilotProvider(OpenAIProvider);
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user