mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 21:05:19 +00:00
@@ -51,7 +51,7 @@ defineModuleConfig('copilot', {
|
||||
override_enabled: false,
|
||||
scenarios: {
|
||||
audio_transcribing: 'gemini-2.5-flash',
|
||||
chat: 'claude-sonnet-4@20250514',
|
||||
chat: 'gemini-2.5-flash',
|
||||
embedding: 'gemini-embedding-001',
|
||||
image: 'gpt-image-1',
|
||||
rerank: 'gpt-4.1',
|
||||
|
||||
@@ -44,6 +44,7 @@ import {
|
||||
NoCopilotProviderAvailable,
|
||||
UnsplashIsNotConfigured,
|
||||
} from '../../base';
|
||||
import { ServerFeature, ServerService } from '../../core';
|
||||
import { CurrentUser, Public } from '../../core/auth';
|
||||
import { CopilotContextService } from './context';
|
||||
import {
|
||||
@@ -75,6 +76,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly server: ServerService,
|
||||
private readonly chatSession: ChatSessionService,
|
||||
private readonly context: CopilotContextService,
|
||||
private readonly provider: CopilotProviderFactory,
|
||||
@@ -112,10 +114,10 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
|
||||
const model =
|
||||
modelId && session.optionalModels.includes(modelId)
|
||||
? modelId
|
||||
: session.model;
|
||||
const model = await session.resolveModel(
|
||||
this.server.features.includes(ServerFeature.Payment),
|
||||
modelId
|
||||
);
|
||||
|
||||
const hasAttachment = messageId
|
||||
? !!(await session.getMessageById(messageId)).attachments?.length
|
||||
|
||||
@@ -1928,7 +1928,7 @@ Now apply the \`updates\` to the \`content\`, following the intent in \`op\`, an
|
||||
];
|
||||
|
||||
const CHAT_PROMPT: Omit<Prompt, 'name'> = {
|
||||
model: 'claude-sonnet-4@20250514',
|
||||
model: 'gemini-2.5-flash',
|
||||
optionalModels: [
|
||||
'gpt-4.1',
|
||||
'gpt-5',
|
||||
@@ -2099,6 +2099,13 @@ Below is the user's query. Please respond in the user's preferred language witho
|
||||
'codeArtifact',
|
||||
'blobRead',
|
||||
],
|
||||
proModels: [
|
||||
'gemini-2.5-pro',
|
||||
'claude-opus-4@20250514',
|
||||
'claude-sonnet-4@20250514',
|
||||
'claude-3-7-sonnet@20250219',
|
||||
'claude-3-5-sonnet-v2@20241022',
|
||||
],
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -4,6 +4,10 @@ import {
|
||||
type OpenAIProvider as VercelOpenAIProvider,
|
||||
OpenAIResponsesProviderOptions,
|
||||
} from '@ai-sdk/openai';
|
||||
import {
|
||||
createOpenAICompatible,
|
||||
type OpenAICompatibleProvider as VercelOpenAICompatibleProvider,
|
||||
} from '@ai-sdk/openai-compatible';
|
||||
import {
|
||||
AISDKError,
|
||||
embedMany,
|
||||
@@ -18,6 +22,7 @@ import { z } from 'zod';
|
||||
|
||||
import {
|
||||
CopilotPromptInvalid,
|
||||
CopilotProviderNotSupported,
|
||||
CopilotProviderSideError,
|
||||
metrics,
|
||||
UserFriendlyError,
|
||||
@@ -47,6 +52,7 @@ export const DEFAULT_DIMENSIONS = 256;
|
||||
export type OpenAIConfig = {
|
||||
apiKey: string;
|
||||
baseURL?: string;
|
||||
oldApiStyle?: boolean;
|
||||
};
|
||||
|
||||
const ModelListSchema = z.object({
|
||||
@@ -296,7 +302,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
},
|
||||
];
|
||||
|
||||
#instance!: VercelOpenAIProvider;
|
||||
#instance!: VercelOpenAIProvider | VercelOpenAICompatibleProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
return !!this.config.apiKey;
|
||||
@@ -304,10 +310,17 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
protected override setup() {
|
||||
super.setup();
|
||||
this.#instance = createOpenAI({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseURL,
|
||||
});
|
||||
this.#instance =
|
||||
this.config.oldApiStyle && this.config.baseURL
|
||||
? createOpenAICompatible({
|
||||
name: 'openai-compatible-old-style',
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseURL,
|
||||
})
|
||||
: createOpenAI({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseURL,
|
||||
});
|
||||
}
|
||||
|
||||
private handleError(
|
||||
@@ -341,7 +354,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
override async refreshOnlineModels() {
|
||||
try {
|
||||
const baseUrl = this.config.baseURL || 'https://api.openai.com/v1';
|
||||
if (baseUrl && !this.onlineModelList.length) {
|
||||
if (this.config.apiKey && baseUrl && !this.onlineModelList.length) {
|
||||
const { data } = await fetch(`${baseUrl}/models`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.config.apiKey}`,
|
||||
@@ -361,7 +374,11 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
toolName: CopilotChatTools,
|
||||
model: string
|
||||
): [string, Tool?] | undefined {
|
||||
if (toolName === 'webSearch' && !this.isReasoningModel(model)) {
|
||||
if (
|
||||
toolName === 'webSearch' &&
|
||||
'responses' in this.#instance &&
|
||||
!this.isReasoningModel(model)
|
||||
) {
|
||||
return ['web_search_preview', openai.tools.webSearchPreview({})];
|
||||
} else if (toolName === 'docEdit') {
|
||||
return ['doc_edit', undefined];
|
||||
@@ -374,10 +391,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Text,
|
||||
};
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
@@ -386,7 +400,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const modelInstance = this.#instance.responses(model.id);
|
||||
const modelInstance =
|
||||
'responses' in this.#instance
|
||||
? this.#instance.responses(model.id)
|
||||
: this.#instance(model.id);
|
||||
|
||||
const { text } = await generateText({
|
||||
model: modelInstance,
|
||||
@@ -507,7 +524,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
throw new CopilotPromptInvalid('Schema is required');
|
||||
}
|
||||
|
||||
const modelInstance = this.#instance.responses(model.id);
|
||||
const modelInstance =
|
||||
'responses' in this.#instance
|
||||
? this.#instance.responses(model.id)
|
||||
: this.#instance(model.id);
|
||||
|
||||
const { object } = await generateObject({
|
||||
model: modelInstance,
|
||||
@@ -539,7 +559,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
await this.checkParams({ messages: [], cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
// get the log probability of "yes"/"no"
|
||||
const instance = this.#instance.chat(model.id);
|
||||
const instance =
|
||||
'chat' in this.#instance
|
||||
? this.#instance.chat(model.id)
|
||||
: this.#instance(model.id);
|
||||
|
||||
const scores = await Promise.all(
|
||||
chunkMessages.map(async messages => {
|
||||
@@ -600,7 +623,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
options: CopilotChatOptions = {}
|
||||
) {
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
const modelInstance = this.#instance.responses(model.id);
|
||||
const modelInstance =
|
||||
'responses' in this.#instance
|
||||
? this.#instance.responses(model.id)
|
||||
: this.#instance(model.id);
|
||||
const { fullStream } = streamText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
@@ -685,6 +711,13 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
if (!('image' in this.#instance)) {
|
||||
throw new CopilotProviderNotSupported({
|
||||
provider: this.type,
|
||||
kind: 'image',
|
||||
});
|
||||
}
|
||||
|
||||
metrics.ai
|
||||
.counter('generate_images_stream_calls')
|
||||
.add(1, { model: model.id });
|
||||
@@ -735,6 +768,13 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
await this.checkParams({ embeddings: messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
if (!('embedding' in this.#instance)) {
|
||||
throw new CopilotProviderNotSupported({
|
||||
provider: this.type,
|
||||
kind: 'embedding',
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('generate_embedding_calls')
|
||||
@@ -775,6 +815,6 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
private isReasoningModel(model: string) {
|
||||
// o series reasoning models
|
||||
return model.startsWith('o');
|
||||
return model.startsWith('o') || model.startsWith('gpt-5');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,6 +80,7 @@ export const PromptToolsSchema = z
|
||||
|
||||
export const PromptConfigStrictSchema = z.object({
|
||||
tools: PromptToolsSchema.nullable().optional(),
|
||||
proModels: z.array(z.string()).nullable().optional(),
|
||||
// params requirements
|
||||
requireContent: z.boolean().nullable().optional(),
|
||||
requireAttachment: z.boolean().nullable().optional(),
|
||||
|
||||
@@ -25,6 +25,8 @@ import {
|
||||
type UpdateChatSession,
|
||||
UpdateChatSessionOptions,
|
||||
} from '../../models';
|
||||
import { SubscriptionService } from '../payment/service';
|
||||
import { SubscriptionPlan, SubscriptionStatus } from '../payment/types';
|
||||
import { ChatMessageCache } from './message';
|
||||
import { ChatPrompt, PromptService } from './prompt';
|
||||
import {
|
||||
@@ -58,6 +60,7 @@ declare global {
|
||||
export class ChatSession implements AsyncDisposable {
|
||||
private stashMessageCount = 0;
|
||||
constructor(
|
||||
private readonly moduleRef: ModuleRef,
|
||||
private readonly messageCache: ChatMessageCache,
|
||||
private readonly state: ChatSessionState,
|
||||
private readonly dispose?: (state: ChatSessionState) => Promise<void>,
|
||||
@@ -72,6 +75,10 @@ export class ChatSession implements AsyncDisposable {
|
||||
return this.state.prompt.optionalModels;
|
||||
}
|
||||
|
||||
get proModels() {
|
||||
return this.state.prompt.config?.proModels || [];
|
||||
}
|
||||
|
||||
get config() {
|
||||
const {
|
||||
sessionId,
|
||||
@@ -93,6 +100,50 @@ export class ChatSession implements AsyncDisposable {
|
||||
return this.state.messages.findLast(m => m.role === 'user');
|
||||
}
|
||||
|
||||
async resolveModel(
|
||||
hasPayment: boolean,
|
||||
requestedModelId?: string
|
||||
): Promise<string> {
|
||||
const defaultModel = this.model;
|
||||
const normalize = (m?: string) =>
|
||||
!!m && this.optionalModels.includes(m) ? m : defaultModel;
|
||||
const isPro = (m?: string) => !!m && this.proModels.includes(m);
|
||||
|
||||
// try resolve payment subscription service lazily
|
||||
let paymentEnabled = hasPayment;
|
||||
let isUserAIPro = false;
|
||||
try {
|
||||
if (paymentEnabled) {
|
||||
const sub = this.moduleRef.get(SubscriptionService, {
|
||||
strict: false,
|
||||
});
|
||||
const subscription = await sub
|
||||
.select(SubscriptionPlan.AI)
|
||||
.getSubscription({
|
||||
userId: this.config.userId,
|
||||
plan: SubscriptionPlan.AI,
|
||||
} as any);
|
||||
isUserAIPro = subscription?.status === SubscriptionStatus.Active;
|
||||
}
|
||||
} catch {
|
||||
// payment not available -> skip checks
|
||||
paymentEnabled = false;
|
||||
}
|
||||
|
||||
if (paymentEnabled) {
|
||||
if (isUserAIPro) {
|
||||
if (!requestedModelId) {
|
||||
const firstPro = this.proModels[0];
|
||||
return normalize(firstPro);
|
||||
}
|
||||
} else if (isPro(requestedModelId)) {
|
||||
return defaultModel;
|
||||
}
|
||||
}
|
||||
|
||||
return normalize(requestedModelId);
|
||||
}
|
||||
|
||||
push(message: ChatMessage) {
|
||||
if (
|
||||
this.state.prompt.action &&
|
||||
@@ -539,12 +590,17 @@ export class ChatSessionService {
|
||||
async get(sessionId: string): Promise<ChatSession | null> {
|
||||
const state = await this.getSessionInfo(sessionId);
|
||||
if (state) {
|
||||
return new ChatSession(this.messageCache, state, async state => {
|
||||
await this.models.copilotSession.updateMessages(state);
|
||||
if (!state.prompt.action) {
|
||||
await this.jobs.add('copilot.session.generateTitle', { sessionId });
|
||||
return new ChatSession(
|
||||
this.moduleRef,
|
||||
this.messageCache,
|
||||
state,
|
||||
async state => {
|
||||
await this.models.copilotSession.updateMessages(state);
|
||||
if (!state.prompt.action) {
|
||||
await this.jobs.add('copilot.session.generateTitle', { sessionId });
|
||||
}
|
||||
}
|
||||
});
|
||||
);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ export class SubscriptionService {
|
||||
return this.stripeProvider.stripe;
|
||||
}
|
||||
|
||||
private select(plan: SubscriptionPlan): SubscriptionManager {
|
||||
select(plan: SubscriptionPlan): SubscriptionManager {
|
||||
switch (plan) {
|
||||
case SubscriptionPlan.Team:
|
||||
return this.workspaceManager;
|
||||
|
||||
Reference in New Issue
Block a user