feat(server): introduce user friendly server errors (#7111)

This commit is contained in:
liuyi
2024-06-17 11:30:58 +08:00
committed by GitHub
parent 5307a55f8a
commit 54fc1197ad
65 changed files with 3170 additions and 924 deletions

View File

@@ -1,11 +1,7 @@
import {
BadRequestException,
Controller,
Get,
HttpException,
InternalServerErrorException,
Logger,
NotFoundException,
Param,
Query,
Req,
@@ -23,14 +19,21 @@ import {
merge,
mergeMap,
Observable,
of,
switchMap,
toArray,
} from 'rxjs';
import { Public } from '../../core/auth';
import { CurrentUser } from '../../core/auth/current-user';
import { Config } from '../../fundamentals';
import {
BlobNotFound,
Config,
CopilotFailedToGenerateText,
CopilotSessionNotFound,
mapSseError,
NoCopilotProviderAvailable,
UnsplashIsNotConfigured,
} from '../../fundamentals';
import { CopilotProviderService } from './providers';
import { ChatSession, ChatSessionService } from './session';
import { CopilotStorage } from './storage';
@@ -40,7 +43,7 @@ import { CopilotWorkflowService } from './workflow';
export interface ChatEvent {
type: 'attachment' | 'message' | 'error';
id?: string;
data: string;
data: string | object;
}
type CheckResult = {
@@ -68,7 +71,7 @@ export class CopilotController {
await this.chatSession.checkQuota(userId);
const session = await this.chatSession.get(sessionId);
if (!session || session.config.userId !== userId) {
throw new BadRequestException('Session not found');
throw new CopilotSessionNotFound();
}
const ret: CheckResult = { model: session.model };
@@ -104,7 +107,7 @@ export class CopilotController {
);
}
if (!provider) {
throw new InternalServerErrorException('No provider available');
throw new NoCopilotProviderAvailable();
}
return provider;
@@ -116,7 +119,7 @@ export class CopilotController {
): Promise<ChatSession> {
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
throw new CopilotSessionNotFound();
}
if (messageId) {
@@ -148,20 +151,6 @@ export class CopilotController {
return num;
}
private handleError(err: any) {
if (err instanceof Error) {
const ret = {
message: err.message,
status: (err as any).status,
};
if (err instanceof HttpException) {
ret.status = err.getStatus();
}
return ret;
}
return err;
}
@Get('/chat/:sessionId')
async chat(
@CurrentUser() user: CurrentUser,
@@ -200,9 +189,7 @@ export class CopilotController {
return content;
} catch (e: any) {
throw new InternalServerErrorException(
e.message || "Couldn't generate text"
);
throw new CopilotFailedToGenerateText(e.message);
}
}
@@ -253,18 +240,10 @@ export class CopilotController {
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
catchError(mapSseError)
);
} catch (err) {
return of({
type: 'error' as const,
data: this.handleError(err),
});
return mapSseError(err);
}
}
@@ -318,18 +297,10 @@ export class CopilotController {
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
catchError(mapSseError)
);
} catch (err) {
return of({
type: 'error' as const,
data: this.handleError(err),
});
return mapSseError(err);
}
}
@@ -356,7 +327,7 @@ export class CopilotController {
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
throw new NoCopilotProviderAvailable();
}
const session = await this.appendSessionMessage(sessionId, messageId);
@@ -402,18 +373,10 @@ export class CopilotController {
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
catchError(mapSseError)
);
} catch (err) {
return of({
type: 'error' as const,
data: this.handleError(err),
});
return mapSseError(err);
}
}
@@ -425,7 +388,7 @@ export class CopilotController {
) {
const { unsplashKey } = this.config.plugins.copilot || {};
if (!unsplashKey) {
throw new InternalServerErrorException('Unsplash key is not configured');
throw new UnsplashIsNotConfigured();
}
const query = new URLSearchParams(params);
@@ -458,9 +421,10 @@ export class CopilotController {
const { body, metadata } = await this.storage.get(userId, workspaceId, key);
if (!body) {
throw new NotFoundException(
`Blob not found in ${userId}'s workspace ${workspaceId}: ${key}`
);
throw new BlobNotFound({
workspaceId,
blobId: key,
});
}
// metadata should always exists if body is not null

View File

@@ -1,6 +1,6 @@
import { createHash } from 'node:crypto';
import { BadRequestException, Logger, NotFoundException } from '@nestjs/common';
import { BadRequestException, NotFoundException } from '@nestjs/common';
import {
Args,
Field,
@@ -23,6 +23,7 @@ import { Admin } from '../../core/common';
import { UserType } from '../../core/user';
import { PermissionService } from '../../core/workspaces/permission';
import {
CopilotFailedToCreateMessage,
FileUpload,
MutexService,
Throttle,
@@ -201,8 +202,6 @@ export class CopilotType {
@Throttle()
@Resolver(() => CopilotType)
export class CopilotResolver {
private readonly logger = new Logger(CopilotResolver.name);
constructor(
private readonly permissions: PermissionService,
private readonly mutex: MutexService,
@@ -385,8 +384,7 @@ export class CopilotResolver {
try {
return await this.chatSession.createMessage(options);
} catch (e: any) {
this.logger.error(`Failed to create chat message: ${e.message}`);
throw new Error('Failed to create chat message');
throw new CopilotFailedToCreateMessage(e.message);
}
}
}

View File

@@ -5,7 +5,14 @@ import { AiPromptRole, PrismaClient } from '@prisma/client';
import { FeatureManagementService } from '../../core/features';
import { QuotaService } from '../../core/quota';
import { PaymentRequiredException } from '../../fundamentals';
import {
CopilotActionTaken,
CopilotMessageNotFound,
CopilotPromptNotFound,
CopilotQuotaExceeded,
CopilotSessionDeleted,
CopilotSessionNotFound,
} from '../../fundamentals';
import { ChatMessageCache } from './message';
import { PromptService } from './prompt';
import {
@@ -58,7 +65,7 @@ export class ChatSession implements AsyncDisposable {
this.state.messages.length > 0 &&
message.role === 'user'
) {
throw new Error('Action has been taken, no more messages allowed');
throw new CopilotActionTaken();
}
this.state.messages.push(message);
this.stashMessageCount += 1;
@@ -74,7 +81,7 @@ export class ChatSession implements AsyncDisposable {
async getMessageById(messageId: string) {
const message = await this.messageCache.get(messageId);
if (!message || message.sessionId !== this.state.sessionId) {
throw new Error(`Message not found: ${messageId}`);
throw new CopilotMessageNotFound();
}
return message;
}
@@ -82,7 +89,7 @@ export class ChatSession implements AsyncDisposable {
async pushByMessageId(messageId: string) {
const message = await this.messageCache.get(messageId);
if (!message || message.sessionId !== this.state.sessionId) {
throw new Error(`Message not found: ${messageId}`);
throw new CopilotMessageNotFound();
}
this.push({
@@ -196,7 +203,7 @@ export class ChatSessionService {
},
select: { id: true, deletedAt: true },
})) || {};
if (deletedAt) throw new Error(`Session is deleted: ${id}`);
if (deletedAt) throw new CopilotSessionDeleted();
if (id) sessionId = id;
}
@@ -274,7 +281,8 @@ export class ChatSessionService {
.then(async session => {
if (!session) return;
const prompt = await this.prompt.get(session.promptName);
if (!prompt) throw new Error(`Prompt not found: ${session.promptName}`);
if (!prompt)
throw new CopilotPromptNotFound({ name: session.promptName });
const messages = ChatMessageSchema.array().safeParse(session.messages);
@@ -300,7 +308,7 @@ export class ChatSessionService {
})
.then(session => session?.id);
if (!id) {
throw new Error(`Session not found: ${sessionId}`);
throw new CopilotSessionNotFound();
}
const ids = await tx.aiSessionMessage
.findMany({
@@ -412,7 +420,7 @@ export class ChatSessionService {
if (ret.success) {
const prompt = await this.prompt.get(promptName);
if (!prompt) {
throw new Error(`Prompt not found: ${promptName}`);
throw new CopilotPromptNotFound({ name: promptName });
}
// render system prompt
@@ -471,9 +479,7 @@ export class ChatSessionService {
async checkQuota(userId: string) {
const { limit, used } = await this.getQuota(userId);
if (limit && Number.isFinite(limit) && used >= limit) {
throw new PaymentRequiredException(
`You have reached the limit of actions in this workspace, please upgrade your plan.`
);
throw new CopilotQuotaExceeded();
}
}
@@ -482,7 +488,7 @@ export class ChatSessionService {
const prompt = await this.prompt.get(options.promptName);
if (!prompt) {
this.logger.error(`Prompt not found: ${options.promptName}`);
throw new Error('Prompt not found');
throw new CopilotPromptNotFound({ name: options.promptName });
}
return await this.setSession({
...options,

View File

@@ -1,10 +1,11 @@
import { createHash } from 'node:crypto';
import { Injectable, PayloadTooLargeException } from '@nestjs/common';
import { Injectable } from '@nestjs/common';
import { QuotaManagementService } from '../../core/quota';
import {
type BlobInputType,
BlobQuotaExceeded,
Config,
type FileUpload,
type StorageProvider,
@@ -54,9 +55,7 @@ export class CopilotStorage {
const checkExceeded = await this.quota.getQuotaCalculator(userId);
if (checkExceeded(0)) {
throw new PayloadTooLargeException(
'Storage or blob size limit exceeded.'
);
throw new BlobQuotaExceeded();
}
const buffer = await new Promise<Buffer>((resolve, reject) => {
const stream = blob.createReadStream();
@@ -67,9 +66,7 @@ export class CopilotStorage {
// check size after receive each chunk to avoid unnecessary memory usage
const bufferSize = chunks.reduce((acc, cur) => acc + cur.length, 0);
if (checkExceeded(bufferSize)) {
reject(
new PayloadTooLargeException('Storage or blob size limit exceeded.')
);
reject(new BlobQuotaExceeded());
}
});
stream.on('error', reject);
@@ -77,7 +74,7 @@ export class CopilotStorage {
const buffer = Buffer.concat(chunks);
if (checkExceeded(buffer.length)) {
reject(new PayloadTooLargeException('Storage limit exceeded.'));
reject(new BlobQuotaExceeded());
} else {
resolve(buffer);
}

View File

@@ -1,17 +1,18 @@
import {
BadRequestException,
Controller,
Get,
Query,
Req,
Res,
} from '@nestjs/common';
import { Controller, Get, Query, Req, Res } from '@nestjs/common';
import { ConnectedAccount, PrismaClient } from '@prisma/client';
import type { Request, Response } from 'express';
import { AuthService, Public } from '../../core/auth';
import { UserService } from '../../core/user';
import { URLHelper } from '../../fundamentals';
import {
InvalidOauthCallbackState,
MissingOauthQueryParameter,
OauthAccountAlreadyConnected,
OauthStateExpired,
UnknownOauthProvider,
URLHelper,
WrongSignInMethod,
} from '../../fundamentals';
import { OAuthProviderName } from './config';
import { OAuthAccount, Tokens } from './providers/def';
import { OAuthProviderFactory } from './register';
@@ -35,12 +36,15 @@ export class OAuthController {
@Query('provider') unknownProviderName: string,
@Query('redirect_uri') redirectUri?: string
) {
if (!unknownProviderName) {
throw new MissingOauthQueryParameter({ name: 'provider' });
}
// @ts-expect-error safe
const providerName = OAuthProviderName[unknownProviderName];
const provider = this.providerFactory.get(providerName);
if (!provider) {
throw new BadRequestException('Invalid OAuth provider');
throw new UnknownOauthProvider({ name: unknownProviderName });
}
const state = await this.oauth.saveOAuthState({
@@ -60,29 +64,31 @@ export class OAuthController {
@Query('state') stateStr?: string
) {
if (!code) {
throw new BadRequestException('Missing query parameter `code`');
throw new MissingOauthQueryParameter({ name: 'code' });
}
if (!stateStr) {
throw new BadRequestException('Invalid callback state parameter');
throw new MissingOauthQueryParameter({ name: 'state' });
}
if (typeof stateStr !== 'string' || !this.oauth.isValidState(stateStr)) {
throw new InvalidOauthCallbackState();
}
const state = await this.oauth.getOAuthState(stateStr);
if (!state) {
throw new BadRequestException('OAuth state expired, please try again.');
throw new OauthStateExpired();
}
if (!state.provider) {
throw new BadRequestException(
'Missing callback state parameter `provider`'
);
throw new MissingOauthQueryParameter({ name: 'provider' });
}
const provider = this.providerFactory.get(state.provider);
if (!provider) {
throw new BadRequestException('Invalid provider');
throw new UnknownOauthProvider({ name: state.provider ?? 'unknown' });
}
const tokens = await provider.getToken(code);
@@ -154,15 +160,9 @@ export class OAuthController {
// we can't directly connect the external account with given email in sign in scenario for safety concern.
// let user manually connect in account sessions instead.
if (user.registered) {
throw new BadRequestException(
'The account with provided email is not register in the same way.'
);
throw new WrongSignInMethod();
}
await this.user.fulfillUser(externalAccount.email, {
emailVerifiedAt: new Date(),
registered: true,
});
await this.db.connectedAccount.create({
data: {
userId: user.id,
@@ -228,9 +228,7 @@ export class OAuthController {
if (connectedUser) {
if (connectedUser.id !== user.id) {
throw new BadRequestException(
'The third-party account has already been connected to another user.'
);
throw new OauthAccountAlreadyConnected();
}
} else {
await this.db.connectedAccount.create({

View File

@@ -1,4 +1,4 @@
import { HttpException, HttpStatus, Injectable } from '@nestjs/common';
import { Injectable } from '@nestjs/common';
import { Config, URLHelper } from '../../../fundamentals';
import { OAuthProviderName } from '../config';
@@ -39,74 +39,60 @@ export class GithubOAuthProvider extends AutoRegisteredOAuthProvider {
}
async getToken(code: string) {
try {
const response = await fetch(
'https://github.com/login/oauth/access_token',
{
method: 'POST',
body: this.url.stringify({
code,
client_id: this.config.clientId,
client_secret: this.config.clientSecret,
redirect_uri: this.url.link('/oauth/callback'),
}),
headers: {
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
},
}
);
if (response.ok) {
const ghToken = (await response.json()) as AuthTokenResponse;
return {
accessToken: ghToken.access_token,
scope: ghToken.scope,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
}, ${JSON.stringify(await response.json())}`
);
const response = await fetch(
'https://github.com/login/oauth/access_token',
{
method: 'POST',
body: this.url.stringify({
code,
client_id: this.config.clientId,
client_secret: this.config.clientSecret,
redirect_uri: this.url.link('/oauth/callback'),
}),
headers: {
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
},
}
} catch (e) {
throw new HttpException(
`Failed to get access_token, err: ${(e as Error).message}`,
HttpStatus.BAD_REQUEST
);
if (response.ok) {
const ghToken = (await response.json()) as AuthTokenResponse;
return {
accessToken: ghToken.access_token,
scope: ghToken.scope,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
}, ${JSON.stringify(await response.json())}`
);
}
}
async getUser(token: string) {
try {
const response = await fetch('https://api.github.com/user', {
method: 'GET',
headers: {
Authorization: `Bearer ${token}`,
},
});
const response = await fetch('https://api.github.com/user', {
method: 'GET',
headers: {
Authorization: `Bearer ${token}`,
},
});
if (response.ok) {
const user = (await response.json()) as UserInfo;
if (response.ok) {
const user = (await response.json()) as UserInfo;
return {
id: user.login,
avatarUrl: user.avatar_url,
email: user.email,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
} ${await response.text()}`
);
}
} catch (e) {
throw new HttpException(
`Failed to get user information, err: ${(e as Error).stack}`,
HttpStatus.BAD_REQUEST
return {
id: user.login,
avatarUrl: user.avatar_url,
email: user.email,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
} ${await response.text()}`
);
}
}

View File

@@ -1,4 +1,4 @@
import { HttpException, HttpStatus, Injectable } from '@nestjs/common';
import { Injectable } from '@nestjs/common';
import { Config, URLHelper } from '../../../fundamentals';
import { OAuthProviderName } from '../config';
@@ -44,77 +44,63 @@ export class GoogleOAuthProvider extends AutoRegisteredOAuthProvider {
}
async getToken(code: string) {
try {
const response = await fetch('https://oauth2.googleapis.com/token', {
method: 'POST',
body: this.url.stringify({
code,
client_id: this.config.clientId,
client_secret: this.config.clientSecret,
redirect_uri: this.url.link('/oauth/callback'),
grant_type: 'authorization_code',
}),
headers: {
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
},
});
const response = await fetch('https://oauth2.googleapis.com/token', {
method: 'POST',
body: this.url.stringify({
code,
client_id: this.config.clientId,
client_secret: this.config.clientSecret,
redirect_uri: this.url.link('/oauth/callback'),
grant_type: 'authorization_code',
}),
headers: {
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
},
});
if (response.ok) {
const ghToken = (await response.json()) as GoogleOAuthTokenResponse;
if (response.ok) {
const ghToken = (await response.json()) as GoogleOAuthTokenResponse;
return {
accessToken: ghToken.access_token,
refreshToken: ghToken.refresh_token,
expiresAt: new Date(Date.now() + ghToken.expires_in * 1000),
scope: ghToken.scope,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
}, ${JSON.stringify(await response.json())}`
);
}
} catch (e) {
throw new HttpException(
`Failed to get access_token, err: ${(e as Error).message}`,
HttpStatus.BAD_REQUEST
return {
accessToken: ghToken.access_token,
refreshToken: ghToken.refresh_token,
expiresAt: new Date(Date.now() + ghToken.expires_in * 1000),
scope: ghToken.scope,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
}, ${JSON.stringify(await response.json())}`
);
}
}
async getUser(token: string) {
try {
const response = await fetch(
'https://www.googleapis.com/oauth2/v2/userinfo',
{
method: 'GET',
headers: {
Authorization: `Bearer ${token}`,
},
}
);
if (response.ok) {
const user = (await response.json()) as UserInfo;
return {
id: user.id,
avatarUrl: user.picture,
email: user.email,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
} ${await response.text()}`
);
const response = await fetch(
'https://www.googleapis.com/oauth2/v2/userinfo',
{
method: 'GET',
headers: {
Authorization: `Bearer ${token}`,
},
}
} catch (e) {
throw new HttpException(
`Failed to get user information, err: ${(e as Error).stack}`,
HttpStatus.BAD_REQUEST
);
if (response.ok) {
const user = (await response.json()) as UserInfo;
return {
id: user.id,
avatarUrl: user.picture,
email: user.email,
};
} else {
throw new Error(
`Server responded with non-success code ${
response.status
} ${await response.text()}`
);
}
}

View File

@@ -1,9 +1,4 @@
import {
BadRequestException,
Injectable,
InternalServerErrorException,
OnModuleInit,
} from '@nestjs/common';
import { Injectable, Logger, OnModuleInit } from '@nestjs/common';
import { z } from 'zod';
import { Config, URLHelper } from '../../../fundamentals';
@@ -44,6 +39,8 @@ const OIDCConfigurationSchema = z.object({
type OIDCConfiguration = z.infer<typeof OIDCConfigurationSchema>;
const logger = new Logger('OIDCClient');
class OIDCClient {
private static async fetch<T = any>(
url: string,
@@ -53,17 +50,8 @@ class OIDCClient {
const response = await fetch(url, options);
if (!response.ok) {
if (response.status >= 400 && response.status < 500) {
throw new BadRequestException(`Invalid OIDC configuration`, {
cause: await response.json(),
description: response.statusText,
});
} else {
throw new InternalServerErrorException(`Failed to configure client`, {
cause: await response.json(),
description: response.statusText,
});
}
logger.error('Failed to fetch OIDC configuration', await response.json());
throw new Error(`Failed to configure client`);
}
const data = await response.json();
return verifier.parse(data);

View File

@@ -20,6 +20,10 @@ export class OAuthService {
private readonly cache: SessionCache
) {}
isValidState(stateStr: string) {
return stateStr.length === 36;
}
async saveOAuthState(state: OAuthState) {
const token = randomUUID();
await this.cache.set(`${OAUTH_STATE_KEY}:${token}`, state, {

View File

@@ -1,4 +1,3 @@
import { BadGatewayException, ForbiddenException } from '@nestjs/common';
import {
Args,
Context,
@@ -19,7 +18,12 @@ import { groupBy } from 'lodash-es';
import { CurrentUser, Public } from '../../core/auth';
import { UserType } from '../../core/user';
import { Config, URLHelper } from '../../fundamentals';
import {
AccessDenied,
Config,
FailedToCheckout,
URLHelper,
} from '../../fundamentals';
import { decodeLookupKey, SubscriptionService } from './service';
import {
InvoiceStatus,
@@ -227,7 +231,7 @@ export class SubscriptionResolver {
});
if (!session.url) {
throw new BadGatewayException('Failed to create checkout session.');
throw new FailedToCheckout();
}
return session.url;
@@ -322,9 +326,7 @@ export class UserSubscriptionResolver {
) {
// allow admin to query other user's subscription
if (!ctx.isAdminQuery && me.id !== user.id) {
throw new ForbiddenException(
'You are not allowed to access this subscription.'
);
throw new AccessDenied();
}
// @FIXME(@forehalo): should not mock any api for selfhosted server
@@ -363,9 +365,7 @@ export class UserSubscriptionResolver {
@Parent() user: User
): Promise<UserSubscription[]> {
if (me.id !== user.id) {
throw new ForbiddenException(
'You are not allowed to access this subscription.'
);
throw new AccessDenied();
}
return this.db.userSubscription.findMany({
@@ -385,9 +385,7 @@ export class UserSubscriptionResolver {
@Args('skip', { type: () => Int, nullable: true }) skip?: number
) {
if (me.id !== user.id) {
throw new ForbiddenException(
'You are not allowed to access this invoices'
);
throw new AccessDenied();
}
return this.db.userInvoice.findMany({

View File

@@ -1,6 +1,6 @@
import { randomUUID } from 'node:crypto';
import { BadRequestException, Injectable, Logger } from '@nestjs/common';
import { Injectable, Logger } from '@nestjs/common';
import { OnEvent as RawOnEvent } from '@nestjs/event-emitter';
import type {
Prisma,
@@ -14,7 +14,20 @@ import Stripe from 'stripe';
import { CurrentUser } from '../../core/auth';
import { EarlyAccessType, FeatureManagementService } from '../../core/features';
import { Config, EventEmitter, OnEvent } from '../../fundamentals';
import {
ActionForbidden,
Config,
CustomerPortalCreateFailed,
EventEmitter,
OnEvent,
SameSubscriptionRecurring,
SubscriptionAlreadyExists,
SubscriptionExpired,
SubscriptionHasBeenCanceled,
SubscriptionNotExists,
SubscriptionPlanNotFound,
UserNotFound,
} from '../../fundamentals';
import { ScheduleManager } from './schedule';
import {
InvoiceStatus,
@@ -160,7 +173,7 @@ export class SubscriptionService {
this.config.affine.canary &&
!this.features.isStaff(user.email)
) {
throw new BadRequestException('You are not allowed to do this.');
throw new ActionForbidden();
}
const currentSubscription = await this.db.userSubscription.findFirst({
@@ -172,9 +185,7 @@ export class SubscriptionService {
});
if (currentSubscription) {
throw new BadRequestException(
`You've already subscribed to the ${plan} plan`
);
throw new SubscriptionAlreadyExists({ plan });
}
const customer = await this.getOrCreateCustomer(
@@ -245,18 +256,16 @@ export class SubscriptionService {
});
if (!user) {
throw new BadRequestException('Unknown user');
throw new UserNotFound();
}
const subscriptionInDB = user?.subscriptions.find(s => s.plan === plan);
if (!subscriptionInDB) {
throw new BadRequestException(`You didn't subscribe to the ${plan} plan`);
throw new SubscriptionNotExists({ plan });
}
if (subscriptionInDB.canceledAt) {
throw new BadRequestException(
'Your subscription has already been canceled'
);
throw new SubscriptionHasBeenCanceled();
}
// should release the schedule first
@@ -298,22 +307,20 @@ export class SubscriptionService {
});
if (!user) {
throw new BadRequestException('Unknown user');
throw new UserNotFound();
}
const subscriptionInDB = user?.subscriptions.find(s => s.plan === plan);
if (!subscriptionInDB) {
throw new BadRequestException(`You didn't subscribe to the ${plan} plan`);
throw new SubscriptionNotExists({ plan });
}
if (!subscriptionInDB.canceledAt) {
throw new BadRequestException('Your subscription has not been canceled');
throw new SubscriptionHasBeenCanceled();
}
if (subscriptionInDB.end < new Date()) {
throw new BadRequestException(
'Your subscription is expired, please checkout again.'
);
throw new SubscriptionExpired();
}
if (subscriptionInDB.stripeScheduleId) {
@@ -354,23 +361,19 @@ export class SubscriptionService {
});
if (!user) {
throw new BadRequestException('Unknown user');
throw new UserNotFound();
}
const subscriptionInDB = user?.subscriptions.find(s => s.plan === plan);
if (!subscriptionInDB) {
throw new BadRequestException(`You didn't subscribe to the ${plan} plan`);
throw new SubscriptionNotExists({ plan });
}
if (subscriptionInDB.canceledAt) {
throw new BadRequestException(
'Your subscription has already been canceled'
);
throw new SubscriptionHasBeenCanceled();
}
if (subscriptionInDB.recurring === recurring) {
throw new BadRequestException(
`You are already in ${recurring} recurring`
);
throw new SameSubscriptionRecurring({ recurring });
}
const price = await this.getPrice(
@@ -404,7 +407,7 @@ export class SubscriptionService {
});
if (!user) {
throw new BadRequestException('Unknown user');
throw new UserNotFound();
}
try {
@@ -415,7 +418,7 @@ export class SubscriptionService {
return portal.url;
} catch (e) {
this.logger.error('Failed to create customer portal.', e);
throw new BadRequestException('Failed to create customer portal');
throw new CustomerPortalCreateFailed();
}
}
@@ -751,9 +754,10 @@ export class SubscriptionService {
});
if (!prices.data.length) {
throw new BadRequestException(
`Unknown subscription plan ${plan} with ${recurring} recurring`
);
throw new SubscriptionPlanNotFound({
plan,
recurring,
});
}
return prices.data[0].id;

View File

@@ -1,19 +1,13 @@
import assert from 'node:assert';
import type { RawBodyRequest } from '@nestjs/common';
import {
Controller,
Logger,
NotAcceptableException,
Post,
Req,
} from '@nestjs/common';
import { Controller, Logger, Post, Req } from '@nestjs/common';
import { EventEmitter2 } from '@nestjs/event-emitter';
import type { Request } from 'express';
import Stripe from 'stripe';
import { Public } from '../../core/auth';
import { Config } from '../../fundamentals';
import { Config, InternalServerError } from '../../fundamentals';
@Controller('/api/stripe')
export class StripeWebhook {
@@ -55,9 +49,8 @@ export class StripeWebhook {
this.logger.error('Failed to handle Stripe Webhook event.', e);
});
});
} catch (err) {
this.logger.error('Stripe Webhook error', err);
throw new NotAcceptableException();
} catch (err: any) {
throw new InternalServerError(err.message);
}
}
}