From 74452a0aaba7960d7089d7df4705f52978e7fbc3 Mon Sep 17 00:00:00 2001 From: akumatus Date: Tue, 13 May 2025 09:17:34 +0000 Subject: [PATCH] feat(core): add optionalModels field in AiPrompt and support the front-end modelId param (#12224) Close [AI-116](https://linear.app/affine-design/issue/AI-116) ## Summary by CodeRabbit - **New Features** - Added support for specifying alternative AI models in chat prompts, enabling users to select from multiple available models. - Expanded AI model options with new additions: 'gpt-4.1', 'o3', and 'claude-3-5-sonnet-20241022'. - **Enhancements** - Users can now optionally choose a specific AI model during chat interactions. - Prompts and chat sessions reflect and support selection of alternative models where applicable. - **Bug Fixes** - Improved handling of prompt configuration defaults for better reliability. --- .../migration.sql | 2 + packages/backend/server/schema.prisma | 1 + .../server/src/plugins/copilot/controller.ts | 56 ++++++++++++------- .../src/plugins/copilot/prompt/chat-prompt.ts | 2 + .../src/plugins/copilot/prompt/prompts.ts | 25 +++++++-- .../src/plugins/copilot/prompt/service.ts | 1 + .../plugins/copilot/providers/anthropic.ts | 5 +- .../src/plugins/copilot/providers/openai.ts | 1 + .../src/plugins/copilot/providers/types.ts | 12 ++-- .../server/src/plugins/copilot/session.ts | 4 ++ .../blocksuite/ai/provider/copilot-client.ts | 6 ++ 11 files changed, 83 insertions(+), 32 deletions(-) create mode 100644 packages/backend/server/migrations/20250512031140_add_optional_models/migration.sql diff --git a/packages/backend/server/migrations/20250512031140_add_optional_models/migration.sql b/packages/backend/server/migrations/20250512031140_add_optional_models/migration.sql new file mode 100644 index 0000000000..39e96d2208 --- /dev/null +++ b/packages/backend/server/migrations/20250512031140_add_optional_models/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "ai_prompts_metadata" ADD COLUMN "optional_models" VARCHAR[] DEFAULT ARRAY[]::VARCHAR[]; diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index 650bec0e6a..80560c7bab 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -397,6 +397,7 @@ model AiPrompt { // it is only used in the frontend and does not affect the backend action String? @db.VarChar model String @db.VarChar + optionalModels String[] @default([]) @db.VarChar @map("optional_models") config Json? @db.Json createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) updatedAt DateTime @default(now()) @map("updated_at") @db.Timestamptz(3) diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index 8b31427f06..b11762a2d3 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -62,7 +62,7 @@ export interface ChatEvent { } type CheckResult = { - model: string | undefined; + model: string; hasAttachment?: boolean; }; @@ -94,7 +94,8 @@ export class CopilotController implements BeforeApplicationShutdown { private async checkRequest( userId: string, sessionId: string, - messageId?: string + messageId?: string, + modelId?: string ): Promise { await this.chatSession.checkQuota(userId); const session = await this.chatSession.get(sessionId); @@ -102,7 +103,13 @@ export class CopilotController implements BeforeApplicationShutdown { throw new CopilotSessionNotFound(); } - const ret: CheckResult = { model: session.model }; + const ret: CheckResult = { + model: session.model, + }; + + if (modelId && session.optionalModels.includes(modelId)) { + ret.model = modelId; + } if (messageId && typeof messageId === 'string') { const message = await session.getMessageById(messageId); @@ -116,13 +123,16 @@ export class CopilotController implements BeforeApplicationShutdown { private async chooseTextProvider( userId: string, sessionId: string, - messageId?: string - ): Promise { + messageId?: string, + modelId?: string + ): Promise<{ provider: CopilotTextProvider; model: string }> { const { hasAttachment, model } = await this.checkRequest( userId, sessionId, - messageId + messageId, + modelId ); + let provider = await this.provider.getProviderByCapability( CopilotCapability.TextToText, { model } @@ -138,7 +148,7 @@ export class CopilotController implements BeforeApplicationShutdown { throw new NoCopilotProviderAvailable(); } - return provider; + return { provider, model }; } private async appendSessionMessage( @@ -182,13 +192,17 @@ export class CopilotController implements BeforeApplicationShutdown { const webSearch = Array.isArray(params.webSearch) ? Boolean(params.webSearch[0]) : Boolean(params.webSearch); + const modelId = Array.isArray(params.modelId) + ? params.modelId[0] + : params.modelId; delete params.messageId; delete params.retry; delete params.reasoning; delete params.webSearch; + delete params.modelId; - return { messageId, retry, reasoning, webSearch, params }; + return { messageId, retry, reasoning, webSearch, modelId, params }; } private getSignal(req: Request) { @@ -236,13 +250,14 @@ export class CopilotController implements BeforeApplicationShutdown { const info: any = { sessionId, params }; try { - const { messageId, retry, reasoning, webSearch } = + const { messageId, retry, reasoning, webSearch, modelId } = this.prepareParams(params); - const provider = await this.chooseTextProvider( + const { provider, model } = await this.chooseTextProvider( user.id, sessionId, - messageId + messageId, + modelId ); const [latestMessage, session] = await this.appendSessionMessage( @@ -251,8 +266,8 @@ export class CopilotController implements BeforeApplicationShutdown { retry ); - info.model = session.model; - metrics.ai.counter('chat_calls').add(1, { model: session.model }); + info.model = model; + metrics.ai.counter('chat_calls').add(1, { model }); if (latestMessage) { params = Object.assign({}, params, latestMessage.params, { @@ -264,7 +279,7 @@ export class CopilotController implements BeforeApplicationShutdown { const finalMessage = session.finish(params); info.finalMessage = finalMessage.filter(m => m.role !== 'system'); - const content = await provider.generateText(finalMessage, session.model, { + const content = await provider.generateText(finalMessage, model, { ...session.config.promptConfig, signal: this.getSignal(req), user: user.id, @@ -302,13 +317,14 @@ export class CopilotController implements BeforeApplicationShutdown { const info: any = { sessionId, params, throwInStream: false }; try { - const { messageId, retry, reasoning, webSearch } = + const { messageId, retry, reasoning, webSearch, modelId } = this.prepareParams(params); - const provider = await this.chooseTextProvider( + const { provider, model } = await this.chooseTextProvider( user.id, sessionId, - messageId + messageId, + modelId ); const [latestMessage, session] = await this.appendSessionMessage( @@ -317,8 +333,8 @@ export class CopilotController implements BeforeApplicationShutdown { retry ); - info.model = session.model; - metrics.ai.counter('chat_stream_calls').add(1, { model: session.model }); + info.model = model; + metrics.ai.counter('chat_stream_calls').add(1, { model }); if (latestMessage) { params = Object.assign({}, params, latestMessage.params, { @@ -332,7 +348,7 @@ export class CopilotController implements BeforeApplicationShutdown { info.finalMessage = finalMessage.filter(m => m.role !== 'system'); const source$ = from( - provider.generateTextStream(finalMessage, session.model, { + provider.generateTextStream(finalMessage, model, { ...session.config.promptConfig, signal: this.getSignal(req), user: user.id, diff --git a/packages/backend/server/src/plugins/copilot/prompt/chat-prompt.ts b/packages/backend/server/src/plugins/copilot/prompt/chat-prompt.ts index 233791571a..86d9fcae13 100644 --- a/packages/backend/server/src/plugins/copilot/prompt/chat-prompt.ts +++ b/packages/backend/server/src/plugins/copilot/prompt/chat-prompt.ts @@ -41,6 +41,7 @@ export class ChatPrompt { options.name, options.action || undefined, options.model, + options.optionalModels, options.config, options.messages ); @@ -50,6 +51,7 @@ export class ChatPrompt { public readonly name: string, public readonly action: string | undefined, public readonly model: string, + public readonly optionalModels: string[], public readonly config: PromptConfig | undefined, private readonly messages: PromptMessage[] ) { diff --git a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts index 6c29859500..017187e1bd 100644 --- a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts +++ b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts @@ -5,8 +5,15 @@ import { PromptConfig, PromptMessage } from '../providers'; type Prompt = Omit< AiPrompt, - 'id' | 'createdAt' | 'updatedAt' | 'modified' | 'action' | 'config' + | 'id' + | 'createdAt' + | 'updatedAt' + | 'modified' + | 'action' + | 'config' + | 'optionalModels' > & { + optionalModels?: string[]; action?: string; messages: PromptMessage[]; config?: PromptConfig; @@ -1037,7 +1044,13 @@ Finally, please only send us the content of your continuation in Markdown Format const chat: Prompt[] = [ { name: 'Chat With AFFiNE AI', - model: 'o4-mini', + model: 'gpt-4.1', + optionalModels: [ + 'o3', + 'o4-mini', + 'claude-3-7-sonnet-20250219', + 'claude-3-5-sonnet-20241022', + ], messages: [ { role: 'system', @@ -1161,14 +1174,15 @@ export async function refreshPrompts(db: PrismaClient) { create: { name: prompt.name, action: prompt.action, - config: prompt.config || undefined, + config: prompt.config ?? undefined, model: prompt.model, + optionalModels: prompt.optionalModels, messages: { create: prompt.messages.map((message, idx) => ({ idx, role: message.role, content: message.content, - params: message.params || undefined, + params: message.params ?? undefined, })), }, }, @@ -1177,6 +1191,7 @@ export async function refreshPrompts(db: PrismaClient) { action: prompt.action, config: prompt.config ?? undefined, model: prompt.model, + optionalModels: prompt.optionalModels, updatedAt: new Date(), messages: { deleteMany: {}, @@ -1184,7 +1199,7 @@ export async function refreshPrompts(db: PrismaClient) { idx, role: message.role, content: message.content, - params: message.params || undefined, + params: message.params ?? undefined, })), }, }, diff --git a/packages/backend/server/src/plugins/copilot/prompt/service.ts b/packages/backend/server/src/plugins/copilot/prompt/service.ts index 52d21f6f30..c4d19bf5ac 100644 --- a/packages/backend/server/src/plugins/copilot/prompt/service.ts +++ b/packages/backend/server/src/plugins/copilot/prompt/service.ts @@ -64,6 +64,7 @@ export class PromptService implements OnApplicationBootstrap { name: true, action: true, model: true, + optionalModels: true, config: true, messages: { select: { diff --git a/packages/backend/server/src/plugins/copilot/providers/anthropic.ts b/packages/backend/server/src/plugins/copilot/providers/anthropic.ts index e7cbd748d5..2c89de5ae5 100644 --- a/packages/backend/server/src/plugins/copilot/providers/anthropic.ts +++ b/packages/backend/server/src/plugins/copilot/providers/anthropic.ts @@ -34,7 +34,10 @@ export class AnthropicProvider { override readonly type = CopilotProviderType.Anthropic; override readonly capabilities = [CopilotCapability.TextToText]; - override readonly models = ['claude-3-7-sonnet-20250219']; + override readonly models = [ + 'claude-3-7-sonnet-20250219', + 'claude-3-5-sonnet-20241022', + ]; private readonly MAX_STEPS = 20; diff --git a/packages/backend/server/src/plugins/copilot/providers/openai.ts b/packages/backend/server/src/plugins/copilot/providers/openai.ts index f9edfc5f54..9bcf2a377b 100644 --- a/packages/backend/server/src/plugins/copilot/providers/openai.ts +++ b/packages/backend/server/src/plugins/copilot/providers/openai.ts @@ -74,6 +74,7 @@ export class OpenAIProvider 'gpt-4.1-2025-04-14', 'gpt-4.1-mini', 'o1', + 'o3', 'o4-mini', // embeddings 'text-embedding-3-large', diff --git a/packages/backend/server/src/plugins/copilot/providers/types.ts b/packages/backend/server/src/plugins/copilot/providers/types.ts index f83609a448..d975d30739 100644 --- a/packages/backend/server/src/plugins/copilot/providers/types.ts +++ b/packages/backend/server/src/plugins/copilot/providers/types.ts @@ -110,12 +110,12 @@ export type CopilotImageOptions = z.infer; export interface CopilotTextToTextProvider extends CopilotProvider { generateText( messages: PromptMessage[], - model?: string, + model: string, options?: CopilotChatOptions ): Promise; generateTextStream( messages: PromptMessage[], - model?: string, + model: string, options?: CopilotChatOptions ): AsyncIterable; } @@ -136,7 +136,7 @@ export interface CopilotTextToImageProvider extends CopilotProvider { ): Promise>; generateImagesStream( messages: PromptMessage[], - model?: string, + model: string, options?: CopilotImageOptions ): AsyncIterable; } @@ -145,12 +145,12 @@ export interface CopilotImageToTextProvider extends CopilotProvider { generateText( messages: PromptMessage[], model: string, - options?: CopilotChatOptions + options: CopilotChatOptions ): Promise; generateTextStream( messages: PromptMessage[], model: string, - options?: CopilotChatOptions + options: CopilotChatOptions ): AsyncIterable; } @@ -162,7 +162,7 @@ export interface CopilotImageToImageProvider extends CopilotProvider { ): Promise>; generateImagesStream( messages: PromptMessage[], - model?: string, + model: string, options?: CopilotImageOptions ): AsyncIterable; } diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index ca71a41b12..5ddd448463 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -45,6 +45,10 @@ export class ChatSession implements AsyncDisposable { return this.state.prompt.model; } + get optionalModels() { + return this.state.prompt.optionalModels; + } + get config() { const { sessionId, diff --git a/packages/frontend/core/src/blocksuite/ai/provider/copilot-client.ts b/packages/frontend/core/src/blocksuite/ai/provider/copilot-client.ts index c4c67da79c..ee3c910bc5 100644 --- a/packages/frontend/core/src/blocksuite/ai/provider/copilot-client.ts +++ b/packages/frontend/core/src/blocksuite/ai/provider/copilot-client.ts @@ -352,12 +352,14 @@ export class CopilotClient { messageId, reasoning, webSearch, + modelId, signal, }: { sessionId: string; messageId?: string; reasoning?: boolean; webSearch?: boolean; + modelId?: string; signal?: AbortSignal; }) { let url = `/api/copilot/chat/${sessionId}`; @@ -365,6 +367,7 @@ export class CopilotClient { messageId, reasoning, webSearch, + modelId, }); if (queryString) { url += `?${queryString}`; @@ -380,11 +383,13 @@ export class CopilotClient { messageId, reasoning, webSearch, + modelId, }: { sessionId: string; messageId?: string; reasoning?: boolean; webSearch?: boolean; + modelId?: string; }, endpoint = 'stream' ) { @@ -393,6 +398,7 @@ export class CopilotClient { messageId, reasoning, webSearch, + modelId, }); if (queryString) { url += `?${queryString}`;