feat: improve model resolve (#13601)

fix AI-419
This commit is contained in:
DarkSky
2025-09-18 18:51:12 +08:00
committed by GitHub
parent 89646869e4
commit a0b73cdcec
11 changed files with 271 additions and 30 deletions

View File

@@ -25,6 +25,8 @@ import {
type UpdateChatSession,
UpdateChatSessionOptions,
} from '../../models';
import { SubscriptionService } from '../payment/service';
import { SubscriptionPlan, SubscriptionStatus } from '../payment/types';
import { ChatMessageCache } from './message';
import { ChatPrompt, PromptService } from './prompt';
import {
@@ -58,6 +60,7 @@ declare global {
export class ChatSession implements AsyncDisposable {
private stashMessageCount = 0;
constructor(
private readonly moduleRef: ModuleRef,
private readonly messageCache: ChatMessageCache,
private readonly state: ChatSessionState,
private readonly dispose?: (state: ChatSessionState) => Promise<void>,
@@ -72,6 +75,10 @@ export class ChatSession implements AsyncDisposable {
return this.state.prompt.optionalModels;
}
get proModels() {
return this.state.prompt.config?.proModels || [];
}
get config() {
const {
sessionId,
@@ -93,6 +100,50 @@ export class ChatSession implements AsyncDisposable {
return this.state.messages.findLast(m => m.role === 'user');
}
async resolveModel(
hasPayment: boolean,
requestedModelId?: string
): Promise<string> {
const defaultModel = this.model;
const normalize = (m?: string) =>
!!m && this.optionalModels.includes(m) ? m : defaultModel;
const isPro = (m?: string) => !!m && this.proModels.includes(m);
// try resolve payment subscription service lazily
let paymentEnabled = hasPayment;
let isUserAIPro = false;
try {
if (paymentEnabled) {
const sub = this.moduleRef.get(SubscriptionService, {
strict: false,
});
const subscription = await sub
.select(SubscriptionPlan.AI)
.getSubscription({
userId: this.config.userId,
plan: SubscriptionPlan.AI,
} as any);
isUserAIPro = subscription?.status === SubscriptionStatus.Active;
}
} catch {
// payment not available -> skip checks
paymentEnabled = false;
}
if (paymentEnabled) {
if (isUserAIPro) {
if (!requestedModelId) {
const firstPro = this.proModels[0];
return normalize(firstPro);
}
} else if (isPro(requestedModelId)) {
return defaultModel;
}
}
return normalize(requestedModelId);
}
push(message: ChatMessage) {
if (
this.state.prompt.action &&
@@ -539,12 +590,17 @@ export class ChatSessionService {
async get(sessionId: string): Promise<ChatSession | null> {
const state = await this.getSessionInfo(sessionId);
if (state) {
return new ChatSession(this.messageCache, state, async state => {
await this.models.copilotSession.updateMessages(state);
if (!state.prompt.action) {
await this.jobs.add('copilot.session.generateTitle', { sessionId });
return new ChatSession(
this.moduleRef,
this.messageCache,
state,
async state => {
await this.models.copilotSession.updateMessages(state);
if (!state.prompt.action) {
await this.jobs.add('copilot.session.generateTitle', { sessionId });
}
}
});
);
}
return null;
}