From c69e542b983a08543197cd76bec9f2815f2eae78 Mon Sep 17 00:00:00 2001 From: liuyi Date: Wed, 22 Nov 2023 04:08:59 +0000 Subject: [PATCH] feat(server): add cache module (#4973) --- packages/backend/server/src/app.ts | 23 +- packages/backend/server/src/cache/cache.ts | 330 ++++++++++++++++++++ packages/backend/server/src/cache/index.ts | 24 ++ packages/backend/server/src/cache/redis.ts | 194 ++++++++++++ packages/backend/server/tests/cache.spec.ts | 108 +++++++ 5 files changed, 669 insertions(+), 10 deletions(-) create mode 100644 packages/backend/server/src/cache/cache.ts create mode 100644 packages/backend/server/src/cache/index.ts create mode 100644 packages/backend/server/src/cache/redis.ts create mode 100644 packages/backend/server/tests/cache.spec.ts diff --git a/packages/backend/server/src/app.ts b/packages/backend/server/src/app.ts index de7954ba06..8441983272 100644 --- a/packages/backend/server/src/app.ts +++ b/packages/backend/server/src/app.ts @@ -1,6 +1,7 @@ import { Module } from '@nestjs/common'; import { AppController } from './app.controller'; +import { CacheModule } from './cache'; import { ConfigModule } from './config'; import { MetricsModule } from './metrics'; import { BusinessModules } from './modules'; @@ -10,17 +11,19 @@ import { SessionModule } from './session'; import { StorageModule } from './storage'; import { RateLimiterModule } from './throttler'; +const BasicModules = [ + PrismaModule, + ConfigModule.forRoot(), + CacheModule, + StorageModule.forRoot(), + MetricsModule, + SessionModule, + RateLimiterModule, + AuthModule, +]; + @Module({ - imports: [ - PrismaModule, - ConfigModule.forRoot(), - StorageModule.forRoot(), - MetricsModule, - SessionModule, - RateLimiterModule, - AuthModule, - ...BusinessModules, - ], + imports: [...BasicModules, ...BusinessModules], controllers: [AppController], }) export class AppModule {} diff --git a/packages/backend/server/src/cache/cache.ts b/packages/backend/server/src/cache/cache.ts new file mode 100644 index 0000000000..24df6ffe76 --- /dev/null +++ b/packages/backend/server/src/cache/cache.ts @@ -0,0 +1,330 @@ +import Keyv from 'keyv'; + +export interface CacheSetOptions { + // in milliseconds + ttl?: number; +} + +// extends if needed +export interface Cache { + // standard operation + get(key: string): Promise; + set( + key: string, + value: T, + opts?: CacheSetOptions + ): Promise; + setnx( + key: string, + value: T, + opts?: CacheSetOptions + ): Promise; + increase(key: string, count?: number): Promise; + decrease(key: string, count?: number): Promise; + delete(key: string): Promise; + has(key: string): Promise; + ttl(key: string): Promise; + expire(key: string, ttl: number): Promise; + + // list operations + pushBack(key: string, ...values: T[]): Promise; + pushFront(key: string, ...values: T[]): Promise; + len(key: string): Promise; + list(key: string, start: number, end: number): Promise; + popFront(key: string, count?: number): Promise; + popBack(key: string, count?: number): Promise; + + // map operations + mapSet( + map: string, + key: string, + value: T, + opts: CacheSetOptions + ): Promise; + mapIncrease(map: string, key: string, count?: number): Promise; + mapDecrease(map: string, key: string, count?: number): Promise; + mapGet(map: string, key: string): Promise; + mapDelete(map: string, key: string): Promise; + mapKeys(map: string): Promise; + mapRandomKey(map: string): Promise; + mapLen(map: string): Promise; +} + +export class LocalCache implements Cache { + private readonly kv: Keyv; + + constructor() { + this.kv = new Keyv(); + } + + // standard operation + async get(key: string): Promise { + return this.kv.get(key).catch(() => undefined); + } + + async set( + key: string, + value: T, + opts: CacheSetOptions = {} + ): Promise { + return this.kv + .set(key, value, opts.ttl) + .then(() => true) + .catch(() => false); + } + + async setnx( + key: string, + value: T, + opts?: CacheSetOptions | undefined + ): Promise { + if (!(await this.has(key))) { + return this.set(key, value, opts); + } + return false; + } + + async increase(key: string, count: number = 1): Promise { + const prev = (await this.get(key)) ?? 0; + if (typeof prev !== 'number') { + throw new Error( + `Expect a Number keyed by ${key}, but found ${typeof prev}` + ); + } + + const curr = prev + count; + return (await this.set(key, curr)) ? curr : prev; + } + + async decrease(key: string, count: number = 1): Promise { + return this.increase(key, -count); + } + + async delete(key: string): Promise { + return this.kv.delete(key).catch(() => false); + } + + async has(key: string): Promise { + return this.kv.has(key).catch(() => false); + } + + async ttl(key: string): Promise { + return this.kv + .get(key, { raw: true }) + .then(raw => (raw?.expires ? raw.expires - Date.now() : Infinity)) + .catch(() => 0); + } + + async expire(key: string, ttl: number): Promise { + const value = await this.kv.get(key); + return this.set(key, value, { ttl }); + } + + // list operations + private async getArray(key: string) { + const raw = await this.kv.get(key, { raw: true }); + if (raw && !Array.isArray(raw.value)) { + throw new Error( + `Expect an Array keyed by ${key}, but found ${raw.value}` + ); + } + + return raw as Keyv.DeserializedData; + } + + private async setArray( + key: string, + value: T[], + opts: CacheSetOptions = {} + ) { + return this.set(key, value, opts).then(() => value.length); + } + + async pushBack(key: string, ...values: T[]): Promise { + let list: any[] = []; + let ttl: number | undefined = undefined; + const raw = await this.getArray(key); + if (raw) { + list = raw.value; + if (raw.expires) { + ttl = raw.expires - Date.now(); + } + } + + list = list.concat(values); + return this.setArray(key, list, { ttl }); + } + + async pushFront(key: string, ...values: T[]): Promise { + let list: any[] = []; + let ttl: number | undefined = undefined; + const raw = await this.getArray(key); + if (raw) { + list = raw.value; + if (raw.expires) { + ttl = raw.expires - Date.now(); + } + } + + list = values.concat(list); + return this.setArray(key, list, { ttl }); + } + + async len(key: string): Promise { + return this.getArray(key).then(v => v?.value.length ?? 0); + } + + /** + * list array elements with `[start, end]` + * the end indice is inclusive + */ + async list( + key: string, + start: number, + end: number + ): Promise { + const raw = await this.getArray(key); + if (raw?.value) { + start = (raw.value.length + start) % raw.value.length; + end = ((raw.value.length + end) % raw.value.length) + 1; + return raw.value.slice(start, end); + } else { + return []; + } + } + + private async trim(key: string, start: number, end: number) { + const raw = await this.getArray(key); + if (raw) { + start = (raw.value.length + start) % raw.value.length; + // make negative end index work, and end indice is inclusive + end = ((raw.value.length + end) % raw.value.length) + 1; + const result = raw.value.splice(start, end); + + await this.set(key, raw.value, { + ttl: raw.expires ? raw.expires - Date.now() : undefined, + }); + + return result; + } + + return []; + } + + async popFront(key: string, count: number = 1) { + return this.trim(key, 0, count - 1); + } + + async popBack(key: string, count: number = 1) { + return this.trim(key, -count, count - 1); + } + + // map operations + private async getMap(map: string) { + const raw = await this.kv.get(map, { raw: true }); + + if (raw) { + if (typeof raw.value !== 'object') { + throw new Error( + `Expect an Object keyed by ${map}, but found ${typeof raw}` + ); + } + + if (Array.isArray(raw.value)) { + throw new Error(`Expect an Object keyed by ${map}, but found an Array`); + } + } + + return raw as Keyv.DeserializedData>; + } + + private async setMap( + map: string, + value: Record, + opts: CacheSetOptions = {} + ) { + return this.kv.set(map, value, opts.ttl).then(() => true); + } + + async mapGet(map: string, key: string): Promise { + const raw = await this.getMap(map); + if (raw?.value) { + return raw.value[key]; + } + + return undefined; + } + + async mapSet( + map: string, + key: string, + value: T + ): Promise { + const raw = await this.getMap(map); + const data = raw?.value ?? {}; + + data[key] = value; + + return this.setMap(map, data, { + ttl: raw?.expires ? raw.expires - Date.now() : undefined, + }); + } + + async mapDelete(map: string, key: string): Promise { + const raw = await this.getMap(map); + + if (raw?.value) { + delete raw.value[key]; + return this.setMap(map, raw.value, { + ttl: raw.expires ? raw.expires - Date.now() : undefined, + }); + } + + return false; + } + + async mapIncrease( + map: string, + key: string, + count: number = 1 + ): Promise { + const prev = (await this.mapGet(map, key)) ?? 0; + + if (typeof prev !== 'number') { + throw new Error( + `Expect a Number keyed by ${key}, but found ${typeof prev}` + ); + } + + const curr = prev + count; + + return (await this.mapSet(map, key, curr)) ? curr : prev; + } + + async mapDecrease( + map: string, + key: string, + count: number = 1 + ): Promise { + return this.mapIncrease(map, key, -count); + } + + async mapKeys(map: string): Promise { + const raw = await this.getMap(map); + if (raw) { + return Object.keys(raw.value); + } + + return []; + } + + async mapRandomKey(map: string): Promise { + const keys = await this.mapKeys(map); + return keys[Math.floor(Math.random() * keys.length)]; + } + + async mapLen(map: string): Promise { + const raw = await this.getMap(map); + return raw ? Object.keys(raw.value).length : 0; + } +} diff --git a/packages/backend/server/src/cache/index.ts b/packages/backend/server/src/cache/index.ts new file mode 100644 index 0000000000..621407f031 --- /dev/null +++ b/packages/backend/server/src/cache/index.ts @@ -0,0 +1,24 @@ +import { FactoryProvider, Global, Module } from '@nestjs/common'; +import { Redis } from 'ioredis'; + +import { Config } from '../config'; +import { LocalCache } from './cache'; +import { RedisCache } from './redis'; + +const CacheProvider: FactoryProvider = { + provide: LocalCache, + useFactory: (config: Config) => { + return config.redis.enabled + ? new RedisCache(new Redis(config.redis)) + : new LocalCache(); + }, + inject: [Config], +}; + +@Global() +@Module({ + providers: [CacheProvider], + exports: [CacheProvider], +}) +export class CacheModule {} +export { LocalCache as Cache }; diff --git a/packages/backend/server/src/cache/redis.ts b/packages/backend/server/src/cache/redis.ts new file mode 100644 index 0000000000..8774c894e2 --- /dev/null +++ b/packages/backend/server/src/cache/redis.ts @@ -0,0 +1,194 @@ +import { Redis } from 'ioredis'; + +import { Cache, CacheSetOptions } from './cache'; + +export class RedisCache implements Cache { + constructor(private readonly redis: Redis) {} + + // standard operation + async get(key: string): Promise { + return this.redis + .get(key) + .then(v => { + if (v) { + return JSON.parse(v); + } + return undefined; + }) + .catch(() => undefined); + } + + async set( + key: string, + value: T, + opts: CacheSetOptions = {} + ): Promise { + if (opts.ttl) { + return this.redis + .set(key, JSON.stringify(value), 'PX', opts.ttl) + .then(() => true) + .catch(() => false); + } + + return this.redis + .set(key, JSON.stringify(value)) + .then(() => true) + .catch(() => false); + } + + async increase(key: string, count: number = 1): Promise { + return this.redis.incrby(key, count).catch(() => 0); + } + + async decrease(key: string, count: number = 1): Promise { + return this.redis.decrby(key, count).catch(() => 0); + } + + async setnx( + key: string, + value: T, + opts: CacheSetOptions = {} + ): Promise { + if (opts.ttl) { + return this.redis + .set(key, JSON.stringify(value), 'PX', opts.ttl, 'NX') + .then(v => !!v) + .catch(() => false); + } + + return this.redis + .set(key, JSON.stringify(value), 'NX') + .then(v => !!v) + .catch(() => false); + } + + async delete(key: string): Promise { + return this.redis + .del(key) + .then(v => v > 0) + .catch(() => false); + } + + async has(key: string): Promise { + return this.redis + .exists(key) + .then(v => v > 0) + .catch(() => false); + } + + async ttl(key: string): Promise { + return this.redis.ttl(key).catch(() => 0); + } + + async expire(key: string, ttl: number): Promise { + return this.redis + .pexpire(key, ttl) + .then(v => v > 0) + .catch(() => false); + } + + // list operations + async pushBack(key: string, ...values: T[]): Promise { + return this.redis + .rpush(key, ...values.map(v => JSON.stringify(v))) + .catch(() => 0); + } + + async pushFront(key: string, ...values: T[]): Promise { + return this.redis + .lpush(key, ...values.map(v => JSON.stringify(v))) + .catch(() => 0); + } + + async len(key: string): Promise { + return this.redis.llen(key).catch(() => 0); + } + + async list( + key: string, + start: number, + end: number + ): Promise { + return this.redis + .lrange(key, start, end) + .then(data => data.map(v => JSON.parse(v))) + .catch(() => []); + } + + async popFront(key: string, count: number = 1): Promise { + return this.redis + .lpop(key, count) + .then(data => (data ?? []).map(v => JSON.parse(v))) + .catch(() => []); + } + + async popBack(key: string, count: number = 1): Promise { + return this.redis + .rpop(key, count) + .then(data => (data ?? []).map(v => JSON.parse(v))) + .catch(() => []); + } + + // map operations + async mapSet( + map: string, + key: string, + value: T + ): Promise { + return this.redis + .hset(map, key, JSON.stringify(value)) + .then(v => v > 0) + .catch(() => false); + } + + async mapIncrease( + map: string, + key: string, + count: number = 1 + ): Promise { + return this.redis.hincrby(map, key, count); + } + + async mapDecrease( + map: string, + key: string, + count: number = 1 + ): Promise { + return this.redis.hincrby(map, key, -count); + } + + async mapGet(map: string, key: string): Promise { + return this.redis + .hget(map, key) + .then(v => (v ? JSON.parse(v) : undefined)) + .catch(() => undefined); + } + + async mapDelete(map: string, key: string): Promise { + return this.redis + .hdel(map, key) + .then(v => v > 0) + .catch(() => false); + } + + async mapKeys(map: string): Promise { + return this.redis.hkeys(map).catch(() => []); + } + + async mapRandomKey(map: string): Promise { + return this.redis + .hrandfield(map, 1) + .then(v => + typeof v === 'string' + ? v + : Array.isArray(v) + ? (v[0] as string) + : undefined + ) + .catch(() => undefined); + } + + async mapLen(map: string): Promise { + return this.redis.hlen(map).catch(() => 0); + } +} diff --git a/packages/backend/server/tests/cache.spec.ts b/packages/backend/server/tests/cache.spec.ts new file mode 100644 index 0000000000..6ffc2c2f00 --- /dev/null +++ b/packages/backend/server/tests/cache.spec.ts @@ -0,0 +1,108 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import test from 'ava'; + +import { Cache, CacheModule } from '../src/cache'; +import { ConfigModule } from '../src/config'; + +let cache: Cache; +let module: TestingModule; +test.beforeEach(async () => { + module = await Test.createTestingModule({ + imports: [ConfigModule.forRoot(), CacheModule], + }).compile(); + const prefix = Math.random().toString(36).slice(2, 7); + cache = new Proxy(module.get(Cache), { + get(target, prop) { + // @ts-expect-error safe + const fn = target[prop]; + if (typeof fn === 'function') { + // replase first parameter of fn with prefix + return (...args: any[]) => + fn.call(target, `${prefix}:${args[0]}`, ...args.slice(1)); + } + + return fn; + }, + }); +}); + +test.afterEach(async () => { + await module.close(); +}); + +test('should be able to set normal cache', async t => { + t.true(await cache.set('test', 1)); + t.is(await cache.get('test'), 1); + + t.true(await cache.has('test')); + t.true(await cache.delete('test')); + t.is(await cache.get('test'), undefined); + + t.true(await cache.set('test', { a: 1 })); + t.deepEqual(await cache.get('test'), { a: 1 }); +}); + +test('should be able to set cache with non-exiting flag', async t => { + t.true(await cache.setnx('test', 1)); + t.false(await cache.setnx('test', 2)); + t.is(await cache.get('test'), 1); +}); + +test('should be able to set cache with ttl', async t => { + t.true(await cache.set('test', 1)); + t.is(await cache.get('test'), 1); + + t.true(await cache.expire('test', 1 * 1000)); + const ttl = await cache.ttl('test'); + t.true(ttl <= 1 * 1000); + t.true(ttl > 0); +}); + +test('should be able to incr/decr number cache', async t => { + t.true(await cache.set('test', 1)); + t.is(await cache.increase('test'), 2); + t.is(await cache.increase('test'), 3); + t.is(await cache.decrease('test'), 2); + t.is(await cache.decrease('test'), 1); + + // increase an nonexists number + t.is(await cache.increase('test2'), 1); + t.is(await cache.increase('test2'), 2); +}); + +test('should be able to manipulate list cache', async t => { + t.is(await cache.pushBack('test', 1), 1); + t.is(await cache.pushBack('test', 2, 3, 4), 4); + t.is(await cache.len('test'), 4); + + t.deepEqual(await cache.list('test', 1, -1), [2, 3, 4]); + + t.deepEqual(await cache.popFront('test', 2), [1, 2]); + t.deepEqual(await cache.popBack('test', 1), [4]); + + t.is(await cache.pushBack('test2', { a: 1 }), 1); + t.deepEqual(await cache.popFront('test2', 1), [{ a: 1 }]); +}); + +test('should be able to manipulate map cache', async t => { + t.is(await cache.mapSet('test', 'a', 1), true); + t.is(await cache.mapSet('test', 'b', 2), true); + t.is(await cache.mapLen('test'), 2); + + t.is(await cache.mapGet('test', 'a'), 1); + t.is(await cache.mapGet('test', 'b'), 2); + + t.is(await cache.mapIncrease('test', 'a'), 2); + t.is(await cache.mapIncrease('test', 'a'), 3); + t.is(await cache.mapDecrease('test', 'b', 3), -1); + + const keys = await cache.mapKeys('test'); + t.deepEqual(keys, ['a', 'b']); + + const randomKey = await cache.mapRandomKey('test'); + t.truthy(randomKey); + t.true(keys.includes(randomKey!)); + + t.is(await cache.mapDelete('test', 'a'), true); + t.is(await cache.mapGet('test', 'a'), undefined); +});