From e1eb9257040bd89d6e236c147af5cebc277e8a3a Mon Sep 17 00:00:00 2001 From: pengx17 Date: Mon, 15 Apr 2024 05:31:32 +0000 Subject: [PATCH] refactor(core): remove copilot client from presets (#6546) depends on https://github.com/toeverything/blocksuite/pull/6748 --- .../block-suite-editor/ai/copilot-client.ts | 18 +--- .../block-suite-editor/ai/provider.ts | 93 ++++++++++++++----- .../block-suite-editor/ai/request.ts | 55 ++++++++--- .../core/src/hooks/affine/use-current-user.ts | 2 +- .../e2e/local-first-workspace-list.spec.ts | 23 ++++- 5 files changed, 130 insertions(+), 61 deletions(-) diff --git a/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/copilot-client.ts b/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/copilot-client.ts index 6ecdf8903e..e96d9290d9 100644 --- a/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/copilot-client.ts +++ b/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/copilot-client.ts @@ -52,7 +52,9 @@ export class CopilotClient { async getHistories( workspaceId: string, docId?: string, - options?: OptionsField + options?: RequestOptions< + typeof getCopilotHistoriesQuery + >['variables']['options'] ) { const res = await fetcher({ query: getCopilotHistoriesQuery, @@ -66,20 +68,6 @@ export class CopilotClient { return res.currentUser?.copilot?.histories; } - async textToText(message: string, sessionId: string) { - const res = await fetch( - `${this.backendUrl}/api/copilot/chat/${sessionId}?message=${encodeURIComponent(message)}` - ); - if (!res.ok) return; - return res.text(); - } - - textToTextStream(message: string, sessionId: string) { - return new EventSource( - `${this.backendUrl}/api/copilot/chat/${sessionId}/stream?message=${encodeURIComponent(message)}` - ); - } - chatText({ sessionId, messageId, diff --git a/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/provider.ts b/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/provider.ts index 7b02a9085b..c610d4ccb4 100644 --- a/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/provider.ts +++ b/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/provider.ts @@ -1,17 +1,39 @@ +import { assertExists } from '@blocksuite/global/utils'; import { AIProvider } from '@blocksuite/presets'; -import { textToText } from './request'; +import { createChatSession, listHistories, textToText } from './request'; export function setupAIProvider() { - AIProvider.provideAction('chat', options => { + // a single workspace should have only a single chat session + // workspace-id:doc-id -> chat session id + const chatSessions = new Map>(); + + async function getChatSessionId(workspaceId: string, docId: string) { + const storeKey = `${workspaceId}:${docId}`; + if (!chatSessions.has(storeKey)) { + chatSessions.set( + storeKey, + createChatSession({ + workspaceId, + docId, + }) + ); + } + const sessionId = await chatSessions.get(storeKey); + assertExists(sessionId); + return sessionId; + } + + AIProvider.provide('chat', options => { + const sessionId = getChatSessionId(options.workspaceId, options.docId); return textToText({ ...options, content: options.input, - promptName: 'debug:chat:gpt4', + sessionId, }); }); - AIProvider.provideAction('summary', options => { + AIProvider.provide('summary', options => { return textToText({ ...options, content: options.input, @@ -19,7 +41,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('translate', options => { + AIProvider.provide('translate', options => { return textToText({ ...options, promptName: 'Translate to', @@ -30,7 +52,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('changeTone', options => { + AIProvider.provide('changeTone', options => { return textToText({ ...options, content: options.input, @@ -38,7 +60,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('improveWriting', options => { + AIProvider.provide('improveWriting', options => { return textToText({ ...options, content: options.input, @@ -46,7 +68,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('improveGrammar', options => { + AIProvider.provide('improveGrammar', options => { return textToText({ ...options, content: options.input, @@ -54,7 +76,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('fixSpelling', options => { + AIProvider.provide('fixSpelling', options => { return textToText({ ...options, content: options.input, @@ -62,7 +84,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('createHeadings', options => { + AIProvider.provide('createHeadings', options => { return textToText({ ...options, content: options.input, @@ -70,7 +92,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('makeLonger', options => { + AIProvider.provide('makeLonger', options => { return textToText({ ...options, content: options.input, @@ -78,7 +100,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('makeShorter', options => { + AIProvider.provide('makeShorter', options => { return textToText({ ...options, content: options.input, @@ -86,7 +108,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('checkCodeErrors', options => { + AIProvider.provide('checkCodeErrors', options => { return textToText({ ...options, content: options.input, @@ -94,7 +116,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('explainCode', options => { + AIProvider.provide('explainCode', options => { return textToText({ ...options, content: options.input, @@ -102,7 +124,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('writeArticle', options => { + AIProvider.provide('writeArticle', options => { return textToText({ ...options, content: options.input, @@ -110,7 +132,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('writeTwitterPost', options => { + AIProvider.provide('writeTwitterPost', options => { return textToText({ ...options, content: options.input, @@ -118,7 +140,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('writePoem', options => { + AIProvider.provide('writePoem', options => { return textToText({ ...options, content: options.input, @@ -126,7 +148,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('writeOutline', options => { + AIProvider.provide('writeOutline', options => { return textToText({ ...options, content: options.input, @@ -134,7 +156,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('writeBlogPost', options => { + AIProvider.provide('writeBlogPost', options => { return textToText({ ...options, content: options.input, @@ -142,7 +164,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('brainstorm', options => { + AIProvider.provide('brainstorm', options => { return textToText({ ...options, content: options.input, @@ -150,7 +172,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('findActions', options => { + AIProvider.provide('findActions', options => { return textToText({ ...options, content: options.input, @@ -158,7 +180,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('brainstormMindmap', options => { + AIProvider.provide('brainstormMindmap', options => { return textToText({ ...options, content: options.input, @@ -166,7 +188,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('explain', options => { + AIProvider.provide('explain', options => { return textToText({ ...options, content: options.input, @@ -174,7 +196,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('explainImage', options => { + AIProvider.provide('explainImage', options => { return textToText({ ...options, content: options.input, @@ -182,7 +204,7 @@ export function setupAIProvider() { }); }); - AIProvider.provideAction('makeItReal', options => { + AIProvider.provide('makeItReal', options => { return textToText({ ...options, promptName: 'Make it real', @@ -192,4 +214,25 @@ export function setupAIProvider() { 'Here are the latest wireframes. Could you make a new website based on these wireframes and notes and send back just the html file?', }); }); + + AIProvider.provide('histories', { + actions: async ( + workspaceId: string, + docId?: string + ): Promise => { + // @ts-expect-error - 'action' is missing in server impl + return ( + (await listHistories(workspaceId, docId, { + action: true, + })) ?? [] + ); + }, + chats: async ( + workspaceId: string, + docId?: string + ): Promise => { + // @ts-expect-error - 'action' is missing in server impl + return (await listHistories(workspaceId, docId)) ?? []; + }, + }); } diff --git a/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/request.ts b/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/request.ts index eedf20557b..1e4541b7ab 100644 --- a/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/request.ts +++ b/packages/frontend/core/src/components/blocksuite/block-suite-editor/ai/request.ts @@ -7,7 +7,7 @@ const TIMEOUT = 50000; const client = new CopilotClient(); -function readBlobAsURL(blob: Blob) { +function readBlobAsURL(blob: Blob | File) { return new Promise((resolve, reject) => { const reader = new FileReader(); reader.onload = e => { @@ -25,28 +25,48 @@ function readBlobAsURL(blob: Blob) { export type TextToTextOptions = { docId: string; workspaceId: string; - promptName: PromptKey; + promptName?: PromptKey; + sessionId?: string | Promise; content?: string; - attachments?: (string | Blob)[]; + attachments?: (string | Blob | File)[]; params?: Record; timeout?: number; stream?: boolean; }; +export function createChatSession({ + workspaceId, + docId, +}: { + workspaceId: string; + docId: string; +}) { + return client.createSession({ + workspaceId, + docId, + promptName: 'debug:chat:gpt4', + }); +} + async function createSessionMessage({ docId, workspaceId, promptName, content, + sessionId: providedSessionId, attachments, params, }: TextToTextOptions) { const hasAttachments = attachments && attachments.length > 0; - const session = await client.createSession({ - workspaceId, - docId, - promptName, - }); + if (!promptName && !providedSessionId) { + throw new Error('promptName or sessionId is required'); + } + const sessionId = await (providedSessionId ?? + client.createSession({ + workspaceId, + docId, + promptName: promptName as string, + })); if (hasAttachments) { const normalizedAttachments = await Promise.all( attachments.map(async attachment => { @@ -58,19 +78,19 @@ async function createSessionMessage({ }) ); const messageId = await client.createMessage({ - sessionId: session, + sessionId: sessionId, content, attachments: normalizedAttachments, params, }); return { messageId, - session, + sessionId, }; } else if (content) { return { message: content, - session, + sessionId, }; } else { throw new Error('No content or attachments provided'); @@ -84,6 +104,7 @@ export function textToText({ content, attachments, params, + sessionId, stream, timeout = TIMEOUT, }: TextToTextOptions) { @@ -97,10 +118,11 @@ export function textToText({ content, attachments, params, + sessionId, }); const eventSource = client.chatTextStream({ - sessionId: message.session, + sessionId: message.sessionId, messageId: message.messageId, message: message.message, }); @@ -123,9 +145,10 @@ export function textToText({ content, attachments, params, - }).then(message => { - return client.chatText({ - sessionId: message.session, + sessionId, + }).then(async message => { + return await client.chatText({ + sessionId: message.sessionId, messageId: message.messageId, message: message.message, }); @@ -133,3 +156,5 @@ export function textToText({ ]); } } + +export const listHistories = client.getHistories; diff --git a/packages/frontend/core/src/hooks/affine/use-current-user.ts b/packages/frontend/core/src/hooks/affine/use-current-user.ts index 76e1465e58..b74414b2c9 100644 --- a/packages/frontend/core/src/hooks/affine/use-current-user.ts +++ b/packages/frontend/core/src/hooks/affine/use-current-user.ts @@ -153,7 +153,7 @@ export function useCurrentUser(): CheckedUser { const user = session.user; dispatcher({ type: 'update', payload: user }); // todo: move this to a better place! - AIProvider.provideUserInfo(() => { + AIProvider.provide('userInfo', () => { return user; }); } else { diff --git a/tests/affine-local/e2e/local-first-workspace-list.spec.ts b/tests/affine-local/e2e/local-first-workspace-list.spec.ts index 8e74758e68..8052cb5134 100644 --- a/tests/affine-local/e2e/local-first-workspace-list.spec.ts +++ b/tests/affine-local/e2e/local-first-workspace-list.spec.ts @@ -76,9 +76,9 @@ test('create multi workspace in the workspace list', async ({ await page.waitForTimeout(1000); { - //check workspace list length - const workspaceCards = await page.$$('data-testid=workspace-card'); - expect(workspaceCards.length).toBe(3); + // check workspace list length + const workspaceCards = page.getByTestId('workspace-card'); + await expect(workspaceCards).toHaveCount(3); } await page.reload(); @@ -118,7 +118,20 @@ test('create multi workspace in the workspace list', async ({ } ); await page.mouse.up(); - await page.waitForTimeout(1000); + + // check workspace list order + await page.waitForFunction( + () => { + const cards = document.querySelectorAll('[data-testid="workspace-card"]'); + return ( + cards[1].textContent?.includes('New Workspace 3') && + cards[2].textContent?.includes('New Workspace 2') + ); + }, + [], + { timeout: 5000 } + ); + await page.reload(); await openWorkspaceListModal(page); @@ -127,7 +140,7 @@ test('create multi workspace in the workspace list', async ({ { await page.waitForTimeout(1000); const workspaceCards = page.getByTestId('workspace-card'); - expect(await workspaceCards.count()).toBe(3); + await expect(workspaceCards).toHaveCount(3); } const workspaceChangePromise = page.evaluate(() => {