mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 20:38:52 +00:00
@@ -34,11 +34,7 @@ import { Config } from '../../fundamentals';
|
||||
import { CopilotProviderService } from './providers';
|
||||
import { ChatSession, ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
import {
|
||||
CopilotCapability,
|
||||
CopilotImageToTextProvider,
|
||||
CopilotTextToTextProvider,
|
||||
} from './types';
|
||||
import { CopilotCapability, CopilotTextProvider } from './types';
|
||||
|
||||
export interface ChatEvent {
|
||||
type: 'attachment' | 'message' | 'error';
|
||||
@@ -88,7 +84,7 @@ export class CopilotController {
|
||||
userId: string,
|
||||
sessionId: string,
|
||||
messageId?: string
|
||||
): Promise<CopilotTextToTextProvider | CopilotImageToTextProvider> {
|
||||
): Promise<CopilotTextProvider> {
|
||||
const { hasAttachment, model } = await this.checkRequest(
|
||||
userId,
|
||||
sessionId,
|
||||
|
||||
@@ -22,6 +22,7 @@ import {
|
||||
} from './resolver';
|
||||
import { ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
import { CopilotWorkflowService } from './workflow';
|
||||
|
||||
registerCopilotProvider(FalProvider);
|
||||
registerCopilotProvider(OpenAIProvider);
|
||||
@@ -39,6 +40,7 @@ registerCopilotProvider(OpenAIProvider);
|
||||
CopilotProviderService,
|
||||
CopilotStorage,
|
||||
PromptsManagementResolver,
|
||||
CopilotWorkflowService,
|
||||
],
|
||||
controllers: [CopilotController],
|
||||
contributesTo: ServerFeature.Copilot,
|
||||
|
||||
@@ -166,6 +166,34 @@ export class CopilotProviderService {
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
async getProviderByModel<C extends CopilotCapability>(
|
||||
model: string,
|
||||
prefer?: CopilotProviderType
|
||||
): Promise<CapabilityToCopilotProvider[C] | null> {
|
||||
const providers = Array.from(COPILOT_PROVIDER.keys());
|
||||
if (providers.length) {
|
||||
let selectedProvider: CopilotProviderType | undefined = prefer;
|
||||
let currentIndex = -1;
|
||||
|
||||
if (!selectedProvider) {
|
||||
currentIndex = 0;
|
||||
selectedProvider = providers[currentIndex];
|
||||
}
|
||||
|
||||
while (selectedProvider) {
|
||||
const provider = this.getProvider(selectedProvider);
|
||||
|
||||
if (await provider.isModelAvailable(model)) {
|
||||
return provider as CapabilityToCopilotProvider[C];
|
||||
}
|
||||
|
||||
currentIndex += 1;
|
||||
selectedProvider = providers[currentIndex];
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export { FalProvider } from './fal';
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import assert from 'node:assert';
|
||||
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { ClientOptions, OpenAI } from 'openai';
|
||||
|
||||
@@ -58,12 +56,11 @@ export class OpenAIProvider
|
||||
private existsModels: string[] | undefined;
|
||||
|
||||
constructor(config: ClientOptions) {
|
||||
assert(OpenAIProvider.assetsConfig(config));
|
||||
this.instance = new OpenAI(config);
|
||||
}
|
||||
|
||||
static assetsConfig(config: ClientOptions) {
|
||||
return !!config.apiKey;
|
||||
return !!config?.apiKey;
|
||||
}
|
||||
|
||||
get type(): CopilotProviderType {
|
||||
|
||||
@@ -230,3 +230,14 @@ export type CapabilityToCopilotProvider = {
|
||||
[CopilotCapability.ImageToText]: CopilotImageToTextProvider;
|
||||
[CopilotCapability.ImageToImage]: CopilotImageToImageProvider;
|
||||
};
|
||||
|
||||
export type CopilotTextProvider =
|
||||
| CopilotTextToTextProvider
|
||||
| CopilotImageToTextProvider;
|
||||
export type CopilotImageProvider =
|
||||
| CopilotTextToImageProvider
|
||||
| CopilotImageToImageProvider;
|
||||
export type CopilotAllProvider =
|
||||
| CopilotTextProvider
|
||||
| CopilotImageProvider
|
||||
| CopilotTextToEmbeddingProvider;
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
import { type WorkflowGraphList, WorkflowNodeType } from './types';
|
||||
|
||||
export const WorkflowGraphs: WorkflowGraphList = [
|
||||
{
|
||||
name: 'Create a presentation',
|
||||
graph: [
|
||||
{
|
||||
id: 'start',
|
||||
name: 'Start: check language',
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
type: 'text',
|
||||
promptName: 'Create a presentation:step1',
|
||||
paramKey: 'language',
|
||||
edges: ['step2'],
|
||||
},
|
||||
{
|
||||
id: 'step2',
|
||||
name: 'Step 2: generate presentation',
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
type: 'text',
|
||||
promptName: 'Create a presentation:step2',
|
||||
edges: [],
|
||||
// edges: ['step3'],
|
||||
},
|
||||
// {
|
||||
// id: 'step3',
|
||||
// name: 'Step 3: check format',
|
||||
// nodeType: WorkflowNodeType.Basic,
|
||||
// type: 'text',
|
||||
// promptName: 'Create a presentation:step3',
|
||||
// paramKey: 'needFormat',
|
||||
// edges: ['step4'],
|
||||
// },
|
||||
// {
|
||||
// id: 'step4',
|
||||
// name: 'Step 4: format presentation if needed',
|
||||
// nodeType: WorkflowNodeType.Decision,
|
||||
// condition: ((
|
||||
// nodeIds: string[],
|
||||
// params: WorkflowNodeState
|
||||
// ) =>
|
||||
// nodeIds[
|
||||
// Number(String(params.needFormat).toLowerCase() === 'true')
|
||||
// ]).toString(),
|
||||
// edges: ['step5', 'step6'],
|
||||
// },
|
||||
// {
|
||||
// id: 'step5',
|
||||
// name: 'Step 5: format presentation',
|
||||
// nodeType: WorkflowNodeType.Basic,
|
||||
// type: 'text',
|
||||
// promptName: 'Create a presentation:step5',
|
||||
// edges: ['step6'],
|
||||
// },
|
||||
// {
|
||||
// id: 'step6',
|
||||
// name: 'Step 6: finish',
|
||||
// nodeType: WorkflowNodeType.Basic,
|
||||
// type: 'text',
|
||||
// promptName: 'Create a presentation:step6',
|
||||
// edges: [],
|
||||
// },
|
||||
],
|
||||
},
|
||||
];
|
||||
@@ -0,0 +1,74 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import { PromptService } from '../prompt';
|
||||
import { CopilotProviderService } from '../providers';
|
||||
import { WorkflowGraphs } from './graph';
|
||||
import { WorkflowNode } from './node';
|
||||
import { WorkflowGraph, WorkflowGraphList } from './types';
|
||||
import { CopilotWorkflow } from './workflow';
|
||||
|
||||
@Injectable()
|
||||
export class CopilotWorkflowService {
|
||||
private readonly logger = new Logger(CopilotWorkflowService.name);
|
||||
constructor(
|
||||
private readonly prompt: PromptService,
|
||||
private readonly provider: CopilotProviderService
|
||||
) {}
|
||||
|
||||
private initWorkflow({ name, graph }: WorkflowGraphList[number]) {
|
||||
const workflow = new Map();
|
||||
for (const nodeData of graph) {
|
||||
const { edges: _, ...data } = nodeData;
|
||||
const node = new WorkflowNode(data);
|
||||
workflow.set(node.id, node);
|
||||
}
|
||||
|
||||
// add edges
|
||||
for (const nodeData of graph) {
|
||||
const node = workflow.get(nodeData.id);
|
||||
if (!node) {
|
||||
this.logger.error(
|
||||
`Failed to init workflow ${name}: node ${nodeData.id} not found`
|
||||
);
|
||||
throw new Error(`Node ${nodeData.id} not found`);
|
||||
}
|
||||
for (const edgeId of nodeData.edges) {
|
||||
const edge = workflow.get(edgeId);
|
||||
if (!edge) {
|
||||
this.logger.error(
|
||||
`Failed to init workflow ${name}: edge ${edgeId} not found in node ${nodeData.id}`
|
||||
);
|
||||
throw new Error(`Edge ${edgeId} not found`);
|
||||
}
|
||||
node.addEdge(edge);
|
||||
}
|
||||
}
|
||||
return workflow;
|
||||
}
|
||||
|
||||
// todo: get workflow from database
|
||||
private async getWorkflow(graphName: string): Promise<WorkflowGraph> {
|
||||
const graph = WorkflowGraphs.find(g => g.name === graphName);
|
||||
if (!graph) {
|
||||
throw new Error(`Graph ${graphName} not found`);
|
||||
}
|
||||
|
||||
return this.initWorkflow(graph);
|
||||
}
|
||||
|
||||
async *runGraph(
|
||||
graphName: string,
|
||||
initContent: string
|
||||
): AsyncIterable<string | undefined> {
|
||||
const workflowGraph = await this.getWorkflow(graphName);
|
||||
const workflow = new CopilotWorkflow(
|
||||
this.prompt,
|
||||
this.provider,
|
||||
workflowGraph
|
||||
);
|
||||
|
||||
for await (const result of workflow.runGraph(initContent)) {
|
||||
yield result;
|
||||
}
|
||||
}
|
||||
}
|
||||
166
packages/backend/server/src/plugins/copilot/workflow/node.ts
Normal file
166
packages/backend/server/src/plugins/copilot/workflow/node.ts
Normal file
@@ -0,0 +1,166 @@
|
||||
import { ChatPrompt, PromptService } from '../prompt';
|
||||
import { CopilotProviderService } from '../providers';
|
||||
import { CopilotAllProvider, CopilotChatOptions } from '../types';
|
||||
import {
|
||||
NodeData,
|
||||
WorkflowNodeState,
|
||||
WorkflowNodeType,
|
||||
WorkflowResult,
|
||||
WorkflowResultType,
|
||||
} from './types';
|
||||
|
||||
export class WorkflowNode {
|
||||
private readonly edges: WorkflowNode[] = [];
|
||||
private readonly parents: WorkflowNode[] = [];
|
||||
private prompt: ChatPrompt | null = null;
|
||||
private provider: CopilotAllProvider | null = null;
|
||||
|
||||
constructor(private readonly data: NodeData) {}
|
||||
|
||||
get id(): string {
|
||||
return this.data.id;
|
||||
}
|
||||
|
||||
get name(): string {
|
||||
return this.data.name;
|
||||
}
|
||||
|
||||
get config(): NodeData {
|
||||
return Object.assign({}, this.data);
|
||||
}
|
||||
|
||||
get parent(): WorkflowNode[] {
|
||||
return this.parents;
|
||||
}
|
||||
|
||||
private set parent(node: WorkflowNode) {
|
||||
if (!this.parents.includes(node)) {
|
||||
this.parents.push(node);
|
||||
}
|
||||
}
|
||||
|
||||
addEdge(node: WorkflowNode): number {
|
||||
if (this.data.nodeType === WorkflowNodeType.Basic) {
|
||||
if (this.edges.length > 0) {
|
||||
throw new Error(`Basic block can only have one edge`);
|
||||
}
|
||||
} else if (!this.data.condition) {
|
||||
throw new Error(`Decision block must have a condition`);
|
||||
}
|
||||
node.parent = this;
|
||||
this.edges.push(node);
|
||||
return this.edges.length;
|
||||
}
|
||||
|
||||
async initNode(prompt: PromptService, provider: CopilotProviderService) {
|
||||
if (this.prompt && this.provider) return;
|
||||
|
||||
if (this.data.nodeType === WorkflowNodeType.Basic) {
|
||||
this.prompt = await prompt.get(this.data.promptName);
|
||||
if (!this.prompt) {
|
||||
throw new Error(
|
||||
`Prompt ${this.data.promptName} not found when running workflow node ${this.name}`
|
||||
);
|
||||
}
|
||||
this.provider = await provider.getProviderByModel(this.prompt.model);
|
||||
if (!this.provider) {
|
||||
throw new Error(
|
||||
`Provider not found for model ${this.prompt.model} when running workflow node ${this.name}`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async evaluateCondition(
|
||||
_condition?: string
|
||||
): Promise<string | undefined> {
|
||||
// todo: evaluate condition to impl decision block
|
||||
return this.edges[0]?.id;
|
||||
}
|
||||
|
||||
async *next(
|
||||
params: WorkflowNodeState,
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<WorkflowResult> {
|
||||
if (!this.prompt || !this.provider) {
|
||||
throw new Error(`Node ${this.name} not initialized`);
|
||||
}
|
||||
|
||||
yield { type: WorkflowResultType.StartRun, nodeId: this.id };
|
||||
|
||||
// choose next node in graph
|
||||
let nextNode: WorkflowNode | undefined = this.edges[0];
|
||||
if (this.data.nodeType === WorkflowNodeType.Decision) {
|
||||
const nextNodeId = await this.evaluateCondition(this.data.condition);
|
||||
// return empty to choose default edge
|
||||
if (nextNodeId) {
|
||||
nextNode = this.edges.find(node => node.id === nextNodeId);
|
||||
if (!nextNode) {
|
||||
throw new Error(`No edge found for condition ${this.data.condition}`);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// pass through content as a stream response if no next node
|
||||
const passthrough = !nextNode;
|
||||
if (this.data.type === 'text' && 'generateText' in this.provider) {
|
||||
if (this.data.paramKey) {
|
||||
// update params with custom key
|
||||
yield {
|
||||
type: WorkflowResultType.Params,
|
||||
params: {
|
||||
[this.data.paramKey]: await this.provider.generateText(
|
||||
this.prompt.finish(params),
|
||||
this.prompt.model,
|
||||
options
|
||||
),
|
||||
},
|
||||
};
|
||||
} else {
|
||||
for await (const content of this.provider.generateTextStream(
|
||||
this.prompt.finish(params),
|
||||
this.prompt.model,
|
||||
options
|
||||
)) {
|
||||
yield {
|
||||
type: WorkflowResultType.Content,
|
||||
nodeId: this.id,
|
||||
content,
|
||||
passthrough,
|
||||
};
|
||||
}
|
||||
}
|
||||
} else if (
|
||||
this.data.type === 'image' &&
|
||||
'generateImages' in this.provider
|
||||
) {
|
||||
if (this.data.paramKey) {
|
||||
yield {
|
||||
type: WorkflowResultType.Params,
|
||||
params: {
|
||||
[this.data.paramKey]: await this.provider.generateImages(
|
||||
this.prompt.finish(params),
|
||||
this.prompt.model,
|
||||
options
|
||||
),
|
||||
},
|
||||
};
|
||||
} else {
|
||||
for await (const content of this.provider.generateImagesStream(
|
||||
this.prompt.finish(params),
|
||||
this.prompt.model,
|
||||
options
|
||||
)) {
|
||||
yield {
|
||||
type: WorkflowResultType.Content,
|
||||
nodeId: this.id,
|
||||
content,
|
||||
passthrough,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
yield { type: WorkflowResultType.EndRun, nextNode };
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
import type { WorkflowNode } from './node';
|
||||
|
||||
export enum WorkflowNodeType {
|
||||
Basic,
|
||||
Decision,
|
||||
}
|
||||
|
||||
export type NodeData = { id: string; name: string } & (
|
||||
| {
|
||||
nodeType: WorkflowNodeType.Basic;
|
||||
promptName: string;
|
||||
type: 'text' | 'image';
|
||||
// update the prompt params by output with the custom key
|
||||
paramKey?: string;
|
||||
}
|
||||
| { nodeType: WorkflowNodeType.Decision; condition: string }
|
||||
);
|
||||
|
||||
export type WorkflowNodeState = Record<string, string>;
|
||||
|
||||
export type WorkflowGraphData = Array<NodeData & { edges: string[] }>;
|
||||
export type WorkflowGraphList = Array<{
|
||||
name: string;
|
||||
graph: WorkflowGraphData;
|
||||
}>;
|
||||
|
||||
export enum WorkflowResultType {
|
||||
StartRun,
|
||||
EndRun,
|
||||
Params,
|
||||
Content,
|
||||
}
|
||||
|
||||
export type WorkflowResult =
|
||||
| { type: WorkflowResultType.StartRun; nodeId: string }
|
||||
| { type: WorkflowResultType.EndRun; nextNode: WorkflowNode }
|
||||
| {
|
||||
type: WorkflowResultType.Params;
|
||||
params: Record<string, string | string[]>;
|
||||
}
|
||||
| {
|
||||
type: WorkflowResultType.Content;
|
||||
nodeId: string;
|
||||
content: string;
|
||||
// if is the end of the workflow, pass through the content to stream response
|
||||
passthrough?: boolean;
|
||||
};
|
||||
|
||||
export type WorkflowGraph = Map<string, WorkflowNode>;
|
||||
@@ -0,0 +1,72 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
import { PromptService } from '../prompt';
|
||||
import { CopilotProviderService } from '../providers';
|
||||
import { WorkflowNode } from './node';
|
||||
import {
|
||||
WorkflowGraph,
|
||||
WorkflowNodeState,
|
||||
WorkflowNodeType,
|
||||
WorkflowResultType,
|
||||
} from './types';
|
||||
|
||||
export class CopilotWorkflow {
|
||||
private readonly logger = new Logger(CopilotWorkflow.name);
|
||||
private readonly rootNode: WorkflowNode;
|
||||
|
||||
constructor(
|
||||
private readonly prompt: PromptService,
|
||||
private readonly provider: CopilotProviderService,
|
||||
workflow: WorkflowGraph
|
||||
) {
|
||||
const startNode = workflow.get('start');
|
||||
if (!startNode) {
|
||||
throw new Error(`No start node found in graph`);
|
||||
}
|
||||
this.rootNode = startNode;
|
||||
}
|
||||
|
||||
async *runGraph(initContent: string): AsyncIterable<string | undefined> {
|
||||
let currentNode: WorkflowNode | undefined = this.rootNode;
|
||||
const lastParams: WorkflowNodeState = { content: initContent };
|
||||
|
||||
while (currentNode) {
|
||||
let result = '';
|
||||
let nextNode: WorkflowNode | undefined;
|
||||
|
||||
await currentNode.initNode(this.prompt, this.provider);
|
||||
|
||||
for await (const ret of currentNode.next(lastParams)) {
|
||||
if (ret.type === WorkflowResultType.EndRun) {
|
||||
nextNode = ret.nextNode;
|
||||
break;
|
||||
} else if (ret.type === WorkflowResultType.Params) {
|
||||
Object.assign(lastParams, ret.params);
|
||||
if (currentNode.config.nodeType === WorkflowNodeType.Basic) {
|
||||
const { type, promptName } = currentNode.config;
|
||||
this.logger.verbose(
|
||||
`[${currentNode.name}][${type}][${promptName}]: update params - '${JSON.stringify(ret.params)}'`
|
||||
);
|
||||
}
|
||||
} else if (ret.type === WorkflowResultType.Content) {
|
||||
// pass through content as a stream response
|
||||
if (ret.passthrough) {
|
||||
yield ret.content;
|
||||
} else {
|
||||
result += ret.content;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (currentNode.config.nodeType === WorkflowNodeType.Basic && result) {
|
||||
const { type, promptName } = currentNode.config;
|
||||
this.logger.verbose(
|
||||
`[${currentNode.name}][${type}][${promptName}]: update content - '${lastParams.content}' -> '${result}'`
|
||||
);
|
||||
}
|
||||
|
||||
currentNode = nextNode;
|
||||
if (result) lastParams.content = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user