feat: no branches workflow support (#7119)

fix AFF-1165 AFF-1164
This commit is contained in:
darkskygit
2024-06-07 05:53:39 +00:00
parent b75da1f3e0
commit 44b0ea2b6c
14 changed files with 599 additions and 12 deletions

View File

@@ -34,11 +34,7 @@ import { Config } from '../../fundamentals';
import { CopilotProviderService } from './providers';
import { ChatSession, ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import {
CopilotCapability,
CopilotImageToTextProvider,
CopilotTextToTextProvider,
} from './types';
import { CopilotCapability, CopilotTextProvider } from './types';
export interface ChatEvent {
type: 'attachment' | 'message' | 'error';
@@ -88,7 +84,7 @@ export class CopilotController {
userId: string,
sessionId: string,
messageId?: string
): Promise<CopilotTextToTextProvider | CopilotImageToTextProvider> {
): Promise<CopilotTextProvider> {
const { hasAttachment, model } = await this.checkRequest(
userId,
sessionId,

View File

@@ -22,6 +22,7 @@ import {
} from './resolver';
import { ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import { CopilotWorkflowService } from './workflow';
registerCopilotProvider(FalProvider);
registerCopilotProvider(OpenAIProvider);
@@ -39,6 +40,7 @@ registerCopilotProvider(OpenAIProvider);
CopilotProviderService,
CopilotStorage,
PromptsManagementResolver,
CopilotWorkflowService,
],
controllers: [CopilotController],
contributesTo: ServerFeature.Copilot,

View File

@@ -166,6 +166,34 @@ export class CopilotProviderService {
}
return null;
}
async getProviderByModel<C extends CopilotCapability>(
model: string,
prefer?: CopilotProviderType
): Promise<CapabilityToCopilotProvider[C] | null> {
const providers = Array.from(COPILOT_PROVIDER.keys());
if (providers.length) {
let selectedProvider: CopilotProviderType | undefined = prefer;
let currentIndex = -1;
if (!selectedProvider) {
currentIndex = 0;
selectedProvider = providers[currentIndex];
}
while (selectedProvider) {
const provider = this.getProvider(selectedProvider);
if (await provider.isModelAvailable(model)) {
return provider as CapabilityToCopilotProvider[C];
}
currentIndex += 1;
selectedProvider = providers[currentIndex];
}
}
return null;
}
}
export { FalProvider } from './fal';

View File

@@ -1,5 +1,3 @@
import assert from 'node:assert';
import { Logger } from '@nestjs/common';
import { ClientOptions, OpenAI } from 'openai';
@@ -58,12 +56,11 @@ export class OpenAIProvider
private existsModels: string[] | undefined;
constructor(config: ClientOptions) {
assert(OpenAIProvider.assetsConfig(config));
this.instance = new OpenAI(config);
}
static assetsConfig(config: ClientOptions) {
return !!config.apiKey;
return !!config?.apiKey;
}
get type(): CopilotProviderType {

View File

@@ -230,3 +230,14 @@ export type CapabilityToCopilotProvider = {
[CopilotCapability.ImageToText]: CopilotImageToTextProvider;
[CopilotCapability.ImageToImage]: CopilotImageToImageProvider;
};
export type CopilotTextProvider =
| CopilotTextToTextProvider
| CopilotImageToTextProvider;
export type CopilotImageProvider =
| CopilotTextToImageProvider
| CopilotImageToImageProvider;
export type CopilotAllProvider =
| CopilotTextProvider
| CopilotImageProvider
| CopilotTextToEmbeddingProvider;

View File

@@ -0,0 +1,65 @@
import { type WorkflowGraphList, WorkflowNodeType } from './types';
export const WorkflowGraphs: WorkflowGraphList = [
{
name: 'Create a presentation',
graph: [
{
id: 'start',
name: 'Start: check language',
nodeType: WorkflowNodeType.Basic,
type: 'text',
promptName: 'Create a presentation:step1',
paramKey: 'language',
edges: ['step2'],
},
{
id: 'step2',
name: 'Step 2: generate presentation',
nodeType: WorkflowNodeType.Basic,
type: 'text',
promptName: 'Create a presentation:step2',
edges: [],
// edges: ['step3'],
},
// {
// id: 'step3',
// name: 'Step 3: check format',
// nodeType: WorkflowNodeType.Basic,
// type: 'text',
// promptName: 'Create a presentation:step3',
// paramKey: 'needFormat',
// edges: ['step4'],
// },
// {
// id: 'step4',
// name: 'Step 4: format presentation if needed',
// nodeType: WorkflowNodeType.Decision,
// condition: ((
// nodeIds: string[],
// params: WorkflowNodeState
// ) =>
// nodeIds[
// Number(String(params.needFormat).toLowerCase() === 'true')
// ]).toString(),
// edges: ['step5', 'step6'],
// },
// {
// id: 'step5',
// name: 'Step 5: format presentation',
// nodeType: WorkflowNodeType.Basic,
// type: 'text',
// promptName: 'Create a presentation:step5',
// edges: ['step6'],
// },
// {
// id: 'step6',
// name: 'Step 6: finish',
// nodeType: WorkflowNodeType.Basic,
// type: 'text',
// promptName: 'Create a presentation:step6',
// edges: [],
// },
],
},
];

View File

@@ -0,0 +1,74 @@
import { Injectable, Logger } from '@nestjs/common';
import { PromptService } from '../prompt';
import { CopilotProviderService } from '../providers';
import { WorkflowGraphs } from './graph';
import { WorkflowNode } from './node';
import { WorkflowGraph, WorkflowGraphList } from './types';
import { CopilotWorkflow } from './workflow';
@Injectable()
export class CopilotWorkflowService {
private readonly logger = new Logger(CopilotWorkflowService.name);
constructor(
private readonly prompt: PromptService,
private readonly provider: CopilotProviderService
) {}
private initWorkflow({ name, graph }: WorkflowGraphList[number]) {
const workflow = new Map();
for (const nodeData of graph) {
const { edges: _, ...data } = nodeData;
const node = new WorkflowNode(data);
workflow.set(node.id, node);
}
// add edges
for (const nodeData of graph) {
const node = workflow.get(nodeData.id);
if (!node) {
this.logger.error(
`Failed to init workflow ${name}: node ${nodeData.id} not found`
);
throw new Error(`Node ${nodeData.id} not found`);
}
for (const edgeId of nodeData.edges) {
const edge = workflow.get(edgeId);
if (!edge) {
this.logger.error(
`Failed to init workflow ${name}: edge ${edgeId} not found in node ${nodeData.id}`
);
throw new Error(`Edge ${edgeId} not found`);
}
node.addEdge(edge);
}
}
return workflow;
}
// todo: get workflow from database
private async getWorkflow(graphName: string): Promise<WorkflowGraph> {
const graph = WorkflowGraphs.find(g => g.name === graphName);
if (!graph) {
throw new Error(`Graph ${graphName} not found`);
}
return this.initWorkflow(graph);
}
async *runGraph(
graphName: string,
initContent: string
): AsyncIterable<string | undefined> {
const workflowGraph = await this.getWorkflow(graphName);
const workflow = new CopilotWorkflow(
this.prompt,
this.provider,
workflowGraph
);
for await (const result of workflow.runGraph(initContent)) {
yield result;
}
}
}

View File

@@ -0,0 +1,166 @@
import { ChatPrompt, PromptService } from '../prompt';
import { CopilotProviderService } from '../providers';
import { CopilotAllProvider, CopilotChatOptions } from '../types';
import {
NodeData,
WorkflowNodeState,
WorkflowNodeType,
WorkflowResult,
WorkflowResultType,
} from './types';
export class WorkflowNode {
private readonly edges: WorkflowNode[] = [];
private readonly parents: WorkflowNode[] = [];
private prompt: ChatPrompt | null = null;
private provider: CopilotAllProvider | null = null;
constructor(private readonly data: NodeData) {}
get id(): string {
return this.data.id;
}
get name(): string {
return this.data.name;
}
get config(): NodeData {
return Object.assign({}, this.data);
}
get parent(): WorkflowNode[] {
return this.parents;
}
private set parent(node: WorkflowNode) {
if (!this.parents.includes(node)) {
this.parents.push(node);
}
}
addEdge(node: WorkflowNode): number {
if (this.data.nodeType === WorkflowNodeType.Basic) {
if (this.edges.length > 0) {
throw new Error(`Basic block can only have one edge`);
}
} else if (!this.data.condition) {
throw new Error(`Decision block must have a condition`);
}
node.parent = this;
this.edges.push(node);
return this.edges.length;
}
async initNode(prompt: PromptService, provider: CopilotProviderService) {
if (this.prompt && this.provider) return;
if (this.data.nodeType === WorkflowNodeType.Basic) {
this.prompt = await prompt.get(this.data.promptName);
if (!this.prompt) {
throw new Error(
`Prompt ${this.data.promptName} not found when running workflow node ${this.name}`
);
}
this.provider = await provider.getProviderByModel(this.prompt.model);
if (!this.provider) {
throw new Error(
`Provider not found for model ${this.prompt.model} when running workflow node ${this.name}`
);
}
}
}
private async evaluateCondition(
_condition?: string
): Promise<string | undefined> {
// todo: evaluate condition to impl decision block
return this.edges[0]?.id;
}
async *next(
params: WorkflowNodeState,
options?: CopilotChatOptions
): AsyncIterable<WorkflowResult> {
if (!this.prompt || !this.provider) {
throw new Error(`Node ${this.name} not initialized`);
}
yield { type: WorkflowResultType.StartRun, nodeId: this.id };
// choose next node in graph
let nextNode: WorkflowNode | undefined = this.edges[0];
if (this.data.nodeType === WorkflowNodeType.Decision) {
const nextNodeId = await this.evaluateCondition(this.data.condition);
// return empty to choose default edge
if (nextNodeId) {
nextNode = this.edges.find(node => node.id === nextNodeId);
if (!nextNode) {
throw new Error(`No edge found for condition ${this.data.condition}`);
}
}
} else {
// pass through content as a stream response if no next node
const passthrough = !nextNode;
if (this.data.type === 'text' && 'generateText' in this.provider) {
if (this.data.paramKey) {
// update params with custom key
yield {
type: WorkflowResultType.Params,
params: {
[this.data.paramKey]: await this.provider.generateText(
this.prompt.finish(params),
this.prompt.model,
options
),
},
};
} else {
for await (const content of this.provider.generateTextStream(
this.prompt.finish(params),
this.prompt.model,
options
)) {
yield {
type: WorkflowResultType.Content,
nodeId: this.id,
content,
passthrough,
};
}
}
} else if (
this.data.type === 'image' &&
'generateImages' in this.provider
) {
if (this.data.paramKey) {
yield {
type: WorkflowResultType.Params,
params: {
[this.data.paramKey]: await this.provider.generateImages(
this.prompt.finish(params),
this.prompt.model,
options
),
},
};
} else {
for await (const content of this.provider.generateImagesStream(
this.prompt.finish(params),
this.prompt.model,
options
)) {
yield {
type: WorkflowResultType.Content,
nodeId: this.id,
content,
passthrough,
};
}
}
}
}
yield { type: WorkflowResultType.EndRun, nextNode };
}
}

View File

@@ -0,0 +1,49 @@
import type { WorkflowNode } from './node';
export enum WorkflowNodeType {
Basic,
Decision,
}
export type NodeData = { id: string; name: string } & (
| {
nodeType: WorkflowNodeType.Basic;
promptName: string;
type: 'text' | 'image';
// update the prompt params by output with the custom key
paramKey?: string;
}
| { nodeType: WorkflowNodeType.Decision; condition: string }
);
export type WorkflowNodeState = Record<string, string>;
export type WorkflowGraphData = Array<NodeData & { edges: string[] }>;
export type WorkflowGraphList = Array<{
name: string;
graph: WorkflowGraphData;
}>;
export enum WorkflowResultType {
StartRun,
EndRun,
Params,
Content,
}
export type WorkflowResult =
| { type: WorkflowResultType.StartRun; nodeId: string }
| { type: WorkflowResultType.EndRun; nextNode: WorkflowNode }
| {
type: WorkflowResultType.Params;
params: Record<string, string | string[]>;
}
| {
type: WorkflowResultType.Content;
nodeId: string;
content: string;
// if is the end of the workflow, pass through the content to stream response
passthrough?: boolean;
};
export type WorkflowGraph = Map<string, WorkflowNode>;

View File

@@ -0,0 +1,72 @@
import { Logger } from '@nestjs/common';
import { PromptService } from '../prompt';
import { CopilotProviderService } from '../providers';
import { WorkflowNode } from './node';
import {
WorkflowGraph,
WorkflowNodeState,
WorkflowNodeType,
WorkflowResultType,
} from './types';
export class CopilotWorkflow {
private readonly logger = new Logger(CopilotWorkflow.name);
private readonly rootNode: WorkflowNode;
constructor(
private readonly prompt: PromptService,
private readonly provider: CopilotProviderService,
workflow: WorkflowGraph
) {
const startNode = workflow.get('start');
if (!startNode) {
throw new Error(`No start node found in graph`);
}
this.rootNode = startNode;
}
async *runGraph(initContent: string): AsyncIterable<string | undefined> {
let currentNode: WorkflowNode | undefined = this.rootNode;
const lastParams: WorkflowNodeState = { content: initContent };
while (currentNode) {
let result = '';
let nextNode: WorkflowNode | undefined;
await currentNode.initNode(this.prompt, this.provider);
for await (const ret of currentNode.next(lastParams)) {
if (ret.type === WorkflowResultType.EndRun) {
nextNode = ret.nextNode;
break;
} else if (ret.type === WorkflowResultType.Params) {
Object.assign(lastParams, ret.params);
if (currentNode.config.nodeType === WorkflowNodeType.Basic) {
const { type, promptName } = currentNode.config;
this.logger.verbose(
`[${currentNode.name}][${type}][${promptName}]: update params - '${JSON.stringify(ret.params)}'`
);
}
} else if (ret.type === WorkflowResultType.Content) {
// pass through content as a stream response
if (ret.passthrough) {
yield ret.content;
} else {
result += ret.content;
}
}
}
if (currentNode.config.nodeType === WorkflowNodeType.Basic && result) {
const { type, promptName } = currentNode.config;
this.logger.verbose(
`[${currentNode.name}][${type}][${promptName}]: update content - '${lastParams.content}' -> '${result}'`
);
}
currentNode = nextNode;
if (result) lastParams.content = result;
}
}
}