mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-27 02:42:25 +08:00
feat: improve selfhosted login (#14502)
fix #13397 fix #14011 #### PR Dependency Tree * **PR #14502** 👈 This tree was auto-generated by [Charcoal](https://github.com/danerwilliams/charcoal) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Centralized CORS policy with dynamic origin validation applied to server and realtime connections * Improved sign-in flows with contextual, localized error hints and toast notifications * Centralized network-error normalization and conditional OAuth provider fetching * **Bug Fixes** * Better feedback for self-hosted connection failures and clearer authentication error handling * More robust handling of network-related failures with user-friendly messages <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
99
packages/backend/server/src/base/cors.ts
Normal file
99
packages/backend/server/src/base/cors.ts
Normal file
@@ -0,0 +1,99 @@
|
||||
import { URLHelper } from './helpers';
|
||||
|
||||
const DEV_LOOPBACK_PROTOCOLS = new Set(['http:', 'https:']);
|
||||
const DEV_LOOPBACK_HOSTS = new Set(['localhost', '127.0.0.1', '::1']);
|
||||
const MOBILE_CLIENT_ORIGINS = new Set([
|
||||
'https://localhost',
|
||||
'capacitor://localhost',
|
||||
'ionic://localhost',
|
||||
]);
|
||||
const DESKTOP_CLIENT_ORIGINS = new Set(['assets://.', 'assets://another-host']);
|
||||
|
||||
export const CORS_ALLOWED_METHODS = [
|
||||
'GET',
|
||||
'HEAD',
|
||||
'PUT',
|
||||
'PATCH',
|
||||
'POST',
|
||||
'DELETE',
|
||||
'OPTIONS',
|
||||
];
|
||||
|
||||
export const CORS_ALLOWED_HEADERS = [
|
||||
'accept',
|
||||
'authorization',
|
||||
'content-type',
|
||||
'x-affine-version',
|
||||
'x-operation-name',
|
||||
'x-request-id',
|
||||
'x-captcha-token',
|
||||
'x-captcha-challenge',
|
||||
'x-affine-csrf-token',
|
||||
'x-requested-with',
|
||||
'range',
|
||||
];
|
||||
|
||||
export const CORS_EXPOSED_HEADERS = [
|
||||
'content-length',
|
||||
'content-range',
|
||||
'x-request-id',
|
||||
];
|
||||
|
||||
function normalizeHostname(hostname: string) {
|
||||
return hostname.toLowerCase().replace(/^\[/, '').replace(/\]$/, '');
|
||||
}
|
||||
|
||||
function isDevLoopbackOrigin(origin: string) {
|
||||
try {
|
||||
const parsed = new URL(origin);
|
||||
return (
|
||||
DEV_LOOPBACK_PROTOCOLS.has(parsed.protocol) &&
|
||||
DEV_LOOPBACK_HOSTS.has(normalizeHostname(parsed.hostname))
|
||||
);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export function buildCorsAllowedOrigins(url: URLHelper) {
|
||||
return new Set<string>([
|
||||
...url.allowedOrigins,
|
||||
...MOBILE_CLIENT_ORIGINS,
|
||||
...DESKTOP_CLIENT_ORIGINS,
|
||||
]);
|
||||
}
|
||||
|
||||
export function isCorsOriginAllowed(
|
||||
origin: string | undefined | null,
|
||||
allowedOrigins: Set<string>
|
||||
) {
|
||||
if (!origin) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (allowedOrigins.has(origin)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if ((env.dev || env.testing) && isDevLoopbackOrigin(origin)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
export function corsOriginCallback(
|
||||
origin: string | undefined,
|
||||
allowedOrigins: Set<string>,
|
||||
onBlocked: (origin: string) => void,
|
||||
callback: (error: Error | null, allow?: boolean) => void
|
||||
) {
|
||||
if (isCorsOriginAllowed(origin, allowedOrigins)) {
|
||||
callback(null, true);
|
||||
return;
|
||||
}
|
||||
|
||||
const blockedOrigin = origin ?? '<empty>';
|
||||
onBlocked(blockedOrigin);
|
||||
callback(null, false);
|
||||
}
|
||||
@@ -11,6 +11,7 @@ export {
|
||||
defineModuleConfig,
|
||||
type JSONSchema,
|
||||
} from './config';
|
||||
export * from './cors';
|
||||
export * from './error';
|
||||
export { EventBus, OnEvent } from './event';
|
||||
export {
|
||||
|
||||
@@ -4,7 +4,15 @@ import { createAdapter } from '@socket.io/redis-adapter';
|
||||
import { Server, Socket } from 'socket.io';
|
||||
|
||||
import { Config } from '../config';
|
||||
import {
|
||||
buildCorsAllowedOrigins,
|
||||
CORS_ALLOWED_HEADERS,
|
||||
CORS_ALLOWED_METHODS,
|
||||
corsOriginCallback,
|
||||
} from '../cors';
|
||||
import { AuthenticationRequired } from '../error';
|
||||
import { URLHelper } from '../helpers';
|
||||
import { AFFiNELogger } from '../logger';
|
||||
import { SocketIoRedis } from '../redis';
|
||||
import { WEBSOCKET_OPTIONS } from './options';
|
||||
|
||||
@@ -14,17 +22,34 @@ export class SocketIoAdapter extends IoAdapter {
|
||||
}
|
||||
|
||||
override createIOServer(port: number, options?: any): Server {
|
||||
const logger = this.app.get(AFFiNELogger);
|
||||
const config = this.app.get(WEBSOCKET_OPTIONS) as Config['websocket'] & {
|
||||
canActivate: (socket: Socket) => Promise<boolean>;
|
||||
};
|
||||
const url = this.app.get(URLHelper);
|
||||
const allowedOrigins = buildCorsAllowedOrigins(url);
|
||||
|
||||
const server: Server = super.createIOServer(port, {
|
||||
...config,
|
||||
...options,
|
||||
// Enable CORS for Socket.IO
|
||||
cors: {
|
||||
origin: true, // Allow all origins
|
||||
credentials: true, // Allow credentials (cookies, auth headers)
|
||||
methods: ['GET', 'POST'],
|
||||
origin: (
|
||||
origin: string | undefined,
|
||||
callback: (error: Error | null, allow?: boolean) => void
|
||||
) => {
|
||||
corsOriginCallback(
|
||||
origin,
|
||||
allowedOrigins,
|
||||
blockedOrigin =>
|
||||
logger.warn(
|
||||
`Blocked WebSocket CORS request from origin: ${blockedOrigin}`
|
||||
),
|
||||
callback
|
||||
);
|
||||
},
|
||||
credentials: true,
|
||||
methods: CORS_ALLOWED_METHODS,
|
||||
allowedHeaders: CORS_ALLOWED_HEADERS,
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -5,9 +5,14 @@ import graphqlUploadExpress from 'graphql-upload/graphqlUploadExpress.mjs';
|
||||
|
||||
import {
|
||||
AFFiNELogger,
|
||||
buildCorsAllowedOrigins,
|
||||
CacheInterceptor,
|
||||
CloudThrottlerGuard,
|
||||
Config,
|
||||
CORS_ALLOWED_HEADERS,
|
||||
CORS_ALLOWED_METHODS,
|
||||
CORS_EXPOSED_HEADERS,
|
||||
corsOriginCallback,
|
||||
GlobalExceptionFilter,
|
||||
URLHelper,
|
||||
} from './base';
|
||||
@@ -16,12 +21,11 @@ import { AuthGuard } from './core/auth';
|
||||
import { serverTimingAndCache } from './middleware/timing';
|
||||
|
||||
const OneMB = 1024 * 1024;
|
||||
|
||||
export async function run() {
|
||||
const { AppModule } = await import('./app.module');
|
||||
|
||||
const app = await NestFactory.create<NestExpressApplication>(AppModule, {
|
||||
cors: true,
|
||||
cors: false,
|
||||
rawBody: true,
|
||||
bodyParser: true,
|
||||
bufferLogs: true,
|
||||
@@ -32,6 +36,27 @@ export async function run() {
|
||||
const logger = app.get(AFFiNELogger);
|
||||
app.useLogger(logger);
|
||||
const config = app.get(Config);
|
||||
const url = app.get(URLHelper);
|
||||
|
||||
const allowedOrigins = buildCorsAllowedOrigins(url);
|
||||
|
||||
app.enableCors({
|
||||
origin: (origin, callback) => {
|
||||
corsOriginCallback(
|
||||
origin,
|
||||
allowedOrigins,
|
||||
blockedOrigin =>
|
||||
logger.warn(`Blocked CORS request from origin: ${blockedOrigin}`),
|
||||
callback
|
||||
);
|
||||
},
|
||||
credentials: true,
|
||||
methods: CORS_ALLOWED_METHODS,
|
||||
allowedHeaders: CORS_ALLOWED_HEADERS,
|
||||
exposedHeaders: CORS_EXPOSED_HEADERS,
|
||||
maxAge: 86400,
|
||||
optionsSuccessStatus: 204,
|
||||
});
|
||||
|
||||
if (config.server.path) {
|
||||
app.setGlobalPrefix(config.server.path);
|
||||
@@ -74,8 +99,6 @@ export async function run() {
|
||||
});
|
||||
}
|
||||
|
||||
const url = app.get(URLHelper);
|
||||
|
||||
await app.listen(config.server.port, config.server.listenAddr);
|
||||
|
||||
const formattedAddr = config.server.listenAddr.includes(':')
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Button } from '@affine/component';
|
||||
import { Button, notify } from '@affine/component';
|
||||
import {
|
||||
AuthContainer,
|
||||
AuthContent,
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
} from '@affine/component/auth-components';
|
||||
import { useAsyncCallback } from '@affine/core/components/hooks/affine-async-hooks';
|
||||
import { ServersService } from '@affine/core/modules/cloud';
|
||||
import { UserFriendlyError } from '@affine/error';
|
||||
import { Trans, useI18n } from '@affine/i18n';
|
||||
import { useService } from '@toeverything/infra';
|
||||
import {
|
||||
@@ -35,12 +36,14 @@ export const AddSelfhostedStep = ({
|
||||
state: SignInState;
|
||||
changeState: Dispatch<SetStateAction<SignInState>>;
|
||||
}) => {
|
||||
const t = useI18n();
|
||||
const serversService = useService(ServersService);
|
||||
const [baseURL, setBaseURL] = useState(state.initialServerBaseUrl ?? '');
|
||||
const [isConnecting, setIsConnecting] = useState(false);
|
||||
const [error, setError] = useState<boolean>(false);
|
||||
|
||||
const t = useI18n();
|
||||
const [errorHint, setErrorHint] = useState(
|
||||
t['com.affine.auth.sign.add-selfhosted.error']()
|
||||
);
|
||||
|
||||
const urlValid = useMemo(() => {
|
||||
try {
|
||||
@@ -51,10 +54,14 @@ export const AddSelfhostedStep = ({
|
||||
}
|
||||
}, [baseURL]);
|
||||
|
||||
const onBaseURLChange = useCallback((value: string) => {
|
||||
setBaseURL(value);
|
||||
setError(false);
|
||||
}, []);
|
||||
const onBaseURLChange = useCallback(
|
||||
(value: string) => {
|
||||
setBaseURL(value);
|
||||
setError(false);
|
||||
setErrorHint(t['com.affine.auth.sign.add-selfhosted.error']());
|
||||
},
|
||||
[t]
|
||||
);
|
||||
|
||||
const onConnect = useAsyncCallback(async () => {
|
||||
setIsConnecting(true);
|
||||
@@ -69,11 +76,33 @@ export const AddSelfhostedStep = ({
|
||||
}));
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
const userFriendlyError = UserFriendlyError.fromAny(err);
|
||||
setError(true);
|
||||
}
|
||||
if (userFriendlyError.is('TOO_MANY_REQUEST')) {
|
||||
setErrorHint(t['error.TOO_MANY_REQUEST']());
|
||||
} else if (
|
||||
userFriendlyError.is('NETWORK_ERROR') ||
|
||||
userFriendlyError.is('REQUEST_ABORTED')
|
||||
) {
|
||||
setErrorHint(t['error.NETWORK_ERROR']());
|
||||
} else {
|
||||
setErrorHint(t['com.affine.auth.sign.add-selfhosted.error']());
|
||||
}
|
||||
|
||||
setIsConnecting(false);
|
||||
}, [baseURL, changeState, serversService]);
|
||||
notify.error({
|
||||
title: t['com.affine.auth.toast.title.failed'](),
|
||||
message:
|
||||
userFriendlyError.is('REQUEST_ABORTED') ||
|
||||
userFriendlyError.is('NETWORK_ERROR')
|
||||
? t['error.NETWORK_ERROR']()
|
||||
: userFriendlyError.is('TOO_MANY_REQUEST')
|
||||
? t['error.TOO_MANY_REQUEST']()
|
||||
: t[`error.${userFriendlyError.name}`](userFriendlyError.data),
|
||||
});
|
||||
} finally {
|
||||
setIsConnecting(false);
|
||||
}
|
||||
}, [baseURL, changeState, serversService, t]);
|
||||
|
||||
useEffect(() => {
|
||||
if (state.initialServerBaseUrl) {
|
||||
@@ -101,7 +130,7 @@ export const AddSelfhostedStep = ({
|
||||
placeholder="https://your-server.com"
|
||||
error={!!error}
|
||||
disabled={isConnecting}
|
||||
errorHint={t['com.affine.auth.sign.add-selfhosted.error']()}
|
||||
errorHint={errorHint}
|
||||
onEnter={onConnect}
|
||||
/>
|
||||
<Button
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
} from '@affine/core/modules/cloud';
|
||||
import type { AuthSessionStatus } from '@affine/core/modules/cloud/entities/session';
|
||||
import { Unreachable } from '@affine/env/constant';
|
||||
import { UserFriendlyError } from '@affine/error';
|
||||
import { ServerDeploymentType } from '@affine/graphql';
|
||||
import { useI18n } from '@affine/i18n';
|
||||
import { useLiveData, useService } from '@toeverything/infra';
|
||||
@@ -46,6 +47,7 @@ export const SignInWithPasswordStep = ({
|
||||
|
||||
const [password, setPassword] = useState('');
|
||||
const [passwordError, setPasswordError] = useState(false);
|
||||
const [passwordErrorHint, setPasswordErrorHint] = useState('');
|
||||
const captchaService = useService(CaptchaService);
|
||||
const serverService = useService(ServerService);
|
||||
const isSelfhosted = useLiveData(
|
||||
@@ -74,6 +76,10 @@ export const SignInWithPasswordStep = ({
|
||||
onAuthenticated?.(loginStatus);
|
||||
}, [loginStatus, onAuthenticated, t]);
|
||||
|
||||
useEffect(() => {
|
||||
setPasswordErrorHint(t['com.affine.auth.password.error']());
|
||||
}, [t]);
|
||||
|
||||
const onSignIn = useAsyncCallback(async () => {
|
||||
if (isLoading || (!verifyToken && needCaptcha)) return;
|
||||
setIsLoading(true);
|
||||
@@ -88,7 +94,23 @@ export const SignInWithPasswordStep = ({
|
||||
});
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
setPasswordError(true);
|
||||
const error = UserFriendlyError.fromAny(err);
|
||||
|
||||
if (
|
||||
error.is('WRONG_SIGN_IN_CREDENTIALS') ||
|
||||
error.is('PASSWORD_REQUIRED')
|
||||
) {
|
||||
setPasswordError(true);
|
||||
setPasswordErrorHint(t['com.affine.auth.password.error']());
|
||||
} else {
|
||||
setPasswordError(false);
|
||||
notify.error({
|
||||
title: t['com.affine.auth.toast.title.failed'](),
|
||||
message: error.is('REQUEST_ABORTED')
|
||||
? t['error.NETWORK_ERROR']()
|
||||
: t[`error.${error.name}`](error.data),
|
||||
});
|
||||
}
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
@@ -101,6 +123,7 @@ export const SignInWithPasswordStep = ({
|
||||
email,
|
||||
password,
|
||||
challenge,
|
||||
t,
|
||||
]);
|
||||
|
||||
const sendMagicLink = useCallback(() => {
|
||||
@@ -126,11 +149,15 @@ export const SignInWithPasswordStep = ({
|
||||
label={t['com.affine.auth.password']()}
|
||||
value={password}
|
||||
type="password"
|
||||
onChange={useCallback((value: string) => {
|
||||
onChange={(value: string) => {
|
||||
setPassword(value);
|
||||
}, [])}
|
||||
if (passwordError) {
|
||||
setPasswordError(false);
|
||||
setPasswordErrorHint(t['com.affine.auth.password.error']());
|
||||
}
|
||||
}}
|
||||
error={passwordError}
|
||||
errorHint={t['com.affine.auth.password.error']()}
|
||||
errorHint={passwordErrorHint}
|
||||
onEnter={onSignIn}
|
||||
/>
|
||||
{!isSelfhosted && (
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { UserFriendlyError } from '@affine/error';
|
||||
import {
|
||||
gqlFetcherFactory,
|
||||
type OauthProvidersQuery,
|
||||
@@ -11,6 +12,45 @@ import { Store } from '@toeverything/infra';
|
||||
export type ServerConfigType = ServerConfigQuery['serverConfig'] &
|
||||
OauthProvidersQuery['serverConfig'];
|
||||
|
||||
const NETWORK_ERROR_PATTERNS = [
|
||||
/failed to fetch/i,
|
||||
/network request failed/i,
|
||||
/fetch failed/i,
|
||||
/load failed/i,
|
||||
/networkerror/i,
|
||||
/cors/i,
|
||||
/certificate/i,
|
||||
/ssl/i,
|
||||
/err_[a-z_]+/i,
|
||||
];
|
||||
|
||||
function mapServerConfigError(error: unknown) {
|
||||
const userFriendlyError = UserFriendlyError.fromAny(error);
|
||||
if (
|
||||
userFriendlyError.is('NETWORK_ERROR') ||
|
||||
userFriendlyError.is('REQUEST_ABORTED') ||
|
||||
userFriendlyError.is('TOO_MANY_REQUEST')
|
||||
) {
|
||||
return userFriendlyError;
|
||||
}
|
||||
|
||||
if (error instanceof Error) {
|
||||
const detail = `${error.name}: ${error.message}`;
|
||||
if (NETWORK_ERROR_PATTERNS.some(pattern => pattern.test(detail))) {
|
||||
return new UserFriendlyError({
|
||||
status: 504,
|
||||
code: 'NETWORK_ERROR',
|
||||
type: 'NETWORK_ERROR',
|
||||
name: 'NETWORK_ERROR',
|
||||
message: detail,
|
||||
stacktrace: error.stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return userFriendlyError;
|
||||
}
|
||||
|
||||
export class ServerConfigStore extends Store {
|
||||
constructor() {
|
||||
super();
|
||||
@@ -20,19 +60,13 @@ export class ServerConfigStore extends Store {
|
||||
serverBaseUrl: string,
|
||||
abortSignal?: AbortSignal
|
||||
): Promise<ServerConfigType> {
|
||||
const gql = gqlFetcherFactory(`${serverBaseUrl}/graphql`, globalThis.fetch);
|
||||
const serverConfigData = await gql({
|
||||
query: serverConfigQuery,
|
||||
context: {
|
||||
signal: abortSignal,
|
||||
headers: {
|
||||
'x-affine-version': BUILD_CONFIG.appVersion,
|
||||
},
|
||||
},
|
||||
});
|
||||
if (serverConfigData.serverConfig.features.includes(ServerFeature.OAuth)) {
|
||||
const oauthProvidersData = await gql({
|
||||
query: oauthProvidersQuery,
|
||||
try {
|
||||
const gql = gqlFetcherFactory(
|
||||
`${serverBaseUrl}/graphql`,
|
||||
globalThis.fetch
|
||||
);
|
||||
const serverConfigData = await gql({
|
||||
query: serverConfigQuery,
|
||||
context: {
|
||||
signal: abortSignal,
|
||||
headers: {
|
||||
@@ -40,11 +74,26 @@ export class ServerConfigStore extends Store {
|
||||
},
|
||||
},
|
||||
});
|
||||
return {
|
||||
...serverConfigData.serverConfig,
|
||||
...oauthProvidersData.serverConfig,
|
||||
};
|
||||
if (
|
||||
serverConfigData.serverConfig.features.includes(ServerFeature.OAuth)
|
||||
) {
|
||||
const oauthProvidersData = await gql({
|
||||
query: oauthProvidersQuery,
|
||||
context: {
|
||||
signal: abortSignal,
|
||||
headers: {
|
||||
'x-affine-version': BUILD_CONFIG.appVersion,
|
||||
},
|
||||
},
|
||||
});
|
||||
return {
|
||||
...serverConfigData.serverConfig,
|
||||
...oauthProvidersData.serverConfig,
|
||||
};
|
||||
}
|
||||
return { ...serverConfigData.serverConfig, oauthProviders: [] };
|
||||
} catch (error) {
|
||||
throw mapServerConfigError(error);
|
||||
}
|
||||
return { ...serverConfigData.serverConfig, oauthProviders: [] };
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user