refactor(server): plugin modules (#5630)

- [x] separates modules into `fundamental`, `core`, `plugins`
- [x] optional modules with `@OptionalModule` decorator to install modules with requirements met(`requires`, `if`)
- [x] `module.contributesTo` defines optional features that will be enabled if module registered
- [x] `AFFiNE.plugins.use('payment', {})` to enable a optional/plugin module
- [x] `PaymentModule` is the first plugin module
- [x] GraphQLSchema will not be generated for non-included modules
- [x] Frontend can use `ServerConfigType` query to detect which features are enabled
- [x] override existing provider globally
This commit is contained in:
liuyi
2024-01-22 07:40:28 +00:00
parent ae8401b6f4
commit e516e0db23
130 changed files with 1297 additions and 974 deletions

View File

@@ -0,0 +1,21 @@
import { PaymentConfig } from './payment';
import { RedisOptions } from './redis';
declare module '../fundamentals/config' {
interface PluginsConfig {
readonly payment: PaymentConfig;
readonly redis: RedisOptions;
}
export type AvailablePlugins = keyof PluginsConfig;
interface AFFiNEConfig {
readonly plugins: {
enabled: AvailablePlugins[];
use<Plugin extends AvailablePlugins>(
plugin: Plugin,
config?: DeepPartial<PluginsConfig[Plugin]>
): void;
} & Partial<PluginsConfig>;
}
}

View File

@@ -0,0 +1,8 @@
import type { AvailablePlugins } from '../fundamentals/config';
import { PaymentModule } from './payment';
import { RedisModule } from './redis';
export const pluginsMap = new Map<AvailablePlugins, AFFiNEModule>([
['payment', PaymentModule],
['redis', RedisModule],
]);

View File

@@ -0,0 +1,28 @@
import { ServerFeature } from '../../core/config';
import { FeatureModule } from '../../core/features';
import { OptionalModule } from '../../fundamentals';
import { SubscriptionResolver, UserSubscriptionResolver } from './resolver';
import { ScheduleManager } from './schedule';
import { SubscriptionService } from './service';
import { StripeProvider } from './stripe';
import { StripeWebhook } from './webhook';
@OptionalModule({
imports: [FeatureModule],
providers: [
ScheduleManager,
StripeProvider,
SubscriptionService,
SubscriptionResolver,
UserSubscriptionResolver,
],
controllers: [StripeWebhook],
requires: [
'plugins.payment.stripe.keys.APIKey',
'plugins.payment.stripe.keys.webhookKey',
],
contributesTo: ServerFeature.Payment,
})
export class PaymentModule {}
export type { PaymentConfig } from './types';

View File

@@ -0,0 +1,332 @@
import { HttpStatus } from '@nestjs/common';
import {
Args,
Context,
Field,
Int,
Mutation,
ObjectType,
Parent,
Query,
registerEnumType,
ResolveField,
Resolver,
} from '@nestjs/graphql';
import type { User, UserInvoice, UserSubscription } from '@prisma/client';
import { GraphQLError } from 'graphql';
import { groupBy } from 'lodash-es';
import { Auth, CurrentUser, Public } from '../../core/auth';
import { UserType } from '../../core/users';
import { Config, PrismaService } from '../../fundamentals';
import { decodeLookupKey, SubscriptionService } from './service';
import {
InvoiceStatus,
SubscriptionPlan,
SubscriptionRecurring,
SubscriptionStatus,
} from './types';
registerEnumType(SubscriptionStatus, { name: 'SubscriptionStatus' });
registerEnumType(SubscriptionRecurring, { name: 'SubscriptionRecurring' });
registerEnumType(SubscriptionPlan, { name: 'SubscriptionPlan' });
registerEnumType(InvoiceStatus, { name: 'InvoiceStatus' });
@ObjectType()
class SubscriptionPrice {
@Field(() => String)
type!: 'fixed';
@Field(() => SubscriptionPlan)
plan!: SubscriptionPlan;
@Field()
currency!: string;
@Field()
amount!: number;
@Field()
yearlyAmount!: number;
}
@ObjectType('UserSubscription')
export class UserSubscriptionType implements Partial<UserSubscription> {
@Field({ name: 'id' })
stripeSubscriptionId!: string;
@Field(() => SubscriptionPlan)
plan!: SubscriptionPlan;
@Field(() => SubscriptionRecurring)
recurring!: SubscriptionRecurring;
@Field(() => SubscriptionStatus)
status!: SubscriptionStatus;
@Field(() => Date)
start!: Date;
@Field(() => Date)
end!: Date;
@Field(() => Date, { nullable: true })
trialStart?: Date | null;
@Field(() => Date, { nullable: true })
trialEnd?: Date | null;
@Field(() => Date, { nullable: true })
nextBillAt?: Date | null;
@Field(() => Date, { nullable: true })
canceledAt?: Date | null;
@Field(() => Date)
createdAt!: Date;
@Field(() => Date)
updatedAt!: Date;
}
@ObjectType('UserInvoice')
class UserInvoiceType implements Partial<UserInvoice> {
@Field({ name: 'id' })
stripeInvoiceId!: string;
@Field(() => SubscriptionPlan)
plan!: SubscriptionPlan;
@Field(() => SubscriptionRecurring)
recurring!: SubscriptionRecurring;
@Field()
currency!: string;
@Field()
amount!: number;
@Field(() => InvoiceStatus)
status!: InvoiceStatus;
@Field()
reason!: string;
@Field(() => String, { nullable: true })
lastPaymentError?: string | null;
@Field(() => String, { nullable: true })
link?: string | null;
@Field(() => Date)
createdAt!: Date;
@Field(() => Date)
updatedAt!: Date;
}
@Auth()
@Resolver(() => UserSubscriptionType)
export class SubscriptionResolver {
constructor(
private readonly service: SubscriptionService,
private readonly config: Config
) {}
@Public()
@Query(() => [SubscriptionPrice])
async prices(): Promise<SubscriptionPrice[]> {
const prices = await this.service.listPrices();
const group = groupBy(
prices.data.filter(price => !!price.lookup_key),
price => {
// @ts-expect-error empty lookup key is filtered out
const [plan] = decodeLookupKey(price.lookup_key);
return plan;
}
);
return Object.entries(group).map(([plan, prices]) => {
const yearly = prices.find(
price =>
decodeLookupKey(
// @ts-expect-error empty lookup key is filtered out
price.lookup_key
)[1] === SubscriptionRecurring.Yearly
);
const monthly = prices.find(
price =>
decodeLookupKey(
// @ts-expect-error empty lookup key is filtered out
price.lookup_key
)[1] === SubscriptionRecurring.Monthly
);
if (!yearly || !monthly) {
throw new GraphQLError('The prices are not configured correctly', {
extensions: {
status: HttpStatus[HttpStatus.BAD_GATEWAY],
code: HttpStatus.BAD_GATEWAY,
},
});
}
return {
type: 'fixed',
plan: plan as SubscriptionPlan,
currency: monthly.currency,
amount: monthly.unit_amount ?? 0,
yearlyAmount: yearly.unit_amount ?? 0,
};
});
}
@Mutation(() => String, {
description: 'Create a subscription checkout link of stripe',
})
async checkout(
@CurrentUser() user: User,
@Args({ name: 'recurring', type: () => SubscriptionRecurring })
recurring: SubscriptionRecurring,
@Args('idempotencyKey') idempotencyKey: string
) {
const session = await this.service.createCheckoutSession({
user,
recurring,
redirectUrl: `${this.config.baseUrl}/upgrade-success`,
idempotencyKey,
});
if (!session.url) {
throw new GraphQLError('Failed to create checkout session', {
extensions: {
status: HttpStatus[HttpStatus.BAD_GATEWAY],
code: HttpStatus.BAD_GATEWAY,
},
});
}
return session.url;
}
@Mutation(() => String, {
description: 'Create a stripe customer portal to manage payment methods',
})
async createCustomerPortal(@CurrentUser() user: User) {
return this.service.createCustomerPortal(user.id);
}
@Mutation(() => UserSubscriptionType)
async cancelSubscription(
@CurrentUser() user: User,
@Args('idempotencyKey') idempotencyKey: string
) {
return this.service.cancelSubscription(idempotencyKey, user.id);
}
@Mutation(() => UserSubscriptionType)
async resumeSubscription(
@CurrentUser() user: User,
@Args('idempotencyKey') idempotencyKey: string
) {
return this.service.resumeCanceledSubscription(idempotencyKey, user.id);
}
@Mutation(() => UserSubscriptionType)
async updateSubscriptionRecurring(
@CurrentUser() user: User,
@Args({ name: 'recurring', type: () => SubscriptionRecurring })
recurring: SubscriptionRecurring,
@Args('idempotencyKey') idempotencyKey: string
) {
return this.service.updateSubscriptionRecurring(
idempotencyKey,
user.id,
recurring
);
}
}
@Resolver(() => UserType)
export class UserSubscriptionResolver {
constructor(
private readonly config: Config,
private readonly db: PrismaService
) {}
@ResolveField(() => UserSubscriptionType, { nullable: true })
async subscription(
@Context() ctx: { isAdminQuery: boolean },
@CurrentUser() me: User,
@Parent() user: User
) {
// allow admin to query other user's subscription
if (!ctx.isAdminQuery && me.id !== user.id) {
throw new GraphQLError(
'You are not allowed to access this subscription',
{
extensions: {
status: HttpStatus[HttpStatus.FORBIDDEN],
code: HttpStatus.FORBIDDEN,
},
}
);
}
// @FIXME(@forehalo): should not mock any api for selfhosted server
// the frontend should avoid calling such api if feature is not enabled
if (this.config.flavor.selfhosted) {
const start = new Date();
const end = new Date();
end.setFullYear(start.getFullYear() + 1);
return {
stripeSubscriptionId: 'dummy',
plan: SubscriptionPlan.SelfHosted,
recurring: SubscriptionRecurring.Yearly,
status: SubscriptionStatus.Active,
start,
end,
createdAt: start,
updatedAt: start,
};
}
return this.db.userSubscription.findUnique({
where: {
userId: user.id,
status: SubscriptionStatus.Active,
},
});
}
@ResolveField(() => [UserInvoiceType])
async invoices(
@CurrentUser() me: User,
@Parent() user: User,
@Args('take', { type: () => Int, nullable: true, defaultValue: 8 })
take: number,
@Args('skip', { type: () => Int, nullable: true }) skip?: number
) {
if (me.id !== user.id) {
throw new GraphQLError('You are not allowed to access this invoices', {
extensions: {
status: HttpStatus[HttpStatus.FORBIDDEN],
code: HttpStatus.FORBIDDEN,
},
});
}
return this.db.userInvoice.findMany({
where: {
userId: user.id,
},
take,
skip,
orderBy: {
id: 'desc',
},
});
}
}

View File

@@ -0,0 +1,238 @@
import { Injectable, Logger } from '@nestjs/common';
import Stripe from 'stripe';
@Injectable()
export class ScheduleManager {
private _schedule: Stripe.SubscriptionSchedule | null = null;
private readonly logger = new Logger(ScheduleManager.name);
constructor(private readonly stripe: Stripe) {}
static create(stripe: Stripe, schedule?: Stripe.SubscriptionSchedule) {
const manager = new ScheduleManager(stripe);
if (schedule) {
manager._schedule = schedule;
}
return manager;
}
get schedule() {
return this._schedule;
}
get currentPhase() {
if (!this._schedule) {
return null;
}
return this._schedule.phases.find(
phase =>
phase.start_date * 1000 < Date.now() &&
phase.end_date * 1000 > Date.now()
);
}
get nextPhase() {
if (!this._schedule) {
return null;
}
return this._schedule.phases.find(
phase => phase.start_date * 1000 > Date.now()
);
}
get isActive() {
return this._schedule?.status === 'active';
}
async fromSchedule(schedule: string | Stripe.SubscriptionSchedule) {
if (typeof schedule === 'string') {
const s = await this.stripe.subscriptionSchedules
.retrieve(schedule)
.catch(e => {
this.logger.error('Failed to retrieve subscription schedule', e);
return undefined;
});
return ScheduleManager.create(this.stripe, s);
} else {
return ScheduleManager.create(this.stripe, schedule);
}
}
async fromSubscription(
idempotencyKey: string,
subscription: string | Stripe.Subscription
) {
if (typeof subscription === 'string') {
subscription = await this.stripe.subscriptions.retrieve(subscription, {
expand: ['schedule'],
});
}
if (subscription.schedule) {
return await this.fromSchedule(subscription.schedule);
} else {
const schedule = await this.stripe.subscriptionSchedules.create(
{ from_subscription: subscription.id },
{ idempotencyKey }
);
return await this.fromSchedule(schedule);
}
}
/**
* Cancel a subscription by marking schedule's end behavior to `cancel`.
* At the same time, the coming phase's price and coupon will be saved to metadata for later resuming to correction subscription.
*/
async cancel(idempotencyKey: string) {
if (!this._schedule) {
throw new Error('No schedule');
}
if (!this.isActive || !this.currentPhase) {
throw new Error('Unexpected subscription schedule status');
}
const phases: Stripe.SubscriptionScheduleUpdateParams.Phase = {
items: [
{
price: this.currentPhase.items[0].price as string,
quantity: 1,
},
],
coupon: (this.currentPhase.coupon as string | null) ?? undefined,
start_date: this.currentPhase.start_date,
end_date: this.currentPhase.end_date,
};
if (this.nextPhase) {
// cancel a subscription with a schedule exiting will delete the upcoming phase,
// it's hard to recover the subscription to the original state if user wan't to resume before due.
// so we manually save the next phase's key information to metadata for later easy resuming.
phases.metadata = {
next_coupon: (this.nextPhase.coupon as string | null) || null, // avoid empty string
next_price: this.nextPhase.items[0].price as string,
};
}
await this.stripe.subscriptionSchedules.update(
this._schedule.id,
{
phases: [phases],
end_behavior: 'cancel',
},
{ idempotencyKey }
);
}
async resume(idempotencyKey: string) {
if (!this._schedule) {
throw new Error('No schedule');
}
if (!this.isActive || !this.currentPhase) {
throw new Error('Unexpected subscription schedule status');
}
const phases: Stripe.SubscriptionScheduleUpdateParams.Phase[] = [
{
items: [
{
price: this.currentPhase.items[0].price as string,
quantity: 1,
},
],
coupon: (this.currentPhase.coupon as string | null) ?? undefined,
start_date: this.currentPhase.start_date,
end_date: this.currentPhase.end_date,
metadata: {
next_coupon: null,
next_price: null,
},
},
];
if (this.currentPhase.metadata && this.currentPhase.metadata.next_price) {
phases.push({
items: [
{
price: this.currentPhase.metadata.next_price,
quantity: 1,
},
],
coupon: this.currentPhase.metadata.next_coupon || undefined,
});
}
await this.stripe.subscriptionSchedules.update(
this._schedule.id,
{
phases: phases,
end_behavior: 'release',
},
{ idempotencyKey }
);
}
async release(idempotencyKey: string) {
if (!this._schedule) {
throw new Error('No schedule');
}
await this.stripe.subscriptionSchedules.release(this._schedule.id, {
idempotencyKey,
});
}
async update(idempotencyKey: string, price: string, coupon?: string) {
if (!this._schedule) {
throw new Error('No schedule');
}
if (!this.isActive || !this.currentPhase) {
throw new Error('Unexpected subscription schedule status');
}
// if current phase's plan matches target, and no coupon change, just release the schedule
if (
this.currentPhase.items[0].price === price &&
(!coupon || this.currentPhase.coupon === coupon)
) {
await this.stripe.subscriptionSchedules.release(this._schedule.id, {
idempotencyKey,
});
this._schedule = null;
} else {
await this.stripe.subscriptionSchedules.update(
this._schedule.id,
{
phases: [
{
items: [
{
price: this.currentPhase.items[0].price as string,
},
],
start_date: this.currentPhase.start_date,
end_date: this.currentPhase.end_date,
},
{
items: [
{
price: price,
quantity: 1,
},
],
coupon,
},
],
},
{ idempotencyKey }
);
}
}
}

View File

@@ -0,0 +1,646 @@
import { Injectable, Logger } from '@nestjs/common';
import { OnEvent as RawOnEvent } from '@nestjs/event-emitter';
import type {
Prisma,
User,
UserInvoice,
UserStripeCustomer,
UserSubscription,
} from '@prisma/client';
import Stripe from 'stripe';
import { FeatureManagementService } from '../../core/features';
import { EventEmitter, PrismaService } from '../../fundamentals';
import { ScheduleManager } from './schedule';
import {
InvoiceStatus,
SubscriptionPlan,
SubscriptionRecurring,
SubscriptionStatus,
} from './types';
const OnEvent = (
event: Stripe.Event.Type,
opts?: Parameters<typeof RawOnEvent>[1]
) => RawOnEvent(event, opts);
// Plan x Recurring make a stripe price lookup key
export function encodeLookupKey(
plan: SubscriptionPlan,
recurring: SubscriptionRecurring
): string {
return plan + '_' + recurring;
}
export function decodeLookupKey(
key: string
): [SubscriptionPlan, SubscriptionRecurring] {
const [plan, recurring] = key.split('_');
return [plan as SubscriptionPlan, recurring as SubscriptionRecurring];
}
const SubscriptionActivated: Stripe.Subscription.Status[] = [
SubscriptionStatus.Active,
SubscriptionStatus.Trialing,
];
export enum CouponType {
EarlyAccess = 'earlyaccess',
EarlyAccessRenew = 'earlyaccessrenew',
}
@Injectable()
export class SubscriptionService {
private readonly logger = new Logger(SubscriptionService.name);
constructor(
private readonly stripe: Stripe,
private readonly db: PrismaService,
private readonly scheduleManager: ScheduleManager,
private readonly event: EventEmitter,
private readonly features: FeatureManagementService
) {}
async listPrices() {
return this.stripe.prices.list();
}
async createCheckoutSession({
user,
recurring,
redirectUrl,
idempotencyKey,
plan = SubscriptionPlan.Pro,
}: {
user: User;
plan?: SubscriptionPlan;
recurring: SubscriptionRecurring;
redirectUrl: string;
idempotencyKey: string;
}) {
const currentSubscription = await this.db.userSubscription.findFirst({
where: {
userId: user.id,
status: SubscriptionStatus.Active,
},
});
if (currentSubscription) {
throw new Error('You already have a subscription');
}
const price = await this.getPrice(plan, recurring);
const customer = await this.getOrCreateCustomer(
`${idempotencyKey}-getOrCreateCustomer`,
user
);
const coupon = await this.getAvailableCoupon(user, CouponType.EarlyAccess);
return await this.stripe.checkout.sessions.create(
{
line_items: [
{
price,
quantity: 1,
},
],
tax_id_collection: {
enabled: true,
},
...(coupon
? {
discounts: [{ coupon }],
}
: {
allow_promotion_codes: true,
}),
mode: 'subscription',
success_url: redirectUrl,
customer: customer.stripeCustomerId,
customer_update: {
address: 'auto',
name: 'auto',
},
},
{ idempotencyKey: `${idempotencyKey}-checkoutSession` }
);
}
async cancelSubscription(
idempotencyKey: string,
userId: string
): Promise<UserSubscription> {
const user = await this.db.user.findUnique({
where: {
id: userId,
},
include: {
subscription: true,
},
});
if (!user?.subscription) {
throw new Error('You do not have any subscription');
}
if (user.subscription.canceledAt) {
throw new Error('Your subscription has already been canceled');
}
// should release the schedule first
if (user.subscription.stripeScheduleId) {
const manager = await this.scheduleManager.fromSchedule(
user.subscription.stripeScheduleId
);
await manager.cancel(idempotencyKey);
return this.saveSubscription(
user,
await this.stripe.subscriptions.retrieve(
user.subscription.stripeSubscriptionId
),
false
);
} else {
// let customer contact support if they want to cancel immediately
// see https://stripe.com/docs/billing/subscriptions/cancel
const subscription = await this.stripe.subscriptions.update(
user.subscription.stripeSubscriptionId,
{ cancel_at_period_end: true },
{ idempotencyKey }
);
return await this.saveSubscription(user, subscription);
}
}
async resumeCanceledSubscription(
idempotencyKey: string,
userId: string
): Promise<UserSubscription> {
const user = await this.db.user.findUnique({
where: {
id: userId,
},
include: {
subscription: true,
},
});
if (!user?.subscription) {
throw new Error('You do not have any subscription');
}
if (!user.subscription.canceledAt) {
throw new Error('Your subscription has not been canceled');
}
if (user.subscription.end < new Date()) {
throw new Error('Your subscription is expired, please checkout again.');
}
if (user.subscription.stripeScheduleId) {
const manager = await this.scheduleManager.fromSchedule(
user.subscription.stripeScheduleId
);
await manager.resume(idempotencyKey);
return this.saveSubscription(
user,
await this.stripe.subscriptions.retrieve(
user.subscription.stripeSubscriptionId
),
false
);
} else {
const subscription = await this.stripe.subscriptions.update(
user.subscription.stripeSubscriptionId,
{ cancel_at_period_end: false },
{ idempotencyKey }
);
return await this.saveSubscription(user, subscription);
}
}
async updateSubscriptionRecurring(
idempotencyKey: string,
userId: string,
recurring: SubscriptionRecurring
): Promise<UserSubscription> {
const user = await this.db.user.findUnique({
where: {
id: userId,
},
include: {
subscription: true,
},
});
if (!user?.subscription) {
throw new Error('You do not have any subscription');
}
if (user.subscription.canceledAt) {
throw new Error('Your subscription has already been canceled ');
}
if (user.subscription.recurring === recurring) {
throw new Error('You have already subscribed to this plan');
}
const price = await this.getPrice(
user.subscription.plan as SubscriptionPlan,
recurring
);
const manager = await this.scheduleManager.fromSubscription(
`${idempotencyKey}-fromSubscription`,
user.subscription.stripeSubscriptionId
);
await manager.update(
`${idempotencyKey}-update`,
price,
// if user is early access user, use early access coupon
manager.currentPhase?.coupon === CouponType.EarlyAccess ||
manager.currentPhase?.coupon === CouponType.EarlyAccessRenew ||
manager.nextPhase?.coupon === CouponType.EarlyAccessRenew
? CouponType.EarlyAccessRenew
: undefined
);
return await this.db.userSubscription.update({
where: {
id: user.subscription.id,
},
data: {
stripeScheduleId: manager.schedule?.id ?? null, // update schedule id or set to null(undefined means untouched)
recurring,
},
});
}
async createCustomerPortal(id: string) {
const user = await this.db.userStripeCustomer.findUnique({
where: {
userId: id,
},
});
if (!user) {
throw new Error('Unknown user');
}
try {
const portal = await this.stripe.billingPortal.sessions.create({
customer: user.stripeCustomerId,
});
return portal.url;
} catch (e) {
this.logger.error('Failed to create customer portal.', e);
throw new Error('Failed to create customer portal');
}
}
@OnEvent('customer.subscription.created')
@OnEvent('customer.subscription.updated')
async onSubscriptionChanges(subscription: Stripe.Subscription) {
const user = await this.retrieveUserFromCustomer(
subscription.customer as string
);
await this.saveSubscription(user, subscription);
}
@OnEvent('customer.subscription.deleted')
async onSubscriptionDeleted(subscription: Stripe.Subscription) {
const user = await this.retrieveUserFromCustomer(
subscription.customer as string
);
await this.db.userSubscription.deleteMany({
where: {
stripeSubscriptionId: subscription.id,
userId: user.id,
},
});
}
@OnEvent('invoice.paid')
async onInvoicePaid(stripeInvoice: Stripe.Invoice) {
await this.saveInvoice(stripeInvoice);
const line = stripeInvoice.lines.data[0];
if (!line.price || line.price.type !== 'recurring') {
throw new Error('Unknown invoice with no recurring price');
}
// deal with early access user
if (stripeInvoice.discount?.coupon.id === CouponType.EarlyAccess) {
const idempotencyKey = stripeInvoice.id + '_earlyaccess';
const manager = await this.scheduleManager.fromSubscription(
`${idempotencyKey}-fromSubscription`,
line.subscription as string
);
await manager.update(
`${idempotencyKey}-update`,
line.price.id,
CouponType.EarlyAccessRenew
);
}
}
@OnEvent('invoice.created')
@OnEvent('invoice.finalization_failed')
@OnEvent('invoice.payment_failed')
async saveInvoice(stripeInvoice: Stripe.Invoice) {
if (!stripeInvoice.customer) {
throw new Error('Unexpected invoice with no customer');
}
const user = await this.retrieveUserFromCustomer(
typeof stripeInvoice.customer === 'string'
? stripeInvoice.customer
: stripeInvoice.customer.id
);
const invoice = await this.db.userInvoice.findUnique({
where: {
stripeInvoiceId: stripeInvoice.id,
},
});
const data: Partial<UserInvoice> = {
currency: stripeInvoice.currency,
amount: stripeInvoice.total,
status: stripeInvoice.status ?? InvoiceStatus.Void,
link: stripeInvoice.hosted_invoice_url,
};
// handle payment error
if (stripeInvoice.attempt_count > 1) {
const paymentIntent = await this.stripe.paymentIntents.retrieve(
stripeInvoice.payment_intent as string
);
if (paymentIntent.last_payment_error) {
if (paymentIntent.last_payment_error.type === 'card_error') {
data.lastPaymentError =
paymentIntent.last_payment_error.message ?? 'Failed to pay';
} else {
data.lastPaymentError = 'Internal Payment error';
}
}
} else if (stripeInvoice.last_finalization_error) {
if (stripeInvoice.last_finalization_error.type === 'card_error') {
data.lastPaymentError =
stripeInvoice.last_finalization_error.message ??
'Failed to finalize invoice';
} else {
data.lastPaymentError = 'Internal Payment error';
}
}
// update invoice
if (invoice) {
await this.db.userInvoice.update({
where: {
stripeInvoiceId: stripeInvoice.id,
},
data,
});
} else {
// create invoice
const price = stripeInvoice.lines.data[0].price;
if (!price || price.type !== 'recurring') {
throw new Error('Unexpected invoice with no recurring price');
}
if (!price.lookup_key) {
throw new Error('Unexpected subscription with no key');
}
const [plan, recurring] = decodeLookupKey(price.lookup_key);
await this.db.userInvoice.create({
data: {
userId: user.id,
stripeInvoiceId: stripeInvoice.id,
plan,
recurring,
reason: stripeInvoice.billing_reason ?? 'contact support',
...(data as any),
},
});
}
}
private async saveSubscription(
user: User,
subscription: Stripe.Subscription,
fromWebhook = true
): Promise<UserSubscription> {
// webhook events may not in sequential order
// always fetch the latest subscription and save
// see https://stripe.com/docs/webhooks#behaviors
if (fromWebhook) {
subscription = await this.stripe.subscriptions.retrieve(subscription.id);
}
const price = subscription.items.data[0].price;
if (!price.lookup_key) {
throw new Error('Unexpected subscription with no key');
}
const [plan, recurring] = decodeLookupKey(price.lookup_key);
const planActivated = SubscriptionActivated.includes(subscription.status);
let nextBillAt: Date | null = null;
if (planActivated) {
this.event.emit('user.subscription.activated', {
userId: user.id,
plan,
});
// get next bill date from upcoming invoice
// see https://stripe.com/docs/api/invoices/upcoming
if (!subscription.canceled_at) {
nextBillAt = new Date(subscription.current_period_end * 1000);
}
} else {
this.event.emit('user.subscription.canceled', user.id);
}
const commonData = {
start: new Date(subscription.current_period_start * 1000),
end: new Date(subscription.current_period_end * 1000),
trialStart: subscription.trial_start
? new Date(subscription.trial_start * 1000)
: null,
trialEnd: subscription.trial_end
? new Date(subscription.trial_end * 1000)
: null,
nextBillAt,
canceledAt: subscription.canceled_at
? new Date(subscription.canceled_at * 1000)
: null,
stripeSubscriptionId: subscription.id,
plan,
recurring,
status: subscription.status,
stripeScheduleId: subscription.schedule as string | null,
};
const currentSubscription = await this.db.userSubscription.findUnique({
where: {
userId: user.id,
},
});
if (currentSubscription) {
const update: Prisma.UserSubscriptionUpdateInput = {
...commonData,
};
// a schedule exists, update the recurring to scheduled one
if (update.stripeScheduleId) {
delete update.recurring;
}
return await this.db.userSubscription.update({
where: {
id: currentSubscription.id,
},
data: update,
});
} else {
return await this.db.userSubscription.create({
data: {
userId: user.id,
...commonData,
},
});
}
}
private async getOrCreateCustomer(
idempotencyKey: string,
user: User
): Promise<UserStripeCustomer> {
const customer = await this.db.userStripeCustomer.findUnique({
where: {
userId: user.id,
},
});
if (customer) {
return customer;
}
const stripeCustomersList = await this.stripe.customers.list({
email: user.email,
limit: 1,
});
let stripeCustomer: Stripe.Customer | undefined;
if (stripeCustomersList.data.length) {
stripeCustomer = stripeCustomersList.data[0];
} else {
stripeCustomer = await this.stripe.customers.create(
{ email: user.email },
{ idempotencyKey }
);
}
return await this.db.userStripeCustomer.create({
data: {
userId: user.id,
stripeCustomerId: stripeCustomer.id,
},
});
}
private async retrieveUserFromCustomer(customerId: string) {
const customer = await this.db.userStripeCustomer.findUnique({
where: {
stripeCustomerId: customerId,
},
include: {
user: true,
},
});
if (customer?.user) {
return customer.user;
}
// customer may not saved is db, check it with stripe
const stripeCustomer = await this.stripe.customers.retrieve(customerId);
if (stripeCustomer.deleted) {
throw new Error('Unexpected subscription created with deleted customer');
}
if (!stripeCustomer.email) {
throw new Error('Unexpected subscription created with no email customer');
}
const user = await this.db.user.findUnique({
where: {
email: stripeCustomer.email,
},
});
if (!user) {
throw new Error(
`Unexpected subscription created with unknown customer ${stripeCustomer.email}`
);
}
await this.db.userStripeCustomer.create({
data: {
userId: user.id,
stripeCustomerId: stripeCustomer.id,
},
});
return user;
}
private async getPrice(
plan: SubscriptionPlan,
recurring: SubscriptionRecurring
): Promise<string> {
const prices = await this.stripe.prices.list({
lookup_keys: [encodeLookupKey(plan, recurring)],
});
if (!prices.data.length) {
throw new Error(
`Unknown subscription plan ${plan} with recurring ${recurring}`
);
}
return prices.data[0].id;
}
private async getAvailableCoupon(
user: User,
couponType: CouponType
): Promise<string | null> {
const earlyAccess = await this.features.isEarlyAccessUser(user.email);
if (earlyAccess) {
try {
const coupon = await this.stripe.coupons.retrieve(couponType);
return coupon.valid ? coupon.id : null;
} catch (e) {
this.logger.error('Failed to get early access coupon', e);
return null;
}
}
return null;
}
}

View File

@@ -0,0 +1,18 @@
import assert from 'node:assert';
import { FactoryProvider } from '@nestjs/common';
import { omit } from 'lodash-es';
import Stripe from 'stripe';
import { Config } from '../../fundamentals';
export const StripeProvider: FactoryProvider = {
provide: Stripe,
useFactory: (config: Config) => {
assert(config.plugins.payment);
const stripeConfig = config.plugins.payment.stripe;
return new Stripe(stripeConfig.keys.APIKey, omit(stripeConfig, 'keys'));
},
inject: [Config],
};

View File

@@ -0,0 +1,58 @@
import { type User } from '@prisma/client';
import { type Stripe } from 'stripe';
import type { Payload } from '../../fundamentals/event/def';
export interface PaymentConfig {
stripe: {
keys: {
APIKey: string;
webhookKey: string;
};
} & Stripe.StripeConfig;
}
export enum SubscriptionRecurring {
Monthly = 'monthly',
Yearly = 'yearly',
}
export enum SubscriptionPlan {
Free = 'free',
Pro = 'pro',
Team = 'team',
Enterprise = 'enterprise',
SelfHosted = 'selfhosted',
}
// see https://stripe.com/docs/api/subscriptions/object#subscription_object-status
export enum SubscriptionStatus {
Active = 'active',
PastDue = 'past_due',
Unpaid = 'unpaid',
Canceled = 'canceled',
Incomplete = 'incomplete',
Paused = 'paused',
IncompleteExpired = 'incomplete_expired',
Trialing = 'trialing',
}
export enum InvoiceStatus {
Draft = 'draft',
Open = 'open',
Void = 'void',
Paid = 'paid',
Uncollectible = 'uncollectible',
}
declare module '../../fundamentals/event/def' {
interface UserEvents {
subscription: {
activated: Payload<{
userId: User['id'];
plan: SubscriptionPlan;
}>;
canceled: Payload<User['id']>;
};
}
}

View File

@@ -0,0 +1,61 @@
import assert from 'node:assert';
import type { RawBodyRequest } from '@nestjs/common';
import {
Controller,
Logger,
NotAcceptableException,
Post,
Req,
} from '@nestjs/common';
import { EventEmitter2 } from '@nestjs/event-emitter';
import type { Request } from 'express';
import Stripe from 'stripe';
import { Config } from '../../fundamentals';
@Controller('/api/stripe')
export class StripeWebhook {
private readonly webhookKey: string;
private readonly logger = new Logger(StripeWebhook.name);
constructor(
config: Config,
private readonly stripe: Stripe,
private readonly event: EventEmitter2
) {
assert(config.plugins.payment);
this.webhookKey = config.plugins.payment.stripe.keys.webhookKey;
}
@Post('/webhook')
async handleWebhook(@Req() req: RawBodyRequest<Request>) {
// Check if webhook signing is configured.
// Retrieve the event by verifying the signature using the raw body and secret.
const signature = req.headers['stripe-signature'];
try {
const event = this.stripe.webhooks.constructEvent(
req.rawBody ?? '',
signature ?? '',
this.webhookKey
);
this.logger.debug(
`[${event.id}] Stripe Webhook {${event.type}} received.`
);
// Stripe requires responseing webhook immediately and handle event asynchronously.
setImmediate(() => {
// handle duplicated events?
// see https://stripe.com/docs/webhooks#handle-duplicate-events
this.event.emitAsync(event.type, event.data.object).catch(e => {
this.logger.error('Failed to handle Stripe Webhook event.', e);
});
});
} catch (err) {
this.logger.error('Stripe Webhook error', err);
throw new NotAcceptableException();
}
}
}

View File

@@ -0,0 +1,194 @@
import { Redis } from 'ioredis';
import type { Cache, CacheSetOptions } from '../../fundamentals/cache/def';
export class RedisCache implements Cache {
constructor(private readonly redis: Redis) {}
// standard operation
async get<T = unknown>(key: string): Promise<T> {
return this.redis
.get(key)
.then(v => {
if (v) {
return JSON.parse(v);
}
return undefined;
})
.catch(() => undefined);
}
async set<T = unknown>(
key: string,
value: T,
opts: CacheSetOptions = {}
): Promise<boolean> {
if (opts.ttl) {
return this.redis
.set(key, JSON.stringify(value), 'PX', opts.ttl)
.then(() => true)
.catch(() => false);
}
return this.redis
.set(key, JSON.stringify(value))
.then(() => true)
.catch(() => false);
}
async increase(key: string, count: number = 1): Promise<number> {
return this.redis.incrby(key, count).catch(() => 0);
}
async decrease(key: string, count: number = 1): Promise<number> {
return this.redis.decrby(key, count).catch(() => 0);
}
async setnx<T = unknown>(
key: string,
value: T,
opts: CacheSetOptions = {}
): Promise<boolean> {
if (opts.ttl) {
return this.redis
.set(key, JSON.stringify(value), 'PX', opts.ttl, 'NX')
.then(v => !!v)
.catch(() => false);
}
return this.redis
.set(key, JSON.stringify(value), 'NX')
.then(v => !!v)
.catch(() => false);
}
async delete(key: string): Promise<boolean> {
return this.redis
.del(key)
.then(v => v > 0)
.catch(() => false);
}
async has(key: string): Promise<boolean> {
return this.redis
.exists(key)
.then(v => v > 0)
.catch(() => false);
}
async ttl(key: string): Promise<number> {
return this.redis.ttl(key).catch(() => 0);
}
async expire(key: string, ttl: number): Promise<boolean> {
return this.redis
.pexpire(key, ttl)
.then(v => v > 0)
.catch(() => false);
}
// list operations
async pushBack<T = unknown>(key: string, ...values: T[]): Promise<number> {
return this.redis
.rpush(key, ...values.map(v => JSON.stringify(v)))
.catch(() => 0);
}
async pushFront<T = unknown>(key: string, ...values: T[]): Promise<number> {
return this.redis
.lpush(key, ...values.map(v => JSON.stringify(v)))
.catch(() => 0);
}
async len(key: string): Promise<number> {
return this.redis.llen(key).catch(() => 0);
}
async list<T = unknown>(
key: string,
start: number,
end: number
): Promise<T[]> {
return this.redis
.lrange(key, start, end)
.then(data => data.map(v => JSON.parse(v)))
.catch(() => []);
}
async popFront<T = unknown>(key: string, count: number = 1): Promise<T[]> {
return this.redis
.lpop(key, count)
.then(data => (data ?? []).map(v => JSON.parse(v)))
.catch(() => []);
}
async popBack<T = unknown>(key: string, count: number = 1): Promise<T[]> {
return this.redis
.rpop(key, count)
.then(data => (data ?? []).map(v => JSON.parse(v)))
.catch(() => []);
}
// map operations
async mapSet<T = unknown>(
map: string,
key: string,
value: T
): Promise<boolean> {
return this.redis
.hset(map, key, JSON.stringify(value))
.then(v => v > 0)
.catch(() => false);
}
async mapIncrease(
map: string,
key: string,
count: number = 1
): Promise<number> {
return this.redis.hincrby(map, key, count);
}
async mapDecrease(
map: string,
key: string,
count: number = 1
): Promise<number> {
return this.redis.hincrby(map, key, -count);
}
async mapGet<T = unknown>(map: string, key: string): Promise<T | undefined> {
return this.redis
.hget(map, key)
.then(v => (v ? JSON.parse(v) : undefined))
.catch(() => undefined);
}
async mapDelete(map: string, key: string): Promise<boolean> {
return this.redis
.hdel(map, key)
.then(v => v > 0)
.catch(() => false);
}
async mapKeys(map: string): Promise<string[]> {
return this.redis.hkeys(map).catch(() => []);
}
async mapRandomKey(map: string): Promise<string | undefined> {
return this.redis
.hrandfield(map, 1)
.then(v =>
typeof v === 'string'
? v
: Array.isArray(v)
? (v[0] as string)
: undefined
)
.catch(() => undefined);
}
async mapLen(map: string): Promise<number> {
return this.redis.hlen(map).catch(() => 0);
}
}

View File

@@ -0,0 +1,62 @@
import { Global, Provider, Type } from '@nestjs/common';
import { Redis, type RedisOptions } from 'ioredis';
import { ThrottlerStorageRedisService } from 'nestjs-throttler-storage-redis';
import { Cache, OptionalModule, SessionCache } from '../../fundamentals';
import { ThrottlerStorage } from '../../fundamentals/throttler';
import { SocketIoAdapterImpl } from '../../fundamentals/websocket';
import { RedisCache } from './cache';
import {
CacheRedis,
SessionRedis,
SocketIoRedis,
ThrottlerRedis,
} from './instances';
import { createSockerIoAdapterImpl } from './ws-adapter';
function makeProvider(token: Type, impl: Type<Redis>): Provider {
return {
provide: token,
useFactory: (redis: Redis) => {
return new RedisCache(redis);
},
inject: [impl],
};
}
// cache
const cacheProvider = makeProvider(Cache, CacheRedis);
const sessionCacheProvider = makeProvider(SessionCache, SessionRedis);
// throttler
const throttlerStorageProvider: Provider = {
provide: ThrottlerStorage,
useFactory: (redis: Redis) => {
return new ThrottlerStorageRedisService(redis);
},
inject: [ThrottlerRedis],
};
// socket io
const socketIoRedisAdapterProvider: Provider = {
provide: SocketIoAdapterImpl,
useFactory: (redis: Redis) => {
return createSockerIoAdapterImpl(redis);
},
inject: [SocketIoRedis],
};
@Global()
@OptionalModule({
providers: [CacheRedis, SessionRedis, ThrottlerRedis, SocketIoRedis],
overrides: [
cacheProvider,
sessionCacheProvider,
socketIoRedisAdapterProvider,
throttlerStorageProvider,
],
requires: ['plugins.redis.host'],
})
export class RedisModule {}
export { RedisOptions };

View File

@@ -0,0 +1,56 @@
import {
Injectable,
Logger,
OnModuleDestroy,
OnModuleInit,
} from '@nestjs/common';
import { Redis as IORedis, RedisOptions } from 'ioredis';
import { Config } from '../../fundamentals/config';
class Redis extends IORedis implements OnModuleDestroy, OnModuleInit {
logger = new Logger(Redis.name);
constructor(opts: RedisOptions) {
super({
...opts,
lazyConnect: true,
});
}
async onModuleInit() {
await this.connect().catch(() => {
this.logger.error('Failed to connect to Redis server.');
});
}
onModuleDestroy() {
this.disconnect();
}
}
@Injectable()
export class CacheRedis extends Redis {
constructor(config: Config) {
super(config.plugins.redis ?? {});
}
}
@Injectable()
export class ThrottlerRedis extends Redis {
constructor(config: Config) {
super({ ...config.plugins.redis, db: (config.plugins.redis?.db ?? 0) + 1 });
}
}
@Injectable()
export class SessionRedis extends Redis {
constructor(config: Config) {
super({ ...config.plugins.redis, db: (config.plugins.redis?.db ?? 0) + 2 });
}
}
@Injectable()
export class SocketIoRedis extends Redis {
constructor(config: Config) {
super({ ...config.plugins.redis, db: (config.plugins.redis?.db ?? 0) + 3 });
}
}

View File

@@ -0,0 +1,3 @@
import { RedisOptions } from 'ioredis';
export type { RedisOptions };

View File

@@ -0,0 +1,28 @@
import { createAdapter } from '@socket.io/redis-adapter';
import { Redis } from 'ioredis';
import { Server, ServerOptions } from 'socket.io';
import { SocketIoAdapter } from '../../fundamentals';
export function createSockerIoAdapterImpl(
redis: Redis
): typeof SocketIoAdapter {
class RedisIoAdapter extends SocketIoAdapter {
override createIOServer(port: number, options?: ServerOptions): Server {
const pubClient = redis;
pubClient.on('error', err => {
console.error(err);
});
const subClient = pubClient.duplicate();
subClient.on('error', err => {
console.error(err);
});
const server = super.createIOServer(port, options) as Server;
server.adapter(createAdapter(pubClient, subClient));
return server;
}
}
return RedisIoAdapter;
}