fix(server): event handler bindings (#10165)

This commit is contained in:
forehalo
2025-02-14 11:29:02 +00:00
parent 42e0563d2e
commit 3dde47dd08
18 changed files with 486 additions and 260 deletions

View File

@@ -1,59 +1,44 @@
import {
applyDecorators,
Injectable,
Logger,
OnApplicationBootstrap,
OnModuleInit,
} from '@nestjs/common';
import {
EventEmitter2,
EventEmitterReadinessWatcher,
OnEvent as RawOnEvent,
OnEventMetadata,
} from '@nestjs/event-emitter';
import { DiscoveryService, MetadataScanner } from '@nestjs/core';
import {
OnGatewayConnection,
WebSocketGateway,
WebSocketServer,
} from '@nestjs/websockets';
import { CLS_ID, ClsService } from 'nestjs-cls';
import EventEmitter2, { type OnOptions } from 'eventemitter2';
import { CLS_ID, ClsService, ClsServiceManager } from 'nestjs-cls';
import type { Server, Socket } from 'socket.io';
import { CallMetric } from '../metrics';
import { wrapCallMetric } from '../metrics';
import { PushMetadata, sliceMetadata } from '../nestjs';
import { genRequestId } from '../utils';
import type { EventName } from './def';
const EventHandlerWrapper = (event: EventName): MethodDecorator => {
// @ts-expect-error allow
return (
_target,
key,
desc: TypedPropertyDescriptor<(...args: any[]) => any>
) => {
const originalMethod = desc.value;
if (!originalMethod) {
return desc;
}
const EVENT_LISTENER_METADATA = Symbol('event_listener');
interface EventHandlerMetadata {
namespace: string;
event: EventName;
opts?: OnOptions;
}
desc.value = function (...args: any[]) {
new Logger(EventBus.name).log(
`Event handler: ${event} (${key.toString()})`
);
return originalMethod.apply(this, args);
};
};
};
interface EventOptions extends OnOptions {
prepend?: boolean;
name?: string;
suppressError?: boolean;
}
export const OnEvent = (
event: EventName,
opts?: OnEventMetadata['options']
) => {
export const OnEvent = (event: EventName, opts?: EventOptions) => {
const namespace = event.split('.')[0];
return applyDecorators(
EventHandlerWrapper(event),
CallMetric('event', 'event_handler', undefined, { event, namespace }),
RawOnEvent(event, opts)
);
return PushMetadata<EventHandlerMetadata>(EVENT_LISTENER_METADATA, {
namespace,
event,
opts,
});
};
/**
@@ -63,7 +48,9 @@ export const OnEvent = (
namespace: 's2s',
})
@Injectable()
export class EventBus implements OnGatewayConnection, OnApplicationBootstrap {
export class EventBus
implements OnGatewayConnection, OnApplicationBootstrap, OnModuleInit
{
private readonly logger = new Logger(EventBus.name);
@WebSocketServer()
@@ -71,8 +58,9 @@ export class EventBus implements OnGatewayConnection, OnApplicationBootstrap {
constructor(
private readonly emitter: EventEmitter2,
private readonly watcher: EventEmitterReadinessWatcher,
private readonly cls: ClsService
private readonly cls: ClsService,
private readonly discovery: DiscoveryService,
private readonly scanner: MetadataScanner
) {}
handleConnection(client: Socket) {
@@ -83,27 +71,21 @@ export class EventBus implements OnGatewayConnection, OnApplicationBootstrap {
client.disconnect();
}
async onModuleInit() {
this.bindEventHandlers();
}
async onApplicationBootstrap() {
this.watcher
.waitUntilReady()
.then(() => {
const events = this.emitter.eventNames() as EventName[];
events.forEach(event => {
// Proxy all events received from server(trigger by `server.serverSideEmit`)
// to internal event system
this.server?.on(event, (payload, requestId?: string) => {
this.cls.run(() => {
requestId = requestId ?? genRequestId('se');
this.cls.set(CLS_ID, requestId);
this.logger.log(`Server Event: ${event} (Received)`);
this.emit(event, payload);
});
});
});
})
.catch(() => {
// startup time promise, never throw at runtime
// Proxy all events received from server(trigger by `server.serverSideEmit`)
// to internal event system
this.server?.on('broadcast', (event, payload, requestId?: string) => {
this.cls.run(() => {
requestId = requestId ?? genRequestId('event');
this.cls.set(CLS_ID, requestId);
this.logger.log(`Server Event: ${event} (Received)`);
this.emit(event, payload);
});
});
}
/**
@@ -127,22 +109,122 @@ export class EventBus implements OnGatewayConnection, OnApplicationBootstrap {
*/
broadcast<T extends EventName>(event: T, payload: Events[T]) {
this.logger.log(`Server Event: ${event} (Send)`);
this.server?.serverSideEmit(event, payload, this.cls.getId());
this.server?.serverSideEmit('broadcast', event, payload, this.cls.getId());
}
on<T extends EventName>(
event: T,
listener: (payload: Events[T]) => void | Promise<any>,
opts?: OnEventMetadata['options']
opts: EventOptions = {}
) {
this.emitter.on(event, listener as any, opts);
const namespace = event.split('.')[0];
const { name, prepend, suppressError } = opts;
let signature = name ?? listener.name ?? 'anonymous fn';
const add = prepend ? this.emitter.prependListener : this.emitter.on;
const handler = wrapCallMetric(
async (payload: any) => {
this.logger.verbose(`Handle event [${event}] (${signature})`);
const cls = ClsServiceManager.getClsService();
return await cls.run({ ifNested: 'reuse' }, async () => {
const requestId = cls.getId();
if (!requestId) {
cls.set(CLS_ID, genRequestId('event'));
}
try {
return await listener(payload);
} catch (e) {
if (suppressError) {
this.logger.error(
`Error happened when handling event [${event}] (${signature})`,
e
);
} else {
throw e;
}
}
});
},
'event',
'event_handler',
{
event,
namespace,
handler: signature,
}
);
add.call(this.emitter, event, handler as any, opts);
this.logger.verbose(
`Event handler for [${event}] registered ${name ? `in [${name}]` : ''}`
);
return () => {
this.emitter.off(event, listener as any);
this.emitter.off(event, handler as any);
};
}
waitFor<T extends EventName>(name: T, timeout?: number) {
return this.emitter.waitFor(name, timeout);
}
private bindEventHandlers() {
// make sure all our job handlers defined in [Providers] to make the code organization clean.
// const providers = [...this.discovery.getProviders(), this.discovery.getControllers()]
const providers = this.discovery.getProviders();
providers.forEach(wrapper => {
const { instance, name } = wrapper;
if (!instance || wrapper.isAlias) {
return;
}
const proto = Object.getPrototypeOf(instance);
const methods = this.scanner.getAllMethodNames(proto);
methods.forEach(method => {
const fn = instance[method];
let defs = sliceMetadata<EventHandlerMetadata>(
EVENT_LISTENER_METADATA,
fn
);
if (defs.length === 0) {
return;
}
const signature = `${name}.${method}`;
if (typeof fn !== 'function') {
throw new Error(`Event handler [${signature}] is not a function.`);
}
if (!wrapper.isDependencyTreeStatic()) {
throw new Error(
`Provider [${name}] could not be RequestScoped or TransientScoped injectable if it contains event handlers.`
);
}
defs.forEach(({ event, opts }) => {
this.on(
event,
(payload: any) => {
// NOTE(@forehalo):
// we might create spies on the event handlers when testing,
// avoid reusing `fn` variable to fail the spies or stubs
return instance[method](payload);
},
{
...opts,
name: signature,
}
);
});
});
});
}
}

View File

@@ -1,12 +1,18 @@
import { Global, Module } from '@nestjs/common';
import { EventEmitterModule } from '@nestjs/event-emitter';
import { DiscoveryModule } from '@nestjs/core';
import EventEmitter2 from 'eventemitter2';
import { EventBus, OnEvent } from './eventbus';
const EmitProvider = {
provide: EventEmitter2,
useValue: new EventEmitter2(),
};
@Global()
@Module({
imports: [EventEmitterModule.forRoot({ global: false })],
providers: [EventBus],
imports: [DiscoveryModule],
providers: [EventBus, EmitProvider],
exports: [EventBus],
})
export class EventModule {}

View File

@@ -28,14 +28,7 @@ export { AFFiNELogger } from './logger';
export { MailService } from './mailer';
export { CallMetric, metrics } from './metrics';
export { Lock, Locker, Mutex, RequestMutex } from './mutex';
export {
GatewayErrorWrapper,
getOptionalModuleMetadata,
GlobalExceptionFilter,
mapAnyError,
mapSseError,
OptionalModule,
} from './nestjs';
export * from './nestjs';
export { type PrismaTransaction } from './prisma';
export { Runtime } from './runtime';
export * from './storage';

View File

@@ -1,5 +1,6 @@
import type { Attributes } from '@opentelemetry/api';
import { makeMethodDecorator } from '../nestjs/decorator';
import { type KnownMetricScopes, metrics } from './metrics';
/**
@@ -9,57 +10,41 @@ import { type KnownMetricScopes, metrics } from './metrics';
* @param attrs attributes
* @returns
*/
export const CallMetric = (
export const CallMetric = makeMethodDecorator(
(scope: KnownMetricScopes, name: string, attrs?: Attributes) => {
return (_target, _key, fn) => {
return wrapCallMetric(fn, scope, name, attrs);
};
}
);
export function wrapCallMetric<Fn extends (...args: any[]) => any>(
fn: Fn,
scope: KnownMetricScopes,
name: string,
record?: { timer?: boolean; count?: boolean; error?: boolean },
attrs?: Attributes
): MethodDecorator => {
// @ts-expect-error allow
return (
_target,
_key,
desc: TypedPropertyDescriptor<(...args: any[]) => any>
) => {
const originalMethod = desc.value;
if (!originalMethod) {
return desc;
) {
return async function (this: any, ...args: any[]) {
const start = Date.now();
let error = false;
try {
return await fn.call(this, ...args);
} catch (err) {
error = true;
throw err;
} finally {
const count = metrics[scope].counter('function_calls', {
description: 'function call counter',
});
const timer = metrics[scope].histogram('function_timer', {
description: 'function call time costs',
unit: 'ms',
});
count.add(1, { ...attrs, name, error });
timer.record(Date.now() - start, { ...attrs, name, error });
}
const timer = metrics[scope].histogram('function_timer', {
description: 'function call time costs',
unit: 'ms',
});
const count = metrics[scope].counter('function_calls', {
description: 'function call counter',
});
desc.value = async function (...args: any[]) {
const start = Date.now();
let error = false;
const end = () => {
timer?.record(Date.now() - start, { ...attrs, name, error });
};
try {
if (!record || !!record.count) {
count.add(1, attrs);
}
return await originalMethod.apply(this, args);
} catch (err) {
if (!record || !!record.error) {
error = true;
}
throw err;
} finally {
count.add(1, { ...attrs, name, error });
if (!record || !!record.timer) {
end();
}
}
};
return desc;
};
};
}

View File

@@ -0,0 +1,45 @@
export function makeMethodDecorator<
T extends any[],
Fn extends (...args: any[]) => any,
>(
decorator: (...args: T) => (target: any, key: string | symbol, fn: Fn) => Fn
) {
return (...args: T) => {
return (
target: any,
key: string | symbol,
desc: TypedPropertyDescriptor<any>
) => {
const originalFn = desc.value;
if (!originalFn || typeof originalFn !== 'function') {
throw new Error(
`MethodDecorator must be applied to a function but got ${typeof originalFn}`
);
}
const decoratedFn = decorator(...args)(target, key, originalFn);
desc.value = decoratedFn;
return desc;
};
};
}
export function PushMetadata<T>(key: string | symbol, value: T) {
const decorator: ClassDecorator | MethodDecorator = (
target,
_,
descriptor
) => {
const metadataTarget = descriptor?.value ?? target;
const metadataArray = Reflect.getMetadata(key, metadataTarget) || [];
metadataArray.push(value);
Reflect.defineMetadata(key, metadataArray, metadataTarget);
};
return decorator;
}
export function sliceMetadata<T>(key: string | symbol, target: any): T[] {
return Reflect.getMetadata(key, target) || [];
}

View File

@@ -1,3 +1,4 @@
import './config';
export * from './decorator';
export * from './exception';
export * from './optional-module';

View File

@@ -87,11 +87,11 @@ export function parseCookies(
* - `graphql`: graphql request
* - `http`: http request
* - `ws`: websocket request
* - `se`: server event
* - `event`: event
* - `job`: cron job
* - `rpc`: rpc request
*/
export type RequestType = GqlContextType | 'se' | 'job';
export type RequestType = GqlContextType | 'event' | 'job';
export function genRequestId(type: RequestType) {
return `${AFFiNE.flavor.type}:${type}/${randomUUID()}`;