feat(core): support gemini model switch in ai (#13631)

<img width="757" height="447" alt="截屏2025-09-22 17 49 34"
src="https://github.com/user-attachments/assets/bab96f45-112e-4d74-bc38-54429d8a54ab"
/>


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- New Features
- Subscription-aware AI model picker in chat: browse models with version
and category, see active selection, switch models, and receive
notifications when choosing pro models without a subscription.
Selections persist across sessions.
- Central AI model service wired into chat UI for consistent model
selection and availability.

- Changes
- Streamlined AI model availability: reduced to a curated set for a more
focused experience.
  - Context menu buttons can display supplemental info next to labels.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Wu Yue
2025-09-22 21:25:11 +08:00
committed by GitHub
parent da3e3eb3fa
commit b25759c264
12 changed files with 281 additions and 50 deletions
@@ -193,6 +193,7 @@ export const menuButtonItems = {
(config: {
name: string;
label?: () => TemplateResult;
info?: TemplateResult;
prefix?: TemplateResult;
postfix?: TemplateResult;
isSelected?: boolean;
@@ -211,7 +212,7 @@ export const menuButtonItems = {
return html`
${config.prefix}
<div class="affine-menu-action-text">
${config.label?.() ?? config.name}
${config.label?.() ?? config.name} ${config.info}
</div>
${config.postfix ?? (config.isSelected ? DoneIcon() : undefined)}
`;
@@ -1930,16 +1930,9 @@ Now apply the \`updates\` to the \`content\`, following the intent in \`op\`, an
const CHAT_PROMPT: Omit<Prompt, 'name'> = {
model: 'gemini-2.5-flash',
optionalModels: [
'gpt-4.1',
'gpt-5',
'o3',
'o4-mini',
'gemini-2.5-flash',
'gemini-2.5-pro',
'claude-opus-4@20250514',
'claude-sonnet-4@20250514',
'claude-3-7-sonnet@20250219',
'claude-3-5-sonnet-v2@20241022',
],
messages: [
{
@@ -2099,13 +2092,7 @@ Below is the user's query. Please respond in the user's preferred language witho
'codeArtifact',
'blobRead',
],
proModels: [
'gemini-2.5-pro',
'claude-opus-4@20250514',
'claude-sonnet-4@20250514',
'claude-3-7-sonnet@20250219',
'claude-3-5-sonnet-v2@20241022',
],
proModels: ['gemini-2.5-pro', 'claude-sonnet-4@20250514'],
},
};
@@ -2,6 +2,8 @@ import type {
AIDraftService,
AIToolsConfigService,
} from '@affine/core/modules/ai-button';
import type { AIModelService } from '@affine/core/modules/ai-button/services/models';
import type { SubscriptionService } from '@affine/core/modules/cloud';
import type { WorkspaceDialogService } from '@affine/core/modules/dialogs';
import type { FeatureFlagService } from '@affine/core/modules/feature-flag';
import type { PeekViewService } from '@affine/core/modules/peek-view';
@@ -129,6 +131,12 @@ export class ChatPanel extends SignalWatcher(
@property({ attribute: false })
accessor peekViewService!: PeekViewService;
@property({ attribute: false })
accessor subscriptionService!: SubscriptionService;
@property({ attribute: false })
accessor aiModelService!: AIModelService;
@state()
accessor session: CopilotChatHistoryFragment | null | undefined;
@@ -426,6 +434,8 @@ export class ChatPanel extends SignalWatcher(
.aiDraftService=${this.aiDraftService}
.aiToolsConfigService=${this.aiToolsConfigService}
.peekViewService=${this.peekViewService}
.subscriptionService=${this.subscriptionService}
.aiModelService=${this.aiModelService}
.onEmbeddingProgressChange=${this.onEmbeddingProgressChange}
.onContextChange=${this.onContextChange}
.width=${this.sidebarWidth}
@@ -4,6 +4,8 @@ import type {
AIDraftService,
AIToolsConfigService,
} from '@affine/core/modules/ai-button';
import type { AIModelService } from '@affine/core/modules/ai-button/services/models';
import type { SubscriptionService } from '@affine/core/modules/cloud';
import type { WorkspaceDialogService } from '@affine/core/modules/dialogs';
import type {
ContextEmbedStatus,
@@ -141,6 +143,12 @@ export class AIChatComposer extends SignalWatcher(
@property({ attribute: false })
accessor affineFeatureFlagService!: FeatureFlagService;
@property({ attribute: false })
accessor subscriptionService!: SubscriptionService;
@property({ attribute: false })
accessor aiModelService!: AIModelService;
@state()
accessor chips: ChatChip[] = [];
@@ -189,6 +197,9 @@ export class AIChatComposer extends SignalWatcher(
.affineFeatureFlagService=${this.affineFeatureFlagService}
.aiDraftService=${this.aiDraftService}
.aiToolsConfigService=${this.aiToolsConfigService}
.notificationService=${this.notificationService}
.subscriptionService=${this.subscriptionService}
.aiModelService=${this.aiModelService}
.portalContainer=${this.portalContainer}
.onChatSuccess=${this.onChatSuccess}
.trackOptions=${this.trackOptions}
@@ -3,6 +3,8 @@ import type {
AIToolsConfigService,
} from '@affine/core/modules/ai-button';
import type { AIDraftState } from '@affine/core/modules/ai-button/services/ai-draft';
import type { AIModelService } from '@affine/core/modules/ai-button/services/models';
import type { SubscriptionService } from '@affine/core/modules/cloud';
import type { WorkspaceDialogService } from '@affine/core/modules/dialogs';
import type { FeatureFlagService } from '@affine/core/modules/feature-flag';
import type { PeekViewService } from '@affine/core/modules/peek-view';
@@ -167,6 +169,9 @@ export class AIChatContent extends SignalWatcher(
@property({ attribute: false })
accessor aiToolsConfigService!: AIToolsConfigService;
@property({ attribute: false })
accessor aiModelService!: AIModelService;
@property({ attribute: false })
accessor onEmbeddingProgressChange:
| ((count: Record<ContextEmbedStatus, number>) => void)
@@ -184,6 +189,9 @@ export class AIChatContent extends SignalWatcher(
@property({ attribute: false })
accessor peekViewService!: PeekViewService;
@property({ attribute: false })
accessor subscriptionService!: SubscriptionService;
@state()
accessor chatContextValue: ChatContextValue = DEFAULT_CHAT_CONTEXT_VALUE;
@@ -462,6 +470,8 @@ export class AIChatContent extends SignalWatcher(
.notificationService=${this.notificationService}
.aiDraftService=${this.aiDraftService}
.aiToolsConfigService=${this.aiToolsConfigService}
.subscriptionService=${this.subscriptionService}
.aiModelService=${this.aiModelService}
.trackOptions=${{
where: 'chat-panel',
control: 'chat-send',
@@ -2,12 +2,15 @@ import type {
AIDraftService,
AIToolsConfigService,
} from '@affine/core/modules/ai-button';
import type { AIModelService } from '@affine/core/modules/ai-button/services/models';
import type { SubscriptionService } from '@affine/core/modules/cloud';
import type { FeatureFlagService } from '@affine/core/modules/feature-flag';
import type { CopilotChatHistoryFragment } from '@affine/graphql';
import { SignalWatcher, WithDisposable } from '@blocksuite/affine/global/lit';
import { unsafeCSSVar, unsafeCSSVarV2 } from '@blocksuite/affine/shared/theme';
import type { EditorHost } from '@blocksuite/affine/std';
import { ShadowlessElement } from '@blocksuite/affine/std';
import type { NotificationService } from '@blocksuite/affine-shared/services';
import { ArrowUpBigIcon, CloseIcon } from '@blocksuite/icons/lit';
import { css, html, nothing, type PropertyValues } from 'lit';
import { property, query, state } from 'lit/decorators.js';
@@ -324,9 +327,6 @@ export class AIChatInput extends SignalWatcher(
@state()
accessor focused = false;
@state()
accessor modelId: string | undefined = undefined;
@property({ attribute: false })
accessor chatContextValue!: AIChatInputContext;
@@ -368,6 +368,15 @@ export class AIChatInput extends SignalWatcher(
@property({ attribute: false })
accessor affineFeatureFlagService!: FeatureFlagService;
@property({ attribute: false })
accessor notificationService!: NotificationService;
@property({ attribute: false })
accessor subscriptionService!: SubscriptionService;
@property({ attribute: false })
accessor aiModelService!: AIModelService;
@property({ attribute: false })
accessor isRootSession: boolean = true;
@@ -516,14 +525,15 @@ export class AIChatInput extends SignalWatcher(
<div class="chat-input-footer-spacer"></div>
<chat-input-preference
.session=${this.session}
.onModelChange=${this._handleModelChange}
.modelId=${this.modelId}
.extendedThinking=${this._isReasoningActive}
.onExtendedThinkingChange=${this._toggleReasoning}
.networkSearchVisible=${!!this.networkSearchConfig.visible.value}
.isNetworkActive=${this._isNetworkActive}
.onNetworkActiveChange=${this._toggleNetworkSearch}
.toolsConfigService=${this.aiToolsConfigService}
.notificationService=${this.notificationService}
.subscriptionService=${this.subscriptionService}
.aiModelService=${this.aiModelService}
></chat-input-preference>
${status === 'transmitting' || status === 'loading'
? html`<button
@@ -646,10 +656,6 @@ export class AIChatInput extends SignalWatcher(
await this.send(value);
};
private readonly _handleModelChange = (modelId: string) => {
this.modelId = modelId;
};
send = async (text: string) => {
try {
const {
@@ -693,6 +699,7 @@ export class AIChatInput extends SignalWatcher(
this.affineFeatureFlagService.flags.enable_send_detailed_object_to_ai
.value;
const modelId = this.aiModelService.modelId.value;
const stream = await AIProvider.actions.chat({
sessionId,
input: userInput,
@@ -717,7 +724,7 @@ export class AIChatInput extends SignalWatcher(
webSearch: this._isNetworkActive,
reasoning: this._isReasoningActive,
toolsConfig: this.aiToolsConfigService.config.value,
modelId: this.modelId,
modelId,
});
for await (const text of stream) {
@@ -1,18 +1,29 @@
import type { AIToolsConfigService } from '@affine/core/modules/ai-button';
import type { CopilotChatHistoryFragment } from '@affine/graphql';
import type { AIModelService } from '@affine/core/modules/ai-button/services/models';
import type { SubscriptionService } from '@affine/core/modules/cloud';
import {
type CopilotChatHistoryFragment,
SubscriptionStatus,
} from '@affine/graphql';
import {
menu,
popMenu,
popupTargetFromElement,
} from '@blocksuite/affine/components/context-menu';
import { SignalWatcher, WithDisposable } from '@blocksuite/affine/global/lit';
import { unsafeCSSVarV2 } from '@blocksuite/affine/shared/theme';
import type { NotificationService } from '@blocksuite/affine-shared/services';
import {
AiOutlineIcon,
ArrowDownSmallIcon,
CloudWorkspaceIcon,
DoneIcon,
LockIcon,
ThinkingIcon,
WebIcon,
} from '@blocksuite/icons/lit';
import { ShadowlessElement } from '@blocksuite/std';
import { computed } from '@preact/signals-core';
import { css, html } from 'lit';
import { property } from 'lit/decorators.js';
@@ -48,16 +59,29 @@ export class ChatInputPreference extends SignalWatcher(
white-space: nowrap;
min-width: 220px;
}
.ai-active-model-name {
font-size: 14px;
color: ${unsafeCSSVarV2('text/secondary')};
line-height: 22px;
margin-left: 40px;
}
.ai-model-prefix {
width: 20px;
height: 20px;
}
.ai-model-prefix svg {
color: ${unsafeCSSVarV2('icon/activated')};
}
.ai-model-version {
font-size: 12px;
color: ${unsafeCSSVarV2('text/tertiary')};
line-height: 20px;
margin-right: 40px;
}
`;
@property({ attribute: false })
accessor session!: CopilotChatHistoryFragment | null | undefined;
@property({ attribute: false })
accessor onModelChange: ((modelId: string) => void) | undefined;
@property({ attribute: false })
accessor modelId: string | undefined = undefined;
// --------- model props end ---------
// --------- extended thinking props start ---------
@@ -86,9 +110,25 @@ export class ChatInputPreference extends SignalWatcher(
@property({ attribute: false })
accessor toolsConfigService!: AIToolsConfigService;
// private readonly _onModelChange = (modelId: string) => {
// this.onModelChange?.(modelId);
// };
@property({ attribute: false })
accessor notificationService!: NotificationService;
@property({ attribute: false })
accessor subscriptionService!: SubscriptionService;
@property({ attribute: false })
accessor aiModelService!: AIModelService;
model = computed(() => {
const modelId = this.aiModelService.modelId.value;
const activeModel = this.aiModelService.models.value.find(
model => model.id === modelId
);
const defaultModel = this.aiModelService.models.value.find(
model => model.isDefault
);
return activeModel || defaultModel;
});
openPreference(e: Event) {
const element = e.currentTarget;
@@ -97,20 +137,48 @@ export class ChatInputPreference extends SignalWatcher(
const searchItems = [];
// model switch
// modelItems.push(
// menu.subMenu({
// name: 'Model',
// prefix: AiOutlineIcon(),
// options: {
// items: (this.session?.optionalModels ?? []).map(modelId => {
// return menu.action({
// name: modelId,
// select: () => this._onModelChange(modelId),
// });
// }),
// },
// })
// );
modelItems.push(
menu.subMenu({
name: 'Model',
prefix: AiOutlineIcon(),
postfix: html`
<span class="ai-active-model-name"> ${this.model.value?.name} </span>
`,
options: {
items: this.aiModelService.models.value.map(model => {
const isSelected = model.id === this.model.value?.id;
const status =
this.subscriptionService.subscription.ai$.value?.status;
const isSubscribed = status === SubscriptionStatus.Active;
return menu.action({
name: model.category,
info: html`
<span class="ai-model-version">${model.version}</span>
`,
prefix: html`
<div class="ai-model-prefix">
${isSelected ? DoneIcon() : undefined}
</div>
`,
postfix: html`
<div>
${model.isPro && !isSubscribed ? LockIcon() : undefined}
</div>
`,
select: () => {
if (model.isPro && !isSubscribed) {
this.notificationService.toast(
`Pro models require an AFFiNE AI subscription.`
);
return;
}
this.aiModelService.setModel(model.id);
},
});
}),
},
})
);
modelItems.push(
menu.toggleSwitch({
@@ -169,7 +237,9 @@ export class ChatInputPreference extends SignalWatcher(
data-testid="chat-input-preference-trigger"
class="chat-input-preference-trigger"
>
<span class="chat-input-preference-trigger-label"> Claude </span>
<span class="chat-input-preference-trigger-label">
${this.model.value?.category}
</span>
<span class="chat-input-preference-trigger-icon">
${ArrowDownSmallIcon()}
</span>
@@ -15,10 +15,12 @@ import {
AIDraftService,
AIToolsConfigService,
} from '@affine/core/modules/ai-button';
import { AIModelService } from '@affine/core/modules/ai-button/services/models';
import {
EventSourceService,
FetchService,
GraphQLService,
SubscriptionService,
} from '@affine/core/modules/cloud';
import { WorkspaceDialogService } from '@affine/core/modules/dialogs';
import { FeatureFlagService } from '@affine/core/modules/feature-flag';
@@ -229,6 +231,8 @@ export const Component = () => {
);
content.aiDraftService = framework.get(AIDraftService);
content.aiToolsConfigService = framework.get(AIToolsConfigService);
content.subscriptionService = framework.get(SubscriptionService);
content.aiModelService = framework.get(AIModelService);
content.createSession = createSession;
content.onOpenDoc = onOpenDoc;
@@ -8,6 +8,8 @@ import {
AIDraftService,
AIToolsConfigService,
} from '@affine/core/modules/ai-button';
import { AIModelService } from '@affine/core/modules/ai-button/services/models';
import { SubscriptionService } from '@affine/core/modules/cloud';
import { WorkspaceDialogService } from '@affine/core/modules/dialogs';
import { FeatureFlagService } from '@affine/core/modules/feature-flag';
import { PeekViewService } from '@affine/core/modules/peek-view';
@@ -104,6 +106,9 @@ export const EditorChatPanel = forwardRef(function EditorChatPanel(
chatPanelRef.current.aiDraftService = framework.get(AIDraftService);
chatPanelRef.current.aiToolsConfigService =
framework.get(AIToolsConfigService);
chatPanelRef.current.subscriptionService =
framework.get(SubscriptionService);
chatPanelRef.current.aiModelService = framework.get(AIModelService);
containerRef.current?.append(chatPanelRef.current);
} else {
@@ -8,12 +8,14 @@ export {
import type { Framework } from '@toeverything/infra';
import { GraphQLService, ServerScope, SubscriptionService } from '../cloud';
import { FeatureFlagService } from '../feature-flag';
import { CacheStorage, GlobalStateService } from '../storage';
import { WorkspaceScope } from '../workspace';
import { AIButtonProvider } from './provider/ai-button';
import { AIButtonService } from './services/ai-button';
import { AIDraftService } from './services/ai-draft';
import { AIModelService } from './services/models';
import { AINetworkSearchService } from './services/network-search';
import { AIPlaygroundService } from './services/playground';
import { AIReasoningService } from './services/reasoning';
@@ -49,3 +51,13 @@ export function configureAIDraftModule(framework: Framework) {
export function configureAIToolsConfigModule(framework: Framework) {
framework.service(AIToolsConfigService, [GlobalStateService]);
}
export function configureAIModelModule(framework: Framework) {
framework
.scope(ServerScope)
.service(AIModelService, [
GlobalStateService,
GraphQLService,
SubscriptionService,
]);
}
@@ -0,0 +1,112 @@
import { getPromptModelsQuery, SubscriptionStatus } from '@affine/graphql';
import {
createSignalFromObservable,
type Signal,
} from '@blocksuite/affine/shared/utils';
import { signal } from '@preact/signals-core';
import { LiveData, Service } from '@toeverything/infra';
import type { GraphQLService, SubscriptionService } from '../../cloud';
import type { GlobalStateService } from '../../storage';
const AI_MODEL_ID_KEY = 'AIModelId';
export interface AIModel {
name: string;
id: string;
version: string;
category: string;
isPro: boolean;
isDefault: boolean;
}
export class AIModelService extends Service {
modelId: Signal<string | undefined>;
models: Signal<AIModel[]> = signal([]);
private readonly modelId$ = LiveData.from(
this.globalStateService.globalState.watch<string>(AI_MODEL_ID_KEY),
undefined
);
constructor(
private readonly globalStateService: GlobalStateService,
private readonly gqlService: GraphQLService,
private readonly subscriptionService: SubscriptionService
) {
super();
const { signal: modelId, cleanup } = createSignalFromObservable<
string | undefined
>(this.modelId$, undefined);
this.modelId = modelId;
this.disposables.push(cleanup);
this.init().catch(err => {
console.error(err);
});
}
resetModel = () => {
this.globalStateService.globalState.set(AI_MODEL_ID_KEY, undefined);
};
setModel = (modelId: string) => {
const isSubscribed =
this.subscriptionService.subscription.ai$.value?.status ===
SubscriptionStatus.Active;
const model = this.models.value.find(model => model.id === modelId);
if (!isSubscribed && model?.isPro) {
return;
}
this.globalStateService.globalState.set(AI_MODEL_ID_KEY, modelId);
};
private readonly init = async () => {
await this.initModels();
// subscribe to ai purchase status
const sub = this.subscriptionService.subscription.ai$.subscribe(
subscription => {
const isSubscribed = subscription?.status === SubscriptionStatus.Active;
const model = this.models.value.find(
model => model.id === this.modelId.value
);
if (!isSubscribed && model?.isPro) {
this.resetModel();
}
}
);
this.disposables.push(() => sub.unsubscribe());
};
private readonly initModels = async (prompt?: string) => {
const promptName = prompt || 'Chat With AFFiNE AI';
const models = await this.getModelsByPrompt(promptName);
if (models) {
const { defaultModel, optionalModels, proModels } = models;
this.models.value = optionalModels.map(model => {
const [category] = model.name.split(' ');
const version = model.name.slice(category.length + 1);
return {
name: model.name,
id: model.id,
version,
category,
isPro: proModels.some(proModel => proModel.id === model.id),
isDefault: model.id === defaultModel,
};
});
}
};
private readonly getModelsByPrompt = async (promptName: string) => {
return this.gqlService
.gql({
query: getPromptModelsQuery,
variables: { promptName },
})
.then(res => res.currentUser?.copilot?.models);
};
}
@@ -4,6 +4,7 @@ import { type Framework } from '@toeverything/infra';
import {
configureAIButtonModule,
configureAIDraftModule,
configureAIModelModule,
configureAINetworkSearchModule,
configureAIPlaygroundModule,
configureAIReasoningModule,
@@ -117,6 +118,7 @@ export function configureCommonModules(framework: Framework) {
configureAIButtonModule(framework);
configureAIDraftModule(framework);
configureAIToolsConfigModule(framework);
configureAIModelModule(framework);
configureTemplateDocModule(framework);
configureBlobManagementModule(framework);
configureMediaModule(framework);