mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-07-04 19:15:33 +08:00
44e1eb503f
fix AI-109
205 lines
5.5 KiB
TypeScript
205 lines
5.5 KiB
TypeScript
import { Inject, Injectable, Logger } from '@nestjs/common';
|
|
import { z } from 'zod';
|
|
|
|
import {
|
|
Config,
|
|
CopilotPromptInvalid,
|
|
CopilotProviderNotSupported,
|
|
OnEvent,
|
|
} from '../../../base';
|
|
import { CopilotProviderFactory } from './factory';
|
|
import {
|
|
type CopilotChatOptions,
|
|
type CopilotEmbeddingOptions,
|
|
type CopilotImageOptions,
|
|
CopilotProviderModel,
|
|
CopilotProviderType,
|
|
CopilotStructuredOptions,
|
|
EmbeddingMessage,
|
|
ModelCapability,
|
|
ModelConditions,
|
|
ModelFullConditions,
|
|
ModelInputType,
|
|
type PromptMessage,
|
|
PromptMessageSchema,
|
|
} from './types';
|
|
|
|
@Injectable()
|
|
export abstract class CopilotProvider<C = any> {
|
|
protected readonly logger = new Logger(this.constructor.name);
|
|
abstract readonly type: CopilotProviderType;
|
|
abstract readonly models: CopilotProviderModel[];
|
|
abstract configured(): boolean;
|
|
|
|
@Inject() protected readonly AFFiNEConfig!: Config;
|
|
@Inject() protected readonly factory!: CopilotProviderFactory;
|
|
|
|
get config(): C {
|
|
return this.AFFiNEConfig.copilot.providers[this.type] as C;
|
|
}
|
|
|
|
@OnEvent('config.init')
|
|
async onConfigInit() {
|
|
this.setup();
|
|
}
|
|
|
|
@OnEvent('config.changed')
|
|
async onConfigChanged(event: Events['config.changed']) {
|
|
if ('copilot' in event.updates) {
|
|
this.setup();
|
|
}
|
|
}
|
|
|
|
protected setup() {
|
|
if (this.configured()) {
|
|
this.factory.register(this);
|
|
} else {
|
|
this.factory.unregister(this);
|
|
}
|
|
}
|
|
|
|
private findValidModel(
|
|
cond: ModelFullConditions
|
|
): CopilotProviderModel | undefined {
|
|
const { modelId, outputType, inputTypes } = cond;
|
|
const matcher = (cap: ModelCapability) =>
|
|
(!outputType || cap.output.includes(outputType)) &&
|
|
(!inputTypes?.length ||
|
|
inputTypes.every(type => cap.input.includes(type)));
|
|
|
|
if (modelId) {
|
|
return this.models.find(
|
|
m => m.id === modelId && m.capabilities.some(matcher)
|
|
);
|
|
}
|
|
if (!outputType) return undefined;
|
|
|
|
return this.models.find(m =>
|
|
m.capabilities.some(c => matcher(c) && c.defaultForOutputType)
|
|
);
|
|
}
|
|
|
|
// make it async to allow dynamic check available models in some providers
|
|
async match(cond: ModelFullConditions = {}): Promise<boolean> {
|
|
return this.configured() && !!this.findValidModel(cond);
|
|
}
|
|
|
|
protected selectModel(cond: ModelFullConditions): CopilotProviderModel {
|
|
const model = this.findValidModel(cond);
|
|
if (model) return model;
|
|
|
|
const { modelId, outputType, inputTypes } = cond;
|
|
throw new CopilotPromptInvalid(
|
|
modelId
|
|
? `Model ${modelId} does not support ${outputType ?? '<any>'} output with ${inputTypes ?? '<any>'} input`
|
|
: outputType
|
|
? `No model supports ${outputType} output with ${inputTypes ?? '<any>'} input for provider ${this.type}`
|
|
: 'Output type is required when modelId is not provided'
|
|
);
|
|
}
|
|
|
|
private handleZodError(ret: z.SafeParseReturnType<any, any>) {
|
|
if (ret.success) return;
|
|
const issues = ret.error.issues.map(i => {
|
|
const path =
|
|
'root' +
|
|
(i.path.length
|
|
? `.${i.path.map(seg => (typeof seg === 'number' ? `[${seg}]` : `.${seg}`)).join('')}`
|
|
: '');
|
|
return `${i.message}${path}`;
|
|
});
|
|
throw new CopilotPromptInvalid(issues.join('; '));
|
|
}
|
|
|
|
protected async checkParams({
|
|
cond,
|
|
messages,
|
|
embeddings,
|
|
options = {},
|
|
}: {
|
|
cond: ModelFullConditions;
|
|
messages?: PromptMessage[];
|
|
embeddings?: string[];
|
|
options?: CopilotChatOptions;
|
|
}) {
|
|
const model = this.selectModel(cond);
|
|
const multimodal = model.capabilities.some(c =>
|
|
[ModelInputType.Image, ModelInputType.Audio].some(t =>
|
|
c.input.includes(t)
|
|
)
|
|
);
|
|
|
|
if (messages) {
|
|
const { requireContent = true, requireAttachment = false } = options;
|
|
|
|
const MessageSchema = z
|
|
.array(
|
|
PromptMessageSchema.extend({
|
|
content: requireContent
|
|
? z.string().trim().min(1)
|
|
: z.string().optional().nullable(),
|
|
})
|
|
.passthrough()
|
|
.catchall(z.union([z.string(), z.number(), z.date(), z.null()]))
|
|
.refine(
|
|
m =>
|
|
!(multimodal && requireAttachment && m.role === 'user') ||
|
|
(m.attachments ? m.attachments.length > 0 : true),
|
|
{ message: 'attachments required in multimodal mode' }
|
|
)
|
|
)
|
|
.optional();
|
|
|
|
this.handleZodError(MessageSchema.safeParse(messages));
|
|
}
|
|
if (embeddings) {
|
|
this.handleZodError(EmbeddingMessage.safeParse(embeddings));
|
|
}
|
|
}
|
|
|
|
abstract text(
|
|
model: ModelConditions,
|
|
messages: PromptMessage[],
|
|
options?: CopilotChatOptions
|
|
): Promise<string>;
|
|
|
|
abstract streamText(
|
|
model: ModelConditions,
|
|
messages: PromptMessage[],
|
|
options?: CopilotChatOptions
|
|
): AsyncIterable<string>;
|
|
|
|
structure(
|
|
_cond: ModelConditions,
|
|
_messages: PromptMessage[],
|
|
_options?: CopilotStructuredOptions
|
|
): Promise<string> {
|
|
throw new CopilotProviderNotSupported({
|
|
provider: this.type,
|
|
kind: 'structure',
|
|
});
|
|
}
|
|
|
|
streamImages(
|
|
_model: ModelConditions,
|
|
_messages: PromptMessage[],
|
|
_options?: CopilotImageOptions
|
|
): AsyncIterable<string> {
|
|
throw new CopilotProviderNotSupported({
|
|
provider: this.type,
|
|
kind: 'image',
|
|
});
|
|
}
|
|
|
|
embedding(
|
|
_model: ModelConditions,
|
|
_text: string | string[],
|
|
_options?: CopilotEmbeddingOptions
|
|
): Promise<number[][]> {
|
|
throw new CopilotProviderNotSupported({
|
|
provider: this.type,
|
|
kind: 'embedding',
|
|
});
|
|
}
|
|
}
|