feat(core): payment backend

This commit is contained in:
forehalo
2023-10-19 10:06:34 +08:00
parent 493b815b7b
commit df054ac7f6
18 changed files with 1260 additions and 8 deletions

View File

@@ -3,4 +3,6 @@ NEXTAUTH_URL="http://localhost:8080"
OAUTH_EMAIL_SENDER="noreply@toeverything.info"
OAUTH_EMAIL_LOGIN=""
OAUTH_EMAIL_PASSWORD=""
ENABLE_LOCAL_EMAIL="true"
ENABLE_LOCAL_EMAIL="true"
STRIPE_API_KEY=
STRIPE_WEBHOOK_KEY=

View File

@@ -0,0 +1,68 @@
-- CreateTable
CREATE TABLE "user_stripe_customers" (
"user_id" VARCHAR NOT NULL,
"stripe_customer_id" VARCHAR NOT NULL,
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT "user_stripe_customers_pkey" PRIMARY KEY ("user_id")
);
-- CreateTable
CREATE TABLE "user_subscriptions" (
"id" SERIAL NOT NULL,
"user_id" VARCHAR(36) NOT NULL,
"plan" VARCHAR(20) NOT NULL,
"recurring" VARCHAR(20) NOT NULL,
"stripe_subscription_id" TEXT NOT NULL,
"status" VARCHAR(20) NOT NULL,
"start" TIMESTAMPTZ(6) NOT NULL,
"end" TIMESTAMPTZ(6) NOT NULL,
"next_bill_at" TIMESTAMPTZ(6),
"canceled_at" TIMESTAMPTZ(6),
"trial_start" TIMESTAMPTZ(6),
"trial_end" TIMESTAMPTZ(6),
"stripe_schedule_id" VARCHAR,
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ(6) NOT NULL,
CONSTRAINT "user_subscriptions_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "user_invoices" (
"id" SERIAL NOT NULL,
"user_id" VARCHAR(36) NOT NULL,
"stripe_invoice_id" TEXT NOT NULL,
"currency" VARCHAR(3) NOT NULL,
"amount" INTEGER NOT NULL,
"status" VARCHAR(20) NOT NULL,
"plan" VARCHAR(20) NOT NULL,
"recurring" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ(6) NOT NULL,
"reason" VARCHAR NOT NULL,
"last_payment_error" TEXT,
CONSTRAINT "user_invoices_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "user_stripe_customers_stripe_customer_id_key" ON "user_stripe_customers"("stripe_customer_id");
-- CreateIndex
CREATE UNIQUE INDEX "user_subscriptions_user_id_key" ON "user_subscriptions"("user_id");
-- CreateIndex
CREATE UNIQUE INDEX "user_subscriptions_stripe_subscription_id_key" ON "user_subscriptions"("stripe_subscription_id");
-- CreateIndex
CREATE UNIQUE INDEX "user_invoices_stripe_invoice_id_key" ON "user_invoices"("stripe_invoice_id");
-- AddForeignKey
ALTER TABLE "user_stripe_customers" ADD CONSTRAINT "user_stripe_customers_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "user_subscriptions" ADD CONSTRAINT "user_subscriptions_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "user_invoices" ADD CONSTRAINT "user_invoices_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -25,6 +25,7 @@
"@nestjs/apollo": "^12.0.9",
"@nestjs/common": "^10.2.7",
"@nestjs/core": "^10.2.7",
"@nestjs/event-emitter": "^2.0.2",
"@nestjs/graphql": "^12.0.9",
"@nestjs/platform-express": "^10.2.7",
"@nestjs/platform-socket.io": "^10.2.7",
@@ -71,6 +72,7 @@
"rxjs": "^7.8.1",
"semver": "^7.5.4",
"socket.io": "^4.7.2",
"stripe": "^13.6.0",
"ws": "^8.14.2",
"yjs": "^13.6.8"
},

View File

@@ -49,6 +49,9 @@ model User {
/// Not available if user signed up through OAuth providers
password String? @db.VarChar
features UserFeatureGates[]
customer UserStripeCustomer?
subscription UserSubscription?
invoices UserInvoice[]
@@map("users")
}
@@ -164,3 +167,65 @@ model NewFeaturesWaitingList {
@@map("new_features_waiting_list")
}
model UserStripeCustomer {
userId String @id @map("user_id") @db.VarChar
stripeCustomerId String @unique @map("stripe_customer_id") @db.VarChar
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@map("user_stripe_customers")
}
model UserSubscription {
id Int @id @default(autoincrement()) @db.Integer
userId String @unique @map("user_id") @db.VarChar(36)
plan String @db.VarChar(20)
// yearly/monthly
recurring String @db.VarChar(20)
// subscription.id
stripeSubscriptionId String @unique @map("stripe_subscription_id")
// subscription.status, active/past_due/canceled/unpaid...
status String @db.VarChar(20)
// subscription.current_period_start
start DateTime @map("start") @db.Timestamptz(6)
// subscription.current_period_end
end DateTime @map("end") @db.Timestamptz(6)
// subscription.billing_cycle_anchor
nextBillAt DateTime? @map("next_bill_at") @db.Timestamptz(6)
// subscription.canceled_at
canceledAt DateTime? @map("canceled_at") @db.Timestamptz(6)
// subscription.trial_start
trialStart DateTime? @map("trial_start") @db.Timestamptz(6)
// subscription.trial_end
trialEnd DateTime? @map("trial_end") @db.Timestamptz(6)
stripeScheduleId String? @map("stripe_schedule_id") @db.VarChar
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6)
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@map("user_subscriptions")
}
model UserInvoice {
id Int @id @default(autoincrement()) @db.Integer
userId String @map("user_id") @db.VarChar(36)
stripeInvoiceId String @unique @map("stripe_invoice_id")
currency String @db.VarChar(3)
// CNY 12.50 stored as 1250
amount Int @db.Integer
status String @db.VarChar(20)
plan String @db.VarChar(20)
recurring String @db.VarChar(20)
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6)
// billing reason
reason String @db.VarChar
lastPaymentError String? @map("last_payment_error") @db.Text
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@map("user_invoices")
}

View File

@@ -363,4 +363,13 @@ export interface AFFiNEConfig {
experimentalMergeWithJwstCodec: boolean;
};
};
payment: {
stripe: {
keys: {
APIKey: string;
webhookKey: string;
};
} & import('stripe').Stripe.StripeConfig;
};
}

View File

@@ -89,6 +89,8 @@ export const getDefaultAFFiNEConfig: () => AFFiNEConfig = () => {
'boolean',
],
ENABLE_LOCAL_EMAIL: ['auth.localEmail', 'boolean'],
STRIPE_API_KEY: 'payment.stripe.keys.APIKey',
STRIPE_WEBHOOK_KEY: 'payment.stripe.keys.webhookKey',
} satisfies AFFiNEConfig['ENV_MAP'],
affineEnv: 'dev',
get affine() {
@@ -207,6 +209,15 @@ export const getDefaultAFFiNEConfig: () => AFFiNEConfig = () => {
experimentalMergeWithJwstCodec: false,
},
},
payment: {
stripe: {
keys: {
APIKey: '',
webhookKey: '',
},
apiVersion: '2023-08-16',
},
},
} satisfies AFFiNEConfig;
applyEnvToConfig(defaultConfig);

View File

@@ -59,6 +59,7 @@ if (NODE_ENV === 'production') {
const app = await NestFactory.create<NestExpressApplication>(AppModule, {
cors: true,
rawBody: true,
bodyParser: true,
logger:
NODE_ENV !== 'production' || AFFINE_ENV !== 'production'

View File

@@ -1,8 +1,10 @@
import { DynamicModule, Type } from '@nestjs/common';
import { EventEmitterModule } from '@nestjs/event-emitter';
import { GqlModule } from '../graphql.module';
import { AuthModule } from './auth';
import { DocModule } from './doc';
import { PaymentModule } from './payment';
import { SyncModule } from './sync';
import { UsersModule } from './users';
import { WorkspaceModule } from './workspaces';
@@ -17,22 +19,30 @@ switch (SERVER_FLAVOR) {
break;
case 'graphql':
BusinessModules.push(
EventEmitterModule.forRoot({
global: true,
}),
GqlModule,
WorkspaceModule,
UsersModule,
AuthModule,
DocModule.forRoot()
DocModule.forRoot(),
PaymentModule
);
break;
case 'allinone':
default:
BusinessModules.push(
EventEmitterModule.forRoot({
global: true,
}),
GqlModule,
WorkspaceModule,
UsersModule,
AuthModule,
SyncModule,
DocModule.forRoot()
DocModule.forRoot(),
PaymentModule
);
break;
}

View File

@@ -0,0 +1,17 @@
import { Module } from '@nestjs/common';
import { SubscriptionResolver, UserSubscriptionResolver } from './resolver';
import { SubscriptionService } from './service';
import { StripeProvider } from './stripe';
import { StripeWebhook } from './webhook';
@Module({
providers: [
StripeProvider,
SubscriptionService,
SubscriptionResolver,
UserSubscriptionResolver,
],
controllers: [StripeWebhook],
})
export class PaymentModule {}

View File

@@ -0,0 +1,246 @@
import {
BadGatewayException,
ForbiddenException,
InternalServerErrorException,
} from '@nestjs/common';
import {
Args,
Field,
Int,
Mutation,
ObjectType,
Parent,
Query,
registerEnumType,
ResolveField,
Resolver,
} from '@nestjs/graphql';
import type { User, UserInvoice, UserSubscription } from '@prisma/client';
import { Config } from '../../config';
import { PrismaService } from '../../prisma';
import { Auth, CurrentUser, Public } from '../auth';
import { UserType } from '../users';
import {
InvoiceStatus,
SubscriptionPlan,
SubscriptionRecurring,
SubscriptionService,
SubscriptionStatus,
} from './service';
registerEnumType(SubscriptionStatus, { name: 'SubscriptionStatus' });
registerEnumType(SubscriptionRecurring, { name: 'SubscriptionRecurring' });
registerEnumType(SubscriptionPlan, { name: 'SubscriptionPlan' });
registerEnumType(InvoiceStatus, { name: 'InvoiceStatus' });
@ObjectType()
class SubscriptionPrice {
@Field(() => String)
type!: 'fixed';
@Field(() => SubscriptionPlan)
plan!: SubscriptionPlan;
@Field()
currency!: string;
@Field()
amount!: number;
@Field()
yearlyAmount!: number;
}
@ObjectType('UserSubscription')
class UserSubscriptionType implements Partial<UserSubscription> {
@Field({ name: 'id' })
stripeSubscriptionId!: string;
@Field(() => SubscriptionPlan)
plan!: SubscriptionPlan;
@Field(() => SubscriptionRecurring)
recurring!: SubscriptionRecurring;
@Field(() => SubscriptionStatus)
status!: SubscriptionStatus;
@Field(() => Date)
start!: Date;
@Field(() => Date)
end!: Date;
@Field(() => Date, { nullable: true })
trialStart?: Date | null;
@Field(() => Date, { nullable: true })
trialEnd?: Date | null;
@Field(() => Date, { nullable: true })
nextBillAt?: Date | null;
@Field(() => Date, { nullable: true })
canceledAt?: Date | null;
@Field(() => Date)
createdAt!: Date;
@Field(() => Date)
updatedAt!: Date;
}
@ObjectType('UserInvoice')
class UserInvoiceType implements Partial<UserInvoice> {
@Field({ name: 'id' })
stripeInvoiceId!: string;
@Field(() => SubscriptionPlan)
plan!: SubscriptionPlan;
@Field(() => SubscriptionRecurring)
recurring!: SubscriptionRecurring;
@Field()
currency!: string;
@Field()
amount!: number;
@Field(() => InvoiceStatus)
status!: InvoiceStatus;
@Field()
reason!: string;
@Field(() => String, { nullable: true })
lastPaymentError?: string | null;
@Field(() => Date)
createdAt!: Date;
@Field(() => Date)
updatedAt!: Date;
}
@Auth()
@Resolver(() => UserSubscriptionType)
export class SubscriptionResolver {
constructor(
private readonly service: SubscriptionService,
private readonly config: Config
) {}
@Public()
@Query(() => [SubscriptionPrice])
async prices(): Promise<SubscriptionPrice[]> {
const prices = await this.service.listPrices();
const yearly = prices.data.find(
price => price.lookup_key === SubscriptionRecurring.Yearly
);
const monthly = prices.data.find(
price => price.lookup_key === SubscriptionRecurring.Monthly
);
if (!yearly || !monthly) {
throw new BadGatewayException('The prices are not configured correctly');
}
return [
{
type: 'fixed',
plan: SubscriptionPlan.Pro,
currency: monthly.currency,
amount: monthly.unit_amount ?? 0,
yearlyAmount: yearly.unit_amount ?? 0,
},
];
}
@Mutation(() => String, {
description: 'Create a subscription checkout link of stripe',
})
async checkout(
@CurrentUser() user: User,
@Args({ name: 'recurring', type: () => SubscriptionRecurring })
recurring: SubscriptionRecurring
) {
const session = await this.service.createCheckoutSession({
user,
recurring,
// TODO: replace with frontend url
redirectUrl: `${this.config.baseUrl}/api/stripe/success`,
});
if (!session.url) {
throw new InternalServerErrorException(
'Failed to create checkout session'
);
}
return session.url;
}
@Mutation(() => UserSubscriptionType)
async cancelSubscription(@CurrentUser() user: User) {
return this.service.cancelSubscription(user.id);
}
@Mutation(() => UserSubscriptionType)
async resumeSubscription(@CurrentUser() user: User) {
return this.service.resumeCanceledSubscriptin(user.id);
}
@Mutation(() => UserSubscriptionType)
async updateSubscriptionRecurring(
@CurrentUser() user: User,
@Args({ name: 'recurring', type: () => SubscriptionRecurring })
recurring: SubscriptionRecurring
) {
return this.service.updateSubscriptionRecurring(user.id, recurring);
}
}
@Resolver(() => UserType)
export class UserSubscriptionResolver {
constructor(private readonly db: PrismaService) {}
@ResolveField(() => UserSubscriptionType, { nullable: true })
async subscription(@CurrentUser() me: User, @Parent() user: User) {
if (me.id !== user.id) {
throw new ForbiddenException();
}
return this.db.userSubscription.findUnique({
where: {
userId: user.id,
},
});
}
@ResolveField(() => [UserInvoiceType])
async invoices(
@CurrentUser() me: User,
@Parent() user: User,
@Args('take', { type: () => Int, nullable: true, defaultValue: 8 })
take: number,
@Args('skip', { type: () => Int, nullable: true }) skip?: number
) {
if (me.id !== user.id) {
throw new ForbiddenException();
}
return this.db.userInvoice.findMany({
where: {
userId: user.id,
},
take,
skip,
orderBy: {
id: 'desc',
},
});
}
}

View File

@@ -0,0 +1,576 @@
import { Injectable, Logger } from '@nestjs/common';
import { OnEvent as RawOnEvent } from '@nestjs/event-emitter';
import type {
Prisma,
User,
UserInvoice,
UserStripeCustomer,
UserSubscription,
} from '@prisma/client';
import Stripe from 'stripe';
import { Config } from '../../config';
import { PrismaService } from '../../prisma';
const OnEvent = (
event: Stripe.Event.Type,
opts?: Parameters<typeof RawOnEvent>[1]
) => RawOnEvent(event, opts);
// also used as lookup key for stripe prices
export enum SubscriptionRecurring {
Monthly = 'monthly',
Yearly = 'yearly',
}
export enum SubscriptionPlan {
Free = 'free',
Pro = 'pro',
Team = 'team',
Enterprise = 'enterprise',
}
// see https://stripe.com/docs/api/subscriptions/object#subscription_object-status
export enum SubscriptionStatus {
Active = 'active',
PastDue = 'past_due',
Unpaid = 'unpaid',
Canceled = 'canceled',
Incomplete = 'incomplete',
Paused = 'paused',
IncompleteExpired = 'incomplete_expired',
Trialing = 'trialing',
}
export enum InvoiceStatus {
Draft = 'draft',
Open = 'open',
Void = 'void',
Paid = 'paid',
Uncollectible = 'uncollectible',
}
@Injectable()
export class SubscriptionService {
private readonly paymentConfig: Config['payment'];
private readonly logger = new Logger(SubscriptionService.name);
constructor(
config: Config,
private readonly stripe: Stripe,
private readonly db: PrismaService
) {
this.paymentConfig = config.payment;
if (
!this.paymentConfig.stripe.keys.APIKey ||
!this.paymentConfig.stripe.keys.webhookKey /* default empty string */
) {
this.logger.warn('Stripe API key not set, Stripe will be disabled');
this.logger.warn('Set STRIPE_API_KEY to enable Stripe');
}
}
async listPrices() {
return this.stripe.prices.list({
lookup_keys: Object.values(SubscriptionRecurring),
});
}
async createCheckoutSession({
user,
recurring,
redirectUrl,
}: {
user: User;
recurring: SubscriptionRecurring;
redirectUrl: string;
}) {
const currentSubscription = await this.db.userSubscription.findUnique({
where: {
userId: user.id,
},
});
if (currentSubscription && currentSubscription.end < new Date()) {
throw new Error('User already has a subscription');
}
const prices = await this.stripe.prices.list({
lookup_keys: [recurring],
});
if (!prices.data.length) {
throw new Error(`Unknown subscription recurring: ${recurring}`);
}
const customer = await this.getOrCreateCustomer(user);
return await this.stripe.checkout.sessions.create({
line_items: [
{
price: prices.data[0].id,
quantity: 1,
},
],
allow_promotion_codes: true,
tax_id_collection: {
enabled: true,
},
mode: 'subscription',
success_url: redirectUrl,
customer: customer.stripeCustomerId,
customer_update: {
address: 'auto',
name: 'auto',
},
});
}
async cancelSubscription(userId: string): Promise<UserSubscription> {
const user = await this.db.user.findUnique({
where: {
id: userId,
},
include: {
subscription: true,
},
});
if (!user?.subscription) {
throw new Error('User has no subscription');
}
if (user.subscription.canceledAt) {
throw new Error('User subscription has already been canceled ');
}
// should release the schedule first
if (user.subscription.stripeScheduleId) {
await this.stripe.subscriptionSchedules.release(
user.subscription.stripeScheduleId
);
}
// let customer contact support if they want to cancel immediately
// see https://stripe.com/docs/billing/subscriptions/cancel
const subscription = await this.stripe.subscriptions.update(
user.subscription.stripeSubscriptionId,
{
cancel_at_period_end: true,
}
);
return await this.saveSubscription(user, subscription);
}
async resumeCanceledSubscriptin(userId: string): Promise<UserSubscription> {
const user = await this.db.user.findUnique({
where: {
id: userId,
},
include: {
subscription: true,
},
});
if (!user?.subscription) {
throw new Error('User has no subscription');
}
if (!user.subscription.canceledAt) {
throw new Error('User subscription is not canceled');
}
if (user.subscription.end < new Date()) {
throw new Error(
'User subscription has already expired, please checkout again.'
);
}
const subscription = await this.stripe.subscriptions.update(
user.subscription.stripeSubscriptionId,
{
cancel_at_period_end: false,
}
);
return await this.saveSubscription(user, subscription);
}
async updateSubscriptionRecurring(
userId: string,
recurring: string
): Promise<UserSubscription> {
const user = await this.db.user.findUnique({
where: {
id: userId,
},
include: {
subscription: true,
},
});
if (!user?.subscription) {
throw new Error('User has no subscription');
}
if (user.subscription.recurring === recurring) {
throw new Error('User has already subscribed to this plan');
}
const prices = await this.stripe.prices.list({
lookup_keys: [recurring],
});
if (!prices.data.length) {
throw new Error(`Unknown subscription recurring: ${recurring}`);
}
const newPrice = prices.data[0];
// a schedule existing
if (user.subscription.stripeScheduleId) {
const schedule = await this.stripe.subscriptionSchedules.retrieve(
user.subscription.stripeScheduleId
);
// a scheduled subscription's old price equals the change
if (
schedule.phases[0] &&
(schedule.phases[0].items[0].price as string) === newPrice.id
) {
await this.stripe.subscriptionSchedules.release(
user.subscription.stripeScheduleId
);
return await this.db.userSubscription.update({
where: {
id: user.subscription.id,
},
data: {
recurring,
},
});
} else {
throw new Error(
'Unexpected subscription scheduled, please contact the supporters'
);
}
} else {
const schedule = await this.stripe.subscriptionSchedules.create({
from_subscription: user.subscription.stripeSubscriptionId,
});
await this.stripe.subscriptionSchedules.update(schedule.id, {
phases: [
{
items: [
{
price: schedule.phases[0].items[0].price as string,
quantity: 1,
},
],
start_date: schedule.phases[0].start_date,
end_date: schedule.phases[0].end_date,
},
{
items: [
{
price: newPrice.id,
quantity: 1,
},
],
},
],
});
return await this.db.userSubscription.update({
where: {
id: user.subscription.id,
},
data: {
recurring,
stripeScheduleId: schedule.id,
},
});
}
}
@OnEvent('customer.subscription.created')
@OnEvent('customer.subscription.updated')
async onSubscriptionChanges(subscription: Stripe.Subscription) {
const user = await this.retrieveUserFromCustomer(
subscription.customer as string
);
await this.saveSubscription(user, subscription);
}
@OnEvent('customer.subscription.deleted')
async onSubscriptionDeleted(subscription: Stripe.Subscription) {
const user = await this.retrieveUserFromCustomer(
subscription.customer as string
);
await this.db.userSubscription.deleteMany({
where: {
stripeSubscriptionId: subscription.id,
userId: user.id,
},
});
}
@OnEvent('invoice.created')
async onInvoiceCreated(invoice: Stripe.Invoice) {
await this.saveInvoice(invoice);
}
@OnEvent('invoice.paid')
async onInvoicePaid(invoice: Stripe.Invoice) {
await this.saveInvoice(invoice);
}
@OnEvent('invoice.finalization_failed')
async onInvoiceFinalizeFailed(invoice: Stripe.Invoice) {
await this.saveInvoice(invoice);
}
@OnEvent('invoice.payment_failed')
async onInvoicePaymentFailed(invoice: Stripe.Invoice) {
await this.saveInvoice(invoice);
}
private async saveSubscription(
user: User,
subscription: Stripe.Subscription
): Promise<UserSubscription> {
// get next bill date from upcoming invoice
// see https://stripe.com/docs/api/invoices/upcoming
let nextBillAt: Date | null = null;
if (
(subscription.status === SubscriptionStatus.Active ||
subscription.status === SubscriptionStatus.Trialing) &&
!subscription.canceled_at
) {
try {
const nextInvoice = await this.stripe.invoices.retrieveUpcoming({
customer: subscription.customer as string,
subscription: subscription.id,
});
nextBillAt = new Date(nextInvoice.created * 1000);
} catch (e) {
// no upcoming invoice
// safe to ignore
}
}
const price = subscription.items.data[0].price;
const commonData = {
start: new Date(subscription.current_period_start * 1000),
end: new Date(subscription.current_period_end * 1000),
trialStart: subscription.trial_start
? new Date(subscription.trial_start * 1000)
: null,
trialEnd: subscription.trial_end
? new Date(subscription.trial_end * 1000)
: null,
nextBillAt,
canceledAt: subscription.canceled_at
? new Date(subscription.canceled_at * 1000)
: null,
stripeSubscriptionId: subscription.id,
recurring: price.lookup_key ?? price.id,
// TODO: dynamic plans
plan: SubscriptionPlan.Pro,
status: subscription.status,
stripeScheduleId: subscription.schedule as string | null,
};
const currentSubscription = await this.db.userSubscription.findUnique({
where: {
userId: user.id,
},
});
if (currentSubscription) {
const update: Prisma.UserSubscriptionUpdateInput = {
...commonData,
};
// a schedule exists, update the recurring to scheduled one
if (update.stripeScheduleId) {
delete update.recurring;
}
return await this.db.userSubscription.update({
where: {
id: currentSubscription.id,
},
data: update,
});
} else {
return await this.db.userSubscription.create({
data: {
userId: user.id,
...commonData,
},
});
}
}
private async getOrCreateCustomer(user: User): Promise<UserStripeCustomer> {
const customer = await this.db.userStripeCustomer.findUnique({
where: {
userId: user.id,
},
});
if (customer) {
return customer;
}
const stripeCustomersList = await this.stripe.customers.list({
email: user.email,
limit: 1,
});
let stripeCustomer: Stripe.Customer | undefined;
if (stripeCustomersList.data.length) {
stripeCustomer = stripeCustomersList.data[0];
} else {
stripeCustomer = await this.stripe.customers.create({
email: user.email,
});
}
return await this.db.userStripeCustomer.create({
data: {
userId: user.id,
stripeCustomerId: stripeCustomer.id,
},
});
}
private async retrieveUserFromCustomer(customerId: string) {
const customer = await this.db.userStripeCustomer.findUnique({
where: {
stripeCustomerId: customerId,
},
include: {
user: true,
},
});
if (customer?.user) {
return customer.user;
}
// customer may not saved is db, check it with stripe
const stripeCustomer = await this.stripe.customers.retrieve(customerId);
if (stripeCustomer.deleted) {
throw new Error('Unexpected subscription created with deleted customer');
}
if (!stripeCustomer.email) {
throw new Error('Unexpected subscription created with no email customer');
}
const user = await this.db.user.findUnique({
where: {
email: stripeCustomer.email,
},
});
if (!user) {
throw new Error(
`Unexpected subscription created with unknown customer ${stripeCustomer.email}`
);
}
await this.db.userStripeCustomer.create({
data: {
userId: user.id,
stripeCustomerId: stripeCustomer.id,
},
});
return user;
}
private async saveInvoice(stripeInvoice: Stripe.Invoice) {
if (!stripeInvoice.customer) {
throw new Error('Unexpected invoice with no customer');
}
const user = await this.retrieveUserFromCustomer(
stripeInvoice.customer as string
);
const invoice = await this.db.userInvoice.findUnique({
where: {
stripeInvoiceId: stripeInvoice.id,
},
});
const data: Partial<UserInvoice> = {
currency: stripeInvoice.currency,
amount: stripeInvoice.total,
status: stripeInvoice.status ?? InvoiceStatus.Void,
};
// handle payment error
if (stripeInvoice.attempt_count > 1) {
const paymentIntent = await this.stripe.paymentIntents.retrieve(
stripeInvoice.payment_intent as string
);
if (paymentIntent.last_payment_error) {
if (paymentIntent.last_payment_error.type === 'card_error') {
data.lastPaymentError =
paymentIntent.last_payment_error.message ?? 'Failed to pay';
} else {
data.lastPaymentError = 'Internal Payment error';
}
}
} else if (stripeInvoice.last_finalization_error) {
if (stripeInvoice.last_finalization_error.type === 'card_error') {
data.lastPaymentError =
stripeInvoice.last_finalization_error.message ??
'Failed to finalize invoice';
} else {
data.lastPaymentError = 'Internal Payment error';
}
}
// update invoice
if (invoice) {
await this.db.userInvoice.update({
where: {
stripeInvoiceId: stripeInvoice.id,
},
data,
});
} else {
// create invoice
const price = stripeInvoice.lines.data[0].price;
if (!price || price.type !== 'recurring') {
throw new Error('Unexpected invoice with no recurring price');
}
await this.db.userInvoice.create({
data: {
userId: user.id,
stripeInvoiceId: stripeInvoice.id,
plan: SubscriptionPlan.Pro,
recurring: price.lookup_key ?? price.id,
reason: stripeInvoice.billing_reason ?? 'contact support',
...(data as any),
},
});
}
}
}

View File

@@ -0,0 +1,18 @@
import { FactoryProvider } from '@nestjs/common';
import { omit } from 'lodash-es';
import Stripe from 'stripe';
import { Config } from '../../config';
export const StripeProvider: FactoryProvider = {
provide: Stripe,
useFactory: (config: Config) => {
const stripeConfig = config.payment.stripe;
return new Stripe(
stripeConfig.keys.APIKey,
omit(config.payment.stripe, 'keys', 'prices')
);
},
inject: [Config],
};

View File

@@ -0,0 +1,75 @@
import type { RawBodyRequest } from '@nestjs/common';
import {
Controller,
Get,
Logger,
NotAcceptableException,
Post,
Req,
} from '@nestjs/common';
import { EventEmitter2 } from '@nestjs/event-emitter';
import type { User } from '@prisma/client';
import type { Request } from 'express';
import Stripe from 'stripe';
import { Config } from '../../config';
import { PrismaService } from '../../prisma';
import { Auth, CurrentUser } from '../auth';
@Controller('/api/stripe')
export class StripeWebhook {
private readonly config: Config['payment'];
private readonly logger = new Logger(StripeWebhook.name);
constructor(
config: Config,
private readonly stripe: Stripe,
private readonly event: EventEmitter2,
private readonly db: PrismaService
) {
this.config = config.payment;
}
// just for test
@Auth()
@Get('/success')
async handleSuccess(@CurrentUser() user: User) {
return this.db.userSubscription.findUnique({
where: {
userId: user.id,
},
});
}
@Post('/webhook')
async handleWebhook(@Req() req: RawBodyRequest<Request>) {
// Check if webhook signing is configured.
if (!this.config.stripe.keys.webhookKey) {
this.logger.error(
'Stripe Webhook key is not set, but a webhook was received.'
);
throw new NotAcceptableException();
}
// Retrieve the event by verifying the signature using the raw body and secret.
const signature = req.headers['stripe-signature'];
try {
const event = this.stripe.webhooks.constructEvent(
req.rawBody ?? '',
signature ?? '',
this.config.stripe.keys.webhookKey
);
this.logger.debug(
`[${event.id}] Stripe Webhook {${event.type}} received.`
);
// handle duplicated events?
// see https://stripe.com/docs/webhooks#handle-duplicate-events
await this.event.emitAsync(event.type, event.data.object);
} catch (err) {
this.logger.error('Stripe Webhook error', err);
throw new NotAcceptableException();
}
}
}

View File

@@ -21,7 +21,7 @@ import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
import { PrismaService } from '../../prisma/service';
import { CloudThrottlerGuard, Throttle } from '../../throttler';
import type { FileUpload } from '../../types';
import { Auth, CurrentUser, Public } from '../auth/guard';
import { Auth, CurrentUser, Public, Publicable } from '../auth/guard';
import { StorageService } from '../storage/storage.service';
import { NewFeaturesKind } from './types';
import { UsersService } from './users';
@@ -97,11 +97,17 @@ export class UserResolver {
ttl: 60,
},
})
@Publicable()
@Query(() => UserType, {
name: 'currentUser',
description: 'Get current user',
nullable: true,
})
async currentUser(@CurrentUser() user: UserType) {
async currentUser(@CurrentUser() user?: UserType) {
if (!user) {
return null;
}
const storedUser = await this.users.findUserById(user.id);
if (!storedUser) {
throw new BadRequestException(`User ${user.id} not found in db`);

View File

@@ -23,6 +23,8 @@ type UserType {
"""User password has been set"""
hasPassword: Boolean
token: TokenType!
subscription: UserSubscription
invoices(take: Int = 8, skip: Int): [UserInvoice!]!
}
"""
@@ -55,6 +57,73 @@ type TokenType {
sessionToken: String
}
type SubscriptionPrice {
type: String!
plan: SubscriptionPlan!
currency: String!
amount: Int!
yearlyAmount: Int!
}
enum SubscriptionPlan {
Free
Pro
Team
Enterprise
}
type UserSubscription {
id: String!
plan: SubscriptionPlan!
recurring: SubscriptionRecurring!
status: SubscriptionStatus!
start: DateTime!
end: DateTime!
trialStart: DateTime
trialEnd: DateTime
nextBillAt: DateTime
canceledAt: DateTime
createdAt: DateTime!
updatedAt: DateTime!
}
enum SubscriptionRecurring {
Monthly
Yearly
}
enum SubscriptionStatus {
Active
PastDue
Unpaid
Canceled
Incomplete
Paused
IncompleteExpired
Trialing
}
type UserInvoice {
id: String!
plan: SubscriptionPlan!
recurring: SubscriptionRecurring!
currency: String!
amount: Int!
status: InvoiceStatus!
reason: String!
lastPaymentError: String
createdAt: DateTime!
updatedAt: DateTime!
}
enum InvoiceStatus {
Draft
Open
Void
Paid
Uncollectible
}
type InviteUserType {
"""User name"""
name: String
@@ -166,10 +235,11 @@ type Query {
checkBlobSize(workspaceId: String!, size: Float!): WorkspaceBlobSizes!
"""Get current user"""
currentUser: UserType!
currentUser: UserType
"""Get user by email"""
user(email: String!): UserType
prices: [SubscriptionPrice!]!
}
type Mutation {
@@ -205,6 +275,12 @@ type Mutation {
removeAvatar: RemoveAvatar!
deleteAccount: DeleteAccount!
addToNewFeaturesWaitingList(type: NewFeaturesKind!, email: String!): AddToNewFeaturesWaitingList!
"""Create a subscription checkout link of stripe"""
checkout(recurring: SubscriptionRecurring!): String!
cancelSubscription: UserSubscription!
resumeSubscription: UserSubscription!
updateSubscriptionRecurring(recurring: SubscriptionRecurring!): UserSubscription!
}
"""The `Upload` scalar type represents a file upload."""

View File

@@ -67,6 +67,6 @@ test('should be able to delete user', async t => {
`,
})
.expect(200);
await t.throwsAsync(() => currentUser(app, user.token.token));
t.is(await currentUser(app, user.token.token), null);
t.pass();
});