diff --git a/blocksuite/affine/model/src/elements/mindmap/mindmap.ts b/blocksuite/affine/model/src/elements/mindmap/mindmap.ts index fe3cd36c57..a8c9f3b736 100644 --- a/blocksuite/affine/model/src/elements/mindmap/mindmap.ts +++ b/blocksuite/affine/model/src/elements/mindmap/mindmap.ts @@ -743,7 +743,7 @@ export class MindmapElementModel extends GfxGroupLikeElementModel { this.childElements.forEach(el => { diff --git a/packages/frontend/core/src/blocksuite/ai/actions/doc-handler.ts b/packages/frontend/core/src/blocksuite/ai/actions/doc-handler.ts index cd6e0c32af..c1bb946196 100644 --- a/packages/frontend/core/src/blocksuite/ai/actions/doc-handler.ts +++ b/packages/frontend/core/src/blocksuite/ai/actions/doc-handler.ts @@ -108,7 +108,7 @@ function actionToStream( workspaceId: host.doc.workspace.id, } as Parameters[0]; // @ts-expect-error TODO(@Peng): maybe fix this - stream = action(options); + stream = await action(options); if (!stream) return; yield* stream; }, diff --git a/packages/frontend/core/src/blocksuite/ai/actions/edgeless-handler.ts b/packages/frontend/core/src/blocksuite/ai/actions/edgeless-handler.ts index 6464b3ac69..f1fbce5a32 100644 --- a/packages/frontend/core/src/blocksuite/ai/actions/edgeless-handler.ts +++ b/packages/frontend/core/src/blocksuite/ai/actions/edgeless-handler.ts @@ -207,7 +207,7 @@ function actionToStream( } // @ts-expect-error TODO(@Peng): maybe fix this - stream = action(options); + stream = await action(options); if (!stream) return; yield* stream; }, @@ -237,7 +237,7 @@ function actionToStream( } as Parameters[0]; // @ts-expect-error TODO(@Peng): maybe fix this - stream = action(options); + stream = await action(options); if (!stream) return; yield* stream; }, diff --git a/packages/frontend/core/src/blocksuite/ai/actions/page-response.ts b/packages/frontend/core/src/blocksuite/ai/actions/page-response.ts index e96d4d5c4d..f6377f5878 100644 --- a/packages/frontend/core/src/blocksuite/ai/actions/page-response.ts +++ b/packages/frontend/core/src/blocksuite/ai/actions/page-response.ts @@ -5,7 +5,11 @@ import { } from '@blocksuite/affine/blocks/surface'; import { fitContent } from '@blocksuite/affine/gfx/shape'; import { createTemplateJob } from '@blocksuite/affine/gfx/template'; -import { Bound, getCommonBound } from '@blocksuite/affine/global/gfx'; +import { + Bound, + getCommonBound, + type XYWH, +} from '@blocksuite/affine/global/gfx'; import type { MindmapElementModel, ShapeElementModel, @@ -83,7 +87,10 @@ function responseToBrainstormMindmap( }); // wait for mindmap xywh update setTimeout(() => { - const frameBound = expandBound(mindmap.elementBound, PADDING); + const { x, y, w, h } = mindmap.elementBound; + const targetBound: XYWH = [x, y + h / 2 + PADDING - 15, w, h]; + mindmap.moveTo(targetBound); + const frameBound = expandBound(new Bound(...targetBound), PADDING); addSurfaceRefBlock(host, frameBound, place); }, 0); }); @@ -107,7 +114,7 @@ function responseToMakeItReal(host: EditorHost, ctx: AIContext, place: Place) { const bound = getEdgelessContentBound(host); const x = bound ? bound.x + bound.w + PADDING * 2 : 0; const y = bound ? bound.y : 0; - const htmlBound = new Bound(x, y, width || 800, height || 600); + const htmlBound = new Bound(x, y + PADDING, width || 800, height || 600); const html = preprocessHtml(aiPanel.answer); host.doc.transact(() => { host.doc.addBlock( diff --git a/packages/frontend/core/src/blocksuite/ai/actions/types.ts b/packages/frontend/core/src/blocksuite/ai/actions/types.ts index 99a819db0d..d109a6bc54 100644 --- a/packages/frontend/core/src/blocksuite/ai/actions/types.ts +++ b/packages/frontend/core/src/blocksuite/ai/actions/types.ts @@ -13,6 +13,8 @@ import type { EditorHost } from '@blocksuite/affine/std'; import type { GfxModel } from '@blocksuite/affine/std/gfx'; import type { BlockModel } from '@blocksuite/affine/store'; +import type { PromptKey } from '../provider/prompt'; + export const translateLangs = [ 'English', 'Spanish', @@ -131,6 +133,7 @@ declare global { interface ChatOptions extends AITextActionOptions { sessionId?: string; isRootSession?: boolean; + networkSearch?: boolean; contexts?: { docs: AIDocContextOption[]; files: AIFileContextOption[]; @@ -155,107 +158,107 @@ declare global { interface AIActions { // chat is a bit special because it's has a internally maintained session - chat(options: T): AIActionTextResponse; + chat(options: T): Promise>; summary( options: T - ): AIActionTextResponse; + ): Promise>; improveWriting( options: T - ): AIActionTextResponse; + ): Promise>; improveGrammar( options: T - ): AIActionTextResponse; + ): Promise>; fixSpelling( options: T - ): AIActionTextResponse; + ): Promise>; createHeadings( options: T - ): AIActionTextResponse; + ): Promise>; makeLonger( options: T - ): AIActionTextResponse; + ): Promise>; makeShorter( options: T - ): AIActionTextResponse; + ): Promise>; continueWriting( options: T - ): AIActionTextResponse; + ): Promise>; checkCodeErrors( options: T - ): AIActionTextResponse; + ): Promise>; explainCode( options: T - ): AIActionTextResponse; + ): Promise>; writeArticle( options: T - ): AIActionTextResponse; + ): Promise>; writeTwitterPost( options: T - ): AIActionTextResponse; + ): Promise>; writePoem( options: T - ): AIActionTextResponse; + ): Promise>; writeBlogPost( options: T - ): AIActionTextResponse; + ): Promise>; brainstorm( options: T - ): AIActionTextResponse; + ): Promise>; writeOutline( options: T - ): AIActionTextResponse; + ): Promise>; explainImage( options: T - ): AIActionTextResponse; + ): Promise>; findActions( options: T - ): AIActionTextResponse; + ): Promise>; // mindmap brainstormMindmap( options: T - ): AIActionTextResponse; + ): Promise>; expandMindmap( options: T - ): AIActionTextResponse; + ): Promise>; // presentation createSlides( options: T - ): AIActionTextResponse; + ): Promise>; // explain this explain( options: T - ): AIActionTextResponse; + ): Promise>; // actions with variants translate( options: T - ): AIActionTextResponse; + ): Promise>; changeTone( options: T - ): AIActionTextResponse; + ): Promise>; // make it real, image to text makeItReal( options: T - ): AIActionTextResponse; + ): Promise>; createImage( options: T - ): AIActionTextResponse; + ): Promise>; processImage( options: T - ): AIActionTextResponse; + ): Promise>; filterImage( options: T - ): AIActionTextResponse; + ): Promise>; generateCaption( options: T - ): AIActionTextResponse; + ): Promise>; } type AIDocsAndFilesContext = { @@ -357,12 +360,16 @@ declare global { >[]; }; + interface CreateSessionOptions { + docId: string; + workspaceId: string; + promptName: PromptKey; + sessionId?: string; + retry?: boolean; + } + interface AISessionService { - createSession: ( - workspaceId: string, - docId: string, - promptName?: string - ) => Promise; + createSession: (options: CreateSessionOptions) => Promise; getSessions: ( workspaceId: string, docId?: string, diff --git a/packages/frontend/core/src/blocksuite/ai/chat-panel/chat-panel-messages.ts b/packages/frontend/core/src/blocksuite/ai/chat-panel/chat-panel-messages.ts index 2f25a215ec..e93ca62eee 100644 --- a/packages/frontend/core/src/blocksuite/ai/chat-panel/chat-panel-messages.ts +++ b/packages/frontend/core/src/blocksuite/ai/chat-panel/chat-panel-messages.ts @@ -348,10 +348,10 @@ export class ChatPanelMessages extends WithDisposable(ShadowlessElement) { } retry = async () => { - const { doc } = this.host; try { const sessionId = await this.createSessionId(); if (!sessionId) return; + if (!AIProvider.actions.chat) return; const abortController = new AbortController(); const messages = [...this.chatContextValue.messages]; @@ -362,7 +362,8 @@ export class ChatPanelMessages extends WithDisposable(ShadowlessElement) { } this.updateContext({ messages, status: 'loading', error: null }); - const stream = AIProvider.actions.chat?.({ + const { doc } = this.host; + const stream = await AIProvider.actions.chat({ sessionId, retry: true, docId: doc.id, @@ -374,18 +375,15 @@ export class ChatPanelMessages extends WithDisposable(ShadowlessElement) { control: 'chat-send', isRootSession: true, }); - - if (stream) { - this.updateContext({ abortController }); - for await (const text of stream) { - const messages = [...this.chatContextValue.messages]; - const last = messages[messages.length - 1] as ChatMessage; - last.content += text; - this.updateContext({ messages, status: 'transmitting' }); - } - - this.updateContext({ status: 'success' }); + this.updateContext({ abortController }); + for await (const text of stream) { + const messages = [...this.chatContextValue.messages]; + const last = messages[messages.length - 1] as ChatMessage; + last.content += text; + this.updateContext({ messages, status: 'transmitting' }); } + + this.updateContext({ status: 'success' }); } catch (error) { this.updateContext({ status: 'error', error: error as AIError }); } finally { diff --git a/packages/frontend/core/src/blocksuite/ai/chat-panel/index.ts b/packages/frontend/core/src/blocksuite/ai/chat-panel/index.ts index 17d1676263..146331259f 100644 --- a/packages/frontend/core/src/blocksuite/ai/chat-panel/index.ts +++ b/packages/frontend/core/src/blocksuite/ai/chat-panel/index.ts @@ -141,7 +141,6 @@ export class ChatPanel extends SignalWatcher( const history = histories?.find(history => history.sessionId === sessionId); if (history) { messages.push(...history.messages); - AIProvider.LAST_ROOT_SESSION_ID = history.sessionId; } this.chatContextValue = { @@ -182,10 +181,11 @@ export class ChatPanel extends SignalWatcher( if (this._sessionId) { return this._sessionId; } - this._sessionId = await AIProvider.session?.createSession( - this.doc.workspace.id, - this.doc.id - ); + this._sessionId = await AIProvider.session?.createSession({ + docId: this.doc.id, + workspaceId: this.doc.workspace.id, + promptName: 'Chat With AFFiNE AI', + }); return this._sessionId; }; diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/ai-chat-input.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/ai-chat-input.ts index 3ea8407a4b..eb19a694a6 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/ai-chat-input.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/ai-chat-input.ts @@ -25,7 +25,6 @@ import type { } from '../ai-chat-chips/type'; import { isDocChip, isFileChip } from '../ai-chat-chips/utils'; import type { ChatMessage } from '../ai-chat-messages'; -import { PROMPT_NAME_AFFINE_AI, PROMPT_NAME_NETWORK_SEARCH } from './const'; import type { AIChatInputContext, AINetworkSearchConfig } from './type'; const MaximumImageCount = 32; @@ -255,7 +254,8 @@ export class AIChatInput extends SignalWatcher(WithDisposable(LitElement)) { private get _isNetworkActive() { return ( !!this.networkSearchConfig.visible.value && - !!this.networkSearchConfig.enabled.value + !!this.networkSearchConfig.enabled.value && + !this._isNetworkDisabled ); } @@ -274,22 +274,6 @@ export class AIChatInput extends SignalWatcher(WithDisposable(LitElement)) { ); } - private _getPromptName() { - if (this._isNetworkDisabled) { - return PROMPT_NAME_AFFINE_AI; - } - return this._isNetworkActive - ? PROMPT_NAME_NETWORK_SEARCH - : PROMPT_NAME_AFFINE_AI; - } - - private async _updatePromptName(promptName: string) { - const sessionId = await this.createSessionId(); - if (sessionId && AIProvider.session) { - await AIProvider.session.updateSession(sessionId, promptName); - } - } - override connectedCallback() { super.connectedCallback(); this._disposables.add( @@ -313,7 +297,7 @@ export class AIChatInput extends SignalWatcher(WithDisposable(LitElement)) { const { images, status } = this.chatContextValue; const hasImages = images.length > 0; const maxHeight = hasImages ? 272 + 2 : 200 + 2; - const uploadDisabled = this._isNetworkActive && !this._isNetworkDisabled; + const uploadDisabled = this._isNetworkActive; return html`
${PublishIcon()} @@ -473,6 +455,9 @@ export class AIChatInput extends SignalWatcher(WithDisposable(LitElement)) { e.preventDefault(); e.stopPropagation(); + if (this._isNetworkDisabled) { + return; + } const enable = this.networkSearchConfig.enabled.value; this.networkSearchConfig.setEnabled(!enable); }; @@ -514,13 +499,12 @@ export class AIChatInput extends SignalWatcher(WithDisposable(LitElement)) { }; send = async (text: string) => { - const { status, markdown, images } = this.chatContextValue; - if (status === 'loading' || status === 'transmitting') return; - if (!text) return; - if (!AIProvider.actions.chat) return; - try { - const promptName = this._getPromptName(); + const { status, markdown, images } = this.chatContextValue; + if (status === 'loading' || status === 'transmitting') return; + if (!text) return; + if (!AIProvider.actions.chat) return; + const abortController = new AbortController(); this.updateContext({ images: [], @@ -538,16 +522,13 @@ export class AIChatInput extends SignalWatcher(WithDisposable(LitElement)) { // optimistic update messages await this._preUpdateMessages(userInput, attachments); - // must update prompt name after local chat message is updated - // otherwise, the unauthorized error can not be rendered properly - await this._updatePromptName(promptName); const sessionId = await this.createSessionId(); const contexts = await this._getMatchedContexts(userInput); if (abortController.signal.aborted) { return; } - const stream = AIProvider.actions.chat({ + const stream = await AIProvider.actions.chat({ sessionId, input: userInput, contexts, @@ -560,6 +541,7 @@ export class AIChatInput extends SignalWatcher(WithDisposable(LitElement)) { isRootSession: this.isRootSession, where: this.trackOptions.where, control: this.trackOptions.control, + networkSearch: this._isNetworkActive, }); for await (const text of stream) { diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/const.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/const.ts deleted file mode 100644 index 0bcb4b739e..0000000000 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/const.ts +++ /dev/null @@ -1,2 +0,0 @@ -export const PROMPT_NAME_AFFINE_AI = 'Chat With AFFiNE AI'; -export const PROMPT_NAME_NETWORK_SEARCH = 'Search With AFFiNE AI'; diff --git a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/index.ts b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/index.ts index 5d7c7cb47b..21af31fca1 100644 --- a/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/index.ts +++ b/packages/frontend/core/src/blocksuite/ai/components/ai-chat-input/index.ts @@ -1,3 +1,2 @@ export * from './ai-chat-input'; -export * from './const'; export * from './type'; diff --git a/packages/frontend/core/src/blocksuite/ai/peek-view/chat-block-peek-view.ts b/packages/frontend/core/src/blocksuite/ai/peek-view/chat-block-peek-view.ts index 95a9cd88e4..fcfde665bc 100644 --- a/packages/frontend/core/src/blocksuite/ai/peek-view/chat-block-peek-view.ts +++ b/packages/frontend/core/src/blocksuite/ai/peek-view/chat-block-peek-view.ts @@ -320,16 +320,12 @@ export class AIChatBlockPeekView extends LitElement { * Retry the last chat message */ retry = async () => { - const { doc } = this.host; - const { _forkBlockId, _forkSessionId } = this; - if (!_forkBlockId || !_forkSessionId) { - return; - } - - let content = ''; try { - const abortController = new AbortController(); + const { _forkBlockId, _forkSessionId } = this; + if (!_forkBlockId || !_forkSessionId) return; + if (!AIProvider.actions.chat) return; + const abortController = new AbortController(); const messages = [...this.chatContext.messages]; const last = messages[messages.length - 1]; if ('content' in last) { @@ -339,7 +335,8 @@ export class AIChatBlockPeekView extends LitElement { } this.updateContext({ messages, status: 'loading', error: null }); - const stream = AIProvider.actions.chat?.({ + const { doc } = this.host; + const stream = await AIProvider.actions.chat({ sessionId: _forkSessionId, retry: true, docId: doc.id, @@ -351,26 +348,21 @@ export class AIChatBlockPeekView extends LitElement { control: 'chat-send', }); - if (stream) { - this.updateContext({ abortController }); - for await (const text of stream) { - const messages = [...this.chatContext.messages]; - const last = messages[messages.length - 1] as ChatMessage; - last.content += text; - this.updateContext({ messages, status: 'transmitting' }); - content += text; - } - - this.updateContext({ status: 'success' }); + this.updateContext({ abortController }); + for await (const text of stream) { + const messages = [...this.chatContext.messages]; + const last = messages[messages.length - 1] as ChatMessage; + last.content += text; + this.updateContext({ messages, status: 'transmitting' }); } + + this.updateContext({ status: 'success' }); + // Update new chat block messages if there are contents returned from AI + await this.updateChatBlockMessages(); } catch (error) { this.updateContext({ status: 'error', error: error as AIError }); } finally { this.updateContext({ abortController: null }); - if (content) { - // Update new chat block messages if there are contents returned from AI - await this.updateChatBlockMessages(); - } } }; diff --git a/packages/frontend/core/src/blocksuite/ai/provider/ai-provider.ts b/packages/frontend/core/src/blocksuite/ai/provider/ai-provider.ts index 9fecaaa00f..53e5acb624 100644 --- a/packages/frontend/core/src/blocksuite/ai/provider/ai-provider.ts +++ b/packages/frontend/core/src/blocksuite/ai/provider/ai-provider.ts @@ -100,8 +100,6 @@ export class AIProvider { static LAST_ACTION_SESSIONID = ''; - static LAST_ROOT_SESSION_ID = ''; - static MAX_LOCAL_HISTORY = 10; private readonly actions: Partial = {}; @@ -158,10 +156,10 @@ export class AIProvider { id: T, action: ( ...options: Parameters - ) => ReturnType + ) => Promise> ): void { // @ts-expect-error TODO: maybe fix this - this.actions[id] = ( + this.actions[id] = async ( ...args: Parameters ) => { const options = args[0]; @@ -176,9 +174,8 @@ export class AIProvider { this.actionHistory.shift(); } // wrap the action with slot actions - const result: BlockSuitePresets.TextStream | Promise = action( - ...args - ); + const result: BlockSuitePresets.TextStream | Promise = + await action(...args); const isTextStream = ( m: BlockSuitePresets.TextStream | Promise ): m is BlockSuitePresets.TextStream => @@ -315,7 +312,7 @@ export class AIProvider { id: T, action: ( ...options: Parameters - ) => ReturnType + ) => Promise> ): void; static provide(id: unknown, action: unknown) { diff --git a/packages/frontend/core/src/blocksuite/ai/provider/request.ts b/packages/frontend/core/src/blocksuite/ai/provider/request.ts index 5590faae5c..764d479c9b 100644 --- a/packages/frontend/core/src/blocksuite/ai/provider/request.ts +++ b/packages/frontend/core/src/blocksuite/ai/provider/request.ts @@ -3,16 +3,12 @@ import { partition } from 'lodash-es'; import { AIProvider } from './ai-provider'; import type { CopilotClient } from './copilot-client'; import { delay, toTextStream } from './event-source'; -import type { PromptKey } from './prompt'; const TIMEOUT = 50000; export type TextToTextOptions = { client: CopilotClient; - docId: string; - workspaceId: string; - promptName?: PromptKey; - sessionId?: string | Promise; + sessionId: string; content?: string; attachments?: (string | Blob | File)[]; params?: Record; @@ -61,30 +57,22 @@ async function resizeImage(blob: Blob | File): Promise { return null; } -async function createSessionMessage({ +interface CreateMessageOptions { + client: CopilotClient; + sessionId: string; + content?: string; + attachments?: (string | Blob | File)[]; + params?: Record; +} + +async function createMessage({ client, - docId, - workspaceId, - promptName = 'Chat With AFFiNE AI', + sessionId, content, - sessionId: providedSessionId, attachments, params, -}: TextToTextOptions): Promise<{ - sessionId: string; - messageId: string; -}> { - if (!promptName && !providedSessionId) { - throw new Error('promptName or sessionId is required'); - } +}: CreateMessageOptions): Promise { const hasAttachments = attachments && attachments.length > 0; - const sessionId = await (providedSessionId ?? - client.createSession({ - workspaceId, - docId, - promptName, - })); - const options: Parameters[0] = { sessionId, content, @@ -110,67 +98,44 @@ async function createSessionMessage({ ).filter(Boolean) as File[]; } - const messageId = await client.createMessage(options); - return { - messageId, - sessionId, - }; + return await client.createMessage(options); } export function textToText({ client, - docId, - workspaceId, - promptName, + sessionId, content, attachments, params, - sessionId, stream, signal, timeout = TIMEOUT, retry = false, workflow = false, - isRootSession = false, postfix, }: TextToTextOptions) { - let _sessionId: string; - let _messageId: string | undefined; + let messageId: string | undefined; if (stream) { return { [Symbol.asyncIterator]: async function* () { - if (retry) { - const retrySessionId = - (await sessionId) ?? AIProvider.LAST_ACTION_SESSIONID; - _sessionId = retrySessionId; - _messageId = undefined; - } else { - const message = await createSessionMessage({ + if (!retry) { + messageId = await createMessage({ client, - docId, - workspaceId, - promptName, + sessionId, content, attachments, params, - sessionId, }); - _sessionId = message.sessionId; - _messageId = message.messageId; } - const eventSource = client.chatTextStream( { - sessionId: _sessionId, - messageId: _messageId, + sessionId, + messageId, }, workflow ? 'workflow' : undefined ); - AIProvider.LAST_ACTION_SESSIONID = _sessionId; - if (isRootSession) { - AIProvider.LAST_ROOT_SESSION_ID = _sessionId; - } + AIProvider.LAST_ACTION_SESSIONID = sessionId; if (signal) { if (signal.aborted) { @@ -212,34 +177,20 @@ export function textToText({ }) : null, (async function () { - if (retry) { - const retrySessionId = - (await sessionId) ?? AIProvider.LAST_ACTION_SESSIONID; - _sessionId = retrySessionId; - _messageId = undefined; - } else { - const message = await createSessionMessage({ + if (!retry) { + messageId = await createMessage({ client, - docId, - workspaceId, - promptName, + sessionId, content, attachments, params, - sessionId, }); - _sessionId = message.sessionId; - _messageId = message.messageId; - } - - AIProvider.LAST_ACTION_SESSIONID = _sessionId; - if (isRootSession) { - AIProvider.LAST_ROOT_SESSION_ID = _sessionId; } + AIProvider.LAST_ACTION_SESSIONID = sessionId; return client.chatText({ - sessionId: _sessionId, - messageId: _messageId, + sessionId, + messageId, }); })(), ]); @@ -248,50 +199,36 @@ export function textToText({ // Only one image is currently being processed export function toImage({ - docId, - workspaceId, - promptName, content, + sessionId, attachments, params, seed, - sessionId, signal, timeout = TIMEOUT, retry = false, workflow = false, client, }: ToImageOptions) { - let _sessionId: string; - let _messageId: string | undefined; + let messageId: string | undefined; return { [Symbol.asyncIterator]: async function* () { - if (retry) { - const retrySessionId = - (await sessionId) ?? AIProvider.LAST_ACTION_SESSIONID; - _sessionId = retrySessionId; - _messageId = undefined; - } else { - const { messageId, sessionId } = await createSessionMessage({ - docId, - workspaceId, - promptName, + if (!retry) { + messageId = await createMessage({ + client, + sessionId, content, attachments, params, - client, }); - _sessionId = sessionId; - _messageId = messageId; } - const eventSource = client.imagesStream( - _sessionId, - _messageId, + sessionId, + messageId, seed, workflow ? 'workflow' : undefined ); - AIProvider.LAST_ACTION_SESSIONID = _sessionId; + AIProvider.LAST_ACTION_SESSIONID = sessionId; for await (const event of toTextStream(eventSource, { timeout, diff --git a/packages/frontend/core/src/blocksuite/ai/provider/setup-provider.tsx b/packages/frontend/core/src/blocksuite/ai/provider/setup-provider.tsx index 652cf4ce18..62f4fedfd2 100644 --- a/packages/frontend/core/src/blocksuite/ai/provider/setup-provider.tsx +++ b/packages/frontend/core/src/blocksuite/ai/provider/setup-provider.tsx @@ -14,7 +14,7 @@ import type { PromptKey } from './prompt'; import { textToText, toImage } from './request'; import { setupTracker } from './tracker'; -const filterStyleToPromptName = new Map( +const filterStyleToPromptName = new Map( Object.entries({ 'Clay style': 'workflow:image-clay', 'Pixel style': 'workflow:image-pixel', @@ -23,7 +23,7 @@ const filterStyleToPromptName = new Map( }) ); -const processTypeToPromptName = new Map( +const processTypeToPromptName = new Map( Object.entries({ Clearer: 'debug:action:fal-upscaler', 'Remove background': 'debug:action:fal-remove-bg', @@ -35,31 +35,78 @@ export function setupAIProvider( client: CopilotClient, globalDialogService: GlobalDialogService ) { + async function createSession({ + workspaceId, + docId, + promptName, + sessionId, + retry, + }: { + workspaceId: string; + docId: string; + promptName: PromptKey; + sessionId?: string; + retry?: boolean; + }) { + if (sessionId) return sessionId; + if (retry) return AIProvider.LAST_ACTION_SESSIONID; + + return client.createSession({ + workspaceId, + docId, + promptName, + }); + } + //#region actions - AIProvider.provide('chat', options => { - const { input, contexts, ...rest } = options; + AIProvider.provide('chat', async options => { + const { input, contexts, attachments, networkSearch, retry } = options; + const disableSearch = + !!contexts?.files.length || + !!contexts?.docs.length || + !!attachments?.length; + const promptName = + networkSearch && !disableSearch + ? 'Search With AFFiNE AI' + : 'Chat With AFFiNE AI'; + const sessionId = await createSession({ + promptName, + ...options, + }); + if (!retry) { + await AIProvider.session?.updateSession(sessionId, promptName); + } return textToText({ - ...rest, + ...options, client, + sessionId, content: input, params: contexts, }); }); - AIProvider.provide('summary', options => { + AIProvider.provide('summary', async options => { + const sessionId = await createSession({ + promptName: 'Summary', + ...options, + }); return textToText({ ...options, client, + sessionId, content: options.input, - promptName: 'Summary', }); }); - AIProvider.provide('translate', options => { + AIProvider.provide('translate', async options => { + const sessionId = await createSession({ + promptName: 'Translate to', + ...options, + }); return textToText({ ...options, client, - promptName: 'Translate to', + sessionId, content: options.input, params: { language: options.lang, @@ -67,200 +114,280 @@ export function setupAIProvider( }); }); - AIProvider.provide('changeTone', options => { + AIProvider.provide('changeTone', async options => { + const sessionId = await createSession({ + promptName: 'Change tone to', + ...options, + }); return textToText({ ...options, client, + sessionId, params: { tone: options.tone.toLowerCase(), }, content: options.input, - promptName: 'Change tone to', }); }); - AIProvider.provide('improveWriting', options => { - return textToText({ - ...options, - client, - content: options.input, + AIProvider.provide('improveWriting', async options => { + const sessionId = await createSession({ promptName: 'Improve writing for it', + ...options, }); - }); - - AIProvider.provide('improveGrammar', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('improveGrammar', async options => { + const sessionId = await createSession({ promptName: 'Improve grammar for it', + ...options, }); - }); - - AIProvider.provide('fixSpelling', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('fixSpelling', async options => { + const sessionId = await createSession({ promptName: 'Fix spelling for it', + ...options, }); - }); - - AIProvider.provide('createHeadings', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('createHeadings', async options => { + const sessionId = await createSession({ promptName: 'Create headings', + ...options, }); - }); - - AIProvider.provide('makeLonger', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('makeLonger', async options => { + const sessionId = await createSession({ promptName: 'Make it longer', + ...options, }); - }); - - AIProvider.provide('makeShorter', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('makeShorter', async options => { + const sessionId = await createSession({ promptName: 'Make it shorter', + ...options, }); - }); - - AIProvider.provide('checkCodeErrors', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('checkCodeErrors', async options => { + const sessionId = await createSession({ promptName: 'Check code error', + ...options, }); - }); - - AIProvider.provide('explainCode', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('explainCode', async options => { + const sessionId = await createSession({ promptName: 'Explain this code', + ...options, }); - }); - - AIProvider.provide('writeArticle', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('writeArticle', async options => { + const sessionId = await createSession({ promptName: 'Write an article about this', + ...options, }); - }); - - AIProvider.provide('writeTwitterPost', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('writeTwitterPost', async options => { + const sessionId = await createSession({ promptName: 'Write a twitter about this', + ...options, }); - }); - - AIProvider.provide('writePoem', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('writePoem', async options => { + const sessionId = await createSession({ promptName: 'Write a poem about this', + ...options, }); - }); - - AIProvider.provide('writeOutline', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('writeOutline', async options => { + const sessionId = await createSession({ promptName: 'Write outline', + ...options, }); - }); - - AIProvider.provide('writeBlogPost', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('writeBlogPost', async options => { + const sessionId = await createSession({ promptName: 'Write a blog post about this', + ...options, }); - }); - - AIProvider.provide('brainstorm', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('brainstorm', async options => { + const sessionId = await createSession({ promptName: 'Brainstorm ideas about this', + ...options, }); - }); - - AIProvider.provide('findActions', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('findActions', async options => { + const sessionId = await createSession({ promptName: 'Find action items from it', + ...options, }); - }); - - AIProvider.provide('brainstormMindmap', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('brainstormMindmap', async options => { + const sessionId = await createSession({ promptName: 'workflow:brainstorm', + ...options, + }); + return textToText({ + ...options, + client, + sessionId, + content: options.input, // 3 minutes timeout: 180000, workflow: true, }); }); - AIProvider.provide('expandMindmap', options => { + AIProvider.provide('expandMindmap', async options => { if (!options.input) { throw new Error('expandMindmap action requires input'); } + const sessionId = await createSession({ + promptName: 'Expand mind map', + ...options, + }); return textToText({ ...options, client, + sessionId, params: { mindmap: options.mindmap, node: options.input, }, content: options.input, - promptName: 'Expand mind map', }); }); - AIProvider.provide('explain', options => { - return textToText({ - ...options, - client, - content: options.input, + AIProvider.provide('explain', async options => { + const sessionId = await createSession({ promptName: 'Explain this', + ...options, }); - }); - - AIProvider.provide('explainImage', options => { return textToText({ ...options, client, + sessionId, content: options.input, - promptName: 'Explain this image', }); }); - AIProvider.provide('makeItReal', options => { + AIProvider.provide('explainImage', async options => { + const sessionId = await createSession({ + promptName: 'Explain this image', + ...options, + }); + return textToText({ + ...options, + client, + sessionId, + content: options.input, + }); + }); + + AIProvider.provide('makeItReal', async options => { let promptName: PromptKey = 'Make it real'; let content = options.input || ''; @@ -275,15 +402,20 @@ Here are our design notes:\n ${content}.`; Could you make a new website based on these notes and send back just the html file?`; } + const sessionId = await createSession({ + promptName, + ...options, + }); + return textToText({ ...options, client, + sessionId, content, - promptName, }); }); - AIProvider.provide('createSlides', options => { + AIProvider.provide('createSlides', async options => { const SlideSchema = z.object({ page: z.number(), type: z.enum(['name', 'title', 'content']), @@ -320,11 +452,15 @@ Could you make a new website based on these notes and send back just the html fi }) .join('\n'); }; + const sessionId = await createSession({ + promptName: 'workflow:presentation', + ...options, + }); return textToText({ ...options, client, + sessionId, content: options.input, - promptName: 'workflow:presentation', // 3 minutes timeout: 180000, workflow: true, @@ -332,79 +468,98 @@ Could you make a new website based on these notes and send back just the html fi }); }); - AIProvider.provide('createImage', options => { + AIProvider.provide('createImage', async options => { // test to image let promptName: PromptKey = 'debug:action:dalle3'; // image to image if (options.attachments?.length) { promptName = 'debug:action:fal-sd15'; } + + const sessionId = await createSession({ + promptName, + ...options, + }); return toImage({ ...options, client, + sessionId, content: options.input, - promptName, }); }); - AIProvider.provide('filterImage', options => { + AIProvider.provide('filterImage', async options => { // test to image - const promptName = filterStyleToPromptName.get(options.style as string); + const promptName: PromptKey | undefined = filterStyleToPromptName.get( + options.style + ); + if (!promptName) { + throw new Error('filterImage requires a promptName'); + } + const sessionId = await createSession({ + promptName, + ...options, + }); return toImage({ ...options, client, + sessionId, content: options.input, timeout: 180000, - promptName: promptName as PromptKey, workflow: !!promptName?.startsWith('workflow:'), }); }); - AIProvider.provide('processImage', options => { + AIProvider.provide('processImage', async options => { // test to image - const promptName = processTypeToPromptName.get( - options.type as string - ) as PromptKey; + const promptName: PromptKey | undefined = processTypeToPromptName.get( + options.type + ); + if (!promptName) { + throw new Error('processImage requires a promptName'); + } + const sessionId = await createSession({ + promptName, + ...options, + }); return toImage({ ...options, client, + sessionId, content: options.input, timeout: 180000, - promptName, }); }); - AIProvider.provide('generateCaption', options => { - return textToText({ - ...options, - client, - content: options.input, + AIProvider.provide('generateCaption', async options => { + const sessionId = await createSession({ promptName: 'Generate a caption', + ...options, }); - }); - - AIProvider.provide('continueWriting', options => { return textToText({ ...options, client, + sessionId, content: options.input, + }); + }); + + AIProvider.provide('continueWriting', async options => { + const sessionId = await createSession({ promptName: 'Continue writing', + ...options, + }); + return textToText({ + ...options, + client, + sessionId, + content: options.input, }); }); //#endregion AIProvider.provide('session', { - createSession: async ( - workspaceId: string, - docId: string, - promptName = 'Chat With AFFiNE AI' - ) => { - return client.createSession({ - workspaceId, - docId, - promptName, - }); - }, + createSession, getSessions: async ( workspaceId: string, docId?: string,