diff --git a/packages/backend/server/.env.example b/packages/backend/server/.env.example index 55d11ef5fb..749873dc40 100644 --- a/packages/backend/server/.env.example +++ b/packages/backend/server/.env.example @@ -3,4 +3,6 @@ NEXTAUTH_URL="http://localhost:8080" OAUTH_EMAIL_SENDER="noreply@toeverything.info" OAUTH_EMAIL_LOGIN="" OAUTH_EMAIL_PASSWORD="" -ENABLE_LOCAL_EMAIL="true" \ No newline at end of file +ENABLE_LOCAL_EMAIL="true" +STRIPE_API_KEY= +STRIPE_WEBHOOK_KEY= diff --git a/packages/backend/server/migrations/20231018074747_payment/migration.sql b/packages/backend/server/migrations/20231018074747_payment/migration.sql new file mode 100644 index 0000000000..abd017259d --- /dev/null +++ b/packages/backend/server/migrations/20231018074747_payment/migration.sql @@ -0,0 +1,68 @@ +-- CreateTable +CREATE TABLE "user_stripe_customers" ( + "user_id" VARCHAR NOT NULL, + "stripe_customer_id" VARCHAR NOT NULL, + "created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "user_stripe_customers_pkey" PRIMARY KEY ("user_id") +); + +-- CreateTable +CREATE TABLE "user_subscriptions" ( + "id" SERIAL NOT NULL, + "user_id" VARCHAR(36) NOT NULL, + "plan" VARCHAR(20) NOT NULL, + "recurring" VARCHAR(20) NOT NULL, + "stripe_subscription_id" TEXT NOT NULL, + "status" VARCHAR(20) NOT NULL, + "start" TIMESTAMPTZ(6) NOT NULL, + "end" TIMESTAMPTZ(6) NOT NULL, + "next_bill_at" TIMESTAMPTZ(6), + "canceled_at" TIMESTAMPTZ(6), + "trial_start" TIMESTAMPTZ(6), + "trial_end" TIMESTAMPTZ(6), + "stripe_schedule_id" VARCHAR, + "created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMPTZ(6) NOT NULL, + + CONSTRAINT "user_subscriptions_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "user_invoices" ( + "id" SERIAL NOT NULL, + "user_id" VARCHAR(36) NOT NULL, + "stripe_invoice_id" TEXT NOT NULL, + "currency" VARCHAR(3) NOT NULL, + "amount" INTEGER NOT NULL, + "status" VARCHAR(20) NOT NULL, + "plan" VARCHAR(20) NOT NULL, + "recurring" VARCHAR(20) NOT NULL, + "created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMPTZ(6) NOT NULL, + "reason" VARCHAR NOT NULL, + "last_payment_error" TEXT, + + CONSTRAINT "user_invoices_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "user_stripe_customers_stripe_customer_id_key" ON "user_stripe_customers"("stripe_customer_id"); + +-- CreateIndex +CREATE UNIQUE INDEX "user_subscriptions_user_id_key" ON "user_subscriptions"("user_id"); + +-- CreateIndex +CREATE UNIQUE INDEX "user_subscriptions_stripe_subscription_id_key" ON "user_subscriptions"("stripe_subscription_id"); + +-- CreateIndex +CREATE UNIQUE INDEX "user_invoices_stripe_invoice_id_key" ON "user_invoices"("stripe_invoice_id"); + +-- AddForeignKey +ALTER TABLE "user_stripe_customers" ADD CONSTRAINT "user_stripe_customers_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "user_subscriptions" ADD CONSTRAINT "user_subscriptions_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "user_invoices" ADD CONSTRAINT "user_invoices_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/packages/backend/server/package.json b/packages/backend/server/package.json index 2567bd6c9d..2571db2fb4 100644 --- a/packages/backend/server/package.json +++ b/packages/backend/server/package.json @@ -25,6 +25,7 @@ "@nestjs/apollo": "^12.0.9", "@nestjs/common": "^10.2.7", "@nestjs/core": "^10.2.7", + "@nestjs/event-emitter": "^2.0.2", "@nestjs/graphql": "^12.0.9", "@nestjs/platform-express": "^10.2.7", "@nestjs/platform-socket.io": "^10.2.7", @@ -71,6 +72,7 @@ "rxjs": "^7.8.1", "semver": "^7.5.4", "socket.io": "^4.7.2", + "stripe": "^13.6.0", "ws": "^8.14.2", "yjs": "^13.6.8" }, diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index 70c6edf834..bea9b7f8d5 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -49,6 +49,9 @@ model User { /// Not available if user signed up through OAuth providers password String? @db.VarChar features UserFeatureGates[] + customer UserStripeCustomer? + subscription UserSubscription? + invoices UserInvoice[] @@map("users") } @@ -164,3 +167,65 @@ model NewFeaturesWaitingList { @@map("new_features_waiting_list") } + +model UserStripeCustomer { + userId String @id @map("user_id") @db.VarChar + stripeCustomerId String @unique @map("stripe_customer_id") @db.VarChar + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) + + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@map("user_stripe_customers") +} + +model UserSubscription { + id Int @id @default(autoincrement()) @db.Integer + userId String @unique @map("user_id") @db.VarChar(36) + plan String @db.VarChar(20) + // yearly/monthly + recurring String @db.VarChar(20) + // subscription.id + stripeSubscriptionId String @unique @map("stripe_subscription_id") + // subscription.status, active/past_due/canceled/unpaid... + status String @db.VarChar(20) + // subscription.current_period_start + start DateTime @map("start") @db.Timestamptz(6) + // subscription.current_period_end + end DateTime @map("end") @db.Timestamptz(6) + // subscription.billing_cycle_anchor + nextBillAt DateTime? @map("next_bill_at") @db.Timestamptz(6) + // subscription.canceled_at + canceledAt DateTime? @map("canceled_at") @db.Timestamptz(6) + // subscription.trial_start + trialStart DateTime? @map("trial_start") @db.Timestamptz(6) + // subscription.trial_end + trialEnd DateTime? @map("trial_end") @db.Timestamptz(6) + stripeScheduleId String? @map("stripe_schedule_id") @db.VarChar + + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) + updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6) + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@map("user_subscriptions") +} + +model UserInvoice { + id Int @id @default(autoincrement()) @db.Integer + userId String @map("user_id") @db.VarChar(36) + stripeInvoiceId String @unique @map("stripe_invoice_id") + currency String @db.VarChar(3) + // CNY 12.50 stored as 1250 + amount Int @db.Integer + status String @db.VarChar(20) + plan String @db.VarChar(20) + recurring String @db.VarChar(20) + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) + updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6) + // billing reason + reason String @db.VarChar + lastPaymentError String? @map("last_payment_error") @db.Text + + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@map("user_invoices") +} diff --git a/packages/backend/server/src/config/def.ts b/packages/backend/server/src/config/def.ts index 345b3942e1..b7d1493d4a 100644 --- a/packages/backend/server/src/config/def.ts +++ b/packages/backend/server/src/config/def.ts @@ -363,4 +363,13 @@ export interface AFFiNEConfig { experimentalMergeWithJwstCodec: boolean; }; }; + + payment: { + stripe: { + keys: { + APIKey: string; + webhookKey: string; + }; + } & import('stripe').Stripe.StripeConfig; + }; } diff --git a/packages/backend/server/src/config/default.ts b/packages/backend/server/src/config/default.ts index b23b915f79..ae46469847 100644 --- a/packages/backend/server/src/config/default.ts +++ b/packages/backend/server/src/config/default.ts @@ -89,6 +89,8 @@ export const getDefaultAFFiNEConfig: () => AFFiNEConfig = () => { 'boolean', ], ENABLE_LOCAL_EMAIL: ['auth.localEmail', 'boolean'], + STRIPE_API_KEY: 'payment.stripe.keys.APIKey', + STRIPE_WEBHOOK_KEY: 'payment.stripe.keys.webhookKey', } satisfies AFFiNEConfig['ENV_MAP'], affineEnv: 'dev', get affine() { @@ -207,6 +209,15 @@ export const getDefaultAFFiNEConfig: () => AFFiNEConfig = () => { experimentalMergeWithJwstCodec: false, }, }, + payment: { + stripe: { + keys: { + APIKey: '', + webhookKey: '', + }, + apiVersion: '2023-08-16', + }, + }, } satisfies AFFiNEConfig; applyEnvToConfig(defaultConfig); diff --git a/packages/backend/server/src/index.ts b/packages/backend/server/src/index.ts index 4b83dbd446..446b82c60d 100644 --- a/packages/backend/server/src/index.ts +++ b/packages/backend/server/src/index.ts @@ -59,6 +59,7 @@ if (NODE_ENV === 'production') { const app = await NestFactory.create(AppModule, { cors: true, + rawBody: true, bodyParser: true, logger: NODE_ENV !== 'production' || AFFINE_ENV !== 'production' diff --git a/packages/backend/server/src/modules/index.ts b/packages/backend/server/src/modules/index.ts index c11e3428b1..600c5154b2 100644 --- a/packages/backend/server/src/modules/index.ts +++ b/packages/backend/server/src/modules/index.ts @@ -1,8 +1,10 @@ import { DynamicModule, Type } from '@nestjs/common'; +import { EventEmitterModule } from '@nestjs/event-emitter'; import { GqlModule } from '../graphql.module'; import { AuthModule } from './auth'; import { DocModule } from './doc'; +import { PaymentModule } from './payment'; import { SyncModule } from './sync'; import { UsersModule } from './users'; import { WorkspaceModule } from './workspaces'; @@ -17,22 +19,30 @@ switch (SERVER_FLAVOR) { break; case 'graphql': BusinessModules.push( + EventEmitterModule.forRoot({ + global: true, + }), GqlModule, WorkspaceModule, UsersModule, AuthModule, - DocModule.forRoot() + DocModule.forRoot(), + PaymentModule ); break; case 'allinone': default: BusinessModules.push( + EventEmitterModule.forRoot({ + global: true, + }), GqlModule, WorkspaceModule, UsersModule, AuthModule, SyncModule, - DocModule.forRoot() + DocModule.forRoot(), + PaymentModule ); break; } diff --git a/packages/backend/server/src/modules/payment/index.ts b/packages/backend/server/src/modules/payment/index.ts new file mode 100644 index 0000000000..1a51678848 --- /dev/null +++ b/packages/backend/server/src/modules/payment/index.ts @@ -0,0 +1,17 @@ +import { Module } from '@nestjs/common'; + +import { SubscriptionResolver, UserSubscriptionResolver } from './resolver'; +import { SubscriptionService } from './service'; +import { StripeProvider } from './stripe'; +import { StripeWebhook } from './webhook'; + +@Module({ + providers: [ + StripeProvider, + SubscriptionService, + SubscriptionResolver, + UserSubscriptionResolver, + ], + controllers: [StripeWebhook], +}) +export class PaymentModule {} diff --git a/packages/backend/server/src/modules/payment/resolver.ts b/packages/backend/server/src/modules/payment/resolver.ts new file mode 100644 index 0000000000..e5e1a77a3c --- /dev/null +++ b/packages/backend/server/src/modules/payment/resolver.ts @@ -0,0 +1,246 @@ +import { + BadGatewayException, + ForbiddenException, + InternalServerErrorException, +} from '@nestjs/common'; +import { + Args, + Field, + Int, + Mutation, + ObjectType, + Parent, + Query, + registerEnumType, + ResolveField, + Resolver, +} from '@nestjs/graphql'; +import type { User, UserInvoice, UserSubscription } from '@prisma/client'; + +import { Config } from '../../config'; +import { PrismaService } from '../../prisma'; +import { Auth, CurrentUser, Public } from '../auth'; +import { UserType } from '../users'; +import { + InvoiceStatus, + SubscriptionPlan, + SubscriptionRecurring, + SubscriptionService, + SubscriptionStatus, +} from './service'; + +registerEnumType(SubscriptionStatus, { name: 'SubscriptionStatus' }); +registerEnumType(SubscriptionRecurring, { name: 'SubscriptionRecurring' }); +registerEnumType(SubscriptionPlan, { name: 'SubscriptionPlan' }); +registerEnumType(InvoiceStatus, { name: 'InvoiceStatus' }); + +@ObjectType() +class SubscriptionPrice { + @Field(() => String) + type!: 'fixed'; + + @Field(() => SubscriptionPlan) + plan!: SubscriptionPlan; + + @Field() + currency!: string; + + @Field() + amount!: number; + + @Field() + yearlyAmount!: number; +} + +@ObjectType('UserSubscription') +class UserSubscriptionType implements Partial { + @Field({ name: 'id' }) + stripeSubscriptionId!: string; + + @Field(() => SubscriptionPlan) + plan!: SubscriptionPlan; + + @Field(() => SubscriptionRecurring) + recurring!: SubscriptionRecurring; + + @Field(() => SubscriptionStatus) + status!: SubscriptionStatus; + + @Field(() => Date) + start!: Date; + + @Field(() => Date) + end!: Date; + + @Field(() => Date, { nullable: true }) + trialStart?: Date | null; + + @Field(() => Date, { nullable: true }) + trialEnd?: Date | null; + + @Field(() => Date, { nullable: true }) + nextBillAt?: Date | null; + + @Field(() => Date, { nullable: true }) + canceledAt?: Date | null; + + @Field(() => Date) + createdAt!: Date; + + @Field(() => Date) + updatedAt!: Date; +} + +@ObjectType('UserInvoice') +class UserInvoiceType implements Partial { + @Field({ name: 'id' }) + stripeInvoiceId!: string; + + @Field(() => SubscriptionPlan) + plan!: SubscriptionPlan; + + @Field(() => SubscriptionRecurring) + recurring!: SubscriptionRecurring; + + @Field() + currency!: string; + + @Field() + amount!: number; + + @Field(() => InvoiceStatus) + status!: InvoiceStatus; + + @Field() + reason!: string; + + @Field(() => String, { nullable: true }) + lastPaymentError?: string | null; + + @Field(() => Date) + createdAt!: Date; + + @Field(() => Date) + updatedAt!: Date; +} + +@Auth() +@Resolver(() => UserSubscriptionType) +export class SubscriptionResolver { + constructor( + private readonly service: SubscriptionService, + private readonly config: Config + ) {} + + @Public() + @Query(() => [SubscriptionPrice]) + async prices(): Promise { + const prices = await this.service.listPrices(); + + const yearly = prices.data.find( + price => price.lookup_key === SubscriptionRecurring.Yearly + ); + const monthly = prices.data.find( + price => price.lookup_key === SubscriptionRecurring.Monthly + ); + + if (!yearly || !monthly) { + throw new BadGatewayException('The prices are not configured correctly'); + } + + return [ + { + type: 'fixed', + plan: SubscriptionPlan.Pro, + currency: monthly.currency, + amount: monthly.unit_amount ?? 0, + yearlyAmount: yearly.unit_amount ?? 0, + }, + ]; + } + + @Mutation(() => String, { + description: 'Create a subscription checkout link of stripe', + }) + async checkout( + @CurrentUser() user: User, + @Args({ name: 'recurring', type: () => SubscriptionRecurring }) + recurring: SubscriptionRecurring + ) { + const session = await this.service.createCheckoutSession({ + user, + recurring, + // TODO: replace with frontend url + redirectUrl: `${this.config.baseUrl}/api/stripe/success`, + }); + + if (!session.url) { + throw new InternalServerErrorException( + 'Failed to create checkout session' + ); + } + + return session.url; + } + + @Mutation(() => UserSubscriptionType) + async cancelSubscription(@CurrentUser() user: User) { + return this.service.cancelSubscription(user.id); + } + + @Mutation(() => UserSubscriptionType) + async resumeSubscription(@CurrentUser() user: User) { + return this.service.resumeCanceledSubscriptin(user.id); + } + + @Mutation(() => UserSubscriptionType) + async updateSubscriptionRecurring( + @CurrentUser() user: User, + @Args({ name: 'recurring', type: () => SubscriptionRecurring }) + recurring: SubscriptionRecurring + ) { + return this.service.updateSubscriptionRecurring(user.id, recurring); + } +} + +@Resolver(() => UserType) +export class UserSubscriptionResolver { + constructor(private readonly db: PrismaService) {} + + @ResolveField(() => UserSubscriptionType, { nullable: true }) + async subscription(@CurrentUser() me: User, @Parent() user: User) { + if (me.id !== user.id) { + throw new ForbiddenException(); + } + + return this.db.userSubscription.findUnique({ + where: { + userId: user.id, + }, + }); + } + + @ResolveField(() => [UserInvoiceType]) + async invoices( + @CurrentUser() me: User, + @Parent() user: User, + @Args('take', { type: () => Int, nullable: true, defaultValue: 8 }) + take: number, + @Args('skip', { type: () => Int, nullable: true }) skip?: number + ) { + if (me.id !== user.id) { + throw new ForbiddenException(); + } + + return this.db.userInvoice.findMany({ + where: { + userId: user.id, + }, + take, + skip, + orderBy: { + id: 'desc', + }, + }); + } +} diff --git a/packages/backend/server/src/modules/payment/service.ts b/packages/backend/server/src/modules/payment/service.ts new file mode 100644 index 0000000000..8f9771b1bd --- /dev/null +++ b/packages/backend/server/src/modules/payment/service.ts @@ -0,0 +1,576 @@ +import { Injectable, Logger } from '@nestjs/common'; +import { OnEvent as RawOnEvent } from '@nestjs/event-emitter'; +import type { + Prisma, + User, + UserInvoice, + UserStripeCustomer, + UserSubscription, +} from '@prisma/client'; +import Stripe from 'stripe'; + +import { Config } from '../../config'; +import { PrismaService } from '../../prisma'; + +const OnEvent = ( + event: Stripe.Event.Type, + opts?: Parameters[1] +) => RawOnEvent(event, opts); + +// also used as lookup key for stripe prices +export enum SubscriptionRecurring { + Monthly = 'monthly', + Yearly = 'yearly', +} + +export enum SubscriptionPlan { + Free = 'free', + Pro = 'pro', + Team = 'team', + Enterprise = 'enterprise', +} + +// see https://stripe.com/docs/api/subscriptions/object#subscription_object-status +export enum SubscriptionStatus { + Active = 'active', + PastDue = 'past_due', + Unpaid = 'unpaid', + Canceled = 'canceled', + Incomplete = 'incomplete', + Paused = 'paused', + IncompleteExpired = 'incomplete_expired', + Trialing = 'trialing', +} + +export enum InvoiceStatus { + Draft = 'draft', + Open = 'open', + Void = 'void', + Paid = 'paid', + Uncollectible = 'uncollectible', +} + +@Injectable() +export class SubscriptionService { + private readonly paymentConfig: Config['payment']; + private readonly logger = new Logger(SubscriptionService.name); + + constructor( + config: Config, + private readonly stripe: Stripe, + private readonly db: PrismaService + ) { + this.paymentConfig = config.payment; + + if ( + !this.paymentConfig.stripe.keys.APIKey || + !this.paymentConfig.stripe.keys.webhookKey /* default empty string */ + ) { + this.logger.warn('Stripe API key not set, Stripe will be disabled'); + this.logger.warn('Set STRIPE_API_KEY to enable Stripe'); + } + } + + async listPrices() { + return this.stripe.prices.list({ + lookup_keys: Object.values(SubscriptionRecurring), + }); + } + + async createCheckoutSession({ + user, + recurring, + redirectUrl, + }: { + user: User; + recurring: SubscriptionRecurring; + redirectUrl: string; + }) { + const currentSubscription = await this.db.userSubscription.findUnique({ + where: { + userId: user.id, + }, + }); + + if (currentSubscription && currentSubscription.end < new Date()) { + throw new Error('User already has a subscription'); + } + + const prices = await this.stripe.prices.list({ + lookup_keys: [recurring], + }); + + if (!prices.data.length) { + throw new Error(`Unknown subscription recurring: ${recurring}`); + } + + const customer = await this.getOrCreateCustomer(user); + return await this.stripe.checkout.sessions.create({ + line_items: [ + { + price: prices.data[0].id, + quantity: 1, + }, + ], + allow_promotion_codes: true, + tax_id_collection: { + enabled: true, + }, + mode: 'subscription', + success_url: redirectUrl, + customer: customer.stripeCustomerId, + customer_update: { + address: 'auto', + name: 'auto', + }, + }); + } + + async cancelSubscription(userId: string): Promise { + const user = await this.db.user.findUnique({ + where: { + id: userId, + }, + include: { + subscription: true, + }, + }); + + if (!user?.subscription) { + throw new Error('User has no subscription'); + } + + if (user.subscription.canceledAt) { + throw new Error('User subscription has already been canceled '); + } + + // should release the schedule first + if (user.subscription.stripeScheduleId) { + await this.stripe.subscriptionSchedules.release( + user.subscription.stripeScheduleId + ); + } + + // let customer contact support if they want to cancel immediately + // see https://stripe.com/docs/billing/subscriptions/cancel + const subscription = await this.stripe.subscriptions.update( + user.subscription.stripeSubscriptionId, + { + cancel_at_period_end: true, + } + ); + + return await this.saveSubscription(user, subscription); + } + + async resumeCanceledSubscriptin(userId: string): Promise { + const user = await this.db.user.findUnique({ + where: { + id: userId, + }, + include: { + subscription: true, + }, + }); + + if (!user?.subscription) { + throw new Error('User has no subscription'); + } + + if (!user.subscription.canceledAt) { + throw new Error('User subscription is not canceled'); + } + + if (user.subscription.end < new Date()) { + throw new Error( + 'User subscription has already expired, please checkout again.' + ); + } + + const subscription = await this.stripe.subscriptions.update( + user.subscription.stripeSubscriptionId, + { + cancel_at_period_end: false, + } + ); + + return await this.saveSubscription(user, subscription); + } + + async updateSubscriptionRecurring( + userId: string, + recurring: string + ): Promise { + const user = await this.db.user.findUnique({ + where: { + id: userId, + }, + include: { + subscription: true, + }, + }); + + if (!user?.subscription) { + throw new Error('User has no subscription'); + } + + if (user.subscription.recurring === recurring) { + throw new Error('User has already subscribed to this plan'); + } + + const prices = await this.stripe.prices.list({ + lookup_keys: [recurring], + }); + + if (!prices.data.length) { + throw new Error(`Unknown subscription recurring: ${recurring}`); + } + + const newPrice = prices.data[0]; + + // a schedule existing + if (user.subscription.stripeScheduleId) { + const schedule = await this.stripe.subscriptionSchedules.retrieve( + user.subscription.stripeScheduleId + ); + + // a scheduled subscription's old price equals the change + if ( + schedule.phases[0] && + (schedule.phases[0].items[0].price as string) === newPrice.id + ) { + await this.stripe.subscriptionSchedules.release( + user.subscription.stripeScheduleId + ); + + return await this.db.userSubscription.update({ + where: { + id: user.subscription.id, + }, + data: { + recurring, + }, + }); + } else { + throw new Error( + 'Unexpected subscription scheduled, please contact the supporters' + ); + } + } else { + const schedule = await this.stripe.subscriptionSchedules.create({ + from_subscription: user.subscription.stripeSubscriptionId, + }); + + await this.stripe.subscriptionSchedules.update(schedule.id, { + phases: [ + { + items: [ + { + price: schedule.phases[0].items[0].price as string, + quantity: 1, + }, + ], + start_date: schedule.phases[0].start_date, + end_date: schedule.phases[0].end_date, + }, + { + items: [ + { + price: newPrice.id, + quantity: 1, + }, + ], + }, + ], + }); + + return await this.db.userSubscription.update({ + where: { + id: user.subscription.id, + }, + data: { + recurring, + stripeScheduleId: schedule.id, + }, + }); + } + } + + @OnEvent('customer.subscription.created') + @OnEvent('customer.subscription.updated') + async onSubscriptionChanges(subscription: Stripe.Subscription) { + const user = await this.retrieveUserFromCustomer( + subscription.customer as string + ); + + await this.saveSubscription(user, subscription); + } + + @OnEvent('customer.subscription.deleted') + async onSubscriptionDeleted(subscription: Stripe.Subscription) { + const user = await this.retrieveUserFromCustomer( + subscription.customer as string + ); + + await this.db.userSubscription.deleteMany({ + where: { + stripeSubscriptionId: subscription.id, + userId: user.id, + }, + }); + } + + @OnEvent('invoice.created') + async onInvoiceCreated(invoice: Stripe.Invoice) { + await this.saveInvoice(invoice); + } + + @OnEvent('invoice.paid') + async onInvoicePaid(invoice: Stripe.Invoice) { + await this.saveInvoice(invoice); + } + + @OnEvent('invoice.finalization_failed') + async onInvoiceFinalizeFailed(invoice: Stripe.Invoice) { + await this.saveInvoice(invoice); + } + + @OnEvent('invoice.payment_failed') + async onInvoicePaymentFailed(invoice: Stripe.Invoice) { + await this.saveInvoice(invoice); + } + + private async saveSubscription( + user: User, + subscription: Stripe.Subscription + ): Promise { + // get next bill date from upcoming invoice + // see https://stripe.com/docs/api/invoices/upcoming + let nextBillAt: Date | null = null; + if ( + (subscription.status === SubscriptionStatus.Active || + subscription.status === SubscriptionStatus.Trialing) && + !subscription.canceled_at + ) { + try { + const nextInvoice = await this.stripe.invoices.retrieveUpcoming({ + customer: subscription.customer as string, + subscription: subscription.id, + }); + + nextBillAt = new Date(nextInvoice.created * 1000); + } catch (e) { + // no upcoming invoice + // safe to ignore + } + } + + const price = subscription.items.data[0].price; + + const commonData = { + start: new Date(subscription.current_period_start * 1000), + end: new Date(subscription.current_period_end * 1000), + trialStart: subscription.trial_start + ? new Date(subscription.trial_start * 1000) + : null, + trialEnd: subscription.trial_end + ? new Date(subscription.trial_end * 1000) + : null, + nextBillAt, + canceledAt: subscription.canceled_at + ? new Date(subscription.canceled_at * 1000) + : null, + stripeSubscriptionId: subscription.id, + recurring: price.lookup_key ?? price.id, + // TODO: dynamic plans + plan: SubscriptionPlan.Pro, + status: subscription.status, + stripeScheduleId: subscription.schedule as string | null, + }; + + const currentSubscription = await this.db.userSubscription.findUnique({ + where: { + userId: user.id, + }, + }); + + if (currentSubscription) { + const update: Prisma.UserSubscriptionUpdateInput = { + ...commonData, + }; + + // a schedule exists, update the recurring to scheduled one + if (update.stripeScheduleId) { + delete update.recurring; + } + + return await this.db.userSubscription.update({ + where: { + id: currentSubscription.id, + }, + data: update, + }); + } else { + return await this.db.userSubscription.create({ + data: { + userId: user.id, + ...commonData, + }, + }); + } + } + + private async getOrCreateCustomer(user: User): Promise { + const customer = await this.db.userStripeCustomer.findUnique({ + where: { + userId: user.id, + }, + }); + + if (customer) { + return customer; + } + + const stripeCustomersList = await this.stripe.customers.list({ + email: user.email, + limit: 1, + }); + + let stripeCustomer: Stripe.Customer | undefined; + if (stripeCustomersList.data.length) { + stripeCustomer = stripeCustomersList.data[0]; + } else { + stripeCustomer = await this.stripe.customers.create({ + email: user.email, + }); + } + + return await this.db.userStripeCustomer.create({ + data: { + userId: user.id, + stripeCustomerId: stripeCustomer.id, + }, + }); + } + + private async retrieveUserFromCustomer(customerId: string) { + const customer = await this.db.userStripeCustomer.findUnique({ + where: { + stripeCustomerId: customerId, + }, + include: { + user: true, + }, + }); + + if (customer?.user) { + return customer.user; + } + + // customer may not saved is db, check it with stripe + const stripeCustomer = await this.stripe.customers.retrieve(customerId); + + if (stripeCustomer.deleted) { + throw new Error('Unexpected subscription created with deleted customer'); + } + + if (!stripeCustomer.email) { + throw new Error('Unexpected subscription created with no email customer'); + } + + const user = await this.db.user.findUnique({ + where: { + email: stripeCustomer.email, + }, + }); + + if (!user) { + throw new Error( + `Unexpected subscription created with unknown customer ${stripeCustomer.email}` + ); + } + + await this.db.userStripeCustomer.create({ + data: { + userId: user.id, + stripeCustomerId: stripeCustomer.id, + }, + }); + + return user; + } + + private async saveInvoice(stripeInvoice: Stripe.Invoice) { + if (!stripeInvoice.customer) { + throw new Error('Unexpected invoice with no customer'); + } + + const user = await this.retrieveUserFromCustomer( + stripeInvoice.customer as string + ); + + const invoice = await this.db.userInvoice.findUnique({ + where: { + stripeInvoiceId: stripeInvoice.id, + }, + }); + + const data: Partial = { + currency: stripeInvoice.currency, + amount: stripeInvoice.total, + status: stripeInvoice.status ?? InvoiceStatus.Void, + }; + + // handle payment error + if (stripeInvoice.attempt_count > 1) { + const paymentIntent = await this.stripe.paymentIntents.retrieve( + stripeInvoice.payment_intent as string + ); + + if (paymentIntent.last_payment_error) { + if (paymentIntent.last_payment_error.type === 'card_error') { + data.lastPaymentError = + paymentIntent.last_payment_error.message ?? 'Failed to pay'; + } else { + data.lastPaymentError = 'Internal Payment error'; + } + } + } else if (stripeInvoice.last_finalization_error) { + if (stripeInvoice.last_finalization_error.type === 'card_error') { + data.lastPaymentError = + stripeInvoice.last_finalization_error.message ?? + 'Failed to finalize invoice'; + } else { + data.lastPaymentError = 'Internal Payment error'; + } + } + + // update invoice + if (invoice) { + await this.db.userInvoice.update({ + where: { + stripeInvoiceId: stripeInvoice.id, + }, + data, + }); + } else { + // create invoice + const price = stripeInvoice.lines.data[0].price; + + if (!price || price.type !== 'recurring') { + throw new Error('Unexpected invoice with no recurring price'); + } + + await this.db.userInvoice.create({ + data: { + userId: user.id, + stripeInvoiceId: stripeInvoice.id, + plan: SubscriptionPlan.Pro, + recurring: price.lookup_key ?? price.id, + reason: stripeInvoice.billing_reason ?? 'contact support', + ...(data as any), + }, + }); + } + } +} diff --git a/packages/backend/server/src/modules/payment/stripe.ts b/packages/backend/server/src/modules/payment/stripe.ts new file mode 100644 index 0000000000..4538471121 --- /dev/null +++ b/packages/backend/server/src/modules/payment/stripe.ts @@ -0,0 +1,18 @@ +import { FactoryProvider } from '@nestjs/common'; +import { omit } from 'lodash-es'; +import Stripe from 'stripe'; + +import { Config } from '../../config'; + +export const StripeProvider: FactoryProvider = { + provide: Stripe, + useFactory: (config: Config) => { + const stripeConfig = config.payment.stripe; + + return new Stripe( + stripeConfig.keys.APIKey, + omit(config.payment.stripe, 'keys', 'prices') + ); + }, + inject: [Config], +}; diff --git a/packages/backend/server/src/modules/payment/webhook.ts b/packages/backend/server/src/modules/payment/webhook.ts new file mode 100644 index 0000000000..785ea97832 --- /dev/null +++ b/packages/backend/server/src/modules/payment/webhook.ts @@ -0,0 +1,75 @@ +import type { RawBodyRequest } from '@nestjs/common'; +import { + Controller, + Get, + Logger, + NotAcceptableException, + Post, + Req, +} from '@nestjs/common'; +import { EventEmitter2 } from '@nestjs/event-emitter'; +import type { User } from '@prisma/client'; +import type { Request } from 'express'; +import Stripe from 'stripe'; + +import { Config } from '../../config'; +import { PrismaService } from '../../prisma'; +import { Auth, CurrentUser } from '../auth'; + +@Controller('/api/stripe') +export class StripeWebhook { + private readonly config: Config['payment']; + private readonly logger = new Logger(StripeWebhook.name); + + constructor( + config: Config, + private readonly stripe: Stripe, + private readonly event: EventEmitter2, + private readonly db: PrismaService + ) { + this.config = config.payment; + } + + // just for test + @Auth() + @Get('/success') + async handleSuccess(@CurrentUser() user: User) { + return this.db.userSubscription.findUnique({ + where: { + userId: user.id, + }, + }); + } + + @Post('/webhook') + async handleWebhook(@Req() req: RawBodyRequest) { + // Check if webhook signing is configured. + if (!this.config.stripe.keys.webhookKey) { + this.logger.error( + 'Stripe Webhook key is not set, but a webhook was received.' + ); + throw new NotAcceptableException(); + } + + // Retrieve the event by verifying the signature using the raw body and secret. + const signature = req.headers['stripe-signature']; + try { + const event = this.stripe.webhooks.constructEvent( + req.rawBody ?? '', + signature ?? '', + this.config.stripe.keys.webhookKey + ); + + this.logger.debug( + `[${event.id}] Stripe Webhook {${event.type}} received.` + ); + + // handle duplicated events? + // see https://stripe.com/docs/webhooks#handle-duplicate-events + await this.event.emitAsync(event.type, event.data.object); + } catch (err) { + this.logger.error('Stripe Webhook error', err); + throw new NotAcceptableException(); + } + } +} diff --git a/packages/backend/server/src/modules/users/resolver.ts b/packages/backend/server/src/modules/users/resolver.ts index b735a4c1f1..cc3a529f44 100644 --- a/packages/backend/server/src/modules/users/resolver.ts +++ b/packages/backend/server/src/modules/users/resolver.ts @@ -21,7 +21,7 @@ import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs'; import { PrismaService } from '../../prisma/service'; import { CloudThrottlerGuard, Throttle } from '../../throttler'; import type { FileUpload } from '../../types'; -import { Auth, CurrentUser, Public } from '../auth/guard'; +import { Auth, CurrentUser, Public, Publicable } from '../auth/guard'; import { StorageService } from '../storage/storage.service'; import { NewFeaturesKind } from './types'; import { UsersService } from './users'; @@ -97,11 +97,17 @@ export class UserResolver { ttl: 60, }, }) + @Publicable() @Query(() => UserType, { name: 'currentUser', description: 'Get current user', + nullable: true, }) - async currentUser(@CurrentUser() user: UserType) { + async currentUser(@CurrentUser() user?: UserType) { + if (!user) { + return null; + } + const storedUser = await this.users.findUserById(user.id); if (!storedUser) { throw new BadRequestException(`User ${user.id} not found in db`); diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index 63d8248ede..6be3ba621e 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -23,6 +23,8 @@ type UserType { """User password has been set""" hasPassword: Boolean token: TokenType! + subscription: UserSubscription + invoices(take: Int = 8, skip: Int): [UserInvoice!]! } """ @@ -55,6 +57,73 @@ type TokenType { sessionToken: String } +type SubscriptionPrice { + type: String! + plan: SubscriptionPlan! + currency: String! + amount: Int! + yearlyAmount: Int! +} + +enum SubscriptionPlan { + Free + Pro + Team + Enterprise +} + +type UserSubscription { + id: String! + plan: SubscriptionPlan! + recurring: SubscriptionRecurring! + status: SubscriptionStatus! + start: DateTime! + end: DateTime! + trialStart: DateTime + trialEnd: DateTime + nextBillAt: DateTime + canceledAt: DateTime + createdAt: DateTime! + updatedAt: DateTime! +} + +enum SubscriptionRecurring { + Monthly + Yearly +} + +enum SubscriptionStatus { + Active + PastDue + Unpaid + Canceled + Incomplete + Paused + IncompleteExpired + Trialing +} + +type UserInvoice { + id: String! + plan: SubscriptionPlan! + recurring: SubscriptionRecurring! + currency: String! + amount: Int! + status: InvoiceStatus! + reason: String! + lastPaymentError: String + createdAt: DateTime! + updatedAt: DateTime! +} + +enum InvoiceStatus { + Draft + Open + Void + Paid + Uncollectible +} + type InviteUserType { """User name""" name: String @@ -166,10 +235,11 @@ type Query { checkBlobSize(workspaceId: String!, size: Float!): WorkspaceBlobSizes! """Get current user""" - currentUser: UserType! + currentUser: UserType """Get user by email""" user(email: String!): UserType + prices: [SubscriptionPrice!]! } type Mutation { @@ -205,6 +275,12 @@ type Mutation { removeAvatar: RemoveAvatar! deleteAccount: DeleteAccount! addToNewFeaturesWaitingList(type: NewFeaturesKind!, email: String!): AddToNewFeaturesWaitingList! + + """Create a subscription checkout link of stripe""" + checkout(recurring: SubscriptionRecurring!): String! + cancelSubscription: UserSubscription! + resumeSubscription: UserSubscription! + updateSubscriptionRecurring(recurring: SubscriptionRecurring!): UserSubscription! } """The `Upload` scalar type represents a file upload.""" diff --git a/packages/backend/server/tests/user.e2e.ts b/packages/backend/server/tests/user.e2e.ts index da06bc4322..065925af21 100644 --- a/packages/backend/server/tests/user.e2e.ts +++ b/packages/backend/server/tests/user.e2e.ts @@ -67,6 +67,6 @@ test('should be able to delete user', async t => { `, }) .expect(200); - await t.throwsAsync(() => currentUser(app, user.token.token)); + t.is(await currentUser(app, user.token.token), null); t.pass(); }); diff --git a/packages/frontend/graphql/src/schema.ts b/packages/frontend/graphql/src/schema.ts index 1a47becf06..3a84e42fe4 100644 --- a/packages/frontend/graphql/src/schema.ts +++ b/packages/frontend/graphql/src/schema.ts @@ -32,6 +32,14 @@ export interface Scalars { Upload: { input: File; output: File }; } +export enum InvoiceStatus { + Draft = 'Draft', + Open = 'Open', + Paid = 'Paid', + Uncollectible = 'Uncollectible', + Void = 'Void', +} + export enum NewFeaturesKind { EarlyAccess = 'EarlyAccess', } @@ -44,6 +52,29 @@ export enum Permission { Write = 'Write', } +export enum SubscriptionPlan { + Enterprise = 'Enterprise', + Free = 'Free', + Pro = 'Pro', + Team = 'Team', +} + +export enum SubscriptionRecurring { + Monthly = 'Monthly', + Yearly = 'Yearly', +} + +export enum SubscriptionStatus { + Active = 'Active', + Canceled = 'Canceled', + Incomplete = 'Incomplete', + IncompleteExpired = 'IncompleteExpired', + PastDue = 'PastDue', + Paused = 'Paused', + Trialing = 'Trialing', + Unpaid = 'Unpaid', +} + export interface UpdateWorkspaceInput { id: Scalars['ID']['input']; /** is Public workspace */ @@ -173,7 +204,7 @@ export type GetCurrentUserQuery = { avatarUrl: string | null; createdAt: string | null; token: { __typename?: 'TokenType'; sessionToken: string | null }; - }; + } | null; }; export type GetInviteInfoQueryVariables = Exact<{ diff --git a/yarn.lock b/yarn.lock index 9445b44dcf..40dfb5747c 100644 --- a/yarn.lock +++ b/yarn.lock @@ -671,6 +671,7 @@ __metadata: "@nestjs/apollo": "npm:^12.0.9" "@nestjs/common": "npm:^10.2.7" "@nestjs/core": "npm:^10.2.7" + "@nestjs/event-emitter": "npm:^2.0.2" "@nestjs/graphql": "npm:^12.0.9" "@nestjs/platform-express": "npm:^10.2.7" "@nestjs/platform-socket.io": "npm:^10.2.7" @@ -735,6 +736,7 @@ __metadata: semver: "npm:^7.5.4" sinon: "npm:^16.1.0" socket.io: "npm:^4.7.2" + stripe: "npm:^13.6.0" supertest: "npm:^6.3.3" ts-node: "npm:^10.9.1" typescript: "npm:^5.2.2" @@ -7280,6 +7282,19 @@ __metadata: languageName: node linkType: hard +"@nestjs/event-emitter@npm:^2.0.2": + version: 2.0.2 + resolution: "@nestjs/event-emitter@npm:2.0.2" + dependencies: + eventemitter2: "npm:6.4.9" + peerDependencies: + "@nestjs/common": ^8.0.0 || ^9.0.0 || ^10.0.0 + "@nestjs/core": ^8.0.0 || ^9.0.0 || ^10.0.0 + reflect-metadata: ^0.1.12 + checksum: 9c7d2645b14bef5a9d26a8fbafb5963e18c9c15e267980c55abd913c8af9215ae363b8c0fc78711c22126e0a973f80aec8b8e962a64e699f523128d11c033894 + languageName: node + linkType: hard + "@nestjs/graphql@npm:^12.0.9": version: 12.0.9 resolution: "@nestjs/graphql@npm:12.0.9" @@ -13432,6 +13447,13 @@ __metadata: languageName: node linkType: hard +"@types/node@npm:>=8.1.0": + version: 20.6.2 + resolution: "@types/node@npm:20.6.2" + checksum: 4b150698cf90c211d4f2f021618f06c33a337d74e9a0ce10ec2e7123f02aacc231eff62118101f56de75f7be309c2da6eb0edb8388d501d4195c50bb919c7a05 + languageName: node + linkType: hard + "@types/node@npm:^16.0.0": version: 16.18.58 resolution: "@types/node@npm:16.18.58" @@ -20137,6 +20159,13 @@ __metadata: languageName: node linkType: hard +"eventemitter2@npm:6.4.9": + version: 6.4.9 + resolution: "eventemitter2@npm:6.4.9" + checksum: b829b1c6b11e15926b635092b5ad62b4463d1c928859831dcae606e988cf41893059e3541f5a8209d21d2f15314422ddd4d84d20830b4bf44978608d15b06b08 + languageName: node + linkType: hard + "eventemitter3@npm:^3.1.0": version: 3.1.2 resolution: "eventemitter3@npm:3.1.2" @@ -31941,6 +31970,16 @@ __metadata: languageName: node linkType: hard +"stripe@npm:^13.6.0": + version: 13.6.0 + resolution: "stripe@npm:13.6.0" + dependencies: + "@types/node": "npm:>=8.1.0" + qs: "npm:^6.11.0" + checksum: 3fae1ed3dc845166c36fb28e4297ec770bb1f1b35e88b0166c465a31d41216203341b1055bf63b653fa3c66cd5d2eb72fdfaec9b58a7d467d207645a12b2cde0 + languageName: node + linkType: hard + "strnum@npm:^1.0.5": version: 1.0.5 resolution: "strnum@npm:1.0.5"