mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 21:05:19 +00:00
@@ -35,6 +35,7 @@ import { CopilotProviderService } from './providers';
|
||||
import { ChatSession, ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
import { CopilotCapability, CopilotTextProvider } from './types';
|
||||
import { CopilotWorkflowService } from './workflow';
|
||||
|
||||
export interface ChatEvent {
|
||||
type: 'attachment' | 'message' | 'error';
|
||||
@@ -55,6 +56,7 @@ export class CopilotController {
|
||||
private readonly config: Config,
|
||||
private readonly chatSession: ChatSessionService,
|
||||
private readonly provider: CopilotProviderService,
|
||||
private readonly workflow: CopilotWorkflowService,
|
||||
private readonly storage: CopilotStorage
|
||||
) {}
|
||||
|
||||
@@ -266,6 +268,71 @@ export class CopilotController {
|
||||
}
|
||||
}
|
||||
|
||||
@Sse('/chat/:sessionId/workflow')
|
||||
async chatWorkflow(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
const messageId = Array.isArray(params.messageId)
|
||||
? params.messageId[0]
|
||||
: params.messageId;
|
||||
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
delete params.messageId;
|
||||
const latestMessage = session.stashMessages.findLast(
|
||||
m => m.role === 'user'
|
||||
);
|
||||
if (latestMessage) {
|
||||
params = Object.assign({}, params, latestMessage.params, {
|
||||
content: latestMessage.content,
|
||||
});
|
||||
}
|
||||
|
||||
return from(
|
||||
this.workflow.runGraph(params, session.model, {
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
).pipe(
|
||||
connect(shared$ =>
|
||||
merge(
|
||||
// actual chat event stream
|
||||
shared$.pipe(
|
||||
map(data => ({ type: 'message' as const, id: messageId, data }))
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(values => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: values.join(''),
|
||||
createdAt: new Date(),
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
switchMap(() => EMPTY)
|
||||
)
|
||||
)
|
||||
),
|
||||
catchError(err =>
|
||||
of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
})
|
||||
)
|
||||
);
|
||||
} catch (err) {
|
||||
return of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@Sse('/chat/:sessionId/images')
|
||||
async chatImagesStream(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
|
||||
@@ -2,14 +2,14 @@ import { type WorkflowGraphList, WorkflowNodeType } from './types';
|
||||
|
||||
export const WorkflowGraphs: WorkflowGraphList = [
|
||||
{
|
||||
name: 'Create a presentation',
|
||||
name: 'presentation',
|
||||
graph: [
|
||||
{
|
||||
id: 'start',
|
||||
name: 'Start: check language',
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
type: 'text',
|
||||
promptName: 'Create a presentation:step1',
|
||||
promptName: 'workflow:presentation:step1',
|
||||
paramKey: 'language',
|
||||
edges: ['step2'],
|
||||
},
|
||||
@@ -18,7 +18,7 @@ export const WorkflowGraphs: WorkflowGraphList = [
|
||||
name: 'Step 2: generate presentation',
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
type: 'text',
|
||||
promptName: 'Create a presentation:step2',
|
||||
promptName: 'workflow:presentation:step2',
|
||||
edges: [],
|
||||
// edges: ['step3'],
|
||||
},
|
||||
@@ -27,7 +27,7 @@ export const WorkflowGraphs: WorkflowGraphList = [
|
||||
// name: 'Step 3: check format',
|
||||
// nodeType: WorkflowNodeType.Basic,
|
||||
// type: 'text',
|
||||
// promptName: 'Create a presentation:step3',
|
||||
// promptName: 'workflow:presentation:step3',
|
||||
// paramKey: 'needFormat',
|
||||
// edges: ['step4'],
|
||||
// },
|
||||
@@ -49,7 +49,7 @@ export const WorkflowGraphs: WorkflowGraphList = [
|
||||
// name: 'Step 5: format presentation',
|
||||
// nodeType: WorkflowNodeType.Basic,
|
||||
// type: 'text',
|
||||
// promptName: 'Create a presentation:step5',
|
||||
// promptName: 'workflow:presentation:step5',
|
||||
// edges: ['step6'],
|
||||
// },
|
||||
// {
|
||||
@@ -57,7 +57,7 @@ export const WorkflowGraphs: WorkflowGraphList = [
|
||||
// name: 'Step 6: finish',
|
||||
// nodeType: WorkflowNodeType.Basic,
|
||||
// type: 'text',
|
||||
// promptName: 'Create a presentation:step6',
|
||||
// promptName: 'workflow:presentation:step6',
|
||||
// edges: [],
|
||||
// },
|
||||
],
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import { PromptService } from '../prompt';
|
||||
import { CopilotProviderService } from '../providers';
|
||||
import { CopilotChatOptions } from '../types';
|
||||
import { WorkflowGraphs } from './graph';
|
||||
import { WorkflowNode } from './node';
|
||||
import { WorkflowGraph, WorkflowGraphList } from './types';
|
||||
@@ -57,9 +58,10 @@ export class CopilotWorkflowService {
|
||||
}
|
||||
|
||||
async *runGraph(
|
||||
params: Record<string, string>,
|
||||
graphName: string,
|
||||
initContent: string
|
||||
): AsyncIterable<string | undefined> {
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<string> {
|
||||
const workflowGraph = await this.getWorkflow(graphName);
|
||||
const workflow = new CopilotWorkflow(
|
||||
this.prompt,
|
||||
@@ -67,7 +69,7 @@ export class CopilotWorkflowService {
|
||||
workflowGraph
|
||||
);
|
||||
|
||||
for await (const result of workflow.runGraph(initContent)) {
|
||||
for await (const result of workflow.runGraph(params, options)) {
|
||||
yield result;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,6 +78,44 @@ export class WorkflowNode {
|
||||
return this.edges[0]?.id;
|
||||
}
|
||||
|
||||
private getStreamProvider() {
|
||||
if (this.data.nodeType === WorkflowNodeType.Basic && this.provider) {
|
||||
if (
|
||||
this.data.type === 'text' &&
|
||||
'generateText' in this.provider &&
|
||||
!this.data.paramKey
|
||||
) {
|
||||
return this.provider.generateTextStream.bind(this.provider);
|
||||
} else if (
|
||||
this.data.type === 'image' &&
|
||||
'generateImages' in this.provider &&
|
||||
!this.data.paramKey
|
||||
) {
|
||||
return this.provider.generateImagesStream.bind(this.provider);
|
||||
}
|
||||
}
|
||||
throw new Error(`Stream Provider not found for node ${this.name}`);
|
||||
}
|
||||
|
||||
private getProvider() {
|
||||
if (this.data.nodeType === WorkflowNodeType.Basic && this.provider) {
|
||||
if (
|
||||
this.data.type === 'text' &&
|
||||
'generateText' in this.provider &&
|
||||
this.data.paramKey
|
||||
) {
|
||||
return this.provider.generateText.bind(this.provider);
|
||||
} else if (
|
||||
this.data.type === 'image' &&
|
||||
'generateImages' in this.provider &&
|
||||
this.data.paramKey
|
||||
) {
|
||||
return this.provider.generateImages.bind(this.provider);
|
||||
}
|
||||
}
|
||||
throw new Error(`Provider not found for node ${this.name}`);
|
||||
}
|
||||
|
||||
async *next(
|
||||
params: WorkflowNodeState,
|
||||
options?: CopilotChatOptions
|
||||
@@ -100,63 +138,34 @@ export class WorkflowNode {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// pass through content as a stream response if no next node
|
||||
const passthrough = !nextNode;
|
||||
if (this.data.type === 'text' && 'generateText' in this.provider) {
|
||||
if (this.data.paramKey) {
|
||||
// update params with custom key
|
||||
const finalMessage = this.prompt.finish(params);
|
||||
if (this.data.paramKey) {
|
||||
const provider = this.getProvider();
|
||||
// update params with custom key
|
||||
yield {
|
||||
type: WorkflowResultType.Params,
|
||||
params: {
|
||||
[this.data.paramKey]: await provider(
|
||||
finalMessage,
|
||||
this.prompt.model,
|
||||
options
|
||||
),
|
||||
},
|
||||
};
|
||||
} else {
|
||||
const provider = this.getStreamProvider();
|
||||
for await (const content of provider(
|
||||
finalMessage,
|
||||
this.prompt.model,
|
||||
options
|
||||
)) {
|
||||
yield {
|
||||
type: WorkflowResultType.Params,
|
||||
params: {
|
||||
[this.data.paramKey]: await this.provider.generateText(
|
||||
this.prompt.finish(params),
|
||||
this.prompt.model,
|
||||
options
|
||||
),
|
||||
},
|
||||
type: WorkflowResultType.Content,
|
||||
nodeId: this.id,
|
||||
content,
|
||||
// pass through content as a stream response if no next node
|
||||
passthrough: !nextNode,
|
||||
};
|
||||
} else {
|
||||
for await (const content of this.provider.generateTextStream(
|
||||
this.prompt.finish(params),
|
||||
this.prompt.model,
|
||||
options
|
||||
)) {
|
||||
yield {
|
||||
type: WorkflowResultType.Content,
|
||||
nodeId: this.id,
|
||||
content,
|
||||
passthrough,
|
||||
};
|
||||
}
|
||||
}
|
||||
} else if (
|
||||
this.data.type === 'image' &&
|
||||
'generateImages' in this.provider
|
||||
) {
|
||||
if (this.data.paramKey) {
|
||||
yield {
|
||||
type: WorkflowResultType.Params,
|
||||
params: {
|
||||
[this.data.paramKey]: await this.provider.generateImages(
|
||||
this.prompt.finish(params),
|
||||
this.prompt.model,
|
||||
options
|
||||
),
|
||||
},
|
||||
};
|
||||
} else {
|
||||
for await (const content of this.provider.generateImagesStream(
|
||||
this.prompt.finish(params),
|
||||
this.prompt.model,
|
||||
options
|
||||
)) {
|
||||
yield {
|
||||
type: WorkflowResultType.Content,
|
||||
nodeId: this.id,
|
||||
content,
|
||||
passthrough,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Logger } from '@nestjs/common';
|
||||
|
||||
import { PromptService } from '../prompt';
|
||||
import { CopilotProviderService } from '../providers';
|
||||
import { CopilotChatOptions } from '../types';
|
||||
import { WorkflowNode } from './node';
|
||||
import {
|
||||
WorkflowGraph,
|
||||
@@ -26,9 +27,12 @@ export class CopilotWorkflow {
|
||||
this.rootNode = startNode;
|
||||
}
|
||||
|
||||
async *runGraph(initContent: string): AsyncIterable<string | undefined> {
|
||||
async *runGraph(
|
||||
params: Record<string, string>,
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<string> {
|
||||
let currentNode: WorkflowNode | undefined = this.rootNode;
|
||||
const lastParams: WorkflowNodeState = { content: initContent };
|
||||
const lastParams: WorkflowNodeState = { ...params };
|
||||
|
||||
while (currentNode) {
|
||||
let result = '';
|
||||
@@ -36,7 +40,7 @@ export class CopilotWorkflow {
|
||||
|
||||
await currentNode.initNode(this.prompt, this.provider);
|
||||
|
||||
for await (const ret of currentNode.next(lastParams)) {
|
||||
for await (const ret of currentNode.next(lastParams, options)) {
|
||||
if (ret.type === WorkflowResultType.EndRun) {
|
||||
nextNode = ret.nextNode;
|
||||
break;
|
||||
@@ -49,8 +53,8 @@ export class CopilotWorkflow {
|
||||
);
|
||||
}
|
||||
} else if (ret.type === WorkflowResultType.Content) {
|
||||
// pass through content as a stream response
|
||||
if (ret.passthrough) {
|
||||
// pass through content as a stream response
|
||||
yield ret.content;
|
||||
} else {
|
||||
result += ret.content;
|
||||
|
||||
Reference in New Issue
Block a user