mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-07-02 02:00:49 +08:00
feat(core): support gemini model switch in ai (#13631)
<img width="757" height="447" alt="截屏2025-09-22 17 49 34" src="https://github.com/user-attachments/assets/bab96f45-112e-4d74-bc38-54429d8a54ab" /> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - New Features - Subscription-aware AI model picker in chat: browse models with version and category, see active selection, switch models, and receive notifications when choosing pro models without a subscription. Selections persist across sessions. - Central AI model service wired into chat UI for consistent model selection and availability. - Changes - Streamlined AI model availability: reduced to a curated set for a more focused experience. - Context menu buttons can display supplemental info next to labels. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -193,6 +193,7 @@ export const menuButtonItems = {
|
||||
(config: {
|
||||
name: string;
|
||||
label?: () => TemplateResult;
|
||||
info?: TemplateResult;
|
||||
prefix?: TemplateResult;
|
||||
postfix?: TemplateResult;
|
||||
isSelected?: boolean;
|
||||
@@ -211,7 +212,7 @@ export const menuButtonItems = {
|
||||
return html`
|
||||
${config.prefix}
|
||||
<div class="affine-menu-action-text">
|
||||
${config.label?.() ?? config.name}
|
||||
${config.label?.() ?? config.name} ${config.info}
|
||||
</div>
|
||||
${config.postfix ?? (config.isSelected ? DoneIcon() : undefined)}
|
||||
`;
|
||||
|
||||
@@ -1930,16 +1930,9 @@ Now apply the \`updates\` to the \`content\`, following the intent in \`op\`, an
|
||||
const CHAT_PROMPT: Omit<Prompt, 'name'> = {
|
||||
model: 'gemini-2.5-flash',
|
||||
optionalModels: [
|
||||
'gpt-4.1',
|
||||
'gpt-5',
|
||||
'o3',
|
||||
'o4-mini',
|
||||
'gemini-2.5-flash',
|
||||
'gemini-2.5-pro',
|
||||
'claude-opus-4@20250514',
|
||||
'claude-sonnet-4@20250514',
|
||||
'claude-3-7-sonnet@20250219',
|
||||
'claude-3-5-sonnet-v2@20241022',
|
||||
],
|
||||
messages: [
|
||||
{
|
||||
@@ -2099,13 +2092,7 @@ Below is the user's query. Please respond in the user's preferred language witho
|
||||
'codeArtifact',
|
||||
'blobRead',
|
||||
],
|
||||
proModels: [
|
||||
'gemini-2.5-pro',
|
||||
'claude-opus-4@20250514',
|
||||
'claude-sonnet-4@20250514',
|
||||
'claude-3-7-sonnet@20250219',
|
||||
'claude-3-5-sonnet-v2@20241022',
|
||||
],
|
||||
proModels: ['gemini-2.5-pro', 'claude-sonnet-4@20250514'],
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ import type {
|
||||
AIDraftService,
|
||||
AIToolsConfigService,
|
||||
} from '@affine/core/modules/ai-button';
|
||||
import type { AIModelService } from '@affine/core/modules/ai-button/services/models';
|
||||
import type { SubscriptionService } from '@affine/core/modules/cloud';
|
||||
import type { WorkspaceDialogService } from '@affine/core/modules/dialogs';
|
||||
import type { FeatureFlagService } from '@affine/core/modules/feature-flag';
|
||||
import type { PeekViewService } from '@affine/core/modules/peek-view';
|
||||
@@ -129,6 +131,12 @@ export class ChatPanel extends SignalWatcher(
|
||||
@property({ attribute: false })
|
||||
accessor peekViewService!: PeekViewService;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor subscriptionService!: SubscriptionService;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor aiModelService!: AIModelService;
|
||||
|
||||
@state()
|
||||
accessor session: CopilotChatHistoryFragment | null | undefined;
|
||||
|
||||
@@ -426,6 +434,8 @@ export class ChatPanel extends SignalWatcher(
|
||||
.aiDraftService=${this.aiDraftService}
|
||||
.aiToolsConfigService=${this.aiToolsConfigService}
|
||||
.peekViewService=${this.peekViewService}
|
||||
.subscriptionService=${this.subscriptionService}
|
||||
.aiModelService=${this.aiModelService}
|
||||
.onEmbeddingProgressChange=${this.onEmbeddingProgressChange}
|
||||
.onContextChange=${this.onContextChange}
|
||||
.width=${this.sidebarWidth}
|
||||
|
||||
+11
@@ -4,6 +4,8 @@ import type {
|
||||
AIDraftService,
|
||||
AIToolsConfigService,
|
||||
} from '@affine/core/modules/ai-button';
|
||||
import type { AIModelService } from '@affine/core/modules/ai-button/services/models';
|
||||
import type { SubscriptionService } from '@affine/core/modules/cloud';
|
||||
import type { WorkspaceDialogService } from '@affine/core/modules/dialogs';
|
||||
import type {
|
||||
ContextEmbedStatus,
|
||||
@@ -141,6 +143,12 @@ export class AIChatComposer extends SignalWatcher(
|
||||
@property({ attribute: false })
|
||||
accessor affineFeatureFlagService!: FeatureFlagService;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor subscriptionService!: SubscriptionService;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor aiModelService!: AIModelService;
|
||||
|
||||
@state()
|
||||
accessor chips: ChatChip[] = [];
|
||||
|
||||
@@ -189,6 +197,9 @@ export class AIChatComposer extends SignalWatcher(
|
||||
.affineFeatureFlagService=${this.affineFeatureFlagService}
|
||||
.aiDraftService=${this.aiDraftService}
|
||||
.aiToolsConfigService=${this.aiToolsConfigService}
|
||||
.notificationService=${this.notificationService}
|
||||
.subscriptionService=${this.subscriptionService}
|
||||
.aiModelService=${this.aiModelService}
|
||||
.portalContainer=${this.portalContainer}
|
||||
.onChatSuccess=${this.onChatSuccess}
|
||||
.trackOptions=${this.trackOptions}
|
||||
|
||||
+10
@@ -3,6 +3,8 @@ import type {
|
||||
AIToolsConfigService,
|
||||
} from '@affine/core/modules/ai-button';
|
||||
import type { AIDraftState } from '@affine/core/modules/ai-button/services/ai-draft';
|
||||
import type { AIModelService } from '@affine/core/modules/ai-button/services/models';
|
||||
import type { SubscriptionService } from '@affine/core/modules/cloud';
|
||||
import type { WorkspaceDialogService } from '@affine/core/modules/dialogs';
|
||||
import type { FeatureFlagService } from '@affine/core/modules/feature-flag';
|
||||
import type { PeekViewService } from '@affine/core/modules/peek-view';
|
||||
@@ -167,6 +169,9 @@ export class AIChatContent extends SignalWatcher(
|
||||
@property({ attribute: false })
|
||||
accessor aiToolsConfigService!: AIToolsConfigService;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor aiModelService!: AIModelService;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor onEmbeddingProgressChange:
|
||||
| ((count: Record<ContextEmbedStatus, number>) => void)
|
||||
@@ -184,6 +189,9 @@ export class AIChatContent extends SignalWatcher(
|
||||
@property({ attribute: false })
|
||||
accessor peekViewService!: PeekViewService;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor subscriptionService!: SubscriptionService;
|
||||
|
||||
@state()
|
||||
accessor chatContextValue: ChatContextValue = DEFAULT_CHAT_CONTEXT_VALUE;
|
||||
|
||||
@@ -462,6 +470,8 @@ export class AIChatContent extends SignalWatcher(
|
||||
.notificationService=${this.notificationService}
|
||||
.aiDraftService=${this.aiDraftService}
|
||||
.aiToolsConfigService=${this.aiToolsConfigService}
|
||||
.subscriptionService=${this.subscriptionService}
|
||||
.aiModelService=${this.aiModelService}
|
||||
.trackOptions=${{
|
||||
where: 'chat-panel',
|
||||
control: 'chat-send',
|
||||
|
||||
+17
-10
@@ -2,12 +2,15 @@ import type {
|
||||
AIDraftService,
|
||||
AIToolsConfigService,
|
||||
} from '@affine/core/modules/ai-button';
|
||||
import type { AIModelService } from '@affine/core/modules/ai-button/services/models';
|
||||
import type { SubscriptionService } from '@affine/core/modules/cloud';
|
||||
import type { FeatureFlagService } from '@affine/core/modules/feature-flag';
|
||||
import type { CopilotChatHistoryFragment } from '@affine/graphql';
|
||||
import { SignalWatcher, WithDisposable } from '@blocksuite/affine/global/lit';
|
||||
import { unsafeCSSVar, unsafeCSSVarV2 } from '@blocksuite/affine/shared/theme';
|
||||
import type { EditorHost } from '@blocksuite/affine/std';
|
||||
import { ShadowlessElement } from '@blocksuite/affine/std';
|
||||
import type { NotificationService } from '@blocksuite/affine-shared/services';
|
||||
import { ArrowUpBigIcon, CloseIcon } from '@blocksuite/icons/lit';
|
||||
import { css, html, nothing, type PropertyValues } from 'lit';
|
||||
import { property, query, state } from 'lit/decorators.js';
|
||||
@@ -324,9 +327,6 @@ export class AIChatInput extends SignalWatcher(
|
||||
@state()
|
||||
accessor focused = false;
|
||||
|
||||
@state()
|
||||
accessor modelId: string | undefined = undefined;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor chatContextValue!: AIChatInputContext;
|
||||
|
||||
@@ -368,6 +368,15 @@ export class AIChatInput extends SignalWatcher(
|
||||
@property({ attribute: false })
|
||||
accessor affineFeatureFlagService!: FeatureFlagService;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor notificationService!: NotificationService;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor subscriptionService!: SubscriptionService;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor aiModelService!: AIModelService;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor isRootSession: boolean = true;
|
||||
|
||||
@@ -516,14 +525,15 @@ export class AIChatInput extends SignalWatcher(
|
||||
<div class="chat-input-footer-spacer"></div>
|
||||
<chat-input-preference
|
||||
.session=${this.session}
|
||||
.onModelChange=${this._handleModelChange}
|
||||
.modelId=${this.modelId}
|
||||
.extendedThinking=${this._isReasoningActive}
|
||||
.onExtendedThinkingChange=${this._toggleReasoning}
|
||||
.networkSearchVisible=${!!this.networkSearchConfig.visible.value}
|
||||
.isNetworkActive=${this._isNetworkActive}
|
||||
.onNetworkActiveChange=${this._toggleNetworkSearch}
|
||||
.toolsConfigService=${this.aiToolsConfigService}
|
||||
.notificationService=${this.notificationService}
|
||||
.subscriptionService=${this.subscriptionService}
|
||||
.aiModelService=${this.aiModelService}
|
||||
></chat-input-preference>
|
||||
${status === 'transmitting' || status === 'loading'
|
||||
? html`<button
|
||||
@@ -646,10 +656,6 @@ export class AIChatInput extends SignalWatcher(
|
||||
await this.send(value);
|
||||
};
|
||||
|
||||
private readonly _handleModelChange = (modelId: string) => {
|
||||
this.modelId = modelId;
|
||||
};
|
||||
|
||||
send = async (text: string) => {
|
||||
try {
|
||||
const {
|
||||
@@ -693,6 +699,7 @@ export class AIChatInput extends SignalWatcher(
|
||||
this.affineFeatureFlagService.flags.enable_send_detailed_object_to_ai
|
||||
.value;
|
||||
|
||||
const modelId = this.aiModelService.modelId.value;
|
||||
const stream = await AIProvider.actions.chat({
|
||||
sessionId,
|
||||
input: userInput,
|
||||
@@ -717,7 +724,7 @@ export class AIChatInput extends SignalWatcher(
|
||||
webSearch: this._isNetworkActive,
|
||||
reasoning: this._isReasoningActive,
|
||||
toolsConfig: this.aiToolsConfigService.config.value,
|
||||
modelId: this.modelId,
|
||||
modelId,
|
||||
});
|
||||
|
||||
for await (const text of stream) {
|
||||
|
||||
+95
-25
@@ -1,18 +1,29 @@
|
||||
import type { AIToolsConfigService } from '@affine/core/modules/ai-button';
|
||||
import type { CopilotChatHistoryFragment } from '@affine/graphql';
|
||||
import type { AIModelService } from '@affine/core/modules/ai-button/services/models';
|
||||
import type { SubscriptionService } from '@affine/core/modules/cloud';
|
||||
import {
|
||||
type CopilotChatHistoryFragment,
|
||||
SubscriptionStatus,
|
||||
} from '@affine/graphql';
|
||||
import {
|
||||
menu,
|
||||
popMenu,
|
||||
popupTargetFromElement,
|
||||
} from '@blocksuite/affine/components/context-menu';
|
||||
import { SignalWatcher, WithDisposable } from '@blocksuite/affine/global/lit';
|
||||
import { unsafeCSSVarV2 } from '@blocksuite/affine/shared/theme';
|
||||
import type { NotificationService } from '@blocksuite/affine-shared/services';
|
||||
import {
|
||||
AiOutlineIcon,
|
||||
ArrowDownSmallIcon,
|
||||
CloudWorkspaceIcon,
|
||||
DoneIcon,
|
||||
LockIcon,
|
||||
ThinkingIcon,
|
||||
WebIcon,
|
||||
} from '@blocksuite/icons/lit';
|
||||
import { ShadowlessElement } from '@blocksuite/std';
|
||||
import { computed } from '@preact/signals-core';
|
||||
import { css, html } from 'lit';
|
||||
import { property } from 'lit/decorators.js';
|
||||
|
||||
@@ -48,16 +59,29 @@ export class ChatInputPreference extends SignalWatcher(
|
||||
white-space: nowrap;
|
||||
min-width: 220px;
|
||||
}
|
||||
.ai-active-model-name {
|
||||
font-size: 14px;
|
||||
color: ${unsafeCSSVarV2('text/secondary')};
|
||||
line-height: 22px;
|
||||
margin-left: 40px;
|
||||
}
|
||||
.ai-model-prefix {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
}
|
||||
.ai-model-prefix svg {
|
||||
color: ${unsafeCSSVarV2('icon/activated')};
|
||||
}
|
||||
.ai-model-version {
|
||||
font-size: 12px;
|
||||
color: ${unsafeCSSVarV2('text/tertiary')};
|
||||
line-height: 20px;
|
||||
margin-right: 40px;
|
||||
}
|
||||
`;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor session!: CopilotChatHistoryFragment | null | undefined;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor onModelChange: ((modelId: string) => void) | undefined;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor modelId: string | undefined = undefined;
|
||||
// --------- model props end ---------
|
||||
|
||||
// --------- extended thinking props start ---------
|
||||
@@ -86,9 +110,25 @@ export class ChatInputPreference extends SignalWatcher(
|
||||
@property({ attribute: false })
|
||||
accessor toolsConfigService!: AIToolsConfigService;
|
||||
|
||||
// private readonly _onModelChange = (modelId: string) => {
|
||||
// this.onModelChange?.(modelId);
|
||||
// };
|
||||
@property({ attribute: false })
|
||||
accessor notificationService!: NotificationService;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor subscriptionService!: SubscriptionService;
|
||||
|
||||
@property({ attribute: false })
|
||||
accessor aiModelService!: AIModelService;
|
||||
|
||||
model = computed(() => {
|
||||
const modelId = this.aiModelService.modelId.value;
|
||||
const activeModel = this.aiModelService.models.value.find(
|
||||
model => model.id === modelId
|
||||
);
|
||||
const defaultModel = this.aiModelService.models.value.find(
|
||||
model => model.isDefault
|
||||
);
|
||||
return activeModel || defaultModel;
|
||||
});
|
||||
|
||||
openPreference(e: Event) {
|
||||
const element = e.currentTarget;
|
||||
@@ -97,20 +137,48 @@ export class ChatInputPreference extends SignalWatcher(
|
||||
const searchItems = [];
|
||||
|
||||
// model switch
|
||||
// modelItems.push(
|
||||
// menu.subMenu({
|
||||
// name: 'Model',
|
||||
// prefix: AiOutlineIcon(),
|
||||
// options: {
|
||||
// items: (this.session?.optionalModels ?? []).map(modelId => {
|
||||
// return menu.action({
|
||||
// name: modelId,
|
||||
// select: () => this._onModelChange(modelId),
|
||||
// });
|
||||
// }),
|
||||
// },
|
||||
// })
|
||||
// );
|
||||
modelItems.push(
|
||||
menu.subMenu({
|
||||
name: 'Model',
|
||||
prefix: AiOutlineIcon(),
|
||||
postfix: html`
|
||||
<span class="ai-active-model-name"> ${this.model.value?.name} </span>
|
||||
`,
|
||||
options: {
|
||||
items: this.aiModelService.models.value.map(model => {
|
||||
const isSelected = model.id === this.model.value?.id;
|
||||
const status =
|
||||
this.subscriptionService.subscription.ai$.value?.status;
|
||||
const isSubscribed = status === SubscriptionStatus.Active;
|
||||
return menu.action({
|
||||
name: model.category,
|
||||
info: html`
|
||||
<span class="ai-model-version">${model.version}</span>
|
||||
`,
|
||||
prefix: html`
|
||||
<div class="ai-model-prefix">
|
||||
${isSelected ? DoneIcon() : undefined}
|
||||
</div>
|
||||
`,
|
||||
postfix: html`
|
||||
<div>
|
||||
${model.isPro && !isSubscribed ? LockIcon() : undefined}
|
||||
</div>
|
||||
`,
|
||||
select: () => {
|
||||
if (model.isPro && !isSubscribed) {
|
||||
this.notificationService.toast(
|
||||
`Pro models require an AFFiNE AI subscription.`
|
||||
);
|
||||
return;
|
||||
}
|
||||
this.aiModelService.setModel(model.id);
|
||||
},
|
||||
});
|
||||
}),
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
modelItems.push(
|
||||
menu.toggleSwitch({
|
||||
@@ -169,7 +237,9 @@ export class ChatInputPreference extends SignalWatcher(
|
||||
data-testid="chat-input-preference-trigger"
|
||||
class="chat-input-preference-trigger"
|
||||
>
|
||||
<span class="chat-input-preference-trigger-label"> Claude </span>
|
||||
<span class="chat-input-preference-trigger-label">
|
||||
${this.model.value?.category}
|
||||
</span>
|
||||
<span class="chat-input-preference-trigger-icon">
|
||||
${ArrowDownSmallIcon()}
|
||||
</span>
|
||||
|
||||
@@ -15,10 +15,12 @@ import {
|
||||
AIDraftService,
|
||||
AIToolsConfigService,
|
||||
} from '@affine/core/modules/ai-button';
|
||||
import { AIModelService } from '@affine/core/modules/ai-button/services/models';
|
||||
import {
|
||||
EventSourceService,
|
||||
FetchService,
|
||||
GraphQLService,
|
||||
SubscriptionService,
|
||||
} from '@affine/core/modules/cloud';
|
||||
import { WorkspaceDialogService } from '@affine/core/modules/dialogs';
|
||||
import { FeatureFlagService } from '@affine/core/modules/feature-flag';
|
||||
@@ -229,6 +231,8 @@ export const Component = () => {
|
||||
);
|
||||
content.aiDraftService = framework.get(AIDraftService);
|
||||
content.aiToolsConfigService = framework.get(AIToolsConfigService);
|
||||
content.subscriptionService = framework.get(SubscriptionService);
|
||||
content.aiModelService = framework.get(AIModelService);
|
||||
content.createSession = createSession;
|
||||
content.onOpenDoc = onOpenDoc;
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ import {
|
||||
AIDraftService,
|
||||
AIToolsConfigService,
|
||||
} from '@affine/core/modules/ai-button';
|
||||
import { AIModelService } from '@affine/core/modules/ai-button/services/models';
|
||||
import { SubscriptionService } from '@affine/core/modules/cloud';
|
||||
import { WorkspaceDialogService } from '@affine/core/modules/dialogs';
|
||||
import { FeatureFlagService } from '@affine/core/modules/feature-flag';
|
||||
import { PeekViewService } from '@affine/core/modules/peek-view';
|
||||
@@ -104,6 +106,9 @@ export const EditorChatPanel = forwardRef(function EditorChatPanel(
|
||||
chatPanelRef.current.aiDraftService = framework.get(AIDraftService);
|
||||
chatPanelRef.current.aiToolsConfigService =
|
||||
framework.get(AIToolsConfigService);
|
||||
chatPanelRef.current.subscriptionService =
|
||||
framework.get(SubscriptionService);
|
||||
chatPanelRef.current.aiModelService = framework.get(AIModelService);
|
||||
|
||||
containerRef.current?.append(chatPanelRef.current);
|
||||
} else {
|
||||
|
||||
@@ -8,12 +8,14 @@ export {
|
||||
|
||||
import type { Framework } from '@toeverything/infra';
|
||||
|
||||
import { GraphQLService, ServerScope, SubscriptionService } from '../cloud';
|
||||
import { FeatureFlagService } from '../feature-flag';
|
||||
import { CacheStorage, GlobalStateService } from '../storage';
|
||||
import { WorkspaceScope } from '../workspace';
|
||||
import { AIButtonProvider } from './provider/ai-button';
|
||||
import { AIButtonService } from './services/ai-button';
|
||||
import { AIDraftService } from './services/ai-draft';
|
||||
import { AIModelService } from './services/models';
|
||||
import { AINetworkSearchService } from './services/network-search';
|
||||
import { AIPlaygroundService } from './services/playground';
|
||||
import { AIReasoningService } from './services/reasoning';
|
||||
@@ -49,3 +51,13 @@ export function configureAIDraftModule(framework: Framework) {
|
||||
export function configureAIToolsConfigModule(framework: Framework) {
|
||||
framework.service(AIToolsConfigService, [GlobalStateService]);
|
||||
}
|
||||
|
||||
export function configureAIModelModule(framework: Framework) {
|
||||
framework
|
||||
.scope(ServerScope)
|
||||
.service(AIModelService, [
|
||||
GlobalStateService,
|
||||
GraphQLService,
|
||||
SubscriptionService,
|
||||
]);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
import { getPromptModelsQuery, SubscriptionStatus } from '@affine/graphql';
|
||||
import {
|
||||
createSignalFromObservable,
|
||||
type Signal,
|
||||
} from '@blocksuite/affine/shared/utils';
|
||||
import { signal } from '@preact/signals-core';
|
||||
import { LiveData, Service } from '@toeverything/infra';
|
||||
|
||||
import type { GraphQLService, SubscriptionService } from '../../cloud';
|
||||
import type { GlobalStateService } from '../../storage';
|
||||
|
||||
const AI_MODEL_ID_KEY = 'AIModelId';
|
||||
|
||||
export interface AIModel {
|
||||
name: string;
|
||||
id: string;
|
||||
version: string;
|
||||
category: string;
|
||||
isPro: boolean;
|
||||
isDefault: boolean;
|
||||
}
|
||||
|
||||
export class AIModelService extends Service {
|
||||
modelId: Signal<string | undefined>;
|
||||
|
||||
models: Signal<AIModel[]> = signal([]);
|
||||
|
||||
private readonly modelId$ = LiveData.from(
|
||||
this.globalStateService.globalState.watch<string>(AI_MODEL_ID_KEY),
|
||||
undefined
|
||||
);
|
||||
|
||||
constructor(
|
||||
private readonly globalStateService: GlobalStateService,
|
||||
private readonly gqlService: GraphQLService,
|
||||
private readonly subscriptionService: SubscriptionService
|
||||
) {
|
||||
super();
|
||||
|
||||
const { signal: modelId, cleanup } = createSignalFromObservable<
|
||||
string | undefined
|
||||
>(this.modelId$, undefined);
|
||||
this.modelId = modelId;
|
||||
this.disposables.push(cleanup);
|
||||
|
||||
this.init().catch(err => {
|
||||
console.error(err);
|
||||
});
|
||||
}
|
||||
|
||||
resetModel = () => {
|
||||
this.globalStateService.globalState.set(AI_MODEL_ID_KEY, undefined);
|
||||
};
|
||||
|
||||
setModel = (modelId: string) => {
|
||||
const isSubscribed =
|
||||
this.subscriptionService.subscription.ai$.value?.status ===
|
||||
SubscriptionStatus.Active;
|
||||
const model = this.models.value.find(model => model.id === modelId);
|
||||
if (!isSubscribed && model?.isPro) {
|
||||
return;
|
||||
}
|
||||
this.globalStateService.globalState.set(AI_MODEL_ID_KEY, modelId);
|
||||
};
|
||||
|
||||
private readonly init = async () => {
|
||||
await this.initModels();
|
||||
|
||||
// subscribe to ai purchase status
|
||||
const sub = this.subscriptionService.subscription.ai$.subscribe(
|
||||
subscription => {
|
||||
const isSubscribed = subscription?.status === SubscriptionStatus.Active;
|
||||
const model = this.models.value.find(
|
||||
model => model.id === this.modelId.value
|
||||
);
|
||||
if (!isSubscribed && model?.isPro) {
|
||||
this.resetModel();
|
||||
}
|
||||
}
|
||||
);
|
||||
this.disposables.push(() => sub.unsubscribe());
|
||||
};
|
||||
|
||||
private readonly initModels = async (prompt?: string) => {
|
||||
const promptName = prompt || 'Chat With AFFiNE AI';
|
||||
const models = await this.getModelsByPrompt(promptName);
|
||||
if (models) {
|
||||
const { defaultModel, optionalModels, proModels } = models;
|
||||
this.models.value = optionalModels.map(model => {
|
||||
const [category] = model.name.split(' ');
|
||||
const version = model.name.slice(category.length + 1);
|
||||
return {
|
||||
name: model.name,
|
||||
id: model.id,
|
||||
version,
|
||||
category,
|
||||
isPro: proModels.some(proModel => proModel.id === model.id),
|
||||
isDefault: model.id === defaultModel,
|
||||
};
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
private readonly getModelsByPrompt = async (promptName: string) => {
|
||||
return this.gqlService
|
||||
.gql({
|
||||
query: getPromptModelsQuery,
|
||||
variables: { promptName },
|
||||
})
|
||||
.then(res => res.currentUser?.copilot?.models);
|
||||
};
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import { type Framework } from '@toeverything/infra';
|
||||
import {
|
||||
configureAIButtonModule,
|
||||
configureAIDraftModule,
|
||||
configureAIModelModule,
|
||||
configureAINetworkSearchModule,
|
||||
configureAIPlaygroundModule,
|
||||
configureAIReasoningModule,
|
||||
@@ -117,6 +118,7 @@ export function configureCommonModules(framework: Framework) {
|
||||
configureAIButtonModule(framework);
|
||||
configureAIDraftModule(framework);
|
||||
configureAIToolsConfigModule(framework);
|
||||
configureAIModelModule(framework);
|
||||
configureTemplateDocModule(framework);
|
||||
configureBlobManagementModule(framework);
|
||||
configureMediaModule(framework);
|
||||
|
||||
Reference in New Issue
Block a user