diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-chips/type.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-chips/type.ts index 7757eb8d97..a7384d93a0 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-chips/type.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-chips/type.ts @@ -36,7 +36,16 @@ export interface CollectionChip extends BaseChip { collectionId: string; } -export type ChatChip = DocChip | FileChip | TagChip | CollectionChip; +export interface SelectedContextChip extends FileChip { + isSelectedContext: true; +} + +export type ChatChip = + | DocChip + | FileChip + | TagChip + | CollectionChip + | SelectedContextChip; export interface DocDisplayConfig { getIcon: (docId: string) => any; diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-chips/utils.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-chips/utils.ts index c62079875d..2e3d9a87a3 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-chips/utils.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-chips/utils.ts @@ -8,6 +8,7 @@ import type { CollectionChip, DocChip, FileChip, + SelectedContextChip, TagChip, } from './type'; @@ -62,6 +63,12 @@ export function isCollectionChip(chip: ChatChip): chip is CollectionChip { return 'collectionId' in chip; } +export function isSelectedContextChip( + chip: ChatChip +): chip is SelectedContextChip { + return 'isSelectedContext' in chip && chip.isSelectedContext; +} + export function getChipKey(chip: ChatChip) { if (isDocChip(chip)) { return chip.docId; diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-composer/ai-chat-composer.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-composer/ai-chat-composer.ts index d2d521957a..9698305d3a 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-composer/ai-chat-composer.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-composer/ai-chat-composer.ts @@ -32,6 +32,7 @@ import { isCollectionChip, isDocChip, isFileChip, + isSelectedContextChip, isTagChip, omitChip, } from '../ai-chat-chips'; @@ -157,6 +158,8 @@ export class AIChatComposer extends SignalWatcher( .session=${this.session} .chips=${this.chips} .addChip=${this.addChip} + .removeSelectedContextChips=${this.removeSelectedContextChips} + .waitForSelectedContextChipsFinished=${this.waitForSelectedContextChipsFinished} .addImages=${this.addImages} .createSession=${this.createSession} .chatContextValue=${this.chatContextValue} @@ -329,12 +332,17 @@ export class AIChatComposer extends SignalWatcher( ]); }; - private readonly addChip = async (chip: ChatChip) => { + private readonly addChip = async ( + chip: ChatChip, + silent: boolean = false + ) => { this.isChipsCollapsed = false; // if already exists const index = findChipIndex(this.chips, chip); if (index !== -1) { - this.notificationService.toast('chip already exists'); + if (!silent) { + this.notificationService.toast('chip already exists'); + } return; } this.updateChips([...this.chips, chip]); @@ -348,6 +356,15 @@ export class AIChatComposer extends SignalWatcher( await this.removeFromContext(chip); }; + private readonly removeSelectedContextChips = async () => { + const selectedContextChips = this.chips.filter(c => + isSelectedContextChip(c) + ); + for (const chip of selectedContextChips) { + await this.removeChip(chip); + } + }; + private readonly addToContext = async (chip: ChatChip) => { if (isDocChip(chip)) { return await this.addDocToContext(chip); @@ -639,4 +656,27 @@ export class AIChatComposer extends SignalWatcher( } await this.pollEmbeddingStatus(); }; + + private readonly waitForSelectedContextChipsFinished = async ( + timeout = 10000, + interval = 500 + ): Promise => { + const start = Date.now(); + return new Promise((resolve, reject) => { + const check = () => { + const selectedChips = this.chips.filter(c => isSelectedContextChip(c)); + const allFinished = selectedChips.every(c => c.state === 'finished'); + if (allFinished) { + resolve(); + } else if (Date.now() - start >= timeout) { + reject( + new Error('Timeout waiting for selected context chips to finish') + ); + } else { + setTimeout(check, interval); + } + }; + check(); + }); + }; } diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/ai-chat-content.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/ai-chat-content.ts index 79c110e97e..7d72ea3a2a 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/ai-chat-content.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/ai-chat-content.ts @@ -45,6 +45,9 @@ const DEFAULT_CHAT_CONTEXT_VALUE: ChatContextValue = { status: 'idle', error: null, markdown: '', + attachments: [], + snapshot: null, + markdownFile: null, }; export class AIChatContent extends SignalWatcher( diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/type.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/type.ts index 628c194339..3d6ed821b2 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/type.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-content/type.ts @@ -1,5 +1,3 @@ -import type { DocSnapshot } from '@blocksuite/affine/store'; - import type { AIError } from '../../provider'; import type { ChatStatus, HistoryMessage } from '../ai-chat-messages'; @@ -15,7 +13,10 @@ export type ChatContextValue = { // images of the selected content or user uploaded images: File[]; // snapshot of the selected content - snapshot: DocSnapshot; + snapshot: File | null; + // attachments of the selected content attachments: File[]; + // markdown file of the selected content + markdownFile: File | null; abortController: AbortController | null; }; diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/ai-chat-input.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/ai-chat-input.ts index 80e3a56fb4..a1d472fa8b 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/ai-chat-input.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/ai-chat-input.ts @@ -338,7 +338,13 @@ export class AIChatInput extends SignalWatcher( accessor addImages!: (images: File[]) => void; @property({ attribute: false }) - accessor addChip!: (chip: ChatChip) => Promise; + accessor addChip!: (chip: ChatChip, silent?: boolean) => Promise; + + @property({ attribute: false }) + accessor removeSelectedContextChips!: () => Promise; + + @property({ attribute: false }) + accessor waitForSelectedContextChipsFinished!: () => Promise; @property({ attribute: false }) accessor networkSearchConfig!: AINetworkSearchConfig; @@ -603,9 +609,47 @@ export class AIChatInput extends SignalWatcher( this.modelId = modelId; }; + private readonly addSelectedContextChips = async () => { + const { snapshot, markdownFile, attachments } = this.chatContextValue; + await this.removeSelectedContextChips(); + for (const attachment of attachments) { + await this.addChip( + { + file: attachment, + state: 'processing', + isSelectedContext: true, + }, + true + ); + if (snapshot) { + await this.addChip( + { + file: snapshot, + state: 'processing', + isSelectedContext: true, + }, + true + ); + } + if (markdownFile) { + await this.addChip( + { + file: markdownFile, + state: 'processing', + isSelectedContext: true, + }, + true + ); + } + } + await this.waitForSelectedContextChipsFinished(); + }; + send = async (text: string) => { try { const { status, markdown, images } = this.chatContextValue; + await this.addSelectedContextChips(); + if (status === 'loading' || status === 'transmitting') return; if (!text) return; if (!AIProvider.actions.chat) return; @@ -620,13 +664,13 @@ export class AIChatInput extends SignalWatcher( abortController, }); - const attachments = await Promise.all( + const imageAttachments = await Promise.all( images?.map(image => readBlobAsURL(image)) ); const userInput = (markdown ? `${markdown}\n` : '') + text; // optimistic update messages - await this._preUpdateMessages(userInput, attachments); + await this._preUpdateMessages(userInput, imageAttachments); const sessionId = (await this.createSession())?.sessionId; let contexts = await this._getMatchedContexts(); diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/type.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/type.ts index 22b09806b9..bb22b77b75 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/type.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/type.ts @@ -26,5 +26,8 @@ export type AIChatInputContext = { quote?: string; markdown?: string; images: File[]; + attachments: File[]; + snapshot: File | null; + markdownFile: File | null; abortController: AbortController | null; }; diff --git a/packages/frontend/core/src/blocksuite/ai/utils/attachment.ts b/packages/frontend/core/src/blocksuite/ai/utils/attachment.ts new file mode 100644 index 0000000000..6f3d3ab652 --- /dev/null +++ b/packages/frontend/core/src/blocksuite/ai/utils/attachment.ts @@ -0,0 +1,9 @@ +import { AttachmentBlockModel } from '@blocksuite/affine/model'; +import type { BlockModel } from '@blocksuite/affine/store'; +import type { GfxModel } from '@blocksuite/std/gfx'; + +export function isAttachment( + model: GfxModel | BlockModel +): model is AttachmentBlockModel { + return model instanceof AttachmentBlockModel; +} diff --git a/packages/frontend/core/src/blocksuite/ai/utils/extract.ts b/packages/frontend/core/src/blocksuite/ai/utils/extract.ts index 3f3ef39484..a0ca86c156 100644 --- a/packages/frontend/core/src/blocksuite/ai/utils/extract.ts +++ b/packages/frontend/core/src/blocksuite/ai/utils/extract.ts @@ -1,6 +1,5 @@ import { WorkspaceImpl } from '@affine/core/modules/workspace/impls/workspace'; import { - AttachmentBlockModel, DatabaseBlockModel, ImageBlockModel, NoteBlockModel, @@ -32,6 +31,7 @@ import { Doc as YDoc } from 'yjs'; import { getStoreManager } from '../../manager/store'; import type { ChatContextValue } from '../components/ai-chat-content'; +import { isAttachment } from './attachment'; import { getSelectedAttachmentsAsBlobs, getSelectedImagesAsBlobs, @@ -63,8 +63,6 @@ async function extractEdgelessSelected( const attachments: ChatContextValue['attachments'] = []; if (selectedElements.length) { - console.log('selectedElements', selectedElements); - const transformer = host.store.getTransformer(); const markdownAdapter = new MarkdownAdapter( transformer, @@ -76,6 +74,8 @@ async function extractEdgelessSelected( }); collection.meta.initialize(); + let needSnapshot = false; + let needMarkdown = false; try { const fragmentDoc = collection.createDoc(); const fragment = fragmentDoc.getStore(); @@ -85,19 +85,13 @@ async function extractEdgelessSelected( const surfaceId = fragment.addBlock('affine:surface', {}, rootId); const noteId = fragment.addBlock('affine:note', {}, rootId); for (const element of selectedElements) { - if (element instanceof GfxBlockElementModel) { - const props = getBlockProps(element); - fragment.addBlock(element.flavour, props, surfaceId); - } - if (element instanceof NoteBlockModel) { + needMarkdown = true; for (const child of element.children) { const props = getBlockProps(child); fragment.addBlock(child.flavour, props, noteId); } - } - - if (element instanceof AttachmentBlockModel) { + } else if (isAttachment(element)) { const { name, sourceId } = element.props; if (name && sourceId) { const blob = await host.store.blobSync.get(sourceId); @@ -105,11 +99,19 @@ async function extractEdgelessSelected( attachments.push(new File([blob], name)); } } + } else if (element instanceof GfxBlockElementModel) { + const props = getBlockProps(element); + needSnapshot = true; + fragment.addBlock(element.flavour, props, surfaceId); } } - snapshot = transformer.docToSnapshot(fragment) ?? null; - markdown = (await markdownAdapter.fromDoc(fragment))?.file ?? ''; + if (needSnapshot) { + snapshot = transformer.docToSnapshot(fragment) ?? null; + } + if (needMarkdown) { + markdown = (await markdownAdapter.fromDoc(fragment))?.file ?? ''; + } } finally { collection.dispose(); } @@ -125,8 +127,10 @@ async function extractEdgelessSelected( return { images: [new File([blob], 'selected.png')], - snapshot: snapshot ?? undefined, - markdown: markdown.length ? markdown : undefined, + snapshot: snapshot + ? new File([JSON.stringify(snapshot)], 'selected.json') + : null, + markdownFile: markdown.length ? new File([markdown], 'selected.md') : null, attachments, }; } diff --git a/packages/frontend/core/src/blocksuite/ai/utils/selection-utils.ts b/packages/frontend/core/src/blocksuite/ai/utils/selection-utils.ts index b782affe0e..6393280a82 100644 --- a/packages/frontend/core/src/blocksuite/ai/utils/selection-utils.ts +++ b/packages/frontend/core/src/blocksuite/ai/utils/selection-utils.ts @@ -6,11 +6,7 @@ import { getSurfaceBlock, type SurfaceBlockComponent, } from '@blocksuite/affine/blocks/surface'; -import { - AttachmentBlockModel, - DatabaseBlockModel, - ImageBlockModel, -} from '@blocksuite/affine/model'; +import { DatabaseBlockModel, ImageBlockModel } from '@blocksuite/affine/model'; import { getBlockSelectionsCommand, getImageSelectionsCommand, @@ -33,6 +29,7 @@ import { import { getContentFromSlice } from '../../utils'; import type { CopilotTool } from '../tool/copilot-tool'; +import { isAttachment } from './attachment'; import { getEdgelessCopilotWidget } from './get-edgeless-copilot-widget'; export async function selectedToCanvas(host: EditorHost) { @@ -245,7 +242,7 @@ export const getSelectedAttachmentsAsBlobs = async (host: EditorHost) => { const attachments: { sourceId: string; name: string }[] = []; for (const block of blocks) { - if (block.model instanceof AttachmentBlockModel) { + if (isAttachment(block.model)) { const { sourceId, name } = block.model.props; if (sourceId && name) { attachments.push({ sourceId, name });