refactor(server): auth (#7994)

This commit is contained in:
forehalo
2024-09-03 09:03:39 +00:00
parent 821de0a3bb
commit 8b0afd6eeb
39 changed files with 639 additions and 775 deletions

View File

@@ -18,6 +18,7 @@ import {
EarlyAccessRequired,
EmailTokenNotFound,
InternalServerError,
InvalidEmail,
InvalidEmailToken,
SignUpForbidden,
Throttle,
@@ -25,19 +26,25 @@ import {
} from '../../fundamentals';
import { UserService } from '../user';
import { validators } from '../utils/validators';
import { CurrentUser } from './current-user';
import { Public } from './guard';
import { AuthService, parseAuthUserSeqNum } from './service';
import { AuthService } from './service';
import { CurrentUser, Session } from './session';
import { TokenService, TokenType } from './token';
class SignInCredential {
email!: string;
password?: string;
interface PreflightResponse {
registered: boolean;
hasPassword: boolean;
}
class MagicLinkCredential {
email!: string;
token!: string;
interface SignInCredential {
email: string;
password?: string;
callbackUrl?: string;
}
interface MagicLinkCredential {
email: string;
token: string;
}
@Throttle('strict')
@@ -51,6 +58,33 @@ export class AuthController {
private readonly config: Config
) {}
@Public()
@Post('/preflight')
async preflight(
@Body() params?: { email: string }
): Promise<PreflightResponse> {
if (!params?.email) {
throw new InvalidEmail();
}
validators.assertValidEmail(params.email);
const user = await this.user.findUserWithHashedPasswordByEmail(
params.email
);
if (!user) {
return {
registered: false,
hasPassword: false,
};
}
return {
registered: user.registered,
hasPassword: !!user.password,
};
}
@Public()
@Post('/sign-in')
@Header('content-type', 'application/json')
@@ -58,7 +92,10 @@ export class AuthController {
@Req() req: Request,
@Res() res: Response,
@Body() credential: SignInCredential,
@Query('redirect_uri') redirectUri = this.url.home
/**
* @deprecated
*/
@Query('redirect_uri') redirectUri?: string
) {
validators.assertValidEmail(credential.email);
const canSignIn = await this.auth.canSignIn(credential.email);
@@ -67,80 +104,83 @@ export class AuthController {
}
if (credential.password) {
const user = await this.auth.signIn(
await this.passwordSignIn(
req,
res,
credential.email,
credential.password
);
await this.auth.setCookie(req, res, user);
res.status(HttpStatus.OK).send(user);
} else {
// send email magic link
const user = await this.user.findUserByEmail(credential.email);
if (!user) {
const allowSignup = await this.config.runtime.fetch('auth/allowSignup');
if (!allowSignup) {
throw new SignUpForbidden();
}
}
const result = await this.sendSignInEmail(
{ email: credential.email, signUp: !user },
await this.sendMagicLink(
req,
res,
credential.email,
credential.callbackUrl,
redirectUri
);
if (result.rejected.length) {
throw new InternalServerError('Failed to send sign-in email.');
}
res.status(HttpStatus.OK).send({
email: credential.email,
});
}
}
async sendSignInEmail(
{ email, signUp }: { email: string; signUp: boolean },
redirectUri: string
async passwordSignIn(
req: Request,
res: Response,
email: string,
password: string
) {
const user = await this.auth.signIn(email, password);
await this.auth.setCookies(req, res, user.id);
res.status(HttpStatus.OK).send(user);
}
async sendMagicLink(
_req: Request,
res: Response,
email: string,
callbackUrl = '/magic-link',
redirectUrl = this.url.home
) {
// send email magic link
const user = await this.user.findUserByEmail(email);
if (!user) {
const allowSignup = await this.config.runtime.fetch('auth/allowSignup');
if (!allowSignup) {
throw new SignUpForbidden();
}
}
const token = await this.token.createToken(TokenType.SignIn, email);
const magicLink = this.url.link('/magic-link', {
const magicLink = this.url.link(callbackUrl, {
token,
email,
redirect_uri: redirectUri,
redirect_uri: redirectUrl,
});
const result = await this.auth.sendSignInEmail(email, magicLink, signUp);
const result = await this.auth.sendSignInEmail(email, magicLink, !user);
return result;
if (result.rejected.length) {
throw new InternalServerError('Failed to send sign-in email.');
}
res.status(HttpStatus.OK).send({
email: email,
});
}
@Get('/sign-out')
async signOut(
@Req() req: Request,
@Res() res: Response,
@Query('redirect_uri') redirectUri?: string
@Session() session: Session,
@Body() { all }: { all: boolean }
) {
const session = await this.auth.signOut(
req.cookies[AuthService.sessionCookieName],
parseAuthUserSeqNum(req.headers[AuthService.authUserSeqHeaderName])
await this.auth.signOut(
session.sessionId,
all ? undefined : session.userId
);
if (session) {
res.cookie(AuthService.sessionCookieName, session.id, {
expires: session.expiresAt ?? void 0, // expiredAt is `string | null`
...this.auth.cookieOptions,
});
} else {
res.clearCookie(AuthService.sessionCookieName);
}
if (redirectUri) {
return this.url.safeRedirect(res, redirectUri);
} else {
return res.send(null);
}
res.status(HttpStatus.OK).send({});
}
@Public()
@@ -156,11 +196,11 @@ export class AuthController {
validators.assertValidEmail(email);
const valid = await this.token.verifyToken(TokenType.SignIn, token, {
const tokenRecord = await this.token.verifyToken(TokenType.SignIn, token, {
credential: email,
});
if (!valid) {
if (!tokenRecord) {
throw new InvalidEmailToken();
}
@@ -169,9 +209,8 @@ export class AuthController {
registered: true,
});
await this.auth.setCookie(req, res, user);
res.send({ id: user.id, email: user.email, name: user.name });
await this.auth.setCookies(req, res, user.id);
res.send({ id: user.id });
}
@Throttle('default', { limit: 1200 })

View File

@@ -4,7 +4,7 @@ import type {
FactoryProvider,
OnModuleInit,
} from '@nestjs/common';
import { Injectable, SetMetadata, UseGuards } from '@nestjs/common';
import { Injectable, SetMetadata } from '@nestjs/common';
import { ModuleRef, Reflector } from '@nestjs/core';
import type { Request } from 'express';
@@ -16,16 +16,8 @@ import {
parseCookies,
} from '../../fundamentals';
import { WEBSOCKET_OPTIONS } from '../../fundamentals/websocket';
import { CurrentUser, UserSession } from './current-user';
import { AuthService, parseAuthUserSeqNum } from './service';
function extractTokenFromHeader(authorization: string) {
if (!/^Bearer\s/i.test(authorization)) {
return;
}
return authorization.substring(7);
}
import { AuthService } from './service';
import { Session } from './session';
const PUBLIC_ENTRYPOINT_SYMBOL = Symbol('public');
@@ -46,8 +38,8 @@ export class AuthGuard implements CanActivate, OnModuleInit {
const { req, res } = getRequestResponseFromContext(context);
const userSession = await this.signIn(req);
if (res && userSession && userSession.session.expiresAt) {
await this.auth.refreshUserSessionIfNeeded(req, res, userSession.session);
if (res && userSession && userSession.expiresAt) {
await this.auth.refreshUserSessionIfNeeded(res, userSession);
}
// api is public
@@ -60,43 +52,31 @@ export class AuthGuard implements CanActivate, OnModuleInit {
return true;
}
if (!req.user) {
if (!userSession) {
throw new AuthenticationRequired();
}
return true;
}
async signIn(
req: Request
): Promise<{ user: CurrentUser; session: UserSession } | null> {
if (req.user && req.session) {
return {
user: req.user,
session: req.session,
};
async signIn(req: Request): Promise<Session | null> {
if (req.session) {
return req.session;
}
// compatibility with websocket request
parseCookies(req);
let sessionToken: string | undefined =
req.cookies[AuthService.sessionCookieName];
if (!sessionToken && req.headers.authorization) {
sessionToken = extractTokenFromHeader(req.headers.authorization);
}
// TODO(@forehalo): a cache for user session
const userSession = await this.auth.getUserSessionFromRequest(req);
if (sessionToken) {
const userSeq = parseAuthUserSeqNum(
req.headers[AuthService.authUserSeqHeaderName]
);
if (userSession) {
req.session = {
...userSession.session,
user: userSession.user,
};
const userSession = await this.auth.getUserSession(sessionToken, userSeq);
if (userSession) {
req.session = userSession.session;
req.user = userSession.user;
}
return userSession;
return req.session;
}
return null;
@@ -104,26 +84,8 @@ export class AuthGuard implements CanActivate, OnModuleInit {
}
/**
* This guard is used to protect routes/queries/mutations that require a user to be logged in.
*
* The `@CurrentUser()` parameter decorator used in a `Auth` guarded queries would always give us the user because the `Auth` guard will
* fast throw if user is not logged in.
*
* @example
*
* ```typescript
* \@Auth()
* \@Query(() => UserType)
* user(@CurrentUser() user: CurrentUser) {
* return user;
* }
* ```
* Mark api to be public accessible
*/
export const Auth = () => {
return UseGuards(AuthGuard);
};
// api is public accessible
export const Public = () => SetMetadata(PUBLIC_ENTRYPOINT_SYMBOL, true);
export const AuthWebsocketOptionsProvider: FactoryProvider = {

View File

@@ -28,4 +28,4 @@ export class AuthModule {}
export * from './guard';
export { ClientTokenType } from './resolver';
export { AuthService, TokenService, TokenType };
export * from './current-user';
export * from './session';

View File

@@ -11,7 +11,6 @@ import {
import {
ActionForbidden,
Config,
EmailAlreadyUsed,
EmailTokenNotFound,
EmailVerificationRequired,
@@ -26,9 +25,9 @@ import { Admin } from '../common';
import { UserService } from '../user';
import { UserType } from '../user/types';
import { validators } from '../utils/validators';
import { CurrentUser } from './current-user';
import { Public } from './guard';
import { AuthService } from './service';
import { CurrentUser } from './session';
import { TokenService, TokenType } from './token';
@ObjectType('tokenType')
@@ -47,7 +46,6 @@ export class ClientTokenType {
@Resolver(() => UserType)
export class AuthResolver {
constructor(
private readonly config: Config,
private readonly url: URLHelper,
private readonly auth: AuthService,
private readonly user: UserService,
@@ -67,7 +65,7 @@ export class AuthResolver {
@ResolveField(() => ClientTokenType, {
name: 'token',
deprecationReason: 'use [/api/auth/authorize]',
deprecationReason: 'use [/api/auth/sign-in?native=true] instead',
})
async clientToken(
@CurrentUser() currentUser: CurrentUser,
@@ -77,15 +75,11 @@ export class AuthResolver {
throw new ActionForbidden();
}
const session = await this.auth.createUserSession(
user,
undefined,
this.config.auth.accessToken.ttl
);
const userSession = await this.auth.createUserSession(user.id);
return {
sessionToken: session.sessionId,
token: session.sessionId,
sessionToken: userSession.sessionId,
token: userSession.sessionId,
refresh: '',
};
}
@@ -101,14 +95,6 @@ export class AuthResolver {
throw new LinkExpired();
}
const config = await this.config.runtime.fetchAll({
'auth/password.max': true,
'auth/password.min': true,
});
validators.assertValidPassword(newPassword, {
min: config['auth/password.min'],
max: config['auth/password.max'],
});
// NOTE: Set & Change password are using the same token type.
const valid = await this.token.verifyToken(
TokenType.ChangePassword,
@@ -134,7 +120,6 @@ export class AuthResolver {
@Args('token') token: string,
@Args('email') email: string
) {
validators.assertValidEmail(email);
// @see [sendChangeEmail]
const valid = await this.token.verifyToken(TokenType.VerifyEmail, token, {
credential: user.id,
@@ -157,8 +142,11 @@ export class AuthResolver {
async sendChangePasswordEmail(
@CurrentUser() user: CurrentUser,
@Args('callbackUrl') callbackUrl: string,
// @deprecated
@Args('email', { nullable: true }) _email?: string
@Args('email', {
nullable: true,
deprecationReason: 'fetched from signed in user',
})
_email?: string
) {
if (!user.emailVerified) {
throw new EmailVerificationRequired();
@@ -180,7 +168,11 @@ export class AuthResolver {
async sendSetPasswordEmail(
@CurrentUser() user: CurrentUser,
@Args('callbackUrl') callbackUrl: string,
@Args('email', { nullable: true }) _email?: string
@Args('email', {
nullable: true,
deprecationReason: 'fetched from signed in user',
})
_email?: string
) {
return this.sendChangePasswordEmail(user, callbackUrl);
}

View File

@@ -5,35 +5,12 @@ import { PrismaClient } from '@prisma/client';
import type { CookieOptions, Request, Response } from 'express';
import { assign, pick } from 'lodash-es';
import { Config, EmailAlreadyUsed, MailService } from '../../fundamentals';
import { Config, MailService, SignUpForbidden } from '../../fundamentals';
import { FeatureManagementService } from '../features/management';
import { QuotaService } from '../quota/service';
import { QuotaType } from '../quota/types';
import { UserService } from '../user/service';
import type { CurrentUser } from './current-user';
export function parseAuthUserSeqNum(value: any) {
let seq: number = 0;
switch (typeof value) {
case 'number': {
seq = value;
break;
}
case 'string': {
const result = value.match(/^([\d{0, 10}])$/);
if (result?.[1]) {
seq = Number(result[1]);
}
break;
}
default: {
seq = 0;
}
}
return Math.max(0, seq);
}
import type { CurrentUser } from './session';
export function sessionUser(
user: Pick<
@@ -48,6 +25,14 @@ export function sessionUser(
});
}
function extractTokenFromHeader(authorization: string) {
if (!/^Bearer\s/i.test(authorization)) {
return;
}
return authorization.substring(7);
}
@Injectable()
export class AuthService implements OnApplicationBootstrap {
readonly cookieOptions: CookieOptions = {
@@ -57,7 +42,7 @@ export class AuthService implements OnApplicationBootstrap {
secure: this.config.server.https,
};
static readonly sessionCookieName = 'affine_session';
static readonly authUserSeqHeaderName = 'x-auth-user';
static readonly userCookieName = 'affine_user_id';
constructor(
private readonly config: Config,
@@ -93,46 +78,69 @@ export class AuthService implements OnApplicationBootstrap {
return this.feature.canEarlyAccess(email);
}
async signUp(
name: string,
email: string,
password: string
): Promise<CurrentUser> {
const user = await this.user.findUserByEmail(email);
if (user) {
throw new EmailAlreadyUsed();
/**
* This is a test only helper to quickly signup a user, do not use in production
*/
async signUp(email: string, password: string): Promise<CurrentUser> {
if (!this.config.node.test) {
throw new SignUpForbidden(
'sign up helper is forbidden for non-test environment'
);
}
return this.user
.createUser({
name,
.createUser_without_verification({
email,
password,
})
.then(sessionUser);
}
async signIn(email: string, password: string) {
const user = await this.user.signIn(email, password);
async signIn(email: string, password: string): Promise<CurrentUser> {
return this.user.signIn(email, password).then(sessionUser);
}
return sessionUser(user);
async signOut(sessionId: string, userId?: string) {
// sign out all users in the session
if (!userId) {
await this.db.session.deleteMany({
where: {
id: sessionId,
},
});
} else {
await this.db.userSession.deleteMany({
where: {
sessionId,
userId,
},
});
}
}
async getUserSession(
token: string,
seq = 0
sessionId: string,
userId?: string
): Promise<{ user: CurrentUser; session: UserSession } | null> {
const session = await this.getSession(token);
const userSession = await this.db.userSession.findFirst({
where: {
sessionId,
userId,
},
select: {
id: true,
sessionId: true,
userId: true,
createdAt: true,
expiresAt: true,
user: true,
},
orderBy: {
createdAt: 'asc',
},
});
// no such session
if (!session) {
return null;
}
const userSession = session.userSessions.at(seq);
// no such user session
if (!userSession) {
return null;
}
@@ -142,112 +150,93 @@ export class AuthService implements OnApplicationBootstrap {
return null;
}
const user = await this.db.user.findUnique({
where: { id: userSession.userId },
});
if (!user) {
return null;
}
return { user: sessionUser(user), session: userSession };
return { user: sessionUser(userSession.user), session: userSession };
}
async getUserList(token: string) {
const session = await this.getSession(token);
if (!session || !session.userSessions.length) {
return [];
}
const users = await this.db.user.findMany({
where: {
id: {
in: session.userSessions.map(({ userId }) => userId),
},
},
});
// TODO(@forehalo): need to separate expired session, same for [getUser]
// Session
// | { user: LimitedUser { email, avatarUrl }, expired: true }
// | { user: User, expired: false }
return session.userSessions
.map(userSession => {
// keep users in the same order as userSessions
const user = users.find(({ id }) => id === userSession.userId);
if (!user) {
return null;
}
return sessionUser(user);
})
.filter(Boolean) as CurrentUser[];
}
async signOut(token: string, seq = 0) {
const session = await this.getSession(token);
if (session) {
// overflow the logged in user
if (session.userSessions.length <= seq) {
return session;
}
await this.db.userSession.deleteMany({
where: { id: session.userSessions[seq].id },
});
// no more user session active, delete the whole session
if (session.userSessions.length === 1) {
await this.db.session.delete({ where: { id: session.id } });
return null;
}
return session;
}
return null;
}
async getSession(token: string) {
if (!token) {
return null;
}
return this.db.$transaction(async tx => {
const session = await tx.session.findUnique({
async createUserSession(
userId: string,
sessionId?: string,
ttl = this.config.auth.session.ttl
) {
// check whether given session is valid
if (sessionId) {
const session = await this.db.session.findFirst({
where: {
id: token,
},
include: {
userSessions: {
orderBy: {
createdAt: 'asc',
},
},
id: sessionId,
},
});
if (!session) {
return null;
sessionId = undefined;
}
}
if (session.expiresAt && session.expiresAt <= new Date()) {
await tx.session.delete({
where: {
id: session.id,
if (!sessionId) {
const session = await this.createSession();
sessionId = session.id;
}
const expiresAt = new Date(Date.now() + ttl * 1000);
return this.db.userSession.upsert({
where: {
sessionId_userId: {
sessionId,
userId,
},
},
update: {
expiresAt,
},
create: {
sessionId,
userId,
expiresAt,
},
});
}
async getUserList(sessionId: string) {
const sessions = await this.db.userSession.findMany({
where: {
sessionId,
OR: [
{
expiresAt: null,
},
});
{
expiresAt: {
gt: new Date(),
},
},
],
},
include: {
user: true,
},
orderBy: {
createdAt: 'asc',
},
});
return null;
}
return sessions.map(({ user }) => sessionUser(user));
}
return session;
async createSession() {
return this.db.session.create({
data: {},
});
}
async getSession(sessionId: string) {
return this.db.session.findFirst({
where: {
id: sessionId,
},
});
}
async refreshUserSessionIfNeeded(
_req: Request,
res: Response,
session: UserSession,
ttr = this.config.auth.session.ttr
@@ -281,70 +270,63 @@ export class AuthService implements OnApplicationBootstrap {
return true;
}
async createUserSession(
user: { id: string },
existingSession?: string,
ttl = this.config.auth.session.ttl
) {
const session = existingSession
? await this.getSession(existingSession)
: null;
const expiresAt = new Date(Date.now() + ttl * 1000);
if (session) {
return this.db.userSession.upsert({
where: {
sessionId_userId: {
sessionId: session.id,
userId: user.id,
},
},
update: {
expiresAt,
},
create: {
sessionId: session.id,
userId: user.id,
expiresAt,
},
});
} else {
return this.db.userSession.create({
data: {
expiresAt,
session: {
create: {},
},
user: {
connect: {
id: user.id,
},
},
},
});
}
}
async revokeUserSessions(userId: string, sessionId?: string) {
async revokeUserSessions(userId: string) {
return this.db.userSession.deleteMany({
where: {
userId,
sessionId,
},
});
}
async setCookie(_req: Request, res: Response, user: { id: string }) {
const session = await this.createUserSession(
user
// TODO(@forehalo): enable multi user session
// req.cookies[AuthService.sessionCookieName]
);
getSessionOptionsFromRequest(req: Request) {
let sessionId: string | undefined =
req.cookies[AuthService.sessionCookieName];
res.cookie(AuthService.sessionCookieName, session.sessionId, {
expires: session.expiresAt ?? void 0,
if (!sessionId && req.headers.authorization) {
sessionId = extractTokenFromHeader(req.headers.authorization);
}
const userId: string | undefined =
req.cookies[AuthService.userCookieName] ||
req.headers[AuthService.userCookieName];
return {
sessionId,
userId,
};
}
async setCookies(req: Request, res: Response, userId: string) {
const { sessionId } = this.getSessionOptionsFromRequest(req);
const userSession = await this.createUserSession(userId, sessionId);
res.cookie(AuthService.sessionCookieName, userSession.sessionId, {
...this.cookieOptions,
expires: userSession.expiresAt ?? void 0,
});
this.setUserCookie(res, userId);
}
setUserCookie(res: Response, userId: string) {
res.cookie(AuthService.userCookieName, userId, {
...this.cookieOptions,
// user cookie is client readable & writable for fast user switch if there are multiple users in one session
// it safe to be non-secure & non-httpOnly because server will validate it by `cookie[AuthService.sessionCookieName]`
httpOnly: false,
secure: false,
});
}
async getUserSessionFromRequest(req: Request) {
const { sessionId, userId } = this.getSessionOptionsFromRequest(req);
if (!sessionId) {
return null;
}
return this.getUserSession(sessionId, userId);
}
async changePassword(
@@ -393,24 +375,16 @@ export class AuthService implements OnApplicationBootstrap {
async sendSignInEmail(email: string, link: string, signUp: boolean) {
return signUp
? await this.mailer.sendSignUpMail(link.toString(), {
? await this.mailer.sendSignUpMail(link, {
to: email,
})
: await this.mailer.sendSignInMail(link.toString(), {
: await this.mailer.sendSignInMail(link, {
to: email,
});
}
@Cron(CronExpression.EVERY_DAY_AT_MIDNIGHT)
async cleanExpiredSessions() {
await this.db.session.deleteMany({
where: {
expiresAt: {
lte: new Date(),
},
},
});
await this.db.userSession.deleteMany({
where: {
expiresAt: {

View File

@@ -4,10 +4,6 @@ import { User, UserSession } from '@prisma/client';
import { getRequestResponseFromContext } from '../../fundamentals';
function getUserFromContext(context: ExecutionContext) {
return getRequestResponseFromContext(context).req.user;
}
/**
* Used to fetch current user from the request context.
*
@@ -44,7 +40,7 @@ function getUserFromContext(context: ExecutionContext) {
// eslint-disable-next-line no-redeclare
export const CurrentUser = createParamDecorator(
(_: unknown, context: ExecutionContext) => {
return getUserFromContext(context);
return getRequestResponseFromContext(context).req.session?.user;
}
);
@@ -54,4 +50,14 @@ export interface CurrentUser
emailVerified: boolean;
}
export { type UserSession };
// interface and variable don't conflict
// eslint-disable-next-line no-redeclare
export const Session = createParamDecorator(
(_: unknown, context: ExecutionContext) => {
return getRequestResponseFromContext(context).req.session;
}
);
export type Session = UserSession & {
user: CurrentUser;
};

View File

@@ -25,8 +25,8 @@ export class AdminGuard implements CanActivate, OnModuleInit {
async canActivate(context: ExecutionContext) {
const { req } = getRequestResponseFromContext(context);
let allow = false;
if (req.user) {
allow = await this.feature.isAdmin(req.user.id);
if (req.session) {
allow = await this.feature.isAdmin(req.session.user.id);
}
if (!allow) {

View File

@@ -7,7 +7,7 @@ import {
} from '@nestjs/graphql';
import { SafeIntResolver } from 'graphql-scalars';
import { CurrentUser } from '../auth/current-user';
import { CurrentUser } from '../auth/session';
import { EarlyAccessType } from '../features';
import { UserType } from '../user';
import { QuotaService } from './service';

View File

@@ -56,7 +56,7 @@ export class CustomSetupController {
try {
await this.event.emitAsync('user.admin.created', user);
await this.auth.setCookie(req, res, user);
await this.auth.setCookies(req, res, user.id);
res.send({ id: user.id, email: user.email, name: user.name });
} catch (e) {
await this.user.deleteUser(user.id);

View File

@@ -21,7 +21,7 @@ import {
SpaceAccessDenied,
VersionRejected,
} from '../../fundamentals';
import { Auth, CurrentUser } from '../auth';
import { CurrentUser } from '../auth';
import {
DocStorageAdapter,
PgUserspaceDocStorageAdapter,
@@ -203,7 +203,6 @@ export class SpaceSyncGateway
}
// v3
@Auth()
@SubscribeMessage('space:join')
async onJoinSpace(
@CurrentUser() user: CurrentUser,
@@ -264,7 +263,6 @@ export class SpaceSyncGateway
};
}
@Auth()
@SubscribeMessage('space:push-doc-updates')
async onReceiveDocUpdates(
@ConnectedSocket() client: Socket,
@@ -324,7 +322,6 @@ export class SpaceSyncGateway
};
}
@Auth()
@SubscribeMessage('space:join-awareness')
async onJoinAwareness(
@ConnectedSocket() client: Socket,
@@ -410,7 +407,6 @@ export class SpaceSyncGateway
// TODO(@forehalo): remove
// deprecated section
@Auth()
@SubscribeMessage('client-handshake-sync')
async handleClientHandshakeSync(
@CurrentUser() user: CurrentUser,
@@ -451,7 +447,6 @@ export class SpaceSyncGateway
});
}
@Auth()
@SubscribeMessage('client-update-v2')
async handleClientUpdateV2(
@CurrentUser() user: CurrentUser,
@@ -499,7 +494,6 @@ export class SpaceSyncGateway
});
}
@Auth()
@SubscribeMessage('client-handshake-awareness')
async handleClientHandshakeAwareness(
@ConnectedSocket() client: Socket,

View File

@@ -18,9 +18,9 @@ import {
Throttle,
UserNotFound,
} from '../../fundamentals';
import { CurrentUser } from '../auth/current-user';
import { Public } from '../auth/guard';
import { sessionUser } from '../auth/service';
import { CurrentUser } from '../auth/session';
import { Admin } from '../common';
import { AvatarStorage } from '../storage';
import { validators } from '../utils/validators';

View File

@@ -56,11 +56,6 @@ export class UserService {
async createUser(data: CreateUserInput) {
validators.assertValidEmail(data.email);
const user = await this.findUserByEmail(data.email);
if (user) {
throw new EmailAlreadyUsed();
}
if (data.password) {
const config = await this.config.runtime.fetchAll({
@@ -77,6 +72,12 @@ export class UserService {
}
async createUser_without_verification(data: CreateUserInput) {
const user = await this.findUserByEmail(data.email);
if (user) {
throw new EmailAlreadyUsed();
}
if (data.password) {
data.password = await this.crypto.encryptPassword(data.password);
}
@@ -158,9 +159,7 @@ export class UserService {
async fulfillUser(
email: string,
data: Partial<
Pick<Prisma.UserCreateInput, 'emailVerifiedAt' | 'registered'>
>
data: Omit<Partial<Prisma.UserCreateInput>, 'id'>
) {
const user = await this.findUserByEmail(email);
if (!user) {
@@ -180,7 +179,6 @@ export class UserService {
if (Object.keys(data).length) {
return await this.prisma.user.update({
select: this.defaultUserSelect,
where: { id: user.id },
data,
});

View File

@@ -8,7 +8,7 @@ import {
import type { User } from '@prisma/client';
import type { Payload } from '../../fundamentals/event/def';
import { CurrentUser } from '../auth/current-user';
import { type CurrentUser } from '../auth/session';
@ObjectType()
export class UserType implements CurrentUser {

View File

@@ -12,7 +12,7 @@ import {
ThrottlerRequest,
ThrottlerStorageService,
} from '@nestjs/throttler';
import type { Request } from 'express';
import type { Request, Response } from 'express';
import { Config } from '../config';
import { getRequestResponseFromContext } from '../utils/request';
@@ -50,7 +50,10 @@ export class CloudThrottlerGuard extends ThrottlerGuard {
super(options, storageService, reflector);
}
override getRequestResponse(context: ExecutionContext) {
override getRequestResponse(context: ExecutionContext): {
req: Request;
res: Response;
} {
return getRequestResponseFromContext(context) as any;
}
@@ -153,7 +156,7 @@ export class CloudThrottlerGuard extends ThrottlerGuard {
const throttler = this.getSpecifiedThrottler(context);
// if user is logged in, bypass non-protected handlers
if (!throttler && req.user) {
if (!throttler && req.session?.user) {
return true;
}

View File

@@ -1,3 +1,5 @@
import { IncomingMessage } from 'node:http';
import type { ArgumentsHost, ExecutionContext } from '@nestjs/common';
import type { GqlContextType } from '@nestjs/graphql';
import { GqlArgumentsHost } from '@nestjs/graphql';
@@ -25,26 +27,7 @@ export function getRequestResponseFromHost(host: ArgumentsHost) {
case 'ws': {
const ws = host.switchToWs();
const req = ws.getClient<Socket>().client.conn.request as Request;
const cookieStr = req?.headers?.cookie ?? '';
// patch cookies to match auth guard logic
if (typeof cookieStr === 'string') {
req.cookies = cookieStr.split(';').reduce(
(cookies, cookie) => {
const [key, val] = cookie.split('=');
if (key) {
cookies[decodeURIComponent(key.trim())] = val
? decodeURIComponent(val.trim())
: val;
}
return cookies;
},
{} as Record<string, string>
);
}
parseCookies(req);
return { req };
}
case 'rpc': {
@@ -71,12 +54,14 @@ export function getRequestResponseFromContext(ctx: ExecutionContext) {
* simple patch for request not protected by `cookie-parser`
* only take effect if `req.cookies` is not defined
*/
export function parseCookies(req: Request) {
export function parseCookies(
req: IncomingMessage & { cookies?: Record<string, string> }
) {
if (req.cookies) {
return;
}
const cookieStr = req?.headers?.cookie ?? '';
const cookieStr = req.headers.cookie ?? '';
req.cookies = cookieStr.split(';').reduce(
(cookies, cookie) => {
const [key, val] = cookie.split('=');

View File

@@ -1,7 +1,6 @@
declare namespace Express {
interface Request {
user?: import('./core/auth/current-user').CurrentUser;
session?: import('./core/auth/current-user').UserSession;
session?: import('./core/auth/session').Session;
}
}

View File

@@ -27,8 +27,7 @@ import {
toArray,
} from 'rxjs';
import { Public } from '../../core/auth';
import { CurrentUser } from '../../core/auth/current-user';
import { CurrentUser, Public } from '../../core/auth';
import {
BlobNotFound,
Config,

View File

@@ -1,4 +1,12 @@
import { Controller, Get, Query, Req, Res } from '@nestjs/common';
import {
Body,
Controller,
HttpCode,
HttpStatus,
Post,
Req,
Res,
} from '@nestjs/common';
import { ConnectedAccount, PrismaClient } from '@prisma/client';
import type { Request, Response } from 'express';
@@ -11,34 +19,34 @@ import {
OauthStateExpired,
UnknownOauthProvider,
URLHelper,
WrongSignInMethod,
} from '../../fundamentals';
import { OAuthProviderName } from './config';
import { OAuthAccount, Tokens } from './providers/def';
import { OAuthProviderFactory } from './register';
import { OAuthService } from './service';
@Controller('/oauth')
@Controller('/api/oauth')
export class OAuthController {
constructor(
private readonly auth: AuthService,
private readonly oauth: OAuthService,
private readonly user: UserService,
private readonly providerFactory: OAuthProviderFactory,
private readonly url: URLHelper,
private readonly providerFactory: OAuthProviderFactory,
private readonly db: PrismaClient
) {}
@Public()
@Get('/login')
async login(
@Res() res: Response,
@Query('provider') unknownProviderName: string,
@Query('redirect_uri') redirectUri?: string
@Post('/preflight')
@HttpCode(HttpStatus.OK)
async preflight(
@Body('provider') unknownProviderName?: string,
@Body('redirect_uri') redirectUri: string = this.url.home
) {
if (!unknownProviderName) {
throw new MissingOauthQueryParameter({ name: 'provider' });
}
// @ts-expect-error safe
const providerName = OAuthProviderName[unknownProviderName];
const provider = this.providerFactory.get(providerName);
@@ -48,20 +56,23 @@ export class OAuthController {
}
const state = await this.oauth.saveOAuthState({
redirectUri: redirectUri ?? this.url.home,
provider: providerName,
redirectUri,
});
return res.redirect(provider.getAuthUrl(state));
return {
url: provider.getAuthUrl(state),
};
}
@Public()
@Get('/callback')
@Post('/callback')
@HttpCode(HttpStatus.OK)
async callback(
@Req() req: Request,
@Res() res: Response,
@Query('code') code?: string,
@Query('state') stateStr?: string
@Body('code') code?: string,
@Body('state') stateStr?: string
) {
if (!code) {
throw new MissingOauthQueryParameter({ name: 'code' });
@@ -93,43 +104,18 @@ export class OAuthController {
const tokens = await provider.getToken(code);
const externAccount = await provider.getUser(tokens.accessToken);
const user = req.user;
const user = await this.loginFromOauth(
state.provider,
externAccount,
tokens
);
try {
if (!user) {
// if user not found, login
const user = await this.loginFromOauth(
state.provider,
externAccount,
tokens
);
const session = await this.auth.createUserSession(
user,
req.cookies[AuthService.sessionCookieName]
);
res.cookie(AuthService.sessionCookieName, session.sessionId, {
expires: session.expiresAt ?? void 0, // expiredAt is `string | null`
...this.auth.cookieOptions,
});
} else {
// if user is found, connect the account to this user
await this.connectAccountFromOauth(
user,
state.provider,
externAccount,
tokens
);
}
} catch (e: any) {
return res.redirect(
this.url.link('/signIn', {
redirect_uri: state.redirectUri,
error: e.message,
})
);
}
this.url.safeRedirect(res, state.redirectUri);
await this.auth.setCookies(req, res, user.id);
res.send({
id: user.id,
/* @deprecated */
redirectUri: state.redirectUri,
});
}
private async loginFromOauth(
@@ -154,37 +140,27 @@ export class OAuthController {
return connectedUser.user;
}
let user = await this.user.findUserByEmail(externalAccount.email);
const user = await this.user.fulfillUser(externalAccount.email, {
emailVerifiedAt: new Date(),
registered: true,
avatarUrl: externalAccount.avatarUrl,
});
if (user) {
// we can't directly connect the external account with given email in sign in scenario for safety concern.
// let user manually connect in account sessions instead.
if (user.registered) {
throw new WrongSignInMethod();
}
await this.db.connectedAccount.create({
data: {
userId: user.id,
provider,
providerAccountId: externalAccount.id,
...tokens,
},
});
return user;
} else {
user = await this.createUserWithConnectedAccount(
await this.db.connectedAccount.create({
data: {
userId: user.id,
provider,
externalAccount,
tokens
);
}
providerAccountId: externalAccount.id,
...tokens,
},
});
return user;
}
updateConnectedAccount(connectedUser: ConnectedAccount, tokens: Tokens) {
private async updateConnectedAccount(
connectedUser: ConnectedAccount,
tokens: Tokens
) {
return this.db.connectedAccount.update({
where: {
id: connectedUser.id,
@@ -193,27 +169,12 @@ export class OAuthController {
});
}
async createUserWithConnectedAccount(
provider: OAuthProviderName,
externalAccount: OAuthAccount,
tokens: Tokens
) {
return this.user.createUser({
email: externalAccount.email,
name: externalAccount.email.split('@')[0],
avatarUrl: externalAccount.avatarUrl,
emailVerifiedAt: new Date(),
connectedAccounts: {
create: {
provider,
providerAccountId: externalAccount.id,
...tokens,
},
},
});
}
private async connectAccountFromOauth(
/**
* we currently don't support connect oauth account to existing user
* keep it incase we need it in the future
*/
// @ts-expect-error allow unused
private async _connectAccount(
user: { id: string },
provider: OAuthProviderName,
externalAccount: OAuthAccount,

View File

@@ -15,7 +15,7 @@ export interface Tokens {
export abstract class OAuthProvider {
abstract provider: OAuthProviderName;
abstract getAuthUrl(state?: string): string;
abstract getAuthUrl(state: string): string;
abstract getToken(code: string): Promise<Tokens>;
abstract getUser(token: string): Promise<OAuthAccount>;
}

View File

@@ -9,7 +9,7 @@ import { OAuthProviderFactory } from './register';
const OAUTH_STATE_KEY = 'OAUTH_STATE';
interface OAuthState {
redirectUri: string;
redirectUri?: string;
provider: OAuthProviderName;
}

View File

@@ -474,8 +474,8 @@ type Mutation {
revokePage(pageId: String!, workspaceId: String!): Boolean! @deprecated(reason: "use revokePublicPage")
revokePublicPage(pageId: String!, workspaceId: String!): WorkspacePage!
sendChangeEmail(callbackUrl: String!, email: String): Boolean!
sendChangePasswordEmail(callbackUrl: String!, email: String): Boolean!
sendSetPasswordEmail(callbackUrl: String!, email: String): Boolean!
sendChangePasswordEmail(callbackUrl: String!, email: String @deprecated(reason: "fetched from signed in user")): Boolean!
sendSetPasswordEmail(callbackUrl: String!, email: String @deprecated(reason: "fetched from signed in user")): Boolean!
sendVerifyChangeEmail(callbackUrl: String!, email: String!, token: String!): Boolean!
sendVerifyEmail(callbackUrl: String!): Boolean!
setBlob(blob: Upload!, workspaceId: String!): String!
@@ -862,7 +862,7 @@ type UserType {
quota: UserQuota
subscription(plan: SubscriptionPlan = Pro): UserSubscription @deprecated(reason: "use `UserType.subscriptions`")
subscriptions: [UserSubscription!]!
token: tokenType! @deprecated(reason: "use [/api/auth/authorize]")
token: tokenType! @deprecated(reason: "use [/api/auth/sign-in?native=true] instead")
}
type VersionRejectedDataType {