diff --git a/packages/backend/server/src/__tests__/models/user.spec.ts b/packages/backend/server/src/__tests__/models/user.spec.ts index d48d629c7d..e1370eb076 100644 --- a/packages/backend/server/src/__tests__/models/user.spec.ts +++ b/packages/backend/server/src/__tests__/models/user.spec.ts @@ -297,3 +297,49 @@ test('should paginate users', async t => { Array.from({ length: 10 }).map((_, i) => `test${i}@affine.pro`) ); }); + +// #region ConnectedAccount + +test('should create, get, update, delete connected account', async t => { + const user = await t.context.user.create({ + email: 'test@affine.pro', + }); + const connectedAccount = await t.context.user.createConnectedAccount({ + userId: user.id, + provider: 'test-provider', + providerAccountId: 'test-provider-account-id', + accessToken: 'test-access-token', + }); + t.truthy(connectedAccount); + + const connectedAccount2 = await t.context.user.getConnectedAccount( + connectedAccount.provider, + connectedAccount.providerAccountId + ); + t.truthy(connectedAccount2); + t.is(connectedAccount2!.id, connectedAccount.id); + t.is(connectedAccount2!.user.id, user.id); + + const updatedConnectedAccount = await t.context.user.updateConnectedAccount( + connectedAccount.id, + { + accessToken: 'new-access-token', + } + ); + t.is(updatedConnectedAccount.accessToken, 'new-access-token'); + // get the updated connected account + const connectedAccount3 = await t.context.user.getConnectedAccount( + connectedAccount.provider, + connectedAccount.providerAccountId + ); + t.is(connectedAccount3!.accessToken, 'new-access-token'); + + await t.context.user.deleteConnectedAccount(connectedAccount.id); + const connectedAccount4 = await t.context.user.getConnectedAccount( + connectedAccount.provider, + connectedAccount.providerAccountId + ); + t.is(connectedAccount4, null); +}); + +// #endregion diff --git a/packages/backend/server/src/models/user.ts b/packages/backend/server/src/models/user.ts index fcaf1b693c..c3fcf5378f 100644 --- a/packages/backend/server/src/models/user.ts +++ b/packages/backend/server/src/models/user.ts @@ -1,5 +1,5 @@ import { Injectable } from '@nestjs/common'; -import { Prisma, type User } from '@prisma/client'; +import { type ConnectedAccount, Prisma, type User } from '@prisma/client'; import { pick } from 'lodash-es'; import { @@ -21,6 +21,15 @@ const publicUserSelect = { type CreateUserInput = Omit & { name?: string }; type UpdateUserInput = Omit, 'id'>; +type CreateConnectedAccountInput = Omit< + Prisma.ConnectedAccountUncheckedCreateInput, + 'id' +> & { accessToken: string }; +type UpdateConnectedAccountInput = Omit< + Prisma.ConnectedAccountUncheckedUpdateInput, + 'id' +>; + declare global { interface Events { 'user.created': User; @@ -35,7 +44,7 @@ declare global { } export type PublicUser = Pick; -export type { User }; +export type { ConnectedAccount, User }; @Injectable() export class UserModel extends BaseModel { @@ -228,4 +237,40 @@ export class UserModel extends BaseModel { async count() { return this.db.user.count(); } + + // #region ConnectedAccount + + async createConnectedAccount(data: CreateConnectedAccountInput) { + return await this.db.connectedAccount.create({ + data, + }); + } + + async getConnectedAccount(provider: string, providerAccountId: string) { + return await this.db.connectedAccount.findFirst({ + where: { provider, providerAccountId }, + include: { + user: true, + }, + }); + } + + async updateConnectedAccount(id: string, data: UpdateConnectedAccountInput) { + return await this.db.connectedAccount.update({ + where: { id }, + data, + }); + } + + async deleteConnectedAccount(id: string) { + const { count } = await this.db.connectedAccount.deleteMany({ + where: { id }, + }); + if (count > 0) { + this.logger.log(`Deleted connected account ${id}`); + } + return count; + } + + // #endregion }