From 181ccf5a45e0b9ba120fb28a2c9276bb74388b50 Mon Sep 17 00:00:00 2001 From: DarkSky <25152247+darkskygit@users.noreply.github.com> Date: Mon, 7 Jul 2025 23:05:02 +0800 Subject: [PATCH] fix(server): rerank scores calc (#13016) fix AI-257 --- .../src/__tests__/copilot-provider.spec.ts | 52 +++++++++ .../src/plugins/copilot/context/service.ts | 4 +- .../src/plugins/copilot/embedding/client.ts | 11 +- .../src/plugins/copilot/embedding/types.ts | 22 ++-- .../src/plugins/copilot/prompt/prompts.ts | 4 +- .../copilot/providers/gemini/gemini.ts | 36 +++++++ .../copilot/providers/gemini/generative.ts | 16 ++- .../copilot/providers/gemini/vertex.ts | 11 ++ .../src/plugins/copilot/providers/openai.ts | 37 ++++--- .../src/plugins/copilot/providers/utils.ts | 6 ++ .../copilot/tools/doc-keyword-search.ts | 8 +- .../copilot/tools/doc-semantic-search.ts | 14 +-- .../entities/embedding-progress.ts | 2 +- .../view/embedding-settings.tsx | 31 ++++-- .../e2e/settings/embedding.spec.ts | 100 ++++++++++-------- .../e2e/utils/chat-panel-utils.ts | 29 +++-- .../e2e/utils/settings-panel-utils.ts | 49 +++++++++ tests/kit/src/utils/workspace.ts | 6 +- 18 files changed, 329 insertions(+), 109 deletions(-) diff --git a/packages/backend/server/src/__tests__/copilot-provider.spec.ts b/packages/backend/server/src/__tests__/copilot-provider.spec.ts index 42bd10fd46..6a5c9880cf 100644 --- a/packages/backend/server/src/__tests__/copilot-provider.spec.ts +++ b/packages/backend/server/src/__tests__/copilot-provider.spec.ts @@ -829,3 +829,55 @@ for (const { name, content, verifier } of workflows) { } ); } + +// ==================== rerank ==================== + +test( + 'should be able to rerank message chunks', + runIfCopilotConfigured, + async t => { + const { factory, prompt } = t.context; + + await retry('rerank', t, async t => { + const query = 'Is this content relevant to programming?'; + const embeddings = [ + 'How to write JavaScript code for web development.', + 'Today is a beautiful sunny day for walking in the park.', + 'Python is a popular programming language for data science.', + 'The weather forecast predicts rain for the weekend.', + 'JavaScript frameworks like React and Angular are widely used.', + 'Cooking recipes can be found in many online blogs.', + 'Machine learning algorithms are essential for AI development.', + 'The latest smartphone models have impressive camera features.', + 'Learning to code can open up many career opportunities.', + 'The stock market is experiencing significant fluctuations.', + ]; + + const p = (await prompt.get('Rerank results'))!; + t.assert(p, 'should have prompt for rerank'); + const provider = (await factory.getProviderByModel(p.model))!; + t.assert(provider, 'should have provider for rerank'); + + const scores = await provider.rerank( + { modelId: p.model }, + embeddings.map(e => p.finish({ query, doc: e })) + ); + + t.is(scores.length, 10, 'should return scores for all chunks'); + + for (const score of scores) { + t.assert( + typeof score === 'number' && score >= 0 && score <= 1, + `score should be a number between 0 and 1, got ${score}` + ); + } + + t.log('Rerank scores:', scores); + t.is( + scores.filter(s => s > 0.5).length, + 4, + 'should have 4 related chunks' + ); + }); + } +); diff --git a/packages/backend/server/src/plugins/copilot/context/service.ts b/packages/backend/server/src/plugins/copilot/context/service.ts index e20e619c94..9d2e55658c 100644 --- a/packages/backend/server/src/plugins/copilot/context/service.ts +++ b/packages/backend/server/src/plugins/copilot/context/service.ts @@ -215,7 +215,6 @@ export class CopilotContextService implements OnApplicationBootstrap { topK * 2, threshold ), - this.models.copilotContext.matchWorkspaceEmbedding( embedding, workspaceId, @@ -237,8 +236,9 @@ export class CopilotContextService implements OnApplicationBootstrap { !fileChunks.length && !workspaceChunks.length && !scopedWorkspaceChunks?.length - ) + ) { return []; + } return await this.embeddingClient.reRank( content, diff --git a/packages/backend/server/src/plugins/copilot/embedding/client.ts b/packages/backend/server/src/plugins/copilot/embedding/client.ts index 720e9e927c..cdc345e2f2 100644 --- a/packages/backend/server/src/plugins/copilot/embedding/client.ts +++ b/packages/backend/server/src/plugins/copilot/embedding/client.ts @@ -20,6 +20,7 @@ import { type ReRankResult, } from './types'; +const EMBEDDING_MODEL = 'gemini-embedding-001'; const RERANK_PROMPT = 'Rerank results'; class ProductionEmbeddingClient extends EmbeddingClient { @@ -34,6 +35,7 @@ class ProductionEmbeddingClient extends EmbeddingClient { override async configured(): Promise { const embedding = await this.providerFactory.getProvider({ + modelId: EMBEDDING_MODEL, outputType: ModelOutputType.Embedding, }); const result = Boolean(embedding); @@ -60,6 +62,7 @@ class ProductionEmbeddingClient extends EmbeddingClient { async getEmbeddings(input: string[]): Promise { const provider = await this.getProvider({ + modelId: EMBEDDING_MODEL, outputType: ModelOutputType.Embedding, }); this.logger.verbose( @@ -109,9 +112,9 @@ class ProductionEmbeddingClient extends EmbeddingClient { ); try { - return ranks.map((score, i) => ({ - chunk: embeddings[i].content, - targetId: this.getTargetId(embeddings[i]), + return ranks.map((score, chunk) => ({ + chunk, + targetId: this.getTargetId(embeddings[chunk]), score, })); } catch (error) { @@ -171,7 +174,7 @@ class ProductionEmbeddingClient extends EmbeddingClient { const highConfidenceChunks = ranks .flat() .toSorted((a, b) => b.score - a.score) - .filter(r => r.score > 5) + .filter(r => r.score > 0.5) .map(r => chunks[`${r.targetId}:${r.chunk}`]) .filter(Boolean); diff --git a/packages/backend/server/src/plugins/copilot/embedding/types.ts b/packages/backend/server/src/plugins/copilot/embedding/types.ts index fc9368fe44..751f23f75c 100644 --- a/packages/backend/server/src/plugins/copilot/embedding/types.ts +++ b/packages/backend/server/src/plugins/copilot/embedding/types.ts @@ -176,17 +176,15 @@ export abstract class EmbeddingClient { } const ReRankItemSchema = z.object({ - scores: z.object({ - chunk: z.string().describe('The chunk index of the search result.'), - targetId: z.string().describe('The id of the target.'), - score: z - .number() - .min(0) - .max(10) - .describe( - 'The relevance score of the results should be 0-10, with 0 being the least relevant and 10 being the most relevant.' - ), - }), + chunk: z.number().describe('The chunk index of the search result.'), + targetId: z.string().describe('The id of the target.'), + score: z + .number() + .min(0) + .max(10) + .describe( + 'The relevance score of the results should be 0-10, with 0 being the least relevant and 10 being the most relevant.' + ), }); -export type ReRankResult = z.infer['scores'][]; +export type ReRankResult = z.infer[]; diff --git a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts index c939bbb677..7607aaff3b 100644 --- a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts +++ b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts @@ -346,7 +346,7 @@ Convert a multi-speaker audio recording into a structured JSON format by transcr }, { role: 'user', - content: `: Given a web search query, retrieve relevant passages that answer the query\n: {query}\n: {doc}`, + content: `: Given a document search result, determine whether the result is relevant to the query.\n: {{query}}\n: {{doc}}`, }, ], }, @@ -1676,7 +1676,7 @@ This sentence contains information from the first source[^1]. This sentence refe Before starting Tool calling, you need to follow: - DO NOT embed a tool call mid-sentence. -- When searching for information, searching web & searching the user's Workspace information. +- When searching for unknown information or keyword, prioritize searching the user's workspace. - Depending on the complexity of the question and the information returned by the search tools, you can call different tools multiple times to search. diff --git a/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts b/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts index 7cb11fece4..9822653437 100644 --- a/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts +++ b/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts @@ -5,6 +5,7 @@ import type { import type { GoogleVertexProvider } from '@ai-sdk/google-vertex'; import { AISDKError, + embedMany, generateObject, generateText, JSONParseError, @@ -20,6 +21,7 @@ import { import { CopilotProvider } from '../provider'; import type { CopilotChatOptions, + CopilotEmbeddingOptions, CopilotImageOptions, CopilotProviderModel, ModelConditions, @@ -211,6 +213,40 @@ export abstract class GeminiProvider extends CopilotProvider { } } + override async embedding( + cond: ModelConditions, + messages: string | string[], + options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS } + ): Promise { + messages = Array.isArray(messages) ? messages : [messages]; + const fullCond = { ...cond, outputType: ModelOutputType.Embedding }; + await this.checkParams({ embeddings: messages, cond: fullCond, options }); + const model = this.selectModel(fullCond); + + try { + metrics.ai + .counter('generate_embedding_calls') + .add(1, { model: model.id }); + + const modelInstance = this.instance.textEmbeddingModel(model.id, { + outputDimensionality: options.dimensions || DEFAULT_DIMENSIONS, + taskType: 'RETRIEVAL_DOCUMENT', + }); + + const { embeddings } = await embedMany({ + model: modelInstance, + values: messages, + }); + + return embeddings.filter(v => v && Array.isArray(v)); + } catch (e: any) { + metrics.ai + .counter('generate_embedding_errors') + .add(1, { model: model.id }); + throw this.handleError(e); + } + } + private async getFullStream( model: CopilotProviderModel, messages: PromptMessage[], diff --git a/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts b/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts index 3e4f0d7528..32b4575917 100644 --- a/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts +++ b/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts @@ -71,8 +71,8 @@ export class GeminiGenerativeProvider extends GeminiProvider { }, ], }, + { + name: 'Gemini Embedding', + id: 'gemini-embedding-001', + capabilities: [ + { + input: [ModelInputType.Text], + output: [ModelOutputType.Embedding], + defaultForOutputType: true, + }, + ], + }, ]; protected instance!: GoogleVertexProvider; diff --git a/packages/backend/server/src/plugins/copilot/providers/openai.ts b/packages/backend/server/src/plugins/copilot/providers/openai.ts index 8f6e3d232e..f7d8d9c8be 100644 --- a/packages/backend/server/src/plugins/copilot/providers/openai.ts +++ b/packages/backend/server/src/plugins/copilot/providers/openai.ts @@ -450,7 +450,8 @@ export class OpenAIProvider extends CopilotProvider { const fullCond = { ...cond, outputType: ModelOutputType.Text }; await this.checkParams({ messages: [], cond: fullCond, options }); const model = this.selectModel(fullCond); - const instance = this.#instance.responses(model.id); + // get the log probability of "yes"/"no" + const instance = this.#instance(model.id, { logprobs: 16 }); const scores = await Promise.all( chunkMessages.map(async messages => { @@ -461,29 +462,37 @@ export class OpenAIProvider extends CopilotProvider { system, messages: msgs, temperature: 0, - maxTokens: 1, + maxTokens: 16, providerOptions: { openai: { ...this.getOpenAIOptions(options, model.id), - // get the log probability of "yes"/"no" - logprobs: 2, }, }, - maxSteps: 1, abortSignal: options.signal, }); - const top = (logprobs?.[0]?.topLogprobs ?? []).reduce( - (acc, item) => { - acc[item.token] = item.logprob; - return acc; - }, - {} as Record + const topMap: Record = ( + logprobs?.[0]?.topLogprobs ?? [] + ).reduce>( + (acc, { token, logprob }) => ({ ...acc, [token]: logprob }), + {} ); - // OpenAI often includes a leading space, so try matching both ' yes' and 'yes' - const logYes = top[' yes'] ?? top['yes'] ?? Number.NEGATIVE_INFINITY; - const logNo = top[' no'] ?? top['no'] ?? Number.NEGATIVE_INFINITY; + const findLogProb = (token: string): number => { + // OpenAI often includes a leading space, so try matching '.yes', '_yes', ' yes' and 'yes' + return [`.${token}`, `_${token}`, ` ${token}`, token] + .flatMap(v => [v, v.toLowerCase(), v.toUpperCase()]) + .reduce( + (best, key) => + (topMap[key] ?? Number.NEGATIVE_INFINITY) > best + ? topMap[key] + : best, + Number.NEGATIVE_INFINITY + ); + }; + + const logYes = findLogProb('Yes'); + const logNo = findLogProb('No'); const pYes = Math.exp(logYes); const pNo = Math.exp(logNo); diff --git a/packages/backend/server/src/plugins/copilot/providers/utils.ts b/packages/backend/server/src/plugins/copilot/providers/utils.ts index fee397bd32..8539a038f5 100644 --- a/packages/backend/server/src/plugins/copilot/providers/utils.ts +++ b/packages/backend/server/src/plugins/copilot/providers/utils.ts @@ -501,6 +501,12 @@ export class TextStreamParser { case 'doc_semantic_search': { if (Array.isArray(chunk.result)) { result += `\nFound ${chunk.result.length} document${chunk.result.length !== 1 ? 's' : ''} related to “${chunk.args.query}”.\n`; + } else if (typeof chunk.result === 'string') { + result += `\n${chunk.result}\n`; + } else { + this.logger.warn( + `Unexpected result type for doc_semantic_search: ${chunk.result?.message || 'Unknown error'}` + ); } break; } diff --git a/packages/backend/server/src/plugins/copilot/tools/doc-keyword-search.ts b/packages/backend/server/src/plugins/copilot/tools/doc-keyword-search.ts index 74571c1e7f..a565b81856 100644 --- a/packages/backend/server/src/plugins/copilot/tools/doc-keyword-search.ts +++ b/packages/backend/server/src/plugins/copilot/tools/doc-keyword-search.ts @@ -39,9 +39,13 @@ export const createDocKeywordSearchTool = ( ) => { return tool({ description: - 'Search all workspace documents for the exact keyword or phrase supplied and return passages ranked by textual match. Use this tool by default whenever a straightforward term-based lookup is sufficient.', + 'Fuzzy search all workspace documents for the exact keyword or phrase supplied and return passages ranked by textual match. Use this tool by default whenever a straightforward term-based or keyword-base lookup is sufficient.', parameters: z.object({ - query: z.string().describe('The query to search for'), + query: z + .string() + .describe( + 'The query to search for, e.g. "meeting notes" or "project plan".' + ), }), execute: async ({ query }) => { try { diff --git a/packages/backend/server/src/plugins/copilot/tools/doc-semantic-search.ts b/packages/backend/server/src/plugins/copilot/tools/doc-semantic-search.ts index c2257c0f2f..299b4dd2d2 100644 --- a/packages/backend/server/src/plugins/copilot/tools/doc-semantic-search.ts +++ b/packages/backend/server/src/plugins/copilot/tools/doc-semantic-search.ts @@ -19,13 +19,14 @@ export const buildDocSearchGetter = ( abortSignal?: AbortSignal ) => { if (!options || !query?.trim() || !options.user || !options.workspace) { - return undefined; + return `Invalid search parameters.`; } const canAccess = await ac .user(options.user) .workspace(options.workspace) .can('Workspace.Read'); - if (!canAccess) return undefined; + if (!canAccess) + return 'You do not have permission to access this workspace.'; const [chunks, contextChunks] = await Promise.all([ context.matchWorkspaceAll(options.workspace, query, 10, abortSignal), docContext?.matchFiles(query, 10, abortSignal) ?? [], @@ -42,7 +43,8 @@ export const buildDocSearchGetter = ( if (contextChunks.length) { fileChunks.push(...contextChunks); } - if (!docChunks.length && !fileChunks.length) return undefined; + if (!docChunks.length && !fileChunks.length) + return `No results found for "${query}".`; return [...fileChunks, ...docChunks]; }; return searchDocs; @@ -52,16 +54,16 @@ export const createDocSemanticSearchTool = ( searchDocs: ( query: string, abortSignal?: AbortSignal - ) => Promise + ) => Promise ) => { return tool({ description: - 'Retrieve conceptually related passages by performing vector-based semantic similarity search across embedded documents; call this tool only when exact keyword search fails or the user explicitly needs meaning-level matches (e.g., paraphrases, synonyms, broader concepts).', + 'Retrieve conceptually related passages by performing vector-based semantic similarity search across embedded documents; use this tool only when exact keyword search fails or the user explicitly needs meaning-level matches (e.g., paraphrases, synonyms, broader concepts).', parameters: z.object({ query: z .string() .describe( - 'The query statement to search for, e.g. "What is the capital of France?"' + 'The query statement to search for, e.g. "What is the capital of France?"\nWhen querying specific terms or IDs, you should provide the complete string instead of separating it with delimiters.\nFor example, if a user wants to look up the ID "sicDoe1is", use "What is sicDoe1is" instead of "si code 1is".' ), }), execute: async ({ query }, options) => { diff --git a/packages/frontend/core/src/modules/workspace-indexer-embedding/entities/embedding-progress.ts b/packages/frontend/core/src/modules/workspace-indexer-embedding/entities/embedding-progress.ts index 7e66e9c4cc..eb803c9fee 100644 --- a/packages/frontend/core/src/modules/workspace-indexer-embedding/entities/embedding-progress.ts +++ b/packages/frontend/core/src/modules/workspace-indexer-embedding/entities/embedding-progress.ts @@ -60,7 +60,7 @@ export class EmbeddingProgress extends Entity { smartRetry(), mergeMap(value => { this.progress$.next(value); - if (value && value.embedded === value.total) { + if (value && value.embedded === value.total && value.total > 0) { this.stopEmbeddingProgressPolling(); } return EMPTY; diff --git a/packages/frontend/core/src/modules/workspace-indexer-embedding/view/embedding-settings.tsx b/packages/frontend/core/src/modules/workspace-indexer-embedding/view/embedding-settings.tsx index 689d900bc1..7d56c40b94 100644 --- a/packages/frontend/core/src/modules/workspace-indexer-embedding/view/embedding-settings.tsx +++ b/packages/frontend/core/src/modules/workspace-indexer-embedding/view/embedding-settings.tsx @@ -62,18 +62,25 @@ const EmbeddingCloud: React.FC<{ disabled: boolean }> = ({ disabled }) => { option: checked ? 'on' : 'off', }); - embeddingService.embeddingEnabled.setEnabled(checked).catch(error => { - const err = UserFriendlyError.fromAny(error); - notify.error({ - title: - t[ - 'com.affine.settings.workspace.indexer-embedding.embedding.switch.error' - ](), - message: t[`error.${err.name}`](err.data), + embeddingService.embeddingEnabled + .setEnabled(checked) + .then(() => { + if (checked) { + embeddingService.embeddingProgress.startEmbeddingProgressPolling(); + } + }) + .catch(error => { + const err = UserFriendlyError.fromAny(error); + notify.error({ + title: + t[ + 'com.affine.settings.workspace.indexer-embedding.embedding.switch.error' + ](), + message: t[`error.${err.name}`](err.data), + }); }); - }); }, - [embeddingService.embeddingEnabled, t] + [embeddingService.embeddingEnabled, embeddingService.embeddingProgress, t] ); const handleAttachmentUpload = useCallback( @@ -84,8 +91,10 @@ const EmbeddingCloud: React.FC<{ disabled: boolean }> = ({ disabled }) => { docType: file.type, }); embeddingService.additionalAttachments.addAttachments([file]); + // Restart polling to track progress of newly uploaded files + embeddingService.embeddingProgress.startEmbeddingProgressPolling(); }, - [embeddingService.additionalAttachments] + [embeddingService.additionalAttachments, embeddingService.embeddingProgress] ); const handleAttachmentsDelete = useCallback( diff --git a/tests/affine-cloud-copilot/e2e/settings/embedding.spec.ts b/tests/affine-cloud-copilot/e2e/settings/embedding.spec.ts index d4c86297f8..080556cf46 100644 --- a/tests/affine-cloud-copilot/e2e/settings/embedding.spec.ts +++ b/tests/affine-cloud-copilot/e2e/settings/embedding.spec.ts @@ -1,15 +1,13 @@ import { createLocalWorkspace } from '@affine-test/kit/utils/workspace'; +import { faker } from '@faker-js/faker'; import { expect } from '@playwright/test'; import { test } from '../base/base-test'; -test.describe.configure({ mode: 'serial' }); - test.describe('AISettings/Embedding', () => { test.beforeEach(async ({ loggedInPage: page, utils }) => { await utils.testUtils.setupTestEnvironment(page); await utils.chatPanel.openChatPanel(page); - await utils.settings.openSettingsPanel(page); }); test.afterEach(async ({ loggedInPage: page, utils }) => { @@ -23,6 +21,7 @@ test.describe('AISettings/Embedding', () => { loggedInPage: page, utils, }) => { + await utils.settings.openSettingsPanel(page); await utils.settings.waitForWorkspaceEmbeddingSwitchToBe(page, true); }); @@ -30,6 +29,7 @@ test.describe('AISettings/Embedding', () => { loggedInPage: page, utils, }) => { + await utils.settings.openSettingsPanel(page); await utils.settings.enableWorkspaceEmbedding(page); await utils.settings.disableWorkspaceEmbedding(page); await utils.settings.waitForWorkspaceEmbeddingSwitchToBe(page, false); @@ -39,6 +39,7 @@ test.describe('AISettings/Embedding', () => { loggedInPage: page, utils, }) => { + await utils.settings.openSettingsPanel(page); await utils.settings.disableWorkspaceEmbedding(page); await utils.settings.enableWorkspaceEmbedding(page); await utils.settings.waitForWorkspaceEmbeddingSwitchToBe(page, true); @@ -99,6 +100,7 @@ test.describe('AISettings/Embedding', () => { loggedInPage: page, utils, }) => { + await utils.settings.openSettingsPanel(page); await utils.settings.enableWorkspaceEmbedding(page); await utils.settings.disableWorkspaceEmbedding(page); await utils.settings.waitForWorkspaceEmbeddingSwitchToBe(page, false); @@ -116,6 +118,7 @@ test.describe('AISettings/Embedding', () => { loggedInPage: page, utils, }) => { + await utils.settings.openSettingsPanel(page); await utils.settings.enableWorkspaceEmbedding(page); await page.getByTestId('embedding-progress-wrapper'); @@ -134,9 +137,13 @@ test.describe('AISettings/Embedding', () => { loggedInPage: page, utils, }) => { + await createLocalWorkspace({ name: 'test' }, page, false, 'affine-cloud'); + await utils.settings.openSettingsPanel(page); await utils.settings.enableWorkspaceEmbedding(page); - const textContent1 = 'WorkspaceEBEEE is a cute cat'; - const textContent2 = 'WorkspaceEBFFF is a cute dog'; + const randomStr1 = Math.random().toString(36).substring(2, 6); + const randomStr2 = Math.random().toString(36).substring(2, 6); + const textContent1 = `Workspace${randomStr1} is a cute cat`; + const textContent2 = `Workspace${randomStr2} is a cute dog`; const buffer1 = Buffer.from(textContent1); const buffer2 = Buffer.from(textContent2); const attachments = [ @@ -164,15 +171,6 @@ test.describe('AISettings/Embedding', () => { await utils.settings.uploadWorkspaceEmbedding(page, attachments); - const attachmentList = await page.getByTestId( - 'workspace-embedding-setting-attachment-list' - ); - - // Persisted - await expect( - attachmentList.getByTestId('workspace-embedding-setting-attachment-item') - ).toHaveCount(2); - await client.send('Network.emulateNetworkConditions', { offline: false, latency: 0, @@ -180,19 +178,19 @@ test.describe('AISettings/Embedding', () => { uploadThroughput: -1, }); - await utils.settings.closeSettingsPanel(page); + await utils.settings.waitForFileEmbeddingReadiness(page, 2); - await page.waitForTimeout(5000); // wait for the embedding to be ready + await utils.settings.closeSettingsPanel(page); await utils.chatPanel.makeChat( page, - 'What is WorkspaceEBEEE? What is WorkspaceEBFFF?' + `What is Workspace${randomStr1}? What is Workspace${randomStr2}?` ); await utils.chatPanel.waitForHistory(page, [ { role: 'user', - content: 'What is WorkspaceEBEEE? What is WorkspaceEBFFF?', + content: `What is Workspace${randomStr1}? What is Workspace${randomStr2}?`, }, { role: 'assistant', @@ -203,8 +201,8 @@ test.describe('AISettings/Embedding', () => { await expect(async () => { const { content, message } = await utils.chatPanel.getLatestAssistantMessage(page); - expect(content).toMatch(/WorkspaceEBEEE.*cat/); - expect(content).toMatch(/WorkspaceEBFFF.*dog/); + expect(content).toMatch(new RegExp(`Workspace${randomStr1}.*cat`)); + expect(content).toMatch(new RegExp(`Workspace${randomStr2}.*dog`)); expect(await message.locator('affine-footnote-node').count()).toBe(2); }).toPass({ timeout: 20000 }); }); @@ -213,6 +211,8 @@ test.describe('AISettings/Embedding', () => { loggedInPage: page, utils, }) => { + await createLocalWorkspace({ name: 'test' }, page, false, 'affine-cloud'); + await utils.settings.openSettingsPanel(page); await utils.settings.enableWorkspaceEmbedding(page); const attachments = [ { @@ -226,7 +226,7 @@ test.describe('AISettings/Embedding', () => { await utils.settings.uploadWorkspaceEmbedding(page, attachments); - const attachmentList = await page.getByTestId( + const attachmentList = page.getByTestId( 'workspace-embedding-setting-attachment-list' ); @@ -243,45 +243,42 @@ test.describe('AISettings/Embedding', () => { loggedInPage: page, utils, }) => { + await createLocalWorkspace({ name: 'test' }, page, false, 'affine-cloud'); + await utils.settings.openSettingsPanel(page); await utils.settings.enableWorkspaceEmbedding(page); - const hobby1 = Buffer.from('Jerry-Affine love climbing'); - const hobby2 = Buffer.from('Jerry-Affine love skating'); + const person = faker.person.fullName(); + + const hobby1 = Buffer.from(`${person} love climbing`); + const hobby2 = Buffer.from(`${person} love skating`); const attachments = [ { - name: 'jerry-affine-hobby.txt', + name: 'hobby.txt', mimeType: 'text/plain', buffer: hobby1, }, ]; await utils.settings.uploadWorkspaceEmbedding(page, attachments); - const attachmentList = await page.getByTestId( - 'workspace-embedding-setting-attachment-list' - ); - await expect( - attachmentList.getByTestId('workspace-embedding-setting-attachment-item') - ).toHaveCount(1); + await utils.settings.waitForFileEmbeddingReadiness(page, 1); await utils.settings.closeSettingsPanel(page); - await page.waitForTimeout(5000); // wait for the embedding to be ready - await utils.chatPanel.chatWithAttachments( page, [ { - name: 'jerry-affine-hobby2.txt', + name: 'hobby2.txt', mimeType: 'text/plain', buffer: hobby2, }, ], - 'What is Jerry-Affine hobby?' + `What is ${person}'s hobby?` ); await utils.chatPanel.waitForHistory(page, [ { role: 'user', - content: 'What is Jerry-Affine hobby?', + content: `What is ${person}'s hobby?`, }, { role: 'assistant', @@ -302,6 +299,8 @@ test.describe('AISettings/Embedding', () => { loggedInPage: page, utils, }) => { + await createLocalWorkspace({ name: 'test' }, page, false, 'affine-cloud'); + await utils.settings.openSettingsPanel(page); await utils.settings.enableWorkspaceEmbedding(page); const attachments = Array.from({ length: 11 }, (_, i) => ({ name: `document${i + 1}.txt`, @@ -318,11 +317,11 @@ test.describe('AISettings/Embedding', () => { await expect( attachmentList.getByTestId('workspace-embedding-setting-attachment-item') ).toHaveCount(10); - const pagination = await attachmentList.getByRole('navigation'); - const currentPage = await pagination.locator('li.active'); + const pagination = attachmentList.getByRole('navigation'); + const currentPage = pagination.locator('li.active'); await expect(currentPage).toHaveText('1'); - const page2 = await pagination.locator('li').nth(2); + const page2 = pagination.locator('li').nth(2); await page2.click(); await expect( @@ -339,8 +338,11 @@ test.describe('AISettings/Embedding', () => { loggedInPage: page, utils, }) => { + await createLocalWorkspace({ name: 'test' }, page, false, 'affine-cloud'); + await utils.settings.openSettingsPanel(page); await utils.settings.enableWorkspaceEmbedding(page); - const textContent = 'WorkspaceEBEEE is a cute cat'; + const randomStr1 = Math.random().toString(36).substring(2, 6); + const textContent = `Workspace${randomStr1} is a cute cat`; const attachments = [ { name: 'document1.txt', @@ -350,7 +352,7 @@ test.describe('AISettings/Embedding', () => { ]; await utils.settings.uploadWorkspaceEmbedding(page, attachments); - const attachmentList = await page.getByTestId( + const attachmentList = page.getByTestId( 'workspace-embedding-setting-attachment-list' ); await expect( @@ -363,8 +365,11 @@ test.describe('AISettings/Embedding', () => { loggedInPage: page, utils, }) => { + await createLocalWorkspace({ name: 'test' }, page, false, 'affine-cloud'); + await utils.settings.openSettingsPanel(page); await utils.settings.enableWorkspaceEmbedding(page); - const textContent = 'WorkspaceEBEEE is a cute cat'; + const randomStr1 = Math.random().toString(36).substring(2, 6); + const textContent = `Workspace${randomStr1} is a cute cat`; const attachments = [ { name: 'document1.txt', @@ -374,7 +379,7 @@ test.describe('AISettings/Embedding', () => { ]; await utils.settings.uploadWorkspaceEmbedding(page, attachments); - const attachmentList = await page.getByTestId( + const attachmentList = page.getByTestId( 'workspace-embedding-setting-attachment-list' ); await expect( @@ -393,8 +398,11 @@ test.describe('AISettings/Embedding', () => { loggedInPage: page, utils, }) => { + await createLocalWorkspace({ name: 'test' }, page, false, 'affine-cloud'); + await utils.settings.openSettingsPanel(page); await utils.settings.enableWorkspaceEmbedding(page); - const textContent = 'WorkspaceEBEEE is a cute cat'; + const randomStr1 = Math.random().toString(36).substring(2, 6); + const textContent = `Workspace${randomStr1} is a cute cat`; const attachments = [ { name: 'document1.txt', @@ -413,6 +421,7 @@ test.describe('AISettings/Embedding', () => { loggedInPage: page, utils, }) => { + await utils.settings.openSettingsPanel(page); await utils.settings.enableWorkspaceEmbedding(page); await utils.settings.closeSettingsPanel(page); await utils.editor.createDoc( @@ -426,7 +435,9 @@ test.describe('AISettings/Embedding', () => { 'WBIgnoreFFF is a cute dog' ); - await page.waitForTimeout(5000); // wait for the embedding to be ready + await utils.settings.openSettingsPanel(page); + await utils.settings.waitForEmbeddingComplete(page); + await utils.settings.closeSettingsPanel(page); await utils.chatPanel.makeChat( page, @@ -487,6 +498,7 @@ test.describe('AISettings/Embedding', () => { loggedInPage: page, utils, }) => { + await utils.settings.openSettingsPanel(page); await utils.settings.enableWorkspaceEmbedding(page); await utils.settings.closeSettingsPanel(page); diff --git a/tests/affine-cloud-copilot/e2e/utils/chat-panel-utils.ts b/tests/affine-cloud-copilot/e2e/utils/chat-panel-utils.ts index 1e97025cba..2397a9d605 100644 --- a/tests/affine-cloud-copilot/e2e/utils/chat-panel-utils.ts +++ b/tests/affine-cloud-copilot/e2e/utils/chat-panel-utils.ts @@ -199,9 +199,11 @@ export class ChatPanelUtils { } public static async chatWithDoc(page: Page, docName: string) { - const withButton = await page.getByTestId('chat-panel-with-button'); + const withButton = page.getByTestId('chat-panel-with-button'); + await withButton.hover(); await withButton.click(); - const withMenu = await page.getByTestId('ai-add-popover'); + const withMenu = page.getByTestId('ai-add-popover'); + await withMenu.waitFor({ state: 'visible' }); await withMenu.getByText(docName).click(); await page.getByTestId('chat-panel-chips').getByText(docName); } @@ -221,9 +223,20 @@ export class ChatPanelUtils { await withButton.hover(); await withButton.click(); const withMenu = page.getByTestId('ai-add-popover'); + await withMenu.waitFor({ state: 'visible' }); await withMenu.getByTestId('ai-chat-with-files').click(); const fileChooser = await fileChooserPromise; await fileChooser.setFiles(attachment); + + await expect(async () => { + const states = await page + .getByTestId('chat-panel-chip') + .evaluateAll(elements => + elements.map(el => el.getAttribute('data-state')) + ); + + expect(states.every(state => state === 'finished')).toBe(true); + }).toPass({ timeout: 20000 }); } await expect(async () => { const states = await page @@ -267,9 +280,11 @@ export class ChatPanelUtils { public static async chatWithTags(page: Page, tags: string[]) { for (const tag of tags) { - const withButton = await page.getByTestId('chat-panel-with-button'); + const withButton = page.getByTestId('chat-panel-with-button'); + await withButton.hover(); await withButton.click(); - const withMenu = await page.getByTestId('ai-add-popover'); + const withMenu = page.getByTestId('ai-add-popover'); + await withMenu.waitFor({ state: 'visible' }); await withMenu.getByTestId('ai-chat-with-tags').click(); await withMenu.getByText(tag).click(); await page.getByTestId('chat-panel-chips').getByText(tag); @@ -282,9 +297,11 @@ export class ChatPanelUtils { public static async chatWithCollections(page: Page, collections: string[]) { for (const collection of collections) { - const withButton = await page.getByTestId('chat-panel-with-button'); + const withButton = page.getByTestId('chat-panel-with-button'); + await withButton.hover(); await withButton.click(); - const withMenu = await page.getByTestId('ai-add-popover'); + const withMenu = page.getByTestId('ai-add-popover'); + await withMenu.waitFor({ state: 'visible' }); await withMenu.getByTestId('ai-chat-with-collections').click(); await withMenu.getByText(collection).click(); await page.getByTestId('chat-panel-chips').getByText(collection); diff --git a/tests/affine-cloud-copilot/e2e/utils/settings-panel-utils.ts b/tests/affine-cloud-copilot/e2e/utils/settings-panel-utils.ts index e3d2ac7d4d..e6e47585c9 100644 --- a/tests/affine-cloud-copilot/e2e/utils/settings-panel-utils.ts +++ b/tests/affine-cloud-copilot/e2e/utils/settings-panel-utils.ts @@ -207,4 +207,53 @@ export class SettingsPanelUtils { await searcher.getByTestId('doc-selector-confirm-button').click(); } } + + private static async waitForEmbeddingStatus( + page: Page, + timeout: number, + status = 'synced' + ) { + await expect(async () => { + await this.openSettingsPanel(page); + const title = page.getByTestId('embedding-progress-title'); + // oxlint-disable-next-line prefer-dom-node-dataset + const progressAttr = await title.getAttribute('data-progress'); + expect(progressAttr).not.toBe('loading'); + + expect(progressAttr).toBe(status); + }).toPass({ timeout }); + } + + public static async waitForEmbeddingComplete(page: Page, timeout = 30000) { + await this.waitForEmbeddingStatus(page, timeout); + + // check embedding progress count + await expect(async () => { + const count = page.getByTestId('embedding-progress-count'); + const countText = await count.textContent(); + if (countText) { + const [embedded, total] = countText.split('/').map(Number); + expect(embedded).toBe(total); + expect(embedded).toBeGreaterThan(0); + } + }).toPass({ timeout }); + } + + public static async waitForFileEmbeddingReadiness( + page: Page, + expectedFileCount: number, + timeout = 30000 + ) { + await expect(async () => { + const attachmentList = page.getByTestId( + 'workspace-embedding-setting-attachment-list' + ); + const attachmentItems = attachmentList.getByTestId( + 'workspace-embedding-setting-attachment-item' + ); + await expect(attachmentItems).toHaveCount(expectedFileCount); + }).toPass({ timeout }); + + await this.waitForEmbeddingComplete(page, timeout); + } } diff --git a/tests/kit/src/utils/workspace.ts b/tests/kit/src/utils/workspace.ts index 32d8c73c64..1abca252f6 100644 --- a/tests/kit/src/utils/workspace.ts +++ b/tests/kit/src/utils/workspace.ts @@ -16,7 +16,8 @@ export async function openWorkspaceListModal(page: Page) { export async function createLocalWorkspace( params: CreateWorkspaceParams, page: Page, - skipOpenWorkspaceListModal = false + skipOpenWorkspaceListModal = false, + serverId?: string ) { if (!skipOpenWorkspaceListModal) { await openWorkspaceListModal(page); @@ -33,10 +34,9 @@ export async function createLocalWorkspace( await page.getByPlaceholder('Set a Workspace name').click(); await page.getByPlaceholder('Set a Workspace name').fill(params.name); - // select local server await page.getByTestId('server-selector-trigger').click(); const serverSelectorList = page.getByTestId('server-selector-list'); - await serverSelectorList.getByTestId('local').click(); + await serverSelectorList.getByTestId(serverId ?? 'local').click(); // click create button await page.getByTestId('create-workspace-create-button').click({