diff --git a/.docker/selfhost/schema.json b/.docker/selfhost/schema.json index 0978362246..d8356b1dc1 100644 --- a/.docker/selfhost/schema.json +++ b/.docker/selfhost/schema.json @@ -669,12 +669,12 @@ }, "scenarios": { "type": "object", - "description": "Use custom models in scenarios and override default settings.\n@default {\"override_enabled\":false,\"scenarios\":{\"audio_transcribing\":\"gemini-2.5-flash\",\"chat\":\"claude-sonnet-4@20250514\",\"embedding\":\"gemini-embedding-001\",\"image\":\"gpt-image-1\",\"rerank\":\"gpt-4.1\",\"coding\":\"claude-sonnet-4@20250514\",\"complex_text_generation\":\"gpt-4o-2024-08-06\",\"quick_decision_making\":\"gpt-5-mini\",\"quick_text_generation\":\"gemini-2.5-flash\",\"polish_and_summarize\":\"gemini-2.5-flash\"}}", + "description": "Use custom models in scenarios and override default settings.\n@default {\"override_enabled\":false,\"scenarios\":{\"audio_transcribing\":\"gemini-2.5-flash\",\"chat\":\"gemini-2.5-flash\",\"embedding\":\"gemini-embedding-001\",\"image\":\"gpt-image-1\",\"rerank\":\"gpt-4.1\",\"coding\":\"claude-sonnet-4@20250514\",\"complex_text_generation\":\"gpt-4o-2024-08-06\",\"quick_decision_making\":\"gpt-5-mini\",\"quick_text_generation\":\"gemini-2.5-flash\",\"polish_and_summarize\":\"gemini-2.5-flash\"}}", "default": { "override_enabled": false, "scenarios": { "audio_transcribing": "gemini-2.5-flash", - "chat": "claude-sonnet-4@20250514", + "chat": "gemini-2.5-flash", "embedding": "gemini-embedding-001", "image": "gpt-image-1", "rerank": "gpt-4.1", diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md b/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md index 1809ad50c0..0797e8989a 100644 --- a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md +++ b/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md @@ -444,3 +444,37 @@ Generated by [AVA](https://avajs.dev). }, ], } + +## should resolve model correctly based on subscription status and prompt config + +> should honor requested pro model + + 'gemini-2.5-pro' + +> should fallback to default model + + 'gemini-2.5-flash' + +> should fallback to default model when requesting pro model during trialing + + 'gemini-2.5-flash' + +> should honor requested non-pro model during trialing + + 'gemini-2.5-flash' + +> should pick default model when no requested model during trialing + + 'gemini-2.5-flash' + +> should pick first pro model when no requested model during active + + 'gemini-2.5-pro' + +> should honor requested pro model during active + + 'claude-sonnet-4@20250514' + +> should fallback to default model when requesting non-optional model during active + + 'gemini-2.5-flash' diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap b/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap index 11157ea0b6..cd677a05b9 100644 Binary files a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap and b/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap differ diff --git a/packages/backend/server/src/__tests__/copilot.spec.ts b/packages/backend/server/src/__tests__/copilot.spec.ts index 978cee77f5..6a3b69e755 100644 --- a/packages/backend/server/src/__tests__/copilot.spec.ts +++ b/packages/backend/server/src/__tests__/copilot.spec.ts @@ -60,6 +60,9 @@ import { import { AutoRegisteredWorkflowExecutor } from '../plugins/copilot/workflow/executor/utils'; import { WorkflowGraphList } from '../plugins/copilot/workflow/graph'; import { CopilotWorkspaceService } from '../plugins/copilot/workspace'; +import { PaymentModule } from '../plugins/payment'; +import { SubscriptionService } from '../plugins/payment/service'; +import { SubscriptionStatus } from '../plugins/payment/types'; import { MockCopilotProvider } from './mocks'; import { createTestingModule, TestingModule } from './utils'; import { WorkflowTestCases } from './utils/copilot'; @@ -82,6 +85,7 @@ type Context = { storage: CopilotStorage; workflow: CopilotWorkflowService; cronJobs: CopilotCronJobs; + subscription: SubscriptionService; executors: { image: CopilotChatImageExecutor; text: CopilotChatTextExecutor; @@ -116,6 +120,7 @@ test.before(async t => { }, }, }), + PaymentModule, QuotaModule, StorageModule, CopilotModule, @@ -124,6 +129,13 @@ test.before(async t => { // use real JobQueue for testing builder.overrideProvider(JobQueue).useClass(JobQueue); builder.overrideProvider(OpenAIProvider).useClass(MockCopilotProvider); + builder.overrideProvider(SubscriptionService).useClass( + class { + select() { + return { getSubscription: async () => undefined }; + } + } + ); }, }); @@ -145,6 +157,7 @@ test.before(async t => { const transcript = module.get(CopilotTranscriptionService); const workspaceEmbedding = module.get(CopilotWorkspaceService); const cronJobs = module.get(CopilotCronJobs); + const subscription = module.get(SubscriptionService); t.context.module = module; t.context.auth = auth; @@ -163,6 +176,7 @@ test.before(async t => { t.context.transcript = transcript; t.context.workspaceEmbedding = workspaceEmbedding; t.context.cronJobs = cronJobs; + t.context.subscription = subscription; t.context.executors = { image: module.get(CopilotChatImageExecutor), @@ -2047,3 +2061,90 @@ test('should handle copilot cron jobs correctly', async t => { toBeGenerateStub.restore(); jobAddStub.restore(); }); + +test('should resolve model correctly based on subscription status and prompt config', async t => { + const { db, session, subscription } = t.context; + + // 1) Seed a prompt that has optionalModels and proModels in config + const promptName = 'resolve-model-test'; + await db.aiPrompt.create({ + data: { + name: promptName, + model: 'gemini-2.5-flash', + messages: { + create: [{ idx: 0, role: 'system', content: 'test' }], + }, + config: { proModels: ['gemini-2.5-pro', 'claude-sonnet-4@20250514'] }, + optionalModels: [ + 'gemini-2.5-flash', + 'gemini-2.5-pro', + 'claude-sonnet-4@20250514', + ], + }, + }); + + // 2) Create a chat session with this prompt + const sessionId = await session.create({ + promptName, + docId: 'test', + workspaceId: 'test', + userId, + pinned: false, + }); + const s = (await session.get(sessionId))!; + + const mockStatus = (status?: SubscriptionStatus) => { + Sinon.restore(); + Sinon.stub(subscription, 'select').callsFake(() => ({ + // @ts-expect-error mock + getSubscription: async () => (status ? { status } : null), + })); + }; + + // payment disabled -> allow requested if in optional; pro not blocked + { + const model1 = await s.resolveModel(false, 'gemini-2.5-pro'); + t.snapshot(model1, 'should honor requested pro model'); + + const model2 = await s.resolveModel(false, 'not-in-optional'); + t.snapshot(model2, 'should fallback to default model'); + } + + // payment enabled + trialing: requesting pro should fallback to default + { + mockStatus(SubscriptionStatus.Trialing); + const model3 = await s.resolveModel(true, 'gemini-2.5-pro'); + t.snapshot( + model3, + 'should fallback to default model when requesting pro model during trialing' + ); + + const model4 = await s.resolveModel(true, 'gemini-2.5-flash'); + t.snapshot(model4, 'should honor requested non-pro model during trialing'); + + const model5 = await s.resolveModel(true); + t.snapshot( + model5, + 'should pick default model when no requested model during trialing' + ); + } + + // payment enabled + active: without requested -> first pro; requested pro should be honored + { + mockStatus(SubscriptionStatus.Active); + const model6 = await s.resolveModel(true); + t.snapshot( + model6, + 'should pick first pro model when no requested model during active' + ); + + const model7 = await s.resolveModel(true, 'claude-sonnet-4@20250514'); + t.snapshot(model7, 'should honor requested pro model during active'); + + const model8 = await s.resolveModel(true, 'not-in-optional'); + t.snapshot( + model8, + 'should fallback to default model when requesting non-optional model during active' + ); + } +}); diff --git a/packages/backend/server/src/plugins/copilot/config.ts b/packages/backend/server/src/plugins/copilot/config.ts index 813e746017..bc025f6612 100644 --- a/packages/backend/server/src/plugins/copilot/config.ts +++ b/packages/backend/server/src/plugins/copilot/config.ts @@ -51,7 +51,7 @@ defineModuleConfig('copilot', { override_enabled: false, scenarios: { audio_transcribing: 'gemini-2.5-flash', - chat: 'claude-sonnet-4@20250514', + chat: 'gemini-2.5-flash', embedding: 'gemini-embedding-001', image: 'gpt-image-1', rerank: 'gpt-4.1', diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index 240a1051a7..c54f76ecc9 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -44,6 +44,7 @@ import { NoCopilotProviderAvailable, UnsplashIsNotConfigured, } from '../../base'; +import { ServerFeature, ServerService } from '../../core'; import { CurrentUser, Public } from '../../core/auth'; import { CopilotContextService } from './context'; import { @@ -75,6 +76,7 @@ export class CopilotController implements BeforeApplicationShutdown { constructor( private readonly config: Config, + private readonly server: ServerService, private readonly chatSession: ChatSessionService, private readonly context: CopilotContextService, private readonly provider: CopilotProviderFactory, @@ -112,10 +114,10 @@ export class CopilotController implements BeforeApplicationShutdown { throw new CopilotSessionNotFound(); } - const model = - modelId && session.optionalModels.includes(modelId) - ? modelId - : session.model; + const model = await session.resolveModel( + this.server.features.includes(ServerFeature.Payment), + modelId + ); const hasAttachment = messageId ? !!(await session.getMessageById(messageId)).attachments?.length diff --git a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts index 1b53ea9836..07e9d1ad5d 100644 --- a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts +++ b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts @@ -1928,7 +1928,7 @@ Now apply the \`updates\` to the \`content\`, following the intent in \`op\`, an ]; const CHAT_PROMPT: Omit = { - model: 'claude-sonnet-4@20250514', + model: 'gemini-2.5-flash', optionalModels: [ 'gpt-4.1', 'gpt-5', @@ -2099,6 +2099,13 @@ Below is the user's query. Please respond in the user's preferred language witho 'codeArtifact', 'blobRead', ], + proModels: [ + 'gemini-2.5-pro', + 'claude-opus-4@20250514', + 'claude-sonnet-4@20250514', + 'claude-3-7-sonnet@20250219', + 'claude-3-5-sonnet-v2@20241022', + ], }, }; diff --git a/packages/backend/server/src/plugins/copilot/providers/openai.ts b/packages/backend/server/src/plugins/copilot/providers/openai.ts index 12160a1adb..ba872bab8c 100644 --- a/packages/backend/server/src/plugins/copilot/providers/openai.ts +++ b/packages/backend/server/src/plugins/copilot/providers/openai.ts @@ -4,6 +4,10 @@ import { type OpenAIProvider as VercelOpenAIProvider, OpenAIResponsesProviderOptions, } from '@ai-sdk/openai'; +import { + createOpenAICompatible, + type OpenAICompatibleProvider as VercelOpenAICompatibleProvider, +} from '@ai-sdk/openai-compatible'; import { AISDKError, embedMany, @@ -18,6 +22,7 @@ import { z } from 'zod'; import { CopilotPromptInvalid, + CopilotProviderNotSupported, CopilotProviderSideError, metrics, UserFriendlyError, @@ -47,6 +52,7 @@ export const DEFAULT_DIMENSIONS = 256; export type OpenAIConfig = { apiKey: string; baseURL?: string; + oldApiStyle?: boolean; }; const ModelListSchema = z.object({ @@ -296,7 +302,7 @@ export class OpenAIProvider extends CopilotProvider { }, ]; - #instance!: VercelOpenAIProvider; + #instance!: VercelOpenAIProvider | VercelOpenAICompatibleProvider; override configured(): boolean { return !!this.config.apiKey; @@ -304,10 +310,17 @@ export class OpenAIProvider extends CopilotProvider { protected override setup() { super.setup(); - this.#instance = createOpenAI({ - apiKey: this.config.apiKey, - baseURL: this.config.baseURL, - }); + this.#instance = + this.config.oldApiStyle && this.config.baseURL + ? createOpenAICompatible({ + name: 'openai-compatible-old-style', + apiKey: this.config.apiKey, + baseURL: this.config.baseURL, + }) + : createOpenAI({ + apiKey: this.config.apiKey, + baseURL: this.config.baseURL, + }); } private handleError( @@ -341,7 +354,7 @@ export class OpenAIProvider extends CopilotProvider { override async refreshOnlineModels() { try { const baseUrl = this.config.baseURL || 'https://api.openai.com/v1'; - if (baseUrl && !this.onlineModelList.length) { + if (this.config.apiKey && baseUrl && !this.onlineModelList.length) { const { data } = await fetch(`${baseUrl}/models`, { headers: { Authorization: `Bearer ${this.config.apiKey}`, @@ -361,7 +374,11 @@ export class OpenAIProvider extends CopilotProvider { toolName: CopilotChatTools, model: string ): [string, Tool?] | undefined { - if (toolName === 'webSearch' && !this.isReasoningModel(model)) { + if ( + toolName === 'webSearch' && + 'responses' in this.#instance && + !this.isReasoningModel(model) + ) { return ['web_search_preview', openai.tools.webSearchPreview({})]; } else if (toolName === 'docEdit') { return ['doc_edit', undefined]; @@ -374,10 +391,7 @@ export class OpenAIProvider extends CopilotProvider { messages: PromptMessage[], options: CopilotChatOptions = {} ): Promise { - const fullCond = { - ...cond, - outputType: ModelOutputType.Text, - }; + const fullCond = { ...cond, outputType: ModelOutputType.Text }; await this.checkParams({ messages, cond: fullCond, options }); const model = this.selectModel(fullCond); @@ -386,7 +400,10 @@ export class OpenAIProvider extends CopilotProvider { const [system, msgs] = await chatToGPTMessage(messages); - const modelInstance = this.#instance.responses(model.id); + const modelInstance = + 'responses' in this.#instance + ? this.#instance.responses(model.id) + : this.#instance(model.id); const { text } = await generateText({ model: modelInstance, @@ -507,7 +524,10 @@ export class OpenAIProvider extends CopilotProvider { throw new CopilotPromptInvalid('Schema is required'); } - const modelInstance = this.#instance.responses(model.id); + const modelInstance = + 'responses' in this.#instance + ? this.#instance.responses(model.id) + : this.#instance(model.id); const { object } = await generateObject({ model: modelInstance, @@ -539,7 +559,10 @@ export class OpenAIProvider extends CopilotProvider { await this.checkParams({ messages: [], cond: fullCond, options }); const model = this.selectModel(fullCond); // get the log probability of "yes"/"no" - const instance = this.#instance.chat(model.id); + const instance = + 'chat' in this.#instance + ? this.#instance.chat(model.id) + : this.#instance(model.id); const scores = await Promise.all( chunkMessages.map(async messages => { @@ -600,7 +623,10 @@ export class OpenAIProvider extends CopilotProvider { options: CopilotChatOptions = {} ) { const [system, msgs] = await chatToGPTMessage(messages); - const modelInstance = this.#instance.responses(model.id); + const modelInstance = + 'responses' in this.#instance + ? this.#instance.responses(model.id) + : this.#instance(model.id); const { fullStream } = streamText({ model: modelInstance, system, @@ -685,6 +711,13 @@ export class OpenAIProvider extends CopilotProvider { await this.checkParams({ messages, cond: fullCond, options }); const model = this.selectModel(fullCond); + if (!('image' in this.#instance)) { + throw new CopilotProviderNotSupported({ + provider: this.type, + kind: 'image', + }); + } + metrics.ai .counter('generate_images_stream_calls') .add(1, { model: model.id }); @@ -735,6 +768,13 @@ export class OpenAIProvider extends CopilotProvider { await this.checkParams({ embeddings: messages, cond: fullCond, options }); const model = this.selectModel(fullCond); + if (!('embedding' in this.#instance)) { + throw new CopilotProviderNotSupported({ + provider: this.type, + kind: 'embedding', + }); + } + try { metrics.ai .counter('generate_embedding_calls') @@ -775,6 +815,6 @@ export class OpenAIProvider extends CopilotProvider { private isReasoningModel(model: string) { // o series reasoning models - return model.startsWith('o'); + return model.startsWith('o') || model.startsWith('gpt-5'); } } diff --git a/packages/backend/server/src/plugins/copilot/providers/types.ts b/packages/backend/server/src/plugins/copilot/providers/types.ts index fb4cb2ae91..e568be80e6 100644 --- a/packages/backend/server/src/plugins/copilot/providers/types.ts +++ b/packages/backend/server/src/plugins/copilot/providers/types.ts @@ -80,6 +80,7 @@ export const PromptToolsSchema = z export const PromptConfigStrictSchema = z.object({ tools: PromptToolsSchema.nullable().optional(), + proModels: z.array(z.string()).nullable().optional(), // params requirements requireContent: z.boolean().nullable().optional(), requireAttachment: z.boolean().nullable().optional(), diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index a88b7435fc..6ce10c5add 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -25,6 +25,8 @@ import { type UpdateChatSession, UpdateChatSessionOptions, } from '../../models'; +import { SubscriptionService } from '../payment/service'; +import { SubscriptionPlan, SubscriptionStatus } from '../payment/types'; import { ChatMessageCache } from './message'; import { ChatPrompt, PromptService } from './prompt'; import { @@ -58,6 +60,7 @@ declare global { export class ChatSession implements AsyncDisposable { private stashMessageCount = 0; constructor( + private readonly moduleRef: ModuleRef, private readonly messageCache: ChatMessageCache, private readonly state: ChatSessionState, private readonly dispose?: (state: ChatSessionState) => Promise, @@ -72,6 +75,10 @@ export class ChatSession implements AsyncDisposable { return this.state.prompt.optionalModels; } + get proModels() { + return this.state.prompt.config?.proModels || []; + } + get config() { const { sessionId, @@ -93,6 +100,50 @@ export class ChatSession implements AsyncDisposable { return this.state.messages.findLast(m => m.role === 'user'); } + async resolveModel( + hasPayment: boolean, + requestedModelId?: string + ): Promise { + const defaultModel = this.model; + const normalize = (m?: string) => + !!m && this.optionalModels.includes(m) ? m : defaultModel; + const isPro = (m?: string) => !!m && this.proModels.includes(m); + + // try resolve payment subscription service lazily + let paymentEnabled = hasPayment; + let isUserAIPro = false; + try { + if (paymentEnabled) { + const sub = this.moduleRef.get(SubscriptionService, { + strict: false, + }); + const subscription = await sub + .select(SubscriptionPlan.AI) + .getSubscription({ + userId: this.config.userId, + plan: SubscriptionPlan.AI, + } as any); + isUserAIPro = subscription?.status === SubscriptionStatus.Active; + } + } catch { + // payment not available -> skip checks + paymentEnabled = false; + } + + if (paymentEnabled) { + if (isUserAIPro) { + if (!requestedModelId) { + const firstPro = this.proModels[0]; + return normalize(firstPro); + } + } else if (isPro(requestedModelId)) { + return defaultModel; + } + } + + return normalize(requestedModelId); + } + push(message: ChatMessage) { if ( this.state.prompt.action && @@ -539,12 +590,17 @@ export class ChatSessionService { async get(sessionId: string): Promise { const state = await this.getSessionInfo(sessionId); if (state) { - return new ChatSession(this.messageCache, state, async state => { - await this.models.copilotSession.updateMessages(state); - if (!state.prompt.action) { - await this.jobs.add('copilot.session.generateTitle', { sessionId }); + return new ChatSession( + this.moduleRef, + this.messageCache, + state, + async state => { + await this.models.copilotSession.updateMessages(state); + if (!state.prompt.action) { + await this.jobs.add('copilot.session.generateTitle', { sessionId }); + } } - }); + ); } return null; } diff --git a/packages/backend/server/src/plugins/payment/service.ts b/packages/backend/server/src/plugins/payment/service.ts index a4d13d0871..f08e07c7f7 100644 --- a/packages/backend/server/src/plugins/payment/service.ts +++ b/packages/backend/server/src/plugins/payment/service.ts @@ -89,7 +89,7 @@ export class SubscriptionService { return this.stripeProvider.stripe; } - private select(plan: SubscriptionPlan): SubscriptionManager { + select(plan: SubscriptionPlan): SubscriptionManager { switch (plan) { case SubscriptionPlan.Team: return this.workspaceManager;