From 67b6c28d6797d6b7014273f245f5c78cab174946 Mon Sep 17 00:00:00 2001 From: fengmk2 Date: Mon, 10 Feb 2025 12:01:14 +0000 Subject: [PATCH] refactor(server): use user model on oauth plugin (#10031) close CLOUD-117 --- .../server/src/__tests__/auth/auth.e2e.ts | 2 +- .../server/src/__tests__/models/user.spec.ts | 14 ++-- .../src/__tests__/oauth/controller.spec.ts | 12 +-- packages/backend/server/src/models/user.ts | 6 +- .../server/src/plugins/oauth/controller.ts | 78 ++++++++----------- 5 files changed, 48 insertions(+), 64 deletions(-) diff --git a/packages/backend/server/src/__tests__/auth/auth.e2e.ts b/packages/backend/server/src/__tests__/auth/auth.e2e.ts index 8628b60267..2e1355a44c 100644 --- a/packages/backend/server/src/__tests__/auth/auth.e2e.ts +++ b/packages/backend/server/src/__tests__/auth/auth.e2e.ts @@ -187,7 +187,7 @@ test('should revoke token after change user identify', async t => { // change password { - const u3Email = 'u3@affine.pro'; + const u3Email = 'u3333@affine.pro'; await app.logout(); const u3 = await app.signup(u3Email); diff --git a/packages/backend/server/src/__tests__/models/user.spec.ts b/packages/backend/server/src/__tests__/models/user.spec.ts index e1370eb076..a7a590c51c 100644 --- a/packages/backend/server/src/__tests__/models/user.spec.ts +++ b/packages/backend/server/src/__tests__/models/user.spec.ts @@ -1,4 +1,3 @@ -import { PrismaClient } from '@prisma/client'; import ava, { TestFn } from 'ava'; import Sinon from 'sinon'; @@ -276,16 +275,13 @@ test('should trigger user.deleted event', async t => { }); test('should paginate users', async t => { - const db = t.context.module.get(PrismaClient); const now = Date.now(); await Promise.all( Array.from({ length: 100 }).map((_, i) => - db.user.create({ - data: { - name: `test${i}`, - email: `test${i}@affine.pro`, - createdAt: new Date(now + i), - }, + t.context.user.create({ + name: `test-paginate-${i}`, + email: `test-paginate-${i}@affine.pro`, + createdAt: new Date(now + i), }) ) ); @@ -294,7 +290,7 @@ test('should paginate users', async t => { t.is(users.length, 10); t.deepEqual( users.map(user => user.email), - Array.from({ length: 10 }).map((_, i) => `test${i}@affine.pro`) + Array.from({ length: 10 }).map((_, i) => `test-paginate-${i}@affine.pro`) ); }); diff --git a/packages/backend/server/src/__tests__/oauth/controller.spec.ts b/packages/backend/server/src/__tests__/oauth/controller.spec.ts index f9ebfd09db..b71bf4a112 100644 --- a/packages/backend/server/src/__tests__/oauth/controller.spec.ts +++ b/packages/backend/server/src/__tests__/oauth/controller.spec.ts @@ -308,7 +308,7 @@ test('should not throw if account registered', async t => { }); test('should be able to fullfil user with oauth sign in', async t => { - const { app, db } = t.context; + const { app, models } = t.context; const u3 = await app.createUser('u3@affine.pro'); @@ -321,11 +321,11 @@ test('should be able to fullfil user with oauth sign in', async t => { t.truthy(sessionUser); t.is(sessionUser!.email, u3.email); - const account = await db.connectedAccount.findFirst({ - where: { - userId: u3.id, - }, - }); + const account = await models.user.getConnectedAccount( + OAuthProviderName.Google, + '1' + ); t.truthy(account); + t.is(account!.user.id, u3.id); }); diff --git a/packages/backend/server/src/models/user.ts b/packages/backend/server/src/models/user.ts index c3fcf5378f..21a0df4197 100644 --- a/packages/backend/server/src/models/user.ts +++ b/packages/backend/server/src/models/user.ts @@ -241,9 +241,13 @@ export class UserModel extends BaseModel { // #region ConnectedAccount async createConnectedAccount(data: CreateConnectedAccountInput) { - return await this.db.connectedAccount.create({ + const account = await this.db.connectedAccount.create({ data, }); + this.logger.log( + `Connected account ${account.provider}:${account.id} created` + ); + return account; } async getConnectedAccount(provider: string, providerAccountId: string) { diff --git a/packages/backend/server/src/plugins/oauth/controller.ts b/packages/backend/server/src/plugins/oauth/controller.ts index 9e8871045e..2524e9c2cd 100644 --- a/packages/backend/server/src/plugins/oauth/controller.ts +++ b/packages/backend/server/src/plugins/oauth/controller.ts @@ -7,7 +7,7 @@ import { Req, Res, } from '@nestjs/common'; -import { ConnectedAccount, PrismaClient } from '@prisma/client'; +import { ConnectedAccount } from '@prisma/client'; import type { Request, Response } from 'express'; import { @@ -30,8 +30,7 @@ export class OAuthController { private readonly auth: AuthService, private readonly oauth: OAuthService, private readonly models: Models, - private readonly providerFactory: OAuthProviderFactory, - private readonly db: PrismaClient + private readonly providerFactory: OAuthProviderFactory ) {} @Public() @@ -120,48 +119,39 @@ export class OAuthController { externalAccount: OAuthAccount, tokens: Tokens ) { - const connectedUser = await this.db.connectedAccount.findFirst({ - where: { - provider, - providerAccountId: externalAccount.id, - }, - include: { - user: true, - }, - }); + const connectedAccount = await this.models.user.getConnectedAccount( + provider, + externalAccount.id + ); - if (connectedUser) { + if (connectedAccount) { // already connected - await this.updateConnectedAccount(connectedUser, tokens); - - return connectedUser.user; + await this.updateConnectedAccount(connectedAccount, tokens); + return connectedAccount.user; } const user = await this.models.user.fulfill(externalAccount.email, { avatarUrl: externalAccount.avatarUrl, }); - await this.db.connectedAccount.create({ - data: { - userId: user.id, - provider, - providerAccountId: externalAccount.id, - ...tokens, - }, + await this.models.user.createConnectedAccount({ + userId: user.id, + provider, + providerAccountId: externalAccount.id, + ...tokens, }); + return user; } private async updateConnectedAccount( - connectedUser: ConnectedAccount, + connectedAccount: ConnectedAccount, tokens: Tokens ) { - return this.db.connectedAccount.update({ - where: { - id: connectedUser.id, - }, - data: tokens, - }); + return await this.models.user.updateConnectedAccount( + connectedAccount.id, + tokens + ); } /** @@ -175,26 +165,20 @@ export class OAuthController { externalAccount: OAuthAccount, tokens: Tokens ) { - const connectedUser = await this.db.connectedAccount.findFirst({ - where: { - provider, - providerAccountId: externalAccount.id, - }, - }); - - if (connectedUser) { - if (connectedUser.id !== user.id) { + const connectedAccount = await this.models.user.getConnectedAccount( + provider, + externalAccount.id + ); + if (connectedAccount) { + if (connectedAccount.userId !== user.id) { throw new OauthAccountAlreadyConnected(); } } else { - await this.db.connectedAccount.create({ - data: { - userId: user.id, - provider, - providerAccountId: externalAccount.id, - accessToken: tokens.accessToken, - refreshToken: tokens.refreshToken, - }, + await this.models.user.createConnectedAccount({ + userId: user.id, + provider, + providerAccountId: externalAccount.id, + ...tokens, }); } }