feat: add workflow resolver (#7123)

fix AFF-1166
This commit is contained in:
darkskygit
2024-06-07 05:53:44 +00:00
parent 44b0ea2b6c
commit ca9a16b728
9 changed files with 203 additions and 78 deletions

View File

@@ -35,6 +35,7 @@ import { CopilotProviderService } from './providers';
import { ChatSession, ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import { CopilotCapability, CopilotTextProvider } from './types';
import { CopilotWorkflowService } from './workflow';
export interface ChatEvent {
type: 'attachment' | 'message' | 'error';
@@ -55,6 +56,7 @@ export class CopilotController {
private readonly config: Config,
private readonly chatSession: ChatSessionService,
private readonly provider: CopilotProviderService,
private readonly workflow: CopilotWorkflowService,
private readonly storage: CopilotStorage
) {}
@@ -266,6 +268,71 @@ export class CopilotController {
}
}
@Sse('/chat/:sessionId/workflow')
async chatWorkflow(
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;
const latestMessage = session.stashMessages.findLast(
m => m.role === 'user'
);
if (latestMessage) {
params = Object.assign({}, params, latestMessage.params, {
content: latestMessage.content,
});
}
return from(
this.workflow.runGraph(params, session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(data => ({ type: 'message' as const, id: messageId, data }))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(values => {
session.push({
role: 'assistant',
content: values.join(''),
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
} catch (err) {
return of({
type: 'error' as const,
data: this.handleError(err),
});
}
}
@Sse('/chat/:sessionId/images')
async chatImagesStream(
@CurrentUser() user: CurrentUser,

View File

@@ -2,14 +2,14 @@ import { type WorkflowGraphList, WorkflowNodeType } from './types';
export const WorkflowGraphs: WorkflowGraphList = [
{
name: 'Create a presentation',
name: 'presentation',
graph: [
{
id: 'start',
name: 'Start: check language',
nodeType: WorkflowNodeType.Basic,
type: 'text',
promptName: 'Create a presentation:step1',
promptName: 'workflow:presentation:step1',
paramKey: 'language',
edges: ['step2'],
},
@@ -18,7 +18,7 @@ export const WorkflowGraphs: WorkflowGraphList = [
name: 'Step 2: generate presentation',
nodeType: WorkflowNodeType.Basic,
type: 'text',
promptName: 'Create a presentation:step2',
promptName: 'workflow:presentation:step2',
edges: [],
// edges: ['step3'],
},
@@ -27,7 +27,7 @@ export const WorkflowGraphs: WorkflowGraphList = [
// name: 'Step 3: check format',
// nodeType: WorkflowNodeType.Basic,
// type: 'text',
// promptName: 'Create a presentation:step3',
// promptName: 'workflow:presentation:step3',
// paramKey: 'needFormat',
// edges: ['step4'],
// },
@@ -49,7 +49,7 @@ export const WorkflowGraphs: WorkflowGraphList = [
// name: 'Step 5: format presentation',
// nodeType: WorkflowNodeType.Basic,
// type: 'text',
// promptName: 'Create a presentation:step5',
// promptName: 'workflow:presentation:step5',
// edges: ['step6'],
// },
// {
@@ -57,7 +57,7 @@ export const WorkflowGraphs: WorkflowGraphList = [
// name: 'Step 6: finish',
// nodeType: WorkflowNodeType.Basic,
// type: 'text',
// promptName: 'Create a presentation:step6',
// promptName: 'workflow:presentation:step6',
// edges: [],
// },
],

View File

@@ -2,6 +2,7 @@ import { Injectable, Logger } from '@nestjs/common';
import { PromptService } from '../prompt';
import { CopilotProviderService } from '../providers';
import { CopilotChatOptions } from '../types';
import { WorkflowGraphs } from './graph';
import { WorkflowNode } from './node';
import { WorkflowGraph, WorkflowGraphList } from './types';
@@ -57,9 +58,10 @@ export class CopilotWorkflowService {
}
async *runGraph(
params: Record<string, string>,
graphName: string,
initContent: string
): AsyncIterable<string | undefined> {
options?: CopilotChatOptions
): AsyncIterable<string> {
const workflowGraph = await this.getWorkflow(graphName);
const workflow = new CopilotWorkflow(
this.prompt,
@@ -67,7 +69,7 @@ export class CopilotWorkflowService {
workflowGraph
);
for await (const result of workflow.runGraph(initContent)) {
for await (const result of workflow.runGraph(params, options)) {
yield result;
}
}

View File

@@ -78,6 +78,44 @@ export class WorkflowNode {
return this.edges[0]?.id;
}
private getStreamProvider() {
if (this.data.nodeType === WorkflowNodeType.Basic && this.provider) {
if (
this.data.type === 'text' &&
'generateText' in this.provider &&
!this.data.paramKey
) {
return this.provider.generateTextStream.bind(this.provider);
} else if (
this.data.type === 'image' &&
'generateImages' in this.provider &&
!this.data.paramKey
) {
return this.provider.generateImagesStream.bind(this.provider);
}
}
throw new Error(`Stream Provider not found for node ${this.name}`);
}
private getProvider() {
if (this.data.nodeType === WorkflowNodeType.Basic && this.provider) {
if (
this.data.type === 'text' &&
'generateText' in this.provider &&
this.data.paramKey
) {
return this.provider.generateText.bind(this.provider);
} else if (
this.data.type === 'image' &&
'generateImages' in this.provider &&
this.data.paramKey
) {
return this.provider.generateImages.bind(this.provider);
}
}
throw new Error(`Provider not found for node ${this.name}`);
}
async *next(
params: WorkflowNodeState,
options?: CopilotChatOptions
@@ -100,63 +138,34 @@ export class WorkflowNode {
}
}
} else {
// pass through content as a stream response if no next node
const passthrough = !nextNode;
if (this.data.type === 'text' && 'generateText' in this.provider) {
if (this.data.paramKey) {
// update params with custom key
const finalMessage = this.prompt.finish(params);
if (this.data.paramKey) {
const provider = this.getProvider();
// update params with custom key
yield {
type: WorkflowResultType.Params,
params: {
[this.data.paramKey]: await provider(
finalMessage,
this.prompt.model,
options
),
},
};
} else {
const provider = this.getStreamProvider();
for await (const content of provider(
finalMessage,
this.prompt.model,
options
)) {
yield {
type: WorkflowResultType.Params,
params: {
[this.data.paramKey]: await this.provider.generateText(
this.prompt.finish(params),
this.prompt.model,
options
),
},
type: WorkflowResultType.Content,
nodeId: this.id,
content,
// pass through content as a stream response if no next node
passthrough: !nextNode,
};
} else {
for await (const content of this.provider.generateTextStream(
this.prompt.finish(params),
this.prompt.model,
options
)) {
yield {
type: WorkflowResultType.Content,
nodeId: this.id,
content,
passthrough,
};
}
}
} else if (
this.data.type === 'image' &&
'generateImages' in this.provider
) {
if (this.data.paramKey) {
yield {
type: WorkflowResultType.Params,
params: {
[this.data.paramKey]: await this.provider.generateImages(
this.prompt.finish(params),
this.prompt.model,
options
),
},
};
} else {
for await (const content of this.provider.generateImagesStream(
this.prompt.finish(params),
this.prompt.model,
options
)) {
yield {
type: WorkflowResultType.Content,
nodeId: this.id,
content,
passthrough,
};
}
}
}
}

View File

@@ -2,6 +2,7 @@ import { Logger } from '@nestjs/common';
import { PromptService } from '../prompt';
import { CopilotProviderService } from '../providers';
import { CopilotChatOptions } from '../types';
import { WorkflowNode } from './node';
import {
WorkflowGraph,
@@ -26,9 +27,12 @@ export class CopilotWorkflow {
this.rootNode = startNode;
}
async *runGraph(initContent: string): AsyncIterable<string | undefined> {
async *runGraph(
params: Record<string, string>,
options?: CopilotChatOptions
): AsyncIterable<string> {
let currentNode: WorkflowNode | undefined = this.rootNode;
const lastParams: WorkflowNodeState = { content: initContent };
const lastParams: WorkflowNodeState = { ...params };
while (currentNode) {
let result = '';
@@ -36,7 +40,7 @@ export class CopilotWorkflow {
await currentNode.initNode(this.prompt, this.provider);
for await (const ret of currentNode.next(lastParams)) {
for await (const ret of currentNode.next(lastParams, options)) {
if (ret.type === WorkflowResultType.EndRun) {
nextNode = ret.nextNode;
break;
@@ -49,8 +53,8 @@ export class CopilotWorkflow {
);
}
} else if (ret.type === WorkflowResultType.Content) {
// pass through content as a stream response
if (ret.passthrough) {
// pass through content as a stream response
yield ret.content;
} else {
result += ret.content;