mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 20:38:52 +00:00
refactor(server): auth (#5895)
Remove `next-auth` and implement our own Authorization/Authentication system from scratch.
## Server
- [x] tokens
- [x] function
- [x] encryption
- [x] AuthController
- [x] /api/auth/sign-in
- [x] /api/auth/sign-out
- [x] /api/auth/session
- [x] /api/auth/session (WE SUPPORT MULTI-ACCOUNT!)
- [x] OAuthPlugin
- [x] OAuthController
- [x] /oauth/login
- [x] /oauth/callback
- [x] Providers
- [x] Google
- [x] GitHub
## Client
- [x] useSession
- [x] cloudSignIn
- [x] cloudSignOut
## NOTE:
Tests will be adding in the future
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import { GCloudConfig } from './gcloud/config';
|
||||
import { OAuthConfig } from './oauth';
|
||||
import { PaymentConfig } from './payment';
|
||||
import { RedisOptions } from './redis';
|
||||
import { R2StorageConfig, S3StorageConfig } from './storage';
|
||||
@@ -10,13 +11,14 @@ declare module '../fundamentals/config' {
|
||||
readonly gcloud: GCloudConfig;
|
||||
readonly 'cloudflare-r2': R2StorageConfig;
|
||||
readonly 'aws-s3': S3StorageConfig;
|
||||
readonly oauth: OAuthConfig;
|
||||
}
|
||||
|
||||
export type AvailablePlugins = keyof PluginsConfig;
|
||||
|
||||
interface AFFiNEConfig {
|
||||
readonly plugins: {
|
||||
enabled: AvailablePlugins[];
|
||||
enabled: Set<AvailablePlugins>;
|
||||
use<Plugin extends AvailablePlugins>(
|
||||
plugin: Plugin,
|
||||
config?: DeepPartial<PluginsConfig[Plugin]>
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { Global } from '@nestjs/common';
|
||||
|
||||
import { OptionalModule } from '../../fundamentals';
|
||||
import { Plugin } from '../registry';
|
||||
import { GCloudMetrics } from './metrics';
|
||||
|
||||
@Global()
|
||||
@OptionalModule({
|
||||
@Plugin({
|
||||
name: 'gcloud',
|
||||
imports: [GCloudMetrics],
|
||||
})
|
||||
export class GCloudModule {}
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
import type { AvailablePlugins } from '../fundamentals/config';
|
||||
import { GCloudModule } from './gcloud';
|
||||
import { PaymentModule } from './payment';
|
||||
import { RedisModule } from './redis';
|
||||
import { AwsS3Module, CloudflareR2Module } from './storage';
|
||||
import './gcloud';
|
||||
import './oauth';
|
||||
import './payment';
|
||||
import './redis';
|
||||
import './storage';
|
||||
|
||||
export const pluginsMap = new Map<AvailablePlugins, AFFiNEModule>([
|
||||
['payment', PaymentModule],
|
||||
['redis', RedisModule],
|
||||
['gcloud', GCloudModule],
|
||||
['cloudflare-r2', CloudflareR2Module],
|
||||
['aws-s3', AwsS3Module],
|
||||
]);
|
||||
export { REGISTERED_PLUGINS } from './registry';
|
||||
|
||||
230
packages/backend/server/src/plugins/oauth/controller.ts
Normal file
230
packages/backend/server/src/plugins/oauth/controller.ts
Normal file
@@ -0,0 +1,230 @@
|
||||
import {
|
||||
BadRequestException,
|
||||
Controller,
|
||||
Get,
|
||||
Query,
|
||||
Req,
|
||||
Res,
|
||||
} from '@nestjs/common';
|
||||
import { ConnectedAccount, PrismaClient } from '@prisma/client';
|
||||
import type { Request, Response } from 'express';
|
||||
|
||||
import { AuthService, Public } from '../../core/auth';
|
||||
import { UserService } from '../../core/user';
|
||||
import { URLHelper } from '../../fundamentals';
|
||||
import { OAuthAccount, Tokens } from './providers/def';
|
||||
import { OAuthProviderFactory } from './register';
|
||||
import { OAuthService } from './service';
|
||||
import { OAuthProviderName } from './types';
|
||||
|
||||
@Controller('/oauth')
|
||||
export class OAuthController {
|
||||
constructor(
|
||||
private readonly auth: AuthService,
|
||||
private readonly oauth: OAuthService,
|
||||
private readonly user: UserService,
|
||||
private readonly providerFactory: OAuthProviderFactory,
|
||||
private readonly url: URLHelper,
|
||||
private readonly db: PrismaClient
|
||||
) {}
|
||||
|
||||
@Public()
|
||||
@Get('/login')
|
||||
async login(
|
||||
@Res() res: Response,
|
||||
@Query('provider') unknownProviderName: string,
|
||||
@Query('redirect_uri') redirectUri?: string
|
||||
) {
|
||||
// @ts-expect-error safe
|
||||
const providerName = OAuthProviderName[unknownProviderName];
|
||||
const provider = this.providerFactory.get(providerName);
|
||||
|
||||
if (!provider) {
|
||||
throw new BadRequestException('Invalid provider');
|
||||
}
|
||||
|
||||
const state = await this.oauth.saveOAuthState({
|
||||
redirectUri: redirectUri ?? this.url.home,
|
||||
provider: providerName,
|
||||
});
|
||||
|
||||
return res.redirect(provider.getAuthUrl(state));
|
||||
}
|
||||
|
||||
@Public()
|
||||
@Get('/callback')
|
||||
async callback(
|
||||
@Req() req: Request,
|
||||
@Res() res: Response,
|
||||
@Query('code') code?: string,
|
||||
@Query('state') stateStr?: string
|
||||
) {
|
||||
if (!code) {
|
||||
throw new BadRequestException('Missing query parameter `code`');
|
||||
}
|
||||
|
||||
if (!stateStr) {
|
||||
throw new BadRequestException('Invalid callback state parameter');
|
||||
}
|
||||
|
||||
const state = await this.oauth.getOAuthState(stateStr);
|
||||
|
||||
if (!state) {
|
||||
throw new BadRequestException('OAuth state expired, please try again.');
|
||||
}
|
||||
|
||||
if (!state.provider) {
|
||||
throw new BadRequestException(
|
||||
'Missing callback state parameter `provider`'
|
||||
);
|
||||
}
|
||||
|
||||
const provider = this.providerFactory.get(state.provider);
|
||||
|
||||
if (!provider) {
|
||||
throw new BadRequestException('Invalid provider');
|
||||
}
|
||||
|
||||
const tokens = await provider.getToken(code);
|
||||
const externAccount = await provider.getUser(tokens.accessToken);
|
||||
const user = req.user;
|
||||
|
||||
try {
|
||||
if (!user) {
|
||||
// if user not found, login
|
||||
const user = await this.loginFromOauth(
|
||||
state.provider,
|
||||
externAccount,
|
||||
tokens
|
||||
);
|
||||
const session = await this.auth.createUserSession(
|
||||
user,
|
||||
req.cookies[AuthService.sessionCookieName]
|
||||
);
|
||||
res.cookie(AuthService.sessionCookieName, session.sessionId, {
|
||||
expires: session.expiresAt ?? void 0, // expiredAt is `string | null`
|
||||
...this.auth.cookieOptions,
|
||||
});
|
||||
} else {
|
||||
// if user is found, connect the account to this user
|
||||
await this.connectAccountFromOauth(
|
||||
user,
|
||||
state.provider,
|
||||
externAccount,
|
||||
tokens
|
||||
);
|
||||
}
|
||||
} catch (e: any) {
|
||||
return res.redirect(
|
||||
this.url.link('/signIn', {
|
||||
redirect_uri: state.redirectUri,
|
||||
error: e.message,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
this.url.safeRedirect(res, state.redirectUri);
|
||||
}
|
||||
|
||||
private async loginFromOauth(
|
||||
provider: OAuthProviderName,
|
||||
externalAccount: OAuthAccount,
|
||||
tokens: Tokens
|
||||
) {
|
||||
const connectedUser = await this.db.connectedAccount.findFirst({
|
||||
where: {
|
||||
provider,
|
||||
providerAccountId: externalAccount.id,
|
||||
},
|
||||
include: {
|
||||
user: true,
|
||||
},
|
||||
});
|
||||
|
||||
if (connectedUser) {
|
||||
// already connected
|
||||
await this.updateConnectedAccount(connectedUser, tokens);
|
||||
|
||||
return connectedUser.user;
|
||||
}
|
||||
|
||||
let user = await this.user.findUserByEmail(externalAccount.email);
|
||||
|
||||
if (user) {
|
||||
// we can't directly connect the external account with given email in sign in scenario for safety concern.
|
||||
// let user manually connect in account sessions instead.
|
||||
throw new BadRequestException(
|
||||
'The account with provided email is not register in the same way.'
|
||||
);
|
||||
} else {
|
||||
user = await this.createUserWithConnectedAccount(
|
||||
provider,
|
||||
externalAccount,
|
||||
tokens
|
||||
);
|
||||
}
|
||||
|
||||
return user;
|
||||
}
|
||||
|
||||
updateConnectedAccount(connectedUser: ConnectedAccount, tokens: Tokens) {
|
||||
return this.db.connectedAccount.update({
|
||||
where: {
|
||||
id: connectedUser.id,
|
||||
},
|
||||
data: tokens,
|
||||
});
|
||||
}
|
||||
|
||||
async createUserWithConnectedAccount(
|
||||
provider: OAuthProviderName,
|
||||
externalAccount: OAuthAccount,
|
||||
tokens: Tokens
|
||||
) {
|
||||
return this.user.createUser({
|
||||
email: externalAccount.email,
|
||||
name: 'Unnamed',
|
||||
avatarUrl: externalAccount.avatarUrl,
|
||||
emailVerifiedAt: new Date(),
|
||||
connectedAccounts: {
|
||||
create: {
|
||||
provider,
|
||||
providerAccountId: externalAccount.id,
|
||||
...tokens,
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
private async connectAccountFromOauth(
|
||||
user: { id: string },
|
||||
provider: OAuthProviderName,
|
||||
externalAccount: OAuthAccount,
|
||||
tokens: Tokens
|
||||
) {
|
||||
const connectedUser = await this.db.connectedAccount.findFirst({
|
||||
where: {
|
||||
provider,
|
||||
providerAccountId: externalAccount.id,
|
||||
},
|
||||
});
|
||||
|
||||
if (connectedUser) {
|
||||
if (connectedUser.id !== user.id) {
|
||||
throw new BadRequestException(
|
||||
'The third-party account has already been connected to another user.'
|
||||
);
|
||||
}
|
||||
} else {
|
||||
await this.db.connectedAccount.create({
|
||||
data: {
|
||||
userId: user.id,
|
||||
provider,
|
||||
providerAccountId: externalAccount.id,
|
||||
accessToken: tokens.accessToken,
|
||||
refreshToken: tokens.refreshToken,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
25
packages/backend/server/src/plugins/oauth/index.ts
Normal file
25
packages/backend/server/src/plugins/oauth/index.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { AuthModule } from '../../core/auth';
|
||||
import { ServerFeature } from '../../core/config';
|
||||
import { UserModule } from '../../core/user';
|
||||
import { Plugin } from '../registry';
|
||||
import { OAuthController } from './controller';
|
||||
import { OAuthProviders } from './providers';
|
||||
import { OAuthProviderFactory } from './register';
|
||||
import { OAuthResolver } from './resolver';
|
||||
import { OAuthService } from './service';
|
||||
|
||||
@Plugin({
|
||||
name: 'oauth',
|
||||
imports: [AuthModule, UserModule],
|
||||
providers: [
|
||||
OAuthProviderFactory,
|
||||
OAuthService,
|
||||
OAuthResolver,
|
||||
...OAuthProviders,
|
||||
],
|
||||
controllers: [OAuthController],
|
||||
contributesTo: ServerFeature.OAuth,
|
||||
if: config => !!config.plugins.oauth,
|
||||
})
|
||||
export class OAuthModule {}
|
||||
export type { OAuthConfig } from './types';
|
||||
21
packages/backend/server/src/plugins/oauth/providers/def.ts
Normal file
21
packages/backend/server/src/plugins/oauth/providers/def.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
import { OAuthProviderName } from '../types';
|
||||
|
||||
export interface OAuthAccount {
|
||||
id: string;
|
||||
email: string;
|
||||
avatarUrl?: string;
|
||||
}
|
||||
|
||||
export interface Tokens {
|
||||
accessToken: string;
|
||||
scope?: string;
|
||||
refreshToken?: string;
|
||||
expiresAt?: Date;
|
||||
}
|
||||
|
||||
export abstract class OAuthProvider {
|
||||
abstract provider: OAuthProviderName;
|
||||
abstract getAuthUrl(state?: string): string;
|
||||
abstract getToken(code: string): Promise<Tokens>;
|
||||
abstract getUser(token: string): Promise<OAuthAccount>;
|
||||
}
|
||||
113
packages/backend/server/src/plugins/oauth/providers/github.ts
Normal file
113
packages/backend/server/src/plugins/oauth/providers/github.ts
Normal file
@@ -0,0 +1,113 @@
|
||||
import { HttpException, HttpStatus, Injectable } from '@nestjs/common';
|
||||
|
||||
import { Config, URLHelper } from '../../../fundamentals';
|
||||
import { AutoRegisteredOAuthProvider } from '../register';
|
||||
import { OAuthProviderName } from '../types';
|
||||
|
||||
interface AuthTokenResponse {
|
||||
access_token: string;
|
||||
scope: string;
|
||||
token_type: string;
|
||||
}
|
||||
|
||||
export interface UserInfo {
|
||||
login: string;
|
||||
email: string;
|
||||
avatar_url: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class GithubOAuthProvider extends AutoRegisteredOAuthProvider {
|
||||
provider = OAuthProviderName.GitHub;
|
||||
|
||||
constructor(
|
||||
protected readonly AFFiNEConfig: Config,
|
||||
private readonly url: URLHelper
|
||||
) {
|
||||
super();
|
||||
}
|
||||
|
||||
getAuthUrl(state: string) {
|
||||
return `https://github.com/login/oauth/authorize?${this.url.stringify({
|
||||
client_id: this.config.clientId,
|
||||
redirect_uri: this.url.link('/oauth/callback'),
|
||||
scope: 'user',
|
||||
...this.config.args,
|
||||
state,
|
||||
})}`;
|
||||
}
|
||||
|
||||
async getToken(code: string) {
|
||||
try {
|
||||
const response = await fetch(
|
||||
'https://github.com/login/oauth/access_token',
|
||||
{
|
||||
method: 'POST',
|
||||
body: this.url.stringify({
|
||||
code,
|
||||
client_id: this.config.clientId,
|
||||
client_secret: this.config.clientSecret,
|
||||
redirect_uri: this.url.link('/oauth/callback'),
|
||||
}),
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
const ghToken = (await response.json()) as AuthTokenResponse;
|
||||
|
||||
return {
|
||||
accessToken: ghToken.access_token,
|
||||
scope: ghToken.scope,
|
||||
};
|
||||
} else {
|
||||
throw new Error(
|
||||
`Server responded with non-success code ${
|
||||
response.status
|
||||
}, ${JSON.stringify(await response.json())}`
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
throw new HttpException(
|
||||
`Failed to get access_token, err: ${(e as Error).message}`,
|
||||
HttpStatus.BAD_REQUEST
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async getUser(token: string) {
|
||||
try {
|
||||
const response = await fetch('https://api.github.com/user', {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const user = (await response.json()) as UserInfo;
|
||||
|
||||
return {
|
||||
id: user.login,
|
||||
avatarUrl: user.avatar_url,
|
||||
email: user.email,
|
||||
};
|
||||
} else {
|
||||
throw new Error(
|
||||
`Server responded with non-success code ${
|
||||
response.status
|
||||
} ${await response.text()}`
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
throw new HttpException(
|
||||
`Failed to get user information, err: ${(e as Error).stack}`,
|
||||
HttpStatus.BAD_REQUEST
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
121
packages/backend/server/src/plugins/oauth/providers/google.ts
Normal file
121
packages/backend/server/src/plugins/oauth/providers/google.ts
Normal file
@@ -0,0 +1,121 @@
|
||||
import { HttpException, HttpStatus, Injectable } from '@nestjs/common';
|
||||
|
||||
import { Config, URLHelper } from '../../../fundamentals';
|
||||
import { AutoRegisteredOAuthProvider } from '../register';
|
||||
import { OAuthProviderName } from '../types';
|
||||
|
||||
interface GoogleOAuthTokenResponse {
|
||||
access_token: string;
|
||||
expires_in: number;
|
||||
refresh_token: string;
|
||||
scope: string;
|
||||
token_type: string;
|
||||
}
|
||||
|
||||
export interface UserInfo {
|
||||
id: string;
|
||||
email: string;
|
||||
picture: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class GoogleOAuthProvider extends AutoRegisteredOAuthProvider {
|
||||
override provider = OAuthProviderName.Google;
|
||||
|
||||
constructor(
|
||||
protected readonly AFFiNEConfig: Config,
|
||||
private readonly url: URLHelper
|
||||
) {
|
||||
super();
|
||||
}
|
||||
|
||||
getAuthUrl(state: string) {
|
||||
return `https://accounts.google.com/o/oauth2/v2/auth?${this.url.stringify({
|
||||
client_id: this.config.clientId,
|
||||
redirect_uri: this.url.link('/oauth/callback'),
|
||||
response_type: 'code',
|
||||
scope: 'openid email profile',
|
||||
promot: 'select_account',
|
||||
access_type: 'offline',
|
||||
...this.config.args,
|
||||
state,
|
||||
})}`;
|
||||
}
|
||||
|
||||
async getToken(code: string) {
|
||||
try {
|
||||
const response = await fetch('https://oauth2.googleapis.com/token', {
|
||||
method: 'POST',
|
||||
body: this.url.stringify({
|
||||
code,
|
||||
client_id: this.config.clientId,
|
||||
client_secret: this.config.clientSecret,
|
||||
redirect_uri: this.url.link('/oauth/callback'),
|
||||
grant_type: 'authorization_code',
|
||||
}),
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const ghToken = (await response.json()) as GoogleOAuthTokenResponse;
|
||||
|
||||
return {
|
||||
accessToken: ghToken.access_token,
|
||||
refreshToken: ghToken.refresh_token,
|
||||
expiresAt: new Date(Date.now() + ghToken.expires_in * 1000),
|
||||
scope: ghToken.scope,
|
||||
};
|
||||
} else {
|
||||
throw new Error(
|
||||
`Server responded with non-success code ${
|
||||
response.status
|
||||
}, ${JSON.stringify(await response.json())}`
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
throw new HttpException(
|
||||
`Failed to get access_token, err: ${(e as Error).message}`,
|
||||
HttpStatus.BAD_REQUEST
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async getUser(token: string) {
|
||||
try {
|
||||
const response = await fetch(
|
||||
'https://www.googleapis.com/oauth2/v2/userinfo',
|
||||
{
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
const user = (await response.json()) as UserInfo;
|
||||
|
||||
return {
|
||||
id: user.id,
|
||||
avatarUrl: user.picture,
|
||||
email: user.email,
|
||||
};
|
||||
} else {
|
||||
throw new Error(
|
||||
`Server responded with non-success code ${
|
||||
response.status
|
||||
} ${await response.text()}`
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
throw new HttpException(
|
||||
`Failed to get user information, err: ${(e as Error).stack}`,
|
||||
HttpStatus.BAD_REQUEST
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
import { GithubOAuthProvider } from './github';
|
||||
import { GoogleOAuthProvider } from './google';
|
||||
|
||||
export const OAuthProviders = [GoogleOAuthProvider, GithubOAuthProvider];
|
||||
58
packages/backend/server/src/plugins/oauth/register.ts
Normal file
58
packages/backend/server/src/plugins/oauth/register.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
import { Injectable, Logger, OnModuleInit } from '@nestjs/common';
|
||||
|
||||
import { Config } from '../../fundamentals';
|
||||
import { OAuthProvider } from './providers/def';
|
||||
import { OAuthProviderName } from './types';
|
||||
|
||||
const PROVIDERS: Map<OAuthProviderName, OAuthProvider> = new Map();
|
||||
|
||||
export function registerOAuthProvider(
|
||||
name: OAuthProviderName,
|
||||
provider: OAuthProvider
|
||||
) {
|
||||
PROVIDERS.set(name, provider);
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class OAuthProviderFactory {
|
||||
get providers() {
|
||||
return PROVIDERS.keys();
|
||||
}
|
||||
|
||||
get(name: OAuthProviderName): OAuthProvider | undefined {
|
||||
return PROVIDERS.get(name);
|
||||
}
|
||||
}
|
||||
|
||||
export abstract class AutoRegisteredOAuthProvider
|
||||
extends OAuthProvider
|
||||
implements OnModuleInit
|
||||
{
|
||||
protected abstract AFFiNEConfig: Config;
|
||||
|
||||
get optionalConfig() {
|
||||
return this.AFFiNEConfig.plugins.oauth?.providers?.[this.provider];
|
||||
}
|
||||
|
||||
get config() {
|
||||
const config = this.optionalConfig;
|
||||
|
||||
if (!config) {
|
||||
throw new Error(
|
||||
`OAuthProvider Config should not be used before registered`
|
||||
);
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
onModuleInit() {
|
||||
const config = this.optionalConfig;
|
||||
if (config && config.clientId && config.clientSecret) {
|
||||
registerOAuthProvider(this.provider, this);
|
||||
new Logger(`OAuthProvider:${this.provider}`).log(
|
||||
'OAuth provider registered.'
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
17
packages/backend/server/src/plugins/oauth/resolver.ts
Normal file
17
packages/backend/server/src/plugins/oauth/resolver.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
import { registerEnumType, ResolveField, Resolver } from '@nestjs/graphql';
|
||||
|
||||
import { ServerConfigType } from '../../core/config';
|
||||
import { OAuthProviderFactory } from './register';
|
||||
import { OAuthProviderName } from './types';
|
||||
|
||||
registerEnumType(OAuthProviderName, { name: 'OAuthProviderType' });
|
||||
|
||||
@Resolver(() => ServerConfigType)
|
||||
export class OAuthResolver {
|
||||
constructor(private readonly factory: OAuthProviderFactory) {}
|
||||
|
||||
@ResolveField(() => [OAuthProviderName])
|
||||
oauthProviders() {
|
||||
return this.factory.providers;
|
||||
}
|
||||
}
|
||||
39
packages/backend/server/src/plugins/oauth/service.ts
Normal file
39
packages/backend/server/src/plugins/oauth/service.ts
Normal file
@@ -0,0 +1,39 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { SessionCache } from '../../fundamentals';
|
||||
import { OAuthProviderFactory } from './register';
|
||||
import { OAuthProviderName } from './types';
|
||||
|
||||
const OAUTH_STATE_KEY = 'OAUTH_STATE';
|
||||
|
||||
interface OAuthState {
|
||||
redirectUri: string;
|
||||
provider: OAuthProviderName;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class OAuthService {
|
||||
constructor(
|
||||
private readonly providerFactory: OAuthProviderFactory,
|
||||
private readonly cache: SessionCache
|
||||
) {}
|
||||
|
||||
async saveOAuthState(state: OAuthState) {
|
||||
const token = randomUUID();
|
||||
await this.cache.set(`${OAUTH_STATE_KEY}:${token}`, state, {
|
||||
ttl: 3600 * 3 * 1000 /* 3 hours */,
|
||||
});
|
||||
|
||||
return token;
|
||||
}
|
||||
|
||||
async getOAuthState(token: string) {
|
||||
return this.cache.get<OAuthState>(`${OAUTH_STATE_KEY}:${token}`);
|
||||
}
|
||||
|
||||
availableOAuthProviders() {
|
||||
return this.providerFactory.providers;
|
||||
}
|
||||
}
|
||||
15
packages/backend/server/src/plugins/oauth/types.ts
Normal file
15
packages/backend/server/src/plugins/oauth/types.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
export interface OAuthProviderConfig {
|
||||
clientId: string;
|
||||
clientSecret: string;
|
||||
args?: Record<string, string>;
|
||||
}
|
||||
|
||||
export enum OAuthProviderName {
|
||||
Google = 'google',
|
||||
GitHub = 'github',
|
||||
}
|
||||
|
||||
export interface OAuthConfig {
|
||||
enabled: boolean;
|
||||
providers: Partial<{ [key in OAuthProviderName]: OAuthProviderConfig }>;
|
||||
}
|
||||
@@ -1,13 +1,14 @@
|
||||
import { ServerFeature } from '../../core/config';
|
||||
import { FeatureModule } from '../../core/features';
|
||||
import { OptionalModule } from '../../fundamentals';
|
||||
import { Plugin } from '../registry';
|
||||
import { SubscriptionResolver, UserSubscriptionResolver } from './resolver';
|
||||
import { ScheduleManager } from './schedule';
|
||||
import { SubscriptionService } from './service';
|
||||
import { StripeProvider } from './stripe';
|
||||
import { StripeWebhook } from './webhook';
|
||||
|
||||
@OptionalModule({
|
||||
@Plugin({
|
||||
name: 'payment',
|
||||
imports: [FeatureModule],
|
||||
providers: [
|
||||
ScheduleManager,
|
||||
|
||||
@@ -21,8 +21,8 @@ import type { User, UserInvoice, UserSubscription } from '@prisma/client';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { groupBy } from 'lodash-es';
|
||||
|
||||
import { Auth, CurrentUser, Public } from '../../core/auth';
|
||||
import { UserType } from '../../core/users';
|
||||
import { CurrentUser, Public } from '../../core/auth';
|
||||
import { UserType } from '../../core/user';
|
||||
import { Config } from '../../fundamentals';
|
||||
import { decodeLookupKey, SubscriptionService } from './service';
|
||||
import {
|
||||
@@ -155,7 +155,6 @@ class CreateCheckoutSessionInput {
|
||||
idempotencyKey!: string;
|
||||
}
|
||||
|
||||
@Auth()
|
||||
@Resolver(() => UserSubscriptionType)
|
||||
export class SubscriptionResolver {
|
||||
constructor(
|
||||
@@ -217,7 +216,7 @@ export class SubscriptionResolver {
|
||||
description: 'Create a subscription checkout link of stripe',
|
||||
})
|
||||
async checkout(
|
||||
@CurrentUser() user: User,
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'recurring', type: () => SubscriptionRecurring })
|
||||
recurring: SubscriptionRecurring,
|
||||
@Args('idempotencyKey') idempotencyKey: string
|
||||
@@ -241,7 +240,7 @@ export class SubscriptionResolver {
|
||||
description: 'Create a subscription checkout link of stripe',
|
||||
})
|
||||
async createCheckoutSession(
|
||||
@CurrentUser() user: User,
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'input', type: () => CreateCheckoutSessionInput })
|
||||
input: CreateCheckoutSessionInput
|
||||
) {
|
||||
@@ -265,13 +264,13 @@ export class SubscriptionResolver {
|
||||
@Mutation(() => String, {
|
||||
description: 'Create a stripe customer portal to manage payment methods',
|
||||
})
|
||||
async createCustomerPortal(@CurrentUser() user: User) {
|
||||
async createCustomerPortal(@CurrentUser() user: CurrentUser) {
|
||||
return this.service.createCustomerPortal(user.id);
|
||||
}
|
||||
|
||||
@Mutation(() => UserSubscriptionType)
|
||||
async cancelSubscription(
|
||||
@CurrentUser() user: User,
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('idempotencyKey') idempotencyKey: string
|
||||
) {
|
||||
return this.service.cancelSubscription(idempotencyKey, user.id);
|
||||
@@ -279,7 +278,7 @@ export class SubscriptionResolver {
|
||||
|
||||
@Mutation(() => UserSubscriptionType)
|
||||
async resumeSubscription(
|
||||
@CurrentUser() user: User,
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('idempotencyKey') idempotencyKey: string
|
||||
) {
|
||||
return this.service.resumeCanceledSubscription(idempotencyKey, user.id);
|
||||
@@ -287,7 +286,7 @@ export class SubscriptionResolver {
|
||||
|
||||
@Mutation(() => UserSubscriptionType)
|
||||
async updateSubscriptionRecurring(
|
||||
@CurrentUser() user: User,
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'recurring', type: () => SubscriptionRecurring })
|
||||
recurring: SubscriptionRecurring,
|
||||
@Args('idempotencyKey') idempotencyKey: string
|
||||
|
||||
@@ -10,6 +10,7 @@ import type {
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import Stripe from 'stripe';
|
||||
|
||||
import { CurrentUser } from '../../core/auth';
|
||||
import { FeatureManagementService } from '../../core/features';
|
||||
import { EventEmitter } from '../../fundamentals';
|
||||
import { ScheduleManager } from './schedule';
|
||||
@@ -75,7 +76,7 @@ export class SubscriptionService {
|
||||
redirectUrl,
|
||||
idempotencyKey,
|
||||
}: {
|
||||
user: User;
|
||||
user: CurrentUser;
|
||||
recurring: SubscriptionRecurring;
|
||||
plan: SubscriptionPlan;
|
||||
promotionCode?: string | null;
|
||||
@@ -549,7 +550,7 @@ export class SubscriptionService {
|
||||
|
||||
private async getOrCreateCustomer(
|
||||
idempotencyKey: string,
|
||||
user: User
|
||||
user: CurrentUser
|
||||
): Promise<UserStripeCustomer> {
|
||||
const customer = await this.db.userStripeCustomer.findUnique({
|
||||
where: {
|
||||
@@ -649,7 +650,7 @@ export class SubscriptionService {
|
||||
}
|
||||
|
||||
private async getAvailableCoupon(
|
||||
user: User,
|
||||
user: CurrentUser,
|
||||
couponType: CouponType
|
||||
): Promise<string | null> {
|
||||
const earlyAccess = await this.features.isEarlyAccessUser(user.email);
|
||||
|
||||
@@ -2,9 +2,10 @@ import { Global, Provider, Type } from '@nestjs/common';
|
||||
import { Redis, type RedisOptions } from 'ioredis';
|
||||
import { ThrottlerStorageRedisService } from 'nestjs-throttler-storage-redis';
|
||||
|
||||
import { Cache, OptionalModule, SessionCache } from '../../fundamentals';
|
||||
import { Cache, SessionCache } from '../../fundamentals';
|
||||
import { ThrottlerStorage } from '../../fundamentals/throttler';
|
||||
import { SocketIoAdapterImpl } from '../../fundamentals/websocket';
|
||||
import { Plugin } from '../registry';
|
||||
import { RedisCache } from './cache';
|
||||
import {
|
||||
CacheRedis,
|
||||
@@ -47,7 +48,8 @@ const socketIoRedisAdapterProvider: Provider = {
|
||||
};
|
||||
|
||||
@Global()
|
||||
@OptionalModule({
|
||||
@Plugin({
|
||||
name: 'redis',
|
||||
providers: [CacheRedis, SessionRedis, ThrottlerRedis, SocketIoRedis],
|
||||
overrides: [
|
||||
cacheProvider,
|
||||
|
||||
22
packages/backend/server/src/plugins/registry.ts
Normal file
22
packages/backend/server/src/plugins/registry.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import { omit } from 'lodash-es';
|
||||
|
||||
import { AvailablePlugins } from '../fundamentals/config';
|
||||
import { OptionalModule, OptionalModuleMetadata } from '../fundamentals/nestjs';
|
||||
|
||||
export const REGISTERED_PLUGINS = new Map<AvailablePlugins, AFFiNEModule>();
|
||||
|
||||
function register(plugin: AvailablePlugins, module: AFFiNEModule) {
|
||||
REGISTERED_PLUGINS.set(plugin, module);
|
||||
}
|
||||
|
||||
interface PluginModuleMetadata extends OptionalModuleMetadata {
|
||||
name: AvailablePlugins;
|
||||
}
|
||||
|
||||
export const Plugin = (options: PluginModuleMetadata) => {
|
||||
return (target: any) => {
|
||||
register(options.name, target);
|
||||
|
||||
return OptionalModule(omit(options, 'name'))(target);
|
||||
};
|
||||
};
|
||||
@@ -1,5 +1,5 @@
|
||||
import { OptionalModule } from '../../fundamentals';
|
||||
import { registerStorageProvider } from '../../fundamentals/storage';
|
||||
import { Plugin } from '../registry';
|
||||
import { R2StorageProvider } from './providers/r2';
|
||||
import { S3StorageProvider } from './providers/s3';
|
||||
|
||||
@@ -18,7 +18,8 @@ registerStorageProvider('aws-s3', (config, bucket) => {
|
||||
return new S3StorageProvider(config.plugins['aws-s3'], bucket);
|
||||
});
|
||||
|
||||
@OptionalModule({
|
||||
@Plugin({
|
||||
name: 'cloudflare-r2',
|
||||
requires: [
|
||||
'plugins.cloudflare-r2.accountId',
|
||||
'plugins.cloudflare-r2.credentials.accessKeyId',
|
||||
@@ -28,7 +29,8 @@ registerStorageProvider('aws-s3', (config, bucket) => {
|
||||
})
|
||||
export class CloudflareR2Module {}
|
||||
|
||||
@OptionalModule({
|
||||
@Plugin({
|
||||
name: 'aws-s3',
|
||||
requires: [
|
||||
'plugins.aws-s3.credentials.accessKeyId',
|
||||
'plugins.aws-s3.credentials.secretAccessKey',
|
||||
|
||||
Reference in New Issue
Block a user