refactor(infra): directory structure (#4615)

This commit is contained in:
Joooye_34
2023-10-18 23:30:08 +08:00
committed by GitHub
parent 814d552be8
commit bed9310519
1150 changed files with 539 additions and 584 deletions

View File

@@ -0,0 +1,13 @@
import { Controller, Get } from '@nestjs/common';
@Controller('/')
export class AppController {
@Get()
info() {
const version = AFFiNE.version;
return {
compatibility: version,
message: `AFFiNE ${version} Server`,
};
}
}

View File

@@ -0,0 +1,26 @@
import { Module } from '@nestjs/common';
import { AppController } from './app.controller';
import { ConfigModule } from './config';
import { MetricsModule } from './metrics';
import { BusinessModules } from './modules';
import { AuthModule } from './modules/auth';
import { PrismaModule } from './prisma';
import { SessionModule } from './session';
import { StorageModule } from './storage';
import { RateLimiterModule } from './throttler';
@Module({
imports: [
PrismaModule,
ConfigModule.forRoot(),
StorageModule.forRoot(),
MetricsModule,
SessionModule,
RateLimiterModule,
AuthModule,
...BusinessModules,
],
controllers: [AppController],
})
export class AppModule {}

View File

@@ -0,0 +1,366 @@
import type { ApolloDriverConfig } from '@nestjs/apollo';
import type { LeafPaths } from '../utils/types';
declare global {
// eslint-disable-next-line @typescript-eslint/no-namespace
namespace globalThis {
// eslint-disable-next-line no-var
var AFFiNE: AFFiNEConfig;
}
}
export enum ExternalAccount {
github = 'github',
google = 'google',
firebase = 'firebase',
}
type EnvConfigType = 'string' | 'int' | 'float' | 'boolean';
type ConfigPaths = LeafPaths<
Omit<
AFFiNEConfig,
| 'ENV_MAP'
| 'version'
| 'baseUrl'
| 'origin'
| 'prod'
| 'dev'
| 'test'
| 'deploy'
>,
'',
'....'
>;
/**
* parse number value from environment variables
*/
function int(value: string) {
const n = parseInt(value);
return Number.isNaN(n) ? undefined : n;
}
function float(value: string) {
const n = parseFloat(value);
return Number.isNaN(n) ? undefined : n;
}
function boolean(value: string) {
return value === '1' || value.toLowerCase() === 'true';
}
export function parseEnvValue(value: string | undefined, type?: EnvConfigType) {
if (typeof value === 'undefined') {
return;
}
return type === 'int'
? int(value)
: type === 'float'
? float(value)
: type === 'boolean'
? boolean(value)
: value;
}
/**
* All Configurations that would control AFFiNE server behaviors
*
*/
export interface AFFiNEConfig {
ENV_MAP: Record<string, ConfigPaths | [ConfigPaths, EnvConfigType?]>;
/**
* Server Identity
*/
readonly serverId: string;
/**
* System version
*/
readonly version: string;
/**
* Deployment environment
*/
readonly affineEnv: 'dev' | 'beta' | 'production';
/**
* alias to `process.env.NODE_ENV`
*
* @default 'production'
* @env NODE_ENV
*/
readonly env: string;
/**
* fast AFFiNE environment judge
*/
get affine(): {
canary: boolean;
beta: boolean;
stable: boolean;
};
/**
* fast environment judge
*/
get node(): {
prod: boolean;
dev: boolean;
test: boolean;
};
get deploy(): boolean;
/**
* Whether the server is hosted on a ssl enabled domain
*/
https: boolean;
/**
* where the server get deployed.
*
* @default 'localhost'
* @env AFFINE_SERVER_HOST
*/
host: string;
/**
* which port the server will listen on
*
* @default 3010
* @env AFFINE_SERVER_PORT
*/
port: number;
/**
* subpath where the server get deployed if there is.
*
* @default '' // empty string
* @env AFFINE_SERVER_SUB_PATH
*/
path: string;
/**
* Readonly property `baseUrl` is the full url of the server consists of `https://HOST:PORT/PATH`.
*
* if `host` is not `localhost` then the port will be ignored
*/
get baseUrl(): string;
/**
* Readonly property `origin` is domain origin in the form of `https://HOST:PORT` without subpath.
*
* if `host` is not `localhost` then the port will be ignored
*/
get origin(): string;
/**
* the database config
*/
db: {
url: string;
};
/**
* the apollo driver config
*/
graphql: ApolloDriverConfig;
/**
* app features flag
*/
featureFlags: {
earlyAccessPreview: boolean;
};
/**
* object storage Config
*
* all artifacts and logs will be stored on instance disk,
* and can not shared between instances if not configured
*/
objectStorage: {
/**
* whether use remote object storage
*/
r2: {
enabled: boolean;
accountId: string;
bucket: string;
accessKeyId: string;
secretAccessKey: string;
};
/**
* Only used when `enable` is `false`
*/
fs: {
path: string;
};
/**
* default storage quota
* @default 10 * 1024 * 1024 * 1024 (10GB)
*/
quota: number;
};
/**
* Rate limiter config
*/
rateLimiter: {
/**
* How long each request will be throttled (seconds)
* @default 60
* @env THROTTLE_TTL
*/
ttl: number;
/**
* How many requests can be made in the given time frame
* @default 120
* @env THROTTLE_LIMIT
*/
limit: number;
};
/**
* Redis Config
*
* whether to use redis as Socket.IO adapter
*/
redis: {
/**
* if not enabled, use in-memory adapter by default
*/
enabled: boolean;
/**
* url of redis host
*/
host: string;
/**
* port of redis
*/
port: number;
username: string;
password: string;
/**
* redis database index
*
* Rate Limiter scope: database + 1
*
* Session scope: database + 2
*
* @default 0
*/
database: number;
};
/**
* authentication config
*/
auth: {
/**
* Application access token expiration time
*/
readonly accessTokenExpiresIn: number;
/**
* Application refresh token expiration time
*/
readonly refreshTokenExpiresIn: number;
/**
* Add some leeway (in seconds) to the exp and nbf validation to account for clock skew.
* Defaults to 60 if omitted.
*/
readonly leeway: number;
/**
* Application public key
*
*/
readonly publicKey: string;
/**
* Application private key
*
*/
readonly privateKey: string;
/**
* whether allow user to signup with email directly
*/
enableSignup: boolean;
/**
* whether allow user to signup by oauth providers
*/
enableOauth: boolean;
/**
* NEXTAUTH_SECRET
*/
nextAuthSecret: string;
/**
* all available oauth providers
*/
oauthProviders: Partial<
Record<
ExternalAccount,
{
enabled: boolean;
clientId: string;
clientSecret: string;
/**
* uri to start oauth flow
*/
authorizationUri?: string;
/**
* uri to authenticate `access_token` when user is redirected back from oauth provider with `code`
*/
accessTokenUri?: string;
/**
* uri to get user info with authenticated `access_token`
*/
userInfoUri?: string;
args?: Record<string, any>;
}
>
>;
/**
* whether to use local email service to send email
* local debug only
*/
localEmail: boolean;
email: {
server: string;
port: number;
login: string;
sender: string;
password: string;
};
captcha: {
/**
* whether to enable captcha
*/
enable: boolean;
turnstile: {
/**
* Cloudflare Turnstile CAPTCHA secret
* default value is demo api key, witch always return success
*/
secret: string;
};
challenge: {
/**
* challenge bits length
* default value is 20, which can resolve in 0.5-3 second in M2 MacBook Air in single thread
* @default 20
*/
bits: number;
};
};
};
doc: {
manager: {
/**
* How often the [DocManager] will start a new turn of merging pending updates into doc snapshot.
*
* This is not the latency a new joint client will take to see the latest doc,
* but the buffer time we introduced to reduce the load of our service.
*
* in {ms}
*/
updatePollInterval: number;
/**
* Use JwstCodec to merge updates at the same time when merging using Yjs.
*
* This is an experimental feature, and aimed to check the correctness of JwstCodec.
*/
experimentalMergeWithJwstCodec: boolean;
};
};
}

View File

@@ -0,0 +1,215 @@
/// <reference types="../global.d.ts" />
import { createPrivateKey, createPublicKey } from 'node:crypto';
import { homedir } from 'node:os';
import { join } from 'node:path';
import parse from 'parse-duration';
import pkg from '../../package.json' assert { type: 'json' };
import type { AFFiNEConfig } from './def';
import { applyEnvToConfig } from './env';
// Don't use this in production
export const examplePrivateKey = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIEtyAJLIULkphVhqXqxk4Nr8Ggty3XLwUJWBxzAWCWTMoAoGCCqGSM49
AwEHoUQDQgAEF3U/0wIeJ3jRKXeFKqQyBKlr9F7xaAUScRrAuSP33rajm3cdfihI
3JvMxVNsS2lE8PSGQrvDrJZaDo0L+Lq9Gg==
-----END EC PRIVATE KEY-----`;
const jwtKeyPair = (function () {
const AUTH_PRIVATE_KEY = process.env.AUTH_PRIVATE_KEY ?? examplePrivateKey;
const privateKey = createPrivateKey({
key: Buffer.from(AUTH_PRIVATE_KEY),
format: 'pem',
type: 'sec1',
})
.export({
format: 'pem',
type: 'pkcs8',
})
.toString('utf8');
const publicKey = createPublicKey({
key: Buffer.from(AUTH_PRIVATE_KEY),
format: 'pem',
type: 'spki',
})
.export({
format: 'pem',
type: 'spki',
})
.toString('utf8');
return {
publicKey,
privateKey,
};
})();
export const getDefaultAFFiNEConfig: () => AFFiNEConfig = () => {
const defaultConfig = {
serverId: 'affine-nestjs-server',
version: pkg.version,
ENV_MAP: {
AFFINE_SERVER_PORT: ['port', 'int'],
AFFINE_SERVER_HOST: 'host',
AFFINE_SERVER_SUB_PATH: 'path',
AFFINE_ENV: 'affineEnv',
AFFINE_FREE_USER_QUOTA: 'objectStorage.quota',
DATABASE_URL: 'db.url',
ENABLE_R2_OBJECT_STORAGE: ['objectStorage.r2.enabled', 'boolean'],
R2_OBJECT_STORAGE_ACCOUNT_ID: 'objectStorage.r2.accountId',
R2_OBJECT_STORAGE_ACCESS_KEY_ID: 'objectStorage.r2.accessKeyId',
R2_OBJECT_STORAGE_SECRET_ACCESS_KEY: 'objectStorage.r2.secretAccessKey',
R2_OBJECT_STORAGE_BUCKET: 'objectStorage.r2.bucket',
ENABLE_CAPTCHA: ['auth.captcha.enable', 'boolean'],
CAPTCHA_TURNSTILE_SECRET: ['auth.captcha.turnstile.secret', 'string'],
OAUTH_GOOGLE_ENABLED: ['auth.oauthProviders.google.enabled', 'boolean'],
OAUTH_GOOGLE_CLIENT_ID: 'auth.oauthProviders.google.clientId',
OAUTH_GOOGLE_CLIENT_SECRET: 'auth.oauthProviders.google.clientSecret',
OAUTH_GITHUB_ENABLED: ['auth.oauthProviders.github.enabled', 'boolean'],
OAUTH_GITHUB_CLIENT_ID: 'auth.oauthProviders.github.clientId',
OAUTH_GITHUB_CLIENT_SECRET: 'auth.oauthProviders.github.clientSecret',
OAUTH_EMAIL_LOGIN: 'auth.email.login',
OAUTH_EMAIL_SENDER: 'auth.email.sender',
OAUTH_EMAIL_SERVER: 'auth.email.server',
OAUTH_EMAIL_PORT: ['auth.email.port', 'int'],
OAUTH_EMAIL_PASSWORD: 'auth.email.password',
THROTTLE_TTL: ['rateLimiter.ttl', 'int'],
THROTTLE_LIMIT: ['rateLimiter.limit', 'int'],
REDIS_SERVER_ENABLED: ['redis.enabled', 'boolean'],
REDIS_SERVER_HOST: 'redis.host',
REDIS_SERVER_PORT: ['redis.port', 'int'],
REDIS_SERVER_USER: 'redis.username',
REDIS_SERVER_PASSWORD: 'redis.password',
REDIS_SERVER_DATABASE: ['redis.database', 'int'],
DOC_MERGE_INTERVAL: ['doc.manager.updatePollInterval', 'int'],
DOC_MERGE_USE_JWST_CODEC: [
'doc.manager.experimentalMergeWithJwstCodec',
'boolean',
],
ENABLE_LOCAL_EMAIL: ['auth.localEmail', 'boolean'],
} satisfies AFFiNEConfig['ENV_MAP'],
affineEnv: 'dev',
get affine() {
const env = this.affineEnv;
return {
canary: env === 'dev',
beta: env === 'beta',
stable: env === 'production',
};
},
env: process.env.NODE_ENV ?? 'development',
get node() {
const env = this.env;
return {
prod: env === 'production',
dev: env === 'development',
test: env === 'test',
};
},
get deploy() {
return !this.node.dev && !this.node.test;
},
get featureFlags() {
return {
earlyAccessPreview:
this.node.prod && (this.affine.beta || this.affine.canary),
};
},
get https() {
return !this.node.dev;
},
host: 'localhost',
port: 3010,
path: '',
db: {
url: '',
},
get origin() {
return this.node.dev
? 'http://localhost:8080'
: `${this.https ? 'https' : 'http'}://${this.host}${
this.host === 'localhost' ? `:${this.port}` : ''
}`;
},
get baseUrl() {
return `${this.origin}${this.path}`;
},
graphql: {
buildSchemaOptions: {
numberScalarMode: 'integer',
},
introspection: true,
playground: true,
},
auth: {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
accessTokenExpiresIn: parse('1h')! / 1000,
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
refreshTokenExpiresIn: parse('7d')! / 1000,
leeway: 60,
captcha: {
enable: false,
turnstile: {
secret: '1x0000000000000000000000000000000AA',
},
challenge: {
bits: 20,
},
},
privateKey: jwtKeyPair.privateKey,
publicKey: jwtKeyPair.publicKey,
enableSignup: true,
enableOauth: false,
get nextAuthSecret() {
return this.privateKey;
},
oauthProviders: {},
localEmail: false,
email: {
server: 'smtp.gmail.com',
port: 465,
login: '',
sender: '',
password: '',
},
},
objectStorage: {
r2: {
enabled: false,
bucket: '',
accountId: '',
accessKeyId: '',
secretAccessKey: '',
},
fs: {
path: join(homedir(), '.affine-storage'),
},
// 10GB
quota: 10 * 1024 * 1024 * 1024,
},
rateLimiter: {
ttl: 60,
limit: 60,
},
redis: {
enabled: false,
host: '127.0.0.1',
port: 6379,
username: '',
password: '',
database: 0,
},
doc: {
manager: {
updatePollInterval: 3000,
experimentalMergeWithJwstCodec: false,
},
},
} satisfies AFFiNEConfig;
applyEnvToConfig(defaultConfig);
return defaultConfig;
};

View File

@@ -0,0 +1,17 @@
import { set } from 'lodash-es';
import { type AFFiNEConfig, parseEnvValue } from './def';
export function applyEnvToConfig(rawConfig: AFFiNEConfig) {
for (const env in rawConfig.ENV_MAP) {
const config = rawConfig.ENV_MAP[env];
const [path, value] =
typeof config === 'string'
? [config, process.env[env]]
: [config[0], parseEnvValue(process.env[env], config[1])];
if (typeof value !== 'undefined') {
set(rawConfig, path, value);
}
}
}

View File

@@ -0,0 +1,75 @@
// eslint-disable-next-line simple-import-sort/imports
import type { DynamicModule, FactoryProvider } from '@nestjs/common';
import { merge } from 'lodash-es';
import type { DeepPartial } from '../utils/types';
import type { AFFiNEConfig } from './def';
import '../prelude';
type ConstructorOf<T> = {
new (): T;
};
function ApplyType<T>(): ConstructorOf<T> {
// @ts-expect-error used to fake the type of config
return class Inner implements T {
constructor() {}
};
}
/**
* usage:
* ```
* import { Config } from '@affine/server'
*
* class TestConfig {
* constructor(private readonly config: Config) {}
* test() {
* return this.config.env
* }
* }
* ```
*/
export class Config extends ApplyType<AFFiNEConfig>() {}
function createConfigProvider(
override?: DeepPartial<Config>
): FactoryProvider<Config> {
return {
provide: Config,
useFactory: () => {
const wrapper = new Config();
const config = merge({}, globalThis.AFFiNE, override);
const proxy: Config = new Proxy(wrapper, {
get: (_target, property: keyof Config) => {
const desc = Object.getOwnPropertyDescriptor(
globalThis.AFFiNE,
property
);
if (desc?.get) {
return desc.get.call(proxy);
}
return config[property];
},
});
return proxy;
},
};
}
export class ConfigModule {
static forRoot = (override?: DeepPartial<Config>): DynamicModule => {
const provider = createConfigProvider(override);
return {
global: true,
module: ConfigModule,
providers: [provider],
exports: [provider],
};
};
}
export type { AFFiNEConfig } from './def';

View File

@@ -0,0 +1,3 @@
export const OPERATION_NAME = 'x-operation-name';
export const REQUEST_ID = 'x-request-id';

View File

@@ -0,0 +1,5 @@
declare namespace Express {
interface Request {
user?: import('@prisma/client').User | null;
}
}

View File

@@ -0,0 +1,41 @@
import type { ApolloDriverConfig } from '@nestjs/apollo';
import { ApolloDriver } from '@nestjs/apollo';
import { Global, Module } from '@nestjs/common';
import { GraphQLModule } from '@nestjs/graphql';
import { Request, Response } from 'express';
import { join } from 'path';
import { fileURLToPath } from 'url';
import { Config } from './config';
import { GQLLoggerPlugin } from './graphql/logger-plugin';
import { Metrics } from './metrics/metrics';
@Global()
@Module({
imports: [
GraphQLModule.forRootAsync<ApolloDriverConfig>({
driver: ApolloDriver,
useFactory: (config: Config, metrics: Metrics) => {
return {
...config.graphql,
path: `${config.path}/graphql`,
csrfPrevention: {
requestHeaders: ['content-type'],
},
autoSchemaFile: join(
fileURLToPath(import.meta.url),
'..',
'schema.gql'
),
context: ({ req, res }: { req: Request; res: Response }) => ({
req,
res,
}),
plugins: [new GQLLoggerPlugin(metrics)],
};
},
inject: [Config, Metrics],
}),
],
})
export class GqlModule {}

View File

@@ -0,0 +1,48 @@
import {
ApolloServerPlugin,
GraphQLRequestContext,
GraphQLRequestListener,
} from '@apollo/server';
import { Plugin } from '@nestjs/apollo';
import { Logger } from '@nestjs/common';
import { Response } from 'express';
import { Metrics } from '../metrics/metrics';
import { ReqContext } from '../types';
@Plugin()
export class GQLLoggerPlugin implements ApolloServerPlugin {
protected logger = new Logger(GQLLoggerPlugin.name);
constructor(private readonly metrics: Metrics) {}
requestDidStart(
reqContext: GraphQLRequestContext<ReqContext>
): Promise<GraphQLRequestListener<GraphQLRequestContext<ReqContext>>> {
const res = reqContext.contextValue.req.res as Response;
const operation = reqContext.request.operationName;
this.metrics.gqlRequest(1, { operation });
const timer = this.metrics.gqlTimer({ operation });
return Promise.resolve({
willSendResponse: () => {
const costInMilliseconds = timer() * 1000;
res.setHeader(
'Server-Timing',
`gql;dur=${costInMilliseconds};desc="GraphQL"`
);
return Promise.resolve();
},
didEncounterErrors: () => {
this.metrics.gqlError(1, { operation });
const costInMilliseconds = timer() * 1000;
res.setHeader(
'Server-Timing',
`gql;dur=${costInMilliseconds};desc="GraphQL ${operation}"`
);
return Promise.resolve();
},
});
}
}

View File

@@ -0,0 +1,104 @@
/// <reference types="./global.d.ts" />
import { MetricExporter } from '@google-cloud/opentelemetry-cloud-monitoring-exporter';
import { TraceExporter } from '@google-cloud/opentelemetry-cloud-trace-exporter';
import { NestFactory } from '@nestjs/core';
import type { NestExpressApplication } from '@nestjs/platform-express';
import {
CompositePropagator,
W3CBaggagePropagator,
W3CTraceContextPropagator,
} from '@opentelemetry/core';
import gql from '@opentelemetry/instrumentation-graphql';
import { HttpInstrumentation } from '@opentelemetry/instrumentation-http';
import ioredis from '@opentelemetry/instrumentation-ioredis';
import { NestInstrumentation } from '@opentelemetry/instrumentation-nestjs-core';
import socketIO from '@opentelemetry/instrumentation-socket.io';
import { PeriodicExportingMetricReader } from '@opentelemetry/sdk-metrics';
import { NodeSDK } from '@opentelemetry/sdk-node';
import { BatchSpanProcessor } from '@opentelemetry/sdk-trace-node';
import { PrismaInstrumentation } from '@prisma/instrumentation';
import cookieParser from 'cookie-parser';
import { static as staticMiddleware } from 'express';
import graphqlUploadExpress from 'graphql-upload/graphqlUploadExpress.mjs';
import { AppModule } from './app';
import { Config } from './config';
import { ExceptionLogger } from './middleware/exception-logger';
import { serverTimingAndCache } from './middleware/timing';
import { RedisIoAdapter } from './modules/sync/redis-adapter';
const { NODE_ENV, AFFINE_ENV } = process.env;
if (NODE_ENV === 'production') {
const traceExporter = new TraceExporter();
const tracing = new NodeSDK({
traceExporter,
metricReader: new PeriodicExportingMetricReader({
exporter: new MetricExporter(),
}),
spanProcessor: new BatchSpanProcessor(traceExporter),
textMapPropagator: new CompositePropagator({
propagators: [
new W3CBaggagePropagator(),
new W3CTraceContextPropagator(),
],
}),
instrumentations: [
new NestInstrumentation(),
new ioredis.IORedisInstrumentation(),
new socketIO.SocketIoInstrumentation({ traceReserved: true }),
new gql.GraphQLInstrumentation({ mergeItems: true }),
new HttpInstrumentation(),
new PrismaInstrumentation(),
],
serviceName: 'affine-cloud',
});
tracing.start();
}
const app = await NestFactory.create<NestExpressApplication>(AppModule, {
cors: true,
bodyParser: true,
logger:
NODE_ENV !== 'production' || AFFINE_ENV !== 'production'
? ['verbose']
: ['log'],
});
app.use(serverTimingAndCache);
app.use(
graphqlUploadExpress({
maxFileSize: 10 * 1024 * 1024,
maxFiles: 5,
})
);
app.useGlobalFilters(new ExceptionLogger());
app.use(cookieParser());
const config = app.get(Config);
const host = config.node.prod ? '0.0.0.0' : 'localhost';
const port = config.port ?? 3010;
if (!config.objectStorage.r2.enabled) {
app.use('/assets', staticMiddleware(config.objectStorage.fs.path));
}
if (config.redis.enabled) {
const redisIoAdapter = new RedisIoAdapter(app);
await redisIoAdapter.connectToRedis(
config.redis.host,
config.redis.port,
config.redis.username,
config.redis.password,
config.redis.database
);
app.useWebSocketAdapter(redisIoAdapter);
}
await app.listen(port, host);
console.log(`Listening on http://${host}:${port}`);

View File

@@ -0,0 +1,18 @@
import { Controller, Get, Res } from '@nestjs/common';
import type { Response } from 'express';
import { register } from 'prom-client';
import { PrismaService } from '../prisma';
@Controller()
export class MetricsController {
constructor(private readonly prisma: PrismaService) {}
@Get('/metrics')
async index(@Res() res: Response): Promise<void> {
res.header('Content-Type', register.contentType);
const prismaMetrics = await this.prisma.$metrics.prometheus();
const appMetrics = await register.metrics();
res.send(appMetrics + prismaMetrics);
}
}

View File

@@ -0,0 +1,12 @@
import { Global, Module } from '@nestjs/common';
import { MetricsController } from '../metrics/controller';
import { Metrics } from './metrics';
@Global()
@Module({
providers: [Metrics],
exports: [Metrics],
controllers: [MetricsController],
})
export class MetricsModule {}

View File

@@ -0,0 +1,28 @@
import { Injectable, OnModuleDestroy } from '@nestjs/common';
import { register } from 'prom-client';
import { metricsCreator } from './utils';
@Injectable()
export class Metrics implements OnModuleDestroy {
onModuleDestroy(): void {
register.clear();
}
socketIOEventCounter = metricsCreator.counter('socket_io_counter', ['event']);
socketIOEventTimer = metricsCreator.timer('socket_io_timer', ['event']);
socketIOConnectionGauge = metricsCreator.gauge(
'socket_io_connection_counter'
);
gqlRequest = metricsCreator.counter('gql_request', ['operation']);
gqlError = metricsCreator.counter('gql_error', ['operation']);
gqlTimer = metricsCreator.timer('gql_timer', ['operation']);
jwstCodecMerge = metricsCreator.counter('jwst_codec_merge');
jwstCodecDidnotMatch = metricsCreator.counter('jwst_codec_didnot_match');
jwstCodecFail = metricsCreator.counter('jwst_codec_fail');
authCounter = metricsCreator.counter('auth');
authFailCounter = metricsCreator.counter('auth_fail', ['reason']);
}

View File

@@ -0,0 +1,73 @@
import { Counter, Gauge, Summary } from 'prom-client';
type LabelValues<T extends string> = Partial<Record<T, string | number>>;
type MetricsCreator<T extends string> = (
value: number,
labels: LabelValues<T>
) => void;
type TimerMetricsCreator<T extends string> = (
labels: LabelValues<T>
) => () => number;
export const metricsCreatorGenerator = () => {
const counterCreator = <T extends string>(
name: string,
labelNames?: T[]
): MetricsCreator<T> => {
const counter = new Counter({
name,
help: name,
...(labelNames ? { labelNames } : {}),
});
return (value: number, labels: LabelValues<T>) => {
counter.inc(labels, value);
};
};
const gaugeCreator = <T extends string>(
name: string,
labelNames?: T[]
): MetricsCreator<T> => {
const gauge = new Gauge({
name,
help: name,
...(labelNames ? { labelNames } : {}),
});
return (value: number, labels: LabelValues<T>) => {
gauge.set(labels, value);
};
};
const timerCreator = <T extends string>(
name: string,
labelNames?: T[]
): TimerMetricsCreator<T> => {
const summary = new Summary({
name,
help: name,
...(labelNames ? { labelNames } : {}),
});
return (labels: LabelValues<T>) => {
const now = process.hrtime();
return () => {
const delta = process.hrtime(now);
const value = delta[0] + delta[1] / 1e9;
summary.observe(labels, value);
return value;
};
};
};
return {
counter: counterCreator,
gauge: gaugeCreator,
timer: timerCreator,
};
};
export const metricsCreator = metricsCreatorGenerator();

View File

@@ -0,0 +1,52 @@
import {
ArgumentsHost,
Catch,
ExceptionFilter,
HttpException,
Logger,
NotFoundException,
} from '@nestjs/common';
import { GqlContextType } from '@nestjs/graphql';
import { Request, Response } from 'express';
import { REQUEST_ID } from '../constants';
const TrivialExceptions = [NotFoundException];
@Catch()
export class ExceptionLogger implements ExceptionFilter {
private logger = new Logger('ExceptionLogger');
catch(exception: Error, host: ArgumentsHost) {
// with useGlobalFilters, the context is always HTTP
const ctx = host.switchToHttp();
const request = ctx.getRequest<Request>();
const requestId = request?.header(REQUEST_ID);
const shouldVerboseLog = !TrivialExceptions.some(
e => exception instanceof e
);
this.logger.error(
new Error(
`${requestId ? `requestId-${requestId}: ` : ''}${exception.message}${
shouldVerboseLog ? '\n' + exception.stack : ''
}}`,
{ cause: exception }
)
);
if (host.getType<GqlContextType>() === 'graphql') {
return;
}
const response = ctx.getResponse<Response>();
if (exception instanceof HttpException) {
response.status(exception.getStatus()).json(exception.getResponse());
} else {
response.status(500).json({
statusCode: 500,
error: exception.message,
});
}
}
}

View File

@@ -0,0 +1,27 @@
import { NextFunction, Request, Response } from 'express';
import onHeaders from 'on-headers';
export const serverTimingAndCache = (
req: Request,
res: Response,
next: NextFunction
) => {
req.res = res;
const now = process.hrtime();
onHeaders(res, () => {
const delta = process.hrtime(now);
const costInMilliseconds = (delta[0] + delta[1] / 1e9) * 1000;
const serverTiming = res.getHeader('Server-Timing') as string | undefined;
const serverTimingValue = `${
serverTiming ? `${serverTiming}, ` : ''
}total;dur=${costInMilliseconds}`;
res.setHeader('Server-Timing', serverTimingValue);
});
res.setHeader('Cache-Control', 'max-age=0, private, must-revalidate');
next();
};

View File

@@ -0,0 +1,150 @@
import type { CanActivate, ExecutionContext } from '@nestjs/common';
import {
createParamDecorator,
Inject,
Injectable,
SetMetadata,
UseGuards,
} from '@nestjs/common';
import { Reflector } from '@nestjs/core';
import type { NextAuthOptions } from 'next-auth';
import { AuthHandler } from 'next-auth/core';
import { PrismaService } from '../../prisma';
import { getRequestResponseFromContext } from '../../utils/nestjs';
import { NextAuthOptionsProvide } from './next-auth-options';
import { AuthService } from './service';
export function getUserFromContext(context: ExecutionContext) {
return getRequestResponseFromContext(context).req.user;
}
/**
* Used to fetch current user from the request context.
*
* > The user may be undefined if authorization token is not provided.
*
* @example
*
* ```typescript
* // Graphql Query
* \@Query(() => UserType)
* user(@CurrentUser() user?: User) {
* return user;
* }
* ```
*
* ```typescript
* // HTTP Controller
* \@Get('/user)
* user(@CurrentUser() user?: User) {
* return user;
* }
* ```
*/
export const CurrentUser = createParamDecorator(
(_: unknown, context: ExecutionContext) => {
return getUserFromContext(context);
}
);
@Injectable()
class AuthGuard implements CanActivate {
constructor(
@Inject(NextAuthOptionsProvide)
private readonly nextAuthOptions: NextAuthOptions,
private auth: AuthService,
private prisma: PrismaService,
private readonly reflector: Reflector
) {}
async canActivate(context: ExecutionContext) {
const { req, res } = getRequestResponseFromContext(context);
const token = req.headers.authorization;
// api is public
const isPublic = this.reflector.get<boolean>(
'isPublic',
context.getHandler()
);
// api can be public, but if user is logged in, we can get user info
const isPublicable = this.reflector.get<boolean>(
'isPublicable',
context.getHandler()
);
if (isPublic) {
return true;
} else if (!token) {
if (!req.cookies) {
return isPublicable;
}
const session = await AuthHandler({
req: {
cookies: req.cookies,
action: 'session',
method: 'GET',
headers: req.headers,
},
options: this.nextAuthOptions,
});
const { body = {}, cookies, status = 200 } = session;
if (!body && !isPublicable) {
return false;
}
// @ts-expect-error body is user here
req.user = body.user;
if (cookies && res) {
for (const cookie of cookies) {
res.cookie(cookie.name, cookie.value, cookie.options);
}
}
return Boolean(
status === 200 &&
typeof body !== 'string' &&
// ignore body if api is publicable
(Object.keys(body).length || isPublicable)
);
} else {
const [type, jwt] = token.split(' ') ?? [];
if (type === 'Bearer') {
const claims = await this.auth.verify(jwt);
req.user = await this.prisma.user.findUnique({
where: { id: claims.id },
});
return !!req.user;
}
}
return false;
}
}
/**
* This guard is used to protect routes/queries/mutations that require a user to be logged in.
*
* The `@CurrentUser()` parameter decorator used in a `Auth` guarded queries would always give us the user because the `Auth` guard will
* fast throw if user is not logged in.
*
* @example
*
* ```typescript
* \@Auth()
* \@Query(() => UserType)
* user(@CurrentUser() user: User) {
* return user;
* }
* ```
*/
export const Auth = () => {
return UseGuards(AuthGuard);
};
// api is public accessible
export const Public = () => SetMetadata('isPublic', true);
// api is public accessible, but if user is logged in, we can get user info
export const Publicable = () => SetMetadata('isPublicable', true);

View File

@@ -0,0 +1,26 @@
import { Global, Module } from '@nestjs/common';
import { SessionModule } from '../../session';
import { MAILER, MailService } from './mailer';
import { NextAuthController } from './next-auth.controller';
import { NextAuthOptionsProvider } from './next-auth-options';
import { AuthResolver } from './resolver';
import { AuthService } from './service';
@Global()
@Module({
imports: [SessionModule],
providers: [
AuthService,
AuthResolver,
NextAuthOptionsProvider,
MAILER,
MailService,
],
exports: [AuthService, NextAuthOptionsProvider, MailService],
controllers: [NextAuthController],
})
export class AuthModule {}
export * from './guard';
export { TokenType } from './resolver';

View File

@@ -0,0 +1,2 @@
export { MailService } from './mail.service';
export { MAILER } from './mailer';

View File

@@ -0,0 +1,236 @@
import { Inject, Injectable } from '@nestjs/common';
import { Config } from '../../../config';
import {
MAILER_SERVICE,
type MailerService,
type Options,
type Response,
} from './mailer';
import { emailTemplate } from './template';
@Injectable()
export class MailService {
constructor(
@Inject(MAILER_SERVICE) private readonly mailer: MailerService,
private readonly config: Config
) {}
async sendMail(options: Options): Promise<Response> {
return this.mailer.sendMail(options);
}
hasConfigured() {
return (
!!this.config.auth.email.login &&
!!this.config.auth.email.password &&
!!this.config.auth.email.sender
);
}
async sendInviteEmail(
to: string,
inviteId: string,
invitationInfo: {
workspace: {
id: string;
name: string;
avatar: string;
};
user: {
avatar: string;
name: string;
};
}
) {
// TODO: use callback url when need support desktop app
const buttonUrl = `${this.config.origin}/invite/${inviteId}`;
const workspaceAvatar = invitationInfo.workspace.avatar;
const content = `<p style="margin:0">${
invitationInfo.user.avatar
? `<img
src="${invitationInfo.user.avatar}"
alt=""
width="24px"
height="24px"
style="width:24px; height:24px; border-radius: 12px;object-fit: cover;vertical-align: middle"
/>`
: ''
}
<span style="font-weight:500;margin-right: 4px;">${
invitationInfo.user.name
}</span>
<span>invited you to join</span>
<img
src="cid:workspaceAvatar"
alt=""
width="24px"
height="24px"
style="width:24px; height:24px; margin-left:4px;border-radius: 12px;object-fit: cover;vertical-align: middle"
/>
<span style="font-weight:500;margin-right: 4px;">${
invitationInfo.workspace.name
}</span></p><p style="margin-top:8px;margin-bottom:0;">Click button to join this workspace</p>`;
const subContent =
'Currently, AFFiNE Cloud is in the early access stage. Only Early Access Sponsors can register and log in to AFFiNE Cloud. <a href="https://community.affine.pro/c/insider-general/" style="color: #1e67af" >Please click here for more information.</a>';
const html = emailTemplate({
title: 'You are invited!',
content,
buttonContent: 'Accept & Join',
buttonUrl,
subContent,
});
return this.sendMail({
from: this.config.auth.email.sender,
to,
subject: `${invitationInfo.user.name} invited you to join ${invitationInfo.workspace.name}`,
html,
attachments: [
{
cid: 'workspaceAvatar',
filename: 'image.png',
content: workspaceAvatar,
encoding: 'base64',
},
],
});
}
async sendSignInEmail(url: string, options: Options) {
const html = emailTemplate({
title: 'Sign in to AFFiNE',
content:
'Click the button below to securely sign in. The magic link will expire in 30 minutes.',
buttonContent: 'Sign in to AFFiNE',
buttonUrl: url,
});
return this.sendMail({
html,
subject: 'Sign in to AFFiNE',
...options,
});
}
async sendChangePasswordEmail(to: string, url: string) {
const html = emailTemplate({
title: 'Modify your AFFiNE password',
content:
'Click the button below to reset your password. The magic link will expire in 30 minutes.',
buttonContent: 'Set new password',
buttonUrl: url,
});
return this.sendMail({
from: this.config.auth.email.sender,
to,
subject: `Modify your AFFiNE password`,
html,
});
}
async sendSetPasswordEmail(to: string, url: string) {
const html = emailTemplate({
title: 'Set your AFFiNE password',
content:
'Click the button below to set your password. The magic link will expire in 30 minutes.',
buttonContent: 'Set your password',
buttonUrl: url,
});
return this.sendMail({
from: this.config.auth.email.sender,
to,
subject: `Set your AFFiNE password`,
html,
});
}
async sendChangeEmail(to: string, url: string) {
const html = emailTemplate({
title: 'Verify your current email for AFFiNE',
content:
'You recently requested to change the email address associated with your AFFiNE account. To complete this process, please click on the verification link below. This magic link will expire in 30 minutes.',
buttonContent: 'Verify and set up a new email address',
buttonUrl: url,
});
return this.sendMail({
from: this.config.auth.email.sender,
to,
subject: `Verify your current email for AFFiNE`,
html,
});
}
async sendVerifyChangeEmail(to: string, url: string) {
const html = emailTemplate({
title: 'Verify your new email address',
content:
'You recently requested to change the email address associated with your AFFiNE account. To complete this process, please click on the verification link below. This magic link will expire in 30 minutes.',
buttonContent: 'Verify your new email address',
buttonUrl: url,
});
return this.sendMail({
from: this.config.auth.email.sender,
to,
subject: `Verify your new email for AFFiNE`,
html,
});
}
async sendNotificationChangeEmail(to: string) {
const html = emailTemplate({
title: 'Email change successful',
content: `As per your request, we have changed your email. Please make sure you're using ${to} when you log in the next time. `,
});
return this.sendMail({
from: this.config.auth.email.sender,
to,
subject: `Your email has been changed`,
html,
});
}
async sendAcceptedEmail(
to: string,
{
inviteeName,
workspaceName,
}: {
inviteeName: string;
workspaceName: string;
}
) {
const title = `${inviteeName} accepted your invitation`;
const html = emailTemplate({
title,
content: `${inviteeName} has joined ${workspaceName}`,
});
return this.sendMail({
from: this.config.auth.email.sender,
to,
subject: title,
html,
});
}
async sendLeaveWorkspaceEmail(
to: string,
{
inviteeName,
workspaceName,
}: {
inviteeName: string;
workspaceName: string;
}
) {
const title = `${inviteeName} left ${workspaceName}`;
const html = emailTemplate({
title,
content: `${inviteeName} has left your workspace`,
});
return this.sendMail({
from: this.config.auth.email.sender,
to,
subject: title,
html,
});
}
}

View File

@@ -0,0 +1,38 @@
import { FactoryProvider } from '@nestjs/common';
import { createTransport, Transporter } from 'nodemailer';
import SMTPTransport from 'nodemailer/lib/smtp-transport';
import { Config } from '../../../config';
export const MAILER_SERVICE = Symbol('MAILER_SERVICE');
export type MailerService = Transporter<SMTPTransport.SentMessageInfo>;
export type Response = SMTPTransport.SentMessageInfo;
export type Options = SMTPTransport.Options;
export const MAILER: FactoryProvider<
Transporter<SMTPTransport.SentMessageInfo>
> = {
provide: MAILER_SERVICE,
useFactory: (config: Config) => {
if (config.auth.localEmail) {
return createTransport({
host: '0.0.0.0',
port: 1025,
secure: false,
auth: {
user: config.auth.email.login,
pass: config.auth.email.password,
},
});
}
return createTransport({
service: 'gmail',
auth: {
user: config.auth.email.login,
pass: config.auth.email.password,
},
});
},
inject: [Config],
};

View File

@@ -0,0 +1,221 @@
export const emailTemplate = ({
title,
content,
buttonContent,
buttonUrl,
subContent,
}: {
title: string;
content: string;
buttonContent?: string;
buttonUrl?: string;
subContent?: string;
}) => {
return `<body style="background: #f6f7fb; overflow: hidden">
<table
width="100%"
border="0"
cellpadding="24px"
style="
background: #fff;
max-width: 450px;
margin: 32px auto 0 auto;
border-radius: 16px 16px 0 0;
box-shadow: 0px 0px 20px 0px rgba(66, 65, 73, 0.04);
"
>
<tr>
<td>
<a href="https://affine.pro" target="_blank">
<img
src="https://cdn.affine.pro/mail/2023-8-9/affine-logo.png"
alt="AFFiNE log"
height="32px"
/>
</a>
</td>
</tr>
<tr>
<td
style="
font-size: 20px;
font-weight: 600;
line-height: 28px;
font-family: inter, Arial, Helvetica, sans-serif;
color: #444;
padding-top: 0;
"
>${title}</td>
</tr>
<tr>
<td
style="
font-size: 15px;
font-weight: 400;
line-height: 24px;
font-family: inter, Arial, Helvetica, sans-serif;
color: #444;
padding-top: 0;
"
>${content}</td>
</tr>
${
buttonContent && buttonUrl
? `<tr>
<td style="margin-left: 24px; padding-top: 0; padding-bottom: ${
subContent ? '0' : '64px'
}">
<table border="0" cellspacing="0" cellpadding="0">
<tr>
<td style="border-radius: 8px" bgcolor="#1E96EB">
<a
href="${buttonUrl}"
target="_blank"
style="
font-size: 15px;
font-family: inter, Arial, Helvetica, sans-serif;
font-weight: 600;
line-height: 24px;
color: #fff;
text-decoration: none;
border-radius: 8px;
padding: 8px 18px;
border: 1px solid rgba(0,0,0,.1);
display: inline-block;
font-weight: bold;
"
>${buttonContent}</a
>
</td>
</tr>
</table>
</td>
</tr>`
: ''
}
${
subContent
? `<tr>
<td
style="
font-size: 12px;
font-weight: 400;
line-height: 20px;
font-family: inter, Arial, Helvetica, sans-serif;
color: #444;
padding-top: 24px;
"
>
${subContent}
</td>
</tr>`
: ''
}
</table>
<table
width="100%"
border="0"
style="
background: #fafafa;
max-width: 450px;
margin: 0 auto 32px auto;
border-radius: 0 0 16px 16px;
box-shadow: 0px 0px 20px 0px rgba(66, 65, 73, 0.04);
padding: 20px;
"
>
<tr align="center">
<td>
<table cellpadding="0">
<tr>
<td style="padding: 0 10px">
<a
href="https://github.com/toeverything/AFFiNE"
target="_blank"
><img
src="https://cdn.affine.pro/mail/2023-8-9/Github.png"
alt="AFFiNE github link"
height="16px"
/></a>
</td>
<td style="padding: 0 10px">
<a href="https://twitter.com/AffineOfficial" target="_blank">
<img
src="https://cdn.affine.pro/mail/2023-8-9/Twitter.png"
alt="AFFiNE twitter link"
height="16px"
/>
</a>
</td>
<td style="padding: 0 10px">
<a href="https://discord.gg/Arn7TqJBvG" target="_blank"
><img
src="https://cdn.affine.pro/mail/2023-8-9/Discord.png"
alt="AFFiNE discord link"
height="16px"
/></a>
</td>
<td style="padding: 0 10px">
<a href="https://www.youtube.com/@affinepro" target="_blank"
><img
src="https://cdn.affine.pro/mail/2023-8-9/Youtube.png"
alt="AFFiNE youtube link"
height="16px"
/></a>
</td>
<td style="padding: 0 10px">
<a href="https://t.me/affineworkos" target="_blank"
><img
src="https://cdn.affine.pro/mail/2023-8-9/Telegram.png"
alt="AFFiNE telegram link"
height="16px"
/></a>
</td>
<td style="padding: 0 10px">
<a href="https://www.reddit.com/r/Affine/" target="_blank"
><img
src="https://cdn.affine.pro/mail/2023-8-9/Reddit.png"
alt="AFFiNE reddit link"
height="16px"
/></a>
</td>
</tr>
</table>
</td>
</tr>
<tr align="center">
<td
style="
font-size: 12px;
font-weight: 400;
line-height: 20px;
font-family: inter, Arial, Helvetica, sans-serif;
color: #8e8d91;
padding-top: 8px;
"
>
One hyper-fused platform for wildly creative minds
</td>
</tr>
<tr align="center">
<td
style="
font-size: 12px;
font-weight: 400;
line-height: 20px;
font-family: inter, Arial, Helvetica, sans-serif;
color: #8e8d91;
padding-top: 8px;
"
>
Copyright<img
src="https://cdn.affine.pro/mail/2023-8-9/copyright.png"
alt="copyright"
height="14px"
style="vertical-align: middle; margin: 0 4px"
/>2023 Toeverything
</td>
</tr>
</table>
</body>`;
};

View File

@@ -0,0 +1,229 @@
import { PrismaAdapter } from '@auth/prisma-adapter';
import { FactoryProvider, Logger } from '@nestjs/common';
import { verify } from '@node-rs/argon2';
import { assign, omit } from 'lodash-es';
import { NextAuthOptions } from 'next-auth';
import Credentials from 'next-auth/providers/credentials';
import Email from 'next-auth/providers/email';
import Github from 'next-auth/providers/github';
import Google from 'next-auth/providers/google';
import { Config } from '../../config';
import { PrismaService } from '../../prisma';
import { SessionService } from '../../session';
import { NewFeaturesKind } from '../users/types';
import { isStaff } from '../users/utils';
import { MailService } from './mailer';
import {
decode,
encode,
sendVerificationRequest,
SendVerificationRequestParams,
} from './utils';
export const NextAuthOptionsProvide = Symbol('NextAuthOptions');
export const NextAuthOptionsProvider: FactoryProvider<NextAuthOptions> = {
provide: NextAuthOptionsProvide,
useFactory(
config: Config,
prisma: PrismaService,
mailer: MailService,
session: SessionService
) {
const logger = new Logger('NextAuth');
const prismaAdapter = PrismaAdapter(prisma);
// createUser exists in the adapter
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const createUser = prismaAdapter.createUser!.bind(prismaAdapter);
prismaAdapter.createUser = async data => {
const userData = {
name: data.name,
email: data.email,
avatarUrl: '',
emailVerified: data.emailVerified,
};
if (data.email && !data.name) {
userData.name = data.email.split('@')[0];
}
if (data.image) {
userData.avatarUrl = data.image;
}
return createUser(userData);
};
// getUser exists in the adapter
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const getUser = prismaAdapter.getUser!.bind(prismaAdapter)!;
prismaAdapter.getUser = async id => {
const result = await getUser(id);
if (result) {
// @ts-expect-error Third part library type mismatch
result.image = result.avatarUrl;
// @ts-expect-error Third part library type mismatch
result.hasPassword = Boolean(result.password);
}
return result;
};
const nextAuthOptions: NextAuthOptions = {
providers: [
// @ts-expect-error esm interop issue
Email.default({
server: {
host: config.auth.email.server,
port: config.auth.email.port,
auth: {
user: config.auth.email.login,
pass: config.auth.email.password,
},
},
from: config.auth.email.sender,
sendVerificationRequest: (params: SendVerificationRequestParams) =>
sendVerificationRequest(config, logger, mailer, session, params),
}),
],
adapter: prismaAdapter,
debug: !config.node.prod,
session: {
strategy: 'database',
},
logger: {
debug(code, metadata) {
logger.debug(`${code}: ${JSON.stringify(metadata)}`);
},
error(code, metadata) {
if (metadata instanceof Error) {
// @ts-expect-error assign code to error
metadata.code = code;
logger.error(metadata);
} else if (metadata.error instanceof Error) {
assign(metadata.error, omit(metadata, 'error'), { code });
logger.error(metadata.error);
}
},
warn(code) {
logger.warn(code);
},
},
};
nextAuthOptions.providers.push(
// @ts-expect-error esm interop issue
Credentials.default({
name: 'Password',
credentials: {
email: {
label: 'Email',
type: 'text',
placeholder: 'torvalds@osdl.org',
},
password: { label: 'Password', type: 'password' },
},
async authorize(
credentials:
| Record<'email' | 'password' | 'hashedPassword', string>
| undefined
) {
if (!credentials) {
return null;
}
const { password, hashedPassword } = credentials;
if (!password || !hashedPassword) {
return null;
}
if (!(await verify(hashedPassword, password))) {
return null;
}
return credentials;
},
})
);
if (config.auth.oauthProviders.github) {
nextAuthOptions.providers.push(
// @ts-expect-error esm interop issue
Github.default({
clientId: config.auth.oauthProviders.github.clientId,
clientSecret: config.auth.oauthProviders.github.clientSecret,
allowDangerousEmailAccountLinking: true,
})
);
}
if (config.auth.oauthProviders.google) {
nextAuthOptions.providers.push(
// @ts-expect-error esm interop issue
Google.default({
clientId: config.auth.oauthProviders.google.clientId,
clientSecret: config.auth.oauthProviders.google.clientSecret,
checks: 'nonce',
allowDangerousEmailAccountLinking: true,
authorization: {
params: { scope: 'openid email profile', prompt: 'select_account' },
},
})
);
}
nextAuthOptions.jwt = {
encode: async ({ token, maxAge }) =>
encode(config, prisma, token, maxAge),
decode: async ({ token }) => decode(config, token),
};
nextAuthOptions.secret ??= config.auth.nextAuthSecret;
nextAuthOptions.callbacks = {
session: async ({ session, user, token }) => {
if (session.user) {
if (user) {
// @ts-expect-error Third part library type mismatch
session.user.id = user.id;
// @ts-expect-error Third part library type mismatch
session.user.image = user.image ?? user.avatarUrl;
// @ts-expect-error Third part library type mismatch
session.user.emailVerified = user.emailVerified;
// @ts-expect-error Third part library type mismatch
session.user.hasPassword = Boolean(user.password);
} else {
// technically the sub should be the same as id
// @ts-expect-error Third part library type mismatch
session.user.id = token.sub;
// @ts-expect-error Third part library type mismatch
session.user.emailVerified = token.emailVerified;
// @ts-expect-error Third part library type mismatch
session.user.hasPassword = token.hasPassword;
}
if (token && token.picture) {
session.user.image = token.picture;
}
}
return session;
},
signIn: async ({ profile, user }) => {
if (!config.featureFlags.earlyAccessPreview) {
return true;
}
const email = profile?.email ?? user.email;
if (email) {
if (isStaff(email)) {
return true;
}
return prisma.newFeaturesWaitingList
.findUnique({
where: {
email,
type: NewFeaturesKind.EarlyAccess,
},
})
.then(user => !!user)
.catch(() => false);
}
return false;
},
redirect({ url }) {
return url;
},
};
return nextAuthOptions;
},
inject: [Config, PrismaService, MailService, SessionService],
};

View File

@@ -0,0 +1,401 @@
import { URLSearchParams } from 'node:url';
import {
All,
BadRequestException,
Controller,
Get,
Inject,
Logger,
Next,
NotFoundException,
Query,
Req,
Res,
UseGuards,
} from '@nestjs/common';
import { hash, verify } from '@node-rs/argon2';
import type { User } from '@prisma/client';
import type { NextFunction, Request, Response } from 'express';
import { pick } from 'lodash-es';
import { nanoid } from 'nanoid';
import type { AuthAction, CookieOption, NextAuthOptions } from 'next-auth';
import { AuthHandler } from 'next-auth/core';
import { Config } from '../../config';
import { Metrics } from '../../metrics/metrics';
import { PrismaService } from '../../prisma/service';
import { SessionService } from '../../session';
import { AuthThrottlerGuard, Throttle } from '../../throttler';
import { NextAuthOptionsProvide } from './next-auth-options';
import { AuthService } from './service';
const BASE_URL = '/api/auth/';
const DEFAULT_SESSION_EXPIRE_DATE = 2592000 * 1000; // 30 days
@Controller(BASE_URL)
export class NextAuthController {
private readonly callbackSession;
private readonly logger = new Logger('NextAuthController');
constructor(
readonly config: Config,
readonly prisma: PrismaService,
private readonly authService: AuthService,
@Inject(NextAuthOptionsProvide)
private readonly nextAuthOptions: NextAuthOptions,
private readonly metrics: Metrics,
private readonly session: SessionService
) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
this.callbackSession = nextAuthOptions.callbacks!.session;
}
@UseGuards(AuthThrottlerGuard)
@Throttle({
default: {
limit: 60,
ttl: 60,
},
})
@Get('/challenge')
async getChallenge(@Res() res: Response) {
const challenge = nanoid();
const resource = nanoid();
await this.session.set(challenge, resource, 5 * 60 * 1000);
res.json({ challenge, resource });
}
@UseGuards(AuthThrottlerGuard)
@Throttle({
default: {
limit: 60,
ttl: 60,
},
})
@All('*')
async auth(
@Req() req: Request,
@Res() res: Response,
@Query() query: Record<string, any>,
@Next() next: NextFunction
) {
if (req.path === '/api/auth/signin' && req.method === 'GET') {
const query = req.query
? // @ts-expect-error req.query is satisfy with the Record<string, any>
`?${new URLSearchParams(req.query).toString()}`
: '';
res.redirect(`/signin${query}`);
return;
}
this.metrics.authCounter(1, {});
const [action, providerId] = req.url // start with request url
.slice(BASE_URL.length) // make relative to baseUrl
.replace(/\?.*/, '') // remove query part, use only path part
.split('/') as [AuthAction, string]; // as array of strings;
const credentialsSignIn =
req.method === 'POST' && providerId === 'credentials';
let userId: string | undefined;
if (credentialsSignIn) {
const { email } = req.body;
if (email) {
const user = await this.prisma.user.findFirst({
where: {
email,
},
});
if (!user) {
req.statusCode = 401;
req.statusMessage = 'User not found';
req.body = null;
throw new NotFoundException(`User not found`);
} else {
userId = user.id;
req.body = {
...req.body,
name: user.name,
email: user.email,
image: user.avatarUrl,
hashedPassword: user.password,
};
}
}
}
const options = this.nextAuthOptions;
if (req.method === 'POST' && action === 'session') {
if (typeof req.body !== 'object' || typeof req.body.data !== 'object') {
this.metrics.authFailCounter(1, { reason: 'invalid_session_data' });
throw new BadRequestException(`Invalid new session data`);
}
const user = await this.updateSession(req, req.body.data);
// callbacks.session existed
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
options.callbacks!.session = ({ session }) => {
return {
user: {
...pick(user, 'id', 'name', 'email'),
image: user.avatarUrl,
hasPassword: !!user.password,
},
expires: session.expires,
};
};
} else {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
options.callbacks!.session = this.callbackSession;
}
if (
this.config.auth.captcha.enable &&
req.method === 'POST' &&
action === 'signin'
) {
const isVerified = await this.verifyChallenge(req, res);
if (!isVerified) return;
}
const { status, headers, body, redirect, cookies } = await AuthHandler({
req: {
body: req.body,
query: query,
method: req.method,
action,
providerId,
error: query.error ?? providerId,
cookies: req.cookies,
},
options,
});
if (headers) {
for (const { key, value } of headers) {
res.setHeader(key, value);
}
}
if (cookies) {
for (const cookie of cookies) {
res.cookie(cookie.name, cookie.value, cookie.options);
}
}
let nextAuthTokenCookie: (CookieOption & { value: string }) | undefined;
const cookiePrefix = this.config.node.prod ? '__Secure-' : '';
const sessionCookieName = `${cookiePrefix}next-auth.session-token`;
// next-auth credentials login only support JWT strategy
// https://next-auth.js.org/configuration/providers/credentials
// let's store the session token in the database
if (
credentialsSignIn &&
(nextAuthTokenCookie = cookies?.find(
({ name }) => name === sessionCookieName
))
) {
const cookieExpires = new Date();
cookieExpires.setTime(
cookieExpires.getTime() + DEFAULT_SESSION_EXPIRE_DATE
);
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
await this.nextAuthOptions.adapter!.createSession!({
sessionToken: nextAuthTokenCookie.value,
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
userId: userId!,
expires: cookieExpires,
});
}
if (redirect?.endsWith('api/auth/error?error=AccessDenied')) {
this.logger.log(`Early access redirect headers: ${req.headers}`);
this.metrics.authFailCounter(1, {
reason: 'no_early_access_permission',
});
if (
!req.headers?.referer ||
checkUrlOrigin(req.headers.referer, 'https://accounts.google.com')
) {
res.redirect('https://community.affine.pro/c/insider-general/');
} else {
res.status(403);
res.json({
url: 'https://community.affine.pro/c/insider-general/',
error: `You don't have early access permission`,
});
}
return;
}
if (status) {
res.status(status);
}
if (redirect) {
if (providerId === 'credentials') {
res.send(JSON.stringify({ ok: true, url: redirect }));
} else if (
action === 'callback' ||
action === 'error' ||
(providerId !== 'credentials' &&
// login in the next-auth page, /api/auth/signin, auto redirect.
// otherwise, return the json value to allow frontend to handle the redirect.
req.headers?.referer?.includes?.('/api/auth/signin'))
) {
res.redirect(redirect);
} else {
res.json({ url: redirect });
}
} else if (typeof body === 'string') {
res.send(body);
} else if (body && typeof body === 'object') {
res.json(body);
} else {
next();
}
}
private async updateSession(
req: Request,
newSession: Partial<Omit<User, 'id'>> & { oldPassword?: string }
): Promise<User> {
const { name, email, password, oldPassword } = newSession;
if (!name && !email && !password) {
throw new BadRequestException(`Invalid new session data`);
}
if (password) {
const user = await this.verifyUserFromRequest(req);
const { password: userPassword } = user;
if (!oldPassword) {
if (userPassword) {
throw new BadRequestException(
`Old password is required to update password`
);
}
} else {
if (!userPassword) {
throw new BadRequestException(`No existed password`);
}
if (await verify(userPassword, oldPassword)) {
await this.prisma.user.update({
where: {
id: user.id,
},
data: {
...pick(newSession, 'email', 'name'),
password: await hash(password),
},
});
}
}
return user;
} else {
const user = await this.verifyUserFromRequest(req);
return this.prisma.user.update({
where: {
id: user.id,
},
data: pick(newSession, 'name', 'email'),
});
}
}
private async verifyChallenge(req: Request, res: Response): Promise<boolean> {
const challenge = req.query?.challenge;
if (typeof challenge === 'string' && challenge) {
const resource = await this.session.get(challenge);
if (!resource) {
this.rejectResponse(res, 'Invalid Challenge');
return false;
}
const isChallengeVerified =
await this.authService.verifyChallengeResponse(
req.query?.token,
resource
);
this.logger.debug(
`Challenge: ${challenge}, Resource: ${resource}, Response: ${req.query?.token}, isChallengeVerified: ${isChallengeVerified}`
);
if (!isChallengeVerified) {
this.rejectResponse(res, 'Invalid Challenge Response');
return false;
}
} else {
const isTokenVerified = await this.authService.verifyCaptchaToken(
req.query?.token,
req.headers['CF-Connecting-IP'] as string
);
if (!isTokenVerified) {
this.rejectResponse(res, 'Invalid Captcha Response');
return false;
}
}
return true;
}
private async verifyUserFromRequest(req: Request): Promise<User> {
const token = req.headers.authorization;
if (!token) {
const session = await AuthHandler({
req: {
cookies: req.cookies,
action: 'session',
method: 'GET',
headers: req.headers,
},
options: this.nextAuthOptions,
});
const { body } = session;
// @ts-expect-error check if body.user exists
if (body && body.user && body.user.id) {
const user = await this.prisma.user.findUnique({
where: {
// @ts-expect-error body.user.id exists
id: body.user.id,
},
});
if (user) {
return user;
}
}
} else {
const [type, jwt] = token.split(' ') ?? [];
if (type === 'Bearer') {
const claims = await this.authService.verify(jwt);
const user = await this.prisma.user.findUnique({
where: { id: claims.id },
});
if (user) {
return user;
}
}
}
throw new BadRequestException(`User not found`);
}
rejectResponse(res: Response, error: string, status = 400) {
res.status(status);
res.json({
url: `https://${this.config.baseUrl}/api/auth/error?${new URLSearchParams(
{
error,
}
).toString()}`,
error,
});
}
}
const checkUrlOrigin = (url: string, origin: string) => {
try {
return new URL(url).origin === origin;
} catch (e) {
return false;
}
};

View File

@@ -0,0 +1,288 @@
import {
BadRequestException,
ForbiddenException,
UseGuards,
} from '@nestjs/common';
import {
Args,
Context,
Field,
Mutation,
ObjectType,
Parent,
ResolveField,
Resolver,
} from '@nestjs/graphql';
import type { Request } from 'express';
import { nanoid } from 'nanoid';
import { Config } from '../../config';
import { SessionService } from '../../session';
import { CloudThrottlerGuard, Throttle } from '../../throttler';
import { UserType } from '../users/resolver';
import { Auth, CurrentUser } from './guard';
import { AuthService } from './service';
@ObjectType()
export class TokenType {
@Field()
token!: string;
@Field()
refresh!: string;
@Field({ nullable: true })
sessionToken?: string;
}
/**
* Auth resolver
* Token rate limit: 20 req/m
* Sign up/in rate limit: 10 req/m
* Other rate limit: 5 req/m
*/
@UseGuards(CloudThrottlerGuard)
@Resolver(() => UserType)
export class AuthResolver {
constructor(
private readonly config: Config,
private readonly auth: AuthService,
private readonly session: SessionService
) {}
@Throttle({
default: {
limit: 20,
ttl: 60,
},
})
@ResolveField(() => TokenType)
async token(
@Context() ctx: { req: Request },
@CurrentUser() currentUser: UserType,
@Parent() user: UserType
) {
if (user.id !== currentUser.id) {
throw new BadRequestException('Invalid user');
}
let sessionToken: string | undefined;
// only return session if the request is from the same origin & path == /open-app
if (
ctx.req.headers.referer &&
ctx.req.headers.host &&
new URL(ctx.req.headers.referer).pathname.startsWith('/open-app') &&
ctx.req.headers.host === new URL(this.config.origin).host
) {
const cookiePrefix = this.config.node.prod ? '__Secure-' : '';
const sessionCookieName = `${cookiePrefix}next-auth.session-token`;
sessionToken = ctx.req.cookies?.[sessionCookieName];
}
return {
sessionToken,
token: this.auth.sign(user),
refresh: this.auth.refresh(user),
};
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => UserType)
async signUp(
@Context() ctx: { req: Request },
@Args('name') name: string,
@Args('email') email: string,
@Args('password') password: string
) {
const user = await this.auth.signUp(name, email, password);
ctx.req.user = user;
return user;
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => UserType)
async signIn(
@Context() ctx: { req: Request },
@Args('email') email: string,
@Args('password') password: string
) {
const user = await this.auth.signIn(email, password);
ctx.req.user = user;
return user;
}
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => UserType)
@Auth()
async changePassword(
@CurrentUser() user: UserType,
@Args('token') token: string,
@Args('newPassword') newPassword: string
) {
const id = await this.session.get(token);
if (!id || id !== user.id) {
throw new ForbiddenException('Invalid token');
}
await this.auth.changePassword(id, newPassword);
await this.session.delete(token);
return user;
}
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => UserType)
@Auth()
async changeEmail(
@CurrentUser() user: UserType,
@Args('token') token: string
) {
// email has set token in `sendVerifyChangeEmail`
const [id, email] = (await this.session.get(token)).split(',');
if (!id || id !== user.id || !email) {
throw new ForbiddenException('Invalid token');
}
await this.auth.changeEmail(id, email);
await this.session.delete(token);
await this.auth.sendNotificationChangeEmail(email);
return user;
}
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean)
@Auth()
async sendChangePasswordEmail(
@CurrentUser() user: UserType,
@Args('email') email: string,
@Args('callbackUrl') callbackUrl: string
) {
const token = nanoid();
await this.session.set(token, user.id);
const url = new URL(callbackUrl, this.config.baseUrl);
url.searchParams.set('token', token);
const res = await this.auth.sendChangePasswordEmail(email, url.toString());
return !res.rejected.length;
}
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean)
@Auth()
async sendSetPasswordEmail(
@CurrentUser() user: UserType,
@Args('email') email: string,
@Args('callbackUrl') callbackUrl: string
) {
const token = nanoid();
await this.session.set(token, user.id);
const url = new URL(callbackUrl, this.config.baseUrl);
url.searchParams.set('token', token);
const res = await this.auth.sendSetPasswordEmail(email, url.toString());
return !res.rejected.length;
}
// The change email step is:
// 1. send email to primitive email `sendChangeEmail`
// 2. user open change email page from email
// 3. send verify email to new email `sendVerifyChangeEmail`
// 4. user open confirm email page from new email
// 5. user click confirm button
// 6. send notification email
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean)
@Auth()
async sendChangeEmail(
@CurrentUser() user: UserType,
@Args('email') email: string,
@Args('callbackUrl') callbackUrl: string
) {
const token = nanoid();
await this.session.set(token, user.id);
const url = new URL(callbackUrl, this.config.baseUrl);
url.searchParams.set('token', token);
const res = await this.auth.sendChangeEmail(email, url.toString());
return !res.rejected.length;
}
@Throttle({
default: {
limit: 5,
ttl: 60,
},
})
@Mutation(() => Boolean)
@Auth()
async sendVerifyChangeEmail(
@CurrentUser() user: UserType,
@Args('token') token: string,
@Args('email') email: string,
@Args('callbackUrl') callbackUrl: string
) {
const id = await this.session.get(token);
if (!id || id !== user.id) {
throw new ForbiddenException('Invalid token');
}
const hasRegistered = await this.auth.getUserByEmail(email);
if (hasRegistered) {
throw new BadRequestException(`Invalid user email`);
}
const withEmailToken = nanoid();
await this.session.set(withEmailToken, `${user.id},${email}`);
const url = new URL(callbackUrl, this.config.baseUrl);
url.searchParams.set('token', withEmailToken);
const res = await this.auth.sendVerifyChangeEmail(email, url.toString());
await this.session.delete(token);
return !res.rejected.length;
}
}

View File

@@ -0,0 +1,294 @@
import { randomUUID } from 'node:crypto';
import {
BadRequestException,
Injectable,
InternalServerErrorException,
UnauthorizedException,
} from '@nestjs/common';
import { hash, verify } from '@node-rs/argon2';
import { Algorithm, sign, verify as jwtVerify } from '@node-rs/jsonwebtoken';
import type { User } from '@prisma/client';
import { nanoid } from 'nanoid';
import { Config } from '../../config';
import { PrismaService } from '../../prisma';
import { verifyChallengeResponse } from '../../storage';
import { MailService } from './mailer';
export type UserClaim = Pick<
User,
'id' | 'name' | 'email' | 'emailVerified' | 'createdAt' | 'avatarUrl'
> & {
hasPassword?: boolean;
};
export const getUtcTimestamp = () => Math.floor(new Date().getTime() / 1000);
@Injectable()
export class AuthService {
constructor(
private config: Config,
private prisma: PrismaService,
private mailer: MailService
) {}
sign(user: UserClaim) {
const now = getUtcTimestamp();
return sign(
{
data: {
id: user.id,
name: user.name,
email: user.email,
emailVerified: user.emailVerified?.toISOString(),
image: user.avatarUrl,
hasPassword: Boolean(user.hasPassword),
createdAt: user.createdAt.toISOString(),
},
iat: now,
exp: now + this.config.auth.accessTokenExpiresIn,
iss: this.config.serverId,
sub: user.id,
aud: user.name,
jti: randomUUID({
disableEntropyCache: true,
}),
},
this.config.auth.privateKey,
{
algorithm: Algorithm.ES256,
}
);
}
refresh(user: UserClaim) {
const now = getUtcTimestamp();
return sign(
{
data: {
id: user.id,
name: user.name,
email: user.email,
emailVerified: user.emailVerified?.toISOString(),
image: user.avatarUrl,
hasPassword: Boolean(user.hasPassword),
createdAt: user.createdAt.toISOString(),
},
exp: now + this.config.auth.refreshTokenExpiresIn,
iat: now,
iss: this.config.serverId,
sub: user.id,
aud: user.name,
jti: randomUUID({
disableEntropyCache: true,
}),
},
this.config.auth.privateKey,
{
algorithm: Algorithm.ES256,
}
);
}
async verify(token: string) {
try {
const data = (
await jwtVerify(token, this.config.auth.publicKey, {
algorithms: [Algorithm.ES256],
iss: [this.config.serverId],
leeway: this.config.auth.leeway,
requiredSpecClaims: ['exp', 'iat', 'iss', 'sub'],
})
).data as UserClaim;
return {
...data,
emailVerified: data.emailVerified ? new Date(data.emailVerified) : null,
createdAt: new Date(data.createdAt),
};
} catch (e) {
throw new UnauthorizedException('Invalid token');
}
}
async verifyCaptchaToken(token: any, ip: string) {
if (typeof token !== 'string' || !token) return false;
const formData = new FormData();
formData.append('secret', this.config.auth.captcha.turnstile.secret);
formData.append('response', token);
formData.append('remoteip', ip);
// prevent replay attack
formData.append('idempotency_key', nanoid());
const url = 'https://challenges.cloudflare.com/turnstile/v0/siteverify';
const result = await fetch(url, {
body: formData,
method: 'POST',
});
const outcome = await result.json();
return (
!!outcome.success &&
// skip hostname check in dev mode
(this.config.affineEnv === 'dev' || outcome.hostname === this.config.host)
);
}
async verifyChallengeResponse(response: any, resource: string) {
return verifyChallengeResponse(
response,
this.config.auth.captcha.challenge.bits,
resource
);
}
async signIn(email: string, password: string): Promise<User> {
const user = await this.prisma.user.findFirst({
where: {
email,
},
});
if (!user) {
throw new BadRequestException('Invalid email');
}
if (!user.password) {
throw new BadRequestException('User has no password');
}
let equal = false;
try {
equal = await verify(user.password, password);
} catch (e) {
console.error(e);
throw new InternalServerErrorException(e, 'Verify password failed');
}
if (!equal) {
throw new UnauthorizedException('Invalid password');
}
return user;
}
async signUp(name: string, email: string, password: string): Promise<User> {
const user = await this.prisma.user.findFirst({
where: {
email,
},
});
if (user) {
throw new BadRequestException('Email already exists');
}
const hashedPassword = await hash(password);
return this.prisma.user.create({
data: {
name,
email,
password: hashedPassword,
},
});
}
async createAnonymousUser(email: string): Promise<User> {
const user = await this.prisma.user.findFirst({
where: {
email,
},
});
if (user) {
throw new BadRequestException('Email already exists');
}
return this.prisma.user.create({
data: {
name: 'Unnamed',
email,
},
});
}
async getUserByEmail(email: string): Promise<User | null> {
return this.prisma.user.findUnique({
where: {
email,
},
});
}
async isUserHasPassword(email: string): Promise<boolean> {
const user = await this.prisma.user.findFirst({
where: {
email,
},
});
if (!user) {
throw new BadRequestException('Invalid email');
}
return Boolean(user.password);
}
async changePassword(id: string, newPassword: string): Promise<User> {
const user = await this.prisma.user.findUnique({
where: {
id,
},
});
if (!user) {
throw new BadRequestException('Invalid email');
}
const hashedPassword = await hash(newPassword);
return this.prisma.user.update({
where: {
id,
},
data: {
password: hashedPassword,
},
});
}
async changeEmail(id: string, newEmail: string): Promise<User> {
const user = await this.prisma.user.findUnique({
where: {
id,
},
});
if (!user) {
throw new BadRequestException('Invalid email');
}
return this.prisma.user.update({
where: {
id,
},
data: {
email: newEmail,
},
});
}
async sendChangePasswordEmail(email: string, callbackUrl: string) {
return this.mailer.sendChangePasswordEmail(email, callbackUrl);
}
async sendSetPasswordEmail(email: string, callbackUrl: string) {
return this.mailer.sendSetPasswordEmail(email, callbackUrl);
}
async sendChangeEmail(email: string, callbackUrl: string) {
return this.mailer.sendChangeEmail(email, callbackUrl);
}
async sendVerifyChangeEmail(email: string, callbackUrl: string) {
return this.mailer.sendVerifyChangeEmail(email, callbackUrl);
}
async sendNotificationChangeEmail(email: string) {
return this.mailer.sendNotificationChangeEmail(email);
}
}

View File

@@ -0,0 +1,3 @@
export { jwtDecode as decode, jwtEncode as encode } from './jwt';
export { sendVerificationRequest } from './send-mail';
export type { SendVerificationRequestParams } from 'next-auth/providers/email';

View File

@@ -0,0 +1,76 @@
import { randomUUID } from 'node:crypto';
import { BadRequestException } from '@nestjs/common';
import { Algorithm, sign, verify as jwtVerify } from '@node-rs/jsonwebtoken';
import { JWT } from 'next-auth/jwt';
import { Config } from '../../../config';
import { PrismaService } from '../../../prisma';
import { getUtcTimestamp, UserClaim } from '../service';
export const jwtEncode = async (
config: Config,
prisma: PrismaService,
token: JWT | undefined,
maxAge: number | undefined
) => {
if (!token?.email) {
throw new BadRequestException('Missing email in jwt token');
}
const user = await prisma.user.findFirstOrThrow({
where: {
email: token.email,
},
});
const now = getUtcTimestamp();
return sign(
{
data: {
id: user.id,
name: user.name,
email: user.email,
emailVerified: user.emailVerified?.toISOString(),
picture: user.avatarUrl,
createdAt: user.createdAt.toISOString(),
hasPassword: Boolean(user.password),
},
iat: now,
exp: now + (maxAge ?? config.auth.accessTokenExpiresIn),
iss: config.serverId,
sub: user.id,
aud: user.name,
jti: randomUUID({
disableEntropyCache: true,
}),
},
config.auth.privateKey,
{
algorithm: Algorithm.ES256,
}
);
};
export const jwtDecode = async (config: Config, token: string | undefined) => {
if (!token) {
return null;
}
const { name, email, emailVerified, id, picture, hasPassword } = (
await jwtVerify(token, config.auth.publicKey, {
algorithms: [Algorithm.ES256],
iss: [config.serverId],
leeway: config.auth.leeway,
requiredSpecClaims: ['exp', 'iat', 'iss', 'sub'],
})
).data as Omit<UserClaim, 'avatarUrl'> & {
picture: string | undefined;
};
return {
name,
email,
emailVerified,
picture,
sub: id,
id,
hasPassword,
};
};

View File

@@ -0,0 +1,41 @@
import { Logger } from '@nestjs/common';
import { nanoid } from 'nanoid';
import type { SendVerificationRequestParams } from 'next-auth/providers/email';
import { Config } from '../../../config';
import { SessionService } from '../../../session';
import { MailService } from '../mailer';
export async function sendVerificationRequest(
config: Config,
logger: Logger,
mailer: MailService,
session: SessionService,
params: SendVerificationRequestParams
) {
const { identifier, url, provider } = params;
const urlWithToken = new URL(url);
const callbackUrl = urlWithToken.searchParams.get('callbackUrl') || '';
if (!callbackUrl) {
throw new Error('callbackUrl is not set');
} else {
const newCallbackUrl = new URL(callbackUrl, config.origin);
const token = nanoid();
await session.set(token, identifier);
newCallbackUrl.searchParams.set('token', token);
urlWithToken.searchParams.set('callbackUrl', newCallbackUrl.toString());
}
const result = await mailer.sendSignInEmail(urlWithToken.toString(), {
to: identifier,
from: provider.from,
});
logger.log(`send verification email success: ${result.accepted.join(', ')}`);
const failed = result.rejected.concat(result.pending).filter(Boolean);
if (failed.length) {
throw new Error(`Email (${failed.join(', ')}) could not be sent`);
}
}

View File

@@ -0,0 +1,42 @@
import { DynamicModule } from '@nestjs/common';
import { DocManager } from './manager';
import { RedisDocManager } from './redis-manager';
export class DocModule {
/**
* @param automation whether enable update merging automation logic
*/
private static defModule(automation = true): DynamicModule {
return {
module: DocModule,
providers: [
{
provide: 'DOC_MANAGER_AUTOMATION',
useValue: automation,
},
{
provide: DocManager,
useClass: globalThis.AFFiNE.redis.enabled
? RedisDocManager
: DocManager,
},
],
exports: [DocManager],
};
}
static forRoot() {
return this.defModule();
}
static forSync(): DynamicModule {
return this.defModule(false);
}
static forFeature(): DynamicModule {
return this.defModule(false);
}
}
export { DocManager };

View File

@@ -0,0 +1,466 @@
import {
Inject,
Injectable,
Logger,
OnApplicationBootstrap,
OnModuleDestroy,
OnModuleInit,
} from '@nestjs/common';
import { Snapshot, Update } from '@prisma/client';
import { defer, retry } from 'rxjs';
import { applyUpdate, Doc, encodeStateAsUpdate, encodeStateVector } from 'yjs';
import { Config } from '../../config';
import { Metrics } from '../../metrics/metrics';
import { PrismaService } from '../../prisma';
import { mergeUpdatesInApplyWay as jwstMergeUpdates } from '../../storage';
import { DocID } from '../../utils/doc';
function compare(yBinary: Buffer, jwstBinary: Buffer, strict = false): boolean {
if (yBinary.equals(jwstBinary)) {
return true;
}
if (strict) {
return false;
}
const doc = new Doc();
applyUpdate(doc, jwstBinary);
const yBinary2 = Buffer.from(encodeStateAsUpdate(doc));
return compare(yBinary, yBinary2, true);
}
const MAX_SEQ_NUM = 0x3fffffff; // u31
/**
* Since we can't directly save all client updates into database, in which way the database will overload,
* we need to buffer the updates and merge them to reduce db write.
*
* And also, if a new client join, it would be nice to see the latest doc asap,
* so we need to at least store a snapshot of the doc and return quickly,
* along side all the updates that have not been applies to that snapshot(timestamp).
*/
@Injectable()
export class DocManager
implements OnModuleInit, OnModuleDestroy, OnApplicationBootstrap
{
protected logger = new Logger(DocManager.name);
private job: NodeJS.Timeout | null = null;
private seqMap = new Map<string, number>();
private busy = false;
constructor(
protected readonly db: PrismaService,
@Inject('DOC_MANAGER_AUTOMATION')
protected readonly automation: boolean,
protected readonly config: Config,
protected readonly metrics: Metrics
) {}
async onApplicationBootstrap() {
if (!this.config.node.test) {
await this.refreshDocGuid();
}
}
onModuleInit() {
if (this.automation) {
this.logger.log('Use Database');
this.setup();
}
}
onModuleDestroy() {
this.destroy();
}
protected recoverDoc(...updates: Buffer[]): Doc {
const doc = new Doc();
updates.forEach((update, i) => {
try {
if (update.length) {
applyUpdate(doc, update);
}
} catch (e) {
this.logger.error(
`Failed to apply updates, index: ${i}\nUpdate: ${updates
.map(u => u.toString('hex'))
.join('\n')}`
);
}
});
return doc;
}
protected applyUpdates(guid: string, ...updates: Buffer[]): Doc {
const doc = this.recoverDoc(...updates);
this.metrics.jwstCodecMerge(1, {});
// test jwst codec
if (this.config.doc.manager.experimentalMergeWithJwstCodec) {
const yjsResult = Buffer.from(encodeStateAsUpdate(doc));
let log = false;
try {
const jwstResult = jwstMergeUpdates(updates);
if (!compare(yjsResult, jwstResult)) {
this.metrics.jwstCodecDidnotMatch(1, {});
this.logger.warn(
`jwst codec result doesn't match yjs codec result for: ${guid}`
);
log = true;
if (this.config.node.dev) {
this.logger.warn(`Expected:\n ${yjsResult.toString('hex')}`);
this.logger.warn(`Result:\n ${jwstResult.toString('hex')}`);
}
}
} catch (e) {
this.metrics.jwstCodecFail(1, {});
this.logger.warn(`jwst apply update failed for ${guid}: ${e}`);
log = true;
} finally {
if (log) {
this.logger.warn(
`Updates: ${updates.map(u => u.toString('hex')).join('\n')}`
);
}
}
}
return doc;
}
/**
* setup pending update processing loop
*/
setup() {
this.job = setInterval(() => {
if (!this.busy) {
this.busy = true;
this.autoSquash()
.catch(() => {
/* we handle all errors in work itself */
})
.finally(() => {
this.busy = false;
});
}
}, this.config.doc.manager.updatePollInterval);
this.logger.log('Automation started');
if (this.config.doc.manager.experimentalMergeWithJwstCodec) {
this.logger.warn(
'Experimental feature enabled: merge updates with jwst codec is enabled'
);
}
}
/**
* stop pending update processing loop
*/
destroy() {
if (this.job) {
clearInterval(this.job);
this.job = null;
this.logger.log('Automation stopped');
}
}
/**
* add update to manager for later processing.
*/
async push(workspaceId: string, guid: string, update: Buffer) {
await new Promise<void>((resolve, reject) => {
defer(async () => {
const seq = await this.getUpdateSeq(workspaceId, guid);
await this.db.update.create({
data: {
workspaceId,
id: guid,
seq,
blob: update,
},
});
})
.pipe(retry(MAX_SEQ_NUM)) // retry until seq num not conflict
.subscribe({
next: () => {
this.logger.verbose(
`pushed update for workspace: ${workspaceId}, guid: ${guid}`
);
resolve();
},
error: reject,
});
});
}
/**
* get the latest doc with all update applied.
*/
async get(workspaceId: string, guid: string): Promise<Doc | null> {
const result = await this._get(workspaceId, guid);
if (result) {
if ('doc' in result) {
return result.doc;
} else if ('snapshot' in result) {
return this.recoverDoc(result.snapshot);
}
}
return null;
}
/**
* get the latest doc binary with all update applied.
*/
async getBinary(workspaceId: string, guid: string): Promise<Buffer | null> {
const result = await this._get(workspaceId, guid);
if (result) {
if ('doc' in result) {
return Buffer.from(encodeStateAsUpdate(result.doc));
} else if ('snapshot' in result) {
return result.snapshot;
}
}
return null;
}
/**
* get the latest doc state vector with all update applied.
*/
async getState(workspaceId: string, guid: string): Promise<Buffer | null> {
const snapshot = await this.getSnapshot(workspaceId, guid);
const updates = await this.getUpdates(workspaceId, guid);
if (updates.length) {
const doc = await this.squash(updates, snapshot);
return Buffer.from(encodeStateVector(doc));
}
return snapshot ? snapshot.state : null;
}
/**
* get the snapshot of the doc we've seen.
*/
async getSnapshot(workspaceId: string, guid: string) {
return this.db.snapshot.findUnique({
where: {
id_workspaceId: {
workspaceId,
id: guid,
},
},
});
}
/**
* get pending updates
*/
async getUpdates(workspaceId: string, guid: string) {
return this.db.update.findMany({
where: {
workspaceId,
id: guid,
},
orderBy: {
seq: 'asc',
},
});
}
/**
* apply pending updates to snapshot
*/
protected async autoSquash() {
// find the first update and batch process updates with same id
const first = await this.db.update.findFirst({
orderBy: {
createdAt: 'asc',
},
});
// no pending updates
if (!first) {
return;
}
const { id, workspaceId } = first;
try {
await this._get(workspaceId, id);
} catch (e) {
this.logger.error(
`Failed to apply updates for workspace: ${workspaceId}, guid: ${id}`
);
this.logger.error(e);
}
}
protected async upsert(
workspaceId: string,
guid: string,
doc: Doc,
seq?: number
) {
const blob = Buffer.from(encodeStateAsUpdate(doc));
const state = Buffer.from(encodeStateVector(doc));
return this.db.snapshot.upsert({
where: {
id_workspaceId: {
id: guid,
workspaceId,
},
},
create: {
id: guid,
workspaceId,
blob,
state,
seq,
},
update: {
blob,
state,
},
});
}
protected async _get(
workspaceId: string,
guid: string
): Promise<{ doc: Doc } | { snapshot: Buffer } | null> {
const snapshot = await this.getSnapshot(workspaceId, guid);
const updates = await this.getUpdates(workspaceId, guid);
if (updates.length) {
return {
doc: await this.squash(updates, snapshot),
};
}
return snapshot ? { snapshot: snapshot.blob } : null;
}
/**
* Squash updates into a single update and save it as snapshot,
* and delete the updates records at the same time.
*/
protected async squash(updates: Update[], snapshot: Snapshot | null) {
if (!updates.length) {
throw new Error('No updates to squash');
}
const first = updates[0];
const last = updates[updates.length - 1];
const doc = this.applyUpdates(
first.id,
snapshot ? snapshot.blob : Buffer.from([0, 0]),
...updates.map(u => u.blob)
);
const { id, workspaceId } = first;
await this.upsert(workspaceId, id, doc, last.seq);
await this.db.update.deleteMany({
where: {
id,
workspaceId,
seq: {
in: updates.map(u => u.seq),
},
},
});
return doc;
}
private async getUpdateSeq(workspaceId: string, guid: string) {
try {
const { seq } = await this.db.snapshot.update({
select: {
seq: true,
},
where: {
id_workspaceId: {
workspaceId,
id: guid,
},
},
data: {
seq: {
increment: 1,
},
},
});
// reset
if (seq === MAX_SEQ_NUM) {
await this.db.snapshot.update({
where: {
id_workspaceId: {
workspaceId,
id: guid,
},
},
data: {
seq: 0,
},
});
}
return seq;
} catch {
const last = this.seqMap.get(workspaceId + guid) ?? 0;
this.seqMap.set(workspaceId + guid, last + 1);
return last + 1;
}
}
/**
* deal with old records that has wrong guid format
* correct guid with `${non-wsId}:${variant}:${subId}` to `${subId}`
*
* @TODO delete in next release
* @deprecated
*/
private async refreshDocGuid() {
let turn = 0;
let lastTurnCount = 100;
while (lastTurnCount === 100) {
const docs = await this.db.snapshot.findMany({
skip: turn * 100,
take: 100,
orderBy: {
createdAt: 'asc',
},
});
lastTurnCount = docs.length;
for (const doc of docs) {
const docId = new DocID(doc.id, doc.workspaceId);
if (docId && !docId.isWorkspace && docId.guid !== doc.id) {
await this.db.snapshot.update({
where: {
id_workspaceId: {
id: doc.id,
workspaceId: doc.workspaceId,
},
},
data: {
id: docId.guid,
},
});
}
}
turn++;
}
}
}

View File

@@ -0,0 +1,129 @@
import { Inject, Injectable } from '@nestjs/common';
import Redis from 'ioredis';
import { Config } from '../../config';
import { Metrics } from '../../metrics/metrics';
import { PrismaService } from '../../prisma';
import { DocID } from '../../utils/doc';
import { DocManager } from './manager';
function makeKey(prefix: string) {
return (parts: TemplateStringsArray, ...args: any[]) => {
return parts.reduce((prev, curr, i) => {
return prev + curr + (args[i] || '');
}, prefix);
};
}
const pending = 'um_pending:';
const updates = makeKey('um_u:');
const lock = makeKey('um_l:');
const pushUpdateLua = `
redis.call('sadd', KEYS[1], ARGV[1])
redis.call('rpush', KEYS[2], ARGV[2])
`;
/**
* @deprecated unstable
*/
@Injectable()
export class RedisDocManager extends DocManager {
private readonly redis: Redis;
constructor(
protected override readonly db: PrismaService,
@Inject('DOC_MANAGER_AUTOMATION')
protected override readonly automation: boolean,
protected override readonly config: Config,
protected override readonly metrics: Metrics
) {
super(db, automation, config, metrics);
this.redis = new Redis(config.redis);
this.redis.defineCommand('pushDocUpdate', {
numberOfKeys: 2,
lua: pushUpdateLua,
});
}
override onModuleInit(): void {
if (this.automation) {
this.setup();
}
}
override async autoSquash(): Promise<void> {
// incase some update fallback to db
await super.autoSquash();
// consume rest updates in redis queue
const pendingDoc = await this.redis.spop(pending).catch(() => null); // safe
if (!pendingDoc) {
return;
}
const docId = new DocID(pendingDoc);
const updateKey = updates`${pendingDoc}`;
const lockKey = lock`${pendingDoc}`;
// acquire the lock
const lockResult = await this.redis
.set(
lockKey,
'1',
'EX',
// 10mins, incase progress exit in between lock require & release, which is a rare.
// if the lock is really hold more then 10mins, we should check the merge logic correctness
600,
'NX'
)
.catch(() => null); // safe;
if (!lockResult) {
// we failed to acquire the lock, put the pending doc back to queue.
await this.redis.sadd(pending, pendingDoc).catch(() => null); // safe
return;
}
try {
// fetch pending updates
const updates = await this.redis
.lrangeBuffer(updateKey, 0, -1)
.catch(() => []); // safe
if (!updates.length) {
return;
}
this.logger.verbose(
`applying ${updates.length} updates for workspace: ${docId}`
);
const snapshot = await this.getSnapshot(docId.workspace, docId.guid);
// merge
const doc = snapshot
? this.applyUpdates(docId.full, snapshot.blob, ...updates)
: this.applyUpdates(docId.full, ...updates);
// update snapshot
await this.upsert(docId.workspace, docId.guid, doc, snapshot?.seq);
// delete merged updates
await this.redis
.ltrim(updateKey, updates.length, -1)
// safe, fallback to mergeUpdates
.catch(e => {
this.logger.error(`Failed to remove merged updates from Redis: ${e}`);
});
} catch (e) {
this.logger.error(
`Failed to merge updates with snapshot for ${docId}: ${e}`
);
await this.redis.sadd(pending, docId.toString()).catch(() => null); // safe
} finally {
await this.redis.del(lockKey);
}
}
}

View File

@@ -0,0 +1,40 @@
import { DynamicModule, Type } from '@nestjs/common';
import { GqlModule } from '../graphql.module';
import { AuthModule } from './auth';
import { DocModule } from './doc';
import { SyncModule } from './sync';
import { UsersModule } from './users';
import { WorkspaceModule } from './workspaces';
const { SERVER_FLAVOR } = process.env;
const BusinessModules: (Type | DynamicModule)[] = [];
switch (SERVER_FLAVOR) {
case 'sync':
BusinessModules.push(SyncModule, DocModule.forSync());
break;
case 'graphql':
BusinessModules.push(
GqlModule,
WorkspaceModule,
UsersModule,
AuthModule,
DocModule.forRoot()
);
break;
case 'allinone':
default:
BusinessModules.push(
GqlModule,
WorkspaceModule,
UsersModule,
AuthModule,
SyncModule,
DocModule.forRoot()
);
break;
}
export { BusinessModules };

View File

@@ -0,0 +1,30 @@
import { randomUUID } from 'node:crypto';
import { createWriteStream } from 'node:fs';
import { mkdir } from 'node:fs/promises';
import { join } from 'node:path';
import { pipeline } from 'node:stream/promises';
import { Injectable } from '@nestjs/common';
import { Config } from '../../config';
import { FileUpload } from '../../types';
@Injectable()
export class FSService {
constructor(private readonly config: Config) {}
async writeFile(key: string, file: FileUpload) {
const dest = this.config.objectStorage.fs.path;
const fileName = `${key}-${randomUUID()}`;
const prefix = this.config.node.dev
? `${this.config.https ? 'https' : 'http'}://${this.config.host}:${
this.config.port
}`
: '';
await mkdir(dest, { recursive: true });
const destFile = join(dest, fileName);
await pipeline(file.createReadStream(), createWriteStream(destFile));
return `${prefix}/assets/${fileName}`;
}
}

View File

@@ -0,0 +1,11 @@
import { Module } from '@nestjs/common';
import { FSService } from './fs';
import { S3 } from './s3';
import { StorageService } from './storage.service';
@Module({
providers: [S3, StorageService, FSService],
exports: [StorageService],
})
export class StorageModule {}

View File

@@ -0,0 +1,22 @@
import { S3Client } from '@aws-sdk/client-s3';
import { FactoryProvider } from '@nestjs/common';
import { Config } from '../../config';
export const S3_SERVICE = Symbol('S3_SERVICE');
export const S3: FactoryProvider<S3Client> = {
provide: S3_SERVICE,
useFactory: (config: Config) => {
const s3 = new S3Client({
region: 'auto',
endpoint: `https://${config.objectStorage.r2.accountId}.r2.cloudflarestorage.com`,
credentials: {
accessKeyId: config.objectStorage.r2.accessKeyId,
secretAccessKey: config.objectStorage.r2.secretAccessKey,
},
});
return s3;
},
inject: [Config],
};

View File

@@ -0,0 +1,43 @@
import { PutObjectCommand, S3Client } from '@aws-sdk/client-s3';
import { Inject, Injectable } from '@nestjs/common';
import { crc32 } from '@node-rs/crc32';
import { fileTypeFromBuffer } from 'file-type';
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore - no types
import { getStreamAsBuffer } from 'get-stream';
import { Config } from '../../config';
import { FileUpload } from '../../types';
import { FSService } from './fs';
import { S3_SERVICE } from './s3';
@Injectable()
export class StorageService {
constructor(
@Inject(S3_SERVICE) private readonly s3: S3Client,
private readonly fs: FSService,
private readonly config: Config
) {}
async uploadFile(key: string, file: FileUpload) {
if (this.config.objectStorage.r2.enabled) {
const readableFile = file.createReadStream();
const fileBuffer = await getStreamAsBuffer(readableFile);
const mime = (await fileTypeFromBuffer(fileBuffer))?.mime;
const crc32Value = crc32(fileBuffer);
const keyWithCrc32 = `${crc32Value}-${key}`;
await this.s3.send(
new PutObjectCommand({
Body: fileBuffer,
Bucket: this.config.objectStorage.r2.bucket,
Key: keyWithCrc32,
ContentLength: fileBuffer.length,
ContentType: mime,
})
);
return `https://avatar.affineassets.com/${keyWithCrc32}`;
} else {
return this.fs.writeFile(key, file);
}
}
}

View File

@@ -0,0 +1,212 @@
import { Logger } from '@nestjs/common';
import {
ConnectedSocket,
MessageBody,
OnGatewayConnection,
OnGatewayDisconnect,
SubscribeMessage,
WebSocketGateway,
WebSocketServer,
} from '@nestjs/websockets';
import { Server, Socket } from 'socket.io';
import { encodeStateAsUpdate, encodeStateVector } from 'yjs';
import { Metrics } from '../../../metrics/metrics';
import { DocID } from '../../../utils/doc';
import { Auth, CurrentUser } from '../../auth';
import { DocManager } from '../../doc';
import { UserType } from '../../users';
import { PermissionService } from '../../workspaces/permission';
import { Permission } from '../../workspaces/types';
@WebSocketGateway({
cors: process.env.NODE_ENV !== 'production',
transports: ['websocket'],
})
export class EventsGateway implements OnGatewayConnection, OnGatewayDisconnect {
protected logger = new Logger(EventsGateway.name);
private connectionCount = 0;
constructor(
private readonly docManager: DocManager,
private readonly metric: Metrics,
private readonly permissions: PermissionService
) {}
@WebSocketServer()
server!: Server;
handleConnection() {
this.connectionCount++;
this.metric.socketIOConnectionGauge(this.connectionCount, {});
}
handleDisconnect() {
this.connectionCount--;
this.metric.socketIOConnectionGauge(this.connectionCount, {});
}
@Auth()
@SubscribeMessage('client-handshake')
async handleClientHandShake(
@CurrentUser() user: UserType,
@MessageBody() workspaceId: string,
@ConnectedSocket() client: Socket
) {
this.metric.socketIOEventCounter(1, { event: 'client-handshake' });
const endTimer = this.metric.socketIOEventTimer({
event: 'client-handshake',
});
const canWrite = await this.permissions.tryCheck(
workspaceId,
user.id,
Permission.Write
);
if (canWrite) await client.join(workspaceId);
endTimer();
return canWrite;
}
@SubscribeMessage('client-leave')
async handleClientLeave(
@MessageBody() workspaceId: string,
@ConnectedSocket() client: Socket
) {
this.metric.socketIOEventCounter(1, { event: 'client-leave' });
const endTimer = this.metric.socketIOEventTimer({
event: 'client-leave',
});
await client.leave(workspaceId);
endTimer();
}
@SubscribeMessage('client-update')
async handleClientUpdate(
@MessageBody()
{
workspaceId,
guid,
update,
}: {
workspaceId: string;
guid: string;
update: string;
},
@ConnectedSocket() client: Socket
) {
this.metric.socketIOEventCounter(1, { event: 'client-update' });
const endTimer = this.metric.socketIOEventTimer({ event: 'client-update' });
if (!client.rooms.has(workspaceId)) {
this.logger.verbose(
`Client ${client.id} tried to push update to workspace ${workspaceId} without joining it first`
);
endTimer();
return;
}
const docId = new DocID(guid, workspaceId);
client
.to(docId.workspace)
.emit('server-update', { workspaceId, guid, update });
const buf = Buffer.from(update, 'base64');
await this.docManager.push(docId.workspace, docId.guid, buf);
endTimer();
}
@Auth()
@SubscribeMessage('doc-load')
async loadDoc(
@ConnectedSocket() client: Socket,
@CurrentUser() user: UserType,
@MessageBody()
{
workspaceId,
guid,
stateVector,
}: {
workspaceId: string;
guid: string;
stateVector?: string;
}
): Promise<{ missing: string; state?: string } | false> {
this.metric.socketIOEventCounter(1, { event: 'doc-load' });
const endTimer = this.metric.socketIOEventTimer({ event: 'doc-load' });
if (!client.rooms.has(workspaceId)) {
const canRead = await this.permissions.tryCheck(workspaceId, user.id);
if (!canRead) {
endTimer();
return false;
}
}
const docId = new DocID(guid, workspaceId);
const doc = await this.docManager.get(docId.workspace, docId.guid);
if (!doc) {
endTimer();
return false;
}
const missing = Buffer.from(
encodeStateAsUpdate(
doc,
stateVector ? Buffer.from(stateVector, 'base64') : undefined
)
).toString('base64');
const state = Buffer.from(encodeStateVector(doc)).toString('base64');
endTimer();
return {
missing,
state,
};
}
@SubscribeMessage('awareness-init')
async handleInitAwareness(
@MessageBody() workspaceId: string,
@ConnectedSocket() client: Socket
) {
this.metric.socketIOEventCounter(1, { event: 'awareness-init' });
const endTimer = this.metric.socketIOEventTimer({
event: 'init-awareness',
});
if (client.rooms.has(workspaceId)) {
client.to(workspaceId).emit('new-client-awareness-init');
} else {
this.logger.verbose(
`Client ${client.id} tried to init awareness for workspace ${workspaceId} without joining it first`
);
}
endTimer();
}
@SubscribeMessage('awareness-update')
async handleHelpGatheringAwareness(
@MessageBody() message: { workspaceId: string; awarenessUpdate: string },
@ConnectedSocket() client: Socket
) {
this.metric.socketIOEventCounter(1, { event: 'awareness-update' });
const endTimer = this.metric.socketIOEventTimer({
event: 'awareness-update',
});
if (client.rooms.has(message.workspaceId)) {
client.to(message.workspaceId).emit('server-awareness-broadcast', {
...message,
});
} else {
this.logger.verbose(
`Client ${client.id} tried to update awareness for workspace ${message.workspaceId} without joining it first`
);
}
endTimer();
return 'ack';
}
}

View File

@@ -0,0 +1,11 @@
import { Module } from '@nestjs/common';
import { DocModule } from '../../doc';
import { PermissionService } from '../../workspaces/permission';
import { EventsGateway } from './events.gateway';
@Module({
imports: [DocModule.forFeature()],
providers: [EventsGateway, PermissionService],
})
export class EventsModule {}

View File

@@ -0,0 +1,8 @@
import { Module } from '@nestjs/common';
import { EventsModule } from './events/events.module';
@Module({
imports: [EventsModule],
})
export class SyncModule {}

View File

@@ -0,0 +1,37 @@
import { IoAdapter } from '@nestjs/platform-socket.io';
import { createAdapter } from '@socket.io/redis-adapter';
import { Redis } from 'ioredis';
import { ServerOptions } from 'socket.io';
export class RedisIoAdapter extends IoAdapter {
private adapterConstructor: ReturnType<typeof createAdapter> | undefined;
async connectToRedis(
host: string,
port: number,
username: string,
password: string,
db: number
): Promise<void> {
const pubClient = new Redis(port, host, {
username,
password,
db,
});
pubClient.on('error', err => {
console.error(err);
});
const subClient = pubClient.duplicate();
subClient.on('error', err => {
console.error(err);
});
this.adapterConstructor = createAdapter(pubClient, subClient);
}
override createIOServer(port: number, options?: ServerOptions): any {
const server = super.createIOServer(port, options);
server.adapter(this.adapterConstructor);
return server;
}
}

View File

@@ -0,0 +1,11 @@
export function assertExists<T>(
val: T | null | undefined,
message: string | Error = 'val does not exist'
): asserts val is T {
if (val === null || val === undefined) {
if (message instanceof Error) {
throw message;
}
throw new Error(message);
}
}

View File

@@ -0,0 +1,42 @@
type FeatureEarlyAccessPreview = {
whitelist: RegExp[];
};
type FeatureStorageLimit = {
storageQuota: number;
};
type UserFeatureGate = {
earlyAccessPreview: FeatureEarlyAccessPreview;
freeUser: FeatureStorageLimit;
proUser: FeatureStorageLimit;
};
const UserLevel = {
freeUser: {
storageQuota: 10 * 1024 * 1024 * 1024,
},
proUser: {
storageQuota: 100 * 1024 * 1024 * 1024,
},
} satisfies Pick<UserFeatureGate, 'freeUser' | 'proUser'>;
export function getStorageQuota(features: string[]) {
for (const feature of features) {
if (feature in UserLevel) {
return UserLevel[feature as keyof typeof UserLevel].storageQuota;
}
}
return null;
}
const UserType = {
earlyAccessPreview: {
whitelist: [/@toeverything\.info$/],
},
} satisfies Pick<UserFeatureGate, 'earlyAccessPreview'>;
export const FeatureGates = {
...UserType,
...UserLevel,
} satisfies UserFeatureGate;

View File

@@ -0,0 +1,14 @@
import { Module } from '@nestjs/common';
import { StorageModule } from '../storage';
import { UserResolver } from './resolver';
import { UsersService } from './users';
@Module({
imports: [StorageModule],
providers: [UserResolver, UsersService],
})
export class UsersModule {}
export { UserType } from './resolver';
export { UsersService } from './users';

View File

@@ -0,0 +1,240 @@
import {
BadRequestException,
ForbiddenException,
HttpStatus,
UseGuards,
} from '@nestjs/common';
import {
Args,
Field,
ID,
Mutation,
ObjectType,
Query,
registerEnumType,
Resolver,
} from '@nestjs/graphql';
import type { User } from '@prisma/client';
import { GraphQLError } from 'graphql';
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
import { PrismaService } from '../../prisma/service';
import { CloudThrottlerGuard, Throttle } from '../../throttler';
import type { FileUpload } from '../../types';
import { Auth, CurrentUser, Public } from '../auth/guard';
import { StorageService } from '../storage/storage.service';
import { NewFeaturesKind } from './types';
import { UsersService } from './users';
import { isStaff } from './utils';
registerEnumType(NewFeaturesKind, {
name: 'NewFeaturesKind',
});
@ObjectType()
export class UserType implements Partial<User> {
@Field(() => ID)
id!: string;
@Field({ description: 'User name' })
name!: string;
@Field({ description: 'User email' })
email!: string;
@Field(() => String, { description: 'User avatar url', nullable: true })
avatarUrl: string | null = null;
@Field(() => Date, { description: 'User email verified', nullable: true })
emailVerified: Date | null = null;
@Field({ description: 'User created date', nullable: true })
createdAt!: Date;
@Field(() => Boolean, {
description: 'User password has been set',
nullable: true,
})
hasPassword?: boolean;
}
@ObjectType()
export class DeleteAccount {
@Field()
success!: boolean;
}
@ObjectType()
export class RemoveAvatar {
@Field()
success!: boolean;
}
@ObjectType()
export class AddToNewFeaturesWaitingList {
@Field()
email!: string;
@Field(() => NewFeaturesKind, { description: 'New features kind' })
type!: NewFeaturesKind;
}
/**
* User resolver
* All op rate limit: 10 req/m
*/
@UseGuards(CloudThrottlerGuard)
@Auth()
@Resolver(() => UserType)
export class UserResolver {
constructor(
private readonly prisma: PrismaService,
private readonly storage: StorageService,
private readonly users: UsersService
) {}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Query(() => UserType, {
name: 'currentUser',
description: 'Get current user',
})
async currentUser(@CurrentUser() user: UserType) {
const storedUser = await this.users.findUserById(user.id);
if (!storedUser) {
throw new BadRequestException(`User ${user.id} not found in db`);
}
return {
id: storedUser.id,
name: storedUser.name,
email: storedUser.email,
emailVerified: storedUser.emailVerified,
avatarUrl: storedUser.avatarUrl,
createdAt: storedUser.createdAt,
hasPassword: !!storedUser.password,
};
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Query(() => UserType, {
name: 'user',
description: 'Get user by email',
nullable: true,
})
@Public()
async user(@Args('email') email: string) {
if (!(await this.users.canEarlyAccess(email))) {
return new GraphQLError(
`You don't have early access permission\nVisit https://community.affine.pro/c/insider-general/ for more information`,
{
extensions: {
status: HttpStatus[HttpStatus.PAYMENT_REQUIRED],
code: HttpStatus.PAYMENT_REQUIRED,
},
}
);
}
// TODO: need to limit a user can only get another user witch is in the same workspace
const user = await this.users.findUserByEmail(email);
if (user?.password) {
const userResponse: UserType = user;
userResponse.hasPassword = true;
}
return user;
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => UserType, {
name: 'uploadAvatar',
description: 'Upload user avatar',
})
async uploadAvatar(
@CurrentUser() user: UserType,
@Args({ name: 'avatar', type: () => GraphQLUpload })
avatar: FileUpload
) {
if (!user) {
throw new BadRequestException(`User not found`);
}
const url = await this.storage.uploadFile(`${user.id}-avatar`, avatar);
return this.prisma.user.update({
where: { id: user.id },
data: { avatarUrl: url },
});
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => RemoveAvatar, {
name: 'removeAvatar',
description: 'Remove user avatar',
})
async removeAvatar(@CurrentUser() user: UserType) {
if (!user) {
throw new BadRequestException(`User not found`);
}
await this.prisma.user.update({
where: { id: user.id },
data: { avatarUrl: null },
});
return { success: true };
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => DeleteAccount)
async deleteAccount(@CurrentUser() user: UserType): Promise<DeleteAccount> {
await this.users.deleteUser(user.id);
return { success: true };
}
@Throttle({
default: {
limit: 10,
ttl: 60,
},
})
@Mutation(() => AddToNewFeaturesWaitingList)
async addToNewFeaturesWaitingList(
@CurrentUser() user: UserType,
@Args('type', {
type: () => NewFeaturesKind,
})
type: NewFeaturesKind,
@Args('email') email: string
): Promise<AddToNewFeaturesWaitingList> {
if (!isStaff(user.email)) {
throw new ForbiddenException('You are not allowed to do this');
}
await this.prisma.newFeaturesWaitingList.create({
data: {
email,
type,
},
});
return {
email,
type,
};
}
}

View File

@@ -0,0 +1,3 @@
export enum NewFeaturesKind {
EarlyAccess,
}

View File

@@ -0,0 +1,68 @@
import { Injectable } from '@nestjs/common';
import { Config } from '../../config';
import { PrismaService } from '../../prisma';
import { getStorageQuota } from './gates';
import { NewFeaturesKind } from './types';
import { isStaff } from './utils';
@Injectable()
export class UsersService {
constructor(
private readonly prisma: PrismaService,
private readonly config: Config
) {}
async canEarlyAccess(email: string) {
if (this.config.featureFlags.earlyAccessPreview && !isStaff(email)) {
return this.prisma.newFeaturesWaitingList
.findUnique({
where: { email, type: NewFeaturesKind.EarlyAccess },
})
.catch(() => false);
} else {
return true;
}
}
async getStorageQuotaById(id: string) {
const features = await this.prisma.user
.findUnique({
where: { id },
select: {
features: {
select: {
feature: true,
},
},
},
})
.then(user => user?.features.map(f => f.feature) ?? []);
return getStorageQuota(features) || this.config.objectStorage.quota;
}
async findUserByEmail(email: string) {
return this.prisma.user
.findUnique({
where: { email },
})
.catch(() => {
return null;
});
}
async findUserById(id: string) {
return this.prisma.user
.findUnique({
where: { id },
})
.catch(() => {
return null;
});
}
async deleteUser(id: string) {
return this.prisma.user.delete({ where: { id } });
}
}

View File

@@ -0,0 +1,3 @@
export function isStaff(email: string) {
return email.endsWith('@toeverything.info');
}

View File

@@ -0,0 +1,89 @@
import type { Storage } from '@affine/storage';
import {
Controller,
ForbiddenException,
Get,
Inject,
Logger,
NotFoundException,
Param,
Res,
} from '@nestjs/common';
import type { Response } from 'express';
import format from 'pretty-time';
import { StorageProvide } from '../../storage';
import { DocID } from '../../utils/doc';
import { Auth, CurrentUser, Publicable } from '../auth';
import { DocManager } from '../doc';
import { UserType } from '../users';
import { PermissionService } from './permission';
@Controller('/api/workspaces')
export class WorkspacesController {
private readonly logger = new Logger('WorkspacesController');
constructor(
@Inject(StorageProvide) private readonly storage: Storage,
private readonly permission: PermissionService,
private readonly docManager: DocManager
) {}
// get workspace blob
//
// NOTE: because graphql can't represent a File, so we have to use REST API to get blob
@Get('/:id/blobs/:name')
async blob(
@Param('id') workspaceId: string,
@Param('name') name: string,
@Res() res: Response
) {
const blob = await this.storage.getBlob(workspaceId, name);
if (!blob) {
throw new NotFoundException(
`Blob not found in workspace ${workspaceId}: ${name}`
);
}
res.setHeader('content-type', blob.contentType);
res.setHeader('last-modified', blob.lastModified);
res.setHeader('content-length', blob.size);
res.send(blob.data);
}
// get doc binary
@Get('/:id/docs/:guid')
@Auth()
@Publicable()
async doc(
@CurrentUser() user: UserType | undefined,
@Param('id') ws: string,
@Param('guid') guid: string,
@Res() res: Response
) {
const start = process.hrtime();
const docId = new DocID(guid, ws);
if (
// if a user has the permission
!(await this.permission.isAccessible(
docId.workspace,
docId.guid,
user?.id
))
) {
throw new ForbiddenException('Permission denied');
}
const update = await this.docManager.getBinary(docId.workspace, docId.guid);
if (!update) {
throw new NotFoundException('Doc not found');
}
res.setHeader('content-type', 'application/octet-stream');
res.send(update);
this.logger.debug(`workspaces doc api: ${format(process.hrtime(start))}`);
}
}

View File

@@ -0,0 +1,16 @@
import { Module } from '@nestjs/common';
import { DocModule } from '../doc';
import { UsersService } from '../users';
import { WorkspacesController } from './controller';
import { PermissionService } from './permission';
import { WorkspaceResolver } from './resolver';
@Module({
imports: [DocModule.forFeature()],
controllers: [WorkspacesController],
providers: [WorkspaceResolver, PermissionService, UsersService],
exports: [PermissionService],
})
export class WorkspaceModule {}
export { InvitationType, WorkspaceType } from './resolver';

View File

@@ -0,0 +1,298 @@
import { ForbiddenException, Injectable } from '@nestjs/common';
import { Prisma } from '@prisma/client';
import { PrismaService } from '../../prisma';
import { Permission } from './types';
@Injectable()
export class PermissionService {
constructor(private readonly prisma: PrismaService) {}
async get(ws: string, user: string) {
const data = await this.prisma.userWorkspacePermission.findFirst({
where: {
workspaceId: ws,
subPageId: null,
userId: user,
accepted: true,
},
});
return data?.type as Permission;
}
async getWorkspaceOwner(workspaceId: string) {
return this.prisma.userWorkspacePermission.findFirstOrThrow({
where: {
workspaceId,
type: Permission.Owner,
},
include: {
user: true,
},
});
}
async tryGetWorkspaceOwner(workspaceId: string) {
return this.prisma.userWorkspacePermission.findFirst({
where: {
workspaceId,
type: Permission.Owner,
},
include: {
user: true,
},
});
}
async isAccessible(ws: string, id: string, user?: string): Promise<boolean> {
if (user) {
const hasPermission = await this.tryCheck(ws, user);
if (hasPermission) return true;
}
// check if this is a public workspace
const count = await this.prisma.workspace.count({
where: { id: ws, public: true },
});
if (count > 0) {
return true;
}
// check whether this is a public subpage
const workspace = await this.prisma.userWorkspacePermission.findMany({
where: {
workspaceId: ws,
userId: null,
},
});
const subpages = workspace
.map(ws => ws.subPageId)
.filter((v): v is string => !!v);
if (subpages.length > 0 && ws === id) {
// rootDoc is always accessible when there is a public subpage
return true;
} else {
// check if this is a public subpage
return subpages.some(subpage => id === subpage);
}
}
async check(
ws: string,
user: string,
permission: Permission = Permission.Read
) {
if (!(await this.tryCheck(ws, user, permission))) {
throw new ForbiddenException('Permission denied');
}
}
async tryCheck(
ws: string,
user: string,
permission: Permission = Permission.Read
) {
// If the permission is read, we should check if the workspace is public
if (permission === Permission.Read) {
const data = await this.prisma.workspace.count({
where: { id: ws, public: true },
});
if (data > 0) {
return true;
}
}
const data = await this.prisma.userWorkspacePermission.count({
where: {
workspaceId: ws,
subPageId: null,
userId: user,
accepted: true,
type: {
gte: permission,
},
},
});
return data > 0;
}
async grant(
ws: string,
user: string,
permission: Permission = Permission.Read
): Promise<string> {
const data = await this.prisma.userWorkspacePermission.findFirst({
where: {
workspaceId: ws,
subPageId: null,
userId: user,
accepted: true,
},
});
if (data) {
const [p] = await this.prisma.$transaction(
[
this.prisma.userWorkspacePermission.update({
where: {
id: data.id,
},
data: {
type: permission,
},
}),
// If the new permission is owner, we need to revoke old owner
permission === Permission.Owner
? this.prisma.userWorkspacePermission.updateMany({
where: {
workspaceId: ws,
type: Permission.Owner,
userId: {
not: user,
},
},
data: {
type: Permission.Admin,
},
})
: null,
].filter(Boolean) as Prisma.PrismaPromise<any>[]
);
return p.id;
}
return this.prisma.userWorkspacePermission
.create({
data: {
workspaceId: ws,
subPageId: null,
userId: user,
type: permission,
},
})
.then(p => p.id);
}
async getInvitationById(inviteId: string, workspaceId: string) {
return this.prisma.userWorkspacePermission.findUniqueOrThrow({
where: {
id: inviteId,
workspaceId,
},
include: {
user: true,
},
});
}
async acceptById(ws: string, id: string) {
const result = await this.prisma.userWorkspacePermission.updateMany({
where: {
id,
workspaceId: ws,
},
data: {
accepted: true,
},
});
return result.count > 0;
}
async accept(ws: string, user: string) {
const result = await this.prisma.userWorkspacePermission.updateMany({
where: {
workspaceId: ws,
subPageId: null,
userId: user,
accepted: false,
},
data: {
accepted: true,
},
});
return result.count > 0;
}
async revoke(ws: string, user: string) {
const result = await this.prisma.userWorkspacePermission.deleteMany({
where: {
workspaceId: ws,
subPageId: null,
userId: user,
type: {
// We shouldn't revoke owner permission, should auto deleted by workspace/user delete cascading
not: Permission.Owner,
},
},
});
return result.count > 0;
}
async isPageAccessible(ws: string, page: string, user?: string) {
const data = await this.prisma.userWorkspacePermission.findFirst({
where: {
workspaceId: ws,
subPageId: page,
userId: user,
},
});
return data?.accepted || false;
}
async grantPage(
ws: string,
page: string,
user?: string,
permission: Permission = Permission.Read
) {
const data = await this.prisma.userWorkspacePermission.findFirst({
where: {
workspaceId: ws,
subPageId: page,
userId: user,
},
});
if (data) {
return data.accepted;
}
return this.prisma.userWorkspacePermission
.create({
data: {
workspaceId: ws,
subPageId: page,
userId: user,
// if provide user id, user need to accept the invitation
accepted: user ? false : true,
type: permission,
},
})
.then(ret => ret.accepted);
}
async revokePage(ws: string, page: string, user?: string) {
const result = await this.prisma.userWorkspacePermission.deleteMany({
where: {
workspaceId: ws,
subPageId: page,
userId: user,
type: {
// We shouldn't revoke owner permission, should auto deleted by workspace/user delete cascading
not: Permission.Owner,
},
},
});
return result.count > 0;
}
}

View File

@@ -0,0 +1,838 @@
import type { Storage } from '@affine/storage';
import {
ForbiddenException,
Inject,
InternalServerErrorException,
Logger,
NotFoundException,
UseGuards,
} from '@nestjs/common';
import {
Args,
Field,
Float,
ID,
InputType,
Int,
Mutation,
ObjectType,
OmitType,
Parent,
PartialType,
PickType,
Query,
registerEnumType,
ResolveField,
Resolver,
} from '@nestjs/graphql';
import type { User, Workspace } from '@prisma/client';
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
import { applyUpdate, Doc } from 'yjs';
import { PrismaService } from '../../prisma';
import { StorageProvide } from '../../storage';
import { CloudThrottlerGuard, Throttle } from '../../throttler';
import type { FileUpload } from '../../types';
import { DocID } from '../../utils/doc';
import { Auth, CurrentUser, Public } from '../auth';
import { MailService } from '../auth/mailer';
import { AuthService } from '../auth/service';
import { UsersService } from '../users';
import { UserType } from '../users/resolver';
import { PermissionService } from './permission';
import { Permission } from './types';
import { defaultWorkspaceAvatar } from './utils';
registerEnumType(Permission, {
name: 'Permission',
description: 'User permission in workspace',
});
@ObjectType()
export class InviteUserType extends OmitType(
PartialType(UserType),
['id'],
ObjectType
) {
@Field(() => ID)
id!: string;
@Field(() => Permission, { description: 'User permission in workspace' })
permission!: Permission;
@Field({ description: 'Invite id' })
inviteId!: string;
@Field({ description: 'User accepted' })
accepted!: boolean;
}
@ObjectType()
export class WorkspaceType implements Partial<Workspace> {
@Field(() => ID)
id!: string;
@Field({ description: 'is Public workspace' })
public!: boolean;
@Field({ description: 'Workspace created date' })
createdAt!: Date;
@Field(() => [InviteUserType], {
description: 'Members of workspace',
})
members!: InviteUserType[];
}
@ObjectType()
export class InvitationWorkspaceType {
@Field(() => ID)
id!: string;
@Field({ description: 'Workspace name' })
name!: string;
@Field(() => String, {
// nullable: true,
description: 'Base64 encoded avatar',
})
avatar!: string;
}
@ObjectType()
export class WorkspaceBlobSizes {
@Field(() => Float)
size!: number;
}
@ObjectType()
export class InvitationType {
@Field({ description: 'Workspace information' })
workspace!: InvitationWorkspaceType;
@Field({ description: 'User information' })
user!: UserType;
@Field({ description: 'Invitee information' })
invitee!: UserType;
}
@InputType()
export class UpdateWorkspaceInput extends PickType(
PartialType(WorkspaceType),
['public'],
InputType
) {
@Field(() => ID)
id!: string;
}
/**
* Workspace resolver
* Public apis rate limit: 10 req/m
* Other rate limit: 120 req/m
*/
@UseGuards(CloudThrottlerGuard)
@Auth()
@Resolver(() => WorkspaceType)
export class WorkspaceResolver {
private readonly logger = new Logger('WorkspaceResolver');
constructor(
private readonly auth: AuthService,
private readonly mailer: MailService,
private readonly prisma: PrismaService,
private readonly permissions: PermissionService,
private readonly users: UsersService,
@Inject(StorageProvide) private readonly storage: Storage
) {}
@ResolveField(() => Permission, {
description: 'Permission of current signed in user in workspace',
complexity: 2,
})
async permission(
@CurrentUser() user: UserType,
@Parent() workspace: WorkspaceType
) {
// may applied in workspaces query
if ('permission' in workspace) {
return workspace.permission;
}
const permission = await this.permissions.get(workspace.id, user.id);
if (!permission) {
throw new ForbiddenException();
}
return permission;
}
@ResolveField(() => Int, {
description: 'member count of workspace',
complexity: 2,
})
memberCount(@Parent() workspace: WorkspaceType) {
return this.prisma.userWorkspacePermission.count({
where: {
workspaceId: workspace.id,
userId: {
not: null,
},
},
});
}
@ResolveField(() => [String], {
description: 'Shared pages of workspace',
complexity: 2,
})
async sharedPages(@Parent() workspace: WorkspaceType) {
const data = await this.prisma.userWorkspacePermission.findMany({
where: {
workspaceId: workspace.id,
},
});
return data.map(item => item.subPageId).filter(Boolean);
}
@ResolveField(() => UserType, {
description: 'Owner of workspace',
complexity: 2,
})
async owner(@Parent() workspace: WorkspaceType) {
const data = await this.permissions.getWorkspaceOwner(workspace.id);
return data.user;
}
@ResolveField(() => [InviteUserType], {
description: 'Members of workspace',
complexity: 2,
})
async members(
@Parent() workspace: WorkspaceType,
@Args('skip', { type: () => Int, nullable: true }) skip?: number,
@Args('take', { type: () => Int, nullable: true }) take?: number
) {
const data = await this.prisma.userWorkspacePermission.findMany({
where: {
workspaceId: workspace.id,
userId: {
not: null,
},
},
skip,
take: take || 8,
orderBy: [
{
createdAt: 'asc',
},
{
type: 'desc',
},
],
include: {
user: true,
},
});
return data
.filter(({ user }) => !!user)
.map(({ id, accepted, type, user }) => ({
...user,
permission: type,
inviteId: id,
accepted,
}));
}
@Query(() => Boolean, {
description: 'Get is owner of workspace',
complexity: 2,
})
async isOwner(
@CurrentUser() user: UserType,
@Args('workspaceId') workspaceId: string
) {
const data = await this.permissions.tryGetWorkspaceOwner(workspaceId);
return data?.user?.id === user.id;
}
@Query(() => [WorkspaceType], {
description: 'Get all accessible workspaces for current user',
complexity: 2,
})
async workspaces(@CurrentUser() user: User) {
const data = await this.prisma.userWorkspacePermission.findMany({
where: {
userId: user.id,
accepted: true,
},
include: {
workspace: true,
},
});
return data.map(({ workspace, type }) => {
return {
...workspace,
permission: type,
};
});
}
@Throttle({
default: {
limit: 10,
ttl: 30,
},
})
@Public()
@Query(() => WorkspaceType, {
description: 'Get public workspace by id',
})
async publicWorkspace(@Args('id') id: string) {
const workspace = await this.prisma.workspace.findUnique({
where: { id },
});
if (workspace?.public) {
return workspace;
}
throw new NotFoundException("Workspace doesn't exist");
}
@Query(() => WorkspaceType, {
description: 'Get workspace by id',
})
async workspace(@CurrentUser() user: UserType, @Args('id') id: string) {
await this.permissions.check(id, user.id);
const workspace = await this.prisma.workspace.findUnique({ where: { id } });
if (!workspace) {
throw new NotFoundException("Workspace doesn't exist");
}
return workspace;
}
@Mutation(() => WorkspaceType, {
description: 'Create a new workspace',
})
async createWorkspace(
@CurrentUser() user: UserType,
@Args({ name: 'init', type: () => GraphQLUpload })
update: FileUpload
) {
// convert stream to buffer
const buffer = await new Promise<Buffer>((resolve, reject) => {
const stream = update.createReadStream();
const chunks: Uint8Array[] = [];
stream.on('data', chunk => {
chunks.push(chunk);
});
stream.on('error', reject);
stream.on('end', () => {
resolve(Buffer.concat(chunks));
});
});
const workspace = await this.prisma.workspace.create({
data: {
public: false,
users: {
create: {
type: Permission.Owner,
user: {
connect: {
id: user.id,
},
},
accepted: true,
},
},
},
});
if (buffer.length) {
await this.prisma.snapshot.create({
data: {
id: workspace.id,
workspaceId: workspace.id,
blob: buffer,
},
});
}
return workspace;
}
@Mutation(() => WorkspaceType, {
description: 'Update workspace',
})
async updateWorkspace(
@CurrentUser() user: UserType,
@Args({ name: 'input', type: () => UpdateWorkspaceInput })
{ id, ...updates }: UpdateWorkspaceInput
) {
await this.permissions.check(id, user.id, Permission.Admin);
return this.prisma.workspace.update({
where: {
id,
},
data: updates,
});
}
@Mutation(() => Boolean)
async deleteWorkspace(@CurrentUser() user: UserType, @Args('id') id: string) {
await this.permissions.check(id, user.id, Permission.Owner);
await this.prisma.workspace.delete({
where: {
id,
},
});
await this.prisma.$transaction([
this.prisma.update.deleteMany({
where: {
workspaceId: id,
},
}),
this.prisma.snapshot.deleteMany({
where: {
workspaceId: id,
},
}),
]);
return true;
}
@Mutation(() => String)
async invite(
@CurrentUser() user: UserType,
@Args('workspaceId') workspaceId: string,
@Args('email') email: string,
@Args('permission', { type: () => Permission }) permission: Permission,
@Args('sendInviteMail', { nullable: true }) sendInviteMail: boolean
) {
await this.permissions.check(workspaceId, user.id, Permission.Admin);
if (permission === Permission.Owner) {
throw new ForbiddenException('Cannot change owner');
}
const target = await this.users.findUserByEmail(email);
if (target) {
const originRecord = await this.prisma.userWorkspacePermission.findFirst({
where: {
workspaceId,
userId: target.id,
},
});
if (originRecord) {
return originRecord.id;
}
const inviteId = await this.permissions.grant(
workspaceId,
target.id,
permission
);
if (sendInviteMail) {
const inviteInfo = await this.getInviteInfo(inviteId);
try {
await this.mailer.sendInviteEmail(email, inviteId, {
workspace: {
id: inviteInfo.workspace.id,
name: inviteInfo.workspace.name,
avatar: inviteInfo.workspace.avatar,
},
user: {
avatar: inviteInfo.user?.avatarUrl || '',
name: inviteInfo.user?.name || '',
},
});
} catch (e) {
const ret = await this.permissions.revoke(workspaceId, target.id);
if (!ret) {
this.logger.fatal(
`failed to send ${workspaceId} invite email to ${email} and failed to revoke permission: ${inviteId}, ${e}`
);
} else {
this.logger.warn(
`failed to send ${workspaceId} invite email to ${email}, but successfully revoked permission: ${e}`
);
}
return new InternalServerErrorException(e);
}
}
return inviteId;
} else {
const user = await this.auth.createAnonymousUser(email);
const inviteId = await this.permissions.grant(
workspaceId,
user.id,
permission
);
if (sendInviteMail) {
const inviteInfo = await this.getInviteInfo(inviteId);
try {
await this.mailer.sendInviteEmail(email, inviteId, {
workspace: {
id: inviteInfo.workspace.id,
name: inviteInfo.workspace.name,
avatar: inviteInfo.workspace.avatar,
},
user: {
avatar: inviteInfo.user?.avatarUrl || '',
name: inviteInfo.user?.name || '',
},
});
} catch (e) {
const ret = await this.permissions.revoke(workspaceId, user.id);
if (!ret) {
this.logger.fatal(
`failed to send ${workspaceId} invite email to ${email} and failed to revoke permission: ${inviteId}, ${e}`
);
} else {
this.logger.warn(
`failed to send ${workspaceId} invite email to ${email}, but successfully revoked permission: ${e}`
);
}
return new InternalServerErrorException(e);
}
}
return inviteId;
}
}
@Throttle({
default: {
limit: 10,
ttl: 30,
},
})
@Public()
@Query(() => InvitationType, {
description: 'Update workspace',
})
async getInviteInfo(@Args('inviteId') inviteId: string) {
const workspaceId = await this.prisma.userWorkspacePermission
.findUniqueOrThrow({
where: {
id: inviteId,
},
select: {
workspaceId: true,
},
})
.then(({ workspaceId }) => workspaceId);
const snapshot = await this.prisma.snapshot.findFirstOrThrow({
where: {
id: workspaceId,
workspaceId,
},
});
const doc = new Doc();
applyUpdate(doc, new Uint8Array(snapshot.blob));
const metaJSON = doc.getMap('meta').toJSON();
const owner = await this.permissions.getWorkspaceOwner(workspaceId);
const invitee = await this.permissions.getInvitationById(
inviteId,
workspaceId
);
let avatar = '';
if (metaJSON.avatar) {
const avatarBlob = await this.storage.getBlob(
workspaceId,
metaJSON.avatar
);
avatar = avatarBlob?.data.toString('base64') || '';
}
return {
workspace: {
name: metaJSON.name || '',
avatar: avatar || defaultWorkspaceAvatar,
id: workspaceId,
},
user: owner.user,
invitee: invitee.user,
};
}
@Mutation(() => Boolean)
async revoke(
@CurrentUser() user: UserType,
@Args('workspaceId') workspaceId: string,
@Args('userId') userId: string
) {
await this.permissions.check(workspaceId, user.id, Permission.Admin);
return this.permissions.revoke(workspaceId, userId);
}
@Mutation(() => Boolean)
@Public()
async acceptInviteById(
@Args('workspaceId') workspaceId: string,
@Args('inviteId') inviteId: string,
@Args('sendAcceptMail', { nullable: true }) sendAcceptMail: boolean
) {
const {
invitee,
user: inviter,
workspace,
} = await this.getInviteInfo(inviteId);
if (!inviter || !invitee) {
throw new ForbiddenException(
`can not find inviter/invitee by inviteId: ${inviteId}`
);
}
if (sendAcceptMail) {
await this.mailer.sendAcceptedEmail(inviter.email, {
inviteeName: invitee.name,
workspaceName: workspace.name,
});
}
return this.permissions.acceptById(workspaceId, inviteId);
}
@Mutation(() => Boolean)
async acceptInvite(
@CurrentUser() user: UserType,
@Args('workspaceId') workspaceId: string
) {
return this.permissions.accept(workspaceId, user.id);
}
@Mutation(() => Boolean)
async leaveWorkspace(
@CurrentUser() user: UserType,
@Args('workspaceId') workspaceId: string,
@Args('workspaceName') workspaceName: string,
@Args('sendLeaveMail', { nullable: true }) sendLeaveMail: boolean
) {
await this.permissions.check(workspaceId, user.id);
const owner = await this.permissions.getWorkspaceOwner(workspaceId);
if (!owner.user) {
throw new ForbiddenException(
`can not find owner by workspaceId: ${workspaceId}`
);
}
if (sendLeaveMail) {
await this.mailer.sendLeaveWorkspaceEmail(owner.user.email, {
workspaceName,
inviteeName: user.name,
});
}
return this.permissions.revoke(workspaceId, user.id);
}
@Mutation(() => Boolean)
async sharePage(
@CurrentUser() user: UserType,
@Args('workspaceId') workspaceId: string,
@Args('pageId') pageId: string
) {
const docId = new DocID(pageId, workspaceId);
if (docId.isWorkspace) {
throw new ForbiddenException('Expect page not to be workspace');
}
const userWorkspace = await this.prisma.userWorkspacePermission.findFirst({
where: {
userId: user.id,
workspaceId: docId.workspace,
},
});
if (!userWorkspace?.accepted) {
throw new ForbiddenException('Permission denied');
}
return this.permissions.grantPage(docId.workspace, docId.guid);
}
@Mutation(() => Boolean)
async revokePage(
@CurrentUser() user: UserType,
@Args('workspaceId') workspaceId: string,
@Args('pageId') pageId: string
) {
const docId = new DocID(pageId, workspaceId);
if (docId.isWorkspace) {
throw new ForbiddenException('Expect page not to be workspace');
}
await this.permissions.check(docId.workspace, user.id, Permission.Admin);
return this.permissions.revokePage(docId.workspace, docId.guid);
}
@Query(() => [String], {
description: 'List blobs of workspace',
})
async listBlobs(
@CurrentUser() user: UserType,
@Args('workspaceId') workspaceId: string
) {
await this.permissions.check(workspaceId, user.id);
return this.storage.listBlobs(workspaceId);
}
@Query(() => WorkspaceBlobSizes)
async collectBlobSizes(
@CurrentUser() user: UserType,
@Args('workspaceId') workspaceId: string
) {
await this.permissions.check(workspaceId, user.id);
return this.storage.blobsSize([workspaceId]).then(size => ({ size }));
}
@Query(() => WorkspaceBlobSizes)
async collectAllBlobSizes(@CurrentUser() user: UserType) {
const workspaces = await this.prisma.userWorkspacePermission
.findMany({
where: {
userId: user.id,
accepted: true,
type: Permission.Owner,
},
select: {
workspace: {
select: {
id: true,
},
},
},
})
.then(data => data.map(({ workspace }) => workspace.id));
const size = await this.storage.blobsSize(workspaces);
return { size };
}
@Query(() => WorkspaceBlobSizes)
async checkBlobSize(
@CurrentUser() user: UserType,
@Args('workspaceId') workspaceId: string,
@Args('size', { type: () => Float }) size: number
) {
const canWrite = await this.permissions.tryCheck(
workspaceId,
user.id,
Permission.Write
);
if (canWrite) {
const { user } = await this.permissions.getWorkspaceOwner(workspaceId);
if (user) {
const quota = await this.users.getStorageQuotaById(user.id);
const { size: currentSize } = await this.collectAllBlobSizes(user);
return { size: quota - (size + currentSize) };
}
}
return false;
}
@Mutation(() => String)
async setBlob(
@CurrentUser() user: UserType,
@Args('workspaceId') workspaceId: string,
@Args({ name: 'blob', type: () => GraphQLUpload })
blob: FileUpload
) {
await this.permissions.check(workspaceId, user.id, Permission.Write);
// quota was apply to owner's account
const { user: owner } =
await this.permissions.getWorkspaceOwner(workspaceId);
if (!owner) return new NotFoundException('Workspace owner not found');
const quota = await this.users.getStorageQuotaById(owner.id);
const { size } = await this.collectAllBlobSizes(owner);
const checkExceeded = (recvSize: number) => {
if (size + recvSize > quota) {
this.logger.log(
`storage size limit exceeded: ${size + recvSize} > ${quota}`
);
return true;
} else {
return false;
}
};
if (checkExceeded(0)) {
throw new ForbiddenException('storage size limit exceeded');
}
const buffer = await new Promise<Buffer>((resolve, reject) => {
const stream = blob.createReadStream();
const chunks: Uint8Array[] = [];
stream.on('data', chunk => {
chunks.push(chunk);
// check size after receive each chunk to avoid unnecessary memory usage
const bufferSize = chunks.reduce((acc, cur) => acc + cur.length, 0);
if (checkExceeded(bufferSize)) {
reject(new ForbiddenException('storage size limit exceeded'));
}
});
stream.on('error', reject);
stream.on('end', () => {
const buffer = Buffer.concat(chunks);
if (checkExceeded(buffer.length)) {
reject(new ForbiddenException('storage size limit exceeded'));
} else {
resolve(buffer);
}
});
});
return this.storage.uploadBlob(workspaceId, buffer);
}
@Mutation(() => Boolean)
async deleteBlob(
@CurrentUser() user: UserType,
@Args('workspaceId') workspaceId: string,
@Args('hash') hash: string
) {
await this.permissions.check(workspaceId, user.id);
return this.storage.deleteBlob(workspaceId, hash);
}
}

View File

@@ -0,0 +1,6 @@
export enum Permission {
Read = 0,
Write = 1,
Admin = 10,
Owner = 99,
}

View File

@@ -0,0 +1,2 @@
export const defaultWorkspaceAvatar =
'iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAQtSURBVHgBfVa9jhxFEK6q7rkf+4T2AgdIIC0ZoXkBuNQJtngBuIzs1hIRye1FhL438D0CRgKRGUeE6wwkhHYlkE2AtGdkbN/MdJe/qu7Z27PWnnG5Znq7v/rqd47pHddkNh/918tR1/FBamXc9zxOPVFKfJ4yP86qD1LD3/986/3F2zB40+LXv83HrHq/6+gAoNS1kF4odUz2nhJRTkI5E6mD6Bk1crLJkLy5cHc+P4ohzxLng8RKLqKUq6hkUtBSe8Zvdmfir7TT2a0fnkzeaeCbv/44ztSfZskjP2ygVRM0mbYTpgHMMMS8CsIIj/c+//Hp8UYD3z758whQUwdeEwPjAZQLqJhI0VxB2MVco+kXP/0zuZKD6dP5uM397ELzqEtMba/UJ4t7iXeq8U94z52Q+js09qjlIXMxAEsRDJpI59dVPzlDTooHko7BdlR2FcYmAtbGMmAt2mFI4yDQkIjtEQkxUAMKAPD9SiOK4b578N0S7Nt+fqFKbTbmRD1YGXurEmdtnjjz4kFuIV0gtWewV62hMHBY2gpEOw3Rnmztx9jnO72xzTV/YkzgNmgkiypeYJdCLjonqyAAg7VCshVpjTbD08HbxrySdhKxcDvoJTA5gLvpeXVQ+K340WKea9UkNeZVqGSba/IbF6athj+LUeRmRCyiAVnlAKhJJQfmugGZ28ZWna24RGzwNUNUqpWGf6HkajvAgNA4NsSjHgcb9obx+k5c3DUttcwd3NcHxpVurXQ2d4MZACGw9TwEHsdtbEwytL1xywAGcxavjoH1quLVywuGi+aBhFWexRilFSwK0QzgdUdkkVMeKw4wijrgxjzz2CefCRZn+21ViOWW4Ym9nNnyFLMbMS8ivNhGP8RdlgUojBkuBLDpEPi+5LpWiDURgFkKOIIckJTgN/sZ84KtKkKpDnsOZiTQ47jD4ZGwHghbw6AXIL3lo5Zg6Tp2AwIAyYJ8BRzGfmfPl6kI7HOLUdN2LIg+4IfL5SiFdvkK4blI6h50qda7jQI0CUMLdEhFIkqtQciMvXsgpaZ1pWtVUfrIa+TX5/8+RBcftAhTa91r8ycXA5ZxBqhAh2zgVagUAddxMkxfF/JxfvbpB+8d2jhBtsPhtuqsE0HJlhxYeHKdkCU8xUCos8dmkDdnGaOlJ1yy9dM52J2spqldvz9fTgB4z+aQd2kqjUY2KU2s4dTT7ezD0AqDAbvZiKF/VO9+fGPv9IoBu+b/P5ti6djDY+JlSg4ug1jc6fJbMAx9/3b4CNGTD/evT698D9avv188m4gKvko8MiMeJC3jmOvU9MSuHXZohAVpOrmxd+10HW/jR3/58uU45TRFt35ZR2XpY61DzW+tH3z/7xdM8sP93d3Fm1gbDawbEtU7CMtt/JVxEw01Kh7RAmoBE4+u7eycYv38bRivAZbdHBtPrwOHAAAAAElFTkSuQmCC';

View File

@@ -0,0 +1,10 @@
import 'reflect-metadata';
import 'dotenv/config';
import { getDefaultAFFiNEConfig } from './config/default';
globalThis.AFFiNE = getDefaultAFFiNEConfig();
if (process.env.NODE_ENV === 'development') {
console.log('AFFiNE Config:', globalThis.AFFiNE);
}

View File

@@ -0,0 +1,11 @@
import { Global, Module } from '@nestjs/common';
import { PrismaService } from './service';
@Global()
@Module({
providers: [PrismaService],
exports: [PrismaService],
})
export class PrismaModule {}
export { PrismaService } from './service';

View File

@@ -0,0 +1,17 @@
import type { OnModuleDestroy, OnModuleInit } from '@nestjs/common';
import { Injectable } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
@Injectable()
export class PrismaService
extends PrismaClient
implements OnModuleInit, OnModuleDestroy
{
async onModuleInit() {
await this.$connect();
}
async onModuleDestroy(): Promise<void> {
await this.$disconnect();
}
}

View File

@@ -0,0 +1,217 @@
# ------------------------------------------------------
# THIS FILE WAS AUTOMATICALLY GENERATED (DO NOT MODIFY)
# ------------------------------------------------------
type UserType {
id: ID!
"""User name"""
name: String!
"""User email"""
email: String!
"""User avatar url"""
avatarUrl: String
"""User email verified"""
emailVerified: DateTime
"""User created date"""
createdAt: DateTime
"""User password has been set"""
hasPassword: Boolean
token: TokenType!
}
"""
A date-time string at UTC, such as 2019-12-03T09:54:33Z, compliant with the date-time format.
"""
scalar DateTime
type DeleteAccount {
success: Boolean!
}
type RemoveAvatar {
success: Boolean!
}
type AddToNewFeaturesWaitingList {
email: String!
"""New features kind"""
type: NewFeaturesKind!
}
enum NewFeaturesKind {
EarlyAccess
}
type TokenType {
token: String!
refresh: String!
sessionToken: String
}
type InviteUserType {
"""User name"""
name: String
"""User email"""
email: String
"""User avatar url"""
avatarUrl: String
"""User email verified"""
emailVerified: DateTime
"""User created date"""
createdAt: DateTime
"""User password has been set"""
hasPassword: Boolean
id: ID!
"""User permission in workspace"""
permission: Permission!
"""Invite id"""
inviteId: String!
"""User accepted"""
accepted: Boolean!
}
"""User permission in workspace"""
enum Permission {
Read
Write
Admin
Owner
}
type WorkspaceType {
id: ID!
"""is Public workspace"""
public: Boolean!
"""Workspace created date"""
createdAt: DateTime!
"""Members of workspace"""
members(skip: Int, take: Int): [InviteUserType!]!
"""Permission of current signed in user in workspace"""
permission: Permission!
"""member count of workspace"""
memberCount: Int!
"""Shared pages of workspace"""
sharedPages: [String!]!
"""Owner of workspace"""
owner: UserType!
}
type InvitationWorkspaceType {
id: ID!
"""Workspace name"""
name: String!
"""Base64 encoded avatar"""
avatar: String!
}
type WorkspaceBlobSizes {
size: Float!
}
type InvitationType {
"""Workspace information"""
workspace: InvitationWorkspaceType!
"""User information"""
user: UserType!
"""Invitee information"""
invitee: UserType!
}
type Query {
"""Get is owner of workspace"""
isOwner(workspaceId: String!): Boolean!
"""Get all accessible workspaces for current user"""
workspaces: [WorkspaceType!]!
"""Get public workspace by id"""
publicWorkspace(id: String!): WorkspaceType!
"""Get workspace by id"""
workspace(id: String!): WorkspaceType!
"""Update workspace"""
getInviteInfo(inviteId: String!): InvitationType!
"""List blobs of workspace"""
listBlobs(workspaceId: String!): [String!]!
collectBlobSizes(workspaceId: String!): WorkspaceBlobSizes!
collectAllBlobSizes: WorkspaceBlobSizes!
checkBlobSize(workspaceId: String!, size: Float!): WorkspaceBlobSizes!
"""Get current user"""
currentUser: UserType!
"""Get user by email"""
user(email: String!): UserType
}
type Mutation {
signUp(name: String!, email: String!, password: String!): UserType!
signIn(email: String!, password: String!): UserType!
changePassword(token: String!, newPassword: String!): UserType!
changeEmail(token: String!): UserType!
sendChangePasswordEmail(email: String!, callbackUrl: String!): Boolean!
sendSetPasswordEmail(email: String!, callbackUrl: String!): Boolean!
sendChangeEmail(email: String!, callbackUrl: String!): Boolean!
sendVerifyChangeEmail(token: String!, email: String!, callbackUrl: String!): Boolean!
"""Create a new workspace"""
createWorkspace(init: Upload!): WorkspaceType!
"""Update workspace"""
updateWorkspace(input: UpdateWorkspaceInput!): WorkspaceType!
deleteWorkspace(id: String!): Boolean!
invite(workspaceId: String!, email: String!, permission: Permission!, sendInviteMail: Boolean): String!
revoke(workspaceId: String!, userId: String!): Boolean!
acceptInviteById(workspaceId: String!, inviteId: String!, sendAcceptMail: Boolean): Boolean!
acceptInvite(workspaceId: String!): Boolean!
leaveWorkspace(workspaceId: String!, workspaceName: String!, sendLeaveMail: Boolean): Boolean!
sharePage(workspaceId: String!, pageId: String!): Boolean!
revokePage(workspaceId: String!, pageId: String!): Boolean!
setBlob(workspaceId: String!, blob: Upload!): String!
deleteBlob(workspaceId: String!, hash: String!): Boolean!
"""Upload user avatar"""
uploadAvatar(avatar: Upload!): UserType!
"""Remove user avatar"""
removeAvatar: RemoveAvatar!
deleteAccount: DeleteAccount!
addToNewFeaturesWaitingList(type: NewFeaturesKind!, email: String!): AddToNewFeaturesWaitingList!
}
"""The `Upload` scalar type represents a file upload."""
scalar Upload
input UpdateWorkspaceInput {
"""is Public workspace"""
public: Boolean
id: ID!
}

View File

@@ -0,0 +1,73 @@
import KeyvRedis from '@keyv/redis';
import {
FactoryProvider,
Global,
Inject,
Injectable,
Module,
} from '@nestjs/common';
import Redis from 'ioredis';
import Keyv from 'keyv';
import { Config } from './config';
export const KeyvProvide = Symbol('KeyvProvide');
export const KeyvProvider: FactoryProvider<Keyv> = {
provide: KeyvProvide,
useFactory(config: Config) {
if (config.redis.enabled) {
return new Keyv({
store: new KeyvRedis(
new Redis(config.redis.port, config.redis.host, {
username: config.redis.username,
password: config.redis.password,
db: config.redis.database + 2,
})
),
});
} else {
return new Keyv();
}
},
inject: [Config],
};
@Injectable()
export class SessionService {
private readonly prefix = 'session:';
private readonly sessionTtl = 30 * 60 * 1000; // 30 min
constructor(@Inject(KeyvProvide) private readonly cache: Keyv) {}
/**
* get session
* @param key session key
* @returns
*/
async get(key: string) {
return this.cache.get(this.prefix + key);
}
/**
* set session
* @param key session key
* @param value session value
* @param sessionTtl session ttl (ms), default 30 min
* @returns return true if success
*/
async set(key: string, value?: any, sessionTtl = this.sessionTtl) {
return this.cache.set(this.prefix + key, value, sessionTtl);
}
async delete(key: string) {
return this.cache.delete(this.prefix + key);
}
}
@Global()
@Module({
providers: [KeyvProvider, SessionService],
exports: [KeyvProvider, SessionService],
})
export class SessionModule {}

View File

@@ -0,0 +1,50 @@
import { createRequire } from 'node:module';
import { type DynamicModule, type FactoryProvider } from '@nestjs/common';
import { Config } from '../config';
export const StorageProvide = Symbol('Storage');
let storageModule: typeof import('@affine/storage');
try {
storageModule = await import('@affine/storage');
} catch {
const require = createRequire(import.meta.url);
storageModule = require('../../storage.node');
}
export class StorageModule {
static forRoot(): DynamicModule {
const storageProvider: FactoryProvider = {
provide: StorageProvide,
useFactory: async (config: Config) => {
return storageModule.Storage.connect(config.db.url);
},
inject: [Config],
};
return {
global: true,
module: StorageModule,
providers: [storageProvider],
exports: [storageProvider],
};
}
}
export const mergeUpdatesInApplyWay = storageModule.mergeUpdatesInApplyWay;
export const verifyChallengeResponse = async (
response: any,
bits: number,
resource: string
) => {
if (typeof response !== 'string' || !response || !resource) return false;
return storageModule.verifyChallengeResponse(response, bits, resource);
};
export const mintChallengeResponse = async (resource: string, bits: number) => {
if (!resource) return null;
return storageModule.mintChallengeResponse(resource, bits);
};

View File

@@ -0,0 +1,87 @@
import { ExecutionContext, Injectable, Logger } from '@nestjs/common';
import { Global, Module } from '@nestjs/common';
import {
Throttle,
ThrottlerGuard,
ThrottlerModule,
ThrottlerModuleOptions,
} from '@nestjs/throttler';
import Redis from 'ioredis';
import { ThrottlerStorageRedisService } from 'nestjs-throttler-storage-redis';
import { Config, ConfigModule } from './config';
import { getRequestResponseFromContext } from './utils/nestjs';
@Global()
@Module({
imports: [
ThrottlerModule.forRootAsync({
imports: [ConfigModule],
inject: [Config],
useFactory: (config: Config): ThrottlerModuleOptions => {
const options: ThrottlerModuleOptions = {
throttlers: [
{
ttl: config.rateLimiter.ttl,
limit: config.rateLimiter.limit,
},
],
skipIf: () => {
return !config.node.prod || config.affine.canary;
},
};
if (config.redis.enabled) {
new Logger(RateLimiterModule.name).log('Use Redis');
options.storage = new ThrottlerStorageRedisService(
new Redis(config.redis.port, config.redis.host, {
username: config.redis.username,
password: config.redis.password,
db: config.redis.database + 1,
})
);
}
return options;
},
}),
],
})
export class RateLimiterModule {}
@Injectable()
export class CloudThrottlerGuard extends ThrottlerGuard {
override getRequestResponse(context: ExecutionContext) {
return getRequestResponseFromContext(context) as any;
}
protected override getTracker(req: Record<string, any>): Promise<string> {
return Promise.resolve(
req?.get('CF-Connecting-IP') ?? req?.get('CF-ray') ?? req?.ip
);
}
}
@Injectable()
export class AuthThrottlerGuard extends CloudThrottlerGuard {
override async handleRequest(
context: ExecutionContext,
limit: number,
ttl: number
): Promise<boolean> {
const { req } = this.getRequestResponse(context);
if (req?.url === '/api/auth/session') {
// relax throttle for session auto renew
return super.handleRequest(context, limit * 20, ttl, {
ttl: ttl * 20,
limit: limit * 20,
});
}
return super.handleRequest(context, limit, ttl, {
ttl,
limit,
});
}
}
export { Throttle };

View File

@@ -0,0 +1,14 @@
import { Readable } from 'node:stream';
export interface FileUpload {
filename: string;
mimetype: string;
encoding: string;
createReadStream: () => Readable;
}
export interface ReqContext {
req: Express.Request & {
res: Express.Response;
};
}

View File

@@ -0,0 +1,72 @@
import test from 'ava';
import { DocID, DocVariant } from '../doc';
test('can parse', t => {
// workspace only
let id = new DocID('ws');
t.is(id.workspace, 'ws');
t.assert(id.isWorkspace);
// full id
id = new DocID('ws:space:sub');
t.is(id.workspace, 'ws');
t.is(id.variant, DocVariant.Space);
t.is(id.guid, 'sub');
// variant only
id = new DocID('space:sub', 'ws');
t.is(id.workspace, 'ws');
t.is(id.variant, DocVariant.Space);
t.is(id.guid, 'sub');
// sub id only
id = new DocID('sub', 'ws');
t.is(id.workspace, 'ws');
t.is(id.variant, DocVariant.Unknown);
t.is(id.guid, 'sub');
});
test('fail', t => {
t.throws(() => new DocID('a:b:c:d'), {
message: 'Invalid format of Doc ID: a:b:c:d',
});
t.throws(() => new DocID(':space:sub'), { message: 'Workspace is required' });
t.throws(() => new DocID('space:sub'), { message: 'Workspace is required' });
t.throws(() => new DocID('ws:any:sub'), {
message: 'Invalid ID variant: any',
});
t.throws(() => new DocID('ws:space:'), {
message: 'ID is required for non-workspace doc',
});
t.throws(() => new DocID('ws::space'), {
message: 'Variant is required for non-workspace doc',
});
});
test('fix', t => {
let id = new DocID('ws');
// can't fix because the doc variant is [Workspace]
id.fixWorkspace('ws2');
t.is(id.workspace, 'ws');
t.is(id.toString(), 'ws');
id = new DocID('ws:space:sub');
id.fixWorkspace('ws2');
t.is(id.workspace, 'ws2');
t.is(id.toString(), 'ws2:space:sub');
id = new DocID('space:sub', 'ws');
t.is(id.workspace, 'ws');
t.is(id.toString(), 'ws:space:sub');
id = new DocID('ws2:space:sub', 'ws');
t.is(id.workspace, 'ws');
t.is(id.toString(), 'ws:space:sub');
});

View File

@@ -0,0 +1,115 @@
import { registerEnumType } from '@nestjs/graphql';
export enum DocVariant {
Workspace = 'workspace',
Space = 'space',
Settings = 'settings',
Unknown = 'unknown',
}
registerEnumType(DocVariant, {
name: 'DocVariant',
});
export class DocID {
raw: string;
workspace: string;
variant: DocVariant;
private sub: string | null;
static parse(raw: string): DocID | null {
try {
return new DocID(raw);
} catch (e) {
return null;
}
}
/**
* pure guid for workspace and subdoc without any prefix
*/
get guid(): string {
return this.variant === DocVariant.Workspace
? this.workspace
: // sub is always truthy when variant is not workspace
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
this.sub!;
}
get full(): string {
return this.variant === DocVariant.Workspace
? this.workspace
: `${this.workspace}:${this.variant}:${this.sub}`;
}
get isWorkspace(): boolean {
return this.variant === DocVariant.Workspace;
}
constructor(raw: string, workspaceId?: string) {
if (!raw.length) {
throw new Error('Invalid Empty Doc ID');
}
let parts = raw.split(':');
if (parts.length > 3) {
throw new Error(`Invalid format of Doc ID: ${raw}`);
} else if (parts.length === 2) {
// `${variant}:${guid}`
if (!workspaceId) {
throw new Error('Workspace is required');
}
parts.unshift(workspaceId);
} else if (parts.length === 1) {
// ${ws} or ${pageId}
if (workspaceId && parts[0] !== workspaceId) {
parts = [workspaceId, DocVariant.Unknown, parts[0]];
} else {
// parts:[ws] equals [workspaceId]
}
}
let workspace = parts.at(0);
// fix for `${non-workspaceId}:${variant}:${guid}`
if (workspaceId) {
workspace = workspaceId;
}
const variant = parts.at(1);
const docId = parts.at(2);
if (!workspace) {
throw new Error('Workspace is required');
}
if (variant) {
if (!Object.values(DocVariant).includes(variant as any)) {
throw new Error(`Invalid ID variant: ${variant}`);
}
if (!docId) {
throw new Error('ID is required for non-workspace doc');
}
} else if (docId) {
throw new Error('Variant is required for non-workspace doc');
}
this.raw = raw;
this.workspace = workspace;
this.variant = (variant as DocVariant | undefined) ?? DocVariant.Workspace;
this.sub = docId || null;
}
toString() {
return this.full;
}
fixWorkspace(workspaceId: string) {
if (!this.isWorkspace && this.workspace !== workspaceId) {
this.workspace = workspaceId;
}
}
}

View File

@@ -0,0 +1,77 @@
import type { ArgumentsHost, ExecutionContext } from '@nestjs/common';
import type { GqlContextType } from '@nestjs/graphql';
import { GqlArgumentsHost, GqlExecutionContext } from '@nestjs/graphql';
import type { Request, Response } from 'express';
export function getRequestResponseFromContext(context: ExecutionContext) {
switch (context.getType<GqlContextType>()) {
case 'graphql': {
const gqlContext = GqlExecutionContext.create(context).getContext<{
req: Request;
}>();
return {
req: gqlContext.req,
res: gqlContext.req.res,
};
}
case 'http': {
const http = context.switchToHttp();
return {
req: http.getRequest<Request>(),
res: http.getResponse<Response>(),
};
}
case 'ws': {
const ws = context.switchToWs();
const req = ws.getClient().handshake;
const cookies = req?.headers?.cookie;
// patch cookies to match auth guard logic
if (typeof cookies === 'string') {
req.cookies = cookies
.split(';')
.map(v => v.split('='))
.reduce(
(acc, v) => {
acc[decodeURIComponent(v[0].trim())] = decodeURIComponent(
v[1].trim()
);
return acc;
},
{} as Record<string, string>
);
}
return { req };
}
default:
throw new Error('Unknown context type for getting request and response');
}
}
export function getRequestResponseFromHost(host: ArgumentsHost) {
switch (host.getType<GqlContextType>()) {
case 'graphql': {
const gqlContext = GqlArgumentsHost.create(host).getContext<{
req: Request;
}>();
return {
req: gqlContext.req,
res: gqlContext.req.res,
};
}
case 'http': {
const http = host.switchToHttp();
return {
req: http.getRequest<Request>(),
res: http.getResponse<Response>(),
};
}
default:
throw new Error('Unknown host type for getting request and response');
}
}
export function getRequestFromHost(host: ArgumentsHost) {
return getRequestResponseFromHost(host).req;
}

View File

@@ -0,0 +1,42 @@
export type DeepPartial<T> = T extends Array<infer U>
? DeepPartial<U>[]
: T extends ReadonlyArray<infer U>
? ReadonlyArray<DeepPartial<U>>
: T extends object
? {
[K in keyof T]?: DeepPartial<T[K]>;
}
: T;
type Join<Prefix, Suffixes> = Prefix extends string | number
? Suffixes extends string | number
? Prefix extends ''
? Suffixes
: `${Prefix}.${Suffixes}`
: never
: never;
export type PrimitiveType =
| string
| number
| boolean
| symbol
| null
| undefined;
export type LeafPaths<
T,
Path extends string = '',
MaxDepth extends string = '...',
Depth extends string = '',
> = Depth extends MaxDepth
? never
: T extends Record<string | number, any>
? {
[K in keyof T]-?: K extends string | number
? T[K] extends PrimitiveType
? K
: Join<K, LeafPaths<T[K], Path, MaxDepth, `${Depth}.`>>
: never;
}[keyof T]
: never;