From ca9a16b72858b39d3b204f57ba3920259a7a1611 Mon Sep 17 00:00:00 2001 From: darkskygit Date: Fri, 7 Jun 2024 05:53:44 +0000 Subject: [PATCH] feat: add workflow resolver (#7123) fix AFF-1166 --- .../src/data/migrations/utils/prompts.ts | 23 ++-- .../server/src/plugins/copilot/controller.ts | 67 ++++++++++ .../src/plugins/copilot/workflow/graph.ts | 12 +- .../src/plugins/copilot/workflow/index.ts | 8 +- .../src/plugins/copilot/workflow/node.ts | 119 ++++++++++-------- .../src/plugins/copilot/workflow/workflow.ts | 12 +- packages/backend/server/tests/copilot.e2e.ts | 26 ++++ packages/backend/server/tests/copilot.spec.ts | 4 +- .../backend/server/tests/utils/copilot.ts | 10 ++ 9 files changed, 203 insertions(+), 78 deletions(-) diff --git a/packages/backend/server/src/data/migrations/utils/prompts.ts b/packages/backend/server/src/data/migrations/utils/prompts.ts index 288715ed20..c4bdd7c866 100644 --- a/packages/backend/server/src/data/migrations/utils/prompts.ts +++ b/packages/backend/server/src/data/migrations/utils/prompts.ts @@ -455,8 +455,15 @@ content: {{content}}`, ], }, { - name: 'Create a presentation:step1', - action: 'Create a presentation:step1', + name: 'workflow:presentation', + action: 'workflow:presentation', + // used only in workflow, point to workflow graph name + model: 'presentation', + messages: [], + }, + { + name: 'workflow:presentation:step1', + action: 'workflow:presentation:step1', model: 'gpt-4o', messages: [ { @@ -471,8 +478,8 @@ content: {{content}}`, ], }, { - name: 'Create a presentation:step2', - action: 'Create a presentation:step2', + name: 'workflow:presentation:step2', + action: 'workflow:presentation:step2', model: 'gpt-4o', messages: [ { @@ -491,8 +498,8 @@ content: {{content}}`, ], }, { - name: 'Create a presentation:step3', - action: 'Create a presentation:step3', + name: 'workflow:presentation:step3', + action: 'workflow:presentation:step3', model: 'gpt-4o', messages: [ { @@ -507,8 +514,8 @@ content: {{content}}`, ], }, { - name: 'Create a presentation:step4', - action: 'Create a presentation:step4', + name: 'workflow:presentation:step4', + action: 'workflow:presentation:step4', model: 'gpt-4o', messages: [ { diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index a6713a118c..9bbe6c1a64 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -35,6 +35,7 @@ import { CopilotProviderService } from './providers'; import { ChatSession, ChatSessionService } from './session'; import { CopilotStorage } from './storage'; import { CopilotCapability, CopilotTextProvider } from './types'; +import { CopilotWorkflowService } from './workflow'; export interface ChatEvent { type: 'attachment' | 'message' | 'error'; @@ -55,6 +56,7 @@ export class CopilotController { private readonly config: Config, private readonly chatSession: ChatSessionService, private readonly provider: CopilotProviderService, + private readonly workflow: CopilotWorkflowService, private readonly storage: CopilotStorage ) {} @@ -266,6 +268,71 @@ export class CopilotController { } } + @Sse('/chat/:sessionId/workflow') + async chatWorkflow( + @CurrentUser() user: CurrentUser, + @Req() req: Request, + @Param('sessionId') sessionId: string, + @Query() params: Record + ): Promise> { + try { + const messageId = Array.isArray(params.messageId) + ? params.messageId[0] + : params.messageId; + + const session = await this.appendSessionMessage(sessionId, messageId); + delete params.messageId; + const latestMessage = session.stashMessages.findLast( + m => m.role === 'user' + ); + if (latestMessage) { + params = Object.assign({}, params, latestMessage.params, { + content: latestMessage.content, + }); + } + + return from( + this.workflow.runGraph(params, session.model, { + signal: this.getSignal(req), + user: user.id, + }) + ).pipe( + connect(shared$ => + merge( + // actual chat event stream + shared$.pipe( + map(data => ({ type: 'message' as const, id: messageId, data })) + ), + // save the generated text to the session + shared$.pipe( + toArray(), + concatMap(values => { + session.push({ + role: 'assistant', + content: values.join(''), + createdAt: new Date(), + }); + return from(session.save()); + }), + switchMap(() => EMPTY) + ) + ) + ), + catchError(err => + of({ + type: 'error' as const, + data: this.handleError(err), + }) + ) + ); + } catch (err) { + return of({ + type: 'error' as const, + data: this.handleError(err), + }); + } + } + @Sse('/chat/:sessionId/images') async chatImagesStream( @CurrentUser() user: CurrentUser, diff --git a/packages/backend/server/src/plugins/copilot/workflow/graph.ts b/packages/backend/server/src/plugins/copilot/workflow/graph.ts index 4858e2f794..2b995b03e1 100644 --- a/packages/backend/server/src/plugins/copilot/workflow/graph.ts +++ b/packages/backend/server/src/plugins/copilot/workflow/graph.ts @@ -2,14 +2,14 @@ import { type WorkflowGraphList, WorkflowNodeType } from './types'; export const WorkflowGraphs: WorkflowGraphList = [ { - name: 'Create a presentation', + name: 'presentation', graph: [ { id: 'start', name: 'Start: check language', nodeType: WorkflowNodeType.Basic, type: 'text', - promptName: 'Create a presentation:step1', + promptName: 'workflow:presentation:step1', paramKey: 'language', edges: ['step2'], }, @@ -18,7 +18,7 @@ export const WorkflowGraphs: WorkflowGraphList = [ name: 'Step 2: generate presentation', nodeType: WorkflowNodeType.Basic, type: 'text', - promptName: 'Create a presentation:step2', + promptName: 'workflow:presentation:step2', edges: [], // edges: ['step3'], }, @@ -27,7 +27,7 @@ export const WorkflowGraphs: WorkflowGraphList = [ // name: 'Step 3: check format', // nodeType: WorkflowNodeType.Basic, // type: 'text', - // promptName: 'Create a presentation:step3', + // promptName: 'workflow:presentation:step3', // paramKey: 'needFormat', // edges: ['step4'], // }, @@ -49,7 +49,7 @@ export const WorkflowGraphs: WorkflowGraphList = [ // name: 'Step 5: format presentation', // nodeType: WorkflowNodeType.Basic, // type: 'text', - // promptName: 'Create a presentation:step5', + // promptName: 'workflow:presentation:step5', // edges: ['step6'], // }, // { @@ -57,7 +57,7 @@ export const WorkflowGraphs: WorkflowGraphList = [ // name: 'Step 6: finish', // nodeType: WorkflowNodeType.Basic, // type: 'text', - // promptName: 'Create a presentation:step6', + // promptName: 'workflow:presentation:step6', // edges: [], // }, ], diff --git a/packages/backend/server/src/plugins/copilot/workflow/index.ts b/packages/backend/server/src/plugins/copilot/workflow/index.ts index 79a24fad37..675e643b31 100644 --- a/packages/backend/server/src/plugins/copilot/workflow/index.ts +++ b/packages/backend/server/src/plugins/copilot/workflow/index.ts @@ -2,6 +2,7 @@ import { Injectable, Logger } from '@nestjs/common'; import { PromptService } from '../prompt'; import { CopilotProviderService } from '../providers'; +import { CopilotChatOptions } from '../types'; import { WorkflowGraphs } from './graph'; import { WorkflowNode } from './node'; import { WorkflowGraph, WorkflowGraphList } from './types'; @@ -57,9 +58,10 @@ export class CopilotWorkflowService { } async *runGraph( + params: Record, graphName: string, - initContent: string - ): AsyncIterable { + options?: CopilotChatOptions + ): AsyncIterable { const workflowGraph = await this.getWorkflow(graphName); const workflow = new CopilotWorkflow( this.prompt, @@ -67,7 +69,7 @@ export class CopilotWorkflowService { workflowGraph ); - for await (const result of workflow.runGraph(initContent)) { + for await (const result of workflow.runGraph(params, options)) { yield result; } } diff --git a/packages/backend/server/src/plugins/copilot/workflow/node.ts b/packages/backend/server/src/plugins/copilot/workflow/node.ts index 42ce29b5b4..6af694d425 100644 --- a/packages/backend/server/src/plugins/copilot/workflow/node.ts +++ b/packages/backend/server/src/plugins/copilot/workflow/node.ts @@ -78,6 +78,44 @@ export class WorkflowNode { return this.edges[0]?.id; } + private getStreamProvider() { + if (this.data.nodeType === WorkflowNodeType.Basic && this.provider) { + if ( + this.data.type === 'text' && + 'generateText' in this.provider && + !this.data.paramKey + ) { + return this.provider.generateTextStream.bind(this.provider); + } else if ( + this.data.type === 'image' && + 'generateImages' in this.provider && + !this.data.paramKey + ) { + return this.provider.generateImagesStream.bind(this.provider); + } + } + throw new Error(`Stream Provider not found for node ${this.name}`); + } + + private getProvider() { + if (this.data.nodeType === WorkflowNodeType.Basic && this.provider) { + if ( + this.data.type === 'text' && + 'generateText' in this.provider && + this.data.paramKey + ) { + return this.provider.generateText.bind(this.provider); + } else if ( + this.data.type === 'image' && + 'generateImages' in this.provider && + this.data.paramKey + ) { + return this.provider.generateImages.bind(this.provider); + } + } + throw new Error(`Provider not found for node ${this.name}`); + } + async *next( params: WorkflowNodeState, options?: CopilotChatOptions @@ -100,63 +138,34 @@ export class WorkflowNode { } } } else { - // pass through content as a stream response if no next node - const passthrough = !nextNode; - if (this.data.type === 'text' && 'generateText' in this.provider) { - if (this.data.paramKey) { - // update params with custom key + const finalMessage = this.prompt.finish(params); + if (this.data.paramKey) { + const provider = this.getProvider(); + // update params with custom key + yield { + type: WorkflowResultType.Params, + params: { + [this.data.paramKey]: await provider( + finalMessage, + this.prompt.model, + options + ), + }, + }; + } else { + const provider = this.getStreamProvider(); + for await (const content of provider( + finalMessage, + this.prompt.model, + options + )) { yield { - type: WorkflowResultType.Params, - params: { - [this.data.paramKey]: await this.provider.generateText( - this.prompt.finish(params), - this.prompt.model, - options - ), - }, + type: WorkflowResultType.Content, + nodeId: this.id, + content, + // pass through content as a stream response if no next node + passthrough: !nextNode, }; - } else { - for await (const content of this.provider.generateTextStream( - this.prompt.finish(params), - this.prompt.model, - options - )) { - yield { - type: WorkflowResultType.Content, - nodeId: this.id, - content, - passthrough, - }; - } - } - } else if ( - this.data.type === 'image' && - 'generateImages' in this.provider - ) { - if (this.data.paramKey) { - yield { - type: WorkflowResultType.Params, - params: { - [this.data.paramKey]: await this.provider.generateImages( - this.prompt.finish(params), - this.prompt.model, - options - ), - }, - }; - } else { - for await (const content of this.provider.generateImagesStream( - this.prompt.finish(params), - this.prompt.model, - options - )) { - yield { - type: WorkflowResultType.Content, - nodeId: this.id, - content, - passthrough, - }; - } } } } diff --git a/packages/backend/server/src/plugins/copilot/workflow/workflow.ts b/packages/backend/server/src/plugins/copilot/workflow/workflow.ts index 7cdfaf09ca..b6011cd328 100644 --- a/packages/backend/server/src/plugins/copilot/workflow/workflow.ts +++ b/packages/backend/server/src/plugins/copilot/workflow/workflow.ts @@ -2,6 +2,7 @@ import { Logger } from '@nestjs/common'; import { PromptService } from '../prompt'; import { CopilotProviderService } from '../providers'; +import { CopilotChatOptions } from '../types'; import { WorkflowNode } from './node'; import { WorkflowGraph, @@ -26,9 +27,12 @@ export class CopilotWorkflow { this.rootNode = startNode; } - async *runGraph(initContent: string): AsyncIterable { + async *runGraph( + params: Record, + options?: CopilotChatOptions + ): AsyncIterable { let currentNode: WorkflowNode | undefined = this.rootNode; - const lastParams: WorkflowNodeState = { content: initContent }; + const lastParams: WorkflowNodeState = { ...params }; while (currentNode) { let result = ''; @@ -36,7 +40,7 @@ export class CopilotWorkflow { await currentNode.initNode(this.prompt, this.provider); - for await (const ret of currentNode.next(lastParams)) { + for await (const ret of currentNode.next(lastParams, options)) { if (ret.type === WorkflowResultType.EndRun) { nextNode = ret.nextNode; break; @@ -49,8 +53,8 @@ export class CopilotWorkflow { ); } } else if (ret.type === WorkflowResultType.Content) { - // pass through content as a stream response if (ret.passthrough) { + // pass through content as a stream response yield ret.content; } else { result += ret.content; diff --git a/packages/backend/server/tests/copilot.e2e.ts b/packages/backend/server/tests/copilot.e2e.ts index 22ce793c5c..5007688006 100644 --- a/packages/backend/server/tests/copilot.e2e.ts +++ b/packages/backend/server/tests/copilot.e2e.ts @@ -32,6 +32,7 @@ import { chatWithImages, chatWithText, chatWithTextStream, + chatWithWorkflow, createCopilotMessage, createCopilotSession, getHistories, @@ -238,6 +239,31 @@ test('should be able to chat with api', async t => { Sinon.restore(); }); +test('should be able to chat with api by workflow', async t => { + const { app } = t.context; + + const { id } = await createWorkspace(app, token); + const sessionId = await createCopilotSession( + app, + token, + id, + randomUUID(), + 'workflow:presentation' + ); + const messageId = await createCopilotMessage( + app, + token, + sessionId, + 'apple company' + ); + const ret = await chatWithWorkflow(app, token, sessionId, messageId); + t.is( + ret, + textToEventStream('generate text to text stream', messageId), + 'should be able to chat with workflow' + ); +}); + test('should be able to chat with special image model', async t => { const { app, storage } = t.context; diff --git a/packages/backend/server/tests/copilot.spec.ts b/packages/backend/server/tests/copilot.spec.ts index 416f9eb1aa..7fd9f40931 100644 --- a/packages/backend/server/tests/copilot.spec.ts +++ b/packages/backend/server/tests/copilot.spec.ts @@ -553,8 +553,8 @@ test.skip('should be able to preview workflow', async t => { let result = ''; for await (const ret of workflow.runGraph( - 'Create a presentation', - 'apple company' + { content: 'apple company' }, + 'workflow:presentation' )) { result += ret; console.log('stream result:', ret); diff --git a/packages/backend/server/tests/utils/copilot.ts b/packages/backend/server/tests/utils/copilot.ts index 0b8641534f..bd3727cf48 100644 --- a/packages/backend/server/tests/utils/copilot.ts +++ b/packages/backend/server/tests/utils/copilot.ts @@ -33,6 +33,7 @@ export class MockCopilotTestProvider static override readonly type = CopilotProviderType.Test; override readonly availableModels = [ 'test', + 'gpt-4o', 'fast-sdxl/image-to-image', 'lcm-sd15-i2i', 'clarity-upscaler', @@ -234,6 +235,15 @@ export async function chatWithTextStream( return chatWithText(app, userToken, sessionId, messageId, '/stream'); } +export async function chatWithWorkflow( + app: INestApplication, + userToken: string, + sessionId: string, + messageId?: string +) { + return chatWithText(app, userToken, sessionId, messageId, '/workflow'); +} + export async function chatWithImages( app: INestApplication, userToken: string,