diff --git a/packages/backend/server/src/__tests__/payment/revenuecat.spec.ts b/packages/backend/server/src/__tests__/payment/revenuecat.spec.ts index 21f37e9c77..33f3045430 100644 --- a/packages/backend/server/src/__tests__/payment/revenuecat.spec.ts +++ b/packages/backend/server/src/__tests__/payment/revenuecat.spec.ts @@ -1,4 +1,4 @@ -import { PrismaClient, User } from '@prisma/client'; +import { PrismaClient, type User } from '@prisma/client'; import ava, { TestFn } from 'ava'; import { omit } from 'lodash-es'; import Sinon from 'sinon'; @@ -14,6 +14,7 @@ import { Models } from '../../models'; import { PaymentModule } from '../../plugins/payment'; import { SubscriptionCronJobs } from '../../plugins/payment/cron'; import { UserSubscriptionManager } from '../../plugins/payment/manager'; +import { UserSubscriptionResolver } from '../../plugins/payment/resolver'; import { RcEvent, resolveProductMapping, @@ -39,6 +40,7 @@ type Ctx = { rc: RevenueCatService; webhook: RevenueCatWebhookHandler; controller: RevenueCatWebhookController; + subResolver: UserSubscriptionResolver; mockSub: (subs: Subscription[]) => Sinon.SinonStub; mockSubSeq: (sequences: Subscription[][]) => Sinon.SinonStub; @@ -85,6 +87,7 @@ test.beforeEach(async t => { const rc = app.get(RevenueCatService); const webhook = app.get(RevenueCatWebhookHandler); const controller = app.get(RevenueCatWebhookController); + const subResolver = app.get(UserSubscriptionResolver); t.context.module = app; t.context.db = db; @@ -95,6 +98,7 @@ test.beforeEach(async t => { t.context.rc = rc; t.context.webhook = webhook; t.context.controller = controller; + t.context.subResolver = subResolver; t.context.mockSub = subs => Sinon.stub(rc, 'getSubscriptions').resolves(subs); t.context.mockSubSeq = sequences => { @@ -927,3 +931,90 @@ test('should not dispatch webhook event when authorization header is missing or const after = event.emitAsync.getCalls()?.length || 0; t.is(after - before, 0, 'should not emit event'); }); + +test('should refresh user subscriptions (empty / revenuecat / stripe-only)', async t => { + const { subResolver, db, mockSubSeq } = t.context; + + const currentUser = { + id: user.id, + email: user.email, + avatarUrl: '', + name: '', + disabled: false, + hasPassword: true, + emailVerified: true, + }; + + // prepare mocks: + // first call returns Pro subscription + // second call returns AI subscription. + const stub = mockSubSeq([ + [ + { + identifier: 'Pro', + isTrial: false, + isActive: true, + latestPurchaseDate: new Date('2025-09-01T00:00:00.000Z'), + expirationDate: new Date('2026-09-01T00:00:00.000Z'), + productId: 'app.affine.pro.Annual', + store: 'app_store', + willRenew: true, + duration: null, + }, + ], + [ + { + identifier: 'AI', + isTrial: false, + isActive: true, + latestPurchaseDate: new Date('2025-09-02T00:00:00.000Z'), + expirationDate: new Date('2026-09-02T00:00:00.000Z'), + productId: 'app.affine.pro.ai.Annual', + store: 'play_store', + willRenew: true, + duration: null, + }, + ], + ]); + + // case1: empty -> should sync (first sequence) + { + const subs = await subResolver.refreshUserSubscriptions(currentUser); + t.is(stub.callCount, 1, 'Scenario1: RC API called once'); + t.truthy( + subs.find(s => s.plan === 'pro'), + 'case1: pro saved' + ); + } + + // case2: existing revenuecat -> should sync again (second sequence) + { + const subs = await subResolver.refreshUserSubscriptions(currentUser); + t.is(stub.callCount, 2, 'Scenario2: RC API called second time'); + t.truthy( + subs.find(s => s.plan === 'ai'), + 'case2: ai saved' + ); + } + + // case3: only stripe subscription -> should NOT sync (call count remains 2) + { + await db.subscription.deleteMany({ + where: { targetId: user.id, provider: 'revenuecat' }, + }); + await db.subscription.create({ + data: { + targetId: user.id, + plan: 'pro', + provider: 'stripe', + status: 'active', + recurring: 'monthly', + start: new Date('2025-01-01T00:00:00.000Z'), + stripeSubscriptionId: 'sub_123', + }, + }); + const subs = await subResolver.refreshUserSubscriptions(currentUser); + t.is(stub.callCount, 2, 'case3: RC API not called again'); + t.is(subs.length, 1, 'case3: only stripe subscription returned'); + } +}); diff --git a/packages/backend/server/src/plugins/payment/resolver.ts b/packages/backend/server/src/plugins/payment/resolver.ts index 43ede3fbe1..4e492bc941 100644 --- a/packages/backend/server/src/plugins/payment/resolver.ts +++ b/packages/backend/server/src/plugins/payment/resolver.ts @@ -12,8 +12,7 @@ import { ResolveField, Resolver, } from '@nestjs/graphql'; -import type { User } from '@prisma/client'; -import { PrismaClient } from '@prisma/client'; +import { PrismaClient, Provider, type User } from '@prisma/client'; import { GraphQLJSONObject } from 'graphql-scalars'; import { groupBy } from 'lodash-es'; import Stripe from 'stripe'; @@ -31,6 +30,7 @@ import { AccessController } from '../../core/permission'; import { UserType } from '../../core/user'; import { WorkspaceType } from '../../core/workspaces'; import { Invoice, Subscription, WorkspaceSubscriptionManager } from './manager'; +import { RevenueCatWebhookHandler } from './revenuecat'; import { CheckoutParams, SubscriptionService } from './service'; import { InvoiceStatus, @@ -463,7 +463,22 @@ export class SubscriptionResolver { @Resolver(() => UserType) export class UserSubscriptionResolver { - constructor(private readonly db: PrismaClient) {} + constructor( + private readonly db: PrismaClient, + private readonly rcHandler: RevenueCatWebhookHandler + ) {} + + private normalizeSubscription(s: Subscription) { + if ( + s.variant && + ![SubscriptionVariant.EA, SubscriptionVariant.Onetime].includes( + s.variant as SubscriptionVariant + ) + ) { + s.variant = null; + } + return s; + } @ResolveField(() => [SubscriptionType]) async subscriptions( @@ -487,16 +502,9 @@ export class UserSubscriptionResolver { }, }); - subscriptions.forEach(subscription => { - if ( - subscription.variant && - ![SubscriptionVariant.EA, SubscriptionVariant.Onetime].includes( - subscription.variant as SubscriptionVariant - ) - ) { - subscription.variant = null; - } - }); + subscriptions.forEach(subscription => + this.normalizeSubscription(subscription) + ); return subscriptions; } @@ -534,6 +542,71 @@ export class UserSubscriptionResolver { }, }); } + + @Throttle('strict') + @Mutation(() => [SubscriptionType], { + description: 'Refresh current user subscriptions and return latest.', + }) + async refreshUserSubscriptions( + @CurrentUser() user: CurrentUser + ): Promise { + if (!user) { + throw new AuthenticationRequired(); + } + + let current = await this.db.subscription.findMany({ + where: { + targetId: user.id, + status: { + in: [ + SubscriptionStatus.Active, + SubscriptionStatus.Trialing, + SubscriptionStatus.PastDue, + ], + }, + }, + }); + + const existsPlans = Object.values(SubscriptionPlan); + const subscriptions = current.reduce( + (r, s) => { + if (existsPlans.includes(s.plan as SubscriptionPlan)) { + r[s.plan as SubscriptionPlan] = s.provider; + } + return r; + }, + {} as Record + ); + + // has revenuecat subscription or no subscription at all + const shouldSync = + current.length === 0 || + subscriptions.pro === Provider.revenuecat || + subscriptions.ai === Provider.revenuecat; + + if (shouldSync) { + try { + await this.rcHandler.syncAppUser(user.id); + current = await this.db.subscription.findMany({ + where: { + targetId: user.id, + status: { + in: [ + SubscriptionStatus.Active, + SubscriptionStatus.Trialing, + SubscriptionStatus.PastDue, + ], + }, + }, + }); + // ignore errors + } catch {} + } + + current.forEach(subscription => this.normalizeSubscription(subscription)); + + return current; + } } @Resolver(() => WorkspaceType) diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index 1c8d8519e4..b1f8180547 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -1299,6 +1299,9 @@ type Mutation { """mark notification as read""" readNotification(id: String!): Boolean! recoverDoc(guid: String!, timestamp: DateTime!, workspaceId: String!): DateTime! + + """Refresh current user subscriptions and return latest.""" + refreshUserSubscriptions: [SubscriptionType!]! releaseDeletedBlobs(workspaceId: String!): Boolean! """Remove user avatar""" diff --git a/packages/common/graphql/src/graphql/index.ts b/packages/common/graphql/src/graphql/index.ts index 4f43d4dbfa..31e46ebc8b 100644 --- a/packages/common/graphql/src/graphql/index.ts +++ b/packages/common/graphql/src/graphql/index.ts @@ -2218,6 +2218,25 @@ export const setWorkspacePublicByIdMutation = { }`, }; +export const refreshSubscriptionMutation = { + id: 'refreshSubscriptionMutation' as const, + op: 'refreshSubscription', + query: `mutation refreshSubscription { + refreshUserSubscriptions { + id + status + plan + recurring + start + end + nextBillAt + canceledAt + variant + } +}`, + deprecations: ["'id' is deprecated: removed"], +}; + export const subscriptionQuery = { id: 'subscriptionQuery' as const, op: 'subscription', diff --git a/packages/common/graphql/src/graphql/subscription-refresh.gql b/packages/common/graphql/src/graphql/subscription-refresh.gql new file mode 100644 index 0000000000..e862de432e --- /dev/null +++ b/packages/common/graphql/src/graphql/subscription-refresh.gql @@ -0,0 +1,13 @@ +mutation refreshSubscription { + refreshUserSubscriptions { + id + status + plan + recurring + start + end + nextBillAt + canceledAt + variant + } +} diff --git a/packages/common/graphql/src/schema.ts b/packages/common/graphql/src/schema.ts index 5f07a20489..5a94ce8038 100644 --- a/packages/common/graphql/src/schema.ts +++ b/packages/common/graphql/src/schema.ts @@ -1451,6 +1451,8 @@ export interface Mutation { /** mark notification as read */ readNotification: Scalars['Boolean']['output']; recoverDoc: Scalars['DateTime']['output']; + /** Refresh current user subscriptions and return latest. */ + refreshUserSubscriptions: Array; releaseDeletedBlobs: Scalars['Boolean']['output']; /** Remove user avatar */ removeAvatar: RemoveAvatar; @@ -5996,6 +5998,26 @@ export type SetWorkspacePublicByIdMutation = { updateWorkspace: { __typename?: 'WorkspaceType'; id: string }; }; +export type RefreshSubscriptionMutationVariables = Exact<{ + [key: string]: never; +}>; + +export type RefreshSubscriptionMutation = { + __typename?: 'Mutation'; + refreshUserSubscriptions: Array<{ + __typename?: 'SubscriptionType'; + id: string | null; + status: SubscriptionStatus; + plan: SubscriptionPlan; + recurring: SubscriptionRecurring; + start: string; + end: string | null; + nextBillAt: string | null; + canceledAt: string | null; + variant: SubscriptionVariant | null; + }>; +}; + export type SubscriptionQueryVariables = Exact<{ [key: string]: never }>; export type SubscriptionQuery = { @@ -7081,6 +7103,11 @@ export type Mutations = variables: SetWorkspacePublicByIdMutationVariables; response: SetWorkspacePublicByIdMutation; } + | { + name: 'refreshSubscriptionMutation'; + variables: RefreshSubscriptionMutationVariables; + response: RefreshSubscriptionMutation; + } | { name: 'updateDocDefaultRoleMutation'; variables: UpdateDocDefaultRoleMutationVariables;