refactor(server): use user model on oauth plugin (#10031)

close CLOUD-117
This commit is contained in:
fengmk2
2025-02-10 12:01:14 +00:00
parent 23364b59a0
commit 67b6c28d67
5 changed files with 48 additions and 64 deletions

View File

@@ -187,7 +187,7 @@ test('should revoke token after change user identify', async t => {
// change password
{
const u3Email = 'u3@affine.pro';
const u3Email = 'u3333@affine.pro';
await app.logout();
const u3 = await app.signup(u3Email);

View File

@@ -1,4 +1,3 @@
import { PrismaClient } from '@prisma/client';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
@@ -276,16 +275,13 @@ test('should trigger user.deleted event', async t => {
});
test('should paginate users', async t => {
const db = t.context.module.get(PrismaClient);
const now = Date.now();
await Promise.all(
Array.from({ length: 100 }).map((_, i) =>
db.user.create({
data: {
name: `test${i}`,
email: `test${i}@affine.pro`,
createdAt: new Date(now + i),
},
t.context.user.create({
name: `test-paginate-${i}`,
email: `test-paginate-${i}@affine.pro`,
createdAt: new Date(now + i),
})
)
);
@@ -294,7 +290,7 @@ test('should paginate users', async t => {
t.is(users.length, 10);
t.deepEqual(
users.map(user => user.email),
Array.from({ length: 10 }).map((_, i) => `test${i}@affine.pro`)
Array.from({ length: 10 }).map((_, i) => `test-paginate-${i}@affine.pro`)
);
});

View File

@@ -308,7 +308,7 @@ test('should not throw if account registered', async t => {
});
test('should be able to fullfil user with oauth sign in', async t => {
const { app, db } = t.context;
const { app, models } = t.context;
const u3 = await app.createUser('u3@affine.pro');
@@ -321,11 +321,11 @@ test('should be able to fullfil user with oauth sign in', async t => {
t.truthy(sessionUser);
t.is(sessionUser!.email, u3.email);
const account = await db.connectedAccount.findFirst({
where: {
userId: u3.id,
},
});
const account = await models.user.getConnectedAccount(
OAuthProviderName.Google,
'1'
);
t.truthy(account);
t.is(account!.user.id, u3.id);
});

View File

@@ -241,9 +241,13 @@ export class UserModel extends BaseModel {
// #region ConnectedAccount
async createConnectedAccount(data: CreateConnectedAccountInput) {
return await this.db.connectedAccount.create({
const account = await this.db.connectedAccount.create({
data,
});
this.logger.log(
`Connected account ${account.provider}:${account.id} created`
);
return account;
}
async getConnectedAccount(provider: string, providerAccountId: string) {

View File

@@ -7,7 +7,7 @@ import {
Req,
Res,
} from '@nestjs/common';
import { ConnectedAccount, PrismaClient } from '@prisma/client';
import { ConnectedAccount } from '@prisma/client';
import type { Request, Response } from 'express';
import {
@@ -30,8 +30,7 @@ export class OAuthController {
private readonly auth: AuthService,
private readonly oauth: OAuthService,
private readonly models: Models,
private readonly providerFactory: OAuthProviderFactory,
private readonly db: PrismaClient
private readonly providerFactory: OAuthProviderFactory
) {}
@Public()
@@ -120,48 +119,39 @@ export class OAuthController {
externalAccount: OAuthAccount,
tokens: Tokens
) {
const connectedUser = await this.db.connectedAccount.findFirst({
where: {
provider,
providerAccountId: externalAccount.id,
},
include: {
user: true,
},
});
const connectedAccount = await this.models.user.getConnectedAccount(
provider,
externalAccount.id
);
if (connectedUser) {
if (connectedAccount) {
// already connected
await this.updateConnectedAccount(connectedUser, tokens);
return connectedUser.user;
await this.updateConnectedAccount(connectedAccount, tokens);
return connectedAccount.user;
}
const user = await this.models.user.fulfill(externalAccount.email, {
avatarUrl: externalAccount.avatarUrl,
});
await this.db.connectedAccount.create({
data: {
userId: user.id,
provider,
providerAccountId: externalAccount.id,
...tokens,
},
await this.models.user.createConnectedAccount({
userId: user.id,
provider,
providerAccountId: externalAccount.id,
...tokens,
});
return user;
}
private async updateConnectedAccount(
connectedUser: ConnectedAccount,
connectedAccount: ConnectedAccount,
tokens: Tokens
) {
return this.db.connectedAccount.update({
where: {
id: connectedUser.id,
},
data: tokens,
});
return await this.models.user.updateConnectedAccount(
connectedAccount.id,
tokens
);
}
/**
@@ -175,26 +165,20 @@ export class OAuthController {
externalAccount: OAuthAccount,
tokens: Tokens
) {
const connectedUser = await this.db.connectedAccount.findFirst({
where: {
provider,
providerAccountId: externalAccount.id,
},
});
if (connectedUser) {
if (connectedUser.id !== user.id) {
const connectedAccount = await this.models.user.getConnectedAccount(
provider,
externalAccount.id
);
if (connectedAccount) {
if (connectedAccount.userId !== user.id) {
throw new OauthAccountAlreadyConnected();
}
} else {
await this.db.connectedAccount.create({
data: {
userId: user.id,
provider,
providerAccountId: externalAccount.id,
accessToken: tokens.accessToken,
refreshToken: tokens.refreshToken,
},
await this.models.user.createConnectedAccount({
userId: user.id,
provider,
providerAccountId: externalAccount.id,
...tokens,
});
}
}