diff --git a/packages/backend/server/src/plugins/copilot/tools/exa-search.ts b/packages/backend/server/src/plugins/copilot/tools/exa-search.ts index 3fce2441a2..19e3fa4c6f 100644 --- a/packages/backend/server/src/plugins/copilot/tools/exa-search.ts +++ b/packages/backend/server/src/plugins/copilot/tools/exa-search.ts @@ -7,7 +7,8 @@ import { defineTool } from './tool'; export const createExaSearchTool = (config: Config) => { return defineTool({ - description: 'Search the web using Exa, one of the best web search APIs for AI', + description: + 'Search the web using Exa, one of the best web search APIs for AI', inputSchema: z.object({ query: z.string().describe('The query to search the web for.'), mode: z diff --git a/packages/frontend/apps/electron/src/helper/dialog/dialog.ts b/packages/frontend/apps/electron/src/helper/dialog/dialog.ts index bc5a5a0acd..4f0a41ec23 100644 --- a/packages/frontend/apps/electron/src/helper/dialog/dialog.ts +++ b/packages/frontend/apps/electron/src/helper/dialog/dialog.ts @@ -1,4 +1,4 @@ -import { parse } from 'node:path'; +import { parse, resolve } from 'node:path'; import { DocStorage, ValidationResult } from '@affine/native'; import { parseUniversalId } from '@affine/nbstore'; @@ -71,10 +71,34 @@ function getDefaultDBFileName(name: string, id: string) { return fileName.replace(/[/\\?%*:|"<>]/g, '-'); } +async function resolveExistingPath(path: string) { + if (!(await fs.pathExists(path))) { + return null; + } + try { + return await fs.realpath(path); + } catch { + return resolve(path); + } +} + +async function isSameFilePath(sourcePath: string, targetPath: string) { + if (resolve(sourcePath) === resolve(targetPath)) { + return true; + } + + const [sourceRealPath, targetRealPath] = await Promise.all([ + resolveExistingPath(sourcePath), + resolveExistingPath(targetPath), + ]); + + return !!sourceRealPath && sourceRealPath === targetRealPath; +} + /** * This function is called when the user clicks the "Save" button in the "Save Workspace" dialog. * - * It will just copy the file to the given path + * It will export a compacted database file to the given path */ export async function saveDBFileAs( universalId: string, @@ -115,12 +139,26 @@ export async function saveDBFileAs( const filePath = ret.filePath; if (ret.canceled || !filePath) { - return { - canceled: true, - }; + return { canceled: true }; } - await fs.copyFile(dbPath, filePath); + if (await isSameFilePath(dbPath, filePath)) { + return { error: 'DB_FILE_PATH_INVALID' }; + } + + const tempFilePath = `${filePath}.${nanoid(6)}.tmp`; + if (await fs.pathExists(tempFilePath)) { + await fs.remove(tempFilePath); + } + + try { + await pool.vacuumInto(universalId, tempFilePath); + await fs.move(tempFilePath, filePath, { overwrite: true }); + } finally { + if (await fs.pathExists(tempFilePath)) { + await fs.remove(tempFilePath); + } + } logger.log('saved', filePath); if (!fakedResult) { mainRPC.showItemInFolder(filePath).catch(err => { @@ -183,11 +221,7 @@ export async function loadDBFile( const provided = getFakedResult() ?? (dbFilePath - ? { - filePath: dbFilePath, - filePaths: [dbFilePath], - canceled: false, - } + ? { filePath: dbFilePath, filePaths: [dbFilePath], canceled: false } : undefined); const ret = provided ?? @@ -224,6 +258,10 @@ export async function loadDBFile( return await cpV1DBFile(originalPath, workspaceId); } + if (!(await storage.validateImportSchema())) { + return { error: 'DB_FILE_INVALID' }; + } + // v2 import logic const internalFilePath = await getSpaceDBPath( 'local', @@ -231,8 +269,8 @@ export async function loadDBFile( workspaceId ); await fs.ensureDir(parse(internalFilePath).dir); - await fs.copy(originalPath, internalFilePath); - logger.info(`loadDBFile, copy: ${originalPath} -> ${internalFilePath}`); + await storage.vacuumInto(internalFilePath); + logger.info(`loadDBFile, vacuum: ${originalPath} -> ${internalFilePath}`); storage = new DocStorage(internalFilePath); await storage.setSpaceId(workspaceId); @@ -260,17 +298,16 @@ async function cpV1DBFile( return { error: 'DB_FILE_INVALID' }; // invalid db file } - // checkout to make sure wal is flushed const connection = new SqliteConnection(originalPath); - await connection.connect(); - await connection.checkpoint(); - await connection.close(); + if (!(await connection.validateImportSchema())) { + return { error: 'DB_FILE_INVALID' }; + } const internalFilePath = await getWorkspaceDBPath('workspace', workspaceId); - await fs.ensureDir(await getWorkspacesBasePath()); - await fs.copy(originalPath, internalFilePath); - logger.info(`loadDBFile, copy: ${originalPath} -> ${internalFilePath}`); + await fs.ensureDir(parse(internalFilePath).dir); + await connection.vacuumInto(internalFilePath); + logger.info(`loadDBFile, vacuum: ${originalPath} -> ${internalFilePath}`); await storeWorkspaceMeta(workspaceId, { id: workspaceId, diff --git a/packages/frontend/apps/electron/test/dialog/dialog.spec.ts b/packages/frontend/apps/electron/test/dialog/dialog.spec.ts new file mode 100644 index 0000000000..b1941b0201 --- /dev/null +++ b/packages/frontend/apps/electron/test/dialog/dialog.spec.ts @@ -0,0 +1,268 @@ +import { afterEach, describe, expect, test, vi } from 'vitest'; + +const connect = vi.fn(); +const checkpoint = vi.fn(); +const poolVacuumInto = vi.fn(); +const pathExists = vi.fn(); +const remove = vi.fn(); +const move = vi.fn(); +const realpath = vi.fn(); +const copyFile = vi.fn(); +const ensureDir = vi.fn(); +const copy = vi.fn(); +const storeWorkspaceMeta = vi.fn(); +const getSpaceDBPath = vi.fn(); +const getWorkspaceDBPath = vi.fn(); +const getWorkspacesBasePath = vi.fn(); +const docValidate = vi.fn(); +const docValidateImportSchema = vi.fn(); +const docVacuumInto = vi.fn(); +const docSetSpaceId = vi.fn(); +const sqliteValidate = vi.fn(); +const sqliteValidateImportSchema = vi.fn(); +const sqliteVacuumInto = vi.fn(); + +vi.doMock('nanoid', () => ({ + nanoid: () => 'workspace-1', +})); + +vi.doMock('@affine/native', () => { + const ValidationResult = { + MissingTables: 'MissingTables', + MissingDocIdColumn: 'MissingDocIdColumn', + MissingVersionColumn: 'MissingVersionColumn', + GeneralError: 'GeneralError', + Valid: 'Valid', + }; + + return { + ValidationResult, + DocStorage: class { + constructor(private readonly path: string) {} + + validate() { + return docValidate(this.path); + } + + validateImportSchema() { + return docValidateImportSchema(this.path); + } + + vacuumInto(path: string) { + return docVacuumInto(this.path, path); + } + + setSpaceId(spaceId: string) { + return docSetSpaceId(this.path, spaceId); + } + }, + SqliteConnection: class { + static validate(path: string) { + return sqliteValidate(path); + } + + constructor(private readonly path: string) {} + + validateImportSchema() { + return sqliteValidateImportSchema(this.path); + } + + vacuumInto(path: string) { + return sqliteVacuumInto(this.path, path); + } + }, + }; +}); + +vi.doMock('@affine/electron/helper/nbstore', () => ({ + getDocStoragePool: () => ({ + connect, + checkpoint, + vacuumInto: poolVacuumInto, + }), +})); + +vi.doMock('@affine/electron/helper/main-rpc', () => ({ + mainRPC: { + showItemInFolder: vi.fn(), + }, +})); + +vi.doMock('@affine/electron/helper/workspace/meta', () => ({ + getSpaceDBPath, + getWorkspaceDBPath, + getWorkspacesBasePath, +})); + +vi.doMock('@affine/electron/helper/workspace', () => ({ + storeWorkspaceMeta, +})); + +vi.doMock('fs-extra', () => ({ + default: { + pathExists, + remove, + move, + realpath, + copyFile, + ensureDir, + copy, + }, +})); + +afterEach(() => { + vi.clearAllMocks(); + vi.resetModules(); +}); + +describe('dialog export', () => { + test('saveDBFileAs exports a vacuumed backup instead of copying the live db', async () => { + const dbPath = '/tmp/workspace/storage.db'; + const exportPath = '/tmp/export.affine'; + const tempExportPath = '/tmp/export.affine.workspace-1.tmp'; + const id = '@peer(local);@type(workspace);@id(workspace-1);'; + + pathExists.mockImplementation(async path => path === dbPath); + realpath.mockImplementation(async path => path); + getSpaceDBPath.mockResolvedValue(dbPath); + move.mockResolvedValue(undefined); + + const { saveDBFileAs, setFakeDialogResult } = + await import('@affine/electron/helper/dialog/dialog'); + + setFakeDialogResult({ filePath: exportPath }); + + const result = await saveDBFileAs(id, 'My Space'); + + expect(result).toEqual({ filePath: exportPath }); + expect(connect).toHaveBeenCalledWith(id, dbPath); + expect(checkpoint).toHaveBeenCalledWith(id); + expect(poolVacuumInto).toHaveBeenCalledWith(id, tempExportPath); + expect(move).toHaveBeenCalledWith(tempExportPath, exportPath, { + overwrite: true, + }); + expect(remove).not.toHaveBeenCalledWith(exportPath); + expect(copyFile).not.toHaveBeenCalled(); + }); + + test('saveDBFileAs rejects exporting over the live database path', async () => { + const dbPath = '/tmp/workspace/storage.db'; + const id = '@peer(local);@type(workspace);@id(workspace-1);'; + + pathExists.mockResolvedValue(false); + getSpaceDBPath.mockResolvedValue(dbPath); + + const { saveDBFileAs, setFakeDialogResult } = + await import('@affine/electron/helper/dialog/dialog'); + + setFakeDialogResult({ filePath: dbPath }); + + const result = await saveDBFileAs(id, 'My Space'); + + expect(result).toEqual({ error: 'DB_FILE_PATH_INVALID' }); + expect(poolVacuumInto).not.toHaveBeenCalled(); + expect(copyFile).not.toHaveBeenCalled(); + }); + + test('saveDBFileAs rejects exporting to a symlink alias of the live database', async () => { + const dbPath = '/tmp/workspace/storage.db'; + const exportPath = '/tmp/alias.affine'; + const id = '@peer(local);@type(workspace);@id(workspace-1);'; + + pathExists.mockResolvedValue(true); + realpath.mockImplementation(async path => + path === exportPath ? dbPath : path + ); + getSpaceDBPath.mockResolvedValue(dbPath); + + const { saveDBFileAs, setFakeDialogResult } = + await import('@affine/electron/helper/dialog/dialog'); + + setFakeDialogResult({ filePath: exportPath }); + + const result = await saveDBFileAs(id, 'My Space'); + + expect(result).toEqual({ error: 'DB_FILE_PATH_INVALID' }); + expect(poolVacuumInto).not.toHaveBeenCalled(); + expect(move).not.toHaveBeenCalled(); + }); +}); + +describe('dialog import', () => { + test('loadDBFile validates schema and vacuums v2 imports into internal storage', async () => { + const originalPath = '/tmp/import.affine'; + const internalPath = '/app/workspaces/local/workspace-1/storage.db'; + + getWorkspacesBasePath.mockResolvedValue('/app/workspaces'); + getSpaceDBPath.mockResolvedValue(internalPath); + docValidate.mockResolvedValue(true); + docValidateImportSchema.mockResolvedValue(true); + docVacuumInto.mockResolvedValue(undefined); + docSetSpaceId.mockResolvedValue(undefined); + ensureDir.mockResolvedValue(undefined); + + const { loadDBFile, setFakeDialogResult } = + await import('@affine/electron/helper/dialog/dialog'); + + setFakeDialogResult({ filePath: originalPath }); + + const result = await loadDBFile(); + + expect(result).toEqual({ workspaceId: 'workspace-1' }); + expect(docValidate).toHaveBeenCalledWith(originalPath); + expect(docValidateImportSchema).toHaveBeenCalledWith(originalPath); + expect(docVacuumInto).toHaveBeenCalledWith(originalPath, internalPath); + expect(docSetSpaceId).toHaveBeenCalledWith(internalPath, 'workspace-1'); + expect(copy).not.toHaveBeenCalled(); + }); + + test('loadDBFile rejects v2 imports with unexpected schema objects', async () => { + const originalPath = '/tmp/import.affine'; + + getWorkspacesBasePath.mockResolvedValue('/app/workspaces'); + docValidate.mockResolvedValue(true); + docValidateImportSchema.mockResolvedValue(false); + + const { loadDBFile, setFakeDialogResult } = + await import('@affine/electron/helper/dialog/dialog'); + + setFakeDialogResult({ filePath: originalPath }); + + const result = await loadDBFile(); + + expect(result).toEqual({ error: 'DB_FILE_INVALID' }); + expect(docVacuumInto).not.toHaveBeenCalled(); + expect(copy).not.toHaveBeenCalled(); + }); + + test('loadDBFile validates schema and vacuums v1 imports into internal storage', async () => { + const originalPath = '/tmp/import-v1.affine'; + const internalPath = '/app/workspaces/workspace-1/storage.db'; + + getWorkspacesBasePath.mockResolvedValue('/app/workspaces'); + getWorkspaceDBPath.mockResolvedValue(internalPath); + docValidate.mockResolvedValue(false); + sqliteValidate.mockResolvedValue('Valid'); + sqliteValidateImportSchema.mockResolvedValue(true); + sqliteVacuumInto.mockResolvedValue(undefined); + ensureDir.mockResolvedValue(undefined); + + const { loadDBFile, setFakeDialogResult } = + await import('@affine/electron/helper/dialog/dialog'); + + setFakeDialogResult({ filePath: originalPath }); + + const result = await loadDBFile(); + + expect(result).toEqual({ workspaceId: 'workspace-1' }); + expect(sqliteValidate).toHaveBeenCalledWith(originalPath); + expect(sqliteValidateImportSchema).toHaveBeenCalledWith(originalPath); + expect(ensureDir).toHaveBeenCalledWith('/app/workspaces/workspace-1'); + expect(sqliteVacuumInto).toHaveBeenCalledWith(originalPath, internalPath); + expect(storeWorkspaceMeta).toHaveBeenCalledWith('workspace-1', { + id: 'workspace-1', + mainDBPath: internalPath, + }); + expect(copy).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/frontend/core/src/desktop/dialogs/import/index.tsx b/packages/frontend/core/src/desktop/dialogs/import/index.tsx index 780fa1da13..921c383634 100644 --- a/packages/frontend/core/src/desktop/dialogs/import/index.tsx +++ b/packages/frontend/core/src/desktop/dialogs/import/index.tsx @@ -204,6 +204,7 @@ type ImportResult = { entryId?: string; isWorkspaceFile?: boolean; rootFolderId?: string; + importedWorkspace?: WorkspaceMetadata; }; type ImportedWorkspacePayload = { @@ -554,11 +555,12 @@ const importConfigs: Record = { _organizeService, _explorerIconService ) => { - await handleImportAffineFile(); + const workspace = await handleImportAffineFile(); return { docIds: [], entryId: undefined, isWorkspaceFile: true, + importedWorkspace: workspace, }; }, }, @@ -773,7 +775,6 @@ export const ImportDialog = ({ undefined, (payload?: ImportedWorkspacePayload) => { if (payload) { - handleCreatedWorkspace({ metadata: payload.workspace }); resolve(payload.workspace); } else { reject(new Error('No workspace imported')); @@ -782,7 +783,7 @@ export const ImportDialog = ({ ); }); }; - }, [globalDialogService, handleCreatedWorkspace]); + }, [globalDialogService]); const handleImport = useAsyncCallback( async (type: ImportType) => { @@ -812,16 +813,27 @@ export const ImportDialog = ({ }); } - const { docIds, entryId, isWorkspaceFile, rootFolderId } = - await importConfig.importFunction( - docCollection, - files, - handleImportAffineFile, - organizeService, - explorerIconService - ); + const { + docIds, + entryId, + isWorkspaceFile, + rootFolderId, + importedWorkspace, + } = await importConfig.importFunction( + docCollection, + files, + handleImportAffineFile, + organizeService, + explorerIconService + ); - setImportResult({ docIds, entryId, isWorkspaceFile, rootFolderId }); + setImportResult({ + docIds, + entryId, + isWorkspaceFile, + rootFolderId, + importedWorkspace, + }); setStatus('success'); track.$.importModal.$.import({ type, @@ -855,9 +867,21 @@ export const ImportDialog = ({ ] ); + const finishImport = useCallback(() => { + if (importResult?.importedWorkspace) { + handleCreatedWorkspace({ metadata: importResult.importedWorkspace }); + } + if (!importResult) { + close(); + return; + } + const { importedWorkspace: _workspace, ...result } = importResult; + close(result); + }, [close, handleCreatedWorkspace, importResult]); + const handleComplete = useCallback(() => { - close(importResult || undefined); - }, [importResult, close]); + finishImport(); + }, [finishImport]); const handleRetry = () => { setStatus('idle'); @@ -875,7 +899,7 @@ export const ImportDialog = ({ open onOpenChange={(open: boolean) => { if (!open) { - close(importResult || undefined); + finishImport(); } }} width={480} diff --git a/packages/frontend/native/index.d.ts b/packages/frontend/native/index.d.ts index e85784fa3b..f27f9c6518 100644 --- a/packages/frontend/native/index.d.ts +++ b/packages/frontend/native/index.d.ts @@ -46,6 +46,8 @@ export declare function verifyChallengeResponse(response: string, bits: number, export declare class DocStorage { constructor(path: string) validate(): Promise + validateImportSchema(): Promise + vacuumInto(path: string): Promise setSpaceId(spaceId: string): Promise } @@ -55,6 +57,7 @@ export declare class DocStoragePool { connect(universalId: string, path: string): Promise disconnect(universalId: string): Promise checkpoint(universalId: string): Promise + vacuumInto(universalId: string, path: string): Promise crawlDocData(universalId: string, docId: string): Promise setSpaceId(universalId: string, spaceId: string): Promise pushUpdate(universalId: string, docId: string, update: Uint8Array): Promise @@ -196,11 +199,13 @@ export declare class SqliteConnection { close(): Promise get isClose(): boolean static validate(path: string): Promise + validateImportSchema(): Promise migrateAddDocId(): Promise /** * Flush the WAL file to the database file. * See https://www.sqlite.org/pragma.html#pragma_wal_checkpoint:~:text=PRAGMA%20schema.wal_checkpoint%3B */ checkpoint(): Promise + vacuumInto(path: string): Promise } export interface BlobRow { diff --git a/packages/frontend/native/nbstore/src/lib.rs b/packages/frontend/native/nbstore/src/lib.rs index e294b63234..13963fc2cf 100644 --- a/packages/frontend/native/nbstore/src/lib.rs +++ b/packages/frontend/native/nbstore/src/lib.rs @@ -129,6 +129,12 @@ impl DocStoragePool { Ok(()) } + #[napi] + pub async fn vacuum_into(&self, universal_id: String, path: String) -> Result<()> { + self.pool.get(universal_id).await?.vacuum_into(path).await?; + Ok(()) + } + #[napi] pub async fn crawl_doc_data(&self, universal_id: String, doc_id: String) -> Result { let result = self.get(universal_id).await?.crawl_doc_data(&doc_id).await?; @@ -485,6 +491,17 @@ impl DocStorage { Ok(self.storage.validate().await?) } + #[napi] + pub async fn validate_import_schema(&self) -> Result { + Ok(self.storage.validate_import_schema().await?) + } + + #[napi] + pub async fn vacuum_into(&self, path: String) -> Result<()> { + self.storage.vacuum_into(path).await?; + Ok(()) + } + #[napi] pub async fn set_space_id(&self, space_id: String) -> Result<()> { self.storage.connect().await?; diff --git a/packages/frontend/native/nbstore/src/storage.rs b/packages/frontend/native/nbstore/src/storage.rs index 0e5622845e..cb82dd2eb7 100644 --- a/packages/frontend/native/nbstore/src/storage.rs +++ b/packages/frontend/native/nbstore/src/storage.rs @@ -1,6 +1,9 @@ use std::sync::Arc; -use affine_schema::get_migrator; +use affine_schema::{ + get_migrator, + import_validation::{V2_IMPORT_SCHEMA_RULES, validate_import_schema, validate_required_schema}, +}; use memory_indexer::InMemoryIndex; use sqlx::{ Pool, Row, @@ -49,17 +52,27 @@ impl SqliteDocStorage { } pub async fn validate(&self) -> Result { - let record = sqlx::query("SELECT * FROM _sqlx_migrations ORDER BY installed_on ASC LIMIT 1;") - .fetch_optional(&self.pool) - .await; - - match record { - Ok(Some(row)) => { - let name: &str = row.try_get("description")?; - Ok(name == "init_v2") - } - _ => Ok(false), + if self.path == ":memory:" { + return Ok(validate_required_schema(&self.pool, &V2_IMPORT_SCHEMA_RULES).await?); } + + let Ok(pool) = self.open_readonly_pool().await else { + return Ok(false); + }; + + Ok(validate_required_schema(&pool, &V2_IMPORT_SCHEMA_RULES).await?) + } + + pub async fn validate_import_schema(&self) -> Result { + if self.path == ":memory:" { + return Ok(validate_import_schema(&self.pool, &V2_IMPORT_SCHEMA_RULES).await?); + } + + let Ok(pool) = self.open_readonly_pool().await else { + return Ok(false); + }; + + Ok(validate_import_schema(&pool, &V2_IMPORT_SCHEMA_RULES).await?) } pub async fn connect(&self) -> Result<()> { @@ -159,14 +172,41 @@ impl SqliteDocStorage { Ok(()) } + + pub async fn vacuum_into(&self, path: String) -> Result<()> { + if self.path == ":memory:" { + sqlx::query("VACUUM INTO ?;").bind(path).execute(&self.pool).await?; + return Ok(()); + } + + let pool = self.open_readonly_pool().await?; + sqlx::query("VACUUM INTO ?;").bind(path).execute(&pool).await?; + + Ok(()) + } + + async fn open_readonly_pool(&self) -> Result> { + let sqlite_options = SqliteConnectOptions::new() + .filename(&self.path) + .foreign_keys(false) + .read_only(true); + + Ok( + SqlitePoolOptions::new() + .max_connections(1) + .connect_with(sqlite_options) + .await?, + ) + } } #[cfg(test)] mod tests { - use std::borrow::Cow; + use std::{borrow::Cow, fs, path::Path}; use affine_schema::get_migrator; use sqlx::migrate::{Migration, Migrator}; + use uuid::Uuid; use super::*; @@ -256,4 +296,105 @@ mod tests { assert_eq!(checksum, expected_checksum); } + + #[tokio::test] + async fn vacuum_into_exports_a_compacted_database() { + let base = std::env::temp_dir().join(format!("nbstore-vacuum-{}", Uuid::new_v4())); + fs::create_dir_all(&base).unwrap(); + + let source = base.join("storage.db"); + let export = base.join("backup.affine"); + + let storage = SqliteDocStorage::new(path_string(&source)); + storage.connect().await.unwrap(); + + storage + .set_blob(crate::SetBlob { + key: "large-blob".to_string(), + data: vec![7; 1024 * 1024], + mime: "application/octet-stream".to_string(), + }) + .await + .unwrap(); + storage.delete_blob("large-blob".to_string(), true).await.unwrap(); + storage.checkpoint().await.unwrap(); + + let source_len = fs::metadata(&source).unwrap().len(); + assert!(source_len > 0); + + storage.vacuum_into(path_string(&export)).await.unwrap(); + + let export_len = fs::metadata(&export).unwrap().len(); + assert!(export_len < source_len); + + let exported = SqliteDocStorage::new(path_string(&export)); + exported.connect().await.unwrap(); + assert!(exported.list_blobs().await.unwrap().is_empty()); + exported.close().await; + storage.close().await; + + fs::remove_dir_all(base).unwrap(); + } + + #[tokio::test] + async fn validate_import_schema_rejects_unexpected_schema_objects() { + let base = std::env::temp_dir().join(format!("nbstore-schema-{}", Uuid::new_v4())); + fs::create_dir_all(&base).unwrap(); + + let source = base.join("storage.db"); + fs::File::create(&source).unwrap(); + let storage = SqliteDocStorage::new(path_string(&source)); + storage.connect().await.unwrap(); + + sqlx::query("CREATE VIEW rogue_view AS SELECT space_id FROM meta") + .execute(&storage.pool) + .await + .unwrap(); + + assert!(!storage.validate_import_schema().await.unwrap()); + + storage.close().await; + fs::remove_dir_all(base).unwrap(); + } + + #[tokio::test] + async fn validate_import_schema_accepts_initial_v2_schema() { + let base = std::env::temp_dir().join(format!("nbstore-v2-schema-{}", Uuid::new_v4())); + fs::create_dir_all(&base).unwrap(); + + let source = base.join("storage.db"); + let source_path = path_string(&source); + let setup_pool = SqlitePoolOptions::new() + .max_connections(1) + .connect_with( + SqliteConnectOptions::new() + .filename(&source_path) + .create_if_missing(true) + .foreign_keys(false), + ) + .await + .unwrap(); + + let mut migrations = get_migrator().migrations.to_vec(); + migrations.truncate(1); + let migrator = Migrator { + migrations: Cow::Owned(migrations), + ..Migrator::DEFAULT + }; + + migrator.run(&setup_pool).await.unwrap(); + setup_pool.close().await; + + let storage = SqliteDocStorage::new(source_path); + + assert!(storage.validate().await.unwrap()); + assert!(storage.validate_import_schema().await.unwrap()); + + storage.close().await; + fs::remove_dir_all(base).unwrap(); + } + + fn path_string(path: &Path) -> String { + path.to_string_lossy().into_owned() + } } diff --git a/packages/frontend/native/schema/src/import_validation.rs b/packages/frontend/native/schema/src/import_validation.rs new file mode 100644 index 0000000000..cf2e69f85c --- /dev/null +++ b/packages/frontend/native/schema/src/import_validation.rs @@ -0,0 +1,261 @@ +use std::collections::BTreeSet; + +use sqlx::{Pool, Row, sqlite::Sqlite}; + +pub struct ImportSchemaRules { + pub tables: &'static [ImportTableRule], + pub indexes: &'static [ImportIndexRule], +} + +pub struct ImportTableRule { + pub name: &'static str, + pub columns: &'static [&'static str], + pub enforce_columns: bool, + pub required: bool, +} + +pub struct ImportIndexRule { + pub name: &'static str, + pub table: &'static str, + pub columns: &'static [&'static str], + pub required: bool, +} + +pub const V2_IMPORT_SCHEMA_RULES: ImportSchemaRules = ImportSchemaRules { + tables: &[ + ImportTableRule { + name: "meta", + columns: &["space_id"], + enforce_columns: true, + required: true, + }, + ImportTableRule { + name: "snapshots", + columns: &["doc_id", "data", "created_at", "updated_at"], + enforce_columns: true, + required: true, + }, + ImportTableRule { + name: "updates", + columns: &["doc_id", "created_at", "data"], + enforce_columns: true, + required: true, + }, + ImportTableRule { + name: "clocks", + columns: &["doc_id", "timestamp"], + enforce_columns: true, + required: true, + }, + ImportTableRule { + name: "blobs", + columns: &["key", "data", "mime", "size", "created_at", "deleted_at"], + enforce_columns: true, + required: true, + }, + ImportTableRule { + name: "peer_clocks", + columns: &["peer", "doc_id", "remote_clock", "pulled_remote_clock", "pushed_clock"], + enforce_columns: true, + required: true, + }, + ImportTableRule { + name: "peer_blob_sync", + columns: &["peer", "blob_id", "uploaded_at"], + enforce_columns: true, + required: false, + }, + ImportTableRule { + name: "idx_snapshots", + columns: &["index_name", "data", "created_at"], + enforce_columns: true, + required: false, + }, + ImportTableRule { + name: "indexer_sync", + columns: &["doc_id", "indexed_clock", "indexer_version"], + enforce_columns: true, + required: false, + }, + ImportTableRule { + name: "_sqlx_migrations", + columns: &[], + enforce_columns: false, + required: false, + }, + ], + indexes: &[ + ImportIndexRule { + name: "peer_clocks_doc_id", + table: "peer_clocks", + columns: &["doc_id"], + required: true, + }, + ImportIndexRule { + name: "peer_blob_sync_peer", + table: "peer_blob_sync", + columns: &["peer"], + required: false, + }, + ], +}; + +pub const V1_IMPORT_SCHEMA_RULES: ImportSchemaRules = ImportSchemaRules { + tables: &[ + ImportTableRule { + name: "updates", + columns: &["id", "timestamp", "data", "doc_id"], + enforce_columns: true, + required: true, + }, + ImportTableRule { + name: "blobs", + columns: &["key", "data", "timestamp"], + enforce_columns: true, + required: true, + }, + ImportTableRule { + name: "version_info", + columns: &["version", "timestamp"], + enforce_columns: true, + required: true, + }, + ImportTableRule { + name: "server_clock", + columns: &["key", "data", "timestamp"], + enforce_columns: true, + required: true, + }, + ImportTableRule { + name: "sync_metadata", + columns: &["key", "data", "timestamp"], + enforce_columns: true, + required: true, + }, + ], + indexes: &[ImportIndexRule { + name: "idx_doc_id", + table: "updates", + columns: &["doc_id"], + required: false, + }], +}; + +pub async fn validate_import_schema(pool: &Pool, rules: &ImportSchemaRules) -> sqlx::Result { + validate_schema(pool, rules, ValidationMode::Strict).await +} + +pub async fn validate_required_schema(pool: &Pool, rules: &ImportSchemaRules) -> sqlx::Result { + validate_schema(pool, rules, ValidationMode::RequiredOnly).await +} + +#[derive(Clone, Copy)] +enum ValidationMode { + Strict, + RequiredOnly, +} + +async fn validate_schema(pool: &Pool, rules: &ImportSchemaRules, mode: ValidationMode) -> sqlx::Result { + let query = match mode { + ValidationMode::Strict => { + "SELECT type, name, tbl_name FROM sqlite_schema WHERE type IN ('table', 'index', 'view', 'trigger')" + } + ValidationMode::RequiredOnly => "SELECT type, name, tbl_name FROM sqlite_schema WHERE type IN ('table', 'index')", + }; + let rows = sqlx::query(query).fetch_all(pool).await?; + + let mut seen_tables = BTreeSet::new(); + let mut seen_indexes = BTreeSet::new(); + + for row in rows { + let object_type: String = row.try_get("type")?; + let name: String = row.try_get("name")?; + let table_name: String = row.try_get("tbl_name")?; + + if name.starts_with("sqlite_") { + continue; + } + + match object_type.as_str() { + "table" => { + let Some(rule) = rules.tables.iter().find(|rule| rule.name == name) else { + if matches!(mode, ValidationMode::Strict) { + return Ok(false); + } + continue; + }; + if rule.enforce_columns && !table_columns_match(pool, rule).await? { + return Ok(false); + } + seen_tables.insert(name); + } + "index" => { + let Some(rule) = rules + .indexes + .iter() + .find(|rule| rule.name == name && rule.table == table_name) + else { + if matches!(mode, ValidationMode::Strict) { + return Ok(false); + } + continue; + }; + if !index_columns_match(pool, rule).await? { + return Ok(false); + } + seen_indexes.insert(name); + } + "view" | "trigger" => return Ok(false), + _ => return Ok(false), + } + } + + if rules + .tables + .iter() + .filter(|rule| rule.required) + .any(|rule| !seen_tables.contains(rule.name)) + { + return Ok(false); + } + + if rules + .indexes + .iter() + .filter(|rule| rule.required) + .any(|rule| !seen_indexes.contains(rule.name)) + { + return Ok(false); + } + + Ok(true) +} + +async fn table_columns_match(pool: &Pool, rule: &ImportTableRule) -> sqlx::Result { + let pragma = format!("PRAGMA table_info(\"{}\")", rule.name); + let rows = sqlx::query(&pragma).fetch_all(pool).await?; + let columns = rows + .into_iter() + .map(|row| row.try_get::("name")) + .collect::, _>>()?; + + Ok(columns == rule.columns.iter().map(|column| (*column).to_string()).collect()) +} + +async fn index_columns_match(pool: &Pool, rule: &ImportIndexRule) -> sqlx::Result { + let pragma = format!("PRAGMA index_info(\"{}\")", rule.name); + let rows = sqlx::query(&pragma).fetch_all(pool).await?; + let columns = rows + .into_iter() + .map(|row| row.try_get::("name")) + .collect::, _>>()?; + + Ok( + columns + == rule + .columns + .iter() + .map(|column| (*column).to_string()) + .collect::>(), + ) +} diff --git a/packages/frontend/native/schema/src/lib.rs b/packages/frontend/native/schema/src/lib.rs index d636dd8aae..6b3d2bcf03 100644 --- a/packages/frontend/native/schema/src/lib.rs +++ b/packages/frontend/native/schema/src/lib.rs @@ -2,6 +2,7 @@ use std::borrow::Cow; use sqlx::migrate::{Migration, MigrationType, Migrator}; +pub mod import_validation; pub mod v1; type SimpleMigration = ( diff --git a/packages/frontend/native/sqlite_v1/src/lib.rs b/packages/frontend/native/sqlite_v1/src/lib.rs index 2aeea46e18..ee20d7ac77 100644 --- a/packages/frontend/native/sqlite_v1/src/lib.rs +++ b/packages/frontend/native/sqlite_v1/src/lib.rs @@ -1,3 +1,4 @@ +use affine_schema::import_validation::{V1_IMPORT_SCHEMA_RULES, validate_import_schema}; use chrono::NaiveDateTime; use napi::bindgen_prelude::{Buffer, Uint8Array}; use napi_derive::napi; @@ -423,7 +424,7 @@ impl SqliteConnection { #[napi] pub async fn validate(path: String) -> ValidationResult { - let pool = match SqlitePoolOptions::new().max_connections(1).connect(&path).await { + let pool = match open_readonly_pool(&path).await { Ok(pool) => pool, Err(_) => return ValidationResult::GeneralError, }; @@ -473,6 +474,16 @@ impl SqliteConnection { } } + #[napi] + pub async fn validate_import_schema(&self) -> napi::Result { + let pool = open_readonly_pool(&self.path).await?; + Ok( + validate_import_schema(&pool, &V1_IMPORT_SCHEMA_RULES) + .await + .map_err(anyhow::Error::from)?, + ) + } + #[napi] pub async fn migrate_add_doc_id(&self) -> napi::Result<()> { // ignore errors @@ -504,6 +515,17 @@ impl SqliteConnection { Ok(()) } + #[napi] + pub async fn vacuum_into(&self, path: String) -> napi::Result<()> { + let pool = open_readonly_pool(&self.path).await?; + sqlx::query("VACUUM INTO ?;") + .bind(path) + .execute(&pool) + .await + .map_err(anyhow::Error::from)?; + Ok(()) + } + pub async fn migrate_add_doc_id_index(&self) -> napi::Result<()> { // ignore errors match sqlx::query("CREATE INDEX IF NOT EXISTS idx_doc_id ON updates(doc_id);") @@ -517,3 +539,64 @@ impl SqliteConnection { } } } + +async fn open_readonly_pool(path: &str) -> anyhow::Result> { + let options = SqliteConnectOptions::new() + .filename(path) + .foreign_keys(false) + .read_only(true); + + Ok( + SqlitePoolOptions::new() + .max_connections(1) + .connect_with(options) + .await?, + ) +} + +#[cfg(test)] +mod tests { + use std::{ + fs, + time::{SystemTime, UNIX_EPOCH}, + }; + + use super::*; + + #[tokio::test] + async fn validate_import_schema_accepts_current_v1_schema() { + let unique = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_nanos(); + let base = std::env::temp_dir().join(format!("sqlite-v1-schema-valid-{unique}")); + fs::create_dir_all(&base).unwrap(); + + let source = base.join("storage.db"); + let connection = SqliteConnection::new(source.to_string_lossy().into_owned()).unwrap(); + connection.connect().await.unwrap(); + + assert!(connection.validate_import_schema().await.unwrap()); + + connection.close().await; + fs::remove_dir_all(base).unwrap(); + } + + #[tokio::test] + async fn validate_import_schema_rejects_unexpected_schema_objects() { + let unique = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_nanos(); + let base = std::env::temp_dir().join(format!("sqlite-v1-schema-{unique}")); + fs::create_dir_all(&base).unwrap(); + + let source = base.join("storage.db"); + let connection = SqliteConnection::new(source.to_string_lossy().into_owned()).unwrap(); + connection.connect().await.unwrap(); + + sqlx::query("CREATE TRIGGER rogue_trigger AFTER INSERT ON updates BEGIN SELECT 1; END;") + .execute(&connection.pool) + .await + .unwrap(); + + assert!(!connection.validate_import_schema().await.unwrap()); + + connection.close().await; + fs::remove_dir_all(base).unwrap(); + } +}