diff --git a/apps/core/src/pages/auth.tsx b/apps/core/src/pages/auth.tsx index cdf2c9bbab..9c9ac07e55 100644 --- a/apps/core/src/pages/auth.tsx +++ b/apps/core/src/pages/auth.tsx @@ -9,7 +9,12 @@ import { changeEmailMutation, changePasswordMutation } from '@affine/graphql'; import { useMutation } from '@affine/workspace/affine/gql'; import type { ReactElement } from 'react'; import { useCallback } from 'react'; -import { type LoaderFunction, redirect, useParams } from 'react-router-dom'; +import { + type LoaderFunction, + redirect, + useParams, + useSearchParams, +} from 'react-router-dom'; import { z } from 'zod'; import { useCurrentLoginStatus } from '../hooks/affine/use-current-login-status'; @@ -27,6 +32,7 @@ const authTypeSchema = z.enum([ export const AuthPage = (): ReactElement | null => { const user = useCurrentUser(); const { authType } = useParams(); + const [searchParams] = useSearchParams(); const { trigger: changePassword } = useMutation({ mutation: changePasswordMutation, }); @@ -39,22 +45,22 @@ export const AuthPage = (): ReactElement | null => { const onChangeEmail = useCallback( async (email: string) => { const res = await changeEmail({ - id: user.id, + token: searchParams.get('token') || '', newEmail: email, }); return !!res?.changeEmail; }, - [changeEmail, user.id] + [changeEmail, searchParams] ); const onSetPassword = useCallback( (password: string) => { changePassword({ - id: user.id, + token: searchParams.get('token') || '', newPassword: password, }).catch(console.error); }, - [changePassword, user.id] + [changePassword, searchParams] ); const onOpenAffine = useCallback(() => { jumpToIndex(RouteLogic.REPLACE); diff --git a/apps/server/package.json b/apps/server/package.json index bba57394a4..b7b6cffe17 100644 --- a/apps/server/package.json +++ b/apps/server/package.json @@ -24,6 +24,7 @@ "@aws-sdk/client-s3": "^3.400.0", "@google-cloud/opentelemetry-cloud-monitoring-exporter": "^0.17.0", "@google-cloud/opentelemetry-cloud-trace-exporter": "^2.1.0", + "@keyv/redis": "^2.7.0", "@nestjs/apollo": "^12.0.7", "@nestjs/common": "^10.2.4", "@nestjs/core": "^10.2.4", @@ -57,6 +58,7 @@ "graphql-type-json": "^0.3.2", "graphql-upload": "^16.0.2", "ioredis": "^5.3.2", + "keyv": "^4.5.3", "lodash-es": "^4.17.21", "nestjs-throttler-storage-redis": "^0.3.3", "next-auth": "4.22.5", diff --git a/apps/server/src/app.ts b/apps/server/src/app.ts index 679c4981a2..822e559d3c 100644 --- a/apps/server/src/app.ts +++ b/apps/server/src/app.ts @@ -5,6 +5,7 @@ import { ConfigModule } from './config'; import { MetricsModule } from './metrics'; import { BusinessModules, Providers } from './modules'; import { PrismaModule } from './prisma'; +import { SessionModule } from './session'; import { StorageModule } from './storage'; import { RateLimiterModule } from './throttler'; @@ -14,6 +15,7 @@ import { RateLimiterModule } from './throttler'; ConfigModule.forRoot(), StorageModule.forRoot(), MetricsModule, + SessionModule, RateLimiterModule, ...BusinessModules, ], diff --git a/apps/server/src/config/def.ts b/apps/server/src/config/def.ts index 2a395f722c..e930474237 100644 --- a/apps/server/src/config/def.ts +++ b/apps/server/src/config/def.ts @@ -231,6 +231,15 @@ export interface AFFiNEConfig { port: number; username: string; password: string; + /** + * redis database index + * + * Rate Limiter scope: database + 1 + * + * Session scope: database + 2 + * + * @default 0 + */ database: number; }; diff --git a/apps/server/src/modules/auth/next-auth-options.ts b/apps/server/src/modules/auth/next-auth-options.ts index 1fea5d9b0b..3c7508d8a4 100644 --- a/apps/server/src/modules/auth/next-auth-options.ts +++ b/apps/server/src/modules/auth/next-auth-options.ts @@ -4,6 +4,7 @@ import { PrismaAdapter } from '@auth/prisma-adapter'; import { BadRequestException, FactoryProvider, Logger } from '@nestjs/common'; import { verify } from '@node-rs/argon2'; import { Algorithm, sign, verify as jwtVerify } from '@node-rs/jsonwebtoken'; +import { nanoid } from 'nanoid'; import { NextAuthOptions } from 'next-auth'; import Credentials from 'next-auth/providers/credentials'; import Email, { @@ -14,6 +15,7 @@ import Google from 'next-auth/providers/google'; import { Config } from '../../config'; import { PrismaService } from '../../prisma'; +import { SessionService } from '../../session'; import { NewFeaturesKind } from '../users/types'; import { isStaff } from '../users/utils'; import { MailService } from './mailer'; @@ -23,7 +25,12 @@ export const NextAuthOptionsProvide = Symbol('NextAuthOptions'); export const NextAuthOptionsProvider: FactoryProvider = { provide: NextAuthOptionsProvide, - useFactory(config: Config, prisma: PrismaService, mailer: MailService) { + useFactory( + config: Config, + prisma: PrismaService, + mailer: MailService, + session: SessionService + ) { const logger = new Logger('NextAuth'); const prismaAdapter = PrismaAdapter(prisma); // createUser exists in the adapter @@ -72,15 +79,31 @@ export const NextAuthOptionsProvider: FactoryProvider = { from: config.auth.email.sender, async sendVerificationRequest(params: SendVerificationRequestParams) { const { identifier, url, provider } = params; - const { searchParams } = new URL(url); - const callbackUrl = searchParams.get('callbackUrl') || ''; + const urlWithToken = new URL(url); + const callbackUrl = + urlWithToken.searchParams.get('callbackUrl') || ''; if (!callbackUrl) { throw new Error('callbackUrl is not set'); + } else { + const newCallbackUrl = new URL(callbackUrl, config.origin); + + const token = nanoid(); + await session.set(token, identifier); + newCallbackUrl.searchParams.set('token', token); + + urlWithToken.searchParams.set( + 'callbackUrl', + newCallbackUrl.toString() + ); } - const result = await mailer.sendSignInEmail(url, { - to: identifier, - from: provider.from, - }); + + const result = await mailer.sendSignInEmail( + urlWithToken.toString(), + { + to: identifier, + from: provider.from, + } + ); logger.log( `send verification email success: ${result.accepted.join(', ')}` ); @@ -277,5 +300,5 @@ export const NextAuthOptionsProvider: FactoryProvider = { }; return nextAuthOptions; }, - inject: [Config, PrismaService, MailService], + inject: [Config, PrismaService, MailService, SessionService], }; diff --git a/apps/server/src/modules/auth/next-auth.controller.ts b/apps/server/src/modules/auth/next-auth.controller.ts index f931af4c34..c59b283855 100644 --- a/apps/server/src/modules/auth/next-auth.controller.ts +++ b/apps/server/src/modules/auth/next-auth.controller.ts @@ -127,7 +127,6 @@ export class NextAuthController { } if (redirect?.endsWith('api/auth/error?error=AccessDenied')) { - this.logger.debug(req.headers); if (!req.headers?.referer) { res.redirect('https://community.affine.pro/c/insider-general/'); } else { @@ -145,7 +144,6 @@ export class NextAuthController { } if (redirect) { - this.logger.debug(providerId, action, req.headers); if (providerId === 'credentials') { res.send(JSON.stringify({ ok: true, url: redirect })); } else if ( diff --git a/apps/server/src/modules/auth/resolver.ts b/apps/server/src/modules/auth/resolver.ts index 68b74b4ab3..55bd32be1b 100644 --- a/apps/server/src/modules/auth/resolver.ts +++ b/apps/server/src/modules/auth/resolver.ts @@ -1,4 +1,8 @@ -import { ForbiddenException, UseGuards } from '@nestjs/common'; +import { + BadRequestException, + ForbiddenException, + UseGuards, +} from '@nestjs/common'; import { Args, Context, @@ -10,11 +14,13 @@ import { Resolver, } from '@nestjs/graphql'; import type { Request } from 'express'; +import { nanoid } from 'nanoid'; import { Config } from '../../config'; +import { SessionService } from '../../session'; import { CloudThrottlerGuard, Throttle } from '../../throttler'; import { UserType } from '../users/resolver'; -import { CurrentUser } from './guard'; +import { Auth, CurrentUser } from './guard'; import { AuthService } from './service'; @ObjectType() @@ -37,14 +43,15 @@ export class TokenType { export class AuthResolver { constructor( private readonly config: Config, - private auth: AuthService + private auth: AuthService, + private readonly session: SessionService ) {} @Throttle(20, 60) @ResolveField(() => TokenType) token(@CurrentUser() currentUser: UserType, @Parent() user: UserType) { if (user.id !== currentUser.id) { - throw new ForbiddenException(); + throw new BadRequestException('Invalid user'); } return { @@ -80,58 +87,93 @@ export class AuthResolver { @Throttle(5, 60) @Mutation(() => UserType) + @Auth() async changePassword( - @Context() ctx: { req: Request }, - @Args('id') id: string, + @CurrentUser() user: UserType, + @Args('token') token: string, @Args('newPassword') newPassword: string ) { - const user = await this.auth.changePassword(id, newPassword); - ctx.req.user = user; + const id = await this.session.get(token); + if (!id || id !== user.id) { + throw new ForbiddenException('Invalid token'); + } + + await this.auth.changePassword(id, newPassword); + await this.session.delete(token); + return user; } @Throttle(5, 60) @Mutation(() => UserType) + @Auth() async changeEmail( - @Context() ctx: { req: Request }, - @Args('id') id: string, + @CurrentUser() user: UserType, + @Args('token') token: string, @Args('email') email: string ) { - const user = await this.auth.changeEmail(id, email); - ctx.req.user = user; + const id = await this.session.get(token); + if (!id || id !== user.id) { + throw new ForbiddenException('Invalid token'); + } + + await this.auth.changeEmail(id, email); + await this.session.delete(token); + return user; } @Throttle(5, 60) @Mutation(() => Boolean) + @Auth() async sendChangePasswordEmail( + @CurrentUser() user: UserType, @Args('email') email: string, @Args('callbackUrl') callbackUrl: string ) { - const url = `${this.config.baseUrl}${callbackUrl}`; - const res = await this.auth.sendChangePasswordEmail(email, url); + const token = nanoid(); + await this.session.set(token, user.id); + + const url = new URL(callbackUrl, this.config.baseUrl); + url.searchParams.set('token', token); + + const res = await this.auth.sendChangePasswordEmail(email, url.toString()); return !res.rejected.length; } @Throttle(5, 60) @Mutation(() => Boolean) + @Auth() async sendSetPasswordEmail( + @CurrentUser() user: UserType, @Args('email') email: string, @Args('callbackUrl') callbackUrl: string ) { - const url = `${this.config.baseUrl}${callbackUrl}`; - const res = await this.auth.sendSetPasswordEmail(email, url); + const token = nanoid(); + await this.session.set(token, user.id); + + const url = new URL(callbackUrl, this.config.baseUrl); + url.searchParams.set('token', token); + + const res = await this.auth.sendSetPasswordEmail(email, url.toString()); return !res.rejected.length; } @Throttle(5, 60) @Mutation(() => Boolean) + @Auth() async sendChangeEmail( + @CurrentUser() user: UserType, @Args('email') email: string, @Args('callbackUrl') callbackUrl: string ) { - const url = `${this.config.baseUrl}${callbackUrl}`; - const res = await this.auth.sendChangeEmail(email, url); + const token = nanoid(); + await this.session.set(token, user.id); + + const url = new URL(callbackUrl, this.config.baseUrl); + url.searchParams.set('token', token); + + const res = await this.auth.sendChangeEmail(email, url.toString()); return !res.rejected.length; } } diff --git a/apps/server/src/modules/users/resolver.ts b/apps/server/src/modules/users/resolver.ts index ffcf39a600..f5bcdf74e3 100644 --- a/apps/server/src/modules/users/resolver.ts +++ b/apps/server/src/modules/users/resolver.ts @@ -91,7 +91,7 @@ export class UserResolver { name: 'currentUser', description: 'Get current user', }) - async currentUser(@CurrentUser() user: User) { + async currentUser(@CurrentUser() user: UserType) { const storedUser = await this.prisma.user.findUnique({ where: { id: user.id }, }); diff --git a/apps/server/src/schema.gql b/apps/server/src/schema.gql index 09dda6bfdc..dd2570aa6b 100644 --- a/apps/server/src/schema.gql +++ b/apps/server/src/schema.gql @@ -186,8 +186,8 @@ type Mutation { addToNewFeaturesWaitingList(type: NewFeaturesKind!, email: String!): AddToNewFeaturesWaitingList! signUp(name: String!, email: String!, password: String!): UserType! signIn(email: String!, password: String!): UserType! - changePassword(id: String!, newPassword: String!): UserType! - changeEmail(id: String!, email: String!): UserType! + changePassword(token: String!, newPassword: String!): UserType! + changeEmail(token: String!, email: String!): UserType! sendChangePasswordEmail(email: String!, callbackUrl: String!): Boolean! sendSetPasswordEmail(email: String!, callbackUrl: String!): Boolean! sendChangeEmail(email: String!, callbackUrl: String!): Boolean! diff --git a/apps/server/src/session.ts b/apps/server/src/session.ts new file mode 100644 index 0000000000..ab1a657f8b --- /dev/null +++ b/apps/server/src/session.ts @@ -0,0 +1,60 @@ +import KeyvRedis from '@keyv/redis'; +import { Global, Injectable, Module } from '@nestjs/common'; +import Redis from 'ioredis'; +import Keyv from 'keyv'; + +import { Config } from './config'; + +@Injectable() +export class SessionService { + private readonly cache: Keyv; + private readonly prefix = 'session:'; + private readonly sessionTtl = 30 * 60 * 1000; // 30 min + + constructor(protected readonly config: Config) { + if (config.redis.enabled) { + this.cache = new Keyv({ + store: new KeyvRedis( + new Redis(config.redis.port, config.redis.host, { + username: config.redis.username, + password: config.redis.password, + db: config.redis.database + 2, + }) + ), + }); + } else { + this.cache = new Keyv(); + } + } + + /** + * get session + * @param key session key + * @returns + */ + async get(key: string) { + return this.cache.get(this.prefix + key); + } + + /** + * set session + * @param key session key + * @param value session value + * @param sessionTtl session ttl (ms), default 30 min + * @returns return true if success + */ + async set(key: string, value?: any, sessionTtl = this.sessionTtl) { + return this.cache.set(this.prefix + key, value, sessionTtl); + } + + async delete(key: string) { + return this.cache.delete(this.prefix + key); + } +} + +@Global() +@Module({ + providers: [SessionService], + exports: [SessionService], +}) +export class SessionModule {} diff --git a/apps/server/src/tests/session.spec.ts b/apps/server/src/tests/session.spec.ts new file mode 100644 index 0000000000..8fd674b41e --- /dev/null +++ b/apps/server/src/tests/session.spec.ts @@ -0,0 +1,42 @@ +/// +import { equal } from 'node:assert'; +import { afterEach, beforeEach, test } from 'node:test'; + +import { Test, TestingModule } from '@nestjs/testing'; +import { PrismaClient } from '@prisma/client'; + +import { ConfigModule } from '../config'; +import { SessionModule, SessionService } from '../session'; + +let session: SessionService; +let module: TestingModule; + +// cleanup database before each test +beforeEach(async () => { + const client = new PrismaClient(); + await client.$connect(); + await client.user.deleteMany({}); +}); + +beforeEach(async () => { + module = await Test.createTestingModule({ + imports: [ConfigModule.forRoot(), SessionModule], + }).compile(); + session = module.get(SessionService); +}); + +afterEach(async () => { + await module.close(); +}); + +test('should be able to set session', async () => { + await session.set('test', 'value'); + equal(await session.get('test'), 'value'); +}); + +test('should be expired by ttl', async () => { + await session.set('test', 'value', 100); + equal(await session.get('test'), 'value'); + await new Promise(resolve => setTimeout(resolve, 500)); + equal(await session.get('test'), undefined); +}); diff --git a/packages/graphql/src/graphql/change-email.gql b/packages/graphql/src/graphql/change-email.gql index 78daad9fd3..e2efdf343e 100644 --- a/packages/graphql/src/graphql/change-email.gql +++ b/packages/graphql/src/graphql/change-email.gql @@ -1,5 +1,5 @@ -mutation changeEmail($id: String!, $newEmail: String!) { - changeEmail(id: $id, email: $newEmail) { +mutation changeEmail($token: String!, $newEmail: String!) { + changeEmail(token: $token, email: $newEmail) { id name avatarUrl diff --git a/packages/graphql/src/graphql/change-password.gql b/packages/graphql/src/graphql/change-password.gql index bc69c92437..e1f86b3ffd 100644 --- a/packages/graphql/src/graphql/change-password.gql +++ b/packages/graphql/src/graphql/change-password.gql @@ -1,5 +1,5 @@ -mutation changePassword($id: String!, $newPassword: String!) { - changePassword(id: $id, newPassword: $newPassword) { +mutation changePassword($token: String!, $newPassword: String!) { + changePassword(token: $token, newPassword: $newPassword) { id name avatarUrl diff --git a/packages/graphql/src/graphql/index.ts b/packages/graphql/src/graphql/index.ts index c1ee7d0b45..c687a92e79 100644 --- a/packages/graphql/src/graphql/index.ts +++ b/packages/graphql/src/graphql/index.ts @@ -72,8 +72,8 @@ export const changeEmailMutation = { definitionName: 'changeEmail', containsFile: false, query: ` -mutation changeEmail($id: String!, $newEmail: String!) { - changeEmail(id: $id, email: $newEmail) { +mutation changeEmail($token: String!, $newEmail: String!) { + changeEmail(token: $token, email: $newEmail) { id name avatarUrl @@ -88,8 +88,8 @@ export const changePasswordMutation = { definitionName: 'changePassword', containsFile: false, query: ` -mutation changePassword($id: String!, $newPassword: String!) { - changePassword(id: $id, newPassword: $newPassword) { +mutation changePassword($token: String!, $newPassword: String!) { + changePassword(token: $token, newPassword: $newPassword) { id name avatarUrl diff --git a/packages/graphql/src/schema.ts b/packages/graphql/src/schema.ts index f8177becea..a0c2cf739e 100644 --- a/packages/graphql/src/schema.ts +++ b/packages/graphql/src/schema.ts @@ -90,7 +90,7 @@ export type AllBlobSizesQuery = { }; export type ChangeEmailMutationVariables = Exact<{ - id: Scalars['String']['input']; + token: Scalars['String']['input']; newEmail: Scalars['String']['input']; }>; @@ -106,7 +106,7 @@ export type ChangeEmailMutation = { }; export type ChangePasswordMutationVariables = Exact<{ - id: Scalars['String']['input']; + token: Scalars['String']['input']; newPassword: Scalars['String']['input']; }>; diff --git a/yarn.lock b/yarn.lock index 7339797c0d..23f7679278 100644 --- a/yarn.lock +++ b/yarn.lock @@ -661,6 +661,7 @@ __metadata: "@aws-sdk/client-s3": ^3.400.0 "@google-cloud/opentelemetry-cloud-monitoring-exporter": ^0.17.0 "@google-cloud/opentelemetry-cloud-trace-exporter": ^2.1.0 + "@keyv/redis": ^2.7.0 "@napi-rs/image": ^1.6.1 "@nestjs/apollo": ^12.0.7 "@nestjs/common": ^10.2.4 @@ -708,6 +709,7 @@ __metadata: graphql-type-json: ^0.3.2 graphql-upload: ^16.0.2 ioredis: ^5.3.2 + keyv: ^4.5.3 lodash-es: ^4.17.21 nestjs-throttler-storage-redis: ^0.3.3 next-auth: 4.22.5 @@ -6386,6 +6388,15 @@ __metadata: languageName: node linkType: hard +"@keyv/redis@npm:^2.7.0": + version: 2.7.0 + resolution: "@keyv/redis@npm:2.7.0" + dependencies: + ioredis: ^5.3.2 + checksum: 2bf16d99f54fa5177c375eb170f46076715e6a17fd65840d10638e441af8a4dd065927a18b45abb9531746ddab54f865884347c80f7e188c981a40f8245269ab + languageName: node + linkType: hard + "@kwsites/file-exists@npm:^1.1.1": version: 1.1.1 resolution: "@kwsites/file-exists@npm:1.1.1"