diff --git a/packages/backend/server/src/__tests__/models/feature.spec.ts b/packages/backend/server/src/__tests__/models/feature.spec.ts index 289c1b33ab..10a8e61401 100644 --- a/packages/backend/server/src/__tests__/models/feature.spec.ts +++ b/packages/backend/server/src/__tests__/models/feature.spec.ts @@ -13,9 +13,7 @@ interface Context { const test = ava as TestFn; test.before(async t => { - const module = await createTestingModule({ - providers: [FeatureModel], - }); + const module = await createTestingModule({}); t.context.feature = module.get(FeatureModel); t.context.module = module; diff --git a/packages/backend/server/src/__tests__/models/session.spec.ts b/packages/backend/server/src/__tests__/models/session.spec.ts index baca5ab852..bc91a58b07 100644 --- a/packages/backend/server/src/__tests__/models/session.spec.ts +++ b/packages/backend/server/src/__tests__/models/session.spec.ts @@ -18,9 +18,7 @@ interface Context { const test = ava as TestFn; test.before(async t => { - const module = await createTestingModule({ - providers: [SessionModel], - }); + const module = await createTestingModule({}); t.context.session = module.get(SessionModel); t.context.user = module.get(UserModel); diff --git a/packages/backend/server/src/__tests__/models/user.spec.ts b/packages/backend/server/src/__tests__/models/user.spec.ts index aacf4e7756..5d4730ed4b 100644 --- a/packages/backend/server/src/__tests__/models/user.spec.ts +++ b/packages/backend/server/src/__tests__/models/user.spec.ts @@ -17,9 +17,7 @@ interface Context { const test = ava as TestFn; test.before(async t => { - const module = await createTestingModule({ - providers: [UserModel], - }); + const module = await createTestingModule({}); t.context.user = module.get(UserModel); t.context.module = module; diff --git a/packages/backend/server/src/__tests__/models/verification-token.spec.ts b/packages/backend/server/src/__tests__/models/verification-token.spec.ts index 03680abef1..30814dc06d 100644 --- a/packages/backend/server/src/__tests__/models/verification-token.spec.ts +++ b/packages/backend/server/src/__tests__/models/verification-token.spec.ts @@ -17,9 +17,7 @@ interface Context { const test = ava as TestFn; test.before(async t => { - const module = await createTestingModule({ - providers: [VerificationTokenModel], - }); + const module = await createTestingModule({}); t.context.verificationToken = module.get(VerificationTokenModel); t.context.db = module.get(PrismaClient); diff --git a/packages/backend/server/src/base/index.ts b/packages/backend/server/src/base/index.ts index 368d457b28..af1b7f2335 100644 --- a/packages/backend/server/src/base/index.ts +++ b/packages/backend/server/src/base/index.ts @@ -41,4 +41,4 @@ export { getRequestResponseFromHost, parseCookies, } from './utils/request'; -export type * from './utils/types'; +export * from './utils/types'; diff --git a/packages/backend/server/src/models/base.ts b/packages/backend/server/src/models/base.ts new file mode 100644 index 0000000000..2b0645cd29 --- /dev/null +++ b/packages/backend/server/src/models/base.ts @@ -0,0 +1,19 @@ +import { Inject, Logger } from '@nestjs/common'; +import { PrismaClient } from '@prisma/client'; + +import { Config } from '../base'; +import type { Models } from '.'; +import { MODELS_SYMBOL } from './provider'; + +export class BaseModel { + protected readonly logger = new Logger(this.constructor.name); + + @Inject(MODELS_SYMBOL) + protected readonly models!: Models; + + @Inject(Config) + protected readonly config!: Config; + + @Inject(PrismaClient) + protected readonly db!: PrismaClient; +} diff --git a/packages/backend/server/src/models/feature.ts b/packages/backend/server/src/models/feature.ts index c32b713177..1e3d3097a7 100644 --- a/packages/backend/server/src/models/feature.ts +++ b/packages/backend/server/src/models/feature.ts @@ -1,8 +1,9 @@ -import { Injectable, Logger } from '@nestjs/common'; -import { Feature, PrismaClient } from '@prisma/client'; +import { Injectable } from '@nestjs/common'; +import { Feature } from '@prisma/client'; import { z } from 'zod'; import { PrismaTransaction } from '../base'; +import { BaseModel } from './base'; import { Features, FeatureType } from './common'; type FeatureNames = keyof typeof Features; @@ -17,11 +18,7 @@ type FeatureConfigs = z.infer< // We have to manually update all the users and workspaces binding to the latest version, which are thousands of handreds. // This is a huge burden for us and we should remove it. @Injectable() -export class FeatureModel { - private readonly logger = new Logger(FeatureModel.name); - - constructor(private readonly db: PrismaClient) {} - +export class FeatureModel extends BaseModel { async get(name: T) { const feature = await this.getLatest(this.db, name); diff --git a/packages/backend/server/src/models/index.ts b/packages/backend/server/src/models/index.ts index 37ba5f8e6d..8da29e9681 100644 --- a/packages/backend/server/src/models/index.ts +++ b/packages/backend/server/src/models/index.ts @@ -1,31 +1,71 @@ -import { Global, Injectable, Module } from '@nestjs/common'; +import { + ExistingProvider, + FactoryProvider, + Global, + Module, +} from '@nestjs/common'; +import { ModuleRef } from '@nestjs/core'; +import { ApplyType } from '../base'; import { FeatureModel } from './feature'; +import { MODELS_SYMBOL } from './provider'; import { SessionModel } from './session'; import { UserModel } from './user'; import { VerificationTokenModel } from './verification-token'; -const models = [ - UserModel, - SessionModel, - VerificationTokenModel, - FeatureModel, -] as const; +const MODELS = { + user: UserModel, + session: SessionModel, + verificationToken: VerificationTokenModel, + feature: FeatureModel, +}; -@Injectable() -export class Models { - constructor( - public readonly user: UserModel, - public readonly session: SessionModel, - public readonly verificationToken: VerificationTokenModel, - public readonly feature: FeatureModel - ) {} -} +type ModelsType = { + [K in keyof typeof MODELS]: InstanceType<(typeof MODELS)[K]>; +}; + +export class Models extends ApplyType() {} + +const ModelsProvider: FactoryProvider = { + provide: Models, + useFactory: (ref: ModuleRef) => { + return new Proxy({} as any, { + get: (target, prop) => { + // cache + if (prop in target) { + return target[prop]; + } + + // find the model instance + // @ts-expect-error null detection happens right after + const Model = MODELS[prop]; + if (!Model) { + return undefined; + } + + const model = ref.get(Model); + + if (!model) { + throw new Error(`Failed to initialize model ${Model.name}`); + } + + target[prop] = model; + return model; + }, + }); + }, + inject: [ModuleRef], +}; + +const ModelsSymbolProvider: ExistingProvider = { + provide: MODELS_SYMBOL, + useExisting: Models, +}; @Global() @Module({ - providers: [...models, Models], - exports: [Models], + providers: [...Object.values(MODELS), ModelsProvider, ModelsSymbolProvider], + exports: [ModelsProvider], }) export class ModelModules {} diff --git a/packages/backend/server/src/models/provider.ts b/packages/backend/server/src/models/provider.ts new file mode 100644 index 0000000000..c43ac9c593 --- /dev/null +++ b/packages/backend/server/src/models/provider.ts @@ -0,0 +1 @@ +export const MODELS_SYMBOL = Symbol('AFFINE_MODELS'); diff --git a/packages/backend/server/src/models/session.ts b/packages/backend/server/src/models/session.ts index 804b305de0..c60cae75c7 100644 --- a/packages/backend/server/src/models/session.ts +++ b/packages/backend/server/src/models/session.ts @@ -1,25 +1,18 @@ -import { Injectable, Logger } from '@nestjs/common'; +import { Injectable } from '@nestjs/common'; import { Prisma, - PrismaClient, type Session, type User, type UserSession, } from '@prisma/client'; -import { Config } from '../base'; +import { BaseModel } from './base'; export type { Session, UserSession }; export type UserSessionWithUser = UserSession & { user: User }; @Injectable() -export class SessionModel { - private readonly logger = new Logger(SessionModel.name); - constructor( - private readonly db: PrismaClient, - private readonly config: Config - ) {} - +export class SessionModel extends BaseModel { async createSession() { return await this.db.session.create({ data: {}, diff --git a/packages/backend/server/src/models/user.ts b/packages/backend/server/src/models/user.ts index d366793fa5..281cd1e472 100644 --- a/packages/backend/server/src/models/user.ts +++ b/packages/backend/server/src/models/user.ts @@ -1,9 +1,8 @@ -import { Injectable, Logger } from '@nestjs/common'; -import { Prisma, PrismaClient, type User, Workspace } from '@prisma/client'; +import { Injectable } from '@nestjs/common'; +import { Prisma, type User, Workspace } from '@prisma/client'; import { pick } from 'lodash-es'; import { - Config, CryptoHelper, EmailAlreadyUsed, EventEmitter, @@ -15,6 +14,7 @@ import { import type { Payload } from '../base/event/def'; import { Permission } from '../core/permission'; import { Quota_FreePlanV1_1 } from '../core/quota/schema'; +import { BaseModel } from './base'; const publicUserSelect = { id: true, @@ -64,14 +64,13 @@ export type PublicUser = Pick; export type { User }; @Injectable() -export class UserModel { - private readonly logger = new Logger(UserModel.name); +export class UserModel extends BaseModel { constructor( - private readonly db: PrismaClient, private readonly crypto: CryptoHelper, - private readonly event: EventEmitter, - private readonly config: Config - ) {} + private readonly event: EventEmitter + ) { + super(); + } async get(id: string) { return this.db.user.findUnique({ diff --git a/packages/backend/server/src/models/verification-token.ts b/packages/backend/server/src/models/verification-token.ts index 85be612b5e..9956b35140 100644 --- a/packages/backend/server/src/models/verification-token.ts +++ b/packages/backend/server/src/models/verification-token.ts @@ -1,9 +1,10 @@ import { randomUUID } from 'node:crypto'; -import { Injectable, Logger } from '@nestjs/common'; -import { PrismaClient, type VerificationToken } from '@prisma/client'; +import { Injectable } from '@nestjs/common'; +import { type VerificationToken } from '@prisma/client'; import { CryptoHelper } from '../base/helpers'; +import { BaseModel } from './base'; export type { VerificationToken }; @@ -16,12 +17,10 @@ export enum TokenType { } @Injectable() -export class VerificationTokenModel { - private readonly logger = new Logger(VerificationTokenModel.name); - constructor( - private readonly db: PrismaClient, - private readonly crypto: CryptoHelper - ) {} +export class VerificationTokenModel extends BaseModel { + constructor(private readonly crypto: CryptoHelper) { + super(); + } /** * create token by type and credential (optional) with ttl in seconds (default 30 minutes)