mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 21:05:19 +00:00
@@ -444,3 +444,37 @@ Generated by [AVA](https://avajs.dev).
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
## should resolve model correctly based on subscription status and prompt config
|
||||
|
||||
> should honor requested pro model
|
||||
|
||||
'gemini-2.5-pro'
|
||||
|
||||
> should fallback to default model
|
||||
|
||||
'gemini-2.5-flash'
|
||||
|
||||
> should fallback to default model when requesting pro model during trialing
|
||||
|
||||
'gemini-2.5-flash'
|
||||
|
||||
> should honor requested non-pro model during trialing
|
||||
|
||||
'gemini-2.5-flash'
|
||||
|
||||
> should pick default model when no requested model during trialing
|
||||
|
||||
'gemini-2.5-flash'
|
||||
|
||||
> should pick first pro model when no requested model during active
|
||||
|
||||
'gemini-2.5-pro'
|
||||
|
||||
> should honor requested pro model during active
|
||||
|
||||
'claude-sonnet-4@20250514'
|
||||
|
||||
> should fallback to default model when requesting non-optional model during active
|
||||
|
||||
'gemini-2.5-flash'
|
||||
|
||||
Binary file not shown.
@@ -60,6 +60,9 @@ import {
|
||||
import { AutoRegisteredWorkflowExecutor } from '../plugins/copilot/workflow/executor/utils';
|
||||
import { WorkflowGraphList } from '../plugins/copilot/workflow/graph';
|
||||
import { CopilotWorkspaceService } from '../plugins/copilot/workspace';
|
||||
import { PaymentModule } from '../plugins/payment';
|
||||
import { SubscriptionService } from '../plugins/payment/service';
|
||||
import { SubscriptionStatus } from '../plugins/payment/types';
|
||||
import { MockCopilotProvider } from './mocks';
|
||||
import { createTestingModule, TestingModule } from './utils';
|
||||
import { WorkflowTestCases } from './utils/copilot';
|
||||
@@ -82,6 +85,7 @@ type Context = {
|
||||
storage: CopilotStorage;
|
||||
workflow: CopilotWorkflowService;
|
||||
cronJobs: CopilotCronJobs;
|
||||
subscription: SubscriptionService;
|
||||
executors: {
|
||||
image: CopilotChatImageExecutor;
|
||||
text: CopilotChatTextExecutor;
|
||||
@@ -116,6 +120,7 @@ test.before(async t => {
|
||||
},
|
||||
},
|
||||
}),
|
||||
PaymentModule,
|
||||
QuotaModule,
|
||||
StorageModule,
|
||||
CopilotModule,
|
||||
@@ -124,6 +129,13 @@ test.before(async t => {
|
||||
// use real JobQueue for testing
|
||||
builder.overrideProvider(JobQueue).useClass(JobQueue);
|
||||
builder.overrideProvider(OpenAIProvider).useClass(MockCopilotProvider);
|
||||
builder.overrideProvider(SubscriptionService).useClass(
|
||||
class {
|
||||
select() {
|
||||
return { getSubscription: async () => undefined };
|
||||
}
|
||||
}
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
@@ -145,6 +157,7 @@ test.before(async t => {
|
||||
const transcript = module.get(CopilotTranscriptionService);
|
||||
const workspaceEmbedding = module.get(CopilotWorkspaceService);
|
||||
const cronJobs = module.get(CopilotCronJobs);
|
||||
const subscription = module.get(SubscriptionService);
|
||||
|
||||
t.context.module = module;
|
||||
t.context.auth = auth;
|
||||
@@ -163,6 +176,7 @@ test.before(async t => {
|
||||
t.context.transcript = transcript;
|
||||
t.context.workspaceEmbedding = workspaceEmbedding;
|
||||
t.context.cronJobs = cronJobs;
|
||||
t.context.subscription = subscription;
|
||||
|
||||
t.context.executors = {
|
||||
image: module.get(CopilotChatImageExecutor),
|
||||
@@ -2047,3 +2061,90 @@ test('should handle copilot cron jobs correctly', async t => {
|
||||
toBeGenerateStub.restore();
|
||||
jobAddStub.restore();
|
||||
});
|
||||
|
||||
test('should resolve model correctly based on subscription status and prompt config', async t => {
|
||||
const { db, session, subscription } = t.context;
|
||||
|
||||
// 1) Seed a prompt that has optionalModels and proModels in config
|
||||
const promptName = 'resolve-model-test';
|
||||
await db.aiPrompt.create({
|
||||
data: {
|
||||
name: promptName,
|
||||
model: 'gemini-2.5-flash',
|
||||
messages: {
|
||||
create: [{ idx: 0, role: 'system', content: 'test' }],
|
||||
},
|
||||
config: { proModels: ['gemini-2.5-pro', 'claude-sonnet-4@20250514'] },
|
||||
optionalModels: [
|
||||
'gemini-2.5-flash',
|
||||
'gemini-2.5-pro',
|
||||
'claude-sonnet-4@20250514',
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
// 2) Create a chat session with this prompt
|
||||
const sessionId = await session.create({
|
||||
promptName,
|
||||
docId: 'test',
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
pinned: false,
|
||||
});
|
||||
const s = (await session.get(sessionId))!;
|
||||
|
||||
const mockStatus = (status?: SubscriptionStatus) => {
|
||||
Sinon.restore();
|
||||
Sinon.stub(subscription, 'select').callsFake(() => ({
|
||||
// @ts-expect-error mock
|
||||
getSubscription: async () => (status ? { status } : null),
|
||||
}));
|
||||
};
|
||||
|
||||
// payment disabled -> allow requested if in optional; pro not blocked
|
||||
{
|
||||
const model1 = await s.resolveModel(false, 'gemini-2.5-pro');
|
||||
t.snapshot(model1, 'should honor requested pro model');
|
||||
|
||||
const model2 = await s.resolveModel(false, 'not-in-optional');
|
||||
t.snapshot(model2, 'should fallback to default model');
|
||||
}
|
||||
|
||||
// payment enabled + trialing: requesting pro should fallback to default
|
||||
{
|
||||
mockStatus(SubscriptionStatus.Trialing);
|
||||
const model3 = await s.resolveModel(true, 'gemini-2.5-pro');
|
||||
t.snapshot(
|
||||
model3,
|
||||
'should fallback to default model when requesting pro model during trialing'
|
||||
);
|
||||
|
||||
const model4 = await s.resolveModel(true, 'gemini-2.5-flash');
|
||||
t.snapshot(model4, 'should honor requested non-pro model during trialing');
|
||||
|
||||
const model5 = await s.resolveModel(true);
|
||||
t.snapshot(
|
||||
model5,
|
||||
'should pick default model when no requested model during trialing'
|
||||
);
|
||||
}
|
||||
|
||||
// payment enabled + active: without requested -> first pro; requested pro should be honored
|
||||
{
|
||||
mockStatus(SubscriptionStatus.Active);
|
||||
const model6 = await s.resolveModel(true);
|
||||
t.snapshot(
|
||||
model6,
|
||||
'should pick first pro model when no requested model during active'
|
||||
);
|
||||
|
||||
const model7 = await s.resolveModel(true, 'claude-sonnet-4@20250514');
|
||||
t.snapshot(model7, 'should honor requested pro model during active');
|
||||
|
||||
const model8 = await s.resolveModel(true, 'not-in-optional');
|
||||
t.snapshot(
|
||||
model8,
|
||||
'should fallback to default model when requesting non-optional model during active'
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -51,7 +51,7 @@ defineModuleConfig('copilot', {
|
||||
override_enabled: false,
|
||||
scenarios: {
|
||||
audio_transcribing: 'gemini-2.5-flash',
|
||||
chat: 'claude-sonnet-4@20250514',
|
||||
chat: 'gemini-2.5-flash',
|
||||
embedding: 'gemini-embedding-001',
|
||||
image: 'gpt-image-1',
|
||||
rerank: 'gpt-4.1',
|
||||
|
||||
@@ -44,6 +44,7 @@ import {
|
||||
NoCopilotProviderAvailable,
|
||||
UnsplashIsNotConfigured,
|
||||
} from '../../base';
|
||||
import { ServerFeature, ServerService } from '../../core';
|
||||
import { CurrentUser, Public } from '../../core/auth';
|
||||
import { CopilotContextService } from './context';
|
||||
import {
|
||||
@@ -75,6 +76,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly server: ServerService,
|
||||
private readonly chatSession: ChatSessionService,
|
||||
private readonly context: CopilotContextService,
|
||||
private readonly provider: CopilotProviderFactory,
|
||||
@@ -112,10 +114,10 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
|
||||
const model =
|
||||
modelId && session.optionalModels.includes(modelId)
|
||||
? modelId
|
||||
: session.model;
|
||||
const model = await session.resolveModel(
|
||||
this.server.features.includes(ServerFeature.Payment),
|
||||
modelId
|
||||
);
|
||||
|
||||
const hasAttachment = messageId
|
||||
? !!(await session.getMessageById(messageId)).attachments?.length
|
||||
|
||||
@@ -1928,7 +1928,7 @@ Now apply the \`updates\` to the \`content\`, following the intent in \`op\`, an
|
||||
];
|
||||
|
||||
const CHAT_PROMPT: Omit<Prompt, 'name'> = {
|
||||
model: 'claude-sonnet-4@20250514',
|
||||
model: 'gemini-2.5-flash',
|
||||
optionalModels: [
|
||||
'gpt-4.1',
|
||||
'gpt-5',
|
||||
@@ -2099,6 +2099,13 @@ Below is the user's query. Please respond in the user's preferred language witho
|
||||
'codeArtifact',
|
||||
'blobRead',
|
||||
],
|
||||
proModels: [
|
||||
'gemini-2.5-pro',
|
||||
'claude-opus-4@20250514',
|
||||
'claude-sonnet-4@20250514',
|
||||
'claude-3-7-sonnet@20250219',
|
||||
'claude-3-5-sonnet-v2@20241022',
|
||||
],
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -4,6 +4,10 @@ import {
|
||||
type OpenAIProvider as VercelOpenAIProvider,
|
||||
OpenAIResponsesProviderOptions,
|
||||
} from '@ai-sdk/openai';
|
||||
import {
|
||||
createOpenAICompatible,
|
||||
type OpenAICompatibleProvider as VercelOpenAICompatibleProvider,
|
||||
} from '@ai-sdk/openai-compatible';
|
||||
import {
|
||||
AISDKError,
|
||||
embedMany,
|
||||
@@ -18,6 +22,7 @@ import { z } from 'zod';
|
||||
|
||||
import {
|
||||
CopilotPromptInvalid,
|
||||
CopilotProviderNotSupported,
|
||||
CopilotProviderSideError,
|
||||
metrics,
|
||||
UserFriendlyError,
|
||||
@@ -47,6 +52,7 @@ export const DEFAULT_DIMENSIONS = 256;
|
||||
export type OpenAIConfig = {
|
||||
apiKey: string;
|
||||
baseURL?: string;
|
||||
oldApiStyle?: boolean;
|
||||
};
|
||||
|
||||
const ModelListSchema = z.object({
|
||||
@@ -296,7 +302,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
},
|
||||
];
|
||||
|
||||
#instance!: VercelOpenAIProvider;
|
||||
#instance!: VercelOpenAIProvider | VercelOpenAICompatibleProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
return !!this.config.apiKey;
|
||||
@@ -304,10 +310,17 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
protected override setup() {
|
||||
super.setup();
|
||||
this.#instance = createOpenAI({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseURL,
|
||||
});
|
||||
this.#instance =
|
||||
this.config.oldApiStyle && this.config.baseURL
|
||||
? createOpenAICompatible({
|
||||
name: 'openai-compatible-old-style',
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseURL,
|
||||
})
|
||||
: createOpenAI({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseURL,
|
||||
});
|
||||
}
|
||||
|
||||
private handleError(
|
||||
@@ -341,7 +354,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
override async refreshOnlineModels() {
|
||||
try {
|
||||
const baseUrl = this.config.baseURL || 'https://api.openai.com/v1';
|
||||
if (baseUrl && !this.onlineModelList.length) {
|
||||
if (this.config.apiKey && baseUrl && !this.onlineModelList.length) {
|
||||
const { data } = await fetch(`${baseUrl}/models`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.config.apiKey}`,
|
||||
@@ -361,7 +374,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
toolName: CopilotChatTools,
|
||||
model: string
|
||||
): [string, Tool?] | undefined {
|
||||
if (toolName === 'webSearch' && !this.isReasoningModel(model)) {
|
||||
if (
|
||||
toolName === 'webSearch' &&
|
||||
'responses' in this.#instance &&
|
||||
!this.isReasoningModel(model)
|
||||
) {
|
||||
return ['web_search_preview', openai.tools.webSearchPreview({})];
|
||||
} else if (toolName === 'docEdit') {
|
||||
return ['doc_edit', undefined];
|
||||
@@ -374,10 +391,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Text,
|
||||
};
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
@@ -386,7 +400,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const modelInstance = this.#instance.responses(model.id);
|
||||
const modelInstance =
|
||||
'responses' in this.#instance
|
||||
? this.#instance.responses(model.id)
|
||||
: this.#instance(model.id);
|
||||
|
||||
const { text } = await generateText({
|
||||
model: modelInstance,
|
||||
@@ -507,7 +524,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
throw new CopilotPromptInvalid('Schema is required');
|
||||
}
|
||||
|
||||
const modelInstance = this.#instance.responses(model.id);
|
||||
const modelInstance =
|
||||
'responses' in this.#instance
|
||||
? this.#instance.responses(model.id)
|
||||
: this.#instance(model.id);
|
||||
|
||||
const { object } = await generateObject({
|
||||
model: modelInstance,
|
||||
@@ -539,7 +559,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
await this.checkParams({ messages: [], cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
// get the log probability of "yes"/"no"
|
||||
const instance = this.#instance.chat(model.id);
|
||||
const instance =
|
||||
'chat' in this.#instance
|
||||
? this.#instance.chat(model.id)
|
||||
: this.#instance(model.id);
|
||||
|
||||
const scores = await Promise.all(
|
||||
chunkMessages.map(async messages => {
|
||||
@@ -600,7 +623,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
options: CopilotChatOptions = {}
|
||||
) {
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
const modelInstance = this.#instance.responses(model.id);
|
||||
const modelInstance =
|
||||
'responses' in this.#instance
|
||||
? this.#instance.responses(model.id)
|
||||
: this.#instance(model.id);
|
||||
const { fullStream } = streamText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
@@ -685,6 +711,13 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
if (!('image' in this.#instance)) {
|
||||
throw new CopilotProviderNotSupported({
|
||||
provider: this.type,
|
||||
kind: 'image',
|
||||
});
|
||||
}
|
||||
|
||||
metrics.ai
|
||||
.counter('generate_images_stream_calls')
|
||||
.add(1, { model: model.id });
|
||||
@@ -735,6 +768,13 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
await this.checkParams({ embeddings: messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
if (!('embedding' in this.#instance)) {
|
||||
throw new CopilotProviderNotSupported({
|
||||
provider: this.type,
|
||||
kind: 'embedding',
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('generate_embedding_calls')
|
||||
@@ -775,6 +815,6 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
private isReasoningModel(model: string) {
|
||||
// o series reasoning models
|
||||
return model.startsWith('o');
|
||||
return model.startsWith('o') || model.startsWith('gpt-5');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,6 +80,7 @@ export const PromptToolsSchema = z
|
||||
|
||||
export const PromptConfigStrictSchema = z.object({
|
||||
tools: PromptToolsSchema.nullable().optional(),
|
||||
proModels: z.array(z.string()).nullable().optional(),
|
||||
// params requirements
|
||||
requireContent: z.boolean().nullable().optional(),
|
||||
requireAttachment: z.boolean().nullable().optional(),
|
||||
|
||||
@@ -25,6 +25,8 @@ import {
|
||||
type UpdateChatSession,
|
||||
UpdateChatSessionOptions,
|
||||
} from '../../models';
|
||||
import { SubscriptionService } from '../payment/service';
|
||||
import { SubscriptionPlan, SubscriptionStatus } from '../payment/types';
|
||||
import { ChatMessageCache } from './message';
|
||||
import { ChatPrompt, PromptService } from './prompt';
|
||||
import {
|
||||
@@ -58,6 +60,7 @@ declare global {
|
||||
export class ChatSession implements AsyncDisposable {
|
||||
private stashMessageCount = 0;
|
||||
constructor(
|
||||
private readonly moduleRef: ModuleRef,
|
||||
private readonly messageCache: ChatMessageCache,
|
||||
private readonly state: ChatSessionState,
|
||||
private readonly dispose?: (state: ChatSessionState) => Promise<void>,
|
||||
@@ -72,6 +75,10 @@ export class ChatSession implements AsyncDisposable {
|
||||
return this.state.prompt.optionalModels;
|
||||
}
|
||||
|
||||
get proModels() {
|
||||
return this.state.prompt.config?.proModels || [];
|
||||
}
|
||||
|
||||
get config() {
|
||||
const {
|
||||
sessionId,
|
||||
@@ -93,6 +100,50 @@ export class ChatSession implements AsyncDisposable {
|
||||
return this.state.messages.findLast(m => m.role === 'user');
|
||||
}
|
||||
|
||||
async resolveModel(
|
||||
hasPayment: boolean,
|
||||
requestedModelId?: string
|
||||
): Promise<string> {
|
||||
const defaultModel = this.model;
|
||||
const normalize = (m?: string) =>
|
||||
!!m && this.optionalModels.includes(m) ? m : defaultModel;
|
||||
const isPro = (m?: string) => !!m && this.proModels.includes(m);
|
||||
|
||||
// try resolve payment subscription service lazily
|
||||
let paymentEnabled = hasPayment;
|
||||
let isUserAIPro = false;
|
||||
try {
|
||||
if (paymentEnabled) {
|
||||
const sub = this.moduleRef.get(SubscriptionService, {
|
||||
strict: false,
|
||||
});
|
||||
const subscription = await sub
|
||||
.select(SubscriptionPlan.AI)
|
||||
.getSubscription({
|
||||
userId: this.config.userId,
|
||||
plan: SubscriptionPlan.AI,
|
||||
} as any);
|
||||
isUserAIPro = subscription?.status === SubscriptionStatus.Active;
|
||||
}
|
||||
} catch {
|
||||
// payment not available -> skip checks
|
||||
paymentEnabled = false;
|
||||
}
|
||||
|
||||
if (paymentEnabled) {
|
||||
if (isUserAIPro) {
|
||||
if (!requestedModelId) {
|
||||
const firstPro = this.proModels[0];
|
||||
return normalize(firstPro);
|
||||
}
|
||||
} else if (isPro(requestedModelId)) {
|
||||
return defaultModel;
|
||||
}
|
||||
}
|
||||
|
||||
return normalize(requestedModelId);
|
||||
}
|
||||
|
||||
push(message: ChatMessage) {
|
||||
if (
|
||||
this.state.prompt.action &&
|
||||
@@ -539,12 +590,17 @@ export class ChatSessionService {
|
||||
async get(sessionId: string): Promise<ChatSession | null> {
|
||||
const state = await this.getSessionInfo(sessionId);
|
||||
if (state) {
|
||||
return new ChatSession(this.messageCache, state, async state => {
|
||||
await this.models.copilotSession.updateMessages(state);
|
||||
if (!state.prompt.action) {
|
||||
await this.jobs.add('copilot.session.generateTitle', { sessionId });
|
||||
return new ChatSession(
|
||||
this.moduleRef,
|
||||
this.messageCache,
|
||||
state,
|
||||
async state => {
|
||||
await this.models.copilotSession.updateMessages(state);
|
||||
if (!state.prompt.action) {
|
||||
await this.jobs.add('copilot.session.generateTitle', { sessionId });
|
||||
}
|
||||
}
|
||||
});
|
||||
);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ export class SubscriptionService {
|
||||
return this.stripeProvider.stripe;
|
||||
}
|
||||
|
||||
private select(plan: SubscriptionPlan): SubscriptionManager {
|
||||
select(plan: SubscriptionPlan): SubscriptionManager {
|
||||
switch (plan) {
|
||||
case SubscriptionPlan.Team:
|
||||
return this.workspaceManager;
|
||||
|
||||
Reference in New Issue
Block a user