diff --git a/packages/backend/server/src/plugins/payment/resolver.ts b/packages/backend/server/src/plugins/payment/resolver.ts index e17895d24f..9735afcb47 100644 --- a/packages/backend/server/src/plugins/payment/resolver.ts +++ b/packages/backend/server/src/plugins/payment/resolver.ts @@ -7,6 +7,7 @@ import { Args, Context, Field, + InputType, Int, Mutation, ObjectType, @@ -128,6 +129,31 @@ class UserInvoiceType implements Partial { updatedAt!: Date; } +@InputType() +class CreateCheckoutSessionInput { + @Field(() => SubscriptionRecurring, { + nullable: true, + defaultValue: SubscriptionRecurring.Yearly, + }) + recurring!: SubscriptionRecurring; + + @Field(() => SubscriptionPlan, { + nullable: true, + defaultValue: SubscriptionPlan.Pro, + }) + plan!: SubscriptionPlan; + + @Field(() => String, { nullable: true }) + coupon!: string | null; + + @Field(() => String, { nullable: true }) + successCallbackLink!: string | null; + + // @FIXME(forehalo): we should put this field in the header instead of as a explicity args + @Field(() => String) + idempotencyKey!: string; +} + @Auth() @Resolver(() => UserSubscriptionType) export class SubscriptionResolver { @@ -182,7 +208,11 @@ export class SubscriptionResolver { }); } + /** + * @deprecated + */ @Mutation(() => String, { + deprecationReason: 'use `createCheckoutSession` instead', description: 'Create a subscription checkout link of stripe', }) async checkout( @@ -193,6 +223,7 @@ export class SubscriptionResolver { ) { const session = await this.service.createCheckoutSession({ user, + plan: SubscriptionPlan.Pro, recurring, redirectUrl: `${this.config.baseUrl}/upgrade-success`, idempotencyKey, @@ -205,6 +236,36 @@ export class SubscriptionResolver { return session.url; } + @Mutation(() => String, { + description: 'Create a subscription checkout link of stripe', + }) + async createCheckoutSession( + @CurrentUser() user: User, + @Args({ name: 'input', type: () => CreateCheckoutSessionInput }) + input: CreateCheckoutSessionInput + ) { + const session = await this.service.createCheckoutSession({ + user, + plan: input.plan, + recurring: input.recurring, + promotionCode: input.coupon, + redirectUrl: + input.successCallbackLink ?? `${this.config.baseUrl}/upgrade-success`, + idempotencyKey: input.idempotencyKey, + }); + + if (!session.url) { + throw new GraphQLError('Failed to create checkout session', { + extensions: { + status: HttpStatus[HttpStatus.BAD_GATEWAY], + code: HttpStatus.BAD_GATEWAY, + }, + }); + } + + return session.url; + } + @Mutation(() => String, { description: 'Create a stripe customer portal to manage payment methods', }) diff --git a/packages/backend/server/src/plugins/payment/service.ts b/packages/backend/server/src/plugins/payment/service.ts index b7b6e62994..069763ffef 100644 --- a/packages/backend/server/src/plugins/payment/service.ts +++ b/packages/backend/server/src/plugins/payment/service.ts @@ -69,13 +69,15 @@ export class SubscriptionService { async createCheckoutSession({ user, recurring, + plan, + promotionCode, redirectUrl, idempotencyKey, - plan = SubscriptionPlan.Pro, }: { user: User; - plan?: SubscriptionPlan; recurring: SubscriptionRecurring; + plan: SubscriptionPlan; + promotionCode?: string | null; redirectUrl: string; idempotencyKey: string; }) { @@ -95,7 +97,28 @@ export class SubscriptionService { `${idempotencyKey}-getOrCreateCustomer`, user ); - const coupon = await this.getAvailableCoupon(user, CouponType.EarlyAccess); + + let discount: { coupon?: string; promotion_code?: string } | undefined; + + if (promotionCode) { + const code = await this.getAvailablePromotionCode( + promotionCode, + customer.stripeCustomerId + ); + if (code) { + discount ??= {}; + discount.promotion_code = code; + } + } else { + const coupon = await this.getAvailableCoupon( + user, + CouponType.EarlyAccess + ); + if (coupon) { + discount ??= {}; + discount.coupon = coupon; + } + } return await this.stripe.checkout.sessions.create( { @@ -108,13 +131,11 @@ export class SubscriptionService { tax_id_collection: { enabled: true, }, - ...(coupon + ...(discount ? { - discounts: [{ coupon }], + discounts: [discount], } - : { - allow_promotion_codes: true, - }), + : { allow_promotion_codes: true }), mode: 'subscription', success_url: redirectUrl, customer: customer.stripeCustomerId, @@ -643,4 +664,33 @@ export class SubscriptionService { return null; } + + private async getAvailablePromotionCode( + userFacingPromotionCode: string, + customer?: string + ) { + const list = await this.stripe.promotionCodes.list({ + code: userFacingPromotionCode, + active: true, + limit: 1, + }); + + const code = list.data[0]; + if (!code) { + return null; + } + + let available = false; + + if (code.customer) { + available = + typeof code.customer === 'string' + ? code.customer === customer + : code.customer.id === customer; + } else { + available = true; + } + + return available ? code.id : null; + } } diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index 70cca83979..beb0bbdce7 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -2,6 +2,14 @@ # THIS FILE WAS AUTOMATICALLY GENERATED (DO NOT MODIFY) # ------------------------------------------------------ +input CreateCheckoutSessionInput { + coupon: String + idempotencyKey: String! + plan: SubscriptionPlan = Pro + recurring: SubscriptionRecurring = Yearly + successCallbackLink: String +} + """ A date-time string at UTC, such as 2019-12-03T09:54:33Z, compliant with the date-time format. """ @@ -107,7 +115,10 @@ type Mutation { changePassword(newPassword: String!, token: String!): UserType! """Create a subscription checkout link of stripe""" - checkout(idempotencyKey: String!, recurring: SubscriptionRecurring!): String! + checkout(idempotencyKey: String!, recurring: SubscriptionRecurring!): String! @deprecated(reason: "use `createCheckoutSession` instead") + + """Create a subscription checkout link of stripe""" + createCheckoutSession(input: CreateCheckoutSessionInput!): String! """Create a stripe customer portal to manage payment methods""" createCustomerPortal: String!