diff --git a/packages/backend/server/src/__tests__/models/session.spec.ts b/packages/backend/server/src/__tests__/models/session.spec.ts new file mode 100644 index 0000000000..baca5ab852 --- /dev/null +++ b/packages/backend/server/src/__tests__/models/session.spec.ts @@ -0,0 +1,276 @@ +import { TestingModule } from '@nestjs/testing'; +import { PrismaClient } from '@prisma/client'; +import ava, { TestFn } from 'ava'; + +import { Config } from '../../base/config'; +import { SessionModel } from '../../models/session'; +import { UserModel } from '../../models/user'; +import { createTestingModule, initTestingDB } from '../utils'; + +interface Context { + config: Config; + module: TestingModule; + db: PrismaClient; + session: SessionModel; + user: UserModel; +} + +const test = ava as TestFn; + +test.before(async t => { + const module = await createTestingModule({ + providers: [SessionModel], + }); + + t.context.session = module.get(SessionModel); + t.context.user = module.get(UserModel); + t.context.db = module.get(PrismaClient); + t.context.config = module.get(Config); + t.context.module = module; +}); + +test.beforeEach(async t => { + await initTestingDB(t.context.db); +}); + +test.after(async t => { + await t.context.module.close(); +}); + +test('should create a new session', async t => { + const session = await t.context.session.createSession(); + t.truthy(session.id); + t.truthy(session.createdAt); + t.is(session.deprecated_expiresAt, null); +}); + +test('should get a exists session', async t => { + const session = await t.context.session.createSession(); + const existsSession = await t.context.session.getSession(session.id); + t.deepEqual(session, existsSession); +}); + +test('should get null when session id not exists', async t => { + const session = await t.context.session.getSession('not-exists'); + t.is(session, null); +}); + +test('should delete a exists session', async t => { + const session = await t.context.session.createSession(); + const count = await t.context.session.deleteSession(session.id); + t.is(count, 1); + const existsSession = await t.context.session.getSession(session.id); + t.is(existsSession, null); +}); + +test('should not delete a not exists session', async t => { + const count = await t.context.session.deleteSession('not-exists'); + t.is(count, 0); +}); + +test('should create a new userSession', async t => { + const user = await t.context.user.create({ + email: 'test@affine.pro', + }); + const session = await t.context.db.session.create({ + data: {}, + }); + const userSession = await t.context.session.createOrRefreshUserSession( + user.id, + session.id + ); + t.is(userSession.sessionId, session.id); + t.is(userSession.userId, user.id); + t.not(userSession.expiresAt, null); +}); + +test('should auto create a new session when sessionId not exists in database', async t => { + const user = await t.context.user.create({ + email: 'test@affine.pro', + }); + const userSession = await t.context.session.createOrRefreshUserSession( + user.id, + 'not-exists-session-id' + ); + t.not(userSession.sessionId, 'not-exists-session-id'); + t.truthy(userSession.sessionId); + t.is(userSession.userId, user.id); + t.not(userSession.expiresAt, null); +}); + +test('should refresh exists userSession', async t => { + const user = await t.context.user.create({ + email: 'test@affine.pro', + }); + const session = await t.context.db.session.create({ + data: {}, + }); + const userSession = await t.context.session.createOrRefreshUserSession( + user.id, + session.id + ); + t.is(userSession.sessionId, session.id); + t.is(userSession.userId, user.id); + t.not(userSession.expiresAt, null); + + const existsUserSession = await t.context.session.createOrRefreshUserSession( + user.id, + session.id + ); + t.is(existsUserSession.sessionId, session.id); + t.is(existsUserSession.userId, user.id); + t.not(existsUserSession.expiresAt, null); + t.is(existsUserSession.id, userSession.id); + t.assert( + existsUserSession.expiresAt!.getTime() > userSession.expiresAt!.getTime() + ); +}); + +test('should not refresh userSession when expires time not hit ttr', async t => { + const user = await t.context.user.create({ + email: 'test@affine.pro', + }); + const session = await t.context.db.session.create({ + data: {}, + }); + const userSession = await t.context.session.createOrRefreshUserSession( + user.id, + session.id + ); + let newExpiresAt = + await t.context.session.refreshUserSessionIfNeeded(userSession); + t.is(newExpiresAt, undefined); + userSession.expiresAt = new Date( + userSession.expiresAt!.getTime() - t.context.config.auth.session.ttr * 1000 + ); + newExpiresAt = + await t.context.session.refreshUserSessionIfNeeded(userSession); + t.is(newExpiresAt, undefined); +}); + +test('should not refresh userSession when expires time hit ttr', async t => { + const user = await t.context.user.create({ + email: 'test@affine.pro', + }); + const session = await t.context.session.createSession(); + const userSession = await t.context.session.createOrRefreshUserSession( + user.id, + session.id + ); + const ttr = t.context.config.auth.session.ttr * 2; + userSession.expiresAt = new Date( + userSession.expiresAt!.getTime() - ttr * 1000 + ); + const newExpiresAt = + await t.context.session.refreshUserSessionIfNeeded(userSession); + t.not(newExpiresAt, undefined); +}); + +test('should find userSessions without user property by default', async t => { + const session = await t.context.db.session.create({ + data: {}, + }); + const count = 10; + for (let i = 0; i < count; i++) { + const user = await t.context.user.create({ + email: `test${i}@affine.pro`, + }); + await t.context.session.createOrRefreshUserSession(user.id, session.id); + } + const userSessions = await t.context.session.findUserSessionsBySessionId( + session.id + ); + t.is(userSessions.length, count); + for (const userSession of userSessions) { + t.is(userSession.sessionId, session.id); + t.is(Reflect.get(userSession, 'user'), undefined); + } +}); + +test('should find userSessions include user property', async t => { + const session = await t.context.db.session.create({ + data: {}, + }); + const count = 10; + for (let i = 0; i < count; i++) { + const user = await t.context.user.create({ + email: `test${i}@affine.pro`, + }); + await t.context.session.createOrRefreshUserSession(user.id, session.id); + } + const userSessions = await t.context.session.findUserSessionsBySessionId( + session.id, + { user: true } + ); + t.is(userSessions.length, count); + for (const userSession of userSessions) { + t.is(userSession.sessionId, session.id); + t.truthy(userSession.user.id); + } +}); + +test('should delete userSession success by userId', async t => { + const user = await t.context.user.create({ + email: 'test@affine.pro', + }); + const session = await t.context.db.session.create({ + data: {}, + }); + await t.context.session.createOrRefreshUserSession(user.id, session.id); + let count = await t.context.session.deleteUserSession(user.id); + t.is(count, 1); + count = await t.context.session.deleteUserSession(user.id); + t.is(count, 0); +}); + +test('should delete userSession success by userId and sessionId', async t => { + const user = await t.context.user.create({ + email: 'test@affine.pro', + }); + const session = await t.context.db.session.create({ + data: {}, + }); + await t.context.session.createOrRefreshUserSession(user.id, session.id); + const count = await t.context.session.deleteUserSession(user.id, session.id); + t.is(count, 1); +}); + +test('should delete userSession fail when sessionId not match', async t => { + const user = await t.context.user.create({ + email: 'test@affine.pro', + }); + const session = await t.context.db.session.create({ + data: {}, + }); + await t.context.session.createOrRefreshUserSession(user.id, session.id); + const count = await t.context.session.deleteUserSession( + user.id, + 'not-exists-session-id' + ); + t.is(count, 0); +}); + +test('should cleanup expired userSessions', async t => { + const user = await t.context.user.create({ + email: 'test@affine.pro', + }); + const session = await t.context.db.session.create({ + data: {}, + }); + const userSession = await t.context.session.createOrRefreshUserSession( + user.id, + session.id + ); + await t.context.session.cleanExpiredUserSessions(); + let count = await t.context.db.userSession.count(); + t.is(count, 1); + + // Set expiresAt to past time + await t.context.db.userSession.update({ + where: { id: userSession.id }, + data: { expiresAt: new Date('2022-01-01') }, + }); + await t.context.session.cleanExpiredUserSessions(); + count = await t.context.db.userSession.count(); + t.is(count, 0); +}); diff --git a/packages/backend/server/src/models/index.ts b/packages/backend/server/src/models/index.ts index 8dbe778014..32f50132ba 100644 --- a/packages/backend/server/src/models/index.ts +++ b/packages/backend/server/src/models/index.ts @@ -1,12 +1,16 @@ import { Global, Injectable, Module } from '@nestjs/common'; +import { SessionModel } from './session'; import { UserModel } from './user'; -const models = [UserModel] as const; +const models = [UserModel, SessionModel] as const; @Injectable() export class Models { - constructor(public readonly user: UserModel) {} + constructor( + public readonly user: UserModel, + public readonly session: SessionModel + ) {} } @Global() diff --git a/packages/backend/server/src/models/session.ts b/packages/backend/server/src/models/session.ts new file mode 100644 index 0000000000..804b305de0 --- /dev/null +++ b/packages/backend/server/src/models/session.ts @@ -0,0 +1,156 @@ +import { Injectable, Logger } from '@nestjs/common'; +import { + Prisma, + PrismaClient, + type Session, + type User, + type UserSession, +} from '@prisma/client'; + +import { Config } from '../base'; + +export type { Session, UserSession }; +export type UserSessionWithUser = UserSession & { user: User }; + +@Injectable() +export class SessionModel { + private readonly logger = new Logger(SessionModel.name); + constructor( + private readonly db: PrismaClient, + private readonly config: Config + ) {} + + async createSession() { + return await this.db.session.create({ + data: {}, + }); + } + + async getSession(id: string) { + return await this.db.session.findFirst({ + where: { + id, + }, + }); + } + + async deleteSession(id: string) { + const { count } = await this.db.session.deleteMany({ + where: { + id, + }, + }); + this.logger.log(`Deleted session success by id: ${id}`); + return count; + } + + async createOrRefreshUserSession( + userId: string, + sessionId?: string, + ttl = this.config.auth.session.ttl + ) { + // check whether given session is valid + if (sessionId) { + const session = await this.db.session.findFirst({ + where: { + id: sessionId, + }, + }); + + if (!session) { + sessionId = undefined; + } + } + + if (!sessionId) { + const session = await this.createSession(); + sessionId = session.id; + } + + const expiresAt = new Date(Date.now() + ttl * 1000); + return await this.db.userSession.upsert({ + where: { + sessionId_userId: { + sessionId, + userId, + }, + }, + update: { + expiresAt, + }, + create: { + sessionId, + userId, + expiresAt, + }, + }); + } + + async refreshUserSessionIfNeeded( + userSession: UserSession, + ttr = this.config.auth.session.ttr + ): Promise { + if ( + userSession.expiresAt && + userSession.expiresAt.getTime() - Date.now() > ttr * 1000 + ) { + // no need to refresh + return; + } + + const newExpiresAt = new Date( + Date.now() + this.config.auth.session.ttl * 1000 + ); + await this.db.userSession.update({ + where: { + id: userSession.id, + }, + data: { + expiresAt: newExpiresAt, + }, + }); + + // return the new expiresAt after refresh + return newExpiresAt; + } + + async findUserSessionsBySessionId( + sessionId: string, + include?: T + ): Promise<(T extends { user: true } ? UserSessionWithUser : UserSession)[]> { + return await this.db.userSession.findMany({ + where: { + sessionId, + OR: [{ expiresAt: { gt: new Date() } }, { expiresAt: null }], + }, + orderBy: { + createdAt: 'asc', + }, + include: include as Prisma.UserSessionInclude, + }); + } + + async deleteUserSession(userId: string, sessionId?: string) { + const { count } = await this.db.userSession.deleteMany({ + where: { + userId, + sessionId, + }, + }); + this.logger.log( + `Deleted user session success by userId: ${userId} and sessionId: ${sessionId}` + ); + return count; + } + + async cleanExpiredUserSessions() { + const result = await this.db.userSession.deleteMany({ + where: { + expiresAt: { + lte: new Date(), + }, + }, + }); + this.logger.log(`Cleaned ${result.count} expired user sessions`); + } +}