mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 20:38:52 +00:00
feat(server): introduce user friendly server errors (#7111)
This commit is contained in:
@@ -1,11 +1,7 @@
|
||||
import {
|
||||
BadRequestException,
|
||||
Controller,
|
||||
Get,
|
||||
HttpException,
|
||||
InternalServerErrorException,
|
||||
Logger,
|
||||
NotFoundException,
|
||||
Param,
|
||||
Query,
|
||||
Req,
|
||||
@@ -23,14 +19,21 @@ import {
|
||||
merge,
|
||||
mergeMap,
|
||||
Observable,
|
||||
of,
|
||||
switchMap,
|
||||
toArray,
|
||||
} from 'rxjs';
|
||||
|
||||
import { Public } from '../../core/auth';
|
||||
import { CurrentUser } from '../../core/auth/current-user';
|
||||
import { Config } from '../../fundamentals';
|
||||
import {
|
||||
BlobNotFound,
|
||||
Config,
|
||||
CopilotFailedToGenerateText,
|
||||
CopilotSessionNotFound,
|
||||
mapSseError,
|
||||
NoCopilotProviderAvailable,
|
||||
UnsplashIsNotConfigured,
|
||||
} from '../../fundamentals';
|
||||
import { CopilotProviderService } from './providers';
|
||||
import { ChatSession, ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
@@ -40,7 +43,7 @@ import { CopilotWorkflowService } from './workflow';
|
||||
export interface ChatEvent {
|
||||
type: 'attachment' | 'message' | 'error';
|
||||
id?: string;
|
||||
data: string;
|
||||
data: string | object;
|
||||
}
|
||||
|
||||
type CheckResult = {
|
||||
@@ -68,7 +71,7 @@ export class CopilotController {
|
||||
await this.chatSession.checkQuota(userId);
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
if (!session || session.config.userId !== userId) {
|
||||
throw new BadRequestException('Session not found');
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
|
||||
const ret: CheckResult = { model: session.model };
|
||||
@@ -104,7 +107,7 @@ export class CopilotController {
|
||||
);
|
||||
}
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
throw new NoCopilotProviderAvailable();
|
||||
}
|
||||
|
||||
return provider;
|
||||
@@ -116,7 +119,7 @@ export class CopilotController {
|
||||
): Promise<ChatSession> {
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
if (!session) {
|
||||
throw new BadRequestException('Session not found');
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
|
||||
if (messageId) {
|
||||
@@ -148,20 +151,6 @@ export class CopilotController {
|
||||
return num;
|
||||
}
|
||||
|
||||
private handleError(err: any) {
|
||||
if (err instanceof Error) {
|
||||
const ret = {
|
||||
message: err.message,
|
||||
status: (err as any).status,
|
||||
};
|
||||
if (err instanceof HttpException) {
|
||||
ret.status = err.getStatus();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
return err;
|
||||
}
|
||||
|
||||
@Get('/chat/:sessionId')
|
||||
async chat(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@@ -200,9 +189,7 @@ export class CopilotController {
|
||||
|
||||
return content;
|
||||
} catch (e: any) {
|
||||
throw new InternalServerErrorException(
|
||||
e.message || "Couldn't generate text"
|
||||
);
|
||||
throw new CopilotFailedToGenerateText(e.message);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -253,18 +240,10 @@ export class CopilotController {
|
||||
)
|
||||
)
|
||||
),
|
||||
catchError(err =>
|
||||
of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
})
|
||||
)
|
||||
catchError(mapSseError)
|
||||
);
|
||||
} catch (err) {
|
||||
return of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
});
|
||||
return mapSseError(err);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -318,18 +297,10 @@ export class CopilotController {
|
||||
)
|
||||
)
|
||||
),
|
||||
catchError(err =>
|
||||
of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
})
|
||||
)
|
||||
catchError(mapSseError)
|
||||
);
|
||||
} catch (err) {
|
||||
return of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
});
|
||||
return mapSseError(err);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -356,7 +327,7 @@ export class CopilotController {
|
||||
model
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
throw new NoCopilotProviderAvailable();
|
||||
}
|
||||
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
@@ -402,18 +373,10 @@ export class CopilotController {
|
||||
)
|
||||
)
|
||||
),
|
||||
catchError(err =>
|
||||
of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
})
|
||||
)
|
||||
catchError(mapSseError)
|
||||
);
|
||||
} catch (err) {
|
||||
return of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
});
|
||||
return mapSseError(err);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -425,7 +388,7 @@ export class CopilotController {
|
||||
) {
|
||||
const { unsplashKey } = this.config.plugins.copilot || {};
|
||||
if (!unsplashKey) {
|
||||
throw new InternalServerErrorException('Unsplash key is not configured');
|
||||
throw new UnsplashIsNotConfigured();
|
||||
}
|
||||
|
||||
const query = new URLSearchParams(params);
|
||||
@@ -458,9 +421,10 @@ export class CopilotController {
|
||||
const { body, metadata } = await this.storage.get(userId, workspaceId, key);
|
||||
|
||||
if (!body) {
|
||||
throw new NotFoundException(
|
||||
`Blob not found in ${userId}'s workspace ${workspaceId}: ${key}`
|
||||
);
|
||||
throw new BlobNotFound({
|
||||
workspaceId,
|
||||
blobId: key,
|
||||
});
|
||||
}
|
||||
|
||||
// metadata should always exists if body is not null
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { createHash } from 'node:crypto';
|
||||
|
||||
import { BadRequestException, Logger, NotFoundException } from '@nestjs/common';
|
||||
import { BadRequestException, NotFoundException } from '@nestjs/common';
|
||||
import {
|
||||
Args,
|
||||
Field,
|
||||
@@ -23,6 +23,7 @@ import { Admin } from '../../core/common';
|
||||
import { UserType } from '../../core/user';
|
||||
import { PermissionService } from '../../core/workspaces/permission';
|
||||
import {
|
||||
CopilotFailedToCreateMessage,
|
||||
FileUpload,
|
||||
MutexService,
|
||||
Throttle,
|
||||
@@ -201,8 +202,6 @@ export class CopilotType {
|
||||
@Throttle()
|
||||
@Resolver(() => CopilotType)
|
||||
export class CopilotResolver {
|
||||
private readonly logger = new Logger(CopilotResolver.name);
|
||||
|
||||
constructor(
|
||||
private readonly permissions: PermissionService,
|
||||
private readonly mutex: MutexService,
|
||||
@@ -385,8 +384,7 @@ export class CopilotResolver {
|
||||
try {
|
||||
return await this.chatSession.createMessage(options);
|
||||
} catch (e: any) {
|
||||
this.logger.error(`Failed to create chat message: ${e.message}`);
|
||||
throw new Error('Failed to create chat message');
|
||||
throw new CopilotFailedToCreateMessage(e.message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,14 @@ import { AiPromptRole, PrismaClient } from '@prisma/client';
|
||||
|
||||
import { FeatureManagementService } from '../../core/features';
|
||||
import { QuotaService } from '../../core/quota';
|
||||
import { PaymentRequiredException } from '../../fundamentals';
|
||||
import {
|
||||
CopilotActionTaken,
|
||||
CopilotMessageNotFound,
|
||||
CopilotPromptNotFound,
|
||||
CopilotQuotaExceeded,
|
||||
CopilotSessionDeleted,
|
||||
CopilotSessionNotFound,
|
||||
} from '../../fundamentals';
|
||||
import { ChatMessageCache } from './message';
|
||||
import { PromptService } from './prompt';
|
||||
import {
|
||||
@@ -58,7 +65,7 @@ export class ChatSession implements AsyncDisposable {
|
||||
this.state.messages.length > 0 &&
|
||||
message.role === 'user'
|
||||
) {
|
||||
throw new Error('Action has been taken, no more messages allowed');
|
||||
throw new CopilotActionTaken();
|
||||
}
|
||||
this.state.messages.push(message);
|
||||
this.stashMessageCount += 1;
|
||||
@@ -74,7 +81,7 @@ export class ChatSession implements AsyncDisposable {
|
||||
async getMessageById(messageId: string) {
|
||||
const message = await this.messageCache.get(messageId);
|
||||
if (!message || message.sessionId !== this.state.sessionId) {
|
||||
throw new Error(`Message not found: ${messageId}`);
|
||||
throw new CopilotMessageNotFound();
|
||||
}
|
||||
return message;
|
||||
}
|
||||
@@ -82,7 +89,7 @@ export class ChatSession implements AsyncDisposable {
|
||||
async pushByMessageId(messageId: string) {
|
||||
const message = await this.messageCache.get(messageId);
|
||||
if (!message || message.sessionId !== this.state.sessionId) {
|
||||
throw new Error(`Message not found: ${messageId}`);
|
||||
throw new CopilotMessageNotFound();
|
||||
}
|
||||
|
||||
this.push({
|
||||
@@ -196,7 +203,7 @@ export class ChatSessionService {
|
||||
},
|
||||
select: { id: true, deletedAt: true },
|
||||
})) || {};
|
||||
if (deletedAt) throw new Error(`Session is deleted: ${id}`);
|
||||
if (deletedAt) throw new CopilotSessionDeleted();
|
||||
if (id) sessionId = id;
|
||||
}
|
||||
|
||||
@@ -274,7 +281,8 @@ export class ChatSessionService {
|
||||
.then(async session => {
|
||||
if (!session) return;
|
||||
const prompt = await this.prompt.get(session.promptName);
|
||||
if (!prompt) throw new Error(`Prompt not found: ${session.promptName}`);
|
||||
if (!prompt)
|
||||
throw new CopilotPromptNotFound({ name: session.promptName });
|
||||
|
||||
const messages = ChatMessageSchema.array().safeParse(session.messages);
|
||||
|
||||
@@ -300,7 +308,7 @@ export class ChatSessionService {
|
||||
})
|
||||
.then(session => session?.id);
|
||||
if (!id) {
|
||||
throw new Error(`Session not found: ${sessionId}`);
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
const ids = await tx.aiSessionMessage
|
||||
.findMany({
|
||||
@@ -412,7 +420,7 @@ export class ChatSessionService {
|
||||
if (ret.success) {
|
||||
const prompt = await this.prompt.get(promptName);
|
||||
if (!prompt) {
|
||||
throw new Error(`Prompt not found: ${promptName}`);
|
||||
throw new CopilotPromptNotFound({ name: promptName });
|
||||
}
|
||||
|
||||
// render system prompt
|
||||
@@ -471,9 +479,7 @@ export class ChatSessionService {
|
||||
async checkQuota(userId: string) {
|
||||
const { limit, used } = await this.getQuota(userId);
|
||||
if (limit && Number.isFinite(limit) && used >= limit) {
|
||||
throw new PaymentRequiredException(
|
||||
`You have reached the limit of actions in this workspace, please upgrade your plan.`
|
||||
);
|
||||
throw new CopilotQuotaExceeded();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -482,7 +488,7 @@ export class ChatSessionService {
|
||||
const prompt = await this.prompt.get(options.promptName);
|
||||
if (!prompt) {
|
||||
this.logger.error(`Prompt not found: ${options.promptName}`);
|
||||
throw new Error('Prompt not found');
|
||||
throw new CopilotPromptNotFound({ name: options.promptName });
|
||||
}
|
||||
return await this.setSession({
|
||||
...options,
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { createHash } from 'node:crypto';
|
||||
|
||||
import { Injectable, PayloadTooLargeException } from '@nestjs/common';
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { QuotaManagementService } from '../../core/quota';
|
||||
import {
|
||||
type BlobInputType,
|
||||
BlobQuotaExceeded,
|
||||
Config,
|
||||
type FileUpload,
|
||||
type StorageProvider,
|
||||
@@ -54,9 +55,7 @@ export class CopilotStorage {
|
||||
const checkExceeded = await this.quota.getQuotaCalculator(userId);
|
||||
|
||||
if (checkExceeded(0)) {
|
||||
throw new PayloadTooLargeException(
|
||||
'Storage or blob size limit exceeded.'
|
||||
);
|
||||
throw new BlobQuotaExceeded();
|
||||
}
|
||||
const buffer = await new Promise<Buffer>((resolve, reject) => {
|
||||
const stream = blob.createReadStream();
|
||||
@@ -67,9 +66,7 @@ export class CopilotStorage {
|
||||
// check size after receive each chunk to avoid unnecessary memory usage
|
||||
const bufferSize = chunks.reduce((acc, cur) => acc + cur.length, 0);
|
||||
if (checkExceeded(bufferSize)) {
|
||||
reject(
|
||||
new PayloadTooLargeException('Storage or blob size limit exceeded.')
|
||||
);
|
||||
reject(new BlobQuotaExceeded());
|
||||
}
|
||||
});
|
||||
stream.on('error', reject);
|
||||
@@ -77,7 +74,7 @@ export class CopilotStorage {
|
||||
const buffer = Buffer.concat(chunks);
|
||||
|
||||
if (checkExceeded(buffer.length)) {
|
||||
reject(new PayloadTooLargeException('Storage limit exceeded.'));
|
||||
reject(new BlobQuotaExceeded());
|
||||
} else {
|
||||
resolve(buffer);
|
||||
}
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import {
|
||||
BadRequestException,
|
||||
Controller,
|
||||
Get,
|
||||
Query,
|
||||
Req,
|
||||
Res,
|
||||
} from '@nestjs/common';
|
||||
import { Controller, Get, Query, Req, Res } from '@nestjs/common';
|
||||
import { ConnectedAccount, PrismaClient } from '@prisma/client';
|
||||
import type { Request, Response } from 'express';
|
||||
|
||||
import { AuthService, Public } from '../../core/auth';
|
||||
import { UserService } from '../../core/user';
|
||||
import { URLHelper } from '../../fundamentals';
|
||||
import {
|
||||
InvalidOauthCallbackState,
|
||||
MissingOauthQueryParameter,
|
||||
OauthAccountAlreadyConnected,
|
||||
OauthStateExpired,
|
||||
UnknownOauthProvider,
|
||||
URLHelper,
|
||||
WrongSignInMethod,
|
||||
} from '../../fundamentals';
|
||||
import { OAuthProviderName } from './config';
|
||||
import { OAuthAccount, Tokens } from './providers/def';
|
||||
import { OAuthProviderFactory } from './register';
|
||||
@@ -35,12 +36,15 @@ export class OAuthController {
|
||||
@Query('provider') unknownProviderName: string,
|
||||
@Query('redirect_uri') redirectUri?: string
|
||||
) {
|
||||
if (!unknownProviderName) {
|
||||
throw new MissingOauthQueryParameter({ name: 'provider' });
|
||||
}
|
||||
// @ts-expect-error safe
|
||||
const providerName = OAuthProviderName[unknownProviderName];
|
||||
const provider = this.providerFactory.get(providerName);
|
||||
|
||||
if (!provider) {
|
||||
throw new BadRequestException('Invalid OAuth provider');
|
||||
throw new UnknownOauthProvider({ name: unknownProviderName });
|
||||
}
|
||||
|
||||
const state = await this.oauth.saveOAuthState({
|
||||
@@ -60,29 +64,31 @@ export class OAuthController {
|
||||
@Query('state') stateStr?: string
|
||||
) {
|
||||
if (!code) {
|
||||
throw new BadRequestException('Missing query parameter `code`');
|
||||
throw new MissingOauthQueryParameter({ name: 'code' });
|
||||
}
|
||||
|
||||
if (!stateStr) {
|
||||
throw new BadRequestException('Invalid callback state parameter');
|
||||
throw new MissingOauthQueryParameter({ name: 'state' });
|
||||
}
|
||||
|
||||
if (typeof stateStr !== 'string' || !this.oauth.isValidState(stateStr)) {
|
||||
throw new InvalidOauthCallbackState();
|
||||
}
|
||||
|
||||
const state = await this.oauth.getOAuthState(stateStr);
|
||||
|
||||
if (!state) {
|
||||
throw new BadRequestException('OAuth state expired, please try again.');
|
||||
throw new OauthStateExpired();
|
||||
}
|
||||
|
||||
if (!state.provider) {
|
||||
throw new BadRequestException(
|
||||
'Missing callback state parameter `provider`'
|
||||
);
|
||||
throw new MissingOauthQueryParameter({ name: 'provider' });
|
||||
}
|
||||
|
||||
const provider = this.providerFactory.get(state.provider);
|
||||
|
||||
if (!provider) {
|
||||
throw new BadRequestException('Invalid provider');
|
||||
throw new UnknownOauthProvider({ name: state.provider ?? 'unknown' });
|
||||
}
|
||||
|
||||
const tokens = await provider.getToken(code);
|
||||
@@ -154,15 +160,9 @@ export class OAuthController {
|
||||
// we can't directly connect the external account with given email in sign in scenario for safety concern.
|
||||
// let user manually connect in account sessions instead.
|
||||
if (user.registered) {
|
||||
throw new BadRequestException(
|
||||
'The account with provided email is not register in the same way.'
|
||||
);
|
||||
throw new WrongSignInMethod();
|
||||
}
|
||||
|
||||
await this.user.fulfillUser(externalAccount.email, {
|
||||
emailVerifiedAt: new Date(),
|
||||
registered: true,
|
||||
});
|
||||
await this.db.connectedAccount.create({
|
||||
data: {
|
||||
userId: user.id,
|
||||
@@ -228,9 +228,7 @@ export class OAuthController {
|
||||
|
||||
if (connectedUser) {
|
||||
if (connectedUser.id !== user.id) {
|
||||
throw new BadRequestException(
|
||||
'The third-party account has already been connected to another user.'
|
||||
);
|
||||
throw new OauthAccountAlreadyConnected();
|
||||
}
|
||||
} else {
|
||||
await this.db.connectedAccount.create({
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { HttpException, HttpStatus, Injectable } from '@nestjs/common';
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { Config, URLHelper } from '../../../fundamentals';
|
||||
import { OAuthProviderName } from '../config';
|
||||
@@ -39,74 +39,60 @@ export class GithubOAuthProvider extends AutoRegisteredOAuthProvider {
|
||||
}
|
||||
|
||||
async getToken(code: string) {
|
||||
try {
|
||||
const response = await fetch(
|
||||
'https://github.com/login/oauth/access_token',
|
||||
{
|
||||
method: 'POST',
|
||||
body: this.url.stringify({
|
||||
code,
|
||||
client_id: this.config.clientId,
|
||||
client_secret: this.config.clientSecret,
|
||||
redirect_uri: this.url.link('/oauth/callback'),
|
||||
}),
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
const ghToken = (await response.json()) as AuthTokenResponse;
|
||||
|
||||
return {
|
||||
accessToken: ghToken.access_token,
|
||||
scope: ghToken.scope,
|
||||
};
|
||||
} else {
|
||||
throw new Error(
|
||||
`Server responded with non-success code ${
|
||||
response.status
|
||||
}, ${JSON.stringify(await response.json())}`
|
||||
);
|
||||
const response = await fetch(
|
||||
'https://github.com/login/oauth/access_token',
|
||||
{
|
||||
method: 'POST',
|
||||
body: this.url.stringify({
|
||||
code,
|
||||
client_id: this.config.clientId,
|
||||
client_secret: this.config.clientSecret,
|
||||
redirect_uri: this.url.link('/oauth/callback'),
|
||||
}),
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
}
|
||||
} catch (e) {
|
||||
throw new HttpException(
|
||||
`Failed to get access_token, err: ${(e as Error).message}`,
|
||||
HttpStatus.BAD_REQUEST
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
const ghToken = (await response.json()) as AuthTokenResponse;
|
||||
|
||||
return {
|
||||
accessToken: ghToken.access_token,
|
||||
scope: ghToken.scope,
|
||||
};
|
||||
} else {
|
||||
throw new Error(
|
||||
`Server responded with non-success code ${
|
||||
response.status
|
||||
}, ${JSON.stringify(await response.json())}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async getUser(token: string) {
|
||||
try {
|
||||
const response = await fetch('https://api.github.com/user', {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
});
|
||||
const response = await fetch('https://api.github.com/user', {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const user = (await response.json()) as UserInfo;
|
||||
if (response.ok) {
|
||||
const user = (await response.json()) as UserInfo;
|
||||
|
||||
return {
|
||||
id: user.login,
|
||||
avatarUrl: user.avatar_url,
|
||||
email: user.email,
|
||||
};
|
||||
} else {
|
||||
throw new Error(
|
||||
`Server responded with non-success code ${
|
||||
response.status
|
||||
} ${await response.text()}`
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
throw new HttpException(
|
||||
`Failed to get user information, err: ${(e as Error).stack}`,
|
||||
HttpStatus.BAD_REQUEST
|
||||
return {
|
||||
id: user.login,
|
||||
avatarUrl: user.avatar_url,
|
||||
email: user.email,
|
||||
};
|
||||
} else {
|
||||
throw new Error(
|
||||
`Server responded with non-success code ${
|
||||
response.status
|
||||
} ${await response.text()}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { HttpException, HttpStatus, Injectable } from '@nestjs/common';
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { Config, URLHelper } from '../../../fundamentals';
|
||||
import { OAuthProviderName } from '../config';
|
||||
@@ -44,77 +44,63 @@ export class GoogleOAuthProvider extends AutoRegisteredOAuthProvider {
|
||||
}
|
||||
|
||||
async getToken(code: string) {
|
||||
try {
|
||||
const response = await fetch('https://oauth2.googleapis.com/token', {
|
||||
method: 'POST',
|
||||
body: this.url.stringify({
|
||||
code,
|
||||
client_id: this.config.clientId,
|
||||
client_secret: this.config.clientSecret,
|
||||
redirect_uri: this.url.link('/oauth/callback'),
|
||||
grant_type: 'authorization_code',
|
||||
}),
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
});
|
||||
const response = await fetch('https://oauth2.googleapis.com/token', {
|
||||
method: 'POST',
|
||||
body: this.url.stringify({
|
||||
code,
|
||||
client_id: this.config.clientId,
|
||||
client_secret: this.config.clientSecret,
|
||||
redirect_uri: this.url.link('/oauth/callback'),
|
||||
grant_type: 'authorization_code',
|
||||
}),
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const ghToken = (await response.json()) as GoogleOAuthTokenResponse;
|
||||
if (response.ok) {
|
||||
const ghToken = (await response.json()) as GoogleOAuthTokenResponse;
|
||||
|
||||
return {
|
||||
accessToken: ghToken.access_token,
|
||||
refreshToken: ghToken.refresh_token,
|
||||
expiresAt: new Date(Date.now() + ghToken.expires_in * 1000),
|
||||
scope: ghToken.scope,
|
||||
};
|
||||
} else {
|
||||
throw new Error(
|
||||
`Server responded with non-success code ${
|
||||
response.status
|
||||
}, ${JSON.stringify(await response.json())}`
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
throw new HttpException(
|
||||
`Failed to get access_token, err: ${(e as Error).message}`,
|
||||
HttpStatus.BAD_REQUEST
|
||||
return {
|
||||
accessToken: ghToken.access_token,
|
||||
refreshToken: ghToken.refresh_token,
|
||||
expiresAt: new Date(Date.now() + ghToken.expires_in * 1000),
|
||||
scope: ghToken.scope,
|
||||
};
|
||||
} else {
|
||||
throw new Error(
|
||||
`Server responded with non-success code ${
|
||||
response.status
|
||||
}, ${JSON.stringify(await response.json())}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async getUser(token: string) {
|
||||
try {
|
||||
const response = await fetch(
|
||||
'https://www.googleapis.com/oauth2/v2/userinfo',
|
||||
{
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
const user = (await response.json()) as UserInfo;
|
||||
|
||||
return {
|
||||
id: user.id,
|
||||
avatarUrl: user.picture,
|
||||
email: user.email,
|
||||
};
|
||||
} else {
|
||||
throw new Error(
|
||||
`Server responded with non-success code ${
|
||||
response.status
|
||||
} ${await response.text()}`
|
||||
);
|
||||
const response = await fetch(
|
||||
'https://www.googleapis.com/oauth2/v2/userinfo',
|
||||
{
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
}
|
||||
} catch (e) {
|
||||
throw new HttpException(
|
||||
`Failed to get user information, err: ${(e as Error).stack}`,
|
||||
HttpStatus.BAD_REQUEST
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
const user = (await response.json()) as UserInfo;
|
||||
|
||||
return {
|
||||
id: user.id,
|
||||
avatarUrl: user.picture,
|
||||
email: user.email,
|
||||
};
|
||||
} else {
|
||||
throw new Error(
|
||||
`Server responded with non-success code ${
|
||||
response.status
|
||||
} ${await response.text()}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
import {
|
||||
BadRequestException,
|
||||
Injectable,
|
||||
InternalServerErrorException,
|
||||
OnModuleInit,
|
||||
} from '@nestjs/common';
|
||||
import { Injectable, Logger, OnModuleInit } from '@nestjs/common';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { Config, URLHelper } from '../../../fundamentals';
|
||||
@@ -44,6 +39,8 @@ const OIDCConfigurationSchema = z.object({
|
||||
|
||||
type OIDCConfiguration = z.infer<typeof OIDCConfigurationSchema>;
|
||||
|
||||
const logger = new Logger('OIDCClient');
|
||||
|
||||
class OIDCClient {
|
||||
private static async fetch<T = any>(
|
||||
url: string,
|
||||
@@ -53,17 +50,8 @@ class OIDCClient {
|
||||
const response = await fetch(url, options);
|
||||
|
||||
if (!response.ok) {
|
||||
if (response.status >= 400 && response.status < 500) {
|
||||
throw new BadRequestException(`Invalid OIDC configuration`, {
|
||||
cause: await response.json(),
|
||||
description: response.statusText,
|
||||
});
|
||||
} else {
|
||||
throw new InternalServerErrorException(`Failed to configure client`, {
|
||||
cause: await response.json(),
|
||||
description: response.statusText,
|
||||
});
|
||||
}
|
||||
logger.error('Failed to fetch OIDC configuration', await response.json());
|
||||
throw new Error(`Failed to configure client`);
|
||||
}
|
||||
const data = await response.json();
|
||||
return verifier.parse(data);
|
||||
|
||||
@@ -20,6 +20,10 @@ export class OAuthService {
|
||||
private readonly cache: SessionCache
|
||||
) {}
|
||||
|
||||
isValidState(stateStr: string) {
|
||||
return stateStr.length === 36;
|
||||
}
|
||||
|
||||
async saveOAuthState(state: OAuthState) {
|
||||
const token = randomUUID();
|
||||
await this.cache.set(`${OAUTH_STATE_KEY}:${token}`, state, {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { BadGatewayException, ForbiddenException } from '@nestjs/common';
|
||||
import {
|
||||
Args,
|
||||
Context,
|
||||
@@ -19,7 +18,12 @@ import { groupBy } from 'lodash-es';
|
||||
|
||||
import { CurrentUser, Public } from '../../core/auth';
|
||||
import { UserType } from '../../core/user';
|
||||
import { Config, URLHelper } from '../../fundamentals';
|
||||
import {
|
||||
AccessDenied,
|
||||
Config,
|
||||
FailedToCheckout,
|
||||
URLHelper,
|
||||
} from '../../fundamentals';
|
||||
import { decodeLookupKey, SubscriptionService } from './service';
|
||||
import {
|
||||
InvoiceStatus,
|
||||
@@ -227,7 +231,7 @@ export class SubscriptionResolver {
|
||||
});
|
||||
|
||||
if (!session.url) {
|
||||
throw new BadGatewayException('Failed to create checkout session.');
|
||||
throw new FailedToCheckout();
|
||||
}
|
||||
|
||||
return session.url;
|
||||
@@ -322,9 +326,7 @@ export class UserSubscriptionResolver {
|
||||
) {
|
||||
// allow admin to query other user's subscription
|
||||
if (!ctx.isAdminQuery && me.id !== user.id) {
|
||||
throw new ForbiddenException(
|
||||
'You are not allowed to access this subscription.'
|
||||
);
|
||||
throw new AccessDenied();
|
||||
}
|
||||
|
||||
// @FIXME(@forehalo): should not mock any api for selfhosted server
|
||||
@@ -363,9 +365,7 @@ export class UserSubscriptionResolver {
|
||||
@Parent() user: User
|
||||
): Promise<UserSubscription[]> {
|
||||
if (me.id !== user.id) {
|
||||
throw new ForbiddenException(
|
||||
'You are not allowed to access this subscription.'
|
||||
);
|
||||
throw new AccessDenied();
|
||||
}
|
||||
|
||||
return this.db.userSubscription.findMany({
|
||||
@@ -385,9 +385,7 @@ export class UserSubscriptionResolver {
|
||||
@Args('skip', { type: () => Int, nullable: true }) skip?: number
|
||||
) {
|
||||
if (me.id !== user.id) {
|
||||
throw new ForbiddenException(
|
||||
'You are not allowed to access this invoices'
|
||||
);
|
||||
throw new AccessDenied();
|
||||
}
|
||||
|
||||
return this.db.userInvoice.findMany({
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { BadRequestException, Injectable, Logger } from '@nestjs/common';
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { OnEvent as RawOnEvent } from '@nestjs/event-emitter';
|
||||
import type {
|
||||
Prisma,
|
||||
@@ -14,7 +14,20 @@ import Stripe from 'stripe';
|
||||
|
||||
import { CurrentUser } from '../../core/auth';
|
||||
import { EarlyAccessType, FeatureManagementService } from '../../core/features';
|
||||
import { Config, EventEmitter, OnEvent } from '../../fundamentals';
|
||||
import {
|
||||
ActionForbidden,
|
||||
Config,
|
||||
CustomerPortalCreateFailed,
|
||||
EventEmitter,
|
||||
OnEvent,
|
||||
SameSubscriptionRecurring,
|
||||
SubscriptionAlreadyExists,
|
||||
SubscriptionExpired,
|
||||
SubscriptionHasBeenCanceled,
|
||||
SubscriptionNotExists,
|
||||
SubscriptionPlanNotFound,
|
||||
UserNotFound,
|
||||
} from '../../fundamentals';
|
||||
import { ScheduleManager } from './schedule';
|
||||
import {
|
||||
InvoiceStatus,
|
||||
@@ -160,7 +173,7 @@ export class SubscriptionService {
|
||||
this.config.affine.canary &&
|
||||
!this.features.isStaff(user.email)
|
||||
) {
|
||||
throw new BadRequestException('You are not allowed to do this.');
|
||||
throw new ActionForbidden();
|
||||
}
|
||||
|
||||
const currentSubscription = await this.db.userSubscription.findFirst({
|
||||
@@ -172,9 +185,7 @@ export class SubscriptionService {
|
||||
});
|
||||
|
||||
if (currentSubscription) {
|
||||
throw new BadRequestException(
|
||||
`You've already subscribed to the ${plan} plan`
|
||||
);
|
||||
throw new SubscriptionAlreadyExists({ plan });
|
||||
}
|
||||
|
||||
const customer = await this.getOrCreateCustomer(
|
||||
@@ -245,18 +256,16 @@ export class SubscriptionService {
|
||||
});
|
||||
|
||||
if (!user) {
|
||||
throw new BadRequestException('Unknown user');
|
||||
throw new UserNotFound();
|
||||
}
|
||||
|
||||
const subscriptionInDB = user?.subscriptions.find(s => s.plan === plan);
|
||||
if (!subscriptionInDB) {
|
||||
throw new BadRequestException(`You didn't subscribe to the ${plan} plan`);
|
||||
throw new SubscriptionNotExists({ plan });
|
||||
}
|
||||
|
||||
if (subscriptionInDB.canceledAt) {
|
||||
throw new BadRequestException(
|
||||
'Your subscription has already been canceled'
|
||||
);
|
||||
throw new SubscriptionHasBeenCanceled();
|
||||
}
|
||||
|
||||
// should release the schedule first
|
||||
@@ -298,22 +307,20 @@ export class SubscriptionService {
|
||||
});
|
||||
|
||||
if (!user) {
|
||||
throw new BadRequestException('Unknown user');
|
||||
throw new UserNotFound();
|
||||
}
|
||||
|
||||
const subscriptionInDB = user?.subscriptions.find(s => s.plan === plan);
|
||||
if (!subscriptionInDB) {
|
||||
throw new BadRequestException(`You didn't subscribe to the ${plan} plan`);
|
||||
throw new SubscriptionNotExists({ plan });
|
||||
}
|
||||
|
||||
if (!subscriptionInDB.canceledAt) {
|
||||
throw new BadRequestException('Your subscription has not been canceled');
|
||||
throw new SubscriptionHasBeenCanceled();
|
||||
}
|
||||
|
||||
if (subscriptionInDB.end < new Date()) {
|
||||
throw new BadRequestException(
|
||||
'Your subscription is expired, please checkout again.'
|
||||
);
|
||||
throw new SubscriptionExpired();
|
||||
}
|
||||
|
||||
if (subscriptionInDB.stripeScheduleId) {
|
||||
@@ -354,23 +361,19 @@ export class SubscriptionService {
|
||||
});
|
||||
|
||||
if (!user) {
|
||||
throw new BadRequestException('Unknown user');
|
||||
throw new UserNotFound();
|
||||
}
|
||||
const subscriptionInDB = user?.subscriptions.find(s => s.plan === plan);
|
||||
if (!subscriptionInDB) {
|
||||
throw new BadRequestException(`You didn't subscribe to the ${plan} plan`);
|
||||
throw new SubscriptionNotExists({ plan });
|
||||
}
|
||||
|
||||
if (subscriptionInDB.canceledAt) {
|
||||
throw new BadRequestException(
|
||||
'Your subscription has already been canceled'
|
||||
);
|
||||
throw new SubscriptionHasBeenCanceled();
|
||||
}
|
||||
|
||||
if (subscriptionInDB.recurring === recurring) {
|
||||
throw new BadRequestException(
|
||||
`You are already in ${recurring} recurring`
|
||||
);
|
||||
throw new SameSubscriptionRecurring({ recurring });
|
||||
}
|
||||
|
||||
const price = await this.getPrice(
|
||||
@@ -404,7 +407,7 @@ export class SubscriptionService {
|
||||
});
|
||||
|
||||
if (!user) {
|
||||
throw new BadRequestException('Unknown user');
|
||||
throw new UserNotFound();
|
||||
}
|
||||
|
||||
try {
|
||||
@@ -415,7 +418,7 @@ export class SubscriptionService {
|
||||
return portal.url;
|
||||
} catch (e) {
|
||||
this.logger.error('Failed to create customer portal.', e);
|
||||
throw new BadRequestException('Failed to create customer portal');
|
||||
throw new CustomerPortalCreateFailed();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -751,9 +754,10 @@ export class SubscriptionService {
|
||||
});
|
||||
|
||||
if (!prices.data.length) {
|
||||
throw new BadRequestException(
|
||||
`Unknown subscription plan ${plan} with ${recurring} recurring`
|
||||
);
|
||||
throw new SubscriptionPlanNotFound({
|
||||
plan,
|
||||
recurring,
|
||||
});
|
||||
}
|
||||
|
||||
return prices.data[0].id;
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
import assert from 'node:assert';
|
||||
|
||||
import type { RawBodyRequest } from '@nestjs/common';
|
||||
import {
|
||||
Controller,
|
||||
Logger,
|
||||
NotAcceptableException,
|
||||
Post,
|
||||
Req,
|
||||
} from '@nestjs/common';
|
||||
import { Controller, Logger, Post, Req } from '@nestjs/common';
|
||||
import { EventEmitter2 } from '@nestjs/event-emitter';
|
||||
import type { Request } from 'express';
|
||||
import Stripe from 'stripe';
|
||||
|
||||
import { Public } from '../../core/auth';
|
||||
import { Config } from '../../fundamentals';
|
||||
import { Config, InternalServerError } from '../../fundamentals';
|
||||
|
||||
@Controller('/api/stripe')
|
||||
export class StripeWebhook {
|
||||
@@ -55,9 +49,8 @@ export class StripeWebhook {
|
||||
this.logger.error('Failed to handle Stripe Webhook event.', e);
|
||||
});
|
||||
});
|
||||
} catch (err) {
|
||||
this.logger.error('Stripe Webhook error', err);
|
||||
throw new NotAcceptableException();
|
||||
} catch (err: any) {
|
||||
throw new InternalServerError(err.message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user