diff --git a/packages/backend/server/src/core/entitlement/__tests__/projection.spec.ts b/packages/backend/server/src/core/entitlement/__tests__/projection.spec.ts index 6b424b1506..e6a69d8322 100644 --- a/packages/backend/server/src/core/entitlement/__tests__/projection.spec.ts +++ b/packages/backend/server/src/core/entitlement/__tests__/projection.spec.ts @@ -250,6 +250,52 @@ test('backfill marks selfhost team subscriptions as needing license revalidation ); }); +test('backfill removes dangling legacy subscriptions and entitlements', async t => { + await t.context.db.subscription.createMany({ + data: [ + { + targetId: randomUUID(), + plan: SubscriptionPlan.Pro, + recurring: SubscriptionRecurring.Yearly, + status: SubscriptionStatus.Active, + start: new Date(), + }, + { + targetId: randomUUID(), + plan: SubscriptionPlan.Team, + recurring: SubscriptionRecurring.Yearly, + status: SubscriptionStatus.Active, + start: new Date(), + }, + ], + }); + await t.context.db.entitlement.createMany({ + data: [ + { + targetType: 'user', + targetId: randomUUID(), + source: 'cloud_subscription', + plan: 'pro', + status: 'active', + subjectId: randomUUID(), + }, + { + targetType: 'workspace', + targetId: randomUUID(), + source: 'cloud_subscription', + plan: 'team', + status: 'active', + subjectId: randomUUID(), + }, + ], + }); + + await t.context.projection.backfillEntitlementsAndQuotaStates(); + + t.is(await t.context.db.subscription.count(), 0); + t.is(await t.context.db.entitlement.count(), 0); +}); + test('key based selfhost entitlements without raw payload need reupload', async t => { const owner = await t.context.models.user.create({ email: `${randomUUID()}@affine.pro`, diff --git a/packages/backend/server/src/core/entitlement/projection.ts b/packages/backend/server/src/core/entitlement/projection.ts index d4d4624083..1c08f9c763 100644 --- a/packages/backend/server/src/core/entitlement/projection.ts +++ b/packages/backend/server/src/core/entitlement/projection.ts @@ -34,13 +34,13 @@ export class LegacyEntitlementProjectionService { targetId, }: Events['entitlement.changed']) { if (targetType === 'user') { - await this.projectCloudSubscriptions('user', targetId); - await this.projectUserFeatures(targetId); + await this.#projectCloudSubscriptions('user', targetId); + await this.#projectUserFeatures(targetId); } else if (targetType === 'workspace') { - await this.projectCloudSubscriptions('workspace', targetId); + await this.#projectCloudSubscriptions('workspace', targetId); await Promise.all([ - this.projectWorkspaceFeatures(targetId), - this.projectInstalledLicense(targetId), + this.#projectWorkspaceFeatures(targetId), + this.#projectInstalledLicense(targetId), ]); } } @@ -49,7 +49,7 @@ export class LegacyEntitlementProjectionService { async onWorkspaceQuotaStateChanged({ workspaceId, }: Events['workspace.quota_state.changed']) { - await this.projectReadonlyFeature(workspaceId); + await this.#projectReadonlyFeature(workspaceId); } async scanInstalledLicenses() { @@ -88,6 +88,8 @@ export class LegacyEntitlementProjectionService { } async backfillEntitlementsAndQuotaStates() { + await this.#cleanupDanglingLegacyEntitlements(); + const [subscriptions, users, workspaces] = await Promise.all([ this.db.subscription.findMany(), this.db.user.findMany({ select: { id: true } }), @@ -95,6 +97,9 @@ export class LegacyEntitlementProjectionService { ]); for (const subscription of subscriptions) { + if (!(await this.#subscriptionTargetExists(subscription))) { + continue; + } if (subscription.plan === SubscriptionPlan.SelfHostedTeam) { await this.entitlement.markSelfhostLicenseNeedsReupload({ licenseKey: subscription.targetId, @@ -148,8 +153,74 @@ export class LegacyEntitlementProjectionService { ]); } - private async projectUserFeatures(userId: string) { - const entitlements = await this.activeEntitlements('user', userId); + async #cleanupDanglingLegacyEntitlements() { + await this.db.$executeRaw` + DELETE FROM entitlements entitlement + WHERE ( + entitlement.target_type = 'user' + AND NOT EXISTS ( + SELECT 1 + FROM users + WHERE users.id = entitlement.target_id + ) + ) + OR ( + entitlement.target_type = 'workspace' + AND NOT EXISTS ( + SELECT 1 + FROM workspaces + WHERE workspaces.id = entitlement.target_id + ) + ) + `; + + await this.db.$executeRaw` + DELETE FROM subscriptions subscription + WHERE ( + subscription.plan IN (${SubscriptionPlan.Pro}, ${SubscriptionPlan.AI}) + AND NOT EXISTS ( + SELECT 1 + FROM users + WHERE users.id = subscription.target_id + ) + ) + OR ( + subscription.plan = ${SubscriptionPlan.Team} + AND NOT EXISTS ( + SELECT 1 + FROM workspaces + WHERE workspaces.id = subscription.target_id + ) + ) + `; + } + + async #subscriptionTargetExists(subscription: { + targetId: string; + plan: string; + }) { + if ( + subscription.plan === SubscriptionPlan.Pro || + subscription.plan === SubscriptionPlan.AI + ) { + return !!(await this.db.user.findUnique({ + where: { id: subscription.targetId }, + select: { id: true }, + })); + } + + if (subscription.plan === SubscriptionPlan.Team) { + return !!(await this.db.workspace.findUnique({ + where: { id: subscription.targetId }, + select: { id: true }, + })); + } + + return true; + } + + async #projectUserFeatures(userId: string) { + const entitlements = await this.#activeEntitlements('user', userId); const quotaEntitlement = entitlements.find(entitlement => ['lifetime_pro', 'pro'].includes(entitlement.plan) ); @@ -190,7 +261,7 @@ export class LegacyEntitlementProjectionService { } } - private async projectWorkspaceFeatures(workspaceId: string) { + async #projectWorkspaceFeatures(workspaceId: string) { const [entitlement, resolved] = await Promise.all([ this.entitlement.getBestEntitlement('workspace', workspaceId), this.entitlement.resolveWorkspaceEntitlement(workspaceId), @@ -215,7 +286,7 @@ export class LegacyEntitlementProjectionService { } } - private async projectCloudSubscriptions( + async #projectCloudSubscriptions( targetType: 'user' | 'workspace', targetId: string ) { @@ -229,13 +300,15 @@ export class LegacyEntitlementProjectionService { orderBy: { updatedAt: 'asc' }, }); - for (const entitlement of this.projectableCloudEntitlements(entitlements)) { + for (const entitlement of this.#projectableCloudEntitlements( + entitlements + )) { const metadata = entitlement.metadata as Metadata; await this.db.subscription.upsert({ where: { targetId_plan: { targetId, - plan: this.subscriptionPlan(entitlement.plan), + plan: this.#subscriptionPlan(entitlement.plan), }, }, update: { @@ -243,21 +316,21 @@ export class LegacyEntitlementProjectionService { variant: metadata.variant ?? null, quantity: entitlement.quantity ?? 1, stripeSubscriptionId: metadata.stripeSubscriptionId ?? null, - provider: this.provider(metadata.provider), - status: this.subscriptionStatus(entitlement.status), + provider: this.#provider(metadata.provider), + status: this.#subscriptionStatus(entitlement.status), start: entitlement.startsAt ?? entitlement.createdAt, end: entitlement.expiresAt, trialEnd: entitlement.graceUntil, }, create: { targetId, - plan: this.subscriptionPlan(entitlement.plan), + plan: this.#subscriptionPlan(entitlement.plan), recurring: metadata.recurring ?? SubscriptionRecurring.Monthly, variant: metadata.variant ?? null, quantity: entitlement.quantity ?? 1, stripeSubscriptionId: metadata.stripeSubscriptionId ?? null, - provider: this.provider(metadata.provider), - status: this.subscriptionStatus(entitlement.status), + provider: this.#provider(metadata.provider), + status: this.#subscriptionStatus(entitlement.status), start: entitlement.startsAt ?? entitlement.createdAt, end: entitlement.expiresAt, trialEnd: entitlement.graceUntil, @@ -277,17 +350,17 @@ export class LegacyEntitlementProjectionService { } } - private *projectableCloudEntitlements(entitlements: Entitlement[]) { + *#projectableCloudEntitlements(entitlements: Entitlement[]) { const byPlan = new Map(); for (const entitlement of entitlements) { - const plan = this.subscriptionPlan(entitlement.plan); + const plan = this.#subscriptionPlan(entitlement.plan); const current = byPlan.get(plan); if ( !current || - this.subscriptionProjectionPriority(entitlement) > - this.subscriptionProjectionPriority(current) + this.#subscriptionProjectionPriority(entitlement) > + this.#subscriptionProjectionPriority(current) ) { byPlan.set(plan, entitlement); } @@ -296,7 +369,7 @@ export class LegacyEntitlementProjectionService { yield* byPlan.values(); } - private subscriptionProjectionPriority(entitlement: { + #subscriptionProjectionPriority(entitlement: { status: string; updatedAt: Date; }) { @@ -312,7 +385,7 @@ export class LegacyEntitlementProjectionService { ); } - private async projectInstalledLicense(workspaceId: string) { + async #projectInstalledLicense(workspaceId: string) { const [entitlements, resolved] = await Promise.all([ this.db.entitlement.findMany({ where: { @@ -326,8 +399,8 @@ export class LegacyEntitlementProjectionService { ]); const entitlement = entitlements.sort( (left, right) => - this.installedLicenseStatusPriority(right.status) - - this.installedLicenseStatusPriority(left.status) || + this.#installedLicenseStatusPriority(right.status) - + this.#installedLicenseStatusPriority(left.status) || Number(!!right.signedPayload) - Number(!!left.signedPayload) || right.updatedAt.getTime() - left.updatedAt.getTime() )[0]; @@ -386,7 +459,7 @@ export class LegacyEntitlementProjectionService { }); } - private installedLicenseStatusPriority(status: string) { + #installedLicenseStatusPriority(status: string) { if (status === 'active' || status === 'grace') { return 3; } @@ -399,7 +472,7 @@ export class LegacyEntitlementProjectionService { return 0; } - private async projectReadonlyFeature(workspaceId: string) { + async #projectReadonlyFeature(workspaceId: string) { const state = await this.db.effectiveWorkspaceQuotaState.findUnique({ where: { workspaceId, @@ -420,7 +493,7 @@ export class LegacyEntitlementProjectionService { } } - private async activeEntitlements( + async #activeEntitlements( targetType: 'user' | 'workspace', targetId: string ) { @@ -439,7 +512,7 @@ export class LegacyEntitlementProjectionService { return count > 0; } - private subscriptionPlan(plan: string) { + #subscriptionPlan(plan: string) { if (plan === 'lifetime_pro') { return SubscriptionPlan.Pro; } @@ -449,7 +522,7 @@ export class LegacyEntitlementProjectionService { return plan; } - private subscriptionStatus(status: string) { + #subscriptionStatus(status: string) { if (status === 'active') { return SubscriptionStatus.Active; } @@ -459,7 +532,7 @@ export class LegacyEntitlementProjectionService { return SubscriptionStatus.Canceled; } - private provider(provider: string | null | undefined) { + #provider(provider: string | null | undefined) { return provider === 'revenuecat' ? 'revenuecat' : 'stripe'; } } diff --git a/packages/backend/server/src/data/__tests__/migrations.spec.ts b/packages/backend/server/src/data/__tests__/migrations.spec.ts new file mode 100644 index 0000000000..cc5d1bd5b4 --- /dev/null +++ b/packages/backend/server/src/data/__tests__/migrations.spec.ts @@ -0,0 +1,92 @@ +import { ModuleRef } from '@nestjs/core'; +import { PrismaClient } from '@prisma/client'; +import ava, { TestFn } from 'ava'; + +import { createTestingModule, type TestingModule } from '../../__tests__/utils'; +import { Models } from '../../models'; +import { BackfillPermissionProjection1765500000000 } from '../migrations/1765500000000-backfill-permission-projection'; + +interface Context { + module: TestingModule; + db: PrismaClient; + models: Models; +} + +const test = ava as TestFn; + +test.before(async t => { + t.context.module = await createTestingModule(); + t.context.db = t.context.module.get(PrismaClient); + t.context.models = t.context.module.get(Models); +}); + +test.beforeEach(async t => { + await t.context.module.initTestingDB(); +}); + +test.after.always(async t => { + await t.context.module.close(); +}); + +test('permission backfill repairs ownerless workspaces before runtime state projection', async t => { + const emptyWorkspace = await t.context.db.workspace.create({ + data: { public: false }, + }); + const member = await t.context.models.user.create({ + email: 'member@affine.pro', + }); + const memberWorkspace = await t.context.db.workspace.create({ + data: { public: false }, + }); + await t.context.db.workspaceMember.create({ + data: { + workspaceId: memberWorkspace.id, + userId: member.id, + role: 'member', + state: 'active', + source: 'legacy', + }, + }); + + const ref = { + get(token: unknown) { + if (token === Models) { + return t.context.models; + } + return { + async getWorkspaceState() { + return { + isReadonly: false, + readonlyReasons: [], + }; + }, + }; + }, + } as unknown as ModuleRef; + + await BackfillPermissionProjection1765500000000.up(t.context.db, ref); + + t.is( + await t.context.db.workspace.count({ where: { id: emptyWorkspace.id } }), + 0 + ); + t.like( + await t.context.db.workspaceMember.findFirstOrThrow({ + where: { + workspaceId: memberWorkspace.id, + userId: member.id, + state: 'active', + }, + }), + { role: 'owner' } + ); + t.like( + await t.context.db.workspaceUserRole.findFirstOrThrow({ + where: { + workspaceId: memberWorkspace.id, + userId: member.id, + }, + }), + { type: 99 } + ); +}); diff --git a/packages/backend/server/src/data/migrations/1765500000000-backfill-permission-projection.ts b/packages/backend/server/src/data/migrations/1765500000000-backfill-permission-projection.ts index 2e9d67f63e..1daef6ebe3 100644 --- a/packages/backend/server/src/data/migrations/1765500000000-backfill-permission-projection.ts +++ b/packages/backend/server/src/data/migrations/1765500000000-backfill-permission-projection.ts @@ -5,12 +5,14 @@ import { WorkspacePolicyService } from '../../core/permission/policy'; import { Models } from '../../models'; export class BackfillPermissionProjection1765500000000 { - static async up(_db: PrismaClient, ref: ModuleRef) { + static async up(db: PrismaClient, ref: ModuleRef) { const models = ref.get(Models, { strict: false }); await models.permissionProjection.backfillLegacyProjection(); + await ensureWorkspaceAdminStatsDirtyTriggerGuard(db); + await repairOwnerlessWorkspaces(db); const policy = ref.get(WorkspacePolicyService, { strict: false }); - const workspaces = await _db.workspace.findMany({ + const workspaces = await db.workspace.findMany({ select: { id: true }, }); for (const workspace of workspaces) { @@ -26,3 +28,81 @@ export class BackfillPermissionProjection1765500000000 { static async down(_db: PrismaClient) {} } + +async function ensureWorkspaceAdminStatsDirtyTriggerGuard(db: PrismaClient) { + await db.$executeRaw` + CREATE OR REPLACE FUNCTION workspace_admin_stats_mark_dirty() RETURNS TRIGGER AS $$ + DECLARE + wid VARCHAR; + BEGIN + wid := COALESCE(NEW."workspace_id", OLD."workspace_id"); + IF wid IS NULL THEN + RETURN NULL; + END IF; + + IF NOT EXISTS (SELECT 1 FROM "workspaces" WHERE "id" = wid) THEN + RETURN NULL; + END IF; + + INSERT INTO "workspace_admin_stats_dirty" ("workspace_id", "updated_at") + VALUES (wid, NOW()) + ON CONFLICT ("workspace_id") + DO UPDATE SET "updated_at" = EXCLUDED."updated_at"; + + RETURN NULL; + END; + $$ LANGUAGE plpgsql + `; +} + +async function repairOwnerlessWorkspaces(db: PrismaClient) { + await db.$executeRaw` + WITH ownerless AS ( + SELECT w.id + FROM workspaces w + WHERE NOT EXISTS ( + SELECT 1 + FROM workspace_members owner + WHERE owner.workspace_id = w.id + AND owner.role = 'owner' + AND owner.state = 'active' + ) + ), + accepted_members AS ( + SELECT id + FROM ( + SELECT + wm.id, + row_number() OVER ( + PARTITION BY wm.workspace_id + ORDER BY wm.created_at ASC, wm.id ASC + ) AS rn + FROM workspace_members wm + JOIN ownerless o ON o.id = wm.workspace_id + WHERE wm.state = 'active' + ) ranked + WHERE rn = 1 + ) + UPDATE workspace_members wm + SET role = 'owner', updated_at = now() + FROM accepted_members am + WHERE wm.id = am.id + `; + + await db.$executeRaw` + DELETE FROM workspaces w + WHERE NOT EXISTS ( + SELECT 1 + FROM workspace_members owner + WHERE owner.workspace_id = w.id + AND owner.role = 'owner' + AND owner.state = 'active' + ) + AND NOT EXISTS ( + SELECT 1 + FROM workspace_members member + WHERE member.workspace_id = w.id + AND member.state = 'active' + ) + `; +}