feat(core): add ai draft service (#13252)

Close [AI-244](https://linear.app/affine-design/issue/AI-244)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added AI chat draft persistence, allowing your chat input, quotes,
markdown, and images to be automatically saved and restored across
sessions.
* Drafts are now synchronized across chat components, so you won’t lose
your progress if you navigate away or refresh the page.

* **Improvements**
* Enhanced chat experience with seamless restoration of previously
entered content and attachments.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Wu Yue
2025-07-17 17:42:01 +08:00
committed by GitHub
parent 4018b3aeca
commit 0770b109cb
9 changed files with 244 additions and 4 deletions

View File

@@ -1,3 +1,4 @@
import type { AIDraftService } from '@affine/core/modules/ai-button';
import type { WorkspaceDialogService } from '@affine/core/modules/dialogs';
import type { FeatureFlagService } from '@affine/core/modules/feature-flag';
import type { AppThemeService } from '@affine/core/modules/theme';
@@ -115,6 +116,9 @@ export class ChatPanel extends SignalWatcher(
@property({ attribute: false })
accessor notificationService!: NotificationService;
@property({ attribute: false })
accessor aiDraftService!: AIDraftService;
@state()
accessor session: CopilotChatHistoryFragment | null | undefined;
@@ -408,6 +412,7 @@ export class ChatPanel extends SignalWatcher(
.affineWorkspaceDialogService=${this.affineWorkspaceDialogService}
.affineThemeService=${this.affineThemeService}
.notificationService=${this.notificationService}
.aiDraftService=${this.aiDraftService}
.onEmbeddingProgressChange=${this.onEmbeddingProgressChange}
.onContextChange=${this.onContextChange}
.width=${this.sidebarWidth}

View File

@@ -1,5 +1,6 @@
import './ai-chat-composer-tip';
import type { AIDraftService } from '@affine/core/modules/ai-button';
import type { WorkspaceDialogService } from '@affine/core/modules/dialogs';
import type {
ContextEmbedStatus,
@@ -116,6 +117,9 @@ export class AIChatComposer extends SignalWatcher(
@property({ attribute: false })
accessor notificationService!: NotificationService;
@property({ attribute: false })
accessor aiDraftService!: AIDraftService;
@state()
accessor chips: ChatChip[] = [];
@@ -161,6 +165,7 @@ export class AIChatComposer extends SignalWatcher(
.reasoningConfig=${this.reasoningConfig}
.docDisplayConfig=${this.docDisplayConfig}
.searchMenuConfig=${this.searchMenuConfig}
.aiDraftService=${this.aiDraftService}
.portalContainer=${this.portalContainer}
.onChatSuccess=${this.onChatSuccess}
.trackOptions=${this.trackOptions}

View File

@@ -1,3 +1,5 @@
import type { AIDraftService } from '@affine/core/modules/ai-button';
import type { AIDraftState } from '@affine/core/modules/ai-button/services/ai-draft';
import type { WorkspaceDialogService } from '@affine/core/modules/dialogs';
import type { FeatureFlagService } from '@affine/core/modules/feature-flag';
import type { AppThemeService } from '@affine/core/modules/theme';
@@ -15,6 +17,7 @@ import { property, state } from 'lit/decorators.js';
import { classMap } from 'lit/directives/class-map.js';
import { createRef, type Ref, ref } from 'lit/directives/ref.js';
import { styleMap } from 'lit/directives/style-map.js';
import { pick } from 'lodash-es';
import { HISTORY_IMAGE_ACTIONS } from '../../chat-panel/const';
import { type AIChatParams, AIProvider } from '../../provider/ai-provider';
@@ -149,6 +152,9 @@ export class AIChatContent extends SignalWatcher(
@property({ attribute: false })
accessor notificationService!: NotificationService;
@property({ attribute: false })
accessor aiDraftService!: AIDraftService;
@property({ attribute: false })
accessor onEmbeddingProgressChange:
| ((count: Record<ContextEmbedStatus, number>) => void)
@@ -263,6 +269,19 @@ export class AIChatContent extends SignalWatcher(
private readonly updateContext = (context: Partial<ChatContextValue>) => {
this.chatContextValue = { ...this.chatContextValue, ...context };
this.onContextChange?.(context);
this.updateDraft(context).catch(console.error);
};
private readonly updateDraft = async (context: Partial<ChatContextValue>) => {
const draft: Partial<AIDraftState> = pick(context, [
'quote',
'images',
'markdown',
]);
if (!Object.keys(draft).length) {
return;
}
await this.aiDraftService.setDraft(draft);
};
private readonly initChatContent = async () => {
@@ -322,8 +341,19 @@ export class AIChatContent extends SignalWatcher(
override connectedCallback() {
super.connectedCallback();
this.initChatContent().catch(console.error);
this.aiDraftService
.getDraft()
.then(draft => {
this.chatContextValue = {
...this.chatContextValue,
...draft,
};
})
.catch(console.error);
this._disposables.add(
AIProvider.slots.actions.subscribe(({ event }) => {
const { status } = this.chatContextValue;
@@ -403,6 +433,7 @@ export class AIChatContent extends SignalWatcher(
.searchMenuConfig=${this.searchMenuConfig}
.affineWorkspaceDialogService=${this.affineWorkspaceDialogService}
.notificationService=${this.notificationService}
.aiDraftService=${this.aiDraftService}
.trackOptions=${{
where: 'chat-panel',
control: 'chat-send',

View File

@@ -1,10 +1,11 @@
import type { AIDraftService } from '@affine/core/modules/ai-button';
import type { CopilotChatHistoryFragment } from '@affine/graphql';
import { SignalWatcher, WithDisposable } from '@blocksuite/affine/global/lit';
import { unsafeCSSVar, unsafeCSSVarV2 } from '@blocksuite/affine/shared/theme';
import type { EditorHost } from '@blocksuite/affine/std';
import { ShadowlessElement } from '@blocksuite/affine/std';
import { ArrowUpBigIcon, CloseIcon } from '@blocksuite/icons/lit';
import { css, html, nothing } from 'lit';
import { css, html, nothing, type PropertyValues } from 'lit';
import { property, query, state } from 'lit/decorators.js';
import { repeat } from 'lit/directives/repeat.js';
import { styleMap } from 'lit/directives/style-map.js';
@@ -351,6 +352,9 @@ export class AIChatInput extends SignalWatcher(
@property({ attribute: false })
accessor searchMenuConfig!: SearchMenuConfig;
@property({ attribute: false })
accessor aiDraftService!: AIDraftService;
@property({ attribute: false })
accessor isRootSession: boolean = true;
@@ -379,6 +383,7 @@ export class AIChatInput extends SignalWatcher(
override connectedCallback() {
super.connectedCallback();
this._disposables.add(
AIProvider.slots.requestSendWithChat.subscribe(
(params: AISendParams | null) => {
@@ -399,6 +404,17 @@ export class AIChatInput extends SignalWatcher(
);
}
protected override firstUpdated(changedProperties: PropertyValues): void {
super.firstUpdated(changedProperties);
this.aiDraftService
.getDraft()
.then(draft => {
this.textarea.value = draft.input;
this.isInputEmpty = !this.textarea.value.trim();
})
.catch(console.error);
}
protected override render() {
const { images, status } = this.chatContextValue;
const hasImages = images.length > 0;
@@ -506,9 +522,11 @@ export class AIChatInput extends SignalWatcher(
}
};
private readonly _handleInput = () => {
private readonly _handleInput = async () => {
const { textarea } = this;
this.isInputEmpty = !textarea.value.trim();
const value = textarea.value.trim();
this.isInputEmpty = !value;
textarea.style.height = 'auto';
textarea.style.height = textarea.scrollHeight + 'px';
let imagesHeight = this.imagePreviewGrid?.scrollHeight ?? 0;
@@ -517,6 +535,10 @@ export class AIChatInput extends SignalWatcher(
textarea.style.height = '148px';
textarea.style.overflowY = 'scroll';
}
await this.aiDraftService.setDraft({
input: value,
});
};
private readonly _handleKeyDown = async (evt: KeyboardEvent) => {
@@ -572,6 +594,9 @@ export class AIChatInput extends SignalWatcher(
this.textarea.style.height = 'unset';
await this.send(value);
await this.aiDraftService.setDraft({
input: '',
});
};
private readonly _handleModelChange = (modelId: string) => {

View File

@@ -11,6 +11,7 @@ import { getViewManager } from '@affine/core/blocksuite/manager/view';
import { NotificationServiceImpl } from '@affine/core/blocksuite/view-extensions/editor-view/notification-service';
import { useAIChatConfig } from '@affine/core/components/hooks/affine/use-ai-chat-config';
import { useAISpecs } from '@affine/core/components/hooks/affine/use-ai-specs';
import { AIDraftService } from '@affine/core/modules/ai-button';
import {
EventSourceService,
FetchService,
@@ -221,6 +222,7 @@ export const Component = () => {
confirmModal.closeConfirmModal,
confirmModal.openConfirmModal
);
content.aiDraftService = framework.get(AIDraftService);
content.createSession = createSession;
content.onOpenDoc = onOpenDoc;

View File

@@ -4,6 +4,7 @@ import type { AffineEditorContainer } from '@affine/core/blocksuite/block-suite-
import { NotificationServiceImpl } from '@affine/core/blocksuite/view-extensions/editor-view/notification-service';
import { useAIChatConfig } from '@affine/core/components/hooks/affine/use-ai-chat-config';
import { useAISpecs } from '@affine/core/components/hooks/affine/use-ai-specs';
import { AIDraftService } from '@affine/core/modules/ai-button';
import { WorkspaceDialogService } from '@affine/core/modules/dialogs';
import { FeatureFlagService } from '@affine/core/modules/feature-flag';
import { AppThemeService } from '@affine/core/modules/theme';
@@ -95,6 +96,7 @@ export const EditorChatPanel = forwardRef(function EditorChatPanel(
confirmModal.closeConfirmModal,
confirmModal.openConfirmModal
);
chatPanelRef.current.aiDraftService = framework.get(AIDraftService);
containerRef.current?.append(chatPanelRef.current);
} else {

View File

@@ -1,12 +1,15 @@
export { AIButtonProvider } from './provider/ai-button';
export { AIButtonService } from './services/ai-button';
export { AIDraftService } from './services/ai-draft';
import type { Framework } from '@toeverything/infra';
import { FeatureFlagService } from '../feature-flag';
import { GlobalStateService } from '../storage';
import { CacheStorage, GlobalStateService } from '../storage';
import { WorkspaceScope } from '../workspace';
import { AIButtonProvider } from './provider/ai-button';
import { AIButtonService } from './services/ai-button';
import { AIDraftService } from './services/ai-draft';
import { AINetworkSearchService } from './services/network-search';
import { AIPlaygroundService } from './services/playground';
import { AIReasoningService } from './services/reasoning';
@@ -31,3 +34,9 @@ export function configureAIReasoningModule(framework: Framework) {
export function configureAIPlaygroundModule(framework: Framework) {
framework.service(AIPlaygroundService, [FeatureFlagService]);
}
export function configureAIDraftModule(framework: Framework) {
framework
.scope(WorkspaceScope)
.service(AIDraftService, [GlobalStateService, CacheStorage]);
}

View File

@@ -0,0 +1,159 @@
import { Service } from '@toeverything/infra';
import type { CacheStorage, GlobalStateService } from '../../storage';
const AI_DRAFTS_KEY = 'AIDrafts';
const AI_DRAFT_FILES_PREFIX = 'AIDraftFile:';
export interface CacheFile {
name: string;
size: number;
type: string;
cacheKey: string;
}
export interface AIDraftState {
input: string;
quote: string;
markdown: string;
images: File[];
}
export interface AIDraftGlobal {
input: string;
quote: string;
markdown: string;
images: CacheFile[];
}
const DEFAULT_VALUE = {
input: '',
quote: '',
markdown: '',
images: [],
};
export class AIDraftService extends Service {
private state: AIDraftState | null = null;
constructor(
private readonly globalStateService: GlobalStateService,
private readonly cacheStorage: CacheStorage
) {
super();
}
setDraft = async (data: Partial<AIDraftState>) => {
const state = await this.getState();
const newState = {
...state,
...data,
};
this.state = newState;
await this.saveDraft(newState);
};
getDraft = async () => {
const state = await this.getState();
return state;
};
private readonly saveDraft = async (state: AIDraftState) => {
const draft =
this.globalStateService.globalState.get<AIDraftGlobal>(AI_DRAFTS_KEY) ||
DEFAULT_VALUE;
const addedImages = state.images.filter(image => {
return !draft.images.some(cacheImage => {
return cacheImage.cacheKey === this.getCacheKey(image);
});
});
const removedImages = draft.images.filter(cacheImage => {
return !state.images.some(image => {
return cacheImage.cacheKey === this.getCacheKey(image);
});
});
const cacheKeys = removedImages.map(image => image.cacheKey);
await this.removeFilesFromCache(cacheKeys);
await this.addFilesToCache(addedImages);
this.globalStateService.globalState.set<AIDraftGlobal>(AI_DRAFTS_KEY, {
input: state.input,
quote: state.quote,
markdown: state.markdown,
images: state.images.map(image => {
return {
name: image.name,
size: image.size,
type: image.type,
cacheKey: this.getCacheKey(image),
};
}),
});
};
private readonly initState = async () => {
if (this.state) {
return;
}
const draft =
this.globalStateService.globalState.get<AIDraftGlobal>(AI_DRAFTS_KEY);
if (draft) {
const images = await this.restoreFilesFromData(draft.images);
this.state = {
input: draft.input,
quote: draft.quote,
markdown: draft.markdown,
images,
};
} else {
this.state = DEFAULT_VALUE;
}
};
private readonly getState = async () => {
await this.initState();
return this.state as AIDraftState;
};
private readonly getCacheKey = (file: File) => {
return AI_DRAFT_FILES_PREFIX + file.name + file.size;
};
private readonly addFilesToCache = async (files: File[]) => {
for (const file of files) {
const arrayBuffer = await file.arrayBuffer();
const cacheKey = this.getCacheKey(file);
await this.cacheStorage.set(cacheKey, arrayBuffer);
}
};
private readonly removeFilesFromCache = async (cacheKeys: string[]) => {
for (const cacheKey of cacheKeys) {
await this.cacheStorage.del(cacheKey);
}
};
private readonly restoreFilesFromData = async (
cacheFiles: CacheFile[]
): Promise<File[]> => {
const files: File[] = [];
for (const cacheFile of cacheFiles) {
try {
const arrayBuffer = await this.cacheStorage.get<ArrayBuffer>(
cacheFile.cacheKey
);
if (arrayBuffer) {
const file = new File([arrayBuffer], cacheFile.name, {
type: cacheFile.type,
});
files.push(file);
}
} catch (error) {
console.warn(`Failed to restore file ${cacheFile.name}:`, error);
}
}
return files;
};
}

View File

@@ -3,6 +3,7 @@ import { type Framework } from '@toeverything/infra';
import {
configureAIButtonModule,
configureAIDraftModule,
configureAINetworkSearchModule,
configureAIPlaygroundModule,
configureAIReasoningModule,
@@ -110,6 +111,7 @@ export function configureCommonModules(framework: Framework) {
configureAIReasoningModule(framework);
configureAIPlaygroundModule(framework);
configureAIButtonModule(framework);
configureAIDraftModule(framework);
configureTemplateDocModule(framework);
configureBlobManagementModule(framework);
configureMediaModule(framework);