refactor(server): auth (#5895)

Remove `next-auth` and implement our own Authorization/Authentication system from scratch.

## Server

- [x] tokens
  - [x] function
  - [x] encryption

- [x] AuthController
  - [x] /api/auth/sign-in
  - [x] /api/auth/sign-out
  - [x] /api/auth/session
  - [x] /api/auth/session (WE SUPPORT MULTI-ACCOUNT!)

- [x] OAuthPlugin
  - [x] OAuthController
  - [x] /oauth/login
  - [x] /oauth/callback
  - [x] Providers
    - [x] Google
    - [x] GitHub

## Client

- [x] useSession
- [x] cloudSignIn
- [x] cloudSignOut

## NOTE:

Tests will be adding in the future
This commit is contained in:
liuyi
2024-03-12 10:00:09 +00:00
parent af49e8cc41
commit fb3a0e7b8f
148 changed files with 3407 additions and 2851 deletions

View File

@@ -1,4 +1,5 @@
import { GCloudConfig } from './gcloud/config';
import { OAuthConfig } from './oauth';
import { PaymentConfig } from './payment';
import { RedisOptions } from './redis';
import { R2StorageConfig, S3StorageConfig } from './storage';
@@ -10,13 +11,14 @@ declare module '../fundamentals/config' {
readonly gcloud: GCloudConfig;
readonly 'cloudflare-r2': R2StorageConfig;
readonly 'aws-s3': S3StorageConfig;
readonly oauth: OAuthConfig;
}
export type AvailablePlugins = keyof PluginsConfig;
interface AFFiNEConfig {
readonly plugins: {
enabled: AvailablePlugins[];
enabled: Set<AvailablePlugins>;
use<Plugin extends AvailablePlugins>(
plugin: Plugin,
config?: DeepPartial<PluginsConfig[Plugin]>

View File

@@ -1,10 +1,11 @@
import { Global } from '@nestjs/common';
import { OptionalModule } from '../../fundamentals';
import { Plugin } from '../registry';
import { GCloudMetrics } from './metrics';
@Global()
@OptionalModule({
@Plugin({
name: 'gcloud',
imports: [GCloudMetrics],
})
export class GCloudModule {}

View File

@@ -1,13 +1,7 @@
import type { AvailablePlugins } from '../fundamentals/config';
import { GCloudModule } from './gcloud';
import { PaymentModule } from './payment';
import { RedisModule } from './redis';
import { AwsS3Module, CloudflareR2Module } from './storage';
import './gcloud';
import './oauth';
import './payment';
import './redis';
import './storage';
export const pluginsMap = new Map<AvailablePlugins, AFFiNEModule>([
['payment', PaymentModule],
['redis', RedisModule],
['gcloud', GCloudModule],
['cloudflare-r2', CloudflareR2Module],
['aws-s3', AwsS3Module],
]);
export { REGISTERED_PLUGINS } from './registry';

View File

@@ -0,0 +1,230 @@
import {
BadRequestException,
Controller,
Get,
Query,
Req,
Res,
} from '@nestjs/common';
import { ConnectedAccount, PrismaClient } from '@prisma/client';
import type { Request, Response } from 'express';
import { AuthService, Public } from '../../core/auth';
import { UserService } from '../../core/user';
import { URLHelper } from '../../fundamentals';
import { OAuthAccount, Tokens } from './providers/def';
import { OAuthProviderFactory } from './register';
import { OAuthService } from './service';
import { OAuthProviderName } from './types';
@Controller('/oauth')
export class OAuthController {
constructor(
private readonly auth: AuthService,
private readonly oauth: OAuthService,
private readonly user: UserService,
private readonly providerFactory: OAuthProviderFactory,
private readonly url: URLHelper,
private readonly db: PrismaClient
) {}
@Public()
@Get('/login')
async login(
@Res() res: Response,
@Query('provider') unknownProviderName: string,
@Query('redirect_uri') redirectUri?: string
) {
// @ts-expect-error safe
const providerName = OAuthProviderName[unknownProviderName];
const provider = this.providerFactory.get(providerName);
if (!provider) {
throw new BadRequestException('Invalid provider');
}
const state = await this.oauth.saveOAuthState({
redirectUri: redirectUri ?? this.url.home,
provider: providerName,
});
return res.redirect(provider.getAuthUrl(state));
}
@Public()
@Get('/callback')
async callback(
@Req() req: Request,
@Res() res: Response,
@Query('code') code?: string,
@Query('state') stateStr?: string
) {
if (!code) {
throw new BadRequestException('Missing query parameter `code`');
}
if (!stateStr) {
throw new BadRequestException('Invalid callback state parameter');
}
const state = await this.oauth.getOAuthState(stateStr);
if (!state) {
throw new BadRequestException('OAuth state expired, please try again.');
}
if (!state.provider) {
throw new BadRequestException(
'Missing callback state parameter `provider`'
);
}
const provider = this.providerFactory.get(state.provider);
if (!provider) {
throw new BadRequestException('Invalid provider');
}
const tokens = await provider.getToken(code);
const externAccount = await provider.getUser(tokens.accessToken);
const user = req.user;
try {
if (!user) {
// if user not found, login
const user = await this.loginFromOauth(
state.provider,
externAccount,
tokens
);
const session = await this.auth.createUserSession(
user,
req.cookies[AuthService.sessionCookieName]
);
res.cookie(AuthService.sessionCookieName, session.sessionId, {
expires: session.expiresAt ?? void 0, // expiredAt is `string | null`
...this.auth.cookieOptions,
});
} else {
// if user is found, connect the account to this user
await this.connectAccountFromOauth(
user,
state.provider,
externAccount,
tokens
);
}
} catch (e: any) {
return res.redirect(
this.url.link('/signIn', {
redirect_uri: state.redirectUri,
error: e.message,
})
);
}
this.url.safeRedirect(res, state.redirectUri);
}
private async loginFromOauth(
provider: OAuthProviderName,
externalAccount: OAuthAccount,
tokens: Tokens
) {
const connectedUser = await this.db.connectedAccount.findFirst({
where: {
provider,
providerAccountId: externalAccount.id,
},
include: {
user: true,
},
});
if (connectedUser) {
// already connected
await this.updateConnectedAccount(connectedUser, tokens);
return connectedUser.user;
}
let user = await this.user.findUserByEmail(externalAccount.email);
if (user) {
// we can't directly connect the external account with given email in sign in scenario for safety concern.
// let user manually connect in account sessions instead.
throw new BadRequestException(
'The account with provided email is not register in the same way.'
);
} else {
user = await this.createUserWithConnectedAccount(
provider,
externalAccount,
tokens
);
}
return user;
}
updateConnectedAccount(connectedUser: ConnectedAccount, tokens: Tokens) {
return this.db.connectedAccount.update({
where: {
id: connectedUser.id,
},
data: tokens,
});
}
async createUserWithConnectedAccount(
provider: OAuthProviderName,
externalAccount: OAuthAccount,
tokens: Tokens
) {
return this.user.createUser({
email: externalAccount.email,
name: 'Unnamed',
avatarUrl: externalAccount.avatarUrl,
emailVerifiedAt: new Date(),
connectedAccounts: {
create: {
provider,
providerAccountId: externalAccount.id,
...tokens,
},
},
});
}
private async connectAccountFromOauth(
user: { id: string },
provider: OAuthProviderName,
externalAccount: OAuthAccount,
tokens: Tokens
) {
const connectedUser = await this.db.connectedAccount.findFirst({
where: {
provider,
providerAccountId: externalAccount.id,
},
});
if (connectedUser) {
if (connectedUser.id !== user.id) {
throw new BadRequestException(
'The third-party account has already been connected to another user.'
);
}
} else {
await this.db.connectedAccount.create({
data: {
userId: user.id,
provider,
providerAccountId: externalAccount.id,
accessToken: tokens.accessToken,
refreshToken: tokens.refreshToken,
},
});
}
}
}

View File

@@ -0,0 +1,25 @@
import { AuthModule } from '../../core/auth';
import { ServerFeature } from '../../core/config';
import { UserModule } from '../../core/user';
import { Plugin } from '../registry';
import { OAuthController } from './controller';
import { OAuthProviders } from './providers';
import { OAuthProviderFactory } from './register';
import { OAuthResolver } from './resolver';
import { OAuthService } from './service';
@Plugin({
name: 'oauth',
imports: [AuthModule, UserModule],
providers: [
OAuthProviderFactory,
OAuthService,
OAuthResolver,
...OAuthProviders,
],
controllers: [OAuthController],
contributesTo: ServerFeature.OAuth,
if: config => !!config.plugins.oauth,
})
export class OAuthModule {}
export type { OAuthConfig } from './types';

View File

@@ -0,0 +1,21 @@
import { OAuthProviderName } from '../types';
export interface OAuthAccount {
id: string;
email: string;
avatarUrl?: string;
}
export interface Tokens {
accessToken: string;
scope?: string;
refreshToken?: string;
expiresAt?: Date;
}
export abstract class OAuthProvider {
abstract provider: OAuthProviderName;
abstract getAuthUrl(state?: string): string;
abstract getToken(code: string): Promise<Tokens>;
abstract getUser(token: string): Promise<OAuthAccount>;
}

View File

@@ -0,0 +1,113 @@
import { HttpException, HttpStatus, Injectable } from '@nestjs/common';
import { Config, URLHelper } from '../../../fundamentals';
import { AutoRegisteredOAuthProvider } from '../register';
import { OAuthProviderName } from '../types';
interface AuthTokenResponse {
access_token: string;
scope: string;
token_type: string;
}
export interface UserInfo {
login: string;
email: string;
avatar_url: string;
name: string;
}
@Injectable()
export class GithubOAuthProvider extends AutoRegisteredOAuthProvider {
provider = OAuthProviderName.GitHub;
constructor(
protected readonly AFFiNEConfig: Config,
private readonly url: URLHelper
) {
super();
}
getAuthUrl(state: string) {
return `https://github.com/login/oauth/authorize?${this.url.stringify({
client_id: this.config.clientId,
redirect_uri: this.url.link('/oauth/callback'),
scope: 'user',
...this.config.args,
state,
})}`;
}
async getToken(code: string) {
try {
const response = await fetch(
'https://github.com/login/oauth/access_token',
{
method: 'POST',
body: this.url.stringify({
code,
client_id: this.config.clientId,
client_secret: this.config.clientSecret,
redirect_uri: this.url.link('/oauth/callback'),
}),
headers: {
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
},
}
);
if (response.ok) {
const ghToken = (await response.json()) as AuthTokenResponse;
return {
accessToken: ghToken.access_token,
scope: ghToken.scope,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
}, ${JSON.stringify(await response.json())}`
);
}
} catch (e) {
throw new HttpException(
`Failed to get access_token, err: ${(e as Error).message}`,
HttpStatus.BAD_REQUEST
);
}
}
async getUser(token: string) {
try {
const response = await fetch('https://api.github.com/user', {
method: 'GET',
headers: {
Authorization: `Bearer ${token}`,
},
});
if (response.ok) {
const user = (await response.json()) as UserInfo;
return {
id: user.login,
avatarUrl: user.avatar_url,
email: user.email,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
} ${await response.text()}`
);
}
} catch (e) {
throw new HttpException(
`Failed to get user information, err: ${(e as Error).stack}`,
HttpStatus.BAD_REQUEST
);
}
}
}

View File

@@ -0,0 +1,121 @@
import { HttpException, HttpStatus, Injectable } from '@nestjs/common';
import { Config, URLHelper } from '../../../fundamentals';
import { AutoRegisteredOAuthProvider } from '../register';
import { OAuthProviderName } from '../types';
interface GoogleOAuthTokenResponse {
access_token: string;
expires_in: number;
refresh_token: string;
scope: string;
token_type: string;
}
export interface UserInfo {
id: string;
email: string;
picture: string;
name: string;
}
@Injectable()
export class GoogleOAuthProvider extends AutoRegisteredOAuthProvider {
override provider = OAuthProviderName.Google;
constructor(
protected readonly AFFiNEConfig: Config,
private readonly url: URLHelper
) {
super();
}
getAuthUrl(state: string) {
return `https://accounts.google.com/o/oauth2/v2/auth?${this.url.stringify({
client_id: this.config.clientId,
redirect_uri: this.url.link('/oauth/callback'),
response_type: 'code',
scope: 'openid email profile',
promot: 'select_account',
access_type: 'offline',
...this.config.args,
state,
})}`;
}
async getToken(code: string) {
try {
const response = await fetch('https://oauth2.googleapis.com/token', {
method: 'POST',
body: this.url.stringify({
code,
client_id: this.config.clientId,
client_secret: this.config.clientSecret,
redirect_uri: this.url.link('/oauth/callback'),
grant_type: 'authorization_code',
}),
headers: {
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
},
});
if (response.ok) {
const ghToken = (await response.json()) as GoogleOAuthTokenResponse;
return {
accessToken: ghToken.access_token,
refreshToken: ghToken.refresh_token,
expiresAt: new Date(Date.now() + ghToken.expires_in * 1000),
scope: ghToken.scope,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
}, ${JSON.stringify(await response.json())}`
);
}
} catch (e) {
throw new HttpException(
`Failed to get access_token, err: ${(e as Error).message}`,
HttpStatus.BAD_REQUEST
);
}
}
async getUser(token: string) {
try {
const response = await fetch(
'https://www.googleapis.com/oauth2/v2/userinfo',
{
method: 'GET',
headers: {
Authorization: `Bearer ${token}`,
},
}
);
if (response.ok) {
const user = (await response.json()) as UserInfo;
return {
id: user.id,
avatarUrl: user.picture,
email: user.email,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
} ${await response.text()}`
);
}
} catch (e) {
throw new HttpException(
`Failed to get user information, err: ${(e as Error).stack}`,
HttpStatus.BAD_REQUEST
);
}
}
}

View File

@@ -0,0 +1,4 @@
import { GithubOAuthProvider } from './github';
import { GoogleOAuthProvider } from './google';
export const OAuthProviders = [GoogleOAuthProvider, GithubOAuthProvider];

View File

@@ -0,0 +1,58 @@
import { Injectable, Logger, OnModuleInit } from '@nestjs/common';
import { Config } from '../../fundamentals';
import { OAuthProvider } from './providers/def';
import { OAuthProviderName } from './types';
const PROVIDERS: Map<OAuthProviderName, OAuthProvider> = new Map();
export function registerOAuthProvider(
name: OAuthProviderName,
provider: OAuthProvider
) {
PROVIDERS.set(name, provider);
}
@Injectable()
export class OAuthProviderFactory {
get providers() {
return PROVIDERS.keys();
}
get(name: OAuthProviderName): OAuthProvider | undefined {
return PROVIDERS.get(name);
}
}
export abstract class AutoRegisteredOAuthProvider
extends OAuthProvider
implements OnModuleInit
{
protected abstract AFFiNEConfig: Config;
get optionalConfig() {
return this.AFFiNEConfig.plugins.oauth?.providers?.[this.provider];
}
get config() {
const config = this.optionalConfig;
if (!config) {
throw new Error(
`OAuthProvider Config should not be used before registered`
);
}
return config;
}
onModuleInit() {
const config = this.optionalConfig;
if (config && config.clientId && config.clientSecret) {
registerOAuthProvider(this.provider, this);
new Logger(`OAuthProvider:${this.provider}`).log(
'OAuth provider registered.'
);
}
}
}

View File

@@ -0,0 +1,17 @@
import { registerEnumType, ResolveField, Resolver } from '@nestjs/graphql';
import { ServerConfigType } from '../../core/config';
import { OAuthProviderFactory } from './register';
import { OAuthProviderName } from './types';
registerEnumType(OAuthProviderName, { name: 'OAuthProviderType' });
@Resolver(() => ServerConfigType)
export class OAuthResolver {
constructor(private readonly factory: OAuthProviderFactory) {}
@ResolveField(() => [OAuthProviderName])
oauthProviders() {
return this.factory.providers;
}
}

View File

@@ -0,0 +1,39 @@
import { randomUUID } from 'node:crypto';
import { Injectable } from '@nestjs/common';
import { SessionCache } from '../../fundamentals';
import { OAuthProviderFactory } from './register';
import { OAuthProviderName } from './types';
const OAUTH_STATE_KEY = 'OAUTH_STATE';
interface OAuthState {
redirectUri: string;
provider: OAuthProviderName;
}
@Injectable()
export class OAuthService {
constructor(
private readonly providerFactory: OAuthProviderFactory,
private readonly cache: SessionCache
) {}
async saveOAuthState(state: OAuthState) {
const token = randomUUID();
await this.cache.set(`${OAUTH_STATE_KEY}:${token}`, state, {
ttl: 3600 * 3 * 1000 /* 3 hours */,
});
return token;
}
async getOAuthState(token: string) {
return this.cache.get<OAuthState>(`${OAUTH_STATE_KEY}:${token}`);
}
availableOAuthProviders() {
return this.providerFactory.providers;
}
}

View File

@@ -0,0 +1,15 @@
export interface OAuthProviderConfig {
clientId: string;
clientSecret: string;
args?: Record<string, string>;
}
export enum OAuthProviderName {
Google = 'google',
GitHub = 'github',
}
export interface OAuthConfig {
enabled: boolean;
providers: Partial<{ [key in OAuthProviderName]: OAuthProviderConfig }>;
}

View File

@@ -1,13 +1,14 @@
import { ServerFeature } from '../../core/config';
import { FeatureModule } from '../../core/features';
import { OptionalModule } from '../../fundamentals';
import { Plugin } from '../registry';
import { SubscriptionResolver, UserSubscriptionResolver } from './resolver';
import { ScheduleManager } from './schedule';
import { SubscriptionService } from './service';
import { StripeProvider } from './stripe';
import { StripeWebhook } from './webhook';
@OptionalModule({
@Plugin({
name: 'payment',
imports: [FeatureModule],
providers: [
ScheduleManager,

View File

@@ -21,8 +21,8 @@ import type { User, UserInvoice, UserSubscription } from '@prisma/client';
import { PrismaClient } from '@prisma/client';
import { groupBy } from 'lodash-es';
import { Auth, CurrentUser, Public } from '../../core/auth';
import { UserType } from '../../core/users';
import { CurrentUser, Public } from '../../core/auth';
import { UserType } from '../../core/user';
import { Config } from '../../fundamentals';
import { decodeLookupKey, SubscriptionService } from './service';
import {
@@ -155,7 +155,6 @@ class CreateCheckoutSessionInput {
idempotencyKey!: string;
}
@Auth()
@Resolver(() => UserSubscriptionType)
export class SubscriptionResolver {
constructor(
@@ -217,7 +216,7 @@ export class SubscriptionResolver {
description: 'Create a subscription checkout link of stripe',
})
async checkout(
@CurrentUser() user: User,
@CurrentUser() user: CurrentUser,
@Args({ name: 'recurring', type: () => SubscriptionRecurring })
recurring: SubscriptionRecurring,
@Args('idempotencyKey') idempotencyKey: string
@@ -241,7 +240,7 @@ export class SubscriptionResolver {
description: 'Create a subscription checkout link of stripe',
})
async createCheckoutSession(
@CurrentUser() user: User,
@CurrentUser() user: CurrentUser,
@Args({ name: 'input', type: () => CreateCheckoutSessionInput })
input: CreateCheckoutSessionInput
) {
@@ -265,13 +264,13 @@ export class SubscriptionResolver {
@Mutation(() => String, {
description: 'Create a stripe customer portal to manage payment methods',
})
async createCustomerPortal(@CurrentUser() user: User) {
async createCustomerPortal(@CurrentUser() user: CurrentUser) {
return this.service.createCustomerPortal(user.id);
}
@Mutation(() => UserSubscriptionType)
async cancelSubscription(
@CurrentUser() user: User,
@CurrentUser() user: CurrentUser,
@Args('idempotencyKey') idempotencyKey: string
) {
return this.service.cancelSubscription(idempotencyKey, user.id);
@@ -279,7 +278,7 @@ export class SubscriptionResolver {
@Mutation(() => UserSubscriptionType)
async resumeSubscription(
@CurrentUser() user: User,
@CurrentUser() user: CurrentUser,
@Args('idempotencyKey') idempotencyKey: string
) {
return this.service.resumeCanceledSubscription(idempotencyKey, user.id);
@@ -287,7 +286,7 @@ export class SubscriptionResolver {
@Mutation(() => UserSubscriptionType)
async updateSubscriptionRecurring(
@CurrentUser() user: User,
@CurrentUser() user: CurrentUser,
@Args({ name: 'recurring', type: () => SubscriptionRecurring })
recurring: SubscriptionRecurring,
@Args('idempotencyKey') idempotencyKey: string

View File

@@ -10,6 +10,7 @@ import type {
import { PrismaClient } from '@prisma/client';
import Stripe from 'stripe';
import { CurrentUser } from '../../core/auth';
import { FeatureManagementService } from '../../core/features';
import { EventEmitter } from '../../fundamentals';
import { ScheduleManager } from './schedule';
@@ -75,7 +76,7 @@ export class SubscriptionService {
redirectUrl,
idempotencyKey,
}: {
user: User;
user: CurrentUser;
recurring: SubscriptionRecurring;
plan: SubscriptionPlan;
promotionCode?: string | null;
@@ -549,7 +550,7 @@ export class SubscriptionService {
private async getOrCreateCustomer(
idempotencyKey: string,
user: User
user: CurrentUser
): Promise<UserStripeCustomer> {
const customer = await this.db.userStripeCustomer.findUnique({
where: {
@@ -649,7 +650,7 @@ export class SubscriptionService {
}
private async getAvailableCoupon(
user: User,
user: CurrentUser,
couponType: CouponType
): Promise<string | null> {
const earlyAccess = await this.features.isEarlyAccessUser(user.email);

View File

@@ -2,9 +2,10 @@ import { Global, Provider, Type } from '@nestjs/common';
import { Redis, type RedisOptions } from 'ioredis';
import { ThrottlerStorageRedisService } from 'nestjs-throttler-storage-redis';
import { Cache, OptionalModule, SessionCache } from '../../fundamentals';
import { Cache, SessionCache } from '../../fundamentals';
import { ThrottlerStorage } from '../../fundamentals/throttler';
import { SocketIoAdapterImpl } from '../../fundamentals/websocket';
import { Plugin } from '../registry';
import { RedisCache } from './cache';
import {
CacheRedis,
@@ -47,7 +48,8 @@ const socketIoRedisAdapterProvider: Provider = {
};
@Global()
@OptionalModule({
@Plugin({
name: 'redis',
providers: [CacheRedis, SessionRedis, ThrottlerRedis, SocketIoRedis],
overrides: [
cacheProvider,

View File

@@ -0,0 +1,22 @@
import { omit } from 'lodash-es';
import { AvailablePlugins } from '../fundamentals/config';
import { OptionalModule, OptionalModuleMetadata } from '../fundamentals/nestjs';
export const REGISTERED_PLUGINS = new Map<AvailablePlugins, AFFiNEModule>();
function register(plugin: AvailablePlugins, module: AFFiNEModule) {
REGISTERED_PLUGINS.set(plugin, module);
}
interface PluginModuleMetadata extends OptionalModuleMetadata {
name: AvailablePlugins;
}
export const Plugin = (options: PluginModuleMetadata) => {
return (target: any) => {
register(options.name, target);
return OptionalModule(omit(options, 'name'))(target);
};
};

View File

@@ -1,5 +1,5 @@
import { OptionalModule } from '../../fundamentals';
import { registerStorageProvider } from '../../fundamentals/storage';
import { Plugin } from '../registry';
import { R2StorageProvider } from './providers/r2';
import { S3StorageProvider } from './providers/s3';
@@ -18,7 +18,8 @@ registerStorageProvider('aws-s3', (config, bucket) => {
return new S3StorageProvider(config.plugins['aws-s3'], bucket);
});
@OptionalModule({
@Plugin({
name: 'cloudflare-r2',
requires: [
'plugins.cloudflare-r2.accountId',
'plugins.cloudflare-r2.credentials.accessKeyId',
@@ -28,7 +29,8 @@ registerStorageProvider('aws-s3', (config, bucket) => {
})
export class CloudflareR2Module {}
@OptionalModule({
@Plugin({
name: 'aws-s3',
requires: [
'plugins.aws-s3.credentials.accessKeyId',
'plugins.aws-s3.credentials.secretAccessKey',