From e222f06e940270bd7d3334d4eeaa0b628b78e546 Mon Sep 17 00:00:00 2001 From: DarkSky <25152247+darkskygit@users.noreply.github.com> Date: Wed, 13 May 2026 21:57:50 +0800 Subject: [PATCH] feat(editor): extract chat runtime (#14937) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit #### PR Dependency Tree * **PR #14937** 👈 This tree was auto-generated by [Charcoal](https://github.com/danerwilliams/charcoal) ## Summary by CodeRabbit * **New Features** * Centralized AI event system and a runtime powering chat sessions and actions. * **Improvements** * Chat UI (composer, messages, toolbar, tabs, panels) now syncs with runtime snapshots for more consistent state. * Improved session/tab lifecycle (create, fork, delete), context embedding status, and history handling. * More reliable send/stop/retry flows, better telemetry scoping, and clearer upgrade/login/insert-template prompts. --- .../ai/_common/chat-actions-handle.ts | 33 +- .../core/src/blocksuite/ai/_common/config.ts | 4 +- .../src/blocksuite/ai/actions/doc-handler.ts | 22 +- .../blocksuite/ai/actions/edgeless-handler.ts | 25 +- .../ai/actions/edgeless-response.ts | 16 +- .../core/src/blocksuite/ai/ai-panel.ts | 24 +- .../ai-chat-composer/ai-chat-composer.ts | 521 +++---- .../ai-chat-content/ai-chat-content.spec.ts | 136 +- .../ai-chat-content/ai-chat-content.ts | 151 +- .../components/ai-chat-input/ai-chat-input.ts | 228 +-- .../ai-chat-messages/ai-chat-messages.spec.ts | 31 + .../ai-chat-messages/ai-chat-messages.ts | 146 +- .../ai-chat-messages/preload-config.ts | 12 +- .../ai-chat-toolbar/ai-chat-tabs.ts | 84 +- .../ai-chat-toolbar/ai-chat-toolbar.ts | 56 +- .../ai-chat-toolbar/ai-session-history.ts | 118 +- .../configure-ai-chat-toolbar.ts | 18 +- .../ai-history-clear/ai-history-clear.ts | 13 +- .../ai/components/ask-ai-toolbar.ts | 4 +- .../ai/components/playground/chat.ts | 156 +- .../ai/components/playground/content.ts | 55 +- .../ai/entries/edgeless/actions-config.ts | 4 +- .../blocksuite/ai/entries/edgeless/index.ts | 6 +- .../ai/entries/space/setup-space.ts | 4 +- .../frontend/core/src/blocksuite/ai/index.ts | 2 + .../core/src/blocksuite/ai/messages/error.ts | 6 +- .../ai/peek-view/chat-block-peek-view.ts | 218 ++- .../blocksuite/ai/provider/ai-app-events.ts | 19 + .../src/blocksuite/ai/provider/ai-provider.ts | 265 ---- .../blocksuite/ai/provider/event-source.ts | 2 +- .../core/src/blocksuite/ai/provider/index.ts | 2 +- .../blocksuite/ai/provider/request.spec.ts | 284 ---- .../ai/provider/setup-provider.spec.ts | 120 -- .../blocksuite/ai/provider/setup-provider.tsx | 776 +--------- .../src/blocksuite/ai/provider/tracker.ts | 39 +- .../src/blocksuite/ai/runtime/chat/actions.ts | 60 + .../src/blocksuite/ai/runtime/chat/index.ts | 6 + .../ai/runtime/chat/runtime.spec.ts | 1001 +++++++++++++ .../src/blocksuite/ai/runtime/chat/runtime.ts | 1311 +++++++++++++++++ .../ai/runtime/chat/session-strategy.ts | 225 +++ .../src/blocksuite/ai/runtime/chat/state.ts | 212 +++ .../ai/runtime/chat/use-element.spec.tsx | 90 ++ .../blocksuite/ai/runtime/chat/use-element.ts | 62 + .../blocksuite/ai/runtime/chat/use-runtime.ts | 22 + .../ai/runtime/request/action-definitions.ts | 225 +++ .../ai/runtime/request/byok-local-lease.ts | 97 ++ .../request}/copilot-client.spec.ts | 0 .../request}/copilot-client.ts | 2 +- .../blocksuite/ai/runtime/request/index.ts | 4 + .../request/message-transport.ts} | 113 +- .../blocksuite/ai/runtime/request/provider.ts | 18 + .../ai/runtime/request/service.spec.ts | 379 +++++ .../blocksuite/ai/runtime/request/service.ts | 396 +++++ .../blocksuite/ai/utils/action-reporter.ts | 14 +- .../ai/widgets/edgeless-copilot/index.ts | 4 +- .../providers/workspace-side-effects.tsx | 13 +- .../pages/workspace/chat-panel-utils.ts | 142 -- .../desktop/pages/workspace/chat/index.tsx | 488 ++---- .../workspace/detail-page/detail-page.tsx | 13 +- .../tabs/chat-panel-session.spec.ts | 341 ----- .../detail-page/tabs/chat-panel-session.ts | 208 --- .../pages/workspace/detail-page/tabs/chat.tsx | 800 +++------- .../view/doc-preview/doc-peek-view.tsx | 10 +- .../e2e/utils/settings-panel-utils.ts | 2 +- 64 files changed, 5348 insertions(+), 4510 deletions(-) create mode 100644 packages/frontend/core/src/blocksuite/ai/provider/ai-app-events.ts delete mode 100644 packages/frontend/core/src/blocksuite/ai/provider/request.spec.ts delete mode 100644 packages/frontend/core/src/blocksuite/ai/provider/setup-provider.spec.ts create mode 100644 packages/frontend/core/src/blocksuite/ai/runtime/chat/actions.ts create mode 100644 packages/frontend/core/src/blocksuite/ai/runtime/chat/index.ts create mode 100644 packages/frontend/core/src/blocksuite/ai/runtime/chat/runtime.spec.ts create mode 100644 packages/frontend/core/src/blocksuite/ai/runtime/chat/runtime.ts create mode 100644 packages/frontend/core/src/blocksuite/ai/runtime/chat/session-strategy.ts create mode 100644 packages/frontend/core/src/blocksuite/ai/runtime/chat/state.ts create mode 100644 packages/frontend/core/src/blocksuite/ai/runtime/chat/use-element.spec.tsx create mode 100644 packages/frontend/core/src/blocksuite/ai/runtime/chat/use-element.ts create mode 100644 packages/frontend/core/src/blocksuite/ai/runtime/chat/use-runtime.ts create mode 100644 packages/frontend/core/src/blocksuite/ai/runtime/request/action-definitions.ts create mode 100644 packages/frontend/core/src/blocksuite/ai/runtime/request/byok-local-lease.ts rename packages/frontend/core/src/blocksuite/ai/{provider => runtime/request}/copilot-client.spec.ts (100%) rename packages/frontend/core/src/blocksuite/ai/{provider => runtime/request}/copilot-client.ts (99%) create mode 100644 packages/frontend/core/src/blocksuite/ai/runtime/request/index.ts rename packages/frontend/core/src/blocksuite/ai/{provider/request.ts => runtime/request/message-transport.ts} (70%) create mode 100644 packages/frontend/core/src/blocksuite/ai/runtime/request/provider.ts create mode 100644 packages/frontend/core/src/blocksuite/ai/runtime/request/service.spec.ts create mode 100644 packages/frontend/core/src/blocksuite/ai/runtime/request/service.ts delete mode 100644 packages/frontend/core/src/desktop/pages/workspace/chat-panel-utils.ts delete mode 100644 packages/frontend/core/src/desktop/pages/workspace/detail-page/tabs/chat-panel-session.spec.ts delete mode 100644 packages/frontend/core/src/desktop/pages/workspace/detail-page/tabs/chat-panel-session.ts diff --git a/packages/frontend/core/src/blocksuite/ai/_common/chat-actions-handle.ts b/packages/frontend/core/src/blocksuite/ai/_common/chat-actions-handle.ts index 431d1e2c21..3e88b53658 100644 --- a/packages/frontend/core/src/blocksuite/ai/_common/chat-actions-handle.ts +++ b/packages/frontend/core/src/blocksuite/ai/_common/chat-actions-handle.ts @@ -42,7 +42,9 @@ import type { TemplateResult } from 'lit'; import { insertFromMarkdown } from '../../utils'; import type { ChatMessage } from '../components/ai-chat-messages'; -import { AIProvider, type AIUserInfo } from '../provider'; +import { AIAppEvents, type AIUserInfo } from '../provider'; +import { AIChatRuntime, ForkAIChatSessionStrategy } from '../runtime/chat'; +import { getAIRequestService } from '../runtime/request'; import { reportResponse } from '../utils/action-reporter'; import { insertBelow } from '../utils/editor-actions'; @@ -72,7 +74,7 @@ export async function queryHistoryMessages( docId?: string ) { // Get fork session messages - const histories = await AIProvider.histories?.chats( + const histories = await getAIRequestService().histories.chats( workspaceId, forkSessionId, docId @@ -114,7 +116,7 @@ export async function constructRootChatBlockMessages( forkSessionId: string ) { // Convert chat messages to AI chat block messages - const userInfo = await AIProvider.userInfo; + const userInfo = AIAppEvents.userInfo.value; const forkMessages = (await queryHistoryMessages( doc.workspace.id, forkSessionId, @@ -247,7 +249,7 @@ async function insertBelowBlock( ): Promise { if (!block) return false; - reportResponse('result:insert'); + reportResponse('result:insert', host); await insertBelow(host, content, block); return true; } @@ -376,13 +378,20 @@ const SAVE_AS_BLOCK: ChatAction = { }); } - try { - const newSessionId = await AIProvider.forkChat?.({ + const runtime = new AIChatRuntime({ + request: getAIRequestService(), + scope: { + kind: 'fork', workspaceId: host.store.workspace.id, docId: host.store.id, - sessionId: parentSessionId, + parentSessionId, latestMessageId: messageId, - }); + }, + strategy: new ForkAIChatSessionStrategy(), + }); + try { + const newSession = await runtime.createSession(); + const newSessionId = newSession?.sessionId; if (!newSessionId) { return false; @@ -424,6 +433,8 @@ const SAVE_AS_BLOCK: ChatAction = { onClose: function (): void {}, }); return false; + } finally { + runtime.dispose(); } }, }; @@ -439,7 +450,7 @@ const ADD_TO_EDGELESS_AS_NOTE = { }, toast: 'New note created', handler: async (host: EditorHost, content: string): Promise => { - reportResponse('result:add-note'); + reportResponse('result:add-note', host); const { store } = host; const gfx = host.std.get(GfxControllerIdentifier); @@ -475,7 +486,7 @@ export const SAVE_AS_DOC = { showWhen: () => true, toast: 'New doc created', handler: (host: EditorHost, content: string) => { - reportResponse('result:add-page'); + reportResponse('result:add-page', host); const doc = host.store.workspace.createDoc(); const newDoc = doc.getStore(); newDoc.load(); @@ -518,7 +529,7 @@ const CREATE_AS_LINKED_DOC = { }, toast: 'New doc created', handler: async (host: EditorHost, content: string) => { - reportResponse('result:add-page'); + reportResponse('result:add-page', host); const { store } = host; const surfaceBlock = store diff --git a/packages/frontend/core/src/blocksuite/ai/_common/config.ts b/packages/frontend/core/src/blocksuite/ai/_common/config.ts index 7de82072b9..28af0c3529 100644 --- a/packages/frontend/core/src/blocksuite/ai/_common/config.ts +++ b/packages/frontend/core/src/blocksuite/ai/_common/config.ts @@ -36,7 +36,7 @@ import type { AIItemGroupConfig, AISubItemConfig, } from '../components/ai-item/types'; -import { AIProvider } from '../provider'; +import { AIAppEvents } from '../provider'; import { getAIPanelWidget } from '../utils/ai-widgets'; import { getEdgelessCopilotWidget } from '../utils/get-edgeless-copilot-widget'; import { @@ -386,7 +386,7 @@ const OthersAIGroup: AIItemGroupConfig = { handler: host => { const panel = getAIPanelWidget(host); const edgelessCopilot = getEdgelessCopilotWidget(host); - AIProvider.slots.requestOpenWithChat.next({ + AIAppEvents.requestOpenWithChat.next({ host, autoSelect: true, }); diff --git a/packages/frontend/core/src/blocksuite/ai/actions/doc-handler.ts b/packages/frontend/core/src/blocksuite/ai/actions/doc-handler.ts index 094ab1f667..585999e06a 100644 --- a/packages/frontend/core/src/blocksuite/ai/actions/doc-handler.ts +++ b/packages/frontend/core/src/blocksuite/ai/actions/doc-handler.ts @@ -12,7 +12,8 @@ import { } from '../ai-panel'; import { StreamObjectSchema } from '../components/ai-chat-messages'; import { type AIItemGroupConfig } from '../components/ai-item/types'; -import { type AIError, AIProvider } from '../provider'; +import { type AIError } from '../provider'; +import { getAIRequestService } from '../runtime/request'; import { reportResponse } from '../utils/action-reporter'; import { getAIPanelWidget } from '../utils/ai-widgets'; import { AIContext } from '../utils/context'; @@ -33,10 +34,12 @@ export function bindTextStream( update, finish, signal, + host, }: { update: (answer: AIActionAnswer) => void; finish: (state: 'success' | 'error' | 'aborted', err?: AIError) => void; signal?: AbortSignal; + host?: EditorHost; } ) { (async () => { @@ -45,7 +48,7 @@ export function bindTextStream( }; signal?.addEventListener('abort', () => { finish('aborted'); - reportResponse('aborted:stop'); + reportResponse('aborted:stop', host); }); for await (const data of stream) { if (signal?.aborted) { @@ -88,9 +91,6 @@ function actionToStream( >, trackerOptions?: BlockSuitePresets.TrackerOptions ): BlockSuitePresets.TextStream | undefined { - const action = AIProvider.actions[id]; - if (!action || typeof action !== 'function') return; - let stream: BlockSuitePresets.TextStream | undefined; return { async *[Symbol.asyncIterator]() { @@ -123,9 +123,11 @@ function actionToStream( where, docId: host.store.id, workspaceId: host.store.workspace.id, - } as Parameters[0]; - // @ts-expect-error TODO(@Peng): maybe fix this - stream = await action(options); + } as BlockSuitePresets.AITextActionOptions & Record; + stream = (await getAIRequestService().executeAction( + id, + options + )) as BlockSuitePresets.TextStream; if (!stream) return; yield* stream; }, @@ -163,7 +165,7 @@ function actionToGenerateAnswer( trackerOptions ); if (!stream) return; - bindTextStream(stream, { update, finish, signal }); + bindTextStream(stream, { update, finish, signal, host }); }; } @@ -198,7 +200,7 @@ function updateAIPanelConfig( config.errorStateConfig = buildErrorConfig(aiPanel); config.copy = buildCopyConfig(aiPanel); config.discardCallback = () => { - reportResponse('result:discard'); + reportResponse('result:discard', host); }; } diff --git a/packages/frontend/core/src/blocksuite/ai/actions/edgeless-handler.ts b/packages/frontend/core/src/blocksuite/ai/actions/edgeless-handler.ts index fc34c2eb17..fcfe9a1034 100644 --- a/packages/frontend/core/src/blocksuite/ai/actions/edgeless-handler.ts +++ b/packages/frontend/core/src/blocksuite/ai/actions/edgeless-handler.ts @@ -21,7 +21,8 @@ import type { TemplateResult } from 'lit'; import { getContentFromSlice } from '../../utils'; import { AIChatBlockModel } from '../blocks'; -import { type AIError, AIProvider } from '../provider'; +import { type AIError } from '../provider'; +import { getAIRequestService } from '../runtime/request'; import { reportResponse } from '../utils/action-reporter'; import { getAIPanelWidget } from '../utils/ai-widgets'; import { AIContext } from '../utils/context'; @@ -178,10 +179,6 @@ function actionToStream( trackerOptions?: BlockSuitePresets.TrackerOptions, panelInput?: string ) { - const action = AIProvider.actions[id]; - - if (!action || typeof action !== 'function') return; - if (extract && typeof extract === 'function') { return (host: EditorHost, ctx: AIContext): BlockSuitePresets.TextStream => { let stream: BlockSuitePresets.TextStream | undefined; @@ -201,7 +198,7 @@ function actionToStream( host, docId: host.store.id, workspaceId: host.store.workspace.id, - } as Parameters[0]; + } as BlockSuitePresets.AITextActionOptions & Record; const content = ctx.get().content; if (typeof content === 'string' && !content.length && panelInput) { @@ -214,8 +211,10 @@ function actionToStream( Object.assign(options, data); } - // @ts-expect-error TODO(@Peng): maybe fix this - stream = await action(options); + stream = (await getAIRequestService().executeAction( + id, + options + )) as BlockSuitePresets.TextStream; if (!stream) return; yield* stream; }, @@ -242,10 +241,12 @@ function actionToStream( host, docId: host.store.id, workspaceId: host.store.workspace.id, - } as Parameters[0]; + } as BlockSuitePresets.AITextActionOptions & Record; - // @ts-expect-error TODO(@Peng): maybe fix this - stream = await action(options); + stream = (await getAIRequestService().executeAction( + id, + options + )) as BlockSuitePresets.TextStream; if (!stream) return; yield* stream; }, @@ -351,7 +352,7 @@ function updateEdgelessAIPanelConfig< }, }; config.discardCallback = () => { - reportResponse('result:discard'); + reportResponse('result:discard', host); }; config.hideCallback = () => { aiPanel.updateComplete diff --git a/packages/frontend/core/src/blocksuite/ai/actions/edgeless-response.ts b/packages/frontend/core/src/blocksuite/ai/actions/edgeless-response.ts index e8e7ac0856..4ee6fadb0a 100644 --- a/packages/frontend/core/src/blocksuite/ai/actions/edgeless-response.ts +++ b/packages/frontend/core/src/blocksuite/ai/actions/edgeless-response.ts @@ -36,7 +36,7 @@ import { styleMap } from 'lit/directives/style-map.js'; import { insertFromMarkdown } from '../../utils'; import type { ChatContextValue } from '../components/ai-chat-content/type'; import type { AIItemConfig } from '../components/ai-item/types'; -import { AIProvider } from '../provider'; +import { AIAppEvents } from '../provider'; import { reportResponse } from '../utils/action-reporter'; import { getAIPanelWidget } from '../utils/ai-widgets'; import type { AIContext } from '../utils/context'; @@ -98,7 +98,7 @@ export function retry(panel: AffineAIPanelWidget): AIItemConfig { icon: ResetIcon(), testId: 'answer-retry', handler: () => { - reportResponse('result:retry'); + reportResponse('result:retry', panel.host); panel.generate(); }, }; @@ -152,7 +152,7 @@ export function createInsertItems( ); }, handler: () => { - reportResponse('result:insert'); + reportResponse('result:insert', host); edgelessResponseHandler(id, host, ctx).catch(console.error); const panel = getAIPanelWidget(host); panel.hide(); @@ -201,7 +201,7 @@ export function asCaption( return id === 'generateCaption' && !!panel.answer; }, handler: () => { - reportResponse('result:use-as-caption'); + reportResponse('result:use-as-caption', host); const panel = getAIPanelWidget(host); const caption = panel.answer; if (!caption) return; @@ -589,7 +589,7 @@ export function actionToResponse( testId: 'answer-continue-in-chat', icon: ChatWithAiIcon({}), handler: () => { - reportResponse('result:continue-in-chat'); + reportResponse('result:continue-in-chat', host); edgelesContinueResponseHandler(id, host, ctx).catch( console.error ); @@ -700,7 +700,7 @@ async function edgelesContinueResponseHandler< } const panel = getAIPanelWidget(host); - AIProvider.slots.requestOpenWithChat.next({ + AIAppEvents.requestOpenWithChat.next({ host, context, fromAnswer: true, @@ -732,11 +732,11 @@ export function actionToErrorResponse< ): ErrorConfig { return { upgrade: () => { - AIProvider.slots.requestUpgradePlan.next({ host: panel.host }); + AIAppEvents.requestUpgradePlan.next({ host: panel.host }); panel.hide(); }, login: () => { - AIProvider.slots.requestLogin.next({ host: panel.host }); + AIAppEvents.requestLogin.next({ host: panel.host }); panel.hide(); }, cancel: () => { diff --git a/packages/frontend/core/src/blocksuite/ai/ai-panel.ts b/packages/frontend/core/src/blocksuite/ai/ai-panel.ts index 5d4414d9f4..9f35b86e89 100644 --- a/packages/frontend/core/src/blocksuite/ai/ai-panel.ts +++ b/packages/frontend/core/src/blocksuite/ai/ai-panel.ts @@ -35,7 +35,7 @@ import { } from './actions/page-response'; import type { AIItemConfig } from './components/ai-item/types'; import { createAIScrollableTextRenderer } from './components/ai-scrollable-text-renderer'; -import { AIProvider } from './provider'; +import { AIAppEvents } from './provider'; import { reportResponse } from './utils/action-reporter'; import { getAIPanelWidget } from './utils/ai-widgets'; import { AIContext } from './utils/context'; @@ -58,7 +58,7 @@ function asCaption( return id === 'generateCaption' && !!panel.answer; }, handler: () => { - reportResponse('result:use-as-caption'); + reportResponse('result:use-as-caption', host); const panel = getAIPanelWidget(host); const caption = panel.answer; if (!caption) return; @@ -85,7 +85,7 @@ function createNewNote(host: EditorHost): AIItemConfig { return !!panel.answer && isInsideEdgelessEditor(host); }, handler: () => { - reportResponse('result:add-note'); + reportResponse('result:add-note', host); // get the note block const { selectedBlocks } = getSelections(host); if (!selectedBlocks || !selectedBlocks.length) return; @@ -157,7 +157,7 @@ function buildPageResponseConfig( showWhen: () => !!panel.answer && (!id || !INSERT_ABOVE_ACTIONS.includes(id)), handler: () => { - reportResponse('result:insert'); + reportResponse('result:insert', host); pageResponseHandler(id, host, ctx, 'after').catch(console.error); panel.hide(); }, @@ -169,7 +169,7 @@ function buildPageResponseConfig( showWhen: () => !!panel.answer && !!id && INSERT_ABOVE_ACTIONS.includes(id), handler: () => { - reportResponse('result:insert'); + reportResponse('result:insert', host); pageResponseHandler(id, host, ctx, 'before').catch(console.error); panel.hide(); }, @@ -182,7 +182,7 @@ function buildPageResponseConfig( showWhen: () => !!panel.answer && !EXCLUDING_REPLACE_ACTIONS.includes(id), handler: () => { - reportResponse('result:replace'); + reportResponse('result:replace', host); replaceWithMarkdown(host).catch(console.error); panel.hide(); }, @@ -199,8 +199,8 @@ function buildPageResponseConfig( icon: ChatWithAiIcon(), testId: 'answer-continue-in-chat', handler: () => { - reportResponse('result:continue-in-chat'); - AIProvider.slots.requestOpenWithChat.next({ host }); + reportResponse('result:continue-in-chat', host); + AIAppEvents.requestOpenWithChat.next({ host }); panel.hide(); }, }, @@ -209,7 +209,7 @@ function buildPageResponseConfig( icon: ResetIcon(), testId: 'answer-regenerate', handler: () => { - reportResponse('result:retry'); + reportResponse('result:retry', host); panel.generate(); }, }, @@ -237,7 +237,7 @@ export function buildErrorResponseConfig(panel: AffineAIPanelWidget) { testId: 'error-retry', showWhen: () => true, handler: () => { - reportResponse('result:retry'); + reportResponse('result:retry', panel.host); panel.generate(); }, }, @@ -269,11 +269,11 @@ export function buildFinishConfig( export function buildErrorConfig(panel: AffineAIPanelWidget) { return { upgrade: () => { - AIProvider.slots.requestUpgradePlan.next({ host: panel.host }); + AIAppEvents.requestUpgradePlan.next({ host: panel.host }); panel.hide(); }, login: () => { - AIProvider.slots.requestLogin.next({ host: panel.host }); + AIAppEvents.requestLogin.next({ host: panel.host }); panel.hide(); }, cancel: () => { diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-composer/ai-chat-composer.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-composer/ai-chat-composer.ts index bed0858a20..3508c9b09d 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-composer/ai-chat-composer.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-composer/ai-chat-composer.ts @@ -10,14 +10,7 @@ import type { SubscriptionService, } from '@affine/core/modules/cloud'; import type { WorkspaceDialogService } from '@affine/core/modules/dialogs'; -import type { - ContextEmbedStatus, - ContextWorkspaceEmbeddingStatus, - CopilotChatHistoryFragment, - CopilotContextBlob, - CopilotContextDoc, - CopilotContextFile, -} from '@affine/graphql'; +import type { CopilotChatHistoryFragment } from '@affine/graphql'; import { SignalWatcher, WithDisposable } from '@blocksuite/affine/global/lit'; import type { EditorHost } from '@blocksuite/affine/std'; import { ShadowlessElement } from '@blocksuite/affine/std'; @@ -30,14 +23,16 @@ import { css, html, type PropertyValues } from 'lit'; import { property, state } from 'lit/decorators.js'; import { + AIAppEvents, type AIChatParams, - AIProvider, type AISendParams, } from '../../provider'; +import type { AIChatRuntime, AIChatSnapshot } from '../../runtime/chat'; import type { SearchMenuConfig } from '../ai-chat-add-context'; import type { AttachmentChip, ChatChip, + ChipState, CollectionChip, DocChip, DocDisplayConfig, @@ -58,8 +53,6 @@ import { import type { AIChatInputContext, AIReasoningConfig } from '../ai-chat-input'; import { MAX_IMAGE_COUNT } from '../ai-chat-input/const'; -export const EMBEDDING_STATUS_CHECK_INTERVAL = 10000; - export class AIChatComposer extends SignalWatcher( WithDisposable(ShadowlessElement) ) { @@ -91,9 +84,10 @@ export class AIChatComposer extends SignalWatcher( accessor session!: CopilotChatHistoryFragment | null | undefined; @property({ attribute: false }) - accessor createSession!: () => Promise< - CopilotChatHistoryFragment | undefined - >; + accessor runtime: AIChatRuntime | null | undefined; + + @property({ attribute: false }) + accessor runtimeSnapshot: AIChatSnapshot | null | undefined; @property({ attribute: false }) accessor chatContextValue!: AIChatInputContext; @@ -101,11 +95,6 @@ export class AIChatComposer extends SignalWatcher( @property({ attribute: false }) accessor updateContext!: (context: Partial) => void; - @property({ attribute: false }) - accessor onEmbeddingProgressChange: - | ((count: Record) => void) - | undefined; - @property({ attribute: false }) accessor docDisplayConfig!: DocDisplayConfig; @@ -160,12 +149,6 @@ export class AIChatComposer extends SignalWatcher( @state() accessor embeddingCompleted = false; - private _contextId: string | undefined = undefined; - - private _pollAbortController: AbortController | null = null; - - private _pollEmbeddingStatusAbortController: AbortController | null = null; - override render() { return html` { @@ -276,100 +269,115 @@ export class AIChatComposer extends SignalWatcher( return this.chips.some(chip => chip.state === 'processing'); } - private readonly _getContextId = async () => { - if (this._contextId) { - return this._contextId; + private readonly toChipState = (state?: string): ChipState => { + if (state === 'finished' || state === 'processing' || state === 'failed') { + return state; } - - const sessionId = this.session?.sessionId; - if (!sessionId) return; - - const contextId = await AIProvider.context?.getContextId( - this.workspaceId, - sessionId - ); - this._contextId = contextId; - return this._contextId; + return 'processing'; }; - private readonly createContextId = async () => { - if (this._contextId) { - return this._contextId; + private readonly runtimeItemToChip = ( + item: AIChatSnapshot['composer']['context']['items'][number] + ): ChatChip => { + switch (item.kind) { + case 'doc': + return { + docId: item.docId, + state: this.toChipState(item.state), + createdAt: item.createdAt, + tooltip: item.tooltip, + }; + case 'file': + return { + file: item.file, + fileId: item.fileId, + blobId: item.blobId, + state: this.toChipState(item.state), + createdAt: item.createdAt, + tooltip: item.tooltip, + }; + case 'tag': + return { + tagId: item.tagId, + state: this.toChipState(item.state), + createdAt: item.createdAt, + tooltip: item.tooltip, + }; + case 'collection': + return { + collectionId: item.collectionId, + state: this.toChipState(item.state), + createdAt: item.createdAt, + tooltip: item.tooltip, + }; + case 'blob': + return { + sourceId: item.blobId, + name: item.blobId, + state: this.toChipState(item.state), + createdAt: item.createdAt, + tooltip: item.tooltip, + }; } + }; - const sessionId = (await this.createSession())?.sessionId; - if (!sessionId) return; + private readonly syncChipsFromRuntimeSnapshot = ( + snapshot = this.runtimeSnapshot + ) => { + const context = snapshot?.composer.context; + if (!context) return; + this.embeddingCompleted = context.embeddingCompleted; + const selectedChips = this.chips.filter(isSelectedContextChip); + this.updateChips([ + ...context.items.map(this.runtimeItemToChip), + ...selectedChips, + ]); + }; - this._contextId = await AIProvider.context?.createContext( - this.workspaceId, - sessionId - ); - return this._contextId; + private readonly syncChipsFromRuntime = () => { + this.syncChipsFromRuntimeSnapshot(this.runtime?.getSnapshot()); + }; + + private readonly chipToContextItem = ( + chip: ChatChip + ): AIChatSnapshot['composer']['context']['items'][number] | null => { + if (isDocChip(chip)) { + return { kind: 'doc', docId: chip.docId, state: chip.state }; + } + if (isFileChip(chip)) { + return { + kind: 'file', + file: chip.file, + fileId: chip.fileId ?? undefined, + blobId: chip.blobId ?? undefined, + state: chip.state, + }; + } + if (isTagChip(chip)) { + return { + kind: 'tag', + tagId: chip.tagId, + docIds: this.docDisplayConfig.getTagPageIds(chip.tagId), + state: chip.state, + }; + } + if (isCollectionChip(chip)) { + return { + kind: 'collection', + collectionId: chip.collectionId, + docIds: this.docDisplayConfig.getCollectionPageIds(chip.collectionId), + state: chip.state, + }; + } + if (isAttachmentChip(chip)) { + return { kind: 'blob', blobId: chip.sourceId, state: chip.state }; + } + return null; }; private readonly initChips = async () => { - // context not initialized - const sessionId = this.session?.sessionId; - const contextId = await this._getContextId(); - if (!sessionId || !contextId) { - return; - } - - // context initialized, show the chips - const { - docs = [], - files = [], - tags = [], - collections = [], - } = (await AIProvider.context?.getContextDocsAndFiles( - this.workspaceId, - sessionId, - contextId - )) || {}; - - const docChips: DocChip[] = docs.map(doc => ({ - docId: doc.id, - state: doc.status || 'processing', - createdAt: doc.createdAt, - })); - - const fileChips: FileChip[] = await Promise.all( - files.map(async file => { - return { - file: new File([], file.name), - blobId: file.blobId, - fileId: file.id, - state: file.status, - tooltip: file.error, - createdAt: file.createdAt, - }; - }) - ); - - const tagChips: TagChip[] = tags.map(tag => ({ - tagId: tag.id, - state: 'finished', - createdAt: tag.createdAt, - })); - - const collectionChips: CollectionChip[] = collections.map(collection => ({ - collectionId: collection.id, - state: 'finished', - createdAt: collection.createdAt, - })); - - const chips: ChatChip[] = [ - ...docChips, - ...fileChips, - ...tagChips, - ...collectionChips, - ].sort((a, b) => { - const aTime = a.createdAt ?? Date.now(); - const bTime = b.createdAt ?? Date.now(); - return aTime - bTime; - }); - - this.updateChips(chips); + await this.runtime?.dispatch({ type: 'loadContext' }); + this.syncChipsFromRuntime(); }; private readonly updateChips = (chips: ChatChip[]) => { @@ -487,14 +495,11 @@ export class AIChatComposer extends SignalWatcher( private readonly addDocToContext = async (chip: DocChip) => { try { - const contextId = await this.createContextId(); - if (!contextId || !AIProvider.context) { - throw new Error('Context not found'); - } - await AIProvider.context.addContextDoc({ - contextId, - docId: chip.docId, + await this.runtime?.dispatch({ + type: 'addContextItem', + item: { kind: 'doc', docId: chip.docId, state: chip.state }, }); + this.syncChipsFromRuntime(); } catch (e) { this.updateChip(chip, { state: 'failed', @@ -505,18 +510,17 @@ export class AIChatComposer extends SignalWatcher( private readonly addFileToContext = async (chip: FileChip) => { try { - const contextId = await this.createContextId(); - if (!contextId || !AIProvider.context) { - throw new Error('Context not found'); - } - const contextFile = await AIProvider.context.addContextFile(chip.file, { - contextId, - }); - this.updateChip(chip, { - state: contextFile.status, - blobId: contextFile.blobId, - fileId: contextFile.id, + await this.runtime?.dispatch({ + type: 'addContextItem', + item: { + kind: 'file', + file: chip.file, + fileId: chip.fileId ?? undefined, + blobId: chip.blobId ?? undefined, + state: chip.state, + }, }); + this.syncChipsFromRuntime(); } catch (e) { this.updateChip(chip, { state: 'failed', @@ -527,20 +531,16 @@ export class AIChatComposer extends SignalWatcher( private readonly addTagToContext = async (chip: TagChip) => { try { - const contextId = await this.createContextId(); - if (!contextId || !AIProvider.context) { - throw new Error('Context not found'); - } - // TODO: server side docIds calculation - const docIds = this.docDisplayConfig.getTagPageIds(chip.tagId); - await AIProvider.context.addContextTag({ - contextId, - tagId: chip.tagId, - docIds, - }); - this.updateChip(chip, { - state: 'finished', + await this.runtime?.dispatch({ + type: 'addContextItem', + item: { + kind: 'tag', + tagId: chip.tagId, + docIds: this.docDisplayConfig.getTagPageIds(chip.tagId), + state: chip.state, + }, }); + this.syncChipsFromRuntime(); } catch (e) { this.updateChip(chip, { state: 'failed', @@ -551,22 +551,16 @@ export class AIChatComposer extends SignalWatcher( private readonly addCollectionToContext = async (chip: CollectionChip) => { try { - const contextId = await this.createContextId(); - if (!contextId || !AIProvider.context) { - throw new Error('Context not found'); - } - // TODO: server side docIds calculation - const docIds = this.docDisplayConfig.getCollectionPageIds( - chip.collectionId - ); - await AIProvider.context.addContextCollection({ - contextId, - collectionId: chip.collectionId, - docIds, - }); - this.updateChip(chip, { - state: 'finished', + await this.runtime?.dispatch({ + type: 'addContextItem', + item: { + kind: 'collection', + collectionId: chip.collectionId, + docIds: this.docDisplayConfig.getCollectionPageIds(chip.collectionId), + state: chip.state, + }, }); + this.syncChipsFromRuntime(); } catch (e) { this.updateChip(chip, { state: 'failed', @@ -579,19 +573,12 @@ export class AIChatComposer extends SignalWatcher( private readonly addAttachmentChipToContext = async ( chip: AttachmentChip ) => { - const contextId = await this.createContextId(); - if (!contextId || !AIProvider.context) { - throw new Error('Context not found'); - } try { - const contextBlob = await AIProvider.context.addContextBlob({ - blobId: chip.sourceId, - contextId, - }); - this.updateChip(chip, { - state: contextBlob.status || 'processing', - blobId: chip.sourceId, + await this.runtime?.dispatch({ + type: 'addContextItem', + item: { kind: 'blob', blobId: chip.sourceId, state: chip.state }, }); + this.syncChipsFromRuntime(); } catch (e) { this.updateChip(chip, { state: 'failed', @@ -604,48 +591,19 @@ export class AIChatComposer extends SignalWatcher( private readonly removeFromContext = async ( chip: ChatChip ): Promise => { + if (isSelectedContextChip(chip)) { + this.updateContext({ + ...this.chatContextValue, + snapshot: null, + combinedElementsMarkdown: null, + }); + return true; + } + const item = this.chipToContextItem(chip); + if (!item) return true; try { - const contextId = await this.createContextId(); - if (!contextId || !AIProvider.context) { - return true; - } - if (isDocChip(chip)) { - return await AIProvider.context.removeContextDoc({ - contextId, - docId: chip.docId, - }); - } - if (isFileChip(chip) && chip.fileId) { - return await AIProvider.context.removeContextFile({ - contextId, - fileId: chip.fileId, - }); - } - if (isTagChip(chip)) { - return await AIProvider.context.removeContextTag({ - contextId, - tagId: chip.tagId, - }); - } - if (isCollectionChip(chip)) { - return await AIProvider.context.removeContextCollection({ - contextId, - collectionId: chip.collectionId, - }); - } - if (isAttachmentChip(chip)) { - return await AIProvider.context.removeContextBlob({ - contextId, - blobId: chip.sourceId, - }); - } - if (isSelectedContextChip(chip)) { - this.updateContext({ - ...this.chatContextValue, - snapshot: null, - combinedElementsMarkdown: null, - }); - } + await this.runtime?.dispatch({ type: 'removeContextItem', item }); + this.syncChipsFromRuntime(); return true; } catch { return true; @@ -669,136 +627,19 @@ export class AIChatComposer extends SignalWatcher( }; private readonly pollContextDocsAndFiles = async () => { - const sessionId = this.session?.sessionId; - const contextId = await this._getContextId(); - if (!sessionId || !contextId || !AIProvider.context) { - return; - } - if (this._pollAbortController) { - // already polling, reset timer - this._abortPoll(); - } - this._pollAbortController = new AbortController(); - await AIProvider.context.pollContextDocsAndFiles( - this.workspaceId, - sessionId, - contextId, - this._onPoll, - this._pollAbortController.signal - ); + if (!this.runtime) return; + await this.runtime.dispatch({ type: 'startContextPolling' }); + this.syncChipsFromRuntime(); }; private readonly pollEmbeddingStatus = async () => { - if (this._pollEmbeddingStatusAbortController) { - this._pollEmbeddingStatusAbortController.abort(); - } - this._pollEmbeddingStatusAbortController = new AbortController(); - const signal = this._pollEmbeddingStatusAbortController.signal; - - try { - await AIProvider.context?.pollEmbeddingStatus( - this.workspaceId, - (status: ContextWorkspaceEmbeddingStatus) => { - if (!status) { - this.embeddingCompleted = false; - return; - } - const prevCompleted = this.embeddingCompleted; - const completed = status.embedded === status.total; - this.embeddingCompleted = completed; - if (prevCompleted !== completed) { - this.requestUpdate(); - } - }, - signal - ); - } catch { - this.embeddingCompleted = false; - } - }; - - private readonly _onPoll = ( - result?: BlockSuitePresets.AIDocsAndFilesContext - ) => { - if (!result) { - this._abortPoll(); - return; - } - const { - docs: sDocs = [], - files = [], - tags = [], - collections = [], - blobs = [], - } = result; - const docs = [ - ...sDocs, - ...tags.flatMap(tag => tag.docs), - ...collections.flatMap(collection => collection.docs), - ]; - const hashMap = new Map< - string, - CopilotContextDoc | CopilotContextFile | CopilotContextBlob - >(); - const count: Record = { - finished: 0, - processing: 0, - failed: 0, - }; - docs.forEach(doc => { - hashMap.set(doc.id, doc); - doc.status && count[doc.status]++; - }); - files.forEach(file => { - hashMap.set(file.id, file); - file.status && count[file.status]++; - }); - blobs.forEach(blob => { - hashMap.set(blob.id, blob); - blob.status && count[blob.status]++; - }); - const nextChips = this.chips.map(chip => { - if (isTagChip(chip) || isCollectionChip(chip)) { - return chip; - } - const id = isDocChip(chip) - ? chip.docId - : isFileChip(chip) - ? chip.fileId - : isAttachmentChip(chip) - ? chip.sourceId - : isSelectedContextChip(chip) - ? chip.uuid - : undefined; - const item = id && hashMap.get(id); - if (item && item.status) { - return { - ...chip, - state: item.status, - tooltip: 'error' in item ? item.error : undefined, - }; - } - return chip; - }); - this.updateChips(nextChips); - this.onEmbeddingProgressChange?.(count); - if (count.processing === 0) { - this._abortPoll(); - } - }; - - private readonly _abortPoll = () => { - this._pollAbortController?.abort(); - this._pollAbortController = null; - }; - - private readonly _abortPollEmbeddingStatus = () => { - this._pollEmbeddingStatusAbortController?.abort(); - this._pollEmbeddingStatusAbortController = null; + if (!this.runtime) return; + await this.runtime.dispatch({ type: 'pollEmbeddingStatus' }); + this.syncChipsFromRuntime(); }; private readonly initComposer = async () => { - const userId = (await AIProvider.userInfo)?.id; + const userId = AIAppEvents.userInfo.value?.id; if (!userId || !this.session) return; await this.initChips(); diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/ai-chat-content.spec.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/ai-chat-content.spec.ts index fd0fb52ef5..18a05462b0 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/ai-chat-content.spec.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/ai-chat-content.spec.ts @@ -1,17 +1,10 @@ /** * @vitest-environment happy-dom */ -import { afterEach, describe, expect, test, vi } from 'vitest'; +import { describe, expect, test, vi } from 'vitest'; -import { AIProvider } from '../../provider'; import { AIChatContent } from './ai-chat-content'; -const originalHistories = AIProvider.histories; - -afterEach(() => { - AIProvider.provide('histories', originalHistories as any); -}); - describe('AIChatContent pinned scroll tracking', () => { test('records scroll position from the chat messages host', async () => { let scrollEndHandler: (() => void) | undefined; @@ -47,103 +40,44 @@ describe('AIChatContent pinned scroll tracking', () => { }); }); -describe('AIChatContent history loading', () => { - test('replaces messages when the active session changes', async () => { - const histories = { - chats: vi.fn(async (_workspaceId: string, sessionId: string) => [ - { - messages: [ - { - id: `${sessionId}-message`, - role: 'user', - content: sessionId, - createdAt: '2026-01-01T00:00:00.000Z', - }, - ], - }, - ]), - actions: vi.fn(async () => []), - cleanup: vi.fn(), - ids: vi.fn(), +describe('AIChatContent runtime snapshot sync', () => { + test('derives messages and loading state from runtime snapshot', () => { + const runtimeMessage = { + id: 'message-1', + role: 'user', + content: 'hello', + createdAt: '2026-01-01T00:00:00.000Z', }; - AIProvider.provide('histories', histories as any); - - const content: { - updateHistoryCounter: number; - historyKey: string | undefined; - workspaceId: string; - docId: string; - session: { sessionId: string }; - chatContextValue: { messages: unknown[]; status?: string }; - updateContext: (context: { messages: unknown[] }) => void; - } = { - updateHistoryCounter: 0, - historyKey: undefined, - workspaceId: 'ws-1', - docId: 'doc-1', - session: { sessionId: 'session-1' }, - chatContextValue: { messages: [] }, - updateContext(context: { messages: unknown[] }) { - this.chatContextValue = { - ...this.chatContextValue, - ...context, - }; - }, - }; - - await (AIChatContent.prototype as any).updateHistory.call(content); - expect( - content.chatContextValue.messages.map((message: any) => message.id) - ).toEqual(['session-1-message']); - - content.session = { sessionId: 'session-2' }; - await (AIChatContent.prototype as any).updateHistory.call(content); - - expect( - content.chatContextValue.messages.map((message: any) => message.id) - ).toEqual(['session-2-message']); - }); - - test('does not overwrite in-flight optimistic messages when a session is created', async () => { - const histories = { - chats: vi.fn(async () => [{ messages: [] }]), - actions: vi.fn(async () => []), - cleanup: vi.fn(), - ids: vi.fn(), - }; - AIProvider.provide('histories', histories as any); - - const optimisticMessages = [ - { - id: '', - role: 'user', - content: 'hello', - createdAt: '2026-01-01T00:00:00.000Z', - }, - { - id: '', - role: 'assistant', - content: '', - createdAt: '2026-01-01T00:00:01.000Z', - }, - ]; - const updateContext = vi.fn(); - const content = { - updateHistoryCounter: 0, - historyKey: 'ws-1:doc-1:', - workspaceId: 'ws-1', - docId: 'doc-1', - session: { sessionId: 'session-1' }, - chatContextValue: { - messages: optimisticMessages, + const content = Object.create(AIChatContent.prototype) as AIChatContent; + Object.defineProperty(content, 'runtimeSnapshot', { + configurable: true, + value: { + messages: [runtimeMessage], status: 'loading', + error: null, + readiness: 'initializing', + history: { loading: false }, }, - updateContext, - }; + }); + Object.defineProperty(content, 'chatContextValue', { + configurable: true, + value: { + messages: [], + status: 'idle', + error: null, + abortController: null, + }, + }); + Object.defineProperty(content, '_initializeScrollListeners', { + configurable: true, + value: vi.fn(), + }); - await (AIChatContent.prototype as any).updateHistory.call(content); + (content as any).updated(new Map([['runtimeSnapshot', null]])); - expect(updateContext).not.toHaveBeenCalled(); - expect(content.chatContextValue.messages).toBe(optimisticMessages); + expect(content.messages).toEqual([runtimeMessage]); + expect(content.isHistoryLoading).toBe(true); + expect(content.chatContextValue.messages).toEqual([]); + expect(content.chatContextValue.status).toBe('idle'); }); }); diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/ai-chat-content.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/ai-chat-content.ts index b6216d92af..7402424f26 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/ai-chat-content.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/ai-chat-content.ts @@ -12,10 +12,7 @@ import type { WorkspaceDialogService } from '@affine/core/modules/dialogs'; import type { FeatureFlagService } from '@affine/core/modules/feature-flag'; import type { PeekViewService } from '@affine/core/modules/peek-view'; import type { AppThemeService } from '@affine/core/modules/theme'; -import type { - ContextEmbedStatus, - CopilotChatHistoryFragment, -} from '@affine/graphql'; +import type { CopilotChatHistoryFragment } from '@affine/graphql'; import { SignalWatcher, WithDisposable } from '@blocksuite/affine/global/lit'; import { type EditorHost, ShadowlessElement } from '@blocksuite/affine/std'; import type { ExtensionType } from '@blocksuite/affine/store'; @@ -28,7 +25,9 @@ import { createRef, type Ref, ref } from 'lit/directives/ref.js'; import { styleMap } from 'lit/directives/style-map.js'; import { pick } from 'lodash-es'; -import { type AIChatParams, AIProvider } from '../../provider/ai-provider'; +import { AIAppEvents } from '../../provider/ai-app-events'; +import type { AIChatParams } from '../../provider/ai-provider'; +import type { AIChatRuntime, AIChatSnapshot } from '../../runtime/chat'; import { extractSelectedContent } from '../../utils/extract'; import { HISTORY_IMAGE_ACTIONS } from '../../utils/history-image-actions'; import type { SearchMenuConfig } from '../ai-chat-add-context'; @@ -36,8 +35,6 @@ import type { DocDisplayConfig } from '../ai-chat-chips'; import type { AIReasoningConfig } from '../ai-chat-input'; import { type AIChatMessages, - type ChatAction, - type ChatMessage, type HistoryMessage, isChatMessage, } from '../ai-chat-messages'; @@ -126,9 +123,10 @@ export class AIChatContent extends SignalWatcher( accessor session!: CopilotChatHistoryFragment | null | undefined; @property({ attribute: false }) - accessor createSession!: () => Promise< - CopilotChatHistoryFragment | undefined - >; + accessor runtime: AIChatRuntime | null | undefined; + + @property({ attribute: false }) + accessor runtimeSnapshot: AIChatSnapshot | null | undefined; @property({ attribute: false }) accessor workspaceId!: string; @@ -172,14 +170,6 @@ export class AIChatContent extends SignalWatcher( @property({ attribute: false }) accessor aiModelService!: AIModelService; - @property({ attribute: false }) - accessor onEmbeddingProgressChange: - | ((count: Record) => void) - | undefined; - - @property({ attribute: false }) - accessor onContextChange!: (context: Partial) => void; - @property({ attribute: false }) accessor onOpenDoc!: (docId: string, sessionId?: string) => void; @@ -198,9 +188,6 @@ export class AIChatContent extends SignalWatcher( @state() accessor chatContextValue: ChatContextValue = DEFAULT_CHAT_CONTEXT_VALUE; - @state() - accessor isHistoryLoading = false; - @state() private accessor showPreviewPanel = false; @@ -210,15 +197,13 @@ export class AIChatContent extends SignalWatcher( private readonly chatMessagesRef: Ref = createRef(); - // request counter to track the latest request - private updateHistoryCounter = 0; - - private historyKey: string | undefined; - private lastScrollTop: number | undefined; get messages() { - return this.chatContextValue.messages.filter(item => { + const messages = + (this.runtimeSnapshot?.messages as HistoryMessage[] | undefined) ?? + this.chatContextValue.messages; + return messages.filter(item => { return ( isChatMessage(item) || item.messages?.length === 3 || @@ -232,86 +217,15 @@ export class AIChatContent extends SignalWatcher( return false; } - private async updateHistory() { - const currentRequest = ++this.updateHistoryCounter; - if (!AIProvider.histories) { - return; - } - - const sessionId = this.session?.sessionId; - const nextHistoryKey = `${this.workspaceId}:${this.docId ?? ''}:${ - sessionId ?? '' - }`; - const previousHistoryKey = this.historyKey; - const preserveCurrentMessages = previousHistoryKey === nextHistoryKey; - this.historyKey = nextHistoryKey; - const [histories, actions] = await Promise.all([ - sessionId - ? AIProvider.histories.chats(this.workspaceId, sessionId) - : Promise.resolve([]), - this.docId && this.showActions - ? AIProvider.histories.actions(this.workspaceId, this.docId) - : Promise.resolve([]), - ]); - - // Check if this is still the latest request - if (currentRequest !== this.updateHistoryCounter) { - return; - } - - if ( - !preserveCurrentMessages && - (this.chatContextValue.status === 'loading' || - this.chatContextValue.status === 'transmitting') && - this.chatContextValue.messages.length - ) { - return; - } - - const messages: HistoryMessage[] = preserveCurrentMessages - ? this.chatContextValue.messages.slice().filter(isChatMessage) - : []; - - const chatActions = (actions || []) as ChatAction[]; - messages.push(...chatActions); - - const chatMessages = (histories?.[0]?.messages || []) as ChatMessage[]; - messages.push(...chatMessages); - - this.updateContext({ - messages: messages.sort( - (a, b) => - new Date(a.createdAt).getTime() - new Date(b.createdAt).getTime() - ), - }); - } - - private readonly updateActions = async () => { - if (!this.docId || !AIProvider.histories || !this.showActions) { - return; - } - const actions = await AIProvider.histories.actions( - this.workspaceId, - this.docId + get isHistoryLoading() { + const snapshot = this.runtimeSnapshot; + return ( + snapshot?.readiness === 'initializing' || !!snapshot?.history.loading ); - if (actions && actions.length) { - const chatMessages = this.chatContextValue.messages.filter(message => - isChatMessage(message) - ); - const chatActions = actions as ChatAction[]; - const messages: HistoryMessage[] = [...chatMessages, ...chatActions]; - this.updateContext({ - messages: messages.sort( - (a, b) => - new Date(a.createdAt).getTime() - new Date(b.createdAt).getTime() - ), - }); - } - }; + } private readonly updateContext = (context: Partial) => { this.chatContextValue = { ...this.chatContextValue, ...context }; - this.onContextChange?.(context); this.updateDraft(context).catch(console.error); }; @@ -331,9 +245,7 @@ export class AIChatContent extends SignalWatcher( }; private readonly initChatContent = async () => { - this.isHistoryLoading = true; - await this.updateHistory(); - this.isHistoryLoading = false; + this._initializeScrollListeners(); }; protected override firstUpdated(): void {} @@ -382,13 +294,13 @@ export class AIChatContent extends SignalWatcher( public openPreviewPanel(content?: TemplateResult<1>) { this.showPreviewPanel = true; if (content) this.previewPanelContent = content; - AIProvider.slots.previewPanelOpenChange.next(true); + AIAppEvents.previewPanelOpenChange.next(true); } public closePreviewPanel(destroyContent: boolean = false) { this.showPreviewPanel = false; if (destroyContent) this.previewPanelContent = null; - AIProvider.slots.previewPanelOpenChange.next(false); + AIAppEvents.previewPanelOpenChange.next(false); } public get isPreviewPanelOpen() { @@ -416,19 +328,7 @@ export class AIChatContent extends SignalWatcher( this.subscriptionService.subscription.revalidate(); this._disposables.add( - AIProvider.slots.actions.subscribe(({ event }) => { - const { status } = this.chatContextValue; - if ( - event === 'finished' && - (status === 'idle' || status === 'success') - ) { - this.updateActions().catch(console.error); - } - }) - ); - - this._disposables.add( - AIProvider.slots.requestOpenWithChat.subscribe( + AIAppEvents.requestOpenWithChat.subscribe( (params: AIChatParams | null) => { if (!params) { return; @@ -445,7 +345,7 @@ export class AIChatContent extends SignalWatcher( .catch(console.error); } } - AIProvider.slots.requestOpenWithChat.next(null); + AIAppEvents.requestOpenWithChat.next(null); } ) ); @@ -463,7 +363,8 @@ export class AIChatContent extends SignalWatcher( .workspaceId=${this.workspaceId} .docId=${this.docId} .session=${this.session} - .createSession=${this.createSession} + .runtime=${this.runtime} + .runtimeSnapshot=${this.runtimeSnapshot} .chatContextValue=${this.chatContextValue} .updateContext=${this.updateContext} .isHistoryLoading=${this.isHistoryLoading} @@ -491,10 +392,10 @@ export class AIChatContent extends SignalWatcher( .workspaceId=${this.workspaceId} .docId=${this.docId} .session=${this.session} - .createSession=${this.createSession} + .runtime=${this.runtime} + .runtimeSnapshot=${this.runtimeSnapshot} .chatContextValue=${this.chatContextValue} .updateContext=${this.updateContext} - .onEmbeddingProgressChange=${this.onEmbeddingProgressChange} .reasoningConfig=${this.reasoningConfig} .docDisplayConfig=${this.docDisplayConfig} .searchMenuConfig=${this.searchMenuConfig} diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/ai-chat-input.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/ai-chat-input.ts index a27aaeee3c..2d2dcab6f8 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/ai-chat-input.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/ai-chat-input.ts @@ -24,19 +24,14 @@ import { repeat } from 'lit/directives/repeat.js'; import { styleMap } from 'lit/directives/style-map.js'; import { ChatAbortIcon } from '../../_common/icons'; -import { type AIError, AIProvider, type AISendParams } from '../../provider'; +import { AIAppEvents, type AISendParams } from '../../provider'; +import type { AIChatRuntime, AIChatSnapshot } from '../../runtime/chat'; import { reportResponse } from '../../utils/action-reporter'; import { readBlobAsURL } from '../../utils/image'; -import { mergeStreamObjects } from '../../utils/stream-objects'; import type { SearchMenuConfig } from '../ai-chat-add-context'; import { addFilesToChat } from '../ai-chat-chips/attachment-utils'; import type { ChatChip, DocDisplayConfig } from '../ai-chat-chips/type'; import { isDocChip } from '../ai-chat-chips/utils'; -import { - type ChatMessage, - isChatMessage, - StreamObjectSchema, -} from '../ai-chat-messages'; import type { AIChatInputContext, AIReasoningConfig } from './type'; function getFirstTwoLines(text: string) { @@ -340,6 +335,12 @@ export class AIChatInput extends SignalWatcher( @property({ attribute: false }) accessor session!: CopilotChatHistoryFragment | null | undefined; + @property({ attribute: false }) + accessor runtime: AIChatRuntime | null | undefined; + + @property({ attribute: false }) + accessor runtimeSnapshot: AIChatSnapshot | null | undefined; + @property({ attribute: false }) accessor isContextProcessing!: boolean | undefined; @@ -371,11 +372,6 @@ export class AIChatInput extends SignalWatcher( @property({ attribute: false }) accessor chips: ChatChip[] = []; - @property({ attribute: false }) - accessor createSession!: () => Promise< - CopilotChatHistoryFragment | undefined - >; - @property({ attribute: false }) accessor updateContext!: (context: Partial) => void; @@ -441,7 +437,7 @@ export class AIChatInput extends SignalWatcher( super.connectedCallback(); this._disposables.add( - AIProvider.slots.requestSendWithChat.subscribe( + AIAppEvents.requestSendWithChat.subscribe( (params: AISendParams | null) => { if (!params) { return; @@ -455,13 +451,13 @@ export class AIChatInput extends SignalWatcher( this.send(input).catch(console.error); }, 0); } - AIProvider.slots.requestSendWithChat.next(null); + AIAppEvents.requestSendWithChat.next(null); } ) ); this._disposables.add( - AIProvider.slots.requestOpenWithChat.subscribe(params => { + AIAppEvents.requestOpenWithChat.subscribe(params => { if (!params) return; const { input, host } = params; @@ -552,7 +548,8 @@ export class AIChatInput extends SignalWatcher( } protected override render() { - const { images, status } = this.chatContextValue; + const { images } = this.chatContextValue; + const status = this.runtimeSnapshot?.status ?? this.chatContextValue.status; const hasImages = images.length > 0; const maxHeight = hasImages ? 272 + 2 : 200 + 2; @@ -664,6 +661,10 @@ export class AIChatInput extends SignalWatcher( return true; } + if (this.runtimeSnapshot && !this.runtimeSnapshot.uiPolicy.canSend) { + return true; + } + if (this.isContextProcessing) { return true; } @@ -782,9 +783,11 @@ export class AIChatInput extends SignalWatcher( }; private readonly _handleAbort = () => { - this.chatContextValue.abortController?.abort(); - this.updateContext({ status: 'success' }); - reportResponse('aborted:stop'); + if (this.runtime) { + this.runtime.dispatch({ type: 'stop' }).catch(console.error); + reportResponse('aborted:stop', this.host); + return; + } }; private readonly _toggleReasoning = (extendedThinking: boolean) => { @@ -817,153 +820,52 @@ export class AIChatInput extends SignalWatcher( }; send = async (text: string) => { - try { - const { - status, - markdown, - images, - snapshot, - combinedElementsMarkdown, - html, - } = this.chatContextValue; + if (!this.runtime) return; + const { markdown, images, snapshot, combinedElementsMarkdown, html } = + this.chatContextValue; + const userInput = (markdown ? `${markdown}\n` : '') + text; + const imageAttachments = await Promise.all( + images?.map(image => readBlobAsURL(image)) + ); + const contexts = await this._getMatchedContexts(); + const enableSendDetailedObject = + this.affineFeatureFlagService.flags.enable_send_detailed_object_to_ai + .value; + const userInfo = AIAppEvents.userInfo.value; - if (status === 'loading' || status === 'transmitting') return; - if (!text) return; - if (!AIProvider.actions.chat) return; - - const abortController = new AbortController(); - this.updateContext({ - images: [], - status: 'loading', - error: null, - quote: '', - markdown: '', - abortController, - }); - - const imageAttachments = await Promise.all( - images?.map(image => readBlobAsURL(image)) - ); - const userInput = (markdown ? `${markdown}\n` : '') + text; - - // optimistic update messages - await this._preUpdateMessages(userInput, imageAttachments); - - const sessionId = (await this.createSession())?.sessionId; - let contexts = await this._getMatchedContexts(); - if (abortController.signal.aborted) { - return; - } - - const enableSendDetailedObject = - this.affineFeatureFlagService.flags.enable_send_detailed_object_to_ai - .value; - - const modelId = this.aiModelService.modelId.value; - const stream = await AIProvider.actions.chat({ - sessionId, - input: userInput, - contexts: { - ...contexts, - selectedSnapshot: - snapshot && enableSendDetailedObject ? snapshot : undefined, - selectedMarkdown: - combinedElementsMarkdown && enableSendDetailedObject - ? combinedElementsMarkdown - : undefined, - html: html || undefined, - }, - docId: this.docId, - attachments: images, - workspaceId: this.workspaceId, - stream: true, - signal: abortController.signal, - isRootSession: this.isRootSession, - where: this.trackOptions?.where, - control: this.trackOptions?.control, - reasoning: this._isReasoningActive, - toolsConfig: this.aiToolsConfigService.config.value, - modelId, - }); - - for await (const text of stream) { - const messages = this.chatContextValue.messages.slice(0); - const last = messages.at(-1); - if (last && isChatMessage(last)) { - try { - const parsed = StreamObjectSchema.parse(JSON.parse(text)); - const streamObjects = mergeStreamObjects([ - ...(last.streamObjects ?? []), - parsed, - ]); - messages[messages.length - 1] = { - ...last, - streamObjects, - }; - } catch { - messages[messages.length - 1] = { - ...last, - content: last.content + text, - }; - } - this.updateContext({ messages, status: 'transmitting' }); - } - } - - this.updateContext({ status: 'success' }); - this.onChatSuccess?.(); - // update message id from server - await this._postUpdateMessages(); - } catch (error) { - this.updateContext({ status: 'error', error: error as AIError }); - } finally { - this.updateContext({ abortController: null }); - } - }; - - private readonly _preUpdateMessages = async ( - userInput: string, - attachments: string[] - ) => { - const userInfo = await AIProvider.userInfo; this.updateContext({ - messages: [ - ...this.chatContextValue.messages, - { - id: '', - role: 'user', - content: userInput, - createdAt: new Date().toISOString(), - attachments, - userId: userInfo?.id, - userName: userInfo?.name, - avatarUrl: userInfo?.avatarUrl ?? undefined, - }, - { - id: '', - role: 'assistant', - content: '', - createdAt: new Date().toISOString(), - }, - ], + images: [], + quote: '', + markdown: '', }); - }; - - private readonly _postUpdateMessages = async () => { - const sessionId = this.session?.sessionId; - if (!sessionId || !AIProvider.histories) return; - - const { messages } = this.chatContextValue; - const last = messages[messages.length - 1] as ChatMessage; - if (!last.id) { - const historyIds = await AIProvider.histories.ids( - this.workspaceId, - this.docId, - { sessionId, withMessages: true } - ); - if (!historyIds || !historyIds[0]) return; - last.id = historyIds[0].messages.at(-1)?.id ?? ''; - } + await this.runtime.dispatch({ + type: 'send', + input: userInput, + contexts: { + ...contexts, + selectedSnapshot: + snapshot && enableSendDetailedObject ? snapshot : undefined, + selectedMarkdown: + combinedElementsMarkdown && enableSendDetailedObject + ? combinedElementsMarkdown + : undefined, + html: html || undefined, + }, + attachments: images, + attachmentPreviews: imageAttachments, + isRootSession: this.isRootSession, + where: this.trackOptions?.where, + control: this.trackOptions?.control, + reasoning: this._isReasoningActive, + toolsConfig: this.aiToolsConfigService.config.value, + modelId: this.aiModelService.modelId.value, + userInfo: { + userId: userInfo?.id, + userName: userInfo?.name, + avatarUrl: userInfo?.avatarUrl ?? undefined, + }, + }); + this.onChatSuccess?.(); }; private async _getMatchedContexts() { diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-messages/ai-chat-messages.spec.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-messages/ai-chat-messages.spec.ts index c3f7e819cd..36b4bac4a3 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-messages/ai-chat-messages.spec.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-messages/ai-chat-messages.spec.ts @@ -93,4 +93,35 @@ describe('AIChatMessages scrolling', () => { expect(element.canScrollDown).toBe(false); expect(scrollToEnd).toHaveBeenCalled(); }); + + test('message keys are scoped by active tab', () => { + const element = {} as AIChatMessages; + const message = { + id: 'message-1', + role: 'assistant', + content: 'reply', + createdAt: new Date().toISOString(), + }; + + element.runtimeSnapshot = { + activeTabId: 'session-1', + } as AIChatMessages['runtimeSnapshot']; + const firstKey = (AIChatMessages.prototype as any)._getMessageKey.call( + element, + message, + 0 + ); + + element.runtimeSnapshot = { + activeTabId: 'session-2', + } as AIChatMessages['runtimeSnapshot']; + const secondKey = (AIChatMessages.prototype as any)._getMessageKey.call( + element, + message, + 0 + ); + + expect(firstKey).toBe('session-1:message-1'); + expect(secondKey).toBe('session-2:message-1'); + }); }); diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-messages/ai-chat-messages.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-messages/ai-chat-messages.ts index 17cff0c3b8..d6bab64503 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-messages/ai-chat-messages.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-messages/ai-chat-messages.ts @@ -19,8 +19,13 @@ import { repeat } from 'lit/directives/repeat.js'; import { debounce, throttle } from 'lodash-es'; import { AffineIcon } from '../../_common/icons'; -import { type AIError, AIProvider, UnauthorizedError } from '../../provider'; -import { mergeStreamObjects } from '../../utils/stream-objects'; +import { + AIAppEvents, + type AIError, + GeneralNetworkError, + UnauthorizedError, +} from '../../provider'; +import type { AIChatRuntime, AIChatSnapshot } from '../../runtime/chat'; import type { DocDisplayConfig } from '../ai-chat-chips'; import { type ChatContextValue } from '../ai-chat-content/type'; import type { AIReasoningConfig } from '../ai-chat-input'; @@ -30,12 +35,7 @@ import { AI_CHAT_SCROLL_DOWN_INDICATOR_THRESHOLD, } from './auto-scroll'; import { AIPreloadConfig } from './preload-config'; -import { - type HistoryMessage, - isChatAction, - isChatMessage, - StreamObjectSchema, -} from './type'; +import { type HistoryMessage, isChatAction, isChatMessage } from './type'; export class AIChatMessages extends WithDisposable(ShadowlessElement) { static override styles = css` @@ -179,9 +179,10 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) { accessor session!: CopilotChatHistoryFragment | null | undefined; @property({ attribute: false }) - accessor createSession!: () => Promise< - CopilotChatHistoryFragment | undefined - >; + accessor runtime: AIChatRuntime | null | undefined; + + @property({ attribute: false }) + accessor runtimeSnapshot: AIChatSnapshot | null | undefined; @property({ attribute: false }) accessor updateContext!: (context: Partial) => void; @@ -227,8 +228,21 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) { private _lastObservedScrollTop = 0; - private get _isReasoningActive() { - return !!this.reasoningConfig.enabled.value; + private get chatStatus() { + return this.runtimeSnapshot?.status ?? this.chatContextValue.status; + } + + private get chatError(): AIError | null { + return this.toAIError( + this.runtimeSnapshot?.error ?? this.chatContextValue.error + ); + } + + private toAIError(error: Error | null): AIError | null { + if (!error) return null; + return 'type' in error + ? (error as AIError) + : new GeneralNetworkError(error.message); } private _renderAIOnboarding() { @@ -296,8 +310,20 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) { return scrollHeight - scrollTop - clientHeight; } + private _getMessageKey(item: HistoryMessage, index: number) { + const tabKey = + this.runtimeSnapshot?.activeTabId ?? + this.runtimeSnapshot?.activeSessionId ?? + 'legacy'; + if (isChatMessage(item)) { + return `${tabKey}:${item.id || index}`; + } + return `${tabKey}:${index}`; + } + protected override render() { - const { status, error } = this.chatContextValue; + const status = this.chatStatus; + const error = this.chatError; const { isHistoryLoading } = this; const filteredItems = this.messages; @@ -337,7 +363,7 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) { ` : repeat( filteredItems, - (_, index) => index, + (item, index) => this._getMessageKey(item, index), (item, index) => { const isLast = index === filteredItems.length - 1; if (isChatMessage(item) && item.role === 'user') { @@ -389,21 +415,19 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) { super.connectedCallback(); const { disposables } = this; - Promise.resolve(AIProvider.userInfo) - .then(res => { - this.avatarUrl = res?.avatarUrl ?? ''; - }) - .catch(console.error); + this.avatarUrl = AIAppEvents.userInfo.value?.avatarUrl ?? ''; disposables.add( - AIProvider.slots.userInfo.subscribe(userInfo => { - const { status, error } = this.chatContextValue; + AIAppEvents.userInfo.subscribe(userInfo => { + const status = this.chatStatus; + const error = this.chatError; this.avatarUrl = userInfo?.avatarUrl ?? ''; if ( status === 'error' && error instanceof UnauthorizedError && userInfo ) { + this.runtime?.dispatch({ type: 'clearError' }).catch(console.error); this.updateContext({ status: 'idle', error: null }); } }) @@ -452,7 +476,7 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) { if (changedProperties.has('messages')) { this._onScroll(); - if (this.chatContextValue.status === 'transmitting') { + if (this.chatStatus === 'transmitting') { this._throttledScrollToEnd(); } else if (this._autoScrollEnabled) { this.scrollToEnd(); @@ -460,9 +484,9 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) { } if ( - changedProperties.has('chatContextValue') && - (this.chatContextValue.status === 'success' || - this.chatContextValue.status === 'error') + (changedProperties.has('chatContextValue') || + changedProperties.has('runtimeSnapshot')) && + (this.chatStatus === 'success' || this.chatStatus === 'error') ) { this._onScroll(); } @@ -484,69 +508,11 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) { } retry = async () => { - try { - const sessionId = (await this.createSession())?.sessionId; - if (!sessionId) return; - if (!AIProvider.actions.chat) return; - - const abortController = new AbortController(); - const messages = [...this.chatContextValue.messages]; - const last = messages[messages.length - 1]; - if ('content' in last) { - last.content = ''; - last.streamObjects = []; - last.createdAt = new Date().toISOString(); - } - this.updateContext({ - messages, - status: 'loading', - error: null, - abortController, - }); - - const stream = await AIProvider.actions.chat({ - sessionId, - retry: true, - docId: this.docId, - workspaceId: this.workspaceId, - stream: true, - signal: abortController.signal, - where: 'chat-panel', - control: 'chat-send', - isRootSession: true, - reasoning: this._isReasoningActive, - toolsConfig: this.aiToolsConfigService.config.value, - }); - - for await (const text of stream) { - const messages = this.chatContextValue.messages.slice(0); - const last = messages.at(-1); - if (last && isChatMessage(last)) { - try { - const parsed = StreamObjectSchema.parse(JSON.parse(text)); - const streamObjects = mergeStreamObjects([ - ...(last.streamObjects ?? []), - parsed, - ]); - messages[messages.length - 1] = { - ...last, - streamObjects, - }; - } catch { - messages[messages.length - 1] = { - ...last, - content: last.content + text, - }; - } - this.updateContext({ messages, status: 'transmitting' }); - } - } - - this.updateContext({ status: 'success' }); - } catch (error) { - this.updateContext({ status: 'error', error: error as AIError }); - } finally { - this.updateContext({ abortController: null }); - } + if (!this.runtime) return; + const last = this.messages.at(-1); + await this.runtime.dispatch({ + type: 'retry', + messageId: last && isChatMessage(last) ? last.id : '', + }); }; } diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-messages/preload-config.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-messages/preload-config.ts index fad63136f3..93ae3b4957 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-messages/preload-config.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-messages/preload-config.ts @@ -6,7 +6,7 @@ import { SendIcon, } from '@blocksuite/icons/lit'; -import { AIProvider } from '../../provider/ai-provider.js'; +import { AIAppEvents } from '../../provider/ai-app-events.js'; import completeWritingWithAI from './templates/completeWritingWithAI.zip'; import freelyCommunicateWithAI from './templates/freelyCommunicateWithAI.zip'; import readAforeign from './templates/readAforeign.zip'; @@ -19,7 +19,7 @@ export const AIPreloadConfig = [ text: 'Read a foreign language article with AI', testId: 'read-foreign-language-article-with-ai', handler: () => { - AIProvider.slots.requestInsertTemplate.next({ + AIAppEvents.requestInsertTemplate.next({ template: readAforeign, mode: 'edgeless', }); @@ -30,7 +30,7 @@ export const AIPreloadConfig = [ text: 'Tidy an article with AI MindMap Action', testId: 'tidy-an-article-with-ai-mindmap-action', handler: () => { - AIProvider.slots.requestInsertTemplate.next({ + AIAppEvents.requestInsertTemplate.next({ template: TidyMindMapV3, mode: 'edgeless', }); @@ -41,7 +41,7 @@ export const AIPreloadConfig = [ text: 'Add illustrations to the article', testId: 'add-illustrations-to-the-article', handler: () => { - AIProvider.slots.requestInsertTemplate.next({ + AIAppEvents.requestInsertTemplate.next({ template: redHat, mode: 'edgeless', }); @@ -52,7 +52,7 @@ export const AIPreloadConfig = [ text: 'Complete writing with AI', testId: 'complete-writing-with-ai', handler: () => { - AIProvider.slots.requestInsertTemplate.next({ + AIAppEvents.requestInsertTemplate.next({ template: completeWritingWithAI, mode: 'edgeless', }); @@ -63,7 +63,7 @@ export const AIPreloadConfig = [ text: 'Freely communicate with AI', testId: 'freely-communicate-with-ai', handler: () => { - AIProvider.slots.requestInsertTemplate.next({ + AIAppEvents.requestInsertTemplate.next({ template: freelyCommunicateWithAI, mode: 'edgeless', }); diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-toolbar/ai-chat-tabs.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-toolbar/ai-chat-tabs.ts index 277de4481a..5b448bc2bb 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-toolbar/ai-chat-tabs.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-toolbar/ai-chat-tabs.ts @@ -7,9 +7,15 @@ import { css, html, type PropertyValues } from 'lit'; import { property } from 'lit/decorators.js'; import { repeat } from 'lit/directives/repeat.js'; +import type { AIChatRuntime, AIChatSnapshot } from '../../runtime/chat'; + const DEFAULT_TAB_TITLE = 'New chat'; const TITLE_MAX_LENGTH = 28; +type RenderTabItem = + | { kind: 'draft' } + | { kind: 'session'; session: CopilotChatHistoryFragment }; + function truncate(text: string): string { if (text.length <= TITLE_MAX_LENGTH) return text; return `${text.slice(0, TITLE_MAX_LENGTH).trimEnd()}…`; @@ -27,19 +33,10 @@ function deriveTabTitle(session: CopilotChatHistoryFragment): string { export class AIChatTabs extends WithDisposable(ShadowlessElement) { @property({ attribute: false }) - accessor sessions: CopilotChatHistoryFragment[] = []; + accessor runtimeSnapshot: AIChatSnapshot | null | undefined; @property({ attribute: false }) - accessor activeSessionId: string | undefined; - - @property({ attribute: false }) - accessor showDraftTab = false; - - @property({ attribute: false }) - accessor onSelectTab!: (sessionId: string) => void; - - @property({ attribute: false }) - accessor onCloseTab!: (sessionId: string) => void; + accessor runtime!: AIChatRuntime; static override styles = css` ai-chat-tabs { @@ -100,6 +97,10 @@ export class AIChatTabs extends WithDisposable(ShadowlessElement) { color: ${unsafeCSSVarV2('text/primary')}; } + .tab[data-kind='draft'] { + padding: 0 10px; + } + .tab-title { white-space: nowrap; overflow: hidden; @@ -132,21 +133,46 @@ export class AIChatTabs extends WithDisposable(ShadowlessElement) { `; override render() { - if (!this.sessions.length && !this.showDraftTab) return html``; + const items = this._getRenderItems(); + if (!items.length) return html``; return html`
- ${this.showDraftTab ? this._renderDraftTab() : null} ${repeat( - this.sessions, - session => session.sessionId, - session => this._renderTab(session) + items, + item => (item.kind === 'draft' ? 'draft' : item.session.sessionId), + item => + item.kind === 'draft' + ? this._renderDraftTab() + : this._renderTab(item.session) )}
`; } + private _getRenderItems(): RenderTabItem[] { + const snapshot = this.runtimeSnapshot; + if (!snapshot) return []; + const sessions = new Map( + snapshot.sessions.map(session => [session.sessionId, session]) + ); + const items: RenderTabItem[] = []; + for (const tab of snapshot.tabs) { + if (tab.kind === 'draft') { + if (snapshot.uiPolicy.showDraftTab) { + items.push({ kind: 'draft' }); + } + continue; + } + const session = sessions.get(tab.sessionId); + if (session) { + items.push({ kind: 'session', session }); + } + } + return items; + } + private readonly _handleWheel = (e: WheelEvent) => { const el = e.currentTarget as HTMLElement; if (el.scrollWidth <= el.clientWidth) return; @@ -157,7 +183,7 @@ export class AIChatTabs extends WithDisposable(ShadowlessElement) { }; private _renderTab(session: CopilotChatHistoryFragment) { - const active = session.sessionId === this.activeSessionId; + const active = session.sessionId === this.runtimeSnapshot?.activeSessionId; const title = deriveTabTitle(session); return html`
@@ -195,23 +223,27 @@ export class AIChatTabs extends WithDisposable(ShadowlessElement) { } private readonly _handleSelect = (sessionId: string) => { - if (sessionId === this.activeSessionId) return; - this.onSelectTab(sessionId); + if (sessionId === this.runtimeSnapshot?.activeSessionId) return; + this.runtime + .dispatch({ type: 'openSession', sessionId }) + .catch(console.error); }; private readonly _handleClose = (e: Event, sessionId: string) => { e.stopPropagation(); - this.onCloseTab(sessionId); + this.runtime + .dispatch({ type: 'closeTab', tabId: sessionId }) + .catch(console.error); }; override updated(changedProps: PropertyValues) { super.updated(changedProps); - if ( - (changedProps.has('activeSessionId') || changedProps.has('sessions')) && - this.activeSessionId - ) { + if (changedProps.has('runtimeSnapshot')) { + const activeSessionId = this.runtimeSnapshot?.activeSessionId; const activeTab = this.renderRoot.querySelector( - `[data-session-id="${this.activeSessionId}"]` + activeSessionId + ? `[data-session-id="${activeSessionId}"]` + : `[data-testid="ai-chat-draft-tab"]` ); activeTab?.scrollIntoView({ behavior: 'smooth', diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-toolbar/ai-chat-toolbar.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-toolbar/ai-chat-toolbar.ts index e4757abbcf..f453236335 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-toolbar/ai-chat-toolbar.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-toolbar/ai-chat-toolbar.ts @@ -14,34 +14,22 @@ import { flip, offset } from '@floating-ui/dom'; import { css, html } from 'lit'; import { property, query } from 'lit/decorators.js'; +import type { AIChatRuntime, AIChatSnapshot } from '../../runtime/chat'; import type { DocDisplayConfig } from '../ai-chat-chips'; -import type { ChatStatus } from '../ai-chat-messages'; export class AIChatToolbar extends WithDisposable(ShadowlessElement) { @property({ attribute: false }) accessor session!: CopilotChatHistoryFragment | null | undefined; @property({ attribute: false }) - accessor workspaceId!: string; + accessor runtime!: AIChatRuntime; + + @property({ attribute: false }) + accessor runtimeSnapshot!: AIChatSnapshot; @property({ attribute: false }) accessor docId: string | undefined; - @property({ attribute: false }) - accessor status!: ChatStatus; - - @property({ attribute: false }) - accessor onNewSession!: () => void; - - @property({ attribute: false }) - accessor canCreateNewSession = true; - - @property({ attribute: false }) - accessor onTogglePin!: () => Promise; - - @property({ attribute: false }) - accessor onOpenSession!: (sessionId: string) => void; - @property({ attribute: false }) accessor onOpenDoc!: (docId: string, sessionId: string) => void; @@ -62,7 +50,12 @@ export class AIChatToolbar extends WithDisposable(ShadowlessElement) { private abortController: AbortController | null = null; get isGenerating() { - return this.status === 'transmitting' || this.status === 'loading'; + const status = this.runtimeSnapshot.status; + return status === 'transmitting' || status === 'loading'; + } + + get canCreateNewSession() { + return this.runtimeSnapshot.uiPolicy.canCreateNewSession; } static override styles = css` @@ -141,7 +134,7 @@ export class AIChatToolbar extends WithDisposable(ShadowlessElement) { ); return; } - await this.onTogglePin(); + await this.runtime.dispatch({ type: 'togglePinActiveSession' }); }; private readonly unpinConfirm = async () => { @@ -157,7 +150,7 @@ export class AIChatToolbar extends WithDisposable(ShadowlessElement) { if (!confirm) { return false; } - await this.onTogglePin(); + await this.runtime.dispatch({ type: 'togglePinActiveSession' }); } catch { this.notificationService.toast('Failed to unpin the chat'); } @@ -168,7 +161,7 @@ export class AIChatToolbar extends WithDisposable(ShadowlessElement) { private readonly onPlusClick = async () => { const confirm = await this.unpinConfirm(); if (confirm) { - this.onNewSession(); + await this.runtime.dispatch({ type: 'createNewSession' }); } }; @@ -179,7 +172,11 @@ export class AIChatToolbar extends WithDisposable(ShadowlessElement) { } const confirm = await this.unpinConfirm(); if (confirm) { - this.onOpenSession(sessionId); + await this.runtime.dispatch({ + type: 'openSession', + sessionId, + }); + this.closeHistoryMenu(); } }; @@ -191,7 +188,7 @@ export class AIChatToolbar extends WithDisposable(ShadowlessElement) { this.onOpenDoc(docId, sessionId); }; - private readonly toggleHistoryMenu = () => { + private readonly toggleHistoryMenu = async () => { if (this.abortController) { this.abortController.abort(); return; @@ -202,12 +199,23 @@ export class AIChatToolbar extends WithDisposable(ShadowlessElement) { this.abortController = null; }); + try { + await this.runtime.dispatch({ type: 'refreshHistory' }); + } catch (error) { + console.error(error); + } + if (this.abortController.signal.aborted) { + return; + } + createLitPortal({ template: html` >; + +const DEFAULT_SESSION_TITLE = 'New chat'; +const TITLE_MAX_LENGTH = 28; + +function truncateSessionTitle(text: string) { + if (text.length <= TITLE_MAX_LENGTH) return text; + return `${text.slice(0, TITLE_MAX_LENGTH).trimEnd()}…`; +} + +function deriveSessionTitle(session: HistorySessionWithMessages) { + const explicit = session.title?.trim(); + if (explicit) return truncateSessionTitle(explicit); + const firstUserMessage = session.messages?.find( + message => message.role === 'user' + ); + const raw = firstUserMessage?.content?.trim(); + if (!raw) return DEFAULT_SESSION_TITLE; + const newlineIdx = raw.indexOf('\n'); + return truncateSessionTitle( + newlineIdx === -1 ? raw : raw.slice(0, newlineIdx) + ); +} + export class AISessionHistory extends WithDisposable(ShadowlessElement) { static override styles = css` .ai-session-history { @@ -157,10 +181,16 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) { accessor session!: CopilotChatHistoryFragment | null | undefined; @property({ attribute: false }) - accessor workspaceId!: string; + accessor docId: string | undefined; @property({ attribute: false }) - accessor docId: string | undefined; + accessor recentSessions: BlockSuitePresets.AIRecentSession[] = []; + + @property({ attribute: false }) + accessor currentDocSessions: BlockSuitePresets.AIRecentSession[] = []; + + @property({ attribute: false }) + accessor loading = false; @property({ attribute: false }) accessor docDisplayConfig!: DocDisplayConfig; @@ -176,30 +206,9 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) { @property({ attribute: false }) accessor onDocClick!: (docId: string, sessionId: string) => void; - @query('.ai-session-history') - accessor scrollContainer!: HTMLElement; - - @state() - private accessor sessions: BlockSuitePresets.AIRecentSession[] | undefined; - - @state() - private accessor currentDocSessions: - | BlockSuitePresets.AIRecentSession[] - | undefined; - - @state() - private accessor loadingMore = false; - - @state() - private accessor hasMore = true; - @state() private accessor selectedSessionId: string | undefined; - private accessor currentOffset = 0; - - private readonly pageSize = 10; - private groupSessionsByTime( sessions: BlockSuitePresets.AIRecentSession[] ): GroupedSessions { @@ -248,51 +257,9 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) { return grouped; } - private async getRecentSessions() { - this.loadingMore = true; - - const moreSessions = - (await AIProvider.session?.getRecentSessions( - this.workspaceId, - this.pageSize, - this.currentOffset - )) || []; - this.sessions = [...(this.sessions || []), ...moreSessions]; - - this.currentOffset += moreSessions.length; - this.hasMore = moreSessions.length === this.pageSize; - this.loadingMore = false; - } - - private async getCurrentDocSessions() { - if (!this.docId) { - this.currentDocSessions = []; - return; - } - this.currentDocSessions = - (await AIProvider.session?.getSessions(this.workspaceId, this.docId, { - action: false, - fork: false, - })) || []; - } - - private readonly onScroll = () => { - if (!this.hasMore || this.loadingMore) { - return; - } - // load more when within 50px of bottom - const { scrollTop, scrollHeight, clientHeight } = this.scrollContainer; - const threshold = 50; - if (scrollTop + clientHeight >= scrollHeight - threshold) { - this.getRecentSessions().catch(console.error); - } - }; - override connectedCallback() { super.connectedCallback(); this.selectedSessionId = this.session?.sessionId ?? undefined; - this.getCurrentDocSessions().catch(console.error); - this.getRecentSessions().catch(console.error); } protected override willUpdate(changedProperties: PropertyValues) { @@ -301,14 +268,6 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) { } } - override firstUpdated(changedProperties: PropertyValues) { - super.firstUpdated(changedProperties); - this.disposables.add(() => { - this.scrollContainer.removeEventListener('scroll', this.onScroll); - }); - this.scrollContainer.addEventListener('scroll', this.onScroll); - } - private renderSessionGroup( title: string, sessions: BlockSuitePresets.AIRecentSession[] @@ -320,6 +279,7 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
${title}
${sessions.map(session => { + const sessionTitle = deriveSessionTitle(session); return html`
- ${session.title || 'New chat'} + ${sessionTitle} Click to open this chat @@ -395,15 +355,15 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) { } private renderHistory() { - if (!this.sessions) { + if (this.loading) { return this.renderLoading(); } - const currentDocSessions = this.currentDocSessions ?? []; + const currentDocSessions = this.currentDocSessions; const currentDocSessionIds = new Set( currentDocSessions.map(session => session.sessionId) ); - const otherSessions = this.sessions.filter( + const otherSessions = this.recentSessions.filter( session => !currentDocSessionIds.has(session.sessionId) && (!this.docId || session.docId !== this.docId) diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-toolbar/configure-ai-chat-toolbar.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-toolbar/configure-ai-chat-toolbar.ts index ba89b6fd4c..b2c9e59873 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-toolbar/configure-ai-chat-toolbar.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-toolbar/configure-ai-chat-toolbar.ts @@ -1,21 +1,17 @@ import type { CopilotChatHistoryFragment } from '@affine/graphql'; import type { NotificationService } from '@blocksuite/affine/shared/services'; +import type { AIChatRuntime, AIChatSnapshot } from '../../runtime/chat'; import type { DocDisplayConfig } from '../ai-chat-chips'; -import type { ChatStatus } from '../ai-chat-messages'; import { AIChatToolbar } from './ai-chat-toolbar'; export type ConfigureAIChatToolbarOptions = { session: CopilotChatHistoryFragment | null | undefined; - workspaceId: string; + runtime: AIChatRuntime; + runtimeSnapshot: AIChatSnapshot; docId?: string; - status: ChatStatus; docDisplayConfig: DocDisplayConfig; notificationService: NotificationService; - onNewSession: () => void; - canCreateNewSession?: boolean; - onTogglePin: () => Promise; - onOpenSession: (sessionId: string) => void; onOpenDoc: (docId: string, sessionId: string) => void; onSessionDelete: (session: BlockSuitePresets.AIRecentSession) => void; }; @@ -31,15 +27,11 @@ export function configureAIChatToolbar( options: ConfigureAIChatToolbarOptions ): AIChatToolbar { tool.session = options.session; - tool.workspaceId = options.workspaceId; + tool.runtime = options.runtime; + tool.runtimeSnapshot = options.runtimeSnapshot; tool.docId = options.docId; - tool.status = options.status; tool.docDisplayConfig = options.docDisplayConfig; tool.notificationService = options.notificationService; - tool.onNewSession = options.onNewSession; - tool.canCreateNewSession = options.canCreateNewSession ?? true; - tool.onTogglePin = options.onTogglePin; - tool.onOpenSession = options.onOpenSession; tool.onOpenDoc = options.onOpenDoc; tool.onSessionDelete = options.onSessionDelete; return tool; diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-history-clear/ai-history-clear.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-history-clear/ai-history-clear.ts index c17cc0baaf..2560f277fb 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-history-clear/ai-history-clear.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-history-clear/ai-history-clear.ts @@ -7,7 +7,6 @@ import type { Store } from '@blocksuite/affine/store'; import { css, html } from 'lit'; import { property } from 'lit/decorators.js'; -import { AIProvider } from '../../provider'; import type { ChatContextValue } from '../ai-chat-content'; export class AIHistoryClear extends WithDisposable(ShadowlessElement) { @@ -26,6 +25,9 @@ export class AIHistoryClear extends WithDisposable(ShadowlessElement) { @property({ attribute: false }) accessor onHistoryCleared!: () => void; + @property({ attribute: false }) + accessor onClearHistory!: (sessionIds: string[]) => Promise | void; + static override styles = css` .chat-history-clear { cursor: pointer; @@ -64,11 +66,10 @@ export class AIHistoryClear extends WithDisposable(ShadowlessElement) { const actionIds = this.chatContextValue.messages .filter(item => 'sessionId' in item) .map(item => item.sessionId); - await AIProvider.histories?.cleanup( - this.doc.workspace.id, - this.doc.id, - [...(sessionId ? [sessionId] : []), ...(actionIds || [])] - ); + await this.onClearHistory([ + ...(sessionId ? [sessionId] : []), + ...(actionIds || []), + ]); this.notificationService.toast('History cleared'); this.onHistoryCleared?.(); } diff --git a/packages/frontend/core/src/blocksuite/ai/components/ask-ai-toolbar.ts b/packages/frontend/core/src/blocksuite/ai/components/ask-ai-toolbar.ts index 5def982cd0..632f36a34c 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ask-ai-toolbar.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ask-ai-toolbar.ts @@ -9,7 +9,7 @@ import { flip, offset } from '@floating-ui/dom'; import { css, html, LitElement } from 'lit'; import { property } from 'lit/decorators.js'; -import { AIProvider } from '../provider'; +import { AIAppEvents } from '../provider'; import { getAIPanelWidget } from '../utils/ai-widgets'; import { extractSelectedContent } from '../utils/extract'; import type { AffineAIPanelWidgetConfig } from '../widgets/ai-panel/type'; @@ -75,7 +75,7 @@ export class AskAIToolbarButton extends WithDisposable(LitElement) { aiPanel.hide(); extractSelectedContent(this.host) .then(context => { - AIProvider.slots.requestSendWithChat.next({ + AIAppEvents.requestSendWithChat.next({ input, context, host: this.host, diff --git a/packages/frontend/core/src/blocksuite/ai/components/playground/chat.ts b/packages/frontend/core/src/blocksuite/ai/components/playground/chat.ts index 15f7e8b8fe..00aa2e0780 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/playground/chat.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/playground/chat.ts @@ -7,10 +7,7 @@ import type { import type { WorkspaceDialogService } from '@affine/core/modules/dialogs'; import type { FeatureFlagService } from '@affine/core/modules/feature-flag'; import type { AppThemeService } from '@affine/core/modules/theme'; -import type { - ContextEmbedStatus, - CopilotChatHistoryFragment, -} from '@affine/graphql'; +import type { CopilotChatHistoryFragment } from '@affine/graphql'; import { SignalWatcher, WithDisposable } from '@blocksuite/affine/global/lit'; import { type NotificationService } from '@blocksuite/affine/shared/services'; import { unsafeCSSVarV2 } from '@blocksuite/affine/shared/theme'; @@ -24,7 +21,13 @@ import { createRef, type Ref, ref } from 'lit/directives/ref.js'; import { throttle } from 'lodash-es'; import type { AppSidebarConfig } from '../../chat-panel/chat-config'; -import { AIProvider } from '../../provider'; +import { AIAppEvents, type AIError } from '../../provider'; +import { + AIChatRuntime, + type AIChatSnapshot, + PlaygroundAIChatSessionStrategy, +} from '../../runtime/chat'; +import { getAIRequestService } from '../../runtime/request'; import { HISTORY_IMAGE_ACTIONS } from '../../utils/history-image-actions'; import type { SearchMenuConfig } from '../ai-chat-add-context'; import type { DocDisplayConfig } from '../ai-chat-chips'; @@ -32,8 +35,6 @@ import type { ChatContextValue } from '../ai-chat-content'; import type { AIPlaygroundConfig, AIReasoningConfig } from '../ai-chat-input'; import { type AIChatMessages, - type ChatAction, - type ChatMessage, type HistoryMessage, isChatMessage, } from '../ai-chat-messages'; @@ -202,16 +203,20 @@ export class PlaygroundChat extends SignalWatcher( accessor chatContextValue: ChatContextValue = DEFAULT_CHAT_CONTEXT_VALUE; @state() - accessor embeddingProgress: [number, number] = [0, 0]; + accessor runtimeSnapshot: AIChatSnapshot | null = null; private readonly _chatMessagesRef: Ref = createRef(); - // request counter to track the latest request - private _updateHistoryCounter = 0; + private runtime: AIChatRuntime | null = null; + + private disposeRuntime: (() => void) | null = null; get messages() { - return this.chatContextValue.messages.filter(item => { + const messages = + (this.runtimeSnapshot?.messages as HistoryMessage[] | undefined) ?? + this.chatContextValue.messages; + return messages.filter(item => { return ( isChatMessage(item) || item.messages?.length === 3 || @@ -226,66 +231,46 @@ export class PlaygroundChat extends SignalWatcher( } private readonly _initPanel = async () => { - const userId = (await AIProvider.userInfo)?.id; + const userId = AIAppEvents.userInfo.value?.id; if (!userId) return; - this.isLoading = true; - await this._updateHistory(); - this.isLoading = false; + this.ensureRuntime(); }; - private readonly _createSession = async () => { - return this.session; - }; - - private readonly _updateHistory = async () => { - if (!AIProvider.histories) { - return; - } - - const currentRequest = ++this._updateHistoryCounter; - - const sessionId = this.session?.sessionId; - const [histories, actions] = await Promise.all([ - sessionId - ? AIProvider.histories.chats( - this.doc.workspace.id, - sessionId, - this.doc.id - ) - : Promise.resolve([]), - this.doc.id && this.showActions - ? AIProvider.histories.actions(this.doc.workspace.id, this.doc.id) - : Promise.resolve([]), - ]); - - // Check if this is still the latest request - if (currentRequest !== this._updateHistoryCounter) { - return; - } - - const chatActions = (actions || []) as ChatAction[]; - const messages: HistoryMessage[] = chatActions; - - const chatMessages = (histories?.[0]?.messages || []) as ChatMessage[]; - messages.push(...chatMessages); - + private readonly syncContextFromRuntime = () => { + const snapshot = this.runtimeSnapshot; + if (!snapshot) return; this.chatContextValue = { ...this.chatContextValue, - messages: messages.sort( - (a, b) => - new Date(a.createdAt).getTime() - new Date(b.createdAt).getTime() - ), + messages: snapshot.messages as HistoryMessage[], + status: snapshot.status, + error: snapshot.error as AIError | null, }; - - this._scrollToEnd(); }; - private readonly onEmbeddingProgressChange = ( - count: Record - ) => { - const total = count.finished + count.processing + count.failed; - this.embeddingProgress = [count.finished, total]; + private readonly ensureRuntime = () => { + if (!this.session || this.runtime) return; + this.runtime = new AIChatRuntime({ + request: getAIRequestService(), + scope: { + kind: 'fork', + workspaceId: this.doc.workspace.id, + docId: this.doc.id, + parentSessionId: this.session.parentSessionId ?? this.session.sessionId, + }, + strategy: new PlaygroundAIChatSessionStrategy(), + }); + this.disposeRuntime = this.runtime.subscribe(() => { + this.runtimeSnapshot = this.runtime?.getSnapshot() ?? null; + this.syncContextFromRuntime(); + }); + this.runtimeSnapshot = this.runtime.getSnapshot(); + this.runtime + .dispatch({ + type: 'openSessionObject', + session: this.session, + }) + .catch(console.error); }; private readonly updateContext = (context: Partial) => { @@ -303,7 +288,23 @@ export class PlaygroundChat extends SignalWatcher( this._initPanel().catch(console.error); } + override disconnectedCallback() { + super.disconnectedCallback(); + this.disposeRuntime?.(); + this.runtime?.dispose(); + this.runtime = null; + this.disposeRuntime = null; + } + protected override updated(_changedProperties: PropertyValues) { + if (_changedProperties.has('session')) { + this.disposeRuntime?.(); + this.runtime?.dispose(); + this.runtime = null; + this.disposeRuntime = null; + this.ensureRuntime(); + } + if ( _changedProperties.has('chatContextValue') && (this.chatContextValue.status === 'loading' || @@ -322,7 +323,11 @@ export class PlaygroundChat extends SignalWatcher( } override render() { - const [done, total] = this.embeddingProgress; + const embeddingCount = + this.runtimeSnapshot?.composer.context.embeddingCount; + const done = embeddingCount?.finished ?? 0; + const total = + done + (embeddingCount?.processing ?? 0) + (embeddingCount?.failed ?? 0); const isEmbedding = total > 0 && done < total; return html`
@@ -342,7 +347,23 @@ export class PlaygroundChat extends SignalWatcher( .doc=${this.doc} .session=${this.session} .notificationService=${this.notificationService} - .onHistoryCleared=${this._updateHistory} + .onClearHistory=${async (sessionIds: string[]) => { + for (const sessionId of sessionIds) { + await this.runtime?.dispatch({ + type: 'deleteSession', + sessionId, + }); + } + }} + .onHistoryCleared=${() => + this.session + ? this.runtime + ?.dispatch({ + type: 'openSessionObject', + session: this.session, + }) + .catch(console.error) + : undefined} .chatContextValue=${this.chatContextValue} >
${DeleteIcon()}
@@ -355,7 +376,8 @@ export class PlaygroundChat extends SignalWatcher( .isHistoryLoading=${this.isLoading} .chatContextValue=${this.chatContextValue} .session=${this.session} - .createSession=${this._createSession} + .runtime=${this.runtime} + .runtimeSnapshot=${this.runtimeSnapshot} .updateContext=${this.updateContext} .extensions=${this.extensions} .affineFeatureFlagService=${this.affineFeatureFlagService} @@ -370,10 +392,10 @@ export class PlaygroundChat extends SignalWatcher( .workspaceId=${this.doc.workspace.id} .docId=${this.doc.id} .session=${this.session} - .createSession=${this._createSession} + .runtime=${this.runtime} + .runtimeSnapshot=${this.runtimeSnapshot} .chatContextValue=${this.chatContextValue} .updateContext=${this.updateContext} - .onEmbeddingProgressChange=${this.onEmbeddingProgressChange} .reasoningConfig=${this.reasoningConfig} .playgroundConfig=${this.playgroundConfig} .docDisplayConfig=${this.docDisplayConfig} diff --git a/packages/frontend/core/src/blocksuite/ai/components/playground/content.ts b/packages/frontend/core/src/blocksuite/ai/components/playground/content.ts index 2335498b29..b10332594b 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/playground/content.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/playground/content.ts @@ -18,7 +18,12 @@ import { property, state } from 'lit/decorators.js'; import { repeat } from 'lit/directives/repeat.js'; import type { AppSidebarConfig } from '../../chat-panel/chat-config'; -import { AIProvider } from '../../provider'; +import { + AIChatRuntime, + type AIChatScope, + PlaygroundAIChatSessionStrategy, +} from '../../runtime/chat'; +import { getAIRequestService } from '../../runtime/request'; import type { SearchMenuConfig } from '../ai-chat-add-context'; import type { DocDisplayConfig } from '../ai-chat-chips'; import type { AIPlaygroundConfig, AIReasoningConfig } from '../ai-chat-input'; @@ -119,27 +124,39 @@ export class PlaygroundContent extends SignalWatcher( private isSending = false; + private createSessionRuntime(scope: AIChatScope) { + return new AIChatRuntime({ + request: getAIRequestService(), + scope, + strategy: new PlaygroundAIChatSessionStrategy(), + }); + } + private readonly getSessions = async () => { const sessions = - (await AIProvider.session?.getSessions( + (await getAIRequestService().getSessions( this.doc.workspace.id, this.doc.id, { action: false } )) || []; const rootSession = sessions?.findLast(session => !session.parentSessionId); if (!rootSession) { - // Create a new session - const rootSessionId = await AIProvider.session?.createSession({ - docId: this.doc.id, + const runtime = this.createSessionRuntime({ + kind: 'playground', workspaceId: this.doc.workspace.id, - promptName: 'Chat With AFFiNE AI', + docId: this.doc.id, }); - if (rootSessionId) { - this.rootSessionId = rootSessionId; - const forkSession = await this.forkSession(rootSessionId); - if (forkSession) { - this.sessions = [forkSession]; + try { + const rootSession = await runtime.createSession(); + if (rootSession) { + this.rootSessionId = rootSession.sessionId; + const forkSession = await this.forkSession(rootSession.sessionId); + if (forkSession) { + this.sessions = [forkSession]; + } } + } finally { + runtime.dispose(); } } else { this.rootSessionId = rootSession.sessionId; @@ -158,19 +175,17 @@ export class PlaygroundContent extends SignalWatcher( }; private readonly forkSession = async (parentSessionId: string) => { - const forkSessionId = await AIProvider.forkChat?.({ + const runtime = this.createSessionRuntime({ + kind: 'playground', workspaceId: this.doc.workspace.id, docId: this.doc.id, - sessionId: parentSessionId, - latestMessageId: '', + parentSessionId, }); - if (!forkSessionId) { - return; + try { + return (await runtime.createSession()) ?? undefined; + } finally { + runtime.dispose(); } - return await AIProvider.session?.getSession( - this.doc.workspace.id, - forkSessionId - ); }; private readonly addChat = async () => { diff --git a/packages/frontend/core/src/blocksuite/ai/entries/edgeless/actions-config.ts b/packages/frontend/core/src/blocksuite/ai/entries/edgeless/actions-config.ts index 1962b6bf1f..04589440a6 100644 --- a/packages/frontend/core/src/blocksuite/ai/entries/edgeless/actions-config.ts +++ b/packages/frontend/core/src/blocksuite/ai/entries/edgeless/actions-config.ts @@ -49,7 +49,7 @@ import { translateLangs, } from '../../actions/types'; import type { AIItemGroupConfig } from '../../components/ai-item/types'; -import { AIProvider } from '../../provider'; +import { AIAppEvents } from '../../provider'; import { getAIPanelWidget } from '../../utils/ai-widgets'; import { getEdgelessCopilotWidget, @@ -121,7 +121,7 @@ const othersGroup: AIItemGroupConfig = { const edgelessCopilot = getEdgelessCopilotWidget(host); extractSelectedContent(host) .then(context => { - AIProvider.slots.requestOpenWithChat.next({ + AIAppEvents.requestOpenWithChat.next({ host, mode: 'edgeless', autoSelect: true, diff --git a/packages/frontend/core/src/blocksuite/ai/entries/edgeless/index.ts b/packages/frontend/core/src/blocksuite/ai/entries/edgeless/index.ts index c0d35e90fc..65d5775687 100644 --- a/packages/frontend/core/src/blocksuite/ai/entries/edgeless/index.ts +++ b/packages/frontend/core/src/blocksuite/ai/entries/edgeless/index.ts @@ -7,7 +7,7 @@ import { import { html } from 'lit'; import type { AIItemGroupConfig } from '../../components/ai-item/types'; -import { AIProvider } from '../../provider'; +import { AIAppEvents } from '../../provider'; import { getAIPanelWidget } from '../../utils/ai-widgets'; import { getEdgelessCopilotWidget } from '../../utils/edgeless'; import { extractSelectedContent } from '../../utils/extract'; @@ -58,14 +58,14 @@ export function edgelessToolbarAIEntryConfig(): ToolbarModuleConfig { extractSelectedContent(host) .then(context => { if (context?.attachments?.length || context?.docs?.length) { - AIProvider.slots.requestOpenWithChat.next({ + AIAppEvents.requestOpenWithChat.next({ input, host, context, autoSelect: true, }); } else { - AIProvider.slots.requestSendWithChat.next({ + AIAppEvents.requestSendWithChat.next({ input, context, host, diff --git a/packages/frontend/core/src/blocksuite/ai/entries/space/setup-space.ts b/packages/frontend/core/src/blocksuite/ai/entries/space/setup-space.ts index 3ff47198c0..67cc202d07 100644 --- a/packages/frontend/core/src/blocksuite/ai/entries/space/setup-space.ts +++ b/packages/frontend/core/src/blocksuite/ai/entries/space/setup-space.ts @@ -2,7 +2,7 @@ import type { RichText } from '@blocksuite/affine/rich-text'; import { type EditorHost, TextSelection } from '@blocksuite/affine/std'; import { handleInlineAskAIAction } from '../../actions/doc-handler'; -import { AIProvider } from '../../provider'; +import { hasAIRequestService } from '../../runtime/request'; import type { AffineAIPanelWidget } from '../../widgets/ai-panel/ai-panel'; function isSpaceEvent(event: KeyboardEvent) { @@ -43,7 +43,7 @@ export function setupSpaceAIEntry(panel: AffineAIPanelWidget) { const host = panel.host; const keyboardState = ctx.get('keyboardState'); const event = keyboardState.raw; - if (AIProvider.actions.chat && isSpaceEvent(event)) { + if (hasAIRequestService() && isSpaceEvent(event)) { // If the AI panel is in the input state and the input content is empty, // insert a space back into the editor. if (panel.state === 'input') { diff --git a/packages/frontend/core/src/blocksuite/ai/index.ts b/packages/frontend/core/src/blocksuite/ai/index.ts index f897aca83b..ac2e318849 100644 --- a/packages/frontend/core/src/blocksuite/ai/index.ts +++ b/packages/frontend/core/src/blocksuite/ai/index.ts @@ -5,4 +5,6 @@ export * from './entries/edgeless/actions-config'; export * from './messages'; export { AIChatBlockPeekViewTemplate } from './peek-view/chat-block-peek-view'; export * from './provider'; +export * from './runtime/chat'; +export * from './runtime/request'; export * from './utils/edgeless'; diff --git a/packages/frontend/core/src/blocksuite/ai/messages/error.ts b/packages/frontend/core/src/blocksuite/ai/messages/error.ts index 8e9c2fb916..d9d70deafa 100644 --- a/packages/frontend/core/src/blocksuite/ai/messages/error.ts +++ b/packages/frontend/core/src/blocksuite/ai/messages/error.ts @@ -9,8 +9,8 @@ import { css, html, LitElement, nothing, unsafeCSS } from 'lit'; import { property } from 'lit/decorators.js'; import { + AIAppEvents, type AIError, - AIProvider, PaymentRequiredError, UnauthorizedError, } from '../provider'; @@ -190,7 +190,7 @@ const PaymentRequiredErrorRenderer = (host?: EditorHost | null) => html` AIProvider.slots.requestUpgradePlan.next({ host })} + .onClick=${() => AIAppEvents.requestUpgradePlan.next({ host })} > `; @@ -198,7 +198,7 @@ const LoginRequiredErrorRenderer = (host?: EditorHost | null) => html` AIProvider.slots.requestLogin.next({ host })} + .onClick=${() => AIAppEvents.requestLogin.next({ host })} > `; diff --git a/packages/frontend/core/src/blocksuite/ai/peek-view/chat-block-peek-view.ts b/packages/frontend/core/src/blocksuite/ai/peek-view/chat-block-peek-view.ts index e40403b3dd..8132f0d038 100644 --- a/packages/frontend/core/src/blocksuite/ai/peek-view/chat-block-peek-view.ts +++ b/packages/frontend/core/src/blocksuite/ai/peek-view/chat-block-peek-view.ts @@ -9,10 +9,7 @@ import type { } from '@affine/core/modules/cloud'; import type { WorkspaceDialogService } from '@affine/core/modules/dialogs'; import type { FeatureFlagService } from '@affine/core/modules/feature-flag'; -import type { - ContextEmbedStatus, - CopilotChatHistoryFragment, -} from '@affine/graphql'; +import type { CopilotChatHistoryFragment } from '@affine/graphql'; import { CanvasElementType, EdgelessCRUDIdentifier, @@ -42,18 +39,17 @@ import type { SearchMenuConfig } from '../components/ai-chat-add-context'; import type { DocDisplayConfig } from '../components/ai-chat-chips'; import type { AIReasoningConfig } from '../components/ai-chat-input'; import type { ChatMessage } from '../components/ai-chat-messages'; -import { - ChatMessagesSchema, - isChatMessage, - StreamObjectSchema, -} from '../components/ai-chat-messages'; +import { ChatMessagesSchema } from '../components/ai-chat-messages'; import type { TextRendererOptions } from '../components/text-renderer'; import { AIChatErrorRenderer } from '../messages/error'; -import { type AIError, AIProvider } from '../provider'; +import { AIAppEvents, type AIError } from '../provider'; import { - mergeStreamContent, - mergeStreamObjects, -} from '../utils/stream-objects'; + AIChatRuntime, + type AIChatSnapshot, + ChatBlockAIChatSessionStrategy, +} from '../runtime/chat'; +import { getAIRequestService } from '../runtime/request'; +import { mergeStreamContent } from '../utils/stream-objects'; import { PeekViewStyles } from './styles'; import type { ChatContext } from './types'; import { calcChildBound } from './utils'; @@ -85,14 +81,14 @@ export class AIChatBlockPeekView extends LitElement { return this.blockModel.props.rootWorkspaceId; } - private get _isReasoningActive() { - return !!this.reasoningConfig.enabled.value; - } - private _textRendererOptions: TextRendererOptions = {}; private _forkBlockId: string | undefined = undefined; + private runtime: AIChatRuntime | null = null; + + private disposeRuntime: (() => void) | null = null; + private readonly _deserializeHistoryChatMessages = ( historyMessagesString: string ) => { @@ -115,7 +111,7 @@ export class AIChatBlockPeekView extends LitElement { forkSessionId: string, docId?: string ) => { - const currentUserInfo = await AIProvider.userInfo; + const currentUserInfo = AIAppEvents.userInfo.value; const forkMessages = (await queryHistoryMessages( rootWorkspaceId, forkSessionId, @@ -160,36 +156,28 @@ export class AIChatBlockPeekView extends LitElement { this._forkBlockId = undefined; }; - private readonly initSession = async () => { - const session = await AIProvider.session?.getSession( - this.rootWorkspaceId, - this._sessionId - ); - this.session = session ?? null; - }; - - private readonly createForkSession = async () => { - if (this.forkSession) { - return this.forkSession; - } - const lastMessage = this._historyMessages.at(-1); - if (!lastMessage) return; - - const { store } = this.host; - const forkSessionId = await AIProvider.forkChat?.({ - workspaceId: store.workspace.id, - docId: store.id, - sessionId: this._sessionId, - latestMessageId: lastMessage.id, + private createRuntime() { + return new AIChatRuntime({ + request: getAIRequestService(), + scope: { + kind: 'chat-block', + workspaceId: this.rootWorkspaceId, + docId: this.rootDocId, + blockId: this.blockId, + parentSessionId: this._sessionId, + latestMessageId: this._historyMessages.at(-1)?.id, + }, + strategy: new ChatBlockAIChatSessionStrategy(), }); - if (forkSessionId) { - const session = await AIProvider.session?.getSession( - this.rootWorkspaceId, - forkSessionId - ); - this.forkSession = session ?? null; + } + + private readonly initSession = async () => { + const runtime = this.createRuntime(); + try { + this.session = (await runtime.loadInitialSession()) ?? null; + } finally { + runtime.dispose(); } - return this.forkSession; }; private readonly _onChatSuccess = async () => { @@ -312,11 +300,37 @@ export class AIChatBlockPeekView extends LitElement { this.chatContext = { ...this.chatContext, ...context }; }; - private readonly onEmbeddingProgressChange = ( - count: Record - ) => { - const total = count.finished + count.processing + count.failed; - this.embeddingProgress = [count.finished, total]; + private readonly syncContextFromRuntime = () => { + const snapshot = this.runtimeSnapshot; + if (!snapshot) return; + const activeSession = snapshot.sessions.find( + session => session.sessionId === snapshot.activeSessionId + ); + this.forkSession = activeSession ?? this.forkSession; + this.chatContext = { + ...this.chatContext, + messages: snapshot.messages as ChatMessage[], + status: snapshot.status, + error: snapshot.error as AIError | null, + }; + }; + + private readonly ensureRuntime = () => { + if (this.runtime || !this.session || !this._historyMessages.length) return; + this.runtime = this.createRuntime(); + this.disposeRuntime = this.runtime.subscribe(() => { + this.runtimeSnapshot = this.runtime?.getSnapshot() ?? null; + this.syncContextFromRuntime(); + }); + this.runtimeSnapshot = this.runtime.getSnapshot(); + if (this.forkSession) { + this.runtime + .dispatch({ + type: 'openSessionObject', + session: this.forkSession, + }) + .catch(console.error); + } }; /** @@ -359,72 +373,16 @@ export class AIChatBlockPeekView extends LitElement { * Retry the last chat message */ retry = async () => { - try { - const forkSessionId = this.forkSession?.sessionId; - if (!this._forkBlockId || !forkSessionId) return; - if (!AIProvider.actions.chat) return; - - const abortController = new AbortController(); - const messages = [...this.chatContext.messages]; - const last = messages[messages.length - 1]; - if ('content' in last) { - last.content = ''; - last.streamObjects = []; - last.createdAt = new Date().toISOString(); - } - this.updateContext({ - messages, - status: 'loading', - error: null, - abortController, + if (this.runtime) { + const lastAssistantMessage = this.chatContext.messages.findLast( + message => message.role === 'assistant' + ); + await this.runtime.dispatch({ + type: 'retry', + messageId: lastAssistantMessage?.id ?? '', }); - - const { store } = this.host; - const stream = await AIProvider.actions.chat({ - sessionId: forkSessionId, - retry: true, - docId: store.id, - workspaceId: store.workspace.id, - host: this.host, - stream: true, - signal: abortController.signal, - where: 'ai-chat-block', - control: 'chat-send', - reasoning: this._isReasoningActive, - toolsConfig: this.aiToolsConfigService.config.value, - }); - - for await (const text of stream) { - const messages = this.chatContext.messages.slice(0); - const last = messages.at(-1); - if (last && isChatMessage(last)) { - try { - const parsed = StreamObjectSchema.parse(JSON.parse(text)); - const streamObjects = mergeStreamObjects([ - ...(last.streamObjects ?? []), - parsed, - ]); - messages[messages.length - 1] = { - ...last, - streamObjects, - }; - } catch { - messages[messages.length - 1] = { - ...last, - content: last.content + text, - }; - } - this.updateContext({ messages, status: 'transmitting' }); - } - } - - this.updateContext({ status: 'success' }); - // Update new chat block messages if there are contents returned from AI await this.updateChatBlockMessages(); - } catch (error) { - this.updateContext({ status: 'error', error: error as AIError }); - } finally { - this.updateContext({ abortController: null }); + return; } }; @@ -526,17 +484,33 @@ export class AIChatBlockPeekView extends LitElement { attachments: messages[idx]?.attachments ?? [], }; }); + this.ensureRuntime(); }) .catch((err: Error) => { console.error('Query history messages failed', err); }); } + override disconnectedCallback() { + super.disconnectedCallback(); + this.disposeRuntime?.(); + this.runtime?.dispose(); + this.disposeRuntime = null; + this.runtime = null; + } + override firstUpdated() { this._scrollToEnd(); } protected override updated(changedProperties: PropertyValues) { + if ( + changedProperties.has('session') || + changedProperties.has('_historyMessages') + ) { + this.ensureRuntime(); + } + if ( changedProperties.has('chatContext') && (this.chatContext.status === 'loading' || @@ -572,6 +546,14 @@ export class AIChatBlockPeekView extends LitElement { { + for (const sessionId of sessionIds) { + await this.runtime?.dispatch({ + type: 'deleteSession', + sessionId, + }); + } + }} .onHistoryCleared=${this._onHistoryCleared} .chatContextValue=${chatContext} .notificationService=${notificationService} @@ -593,10 +575,10 @@ export class AIChatBlockPeekView extends LitElement { .workspaceId=${this.rootWorkspaceId} .docId=${this.rootDocId} .session=${this.forkSession ?? this.session} - .createSession=${this.createForkSession} + .runtime=${this.runtime} + .runtimeSnapshot=${this.runtimeSnapshot} .chatContextValue=${chatContext} .updateContext=${updateContext} - .onEmbeddingProgressChange=${this.onEmbeddingProgressChange} .docDisplayConfig=${this.docDisplayConfig} .searchMenuConfig=${this.searchMenuConfig} .affineWorkspaceDialogService=${this.affineWorkspaceDialogService} @@ -673,7 +655,7 @@ export class AIChatBlockPeekView extends LitElement { }; @state() - accessor embeddingProgress: [number, number] = [0, 0]; + accessor runtimeSnapshot: AIChatSnapshot | null = null; @state() accessor session: CopilotChatHistoryFragment | null | undefined; diff --git a/packages/frontend/core/src/blocksuite/ai/provider/ai-app-events.ts b/packages/frontend/core/src/blocksuite/ai/provider/ai-app-events.ts new file mode 100644 index 0000000000..3d7895a39c --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/provider/ai-app-events.ts @@ -0,0 +1,19 @@ +import type { EditorHost } from '@blocksuite/affine/std'; +import { BehaviorSubject, Subject } from 'rxjs'; + +import type { AIChatParams, AISendParams, AIUserInfo } from './ai-provider'; + +export const AIAppEvents = { + /* eslint-disable rxjs/finnish */ + requestOpenWithChat: new BehaviorSubject(null), + requestSendWithChat: new BehaviorSubject(null), + requestInsertTemplate: new Subject<{ + template: string; + mode: 'page' | 'edgeless'; + }>(), + requestLogin: new Subject<{ host?: EditorHost | null }>(), + requestUpgradePlan: new Subject<{ host?: EditorHost | null }>(), + userInfo: new BehaviorSubject(null), + previewPanelOpenChange: new Subject(), + /* eslint-enable rxjs/finnish */ +}; diff --git a/packages/frontend/core/src/blocksuite/ai/provider/ai-provider.ts b/packages/frontend/core/src/blocksuite/ai/provider/ai-provider.ts index f39f7490f0..b62a4f88aa 100644 --- a/packages/frontend/core/src/blocksuite/ai/provider/ai-provider.ts +++ b/packages/frontend/core/src/blocksuite/ai/provider/ai-provider.ts @@ -1,13 +1,6 @@ import type { EditorHost } from '@blocksuite/affine/std'; -import { captureException } from '@sentry/react'; -import { BehaviorSubject, Subject } from 'rxjs'; import type { ChatContextValue } from '../components/ai-chat-content'; -import { - PaymentRequiredError, - RequestTimeoutError, - UnauthorizedError, -} from './error'; export interface AIUserInfo { id: string; @@ -64,14 +57,6 @@ export type ActionEventType = * TODO: breakdown into different parts? */ export class AIProvider { - static get slots() { - return AIProvider.instance.slots; - } - - static get actions() { - return AIProvider.instance.actions; - } - static get userInfo() { return AIProvider.instance.userInfoFn(); } @@ -80,275 +65,34 @@ export class AIProvider { return AIProvider.instance.photoEngine; } - static get histories() { - return AIProvider.instance.histories; - } - - static get session() { - return AIProvider.instance.session; - } - - static get context() { - return AIProvider.instance.context; - } - - static get actionHistory() { - return AIProvider.instance.actionHistory; - } - static get toggleGeneralAIOnboarding() { return AIProvider.instance.toggleGeneralAIOnboarding; } - static get forkChat() { - return AIProvider.instance.forkChat; - } - - static get embedding() { - return AIProvider.instance.embedding; - } - private static readonly instance = new AIProvider(); - static LAST_ACTION_SESSIONID = ''; - - static MAX_LOCAL_HISTORY = 10; - - private readonly actions: Partial = {}; - private photoEngine: BlockSuitePresets.AIPhotoEngineService | null = null; - private histories: BlockSuitePresets.AIHistoryService | null = null; - - private session: BlockSuitePresets.AISessionService | null = null; - - private context: BlockSuitePresets.AIContextService | null = null; - private toggleGeneralAIOnboarding: ((value: boolean) => void) | null = null; - private forkChat: - | (( - options: BlockSuitePresets.AIForkChatSessionOptions - ) => string | Promise) - | null = null; - - private readonly slots = { - // use case: when user selects "continue in chat" in an ask ai result panel - // do we need to pass the context to the chat panel? - /* eslint-disable rxjs/finnish */ - requestOpenWithChat: new BehaviorSubject(null), - requestSendWithChat: new BehaviorSubject(null), - requestInsertTemplate: new Subject<{ - template: string; - mode: 'page' | 'edgeless'; - }>(), - requestLogin: new Subject<{ host?: EditorHost | null }>(), - requestUpgradePlan: new Subject<{ host?: EditorHost | null }>(), - // stream of AI actions triggered by users - actions: new Subject<{ - action: keyof BlockSuitePresets.AIActions; - options: BlockSuitePresets.AITextActionOptions; - event: ActionEventType; - }>(), - // downstream can emit this slot to notify ai presets that user info has been updated - userInfo: new Subject(), - sessionReady: new BehaviorSubject(false), - previewPanelOpenChange: new Subject(), - /* eslint-enable rxjs/finnish */ - }; - - // track the history of triggered actions (in memory only) - private readonly actionHistory: { - action: keyof BlockSuitePresets.AIActions; - options: BlockSuitePresets.AITextActionOptions; - }[] = []; - private userInfoFn: () => AIUserInfo | Promise | null = () => null; - private embedding: BlockSuitePresets.AIEmbeddingService | null = null; - - private provideAction( - id: T, - action: ( - ...options: Parameters - ) => Promise> - ): void { - // @ts-expect-error TODO: maybe fix this - this.actions[id] = async ( - ...args: Parameters - ) => { - const options = args[0]; - const slots = this.slots; - slots.actions.next({ - action: id, - options, - event: 'started', - }); - this.actionHistory.push({ action: id, options }); - if (this.actionHistory.length > AIProvider.MAX_LOCAL_HISTORY) { - this.actionHistory.shift(); - } - // wrap the action with slot actions - const result: BlockSuitePresets.TextStream | Promise = - await action(...args); - const isTextStream = ( - m: BlockSuitePresets.TextStream | Promise - ): m is BlockSuitePresets.TextStream => - Reflect.has(m, Symbol.asyncIterator); - if (isTextStream(result)) { - return { - [Symbol.asyncIterator]: async function* () { - let user = null; - try { - user = await AIProvider.userInfo; - yield* result; - slots.actions.next({ - action: id, - options, - event: 'finished', - }); - } catch (err) { - slots.actions.next({ - action: id, - options, - event: 'error', - }); - if (err instanceof RequestTimeoutError) { - slots.actions.next({ - action: id, - options, - event: 'aborted:timeout', - }); - } else if (err instanceof PaymentRequiredError) { - slots.actions.next({ - action: id, - options, - event: 'aborted:paywall', - }); - } else if (err instanceof UnauthorizedError) { - slots.actions.next({ - action: id, - options, - event: 'aborted:login-required', - }); - } else { - slots.actions.next({ - action: id, - options, - event: 'aborted:server-error', - }); - captureException(err, { - user: { id: user?.id }, - extra: { - action: id, - session: AIProvider.LAST_ACTION_SESSIONID, - }, - }); - } - throw err; - } - }, - }; - } else { - let user: any = null; - return result - .then(async result => { - user = await AIProvider.userInfo; - slots.actions.next({ - action: id, - options, - event: 'finished', - }); - return result; - }) - .catch(err => { - slots.actions.next({ - action: id, - options, - event: 'error', - }); - if (err instanceof PaymentRequiredError) { - slots.actions.next({ - action: id, - options, - event: 'aborted:paywall', - }); - } else { - captureException(err, { - user: { id: user?.id }, - extra: { - action: id, - session: AIProvider.LAST_ACTION_SESSIONID, - }, - }); - } - throw err; - }); - } - }; - } - static provide( id: 'userInfo', fn: () => AIUserInfo | Promise | null ): void; - static provide( - id: 'session', - service: BlockSuitePresets.AISessionService - ): void; - - static provide( - id: 'context', - service: BlockSuitePresets.AIContextService - ): void; - - static provide( - id: 'histories', - service: BlockSuitePresets.AIHistoryService - ): void; - static provide( id: 'photoEngine', engine: BlockSuitePresets.AIPhotoEngineService ): void; - static provide( - id: 'forkChat', - fn: ( - options: BlockSuitePresets.AIForkChatSessionOptions - ) => string | Promise - ): void; - static provide(id: 'onboarding', fn: (value: boolean) => void): void; - static provide( - id: 'embedding', - service: BlockSuitePresets.AIEmbeddingService - ): void; - - // actions: - static provide( - id: T, - action: ( - ...options: Parameters - ) => Promise> - ): void; - static provide(id: unknown, action: unknown) { if (id === 'userInfo') { AIProvider.instance.userInfoFn = action as () => AIUserInfo; - } else if (id === 'histories') { - AIProvider.instance.histories = - action as BlockSuitePresets.AIHistoryService; - } else if (id === 'session') { - AIProvider.instance.session = - action as BlockSuitePresets.AISessionService; - AIProvider.instance.slots.sessionReady.next(true); - } else if (id === 'context') { - AIProvider.instance.context = - action as BlockSuitePresets.AIContextService; } else if (id === 'photoEngine') { AIProvider.instance.photoEngine = action as BlockSuitePresets.AIPhotoEngineService; @@ -356,15 +100,6 @@ export class AIProvider { AIProvider.instance.toggleGeneralAIOnboarding = action as ( value: boolean ) => void; - } else if (id === 'forkChat') { - AIProvider.instance.forkChat = action as ( - options: BlockSuitePresets.AIForkChatSessionOptions - ) => string | Promise; - } else if (id === 'embedding') { - AIProvider.instance.embedding = - action as BlockSuitePresets.AIEmbeddingService; - } else { - AIProvider.instance.provideAction(id as any, action as any); } } } diff --git a/packages/frontend/core/src/blocksuite/ai/provider/event-source.ts b/packages/frontend/core/src/blocksuite/ai/provider/event-source.ts index 4fd55c565b..15f2d1edda 100644 --- a/packages/frontend/core/src/blocksuite/ai/provider/event-source.ts +++ b/packages/frontend/core/src/blocksuite/ai/provider/event-source.ts @@ -1,4 +1,4 @@ -import { handleError } from './copilot-client'; +import { handleError } from '../runtime/request/copilot-client'; import { RequestTimeoutError } from './error'; export function delay(ms: number) { diff --git a/packages/frontend/core/src/blocksuite/ai/provider/index.ts b/packages/frontend/core/src/blocksuite/ai/provider/index.ts index 4dee66f374..6d62123338 100644 --- a/packages/frontend/core/src/blocksuite/ai/provider/index.ts +++ b/packages/frontend/core/src/blocksuite/ai/provider/index.ts @@ -1,4 +1,4 @@ +export * from './ai-app-events'; export * from './ai-provider'; -export * from './copilot-client'; export * from './error'; export * from './setup-provider'; diff --git a/packages/frontend/core/src/blocksuite/ai/provider/request.spec.ts b/packages/frontend/core/src/blocksuite/ai/provider/request.spec.ts deleted file mode 100644 index 4da3f48e41..0000000000 --- a/packages/frontend/core/src/blocksuite/ai/provider/request.spec.ts +++ /dev/null @@ -1,284 +0,0 @@ -/** - * @vitest-environment happy-dom - */ -import { UserFriendlyError } from '@affine/error'; -import { beforeEach, describe, expect, test, vi } from 'vitest'; - -import { type CopilotClient, Endpoint } from './copilot-client'; -import { textToText, toImage } from './request'; - -const electronApis = vi.hoisted(() => ({ - byokStorage: undefined as - | { - isSupported: () => Promise; - getWorkspaceLeaseProviders: (workspaceId: string) => Promise< - Array<{ - provider: string; - name: string; - apiKey: string; - description?: string | null; - endpoint?: string | null; - sortOrder?: number | null; - enabled?: boolean | null; - }> - >; - } - | undefined, -})); - -const createWorkspaceByokLocalLeaseMutation = vi.hoisted(() => - Symbol('createWorkspaceByokLocalLeaseMutation') -); - -vi.mock('@affine/electron-api', () => ({ - apis: electronApis, -})); - -vi.mock('@affine/graphql', () => ({ - ByokProvider: { - openai: 'openai', - anthropic: 'anthropic', - gemini: 'gemini', - fal: 'fal', - }, - createWorkspaceByokLocalLeaseMutation, -})); - -function createClient( - overrides: Partial< - Pick< - CopilotClient, - 'gql' | 'createMessage' | 'chatTextStream' | 'imagesStream' - > - > = {} -) { - return { - gql: vi.fn().mockResolvedValue({ - createWorkspaceByokLocalLease: { leaseId: 'lease-1' }, - }), - createMessage: vi.fn().mockResolvedValue('message-1'), - chatTextStream: vi.fn(), - imagesStream: vi.fn(), - ...overrides, - } as unknown as CopilotClient; -} - -async function drain(stream: AsyncIterable) { - for await (const chunk of stream) { - void chunk; - } -} - -describe('AI request BYOK local lease handling', () => { - beforeEach(() => { - vi.stubGlobal('BUILD_CONFIG', { isElectron: true }); - electronApis.byokStorage = { - isSupported: vi.fn().mockResolvedValue(true), - getWorkspaceLeaseProviders: vi.fn().mockResolvedValue([ - { - provider: 'openai', - name: 'OpenAI', - apiKey: 'sk-local', - }, - ]), - }; - }); - - test('fails closed when local BYOK providers exist but lease creation fails', async () => { - const client = createClient({ - gql: vi.fn().mockRejectedValue(new Error('mutation failed')), - }); - - const result = textToText({ - client, - sessionId: 'session-1', - workspaceId: 'workspace-1', - content: 'hello', - }) as Promise; - - await expect(result).rejects.toThrow('mutation failed'); - await expect(result).rejects.toBeInstanceOf(UserFriendlyError); - expect(client.chatTextStream).not.toHaveBeenCalled(); - }); - - test('wraps local BYOK storage support failures as user friendly errors', async () => { - electronApis.byokStorage = { - isSupported: vi.fn().mockRejectedValue(new Error('support check failed')), - getWorkspaceLeaseProviders: vi.fn(), - }; - const client = createClient(); - - const result = textToText({ - client, - sessionId: 'session-1', - workspaceId: 'workspace-1', - content: 'hello', - }) as Promise; - - await expect(result).rejects.toThrow('support check failed'); - await expect(result).rejects.toBeInstanceOf(UserFriendlyError); - expect(client.chatTextStream).not.toHaveBeenCalled(); - }); - - test('wraps local BYOK provider loading failures as user friendly errors', async () => { - electronApis.byokStorage = { - isSupported: vi.fn().mockResolvedValue(true), - getWorkspaceLeaseProviders: vi - .fn() - .mockRejectedValue(new Error('provider load failed')), - }; - const client = createClient(); - - const result = textToText({ - client, - sessionId: 'session-1', - workspaceId: 'workspace-1', - content: 'hello', - }) as Promise; - - await expect(result).rejects.toThrow('provider load failed'); - await expect(result).rejects.toBeInstanceOf(UserFriendlyError); - expect(client.chatTextStream).not.toHaveBeenCalled(); - }); - - test('does not create local BYOK lease after cancellation', async () => { - const controller = new AbortController(); - const client = createClient({ - createMessage: vi.fn().mockImplementation(async () => { - controller.abort(); - return 'message-1'; - }), - }); - - await expect( - textToText({ - client, - sessionId: 'session-1', - workspaceId: 'workspace-1', - content: 'hello', - signal: controller.signal, - }) as Promise - ).resolves.toBe(''); - expect(client.gql).not.toHaveBeenCalled(); - expect(client.chatTextStream).not.toHaveBeenCalled(); - }); - - test('does not create stream local BYOK lease after cancellation', async () => { - const controller = new AbortController(); - const client = createClient({ - createMessage: vi.fn().mockImplementation(async () => { - controller.abort(); - return 'message-1'; - }), - }); - - await drain( - textToText({ - client, - sessionId: 'session-1', - workspaceId: 'workspace-1', - content: 'hello', - stream: true, - signal: controller.signal, - }) as AsyncIterable - ); - - expect(client.gql).not.toHaveBeenCalled(); - expect(client.chatTextStream).not.toHaveBeenCalled(); - }); - - test('does not create text stream when cancelled while creating local BYOK lease', async () => { - const controller = new AbortController(); - const client = createClient({ - gql: vi.fn().mockImplementation(async () => { - controller.abort(); - return { createWorkspaceByokLocalLease: { leaseId: 'lease-1' } }; - }), - }); - - await drain( - textToText({ - client, - sessionId: 'session-1', - workspaceId: 'workspace-1', - content: 'hello', - stream: true, - signal: controller.signal, - }) as AsyncIterable - ); - - expect(client.gql).toHaveBeenCalled(); - expect(client.chatTextStream).not.toHaveBeenCalled(); - }); - - test('does not create text request when cancelled while creating local BYOK lease', async () => { - const controller = new AbortController(); - const client = createClient({ - gql: vi.fn().mockImplementation(async () => { - controller.abort(); - return { createWorkspaceByokLocalLease: { leaseId: 'lease-1' } }; - }), - }); - - await expect( - textToText({ - client, - sessionId: 'session-1', - workspaceId: 'workspace-1', - content: 'hello', - signal: controller.signal, - }) as Promise - ).resolves.toBe(''); - - expect(client.gql).toHaveBeenCalled(); - expect(client.chatTextStream).not.toHaveBeenCalled(); - }); - - test('does not create image local BYOK lease after cancellation', async () => { - const controller = new AbortController(); - const client = createClient({ - createMessage: vi.fn().mockImplementation(async () => { - controller.abort(); - return 'message-1'; - }), - }); - - await drain( - toImage({ - client, - sessionId: 'session-1', - workspaceId: 'workspace-1', - content: 'image', - endpoint: Endpoint.Images, - signal: controller.signal, - }) as AsyncIterable - ); - - expect(client.gql).not.toHaveBeenCalled(); - expect(client.imagesStream).not.toHaveBeenCalled(); - }); - - test('does not create image stream when cancelled while creating local BYOK lease', async () => { - const controller = new AbortController(); - const client = createClient({ - gql: vi.fn().mockImplementation(async () => { - controller.abort(); - return { createWorkspaceByokLocalLease: { leaseId: 'lease-1' } }; - }), - }); - - await drain( - toImage({ - client, - sessionId: 'session-1', - workspaceId: 'workspace-1', - content: 'image', - endpoint: Endpoint.Images, - signal: controller.signal, - }) as AsyncIterable - ); - - expect(client.gql).toHaveBeenCalled(); - expect(client.imagesStream).not.toHaveBeenCalled(); - }); -}); diff --git a/packages/frontend/core/src/blocksuite/ai/provider/setup-provider.spec.ts b/packages/frontend/core/src/blocksuite/ai/provider/setup-provider.spec.ts deleted file mode 100644 index 0dd57c9040..0000000000 --- a/packages/frontend/core/src/blocksuite/ai/provider/setup-provider.spec.ts +++ /dev/null @@ -1,120 +0,0 @@ -/** - * @vitest-environment happy-dom - */ -import { BehaviorSubject } from 'rxjs'; -import { describe, expect, test, vi } from 'vitest'; - -import { AIProvider } from './ai-provider'; -import { CopilotClient, Endpoint } from './copilot-client'; -import { setupAIProvider } from './setup-provider'; - -Object.defineProperty(globalThis, 'EventSource', { - configurable: true, - value: { - CLOSED: 2, - }, -}); - -type SetupAIProviderArgs = Parameters; -type ActionInput = Parameters< - NonNullable ->[0]; - -async function drain(stream: AsyncIterable) { - for await (const chunk of stream) { - void chunk; - } -} - -async function drainActionResult( - stream: string | AsyncIterable | undefined -) { - expect(stream).toBeDefined(); - expect(typeof stream).not.toBe('string'); - await drain(stream as AsyncIterable); -} - -function createClosedEventSource(): EventSource { - return { - readyState: EventSource.CLOSED, - addEventListener: vi.fn(), - close: vi.fn(), - } as unknown as EventSource; -} - -describe('setupAIProvider action migrations', () => { - test('routes mindmap, slides and image filter through action API', async () => { - const createdSessions: unknown[] = []; - const textStreams: unknown[] = []; - const client = new CopilotClient( - vi.fn(), - vi.fn(() => createClosedEventSource()) - ); - vi.spyOn(client, 'createSession').mockImplementation(async options => { - createdSessions.push(options); - return `session:${options.promptName}`; - }); - vi.spyOn(client, 'createMessage').mockResolvedValue('message-1'); - vi.spyOn(client, 'chatTextStream').mockImplementation( - (options, endpoint) => { - textStreams.push({ options, endpoint }); - return createClosedEventSource(); - } - ); - vi.spyOn(client, 'imagesStream').mockReturnValue(createClosedEventSource()); - - setupAIProvider( - client, - { open: vi.fn() } as unknown as SetupAIProviderArgs[1], - { - session: { - account$: new BehaviorSubject(null), - }, - } as unknown as SetupAIProviderArgs[2] - ); - - await drainActionResult( - await AIProvider.actions.brainstormMindmap?.({ - workspaceId: 'workspace-1', - input: 'make a map', - stream: true, - } satisfies ActionInput<'brainstormMindmap'>) - ); - await drainActionResult( - await AIProvider.actions.createSlides?.({ - workspaceId: 'workspace-1', - input: 'make slides', - stream: true, - } satisfies ActionInput<'createSlides'>) - ); - await drainActionResult( - await AIProvider.actions.filterImage?.({ - workspaceId: 'workspace-1', - input: 'convert', - attachments: ['blob-1'], - style: 'Sketch style', - } satisfies ActionInput<'filterImage'>) - ); - - expect(createdSessions).toEqual([ - expect.objectContaining({ promptName: 'mindmap.generate' }), - expect.objectContaining({ promptName: 'slides.outline' }), - expect.objectContaining({ promptName: 'image.filter.sketch' }), - ]); - expect(textStreams).toEqual([ - expect.objectContaining({ - endpoint: Endpoint.Action, - options: expect.objectContaining({ actionId: 'mindmap.generate' }), - }), - expect.objectContaining({ - endpoint: Endpoint.Action, - options: expect.objectContaining({ actionId: 'slides.outline' }), - }), - expect.objectContaining({ - endpoint: Endpoint.Action, - options: expect.objectContaining({ actionId: 'image.filter.sketch' }), - }), - ]); - expect(client.imagesStream).not.toHaveBeenCalled(); - }); -}); diff --git a/packages/frontend/core/src/blocksuite/ai/provider/setup-provider.tsx b/packages/frontend/core/src/blocksuite/ai/provider/setup-provider.tsx index 900e573f5f..7315f795dc 100644 --- a/packages/frontend/core/src/blocksuite/ai/provider/setup-provider.tsx +++ b/packages/frontend/core/src/blocksuite/ai/provider/setup-provider.tsx @@ -1,20 +1,11 @@ import { toggleGeneralAIOnboarding } from '@affine/core/components/affine/ai-onboarding/apis'; import type { AuthAccountInfo, AuthService } from '@affine/core/modules/cloud'; import type { GlobalDialogService } from '@affine/core/modules/dialogs'; -import { - type AddContextFileInput, - ContextCategories, - type ContextWorkspaceEmbeddingStatus, - type getCopilotHistoriesQuery, - type QueryChatSessionsInput, - type RequestOptions, - type UpdateChatSessionInput, -} from '@affine/graphql'; +import type { AIRequestService } from '../runtime/request'; +import { setAIRequestService } from '../runtime/request'; +import { AIAppEvents } from './ai-app-events'; import { AIProvider } from './ai-provider'; -import { type CopilotClient, Endpoint } from './copilot-client'; -import type { PromptKey } from './prompt'; -import { textToText, toImage } from './request'; import { setupTracker } from './tracker'; function toAIUserInfo(account: AuthAccountInfo | null) { @@ -27,48 +18,12 @@ function toAIUserInfo(account: AuthAccountInfo | null) { }; } -const filterStyleToPromptName = new Map( - Object.entries({ - 'Clay style': 'image.filter.clay', - 'Pixel style': 'image.filter.pixel', - 'Sketch style': 'image.filter.sketch', - 'Anime style': 'image.filter.anime', - }) -); - -const processTypeToPromptName = new Map( - Object.entries({ - Clearer: 'Upscale image', - 'Remove background': 'Remove background', - 'Convert to sticker': 'Convert to sticker', - }) -); - export function setupAIProvider( - client: CopilotClient, + requestService: AIRequestService, globalDialogService: GlobalDialogService, authService: AuthService ) { - async function createSession({ - promptName, - workspaceId, - docId, - sessionId, - retry, - pinned, - reuseLatestChat, - }: BlockSuitePresets.AICreateSessionOptions) { - if (sessionId) return sessionId; - if (retry) return AIProvider.LAST_ACTION_SESSIONID; - - return client.createSession({ - workspaceId, - docId, - promptName, - pinned, - reuseLatestChat, - }); - } + setAIRequestService(requestService); AIProvider.provide('userInfo', () => { return toAIUserInfo(authService.session.account$.value); @@ -76,715 +31,10 @@ export function setupAIProvider( const accountSubscription = authService.session.account$.subscribe( account => { - AIProvider.slots.userInfo.next(toAIUserInfo(account)); + AIAppEvents.userInfo.next(toAIUserInfo(account)); } ); - //#region actions - AIProvider.provide('chat', async options => { - const { input, contexts } = options; - - const sessionId = await createSession({ - promptName: 'Chat With AFFiNE AI', - ...options, - }); - return textToText({ - ...options, - modelId: options.modelId, - client, - sessionId, - content: input, - timeout: 5 * 60 * 1000, // 5 minutes - params: { - docs: contexts?.docs, - files: contexts?.files, - selectedSnapshot: contexts?.selectedSnapshot, - selectedMarkdown: contexts?.selectedMarkdown, - html: contexts?.html, - ...(options.docId ? { currentDocId: options.docId } : {}), - }, - endpoint: Endpoint.StreamObject, - }); - }); - - AIProvider.provide('summary', async options => { - const sessionId = await createSession({ - promptName: 'Summary', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('translate', async options => { - const sessionId = await createSession({ - promptName: 'Translate to', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - params: { - language: options.lang, - }, - }); - }); - - AIProvider.provide('changeTone', async options => { - const sessionId = await createSession({ - promptName: 'Change tone to', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - params: { - tone: options.tone.toLowerCase(), - }, - content: options.input, - }); - }); - - AIProvider.provide('improveWriting', async options => { - const sessionId = await createSession({ - promptName: 'Improve writing for it', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('improveGrammar', async options => { - const sessionId = await createSession({ - promptName: 'Improve grammar for it', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('fixSpelling', async options => { - const sessionId = await createSession({ - promptName: 'Fix spelling for it', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('createHeadings', async options => { - const sessionId = await createSession({ - promptName: 'Create headings', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('makeLonger', async options => { - const sessionId = await createSession({ - promptName: 'Make it longer', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('makeShorter', async options => { - const sessionId = await createSession({ - promptName: 'Make it shorter', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('checkCodeErrors', async options => { - const sessionId = await createSession({ - promptName: 'Check code error', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('explainCode', async options => { - const sessionId = await createSession({ - promptName: 'Explain this code', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('writeArticle', async options => { - const sessionId = await createSession({ - promptName: 'Write an article about this', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('writeTwitterPost', async options => { - const sessionId = await createSession({ - promptName: 'Write a twitter about this', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('writePoem', async options => { - const sessionId = await createSession({ - promptName: 'Write a poem about this', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('writeOutline', async options => { - const sessionId = await createSession({ - promptName: 'Write outline', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('writeBlogPost', async options => { - const sessionId = await createSession({ - promptName: 'Write a blog post about this', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('brainstorm', async options => { - const sessionId = await createSession({ - promptName: 'Brainstorm ideas about this', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('findActions', async options => { - const sessionId = await createSession({ - promptName: 'Find action items from it', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('brainstormMindmap', async options => { - const sessionId = await createSession({ - promptName: 'mindmap.generate', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - // 3 minutes - timeout: 180000, - endpoint: Endpoint.Action, - actionId: 'mindmap.generate', - actionVersion: 'v1', - }); - }); - - AIProvider.provide('expandMindmap', async options => { - if (!options.input) { - throw new Error('expandMindmap action requires input'); - } - const sessionId = await createSession({ - promptName: 'Expand mind map', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - params: { - mindmap: options.mindmap, - node: options.input, - }, - content: options.input, - }); - }); - - AIProvider.provide('explain', async options => { - const sessionId = await createSession({ - promptName: 'Explain this', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('explainImage', async options => { - const sessionId = await createSession({ - promptName: 'Explain this image', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('makeItReal', async options => { - let promptName: PromptKey = 'Make it real'; - let content = options.input || ''; - - // wireframes - if (options.attachments?.length) { - content = `Here are the latest wireframes. Could you make a new website based on these wireframes and notes and send back just the html file? -Here are our design notes:\n ${content}.`; - } else { - // notes - promptName = 'Make it real with text'; - content = `Here are the latest notes: \n ${content}. -Could you make a new website based on these notes and send back just the html file?`; - } - - const sessionId = await createSession({ - promptName, - ...options, - }); - - return textToText({ - ...options, - client, - sessionId, - content, - }); - }); - - AIProvider.provide('createSlides', async options => { - const sessionId = await createSession({ - promptName: 'slides.outline', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - // 3 minutes - timeout: 180000, - endpoint: Endpoint.Action, - actionId: 'slides.outline', - actionVersion: 'v1', - }); - }); - - AIProvider.provide('createImage', async options => { - const sessionId = await createSession({ - promptName: 'Generate image', - ...options, - }); - return toImage({ - ...options, - client, - sessionId, - content: - !options.input && options.attachments - ? 'Make the image more detailed.' - : options.input, - // 5 minutes - timeout: 300000, - }); - }); - - AIProvider.provide('filterImage', async options => { - // test to image - const promptName: PromptKey | undefined = filterStyleToPromptName.get( - options.style - ); - if (!promptName) { - throw new Error('filterImage requires a promptName'); - } - const sessionId = await createSession({ - promptName, - ...options, - }); - return toImage({ - ...options, - client, - sessionId, - content: options.input, - timeout: 180000, - endpoint: Endpoint.Action, - actionId: promptName, - actionVersion: 'v1', - }); - }); - - AIProvider.provide('processImage', async options => { - // test to image - const promptName: PromptKey | undefined = processTypeToPromptName.get( - options.type - ); - if (!promptName) { - throw new Error('processImage requires a promptName'); - } - const sessionId = await createSession({ - promptName, - ...options, - }); - return toImage({ - ...options, - client, - sessionId, - content: options.input, - timeout: 180000, - }); - }); - - AIProvider.provide('generateCaption', async options => { - const sessionId = await createSession({ - promptName: 'Generate a caption', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - - AIProvider.provide('continueWriting', async options => { - const sessionId = await createSession({ - promptName: 'Continue writing', - ...options, - }); - return textToText({ - ...options, - client, - sessionId, - content: options.input, - }); - }); - //#endregion - - AIProvider.provide('session', { - createSession, - createSessionWithHistory: async options => { - if (!options.sessionId && !options.retry) { - return client.createSessionWithHistory({ - workspaceId: options.workspaceId, - docId: options.docId, - promptName: options.promptName, - pinned: options.pinned, - reuseLatestChat: options.reuseLatestChat, - }); - } - - const sessionId = await createSession(options); - if (!sessionId) return undefined; - return client.getSession(options.workspaceId, sessionId); - }, - getSession: async (workspaceId: string, sessionId: string) => { - return client.getSession(workspaceId, sessionId); - }, - getSessions: async ( - workspaceId: string, - docId?: string, - options?: QueryChatSessionsInput - ) => { - return client.getSessions(workspaceId, {}, docId, options); - }, - getRecentSessions: async ( - workspaceId: string, - limit?: number, - offset?: number - ) => { - return client.getRecentSessions(workspaceId, limit, offset); - }, - updateSession: async (options: UpdateChatSessionInput) => { - return client.updateSession(options); - }, - }); - - AIProvider.provide('context', { - createContext: async (workspaceId: string, sessionId: string) => { - return client.createContext(workspaceId, sessionId); - }, - getContextId: async (workspaceId: string, sessionId: string) => { - return client.getContextId(workspaceId, sessionId); - }, - addContextDoc: async (options: { contextId: string; docId: string }) => { - return client.addContextDoc(options); - }, - removeContextDoc: async (options: { contextId: string; docId: string }) => { - return client.removeContextDoc(options); - }, - addContextFile: async (file: File, options: AddContextFileInput) => { - return client.addContextFile(file, options); - }, - removeContextFile: async (options: { - contextId: string; - fileId: string; - }) => { - return client.removeContextFile(options); - }, - addContextTag: async (options: { - contextId: string; - tagId: string; - docIds: string[]; - }) => { - return client.addContextCategory({ - contextId: options.contextId, - type: ContextCategories.Tag, - categoryId: options.tagId, - docs: options.docIds, - }); - }, - removeContextTag: async (options: { contextId: string; tagId: string }) => { - return client.removeContextCategory({ - contextId: options.contextId, - type: ContextCategories.Tag, - categoryId: options.tagId, - }); - }, - addContextCollection: async (options: { - contextId: string; - collectionId: string; - docIds: string[]; - }) => { - return client.addContextCategory({ - contextId: options.contextId, - type: ContextCategories.Collection, - categoryId: options.collectionId, - docs: options.docIds, - }); - }, - removeContextCollection: async (options: { - contextId: string; - collectionId: string; - }) => { - return client.removeContextCategory({ - contextId: options.contextId, - type: ContextCategories.Collection, - categoryId: options.collectionId, - }); - }, - getContextDocsAndFiles: async ( - workspaceId: string, - sessionId: string, - contextId: string - ) => { - return client.getContextDocsAndFiles(workspaceId, sessionId, contextId); - }, - pollContextDocsAndFiles: async ( - workspaceId: string, - sessionId: string, - contextId: string, - onPoll: ( - result: BlockSuitePresets.AIDocsAndFilesContext | undefined - ) => void, - abortSignal: AbortSignal - ) => { - const poll = async () => { - const result = await client.getContextDocsAndFiles( - workspaceId, - sessionId, - contextId - ); - onPoll(result); - }; - - let attempts = 0; - const MIN_INTERVAL = 1000; - const MAX_INTERVAL = 30 * 1000; - - while (!abortSignal.aborted) { - await poll(); - const interval = Math.min( - MIN_INTERVAL * Math.pow(1.5, attempts), - MAX_INTERVAL - ); - attempts++; - await new Promise(resolve => setTimeout(resolve, interval)); - } - }, - pollEmbeddingStatus: async ( - workspaceId: string, - onPoll: (result: ContextWorkspaceEmbeddingStatus) => void, - abortSignal: AbortSignal - ) => { - const poll = async () => { - const result = await client.getEmbeddingStatus(workspaceId); - onPoll(result); - }; - - const INTERVAL = 10 * 1000; - - while (!abortSignal.aborted) { - await poll(); - await new Promise(resolve => setTimeout(resolve, INTERVAL)); - } - }, - matchContext: async ( - content: string, - contextId?: string, - workspaceId?: string, - limit?: number, - scopedThreshold?: number, - threshold?: number - ) => { - return client.matchContext( - content, - contextId, - workspaceId, - limit, - scopedThreshold, - threshold - ); - }, - addContextBlob: async (options: { blobId: string; contextId: string }) => { - return client.addContextBlob({ - contextId: options.contextId, - blobId: options.blobId, - }); - }, - removeContextBlob: async (options: { - blobId: string; - contextId: string; - }) => { - return client.removeContextBlob({ - contextId: options.contextId, - blobId: options.blobId, - }); - }, - }); - - AIProvider.provide('histories', { - actions: async ( - workspaceId: string, - docId: string - ): Promise => { - // @ts-expect-error - 'action' is missing in server impl - return ( - (await client.getHistories(workspaceId, {}, docId, { - action: true, - withPrompt: true, - withMessages: true, - })) ?? [] - ); - }, - chats: async ( - workspaceId: string, - sessionId: string, - docId?: string - ): Promise => { - // @ts-expect-error - 'action' is missing in server impl - return ( - (await client.getHistories(workspaceId, {}, docId, { - sessionId, - withMessages: true, - })) ?? [] - ); - }, - cleanup: async ( - workspaceId: string, - docId: string | undefined, - sessionIds: string[] - ) => { - await client.cleanupSessions({ workspaceId, docId, sessionIds }); - }, - ids: async ( - workspaceId: string, - docId?: string, - options?: RequestOptions< - typeof getCopilotHistoriesQuery - >['variables']['options'] - ): Promise => { - // @ts-expect-error - 'action' is missing in server impl - return await client.getHistoryIds(workspaceId, {}, docId, options); - }, - }); - AIProvider.provide('photoEngine', { async searchImages(options): Promise { let url = '/api/copilot/unsplash/photos'; @@ -813,19 +63,15 @@ Could you make a new website based on these notes and send back just the html fi AIProvider.provide('onboarding', toggleGeneralAIOnboarding); - AIProvider.provide('forkChat', options => { - return client.forkSession(options); + const disposeRequestLoginHandler = AIAppEvents.requestLogin.subscribe(() => { + globalDialogService.open('sign-in', {}); }); - const disposeRequestLoginHandler = AIProvider.slots.requestLogin.subscribe( - () => { - globalDialogService.open('sign-in', {}); - } - ); - - setupTracker(); + const trackerDisposer = setupTracker(requestService); return () => { + setAIRequestService(null); + trackerDisposer(); disposeRequestLoginHandler.unsubscribe(); accountSubscription.unsubscribe(); }; diff --git a/packages/frontend/core/src/blocksuite/ai/provider/tracker.ts b/packages/frontend/core/src/blocksuite/ai/provider/tracker.ts index cf087768d6..4b03065455 100644 --- a/packages/frontend/core/src/blocksuite/ai/provider/tracker.ts +++ b/packages/frontend/core/src/blocksuite/ai/provider/tracker.ts @@ -3,9 +3,12 @@ import type { EditorHost } from '@blocksuite/affine/std'; import type { GfxPrimitiveElementModel } from '@blocksuite/affine/std/gfx'; import type { BlockModel } from '@blocksuite/affine/store'; import { lowerCase, omit } from 'lodash-es'; -import type { Subject } from 'rxjs'; -import { AIProvider } from './ai-provider'; +import type { + AIRequestActionEvent, + AIRequestService, +} from '../runtime/request'; +import { AIAppEvents } from './ai-app-events'; type ElementModel = GfxPrimitiveElementModel; @@ -62,10 +65,6 @@ type AIActionEventProperties = { workspaceId: string; }; -type SubjectValue = T extends Subject ? U : never; - -type BlocksuiteActionEvent = SubjectValue; - const trackAction = ({ eventName, properties, @@ -106,7 +105,7 @@ function isBlockModel(model: BlockModel | ElementModel): model is BlockModel { return 'flavour' in model; } -function inferObjectType(event: BlocksuiteActionEvent) { +function inferObjectType(event: AIRequestActionEvent) { const models: (BlockModel | ElementModel)[] | undefined = event.options.models; if (!models) { @@ -135,7 +134,7 @@ function inferObjectType(event: BlocksuiteActionEvent) { } function inferSegment( - event: BlocksuiteActionEvent + event: AIRequestActionEvent ): AIActionEventProperties['segment'] { if (event.options.where === 'inline-chat-panel') { return 'inline chat panel'; @@ -151,7 +150,7 @@ function inferSegment( } function inferModule( - event: BlocksuiteActionEvent + event: AIRequestActionEvent ): AIActionEventProperties['module'] { if (event.options.where === 'chat-panel') { return 'AI chat panel'; @@ -168,9 +167,7 @@ function inferModule( } } -function inferEventName( - event: BlocksuiteActionEvent -): AIActionEventName | null { +function inferEventName(event: AIRequestActionEvent): AIActionEventName | null { if (['result:discard', 'result:retry'].includes(event.event)) { return 'AI result discarded'; } else if (event.event.startsWith('result:')) { @@ -184,7 +181,7 @@ function inferEventName( } function inferControl( - event: BlocksuiteActionEvent + event: AIRequestActionEvent ): AIActionEventProperties['control'] { if (event.event === 'aborted:stop') { return 'stop button'; @@ -222,7 +219,7 @@ function inferControl( } const toTrackedOptions = ( - event: BlocksuiteActionEvent + event: AIRequestActionEvent ): { eventName: AIActionEventName; properties: AIActionEventProperties; @@ -257,19 +254,25 @@ const toTrackedOptions = ( }; }; -export function setupTracker() { - AIProvider.slots.requestUpgradePlan.subscribe(() => { +export function setupTracker(requestService: AIRequestService) { + const upgradeSubscription = AIAppEvents.requestUpgradePlan.subscribe(() => { track.$.paywall.aiAction.viewPlans(); }); - AIProvider.slots.requestLogin.subscribe(() => { + const loginSubscription = AIAppEvents.requestLogin.subscribe(() => { track.doc.editor.aiActions.requestSignIn(); }); - AIProvider.slots.actions.subscribe(event => { + const actionSubscription = requestService.actionEvents$.subscribe(event => { const properties = toTrackedOptions(event); if (properties) { trackAction(properties); } }); + + return () => { + upgradeSubscription.unsubscribe(); + loginSubscription.unsubscribe(); + actionSubscription.unsubscribe(); + }; } diff --git a/packages/frontend/core/src/blocksuite/ai/runtime/chat/actions.ts b/packages/frontend/core/src/blocksuite/ai/runtime/chat/actions.ts new file mode 100644 index 0000000000..6d71546639 --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/runtime/chat/actions.ts @@ -0,0 +1,60 @@ +import type { CopilotChatHistoryFragment } from '@affine/graphql'; + +import type { AIChatContextItem, AIChatScope } from './state'; + +export type AIChatSendOptions = { + input?: string; + contexts?: { + docs?: unknown; + files?: unknown; + selectedSnapshot?: unknown; + selectedMarkdown?: unknown; + html?: unknown; + }; + attachments?: (string | Blob | File)[]; + attachmentPreviews?: string[]; + isRootSession?: boolean; + where?: BlockSuitePresets.TrackerWhere; + control?: BlockSuitePresets.TrackerControl; + reasoning?: boolean; + toolsConfig?: unknown; + modelId?: string; + userInfo?: { + userId?: string; + userName?: string; + avatarUrl?: string; + }; +}; + +export type AIChatAction = + | { type: 'initialize'; scope?: AIChatScope } + | { type: 'setScope'; scope: AIChatScope } + | { type: 'refreshHistory' } + | { + type: 'openSession'; + sessionId: string; + } + | { + type: 'openSessionObject'; + session: CopilotChatHistoryFragment; + } + | { type: 'closeTab'; tabId: string } + | { type: 'createNewSession'; pinned?: boolean } + | { type: 'togglePinActiveSession' } + | { type: 'deleteSession'; sessionId: string } + | { type: 'clearError' } + | { type: 'setComposerText'; text: string } + | { type: 'setReasoning'; reasoning: boolean } + | { type: 'setModel'; modelId?: string } + | { type: 'addAttachment'; attachment: string | Blob | File } + | { type: 'removeAttachment'; index: number } + | { type: 'addContextItem'; item: AIChatContextItem } + | { type: 'removeContextItem'; item: AIChatContextItem } + | { type: 'loadContext' } + | { type: 'pollContext' } + | { type: 'startContextPolling' } + | { type: 'stopContextPolling' } + | { type: 'pollEmbeddingStatus' } + | ({ type: 'send' } & AIChatSendOptions) + | { type: 'retry'; messageId: string } + | { type: 'stop' }; diff --git a/packages/frontend/core/src/blocksuite/ai/runtime/chat/index.ts b/packages/frontend/core/src/blocksuite/ai/runtime/chat/index.ts new file mode 100644 index 0000000000..47a6bd4fed --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/runtime/chat/index.ts @@ -0,0 +1,6 @@ +export * from './actions'; +export * from './runtime'; +export * from './session-strategy'; +export * from './state'; +export * from './use-element'; +export * from './use-runtime'; diff --git a/packages/frontend/core/src/blocksuite/ai/runtime/chat/runtime.spec.ts b/packages/frontend/core/src/blocksuite/ai/runtime/chat/runtime.spec.ts new file mode 100644 index 0000000000..c5e73788c4 --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/runtime/chat/runtime.spec.ts @@ -0,0 +1,1001 @@ +/** + * @vitest-environment happy-dom + */ +import type { CopilotChatHistoryFragment } from '@affine/graphql'; +import { describe, expect, test, vi } from 'vitest'; + +import type { AIRequestService } from '../request'; +import { AIChatRuntime } from './runtime'; +import { + DocAIChatSessionStrategy, + ForkAIChatSessionStrategy, + PlaygroundAIChatSessionStrategy, +} from './session-strategy'; +import type { AIChatScope } from './state'; + +const docScope: AIChatScope = { + kind: 'doc', + workspaceId: 'workspace-1', + docId: 'doc-1', +}; + +function session( + overrides: Partial = {} +): CopilotChatHistoryFragment { + return { + sessionId: 'session-1', + workspaceId: 'workspace-1', + docId: 'doc-1', + title: 'Session 1', + pinned: false, + messages: [], + createdAt: new Date().toISOString(), + updatedAt: new Date().toISOString(), + parentSessionId: null, + promptName: 'Chat With AFFiNE AI', + action: null, + optionalModels: null, + tokens: 0, + ...overrides, + } as CopilotChatHistoryFragment; +} + +async function* stream(chunks: string[]) { + for (const chunk of chunks) { + yield chunk; + } +} + +async function waitUntil(assertion: () => void) { + for (let i = 0; i < 10; i++) { + try { + assertion(); + return; + } catch { + await Promise.resolve(); + } + } + assertion(); +} + +function createRequest( + overrides: Partial = {} +): AIRequestService { + return { + getSessions: vi.fn().mockResolvedValue([]), + getRecentSessions: vi.fn().mockResolvedValue([]), + getSession: vi.fn().mockResolvedValue(null), + createSessionWithHistory: vi.fn().mockResolvedValue(session()), + updateSession: vi.fn().mockResolvedValue(undefined), + cleanupSessions: vi.fn().mockResolvedValue(undefined), + executeAction: vi.fn().mockResolvedValue(stream(['hello'])), + histories: { + ids: vi.fn().mockResolvedValue([]), + }, + context: { + createContext: vi.fn().mockResolvedValue('context-1'), + getContextId: vi.fn().mockResolvedValue(undefined), + addContextDoc: vi.fn().mockResolvedValue(undefined), + removeContextDoc: vi.fn().mockResolvedValue(undefined), + addContextFile: vi + .fn() + .mockResolvedValue({ id: 'file-1', status: 'processing' }), + removeContextFile: vi.fn().mockResolvedValue(undefined), + addContextTag: vi.fn().mockResolvedValue(undefined), + removeContextTag: vi.fn().mockResolvedValue(undefined), + addContextCollection: vi.fn().mockResolvedValue(undefined), + removeContextCollection: vi.fn().mockResolvedValue(undefined), + getContextDocsAndFiles: vi.fn().mockResolvedValue(undefined), + matchContext: vi.fn().mockResolvedValue({ files: [], docs: [] }), + addContextBlob: vi + .fn() + .mockResolvedValue({ id: 'blob-1', status: 'processing' }), + removeContextBlob: vi.fn().mockResolvedValue(undefined), + pollContextDocsAndFiles: vi.fn(), + pollEmbeddingStatus: vi.fn(), + }, + ...overrides, + } as unknown as AIRequestService; +} + +function createRuntime(request = createRequest()) { + return new AIChatRuntime({ + request, + scope: docScope, + strategy: new DocAIChatSessionStrategy(), + }); +} + +describe('AIChatRuntime', () => { + test('initializes doc scope with a draft tab when no session exists', async () => { + const runtime = createRuntime(); + + await runtime.dispatch({ type: 'initialize' }); + + const snapshot = runtime.getSnapshot(); + expect(snapshot.readiness).toBe('ready'); + expect(snapshot.activeSessionId).toBeNull(); + expect(snapshot.tabs).toEqual([ + expect.objectContaining({ kind: 'draft', hasMessages: false }), + ]); + expect(snapshot.uiPolicy.showDraftTab).toBe(true); + }); + + test('initializes doc scope with full messages for the latest session', async () => { + const listedSession = session({ sessionId: 'session-1', messages: [] }); + const fullSession = session({ + sessionId: 'session-1', + messages: [ + { + id: 'message-1', + role: 'user', + content: 'previous chat', + attachments: [], + streamObjects: [], + createdAt: new Date().toISOString(), + }, + ], + }); + const request = createRequest({ + getSessions: vi + .fn() + .mockResolvedValueOnce([]) + .mockResolvedValueOnce([listedSession]), + getSession: vi.fn().mockResolvedValue(fullSession), + }); + const runtime = createRuntime(request); + + await runtime.dispatch({ type: 'initialize' }); + + expect(request.getSession).toHaveBeenCalledWith('workspace-1', 'session-1'); + expect(runtime.getSnapshot().messages).toEqual(fullSession.messages); + }); + + test('send creates a session once and ignores duplicate sends while transmitting', async () => { + let release!: () => void; + const blockedStream = { + async *[Symbol.asyncIterator]() { + await new Promise(resolve => { + release = resolve; + }); + yield 'done'; + }, + }; + const request = createRequest({ + executeAction: vi.fn().mockResolvedValue(blockedStream), + }); + const runtime = createRuntime(request); + await runtime.dispatch({ type: 'initialize' }); + + const firstSend = runtime.dispatch({ type: 'send', input: 'hello' }); + await waitUntil(() => { + expect(request.executeAction).toHaveBeenCalled(); + }); + await runtime.dispatch({ type: 'send', input: 'again' }); + release(); + await firstSend; + + expect(request.createSessionWithHistory).toHaveBeenCalledTimes(1); + expect(request.executeAction).toHaveBeenCalledTimes(1); + expect(runtime.getSnapshot().messages.at(-1)?.content).toBe('done'); + expect(runtime.getSnapshot().uiPolicy.canCreateNewSession).toBe(true); + }); + + test('send binds an unbound session to the active doc after success', async () => { + const unboundSession = session({ docId: null }); + const boundSession = session({ docId: 'doc-1' }); + const request = createRequest({ + getSession: vi.fn().mockResolvedValue(boundSession), + }); + const runtime = createRuntime(request); + await runtime.dispatch({ + type: 'openSessionObject', + session: unboundSession, + }); + + await runtime.dispatch({ type: 'send', input: 'hello' }); + + expect(request.updateSession).toHaveBeenCalledWith({ + sessionId: 'session-1', + docId: 'doc-1', + }); + expect(request.getSession).toHaveBeenCalledWith('workspace-1', 'session-1'); + expect(runtime.getSnapshot().sessions[0].docId).toBe('doc-1'); + }); + + test('new session opens a draft tab and persists on first send', async () => { + const request = createRequest({ + createSessionWithHistory: vi + .fn() + .mockResolvedValue(session({ sessionId: 'session-2', messages: [] })), + }); + const runtime = createRuntime(request); + await runtime.dispatch({ + type: 'openSessionObject', + session: session({ + sessionId: 'session-1', + title: 'One', + messages: [ + { + id: 'message-1', + role: 'user', + content: 'existing chat', + attachments: [], + streamObjects: [], + createdAt: new Date().toISOString(), + }, + ], + }), + }); + + await runtime.dispatch({ type: 'createNewSession' }); + + expect(request.createSessionWithHistory).not.toHaveBeenCalled(); + expect(runtime.getSnapshot().activeSessionId).toBeNull(); + expect(runtime.getSnapshot().messages).toEqual([]); + expect(runtime.getSnapshot().uiPolicy.showDraftTab).toBe(true); + expect(runtime.getSnapshot().tabs).toEqual([ + expect.objectContaining({ kind: 'session', sessionId: 'session-1' }), + expect.objectContaining({ kind: 'draft' }), + ]); + + await runtime.dispatch({ type: 'send', input: 'hello' }); + + expect(request.createSessionWithHistory).toHaveBeenCalledTimes(1); + expect(runtime.getSnapshot().activeSessionId).toBe('session-2'); + expect(runtime.getSnapshot().uiPolicy.showDraftTab).toBe(false); + expect(runtime.getSnapshot().tabs).toEqual([ + expect.objectContaining({ kind: 'session', sessionId: 'session-1' }), + expect.objectContaining({ kind: 'session', sessionId: 'session-2' }), + ]); + }); + + test('toggle pin updates tab and session snapshots', async () => { + const request = createRequest(); + const runtime = createRuntime(request); + await runtime.dispatch({ + type: 'openSessionObject', + session: session({ pinned: false }), + }); + + await runtime.dispatch({ type: 'togglePinActiveSession' }); + + expect(request.updateSession).toHaveBeenCalledWith({ + sessionId: 'session-1', + pinned: true, + }); + expect(runtime.getSnapshot().tabs[0]).toEqual( + expect.objectContaining({ pinned: true }) + ); + expect(runtime.getSnapshot().sessions[0]).toEqual( + expect.objectContaining({ pinned: true }) + ); + }); + + test('new session inserts the draft tab after the active tab', async () => { + const runtime = createRuntime(); + await runtime.dispatch({ + type: 'openSessionObject', + session: session({ + sessionId: 'session-1', + messages: [ + { + id: 'message-1', + role: 'user', + content: 'first chat', + attachments: [], + streamObjects: [], + createdAt: new Date().toISOString(), + }, + ], + }), + }); + await runtime.dispatch({ + type: 'openSessionObject', + session: session({ + sessionId: 'session-2', + messages: [ + { + id: 'message-2', + role: 'user', + content: 'second chat', + attachments: [], + streamObjects: [], + createdAt: new Date().toISOString(), + }, + ], + }), + }); + await runtime.dispatch({ + type: 'openSessionObject', + session: session({ + sessionId: 'session-1', + messages: [ + { + id: 'message-1', + role: 'user', + content: 'first chat', + attachments: [], + streamObjects: [], + createdAt: new Date().toISOString(), + }, + ], + }), + }); + + await runtime.dispatch({ type: 'createNewSession' }); + + expect(runtime.getSnapshot().tabs.map(tab => tab.kind)).toEqual([ + 'session', + 'draft', + 'session', + ]); + expect(runtime.getSnapshot().tabs.map(tab => tab.id)).toEqual([ + 'session-1', + expect.stringContaining('draft:'), + 'session-2', + ]); + }); + + test('close active tab falls back to the previous session tab', async () => { + const runtime = createRuntime(); + await runtime.dispatch({ + type: 'openSessionObject', + session: session({ sessionId: 'session-1', title: 'One' }), + }); + await runtime.dispatch({ + type: 'openSessionObject', + session: session({ sessionId: 'session-2', title: 'Two' }), + }); + + await runtime.dispatch({ type: 'closeTab', tabId: 'session-2' }); + + expect(runtime.getSnapshot().activeSessionId).toBe('session-1'); + }); + + test('close active tab reloads fallback session messages', async () => { + const fallbackSession = session({ + sessionId: 'session-1', + title: 'One', + messages: [ + { + id: 'message-1', + role: 'user', + content: 'old chat', + attachments: [], + streamObjects: [], + createdAt: new Date().toISOString(), + }, + ], + }); + const request = createRequest({ + getSession: vi.fn().mockResolvedValue(fallbackSession), + }); + const runtime = createRuntime(request); + await runtime.dispatch({ + type: 'openSessionObject', + session: session({ sessionId: 'session-1', title: 'One', messages: [] }), + }); + await runtime.dispatch({ + type: 'openSessionObject', + session: session({ sessionId: 'session-2', title: 'Two', messages: [] }), + }); + + await runtime.dispatch({ type: 'closeTab', tabId: 'session-2' }); + + expect(request.getSession).toHaveBeenCalledWith('workspace-1', 'session-1'); + expect(runtime.getSnapshot().activeSessionId).toBe('session-1'); + expect(runtime.getSnapshot().messages).toEqual(fallbackSession.messages); + }); + + test('refreshHistory keeps current doc sessions separate from recent sessions', async () => { + const currentDoc = [session({ sessionId: 'doc-session' })]; + const recent = [session({ sessionId: 'recent-session', docId: null })]; + const request = createRequest({ + getSessions: vi.fn().mockResolvedValue(currentDoc), + getRecentSessions: vi.fn().mockResolvedValue(recent), + }); + const runtime = createRuntime(request); + + await runtime.dispatch({ type: 'refreshHistory' }); + + expect(runtime.getSnapshot().history.currentDoc).toEqual(currentDoc); + expect(runtime.getSnapshot().history.recent).toEqual(recent); + }); + + test('other-doc session returns a navigation request instead of opening a tab', async () => { + const runtime = createRuntime(); + + await runtime.dispatch({ + type: 'openSessionObject', + session: session({ sessionId: 'session-2', docId: 'doc-2' }), + }); + + expect(runtime.getSnapshot().navigationRequest).toEqual({ + workspaceId: 'workspace-1', + docId: 'doc-2', + sessionId: 'session-2', + resetTabs: true, + }); + expect(runtime.getSnapshot().activeSessionId).toBeNull(); + }); + + test('stale stream result does not commit after scope switch', async () => { + let release!: () => void; + const delayedStream = { + async *[Symbol.asyncIterator]() { + await new Promise(resolve => { + release = resolve; + }); + yield 'late'; + }, + }; + const request = createRequest({ + executeAction: vi.fn().mockResolvedValue(delayedStream), + }); + const runtime = createRuntime(request); + await runtime.dispatch({ type: 'initialize' }); + const send = runtime.dispatch({ type: 'send', input: 'hello' }); + await waitUntil(() => { + expect(request.executeAction).toHaveBeenCalled(); + }); + + await runtime.dispatch({ + type: 'setScope', + scope: { kind: 'doc', workspaceId: 'workspace-1', docId: 'doc-2' }, + }); + release(); + await send; + + expect(runtime.getSnapshot().scope).toEqual({ + kind: 'doc', + workspaceId: 'workspace-1', + docId: 'doc-2', + }); + expect(runtime.getSnapshot().messages).toEqual([]); + }); + + test('stale session creation does not open after scope switch', async () => { + let releaseSession!: (value: CopilotChatHistoryFragment) => void; + const request = createRequest({ + createSessionWithHistory: vi.fn().mockReturnValue( + new Promise(resolve => { + releaseSession = resolve; + }) + ), + }); + const runtime = createRuntime(request); + await runtime.dispatch({ type: 'initialize' }); + + const send = runtime.dispatch({ type: 'send', input: 'hello' }); + await runtime.dispatch({ + type: 'setScope', + scope: { kind: 'doc', workspaceId: 'workspace-1', docId: 'doc-2' }, + }); + releaseSession(session({ sessionId: 'late-session' })); + await send; + + expect(runtime.getSnapshot().activeSessionId).toBeNull(); + expect(runtime.getSnapshot().sessions).toEqual([]); + }); + + test('send failure commits error status without throwing', async () => { + const error = new Error('network failed'); + const request = createRequest({ + executeAction: vi.fn().mockRejectedValue(error), + }); + const runtime = createRuntime(request); + await runtime.dispatch({ type: 'initialize' }); + + await runtime.dispatch({ type: 'send', input: 'hello' }); + + expect(runtime.getSnapshot().status).toBe('error'); + expect(runtime.getSnapshot().error).toBe(error); + expect(runtime.getSnapshot().messages).toEqual([ + expect.objectContaining({ role: 'user', content: 'hello' }), + expect.objectContaining({ role: 'assistant', content: '' }), + ]); + }); + + test('send remains successful when refreshing the assistant message id fails', async () => { + const error = new Error('history unavailable'); + const consoleError = vi + .spyOn(console, 'error') + .mockImplementation(() => {}); + const request = createRequest(); + (request.histories.ids as ReturnType).mockRejectedValue( + error + ); + const runtime = createRuntime(request); + await runtime.dispatch({ type: 'initialize' }); + + await runtime.dispatch({ type: 'send', input: 'hello' }); + + expect(runtime.getSnapshot().status).toBe('success'); + expect(runtime.getSnapshot().error).toBeNull(); + expect(runtime.getSnapshot().messages.at(-1)).toEqual( + expect.objectContaining({ role: 'assistant', content: 'hello' }) + ); + expect(consoleError).toHaveBeenCalledWith(error); + consoleError.mockRestore(); + }); + + test('stop marks the active assistant response as complete', async () => { + let release!: () => void; + const blockedStream = { + async *[Symbol.asyncIterator]() { + yield 'partial'; + await new Promise(resolve => { + release = resolve; + }); + yield 'late'; + }, + }; + const request = createRequest({ + executeAction: vi.fn().mockResolvedValue(blockedStream), + }); + const runtime = createRuntime(request); + await runtime.dispatch({ type: 'initialize' }); + + const send = runtime.dispatch({ type: 'send', input: 'hello' }); + await waitUntil(() => { + expect(runtime.getSnapshot().status).toBe('transmitting'); + }); + + await runtime.dispatch({ type: 'stop' }); + release(); + await send; + + expect(runtime.getSnapshot().status).toBe('success'); + expect(runtime.getSnapshot().messages.at(-1)).toEqual( + expect.objectContaining({ role: 'assistant', content: 'partial' }) + ); + }); + + test('clearError resets error status', async () => { + const error = new Error('network failed'); + const runtime = createRuntime( + createRequest({ + executeAction: vi.fn().mockRejectedValue(error), + }) + ); + await runtime.dispatch({ type: 'initialize' }); + await runtime.dispatch({ type: 'send', input: 'hello' }); + + await runtime.dispatch({ type: 'clearError' }); + + expect(runtime.getSnapshot().status).toBe('idle'); + expect(runtime.getSnapshot().error).toBeNull(); + }); + + test('retry failure commits error status and keeps the retried assistant placeholder', async () => { + const error = new Error('retry failed'); + const request = createRequest({ + executeAction: vi.fn().mockRejectedValue(error), + }); + const runtime = createRuntime(request); + await runtime.dispatch({ + type: 'openSessionObject', + session: session({ + messages: [ + { + id: 'user-1', + role: 'user', + content: 'hello', + createdAt: new Date().toISOString(), + attachments: null, + streamObjects: null, + }, + { + id: 'assistant-1', + role: 'assistant', + content: 'old', + createdAt: new Date().toISOString(), + attachments: null, + streamObjects: null, + }, + ], + }), + }); + + await runtime.dispatch({ type: 'retry', messageId: 'assistant-1' }); + + expect(runtime.getSnapshot().status).toBe('error'); + expect(runtime.getSnapshot().error).toBe(error); + expect(runtime.getSnapshot().messages[1]).toEqual( + expect.objectContaining({ role: 'assistant', content: '' }) + ); + }); + + test('stale openSession result does not commit after scope switch', async () => { + let release!: (value: CopilotChatHistoryFragment) => void; + const request = createRequest({ + getSession: vi.fn().mockReturnValue( + new Promise(resolve => { + release = resolve; + }) + ), + }); + const runtime = createRuntime(request); + + const open = runtime.dispatch({ + type: 'openSession', + sessionId: 'session-2', + }); + await runtime.dispatch({ + type: 'setScope', + scope: { kind: 'doc', workspaceId: 'workspace-1', docId: 'doc-2' }, + }); + release(session({ sessionId: 'session-2' })); + await open; + + expect(runtime.getSnapshot().scope).toEqual({ + kind: 'doc', + workspaceId: 'workspace-1', + docId: 'doc-2', + }); + expect(runtime.getSnapshot().activeSessionId).toBeNull(); + }); + + test('retry uses existing session and preserves user messages', async () => { + const request = createRequest({ + executeAction: vi.fn().mockResolvedValue(stream(['retry'])), + }); + const runtime = createRuntime(request); + await runtime.dispatch({ + type: 'openSessionObject', + session: session({ + messages: [ + { + id: 'user-1', + role: 'user', + content: 'hello', + createdAt: new Date().toISOString(), + attachments: null, + streamObjects: null, + }, + { + id: 'assistant-1', + role: 'assistant', + content: 'old', + createdAt: new Date().toISOString(), + attachments: null, + streamObjects: null, + }, + ], + }), + }); + + await runtime.dispatch({ type: 'retry', messageId: 'assistant-1' }); + + expect(runtime.getSnapshot().messages[0].content).toBe('hello'); + expect(runtime.getSnapshot().messages[1].content).toBe('retry'); + expect(request.executeAction).toHaveBeenCalledWith( + 'chat', + expect.objectContaining({ retry: true, sessionId: 'session-1' }) + ); + }); + + test('retry remains successful when refreshing the assistant message id fails', async () => { + const error = new Error('history unavailable'); + const consoleError = vi + .spyOn(console, 'error') + .mockImplementation(() => {}); + const request = createRequest({ + executeAction: vi.fn().mockResolvedValue(stream(['retry'])), + }); + (request.histories.ids as ReturnType).mockRejectedValue( + error + ); + const runtime = createRuntime(request); + await runtime.dispatch({ + type: 'openSessionObject', + session: session({ + messages: [ + { + id: 'user-1', + role: 'user', + content: 'hello', + createdAt: new Date().toISOString(), + attachments: null, + streamObjects: null, + }, + { + id: '', + role: 'assistant', + content: 'old', + createdAt: new Date().toISOString(), + attachments: null, + streamObjects: null, + }, + ], + }), + }); + + await runtime.dispatch({ type: 'retry', messageId: '' }); + + expect(runtime.getSnapshot().status).toBe('success'); + expect(runtime.getSnapshot().error).toBeNull(); + expect(runtime.getSnapshot().messages[1]).toEqual( + expect.objectContaining({ role: 'assistant', content: 'retry' }) + ); + expect(consoleError).toHaveBeenCalledWith(error); + consoleError.mockRestore(); + }); + + test('retry reuses failed initial messages when no session was created', async () => { + const error = new Error('create session failed'); + const request = createRequest({ + createSessionWithHistory: vi + .fn() + .mockRejectedValueOnce(error) + .mockResolvedValueOnce(session({ sessionId: 'session-2' })), + executeAction: vi.fn().mockResolvedValue(stream(['retry'])), + }); + const runtime = createRuntime(request); + await runtime.dispatch({ type: 'initialize' }); + await runtime.dispatch({ type: 'send', input: 'hello' }); + + expect(runtime.getSnapshot().activeSessionId).toBeNull(); + expect(runtime.getSnapshot().status).toBe('error'); + + await runtime.dispatch({ type: 'retry', messageId: '' }); + + expect(runtime.getSnapshot().activeSessionId).toBe('session-2'); + expect(runtime.getSnapshot().status).toBe('success'); + expect(runtime.getSnapshot().messages).toEqual([ + expect.objectContaining({ role: 'user', content: 'hello' }), + expect.objectContaining({ role: 'assistant', content: 'retry' }), + ]); + }); + + test('history refresh does not stale an active stream', async () => { + let release!: () => void; + const delayedStream = { + async *[Symbol.asyncIterator]() { + await new Promise(resolve => { + release = resolve; + }); + yield 'late'; + }, + }; + const request = createRequest({ + executeAction: vi.fn().mockResolvedValue(delayedStream), + getRecentSessions: vi + .fn() + .mockResolvedValue([session({ sessionId: 'recent' })]), + }); + const runtime = createRuntime(request); + await runtime.dispatch({ type: 'initialize' }); + + const send = runtime.dispatch({ type: 'send', input: 'hello' }); + await waitUntil(() => { + expect(request.executeAction).toHaveBeenCalled(); + }); + await runtime.dispatch({ type: 'refreshHistory' }); + release(); + await send; + + expect(runtime.getSnapshot().messages.at(-1)?.content).toBe('late'); + expect(runtime.getSnapshot().history.recent[0].sessionId).toBe('recent'); + }); + + test('context add remove and poll preserve operation order', async () => { + const request = createRequest(); + const runtime = createRuntime(request); + await runtime.dispatch({ type: 'initialize' }); + + await runtime.dispatch({ + type: 'addContextItem', + item: { kind: 'doc', docId: 'doc-2' }, + }); + await runtime.dispatch({ + type: 'addContextItem', + item: { kind: 'blob', blobId: 'blob-1' }, + }); + await runtime.dispatch({ + type: 'removeContextItem', + item: { kind: 'doc', docId: 'doc-2' }, + }); + ( + request.context.getContextDocsAndFiles as ReturnType + ).mockResolvedValue({ + blobs: [{ blobId: 'blob-1', status: 'finished' }], + }); + await runtime.dispatch({ type: 'pollContext' }); + + expect(request.context.createContext).toHaveBeenCalledTimes(1); + expect(request.context.addContextDoc).toHaveBeenCalledWith({ + contextId: 'context-1', + docId: 'doc-2', + }); + expect(request.context.removeContextDoc).toHaveBeenCalledWith({ + contextId: 'context-1', + docId: 'doc-2', + }); + expect(runtime.getSnapshot().composer.context.items).toEqual([ + { kind: 'blob', blobId: 'blob-1', state: 'finished' }, + ]); + }); + + test('loadContext restores existing session context without creating a new context', async () => { + const request = createRequest(); + ( + request.context.getContextId as ReturnType + ).mockResolvedValue('context-1'); + ( + request.context.getContextDocsAndFiles as ReturnType + ).mockResolvedValue({ + docs: [{ id: 'doc-2', status: 'finished', createdAt: 2 }], + files: [ + { + id: 'file-1', + blobId: 'blob-file-1', + name: 'note.pdf', + status: 'processing', + createdAt: 1, + }, + ], + tags: [ + { + id: 'tag-1', + docs: [{ id: 'tag-doc', status: 'failed' }], + createdAt: 3, + }, + ], + collections: [], + blobs: [], + }); + const runtime = createRuntime(request); + await runtime.dispatch({ + type: 'openSessionObject', + session: session(), + }); + + await runtime.dispatch({ type: 'loadContext' }); + + expect(request.context.createContext).not.toHaveBeenCalled(); + expect(runtime.getSnapshot().composer.context.contextId).toBe('context-1'); + expect(runtime.getSnapshot().composer.context.items).toEqual([ + expect.objectContaining({ + kind: 'file', + fileId: 'file-1', + blobId: 'blob-file-1', + state: 'processing', + }), + { kind: 'doc', docId: 'doc-2', state: 'finished', createdAt: 2 }, + { + kind: 'tag', + tagId: 'tag-1', + docIds: ['tag-doc'], + state: 'finished', + createdAt: 3, + tooltip: undefined, + }, + ]); + expect(runtime.getSnapshot().composer.context.embeddingCount).toEqual({ + finished: 1, + processing: 1, + failed: 1, + }); + }); + + test('pollEmbeddingStatus updates composer embedding completion state', async () => { + const request = createRequest(); + (request.context.pollEmbeddingStatus as ReturnType) + .mockImplementationOnce(async (_workspaceId, onPoll) => { + onPoll({ embedded: 1, total: 2 }); + }) + .mockImplementationOnce(async (_workspaceId, onPoll) => { + onPoll({ embedded: 2, total: 2 }); + }); + const runtime = createRuntime(request); + + await runtime.dispatch({ type: 'pollEmbeddingStatus' }); + expect(runtime.getSnapshot().composer.context.embeddingCompleted).toBe( + false + ); + + await runtime.dispatch({ type: 'pollEmbeddingStatus' }); + expect(runtime.getSnapshot().composer.context.embeddingCompleted).toBe( + true + ); + }); + + test('startContextPolling owns context polling lifecycle', async () => { + const request = createRequest(); + ( + request.context.getContextDocsAndFiles as ReturnType + ).mockResolvedValue({ + docs: [{ docId: 'doc-2', status: 'finished' }], + }); + ( + request.context.getContextId as ReturnType + ).mockResolvedValue('context-1'); + const runtime = createRuntime(request); + + await runtime.dispatch({ + type: 'openSessionObject', + session: session(), + }); + await runtime.dispatch({ type: 'loadContext' }); + await runtime.dispatch({ type: 'startContextPolling' }); + await waitUntil(() => { + expect(request.context.getContextDocsAndFiles).toHaveBeenCalledTimes(2); + }); + + expect(runtime.getSnapshot().composer.context.polling).toBe(false); + expect(runtime.getSnapshot().composer.context.embeddingCount).toEqual({ + finished: 1, + processing: 0, + failed: 0, + }); + }); + + test('fork strategy creates child session from parent without doc tab restrictions', async () => { + const request = createRequest({ + forkChat: vi.fn().mockResolvedValue('fork-session'), + getSession: vi.fn().mockResolvedValue( + session({ + sessionId: 'fork-session', + docId: 'another-doc', + parentSessionId: 'parent-session', + }) + ), + }); + const runtime = new AIChatRuntime({ + request, + scope: { + kind: 'fork', + workspaceId: 'workspace-1', + docId: 'doc-1', + parentSessionId: 'parent-session', + latestMessageId: 'message-1', + }, + strategy: new ForkAIChatSessionStrategy(), + }); + + await runtime.dispatch({ type: 'send', input: 'hello' }); + + expect(request.forkChat).toHaveBeenCalledWith({ + workspaceId: 'workspace-1', + docId: 'doc-1', + sessionId: 'parent-session', + latestMessageId: 'message-1', + }); + expect(runtime.getSnapshot().navigationRequest).toBeNull(); + expect(runtime.getSnapshot().activeSessionId).toBe('fork-session'); + }); + + test('playground strategy creates fork sessions from parent scope', async () => { + const request = createRequest({ + forkChat: vi.fn().mockResolvedValue('playground-fork'), + getSession: vi.fn().mockResolvedValue( + session({ + sessionId: 'playground-fork', + docId: 'doc-1', + parentSessionId: 'root-session', + }) + ), + }); + const runtime = new AIChatRuntime({ + request, + scope: { + kind: 'playground', + workspaceId: 'workspace-1', + docId: 'doc-1', + parentSessionId: 'root-session', + }, + strategy: new PlaygroundAIChatSessionStrategy(), + }); + + const forkSession = await runtime.createSession(); + + expect(request.forkChat).toHaveBeenCalledWith({ + workspaceId: 'workspace-1', + docId: 'doc-1', + sessionId: 'root-session', + }); + expect(forkSession?.sessionId).toBe('playground-fork'); + }); +}); diff --git a/packages/frontend/core/src/blocksuite/ai/runtime/chat/runtime.ts b/packages/frontend/core/src/blocksuite/ai/runtime/chat/runtime.ts new file mode 100644 index 0000000000..c1c3473de4 --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/runtime/chat/runtime.ts @@ -0,0 +1,1311 @@ +import type { CopilotChatHistoryFragment } from '@affine/graphql'; + +import type { AIRequestService } from '../request'; +import type { AIChatAction, AIChatSendOptions } from './actions'; +import type { AIChatSessionStrategy } from './session-strategy'; +import { + type AIChatContextItem, + type AIChatMessage, + type AIChatScope, + type AIChatSnapshot, + type AIChatStatus, + type AIChatTab, + createDraftTab, + createInitialComposerState, + sessionToTab, +} from './state'; + +type RuntimeOptions = { + request: AIRequestService; + scope: AIChatScope; + strategy: AIChatSessionStrategy; +}; + +type ContextStatus = 'finished' | 'processing' | 'failed'; + +type ContextObject = { + id?: string; + docId?: string; + blobId?: string; + name?: string; + status?: ContextStatus; + error?: string | null; + createdAt?: number | null; + docs?: ContextObject[]; +}; + +type ContextData = { + docs?: ContextObject[]; + files?: ContextObject[]; + tags?: ContextObject[]; + collections?: ContextObject[]; + blobs?: ContextObject[]; +}; + +type EmbeddingStatus = { + embedded: number; + total: number; +}; + +const CONTEXT_POLLING_INTERVAL = 10000; + +export class AIChatRuntime { + private readonly listeners = new Set<() => void>(); + private requestSeq = 0; + private historyRequestSeq = 0; + private contextRequestSeq = 0; + private streamAbortController: AbortController | null = null; + private contextPollingAbortController: AbortController | null = null; + private embeddingStatusAbortController: AbortController | null = null; + private createSessionPromiseKey: string | null = null; + private createSessionPromise: Promise< + CopilotChatHistoryFragment | null | undefined + > | null = null; + private snapshot: AIChatSnapshot; + + constructor(private readonly options: RuntimeOptions) { + this.snapshot = this.createInitialSnapshot(options.scope); + } + + getSnapshot = () => this.snapshot; + + subscribe = (listener: () => void) => { + this.listeners.add(listener); + return () => { + this.listeners.delete(listener); + }; + }; + + async createSession(options: { pinned?: boolean } = {}) { + const session = await this.options.strategy.createSession( + this.snapshot.scope, + this.options.request, + options + ); + if (session) { + this.openSessionObject(session, true); + } + return session ?? undefined; + } + + loadInitialSession() { + return this.options.strategy.loadInitialSession( + this.snapshot.scope, + this.options.request + ); + } + + dispose() { + this.requestSeq++; + this.createSessionPromise = null; + this.createSessionPromiseKey = null; + this.streamAbortController?.abort(); + this.contextPollingAbortController?.abort(); + this.embeddingStatusAbortController?.abort(); + this.listeners.clear(); + } + + async dispatch(action: AIChatAction) { + switch (action.type) { + case 'initialize': + await this.initialize(action.scope ?? this.snapshot.scope); + return; + case 'setScope': + await this.setScope(action.scope); + return; + case 'refreshHistory': + await this.refreshHistory(); + return; + case 'openSession': + await this.openSession(action.sessionId); + return; + case 'openSessionObject': + this.openSessionObject(action.session); + return; + case 'closeTab': + await this.closeTab(action.tabId); + return; + case 'createNewSession': + await this.createNewSession(action.pinned); + return; + case 'togglePinActiveSession': + await this.togglePinActiveSession(); + return; + case 'send': + await this.send(action); + return; + case 'retry': + await this.retry(); + return; + case 'stop': + this.stop(); + return; + case 'deleteSession': + await this.deleteSession(action.sessionId); + return; + case 'clearError': + this.commit({ status: 'idle', error: null }); + return; + case 'setComposerText': + this.updateComposer({ text: action.text }); + return; + case 'setReasoning': + this.updateComposer({ reasoning: action.reasoning }); + return; + case 'setModel': + this.updateComposer({ modelId: action.modelId }); + return; + case 'addAttachment': + this.updateComposer({ + attachments: [ + ...this.snapshot.composer.attachments, + action.attachment, + ], + }); + return; + case 'removeAttachment': + this.updateComposer({ + attachments: this.snapshot.composer.attachments.filter( + (_, index) => index !== action.index + ), + }); + return; + case 'addContextItem': + await this.addContextItem(action.item); + return; + case 'removeContextItem': + await this.removeContextItem(action.item); + return; + case 'loadContext': + await this.loadContext(); + return; + case 'startContextPolling': + this.startContextPolling(); + return; + case 'stopContextPolling': + this.stopContextPolling(); + return; + case 'pollContext': + await this.pollContext(); + return; + case 'pollEmbeddingStatus': + this.pollEmbeddingStatus(); + return; + } + } + + private createInitialSnapshot(scope: AIChatScope): AIChatSnapshot { + const draft = createDraftTab(scope); + return { + scope, + readiness: 'initializing', + activeSessionId: null, + activeTabId: draft.id, + tabs: [draft], + sessions: [], + history: { + currentDoc: [], + recent: [], + loading: false, + error: null, + }, + messages: [], + status: 'idle', + error: null, + composer: createInitialComposerState(), + navigationRequest: null, + uiPolicy: this.createUiPolicy('idle', [draft], draft.id), + }; + } + + private commit(patch: Partial) { + const next = { + ...this.snapshot, + ...patch, + }; + this.snapshot = { + ...next, + uiPolicy: this.createUiPolicy(next.status, next.tabs, next.activeTabId), + }; + this.listeners.forEach(listener => listener()); + } + + private createUiPolicy( + status: AIChatStatus, + tabs: AIChatTab[], + activeTabId: string | null + ): AIChatSnapshot['uiPolicy'] { + const activeTab = tabs.find(tab => tab.id === activeTabId); + const isGenerating = status === 'loading' || status === 'transmitting'; + return { + showDraftTab: activeTab?.kind === 'draft', + canCreateNewSession: !!activeTab?.hasMessages && !isGenerating, + canCloseActiveTab: activeTab?.kind === 'session' && tabs.length > 1, + canPinActiveSession: activeTab?.kind === 'session', + canSend: !isGenerating, + }; + } + + private getScopeKey(scope: AIChatScope) { + switch (scope.kind) { + case 'doc': + return `${scope.kind}:${scope.workspaceId}:${scope.docId}`; + case 'workspace': + return `${scope.kind}:${scope.workspaceId}`; + case 'fork': + return `${scope.kind}:${scope.workspaceId}:${scope.parentSessionId}:${scope.latestMessageId ?? ''}:${scope.docId ?? ''}`; + case 'chat-block': + return `${scope.kind}:${scope.workspaceId}:${scope.docId}:${scope.blockId}:${scope.parentSessionId ?? ''}:${scope.latestMessageId ?? ''}`; + case 'playground': + return `${scope.kind}:${scope.workspaceId}:${scope.docId ?? ''}:${scope.parentSessionId ?? ''}:${scope.latestMessageId ?? ''}`; + } + } + + private markActiveTabHasMessages(tabs: AIChatTab[]) { + return tabs.map(tab => + tab.kind === 'session' && tab.id === this.snapshot.activeTabId + ? { ...tab, hasMessages: true } + : tab + ); + } + + private async initialize(scope: AIChatScope) { + const seq = ++this.requestSeq; + this.commit({ + ...this.createInitialSnapshot(scope), + readiness: 'initializing', + }); + const session = await this.options.strategy.loadInitialSession( + scope, + this.options.request + ); + if (seq !== this.requestSeq) return; + if (!session) { + const draft = this.options.strategy.createDraftSession(scope); + this.commit({ + readiness: 'ready', + tabs: [draft], + sessions: [], + activeTabId: draft.id, + activeSessionId: null, + messages: [], + }); + return; + } + this.openSessionObject(session); + this.commit({ readiness: 'ready' }); + } + + private async setScope(scope: AIChatScope) { + this.stop(); + this.createSessionPromise = null; + this.createSessionPromiseKey = null; + await this.initialize(scope); + } + + private async ensureSession() { + const activeSessionId = this.snapshot.activeSessionId; + if (activeSessionId) { + return ( + this.snapshot.sessions.find( + session => session.sessionId === activeSessionId + ) ?? + this.options.request.getSession( + this.snapshot.scope.workspaceId, + activeSessionId + ) + ); + } + const scopeKey = this.getScopeKey(this.snapshot.scope); + if ( + !this.createSessionPromise || + this.createSessionPromiseKey !== scopeKey + ) { + this.createSessionPromiseKey = scopeKey; + this.createSessionPromise = this.options.strategy + .createSession(this.snapshot.scope, this.options.request) + .finally(() => { + this.createSessionPromise = null; + this.createSessionPromiseKey = null; + }); + } + return this.createSessionPromise; + } + + private resetLastAssistantMessage(messages: AIChatMessage[]) { + return messages.map((message, index) => + index === messages.length - 1 && message.role === 'assistant' + ? { ...message, content: '', createdAt: new Date().toISOString() } + : message + ); + } + + private getLastUserMessage() { + return this.snapshot.messages.findLast(message => message.role === 'user'); + } + + private async send(options: AIChatSendOptions, retryExisting = false) { + const content = options.input || this.snapshot.composer.text; + if (!content.trim() || !this.snapshot.uiPolicy.canSend) return; + const seq = ++this.requestSeq; + this.streamAbortController?.abort(); + this.streamAbortController = new AbortController(); + this.commit({ + status: 'loading', + error: null, + messages: retryExisting + ? this.resetLastAssistantMessage(this.snapshot.messages) + : [ + ...this.snapshot.messages, + this.createMessage('user', content, { + attachments: options.attachmentPreviews, + ...options.userInfo, + }), + this.createMessage('assistant', ''), + ], + }); + try { + const session = await this.ensureSession(); + if (seq !== this.requestSeq) return; + if (!session) { + this.commit({ status: 'error', error: new Error('Session not found') }); + return; + } + if (!this.snapshot.activeSessionId) { + this.openSessionObject(session, true); + } + + const stream = (await this.options.request.executeAction('chat', { + workspaceId: this.snapshot.scope.workspaceId, + docId: + 'docId' in this.snapshot.scope + ? this.snapshot.scope.docId + : undefined, + sessionId: session.sessionId, + input: content, + contexts: options.contexts, + attachments: options.attachments ?? this.snapshot.composer.attachments, + contextId: this.snapshot.composer.context.contextId, + reasoning: options.reasoning ?? this.snapshot.composer.reasoning, + toolsConfig: options.toolsConfig ?? this.snapshot.composer.toolsConfig, + modelId: options.modelId ?? this.snapshot.composer.modelId, + isRootSession: options.isRootSession, + where: options.where, + control: options.control, + stream: true, + signal: this.streamAbortController.signal, + })) as AsyncIterable; + + for await (const chunk of stream) { + if (seq !== this.requestSeq) return; + this.appendAssistantContent(chunk); + this.commit({ status: 'transmitting' }); + } + if (seq !== this.requestSeq) return; + this.commit({ + status: 'success', + tabs: this.markActiveTabHasMessages(this.snapshot.tabs), + composer: { + ...this.snapshot.composer, + text: '', + attachments: [], + }, + }); + await this.refreshLastMessageId(session.sessionId).catch(console.error); + await this.bindActiveSessionToDoc().catch(console.error); + } catch (error) { + if (seq !== this.requestSeq) return; + this.commit({ status: 'error', error: this.toError(error) }); + } + } + + private async retry() { + if (!this.snapshot.uiPolicy.canSend) { + return; + } + if (!this.snapshot.activeSessionId) { + const lastUserMessage = this.getLastUserMessage(); + if (lastUserMessage) { + await this.send({ input: lastUserMessage.content }, true); + } + return; + } + const seq = ++this.requestSeq; + this.streamAbortController?.abort(); + this.streamAbortController = new AbortController(); + this.commit({ + status: 'loading', + error: null, + messages: this.resetLastAssistantMessage(this.snapshot.messages), + }); + try { + const stream = (await this.options.request.executeAction('chat', { + workspaceId: this.snapshot.scope.workspaceId, + sessionId: this.snapshot.activeSessionId, + retry: true, + stream: true, + signal: this.streamAbortController.signal, + })) as AsyncIterable; + for await (const chunk of stream) { + if (seq !== this.requestSeq) return; + this.appendAssistantContent(chunk); + this.commit({ status: 'transmitting' }); + } + if (seq === this.requestSeq) { + this.commit({ status: 'success' }); + await this.refreshLastMessageId(this.snapshot.activeSessionId).catch( + console.error + ); + await this.bindActiveSessionToDoc().catch(console.error); + } + } catch (error) { + if (seq !== this.requestSeq) return; + this.commit({ status: 'error', error: this.toError(error) }); + } + } + + private stop() { + this.requestSeq++; + this.streamAbortController?.abort(); + this.streamAbortController = null; + if ( + this.snapshot.status === 'loading' || + this.snapshot.status === 'transmitting' + ) { + this.commit({ status: 'success' }); + } + } + + private async refreshHistory() { + const seq = ++this.historyRequestSeq; + this.commit({ + history: { ...this.snapshot.history, loading: true, error: null }, + }); + try { + const currentDoc = + this.snapshot.scope.kind === 'doc' + ? ((await this.options.request.getSessions( + this.snapshot.scope.workspaceId, + this.snapshot.scope.docId, + { action: false, fork: false } + )) ?? []) + : []; + const recent = + (await this.options.request.getRecentSessions( + this.snapshot.scope.workspaceId + )) ?? []; + if (seq !== this.historyRequestSeq) return; + this.commit({ + history: { + currentDoc, + recent, + loading: false, + error: null, + }, + }); + } catch (error) { + if (seq !== this.historyRequestSeq) return; + this.commit({ + history: { + ...this.snapshot.history, + loading: false, + error: error instanceof Error ? error : new Error(String(error)), + }, + }); + } + } + + private updateComposer(patch: Partial) { + this.commit({ + composer: { + ...this.snapshot.composer, + ...patch, + }, + }); + } + + private updateContextState( + patch: Partial + ) { + this.updateComposer({ + context: { + ...this.snapshot.composer.context, + ...patch, + }, + }); + } + + private async getContextId() { + const createdSession = this.snapshot.activeSessionId + ? null + : await this.ensureSession(); + if (createdSession) { + this.openSessionObject(createdSession, true); + } + const sessionId = + this.snapshot.activeSessionId ?? createdSession?.sessionId ?? null; + if (!sessionId) return null; + + const cached = this.snapshot.composer.context.contextId; + if (cached) return cached; + + const { workspaceId } = this.snapshot.scope; + const existing = await this.options.request.context.getContextId( + workspaceId, + sessionId + ); + const contextId = + existing ?? + (await this.options.request.context.createContext( + workspaceId, + sessionId + )); + this.updateContextState({ contextId }); + return contextId; + } + + private async addContextItem(item: AIChatContextItem) { + const seq = ++this.contextRequestSeq; + this.updateContextState({ loading: true, error: null }); + try { + const contextId = await this.getContextId(); + if (!contextId) throw new Error('Context not found'); + + const nextItem = await this.persistContextItem(contextId, item); + if (seq !== this.contextRequestSeq) return; + this.updateContextState({ + loading: false, + items: [...this.snapshot.composer.context.items, nextItem], + }); + } catch (error) { + if (seq !== this.contextRequestSeq) return; + this.updateContextState({ loading: false, error: this.toError(error) }); + } + } + + private async removeContextItem(item: AIChatContextItem) { + const seq = ++this.contextRequestSeq; + this.updateContextState({ loading: true, error: null }); + try { + const contextId = this.snapshot.composer.context.contextId; + if (contextId) { + await this.deleteContextItem(contextId, item); + } + if (seq !== this.contextRequestSeq) return; + this.updateContextState({ + loading: false, + items: this.snapshot.composer.context.items.filter( + existing => + this.getContextItemKey(existing) !== this.getContextItemKey(item) + ), + }); + } catch (error) { + if (seq !== this.contextRequestSeq) return; + this.updateContextState({ loading: false, error: this.toError(error) }); + } + } + + private async pollContext() { + const seq = ++this.contextRequestSeq; + const sessionId = this.snapshot.activeSessionId; + const contextId = this.snapshot.composer.context.contextId; + if (!sessionId || !contextId) return; + + this.updateContextState({ polling: true, error: null }); + try { + const context = await this.options.request.context.getContextDocsAndFiles( + this.snapshot.scope.workspaceId, + sessionId, + contextId + ); + if (seq !== this.contextRequestSeq) return; + this.updateContextState({ + polling: false, + items: this.mergePolledContextItems(context), + embeddingCount: this.getContextEmbeddingCount(context), + }); + } catch (error) { + if (seq !== this.contextRequestSeq) return; + this.updateContextState({ polling: false, error: this.toError(error) }); + } + } + + private startContextPolling() { + this.stopContextPolling(); + this.contextPollingAbortController = new AbortController(); + const signal = this.contextPollingAbortController.signal; + void this.pollContextUntilIdle(signal).catch(error => { + if (signal.aborted) return; + this.updateContextState({ polling: false, error: this.toError(error) }); + }); + } + + private stopContextPolling() { + this.contextPollingAbortController?.abort(); + this.contextPollingAbortController = null; + } + + private async pollContextUntilIdle(signal: AbortSignal) { + while (!signal.aborted) { + await this.pollContext(); + if (signal.aborted) return; + if (this.snapshot.composer.context.embeddingCount.processing === 0) { + this.stopContextPolling(); + return; + } + await this.waitForContextPollingInterval(signal); + } + } + + private waitForContextPollingInterval(signal: AbortSignal) { + return new Promise(resolve => { + const timeout = setTimeout(resolve, CONTEXT_POLLING_INTERVAL); + signal.addEventListener( + 'abort', + () => { + clearTimeout(timeout); + resolve(); + }, + { once: true } + ); + }); + } + + private async loadContext() { + const seq = ++this.contextRequestSeq; + const sessionId = this.snapshot.activeSessionId; + if (!sessionId) return; + + this.updateContextState({ loading: true, error: null }); + try { + const { workspaceId } = this.snapshot.scope; + const contextId = await this.options.request.context.getContextId( + workspaceId, + sessionId + ); + if (!contextId) { + if (seq !== this.contextRequestSeq) return; + this.updateContextState({ + contextId: null, + items: [], + loading: false, + embeddingCount: { finished: 0, processing: 0, failed: 0 }, + }); + return; + } + const context = await this.options.request.context.getContextDocsAndFiles( + workspaceId, + sessionId, + contextId + ); + if (seq !== this.contextRequestSeq) return; + this.updateContextState({ + contextId, + loading: false, + items: this.contextDataToItems(context), + embeddingCount: this.getContextEmbeddingCount(context), + }); + } catch (error) { + if (seq !== this.contextRequestSeq) return; + this.updateContextState({ loading: false, error: this.toError(error) }); + } + } + + private pollEmbeddingStatus() { + this.embeddingStatusAbortController?.abort(); + this.embeddingStatusAbortController = new AbortController(); + const signal = this.embeddingStatusAbortController.signal; + void this.options.request.context + .pollEmbeddingStatus( + this.snapshot.scope.workspaceId, + status => { + if (signal.aborted) return; + this.updateContextState({ + embeddingCompleted: this.isEmbeddingCompleted(status), + }); + }, + signal + ) + .catch(error => { + if (signal.aborted) return; + this.updateContextState({ + embeddingCompleted: false, + error: this.toError(error), + }); + }); + } + + private async persistContextItem( + contextId: string, + item: AIChatContextItem + ): Promise { + switch (item.kind) { + case 'doc': + await this.options.request.context.addContextDoc({ + contextId, + docId: item.docId, + }); + return item; + case 'file': { + const file = await this.options.request.context.addContextFile( + item.file, + { contextId } + ); + return { + ...item, + fileId: file.id, + blobId: file.blobId ?? item.blobId, + state: file.status, + createdAt: file.createdAt, + tooltip: file.error ?? undefined, + }; + } + case 'tag': + await this.options.request.context.addContextTag({ + contextId, + tagId: item.tagId, + docIds: item.docIds, + }); + return item; + case 'collection': + await this.options.request.context.addContextCollection({ + contextId, + collectionId: item.collectionId, + docIds: item.docIds, + }); + return item; + case 'blob': { + const blob = await this.options.request.context.addContextBlob({ + contextId, + blobId: item.blobId, + }); + return { + ...item, + state: blob.status || item.state, + createdAt: blob.createdAt, + }; + } + } + } + + private deleteContextItem(contextId: string, item: AIChatContextItem) { + switch (item.kind) { + case 'doc': + return this.options.request.context.removeContextDoc({ + contextId, + docId: item.docId, + }); + case 'file': + if (!item.fileId) return Promise.resolve(); + return this.options.request.context.removeContextFile({ + contextId, + fileId: item.fileId, + }); + case 'tag': + return this.options.request.context.removeContextTag({ + contextId, + tagId: item.tagId, + }); + case 'collection': + return this.options.request.context.removeContextCollection({ + contextId, + collectionId: item.collectionId, + }); + case 'blob': + return this.options.request.context.removeContextBlob({ + contextId, + blobId: item.blobId, + }); + } + } + + private mergePolledContextItems(context: unknown) { + if (!context || typeof context !== 'object') { + return this.snapshot.composer.context.items; + } + const data = context as ContextData; + const docs = [ + ...(data.docs ?? []), + ...(data.tags ?? []).flatMap(tag => tag.docs ?? []), + ...(data.collections ?? []).flatMap(collection => collection.docs ?? []), + ]; + + return this.snapshot.composer.context.items.map(item => { + if (item.kind === 'doc') { + const doc = docs.find( + candidate => + candidate.docId === item.docId || candidate.id === item.docId + ); + return doc?.status + ? { ...item, state: doc.status, tooltip: doc.error ?? undefined } + : item; + } + if (item.kind === 'file') { + const file = data.files?.find( + candidate => + candidate.id === item.fileId || + candidate.blobId === item.blobId || + candidate.blobId === item.fileId + ); + return file?.status + ? { ...item, state: file.status, tooltip: file.error ?? undefined } + : item; + } + if (item.kind === 'blob') { + const blob = data.blobs?.find( + candidate => + candidate.blobId === item.blobId || candidate.id === item.blobId + ); + return blob?.status + ? { ...item, state: blob.status, tooltip: blob.error ?? undefined } + : item; + } + return item; + }); + } + + private contextDataToItems(context: unknown): AIChatContextItem[] { + if (!context || typeof context !== 'object') return []; + const data = context as ContextData; + const items: AIChatContextItem[] = [ + ...(data.docs ?? []).flatMap(doc => + doc.id + ? [ + { + kind: 'doc' as const, + docId: doc.id, + state: doc.status, + createdAt: doc.createdAt ?? undefined, + tooltip: doc.error ?? undefined, + }, + ] + : [] + ), + ...(data.files ?? []).flatMap(file => + file.id && file.name + ? [ + { + kind: 'file' as const, + file: new File([], file.name), + fileId: file.id, + blobId: file.blobId, + state: file.status, + createdAt: file.createdAt ?? undefined, + tooltip: file.error ?? undefined, + }, + ] + : [] + ), + ...(data.tags ?? []).flatMap(tag => + tag.id + ? [ + { + kind: 'tag' as const, + tagId: tag.id, + docIds: (tag.docs ?? []).flatMap(doc => + doc.id ? [doc.id] : [] + ), + state: 'finished', + createdAt: tag.createdAt ?? undefined, + tooltip: tag.error ?? undefined, + }, + ] + : [] + ), + ...(data.collections ?? []).flatMap(collection => + collection.id + ? [ + { + kind: 'collection' as const, + collectionId: collection.id, + docIds: (collection.docs ?? []).flatMap(doc => + doc.id ? [doc.id] : [] + ), + state: 'finished', + createdAt: collection.createdAt ?? undefined, + tooltip: collection.error ?? undefined, + }, + ] + : [] + ), + ...(data.blobs ?? []).flatMap(blob => + (blob.blobId ?? blob.id) + ? [ + { + kind: 'blob' as const, + blobId: blob.blobId ?? blob.id ?? '', + state: blob.status, + createdAt: blob.createdAt ?? undefined, + tooltip: blob.error ?? undefined, + }, + ] + : [] + ), + ]; + return items.sort((a, b) => (a.createdAt ?? 0) - (b.createdAt ?? 0)); + } + + private getContextEmbeddingCount( + context: unknown + ): AIChatSnapshot['composer']['context']['embeddingCount'] { + const count = { finished: 0, processing: 0, failed: 0 }; + if (!context || typeof context !== 'object') return count; + const data = context as ContextData; + const docs = [ + ...(data.docs ?? []), + ...(data.tags ?? []).flatMap(tag => tag.docs ?? []), + ...(data.collections ?? []).flatMap(collection => collection.docs ?? []), + ]; + for (const item of [ + ...docs, + ...(data.files ?? []), + ...(data.blobs ?? []), + ]) { + if (item.status) count[item.status]++; + } + return count; + } + + private isEmbeddingCompleted(status: unknown) { + if (!status || typeof status !== 'object') return false; + const { embedded, total } = status as EmbeddingStatus; + return embedded === total; + } + + private getContextItemKey(item: AIChatContextItem) { + switch (item.kind) { + case 'doc': + return `doc:${item.docId}`; + case 'file': + return `file:${item.fileId ?? item.file.name}`; + case 'tag': + return `tag:${item.tagId}`; + case 'collection': + return `collection:${item.collectionId}`; + case 'blob': + return `blob:${item.blobId}`; + } + } + + private async openSession(sessionId: string) { + const seq = ++this.requestSeq; + this.streamAbortController?.abort(); + const session = await this.options.request.getSession( + this.snapshot.scope.workspaceId, + sessionId + ); + if (seq !== this.requestSeq) return; + if (session) { + this.openSessionObject(session); + } else { + this.commit({ + tabs: this.snapshot.tabs.filter(tab => tab.id !== sessionId), + sessions: this.snapshot.sessions.filter( + session => session.sessionId !== sessionId + ), + messages: + this.snapshot.activeSessionId === sessionId + ? [] + : this.snapshot.messages, + }); + } + } + + private openSessionObject( + session: CopilotChatHistoryFragment, + preserveMessages = false + ) { + const result = this.options.strategy.openSession( + session, + this.snapshot.scope + ); + if (result.type === 'navigate') { + this.commit({ + navigationRequest: { + ...result.target, + resetTabs: true, + }, + tabs: [], + sessions: [], + activeSessionId: null, + activeTabId: null, + }); + return; + } + const tab = sessionToTab(result.session); + const existing = this.snapshot.tabs.findIndex(item => item.id === tab.id); + const tabs = + existing === -1 + ? [...this.snapshot.tabs.filter(item => item.kind !== 'draft'), tab] + : this.snapshot.tabs.map(item => (item.id === tab.id ? tab : item)); + const sessionExisting = this.snapshot.sessions.findIndex( + item => item.sessionId === result.session.sessionId + ); + const sessions = + sessionExisting === -1 + ? [...this.snapshot.sessions, result.session] + : this.snapshot.sessions.map(item => + item.sessionId === result.session.sessionId ? result.session : item + ); + this.commit({ + tabs, + sessions, + activeTabId: tab.id, + activeSessionId: result.session.sessionId, + navigationRequest: null, + messages: preserveMessages + ? this.snapshot.messages + : ((result.session.messages ?? []) as AIChatMessage[]).slice(), + }); + } + + private async closeTab(tabId: string) { + const tabs = this.snapshot.tabs.filter(tab => tab.id !== tabId); + const sessions = this.snapshot.sessions.filter( + session => session.sessionId !== tabId + ); + if (this.snapshot.activeTabId !== tabId) { + this.commit({ tabs, sessions }); + return; + } + const fallback = + tabs.at(-1) ?? + this.options.strategy.createDraftSession(this.snapshot.scope); + const seq = ++this.requestSeq; + if (fallback.kind === 'session') { + const session = await this.options.request.getSession( + this.snapshot.scope.workspaceId, + fallback.sessionId + ); + if (seq !== this.requestSeq) return; + if (session) { + this.commit({ tabs: tabs.length ? tabs : [fallback], sessions }); + this.openSessionObject(session); + return; + } + } + this.commit({ + tabs: tabs.length ? tabs : [fallback], + sessions, + activeTabId: fallback.id, + activeSessionId: fallback.kind === 'session' ? fallback.sessionId : null, + messages: [], + }); + } + + private async createNewSession(pinned?: boolean) { + if (!this.snapshot.uiPolicy.canCreateNewSession) return; + const seq = ++this.requestSeq; + const draft = this.options.strategy.createDraftSession(this.snapshot.scope); + const activeTabIndex = this.snapshot.tabs.findIndex( + tab => tab.id === this.snapshot.activeTabId + ); + const insertIndex = + activeTabIndex === -1 ? this.snapshot.tabs.length : activeTabIndex + 1; + const tabs = [ + ...this.snapshot.tabs.slice(0, insertIndex), + draft, + ...this.snapshot.tabs.slice(insertIndex), + ]; + this.commit({ + tabs, + sessions: this.snapshot.sessions, + activeTabId: draft.id, + activeSessionId: null, + messages: [], + }); + if (pinned) { + const session = await this.options.strategy.createSession( + this.snapshot.scope, + this.options.request, + { pinned } + ); + if (seq !== this.requestSeq) return; + if (session) this.openSessionObject(session); + } + } + + private async togglePinActiveSession() { + const activeSessionId = this.snapshot.activeSessionId; + if (!activeSessionId) return; + const active = this.snapshot.tabs.find( + tab => tab.kind === 'session' && tab.sessionId === activeSessionId + ); + if (!active || active.kind !== 'session') return; + const nextPinned = !active.pinned; + await this.options.request.updateSession({ + sessionId: activeSessionId, + pinned: nextPinned, + }); + this.commit({ + tabs: this.snapshot.tabs.map(tab => + tab.kind === 'session' && tab.sessionId === activeSessionId + ? { ...tab, pinned: nextPinned } + : tab + ), + sessions: this.snapshot.sessions.map(session => + session.sessionId === activeSessionId + ? { ...session, pinned: nextPinned } + : session + ), + }); + } + + private async deleteSession(sessionId: string) { + await this.options.request.cleanupSessions({ + workspaceId: this.snapshot.scope.workspaceId, + docId: + 'docId' in this.snapshot.scope ? this.snapshot.scope.docId : undefined, + sessionIds: [sessionId], + }); + await this.closeTab(sessionId); + } + + private async bindActiveSessionToDoc() { + if (this.snapshot.scope.kind !== 'doc' || !this.snapshot.activeSessionId) { + return; + } + const active = this.snapshot.sessions.find( + session => session.sessionId === this.snapshot.activeSessionId + ); + if (!active || active.docId === this.snapshot.scope.docId) { + return; + } + await this.options.request.updateSession({ + sessionId: active.sessionId, + docId: this.snapshot.scope.docId, + }); + const session = await this.options.request.getSession( + this.snapshot.scope.workspaceId, + active.sessionId + ); + if (session) { + this.openSessionObject(session, true); + } + } + + private createMessage( + role: AIChatMessage['role'], + content: string, + options: Partial = {} + ): AIChatMessage { + return { + id: '', + role, + content, + createdAt: new Date().toISOString(), + ...options, + }; + } + + private appendAssistantContent(content: string) { + const messages = this.snapshot.messages.slice(); + const last = messages.at(-1); + if (last?.role === 'assistant') { + const streamObject = this.parseStreamObject(content); + messages[messages.length - 1] = { + ...last, + ...(streamObject + ? { + streamObjects: this.mergeStreamObjects([ + ...(last.streamObjects ?? []), + streamObject, + ]), + } + : { content: last.content + content }), + }; + } + this.commit({ messages }); + } + + private parseStreamObject(content: string): unknown | null { + try { + const parsed = JSON.parse(content) as unknown; + if ( + parsed && + typeof parsed === 'object' && + 'type' in parsed && + typeof (parsed as { type?: unknown }).type === 'string' + ) { + return parsed; + } + } catch { + return null; + } + return null; + } + + private mergeStreamObjects(chunks: unknown[]): unknown[] { + return chunks.reduce((acc, curr) => { + if (!curr || typeof curr !== 'object' || !('type' in curr)) { + return acc; + } + const current = curr as { + type: string; + textDelta?: string; + toolCallId?: string; + toolName?: string; + }; + const previous = acc.at(-1) as + | { + type?: string; + textDelta?: string; + } + | undefined; + if ( + (current.type === 'reasoning' || current.type === 'text-delta') && + previous?.type === current.type + ) { + acc[acc.length - 1] = { + ...previous, + textDelta: `${previous.textDelta ?? ''}${current.textDelta ?? ''}`, + }; + return acc; + } + if (current.type === 'tool-result') { + const index = acc.findIndex(item => { + if (!item || typeof item !== 'object') return false; + const existing = item as { + type?: string; + toolCallId?: string; + toolName?: string; + }; + return ( + existing.type === 'tool-call' && + existing.toolCallId === current.toolCallId && + existing.toolName === current.toolName + ); + }); + if (index !== -1) { + acc[index] = curr; + return acc; + } + } + acc.push(curr); + return acc; + }, []); + } + + private async refreshLastMessageId(sessionId: string | null) { + if (!sessionId) return; + const last = this.snapshot.messages.at(-1); + if (!last || last.id) return; + const historyIds = await this.options.request.histories.ids( + this.snapshot.scope.workspaceId, + 'docId' in this.snapshot.scope ? this.snapshot.scope.docId : undefined, + { sessionId, withMessages: true } + ); + const lastId = historyIds?.[0]?.messages?.at(-1)?.id; + if (!lastId) return; + const messages = this.snapshot.messages.slice(); + messages[messages.length - 1] = { + ...last, + id: lastId, + }; + this.commit({ messages }); + } + + private toError(error: unknown) { + return error instanceof Error ? error : new Error(String(error)); + } +} diff --git a/packages/frontend/core/src/blocksuite/ai/runtime/chat/session-strategy.ts b/packages/frontend/core/src/blocksuite/ai/runtime/chat/session-strategy.ts new file mode 100644 index 0000000000..ba33edc9f7 --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/runtime/chat/session-strategy.ts @@ -0,0 +1,225 @@ +import type { CopilotChatHistoryFragment } from '@affine/graphql'; + +import type { AIRequestService } from '../request'; +import { type AIChatScope, type AIChatTab, createDraftTab } from './state'; + +export type OpenSessionResult = + | { + type: 'opened'; + session: CopilotChatHistoryFragment; + } + | { + type: 'navigate'; + target: { + workspaceId: string; + docId: string; + sessionId: string; + }; + resetTabs: true; + }; + +export interface AIChatSessionStrategy { + loadInitialSession( + scope: AIChatScope, + request: AIRequestService + ): Promise; + createDraftSession(scope: AIChatScope): AIChatTab; + createSession( + scope: AIChatScope, + request: AIRequestService, + options?: { pinned?: boolean } + ): Promise; + canOpenAsTab( + session: CopilotChatHistoryFragment, + scope: AIChatScope + ): boolean; + openSession( + session: CopilotChatHistoryFragment, + scope: AIChatScope + ): OpenSessionResult; +} + +export class DocAIChatSessionStrategy implements AIChatSessionStrategy { + async loadInitialSession(scope: AIChatScope, request: AIRequestService) { + if (scope.kind !== 'doc') return null; + const pinned = await request.getSessions(scope.workspaceId, undefined, { + pinned: true, + limit: 1, + }); + if (Array.isArray(pinned) && pinned[0]) { + return ( + (await request.getSession(scope.workspaceId, pinned[0].sessionId)) ?? + pinned[0] + ); + } + if (scope.pendingSessionId) { + return ( + (await request.getSession(scope.workspaceId, scope.pendingSessionId)) ?? + null + ); + } + + const docSessions = await request.getSessions( + scope.workspaceId, + scope.docId, + { + action: false, + fork: false, + limit: 1, + } + ); + const session = docSessions?.[0]; + if (!session) return null; + return ( + (await request.getSession(scope.workspaceId, session.sessionId)) ?? + session + ); + } + + createDraftSession(scope: AIChatScope) { + return createDraftTab(scope); + } + + createSession( + scope: AIChatScope, + request: AIRequestService, + options: { pinned?: boolean } = {} + ) { + if (scope.kind !== 'doc') return Promise.resolve(null); + return request.createSessionWithHistory({ + workspaceId: scope.workspaceId, + docId: scope.docId, + promptName: 'Chat With AFFiNE AI', + reuseLatestChat: false, + pinned: options.pinned, + }); + } + + canOpenAsTab(session: CopilotChatHistoryFragment, scope: AIChatScope) { + return ( + scope.kind === 'doc' && (!session.docId || session.docId === scope.docId) + ); + } + + openSession(session: CopilotChatHistoryFragment, scope: AIChatScope) { + if (this.canOpenAsTab(session, scope)) { + return { type: 'opened' as const, session }; + } + if (scope.kind === 'doc' && session.docId) { + return { + type: 'navigate' as const, + target: { + workspaceId: session.workspaceId, + docId: session.docId, + sessionId: session.sessionId, + }, + resetTabs: true as const, + }; + } + return { type: 'opened' as const, session }; + } +} + +export class WorkspaceAIChatSessionStrategy implements AIChatSessionStrategy { + async loadInitialSession(scope: AIChatScope, request: AIRequestService) { + const sessions = await request.getSessions(scope.workspaceId, undefined, { + pinned: true, + limit: 1, + }); + const session = sessions?.[0]; + if (!session) return null; + return ( + (await request.getSession(scope.workspaceId, session.sessionId)) ?? + session + ); + } + + createDraftSession(scope: AIChatScope) { + return createDraftTab(scope); + } + + createSession( + scope: AIChatScope, + request: AIRequestService, + options: { pinned?: boolean } = {} + ) { + return request.createSessionWithHistory({ + workspaceId: scope.workspaceId, + promptName: 'Chat With AFFiNE AI', + reuseLatestChat: false, + pinned: options.pinned, + }); + } + + canOpenAsTab(session: CopilotChatHistoryFragment, scope: AIChatScope) { + return session.workspaceId === scope.workspaceId && !session.docId; + } + + openSession(session: CopilotChatHistoryFragment) { + return { type: 'opened' as const, session }; + } +} + +export class ForkAIChatSessionStrategy implements AIChatSessionStrategy { + async loadInitialSession(scope: AIChatScope, request: AIRequestService) { + if ( + scope.kind !== 'fork' && + scope.kind !== 'chat-block' && + scope.kind !== 'playground' + ) { + return null; + } + const parentSessionId = + 'parentSessionId' in scope ? scope.parentSessionId : undefined; + if (!parentSessionId) return null; + return ( + (await request.getSession(scope.workspaceId, parentSessionId)) ?? null + ); + } + + createDraftSession(scope: AIChatScope) { + return createDraftTab(scope); + } + + async createSession( + scope: AIChatScope, + request: AIRequestService, + options: { pinned?: boolean } = {} + ) { + const docId = 'docId' in scope ? scope.docId : undefined; + const parentSessionId = + 'parentSessionId' in scope ? scope.parentSessionId : undefined; + if (!parentSessionId) { + return request.createSessionWithHistory({ + workspaceId: scope.workspaceId, + docId, + promptName: 'Chat With AFFiNE AI', + reuseLatestChat: false, + pinned: options.pinned, + }); + } + + const latestMessageId = + 'latestMessageId' in scope ? scope.latestMessageId : undefined; + const forkSessionId = await request.forkChat({ + workspaceId: scope.workspaceId, + docId: docId ?? '', + sessionId: parentSessionId, + ...(latestMessageId ? { latestMessageId } : {}), + }); + if (!forkSessionId) return null; + return request.getSession(scope.workspaceId, forkSessionId); + } + + canOpenAsTab(session: CopilotChatHistoryFragment, scope: AIChatScope) { + return session.workspaceId === scope.workspaceId; + } + + openSession(session: CopilotChatHistoryFragment) { + return { type: 'opened' as const, session }; + } +} + +export class ChatBlockAIChatSessionStrategy extends ForkAIChatSessionStrategy {} + +export class PlaygroundAIChatSessionStrategy extends ForkAIChatSessionStrategy {} diff --git a/packages/frontend/core/src/blocksuite/ai/runtime/chat/state.ts b/packages/frontend/core/src/blocksuite/ai/runtime/chat/state.ts new file mode 100644 index 0000000000..0840e427a9 --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/runtime/chat/state.ts @@ -0,0 +1,212 @@ +import type { AIToolsConfig } from '@affine/core/modules/ai-button'; +import type { CopilotChatHistoryFragment } from '@affine/graphql'; + +export type AIChatScope = + | { + kind: 'doc'; + workspaceId: string; + docId: string; + pendingSessionId?: string; + } + | { + kind: 'workspace'; + workspaceId: string; + } + | { + kind: 'fork'; + workspaceId: string; + parentSessionId: string; + latestMessageId?: string; + docId?: string; + } + | { + kind: 'chat-block'; + workspaceId: string; + docId: string; + blockId: string; + parentSessionId?: string; + latestMessageId?: string; + } + | { + kind: 'playground'; + workspaceId: string; + docId?: string; + parentSessionId?: string; + latestMessageId?: string; + }; + +export type AIChatMessage = { + id: string; + role: 'user' | 'assistant'; + content: string; + createdAt: string; + streamObjects?: unknown[]; + attachments?: string[]; + userId?: string; + userName?: string; + avatarUrl?: string; +}; + +export type AIChatTab = + | { + kind: 'draft'; + id: string; + title: string; + scope: AIChatScope; + hasMessages: false; + } + | { + kind: 'session'; + id: string; + sessionId: string; + title: string; + docId: string | null; + pinned: boolean; + hasMessages: boolean; + }; + +export type AIChatStatus = + | 'idle' + | 'loading' + | 'transmitting' + | 'success' + | 'error'; + +export type AIChatHistoryGroups = { + currentDoc: CopilotChatHistoryFragment[]; + recent: CopilotChatHistoryFragment[]; + loading: boolean; + error: Error | null; +}; + +export type AIChatContextItem = + | { + kind: 'doc'; + docId: string; + state?: string; + createdAt?: number; + tooltip?: string; + } + | { + kind: 'file'; + file: File; + blobId?: string; + fileId?: string; + state?: string; + createdAt?: number; + tooltip?: string; + } + | { + kind: 'tag'; + tagId: string; + docIds: string[]; + state?: string; + createdAt?: number; + tooltip?: string; + } + | { + kind: 'collection'; + collectionId: string; + docIds: string[]; + state?: string; + createdAt?: number; + tooltip?: string; + } + | { + kind: 'blob'; + blobId: string; + state?: string; + createdAt?: number; + tooltip?: string; + }; + +export type AIChatContextState = { + contextId: string | null; + items: AIChatContextItem[]; + loading: boolean; + polling: boolean; + error: Error | null; + embeddingCompleted: boolean; + embeddingCount: Record<'finished' | 'processing' | 'failed', number>; +}; + +export type AIChatComposerState = { + text: string; + attachments: (string | Blob | File)[]; + context: AIChatContextState; + reasoning: boolean; + toolsConfig?: AIToolsConfig; + modelId?: string; +}; + +export type AIChatNavigationRequest = { + workspaceId: string; + docId: string; + sessionId: string; + resetTabs: true; +}; + +export type AIChatSnapshot = { + scope: AIChatScope; + readiness: 'initializing' | 'ready' | 'unavailable'; + activeSessionId: string | null; + activeTabId: string | null; + tabs: AIChatTab[]; + sessions: CopilotChatHistoryFragment[]; + history: AIChatHistoryGroups; + messages: AIChatMessage[]; + status: AIChatStatus; + error: Error | null; + composer: AIChatComposerState; + navigationRequest: AIChatNavigationRequest | null; + uiPolicy: { + showDraftTab: boolean; + canCreateNewSession: boolean; + canCloseActiveTab: boolean; + canPinActiveSession: boolean; + canSend: boolean; + }; +}; + +export function createInitialComposerState(): AIChatComposerState { + return { + text: '', + attachments: [], + context: { + contextId: null, + items: [], + loading: false, + polling: false, + error: null, + embeddingCompleted: false, + embeddingCount: { + finished: 0, + processing: 0, + failed: 0, + }, + }, + reasoning: false, + }; +} + +export function sessionToTab(session: CopilotChatHistoryFragment): AIChatTab { + return { + kind: 'session', + id: session.sessionId, + sessionId: session.sessionId, + title: session.title || 'New chat', + docId: session.docId ?? null, + pinned: !!session.pinned, + hasMessages: !!session.messages?.length, + }; +} + +export function createDraftTab(scope: AIChatScope): AIChatTab { + return { + kind: 'draft', + id: `draft:${scope.kind}:${'docId' in scope ? (scope.docId ?? '') : ''}`, + title: 'New chat', + scope, + hasMessages: false, + }; +} diff --git a/packages/frontend/core/src/blocksuite/ai/runtime/chat/use-element.spec.tsx b/packages/frontend/core/src/blocksuite/ai/runtime/chat/use-element.spec.tsx new file mode 100644 index 0000000000..f9f6e9c295 --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/runtime/chat/use-element.spec.tsx @@ -0,0 +1,90 @@ +/** + * @vitest-environment happy-dom + */ +import { renderHook, waitFor } from '@testing-library/react'; +import { describe, expect, test, vi } from 'vitest'; + +import { useAIChatElement } from './use-element'; + +class TestAIChatElement extends HTMLElement { + accessor value = ''; +} + +if (!customElements.get('test-ai-chat-element')) { + customElements.define('test-ai-chat-element', TestAIChatElement); +} + +function createContainerRef() { + const container = document.createElement('div'); + document.body.append(container); + return { current: container }; +} + +function createElement() { + return document.createElement('test-ai-chat-element') as TestAIChatElement; +} + +describe('useAIChatElement', () => { + test('creates one element and keeps its properties in sync', async () => { + const containerRef = createContainerRef(); + + const { rerender, result } = renderHook( + ({ value }) => + useAIChatElement({ + containerRef, + selector: 'test-ai-chat-element', + enabled: true, + createElement, + configureElement: element => { + element.value = value; + }, + }), + { initialProps: { value: 'first' } } + ); + + await waitFor(() => { + expect( + containerRef.current.querySelectorAll('test-ai-chat-element') + ).toHaveLength(1); + expect(result.current?.value).toBe('first'); + }); + + rerender({ value: 'next' }); + + await waitFor(() => { + expect( + containerRef.current.querySelectorAll('test-ai-chat-element') + ).toHaveLength(1); + expect(result.current?.value).toBe('next'); + }); + }); + + test('reuses an existing element and removes duplicates', async () => { + const containerRef = createContainerRef(); + const first = createElement(); + const duplicate = createElement(); + containerRef.current.append(first, duplicate); + const createElementSpy = vi.fn(createElement); + + const { result } = renderHook(() => + useAIChatElement({ + containerRef, + selector: 'test-ai-chat-element', + enabled: true, + createElement: createElementSpy, + configureElement: element => { + element.value = 'reused'; + }, + }) + ); + + await waitFor(() => { + expect( + containerRef.current.querySelectorAll('test-ai-chat-element') + ).toHaveLength(1); + expect(result.current).toBe(first); + expect(first.value).toBe('reused'); + }); + expect(createElementSpy).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/frontend/core/src/blocksuite/ai/runtime/chat/use-element.ts b/packages/frontend/core/src/blocksuite/ai/runtime/chat/use-element.ts new file mode 100644 index 0000000000..107ccd819c --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/runtime/chat/use-element.ts @@ -0,0 +1,62 @@ +import type { RefObject } from 'react'; +import { useEffect, useRef, useState } from 'react'; + +export type UseAIChatElementOptions = { + containerRef: RefObject; + selector: string; + enabled: boolean; + createElement: () => T; + configureElement: (element: T) => void; + onElementReady?: (element: T) => void; +}; + +export function useAIChatElement({ + containerRef, + selector, + enabled, + createElement, + configureElement, + onElementReady, +}: UseAIChatElementOptions) { + const [element, setElement] = useState(null); + const readyElementsRef = useRef(new WeakSet()); + + useEffect(() => { + const container = containerRef.current; + if (!enabled || !container) return; + + const existingElements = Array.from( + container.querySelectorAll(selector) + ) as T[]; + const nextElement = element ?? existingElements[0] ?? createElement(); + + existingElements + .filter(existingElement => existingElement !== nextElement) + .forEach(existingElement => existingElement.remove()); + + configureElement(nextElement); + + if (nextElement.parentElement !== container) { + container.append(nextElement); + } + + if (!readyElementsRef.current.has(nextElement)) { + readyElementsRef.current.add(nextElement); + onElementReady?.(nextElement); + } + + if (element !== nextElement) { + setElement(nextElement); + } + }, [ + configureElement, + containerRef, + createElement, + element, + enabled, + onElementReady, + selector, + ]); + + return element; +} diff --git a/packages/frontend/core/src/blocksuite/ai/runtime/chat/use-runtime.ts b/packages/frontend/core/src/blocksuite/ai/runtime/chat/use-runtime.ts new file mode 100644 index 0000000000..db383f10a3 --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/runtime/chat/use-runtime.ts @@ -0,0 +1,22 @@ +import { useEffect, useSyncExternalStore } from 'react'; + +import type { AIChatRuntime } from './runtime'; + +/** + * Initializes and owns the passed runtime for the current React mount. + */ +export function useAIChatRuntime(runtime: AIChatRuntime | null) { + const snapshot = useSyncExternalStore( + runtime?.subscribe ?? (() => () => {}), + runtime?.getSnapshot ?? (() => null), + runtime?.getSnapshot ?? (() => null) + ); + + useEffect(() => { + if (!runtime) return; + runtime.dispatch({ type: 'initialize' }).catch(console.error); + return () => runtime.dispose(); + }, [runtime]); + + return snapshot; +} diff --git a/packages/frontend/core/src/blocksuite/ai/runtime/request/action-definitions.ts b/packages/frontend/core/src/blocksuite/ai/runtime/request/action-definitions.ts new file mode 100644 index 0000000000..c414b36bca --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/runtime/request/action-definitions.ts @@ -0,0 +1,225 @@ +import type { PromptKey } from '../../provider/prompt'; +import { Endpoint } from './copilot-client'; +import type { TextToTextOptions } from './message-transport'; + +export type AIActionId = keyof BlockSuitePresets.AIActions; +export type AIActionOptions = BlockSuitePresets.AITextActionOptions & + Record; + +export type AIActionDefinition = { + id: AIActionId; + promptName: PromptKey | ((options: AIActionOptions) => PromptKey); + responseType: 'text' | 'image'; + endpoint?: Endpoint; + actionId?: string | ((options: AIActionOptions) => string | undefined); + actionVersion?: string | ((options: AIActionOptions) => string | undefined); + timeout?: number; + buildContent?: (options: AIActionOptions) => string | undefined; + buildParams?: (options: AIActionOptions) => TextToTextOptions['params']; + validate?: (options: AIActionOptions) => void; +}; + +const filterStyleToPromptName = new Map( + Object.entries({ + 'Clay style': 'image.filter.clay', + 'Pixel style': 'image.filter.pixel', + 'Sketch style': 'image.filter.sketch', + 'Anime style': 'image.filter.anime', + }) +); + +const processTypeToPromptName = new Map( + Object.entries({ + Clearer: 'Upscale image', + 'Remove background': 'Remove background', + 'Convert to sticker': 'Convert to sticker', + }) +); + +const textAction = ( + id: AIActionId, + promptName: PromptKey +): AIActionDefinition => ({ + id, + promptName, + responseType: 'text', + buildContent: options => options.input, +}); + +export const actionDefinitions = { + chat: { + id: 'chat', + promptName: 'Chat With AFFiNE AI', + responseType: 'text', + timeout: 5 * 60 * 1000, + endpoint: Endpoint.StreamObject, + buildContent: options => options.input, + buildParams: options => { + const contexts = options.contexts as + | { + docs?: unknown; + files?: unknown; + selectedSnapshot?: unknown; + selectedMarkdown?: unknown; + html?: unknown; + } + | undefined; + return { + docs: contexts?.docs, + files: contexts?.files, + selectedSnapshot: contexts?.selectedSnapshot, + selectedMarkdown: contexts?.selectedMarkdown, + html: contexts?.html, + ...(options.docId ? { currentDocId: options.docId } : {}), + }; + }, + }, + summary: textAction('summary', 'Summary'), + translate: { + ...textAction('translate', 'Translate to'), + buildParams: options => ({ language: options.lang }), + }, + changeTone: { + ...textAction('changeTone', 'Change tone to'), + buildParams: options => ({ + tone: typeof options.tone === 'string' ? options.tone.toLowerCase() : '', + }), + }, + improveWriting: textAction('improveWriting', 'Improve writing for it'), + improveGrammar: textAction('improveGrammar', 'Improve grammar for it'), + fixSpelling: textAction('fixSpelling', 'Fix spelling for it'), + createHeadings: textAction('createHeadings', 'Create headings'), + makeLonger: textAction('makeLonger', 'Make it longer'), + makeShorter: textAction('makeShorter', 'Make it shorter'), + checkCodeErrors: textAction('checkCodeErrors', 'Check code error'), + explainCode: textAction('explainCode', 'Explain this code'), + writeArticle: textAction('writeArticle', 'Write an article about this'), + writeTwitterPost: textAction( + 'writeTwitterPost', + 'Write a twitter about this' + ), + writePoem: textAction('writePoem', 'Write a poem about this'), + writeOutline: textAction('writeOutline', 'Write outline'), + writeBlogPost: textAction('writeBlogPost', 'Write a blog post about this'), + brainstorm: textAction('brainstorm', 'Brainstorm ideas about this'), + findActions: textAction('findActions', 'Find action items from it'), + brainstormMindmap: { + id: 'brainstormMindmap', + promptName: 'mindmap.generate', + responseType: 'text', + timeout: 180000, + endpoint: Endpoint.Action, + actionId: 'mindmap.generate', + actionVersion: 'v1', + buildContent: options => options.input, + }, + expandMindmap: { + ...textAction('expandMindmap', 'Expand mind map'), + validate: options => { + if (!options.input) { + throw new Error('expandMindmap action requires input'); + } + }, + buildParams: options => ({ + mindmap: options.mindmap, + node: options.input, + }), + }, + explain: textAction('explain', 'Explain this'), + explainImage: textAction('explainImage', 'Explain this image'), + makeItReal: { + id: 'makeItReal', + promptName: options => + options.attachments && Array.isArray(options.attachments) + ? 'Make it real' + : 'Make it real with text', + responseType: 'text', + buildContent: options => { + const input = options.input ?? ''; + if (options.attachments && Array.isArray(options.attachments)) { + return `Here are the latest wireframes. Could you make a new website based on these wireframes and notes and send back just the html file? +Here are our design notes:\n ${input}.`; + } + return `Here are the latest notes: \n ${input}. +Could you make a new website based on these notes and send back just the html file?`; + }, + }, + createSlides: { + id: 'createSlides', + promptName: 'slides.outline', + responseType: 'text', + timeout: 180000, + endpoint: Endpoint.Action, + actionId: 'slides.outline', + actionVersion: 'v1', + buildContent: options => options.input, + }, + createImage: { + id: 'createImage', + promptName: 'Generate image', + responseType: 'image', + timeout: 300000, + buildContent: options => + !options.input && options.attachments + ? 'Make the image more detailed.' + : options.input, + }, + filterImage: { + id: 'filterImage', + promptName: options => { + const promptName = + typeof options.style === 'string' + ? filterStyleToPromptName.get(options.style) + : undefined; + if (!promptName) { + throw new Error('filterImage requires a promptName'); + } + return promptName; + }, + responseType: 'image', + timeout: 180000, + endpoint: Endpoint.Action, + actionId: options => + typeof options.style === 'string' + ? filterStyleToPromptName.get(options.style) + : undefined, + actionVersion: 'v1', + buildContent: options => options.input, + }, + processImage: { + id: 'processImage', + promptName: options => { + const promptName = + typeof options.type === 'string' + ? processTypeToPromptName.get(options.type) + : undefined; + if (!promptName) { + throw new Error('processImage requires a promptName'); + } + return promptName; + }, + responseType: 'image', + timeout: 180000, + buildContent: options => options.input, + }, + generateCaption: textAction('generateCaption', 'Generate a caption'), + continueWriting: textAction('continueWriting', 'Continue writing'), +} satisfies Partial>; + +export function getActionDefinition(id: AIActionId): AIActionDefinition { + const definition = actionDefinitions[id]; + if (!definition) { + throw new Error(`AI action ${String(id)} is not defined`); + } + return definition; +} + +export function resolveDefinitionValue( + value: + | string + | ((options: AIActionOptions) => string | undefined) + | undefined, + options: AIActionOptions +) { + return typeof value === 'function' ? value(options) : value; +} diff --git a/packages/frontend/core/src/blocksuite/ai/runtime/request/byok-local-lease.ts b/packages/frontend/core/src/blocksuite/ai/runtime/request/byok-local-lease.ts new file mode 100644 index 0000000000..e6a3bb2b32 --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/runtime/request/byok-local-lease.ts @@ -0,0 +1,97 @@ +import { apis, type ClientHandler } from '@affine/electron-api'; +import { UserFriendlyError } from '@affine/error'; +import { + ByokProvider, + createWorkspaceByokLocalLeaseMutation, +} from '@affine/graphql'; + +import type { CopilotClient } from './copilot-client'; + +function isElectronBuild() { + return typeof BUILD_CONFIG !== 'undefined' && BUILD_CONFIG.isElectron; +} + +function byokStorageApi(): ClientHandler['byokStorage'] | undefined { + return isElectronBuild() ? apis?.byokStorage : undefined; +} + +function toGraphqlByokProvider(provider: string): ByokProvider | null { + switch (provider) { + case ByokProvider.openai: + return ByokProvider.openai; + case ByokProvider.anthropic: + return ByokProvider.anthropic; + case ByokProvider.gemini: + return ByokProvider.gemini; + case ByokProvider.fal: + return ByokProvider.fal; + default: + return null; + } +} + +function errorMetadata(error: unknown) { + if (!error || typeof error !== 'object') { + return { kind: typeof error }; + } + const record = error as Record; + return { + name: typeof record.name === 'string' ? record.name : undefined, + code: typeof record.code === 'string' ? record.code : undefined, + status: + typeof record.status === 'number' || typeof record.status === 'string' + ? record.status + : undefined, + type: typeof record.type === 'string' ? record.type : undefined, + }; +} + +export async function createWorkspaceByokLocalLease( + client: CopilotClient, + workspaceId?: string +) { + const storage = byokStorageApi(); + if (!workspaceId || !storage) { + return undefined; + } + + try { + if (!(await storage.isSupported())) return undefined; + const providers = await storage.getWorkspaceLeaseProviders(workspaceId); + if (!providers.length) return undefined; + const leaseProviders = providers.flatMap(provider => { + const gqlProvider = toGraphqlByokProvider(provider.provider); + return gqlProvider + ? [ + { + provider: gqlProvider, + name: provider.name, + description: provider.description ?? null, + apiKey: provider.apiKey, + endpoint: provider.endpoint ?? null, + sortOrder: provider.sortOrder ?? 0, + enabled: provider.enabled ?? true, + }, + ] + : []; + }); + if (!leaseProviders.length) return undefined; + + const result = await client.gql({ + query: createWorkspaceByokLocalLeaseMutation, + variables: { + input: { + workspaceId, + providers: leaseProviders, + }, + }, + }); + return result.createWorkspaceByokLocalLease.leaseId; + } catch (error) { + console.warn( + 'Failed to create workspace BYOK local lease', + errorMetadata(error) + ); + throw UserFriendlyError.fromAny(error); + } +} diff --git a/packages/frontend/core/src/blocksuite/ai/provider/copilot-client.spec.ts b/packages/frontend/core/src/blocksuite/ai/runtime/request/copilot-client.spec.ts similarity index 100% rename from packages/frontend/core/src/blocksuite/ai/provider/copilot-client.spec.ts rename to packages/frontend/core/src/blocksuite/ai/runtime/request/copilot-client.spec.ts diff --git a/packages/frontend/core/src/blocksuite/ai/provider/copilot-client.ts b/packages/frontend/core/src/blocksuite/ai/runtime/request/copilot-client.ts similarity index 99% rename from packages/frontend/core/src/blocksuite/ai/provider/copilot-client.ts rename to packages/frontend/core/src/blocksuite/ai/runtime/request/copilot-client.ts index 4a2c4df571..5c16c65f6b 100644 --- a/packages/frontend/core/src/blocksuite/ai/provider/copilot-client.ts +++ b/packages/frontend/core/src/blocksuite/ai/runtime/request/copilot-client.ts @@ -38,7 +38,7 @@ import { GeneralNetworkError, PaymentRequiredError, UnauthorizedError, -} from './error'; +} from '../../provider/error'; export enum Endpoint { Action = 'action', diff --git a/packages/frontend/core/src/blocksuite/ai/runtime/request/index.ts b/packages/frontend/core/src/blocksuite/ai/runtime/request/index.ts new file mode 100644 index 0000000000..30489bf97c --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/runtime/request/index.ts @@ -0,0 +1,4 @@ +export * from './action-definitions'; +export * from './message-transport'; +export * from './provider'; +export * from './service'; diff --git a/packages/frontend/core/src/blocksuite/ai/provider/request.ts b/packages/frontend/core/src/blocksuite/ai/runtime/request/message-transport.ts similarity index 70% rename from packages/frontend/core/src/blocksuite/ai/provider/request.ts rename to packages/frontend/core/src/blocksuite/ai/runtime/request/message-transport.ts index 1a2de873f7..09e73cb2d9 100644 --- a/packages/frontend/core/src/blocksuite/ai/provider/request.ts +++ b/packages/frontend/core/src/blocksuite/ai/runtime/request/message-transport.ts @@ -1,114 +1,19 @@ import type { AIToolsConfig } from '@affine/core/modules/ai-button'; -import { apis, type ClientHandler } from '@affine/electron-api'; -import { UserFriendlyError } from '@affine/error'; -import { - ByokProvider, - createWorkspaceByokLocalLeaseMutation, -} from '@affine/graphql'; import { partition } from 'lodash-es'; -import { AIProvider } from './ai-provider'; +import { toTextStream } from '../../provider/event-source'; +import { createWorkspaceByokLocalLease } from './byok-local-lease'; import { type CopilotClient, Endpoint } from './copilot-client'; -import { toTextStream } from './event-source'; const TIMEOUT = 50000; -function isElectronBuild() { - return typeof BUILD_CONFIG !== 'undefined' && BUILD_CONFIG.isElectron; -} - -function byokStorageApi(): ClientHandler['byokStorage'] | undefined { - return isElectronBuild() ? apis?.byokStorage : undefined; -} - -function toGraphqlByokProvider(provider: string): ByokProvider | null { - switch (provider) { - case ByokProvider.openai: - return ByokProvider.openai; - case ByokProvider.anthropic: - return ByokProvider.anthropic; - case ByokProvider.gemini: - return ByokProvider.gemini; - case ByokProvider.fal: - return ByokProvider.fal; - default: - return null; - } -} - -function errorMetadata(error: unknown) { - if (!error || typeof error !== 'object') { - return { kind: typeof error }; - } - const record = error as Record; - return { - name: typeof record.name === 'string' ? record.name : undefined, - code: typeof record.code === 'string' ? record.code : undefined, - status: - typeof record.status === 'number' || typeof record.status === 'string' - ? record.status - : undefined, - type: typeof record.type === 'string' ? record.type : undefined, - }; -} - -async function createWorkspaceByokLocalLease( - client: CopilotClient, - workspaceId?: string -) { - const storage = byokStorageApi(); - if (!workspaceId || !storage) { - return undefined; - } - - try { - if (!(await storage.isSupported())) return undefined; - const providers = await storage.getWorkspaceLeaseProviders(workspaceId); - if (!providers.length) return undefined; - const leaseProviders = providers.flatMap(provider => { - const gqlProvider = toGraphqlByokProvider(provider.provider); - return gqlProvider - ? [ - { - provider: gqlProvider, - name: provider.name, - description: provider.description ?? null, - apiKey: provider.apiKey, - endpoint: provider.endpoint ?? null, - sortOrder: provider.sortOrder ?? 0, - enabled: provider.enabled ?? true, - }, - ] - : []; - }); - if (!leaseProviders.length) return undefined; - - const result = await client.gql({ - query: createWorkspaceByokLocalLeaseMutation, - variables: { - input: { - workspaceId, - providers: leaseProviders, - }, - }, - }); - return result.createWorkspaceByokLocalLease.leaseId; - } catch (error) { - console.warn( - 'Failed to create workspace BYOK local lease', - errorMetadata(error) - ); - throw UserFriendlyError.fromAny(error); - } -} - export type TextToTextOptions = { client: CopilotClient; sessionId: string; workspaceId?: string; content?: string; attachments?: (string | Blob | File)[]; - params?: Record; + params?: Record; timeout?: number; stream?: boolean; signal?: AbortSignal; @@ -138,7 +43,6 @@ async function resizeImage(blob: Blob | File): Promise { }); const canvas = document.createElement('canvas'); - // keep aspect ratio const scale = Math.min(1024 / img.width, 1024 / img.height); canvas.width = Math.floor(img.width * scale); canvas.height = Math.floor(img.height * scale); @@ -164,7 +68,7 @@ interface CreateMessageOptions { sessionId: string; content?: string; attachments?: (string | Blob | File)[]; - params?: Record; + params?: Record; timeout?: number; signal?: AbortSignal; } @@ -182,7 +86,7 @@ async function createMessage({ const options: Parameters[0] = { sessionId, content, - params, + params: params as Parameters[0]['params'], }; if (hasAttachments) { @@ -267,7 +171,6 @@ export function textToText({ }, endpoint ); - AIProvider.LAST_ACTION_SESSIONID = sessionId; let onAbort: (() => void) | undefined; try { @@ -336,7 +239,6 @@ export function textToText({ }, endpoint ); - AIProvider.LAST_ACTION_SESSIONID = sessionId; let onAbort: (() => void) | undefined; try { @@ -361,8 +263,7 @@ export function textToText({ } } - const result = messages.join(''); - return result; + return messages.join(''); } finally { eventSource.close(); if (signal && onAbort) { @@ -373,7 +274,6 @@ export function textToText({ } } -// Only one image is currently being processed export function toImage({ content, sessionId, @@ -435,7 +335,6 @@ export function toImage({ endpoint, byokLeaseId ); - AIProvider.LAST_ACTION_SESSIONID = sessionId; for await (const event of toTextStream(eventSource, { timeout, diff --git a/packages/frontend/core/src/blocksuite/ai/runtime/request/provider.ts b/packages/frontend/core/src/blocksuite/ai/runtime/request/provider.ts new file mode 100644 index 0000000000..da538c8cbc --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/runtime/request/provider.ts @@ -0,0 +1,18 @@ +import type { AIRequestService } from './service'; + +let currentRequestService: AIRequestService | null = null; + +export function setAIRequestService(service: AIRequestService | null) { + currentRequestService = service; +} + +export function getAIRequestService() { + if (!currentRequestService) { + throw new Error('AIRequestService is not initialized'); + } + return currentRequestService; +} + +export function hasAIRequestService() { + return !!currentRequestService; +} diff --git a/packages/frontend/core/src/blocksuite/ai/runtime/request/service.spec.ts b/packages/frontend/core/src/blocksuite/ai/runtime/request/service.spec.ts new file mode 100644 index 0000000000..f1ce035074 --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/runtime/request/service.spec.ts @@ -0,0 +1,379 @@ +/** + * @vitest-environment happy-dom + */ +import { UserFriendlyError } from '@affine/error'; +import { beforeEach, describe, expect, test, vi } from 'vitest'; + +import { type CopilotClient, Endpoint } from './copilot-client'; +import { textToText, toImage } from './message-transport'; +import { AIRequestService } from './service'; + +Object.defineProperty(globalThis, 'EventSource', { + configurable: true, + value: { + CLOSED: 2, + }, +}); + +const electronApis = vi.hoisted(() => ({ + byokStorage: undefined as + | { + isSupported: () => Promise; + getWorkspaceLeaseProviders: (workspaceId: string) => Promise< + Array<{ + provider: string; + name: string; + apiKey: string; + description?: string | null; + endpoint?: string | null; + sortOrder?: number | null; + enabled?: boolean | null; + }> + >; + } + | undefined, +})); + +const createWorkspaceByokLocalLeaseMutation = vi.hoisted(() => + Symbol('createWorkspaceByokLocalLeaseMutation') +); + +vi.mock('@affine/electron-api', () => ({ + apis: electronApis, +})); + +vi.mock('@affine/graphql', () => ({ + ByokProvider: { + openai: 'openai', + anthropic: 'anthropic', + gemini: 'gemini', + fal: 'fal', + }, + ContextCategories: { + Tag: 'tag', + Collection: 'collection', + }, + createWorkspaceByokLocalLeaseMutation, +})); + +function createClosedEventSource(): EventSource { + return { + readyState: EventSource.CLOSED, + addEventListener: vi.fn(), + close: vi.fn(), + } as unknown as EventSource; +} + +function createClient( + overrides: Partial< + Pick< + CopilotClient, + | 'gql' + | 'createSession' + | 'createMessage' + | 'getSessions' + | 'getHistories' + | 'chatTextStream' + | 'imagesStream' + > + > = {} +) { + return { + gql: vi.fn().mockResolvedValue({ + createWorkspaceByokLocalLease: { leaseId: 'lease-1' }, + }), + createSession: vi.fn().mockImplementation(async options => { + return `session:${options.promptName}`; + }), + createMessage: vi.fn().mockResolvedValue('message-1'), + getSessions: vi.fn().mockResolvedValue([]), + getHistories: vi.fn().mockResolvedValue([]), + chatTextStream: vi.fn(() => createClosedEventSource()), + imagesStream: vi.fn(() => createClosedEventSource()), + ...overrides, + } as unknown as CopilotClient; +} + +async function drain(stream: AsyncIterable) { + for await (const chunk of stream) { + void chunk; + } +} + +async function drainActionResult( + stream: string | AsyncIterable | undefined +) { + expect(stream).toBeDefined(); + expect(typeof stream).not.toBe('string'); + await drain(stream as AsyncIterable); +} + +describe('runtime request transport BYOK local lease handling', () => { + beforeEach(() => { + vi.stubGlobal('BUILD_CONFIG', { isElectron: true }); + electronApis.byokStorage = { + isSupported: vi.fn().mockResolvedValue(true), + getWorkspaceLeaseProviders: vi.fn().mockResolvedValue([ + { + provider: 'openai', + name: 'OpenAI', + apiKey: 'sk-local', + }, + ]), + }; + }); + + test('fails closed when local BYOK providers exist but lease creation fails', async () => { + const client = createClient({ + gql: vi.fn().mockRejectedValue(new Error('mutation failed')), + }); + + const result = textToText({ + client, + sessionId: 'session-1', + workspaceId: 'workspace-1', + content: 'hello', + }) as Promise; + + await expect(result).rejects.toThrow('mutation failed'); + await expect(result).rejects.toBeInstanceOf(UserFriendlyError); + expect(client.chatTextStream).not.toHaveBeenCalled(); + }); + + test('does not create stream local BYOK lease after cancellation', async () => { + const controller = new AbortController(); + const client = createClient({ + createMessage: vi.fn().mockImplementation(async () => { + controller.abort(); + return 'message-1'; + }), + }); + + await drain( + textToText({ + client, + sessionId: 'session-1', + workspaceId: 'workspace-1', + content: 'hello', + stream: true, + signal: controller.signal, + }) as AsyncIterable + ); + + expect(client.gql).not.toHaveBeenCalled(); + expect(client.chatTextStream).not.toHaveBeenCalled(); + }); + + test('does not create image stream when cancelled while creating local BYOK lease', async () => { + const controller = new AbortController(); + const client = createClient({ + gql: vi.fn().mockImplementation(async () => { + controller.abort(); + return { createWorkspaceByokLocalLease: { leaseId: 'lease-1' } }; + }), + }); + + await drain( + toImage({ + client, + sessionId: 'session-1', + workspaceId: 'workspace-1', + content: 'image', + endpoint: Endpoint.Images, + signal: controller.signal, + }) as AsyncIterable + ); + + expect(client.gql).toHaveBeenCalled(); + expect(client.imagesStream).not.toHaveBeenCalled(); + }); +}); + +describe('AIRequestService action definitions', () => { + beforeEach(() => { + vi.stubGlobal('BUILD_CONFIG', { isElectron: false }); + electronApis.byokStorage = undefined; + }); + + test('routes action-stream requests through action endpoint', async () => { + const client = createClient(); + const service = new AIRequestService(client); + + await drainActionResult( + (await service.executeAction('brainstormMindmap', { + workspaceId: 'workspace-1', + input: 'make a map', + stream: true, + })) as AsyncIterable + ); + await drainActionResult( + (await service.executeAction('createSlides', { + workspaceId: 'workspace-1', + input: 'make slides', + stream: true, + })) as AsyncIterable + ); + await drainActionResult( + (await service.executeAction('filterImage', { + workspaceId: 'workspace-1', + input: 'convert', + attachments: ['blob-1'], + style: 'Sketch style', + })) as AsyncIterable + ); + + expect(client.createSession).toHaveBeenCalledWith( + expect.objectContaining({ promptName: 'mindmap.generate' }) + ); + expect(client.createSession).toHaveBeenCalledWith( + expect.objectContaining({ promptName: 'slides.outline' }) + ); + expect(client.createSession).toHaveBeenCalledWith( + expect.objectContaining({ promptName: 'image.filter.sketch' }) + ); + expect(client.chatTextStream).toHaveBeenCalledWith( + expect.objectContaining({ actionId: 'mindmap.generate' }), + Endpoint.Action + ); + expect(client.chatTextStream).toHaveBeenCalledWith( + expect.objectContaining({ actionId: 'slides.outline' }), + Endpoint.Action + ); + expect(client.chatTextStream).toHaveBeenCalledWith( + expect.objectContaining({ actionId: 'image.filter.sketch' }), + Endpoint.Action + ); + expect(client.imagesStream).not.toHaveBeenCalled(); + }); + + test('reuses the last action session for retry', async () => { + const client = createClient(); + const service = new AIRequestService(client); + + await drainActionResult( + (await service.executeAction('summary', { + workspaceId: 'workspace-1', + input: 'summarize', + stream: true, + })) as AsyncIterable + ); + await drainActionResult( + (await service.executeAction('summary', { + workspaceId: 'workspace-1', + input: 'summarize again', + retry: true, + stream: true, + })) as AsyncIterable + ); + + expect(client.createSession).toHaveBeenCalledTimes(1); + expect(client.createMessage).toHaveBeenCalledTimes(1); + expect(client.chatTextStream).toHaveBeenLastCalledWith( + expect.objectContaining({ + sessionId: 'session:Summary', + retry: true, + }), + Endpoint.StreamObject + ); + }); + + test('reports action result against the matching host action', async () => { + const client = createClient(); + const service = new AIRequestService(client); + const events: string[] = []; + const hostOne = {} as NonNullable< + BlockSuitePresets.AITextActionOptions['host'] + >; + const hostTwo = {} as NonNullable< + BlockSuitePresets.AITextActionOptions['host'] + >; + const subscription = service.actionEvents$.subscribe(event => { + events.push( + `${event.options.host === hostOne ? 'one' : 'two'}:${event.event}` + ); + }); + + await drainActionResult( + (await service.executeAction('summary', { + workspaceId: 'workspace-1', + input: 'first', + host: hostOne, + stream: true, + })) as AsyncIterable + ); + await drainActionResult( + (await service.executeAction('translate', { + workspaceId: 'workspace-1', + input: 'second', + lang: 'French', + host: hostTwo, + stream: true, + })) as AsyncIterable + ); + + service.reportLastAction('result:insert', hostOne); + subscription.unsubscribe(); + + expect(events).toContain('one:result:insert'); + }); + + test('loads sessions through history query with messages', async () => { + const history = { + sessionId: 'session-1', + workspaceId: 'workspace-1', + docId: 'doc-1', + messages: [{ id: 'message-1', role: 'user', content: 'hello' }], + }; + const client = createClient({ + getHistories: vi.fn().mockResolvedValue([history]), + }); + const service = new AIRequestService(client); + + const session = await service.getSession('workspace-1', 'session-1'); + + expect(client.getHistories).toHaveBeenCalledWith( + 'workspace-1', + {}, + undefined, + expect.objectContaining({ + sessionId: 'session-1', + withMessages: true, + }) + ); + expect(session?.messages).toEqual(history.messages); + }); + + test('loads chat history lists with messages for title derivation', async () => { + const client = createClient(); + const service = new AIRequestService(client); + + await service.getSessions('workspace-1', 'doc-1', { + action: false, + fork: false, + }); + await service.getRecentSessions('workspace-1', 10, 20); + + expect(client.getSessions).toHaveBeenCalledWith( + 'workspace-1', + {}, + 'doc-1', + expect.objectContaining({ + action: false, + fork: false, + withMessages: true, + }), + undefined + ); + expect(client.getHistories).toHaveBeenCalledWith( + 'workspace-1', + { first: 10, offset: 20 }, + undefined, + expect.objectContaining({ + action: false, + fork: false, + sessionOrder: 'desc', + withMessages: true, + }) + ); + }); +}); diff --git a/packages/frontend/core/src/blocksuite/ai/runtime/request/service.ts b/packages/frontend/core/src/blocksuite/ai/runtime/request/service.ts new file mode 100644 index 0000000000..293a220803 --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/runtime/request/service.ts @@ -0,0 +1,396 @@ +import { + ContextCategories, + type CopilotChatHistoryFragment, + type getCopilotHistoriesQuery, + type GraphQLQuery, + type QueryChatSessionsInput, + type QueryOptions, + type QueryResponse, + type RequestOptions, + type UpdateChatSessionInput, +} from '@affine/graphql'; +import { Subject } from 'rxjs'; + +import type { ActionEventType } from '../../provider'; +import { + type AIActionId, + type AIActionOptions, + getActionDefinition, + resolveDefinitionValue, +} from './action-definitions'; +import { + CopilotClient, + type CopilotClient as CopilotClientType, +} from './copilot-client'; +import { textToText, toImage } from './message-transport'; + +type CreateSessionOptions = BlockSuitePresets.AICreateSessionOptions; + +export type AIRequestActionEvent = { + action: AIActionId; + options: AIActionOptions; + event: ActionEventType; +}; + +export class AIRequestService { + private lastActionSessionId = ''; + private readonly actionHistory: { + action: AIActionId; + options: AIActionOptions; + }[] = []; + readonly actionEvents$ = new Subject(); + + constructor(readonly client: CopilotClientType) {} + + isReady() { + return true; + } + + async createSession(options: CreateSessionOptions) { + if (options.sessionId) return options.sessionId; + if (options.retry) return this.lastActionSessionId; + return this.client.createSession({ + workspaceId: options.workspaceId, + docId: options.docId, + promptName: options.promptName, + pinned: options.pinned, + reuseLatestChat: options.reuseLatestChat, + }); + } + + async createSessionWithHistory(options: CreateSessionOptions) { + if (!options.sessionId && !options.retry) { + return this.client.createSessionWithHistory({ + workspaceId: options.workspaceId, + docId: options.docId, + promptName: options.promptName, + pinned: options.pinned, + reuseLatestChat: options.reuseLatestChat, + }); + } + + const sessionId = await this.createSession(options); + if (!sessionId) return undefined; + return this.getSession(options.workspaceId, sessionId); + } + + getSession(workspaceId: string, sessionId: string) { + return this.client + .getHistories(workspaceId, {}, undefined, { + sessionId, + withMessages: true, + } as RequestOptions< + typeof getCopilotHistoriesQuery + >['variables']['options']) + .then( + histories => + (histories?.[0] ?? null) as CopilotChatHistoryFragment | null + ); + } + + getSessions( + workspaceId: string, + docId?: string, + options?: QueryChatSessionsInput, + signal?: AbortSignal + ) { + return this.client.getSessions( + workspaceId, + {}, + docId, + { ...options, withMessages: true }, + signal + ); + } + + getRecentSessions(workspaceId: string, limit?: number, offset?: number) { + return this.client.getHistories( + workspaceId, + { first: limit, offset }, + undefined, + { + action: false, + fork: false, + sessionOrder: 'desc', + withMessages: true, + } as RequestOptions< + typeof getCopilotHistoriesQuery + >['variables']['options'] + ); + } + + updateSession(options: UpdateChatSessionInput) { + return this.client.updateSession(options); + } + + cleanupSessions(input: { + workspaceId: string; + docId: string | undefined; + sessionIds: string[]; + }) { + return this.client.cleanupSessions(input); + } + + histories = { + actions: async ( + workspaceId: string, + docId: string + ): Promise => { + return ((await this.client.getHistories(workspaceId, {}, docId, { + action: true, + withPrompt: true, + withMessages: true, + } as RequestOptions< + typeof getCopilotHistoriesQuery + >['variables']['options'])) ?? []) as BlockSuitePresets.AIHistory[]; + }, + chats: async ( + workspaceId: string, + sessionId: string, + docId?: string + ): Promise => { + return ((await this.client.getHistories(workspaceId, {}, docId, { + sessionId, + withMessages: true, + } as RequestOptions< + typeof getCopilotHistoriesQuery + >['variables']['options'])) ?? []) as BlockSuitePresets.AIHistory[]; + }, + cleanup: async ( + workspaceId: string, + docId: string | undefined, + sessionIds: string[] + ) => { + await this.cleanupSessions({ workspaceId, docId, sessionIds }); + }, + ids: async ( + workspaceId: string, + docId?: string, + options?: RequestOptions< + typeof getCopilotHistoriesQuery + >['variables']['options'] + ): Promise => { + return (await this.client.getHistoryIds( + workspaceId, + {}, + docId, + options + )) as unknown as BlockSuitePresets.AIHistoryIds[]; + }, + }; + + context = { + createContext: (workspaceId: string, sessionId: string) => + this.client.createContext(workspaceId, sessionId), + getContextId: (workspaceId: string, sessionId: string) => + this.client.getContextId(workspaceId, sessionId), + addContextDoc: (options: { contextId: string; docId: string }) => + this.client.addContextDoc(options), + removeContextDoc: (options: { contextId: string; docId: string }) => + this.client.removeContextDoc(options), + addContextFile: ( + file: File, + options: Parameters[1] + ) => this.client.addContextFile(file, options), + removeContextFile: (options: { contextId: string; fileId: string }) => + this.client.removeContextFile(options), + addContextTag: (options: { + contextId: string; + tagId: string; + docIds: string[]; + }) => + this.client.addContextCategory({ + contextId: options.contextId, + type: ContextCategories.Tag, + categoryId: options.tagId, + docs: options.docIds, + }), + removeContextTag: (options: { contextId: string; tagId: string }) => + this.client.removeContextCategory({ + contextId: options.contextId, + type: ContextCategories.Tag, + categoryId: options.tagId, + }), + addContextCollection: (options: { + contextId: string; + collectionId: string; + docIds: string[]; + }) => + this.client.addContextCategory({ + contextId: options.contextId, + type: ContextCategories.Collection, + categoryId: options.collectionId, + docs: options.docIds, + }), + removeContextCollection: (options: { + contextId: string; + collectionId: string; + }) => + this.client.removeContextCategory({ + contextId: options.contextId, + type: ContextCategories.Collection, + categoryId: options.collectionId, + }), + getContextDocsAndFiles: ( + workspaceId: string, + sessionId: string, + contextId: string + ) => this.client.getContextDocsAndFiles(workspaceId, sessionId, contextId), + matchContext: ( + content: string, + contextId?: string, + workspaceId?: string, + limit?: number, + scopedThreshold?: number, + threshold?: number + ) => + this.client.matchContext( + content, + contextId, + workspaceId, + limit, + scopedThreshold, + threshold + ), + addContextBlob: (options: { blobId: string; contextId: string }) => + this.client.addContextBlob({ + contextId: options.contextId, + blobId: options.blobId, + }), + removeContextBlob: (options: { blobId: string; contextId: string }) => + this.client.removeContextBlob({ + contextId: options.contextId, + blobId: options.blobId, + }), + pollContextDocsAndFiles: async ( + workspaceId: string, + sessionId: string, + contextId: string, + onPoll: ( + result: BlockSuitePresets.AIDocsAndFilesContext | undefined + ) => void, + abortSignal: AbortSignal + ) => { + let attempts = 0; + const minInterval = 1000; + const maxInterval = 30 * 1000; + + while (!abortSignal.aborted) { + const result = await this.client.getContextDocsAndFiles( + workspaceId, + sessionId, + contextId + ); + onPoll(result); + const interval = Math.min( + minInterval * Math.pow(1.5, attempts), + maxInterval + ); + attempts++; + await new Promise(resolve => setTimeout(resolve, interval)); + } + }, + pollEmbeddingStatus: async ( + workspaceId: string, + onPoll: ( + result: Awaited> + ) => void, + abortSignal: AbortSignal + ) => { + const interval = 10 * 1000; + while (!abortSignal.aborted) { + onPoll(await this.client.getEmbeddingStatus(workspaceId)); + await new Promise(resolve => setTimeout(resolve, interval)); + } + }, + }; + + forkChat(options: BlockSuitePresets.AIForkChatSessionOptions) { + return this.client.forkSession(options); + } + + reportLastAction(event: ActionEventType, host?: unknown) { + const lastAction = host + ? this.actionHistory.findLast(item => item.options.host === host) + : this.actionHistory.at(-1); + if (!lastAction) return; + this.actionEvents$.next({ + action: lastAction.action, + options: lastAction.options, + event, + }); + } + + private wrapTextStream( + stream: AsyncIterable, + id: AIActionId, + options: AIActionOptions + ): AsyncIterable { + const actionEvents$ = this.actionEvents$; + return { + async *[Symbol.asyncIterator]() { + try { + yield* stream; + actionEvents$.next({ action: id, options, event: 'finished' }); + } catch (error) { + actionEvents$.next({ action: id, options, event: 'error' }); + throw error; + } + }, + }; + } + + async executeAction(id: AIActionId, options: AIActionOptions) { + this.actionHistory.push({ action: id, options }); + if (this.actionHistory.length > 10) { + this.actionHistory.shift(); + } + this.actionEvents$.next({ action: id, options, event: 'started' }); + const definition = getActionDefinition(id); + definition.validate?.(options); + const promptName = resolveDefinitionValue( + definition.promptName, + options + ) as CreateSessionOptions['promptName']; + const sessionId = await this.createSession({ + promptName, + ...options, + } as CreateSessionOptions); + this.lastActionSessionId = sessionId; + + const actionId = resolveDefinitionValue(definition.actionId, options); + const actionVersion = resolveDefinitionValue( + definition.actionVersion, + options + ); + const transportOptions = { + ...options, + client: this.client, + sessionId, + content: definition.buildContent?.(options) ?? options.input, + params: definition.buildParams?.(options), + timeout: definition.timeout, + endpoint: definition.endpoint, + actionId, + actionVersion, + }; + + const stream = + definition.responseType === 'image' + ? toImage(transportOptions) + : textToText(transportOptions); + return this.wrapTextStream(stream as AsyncIterable, id, options); + } +} + +export function createAIRequestService( + gql: ( + options: QueryOptions + ) => Promise>, + eventSource: ( + url: string, + eventSourceInitDict?: EventSourceInit + ) => EventSource +) { + return new AIRequestService(new CopilotClient(gql, eventSource)); +} diff --git a/packages/frontend/core/src/blocksuite/ai/utils/action-reporter.ts b/packages/frontend/core/src/blocksuite/ai/utils/action-reporter.ts index 4257879ef0..8aa9b9e5b1 100644 --- a/packages/frontend/core/src/blocksuite/ai/utils/action-reporter.ts +++ b/packages/frontend/core/src/blocksuite/ai/utils/action-reporter.ts @@ -1,12 +1,6 @@ -import { type ActionEventType, AIProvider } from '../provider'; +import type { ActionEventType } from '../provider'; +import { getAIRequestService } from '../runtime/request'; -export function reportResponse(event: ActionEventType) { - const lastAction = AIProvider.actionHistory.at(-1); - if (!lastAction) return; - - AIProvider.slots.actions.next({ - action: lastAction.action, - options: lastAction.options, - event, - }); +export function reportResponse(event: ActionEventType, host?: unknown) { + getAIRequestService().reportLastAction(event, host); } diff --git a/packages/frontend/core/src/blocksuite/ai/widgets/edgeless-copilot/index.ts b/packages/frontend/core/src/blocksuite/ai/widgets/edgeless-copilot/index.ts index 529909cade..31a119afdd 100644 --- a/packages/frontend/core/src/blocksuite/ai/widgets/edgeless-copilot/index.ts +++ b/packages/frontend/core/src/blocksuite/ai/widgets/edgeless-copilot/index.ts @@ -27,7 +27,7 @@ import { styleMap } from 'lit/directives/style-map.js'; import { literal, unsafeStatic } from 'lit/static-html.js'; import type { AIItemGroupConfig } from '../../components/ai-item/types.js'; -import { AIProvider } from '../../provider/index.js'; +import { AIAppEvents } from '../../provider/index.js'; import { extractSelectedContent } from '../../utils/extract.js'; import { AFFINE_AI_PANEL_WIDGET, @@ -106,7 +106,7 @@ export class EdgelessCopilotWidget extends WidgetComponent { aiPanel.hide(); extractSelectedContent(this.host) .then(context => { - AIProvider.slots.requestSendWithChat.next({ + AIAppEvents.requestSendWithChat.next({ input, context, host: this.host, diff --git a/packages/frontend/core/src/components/providers/workspace-side-effects.tsx b/packages/frontend/core/src/components/providers/workspace-side-effects.tsx index 7349389609..867841b52e 100644 --- a/packages/frontend/core/src/components/providers/workspace-side-effects.tsx +++ b/packages/frontend/core/src/components/providers/workspace-side-effects.tsx @@ -4,8 +4,8 @@ import { resolveGlobalLoadingEventAtom, } from '@affine/component/global-loading'; import { - AIProvider, - CopilotClient, + AIAppEvents, + createAIRequestService, setupAIProvider, } from '@affine/core/blocksuite/ai'; import { useRegisterFindInPageCommands } from '@affine/core/components/hooks/affine/use-register-find-in-page-commands'; @@ -103,7 +103,7 @@ export const WorkspaceSideEffects = () => { }) ); - const disposable = AIProvider.slots.requestInsertTemplate.subscribe( + const disposable = AIAppEvents.requestInsertTemplate.subscribe( ({ template, mode }) => { insertTemplate({ template, mode }); } @@ -126,7 +126,7 @@ export const WorkspaceSideEffects = () => { const globalDialogService = useService(GlobalDialogService); useEffect(() => { - const disposable = AIProvider.slots.requestUpgradePlan.subscribe(() => { + const disposable = AIAppEvents.requestUpgradePlan.subscribe(() => { workspaceDialogService.open('setting', { activeTab: 'billing', }); @@ -143,7 +143,10 @@ export const WorkspaceSideEffects = () => { useEffect(() => { const dispose = setupAIProvider( - new CopilotClient(graphqlService.gql, eventSourceService.eventSource), + createAIRequestService( + graphqlService.gql, + eventSourceService.eventSource + ), globalDialogService, authService ); diff --git a/packages/frontend/core/src/desktop/pages/workspace/chat-panel-utils.ts b/packages/frontend/core/src/desktop/pages/workspace/chat-panel-utils.ts deleted file mode 100644 index 13c692c7e2..0000000000 --- a/packages/frontend/core/src/desktop/pages/workspace/chat-panel-utils.ts +++ /dev/null @@ -1,142 +0,0 @@ -import { WorkspaceLocalState } from '@affine/core/modules/workspace'; -import type { I18nInstance } from '@affine/i18n'; -import type { NotificationService } from '@blocksuite/affine/shared/services'; -import { useService } from '@toeverything/infra'; -import { - type Dispatch, - type SetStateAction, - useCallback, - useEffect, - useRef, - useState, -} from 'react'; - -const AI_CHAT_OPEN_TABS_KEY = 'aiChatOpenTabs'; - -// Pass `null` for `loadSession` to defer hydration until a real loader is ready. -export function useAIChatOpenTabs( - loadSession: ((sessionId: string) => Promise) | null -): { - openTabs: T[]; - setOpenTabs: Dispatch>; -} { - const workspaceLocalState = useService(WorkspaceLocalState); - const [openTabs, setOpenTabsState] = useState([]); - // Ref so persist gate isn't subject to React state-batch ordering. - const hydratedRef = useRef(false); - - useEffect(() => { - if (!loadSession) return; - hydratedRef.current = false; - setOpenTabsState([]); - - const ids = workspaceLocalState.get(AI_CHAT_OPEN_TABS_KEY) ?? []; - if (!ids.length) { - hydratedRef.current = true; - return; - } - - let cancelled = false; - Promise.all(ids.map(id => loadSession(id).catch(() => null))) - .then(results => { - if (cancelled) return; - const valid = (results as (T | null | undefined)[]).filter( - (entry): entry is T => !!entry && !!entry.sessionId - ); - if (valid.length) setOpenTabsState(valid); - hydratedRef.current = true; - }) - .catch(error => { - console.error(error); - if (!cancelled) hydratedRef.current = true; - }); - - return () => { - cancelled = true; - }; - }, [loadSession, workspaceLocalState]); - - const setOpenTabs = useCallback>>( - updater => { - setOpenTabsState(prev => { - const next = - typeof updater === 'function' - ? (updater as (p: T[]) => T[])(prev) - : updater; - if (hydratedRef.current) { - if (next.length) { - workspaceLocalState.set( - AI_CHAT_OPEN_TABS_KEY, - next.map(tab => tab.sessionId) - ); - } else { - workspaceLocalState.del(AI_CHAT_OPEN_TABS_KEY); - } - } - return next; - }); - }, - [workspaceLocalState] - ); - - return { openTabs, setOpenTabs }; -} - -export type SessionDeleteCleanupFn = ( - session: BlockSuitePresets.AIRecentSession -) => Promise; - -export type CreateSessionDeleteHandlerOptions = { - t: I18nInstance; - notificationService: NotificationService; - cleanupSession: SessionDeleteCleanupFn; - canDeleteSession?: (session: BlockSuitePresets.AIRecentSession) => boolean; - isActiveSession?: (session: BlockSuitePresets.AIRecentSession) => boolean; - onActiveSessionDeleted?: () => void; -}; - -export function createSessionDeleteHandler({ - t, - notificationService, - cleanupSession, - canDeleteSession, - isActiveSession, - onActiveSessionDeleted, -}: CreateSessionDeleteHandlerOptions) { - return async (sessionToDelete: BlockSuitePresets.AIRecentSession) => { - if (canDeleteSession && !canDeleteSession(sessionToDelete)) { - notificationService.toast( - t['com.affine.ai.chat-panel.session.delete.toast.failed']() - ); - return; - } - - const confirm = await notificationService.confirm({ - title: t['com.affine.ai.chat-panel.session.delete.confirm.title'](), - message: t['com.affine.ai.chat-panel.session.delete.confirm.message'](), - confirmText: t['Delete'](), - cancelText: t['Cancel'](), - }); - - if (!confirm) { - return; - } - - try { - await cleanupSession(sessionToDelete); - notificationService.toast( - t['com.affine.ai.chat-panel.session.delete.toast.success']() - ); - } catch (error) { - console.error(error); - notificationService.toast( - t['com.affine.ai.chat-panel.session.delete.toast.failed']() - ); - return; - } - - if (isActiveSession?.(sessionToDelete)) { - onActiveSessionDeleted?.(); - } - }; -} diff --git a/packages/frontend/core/src/desktop/pages/workspace/chat/index.tsx b/packages/frontend/core/src/desktop/pages/workspace/chat/index.tsx index 0f1e909335..aefd79c4b2 100644 --- a/packages/frontend/core/src/desktop/pages/workspace/chat/index.tsx +++ b/packages/frontend/core/src/desktop/pages/workspace/chat/index.tsx @@ -1,17 +1,17 @@ import { observeResize, useConfirmModal } from '@affine/component'; -import { CopilotClient } from '@affine/core/blocksuite/ai'; import { - AIChatContent, - type ChatContextValue, -} from '@affine/core/blocksuite/ai/components/ai-chat-content'; -import type { ChatStatus } from '@affine/core/blocksuite/ai/components/ai-chat-messages'; -import type { AIChatToolbar } from '@affine/core/blocksuite/ai/components/ai-chat-toolbar'; + AIChatRuntime, + createAIRequestService, + useAIChatElement, + useAIChatRuntime, + WorkspaceAIChatSessionStrategy, +} from '@affine/core/blocksuite/ai'; +import { AIChatContent } from '@affine/core/blocksuite/ai/components/ai-chat-content'; import { AIChatTabs, + AIChatToolbar, configureAIChatToolbar, - getOrCreateAIChatToolbar, } from '@affine/core/blocksuite/ai/components/ai-chat-toolbar'; -import type { PromptKey } from '@affine/core/blocksuite/ai/provider/prompt'; import { getViewManager } from '@affine/core/blocksuite/manager/view'; import { NotificationServiceImpl } from '@affine/core/blocksuite/view-extensions/editor-view/notification-service'; import { useAIChatConfig } from '@affine/core/components/hooks/affine/use-ai-chat-config'; @@ -50,20 +50,18 @@ import { useFramework, useService } from '@toeverything/infra'; import { nanoid } from 'nanoid'; import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; -import { - createSessionDeleteHandler, - useAIChatOpenTabs, -} from '../chat-panel-utils'; import * as styles from './index.css'; -type CopilotSession = Awaited>; - -function useCopilotClient() { +function useAIRequestService() { const graphqlService = useService(GraphQLService); const eventSourceService = useService(EventSourceService); return useMemo( - () => new CopilotClient(graphqlService.gql, eventSourceService.eventSource), + () => + createAIRequestService( + graphqlService.gql, + eventSourceService.eventSource + ), [graphqlService, eventSourceService] ); } @@ -95,164 +93,32 @@ export const Component = () => { const framework = useFramework(); const [isBodyProvided, setIsBodyProvided] = useState(false); const [isHeaderProvided, setIsHeaderProvided] = useState(false); - const [chatContent, setChatContent] = useState(null); - const [chatTool, setChatTool] = useState(null); - const [chatTabs, setChatTabs] = useState(null); - const [currentSession, setCurrentSession] = useState( - null - ); - const [status, setStatus] = useState('idle'); - const [isTogglingPin, setIsTogglingPin] = useState(false); - const [isOpeningSession, setIsOpeningSession] = useState(false); - const hasRestoredPinnedSessionRef = useRef(false); const chatContainerRef = useRef(null); const chatToolContainerRef = useRef(null); const chatTabsContainerRef = useRef(null); const widthSignalRef = useRef>(signal(0)); - const client = useCopilotClient(); + const requestService = useAIRequestService(); const workbench = useService(WorkbenchService).workbench; const workspaceId = useService(WorkspaceService).workspace.id; - const loadSession = useCallback( - (sessionId: string) => client.getSession(workspaceId, sessionId), - [client, workspaceId] + const runtime = useMemo( + () => + new AIChatRuntime({ + request: requestService, + scope: { kind: 'workspace', workspaceId }, + strategy: new WorkspaceAIChatSessionStrategy(), + }), + [requestService, workspaceId] ); - const { openTabs, setOpenTabs } = useAIChatOpenTabs(loadSession); - - useEffect(() => { - hasRestoredPinnedSessionRef.current = false; - }, [workspaceId]); - + const snapshot = useAIChatRuntime(runtime); + const session = + snapshot?.sessions.find( + session => session.sessionId === snapshot.activeSessionId + ) ?? null; const { docDisplayConfig, searchMenuConfig, reasoningConfig } = useAIChatConfig(); - const createSession = useCallback( - async (options: Partial = {}) => { - if (currentSession) { - return currentSession; - } - const session = await client.createSessionWithHistory({ - workspaceId, - promptName: 'Chat With AFFiNE AI' satisfies PromptKey, - reuseLatestChat: false, - ...options, - }); - setCurrentSession(session); - return session; - }, - [client, currentSession, workspaceId] - ); - - const togglePin = useCallback(async () => { - if (isTogglingPin) return; - setIsTogglingPin(true); - try { - const pinned = !currentSession?.pinned; - if (!currentSession) { - await createSession({ pinned }); - } else { - await client.updateSession({ - sessionId: currentSession.sessionId, - pinned, - }); - // retrieve the latest session and update the state - const session = await client.getSession( - workspaceId, - currentSession.sessionId - ); - setCurrentSession(session); - } - } finally { - setIsTogglingPin(false); - } - }, [client, createSession, currentSession, isTogglingPin, workspaceId]); - - // remove the old content to trigger re-mount - // to avoid infinitely load and mount, should not make `chatContent` as dependency - const reMountChatContent = useCallback(() => { - setChatContent(prev => { - prev?.remove(); - return null; - }); - }, []); - - const createFreshSession = useCallback(async () => { - if (isOpeningSession) { - return; - } - setIsOpeningSession(true); - try { - setCurrentSession(null); - reMountChatContent(); - const session = await client.createSessionWithHistory({ - workspaceId, - promptName: 'Chat With AFFiNE AI' satisfies PromptKey, - reuseLatestChat: false, - }); - setCurrentSession(session); - } catch (error) { - console.error(error); - } finally { - setIsOpeningSession(false); - } - }, [client, isOpeningSession, reMountChatContent, workspaceId]); - - const onOpenSession = useCallback( - async (sessionId: string) => { - if (isOpeningSession || currentSession?.sessionId === sessionId) return; - setIsOpeningSession(true); - try { - const session = await client.getSession(workspaceId, sessionId); - if (!session) { - // Drop stale tab if session no longer exists. - setOpenTabs(prev => prev.filter(tab => tab.sessionId !== sessionId)); - return; - } - setCurrentSession(session); - reMountChatContent(); - chatTool?.closeHistoryMenu(); - } catch (error) { - console.error(error); - } finally { - setIsOpeningSession(false); - } - }, - [ - chatTool, - client, - currentSession?.sessionId, - isOpeningSession, - reMountChatContent, - setOpenTabs, - workspaceId, - ] - ); - - const closeTab = useCallback( - (sessionId: string) => { - let fallback: NonNullable | undefined; - setOpenTabs(prev => { - const idx = prev.findIndex(tab => tab.sessionId === sessionId); - if (idx === -1) return prev; - const next = prev.filter(tab => tab.sessionId !== sessionId); - fallback = next[idx] ?? next[idx - 1] ?? next[0]; - return next; - }); - if (currentSession?.sessionId !== sessionId) return; - if (fallback) { - onOpenSession(fallback.sessionId).catch(console.error); - } else { - createFreshSession().catch(console.error); - } - }, - [createFreshSession, currentSession?.sessionId, onOpenSession, setOpenTabs] - ); - - const onContextChange = useCallback((context: Partial) => { - setStatus(context.status ?? 'idle'); - }, []); - const onOpenDoc = useCallback( (docId: string) => { workbench.openDoc(docId, { at: 'active' }); @@ -283,143 +149,86 @@ export const Component = () => { const mockStd = useMockStd(); const handleAISubscribe = useAISubscribe(); - const deleteSession = useMemo( - () => - createSessionDeleteHandler({ - t, - notificationService, - cleanupSession: async sessionToDelete => { - await client.cleanupSessions({ - workspaceId: sessionToDelete.workspaceId, - docId: sessionToDelete.docId || undefined, - sessionIds: [sessionToDelete.sessionId], - }); - }, - isActiveSession: sessionToDelete => - sessionToDelete.sessionId === currentSession?.sessionId, - onActiveSessionDeleted: () => { - setCurrentSession(null); - reMountChatContent(); - }, - }), - [ - client, - currentSession?.sessionId, - notificationService, - reMountChatContent, - t, - ] + const deleteSession = useCallback( + async (sessionToDelete: BlockSuitePresets.AIRecentSession) => { + const confirm = await notificationService.confirm({ + title: t['com.affine.ai.chat-panel.session.delete.confirm.title'](), + message: t['com.affine.ai.chat-panel.session.delete.confirm.message'](), + confirmText: t['Delete'](), + cancelText: t['Cancel'](), + }); + if (!confirm) return; + await runtime.dispatch({ + type: 'deleteSession', + sessionId: sessionToDelete.sessionId, + }); + notificationService.toast( + t['com.affine.ai.chat-panel.session.delete.toast.success'](), + {} + ); + }, + [notificationService, runtime, t] ); - // init or update ai-chat-content - useEffect(() => { - if (!isBodyProvided) { - return; - } - - let content = chatContent; - - if (!content) { - content = new AIChatContent(); - } - - content.session = currentSession; - content.workspaceId = workspaceId; - content.extensions = specs; - content.host = mockStd?.host; - content.docDisplayConfig = docDisplayConfig; - content.searchMenuConfig = searchMenuConfig; - content.reasoningConfig = reasoningConfig; - content.onContextChange = onContextChange; - content.affineFeatureFlagService = framework.get(FeatureFlagService); - content.affineWorkspaceDialogService = framework.get( - WorkspaceDialogService - ); - content.peekViewService = framework.get(PeekViewService); - content.affineThemeService = framework.get(AppThemeService); - content.notificationService = notificationService; - content.aiDraftService = framework.get(AIDraftService); - content.aiToolsConfigService = framework.get(AIToolsConfigService); - content.serverService = framework.get(ServerService); - content.subscriptionService = framework.get(SubscriptionService); - content.aiModelService = framework.get(AIModelService); - content.onAISubscribe = handleAISubscribe; - - content.createSession = createSession; - content.onOpenDoc = onOpenDoc; - - if (!chatContent) { - // initial values that won't change + useAIChatElement({ + containerRef: chatContainerRef, + selector: 'ai-chat-content', + enabled: isBodyProvided, + createElement: () => new AIChatContent(), + configureElement: content => { + content.session = session; + content.runtime = runtime; + content.runtimeSnapshot = snapshot; + content.workspaceId = workspaceId; + content.extensions = specs; + content.host = mockStd?.host; + content.docDisplayConfig = docDisplayConfig; + content.searchMenuConfig = searchMenuConfig; + content.reasoningConfig = reasoningConfig; + content.affineFeatureFlagService = framework.get(FeatureFlagService); + content.affineWorkspaceDialogService = framework.get( + WorkspaceDialogService + ); + content.peekViewService = framework.get(PeekViewService); + content.affineThemeService = framework.get(AppThemeService); + content.notificationService = notificationService; + content.aiDraftService = framework.get(AIDraftService); + content.aiToolsConfigService = framework.get(AIToolsConfigService); + content.serverService = framework.get(ServerService); + content.subscriptionService = framework.get(SubscriptionService); + content.aiModelService = framework.get(AIModelService); + content.onAISubscribe = handleAISubscribe; + content.onOpenDoc = onOpenDoc; + }, + onElementReady: content => { content.independentMode = true; content.onboardingOffsetY = -100; - chatContainerRef.current?.append(content); - setChatContent(content); - } - }, [ - chatContent, - createSession, - currentSession, - docDisplayConfig, - framework, - isBodyProvided, - mockStd, - reasoningConfig, - searchMenuConfig, - workspaceId, - onContextChange, - notificationService, - specs, - onOpenDoc, - handleAISubscribe, - ]); + }, + }); - // init or update header ai-chat-toolbar - useEffect(() => { - if (!isHeaderProvided || !chatToolContainerRef.current) { - return; - } - const tool = getOrCreateAIChatToolbar(chatTool); - configureAIChatToolbar(tool, { - session: currentSession, - workspaceId, - status, - docDisplayConfig, - notificationService, - onOpenSession: sessionId => { - onOpenSession(sessionId).catch(console.error); - }, - onNewSession: () => { - createFreshSession().catch(console.error); - }, - onTogglePin: togglePin, - onOpenDoc: (docId: string, sessionId: string) => { - onOpenSessionDoc(docId, sessionId); - }, - onSessionDelete: (sessionToDelete: BlockSuitePresets.AIRecentSession) => { - deleteSession(sessionToDelete).catch(console.error); - }, - }); - - // initial props - if (!chatTool) { - // mount - chatToolContainerRef.current.append(tool); - setChatTool(tool); - } - }, [ - chatTool, - currentSession, - docDisplayConfig, - isHeaderProvided, - onOpenSession, - togglePin, - workspaceId, - onOpenSessionDoc, - deleteSession, - status, - notificationService, - createFreshSession, - ]); + useAIChatElement({ + containerRef: chatToolContainerRef, + selector: 'ai-chat-toolbar', + enabled: isHeaderProvided, + createElement: () => new AIChatToolbar(), + configureElement: tool => { + configureAIChatToolbar(tool, { + session, + runtime, + runtimeSnapshot: snapshot ?? runtime.getSnapshot(), + docDisplayConfig, + notificationService, + onOpenDoc: (docId: string, sessionId: string) => { + onOpenSessionDoc(docId, sessionId); + }, + onSessionDelete: ( + sessionToDelete: BlockSuitePresets.AIRecentSession + ) => { + deleteSession(sessionToDelete).catch(console.error); + }, + }); + }, + }); useEffect(() => { const refNodeSlots = mockStd?.getOptional(RefNodeSlotsProvider); @@ -437,87 +246,16 @@ export const Component = () => { return () => sub.unsubscribe(); }, [framework, mockStd]); - useEffect(() => { - if (!currentSession?.sessionId) return; - setOpenTabs(prev => { - const existing = prev.findIndex( - tab => tab.sessionId === currentSession.sessionId - ); - if (existing !== -1) { - if (prev[existing] === currentSession) return prev; - const next = prev.slice(); - next[existing] = currentSession; - return next; - } - return [...prev, currentSession]; - }); - }, [currentSession, setOpenTabs]); - - useEffect(() => { - if (!chatTabsContainerRef.current) return; - let tabs = chatTabs; - if (!tabs) { - tabs = new AIChatTabs(); - chatTabsContainerRef.current.append(tabs); - setChatTabs(tabs); - } - tabs.sessions = openTabs; - tabs.activeSessionId = currentSession?.sessionId; - tabs.onSelectTab = (sessionId: string) => { - onOpenSession(sessionId).catch(console.error); - }; - tabs.onCloseTab = (sessionId: string) => { - closeTab(sessionId); - }; - }, [chatTabs, closeTab, currentSession?.sessionId, onOpenSession, openTabs]); - - // restore pinned session - useEffect(() => { - if (hasRestoredPinnedSessionRef.current || currentSession) return; - hasRestoredPinnedSessionRef.current = true; - - const controller = new AbortController(); - const loadPinnedSession = async () => { - try { - const sessions = await client.getSessions( - workspaceId, - {}, - undefined, - { pinned: true, limit: 1 }, - controller.signal - ); - if (controller.signal.aborted || !Array.isArray(sessions)) { - return; - } - const pinnedSession = sessions[0]; - if (!pinnedSession) { - return; - } - - let shouldRemount = false; - setCurrentSession(prev => { - if (prev) return prev; - shouldRemount = true; - return pinnedSession; - }); - if (shouldRemount) reMountChatContent(); - } catch (error) { - if (controller.signal.aborted) { - return; - } - console.error(error); - } - }; - loadPinnedSession().catch(error => { - if (controller.signal.aborted) return; - console.error(error); - }); - - // abort the request - return () => { - controller.abort(); - }; - }, [client, currentSession, reMountChatContent, workspaceId]); + useAIChatElement({ + containerRef: chatTabsContainerRef, + selector: 'ai-chat-tabs', + enabled: true, + createElement: () => new AIChatTabs(), + configureElement: tabs => { + tabs.runtime = runtime; + tabs.runtimeSnapshot = snapshot; + }, + }); const onChatContainerRef = useCallback((node: HTMLDivElement) => { if (node) { diff --git a/packages/frontend/core/src/desktop/pages/workspace/detail-page/detail-page.tsx b/packages/frontend/core/src/desktop/pages/workspace/detail-page/detail-page.tsx index 66852275ef..f70bd5ca84 100644 --- a/packages/frontend/core/src/desktop/pages/workspace/detail-page/detail-page.tsx +++ b/packages/frontend/core/src/desktop/pages/workspace/detail-page/detail-page.tsx @@ -1,7 +1,6 @@ import { Scrollable } from '@affine/component'; import { PageDetailLoading } from '@affine/component/page-detail-skeleton'; -import type { AIChatParams } from '@affine/core/blocksuite/ai'; -import { AIProvider } from '@affine/core/blocksuite/ai'; +import { AIAppEvents, type AIChatParams } from '@affine/core/blocksuite/ai'; import type { AffineEditorContainer } from '@affine/core/blocksuite/block-suite-editor'; import { EditorOutlineViewer } from '@affine/core/blocksuite/outline-viewer'; import { AffineErrorBoundary } from '@affine/core/components/affine/affine-error-boundary'; @@ -145,12 +144,8 @@ const DetailPageImpl = memo(function DetailPageImpl() { workbench.openSidebar(); view.activeSidebarTab('chat'); }; - disposables.push( - AIProvider.slots.requestOpenWithChat.subscribe(openHandler) - ); - disposables.push( - AIProvider.slots.requestSendWithChat.subscribe(openHandler) - ); + disposables.push(AIAppEvents.requestOpenWithChat.subscribe(openHandler)); + disposables.push(AIAppEvents.requestSendWithChat.subscribe(openHandler)); return () => disposables.forEach(d => d.unsubscribe()); }, [activeSidebarTab, view, workbench]); @@ -378,7 +373,7 @@ const DetailPageImpl = memo(function DetailPageImpl() { icon={} unmountOnInactive={false} > - + )} diff --git a/packages/frontend/core/src/desktop/pages/workspace/detail-page/tabs/chat-panel-session.spec.ts b/packages/frontend/core/src/desktop/pages/workspace/detail-page/tabs/chat-panel-session.spec.ts deleted file mode 100644 index d37ade8aa4..0000000000 --- a/packages/frontend/core/src/desktop/pages/workspace/detail-page/tabs/chat-panel-session.spec.ts +++ /dev/null @@ -1,341 +0,0 @@ -/* eslint-disable rxjs/finnish */ -import { type CopilotChatHistoryFragment } from '@affine/graphql'; -import { describe, expect, test, vi } from 'vitest'; - -import { - canCreateNewDocPanelSession, - filterDocPanelTabs, - getChatContentKey, - hasSessionMessages, - isSessionAvailableInDocPanel, - resolveInitialSession, - type SessionService, - shouldResetChatPanelOnUserInfoChange, - type WorkbenchLike, -} from './chat-panel-session'; - -const createWorkbench = (search: string) => { - const updateQueryString = vi.fn(); - const workbench = { - location$: { value: { search } }, - activeView$: { value: { updateQueryString } }, - } satisfies WorkbenchLike; - - return { workbench, updateQueryString }; -}; - -const doc = { id: 'doc-1', workspace: { id: 'ws-1' } }; - -describe('getChatContentKey', () => { - const cases = [ - { - name: 'uses doc id before a session is created', - input: { - docId: 'doc-1', - hasPinned: false, - session: null, - }, - expected: 'doc-1', - }, - { - name: 'keeps a new empty doc session on the doc key', - input: { - docId: 'doc-2', - hasPinned: false, - previousSessionDocId: 'doc-1', - previousSessionId: 'session-1', - session: { - sessionId: 'session-2', - docId: 'doc-2', - messages: [], - }, - }, - expected: 'doc-2', - }, - { - name: 'uses session id for a session with history', - input: { - docId: 'doc-1', - hasPinned: false, - session: { - sessionId: 'session-1', - docId: 'doc-1', - messages: [{ id: 'message-1' }], - }, - }, - expected: 'session-1', - }, - { - name: 'uses session id for a pinned session', - input: { - docId: 'doc-1', - hasPinned: true, - session: { - sessionId: 'session-1', - docId: 'doc-1', - messages: [], - }, - }, - expected: 'session-1', - }, - { - name: 'uses session id for same-doc session switch', - input: { - docId: 'doc-1', - hasPinned: false, - previousSessionDocId: 'doc-1', - previousSessionId: 'session-1', - session: { - sessionId: 'session-2', - docId: 'doc-1', - messages: [], - }, - }, - expected: 'session-2', - }, - { - name: 'keeps generating draft session on the doc key', - input: { - docId: 'doc-1', - hasPinned: false, - isGenerating: true, - previousSessionDocId: 'doc-1', - previousSessionId: 'session-1', - session: { - sessionId: 'session-2', - docId: 'doc-1', - messages: [], - }, - }, - expected: 'doc-1', - }, - ] satisfies { - name: string; - input: Parameters[0]; - expected: string; - }[]; - - test.each(cases)('$name', ({ input, expected }) => { - expect(getChatContentKey(input)).toBe(expected); - }); -}); - -describe('shouldResetChatPanelOnUserInfoChange', () => { - const cases = [ - { - name: 'ignores the initial user info emission', - input: { - previousUserId: undefined, - nextUserId: 'user-1', - }, - expected: false, - }, - { - name: 'ignores same-user refreshes', - input: { - previousUserId: 'user-1', - nextUserId: 'user-1', - }, - expected: false, - }, - { - name: 'resets when the effective user changes', - input: { - previousUserId: 'user-1', - nextUserId: 'user-2', - }, - expected: true, - }, - { - name: 'resets when the effective user signs out', - input: { - previousUserId: 'user-1', - nextUserId: null, - }, - expected: true, - }, - ] satisfies { - name: string; - input: Parameters[0]; - expected: boolean; - }[]; - - test.each(cases)('$name', ({ input, expected }) => { - expect(shouldResetChatPanelOnUserInfoChange(input)).toBe(expected); - }); -}); - -describe('doc panel tabs', () => { - const sessions = [ - { sessionId: 'current-doc-session', docId: 'doc-1' }, - { sessionId: 'workspace-session', docId: null }, - { sessionId: 'other-doc-session', docId: 'doc-2' }, - ]; - - test('allows only current doc or workspace sessions', () => { - expect(filterDocPanelTabs(sessions, 'doc-1')).toEqual([ - sessions[0], - sessions[1], - ]); - }); - - test('rejects other document sessions', () => { - expect(isSessionAvailableInDocPanel(sessions[2], 'doc-1')).toBe(false); - }); -}); - -describe('new session guard', () => { - test('allows a new session only after the current chat has messages', () => { - expect( - canCreateNewDocPanelSession({ - hasContextMessages: false, - session: { messages: [] }, - status: 'idle', - }) - ).toBe(false); - expect( - canCreateNewDocPanelSession({ - hasContextMessages: true, - session: { messages: [] }, - status: 'idle', - }) - ).toBe(true); - expect(hasSessionMessages({ messages: [{ id: 'message-1' }] })).toBe(true); - }); - - test('does not allow a new session while generating', () => { - expect( - canCreateNewDocPanelSession({ - hasContextMessages: true, - session: null, - status: 'loading', - }) - ).toBe(false); - }); -}); - -test('returns undefined without session service or doc', async () => { - await expect( - resolveInitialSession({ sessionService: null, doc, workbench: null }) - ).resolves.toBeUndefined(); - await expect( - resolveInitialSession({ - sessionService: { - getSessions: vi.fn(), - getSession: vi.fn(), - }, - doc: null, - workbench: null, - }) - ).resolves.toBeUndefined(); -}); - -describe('resolveInitialSession', () => { - test('prefers pinned session and clears sessionId from url', async () => { - const pinnedSession = { - sessionId: 'pinned-session', - pinned: true, - } as CopilotChatHistoryFragment; - - const sessionService: SessionService = { - getSessions: vi.fn().mockResolvedValueOnce([pinnedSession]), - getSession: vi.fn(), - }; - - const { workbench, updateQueryString } = createWorkbench( - '?sessionId=from-url' - ); - - const result = await resolveInitialSession({ - sessionService, - doc, - workbench, - }); - - expect(result).toBe(pinnedSession); - expect(updateQueryString).toHaveBeenCalledWith( - { sessionId: undefined }, - { replace: true } - ); - expect(sessionService.getSession).not.toHaveBeenCalled(); - }); - - test('loads session from url when no pinned session', async () => { - const sessionFromUrl = { - sessionId: 'url-session', - pinned: false, - } as CopilotChatHistoryFragment; - - const sessionService: SessionService = { - getSessions: vi.fn().mockResolvedValueOnce([]), - getSession: vi.fn().mockResolvedValueOnce(sessionFromUrl), - }; - - const { workbench, updateQueryString } = createWorkbench( - '?sessionId=url-session' - ); - - const result = await resolveInitialSession({ - sessionService, - doc, - workbench, - }); - - expect(result).toBe(sessionFromUrl); - expect(sessionService.getSession).toHaveBeenCalledWith( - doc.workspace.id, - 'url-session' - ); - expect(updateQueryString).toHaveBeenCalledWith( - { sessionId: undefined }, - { replace: true } - ); - }); - - test('falls back to latest doc session', async () => { - const docSession = { - sessionId: 'doc-session', - pinned: false, - } as CopilotChatHistoryFragment; - - const sessionService: SessionService = { - getSessions: vi - .fn() - .mockResolvedValueOnce([]) - .mockResolvedValueOnce([docSession]), - getSession: vi.fn(), - }; - - const { workbench } = createWorkbench(''); - - const result = await resolveInitialSession({ - sessionService, - doc, - workbench, - }); - - expect(result).toBe(docSession); - expect(sessionService.getSessions).toHaveBeenCalledWith( - doc.workspace.id, - doc.id, - { action: false, fork: false, limit: 1 } - ); - }); - - test('returns null when url session is missing', async () => { - const sessionService: SessionService = { - getSessions: vi.fn().mockResolvedValueOnce([]), - getSession: vi.fn().mockResolvedValueOnce(null), - }; - - const { workbench } = createWorkbench('?sessionId=missing'); - - const result = await resolveInitialSession({ - sessionService, - doc, - workbench, - }); - - expect(result).toBeNull(); - }); -}); diff --git a/packages/frontend/core/src/desktop/pages/workspace/detail-page/tabs/chat-panel-session.ts b/packages/frontend/core/src/desktop/pages/workspace/detail-page/tabs/chat-panel-session.ts deleted file mode 100644 index 4875a5666b..0000000000 --- a/packages/frontend/core/src/desktop/pages/workspace/detail-page/tabs/chat-panel-session.ts +++ /dev/null @@ -1,208 +0,0 @@ -/* eslint-disable rxjs/finnish */ -import type { CopilotChatHistoryFragment } from '@affine/graphql'; - -type SessionListOptions = { - pinned?: boolean; - action?: boolean; - fork?: boolean; - limit?: number; -}; - -export interface SessionService { - getSessions: ( - workspaceId: string, - docId?: string, - options?: SessionListOptions - ) => Promise; - getSession: ( - workspaceId: string, - sessionId: string - ) => Promise; -} - -export interface WorkbenchLike { - location$: { - value: { - search: string; - }; - }; - activeView$: { - value: { - updateQueryString: ( - patch: Record, - options?: { replace?: boolean } - ) => void; - }; - }; -} - -export interface DocLike { - id: string; - workspace: { - id: string; - }; -} - -interface ChatContentKeySession { - sessionId?: string | null; - docId?: string | null; - messages?: readonly unknown[] | null; -} - -type TabSession = { - sessionId: string; - docId?: string | null; - messages?: readonly unknown[] | null; -}; - -export const shouldResetChatPanelOnUserInfoChange = ({ - previousUserId, - nextUserId, -}: { - previousUserId?: string | null; - nextUserId?: string | null; -}) => { - return previousUserId !== undefined && previousUserId !== nextUserId; -}; - -export const getChatContentKey = ({ - docId, - hasPinned, - isGenerating, - previousSessionDocId, - previousSessionId, - session, -}: { - docId?: string | null; - hasPinned: boolean; - isGenerating?: boolean; - previousSessionDocId?: string | null; - previousSessionId?: string | null; - session?: ChatContentKeySession | null; -}) => { - const fallbackKey = docId ?? 'chat-panel'; - const sessionId = session?.sessionId; - if (!sessionId) { - return fallbackKey; - } - - const sessionDocId = session.docId ?? docId ?? null; - const hasSessionHistory = !!session.messages?.length; - const shouldPreserveTransientMessages = isGenerating && !hasSessionHistory; - const sessionSwitchedWithinDoc = !!( - previousSessionId && - previousSessionId !== sessionId && - previousSessionDocId && - sessionDocId && - previousSessionDocId === sessionDocId && - sessionDocId === docId - ); - - return hasPinned || - hasSessionHistory || - (sessionSwitchedWithinDoc && !shouldPreserveTransientMessages) - ? sessionId - : fallbackKey; -}; - -export const isSessionAvailableInDocPanel = ( - session: TabSession, - docId?: string | null -) => { - return !session.docId || session.docId === docId; -}; - -export const filterDocPanelTabs = ( - sessions: T[], - docId?: string | null -) => { - return sessions.filter(session => - isSessionAvailableInDocPanel(session, docId) - ); -}; - -export const hasSessionMessages = ( - session?: Pick | null -) => { - return !!session?.messages?.length; -}; - -export const canCreateNewDocPanelSession = ({ - hasContextMessages, - session, - status, -}: { - hasContextMessages: boolean; - session?: Pick | null; - status?: string | null; -}) => { - return ( - (hasContextMessages || hasSessionMessages(session)) && - status !== 'loading' && - status !== 'transmitting' - ); -}; - -export const getSessionIdFromUrl = (workbench?: WorkbenchLike | null) => { - if (!workbench) { - return undefined; - } - const searchParams = new URLSearchParams(workbench.location$.value.search); - const sessionId = searchParams.get('sessionId'); - if (sessionId) { - workbench.activeView$.value.updateQueryString( - { sessionId: undefined }, - { replace: true } - ); - } - return sessionId ?? undefined; -}; - -export const resolveInitialSession = async ({ - sessionService, - doc, - workbench, -}: { - sessionService?: SessionService | null; - doc?: DocLike | null; - workbench?: WorkbenchLike | null; -}): Promise => { - if (!sessionService || !doc) { - return undefined; - } - - const sessionId = getSessionIdFromUrl(workbench); - - const pinSessions = await sessionService.getSessions( - doc.workspace.id, - undefined, - { - pinned: true, - limit: 1, - } - ); - - if (Array.isArray(pinSessions) && pinSessions[0]) { - return pinSessions[0]; - } - - if (sessionId) { - const session = await sessionService.getSession( - doc.workspace.id, - sessionId - ); - return session ?? null; - } - - const docSessions = await sessionService.getSessions( - doc.workspace.id, - doc.id, - { - action: false, - fork: false, - limit: 1, - } - ); - - return docSessions?.[0] ?? null; -}; diff --git a/packages/frontend/core/src/desktop/pages/workspace/detail-page/tabs/chat.tsx b/packages/frontend/core/src/desktop/pages/workspace/detail-page/tabs/chat.tsx index 7f0d6b4218..1c22cb0b15 100644 --- a/packages/frontend/core/src/desktop/pages/workspace/detail-page/tabs/chat.tsx +++ b/packages/frontend/core/src/desktop/pages/workspace/detail-page/tabs/chat.tsx @@ -1,16 +1,18 @@ import { useConfirmModal } from '@affine/component'; -import { AIProvider } from '@affine/core/blocksuite/ai'; -import type { AppSidebarConfig } from '@affine/core/blocksuite/ai/chat-panel/chat-config'; import { - AIChatContent, - type ChatContextValue, -} from '@affine/core/blocksuite/ai/components/ai-chat-content'; -import type { ChatStatus } from '@affine/core/blocksuite/ai/components/ai-chat-messages'; -import type { AIChatToolbar } from '@affine/core/blocksuite/ai/components/ai-chat-toolbar'; + AIAppEvents, + AIChatRuntime, + createAIRequestService, + DocAIChatSessionStrategy, + useAIChatElement, + useAIChatRuntime, +} from '@affine/core/blocksuite/ai'; +import type { AppSidebarConfig } from '@affine/core/blocksuite/ai/chat-panel/chat-config'; +import { AIChatContent } from '@affine/core/blocksuite/ai/components/ai-chat-content'; import { AIChatTabs, + AIChatToolbar, configureAIChatToolbar, - getOrCreateAIChatToolbar, } from '@affine/core/blocksuite/ai/components/ai-chat-toolbar'; import { createPlaygroundModal } from '@affine/core/blocksuite/ai/components/playground/modal'; import { registerAIAppEffects } from '@affine/core/blocksuite/ai/effects/app'; @@ -24,52 +26,55 @@ import { AIToolsConfigService, } from '@affine/core/modules/ai-button'; import { AIModelService } from '@affine/core/modules/ai-button/services/models'; -import { ServerService, SubscriptionService } from '@affine/core/modules/cloud'; +import { + EventSourceService, + GraphQLService, + ServerService, + SubscriptionService, +} from '@affine/core/modules/cloud'; import { WorkspaceDialogService } from '@affine/core/modules/dialogs'; import { useSignalValue } from '@affine/core/modules/doc-info/utils'; import { FeatureFlagService } from '@affine/core/modules/feature-flag'; import { PeekViewService } from '@affine/core/modules/peek-view'; import { AppThemeService } from '@affine/core/modules/theme'; import { WorkbenchService } from '@affine/core/modules/workbench'; -import type { - ContextEmbedStatus, - CopilotChatHistoryFragment, - UpdateChatSessionInput, -} from '@affine/graphql'; import { useI18n } from '@affine/i18n'; import { RefNodeSlotsProvider } from '@blocksuite/affine/inlines/reference'; import { DocModeProvider } from '@blocksuite/affine/shared/services'; import { createSignalFromObservable } from '@blocksuite/affine/shared/utils'; +import type { Store } from '@blocksuite/affine/store'; import { CenterPeekIcon, Logo1Icon } from '@blocksuite/icons/rc'; import type { Signal } from '@preact/signals-core'; import { useFramework, useService } from '@toeverything/infra'; import { html } from 'lit'; import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; -import { - createSessionDeleteHandler, - useAIChatOpenTabs, -} from '../../chat-panel-utils'; import * as styles from './chat.css'; -import { - canCreateNewDocPanelSession, - filterDocPanelTabs, - getChatContentKey, - isSessionAvailableInDocPanel, - resolveInitialSession, - shouldResetChatPanelOnUserInfoChange, - type WorkbenchLike, -} from './chat-panel-session'; registerAIAppEffects(); +const shouldResetChatPanelOnUserInfoChange = ({ + previousUserId, + nextUserId, +}: { + previousUserId?: string | null; + nextUserId?: string | null; +}) => previousUserId !== undefined && previousUserId !== nextUserId; + export interface SidebarTabProps { editor: AffineEditorContainer | null; + doc: Store; onLoad?: ((component: HTMLElement) => void) | null; } -export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => { +export const EditorChatPanel = ({ + editor, + doc: fallbackDoc, + onLoad, +}: SidebarTabProps) => { const framework = useFramework(); + const graphqlService = useService(GraphQLService); + const eventSourceService = useService(EventSourceService); const workbench = useService(WorkbenchService).workbench; const t = useI18n(); @@ -89,81 +94,57 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => { } = useAIChatConfig(); const playgroundVisible = useSignalValue(playgroundConfig.visible) ?? false; - const [session, setSession] = useState< - CopilotChatHistoryFragment | null | undefined - >(undefined); - const [embeddingProgress, setEmbeddingProgress] = useState<[number, number]>([ - 0, 0, - ]); - const [status, setStatus] = useState('idle'); - const [hasPinned, setHasPinned] = useState(false); - - const [chatContent, setChatContent] = useState(null); - const [chatToolbar, setChatToolbar] = useState(null); - const [chatTabs, setChatTabs] = useState(null); const [isBodyProvided, setIsBodyProvided] = useState(false); const [isHeaderProvided, setIsHeaderProvided] = useState(false); const chatContainerRef = useRef(null); const chatToolbarContainerRef = useRef(null); const chatTabsContainerRef = useRef(null); - const contentKeyRef = useRef(null); - const prevSessionIdRef = useRef(null); - const prevSessionDocIdRef = useRef(null); - const lastDocIdRef = useRef(null); - const sessionLoadSeqRef = useRef(0); - const creatingSessionRef = useRef<{ - docId: string; - promise: Promise; - } | null>(null); - const creatingFreshSessionRef = useRef<{ - docId: string; - promise: Promise; - } | null>(null); const userIdRef = useRef(undefined); - const doc = editor?.doc; + const doc = editor?.doc ?? fallbackDoc; const host = editor?.host; const workspaceId = doc?.workspace.id; - - const [sessionServiceReady, setSessionServiceReady] = useState( - () => !!AIProvider.session + const requestService = useMemo( + () => + createAIRequestService( + graphqlService.gql, + eventSourceService.eventSource + ), + [eventSourceService.eventSource, graphqlService.gql] ); - const [hasContextMessages, setHasContextMessages] = useState(false); - useEffect(() => { - if (sessionServiceReady) return; - if (AIProvider.session) { - setSessionServiceReady(true); - return; - } - const sub = AIProvider.slots.sessionReady.subscribe(ready => { - if (ready) setSessionServiceReady(true); - }); - return () => sub.unsubscribe(); - }, [sessionServiceReady]); - - const loadSession = useMemo(() => { - if (!sessionServiceReady || !workspaceId) return null; - const sessionService = AIProvider.session; - if (!sessionService) return null; - return async ( - sessionId: string - ): Promise => - sessionService.getSession(workspaceId, sessionId); - }, [sessionServiceReady, workspaceId]); - - const { openTabs, setOpenTabs } = - useAIChatOpenTabs(loadSession); - const visibleOpenTabs = useMemo( - () => filterDocPanelTabs(openTabs, doc?.id), - [doc?.id, openTabs] - ); - const canCreateNewSession = canCreateNewDocPanelSession({ - hasContextMessages, - session, - status, + const [pendingSessionId] = useState(() => { + const searchParams = new URLSearchParams(workbench.location$.value.search); + return searchParams.get('sessionId') ?? undefined; }); + useEffect(() => { + if (pendingSessionId) { + workbench.activeView$.value.updateQueryString( + { sessionId: undefined }, + { replace: true } + ); + } + }, [pendingSessionId, workbench]); + + const runtime = useMemo(() => { + if (!doc || !workspaceId) return null; + return new AIChatRuntime({ + request: requestService, + scope: { + kind: 'doc', + workspaceId, + docId: doc.id, + pendingSessionId, + }, + strategy: new DocAIChatSessionStrategy(), + }); + }, [doc, pendingSessionId, requestService, workspaceId]); + const snapshot = useAIChatRuntime(runtime); + const session = + snapshot?.sessions.find( + item => item.sessionId === snapshot.activeSessionId + ) ?? null; const appSidebarConfig = useMemo(() => { return { getWidth: () => @@ -188,159 +169,14 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => { return cleanup; }, [appSidebarConfig]); - const resetPanel = useCallback(() => { - sessionLoadSeqRef.current += 1; - setSession(undefined); - setEmbeddingProgress([0, 0]); - setHasPinned(false); - }, []); - - const initPanel = useCallback(async () => { - const requestSeq = ++sessionLoadSeqRef.current; - try { - const nextSession = await resolveInitialSession({ - sessionService: AIProvider.session ?? undefined, - doc, - workbench: workbench as WorkbenchLike, - }); - - if (requestSeq !== sessionLoadSeqRef.current) return; - if (nextSession === undefined) { - return; - } - - setSession(nextSession); - setHasPinned(!!nextSession?.pinned); - } catch (error) { - console.error(error); - } - }, [doc, workbench]); - - const createSession = useCallback( - async (options: Partial = {}) => { - if (session || !AIProvider.session || !doc) { - return session ?? undefined; - } - if (creatingSessionRef.current?.docId === doc.id) { - return creatingSessionRef.current.promise; - } - const requestSeq = ++sessionLoadSeqRef.current; - let promise: Promise; - promise = AIProvider.session - .createSessionWithHistory({ - docId: doc.id, - workspaceId: doc.workspace.id, - promptName: 'Chat With AFFiNE AI', - reuseLatestChat: false, - ...options, - }) - .then(nextSession => { - if (requestSeq !== sessionLoadSeqRef.current) return undefined; - setSession(nextSession ?? null); - setHasPinned(!!nextSession?.pinned); - return nextSession ?? undefined; - }) - .finally(() => { - if (creatingSessionRef.current?.promise === promise) { - creatingSessionRef.current = null; - } - }); - creatingSessionRef.current = { docId: doc.id, promise }; - return promise; - }, - [doc, session] - ); - - const updateSession = useCallback( - async (options: UpdateChatSessionInput) => { - if (!AIProvider.session || !doc) { - return undefined; - } - const requestSeq = ++sessionLoadSeqRef.current; - await AIProvider.session.updateSession(options); - const nextSession = await AIProvider.session.getSession( - doc.workspace.id, - options.sessionId - ); - if (requestSeq !== sessionLoadSeqRef.current) return undefined; - setSession(nextSession ?? null); - setHasPinned(!!nextSession?.pinned); - return nextSession ?? undefined; - }, - [doc] - ); - - const newSession = useCallback(async () => { - if (!canCreateNewSession) { - return; - } - if (doc && creatingFreshSessionRef.current?.docId === doc.id) { - return creatingFreshSessionRef.current.promise; - } - resetPanel(); - const requestSeq = sessionLoadSeqRef.current; - setSession(null); - setHasContextMessages(false); - - if (!AIProvider.session || !doc) { - return; - } - - let promise: Promise; - promise = AIProvider.session - .createSessionWithHistory({ - docId: doc.id, - workspaceId: doc.workspace.id, - promptName: 'Chat With AFFiNE AI', - reuseLatestChat: false, - }) - .then(nextSession => { - if (requestSeq === sessionLoadSeqRef.current) { - setSession(nextSession ?? null); - setHasPinned(!!nextSession?.pinned); - } - }) - .catch(console.error) - .finally(() => { - if (creatingFreshSessionRef.current?.promise === promise) { - creatingFreshSessionRef.current = null; - } - }); - creatingFreshSessionRef.current = { docId: doc.id, promise }; - return promise; - }, [canCreateNewSession, doc, resetPanel]); - const openSession = useCallback( async (sessionId: string) => { - if (session?.sessionId === sessionId || !AIProvider.session || !doc) { + if (session?.sessionId === sessionId || !runtime) { return; } - const requestSeq = ++sessionLoadSeqRef.current; - try { - const nextSession = await AIProvider.session.getSession( - doc.workspace.id, - sessionId - ); - if (requestSeq !== sessionLoadSeqRef.current) return; - if (!nextSession) { - // Drop stale tab if session no longer exists. - setOpenTabs(prev => prev.filter(tab => tab.sessionId !== sessionId)); - return; - } - if (!isSessionAvailableInDocPanel(nextSession, doc.id)) { - setOpenTabs([]); - workbench.open(`/${nextSession.docId}?sessionId=${sessionId}`, { - at: 'active', - }); - return; - } - setSession(nextSession); - setHasPinned(!!nextSession.pinned); - } catch (error) { - console.error(error); - } + await runtime.dispatch({ type: 'openSession', sessionId }); }, - [doc, session?.sessionId, setOpenTabs, workbench] + [runtime, session?.sessionId] ); const openDoc = useCallback( @@ -361,159 +197,47 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => { workbench.open(`/${docId}`, { at: 'active' }); return; } - setOpenTabs([]); workbench.open(`/${docId}?sessionId=${sessionId}`, { at: 'active' }); }, - [ - doc, - openSession, - session?.pinned, - session?.sessionId, - setOpenTabs, - workbench, - ] - ); - - const deleteSession = useMemo( - () => - createSessionDeleteHandler({ - t, - notificationService, - canDeleteSession: () => Boolean(AIProvider.histories), - cleanupSession: async sessionToDelete => { - await AIProvider.histories?.cleanup( - sessionToDelete.workspaceId, - sessionToDelete.docId || undefined, - [sessionToDelete.sessionId] - ); - }, - isActiveSession: sessionToDelete => - sessionToDelete.sessionId === session?.sessionId, - onActiveSessionDeleted: () => { - resetPanel(); - setSession(null); - setHasContextMessages(false); - }, - }), - [notificationService, resetPanel, session?.sessionId, t] - ); - - const closeTab = useCallback( - (sessionId: string) => { - let fallback: CopilotChatHistoryFragment | undefined; - setOpenTabs(prev => { - const idx = prev.findIndex(tab => tab.sessionId === sessionId); - if (idx === -1) return prev; - const next = prev.filter(tab => tab.sessionId !== sessionId); - const visibleNext = filterDocPanelTabs(next, doc?.id); - fallback = visibleNext[idx] ?? visibleNext[idx - 1] ?? visibleNext[0]; - return next; - }); - if (session?.sessionId !== sessionId) return; - if (fallback) { - openSession(fallback.sessionId).catch(console.error); - } else { - resetPanel(); - setSession(null); - setHasContextMessages(false); - } - }, - [doc?.id, openSession, resetPanel, session?.sessionId, setOpenTabs] - ); - - const togglePin = useCallback(async () => { - const pinned = !session?.pinned; - setHasPinned(true); - if (!session) { - await createSession({ pinned }); - return; - } - setSession(prev => (prev ? { ...prev, pinned } : prev)); - await updateSession({ - sessionId: session.sessionId, - pinned, - }); - }, [createSession, session, updateSession]); - - const rebindSession = useCallback(async () => { - if (!session || !doc) { - return; - } - if (session.docId !== doc.id) { - await updateSession({ - sessionId: session.sessionId, - docId: doc.id, - }); - } - }, [doc, session, updateSession]); - - const onEmbeddingProgressChange = useCallback( - (count: Record) => { - const total = count.finished + count.processing + count.failed; - setEmbeddingProgress([count.finished, total]); - }, - [] - ); - - const onContextChange = useCallback( - (context: Partial) => { - if (context.status) { - setStatus(context.status); - } - if (context.messages) { - setHasContextMessages(context.messages.length > 0); - } - if (context.status === 'success') { - rebindSession().catch(console.error); - } - }, - [rebindSession] + [doc, openSession, session?.pinned, session?.sessionId, workbench] ); useEffect(() => { - if (session !== undefined) { + const navigationRequest = snapshot?.navigationRequest; + if (!navigationRequest) { return; } - if (chatContent) { - chatContent.remove(); - setChatContent(null); - } - if (chatToolbar) { - chatToolbar.remove(); - setChatToolbar(null); - } - if (chatTabs) { - chatTabs.remove(); - setChatTabs(null); - } - }, [chatContent, chatTabs, chatToolbar, session]); + workbench.open( + `/${navigationRequest.docId}?sessionId=${navigationRequest.sessionId}`, + { at: 'active' } + ); + }, [snapshot?.navigationRequest, workbench]); - useEffect(() => { - if (!session?.sessionId) return; - setOpenTabs(prev => { - const existing = prev.findIndex( - tab => tab.sessionId === session.sessionId + const deleteSession = useCallback( + async (sessionToDelete: BlockSuitePresets.AIRecentSession) => { + if (!runtime) return; + const confirm = await notificationService.confirm({ + title: t['com.affine.ai.chat-panel.session.delete.confirm.title'](), + message: t['com.affine.ai.chat-panel.session.delete.confirm.message'](), + confirmText: t['Delete'](), + cancelText: t['Cancel'](), + }); + if (!confirm) return; + await runtime.dispatch({ + type: 'deleteSession', + sessionId: sessionToDelete.sessionId, + }); + notificationService.toast( + t['com.affine.ai.chat-panel.session.delete.toast.success'](), + {} ); - if (existing !== -1) { - if (prev[existing] === session) return prev; - const next = prev.slice(); - next[existing] = session; - return next; - } - return [...prev, session]; - }); - }, [session, setOpenTabs]); + }, + [notificationService, runtime, t] + ); useEffect(() => { - let disposed = false; - Promise.resolve(AIProvider.userInfo) - .then(userInfo => { - if (!disposed && userIdRef.current === undefined) { - userIdRef.current = userInfo?.id ?? null; - } - }) - .catch(console.error); - const subscription = AIProvider.slots.userInfo.subscribe(userInfo => { + userIdRef.current ??= AIAppEvents.userInfo.value?.id ?? null; + const subscription = AIAppEvents.userInfo.subscribe(userInfo => { const nextUserId = userInfo?.id ?? null; const shouldReset = shouldResetChatPanelOnUserInfoChange({ previousUserId: userIdRef.current, @@ -523,220 +247,90 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => { if (!shouldReset) { return; } - resetPanel(); - initPanel().catch(console.error); + runtime?.dispatch({ type: 'initialize' }).catch(console.error); }); return () => { - disposed = true; subscription.unsubscribe(); }; - }, [initPanel, resetPanel]); + }, [runtime]); - useEffect(() => { - const docId = doc?.id; - if (!docId) { - return; - } - if ( - lastDocIdRef.current && - lastDocIdRef.current !== docId && - !session?.pinned - ) { - resetPanel(); - setHasContextMessages(false); - } - lastDocIdRef.current = docId; - }, [doc?.id, resetPanel, session?.pinned]); - - useEffect(() => { - if (!doc || session !== undefined) { - return; - } - if (AIProvider.session) { - initPanel().catch(console.error); - return; - } - const subscription = AIProvider.slots.sessionReady.subscribe(ready => { - if (!ready || session !== undefined) return; - initPanel().catch(console.error); - }); - return () => subscription.unsubscribe(); - }, [doc, initPanel, session]); - - const contentKey = getChatContentKey({ - docId: doc?.id, - hasPinned, - isGenerating: status === 'loading' || status === 'transmitting', - previousSessionDocId: prevSessionDocIdRef.current, - previousSessionId: prevSessionIdRef.current, - session, + const chatContent = useAIChatElement({ + containerRef: chatContainerRef, + selector: 'ai-chat-content', + enabled: isBodyProvided && !!runtime && !!snapshot, + createElement: () => new AIChatContent(), + configureElement: content => { + if (!runtime || !snapshot) return; + content.host = host; + content.session = session; + content.runtime = runtime; + content.runtimeSnapshot = snapshot; + content.workspaceId = doc.workspace.id; + content.docId = doc.id; + content.reasoningConfig = reasoningConfig; + content.searchMenuConfig = searchMenuConfig; + content.docDisplayConfig = docDisplayConfig; + content.extensions = specs; + content.serverService = framework.get(ServerService); + content.affineFeatureFlagService = framework.get(FeatureFlagService); + content.affineWorkspaceDialogService = framework.get( + WorkspaceDialogService + ); + content.affineThemeService = framework.get(AppThemeService); + content.notificationService = notificationService; + content.aiDraftService = framework.get(AIDraftService); + content.aiToolsConfigService = framework.get(AIToolsConfigService); + content.peekViewService = framework.get(PeekViewService); + content.subscriptionService = framework.get(SubscriptionService); + content.aiModelService = framework.get(AIModelService); + content.onAISubscribe = handleAISubscribe; + content.width = sidebarWidthSignal; + content.onOpenDoc = (docId: string, sessionId?: string) => { + openDoc(docId, sessionId).catch(console.error); + }; + }, + onElementReady: content => { + onLoad?.(content); + }, }); - useEffect(() => { - if (session?.sessionId) { - prevSessionIdRef.current = session.sessionId; - prevSessionDocIdRef.current = session.docId ?? doc?.id ?? null; - } - }, [doc?.id, session?.docId, session?.sessionId]); + useAIChatElement({ + containerRef: chatToolbarContainerRef, + selector: 'ai-chat-toolbar', + enabled: isHeaderProvided && !!runtime && !!snapshot, + createElement: () => new AIChatToolbar(), + configureElement: tool => { + if (!runtime || !snapshot) return; + configureAIChatToolbar(tool, { + session, + runtime, + runtimeSnapshot: snapshot, + docId: doc.id, + docDisplayConfig, + notificationService, + onOpenDoc: (docId: string, sessionId: string) => { + openDoc(docId, sessionId).catch(console.error); + }, + onSessionDelete: ( + sessionToDelete: BlockSuitePresets.AIRecentSession + ) => { + deleteSession(sessionToDelete).catch(console.error); + }, + }); + }, + }); - useEffect(() => { - if (!chatContent) { - contentKeyRef.current = contentKey; - return; - } - if (contentKeyRef.current && contentKeyRef.current !== contentKey) { - chatContent.remove(); - setChatContent(null); - } - contentKeyRef.current = contentKey; - }, [chatContent, contentKey]); - - useEffect(() => { - if (!isBodyProvided || !chatContainerRef.current || !doc || !host) { - return; - } - if (session === undefined) { - return; - } - - let content = chatContent; - - if (!content) { - content = new AIChatContent(); - } - - content.host = host; - content.session = session; - content.createSession = createSession; - content.workspaceId = doc.workspace.id; - content.docId = doc.id; - content.reasoningConfig = reasoningConfig; - content.searchMenuConfig = searchMenuConfig; - content.docDisplayConfig = docDisplayConfig; - content.extensions = specs; - content.serverService = framework.get(ServerService); - content.affineFeatureFlagService = framework.get(FeatureFlagService); - content.affineWorkspaceDialogService = framework.get( - WorkspaceDialogService - ); - content.affineThemeService = framework.get(AppThemeService); - content.notificationService = notificationService; - content.aiDraftService = framework.get(AIDraftService); - content.aiToolsConfigService = framework.get(AIToolsConfigService); - content.peekViewService = framework.get(PeekViewService); - content.subscriptionService = framework.get(SubscriptionService); - content.aiModelService = framework.get(AIModelService); - content.onAISubscribe = handleAISubscribe; - content.onEmbeddingProgressChange = onEmbeddingProgressChange; - content.onContextChange = onContextChange; - content.width = sidebarWidthSignal; - content.onOpenDoc = (docId: string, sessionId?: string) => { - openDoc(docId, sessionId).catch(console.error); - }; - - if (!chatContent) { - chatContainerRef.current.append(content); - setChatContent(content); - onLoad?.(content); - } - }, [ - chatContent, - createSession, - doc, - docDisplayConfig, - framework, - handleAISubscribe, - host, - isBodyProvided, - notificationService, - onContextChange, - onEmbeddingProgressChange, - onLoad, - openDoc, - reasoningConfig, - searchMenuConfig, - session, - sidebarWidthSignal, - specs, - ]); - - useEffect(() => { - if (!isHeaderProvided || !chatToolbarContainerRef.current || !doc) { - return; - } - if (session === undefined) { - return; - } - - const tool = getOrCreateAIChatToolbar(chatToolbar); - configureAIChatToolbar(tool, { - session, - workspaceId: doc.workspace.id, - docId: doc.id, - status, - canCreateNewSession, - docDisplayConfig, - notificationService, - onNewSession: () => { - newSession().catch(console.error); - }, - onTogglePin: togglePin, - onOpenSession: (sessionId: string) => { - openSession(sessionId).catch(console.error); - }, - onOpenDoc: (docId: string, sessionId: string) => { - openDoc(docId, sessionId).catch(console.error); - }, - onSessionDelete: (sessionToDelete: BlockSuitePresets.AIRecentSession) => { - deleteSession(sessionToDelete).catch(console.error); - }, - }); - - if (!chatToolbar) { - chatToolbarContainerRef.current.append(tool); - setChatToolbar(tool); - } - }, [ - chatToolbar, - canCreateNewSession, - deleteSession, - doc, - docDisplayConfig, - isHeaderProvided, - newSession, - notificationService, - openDoc, - openSession, - session, - status, - togglePin, - ]); - - useEffect(() => { - if (!chatTabsContainerRef.current || !doc) { - return; - } - if (session === undefined) { - return; - } - - let tabs = chatTabs; - if (!tabs) { - tabs = new AIChatTabs(); - chatTabsContainerRef.current.append(tabs); - setChatTabs(tabs); - } - tabs.sessions = visibleOpenTabs; - tabs.activeSessionId = session?.sessionId; - tabs.showDraftTab = - visibleOpenTabs.length === 0 && !session?.sessionId && !!doc; - tabs.onSelectTab = (sessionId: string) => { - openSession(sessionId).catch(console.error); - }; - tabs.onCloseTab = (sessionId: string) => { - closeTab(sessionId); - }; - }, [chatTabs, closeTab, doc, openSession, session, visibleOpenTabs]); + useAIChatElement({ + containerRef: chatTabsContainerRef, + selector: 'ai-chat-tabs', + enabled: !!runtime && !!snapshot, + createElement: () => new AIChatTabs(), + configureElement: tabs => { + if (!runtime || !snapshot) return; + tabs.runtime = runtime; + tabs.runtimeSnapshot = snapshot; + }, + }); useEffect(() => { if (!editor?.host || !chatContent) { @@ -766,19 +360,17 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => { if (autoResized) { return; } - const subscription = AIProvider.slots.previewPanelOpenChange.subscribe( - open => { - if (!open) { - return; - } - const sidebarWidth = workbench.sidebarWidth$.value; - const minSidebarWidth = 1080; - if (!sidebarWidth || sidebarWidth < minSidebarWidth) { - workbench.setSidebarWidth(minSidebarWidth); - setAutoResized(true); - } + const subscription = AIAppEvents.previewPanelOpenChange.subscribe(open => { + if (!open) { + return; } - ); + const sidebarWidth = workbench.sidebarWidth$.value; + const minSidebarWidth = 1080; + if (!sidebarWidth || sidebarWidth < minSidebarWidth) { + workbench.setSidebarWidth(minSidebarWidth); + setAutoResized(true); + } + }); return () => { subscription.unsubscribe(); }; @@ -843,14 +435,16 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => { chatTabsContainerRef.current = node; }, []); - const isEmbedding = - embeddingProgress[1] > 0 && embeddingProgress[0] < embeddingProgress[1]; - const [done, total] = embeddingProgress; - const isInitialized = session !== undefined; + const embeddingCount = snapshot?.composer.context.embeddingCount; + const done = embeddingCount?.finished ?? 0; + const total = + done + (embeddingCount?.processing ?? 0) + (embeddingCount?.failed ?? 0); + const isEmbedding = total > 0 && done < total; + const hasRuntimeSnapshot = !!snapshot; return (
- {!isInitialized ? ( + {!hasRuntimeSnapshot ? (
diff --git a/packages/frontend/core/src/modules/peek-view/view/doc-preview/doc-peek-view.tsx b/packages/frontend/core/src/modules/peek-view/view/doc-preview/doc-peek-view.tsx index edea2098cb..0de0609414 100644 --- a/packages/frontend/core/src/modules/peek-view/view/doc-preview/doc-peek-view.tsx +++ b/packages/frontend/core/src/modules/peek-view/view/doc-preview/doc-peek-view.tsx @@ -1,6 +1,6 @@ import { Scrollable } from '@affine/component'; import { PageDetailLoading } from '@affine/component/page-detail-skeleton'; -import { type AIChatParams, AIProvider } from '@affine/core/blocksuite/ai'; +import { AIAppEvents, type AIChatParams } from '@affine/core/blocksuite/ai'; import type { AffineEditorContainer } from '@affine/core/blocksuite/block-suite-editor'; import { EditorOutlineViewer } from '@affine/core/blocksuite/outline-viewer'; import { AffineErrorBoundary } from '@affine/core/components/affine/affine-error-boundary'; @@ -137,12 +137,8 @@ function DocPeekPreviewEditor({ // chat panel open is already handled in } }; - disposables.push( - AIProvider.slots.requestOpenWithChat.subscribe(openHandler) - ); - disposables.push( - AIProvider.slots.requestSendWithChat.subscribe(openHandler) - ); + disposables.push(AIAppEvents.requestOpenWithChat.subscribe(openHandler)); + disposables.push(AIAppEvents.requestSendWithChat.subscribe(openHandler)); return () => disposables.forEach(d => d.unsubscribe()); }, [doc, peekView, workbench, workspace.id]); diff --git a/tests/affine-cloud-copilot/e2e/utils/settings-panel-utils.ts b/tests/affine-cloud-copilot/e2e/utils/settings-panel-utils.ts index 5625d9f9d7..167252f990 100644 --- a/tests/affine-cloud-copilot/e2e/utils/settings-panel-utils.ts +++ b/tests/affine-cloud-copilot/e2e/utils/settings-panel-utils.ts @@ -226,8 +226,8 @@ export class SettingsPanelUtils { timeout: number, status = 'synced' ) { + await cleanupWorkspace(page.url().split('/').slice(-2)[0] || ''); await expect(async () => { - await cleanupWorkspace(page.url().split('/').slice(-2)[0] || ''); await this.openSettingsPanel(page); const title = page.getByTestId('embedding-progress-title'); // oxlint-disable-next-line prefer-dom-node-dataset