feat: improve model resolve (#13601)

fix AI-419
This commit is contained in:
DarkSky
2025-09-18 18:51:12 +08:00
committed by GitHub
parent 89646869e4
commit a0b73cdcec
11 changed files with 271 additions and 30 deletions

View File

@@ -444,3 +444,37 @@ Generated by [AVA](https://avajs.dev).
},
],
}
## should resolve model correctly based on subscription status and prompt config
> should honor requested pro model
'gemini-2.5-pro'
> should fallback to default model
'gemini-2.5-flash'
> should fallback to default model when requesting pro model during trialing
'gemini-2.5-flash'
> should honor requested non-pro model during trialing
'gemini-2.5-flash'
> should pick default model when no requested model during trialing
'gemini-2.5-flash'
> should pick first pro model when no requested model during active
'gemini-2.5-pro'
> should honor requested pro model during active
'claude-sonnet-4@20250514'
> should fallback to default model when requesting non-optional model during active
'gemini-2.5-flash'

View File

@@ -60,6 +60,9 @@ import {
import { AutoRegisteredWorkflowExecutor } from '../plugins/copilot/workflow/executor/utils';
import { WorkflowGraphList } from '../plugins/copilot/workflow/graph';
import { CopilotWorkspaceService } from '../plugins/copilot/workspace';
import { PaymentModule } from '../plugins/payment';
import { SubscriptionService } from '../plugins/payment/service';
import { SubscriptionStatus } from '../plugins/payment/types';
import { MockCopilotProvider } from './mocks';
import { createTestingModule, TestingModule } from './utils';
import { WorkflowTestCases } from './utils/copilot';
@@ -82,6 +85,7 @@ type Context = {
storage: CopilotStorage;
workflow: CopilotWorkflowService;
cronJobs: CopilotCronJobs;
subscription: SubscriptionService;
executors: {
image: CopilotChatImageExecutor;
text: CopilotChatTextExecutor;
@@ -116,6 +120,7 @@ test.before(async t => {
},
},
}),
PaymentModule,
QuotaModule,
StorageModule,
CopilotModule,
@@ -124,6 +129,13 @@ test.before(async t => {
// use real JobQueue for testing
builder.overrideProvider(JobQueue).useClass(JobQueue);
builder.overrideProvider(OpenAIProvider).useClass(MockCopilotProvider);
builder.overrideProvider(SubscriptionService).useClass(
class {
select() {
return { getSubscription: async () => undefined };
}
}
);
},
});
@@ -145,6 +157,7 @@ test.before(async t => {
const transcript = module.get(CopilotTranscriptionService);
const workspaceEmbedding = module.get(CopilotWorkspaceService);
const cronJobs = module.get(CopilotCronJobs);
const subscription = module.get(SubscriptionService);
t.context.module = module;
t.context.auth = auth;
@@ -163,6 +176,7 @@ test.before(async t => {
t.context.transcript = transcript;
t.context.workspaceEmbedding = workspaceEmbedding;
t.context.cronJobs = cronJobs;
t.context.subscription = subscription;
t.context.executors = {
image: module.get(CopilotChatImageExecutor),
@@ -2047,3 +2061,90 @@ test('should handle copilot cron jobs correctly', async t => {
toBeGenerateStub.restore();
jobAddStub.restore();
});
test('should resolve model correctly based on subscription status and prompt config', async t => {
const { db, session, subscription } = t.context;
// 1) Seed a prompt that has optionalModels and proModels in config
const promptName = 'resolve-model-test';
await db.aiPrompt.create({
data: {
name: promptName,
model: 'gemini-2.5-flash',
messages: {
create: [{ idx: 0, role: 'system', content: 'test' }],
},
config: { proModels: ['gemini-2.5-pro', 'claude-sonnet-4@20250514'] },
optionalModels: [
'gemini-2.5-flash',
'gemini-2.5-pro',
'claude-sonnet-4@20250514',
],
},
});
// 2) Create a chat session with this prompt
const sessionId = await session.create({
promptName,
docId: 'test',
workspaceId: 'test',
userId,
pinned: false,
});
const s = (await session.get(sessionId))!;
const mockStatus = (status?: SubscriptionStatus) => {
Sinon.restore();
Sinon.stub(subscription, 'select').callsFake(() => ({
// @ts-expect-error mock
getSubscription: async () => (status ? { status } : null),
}));
};
// payment disabled -> allow requested if in optional; pro not blocked
{
const model1 = await s.resolveModel(false, 'gemini-2.5-pro');
t.snapshot(model1, 'should honor requested pro model');
const model2 = await s.resolveModel(false, 'not-in-optional');
t.snapshot(model2, 'should fallback to default model');
}
// payment enabled + trialing: requesting pro should fallback to default
{
mockStatus(SubscriptionStatus.Trialing);
const model3 = await s.resolveModel(true, 'gemini-2.5-pro');
t.snapshot(
model3,
'should fallback to default model when requesting pro model during trialing'
);
const model4 = await s.resolveModel(true, 'gemini-2.5-flash');
t.snapshot(model4, 'should honor requested non-pro model during trialing');
const model5 = await s.resolveModel(true);
t.snapshot(
model5,
'should pick default model when no requested model during trialing'
);
}
// payment enabled + active: without requested -> first pro; requested pro should be honored
{
mockStatus(SubscriptionStatus.Active);
const model6 = await s.resolveModel(true);
t.snapshot(
model6,
'should pick first pro model when no requested model during active'
);
const model7 = await s.resolveModel(true, 'claude-sonnet-4@20250514');
t.snapshot(model7, 'should honor requested pro model during active');
const model8 = await s.resolveModel(true, 'not-in-optional');
t.snapshot(
model8,
'should fallback to default model when requesting non-optional model during active'
);
}
});

View File

@@ -51,7 +51,7 @@ defineModuleConfig('copilot', {
override_enabled: false,
scenarios: {
audio_transcribing: 'gemini-2.5-flash',
chat: 'claude-sonnet-4@20250514',
chat: 'gemini-2.5-flash',
embedding: 'gemini-embedding-001',
image: 'gpt-image-1',
rerank: 'gpt-4.1',

View File

@@ -44,6 +44,7 @@ import {
NoCopilotProviderAvailable,
UnsplashIsNotConfigured,
} from '../../base';
import { ServerFeature, ServerService } from '../../core';
import { CurrentUser, Public } from '../../core/auth';
import { CopilotContextService } from './context';
import {
@@ -75,6 +76,7 @@ export class CopilotController implements BeforeApplicationShutdown {
constructor(
private readonly config: Config,
private readonly server: ServerService,
private readonly chatSession: ChatSessionService,
private readonly context: CopilotContextService,
private readonly provider: CopilotProviderFactory,
@@ -112,10 +114,10 @@ export class CopilotController implements BeforeApplicationShutdown {
throw new CopilotSessionNotFound();
}
const model =
modelId && session.optionalModels.includes(modelId)
? modelId
: session.model;
const model = await session.resolveModel(
this.server.features.includes(ServerFeature.Payment),
modelId
);
const hasAttachment = messageId
? !!(await session.getMessageById(messageId)).attachments?.length

View File

@@ -1928,7 +1928,7 @@ Now apply the \`updates\` to the \`content\`, following the intent in \`op\`, an
];
const CHAT_PROMPT: Omit<Prompt, 'name'> = {
model: 'claude-sonnet-4@20250514',
model: 'gemini-2.5-flash',
optionalModels: [
'gpt-4.1',
'gpt-5',
@@ -2099,6 +2099,13 @@ Below is the user's query. Please respond in the user's preferred language witho
'codeArtifact',
'blobRead',
],
proModels: [
'gemini-2.5-pro',
'claude-opus-4@20250514',
'claude-sonnet-4@20250514',
'claude-3-7-sonnet@20250219',
'claude-3-5-sonnet-v2@20241022',
],
},
};

View File

@@ -4,6 +4,10 @@ import {
type OpenAIProvider as VercelOpenAIProvider,
OpenAIResponsesProviderOptions,
} from '@ai-sdk/openai';
import {
createOpenAICompatible,
type OpenAICompatibleProvider as VercelOpenAICompatibleProvider,
} from '@ai-sdk/openai-compatible';
import {
AISDKError,
embedMany,
@@ -18,6 +22,7 @@ import { z } from 'zod';
import {
CopilotPromptInvalid,
CopilotProviderNotSupported,
CopilotProviderSideError,
metrics,
UserFriendlyError,
@@ -47,6 +52,7 @@ export const DEFAULT_DIMENSIONS = 256;
export type OpenAIConfig = {
apiKey: string;
baseURL?: string;
oldApiStyle?: boolean;
};
const ModelListSchema = z.object({
@@ -296,7 +302,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
},
];
#instance!: VercelOpenAIProvider;
#instance!: VercelOpenAIProvider | VercelOpenAICompatibleProvider;
override configured(): boolean {
return !!this.config.apiKey;
@@ -304,10 +310,17 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
protected override setup() {
super.setup();
this.#instance = createOpenAI({
apiKey: this.config.apiKey,
baseURL: this.config.baseURL,
});
this.#instance =
this.config.oldApiStyle && this.config.baseURL
? createOpenAICompatible({
name: 'openai-compatible-old-style',
apiKey: this.config.apiKey,
baseURL: this.config.baseURL,
})
: createOpenAI({
apiKey: this.config.apiKey,
baseURL: this.config.baseURL,
});
}
private handleError(
@@ -341,7 +354,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
override async refreshOnlineModels() {
try {
const baseUrl = this.config.baseURL || 'https://api.openai.com/v1';
if (baseUrl && !this.onlineModelList.length) {
if (this.config.apiKey && baseUrl && !this.onlineModelList.length) {
const { data } = await fetch(`${baseUrl}/models`, {
headers: {
Authorization: `Bearer ${this.config.apiKey}`,
@@ -361,7 +374,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
toolName: CopilotChatTools,
model: string
): [string, Tool?] | undefined {
if (toolName === 'webSearch' && !this.isReasoningModel(model)) {
if (
toolName === 'webSearch' &&
'responses' in this.#instance &&
!this.isReasoningModel(model)
) {
return ['web_search_preview', openai.tools.webSearchPreview({})];
} else if (toolName === 'docEdit') {
return ['doc_edit', undefined];
@@ -374,10 +391,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
messages: PromptMessage[],
options: CopilotChatOptions = {}
): Promise<string> {
const fullCond = {
...cond,
outputType: ModelOutputType.Text,
};
const fullCond = { ...cond, outputType: ModelOutputType.Text };
await this.checkParams({ messages, cond: fullCond, options });
const model = this.selectModel(fullCond);
@@ -386,7 +400,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
const [system, msgs] = await chatToGPTMessage(messages);
const modelInstance = this.#instance.responses(model.id);
const modelInstance =
'responses' in this.#instance
? this.#instance.responses(model.id)
: this.#instance(model.id);
const { text } = await generateText({
model: modelInstance,
@@ -507,7 +524,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
throw new CopilotPromptInvalid('Schema is required');
}
const modelInstance = this.#instance.responses(model.id);
const modelInstance =
'responses' in this.#instance
? this.#instance.responses(model.id)
: this.#instance(model.id);
const { object } = await generateObject({
model: modelInstance,
@@ -539,7 +559,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
await this.checkParams({ messages: [], cond: fullCond, options });
const model = this.selectModel(fullCond);
// get the log probability of "yes"/"no"
const instance = this.#instance.chat(model.id);
const instance =
'chat' in this.#instance
? this.#instance.chat(model.id)
: this.#instance(model.id);
const scores = await Promise.all(
chunkMessages.map(async messages => {
@@ -600,7 +623,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
options: CopilotChatOptions = {}
) {
const [system, msgs] = await chatToGPTMessage(messages);
const modelInstance = this.#instance.responses(model.id);
const modelInstance =
'responses' in this.#instance
? this.#instance.responses(model.id)
: this.#instance(model.id);
const { fullStream } = streamText({
model: modelInstance,
system,
@@ -685,6 +711,13 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
await this.checkParams({ messages, cond: fullCond, options });
const model = this.selectModel(fullCond);
if (!('image' in this.#instance)) {
throw new CopilotProviderNotSupported({
provider: this.type,
kind: 'image',
});
}
metrics.ai
.counter('generate_images_stream_calls')
.add(1, { model: model.id });
@@ -735,6 +768,13 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
await this.checkParams({ embeddings: messages, cond: fullCond, options });
const model = this.selectModel(fullCond);
if (!('embedding' in this.#instance)) {
throw new CopilotProviderNotSupported({
provider: this.type,
kind: 'embedding',
});
}
try {
metrics.ai
.counter('generate_embedding_calls')
@@ -775,6 +815,6 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
private isReasoningModel(model: string) {
// o series reasoning models
return model.startsWith('o');
return model.startsWith('o') || model.startsWith('gpt-5');
}
}

View File

@@ -80,6 +80,7 @@ export const PromptToolsSchema = z
export const PromptConfigStrictSchema = z.object({
tools: PromptToolsSchema.nullable().optional(),
proModels: z.array(z.string()).nullable().optional(),
// params requirements
requireContent: z.boolean().nullable().optional(),
requireAttachment: z.boolean().nullable().optional(),

View File

@@ -25,6 +25,8 @@ import {
type UpdateChatSession,
UpdateChatSessionOptions,
} from '../../models';
import { SubscriptionService } from '../payment/service';
import { SubscriptionPlan, SubscriptionStatus } from '../payment/types';
import { ChatMessageCache } from './message';
import { ChatPrompt, PromptService } from './prompt';
import {
@@ -58,6 +60,7 @@ declare global {
export class ChatSession implements AsyncDisposable {
private stashMessageCount = 0;
constructor(
private readonly moduleRef: ModuleRef,
private readonly messageCache: ChatMessageCache,
private readonly state: ChatSessionState,
private readonly dispose?: (state: ChatSessionState) => Promise<void>,
@@ -72,6 +75,10 @@ export class ChatSession implements AsyncDisposable {
return this.state.prompt.optionalModels;
}
get proModels() {
return this.state.prompt.config?.proModels || [];
}
get config() {
const {
sessionId,
@@ -93,6 +100,50 @@ export class ChatSession implements AsyncDisposable {
return this.state.messages.findLast(m => m.role === 'user');
}
async resolveModel(
hasPayment: boolean,
requestedModelId?: string
): Promise<string> {
const defaultModel = this.model;
const normalize = (m?: string) =>
!!m && this.optionalModels.includes(m) ? m : defaultModel;
const isPro = (m?: string) => !!m && this.proModels.includes(m);
// try resolve payment subscription service lazily
let paymentEnabled = hasPayment;
let isUserAIPro = false;
try {
if (paymentEnabled) {
const sub = this.moduleRef.get(SubscriptionService, {
strict: false,
});
const subscription = await sub
.select(SubscriptionPlan.AI)
.getSubscription({
userId: this.config.userId,
plan: SubscriptionPlan.AI,
} as any);
isUserAIPro = subscription?.status === SubscriptionStatus.Active;
}
} catch {
// payment not available -> skip checks
paymentEnabled = false;
}
if (paymentEnabled) {
if (isUserAIPro) {
if (!requestedModelId) {
const firstPro = this.proModels[0];
return normalize(firstPro);
}
} else if (isPro(requestedModelId)) {
return defaultModel;
}
}
return normalize(requestedModelId);
}
push(message: ChatMessage) {
if (
this.state.prompt.action &&
@@ -539,12 +590,17 @@ export class ChatSessionService {
async get(sessionId: string): Promise<ChatSession | null> {
const state = await this.getSessionInfo(sessionId);
if (state) {
return new ChatSession(this.messageCache, state, async state => {
await this.models.copilotSession.updateMessages(state);
if (!state.prompt.action) {
await this.jobs.add('copilot.session.generateTitle', { sessionId });
return new ChatSession(
this.moduleRef,
this.messageCache,
state,
async state => {
await this.models.copilotSession.updateMessages(state);
if (!state.prompt.action) {
await this.jobs.add('copilot.session.generateTitle', { sessionId });
}
}
});
);
}
return null;
}

View File

@@ -89,7 +89,7 @@ export class SubscriptionService {
return this.stripeProvider.stripe;
}
private select(plan: SubscriptionPlan): SubscriptionManager {
select(plan: SubscriptionPlan): SubscriptionManager {
switch (plan) {
case SubscriptionPlan.Team:
return this.workspaceManager;