diff --git a/.devcontainer/build.sh b/.devcontainer/build.sh index 47940963b8..fbe8a69b48 100644 --- a/.devcontainer/build.sh +++ b/.devcontainer/build.sh @@ -12,4 +12,4 @@ yarn install yarn affine @affine/server-native build # Create database -yarn affine @affine/server prisma db push +yarn affine @affine/server prisma migrate reset -f diff --git a/.github/actions/server-test-env/action.yml b/.github/actions/server-test-env/action.yml index 5e8959fc4a..5994d68900 100644 --- a/.github/actions/server-test-env/action.yml +++ b/.github/actions/server-test-env/action.yml @@ -19,5 +19,5 @@ runs: NODE_ENV: test run: | yarn affine @affine/server prisma generate - yarn affine @affine/server prisma db push + yarn affine @affine/server prisma migrate reset -f yarn affine @affine/server data-migration run diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index b8be3820f7..0339d48311 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -444,7 +444,7 @@ model AiContextEmbedding { // a file can be divided into multiple chunks and embedded separately. chunk Int @db.Integer content String @db.VarChar - embedding Unsupported("vector(512)") + embedding Unsupported("vector(1024)") createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3) @@ -462,7 +462,7 @@ model AiWorkspaceEmbedding { // a doc can be divided into multiple chunks and embedded separately. chunk Int @db.Integer content String @db.VarChar - embedding Unsupported("vector(512)") + embedding Unsupported("vector(1024)") createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3) diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md index a2ade51944..3c67be6eae 100644 --- a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md +++ b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md @@ -40,7 +40,10 @@ Generated by [AVA](https://avajs.dev). [ { - id: 'docId1', + blobId: 'fileId1', + chunkSize: 0, + name: 'sample.pdf', + status: 'processing', }, ] @@ -48,9 +51,6 @@ Generated by [AVA](https://avajs.dev). [ { - blobId: 'fileId1', - chunkSize: 0, - name: 'sample.pdf', - status: 'processing', + id: 'docId1', }, ] diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap index 08d34469c9..ae9951c74e 100644 Binary files a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap and b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap differ diff --git a/packages/backend/server/src/__tests__/copilot.e2e.ts b/packages/backend/server/src/__tests__/copilot.e2e.ts index 6c54929fba..1dbcf6f1e0 100644 --- a/packages/backend/server/src/__tests__/copilot.e2e.ts +++ b/packages/backend/server/src/__tests__/copilot.e2e.ts @@ -1,6 +1,7 @@ import { randomUUID } from 'node:crypto'; import { ProjectRoot } from '@affine-tools/utils/path'; +import { PrismaClient } from '@prisma/client'; import type { TestFn } from 'ava'; import ava from 'ava'; import Sinon from 'sinon'; @@ -8,6 +9,7 @@ import Sinon from 'sinon'; import { JobQueue } from '../base'; import { ConfigModule } from '../base/config'; import { AuthService } from '../core/auth'; +import { DocReader } from '../core/doc'; import { WorkspaceModule } from '../core/workspaces'; import { CopilotModule } from '../plugins/copilot'; import { @@ -41,14 +43,16 @@ import { chatWithText, chatWithTextStream, chatWithWorkflow, + cleanObject, createCopilotContext, createCopilotMessage, createCopilotSession, forkCopilotSession, getHistories, listContext, - listContextFiles, - matchContext, + listContextDocAndFiles, + matchFiles, + matchWorkspaceDocs, MockCopilotTestProvider, sse2array, textToEventStream, @@ -59,6 +63,7 @@ import { const test = ava as TestFn<{ auth: AuthService; app: TestingApp; + db: PrismaClient; context: CopilotContextService; jobs: CopilotContextDocJob; prompt: PromptService; @@ -92,16 +97,26 @@ test.before(async t => { tapModule: m => { // use real JobQueue for testing m.overrideProvider(JobQueue).useClass(JobQueue); + m.overrideProvider(DocReader).useValue({ + getFullDocContent() { + return { + title: '1', + summary: '1', + }; + }, + }); }, }); const auth = app.get(AuthService); + const db = app.get(PrismaClient); const context = app.get(CopilotContextService); const prompt = app.get(PromptService); const storage = app.get(CopilotStorage); const jobs = app.get(CopilotContextDocJob); t.context.app = app; + t.context.db = db; t.context.auth = auth; t.context.context = context; t.context.prompt = prompt; @@ -513,15 +528,6 @@ test('should be able to retry with api', async t => { ); } - const cleanObject = (obj: any[]) => - JSON.parse( - JSON.stringify(obj, (k, v) => - ['id', 'sessionId', 'createdAt'].includes(k) || v === null - ? undefined - : v - ) - ); - // retry chat { const { id } = await createWorkspace(app); @@ -771,6 +777,7 @@ test('should be able to manage context', async t => { ProjectRoot.join('packages/common/native/fixtures/sample.pdf').toFileUrl() ); + // match files { const contextId = await createCopilotContext(app, workspaceId, sessionId); @@ -781,34 +788,98 @@ test('should be able to manage context', async t => { 'sample.pdf', buffer ); - await addContextDoc(app, contextId, 'docId1'); - const { docs, files } = - (await listContextFiles(app, workspaceId, sessionId, contextId)) || {}; + const { files } = + (await listContextDocAndFiles(app, workspaceId, sessionId, contextId)) || + {}; t.snapshot( - docs?.map(({ createdAt: _, ...d }) => d), + cleanObject(files, ['id', 'error', 'createdAt']), 'should list context files' ); - t.snapshot( - files?.map(({ createdAt: _, id: __, ...f }) => f), - 'should list context docs' - ); // wait for processing { let { files } = - (await listContextFiles(app, workspaceId, sessionId, contextId)) || {}; + (await listContextDocAndFiles( + app, + workspaceId, + sessionId, + contextId + )) || {}; while (files?.[0].status !== 'finished') { await new Promise(resolve => setTimeout(resolve, 1000)); ({ files } = - (await listContextFiles(app, workspaceId, sessionId, contextId)) || - {}); + (await listContextDocAndFiles( + app, + workspaceId, + sessionId, + contextId + )) || {}); } } - const result = (await matchContext(app, contextId, 'test', 1))!; + const result = (await matchFiles(app, contextId, 'test', 1))!; t.is(result.length, 1, 'should match context'); t.is(result[0].fileId, fileId, 'should match file id'); } + + // match docs + { + const sessionId = await createCopilotSession( + app, + workspaceId, + randomUUID(), + promptName + ); + const contextId = await createCopilotContext(app, workspaceId, sessionId); + + const docId = 'docId1'; + await t.context.db.snapshot.create({ + data: { + workspaceId: workspaceId, + id: docId, + blob: Buffer.from([1, 1]), + state: Buffer.from([1, 1]), + updatedAt: new Date(), + createdAt: new Date(), + }, + }); + + await addContextDoc(app, contextId, docId); + + const { docs } = + (await listContextDocAndFiles(app, workspaceId, sessionId, contextId)) || + {}; + t.snapshot( + cleanObject(docs, ['error', 'createdAt']), + 'should list context docs' + ); + + // wait for processing + { + let { docs } = + (await listContextDocAndFiles( + app, + workspaceId, + sessionId, + contextId + )) || {}; + + while (docs?.[0].status !== 'finished') { + await new Promise(resolve => setTimeout(resolve, 1000)); + ({ docs } = + (await listContextDocAndFiles( + app, + workspaceId, + sessionId, + contextId + )) || {}); + } + } + + const result = (await matchWorkspaceDocs(app, contextId, 'test', 1))!; + t.is(result.length, 1, 'should match context'); + t.is(result[0].docId, docId, 'should match doc id'); + } }); diff --git a/packages/backend/server/src/__tests__/models/copilot-context.spec.ts b/packages/backend/server/src/__tests__/models/copilot-context.spec.ts index 408f9703df..bfc1604e5c 100644 --- a/packages/backend/server/src/__tests__/models/copilot-context.spec.ts +++ b/packages/backend/server/src/__tests__/models/copilot-context.spec.ts @@ -104,14 +104,14 @@ test('should insert embedding by doc id', async t => { { index: 0, content: 'content', - embedding: Array.from({ length: 512 }, () => 1), + embedding: Array.from({ length: 1024 }, () => 1), }, ] ); { const ret = await t.context.copilotContext.matchContentEmbedding( - Array.from({ length: 512 }, () => 0.9), + Array.from({ length: 1024 }, () => 0.9), contextId, 1, 1 @@ -123,7 +123,7 @@ test('should insert embedding by doc id', async t => { { await t.context.copilotContext.deleteEmbedding(contextId, 'file-id'); const ret = await t.context.copilotContext.matchContentEmbedding( - Array.from({ length: 512 }, () => 0.9), + Array.from({ length: 1024 }, () => 0.9), contextId, 1, 1 @@ -151,7 +151,7 @@ test('should insert embedding by doc id', async t => { { index: 0, content: 'content', - embedding: Array.from({ length: 512 }, () => 1), + embedding: Array.from({ length: 1024 }, () => 1), }, ] ); @@ -166,7 +166,7 @@ test('should insert embedding by doc id', async t => { { const ret = await t.context.copilotContext.matchWorkspaceEmbedding( - Array.from({ length: 512 }, () => 0.9), + Array.from({ length: 1024 }, () => 0.9), workspace.id, 1, 1 diff --git a/packages/backend/server/src/__tests__/utils/copilot.ts b/packages/backend/server/src/__tests__/utils/copilot.ts index f4cecbe317..bb6efad5b4 100644 --- a/packages/backend/server/src/__tests__/utils/copilot.ts +++ b/packages/backend/server/src/__tests__/utils/copilot.ts @@ -156,6 +156,16 @@ export class MockCopilotTestProvider } } +export const cleanObject = ( + obj: any[] | undefined, + condition = ['id', 'status', 'error', 'sessionId', 'createdAt'] +) => + JSON.parse( + JSON.stringify(obj || [], (k, v) => + condition.includes(k) || v === null ? undefined : v + ) + ); + export async function createCopilotSession( app: TestingApp, workspaceId: string, @@ -224,7 +234,7 @@ export async function createCopilotContext( return res.createCopilotContext; } -export async function matchContext( +export async function matchFiles( app: TestingApp, contextId: string, content: string, @@ -240,11 +250,11 @@ export async function matchContext( > { const res = await app.gql( ` - query matchContext($contextId: String!, $content: String!, $limit: SafeInt, $threshold: Float) { + query matchFiles($contextId: String!, $content: String!, $limit: SafeInt, $threshold: Float) { currentUser { copilot { contexts(contextId: $contextId) { - matchContext(content: $content, limit: $limit, threshold: $threshold) { + matchFiles(content: $content, limit: $limit, threshold: $threshold) { fileId chunk content @@ -258,7 +268,44 @@ export async function matchContext( { contextId, content, limit, threshold: 1 } ); - return res.currentUser?.copilot?.contexts?.[0]?.matchContext; + return res.currentUser?.copilot?.contexts?.[0]?.matchFiles; +} + +export async function matchWorkspaceDocs( + app: TestingApp, + contextId: string, + content: string, + limit: number +): Promise< + | { + docId: string; + chunk: number; + content: string; + distance: number | null; + }[] + | undefined +> { + const res = await app.gql( + ` + query matchWorkspaceDocs($contextId: String!, $content: String!, $limit: SafeInt, $threshold: Float) { + currentUser { + copilot { + contexts(contextId: $contextId) { + matchWorkspaceDocs(content: $content, limit: $limit, threshold: $threshold) { + docId + chunk + content + distance + } + } + } + } + } + `, + { contextId, content, limit, threshold: 1 } + ); + + return res.currentUser?.copilot?.contexts?.[0]?.matchWorkspaceDocs; } export async function listContext( @@ -376,7 +423,7 @@ export async function removeContextDoc( return res.removeContextDoc; } -export async function listContextFiles( +export async function listContextDocAndFiles( app: TestingApp, workspaceId: string, sessionId: string, @@ -385,6 +432,8 @@ export async function listContextFiles( | { docs: { id: string; + status: string; + error: string | null; createdAt: number; }[]; files: { @@ -393,6 +442,7 @@ export async function listContextFiles( blobId: string; chunkSize: number; status: string; + error: string | null; createdAt: number; }[]; } @@ -405,6 +455,8 @@ export async function listContextFiles( contexts(sessionId: "${sessionId}", contextId: "${contextId}") { docs { id + status + error createdAt } files { @@ -413,6 +465,7 @@ export async function listContextFiles( blobId chunkSize status + error createdAt } } diff --git a/packages/backend/server/src/plugins/copilot/context/embedding.ts b/packages/backend/server/src/plugins/copilot/context/embedding.ts index 40875aa065..3f0c8c411b 100644 --- a/packages/backend/server/src/plugins/copilot/context/embedding.ts +++ b/packages/backend/server/src/plugins/copilot/context/embedding.ts @@ -30,7 +30,7 @@ export class MockEmbeddingClient extends EmbeddingClient { return input.map((_, i) => ({ index: i, content: input[i], - embedding: Array.from({ length: 512 }, () => Math.random()), + embedding: Array.from({ length: 1024 }, () => Math.random()), })); } } diff --git a/packages/backend/server/src/plugins/copilot/context/resolver.ts b/packages/backend/server/src/plugins/copilot/context/resolver.ts index 0e383aa5f0..a9332bdc3c 100644 --- a/packages/backend/server/src/plugins/copilot/context/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/context/resolver.ts @@ -656,10 +656,10 @@ export class CopilotContextResolver { } @ResolveField(() => [ContextMatchedFileChunk], { - description: 'match file context', + description: 'match file in context', }) @CallMetric('ai', 'context_file_remove') - async matchContext( + async matchFiles( @Context() ctx: { req: Request }, @Parent() context: CopilotContextType, @Args('content') content: string, @@ -667,16 +667,11 @@ export class CopilotContextResolver { limit?: number, @Args('threshold', { type: () => Float, nullable: true }) threshold?: number - ) { + ): Promise { if (!this.context.canEmbedding) { return []; } - const lockFlag = `${COPILOT_LOCKER}:context:${context.id}`; - await using lock = await this.mutex.acquire(lockFlag); - if (!lock) { - return new TooManyRequest('Server is busy'); - } const session = await this.context.get(context.id); try { @@ -696,18 +691,20 @@ export class CopilotContextResolver { } } - @ResolveField(() => ContextMatchedDocChunk, { - description: 'match workspace doc content', + @ResolveField(() => [ContextMatchedDocChunk], { + description: 'match workspace docs', }) @CallMetric('ai', 'context_match_workspace_doc') - async matchWorkspaceContext( + async matchWorkspaceDocs( @CurrentUser() user: CurrentUser, @Context() ctx: { req: Request }, @Parent() context: CopilotContextType, @Args('content') content: string, @Args('limit', { type: () => SafeIntResolver, nullable: true }) - limit?: number - ) { + limit?: number, + @Args('threshold', { type: () => Float, nullable: true }) + threshold?: number + ): Promise { if (!this.context.canEmbedding) { return []; } @@ -723,7 +720,8 @@ export class CopilotContextResolver { return await session.matchWorkspaceChunks( content, limit, - this.getSignal(ctx.req) + this.getSignal(ctx.req), + threshold ); } catch (e: any) { throw new CopilotFailedToMatchContext({ diff --git a/packages/backend/server/src/plugins/copilot/context/session.ts b/packages/backend/server/src/plugins/copilot/context/session.ts index 2ed52e1f1b..99b8e97dea 100644 --- a/packages/backend/server/src/plugins/copilot/context/session.ts +++ b/packages/backend/server/src/plugins/copilot/context/session.ts @@ -199,7 +199,7 @@ export class ContextSession implements AsyncDisposable { return this.models.copilotContext.matchWorkspaceEmbedding( embedding, - this.id, + this.workspaceId, topK, threshold ); diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index 7d88f0f14e..250c5d28af 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -109,11 +109,11 @@ type CopilotContext { files: [CopilotContextFile!]! id: ID! - """match file context""" - matchContext(content: String!, limit: SafeInt, threshold: Float): [ContextMatchedFileChunk!]! + """match file in context""" + matchFiles(content: String!, limit: SafeInt, threshold: Float): [ContextMatchedFileChunk!]! - """match workspace doc content""" - matchWorkspaceContext(content: String!, limit: SafeInt): ContextMatchedDocChunk! + """match workspace docs""" + matchWorkspaceDocs(content: String!, limit: SafeInt, threshold: Float): [ContextMatchedDocChunk!]! """list tags in context""" tags: [CopilotContextCategory!]! diff --git a/packages/common/graphql/src/graphql/copilot-context-match-all.gql b/packages/common/graphql/src/graphql/copilot-context-match-all.gql new file mode 100644 index 0000000000..00d386abe7 --- /dev/null +++ b/packages/common/graphql/src/graphql/copilot-context-match-all.gql @@ -0,0 +1,20 @@ +query matchContext($contextId: String!, $content: String!, $limit: SafeInt, $threshold: Float) { + currentUser { + copilot { + contexts(contextId: $contextId) { + matchFiles(content: $content, limit: $limit, threshold: $threshold) { + fileId + chunk + content + distance + } + matchWorkspaceDocs(content: $content, limit: $limit, threshold: $threshold) { + docId + chunk + content + distance + } + } + } + } +} diff --git a/packages/common/graphql/src/graphql/copilot-context-workspace-match.gql b/packages/common/graphql/src/graphql/copilot-context-match-docs.gql similarity index 52% rename from packages/common/graphql/src/graphql/copilot-context-workspace-match.gql rename to packages/common/graphql/src/graphql/copilot-context-match-docs.gql index 183b6b541d..464f50517d 100644 --- a/packages/common/graphql/src/graphql/copilot-context-workspace-match.gql +++ b/packages/common/graphql/src/graphql/copilot-context-match-docs.gql @@ -1,8 +1,8 @@ -query matchWorkspaceContext($contextId: String!, $content: String!, $limit: SafeInt) { +query matchWorkspaceDocs($contextId: String!, $content: String!, $limit: SafeInt) { currentUser { copilot { contexts(contextId: $contextId) { - matchWorkspaceContext(content: $content, limit: $limit) { + matchWorkspaceDocs(content: $content, limit: $limit) { docId chunk content diff --git a/packages/common/graphql/src/graphql/copilot-context-file-match.gql b/packages/common/graphql/src/graphql/copilot-context-match-files.gql similarity index 55% rename from packages/common/graphql/src/graphql/copilot-context-file-match.gql rename to packages/common/graphql/src/graphql/copilot-context-match-files.gql index 46432cc1a6..9d8cfd7a82 100644 --- a/packages/common/graphql/src/graphql/copilot-context-file-match.gql +++ b/packages/common/graphql/src/graphql/copilot-context-match-files.gql @@ -1,8 +1,8 @@ -query matchContext($contextId: String!, $content: String!, $limit: SafeInt) { +query matchFiles($contextId: String!, $content: String!, $limit: SafeInt) { currentUser { copilot { contexts(contextId: $contextId) { - matchContext(content: $content, limit: $limit) { + matchFiles(content: $content, limit: $limit) { fileId chunk content diff --git a/packages/common/graphql/src/graphql/index.ts b/packages/common/graphql/src/graphql/index.ts index 6c52569325..b439f096b6 100644 --- a/packages/common/graphql/src/graphql/index.ts +++ b/packages/common/graphql/src/graphql/index.ts @@ -205,25 +205,6 @@ export const addContextFileMutation = { file: true, }; -export const matchContextQuery = { - id: 'matchContextQuery' as const, - op: 'matchContext', - query: `query matchContext($contextId: String!, $content: String!, $limit: SafeInt) { - currentUser { - copilot { - contexts(contextId: $contextId) { - matchContext(content: $content, limit: $limit) { - fileId - chunk - content - distance - } - } - } - } -}`, -}; - export const removeContextFileMutation = { id: 'removeContextFileMutation' as const, op: 'removeContextFile', @@ -295,14 +276,20 @@ export const listContextQuery = { }`, }; -export const matchWorkspaceContextQuery = { - id: 'matchWorkspaceContextQuery' as const, - op: 'matchWorkspaceContext', - query: `query matchWorkspaceContext($contextId: String!, $content: String!, $limit: SafeInt) { +export const matchContextQuery = { + id: 'matchContextQuery' as const, + op: 'matchContext', + query: `query matchContext($contextId: String!, $content: String!, $limit: SafeInt, $threshold: Float) { currentUser { copilot { contexts(contextId: $contextId) { - matchWorkspaceContext(content: $content, limit: $limit) { + matchFiles(content: $content, limit: $limit, threshold: $threshold) { + fileId + chunk + content + distance + } + matchWorkspaceDocs(content: $content, limit: $limit, threshold: $threshold) { docId chunk content @@ -314,6 +301,44 @@ export const matchWorkspaceContextQuery = { }`, }; +export const matchWorkspaceDocsQuery = { + id: 'matchWorkspaceDocsQuery' as const, + op: 'matchWorkspaceDocs', + query: `query matchWorkspaceDocs($contextId: String!, $content: String!, $limit: SafeInt) { + currentUser { + copilot { + contexts(contextId: $contextId) { + matchWorkspaceDocs(content: $content, limit: $limit) { + docId + chunk + content + distance + } + } + } + } +}`, +}; + +export const matchFilesQuery = { + id: 'matchFilesQuery' as const, + op: 'matchFiles', + query: `query matchFiles($contextId: String!, $content: String!, $limit: SafeInt) { + currentUser { + copilot { + contexts(contextId: $contextId) { + matchFiles(content: $content, limit: $limit) { + fileId + chunk + content + distance + } + } + } + } +}`, +}; + export const getWorkspaceEmbeddingStatusQuery = { id: 'getWorkspaceEmbeddingStatusQuery' as const, op: 'getWorkspaceEmbeddingStatus', diff --git a/packages/common/graphql/src/schema.ts b/packages/common/graphql/src/schema.ts index 7e5ed5d340..89354ca078 100644 --- a/packages/common/graphql/src/schema.ts +++ b/packages/common/graphql/src/schema.ts @@ -172,24 +172,25 @@ export interface CopilotContext { /** list files in context */ files: Array; id: Scalars['ID']['output']; - /** match file context */ - matchContext: Array; - /** match workspace doc content */ - matchWorkspaceContext: ContextMatchedDocChunk; + /** match file in context */ + matchFiles: Array; + /** match workspace docs */ + matchWorkspaceDocs: Array; /** list tags in context */ tags: Array; workspaceId: Scalars['String']['output']; } -export interface CopilotContextMatchContextArgs { +export interface CopilotContextMatchFilesArgs { content: Scalars['String']['input']; limit?: InputMaybe; threshold?: InputMaybe; } -export interface CopilotContextMatchWorkspaceContextArgs { +export interface CopilotContextMatchWorkspaceDocsArgs { content: Scalars['String']['input']; limit?: InputMaybe; + threshold?: InputMaybe; } export interface CopilotContextCategory { @@ -2562,32 +2563,6 @@ export type AddContextFileMutation = { }; }; -export type MatchContextQueryVariables = Exact<{ - contextId: Scalars['String']['input']; - content: Scalars['String']['input']; - limit?: InputMaybe; -}>; - -export type MatchContextQuery = { - __typename?: 'Query'; - currentUser: { - __typename?: 'UserType'; - copilot: { - __typename?: 'Copilot'; - contexts: Array<{ - __typename?: 'CopilotContext'; - matchContext: Array<{ - __typename?: 'ContextMatchedFileChunk'; - fileId: string; - chunk: number; - content: string; - distance: number | null; - }>; - }>; - }; - } | null; -}; - export type RemoveContextFileMutationVariables = Exact<{ options: RemoveContextFileInput; }>; @@ -2677,13 +2652,14 @@ export type ListContextQuery = { } | null; }; -export type MatchWorkspaceContextQueryVariables = Exact<{ +export type MatchContextQueryVariables = Exact<{ contextId: Scalars['String']['input']; content: Scalars['String']['input']; limit?: InputMaybe; + threshold?: InputMaybe; }>; -export type MatchWorkspaceContextQuery = { +export type MatchContextQuery = { __typename?: 'Query'; currentUser: { __typename?: 'UserType'; @@ -2691,13 +2667,72 @@ export type MatchWorkspaceContextQuery = { __typename?: 'Copilot'; contexts: Array<{ __typename?: 'CopilotContext'; - matchWorkspaceContext: { + matchFiles: Array<{ + __typename?: 'ContextMatchedFileChunk'; + fileId: string; + chunk: number; + content: string; + distance: number | null; + }>; + matchWorkspaceDocs: Array<{ __typename?: 'ContextMatchedDocChunk'; docId: string; chunk: number; content: string; distance: number | null; - }; + }>; + }>; + }; + } | null; +}; + +export type MatchWorkspaceDocsQueryVariables = Exact<{ + contextId: Scalars['String']['input']; + content: Scalars['String']['input']; + limit?: InputMaybe; +}>; + +export type MatchWorkspaceDocsQuery = { + __typename?: 'Query'; + currentUser: { + __typename?: 'UserType'; + copilot: { + __typename?: 'Copilot'; + contexts: Array<{ + __typename?: 'CopilotContext'; + matchWorkspaceDocs: Array<{ + __typename?: 'ContextMatchedDocChunk'; + docId: string; + chunk: number; + content: string; + distance: number | null; + }>; + }>; + }; + } | null; +}; + +export type MatchFilesQueryVariables = Exact<{ + contextId: Scalars['String']['input']; + content: Scalars['String']['input']; + limit?: InputMaybe; +}>; + +export type MatchFilesQuery = { + __typename?: 'Query'; + currentUser: { + __typename?: 'UserType'; + copilot: { + __typename?: 'Copilot'; + contexts: Array<{ + __typename?: 'CopilotContext'; + matchFiles: Array<{ + __typename?: 'ContextMatchedFileChunk'; + fileId: string; + chunk: number; + content: string; + distance: number | null; + }>; }>; }; } | null; @@ -4315,11 +4350,6 @@ export type Queries = variables: ListBlobsQueryVariables; response: ListBlobsQuery; } - | { - name: 'matchContextQuery'; - variables: MatchContextQueryVariables; - response: MatchContextQuery; - } | { name: 'listContextObjectQuery'; variables: ListContextObjectQueryVariables; @@ -4331,9 +4361,19 @@ export type Queries = response: ListContextQuery; } | { - name: 'matchWorkspaceContextQuery'; - variables: MatchWorkspaceContextQueryVariables; - response: MatchWorkspaceContextQuery; + name: 'matchContextQuery'; + variables: MatchContextQueryVariables; + response: MatchContextQuery; + } + | { + name: 'matchWorkspaceDocsQuery'; + variables: MatchWorkspaceDocsQueryVariables; + response: MatchWorkspaceDocsQuery; + } + | { + name: 'matchFilesQuery'; + variables: MatchFilesQueryVariables; + response: MatchFilesQuery; } | { name: 'getWorkspaceEmbeddingStatusQuery'; diff --git a/packages/frontend/core/src/blocksuite/ai/actions/types.ts b/packages/frontend/core/src/blocksuite/ai/actions/types.ts index 4f38f6c028..8f17bc36c3 100644 --- a/packages/frontend/core/src/blocksuite/ai/actions/types.ts +++ b/packages/frontend/core/src/blocksuite/ai/actions/types.ts @@ -1,5 +1,6 @@ import type { ChatHistoryOrder, + ContextMatchedDocChunk, ContextMatchedFileChunk, CopilotContextCategory, CopilotContextDoc, @@ -312,7 +313,10 @@ declare global { contextId: string, content: string, limit?: number - ) => Promise; + ) => Promise<{ + files?: ContextMatchedFileChunk[]; + docs?: ContextMatchedDocChunk[]; + }>; } // TODO(@Peng): should be refactored to get rid of implement details (like messages, action, role, etc.) diff --git a/packages/frontend/core/src/blocksuite/ai/chat-panel/chat-panel-input.ts b/packages/frontend/core/src/blocksuite/ai/chat-panel/chat-panel-input.ts index 5dbe1a1fa1..1cbe52b70a 100644 --- a/packages/frontend/core/src/blocksuite/ai/chat-panel/chat-panel-input.ts +++ b/packages/frontend/core/src/blocksuite/ai/chat-panel/chat-panel-input.ts @@ -556,9 +556,12 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) { private async _getMatchedContexts(userInput: string) { const contextId = await this.getContextId(); - const matched = contextId - ? (await AIProvider.context?.matchContext(contextId, userInput)) || [] - : []; + // TODO(@akumatus): adapt workspace docs + const { files: matched = [] } = + (contextId && + (await AIProvider.context?.matchContext(contextId, userInput))) || + {}; + const contexts = this.chatContextValue.chips.reduce( (acc, chip, index) => { if (chip.state !== 'finished') { diff --git a/packages/frontend/core/src/blocksuite/ai/provider/copilot-client.ts b/packages/frontend/core/src/blocksuite/ai/provider/copilot-client.ts index 4916de9ed1..fef24e4177 100644 --- a/packages/frontend/core/src/blocksuite/ai/provider/copilot-client.ts +++ b/packages/frontend/core/src/blocksuite/ai/provider/copilot-client.ts @@ -341,7 +341,9 @@ export class CopilotClient { limit, }, }); - return res.currentUser?.copilot?.contexts?.[0]?.matchContext; + const { matchFiles: files, matchWorkspaceDocs: docs } = + res.currentUser?.copilot?.contexts?.[0] || {}; + return { files, docs }; } async chatText({