mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 21:05:19 +00:00
feat(server): refactor provider interface (#11665)
fix AI-4 fix AI-18 better provider/model choose to allow fallback to similar models (e.g., self-hosted) when the provider is not fully configured split functions of different output types
This commit is contained in:
@@ -46,13 +46,14 @@ import {
|
||||
} from '../../base';
|
||||
import { CurrentUser, Public } from '../../core/auth';
|
||||
import {
|
||||
CopilotCapability,
|
||||
CopilotProvider,
|
||||
CopilotProviderFactory,
|
||||
CopilotTextProvider,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
} from './providers';
|
||||
import { ChatSession, ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
import { ChatMessage } from './types';
|
||||
import { ChatMessage, ChatQuerySchema } from './types';
|
||||
import { CopilotWorkflowService, GraphExecutorState } from './workflow';
|
||||
|
||||
export interface ChatEvent {
|
||||
@@ -61,11 +62,6 @@ export interface ChatEvent {
|
||||
data: string | object;
|
||||
}
|
||||
|
||||
type CheckResult = {
|
||||
model: string;
|
||||
hasAttachment?: boolean;
|
||||
};
|
||||
|
||||
const PING_INTERVAL = 5000;
|
||||
|
||||
@Controller('/api/copilot')
|
||||
@@ -91,64 +87,44 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
this.ongoingStreamCount$.complete();
|
||||
}
|
||||
|
||||
private async checkRequest(
|
||||
private async chooseProvider(
|
||||
outputType: ModelOutputType,
|
||||
userId: string,
|
||||
sessionId: string,
|
||||
messageId?: string,
|
||||
modelId?: string
|
||||
): Promise<CheckResult> {
|
||||
await this.chatSession.checkQuota(userId);
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
): Promise<{
|
||||
provider: CopilotProvider;
|
||||
model: string;
|
||||
hasAttachment: boolean;
|
||||
}> {
|
||||
const [, session] = await Promise.all([
|
||||
this.chatSession.checkQuota(userId),
|
||||
this.chatSession.get(sessionId),
|
||||
]);
|
||||
|
||||
if (!session || session.config.userId !== userId) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
|
||||
const ret: CheckResult = {
|
||||
model: session.model,
|
||||
};
|
||||
const model =
|
||||
modelId && session.optionalModels.includes(modelId)
|
||||
? modelId
|
||||
: session.model;
|
||||
|
||||
if (modelId && session.optionalModels.includes(modelId)) {
|
||||
ret.model = modelId;
|
||||
}
|
||||
const hasAttachment = messageId
|
||||
? !!(await session.getMessageById(messageId)).attachments?.length
|
||||
: false;
|
||||
|
||||
if (messageId && typeof messageId === 'string') {
|
||||
const message = await session.getMessageById(messageId);
|
||||
ret.hasAttachment =
|
||||
Array.isArray(message.attachments) && !!message.attachments.length;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
private async chooseTextProvider(
|
||||
userId: string,
|
||||
sessionId: string,
|
||||
messageId?: string,
|
||||
modelId?: string
|
||||
): Promise<{ provider: CopilotTextProvider; model: string }> {
|
||||
const { hasAttachment, model } = await this.checkRequest(
|
||||
userId,
|
||||
sessionId,
|
||||
messageId,
|
||||
modelId
|
||||
);
|
||||
|
||||
let provider = await this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText,
|
||||
{ model }
|
||||
);
|
||||
// fallback to image to text if text to text is not available
|
||||
if (!provider && hasAttachment) {
|
||||
provider = await this.provider.getProviderByCapability(
|
||||
CopilotCapability.ImageToText,
|
||||
{ model }
|
||||
);
|
||||
}
|
||||
const provider = await this.provider.getProvider({
|
||||
outputType,
|
||||
modelId: model,
|
||||
});
|
||||
if (!provider) {
|
||||
throw new NoCopilotProviderAvailable();
|
||||
}
|
||||
|
||||
return { provider, model };
|
||||
return { provider, model, hasAttachment };
|
||||
}
|
||||
|
||||
private async appendSessionMessage(
|
||||
@@ -179,32 +155,6 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
return [latestMessage, session];
|
||||
}
|
||||
|
||||
private prepareParams(params: Record<string, string | string[]>) {
|
||||
const messageId = Array.isArray(params.messageId)
|
||||
? params.messageId[0]
|
||||
: params.messageId;
|
||||
const retry = Array.isArray(params.retry)
|
||||
? Boolean(params.retry[0])
|
||||
: Boolean(params.retry);
|
||||
const reasoning = Array.isArray(params.reasoning)
|
||||
? Boolean(params.reasoning[0])
|
||||
: Boolean(params.reasoning);
|
||||
const webSearch = Array.isArray(params.webSearch)
|
||||
? Boolean(params.webSearch[0])
|
||||
: Boolean(params.webSearch);
|
||||
const modelId = Array.isArray(params.modelId)
|
||||
? params.modelId[0]
|
||||
: params.modelId;
|
||||
|
||||
delete params.messageId;
|
||||
delete params.retry;
|
||||
delete params.reasoning;
|
||||
delete params.webSearch;
|
||||
delete params.modelId;
|
||||
|
||||
return { messageId, retry, reasoning, webSearch, modelId, params };
|
||||
}
|
||||
|
||||
private getSignal(req: Request) {
|
||||
const controller = new AbortController();
|
||||
req.socket.on('close', hasError => {
|
||||
@@ -245,15 +195,16 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query() params: Record<string, string | string[]>
|
||||
@Query() query: Record<string, string | string[]>
|
||||
): Promise<string> {
|
||||
const info: any = { sessionId, params };
|
||||
const info: any = { sessionId, params: query };
|
||||
|
||||
try {
|
||||
const { messageId, retry, reasoning, webSearch, modelId } =
|
||||
this.prepareParams(params);
|
||||
let { messageId, retry, reasoning, webSearch, modelId, params } =
|
||||
ChatQuerySchema.parse(query);
|
||||
|
||||
const { provider, model } = await this.chooseTextProvider(
|
||||
const { provider, model } = await this.chooseProvider(
|
||||
ModelOutputType.Text,
|
||||
user.id,
|
||||
sessionId,
|
||||
messageId,
|
||||
@@ -279,7 +230,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
const finalMessage = session.finish(params);
|
||||
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
|
||||
|
||||
const content = await provider.generateText(finalMessage, model, {
|
||||
const content = await provider.text({ modelId: model }, finalMessage, {
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
@@ -312,15 +263,16 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query() params: Record<string, string>
|
||||
@Query() query: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
const info: any = { sessionId, params, throwInStream: false };
|
||||
const info: any = { sessionId, params: query, throwInStream: false };
|
||||
|
||||
try {
|
||||
const { messageId, retry, reasoning, webSearch, modelId } =
|
||||
this.prepareParams(params);
|
||||
let { messageId, retry, reasoning, webSearch, modelId, params } =
|
||||
ChatQuerySchema.parse(query);
|
||||
|
||||
const { provider, model } = await this.chooseTextProvider(
|
||||
const { provider, model } = await this.chooseProvider(
|
||||
ModelOutputType.Text,
|
||||
user.id,
|
||||
sessionId,
|
||||
messageId,
|
||||
@@ -348,7 +300,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
|
||||
|
||||
const source$ = from(
|
||||
provider.generateTextStream(finalMessage, model, {
|
||||
provider.streamText({ modelId: model }, finalMessage, {
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
@@ -387,7 +339,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
})
|
||||
);
|
||||
|
||||
return this.mergePingStream(messageId, source$);
|
||||
return this.mergePingStream(messageId || '', source$);
|
||||
} catch (err) {
|
||||
metrics.ai.counter('chat_stream_errors').add(1, info);
|
||||
return mapSseError(err, info);
|
||||
@@ -400,11 +352,11 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query() params: Record<string, string>
|
||||
@Query() query: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
const info: any = { sessionId, params, throwInStream: false };
|
||||
const info: any = { sessionId, params: query, throwInStream: false };
|
||||
try {
|
||||
const { messageId } = this.prepareParams(params);
|
||||
let { messageId, params } = ChatQuerySchema.parse(query);
|
||||
|
||||
const [, session] = await this.appendSessionMessage(sessionId, messageId);
|
||||
info.model = session.model;
|
||||
@@ -487,7 +439,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
)
|
||||
);
|
||||
|
||||
return this.mergePingStream(messageId, source$);
|
||||
return this.mergePingStream(messageId || '', source$);
|
||||
} catch (err) {
|
||||
metrics.ai.counter('workflow_errors').add(1, info);
|
||||
return mapSseError(err, info);
|
||||
@@ -500,35 +452,25 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query() params: Record<string, string>
|
||||
@Query() query: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
const info: any = { sessionId, params, throwInStream: false };
|
||||
const info: any = { sessionId, params: query, throwInStream: false };
|
||||
try {
|
||||
const { messageId } = this.prepareParams(params);
|
||||
let { messageId, params } = ChatQuerySchema.parse(query);
|
||||
|
||||
const { model, hasAttachment } = await this.checkRequest(
|
||||
const { provider, model, hasAttachment } = await this.chooseProvider(
|
||||
ModelOutputType.Image,
|
||||
user.id,
|
||||
sessionId,
|
||||
messageId
|
||||
);
|
||||
const provider = await this.provider.getProviderByCapability(
|
||||
hasAttachment
|
||||
? CopilotCapability.ImageToImage
|
||||
: CopilotCapability.TextToImage,
|
||||
{ model }
|
||||
);
|
||||
if (!provider) {
|
||||
throw new NoCopilotProviderAvailable();
|
||||
}
|
||||
|
||||
const [latestMessage, session] = await this.appendSessionMessage(
|
||||
sessionId,
|
||||
messageId
|
||||
);
|
||||
info.model = session.model;
|
||||
metrics.ai
|
||||
.counter('images_stream_calls')
|
||||
.add(1, { model: session.model });
|
||||
info.model = model;
|
||||
metrics.ai.counter('images_stream_calls').add(1, { model });
|
||||
|
||||
if (latestMessage) {
|
||||
params = Object.assign({}, params, latestMessage.params, {
|
||||
@@ -544,13 +486,22 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
);
|
||||
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
|
||||
const source$ = from(
|
||||
provider.generateImagesStream(session.finish(params), session.model, {
|
||||
...session.config.promptConfig,
|
||||
quality: params.quality || undefined,
|
||||
seed: this.parseNumber(params.seed),
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
provider.streamImages(
|
||||
{
|
||||
modelId: model,
|
||||
inputTypes: hasAttachment
|
||||
? [ModelInputType.Image]
|
||||
: [ModelInputType.Text],
|
||||
},
|
||||
session.finish(params),
|
||||
{
|
||||
...session.config.promptConfig,
|
||||
quality: params.quality || undefined,
|
||||
seed: this.parseNumber(params.seed),
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
}
|
||||
)
|
||||
).pipe(
|
||||
mergeMap(handleRemoteLink),
|
||||
connect(shared$ =>
|
||||
@@ -589,7 +540,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
)
|
||||
);
|
||||
|
||||
return this.mergePingStream(messageId, source$);
|
||||
return this.mergePingStream(messageId || '', source$);
|
||||
} catch (err) {
|
||||
metrics.ai.counter('images_stream_errors').add(1, info);
|
||||
return mapSseError(err, info);
|
||||
|
||||
Reference in New Issue
Block a user