refactor(server): config system (#11081)

This commit is contained in:
forehalo
2025-03-27 12:32:28 +00:00
parent 7091111f85
commit 0ea38680fa
274 changed files with 7583 additions and 5841 deletions

View File

@@ -0,0 +1,13 @@
-- CreateTable
CREATE TABLE "app_configs" (
"id" VARCHAR NOT NULL,
"value" JSONB NOT NULL,
"created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ(3) NOT NULL,
"last_updated_by" VARCHAR,
CONSTRAINT "app_configs_pkey" PRIMARY KEY ("id")
);
-- AddForeignKey
ALTER TABLE "app_configs" ADD CONSTRAINT "app_configs_last_updated_by_fkey" FOREIGN KEY ("last_updated_by") REFERENCES "users"("id") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@@ -8,7 +8,7 @@
"run-test": "./scripts/run-test.ts"
},
"scripts": {
"build": "tsc",
"build": "tsc -b",
"dev": "nodemon ./src/index.ts",
"dev:mail": "email dev -d src/mails",
"test": "ava --concurrency 1 --serial",
@@ -20,6 +20,7 @@
"data-migration": "cross-env NODE_ENV=development r ./src/data/index.ts",
"init": "yarn prisma migrate dev && yarn data-migration run",
"seed": "r ./src/seed/index.ts",
"genconfig": "r ./scripts/genconfig.ts",
"predeploy": "yarn prisma migrate deploy && node --import ./scripts/register.js ./dist/data/index.js run",
"postinstall": "prisma generate"
},
@@ -139,7 +140,8 @@
"nodemon": "^3.1.7",
"react-email": "3.0.7",
"sinon": "^19.0.2",
"supertest": "^7.0.0"
"supertest": "^7.0.0",
"why-is-node-running": "^3.2.2"
},
"nodemonConfig": {
"exec": "node",

File diff suppressed because it is too large Load Diff

View File

@@ -24,23 +24,25 @@ model User {
registered Boolean @default(true)
disabled Boolean @default(false)
features UserFeature[]
userStripeCustomer UserStripeCustomer?
workspacePermissions WorkspaceUserRole[]
docPermissions WorkspaceDocUserRole[]
connectedAccounts ConnectedAccount[]
sessions UserSession[]
aiSessions AiSession[]
updatedRuntimeConfigs RuntimeConfig[]
userSnapshots UserSnapshot[]
createdSnapshot Snapshot[] @relation("createdSnapshot")
updatedSnapshot Snapshot[] @relation("updatedSnapshot")
createdUpdate Update[] @relation("createdUpdate")
createdHistory SnapshotHistory[] @relation("createdHistory")
createdAiJobs AiJobs[] @relation("createdAiJobs")
features UserFeature[]
userStripeCustomer UserStripeCustomer?
workspacePermissions WorkspaceUserRole[]
docPermissions WorkspaceDocUserRole[]
connectedAccounts ConnectedAccount[]
sessions UserSession[]
aiSessions AiSession[]
/// @deprecated
deprecatedAppRuntimeSettings DeprecatedAppRuntimeSettings[]
appConfigs AppConfig[]
userSnapshots UserSnapshot[]
createdSnapshot Snapshot[] @relation("createdSnapshot")
updatedSnapshot Snapshot[] @relation("updatedSnapshot")
createdUpdate Update[] @relation("createdUpdate")
createdHistory SnapshotHistory[] @relation("createdHistory")
createdAiJobs AiJobs[] @relation("createdAiJobs")
// receive notifications
notifications Notification[] @relation("user_notifications")
settings UserSettings?
notifications Notification[] @relation("user_notifications")
settings UserSettings?
@@index([email])
@@map("users")
@@ -438,12 +440,12 @@ model AiContext {
}
model AiContextEmbedding {
id String @id @default(uuid()) @db.VarChar
contextId String @map("context_id") @db.VarChar
fileId String @map("file_id") @db.VarChar
id String @id @default(uuid()) @db.VarChar
contextId String @map("context_id") @db.VarChar
fileId String @map("file_id") @db.VarChar
// a file can be divided into multiple chunks and embedded separately.
chunk Int @db.Integer
content String @db.VarChar
chunk Int @db.Integer
content String @db.VarChar
embedding Unsupported("vector(1024)")
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
@@ -457,11 +459,11 @@ model AiContextEmbedding {
}
model AiWorkspaceEmbedding {
workspaceId String @map("workspace_id") @db.VarChar
docId String @map("doc_id") @db.VarChar
workspaceId String @map("workspace_id") @db.VarChar
docId String @map("doc_id") @db.VarChar
// a doc can be divided into multiple chunks and embedded separately.
chunk Int @db.Integer
content String @db.VarChar
chunk Int @db.Integer
content String @db.VarChar
embedding Unsupported("vector(1024)")
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
@@ -527,7 +529,8 @@ enum RuntimeConfigType {
Array
}
model RuntimeConfig {
/// @deprecated use AppConfig instead
model DeprecatedAppRuntimeSettings {
id String @id @db.VarChar
type RuntimeConfigType
module String @db.VarChar
@@ -544,6 +547,18 @@ model RuntimeConfig {
@@map("app_runtime_settings")
}
model AppConfig {
id String @id @db.VarChar
value Json @db.JsonB
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3)
lastUpdatedBy String? @map("last_updated_by") @db.VarChar
lastUpdatedByUser User? @relation(fields: [lastUpdatedBy], references: [id], onDelete: SetNull)
@@map("app_configs")
}
model DeprecatedUserSubscription {
id Int @id @default(autoincrement()) @db.Integer
userId String @map("user_id") @db.VarChar

View File

@@ -0,0 +1,101 @@
/* eslint-disable */
import '../src/prelude';
import '../src/app.module';
import fs from 'node:fs';
import { ProjectRoot } from '@affine-tools/utils/path';
import { Package } from '@affine-tools/utils/workspace';
import { getDescriptors, ConfigDescriptor } from '../src/base/config/register';
import { pick } from 'lodash-es';
interface PropertySchema {
description: string;
type?: 'array' | 'boolean' | 'integer' | 'number' | 'object' | 'string';
default?: any;
}
function convertDescriptorToSchemaProperty(descriptor: ConfigDescriptor<any>) {
const property: PropertySchema = {
...descriptor.schema,
description:
descriptor.schema.description +
`\n@default ${JSON.stringify(descriptor.default)}` +
(descriptor.env ? `\n@environment \`${descriptor.env[0]}\`` : '') +
(descriptor.link ? `\n@link ${descriptor.link}` : ''),
default: descriptor.default,
};
return property;
}
function generateJsonSchema(outputPath: string) {
const schema = {
$schema: 'http://json-schema.org/draft-07/schema#',
title: 'AFFiNE Application Configuration',
type: 'object',
properties: {},
};
getDescriptors().forEach(({ module, descriptors }) => {
schema.properties[module] = {
type: 'object',
description: `Configuration for ${module} module`,
properties: {},
};
descriptors.forEach(({ key, descriptor }) => {
schema.properties[module].properties[key] =
convertDescriptorToSchemaProperty(descriptor);
});
});
fs.writeFileSync(outputPath, JSON.stringify(schema, null, 2));
console.log(`Config schema generated at: ${outputPath}`);
}
function generateAdminConfigJson(outputPath: string) {
const config = {};
getDescriptors().forEach(({ module, descriptors }) => {
const modulizedConfig = {};
config[module] = modulizedConfig;
descriptors.forEach(({ key, descriptor }) => {
let type: string;
switch (descriptor.schema?.type) {
case 'number':
type = 'Number';
break;
case 'boolean':
type = 'Boolean';
break;
case 'array':
type = 'Array';
break;
case 'object':
type = 'Object';
break;
default:
type = 'String';
}
modulizedConfig[key] = {
type,
desc: descriptor.desc,
link: descriptor.link,
env: descriptor.env?.[0],
};
});
});
fs.writeFileSync(outputPath, JSON.stringify(config, null, 2));
}
function main() {
generateJsonSchema(
ProjectRoot.join('.docker', 'selfhost', 'schema.json').toString()
);
generateAdminConfigJson(
new Package('@affine/admin').join('src/config.json').toString()
);
}
main();

View File

@@ -1,17 +1,10 @@
import { execSync } from 'node:child_process';
import { generateKeyPairSync } from 'node:crypto';
import fs from 'node:fs';
import { homedir } from 'node:os';
import path from 'node:path';
const SELF_HOST_CONFIG_DIR = '/root/.affine/config';
function generateConfigFile() {
const content = fs.readFileSync('./dist/config/affine.js', 'utf-8');
return content.replace(
/(^\/\/#.*$)|(^\/\/\s+TODO.*$)|("use\sstrict";?)|(^.*lint-disable.*$)/gm,
''
);
}
const SELF_HOST_CONFIG_DIR = `${homedir()}/.affine/config`;
function generatePrivateKey() {
const key = generateKeyPairSync('ec', {
@@ -31,15 +24,12 @@ function generatePrivateKey() {
/**
* @type {Array<{ to: string; generator: () => string }>}
*/
const configFiles = [
{ to: 'affine.js', generator: generateConfigFile },
{ to: 'private.key', generator: generatePrivateKey },
];
const files = [{ to: 'private.key', generator: generatePrivateKey }];
function prepare() {
fs.mkdirSync(SELF_HOST_CONFIG_DIR, { recursive: true });
for (const { to, generator } of configFiles) {
for (const { to, generator } of files) {
const targetFilePath = path.join(SELF_HOST_CONFIG_DIR, to);
if (!fs.existsSync(targetFilePath)) {
console.log(`creating config file [${targetFilePath}].`);

View File

@@ -1,36 +0,0 @@
import type { INestApplication } from '@nestjs/common';
import type { TestFn } from 'ava';
import ava from 'ava';
import request from 'supertest';
import { buildAppModule } from '../../app.module';
import { createTestingApp } from '../utils';
const test = ava as TestFn<{
app: INestApplication;
}>;
test.before('start app', async t => {
// @ts-expect-error override
AFFiNE.flavor = {
type: 'doc',
doc: true,
} as typeof AFFiNE.flavor;
const app = await createTestingApp({
imports: [buildAppModule()],
});
t.context.app = app;
});
test.after.always(async t => {
await t.context.app.close();
});
test('should init app', async t => {
const res = await request(t.context.app.getHttpServer())
.get('/info')
.expect(200);
t.is(res.body.flavor, 'doc');
});

View File

@@ -1,137 +0,0 @@
import { Args, Mutation, Resolver } from '@nestjs/graphql';
import type { TestFn } from 'ava';
import ava from 'ava';
import GraphQLUpload, {
type FileUpload,
} from 'graphql-upload/GraphQLUpload.mjs';
import request from 'supertest';
import { buildAppModule } from '../../app.module';
import { Public } from '../../core/auth';
import { createTestingApp, TestingApp } from '../utils';
const gql = '/graphql';
const test = ava as TestFn<{
app: TestingApp;
}>;
@Resolver(() => String)
class TestResolver {
@Public()
@Mutation(() => Number)
async upload(
@Args({ name: 'body', type: () => GraphQLUpload })
body: FileUpload
): Promise<number> {
const size = await new Promise<number>((resolve, reject) => {
const stream = body.createReadStream();
let size = 0;
stream.on('data', chunk => (size += chunk.length));
stream.on('error', reject);
stream.on('end', () => resolve(size));
});
return size;
}
}
test.before('start app', async t => {
// @ts-expect-error override
AFFiNE.flavor = {
type: 'graphql',
graphql: true,
} as typeof AFFiNE.flavor;
const app = await createTestingApp({
imports: [buildAppModule()],
providers: [TestResolver],
});
t.context.app = app;
});
test.after.always(async t => {
await t.context.app.close();
});
test('should init app', async t => {
await request(t.context.app.getHttpServer())
.post(gql)
.send({
query: `
query {
error
}
`,
})
.expect(400);
const response = await request(t.context.app.getHttpServer())
.post(gql)
.send({
query: `query {
serverConfig {
name
version
type
features
}
}`,
})
.expect(200);
const config = response.body.data.serverConfig;
t.is(config.type, 'Affine');
t.true(Array.isArray(config.features));
// make sure the request id is set
t.truthy(response.headers['x-request-id']);
});
test('should return 404 for unknown path', async t => {
await request(t.context.app.getHttpServer()).get('/unknown').expect(404);
t.pass();
});
test('should be able to call apis', async t => {
const res = await request(t.context.app.getHttpServer())
.get('/info')
.expect(200);
t.is(res.body.flavor, 'graphql');
// make sure the request id is set
t.truthy(res.headers['x-request-id']);
});
test('should not throw internal error when graphql call with invalid params', async t => {
await t.throwsAsync(t.context.app.gql(`query { workspace("1") }`), {
message: /Failed to execute gql: query { workspace\("1"\) \}, status: 400/,
});
});
test('should can send maximum size of body', async t => {
const { app } = t.context;
const body = Buffer.from('a'.repeat(10 * 1024 * 1024 - 1));
const res = await app
.POST('/graphql')
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
.field(
'operations',
JSON.stringify({
name: 'upload',
query: `mutation upload($body: Upload!) { upload(body: $body) }`,
variables: { body: null },
})
)
.field('map', JSON.stringify({ '0': ['variables.body'] }))
.attach(
'0',
body,
`body-${Math.random().toString(16).substring(2, 10)}.data`
)
.expect(200);
t.is(Number(res.body.data.upload), body.length);
});

View File

@@ -1,36 +0,0 @@
import type { INestApplication } from '@nestjs/common';
import type { TestFn } from 'ava';
import ava from 'ava';
import request from 'supertest';
import { buildAppModule } from '../../app.module';
import { createTestingApp } from '../utils';
const test = ava as TestFn<{
app: INestApplication;
}>;
test.before('start app', async t => {
// @ts-expect-error override
AFFiNE.flavor = {
type: 'renderer',
renderer: true,
} as typeof AFFiNE.flavor;
const app = await createTestingApp({
imports: [buildAppModule()],
});
t.context.app = app;
});
test.after.always(async t => {
await t.context.app.close();
});
test('should init app', async t => {
const res = await request(t.context.app.getHttpServer())
.get('/info')
.expect(200);
t.is(res.body.flavor, 'renderer');
});

View File

@@ -8,7 +8,6 @@ import ava from 'ava';
import request from 'supertest';
import { buildAppModule } from '../../app.module';
import { Config } from '../../base';
import { Public } from '../../core/auth';
import { ServerService } from '../../core/config';
import { createTestingApp, type TestingApp } from '../utils';
@@ -49,18 +48,16 @@ export class TestResolver {
test.before('init selfhost server', async t => {
// @ts-expect-error override
AFFiNE.isSelfhosted = true;
AFFiNE.flavor.renderer = true;
globalThis.env.DEPLOYMENT_TYPE = 'selfhosted';
const app = await createTestingApp({
imports: [buildAppModule()],
imports: [buildAppModule(globalThis.env)],
controllers: [TestResolver],
});
t.context.app = app;
t.context.db = t.context.app.get(PrismaClient);
const config = app.get(Config);
const staticPath = path.join(config.projectRoot, 'static');
const staticPath = path.join(env.projectRoot, 'static');
initTestStaticFiles(staticPath);
});

View File

@@ -1,36 +0,0 @@
import type { INestApplication } from '@nestjs/common';
import type { TestFn } from 'ava';
import ava from 'ava';
import request from 'supertest';
import { buildAppModule } from '../../app.module';
import { createTestingApp } from '../utils';
const test = ava as TestFn<{
app: INestApplication;
}>;
test.before('start app', async t => {
// @ts-expect-error override
AFFiNE.flavor = {
type: 'sync',
sync: true,
} as typeof AFFiNE.flavor;
const app = await createTestingApp({
imports: [buildAppModule()],
});
t.context.app = app;
});
test.after.always(async t => {
await t.context.app.close();
});
test('should init app', async t => {
const res = await request(t.context.app.getHttpServer())
.get('/info')
.expect(200);
t.is(res.body.flavor, 'sync');
});

View File

@@ -5,10 +5,7 @@ import { PrismaClient } from '@prisma/client';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import { AuthModule } from '../../core/auth';
import { AuthService } from '../../core/auth/service';
import { FeatureModule } from '../../core/features';
import { UserModule } from '../../core/user';
import {
createTestingApp,
currentUser,
@@ -23,9 +20,7 @@ const test = ava as TestFn<{
}>;
test.before(async t => {
const app = await createTestingApp({
imports: [FeatureModule, UserModule, AuthModule],
});
const app = await createTestingApp();
t.context.auth = app.get(AuthService);
t.context.db = app.get(PrismaClient);

View File

@@ -1,39 +0,0 @@
import { TestingModule } from '@nestjs/testing';
import test from 'ava';
import { Config, ConfigModule } from '../base/config';
import { createTestingModule } from './utils';
let config: Config;
let module: TestingModule;
test.beforeEach(async () => {
module = await createTestingModule({}, false);
config = module.get(Config);
});
test.afterEach.always(async () => {
await module.close();
});
test('should be able to get config', t => {
t.true(typeof config.server.host === 'string');
t.is(config.projectRoot, process.cwd());
t.is(config.NODE_ENV, 'test');
});
test('should be able to override config', async t => {
const module = await createTestingModule({
imports: [
ConfigModule.forRoot({
server: {
host: 'testing',
},
}),
],
});
const config = module.get(Config);
t.is(config.server.host, 'testing');
await module.close();
});

View File

@@ -7,14 +7,7 @@ import { AuthService } from '../core/auth';
import { QuotaModule } from '../core/quota';
import { CopilotModule } from '../plugins/copilot';
import { prompts, PromptService } from '../plugins/copilot/prompt';
import {
CopilotProviderService,
FalProvider,
OpenAIProvider,
PerplexityProvider,
registerCopilotProvider,
unregisterCopilotProvider,
} from '../plugins/copilot/providers';
import { CopilotProviderFactory } from '../plugins/copilot/providers';
import {
CopilotChatTextExecutor,
CopilotWorkflowService,
@@ -32,7 +25,7 @@ type Tester = {
auth: AuthService;
module: TestingModule;
prompt: PromptService;
provider: CopilotProviderService;
factory: CopilotProviderFactory;
workflow: CopilotWorkflowService;
executors: {
image: CopilotChatImageExecutor;
@@ -67,9 +60,9 @@ const runIfCopilotConfigured = test.macro(
test.serial.before(async t => {
const module = await createTestingModule({
imports: [
ConfigModule.forRoot({
plugins: {
copilot: {
ConfigModule.override({
copilot: {
providers: {
openai: {
apiKey: process.env.COPILOT_OPENAI_API_KEY,
},
@@ -79,6 +72,9 @@ test.serial.before(async t => {
perplexity: {
apiKey: process.env.COPILOT_PERPLEXITY_API_KEY,
},
gemini: {
apiKey: process.env.COPILOT_GOOGLE_API_KEY,
},
},
},
}),
@@ -89,13 +85,13 @@ test.serial.before(async t => {
const auth = module.get(AuthService);
const prompt = module.get(PromptService);
const provider = module.get(CopilotProviderService);
const factory = module.get(CopilotProviderFactory);
const workflow = module.get(CopilotWorkflowService);
t.context.module = module;
t.context.auth = auth;
t.context.prompt = prompt;
t.context.provider = provider;
t.context.factory = factory;
t.context.workflow = workflow;
t.context.executors = {
image: module.get(CopilotChatImageExecutor),
@@ -113,10 +109,6 @@ test.serial.before(async t => {
executors.html.register();
executors.json.register();
registerCopilotProvider(OpenAIProvider);
registerCopilotProvider(FalProvider);
registerCopilotProvider(PerplexityProvider);
for (const name of await prompt.listNames()) {
await prompt.delete(name);
}
@@ -126,12 +118,6 @@ test.serial.before(async t => {
}
});
test.after(async _ => {
unregisterCopilotProvider(OpenAIProvider.type);
unregisterCopilotProvider(FalProvider.type);
unregisterCopilotProvider(PerplexityProvider.type);
});
test.after(async t => {
await t.context.module.close();
});
@@ -523,12 +509,10 @@ for (const { name, promptName, messages, verifier, type } of actions) {
`should be able to run action: ${promptName}${name ? ` - ${name}` : ''}`,
runIfCopilotConfigured,
async t => {
const { provider: providerService, prompt: promptService } = t.context;
const { factory, prompt: promptService } = t.context;
const prompt = (await promptService.get(promptName))!;
t.truthy(prompt, 'should have prompt');
const provider = (await providerService.getProviderByModel(
prompt.model
))!;
const provider = (await factory.getProviderByModel(prompt.model))!;
t.truthy(provider, 'should have provider');
await retry(`action: ${promptName}`, t, async t => {
if (type === 'text' && 'generateText' in provider) {

View File

@@ -6,12 +6,11 @@ import type { TestFn } from 'ava';
import ava from 'ava';
import Sinon from 'sinon';
import { AppModule } from '../app.module';
import { JobQueue } from '../base';
import { ConfigModule } from '../base/config';
import { AuthService } from '../core/auth';
import { DocReader } from '../core/doc';
import { WorkspaceModule } from '../core/workspaces';
import { CopilotModule } from '../plugins/copilot';
import {
CopilotContextDocJob,
CopilotContextService,
@@ -19,14 +18,11 @@ import {
import { MockEmbeddingClient } from '../plugins/copilot/context/embedding';
import { prompts, PromptService } from '../plugins/copilot/prompt';
import {
CopilotProviderService,
FalProvider,
CopilotProviderFactory,
OpenAIProvider,
PerplexityProvider,
registerCopilotProvider,
unregisterCopilotProvider,
} from '../plugins/copilot/providers';
import { CopilotStorage } from '../plugins/copilot/storage';
import { MockCopilotProvider } from './mocks';
import {
acceptInviteById,
createTestingApp,
@@ -53,7 +49,6 @@ import {
listContextDocAndFiles,
matchFiles,
matchWorkspaceDocs,
MockCopilotTestProvider,
sse2array,
textToEventStream,
unsplashSearch,
@@ -67,7 +62,7 @@ const test = ava as TestFn<{
context: CopilotContextService;
jobs: CopilotContextDocJob;
prompt: PromptService;
provider: CopilotProviderService;
factory: CopilotProviderFactory;
storage: CopilotStorage;
u1: TestUser;
}>;
@@ -75,24 +70,19 @@ const test = ava as TestFn<{
test.before(async t => {
const app = await createTestingApp({
imports: [
ConfigModule.forRoot({
plugins: {
copilot: {
openai: {
apiKey: '1',
},
fal: {
apiKey: '1',
},
perplexity: {
apiKey: '1',
},
unsplashKey: process.env.UNSPLASH_ACCESS_KEY || '1',
ConfigModule.override({
copilot: {
providers: {
openai: { apiKey: '1' },
fal: {},
perplexity: {},
},
unsplash: {
key: process.env.UNSPLASH_ACCESS_KEY || '1',
},
},
}),
WorkspaceModule,
CopilotModule,
AppModule,
],
tapModule: m => {
// use real JobQueue for testing
@@ -105,6 +95,7 @@ test.before(async t => {
};
},
});
m.overrideProvider(OpenAIProvider).useClass(MockCopilotProvider);
},
});
@@ -129,14 +120,9 @@ test.beforeEach(async t => {
Sinon.restore();
const { app, prompt } = t.context;
await app.initTestingDB();
await prompt.onModuleInit();
await prompt.onApplicationBootstrap();
t.context.u1 = await app.signupV1('u1@affine.pro');
unregisterCopilotProvider(OpenAIProvider.type);
unregisterCopilotProvider(FalProvider.type);
unregisterCopilotProvider(PerplexityProvider.type);
registerCopilotProvider(MockCopilotTestProvider);
await prompt.set(promptName, 'test', [
{ role: 'system', content: 'hello {{word}}' },
]);
@@ -761,13 +747,12 @@ test('should be able to manage context', async t => {
'should throw error if create context with invalid session id'
);
const context = createCopilotContext(app, workspaceId, sessionId);
await t.notThrowsAsync(context, 'should create context with chat session');
const context = await createCopilotContext(app, workspaceId, sessionId);
const list = await listContext(app, workspaceId, sessionId);
t.deepEqual(
list.map(f => ({ id: f.id })),
[{ id: await context }],
[{ id: context }],
'should list context'
);
}

View File

@@ -19,18 +19,14 @@ import {
import { MockEmbeddingClient } from '../plugins/copilot/context/embedding';
import { prompts, PromptService } from '../plugins/copilot/prompt';
import {
CopilotProviderService,
CopilotCapability,
CopilotProviderFactory,
CopilotProviderType,
OpenAIProvider,
registerCopilotProvider,
unregisterCopilotProvider,
} from '../plugins/copilot/providers';
import { CitationParser } from '../plugins/copilot/providers/perplexity';
import { ChatSessionService } from '../plugins/copilot/session';
import { CopilotStorage } from '../plugins/copilot/storage';
import {
CopilotCapability,
CopilotProviderType,
} from '../plugins/copilot/types';
import {
CopilotChatTextExecutor,
CopilotWorkflowService,
@@ -50,8 +46,9 @@ import {
} from '../plugins/copilot/workflow/executor';
import { AutoRegisteredWorkflowExecutor } from '../plugins/copilot/workflow/executor/utils';
import { WorkflowGraphList } from '../plugins/copilot/workflow/graph';
import { MockCopilotProvider } from './mocks';
import { createTestingModule, TestingModule } from './utils';
import { MockCopilotTestProvider, WorkflowTestCases } from './utils/copilot';
import { WorkflowTestCases } from './utils/copilot';
const test = ava as TestFn<{
auth: AuthService;
@@ -60,7 +57,7 @@ const test = ava as TestFn<{
event: EventBus;
context: CopilotContextService;
prompt: PromptService;
provider: CopilotProviderService;
factory: CopilotProviderFactory;
session: ChatSessionService;
jobs: CopilotContextDocJob;
storage: CopilotStorage;
@@ -77,9 +74,9 @@ let userId: string;
test.before(async t => {
const module = await createTestingModule({
imports: [
ConfigModule.forRoot({
plugins: {
copilot: {
ConfigModule.override({
copilot: {
providers: {
openai: {
apiKey: process.env.COPILOT_OPENAI_API_KEY ?? '1',
},
@@ -95,6 +92,9 @@ test.before(async t => {
QuotaModule,
CopilotModule,
],
tapModule: builder => {
builder.overrideProvider(OpenAIProvider).useClass(MockCopilotProvider);
},
});
const auth = module.get(AuthService);
@@ -102,7 +102,7 @@ test.before(async t => {
const event = module.get(EventBus);
const context = module.get(CopilotContextService);
const prompt = module.get(PromptService);
const provider = module.get(CopilotProviderService);
const factory = module.get(CopilotProviderFactory);
const session = module.get(ChatSessionService);
const workflow = module.get(CopilotWorkflowService);
const jobs = module.get(CopilotContextDocJob);
@@ -114,7 +114,7 @@ test.before(async t => {
t.context.event = event;
t.context.context = context;
t.context.prompt = prompt;
t.context.provider = provider;
t.context.factory = factory;
t.context.session = session;
t.context.workflow = workflow;
t.context.jobs = jobs;
@@ -131,7 +131,7 @@ test.beforeEach(async t => {
Sinon.restore();
const { module, auth, prompt } = t.context;
await module.initTestingDB();
await prompt.onModuleInit();
await prompt.onApplicationBootstrap();
const user = await auth.signUp('test@affine.pro', '123456');
userId = user.id;
});
@@ -730,10 +730,10 @@ test('should handle params correctly in chat session', async t => {
// ==================== provider ====================
test('should be able to get provider', async t => {
const { provider } = t.context;
const { factory } = t.context;
{
const p = await provider.getProviderByCapability(
const p = await factory.getProviderByCapability(
CopilotCapability.TextToText
);
t.is(
@@ -744,108 +744,40 @@ test('should be able to get provider', async t => {
}
{
const p = await provider.getProviderByCapability(
CopilotCapability.TextToEmbedding
const p = await factory.getProviderByCapability(
CopilotCapability.ImageToImage,
{ model: 'lora/image-to-image' }
);
t.is(
p?.type.toString(),
'openai',
'fal',
'should get provider support text-to-embedding'
);
}
{
const p = await provider.getProviderByCapability(
CopilotCapability.TextToImage
);
t.is(
p?.type.toString(),
'fal',
'should get provider support text-to-image'
);
}
{
const p = await provider.getProviderByCapability(
CopilotCapability.ImageToImage
);
t.is(
p?.type.toString(),
'fal',
'should get provider support image-to-image'
);
}
{
const p = await provider.getProviderByCapability(
CopilotCapability.ImageToText
);
t.is(
p?.type.toString(),
'fal',
'should get provider support image-to-text'
);
}
// text-to-image use fal by default, but this case can use
// model dall-e-3 to select openai provider
{
const p = await provider.getProviderByCapability(
CopilotCapability.TextToImage,
'dall-e-3'
);
t.is(
p?.type.toString(),
'openai',
'should get provider support text-to-image and model'
);
}
// gpt4o is not defined now, but it already published by openai
// we should check from online api if it is available
{
const p = await provider.getProviderByCapability(
const p = await factory.getProviderByCapability(
CopilotCapability.ImageToText,
'gpt-4o-2024-08-06'
{ prefer: CopilotProviderType.FAL }
);
t.is(
p?.type.toString(),
'openai',
'should get provider support text-to-image and model'
'fal',
'should get provider support text-to-embedding'
);
}
// if a model is not defined and not available in online api
// it should return null
{
const p = await provider.getProviderByCapability(
const p = await factory.getProviderByCapability(
CopilotCapability.ImageToText,
'gpt-4-not-exist'
{ model: 'gpt-4-not-exist' }
);
t.falsy(p, 'should not get provider');
}
});
test('should be able to register test provider', async t => {
const { provider } = t.context;
registerCopilotProvider(MockCopilotTestProvider);
const assertProvider = async (cap: CopilotCapability) => {
const p = await provider.getProviderByCapability(cap, 'test');
t.is(
p?.type,
CopilotProviderType.Test,
`should get test provider with ${cap}`
);
};
await assertProvider(CopilotCapability.TextToText);
await assertProvider(CopilotCapability.TextToEmbedding);
await assertProvider(CopilotCapability.TextToImage);
await assertProvider(CopilotCapability.ImageToImage);
await assertProvider(CopilotCapability.ImageToText);
});
// ==================== workflow ====================
// this test used to preview the final result of the workflow
@@ -854,7 +786,6 @@ test.skip('should be able to preview workflow', async t => {
const { prompt, workflow, executors } = t.context;
executors.text.register();
registerCopilotProvider(OpenAIProvider);
for (const p of prompts) {
await prompt.set(p.name, p.model, p.messages, p.config);
@@ -878,8 +809,6 @@ test.skip('should be able to preview workflow', async t => {
}
console.log('final stream result:', result);
t.truthy(result, 'should return result');
unregisterCopilotProvider(OpenAIProvider.type);
});
const runWorkflow = async function* runWorkflow(
@@ -900,8 +829,6 @@ test('should be able to run pre defined workflow', async t => {
executors.text.register();
executors.html.register();
executors.json.register();
unregisterCopilotProvider(OpenAIProvider.type);
registerCopilotProvider(MockCopilotTestProvider);
const executor = Sinon.spy(executors.text, 'next');
@@ -941,17 +868,12 @@ test('should be able to run pre defined workflow', async t => {
}
}
}
unregisterCopilotProvider(MockCopilotTestProvider.type);
registerCopilotProvider(OpenAIProvider);
});
test('should be able to run workflow', async t => {
const { workflow, executors } = t.context;
executors.text.register();
unregisterCopilotProvider(OpenAIProvider.type);
registerCopilotProvider(MockCopilotTestProvider);
const executor = Sinon.spy(executors.text, 'next');
@@ -998,9 +920,6 @@ test('should be able to run workflow', async t => {
'graph params should correct'
);
}
unregisterCopilotProvider(MockCopilotTestProvider.type);
registerCopilotProvider(OpenAIProvider);
});
// ==================== workflow executor ====================
@@ -1037,18 +956,16 @@ test('should be able to run executor', async t => {
});
test('should be able to run text executor', async t => {
const { executors, provider, prompt } = t.context;
const { executors, factory, prompt } = t.context;
executors.text.register();
const executor = getWorkflowExecutor(executors.text.type);
unregisterCopilotProvider(OpenAIProvider.type);
registerCopilotProvider(MockCopilotTestProvider);
await prompt.set('test', 'test', [
{ role: 'system', content: 'hello {{word}}' },
]);
// mock provider
const testProvider =
(await provider.getProviderByModel<CopilotCapability.TextToText>('test'))!;
(await factory.getProviderByModel<CopilotCapability.TextToText>('test'))!;
const text = Sinon.spy(testProvider, 'generateText');
const textStream = Sinon.spy(testProvider, 'generateTextStream');
@@ -1103,23 +1020,19 @@ test('should be able to run text executor', async t => {
}
Sinon.restore();
unregisterCopilotProvider(MockCopilotTestProvider.type);
registerCopilotProvider(OpenAIProvider);
});
test('should be able to run image executor', async t => {
const { executors, provider, prompt } = t.context;
const { executors, factory, prompt } = t.context;
executors.image.register();
const executor = getWorkflowExecutor(executors.image.type);
unregisterCopilotProvider(OpenAIProvider.type);
registerCopilotProvider(MockCopilotTestProvider);
await prompt.set('test', 'test', [
{ role: 'user', content: 'tag1, tag2, tag3, {{#tags}}{{.}}, {{/tags}}' },
]);
// mock provider
const testProvider =
(await provider.getProviderByModel<CopilotCapability.TextToImage>('test'))!;
(await factory.getProviderByModel<CopilotCapability.TextToImage>('test'))!;
const image = Sinon.spy(testProvider, 'generateImages');
const imageStream = Sinon.spy(testProvider, 'generateImagesStream');
@@ -1184,8 +1097,6 @@ test('should be able to run image executor', async t => {
}
Sinon.restore();
unregisterCopilotProvider(MockCopilotTestProvider.type);
registerCopilotProvider(OpenAIProvider);
});
test('CitationParser should replace citation placeholders with URLs', t => {

View File

@@ -4,9 +4,11 @@ import {
TestingModule as NestjsTestingModule,
TestingModuleBuilder,
} from '@nestjs/testing';
import { PrismaClient } from '@prisma/client';
import { FunctionalityModules } from '../app.module';
import { AFFiNELogger } from '../base';
import { AFFiNELogger, EventBus, JobQueue } from '../base';
import { createFactory, MockEventBus, MockJobQueue } from './mocks';
import { TEST_LOG_LEVEL } from './utils';
interface TestingModuleMetadata extends ModuleMetadata {
@@ -15,10 +17,13 @@ interface TestingModuleMetadata extends ModuleMetadata {
export interface TestingModule extends NestjsTestingModule {
[Symbol.asyncDispose](): Promise<void>;
create: ReturnType<typeof createFactory>;
queue: MockJobQueue;
event: MockEventBus;
}
export async function createModule(
metadata: TestingModuleMetadata
metadata: TestingModuleMetadata = {}
): Promise<TestingModule> {
const { tapModule, ...meta } = metadata;
@@ -27,6 +32,12 @@ export async function createModule(
imports: [...FunctionalityModules, ...(meta.imports ?? [])],
});
builder
.overrideProvider(JobQueue)
.useValue(new MockJobQueue())
.overrideProvider(EventBus)
.useValue(new MockEventBus());
// when custom override happens
if (tapModule) {
tapModule(builder);
@@ -44,6 +55,9 @@ export async function createModule(
module[Symbol.asyncDispose] = async () => {
await module.close();
};
module.create = createFactory(module.get(PrismaClient));
module.queue = module.get(JobQueue);
module.event = module.get(EventBus);
return module;
}

View File

@@ -7,7 +7,6 @@ import type { TestFn } from 'ava';
import ava from 'ava';
import request from 'supertest';
import { DocRendererModule } from '../../core/doc-renderer';
import { createTestingApp } from '../utils';
const test = ava as TestFn<{
@@ -45,13 +44,11 @@ function initTestStaticFiles(staticPath: string) {
}
}
test.before('init selfhost server', async t => {
test.before(async t => {
const staticPath = new Package('@affine/server').join('static').value;
initTestStaticFiles(staticPath);
const app = await createTestingApp({
imports: [DocRendererModule],
});
const app = await createTestingApp();
t.context.app = app;
});

View File

@@ -1,7 +1,7 @@
import { getCurrentUserQuery } from '@affine/graphql';
import { Mockers } from '../mocks';
import { app, e2e } from './test';
import { Mockers } from '../../mocks';
import { app, e2e } from '../test';
e2e('should create test app correctly', async t => {
t.truthy(app);
@@ -18,12 +18,7 @@ e2e('should mock queue work', async t => {
e2e('should handle http request', async t => {
const res = await app.GET('/info');
t.is(res.status, 200);
t.is(res.body.compatibility, AFFiNE.version);
});
e2e('should handle gql request', async t => {
const user = await app.gql({ query: getCurrentUserQuery });
t.is(user.currentUser, null);
t.is(res.body.compatibility, env.version);
});
e2e('should create workspace with owner', async t => {

View File

@@ -0,0 +1,46 @@
import { getCurrentUserQuery } from '@affine/graphql';
import { createApp } from '../create-app';
import { e2e } from '../test';
e2e('should init doc service', async t => {
// @ts-expect-error override
globalThis.env.FLAVOR = 'doc';
await using app = await createApp();
const res = await app.GET('/info').expect(200);
t.is(res.body.flavor, 'doc');
await t.throwsAsync(app.gql({ query: getCurrentUserQuery }));
});
e2e('should init graphql service', async t => {
// @ts-expect-error override
globalThis.env.FLAVOR = 'graphql';
await using app = await createApp();
const res = await app.GET('/info').expect(200);
t.is(res.body.flavor, 'graphql');
const user = await app.gql({ query: getCurrentUserQuery });
t.is(user.currentUser, null);
});
e2e('should init sync service', async t => {
// @ts-expect-error override
globalThis.env.FLAVOR = 'sync';
await using app = await createApp();
const res = await app.GET('/info').expect(200);
t.is(res.body.flavor, 'sync');
});
e2e('should init renderer service', async t => {
// @ts-expect-error override
globalThis.env.FLAVOR = 'renderer';
await using app = await createApp();
const res = await app.GET('/info').expect(200);
t.is(res.body.flavor, 'renderer');
});

View File

@@ -13,6 +13,7 @@ import {
AFFiNELogger,
CacheInterceptor,
CloudThrottlerGuard,
EventBus,
GlobalExceptionFilter,
JobQueue,
OneMB,
@@ -23,6 +24,7 @@ import { Mailer } from '../../core/mail';
import {
createFactory,
MockedUser,
MockEventBus,
MockJobQueue,
MockMailer,
MockUser,
@@ -181,23 +183,19 @@ export class TestingApp extends NestApplication {
}
}
let GLOBAL_APP_INSTANCE: TestingApp | null = null;
export async function createApp(
metadata: TestingAppMetadata = {}
): Promise<TestingApp> {
if (GLOBAL_APP_INSTANCE) {
return GLOBAL_APP_INSTANCE;
}
const { buildAppModule } = await import('../../app.module');
const { tapModule, tapApp } = metadata;
const builder = Test.createTestingModule({
imports: [buildAppModule()],
imports: [buildAppModule(globalThis.env)],
});
builder.overrideProvider(Mailer).useValue(new MockMailer());
builder.overrideProvider(JobQueue).useValue(new MockJobQueue());
builder.overrideProvider(EventBus).useValue(new MockEventBus());
// when custom override happens
if (tapModule) {
@@ -240,6 +238,5 @@ export async function createApp(
await app.init();
GLOBAL_APP_INSTANCE = app;
return app;
}

View File

@@ -0,0 +1,156 @@
import test from 'ava';
import { Env } from '../env';
const envs = { ...process.env };
test.beforeEach(() => {
process.env = { ...envs };
});
test('should init env', t => {
t.true(globalThis.env.testing);
});
test('should read NODE_ENV', t => {
process.env.NODE_ENV = 'test';
t.deepEqual(
['test', 'development', 'production'].map(envVal => {
process.env.NODE_ENV = envVal;
const env = new Env();
return env.NODE_ENV;
}),
['test', 'development', 'production']
);
t.throws(
() => {
process.env.NODE_ENV = 'unknown';
new Env();
},
{
message:
'Invalid value "unknown" for environment variable NODE_ENV, expected one of ["development","test","production"]',
}
);
});
test('should read NAMESPACE', t => {
t.deepEqual(
['dev', 'beta', 'production'].map(envVal => {
process.env.AFFINE_ENV = envVal;
const env = new Env();
return env.NAMESPACE;
}),
['dev', 'beta', 'production']
);
t.throws(() => {
process.env.AFFINE_ENV = 'unknown';
new Env();
});
});
test('should read DEPLOYMENT_TYPE', t => {
t.deepEqual(
['affine', 'selfhosted'].map(envVal => {
process.env.DEPLOYMENT_TYPE = envVal;
const env = new Env();
return env.DEPLOYMENT_TYPE;
}),
['affine', 'selfhosted']
);
t.throws(() => {
process.env.DEPLOYMENT_TYPE = 'unknown';
new Env();
});
});
test('should read FLAVOR', t => {
t.deepEqual(
['allinone', 'graphql', 'sync', 'renderer', 'doc', 'script'].map(envVal => {
process.env.SERVER_FLAVOR = envVal;
const env = new Env();
return env.FLAVOR;
}),
['allinone', 'graphql', 'sync', 'renderer', 'doc', 'script']
);
t.throws(
() => {
process.env.SERVER_FLAVOR = 'unknown';
new Env();
},
{
message:
'Invalid value "unknown" for environment variable SERVER_FLAVOR, expected one of ["allinone","graphql","sync","renderer","doc","script"]',
}
);
});
test('should read platform', t => {
t.deepEqual(
['gcp', 'unknown'].map(envVal => {
process.env.DEPLOYMENT_PLATFORM = envVal;
const env = new Env();
return env.platform;
}),
['gcp', 'unknown']
);
t.notThrows(() => {
process.env.PLATFORM = 'unknown';
new Env();
});
});
test('should tell flavors correctly', t => {
process.env.SERVER_FLAVOR = 'allinone';
t.deepEqual(new Env().flavors, {
graphql: true,
sync: true,
renderer: true,
doc: true,
script: true,
});
process.env.SERVER_FLAVOR = 'graphql';
t.deepEqual(new Env().flavors, {
graphql: true,
sync: false,
renderer: false,
doc: false,
script: false,
});
});
test('should tell selfhosted correctly', t => {
process.env.DEPLOYMENT_TYPE = 'selfhosted';
t.true(new Env().selfhosted);
process.env.DEPLOYMENT_TYPE = 'affine';
t.false(new Env().selfhosted);
});
test('should tell namespaces correctly', t => {
process.env.AFFINE_ENV = 'dev';
t.deepEqual(new Env().namespaces, {
canary: true,
beta: false,
production: false,
});
process.env.AFFINE_ENV = 'beta';
t.deepEqual(new Env().namespaces, {
canary: false,
beta: true,
production: false,
});
process.env.AFFINE_ENV = 'production';
t.deepEqual(new Env().namespaces, {
canary: false,
beta: false,
production: true,
});
});

View File

@@ -0,0 +1,113 @@
import { randomBytes } from 'node:crypto';
import {
CopilotCapability,
CopilotChatOptions,
CopilotEmbeddingOptions,
PromptMessage,
} from '../../plugins/copilot/providers';
import {
DEFAULT_DIMENSIONS,
OpenAIProvider,
} from '../../plugins/copilot/providers/openai';
import { sleep } from '../utils/utils';
export class MockCopilotProvider extends OpenAIProvider {
override readonly models = [
'test',
'gpt-4o',
'gpt-4o-2024-08-06',
'fast-sdxl/image-to-image',
'lcm-sd15-i2i',
'clarity-upscaler',
'imageutils/rembg',
];
override readonly capabilities = [
CopilotCapability.TextToText,
CopilotCapability.TextToEmbedding,
CopilotCapability.TextToImage,
CopilotCapability.ImageToImage,
CopilotCapability.ImageToText,
];
// ====== text to text ======
override async generateText(
messages: PromptMessage[],
model: string = 'test',
options: CopilotChatOptions = {}
): Promise<string> {
this.checkParams({ messages, model, options });
// make some time gap for history test case
await sleep(100);
return 'generate text to text';
}
override async *generateTextStream(
messages: PromptMessage[],
model: string = 'gpt-4o-mini',
options: CopilotChatOptions = {}
): AsyncIterable<string> {
this.checkParams({ messages, model, options });
// make some time gap for history test case
await sleep(100);
const result = 'generate text to text stream';
for (const message of result) {
yield message;
if (options.signal?.aborted) {
break;
}
}
}
// ====== text to embedding ======
override async generateEmbedding(
messages: string | string[],
model: string,
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages];
this.checkParams({ embeddings: messages, model, options });
// make some time gap for history test case
await sleep(100);
return [Array.from(randomBytes(options.dimensions)).map(v => v % 128)];
}
// ====== text to image ======
override async generateImages(
messages: PromptMessage[],
model: string = 'test',
_options: {
signal?: AbortSignal;
user?: string;
} = {}
): Promise<Array<string>> {
const { content: prompt } = messages[0] || {};
if (!prompt) {
throw new Error('Prompt is required');
}
// make some time gap for history test case
await sleep(100);
// just let test case can easily verify the final prompt
return [`https://example.com/${model}.jpg`, prompt];
}
override async *generateImagesStream(
messages: PromptMessage[],
model: string = 'dall-e-3',
options: {
signal?: AbortSignal;
user?: string;
} = {}
): AsyncIterable<string> {
const ret = await this.generateImages(messages, model, options);
for (const url of ret) {
yield url;
}
}
}

View File

@@ -0,0 +1,35 @@
import Sinon from 'sinon';
import { EventBus } from '../../base';
import { EventName } from '../../base/event/def';
export class MockEventBus {
private readonly stub = Sinon.createStubInstance(EventBus);
emit = this.stub.emitAsync;
emitAsync = this.stub.emitAsync;
broadcast = this.stub.broadcast;
last<Event extends EventName>(
name: Event
): { name: Event; payload: Events[Event] } {
const call = this.emitAsync
.getCalls()
.reverse()
.find(call => call.args[0] === name);
if (!call) {
throw new Error(`Event ${name} never called`);
}
// @ts-expect-error allow
return {
name,
payload: call.args[1],
};
}
count(name: EventName) {
return this.emitAsync.getCalls().filter(call => call.args[0] === name)
.length;
}
}

View File

@@ -4,7 +4,9 @@ export * from './user.mock';
export * from './workspace.mock';
export * from './workspace-user.mock';
import { MockCopilotProvider } from './copilot.mock';
import { MockDocMeta } from './doc-meta.mock';
import { MockEventBus } from './eventbus.mock';
import { MockMailer } from './mailer.mock';
import { MockJobQueue } from './queue.mock';
import { MockTeamWorkspace } from './team-workspace.mock';
@@ -22,4 +24,4 @@ export const Mockers = {
DocMeta: MockDocMeta,
};
export { MockJobQueue, MockMailer };
export { MockCopilotProvider, MockEventBus, MockJobQueue, MockMailer };

View File

@@ -1,7 +1,6 @@
import { User } from '@prisma/client';
import ava, { TestFn } from 'ava';
import { ConfigModule } from '../../base/config';
import { FeatureType, Models, UserFeatureModel, UserModel } from '../../models';
import { createTestingModule, TestingModule } from '../utils';
@@ -126,13 +125,9 @@ test('should not switch user quota if the new quota is the same as the current o
});
test('should use pro plan as free for selfhost instance', async t => {
await using module = await createTestingModule({
imports: [
ConfigModule.forRoot({
isSelfhosted: true,
}),
],
});
// @ts-expect-error
env.DEPLOYMENT_TYPE = 'selfhosted';
await using module = await createTestingModule();
const models = module.get(Models);
const u1 = await models.user.create({

View File

@@ -1,5 +1,3 @@
import '../../plugins/config';
import { Controller, Get, HttpStatus, UseGuards } from '@nestjs/common';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
@@ -89,11 +87,13 @@ class NonThrottledController {
test.before(async t => {
const app = await createTestingApp({
imports: [
ConfigModule.forRoot({
throttler: {
default: {
ttl: 60,
limit: 120,
ConfigModule.override({
throttle: {
throttlers: {
default: {
ttl: 60,
limit: 120,
},
},
},
}),

View File

@@ -1,5 +1,3 @@
import '../../plugins/config';
import { randomUUID } from 'node:crypto';
import { HttpStatus } from '@nestjs/common';
@@ -30,14 +28,12 @@ const test = ava as TestFn<{
test.before(async t => {
const app = await createTestingApp({
imports: [
ConfigModule.forRoot({
plugins: {
oauth: {
providers: {
google: {
clientId: 'google-client-id',
clientSecret: 'google-client-secret',
},
ConfigModule.override({
oauth: {
providers: {
google: {
clientId: 'google-client-id',
clientSecret: 'google-client-secret',
},
},
},

View File

@@ -6,12 +6,13 @@ import Sinon from 'sinon';
import Stripe from 'stripe';
import { AppModule } from '../../app.module';
import { EventBus, Runtime } from '../../base';
import { ConfigModule } from '../../base/config';
import { EventBus } from '../../base';
import { ConfigFactory, ConfigModule } from '../../base/config';
import { CurrentUser } from '../../core/auth';
import { AuthService } from '../../core/auth/service';
import { EarlyAccessType, FeatureService } from '../../core/features';
import { SubscriptionService } from '../../plugins/payment/service';
import { StripeFactory } from '../../plugins/payment/stripe';
import {
CouponType,
encodeLookupKey,
@@ -159,7 +160,6 @@ const test = ava as TestFn<{
service: SubscriptionService;
event: Sinon.SinonStubbedInstance<EventBus>;
feature: Sinon.SinonStubbedInstance<FeatureService>;
runtime: Sinon.SinonStubbedInstance<Runtime>;
stripe: {
customers: Sinon.SinonStubbedInstance<Stripe.CustomersResource>;
prices: Sinon.SinonStubbedInstance<Stripe.PricesResource>;
@@ -184,16 +184,12 @@ function getLastCheckoutPrice(checkoutStub: Sinon.SinonStub) {
test.before(async t => {
const app = await createTestingApp({
imports: [
ConfigModule.forRoot({
plugins: {
payment: {
stripe: {
keys: {
APIKey: '1',
webhookKey: '1',
},
},
},
ConfigModule.override({
payment: {
enabled: true,
showLifetimePrice: true,
apiKey: '1',
webhookKey: '1',
},
}),
AppModule,
@@ -203,18 +199,19 @@ test.before(async t => {
Sinon.createStubInstance(FeatureService)
);
m.overrideProvider(EventBus).useValue(Sinon.createStubInstance(EventBus));
m.overrideProvider(Runtime).useValue(Sinon.createStubInstance(Runtime));
},
});
t.context.event = app.get(EventBus);
t.context.service = app.get(SubscriptionService);
t.context.feature = app.get(FeatureService);
t.context.runtime = app.get(Runtime);
t.context.db = app.get(PrismaClient);
t.context.app = app;
const stripe = app.get(Stripe);
const stripeFactory = app.get(StripeFactory);
await stripeFactory.onConfigInit();
const stripe = stripeFactory.stripe;
const stripeStubs = {
customers: Sinon.stub(stripe.customers),
prices: Sinon.stub(stripe.prices),
@@ -234,6 +231,12 @@ test.beforeEach(async t => {
await t.context.app.initTestingDB();
t.context.u1 = await app.get(AuthService).signUp('u1@affine.pro', '1');
app.get(ConfigFactory).override({
payment: {
showLifetimePrice: true,
},
});
await db.workspace.create({
data: {
id: 'ws_1',
@@ -249,11 +252,6 @@ test.beforeEach(async t => {
Sinon.reset();
// default stubs
t.context.runtime.fetch
.withArgs('plugins.payment/showLifetimePrice')
.resolves(true);
// @ts-expect-error stub
stripe.prices.list.callsFake((params: Stripe.PriceListParams) => {
if (params.lookup_keys) {
@@ -294,8 +292,13 @@ test('should list normal prices for authenticated user', async t => {
});
test('should not show lifetime price if not enabled', async t => {
const { service, runtime } = t.context;
runtime.fetch.withArgs('plugins.payment/showLifetimePrice').resolves(false);
const { service, app } = t.context;
app.get(ConfigFactory).override({
payment: {
showLifetimePrice: false,
},
});
const prices = await service.listPrices(t.context.u1);
@@ -539,8 +542,11 @@ test('should get correct pro plan price for checking out', async t => {
// any user, lifetime recurring
{
feature.isEarlyAccessUser.resolves(false);
const runtime = app.get(Runtime);
await runtime.set('plugins.payment/showLifetimePrice', true);
app.get(ConfigFactory).override({
payment: {
showLifetimePrice: true,
},
});
await service.checkout(
{
@@ -1181,8 +1187,12 @@ const onetimeYearlyInvoice: Stripe.Invoice = {
};
test('should not be able to checkout for lifetime recurring if not enabled', async t => {
const { service, u1, runtime } = t.context;
runtime.fetch.withArgs('plugins.payment/showLifetimePrice').resolves(false);
const { service, u1, app } = t.context;
app.get(ConfigFactory).override({
payment: {
showLifetimePrice: false,
},
});
await t.throwsAsync(
() =>
@@ -1202,7 +1212,13 @@ test('should not be able to checkout for lifetime recurring if not enabled', asy
});
test('should be able to checkout for lifetime recurring', async t => {
const { service, u1, stripe } = t.context;
const { service, u1, stripe, app } = t.context;
app.get(ConfigFactory).override({
payment: {
showLifetimePrice: true,
},
});
await service.checkout(
{

View File

@@ -1 +0,0 @@
export const gql = '/graphql';

View File

@@ -1,160 +1,11 @@
import { randomBytes } from 'node:crypto';
import {
DEFAULT_DIMENSIONS,
OpenAIProvider,
} from '../../plugins/copilot/providers/openai';
import {
CopilotCapability,
CopilotChatOptions,
CopilotEmbeddingOptions,
CopilotImageToImageProvider,
CopilotImageToTextProvider,
CopilotProviderType,
CopilotTextToEmbeddingProvider,
CopilotTextToImageProvider,
CopilotTextToTextProvider,
PromptConfig,
PromptMessage,
} from '../../plugins/copilot/types';
import { PromptConfig, PromptMessage } from '../../plugins/copilot/providers';
import { NodeExecutorType } from '../../plugins/copilot/workflow/executor';
import {
WorkflowGraph,
WorkflowNodeType,
WorkflowParams,
} from '../../plugins/copilot/workflow/types';
import { gql } from './common';
import { TestingApp } from './testing-app';
import { sleep } from './utils';
// @ts-expect-error no error
export class MockCopilotTestProvider
extends OpenAIProvider
implements
CopilotTextToTextProvider,
CopilotTextToEmbeddingProvider,
CopilotTextToImageProvider,
CopilotImageToImageProvider,
CopilotImageToTextProvider
{
static override readonly type = CopilotProviderType.Test;
override readonly availableModels = [
'test',
'gpt-4o',
'gpt-4o-2024-08-06',
'fast-sdxl/image-to-image',
'lcm-sd15-i2i',
'clarity-upscaler',
'imageutils/rembg',
];
static override readonly capabilities = [
CopilotCapability.TextToText,
CopilotCapability.TextToEmbedding,
CopilotCapability.TextToImage,
CopilotCapability.ImageToImage,
CopilotCapability.ImageToText,
];
constructor() {
super({ apiKey: '1' });
}
override getCapabilities(): CopilotCapability[] {
return MockCopilotTestProvider.capabilities;
}
static override assetsConfig(_config: any) {
return true;
}
override get type(): CopilotProviderType {
return CopilotProviderType.Test;
}
override async isModelAvailable(model: string): Promise<boolean> {
return this.availableModels.includes(model);
}
// ====== text to text ======
override async generateText(
messages: PromptMessage[],
model: string = 'test',
options: CopilotChatOptions = {}
): Promise<string> {
this.checkParams({ messages, model, options });
// make some time gap for history test case
await sleep(100);
return 'generate text to text';
}
override async *generateTextStream(
messages: PromptMessage[],
model: string = 'gpt-4o-mini',
options: CopilotChatOptions = {}
): AsyncIterable<string> {
this.checkParams({ messages, model, options });
// make some time gap for history test case
await sleep(100);
const result = 'generate text to text stream';
for (const message of result) {
yield message;
if (options.signal?.aborted) {
break;
}
}
}
// ====== text to embedding ======
override async generateEmbedding(
messages: string | string[],
model: string,
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages];
this.checkParams({ embeddings: messages, model, options });
// make some time gap for history test case
await sleep(100);
return [Array.from(randomBytes(options.dimensions)).map(v => v % 128)];
}
// ====== text to image ======
override async generateImages(
messages: PromptMessage[],
model: string = 'test',
_options: {
signal?: AbortSignal;
user?: string;
} = {}
): Promise<Array<string>> {
const { content: prompt } = messages[0] || {};
if (!prompt) {
throw new Error('Prompt is required');
}
// make some time gap for history test case
await sleep(100);
// just let test case can easily verify the final prompt
return [`https://example.com/${model}.jpg`, prompt];
}
override async *generateImagesStream(
messages: PromptMessage[],
model: string = 'dall-e-3',
options: {
signal?: AbortSignal;
user?: string;
} = {}
): AsyncIterable<string> {
const ret = await this.generateImages(messages, model, options);
for (const url of ret) {
yield url;
}
}
}
export const cleanObject = (
obj: any[] | undefined,
@@ -342,7 +193,7 @@ export async function addContextFile(
content: Buffer
): Promise<{ id: string }> {
const res = await app
.POST(gql)
.POST('/graphql')
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
.field(
'operations',

View File

@@ -41,20 +41,17 @@ export async function createTestingApp(
moduleDef: TestingAppMetadata = {}
): Promise<TestingApp> {
const module = await createTestingModule(moduleDef, false);
const logger = new AFFiNELogger();
logger.setLogLevels([TEST_LOG_LEVEL]);
const app = module.createNestApplication<NestExpressApplication>({
cors: true,
bodyParser: true,
rawBody: true,
logger,
});
app.useBodyParser('raw', { limit: 1 * OneMB });
const logger = new AFFiNELogger();
logger.setLogLevels([TEST_LOG_LEVEL]);
app.useLogger(logger);
app.useGlobalFilters(new GlobalExceptionFilter(app.getHttpAdapter()));
app.use(
graphqlUploadExpress({

View File

@@ -8,9 +8,10 @@ import {
} from '@nestjs/testing';
import { PrismaClient } from '@prisma/client';
import { AppModule, FunctionalityModules } from '../../app.module';
import { AFFiNELogger, JobQueue, Runtime } from '../../base';
import { buildAppModule, FunctionalityModules } from '../../app.module';
import { AFFiNELogger, JobQueue } from '../../base';
import { GqlModule } from '../../base/graphql';
import { ServerConfigModule } from '../../core';
import { AuthGuard, AuthModule } from '../../core/auth';
import { Mailer, MailModule } from '../../core/mail';
import { ModelsModule } from '../../models';
@@ -63,16 +64,18 @@ export async function createTestingModule(
autoInitialize = true
): Promise<TestingModule> {
// setting up
let imports = moduleDef.imports ?? [AppModule];
let imports = moduleDef.imports ?? [buildAppModule(globalThis.env)];
imports =
imports[0] === AppModule
? [AppModule]
// @ts-expect-error
imports[0].module?.name === 'AppModule'
? imports
: dedupeModules([
...FunctionalityModules,
ModelsModule,
AuthModule,
GqlModule,
MailModule,
ServerConfigModule,
...imports,
]);
@@ -101,10 +104,6 @@ export async function createTestingModule(
testingModule.initTestingDB = async () => {
await initTestingDB(module);
const runtime = module.get(Runtime);
// by pass password min length validation
await runtime.set('auth/password.min', 1);
};
testingModule.create = createFactory(

View File

@@ -1,6 +1,7 @@
import { INestApplicationContext, LogLevel } from '@nestjs/common';
import { ModuleRef } from '@nestjs/core';
import { PrismaClient } from '@prisma/client';
import whywhywhy from 'why-is-node-running';
import { RefreshFeatures0001 } from '../../data/migrations/0001-refresh-features';
@@ -32,3 +33,21 @@ export async function initTestingDB(context: INestApplicationContext) {
export async function sleep(ms: number) {
return new Promise(resolve => setTimeout(resolve, ms));
}
export function debugProcessHolding(ignorePrismaStack = true) {
setImmediate(() => {
whywhywhy({
error: message => {
// ignore prisma error
if (
ignorePrismaStack &&
(message.includes('Prisma') || message.includes('prisma'))
) {
return;
}
console.error(message);
},
});
});
}

View File

@@ -3,7 +3,7 @@ import test from 'ava';
import Sinon from 'sinon';
import { AppModule } from '../app.module';
import { Runtime, UseNamedGuard } from '../base';
import { ConfigFactory, UseNamedGuard } from '../base';
import { Public } from '../core/auth/guard';
import { VersionService } from '../core/version/service';
import { createTestingApp, TestingApp } from './utils';
@@ -19,28 +19,28 @@ class GuardedController {
}
let app: TestingApp;
let runtime: Sinon.SinonStubbedInstance<Runtime>;
let config: ConfigFactory;
let version: VersionService;
function checkVersion(enabled = true) {
runtime.fetch.withArgs('client/versionControl.enabled').resolves(enabled);
runtime.fetch
.withArgs('client/versionControl.requiredVersion')
.resolves('>=0.20.0');
config.override({
client: {
versionControl: {
enabled,
requiredVersion: '>=0.20.0',
},
},
});
}
test.before(async () => {
app = await createTestingApp({
imports: [AppModule],
controllers: [GuardedController],
tapModule: m => {
m.overrideProvider(Runtime).useValue(Sinon.createStubInstance(Runtime));
},
});
runtime = app.get(Runtime);
version = app.get(VersionService, { strict: false });
config = app.get(ConfigFactory, { strict: false });
});
test.beforeEach(async () => {
@@ -74,9 +74,13 @@ test('should passthrough if version check is not enabled', async t => {
});
test('should passthrough is version range is invalid', async t => {
runtime.fetch
.withArgs('client/versionControl.requiredVersion')
.resolves('invalid');
config.override({
client: {
versionControl: {
requiredVersion: 'invalid',
},
},
});
let res = await app.GET('/guarded/test').set('x-affine-version', 'invalid');
@@ -92,9 +96,13 @@ test('should pass if client version is allowed', async t => {
t.is(res.status, 200);
runtime.fetch
.withArgs('client/versionControl.requiredVersion')
.resolves('>=0.19.0');
config.override({
client: {
versionControl: {
requiredVersion: '>=0.19.0',
},
},
});
res = await app.GET('/guarded/test').set('x-affine-version', '0.19.0');
@@ -120,9 +128,13 @@ test('should fail if client version is not set or invalid', async t => {
});
test('should tell upgrade if client version is lower than allowed', async t => {
runtime.fetch
.withArgs('client/versionControl.requiredVersion')
.resolves('>=0.21.0 <=0.22.0');
config.override({
client: {
versionControl: {
requiredVersion: '>=0.21.0 <=0.22.0',
},
},
});
let res = await app.GET('/guarded/test').set('x-affine-version', '0.20.0');
@@ -134,9 +146,13 @@ test('should tell upgrade if client version is lower than allowed', async t => {
});
test('should tell downgrade if client version is higher than allowed', async t => {
runtime.fetch
.withArgs('client/versionControl.requiredVersion')
.resolves('>=0.20.0 <=0.22.0');
config.override({
client: {
versionControl: {
requiredVersion: '>=0.20.0 <=0.22.0',
},
},
});
let res = await app.GET('/guarded/test').set('x-affine-version', '0.23.0');
@@ -148,9 +164,13 @@ test('should tell downgrade if client version is higher than allowed', async t =
});
test('should test prerelease version', async t => {
runtime.fetch
.withArgs('client/versionControl.requiredVersion')
.resolves('>=0.19.0');
config.override({
client: {
versionControl: {
requiredVersion: '>=0.19.0',
},
},
});
let res = await app
.GET('/guarded/test')

View File

@@ -3,7 +3,6 @@ import ava from 'ava';
import Sinon from 'sinon';
import type { Response } from 'supertest';
import { WorkerModule } from '../plugins/worker';
import { createTestingApp, TestingApp } from './utils';
type TestContext = {
@@ -13,9 +12,9 @@ type TestContext = {
const test = ava as TestFn<TestContext>;
test.before(async t => {
const app = await createTestingApp({
imports: [WorkerModule],
});
// @ts-expect-error test
env.DEPLOYMENT_TYPE = 'selfhosted';
const app = await createTestingApp();
t.context.app = app;
});

View File

@@ -1,21 +1,19 @@
import { Controller, Get } from '@nestjs/common';
import { Config, SkipThrottle } from './base';
import { SkipThrottle } from './base';
import { Public } from './core/auth';
@Controller('/info')
export class AppController {
constructor(private readonly config: Config) {}
@SkipThrottle()
@Public()
@Get()
info() {
return {
compatibility: this.config.version,
message: `AFFiNE ${this.config.version} Server`,
type: this.config.type,
flavor: this.config.flavor.type,
compatibility: env.version,
message: `AFFiNE ${env.version} Server`,
type: env.DEPLOYMENT_TYPE,
flavor: env.FLAVOR,
};
}
}

View File

@@ -1,27 +1,19 @@
import {
DynamicModule,
ExecutionContext,
ForwardReference,
Logger,
Module,
} from '@nestjs/common';
import { DynamicModule, ExecutionContext } from '@nestjs/common';
import { ScheduleModule } from '@nestjs/schedule';
import { ClsPluginTransactional } from '@nestjs-cls/transactional';
import { TransactionalAdapterPrisma } from '@nestjs-cls/transactional-adapter-prisma';
import { PrismaClient } from '@prisma/client';
import { Request, Response } from 'express';
import { get } from 'lodash-es';
import { ClsModule } from 'nestjs-cls';
import { AppController } from './app.controller';
import {
getOptionalModuleMetadata,
getRequestIdFromHost,
getRequestIdFromRequest,
ScannerModule,
} from './base';
import { CacheModule } from './base/cache';
import { AFFiNEConfig, ConfigModule, mergeConfigOverride } from './base/config';
import { ConfigModule } from './base/config';
import { ErrorModule } from './base/error';
import { EventModule } from './base/event';
import { GqlModule } from './base/graphql';
@@ -32,12 +24,11 @@ import { MetricsModule } from './base/metrics';
import { MutexModule } from './base/mutex';
import { PrismaModule } from './base/prisma';
import { RedisModule } from './base/redis';
import { RuntimeModule } from './base/runtime';
import { StorageProviderModule } from './base/storage';
import { RateLimiterModule } from './base/throttler';
import { WebSocketModule } from './base/websocket';
import { AuthModule } from './core/auth';
import { ADD_ENABLED_FEATURES, ServerConfigModule } from './core/config';
import { ServerConfigModule, ServerConfigResolverModule } from './core/config';
import { DocStorageModule } from './core/doc';
import { DocRendererModule } from './core/doc-renderer';
import { DocServiceModule } from './core/doc-service';
@@ -52,10 +43,16 @@ import { SyncModule } from './core/sync';
import { UserModule } from './core/user';
import { VersionModule } from './core/version';
import { WorkspaceModule } from './core/workspaces';
import { Env } from './env';
import { ModelsModule } from './models';
import { REGISTERED_PLUGINS } from './plugins';
import { CaptchaModule } from './plugins/captcha';
import { CopilotModule } from './plugins/copilot';
import { CustomerIoModule } from './plugins/customerio';
import { GCloudModule } from './plugins/gcloud';
import { LicenseModule } from './plugins/license';
import { ENABLED_PLUGINS } from './plugins/registry';
import { OAuthModule } from './plugins/oauth';
import { PaymentModule } from './plugins/payment';
import { WorkerModule } from './plugins/worker';
export const FunctionalityModules = [
ClsModule.forRoot({
@@ -91,126 +88,64 @@ export const FunctionalityModules = [
}),
],
}),
ConfigModule.forRoot(),
RuntimeModule,
LoggerModule,
ScannerModule,
PrismaModule,
EventModule,
ConfigModule,
RedisModule,
CacheModule,
MutexModule,
PrismaModule,
MetricsModule,
RateLimiterModule,
StorageProviderModule,
HelpersModule,
ErrorModule,
LoggerModule,
WebSocketModule,
JobModule.forRoot(),
ModelsModule,
];
function filterOptionalModule(
config: AFFiNEConfig,
module: AFFiNEModule | Promise<DynamicModule> | ForwardReference<any>
) {
// can't deal with promise or forward reference
if (module instanceof Promise || 'forwardRef' in module) {
return module;
}
const requirements = getOptionalModuleMetadata(module, 'requires');
// if condition not set or condition met, include the module
if (requirements?.length) {
const nonMetRequirements = requirements.filter(c => {
const value = get(config, c);
return (
value === undefined ||
value === null ||
(typeof value === 'string' && value.trim().length === 0)
);
});
if (nonMetRequirements.length) {
const name = 'module' in module ? module.module.name : module.name;
if (!config.node.test) {
new Logger(name).warn(
`${name} is not enabled because of the required configuration is not satisfied.`,
'Unsatisfied configuration:',
...nonMetRequirements.map(config => ` AFFiNE.${config}`)
);
}
return null;
}
}
const predicator = getOptionalModuleMetadata(module, 'if');
if (predicator && !predicator(config)) {
return null;
}
const contribution = getOptionalModuleMetadata(module, 'contributesTo');
if (contribution) {
ADD_ENABLED_FEATURES(contribution);
}
const subModules = getOptionalModuleMetadata(module, 'imports');
const filteredSubModules = subModules
?.map(subModule => filterOptionalModule(config, subModule))
.filter(Boolean);
Reflect.defineMetadata('imports', filteredSubModules, module);
return module;
}
export class AppModuleBuilder {
private readonly modules: AFFiNEModule[] = [];
constructor(private readonly config: AFFiNEConfig) {}
use(...modules: AFFiNEModule[]): this {
modules.forEach(m => {
const result = filterOptionalModule(this.config, m);
if (result) {
this.modules.push(m);
}
this.modules.push(m);
});
return this;
}
useIf(
predicator: (config: AFFiNEConfig) => boolean,
...modules: AFFiNEModule[]
): this {
if (predicator(this.config)) {
useIf(predicator: () => boolean, ...modules: AFFiNEModule[]): this {
if (predicator()) {
this.use(...modules);
}
return this;
}
compile() {
@Module({
imports: this.modules,
controllers: [AppController],
})
compile(): DynamicModule {
class AppModule {}
return AppModule;
return {
module: AppModule,
imports: this.modules,
controllers: [AppController],
};
}
}
export function buildAppModule() {
AFFiNE = mergeConfigOverride(AFFiNE);
const factor = new AppModuleBuilder(AFFiNE);
export function buildAppModule(env: Env) {
const factor = new AppModuleBuilder();
factor
// basic
.use(...FunctionalityModules)
.use(ModelsModule)
// enable schedule module on graphql server and doc service
.useIf(
config => config.flavor.graphql || config.flavor.doc,
() => env.flavors.graphql || env.flavors.doc,
ScheduleModule.forRoot()
)
@@ -219,46 +154,41 @@ export function buildAppModule() {
// business modules
.use(
ServerConfigModule,
FeatureModule,
QuotaModule,
DocStorageModule,
NotificationModule,
MailModule
)
// renderer server only
.useIf(() => env.flavors.renderer, DocRendererModule)
// sync server only
.useIf(config => config.flavor.sync, SyncModule)
.useIf(() => env.flavors.sync, SyncModule)
// graphql server only
.useIf(
config => config.flavor.graphql,
VersionModule,
() => env.flavors.graphql,
GqlModule,
VersionModule,
StorageModule,
ServerConfigModule,
ServerConfigResolverModule,
WorkspaceModule,
LicenseModule
LicenseModule,
PaymentModule,
CopilotModule,
CaptchaModule,
OAuthModule,
CustomerIoModule
)
// doc service only
.useIf(config => config.flavor.doc, DocServiceModule)
.useIf(() => env.flavors.doc, DocServiceModule)
// self hosted server only
.useIf(config => config.isSelfhosted, SelfhostModule)
.useIf(config => config.flavor.renderer, DocRendererModule);
.useIf(() => env.selfhosted, WorkerModule, SelfhostModule)
// plugin modules
ENABLED_PLUGINS.forEach(name => {
const plugin = REGISTERED_PLUGINS.get(name);
if (!plugin) {
new Logger('AppBuilder').warn(`Unknown plugin ${name}`);
return;
}
factor.use(plugin);
});
// gcloud
.useIf(() => env.gcp, GCloudModule);
return factor.compile();
}
export const AppModule = buildAppModule();
export const AppModule = buildAppModule(env);

View File

@@ -7,11 +7,11 @@ import {
AFFiNELogger,
CacheInterceptor,
CloudThrottlerGuard,
Config,
GlobalExceptionFilter,
} from './base';
import { SocketIoAdapter } from './base/websocket';
import { AuthGuard } from './core/auth';
import { ENABLED_FEATURES } from './core/config/server-feature';
import { serverTimingAndCache } from './middleware/timing';
const OneMB = 1024 * 1024;
@@ -29,9 +29,10 @@ export async function createApp() {
app.useBodyParser('raw', { limit: 100 * OneMB });
app.useLogger(app.get(AFFiNELogger));
const config = app.get(Config);
if (AFFiNE.server.path) {
app.setGlobalPrefix(AFFiNE.server.path);
if (config.server.path) {
app.setGlobalPrefix(config.server.path);
}
app.use(serverTimingAndCache);
@@ -49,22 +50,12 @@ export async function createApp() {
app.use(cookieParser());
// only enable shutdown hooks in production
// https://docs.nestjs.com/fundamentals/lifecycle-events#application-shutdown
if (AFFiNE.NODE_ENV === 'production') {
if (env.prod) {
app.enableShutdownHooks();
}
const adapter = new SocketIoAdapter(app);
app.useWebSocketAdapter(adapter);
if (AFFiNE.isSelfhosted && AFFiNE.metrics.telemetry.enabled) {
const mixpanel = await import('mixpanel');
mixpanel
.init(AFFiNE.metrics.telemetry.token)
.track('selfhost-server-started', {
version: AFFiNE.version,
features: Array.from(ENABLED_FEATURES),
});
}
return app;
}

View File

@@ -0,0 +1,90 @@
import test from 'ava';
import { createModule } from '../../../__tests__/create-module';
import { ConfigFactory, ConfigModule } from '..';
import { Config } from '../config';
const module = await createModule();
test.after.always(async () => {
await module.close();
});
test('should create config', t => {
const config = module.get(Config);
t.is(typeof config.auth.passwordRequirements.max, 'number');
t.is(typeof config.job.queue, 'object');
});
test('should override config', async t => {
await using module = await createModule({
imports: [
ConfigModule.override({
auth: {
passwordRequirements: {
max: 100,
min: 6,
},
},
job: {
queues: {
notification: {
concurrency: 1000,
},
},
},
}),
],
});
const config = module.get(Config);
const configFactory = module.get(ConfigFactory);
t.deepEqual(config.auth.passwordRequirements, {
max: 100,
min: 6,
});
configFactory.override({
auth: {
passwordRequirements: {
max: 10,
},
},
});
t.deepEqual(config.auth.passwordRequirements, {
max: 10,
min: 6,
});
});
test('should validate config', t => {
const config = module.get(ConfigFactory);
t.notThrows(() =>
config.validate([
{
module: 'auth',
key: 'passwordRequirements',
value: { max: 10, min: 6 },
},
])
);
t.throws(
() =>
config.validate([
{
module: 'auth',
key: 'passwordRequirements',
value: { max: 10, min: 10 },
},
]),
{
message: `Invalid config for module [auth] with key [passwordRequirements]
Value: {"max":10,"min":10}
Error: Minimum length of password must be less than maximum length`,
}
);
});

View File

@@ -0,0 +1,3 @@
import { ApplyType } from '../utils';
export class Config extends ApplyType<AppConfig>() {}

View File

@@ -1,55 +0,0 @@
import type { LeafPaths } from '../utils/types';
import { AppStartupConfig } from './types';
export type EnvConfigType = 'string' | 'int' | 'float' | 'boolean';
export type ServerFlavor =
| 'allinone'
| 'graphql'
| 'sync'
| 'renderer'
| 'doc'
| 'script';
export type AFFINE_ENV = 'dev' | 'beta' | 'production';
export type NODE_ENV = 'development' | 'test' | 'production';
export enum DeploymentType {
Affine = 'affine',
Selfhosted = 'selfhosted',
}
export type ConfigPaths = LeafPaths<AppStartupConfig, '', '......'>;
export interface PreDefinedAFFiNEConfig {
ENV_MAP: Record<string, ConfigPaths | [ConfigPaths, EnvConfigType?]>;
serverId: string;
serverName: string;
readonly projectRoot: string;
readonly AFFINE_ENV: AFFINE_ENV;
readonly NODE_ENV: NODE_ENV;
readonly version: string;
readonly type: DeploymentType;
readonly isSelfhosted: boolean;
readonly flavor: { type: string } & { [key in ServerFlavor]: boolean };
readonly affine: { canary: boolean; beta: boolean; stable: boolean };
readonly node: {
prod: boolean;
dev: boolean;
test: boolean;
};
readonly deploy: boolean;
}
export interface AppPluginsConfig {}
export type AFFiNEConfig = PreDefinedAFFiNEConfig &
AppStartupConfig &
AppPluginsConfig;
declare global {
// oxlint-disable-next-line @typescript-eslint/no-namespace
namespace globalThis {
// oxlint-disable-next-line no-var
var AFFiNE: AFFiNEConfig;
}
}

View File

@@ -1,142 +0,0 @@
import { resolve } from 'node:path';
import { fileURLToPath } from 'node:url';
import pkg from '../../../package.json' with { type: 'json' };
import {
AFFINE_ENV,
AFFiNEConfig,
DeploymentType,
NODE_ENV,
PreDefinedAFFiNEConfig,
ServerFlavor,
} from './def';
import { readEnv } from './env';
import { defaultStartupConfig } from './register';
function expectFlavor(flavor: ServerFlavor, expected: ServerFlavor) {
return flavor === expected || flavor === 'allinone';
}
function getPredefinedAFFiNEConfig(): PreDefinedAFFiNEConfig {
const NODE_ENV = readEnv<NODE_ENV>('NODE_ENV', 'production', [
'development',
'test',
'production',
]);
const AFFINE_ENV = readEnv<AFFINE_ENV>('AFFINE_ENV', 'production', [
'dev',
'beta',
'production',
]);
const flavor = readEnv<ServerFlavor>('SERVER_FLAVOR', 'allinone', [
'allinone',
'graphql',
'sync',
'renderer',
'doc',
'script',
]);
const deploymentType = readEnv<DeploymentType>(
'DEPLOYMENT_TYPE',
NODE_ENV === 'development'
? DeploymentType.Affine
: DeploymentType.Selfhosted,
Object.values(DeploymentType)
);
const isSelfhosted = deploymentType === DeploymentType.Selfhosted;
const affine = {
canary: AFFINE_ENV === 'dev',
beta: AFFINE_ENV === 'beta',
stable: AFFINE_ENV === 'production',
};
const node = {
prod: NODE_ENV === 'production',
dev: NODE_ENV === 'development',
test: NODE_ENV === 'test',
};
return {
ENV_MAP: {},
NODE_ENV,
AFFINE_ENV,
serverId: 'some-randome-uuid',
serverName: isSelfhosted ? 'Self-Host Cloud' : 'AFFiNE Cloud',
version: pkg.version,
type: deploymentType,
isSelfhosted,
flavor: {
type: flavor,
allinone: flavor === 'allinone',
graphql: expectFlavor(flavor, 'graphql'),
sync: expectFlavor(flavor, 'sync'),
renderer: expectFlavor(flavor, 'renderer'),
doc: expectFlavor(flavor, 'doc'),
script: expectFlavor(flavor, 'script'),
},
affine,
node,
deploy: !node.dev && !node.test,
projectRoot: resolve(fileURLToPath(import.meta.url), '../../../../'),
};
}
export function getAFFiNEConfigModifier(): AFFiNEConfig {
const predefined = getPredefinedAFFiNEConfig() as AFFiNEConfig;
return chainableProxy(predefined);
}
function merge(a: any, b: any) {
if (typeof b !== 'object' || b instanceof Map || b instanceof Set) {
return b;
}
if (Array.isArray(b)) {
if (Array.isArray(a)) {
return a.concat(b);
}
return b;
}
const result = { ...a };
Object.keys(b).forEach(key => {
result[key] = merge(result[key], b[key]);
});
return result;
}
export function mergeConfigOverride(override: any) {
return merge(defaultStartupConfig, override);
}
function chainableProxy(obj: any) {
const keys: Set<string> = new Set(Object.keys(obj));
return new Proxy(obj, {
get(target, prop) {
if (!(prop in target)) {
keys.add(prop as string);
target[prop] = chainableProxy({});
}
return target[prop];
},
set(target, prop, value) {
keys.add(prop as string);
if (
typeof value === 'object' &&
!(
value instanceof Map ||
value instanceof Set ||
value instanceof Array
)
) {
value = chainableProxy(value);
}
target[prop] = value;
return true;
},
ownKeys() {
return Array.from(keys);
},
});
}

View File

@@ -1,11 +1,9 @@
import { set } from 'lodash-es';
import type { AFFiNEConfig, EnvConfigType } from './def';
export type EnvConfigType = 'string' | 'integer' | 'float' | 'boolean';
/**
* parse number value from environment variables
*/
function int(value: string) {
function integer(value: string) {
const n = parseInt(value);
return Number.isNaN(n) ? undefined : n;
}
@@ -20,7 +18,7 @@ function boolean(value: string) {
}
const envParsers: Record<EnvConfigType, (value: string) => unknown> = {
int,
integer,
float,
boolean,
string: value => value,
@@ -33,38 +31,3 @@ export function parseEnvValue(value: string | undefined, type: EnvConfigType) {
return envParsers[type](value);
}
export function applyEnvToConfig(rawConfig: AFFiNEConfig) {
for (const env in rawConfig.ENV_MAP) {
const config = rawConfig.ENV_MAP[env];
const [path, value] =
typeof config === 'string'
? [config, parseEnvValue(process.env[env], 'string')]
: [config[0], parseEnvValue(process.env[env], config[1] ?? 'string')];
if (value !== undefined) {
set(rawConfig, path, value);
}
}
}
export function readEnv<T>(
env: string,
defaultValue: T,
availableValues?: T[]
) {
const value = process.env[env];
if (value === undefined) {
return defaultValue;
}
if (availableValues && !availableValues.includes(value as any)) {
throw new Error(
`Invalid value '${value}' for environment variable ${env}, expected one of [${availableValues.join(
', '
)}]`
);
}
return value as T;
}

View File

@@ -0,0 +1,60 @@
import { Inject, Injectable, Optional } from '@nestjs/common';
import { merge } from 'lodash-es';
import { InvalidAppConfig } from '../error';
import { APP_CONFIG_DESCRIPTORS, getDefaultConfig } from './register';
export const OVERRIDE_CONFIG_TOKEN = Symbol('OVERRIDE_CONFIG_TOKEN');
@Injectable()
export class ConfigFactory {
readonly #config: DeepReadonly<AppConfig>;
constructor(
@Inject(OVERRIDE_CONFIG_TOKEN)
@Optional()
private readonly overrides: DeepPartial<AppConfig> = {}
) {
this.#config = this.loadDefault();
}
get config() {
return this.#config;
}
override(updates: DeepPartial<AppConfig>) {
merge(this.#config, updates);
}
validate(updates: Array<{ module: string; key: string; value: any }>) {
const errors: string[] = [];
updates.forEach(update => {
const descriptor = APP_CONFIG_DESCRIPTORS[update.module]?.[update.key];
if (!descriptor) {
errors.push(
`Invalid config for module [${update.module}] with unknown key [${update.key}]`
);
return;
}
const { success, error } = descriptor.validate(update.value);
if (!success) {
error.issues.forEach(issue => {
errors.push(`Invalid config for module [${update.module}] with key [${update.key}]
Value: ${JSON.stringify(update.value)}
Error: ${issue.message}`);
});
}
});
if (errors.length > 0) {
throw new InvalidAppConfig(errors.join('\n'));
}
}
private loadDefault(): DeepReadonly<AppConfig> {
const config = getDefaultConfig();
return merge(config, this.overrides);
}
}

View File

@@ -1,37 +1,29 @@
import { DynamicModule, FactoryProvider } from '@nestjs/common';
import { merge } from 'lodash-es';
import { DynamicModule, Global, Module, Provider } from '@nestjs/common';
import { AFFiNEConfig } from './def';
import { Config } from './provider';
export * from './def';
export * from './default';
export { applyEnvToConfig, parseEnvValue } from './env';
export * from './provider';
export { defineRuntimeConfig, defineStartupConfig } from './register';
export type { AppConfig, ConfigItem, ModuleConfig } from './types';
function createConfigProvider(
override?: DeepPartial<Config>
): FactoryProvider<Config> {
return {
provide: Config,
useFactory: () => {
return Object.freeze(merge({}, globalThis.AFFiNE, override));
},
inject: [],
};
}
import { Config } from './config';
import { ConfigFactory, OVERRIDE_CONFIG_TOKEN } from './factory';
import { ConfigProvider } from './provider';
@Global()
@Module({
providers: [ConfigProvider, ConfigFactory],
exports: [ConfigProvider, ConfigFactory],
})
export class ConfigModule {
static forRoot = (override?: DeepPartial<AFFiNEConfig>): DynamicModule => {
const provider = createConfigProvider(override);
static override(overrides: DeepPartial<AppConfigSchema> = {}): DynamicModule {
const provider: Provider = {
provide: OVERRIDE_CONFIG_TOKEN,
useValue: overrides,
};
return {
global: true,
module: ConfigModule,
module: class ConfigOverrideModule {},
providers: [provider],
exports: [provider],
};
};
}
}
export { Config, ConfigFactory };
export { defineModuleConfig, type JSONSchema } from './register';

View File

@@ -1,16 +1,12 @@
import { ApplyType } from '../utils/types';
import { AFFiNEConfig } from './def';
import { FactoryProvider } from '@nestjs/common';
/**
* @example
*
* import { Config } from '@affine/server'
*
* class TestConfig {
* constructor(private readonly config: Config) {}
* test() {
* return this.config.env
* }
* }
*/
export class Config extends ApplyType<AFFiNEConfig>() {}
import { Config } from './config';
import { ConfigFactory } from './factory';
export const ConfigProvider: FactoryProvider = {
provide: Config,
useFactory: (factory: ConfigFactory) => {
return factory.config;
},
inject: [ConfigFactory],
};

View File

@@ -1,66 +1,239 @@
import { Prisma, RuntimeConfigType } from '@prisma/client';
import { get, merge, set } from 'lodash-es';
import { once, set } from 'lodash-es';
import { z } from 'zod';
import {
AppModulesConfigDef,
AppStartupConfig,
ModuleRuntimeConfigDescriptions,
ModuleStartupConfigDescriptions,
} from './types';
import { type EnvConfigType, parseEnvValue } from './env';
import { AppConfigByPath } from './types';
export const defaultStartupConfig: AppStartupConfig = {} as any;
export const defaultRuntimeConfig: Record<
string,
Prisma.RuntimeConfigCreateInput
> = {} as any;
export type JSONSchema = { description?: string } & (
| { type?: undefined; oneOf?: JSONSchema[] }
| {
type: 'string' | 'number' | 'boolean';
enum?: string[];
}
| {
type: 'array';
items?: JSONSchema;
}
| {
type: 'object';
properties?: Record<string, JSONSchema>;
}
);
export function runtimeConfigType(val: any): RuntimeConfigType {
if (Array.isArray(val)) {
return RuntimeConfigType.Array;
}
type ConfigType = EnvConfigType | 'array' | 'object' | 'any';
export type ConfigDescriptor<T> = {
desc: string;
type: ConfigType;
validate: (value: T) => z.SafeParseReturnType<T, T>;
schema: JSONSchema;
default: T;
env?: [string, EnvConfigType];
link?: string;
};
switch (typeof val) {
case 'string':
return RuntimeConfigType.String;
case 'number':
return RuntimeConfigType.Number;
case 'boolean':
return RuntimeConfigType.Boolean;
type ConfigDefineDescriptor<T> = {
desc: string;
default: T;
validate?: (value: T) => boolean;
shape?: z.ZodType<T>;
env?: string | [string, EnvConfigType];
link?: string;
schema?: JSONSchema;
};
function typeFromShape(shape: z.ZodType<any>): ConfigType {
switch (shape.constructor) {
case z.ZodString:
return 'string';
case z.ZodNumber:
return 'float';
case z.ZodBoolean:
return 'boolean';
case z.ZodArray:
return 'array';
case z.ZodObject:
return 'object';
default:
return RuntimeConfigType.Object;
return 'any';
}
}
function registerRuntimeConfig<T extends keyof AppModulesConfigDef>(
module: T,
configs: ModuleRuntimeConfigDescriptions<T>
) {
Object.entries(configs).forEach(([key, value]) => {
defaultRuntimeConfig[`${module}/${key}`] = {
id: `${module}/${key}`,
function shapeFromType(type: ConfigType): z.ZodType<any> {
switch (type) {
case 'string':
return z.string();
case 'float':
return z.number();
case 'boolean':
return z.boolean();
case 'integer':
return z.number().int();
case 'array':
return z.array(z.any());
case 'object':
return z.object({});
default:
return z.any();
}
}
function typeFromSchema(schema: JSONSchema): ConfigType {
if ('type' in schema) {
switch (schema.type) {
case 'string':
return 'string';
case 'number':
return 'float';
case 'boolean':
return 'boolean';
case 'array':
return 'array';
case 'object':
return 'object';
}
}
return 'any';
}
function schemaFromType(type: ConfigType): JSONSchema['type'] {
switch (type) {
case 'any':
return undefined;
case 'float':
case 'integer':
return 'number';
default:
return type;
}
}
function typeFromDefault<T>(defaultValue: T): ConfigType {
if (Array.isArray(defaultValue)) {
return 'array';
}
switch (typeof defaultValue) {
case 'string':
return 'string';
case 'number':
return 'float';
case 'boolean':
return 'boolean';
case 'object':
return 'object';
default:
return 'any';
}
}
function standardizeDescriptor<T>(
desc: ConfigDefineDescriptor<T>
): ConfigDescriptor<T> {
const env = desc.env
? Array.isArray(desc.env)
? desc.env
: ([desc.env, 'string'] as [string, EnvConfigType])
: undefined;
let type: ConfigType = 'any';
if (desc.default !== undefined && desc.default !== null) {
type = typeFromDefault(desc.default);
} else if (env) {
type = env[1];
} else if (desc.shape) {
type = typeFromShape(desc.shape);
} else if (desc.schema) {
type = typeFromSchema(desc.schema);
}
const shape = desc.shape ?? shapeFromType(type);
return {
desc: desc.desc,
default: desc.default,
type,
validate: (value: T) => {
return shape.safeParse(value);
},
env,
link: desc.link,
schema: {
type: schemaFromType(type),
description: desc.desc,
...desc.schema,
},
};
}
type ModuleConfigDescriptors<T> = {
[K in keyof T]: ConfigDefineDescriptor<T[K]>;
};
export const APP_CONFIG_DESCRIPTORS: Record<
string,
Record<string, ConfigDescriptor<any>>
> = {};
export const getDescriptors = once(() => {
return Object.entries(APP_CONFIG_DESCRIPTORS).map(
([module, descriptors]) => ({
module,
key,
description: value.desc,
value: value.default,
type: runtimeConfigType(value.default),
};
});
}
export function defineStartupConfig<T extends keyof AppModulesConfigDef>(
module: T,
configs: ModuleStartupConfigDescriptions<AppModulesConfigDef[T]>
) {
set(
defaultStartupConfig,
module,
merge(get(defaultStartupConfig, module, {}), configs)
descriptors: Object.entries(descriptors).map(([key, descriptor]) => ({
key,
descriptor,
})),
})
);
});
export function defineModuleConfig<T extends keyof AppConfigSchema>(
module: T,
defs: ModuleConfigDescriptors<AppConfigByPath<T>>
) {
const descriptors: Record<string, ConfigDescriptor<any>> = {};
Object.entries(defs).forEach(([key, desc]) => {
descriptors[key] = standardizeDescriptor(
desc as ConfigDefineDescriptor<any>
);
});
APP_CONFIG_DESCRIPTORS[module] = {
...APP_CONFIG_DESCRIPTORS[module],
...descriptors,
};
}
export function defineRuntimeConfig<T extends keyof AppModulesConfigDef>(
module: T,
configs: ModuleRuntimeConfigDescriptions<T>
) {
registerRuntimeConfig(module, configs);
export function getDefaultConfig(): AppConfigSchema {
const config: Record<string, any> = {};
const envs = process.env;
for (const [module, defs] of Object.entries(APP_CONFIG_DESCRIPTORS)) {
const modulizedConfig = {};
for (const [key, desc] of Object.entries(defs)) {
let defaultValue = desc.default;
if (desc.env) {
const [env, parser] = desc.env;
const envValue = envs[env];
if (envValue) {
defaultValue = parseEnvValue(envValue, parser);
}
}
const { success, error } = desc.validate(defaultValue);
if (!success) {
throw error;
}
set(modulizedConfig, key, defaultValue);
}
config[module] = modulizedConfig;
}
return config as AppConfigSchema;
}

View File

@@ -1,127 +1,20 @@
import { Join, PathType } from '../utils/types';
import { LeafPaths, PathType } from '../utils';
export type ConfigItem<T> = T & { __type: 'ConfigItem' };
type ConfigDef = Record<string, any> | never;
export interface ModuleConfig<
Startup extends ConfigDef = never,
Runtime extends ConfigDef = never,
> {
startup: Startup;
runtime: Runtime;
declare global {
type ConfigItem<T> = Leaf<T>;
interface AppConfigSchema {}
type AppConfig = DeeplyEraseLeaf<AppConfigSchema>;
}
export type RuntimeConfigDescription<T> = {
desc: string;
default: T;
};
type ConfigItemLeaves<T, P extends string = ''> =
T extends Record<string, any>
export type AppConfigByPath<Module extends keyof AppConfigSchema> =
AppConfigSchema[Module] extends infer Config
? {
[K in keyof T]: K extends string
? T[K] extends { __type: 'ConfigItem' }
? K
: T[K] extends PrimitiveType
? K
: Join<K, ConfigItemLeaves<T[K], P>>
: never;
}[keyof T]
: never;
type StartupConfigDescriptions<T extends ConfigDef> = {
[K in keyof T]: T[K] extends Record<string, any>
? T[K] extends ConfigItem<infer V>
? V
: T[K]
: T[K];
};
type ModuleConfigLeaves<T, P extends string = ''> =
T extends Record<string, any>
? {
[K in keyof T]: K extends string
? T[K] extends ModuleConfig<any, any>
? K
: Join<K, ModuleConfigLeaves<T[K], P>>
: never;
}[keyof T]
: never;
type FlattenModuleConfigs<T extends Record<string, any>> = {
// @ts-expect-error allow
[K in ModuleConfigLeaves<T>]: PathType<T, K>;
};
type _AppStartupConfig<T extends Record<string, any>> = {
[K in keyof T]: T[K] extends ModuleConfig<infer S, any>
? S
: _AppStartupConfig<T[K]>;
};
// for extending
export interface AppConfig {}
export type AppModulesConfigDef = FlattenModuleConfigs<AppConfig>;
export type AppConfigModules = keyof AppModulesConfigDef;
export type AppStartupConfig = _AppStartupConfig<AppConfig>;
// app runtime config keyed by module names
export type AppRuntimeConfigByModules = {
[Module in keyof AppModulesConfigDef]: AppModulesConfigDef[Module] extends ModuleConfig<
any,
infer Runtime
>
? Runtime extends never
? never
: {
// @ts-expect-error allow
[K in ConfigItemLeaves<Runtime>]: PathType<
Runtime,
K
> extends infer Config
? Config extends ConfigItem<infer V>
[Path in LeafPaths<Config>]: Path extends string
? PathType<Config, Path> extends infer Item
? Item extends Leaf<infer V>
? V
: Config
: never;
}
: Item
: never
: never;
}
: never;
};
// names of modules that have runtime config
export type AppRuntimeConfigModules = {
[Module in keyof AppRuntimeConfigByModules]: AppRuntimeConfigByModules[Module] extends never
? never
: Module;
}[keyof AppRuntimeConfigByModules];
// runtime config keyed by module names flattened into config names
// { auth: { allowSignup: boolean } } => { 'auth/allowSignup': boolean }
export type FlattenedAppRuntimeConfig = UnionToIntersection<
{
[Module in keyof AppRuntimeConfigByModules]: AppModulesConfigDef[Module] extends never
? never
: {
[K in keyof AppRuntimeConfigByModules[Module] as K extends string
? `${Module}/${K}`
: never]: AppRuntimeConfigByModules[Module][K];
};
}[keyof AppRuntimeConfigByModules]
>;
export type ModuleStartupConfigDescriptions<T extends ModuleConfig<any, any>> =
T extends ModuleConfig<infer S, any>
? S extends never
? undefined
: StartupConfigDescriptions<S>
: never;
export type ModuleRuntimeConfigDescriptions<
Module extends keyof AppRuntimeConfigByModules,
> = AppModulesConfigDef[Module] extends never
? never
: {
[K in keyof AppRuntimeConfigByModules[Module]]: RuntimeConfigDescription<
AppRuntimeConfigByModules[Module][K]
>;
};

View File

@@ -815,4 +815,10 @@ export const USER_FRIENDLY_ERRORS = {
type: 'action_forbidden',
message: 'You can not mention yourself.',
},
// app config
invalid_app_config: {
type: 'invalid_input',
message: 'Invalid app config.',
},
} satisfies Record<string, UserFriendlyErrorOptions>;

View File

@@ -917,6 +917,12 @@ export class MentionUserOneselfDenied extends UserFriendlyError {
super('action_forbidden', 'mention_user_oneself_denied', message);
}
}
export class InvalidAppConfig extends UserFriendlyError {
constructor(message?: string) {
super('invalid_input', 'invalid_app_config', message);
}
}
export enum ErrorNames {
INTERNAL_SERVER_ERROR,
NETWORK_ERROR,
@@ -1034,7 +1040,8 @@ export enum ErrorNames {
UNSUPPORTED_CLIENT_VERSION,
NOTIFICATION_NOT_FOUND,
MENTION_USER_DOC_ACCESS_DENIED,
MENTION_USER_ONESELF_DENIED
MENTION_USER_ONESELF_DENIED,
INVALID_APP_CONFIG
}
registerEnumType(ErrorNames, {
name: 'ErrorNames'

View File

@@ -5,7 +5,6 @@ import { fileURLToPath } from 'node:url';
import { Logger, Module, OnModuleInit } from '@nestjs/common';
import { Args, Query, Resolver } from '@nestjs/graphql';
import { Config } from '../config/provider';
import { generateUserFriendlyErrors } from './def';
import { ActionForbidden, ErrorDataUnionType, ErrorNames } from './errors.gen';
@@ -23,9 +22,8 @@ class ErrorResolver {
})
export class ErrorModule implements OnModuleInit {
logger = new Logger('ErrorModule');
constructor(private readonly config: Config) {}
onModuleInit() {
if (!this.config.node.dev) {
if (!env.dev) {
return;
}
this.logger.log('Generating UserFriendlyError classes');

View File

@@ -1,6 +1,6 @@
import { OnOptions } from 'eventemitter2';
import { PushMetadata, sliceMetadata } from '../nestjs';
import { PushMetadata, sliceMetadata } from '../nestjs/decorator';
declare global {
/**

View File

@@ -6,7 +6,10 @@ import { EventHandlerScanner } from './scanner';
const EmitProvider = {
provide: EventEmitter2,
useFactory: () => new EventEmitter2(),
useFactory: () =>
new EventEmitter2({
maxListeners: 100,
}),
};
@Global()

View File

@@ -1,7 +1,7 @@
import { Injectable } from '@nestjs/common';
import { once } from 'lodash-es';
import { ModuleScanner } from '../nestjs';
import { ModuleScanner } from '../nestjs/scanner';
import {
type EventName,
type EventOptions,

View File

@@ -1,17 +1,27 @@
import { ApolloDriverConfig } from '@nestjs/apollo';
import { defineStartupConfig, ModuleConfig } from '../../base/config';
import { defineModuleConfig } from '../config';
declare module '../../base/config' {
interface AppConfig {
graphql: ModuleConfig<ApolloDriverConfig>;
declare global {
interface AppConfigSchema {
graphql: {
apolloDriverConfig: ConfigItem<ApolloDriverConfig>;
};
}
}
defineStartupConfig('graphql', {
buildSchemaOptions: {
numberScalarMode: 'integer',
defineModuleConfig('graphql', {
apolloDriverConfig: {
desc: 'The config for underlying nestjs GraphQL and apollo driver engine.',
default: {
buildSchemaOptions: {
numberScalarMode: 'integer',
},
useGlobalPrefix: true,
playground: true,
introspection: true,
sortSchema: true,
},
link: 'https://docs.nestjs.com/graphql/quick-start',
},
introspection: true,
playground: true,
});

View File

@@ -1,13 +1,12 @@
import './config';
import { join } from 'node:path';
import { fileURLToPath } from 'node:url';
import type { ApolloDriverConfig } from '@nestjs/apollo';
import { ApolloDriver } from '@nestjs/apollo';
import { Global, Module } from '@nestjs/common';
import { GraphQLModule } from '@nestjs/graphql';
import { Request, Response } from 'express';
import type { Request, Response } from 'express';
import { Config } from '../config';
import { mapAnyError } from '../nestjs/exception';
@@ -26,18 +25,17 @@ export type GraphqlContext = {
driver: ApolloDriver,
useFactory: (config: Config) => {
return {
...config.graphql,
path: `${config.server.path}/graphql`,
...config.graphql.apolloDriverConfig,
autoSchemaFile: join(
env.projectRoot,
env.testing
? './node_modules/.cache/schema.gql'
: './src/schema.gql'
),
path: '/graphql',
csrfPrevention: {
requestHeaders: ['content-type'],
},
autoSchemaFile: join(
fileURLToPath(import.meta.url),
config.node.dev
? '../../../schema.gql'
: '../../../../node_modules/.cache/schema.gql'
),
sortSchema: true,
context: ({
req,
res,
@@ -55,7 +53,7 @@ export type GraphqlContext = {
// @ts-expect-error allow assign
formattedError.extensions = ufe.toJSON();
if (config.affine.canary) {
if (env.namespaces.canary) {
formattedError.extensions.stacktrace = ufe.stacktrace;
}
return formattedError;

View File

@@ -1,4 +1,3 @@
import { Type } from '@nestjs/common';
import { Field, FieldOptions, ObjectType } from '@nestjs/graphql';
import { ApplyType } from '../utils/types';
@@ -7,7 +6,7 @@ export function registerObjectType<T>(
fields: Record<
string,
{
type: () => Type<any>;
type: () => any;
options?: FieldOptions;
}
>,

View File

@@ -1,5 +1,3 @@
import { createPrivateKey, createPublicKey } from 'node:crypto';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
@@ -9,42 +7,19 @@ const test = ava as TestFn<{
crypto: CryptoHelper;
}>;
const key = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIEtyAJLIULkphVhqXqxk4Nr8Ggty3XLwUJWBxzAWCWTMoAoGCCqGSM49
AwEHoUQDQgAEF3U/0wIeJ3jRKXeFKqQyBKlr9F7xaAUScRrAuSP33rajm3cdfihI
3JvMxVNsS2lE8PSGQrvDrJZaDo0L+Lq9Gg==
-----END EC PRIVATE KEY-----`;
const privateKey = createPrivateKey({
key,
format: 'pem',
type: 'sec1',
})
.export({
type: 'pkcs8',
format: 'pem',
})
.toString('utf8');
const publicKey = createPublicKey({
key,
format: 'pem',
type: 'spki',
})
.export({
format: 'pem',
type: 'spki',
})
.toString('utf8');
const privateKey = `-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgS3IAkshQuSmFWGpe
rGTg2vwaC3LdcvBQlYHHMBYJZMyhRANCAAQXdT/TAh4neNEpd4UqpDIEqWv0XvFo
BRJxGsC5I/fetqObdx1+KEjcm8zFU2xLaUTw9IZCu8OslloOjQv4ur0a
-----END PRIVATE KEY-----`;
test.beforeEach(async t => {
t.context.crypto = new CryptoHelper({
crypto: {
secret: {
publicKey,
privateKey,
},
privateKey,
},
} as any);
t.context.crypto.onConfigInit();
});
test('should be able to sign and verify', t => {

View File

@@ -1,53 +1,18 @@
import { createPrivateKey, createPublicKey } from 'node:crypto';
import { defineModuleConfig } from '../config';
import { defineStartupConfig, ModuleConfig } from '../config';
declare module '../config' {
interface AppConfig {
crypto: ModuleConfig<{
secret: {
publicKey: string;
privateKey: string;
};
}>;
declare global {
interface AppConfigSchema {
crypto: {
privateKey: string;
};
}
}
// Don't use this in production
const examplePrivateKey = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIEtyAJLIULkphVhqXqxk4Nr8Ggty3XLwUJWBxzAWCWTMoAoGCCqGSM49
AwEHoUQDQgAEF3U/0wIeJ3jRKXeFKqQyBKlr9F7xaAUScRrAuSP33rajm3cdfihI
3JvMxVNsS2lE8PSGQrvDrJZaDo0L+Lq9Gg==
-----END EC PRIVATE KEY-----`;
defineStartupConfig('crypto', {
secret: (function () {
const AFFINE_PRIVATE_KEY =
process.env.AFFINE_PRIVATE_KEY ?? examplePrivateKey;
const privateKey = createPrivateKey({
key: Buffer.from(AFFINE_PRIVATE_KEY),
format: 'pem',
type: 'sec1',
})
.export({
format: 'pem',
type: 'pkcs8',
})
.toString('utf8');
const publicKey = createPublicKey({
key: Buffer.from(AFFINE_PRIVATE_KEY),
format: 'pem',
type: 'spki',
})
.export({
format: 'pem',
type: 'spki',
})
.toString('utf8');
return {
publicKey,
privateKey,
};
})(),
defineModuleConfig('crypto', {
privateKey: {
desc: 'The private key for used by the crypto module to create signed tokens or encrypt data.',
env: 'AFFINE_PRIVATE_KEY',
default: '',
schema: { type: 'string' },
},
});

View File

@@ -2,8 +2,11 @@ import {
createCipheriv,
createDecipheriv,
createHash,
createPrivateKey,
createPublicKey,
createSign,
createVerify,
generateKeyPairSync,
randomBytes,
randomInt,
timingSafeEqual,
@@ -16,13 +19,48 @@ import {
} from '@node-rs/argon2';
import { Config } from '../config';
import { OnEvent } from '../event';
const NONCE_LENGTH = 12;
const AUTH_TAG_LENGTH = 12;
function generatePrivateKey(): string {
const { privateKey } = generateKeyPairSync('ec', {
namedCurve: 'prime256v1',
});
const key = privateKey.export({
type: 'sec1',
format: 'pem',
});
return key.toString('utf8');
}
function readPrivateKey(privateKey: string) {
return createPrivateKey({
key: Buffer.from(privateKey),
format: 'pem',
type: 'sec1',
})
.export({
format: 'pem',
type: 'pkcs8',
})
.toString('utf8');
}
function readPublicKey(privateKey: string) {
return createPublicKey({
key: Buffer.from(privateKey),
})
.export({ format: 'pem', type: 'spki' })
.toString('utf8');
}
@Injectable()
export class CryptoHelper {
keyPair: {
keyPair!: {
publicKey: Buffer;
privateKey: Buffer;
sha256: {
@@ -31,13 +69,31 @@ export class CryptoHelper {
};
};
constructor(config: Config) {
constructor(private readonly config: Config) {}
@OnEvent('config.init')
onConfigInit() {
this.setup();
}
@OnEvent('config.changed')
onConfigChanged(event: Events['config.changed']) {
if (event.updates.crypto?.privateKey) {
this.setup();
}
}
private setup() {
const key = this.config.crypto.privateKey || generatePrivateKey();
const privateKey = readPrivateKey(key);
const publicKey = readPublicKey(key);
this.keyPair = {
publicKey: Buffer.from(config.crypto.secret.publicKey, 'utf8'),
privateKey: Buffer.from(config.crypto.secret.privateKey, 'utf8'),
publicKey: Buffer.from(publicKey),
privateKey: Buffer.from(privateKey),
sha256: {
publicKey: this.sha256(config.crypto.secret.publicKey),
privateKey: this.sha256(config.crypto.secret.privateKey),
publicKey: this.sha256(publicKey),
privateKey: this.sha256(privateKey),
},
};
}

View File

@@ -4,16 +4,23 @@ import { Injectable } from '@nestjs/common';
import type { Response } from 'express';
import { Config } from '../config';
import { OnEvent } from '../event';
@Injectable()
export class URLHelper {
private readonly redirectAllowHosts: string[];
redirectAllowHosts!: string[];
readonly origin: string;
readonly baseUrl: string;
readonly home: string;
origin!: string;
baseUrl!: string;
home!: string;
constructor(private readonly config: Config) {
this.init();
}
@OnEvent('config.changed')
@OnEvent('config.init')
init() {
if (this.config.server.externalUrl) {
if (!this.verify(this.config.server.externalUrl)) {
throw new Error(

View File

@@ -6,17 +6,14 @@ export {
SessionCache,
} from './cache';
export {
type AFFiNEConfig,
applyEnvToConfig,
Config,
type ConfigPaths,
DeploymentType,
getAFFiNEConfigModifier,
ConfigFactory,
defineModuleConfig,
type JSONSchema,
} from './config';
export * from './error';
export { EventBus, OnEvent } from './event';
export {
type GraphqlContext,
paginate,
Paginated,
PaginationInput,
@@ -30,8 +27,12 @@ export { CallMetric, metrics } from './metrics';
export { Lock, Locker, Mutex, RequestMutex } from './mutex';
export * from './nestjs';
export { type PrismaTransaction } from './prisma';
export { Runtime } from './runtime';
export * from './storage';
export { type StorageProvider, StorageProviderFactory } from './storage';
export {
autoMetadata,
type StorageProvider,
type StorageProviderConfig,
StorageProviderFactory,
} from './storage';
export { CloudThrottlerGuard, SkipThrottle, Throttle } from './throttler';
export * from './utils';

View File

@@ -52,13 +52,15 @@ class JobHandlers {
test.before(async () => {
module = await createTestingModule({
imports: [
ConfigModule.forRoot({
ConfigModule.override({
job: {
worker: {
// NOTE(@forehalo):
// bullmq will hold the connection to check stalled jobs,
// which will keep the test process alive to timeout.
stalledInterval: 100,
defaultWorkerOptions: {
// NOTE(@forehalo):
// bullmq will hold the connection to check stalled jobs,
// which will keep the test process alive to timeout.
stalledInterval: 100,
},
},
queue: {
defaultJobOptions: { delay: 1000 },
@@ -82,7 +84,7 @@ test.afterEach(async () => {
// @ts-expect-error private api
const inner = queue.getQueue('nightly');
await inner.obliterate({ force: true });
inner.resume();
await inner.resume();
});
test.after.always(async () => {
@@ -132,7 +134,7 @@ test('should remove job from queue', async t => {
// #region executor
test('should start workers', async t => {
// @ts-expect-error private api
const worker = executor.workers['nightly'];
const worker = executor.workers.get('nightly')!;
t.truthy(worker);
t.true(worker.isRunning());

View File

@@ -1,61 +1,86 @@
import { QueueOptions, WorkerOptions } from 'bullmq';
import {
defineRuntimeConfig,
defineStartupConfig,
ModuleConfig,
} from '../../config';
import { defineModuleConfig, JSONSchema } from '../../config';
import { Queue } from './def';
declare module '../../config' {
interface AppConfig {
job: ModuleConfig<
{
queue: Omit<QueueOptions, 'connection'>;
worker: Omit<WorkerOptions, 'connection'>;
},
{
queues: {
[key in Queue]: {
concurrency: number;
};
};
}
>;
declare global {
interface AppConfigSchema {
job: {
queue: ConfigItem<Omit<QueueOptions, 'connection' | 'telemetry'>>;
worker: ConfigItem<{
defaultWorkerOptions: Omit<WorkerOptions, 'connection' | 'telemetry'>;
}>;
queues: {
[key in Queue]: ConfigItem<{
concurrency: number;
}>;
};
};
}
}
defineStartupConfig('job', {
const schema: JSONSchema = {
type: 'object',
properties: {
concurrency: { type: 'number' },
},
};
defineModuleConfig('job', {
queue: {
prefix: AFFiNE.node.test ? 'affine_job_test' : 'affine_job',
defaultJobOptions: {
attempts: 5,
// should remove job after it's completed, because we will add a new job with the same job id
removeOnComplete: true,
removeOnFail: {
age: 24 * 3600 /* 1 day */,
count: 500,
desc: 'The config for job queues',
default: {
prefix: env.testing ? 'affine_job_test' : 'affine_job',
defaultJobOptions: {
attempts: 5,
// should remove job after it's completed, because we will add a new job with the same job id
removeOnComplete: true,
removeOnFail: {
age: 24 * 3600 /* 1 day */,
count: 500,
},
},
},
link: 'https://api.docs.bullmq.io/interfaces/v5.QueueOptions.html',
},
worker: {},
});
defineRuntimeConfig('job', {
'queues.nightly.concurrency': {
default: 1,
desc: 'Concurrency of worker consuming of nightly checking job queue',
worker: {
desc: 'The config for job workers',
default: {
defaultWorkerOptions: {},
},
link: 'https://api.docs.bullmq.io/interfaces/v5.WorkerOptions.html',
},
'queues.notification.concurrency': {
default: 10,
desc: 'Concurrency of worker consuming of notification job queue',
'queues.copilot': {
desc: 'The config for copilot job queue',
default: {
concurrency: 1,
},
schema,
},
'queues.doc.concurrency': {
default: 1,
desc: 'Concurrency of worker consuming of doc job queue',
'queues.doc': {
desc: 'The config for doc job queue',
default: {
concurrency: 1,
},
schema,
},
'queues.copilot.concurrency': {
default: 1,
desc: 'Concurrency of worker consuming of copilot job queue',
'queues.notification': {
desc: 'The config for notification job queue',
default: {
concurrency: 10,
},
schema,
},
'queues.nightly': {
desc: 'The config for nightly job queue',
default: {
concurrency: 1,
},
schema,
},
});

View File

@@ -49,7 +49,7 @@ export const OnJob = (job: JobName) => {
if (!QUEUES.includes(ns as Queue)) {
throw new Error(
`Invalid job queue: ${ns}, must be one of [${QUEUES.join(', ')}].
If you want to introduce new job queue, please modify the Queue enum first in ${join(AFFiNE.projectRoot, 'src/base/job/queue/def.ts')}`
If you want to introduce new job queue, please modify the Queue enum first in ${join(env.projectRoot, 'src/base/job/queue/def.ts')}`
);
}

View File

@@ -1,49 +1,51 @@
import {
Injectable,
Logger,
OnApplicationBootstrap,
OnApplicationShutdown,
} from '@nestjs/common';
import { Injectable, Logger, OnModuleDestroy } from '@nestjs/common';
import { Worker } from 'bullmq';
import { difference } from 'lodash-es';
import { difference, merge } from 'lodash-es';
import { CLS_ID, ClsServiceManager } from 'nestjs-cls';
import { Config } from '../../config';
import { OnEvent } from '../../event';
import { metrics, wrapCallMetric } from '../../metrics';
import { QueueRedis } from '../../redis';
import { Runtime } from '../../runtime';
import { genRequestId } from '../../utils';
import { JOB_SIGNAL, namespace, Queue, QUEUES } from './def';
import { JobHandlerScanner } from './scanner';
@Injectable()
export class JobExecutor
implements OnApplicationBootstrap, OnApplicationShutdown
{
export class JobExecutor implements OnModuleDestroy {
private readonly logger = new Logger('job');
private readonly workers: Record<string, Worker> = {};
private readonly workers: Map<Queue, Worker> = new Map();
constructor(
private readonly config: Config,
private readonly redis: QueueRedis,
private readonly scanner: JobHandlerScanner,
private readonly runtime: Runtime
private readonly scanner: JobHandlerScanner
) {}
async onApplicationBootstrap() {
const queues = this.config.flavor.graphql
? difference(QUEUES, [Queue.DOC])
: [];
@OnEvent('config.init')
async onConfigInit() {
const queues = env.flavors.graphql ? difference(QUEUES, [Queue.DOC]) : [];
// NOTE(@forehalo): only enable doc queue in doc service
if (this.config.flavor.doc) {
if (env.flavors.doc) {
queues.push(Queue.DOC);
}
await this.startWorkers(queues);
}
async onApplicationShutdown() {
@OnEvent('config.changed')
async onConfigChanged({ updates }: Events['config.changed']) {
if (updates.job?.queues) {
Object.entries(updates.job.queues).forEach(([queue, options]) => {
if (options.concurrency) {
this.setConcurrency(queue as Queue, options.concurrency);
}
});
}
}
async onModuleDestroy() {
await this.stopWorkers();
}
@@ -98,38 +100,35 @@ export class JobExecutor
}
}
private async startWorkers(queues: Queue[]) {
const configs =
(await this.runtime.fetchAll(
queues.reduce(
(ret, queue) => {
ret[`job/queues.${queue}.concurrency`] = true;
return ret;
},
{} as {
[key in `job/queues.${Queue}.concurrency`]: true;
}
)
// TODO(@forehalo): fix the override by [payment/service.spec.ts]
)) ?? {};
setConcurrency(queue: Queue, concurrency: number) {
const worker = this.workers.get(queue);
if (!worker) {
throw new Error(`Worker for [${queue}] not found.`);
}
worker.concurrency = concurrency;
}
private async startWorkers(queues: Queue[]) {
for (const queue of queues) {
const concurrency =
(configs[`job/queues.${queue}.concurrency`] as number) ??
this.config.job.worker.concurrency ??
1;
const queueOptions = this.config.job.queues[queue];
const concurrency = queueOptions.concurrency ?? 1;
const worker = new Worker(
queue,
async job => {
await this.run(job.name as JobName, job.data);
},
{
...this.config.job.queue,
...this.config.job.worker,
connection: this.redis,
concurrency,
}
merge(
{},
this.config.job.queue,
this.config.job.worker.defaultWorkerOptions,
queueOptions,
{
concurrency,
connection: this.redis,
}
)
);
worker.on('error', error => {
@@ -140,13 +139,13 @@ export class JobExecutor
`Queue Worker [${queue}] started; concurrency=${concurrency};`
);
this.workers[queue] = worker;
this.workers.set(queue, worker);
}
}
private async stopWorkers() {
await Promise.all(
Object.values(this.workers).map(async worker => {
Array.from(this.workers.values()).map(async worker => {
await worker.close(true);
})
);

View File

@@ -1,33 +1,16 @@
import { defineStartupConfig, ModuleConfig } from '../config';
import { defineModuleConfig } from '../config';
declare module '../config' {
interface AppConfig {
metrics: ModuleConfig<{
/**
* Enable metric and tracing collection
*/
declare global {
interface AppConfigSchema {
metrics: {
enabled: boolean;
/**
* Enable telemetry
*/
telemetry: {
enabled: boolean;
token: string;
};
customerIo: {
token: string;
};
}>;
};
}
}
defineStartupConfig('metrics', {
enabled: false,
telemetry: {
enabled: false,
token: '',
},
customerIo: {
token: '',
defineModuleConfig('metrics', {
enabled: {
desc: 'Enable metric and tracing collection',
default: false,
},
});

View File

@@ -1,54 +1,14 @@
import './config';
import {
Global,
Module,
OnModuleDestroy,
OnModuleInit,
Provider,
} from '@nestjs/common';
import { ModuleRef } from '@nestjs/core';
import { NodeSDK } from '@opentelemetry/sdk-node';
import { Global, Module } from '@nestjs/common';
import { Config } from '../config';
import {
LocalOpentelemetryFactory,
OpentelemetryFactory,
registerCustomMetrics,
} from './opentelemetry';
const factorProvider: Provider = {
provide: OpentelemetryFactory,
useFactory: (config: Config) => {
return config.metrics.enabled ? new LocalOpentelemetryFactory() : null;
},
inject: [Config],
};
import { OpentelemetryFactory } from './opentelemetry';
@Global()
@Module({
providers: [factorProvider],
exports: [factorProvider],
providers: [OpentelemetryFactory],
})
export class MetricsModule implements OnModuleInit, OnModuleDestroy {
private sdk: NodeSDK | null = null;
constructor(private readonly ref: ModuleRef) {}
onModuleInit() {
const factor = this.ref.get(OpentelemetryFactory, { strict: false });
if (factor) {
this.sdk = factor.create();
this.sdk.start();
registerCustomMetrics();
}
}
async onModuleDestroy() {
if (this.sdk) {
await this.sdk.shutdown();
}
}
}
export class MetricsModule {}
export * from './metrics';
export * from './utils';

View File

@@ -2,11 +2,28 @@ import {
Gauge,
Histogram,
Meter,
MeterProvider,
MetricOptions,
metrics as otelMetrics,
UpDownCounter,
} from '@opentelemetry/api';
import { HostMetrics } from '@opentelemetry/host-metrics';
import { getMeter } from './opentelemetry';
function getMeterProvider() {
return otelMetrics.getMeterProvider();
}
export function registerCustomMetrics() {
const hostMetricsMonitoring = new HostMetrics({
name: 'instance-host-metrics',
meterProvider: getMeterProvider() as MeterProvider,
});
hostMetricsMonitoring.start();
}
export function getMeter(name = 'business') {
return getMeterProvider().getMeter(name);
}
type MetricType = 'counter' | 'gauge' | 'histogram';
type Metric<T extends MetricType> = T extends 'counter'
@@ -122,5 +139,3 @@ export const metrics = new Proxy<Record<KnownMetricScopes, ScopedMetrics>>(
},
}
);
export function stopMetrics() {}

View File

@@ -1,5 +1,4 @@
import { OnModuleDestroy } from '@nestjs/common';
import { metrics } from '@opentelemetry/api';
import { Injectable, Logger, OnModuleDestroy } from '@nestjs/common';
import {
CompositePropagator,
W3CBaggagePropagator,
@@ -7,7 +6,6 @@ import {
} from '@opentelemetry/core';
import { PrometheusExporter } from '@opentelemetry/exporter-prometheus';
import { ZipkinExporter } from '@opentelemetry/exporter-zipkin';
import { HostMetrics } from '@opentelemetry/host-metrics';
import { Instrumentation } from '@opentelemetry/instrumentation';
import { GraphQLInstrumentation } from '@opentelemetry/instrumentation-graphql';
import { HttpInstrumentation } from '@opentelemetry/instrumentation-http';
@@ -15,7 +13,6 @@ import { IORedisInstrumentation } from '@opentelemetry/instrumentation-ioredis';
import { NestInstrumentation } from '@opentelemetry/instrumentation-nestjs-core';
import { SocketIoInstrumentation } from '@opentelemetry/instrumentation-socket.io';
import { Resource } from '@opentelemetry/resources';
import type { MeterProvider } from '@opentelemetry/sdk-metrics';
import { MetricProducer, MetricReader } from '@opentelemetry/sdk-metrics';
import { NodeSDK } from '@opentelemetry/sdk-node';
import {
@@ -30,11 +27,14 @@ import {
} from '@opentelemetry/semantic-conventions/incubating';
import prismaInstrument from '@prisma/instrumentation';
import { Config } from '../config';
import { OnEvent } from '../event/def';
import { registerCustomMetrics } from './metrics';
import { PrismaMetricProducer } from './prisma';
const { PrismaInstrumentation } = prismaInstrument;
export abstract class OpentelemetryFactory {
export abstract class BaseOpentelemetryFactory {
abstract getMetricReader(): MetricReader;
abstract getSpanExporter(): SpanExporter;
@@ -55,9 +55,9 @@ export abstract class OpentelemetryFactory {
getResource() {
return new Resource({
[ATTR_K8S_NAMESPACE_NAME]: AFFiNE.AFFINE_ENV,
[ATTR_SERVICE_NAME]: AFFiNE.flavor.type,
[ATTR_SERVICE_VERSION]: AFFiNE.version,
[ATTR_K8S_NAMESPACE_NAME]: env.NAMESPACE,
[ATTR_SERVICE_NAME]: env.FLAVOR,
[ATTR_SERVICE_VERSION]: env.version,
});
}
@@ -81,39 +81,58 @@ export abstract class OpentelemetryFactory {
}
}
export class LocalOpentelemetryFactory
extends OpentelemetryFactory
@Injectable()
export class OpentelemetryFactory
extends BaseOpentelemetryFactory
implements OnModuleDestroy
{
private readonly metricsExporter = new PrometheusExporter({
metricProducers: this.getMetricsProducers(),
});
private readonly logger = new Logger(OpentelemetryFactory.name);
#sdk: NodeSDK | null = null;
constructor(private readonly config: Config) {
super();
}
@OnEvent('config.init')
async init(event: Events['config.init']) {
if (event.config.metrics.enabled) {
await this.setup();
registerCustomMetrics();
}
}
@OnEvent('config.changed')
async onConfigChanged(event: Events['config.changed']) {
if ('metrics' in event.updates) {
await this.setup();
}
}
async onModuleDestroy() {
await this.metricsExporter.shutdown();
await this.#sdk?.shutdown();
}
override getMetricReader(): MetricReader {
return this.metricsExporter;
return new PrometheusExporter({
metricProducers: this.getMetricsProducers(),
});
}
override getSpanExporter(): SpanExporter {
return new ZipkinExporter();
}
}
function getMeterProvider() {
return metrics.getMeterProvider();
}
export function registerCustomMetrics() {
const hostMetricsMonitoring = new HostMetrics({
name: 'instance-host-metrics',
meterProvider: getMeterProvider() as MeterProvider,
});
hostMetricsMonitoring.start();
}
export function getMeter(name = 'business') {
return getMeterProvider().getMeter(name);
private async setup() {
if (this.config.metrics.enabled) {
if (!this.#sdk) {
this.#sdk = this.create();
}
this.#sdk.start();
this.logger.log('OpenTelemetry SDK started');
} else {
await this.#sdk?.shutdown();
this.#sdk = null;
this.logger.log('OpenTelemetry SDK stopped');
}
}
}

View File

@@ -10,7 +10,7 @@ import {
ScopeMetrics,
} from '@opentelemetry/sdk-metrics';
import { PrismaService } from '../prisma';
import { PrismaFactory } from '../prisma/factory';
function transformPrismaKey(key: string) {
// replace first '_' to '/' as a scope prefix
@@ -30,11 +30,11 @@ export class PrismaMetricProducer implements MetricProducer {
errors: [],
};
if (!PrismaService.INSTANCE) {
if (!PrismaFactory.INSTANCE) {
return result;
}
const prisma = PrismaService.INSTANCE;
const prisma = PrismaFactory.INSTANCE;
const endTime = hrTime();

View File

@@ -1,39 +0,0 @@
import { defineStartupConfig, ModuleConfig } from '../../base/config';
export interface ServerStartupConfigurations {
/**
* Base url of AFFiNE server, used for generating external urls.
* default to be `[AFFiNE.protocol]://[AFFiNE.host][:AFFiNE.port]/[AFFiNE.path]` if not specified
*/
externalUrl: string;
/**
* Whether the server is hosted on a ssl enabled domain
*/
https: boolean;
/**
* where the server get deployed(FQDN).
*/
host: string;
/**
* which port the server will listen on
*/
port: number;
/**
* subpath where the server get deployed if there is.
*/
path: string;
}
declare module '../../base/config' {
interface AppConfig {
server: ModuleConfig<ServerStartupConfigurations>;
}
}
defineStartupConfig('server', {
externalUrl: '',
https: false,
host: 'localhost',
port: 3010,
path: '',
});

View File

@@ -1,5 +1,3 @@
import './config';
export * from './decorator';
export * from './exception';
export * from './optional-module';
export * from './scanner';

View File

@@ -1,72 +0,0 @@
import {
DynamicModule,
Module,
ModuleMetadata,
Provider,
Type,
} from '@nestjs/common';
import { omit } from 'lodash-es';
import type { AFFiNEConfig, ConfigPaths } from '../config';
export interface OptionalModuleMetadata extends ModuleMetadata {
/**
* Only install module if given config paths are defined in AFFiNE config.
*/
requires?: ConfigPaths[];
/**
* Only install module if the predication returns true.
*/
if?: (config: AFFiNEConfig) => boolean;
/**
* Defines which feature will be enabled if the module installed.
*/
contributesTo?: import('../../core/config').ServerFeature; // avoid circlar dependency
/**
* Defines which providers provided by other modules will be overridden if the module installed.
*/
overrides?: Provider[];
}
const additionalOptions = [
'contributesTo',
'requires',
'if',
'overrides',
] as const satisfies Array<keyof OptionalModuleMetadata>;
type OptionalDynamicModule = DynamicModule & OptionalModuleMetadata;
export function OptionalModule(metadata: OptionalModuleMetadata) {
return (target: Type) => {
additionalOptions.forEach(option => {
if (Object.hasOwn(metadata, option)) {
Reflect.defineMetadata(option, metadata[option], target);
}
});
if (metadata.overrides) {
metadata.providers = (metadata.providers ?? []).concat(
metadata.overrides
);
// eslint-disable-next-line @typescript-eslint/no-floating-promises
metadata.exports = (metadata.exports ?? []).concat(metadata.overrides);
}
const nestMetadata = omit(metadata, additionalOptions);
Module(nestMetadata)(target);
};
}
export function getOptionalModuleMetadata<
T extends keyof OptionalModuleMetadata,
>(target: Type | OptionalDynamicModule, key: T): OptionalModuleMetadata[T] {
if ('module' in target) {
return target[key];
} else {
return Reflect.getMetadata(key, target);
}
}

View File

@@ -1,17 +1,27 @@
import type { Prisma } from '@prisma/client';
import { z } from 'zod';
import { defineStartupConfig, ModuleConfig } from '../config';
import { defineModuleConfig } from '../config';
interface PrismaStartupConfiguration extends Prisma.PrismaClientOptions {
datasourceUrl: string;
}
declare module '../config' {
interface AppConfig {
prisma: ModuleConfig<PrismaStartupConfiguration>;
declare global {
interface AppConfigSchema {
db: {
datasourceUrl: string;
prisma: ConfigItem<Prisma.PrismaClientOptions>;
};
}
}
defineStartupConfig('prisma', {
datasourceUrl: '',
defineModuleConfig('db', {
datasourceUrl: {
desc: 'The datasource url for the prisma client.',
default: 'postgresql://localhost:5432/affine',
env: 'DATABASE_URL',
shape: z.string().url(),
},
prisma: {
desc: 'The config for the prisma client.',
default: {},
link: 'https://www.prisma.io/docs/reference/api-reference/prisma-client-reference',
},
});

View File

@@ -0,0 +1,25 @@
import type { OnModuleDestroy } from '@nestjs/common';
import { Injectable } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import { Config } from '../config';
@Injectable()
export class PrismaFactory implements OnModuleDestroy {
static INSTANCE: PrismaClient | null = null;
readonly #instance: PrismaClient;
constructor(config: Config) {
this.#instance = new PrismaClient(config.db.prisma);
PrismaFactory.INSTANCE = this.#instance;
}
get() {
return this.#instance;
}
async onModuleDestroy() {
await PrismaFactory.INSTANCE?.$disconnect();
PrismaFactory.INSTANCE = null;
}
}

View File

@@ -3,29 +3,24 @@ import './config';
import { Global, Module, Provider } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import { Config } from '../config';
import { PrismaService } from './service';
import { PrismaFactory } from './factory';
// only `PrismaClient` can be injected
const clientProvider: Provider = {
provide: PrismaClient,
useFactory: (config: Config) => {
if (PrismaService.INSTANCE) {
return PrismaService.INSTANCE;
}
return new PrismaService(config.prisma);
useFactory: (factory: PrismaFactory) => {
return factory.get();
},
inject: [Config],
inject: [PrismaFactory],
};
@Global()
@Module({
providers: [clientProvider],
providers: [PrismaFactory, clientProvider],
exports: [clientProvider],
})
export class PrismaModule {}
export { PrismaService } from './service';
export { PrismaFactory };
export type PrismaTransaction = Parameters<
Parameters<PrismaClient['$transaction']>[0]

View File

@@ -1,27 +0,0 @@
import type { OnApplicationShutdown, OnModuleInit } from '@nestjs/common';
import { Injectable } from '@nestjs/common';
import { Prisma, PrismaClient } from '@prisma/client';
@Injectable()
export class PrismaService
extends PrismaClient
implements OnModuleInit, OnApplicationShutdown
{
static INSTANCE: PrismaService | null = null;
constructor(opts: Prisma.PrismaClientOptions) {
super(opts);
PrismaService.INSTANCE = this;
}
async onModuleInit() {
await this.$connect();
}
async onApplicationShutdown(): Promise<void> {
if (!AFFiNE.node.test) {
await this.$disconnect();
PrismaService.INSTANCE = null;
}
}
}

View File

@@ -1,11 +1,54 @@
import { RedisOptions } from 'ioredis';
import { z } from 'zod';
import { defineStartupConfig, ModuleConfig } from '../../base/config';
import { defineModuleConfig } from '../config';
declare module '../config' {
interface AppConfig {
redis: ModuleConfig<RedisOptions>;
declare global {
interface AppConfigSchema {
redis: {
host: string;
port: number;
db: number;
username: string;
password: string;
ioredis: ConfigItem<
Omit<RedisOptions, 'host' | 'port' | 'db' | 'username' | 'password'>
>;
};
}
}
defineStartupConfig('redis', {});
defineModuleConfig('redis', {
db: {
desc: 'The database index of redis server to be used(Must be less than 10).',
default: 0,
env: ['REDIS_DATABASE', 'integer'],
validate: val => val >= 0 && val < 10,
},
host: {
desc: 'The host of the redis server.',
default: 'localhost',
env: ['REDIS_HOST', 'string'],
},
port: {
desc: 'The port of the redis server.',
default: 6379,
env: ['REDIS_PORT', 'integer'],
shape: z.number().positive(),
},
username: {
desc: 'The username of the redis server.',
default: '',
env: ['REDIS_USERNAME', 'string'],
},
password: {
desc: 'The password of the redis server.',
default: '',
env: ['REDIS_PASSWORD', 'string'],
},
ioredis: {
desc: 'The config for the ioredis client.',
default: {},
link: 'https://github.com/luin/ioredis',
},
});

View File

@@ -6,13 +6,10 @@ import {
} from '@nestjs/common';
import { Redis as IORedis, RedisOptions } from 'ioredis';
import { Config } from '../../base/config';
import { Config } from '../config';
class Redis extends IORedis implements OnModuleInit, OnModuleDestroy {
private readonly logger = new Logger(this.constructor.name);
constructor(opts: RedisOptions) {
super(opts);
}
errorHandler = (err: Error) => {
this.logger.error(err);
@@ -46,21 +43,29 @@ class Redis extends IORedis implements OnModuleInit, OnModuleDestroy {
@Injectable()
export class CacheRedis extends Redis {
constructor(config: Config) {
super(config.redis);
super({ ...config.redis, ...config.redis.ioredis });
}
}
@Injectable()
export class SessionRedis extends Redis {
constructor(config: Config) {
super({ ...config.redis, db: (config.redis.db ?? 0) + 2 });
super({
...config.redis,
...config.redis.ioredis,
db: (config.redis.db ?? 0) + 2,
});
}
}
@Injectable()
export class SocketIoRedis extends Redis {
constructor(config: Config) {
super({ ...config.redis, db: (config.redis.db ?? 0) + 3 });
super({
...config.redis,
...config.redis.ioredis,
db: (config.redis.db ?? 0) + 3,
});
}
}
@@ -69,6 +74,7 @@ export class QueueRedis extends Redis {
constructor(config: Config) {
super({
...config.redis,
...config.redis.ioredis,
db: (config.redis.db ?? 0) + 4,
// required explicitly set to `null` by bullmq
maxRetriesPerRequest: null,

View File

@@ -1,7 +0,0 @@
import { FlattenedAppRuntimeConfig } from '../config/types';
declare global {
interface Events {
'runtime.changed__NOT_IMPLEMENTED__': Partial<FlattenedAppRuntimeConfig>;
}
}

View File

@@ -1,11 +0,0 @@
import { Global, Module } from '@nestjs/common';
import { Runtime } from './service';
@Global()
@Module({
providers: [Runtime],
exports: [Runtime],
})
export class RuntimeModule {}
export { Runtime };

View File

@@ -1,258 +0,0 @@
import {
forwardRef,
Inject,
Injectable,
Logger,
OnModuleInit,
} from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import { difference, keyBy } from 'lodash-es';
import { Cache } from '../cache';
import { defaultRuntimeConfig, runtimeConfigType } from '../config/register';
import {
AppRuntimeConfigModules,
FlattenedAppRuntimeConfig,
} from '../config/types';
import { InvalidRuntimeConfigType, RuntimeConfigNotFound } from '../error';
import { defer } from '../utils/promise';
function validateConfigType<K extends keyof FlattenedAppRuntimeConfig>(
key: K,
value: any
) {
const config = defaultRuntimeConfig[key];
if (!config) {
throw new RuntimeConfigNotFound({ key });
}
const want = config.type;
const get = runtimeConfigType(value);
if (get !== want) {
throw new InvalidRuntimeConfigType({
key,
want,
get,
});
}
}
/**
* runtime.fetch(k) // v1
* runtime.fetchAll(k1, k2, k3) // [v1, v2, v3]
* runtime.set(k, v)
* runtime.update(k, (v) => {
* v.xxx = 'yyy';
* return v
* })
*/
@Injectable()
export class Runtime implements OnModuleInit {
private readonly logger = new Logger('App:RuntimeConfig');
constructor(
private readonly db: PrismaClient,
// circular deps: runtime => cache => redis(maybe) => config => runtime
@Inject(forwardRef(() => Cache)) private readonly cache: Cache
) {}
async onModuleInit() {
await this.upgradeDB();
}
async fetch<K extends keyof FlattenedAppRuntimeConfig>(
k: K
): Promise<FlattenedAppRuntimeConfig[K]> {
const cached = await this.loadCache<K>(k);
if (cached !== undefined) {
return cached;
}
const dbValue = await this.loadDb<K>(k);
if (dbValue === undefined) {
throw new RuntimeConfigNotFound({ key: k });
}
await this.setCache(k, dbValue);
return dbValue;
}
async fetchAll<
Selector extends { [Key in keyof FlattenedAppRuntimeConfig]?: true },
>(
selector: Selector
): Promise<{
// @ts-expect-error allow
[Key in keyof Selector]: FlattenedAppRuntimeConfig[Key];
}> {
const keys = Object.keys(selector);
if (keys.length === 0) {
return {} as any;
}
const records = await this.db.runtimeConfig.findMany({
select: {
id: true,
value: true,
},
where: {
id: {
in: keys,
},
deletedAt: null,
},
});
const keyed = keyBy(records, 'id');
return keys.reduce((ret, key) => {
ret[key] = keyed[key]?.value ?? defaultRuntimeConfig[key].value;
return ret;
}, {} as any);
}
async list(module?: AppRuntimeConfigModules) {
return await this.db.runtimeConfig.findMany({
where: module ? { module, deletedAt: null } : { deletedAt: null },
});
}
async set<
K extends keyof FlattenedAppRuntimeConfig,
V = FlattenedAppRuntimeConfig[K],
>(key: K, value: V) {
validateConfigType(key, value);
const config = await this.db.runtimeConfig.upsert({
where: {
id: key,
deletedAt: null,
},
create: {
...defaultRuntimeConfig[key],
value: value as any,
},
update: {
value: value as any,
deletedAt: null,
},
});
await this.setCache(key, config.value as FlattenedAppRuntimeConfig[K]);
return config;
}
async update<
K extends keyof FlattenedAppRuntimeConfig,
V = FlattenedAppRuntimeConfig[K],
>(k: K, modifier: (v: V) => V | Promise<V>) {
const data = await this.fetch<K>(k);
const updated = await modifier(data as V);
await this.set(k, updated);
return updated;
}
async loadDb<K extends keyof FlattenedAppRuntimeConfig>(
k: K
): Promise<FlattenedAppRuntimeConfig[K] | undefined> {
const v = await this.db.runtimeConfig.findFirst({
where: {
id: k,
deletedAt: null,
},
});
if (v) {
return v.value as FlattenedAppRuntimeConfig[K];
} else {
const record = await this.db.runtimeConfig.create({
data: defaultRuntimeConfig[k],
});
return record.value as any;
}
}
async loadCache<K extends keyof FlattenedAppRuntimeConfig>(
k: K
): Promise<FlattenedAppRuntimeConfig[K] | undefined> {
return this.cache.get<FlattenedAppRuntimeConfig[K]>(`SERVER_RUNTIME:${k}`);
}
async setCache<K extends keyof FlattenedAppRuntimeConfig>(
k: K,
v: FlattenedAppRuntimeConfig[K]
): Promise<boolean> {
return this.cache.set<FlattenedAppRuntimeConfig[K]>(
`SERVER_RUNTIME:${k}`,
v,
{ ttl: 60 * 1000 }
);
}
/**
* Upgrade the DB with latest runtime configs
*/
private async upgradeDB() {
const existingConfig = await this.db.runtimeConfig.findMany({
select: {
id: true,
},
where: {
deletedAt: null,
},
});
const defined = Object.keys(defaultRuntimeConfig);
const existing = existingConfig.map(c => c.id);
const newConfigs = difference(defined, existing);
const deleteConfigs = difference(existing, defined);
if (!newConfigs.length && !deleteConfigs.length) {
return;
}
this.logger.log(`Found runtime config changes, upgrading...`);
const acquired = await this.cache.setnx('runtime:upgrade', 1, {
ttl: 10 * 60 * 1000,
});
await using _ = defer(async () => {
await this.cache.delete('runtime:upgrade');
});
if (acquired) {
for (const key of newConfigs) {
await this.db.runtimeConfig.upsert({
create: defaultRuntimeConfig[key],
// old deleted setting should be restored
update: {
...defaultRuntimeConfig[key],
deletedAt: null,
},
where: {
id: key,
},
});
}
await this.db.runtimeConfig.updateMany({
where: {
id: {
in: deleteConfigs,
},
},
data: {
deletedAt: new Date(),
},
});
}
this.logger.log('Upgrade completed');
}
}

View File

@@ -1,75 +0,0 @@
import { homedir } from 'node:os';
import { join } from 'node:path';
import { defineStartupConfig, ModuleConfig } from '../config';
export interface FsStorageConfig {
path: string;
}
export interface StorageProvidersConfig {
fs?: FsStorageConfig;
}
declare module '../config' {
interface AppConfig {
storageProviders: ModuleConfig<StorageProvidersConfig>;
}
}
defineStartupConfig('storageProviders', {
fs: {
path: join(homedir(), '.affine/storage'),
},
});
export type StorageProviderType = keyof StorageProvidersConfig;
export type StorageConfig<Ext = unknown> = {
provider: StorageProviderType;
bucket: string;
} & Ext;
export interface StoragesConfig {
avatar: StorageConfig<{ publicLinkFactory: (key: string) => string }>;
blob: StorageConfig;
copilot: StorageConfig;
}
export interface AFFiNEStorageConfig {
/**
* All providers for object storage
*
* Support different providers for different usage at the same time.
*/
providers: StorageProvidersConfig;
storages: StoragesConfig;
}
export type StorageProviders = AFFiNEStorageConfig['providers'];
export type Storages = keyof AFFiNEStorageConfig['storages'];
export function getDefaultAFFiNEStorageConfig(): AFFiNEStorageConfig {
return {
providers: {
fs: {
path: join(homedir(), '.affine/storage'),
},
},
storages: {
avatar: {
provider: 'fs',
bucket: 'avatars',
publicLinkFactory: key => `/api/avatars/${key}`,
},
blob: {
provider: 'fs',
bucket: 'blobs',
},
copilot: {
provider: 'fs',
bucket: 'copilot',
},
},
};
}

View File

@@ -0,0 +1,20 @@
import { Injectable } from '@nestjs/common';
import {
StorageProvider,
StorageProviderConfig,
StorageProviders,
} from './providers';
@Injectable()
export class StorageProviderFactory {
create(config: StorageProviderConfig): StorageProvider {
const Provider = StorageProviders[config.provider];
if (!Provider) {
throw new Error(`Unknown storage provider type: ${config.provider}`);
}
return new Provider(config.config, config.bucket);
}
}

View File

@@ -1,17 +1,6 @@
import './config';
import { Global, Module } from '@nestjs/common';
import { registerStorageProvider, StorageProviderFactory } from './providers';
import { FsStorageProvider } from './providers/fs';
registerStorageProvider('fs', (config, bucket) => {
if (!config.storageProviders.fs) {
throw new Error('Missing fs storage provider configuration');
}
return new FsStorageProvider(config.storageProviders.fs, bucket);
});
import { StorageProviderFactory } from './factory';
@Global()
@Module({
@@ -19,16 +8,5 @@ registerStorageProvider('fs', (config, bucket) => {
exports: [StorageProviderFactory],
})
export class StorageProviderModule {}
export * from '../../native';
export type { StorageProviderType } from './config';
export type {
BlobInputType,
BlobOutputType,
GetObjectMetadata,
ListObjectsMetadata,
PutObjectMetadata,
StorageProvider,
} from './providers';
export { registerStorageProvider, StorageProviderFactory } from './providers';
export { autoMetadata, toBuffer } from './providers/utils';
export { StorageProviderFactory } from './factory';
export * from './providers';

View File

@@ -10,12 +10,12 @@ import {
statSync,
writeFileSync,
} from 'node:fs';
import { join, parse, resolve } from 'node:path';
import { homedir } from 'node:os';
import { join, parse } from 'node:path';
import { Readable } from 'node:stream';
import { Logger } from '@nestjs/common';
import { FsStorageConfig } from '../config';
import {
BlobInputType,
GetObjectMetadata,
@@ -30,6 +30,10 @@ function escapeKey(key: string): string {
return key.replace(/\.?\.[/\\]/g, '%');
}
export interface FsStorageConfig {
path: string;
}
export class FsStorageProvider implements StorageProvider {
private readonly path: string;
private readonly logger: Logger;
@@ -40,7 +44,9 @@ export class FsStorageProvider implements StorageProvider {
config: FsStorageConfig,
public readonly bucket: string
) {
this.path = resolve(config.path, bucket);
this.path = config.path.startsWith('~/')
? join(homedir(), config.path.slice(2), bucket)
: join(config.path, bucket);
this.ensureAvailability();
this.logger = new Logger(`${FsStorageProvider.name}:${bucket}`);

View File

@@ -1,34 +1,116 @@
import { Injectable } from '@nestjs/common';
import { Type } from '@nestjs/common';
import { Config } from '../../config';
import { StorageConfig, StorageProviderType } from '../config';
import type { StorageProvider } from './provider';
import { JSONSchema } from '../../config';
import { FsStorageConfig, FsStorageProvider } from './fs';
import { StorageProvider } from './provider';
import { R2StorageConfig, R2StorageProvider } from './r2';
import { S3StorageConfig, S3StorageProvider } from './s3';
const availableProviders = new Map<
StorageProviderType,
(config: Config, bucket: string) => StorageProvider
>();
export type StorageProviderName = 'fs' | 'aws-s3' | 'cloudflare-r2';
export const StorageProviders: Record<
StorageProviderName,
Type<StorageProvider>
> = {
fs: FsStorageProvider,
'aws-s3': S3StorageProvider,
'cloudflare-r2': R2StorageProvider,
};
export function registerStorageProvider(
type: StorageProviderType,
providerFactory: (config: Config, bucket: string) => StorageProvider
) {
availableProviders.set(type, providerFactory);
}
@Injectable()
export class StorageProviderFactory {
constructor(private readonly config: Config) {}
create(storage: StorageConfig): StorageProvider {
const providerFactory = availableProviders.get(storage.provider);
if (!providerFactory) {
throw new Error(`Unknown storage provider type: ${storage.provider}`);
export type StorageProviderConfig = { bucket: string } & (
| {
provider: 'fs';
config: FsStorageConfig;
}
| {
provider: 'aws-s3';
config: S3StorageConfig;
}
| {
provider: 'cloudflare-r2';
config: R2StorageConfig;
}
);
return providerFactory(this.config, storage.bucket);
}
}
const S3ConfigSchema: JSONSchema = {
type: 'object',
description:
'The config for the s3 compatible storage provider. directly passed to aws-sdk client.\n@link https://docs.aws.amazon.com/AWSJavaScriptSDK/latest/AWS/S3.html',
properties: {
credentials: {
type: 'object',
description: 'The credentials for the s3 compatible storage provider.',
properties: {
accessKeyId: {
type: 'string',
},
secretAccessKey: {
type: 'string',
},
},
},
},
};
export const StorageJSONSchema: JSONSchema = {
oneOf: [
{
type: 'object',
properties: {
provider: {
type: 'string',
enum: ['fs'],
},
bucket: {
type: 'string',
},
config: {
type: 'object',
properties: {
path: {
type: 'string',
},
},
},
},
},
{
type: 'object',
properties: {
provider: {
type: 'string',
enum: ['aws-s3'],
},
bucket: {
type: 'string',
},
config: S3ConfigSchema,
},
},
{
type: 'object',
properties: {
provider: {
type: 'string',
enum: ['cloudflare-r2'],
},
bucket: {
type: 'string',
},
config: {
...S3ConfigSchema,
properties: {
...S3ConfigSchema.properties,
accountId: {
type: 'string' as const,
description:
'The account id for the cloudflare r2 storage provider.',
},
},
},
},
},
],
};
export type * from './provider';
export { autoMetadata, toBuffer } from './utils';

View File

@@ -1,7 +1,5 @@
import type { Readable } from 'node:stream';
import { StorageProviderType } from '../config';
export interface GetObjectMetadata {
/**
* @default 'application/octet-stream'
@@ -28,7 +26,6 @@ export type BlobInputType = Buffer | Readable | string;
export type BlobOutputType = Readable;
export interface StorageProvider {
readonly type: StorageProviderType;
put(
key: string,
body: BlobInputType,

View File

@@ -2,12 +2,13 @@ import assert from 'node:assert';
import { Logger } from '@nestjs/common';
import type { R2StorageConfig } from '../config';
import { S3StorageProvider } from './s3';
import { S3StorageConfig, S3StorageProvider } from './s3';
export interface R2StorageConfig extends S3StorageConfig {
accountId: string;
}
export class R2StorageProvider extends S3StorageProvider {
override readonly type = 'cloudflare-r2' as any /* cast 'r2' to 's3' */;
constructor(config: R2StorageConfig, bucket: string) {
assert(config.accountId, 'accountId is required for R2 storage provider');
super(
@@ -15,6 +16,9 @@ export class R2StorageProvider extends S3StorageProvider {
...config,
forcePathStyle: true,
endpoint: `https://${config.accountId}.r2.cloudflarestorage.com`,
// see https://github.com/aws/aws-sdk-js-v3/issues/6810
requestChecksumCalculation: 'WHEN_REQUIRED',
responseChecksumValidation: 'WHEN_REQUIRED',
},
bucket
);

View File

@@ -9,26 +9,25 @@ import {
NoSuchKey,
PutObjectCommand,
S3Client,
S3ClientConfig,
} from '@aws-sdk/client-s3';
import { Logger } from '@nestjs/common';
import {
autoMetadata,
BlobInputType,
GetObjectMetadata,
ListObjectsMetadata,
PutObjectMetadata,
StorageProvider,
toBuffer,
} from '../../../base/storage';
import type { S3StorageConfig } from '../config';
} from './provider';
import { autoMetadata, toBuffer } from './utils';
export type S3StorageConfig = S3ClientConfig;
export class S3StorageProvider implements StorageProvider {
protected logger: Logger;
protected client: S3Client;
readonly type = 'aws-s3';
constructor(
config: S3StorageConfig,
public readonly bucket: string

View File

@@ -1,27 +1,38 @@
import { defineStartupConfig, ModuleConfig } from '../config';
import { defineModuleConfig } from '../config';
export type ThrottlerType = 'default' | 'strict';
type ThrottlerStartupConfigurations = {
[key in ThrottlerType]: {
ttl: number;
limit: number;
};
};
declare module '../config' {
interface AppConfig {
throttler: ModuleConfig<ThrottlerStartupConfigurations>;
declare global {
interface AppConfigSchema {
throttle: {
enabled: boolean;
throttlers: {
[key in ThrottlerType]: ConfigItem<{
ttl: number;
limit: number;
}>;
};
};
}
}
defineStartupConfig('throttler', {
default: {
ttl: 60,
limit: 120,
defineModuleConfig('throttle', {
enabled: {
desc: 'Whether the throttler is enabled.',
default: true,
},
strict: {
ttl: 60,
limit: 20,
'throttlers.default': {
desc: 'The config for the default throttler.',
default: {
ttl: 60,
limit: 120,
},
},
'throttlers.strict': {
desc: 'The config for the strict throttler.',
default: {
ttl: 60,
limit: 20,
},
},
});

View File

@@ -24,14 +24,19 @@ export class ThrottlerStorage extends ThrottlerStorageService {}
@Injectable()
class CustomOptionsFactory implements ThrottlerOptionsFactory {
constructor(private readonly storage: ThrottlerStorage) {}
constructor(
private readonly config: Config,
private readonly storage: ThrottlerStorage
) {}
createThrottlerOptions() {
const options: ThrottlerModuleOptions = {
throttlers: Object.entries(AFFiNE.throttler).map(([name, config]) => ({
name,
...config,
})),
throttlers: Object.entries(this.config.throttle.throttlers).map(
([name, config]) => ({
name,
...config,
})
),
storage: this.storage,
};
@@ -84,6 +89,7 @@ export class CloudThrottlerGuard extends ThrottlerGuard {
ttl,
blockDuration,
} = request;
let limit = request.limit;
// give it 'default' if no throttler is specified,
@@ -110,13 +116,9 @@ export class CloudThrottlerGuard extends ThrottlerGuard {
let tracker = await this.getTracker(req);
if (this.config.node.dev) {
limit = Number.MAX_SAFE_INTEGER;
} else {
// custom limit or ttl APIs will be treated standalone
if (limit !== throttlerOptions.limit || ttl !== throttlerOptions.ttl) {
tracker += ';custom';
}
// custom limit or ttl APIs will be treated standalone
if (limit !== throttlerOptions.limit || ttl !== throttlerOptions.ttl) {
tracker += ';custom';
}
const key = this.generateKey(
@@ -151,6 +153,10 @@ export class CloudThrottlerGuard extends ThrottlerGuard {
}
override async canActivate(context: ExecutionContext): Promise<boolean> {
if (!this.config.throttle.enabled) {
return true;
}
const { req } = this.getRequestResponse(context);
const throttler = this.getSpecifiedThrottler(context);

View File

@@ -94,7 +94,7 @@ export function parseCookies(
export type RequestType = GqlContextType | 'event' | 'job';
export function genRequestId(type: RequestType) {
return `${AFFiNE.flavor.type}:${type}:${randomUUID()}`;
return `${env.DEPLOYMENT_TYPE}:${type}:${randomUUID()}`;
}
export function getOrGenRequestId(type: RequestType) {

View File

@@ -2,9 +2,7 @@ import { Readable } from 'node:stream';
export function ApplyType<T>(): ConstructorOf<T> {
// @ts-expect-error used to fake the type of config
return class Inner implements T {
constructor() {}
};
return class Inner implements T {};
}
export type PathType<T, Path extends string> =
@@ -30,7 +28,7 @@ export type Join<Prefix, Suffixes> = Prefix extends string | number
export type LeafPaths<
T,
Path extends string = '',
Prefix extends string = '',
MaxDepth extends string = '.....',
Depth extends string = '',
> = Depth extends MaxDepth
@@ -40,7 +38,9 @@ export type LeafPaths<
[K in keyof T]-?: K extends string | number
? T[K] extends PrimitiveType
? K
: Join<K, LeafPaths<T[K], Path, MaxDepth, `${Depth}.`>>
: T[K] extends { __leaf: true }
? K
: Join<K, LeafPaths<T[K], Prefix, MaxDepth, `${Depth}.`>>
: never;
}[keyof T]
: never;

View File

@@ -1,7 +1,7 @@
import { INestApplication } from '@nestjs/common';
import { IoAdapter } from '@nestjs/platform-socket.io';
import { createAdapter } from '@socket.io/redis-adapter';
import { Server } from 'socket.io';
import { Server, Socket } from 'socket.io';
import { Config } from '../config';
import { AuthenticationRequired } from '../error';
@@ -14,7 +14,9 @@ export class SocketIoAdapter extends IoAdapter {
}
override createIOServer(port: number, options?: any): Server {
const config = this.app.get(WEBSOCKET_OPTIONS) as Config['websocket'];
const config = this.app.get(WEBSOCKET_OPTIONS) as Config['websocket'] & {
canActivate: (socket: Socket) => Promise<boolean>;
};
const server: Server = super.createIOServer(port, {
...config,
...options,
@@ -22,7 +24,6 @@ export class SocketIoAdapter extends IoAdapter {
if (config.canActivate) {
server.use((socket, next) => {
// @ts-expect-error checked
config
.canActivate(socket)
.then(pass => {

View File

@@ -1,20 +1,34 @@
import { GatewayMetadata } from '@nestjs/websockets';
import { Socket } from 'socket.io';
import { z } from 'zod';
import { defineStartupConfig, ModuleConfig } from '../config';
import { defineModuleConfig } from '../config';
declare module '../config' {
interface AppConfig {
websocket: ModuleConfig<
GatewayMetadata & {
canActivate?: (socket: Socket) => Promise<boolean>;
}
>;
declare global {
interface AppConfigSchema {
websocket: {
transports: ConfigItem<GatewayMetadata['transports']>;
maxHttpBufferSize: number;
};
}
}
defineStartupConfig('websocket', {
transports: ['websocket', 'polling'],
// see: https://socket.io/docs/v4/server-options/#maxhttpbuffersize
maxHttpBufferSize: 1e8, // 100 MB
defineModuleConfig('websocket', {
transports: {
desc: 'The enabled transports for accepting websocket traffics.',
default: ['websocket', 'polling'],
shape: z.array(z.enum(['websocket', 'polling'])),
schema: {
type: 'array',
items: {
type: 'string',
enum: ['websocket', 'polling'],
},
},
link: 'https://docs.nestjs.com/websockets/gateways#transports',
},
maxHttpBufferSize: {
desc: 'How many bytes or characters a message can be, before closing the session (to avoid DoS).',
default: 1e8, // 100 MB
shape: z.number().int().positive(),
},
});

View File

@@ -1,42 +0,0 @@
// Convenient way to map environment variables to config values.
AFFiNE.ENV_MAP = {
AFFINE_SERVER_EXTERNAL_URL: ['server.externalUrl'],
AFFINE_SERVER_PORT: ['server.port', 'int'],
AFFINE_SERVER_HOST: 'server.host',
AFFINE_SERVER_SUB_PATH: 'server.path',
AFFINE_SERVER_HTTPS: ['server.https', 'boolean'],
ENABLE_TELEMETRY: ['metrics.telemetry.enabled', 'boolean'],
MAILER_HOST: 'mailer.host',
MAILER_PORT: ['mailer.port', 'int'],
MAILER_USER: 'mailer.auth.user',
MAILER_PASSWORD: 'mailer.auth.pass',
MAILER_SENDER: 'mailer.from.address',
MAILER_SECURE: ['mailer.secure', 'boolean'],
DATABASE_URL: 'prisma.datasourceUrl',
OAUTH_GOOGLE_CLIENT_ID: 'plugins.oauth.providers.google.clientId',
OAUTH_GOOGLE_CLIENT_SECRET: 'plugins.oauth.providers.google.clientSecret',
OAUTH_GITHUB_CLIENT_ID: 'plugins.oauth.providers.github.clientId',
OAUTH_GITHUB_CLIENT_SECRET: 'plugins.oauth.providers.github.clientSecret',
OAUTH_OIDC_ISSUER: 'plugins.oauth.providers.oidc.issuer',
OAUTH_OIDC_CLIENT_ID: 'plugins.oauth.providers.oidc.clientId',
OAUTH_OIDC_CLIENT_SECRET: 'plugins.oauth.providers.oidc.clientSecret',
OAUTH_OIDC_SCOPE: 'plugins.oauth.providers.oidc.args.scope',
OAUTH_OIDC_CLAIM_MAP_USERNAME: 'plugins.oauth.providers.oidc.args.claim_id',
OAUTH_OIDC_CLAIM_MAP_EMAIL: 'plugins.oauth.providers.oidc.args.claim_email',
OAUTH_OIDC_CLAIM_MAP_NAME: 'plugins.oauth.providers.oidc.args.claim_name',
METRICS_CUSTOMER_IO_TOKEN: ['metrics.customerIo.token', 'string'],
CAPTCHA_TURNSTILE_SECRET: ['plugins.captcha.turnstile.secret', 'string'],
COPILOT_OPENAI_API_KEY: 'plugins.copilot.openai.apiKey',
COPILOT_FAL_API_KEY: 'plugins.copilot.fal.apiKey',
COPILOT_GOOGLE_API_KEY: 'plugins.copilot.google.apiKey',
COPILOT_PERPLEXITY_API_KEY: 'plugins.copilot.perplexity.apiKey',
COPILOT_UNSPLASH_API_KEY: 'plugins.copilot.unsplashKey',
REDIS_SERVER_HOST: 'redis.host',
REDIS_SERVER_PORT: ['redis.port', 'int'],
REDIS_SERVER_USER: 'redis.username',
REDIS_SERVER_PASSWORD: 'redis.password',
REDIS_SERVER_DATABASE: ['redis.db', 'int'],
DOC_SERVICE_ENDPOINT: 'docService.endpoint',
STRIPE_API_KEY: 'plugins.payment.stripe.keys.APIKey',
STRIPE_WEBHOOK_KEY: 'plugins.payment.stripe.keys.webhookKey',
};

View File

@@ -1,95 +0,0 @@
/* oxlint-disable @typescript-eslint/no-non-null-assertion */
// Custom configurations for AFFiNE Cloud
// ====================================================================================
// Q: WHY THIS FILE EXISTS?
// A: AFFiNE deployment environment may have a lot of custom environment variables,
// which are not suitable to be put in the `affine.ts` file.
// For example, AFFiNE Cloud Clusters are deployed on Google Cloud Platform.
// We need to enable the `gcloud` plugin to make sure the nodes working well,
// but the default selfhost version may not require it.
// So it's not a good idea to put such logic in the common `affine.ts` file.
//
// ```
// if (AFFiNE.deploy) {
// AFFiNE.plugins.use('gcloud');
// }
// ```
// ====================================================================================
const env = process.env;
AFFiNE.serverName = AFFiNE.affine.canary
? 'AFFiNE Canary Cloud'
: AFFiNE.affine.beta
? 'AFFiNE Beta Cloud'
: 'AFFiNE Cloud';
AFFiNE.metrics.enabled = !AFFiNE.node.test;
if (env.R2_OBJECT_STORAGE_ACCOUNT_ID) {
AFFiNE.use('cloudflare-r2', {
accountId: env.R2_OBJECT_STORAGE_ACCOUNT_ID,
credentials: {
accessKeyId: env.R2_OBJECT_STORAGE_ACCESS_KEY_ID!,
secretAccessKey: env.R2_OBJECT_STORAGE_SECRET_ACCESS_KEY!,
},
requestChecksumCalculation: 'WHEN_REQUIRED',
responseChecksumValidation: 'WHEN_REQUIRED',
});
AFFiNE.storages.avatar.provider = 'cloudflare-r2';
AFFiNE.storages.avatar.bucket = 'account-avatar';
AFFiNE.storages.avatar.publicLinkFactory = key =>
`https://avatar.affineassets.com/${key}`;
AFFiNE.storages.blob.provider = 'cloudflare-r2';
AFFiNE.storages.blob.bucket = `workspace-blobs-${
AFFiNE.affine.canary ? 'canary' : 'prod'
}`;
AFFiNE.use('copilot', {
storage: {
provider: 'cloudflare-r2',
bucket: `workspace-copilot-${AFFiNE.affine.canary ? 'canary' : 'prod'}`,
},
});
}
AFFiNE.use('copilot', {
openai: {
apiKey: '',
},
fal: {
apiKey: '',
},
});
AFFiNE.use('payment', {
stripe: {
keys: {
// fake the key to ensure the server generate full GraphQL Schema even env vars are not set
APIKey: '1',
webhookKey: '1',
},
},
});
AFFiNE.use('oauth');
/* Captcha Plugin Default Config */
AFFiNE.use('captcha', {
turnstile: {},
challenge: {
bits: 20,
},
});
if (AFFiNE.deploy) {
AFFiNE.mailer = {
service: 'gmail',
auth: {
user: env.MAILER_USER,
pass: env.MAILER_PASSWORD,
},
};
AFFiNE.use('gcloud');
} else {
// only enable dev mode
AFFiNE.use('worker');
}

Some files were not shown because too many files have changed in this diff Show More