feat(server): implement doc service (#9961)

close CLOUD-94
This commit is contained in:
fengmk2
2025-02-08 03:37:41 +00:00
parent 5ae5fd88f1
commit 5d62c5e85c
37 changed files with 914 additions and 20 deletions

View File

@@ -10,8 +10,10 @@ import type { Request, Response } from 'express';
import { Socket } from 'socket.io';
import {
AccessDenied,
AuthenticationRequired,
Config,
CryptoHelper,
getRequestResponseFromContext,
parseCookies,
} from '../../base';
@@ -20,12 +22,14 @@ import { AuthService } from './service';
import { Session } from './session';
const PUBLIC_ENTRYPOINT_SYMBOL = Symbol('public');
const INTERNAL_ENTRYPOINT_SYMBOL = Symbol('internal');
@Injectable()
export class AuthGuard implements CanActivate, OnModuleInit {
private auth!: AuthService;
constructor(
private readonly crypto: CryptoHelper,
private readonly ref: ModuleRef,
private readonly reflector: Reflector
) {}
@@ -36,6 +40,21 @@ export class AuthGuard implements CanActivate, OnModuleInit {
async canActivate(context: ExecutionContext) {
const { req, res } = getRequestResponseFromContext(context);
const clazz = context.getClass();
const handler = context.getHandler();
// rpc request is internal
const isInternal = this.reflector.getAllAndOverride<boolean>(
INTERNAL_ENTRYPOINT_SYMBOL,
[clazz, handler]
);
if (isInternal) {
// check access token: data,signature
const accessToken = req.get('x-access-token');
if (accessToken && this.crypto.verify(accessToken)) {
return true;
}
throw new AccessDenied('Invalid internal request');
}
const userSession = await this.signIn(req, res);
if (res && userSession && userSession.expiresAt) {
@@ -45,7 +64,7 @@ export class AuthGuard implements CanActivate, OnModuleInit {
// api is public
const isPublic = this.reflector.getAllAndOverride<boolean>(
PUBLIC_ENTRYPOINT_SYMBOL,
[context.getClass(), context.getHandler()]
[clazz, handler]
);
if (isPublic) {
@@ -85,6 +104,11 @@ export class AuthGuard implements CanActivate, OnModuleInit {
*/
export const Public = () => SetMetadata(PUBLIC_ENTRYPOINT_SYMBOL, true);
/**
* Mark rpc api to be internal accessible
*/
export const Internal = () => SetMetadata(INTERNAL_ENTRYPOINT_SYMBOL, true);
export const AuthWebsocketOptionsProvider: FactoryProvider = {
provide: WEBSOCKET_OPTIONS,
useFactory: (config: Config, guard: AuthGuard) => {

View File

@@ -0,0 +1,122 @@
import { randomUUID } from 'node:crypto';
import { User, Workspace } from '@prisma/client';
import ava, { TestFn } from 'ava';
import request from 'supertest';
import { createTestingApp, type TestingApp } from '../../../__tests__/utils';
import { AppModule } from '../../../app.module';
import { CryptoHelper } from '../../../base';
import { ConfigModule } from '../../../base/config';
import { Models } from '../../../models';
const test = ava as TestFn<{
models: Models;
app: TestingApp;
crypto: CryptoHelper;
}>;
test.before(async t => {
const { app } = await createTestingApp({
imports: [ConfigModule.forRoot(), AppModule],
});
t.context.models = app.get(Models);
t.context.crypto = app.get(CryptoHelper);
t.context.app = app;
});
let user: User;
let workspace: Workspace;
test.beforeEach(async t => {
await t.context.app.initTestingDB();
user = await t.context.models.user.create({
email: 'test@affine.pro',
});
workspace = await t.context.models.workspace.create(user.id);
});
test.after.always(async t => {
await t.context.app.close();
});
test('should forbid access to rpc api without access token', async t => {
const { app } = t.context;
await request(app.getHttpServer())
.get('/rpc/workspaces/123/docs/123')
.expect({
status: 403,
code: 'Forbidden',
type: 'NO_PERMISSION',
name: 'ACCESS_DENIED',
message: 'Invalid internal request',
})
.expect(403);
t.pass();
});
test('should forbid access to rpc api with invalid access token', async t => {
const { app } = t.context;
await request(app.getHttpServer())
.get('/rpc/workspaces/123/docs/123')
.set('x-access-token', 'invalid,wrong-signature')
.expect({
status: 403,
code: 'Forbidden',
type: 'NO_PERMISSION',
name: 'ACCESS_DENIED',
message: 'Invalid internal request',
})
.expect(403);
t.pass();
});
test('should 404 when doc not found', async t => {
const { app } = t.context;
const workspaceId = '123';
const docId = '123';
await request(app.getHttpServer())
.get(`/rpc/workspaces/${workspaceId}/docs/${docId}`)
.set('x-access-token', t.context.crypto.sign(docId))
.expect({
status: 404,
code: 'Not Found',
type: 'RESOURCE_NOT_FOUND',
name: 'NOT_FOUND',
message: 'Doc not found',
})
.expect(404);
t.pass();
});
test('should return doc when found', async t => {
const { app } = t.context;
const docId = randomUUID();
const timestamp = Date.now();
await t.context.models.doc.createUpdates([
{
spaceId: workspace.id,
docId,
blob: Buffer.from('blob1 data'),
timestamp,
editorId: user.id,
},
]);
const res = await request(app.getHttpServer())
.get(`/rpc/workspaces/${workspace.id}/docs/${docId}`)
.set('x-access-token', t.context.crypto.sign(docId))
.set('x-rpc-trace-id', 'test-trace-id')
.expect(200)
.expect('x-request-id', 'test-trace-id')
.expect('Content-Type', 'application/octet-stream');
const bin = res.body as Buffer;
t.is(bin.toString(), 'blob1 data');
t.is(res.headers['x-doc-timestamp'], timestamp.toString());
t.is(res.headers['x-doc-editor-id'], user.id);
});

View File

@@ -0,0 +1,19 @@
import { defineStartupConfig, ModuleConfig } from '../../base/config';
interface DocServiceStartupConfigurations {
/**
* The endpoint of the doc service.
* Example: http://doc-service:3020
*/
endpoint: string;
}
declare module '../../base/config' {
interface AppConfig {
docService: ModuleConfig<DocServiceStartupConfigurations>;
}
}
defineStartupConfig('docService', {
endpoint: '',
});

View File

@@ -0,0 +1,30 @@
import { Controller, Get, Param, Res } from '@nestjs/common';
import type { Response } from 'express';
import { NotFound, SkipThrottle } from '../../base';
import { Internal } from '../auth';
import { PgWorkspaceDocStorageAdapter } from '../doc';
@Controller('/rpc')
export class DocRpcController {
constructor(private readonly workspace: PgWorkspaceDocStorageAdapter) {}
@SkipThrottle()
@Internal()
@Get('/workspaces/:workspaceId/docs/:docId')
async render(
@Param('workspaceId') workspaceId: string,
@Param('docId') docId: string,
@Res() res: Response
) {
const doc = await this.workspace.getDoc(workspaceId, docId);
if (!doc) {
throw new NotFound('Doc not found');
}
res.setHeader('x-doc-timestamp', doc.timestamp.toString());
if (doc.editor) {
res.setHeader('x-doc-editor-id', doc.editor);
}
res.send(doc.bin);
}
}

View File

@@ -0,0 +1,10 @@
import { Module } from '@nestjs/common';
import { DocStorageModule } from '../doc';
import { DocRpcController } from './controller';
@Module({
imports: [DocStorageModule],
controllers: [DocRpcController],
})
export class DocServiceModule {}

View File

@@ -0,0 +1,72 @@
import { randomUUID } from 'node:crypto';
import { User, Workspace } from '@prisma/client';
import ava, { TestFn } from 'ava';
import { createTestingApp, type TestingApp } from '../../../__tests__/utils';
import { AppModule } from '../../../app.module';
import { ConfigModule } from '../../../base/config';
import { Models } from '../../../models';
import { DocReader } from '..';
import { DatabaseDocReader } from '../reader';
const test = ava as TestFn<{
models: Models;
app: TestingApp;
docReader: DocReader;
}>;
test.before(async t => {
const { app } = await createTestingApp({
imports: [ConfigModule.forRoot(), AppModule],
});
t.context.models = app.get(Models);
t.context.docReader = app.get(DocReader);
t.context.app = app;
});
let user: User;
let workspace: Workspace;
test.beforeEach(async t => {
await t.context.app.initTestingDB();
user = await t.context.models.user.create({
email: 'test@affine.pro',
});
workspace = await t.context.models.workspace.create(user.id);
});
test.after.always(async t => {
await t.context.app.close();
});
test('should return null when doc not found', async t => {
const { docReader } = t.context;
const docId = randomUUID();
const doc = await docReader.getDoc(workspace.id, docId);
t.is(doc, null);
});
test('should return doc when found', async t => {
const { docReader } = t.context;
t.true(docReader instanceof DatabaseDocReader);
const docId = randomUUID();
const timestamp = Date.now();
await t.context.models.doc.createUpdates([
{
spaceId: workspace.id,
docId,
blob: Buffer.from('blob1 data'),
timestamp,
editorId: user.id,
},
]);
const doc = await docReader.getDoc(workspace.id, docId);
t.truthy(doc);
t.is(doc!.bin.toString(), 'blob1 data');
t.is(doc!.timestamp, timestamp);
t.is(doc!.editor, user.id);
});

View File

@@ -0,0 +1,124 @@
import { randomUUID } from 'node:crypto';
import { mock } from 'node:test';
import { User, Workspace } from '@prisma/client';
import ava, { TestFn } from 'ava';
import { createTestingApp, type TestingApp } from '../../../__tests__/utils';
import { AppModule } from '../../../app.module';
import { Config, InternalServerError } from '../../../base';
import { ConfigModule } from '../../../base/config';
import { Models } from '../../../models';
import { DocReader } from '..';
import { RpcDocReader } from '../reader';
const test = ava as TestFn<{
models: Models;
app: TestingApp;
docReader: DocReader;
config: Config;
}>;
test.before(async t => {
const { app } = await createTestingApp({
imports: [
ConfigModule.forRoot({
flavor: {
doc: false,
},
docService: {
endpoint: '',
},
}),
AppModule,
],
});
t.context.models = app.get(Models);
t.context.docReader = app.get(DocReader);
t.context.config = app.get(Config);
t.context.app = app;
});
let user: User;
let workspace: Workspace;
test.beforeEach(async t => {
t.context.config.docService.endpoint = t.context.app.getHttpServerUrl();
await t.context.app.initTestingDB();
user = await t.context.models.user.create({
email: 'test@affine.pro',
});
workspace = await t.context.models.workspace.create(user.id);
});
test.afterEach.always(() => {
mock.reset();
});
test.after.always(async t => {
await t.context.app.close();
});
test('should return null when doc not found', async t => {
const { docReader } = t.context;
const docId = randomUUID();
const doc = await docReader.getDoc(workspace.id, docId);
t.is(doc, null);
});
test('should throw error when doc service internal error', async t => {
const { docReader } = t.context;
const docId = randomUUID();
mock.method(docReader, 'getDoc', async () => {
throw new InternalServerError('mock doc service internal error');
});
await t.throwsAsync(docReader.getDoc(workspace.id, docId), {
instanceOf: InternalServerError,
});
});
test('should fallback to database doc service when endpoint network error', async t => {
const { docReader } = t.context;
t.context.config.docService.endpoint = 'http://localhost:13010';
const docId = randomUUID();
const timestamp = Date.now();
await t.context.models.doc.createUpdates([
{
spaceId: workspace.id,
docId,
blob: Buffer.from('blob1 data'),
timestamp,
editorId: user.id,
},
]);
const doc = await docReader.getDoc(workspace.id, docId);
t.truthy(doc);
t.is(doc!.bin.toString(), 'blob1 data');
t.is(doc!.timestamp, timestamp);
t.is(doc!.editor, user.id);
});
test('should return doc when found', async t => {
const { docReader } = t.context;
t.true(docReader instanceof RpcDocReader);
const docId = randomUUID();
const timestamp = Date.now();
await t.context.models.doc.createUpdates([
{
spaceId: workspace.id,
docId,
blob: Buffer.from('blob1 data'),
timestamp,
editorId: user.id,
},
]);
const doc = await docReader.getDoc(workspace.id, docId);
t.truthy(doc);
t.is(doc!.bin.toString(), 'blob1 data');
t.is(doc!.timestamp, timestamp);
t.is(doc!.editor, user.id);
});

View File

@@ -41,7 +41,9 @@ declare global {
}
@Injectable()
export class PgWorkspaceDocStorageAdapter extends DocStorageAdapter {
private readonly logger = new Logger(PgWorkspaceDocStorageAdapter.name);
protected override readonly logger = new Logger(
PgWorkspaceDocStorageAdapter.name
);
constructor(
private readonly models: Models,

View File

@@ -8,6 +8,7 @@ import { PgUserspaceDocStorageAdapter } from './adapters/userspace';
import { PgWorkspaceDocStorageAdapter } from './adapters/workspace';
import { DocStorageCronJob } from './job';
import { DocStorageOptions } from './options';
import { DocReader, DocReaderProvider } from './reader';
@Module({
imports: [QuotaModule, PermissionModule],
@@ -16,10 +17,15 @@ import { DocStorageOptions } from './options';
PgWorkspaceDocStorageAdapter,
PgUserspaceDocStorageAdapter,
DocStorageCronJob,
DocReaderProvider,
],
exports: [PgWorkspaceDocStorageAdapter, PgUserspaceDocStorageAdapter],
})
export class DocStorageModule {}
export { PgUserspaceDocStorageAdapter, PgWorkspaceDocStorageAdapter };
export {
DocReader,
PgUserspaceDocStorageAdapter,
PgWorkspaceDocStorageAdapter,
};
export { DocStorageAdapter, type Editor } from './storage';

View File

@@ -0,0 +1,93 @@
import { FactoryProvider, Injectable, Logger } from '@nestjs/common';
import { ModuleRef } from '@nestjs/core';
import { ClsService } from 'nestjs-cls';
import { Config, CryptoHelper, UserFriendlyError } from '../../base';
import { PgWorkspaceDocStorageAdapter } from './adapters/workspace';
import { type DocRecord } from './storage';
export abstract class DocReader {
abstract getDoc(
workspaceId: string,
docId: string
): Promise<DocRecord | null>;
}
@Injectable()
export class DatabaseDocReader extends DocReader {
constructor(protected readonly workspace: PgWorkspaceDocStorageAdapter) {
super();
}
async getDoc(workspaceId: string, docId: string): Promise<DocRecord | null> {
return await this.workspace.getDoc(workspaceId, docId);
}
}
@Injectable()
export class RpcDocReader extends DatabaseDocReader {
private readonly logger = new Logger(DocReader.name);
constructor(
private readonly config: Config,
private readonly crypto: CryptoHelper,
private readonly cls: ClsService,
protected override readonly workspace: PgWorkspaceDocStorageAdapter
) {
super(workspace);
}
override async getDoc(
workspaceId: string,
docId: string
): Promise<DocRecord | null> {
const url = `${this.config.docService.endpoint}/rpc/workspaces/${workspaceId}/docs/${docId}`;
try {
const res = await fetch(url, {
headers: {
'x-access-token': this.crypto.sign(docId),
'x-rpc-trace-id': this.cls.getId(),
},
});
if (!res.ok) {
if (res.status === 404) {
return null;
}
const body = (await res.json()) as UserFriendlyError;
throw UserFriendlyError.fromUserFriendlyErrorJSON(body);
}
const timestamp = res.headers.get('x-doc-timestamp') as string;
const editor = res.headers.get('x-doc-editor-id') as string;
const bin = await res.arrayBuffer();
return {
spaceId: workspaceId,
docId,
bin: Buffer.from(bin),
timestamp: parseInt(timestamp),
editor,
};
} catch (err) {
if (err instanceof UserFriendlyError) {
throw err;
}
// other error
this.logger.error(
`Failed to fetch doc ${url}, error: ${err}`,
(err as Error).stack
);
// fallback to database doc service if the error is not user friendly, like network error
return await super.getDoc(workspaceId, docId);
}
}
}
export const DocReaderProvider: FactoryProvider = {
provide: DocReader,
useFactory: (config: Config, ref: ModuleRef) => {
if (config.flavor.doc) {
return ref.create(DatabaseDocReader);
}
return ref.create(RpcDocReader);
},
inject: [Config, ModuleRef],
};

View File

@@ -1,3 +1,4 @@
import { Logger } from '@nestjs/common';
import {
applyUpdate,
diffUpdate,
@@ -49,6 +50,7 @@ export interface DocStorageOptions {
export abstract class DocStorageAdapter extends Connection {
private readonly locker = new SingletonLocker();
protected readonly logger = new Logger(DocStorageAdapter.name);
constructor(
protected readonly options: DocStorageOptions = {
@@ -76,6 +78,9 @@ export abstract class DocStorageAdapter extends Connection {
const updates = await this.getDocUpdates(spaceId, docId);
if (updates.length) {
this.logger.log(
`Squashing updates, spaceId: ${spaceId}, docId: ${docId}, updates: ${updates.length}`
);
const { timestamp, bin, editor } = await this.squash(
snapshot ? [snapshot, ...updates] : updates
);
@@ -96,7 +101,12 @@ export abstract class DocStorageAdapter extends Connection {
}
// always mark updates as merged unless throws
await this.markUpdatesMerged(spaceId, docId, updates);
const count = await this.markUpdatesMerged(spaceId, docId, updates);
if (count > 0) {
this.logger.log(
`Marked ${count} updates as merged, spaceId: ${spaceId}, docId: ${docId}`
);
}
return newSnapshot;
}