feat(server): support team workspace subscription (#8919)

close AF-1724, AF-1722
This commit is contained in:
forehalo
2024-12-05 08:31:01 +00:00
parent 4055e3aa67
commit 5bf8ed1095
26 changed files with 2208 additions and 785 deletions

View File

@@ -0,0 +1,53 @@
-- DropForeignKey
ALTER TABLE "user_invoices" DROP CONSTRAINT "user_invoices_user_id_fkey";
-- DropForeignKey
ALTER TABLE "user_subscriptions" DROP CONSTRAINT "user_subscriptions_user_id_fkey";
-- CreateTable
CREATE TABLE "subscriptions" (
"id" SERIAL NOT NULL,
"target_id" VARCHAR NOT NULL,
"plan" VARCHAR(20) NOT NULL,
"recurring" VARCHAR(20) NOT NULL,
"variant" VARCHAR(20),
"quantity" INTEGER NOT NULL DEFAULT 1,
"stripe_subscription_id" TEXT,
"stripe_schedule_id" VARCHAR,
"status" VARCHAR(20) NOT NULL,
"start" TIMESTAMPTZ(3) NOT NULL,
"end" TIMESTAMPTZ(3),
"next_bill_at" TIMESTAMPTZ(3),
"canceled_at" TIMESTAMPTZ(3),
"trial_start" TIMESTAMPTZ(3),
"trial_end" TIMESTAMPTZ(3),
"created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ(3) NOT NULL,
CONSTRAINT "subscriptions_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "invoices" (
"stripe_invoice_id" TEXT NOT NULL,
"target_id" VARCHAR NOT NULL,
"currency" VARCHAR(3) NOT NULL,
"amount" INTEGER NOT NULL,
"status" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ(3) NOT NULL,
"reason" VARCHAR,
"last_payment_error" TEXT,
"link" TEXT,
CONSTRAINT "invoices_pkey" PRIMARY KEY ("stripe_invoice_id")
);
-- CreateIndex
CREATE UNIQUE INDEX "subscriptions_stripe_subscription_id_key" ON "subscriptions"("stripe_subscription_id");
-- CreateIndex
CREATE UNIQUE INDEX "subscriptions_target_id_plan_key" ON "subscriptions"("target_id", "plan");
-- CreateIndex
CREATE INDEX "invoices_target_id_idx" ON "invoices"("target_id");

View File

@@ -23,9 +23,7 @@ model User {
registered Boolean @default(true)
features UserFeature[]
customer UserStripeCustomer?
subscriptions UserSubscription[]
invoices UserInvoice[]
userStripeCustomer UserStripeCustomer?
workspacePermissions WorkspaceUserPermission[]
pagePermissions WorkspacePageUserPermission[]
connectedAccounts ConnectedAccount[]
@@ -318,77 +316,6 @@ model SnapshotHistory {
@@map("snapshot_histories")
}
model UserStripeCustomer {
userId String @id @map("user_id") @db.VarChar
stripeCustomerId String @unique @map("stripe_customer_id") @db.VarChar
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@map("user_stripe_customers")
}
model UserSubscription {
id Int @id @default(autoincrement()) @db.Integer
userId String @map("user_id") @db.VarChar
plan String @db.VarChar(20)
// yearly/monthly/lifetime
recurring String @db.VarChar(20)
// onetime subscription or anything else
variant String? @db.VarChar(20)
// subscription.id, null for linefetime payment or one time payment subscription
stripeSubscriptionId String? @unique @map("stripe_subscription_id")
// subscription.status, active/past_due/canceled/unpaid...
status String @db.VarChar(20)
// subscription.current_period_start
start DateTime @map("start") @db.Timestamptz(3)
// subscription.current_period_end, null for lifetime payment
end DateTime? @map("end") @db.Timestamptz(3)
// subscription.billing_cycle_anchor
nextBillAt DateTime? @map("next_bill_at") @db.Timestamptz(3)
// subscription.canceled_at
canceledAt DateTime? @map("canceled_at") @db.Timestamptz(3)
// subscription.trial_start
trialStart DateTime? @map("trial_start") @db.Timestamptz(3)
// subscription.trial_end
trialEnd DateTime? @map("trial_end") @db.Timestamptz(3)
stripeScheduleId String? @map("stripe_schedule_id") @db.VarChar
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3)
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@unique([userId, plan])
@@map("user_subscriptions")
}
model UserInvoice {
id Int @id @default(autoincrement()) @db.Integer
userId String @map("user_id") @db.VarChar
stripeInvoiceId String @unique @map("stripe_invoice_id")
currency String @db.VarChar(3)
// CNY 12.50 stored as 1250
amount Int @db.Integer
status String @db.VarChar(20)
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3)
// billing reason
reason String? @db.VarChar
lastPaymentError String? @map("last_payment_error") @db.Text
// stripe hosted invoice link
link String? @db.Text
// @deprecated
plan String? @db.VarChar(20)
// @deprecated
recurring String? @db.VarChar(20)
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@index([userId])
@@map("user_invoices")
}
enum AiPromptRole {
system
assistant
@@ -503,3 +430,124 @@ model RuntimeConfig {
@@unique([module, key])
@@map("app_runtime_settings")
}
model DeprecatedUserSubscription {
id Int @id @default(autoincrement()) @db.Integer
userId String @map("user_id") @db.VarChar
plan String @db.VarChar(20)
// yearly/monthly/lifetime
recurring String @db.VarChar(20)
// onetime subscription or anything else
variant String? @db.VarChar(20)
// subscription.id, null for lifetime payment or one time payment subscription
stripeSubscriptionId String? @unique @map("stripe_subscription_id")
// subscription.status, active/past_due/canceled/unpaid...
status String @db.VarChar(20)
// subscription.current_period_start
start DateTime @map("start") @db.Timestamptz(3)
// subscription.current_period_end, null for lifetime payment
end DateTime? @map("end") @db.Timestamptz(3)
// subscription.billing_cycle_anchor
nextBillAt DateTime? @map("next_bill_at") @db.Timestamptz(3)
// subscription.canceled_at
canceledAt DateTime? @map("canceled_at") @db.Timestamptz(3)
// subscription.trial_start
trialStart DateTime? @map("trial_start") @db.Timestamptz(3)
// subscription.trial_end
trialEnd DateTime? @map("trial_end") @db.Timestamptz(3)
stripeScheduleId String? @map("stripe_schedule_id") @db.VarChar
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3)
@@unique([userId, plan])
@@map("user_subscriptions")
}
model DeprecatedUserInvoice {
id Int @id @default(autoincrement()) @db.Integer
userId String @map("user_id") @db.VarChar
stripeInvoiceId String @unique @map("stripe_invoice_id")
currency String @db.VarChar(3)
// CNY 12.50 stored as 1250
amount Int @db.Integer
status String @db.VarChar(20)
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3)
// billing reason
reason String? @db.VarChar
lastPaymentError String? @map("last_payment_error") @db.Text
// stripe hosted invoice link
link String? @db.Text
// @deprecated
plan String? @db.VarChar(20)
// @deprecated
recurring String? @db.VarChar(20)
@@index([userId])
@@map("user_invoices")
}
model UserStripeCustomer {
userId String @id @map("user_id") @db.VarChar
stripeCustomerId String @unique @map("stripe_customer_id") @db.VarChar
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@map("user_stripe_customers")
}
model Subscription {
id Int @id @default(autoincrement()) @db.Integer
targetId String @map("target_id") @db.VarChar
plan String @db.VarChar(20)
// yearly/monthly/lifetime
recurring String @db.VarChar(20)
// onetime subscription or anything else
variant String? @db.VarChar(20)
quantity Int @default(1) @db.Integer
// subscription.id, null for lifetime payment or one time payment subscription
stripeSubscriptionId String? @unique @map("stripe_subscription_id")
// stripe schedule id
stripeScheduleId String? @map("stripe_schedule_id") @db.VarChar
// subscription.status, active/past_due/canceled/unpaid...
status String @db.VarChar(20)
// subscription.current_period_start
start DateTime @map("start") @db.Timestamptz(3)
// subscription.current_period_end, null for lifetime payment
end DateTime? @map("end") @db.Timestamptz(3)
// subscription.billing_cycle_anchor
nextBillAt DateTime? @map("next_bill_at") @db.Timestamptz(3)
// subscription.canceled_at
canceledAt DateTime? @map("canceled_at") @db.Timestamptz(3)
// subscription.trial_start
trialStart DateTime? @map("trial_start") @db.Timestamptz(3)
// subscription.trial_end
trialEnd DateTime? @map("trial_end") @db.Timestamptz(3)
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3)
@@unique([targetId, plan])
@@map("subscriptions")
}
model Invoice {
stripeInvoiceId String @id @map("stripe_invoice_id")
targetId String @map("target_id") @db.VarChar
currency String @db.VarChar(3)
// CNY 12.50 stored as 1250
amount Int @db.Integer
status String @db.VarChar(20)
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3)
// billing reason
reason String? @db.VarChar
lastPaymentError String? @map("last_payment_error") @db.Text
// stripe hosted invoice link
link String? @db.Text
@@index([targetId])
@@map("invoices")
}

View File

@@ -5,7 +5,6 @@ import {
Int,
Mutation,
Query,
ResolveField,
Resolver,
} from '@nestjs/graphql';
import { PrismaClient } from '@prisma/client';
@@ -37,7 +36,6 @@ import {
@Resolver(() => UserType)
export class UserResolver {
constructor(
private readonly prisma: PrismaClient,
private readonly storage: AvatarStorage,
private readonly users: UserService
) {}
@@ -72,16 +70,6 @@ export class UserResolver {
};
}
@ResolveField(() => Int, {
name: 'invoiceCount',
description: 'Get user invoice count',
})
async invoiceCount(@CurrentUser() user: CurrentUser) {
return this.prisma.userInvoice.count({
where: { userId: user.id },
});
}
@Mutation(() => UserType, {
name: 'uploadAvatar',
description: 'Upload user avatar',

View File

@@ -37,4 +37,4 @@ import {
})
export class WorkspaceModule {}
export type { InvitationType, WorkspaceType } from './types';
export { InvitationType, WorkspaceType } from './types';

View File

@@ -0,0 +1,29 @@
import { PrismaClient } from '@prisma/client';
import { loop } from './utils/loop';
export class UniversalSubscription1733125339942 {
// do the migration
static async up(db: PrismaClient) {
await loop(async (offset, take) => {
const oldSubscriptions = await db.deprecatedUserSubscription.findMany({
skip: offset,
take,
});
await db.subscription.createMany({
data: oldSubscriptions.map(s => ({
targetId: s.userId,
...s,
})),
});
return oldSubscriptions.length;
}, 50);
}
// revert the migration
static async down(_db: PrismaClient) {
// noop
}
}

View File

@@ -36,3 +36,5 @@ export class ConfigModule {
};
};
}
export { Runtime };

View File

@@ -412,15 +412,28 @@ export const USER_FRIENDLY_ERRORS = {
},
// Subscription Errors
unsupported_subscription_plan: {
type: 'invalid_input',
args: { plan: 'string' },
message: ({ plan }) => `Unsupported subscription plan: ${plan}.`,
},
failed_to_checkout: {
type: 'internal_server_error',
message: 'Failed to create checkout session.',
},
invalid_checkout_parameters: {
type: 'invalid_input',
message: 'Invalid checkout parameters provided.',
},
subscription_already_exists: {
type: 'resource_already_exists',
args: { plan: 'string' },
message: ({ plan }) => `You have already subscribed to the ${plan} plan.`,
},
invalid_subscription_parameters: {
type: 'invalid_input',
message: 'Invalid subscription parameters provided.',
},
subscription_not_exists: {
type: 'resource_not_found',
args: { plan: 'string' },
@@ -430,6 +443,10 @@ export const USER_FRIENDLY_ERRORS = {
type: 'action_forbidden',
message: 'Your subscription has already been canceled.',
},
subscription_has_not_been_canceled: {
type: 'action_forbidden',
message: 'Your subscription has not been canceled.',
},
subscription_expired: {
type: 'action_forbidden',
message: 'Your subscription has expired.',
@@ -453,6 +470,14 @@ export const USER_FRIENDLY_ERRORS = {
type: 'action_forbidden',
message: 'You cannot update an onetime payment subscription.',
},
workspace_id_required_for_team_subscription: {
type: 'invalid_input',
message: 'A workspace is required to checkout for team subscription.',
},
workspace_id_required_to_update_team_subscription: {
type: 'invalid_input',
message: 'Workspace id is required to update team subscription.',
},
// Copilot errors
copilot_session_not_found: {

View File

@@ -328,12 +328,28 @@ export class FailedToUpsertSnapshot extends UserFriendlyError {
super('internal_server_error', 'failed_to_upsert_snapshot', message);
}
}
@ObjectType()
class UnsupportedSubscriptionPlanDataType {
@Field() plan!: string
}
export class UnsupportedSubscriptionPlan extends UserFriendlyError {
constructor(args: UnsupportedSubscriptionPlanDataType, message?: string | ((args: UnsupportedSubscriptionPlanDataType) => string)) {
super('invalid_input', 'unsupported_subscription_plan', message, args);
}
}
export class FailedToCheckout extends UserFriendlyError {
constructor(message?: string) {
super('internal_server_error', 'failed_to_checkout', message);
}
}
export class InvalidCheckoutParameters extends UserFriendlyError {
constructor(message?: string) {
super('invalid_input', 'invalid_checkout_parameters', message);
}
}
@ObjectType()
class SubscriptionAlreadyExistsDataType {
@Field() plan!: string
@@ -344,6 +360,12 @@ export class SubscriptionAlreadyExists extends UserFriendlyError {
super('resource_already_exists', 'subscription_already_exists', message, args);
}
}
export class InvalidSubscriptionParameters extends UserFriendlyError {
constructor(message?: string) {
super('invalid_input', 'invalid_subscription_parameters', message);
}
}
@ObjectType()
class SubscriptionNotExistsDataType {
@Field() plan!: string
@@ -361,6 +383,12 @@ export class SubscriptionHasBeenCanceled extends UserFriendlyError {
}
}
export class SubscriptionHasNotBeenCanceled extends UserFriendlyError {
constructor(message?: string) {
super('action_forbidden', 'subscription_has_not_been_canceled', message);
}
}
export class SubscriptionExpired extends UserFriendlyError {
constructor(message?: string) {
super('action_forbidden', 'subscription_expired', message);
@@ -400,6 +428,18 @@ export class CantUpdateOnetimePaymentSubscription extends UserFriendlyError {
}
}
export class WorkspaceIdRequiredForTeamSubscription extends UserFriendlyError {
constructor(message?: string) {
super('invalid_input', 'workspace_id_required_for_team_subscription', message);
}
}
export class WorkspaceIdRequiredToUpdateTeamSubscription extends UserFriendlyError {
constructor(message?: string) {
super('invalid_input', 'workspace_id_required_to_update_team_subscription', message);
}
}
export class CopilotSessionNotFound extends UserFriendlyError {
constructor(message?: string) {
super('resource_not_found', 'copilot_session_not_found', message);
@@ -587,15 +627,21 @@ export enum ErrorNames {
PAGE_IS_NOT_PUBLIC,
FAILED_TO_SAVE_UPDATES,
FAILED_TO_UPSERT_SNAPSHOT,
UNSUPPORTED_SUBSCRIPTION_PLAN,
FAILED_TO_CHECKOUT,
INVALID_CHECKOUT_PARAMETERS,
SUBSCRIPTION_ALREADY_EXISTS,
INVALID_SUBSCRIPTION_PARAMETERS,
SUBSCRIPTION_NOT_EXISTS,
SUBSCRIPTION_HAS_BEEN_CANCELED,
SUBSCRIPTION_HAS_NOT_BEEN_CANCELED,
SUBSCRIPTION_EXPIRED,
SAME_SUBSCRIPTION_RECURRING,
CUSTOMER_PORTAL_CREATE_FAILED,
SUBSCRIPTION_PLAN_NOT_FOUND,
CANT_UPDATE_ONETIME_PAYMENT_SUBSCRIPTION,
WORKSPACE_ID_REQUIRED_FOR_TEAM_SUBSCRIPTION,
WORKSPACE_ID_REQUIRED_TO_UPDATE_TEAM_SUBSCRIPTION,
COPILOT_SESSION_NOT_FOUND,
COPILOT_SESSION_DELETED,
NO_COPILOT_PROVIDER_AVAILABLE,
@@ -624,5 +670,5 @@ registerEnumType(ErrorNames, {
export const ErrorDataUnionType = createUnionType({
name: 'ErrorDataUnion',
types: () =>
[UnknownOauthProviderDataType, MissingOauthQueryParameterDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, SpaceNotFoundDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, DocNotFoundDataType, DocAccessDeniedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderSideErrorDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType] as const,
[UnknownOauthProviderDataType, MissingOauthQueryParameterDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, SpaceNotFoundDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, DocNotFoundDataType, DocAccessDeniedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, UnsupportedSubscriptionPlanDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderSideErrorDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType] as const,
});

View File

@@ -29,6 +29,6 @@ defineStartupConfig('plugins.payment', {});
defineRuntimeConfig('plugins.payment', {
showLifetimePrice: {
desc: 'Whether enable lifetime price and allow user to pay for it.',
default: false,
default: true,
},
});

View File

@@ -19,7 +19,7 @@ export class SubscriptionCronJobs {
@Cron(CronExpression.EVERY_HOUR)
async cleanExpiredOnetimeSubscriptions() {
const subscriptions = await this.db.userSubscription.findMany({
const subscriptions = await this.db.subscription.findMany({
where: {
variant: SubscriptionVariant.Onetime,
end: {
@@ -30,7 +30,7 @@ export class SubscriptionCronJobs {
for (const subscription of subscriptions) {
this.event.emit('user.subscription.canceled', {
userId: subscription.userId,
userId: subscription.targetId,
plan: subscription.plan as SubscriptionPlan,
recurring: subscription.variant as SubscriptionRecurring,
});
@@ -42,10 +42,10 @@ export class SubscriptionCronJobs {
userId,
plan,
}: EventPayload<'user.subscription.canceled'>) {
await this.db.userSubscription.delete({
await this.db.subscription.delete({
where: {
userId_plan: {
userId,
targetId_plan: {
targetId: userId,
plan,
},
},

View File

@@ -2,19 +2,27 @@ import './config';
import { ServerFeature } from '../../core/config';
import { FeatureModule } from '../../core/features';
import { PermissionModule } from '../../core/permission';
import { UserModule } from '../../core/user';
import { Plugin } from '../registry';
import { StripeWebhookController } from './controller';
import { SubscriptionCronJobs } from './cron';
import { UserSubscriptionManager } from './manager';
import { SubscriptionResolver, UserSubscriptionResolver } from './resolver';
import {
UserSubscriptionManager,
WorkspaceSubscriptionManager,
} from './manager';
import {
SubscriptionResolver,
UserSubscriptionResolver,
WorkspaceSubscriptionResolver,
} from './resolver';
import { SubscriptionService } from './service';
import { StripeProvider } from './stripe';
import { StripeWebhook } from './webhook';
@Plugin({
name: 'payment',
imports: [FeatureModule, UserModule],
imports: [FeatureModule, UserModule, PermissionModule],
providers: [
StripeProvider,
SubscriptionService,
@@ -22,7 +30,9 @@ import { StripeWebhook } from './webhook';
UserSubscriptionResolver,
StripeWebhook,
UserSubscriptionManager,
WorkspaceSubscriptionManager,
SubscriptionCronJobs,
WorkspaceSubscriptionResolver,
],
controllers: [StripeWebhookController],
requires: [

View File

@@ -1,13 +1,23 @@
import { UserStripeCustomer } from '@prisma/client';
import { PrismaClient, UserStripeCustomer } from '@prisma/client';
import Stripe from 'stripe';
import { z } from 'zod';
import { UserNotFound } from '../../../fundamentals';
import { ScheduleManager } from '../schedule';
import {
encodeLookupKey,
KnownStripeInvoice,
KnownStripePrice,
KnownStripeSubscription,
LookupKey,
SubscriptionPlan,
SubscriptionRecurring,
SubscriptionVariant,
} from '../types';
export interface Subscription {
stripeSubscriptionId: string | null;
stripeScheduleId: string | null;
status: string;
plan: string;
recurring: string;
@@ -21,36 +31,225 @@ export interface Subscription {
}
export interface Invoice {
stripeInvoiceId: string;
currency: string;
amount: number;
status: string;
createdAt: Date;
reason: string | null;
lastPaymentError: string | null;
link: string | null;
}
export interface SubscriptionManager {
filterPrices(
export const SubscriptionIdentity = z.object({
plan: z.nativeEnum(SubscriptionPlan),
});
export const CheckoutParams = z.object({
plan: z.nativeEnum(SubscriptionPlan),
recurring: z.nativeEnum(SubscriptionRecurring),
variant: z.nativeEnum(SubscriptionVariant).nullable().optional(),
coupon: z.string().nullable().optional(),
quantity: z.number().min(1).nullable().optional(),
successCallbackLink: z.string(),
});
export abstract class SubscriptionManager {
protected readonly scheduleManager = new ScheduleManager(this.stripe);
constructor(
protected readonly stripe: Stripe,
protected readonly db: PrismaClient
) {}
abstract filterPrices(
prices: KnownStripePrice[],
customer?: UserStripeCustomer
): Promise<KnownStripePrice[]>;
): KnownStripePrice[] | Promise<KnownStripePrice[]>;
saveSubscription(
abstract checkout(
price: KnownStripePrice,
params: z.infer<typeof CheckoutParams>,
args: any
): Promise<Stripe.Checkout.Session>;
abstract saveStripeSubscription(
subscription: KnownStripeSubscription
): Promise<Subscription>;
deleteSubscription(subscription: KnownStripeSubscription): Promise<void>;
abstract deleteStripeSubscription(
subscription: KnownStripeSubscription
): Promise<void>;
getSubscription(
id: string,
plan: SubscriptionPlan
abstract getSubscription(
identity: z.infer<typeof SubscriptionIdentity>
): Promise<Subscription | null>;
abstract cancelSubscription(
subscription: Subscription
): Promise<Subscription>;
cancelSubscription(subscription: Subscription): Promise<Subscription>;
abstract resumeSubscription(
subscription: Subscription
): Promise<Subscription>;
resumeSubscription(subscription: Subscription): Promise<Subscription>;
updateSubscriptionRecurring(
abstract updateSubscriptionRecurring(
subscription: Subscription,
recurring: SubscriptionRecurring
): Promise<Subscription>;
abstract saveInvoice(knownInvoice: KnownStripeInvoice): Promise<Invoice>;
transformSubscription({
lookupKey,
stripeSubscription: subscription,
}: KnownStripeSubscription): Subscription {
return {
...lookupKey,
stripeScheduleId: subscription.schedule as string | null,
stripeSubscriptionId: subscription.id,
status: subscription.status,
start: new Date(subscription.current_period_start * 1000),
end: new Date(subscription.current_period_end * 1000),
trialStart: subscription.trial_start
? new Date(subscription.trial_start * 1000)
: null,
trialEnd: subscription.trial_end
? new Date(subscription.trial_end * 1000)
: null,
nextBillAt: !subscription.canceled_at
? new Date(subscription.current_period_end * 1000)
: null,
canceledAt: subscription.canceled_at
? new Date(subscription.canceled_at * 1000)
: null,
};
}
async transformInvoice({
stripeInvoice,
}: KnownStripeInvoice): Promise<Invoice> {
const status = stripeInvoice.status ?? 'void';
let error: string | boolean | null = null;
if (status !== 'paid') {
if (stripeInvoice.last_finalization_error) {
error = stripeInvoice.last_finalization_error.message ?? true;
} else if (
stripeInvoice.attempt_count > 1 &&
stripeInvoice.payment_intent
) {
const paymentIntent =
typeof stripeInvoice.payment_intent === 'string'
? await this.stripe.paymentIntents.retrieve(
stripeInvoice.payment_intent
)
: stripeInvoice.payment_intent;
if (paymentIntent.last_payment_error) {
error = paymentIntent.last_payment_error.message ?? true;
}
}
}
// fallback to generic error message
if (error === true) {
error = 'Payment Error. Please contact support.';
}
return {
stripeInvoiceId: stripeInvoice.id,
status,
link: stripeInvoice.hosted_invoice_url || null,
reason: stripeInvoice.billing_reason,
amount: stripeInvoice.total,
currency: stripeInvoice.currency,
lastPaymentError: error,
};
}
async getOrCreateCustomer(userId: string): Promise<UserStripeCustomer> {
const user = await this.db.user.findUnique({
where: {
id: userId,
},
select: {
email: true,
userStripeCustomer: true,
},
});
if (!user) {
throw new UserNotFound();
}
let customer = user.userStripeCustomer;
if (!customer) {
const stripeCustomersList = await this.stripe.customers.list({
email: user.email,
limit: 1,
});
let stripeCustomer: Stripe.Customer | undefined;
if (stripeCustomersList.data.length) {
stripeCustomer = stripeCustomersList.data[0];
} else {
stripeCustomer = await this.stripe.customers.create({
email: user.email,
});
}
customer = await this.db.userStripeCustomer.create({
data: {
userId,
stripeCustomerId: stripeCustomer.id,
},
});
}
return customer;
}
protected async getPrice(
lookupKey: LookupKey
): Promise<KnownStripePrice | null> {
const prices = await this.stripe.prices.list({
lookup_keys: [encodeLookupKey(lookupKey)],
limit: 1,
});
const price = prices.data[0];
return price
? {
lookupKey,
price,
}
: null;
}
protected async getCouponFromPromotionCode(
userFacingPromotionCode: string,
customer: UserStripeCustomer
) {
const list = await this.stripe.promotionCodes.list({
code: userFacingPromotionCode,
active: true,
limit: 1,
});
const code = list.data[0];
if (!code) {
return null;
}
// the coupons are always bound to products, we need to check it first
// but the logic would be too complicated, and stripe will complain if the code is not applicable when checking out
// It's safe to skip the check here
// code.coupon.applies_to.products.forEach()
// check if the code is bound to a specific customer
return !code.customer ||
(typeof code.customer === 'string'
? code.customer === customer.stripeCustomerId
: code.customer.id === customer.stripeCustomerId)
? code.coupon.id
: null;
}
}

View File

@@ -1,2 +1,3 @@
export * from './common';
export * from './user';
export * from './workspace';

View File

@@ -1,10 +1,8 @@
import { Injectable } from '@nestjs/common';
import {
PrismaClient,
UserStripeCustomer,
UserSubscription,
} from '@prisma/client';
import { PrismaClient, UserStripeCustomer } from '@prisma/client';
import { omit, pick } from 'lodash-es';
import Stripe from 'stripe';
import { z } from 'zod';
import {
EarlyAccessType,
@@ -14,6 +12,9 @@ import {
Config,
EventEmitter,
InternalServerError,
SubscriptionAlreadyExists,
SubscriptionPlanNotFound,
URLHelper,
} from '../../../fundamentals';
import {
CouponType,
@@ -26,7 +27,7 @@ import {
SubscriptionStatus,
SubscriptionVariant,
} from '../types';
import { SubscriptionManager } from './common';
import { CheckoutParams, Subscription, SubscriptionManager } from './common';
interface PriceStrategyStatus {
proEarlyAccess: boolean;
@@ -36,15 +37,30 @@ interface PriceStrategyStatus {
onetime: boolean;
}
export const UserSubscriptionIdentity = z.object({
plan: z.enum([SubscriptionPlan.Pro, SubscriptionPlan.AI]),
userId: z.string(),
});
export const UserSubscriptionCheckoutArgs = z.object({
user: z.object({
id: z.string(),
email: z.string(),
}),
});
@Injectable()
export class UserSubscriptionManager implements SubscriptionManager {
export class UserSubscriptionManager extends SubscriptionManager {
constructor(
private readonly db: PrismaClient,
stripe: Stripe,
db: PrismaClient,
private readonly config: Config,
private readonly stripe: Stripe,
private readonly feature: FeatureManagementService,
private readonly event: EventEmitter
) {}
private readonly event: EventEmitter,
private readonly url: URLHelper
) {
super(stripe, db);
}
async filterPrices(
prices: KnownStripePrice[],
@@ -71,11 +87,105 @@ export class UserSubscriptionManager implements SubscriptionManager {
return availablePrices;
}
async getSubscription(userId: string, plan: SubscriptionPlan) {
return this.db.userSubscription.findFirst({
async checkout(
price: KnownStripePrice,
params: z.infer<typeof CheckoutParams>,
{ user }: z.infer<typeof UserSubscriptionCheckoutArgs>
) {
const lookupKey = price.lookupKey;
const subscription = await this.getSubscription({
// @ts-expect-error filtered already
plan: price.lookupKey.plan,
user,
});
if (
subscription &&
// do not allow to re-subscribe unless
!(
/* current subscription is a onetime subscription and so as the one that's checking out */
(
(subscription.variant === SubscriptionVariant.Onetime &&
lookupKey.variant === SubscriptionVariant.Onetime) ||
/* current subscription is normal subscription and is checking-out a lifetime subscription */
(subscription.recurring !== SubscriptionRecurring.Lifetime &&
subscription.variant !== SubscriptionVariant.Onetime &&
lookupKey.recurring === SubscriptionRecurring.Lifetime)
)
)
) {
throw new SubscriptionAlreadyExists({ plan: lookupKey.plan });
}
const customer = await this.getOrCreateCustomer(user.id);
const strategy = await this.strategyStatus(customer);
const available = await this.isPriceAvailable(price, {
...strategy,
onetime: true,
});
if (!available) {
throw new SubscriptionPlanNotFound({
plan: lookupKey.plan,
recurring: lookupKey.recurring,
});
}
const discounts = await (async () => {
const coupon = await this.getBuildInCoupon(customer, price);
if (coupon) {
return { discounts: [{ coupon }] };
} else if (params.coupon) {
const couponId = await this.getCouponFromPromotionCode(
params.coupon,
customer
);
if (couponId) {
return { discounts: [{ coupon: couponId }] };
}
}
return { allow_promotion_codes: true };
})();
// mode: 'subscription' or 'payment' for lifetime and onetime payment
const mode =
lookupKey.recurring === SubscriptionRecurring.Lifetime ||
lookupKey.variant === SubscriptionVariant.Onetime
? {
mode: 'payment' as const,
invoice_creation: {
enabled: true,
},
}
: {
mode: 'subscription' as const,
};
return this.stripe.checkout.sessions.create({
line_items: [
{
price: price.price.id,
quantity: 1,
},
],
tax_id_collection: {
enabled: true,
},
...discounts,
...mode,
success_url: this.url.link(params.successCallbackLink, {
session_id: '{CHECKOUT_SESSION_ID}',
}),
customer: customer.stripeCustomerId,
});
}
async getSubscription(args: z.infer<typeof UserSubscriptionIdentity>) {
return this.db.subscription.findFirst({
where: {
userId,
plan,
targetId: args.userId,
plan: args.plan,
status: {
in: [SubscriptionStatus.Active, SubscriptionStatus.Trialing],
},
@@ -83,11 +193,8 @@ export class UserSubscriptionManager implements SubscriptionManager {
});
}
async saveSubscription({
userId,
lookupKey,
stripeSubscription: subscription,
}: KnownStripeSubscription) {
async saveStripeSubscription(subscription: KnownStripeSubscription) {
const { userId, lookupKey, stripeSubscription } = subscription;
// update features first, features modify are idempotent
// so there is no need to skip if a subscription already exists.
// TODO(@forehalo):
@@ -99,43 +206,85 @@ export class UserSubscriptionManager implements SubscriptionManager {
recurring: lookupKey.recurring,
});
const commonData = {
status: subscription.status,
stripeScheduleId: subscription.schedule as string | null,
nextBillAt: !subscription.canceled_at
? new Date(subscription.current_period_end * 1000)
: null,
canceledAt: subscription.canceled_at
? new Date(subscription.canceled_at * 1000)
: null,
};
const subscriptionData = this.transformSubscription(subscription);
return await this.db.userSubscription.upsert({
// @deprecated backward compatibility
await this.db.deprecatedUserSubscription.upsert({
where: {
stripeSubscriptionId: subscription.id,
stripeSubscriptionId: stripeSubscription.id,
},
update: commonData,
update: pick(subscriptionData, [
'status',
'stripeScheduleId',
'nextBillAt',
'canceledAt',
]),
create: {
userId,
...lookupKey,
stripeSubscriptionId: subscription.id,
start: new Date(subscription.current_period_start * 1000),
end: new Date(subscription.current_period_end * 1000),
trialStart: subscription.trial_start
? new Date(subscription.trial_start * 1000)
: null,
trialEnd: subscription.trial_end
? new Date(subscription.trial_end * 1000)
: null,
...commonData,
...subscriptionData,
},
});
return this.db.subscription.upsert({
where: {
stripeSubscriptionId: stripeSubscription.id,
},
update: pick(subscriptionData, [
'status',
'stripeScheduleId',
'nextBillAt',
'canceledAt',
]),
create: {
targetId: userId,
...subscriptionData,
},
});
}
async cancelSubscription(subscription: UserSubscription) {
return this.db.userSubscription.update({
async deleteStripeSubscription({
userId,
lookupKey,
stripeSubscription,
}: KnownStripeSubscription) {
const deleted = await this.db.subscription.deleteMany({
where: {
id: subscription.id,
stripeSubscriptionId: stripeSubscription.id,
},
});
// @deprecated backward compatibility
await this.db.deprecatedUserSubscription.deleteMany({
where: {
stripeSubscriptionId: stripeSubscription.id,
},
});
if (deleted.count > 0) {
this.event.emit('user.subscription.canceled', {
userId,
plan: lookupKey.plan,
recurring: lookupKey.recurring,
});
}
}
async cancelSubscription(subscription: Subscription) {
// @deprecated backward compatibility
await this.db.deprecatedUserSubscription.updateMany({
where: {
stripeSubscriptionId: subscription.stripeSubscriptionId,
},
data: {
canceledAt: new Date(),
nextBillAt: null,
},
});
return this.db.subscription.update({
where: {
// @ts-expect-error checked outside
stripeSubscriptionId: subscription.stripeSubscriptionId,
},
data: {
canceledAt: new Date(),
@@ -144,9 +293,23 @@ export class UserSubscriptionManager implements SubscriptionManager {
});
}
async resumeSubscription(subscription: UserSubscription) {
return this.db.userSubscription.update({
where: { id: subscription.id },
async resumeSubscription(subscription: Subscription) {
// @deprecated backward compatibility
await this.db.deprecatedUserSubscription.updateMany({
where: {
stripeSubscriptionId: subscription.stripeSubscriptionId,
},
data: {
canceledAt: null,
nextBillAt: subscription.end,
},
});
return this.db.subscription.update({
where: {
// @ts-expect-error checked outside
stripeSubscriptionId: subscription.stripeSubscriptionId,
},
data: {
canceledAt: null,
nextBillAt: subscription.end,
@@ -155,34 +318,30 @@ export class UserSubscriptionManager implements SubscriptionManager {
}
async updateSubscriptionRecurring(
subscription: UserSubscription,
subscription: Subscription,
recurring: SubscriptionRecurring
) {
return this.db.userSubscription.update({
where: { id: subscription.id },
// @deprecated backward compatibility
await this.db.deprecatedUserSubscription.updateMany({
where: {
stripeSubscriptionId: subscription.stripeSubscriptionId,
},
data: { recurring },
});
return this.db.subscription.update({
where: {
// @ts-expect-error checked outside
stripeSubscriptionId: subscription.stripeSubscriptionId,
},
data: { recurring },
});
}
async deleteSubscription({
userId,
lookupKey,
stripeSubscription,
}: KnownStripeSubscription) {
await this.db.userSubscription.delete({
where: {
stripeSubscriptionId: stripeSubscription.id,
},
});
this.event.emit('user.subscription.canceled', {
userId,
plan: lookupKey.plan,
recurring: lookupKey.recurring,
});
}
async validatePrice(price: KnownStripePrice, customer: UserStripeCustomer) {
private async getBuildInCoupon(
customer: UserStripeCustomer,
price: KnownStripePrice
) {
const strategyStatus = await this.strategyStatus(customer);
// onetime price is allowed for checkout
@@ -192,7 +351,7 @@ export class UserSubscriptionManager implements SubscriptionManager {
return null;
}
let coupon: CouponType | null = null;
let coupon: CouponType | undefined;
if (price.lookupKey.variant === SubscriptionVariant.EA) {
if (price.lookupKey.plan === SubscriptionPlan.Pro) {
@@ -207,69 +366,40 @@ export class UserSubscriptionManager implements SubscriptionManager {
}
}
return {
price,
coupon,
};
return coupon;
}
async saveInvoice(knownInvoice: KnownStripeInvoice) {
const { userId, lookupKey, stripeInvoice } = knownInvoice;
const status = stripeInvoice.status ?? 'void';
let error: string | boolean | null = null;
const invoiceData = await this.transformInvoice(knownInvoice);
if (status !== 'paid') {
if (stripeInvoice.last_finalization_error) {
error = stripeInvoice.last_finalization_error.message ?? true;
} else if (
stripeInvoice.attempt_count > 1 &&
stripeInvoice.payment_intent
) {
const paymentIntent =
typeof stripeInvoice.payment_intent === 'string'
? await this.stripe.paymentIntents.retrieve(
stripeInvoice.payment_intent
)
: stripeInvoice.payment_intent;
if (paymentIntent.last_payment_error) {
error = paymentIntent.last_payment_error.message ?? true;
}
}
}
// fallback to generic error message
if (error === true) {
error = 'Payment Error. Please contact support.';
}
const invoice = this.db.userInvoice.upsert({
// @deprecated backward compatibility
await this.db.deprecatedUserInvoice.upsert({
where: {
stripeInvoiceId: stripeInvoice.id,
},
update: {
status,
link: stripeInvoice.hosted_invoice_url,
amount: stripeInvoice.total,
currency: stripeInvoice.currency,
lastPaymentError: error,
},
update: omit(invoiceData, 'stripeInvoiceId'),
create: {
userId,
...invoiceData,
},
});
const invoice = this.db.invoice.upsert({
where: {
stripeInvoiceId: stripeInvoice.id,
status,
link: stripeInvoice.hosted_invoice_url,
reason: stripeInvoice.billing_reason,
amount: stripeInvoice.total,
currency: stripeInvoice.currency,
lastPaymentError: error,
},
update: omit(invoiceData, 'stripeInvoiceId'),
create: {
targetId: userId,
...invoiceData,
},
});
// onetime and lifetime subscription is a special "subscription" that doesn't get involved with stripe subscription system
// we track the deals by invoice only.
if (status === 'paid') {
if (stripeInvoice.status === 'paid') {
if (lookupKey.recurring === SubscriptionRecurring.Lifetime) {
await this.saveLifetimeSubscription(knownInvoice);
} else if (lookupKey.variant === SubscriptionVariant.Onetime) {
@@ -282,45 +412,49 @@ export class UserSubscriptionManager implements SubscriptionManager {
async saveLifetimeSubscription(
knownInvoice: KnownStripeInvoice
): Promise<UserSubscription> {
): Promise<Subscription> {
// cancel previous non-lifetime subscription
const prevSubscription = await this.db.userSubscription.findUnique({
const prevSubscription = await this.db.subscription.findUnique({
where: {
userId_plan: {
userId: knownInvoice.userId,
targetId_plan: {
targetId: knownInvoice.userId,
plan: SubscriptionPlan.Pro,
},
},
});
let subscription: UserSubscription;
if (prevSubscription && prevSubscription.stripeSubscriptionId) {
subscription = await this.db.userSubscription.update({
where: {
id: prevSubscription.id,
},
data: {
stripeScheduleId: null,
stripeSubscriptionId: null,
plan: knownInvoice.lookupKey.plan,
recurring: SubscriptionRecurring.Lifetime,
start: new Date(),
end: null,
status: SubscriptionStatus.Active,
nextBillAt: null,
},
});
let subscription: Subscription;
if (prevSubscription) {
if (prevSubscription.stripeSubscriptionId) {
subscription = await this.db.subscription.update({
where: {
id: prevSubscription.id,
},
data: {
stripeScheduleId: null,
stripeSubscriptionId: null,
plan: knownInvoice.lookupKey.plan,
recurring: SubscriptionRecurring.Lifetime,
start: new Date(),
end: null,
status: SubscriptionStatus.Active,
nextBillAt: null,
},
});
await this.stripe.subscriptions.cancel(
prevSubscription.stripeSubscriptionId,
{
prorate: true,
}
);
await this.stripe.subscriptions.cancel(
prevSubscription.stripeSubscriptionId,
{
prorate: true,
}
);
} else {
subscription = prevSubscription;
}
} else {
subscription = await this.db.userSubscription.create({
subscription = await this.db.subscription.create({
data: {
userId: knownInvoice.userId,
targetId: knownInvoice.userId,
stripeSubscriptionId: null,
plan: knownInvoice.lookupKey.plan,
recurring: SubscriptionRecurring.Lifetime,
@@ -343,12 +477,13 @@ export class UserSubscriptionManager implements SubscriptionManager {
async saveOnetimePaymentSubscription(
knownInvoice: KnownStripeInvoice
): Promise<UserSubscription> {
): Promise<Subscription> {
// TODO(@forehalo): identify whether the invoice has already been redeemed.
const { userId, lookupKey } = knownInvoice;
const existingSubscription = await this.db.userSubscription.findUnique({
const existingSubscription = await this.db.subscription.findUnique({
where: {
userId_plan: {
userId,
targetId_plan: {
targetId: userId,
plan: lookupKey.plan,
},
},
@@ -362,7 +497,7 @@ export class UserSubscriptionManager implements SubscriptionManager {
60 *
1000;
let subscription: UserSubscription;
let subscription: Subscription;
// extends the subscription time if exists
if (existingSubscription) {
@@ -385,16 +520,16 @@ export class UserSubscriptionManager implements SubscriptionManager {
),
};
subscription = await this.db.userSubscription.update({
subscription = await this.db.subscription.update({
where: {
id: existingSubscription.id,
},
data: period,
});
} else {
subscription = await this.db.userSubscription.create({
subscription = await this.db.subscription.create({
data: {
userId,
targetId: userId,
stripeSubscriptionId: null,
...lookupKey,
start: new Date(),

View File

@@ -0,0 +1,305 @@
import { Injectable } from '@nestjs/common';
import { PrismaClient, UserStripeCustomer } from '@prisma/client';
import { omit, pick } from 'lodash-es';
import Stripe from 'stripe';
import { z } from 'zod';
import {
EventEmitter,
type EventPayload,
OnEvent,
SubscriptionAlreadyExists,
URLHelper,
} from '../../../fundamentals';
import {
KnownStripeInvoice,
KnownStripePrice,
KnownStripeSubscription,
retriveLookupKeyFromStripeSubscription,
SubscriptionPlan,
SubscriptionRecurring,
SubscriptionStatus,
} from '../types';
import {
CheckoutParams,
Invoice,
Subscription,
SubscriptionManager,
} from './common';
export const WorkspaceSubscriptionIdentity = z.object({
plan: z.literal(SubscriptionPlan.Team),
workspaceId: z.string(),
});
export const WorkspaceSubscriptionCheckoutArgs = z.object({
plan: z.literal(SubscriptionPlan.Team),
workspaceId: z.string(),
user: z.object({
id: z.string(),
email: z.string(),
}),
});
@Injectable()
export class WorkspaceSubscriptionManager extends SubscriptionManager {
constructor(
stripe: Stripe,
db: PrismaClient,
private readonly url: URLHelper,
private readonly event: EventEmitter
) {
super(stripe, db);
}
filterPrices(
prices: KnownStripePrice[],
_customer?: UserStripeCustomer
): KnownStripePrice[] {
return prices.filter(
price => price.lookupKey.plan === SubscriptionPlan.Team
);
}
async checkout(
{ price }: KnownStripePrice,
params: z.infer<typeof CheckoutParams>,
args: z.infer<typeof WorkspaceSubscriptionCheckoutArgs>
) {
const subscription = await this.getSubscription({
plan: SubscriptionPlan.Team,
workspaceId: args.workspaceId,
});
if (subscription) {
throw new SubscriptionAlreadyExists({ plan: SubscriptionPlan.Team });
}
const customer = await this.getOrCreateCustomer(args.user.id);
const discounts = await (async () => {
if (params.coupon) {
const couponId = await this.getCouponFromPromotionCode(
params.coupon,
customer
);
if (couponId) {
return { discounts: [{ coupon: couponId }] };
}
}
return { allow_promotion_codes: true };
})();
const count = await this.db.workspaceUserPermission.count({
where: {
workspaceId: args.workspaceId,
// @TODO(darksky): replace with [status: WorkspaceUserPermissionStatus.Accepted]
accepted: true,
},
});
return this.stripe.checkout.sessions.create({
line_items: [
{
price: price.id,
quantity: count,
},
],
tax_id_collection: {
enabled: true,
},
...discounts,
mode: 'subscription',
success_url: this.url.link(params.successCallbackLink),
customer: customer.stripeCustomerId,
subscription_data: {
metadata: {
workspaceId: args.workspaceId,
},
},
});
}
async saveStripeSubscription(subscription: KnownStripeSubscription) {
const { lookupKey, quantity, stripeSubscription } = subscription;
const workspaceId = stripeSubscription.metadata.workspaceId;
if (!workspaceId) {
throw new Error(
'Workspace ID is required in workspace subscription metadata'
);
}
this.event.emit('workspace.subscription.activated', {
workspaceId,
plan: lookupKey.plan,
recurring: lookupKey.recurring,
quantity,
});
const subscriptionData = this.transformSubscription(subscription);
return this.db.subscription.upsert({
where: {
stripeSubscriptionId: stripeSubscription.id,
},
update: {
quantity,
...pick(subscriptionData, [
'status',
'stripeScheduleId',
'nextBillAt',
'canceledAt',
]),
},
create: {
targetId: workspaceId,
quantity,
...subscriptionData,
},
});
}
async deleteStripeSubscription({
lookupKey,
stripeSubscription,
}: KnownStripeSubscription) {
const workspaceId = stripeSubscription.metadata.workspaceId;
if (!workspaceId) {
throw new Error(
'Workspace ID is required in workspace subscription metadata'
);
}
const deleted = await this.db.subscription.deleteMany({
where: { stripeSubscriptionId: stripeSubscription.id },
});
if (deleted.count > 0) {
this.event.emit('workspace.subscription.canceled', {
workspaceId,
plan: lookupKey.plan,
recurring: lookupKey.recurring,
});
}
}
getSubscription(identity: z.infer<typeof WorkspaceSubscriptionIdentity>) {
return this.db.subscription.findFirst({
where: {
targetId: identity.workspaceId,
status: {
in: [SubscriptionStatus.Active, SubscriptionStatus.Trialing],
},
},
});
}
async cancelSubscription(subscription: Subscription) {
return await this.db.subscription.update({
where: {
// @ts-expect-error checked outside
stripeSubscriptionId: subscription.stripeSubscriptionId,
},
data: {
canceledAt: new Date(),
nextBillAt: null,
},
});
}
resumeSubscription(subscription: Subscription): Promise<Subscription> {
return this.db.subscription.update({
where: {
// @ts-expect-error checked outside
stripeSubscriptionId: subscription.stripeSubscriptionId,
},
data: {
canceledAt: null,
nextBillAt: subscription.end,
},
});
}
updateSubscriptionRecurring(
subscription: Subscription,
recurring: SubscriptionRecurring
): Promise<Subscription> {
return this.db.subscription.update({
where: {
// @ts-expect-error checked outside
stripeSubscriptionId: subscription.stripeSubscriptionId,
},
data: { recurring },
});
}
async saveInvoice(knownInvoice: KnownStripeInvoice): Promise<Invoice> {
const { metadata, stripeInvoice } = knownInvoice;
const workspaceId = metadata.workspaceId;
if (!workspaceId) {
throw new Error('Workspace ID is required in workspace invoice metadata');
}
const invoiceData = await this.transformInvoice(knownInvoice);
return this.db.invoice.upsert({
where: {
stripeInvoiceId: stripeInvoice.id,
},
update: omit(invoiceData, 'stripeInvoiceId'),
create: {
targetId: workspaceId,
...invoiceData,
},
});
}
@OnEvent('workspace.members.updated')
async onMembersUpdated({
workspaceId,
count,
}: EventPayload<'workspace.members.updated'>) {
const subscription = await this.getSubscription({
plan: SubscriptionPlan.Team,
workspaceId,
});
if (!subscription || !subscription.stripeSubscriptionId) {
return;
}
const stripeSubscription = await this.stripe.subscriptions.retrieve(
subscription.stripeSubscriptionId
);
const lookupKey =
retriveLookupKeyFromStripeSubscription(stripeSubscription);
await this.stripe.subscriptions.update(stripeSubscription.id, {
items: [
{
id: stripeSubscription.items.data[0].id,
quantity: count,
},
],
payment_behavior: 'pending_if_incomplete',
proration_behavior:
lookupKey?.recurring === SubscriptionRecurring.Yearly
? 'always_invoice'
: 'none',
});
if (subscription.stripeScheduleId) {
const schedule = await this.scheduleManager.fromSchedule(
subscription.stripeScheduleId
);
await schedule.updateQuantity(count);
}
}
}

View File

@@ -12,15 +12,23 @@ import {
ResolveField,
Resolver,
} from '@nestjs/graphql';
import type { User, UserSubscription } from '@prisma/client';
import type { User } from '@prisma/client';
import { PrismaClient } from '@prisma/client';
import { GraphQLJSONObject } from 'graphql-scalars';
import { groupBy } from 'lodash-es';
import { z } from 'zod';
import { CurrentUser, Public } from '../../core/auth';
import { Permission, PermissionService } from '../../core/permission';
import { UserType } from '../../core/user';
import { AccessDenied, FailedToCheckout, URLHelper } from '../../fundamentals';
import { Invoice, Subscription } from './manager';
import { SubscriptionService } from './service';
import { WorkspaceType } from '../../core/workspaces';
import {
AccessDenied,
FailedToCheckout,
WorkspaceIdRequiredToUpdateTeamSubscription,
} from '../../fundamentals';
import { Invoice, Subscription, WorkspaceSubscriptionManager } from './manager';
import { CheckoutParams, SubscriptionService } from './service';
import {
InvoiceStatus,
SubscriptionPlan,
@@ -57,7 +65,7 @@ class SubscriptionPrice {
}
@ObjectType()
export class SubscriptionType implements Subscription {
export class SubscriptionType implements Partial<Subscription> {
@Field(() => SubscriptionPlan, {
description:
"The 'Free' plan just exists to be a placeholder and for the type convenience of frontend.\nThere won't actually be a subscription with plan 'Free'",
@@ -107,7 +115,7 @@ export class SubscriptionType implements Subscription {
}
@ObjectType()
export class InvoiceType implements Invoice {
export class InvoiceType implements Partial<Invoice> {
@Field()
currency!: string;
@@ -138,7 +146,7 @@ export class InvoiceType implements Invoice {
nullable: true,
deprecationReason: 'removed',
})
stripeInvoiceId!: string | null;
stripeInvoiceId?: string;
@Field(() => SubscriptionPlan, {
nullable: true,
@@ -154,7 +162,7 @@ export class InvoiceType implements Invoice {
}
@InputType()
class CreateCheckoutSessionInput {
class CreateCheckoutSessionInput implements z.infer<typeof CheckoutParams> {
@Field(() => SubscriptionRecurring, {
nullable: true,
defaultValue: SubscriptionRecurring.Yearly,
@@ -170,7 +178,7 @@ class CreateCheckoutSessionInput {
@Field(() => SubscriptionVariant, {
nullable: true,
})
variant?: SubscriptionVariant;
variant!: SubscriptionVariant | null;
@Field(() => String, { nullable: true })
coupon!: string | null;
@@ -180,17 +188,17 @@ class CreateCheckoutSessionInput {
@Field(() => String, {
nullable: true,
deprecationReason: 'use header `Idempotency-Key`',
deprecationReason: 'not required anymore',
})
idempotencyKey?: string;
@Field(() => GraphQLJSONObject, { nullable: true })
args!: { workspaceId?: string };
}
@Resolver(() => SubscriptionType)
export class SubscriptionResolver {
constructor(
private readonly service: SubscriptionService,
private readonly url: URLHelper
) {}
constructor(private readonly service: SubscriptionService) {}
@Public()
@Query(() => [SubscriptionPrice])
@@ -232,7 +240,11 @@ export class SubscriptionResolver {
}
// extend it when new plans are added
const fixedPlans = [SubscriptionPlan.Pro, SubscriptionPlan.AI];
const fixedPlans = [
SubscriptionPlan.Pro,
SubscriptionPlan.AI,
SubscriptionPlan.Team,
];
return fixedPlans.reduce((prices, plan) => {
const price = findPrice(plan);
@@ -255,26 +267,19 @@ export class SubscriptionResolver {
async createCheckoutSession(
@CurrentUser() user: CurrentUser,
@Args({ name: 'input', type: () => CreateCheckoutSessionInput })
input: CreateCheckoutSessionInput,
@Headers('idempotency-key') idempotencyKey?: string
input: CreateCheckoutSessionInput
) {
const session = await this.service.checkout({
const session = await this.service.checkout(input, {
plan: input.plan as any,
user,
lookupKey: {
plan: input.plan,
recurring: input.recurring,
variant: input.variant,
},
promotionCode: input.coupon,
redirectUrl: this.url.link(input.successCallbackLink),
idempotencyKey,
workspaceId: input.args?.workspaceId,
});
if (!session.url) {
throw new FailedToCheckout();
}
return session.url;
return session;
}
@Mutation(() => String, {
@@ -294,6 +299,8 @@ export class SubscriptionResolver {
defaultValue: SubscriptionPlan.Pro,
})
plan: SubscriptionPlan,
@Args({ name: 'workspaceId', type: () => String, nullable: true })
workspaceId: string | null,
@Headers('idempotency-key') idempotencyKey?: string,
@Args('idempotencyKey', {
type: () => String,
@@ -302,7 +309,25 @@ export class SubscriptionResolver {
})
_?: string
) {
return this.service.cancelSubscription(user.id, plan, idempotencyKey);
if (plan === SubscriptionPlan.Team) {
if (!workspaceId) {
throw new WorkspaceIdRequiredToUpdateTeamSubscription();
}
return this.service.cancelSubscription(
{ workspaceId, plan },
idempotencyKey
);
}
return this.service.cancelSubscription(
{
targetId: user.id,
// @ts-expect-error exam inside
plan,
},
idempotencyKey
);
}
@Mutation(() => SubscriptionType)
@@ -315,6 +340,8 @@ export class SubscriptionResolver {
defaultValue: SubscriptionPlan.Pro,
})
plan: SubscriptionPlan,
@Args({ name: 'workspaceId', type: () => String, nullable: true })
workspaceId: string | null,
@Headers('idempotency-key') idempotencyKey?: string,
@Args('idempotencyKey', {
type: () => String,
@@ -323,14 +350,30 @@ export class SubscriptionResolver {
})
_?: string
) {
return this.service.resumeSubscription(user.id, plan, idempotencyKey);
if (plan === SubscriptionPlan.Team) {
if (!workspaceId) {
throw new WorkspaceIdRequiredToUpdateTeamSubscription();
}
return this.service.resumeSubscription(
{ workspaceId, plan },
idempotencyKey
);
}
return this.service.resumeSubscription(
{
targetId: user.id,
// @ts-expect-error exam inside
plan,
},
idempotencyKey
);
}
@Mutation(() => SubscriptionType)
async updateSubscriptionRecurring(
@CurrentUser() user: CurrentUser,
@Args({ name: 'recurring', type: () => SubscriptionRecurring })
recurring: SubscriptionRecurring,
@Args({
name: 'plan',
type: () => SubscriptionPlan,
@@ -338,6 +381,10 @@ export class SubscriptionResolver {
defaultValue: SubscriptionPlan.Pro,
})
plan: SubscriptionPlan,
@Args({ name: 'workspaceId', type: () => String, nullable: true })
workspaceId: string | null,
@Args({ name: 'recurring', type: () => SubscriptionRecurring })
recurring: SubscriptionRecurring,
@Headers('idempotency-key') idempotencyKey?: string,
@Args('idempotencyKey', {
type: () => String,
@@ -346,9 +393,24 @@ export class SubscriptionResolver {
})
_?: string
) {
if (plan === SubscriptionPlan.Team) {
if (!workspaceId) {
throw new WorkspaceIdRequiredToUpdateTeamSubscription();
}
return this.service.updateSubscriptionRecurring(
{ workspaceId, plan },
recurring,
idempotencyKey
);
}
return this.service.updateSubscriptionRecurring(
user.id,
plan,
{
userId: user.id,
// @ts-expect-error exam inside
plan,
},
recurring,
idempotencyKey
);
@@ -363,14 +425,14 @@ export class UserSubscriptionResolver {
async subscriptions(
@CurrentUser() me: User,
@Parent() user: User
): Promise<UserSubscription[]> {
): Promise<Subscription[]> {
if (me.id !== user.id) {
throw new AccessDenied();
}
const subscriptions = await this.db.userSubscription.findMany({
const subscriptions = await this.db.subscription.findMany({
where: {
userId: user.id,
targetId: user.id,
status: SubscriptionStatus.Active,
},
});
@@ -389,6 +451,16 @@ export class UserSubscriptionResolver {
return subscriptions;
}
@ResolveField(() => Int, {
name: 'invoiceCount',
description: 'Get user invoice count',
})
async invoiceCount(@CurrentUser() user: CurrentUser) {
return this.db.invoice.count({
where: { targetId: user.id },
});
}
@ResolveField(() => [InvoiceType])
async invoices(
@CurrentUser() me: User,
@@ -401,14 +473,72 @@ export class UserSubscriptionResolver {
throw new AccessDenied();
}
return this.db.userInvoice.findMany({
return this.db.invoice.findMany({
where: {
userId: user.id,
targetId: user.id,
},
take,
skip,
orderBy: {
id: 'desc',
createdAt: 'desc',
},
});
}
}
@Resolver(() => WorkspaceType)
export class WorkspaceSubscriptionResolver {
constructor(
private readonly service: WorkspaceSubscriptionManager,
private readonly db: PrismaClient,
private readonly permission: PermissionService
) {}
@ResolveField(() => SubscriptionType, {
nullable: true,
description: 'The team subscription of the workspace, if exists.',
})
async subscription(@Parent() workspace: WorkspaceType) {
return this.service.getSubscription({
plan: SubscriptionPlan.Team,
workspaceId: workspace.id,
});
}
@ResolveField(() => Int, {
name: 'invoiceCount',
description: 'Get user invoice count',
})
async invoiceCount(
@CurrentUser() me: CurrentUser,
@Parent() workspace: WorkspaceType
) {
await this.permission.checkWorkspace(workspace.id, me.id, Permission.Owner);
return this.db.invoice.count({
where: {
targetId: workspace.id,
},
});
}
@ResolveField(() => [InvoiceType])
async invoices(
@CurrentUser() me: CurrentUser,
@Parent() workspace: WorkspaceType,
@Args('take', { type: () => Int, nullable: true, defaultValue: 8 })
take: number,
@Args('skip', { type: () => Int, nullable: true }) skip?: number
) {
await this.permission.checkWorkspace(workspace.id, me.id, Permission.Owner);
return this.db.invoice.findMany({
where: {
targetId: workspace.id,
},
take,
skip,
orderBy: {
createdAt: 'desc',
},
});
}

View File

@@ -101,7 +101,7 @@ export class ScheduleManager {
items: [
{
price: this.currentPhase.items[0].price as string,
quantity: 1,
quantity: this.currentPhase.items[0].quantity,
},
],
coupon: (this.currentPhase.coupon as string | null) ?? undefined,
@@ -143,10 +143,9 @@ export class ScheduleManager {
items: [
{
price: this.currentPhase.items[0].price as string,
quantity: 1,
quantity: this.currentPhase.items[0].quantity,
},
],
coupon: (this.currentPhase.coupon as string | null) ?? undefined,
start_date: this.currentPhase.start_date,
end_date: this.currentPhase.end_date,
metadata: {
@@ -161,7 +160,7 @@ export class ScheduleManager {
items: [
{
price: this.currentPhase.metadata.next_price,
quantity: 1,
quantity: this.currentPhase.items[0].quantity,
},
],
coupon: this.currentPhase.metadata.next_coupon || undefined,
@@ -212,6 +211,7 @@ export class ScheduleManager {
items: [
{
price: this.currentPhase.items[0].price as string,
quantity: this.currentPhase.items[0].quantity,
},
],
start_date: this.currentPhase.start_date,
@@ -221,6 +221,7 @@ export class ScheduleManager {
items: [
{
price: price,
quantity: this.currentPhase.items[0].quantity,
},
],
},
@@ -230,4 +231,31 @@ export class ScheduleManager {
);
}
}
async updateQuantity(quantity: number, idempotencyKey?: string) {
if (!this._schedule) {
throw new Error('No schedule');
}
if (!this.isActive || !this.currentPhase) {
throw new Error('Unexpected subscription schedule status');
}
await this.stripe.subscriptionSchedules.update(
this._schedule.id,
{
phases: this._schedule.phases.map(phase => ({
items: [
{
price: phase.items[0].price as string,
quantity,
},
],
start_date: phase.start_date,
end_date: phase.end_date,
})),
},
{ idempotencyKey }
);
}
}

View File

@@ -1,12 +1,8 @@
import { Injectable, Logger } from '@nestjs/common';
import type {
User,
UserInvoice,
UserStripeCustomer,
UserSubscription,
} from '@prisma/client';
import type { User, UserStripeCustomer } from '@prisma/client';
import { PrismaClient } from '@prisma/client';
import Stripe from 'stripe';
import { z } from 'zod';
import { CurrentUser } from '../../core/auth';
import { FeatureManagementService } from '../../core/features';
@@ -17,30 +13,56 @@ import {
Config,
CustomerPortalCreateFailed,
InternalServerError,
InvalidCheckoutParameters,
InvalidSubscriptionParameters,
OnEvent,
SameSubscriptionRecurring,
SubscriptionAlreadyExists,
SubscriptionExpired,
SubscriptionHasBeenCanceled,
SubscriptionHasNotBeenCanceled,
SubscriptionNotExists,
SubscriptionPlanNotFound,
UnsupportedSubscriptionPlan,
UserNotFound,
} from '../../fundamentals';
import { UserSubscriptionManager } from './manager';
import {
CheckoutParams,
Invoice,
Subscription,
SubscriptionManager,
UserSubscriptionCheckoutArgs,
UserSubscriptionIdentity,
UserSubscriptionManager,
WorkspaceSubscriptionCheckoutArgs,
WorkspaceSubscriptionIdentity,
WorkspaceSubscriptionManager,
} from './manager';
import { ScheduleManager } from './schedule';
import {
encodeLookupKey,
KnownStripeInvoice,
KnownStripePrice,
KnownStripeSubscription,
LookupKey,
retriveLookupKeyFromStripePrice,
retriveLookupKeyFromStripeSubscription,
SubscriptionPlan,
SubscriptionRecurring,
SubscriptionStatus,
SubscriptionVariant,
} from './types';
export const CheckoutExtraArgs = z.union([
UserSubscriptionCheckoutArgs,
WorkspaceSubscriptionCheckoutArgs,
]);
export const SubscriptionIdentity = z.union([
UserSubscriptionIdentity,
WorkspaceSubscriptionIdentity,
]);
export { CheckoutParams };
@Injectable()
export class SubscriptionService {
private readonly logger = new Logger(SubscriptionService.name);
@@ -52,143 +74,86 @@ export class SubscriptionService {
private readonly db: PrismaClient,
private readonly feature: FeatureManagementService,
private readonly user: UserService,
private readonly userManager: UserSubscriptionManager
private readonly userManager: UserSubscriptionManager,
private readonly workspaceManager: WorkspaceSubscriptionManager
) {}
async listPrices(user?: CurrentUser): Promise<KnownStripePrice[]> {
const customer = user ? await this.getOrCreateCustomer(user) : undefined;
// TODO(@forehalo): cache
const prices = await this.stripe.prices.list({
active: true,
limit: 100,
});
return this.userManager.filterPrices(
prices.data
.map(price => this.parseStripePrice(price))
.filter(Boolean) as KnownStripePrice[],
customer
);
private select(plan: SubscriptionPlan): SubscriptionManager {
switch (plan) {
case SubscriptionPlan.Team:
return this.workspaceManager;
case SubscriptionPlan.Pro:
case SubscriptionPlan.AI:
return this.userManager;
default:
throw new UnsupportedSubscriptionPlan({ plan });
}
}
async checkout({
user,
lookupKey,
promotionCode,
redirectUrl,
idempotencyKey,
}: {
user: CurrentUser;
lookupKey: LookupKey;
promotionCode?: string | null;
redirectUrl: string;
idempotencyKey?: string;
}) {
async listPrices(user?: CurrentUser): Promise<KnownStripePrice[]> {
const prices = await this.listStripePrices();
const customer = user
? await this.getOrCreateCustomer({
userId: user.id,
userEmail: user.email,
})
: undefined;
return [
...(await this.userManager.filterPrices(prices, customer)),
...this.workspaceManager.filterPrices(prices, customer),
];
}
async checkout(
params: z.infer<typeof CheckoutParams>,
args: z.infer<typeof CheckoutExtraArgs>
) {
const { plan, recurring, variant } = params;
if (
this.config.deploy &&
this.config.affine.canary &&
!this.feature.isStaff(user.email)
!this.feature.isStaff(args.user.email)
) {
throw new ActionForbidden();
}
const currentSubscription = await this.userManager.getSubscription(
user.id,
lookupKey.plan
);
const price = await this.getPrice({
plan,
recurring,
variant: variant ?? null,
});
if (
currentSubscription &&
// do not allow to re-subscribe unless
!(
/* current subscription is a onetime subscription and so as the one that's checking out */
(
(currentSubscription.variant === SubscriptionVariant.Onetime &&
lookupKey.variant === SubscriptionVariant.Onetime) ||
/* current subscription is normal subscription and is checking-out a lifetime subscription */
(currentSubscription.recurring !== SubscriptionRecurring.Lifetime &&
currentSubscription.variant !== SubscriptionVariant.Onetime &&
lookupKey.recurring === SubscriptionRecurring.Lifetime)
)
)
) {
throw new SubscriptionAlreadyExists({ plan: lookupKey.plan });
}
const price = await this.getPrice(lookupKey);
const customer = await this.getOrCreateCustomer(user);
const priceAndAutoCoupon = price
? await this.userManager.validatePrice(price, customer)
: null;
if (!priceAndAutoCoupon) {
if (!price) {
throw new SubscriptionPlanNotFound({
plan: lookupKey.plan,
recurring: lookupKey.recurring,
plan,
recurring,
});
}
let discounts: Stripe.Checkout.SessionCreateParams['discounts'] = [];
const manager = this.select(plan);
const result = CheckoutExtraArgs.safeParse(args);
if (priceAndAutoCoupon.coupon) {
discounts = [{ coupon: priceAndAutoCoupon.coupon }];
} else if (promotionCode) {
const coupon = await this.getCouponFromPromotionCode(
promotionCode,
customer
);
if (coupon) {
discounts = [{ coupon }];
}
if (!result.success) {
throw new InvalidCheckoutParameters();
}
return await this.stripe.checkout.sessions.create(
{
line_items: [
{
price: priceAndAutoCoupon.price.price.id,
quantity: 1,
},
],
tax_id_collection: {
enabled: true,
},
// discount
...(discounts.length ? { discounts } : { allow_promotion_codes: true }),
// mode: 'subscription' or 'payment' for lifetime and onetime payment
...(lookupKey.recurring === SubscriptionRecurring.Lifetime ||
lookupKey.variant === SubscriptionVariant.Onetime
? {
mode: 'payment',
invoice_creation: {
enabled: true,
},
}
: {
mode: 'subscription',
}),
success_url: redirectUrl,
customer: customer.stripeCustomerId,
customer_update: {
address: 'auto',
name: 'auto',
},
},
{ idempotencyKey }
);
return manager.checkout(price, params, args);
}
async cancelSubscription(
userId: string,
plan: SubscriptionPlan,
identity: z.infer<typeof SubscriptionIdentity>,
idempotencyKey?: string
): Promise<UserSubscription> {
const subscription = await this.userManager.getSubscription(userId, plan);
): Promise<Subscription> {
this.assertSubscriptionIdentity(identity);
const manager = this.select(identity.plan);
const subscription = await manager.getSubscription(identity);
if (!subscription) {
throw new SubscriptionNotExists({ plan });
throw new SubscriptionNotExists({ plan: identity.plan });
}
if (!subscription.stripeSubscriptionId) {
@@ -202,7 +167,7 @@ export class SubscriptionService {
}
// update the subscription in db optimistically
const newSubscription = this.userManager.cancelSubscription(subscription);
const newSubscription = manager.cancelSubscription(subscription);
// should release the schedule first
if (subscription.stripeScheduleId) {
@@ -224,18 +189,21 @@ export class SubscriptionService {
}
async resumeSubscription(
userId: string,
plan: SubscriptionPlan,
identity: z.infer<typeof SubscriptionIdentity>,
idempotencyKey?: string
): Promise<UserSubscription> {
const subscription = await this.userManager.getSubscription(userId, plan);
): Promise<Subscription> {
this.assertSubscriptionIdentity(identity);
const manager = this.select(identity.plan);
const subscription = await manager.getSubscription(identity);
if (!subscription) {
throw new SubscriptionNotExists({ plan });
throw new SubscriptionNotExists({ plan: identity.plan });
}
if (!subscription.canceledAt) {
throw new SubscriptionHasBeenCanceled();
throw new SubscriptionHasNotBeenCanceled();
}
if (!subscription.stripeSubscriptionId || !subscription.end) {
@@ -249,8 +217,7 @@ export class SubscriptionService {
}
// update the subscription in db optimistically
const newSubscription =
await this.userManager.resumeSubscription(subscription);
const newSubscription = await manager.resumeSubscription(subscription);
if (subscription.stripeScheduleId) {
const manager = await this.scheduleManager.fromSchedule(
@@ -269,15 +236,17 @@ export class SubscriptionService {
}
async updateSubscriptionRecurring(
userId: string,
plan: SubscriptionPlan,
identity: z.infer<typeof SubscriptionIdentity>,
recurring: SubscriptionRecurring,
idempotencyKey?: string
): Promise<UserSubscription> {
const subscription = await this.userManager.getSubscription(userId, plan);
): Promise<Subscription> {
this.assertSubscriptionIdentity(identity);
const manager = this.select(identity.plan);
const subscription = await manager.getSubscription(identity);
if (!subscription) {
throw new SubscriptionNotExists({ plan });
throw new SubscriptionNotExists({ plan: identity.plan });
}
if (!subscription.stripeSubscriptionId) {
@@ -293,25 +262,29 @@ export class SubscriptionService {
}
const price = await this.getPrice({
plan,
plan: identity.plan,
recurring,
variant: null,
});
if (!price) {
throw new SubscriptionPlanNotFound({ plan, recurring });
throw new SubscriptionPlanNotFound({
plan: identity.plan,
recurring,
});
}
// update the subscription in db optimistically
const newSubscription = this.userManager.updateSubscriptionRecurring(
const newSubscription = manager.updateSubscriptionRecurring(
subscription,
recurring
);
const manager = await this.scheduleManager.fromSubscription(
const scheduleManager = await this.scheduleManager.fromSubscription(
subscription.stripeSubscriptionId
);
await manager.update(price.price.id, idempotencyKey);
await scheduleManager.update(price.price.id, idempotencyKey);
return newSubscription;
}
@@ -339,14 +312,14 @@ export class SubscriptionService {
}
}
async saveStripeInvoice(stripeInvoice: Stripe.Invoice): Promise<UserInvoice> {
async saveStripeInvoice(stripeInvoice: Stripe.Invoice): Promise<Invoice> {
const knownInvoice = await this.parseStripeInvoice(stripeInvoice);
if (!knownInvoice) {
throw new InternalServerError('Failed to parse stripe invoice.');
}
return this.userManager.saveInvoice(knownInvoice);
return this.select(knownInvoice.lookupKey.plan).saveInvoice(knownInvoice);
}
async saveStripeSubscription(subscription: Stripe.Subscription) {
@@ -360,10 +333,12 @@ export class SubscriptionService {
subscription.status === SubscriptionStatus.Active ||
subscription.status === SubscriptionStatus.Trialing;
const manager = this.select(knownSubscription.lookupKey.plan);
if (!isPlanActive) {
await this.userManager.deleteSubscription(knownSubscription);
await manager.deleteStripeSubscription(knownSubscription);
} else {
await this.userManager.saveSubscription(knownSubscription);
await manager.saveStripeSubscription(knownSubscription);
}
}
@@ -374,19 +349,26 @@ export class SubscriptionService {
throw new InternalServerError('Failed to parse stripe subscription.');
}
await this.userManager.deleteSubscription(knownSubscription);
const manager = this.select(knownSubscription.lookupKey.plan);
await manager.deleteStripeSubscription(knownSubscription);
}
async getOrCreateCustomer(user: CurrentUser): Promise<UserStripeCustomer> {
async getOrCreateCustomer({
userId,
userEmail,
}: {
userId: string;
userEmail: string;
}): Promise<UserStripeCustomer> {
let customer = await this.db.userStripeCustomer.findUnique({
where: {
userId: user.id,
userId,
},
});
if (!customer) {
const stripeCustomersList = await this.stripe.customers.list({
email: user.email,
email: userEmail,
limit: 1,
});
@@ -395,13 +377,13 @@ export class SubscriptionService {
stripeCustomer = stripeCustomersList.data[0];
} else {
stripeCustomer = await this.stripe.customers.create({
email: user.email,
email: userEmail,
});
}
customer = await this.db.userStripeCustomer.create({
data: {
userId: user.id,
userId,
stripeCustomerId: stripeCustomer.id,
},
});
@@ -467,6 +449,17 @@ export class SubscriptionService {
return user.id;
}
private async listStripePrices(): Promise<KnownStripePrice[]> {
const prices = await this.stripe.prices.list({
active: true,
limit: 100,
});
return prices.data
.map(price => this.parseStripePrice(price))
.filter(Boolean) as KnownStripePrice[];
}
private async getPrice(
lookupKey: LookupKey
): Promise<KnownStripePrice | null> {
@@ -485,35 +478,6 @@ export class SubscriptionService {
: null;
}
private async getCouponFromPromotionCode(
userFacingPromotionCode: string,
customer: UserStripeCustomer
) {
const list = await this.stripe.promotionCodes.list({
code: userFacingPromotionCode,
active: true,
limit: 1,
});
const code = list.data[0];
if (!code) {
return null;
}
// the coupons are always bound to products, we need to check it first
// but the logic would be too complicated, and stripe will complain if the code is not applicable when checking out
// It's safe to skip the check here
// code.coupon.applies_to.products.forEach()
// check if the code is bound to a specific customer
return !code.customer ||
(typeof code.customer === 'string'
? code.customer === customer.stripeCustomerId
: code.customer.id === customer.stripeCustomerId)
? code.coupon.id
: null;
}
private async parseStripeInvoice(
invoice: Stripe.Invoice
): Promise<KnownStripeInvoice | null> {
@@ -549,10 +513,13 @@ export class SubscriptionService {
userId: user.id,
stripeInvoice: invoice,
lookupKey,
metadata: invoice.subscription_details?.metadata ?? {},
};
}
private async parseStripeSubscription(subscription: Stripe.Subscription) {
private async parseStripeSubscription(
subscription: Stripe.Subscription
): Promise<KnownStripeSubscription | null> {
const lookupKey = retriveLookupKeyFromStripeSubscription(subscription);
if (!lookupKey) {
@@ -569,6 +536,8 @@ export class SubscriptionService {
userId,
lookupKey,
stripeSubscription: subscription,
quantity: subscription.items.data[0]?.quantity ?? 1,
metadata: subscription.metadata,
};
}
@@ -582,4 +551,14 @@ export class SubscriptionService {
}
: null;
}
private assertSubscriptionIdentity(
args: z.infer<typeof SubscriptionIdentity>
) {
const result = SubscriptionIdentity.safeParse(args);
if (!result.success) {
throw new InvalidSubscriptionParameters();
}
}
}

View File

@@ -1,4 +1,4 @@
import type { User } from '@prisma/client';
import type { User, Workspace } from '@prisma/client';
import Stripe from 'stripe';
import type { Payload } from '../../fundamentals/event/def';
@@ -64,12 +64,31 @@ declare module '../../fundamentals/event/def' {
}>;
};
}
interface WorkspaceEvents {
subscription: {
activated: Payload<{
workspaceId: Workspace['id'];
plan: SubscriptionPlan;
recurring: SubscriptionRecurring;
quantity: number;
}>;
canceled: Payload<{
workspaceId: Workspace['id'];
plan: SubscriptionPlan;
recurring: SubscriptionRecurring;
}>;
};
members: {
updated: Payload<{ workspaceId: Workspace['id']; count: number }>;
};
}
}
export interface LookupKey {
plan: SubscriptionPlan;
recurring: SubscriptionRecurring;
variant?: SubscriptionVariant;
variant: SubscriptionVariant | null;
}
export interface KnownStripeInvoice {
@@ -87,6 +106,11 @@ export interface KnownStripeInvoice {
* The invoice object from Stripe.
*/
stripeInvoice: Stripe.Invoice;
/**
* The metadata of the subscription related to the invoice.
*/
metadata: Record<string, string>;
}
export interface KnownStripeSubscription {
@@ -104,6 +128,16 @@ export interface KnownStripeSubscription {
* The subscription object from Stripe.
*/
stripeSubscription: Stripe.Subscription;
/**
* The quantity of the subscription items.
*/
quantity: number;
/**
* The metadata of the subscription.
*/
metadata: Record<string, string>;
}
export interface KnownStripePrice {
@@ -167,7 +201,7 @@ export function decodeLookupKey(key: string): LookupKey | null {
return {
plan: plan as SubscriptionPlan,
recurring: recurring as SubscriptionRecurring,
variant: variant as SubscriptionVariant | undefined,
variant: variant as SubscriptionVariant,
};
}

View File

@@ -140,6 +140,7 @@ input CreateChatSessionInput {
}
input CreateCheckoutSessionInput {
args: JSONObject
coupon: String
idempotencyKey: String
plan: SubscriptionPlan = Pro
@@ -208,7 +209,7 @@ type EditorType {
name: String!
}
union ErrorDataUnion = AlreadyInSpaceDataType | BlobNotFoundDataType | CopilotMessageNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderSideErrorDataType | DocAccessDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | InvalidEmailDataType | InvalidHistoryTimestampDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | MissingOauthQueryParameterDataType | NotInSpaceDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SpaceAccessDeniedDataType | SpaceNotFoundDataType | SpaceOwnerNotFoundDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | VersionRejectedDataType
union ErrorDataUnion = AlreadyInSpaceDataType | BlobNotFoundDataType | CopilotMessageNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderSideErrorDataType | DocAccessDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | InvalidEmailDataType | InvalidHistoryTimestampDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | MissingOauthQueryParameterDataType | NotInSpaceDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SpaceAccessDeniedDataType | SpaceNotFoundDataType | SpaceOwnerNotFoundDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | UnsupportedSubscriptionPlanDataType | VersionRejectedDataType
enum ErrorNames {
ACCESS_DENIED
@@ -246,12 +247,14 @@ enum ErrorNames {
FAILED_TO_SAVE_UPDATES
FAILED_TO_UPSERT_SNAPSHOT
INTERNAL_SERVER_ERROR
INVALID_CHECKOUT_PARAMETERS
INVALID_EMAIL
INVALID_EMAIL_TOKEN
INVALID_HISTORY_TIMESTAMP
INVALID_OAUTH_CALLBACK_STATE
INVALID_PASSWORD_LENGTH
INVALID_RUNTIME_CONFIG_TYPE
INVALID_SUBSCRIPTION_PARAMETERS
LINK_EXPIRED
MAILER_SERVICE_IS_NOT_CONFIGURED
MEMBER_QUOTA_EXCEEDED
@@ -273,14 +276,18 @@ enum ErrorNames {
SUBSCRIPTION_ALREADY_EXISTS
SUBSCRIPTION_EXPIRED
SUBSCRIPTION_HAS_BEEN_CANCELED
SUBSCRIPTION_HAS_NOT_BEEN_CANCELED
SUBSCRIPTION_NOT_EXISTS
SUBSCRIPTION_PLAN_NOT_FOUND
TOO_MANY_REQUEST
UNKNOWN_OAUTH_PROVIDER
UNSPLASH_IS_NOT_CONFIGURED
UNSUPPORTED_SUBSCRIPTION_PLAN
USER_AVATAR_NOT_FOUND
USER_NOT_FOUND
VERSION_REJECTED
WORKSPACE_ID_REQUIRED_FOR_TEAM_SUBSCRIPTION
WORKSPACE_ID_REQUIRED_TO_UPDATE_TEAM_SUBSCRIPTION
WRONG_SIGN_IN_CREDENTIALS
WRONG_SIGN_IN_METHOD
}
@@ -444,7 +451,7 @@ type MissingOauthQueryParameterDataType {
type Mutation {
acceptInviteById(inviteId: String!, sendAcceptMail: Boolean, workspaceId: String!): Boolean!
addWorkspaceFeature(feature: FeatureType!, workspaceId: String!): Int!
cancelSubscription(idempotencyKey: String @deprecated(reason: "use header `Idempotency-Key`"), plan: SubscriptionPlan = Pro): SubscriptionType!
cancelSubscription(idempotencyKey: String @deprecated(reason: "use header `Idempotency-Key`"), plan: SubscriptionPlan = Pro, workspaceId: String): SubscriptionType!
changeEmail(email: String!, token: String!): UserType!
changePassword(newPassword: String!, token: String!, userId: String): Boolean!
@@ -491,7 +498,7 @@ type Mutation {
"""Remove user avatar"""
removeAvatar: RemoveAvatar!
removeWorkspaceFeature(feature: FeatureType!, workspaceId: String!): Int!
resumeSubscription(idempotencyKey: String @deprecated(reason: "use header `Idempotency-Key`"), plan: SubscriptionPlan = Pro): SubscriptionType!
resumeSubscription(idempotencyKey: String @deprecated(reason: "use header `Idempotency-Key`"), plan: SubscriptionPlan = Pro, workspaceId: String): SubscriptionType!
revoke(userId: String!, workspaceId: String!): Boolean!
revokePage(pageId: String!, workspaceId: String!): Boolean! @deprecated(reason: "use revokePublicPage")
revokePublicPage(pageId: String!, workspaceId: String!): WorkspacePage!
@@ -513,7 +520,7 @@ type Mutation {
"""update multiple server runtime configurable settings"""
updateRuntimeConfigs(updates: JSONObject!): [ServerRuntimeConfigType!]!
updateSubscriptionRecurring(idempotencyKey: String @deprecated(reason: "use header `Idempotency-Key`"), plan: SubscriptionPlan = Pro, recurring: SubscriptionRecurring!): SubscriptionType!
updateSubscriptionRecurring(idempotencyKey: String @deprecated(reason: "use header `Idempotency-Key`"), plan: SubscriptionPlan = Pro, recurring: SubscriptionRecurring!, workspaceId: String): SubscriptionType!
"""Update a user"""
updateUser(id: String!, input: ManageUserInput!): UserType!
@@ -814,6 +821,10 @@ type UnknownOauthProviderDataType {
name: String!
}
type UnsupportedSubscriptionPlanDataType {
plan: String!
}
input UpdateUserInput {
"""User name"""
name: String
@@ -929,6 +940,10 @@ type WorkspaceType {
"""is current workspace initialized"""
initialized: Boolean!
"""Get user invoice count"""
invoiceCount: Int!
invoices(skip: Int, take: Int = 8): [InvoiceType!]!
"""member count of workspace"""
memberCount: Int!
@@ -958,6 +973,9 @@ type WorkspaceType {
"""Shared pages of workspace"""
sharedPages: [String!]! @deprecated(reason: "use WorkspaceType.publicPages")
"""The team subscription of the workspace, if exists."""
subscription: SubscriptionType
}
type tokenType {

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,10 @@ Generated by [AVA](https://avajs.dev).
[
'pro_monthly',
'pro_yearly',
'pro_lifetime',
'ai_yearly',
'team_monthly',
'team_yearly',
]
## should list normal prices for authenticated user
@@ -21,7 +24,22 @@ Generated by [AVA](https://avajs.dev).
[
'pro_monthly',
'pro_yearly',
'pro_lifetime',
'ai_yearly',
'team_monthly',
'team_yearly',
]
## should not show lifetime price if not enabled
> Snapshot 1
[
'pro_monthly',
'pro_yearly',
'ai_yearly',
'team_monthly',
'team_yearly',
]
## should list early access prices for pro ea user
@@ -30,8 +48,11 @@ Generated by [AVA](https://avajs.dev).
[
'pro_monthly',
'pro_lifetime',
'pro_yearly_earlyaccess',
'ai_yearly',
'team_monthly',
'team_yearly',
]
## should list normal prices for pro ea user with old subscriptions
@@ -41,7 +62,10 @@ Generated by [AVA](https://avajs.dev).
[
'pro_monthly',
'pro_yearly',
'pro_lifetime',
'ai_yearly',
'team_monthly',
'team_yearly',
]
## should list early access prices for ai ea user
@@ -51,7 +75,10 @@ Generated by [AVA](https://avajs.dev).
[
'pro_monthly',
'pro_yearly',
'pro_lifetime',
'ai_yearly_earlyaccess',
'team_monthly',
'team_yearly',
]
## should list early access prices for pro and ai ea user
@@ -60,8 +87,11 @@ Generated by [AVA](https://avajs.dev).
[
'pro_monthly',
'pro_lifetime',
'pro_yearly_earlyaccess',
'ai_yearly_earlyaccess',
'team_monthly',
'team_yearly',
]
## should list normal prices for ai ea user with old subscriptions
@@ -71,5 +101,21 @@ Generated by [AVA](https://avajs.dev).
[
'pro_monthly',
'pro_yearly',
'pro_lifetime',
'ai_yearly',
'team_monthly',
'team_yearly',
]
## should be able to list prices for team
> Snapshot 1
[
'pro_monthly',
'pro_yearly',
'pro_lifetime',
'ai_yearly',
'team_monthly',
'team_yearly',
]