feat(editor): extract chat runtime (#14937)

#### PR Dependency Tree


* **PR #14937** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Centralized AI event system and a runtime powering chat sessions and
actions.

* **Improvements**
* Chat UI (composer, messages, toolbar, tabs, panels) now syncs with
runtime snapshots for more consistent state.
* Improved session/tab lifecycle (create, fork, delete), context
embedding status, and history handling.
* More reliable send/stop/retry flows, better telemetry scoping, and
clearer upgrade/login/insert-template prompts.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
DarkSky
2026-05-13 21:57:50 +08:00
committed by GitHub
parent 322f2ba986
commit e222f06e94
64 changed files with 5348 additions and 4510 deletions
@@ -42,7 +42,9 @@ import type { TemplateResult } from 'lit';
import { insertFromMarkdown } from '../../utils';
import type { ChatMessage } from '../components/ai-chat-messages';
import { AIProvider, type AIUserInfo } from '../provider';
import { AIAppEvents, type AIUserInfo } from '../provider';
import { AIChatRuntime, ForkAIChatSessionStrategy } from '../runtime/chat';
import { getAIRequestService } from '../runtime/request';
import { reportResponse } from '../utils/action-reporter';
import { insertBelow } from '../utils/editor-actions';
@@ -72,7 +74,7 @@ export async function queryHistoryMessages(
docId?: string
) {
// Get fork session messages
const histories = await AIProvider.histories?.chats(
const histories = await getAIRequestService().histories.chats(
workspaceId,
forkSessionId,
docId
@@ -114,7 +116,7 @@ export async function constructRootChatBlockMessages(
forkSessionId: string
) {
// Convert chat messages to AI chat block messages
const userInfo = await AIProvider.userInfo;
const userInfo = AIAppEvents.userInfo.value;
const forkMessages = (await queryHistoryMessages(
doc.workspace.id,
forkSessionId,
@@ -247,7 +249,7 @@ async function insertBelowBlock(
): Promise<boolean> {
if (!block) return false;
reportResponse('result:insert');
reportResponse('result:insert', host);
await insertBelow(host, content, block);
return true;
}
@@ -376,13 +378,20 @@ const SAVE_AS_BLOCK: ChatAction = {
});
}
try {
const newSessionId = await AIProvider.forkChat?.({
const runtime = new AIChatRuntime({
request: getAIRequestService(),
scope: {
kind: 'fork',
workspaceId: host.store.workspace.id,
docId: host.store.id,
sessionId: parentSessionId,
parentSessionId,
latestMessageId: messageId,
});
},
strategy: new ForkAIChatSessionStrategy(),
});
try {
const newSession = await runtime.createSession();
const newSessionId = newSession?.sessionId;
if (!newSessionId) {
return false;
@@ -424,6 +433,8 @@ const SAVE_AS_BLOCK: ChatAction = {
onClose: function (): void {},
});
return false;
} finally {
runtime.dispose();
}
},
};
@@ -439,7 +450,7 @@ const ADD_TO_EDGELESS_AS_NOTE = {
},
toast: 'New note created',
handler: async (host: EditorHost, content: string): Promise<boolean> => {
reportResponse('result:add-note');
reportResponse('result:add-note', host);
const { store } = host;
const gfx = host.std.get(GfxControllerIdentifier);
@@ -475,7 +486,7 @@ export const SAVE_AS_DOC = {
showWhen: () => true,
toast: 'New doc created',
handler: (host: EditorHost, content: string) => {
reportResponse('result:add-page');
reportResponse('result:add-page', host);
const doc = host.store.workspace.createDoc();
const newDoc = doc.getStore();
newDoc.load();
@@ -518,7 +529,7 @@ const CREATE_AS_LINKED_DOC = {
},
toast: 'New doc created',
handler: async (host: EditorHost, content: string) => {
reportResponse('result:add-page');
reportResponse('result:add-page', host);
const { store } = host;
const surfaceBlock = store
@@ -36,7 +36,7 @@ import type {
AIItemGroupConfig,
AISubItemConfig,
} from '../components/ai-item/types';
import { AIProvider } from '../provider';
import { AIAppEvents } from '../provider';
import { getAIPanelWidget } from '../utils/ai-widgets';
import { getEdgelessCopilotWidget } from '../utils/get-edgeless-copilot-widget';
import {
@@ -386,7 +386,7 @@ const OthersAIGroup: AIItemGroupConfig = {
handler: host => {
const panel = getAIPanelWidget(host);
const edgelessCopilot = getEdgelessCopilotWidget(host);
AIProvider.slots.requestOpenWithChat.next({
AIAppEvents.requestOpenWithChat.next({
host,
autoSelect: true,
});
@@ -12,7 +12,8 @@ import {
} from '../ai-panel';
import { StreamObjectSchema } from '../components/ai-chat-messages';
import { type AIItemGroupConfig } from '../components/ai-item/types';
import { type AIError, AIProvider } from '../provider';
import { type AIError } from '../provider';
import { getAIRequestService } from '../runtime/request';
import { reportResponse } from '../utils/action-reporter';
import { getAIPanelWidget } from '../utils/ai-widgets';
import { AIContext } from '../utils/context';
@@ -33,10 +34,12 @@ export function bindTextStream(
update,
finish,
signal,
host,
}: {
update: (answer: AIActionAnswer) => void;
finish: (state: 'success' | 'error' | 'aborted', err?: AIError) => void;
signal?: AbortSignal;
host?: EditorHost;
}
) {
(async () => {
@@ -45,7 +48,7 @@ export function bindTextStream(
};
signal?.addEventListener('abort', () => {
finish('aborted');
reportResponse('aborted:stop');
reportResponse('aborted:stop', host);
});
for await (const data of stream) {
if (signal?.aborted) {
@@ -88,9 +91,6 @@ function actionToStream<T extends keyof BlockSuitePresets.AIActions>(
>,
trackerOptions?: BlockSuitePresets.TrackerOptions
): BlockSuitePresets.TextStream | undefined {
const action = AIProvider.actions[id];
if (!action || typeof action !== 'function') return;
let stream: BlockSuitePresets.TextStream | undefined;
return {
async *[Symbol.asyncIterator]() {
@@ -123,9 +123,11 @@ function actionToStream<T extends keyof BlockSuitePresets.AIActions>(
where,
docId: host.store.id,
workspaceId: host.store.workspace.id,
} as Parameters<typeof action>[0];
// @ts-expect-error TODO(@Peng): maybe fix this
stream = await action(options);
} as BlockSuitePresets.AITextActionOptions & Record<string, unknown>;
stream = (await getAIRequestService().executeAction(
id,
options
)) as BlockSuitePresets.TextStream;
if (!stream) return;
yield* stream;
},
@@ -163,7 +165,7 @@ function actionToGenerateAnswer<T extends keyof BlockSuitePresets.AIActions>(
trackerOptions
);
if (!stream) return;
bindTextStream(stream, { update, finish, signal });
bindTextStream(stream, { update, finish, signal, host });
};
}
@@ -198,7 +200,7 @@ function updateAIPanelConfig<T extends keyof BlockSuitePresets.AIActions>(
config.errorStateConfig = buildErrorConfig(aiPanel);
config.copy = buildCopyConfig(aiPanel);
config.discardCallback = () => {
reportResponse('result:discard');
reportResponse('result:discard', host);
};
}
@@ -21,7 +21,8 @@ import type { TemplateResult } from 'lit';
import { getContentFromSlice } from '../../utils';
import { AIChatBlockModel } from '../blocks';
import { type AIError, AIProvider } from '../provider';
import { type AIError } from '../provider';
import { getAIRequestService } from '../runtime/request';
import { reportResponse } from '../utils/action-reporter';
import { getAIPanelWidget } from '../utils/ai-widgets';
import { AIContext } from '../utils/context';
@@ -178,10 +179,6 @@ function actionToStream<T extends keyof BlockSuitePresets.AIActions>(
trackerOptions?: BlockSuitePresets.TrackerOptions,
panelInput?: string
) {
const action = AIProvider.actions[id];
if (!action || typeof action !== 'function') return;
if (extract && typeof extract === 'function') {
return (host: EditorHost, ctx: AIContext): BlockSuitePresets.TextStream => {
let stream: BlockSuitePresets.TextStream | undefined;
@@ -201,7 +198,7 @@ function actionToStream<T extends keyof BlockSuitePresets.AIActions>(
host,
docId: host.store.id,
workspaceId: host.store.workspace.id,
} as Parameters<typeof action>[0];
} as BlockSuitePresets.AITextActionOptions & Record<string, unknown>;
const content = ctx.get().content;
if (typeof content === 'string' && !content.length && panelInput) {
@@ -214,8 +211,10 @@ function actionToStream<T extends keyof BlockSuitePresets.AIActions>(
Object.assign(options, data);
}
// @ts-expect-error TODO(@Peng): maybe fix this
stream = await action(options);
stream = (await getAIRequestService().executeAction(
id,
options
)) as BlockSuitePresets.TextStream;
if (!stream) return;
yield* stream;
},
@@ -242,10 +241,12 @@ function actionToStream<T extends keyof BlockSuitePresets.AIActions>(
host,
docId: host.store.id,
workspaceId: host.store.workspace.id,
} as Parameters<typeof action>[0];
} as BlockSuitePresets.AITextActionOptions & Record<string, unknown>;
// @ts-expect-error TODO(@Peng): maybe fix this
stream = await action(options);
stream = (await getAIRequestService().executeAction(
id,
options
)) as BlockSuitePresets.TextStream;
if (!stream) return;
yield* stream;
},
@@ -351,7 +352,7 @@ function updateEdgelessAIPanelConfig<
},
};
config.discardCallback = () => {
reportResponse('result:discard');
reportResponse('result:discard', host);
};
config.hideCallback = () => {
aiPanel.updateComplete
@@ -36,7 +36,7 @@ import { styleMap } from 'lit/directives/style-map.js';
import { insertFromMarkdown } from '../../utils';
import type { ChatContextValue } from '../components/ai-chat-content/type';
import type { AIItemConfig } from '../components/ai-item/types';
import { AIProvider } from '../provider';
import { AIAppEvents } from '../provider';
import { reportResponse } from '../utils/action-reporter';
import { getAIPanelWidget } from '../utils/ai-widgets';
import type { AIContext } from '../utils/context';
@@ -98,7 +98,7 @@ export function retry(panel: AffineAIPanelWidget): AIItemConfig {
icon: ResetIcon(),
testId: 'answer-retry',
handler: () => {
reportResponse('result:retry');
reportResponse('result:retry', panel.host);
panel.generate();
},
};
@@ -152,7 +152,7 @@ export function createInsertItems<T extends keyof BlockSuitePresets.AIActions>(
);
},
handler: () => {
reportResponse('result:insert');
reportResponse('result:insert', host);
edgelessResponseHandler(id, host, ctx).catch(console.error);
const panel = getAIPanelWidget(host);
panel.hide();
@@ -201,7 +201,7 @@ export function asCaption<T extends keyof BlockSuitePresets.AIActions>(
return id === 'generateCaption' && !!panel.answer;
},
handler: () => {
reportResponse('result:use-as-caption');
reportResponse('result:use-as-caption', host);
const panel = getAIPanelWidget(host);
const caption = panel.answer;
if (!caption) return;
@@ -589,7 +589,7 @@ export function actionToResponse<T extends keyof BlockSuitePresets.AIActions>(
testId: 'answer-continue-in-chat',
icon: ChatWithAiIcon({}),
handler: () => {
reportResponse('result:continue-in-chat');
reportResponse('result:continue-in-chat', host);
edgelesContinueResponseHandler(id, host, ctx).catch(
console.error
);
@@ -700,7 +700,7 @@ async function edgelesContinueResponseHandler<
}
const panel = getAIPanelWidget(host);
AIProvider.slots.requestOpenWithChat.next({
AIAppEvents.requestOpenWithChat.next({
host,
context,
fromAnswer: true,
@@ -732,11 +732,11 @@ export function actionToErrorResponse<
): ErrorConfig {
return {
upgrade: () => {
AIProvider.slots.requestUpgradePlan.next({ host: panel.host });
AIAppEvents.requestUpgradePlan.next({ host: panel.host });
panel.hide();
},
login: () => {
AIProvider.slots.requestLogin.next({ host: panel.host });
AIAppEvents.requestLogin.next({ host: panel.host });
panel.hide();
},
cancel: () => {
@@ -35,7 +35,7 @@ import {
} from './actions/page-response';
import type { AIItemConfig } from './components/ai-item/types';
import { createAIScrollableTextRenderer } from './components/ai-scrollable-text-renderer';
import { AIProvider } from './provider';
import { AIAppEvents } from './provider';
import { reportResponse } from './utils/action-reporter';
import { getAIPanelWidget } from './utils/ai-widgets';
import { AIContext } from './utils/context';
@@ -58,7 +58,7 @@ function asCaption<T extends keyof BlockSuitePresets.AIActions>(
return id === 'generateCaption' && !!panel.answer;
},
handler: () => {
reportResponse('result:use-as-caption');
reportResponse('result:use-as-caption', host);
const panel = getAIPanelWidget(host);
const caption = panel.answer;
if (!caption) return;
@@ -85,7 +85,7 @@ function createNewNote(host: EditorHost): AIItemConfig {
return !!panel.answer && isInsideEdgelessEditor(host);
},
handler: () => {
reportResponse('result:add-note');
reportResponse('result:add-note', host);
// get the note block
const { selectedBlocks } = getSelections(host);
if (!selectedBlocks || !selectedBlocks.length) return;
@@ -157,7 +157,7 @@ function buildPageResponseConfig<T extends keyof BlockSuitePresets.AIActions>(
showWhen: () =>
!!panel.answer && (!id || !INSERT_ABOVE_ACTIONS.includes(id)),
handler: () => {
reportResponse('result:insert');
reportResponse('result:insert', host);
pageResponseHandler(id, host, ctx, 'after').catch(console.error);
panel.hide();
},
@@ -169,7 +169,7 @@ function buildPageResponseConfig<T extends keyof BlockSuitePresets.AIActions>(
showWhen: () =>
!!panel.answer && !!id && INSERT_ABOVE_ACTIONS.includes(id),
handler: () => {
reportResponse('result:insert');
reportResponse('result:insert', host);
pageResponseHandler(id, host, ctx, 'before').catch(console.error);
panel.hide();
},
@@ -182,7 +182,7 @@ function buildPageResponseConfig<T extends keyof BlockSuitePresets.AIActions>(
showWhen: () =>
!!panel.answer && !EXCLUDING_REPLACE_ACTIONS.includes(id),
handler: () => {
reportResponse('result:replace');
reportResponse('result:replace', host);
replaceWithMarkdown(host).catch(console.error);
panel.hide();
},
@@ -199,8 +199,8 @@ function buildPageResponseConfig<T extends keyof BlockSuitePresets.AIActions>(
icon: ChatWithAiIcon(),
testId: 'answer-continue-in-chat',
handler: () => {
reportResponse('result:continue-in-chat');
AIProvider.slots.requestOpenWithChat.next({ host });
reportResponse('result:continue-in-chat', host);
AIAppEvents.requestOpenWithChat.next({ host });
panel.hide();
},
},
@@ -209,7 +209,7 @@ function buildPageResponseConfig<T extends keyof BlockSuitePresets.AIActions>(
icon: ResetIcon(),
testId: 'answer-regenerate',
handler: () => {
reportResponse('result:retry');
reportResponse('result:retry', host);
panel.generate();
},
},
@@ -237,7 +237,7 @@ export function buildErrorResponseConfig(panel: AffineAIPanelWidget) {
testId: 'error-retry',
showWhen: () => true,
handler: () => {
reportResponse('result:retry');
reportResponse('result:retry', panel.host);
panel.generate();
},
},
@@ -269,11 +269,11 @@ export function buildFinishConfig<T extends keyof BlockSuitePresets.AIActions>(
export function buildErrorConfig(panel: AffineAIPanelWidget) {
return {
upgrade: () => {
AIProvider.slots.requestUpgradePlan.next({ host: panel.host });
AIAppEvents.requestUpgradePlan.next({ host: panel.host });
panel.hide();
},
login: () => {
AIProvider.slots.requestLogin.next({ host: panel.host });
AIAppEvents.requestLogin.next({ host: panel.host });
panel.hide();
},
cancel: () => {
@@ -10,14 +10,7 @@ import type {
SubscriptionService,
} from '@affine/core/modules/cloud';
import type { WorkspaceDialogService } from '@affine/core/modules/dialogs';
import type {
ContextEmbedStatus,
ContextWorkspaceEmbeddingStatus,
CopilotChatHistoryFragment,
CopilotContextBlob,
CopilotContextDoc,
CopilotContextFile,
} from '@affine/graphql';
import type { CopilotChatHistoryFragment } from '@affine/graphql';
import { SignalWatcher, WithDisposable } from '@blocksuite/affine/global/lit';
import type { EditorHost } from '@blocksuite/affine/std';
import { ShadowlessElement } from '@blocksuite/affine/std';
@@ -30,14 +23,16 @@ import { css, html, type PropertyValues } from 'lit';
import { property, state } from 'lit/decorators.js';
import {
AIAppEvents,
type AIChatParams,
AIProvider,
type AISendParams,
} from '../../provider';
import type { AIChatRuntime, AIChatSnapshot } from '../../runtime/chat';
import type { SearchMenuConfig } from '../ai-chat-add-context';
import type {
AttachmentChip,
ChatChip,
ChipState,
CollectionChip,
DocChip,
DocDisplayConfig,
@@ -58,8 +53,6 @@ import {
import type { AIChatInputContext, AIReasoningConfig } from '../ai-chat-input';
import { MAX_IMAGE_COUNT } from '../ai-chat-input/const';
export const EMBEDDING_STATUS_CHECK_INTERVAL = 10000;
export class AIChatComposer extends SignalWatcher(
WithDisposable(ShadowlessElement)
) {
@@ -91,9 +84,10 @@ export class AIChatComposer extends SignalWatcher(
accessor session!: CopilotChatHistoryFragment | null | undefined;
@property({ attribute: false })
accessor createSession!: () => Promise<
CopilotChatHistoryFragment | undefined
>;
accessor runtime: AIChatRuntime | null | undefined;
@property({ attribute: false })
accessor runtimeSnapshot: AIChatSnapshot | null | undefined;
@property({ attribute: false })
accessor chatContextValue!: AIChatInputContext;
@@ -101,11 +95,6 @@ export class AIChatComposer extends SignalWatcher(
@property({ attribute: false })
accessor updateContext!: (context: Partial<AIChatInputContext>) => void;
@property({ attribute: false })
accessor onEmbeddingProgressChange:
| ((count: Record<ContextEmbedStatus, number>) => void)
| undefined;
@property({ attribute: false })
accessor docDisplayConfig!: DocDisplayConfig;
@@ -160,12 +149,6 @@ export class AIChatComposer extends SignalWatcher(
@state()
accessor embeddingCompleted = false;
private _contextId: string | undefined = undefined;
private _pollAbortController: AbortController | null = null;
private _pollEmbeddingStatusAbortController: AbortController | null = null;
override render() {
return html`
<chat-panel-chips
@@ -186,10 +169,11 @@ export class AIChatComposer extends SignalWatcher(
.workspaceId=${this.workspaceId}
.docId=${this.docId}
.session=${this.session}
.runtime=${this.runtime}
.runtimeSnapshot=${this.runtimeSnapshot}
.chips=${this.chips}
.addChip=${this.addChip}
.addImages=${this.addImages}
.createSession=${this.createSession}
.chatContextValue=${this.chatContextValue}
.updateContext=${this.updateContext}
.reasoningConfig=${this.reasoningConfig}
@@ -222,31 +206,40 @@ export class AIChatComposer extends SignalWatcher(
override connectedCallback() {
super.connectedCallback();
this._disposables.add(
AIProvider.slots.requestOpenWithChat.subscribe(this.beforeChatContextSend)
AIAppEvents.requestOpenWithChat.subscribe(this.beforeChatContextSend)
);
this._disposables.add(
AIProvider.slots.requestSendWithChat.subscribe(this.beforeChatContextSend)
AIAppEvents.requestSendWithChat.subscribe(this.beforeChatContextSend)
);
this.initComposer().catch(console.error);
}
override disconnectedCallback() {
super.disconnectedCallback();
this._abortPoll();
this._abortPollEmbeddingStatus();
this.runtime?.dispatch({ type: 'stopContextPolling' }).catch(console.error);
}
protected override willUpdate(changedProperties: PropertyValues): void {
const previousSnapshot = changedProperties.get('runtimeSnapshot') as
| AIChatSnapshot
| null
| undefined;
if (
changedProperties.has('chatContextValue') &&
changedProperties.get('chatContextValue')?.status !== 'loading' &&
this.chatContextValue.status === 'loading' &&
changedProperties.has('runtimeSnapshot') &&
previousSnapshot?.status !== 'loading' &&
this.runtimeSnapshot?.status === 'loading' &&
this.isChipsCollapsed === false
) {
this.isChipsCollapsed = true;
}
}
protected override updated(changedProperties: PropertyValues): void {
if (changedProperties.has('runtimeSnapshot')) {
this.syncChipsFromRuntimeSnapshot();
}
}
private readonly beforeChatContextSend = (
params: AISendParams | AIChatParams | null
) => {
@@ -276,100 +269,115 @@ export class AIChatComposer extends SignalWatcher(
return this.chips.some(chip => chip.state === 'processing');
}
private readonly _getContextId = async () => {
if (this._contextId) {
return this._contextId;
private readonly toChipState = (state?: string): ChipState => {
if (state === 'finished' || state === 'processing' || state === 'failed') {
return state;
}
const sessionId = this.session?.sessionId;
if (!sessionId) return;
const contextId = await AIProvider.context?.getContextId(
this.workspaceId,
sessionId
);
this._contextId = contextId;
return this._contextId;
return 'processing';
};
private readonly createContextId = async () => {
if (this._contextId) {
return this._contextId;
private readonly runtimeItemToChip = (
item: AIChatSnapshot['composer']['context']['items'][number]
): ChatChip => {
switch (item.kind) {
case 'doc':
return {
docId: item.docId,
state: this.toChipState(item.state),
createdAt: item.createdAt,
tooltip: item.tooltip,
};
case 'file':
return {
file: item.file,
fileId: item.fileId,
blobId: item.blobId,
state: this.toChipState(item.state),
createdAt: item.createdAt,
tooltip: item.tooltip,
};
case 'tag':
return {
tagId: item.tagId,
state: this.toChipState(item.state),
createdAt: item.createdAt,
tooltip: item.tooltip,
};
case 'collection':
return {
collectionId: item.collectionId,
state: this.toChipState(item.state),
createdAt: item.createdAt,
tooltip: item.tooltip,
};
case 'blob':
return {
sourceId: item.blobId,
name: item.blobId,
state: this.toChipState(item.state),
createdAt: item.createdAt,
tooltip: item.tooltip,
};
}
};
const sessionId = (await this.createSession())?.sessionId;
if (!sessionId) return;
private readonly syncChipsFromRuntimeSnapshot = (
snapshot = this.runtimeSnapshot
) => {
const context = snapshot?.composer.context;
if (!context) return;
this.embeddingCompleted = context.embeddingCompleted;
const selectedChips = this.chips.filter(isSelectedContextChip);
this.updateChips([
...context.items.map(this.runtimeItemToChip),
...selectedChips,
]);
};
this._contextId = await AIProvider.context?.createContext(
this.workspaceId,
sessionId
);
return this._contextId;
private readonly syncChipsFromRuntime = () => {
this.syncChipsFromRuntimeSnapshot(this.runtime?.getSnapshot());
};
private readonly chipToContextItem = (
chip: ChatChip
): AIChatSnapshot['composer']['context']['items'][number] | null => {
if (isDocChip(chip)) {
return { kind: 'doc', docId: chip.docId, state: chip.state };
}
if (isFileChip(chip)) {
return {
kind: 'file',
file: chip.file,
fileId: chip.fileId ?? undefined,
blobId: chip.blobId ?? undefined,
state: chip.state,
};
}
if (isTagChip(chip)) {
return {
kind: 'tag',
tagId: chip.tagId,
docIds: this.docDisplayConfig.getTagPageIds(chip.tagId),
state: chip.state,
};
}
if (isCollectionChip(chip)) {
return {
kind: 'collection',
collectionId: chip.collectionId,
docIds: this.docDisplayConfig.getCollectionPageIds(chip.collectionId),
state: chip.state,
};
}
if (isAttachmentChip(chip)) {
return { kind: 'blob', blobId: chip.sourceId, state: chip.state };
}
return null;
};
private readonly initChips = async () => {
// context not initialized
const sessionId = this.session?.sessionId;
const contextId = await this._getContextId();
if (!sessionId || !contextId) {
return;
}
// context initialized, show the chips
const {
docs = [],
files = [],
tags = [],
collections = [],
} = (await AIProvider.context?.getContextDocsAndFiles(
this.workspaceId,
sessionId,
contextId
)) || {};
const docChips: DocChip[] = docs.map(doc => ({
docId: doc.id,
state: doc.status || 'processing',
createdAt: doc.createdAt,
}));
const fileChips: FileChip[] = await Promise.all(
files.map(async file => {
return {
file: new File([], file.name),
blobId: file.blobId,
fileId: file.id,
state: file.status,
tooltip: file.error,
createdAt: file.createdAt,
};
})
);
const tagChips: TagChip[] = tags.map(tag => ({
tagId: tag.id,
state: 'finished',
createdAt: tag.createdAt,
}));
const collectionChips: CollectionChip[] = collections.map(collection => ({
collectionId: collection.id,
state: 'finished',
createdAt: collection.createdAt,
}));
const chips: ChatChip[] = [
...docChips,
...fileChips,
...tagChips,
...collectionChips,
].sort((a, b) => {
const aTime = a.createdAt ?? Date.now();
const bTime = b.createdAt ?? Date.now();
return aTime - bTime;
});
this.updateChips(chips);
await this.runtime?.dispatch({ type: 'loadContext' });
this.syncChipsFromRuntime();
};
private readonly updateChips = (chips: ChatChip[]) => {
@@ -487,14 +495,11 @@ export class AIChatComposer extends SignalWatcher(
private readonly addDocToContext = async (chip: DocChip) => {
try {
const contextId = await this.createContextId();
if (!contextId || !AIProvider.context) {
throw new Error('Context not found');
}
await AIProvider.context.addContextDoc({
contextId,
docId: chip.docId,
await this.runtime?.dispatch({
type: 'addContextItem',
item: { kind: 'doc', docId: chip.docId, state: chip.state },
});
this.syncChipsFromRuntime();
} catch (e) {
this.updateChip(chip, {
state: 'failed',
@@ -505,18 +510,17 @@ export class AIChatComposer extends SignalWatcher(
private readonly addFileToContext = async (chip: FileChip) => {
try {
const contextId = await this.createContextId();
if (!contextId || !AIProvider.context) {
throw new Error('Context not found');
}
const contextFile = await AIProvider.context.addContextFile(chip.file, {
contextId,
});
this.updateChip(chip, {
state: contextFile.status,
blobId: contextFile.blobId,
fileId: contextFile.id,
await this.runtime?.dispatch({
type: 'addContextItem',
item: {
kind: 'file',
file: chip.file,
fileId: chip.fileId ?? undefined,
blobId: chip.blobId ?? undefined,
state: chip.state,
},
});
this.syncChipsFromRuntime();
} catch (e) {
this.updateChip(chip, {
state: 'failed',
@@ -527,20 +531,16 @@ export class AIChatComposer extends SignalWatcher(
private readonly addTagToContext = async (chip: TagChip) => {
try {
const contextId = await this.createContextId();
if (!contextId || !AIProvider.context) {
throw new Error('Context not found');
}
// TODO: server side docIds calculation
const docIds = this.docDisplayConfig.getTagPageIds(chip.tagId);
await AIProvider.context.addContextTag({
contextId,
tagId: chip.tagId,
docIds,
});
this.updateChip(chip, {
state: 'finished',
await this.runtime?.dispatch({
type: 'addContextItem',
item: {
kind: 'tag',
tagId: chip.tagId,
docIds: this.docDisplayConfig.getTagPageIds(chip.tagId),
state: chip.state,
},
});
this.syncChipsFromRuntime();
} catch (e) {
this.updateChip(chip, {
state: 'failed',
@@ -551,22 +551,16 @@ export class AIChatComposer extends SignalWatcher(
private readonly addCollectionToContext = async (chip: CollectionChip) => {
try {
const contextId = await this.createContextId();
if (!contextId || !AIProvider.context) {
throw new Error('Context not found');
}
// TODO: server side docIds calculation
const docIds = this.docDisplayConfig.getCollectionPageIds(
chip.collectionId
);
await AIProvider.context.addContextCollection({
contextId,
collectionId: chip.collectionId,
docIds,
});
this.updateChip(chip, {
state: 'finished',
await this.runtime?.dispatch({
type: 'addContextItem',
item: {
kind: 'collection',
collectionId: chip.collectionId,
docIds: this.docDisplayConfig.getCollectionPageIds(chip.collectionId),
state: chip.state,
},
});
this.syncChipsFromRuntime();
} catch (e) {
this.updateChip(chip, {
state: 'failed',
@@ -579,19 +573,12 @@ export class AIChatComposer extends SignalWatcher(
private readonly addAttachmentChipToContext = async (
chip: AttachmentChip
) => {
const contextId = await this.createContextId();
if (!contextId || !AIProvider.context) {
throw new Error('Context not found');
}
try {
const contextBlob = await AIProvider.context.addContextBlob({
blobId: chip.sourceId,
contextId,
});
this.updateChip(chip, {
state: contextBlob.status || 'processing',
blobId: chip.sourceId,
await this.runtime?.dispatch({
type: 'addContextItem',
item: { kind: 'blob', blobId: chip.sourceId, state: chip.state },
});
this.syncChipsFromRuntime();
} catch (e) {
this.updateChip(chip, {
state: 'failed',
@@ -604,48 +591,19 @@ export class AIChatComposer extends SignalWatcher(
private readonly removeFromContext = async (
chip: ChatChip
): Promise<boolean> => {
if (isSelectedContextChip(chip)) {
this.updateContext({
...this.chatContextValue,
snapshot: null,
combinedElementsMarkdown: null,
});
return true;
}
const item = this.chipToContextItem(chip);
if (!item) return true;
try {
const contextId = await this.createContextId();
if (!contextId || !AIProvider.context) {
return true;
}
if (isDocChip(chip)) {
return await AIProvider.context.removeContextDoc({
contextId,
docId: chip.docId,
});
}
if (isFileChip(chip) && chip.fileId) {
return await AIProvider.context.removeContextFile({
contextId,
fileId: chip.fileId,
});
}
if (isTagChip(chip)) {
return await AIProvider.context.removeContextTag({
contextId,
tagId: chip.tagId,
});
}
if (isCollectionChip(chip)) {
return await AIProvider.context.removeContextCollection({
contextId,
collectionId: chip.collectionId,
});
}
if (isAttachmentChip(chip)) {
return await AIProvider.context.removeContextBlob({
contextId,
blobId: chip.sourceId,
});
}
if (isSelectedContextChip(chip)) {
this.updateContext({
...this.chatContextValue,
snapshot: null,
combinedElementsMarkdown: null,
});
}
await this.runtime?.dispatch({ type: 'removeContextItem', item });
this.syncChipsFromRuntime();
return true;
} catch {
return true;
@@ -669,136 +627,19 @@ export class AIChatComposer extends SignalWatcher(
};
private readonly pollContextDocsAndFiles = async () => {
const sessionId = this.session?.sessionId;
const contextId = await this._getContextId();
if (!sessionId || !contextId || !AIProvider.context) {
return;
}
if (this._pollAbortController) {
// already polling, reset timer
this._abortPoll();
}
this._pollAbortController = new AbortController();
await AIProvider.context.pollContextDocsAndFiles(
this.workspaceId,
sessionId,
contextId,
this._onPoll,
this._pollAbortController.signal
);
if (!this.runtime) return;
await this.runtime.dispatch({ type: 'startContextPolling' });
this.syncChipsFromRuntime();
};
private readonly pollEmbeddingStatus = async () => {
if (this._pollEmbeddingStatusAbortController) {
this._pollEmbeddingStatusAbortController.abort();
}
this._pollEmbeddingStatusAbortController = new AbortController();
const signal = this._pollEmbeddingStatusAbortController.signal;
try {
await AIProvider.context?.pollEmbeddingStatus(
this.workspaceId,
(status: ContextWorkspaceEmbeddingStatus) => {
if (!status) {
this.embeddingCompleted = false;
return;
}
const prevCompleted = this.embeddingCompleted;
const completed = status.embedded === status.total;
this.embeddingCompleted = completed;
if (prevCompleted !== completed) {
this.requestUpdate();
}
},
signal
);
} catch {
this.embeddingCompleted = false;
}
};
private readonly _onPoll = (
result?: BlockSuitePresets.AIDocsAndFilesContext
) => {
if (!result) {
this._abortPoll();
return;
}
const {
docs: sDocs = [],
files = [],
tags = [],
collections = [],
blobs = [],
} = result;
const docs = [
...sDocs,
...tags.flatMap(tag => tag.docs),
...collections.flatMap(collection => collection.docs),
];
const hashMap = new Map<
string,
CopilotContextDoc | CopilotContextFile | CopilotContextBlob
>();
const count: Record<ContextEmbedStatus, number> = {
finished: 0,
processing: 0,
failed: 0,
};
docs.forEach(doc => {
hashMap.set(doc.id, doc);
doc.status && count[doc.status]++;
});
files.forEach(file => {
hashMap.set(file.id, file);
file.status && count[file.status]++;
});
blobs.forEach(blob => {
hashMap.set(blob.id, blob);
blob.status && count[blob.status]++;
});
const nextChips = this.chips.map(chip => {
if (isTagChip(chip) || isCollectionChip(chip)) {
return chip;
}
const id = isDocChip(chip)
? chip.docId
: isFileChip(chip)
? chip.fileId
: isAttachmentChip(chip)
? chip.sourceId
: isSelectedContextChip(chip)
? chip.uuid
: undefined;
const item = id && hashMap.get(id);
if (item && item.status) {
return {
...chip,
state: item.status,
tooltip: 'error' in item ? item.error : undefined,
};
}
return chip;
});
this.updateChips(nextChips);
this.onEmbeddingProgressChange?.(count);
if (count.processing === 0) {
this._abortPoll();
}
};
private readonly _abortPoll = () => {
this._pollAbortController?.abort();
this._pollAbortController = null;
};
private readonly _abortPollEmbeddingStatus = () => {
this._pollEmbeddingStatusAbortController?.abort();
this._pollEmbeddingStatusAbortController = null;
if (!this.runtime) return;
await this.runtime.dispatch({ type: 'pollEmbeddingStatus' });
this.syncChipsFromRuntime();
};
private readonly initComposer = async () => {
const userId = (await AIProvider.userInfo)?.id;
const userId = AIAppEvents.userInfo.value?.id;
if (!userId || !this.session) return;
await this.initChips();
@@ -1,17 +1,10 @@
/**
* @vitest-environment happy-dom
*/
import { afterEach, describe, expect, test, vi } from 'vitest';
import { describe, expect, test, vi } from 'vitest';
import { AIProvider } from '../../provider';
import { AIChatContent } from './ai-chat-content';
const originalHistories = AIProvider.histories;
afterEach(() => {
AIProvider.provide('histories', originalHistories as any);
});
describe('AIChatContent pinned scroll tracking', () => {
test('records scroll position from the chat messages host', async () => {
let scrollEndHandler: (() => void) | undefined;
@@ -47,103 +40,44 @@ describe('AIChatContent pinned scroll tracking', () => {
});
});
describe('AIChatContent history loading', () => {
test('replaces messages when the active session changes', async () => {
const histories = {
chats: vi.fn(async (_workspaceId: string, sessionId: string) => [
{
messages: [
{
id: `${sessionId}-message`,
role: 'user',
content: sessionId,
createdAt: '2026-01-01T00:00:00.000Z',
},
],
},
]),
actions: vi.fn(async () => []),
cleanup: vi.fn(),
ids: vi.fn(),
describe('AIChatContent runtime snapshot sync', () => {
test('derives messages and loading state from runtime snapshot', () => {
const runtimeMessage = {
id: 'message-1',
role: 'user',
content: 'hello',
createdAt: '2026-01-01T00:00:00.000Z',
};
AIProvider.provide('histories', histories as any);
const content: {
updateHistoryCounter: number;
historyKey: string | undefined;
workspaceId: string;
docId: string;
session: { sessionId: string };
chatContextValue: { messages: unknown[]; status?: string };
updateContext: (context: { messages: unknown[] }) => void;
} = {
updateHistoryCounter: 0,
historyKey: undefined,
workspaceId: 'ws-1',
docId: 'doc-1',
session: { sessionId: 'session-1' },
chatContextValue: { messages: [] },
updateContext(context: { messages: unknown[] }) {
this.chatContextValue = {
...this.chatContextValue,
...context,
};
},
};
await (AIChatContent.prototype as any).updateHistory.call(content);
expect(
content.chatContextValue.messages.map((message: any) => message.id)
).toEqual(['session-1-message']);
content.session = { sessionId: 'session-2' };
await (AIChatContent.prototype as any).updateHistory.call(content);
expect(
content.chatContextValue.messages.map((message: any) => message.id)
).toEqual(['session-2-message']);
});
test('does not overwrite in-flight optimistic messages when a session is created', async () => {
const histories = {
chats: vi.fn(async () => [{ messages: [] }]),
actions: vi.fn(async () => []),
cleanup: vi.fn(),
ids: vi.fn(),
};
AIProvider.provide('histories', histories as any);
const optimisticMessages = [
{
id: '',
role: 'user',
content: 'hello',
createdAt: '2026-01-01T00:00:00.000Z',
},
{
id: '',
role: 'assistant',
content: '',
createdAt: '2026-01-01T00:00:01.000Z',
},
];
const updateContext = vi.fn();
const content = {
updateHistoryCounter: 0,
historyKey: 'ws-1:doc-1:',
workspaceId: 'ws-1',
docId: 'doc-1',
session: { sessionId: 'session-1' },
chatContextValue: {
messages: optimisticMessages,
const content = Object.create(AIChatContent.prototype) as AIChatContent;
Object.defineProperty(content, 'runtimeSnapshot', {
configurable: true,
value: {
messages: [runtimeMessage],
status: 'loading',
error: null,
readiness: 'initializing',
history: { loading: false },
},
updateContext,
};
});
Object.defineProperty(content, 'chatContextValue', {
configurable: true,
value: {
messages: [],
status: 'idle',
error: null,
abortController: null,
},
});
Object.defineProperty(content, '_initializeScrollListeners', {
configurable: true,
value: vi.fn(),
});
await (AIChatContent.prototype as any).updateHistory.call(content);
(content as any).updated(new Map([['runtimeSnapshot', null]]));
expect(updateContext).not.toHaveBeenCalled();
expect(content.chatContextValue.messages).toBe(optimisticMessages);
expect(content.messages).toEqual([runtimeMessage]);
expect(content.isHistoryLoading).toBe(true);
expect(content.chatContextValue.messages).toEqual([]);
expect(content.chatContextValue.status).toBe('idle');
});
});
@@ -12,10 +12,7 @@ import type { WorkspaceDialogService } from '@affine/core/modules/dialogs';
import type { FeatureFlagService } from '@affine/core/modules/feature-flag';
import type { PeekViewService } from '@affine/core/modules/peek-view';
import type { AppThemeService } from '@affine/core/modules/theme';
import type {
ContextEmbedStatus,
CopilotChatHistoryFragment,
} from '@affine/graphql';
import type { CopilotChatHistoryFragment } from '@affine/graphql';
import { SignalWatcher, WithDisposable } from '@blocksuite/affine/global/lit';
import { type EditorHost, ShadowlessElement } from '@blocksuite/affine/std';
import type { ExtensionType } from '@blocksuite/affine/store';
@@ -28,7 +25,9 @@ import { createRef, type Ref, ref } from 'lit/directives/ref.js';
import { styleMap } from 'lit/directives/style-map.js';
import { pick } from 'lodash-es';
import { type AIChatParams, AIProvider } from '../../provider/ai-provider';
import { AIAppEvents } from '../../provider/ai-app-events';
import type { AIChatParams } from '../../provider/ai-provider';
import type { AIChatRuntime, AIChatSnapshot } from '../../runtime/chat';
import { extractSelectedContent } from '../../utils/extract';
import { HISTORY_IMAGE_ACTIONS } from '../../utils/history-image-actions';
import type { SearchMenuConfig } from '../ai-chat-add-context';
@@ -36,8 +35,6 @@ import type { DocDisplayConfig } from '../ai-chat-chips';
import type { AIReasoningConfig } from '../ai-chat-input';
import {
type AIChatMessages,
type ChatAction,
type ChatMessage,
type HistoryMessage,
isChatMessage,
} from '../ai-chat-messages';
@@ -126,9 +123,10 @@ export class AIChatContent extends SignalWatcher(
accessor session!: CopilotChatHistoryFragment | null | undefined;
@property({ attribute: false })
accessor createSession!: () => Promise<
CopilotChatHistoryFragment | undefined
>;
accessor runtime: AIChatRuntime | null | undefined;
@property({ attribute: false })
accessor runtimeSnapshot: AIChatSnapshot | null | undefined;
@property({ attribute: false })
accessor workspaceId!: string;
@@ -172,14 +170,6 @@ export class AIChatContent extends SignalWatcher(
@property({ attribute: false })
accessor aiModelService!: AIModelService;
@property({ attribute: false })
accessor onEmbeddingProgressChange:
| ((count: Record<ContextEmbedStatus, number>) => void)
| undefined;
@property({ attribute: false })
accessor onContextChange!: (context: Partial<ChatContextValue>) => void;
@property({ attribute: false })
accessor onOpenDoc!: (docId: string, sessionId?: string) => void;
@@ -198,9 +188,6 @@ export class AIChatContent extends SignalWatcher(
@state()
accessor chatContextValue: ChatContextValue = DEFAULT_CHAT_CONTEXT_VALUE;
@state()
accessor isHistoryLoading = false;
@state()
private accessor showPreviewPanel = false;
@@ -210,15 +197,13 @@ export class AIChatContent extends SignalWatcher(
private readonly chatMessagesRef: Ref<AIChatMessages> =
createRef<AIChatMessages>();
// request counter to track the latest request
private updateHistoryCounter = 0;
private historyKey: string | undefined;
private lastScrollTop: number | undefined;
get messages() {
return this.chatContextValue.messages.filter(item => {
const messages =
(this.runtimeSnapshot?.messages as HistoryMessage[] | undefined) ??
this.chatContextValue.messages;
return messages.filter(item => {
return (
isChatMessage(item) ||
item.messages?.length === 3 ||
@@ -232,86 +217,15 @@ export class AIChatContent extends SignalWatcher(
return false;
}
private async updateHistory() {
const currentRequest = ++this.updateHistoryCounter;
if (!AIProvider.histories) {
return;
}
const sessionId = this.session?.sessionId;
const nextHistoryKey = `${this.workspaceId}:${this.docId ?? ''}:${
sessionId ?? ''
}`;
const previousHistoryKey = this.historyKey;
const preserveCurrentMessages = previousHistoryKey === nextHistoryKey;
this.historyKey = nextHistoryKey;
const [histories, actions] = await Promise.all([
sessionId
? AIProvider.histories.chats(this.workspaceId, sessionId)
: Promise.resolve([]),
this.docId && this.showActions
? AIProvider.histories.actions(this.workspaceId, this.docId)
: Promise.resolve([]),
]);
// Check if this is still the latest request
if (currentRequest !== this.updateHistoryCounter) {
return;
}
if (
!preserveCurrentMessages &&
(this.chatContextValue.status === 'loading' ||
this.chatContextValue.status === 'transmitting') &&
this.chatContextValue.messages.length
) {
return;
}
const messages: HistoryMessage[] = preserveCurrentMessages
? this.chatContextValue.messages.slice().filter(isChatMessage)
: [];
const chatActions = (actions || []) as ChatAction[];
messages.push(...chatActions);
const chatMessages = (histories?.[0]?.messages || []) as ChatMessage[];
messages.push(...chatMessages);
this.updateContext({
messages: messages.sort(
(a, b) =>
new Date(a.createdAt).getTime() - new Date(b.createdAt).getTime()
),
});
}
private readonly updateActions = async () => {
if (!this.docId || !AIProvider.histories || !this.showActions) {
return;
}
const actions = await AIProvider.histories.actions(
this.workspaceId,
this.docId
get isHistoryLoading() {
const snapshot = this.runtimeSnapshot;
return (
snapshot?.readiness === 'initializing' || !!snapshot?.history.loading
);
if (actions && actions.length) {
const chatMessages = this.chatContextValue.messages.filter(message =>
isChatMessage(message)
);
const chatActions = actions as ChatAction[];
const messages: HistoryMessage[] = [...chatMessages, ...chatActions];
this.updateContext({
messages: messages.sort(
(a, b) =>
new Date(a.createdAt).getTime() - new Date(b.createdAt).getTime()
),
});
}
};
}
private readonly updateContext = (context: Partial<ChatContextValue>) => {
this.chatContextValue = { ...this.chatContextValue, ...context };
this.onContextChange?.(context);
this.updateDraft(context).catch(console.error);
};
@@ -331,9 +245,7 @@ export class AIChatContent extends SignalWatcher(
};
private readonly initChatContent = async () => {
this.isHistoryLoading = true;
await this.updateHistory();
this.isHistoryLoading = false;
this._initializeScrollListeners();
};
protected override firstUpdated(): void {}
@@ -382,13 +294,13 @@ export class AIChatContent extends SignalWatcher(
public openPreviewPanel(content?: TemplateResult<1>) {
this.showPreviewPanel = true;
if (content) this.previewPanelContent = content;
AIProvider.slots.previewPanelOpenChange.next(true);
AIAppEvents.previewPanelOpenChange.next(true);
}
public closePreviewPanel(destroyContent: boolean = false) {
this.showPreviewPanel = false;
if (destroyContent) this.previewPanelContent = null;
AIProvider.slots.previewPanelOpenChange.next(false);
AIAppEvents.previewPanelOpenChange.next(false);
}
public get isPreviewPanelOpen() {
@@ -416,19 +328,7 @@ export class AIChatContent extends SignalWatcher(
this.subscriptionService.subscription.revalidate();
this._disposables.add(
AIProvider.slots.actions.subscribe(({ event }) => {
const { status } = this.chatContextValue;
if (
event === 'finished' &&
(status === 'idle' || status === 'success')
) {
this.updateActions().catch(console.error);
}
})
);
this._disposables.add(
AIProvider.slots.requestOpenWithChat.subscribe(
AIAppEvents.requestOpenWithChat.subscribe(
(params: AIChatParams | null) => {
if (!params) {
return;
@@ -445,7 +345,7 @@ export class AIChatContent extends SignalWatcher(
.catch(console.error);
}
}
AIProvider.slots.requestOpenWithChat.next(null);
AIAppEvents.requestOpenWithChat.next(null);
}
)
);
@@ -463,7 +363,8 @@ export class AIChatContent extends SignalWatcher(
.workspaceId=${this.workspaceId}
.docId=${this.docId}
.session=${this.session}
.createSession=${this.createSession}
.runtime=${this.runtime}
.runtimeSnapshot=${this.runtimeSnapshot}
.chatContextValue=${this.chatContextValue}
.updateContext=${this.updateContext}
.isHistoryLoading=${this.isHistoryLoading}
@@ -491,10 +392,10 @@ export class AIChatContent extends SignalWatcher(
.workspaceId=${this.workspaceId}
.docId=${this.docId}
.session=${this.session}
.createSession=${this.createSession}
.runtime=${this.runtime}
.runtimeSnapshot=${this.runtimeSnapshot}
.chatContextValue=${this.chatContextValue}
.updateContext=${this.updateContext}
.onEmbeddingProgressChange=${this.onEmbeddingProgressChange}
.reasoningConfig=${this.reasoningConfig}
.docDisplayConfig=${this.docDisplayConfig}
.searchMenuConfig=${this.searchMenuConfig}
@@ -24,19 +24,14 @@ import { repeat } from 'lit/directives/repeat.js';
import { styleMap } from 'lit/directives/style-map.js';
import { ChatAbortIcon } from '../../_common/icons';
import { type AIError, AIProvider, type AISendParams } from '../../provider';
import { AIAppEvents, type AISendParams } from '../../provider';
import type { AIChatRuntime, AIChatSnapshot } from '../../runtime/chat';
import { reportResponse } from '../../utils/action-reporter';
import { readBlobAsURL } from '../../utils/image';
import { mergeStreamObjects } from '../../utils/stream-objects';
import type { SearchMenuConfig } from '../ai-chat-add-context';
import { addFilesToChat } from '../ai-chat-chips/attachment-utils';
import type { ChatChip, DocDisplayConfig } from '../ai-chat-chips/type';
import { isDocChip } from '../ai-chat-chips/utils';
import {
type ChatMessage,
isChatMessage,
StreamObjectSchema,
} from '../ai-chat-messages';
import type { AIChatInputContext, AIReasoningConfig } from './type';
function getFirstTwoLines(text: string) {
@@ -340,6 +335,12 @@ export class AIChatInput extends SignalWatcher(
@property({ attribute: false })
accessor session!: CopilotChatHistoryFragment | null | undefined;
@property({ attribute: false })
accessor runtime: AIChatRuntime | null | undefined;
@property({ attribute: false })
accessor runtimeSnapshot: AIChatSnapshot | null | undefined;
@property({ attribute: false })
accessor isContextProcessing!: boolean | undefined;
@@ -371,11 +372,6 @@ export class AIChatInput extends SignalWatcher(
@property({ attribute: false })
accessor chips: ChatChip[] = [];
@property({ attribute: false })
accessor createSession!: () => Promise<
CopilotChatHistoryFragment | undefined
>;
@property({ attribute: false })
accessor updateContext!: (context: Partial<AIChatInputContext>) => void;
@@ -441,7 +437,7 @@ export class AIChatInput extends SignalWatcher(
super.connectedCallback();
this._disposables.add(
AIProvider.slots.requestSendWithChat.subscribe(
AIAppEvents.requestSendWithChat.subscribe(
(params: AISendParams | null) => {
if (!params) {
return;
@@ -455,13 +451,13 @@ export class AIChatInput extends SignalWatcher(
this.send(input).catch(console.error);
}, 0);
}
AIProvider.slots.requestSendWithChat.next(null);
AIAppEvents.requestSendWithChat.next(null);
}
)
);
this._disposables.add(
AIProvider.slots.requestOpenWithChat.subscribe(params => {
AIAppEvents.requestOpenWithChat.subscribe(params => {
if (!params) return;
const { input, host } = params;
@@ -552,7 +548,8 @@ export class AIChatInput extends SignalWatcher(
}
protected override render() {
const { images, status } = this.chatContextValue;
const { images } = this.chatContextValue;
const status = this.runtimeSnapshot?.status ?? this.chatContextValue.status;
const hasImages = images.length > 0;
const maxHeight = hasImages ? 272 + 2 : 200 + 2;
@@ -664,6 +661,10 @@ export class AIChatInput extends SignalWatcher(
return true;
}
if (this.runtimeSnapshot && !this.runtimeSnapshot.uiPolicy.canSend) {
return true;
}
if (this.isContextProcessing) {
return true;
}
@@ -782,9 +783,11 @@ export class AIChatInput extends SignalWatcher(
};
private readonly _handleAbort = () => {
this.chatContextValue.abortController?.abort();
this.updateContext({ status: 'success' });
reportResponse('aborted:stop');
if (this.runtime) {
this.runtime.dispatch({ type: 'stop' }).catch(console.error);
reportResponse('aborted:stop', this.host);
return;
}
};
private readonly _toggleReasoning = (extendedThinking: boolean) => {
@@ -817,153 +820,52 @@ export class AIChatInput extends SignalWatcher(
};
send = async (text: string) => {
try {
const {
status,
markdown,
images,
snapshot,
combinedElementsMarkdown,
html,
} = this.chatContextValue;
if (!this.runtime) return;
const { markdown, images, snapshot, combinedElementsMarkdown, html } =
this.chatContextValue;
const userInput = (markdown ? `${markdown}\n` : '') + text;
const imageAttachments = await Promise.all(
images?.map(image => readBlobAsURL(image))
);
const contexts = await this._getMatchedContexts();
const enableSendDetailedObject =
this.affineFeatureFlagService.flags.enable_send_detailed_object_to_ai
.value;
const userInfo = AIAppEvents.userInfo.value;
if (status === 'loading' || status === 'transmitting') return;
if (!text) return;
if (!AIProvider.actions.chat) return;
const abortController = new AbortController();
this.updateContext({
images: [],
status: 'loading',
error: null,
quote: '',
markdown: '',
abortController,
});
const imageAttachments = await Promise.all(
images?.map(image => readBlobAsURL(image))
);
const userInput = (markdown ? `${markdown}\n` : '') + text;
// optimistic update messages
await this._preUpdateMessages(userInput, imageAttachments);
const sessionId = (await this.createSession())?.sessionId;
let contexts = await this._getMatchedContexts();
if (abortController.signal.aborted) {
return;
}
const enableSendDetailedObject =
this.affineFeatureFlagService.flags.enable_send_detailed_object_to_ai
.value;
const modelId = this.aiModelService.modelId.value;
const stream = await AIProvider.actions.chat({
sessionId,
input: userInput,
contexts: {
...contexts,
selectedSnapshot:
snapshot && enableSendDetailedObject ? snapshot : undefined,
selectedMarkdown:
combinedElementsMarkdown && enableSendDetailedObject
? combinedElementsMarkdown
: undefined,
html: html || undefined,
},
docId: this.docId,
attachments: images,
workspaceId: this.workspaceId,
stream: true,
signal: abortController.signal,
isRootSession: this.isRootSession,
where: this.trackOptions?.where,
control: this.trackOptions?.control,
reasoning: this._isReasoningActive,
toolsConfig: this.aiToolsConfigService.config.value,
modelId,
});
for await (const text of stream) {
const messages = this.chatContextValue.messages.slice(0);
const last = messages.at(-1);
if (last && isChatMessage(last)) {
try {
const parsed = StreamObjectSchema.parse(JSON.parse(text));
const streamObjects = mergeStreamObjects([
...(last.streamObjects ?? []),
parsed,
]);
messages[messages.length - 1] = {
...last,
streamObjects,
};
} catch {
messages[messages.length - 1] = {
...last,
content: last.content + text,
};
}
this.updateContext({ messages, status: 'transmitting' });
}
}
this.updateContext({ status: 'success' });
this.onChatSuccess?.();
// update message id from server
await this._postUpdateMessages();
} catch (error) {
this.updateContext({ status: 'error', error: error as AIError });
} finally {
this.updateContext({ abortController: null });
}
};
private readonly _preUpdateMessages = async (
userInput: string,
attachments: string[]
) => {
const userInfo = await AIProvider.userInfo;
this.updateContext({
messages: [
...this.chatContextValue.messages,
{
id: '',
role: 'user',
content: userInput,
createdAt: new Date().toISOString(),
attachments,
userId: userInfo?.id,
userName: userInfo?.name,
avatarUrl: userInfo?.avatarUrl ?? undefined,
},
{
id: '',
role: 'assistant',
content: '',
createdAt: new Date().toISOString(),
},
],
images: [],
quote: '',
markdown: '',
});
};
private readonly _postUpdateMessages = async () => {
const sessionId = this.session?.sessionId;
if (!sessionId || !AIProvider.histories) return;
const { messages } = this.chatContextValue;
const last = messages[messages.length - 1] as ChatMessage;
if (!last.id) {
const historyIds = await AIProvider.histories.ids(
this.workspaceId,
this.docId,
{ sessionId, withMessages: true }
);
if (!historyIds || !historyIds[0]) return;
last.id = historyIds[0].messages.at(-1)?.id ?? '';
}
await this.runtime.dispatch({
type: 'send',
input: userInput,
contexts: {
...contexts,
selectedSnapshot:
snapshot && enableSendDetailedObject ? snapshot : undefined,
selectedMarkdown:
combinedElementsMarkdown && enableSendDetailedObject
? combinedElementsMarkdown
: undefined,
html: html || undefined,
},
attachments: images,
attachmentPreviews: imageAttachments,
isRootSession: this.isRootSession,
where: this.trackOptions?.where,
control: this.trackOptions?.control,
reasoning: this._isReasoningActive,
toolsConfig: this.aiToolsConfigService.config.value,
modelId: this.aiModelService.modelId.value,
userInfo: {
userId: userInfo?.id,
userName: userInfo?.name,
avatarUrl: userInfo?.avatarUrl ?? undefined,
},
});
this.onChatSuccess?.();
};
private async _getMatchedContexts() {
@@ -93,4 +93,35 @@ describe('AIChatMessages scrolling', () => {
expect(element.canScrollDown).toBe(false);
expect(scrollToEnd).toHaveBeenCalled();
});
test('message keys are scoped by active tab', () => {
const element = {} as AIChatMessages;
const message = {
id: 'message-1',
role: 'assistant',
content: 'reply',
createdAt: new Date().toISOString(),
};
element.runtimeSnapshot = {
activeTabId: 'session-1',
} as AIChatMessages['runtimeSnapshot'];
const firstKey = (AIChatMessages.prototype as any)._getMessageKey.call(
element,
message,
0
);
element.runtimeSnapshot = {
activeTabId: 'session-2',
} as AIChatMessages['runtimeSnapshot'];
const secondKey = (AIChatMessages.prototype as any)._getMessageKey.call(
element,
message,
0
);
expect(firstKey).toBe('session-1:message-1');
expect(secondKey).toBe('session-2:message-1');
});
});
@@ -19,8 +19,13 @@ import { repeat } from 'lit/directives/repeat.js';
import { debounce, throttle } from 'lodash-es';
import { AffineIcon } from '../../_common/icons';
import { type AIError, AIProvider, UnauthorizedError } from '../../provider';
import { mergeStreamObjects } from '../../utils/stream-objects';
import {
AIAppEvents,
type AIError,
GeneralNetworkError,
UnauthorizedError,
} from '../../provider';
import type { AIChatRuntime, AIChatSnapshot } from '../../runtime/chat';
import type { DocDisplayConfig } from '../ai-chat-chips';
import { type ChatContextValue } from '../ai-chat-content/type';
import type { AIReasoningConfig } from '../ai-chat-input';
@@ -30,12 +35,7 @@ import {
AI_CHAT_SCROLL_DOWN_INDICATOR_THRESHOLD,
} from './auto-scroll';
import { AIPreloadConfig } from './preload-config';
import {
type HistoryMessage,
isChatAction,
isChatMessage,
StreamObjectSchema,
} from './type';
import { type HistoryMessage, isChatAction, isChatMessage } from './type';
export class AIChatMessages extends WithDisposable(ShadowlessElement) {
static override styles = css`
@@ -179,9 +179,10 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) {
accessor session!: CopilotChatHistoryFragment | null | undefined;
@property({ attribute: false })
accessor createSession!: () => Promise<
CopilotChatHistoryFragment | undefined
>;
accessor runtime: AIChatRuntime | null | undefined;
@property({ attribute: false })
accessor runtimeSnapshot: AIChatSnapshot | null | undefined;
@property({ attribute: false })
accessor updateContext!: (context: Partial<ChatContextValue>) => void;
@@ -227,8 +228,21 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) {
private _lastObservedScrollTop = 0;
private get _isReasoningActive() {
return !!this.reasoningConfig.enabled.value;
private get chatStatus() {
return this.runtimeSnapshot?.status ?? this.chatContextValue.status;
}
private get chatError(): AIError | null {
return this.toAIError(
this.runtimeSnapshot?.error ?? this.chatContextValue.error
);
}
private toAIError(error: Error | null): AIError | null {
if (!error) return null;
return 'type' in error
? (error as AIError)
: new GeneralNetworkError(error.message);
}
private _renderAIOnboarding() {
@@ -296,8 +310,20 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) {
return scrollHeight - scrollTop - clientHeight;
}
private _getMessageKey(item: HistoryMessage, index: number) {
const tabKey =
this.runtimeSnapshot?.activeTabId ??
this.runtimeSnapshot?.activeSessionId ??
'legacy';
if (isChatMessage(item)) {
return `${tabKey}:${item.id || index}`;
}
return `${tabKey}:${index}`;
}
protected override render() {
const { status, error } = this.chatContextValue;
const status = this.chatStatus;
const error = this.chatError;
const { isHistoryLoading } = this;
const filteredItems = this.messages;
@@ -337,7 +363,7 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) {
</div> `
: repeat(
filteredItems,
(_, index) => index,
(item, index) => this._getMessageKey(item, index),
(item, index) => {
const isLast = index === filteredItems.length - 1;
if (isChatMessage(item) && item.role === 'user') {
@@ -389,21 +415,19 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) {
super.connectedCallback();
const { disposables } = this;
Promise.resolve(AIProvider.userInfo)
.then(res => {
this.avatarUrl = res?.avatarUrl ?? '';
})
.catch(console.error);
this.avatarUrl = AIAppEvents.userInfo.value?.avatarUrl ?? '';
disposables.add(
AIProvider.slots.userInfo.subscribe(userInfo => {
const { status, error } = this.chatContextValue;
AIAppEvents.userInfo.subscribe(userInfo => {
const status = this.chatStatus;
const error = this.chatError;
this.avatarUrl = userInfo?.avatarUrl ?? '';
if (
status === 'error' &&
error instanceof UnauthorizedError &&
userInfo
) {
this.runtime?.dispatch({ type: 'clearError' }).catch(console.error);
this.updateContext({ status: 'idle', error: null });
}
})
@@ -452,7 +476,7 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) {
if (changedProperties.has('messages')) {
this._onScroll();
if (this.chatContextValue.status === 'transmitting') {
if (this.chatStatus === 'transmitting') {
this._throttledScrollToEnd();
} else if (this._autoScrollEnabled) {
this.scrollToEnd();
@@ -460,9 +484,9 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) {
}
if (
changedProperties.has('chatContextValue') &&
(this.chatContextValue.status === 'success' ||
this.chatContextValue.status === 'error')
(changedProperties.has('chatContextValue') ||
changedProperties.has('runtimeSnapshot')) &&
(this.chatStatus === 'success' || this.chatStatus === 'error')
) {
this._onScroll();
}
@@ -484,69 +508,11 @@ export class AIChatMessages extends WithDisposable(ShadowlessElement) {
}
retry = async () => {
try {
const sessionId = (await this.createSession())?.sessionId;
if (!sessionId) return;
if (!AIProvider.actions.chat) return;
const abortController = new AbortController();
const messages = [...this.chatContextValue.messages];
const last = messages[messages.length - 1];
if ('content' in last) {
last.content = '';
last.streamObjects = [];
last.createdAt = new Date().toISOString();
}
this.updateContext({
messages,
status: 'loading',
error: null,
abortController,
});
const stream = await AIProvider.actions.chat({
sessionId,
retry: true,
docId: this.docId,
workspaceId: this.workspaceId,
stream: true,
signal: abortController.signal,
where: 'chat-panel',
control: 'chat-send',
isRootSession: true,
reasoning: this._isReasoningActive,
toolsConfig: this.aiToolsConfigService.config.value,
});
for await (const text of stream) {
const messages = this.chatContextValue.messages.slice(0);
const last = messages.at(-1);
if (last && isChatMessage(last)) {
try {
const parsed = StreamObjectSchema.parse(JSON.parse(text));
const streamObjects = mergeStreamObjects([
...(last.streamObjects ?? []),
parsed,
]);
messages[messages.length - 1] = {
...last,
streamObjects,
};
} catch {
messages[messages.length - 1] = {
...last,
content: last.content + text,
};
}
this.updateContext({ messages, status: 'transmitting' });
}
}
this.updateContext({ status: 'success' });
} catch (error) {
this.updateContext({ status: 'error', error: error as AIError });
} finally {
this.updateContext({ abortController: null });
}
if (!this.runtime) return;
const last = this.messages.at(-1);
await this.runtime.dispatch({
type: 'retry',
messageId: last && isChatMessage(last) ? last.id : '',
});
};
}
@@ -6,7 +6,7 @@ import {
SendIcon,
} from '@blocksuite/icons/lit';
import { AIProvider } from '../../provider/ai-provider.js';
import { AIAppEvents } from '../../provider/ai-app-events.js';
import completeWritingWithAI from './templates/completeWritingWithAI.zip';
import freelyCommunicateWithAI from './templates/freelyCommunicateWithAI.zip';
import readAforeign from './templates/readAforeign.zip';
@@ -19,7 +19,7 @@ export const AIPreloadConfig = [
text: 'Read a foreign language article with AI',
testId: 'read-foreign-language-article-with-ai',
handler: () => {
AIProvider.slots.requestInsertTemplate.next({
AIAppEvents.requestInsertTemplate.next({
template: readAforeign,
mode: 'edgeless',
});
@@ -30,7 +30,7 @@ export const AIPreloadConfig = [
text: 'Tidy an article with AI MindMap Action',
testId: 'tidy-an-article-with-ai-mindmap-action',
handler: () => {
AIProvider.slots.requestInsertTemplate.next({
AIAppEvents.requestInsertTemplate.next({
template: TidyMindMapV3,
mode: 'edgeless',
});
@@ -41,7 +41,7 @@ export const AIPreloadConfig = [
text: 'Add illustrations to the article',
testId: 'add-illustrations-to-the-article',
handler: () => {
AIProvider.slots.requestInsertTemplate.next({
AIAppEvents.requestInsertTemplate.next({
template: redHat,
mode: 'edgeless',
});
@@ -52,7 +52,7 @@ export const AIPreloadConfig = [
text: 'Complete writing with AI',
testId: 'complete-writing-with-ai',
handler: () => {
AIProvider.slots.requestInsertTemplate.next({
AIAppEvents.requestInsertTemplate.next({
template: completeWritingWithAI,
mode: 'edgeless',
});
@@ -63,7 +63,7 @@ export const AIPreloadConfig = [
text: 'Freely communicate with AI',
testId: 'freely-communicate-with-ai',
handler: () => {
AIProvider.slots.requestInsertTemplate.next({
AIAppEvents.requestInsertTemplate.next({
template: freelyCommunicateWithAI,
mode: 'edgeless',
});
@@ -7,9 +7,15 @@ import { css, html, type PropertyValues } from 'lit';
import { property } from 'lit/decorators.js';
import { repeat } from 'lit/directives/repeat.js';
import type { AIChatRuntime, AIChatSnapshot } from '../../runtime/chat';
const DEFAULT_TAB_TITLE = 'New chat';
const TITLE_MAX_LENGTH = 28;
type RenderTabItem =
| { kind: 'draft' }
| { kind: 'session'; session: CopilotChatHistoryFragment };
function truncate(text: string): string {
if (text.length <= TITLE_MAX_LENGTH) return text;
return `${text.slice(0, TITLE_MAX_LENGTH).trimEnd()}`;
@@ -27,19 +33,10 @@ function deriveTabTitle(session: CopilotChatHistoryFragment): string {
export class AIChatTabs extends WithDisposable(ShadowlessElement) {
@property({ attribute: false })
accessor sessions: CopilotChatHistoryFragment[] = [];
accessor runtimeSnapshot: AIChatSnapshot | null | undefined;
@property({ attribute: false })
accessor activeSessionId: string | undefined;
@property({ attribute: false })
accessor showDraftTab = false;
@property({ attribute: false })
accessor onSelectTab!: (sessionId: string) => void;
@property({ attribute: false })
accessor onCloseTab!: (sessionId: string) => void;
accessor runtime!: AIChatRuntime;
static override styles = css`
ai-chat-tabs {
@@ -100,6 +97,10 @@ export class AIChatTabs extends WithDisposable(ShadowlessElement) {
color: ${unsafeCSSVarV2('text/primary')};
}
.tab[data-kind='draft'] {
padding: 0 10px;
}
.tab-title {
white-space: nowrap;
overflow: hidden;
@@ -132,21 +133,46 @@ export class AIChatTabs extends WithDisposable(ShadowlessElement) {
`;
override render() {
if (!this.sessions.length && !this.showDraftTab) return html``;
const items = this._getRenderItems();
if (!items.length) return html``;
return html`
<div class="ai-chat-tabs" data-testid="ai-chat-tabs">
<div class="tabs-scroll" @wheel=${this._handleWheel}>
${this.showDraftTab ? this._renderDraftTab() : null}
${repeat(
this.sessions,
session => session.sessionId,
session => this._renderTab(session)
items,
item => (item.kind === 'draft' ? 'draft' : item.session.sessionId),
item =>
item.kind === 'draft'
? this._renderDraftTab()
: this._renderTab(item.session)
)}
</div>
</div>
`;
}
private _getRenderItems(): RenderTabItem[] {
const snapshot = this.runtimeSnapshot;
if (!snapshot) return [];
const sessions = new Map(
snapshot.sessions.map(session => [session.sessionId, session])
);
const items: RenderTabItem[] = [];
for (const tab of snapshot.tabs) {
if (tab.kind === 'draft') {
if (snapshot.uiPolicy.showDraftTab) {
items.push({ kind: 'draft' });
}
continue;
}
const session = sessions.get(tab.sessionId);
if (session) {
items.push({ kind: 'session', session });
}
}
return items;
}
private readonly _handleWheel = (e: WheelEvent) => {
const el = e.currentTarget as HTMLElement;
if (el.scrollWidth <= el.clientWidth) return;
@@ -157,7 +183,7 @@ export class AIChatTabs extends WithDisposable(ShadowlessElement) {
};
private _renderTab(session: CopilotChatHistoryFragment) {
const active = session.sessionId === this.activeSessionId;
const active = session.sessionId === this.runtimeSnapshot?.activeSessionId;
const title = deriveTabTitle(session);
return html`
<div
@@ -182,10 +208,12 @@ export class AIChatTabs extends WithDisposable(ShadowlessElement) {
}
private _renderDraftTab() {
const active = this.runtimeSnapshot?.activeTabId?.startsWith('draft:');
return html`
<div
class="tab"
data-active="true"
data-kind="draft"
data-active=${active}
data-testid="ai-chat-draft-tab"
title=${DEFAULT_TAB_TITLE}
>
@@ -195,23 +223,27 @@ export class AIChatTabs extends WithDisposable(ShadowlessElement) {
}
private readonly _handleSelect = (sessionId: string) => {
if (sessionId === this.activeSessionId) return;
this.onSelectTab(sessionId);
if (sessionId === this.runtimeSnapshot?.activeSessionId) return;
this.runtime
.dispatch({ type: 'openSession', sessionId })
.catch(console.error);
};
private readonly _handleClose = (e: Event, sessionId: string) => {
e.stopPropagation();
this.onCloseTab(sessionId);
this.runtime
.dispatch({ type: 'closeTab', tabId: sessionId })
.catch(console.error);
};
override updated(changedProps: PropertyValues) {
super.updated(changedProps);
if (
(changedProps.has('activeSessionId') || changedProps.has('sessions')) &&
this.activeSessionId
) {
if (changedProps.has('runtimeSnapshot')) {
const activeSessionId = this.runtimeSnapshot?.activeSessionId;
const activeTab = this.renderRoot.querySelector(
`[data-session-id="${this.activeSessionId}"]`
activeSessionId
? `[data-session-id="${activeSessionId}"]`
: `[data-testid="ai-chat-draft-tab"]`
);
activeTab?.scrollIntoView({
behavior: 'smooth',
@@ -14,34 +14,22 @@ import { flip, offset } from '@floating-ui/dom';
import { css, html } from 'lit';
import { property, query } from 'lit/decorators.js';
import type { AIChatRuntime, AIChatSnapshot } from '../../runtime/chat';
import type { DocDisplayConfig } from '../ai-chat-chips';
import type { ChatStatus } from '../ai-chat-messages';
export class AIChatToolbar extends WithDisposable(ShadowlessElement) {
@property({ attribute: false })
accessor session!: CopilotChatHistoryFragment | null | undefined;
@property({ attribute: false })
accessor workspaceId!: string;
accessor runtime!: AIChatRuntime;
@property({ attribute: false })
accessor runtimeSnapshot!: AIChatSnapshot;
@property({ attribute: false })
accessor docId: string | undefined;
@property({ attribute: false })
accessor status!: ChatStatus;
@property({ attribute: false })
accessor onNewSession!: () => void;
@property({ attribute: false })
accessor canCreateNewSession = true;
@property({ attribute: false })
accessor onTogglePin!: () => Promise<void>;
@property({ attribute: false })
accessor onOpenSession!: (sessionId: string) => void;
@property({ attribute: false })
accessor onOpenDoc!: (docId: string, sessionId: string) => void;
@@ -62,7 +50,12 @@ export class AIChatToolbar extends WithDisposable(ShadowlessElement) {
private abortController: AbortController | null = null;
get isGenerating() {
return this.status === 'transmitting' || this.status === 'loading';
const status = this.runtimeSnapshot.status;
return status === 'transmitting' || status === 'loading';
}
get canCreateNewSession() {
return this.runtimeSnapshot.uiPolicy.canCreateNewSession;
}
static override styles = css`
@@ -141,7 +134,7 @@ export class AIChatToolbar extends WithDisposable(ShadowlessElement) {
);
return;
}
await this.onTogglePin();
await this.runtime.dispatch({ type: 'togglePinActiveSession' });
};
private readonly unpinConfirm = async () => {
@@ -157,7 +150,7 @@ export class AIChatToolbar extends WithDisposable(ShadowlessElement) {
if (!confirm) {
return false;
}
await this.onTogglePin();
await this.runtime.dispatch({ type: 'togglePinActiveSession' });
} catch {
this.notificationService.toast('Failed to unpin the chat');
}
@@ -168,7 +161,7 @@ export class AIChatToolbar extends WithDisposable(ShadowlessElement) {
private readonly onPlusClick = async () => {
const confirm = await this.unpinConfirm();
if (confirm) {
this.onNewSession();
await this.runtime.dispatch({ type: 'createNewSession' });
}
};
@@ -179,7 +172,11 @@ export class AIChatToolbar extends WithDisposable(ShadowlessElement) {
}
const confirm = await this.unpinConfirm();
if (confirm) {
this.onOpenSession(sessionId);
await this.runtime.dispatch({
type: 'openSession',
sessionId,
});
this.closeHistoryMenu();
}
};
@@ -191,7 +188,7 @@ export class AIChatToolbar extends WithDisposable(ShadowlessElement) {
this.onOpenDoc(docId, sessionId);
};
private readonly toggleHistoryMenu = () => {
private readonly toggleHistoryMenu = async () => {
if (this.abortController) {
this.abortController.abort();
return;
@@ -202,12 +199,23 @@ export class AIChatToolbar extends WithDisposable(ShadowlessElement) {
this.abortController = null;
});
try {
await this.runtime.dispatch({ type: 'refreshHistory' });
} catch (error) {
console.error(error);
}
if (this.abortController.signal.aborted) {
return;
}
createLitPortal({
template: html`
<ai-session-history
.session=${this.session}
.workspaceId=${this.workspaceId}
.docId=${this.docId}
.recentSessions=${this.runtime.getSnapshot().history.recent}
.currentDocSessions=${this.runtime.getSnapshot().history.currentDoc}
.loading=${this.runtime.getSnapshot().history.loading}
.docDisplayConfig=${this.docDisplayConfig}
.onSessionClick=${this.onSessionClick}
.onSessionDelete=${this.onSessionDelete}
@@ -5,9 +5,8 @@ import { unsafeCSSVar, unsafeCSSVarV2 } from '@blocksuite/affine/shared/theme';
import { ShadowlessElement } from '@blocksuite/affine/std';
import { DeleteIcon } from '@blocksuite/icons/lit';
import { css, html, nothing, type PropertyValues } from 'lit';
import { property, query, state } from 'lit/decorators.js';
import { property, state } from 'lit/decorators.js';
import { AIProvider } from '../../provider';
import type { DocDisplayConfig } from '../ai-chat-chips';
interface GroupedSessions {
@@ -17,6 +16,31 @@ interface GroupedSessions {
older: BlockSuitePresets.AIRecentSession[];
}
type HistorySessionWithMessages = BlockSuitePresets.AIRecentSession &
Partial<Pick<CopilotChatHistoryFragment, 'messages'>>;
const DEFAULT_SESSION_TITLE = 'New chat';
const TITLE_MAX_LENGTH = 28;
function truncateSessionTitle(text: string) {
if (text.length <= TITLE_MAX_LENGTH) return text;
return `${text.slice(0, TITLE_MAX_LENGTH).trimEnd()}`;
}
function deriveSessionTitle(session: HistorySessionWithMessages) {
const explicit = session.title?.trim();
if (explicit) return truncateSessionTitle(explicit);
const firstUserMessage = session.messages?.find(
message => message.role === 'user'
);
const raw = firstUserMessage?.content?.trim();
if (!raw) return DEFAULT_SESSION_TITLE;
const newlineIdx = raw.indexOf('\n');
return truncateSessionTitle(
newlineIdx === -1 ? raw : raw.slice(0, newlineIdx)
);
}
export class AISessionHistory extends WithDisposable(ShadowlessElement) {
static override styles = css`
.ai-session-history {
@@ -157,10 +181,16 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
accessor session!: CopilotChatHistoryFragment | null | undefined;
@property({ attribute: false })
accessor workspaceId!: string;
accessor docId: string | undefined;
@property({ attribute: false })
accessor docId: string | undefined;
accessor recentSessions: BlockSuitePresets.AIRecentSession[] = [];
@property({ attribute: false })
accessor currentDocSessions: BlockSuitePresets.AIRecentSession[] = [];
@property({ attribute: false })
accessor loading = false;
@property({ attribute: false })
accessor docDisplayConfig!: DocDisplayConfig;
@@ -176,30 +206,9 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
@property({ attribute: false })
accessor onDocClick!: (docId: string, sessionId: string) => void;
@query('.ai-session-history')
accessor scrollContainer!: HTMLElement;
@state()
private accessor sessions: BlockSuitePresets.AIRecentSession[] | undefined;
@state()
private accessor currentDocSessions:
| BlockSuitePresets.AIRecentSession[]
| undefined;
@state()
private accessor loadingMore = false;
@state()
private accessor hasMore = true;
@state()
private accessor selectedSessionId: string | undefined;
private accessor currentOffset = 0;
private readonly pageSize = 10;
private groupSessionsByTime(
sessions: BlockSuitePresets.AIRecentSession[]
): GroupedSessions {
@@ -248,51 +257,9 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
return grouped;
}
private async getRecentSessions() {
this.loadingMore = true;
const moreSessions =
(await AIProvider.session?.getRecentSessions(
this.workspaceId,
this.pageSize,
this.currentOffset
)) || [];
this.sessions = [...(this.sessions || []), ...moreSessions];
this.currentOffset += moreSessions.length;
this.hasMore = moreSessions.length === this.pageSize;
this.loadingMore = false;
}
private async getCurrentDocSessions() {
if (!this.docId) {
this.currentDocSessions = [];
return;
}
this.currentDocSessions =
(await AIProvider.session?.getSessions(this.workspaceId, this.docId, {
action: false,
fork: false,
})) || [];
}
private readonly onScroll = () => {
if (!this.hasMore || this.loadingMore) {
return;
}
// load more when within 50px of bottom
const { scrollTop, scrollHeight, clientHeight } = this.scrollContainer;
const threshold = 50;
if (scrollTop + clientHeight >= scrollHeight - threshold) {
this.getRecentSessions().catch(console.error);
}
};
override connectedCallback() {
super.connectedCallback();
this.selectedSessionId = this.session?.sessionId ?? undefined;
this.getCurrentDocSessions().catch(console.error);
this.getRecentSessions().catch(console.error);
}
protected override willUpdate(changedProperties: PropertyValues) {
@@ -301,14 +268,6 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
}
}
override firstUpdated(changedProperties: PropertyValues) {
super.firstUpdated(changedProperties);
this.disposables.add(() => {
this.scrollContainer.removeEventListener('scroll', this.onScroll);
});
this.scrollContainer.addEventListener('scroll', this.onScroll);
}
private renderSessionGroup(
title: string,
sessions: BlockSuitePresets.AIRecentSession[]
@@ -320,6 +279,7 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
<div class="ai-session-group">
<div class="ai-session-group-title">${title}</div>
${sessions.map(session => {
const sessionTitle = deriveSessionTitle(session);
return html`
<div
class="ai-session-item"
@@ -336,7 +296,7 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
data-session-id=${session.sessionId}
>
<div class="ai-session-title">
${session.title || 'New chat'}
${sessionTitle}
<affine-tooltip .offsetX=${60}>
Click to open this chat
</affine-tooltip>
@@ -395,15 +355,15 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
}
private renderHistory() {
if (!this.sessions) {
if (this.loading) {
return this.renderLoading();
}
const currentDocSessions = this.currentDocSessions ?? [];
const currentDocSessions = this.currentDocSessions;
const currentDocSessionIds = new Set(
currentDocSessions.map(session => session.sessionId)
);
const otherSessions = this.sessions.filter(
const otherSessions = this.recentSessions.filter(
session =>
!currentDocSessionIds.has(session.sessionId) &&
(!this.docId || session.docId !== this.docId)
@@ -1,21 +1,17 @@
import type { CopilotChatHistoryFragment } from '@affine/graphql';
import type { NotificationService } from '@blocksuite/affine/shared/services';
import type { AIChatRuntime, AIChatSnapshot } from '../../runtime/chat';
import type { DocDisplayConfig } from '../ai-chat-chips';
import type { ChatStatus } from '../ai-chat-messages';
import { AIChatToolbar } from './ai-chat-toolbar';
export type ConfigureAIChatToolbarOptions = {
session: CopilotChatHistoryFragment | null | undefined;
workspaceId: string;
runtime: AIChatRuntime;
runtimeSnapshot: AIChatSnapshot;
docId?: string;
status: ChatStatus;
docDisplayConfig: DocDisplayConfig;
notificationService: NotificationService;
onNewSession: () => void;
canCreateNewSession?: boolean;
onTogglePin: () => Promise<void>;
onOpenSession: (sessionId: string) => void;
onOpenDoc: (docId: string, sessionId: string) => void;
onSessionDelete: (session: BlockSuitePresets.AIRecentSession) => void;
};
@@ -31,15 +27,11 @@ export function configureAIChatToolbar(
options: ConfigureAIChatToolbarOptions
): AIChatToolbar {
tool.session = options.session;
tool.workspaceId = options.workspaceId;
tool.runtime = options.runtime;
tool.runtimeSnapshot = options.runtimeSnapshot;
tool.docId = options.docId;
tool.status = options.status;
tool.docDisplayConfig = options.docDisplayConfig;
tool.notificationService = options.notificationService;
tool.onNewSession = options.onNewSession;
tool.canCreateNewSession = options.canCreateNewSession ?? true;
tool.onTogglePin = options.onTogglePin;
tool.onOpenSession = options.onOpenSession;
tool.onOpenDoc = options.onOpenDoc;
tool.onSessionDelete = options.onSessionDelete;
return tool;
@@ -7,7 +7,6 @@ import type { Store } from '@blocksuite/affine/store';
import { css, html } from 'lit';
import { property } from 'lit/decorators.js';
import { AIProvider } from '../../provider';
import type { ChatContextValue } from '../ai-chat-content';
export class AIHistoryClear extends WithDisposable(ShadowlessElement) {
@@ -26,6 +25,9 @@ export class AIHistoryClear extends WithDisposable(ShadowlessElement) {
@property({ attribute: false })
accessor onHistoryCleared!: () => void;
@property({ attribute: false })
accessor onClearHistory!: (sessionIds: string[]) => Promise<void> | void;
static override styles = css`
.chat-history-clear {
cursor: pointer;
@@ -64,11 +66,10 @@ export class AIHistoryClear extends WithDisposable(ShadowlessElement) {
const actionIds = this.chatContextValue.messages
.filter(item => 'sessionId' in item)
.map(item => item.sessionId);
await AIProvider.histories?.cleanup(
this.doc.workspace.id,
this.doc.id,
[...(sessionId ? [sessionId] : []), ...(actionIds || [])]
);
await this.onClearHistory([
...(sessionId ? [sessionId] : []),
...(actionIds || []),
]);
this.notificationService.toast('History cleared');
this.onHistoryCleared?.();
}
@@ -9,7 +9,7 @@ import { flip, offset } from '@floating-ui/dom';
import { css, html, LitElement } from 'lit';
import { property } from 'lit/decorators.js';
import { AIProvider } from '../provider';
import { AIAppEvents } from '../provider';
import { getAIPanelWidget } from '../utils/ai-widgets';
import { extractSelectedContent } from '../utils/extract';
import type { AffineAIPanelWidgetConfig } from '../widgets/ai-panel/type';
@@ -75,7 +75,7 @@ export class AskAIToolbarButton extends WithDisposable(LitElement) {
aiPanel.hide();
extractSelectedContent(this.host)
.then(context => {
AIProvider.slots.requestSendWithChat.next({
AIAppEvents.requestSendWithChat.next({
input,
context,
host: this.host,
@@ -7,10 +7,7 @@ import type {
import type { WorkspaceDialogService } from '@affine/core/modules/dialogs';
import type { FeatureFlagService } from '@affine/core/modules/feature-flag';
import type { AppThemeService } from '@affine/core/modules/theme';
import type {
ContextEmbedStatus,
CopilotChatHistoryFragment,
} from '@affine/graphql';
import type { CopilotChatHistoryFragment } from '@affine/graphql';
import { SignalWatcher, WithDisposable } from '@blocksuite/affine/global/lit';
import { type NotificationService } from '@blocksuite/affine/shared/services';
import { unsafeCSSVarV2 } from '@blocksuite/affine/shared/theme';
@@ -24,7 +21,13 @@ import { createRef, type Ref, ref } from 'lit/directives/ref.js';
import { throttle } from 'lodash-es';
import type { AppSidebarConfig } from '../../chat-panel/chat-config';
import { AIProvider } from '../../provider';
import { AIAppEvents, type AIError } from '../../provider';
import {
AIChatRuntime,
type AIChatSnapshot,
PlaygroundAIChatSessionStrategy,
} from '../../runtime/chat';
import { getAIRequestService } from '../../runtime/request';
import { HISTORY_IMAGE_ACTIONS } from '../../utils/history-image-actions';
import type { SearchMenuConfig } from '../ai-chat-add-context';
import type { DocDisplayConfig } from '../ai-chat-chips';
@@ -32,8 +35,6 @@ import type { ChatContextValue } from '../ai-chat-content';
import type { AIPlaygroundConfig, AIReasoningConfig } from '../ai-chat-input';
import {
type AIChatMessages,
type ChatAction,
type ChatMessage,
type HistoryMessage,
isChatMessage,
} from '../ai-chat-messages';
@@ -202,16 +203,20 @@ export class PlaygroundChat extends SignalWatcher(
accessor chatContextValue: ChatContextValue = DEFAULT_CHAT_CONTEXT_VALUE;
@state()
accessor embeddingProgress: [number, number] = [0, 0];
accessor runtimeSnapshot: AIChatSnapshot | null = null;
private readonly _chatMessagesRef: Ref<AIChatMessages> =
createRef<AIChatMessages>();
// request counter to track the latest request
private _updateHistoryCounter = 0;
private runtime: AIChatRuntime | null = null;
private disposeRuntime: (() => void) | null = null;
get messages() {
return this.chatContextValue.messages.filter(item => {
const messages =
(this.runtimeSnapshot?.messages as HistoryMessage[] | undefined) ??
this.chatContextValue.messages;
return messages.filter(item => {
return (
isChatMessage(item) ||
item.messages?.length === 3 ||
@@ -226,66 +231,46 @@ export class PlaygroundChat extends SignalWatcher(
}
private readonly _initPanel = async () => {
const userId = (await AIProvider.userInfo)?.id;
const userId = AIAppEvents.userInfo.value?.id;
if (!userId) return;
this.isLoading = true;
await this._updateHistory();
this.isLoading = false;
this.ensureRuntime();
};
private readonly _createSession = async () => {
return this.session;
};
private readonly _updateHistory = async () => {
if (!AIProvider.histories) {
return;
}
const currentRequest = ++this._updateHistoryCounter;
const sessionId = this.session?.sessionId;
const [histories, actions] = await Promise.all([
sessionId
? AIProvider.histories.chats(
this.doc.workspace.id,
sessionId,
this.doc.id
)
: Promise.resolve([]),
this.doc.id && this.showActions
? AIProvider.histories.actions(this.doc.workspace.id, this.doc.id)
: Promise.resolve([]),
]);
// Check if this is still the latest request
if (currentRequest !== this._updateHistoryCounter) {
return;
}
const chatActions = (actions || []) as ChatAction[];
const messages: HistoryMessage[] = chatActions;
const chatMessages = (histories?.[0]?.messages || []) as ChatMessage[];
messages.push(...chatMessages);
private readonly syncContextFromRuntime = () => {
const snapshot = this.runtimeSnapshot;
if (!snapshot) return;
this.chatContextValue = {
...this.chatContextValue,
messages: messages.sort(
(a, b) =>
new Date(a.createdAt).getTime() - new Date(b.createdAt).getTime()
),
messages: snapshot.messages as HistoryMessage[],
status: snapshot.status,
error: snapshot.error as AIError | null,
};
this._scrollToEnd();
};
private readonly onEmbeddingProgressChange = (
count: Record<ContextEmbedStatus, number>
) => {
const total = count.finished + count.processing + count.failed;
this.embeddingProgress = [count.finished, total];
private readonly ensureRuntime = () => {
if (!this.session || this.runtime) return;
this.runtime = new AIChatRuntime({
request: getAIRequestService(),
scope: {
kind: 'fork',
workspaceId: this.doc.workspace.id,
docId: this.doc.id,
parentSessionId: this.session.parentSessionId ?? this.session.sessionId,
},
strategy: new PlaygroundAIChatSessionStrategy(),
});
this.disposeRuntime = this.runtime.subscribe(() => {
this.runtimeSnapshot = this.runtime?.getSnapshot() ?? null;
this.syncContextFromRuntime();
});
this.runtimeSnapshot = this.runtime.getSnapshot();
this.runtime
.dispatch({
type: 'openSessionObject',
session: this.session,
})
.catch(console.error);
};
private readonly updateContext = (context: Partial<ChatContextValue>) => {
@@ -303,7 +288,23 @@ export class PlaygroundChat extends SignalWatcher(
this._initPanel().catch(console.error);
}
override disconnectedCallback() {
super.disconnectedCallback();
this.disposeRuntime?.();
this.runtime?.dispose();
this.runtime = null;
this.disposeRuntime = null;
}
protected override updated(_changedProperties: PropertyValues) {
if (_changedProperties.has('session')) {
this.disposeRuntime?.();
this.runtime?.dispose();
this.runtime = null;
this.disposeRuntime = null;
this.ensureRuntime();
}
if (
_changedProperties.has('chatContextValue') &&
(this.chatContextValue.status === 'loading' ||
@@ -322,7 +323,11 @@ export class PlaygroundChat extends SignalWatcher(
}
override render() {
const [done, total] = this.embeddingProgress;
const embeddingCount =
this.runtimeSnapshot?.composer.context.embeddingCount;
const done = embeddingCount?.finished ?? 0;
const total =
done + (embeddingCount?.processing ?? 0) + (embeddingCount?.failed ?? 0);
const isEmbedding = total > 0 && done < total;
return html`<div class="chat-panel-container">
@@ -342,7 +347,23 @@ export class PlaygroundChat extends SignalWatcher(
.doc=${this.doc}
.session=${this.session}
.notificationService=${this.notificationService}
.onHistoryCleared=${this._updateHistory}
.onClearHistory=${async (sessionIds: string[]) => {
for (const sessionId of sessionIds) {
await this.runtime?.dispatch({
type: 'deleteSession',
sessionId,
});
}
}}
.onHistoryCleared=${() =>
this.session
? this.runtime
?.dispatch({
type: 'openSessionObject',
session: this.session,
})
.catch(console.error)
: undefined}
.chatContextValue=${this.chatContextValue}
></ai-history-clear>
<div class="chat-panel-delete">${DeleteIcon()}</div>
@@ -355,7 +376,8 @@ export class PlaygroundChat extends SignalWatcher(
.isHistoryLoading=${this.isLoading}
.chatContextValue=${this.chatContextValue}
.session=${this.session}
.createSession=${this._createSession}
.runtime=${this.runtime}
.runtimeSnapshot=${this.runtimeSnapshot}
.updateContext=${this.updateContext}
.extensions=${this.extensions}
.affineFeatureFlagService=${this.affineFeatureFlagService}
@@ -370,10 +392,10 @@ export class PlaygroundChat extends SignalWatcher(
.workspaceId=${this.doc.workspace.id}
.docId=${this.doc.id}
.session=${this.session}
.createSession=${this._createSession}
.runtime=${this.runtime}
.runtimeSnapshot=${this.runtimeSnapshot}
.chatContextValue=${this.chatContextValue}
.updateContext=${this.updateContext}
.onEmbeddingProgressChange=${this.onEmbeddingProgressChange}
.reasoningConfig=${this.reasoningConfig}
.playgroundConfig=${this.playgroundConfig}
.docDisplayConfig=${this.docDisplayConfig}
@@ -18,7 +18,12 @@ import { property, state } from 'lit/decorators.js';
import { repeat } from 'lit/directives/repeat.js';
import type { AppSidebarConfig } from '../../chat-panel/chat-config';
import { AIProvider } from '../../provider';
import {
AIChatRuntime,
type AIChatScope,
PlaygroundAIChatSessionStrategy,
} from '../../runtime/chat';
import { getAIRequestService } from '../../runtime/request';
import type { SearchMenuConfig } from '../ai-chat-add-context';
import type { DocDisplayConfig } from '../ai-chat-chips';
import type { AIPlaygroundConfig, AIReasoningConfig } from '../ai-chat-input';
@@ -119,27 +124,39 @@ export class PlaygroundContent extends SignalWatcher(
private isSending = false;
private createSessionRuntime(scope: AIChatScope) {
return new AIChatRuntime({
request: getAIRequestService(),
scope,
strategy: new PlaygroundAIChatSessionStrategy(),
});
}
private readonly getSessions = async () => {
const sessions =
(await AIProvider.session?.getSessions(
(await getAIRequestService().getSessions(
this.doc.workspace.id,
this.doc.id,
{ action: false }
)) || [];
const rootSession = sessions?.findLast(session => !session.parentSessionId);
if (!rootSession) {
// Create a new session
const rootSessionId = await AIProvider.session?.createSession({
docId: this.doc.id,
const runtime = this.createSessionRuntime({
kind: 'playground',
workspaceId: this.doc.workspace.id,
promptName: 'Chat With AFFiNE AI',
docId: this.doc.id,
});
if (rootSessionId) {
this.rootSessionId = rootSessionId;
const forkSession = await this.forkSession(rootSessionId);
if (forkSession) {
this.sessions = [forkSession];
try {
const rootSession = await runtime.createSession();
if (rootSession) {
this.rootSessionId = rootSession.sessionId;
const forkSession = await this.forkSession(rootSession.sessionId);
if (forkSession) {
this.sessions = [forkSession];
}
}
} finally {
runtime.dispose();
}
} else {
this.rootSessionId = rootSession.sessionId;
@@ -158,19 +175,17 @@ export class PlaygroundContent extends SignalWatcher(
};
private readonly forkSession = async (parentSessionId: string) => {
const forkSessionId = await AIProvider.forkChat?.({
const runtime = this.createSessionRuntime({
kind: 'playground',
workspaceId: this.doc.workspace.id,
docId: this.doc.id,
sessionId: parentSessionId,
latestMessageId: '',
parentSessionId,
});
if (!forkSessionId) {
return;
try {
return (await runtime.createSession()) ?? undefined;
} finally {
runtime.dispose();
}
return await AIProvider.session?.getSession(
this.doc.workspace.id,
forkSessionId
);
};
private readonly addChat = async () => {
@@ -49,7 +49,7 @@ import {
translateLangs,
} from '../../actions/types';
import type { AIItemGroupConfig } from '../../components/ai-item/types';
import { AIProvider } from '../../provider';
import { AIAppEvents } from '../../provider';
import { getAIPanelWidget } from '../../utils/ai-widgets';
import {
getEdgelessCopilotWidget,
@@ -121,7 +121,7 @@ const othersGroup: AIItemGroupConfig = {
const edgelessCopilot = getEdgelessCopilotWidget(host);
extractSelectedContent(host)
.then(context => {
AIProvider.slots.requestOpenWithChat.next({
AIAppEvents.requestOpenWithChat.next({
host,
mode: 'edgeless',
autoSelect: true,
@@ -7,7 +7,7 @@ import {
import { html } from 'lit';
import type { AIItemGroupConfig } from '../../components/ai-item/types';
import { AIProvider } from '../../provider';
import { AIAppEvents } from '../../provider';
import { getAIPanelWidget } from '../../utils/ai-widgets';
import { getEdgelessCopilotWidget } from '../../utils/edgeless';
import { extractSelectedContent } from '../../utils/extract';
@@ -58,14 +58,14 @@ export function edgelessToolbarAIEntryConfig(): ToolbarModuleConfig {
extractSelectedContent(host)
.then(context => {
if (context?.attachments?.length || context?.docs?.length) {
AIProvider.slots.requestOpenWithChat.next({
AIAppEvents.requestOpenWithChat.next({
input,
host,
context,
autoSelect: true,
});
} else {
AIProvider.slots.requestSendWithChat.next({
AIAppEvents.requestSendWithChat.next({
input,
context,
host,
@@ -2,7 +2,7 @@ import type { RichText } from '@blocksuite/affine/rich-text';
import { type EditorHost, TextSelection } from '@blocksuite/affine/std';
import { handleInlineAskAIAction } from '../../actions/doc-handler';
import { AIProvider } from '../../provider';
import { hasAIRequestService } from '../../runtime/request';
import type { AffineAIPanelWidget } from '../../widgets/ai-panel/ai-panel';
function isSpaceEvent(event: KeyboardEvent) {
@@ -43,7 +43,7 @@ export function setupSpaceAIEntry(panel: AffineAIPanelWidget) {
const host = panel.host;
const keyboardState = ctx.get('keyboardState');
const event = keyboardState.raw;
if (AIProvider.actions.chat && isSpaceEvent(event)) {
if (hasAIRequestService() && isSpaceEvent(event)) {
// If the AI panel is in the input state and the input content is empty,
// insert a space back into the editor.
if (panel.state === 'input') {
@@ -5,4 +5,6 @@ export * from './entries/edgeless/actions-config';
export * from './messages';
export { AIChatBlockPeekViewTemplate } from './peek-view/chat-block-peek-view';
export * from './provider';
export * from './runtime/chat';
export * from './runtime/request';
export * from './utils/edgeless';
@@ -9,8 +9,8 @@ import { css, html, LitElement, nothing, unsafeCSS } from 'lit';
import { property } from 'lit/decorators.js';
import {
AIAppEvents,
type AIError,
AIProvider,
PaymentRequiredError,
UnauthorizedError,
} from '../provider';
@@ -190,7 +190,7 @@ const PaymentRequiredErrorRenderer = (host?: EditorHost | null) => html`
<ai-error-wrapper
.text=${"You've reached the current usage cap for AFFiNE AI. You can subscribe to AFFiNE AI(with free 7-day-trial) to continue the AI experience!"}
.actionText=${'Upgrade'}
.onClick=${() => AIProvider.slots.requestUpgradePlan.next({ host })}
.onClick=${() => AIAppEvents.requestUpgradePlan.next({ host })}
></ai-error-wrapper>
`;
@@ -198,7 +198,7 @@ const LoginRequiredErrorRenderer = (host?: EditorHost | null) => html`
<ai-error-wrapper
.text=${'You need to login to AFFiNE Cloud to continue using AFFiNE AI.'}
.actionText=${'Login'}
.onClick=${() => AIProvider.slots.requestLogin.next({ host })}
.onClick=${() => AIAppEvents.requestLogin.next({ host })}
></ai-error-wrapper>
`;
@@ -9,10 +9,7 @@ import type {
} from '@affine/core/modules/cloud';
import type { WorkspaceDialogService } from '@affine/core/modules/dialogs';
import type { FeatureFlagService } from '@affine/core/modules/feature-flag';
import type {
ContextEmbedStatus,
CopilotChatHistoryFragment,
} from '@affine/graphql';
import type { CopilotChatHistoryFragment } from '@affine/graphql';
import {
CanvasElementType,
EdgelessCRUDIdentifier,
@@ -42,18 +39,17 @@ import type { SearchMenuConfig } from '../components/ai-chat-add-context';
import type { DocDisplayConfig } from '../components/ai-chat-chips';
import type { AIReasoningConfig } from '../components/ai-chat-input';
import type { ChatMessage } from '../components/ai-chat-messages';
import {
ChatMessagesSchema,
isChatMessage,
StreamObjectSchema,
} from '../components/ai-chat-messages';
import { ChatMessagesSchema } from '../components/ai-chat-messages';
import type { TextRendererOptions } from '../components/text-renderer';
import { AIChatErrorRenderer } from '../messages/error';
import { type AIError, AIProvider } from '../provider';
import { AIAppEvents, type AIError } from '../provider';
import {
mergeStreamContent,
mergeStreamObjects,
} from '../utils/stream-objects';
AIChatRuntime,
type AIChatSnapshot,
ChatBlockAIChatSessionStrategy,
} from '../runtime/chat';
import { getAIRequestService } from '../runtime/request';
import { mergeStreamContent } from '../utils/stream-objects';
import { PeekViewStyles } from './styles';
import type { ChatContext } from './types';
import { calcChildBound } from './utils';
@@ -85,14 +81,14 @@ export class AIChatBlockPeekView extends LitElement {
return this.blockModel.props.rootWorkspaceId;
}
private get _isReasoningActive() {
return !!this.reasoningConfig.enabled.value;
}
private _textRendererOptions: TextRendererOptions = {};
private _forkBlockId: string | undefined = undefined;
private runtime: AIChatRuntime | null = null;
private disposeRuntime: (() => void) | null = null;
private readonly _deserializeHistoryChatMessages = (
historyMessagesString: string
) => {
@@ -115,7 +111,7 @@ export class AIChatBlockPeekView extends LitElement {
forkSessionId: string,
docId?: string
) => {
const currentUserInfo = await AIProvider.userInfo;
const currentUserInfo = AIAppEvents.userInfo.value;
const forkMessages = (await queryHistoryMessages(
rootWorkspaceId,
forkSessionId,
@@ -160,36 +156,28 @@ export class AIChatBlockPeekView extends LitElement {
this._forkBlockId = undefined;
};
private readonly initSession = async () => {
const session = await AIProvider.session?.getSession(
this.rootWorkspaceId,
this._sessionId
);
this.session = session ?? null;
};
private readonly createForkSession = async () => {
if (this.forkSession) {
return this.forkSession;
}
const lastMessage = this._historyMessages.at(-1);
if (!lastMessage) return;
const { store } = this.host;
const forkSessionId = await AIProvider.forkChat?.({
workspaceId: store.workspace.id,
docId: store.id,
sessionId: this._sessionId,
latestMessageId: lastMessage.id,
private createRuntime() {
return new AIChatRuntime({
request: getAIRequestService(),
scope: {
kind: 'chat-block',
workspaceId: this.rootWorkspaceId,
docId: this.rootDocId,
blockId: this.blockId,
parentSessionId: this._sessionId,
latestMessageId: this._historyMessages.at(-1)?.id,
},
strategy: new ChatBlockAIChatSessionStrategy(),
});
if (forkSessionId) {
const session = await AIProvider.session?.getSession(
this.rootWorkspaceId,
forkSessionId
);
this.forkSession = session ?? null;
}
private readonly initSession = async () => {
const runtime = this.createRuntime();
try {
this.session = (await runtime.loadInitialSession()) ?? null;
} finally {
runtime.dispose();
}
return this.forkSession;
};
private readonly _onChatSuccess = async () => {
@@ -312,11 +300,37 @@ export class AIChatBlockPeekView extends LitElement {
this.chatContext = { ...this.chatContext, ...context };
};
private readonly onEmbeddingProgressChange = (
count: Record<ContextEmbedStatus, number>
) => {
const total = count.finished + count.processing + count.failed;
this.embeddingProgress = [count.finished, total];
private readonly syncContextFromRuntime = () => {
const snapshot = this.runtimeSnapshot;
if (!snapshot) return;
const activeSession = snapshot.sessions.find(
session => session.sessionId === snapshot.activeSessionId
);
this.forkSession = activeSession ?? this.forkSession;
this.chatContext = {
...this.chatContext,
messages: snapshot.messages as ChatMessage[],
status: snapshot.status,
error: snapshot.error as AIError | null,
};
};
private readonly ensureRuntime = () => {
if (this.runtime || !this.session || !this._historyMessages.length) return;
this.runtime = this.createRuntime();
this.disposeRuntime = this.runtime.subscribe(() => {
this.runtimeSnapshot = this.runtime?.getSnapshot() ?? null;
this.syncContextFromRuntime();
});
this.runtimeSnapshot = this.runtime.getSnapshot();
if (this.forkSession) {
this.runtime
.dispatch({
type: 'openSessionObject',
session: this.forkSession,
})
.catch(console.error);
}
};
/**
@@ -359,72 +373,16 @@ export class AIChatBlockPeekView extends LitElement {
* Retry the last chat message
*/
retry = async () => {
try {
const forkSessionId = this.forkSession?.sessionId;
if (!this._forkBlockId || !forkSessionId) return;
if (!AIProvider.actions.chat) return;
const abortController = new AbortController();
const messages = [...this.chatContext.messages];
const last = messages[messages.length - 1];
if ('content' in last) {
last.content = '';
last.streamObjects = [];
last.createdAt = new Date().toISOString();
}
this.updateContext({
messages,
status: 'loading',
error: null,
abortController,
if (this.runtime) {
const lastAssistantMessage = this.chatContext.messages.findLast(
message => message.role === 'assistant'
);
await this.runtime.dispatch({
type: 'retry',
messageId: lastAssistantMessage?.id ?? '',
});
const { store } = this.host;
const stream = await AIProvider.actions.chat({
sessionId: forkSessionId,
retry: true,
docId: store.id,
workspaceId: store.workspace.id,
host: this.host,
stream: true,
signal: abortController.signal,
where: 'ai-chat-block',
control: 'chat-send',
reasoning: this._isReasoningActive,
toolsConfig: this.aiToolsConfigService.config.value,
});
for await (const text of stream) {
const messages = this.chatContext.messages.slice(0);
const last = messages.at(-1);
if (last && isChatMessage(last)) {
try {
const parsed = StreamObjectSchema.parse(JSON.parse(text));
const streamObjects = mergeStreamObjects([
...(last.streamObjects ?? []),
parsed,
]);
messages[messages.length - 1] = {
...last,
streamObjects,
};
} catch {
messages[messages.length - 1] = {
...last,
content: last.content + text,
};
}
this.updateContext({ messages, status: 'transmitting' });
}
}
this.updateContext({ status: 'success' });
// Update new chat block messages if there are contents returned from AI
await this.updateChatBlockMessages();
} catch (error) {
this.updateContext({ status: 'error', error: error as AIError });
} finally {
this.updateContext({ abortController: null });
return;
}
};
@@ -526,17 +484,33 @@ export class AIChatBlockPeekView extends LitElement {
attachments: messages[idx]?.attachments ?? [],
};
});
this.ensureRuntime();
})
.catch((err: Error) => {
console.error('Query history messages failed', err);
});
}
override disconnectedCallback() {
super.disconnectedCallback();
this.disposeRuntime?.();
this.runtime?.dispose();
this.disposeRuntime = null;
this.runtime = null;
}
override firstUpdated() {
this._scrollToEnd();
}
protected override updated(changedProperties: PropertyValues) {
if (
changedProperties.has('session') ||
changedProperties.has('_historyMessages')
) {
this.ensureRuntime();
}
if (
changedProperties.has('chatContext') &&
(this.chatContext.status === 'loading' ||
@@ -572,6 +546,14 @@ export class AIChatBlockPeekView extends LitElement {
<ai-history-clear
.doc=${this.host.store}
.session=${this.forkSession}
.onClearHistory=${async (sessionIds: string[]) => {
for (const sessionId of sessionIds) {
await this.runtime?.dispatch({
type: 'deleteSession',
sessionId,
});
}
}}
.onHistoryCleared=${this._onHistoryCleared}
.chatContextValue=${chatContext}
.notificationService=${notificationService}
@@ -593,10 +575,10 @@ export class AIChatBlockPeekView extends LitElement {
.workspaceId=${this.rootWorkspaceId}
.docId=${this.rootDocId}
.session=${this.forkSession ?? this.session}
.createSession=${this.createForkSession}
.runtime=${this.runtime}
.runtimeSnapshot=${this.runtimeSnapshot}
.chatContextValue=${chatContext}
.updateContext=${updateContext}
.onEmbeddingProgressChange=${this.onEmbeddingProgressChange}
.docDisplayConfig=${this.docDisplayConfig}
.searchMenuConfig=${this.searchMenuConfig}
.affineWorkspaceDialogService=${this.affineWorkspaceDialogService}
@@ -673,7 +655,7 @@ export class AIChatBlockPeekView extends LitElement {
};
@state()
accessor embeddingProgress: [number, number] = [0, 0];
accessor runtimeSnapshot: AIChatSnapshot | null = null;
@state()
accessor session: CopilotChatHistoryFragment | null | undefined;
@@ -0,0 +1,19 @@
import type { EditorHost } from '@blocksuite/affine/std';
import { BehaviorSubject, Subject } from 'rxjs';
import type { AIChatParams, AISendParams, AIUserInfo } from './ai-provider';
export const AIAppEvents = {
/* eslint-disable rxjs/finnish */
requestOpenWithChat: new BehaviorSubject<AIChatParams | null>(null),
requestSendWithChat: new BehaviorSubject<AISendParams | null>(null),
requestInsertTemplate: new Subject<{
template: string;
mode: 'page' | 'edgeless';
}>(),
requestLogin: new Subject<{ host?: EditorHost | null }>(),
requestUpgradePlan: new Subject<{ host?: EditorHost | null }>(),
userInfo: new BehaviorSubject<AIUserInfo | null>(null),
previewPanelOpenChange: new Subject<boolean>(),
/* eslint-enable rxjs/finnish */
};
@@ -1,13 +1,6 @@
import type { EditorHost } from '@blocksuite/affine/std';
import { captureException } from '@sentry/react';
import { BehaviorSubject, Subject } from 'rxjs';
import type { ChatContextValue } from '../components/ai-chat-content';
import {
PaymentRequiredError,
RequestTimeoutError,
UnauthorizedError,
} from './error';
export interface AIUserInfo {
id: string;
@@ -64,14 +57,6 @@ export type ActionEventType =
* TODO: breakdown into different parts?
*/
export class AIProvider {
static get slots() {
return AIProvider.instance.slots;
}
static get actions() {
return AIProvider.instance.actions;
}
static get userInfo() {
return AIProvider.instance.userInfoFn();
}
@@ -80,275 +65,34 @@ export class AIProvider {
return AIProvider.instance.photoEngine;
}
static get histories() {
return AIProvider.instance.histories;
}
static get session() {
return AIProvider.instance.session;
}
static get context() {
return AIProvider.instance.context;
}
static get actionHistory() {
return AIProvider.instance.actionHistory;
}
static get toggleGeneralAIOnboarding() {
return AIProvider.instance.toggleGeneralAIOnboarding;
}
static get forkChat() {
return AIProvider.instance.forkChat;
}
static get embedding() {
return AIProvider.instance.embedding;
}
private static readonly instance = new AIProvider();
static LAST_ACTION_SESSIONID = '';
static MAX_LOCAL_HISTORY = 10;
private readonly actions: Partial<BlockSuitePresets.AIActions> = {};
private photoEngine: BlockSuitePresets.AIPhotoEngineService | null = null;
private histories: BlockSuitePresets.AIHistoryService | null = null;
private session: BlockSuitePresets.AISessionService | null = null;
private context: BlockSuitePresets.AIContextService | null = null;
private toggleGeneralAIOnboarding: ((value: boolean) => void) | null = null;
private forkChat:
| ((
options: BlockSuitePresets.AIForkChatSessionOptions
) => string | Promise<string>)
| null = null;
private readonly slots = {
// use case: when user selects "continue in chat" in an ask ai result panel
// do we need to pass the context to the chat panel?
/* eslint-disable rxjs/finnish */
requestOpenWithChat: new BehaviorSubject<AIChatParams | null>(null),
requestSendWithChat: new BehaviorSubject<AISendParams | null>(null),
requestInsertTemplate: new Subject<{
template: string;
mode: 'page' | 'edgeless';
}>(),
requestLogin: new Subject<{ host?: EditorHost | null }>(),
requestUpgradePlan: new Subject<{ host?: EditorHost | null }>(),
// stream of AI actions triggered by users
actions: new Subject<{
action: keyof BlockSuitePresets.AIActions;
options: BlockSuitePresets.AITextActionOptions;
event: ActionEventType;
}>(),
// downstream can emit this slot to notify ai presets that user info has been updated
userInfo: new Subject<AIUserInfo | null>(),
sessionReady: new BehaviorSubject<boolean>(false),
previewPanelOpenChange: new Subject<boolean>(),
/* eslint-enable rxjs/finnish */
};
// track the history of triggered actions (in memory only)
private readonly actionHistory: {
action: keyof BlockSuitePresets.AIActions;
options: BlockSuitePresets.AITextActionOptions;
}[] = [];
private userInfoFn: () => AIUserInfo | Promise<AIUserInfo> | null = () =>
null;
private embedding: BlockSuitePresets.AIEmbeddingService | null = null;
private provideAction<T extends keyof BlockSuitePresets.AIActions>(
id: T,
action: (
...options: Parameters<BlockSuitePresets.AIActions[T]>
) => Promise<ReturnType<BlockSuitePresets.AIActions[T]>>
): void {
// @ts-expect-error TODO: maybe fix this
this.actions[id] = async (
...args: Parameters<BlockSuitePresets.AIActions[T]>
) => {
const options = args[0];
const slots = this.slots;
slots.actions.next({
action: id,
options,
event: 'started',
});
this.actionHistory.push({ action: id, options });
if (this.actionHistory.length > AIProvider.MAX_LOCAL_HISTORY) {
this.actionHistory.shift();
}
// wrap the action with slot actions
const result: BlockSuitePresets.TextStream | Promise<string> =
await action(...args);
const isTextStream = (
m: BlockSuitePresets.TextStream | Promise<string>
): m is BlockSuitePresets.TextStream =>
Reflect.has(m, Symbol.asyncIterator);
if (isTextStream(result)) {
return {
[Symbol.asyncIterator]: async function* () {
let user = null;
try {
user = await AIProvider.userInfo;
yield* result;
slots.actions.next({
action: id,
options,
event: 'finished',
});
} catch (err) {
slots.actions.next({
action: id,
options,
event: 'error',
});
if (err instanceof RequestTimeoutError) {
slots.actions.next({
action: id,
options,
event: 'aborted:timeout',
});
} else if (err instanceof PaymentRequiredError) {
slots.actions.next({
action: id,
options,
event: 'aborted:paywall',
});
} else if (err instanceof UnauthorizedError) {
slots.actions.next({
action: id,
options,
event: 'aborted:login-required',
});
} else {
slots.actions.next({
action: id,
options,
event: 'aborted:server-error',
});
captureException(err, {
user: { id: user?.id },
extra: {
action: id,
session: AIProvider.LAST_ACTION_SESSIONID,
},
});
}
throw err;
}
},
};
} else {
let user: any = null;
return result
.then(async result => {
user = await AIProvider.userInfo;
slots.actions.next({
action: id,
options,
event: 'finished',
});
return result;
})
.catch(err => {
slots.actions.next({
action: id,
options,
event: 'error',
});
if (err instanceof PaymentRequiredError) {
slots.actions.next({
action: id,
options,
event: 'aborted:paywall',
});
} else {
captureException(err, {
user: { id: user?.id },
extra: {
action: id,
session: AIProvider.LAST_ACTION_SESSIONID,
},
});
}
throw err;
});
}
};
}
static provide(
id: 'userInfo',
fn: () => AIUserInfo | Promise<AIUserInfo> | null
): void;
static provide(
id: 'session',
service: BlockSuitePresets.AISessionService
): void;
static provide(
id: 'context',
service: BlockSuitePresets.AIContextService
): void;
static provide(
id: 'histories',
service: BlockSuitePresets.AIHistoryService
): void;
static provide(
id: 'photoEngine',
engine: BlockSuitePresets.AIPhotoEngineService
): void;
static provide(
id: 'forkChat',
fn: (
options: BlockSuitePresets.AIForkChatSessionOptions
) => string | Promise<string>
): void;
static provide(id: 'onboarding', fn: (value: boolean) => void): void;
static provide(
id: 'embedding',
service: BlockSuitePresets.AIEmbeddingService
): void;
// actions:
static provide<T extends keyof BlockSuitePresets.AIActions>(
id: T,
action: (
...options: Parameters<BlockSuitePresets.AIActions[T]>
) => Promise<ReturnType<BlockSuitePresets.AIActions[T]>>
): void;
static provide(id: unknown, action: unknown) {
if (id === 'userInfo') {
AIProvider.instance.userInfoFn = action as () => AIUserInfo;
} else if (id === 'histories') {
AIProvider.instance.histories =
action as BlockSuitePresets.AIHistoryService;
} else if (id === 'session') {
AIProvider.instance.session =
action as BlockSuitePresets.AISessionService;
AIProvider.instance.slots.sessionReady.next(true);
} else if (id === 'context') {
AIProvider.instance.context =
action as BlockSuitePresets.AIContextService;
} else if (id === 'photoEngine') {
AIProvider.instance.photoEngine =
action as BlockSuitePresets.AIPhotoEngineService;
@@ -356,15 +100,6 @@ export class AIProvider {
AIProvider.instance.toggleGeneralAIOnboarding = action as (
value: boolean
) => void;
} else if (id === 'forkChat') {
AIProvider.instance.forkChat = action as (
options: BlockSuitePresets.AIForkChatSessionOptions
) => string | Promise<string>;
} else if (id === 'embedding') {
AIProvider.instance.embedding =
action as BlockSuitePresets.AIEmbeddingService;
} else {
AIProvider.instance.provideAction(id as any, action as any);
}
}
}
@@ -1,4 +1,4 @@
import { handleError } from './copilot-client';
import { handleError } from '../runtime/request/copilot-client';
import { RequestTimeoutError } from './error';
export function delay(ms: number) {
@@ -1,4 +1,4 @@
export * from './ai-app-events';
export * from './ai-provider';
export * from './copilot-client';
export * from './error';
export * from './setup-provider';
@@ -1,284 +0,0 @@
/**
* @vitest-environment happy-dom
*/
import { UserFriendlyError } from '@affine/error';
import { beforeEach, describe, expect, test, vi } from 'vitest';
import { type CopilotClient, Endpoint } from './copilot-client';
import { textToText, toImage } from './request';
const electronApis = vi.hoisted(() => ({
byokStorage: undefined as
| {
isSupported: () => Promise<boolean>;
getWorkspaceLeaseProviders: (workspaceId: string) => Promise<
Array<{
provider: string;
name: string;
apiKey: string;
description?: string | null;
endpoint?: string | null;
sortOrder?: number | null;
enabled?: boolean | null;
}>
>;
}
| undefined,
}));
const createWorkspaceByokLocalLeaseMutation = vi.hoisted(() =>
Symbol('createWorkspaceByokLocalLeaseMutation')
);
vi.mock('@affine/electron-api', () => ({
apis: electronApis,
}));
vi.mock('@affine/graphql', () => ({
ByokProvider: {
openai: 'openai',
anthropic: 'anthropic',
gemini: 'gemini',
fal: 'fal',
},
createWorkspaceByokLocalLeaseMutation,
}));
function createClient(
overrides: Partial<
Pick<
CopilotClient,
'gql' | 'createMessage' | 'chatTextStream' | 'imagesStream'
>
> = {}
) {
return {
gql: vi.fn().mockResolvedValue({
createWorkspaceByokLocalLease: { leaseId: 'lease-1' },
}),
createMessage: vi.fn().mockResolvedValue('message-1'),
chatTextStream: vi.fn(),
imagesStream: vi.fn(),
...overrides,
} as unknown as CopilotClient;
}
async function drain(stream: AsyncIterable<unknown>) {
for await (const chunk of stream) {
void chunk;
}
}
describe('AI request BYOK local lease handling', () => {
beforeEach(() => {
vi.stubGlobal('BUILD_CONFIG', { isElectron: true });
electronApis.byokStorage = {
isSupported: vi.fn().mockResolvedValue(true),
getWorkspaceLeaseProviders: vi.fn().mockResolvedValue([
{
provider: 'openai',
name: 'OpenAI',
apiKey: 'sk-local',
},
]),
};
});
test('fails closed when local BYOK providers exist but lease creation fails', async () => {
const client = createClient({
gql: vi.fn().mockRejectedValue(new Error('mutation failed')),
});
const result = textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
}) as Promise<string>;
await expect(result).rejects.toThrow('mutation failed');
await expect(result).rejects.toBeInstanceOf(UserFriendlyError);
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('wraps local BYOK storage support failures as user friendly errors', async () => {
electronApis.byokStorage = {
isSupported: vi.fn().mockRejectedValue(new Error('support check failed')),
getWorkspaceLeaseProviders: vi.fn(),
};
const client = createClient();
const result = textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
}) as Promise<string>;
await expect(result).rejects.toThrow('support check failed');
await expect(result).rejects.toBeInstanceOf(UserFriendlyError);
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('wraps local BYOK provider loading failures as user friendly errors', async () => {
electronApis.byokStorage = {
isSupported: vi.fn().mockResolvedValue(true),
getWorkspaceLeaseProviders: vi
.fn()
.mockRejectedValue(new Error('provider load failed')),
};
const client = createClient();
const result = textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
}) as Promise<string>;
await expect(result).rejects.toThrow('provider load failed');
await expect(result).rejects.toBeInstanceOf(UserFriendlyError);
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('does not create local BYOK lease after cancellation', async () => {
const controller = new AbortController();
const client = createClient({
createMessage: vi.fn().mockImplementation(async () => {
controller.abort();
return 'message-1';
}),
});
await expect(
textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
signal: controller.signal,
}) as Promise<string>
).resolves.toBe('');
expect(client.gql).not.toHaveBeenCalled();
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('does not create stream local BYOK lease after cancellation', async () => {
const controller = new AbortController();
const client = createClient({
createMessage: vi.fn().mockImplementation(async () => {
controller.abort();
return 'message-1';
}),
});
await drain(
textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
stream: true,
signal: controller.signal,
}) as AsyncIterable<string>
);
expect(client.gql).not.toHaveBeenCalled();
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('does not create text stream when cancelled while creating local BYOK lease', async () => {
const controller = new AbortController();
const client = createClient({
gql: vi.fn().mockImplementation(async () => {
controller.abort();
return { createWorkspaceByokLocalLease: { leaseId: 'lease-1' } };
}),
});
await drain(
textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
stream: true,
signal: controller.signal,
}) as AsyncIterable<string>
);
expect(client.gql).toHaveBeenCalled();
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('does not create text request when cancelled while creating local BYOK lease', async () => {
const controller = new AbortController();
const client = createClient({
gql: vi.fn().mockImplementation(async () => {
controller.abort();
return { createWorkspaceByokLocalLease: { leaseId: 'lease-1' } };
}),
});
await expect(
textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
signal: controller.signal,
}) as Promise<string>
).resolves.toBe('');
expect(client.gql).toHaveBeenCalled();
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('does not create image local BYOK lease after cancellation', async () => {
const controller = new AbortController();
const client = createClient({
createMessage: vi.fn().mockImplementation(async () => {
controller.abort();
return 'message-1';
}),
});
await drain(
toImage({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'image',
endpoint: Endpoint.Images,
signal: controller.signal,
}) as AsyncIterable<string>
);
expect(client.gql).not.toHaveBeenCalled();
expect(client.imagesStream).not.toHaveBeenCalled();
});
test('does not create image stream when cancelled while creating local BYOK lease', async () => {
const controller = new AbortController();
const client = createClient({
gql: vi.fn().mockImplementation(async () => {
controller.abort();
return { createWorkspaceByokLocalLease: { leaseId: 'lease-1' } };
}),
});
await drain(
toImage({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'image',
endpoint: Endpoint.Images,
signal: controller.signal,
}) as AsyncIterable<string>
);
expect(client.gql).toHaveBeenCalled();
expect(client.imagesStream).not.toHaveBeenCalled();
});
});
@@ -1,120 +0,0 @@
/**
* @vitest-environment happy-dom
*/
import { BehaviorSubject } from 'rxjs';
import { describe, expect, test, vi } from 'vitest';
import { AIProvider } from './ai-provider';
import { CopilotClient, Endpoint } from './copilot-client';
import { setupAIProvider } from './setup-provider';
Object.defineProperty(globalThis, 'EventSource', {
configurable: true,
value: {
CLOSED: 2,
},
});
type SetupAIProviderArgs = Parameters<typeof setupAIProvider>;
type ActionInput<T extends keyof BlockSuitePresets.AIActions> = Parameters<
NonNullable<BlockSuitePresets.AIActions[T]>
>[0];
async function drain(stream: AsyncIterable<unknown>) {
for await (const chunk of stream) {
void chunk;
}
}
async function drainActionResult(
stream: string | AsyncIterable<unknown> | undefined
) {
expect(stream).toBeDefined();
expect(typeof stream).not.toBe('string');
await drain(stream as AsyncIterable<unknown>);
}
function createClosedEventSource(): EventSource {
return {
readyState: EventSource.CLOSED,
addEventListener: vi.fn(),
close: vi.fn(),
} as unknown as EventSource;
}
describe('setupAIProvider action migrations', () => {
test('routes mindmap, slides and image filter through action API', async () => {
const createdSessions: unknown[] = [];
const textStreams: unknown[] = [];
const client = new CopilotClient(
vi.fn(),
vi.fn(() => createClosedEventSource())
);
vi.spyOn(client, 'createSession').mockImplementation(async options => {
createdSessions.push(options);
return `session:${options.promptName}`;
});
vi.spyOn(client, 'createMessage').mockResolvedValue('message-1');
vi.spyOn(client, 'chatTextStream').mockImplementation(
(options, endpoint) => {
textStreams.push({ options, endpoint });
return createClosedEventSource();
}
);
vi.spyOn(client, 'imagesStream').mockReturnValue(createClosedEventSource());
setupAIProvider(
client,
{ open: vi.fn() } as unknown as SetupAIProviderArgs[1],
{
session: {
account$: new BehaviorSubject(null),
},
} as unknown as SetupAIProviderArgs[2]
);
await drainActionResult(
await AIProvider.actions.brainstormMindmap?.({
workspaceId: 'workspace-1',
input: 'make a map',
stream: true,
} satisfies ActionInput<'brainstormMindmap'>)
);
await drainActionResult(
await AIProvider.actions.createSlides?.({
workspaceId: 'workspace-1',
input: 'make slides',
stream: true,
} satisfies ActionInput<'createSlides'>)
);
await drainActionResult(
await AIProvider.actions.filterImage?.({
workspaceId: 'workspace-1',
input: 'convert',
attachments: ['blob-1'],
style: 'Sketch style',
} satisfies ActionInput<'filterImage'>)
);
expect(createdSessions).toEqual([
expect.objectContaining({ promptName: 'mindmap.generate' }),
expect.objectContaining({ promptName: 'slides.outline' }),
expect.objectContaining({ promptName: 'image.filter.sketch' }),
]);
expect(textStreams).toEqual([
expect.objectContaining({
endpoint: Endpoint.Action,
options: expect.objectContaining({ actionId: 'mindmap.generate' }),
}),
expect.objectContaining({
endpoint: Endpoint.Action,
options: expect.objectContaining({ actionId: 'slides.outline' }),
}),
expect.objectContaining({
endpoint: Endpoint.Action,
options: expect.objectContaining({ actionId: 'image.filter.sketch' }),
}),
]);
expect(client.imagesStream).not.toHaveBeenCalled();
});
});
@@ -1,20 +1,11 @@
import { toggleGeneralAIOnboarding } from '@affine/core/components/affine/ai-onboarding/apis';
import type { AuthAccountInfo, AuthService } from '@affine/core/modules/cloud';
import type { GlobalDialogService } from '@affine/core/modules/dialogs';
import {
type AddContextFileInput,
ContextCategories,
type ContextWorkspaceEmbeddingStatus,
type getCopilotHistoriesQuery,
type QueryChatSessionsInput,
type RequestOptions,
type UpdateChatSessionInput,
} from '@affine/graphql';
import type { AIRequestService } from '../runtime/request';
import { setAIRequestService } from '../runtime/request';
import { AIAppEvents } from './ai-app-events';
import { AIProvider } from './ai-provider';
import { type CopilotClient, Endpoint } from './copilot-client';
import type { PromptKey } from './prompt';
import { textToText, toImage } from './request';
import { setupTracker } from './tracker';
function toAIUserInfo(account: AuthAccountInfo | null) {
@@ -27,48 +18,12 @@ function toAIUserInfo(account: AuthAccountInfo | null) {
};
}
const filterStyleToPromptName = new Map<string, PromptKey>(
Object.entries({
'Clay style': 'image.filter.clay',
'Pixel style': 'image.filter.pixel',
'Sketch style': 'image.filter.sketch',
'Anime style': 'image.filter.anime',
})
);
const processTypeToPromptName = new Map<string, PromptKey>(
Object.entries({
Clearer: 'Upscale image',
'Remove background': 'Remove background',
'Convert to sticker': 'Convert to sticker',
})
);
export function setupAIProvider(
client: CopilotClient,
requestService: AIRequestService,
globalDialogService: GlobalDialogService,
authService: AuthService
) {
async function createSession({
promptName,
workspaceId,
docId,
sessionId,
retry,
pinned,
reuseLatestChat,
}: BlockSuitePresets.AICreateSessionOptions) {
if (sessionId) return sessionId;
if (retry) return AIProvider.LAST_ACTION_SESSIONID;
return client.createSession({
workspaceId,
docId,
promptName,
pinned,
reuseLatestChat,
});
}
setAIRequestService(requestService);
AIProvider.provide('userInfo', () => {
return toAIUserInfo(authService.session.account$.value);
@@ -76,715 +31,10 @@ export function setupAIProvider(
const accountSubscription = authService.session.account$.subscribe(
account => {
AIProvider.slots.userInfo.next(toAIUserInfo(account));
AIAppEvents.userInfo.next(toAIUserInfo(account));
}
);
//#region actions
AIProvider.provide('chat', async options => {
const { input, contexts } = options;
const sessionId = await createSession({
promptName: 'Chat With AFFiNE AI',
...options,
});
return textToText({
...options,
modelId: options.modelId,
client,
sessionId,
content: input,
timeout: 5 * 60 * 1000, // 5 minutes
params: {
docs: contexts?.docs,
files: contexts?.files,
selectedSnapshot: contexts?.selectedSnapshot,
selectedMarkdown: contexts?.selectedMarkdown,
html: contexts?.html,
...(options.docId ? { currentDocId: options.docId } : {}),
},
endpoint: Endpoint.StreamObject,
});
});
AIProvider.provide('summary', async options => {
const sessionId = await createSession({
promptName: 'Summary',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('translate', async options => {
const sessionId = await createSession({
promptName: 'Translate to',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
params: {
language: options.lang,
},
});
});
AIProvider.provide('changeTone', async options => {
const sessionId = await createSession({
promptName: 'Change tone to',
...options,
});
return textToText({
...options,
client,
sessionId,
params: {
tone: options.tone.toLowerCase(),
},
content: options.input,
});
});
AIProvider.provide('improveWriting', async options => {
const sessionId = await createSession({
promptName: 'Improve writing for it',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('improveGrammar', async options => {
const sessionId = await createSession({
promptName: 'Improve grammar for it',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('fixSpelling', async options => {
const sessionId = await createSession({
promptName: 'Fix spelling for it',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('createHeadings', async options => {
const sessionId = await createSession({
promptName: 'Create headings',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('makeLonger', async options => {
const sessionId = await createSession({
promptName: 'Make it longer',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('makeShorter', async options => {
const sessionId = await createSession({
promptName: 'Make it shorter',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('checkCodeErrors', async options => {
const sessionId = await createSession({
promptName: 'Check code error',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('explainCode', async options => {
const sessionId = await createSession({
promptName: 'Explain this code',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('writeArticle', async options => {
const sessionId = await createSession({
promptName: 'Write an article about this',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('writeTwitterPost', async options => {
const sessionId = await createSession({
promptName: 'Write a twitter about this',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('writePoem', async options => {
const sessionId = await createSession({
promptName: 'Write a poem about this',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('writeOutline', async options => {
const sessionId = await createSession({
promptName: 'Write outline',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('writeBlogPost', async options => {
const sessionId = await createSession({
promptName: 'Write a blog post about this',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('brainstorm', async options => {
const sessionId = await createSession({
promptName: 'Brainstorm ideas about this',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('findActions', async options => {
const sessionId = await createSession({
promptName: 'Find action items from it',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('brainstormMindmap', async options => {
const sessionId = await createSession({
promptName: 'mindmap.generate',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
// 3 minutes
timeout: 180000,
endpoint: Endpoint.Action,
actionId: 'mindmap.generate',
actionVersion: 'v1',
});
});
AIProvider.provide('expandMindmap', async options => {
if (!options.input) {
throw new Error('expandMindmap action requires input');
}
const sessionId = await createSession({
promptName: 'Expand mind map',
...options,
});
return textToText({
...options,
client,
sessionId,
params: {
mindmap: options.mindmap,
node: options.input,
},
content: options.input,
});
});
AIProvider.provide('explain', async options => {
const sessionId = await createSession({
promptName: 'Explain this',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('explainImage', async options => {
const sessionId = await createSession({
promptName: 'Explain this image',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('makeItReal', async options => {
let promptName: PromptKey = 'Make it real';
let content = options.input || '';
// wireframes
if (options.attachments?.length) {
content = `Here are the latest wireframes. Could you make a new website based on these wireframes and notes and send back just the html file?
Here are our design notes:\n ${content}.`;
} else {
// notes
promptName = 'Make it real with text';
content = `Here are the latest notes: \n ${content}.
Could you make a new website based on these notes and send back just the html file?`;
}
const sessionId = await createSession({
promptName,
...options,
});
return textToText({
...options,
client,
sessionId,
content,
});
});
AIProvider.provide('createSlides', async options => {
const sessionId = await createSession({
promptName: 'slides.outline',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
// 3 minutes
timeout: 180000,
endpoint: Endpoint.Action,
actionId: 'slides.outline',
actionVersion: 'v1',
});
});
AIProvider.provide('createImage', async options => {
const sessionId = await createSession({
promptName: 'Generate image',
...options,
});
return toImage({
...options,
client,
sessionId,
content:
!options.input && options.attachments
? 'Make the image more detailed.'
: options.input,
// 5 minutes
timeout: 300000,
});
});
AIProvider.provide('filterImage', async options => {
// test to image
const promptName: PromptKey | undefined = filterStyleToPromptName.get(
options.style
);
if (!promptName) {
throw new Error('filterImage requires a promptName');
}
const sessionId = await createSession({
promptName,
...options,
});
return toImage({
...options,
client,
sessionId,
content: options.input,
timeout: 180000,
endpoint: Endpoint.Action,
actionId: promptName,
actionVersion: 'v1',
});
});
AIProvider.provide('processImage', async options => {
// test to image
const promptName: PromptKey | undefined = processTypeToPromptName.get(
options.type
);
if (!promptName) {
throw new Error('processImage requires a promptName');
}
const sessionId = await createSession({
promptName,
...options,
});
return toImage({
...options,
client,
sessionId,
content: options.input,
timeout: 180000,
});
});
AIProvider.provide('generateCaption', async options => {
const sessionId = await createSession({
promptName: 'Generate a caption',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
AIProvider.provide('continueWriting', async options => {
const sessionId = await createSession({
promptName: 'Continue writing',
...options,
});
return textToText({
...options,
client,
sessionId,
content: options.input,
});
});
//#endregion
AIProvider.provide('session', {
createSession,
createSessionWithHistory: async options => {
if (!options.sessionId && !options.retry) {
return client.createSessionWithHistory({
workspaceId: options.workspaceId,
docId: options.docId,
promptName: options.promptName,
pinned: options.pinned,
reuseLatestChat: options.reuseLatestChat,
});
}
const sessionId = await createSession(options);
if (!sessionId) return undefined;
return client.getSession(options.workspaceId, sessionId);
},
getSession: async (workspaceId: string, sessionId: string) => {
return client.getSession(workspaceId, sessionId);
},
getSessions: async (
workspaceId: string,
docId?: string,
options?: QueryChatSessionsInput
) => {
return client.getSessions(workspaceId, {}, docId, options);
},
getRecentSessions: async (
workspaceId: string,
limit?: number,
offset?: number
) => {
return client.getRecentSessions(workspaceId, limit, offset);
},
updateSession: async (options: UpdateChatSessionInput) => {
return client.updateSession(options);
},
});
AIProvider.provide('context', {
createContext: async (workspaceId: string, sessionId: string) => {
return client.createContext(workspaceId, sessionId);
},
getContextId: async (workspaceId: string, sessionId: string) => {
return client.getContextId(workspaceId, sessionId);
},
addContextDoc: async (options: { contextId: string; docId: string }) => {
return client.addContextDoc(options);
},
removeContextDoc: async (options: { contextId: string; docId: string }) => {
return client.removeContextDoc(options);
},
addContextFile: async (file: File, options: AddContextFileInput) => {
return client.addContextFile(file, options);
},
removeContextFile: async (options: {
contextId: string;
fileId: string;
}) => {
return client.removeContextFile(options);
},
addContextTag: async (options: {
contextId: string;
tagId: string;
docIds: string[];
}) => {
return client.addContextCategory({
contextId: options.contextId,
type: ContextCategories.Tag,
categoryId: options.tagId,
docs: options.docIds,
});
},
removeContextTag: async (options: { contextId: string; tagId: string }) => {
return client.removeContextCategory({
contextId: options.contextId,
type: ContextCategories.Tag,
categoryId: options.tagId,
});
},
addContextCollection: async (options: {
contextId: string;
collectionId: string;
docIds: string[];
}) => {
return client.addContextCategory({
contextId: options.contextId,
type: ContextCategories.Collection,
categoryId: options.collectionId,
docs: options.docIds,
});
},
removeContextCollection: async (options: {
contextId: string;
collectionId: string;
}) => {
return client.removeContextCategory({
contextId: options.contextId,
type: ContextCategories.Collection,
categoryId: options.collectionId,
});
},
getContextDocsAndFiles: async (
workspaceId: string,
sessionId: string,
contextId: string
) => {
return client.getContextDocsAndFiles(workspaceId, sessionId, contextId);
},
pollContextDocsAndFiles: async (
workspaceId: string,
sessionId: string,
contextId: string,
onPoll: (
result: BlockSuitePresets.AIDocsAndFilesContext | undefined
) => void,
abortSignal: AbortSignal
) => {
const poll = async () => {
const result = await client.getContextDocsAndFiles(
workspaceId,
sessionId,
contextId
);
onPoll(result);
};
let attempts = 0;
const MIN_INTERVAL = 1000;
const MAX_INTERVAL = 30 * 1000;
while (!abortSignal.aborted) {
await poll();
const interval = Math.min(
MIN_INTERVAL * Math.pow(1.5, attempts),
MAX_INTERVAL
);
attempts++;
await new Promise(resolve => setTimeout(resolve, interval));
}
},
pollEmbeddingStatus: async (
workspaceId: string,
onPoll: (result: ContextWorkspaceEmbeddingStatus) => void,
abortSignal: AbortSignal
) => {
const poll = async () => {
const result = await client.getEmbeddingStatus(workspaceId);
onPoll(result);
};
const INTERVAL = 10 * 1000;
while (!abortSignal.aborted) {
await poll();
await new Promise(resolve => setTimeout(resolve, INTERVAL));
}
},
matchContext: async (
content: string,
contextId?: string,
workspaceId?: string,
limit?: number,
scopedThreshold?: number,
threshold?: number
) => {
return client.matchContext(
content,
contextId,
workspaceId,
limit,
scopedThreshold,
threshold
);
},
addContextBlob: async (options: { blobId: string; contextId: string }) => {
return client.addContextBlob({
contextId: options.contextId,
blobId: options.blobId,
});
},
removeContextBlob: async (options: {
blobId: string;
contextId: string;
}) => {
return client.removeContextBlob({
contextId: options.contextId,
blobId: options.blobId,
});
},
});
AIProvider.provide('histories', {
actions: async (
workspaceId: string,
docId: string
): Promise<BlockSuitePresets.AIHistory[]> => {
// @ts-expect-error - 'action' is missing in server impl
return (
(await client.getHistories(workspaceId, {}, docId, {
action: true,
withPrompt: true,
withMessages: true,
})) ?? []
);
},
chats: async (
workspaceId: string,
sessionId: string,
docId?: string
): Promise<BlockSuitePresets.AIHistory[]> => {
// @ts-expect-error - 'action' is missing in server impl
return (
(await client.getHistories(workspaceId, {}, docId, {
sessionId,
withMessages: true,
})) ?? []
);
},
cleanup: async (
workspaceId: string,
docId: string | undefined,
sessionIds: string[]
) => {
await client.cleanupSessions({ workspaceId, docId, sessionIds });
},
ids: async (
workspaceId: string,
docId?: string,
options?: RequestOptions<
typeof getCopilotHistoriesQuery
>['variables']['options']
): Promise<BlockSuitePresets.AIHistoryIds[]> => {
// @ts-expect-error - 'action' is missing in server impl
return await client.getHistoryIds(workspaceId, {}, docId, options);
},
});
AIProvider.provide('photoEngine', {
async searchImages(options): Promise<string[]> {
let url = '/api/copilot/unsplash/photos';
@@ -813,19 +63,15 @@ Could you make a new website based on these notes and send back just the html fi
AIProvider.provide('onboarding', toggleGeneralAIOnboarding);
AIProvider.provide('forkChat', options => {
return client.forkSession(options);
const disposeRequestLoginHandler = AIAppEvents.requestLogin.subscribe(() => {
globalDialogService.open('sign-in', {});
});
const disposeRequestLoginHandler = AIProvider.slots.requestLogin.subscribe(
() => {
globalDialogService.open('sign-in', {});
}
);
setupTracker();
const trackerDisposer = setupTracker(requestService);
return () => {
setAIRequestService(null);
trackerDisposer();
disposeRequestLoginHandler.unsubscribe();
accountSubscription.unsubscribe();
};
@@ -3,9 +3,12 @@ import type { EditorHost } from '@blocksuite/affine/std';
import type { GfxPrimitiveElementModel } from '@blocksuite/affine/std/gfx';
import type { BlockModel } from '@blocksuite/affine/store';
import { lowerCase, omit } from 'lodash-es';
import type { Subject } from 'rxjs';
import { AIProvider } from './ai-provider';
import type {
AIRequestActionEvent,
AIRequestService,
} from '../runtime/request';
import { AIAppEvents } from './ai-app-events';
type ElementModel = GfxPrimitiveElementModel;
@@ -62,10 +65,6 @@ type AIActionEventProperties = {
workspaceId: string;
};
type SubjectValue<T> = T extends Subject<infer U> ? U : never;
type BlocksuiteActionEvent = SubjectValue<typeof AIProvider.slots.actions>;
const trackAction = ({
eventName,
properties,
@@ -106,7 +105,7 @@ function isBlockModel(model: BlockModel | ElementModel): model is BlockModel {
return 'flavour' in model;
}
function inferObjectType(event: BlocksuiteActionEvent) {
function inferObjectType(event: AIRequestActionEvent) {
const models: (BlockModel | ElementModel)[] | undefined =
event.options.models;
if (!models) {
@@ -135,7 +134,7 @@ function inferObjectType(event: BlocksuiteActionEvent) {
}
function inferSegment(
event: BlocksuiteActionEvent
event: AIRequestActionEvent
): AIActionEventProperties['segment'] {
if (event.options.where === 'inline-chat-panel') {
return 'inline chat panel';
@@ -151,7 +150,7 @@ function inferSegment(
}
function inferModule(
event: BlocksuiteActionEvent
event: AIRequestActionEvent
): AIActionEventProperties['module'] {
if (event.options.where === 'chat-panel') {
return 'AI chat panel';
@@ -168,9 +167,7 @@ function inferModule(
}
}
function inferEventName(
event: BlocksuiteActionEvent
): AIActionEventName | null {
function inferEventName(event: AIRequestActionEvent): AIActionEventName | null {
if (['result:discard', 'result:retry'].includes(event.event)) {
return 'AI result discarded';
} else if (event.event.startsWith('result:')) {
@@ -184,7 +181,7 @@ function inferEventName(
}
function inferControl(
event: BlocksuiteActionEvent
event: AIRequestActionEvent
): AIActionEventProperties['control'] {
if (event.event === 'aborted:stop') {
return 'stop button';
@@ -222,7 +219,7 @@ function inferControl(
}
const toTrackedOptions = (
event: BlocksuiteActionEvent
event: AIRequestActionEvent
): {
eventName: AIActionEventName;
properties: AIActionEventProperties;
@@ -257,19 +254,25 @@ const toTrackedOptions = (
};
};
export function setupTracker() {
AIProvider.slots.requestUpgradePlan.subscribe(() => {
export function setupTracker(requestService: AIRequestService) {
const upgradeSubscription = AIAppEvents.requestUpgradePlan.subscribe(() => {
track.$.paywall.aiAction.viewPlans();
});
AIProvider.slots.requestLogin.subscribe(() => {
const loginSubscription = AIAppEvents.requestLogin.subscribe(() => {
track.doc.editor.aiActions.requestSignIn();
});
AIProvider.slots.actions.subscribe(event => {
const actionSubscription = requestService.actionEvents$.subscribe(event => {
const properties = toTrackedOptions(event);
if (properties) {
trackAction(properties);
}
});
return () => {
upgradeSubscription.unsubscribe();
loginSubscription.unsubscribe();
actionSubscription.unsubscribe();
};
}
@@ -0,0 +1,60 @@
import type { CopilotChatHistoryFragment } from '@affine/graphql';
import type { AIChatContextItem, AIChatScope } from './state';
export type AIChatSendOptions = {
input?: string;
contexts?: {
docs?: unknown;
files?: unknown;
selectedSnapshot?: unknown;
selectedMarkdown?: unknown;
html?: unknown;
};
attachments?: (string | Blob | File)[];
attachmentPreviews?: string[];
isRootSession?: boolean;
where?: BlockSuitePresets.TrackerWhere;
control?: BlockSuitePresets.TrackerControl;
reasoning?: boolean;
toolsConfig?: unknown;
modelId?: string;
userInfo?: {
userId?: string;
userName?: string;
avatarUrl?: string;
};
};
export type AIChatAction =
| { type: 'initialize'; scope?: AIChatScope }
| { type: 'setScope'; scope: AIChatScope }
| { type: 'refreshHistory' }
| {
type: 'openSession';
sessionId: string;
}
| {
type: 'openSessionObject';
session: CopilotChatHistoryFragment;
}
| { type: 'closeTab'; tabId: string }
| { type: 'createNewSession'; pinned?: boolean }
| { type: 'togglePinActiveSession' }
| { type: 'deleteSession'; sessionId: string }
| { type: 'clearError' }
| { type: 'setComposerText'; text: string }
| { type: 'setReasoning'; reasoning: boolean }
| { type: 'setModel'; modelId?: string }
| { type: 'addAttachment'; attachment: string | Blob | File }
| { type: 'removeAttachment'; index: number }
| { type: 'addContextItem'; item: AIChatContextItem }
| { type: 'removeContextItem'; item: AIChatContextItem }
| { type: 'loadContext' }
| { type: 'pollContext' }
| { type: 'startContextPolling' }
| { type: 'stopContextPolling' }
| { type: 'pollEmbeddingStatus' }
| ({ type: 'send' } & AIChatSendOptions)
| { type: 'retry'; messageId: string }
| { type: 'stop' };
@@ -0,0 +1,6 @@
export * from './actions';
export * from './runtime';
export * from './session-strategy';
export * from './state';
export * from './use-element';
export * from './use-runtime';
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,225 @@
import type { CopilotChatHistoryFragment } from '@affine/graphql';
import type { AIRequestService } from '../request';
import { type AIChatScope, type AIChatTab, createDraftTab } from './state';
export type OpenSessionResult =
| {
type: 'opened';
session: CopilotChatHistoryFragment;
}
| {
type: 'navigate';
target: {
workspaceId: string;
docId: string;
sessionId: string;
};
resetTabs: true;
};
export interface AIChatSessionStrategy {
loadInitialSession(
scope: AIChatScope,
request: AIRequestService
): Promise<CopilotChatHistoryFragment | null>;
createDraftSession(scope: AIChatScope): AIChatTab;
createSession(
scope: AIChatScope,
request: AIRequestService,
options?: { pinned?: boolean }
): Promise<CopilotChatHistoryFragment | null | undefined>;
canOpenAsTab(
session: CopilotChatHistoryFragment,
scope: AIChatScope
): boolean;
openSession(
session: CopilotChatHistoryFragment,
scope: AIChatScope
): OpenSessionResult;
}
export class DocAIChatSessionStrategy implements AIChatSessionStrategy {
async loadInitialSession(scope: AIChatScope, request: AIRequestService) {
if (scope.kind !== 'doc') return null;
const pinned = await request.getSessions(scope.workspaceId, undefined, {
pinned: true,
limit: 1,
});
if (Array.isArray(pinned) && pinned[0]) {
return (
(await request.getSession(scope.workspaceId, pinned[0].sessionId)) ??
pinned[0]
);
}
if (scope.pendingSessionId) {
return (
(await request.getSession(scope.workspaceId, scope.pendingSessionId)) ??
null
);
}
const docSessions = await request.getSessions(
scope.workspaceId,
scope.docId,
{
action: false,
fork: false,
limit: 1,
}
);
const session = docSessions?.[0];
if (!session) return null;
return (
(await request.getSession(scope.workspaceId, session.sessionId)) ??
session
);
}
createDraftSession(scope: AIChatScope) {
return createDraftTab(scope);
}
createSession(
scope: AIChatScope,
request: AIRequestService,
options: { pinned?: boolean } = {}
) {
if (scope.kind !== 'doc') return Promise.resolve(null);
return request.createSessionWithHistory({
workspaceId: scope.workspaceId,
docId: scope.docId,
promptName: 'Chat With AFFiNE AI',
reuseLatestChat: false,
pinned: options.pinned,
});
}
canOpenAsTab(session: CopilotChatHistoryFragment, scope: AIChatScope) {
return (
scope.kind === 'doc' && (!session.docId || session.docId === scope.docId)
);
}
openSession(session: CopilotChatHistoryFragment, scope: AIChatScope) {
if (this.canOpenAsTab(session, scope)) {
return { type: 'opened' as const, session };
}
if (scope.kind === 'doc' && session.docId) {
return {
type: 'navigate' as const,
target: {
workspaceId: session.workspaceId,
docId: session.docId,
sessionId: session.sessionId,
},
resetTabs: true as const,
};
}
return { type: 'opened' as const, session };
}
}
export class WorkspaceAIChatSessionStrategy implements AIChatSessionStrategy {
async loadInitialSession(scope: AIChatScope, request: AIRequestService) {
const sessions = await request.getSessions(scope.workspaceId, undefined, {
pinned: true,
limit: 1,
});
const session = sessions?.[0];
if (!session) return null;
return (
(await request.getSession(scope.workspaceId, session.sessionId)) ??
session
);
}
createDraftSession(scope: AIChatScope) {
return createDraftTab(scope);
}
createSession(
scope: AIChatScope,
request: AIRequestService,
options: { pinned?: boolean } = {}
) {
return request.createSessionWithHistory({
workspaceId: scope.workspaceId,
promptName: 'Chat With AFFiNE AI',
reuseLatestChat: false,
pinned: options.pinned,
});
}
canOpenAsTab(session: CopilotChatHistoryFragment, scope: AIChatScope) {
return session.workspaceId === scope.workspaceId && !session.docId;
}
openSession(session: CopilotChatHistoryFragment) {
return { type: 'opened' as const, session };
}
}
export class ForkAIChatSessionStrategy implements AIChatSessionStrategy {
async loadInitialSession(scope: AIChatScope, request: AIRequestService) {
if (
scope.kind !== 'fork' &&
scope.kind !== 'chat-block' &&
scope.kind !== 'playground'
) {
return null;
}
const parentSessionId =
'parentSessionId' in scope ? scope.parentSessionId : undefined;
if (!parentSessionId) return null;
return (
(await request.getSession(scope.workspaceId, parentSessionId)) ?? null
);
}
createDraftSession(scope: AIChatScope) {
return createDraftTab(scope);
}
async createSession(
scope: AIChatScope,
request: AIRequestService,
options: { pinned?: boolean } = {}
) {
const docId = 'docId' in scope ? scope.docId : undefined;
const parentSessionId =
'parentSessionId' in scope ? scope.parentSessionId : undefined;
if (!parentSessionId) {
return request.createSessionWithHistory({
workspaceId: scope.workspaceId,
docId,
promptName: 'Chat With AFFiNE AI',
reuseLatestChat: false,
pinned: options.pinned,
});
}
const latestMessageId =
'latestMessageId' in scope ? scope.latestMessageId : undefined;
const forkSessionId = await request.forkChat({
workspaceId: scope.workspaceId,
docId: docId ?? '',
sessionId: parentSessionId,
...(latestMessageId ? { latestMessageId } : {}),
});
if (!forkSessionId) return null;
return request.getSession(scope.workspaceId, forkSessionId);
}
canOpenAsTab(session: CopilotChatHistoryFragment, scope: AIChatScope) {
return session.workspaceId === scope.workspaceId;
}
openSession(session: CopilotChatHistoryFragment) {
return { type: 'opened' as const, session };
}
}
export class ChatBlockAIChatSessionStrategy extends ForkAIChatSessionStrategy {}
export class PlaygroundAIChatSessionStrategy extends ForkAIChatSessionStrategy {}
@@ -0,0 +1,212 @@
import type { AIToolsConfig } from '@affine/core/modules/ai-button';
import type { CopilotChatHistoryFragment } from '@affine/graphql';
export type AIChatScope =
| {
kind: 'doc';
workspaceId: string;
docId: string;
pendingSessionId?: string;
}
| {
kind: 'workspace';
workspaceId: string;
}
| {
kind: 'fork';
workspaceId: string;
parentSessionId: string;
latestMessageId?: string;
docId?: string;
}
| {
kind: 'chat-block';
workspaceId: string;
docId: string;
blockId: string;
parentSessionId?: string;
latestMessageId?: string;
}
| {
kind: 'playground';
workspaceId: string;
docId?: string;
parentSessionId?: string;
latestMessageId?: string;
};
export type AIChatMessage = {
id: string;
role: 'user' | 'assistant';
content: string;
createdAt: string;
streamObjects?: unknown[];
attachments?: string[];
userId?: string;
userName?: string;
avatarUrl?: string;
};
export type AIChatTab =
| {
kind: 'draft';
id: string;
title: string;
scope: AIChatScope;
hasMessages: false;
}
| {
kind: 'session';
id: string;
sessionId: string;
title: string;
docId: string | null;
pinned: boolean;
hasMessages: boolean;
};
export type AIChatStatus =
| 'idle'
| 'loading'
| 'transmitting'
| 'success'
| 'error';
export type AIChatHistoryGroups = {
currentDoc: CopilotChatHistoryFragment[];
recent: CopilotChatHistoryFragment[];
loading: boolean;
error: Error | null;
};
export type AIChatContextItem =
| {
kind: 'doc';
docId: string;
state?: string;
createdAt?: number;
tooltip?: string;
}
| {
kind: 'file';
file: File;
blobId?: string;
fileId?: string;
state?: string;
createdAt?: number;
tooltip?: string;
}
| {
kind: 'tag';
tagId: string;
docIds: string[];
state?: string;
createdAt?: number;
tooltip?: string;
}
| {
kind: 'collection';
collectionId: string;
docIds: string[];
state?: string;
createdAt?: number;
tooltip?: string;
}
| {
kind: 'blob';
blobId: string;
state?: string;
createdAt?: number;
tooltip?: string;
};
export type AIChatContextState = {
contextId: string | null;
items: AIChatContextItem[];
loading: boolean;
polling: boolean;
error: Error | null;
embeddingCompleted: boolean;
embeddingCount: Record<'finished' | 'processing' | 'failed', number>;
};
export type AIChatComposerState = {
text: string;
attachments: (string | Blob | File)[];
context: AIChatContextState;
reasoning: boolean;
toolsConfig?: AIToolsConfig;
modelId?: string;
};
export type AIChatNavigationRequest = {
workspaceId: string;
docId: string;
sessionId: string;
resetTabs: true;
};
export type AIChatSnapshot = {
scope: AIChatScope;
readiness: 'initializing' | 'ready' | 'unavailable';
activeSessionId: string | null;
activeTabId: string | null;
tabs: AIChatTab[];
sessions: CopilotChatHistoryFragment[];
history: AIChatHistoryGroups;
messages: AIChatMessage[];
status: AIChatStatus;
error: Error | null;
composer: AIChatComposerState;
navigationRequest: AIChatNavigationRequest | null;
uiPolicy: {
showDraftTab: boolean;
canCreateNewSession: boolean;
canCloseActiveTab: boolean;
canPinActiveSession: boolean;
canSend: boolean;
};
};
export function createInitialComposerState(): AIChatComposerState {
return {
text: '',
attachments: [],
context: {
contextId: null,
items: [],
loading: false,
polling: false,
error: null,
embeddingCompleted: false,
embeddingCount: {
finished: 0,
processing: 0,
failed: 0,
},
},
reasoning: false,
};
}
export function sessionToTab(session: CopilotChatHistoryFragment): AIChatTab {
return {
kind: 'session',
id: session.sessionId,
sessionId: session.sessionId,
title: session.title || 'New chat',
docId: session.docId ?? null,
pinned: !!session.pinned,
hasMessages: !!session.messages?.length,
};
}
export function createDraftTab(scope: AIChatScope): AIChatTab {
return {
kind: 'draft',
id: `draft:${scope.kind}:${'docId' in scope ? (scope.docId ?? '') : ''}`,
title: 'New chat',
scope,
hasMessages: false,
};
}
@@ -0,0 +1,90 @@
/**
* @vitest-environment happy-dom
*/
import { renderHook, waitFor } from '@testing-library/react';
import { describe, expect, test, vi } from 'vitest';
import { useAIChatElement } from './use-element';
class TestAIChatElement extends HTMLElement {
accessor value = '';
}
if (!customElements.get('test-ai-chat-element')) {
customElements.define('test-ai-chat-element', TestAIChatElement);
}
function createContainerRef() {
const container = document.createElement('div');
document.body.append(container);
return { current: container };
}
function createElement() {
return document.createElement('test-ai-chat-element') as TestAIChatElement;
}
describe('useAIChatElement', () => {
test('creates one element and keeps its properties in sync', async () => {
const containerRef = createContainerRef();
const { rerender, result } = renderHook(
({ value }) =>
useAIChatElement({
containerRef,
selector: 'test-ai-chat-element',
enabled: true,
createElement,
configureElement: element => {
element.value = value;
},
}),
{ initialProps: { value: 'first' } }
);
await waitFor(() => {
expect(
containerRef.current.querySelectorAll('test-ai-chat-element')
).toHaveLength(1);
expect(result.current?.value).toBe('first');
});
rerender({ value: 'next' });
await waitFor(() => {
expect(
containerRef.current.querySelectorAll('test-ai-chat-element')
).toHaveLength(1);
expect(result.current?.value).toBe('next');
});
});
test('reuses an existing element and removes duplicates', async () => {
const containerRef = createContainerRef();
const first = createElement();
const duplicate = createElement();
containerRef.current.append(first, duplicate);
const createElementSpy = vi.fn(createElement);
const { result } = renderHook(() =>
useAIChatElement({
containerRef,
selector: 'test-ai-chat-element',
enabled: true,
createElement: createElementSpy,
configureElement: element => {
element.value = 'reused';
},
})
);
await waitFor(() => {
expect(
containerRef.current.querySelectorAll('test-ai-chat-element')
).toHaveLength(1);
expect(result.current).toBe(first);
expect(first.value).toBe('reused');
});
expect(createElementSpy).not.toHaveBeenCalled();
});
});
@@ -0,0 +1,62 @@
import type { RefObject } from 'react';
import { useEffect, useRef, useState } from 'react';
export type UseAIChatElementOptions<T extends HTMLElement> = {
containerRef: RefObject<HTMLElement | null>;
selector: string;
enabled: boolean;
createElement: () => T;
configureElement: (element: T) => void;
onElementReady?: (element: T) => void;
};
export function useAIChatElement<T extends HTMLElement>({
containerRef,
selector,
enabled,
createElement,
configureElement,
onElementReady,
}: UseAIChatElementOptions<T>) {
const [element, setElement] = useState<T | null>(null);
const readyElementsRef = useRef(new WeakSet<T>());
useEffect(() => {
const container = containerRef.current;
if (!enabled || !container) return;
const existingElements = Array.from(
container.querySelectorAll(selector)
) as T[];
const nextElement = element ?? existingElements[0] ?? createElement();
existingElements
.filter(existingElement => existingElement !== nextElement)
.forEach(existingElement => existingElement.remove());
configureElement(nextElement);
if (nextElement.parentElement !== container) {
container.append(nextElement);
}
if (!readyElementsRef.current.has(nextElement)) {
readyElementsRef.current.add(nextElement);
onElementReady?.(nextElement);
}
if (element !== nextElement) {
setElement(nextElement);
}
}, [
configureElement,
containerRef,
createElement,
element,
enabled,
onElementReady,
selector,
]);
return element;
}
@@ -0,0 +1,22 @@
import { useEffect, useSyncExternalStore } from 'react';
import type { AIChatRuntime } from './runtime';
/**
* Initializes and owns the passed runtime for the current React mount.
*/
export function useAIChatRuntime(runtime: AIChatRuntime | null) {
const snapshot = useSyncExternalStore(
runtime?.subscribe ?? (() => () => {}),
runtime?.getSnapshot ?? (() => null),
runtime?.getSnapshot ?? (() => null)
);
useEffect(() => {
if (!runtime) return;
runtime.dispatch({ type: 'initialize' }).catch(console.error);
return () => runtime.dispose();
}, [runtime]);
return snapshot;
}
@@ -0,0 +1,225 @@
import type { PromptKey } from '../../provider/prompt';
import { Endpoint } from './copilot-client';
import type { TextToTextOptions } from './message-transport';
export type AIActionId = keyof BlockSuitePresets.AIActions;
export type AIActionOptions = BlockSuitePresets.AITextActionOptions &
Record<string, unknown>;
export type AIActionDefinition = {
id: AIActionId;
promptName: PromptKey | ((options: AIActionOptions) => PromptKey);
responseType: 'text' | 'image';
endpoint?: Endpoint;
actionId?: string | ((options: AIActionOptions) => string | undefined);
actionVersion?: string | ((options: AIActionOptions) => string | undefined);
timeout?: number;
buildContent?: (options: AIActionOptions) => string | undefined;
buildParams?: (options: AIActionOptions) => TextToTextOptions['params'];
validate?: (options: AIActionOptions) => void;
};
const filterStyleToPromptName = new Map<string, PromptKey>(
Object.entries({
'Clay style': 'image.filter.clay',
'Pixel style': 'image.filter.pixel',
'Sketch style': 'image.filter.sketch',
'Anime style': 'image.filter.anime',
})
);
const processTypeToPromptName = new Map<string, PromptKey>(
Object.entries({
Clearer: 'Upscale image',
'Remove background': 'Remove background',
'Convert to sticker': 'Convert to sticker',
})
);
const textAction = (
id: AIActionId,
promptName: PromptKey
): AIActionDefinition => ({
id,
promptName,
responseType: 'text',
buildContent: options => options.input,
});
export const actionDefinitions = {
chat: {
id: 'chat',
promptName: 'Chat With AFFiNE AI',
responseType: 'text',
timeout: 5 * 60 * 1000,
endpoint: Endpoint.StreamObject,
buildContent: options => options.input,
buildParams: options => {
const contexts = options.contexts as
| {
docs?: unknown;
files?: unknown;
selectedSnapshot?: unknown;
selectedMarkdown?: unknown;
html?: unknown;
}
| undefined;
return {
docs: contexts?.docs,
files: contexts?.files,
selectedSnapshot: contexts?.selectedSnapshot,
selectedMarkdown: contexts?.selectedMarkdown,
html: contexts?.html,
...(options.docId ? { currentDocId: options.docId } : {}),
};
},
},
summary: textAction('summary', 'Summary'),
translate: {
...textAction('translate', 'Translate to'),
buildParams: options => ({ language: options.lang }),
},
changeTone: {
...textAction('changeTone', 'Change tone to'),
buildParams: options => ({
tone: typeof options.tone === 'string' ? options.tone.toLowerCase() : '',
}),
},
improveWriting: textAction('improveWriting', 'Improve writing for it'),
improveGrammar: textAction('improveGrammar', 'Improve grammar for it'),
fixSpelling: textAction('fixSpelling', 'Fix spelling for it'),
createHeadings: textAction('createHeadings', 'Create headings'),
makeLonger: textAction('makeLonger', 'Make it longer'),
makeShorter: textAction('makeShorter', 'Make it shorter'),
checkCodeErrors: textAction('checkCodeErrors', 'Check code error'),
explainCode: textAction('explainCode', 'Explain this code'),
writeArticle: textAction('writeArticle', 'Write an article about this'),
writeTwitterPost: textAction(
'writeTwitterPost',
'Write a twitter about this'
),
writePoem: textAction('writePoem', 'Write a poem about this'),
writeOutline: textAction('writeOutline', 'Write outline'),
writeBlogPost: textAction('writeBlogPost', 'Write a blog post about this'),
brainstorm: textAction('brainstorm', 'Brainstorm ideas about this'),
findActions: textAction('findActions', 'Find action items from it'),
brainstormMindmap: {
id: 'brainstormMindmap',
promptName: 'mindmap.generate',
responseType: 'text',
timeout: 180000,
endpoint: Endpoint.Action,
actionId: 'mindmap.generate',
actionVersion: 'v1',
buildContent: options => options.input,
},
expandMindmap: {
...textAction('expandMindmap', 'Expand mind map'),
validate: options => {
if (!options.input) {
throw new Error('expandMindmap action requires input');
}
},
buildParams: options => ({
mindmap: options.mindmap,
node: options.input,
}),
},
explain: textAction('explain', 'Explain this'),
explainImage: textAction('explainImage', 'Explain this image'),
makeItReal: {
id: 'makeItReal',
promptName: options =>
options.attachments && Array.isArray(options.attachments)
? 'Make it real'
: 'Make it real with text',
responseType: 'text',
buildContent: options => {
const input = options.input ?? '';
if (options.attachments && Array.isArray(options.attachments)) {
return `Here are the latest wireframes. Could you make a new website based on these wireframes and notes and send back just the html file?
Here are our design notes:\n ${input}.`;
}
return `Here are the latest notes: \n ${input}.
Could you make a new website based on these notes and send back just the html file?`;
},
},
createSlides: {
id: 'createSlides',
promptName: 'slides.outline',
responseType: 'text',
timeout: 180000,
endpoint: Endpoint.Action,
actionId: 'slides.outline',
actionVersion: 'v1',
buildContent: options => options.input,
},
createImage: {
id: 'createImage',
promptName: 'Generate image',
responseType: 'image',
timeout: 300000,
buildContent: options =>
!options.input && options.attachments
? 'Make the image more detailed.'
: options.input,
},
filterImage: {
id: 'filterImage',
promptName: options => {
const promptName =
typeof options.style === 'string'
? filterStyleToPromptName.get(options.style)
: undefined;
if (!promptName) {
throw new Error('filterImage requires a promptName');
}
return promptName;
},
responseType: 'image',
timeout: 180000,
endpoint: Endpoint.Action,
actionId: options =>
typeof options.style === 'string'
? filterStyleToPromptName.get(options.style)
: undefined,
actionVersion: 'v1',
buildContent: options => options.input,
},
processImage: {
id: 'processImage',
promptName: options => {
const promptName =
typeof options.type === 'string'
? processTypeToPromptName.get(options.type)
: undefined;
if (!promptName) {
throw new Error('processImage requires a promptName');
}
return promptName;
},
responseType: 'image',
timeout: 180000,
buildContent: options => options.input,
},
generateCaption: textAction('generateCaption', 'Generate a caption'),
continueWriting: textAction('continueWriting', 'Continue writing'),
} satisfies Partial<Record<AIActionId, AIActionDefinition>>;
export function getActionDefinition(id: AIActionId): AIActionDefinition {
const definition = actionDefinitions[id];
if (!definition) {
throw new Error(`AI action ${String(id)} is not defined`);
}
return definition;
}
export function resolveDefinitionValue(
value:
| string
| ((options: AIActionOptions) => string | undefined)
| undefined,
options: AIActionOptions
) {
return typeof value === 'function' ? value(options) : value;
}
@@ -0,0 +1,97 @@
import { apis, type ClientHandler } from '@affine/electron-api';
import { UserFriendlyError } from '@affine/error';
import {
ByokProvider,
createWorkspaceByokLocalLeaseMutation,
} from '@affine/graphql';
import type { CopilotClient } from './copilot-client';
function isElectronBuild() {
return typeof BUILD_CONFIG !== 'undefined' && BUILD_CONFIG.isElectron;
}
function byokStorageApi(): ClientHandler['byokStorage'] | undefined {
return isElectronBuild() ? apis?.byokStorage : undefined;
}
function toGraphqlByokProvider(provider: string): ByokProvider | null {
switch (provider) {
case ByokProvider.openai:
return ByokProvider.openai;
case ByokProvider.anthropic:
return ByokProvider.anthropic;
case ByokProvider.gemini:
return ByokProvider.gemini;
case ByokProvider.fal:
return ByokProvider.fal;
default:
return null;
}
}
function errorMetadata(error: unknown) {
if (!error || typeof error !== 'object') {
return { kind: typeof error };
}
const record = error as Record<string, unknown>;
return {
name: typeof record.name === 'string' ? record.name : undefined,
code: typeof record.code === 'string' ? record.code : undefined,
status:
typeof record.status === 'number' || typeof record.status === 'string'
? record.status
: undefined,
type: typeof record.type === 'string' ? record.type : undefined,
};
}
export async function createWorkspaceByokLocalLease(
client: CopilotClient,
workspaceId?: string
) {
const storage = byokStorageApi();
if (!workspaceId || !storage) {
return undefined;
}
try {
if (!(await storage.isSupported())) return undefined;
const providers = await storage.getWorkspaceLeaseProviders(workspaceId);
if (!providers.length) return undefined;
const leaseProviders = providers.flatMap(provider => {
const gqlProvider = toGraphqlByokProvider(provider.provider);
return gqlProvider
? [
{
provider: gqlProvider,
name: provider.name,
description: provider.description ?? null,
apiKey: provider.apiKey,
endpoint: provider.endpoint ?? null,
sortOrder: provider.sortOrder ?? 0,
enabled: provider.enabled ?? true,
},
]
: [];
});
if (!leaseProviders.length) return undefined;
const result = await client.gql({
query: createWorkspaceByokLocalLeaseMutation,
variables: {
input: {
workspaceId,
providers: leaseProviders,
},
},
});
return result.createWorkspaceByokLocalLease.leaseId;
} catch (error) {
console.warn(
'Failed to create workspace BYOK local lease',
errorMetadata(error)
);
throw UserFriendlyError.fromAny(error);
}
}
@@ -38,7 +38,7 @@ import {
GeneralNetworkError,
PaymentRequiredError,
UnauthorizedError,
} from './error';
} from '../../provider/error';
export enum Endpoint {
Action = 'action',
@@ -0,0 +1,4 @@
export * from './action-definitions';
export * from './message-transport';
export * from './provider';
export * from './service';
@@ -1,114 +1,19 @@
import type { AIToolsConfig } from '@affine/core/modules/ai-button';
import { apis, type ClientHandler } from '@affine/electron-api';
import { UserFriendlyError } from '@affine/error';
import {
ByokProvider,
createWorkspaceByokLocalLeaseMutation,
} from '@affine/graphql';
import { partition } from 'lodash-es';
import { AIProvider } from './ai-provider';
import { toTextStream } from '../../provider/event-source';
import { createWorkspaceByokLocalLease } from './byok-local-lease';
import { type CopilotClient, Endpoint } from './copilot-client';
import { toTextStream } from './event-source';
const TIMEOUT = 50000;
function isElectronBuild() {
return typeof BUILD_CONFIG !== 'undefined' && BUILD_CONFIG.isElectron;
}
function byokStorageApi(): ClientHandler['byokStorage'] | undefined {
return isElectronBuild() ? apis?.byokStorage : undefined;
}
function toGraphqlByokProvider(provider: string): ByokProvider | null {
switch (provider) {
case ByokProvider.openai:
return ByokProvider.openai;
case ByokProvider.anthropic:
return ByokProvider.anthropic;
case ByokProvider.gemini:
return ByokProvider.gemini;
case ByokProvider.fal:
return ByokProvider.fal;
default:
return null;
}
}
function errorMetadata(error: unknown) {
if (!error || typeof error !== 'object') {
return { kind: typeof error };
}
const record = error as Record<string, unknown>;
return {
name: typeof record.name === 'string' ? record.name : undefined,
code: typeof record.code === 'string' ? record.code : undefined,
status:
typeof record.status === 'number' || typeof record.status === 'string'
? record.status
: undefined,
type: typeof record.type === 'string' ? record.type : undefined,
};
}
async function createWorkspaceByokLocalLease(
client: CopilotClient,
workspaceId?: string
) {
const storage = byokStorageApi();
if (!workspaceId || !storage) {
return undefined;
}
try {
if (!(await storage.isSupported())) return undefined;
const providers = await storage.getWorkspaceLeaseProviders(workspaceId);
if (!providers.length) return undefined;
const leaseProviders = providers.flatMap(provider => {
const gqlProvider = toGraphqlByokProvider(provider.provider);
return gqlProvider
? [
{
provider: gqlProvider,
name: provider.name,
description: provider.description ?? null,
apiKey: provider.apiKey,
endpoint: provider.endpoint ?? null,
sortOrder: provider.sortOrder ?? 0,
enabled: provider.enabled ?? true,
},
]
: [];
});
if (!leaseProviders.length) return undefined;
const result = await client.gql({
query: createWorkspaceByokLocalLeaseMutation,
variables: {
input: {
workspaceId,
providers: leaseProviders,
},
},
});
return result.createWorkspaceByokLocalLease.leaseId;
} catch (error) {
console.warn(
'Failed to create workspace BYOK local lease',
errorMetadata(error)
);
throw UserFriendlyError.fromAny(error);
}
}
export type TextToTextOptions = {
client: CopilotClient;
sessionId: string;
workspaceId?: string;
content?: string;
attachments?: (string | Blob | File)[];
params?: Record<string, any>;
params?: Record<string, unknown>;
timeout?: number;
stream?: boolean;
signal?: AbortSignal;
@@ -138,7 +43,6 @@ async function resizeImage(blob: Blob | File): Promise<Blob | null> {
});
const canvas = document.createElement('canvas');
// keep aspect ratio
const scale = Math.min(1024 / img.width, 1024 / img.height);
canvas.width = Math.floor(img.width * scale);
canvas.height = Math.floor(img.height * scale);
@@ -164,7 +68,7 @@ interface CreateMessageOptions {
sessionId: string;
content?: string;
attachments?: (string | Blob | File)[];
params?: Record<string, any>;
params?: Record<string, unknown>;
timeout?: number;
signal?: AbortSignal;
}
@@ -182,7 +86,7 @@ async function createMessage({
const options: Parameters<CopilotClient['createMessage']>[0] = {
sessionId,
content,
params,
params: params as Parameters<CopilotClient['createMessage']>[0]['params'],
};
if (hasAttachments) {
@@ -267,7 +171,6 @@ export function textToText({
},
endpoint
);
AIProvider.LAST_ACTION_SESSIONID = sessionId;
let onAbort: (() => void) | undefined;
try {
@@ -336,7 +239,6 @@ export function textToText({
},
endpoint
);
AIProvider.LAST_ACTION_SESSIONID = sessionId;
let onAbort: (() => void) | undefined;
try {
@@ -361,8 +263,7 @@ export function textToText({
}
}
const result = messages.join('');
return result;
return messages.join('');
} finally {
eventSource.close();
if (signal && onAbort) {
@@ -373,7 +274,6 @@ export function textToText({
}
}
// Only one image is currently being processed
export function toImage({
content,
sessionId,
@@ -435,7 +335,6 @@ export function toImage({
endpoint,
byokLeaseId
);
AIProvider.LAST_ACTION_SESSIONID = sessionId;
for await (const event of toTextStream(eventSource, {
timeout,
@@ -0,0 +1,18 @@
import type { AIRequestService } from './service';
let currentRequestService: AIRequestService | null = null;
export function setAIRequestService(service: AIRequestService | null) {
currentRequestService = service;
}
export function getAIRequestService() {
if (!currentRequestService) {
throw new Error('AIRequestService is not initialized');
}
return currentRequestService;
}
export function hasAIRequestService() {
return !!currentRequestService;
}
@@ -0,0 +1,379 @@
/**
* @vitest-environment happy-dom
*/
import { UserFriendlyError } from '@affine/error';
import { beforeEach, describe, expect, test, vi } from 'vitest';
import { type CopilotClient, Endpoint } from './copilot-client';
import { textToText, toImage } from './message-transport';
import { AIRequestService } from './service';
Object.defineProperty(globalThis, 'EventSource', {
configurable: true,
value: {
CLOSED: 2,
},
});
const electronApis = vi.hoisted(() => ({
byokStorage: undefined as
| {
isSupported: () => Promise<boolean>;
getWorkspaceLeaseProviders: (workspaceId: string) => Promise<
Array<{
provider: string;
name: string;
apiKey: string;
description?: string | null;
endpoint?: string | null;
sortOrder?: number | null;
enabled?: boolean | null;
}>
>;
}
| undefined,
}));
const createWorkspaceByokLocalLeaseMutation = vi.hoisted(() =>
Symbol('createWorkspaceByokLocalLeaseMutation')
);
vi.mock('@affine/electron-api', () => ({
apis: electronApis,
}));
vi.mock('@affine/graphql', () => ({
ByokProvider: {
openai: 'openai',
anthropic: 'anthropic',
gemini: 'gemini',
fal: 'fal',
},
ContextCategories: {
Tag: 'tag',
Collection: 'collection',
},
createWorkspaceByokLocalLeaseMutation,
}));
function createClosedEventSource(): EventSource {
return {
readyState: EventSource.CLOSED,
addEventListener: vi.fn(),
close: vi.fn(),
} as unknown as EventSource;
}
function createClient(
overrides: Partial<
Pick<
CopilotClient,
| 'gql'
| 'createSession'
| 'createMessage'
| 'getSessions'
| 'getHistories'
| 'chatTextStream'
| 'imagesStream'
>
> = {}
) {
return {
gql: vi.fn().mockResolvedValue({
createWorkspaceByokLocalLease: { leaseId: 'lease-1' },
}),
createSession: vi.fn().mockImplementation(async options => {
return `session:${options.promptName}`;
}),
createMessage: vi.fn().mockResolvedValue('message-1'),
getSessions: vi.fn().mockResolvedValue([]),
getHistories: vi.fn().mockResolvedValue([]),
chatTextStream: vi.fn(() => createClosedEventSource()),
imagesStream: vi.fn(() => createClosedEventSource()),
...overrides,
} as unknown as CopilotClient;
}
async function drain(stream: AsyncIterable<unknown>) {
for await (const chunk of stream) {
void chunk;
}
}
async function drainActionResult(
stream: string | AsyncIterable<unknown> | undefined
) {
expect(stream).toBeDefined();
expect(typeof stream).not.toBe('string');
await drain(stream as AsyncIterable<unknown>);
}
describe('runtime request transport BYOK local lease handling', () => {
beforeEach(() => {
vi.stubGlobal('BUILD_CONFIG', { isElectron: true });
electronApis.byokStorage = {
isSupported: vi.fn().mockResolvedValue(true),
getWorkspaceLeaseProviders: vi.fn().mockResolvedValue([
{
provider: 'openai',
name: 'OpenAI',
apiKey: 'sk-local',
},
]),
};
});
test('fails closed when local BYOK providers exist but lease creation fails', async () => {
const client = createClient({
gql: vi.fn().mockRejectedValue(new Error('mutation failed')),
});
const result = textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
}) as Promise<string>;
await expect(result).rejects.toThrow('mutation failed');
await expect(result).rejects.toBeInstanceOf(UserFriendlyError);
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('does not create stream local BYOK lease after cancellation', async () => {
const controller = new AbortController();
const client = createClient({
createMessage: vi.fn().mockImplementation(async () => {
controller.abort();
return 'message-1';
}),
});
await drain(
textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
stream: true,
signal: controller.signal,
}) as AsyncIterable<string>
);
expect(client.gql).not.toHaveBeenCalled();
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('does not create image stream when cancelled while creating local BYOK lease', async () => {
const controller = new AbortController();
const client = createClient({
gql: vi.fn().mockImplementation(async () => {
controller.abort();
return { createWorkspaceByokLocalLease: { leaseId: 'lease-1' } };
}),
});
await drain(
toImage({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'image',
endpoint: Endpoint.Images,
signal: controller.signal,
}) as AsyncIterable<string>
);
expect(client.gql).toHaveBeenCalled();
expect(client.imagesStream).not.toHaveBeenCalled();
});
});
describe('AIRequestService action definitions', () => {
beforeEach(() => {
vi.stubGlobal('BUILD_CONFIG', { isElectron: false });
electronApis.byokStorage = undefined;
});
test('routes action-stream requests through action endpoint', async () => {
const client = createClient();
const service = new AIRequestService(client);
await drainActionResult(
(await service.executeAction('brainstormMindmap', {
workspaceId: 'workspace-1',
input: 'make a map',
stream: true,
})) as AsyncIterable<unknown>
);
await drainActionResult(
(await service.executeAction('createSlides', {
workspaceId: 'workspace-1',
input: 'make slides',
stream: true,
})) as AsyncIterable<unknown>
);
await drainActionResult(
(await service.executeAction('filterImage', {
workspaceId: 'workspace-1',
input: 'convert',
attachments: ['blob-1'],
style: 'Sketch style',
})) as AsyncIterable<unknown>
);
expect(client.createSession).toHaveBeenCalledWith(
expect.objectContaining({ promptName: 'mindmap.generate' })
);
expect(client.createSession).toHaveBeenCalledWith(
expect.objectContaining({ promptName: 'slides.outline' })
);
expect(client.createSession).toHaveBeenCalledWith(
expect.objectContaining({ promptName: 'image.filter.sketch' })
);
expect(client.chatTextStream).toHaveBeenCalledWith(
expect.objectContaining({ actionId: 'mindmap.generate' }),
Endpoint.Action
);
expect(client.chatTextStream).toHaveBeenCalledWith(
expect.objectContaining({ actionId: 'slides.outline' }),
Endpoint.Action
);
expect(client.chatTextStream).toHaveBeenCalledWith(
expect.objectContaining({ actionId: 'image.filter.sketch' }),
Endpoint.Action
);
expect(client.imagesStream).not.toHaveBeenCalled();
});
test('reuses the last action session for retry', async () => {
const client = createClient();
const service = new AIRequestService(client);
await drainActionResult(
(await service.executeAction('summary', {
workspaceId: 'workspace-1',
input: 'summarize',
stream: true,
})) as AsyncIterable<unknown>
);
await drainActionResult(
(await service.executeAction('summary', {
workspaceId: 'workspace-1',
input: 'summarize again',
retry: true,
stream: true,
})) as AsyncIterable<unknown>
);
expect(client.createSession).toHaveBeenCalledTimes(1);
expect(client.createMessage).toHaveBeenCalledTimes(1);
expect(client.chatTextStream).toHaveBeenLastCalledWith(
expect.objectContaining({
sessionId: 'session:Summary',
retry: true,
}),
Endpoint.StreamObject
);
});
test('reports action result against the matching host action', async () => {
const client = createClient();
const service = new AIRequestService(client);
const events: string[] = [];
const hostOne = {} as NonNullable<
BlockSuitePresets.AITextActionOptions['host']
>;
const hostTwo = {} as NonNullable<
BlockSuitePresets.AITextActionOptions['host']
>;
const subscription = service.actionEvents$.subscribe(event => {
events.push(
`${event.options.host === hostOne ? 'one' : 'two'}:${event.event}`
);
});
await drainActionResult(
(await service.executeAction('summary', {
workspaceId: 'workspace-1',
input: 'first',
host: hostOne,
stream: true,
})) as AsyncIterable<unknown>
);
await drainActionResult(
(await service.executeAction('translate', {
workspaceId: 'workspace-1',
input: 'second',
lang: 'French',
host: hostTwo,
stream: true,
})) as AsyncIterable<unknown>
);
service.reportLastAction('result:insert', hostOne);
subscription.unsubscribe();
expect(events).toContain('one:result:insert');
});
test('loads sessions through history query with messages', async () => {
const history = {
sessionId: 'session-1',
workspaceId: 'workspace-1',
docId: 'doc-1',
messages: [{ id: 'message-1', role: 'user', content: 'hello' }],
};
const client = createClient({
getHistories: vi.fn().mockResolvedValue([history]),
});
const service = new AIRequestService(client);
const session = await service.getSession('workspace-1', 'session-1');
expect(client.getHistories).toHaveBeenCalledWith(
'workspace-1',
{},
undefined,
expect.objectContaining({
sessionId: 'session-1',
withMessages: true,
})
);
expect(session?.messages).toEqual(history.messages);
});
test('loads chat history lists with messages for title derivation', async () => {
const client = createClient();
const service = new AIRequestService(client);
await service.getSessions('workspace-1', 'doc-1', {
action: false,
fork: false,
});
await service.getRecentSessions('workspace-1', 10, 20);
expect(client.getSessions).toHaveBeenCalledWith(
'workspace-1',
{},
'doc-1',
expect.objectContaining({
action: false,
fork: false,
withMessages: true,
}),
undefined
);
expect(client.getHistories).toHaveBeenCalledWith(
'workspace-1',
{ first: 10, offset: 20 },
undefined,
expect.objectContaining({
action: false,
fork: false,
sessionOrder: 'desc',
withMessages: true,
})
);
});
});
@@ -0,0 +1,396 @@
import {
ContextCategories,
type CopilotChatHistoryFragment,
type getCopilotHistoriesQuery,
type GraphQLQuery,
type QueryChatSessionsInput,
type QueryOptions,
type QueryResponse,
type RequestOptions,
type UpdateChatSessionInput,
} from '@affine/graphql';
import { Subject } from 'rxjs';
import type { ActionEventType } from '../../provider';
import {
type AIActionId,
type AIActionOptions,
getActionDefinition,
resolveDefinitionValue,
} from './action-definitions';
import {
CopilotClient,
type CopilotClient as CopilotClientType,
} from './copilot-client';
import { textToText, toImage } from './message-transport';
type CreateSessionOptions = BlockSuitePresets.AICreateSessionOptions;
export type AIRequestActionEvent = {
action: AIActionId;
options: AIActionOptions;
event: ActionEventType;
};
export class AIRequestService {
private lastActionSessionId = '';
private readonly actionHistory: {
action: AIActionId;
options: AIActionOptions;
}[] = [];
readonly actionEvents$ = new Subject<AIRequestActionEvent>();
constructor(readonly client: CopilotClientType) {}
isReady() {
return true;
}
async createSession(options: CreateSessionOptions) {
if (options.sessionId) return options.sessionId;
if (options.retry) return this.lastActionSessionId;
return this.client.createSession({
workspaceId: options.workspaceId,
docId: options.docId,
promptName: options.promptName,
pinned: options.pinned,
reuseLatestChat: options.reuseLatestChat,
});
}
async createSessionWithHistory(options: CreateSessionOptions) {
if (!options.sessionId && !options.retry) {
return this.client.createSessionWithHistory({
workspaceId: options.workspaceId,
docId: options.docId,
promptName: options.promptName,
pinned: options.pinned,
reuseLatestChat: options.reuseLatestChat,
});
}
const sessionId = await this.createSession(options);
if (!sessionId) return undefined;
return this.getSession(options.workspaceId, sessionId);
}
getSession(workspaceId: string, sessionId: string) {
return this.client
.getHistories(workspaceId, {}, undefined, {
sessionId,
withMessages: true,
} as RequestOptions<
typeof getCopilotHistoriesQuery
>['variables']['options'])
.then(
histories =>
(histories?.[0] ?? null) as CopilotChatHistoryFragment | null
);
}
getSessions(
workspaceId: string,
docId?: string,
options?: QueryChatSessionsInput,
signal?: AbortSignal
) {
return this.client.getSessions(
workspaceId,
{},
docId,
{ ...options, withMessages: true },
signal
);
}
getRecentSessions(workspaceId: string, limit?: number, offset?: number) {
return this.client.getHistories(
workspaceId,
{ first: limit, offset },
undefined,
{
action: false,
fork: false,
sessionOrder: 'desc',
withMessages: true,
} as RequestOptions<
typeof getCopilotHistoriesQuery
>['variables']['options']
);
}
updateSession(options: UpdateChatSessionInput) {
return this.client.updateSession(options);
}
cleanupSessions(input: {
workspaceId: string;
docId: string | undefined;
sessionIds: string[];
}) {
return this.client.cleanupSessions(input);
}
histories = {
actions: async (
workspaceId: string,
docId: string
): Promise<BlockSuitePresets.AIHistory[]> => {
return ((await this.client.getHistories(workspaceId, {}, docId, {
action: true,
withPrompt: true,
withMessages: true,
} as RequestOptions<
typeof getCopilotHistoriesQuery
>['variables']['options'])) ?? []) as BlockSuitePresets.AIHistory[];
},
chats: async (
workspaceId: string,
sessionId: string,
docId?: string
): Promise<BlockSuitePresets.AIHistory[]> => {
return ((await this.client.getHistories(workspaceId, {}, docId, {
sessionId,
withMessages: true,
} as RequestOptions<
typeof getCopilotHistoriesQuery
>['variables']['options'])) ?? []) as BlockSuitePresets.AIHistory[];
},
cleanup: async (
workspaceId: string,
docId: string | undefined,
sessionIds: string[]
) => {
await this.cleanupSessions({ workspaceId, docId, sessionIds });
},
ids: async (
workspaceId: string,
docId?: string,
options?: RequestOptions<
typeof getCopilotHistoriesQuery
>['variables']['options']
): Promise<BlockSuitePresets.AIHistoryIds[]> => {
return (await this.client.getHistoryIds(
workspaceId,
{},
docId,
options
)) as unknown as BlockSuitePresets.AIHistoryIds[];
},
};
context = {
createContext: (workspaceId: string, sessionId: string) =>
this.client.createContext(workspaceId, sessionId),
getContextId: (workspaceId: string, sessionId: string) =>
this.client.getContextId(workspaceId, sessionId),
addContextDoc: (options: { contextId: string; docId: string }) =>
this.client.addContextDoc(options),
removeContextDoc: (options: { contextId: string; docId: string }) =>
this.client.removeContextDoc(options),
addContextFile: (
file: File,
options: Parameters<CopilotClient['addContextFile']>[1]
) => this.client.addContextFile(file, options),
removeContextFile: (options: { contextId: string; fileId: string }) =>
this.client.removeContextFile(options),
addContextTag: (options: {
contextId: string;
tagId: string;
docIds: string[];
}) =>
this.client.addContextCategory({
contextId: options.contextId,
type: ContextCategories.Tag,
categoryId: options.tagId,
docs: options.docIds,
}),
removeContextTag: (options: { contextId: string; tagId: string }) =>
this.client.removeContextCategory({
contextId: options.contextId,
type: ContextCategories.Tag,
categoryId: options.tagId,
}),
addContextCollection: (options: {
contextId: string;
collectionId: string;
docIds: string[];
}) =>
this.client.addContextCategory({
contextId: options.contextId,
type: ContextCategories.Collection,
categoryId: options.collectionId,
docs: options.docIds,
}),
removeContextCollection: (options: {
contextId: string;
collectionId: string;
}) =>
this.client.removeContextCategory({
contextId: options.contextId,
type: ContextCategories.Collection,
categoryId: options.collectionId,
}),
getContextDocsAndFiles: (
workspaceId: string,
sessionId: string,
contextId: string
) => this.client.getContextDocsAndFiles(workspaceId, sessionId, contextId),
matchContext: (
content: string,
contextId?: string,
workspaceId?: string,
limit?: number,
scopedThreshold?: number,
threshold?: number
) =>
this.client.matchContext(
content,
contextId,
workspaceId,
limit,
scopedThreshold,
threshold
),
addContextBlob: (options: { blobId: string; contextId: string }) =>
this.client.addContextBlob({
contextId: options.contextId,
blobId: options.blobId,
}),
removeContextBlob: (options: { blobId: string; contextId: string }) =>
this.client.removeContextBlob({
contextId: options.contextId,
blobId: options.blobId,
}),
pollContextDocsAndFiles: async (
workspaceId: string,
sessionId: string,
contextId: string,
onPoll: (
result: BlockSuitePresets.AIDocsAndFilesContext | undefined
) => void,
abortSignal: AbortSignal
) => {
let attempts = 0;
const minInterval = 1000;
const maxInterval = 30 * 1000;
while (!abortSignal.aborted) {
const result = await this.client.getContextDocsAndFiles(
workspaceId,
sessionId,
contextId
);
onPoll(result);
const interval = Math.min(
minInterval * Math.pow(1.5, attempts),
maxInterval
);
attempts++;
await new Promise(resolve => setTimeout(resolve, interval));
}
},
pollEmbeddingStatus: async (
workspaceId: string,
onPoll: (
result: Awaited<ReturnType<CopilotClientType['getEmbeddingStatus']>>
) => void,
abortSignal: AbortSignal
) => {
const interval = 10 * 1000;
while (!abortSignal.aborted) {
onPoll(await this.client.getEmbeddingStatus(workspaceId));
await new Promise(resolve => setTimeout(resolve, interval));
}
},
};
forkChat(options: BlockSuitePresets.AIForkChatSessionOptions) {
return this.client.forkSession(options);
}
reportLastAction(event: ActionEventType, host?: unknown) {
const lastAction = host
? this.actionHistory.findLast(item => item.options.host === host)
: this.actionHistory.at(-1);
if (!lastAction) return;
this.actionEvents$.next({
action: lastAction.action,
options: lastAction.options,
event,
});
}
private wrapTextStream(
stream: AsyncIterable<string>,
id: AIActionId,
options: AIActionOptions
): AsyncIterable<string> {
const actionEvents$ = this.actionEvents$;
return {
async *[Symbol.asyncIterator]() {
try {
yield* stream;
actionEvents$.next({ action: id, options, event: 'finished' });
} catch (error) {
actionEvents$.next({ action: id, options, event: 'error' });
throw error;
}
},
};
}
async executeAction(id: AIActionId, options: AIActionOptions) {
this.actionHistory.push({ action: id, options });
if (this.actionHistory.length > 10) {
this.actionHistory.shift();
}
this.actionEvents$.next({ action: id, options, event: 'started' });
const definition = getActionDefinition(id);
definition.validate?.(options);
const promptName = resolveDefinitionValue(
definition.promptName,
options
) as CreateSessionOptions['promptName'];
const sessionId = await this.createSession({
promptName,
...options,
} as CreateSessionOptions);
this.lastActionSessionId = sessionId;
const actionId = resolveDefinitionValue(definition.actionId, options);
const actionVersion = resolveDefinitionValue(
definition.actionVersion,
options
);
const transportOptions = {
...options,
client: this.client,
sessionId,
content: definition.buildContent?.(options) ?? options.input,
params: definition.buildParams?.(options),
timeout: definition.timeout,
endpoint: definition.endpoint,
actionId,
actionVersion,
};
const stream =
definition.responseType === 'image'
? toImage(transportOptions)
: textToText(transportOptions);
return this.wrapTextStream(stream as AsyncIterable<string>, id, options);
}
}
export function createAIRequestService(
gql: <Query extends GraphQLQuery>(
options: QueryOptions<Query>
) => Promise<QueryResponse<Query>>,
eventSource: (
url: string,
eventSourceInitDict?: EventSourceInit
) => EventSource
) {
return new AIRequestService(new CopilotClient(gql, eventSource));
}
@@ -1,12 +1,6 @@
import { type ActionEventType, AIProvider } from '../provider';
import type { ActionEventType } from '../provider';
import { getAIRequestService } from '../runtime/request';
export function reportResponse(event: ActionEventType) {
const lastAction = AIProvider.actionHistory.at(-1);
if (!lastAction) return;
AIProvider.slots.actions.next({
action: lastAction.action,
options: lastAction.options,
event,
});
export function reportResponse(event: ActionEventType, host?: unknown) {
getAIRequestService().reportLastAction(event, host);
}
@@ -27,7 +27,7 @@ import { styleMap } from 'lit/directives/style-map.js';
import { literal, unsafeStatic } from 'lit/static-html.js';
import type { AIItemGroupConfig } from '../../components/ai-item/types.js';
import { AIProvider } from '../../provider/index.js';
import { AIAppEvents } from '../../provider/index.js';
import { extractSelectedContent } from '../../utils/extract.js';
import {
AFFINE_AI_PANEL_WIDGET,
@@ -106,7 +106,7 @@ export class EdgelessCopilotWidget extends WidgetComponent<RootBlockModel> {
aiPanel.hide();
extractSelectedContent(this.host)
.then(context => {
AIProvider.slots.requestSendWithChat.next({
AIAppEvents.requestSendWithChat.next({
input,
context,
host: this.host,
@@ -4,8 +4,8 @@ import {
resolveGlobalLoadingEventAtom,
} from '@affine/component/global-loading';
import {
AIProvider,
CopilotClient,
AIAppEvents,
createAIRequestService,
setupAIProvider,
} from '@affine/core/blocksuite/ai';
import { useRegisterFindInPageCommands } from '@affine/core/components/hooks/affine/use-register-find-in-page-commands';
@@ -103,7 +103,7 @@ export const WorkspaceSideEffects = () => {
})
);
const disposable = AIProvider.slots.requestInsertTemplate.subscribe(
const disposable = AIAppEvents.requestInsertTemplate.subscribe(
({ template, mode }) => {
insertTemplate({ template, mode });
}
@@ -126,7 +126,7 @@ export const WorkspaceSideEffects = () => {
const globalDialogService = useService(GlobalDialogService);
useEffect(() => {
const disposable = AIProvider.slots.requestUpgradePlan.subscribe(() => {
const disposable = AIAppEvents.requestUpgradePlan.subscribe(() => {
workspaceDialogService.open('setting', {
activeTab: 'billing',
});
@@ -143,7 +143,10 @@ export const WorkspaceSideEffects = () => {
useEffect(() => {
const dispose = setupAIProvider(
new CopilotClient(graphqlService.gql, eventSourceService.eventSource),
createAIRequestService(
graphqlService.gql,
eventSourceService.eventSource
),
globalDialogService,
authService
);
@@ -1,142 +0,0 @@
import { WorkspaceLocalState } from '@affine/core/modules/workspace';
import type { I18nInstance } from '@affine/i18n';
import type { NotificationService } from '@blocksuite/affine/shared/services';
import { useService } from '@toeverything/infra';
import {
type Dispatch,
type SetStateAction,
useCallback,
useEffect,
useRef,
useState,
} from 'react';
const AI_CHAT_OPEN_TABS_KEY = 'aiChatOpenTabs';
// Pass `null` for `loadSession` to defer hydration until a real loader is ready.
export function useAIChatOpenTabs<T extends { sessionId: string }>(
loadSession: ((sessionId: string) => Promise<T | null | undefined>) | null
): {
openTabs: T[];
setOpenTabs: Dispatch<SetStateAction<T[]>>;
} {
const workspaceLocalState = useService(WorkspaceLocalState);
const [openTabs, setOpenTabsState] = useState<T[]>([]);
// Ref so persist gate isn't subject to React state-batch ordering.
const hydratedRef = useRef(false);
useEffect(() => {
if (!loadSession) return;
hydratedRef.current = false;
setOpenTabsState([]);
const ids = workspaceLocalState.get<string[]>(AI_CHAT_OPEN_TABS_KEY) ?? [];
if (!ids.length) {
hydratedRef.current = true;
return;
}
let cancelled = false;
Promise.all(ids.map(id => loadSession(id).catch(() => null)))
.then(results => {
if (cancelled) return;
const valid = (results as (T | null | undefined)[]).filter(
(entry): entry is T => !!entry && !!entry.sessionId
);
if (valid.length) setOpenTabsState(valid);
hydratedRef.current = true;
})
.catch(error => {
console.error(error);
if (!cancelled) hydratedRef.current = true;
});
return () => {
cancelled = true;
};
}, [loadSession, workspaceLocalState]);
const setOpenTabs = useCallback<Dispatch<SetStateAction<T[]>>>(
updater => {
setOpenTabsState(prev => {
const next =
typeof updater === 'function'
? (updater as (p: T[]) => T[])(prev)
: updater;
if (hydratedRef.current) {
if (next.length) {
workspaceLocalState.set(
AI_CHAT_OPEN_TABS_KEY,
next.map(tab => tab.sessionId)
);
} else {
workspaceLocalState.del(AI_CHAT_OPEN_TABS_KEY);
}
}
return next;
});
},
[workspaceLocalState]
);
return { openTabs, setOpenTabs };
}
export type SessionDeleteCleanupFn = (
session: BlockSuitePresets.AIRecentSession
) => Promise<void>;
export type CreateSessionDeleteHandlerOptions = {
t: I18nInstance;
notificationService: NotificationService;
cleanupSession: SessionDeleteCleanupFn;
canDeleteSession?: (session: BlockSuitePresets.AIRecentSession) => boolean;
isActiveSession?: (session: BlockSuitePresets.AIRecentSession) => boolean;
onActiveSessionDeleted?: () => void;
};
export function createSessionDeleteHandler({
t,
notificationService,
cleanupSession,
canDeleteSession,
isActiveSession,
onActiveSessionDeleted,
}: CreateSessionDeleteHandlerOptions) {
return async (sessionToDelete: BlockSuitePresets.AIRecentSession) => {
if (canDeleteSession && !canDeleteSession(sessionToDelete)) {
notificationService.toast(
t['com.affine.ai.chat-panel.session.delete.toast.failed']()
);
return;
}
const confirm = await notificationService.confirm({
title: t['com.affine.ai.chat-panel.session.delete.confirm.title'](),
message: t['com.affine.ai.chat-panel.session.delete.confirm.message'](),
confirmText: t['Delete'](),
cancelText: t['Cancel'](),
});
if (!confirm) {
return;
}
try {
await cleanupSession(sessionToDelete);
notificationService.toast(
t['com.affine.ai.chat-panel.session.delete.toast.success']()
);
} catch (error) {
console.error(error);
notificationService.toast(
t['com.affine.ai.chat-panel.session.delete.toast.failed']()
);
return;
}
if (isActiveSession?.(sessionToDelete)) {
onActiveSessionDeleted?.();
}
};
}
@@ -1,17 +1,17 @@
import { observeResize, useConfirmModal } from '@affine/component';
import { CopilotClient } from '@affine/core/blocksuite/ai';
import {
AIChatContent,
type ChatContextValue,
} from '@affine/core/blocksuite/ai/components/ai-chat-content';
import type { ChatStatus } from '@affine/core/blocksuite/ai/components/ai-chat-messages';
import type { AIChatToolbar } from '@affine/core/blocksuite/ai/components/ai-chat-toolbar';
AIChatRuntime,
createAIRequestService,
useAIChatElement,
useAIChatRuntime,
WorkspaceAIChatSessionStrategy,
} from '@affine/core/blocksuite/ai';
import { AIChatContent } from '@affine/core/blocksuite/ai/components/ai-chat-content';
import {
AIChatTabs,
AIChatToolbar,
configureAIChatToolbar,
getOrCreateAIChatToolbar,
} from '@affine/core/blocksuite/ai/components/ai-chat-toolbar';
import type { PromptKey } from '@affine/core/blocksuite/ai/provider/prompt';
import { getViewManager } from '@affine/core/blocksuite/manager/view';
import { NotificationServiceImpl } from '@affine/core/blocksuite/view-extensions/editor-view/notification-service';
import { useAIChatConfig } from '@affine/core/components/hooks/affine/use-ai-chat-config';
@@ -50,20 +50,18 @@ import { useFramework, useService } from '@toeverything/infra';
import { nanoid } from 'nanoid';
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import {
createSessionDeleteHandler,
useAIChatOpenTabs,
} from '../chat-panel-utils';
import * as styles from './index.css';
type CopilotSession = Awaited<ReturnType<CopilotClient['getSession']>>;
function useCopilotClient() {
function useAIRequestService() {
const graphqlService = useService(GraphQLService);
const eventSourceService = useService(EventSourceService);
return useMemo(
() => new CopilotClient(graphqlService.gql, eventSourceService.eventSource),
() =>
createAIRequestService(
graphqlService.gql,
eventSourceService.eventSource
),
[graphqlService, eventSourceService]
);
}
@@ -95,164 +93,32 @@ export const Component = () => {
const framework = useFramework();
const [isBodyProvided, setIsBodyProvided] = useState(false);
const [isHeaderProvided, setIsHeaderProvided] = useState(false);
const [chatContent, setChatContent] = useState<AIChatContent | null>(null);
const [chatTool, setChatTool] = useState<AIChatToolbar | null>(null);
const [chatTabs, setChatTabs] = useState<AIChatTabs | null>(null);
const [currentSession, setCurrentSession] = useState<CopilotSession | null>(
null
);
const [status, setStatus] = useState<ChatStatus>('idle');
const [isTogglingPin, setIsTogglingPin] = useState(false);
const [isOpeningSession, setIsOpeningSession] = useState(false);
const hasRestoredPinnedSessionRef = useRef(false);
const chatContainerRef = useRef<HTMLDivElement>(null);
const chatToolContainerRef = useRef<HTMLDivElement>(null);
const chatTabsContainerRef = useRef<HTMLDivElement | null>(null);
const widthSignalRef = useRef<Signal<number>>(signal(0));
const client = useCopilotClient();
const requestService = useAIRequestService();
const workbench = useService(WorkbenchService).workbench;
const workspaceId = useService(WorkspaceService).workspace.id;
const loadSession = useCallback(
(sessionId: string) => client.getSession(workspaceId, sessionId),
[client, workspaceId]
const runtime = useMemo(
() =>
new AIChatRuntime({
request: requestService,
scope: { kind: 'workspace', workspaceId },
strategy: new WorkspaceAIChatSessionStrategy(),
}),
[requestService, workspaceId]
);
const { openTabs, setOpenTabs } = useAIChatOpenTabs(loadSession);
useEffect(() => {
hasRestoredPinnedSessionRef.current = false;
}, [workspaceId]);
const snapshot = useAIChatRuntime(runtime);
const session =
snapshot?.sessions.find(
session => session.sessionId === snapshot.activeSessionId
) ?? null;
const { docDisplayConfig, searchMenuConfig, reasoningConfig } =
useAIChatConfig();
const createSession = useCallback(
async (options: Partial<BlockSuitePresets.AICreateSessionOptions> = {}) => {
if (currentSession) {
return currentSession;
}
const session = await client.createSessionWithHistory({
workspaceId,
promptName: 'Chat With AFFiNE AI' satisfies PromptKey,
reuseLatestChat: false,
...options,
});
setCurrentSession(session);
return session;
},
[client, currentSession, workspaceId]
);
const togglePin = useCallback(async () => {
if (isTogglingPin) return;
setIsTogglingPin(true);
try {
const pinned = !currentSession?.pinned;
if (!currentSession) {
await createSession({ pinned });
} else {
await client.updateSession({
sessionId: currentSession.sessionId,
pinned,
});
// retrieve the latest session and update the state
const session = await client.getSession(
workspaceId,
currentSession.sessionId
);
setCurrentSession(session);
}
} finally {
setIsTogglingPin(false);
}
}, [client, createSession, currentSession, isTogglingPin, workspaceId]);
// remove the old content to trigger re-mount
// to avoid infinitely load and mount, should not make `chatContent` as dependency
const reMountChatContent = useCallback(() => {
setChatContent(prev => {
prev?.remove();
return null;
});
}, []);
const createFreshSession = useCallback(async () => {
if (isOpeningSession) {
return;
}
setIsOpeningSession(true);
try {
setCurrentSession(null);
reMountChatContent();
const session = await client.createSessionWithHistory({
workspaceId,
promptName: 'Chat With AFFiNE AI' satisfies PromptKey,
reuseLatestChat: false,
});
setCurrentSession(session);
} catch (error) {
console.error(error);
} finally {
setIsOpeningSession(false);
}
}, [client, isOpeningSession, reMountChatContent, workspaceId]);
const onOpenSession = useCallback(
async (sessionId: string) => {
if (isOpeningSession || currentSession?.sessionId === sessionId) return;
setIsOpeningSession(true);
try {
const session = await client.getSession(workspaceId, sessionId);
if (!session) {
// Drop stale tab if session no longer exists.
setOpenTabs(prev => prev.filter(tab => tab.sessionId !== sessionId));
return;
}
setCurrentSession(session);
reMountChatContent();
chatTool?.closeHistoryMenu();
} catch (error) {
console.error(error);
} finally {
setIsOpeningSession(false);
}
},
[
chatTool,
client,
currentSession?.sessionId,
isOpeningSession,
reMountChatContent,
setOpenTabs,
workspaceId,
]
);
const closeTab = useCallback(
(sessionId: string) => {
let fallback: NonNullable<CopilotSession> | undefined;
setOpenTabs(prev => {
const idx = prev.findIndex(tab => tab.sessionId === sessionId);
if (idx === -1) return prev;
const next = prev.filter(tab => tab.sessionId !== sessionId);
fallback = next[idx] ?? next[idx - 1] ?? next[0];
return next;
});
if (currentSession?.sessionId !== sessionId) return;
if (fallback) {
onOpenSession(fallback.sessionId).catch(console.error);
} else {
createFreshSession().catch(console.error);
}
},
[createFreshSession, currentSession?.sessionId, onOpenSession, setOpenTabs]
);
const onContextChange = useCallback((context: Partial<ChatContextValue>) => {
setStatus(context.status ?? 'idle');
}, []);
const onOpenDoc = useCallback(
(docId: string) => {
workbench.openDoc(docId, { at: 'active' });
@@ -283,143 +149,86 @@ export const Component = () => {
const mockStd = useMockStd();
const handleAISubscribe = useAISubscribe();
const deleteSession = useMemo(
() =>
createSessionDeleteHandler({
t,
notificationService,
cleanupSession: async sessionToDelete => {
await client.cleanupSessions({
workspaceId: sessionToDelete.workspaceId,
docId: sessionToDelete.docId || undefined,
sessionIds: [sessionToDelete.sessionId],
});
},
isActiveSession: sessionToDelete =>
sessionToDelete.sessionId === currentSession?.sessionId,
onActiveSessionDeleted: () => {
setCurrentSession(null);
reMountChatContent();
},
}),
[
client,
currentSession?.sessionId,
notificationService,
reMountChatContent,
t,
]
const deleteSession = useCallback(
async (sessionToDelete: BlockSuitePresets.AIRecentSession) => {
const confirm = await notificationService.confirm({
title: t['com.affine.ai.chat-panel.session.delete.confirm.title'](),
message: t['com.affine.ai.chat-panel.session.delete.confirm.message'](),
confirmText: t['Delete'](),
cancelText: t['Cancel'](),
});
if (!confirm) return;
await runtime.dispatch({
type: 'deleteSession',
sessionId: sessionToDelete.sessionId,
});
notificationService.toast(
t['com.affine.ai.chat-panel.session.delete.toast.success'](),
{}
);
},
[notificationService, runtime, t]
);
// init or update ai-chat-content
useEffect(() => {
if (!isBodyProvided) {
return;
}
let content = chatContent;
if (!content) {
content = new AIChatContent();
}
content.session = currentSession;
content.workspaceId = workspaceId;
content.extensions = specs;
content.host = mockStd?.host;
content.docDisplayConfig = docDisplayConfig;
content.searchMenuConfig = searchMenuConfig;
content.reasoningConfig = reasoningConfig;
content.onContextChange = onContextChange;
content.affineFeatureFlagService = framework.get(FeatureFlagService);
content.affineWorkspaceDialogService = framework.get(
WorkspaceDialogService
);
content.peekViewService = framework.get(PeekViewService);
content.affineThemeService = framework.get(AppThemeService);
content.notificationService = notificationService;
content.aiDraftService = framework.get(AIDraftService);
content.aiToolsConfigService = framework.get(AIToolsConfigService);
content.serverService = framework.get(ServerService);
content.subscriptionService = framework.get(SubscriptionService);
content.aiModelService = framework.get(AIModelService);
content.onAISubscribe = handleAISubscribe;
content.createSession = createSession;
content.onOpenDoc = onOpenDoc;
if (!chatContent) {
// initial values that won't change
useAIChatElement({
containerRef: chatContainerRef,
selector: 'ai-chat-content',
enabled: isBodyProvided,
createElement: () => new AIChatContent(),
configureElement: content => {
content.session = session;
content.runtime = runtime;
content.runtimeSnapshot = snapshot;
content.workspaceId = workspaceId;
content.extensions = specs;
content.host = mockStd?.host;
content.docDisplayConfig = docDisplayConfig;
content.searchMenuConfig = searchMenuConfig;
content.reasoningConfig = reasoningConfig;
content.affineFeatureFlagService = framework.get(FeatureFlagService);
content.affineWorkspaceDialogService = framework.get(
WorkspaceDialogService
);
content.peekViewService = framework.get(PeekViewService);
content.affineThemeService = framework.get(AppThemeService);
content.notificationService = notificationService;
content.aiDraftService = framework.get(AIDraftService);
content.aiToolsConfigService = framework.get(AIToolsConfigService);
content.serverService = framework.get(ServerService);
content.subscriptionService = framework.get(SubscriptionService);
content.aiModelService = framework.get(AIModelService);
content.onAISubscribe = handleAISubscribe;
content.onOpenDoc = onOpenDoc;
},
onElementReady: content => {
content.independentMode = true;
content.onboardingOffsetY = -100;
chatContainerRef.current?.append(content);
setChatContent(content);
}
}, [
chatContent,
createSession,
currentSession,
docDisplayConfig,
framework,
isBodyProvided,
mockStd,
reasoningConfig,
searchMenuConfig,
workspaceId,
onContextChange,
notificationService,
specs,
onOpenDoc,
handleAISubscribe,
]);
},
});
// init or update header ai-chat-toolbar
useEffect(() => {
if (!isHeaderProvided || !chatToolContainerRef.current) {
return;
}
const tool = getOrCreateAIChatToolbar(chatTool);
configureAIChatToolbar(tool, {
session: currentSession,
workspaceId,
status,
docDisplayConfig,
notificationService,
onOpenSession: sessionId => {
onOpenSession(sessionId).catch(console.error);
},
onNewSession: () => {
createFreshSession().catch(console.error);
},
onTogglePin: togglePin,
onOpenDoc: (docId: string, sessionId: string) => {
onOpenSessionDoc(docId, sessionId);
},
onSessionDelete: (sessionToDelete: BlockSuitePresets.AIRecentSession) => {
deleteSession(sessionToDelete).catch(console.error);
},
});
// initial props
if (!chatTool) {
// mount
chatToolContainerRef.current.append(tool);
setChatTool(tool);
}
}, [
chatTool,
currentSession,
docDisplayConfig,
isHeaderProvided,
onOpenSession,
togglePin,
workspaceId,
onOpenSessionDoc,
deleteSession,
status,
notificationService,
createFreshSession,
]);
useAIChatElement({
containerRef: chatToolContainerRef,
selector: 'ai-chat-toolbar',
enabled: isHeaderProvided,
createElement: () => new AIChatToolbar(),
configureElement: tool => {
configureAIChatToolbar(tool, {
session,
runtime,
runtimeSnapshot: snapshot ?? runtime.getSnapshot(),
docDisplayConfig,
notificationService,
onOpenDoc: (docId: string, sessionId: string) => {
onOpenSessionDoc(docId, sessionId);
},
onSessionDelete: (
sessionToDelete: BlockSuitePresets.AIRecentSession
) => {
deleteSession(sessionToDelete).catch(console.error);
},
});
},
});
useEffect(() => {
const refNodeSlots = mockStd?.getOptional(RefNodeSlotsProvider);
@@ -437,87 +246,16 @@ export const Component = () => {
return () => sub.unsubscribe();
}, [framework, mockStd]);
useEffect(() => {
if (!currentSession?.sessionId) return;
setOpenTabs(prev => {
const existing = prev.findIndex(
tab => tab.sessionId === currentSession.sessionId
);
if (existing !== -1) {
if (prev[existing] === currentSession) return prev;
const next = prev.slice();
next[existing] = currentSession;
return next;
}
return [...prev, currentSession];
});
}, [currentSession, setOpenTabs]);
useEffect(() => {
if (!chatTabsContainerRef.current) return;
let tabs = chatTabs;
if (!tabs) {
tabs = new AIChatTabs();
chatTabsContainerRef.current.append(tabs);
setChatTabs(tabs);
}
tabs.sessions = openTabs;
tabs.activeSessionId = currentSession?.sessionId;
tabs.onSelectTab = (sessionId: string) => {
onOpenSession(sessionId).catch(console.error);
};
tabs.onCloseTab = (sessionId: string) => {
closeTab(sessionId);
};
}, [chatTabs, closeTab, currentSession?.sessionId, onOpenSession, openTabs]);
// restore pinned session
useEffect(() => {
if (hasRestoredPinnedSessionRef.current || currentSession) return;
hasRestoredPinnedSessionRef.current = true;
const controller = new AbortController();
const loadPinnedSession = async () => {
try {
const sessions = await client.getSessions(
workspaceId,
{},
undefined,
{ pinned: true, limit: 1 },
controller.signal
);
if (controller.signal.aborted || !Array.isArray(sessions)) {
return;
}
const pinnedSession = sessions[0];
if (!pinnedSession) {
return;
}
let shouldRemount = false;
setCurrentSession(prev => {
if (prev) return prev;
shouldRemount = true;
return pinnedSession;
});
if (shouldRemount) reMountChatContent();
} catch (error) {
if (controller.signal.aborted) {
return;
}
console.error(error);
}
};
loadPinnedSession().catch(error => {
if (controller.signal.aborted) return;
console.error(error);
});
// abort the request
return () => {
controller.abort();
};
}, [client, currentSession, reMountChatContent, workspaceId]);
useAIChatElement({
containerRef: chatTabsContainerRef,
selector: 'ai-chat-tabs',
enabled: true,
createElement: () => new AIChatTabs(),
configureElement: tabs => {
tabs.runtime = runtime;
tabs.runtimeSnapshot = snapshot;
},
});
const onChatContainerRef = useCallback((node: HTMLDivElement) => {
if (node) {
@@ -1,7 +1,6 @@
import { Scrollable } from '@affine/component';
import { PageDetailLoading } from '@affine/component/page-detail-skeleton';
import type { AIChatParams } from '@affine/core/blocksuite/ai';
import { AIProvider } from '@affine/core/blocksuite/ai';
import { AIAppEvents, type AIChatParams } from '@affine/core/blocksuite/ai';
import type { AffineEditorContainer } from '@affine/core/blocksuite/block-suite-editor';
import { EditorOutlineViewer } from '@affine/core/blocksuite/outline-viewer';
import { AffineErrorBoundary } from '@affine/core/components/affine/affine-error-boundary';
@@ -145,12 +144,8 @@ const DetailPageImpl = memo(function DetailPageImpl() {
workbench.openSidebar();
view.activeSidebarTab('chat');
};
disposables.push(
AIProvider.slots.requestOpenWithChat.subscribe(openHandler)
);
disposables.push(
AIProvider.slots.requestSendWithChat.subscribe(openHandler)
);
disposables.push(AIAppEvents.requestOpenWithChat.subscribe(openHandler));
disposables.push(AIAppEvents.requestSendWithChat.subscribe(openHandler));
return () => disposables.forEach(d => d.unsubscribe());
}, [activeSidebarTab, view, workbench]);
@@ -378,7 +373,7 @@ const DetailPageImpl = memo(function DetailPageImpl() {
icon={<AiIcon />}
unmountOnInactive={false}
>
<EditorChatPanel editor={editorContainer} />
<EditorChatPanel editor={editorContainer} doc={doc.blockSuiteDoc} />
</ViewSidebarTab>
)}
@@ -1,341 +0,0 @@
/* eslint-disable rxjs/finnish */
import { type CopilotChatHistoryFragment } from '@affine/graphql';
import { describe, expect, test, vi } from 'vitest';
import {
canCreateNewDocPanelSession,
filterDocPanelTabs,
getChatContentKey,
hasSessionMessages,
isSessionAvailableInDocPanel,
resolveInitialSession,
type SessionService,
shouldResetChatPanelOnUserInfoChange,
type WorkbenchLike,
} from './chat-panel-session';
const createWorkbench = (search: string) => {
const updateQueryString = vi.fn();
const workbench = {
location$: { value: { search } },
activeView$: { value: { updateQueryString } },
} satisfies WorkbenchLike;
return { workbench, updateQueryString };
};
const doc = { id: 'doc-1', workspace: { id: 'ws-1' } };
describe('getChatContentKey', () => {
const cases = [
{
name: 'uses doc id before a session is created',
input: {
docId: 'doc-1',
hasPinned: false,
session: null,
},
expected: 'doc-1',
},
{
name: 'keeps a new empty doc session on the doc key',
input: {
docId: 'doc-2',
hasPinned: false,
previousSessionDocId: 'doc-1',
previousSessionId: 'session-1',
session: {
sessionId: 'session-2',
docId: 'doc-2',
messages: [],
},
},
expected: 'doc-2',
},
{
name: 'uses session id for a session with history',
input: {
docId: 'doc-1',
hasPinned: false,
session: {
sessionId: 'session-1',
docId: 'doc-1',
messages: [{ id: 'message-1' }],
},
},
expected: 'session-1',
},
{
name: 'uses session id for a pinned session',
input: {
docId: 'doc-1',
hasPinned: true,
session: {
sessionId: 'session-1',
docId: 'doc-1',
messages: [],
},
},
expected: 'session-1',
},
{
name: 'uses session id for same-doc session switch',
input: {
docId: 'doc-1',
hasPinned: false,
previousSessionDocId: 'doc-1',
previousSessionId: 'session-1',
session: {
sessionId: 'session-2',
docId: 'doc-1',
messages: [],
},
},
expected: 'session-2',
},
{
name: 'keeps generating draft session on the doc key',
input: {
docId: 'doc-1',
hasPinned: false,
isGenerating: true,
previousSessionDocId: 'doc-1',
previousSessionId: 'session-1',
session: {
sessionId: 'session-2',
docId: 'doc-1',
messages: [],
},
},
expected: 'doc-1',
},
] satisfies {
name: string;
input: Parameters<typeof getChatContentKey>[0];
expected: string;
}[];
test.each(cases)('$name', ({ input, expected }) => {
expect(getChatContentKey(input)).toBe(expected);
});
});
describe('shouldResetChatPanelOnUserInfoChange', () => {
const cases = [
{
name: 'ignores the initial user info emission',
input: {
previousUserId: undefined,
nextUserId: 'user-1',
},
expected: false,
},
{
name: 'ignores same-user refreshes',
input: {
previousUserId: 'user-1',
nextUserId: 'user-1',
},
expected: false,
},
{
name: 'resets when the effective user changes',
input: {
previousUserId: 'user-1',
nextUserId: 'user-2',
},
expected: true,
},
{
name: 'resets when the effective user signs out',
input: {
previousUserId: 'user-1',
nextUserId: null,
},
expected: true,
},
] satisfies {
name: string;
input: Parameters<typeof shouldResetChatPanelOnUserInfoChange>[0];
expected: boolean;
}[];
test.each(cases)('$name', ({ input, expected }) => {
expect(shouldResetChatPanelOnUserInfoChange(input)).toBe(expected);
});
});
describe('doc panel tabs', () => {
const sessions = [
{ sessionId: 'current-doc-session', docId: 'doc-1' },
{ sessionId: 'workspace-session', docId: null },
{ sessionId: 'other-doc-session', docId: 'doc-2' },
];
test('allows only current doc or workspace sessions', () => {
expect(filterDocPanelTabs(sessions, 'doc-1')).toEqual([
sessions[0],
sessions[1],
]);
});
test('rejects other document sessions', () => {
expect(isSessionAvailableInDocPanel(sessions[2], 'doc-1')).toBe(false);
});
});
describe('new session guard', () => {
test('allows a new session only after the current chat has messages', () => {
expect(
canCreateNewDocPanelSession({
hasContextMessages: false,
session: { messages: [] },
status: 'idle',
})
).toBe(false);
expect(
canCreateNewDocPanelSession({
hasContextMessages: true,
session: { messages: [] },
status: 'idle',
})
).toBe(true);
expect(hasSessionMessages({ messages: [{ id: 'message-1' }] })).toBe(true);
});
test('does not allow a new session while generating', () => {
expect(
canCreateNewDocPanelSession({
hasContextMessages: true,
session: null,
status: 'loading',
})
).toBe(false);
});
});
test('returns undefined without session service or doc', async () => {
await expect(
resolveInitialSession({ sessionService: null, doc, workbench: null })
).resolves.toBeUndefined();
await expect(
resolveInitialSession({
sessionService: {
getSessions: vi.fn(),
getSession: vi.fn(),
},
doc: null,
workbench: null,
})
).resolves.toBeUndefined();
});
describe('resolveInitialSession', () => {
test('prefers pinned session and clears sessionId from url', async () => {
const pinnedSession = {
sessionId: 'pinned-session',
pinned: true,
} as CopilotChatHistoryFragment;
const sessionService: SessionService = {
getSessions: vi.fn().mockResolvedValueOnce([pinnedSession]),
getSession: vi.fn(),
};
const { workbench, updateQueryString } = createWorkbench(
'?sessionId=from-url'
);
const result = await resolveInitialSession({
sessionService,
doc,
workbench,
});
expect(result).toBe(pinnedSession);
expect(updateQueryString).toHaveBeenCalledWith(
{ sessionId: undefined },
{ replace: true }
);
expect(sessionService.getSession).not.toHaveBeenCalled();
});
test('loads session from url when no pinned session', async () => {
const sessionFromUrl = {
sessionId: 'url-session',
pinned: false,
} as CopilotChatHistoryFragment;
const sessionService: SessionService = {
getSessions: vi.fn().mockResolvedValueOnce([]),
getSession: vi.fn().mockResolvedValueOnce(sessionFromUrl),
};
const { workbench, updateQueryString } = createWorkbench(
'?sessionId=url-session'
);
const result = await resolveInitialSession({
sessionService,
doc,
workbench,
});
expect(result).toBe(sessionFromUrl);
expect(sessionService.getSession).toHaveBeenCalledWith(
doc.workspace.id,
'url-session'
);
expect(updateQueryString).toHaveBeenCalledWith(
{ sessionId: undefined },
{ replace: true }
);
});
test('falls back to latest doc session', async () => {
const docSession = {
sessionId: 'doc-session',
pinned: false,
} as CopilotChatHistoryFragment;
const sessionService: SessionService = {
getSessions: vi
.fn()
.mockResolvedValueOnce([])
.mockResolvedValueOnce([docSession]),
getSession: vi.fn(),
};
const { workbench } = createWorkbench('');
const result = await resolveInitialSession({
sessionService,
doc,
workbench,
});
expect(result).toBe(docSession);
expect(sessionService.getSessions).toHaveBeenCalledWith(
doc.workspace.id,
doc.id,
{ action: false, fork: false, limit: 1 }
);
});
test('returns null when url session is missing', async () => {
const sessionService: SessionService = {
getSessions: vi.fn().mockResolvedValueOnce([]),
getSession: vi.fn().mockResolvedValueOnce(null),
};
const { workbench } = createWorkbench('?sessionId=missing');
const result = await resolveInitialSession({
sessionService,
doc,
workbench,
});
expect(result).toBeNull();
});
});
@@ -1,208 +0,0 @@
/* eslint-disable rxjs/finnish */
import type { CopilotChatHistoryFragment } from '@affine/graphql';
type SessionListOptions = {
pinned?: boolean;
action?: boolean;
fork?: boolean;
limit?: number;
};
export interface SessionService {
getSessions: (
workspaceId: string,
docId?: string,
options?: SessionListOptions
) => Promise<CopilotChatHistoryFragment[] | null | undefined>;
getSession: (
workspaceId: string,
sessionId: string
) => Promise<CopilotChatHistoryFragment | null | undefined>;
}
export interface WorkbenchLike {
location$: {
value: {
search: string;
};
};
activeView$: {
value: {
updateQueryString: (
patch: Record<string, unknown>,
options?: { replace?: boolean }
) => void;
};
};
}
export interface DocLike {
id: string;
workspace: {
id: string;
};
}
interface ChatContentKeySession {
sessionId?: string | null;
docId?: string | null;
messages?: readonly unknown[] | null;
}
type TabSession = {
sessionId: string;
docId?: string | null;
messages?: readonly unknown[] | null;
};
export const shouldResetChatPanelOnUserInfoChange = ({
previousUserId,
nextUserId,
}: {
previousUserId?: string | null;
nextUserId?: string | null;
}) => {
return previousUserId !== undefined && previousUserId !== nextUserId;
};
export const getChatContentKey = ({
docId,
hasPinned,
isGenerating,
previousSessionDocId,
previousSessionId,
session,
}: {
docId?: string | null;
hasPinned: boolean;
isGenerating?: boolean;
previousSessionDocId?: string | null;
previousSessionId?: string | null;
session?: ChatContentKeySession | null;
}) => {
const fallbackKey = docId ?? 'chat-panel';
const sessionId = session?.sessionId;
if (!sessionId) {
return fallbackKey;
}
const sessionDocId = session.docId ?? docId ?? null;
const hasSessionHistory = !!session.messages?.length;
const shouldPreserveTransientMessages = isGenerating && !hasSessionHistory;
const sessionSwitchedWithinDoc = !!(
previousSessionId &&
previousSessionId !== sessionId &&
previousSessionDocId &&
sessionDocId &&
previousSessionDocId === sessionDocId &&
sessionDocId === docId
);
return hasPinned ||
hasSessionHistory ||
(sessionSwitchedWithinDoc && !shouldPreserveTransientMessages)
? sessionId
: fallbackKey;
};
export const isSessionAvailableInDocPanel = (
session: TabSession,
docId?: string | null
) => {
return !session.docId || session.docId === docId;
};
export const filterDocPanelTabs = <T extends TabSession>(
sessions: T[],
docId?: string | null
) => {
return sessions.filter(session =>
isSessionAvailableInDocPanel(session, docId)
);
};
export const hasSessionMessages = (
session?: Pick<TabSession, 'messages'> | null
) => {
return !!session?.messages?.length;
};
export const canCreateNewDocPanelSession = ({
hasContextMessages,
session,
status,
}: {
hasContextMessages: boolean;
session?: Pick<TabSession, 'messages'> | null;
status?: string | null;
}) => {
return (
(hasContextMessages || hasSessionMessages(session)) &&
status !== 'loading' &&
status !== 'transmitting'
);
};
export const getSessionIdFromUrl = (workbench?: WorkbenchLike | null) => {
if (!workbench) {
return undefined;
}
const searchParams = new URLSearchParams(workbench.location$.value.search);
const sessionId = searchParams.get('sessionId');
if (sessionId) {
workbench.activeView$.value.updateQueryString(
{ sessionId: undefined },
{ replace: true }
);
}
return sessionId ?? undefined;
};
export const resolveInitialSession = async ({
sessionService,
doc,
workbench,
}: {
sessionService?: SessionService | null;
doc?: DocLike | null;
workbench?: WorkbenchLike | null;
}): Promise<CopilotChatHistoryFragment | null | undefined> => {
if (!sessionService || !doc) {
return undefined;
}
const sessionId = getSessionIdFromUrl(workbench);
const pinSessions = await sessionService.getSessions(
doc.workspace.id,
undefined,
{
pinned: true,
limit: 1,
}
);
if (Array.isArray(pinSessions) && pinSessions[0]) {
return pinSessions[0];
}
if (sessionId) {
const session = await sessionService.getSession(
doc.workspace.id,
sessionId
);
return session ?? null;
}
const docSessions = await sessionService.getSessions(
doc.workspace.id,
doc.id,
{
action: false,
fork: false,
limit: 1,
}
);
return docSessions?.[0] ?? null;
};
@@ -1,16 +1,18 @@
import { useConfirmModal } from '@affine/component';
import { AIProvider } from '@affine/core/blocksuite/ai';
import type { AppSidebarConfig } from '@affine/core/blocksuite/ai/chat-panel/chat-config';
import {
AIChatContent,
type ChatContextValue,
} from '@affine/core/blocksuite/ai/components/ai-chat-content';
import type { ChatStatus } from '@affine/core/blocksuite/ai/components/ai-chat-messages';
import type { AIChatToolbar } from '@affine/core/blocksuite/ai/components/ai-chat-toolbar';
AIAppEvents,
AIChatRuntime,
createAIRequestService,
DocAIChatSessionStrategy,
useAIChatElement,
useAIChatRuntime,
} from '@affine/core/blocksuite/ai';
import type { AppSidebarConfig } from '@affine/core/blocksuite/ai/chat-panel/chat-config';
import { AIChatContent } from '@affine/core/blocksuite/ai/components/ai-chat-content';
import {
AIChatTabs,
AIChatToolbar,
configureAIChatToolbar,
getOrCreateAIChatToolbar,
} from '@affine/core/blocksuite/ai/components/ai-chat-toolbar';
import { createPlaygroundModal } from '@affine/core/blocksuite/ai/components/playground/modal';
import { registerAIAppEffects } from '@affine/core/blocksuite/ai/effects/app';
@@ -24,52 +26,55 @@ import {
AIToolsConfigService,
} from '@affine/core/modules/ai-button';
import { AIModelService } from '@affine/core/modules/ai-button/services/models';
import { ServerService, SubscriptionService } from '@affine/core/modules/cloud';
import {
EventSourceService,
GraphQLService,
ServerService,
SubscriptionService,
} from '@affine/core/modules/cloud';
import { WorkspaceDialogService } from '@affine/core/modules/dialogs';
import { useSignalValue } from '@affine/core/modules/doc-info/utils';
import { FeatureFlagService } from '@affine/core/modules/feature-flag';
import { PeekViewService } from '@affine/core/modules/peek-view';
import { AppThemeService } from '@affine/core/modules/theme';
import { WorkbenchService } from '@affine/core/modules/workbench';
import type {
ContextEmbedStatus,
CopilotChatHistoryFragment,
UpdateChatSessionInput,
} from '@affine/graphql';
import { useI18n } from '@affine/i18n';
import { RefNodeSlotsProvider } from '@blocksuite/affine/inlines/reference';
import { DocModeProvider } from '@blocksuite/affine/shared/services';
import { createSignalFromObservable } from '@blocksuite/affine/shared/utils';
import type { Store } from '@blocksuite/affine/store';
import { CenterPeekIcon, Logo1Icon } from '@blocksuite/icons/rc';
import type { Signal } from '@preact/signals-core';
import { useFramework, useService } from '@toeverything/infra';
import { html } from 'lit';
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import {
createSessionDeleteHandler,
useAIChatOpenTabs,
} from '../../chat-panel-utils';
import * as styles from './chat.css';
import {
canCreateNewDocPanelSession,
filterDocPanelTabs,
getChatContentKey,
isSessionAvailableInDocPanel,
resolveInitialSession,
shouldResetChatPanelOnUserInfoChange,
type WorkbenchLike,
} from './chat-panel-session';
registerAIAppEffects();
const shouldResetChatPanelOnUserInfoChange = ({
previousUserId,
nextUserId,
}: {
previousUserId?: string | null;
nextUserId?: string | null;
}) => previousUserId !== undefined && previousUserId !== nextUserId;
export interface SidebarTabProps {
editor: AffineEditorContainer | null;
doc: Store;
onLoad?: ((component: HTMLElement) => void) | null;
}
export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
export const EditorChatPanel = ({
editor,
doc: fallbackDoc,
onLoad,
}: SidebarTabProps) => {
const framework = useFramework();
const graphqlService = useService(GraphQLService);
const eventSourceService = useService(EventSourceService);
const workbench = useService(WorkbenchService).workbench;
const t = useI18n();
@@ -89,81 +94,57 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
} = useAIChatConfig();
const playgroundVisible = useSignalValue(playgroundConfig.visible) ?? false;
const [session, setSession] = useState<
CopilotChatHistoryFragment | null | undefined
>(undefined);
const [embeddingProgress, setEmbeddingProgress] = useState<[number, number]>([
0, 0,
]);
const [status, setStatus] = useState<ChatStatus>('idle');
const [hasPinned, setHasPinned] = useState(false);
const [chatContent, setChatContent] = useState<AIChatContent | null>(null);
const [chatToolbar, setChatToolbar] = useState<AIChatToolbar | null>(null);
const [chatTabs, setChatTabs] = useState<AIChatTabs | null>(null);
const [isBodyProvided, setIsBodyProvided] = useState(false);
const [isHeaderProvided, setIsHeaderProvided] = useState(false);
const chatContainerRef = useRef<HTMLDivElement | null>(null);
const chatToolbarContainerRef = useRef<HTMLDivElement | null>(null);
const chatTabsContainerRef = useRef<HTMLDivElement | null>(null);
const contentKeyRef = useRef<string | null>(null);
const prevSessionIdRef = useRef<string | null>(null);
const prevSessionDocIdRef = useRef<string | null>(null);
const lastDocIdRef = useRef<string | null>(null);
const sessionLoadSeqRef = useRef(0);
const creatingSessionRef = useRef<{
docId: string;
promise: Promise<CopilotChatHistoryFragment | undefined>;
} | null>(null);
const creatingFreshSessionRef = useRef<{
docId: string;
promise: Promise<void>;
} | null>(null);
const userIdRef = useRef<string | null | undefined>(undefined);
const doc = editor?.doc;
const doc = editor?.doc ?? fallbackDoc;
const host = editor?.host;
const workspaceId = doc?.workspace.id;
const [sessionServiceReady, setSessionServiceReady] = useState(
() => !!AIProvider.session
const requestService = useMemo(
() =>
createAIRequestService(
graphqlService.gql,
eventSourceService.eventSource
),
[eventSourceService.eventSource, graphqlService.gql]
);
const [hasContextMessages, setHasContextMessages] = useState(false);
useEffect(() => {
if (sessionServiceReady) return;
if (AIProvider.session) {
setSessionServiceReady(true);
return;
}
const sub = AIProvider.slots.sessionReady.subscribe(ready => {
if (ready) setSessionServiceReady(true);
});
return () => sub.unsubscribe();
}, [sessionServiceReady]);
const loadSession = useMemo(() => {
if (!sessionServiceReady || !workspaceId) return null;
const sessionService = AIProvider.session;
if (!sessionService) return null;
return async (
sessionId: string
): Promise<CopilotChatHistoryFragment | null | undefined> =>
sessionService.getSession(workspaceId, sessionId);
}, [sessionServiceReady, workspaceId]);
const { openTabs, setOpenTabs } =
useAIChatOpenTabs<CopilotChatHistoryFragment>(loadSession);
const visibleOpenTabs = useMemo(
() => filterDocPanelTabs(openTabs, doc?.id),
[doc?.id, openTabs]
);
const canCreateNewSession = canCreateNewDocPanelSession({
hasContextMessages,
session,
status,
const [pendingSessionId] = useState(() => {
const searchParams = new URLSearchParams(workbench.location$.value.search);
return searchParams.get('sessionId') ?? undefined;
});
useEffect(() => {
if (pendingSessionId) {
workbench.activeView$.value.updateQueryString(
{ sessionId: undefined },
{ replace: true }
);
}
}, [pendingSessionId, workbench]);
const runtime = useMemo(() => {
if (!doc || !workspaceId) return null;
return new AIChatRuntime({
request: requestService,
scope: {
kind: 'doc',
workspaceId,
docId: doc.id,
pendingSessionId,
},
strategy: new DocAIChatSessionStrategy(),
});
}, [doc, pendingSessionId, requestService, workspaceId]);
const snapshot = useAIChatRuntime(runtime);
const session =
snapshot?.sessions.find(
item => item.sessionId === snapshot.activeSessionId
) ?? null;
const appSidebarConfig = useMemo<AppSidebarConfig>(() => {
return {
getWidth: () =>
@@ -188,159 +169,14 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
return cleanup;
}, [appSidebarConfig]);
const resetPanel = useCallback(() => {
sessionLoadSeqRef.current += 1;
setSession(undefined);
setEmbeddingProgress([0, 0]);
setHasPinned(false);
}, []);
const initPanel = useCallback(async () => {
const requestSeq = ++sessionLoadSeqRef.current;
try {
const nextSession = await resolveInitialSession({
sessionService: AIProvider.session ?? undefined,
doc,
workbench: workbench as WorkbenchLike,
});
if (requestSeq !== sessionLoadSeqRef.current) return;
if (nextSession === undefined) {
return;
}
setSession(nextSession);
setHasPinned(!!nextSession?.pinned);
} catch (error) {
console.error(error);
}
}, [doc, workbench]);
const createSession = useCallback(
async (options: Partial<BlockSuitePresets.AICreateSessionOptions> = {}) => {
if (session || !AIProvider.session || !doc) {
return session ?? undefined;
}
if (creatingSessionRef.current?.docId === doc.id) {
return creatingSessionRef.current.promise;
}
const requestSeq = ++sessionLoadSeqRef.current;
let promise: Promise<CopilotChatHistoryFragment | undefined>;
promise = AIProvider.session
.createSessionWithHistory({
docId: doc.id,
workspaceId: doc.workspace.id,
promptName: 'Chat With AFFiNE AI',
reuseLatestChat: false,
...options,
})
.then(nextSession => {
if (requestSeq !== sessionLoadSeqRef.current) return undefined;
setSession(nextSession ?? null);
setHasPinned(!!nextSession?.pinned);
return nextSession ?? undefined;
})
.finally(() => {
if (creatingSessionRef.current?.promise === promise) {
creatingSessionRef.current = null;
}
});
creatingSessionRef.current = { docId: doc.id, promise };
return promise;
},
[doc, session]
);
const updateSession = useCallback(
async (options: UpdateChatSessionInput) => {
if (!AIProvider.session || !doc) {
return undefined;
}
const requestSeq = ++sessionLoadSeqRef.current;
await AIProvider.session.updateSession(options);
const nextSession = await AIProvider.session.getSession(
doc.workspace.id,
options.sessionId
);
if (requestSeq !== sessionLoadSeqRef.current) return undefined;
setSession(nextSession ?? null);
setHasPinned(!!nextSession?.pinned);
return nextSession ?? undefined;
},
[doc]
);
const newSession = useCallback(async () => {
if (!canCreateNewSession) {
return;
}
if (doc && creatingFreshSessionRef.current?.docId === doc.id) {
return creatingFreshSessionRef.current.promise;
}
resetPanel();
const requestSeq = sessionLoadSeqRef.current;
setSession(null);
setHasContextMessages(false);
if (!AIProvider.session || !doc) {
return;
}
let promise: Promise<void>;
promise = AIProvider.session
.createSessionWithHistory({
docId: doc.id,
workspaceId: doc.workspace.id,
promptName: 'Chat With AFFiNE AI',
reuseLatestChat: false,
})
.then(nextSession => {
if (requestSeq === sessionLoadSeqRef.current) {
setSession(nextSession ?? null);
setHasPinned(!!nextSession?.pinned);
}
})
.catch(console.error)
.finally(() => {
if (creatingFreshSessionRef.current?.promise === promise) {
creatingFreshSessionRef.current = null;
}
});
creatingFreshSessionRef.current = { docId: doc.id, promise };
return promise;
}, [canCreateNewSession, doc, resetPanel]);
const openSession = useCallback(
async (sessionId: string) => {
if (session?.sessionId === sessionId || !AIProvider.session || !doc) {
if (session?.sessionId === sessionId || !runtime) {
return;
}
const requestSeq = ++sessionLoadSeqRef.current;
try {
const nextSession = await AIProvider.session.getSession(
doc.workspace.id,
sessionId
);
if (requestSeq !== sessionLoadSeqRef.current) return;
if (!nextSession) {
// Drop stale tab if session no longer exists.
setOpenTabs(prev => prev.filter(tab => tab.sessionId !== sessionId));
return;
}
if (!isSessionAvailableInDocPanel(nextSession, doc.id)) {
setOpenTabs([]);
workbench.open(`/${nextSession.docId}?sessionId=${sessionId}`, {
at: 'active',
});
return;
}
setSession(nextSession);
setHasPinned(!!nextSession.pinned);
} catch (error) {
console.error(error);
}
await runtime.dispatch({ type: 'openSession', sessionId });
},
[doc, session?.sessionId, setOpenTabs, workbench]
[runtime, session?.sessionId]
);
const openDoc = useCallback(
@@ -361,159 +197,47 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
workbench.open(`/${docId}`, { at: 'active' });
return;
}
setOpenTabs([]);
workbench.open(`/${docId}?sessionId=${sessionId}`, { at: 'active' });
},
[
doc,
openSession,
session?.pinned,
session?.sessionId,
setOpenTabs,
workbench,
]
);
const deleteSession = useMemo(
() =>
createSessionDeleteHandler({
t,
notificationService,
canDeleteSession: () => Boolean(AIProvider.histories),
cleanupSession: async sessionToDelete => {
await AIProvider.histories?.cleanup(
sessionToDelete.workspaceId,
sessionToDelete.docId || undefined,
[sessionToDelete.sessionId]
);
},
isActiveSession: sessionToDelete =>
sessionToDelete.sessionId === session?.sessionId,
onActiveSessionDeleted: () => {
resetPanel();
setSession(null);
setHasContextMessages(false);
},
}),
[notificationService, resetPanel, session?.sessionId, t]
);
const closeTab = useCallback(
(sessionId: string) => {
let fallback: CopilotChatHistoryFragment | undefined;
setOpenTabs(prev => {
const idx = prev.findIndex(tab => tab.sessionId === sessionId);
if (idx === -1) return prev;
const next = prev.filter(tab => tab.sessionId !== sessionId);
const visibleNext = filterDocPanelTabs(next, doc?.id);
fallback = visibleNext[idx] ?? visibleNext[idx - 1] ?? visibleNext[0];
return next;
});
if (session?.sessionId !== sessionId) return;
if (fallback) {
openSession(fallback.sessionId).catch(console.error);
} else {
resetPanel();
setSession(null);
setHasContextMessages(false);
}
},
[doc?.id, openSession, resetPanel, session?.sessionId, setOpenTabs]
);
const togglePin = useCallback(async () => {
const pinned = !session?.pinned;
setHasPinned(true);
if (!session) {
await createSession({ pinned });
return;
}
setSession(prev => (prev ? { ...prev, pinned } : prev));
await updateSession({
sessionId: session.sessionId,
pinned,
});
}, [createSession, session, updateSession]);
const rebindSession = useCallback(async () => {
if (!session || !doc) {
return;
}
if (session.docId !== doc.id) {
await updateSession({
sessionId: session.sessionId,
docId: doc.id,
});
}
}, [doc, session, updateSession]);
const onEmbeddingProgressChange = useCallback(
(count: Record<ContextEmbedStatus, number>) => {
const total = count.finished + count.processing + count.failed;
setEmbeddingProgress([count.finished, total]);
},
[]
);
const onContextChange = useCallback(
(context: Partial<ChatContextValue>) => {
if (context.status) {
setStatus(context.status);
}
if (context.messages) {
setHasContextMessages(context.messages.length > 0);
}
if (context.status === 'success') {
rebindSession().catch(console.error);
}
},
[rebindSession]
[doc, openSession, session?.pinned, session?.sessionId, workbench]
);
useEffect(() => {
if (session !== undefined) {
const navigationRequest = snapshot?.navigationRequest;
if (!navigationRequest) {
return;
}
if (chatContent) {
chatContent.remove();
setChatContent(null);
}
if (chatToolbar) {
chatToolbar.remove();
setChatToolbar(null);
}
if (chatTabs) {
chatTabs.remove();
setChatTabs(null);
}
}, [chatContent, chatTabs, chatToolbar, session]);
workbench.open(
`/${navigationRequest.docId}?sessionId=${navigationRequest.sessionId}`,
{ at: 'active' }
);
}, [snapshot?.navigationRequest, workbench]);
useEffect(() => {
if (!session?.sessionId) return;
setOpenTabs(prev => {
const existing = prev.findIndex(
tab => tab.sessionId === session.sessionId
const deleteSession = useCallback(
async (sessionToDelete: BlockSuitePresets.AIRecentSession) => {
if (!runtime) return;
const confirm = await notificationService.confirm({
title: t['com.affine.ai.chat-panel.session.delete.confirm.title'](),
message: t['com.affine.ai.chat-panel.session.delete.confirm.message'](),
confirmText: t['Delete'](),
cancelText: t['Cancel'](),
});
if (!confirm) return;
await runtime.dispatch({
type: 'deleteSession',
sessionId: sessionToDelete.sessionId,
});
notificationService.toast(
t['com.affine.ai.chat-panel.session.delete.toast.success'](),
{}
);
if (existing !== -1) {
if (prev[existing] === session) return prev;
const next = prev.slice();
next[existing] = session;
return next;
}
return [...prev, session];
});
}, [session, setOpenTabs]);
},
[notificationService, runtime, t]
);
useEffect(() => {
let disposed = false;
Promise.resolve(AIProvider.userInfo)
.then(userInfo => {
if (!disposed && userIdRef.current === undefined) {
userIdRef.current = userInfo?.id ?? null;
}
})
.catch(console.error);
const subscription = AIProvider.slots.userInfo.subscribe(userInfo => {
userIdRef.current ??= AIAppEvents.userInfo.value?.id ?? null;
const subscription = AIAppEvents.userInfo.subscribe(userInfo => {
const nextUserId = userInfo?.id ?? null;
const shouldReset = shouldResetChatPanelOnUserInfoChange({
previousUserId: userIdRef.current,
@@ -523,220 +247,90 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
if (!shouldReset) {
return;
}
resetPanel();
initPanel().catch(console.error);
runtime?.dispatch({ type: 'initialize' }).catch(console.error);
});
return () => {
disposed = true;
subscription.unsubscribe();
};
}, [initPanel, resetPanel]);
}, [runtime]);
useEffect(() => {
const docId = doc?.id;
if (!docId) {
return;
}
if (
lastDocIdRef.current &&
lastDocIdRef.current !== docId &&
!session?.pinned
) {
resetPanel();
setHasContextMessages(false);
}
lastDocIdRef.current = docId;
}, [doc?.id, resetPanel, session?.pinned]);
useEffect(() => {
if (!doc || session !== undefined) {
return;
}
if (AIProvider.session) {
initPanel().catch(console.error);
return;
}
const subscription = AIProvider.slots.sessionReady.subscribe(ready => {
if (!ready || session !== undefined) return;
initPanel().catch(console.error);
});
return () => subscription.unsubscribe();
}, [doc, initPanel, session]);
const contentKey = getChatContentKey({
docId: doc?.id,
hasPinned,
isGenerating: status === 'loading' || status === 'transmitting',
previousSessionDocId: prevSessionDocIdRef.current,
previousSessionId: prevSessionIdRef.current,
session,
const chatContent = useAIChatElement({
containerRef: chatContainerRef,
selector: 'ai-chat-content',
enabled: isBodyProvided && !!runtime && !!snapshot,
createElement: () => new AIChatContent(),
configureElement: content => {
if (!runtime || !snapshot) return;
content.host = host;
content.session = session;
content.runtime = runtime;
content.runtimeSnapshot = snapshot;
content.workspaceId = doc.workspace.id;
content.docId = doc.id;
content.reasoningConfig = reasoningConfig;
content.searchMenuConfig = searchMenuConfig;
content.docDisplayConfig = docDisplayConfig;
content.extensions = specs;
content.serverService = framework.get(ServerService);
content.affineFeatureFlagService = framework.get(FeatureFlagService);
content.affineWorkspaceDialogService = framework.get(
WorkspaceDialogService
);
content.affineThemeService = framework.get(AppThemeService);
content.notificationService = notificationService;
content.aiDraftService = framework.get(AIDraftService);
content.aiToolsConfigService = framework.get(AIToolsConfigService);
content.peekViewService = framework.get(PeekViewService);
content.subscriptionService = framework.get(SubscriptionService);
content.aiModelService = framework.get(AIModelService);
content.onAISubscribe = handleAISubscribe;
content.width = sidebarWidthSignal;
content.onOpenDoc = (docId: string, sessionId?: string) => {
openDoc(docId, sessionId).catch(console.error);
};
},
onElementReady: content => {
onLoad?.(content);
},
});
useEffect(() => {
if (session?.sessionId) {
prevSessionIdRef.current = session.sessionId;
prevSessionDocIdRef.current = session.docId ?? doc?.id ?? null;
}
}, [doc?.id, session?.docId, session?.sessionId]);
useAIChatElement({
containerRef: chatToolbarContainerRef,
selector: 'ai-chat-toolbar',
enabled: isHeaderProvided && !!runtime && !!snapshot,
createElement: () => new AIChatToolbar(),
configureElement: tool => {
if (!runtime || !snapshot) return;
configureAIChatToolbar(tool, {
session,
runtime,
runtimeSnapshot: snapshot,
docId: doc.id,
docDisplayConfig,
notificationService,
onOpenDoc: (docId: string, sessionId: string) => {
openDoc(docId, sessionId).catch(console.error);
},
onSessionDelete: (
sessionToDelete: BlockSuitePresets.AIRecentSession
) => {
deleteSession(sessionToDelete).catch(console.error);
},
});
},
});
useEffect(() => {
if (!chatContent) {
contentKeyRef.current = contentKey;
return;
}
if (contentKeyRef.current && contentKeyRef.current !== contentKey) {
chatContent.remove();
setChatContent(null);
}
contentKeyRef.current = contentKey;
}, [chatContent, contentKey]);
useEffect(() => {
if (!isBodyProvided || !chatContainerRef.current || !doc || !host) {
return;
}
if (session === undefined) {
return;
}
let content = chatContent;
if (!content) {
content = new AIChatContent();
}
content.host = host;
content.session = session;
content.createSession = createSession;
content.workspaceId = doc.workspace.id;
content.docId = doc.id;
content.reasoningConfig = reasoningConfig;
content.searchMenuConfig = searchMenuConfig;
content.docDisplayConfig = docDisplayConfig;
content.extensions = specs;
content.serverService = framework.get(ServerService);
content.affineFeatureFlagService = framework.get(FeatureFlagService);
content.affineWorkspaceDialogService = framework.get(
WorkspaceDialogService
);
content.affineThemeService = framework.get(AppThemeService);
content.notificationService = notificationService;
content.aiDraftService = framework.get(AIDraftService);
content.aiToolsConfigService = framework.get(AIToolsConfigService);
content.peekViewService = framework.get(PeekViewService);
content.subscriptionService = framework.get(SubscriptionService);
content.aiModelService = framework.get(AIModelService);
content.onAISubscribe = handleAISubscribe;
content.onEmbeddingProgressChange = onEmbeddingProgressChange;
content.onContextChange = onContextChange;
content.width = sidebarWidthSignal;
content.onOpenDoc = (docId: string, sessionId?: string) => {
openDoc(docId, sessionId).catch(console.error);
};
if (!chatContent) {
chatContainerRef.current.append(content);
setChatContent(content);
onLoad?.(content);
}
}, [
chatContent,
createSession,
doc,
docDisplayConfig,
framework,
handleAISubscribe,
host,
isBodyProvided,
notificationService,
onContextChange,
onEmbeddingProgressChange,
onLoad,
openDoc,
reasoningConfig,
searchMenuConfig,
session,
sidebarWidthSignal,
specs,
]);
useEffect(() => {
if (!isHeaderProvided || !chatToolbarContainerRef.current || !doc) {
return;
}
if (session === undefined) {
return;
}
const tool = getOrCreateAIChatToolbar(chatToolbar);
configureAIChatToolbar(tool, {
session,
workspaceId: doc.workspace.id,
docId: doc.id,
status,
canCreateNewSession,
docDisplayConfig,
notificationService,
onNewSession: () => {
newSession().catch(console.error);
},
onTogglePin: togglePin,
onOpenSession: (sessionId: string) => {
openSession(sessionId).catch(console.error);
},
onOpenDoc: (docId: string, sessionId: string) => {
openDoc(docId, sessionId).catch(console.error);
},
onSessionDelete: (sessionToDelete: BlockSuitePresets.AIRecentSession) => {
deleteSession(sessionToDelete).catch(console.error);
},
});
if (!chatToolbar) {
chatToolbarContainerRef.current.append(tool);
setChatToolbar(tool);
}
}, [
chatToolbar,
canCreateNewSession,
deleteSession,
doc,
docDisplayConfig,
isHeaderProvided,
newSession,
notificationService,
openDoc,
openSession,
session,
status,
togglePin,
]);
useEffect(() => {
if (!chatTabsContainerRef.current || !doc) {
return;
}
if (session === undefined) {
return;
}
let tabs = chatTabs;
if (!tabs) {
tabs = new AIChatTabs();
chatTabsContainerRef.current.append(tabs);
setChatTabs(tabs);
}
tabs.sessions = visibleOpenTabs;
tabs.activeSessionId = session?.sessionId;
tabs.showDraftTab =
visibleOpenTabs.length === 0 && !session?.sessionId && !!doc;
tabs.onSelectTab = (sessionId: string) => {
openSession(sessionId).catch(console.error);
};
tabs.onCloseTab = (sessionId: string) => {
closeTab(sessionId);
};
}, [chatTabs, closeTab, doc, openSession, session, visibleOpenTabs]);
useAIChatElement({
containerRef: chatTabsContainerRef,
selector: 'ai-chat-tabs',
enabled: !!runtime && !!snapshot,
createElement: () => new AIChatTabs(),
configureElement: tabs => {
if (!runtime || !snapshot) return;
tabs.runtime = runtime;
tabs.runtimeSnapshot = snapshot;
},
});
useEffect(() => {
if (!editor?.host || !chatContent) {
@@ -766,19 +360,17 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
if (autoResized) {
return;
}
const subscription = AIProvider.slots.previewPanelOpenChange.subscribe(
open => {
if (!open) {
return;
}
const sidebarWidth = workbench.sidebarWidth$.value;
const minSidebarWidth = 1080;
if (!sidebarWidth || sidebarWidth < minSidebarWidth) {
workbench.setSidebarWidth(minSidebarWidth);
setAutoResized(true);
}
const subscription = AIAppEvents.previewPanelOpenChange.subscribe(open => {
if (!open) {
return;
}
);
const sidebarWidth = workbench.sidebarWidth$.value;
const minSidebarWidth = 1080;
if (!sidebarWidth || sidebarWidth < minSidebarWidth) {
workbench.setSidebarWidth(minSidebarWidth);
setAutoResized(true);
}
});
return () => {
subscription.unsubscribe();
};
@@ -843,14 +435,16 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
chatTabsContainerRef.current = node;
}, []);
const isEmbedding =
embeddingProgress[1] > 0 && embeddingProgress[0] < embeddingProgress[1];
const [done, total] = embeddingProgress;
const isInitialized = session !== undefined;
const embeddingCount = snapshot?.composer.context.embeddingCount;
const done = embeddingCount?.finished ?? 0;
const total =
done + (embeddingCount?.processing ?? 0) + (embeddingCount?.failed ?? 0);
const isEmbedding = total > 0 && done < total;
const hasRuntimeSnapshot = !!snapshot;
return (
<div className={styles.root}>
{!isInitialized ? (
{!hasRuntimeSnapshot ? (
<div className={styles.loadingContainer}>
<div className={styles.loading}>
<Logo1Icon className={styles.loadingIcon} />
@@ -1,6 +1,6 @@
import { Scrollable } from '@affine/component';
import { PageDetailLoading } from '@affine/component/page-detail-skeleton';
import { type AIChatParams, AIProvider } from '@affine/core/blocksuite/ai';
import { AIAppEvents, type AIChatParams } from '@affine/core/blocksuite/ai';
import type { AffineEditorContainer } from '@affine/core/blocksuite/block-suite-editor';
import { EditorOutlineViewer } from '@affine/core/blocksuite/outline-viewer';
import { AffineErrorBoundary } from '@affine/core/components/affine/affine-error-boundary';
@@ -137,12 +137,8 @@ function DocPeekPreviewEditor({
// chat panel open is already handled in <DetailPageImpl />
}
};
disposables.push(
AIProvider.slots.requestOpenWithChat.subscribe(openHandler)
);
disposables.push(
AIProvider.slots.requestSendWithChat.subscribe(openHandler)
);
disposables.push(AIAppEvents.requestOpenWithChat.subscribe(openHandler));
disposables.push(AIAppEvents.requestSendWithChat.subscribe(openHandler));
return () => disposables.forEach(d => d.unsubscribe());
}, [doc, peekView, workbench, workspace.id]);
@@ -226,8 +226,8 @@ export class SettingsPanelUtils {
timeout: number,
status = 'synced'
) {
await cleanupWorkspace(page.url().split('/').slice(-2)[0] || '');
await expect(async () => {
await cleanupWorkspace(page.url().split('/').slice(-2)[0] || '');
await this.openSettingsPanel(page);
const title = page.getByTestId('embedding-progress-title');
// oxlint-disable-next-line prefer-dom-node-dataset