feat: improve model resolve (#13601)

fix AI-419
This commit is contained in:
DarkSky
2025-09-18 18:51:12 +08:00
committed by GitHub
parent 89646869e4
commit a0b73cdcec
11 changed files with 271 additions and 30 deletions

View File

@@ -51,7 +51,7 @@ defineModuleConfig('copilot', {
override_enabled: false,
scenarios: {
audio_transcribing: 'gemini-2.5-flash',
chat: 'claude-sonnet-4@20250514',
chat: 'gemini-2.5-flash',
embedding: 'gemini-embedding-001',
image: 'gpt-image-1',
rerank: 'gpt-4.1',

View File

@@ -44,6 +44,7 @@ import {
NoCopilotProviderAvailable,
UnsplashIsNotConfigured,
} from '../../base';
import { ServerFeature, ServerService } from '../../core';
import { CurrentUser, Public } from '../../core/auth';
import { CopilotContextService } from './context';
import {
@@ -75,6 +76,7 @@ export class CopilotController implements BeforeApplicationShutdown {
constructor(
private readonly config: Config,
private readonly server: ServerService,
private readonly chatSession: ChatSessionService,
private readonly context: CopilotContextService,
private readonly provider: CopilotProviderFactory,
@@ -112,10 +114,10 @@ export class CopilotController implements BeforeApplicationShutdown {
throw new CopilotSessionNotFound();
}
const model =
modelId && session.optionalModels.includes(modelId)
? modelId
: session.model;
const model = await session.resolveModel(
this.server.features.includes(ServerFeature.Payment),
modelId
);
const hasAttachment = messageId
? !!(await session.getMessageById(messageId)).attachments?.length

View File

@@ -1928,7 +1928,7 @@ Now apply the \`updates\` to the \`content\`, following the intent in \`op\`, an
];
const CHAT_PROMPT: Omit<Prompt, 'name'> = {
model: 'claude-sonnet-4@20250514',
model: 'gemini-2.5-flash',
optionalModels: [
'gpt-4.1',
'gpt-5',
@@ -2099,6 +2099,13 @@ Below is the user's query. Please respond in the user's preferred language witho
'codeArtifact',
'blobRead',
],
proModels: [
'gemini-2.5-pro',
'claude-opus-4@20250514',
'claude-sonnet-4@20250514',
'claude-3-7-sonnet@20250219',
'claude-3-5-sonnet-v2@20241022',
],
},
};

View File

@@ -4,6 +4,10 @@ import {
type OpenAIProvider as VercelOpenAIProvider,
OpenAIResponsesProviderOptions,
} from '@ai-sdk/openai';
import {
createOpenAICompatible,
type OpenAICompatibleProvider as VercelOpenAICompatibleProvider,
} from '@ai-sdk/openai-compatible';
import {
AISDKError,
embedMany,
@@ -18,6 +22,7 @@ import { z } from 'zod';
import {
CopilotPromptInvalid,
CopilotProviderNotSupported,
CopilotProviderSideError,
metrics,
UserFriendlyError,
@@ -47,6 +52,7 @@ export const DEFAULT_DIMENSIONS = 256;
export type OpenAIConfig = {
apiKey: string;
baseURL?: string;
oldApiStyle?: boolean;
};
const ModelListSchema = z.object({
@@ -296,7 +302,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
},
];
#instance!: VercelOpenAIProvider;
#instance!: VercelOpenAIProvider | VercelOpenAICompatibleProvider;
override configured(): boolean {
return !!this.config.apiKey;
@@ -304,10 +310,17 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
protected override setup() {
super.setup();
this.#instance = createOpenAI({
apiKey: this.config.apiKey,
baseURL: this.config.baseURL,
});
this.#instance =
this.config.oldApiStyle && this.config.baseURL
? createOpenAICompatible({
name: 'openai-compatible-old-style',
apiKey: this.config.apiKey,
baseURL: this.config.baseURL,
})
: createOpenAI({
apiKey: this.config.apiKey,
baseURL: this.config.baseURL,
});
}
private handleError(
@@ -341,7 +354,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
override async refreshOnlineModels() {
try {
const baseUrl = this.config.baseURL || 'https://api.openai.com/v1';
if (baseUrl && !this.onlineModelList.length) {
if (this.config.apiKey && baseUrl && !this.onlineModelList.length) {
const { data } = await fetch(`${baseUrl}/models`, {
headers: {
Authorization: `Bearer ${this.config.apiKey}`,
@@ -361,7 +374,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
toolName: CopilotChatTools,
model: string
): [string, Tool?] | undefined {
if (toolName === 'webSearch' && !this.isReasoningModel(model)) {
if (
toolName === 'webSearch' &&
'responses' in this.#instance &&
!this.isReasoningModel(model)
) {
return ['web_search_preview', openai.tools.webSearchPreview({})];
} else if (toolName === 'docEdit') {
return ['doc_edit', undefined];
@@ -374,10 +391,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
messages: PromptMessage[],
options: CopilotChatOptions = {}
): Promise<string> {
const fullCond = {
...cond,
outputType: ModelOutputType.Text,
};
const fullCond = { ...cond, outputType: ModelOutputType.Text };
await this.checkParams({ messages, cond: fullCond, options });
const model = this.selectModel(fullCond);
@@ -386,7 +400,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
const [system, msgs] = await chatToGPTMessage(messages);
const modelInstance = this.#instance.responses(model.id);
const modelInstance =
'responses' in this.#instance
? this.#instance.responses(model.id)
: this.#instance(model.id);
const { text } = await generateText({
model: modelInstance,
@@ -507,7 +524,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
throw new CopilotPromptInvalid('Schema is required');
}
const modelInstance = this.#instance.responses(model.id);
const modelInstance =
'responses' in this.#instance
? this.#instance.responses(model.id)
: this.#instance(model.id);
const { object } = await generateObject({
model: modelInstance,
@@ -539,7 +559,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
await this.checkParams({ messages: [], cond: fullCond, options });
const model = this.selectModel(fullCond);
// get the log probability of "yes"/"no"
const instance = this.#instance.chat(model.id);
const instance =
'chat' in this.#instance
? this.#instance.chat(model.id)
: this.#instance(model.id);
const scores = await Promise.all(
chunkMessages.map(async messages => {
@@ -600,7 +623,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
options: CopilotChatOptions = {}
) {
const [system, msgs] = await chatToGPTMessage(messages);
const modelInstance = this.#instance.responses(model.id);
const modelInstance =
'responses' in this.#instance
? this.#instance.responses(model.id)
: this.#instance(model.id);
const { fullStream } = streamText({
model: modelInstance,
system,
@@ -685,6 +711,13 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
await this.checkParams({ messages, cond: fullCond, options });
const model = this.selectModel(fullCond);
if (!('image' in this.#instance)) {
throw new CopilotProviderNotSupported({
provider: this.type,
kind: 'image',
});
}
metrics.ai
.counter('generate_images_stream_calls')
.add(1, { model: model.id });
@@ -735,6 +768,13 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
await this.checkParams({ embeddings: messages, cond: fullCond, options });
const model = this.selectModel(fullCond);
if (!('embedding' in this.#instance)) {
throw new CopilotProviderNotSupported({
provider: this.type,
kind: 'embedding',
});
}
try {
metrics.ai
.counter('generate_embedding_calls')
@@ -775,6 +815,6 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
private isReasoningModel(model: string) {
// o series reasoning models
return model.startsWith('o');
return model.startsWith('o') || model.startsWith('gpt-5');
}
}

View File

@@ -80,6 +80,7 @@ export const PromptToolsSchema = z
export const PromptConfigStrictSchema = z.object({
tools: PromptToolsSchema.nullable().optional(),
proModels: z.array(z.string()).nullable().optional(),
// params requirements
requireContent: z.boolean().nullable().optional(),
requireAttachment: z.boolean().nullable().optional(),

View File

@@ -25,6 +25,8 @@ import {
type UpdateChatSession,
UpdateChatSessionOptions,
} from '../../models';
import { SubscriptionService } from '../payment/service';
import { SubscriptionPlan, SubscriptionStatus } from '../payment/types';
import { ChatMessageCache } from './message';
import { ChatPrompt, PromptService } from './prompt';
import {
@@ -58,6 +60,7 @@ declare global {
export class ChatSession implements AsyncDisposable {
private stashMessageCount = 0;
constructor(
private readonly moduleRef: ModuleRef,
private readonly messageCache: ChatMessageCache,
private readonly state: ChatSessionState,
private readonly dispose?: (state: ChatSessionState) => Promise<void>,
@@ -72,6 +75,10 @@ export class ChatSession implements AsyncDisposable {
return this.state.prompt.optionalModels;
}
get proModels() {
return this.state.prompt.config?.proModels || [];
}
get config() {
const {
sessionId,
@@ -93,6 +100,50 @@ export class ChatSession implements AsyncDisposable {
return this.state.messages.findLast(m => m.role === 'user');
}
async resolveModel(
hasPayment: boolean,
requestedModelId?: string
): Promise<string> {
const defaultModel = this.model;
const normalize = (m?: string) =>
!!m && this.optionalModels.includes(m) ? m : defaultModel;
const isPro = (m?: string) => !!m && this.proModels.includes(m);
// try resolve payment subscription service lazily
let paymentEnabled = hasPayment;
let isUserAIPro = false;
try {
if (paymentEnabled) {
const sub = this.moduleRef.get(SubscriptionService, {
strict: false,
});
const subscription = await sub
.select(SubscriptionPlan.AI)
.getSubscription({
userId: this.config.userId,
plan: SubscriptionPlan.AI,
} as any);
isUserAIPro = subscription?.status === SubscriptionStatus.Active;
}
} catch {
// payment not available -> skip checks
paymentEnabled = false;
}
if (paymentEnabled) {
if (isUserAIPro) {
if (!requestedModelId) {
const firstPro = this.proModels[0];
return normalize(firstPro);
}
} else if (isPro(requestedModelId)) {
return defaultModel;
}
}
return normalize(requestedModelId);
}
push(message: ChatMessage) {
if (
this.state.prompt.action &&
@@ -539,12 +590,17 @@ export class ChatSessionService {
async get(sessionId: string): Promise<ChatSession | null> {
const state = await this.getSessionInfo(sessionId);
if (state) {
return new ChatSession(this.messageCache, state, async state => {
await this.models.copilotSession.updateMessages(state);
if (!state.prompt.action) {
await this.jobs.add('copilot.session.generateTitle', { sessionId });
return new ChatSession(
this.moduleRef,
this.messageCache,
state,
async state => {
await this.models.copilotSession.updateMessages(state);
if (!state.prompt.action) {
await this.jobs.add('copilot.session.generateTitle', { sessionId });
}
}
});
);
}
return null;
}

View File

@@ -89,7 +89,7 @@ export class SubscriptionService {
return this.stripeProvider.stripe;
}
private select(plan: SubscriptionPlan): SubscriptionManager {
select(plan: SubscriptionPlan): SubscriptionManager {
switch (plan) {
case SubscriptionPlan.Team:
return this.workspaceManager;