feat(server): implement doc service (#9961)

close CLOUD-94
This commit is contained in:
fengmk2
2025-02-08 03:37:41 +00:00
parent 5ae5fd88f1
commit 5d62c5e85c
37 changed files with 914 additions and 20 deletions

View File

@@ -47,18 +47,21 @@ const replicaConfig = {
graphql: Number(process.env.PRODUCTION_GRAPHQL_REPLICA) || 3, graphql: Number(process.env.PRODUCTION_GRAPHQL_REPLICA) || 3,
sync: Number(process.env.PRODUCTION_SYNC_REPLICA) || 3, sync: Number(process.env.PRODUCTION_SYNC_REPLICA) || 3,
renderer: Number(process.env.PRODUCTION_RENDERER_REPLICA) || 3, renderer: Number(process.env.PRODUCTION_RENDERER_REPLICA) || 3,
doc: Number(process.env.PRODUCTION_DOC_REPLICA) || 3,
}, },
beta: { beta: {
web: 2, web: 2,
graphql: Number(process.env.BETA_GRAPHQL_REPLICA) || 2, graphql: Number(process.env.BETA_GRAPHQL_REPLICA) || 2,
sync: Number(process.env.BETA_SYNC_REPLICA) || 2, sync: Number(process.env.BETA_SYNC_REPLICA) || 2,
renderer: Number(process.env.BETA_RENDERER_REPLICA) || 2, renderer: Number(process.env.BETA_RENDERER_REPLICA) || 2,
doc: Number(process.env.BETA_DOC_REPLICA) || 2,
}, },
canary: { canary: {
web: 2, web: 2,
graphql: 2, graphql: 2,
sync: 2, sync: 2,
renderer: 2, renderer: 2,
doc: 2,
}, },
}; };
@@ -67,12 +70,14 @@ const cpuConfig = {
web: '300m', web: '300m',
graphql: '1', graphql: '1',
sync: '1', sync: '1',
doc: '1',
renderer: '300m', renderer: '300m',
}, },
canary: { canary: {
web: '300m', web: '300m',
graphql: '1', graphql: '1',
sync: '1', sync: '1',
doc: '1',
renderer: '300m', renderer: '300m',
}, },
}; };
@@ -111,6 +116,7 @@ const createHelmCommand = ({ isDryRun }) => {
`--set web.resources.requests.cpu="${cpu.web}"`, `--set web.resources.requests.cpu="${cpu.web}"`,
`--set graphql.resources.requests.cpu="${cpu.graphql}"`, `--set graphql.resources.requests.cpu="${cpu.graphql}"`,
`--set sync.resources.requests.cpu="${cpu.sync}"`, `--set sync.resources.requests.cpu="${cpu.sync}"`,
`--set doc.resources.requests.cpu="${cpu.doc}"`,
] ]
: []; : [];
@@ -168,6 +174,9 @@ const createHelmCommand = ({ isDryRun }) => {
`--set-string renderer.image.tag="${imageTag}"`, `--set-string renderer.image.tag="${imageTag}"`,
`--set renderer.app.host=${host}`, `--set renderer.app.host=${host}`,
`--set renderer.replicaCount=${replica.renderer}`, `--set renderer.replicaCount=${replica.renderer}`,
`--set-string doc.image.tag="${imageTag}"`,
`--set doc.app.host=${host}`,
`--set doc.replicaCount=${replica.doc}`,
...serviceAnnotations, ...serviceAnnotations,
...resources, ...resources,
`--timeout 10m`, `--timeout 10m`,

View File

@@ -0,0 +1,11 @@
apiVersion: v2
name: doc
description: AFFiNE doc server
type: application
version: 0.0.0
appVersion: "0.20.0"
dependencies:
- name: gcloud-sql-proxy
version: 0.0.0
repository: "file://../gcloud-sql-proxy"
condition: .global.database.gcloud.enabled

View File

@@ -0,0 +1,16 @@
1. Get the application URL by running these commands:
{{- if contains "NodePort" .Values.service.type }}
export NODE_PORT=$(kubectl get --namespace {{ .Release.Namespace }} -o jsonpath="{.spec.ports[0].nodePort}" services {{ include "doc.fullname" . }})
export NODE_IP=$(kubectl get nodes --namespace {{ .Release.Namespace }} -o jsonpath="{.items[0].status.addresses[0].address}")
echo http://$NODE_IP:$NODE_PORT
{{- else if contains "LoadBalancer" .Values.service.type }}
NOTE: It may take a few minutes for the LoadBalancer IP to be available.
You can watch the status of by running 'kubectl get --namespace {{ .Release.Namespace }} svc -w {{ include "doc.fullname" . }}'
export SERVICE_IP=$(kubectl get svc --namespace {{ .Release.Namespace }} {{ include "doc.fullname" . }} --template "{{"{{ range (index .status.loadBalancer.ingress 0) }}{{.}}{{ end }}"}}")
echo http://$SERVICE_IP:{{ .Values.service.port }}
{{- else if contains "ClusterIP" .Values.service.type }}
export POD_NAME=$(kubectl get pods --namespace {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "doc.name" . }},app.kubernetes.io/instance={{ .Release.Name }}" -o jsonpath="{.items[0].metadata.name}")
export CONTAINER_PORT=$(kubectl get pod --namespace {{ .Release.Namespace }} $POD_NAME -o jsonpath="{.spec.containers[0].ports[0].containerPort}")
echo "Visit http://127.0.0.1:8080 to use your application"
kubectl --namespace {{ .Release.Namespace }} port-forward $POD_NAME 8080:$CONTAINER_PORT
{{- end }}

View File

@@ -0,0 +1,63 @@
{{/*
Expand the name of the chart.
*/}}
{{- define "doc.name" -}}
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Create a default fully qualified app name.
We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
If release name contains chart name it will be used as a full name.
*/}}
{{- define "doc.fullname" -}}
{{- if .Values.fullnameOverride }}
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
{{- else }}
{{- $name := default .Chart.Name .Values.nameOverride }}
{{- if contains $name .Release.Name }}
{{- .Release.Name | trunc 63 | trimSuffix "-" }}
{{- else }}
{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
{{- end }}
{{- end }}
{{- end }}
{{/*
Create chart name and version as used by the chart label.
*/}}
{{- define "doc.chart" -}}
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Common labels
*/}}
{{- define "doc.labels" -}}
helm.sh/chart: {{ include "doc.chart" . }}
{{ include "doc.selectorLabels" . }}
{{- if .Chart.AppVersion }}
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
{{- end }}
app.kubernetes.io/managed-by: {{ .Release.Service }}
monitoring: enabled
{{- end }}
{{/*
Selector labels
*/}}
{{- define "doc.selectorLabels" -}}
app.kubernetes.io/name: {{ include "doc.name" . }}
app.kubernetes.io/instance: {{ .Release.Name }}
{{- end }}
{{/*
Create the name of the service account to use
*/}}
{{- define "doc.serviceAccountName" -}}
{{- if .Values.serviceAccount.create }}
{{- default (include "doc.fullname" .) .Values.serviceAccount.name }}
{{- else }}
{{- default "default" .Values.serviceAccount.name }}
{{- end }}
{{- end }}

View File

@@ -0,0 +1,105 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "doc.fullname" . }}
labels:
{{- include "doc.labels" . | nindent 4 }}
spec:
replicas: {{ .Values.replicaCount }}
selector:
matchLabels:
{{- include "doc.selectorLabels" . | nindent 6 }}
template:
metadata:
{{- with .Values.podAnnotations }}
annotations:
{{- toYaml . | nindent 8 }}
{{- end }}
labels:
{{- include "doc.selectorLabels" . | nindent 8 }}
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "doc.serviceAccountName" . }}
containers:
- name: {{ .Chart.Name }}
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
imagePullPolicy: {{ .Values.image.pullPolicy }}
env:
- name: AFFINE_PRIVATE_KEY
valueFrom:
secretKeyRef:
name: "{{ .Values.global.secret.secretName }}"
key: key
- name: NODE_ENV
value: "{{ .Values.env }}"
- name: NODE_OPTIONS
value: "--max-old-space-size=4096"
- name: NO_COLOR
value: "1"
- name: DEPLOYMENT_TYPE
value: "affine"
- name: SERVER_FLAVOR
value: "doc"
- name: AFFINE_ENV
value: "{{ .Release.Namespace }}"
- name: DATABASE_PASSWORD
valueFrom:
secretKeyRef:
name: pg-postgresql
key: postgres-password
- name: DATABASE_URL
value: postgres://{{ .Values.global.database.user }}:$(DATABASE_PASSWORD)@{{ .Values.global.database.url }}:{{ .Values.global.database.port }}/{{ .Values.global.database.name }}
- name: REDIS_SERVER_ENABLED
value: "true"
- name: REDIS_SERVER_HOST
value: "{{ .Values.global.redis.host }}"
- name: REDIS_SERVER_PORT
value: "{{ .Values.global.redis.port }}"
- name: REDIS_SERVER_USER
value: "{{ .Values.global.redis.username }}"
- name: REDIS_SERVER_PASSWORD
valueFrom:
secretKeyRef:
name: redis
key: redis-password
- name: REDIS_SERVER_DATABASE
value: "{{ .Values.global.redis.database }}"
- name: AFFINE_SERVER_PORT
value: "{{ .Values.service.port }}"
- name: AFFINE_SERVER_SUB_PATH
value: "{{ .Values.app.path }}"
- name: AFFINE_SERVER_HOST
value: "{{ .Values.app.host }}"
- name: AFFINE_SERVER_HTTPS
value: "{{ .Values.app.https }}"
ports:
- name: http
containerPort: {{ .Values.service.port }}
protocol: TCP
livenessProbe:
httpGet:
path: /info
port: http
initialDelaySeconds: {{ .Values.probe.initialDelaySeconds }}
readinessProbe:
httpGet:
path: /info
port: http
initialDelaySeconds: {{ .Values.probe.initialDelaySeconds }}
resources:
{{- toYaml .Values.resources | nindent 12 }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}

View File

@@ -0,0 +1,19 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "doc.fullname" . }}
labels:
{{- include "doc.labels" . | nindent 4 }}
{{- with .Values.service.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
type: {{ .Values.service.type }}
ports:
- port: {{ .Values.service.port }}
targetPort: http
protocol: TCP
name: http
selector:
{{- include "doc.selectorLabels" . | nindent 4 }}

View File

@@ -0,0 +1,12 @@
{{- if .Values.serviceAccount.create -}}
apiVersion: v1
kind: ServiceAccount
metadata:
name: {{ include "doc.serviceAccountName" . }}
labels:
{{- include "doc.labels" . | nindent 4 }}
{{- with .Values.serviceAccount.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
{{- end }}

View File

@@ -0,0 +1,15 @@
apiVersion: v1
kind: Pod
metadata:
name: "{{ include "doc.fullname" . }}-test-connection"
labels:
{{- include "doc.labels" . | nindent 4 }}
annotations:
"helm.sh/hook": test
spec:
containers:
- name: wget
image: busybox
command: ['wget']
args: ['{{ include "doc.fullname" . }}:{{ .Values.service.port }}']
restartPolicy: Never

View File

@@ -0,0 +1,38 @@
replicaCount: 1
image:
repository: ghcr.io/toeverything/affine-graphql
pullPolicy: IfNotPresent
tag: ''
imagePullSecrets: []
nameOverride: ''
fullnameOverride: ''
# map to NODE_ENV environment variable
env: 'production'
app:
# AFFINE_SERVER_SUB_PATH
path: ''
# AFFINE_SERVER_HOST
host: '0.0.0.0'
https: true
serviceAccount:
create: true
annotations: {}
name: 'affine-doc'
podAnnotations: {}
podSecurityContext:
fsGroup: 2000
resources:
requests:
cpu: '2'
memory: 4Gi
probe:
initialDelaySeconds: 20
nodeSelector: {}
tolerations: []
affinity: {}

View File

@@ -118,6 +118,8 @@ spec:
key: stripeWebhookKey key: stripeWebhookKey
- name: DOC_MERGE_INTERVAL - name: DOC_MERGE_INTERVAL
value: "{{ .Values.app.doc.mergeInterval }}" value: "{{ .Values.app.doc.mergeInterval }}"
- name: DOC_SERVICE_ENDPOINT
value: "{{ .Values.global.docService.endpoint }}"
{{ if .Values.app.experimental.enableJwstCodec }} {{ if .Values.app.experimental.enableJwstCodec }}
- name: DOC_MERGE_USE_JWST_CODEC - name: DOC_MERGE_USE_JWST_CODEC
value: "true" value: "true"

View File

@@ -94,6 +94,8 @@ spec:
name: "{{ .Values.global.objectStorage.r2.secretName }}" name: "{{ .Values.global.objectStorage.r2.secretName }}"
key: secretAccessKey key: secretAccessKey
{{ end }} {{ end }}
- name: DOC_SERVICE_ENDPOINT
value: "{{ .Values.global.docService.endpoint }}"
ports: ports:
- name: http - name: http
containerPort: {{ .Values.service.port }} containerPort: {{ .Values.service.port }}

View File

@@ -73,6 +73,8 @@ spec:
value: "{{ .Values.service.port }}" value: "{{ .Values.service.port }}"
- name: AFFINE_SERVER_HOST - name: AFFINE_SERVER_HOST
value: "{{ .Values.app.host }}" value: "{{ .Values.app.host }}"
- name: DOC_SERVICE_ENDPOINT
value: "{{ .Values.global.docService.endpoint }}"
ports: ports:
- name: http - name: http
containerPort: {{ .Values.service.port }} containerPort: {{ .Values.service.port }}

View File

@@ -39,6 +39,8 @@ global:
secretAccessKey: '' secretAccessKey: ''
gke: gke:
enabled: true enabled: true
docService:
endpoint: 'http://affine-doc:3020'
graphql: graphql:
service: service:
@@ -61,6 +63,13 @@ renderer:
annotations: annotations:
cloud.google.com/backend-config: '{"default": "affine-api-backendconfig"}' cloud.google.com/backend-config: '{"default": "affine-api-backendconfig"}'
doc:
service:
type: ClusterIP
port: 3020
annotations:
cloud.google.com/backend-config: '{"default": "affine-api-backendconfig"}'
web: web:
service: service:
type: ClusterIP type: ClusterIP

View File

@@ -0,0 +1,40 @@
import type { INestApplication } from '@nestjs/common';
import type { TestFn } from 'ava';
import ava from 'ava';
import request from 'supertest';
import { buildAppModule } from '../../app.module';
import { createTestingApp } from '../utils';
const test = ava as TestFn<{
app: INestApplication;
}>;
test.before('start app', async t => {
// @ts-expect-error override
AFFiNE.flavor = {
type: 'doc',
allinone: false,
graphql: false,
sync: false,
renderer: false,
doc: true,
} satisfies typeof AFFiNE.flavor;
const { app } = await createTestingApp({
imports: [buildAppModule()],
});
t.context.app = app;
});
test.after.always(async t => {
await t.context.app.close();
});
test('should init app', async t => {
const res = await request(t.context.app.getHttpServer())
.get('/info')
.expect(200);
t.is(res.body.flavor, 'doc');
});

View File

@@ -20,6 +20,7 @@ test.before('start app', async t => {
graphql: true, graphql: true,
sync: false, sync: false,
renderer: false, renderer: false,
doc: false,
} satisfies typeof AFFiNE.flavor; } satisfies typeof AFFiNE.flavor;
const { app } = await createTestingApp({ const { app } = await createTestingApp({
imports: [buildAppModule()], imports: [buildAppModule()],

View File

@@ -18,6 +18,7 @@ test.before('start app', async t => {
graphql: false, graphql: false,
sync: false, sync: false,
renderer: true, renderer: true,
doc: false,
} satisfies typeof AFFiNE.flavor; } satisfies typeof AFFiNE.flavor;
const { app } = await createTestingApp({ const { app } = await createTestingApp({
imports: [buildAppModule()], imports: [buildAppModule()],

View File

@@ -18,6 +18,7 @@ test.before('start app', async t => {
graphql: false, graphql: false,
sync: true, sync: true,
renderer: false, renderer: false,
doc: false,
} satisfies typeof AFFiNE.flavor; } satisfies typeof AFFiNE.flavor;
const { app } = await createTestingApp({ const { app } = await createTestingApp({
imports: [buildAppModule()], imports: [buildAppModule()],

View File

@@ -58,6 +58,8 @@ export type TestingModule = BaseTestingModule & {
export type TestingApp = INestApplication & { export type TestingApp = INestApplication & {
initTestingDB(): Promise<void>; initTestingDB(): Promise<void>;
[Symbol.asyncDispose](): Promise<void>; [Symbol.asyncDispose](): Promise<void>;
// get the url of the http server, e.g. http://localhost:random-port
getHttpServerUrl(): string;
}; };
function dedupeModules(modules: NonNullable<ModuleMetadata['imports']>) { function dedupeModules(modules: NonNullable<ModuleMetadata['imports']>) {
@@ -180,6 +182,15 @@ export async function createTestingApp(
await m[Symbol.asyncDispose](); await m[Symbol.asyncDispose]();
await app.close(); await app.close();
}; };
app.getHttpServerUrl = () => {
const server = app.getHttpServer();
if (!server.address()) {
server.listen();
}
return `http://localhost:${server.address().port}`;
};
return { return {
module: m, module: m,
app: app, app: app,

View File

@@ -10,7 +10,7 @@ import { ScheduleModule } from '@nestjs/schedule';
import { ClsPluginTransactional } from '@nestjs-cls/transactional'; import { ClsPluginTransactional } from '@nestjs-cls/transactional';
import { TransactionalAdapterPrisma } from '@nestjs-cls/transactional-adapter-prisma'; import { TransactionalAdapterPrisma } from '@nestjs-cls/transactional-adapter-prisma';
import { PrismaClient } from '@prisma/client'; import { PrismaClient } from '@prisma/client';
import { Response } from 'express'; import { Request, Response } from 'express';
import { get } from 'lodash-es'; import { get } from 'lodash-es';
import { ClsModule } from 'nestjs-cls'; import { ClsModule } from 'nestjs-cls';
@@ -36,6 +36,7 @@ import { AuthModule } from './core/auth';
import { ADD_ENABLED_FEATURES, ServerConfigModule } from './core/config'; import { ADD_ENABLED_FEATURES, ServerConfigModule } from './core/config';
import { DocStorageModule } from './core/doc'; import { DocStorageModule } from './core/doc';
import { DocRendererModule } from './core/doc-renderer'; import { DocRendererModule } from './core/doc-renderer';
import { DocServiceModule } from './core/doc-service';
import { FeatureModule } from './core/features'; import { FeatureModule } from './core/features';
import { PermissionModule } from './core/permission'; import { PermissionModule } from './core/permission';
import { QuotaModule } from './core/quota'; import { QuotaModule } from './core/quota';
@@ -56,9 +57,9 @@ export const FunctionalityModules = [
middleware: { middleware: {
mount: true, mount: true,
generateId: true, generateId: true,
idGenerator() { idGenerator(req: Request) {
// make every request has a unique id to tracing // make every request has a unique id to tracing
return `req-${randomUUID()}`; return req.get('x-rpc-trace-id') ?? `req-${randomUUID()}`;
}, },
setup(cls, _req, res: Response) { setup(cls, _req, res: Response) {
res.setHeader('X-Request-Id', cls.getId()); res.setHeader('X-Request-Id', cls.getId());
@@ -219,6 +220,9 @@ export function buildAppModule() {
LicenseModule LicenseModule
) )
// doc service only
.useIf(config => config.flavor.doc, DocServiceModule)
// self hosted server only // self hosted server only
.useIf(config => config.isSelfhosted, SelfhostModule) .useIf(config => config.isSelfhosted, SelfhostModule)
.useIf(config => config.flavor.renderer, DocRendererModule); .useIf(config => config.flavor.renderer, DocRendererModule);

View File

@@ -2,7 +2,7 @@ import type { LeafPaths } from '../utils/types';
import { AppStartupConfig } from './types'; import { AppStartupConfig } from './types';
export type EnvConfigType = 'string' | 'int' | 'float' | 'boolean'; export type EnvConfigType = 'string' | 'int' | 'float' | 'boolean';
export type ServerFlavor = 'allinone' | 'graphql' | 'sync' | 'renderer'; export type ServerFlavor = 'allinone' | 'graphql' | 'sync' | 'renderer' | 'doc';
export type AFFINE_ENV = 'dev' | 'beta' | 'production'; export type AFFINE_ENV = 'dev' | 'beta' | 'production';
export type NODE_ENV = 'development' | 'test' | 'production' | 'script'; export type NODE_ENV = 'development' | 'test' | 'production' | 'script';
@@ -41,9 +41,9 @@ export type AFFiNEConfig = PreDefinedAFFiNEConfig &
AppPluginsConfig; AppPluginsConfig;
declare global { declare global {
// eslint-disable-next-line @typescript-eslint/no-namespace // oxlint-disable-next-line @typescript-eslint/no-namespace
namespace globalThis { namespace globalThis {
// eslint-disable-next-line no-var // oxlint-disable-next-line no-var
var AFFiNE: AFFiNEConfig; var AFFiNE: AFFiNEConfig;
} }
} }

View File

@@ -30,6 +30,7 @@ function getPredefinedAFFiNEConfig(): PreDefinedAFFiNEConfig {
'graphql', 'graphql',
'sync', 'sync',
'renderer', 'renderer',
'doc',
]); ]);
const deploymentType = readEnv<DeploymentType>( const deploymentType = readEnv<DeploymentType>(
'DEPLOYMENT_TYPE', 'DEPLOYMENT_TYPE',
@@ -66,6 +67,7 @@ function getPredefinedAFFiNEConfig(): PreDefinedAFFiNEConfig {
graphql: flavor === 'graphql' || flavor === 'allinone', graphql: flavor === 'graphql' || flavor === 'allinone',
sync: flavor === 'sync' || flavor === 'allinone', sync: flavor === 'sync' || flavor === 'allinone',
renderer: flavor === 'renderer' || flavor === 'allinone', renderer: flavor === 'renderer' || flavor === 'allinone',
doc: flavor === 'doc' || flavor === 'allinone',
}, },
affine, affine,
node, node,

View File

@@ -101,6 +101,15 @@ export class UserFriendlyError extends Error {
this.requestId = ClsServiceManager.getClsService()?.getId(); this.requestId = ClsServiceManager.getClsService()?.getId();
} }
static fromUserFriendlyErrorJSON(body: UserFriendlyError) {
return new UserFriendlyError(
body.type as UserFriendlyErrorBaseType,
body.name.toLowerCase() as keyof typeof USER_FRIENDLY_ERRORS,
body.message,
body.data
);
}
toJSON() { toJSON() {
return { return {
status: this.status, status: this.status,

View File

@@ -50,8 +50,18 @@ test.beforeEach(async t => {
test('should be able to sign and verify', t => { test('should be able to sign and verify', t => {
const data = 'hello world'; const data = 'hello world';
const signature = t.context.crypto.sign(data); const signature = t.context.crypto.sign(data);
t.true(t.context.crypto.verify(data, signature)); t.true(t.context.crypto.verify(signature));
t.false(t.context.crypto.verify(data, 'fake-signature')); t.false(t.context.crypto.verify('fake-signature'));
t.false(t.context.crypto.verify(`${data},fake-signature`));
});
test('should same data should get different signature', t => {
const data = 'hello world';
const signature = t.context.crypto.sign(data);
const signature2 = t.context.crypto.sign(data);
t.not(signature2, signature);
t.true(t.context.crypto.verify(signature));
t.true(t.context.crypto.verify(signature2));
}); });
test('should be able to encrypt and decrypt', t => { test('should be able to encrypt and decrypt', t => {

View File

@@ -46,10 +46,14 @@ export class CryptoHelper {
const sign = createSign('rsa-sha256'); const sign = createSign('rsa-sha256');
sign.update(data, 'utf-8'); sign.update(data, 'utf-8');
sign.end(); sign.end();
return sign.sign(this.keyPair.privateKey, 'base64'); return `${data},${sign.sign(this.keyPair.privateKey, 'base64')}`;
} }
verify(data: string, signature: string) { verify(signatureWithData: string) {
const [data, signature] = signatureWithData.split(',');
if (!signature) {
return false;
}
const verify = createVerify('rsa-sha256'); const verify = createVerify('rsa-sha256');
verify.update(data, 'utf-8'); verify.update(data, 'utf-8');
verify.end(); verify.end();

View File

@@ -36,6 +36,7 @@ AFFiNE.ENV_MAP = {
REDIS_SERVER_PASSWORD: 'redis.password', REDIS_SERVER_PASSWORD: 'redis.password',
REDIS_SERVER_DATABASE: ['redis.db', 'int'], REDIS_SERVER_DATABASE: ['redis.db', 'int'],
DOC_MERGE_INTERVAL: ['doc.manager.updatePollInterval', 'int'], DOC_MERGE_INTERVAL: ['doc.manager.updatePollInterval', 'int'],
DOC_SERVICE_ENDPOINT: 'docService.endpoint',
STRIPE_API_KEY: 'plugins.payment.stripe.keys.APIKey', STRIPE_API_KEY: 'plugins.payment.stripe.keys.APIKey',
STRIPE_WEBHOOK_KEY: 'plugins.payment.stripe.keys.webhookKey', STRIPE_WEBHOOK_KEY: 'plugins.payment.stripe.keys.webhookKey',
}; };

View File

@@ -10,8 +10,10 @@ import type { Request, Response } from 'express';
import { Socket } from 'socket.io'; import { Socket } from 'socket.io';
import { import {
AccessDenied,
AuthenticationRequired, AuthenticationRequired,
Config, Config,
CryptoHelper,
getRequestResponseFromContext, getRequestResponseFromContext,
parseCookies, parseCookies,
} from '../../base'; } from '../../base';
@@ -20,12 +22,14 @@ import { AuthService } from './service';
import { Session } from './session'; import { Session } from './session';
const PUBLIC_ENTRYPOINT_SYMBOL = Symbol('public'); const PUBLIC_ENTRYPOINT_SYMBOL = Symbol('public');
const INTERNAL_ENTRYPOINT_SYMBOL = Symbol('internal');
@Injectable() @Injectable()
export class AuthGuard implements CanActivate, OnModuleInit { export class AuthGuard implements CanActivate, OnModuleInit {
private auth!: AuthService; private auth!: AuthService;
constructor( constructor(
private readonly crypto: CryptoHelper,
private readonly ref: ModuleRef, private readonly ref: ModuleRef,
private readonly reflector: Reflector private readonly reflector: Reflector
) {} ) {}
@@ -36,6 +40,21 @@ export class AuthGuard implements CanActivate, OnModuleInit {
async canActivate(context: ExecutionContext) { async canActivate(context: ExecutionContext) {
const { req, res } = getRequestResponseFromContext(context); const { req, res } = getRequestResponseFromContext(context);
const clazz = context.getClass();
const handler = context.getHandler();
// rpc request is internal
const isInternal = this.reflector.getAllAndOverride<boolean>(
INTERNAL_ENTRYPOINT_SYMBOL,
[clazz, handler]
);
if (isInternal) {
// check access token: data,signature
const accessToken = req.get('x-access-token');
if (accessToken && this.crypto.verify(accessToken)) {
return true;
}
throw new AccessDenied('Invalid internal request');
}
const userSession = await this.signIn(req, res); const userSession = await this.signIn(req, res);
if (res && userSession && userSession.expiresAt) { if (res && userSession && userSession.expiresAt) {
@@ -45,7 +64,7 @@ export class AuthGuard implements CanActivate, OnModuleInit {
// api is public // api is public
const isPublic = this.reflector.getAllAndOverride<boolean>( const isPublic = this.reflector.getAllAndOverride<boolean>(
PUBLIC_ENTRYPOINT_SYMBOL, PUBLIC_ENTRYPOINT_SYMBOL,
[context.getClass(), context.getHandler()] [clazz, handler]
); );
if (isPublic) { if (isPublic) {
@@ -85,6 +104,11 @@ export class AuthGuard implements CanActivate, OnModuleInit {
*/ */
export const Public = () => SetMetadata(PUBLIC_ENTRYPOINT_SYMBOL, true); export const Public = () => SetMetadata(PUBLIC_ENTRYPOINT_SYMBOL, true);
/**
* Mark rpc api to be internal accessible
*/
export const Internal = () => SetMetadata(INTERNAL_ENTRYPOINT_SYMBOL, true);
export const AuthWebsocketOptionsProvider: FactoryProvider = { export const AuthWebsocketOptionsProvider: FactoryProvider = {
provide: WEBSOCKET_OPTIONS, provide: WEBSOCKET_OPTIONS,
useFactory: (config: Config, guard: AuthGuard) => { useFactory: (config: Config, guard: AuthGuard) => {

View File

@@ -0,0 +1,122 @@
import { randomUUID } from 'node:crypto';
import { User, Workspace } from '@prisma/client';
import ava, { TestFn } from 'ava';
import request from 'supertest';
import { createTestingApp, type TestingApp } from '../../../__tests__/utils';
import { AppModule } from '../../../app.module';
import { CryptoHelper } from '../../../base';
import { ConfigModule } from '../../../base/config';
import { Models } from '../../../models';
const test = ava as TestFn<{
models: Models;
app: TestingApp;
crypto: CryptoHelper;
}>;
test.before(async t => {
const { app } = await createTestingApp({
imports: [ConfigModule.forRoot(), AppModule],
});
t.context.models = app.get(Models);
t.context.crypto = app.get(CryptoHelper);
t.context.app = app;
});
let user: User;
let workspace: Workspace;
test.beforeEach(async t => {
await t.context.app.initTestingDB();
user = await t.context.models.user.create({
email: 'test@affine.pro',
});
workspace = await t.context.models.workspace.create(user.id);
});
test.after.always(async t => {
await t.context.app.close();
});
test('should forbid access to rpc api without access token', async t => {
const { app } = t.context;
await request(app.getHttpServer())
.get('/rpc/workspaces/123/docs/123')
.expect({
status: 403,
code: 'Forbidden',
type: 'NO_PERMISSION',
name: 'ACCESS_DENIED',
message: 'Invalid internal request',
})
.expect(403);
t.pass();
});
test('should forbid access to rpc api with invalid access token', async t => {
const { app } = t.context;
await request(app.getHttpServer())
.get('/rpc/workspaces/123/docs/123')
.set('x-access-token', 'invalid,wrong-signature')
.expect({
status: 403,
code: 'Forbidden',
type: 'NO_PERMISSION',
name: 'ACCESS_DENIED',
message: 'Invalid internal request',
})
.expect(403);
t.pass();
});
test('should 404 when doc not found', async t => {
const { app } = t.context;
const workspaceId = '123';
const docId = '123';
await request(app.getHttpServer())
.get(`/rpc/workspaces/${workspaceId}/docs/${docId}`)
.set('x-access-token', t.context.crypto.sign(docId))
.expect({
status: 404,
code: 'Not Found',
type: 'RESOURCE_NOT_FOUND',
name: 'NOT_FOUND',
message: 'Doc not found',
})
.expect(404);
t.pass();
});
test('should return doc when found', async t => {
const { app } = t.context;
const docId = randomUUID();
const timestamp = Date.now();
await t.context.models.doc.createUpdates([
{
spaceId: workspace.id,
docId,
blob: Buffer.from('blob1 data'),
timestamp,
editorId: user.id,
},
]);
const res = await request(app.getHttpServer())
.get(`/rpc/workspaces/${workspace.id}/docs/${docId}`)
.set('x-access-token', t.context.crypto.sign(docId))
.set('x-rpc-trace-id', 'test-trace-id')
.expect(200)
.expect('x-request-id', 'test-trace-id')
.expect('Content-Type', 'application/octet-stream');
const bin = res.body as Buffer;
t.is(bin.toString(), 'blob1 data');
t.is(res.headers['x-doc-timestamp'], timestamp.toString());
t.is(res.headers['x-doc-editor-id'], user.id);
});

View File

@@ -0,0 +1,19 @@
import { defineStartupConfig, ModuleConfig } from '../../base/config';
interface DocServiceStartupConfigurations {
/**
* The endpoint of the doc service.
* Example: http://doc-service:3020
*/
endpoint: string;
}
declare module '../../base/config' {
interface AppConfig {
docService: ModuleConfig<DocServiceStartupConfigurations>;
}
}
defineStartupConfig('docService', {
endpoint: '',
});

View File

@@ -0,0 +1,30 @@
import { Controller, Get, Param, Res } from '@nestjs/common';
import type { Response } from 'express';
import { NotFound, SkipThrottle } from '../../base';
import { Internal } from '../auth';
import { PgWorkspaceDocStorageAdapter } from '../doc';
@Controller('/rpc')
export class DocRpcController {
constructor(private readonly workspace: PgWorkspaceDocStorageAdapter) {}
@SkipThrottle()
@Internal()
@Get('/workspaces/:workspaceId/docs/:docId')
async render(
@Param('workspaceId') workspaceId: string,
@Param('docId') docId: string,
@Res() res: Response
) {
const doc = await this.workspace.getDoc(workspaceId, docId);
if (!doc) {
throw new NotFound('Doc not found');
}
res.setHeader('x-doc-timestamp', doc.timestamp.toString());
if (doc.editor) {
res.setHeader('x-doc-editor-id', doc.editor);
}
res.send(doc.bin);
}
}

View File

@@ -0,0 +1,10 @@
import { Module } from '@nestjs/common';
import { DocStorageModule } from '../doc';
import { DocRpcController } from './controller';
@Module({
imports: [DocStorageModule],
controllers: [DocRpcController],
})
export class DocServiceModule {}

View File

@@ -0,0 +1,72 @@
import { randomUUID } from 'node:crypto';
import { User, Workspace } from '@prisma/client';
import ava, { TestFn } from 'ava';
import { createTestingApp, type TestingApp } from '../../../__tests__/utils';
import { AppModule } from '../../../app.module';
import { ConfigModule } from '../../../base/config';
import { Models } from '../../../models';
import { DocReader } from '..';
import { DatabaseDocReader } from '../reader';
const test = ava as TestFn<{
models: Models;
app: TestingApp;
docReader: DocReader;
}>;
test.before(async t => {
const { app } = await createTestingApp({
imports: [ConfigModule.forRoot(), AppModule],
});
t.context.models = app.get(Models);
t.context.docReader = app.get(DocReader);
t.context.app = app;
});
let user: User;
let workspace: Workspace;
test.beforeEach(async t => {
await t.context.app.initTestingDB();
user = await t.context.models.user.create({
email: 'test@affine.pro',
});
workspace = await t.context.models.workspace.create(user.id);
});
test.after.always(async t => {
await t.context.app.close();
});
test('should return null when doc not found', async t => {
const { docReader } = t.context;
const docId = randomUUID();
const doc = await docReader.getDoc(workspace.id, docId);
t.is(doc, null);
});
test('should return doc when found', async t => {
const { docReader } = t.context;
t.true(docReader instanceof DatabaseDocReader);
const docId = randomUUID();
const timestamp = Date.now();
await t.context.models.doc.createUpdates([
{
spaceId: workspace.id,
docId,
blob: Buffer.from('blob1 data'),
timestamp,
editorId: user.id,
},
]);
const doc = await docReader.getDoc(workspace.id, docId);
t.truthy(doc);
t.is(doc!.bin.toString(), 'blob1 data');
t.is(doc!.timestamp, timestamp);
t.is(doc!.editor, user.id);
});

View File

@@ -0,0 +1,124 @@
import { randomUUID } from 'node:crypto';
import { mock } from 'node:test';
import { User, Workspace } from '@prisma/client';
import ava, { TestFn } from 'ava';
import { createTestingApp, type TestingApp } from '../../../__tests__/utils';
import { AppModule } from '../../../app.module';
import { Config, InternalServerError } from '../../../base';
import { ConfigModule } from '../../../base/config';
import { Models } from '../../../models';
import { DocReader } from '..';
import { RpcDocReader } from '../reader';
const test = ava as TestFn<{
models: Models;
app: TestingApp;
docReader: DocReader;
config: Config;
}>;
test.before(async t => {
const { app } = await createTestingApp({
imports: [
ConfigModule.forRoot({
flavor: {
doc: false,
},
docService: {
endpoint: '',
},
}),
AppModule,
],
});
t.context.models = app.get(Models);
t.context.docReader = app.get(DocReader);
t.context.config = app.get(Config);
t.context.app = app;
});
let user: User;
let workspace: Workspace;
test.beforeEach(async t => {
t.context.config.docService.endpoint = t.context.app.getHttpServerUrl();
await t.context.app.initTestingDB();
user = await t.context.models.user.create({
email: 'test@affine.pro',
});
workspace = await t.context.models.workspace.create(user.id);
});
test.afterEach.always(() => {
mock.reset();
});
test.after.always(async t => {
await t.context.app.close();
});
test('should return null when doc not found', async t => {
const { docReader } = t.context;
const docId = randomUUID();
const doc = await docReader.getDoc(workspace.id, docId);
t.is(doc, null);
});
test('should throw error when doc service internal error', async t => {
const { docReader } = t.context;
const docId = randomUUID();
mock.method(docReader, 'getDoc', async () => {
throw new InternalServerError('mock doc service internal error');
});
await t.throwsAsync(docReader.getDoc(workspace.id, docId), {
instanceOf: InternalServerError,
});
});
test('should fallback to database doc service when endpoint network error', async t => {
const { docReader } = t.context;
t.context.config.docService.endpoint = 'http://localhost:13010';
const docId = randomUUID();
const timestamp = Date.now();
await t.context.models.doc.createUpdates([
{
spaceId: workspace.id,
docId,
blob: Buffer.from('blob1 data'),
timestamp,
editorId: user.id,
},
]);
const doc = await docReader.getDoc(workspace.id, docId);
t.truthy(doc);
t.is(doc!.bin.toString(), 'blob1 data');
t.is(doc!.timestamp, timestamp);
t.is(doc!.editor, user.id);
});
test('should return doc when found', async t => {
const { docReader } = t.context;
t.true(docReader instanceof RpcDocReader);
const docId = randomUUID();
const timestamp = Date.now();
await t.context.models.doc.createUpdates([
{
spaceId: workspace.id,
docId,
blob: Buffer.from('blob1 data'),
timestamp,
editorId: user.id,
},
]);
const doc = await docReader.getDoc(workspace.id, docId);
t.truthy(doc);
t.is(doc!.bin.toString(), 'blob1 data');
t.is(doc!.timestamp, timestamp);
t.is(doc!.editor, user.id);
});

View File

@@ -41,7 +41,9 @@ declare global {
} }
@Injectable() @Injectable()
export class PgWorkspaceDocStorageAdapter extends DocStorageAdapter { export class PgWorkspaceDocStorageAdapter extends DocStorageAdapter {
private readonly logger = new Logger(PgWorkspaceDocStorageAdapter.name); protected override readonly logger = new Logger(
PgWorkspaceDocStorageAdapter.name
);
constructor( constructor(
private readonly models: Models, private readonly models: Models,

View File

@@ -8,6 +8,7 @@ import { PgUserspaceDocStorageAdapter } from './adapters/userspace';
import { PgWorkspaceDocStorageAdapter } from './adapters/workspace'; import { PgWorkspaceDocStorageAdapter } from './adapters/workspace';
import { DocStorageCronJob } from './job'; import { DocStorageCronJob } from './job';
import { DocStorageOptions } from './options'; import { DocStorageOptions } from './options';
import { DocReader, DocReaderProvider } from './reader';
@Module({ @Module({
imports: [QuotaModule, PermissionModule], imports: [QuotaModule, PermissionModule],
@@ -16,10 +17,15 @@ import { DocStorageOptions } from './options';
PgWorkspaceDocStorageAdapter, PgWorkspaceDocStorageAdapter,
PgUserspaceDocStorageAdapter, PgUserspaceDocStorageAdapter,
DocStorageCronJob, DocStorageCronJob,
DocReaderProvider,
], ],
exports: [PgWorkspaceDocStorageAdapter, PgUserspaceDocStorageAdapter], exports: [PgWorkspaceDocStorageAdapter, PgUserspaceDocStorageAdapter],
}) })
export class DocStorageModule {} export class DocStorageModule {}
export { PgUserspaceDocStorageAdapter, PgWorkspaceDocStorageAdapter }; export {
DocReader,
PgUserspaceDocStorageAdapter,
PgWorkspaceDocStorageAdapter,
};
export { DocStorageAdapter, type Editor } from './storage'; export { DocStorageAdapter, type Editor } from './storage';

View File

@@ -0,0 +1,93 @@
import { FactoryProvider, Injectable, Logger } from '@nestjs/common';
import { ModuleRef } from '@nestjs/core';
import { ClsService } from 'nestjs-cls';
import { Config, CryptoHelper, UserFriendlyError } from '../../base';
import { PgWorkspaceDocStorageAdapter } from './adapters/workspace';
import { type DocRecord } from './storage';
export abstract class DocReader {
abstract getDoc(
workspaceId: string,
docId: string
): Promise<DocRecord | null>;
}
@Injectable()
export class DatabaseDocReader extends DocReader {
constructor(protected readonly workspace: PgWorkspaceDocStorageAdapter) {
super();
}
async getDoc(workspaceId: string, docId: string): Promise<DocRecord | null> {
return await this.workspace.getDoc(workspaceId, docId);
}
}
@Injectable()
export class RpcDocReader extends DatabaseDocReader {
private readonly logger = new Logger(DocReader.name);
constructor(
private readonly config: Config,
private readonly crypto: CryptoHelper,
private readonly cls: ClsService,
protected override readonly workspace: PgWorkspaceDocStorageAdapter
) {
super(workspace);
}
override async getDoc(
workspaceId: string,
docId: string
): Promise<DocRecord | null> {
const url = `${this.config.docService.endpoint}/rpc/workspaces/${workspaceId}/docs/${docId}`;
try {
const res = await fetch(url, {
headers: {
'x-access-token': this.crypto.sign(docId),
'x-rpc-trace-id': this.cls.getId(),
},
});
if (!res.ok) {
if (res.status === 404) {
return null;
}
const body = (await res.json()) as UserFriendlyError;
throw UserFriendlyError.fromUserFriendlyErrorJSON(body);
}
const timestamp = res.headers.get('x-doc-timestamp') as string;
const editor = res.headers.get('x-doc-editor-id') as string;
const bin = await res.arrayBuffer();
return {
spaceId: workspaceId,
docId,
bin: Buffer.from(bin),
timestamp: parseInt(timestamp),
editor,
};
} catch (err) {
if (err instanceof UserFriendlyError) {
throw err;
}
// other error
this.logger.error(
`Failed to fetch doc ${url}, error: ${err}`,
(err as Error).stack
);
// fallback to database doc service if the error is not user friendly, like network error
return await super.getDoc(workspaceId, docId);
}
}
}
export const DocReaderProvider: FactoryProvider = {
provide: DocReader,
useFactory: (config: Config, ref: ModuleRef) => {
if (config.flavor.doc) {
return ref.create(DatabaseDocReader);
}
return ref.create(RpcDocReader);
},
inject: [Config, ModuleRef],
};

View File

@@ -1,3 +1,4 @@
import { Logger } from '@nestjs/common';
import { import {
applyUpdate, applyUpdate,
diffUpdate, diffUpdate,
@@ -49,6 +50,7 @@ export interface DocStorageOptions {
export abstract class DocStorageAdapter extends Connection { export abstract class DocStorageAdapter extends Connection {
private readonly locker = new SingletonLocker(); private readonly locker = new SingletonLocker();
protected readonly logger = new Logger(DocStorageAdapter.name);
constructor( constructor(
protected readonly options: DocStorageOptions = { protected readonly options: DocStorageOptions = {
@@ -76,6 +78,9 @@ export abstract class DocStorageAdapter extends Connection {
const updates = await this.getDocUpdates(spaceId, docId); const updates = await this.getDocUpdates(spaceId, docId);
if (updates.length) { if (updates.length) {
this.logger.log(
`Squashing updates, spaceId: ${spaceId}, docId: ${docId}, updates: ${updates.length}`
);
const { timestamp, bin, editor } = await this.squash( const { timestamp, bin, editor } = await this.squash(
snapshot ? [snapshot, ...updates] : updates snapshot ? [snapshot, ...updates] : updates
); );
@@ -96,7 +101,12 @@ export abstract class DocStorageAdapter extends Connection {
} }
// always mark updates as merged unless throws // always mark updates as merged unless throws
await this.markUpdatesMerged(spaceId, docId, updates); const count = await this.markUpdatesMerged(spaceId, docId, updates);
if (count > 0) {
this.logger.log(
`Marked ${count} updates as merged, spaceId: ${spaceId}, docId: ${docId}`
);
}
return newSnapshot; return newSnapshot;
} }

View File

@@ -345,12 +345,7 @@ export class LicenseService implements OnModuleInit {
if (!res.ok) { if (!res.ok) {
const body = (await res.json()) as UserFriendlyError; const body = (await res.json()) as UserFriendlyError;
throw new UserFriendlyError( throw UserFriendlyError.fromUserFriendlyErrorJSON(body);
body.type as any,
body.name.toLowerCase() as any,
body.message,
body.data
);
} }
const data = (await res.json()) as T; const data = (await res.json()) as T;