mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-04 08:38:34 +00:00
Compare commits
29 Commits
renovate/n
...
76eefcb4f3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76eefcb4f3 | ||
|
|
18b8c7831f | ||
|
|
ae26418281 | ||
|
|
de29e8300a | ||
|
|
e2b26ffb0c | ||
|
|
63e602a6f5 | ||
|
|
12f0a9ae62 | ||
|
|
73d4da192d | ||
|
|
0b648f8613 | ||
|
|
516d72e83f | ||
|
|
a27f8b168a | ||
|
|
7040fe3e75 | ||
|
|
a8211b2e00 | ||
|
|
cce6122a63 | ||
|
|
40a2518ff9 | ||
|
|
345f45d327 | ||
|
|
1f94d7d1bc | ||
|
|
f1a6e409cb | ||
|
|
059d3aa04a | ||
|
|
948951d461 | ||
|
|
0f0bfb9f06 | ||
|
|
b778207af9 | ||
|
|
888f1f39db | ||
|
|
b49e48b467 | ||
|
|
759aa1b684 | ||
|
|
5041578768 | ||
|
|
b8f626513f | ||
|
|
3b4b0bad22 | ||
|
|
d8404e9df8 |
@@ -337,8 +337,42 @@
|
||||
},
|
||||
"config": {
|
||||
"type": "object",
|
||||
"description": "The config for the s3 compatible storage provider. directly passed to aws-sdk client.\n@link https://docs.aws.amazon.com/AWSJavaScriptSDK/latest/AWS/S3.html",
|
||||
"description": "The config for the S3 compatible storage provider.",
|
||||
"properties": {
|
||||
"endpoint": {
|
||||
"type": "string",
|
||||
"description": "The S3 compatible endpoint. Example: \"https://s3.us-east-1.amazonaws.com\" or \"https://<account>.r2.cloudflarestorage.com\"."
|
||||
},
|
||||
"region": {
|
||||
"type": "string",
|
||||
"description": "The region for the storage provider. Example: \"us-east-1\" or \"auto\" for R2."
|
||||
},
|
||||
"forcePathStyle": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to use path-style bucket addressing."
|
||||
},
|
||||
"requestTimeoutMs": {
|
||||
"type": "number",
|
||||
"description": "Request timeout in milliseconds."
|
||||
},
|
||||
"minPartSize": {
|
||||
"type": "number",
|
||||
"description": "Minimum multipart part size in bytes."
|
||||
},
|
||||
"presign": {
|
||||
"type": "object",
|
||||
"description": "Presigned URL behavior configuration.",
|
||||
"properties": {
|
||||
"expiresInSeconds": {
|
||||
"type": "number",
|
||||
"description": "Expiration time in seconds for presigned URLs."
|
||||
},
|
||||
"signContentTypeForPut": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to sign Content-Type for presigned PUT."
|
||||
}
|
||||
}
|
||||
},
|
||||
"credentials": {
|
||||
"type": "object",
|
||||
"description": "The credentials for the s3 compatible storage provider.",
|
||||
@@ -348,6 +382,9 @@
|
||||
},
|
||||
"secretAccessKey": {
|
||||
"type": "string"
|
||||
},
|
||||
"sessionToken": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -369,8 +406,42 @@
|
||||
},
|
||||
"config": {
|
||||
"type": "object",
|
||||
"description": "The config for the s3 compatible storage provider. directly passed to aws-sdk client.\n@link https://docs.aws.amazon.com/AWSJavaScriptSDK/latest/AWS/S3.html",
|
||||
"description": "The config for the S3 compatible storage provider.",
|
||||
"properties": {
|
||||
"endpoint": {
|
||||
"type": "string",
|
||||
"description": "The S3 compatible endpoint. Example: \"https://s3.us-east-1.amazonaws.com\" or \"https://<account>.r2.cloudflarestorage.com\"."
|
||||
},
|
||||
"region": {
|
||||
"type": "string",
|
||||
"description": "The region for the storage provider. Example: \"us-east-1\" or \"auto\" for R2."
|
||||
},
|
||||
"forcePathStyle": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to use path-style bucket addressing."
|
||||
},
|
||||
"requestTimeoutMs": {
|
||||
"type": "number",
|
||||
"description": "Request timeout in milliseconds."
|
||||
},
|
||||
"minPartSize": {
|
||||
"type": "number",
|
||||
"description": "Minimum multipart part size in bytes."
|
||||
},
|
||||
"presign": {
|
||||
"type": "object",
|
||||
"description": "Presigned URL behavior configuration.",
|
||||
"properties": {
|
||||
"expiresInSeconds": {
|
||||
"type": "number",
|
||||
"description": "Expiration time in seconds for presigned URLs."
|
||||
},
|
||||
"signContentTypeForPut": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to sign Content-Type for presigned PUT."
|
||||
}
|
||||
}
|
||||
},
|
||||
"credentials": {
|
||||
"type": "object",
|
||||
"description": "The credentials for the s3 compatible storage provider.",
|
||||
@@ -380,6 +451,9 @@
|
||||
},
|
||||
"secretAccessKey": {
|
||||
"type": "string"
|
||||
},
|
||||
"sessionToken": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -458,8 +532,42 @@
|
||||
},
|
||||
"config": {
|
||||
"type": "object",
|
||||
"description": "The config for the s3 compatible storage provider. directly passed to aws-sdk client.\n@link https://docs.aws.amazon.com/AWSJavaScriptSDK/latest/AWS/S3.html",
|
||||
"description": "The config for the S3 compatible storage provider.",
|
||||
"properties": {
|
||||
"endpoint": {
|
||||
"type": "string",
|
||||
"description": "The S3 compatible endpoint. Example: \"https://s3.us-east-1.amazonaws.com\" or \"https://<account>.r2.cloudflarestorage.com\"."
|
||||
},
|
||||
"region": {
|
||||
"type": "string",
|
||||
"description": "The region for the storage provider. Example: \"us-east-1\" or \"auto\" for R2."
|
||||
},
|
||||
"forcePathStyle": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to use path-style bucket addressing."
|
||||
},
|
||||
"requestTimeoutMs": {
|
||||
"type": "number",
|
||||
"description": "Request timeout in milliseconds."
|
||||
},
|
||||
"minPartSize": {
|
||||
"type": "number",
|
||||
"description": "Minimum multipart part size in bytes."
|
||||
},
|
||||
"presign": {
|
||||
"type": "object",
|
||||
"description": "Presigned URL behavior configuration.",
|
||||
"properties": {
|
||||
"expiresInSeconds": {
|
||||
"type": "number",
|
||||
"description": "Expiration time in seconds for presigned URLs."
|
||||
},
|
||||
"signContentTypeForPut": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to sign Content-Type for presigned PUT."
|
||||
}
|
||||
}
|
||||
},
|
||||
"credentials": {
|
||||
"type": "object",
|
||||
"description": "The credentials for the s3 compatible storage provider.",
|
||||
@@ -469,6 +577,9 @@
|
||||
},
|
||||
"secretAccessKey": {
|
||||
"type": "string"
|
||||
},
|
||||
"sessionToken": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -490,8 +601,42 @@
|
||||
},
|
||||
"config": {
|
||||
"type": "object",
|
||||
"description": "The config for the s3 compatible storage provider. directly passed to aws-sdk client.\n@link https://docs.aws.amazon.com/AWSJavaScriptSDK/latest/AWS/S3.html",
|
||||
"description": "The config for the S3 compatible storage provider.",
|
||||
"properties": {
|
||||
"endpoint": {
|
||||
"type": "string",
|
||||
"description": "The S3 compatible endpoint. Example: \"https://s3.us-east-1.amazonaws.com\" or \"https://<account>.r2.cloudflarestorage.com\"."
|
||||
},
|
||||
"region": {
|
||||
"type": "string",
|
||||
"description": "The region for the storage provider. Example: \"us-east-1\" or \"auto\" for R2."
|
||||
},
|
||||
"forcePathStyle": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to use path-style bucket addressing."
|
||||
},
|
||||
"requestTimeoutMs": {
|
||||
"type": "number",
|
||||
"description": "Request timeout in milliseconds."
|
||||
},
|
||||
"minPartSize": {
|
||||
"type": "number",
|
||||
"description": "Minimum multipart part size in bytes."
|
||||
},
|
||||
"presign": {
|
||||
"type": "object",
|
||||
"description": "Presigned URL behavior configuration.",
|
||||
"properties": {
|
||||
"expiresInSeconds": {
|
||||
"type": "number",
|
||||
"description": "Expiration time in seconds for presigned URLs."
|
||||
},
|
||||
"signContentTypeForPut": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to sign Content-Type for presigned PUT."
|
||||
}
|
||||
}
|
||||
},
|
||||
"credentials": {
|
||||
"type": "object",
|
||||
"description": "The credentials for the s3 compatible storage provider.",
|
||||
@@ -501,6 +646,9 @@
|
||||
},
|
||||
"secretAccessKey": {
|
||||
"type": "string"
|
||||
},
|
||||
"sessionToken": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -684,8 +832,8 @@
|
||||
},
|
||||
"versionControl.requiredVersion": {
|
||||
"type": "string",
|
||||
"description": "Allowed version range of the app that allowed to access the server. Requires 'client/versionControl.enabled' to be true to take effect.\n@default \">=0.20.0\"",
|
||||
"default": ">=0.20.0"
|
||||
"description": "Allowed version range of the app that allowed to access the server. Requires 'client/versionControl.enabled' to be true to take effect.\n@default \">=0.25.0\"",
|
||||
"default": ">=0.25.0"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -941,8 +1089,42 @@
|
||||
},
|
||||
"config": {
|
||||
"type": "object",
|
||||
"description": "The config for the s3 compatible storage provider. directly passed to aws-sdk client.\n@link https://docs.aws.amazon.com/AWSJavaScriptSDK/latest/AWS/S3.html",
|
||||
"description": "The config for the S3 compatible storage provider.",
|
||||
"properties": {
|
||||
"endpoint": {
|
||||
"type": "string",
|
||||
"description": "The S3 compatible endpoint. Example: \"https://s3.us-east-1.amazonaws.com\" or \"https://<account>.r2.cloudflarestorage.com\"."
|
||||
},
|
||||
"region": {
|
||||
"type": "string",
|
||||
"description": "The region for the storage provider. Example: \"us-east-1\" or \"auto\" for R2."
|
||||
},
|
||||
"forcePathStyle": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to use path-style bucket addressing."
|
||||
},
|
||||
"requestTimeoutMs": {
|
||||
"type": "number",
|
||||
"description": "Request timeout in milliseconds."
|
||||
},
|
||||
"minPartSize": {
|
||||
"type": "number",
|
||||
"description": "Minimum multipart part size in bytes."
|
||||
},
|
||||
"presign": {
|
||||
"type": "object",
|
||||
"description": "Presigned URL behavior configuration.",
|
||||
"properties": {
|
||||
"expiresInSeconds": {
|
||||
"type": "number",
|
||||
"description": "Expiration time in seconds for presigned URLs."
|
||||
},
|
||||
"signContentTypeForPut": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to sign Content-Type for presigned PUT."
|
||||
}
|
||||
}
|
||||
},
|
||||
"credentials": {
|
||||
"type": "object",
|
||||
"description": "The credentials for the s3 compatible storage provider.",
|
||||
@@ -952,6 +1134,9 @@
|
||||
},
|
||||
"secretAccessKey": {
|
||||
"type": "string"
|
||||
},
|
||||
"sessionToken": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -973,8 +1158,42 @@
|
||||
},
|
||||
"config": {
|
||||
"type": "object",
|
||||
"description": "The config for the s3 compatible storage provider. directly passed to aws-sdk client.\n@link https://docs.aws.amazon.com/AWSJavaScriptSDK/latest/AWS/S3.html",
|
||||
"description": "The config for the S3 compatible storage provider.",
|
||||
"properties": {
|
||||
"endpoint": {
|
||||
"type": "string",
|
||||
"description": "The S3 compatible endpoint. Example: \"https://s3.us-east-1.amazonaws.com\" or \"https://<account>.r2.cloudflarestorage.com\"."
|
||||
},
|
||||
"region": {
|
||||
"type": "string",
|
||||
"description": "The region for the storage provider. Example: \"us-east-1\" or \"auto\" for R2."
|
||||
},
|
||||
"forcePathStyle": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to use path-style bucket addressing."
|
||||
},
|
||||
"requestTimeoutMs": {
|
||||
"type": "number",
|
||||
"description": "Request timeout in milliseconds."
|
||||
},
|
||||
"minPartSize": {
|
||||
"type": "number",
|
||||
"description": "Minimum multipart part size in bytes."
|
||||
},
|
||||
"presign": {
|
||||
"type": "object",
|
||||
"description": "Presigned URL behavior configuration.",
|
||||
"properties": {
|
||||
"expiresInSeconds": {
|
||||
"type": "number",
|
||||
"description": "Expiration time in seconds for presigned URLs."
|
||||
},
|
||||
"signContentTypeForPut": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to sign Content-Type for presigned PUT."
|
||||
}
|
||||
}
|
||||
},
|
||||
"credentials": {
|
||||
"type": "object",
|
||||
"description": "The credentials for the s3 compatible storage provider.",
|
||||
@@ -984,6 +1203,9 @@
|
||||
},
|
||||
"secretAccessKey": {
|
||||
"type": "string"
|
||||
},
|
||||
"sessionToken": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
2
.github/workflows/release-mobile.yml
vendored
2
.github/workflows/release-mobile.yml
vendored
@@ -112,7 +112,7 @@ jobs:
|
||||
enableScripts: false
|
||||
- uses: maxim-lobanov/setup-xcode@v1
|
||||
with:
|
||||
xcode-version: 16.4
|
||||
xcode-version: 26.2
|
||||
- name: Install Swiftformat
|
||||
run: brew install swiftformat
|
||||
- name: Cap sync
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -47,8 +47,7 @@ testem.log
|
||||
.pnpm-debug.log
|
||||
/typings
|
||||
tsconfig.tsbuildinfo
|
||||
rfc*.md
|
||||
todo.md
|
||||
.context
|
||||
|
||||
# System Files
|
||||
.DS_Store
|
||||
|
||||
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -44,6 +44,7 @@ dependencies = [
|
||||
"docx-parser",
|
||||
"infer",
|
||||
"nanoid",
|
||||
"napi",
|
||||
"path-ext",
|
||||
"pdf-extract",
|
||||
"pulldown-cmark",
|
||||
@@ -126,6 +127,7 @@ dependencies = [
|
||||
"affine_nbstore",
|
||||
"affine_sqlite_v1",
|
||||
"chrono",
|
||||
"mimalloc",
|
||||
"napi",
|
||||
"napi-build",
|
||||
"napi-derive",
|
||||
|
||||
@@ -35,9 +35,28 @@ export async function printToPdf(
|
||||
overflow: initial !important;
|
||||
print-color-adjust: exact;
|
||||
-webkit-print-color-adjust: exact;
|
||||
color: #000 !important;
|
||||
background: #fff !important;
|
||||
color-scheme: light !important;
|
||||
}
|
||||
::-webkit-scrollbar {
|
||||
display: none;
|
||||
::-webkit-scrollbar {
|
||||
display: none;
|
||||
}
|
||||
:root, body {
|
||||
--affine-text-primary: #000 !important;
|
||||
--affine-text-secondary: #111 !important;
|
||||
--affine-text-tertiary: #333 !important;
|
||||
--affine-background-primary: #fff !important;
|
||||
--affine-background-secondary: #fff !important;
|
||||
--affine-background-tertiary: #fff !important;
|
||||
}
|
||||
body, [data-theme='dark'] {
|
||||
color: #000 !important;
|
||||
background: #fff !important;
|
||||
}
|
||||
body * {
|
||||
color: #000 !important;
|
||||
-webkit-text-fill-color: #000 !important;
|
||||
}
|
||||
:root {
|
||||
--affine-note-shadow-box: none !important;
|
||||
@@ -95,6 +114,14 @@ export async function printToPdf(
|
||||
true
|
||||
) as HTMLDivElement;
|
||||
|
||||
// force light theme in print iframe
|
||||
iframe.contentWindow.document.documentElement.setAttribute(
|
||||
'data-theme',
|
||||
'light'
|
||||
);
|
||||
iframe.contentWindow.document.body.setAttribute('data-theme', 'light');
|
||||
importedRoot.setAttribute('data-theme', 'light');
|
||||
|
||||
// draw saved canvas image to canvas
|
||||
const allImportedCanvas = importedRoot.getElementsByTagName('canvas');
|
||||
for (const importedCanvas of allImportedCanvas) {
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import { getInlineEditorByModel } from '@blocksuite/affine-rich-text';
|
||||
import type { AffineInlineEditor } from '@blocksuite/affine-shared/types';
|
||||
import { DisposableGroup } from '@blocksuite/global/disposable';
|
||||
import type { UIEventStateContext } from '@blocksuite/std';
|
||||
import { TextSelection, WidgetComponent } from '@blocksuite/std';
|
||||
import {
|
||||
TextSelection,
|
||||
type UIEventStateContext,
|
||||
WidgetComponent,
|
||||
} from '@blocksuite/std';
|
||||
import { InlineEditor } from '@blocksuite/std/inline';
|
||||
import debounce from 'lodash-es/debounce';
|
||||
|
||||
@@ -59,9 +62,7 @@ const showSlashMenu = debounce(
|
||||
);
|
||||
|
||||
export class AffineSlashMenuWidget extends WidgetComponent {
|
||||
private readonly _getInlineEditor = (
|
||||
evt: KeyboardEvent | CompositionEvent
|
||||
) => {
|
||||
private readonly _getInlineEditor = (evt: CompositionEvent | InputEvent) => {
|
||||
if (evt.target instanceof HTMLElement) {
|
||||
const editor = (
|
||||
evt.target.closest('.inline-editor') as {
|
||||
@@ -152,18 +153,27 @@ export class AffineSlashMenuWidget extends WidgetComponent {
|
||||
this._handleInput(inlineEditor, true);
|
||||
};
|
||||
|
||||
private readonly _onKeyDown = (ctx: UIEventStateContext) => {
|
||||
const eventState = ctx.get('keyboardState');
|
||||
const event = eventState.raw;
|
||||
private readonly _onBeforeInput = (ctx: UIEventStateContext) => {
|
||||
const event = ctx.get('defaultState').event;
|
||||
if (!(event instanceof InputEvent)) return;
|
||||
|
||||
const key = event.key;
|
||||
// Skip non-character inputs and IME composition (handled by _onCompositionEnd)
|
||||
if (event.data === null || event.isComposing) return;
|
||||
|
||||
if (event.isComposing || key !== AFFINE_SLASH_MENU_TRIGGER_KEY) return;
|
||||
// Quick check: only proceed if the input contains the trigger key
|
||||
if (!event.data.includes(AFFINE_SLASH_MENU_TRIGGER_KEY)) return;
|
||||
|
||||
const inlineEditor = this._getInlineEditor(event);
|
||||
if (!inlineEditor) return;
|
||||
|
||||
this._handleInput(inlineEditor, false);
|
||||
// Wait for the input to be processed, then handle it
|
||||
// Pass true because after waitForUpdate(), the range is already synced
|
||||
inlineEditor
|
||||
.waitForUpdate()
|
||||
.then(() => {
|
||||
this._handleInput(inlineEditor, true);
|
||||
})
|
||||
.catch(console.error);
|
||||
};
|
||||
|
||||
get config() {
|
||||
@@ -177,8 +187,7 @@ export class AffineSlashMenuWidget extends WidgetComponent {
|
||||
override connectedCallback() {
|
||||
super.connectedCallback();
|
||||
|
||||
// this.handleEvent('beforeInput', this._onBeforeInput);
|
||||
this.handleEvent('keyDown', this._onKeyDown);
|
||||
this.handleEvent('beforeInput', this._onBeforeInput);
|
||||
this.handleEvent('compositionEnd', this._onCompositionEnd);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ crate-type = ["cdylib"]
|
||||
affine_common = { workspace = true, features = [
|
||||
"doc-loader",
|
||||
"hashcash",
|
||||
"napi",
|
||||
"ydoc-loader",
|
||||
] }
|
||||
chrono = { workspace = true }
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use affine_common::doc_parser::{self, BlockInfo, CrawlResult, MarkdownResult, PageDocContent, WorkspaceDocContent};
|
||||
use affine_common::{
|
||||
doc_parser::{self, BlockInfo, CrawlResult, MarkdownResult, PageDocContent, WorkspaceDocContent},
|
||||
napi_utils::map_napi_err,
|
||||
};
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::napi;
|
||||
|
||||
@@ -95,22 +98,25 @@ impl From<CrawlResult> for NativeCrawlResult {
|
||||
|
||||
#[napi]
|
||||
pub fn parse_doc_from_binary(doc_bin: Buffer, doc_id: String) -> Result<NativeCrawlResult> {
|
||||
let result = doc_parser::parse_doc_from_binary(doc_bin.into(), doc_id)
|
||||
.map_err(|e| Error::new(Status::GenericFailure, e.to_string()))?;
|
||||
let result = map_napi_err(
|
||||
doc_parser::parse_doc_from_binary(doc_bin.into(), doc_id),
|
||||
Status::GenericFailure,
|
||||
)?;
|
||||
Ok(result.into())
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn parse_page_doc(doc_bin: Buffer, max_summary_length: Option<i32>) -> Result<Option<NativePageDocContent>> {
|
||||
let result = doc_parser::parse_page_doc(doc_bin.into(), max_summary_length.map(|v| v as isize))
|
||||
.map_err(|e| Error::new(Status::GenericFailure, e.to_string()))?;
|
||||
let result = map_napi_err(
|
||||
doc_parser::parse_page_doc(doc_bin.into(), max_summary_length.map(|v| v as isize)),
|
||||
Status::GenericFailure,
|
||||
)?;
|
||||
Ok(result.map(Into::into))
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn parse_workspace_doc(doc_bin: Buffer) -> Result<Option<NativeWorkspaceDocContent>> {
|
||||
let result =
|
||||
doc_parser::parse_workspace_doc(doc_bin.into()).map_err(|e| Error::new(Status::GenericFailure, e.to_string()))?;
|
||||
let result = map_napi_err(doc_parser::parse_workspace_doc(doc_bin.into()), Status::GenericFailure)?;
|
||||
Ok(result.map(Into::into))
|
||||
}
|
||||
|
||||
@@ -121,15 +127,19 @@ pub fn parse_doc_to_markdown(
|
||||
ai_editable: Option<bool>,
|
||||
doc_url_prefix: Option<String>,
|
||||
) -> Result<NativeMarkdownResult> {
|
||||
let result = doc_parser::parse_doc_to_markdown(doc_bin.into(), doc_id, ai_editable.unwrap_or(false), doc_url_prefix)
|
||||
.map_err(|e| Error::new(Status::GenericFailure, e.to_string()))?;
|
||||
let result = map_napi_err(
|
||||
doc_parser::parse_doc_to_markdown(doc_bin.into(), doc_id, ai_editable.unwrap_or(false), doc_url_prefix),
|
||||
Status::GenericFailure,
|
||||
)?;
|
||||
Ok(result.into())
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn read_all_doc_ids_from_root_doc(doc_bin: Buffer, include_trash: Option<bool>) -> Result<Vec<String>> {
|
||||
let result = doc_parser::get_doc_ids_from_binary(doc_bin.into(), include_trash.unwrap_or(false))
|
||||
.map_err(|e| Error::new(Status::GenericFailure, e.to_string()))?;
|
||||
let result = map_napi_err(
|
||||
doc_parser::get_doc_ids_from_binary(doc_bin.into(), include_trash.unwrap_or(false)),
|
||||
Status::GenericFailure,
|
||||
)?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
@@ -144,8 +154,10 @@ pub fn read_all_doc_ids_from_root_doc(doc_bin: Buffer, include_trash: Option<boo
|
||||
/// A Buffer containing the y-octo document update binary
|
||||
#[napi]
|
||||
pub fn create_doc_with_markdown(title: String, markdown: String, doc_id: String) -> Result<Buffer> {
|
||||
let result = doc_parser::build_full_doc(&title, &markdown, &doc_id)
|
||||
.map_err(|e| Error::new(Status::GenericFailure, e.to_string()))?;
|
||||
let result = map_napi_err(
|
||||
doc_parser::build_full_doc(&title, &markdown, &doc_id),
|
||||
Status::GenericFailure,
|
||||
)?;
|
||||
Ok(Buffer::from(result))
|
||||
}
|
||||
|
||||
@@ -161,8 +173,10 @@ pub fn create_doc_with_markdown(title: String, markdown: String, doc_id: String)
|
||||
/// A Buffer containing only the delta (changes) as a y-octo update binary
|
||||
#[napi]
|
||||
pub fn update_doc_with_markdown(existing_binary: Buffer, new_markdown: String, doc_id: String) -> Result<Buffer> {
|
||||
let result = doc_parser::update_doc(&existing_binary, &new_markdown, &doc_id)
|
||||
.map_err(|e| Error::new(Status::GenericFailure, e.to_string()))?;
|
||||
let result = map_napi_err(
|
||||
doc_parser::update_doc(&existing_binary, &new_markdown, &doc_id),
|
||||
Status::GenericFailure,
|
||||
)?;
|
||||
Ok(Buffer::from(result))
|
||||
}
|
||||
|
||||
@@ -177,8 +191,10 @@ pub fn update_doc_with_markdown(existing_binary: Buffer, new_markdown: String, d
|
||||
/// A Buffer containing only the delta (changes) as a y-octo update binary
|
||||
#[napi]
|
||||
pub fn update_doc_title(existing_binary: Buffer, title: String, doc_id: String) -> Result<Buffer> {
|
||||
let result = doc_parser::update_doc_title(&existing_binary, &doc_id, &title)
|
||||
.map_err(|e| Error::new(Status::GenericFailure, e.to_string()))?;
|
||||
let result = map_napi_err(
|
||||
doc_parser::update_doc_title(&existing_binary, &doc_id, &title),
|
||||
Status::GenericFailure,
|
||||
)?;
|
||||
Ok(Buffer::from(result))
|
||||
}
|
||||
|
||||
@@ -202,14 +218,16 @@ pub fn update_doc_properties(
|
||||
created_by: Option<String>,
|
||||
updated_by: Option<String>,
|
||||
) -> Result<Buffer> {
|
||||
let result = doc_parser::update_doc_properties(
|
||||
&existing_binary,
|
||||
&properties_doc_id,
|
||||
&target_doc_id,
|
||||
created_by.as_deref(),
|
||||
updated_by.as_deref(),
|
||||
)
|
||||
.map_err(|e| Error::new(Status::GenericFailure, e.to_string()))?;
|
||||
let result = map_napi_err(
|
||||
doc_parser::update_doc_properties(
|
||||
&existing_binary,
|
||||
&properties_doc_id,
|
||||
&target_doc_id,
|
||||
created_by.as_deref(),
|
||||
updated_by.as_deref(),
|
||||
),
|
||||
Status::GenericFailure,
|
||||
)?;
|
||||
Ok(Buffer::from(result))
|
||||
}
|
||||
|
||||
@@ -225,8 +243,10 @@ pub fn update_doc_properties(
|
||||
/// A Buffer containing the y-octo update binary to apply to the root doc
|
||||
#[napi]
|
||||
pub fn add_doc_to_root_doc(root_doc_bin: Buffer, doc_id: String, title: Option<String>) -> Result<Buffer> {
|
||||
let result = doc_parser::add_doc_to_root_doc(root_doc_bin.into(), &doc_id, title.as_deref())
|
||||
.map_err(|e| Error::new(Status::GenericFailure, e.to_string()))?;
|
||||
let result = map_napi_err(
|
||||
doc_parser::add_doc_to_root_doc(root_doc_bin.into(), &doc_id, title.as_deref()),
|
||||
Status::GenericFailure,
|
||||
)?;
|
||||
Ok(Buffer::from(result))
|
||||
}
|
||||
|
||||
@@ -241,7 +261,9 @@ pub fn add_doc_to_root_doc(root_doc_bin: Buffer, doc_id: String, title: Option<S
|
||||
/// A Buffer containing the y-octo update binary to apply to the root doc
|
||||
#[napi]
|
||||
pub fn update_root_doc_meta_title(root_doc_bin: Buffer, doc_id: String, title: String) -> Result<Buffer> {
|
||||
let result = doc_parser::update_root_doc_meta_title(&root_doc_bin, &doc_id, &title)
|
||||
.map_err(|e| Error::new(Status::GenericFailure, e.to_string()))?;
|
||||
let result = map_napi_err(
|
||||
doc_parser::update_root_doc_meta_title(&root_doc_bin, &doc_id, &title),
|
||||
Status::GenericFailure,
|
||||
)?;
|
||||
Ok(Buffer::from(result))
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use affine_common::doc_loader::Doc;
|
||||
use affine_common::{doc_loader::Doc, napi_utils::map_napi_err};
|
||||
use napi::{
|
||||
Env, Result, Task,
|
||||
anyhow::anyhow,
|
||||
Env, Result, Status, Task,
|
||||
bindgen_prelude::{AsyncTask, Buffer},
|
||||
};
|
||||
|
||||
@@ -54,7 +53,7 @@ impl Task for AsyncParseDocResponse {
|
||||
type JsValue = ParsedDoc;
|
||||
|
||||
fn compute(&mut self) -> Result<Self::Output> {
|
||||
let doc = Doc::new(&self.file_path, &self.doc).map_err(|e| anyhow!(e))?;
|
||||
let doc = map_napi_err(Doc::new(&self.file_path, &self.doc), Status::GenericFailure)?;
|
||||
Ok(Document { inner: doc })
|
||||
}
|
||||
|
||||
|
||||
@@ -9,9 +9,8 @@ pub mod hashcash;
|
||||
pub mod html_sanitize;
|
||||
pub mod tiktoken;
|
||||
|
||||
use std::fmt::{Debug, Display};
|
||||
|
||||
use napi::{Error, Result, Status, bindgen_prelude::*};
|
||||
use affine_common::napi_utils::map_napi_err;
|
||||
use napi::{Result, Status, bindgen_prelude::*};
|
||||
use y_octo::Doc;
|
||||
|
||||
#[cfg(not(target_arch = "arm"))]
|
||||
@@ -21,35 +20,16 @@ static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
#[macro_use]
|
||||
extern crate napi_derive;
|
||||
|
||||
fn map_err_inner<T, E: Display + Debug>(v: std::result::Result<T, E>, status: Status) -> Result<T> {
|
||||
match v {
|
||||
Ok(val) => Ok(val),
|
||||
Err(e) => {
|
||||
dbg!(&e);
|
||||
Err(Error::new(status, e.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! map_err {
|
||||
($val: expr) => {
|
||||
map_err_inner($val, Status::GenericFailure)
|
||||
};
|
||||
($val: expr, $stauts: ident) => {
|
||||
map_err_inner($val, $stauts)
|
||||
};
|
||||
}
|
||||
|
||||
/// Merge updates in form like `Y.applyUpdate(doc, update)` way and return the
|
||||
/// result binary.
|
||||
#[napi(catch_unwind)]
|
||||
pub fn merge_updates_in_apply_way(updates: Vec<Buffer>) -> Result<Buffer> {
|
||||
let mut doc = Doc::default();
|
||||
for update in updates {
|
||||
map_err!(doc.apply_update_from_binary_v1(update.as_ref()))?;
|
||||
map_napi_err(doc.apply_update_from_binary_v1(update.as_ref()), Status::GenericFailure)?;
|
||||
}
|
||||
|
||||
let buf = map_err!(doc.encode_update_v1())?;
|
||||
let buf = map_napi_err(doc.encode_update_v1(), Status::GenericFailure)?;
|
||||
|
||||
Ok(buf.into())
|
||||
}
|
||||
@@ -59,3 +39,17 @@ pub const AFFINE_PRO_PUBLIC_KEY: Option<&'static str> = std::option_env!("AFFINE
|
||||
|
||||
#[napi]
|
||||
pub const AFFINE_PRO_LICENSE_AES_KEY: Option<&'static str> = std::option_env!("AFFINE_PRO_LICENSE_AES_KEY");
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn merge_updates_reports_generic_failure() {
|
||||
let err = match merge_updates_in_apply_way(vec![Buffer::from(vec![0])]) {
|
||||
Ok(_) => panic!("expected error"),
|
||||
Err(err) => err,
|
||||
};
|
||||
assert_eq!(err.status, Status::GenericFailure);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
DO $$
|
||||
DECLARE error_message TEXT;
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pgcrypto";
|
||||
EXCEPTION
|
||||
WHEN OTHERS THEN
|
||||
error_message := 'pgcrypto extension not found. access_tokens.token will not be hashed automatically.' || E'\n' ||
|
||||
'Tokens will be lazily migrated on use.';
|
||||
RAISE WARNING '%', error_message;
|
||||
END;
|
||||
END IF;
|
||||
|
||||
IF EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN
|
||||
UPDATE "access_tokens"
|
||||
SET "token" = encode(digest("token", 'sha256'), 'hex')
|
||||
WHERE substr("token", 1, 3) = 'ut_';
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "magic_link_otps" (
|
||||
"id" VARCHAR NOT NULL,
|
||||
"email" TEXT NOT NULL,
|
||||
"otp_hash" VARCHAR NOT NULL,
|
||||
"token" TEXT NOT NULL,
|
||||
"client_nonce" TEXT,
|
||||
"attempts" INTEGER NOT NULL DEFAULT 0,
|
||||
"expires_at" TIMESTAMPTZ(3) NOT NULL,
|
||||
"created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMPTZ(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "magic_link_otps_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "magic_link_otps_email_key" ON "magic_link_otps"("email");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "magic_link_otps_expires_at_idx" ON "magic_link_otps"("expires_at");
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
"postinstall": "prisma generate"
|
||||
},
|
||||
"dependencies": {
|
||||
"@affine/s3-compat": "workspace:*",
|
||||
"@affine/server-native": "workspace:*",
|
||||
"@ai-sdk/anthropic": "^2.0.54",
|
||||
"@ai-sdk/google": "^2.0.45",
|
||||
@@ -34,8 +35,6 @@
|
||||
"@ai-sdk/openai-compatible": "^1.0.28",
|
||||
"@ai-sdk/perplexity": "^2.0.21",
|
||||
"@apollo/server": "^4.12.2",
|
||||
"@aws-sdk/client-s3": "^3.948.0",
|
||||
"@aws-sdk/s3-request-presigner": "^3.948.0",
|
||||
"@fal-ai/serverless-client": "^0.15.0",
|
||||
"@google-cloud/opentelemetry-cloud-trace-exporter": "^3.0.0",
|
||||
"@google-cloud/opentelemetry-resource-util": "^3.0.0",
|
||||
@@ -84,7 +83,7 @@
|
||||
"eventemitter2": "^6.4.9",
|
||||
"exa-js": "^1.6.13",
|
||||
"express": "^5.0.1",
|
||||
"fast-xml-parser": "^5.0.0",
|
||||
"fast-xml-parser": "^5.3.4",
|
||||
"get-stream": "^9.0.1",
|
||||
"google-auth-library": "^10.2.0",
|
||||
"graphql": "^16.9.0",
|
||||
@@ -141,18 +140,19 @@
|
||||
"@types/mixpanel": "^2.14.9",
|
||||
"@types/mustache": "^4.2.5",
|
||||
"@types/node": "^22.0.0",
|
||||
"@types/nodemailer": "^6.4.17",
|
||||
"@types/nodemailer": "^7.0.0",
|
||||
"@types/on-headers": "^1.0.3",
|
||||
"@types/react": "^19.0.1",
|
||||
"@types/react-dom": "^19.0.2",
|
||||
"@types/semver": "^7.5.8",
|
||||
"@types/sinon": "^17.0.3",
|
||||
"@types/sinon": "^21.0.0",
|
||||
"@types/supertest": "^6.0.2",
|
||||
"ava": "^6.4.0",
|
||||
"c8": "^10.1.3",
|
||||
"nodemon": "^3.1.11",
|
||||
"react-email": "4.0.11",
|
||||
"sinon": "^21.0.1",
|
||||
"socket.io-client": "^4.8.3",
|
||||
"supertest": "^7.1.4",
|
||||
"why-is-node-running": "^3.2.2"
|
||||
},
|
||||
|
||||
@@ -106,6 +106,21 @@ model VerificationToken {
|
||||
@@map("verification_tokens")
|
||||
}
|
||||
|
||||
model MagicLinkOtp {
|
||||
id String @id @default(uuid()) @db.VarChar
|
||||
email String @unique @db.Text
|
||||
otpHash String @map("otp_hash") @db.VarChar
|
||||
token String @db.Text
|
||||
clientNonce String? @map("client_nonce") @db.Text
|
||||
attempts Int @default(0)
|
||||
expiresAt DateTime @map("expires_at") @db.Timestamptz(3)
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
|
||||
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3)
|
||||
|
||||
@@index([expiresAt])
|
||||
@@map("magic_link_otps")
|
||||
}
|
||||
|
||||
model Workspace {
|
||||
// NOTE: manually set this column type to identity in migration file
|
||||
sid Int @unique @default(autoincrement())
|
||||
|
||||
@@ -32,6 +32,16 @@ Generated by [AVA](https://avajs.dev).
|
||||
|
||||
> Snapshot 4
|
||||
|
||||
{
|
||||
code: 'Bad Request',
|
||||
message: 'Invalid header',
|
||||
name: 'BAD_REQUEST',
|
||||
status: 400,
|
||||
type: 'BAD_REQUEST',
|
||||
}
|
||||
|
||||
> Snapshot 5
|
||||
|
||||
Buffer @Uint8Array [
|
||||
66616b65 20696d61 6765
|
||||
]
|
||||
@@ -56,7 +66,7 @@ Generated by [AVA](https://avajs.dev).
|
||||
|
||||
{
|
||||
code: 'Bad Request',
|
||||
message: 'Invalid URL',
|
||||
message: 'Invalid header',
|
||||
name: 'BAD_REQUEST',
|
||||
status: 400,
|
||||
type: 'BAD_REQUEST',
|
||||
@@ -64,6 +74,16 @@ Generated by [AVA](https://avajs.dev).
|
||||
|
||||
> Snapshot 4
|
||||
|
||||
{
|
||||
code: 'Bad Request',
|
||||
message: 'Invalid URL',
|
||||
name: 'BAD_REQUEST',
|
||||
status: 400,
|
||||
type: 'BAD_REQUEST',
|
||||
}
|
||||
|
||||
> Snapshot 5
|
||||
|
||||
{
|
||||
description: 'Test Description',
|
||||
favicons: [
|
||||
@@ -77,7 +97,7 @@ Generated by [AVA](https://avajs.dev).
|
||||
videos: [],
|
||||
}
|
||||
|
||||
> Snapshot 5
|
||||
> Snapshot 6
|
||||
|
||||
{
|
||||
charset: 'gbk',
|
||||
@@ -90,7 +110,7 @@ Generated by [AVA](https://avajs.dev).
|
||||
videos: [],
|
||||
}
|
||||
|
||||
> Snapshot 6
|
||||
> Snapshot 7
|
||||
|
||||
{
|
||||
charset: 'shift_jis',
|
||||
@@ -103,7 +123,7 @@ Generated by [AVA](https://avajs.dev).
|
||||
videos: [],
|
||||
}
|
||||
|
||||
> Snapshot 7
|
||||
> Snapshot 8
|
||||
|
||||
{
|
||||
charset: 'big5',
|
||||
@@ -116,7 +136,7 @@ Generated by [AVA](https://avajs.dev).
|
||||
videos: [],
|
||||
}
|
||||
|
||||
> Snapshot 8
|
||||
> Snapshot 9
|
||||
|
||||
{
|
||||
charset: 'euc-kr',
|
||||
|
||||
Binary file not shown.
@@ -33,7 +33,7 @@ test('change email', async t => {
|
||||
const u2Email = 'u2@affine.pro';
|
||||
|
||||
const user = await app.signupV1(u1Email);
|
||||
await sendChangeEmail(app, u1Email, 'affine.pro');
|
||||
await sendChangeEmail(app, u1Email, '/email-change');
|
||||
|
||||
const changeMail = app.mails.last('ChangeEmail');
|
||||
|
||||
@@ -53,7 +53,7 @@ test('change email', async t => {
|
||||
app,
|
||||
changeEmailToken as string,
|
||||
u2Email,
|
||||
'affine.pro'
|
||||
'/email-change-verify'
|
||||
);
|
||||
|
||||
const verifyMail = app.mails.last('VerifyChangeEmail');
|
||||
@@ -94,7 +94,7 @@ test('set and change password', async t => {
|
||||
const u1Email = 'u1@affine.pro';
|
||||
|
||||
const u1 = await app.signupV1(u1Email);
|
||||
await sendSetPasswordEmail(app, u1Email, 'affine.pro');
|
||||
await sendSetPasswordEmail(app, u1Email, '/password-change');
|
||||
|
||||
const setPasswordMail = app.mails.last('ChangePassword');
|
||||
const link = new URL(setPasswordMail.props.url);
|
||||
@@ -131,3 +131,29 @@ test('set and change password', async t => {
|
||||
t.not(user, null, 'failed to get current user');
|
||||
t.is(user?.email, u1Email, 'failed to get current user');
|
||||
});
|
||||
|
||||
test('should forbid graphql callbackUrl to external origin', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const u1Email = 'u1@affine.pro';
|
||||
await app.signupV1(u1Email);
|
||||
|
||||
const res = await app
|
||||
.POST('/graphql')
|
||||
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
|
||||
.send({
|
||||
query: `
|
||||
mutation($email: String!, $callbackUrl: String!) {
|
||||
sendChangeEmail(email: $email, callbackUrl: $callbackUrl)
|
||||
}
|
||||
`,
|
||||
variables: {
|
||||
email: u1Email,
|
||||
callbackUrl: 'https://evil.example',
|
||||
},
|
||||
})
|
||||
.expect(200);
|
||||
|
||||
t.truthy(res.body.errors?.length);
|
||||
t.is(res.body.errors[0].extensions?.name, 'ACTION_FORBIDDEN');
|
||||
});
|
||||
|
||||
@@ -5,6 +5,7 @@ import { HttpStatus } from '@nestjs/common';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import ava, { TestFn } from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
import supertest from 'supertest';
|
||||
|
||||
import { parseCookies as safeParseCookies } from '../../base/utils/request';
|
||||
import { AuthService } from '../../core/auth/service';
|
||||
@@ -126,6 +127,36 @@ test('should not be able to sign in if forbidden', async t => {
|
||||
t.pass();
|
||||
});
|
||||
|
||||
test('should forbid magic link with external callbackUrl', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const u1 = await app.createUser('u1@affine.pro');
|
||||
|
||||
await app
|
||||
.POST('/api/auth/sign-in')
|
||||
.send({
|
||||
email: u1.email,
|
||||
callbackUrl: 'https://evil.example/magic-link',
|
||||
})
|
||||
.expect(HttpStatus.FORBIDDEN);
|
||||
t.pass();
|
||||
});
|
||||
|
||||
test('should forbid magic link with untrusted redirect_uri in callbackUrl', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const u1 = await app.createUser('u1@affine.pro');
|
||||
|
||||
await app
|
||||
.POST('/api/auth/sign-in')
|
||||
.send({
|
||||
email: u1.email,
|
||||
callbackUrl: '/magic-link?redirect_uri=https://evil.example',
|
||||
})
|
||||
.expect(HttpStatus.FORBIDDEN);
|
||||
t.pass();
|
||||
});
|
||||
|
||||
test('should be able to sign out', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
@@ -136,13 +167,82 @@ test('should be able to sign out', async t => {
|
||||
.send({ email: u1.email, password: u1.password })
|
||||
.expect(200);
|
||||
|
||||
await app.GET('/api/auth/sign-out').expect(200);
|
||||
await app.POST('/api/auth/sign-out').expect(200);
|
||||
|
||||
const session = await currentUser(app);
|
||||
|
||||
t.falsy(session);
|
||||
});
|
||||
|
||||
test('should reject sign out when csrf token mismatched', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const u1 = await app.createUser('u1@affine.pro');
|
||||
|
||||
await app
|
||||
.POST('/api/auth/sign-in')
|
||||
.send({ email: u1.email, password: u1.password })
|
||||
.expect(200);
|
||||
|
||||
await app
|
||||
.POST('/api/auth/sign-out')
|
||||
.set('x-affine-csrf-token', 'invalid')
|
||||
.expect(HttpStatus.FORBIDDEN);
|
||||
|
||||
const session = await currentUser(app);
|
||||
t.is(session?.id, u1.id);
|
||||
});
|
||||
|
||||
test('should sign in desktop app via one-time open-app code', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const u1 = await app.createUser('u1@affine.pro');
|
||||
|
||||
await app
|
||||
.POST('/api/auth/sign-in')
|
||||
.send({ email: u1.email, password: u1.password })
|
||||
.expect(200);
|
||||
|
||||
const codeRes = await app.POST('/api/auth/open-app/sign-in-code').expect(201);
|
||||
|
||||
const code = codeRes.body.code as string;
|
||||
t.truthy(code);
|
||||
|
||||
const exchangeRes = await supertest(app.getHttpServer())
|
||||
.post('/api/auth/open-app/sign-in')
|
||||
.send({ code })
|
||||
.expect(201);
|
||||
|
||||
const exchangedCookies = exchangeRes.get('Set-Cookie') ?? [];
|
||||
t.true(
|
||||
exchangedCookies.some(c =>
|
||||
c.startsWith(`${AuthService.sessionCookieName}=`)
|
||||
)
|
||||
);
|
||||
|
||||
const cookieHeader = exchangedCookies.map(c => c.split(';')[0]).join('; ');
|
||||
const sessionRes = await supertest(app.getHttpServer())
|
||||
.get('/api/auth/session')
|
||||
.set('Cookie', cookieHeader)
|
||||
.expect(200);
|
||||
|
||||
t.is(sessionRes.body.user?.id, u1.id);
|
||||
|
||||
// one-time use
|
||||
await supertest(app.getHttpServer())
|
||||
.post('/api/auth/open-app/sign-in')
|
||||
.send({ code })
|
||||
.expect(400)
|
||||
.expect({
|
||||
status: 400,
|
||||
code: 'Bad Request',
|
||||
type: 'BAD_REQUEST',
|
||||
name: 'INVALID_AUTH_STATE',
|
||||
message:
|
||||
'Invalid auth state. You might start the auth progress from another device.',
|
||||
});
|
||||
});
|
||||
|
||||
test('should be able to correct user id cookie', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
@@ -228,7 +328,7 @@ test('should be able to sign out multiple accounts in one session', async t => {
|
||||
const u2 = await app.signupV1('u2@affine.pro');
|
||||
|
||||
// sign out u2
|
||||
await app.GET(`/api/auth/sign-out?user_id=${u2.id}`).expect(200);
|
||||
await app.POST(`/api/auth/sign-out?user_id=${u2.id}`).expect(200);
|
||||
|
||||
// list [u1]
|
||||
let session = await app.GET('/api/auth/session').expect(200);
|
||||
@@ -241,7 +341,7 @@ test('should be able to sign out multiple accounts in one session', async t => {
|
||||
.expect(200);
|
||||
|
||||
// sign out all account in session
|
||||
await app.GET('/api/auth/sign-out').expect(200);
|
||||
await app.POST('/api/auth/sign-out').expect(200);
|
||||
|
||||
session = await app.GET('/api/auth/session').expect(200);
|
||||
t.falsy(session.body.user);
|
||||
@@ -337,3 +437,56 @@ test('should not be able to sign in if token is invalid', async t => {
|
||||
|
||||
t.is(res.body.message, 'An invalid email token provided.');
|
||||
});
|
||||
|
||||
test('should not allow magic link OTP replay', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const u1 = await app.createUser('u1@affine.pro');
|
||||
|
||||
await app.POST('/api/auth/sign-in').send({ email: u1.email }).expect(200);
|
||||
const signInMail = app.mails.last('SignIn');
|
||||
const url = new URL(signInMail.props.url);
|
||||
const email = url.searchParams.get('email');
|
||||
const token = url.searchParams.get('token');
|
||||
|
||||
await app.POST('/api/auth/magic-link').send({ email, token }).expect(201);
|
||||
|
||||
await app
|
||||
.POST('/api/auth/magic-link')
|
||||
.send({ email, token })
|
||||
.expect(400)
|
||||
.expect({
|
||||
status: 400,
|
||||
code: 'Bad Request',
|
||||
type: 'INVALID_INPUT',
|
||||
name: 'INVALID_EMAIL_TOKEN',
|
||||
message: 'An invalid email token provided.',
|
||||
});
|
||||
t.pass();
|
||||
});
|
||||
|
||||
test('should lock magic link OTP after too many attempts', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const u1 = await app.createUser('u1@affine.pro');
|
||||
|
||||
await app.POST('/api/auth/sign-in').send({ email: u1.email }).expect(200);
|
||||
const signInMail = app.mails.last('SignIn');
|
||||
const url = new URL(signInMail.props.url);
|
||||
const email = url.searchParams.get('email');
|
||||
const token = url.searchParams.get('token') as string;
|
||||
|
||||
const wrongOtp = token === '000000' ? '000001' : '000000';
|
||||
|
||||
for (let i = 0; i < 10; i++) {
|
||||
await app
|
||||
.POST('/api/auth/magic-link')
|
||||
.send({ email, token: wrongOtp })
|
||||
.expect(400);
|
||||
}
|
||||
|
||||
await app.POST('/api/auth/magic-link').send({ email, token }).expect(400);
|
||||
|
||||
const session = await currentUser(app);
|
||||
t.falsy(session);
|
||||
});
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { TestingModule } from '@nestjs/testing';
|
||||
import test from 'ava';
|
||||
|
||||
@@ -7,6 +9,8 @@ import { createTestingModule } from './utils';
|
||||
|
||||
let cache: Cache;
|
||||
let module: TestingModule;
|
||||
const keyPrefix = `test:${randomUUID()}:`;
|
||||
const key = (name: string) => `${keyPrefix}${name}`;
|
||||
test.before(async () => {
|
||||
module = await createTestingModule({
|
||||
imports: FunctionalityModules,
|
||||
@@ -19,78 +23,78 @@ test.after.always(async () => {
|
||||
});
|
||||
|
||||
test('should be able to set normal cache', async t => {
|
||||
t.true(await cache.set('test', 1));
|
||||
t.is(await cache.get<number>('test'), 1);
|
||||
t.true(await cache.set(key('test'), 1));
|
||||
t.is(await cache.get<number>(key('test')), 1);
|
||||
|
||||
t.true(await cache.has('test'));
|
||||
t.true(await cache.delete('test'));
|
||||
t.is(await cache.get('test'), undefined);
|
||||
t.true(await cache.has(key('test')));
|
||||
t.true(await cache.delete(key('test')));
|
||||
t.is(await cache.get(key('test')), undefined);
|
||||
|
||||
t.true(await cache.set('test', { a: 1 }));
|
||||
t.deepEqual(await cache.get('test'), { a: 1 });
|
||||
t.true(await cache.set(key('test'), { a: 1 }));
|
||||
t.deepEqual(await cache.get(key('test')), { a: 1 });
|
||||
});
|
||||
|
||||
test('should be able to set cache with non-exiting flag', async t => {
|
||||
t.true(await cache.setnx('test-nx', 1));
|
||||
t.false(await cache.setnx('test-nx', 2));
|
||||
t.is(await cache.get('test-nx'), 1);
|
||||
t.true(await cache.setnx(key('test-nx'), 1));
|
||||
t.false(await cache.setnx(key('test-nx'), 2));
|
||||
t.is(await cache.get(key('test-nx')), 1);
|
||||
});
|
||||
|
||||
test('should be able to set cache with ttl', async t => {
|
||||
t.true(await cache.set('test-ttl', 1));
|
||||
t.is(await cache.get('test-ttl'), 1);
|
||||
t.true(await cache.set(key('test-ttl'), 1));
|
||||
t.is(await cache.get(key('test-ttl')), 1);
|
||||
|
||||
t.true(await cache.expire('test-ttl', 1 * 1000));
|
||||
const ttl = await cache.ttl('test-ttl');
|
||||
t.true(await cache.expire(key('test-ttl'), 1 * 1000));
|
||||
const ttl = await cache.ttl(key('test-ttl'));
|
||||
t.true(ttl <= 1 * 1000);
|
||||
t.true(ttl > 0);
|
||||
});
|
||||
|
||||
test('should be able to incr/decr number cache', async t => {
|
||||
t.true(await cache.set('test-incr', 1));
|
||||
t.is(await cache.increase('test-incr'), 2);
|
||||
t.is(await cache.increase('test-incr'), 3);
|
||||
t.is(await cache.decrease('test-incr'), 2);
|
||||
t.is(await cache.decrease('test-incr'), 1);
|
||||
t.true(await cache.set(key('test-incr'), 1));
|
||||
t.is(await cache.increase(key('test-incr')), 2);
|
||||
t.is(await cache.increase(key('test-incr')), 3);
|
||||
t.is(await cache.decrease(key('test-incr')), 2);
|
||||
t.is(await cache.decrease(key('test-incr')), 1);
|
||||
|
||||
// increase an nonexists number
|
||||
t.is(await cache.increase('test-incr2'), 1);
|
||||
t.is(await cache.increase('test-incr2'), 2);
|
||||
t.is(await cache.increase(key('test-incr2')), 1);
|
||||
t.is(await cache.increase(key('test-incr2')), 2);
|
||||
});
|
||||
|
||||
test('should be able to manipulate list cache', async t => {
|
||||
t.is(await cache.pushBack('test-list', 1), 1);
|
||||
t.is(await cache.pushBack('test-list', 2, 3, 4), 4);
|
||||
t.is(await cache.len('test-list'), 4);
|
||||
t.is(await cache.pushBack(key('test-list'), 1), 1);
|
||||
t.is(await cache.pushBack(key('test-list'), 2, 3, 4), 4);
|
||||
t.is(await cache.len(key('test-list')), 4);
|
||||
|
||||
t.deepEqual(await cache.list('test-list', 1, -1), [2, 3, 4]);
|
||||
t.deepEqual(await cache.list(key('test-list'), 1, -1), [2, 3, 4]);
|
||||
|
||||
t.deepEqual(await cache.popFront('test-list', 2), [1, 2]);
|
||||
t.deepEqual(await cache.popBack('test-list', 1), [4]);
|
||||
t.deepEqual(await cache.popFront(key('test-list'), 2), [1, 2]);
|
||||
t.deepEqual(await cache.popBack(key('test-list'), 1), [4]);
|
||||
|
||||
t.is(await cache.pushBack('test-list2', { a: 1 }), 1);
|
||||
t.deepEqual(await cache.popFront('test-list2', 1), [{ a: 1 }]);
|
||||
t.is(await cache.pushBack(key('test-list2'), { a: 1 }), 1);
|
||||
t.deepEqual(await cache.popFront(key('test-list2'), 1), [{ a: 1 }]);
|
||||
});
|
||||
|
||||
test('should be able to manipulate map cache', async t => {
|
||||
t.is(await cache.mapSet('test-map', 'a', 1), true);
|
||||
t.is(await cache.mapSet('test-map', 'b', 2), true);
|
||||
t.is(await cache.mapLen('test-map'), 2);
|
||||
t.is(await cache.mapSet(key('test-map'), 'a', 1), true);
|
||||
t.is(await cache.mapSet(key('test-map'), 'b', 2), true);
|
||||
t.is(await cache.mapLen(key('test-map')), 2);
|
||||
|
||||
t.is(await cache.mapGet('test-map', 'a'), 1);
|
||||
t.is(await cache.mapGet('test-map', 'b'), 2);
|
||||
t.is(await cache.mapGet(key('test-map'), 'a'), 1);
|
||||
t.is(await cache.mapGet(key('test-map'), 'b'), 2);
|
||||
|
||||
t.is(await cache.mapIncrease('test-map', 'a'), 2);
|
||||
t.is(await cache.mapIncrease('test-map', 'a'), 3);
|
||||
t.is(await cache.mapDecrease('test-map', 'b', 3), -1);
|
||||
t.is(await cache.mapIncrease(key('test-map'), 'a'), 2);
|
||||
t.is(await cache.mapIncrease(key('test-map'), 'a'), 3);
|
||||
t.is(await cache.mapDecrease(key('test-map'), 'b', 3), -1);
|
||||
|
||||
const keys = await cache.mapKeys('test-map');
|
||||
const keys = await cache.mapKeys(key('test-map'));
|
||||
t.deepEqual(keys, ['a', 'b']);
|
||||
|
||||
const randomKey = await cache.mapRandomKey('test-map');
|
||||
const randomKey = await cache.mapRandomKey(key('test-map'));
|
||||
t.truthy(randomKey);
|
||||
t.true(keys.includes(randomKey!));
|
||||
|
||||
t.is(await cache.mapDelete('test-map', 'a'), true);
|
||||
t.is(await cache.mapGet('test-map', 'a'), undefined);
|
||||
t.is(await cache.mapDelete(key('test-map'), 'a'), true);
|
||||
t.is(await cache.mapGet(key('test-map'), 'a'), undefined);
|
||||
});
|
||||
|
||||
@@ -922,7 +922,6 @@ test('should be able to manage context', async t => {
|
||||
const { id: fileId } = await addContextFile(
|
||||
app,
|
||||
contextId,
|
||||
'fileId1',
|
||||
'sample.pdf',
|
||||
buffer
|
||||
);
|
||||
|
||||
@@ -41,6 +41,7 @@ interface TestingAppMetadata {
|
||||
export class TestingApp extends NestApplication {
|
||||
private sessionCookie: string | null = null;
|
||||
private currentUserCookie: string | null = null;
|
||||
private csrfCookie: string | null = null;
|
||||
private readonly userCookies: Set<string> = new Set();
|
||||
|
||||
create = createFactory(this.get(PrismaClient, { strict: false }));
|
||||
@@ -65,12 +66,23 @@ export class TestingApp extends NestApplication {
|
||||
method: 'options' | 'get' | 'post' | 'put' | 'delete' | 'patch',
|
||||
path: string
|
||||
): supertest.Test {
|
||||
return supertest(this.getHttpServer())
|
||||
const cookies = [
|
||||
`${AuthService.sessionCookieName}=${this.sessionCookie ?? ''}`,
|
||||
`${AuthService.userCookieName}=${this.currentUserCookie ?? ''}`,
|
||||
];
|
||||
if (this.csrfCookie) {
|
||||
cookies.push(`${AuthService.csrfCookieName}=${this.csrfCookie}`);
|
||||
}
|
||||
|
||||
const req = supertest(this.getHttpServer())
|
||||
[method](path)
|
||||
.set('Cookie', [
|
||||
`${AuthService.sessionCookieName}=${this.sessionCookie ?? ''}`,
|
||||
`${AuthService.userCookieName}=${this.currentUserCookie ?? ''}`,
|
||||
]);
|
||||
.set('Cookie', cookies);
|
||||
|
||||
if (this.csrfCookie) {
|
||||
req.set('x-affine-csrf-token', this.csrfCookie);
|
||||
}
|
||||
|
||||
return req;
|
||||
}
|
||||
|
||||
gql = gqlFetcherFactory('', async (_input, init) => {
|
||||
@@ -123,6 +135,9 @@ export class TestingApp extends NestApplication {
|
||||
|
||||
this.sessionCookie = cookies[AuthService.sessionCookieName];
|
||||
this.currentUserCookie = cookies[AuthService.userCookieName];
|
||||
if (AuthService.csrfCookieName in cookies) {
|
||||
this.csrfCookie = cookies[AuthService.csrfCookieName] || null;
|
||||
}
|
||||
if (this.currentUserCookie) {
|
||||
this.userCookies.add(this.currentUserCookie);
|
||||
}
|
||||
@@ -180,13 +195,17 @@ export class TestingApp extends NestApplication {
|
||||
}
|
||||
|
||||
async logout(userId?: string) {
|
||||
const res = await this.GET(
|
||||
const res = await this.POST(
|
||||
'/api/auth/sign-out' + (userId ? `?user_id=${userId}` : '')
|
||||
).expect(200);
|
||||
const cookies = parseCookies(res);
|
||||
this.sessionCookie = cookies[AuthService.sessionCookieName];
|
||||
if (AuthService.csrfCookieName in cookies) {
|
||||
this.csrfCookie = cookies[AuthService.csrfCookieName] || null;
|
||||
}
|
||||
if (!this.sessionCookie) {
|
||||
this.currentUserCookie = null;
|
||||
this.csrfCookie = null;
|
||||
this.userCookies.clear();
|
||||
} else {
|
||||
this.currentUserCookie = cookies[AuthService.userCookieName];
|
||||
|
||||
@@ -16,9 +16,13 @@ e2e('should get doc markdown success', async t => {
|
||||
user: owner,
|
||||
});
|
||||
|
||||
const path = `/rpc/workspaces/${workspace.id}/docs/${docSnapshot.id}/markdown`;
|
||||
const res = await app
|
||||
.GET(`/rpc/workspaces/${workspace.id}/docs/${docSnapshot.id}/markdown`)
|
||||
.set('x-access-token', crypto.sign(docSnapshot.id))
|
||||
.GET(path)
|
||||
.set(
|
||||
'x-access-token',
|
||||
crypto.signInternalAccessToken({ method: 'GET', path })
|
||||
)
|
||||
.expect(200)
|
||||
.expect('Content-Type', 'application/json; charset=utf-8');
|
||||
|
||||
@@ -32,9 +36,13 @@ e2e('should get doc markdown return null when doc not exists', async t => {
|
||||
});
|
||||
|
||||
const docId = randomUUID();
|
||||
const path = `/rpc/workspaces/${workspace.id}/docs/${docId}/markdown`;
|
||||
const res = await app
|
||||
.GET(`/rpc/workspaces/${workspace.id}/docs/${docId}/markdown`)
|
||||
.set('x-access-token', crypto.sign(docId))
|
||||
.GET(path)
|
||||
.set(
|
||||
'x-access-token',
|
||||
crypto.signInternalAccessToken({ method: 'GET', path })
|
||||
)
|
||||
.expect(404)
|
||||
.expect('Content-Type', 'application/json; charset=utf-8');
|
||||
|
||||
|
||||
@@ -39,31 +39,7 @@ Generated by [AVA](https://avajs.dev).
|
||||
},
|
||||
}
|
||||
|
||||
## should not return apple oauth provider when client version is not specified
|
||||
|
||||
> Snapshot 1
|
||||
|
||||
{
|
||||
serverConfig: {
|
||||
oauthProviders: [
|
||||
'Google',
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
## should not return apple oauth provider in version < 0.22.0
|
||||
|
||||
> Snapshot 1
|
||||
|
||||
{
|
||||
serverConfig: {
|
||||
oauthProviders: [
|
||||
'Google',
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
## should not return apple oauth provider when client version format is not correct
|
||||
## should return apple oauth provider when client version is not specified
|
||||
|
||||
> Snapshot 1
|
||||
|
||||
@@ -71,6 +47,7 @@ Generated by [AVA](https://avajs.dev).
|
||||
serverConfig: {
|
||||
oauthProviders: [
|
||||
'Google',
|
||||
'Apple',
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
Binary file not shown.
@@ -71,7 +71,7 @@ e2e('should return apple oauth provider in version >= 0.22.0', async t => {
|
||||
});
|
||||
|
||||
e2e(
|
||||
'should not return apple oauth provider when client version is not specified',
|
||||
'should return apple oauth provider when client version is not specified',
|
||||
async t => {
|
||||
const res = await app.gql({
|
||||
query: oauthProvidersQuery,
|
||||
@@ -80,32 +80,3 @@ e2e(
|
||||
t.snapshot(res);
|
||||
}
|
||||
);
|
||||
|
||||
e2e('should not return apple oauth provider in version < 0.22.0', async t => {
|
||||
const res = await app.gql({
|
||||
query: oauthProvidersQuery,
|
||||
context: {
|
||||
headers: {
|
||||
'x-affine-version': '0.21.0',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
t.snapshot(res);
|
||||
});
|
||||
|
||||
e2e(
|
||||
'should not return apple oauth provider when client version format is not correct',
|
||||
async t => {
|
||||
const res = await app.gql({
|
||||
query: oauthProvidersQuery,
|
||||
context: {
|
||||
headers: {
|
||||
'x-affine-version': 'mock-invalid-version',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
t.snapshot(res);
|
||||
}
|
||||
);
|
||||
|
||||
@@ -41,9 +41,7 @@ class MockR2Provider extends R2StorageProvider {
|
||||
super(config, bucket);
|
||||
}
|
||||
|
||||
destroy() {
|
||||
this.client.destroy();
|
||||
}
|
||||
destroy() {}
|
||||
|
||||
// @ts-ignore expect override
|
||||
override async proxyPutObject(
|
||||
@@ -66,7 +64,7 @@ class MockR2Provider extends R2StorageProvider {
|
||||
body: any,
|
||||
options: { contentLength?: number } = {}
|
||||
) {
|
||||
const etag = `"etag-${partNumber}"`;
|
||||
const etag = `etag-${partNumber}`;
|
||||
this.partCalls.push({
|
||||
key,
|
||||
uploadId,
|
||||
@@ -230,11 +228,13 @@ async function getBlobUploadPartUrl(
|
||||
) {
|
||||
const data = await gql(
|
||||
`
|
||||
mutation getBlobUploadPartUrl($workspaceId: String!, $key: String!, $uploadId: String!, $partNumber: Int!) {
|
||||
getBlobUploadPartUrl(workspaceId: $workspaceId, key: $key, uploadId: $uploadId, partNumber: $partNumber) {
|
||||
uploadUrl
|
||||
headers
|
||||
expiresAt
|
||||
query getBlobUploadPartUrl($workspaceId: String!, $key: String!, $uploadId: String!, $partNumber: Int!) {
|
||||
workspace(id: $workspaceId) {
|
||||
blobUploadPartUrl(key: $key, uploadId: $uploadId, partNumber: $partNumber) {
|
||||
uploadUrl
|
||||
headers
|
||||
expiresAt
|
||||
}
|
||||
}
|
||||
}
|
||||
`,
|
||||
@@ -242,7 +242,7 @@ async function getBlobUploadPartUrl(
|
||||
'getBlobUploadPartUrl'
|
||||
);
|
||||
|
||||
return data.getBlobUploadPartUrl;
|
||||
return data.workspace.blobUploadPartUrl;
|
||||
}
|
||||
|
||||
async function setupWorkspace() {
|
||||
@@ -322,7 +322,7 @@ e2e('should proxy multipart upload and return etag', async t => {
|
||||
.send(payload);
|
||||
|
||||
t.is(res.status, 200);
|
||||
t.is(res.get('etag'), '"etag-1"');
|
||||
t.is(res.get('etag'), 'etag-1');
|
||||
|
||||
const calls = getProvider().partCalls;
|
||||
t.is(calls.length, 1);
|
||||
@@ -356,7 +356,7 @@ e2e('should resume multipart upload and return uploaded parts', async t => {
|
||||
const init2 = await createBlobUpload(workspace.id, key, totalSize, 'bin');
|
||||
t.is(init2.method, 'MULTIPART');
|
||||
t.is(init2.uploadId, 'upload-id');
|
||||
t.deepEqual(init2.uploadedParts, [{ partNumber: 1, etag: '"etag-1"' }]);
|
||||
t.deepEqual(init2.uploadedParts, [{ partNumber: 1, etag: 'etag-1' }]);
|
||||
t.is(getProvider().createMultipartCalls, 1);
|
||||
});
|
||||
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
import { getUserQuery } from '@affine/graphql';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import { ThrottlerStorage } from '../../../base/throttler';
|
||||
import { app, e2e, Mockers } from '../test';
|
||||
|
||||
e2e('user(email) should return null without auth', async t => {
|
||||
const user = await app.create(Mockers.User);
|
||||
|
||||
await app.logout();
|
||||
|
||||
const res = await app.gql({
|
||||
query: getUserQuery,
|
||||
variables: { email: user.email },
|
||||
});
|
||||
|
||||
t.is(res.user, null);
|
||||
});
|
||||
|
||||
e2e('user(email) should return null outside workspace scope', async t => {
|
||||
await app.logout();
|
||||
const me = await app.signup();
|
||||
const other = await app.create(Mockers.User);
|
||||
|
||||
const res = await app.gql({
|
||||
query: getUserQuery,
|
||||
variables: { email: other.email },
|
||||
});
|
||||
|
||||
t.is(res.user, null);
|
||||
|
||||
// sanity: querying self is always allowed
|
||||
const self = await app.gql({
|
||||
query: getUserQuery,
|
||||
variables: { email: me.email },
|
||||
});
|
||||
t.truthy(self.user);
|
||||
if (!self.user) return;
|
||||
t.is(self.user.__typename, 'UserType');
|
||||
if (self.user.__typename === 'UserType') {
|
||||
t.is(self.user.id, me.id);
|
||||
}
|
||||
});
|
||||
|
||||
e2e('user(email) should return user within workspace scope', async t => {
|
||||
await app.logout();
|
||||
const me = await app.signup();
|
||||
const other = await app.create(Mockers.User);
|
||||
const ws = await app.create(Mockers.Workspace, { owner: me });
|
||||
|
||||
await app.create(Mockers.WorkspaceUser, {
|
||||
workspaceId: ws.id,
|
||||
userId: other.id,
|
||||
});
|
||||
|
||||
const res = await app.gql({
|
||||
query: getUserQuery,
|
||||
variables: { email: other.email },
|
||||
});
|
||||
|
||||
t.truthy(res.user);
|
||||
if (!res.user) return;
|
||||
t.is(res.user.__typename, 'UserType');
|
||||
if (res.user.__typename === 'UserType') {
|
||||
t.is(res.user.id, other.id);
|
||||
}
|
||||
});
|
||||
|
||||
e2e('user(email) should be rate limited', async t => {
|
||||
await app.logout();
|
||||
const me = await app.signup();
|
||||
|
||||
const stub = Sinon.stub(app.get(ThrottlerStorage), 'increment').resolves({
|
||||
timeToExpire: 10,
|
||||
totalHits: 21,
|
||||
isBlocked: true,
|
||||
timeToBlockExpire: 10,
|
||||
});
|
||||
|
||||
await t.throwsAsync(
|
||||
app.gql({
|
||||
query: getUserQuery,
|
||||
variables: { email: me.email },
|
||||
}),
|
||||
{ message: /too many requests/i }
|
||||
);
|
||||
|
||||
stub.restore();
|
||||
});
|
||||
@@ -17,17 +17,3 @@ Generated by [AVA](https://avajs.dev).
|
||||
name: 'Free',
|
||||
storageQuota: 10737418240,
|
||||
}
|
||||
|
||||
## should get feature if extra fields exist in feature config
|
||||
|
||||
> Snapshot 1
|
||||
|
||||
{
|
||||
blobLimit: 10485760,
|
||||
businessBlobLimit: 104857600,
|
||||
copilotActionLimit: 10,
|
||||
historyPeriod: 604800000,
|
||||
memberLimit: 3,
|
||||
name: 'Free',
|
||||
storageQuota: 10737418240,
|
||||
}
|
||||
|
||||
Binary file not shown.
@@ -68,7 +68,7 @@ test("should be able to redirect to oauth provider's login page", async t => {
|
||||
|
||||
const res = await app
|
||||
.POST('/api/oauth/preflight')
|
||||
.send({ provider: 'Google' })
|
||||
.send({ provider: 'Google', client_nonce: 'test-nonce' })
|
||||
.expect(HttpStatus.OK);
|
||||
|
||||
const { url } = res.body;
|
||||
@@ -100,7 +100,7 @@ test('should be able to redirect to oauth provider with multiple hosts', async t
|
||||
const res = await app
|
||||
.POST('/api/oauth/preflight')
|
||||
.set('host', 'test.affine.dev')
|
||||
.send({ provider: 'Google' })
|
||||
.send({ provider: 'Google', client_nonce: 'test-nonce' })
|
||||
.expect(HttpStatus.OK);
|
||||
|
||||
const { url } = res.body;
|
||||
@@ -156,12 +156,45 @@ test('should be able to redirect to oauth provider with client_nonce', async t =
|
||||
t.truthy(state.state);
|
||||
});
|
||||
|
||||
test('should forbid preflight with untrusted redirect_uri', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
await app
|
||||
.POST('/api/oauth/preflight')
|
||||
.send({
|
||||
provider: 'Google',
|
||||
redirect_uri: 'https://evil.example',
|
||||
client_nonce: 'test-nonce',
|
||||
})
|
||||
.expect(HttpStatus.FORBIDDEN);
|
||||
t.pass();
|
||||
});
|
||||
|
||||
test('should throw if client_nonce is missing in preflight', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
await app
|
||||
.POST('/api/oauth/preflight')
|
||||
.send({ provider: 'Google' })
|
||||
.expect(HttpStatus.BAD_REQUEST)
|
||||
.expect({
|
||||
status: 400,
|
||||
code: 'Bad Request',
|
||||
type: 'BAD_REQUEST',
|
||||
name: 'MISSING_OAUTH_QUERY_PARAMETER',
|
||||
message: 'Missing query parameter `client_nonce`.',
|
||||
data: { name: 'client_nonce' },
|
||||
});
|
||||
|
||||
t.pass();
|
||||
});
|
||||
|
||||
test('should throw if provider is invalid', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
await app
|
||||
.POST('/api/oauth/preflight')
|
||||
.send({ provider: 'Invalid' })
|
||||
.send({ provider: 'Invalid', client_nonce: 'test-nonce' })
|
||||
.expect(HttpStatus.BAD_REQUEST)
|
||||
.expect({
|
||||
status: 400,
|
||||
@@ -320,7 +353,7 @@ test('should throw if provider is invalid in callback uri', async t => {
|
||||
function mockOAuthProvider(
|
||||
app: TestingApp,
|
||||
email: string,
|
||||
clientNonce?: string
|
||||
clientNonce: string = randomUUID()
|
||||
) {
|
||||
const provider = app.get(GoogleOAuthProvider);
|
||||
const oauth = app.get(OAuthService);
|
||||
@@ -337,16 +370,18 @@ function mockOAuthProvider(
|
||||
email,
|
||||
avatarUrl: 'avatar',
|
||||
});
|
||||
|
||||
return clientNonce;
|
||||
}
|
||||
|
||||
test('should be able to sign up with oauth', async t => {
|
||||
const { app, db } = t.context;
|
||||
|
||||
mockOAuthProvider(app, 'u2@affine.pro');
|
||||
const clientNonce = mockOAuthProvider(app, 'u2@affine.pro');
|
||||
|
||||
await app
|
||||
.POST('/api/oauth/callback')
|
||||
.send({ code: '1', state: '1' })
|
||||
.send({ code: '1', state: '1', client_nonce: clientNonce })
|
||||
.expect(HttpStatus.OK);
|
||||
|
||||
const sessionUser = await currentUser(app);
|
||||
@@ -427,11 +462,11 @@ test('should throw if client_nonce is invalid', async t => {
|
||||
test('should not throw if account registered', async t => {
|
||||
const { app, u1 } = t.context;
|
||||
|
||||
mockOAuthProvider(app, u1.email);
|
||||
const clientNonce = mockOAuthProvider(app, u1.email);
|
||||
|
||||
const res = await app
|
||||
.POST('/api/oauth/callback')
|
||||
.send({ code: '1', state: '1' })
|
||||
.send({ code: '1', state: '1', client_nonce: clientNonce })
|
||||
.expect(HttpStatus.OK);
|
||||
|
||||
t.is(res.body.id, u1.id);
|
||||
@@ -442,9 +477,11 @@ test('should be able to fullfil user with oauth sign in', async t => {
|
||||
|
||||
const u3 = await app.createUser('u3@affine.pro');
|
||||
|
||||
mockOAuthProvider(app, u3.email);
|
||||
const clientNonce = mockOAuthProvider(app, u3.email);
|
||||
|
||||
await app.POST('/api/oauth/callback').send({ code: '1', state: '1' });
|
||||
await app
|
||||
.POST('/api/oauth/callback')
|
||||
.send({ code: '1', state: '1', client_nonce: clientNonce });
|
||||
|
||||
const sessionUser = await currentUser(app);
|
||||
|
||||
|
||||
@@ -1,5 +1,339 @@
|
||||
import test from 'ava';
|
||||
import test, { type ExecutionContext } from 'ava';
|
||||
import { io, type Socket as SocketIOClient } from 'socket.io-client';
|
||||
import { Doc, encodeStateAsUpdate } from 'yjs';
|
||||
|
||||
test('should test through sync gateway', t => {
|
||||
t.pass();
|
||||
import { createTestingApp, TestingApp } from '../utils';
|
||||
|
||||
type WebsocketResponse<T> =
|
||||
| { error: { name: string; message: string } }
|
||||
| { data: T };
|
||||
|
||||
const WS_TIMEOUT_MS = 5_000;
|
||||
|
||||
function unwrapResponse<T>(t: ExecutionContext, res: WebsocketResponse<T>): T {
|
||||
if ('data' in res) {
|
||||
return res.data;
|
||||
}
|
||||
|
||||
t.log(res);
|
||||
throw new Error(`Websocket error: ${res.error.name}: ${res.error.message}`);
|
||||
}
|
||||
|
||||
async function withTimeout<T>(
|
||||
promise: Promise<T>,
|
||||
timeoutMs: number,
|
||||
label: string
|
||||
) {
|
||||
let timer: NodeJS.Timeout | undefined;
|
||||
const timeout = new Promise<never>((_, reject) => {
|
||||
timer = setTimeout(() => {
|
||||
reject(new Error(`Timeout (${timeoutMs}ms): ${label}`));
|
||||
}, timeoutMs);
|
||||
});
|
||||
|
||||
try {
|
||||
return await Promise.race([promise, timeout]);
|
||||
} finally {
|
||||
if (timer) clearTimeout(timer);
|
||||
}
|
||||
}
|
||||
|
||||
function createClient(url: string, cookie: string): SocketIOClient {
|
||||
return io(url, {
|
||||
transports: ['websocket'],
|
||||
reconnection: false,
|
||||
forceNew: true,
|
||||
extraHeaders: {
|
||||
cookie,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function waitForConnect(socket: SocketIOClient) {
|
||||
if (socket.connected) {
|
||||
return Promise.resolve();
|
||||
}
|
||||
return withTimeout(
|
||||
new Promise<void>((resolve, reject) => {
|
||||
socket.once('connect', resolve);
|
||||
socket.once('connect_error', reject);
|
||||
}),
|
||||
WS_TIMEOUT_MS,
|
||||
'socket connect'
|
||||
);
|
||||
}
|
||||
|
||||
function waitForDisconnect(socket: SocketIOClient) {
|
||||
if (socket.disconnected) {
|
||||
return Promise.resolve();
|
||||
}
|
||||
return withTimeout(
|
||||
new Promise<void>(resolve => {
|
||||
socket.once('disconnect', () => resolve());
|
||||
}),
|
||||
WS_TIMEOUT_MS,
|
||||
'socket disconnect'
|
||||
);
|
||||
}
|
||||
|
||||
function emitWithAck<T>(socket: SocketIOClient, event: string, data: unknown) {
|
||||
return withTimeout(
|
||||
new Promise<WebsocketResponse<T>>(resolve => {
|
||||
socket.emit(event, data, (res: WebsocketResponse<T>) => resolve(res));
|
||||
}),
|
||||
WS_TIMEOUT_MS,
|
||||
`ack ${event}`
|
||||
);
|
||||
}
|
||||
|
||||
function waitForEvent<T>(socket: SocketIOClient, event: string) {
|
||||
return withTimeout(
|
||||
new Promise<T>(resolve => {
|
||||
socket.once(event, (payload: T) => resolve(payload));
|
||||
}),
|
||||
WS_TIMEOUT_MS,
|
||||
`event ${event}`
|
||||
);
|
||||
}
|
||||
|
||||
function expectNoEvent(
|
||||
socket: SocketIOClient,
|
||||
event: string,
|
||||
durationMs = 200
|
||||
) {
|
||||
return withTimeout(
|
||||
new Promise<void>((resolve, reject) => {
|
||||
let timer: NodeJS.Timeout;
|
||||
const onEvent = () => {
|
||||
clearTimeout(timer);
|
||||
socket.off(event, onEvent);
|
||||
reject(new Error(`Unexpected event received: ${event}`));
|
||||
};
|
||||
|
||||
timer = setTimeout(() => {
|
||||
socket.off(event, onEvent);
|
||||
resolve();
|
||||
}, durationMs);
|
||||
|
||||
socket.on(event, onEvent);
|
||||
}),
|
||||
WS_TIMEOUT_MS,
|
||||
`expect no event ${event}`
|
||||
);
|
||||
}
|
||||
|
||||
async function login(app: TestingApp) {
|
||||
const user = await app.createUser('u1@affine.pro');
|
||||
const res = await app
|
||||
.POST('/api/auth/sign-in')
|
||||
.send({ email: user.email, password: user.password })
|
||||
.expect(200);
|
||||
|
||||
const cookies = res.get('Set-Cookie') ?? [];
|
||||
const cookieHeader = cookies.map(c => c.split(';')[0]).join('; ');
|
||||
return { user, cookieHeader };
|
||||
}
|
||||
|
||||
function createYjsUpdateBase64() {
|
||||
const doc = new Doc();
|
||||
doc.getMap('m').set('k', 'v');
|
||||
const update = encodeStateAsUpdate(doc);
|
||||
return Buffer.from(update).toString('base64');
|
||||
}
|
||||
|
||||
let app: TestingApp;
|
||||
let url: string;
|
||||
|
||||
test.before(async () => {
|
||||
app = await createTestingApp();
|
||||
url = app.url();
|
||||
});
|
||||
|
||||
test.beforeEach(async () => {
|
||||
await app.initTestingDB();
|
||||
});
|
||||
|
||||
test.after.always(async () => {
|
||||
await app.close();
|
||||
});
|
||||
|
||||
test('clientVersion=0.25.0 should only receive space:broadcast-doc-update', async t => {
|
||||
const { user, cookieHeader } = await login(app);
|
||||
const spaceId = user.id;
|
||||
const update = createYjsUpdateBase64();
|
||||
|
||||
const sender = createClient(url, cookieHeader);
|
||||
const receiver = createClient(url, cookieHeader);
|
||||
|
||||
try {
|
||||
await Promise.all([waitForConnect(sender), waitForConnect(receiver)]);
|
||||
|
||||
const receiverJoin = unwrapResponse(
|
||||
t,
|
||||
await emitWithAck<{ clientId: string; success: boolean }>(
|
||||
receiver,
|
||||
'space:join',
|
||||
{ spaceType: 'userspace', spaceId, clientVersion: '0.25.0' }
|
||||
)
|
||||
);
|
||||
t.true(receiverJoin.success);
|
||||
|
||||
const senderJoin = unwrapResponse(
|
||||
t,
|
||||
await emitWithAck<{ clientId: string; success: boolean }>(
|
||||
sender,
|
||||
'space:join',
|
||||
{ spaceType: 'userspace', spaceId, clientVersion: '0.26.0' }
|
||||
)
|
||||
);
|
||||
t.true(senderJoin.success);
|
||||
|
||||
const onUpdate = waitForEvent<{
|
||||
spaceType: string;
|
||||
spaceId: string;
|
||||
docId: string;
|
||||
update: string;
|
||||
}>(receiver, 'space:broadcast-doc-update');
|
||||
const noUpdates = expectNoEvent(receiver, 'space:broadcast-doc-updates');
|
||||
|
||||
const pushRes = await emitWithAck<{ accepted: true; timestamp?: number }>(
|
||||
sender,
|
||||
'space:push-doc-update',
|
||||
{
|
||||
spaceType: 'userspace',
|
||||
spaceId,
|
||||
docId: 'doc-1',
|
||||
update,
|
||||
}
|
||||
);
|
||||
unwrapResponse(t, pushRes);
|
||||
|
||||
const message = await onUpdate;
|
||||
t.is(message.spaceType, 'userspace');
|
||||
t.is(message.spaceId, spaceId);
|
||||
t.is(message.docId, 'doc-1');
|
||||
t.is(message.update, update);
|
||||
|
||||
await noUpdates;
|
||||
} finally {
|
||||
sender.disconnect();
|
||||
receiver.disconnect();
|
||||
}
|
||||
});
|
||||
|
||||
test('clientVersion>=0.26.0 should only receive space:broadcast-doc-updates', async t => {
|
||||
const { user, cookieHeader } = await login(app);
|
||||
const spaceId = user.id;
|
||||
const update = createYjsUpdateBase64();
|
||||
|
||||
const sender = createClient(url, cookieHeader);
|
||||
const receiver = createClient(url, cookieHeader);
|
||||
|
||||
try {
|
||||
await Promise.all([waitForConnect(sender), waitForConnect(receiver)]);
|
||||
|
||||
const receiverJoin = unwrapResponse(
|
||||
t,
|
||||
await emitWithAck<{ clientId: string; success: boolean }>(
|
||||
receiver,
|
||||
'space:join',
|
||||
{ spaceType: 'userspace', spaceId, clientVersion: '0.26.0' }
|
||||
)
|
||||
);
|
||||
t.true(receiverJoin.success);
|
||||
|
||||
const senderJoin = unwrapResponse(
|
||||
t,
|
||||
await emitWithAck<{ clientId: string; success: boolean }>(
|
||||
sender,
|
||||
'space:join',
|
||||
{ spaceType: 'userspace', spaceId, clientVersion: '0.25.0' }
|
||||
)
|
||||
);
|
||||
t.true(senderJoin.success);
|
||||
|
||||
const onUpdates = waitForEvent<{
|
||||
spaceType: string;
|
||||
spaceId: string;
|
||||
docId: string;
|
||||
updates: string[];
|
||||
}>(receiver, 'space:broadcast-doc-updates');
|
||||
const noUpdate = expectNoEvent(receiver, 'space:broadcast-doc-update');
|
||||
|
||||
const pushRes = await emitWithAck<{ accepted: true; timestamp?: number }>(
|
||||
sender,
|
||||
'space:push-doc-update',
|
||||
{
|
||||
spaceType: 'userspace',
|
||||
spaceId,
|
||||
docId: 'doc-2',
|
||||
update,
|
||||
}
|
||||
);
|
||||
unwrapResponse(t, pushRes);
|
||||
|
||||
const message = await onUpdates;
|
||||
t.is(message.spaceType, 'userspace');
|
||||
t.is(message.spaceId, spaceId);
|
||||
t.is(message.docId, 'doc-2');
|
||||
t.deepEqual(message.updates, [update]);
|
||||
|
||||
await noUpdate;
|
||||
} finally {
|
||||
sender.disconnect();
|
||||
receiver.disconnect();
|
||||
}
|
||||
});
|
||||
|
||||
test('clientVersion<0.25.0 should be rejected and disconnected', async t => {
|
||||
const { user, cookieHeader } = await login(app);
|
||||
const spaceId = user.id;
|
||||
|
||||
const socket = createClient(url, cookieHeader);
|
||||
try {
|
||||
await waitForConnect(socket);
|
||||
|
||||
const res = unwrapResponse(
|
||||
t,
|
||||
await emitWithAck<{ clientId: string; success: boolean }>(
|
||||
socket,
|
||||
'space:join',
|
||||
{ spaceType: 'userspace', spaceId, clientVersion: '0.24.4' }
|
||||
)
|
||||
);
|
||||
t.false(res.success);
|
||||
|
||||
await waitForDisconnect(socket);
|
||||
} finally {
|
||||
socket.disconnect();
|
||||
}
|
||||
});
|
||||
|
||||
test('space:join-awareness should reject clientVersion<0.25.0', async t => {
|
||||
const { user, cookieHeader } = await login(app);
|
||||
const spaceId = user.id;
|
||||
|
||||
const socket = createClient(url, cookieHeader);
|
||||
try {
|
||||
await waitForConnect(socket);
|
||||
|
||||
const res = unwrapResponse(
|
||||
t,
|
||||
await emitWithAck<{ clientId: string; success: boolean }>(
|
||||
socket,
|
||||
'space:join-awareness',
|
||||
{
|
||||
spaceType: 'userspace',
|
||||
spaceId,
|
||||
docId: 'doc-awareness',
|
||||
clientVersion: '0.24.4',
|
||||
}
|
||||
)
|
||||
);
|
||||
t.false(res.success);
|
||||
|
||||
await waitForDisconnect(socket);
|
||||
} finally {
|
||||
socket.disconnect();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -152,9 +152,13 @@ export async function getBlobUploadPartUrl(
|
||||
) {
|
||||
const res = await app.gql(
|
||||
`
|
||||
mutation getBlobUploadPartUrl($workspaceId: String!, $key: String!, $uploadId: String!, $partNumber: Int!) {
|
||||
getBlobUploadPartUrl(workspaceId: $workspaceId, key: $key, uploadId: $uploadId, partNumber: $partNumber) {
|
||||
uploadUrl
|
||||
query getBlobUploadPartUrl($workspaceId: String!, $key: String!, $uploadId: String!, $partNumber: Int!) {
|
||||
workspace(id: $workspaceId) {
|
||||
blobUploadPartUrl(key: $key, uploadId: $uploadId, partNumber: $partNumber) {
|
||||
uploadUrl
|
||||
headers
|
||||
expiresAt
|
||||
}
|
||||
}
|
||||
}
|
||||
`,
|
||||
@@ -165,5 +169,5 @@ export async function getBlobUploadPartUrl(
|
||||
partNumber,
|
||||
}
|
||||
);
|
||||
return res.getBlobUploadPartUrl;
|
||||
return res.workspace.blobUploadPartUrl;
|
||||
}
|
||||
|
||||
@@ -250,7 +250,6 @@ export async function listContext(
|
||||
export async function addContextFile(
|
||||
app: TestingApp,
|
||||
contextId: string,
|
||||
blobId: string,
|
||||
fileName: string,
|
||||
content: Buffer
|
||||
): Promise<{ id: string }> {
|
||||
@@ -269,7 +268,7 @@ export async function addContextFile(
|
||||
`,
|
||||
variables: {
|
||||
content: null,
|
||||
options: { contextId, blobId },
|
||||
options: { contextId },
|
||||
},
|
||||
})
|
||||
)
|
||||
|
||||
@@ -139,11 +139,11 @@ export async function revokeUser(
|
||||
): Promise<boolean> {
|
||||
const res = await app.gql(`
|
||||
mutation {
|
||||
revoke(workspaceId: "${workspaceId}", userId: "${userId}")
|
||||
revokeMember(workspaceId: "${workspaceId}", userId: "${userId}")
|
||||
}
|
||||
`);
|
||||
|
||||
return res.revoke;
|
||||
return res.revokeMember;
|
||||
}
|
||||
|
||||
export async function getInviteInfo(
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
GlobalExceptionFilter,
|
||||
JobQueue,
|
||||
} from '../../base';
|
||||
import { SocketIoAdapter } from '../../base/websocket';
|
||||
import { AuthService } from '../../core/auth';
|
||||
import { Mailer } from '../../core/mail';
|
||||
import { UserModel } from '../../models';
|
||||
@@ -61,6 +62,7 @@ export async function createTestingApp(
|
||||
);
|
||||
|
||||
app.use(cookieParser());
|
||||
app.useWebSocketAdapter(new SocketIoAdapter(app));
|
||||
|
||||
if (moduleDef.tapApp) {
|
||||
moduleDef.tapApp(app);
|
||||
@@ -89,6 +91,7 @@ export function parseCookies(res: supertest.Response) {
|
||||
export class TestingApp extends ApplyType<INestApplication>() {
|
||||
private sessionCookie: string | null = null;
|
||||
private currentUserCookie: string | null = null;
|
||||
private csrfCookie: string | null = null;
|
||||
private readonly userCookies: Set<string> = new Set();
|
||||
|
||||
readonly create!: ReturnType<typeof createFactory>;
|
||||
@@ -103,6 +106,7 @@ export class TestingApp extends ApplyType<INestApplication>() {
|
||||
await initTestingDB(this);
|
||||
this.sessionCookie = null;
|
||||
this.currentUserCookie = null;
|
||||
this.csrfCookie = null;
|
||||
this.userCookies.clear();
|
||||
}
|
||||
|
||||
@@ -118,12 +122,23 @@ export class TestingApp extends ApplyType<INestApplication>() {
|
||||
method: 'options' | 'get' | 'post' | 'put' | 'delete' | 'patch',
|
||||
path: string
|
||||
): supertest.Test {
|
||||
return supertest(this.getHttpServer())
|
||||
const cookies = [
|
||||
`${AuthService.sessionCookieName}=${this.sessionCookie ?? ''}`,
|
||||
`${AuthService.userCookieName}=${this.currentUserCookie ?? ''}`,
|
||||
];
|
||||
if (this.csrfCookie) {
|
||||
cookies.push(`${AuthService.csrfCookieName}=${this.csrfCookie}`);
|
||||
}
|
||||
|
||||
const req = supertest(this.getHttpServer())
|
||||
[method](path)
|
||||
.set('Cookie', [
|
||||
`${AuthService.sessionCookieName}=${this.sessionCookie ?? ''}`,
|
||||
`${AuthService.userCookieName}=${this.currentUserCookie ?? ''}`,
|
||||
]);
|
||||
.set('Cookie', cookies);
|
||||
|
||||
if (this.csrfCookie) {
|
||||
req.set('x-affine-csrf-token', this.csrfCookie);
|
||||
}
|
||||
|
||||
return req;
|
||||
}
|
||||
|
||||
OPTIONS(path: string): supertest.Test {
|
||||
@@ -147,6 +162,9 @@ export class TestingApp extends ApplyType<INestApplication>() {
|
||||
|
||||
this.sessionCookie = cookies[AuthService.sessionCookieName];
|
||||
this.currentUserCookie = cookies[AuthService.userCookieName];
|
||||
if (AuthService.csrfCookieName in cookies) {
|
||||
this.csrfCookie = cookies[AuthService.csrfCookieName] || null;
|
||||
}
|
||||
if (this.currentUserCookie) {
|
||||
this.userCookies.add(this.currentUserCookie);
|
||||
}
|
||||
@@ -270,13 +288,17 @@ export class TestingApp extends ApplyType<INestApplication>() {
|
||||
}
|
||||
|
||||
async logout(userId?: string) {
|
||||
const res = await this.GET(
|
||||
const res = await this.POST(
|
||||
'/api/auth/sign-out' + (userId ? `?user_id=${userId}` : '')
|
||||
).expect(200);
|
||||
const cookies = parseCookies(res);
|
||||
this.sessionCookie = cookies[AuthService.sessionCookieName];
|
||||
if (AuthService.csrfCookieName in cookies) {
|
||||
this.csrfCookie = cookies[AuthService.csrfCookieName] || null;
|
||||
}
|
||||
if (!this.sessionCookie) {
|
||||
this.currentUserCookie = null;
|
||||
this.csrfCookie = null;
|
||||
this.userCookies.clear();
|
||||
} else {
|
||||
this.currentUserCookie = cookies[AuthService.userCookieName];
|
||||
|
||||
@@ -188,10 +188,10 @@ export async function revokeMember(
|
||||
const res = await app.gql(
|
||||
`
|
||||
mutation {
|
||||
revoke(workspaceId: "${workspaceId}", userId: "${userId}")
|
||||
revokeMember(workspaceId: "${workspaceId}", userId: "${userId}")
|
||||
}
|
||||
`
|
||||
);
|
||||
|
||||
return res.revoke;
|
||||
return res.revokeMember;
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ function checkVersion(enabled = true) {
|
||||
client: {
|
||||
versionControl: {
|
||||
enabled,
|
||||
requiredVersion: '>=0.20.0',
|
||||
requiredVersion: '>=0.25.0',
|
||||
},
|
||||
},
|
||||
});
|
||||
@@ -88,23 +88,23 @@ test('should passthrough is version range is invalid', async t => {
|
||||
});
|
||||
|
||||
test('should pass if client version is allowed', async t => {
|
||||
let res = await app.GET('/guarded/test').set('x-affine-version', '0.20.0');
|
||||
let res = await app.GET('/guarded/test').set('x-affine-version', '0.25.0');
|
||||
|
||||
t.is(res.status, 200);
|
||||
|
||||
res = await app.GET('/guarded/test').set('x-affine-version', '0.21.0');
|
||||
res = await app.GET('/guarded/test').set('x-affine-version', '0.26.0');
|
||||
|
||||
t.is(res.status, 200);
|
||||
|
||||
config.override({
|
||||
client: {
|
||||
versionControl: {
|
||||
requiredVersion: '>=0.19.0',
|
||||
requiredVersion: '>=0.25.0',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
res = await app.GET('/guarded/test').set('x-affine-version', '0.19.0');
|
||||
res = await app.GET('/guarded/test').set('x-affine-version', '0.25.0');
|
||||
|
||||
t.is(res.status, 200);
|
||||
});
|
||||
@@ -115,7 +115,7 @@ test('should fail if client version is not set or invalid', async t => {
|
||||
t.is(res.status, 403);
|
||||
t.is(
|
||||
res.body.message,
|
||||
'Unsupported client with version [unset_or_invalid], required version is [>=0.20.0].'
|
||||
'Unsupported client with version [unset_or_invalid], required version is [>=0.25.0].'
|
||||
);
|
||||
|
||||
res = await app.GET('/guarded/test').set('x-affine-version', 'invalid');
|
||||
@@ -123,7 +123,7 @@ test('should fail if client version is not set or invalid', async t => {
|
||||
t.is(res.status, 403);
|
||||
t.is(
|
||||
res.body.message,
|
||||
'Unsupported client with version [invalid], required version is [>=0.20.0].'
|
||||
'Unsupported client with version [invalid], required version is [>=0.25.0].'
|
||||
);
|
||||
});
|
||||
|
||||
@@ -131,17 +131,17 @@ test('should tell upgrade if client version is lower than allowed', async t => {
|
||||
config.override({
|
||||
client: {
|
||||
versionControl: {
|
||||
requiredVersion: '>=0.21.0 <=0.22.0',
|
||||
requiredVersion: '>=0.26.0 <=0.27.0',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
let res = await app.GET('/guarded/test').set('x-affine-version', '0.20.0');
|
||||
let res = await app.GET('/guarded/test').set('x-affine-version', '0.25.0');
|
||||
|
||||
t.is(res.status, 403);
|
||||
t.is(
|
||||
res.body.message,
|
||||
'Unsupported client with version [0.20.0], required version is [>=0.21.0 <=0.22.0].'
|
||||
'Unsupported client with version [0.25.0], required version is [>=0.26.0 <=0.27.0].'
|
||||
);
|
||||
});
|
||||
|
||||
@@ -149,17 +149,17 @@ test('should tell downgrade if client version is higher than allowed', async t =
|
||||
config.override({
|
||||
client: {
|
||||
versionControl: {
|
||||
requiredVersion: '>=0.20.0 <=0.22.0',
|
||||
requiredVersion: '>=0.25.0 <=0.26.0',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
let res = await app.GET('/guarded/test').set('x-affine-version', '0.23.0');
|
||||
let res = await app.GET('/guarded/test').set('x-affine-version', '0.27.0');
|
||||
|
||||
t.is(res.status, 403);
|
||||
t.is(
|
||||
res.body.message,
|
||||
'Unsupported client with version [0.23.0], required version is [>=0.20.0 <=0.22.0].'
|
||||
'Unsupported client with version [0.27.0], required version is [>=0.25.0 <=0.26.0].'
|
||||
);
|
||||
});
|
||||
|
||||
@@ -167,25 +167,25 @@ test('should test prerelease version', async t => {
|
||||
config.override({
|
||||
client: {
|
||||
versionControl: {
|
||||
requiredVersion: '>=0.19.0',
|
||||
requiredVersion: '>=0.25.0',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
let res = await app
|
||||
.GET('/guarded/test')
|
||||
.set('x-affine-version', '0.19.0-canary.1');
|
||||
.set('x-affine-version', '0.25.0-canary.1');
|
||||
|
||||
// 0.19.0-canary.1 is lower than 0.19.0 obviously
|
||||
// 0.25.0-canary.1 is lower than 0.25.0 obviously
|
||||
t.is(res.status, 403);
|
||||
|
||||
res = await app
|
||||
.GET('/guarded/test')
|
||||
.set('x-affine-version', '0.20.0-canary.1');
|
||||
.set('x-affine-version', '0.26.0-canary.1');
|
||||
|
||||
t.is(res.status, 200);
|
||||
|
||||
res = await app.GET('/guarded/test').set('x-affine-version', '0.20.0-beta.2');
|
||||
res = await app.GET('/guarded/test').set('x-affine-version', '0.26.0-beta.2');
|
||||
|
||||
t.is(res.status, 200);
|
||||
});
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
import type { ExecutionContext, TestFn } from 'ava';
|
||||
import ava from 'ava';
|
||||
import { LookupAddress } from 'dns';
|
||||
import Sinon from 'sinon';
|
||||
import type { Response } from 'supertest';
|
||||
|
||||
import {
|
||||
__resetDnsLookupForTests,
|
||||
__setDnsLookupForTests,
|
||||
type DnsLookup,
|
||||
} from '../base/utils/ssrf';
|
||||
import { createTestingApp, TestingApp } from './utils';
|
||||
|
||||
type TestContext = {
|
||||
@@ -11,15 +17,30 @@ type TestContext = {
|
||||
|
||||
const test = ava as TestFn<TestContext>;
|
||||
|
||||
const LookupAddressStub = (async (_hostname, options) => {
|
||||
const result = [{ address: '76.76.21.21', family: 4 }] as LookupAddress[];
|
||||
const isOptions = options && typeof options === 'object';
|
||||
if (isOptions && 'all' in options && options.all) {
|
||||
return result;
|
||||
}
|
||||
return result[0];
|
||||
}) as DnsLookup;
|
||||
|
||||
test.before(async t => {
|
||||
// @ts-expect-error test
|
||||
env.DEPLOYMENT_TYPE = 'selfhosted';
|
||||
|
||||
// Avoid relying on real DNS during tests. SSRF protection uses dns.lookup().
|
||||
__setDnsLookupForTests(LookupAddressStub);
|
||||
|
||||
const app = await createTestingApp();
|
||||
|
||||
t.context.app = app;
|
||||
});
|
||||
|
||||
test.after.always(async t => {
|
||||
Sinon.restore();
|
||||
__resetDnsLookupForTests();
|
||||
await t.context.app.close();
|
||||
});
|
||||
|
||||
@@ -29,7 +50,8 @@ const assertAndSnapshotRaw = async (
|
||||
message: string,
|
||||
options?: {
|
||||
status?: number;
|
||||
origin?: string;
|
||||
origin?: string | null;
|
||||
referer?: string | null;
|
||||
method?: 'GET' | 'OPTIONS' | 'POST';
|
||||
body?: any;
|
||||
checker?: (res: Response) => any;
|
||||
@@ -37,16 +59,21 @@ const assertAndSnapshotRaw = async (
|
||||
) => {
|
||||
const {
|
||||
status = 200,
|
||||
origin = 'http://localhost',
|
||||
origin = 'http://localhost:3010',
|
||||
referer,
|
||||
method = 'GET',
|
||||
checker = () => {},
|
||||
} = options || {};
|
||||
const { app } = t.context;
|
||||
const res = app[method](route)
|
||||
.set('Origin', origin)
|
||||
.send(options?.body)
|
||||
.expect(status)
|
||||
.expect(checker);
|
||||
const req = app[method](route);
|
||||
if (origin) {
|
||||
req.set('Origin', origin);
|
||||
}
|
||||
if (referer) {
|
||||
req.set('Referer', referer);
|
||||
}
|
||||
|
||||
const res = req.send(options?.body).expect(status).expect(checker);
|
||||
await t.notThrowsAsync(res, message);
|
||||
t.snapshot((await res).body);
|
||||
};
|
||||
@@ -76,6 +103,14 @@ test('should proxy image', async t => {
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/image-proxy?url=http://example.com/image.png',
|
||||
'should return 400 if origin and referer are missing',
|
||||
{ status: 400, origin: null, referer: null }
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/image-proxy?url=http://example.com/image.png',
|
||||
@@ -86,17 +121,13 @@ test('should proxy image', async t => {
|
||||
|
||||
{
|
||||
const fakeBuffer = Buffer.from('fake image');
|
||||
const fakeResponse = {
|
||||
ok: true,
|
||||
const fakeResponse = new Response(fakeBuffer, {
|
||||
status: 200,
|
||||
headers: {
|
||||
get: (header: string) => {
|
||||
if (header.toLowerCase() === 'content-type') return 'image/png';
|
||||
if (header.toLowerCase() === 'content-disposition') return 'inline';
|
||||
return null;
|
||||
},
|
||||
'content-type': 'image/png',
|
||||
'content-disposition': 'inline',
|
||||
},
|
||||
arrayBuffer: async () => fakeBuffer,
|
||||
} as any;
|
||||
});
|
||||
|
||||
const fetchSpy = Sinon.stub(global, 'fetch').resolves(fakeResponse);
|
||||
|
||||
@@ -132,6 +163,18 @@ test('should preview link', async t => {
|
||||
{ status: 400, method: 'POST' }
|
||||
);
|
||||
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/link-preview',
|
||||
'should return 400 if origin and referer are missing',
|
||||
{
|
||||
status: 400,
|
||||
method: 'POST',
|
||||
origin: null,
|
||||
referer: null,
|
||||
body: { url: 'http://external.com/page' },
|
||||
}
|
||||
);
|
||||
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/link-preview',
|
||||
'should return 400 if provided URL is from the same origin',
|
||||
|
||||
@@ -141,7 +141,7 @@ test('should override correctly', t => {
|
||||
config: {
|
||||
credentials: {
|
||||
accessKeyId: '1',
|
||||
accessKeySecret: '1',
|
||||
secretAccessKey: '1',
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -169,7 +169,7 @@ test('should override correctly', t => {
|
||||
config: {
|
||||
credentials: {
|
||||
accessKeyId: '1',
|
||||
accessKeySecret: '1',
|
||||
secretAccessKey: '1',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -275,6 +275,26 @@ export const USER_FRIENDLY_ERRORS = {
|
||||
args: { message: 'string' },
|
||||
message: ({ message }) => `HTTP request error, message: ${message}`,
|
||||
},
|
||||
ssrf_blocked_error: {
|
||||
type: 'invalid_input',
|
||||
args: { reason: 'string' },
|
||||
message: ({ reason }) => {
|
||||
switch (reason) {
|
||||
case 'unresolvable_hostname':
|
||||
return 'Failed to resolve hostname';
|
||||
case 'too_many_redirects':
|
||||
return 'Too many redirects';
|
||||
default:
|
||||
return 'Invalid URL';
|
||||
}
|
||||
},
|
||||
},
|
||||
response_too_large_error: {
|
||||
type: 'invalid_input',
|
||||
args: { limitBytes: 'number', receivedBytes: 'number' },
|
||||
message: ({ limitBytes, receivedBytes }) =>
|
||||
`Response too large (${receivedBytes} bytes), limit is ${limitBytes} bytes`,
|
||||
},
|
||||
email_service_not_configured: {
|
||||
type: 'internal_server_error',
|
||||
message: 'Email service is not configured.',
|
||||
|
||||
@@ -54,6 +54,27 @@ export class HttpRequestError extends UserFriendlyError {
|
||||
super('bad_request', 'http_request_error', message, args);
|
||||
}
|
||||
}
|
||||
@ObjectType()
|
||||
class SsrfBlockedErrorDataType {
|
||||
@Field() reason!: string
|
||||
}
|
||||
|
||||
export class SsrfBlockedError extends UserFriendlyError {
|
||||
constructor(args: SsrfBlockedErrorDataType, message?: string | ((args: SsrfBlockedErrorDataType) => string)) {
|
||||
super('invalid_input', 'ssrf_blocked_error', message, args);
|
||||
}
|
||||
}
|
||||
@ObjectType()
|
||||
class ResponseTooLargeErrorDataType {
|
||||
@Field() limitBytes!: number
|
||||
@Field() receivedBytes!: number
|
||||
}
|
||||
|
||||
export class ResponseTooLargeError extends UserFriendlyError {
|
||||
constructor(args: ResponseTooLargeErrorDataType, message?: string | ((args: ResponseTooLargeErrorDataType) => string)) {
|
||||
super('invalid_input', 'response_too_large_error', message, args);
|
||||
}
|
||||
}
|
||||
|
||||
export class EmailServiceNotConfigured extends UserFriendlyError {
|
||||
constructor(message?: string) {
|
||||
@@ -1131,6 +1152,8 @@ export enum ErrorNames {
|
||||
BAD_REQUEST,
|
||||
GRAPHQL_BAD_REQUEST,
|
||||
HTTP_REQUEST_ERROR,
|
||||
SSRF_BLOCKED_ERROR,
|
||||
RESPONSE_TOO_LARGE_ERROR,
|
||||
EMAIL_SERVICE_NOT_CONFIGURED,
|
||||
QUERY_TOO_LONG,
|
||||
VALIDATION_ERROR,
|
||||
@@ -1274,5 +1297,5 @@ registerEnumType(ErrorNames, {
|
||||
export const ErrorDataUnionType = createUnionType({
|
||||
name: 'ErrorDataUnion',
|
||||
types: () =>
|
||||
[GraphqlBadRequestDataType, HttpRequestErrorDataType, QueryTooLongDataType, ValidationErrorDataType, WrongSignInCredentialsDataType, UnknownOauthProviderDataType, InvalidOauthCallbackCodeDataType, MissingOauthQueryParameterDataType, InvalidOauthResponseDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, WorkspacePermissionNotFoundDataType, SpaceNotFoundDataType, MemberNotFoundInSpaceDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, SpaceShouldHaveOnlyOneOwnerDataType, DocNotFoundDataType, DocActionDeniedDataType, DocUpdateBlockedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, ExpectToGrantDocUserRolesDataType, ExpectToRevokeDocUserRolesDataType, ExpectToUpdateDocUserRoleDataType, NoMoreSeatDataType, UnsupportedSubscriptionPlanDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CalendarProviderRequestErrorDataType, NoCopilotProviderAvailableDataType, CopilotFailedToGenerateEmbeddingDataType, CopilotDocNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderNotSupportedDataType, CopilotProviderSideErrorDataType, CopilotInvalidContextDataType, CopilotContextFileNotSupportedDataType, CopilotFailedToModifyContextDataType, CopilotFailedToMatchContextDataType, CopilotFailedToMatchGlobalContextDataType, CopilotFailedToAddWorkspaceFileEmbeddingDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType, InvalidLicenseToActivateDataType, InvalidLicenseUpdateParamsDataType, UnsupportedClientVersionDataType, MentionUserDocAccessDeniedDataType, InvalidAppConfigDataType, InvalidAppConfigInputDataType, InvalidSearchProviderRequestDataType, InvalidIndexerInputDataType] as const,
|
||||
[GraphqlBadRequestDataType, HttpRequestErrorDataType, SsrfBlockedErrorDataType, ResponseTooLargeErrorDataType, QueryTooLongDataType, ValidationErrorDataType, WrongSignInCredentialsDataType, UnknownOauthProviderDataType, InvalidOauthCallbackCodeDataType, MissingOauthQueryParameterDataType, InvalidOauthResponseDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, WorkspacePermissionNotFoundDataType, SpaceNotFoundDataType, MemberNotFoundInSpaceDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, SpaceShouldHaveOnlyOneOwnerDataType, DocNotFoundDataType, DocActionDeniedDataType, DocUpdateBlockedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, ExpectToGrantDocUserRolesDataType, ExpectToRevokeDocUserRolesDataType, ExpectToUpdateDocUserRoleDataType, NoMoreSeatDataType, UnsupportedSubscriptionPlanDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CalendarProviderRequestErrorDataType, NoCopilotProviderAvailableDataType, CopilotFailedToGenerateEmbeddingDataType, CopilotDocNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderNotSupportedDataType, CopilotProviderSideErrorDataType, CopilotInvalidContextDataType, CopilotContextFileNotSupportedDataType, CopilotFailedToModifyContextDataType, CopilotFailedToMatchContextDataType, CopilotFailedToMatchGlobalContextDataType, CopilotFailedToAddWorkspaceFileEmbeddingDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType, InvalidLicenseToActivateDataType, InvalidLicenseUpdateParamsDataType, UnsupportedClientVersionDataType, MentionUserDocAccessDeniedDataType, InvalidAppConfigDataType, InvalidAppConfigInputDataType, InvalidSearchProviderRequestDataType, InvalidIndexerInputDataType] as const,
|
||||
});
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import { generateKeyPairSync } from 'node:crypto';
|
||||
|
||||
import ava, { TestFn } from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
@@ -7,11 +9,20 @@ const test = ava as TestFn<{
|
||||
crypto: CryptoHelper;
|
||||
}>;
|
||||
|
||||
const privateKey = `-----BEGIN PRIVATE KEY-----
|
||||
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgS3IAkshQuSmFWGpe
|
||||
rGTg2vwaC3LdcvBQlYHHMBYJZMyhRANCAAQXdT/TAh4neNEpd4UqpDIEqWv0XvFo
|
||||
BRJxGsC5I/fetqObdx1+KEjcm8zFU2xLaUTw9IZCu8OslloOjQv4ur0a
|
||||
-----END PRIVATE KEY-----`;
|
||||
function generateTestPrivateKey(): string {
|
||||
const { privateKey } = generateKeyPairSync('ec', {
|
||||
namedCurve: 'prime256v1',
|
||||
});
|
||||
return privateKey
|
||||
.export({
|
||||
type: 'pkcs8',
|
||||
format: 'pem',
|
||||
})
|
||||
.toString();
|
||||
}
|
||||
|
||||
const privateKey = generateTestPrivateKey();
|
||||
const privateKey2 = generateTestPrivateKey();
|
||||
|
||||
test.beforeEach(async t => {
|
||||
t.context.crypto = new CryptoHelper({
|
||||
@@ -30,6 +41,21 @@ test('should be able to sign and verify', t => {
|
||||
t.false(t.context.crypto.verify(`${data},fake-signature`));
|
||||
});
|
||||
|
||||
test('should verify signatures across key rotation', t => {
|
||||
const data = 'hello world';
|
||||
const signatureV1 = t.context.crypto.sign(data);
|
||||
t.true(t.context.crypto.verify(signatureV1));
|
||||
|
||||
(t.context.crypto as any).config.crypto.privateKey = privateKey2;
|
||||
t.context.crypto.onConfigChanged({
|
||||
updates: { crypto: { privateKey: privateKey2 } },
|
||||
} as any);
|
||||
|
||||
const signatureV2 = t.context.crypto.sign(data);
|
||||
t.true(t.context.crypto.verify(signatureV1));
|
||||
t.true(t.context.crypto.verify(signatureV2));
|
||||
});
|
||||
|
||||
test('should same data should get different signature', t => {
|
||||
const data = 'hello world';
|
||||
const signature = t.context.crypto.sign(data);
|
||||
@@ -46,11 +72,12 @@ test('should be able to encrypt and decrypt', t => {
|
||||
);
|
||||
|
||||
const encrypted = t.context.crypto.encrypt(data);
|
||||
const encrypted2 = t.context.crypto.encrypt(data);
|
||||
const decrypted = t.context.crypto.decrypt(encrypted);
|
||||
|
||||
// we are using a stub to make sure the iv is always 0,
|
||||
// the encrypted result will always be the same
|
||||
t.is(encrypted, 'AAAAAAAAAAAAAAAAOXbR/9glITL3BcO3kPd6fGOMasSkPQ==');
|
||||
// the encrypted result will always be the same for the same key+data
|
||||
t.is(encrypted2, encrypted);
|
||||
t.is(decrypted, data);
|
||||
|
||||
stub.restore();
|
||||
@@ -75,6 +102,24 @@ test('should be able to safe compare', t => {
|
||||
t.false(t.context.crypto.compare('abc', 'def'));
|
||||
});
|
||||
|
||||
test('should sign and parse internal access token', t => {
|
||||
const token = t.context.crypto.signInternalAccessToken({
|
||||
method: 'GET',
|
||||
path: '/rpc/workspaces/123/docs/456',
|
||||
now: 1700000000000,
|
||||
nonce: 'nonce-123',
|
||||
});
|
||||
|
||||
const payload = t.context.crypto.parseInternalAccessToken(token);
|
||||
t.deepEqual(payload, {
|
||||
v: 1,
|
||||
ts: 1700000000000,
|
||||
nonce: 'nonce-123',
|
||||
m: 'GET',
|
||||
p: '/rpc/workspaces/123/docs/456',
|
||||
});
|
||||
});
|
||||
|
||||
test('should be able to hash and verify password', async t => {
|
||||
const password = 'mySecurePassword';
|
||||
const hash = await t.context.crypto.encryptPassword(password);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import ava, { TestFn } from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import { ActionForbidden } from '../../error';
|
||||
import { URLHelper } from '../url';
|
||||
|
||||
const test = ava as TestFn<{
|
||||
@@ -85,6 +86,30 @@ test('can create link', t => {
|
||||
);
|
||||
});
|
||||
|
||||
test('can validate callbackUrl allowlist', t => {
|
||||
t.true(t.context.url.isAllowedCallbackUrl('/magic-link'));
|
||||
t.true(
|
||||
t.context.url.isAllowedCallbackUrl('https://app.affine.local/magic-link')
|
||||
);
|
||||
t.false(
|
||||
t.context.url.isAllowedCallbackUrl('https://evil.example/magic-link')
|
||||
);
|
||||
});
|
||||
|
||||
test('can validate redirect_uri allowlist', t => {
|
||||
t.true(t.context.url.isAllowedRedirectUri('/redirect-proxy'));
|
||||
t.true(t.context.url.isAllowedRedirectUri('https://github.com'));
|
||||
t.false(t.context.url.isAllowedRedirectUri('javascript:alert(1)'));
|
||||
t.false(t.context.url.isAllowedRedirectUri('https://evilgithub.com'));
|
||||
});
|
||||
|
||||
test('can create safe link', t => {
|
||||
t.is(t.context.url.safeLink('/path'), 'https://app.affine.local/path');
|
||||
t.throws(() => t.context.url.safeLink('https://evil.example/magic-link'), {
|
||||
instanceOf: ActionForbidden,
|
||||
});
|
||||
});
|
||||
|
||||
test('can safe redirect', t => {
|
||||
const res = {
|
||||
redirect: (to: string) => to,
|
||||
|
||||
@@ -76,6 +76,8 @@ export class CryptoHelper implements OnModuleInit {
|
||||
};
|
||||
};
|
||||
|
||||
private previousPublicKeys: KeyObject[] = [];
|
||||
|
||||
AFFiNEProPublicKey: Buffer | null = null;
|
||||
AFFiNEProLicenseAESKey: Buffer | null = null;
|
||||
|
||||
@@ -101,12 +103,23 @@ export class CryptoHelper implements OnModuleInit {
|
||||
}
|
||||
|
||||
private setup() {
|
||||
const prevPublicKey = this.keyPair?.publicKey;
|
||||
const privateKey = this.config.crypto.privateKey || generatePrivateKey();
|
||||
const { priv, pub } = parseKey(privateKey);
|
||||
const publicKey = pub
|
||||
.export({ format: 'pem', type: 'spki' })
|
||||
.toString('utf8');
|
||||
|
||||
if (prevPublicKey) {
|
||||
const prevPem = prevPublicKey
|
||||
.export({ format: 'pem', type: 'spki' })
|
||||
.toString('utf8');
|
||||
if (prevPem !== publicKey) {
|
||||
this.previousPublicKeys.unshift(prevPublicKey);
|
||||
this.previousPublicKeys = this.previousPublicKeys.slice(0, 2);
|
||||
}
|
||||
}
|
||||
|
||||
this.keyPair = {
|
||||
publicKey: pub,
|
||||
privateKey: priv,
|
||||
@@ -143,15 +156,81 @@ export class CryptoHelper implements OnModuleInit {
|
||||
}
|
||||
const input = Buffer.from(data, 'utf-8');
|
||||
const sigBuf = Buffer.from(signature, 'base64');
|
||||
if (this.keyType === 'ed25519') {
|
||||
// Ed25519 verifies the message directly
|
||||
return verify(null, input, this.keyPair.publicKey, sigBuf);
|
||||
} else {
|
||||
// ECDSA with SHA-256
|
||||
const verify = createVerify('sha256');
|
||||
verify.update(input);
|
||||
verify.end();
|
||||
return verify.verify(this.keyPair.publicKey, sigBuf);
|
||||
|
||||
const keys = [this.keyPair.publicKey, ...this.previousPublicKeys];
|
||||
return keys.some(publicKey => {
|
||||
const keyType = (publicKey.asymmetricKeyType as string) || 'ec';
|
||||
if (keyType === 'ed25519') {
|
||||
// Ed25519 verifies the message directly
|
||||
return verify(null, input, publicKey, sigBuf);
|
||||
} else {
|
||||
// ECDSA with SHA-256
|
||||
const verifier = createVerify('sha256');
|
||||
verifier.update(input);
|
||||
verifier.end();
|
||||
return verifier.verify(publicKey, sigBuf);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
signInternalAccessToken(input: {
|
||||
method: string;
|
||||
path: string;
|
||||
now?: number;
|
||||
nonce?: string;
|
||||
}) {
|
||||
const payload = {
|
||||
v: 1 as const,
|
||||
ts: input.now ?? Date.now(),
|
||||
nonce: input.nonce ?? this.randomBytes(16).toString('base64url'),
|
||||
m: input.method.toUpperCase(),
|
||||
p: input.path,
|
||||
};
|
||||
const data = Buffer.from(JSON.stringify(payload), 'utf8').toString(
|
||||
'base64url'
|
||||
);
|
||||
return this.sign(data);
|
||||
}
|
||||
|
||||
parseInternalAccessToken(signatureWithData: string): {
|
||||
v: 1;
|
||||
ts: number;
|
||||
nonce: string;
|
||||
m: string;
|
||||
p: string;
|
||||
} | null {
|
||||
const [data, signature] = signatureWithData.split(',');
|
||||
if (!signature) {
|
||||
return null;
|
||||
}
|
||||
if (!this.verify(signatureWithData)) {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
const json = Buffer.from(data, 'base64url').toString('utf8');
|
||||
const payload = JSON.parse(json) as unknown;
|
||||
if (!payload || typeof payload !== 'object') {
|
||||
return null;
|
||||
}
|
||||
const val = payload as {
|
||||
v?: unknown;
|
||||
ts?: unknown;
|
||||
nonce?: unknown;
|
||||
m?: unknown;
|
||||
p?: unknown;
|
||||
};
|
||||
if (
|
||||
val.v !== 1 ||
|
||||
typeof val.ts !== 'number' ||
|
||||
typeof val.nonce !== 'string' ||
|
||||
typeof val.m !== 'string' ||
|
||||
typeof val.p !== 'string'
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
return { v: 1, ts: val.ts, nonce: val.nonce, m: val.m, p: val.p };
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,31 @@ import type { Response } from 'express';
|
||||
import { ClsService } from 'nestjs-cls';
|
||||
|
||||
import { Config } from '../config';
|
||||
import { ActionForbidden } from '../error';
|
||||
import { OnEvent } from '../event';
|
||||
|
||||
const ALLOWED_REDIRECT_PROTOCOLS = new Set(['http:', 'https:']);
|
||||
// Keep in sync with frontend /redirect-proxy allowlist.
|
||||
const TRUSTED_REDIRECT_DOMAINS = [
|
||||
'google.com',
|
||||
'stripe.com',
|
||||
'github.com',
|
||||
'twitter.com',
|
||||
'discord.gg',
|
||||
'youtube.com',
|
||||
't.me',
|
||||
'reddit.com',
|
||||
'affine.pro',
|
||||
].map(d => d.toLowerCase());
|
||||
|
||||
function normalizeHostname(hostname: string) {
|
||||
return hostname.toLowerCase().replace(/\.$/, '');
|
||||
}
|
||||
|
||||
function hostnameMatchesDomain(hostname: string, domain: string) {
|
||||
return hostname === domain || hostname.endsWith(`.${domain}`);
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class URLHelper {
|
||||
redirectAllowHosts!: string[];
|
||||
@@ -110,6 +133,13 @@ export class URLHelper {
|
||||
return this.url(path, query).toString();
|
||||
}
|
||||
|
||||
safeLink(path: string, query: Record<string, any> = {}) {
|
||||
if (!this.isAllowedCallbackUrl(path)) {
|
||||
throw new ActionForbidden();
|
||||
}
|
||||
return this.link(path, query);
|
||||
}
|
||||
|
||||
safeRedirect(res: Response, to: string) {
|
||||
try {
|
||||
const finalTo = new URL(decodeURIComponent(to), this.requestBaseUrl);
|
||||
@@ -131,6 +161,68 @@ export class URLHelper {
|
||||
return res.redirect(this.baseUrl);
|
||||
}
|
||||
|
||||
isAllowedCallbackUrl(url: string): boolean {
|
||||
if (!url) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Allow same-app relative paths (e.g. `/magic-link?...`).
|
||||
if (url.startsWith('/') && !url.startsWith('//')) {
|
||||
return true;
|
||||
}
|
||||
|
||||
try {
|
||||
const u = new URL(url);
|
||||
if (!ALLOWED_REDIRECT_PROTOCOLS.has(u.protocol)) {
|
||||
return false;
|
||||
}
|
||||
if (u.username || u.password) {
|
||||
return false;
|
||||
}
|
||||
return this.allowedOrigins.includes(u.origin);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
isAllowedRedirectUri(redirectUri: string): boolean {
|
||||
if (!redirectUri) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Allow internal navigation (e.g. `/` or `/redirect-proxy?...`).
|
||||
if (redirectUri.startsWith('/') && !redirectUri.startsWith('//')) {
|
||||
return true;
|
||||
}
|
||||
|
||||
try {
|
||||
const u = new URL(redirectUri);
|
||||
if (!ALLOWED_REDIRECT_PROTOCOLS.has(u.protocol)) {
|
||||
return false;
|
||||
}
|
||||
if (u.username || u.password) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const hostname = normalizeHostname(u.hostname);
|
||||
|
||||
// Allow server known hosts.
|
||||
for (const origin of this.allowedOrigins) {
|
||||
const allowedHost = normalizeHostname(new URL(origin).hostname);
|
||||
if (hostname === allowedHost) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Allow known trusted domains (for redirect-proxy).
|
||||
return TRUSTED_REDIRECT_DOMAINS.some(domain =>
|
||||
hostnameMatchesDomain(hostname, domain)
|
||||
);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
verify(url: string | URL) {
|
||||
try {
|
||||
if (typeof url === 'string') {
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import { parseListPartsXml } from '@affine/s3-compat';
|
||||
import test from 'ava';
|
||||
|
||||
test('parseListPartsXml handles array parts and pagination', t => {
|
||||
const xml = `<?xml version="1.0" encoding="UTF-8"?>
|
||||
<ListPartsResult>
|
||||
<Bucket>test</Bucket>
|
||||
<Key>key</Key>
|
||||
<UploadId>upload-id</UploadId>
|
||||
<PartNumberMarker>0</PartNumberMarker>
|
||||
<NextPartNumberMarker>3</NextPartNumberMarker>
|
||||
<IsTruncated>true</IsTruncated>
|
||||
<Part>
|
||||
<PartNumber>1</PartNumber>
|
||||
<ETag>"etag-1"</ETag>
|
||||
</Part>
|
||||
<Part>
|
||||
<PartNumber>2</PartNumber>
|
||||
<ETag>etag-2</ETag>
|
||||
</Part>
|
||||
</ListPartsResult>`;
|
||||
|
||||
const result = parseListPartsXml(xml);
|
||||
t.deepEqual(result.parts, [
|
||||
{ partNumber: 1, etag: 'etag-1' },
|
||||
{ partNumber: 2, etag: 'etag-2' },
|
||||
]);
|
||||
t.true(result.isTruncated);
|
||||
t.is(result.nextPartNumberMarker, '3');
|
||||
});
|
||||
|
||||
test('parseListPartsXml handles single part', t => {
|
||||
const xml = `<?xml version="1.0" encoding="UTF-8"?>
|
||||
<ListPartsResult>
|
||||
<Bucket>test</Bucket>
|
||||
<Key>key</Key>
|
||||
<UploadId>upload-id</UploadId>
|
||||
<IsTruncated>false</IsTruncated>
|
||||
<Part>
|
||||
<PartNumber>5</PartNumber>
|
||||
<ETag>"etag-5"</ETag>
|
||||
</Part>
|
||||
</ListPartsResult>`;
|
||||
|
||||
const result = parseListPartsXml(xml);
|
||||
t.deepEqual(result.parts, [{ partNumber: 5, etag: 'etag-5' }]);
|
||||
t.false(result.isTruncated);
|
||||
t.is(result.nextPartNumberMarker, undefined);
|
||||
});
|
||||
@@ -4,7 +4,8 @@ import { S3StorageProvider } from '../providers/s3';
|
||||
import { SIGNED_URL_EXPIRED } from '../providers/utils';
|
||||
|
||||
const config = {
|
||||
region: 'auto',
|
||||
region: 'us-east-1',
|
||||
endpoint: 'https://s3.us-east-1.amazonaws.com',
|
||||
credentials: {
|
||||
accessKeyId: 'test',
|
||||
secretAccessKey: 'test',
|
||||
@@ -24,6 +25,8 @@ test('presignPut should return url and headers', async t => {
|
||||
t.truthy(result);
|
||||
t.true(result!.url.length > 0);
|
||||
t.true(result!.url.includes('X-Amz-Algorithm=AWS4-HMAC-SHA256'));
|
||||
t.true(result!.url.includes('X-Amz-SignedHeaders='));
|
||||
t.true(result!.url.includes('content-type'));
|
||||
t.deepEqual(result!.headers, { 'Content-Type': 'text/plain' });
|
||||
const now = Date.now();
|
||||
t.true(result!.expiresAt.getTime() >= now + SIGNED_URL_EXPIRED * 1000 - 2000);
|
||||
@@ -41,12 +44,15 @@ test('presignUploadPart should return url', async t => {
|
||||
|
||||
test('createMultipartUpload should return uploadId', async t => {
|
||||
const provider = createProvider();
|
||||
let receivedCommand: any;
|
||||
const sendStub = async (command: any) => {
|
||||
receivedCommand = command;
|
||||
return { UploadId: 'upload-1' };
|
||||
let receivedKey: string | undefined;
|
||||
let receivedMeta: any;
|
||||
(provider as any).client = {
|
||||
createMultipartUpload: async (key: string, meta: any) => {
|
||||
receivedKey = key;
|
||||
receivedMeta = meta;
|
||||
return { uploadId: 'upload-1' };
|
||||
},
|
||||
};
|
||||
(provider as any).client = { send: sendStub };
|
||||
|
||||
const now = Date.now();
|
||||
const result = await provider.createMultipartUpload('key', {
|
||||
@@ -56,25 +62,29 @@ test('createMultipartUpload should return uploadId', async t => {
|
||||
t.is(result?.uploadId, 'upload-1');
|
||||
t.true(result!.expiresAt.getTime() >= now + SIGNED_URL_EXPIRED * 1000 - 2000);
|
||||
t.true(result!.expiresAt.getTime() <= now + SIGNED_URL_EXPIRED * 1000 + 2000);
|
||||
t.is(receivedCommand.input.Key, 'key');
|
||||
t.is(receivedCommand.input.ContentType, 'text/plain');
|
||||
t.is(receivedKey, 'key');
|
||||
t.is(receivedMeta.contentType, 'text/plain');
|
||||
});
|
||||
|
||||
test('completeMultipartUpload should order parts', async t => {
|
||||
const provider = createProvider();
|
||||
let called = false;
|
||||
const sendStub = async (command: any) => {
|
||||
called = true;
|
||||
t.deepEqual(command.input.MultipartUpload.Parts, [
|
||||
{ ETag: 'a', PartNumber: 1 },
|
||||
{ ETag: 'b', PartNumber: 2 },
|
||||
]);
|
||||
let receivedParts: any;
|
||||
(provider as any).client = {
|
||||
completeMultipartUpload: async (
|
||||
_key: string,
|
||||
_uploadId: string,
|
||||
parts: any
|
||||
) => {
|
||||
receivedParts = parts;
|
||||
},
|
||||
};
|
||||
(provider as any).client = { send: sendStub };
|
||||
|
||||
await provider.completeMultipartUpload('key', 'upload-1', [
|
||||
{ partNumber: 2, etag: 'b' },
|
||||
{ partNumber: 1, etag: 'a' },
|
||||
]);
|
||||
t.true(called);
|
||||
t.deepEqual(receivedParts, [
|
||||
{ partNumber: 1, etag: 'a' },
|
||||
{ partNumber: 2, etag: 'b' },
|
||||
]);
|
||||
});
|
||||
|
||||
@@ -33,9 +33,44 @@ export type StorageProviderConfig = { bucket: string } & (
|
||||
|
||||
const S3ConfigSchema: JSONSchema = {
|
||||
type: 'object',
|
||||
description:
|
||||
'The config for the s3 compatible storage provider. directly passed to aws-sdk client.\n@link https://docs.aws.amazon.com/AWSJavaScriptSDK/latest/AWS/S3.html',
|
||||
description: 'The config for the S3 compatible storage provider.',
|
||||
properties: {
|
||||
endpoint: {
|
||||
type: 'string',
|
||||
description:
|
||||
'The S3 compatible endpoint. Example: "https://s3.us-east-1.amazonaws.com" or "https://<account>.r2.cloudflarestorage.com".',
|
||||
},
|
||||
region: {
|
||||
type: 'string',
|
||||
description:
|
||||
'The region for the storage provider. Example: "us-east-1" or "auto" for R2.',
|
||||
},
|
||||
forcePathStyle: {
|
||||
type: 'boolean',
|
||||
description: 'Whether to use path-style bucket addressing.',
|
||||
},
|
||||
requestTimeoutMs: {
|
||||
type: 'number',
|
||||
description: 'Request timeout in milliseconds.',
|
||||
},
|
||||
minPartSize: {
|
||||
type: 'number',
|
||||
description: 'Minimum multipart part size in bytes.',
|
||||
},
|
||||
presign: {
|
||||
type: 'object',
|
||||
description: 'Presigned URL behavior configuration.',
|
||||
properties: {
|
||||
expiresInSeconds: {
|
||||
type: 'number',
|
||||
description: 'Expiration time in seconds for presigned URLs.',
|
||||
},
|
||||
signContentTypeForPut: {
|
||||
type: 'boolean',
|
||||
description: 'Whether to sign Content-Type for presigned PUT.',
|
||||
},
|
||||
},
|
||||
},
|
||||
credentials: {
|
||||
type: 'object',
|
||||
description: 'The credentials for the s3 compatible storage provider.',
|
||||
@@ -46,6 +81,9 @@ const S3ConfigSchema: JSONSchema = {
|
||||
secretAccessKey: {
|
||||
type: 'string',
|
||||
},
|
||||
sessionToken: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import assert from 'node:assert';
|
||||
import { Readable } from 'node:stream';
|
||||
|
||||
import { PutObjectCommand, UploadPartCommand } from '@aws-sdk/client-s3';
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
import {
|
||||
@@ -39,9 +38,6 @@ export class R2StorageProvider extends S3StorageProvider {
|
||||
...config,
|
||||
forcePathStyle: true,
|
||||
endpoint: `https://${config.accountId}.r2.cloudflarestorage.com`,
|
||||
// see https://github.com/aws/aws-sdk-js-v3/issues/6810
|
||||
requestChecksumCalculation: 'WHEN_REQUIRED',
|
||||
responseChecksumValidation: 'WHEN_REQUIRED',
|
||||
},
|
||||
bucket
|
||||
);
|
||||
@@ -179,15 +175,10 @@ export class R2StorageProvider extends S3StorageProvider {
|
||||
body: Readable | Buffer | Uint8Array | string,
|
||||
options: { contentType?: string; contentLength?: number } = {}
|
||||
) {
|
||||
return this.client.send(
|
||||
new PutObjectCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key,
|
||||
Body: body,
|
||||
ContentType: options.contentType,
|
||||
ContentLength: options.contentLength,
|
||||
})
|
||||
);
|
||||
return this.client.putObject(key, body as any, {
|
||||
contentType: options.contentType,
|
||||
contentLength: options.contentLength,
|
||||
});
|
||||
}
|
||||
|
||||
async proxyUploadPart(
|
||||
@@ -197,18 +188,15 @@ export class R2StorageProvider extends S3StorageProvider {
|
||||
body: Readable | Buffer | Uint8Array | string,
|
||||
options: { contentLength?: number } = {}
|
||||
) {
|
||||
const result = await this.client.send(
|
||||
new UploadPartCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key,
|
||||
UploadId: uploadId,
|
||||
PartNumber: partNumber,
|
||||
Body: body,
|
||||
ContentLength: options.contentLength,
|
||||
})
|
||||
const result = await this.client.uploadPart(
|
||||
key,
|
||||
uploadId,
|
||||
partNumber,
|
||||
body as any,
|
||||
{ contentLength: options.contentLength }
|
||||
);
|
||||
|
||||
return result.ETag;
|
||||
return result.etag;
|
||||
}
|
||||
|
||||
override async get(
|
||||
|
||||
@@ -1,24 +1,12 @@
|
||||
/* oxlint-disable @typescript-eslint/no-non-null-assertion */
|
||||
import { Readable } from 'node:stream';
|
||||
|
||||
import {
|
||||
AbortMultipartUploadCommand,
|
||||
CompleteMultipartUploadCommand,
|
||||
CreateMultipartUploadCommand,
|
||||
DeleteObjectCommand,
|
||||
GetObjectCommand,
|
||||
HeadObjectCommand,
|
||||
ListObjectsV2Command,
|
||||
ListPartsCommand,
|
||||
NoSuchKey,
|
||||
NoSuchUpload,
|
||||
NotFound,
|
||||
PutObjectCommand,
|
||||
S3Client,
|
||||
S3ClientConfig,
|
||||
UploadPartCommand,
|
||||
} from '@aws-sdk/client-s3';
|
||||
import { getSignedUrl } from '@aws-sdk/s3-request-presigner';
|
||||
import type {
|
||||
S3CompatClient,
|
||||
S3CompatConfig,
|
||||
S3CompatCredentials,
|
||||
} from '@affine/s3-compat';
|
||||
import { createS3CompatClient } from '@affine/s3-compat';
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
import {
|
||||
@@ -33,30 +21,55 @@ import {
|
||||
} from './provider';
|
||||
import { autoMetadata, SIGNED_URL_EXPIRED, toBuffer } from './utils';
|
||||
|
||||
export interface S3StorageConfig extends S3ClientConfig {
|
||||
export interface S3StorageConfig {
|
||||
endpoint?: string;
|
||||
region: string;
|
||||
credentials: S3CompatCredentials;
|
||||
forcePathStyle?: boolean;
|
||||
requestTimeoutMs?: number;
|
||||
minPartSize?: number;
|
||||
presign?: {
|
||||
expiresInSeconds?: number;
|
||||
signContentTypeForPut?: boolean;
|
||||
};
|
||||
usePresignedURL?: {
|
||||
enabled: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
function resolveEndpoint(config: S3StorageConfig) {
|
||||
if (config.endpoint) {
|
||||
return config.endpoint;
|
||||
}
|
||||
if (config.region === 'us-east-1') {
|
||||
return 'https://s3.amazonaws.com';
|
||||
}
|
||||
return `https://s3.${config.region}.amazonaws.com`;
|
||||
}
|
||||
|
||||
export class S3StorageProvider implements StorageProvider {
|
||||
protected logger: Logger;
|
||||
protected client: S3Client;
|
||||
protected client: S3CompatClient;
|
||||
private readonly usePresignedURL: boolean;
|
||||
|
||||
constructor(
|
||||
config: S3StorageConfig,
|
||||
public readonly bucket: string
|
||||
) {
|
||||
const { usePresignedURL, ...clientConfig } = config;
|
||||
this.client = new S3Client({
|
||||
region: 'auto',
|
||||
// s3 client uses keep-alive by default to accelerate requests, and max requests queue is 50.
|
||||
// If some of them are long holding or dead without response, the whole queue will block.
|
||||
// By default no timeout is set for requests or connections, so we set them here.
|
||||
requestHandler: { requestTimeout: 60_000, connectionTimeout: 10_000 },
|
||||
const { usePresignedURL, presign, credentials, ...clientConfig } = config;
|
||||
|
||||
const compatConfig: S3CompatConfig = {
|
||||
...clientConfig,
|
||||
});
|
||||
endpoint: resolveEndpoint(config),
|
||||
bucket,
|
||||
requestTimeoutMs: clientConfig.requestTimeoutMs ?? 60_000,
|
||||
presign: {
|
||||
expiresInSeconds: presign?.expiresInSeconds ?? SIGNED_URL_EXPIRED,
|
||||
signContentTypeForPut: presign?.signContentTypeForPut ?? true,
|
||||
},
|
||||
};
|
||||
|
||||
this.client = createS3CompatClient(compatConfig, credentials);
|
||||
this.usePresignedURL = usePresignedURL?.enabled ?? false;
|
||||
this.logger = new Logger(`${S3StorageProvider.name}:${bucket}`);
|
||||
}
|
||||
@@ -71,19 +84,10 @@ export class S3StorageProvider implements StorageProvider {
|
||||
metadata = autoMetadata(blob, metadata);
|
||||
|
||||
try {
|
||||
await this.client.send(
|
||||
new PutObjectCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key,
|
||||
Body: blob,
|
||||
|
||||
// metadata
|
||||
ContentType: metadata.contentType,
|
||||
ContentLength: metadata.contentLength,
|
||||
// TODO(@forehalo): Cloudflare doesn't support CRC32, use md5 instead later.
|
||||
// ChecksumCRC32: metadata.checksumCRC32,
|
||||
})
|
||||
);
|
||||
await this.client.putObject(key, blob, {
|
||||
contentType: metadata.contentType,
|
||||
contentLength: metadata.contentLength,
|
||||
});
|
||||
|
||||
this.logger.verbose(`Object \`${key}\` put`);
|
||||
} catch (e) {
|
||||
@@ -104,20 +108,12 @@ export class S3StorageProvider implements StorageProvider {
|
||||
): Promise<PresignedUpload | undefined> {
|
||||
try {
|
||||
const contentType = metadata.contentType ?? 'application/octet-stream';
|
||||
const url = await getSignedUrl(
|
||||
this.client,
|
||||
new PutObjectCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key,
|
||||
ContentType: contentType,
|
||||
}),
|
||||
{ expiresIn: SIGNED_URL_EXPIRED }
|
||||
);
|
||||
const result = await this.client.presignPutObject(key, { contentType });
|
||||
|
||||
return {
|
||||
url,
|
||||
headers: { 'Content-Type': contentType },
|
||||
expiresAt: new Date(Date.now() + SIGNED_URL_EXPIRED * 1000),
|
||||
url: result.url,
|
||||
headers: result.headers,
|
||||
expiresAt: result.expiresAt,
|
||||
};
|
||||
} catch (e) {
|
||||
this.logger.error(
|
||||
@@ -137,20 +133,16 @@ export class S3StorageProvider implements StorageProvider {
|
||||
): Promise<MultipartUploadInit | undefined> {
|
||||
try {
|
||||
const contentType = metadata.contentType ?? 'application/octet-stream';
|
||||
const response = await this.client.send(
|
||||
new CreateMultipartUploadCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key,
|
||||
ContentType: contentType,
|
||||
})
|
||||
);
|
||||
const response = await this.client.createMultipartUpload(key, {
|
||||
contentType,
|
||||
});
|
||||
|
||||
if (!response.UploadId) {
|
||||
if (!response.uploadId) {
|
||||
return;
|
||||
}
|
||||
|
||||
return {
|
||||
uploadId: response.UploadId,
|
||||
uploadId: response.uploadId,
|
||||
expiresAt: new Date(Date.now() + SIGNED_URL_EXPIRED * 1000),
|
||||
};
|
||||
} catch (e) {
|
||||
@@ -171,20 +163,15 @@ export class S3StorageProvider implements StorageProvider {
|
||||
partNumber: number
|
||||
): Promise<PresignedUpload | undefined> {
|
||||
try {
|
||||
const url = await getSignedUrl(
|
||||
this.client,
|
||||
new UploadPartCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key,
|
||||
UploadId: uploadId,
|
||||
PartNumber: partNumber,
|
||||
}),
|
||||
{ expiresIn: SIGNED_URL_EXPIRED }
|
||||
const result = await this.client.presignUploadPart(
|
||||
key,
|
||||
uploadId,
|
||||
partNumber
|
||||
);
|
||||
|
||||
return {
|
||||
url,
|
||||
expiresAt: new Date(Date.now() + SIGNED_URL_EXPIRED * 1000),
|
||||
url: result.url,
|
||||
expiresAt: result.expiresAt,
|
||||
};
|
||||
} catch (e) {
|
||||
this.logger.error(
|
||||
@@ -198,47 +185,9 @@ export class S3StorageProvider implements StorageProvider {
|
||||
key: string,
|
||||
uploadId: string
|
||||
): Promise<MultipartUploadPart[] | undefined> {
|
||||
const parts: MultipartUploadPart[] = [];
|
||||
let partNumberMarker: string | undefined;
|
||||
|
||||
try {
|
||||
// ListParts is paginated by part number marker
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListParts.html
|
||||
// R2 follows S3 semantics here.
|
||||
while (true) {
|
||||
const response = await this.client.send(
|
||||
new ListPartsCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key,
|
||||
UploadId: uploadId,
|
||||
PartNumberMarker: partNumberMarker,
|
||||
})
|
||||
);
|
||||
|
||||
for (const part of response.Parts ?? []) {
|
||||
if (!part.PartNumber || !part.ETag) {
|
||||
continue;
|
||||
}
|
||||
parts.push({ partNumber: part.PartNumber, etag: part.ETag });
|
||||
}
|
||||
|
||||
if (!response.IsTruncated) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (response.NextPartNumberMarker === undefined) {
|
||||
break;
|
||||
}
|
||||
|
||||
partNumberMarker = response.NextPartNumberMarker;
|
||||
}
|
||||
|
||||
return parts;
|
||||
return await this.client.listParts(key, uploadId);
|
||||
} catch (e) {
|
||||
// the upload may have been aborted/expired by provider lifecycle rules
|
||||
if (e instanceof NoSuchUpload || e instanceof NotFound) {
|
||||
return undefined;
|
||||
}
|
||||
this.logger.error(`Failed to list multipart upload parts for \`${key}\``);
|
||||
throw e;
|
||||
}
|
||||
@@ -254,19 +203,7 @@ export class S3StorageProvider implements StorageProvider {
|
||||
(left, right) => left.partNumber - right.partNumber
|
||||
);
|
||||
|
||||
await this.client.send(
|
||||
new CompleteMultipartUploadCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key,
|
||||
UploadId: uploadId,
|
||||
MultipartUpload: {
|
||||
Parts: orderedParts.map(part => ({
|
||||
ETag: part.etag,
|
||||
PartNumber: part.partNumber,
|
||||
})),
|
||||
},
|
||||
})
|
||||
);
|
||||
await this.client.completeMultipartUpload(key, uploadId, orderedParts);
|
||||
} catch (e) {
|
||||
this.logger.error(`Failed to complete multipart upload for \`${key}\``);
|
||||
throw e;
|
||||
@@ -275,13 +212,7 @@ export class S3StorageProvider implements StorageProvider {
|
||||
|
||||
async abortMultipartUpload(key: string, uploadId: string): Promise<void> {
|
||||
try {
|
||||
await this.client.send(
|
||||
new AbortMultipartUploadCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key,
|
||||
UploadId: uploadId,
|
||||
})
|
||||
);
|
||||
await this.client.abortMultipartUpload(key, uploadId);
|
||||
} catch (e) {
|
||||
this.logger.error(`Failed to abort multipart upload for \`${key}\``);
|
||||
throw e;
|
||||
@@ -290,25 +221,19 @@ export class S3StorageProvider implements StorageProvider {
|
||||
|
||||
async head(key: string) {
|
||||
try {
|
||||
const obj = await this.client.send(
|
||||
new HeadObjectCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key,
|
||||
})
|
||||
);
|
||||
|
||||
return {
|
||||
contentType: obj.ContentType!,
|
||||
contentLength: obj.ContentLength!,
|
||||
lastModified: obj.LastModified!,
|
||||
checksumCRC32: obj.ChecksumCRC32,
|
||||
};
|
||||
} catch (e) {
|
||||
// 404
|
||||
if (e instanceof NoSuchKey || e instanceof NotFound) {
|
||||
const obj = await this.client.headObject(key);
|
||||
if (!obj) {
|
||||
this.logger.verbose(`Object \`${key}\` not found`);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
contentType: obj.contentType ?? 'application/octet-stream',
|
||||
contentLength: obj.contentLength ?? 0,
|
||||
lastModified: obj.lastModified ?? new Date(0),
|
||||
checksumCRC32: obj.checksumCRC32,
|
||||
};
|
||||
} catch (e) {
|
||||
this.logger.error(`Failed to head object \`${key}\``);
|
||||
throw e;
|
||||
}
|
||||
@@ -323,25 +248,13 @@ export class S3StorageProvider implements StorageProvider {
|
||||
redirectUrl?: string;
|
||||
}> {
|
||||
try {
|
||||
const command = new GetObjectCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key,
|
||||
});
|
||||
|
||||
if (this.usePresignedURL && signedUrl) {
|
||||
const metadata = await this.head(key);
|
||||
if (metadata) {
|
||||
const url = await getSignedUrl(
|
||||
this.client,
|
||||
new GetObjectCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key,
|
||||
}),
|
||||
{ expiresIn: SIGNED_URL_EXPIRED }
|
||||
);
|
||||
const result = await this.client.presignGetObject(key);
|
||||
|
||||
return {
|
||||
redirectUrl: url,
|
||||
redirectUrl: result.url,
|
||||
metadata,
|
||||
};
|
||||
}
|
||||
@@ -350,68 +263,41 @@ export class S3StorageProvider implements StorageProvider {
|
||||
return {};
|
||||
}
|
||||
|
||||
const obj = await this.client.send(command);
|
||||
|
||||
if (!obj.Body) {
|
||||
const obj = await this.client.getObjectResponse(key);
|
||||
if (!obj || !obj.body) {
|
||||
this.logger.verbose(`Object \`${key}\` not found`);
|
||||
return {};
|
||||
}
|
||||
|
||||
const contentType = obj.headers.get('content-type') ?? undefined;
|
||||
const contentLengthHeader = obj.headers.get('content-length');
|
||||
const contentLength = contentLengthHeader
|
||||
? Number(contentLengthHeader)
|
||||
: undefined;
|
||||
const lastModifiedHeader = obj.headers.get('last-modified');
|
||||
const lastModified = lastModifiedHeader
|
||||
? new Date(lastModifiedHeader)
|
||||
: undefined;
|
||||
|
||||
this.logger.verbose(`Read object \`${key}\``);
|
||||
return {
|
||||
// @ts-expect-errors ignore browser response type `Blob`
|
||||
body: obj.Body,
|
||||
body: Readable.fromWeb(obj.body as any),
|
||||
metadata: {
|
||||
// always set when putting object
|
||||
contentType: obj.ContentType ?? 'application/octet-stream',
|
||||
contentLength: obj.ContentLength!,
|
||||
lastModified: obj.LastModified!,
|
||||
checksumCRC32: obj.ChecksumCRC32,
|
||||
contentType: contentType ?? 'application/octet-stream',
|
||||
contentLength: contentLength ?? 0,
|
||||
lastModified: lastModified ?? new Date(0),
|
||||
checksumCRC32: obj.headers.get('x-amz-checksum-crc32') ?? undefined,
|
||||
},
|
||||
};
|
||||
} catch (e) {
|
||||
// 404
|
||||
if (e instanceof NoSuchKey) {
|
||||
this.logger.verbose(`Object \`${key}\` not found`);
|
||||
return {};
|
||||
}
|
||||
this.logger.error(`Failed to read object \`${key}\``);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
async list(prefix?: string): Promise<ListObjectsMetadata[]> {
|
||||
// continuationToken should be `string | undefined`,
|
||||
// but TypeScript will fail on type infer in the code below.
|
||||
// Seems to be a bug in TypeScript
|
||||
let continuationToken: any = undefined;
|
||||
let hasMore = true;
|
||||
let result: ListObjectsMetadata[] = [];
|
||||
|
||||
try {
|
||||
while (hasMore) {
|
||||
const listResult = await this.client.send(
|
||||
new ListObjectsV2Command({
|
||||
Bucket: this.bucket,
|
||||
Prefix: prefix,
|
||||
ContinuationToken: continuationToken,
|
||||
})
|
||||
);
|
||||
|
||||
if (listResult.Contents?.length) {
|
||||
result = result.concat(
|
||||
listResult.Contents.map(r => ({
|
||||
key: r.Key!,
|
||||
lastModified: r.LastModified!,
|
||||
contentLength: r.Size!,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
// has more items not listed
|
||||
hasMore = !!listResult.IsTruncated;
|
||||
continuationToken = listResult.NextContinuationToken;
|
||||
}
|
||||
const result = await this.client.listObjectsV2(prefix);
|
||||
|
||||
this.logger.verbose(
|
||||
`List ${result.length} objects with prefix \`${prefix}\``
|
||||
@@ -425,12 +311,7 @@ export class S3StorageProvider implements StorageProvider {
|
||||
|
||||
async delete(key: string): Promise<void> {
|
||||
try {
|
||||
await this.client.send(
|
||||
new DeleteObjectCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key,
|
||||
})
|
||||
);
|
||||
await this.client.deleteObject(key);
|
||||
|
||||
this.logger.verbose(`Deleted object \`${key}\``);
|
||||
} catch (e) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
export * from './duration';
|
||||
export * from './promise';
|
||||
export * from './request';
|
||||
export * from './ssrf';
|
||||
export * from './stream';
|
||||
export * from './types';
|
||||
export * from './unit';
|
||||
|
||||
364
packages/backend/server/src/base/utils/ssrf.ts
Normal file
364
packages/backend/server/src/base/utils/ssrf.ts
Normal file
@@ -0,0 +1,364 @@
|
||||
import * as dns from 'node:dns/promises';
|
||||
import { BlockList, isIP } from 'node:net';
|
||||
import { Readable } from 'node:stream';
|
||||
|
||||
import { ResponseTooLargeError, SsrfBlockedError } from '../error/errors.gen';
|
||||
import { OneMinute } from './unit';
|
||||
|
||||
const DEFAULT_ALLOWED_PROTOCOLS = new Set(['http:', 'https:']);
|
||||
const BLOCKED_IPS = new BlockList();
|
||||
const ALLOWED_IPV6 = new BlockList();
|
||||
|
||||
export type DnsLookup = typeof dns.lookup;
|
||||
let dnsLookup: DnsLookup = dns.lookup;
|
||||
|
||||
export function __setDnsLookupForTests(lookup: DnsLookup) {
|
||||
dnsLookup = lookup;
|
||||
}
|
||||
|
||||
export function __resetDnsLookupForTests() {
|
||||
dnsLookup = dns.lookup;
|
||||
}
|
||||
|
||||
export type SSRFBlockReason =
|
||||
| 'invalid_url'
|
||||
| 'disallowed_protocol'
|
||||
| 'url_has_credentials'
|
||||
| 'blocked_hostname'
|
||||
| 'unresolvable_hostname'
|
||||
| 'blocked_ip'
|
||||
| 'too_many_redirects';
|
||||
|
||||
type SsrfErrorContext = { url?: string; hostname?: string; address?: string };
|
||||
|
||||
function createSsrfBlockedError(
|
||||
reason: SSRFBlockReason,
|
||||
context?: SsrfErrorContext
|
||||
) {
|
||||
const err = new SsrfBlockedError({ reason });
|
||||
// For logging/debugging only (not part of UserFriendlyError JSON).
|
||||
(err as any).context = context;
|
||||
return err;
|
||||
}
|
||||
|
||||
export interface SSRFProtectionOptions {
|
||||
allowedProtocols?: ReadonlySet<string>;
|
||||
/**
|
||||
* Allow fetching private/reserved IPs when URL.origin is allowlisted.
|
||||
* Defaults to an empty allowlist (i.e. private IPs are blocked).
|
||||
*/
|
||||
allowPrivateOrigins?: ReadonlySet<string>;
|
||||
}
|
||||
|
||||
function stripZoneId(address: string) {
|
||||
const idx = address.indexOf('%');
|
||||
return idx === -1 ? address : address.slice(0, idx);
|
||||
}
|
||||
|
||||
// IPv4: RFC1918 + loopback + link-local + CGNAT + special/reserved
|
||||
for (const [network, prefix] of [
|
||||
['0.0.0.0', 8],
|
||||
['10.0.0.0', 8],
|
||||
['127.0.0.0', 8],
|
||||
['169.254.0.0', 16],
|
||||
['172.16.0.0', 12],
|
||||
['192.168.0.0', 16],
|
||||
['100.64.0.0', 10], // CGNAT
|
||||
['192.0.0.0', 24],
|
||||
['192.0.2.0', 24], // TEST-NET-1
|
||||
['198.51.100.0', 24], // TEST-NET-2
|
||||
['203.0.113.0', 24], // TEST-NET-3
|
||||
['198.18.0.0', 15], // benchmark
|
||||
['192.88.99.0', 24], // 6to4 relay
|
||||
['224.0.0.0', 4], // multicast
|
||||
['240.0.0.0', 4], // reserved (includes broadcast)
|
||||
] as const) {
|
||||
BLOCKED_IPS.addSubnet(network, prefix, 'ipv4');
|
||||
}
|
||||
|
||||
// IPv6: block loopback/unspecified/link-local/ULA/multicast/doc; allow only global unicast.
|
||||
BLOCKED_IPS.addAddress('::', 'ipv6');
|
||||
BLOCKED_IPS.addAddress('::1', 'ipv6');
|
||||
BLOCKED_IPS.addSubnet('ff00::', 8, 'ipv6'); // multicast
|
||||
BLOCKED_IPS.addSubnet('fc00::', 7, 'ipv6'); // unique local
|
||||
BLOCKED_IPS.addSubnet('fe80::', 10, 'ipv6'); // link-local
|
||||
BLOCKED_IPS.addSubnet('2001:db8::', 32, 'ipv6'); // documentation
|
||||
ALLOWED_IPV6.addSubnet('2000::', 3, 'ipv6'); // global unicast
|
||||
|
||||
function extractEmbeddedIPv4FromIPv6(address: string): string | null {
|
||||
if (!address.includes('.')) {
|
||||
return null;
|
||||
}
|
||||
const idx = address.lastIndexOf(':');
|
||||
if (idx === -1) {
|
||||
return null;
|
||||
}
|
||||
const tail = address.slice(idx + 1);
|
||||
return isIP(tail) === 4 ? tail : null;
|
||||
}
|
||||
|
||||
function isBlockedIpAddress(address: string): boolean {
|
||||
const ip = stripZoneId(address);
|
||||
const family = isIP(ip);
|
||||
if (family === 4) {
|
||||
return BLOCKED_IPS.check(ip, 'ipv4');
|
||||
}
|
||||
if (family === 6) {
|
||||
const embeddedV4 = extractEmbeddedIPv4FromIPv6(ip);
|
||||
if (embeddedV4) {
|
||||
return isBlockedIpAddress(embeddedV4);
|
||||
}
|
||||
if (!ALLOWED_IPV6.check(ip, 'ipv6')) {
|
||||
return true;
|
||||
}
|
||||
return BLOCKED_IPS.check(ip, 'ipv6');
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
async function resolveHostAddresses(hostname: string): Promise<string[]> {
|
||||
// Normalize common localhost aliases without DNS.
|
||||
const lowered = hostname.toLowerCase();
|
||||
if (lowered === 'localhost' || lowered.endsWith('.localhost')) {
|
||||
return ['127.0.0.1', '::1'];
|
||||
}
|
||||
|
||||
const results = await dnsLookup(hostname, {
|
||||
all: true,
|
||||
verbatim: true,
|
||||
});
|
||||
return results.map(r => r.address);
|
||||
}
|
||||
|
||||
export async function assertSsrFSafeUrl(
|
||||
rawUrl: string | URL,
|
||||
options: SSRFProtectionOptions = {}
|
||||
): Promise<URL> {
|
||||
const allowedProtocols =
|
||||
options.allowedProtocols ?? DEFAULT_ALLOWED_PROTOCOLS;
|
||||
|
||||
let url: URL;
|
||||
try {
|
||||
url = rawUrl instanceof URL ? rawUrl : new URL(rawUrl);
|
||||
} catch {
|
||||
throw createSsrfBlockedError('invalid_url', {
|
||||
url: typeof rawUrl === 'string' ? rawUrl : undefined,
|
||||
});
|
||||
}
|
||||
|
||||
if (!allowedProtocols.has(url.protocol)) {
|
||||
throw createSsrfBlockedError('disallowed_protocol', {
|
||||
url: url.toString(),
|
||||
});
|
||||
}
|
||||
|
||||
if (url.username || url.password) {
|
||||
throw createSsrfBlockedError('url_has_credentials', {
|
||||
url: url.toString(),
|
||||
});
|
||||
}
|
||||
|
||||
const hostname = url.hostname;
|
||||
if (!hostname) {
|
||||
throw createSsrfBlockedError('blocked_hostname', { url: url.toString() });
|
||||
}
|
||||
|
||||
const allowPrivate =
|
||||
options.allowPrivateOrigins && options.allowPrivateOrigins.has(url.origin);
|
||||
|
||||
// IP literal
|
||||
if (isIP(hostname)) {
|
||||
if (isBlockedIpAddress(hostname) && !allowPrivate) {
|
||||
throw createSsrfBlockedError('blocked_ip', {
|
||||
url: url.toString(),
|
||||
address: hostname,
|
||||
});
|
||||
}
|
||||
return url;
|
||||
}
|
||||
|
||||
let addresses: string[];
|
||||
try {
|
||||
addresses = await resolveHostAddresses(hostname);
|
||||
} catch (error) {
|
||||
throw createSsrfBlockedError('unresolvable_hostname', {
|
||||
url: url.toString(),
|
||||
hostname,
|
||||
});
|
||||
}
|
||||
|
||||
if (addresses.length === 0) {
|
||||
throw createSsrfBlockedError('unresolvable_hostname', {
|
||||
url: url.toString(),
|
||||
hostname,
|
||||
});
|
||||
}
|
||||
|
||||
for (const address of addresses) {
|
||||
if (isBlockedIpAddress(address) && !allowPrivate) {
|
||||
throw createSsrfBlockedError('blocked_ip', {
|
||||
url: url.toString(),
|
||||
hostname,
|
||||
address,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return url;
|
||||
}
|
||||
|
||||
export interface SafeFetchOptions extends SSRFProtectionOptions {
|
||||
timeoutMs?: number;
|
||||
maxRedirects?: number;
|
||||
}
|
||||
|
||||
export async function safeFetch(
|
||||
rawUrl: string | URL,
|
||||
init: RequestInit = {},
|
||||
options: SafeFetchOptions = {}
|
||||
): Promise<Response> {
|
||||
const timeoutMs = options.timeoutMs ?? 10_000;
|
||||
const maxRedirects = options.maxRedirects ?? 3;
|
||||
|
||||
const timeoutSignal = AbortSignal.timeout(timeoutMs);
|
||||
const signal = init.signal
|
||||
? AbortSignal.any([init.signal, timeoutSignal])
|
||||
: timeoutSignal;
|
||||
|
||||
let current = await assertSsrFSafeUrl(rawUrl, options);
|
||||
let redirects = 0;
|
||||
|
||||
// Always handle redirects manually (SSRF-safe on each hop).
|
||||
let requestInit: RequestInit = {
|
||||
...init,
|
||||
redirect: 'manual',
|
||||
signal,
|
||||
};
|
||||
|
||||
while (true) {
|
||||
const response = await fetch(current, requestInit);
|
||||
|
||||
if (response.status >= 300 && response.status < 400) {
|
||||
const location = response.headers.get('location');
|
||||
if (!location) {
|
||||
return response;
|
||||
}
|
||||
|
||||
// Drain/cancel body before following redirect to avoid leaking resources.
|
||||
try {
|
||||
await response.body?.cancel();
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
|
||||
if (redirects >= maxRedirects) {
|
||||
throw createSsrfBlockedError('too_many_redirects', {
|
||||
url: current.toString(),
|
||||
});
|
||||
}
|
||||
|
||||
const next = new URL(location, current);
|
||||
current = await assertSsrFSafeUrl(next, options);
|
||||
redirects += 1;
|
||||
|
||||
// 303 forces GET semantics
|
||||
if (
|
||||
response.status === 303 &&
|
||||
requestInit.method &&
|
||||
requestInit.method !== 'GET'
|
||||
) {
|
||||
requestInit = { ...requestInit, method: 'GET', body: undefined };
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
export async function readResponseBufferWithLimit(
|
||||
response: Response,
|
||||
limitBytes: number
|
||||
): Promise<Buffer> {
|
||||
const rawLen = response.headers.get('content-length');
|
||||
if (rawLen) {
|
||||
const len = Number.parseInt(rawLen, 10);
|
||||
if (Number.isFinite(len) && len > limitBytes) {
|
||||
try {
|
||||
await response.body?.cancel();
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
throw new ResponseTooLargeError({ limitBytes, receivedBytes: len });
|
||||
}
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
return Buffer.alloc(0);
|
||||
}
|
||||
|
||||
// Convert Web ReadableStream -> Node Readable for consistent limit handling.
|
||||
const nodeStream = Readable.fromWeb(response.body);
|
||||
const chunks: Buffer[] = [];
|
||||
let total = 0;
|
||||
|
||||
try {
|
||||
for await (const chunk of nodeStream) {
|
||||
const buf = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk);
|
||||
total += buf.length;
|
||||
if (total > limitBytes) {
|
||||
try {
|
||||
nodeStream.destroy();
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
throw new ResponseTooLargeError({ limitBytes, receivedBytes: total });
|
||||
}
|
||||
chunks.push(buf);
|
||||
}
|
||||
} finally {
|
||||
if (total > limitBytes) {
|
||||
try {
|
||||
await response.body?.cancel();
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Buffer.concat(chunks, total);
|
||||
}
|
||||
|
||||
type FetchBufferResult = { buffer: Buffer; type: string };
|
||||
const ATTACH_GET_PARAMS = { timeoutMs: OneMinute / 6, maxRedirects: 3 };
|
||||
|
||||
export async function fetchBuffer(
|
||||
url: string,
|
||||
limit: number,
|
||||
contentType?: string
|
||||
): Promise<FetchBufferResult> {
|
||||
const resp = url.startsWith('data:')
|
||||
? await fetch(url)
|
||||
: await safeFetch(url, { method: 'GET' }, ATTACH_GET_PARAMS);
|
||||
|
||||
if (!resp.ok) {
|
||||
throw new Error(
|
||||
`Failed to fetch attachment: ${resp.status} ${resp.statusText}`
|
||||
);
|
||||
}
|
||||
const type = resp.headers.get('content-type') || 'application/octet-stream';
|
||||
if (contentType && !type.startsWith(contentType)) {
|
||||
throw new Error(
|
||||
`Attachment content-type mismatch: expected ${contentType} but got ${type}`
|
||||
);
|
||||
}
|
||||
const buffer = await readResponseBufferWithLimit(resp, limit);
|
||||
return { buffer, type: type };
|
||||
}
|
||||
|
||||
export function bufferToArrayBuffer(buffer: Buffer): ArrayBuffer {
|
||||
const copy = new Uint8Array(buffer.byteLength);
|
||||
copy.set(buffer);
|
||||
return copy.buffer;
|
||||
}
|
||||
@@ -28,13 +28,6 @@ class GenerateAccessTokenInput {
|
||||
export class AccessTokenResolver {
|
||||
constructor(private readonly models: Models) {}
|
||||
|
||||
@Query(() => [AccessToken], {
|
||||
deprecationReason: 'use currentUser.accessTokens',
|
||||
})
|
||||
async accessTokens(@CurrentUser() user: CurrentUser): Promise<AccessToken[]> {
|
||||
return await this.models.accessToken.list(user.id);
|
||||
}
|
||||
|
||||
@Query(() => [RevealedAccessToken], {
|
||||
deprecationReason: 'use currentUser.revealedAccessTokens',
|
||||
})
|
||||
|
||||
@@ -16,7 +16,6 @@ import type { Request, Response } from 'express';
|
||||
|
||||
import {
|
||||
ActionForbidden,
|
||||
Cache,
|
||||
Config,
|
||||
CryptoHelper,
|
||||
EmailTokenNotFound,
|
||||
@@ -53,7 +52,9 @@ interface MagicLinkCredential {
|
||||
client_nonce?: string;
|
||||
}
|
||||
|
||||
const OTP_CACHE_KEY = (otp: string) => `magic-link-otp:${otp}`;
|
||||
interface OpenAppSignInCredential {
|
||||
code: string;
|
||||
}
|
||||
|
||||
@Throttle('strict')
|
||||
@Controller('/api/auth')
|
||||
@@ -65,7 +66,6 @@ export class AuthController {
|
||||
private readonly auth: AuthService,
|
||||
private readonly models: Models,
|
||||
private readonly config: Config,
|
||||
private readonly cache: Cache,
|
||||
private readonly crypto: CryptoHelper
|
||||
) {
|
||||
if (env.dev) {
|
||||
@@ -111,11 +111,7 @@ export class AuthController {
|
||||
async signIn(
|
||||
@Req() req: Request,
|
||||
@Res() res: Response,
|
||||
@Body() credential: SignInCredential,
|
||||
/**
|
||||
* @deprecated
|
||||
*/
|
||||
@Query('redirect_uri') redirectUri?: string
|
||||
@Body() credential: SignInCredential
|
||||
) {
|
||||
validators.assertValidEmail(credential.email);
|
||||
const canSignIn = await this.auth.canSignIn(credential.email);
|
||||
@@ -132,11 +128,9 @@ export class AuthController {
|
||||
);
|
||||
} else {
|
||||
await this.sendMagicLink(
|
||||
req,
|
||||
res,
|
||||
credential.email,
|
||||
credential.callbackUrl,
|
||||
redirectUri,
|
||||
credential.client_nonce
|
||||
);
|
||||
}
|
||||
@@ -155,13 +149,25 @@ export class AuthController {
|
||||
}
|
||||
|
||||
async sendMagicLink(
|
||||
_req: Request,
|
||||
res: Response,
|
||||
email: string,
|
||||
callbackUrl = '/magic-link',
|
||||
redirectUrl?: string,
|
||||
clientNonce?: string
|
||||
) {
|
||||
if (!this.url.isAllowedCallbackUrl(callbackUrl)) {
|
||||
throw new ActionForbidden();
|
||||
}
|
||||
|
||||
const callbackUrlObj = this.url.url(callbackUrl);
|
||||
const redirectUriInCallback =
|
||||
callbackUrlObj.searchParams.get('redirect_uri');
|
||||
if (
|
||||
redirectUriInCallback &&
|
||||
!this.url.isAllowedRedirectUri(redirectUriInCallback)
|
||||
) {
|
||||
throw new ActionForbidden();
|
||||
}
|
||||
|
||||
// send email magic link
|
||||
const user = await this.models.user.getUserByEmail(email, {
|
||||
withDisabled: true,
|
||||
@@ -207,23 +213,9 @@ export class AuthController {
|
||||
);
|
||||
|
||||
const otp = this.crypto.otp();
|
||||
// TODO(@forehalo): this is a temporary solution, we should not rely on cache to store the otp
|
||||
const cacheKey = OTP_CACHE_KEY(otp);
|
||||
await this.cache.set(
|
||||
cacheKey,
|
||||
{ token, clientNonce },
|
||||
{ ttl: ttlInSec * 1000 }
|
||||
);
|
||||
await this.models.magicLinkOtp.upsert(email, otp, token, clientNonce);
|
||||
|
||||
const magicLink = this.url.link(callbackUrl, {
|
||||
token: otp,
|
||||
email,
|
||||
...(redirectUrl
|
||||
? {
|
||||
redirect_uri: redirectUrl,
|
||||
}
|
||||
: {}),
|
||||
});
|
||||
const magicLink = this.url.link(callbackUrl, { token: otp, email });
|
||||
if (env.dev) {
|
||||
// make it easier to test in dev mode
|
||||
this.logger.debug(`Magic link: ${magicLink}`);
|
||||
@@ -237,8 +229,9 @@ export class AuthController {
|
||||
}
|
||||
|
||||
@Public()
|
||||
@Get('/sign-out')
|
||||
@Post('/sign-out')
|
||||
async signOut(
|
||||
@Req() req: Request,
|
||||
@Res() res: Response,
|
||||
@Session() session: Session | undefined,
|
||||
@Query('user_id') userId: string | undefined
|
||||
@@ -248,12 +241,63 @@ export class AuthController {
|
||||
return;
|
||||
}
|
||||
|
||||
const csrfCookie = req.cookies?.[AuthService.csrfCookieName] as
|
||||
| string
|
||||
| undefined;
|
||||
const csrfHeader = req.get('x-affine-csrf-token');
|
||||
if (!csrfCookie || !csrfHeader || csrfCookie !== csrfHeader) {
|
||||
throw new ActionForbidden();
|
||||
}
|
||||
|
||||
await this.auth.signOut(session.sessionId, userId);
|
||||
await this.auth.refreshCookies(res, session.sessionId);
|
||||
|
||||
res.status(HttpStatus.OK).send({});
|
||||
}
|
||||
|
||||
@Public()
|
||||
@UseNamedGuard('version')
|
||||
@Post('/open-app/sign-in-code')
|
||||
async openAppSignInCode(@CurrentUser() user?: CurrentUser) {
|
||||
if (!user) {
|
||||
throw new ActionForbidden();
|
||||
}
|
||||
|
||||
// short-lived one-time code for handing off the authenticated session
|
||||
const code = await this.models.verificationToken.create(
|
||||
TokenType.OpenAppSignIn,
|
||||
user.id,
|
||||
5 * 60
|
||||
);
|
||||
|
||||
return { code };
|
||||
}
|
||||
|
||||
@Public()
|
||||
@UseNamedGuard('version')
|
||||
@Post('/open-app/sign-in')
|
||||
async openAppSignIn(
|
||||
@Req() req: Request,
|
||||
@Res() res: Response,
|
||||
@Body() credential: OpenAppSignInCredential
|
||||
) {
|
||||
if (!credential?.code) {
|
||||
throw new InvalidAuthState();
|
||||
}
|
||||
|
||||
const tokenRecord = await this.models.verificationToken.get(
|
||||
TokenType.OpenAppSignIn,
|
||||
credential.code
|
||||
);
|
||||
|
||||
if (!tokenRecord?.credential) {
|
||||
throw new InvalidAuthState();
|
||||
}
|
||||
|
||||
await this.auth.setCookies(req, res, tokenRecord.credential);
|
||||
res.send({ id: tokenRecord.credential });
|
||||
}
|
||||
|
||||
@Public()
|
||||
@UseNamedGuard('version')
|
||||
@Post('/magic-link')
|
||||
@@ -269,23 +313,20 @@ export class AuthController {
|
||||
|
||||
validators.assertValidEmail(email);
|
||||
|
||||
const cacheKey = OTP_CACHE_KEY(otp);
|
||||
const cachedToken = await this.cache.get<{
|
||||
token: string;
|
||||
clientNonce: string;
|
||||
}>(cacheKey);
|
||||
let token: string | undefined;
|
||||
if (cachedToken && typeof cachedToken === 'object') {
|
||||
token = cachedToken.token;
|
||||
if (cachedToken.clientNonce && cachedToken.clientNonce !== clientNonce) {
|
||||
const consumed = await this.models.magicLinkOtp.consume(
|
||||
email,
|
||||
otp,
|
||||
clientNonce
|
||||
);
|
||||
if (!consumed.ok) {
|
||||
if (consumed.reason === 'nonce_mismatch') {
|
||||
throw new InvalidAuthState();
|
||||
}
|
||||
}
|
||||
|
||||
if (!token) {
|
||||
throw new InvalidEmailToken();
|
||||
}
|
||||
|
||||
const token = consumed.token;
|
||||
|
||||
const tokenRecord = await this.models.verificationToken.verify(
|
||||
TokenType.SignIn,
|
||||
token,
|
||||
|
||||
@@ -12,6 +12,7 @@ import { Socket } from 'socket.io';
|
||||
import {
|
||||
AccessDenied,
|
||||
AuthenticationRequired,
|
||||
Cache,
|
||||
Config,
|
||||
CryptoHelper,
|
||||
getRequestResponseFromContext,
|
||||
@@ -23,6 +24,8 @@ import { Session, TokenSession } from './session';
|
||||
|
||||
const PUBLIC_ENTRYPOINT_SYMBOL = Symbol('public');
|
||||
const INTERNAL_ENTRYPOINT_SYMBOL = Symbol('internal');
|
||||
const INTERNAL_ACCESS_TOKEN_TTL_MS = 5 * 60 * 1000;
|
||||
const INTERNAL_ACCESS_TOKEN_CLOCK_SKEW_MS = 30 * 1000;
|
||||
|
||||
@Injectable()
|
||||
export class AuthGuard implements CanActivate, OnModuleInit {
|
||||
@@ -30,6 +33,7 @@ export class AuthGuard implements CanActivate, OnModuleInit {
|
||||
|
||||
constructor(
|
||||
private readonly crypto: CryptoHelper,
|
||||
private readonly cache: Cache,
|
||||
private readonly ref: ModuleRef,
|
||||
private readonly reflector: Reflector
|
||||
) {}
|
||||
@@ -48,10 +52,28 @@ export class AuthGuard implements CanActivate, OnModuleInit {
|
||||
[clazz, handler]
|
||||
);
|
||||
if (isInternal) {
|
||||
// check access token: data,signature
|
||||
const accessToken = req.get('x-access-token');
|
||||
if (accessToken && this.crypto.verify(accessToken)) {
|
||||
return true;
|
||||
if (accessToken) {
|
||||
const payload = this.crypto.parseInternalAccessToken(accessToken);
|
||||
if (payload) {
|
||||
const now = Date.now();
|
||||
const method = req.method.toUpperCase();
|
||||
const path = req.path;
|
||||
|
||||
const timestampInRange =
|
||||
payload.ts <= now + INTERNAL_ACCESS_TOKEN_CLOCK_SKEW_MS &&
|
||||
now - payload.ts <= INTERNAL_ACCESS_TOKEN_TTL_MS;
|
||||
|
||||
if (timestampInRange && payload.m === method && payload.p === path) {
|
||||
const nonceKey = `rpc:nonce:${payload.nonce}`;
|
||||
const ok = await this.cache.setnx(nonceKey, 1, {
|
||||
ttl: INTERNAL_ACCESS_TOKEN_TTL_MS,
|
||||
});
|
||||
if (ok) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
throw new AccessDenied('Invalid internal request');
|
||||
}
|
||||
|
||||
@@ -159,7 +159,7 @@ export class AuthResolver {
|
||||
user.id
|
||||
);
|
||||
|
||||
const url = this.url.link(callbackUrl, { userId: user.id, token });
|
||||
const url = this.url.safeLink(callbackUrl, { userId: user.id, token });
|
||||
|
||||
return await this.auth.sendChangePasswordEmail(user.email, url);
|
||||
}
|
||||
@@ -200,7 +200,7 @@ export class AuthResolver {
|
||||
user.id
|
||||
);
|
||||
|
||||
const url = this.url.link(callbackUrl, { token });
|
||||
const url = this.url.safeLink(callbackUrl, { token });
|
||||
|
||||
return await this.auth.sendChangeEmail(user.email, url);
|
||||
}
|
||||
@@ -244,7 +244,10 @@ export class AuthResolver {
|
||||
user.id
|
||||
);
|
||||
|
||||
const url = this.url.link(callbackUrl, { token: verifyEmailToken, email });
|
||||
const url = this.url.safeLink(callbackUrl, {
|
||||
token: verifyEmailToken,
|
||||
email,
|
||||
});
|
||||
return await this.auth.sendVerifyChangeEmail(email, url);
|
||||
}
|
||||
|
||||
@@ -258,7 +261,7 @@ export class AuthResolver {
|
||||
user.id
|
||||
);
|
||||
|
||||
const url = this.url.link(callbackUrl, { token });
|
||||
const url = this.url.safeLink(callbackUrl, { token });
|
||||
|
||||
return await this.auth.sendVerifyEmail(user.email, url);
|
||||
}
|
||||
@@ -302,6 +305,6 @@ export class AuthResolver {
|
||||
userId
|
||||
);
|
||||
|
||||
return this.url.link(callbackUrl, { userId, token });
|
||||
return this.url.safeLink(callbackUrl, { userId, token });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { Injectable, OnApplicationBootstrap } from '@nestjs/common';
|
||||
import type { CookieOptions, Request, Response } from 'express';
|
||||
import { assign, pick } from 'lodash-es';
|
||||
@@ -39,6 +41,7 @@ export class AuthService implements OnApplicationBootstrap {
|
||||
};
|
||||
static readonly sessionCookieName = 'affine_session';
|
||||
static readonly userCookieName = 'affine_user_id';
|
||||
static readonly csrfCookieName = 'affine_csrf_token';
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
@@ -171,6 +174,11 @@ export class AuthService implements OnApplicationBootstrap {
|
||||
expires: newExpiresAt,
|
||||
...this.cookieOptions,
|
||||
});
|
||||
res.cookie(AuthService.csrfCookieName, randomUUID(), {
|
||||
expires: newExpiresAt,
|
||||
...this.cookieOptions,
|
||||
httpOnly: false,
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -207,6 +215,12 @@ export class AuthService implements OnApplicationBootstrap {
|
||||
expires: userSession.expiresAt ?? void 0,
|
||||
});
|
||||
|
||||
res.cookie(AuthService.csrfCookieName, randomUUID(), {
|
||||
...this.cookieOptions,
|
||||
httpOnly: false,
|
||||
expires: userSession.expiresAt ?? void 0,
|
||||
});
|
||||
|
||||
this.setUserCookie(res, userId);
|
||||
}
|
||||
|
||||
@@ -227,6 +241,7 @@ export class AuthService implements OnApplicationBootstrap {
|
||||
private clearCookies(res: Response<any, Record<string, any>>) {
|
||||
res.clearCookie(AuthService.sessionCookieName);
|
||||
res.clearCookie(AuthService.userCookieName);
|
||||
res.clearCookie(AuthService.csrfCookieName);
|
||||
}
|
||||
|
||||
setUserCookie(res: Response, userId: string) {
|
||||
|
||||
@@ -240,18 +240,6 @@ export class AppConfigResolver {
|
||||
return this.validateConfigInternal(updates);
|
||||
}
|
||||
|
||||
@Mutation(() => [AppConfigValidateResult], {
|
||||
description: 'validate app configuration',
|
||||
deprecationReason: 'use Query.validateAppConfig',
|
||||
name: 'validateAppConfig',
|
||||
})
|
||||
async validateAppConfigMutation(
|
||||
@Args('updates', { type: () => [UpdateAppConfigInput] })
|
||||
updates: UpdateAppConfigInput[]
|
||||
): Promise<AppConfigValidateResult[]> {
|
||||
return this.validateConfigInternal(updates);
|
||||
}
|
||||
|
||||
private validateConfigInternal(
|
||||
updates: UpdateAppConfigInput[]
|
||||
): AppConfigValidateResult[] {
|
||||
|
||||
@@ -77,14 +77,124 @@ test('should forbid access to rpc api with invalid access token', async t => {
|
||||
t.pass();
|
||||
});
|
||||
|
||||
test('should forbid replayed internal access token', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const workspaceId = '123';
|
||||
const docId = '123';
|
||||
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}`;
|
||||
const token = t.context.crypto.signInternalAccessToken({
|
||||
method: 'GET',
|
||||
path,
|
||||
nonce: `nonce-${randomUUID()}`,
|
||||
});
|
||||
|
||||
await app.GET(path).set('x-access-token', token).expect(404);
|
||||
|
||||
await app
|
||||
.GET(path)
|
||||
.set('x-access-token', token)
|
||||
.expect({
|
||||
status: 403,
|
||||
code: 'Forbidden',
|
||||
type: 'NO_PERMISSION',
|
||||
name: 'ACCESS_DENIED',
|
||||
message: 'Invalid internal request',
|
||||
})
|
||||
.expect(403);
|
||||
t.pass();
|
||||
});
|
||||
|
||||
test('should forbid internal access token when method mismatched', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const workspaceId = '123';
|
||||
const docId = '123';
|
||||
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}/diff`;
|
||||
await app
|
||||
.POST(path)
|
||||
.set(
|
||||
'x-access-token',
|
||||
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
|
||||
)
|
||||
.expect({
|
||||
status: 403,
|
||||
code: 'Forbidden',
|
||||
type: 'NO_PERMISSION',
|
||||
name: 'ACCESS_DENIED',
|
||||
message: 'Invalid internal request',
|
||||
})
|
||||
.expect(403);
|
||||
t.pass();
|
||||
});
|
||||
|
||||
test('should forbid internal access token when path mismatched', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const workspaceId = '123';
|
||||
const docId = '123';
|
||||
const wrongPath = `/rpc/workspaces/${workspaceId}/docs/${docId}`;
|
||||
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}/content`;
|
||||
await app
|
||||
.GET(path)
|
||||
.set(
|
||||
'x-access-token',
|
||||
t.context.crypto.signInternalAccessToken({
|
||||
method: 'GET',
|
||||
path: wrongPath,
|
||||
})
|
||||
)
|
||||
.expect({
|
||||
status: 403,
|
||||
code: 'Forbidden',
|
||||
type: 'NO_PERMISSION',
|
||||
name: 'ACCESS_DENIED',
|
||||
message: 'Invalid internal request',
|
||||
})
|
||||
.expect(403);
|
||||
t.pass();
|
||||
});
|
||||
|
||||
test('should forbid internal access token when expired', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const workspaceId = '123';
|
||||
const docId = '123';
|
||||
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}`;
|
||||
await app
|
||||
.GET(path)
|
||||
.set(
|
||||
'x-access-token',
|
||||
t.context.crypto.signInternalAccessToken({
|
||||
method: 'GET',
|
||||
path,
|
||||
now: Date.now() - 10 * 60 * 1000,
|
||||
nonce: `nonce-${randomUUID()}`,
|
||||
})
|
||||
)
|
||||
.expect({
|
||||
status: 403,
|
||||
code: 'Forbidden',
|
||||
type: 'NO_PERMISSION',
|
||||
name: 'ACCESS_DENIED',
|
||||
message: 'Invalid internal request',
|
||||
})
|
||||
.expect(403);
|
||||
t.pass();
|
||||
});
|
||||
|
||||
test('should 404 when doc not found', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const workspaceId = '123';
|
||||
const docId = '123';
|
||||
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}`;
|
||||
await app
|
||||
.GET(`/rpc/workspaces/${workspaceId}/docs/${docId}`)
|
||||
.set('x-access-token', t.context.crypto.sign(docId))
|
||||
.GET(path)
|
||||
.set(
|
||||
'x-access-token',
|
||||
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
|
||||
)
|
||||
.expect({
|
||||
status: 404,
|
||||
code: 'Not Found',
|
||||
@@ -111,9 +221,13 @@ test('should return doc when found', async t => {
|
||||
},
|
||||
]);
|
||||
|
||||
const path = `/rpc/workspaces/${workspace.id}/docs/${docId}`;
|
||||
const res = await app
|
||||
.GET(`/rpc/workspaces/${workspace.id}/docs/${docId}`)
|
||||
.set('x-access-token', t.context.crypto.sign(docId))
|
||||
.GET(path)
|
||||
.set(
|
||||
'x-access-token',
|
||||
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
|
||||
)
|
||||
.set('x-cloud-trace-context', 'test-trace-id/span-id')
|
||||
.expect(200)
|
||||
.expect('x-request-id', 'test-trace-id')
|
||||
@@ -129,9 +243,13 @@ test('should 404 when doc diff not found', async t => {
|
||||
|
||||
const workspaceId = '123';
|
||||
const docId = '123';
|
||||
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}/diff`;
|
||||
await app
|
||||
.POST(`/rpc/workspaces/${workspaceId}/docs/${docId}/diff`)
|
||||
.set('x-access-token', t.context.crypto.sign(docId))
|
||||
.POST(path)
|
||||
.set(
|
||||
'x-access-token',
|
||||
t.context.crypto.signInternalAccessToken({ method: 'POST', path })
|
||||
)
|
||||
.expect({
|
||||
status: 404,
|
||||
code: 'Not Found',
|
||||
@@ -148,9 +266,13 @@ test('should 404 when doc content not found', async t => {
|
||||
|
||||
const workspaceId = '123';
|
||||
const docId = '123';
|
||||
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}/content`;
|
||||
await app
|
||||
.GET(`/rpc/workspaces/${workspaceId}/docs/${docId}/content`)
|
||||
.set('x-access-token', t.context.crypto.sign(docId))
|
||||
.GET(path)
|
||||
.set(
|
||||
'x-access-token',
|
||||
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
|
||||
)
|
||||
.expect({
|
||||
status: 404,
|
||||
code: 'Not Found',
|
||||
@@ -172,9 +294,13 @@ test('should get doc content in json format', async t => {
|
||||
});
|
||||
|
||||
const docId = randomUUID();
|
||||
const path = `/rpc/workspaces/${workspace.id}/docs/${docId}/content`;
|
||||
await app
|
||||
.GET(`/rpc/workspaces/${workspace.id}/docs/${docId}/content`)
|
||||
.set('x-access-token', t.context.crypto.sign(docId))
|
||||
.GET(path)
|
||||
.set(
|
||||
'x-access-token',
|
||||
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
|
||||
)
|
||||
.expect('Content-Type', 'application/json; charset=utf-8')
|
||||
.expect({
|
||||
title: 'test title',
|
||||
@@ -183,8 +309,11 @@ test('should get doc content in json format', async t => {
|
||||
.expect(200);
|
||||
|
||||
await app
|
||||
.GET(`/rpc/workspaces/${workspace.id}/docs/${docId}/content?full=false`)
|
||||
.set('x-access-token', t.context.crypto.sign(docId))
|
||||
.GET(`${path}?full=false`)
|
||||
.set(
|
||||
'x-access-token',
|
||||
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
|
||||
)
|
||||
.expect('Content-Type', 'application/json; charset=utf-8')
|
||||
.expect({
|
||||
title: 'test title',
|
||||
@@ -204,9 +333,13 @@ test('should get full doc content in json format', async t => {
|
||||
});
|
||||
|
||||
const docId = randomUUID();
|
||||
const path = `/rpc/workspaces/${workspace.id}/docs/${docId}/content`;
|
||||
await app
|
||||
.GET(`/rpc/workspaces/${workspace.id}/docs/${docId}/content?full=true`)
|
||||
.set('x-access-token', t.context.crypto.sign(docId))
|
||||
.GET(`${path}?full=true`)
|
||||
.set(
|
||||
'x-access-token',
|
||||
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
|
||||
)
|
||||
.expect('Content-Type', 'application/json; charset=utf-8')
|
||||
.expect({
|
||||
title: 'test title',
|
||||
@@ -220,9 +353,13 @@ test('should 404 when workspace content not found', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const workspaceId = '123';
|
||||
const path = `/rpc/workspaces/${workspaceId}/content`;
|
||||
await app
|
||||
.GET(`/rpc/workspaces/${workspaceId}/content`)
|
||||
.set('x-access-token', t.context.crypto.sign(workspaceId))
|
||||
.GET(path)
|
||||
.set(
|
||||
'x-access-token',
|
||||
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
|
||||
)
|
||||
.expect({
|
||||
status: 404,
|
||||
code: 'Not Found',
|
||||
@@ -244,9 +381,13 @@ test('should get workspace content in json format', async t => {
|
||||
});
|
||||
|
||||
const workspaceId = randomUUID();
|
||||
const path = `/rpc/workspaces/${workspaceId}/content`;
|
||||
await app
|
||||
.GET(`/rpc/workspaces/${workspaceId}/content`)
|
||||
.set('x-access-token', t.context.crypto.sign(workspaceId))
|
||||
.GET(path)
|
||||
.set(
|
||||
'x-access-token',
|
||||
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
|
||||
)
|
||||
.expect(200)
|
||||
.expect({
|
||||
name: 'test name',
|
||||
@@ -265,9 +406,13 @@ test('should get doc markdown in json format', async t => {
|
||||
});
|
||||
|
||||
const docId = randomUUID();
|
||||
const path = `/rpc/workspaces/${workspace.id}/docs/${docId}/markdown`;
|
||||
await app
|
||||
.GET(`/rpc/workspaces/${workspace.id}/docs/${docId}/markdown`)
|
||||
.set('x-access-token', t.context.crypto.sign(docId))
|
||||
.GET(path)
|
||||
.set(
|
||||
'x-access-token',
|
||||
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
|
||||
)
|
||||
.expect('Content-Type', 'application/json; charset=utf-8')
|
||||
.expect(200)
|
||||
.expect({
|
||||
@@ -282,9 +427,13 @@ test('should 404 when doc markdown not found', async t => {
|
||||
|
||||
const workspaceId = '123';
|
||||
const docId = '123';
|
||||
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}/markdown`;
|
||||
await app
|
||||
.GET(`/rpc/workspaces/${workspaceId}/docs/${docId}/markdown`)
|
||||
.set('x-access-token', t.context.crypto.sign(docId))
|
||||
.GET(path)
|
||||
.set(
|
||||
'x-access-token',
|
||||
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
|
||||
)
|
||||
.expect({
|
||||
status: 404,
|
||||
code: 'Not Found',
|
||||
|
||||
@@ -257,12 +257,13 @@ export class RpcDocReader extends DatabaseDocReader {
|
||||
super(cache, models, blobStorage, workspace);
|
||||
}
|
||||
|
||||
private async fetch(
|
||||
accessToken: string,
|
||||
url: string,
|
||||
method: 'GET' | 'POST',
|
||||
body?: Uint8Array
|
||||
) {
|
||||
private async fetch(url: string, method: 'GET' | 'POST', body?: Uint8Array) {
|
||||
const { pathname } = new URL(url);
|
||||
const accessToken = this.crypto.signInternalAccessToken({
|
||||
method,
|
||||
path: pathname,
|
||||
});
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
'x-access-token': accessToken,
|
||||
'x-cloud-trace-context': getOrGenRequestId('rpc'),
|
||||
@@ -293,9 +294,8 @@ export class RpcDocReader extends DatabaseDocReader {
|
||||
docId: string
|
||||
): Promise<DocRecord | null> {
|
||||
const url = `${this.config.docService.endpoint}/rpc/workspaces/${workspaceId}/docs/${docId}`;
|
||||
const accessToken = this.crypto.sign(docId);
|
||||
try {
|
||||
const res = await this.fetch(accessToken, url, 'GET');
|
||||
const res = await this.fetch(url, 'GET');
|
||||
if (!res) {
|
||||
return null;
|
||||
}
|
||||
@@ -330,9 +330,8 @@ export class RpcDocReader extends DatabaseDocReader {
|
||||
aiEditable: boolean
|
||||
): Promise<DocMarkdown | null> {
|
||||
const url = `${this.config.docService.endpoint}/rpc/workspaces/${workspaceId}/docs/${docId}/markdown?aiEditable=${aiEditable}`;
|
||||
const accessToken = this.crypto.sign(docId);
|
||||
try {
|
||||
const res = await this.fetch(accessToken, url, 'GET');
|
||||
const res = await this.fetch(url, 'GET');
|
||||
if (!res) {
|
||||
return null;
|
||||
}
|
||||
@@ -358,9 +357,8 @@ export class RpcDocReader extends DatabaseDocReader {
|
||||
stateVector?: Uint8Array
|
||||
): Promise<DocDiff | null> {
|
||||
const url = `${this.config.docService.endpoint}/rpc/workspaces/${workspaceId}/docs/${docId}/diff`;
|
||||
const accessToken = this.crypto.sign(docId);
|
||||
try {
|
||||
const res = await this.fetch(accessToken, url, 'POST', stateVector);
|
||||
const res = await this.fetch(url, 'POST', stateVector);
|
||||
if (!res) {
|
||||
return null;
|
||||
}
|
||||
@@ -399,9 +397,8 @@ export class RpcDocReader extends DatabaseDocReader {
|
||||
fullContent = false
|
||||
): Promise<PageDocContent | null> {
|
||||
const url = `${this.config.docService.endpoint}/rpc/workspaces/${workspaceId}/docs/${docId}/content?full=${fullContent}`;
|
||||
const accessToken = this.crypto.sign(docId);
|
||||
try {
|
||||
const res = await this.fetch(accessToken, url, 'GET');
|
||||
const res = await this.fetch(url, 'GET');
|
||||
if (!res) {
|
||||
return null;
|
||||
}
|
||||
@@ -427,9 +424,8 @@ export class RpcDocReader extends DatabaseDocReader {
|
||||
workspaceId: string
|
||||
): Promise<WorkspaceDocInfo | null> {
|
||||
const url = `${this.config.docService.endpoint}/rpc/workspaces/${workspaceId}/content`;
|
||||
const accessToken = this.crypto.sign(workspaceId);
|
||||
try {
|
||||
const res = await this.fetch(accessToken, url, 'GET');
|
||||
const res = await this.fetch(url, 'GET');
|
||||
if (!res) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -130,7 +130,7 @@ export abstract class DocStorageAdapter extends Connection {
|
||||
snapshot: DocRecord | null,
|
||||
finalUpdate: DocUpdate
|
||||
) {
|
||||
this.logger.log(
|
||||
this.logger.verbose(
|
||||
`Squashing updates, spaceId: ${spaceId}, docId: ${docId}, updates: ${updates.length}`
|
||||
);
|
||||
|
||||
@@ -152,7 +152,7 @@ export abstract class DocStorageAdapter extends Connection {
|
||||
|
||||
// always mark updates as merged unless throws
|
||||
const count = await this.markUpdatesMerged(spaceId, docId, updates);
|
||||
this.logger.log(
|
||||
this.logger.verbose(
|
||||
`Marked ${count} updates as merged, spaceId: ${spaceId}, docId: ${docId}, timestamp: ${timestamp}`
|
||||
);
|
||||
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
import ava, { TestFn } from 'ava';
|
||||
|
||||
import { createTestingApp, type TestingApp } from '../../../__tests__/utils';
|
||||
import { buildAppModule } from '../../../app.module';
|
||||
import { Models } from '../../../models';
|
||||
|
||||
const test = ava as TestFn<{
|
||||
app: TestingApp;
|
||||
models: Models;
|
||||
allowlistedAdminToken: string;
|
||||
nonAllowlistedAdminToken: string;
|
||||
userToken: string;
|
||||
}>;
|
||||
|
||||
test.before(async t => {
|
||||
const app = await createTestingApp({
|
||||
imports: [buildAppModule(globalThis.env)],
|
||||
});
|
||||
|
||||
t.context.app = app;
|
||||
t.context.models = app.get(Models);
|
||||
});
|
||||
|
||||
test.beforeEach(async t => {
|
||||
await t.context.app.initTestingDB();
|
||||
|
||||
const allowlistedAdmin = await t.context.models.user.create({
|
||||
email: 'admin@affine.pro',
|
||||
password: '1',
|
||||
emailVerifiedAt: new Date(),
|
||||
});
|
||||
await t.context.models.userFeature.add(
|
||||
allowlistedAdmin.id,
|
||||
'administrator',
|
||||
'test'
|
||||
);
|
||||
const allowlistedAdminToken = await t.context.models.accessToken.create({
|
||||
userId: allowlistedAdmin.id,
|
||||
name: 'test',
|
||||
});
|
||||
t.context.allowlistedAdminToken = allowlistedAdminToken.token;
|
||||
|
||||
const nonAllowlistedAdmin = await t.context.models.user.create({
|
||||
email: 'admin2@affine.pro',
|
||||
password: '1',
|
||||
emailVerifiedAt: new Date(),
|
||||
});
|
||||
await t.context.models.userFeature.add(
|
||||
nonAllowlistedAdmin.id,
|
||||
'administrator',
|
||||
'test'
|
||||
);
|
||||
const nonAllowlistedAdminToken = await t.context.models.accessToken.create({
|
||||
userId: nonAllowlistedAdmin.id,
|
||||
name: 'test',
|
||||
});
|
||||
t.context.nonAllowlistedAdminToken = nonAllowlistedAdminToken.token;
|
||||
|
||||
const user = await t.context.models.user.create({
|
||||
email: 'user@affine.pro',
|
||||
password: '1',
|
||||
emailVerifiedAt: new Date(),
|
||||
});
|
||||
const userToken = await t.context.models.accessToken.create({
|
||||
userId: user.id,
|
||||
name: 'test',
|
||||
});
|
||||
t.context.userToken = userToken.token;
|
||||
});
|
||||
|
||||
test.after.always(async t => {
|
||||
await t.context.app.close();
|
||||
});
|
||||
|
||||
test('should return 404 for non-admin user', async t => {
|
||||
await t.context.app
|
||||
.GET('/api/queue')
|
||||
.set('Authorization', `Bearer ${t.context.userToken}`)
|
||||
.expect(404);
|
||||
t.pass();
|
||||
});
|
||||
|
||||
test('should allow allowlisted admin', async t => {
|
||||
await t.context.app
|
||||
.GET('/api/queue')
|
||||
.set('Authorization', `Bearer ${t.context.allowlistedAdminToken}`)
|
||||
.expect(200)
|
||||
.expect('Content-Type', /text\/html/);
|
||||
t.pass();
|
||||
});
|
||||
@@ -53,12 +53,21 @@ class QueueDashboardService implements OnModuleInit {
|
||||
): Promise<void> => {
|
||||
try {
|
||||
const session = await this.authGuard.signIn(req, res);
|
||||
const userId = session?.user?.id;
|
||||
const user = session?.user;
|
||||
const userId = user?.id;
|
||||
const email = user?.email?.toLowerCase();
|
||||
|
||||
const isAdmin = userId ? await this.feature.isAdmin(userId) : false;
|
||||
if (!isAdmin) {
|
||||
res.status(404).end();
|
||||
return;
|
||||
}
|
||||
|
||||
if (req.method === 'GET' && (req.path === '/' || req.path === '')) {
|
||||
this.logger.log(
|
||||
`QueueDash accessed by ${userId} (${email ?? 'n/a'})`
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
this.logger.warn('QueueDash auth failed', error as Error);
|
||||
res.status(404).end();
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
WebSocketServer,
|
||||
} from '@nestjs/websockets';
|
||||
import { ClsInterceptor } from 'nestjs-cls';
|
||||
import semver from 'semver';
|
||||
import { type Server, Socket } from 'socket.io';
|
||||
|
||||
import {
|
||||
@@ -23,6 +24,7 @@ import {
|
||||
SpaceAccessDenied,
|
||||
} from '../../base';
|
||||
import { Models } from '../../models';
|
||||
import { mergeUpdatesInApplyWay } from '../../native';
|
||||
import { CurrentUser } from '../auth';
|
||||
import {
|
||||
DocReader,
|
||||
@@ -48,9 +50,10 @@ type EventResponse<Data = any> = Data extends never
|
||||
data: Data;
|
||||
};
|
||||
|
||||
// 019 only receives space:broadcast-doc-updates and send space:push-doc-updates
|
||||
// 020 only receives space:broadcast-doc-update and send space:push-doc-update
|
||||
type RoomType = 'sync' | `${string}:awareness` | 'sync-019';
|
||||
// sync: shared room for space membership checks and non-protocol broadcasts.
|
||||
// sync-025: legacy 0.25 doc sync protocol (space:broadcast-doc-update).
|
||||
// sync-026: current doc sync protocol (space:broadcast-doc-updates).
|
||||
type RoomType = 'sync' | 'sync-025' | 'sync-026' | `${string}:awareness`;
|
||||
|
||||
function Room(
|
||||
spaceId: string,
|
||||
@@ -59,6 +62,25 @@ function Room(
|
||||
return `${spaceId}:${type}`;
|
||||
}
|
||||
|
||||
const MIN_WS_CLIENT_VERSION = new semver.Range('>=0.25.0', {
|
||||
includePrerelease: true,
|
||||
});
|
||||
const DOC_UPDATES_PROTOCOL_026 = new semver.Range('>=0.26.0-0', {
|
||||
includePrerelease: true,
|
||||
});
|
||||
|
||||
type SyncProtocolRoomType = Extract<RoomType, 'sync-025' | 'sync-026'>;
|
||||
|
||||
function isSupportedWsClientVersion(clientVersion: string): boolean {
|
||||
return Boolean(
|
||||
semver.valid(clientVersion) && MIN_WS_CLIENT_VERSION.test(clientVersion)
|
||||
);
|
||||
}
|
||||
|
||||
function getSyncProtocolRoomType(clientVersion: string): SyncProtocolRoomType {
|
||||
return DOC_UPDATES_PROTOCOL_026.test(clientVersion) ? 'sync-026' : 'sync-025';
|
||||
}
|
||||
|
||||
enum SpaceType {
|
||||
Workspace = 'workspace',
|
||||
Userspace = 'userspace',
|
||||
@@ -88,16 +110,6 @@ interface LeaveSpaceAwarenessMessage {
|
||||
docId: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated
|
||||
*/
|
||||
interface PushDocUpdatesMessage {
|
||||
spaceType: SpaceType;
|
||||
spaceId: string;
|
||||
docId: string;
|
||||
updates: string[];
|
||||
}
|
||||
|
||||
interface PushDocUpdateMessage {
|
||||
spaceType: SpaceType;
|
||||
spaceId: string;
|
||||
@@ -105,6 +117,25 @@ interface PushDocUpdateMessage {
|
||||
update: string;
|
||||
}
|
||||
|
||||
interface BroadcastDocUpdatesMessage {
|
||||
spaceType: SpaceType;
|
||||
spaceId: string;
|
||||
docId: string;
|
||||
updates: string[];
|
||||
timestamp: number;
|
||||
editor?: string;
|
||||
compressed?: boolean;
|
||||
}
|
||||
|
||||
interface BroadcastDocUpdateMessage {
|
||||
spaceType: SpaceType;
|
||||
spaceId: string;
|
||||
docId: string;
|
||||
update: string;
|
||||
timestamp: number;
|
||||
editor: string;
|
||||
}
|
||||
|
||||
interface LoadDocMessage {
|
||||
spaceType: SpaceType;
|
||||
spaceId: string;
|
||||
@@ -157,6 +188,67 @@ export class SpaceSyncGateway
|
||||
private readonly models: Models
|
||||
) {}
|
||||
|
||||
private encodeUpdates(updates: Uint8Array[]) {
|
||||
return updates.map(update => Buffer.from(update).toString('base64'));
|
||||
}
|
||||
|
||||
private buildBroadcastPayload(
|
||||
spaceType: SpaceType,
|
||||
spaceId: string,
|
||||
docId: string,
|
||||
updates: Uint8Array[],
|
||||
timestamp: number,
|
||||
editor?: string
|
||||
): BroadcastDocUpdatesMessage {
|
||||
const encodedUpdates = this.encodeUpdates(updates);
|
||||
if (updates.length <= 1) {
|
||||
return {
|
||||
spaceType,
|
||||
spaceId,
|
||||
docId,
|
||||
updates: encodedUpdates,
|
||||
timestamp,
|
||||
editor,
|
||||
compressed: false,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const merged = mergeUpdatesInApplyWay(
|
||||
updates.map(update => Buffer.from(update))
|
||||
);
|
||||
metrics.socketio.counter('doc_updates_compressed').add(1);
|
||||
return {
|
||||
spaceType,
|
||||
spaceId,
|
||||
docId,
|
||||
updates: [Buffer.from(merged).toString('base64')],
|
||||
timestamp,
|
||||
editor,
|
||||
compressed: true,
|
||||
};
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
'Failed to merge updates for broadcast, falling back to batch',
|
||||
error as Error
|
||||
);
|
||||
return {
|
||||
spaceType,
|
||||
spaceId,
|
||||
docId,
|
||||
updates: encodedUpdates,
|
||||
timestamp,
|
||||
editor,
|
||||
compressed: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private rejectJoin(client: Socket) {
|
||||
// Give socket.io a chance to flush the ack packet before disconnecting.
|
||||
setImmediate(() => client.disconnect());
|
||||
}
|
||||
|
||||
handleConnection() {
|
||||
this.connectionCount++;
|
||||
this.logger.debug(`New connection, total: ${this.connectionCount}`);
|
||||
@@ -184,31 +276,35 @@ export class SpaceSyncGateway
|
||||
return;
|
||||
}
|
||||
|
||||
const encodedUpdates = updates.map(update =>
|
||||
Buffer.from(update).toString('base64')
|
||||
);
|
||||
|
||||
this.server
|
||||
.to(Room(spaceId, 'sync-019'))
|
||||
.emit('space:broadcast-doc-updates', {
|
||||
spaceType,
|
||||
spaceId,
|
||||
docId,
|
||||
updates: encodedUpdates,
|
||||
timestamp,
|
||||
});
|
||||
|
||||
const room = `${spaceType}:${Room(spaceId)}`;
|
||||
encodedUpdates.forEach(update => {
|
||||
this.server.to(room).emit('space:broadcast-doc-update', {
|
||||
spaceType,
|
||||
const room025 = `${spaceType}:${Room(spaceId, 'sync-025')}`;
|
||||
const encodedUpdates = this.encodeUpdates(updates);
|
||||
for (const update of encodedUpdates) {
|
||||
const payload: BroadcastDocUpdateMessage = {
|
||||
spaceType: spaceType as SpaceType,
|
||||
spaceId,
|
||||
docId,
|
||||
update,
|
||||
timestamp,
|
||||
editor,
|
||||
editor: editor ?? '',
|
||||
};
|
||||
this.server.to(room025).emit('space:broadcast-doc-update', payload);
|
||||
}
|
||||
|
||||
const room026 = `${spaceType}:${Room(spaceId, 'sync-026')}`;
|
||||
const payload = this.buildBroadcastPayload(
|
||||
spaceType as SpaceType,
|
||||
spaceId,
|
||||
docId,
|
||||
updates,
|
||||
timestamp,
|
||||
editor
|
||||
);
|
||||
this.server.to(room026).emit('space:broadcast-doc-updates', payload);
|
||||
metrics.socketio
|
||||
.counter('doc_updates_broadcast')
|
||||
.add(payload.updates.length, {
|
||||
mode: payload.compressed ? 'compressed' : 'batch',
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
selectAdapter(client: Socket, spaceType: SpaceType): SyncSocketAdapter {
|
||||
@@ -240,16 +336,34 @@ export class SpaceSyncGateway
|
||||
@MessageBody()
|
||||
{ spaceType, spaceId, clientVersion }: JoinSpaceMessage
|
||||
): Promise<EventResponse<{ clientId: string; success: boolean }>> {
|
||||
if (
|
||||
![SpaceType.Userspace, SpaceType.Workspace].includes(spaceType) ||
|
||||
/^0.1/.test(clientVersion)
|
||||
) {
|
||||
if (![SpaceType.Userspace, SpaceType.Workspace].includes(spaceType)) {
|
||||
this.rejectJoin(client);
|
||||
return { data: { clientId: client.id, success: false } };
|
||||
} else {
|
||||
if (spaceType === SpaceType.Workspace) {
|
||||
this.event.emit('workspace.embedding', { workspaceId: spaceId });
|
||||
}
|
||||
await this.selectAdapter(client, spaceType).join(user.id, spaceId);
|
||||
}
|
||||
|
||||
if (!isSupportedWsClientVersion(clientVersion)) {
|
||||
this.rejectJoin(client);
|
||||
return { data: { clientId: client.id, success: false } };
|
||||
}
|
||||
|
||||
if (spaceType === SpaceType.Workspace) {
|
||||
this.event.emit('workspace.embedding', { workspaceId: spaceId });
|
||||
}
|
||||
|
||||
const adapter = this.selectAdapter(client, spaceType);
|
||||
await adapter.join(user.id, spaceId);
|
||||
|
||||
const protocolRoomType = getSyncProtocolRoomType(clientVersion);
|
||||
const protocolRoom = adapter.room(spaceId, protocolRoomType);
|
||||
const otherProtocolRoom = adapter.room(
|
||||
spaceId,
|
||||
protocolRoomType === 'sync-025' ? 'sync-026' : 'sync-025'
|
||||
);
|
||||
if (client.rooms.has(otherProtocolRoom)) {
|
||||
await client.leave(otherProtocolRoom);
|
||||
}
|
||||
if (!client.rooms.has(protocolRoom)) {
|
||||
await client.join(protocolRoom);
|
||||
}
|
||||
|
||||
return { data: { clientId: client.id, success: true } };
|
||||
@@ -306,52 +420,8 @@ export class SpaceSyncGateway
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use [space:push-doc-update] instead, client should always merge updates on their own
|
||||
*
|
||||
* only 0.19.x client will send this event
|
||||
* client should always merge updates on their own
|
||||
*/
|
||||
@SubscribeMessage('space:push-doc-updates')
|
||||
async onReceiveDocUpdates(
|
||||
@ConnectedSocket() client: Socket,
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@MessageBody()
|
||||
message: PushDocUpdatesMessage
|
||||
): Promise<EventResponse<{ accepted: true; timestamp?: number }>> {
|
||||
const { spaceType, spaceId, docId, updates } = message;
|
||||
const adapter = this.selectAdapter(client, spaceType);
|
||||
const id = new DocID(docId, spaceId);
|
||||
|
||||
// TODO(@forehalo): enable after frontend supporting doc revert
|
||||
// await this.ac.user(user.id).doc(spaceId, id.guid).assert('Doc.Update');
|
||||
const timestamp = await adapter.push(
|
||||
spaceId,
|
||||
id.guid,
|
||||
updates.map(update => Buffer.from(update, 'base64')),
|
||||
user.id
|
||||
);
|
||||
|
||||
// broadcast to 0.19.x clients
|
||||
client
|
||||
.to(Room(spaceId, 'sync-019'))
|
||||
.emit('space:broadcast-doc-updates', { ...message, timestamp });
|
||||
|
||||
// broadcast to new clients
|
||||
updates.forEach(update => {
|
||||
client.to(adapter.room(spaceId)).emit('space:broadcast-doc-update', {
|
||||
...message,
|
||||
update,
|
||||
timestamp,
|
||||
});
|
||||
});
|
||||
|
||||
return {
|
||||
data: {
|
||||
accepted: true,
|
||||
timestamp,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@SubscribeMessage('space:push-doc-update')
|
||||
async onReceiveDocUpdate(
|
||||
@ConnectedSocket() client: Socket,
|
||||
@@ -371,23 +441,33 @@ export class SpaceSyncGateway
|
||||
user.id
|
||||
);
|
||||
|
||||
// broadcast to 0.19.x clients
|
||||
client.to(Room(spaceId, 'sync-019')).emit('space:broadcast-doc-updates', {
|
||||
const payload = this.buildBroadcastPayload(
|
||||
spaceType,
|
||||
spaceId,
|
||||
docId,
|
||||
updates: [update],
|
||||
[Buffer.from(update, 'base64')],
|
||||
timestamp,
|
||||
});
|
||||
user.id
|
||||
);
|
||||
client
|
||||
.to(adapter.room(spaceId, 'sync-026'))
|
||||
.emit('space:broadcast-doc-updates', payload);
|
||||
metrics.socketio
|
||||
.counter('doc_updates_broadcast')
|
||||
.add(payload.updates.length, {
|
||||
mode: payload.compressed ? 'compressed' : 'batch',
|
||||
});
|
||||
|
||||
client.to(adapter.room(spaceId)).emit('space:broadcast-doc-update', {
|
||||
spaceType,
|
||||
spaceId,
|
||||
docId,
|
||||
update,
|
||||
timestamp,
|
||||
editor: user.id,
|
||||
});
|
||||
client
|
||||
.to(adapter.room(spaceId, 'sync-025'))
|
||||
.emit('space:broadcast-doc-update', {
|
||||
spaceType,
|
||||
spaceId,
|
||||
docId,
|
||||
update,
|
||||
timestamp,
|
||||
editor: user.id,
|
||||
} satisfies BroadcastDocUpdateMessage);
|
||||
|
||||
return {
|
||||
data: {
|
||||
@@ -417,8 +497,18 @@ export class SpaceSyncGateway
|
||||
@ConnectedSocket() client: Socket,
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@MessageBody()
|
||||
{ spaceType, spaceId, docId }: JoinSpaceAwarenessMessage
|
||||
{ spaceType, spaceId, docId, clientVersion }: JoinSpaceAwarenessMessage
|
||||
) {
|
||||
if (![SpaceType.Userspace, SpaceType.Workspace].includes(spaceType)) {
|
||||
this.rejectJoin(client);
|
||||
return { data: { clientId: client.id, success: false } };
|
||||
}
|
||||
|
||||
if (!isSupportedWsClientVersion(clientVersion)) {
|
||||
this.rejectJoin(client);
|
||||
return { data: { clientId: client.id, success: false } };
|
||||
}
|
||||
|
||||
await this.selectAdapter(client, spaceType).join(
|
||||
user.id,
|
||||
spaceId,
|
||||
@@ -456,13 +546,6 @@ export class SpaceSyncGateway
|
||||
.to(adapter.room(spaceId, roomType))
|
||||
.emit('space:collect-awareness', { spaceType, spaceId, docId });
|
||||
|
||||
// TODO(@forehalo): remove backward compatibility
|
||||
if (spaceType === SpaceType.Workspace) {
|
||||
client
|
||||
.to(adapter.room(spaceId, roomType))
|
||||
.emit('new-client-awareness-init');
|
||||
}
|
||||
|
||||
return { data: { clientId: client.id } };
|
||||
}
|
||||
|
||||
|
||||
@@ -66,21 +66,27 @@ export class UserResolver {
|
||||
): Promise<typeof UserOrLimitedUser | null> {
|
||||
validators.assertValidEmail(email);
|
||||
|
||||
// TODO(@forehalo): need to limit a user can only get another user witch is in the same workspace
|
||||
// NOTE: prevent user enumeration. Only allow querying users within the same workspace scope.
|
||||
if (!currentUser) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const user = await this.models.user.getUserByEmail(email);
|
||||
|
||||
// return empty response when user not exists
|
||||
if (!user) return null;
|
||||
|
||||
if (currentUser) {
|
||||
if (user.id === currentUser.id) {
|
||||
return sessionUser(user);
|
||||
}
|
||||
|
||||
// only return limited info when not logged in
|
||||
return {
|
||||
email: user.email,
|
||||
hasPassword: !!user.password,
|
||||
};
|
||||
const allowed = await this.models.workspaceUser.hasSharedWorkspace(
|
||||
currentUser.id,
|
||||
user.id
|
||||
);
|
||||
if (!allowed) return null;
|
||||
|
||||
return sessionUser(user);
|
||||
}
|
||||
|
||||
@Throttle('strict')
|
||||
|
||||
@@ -26,6 +26,6 @@ defineModuleConfig('client', {
|
||||
},
|
||||
'versionControl.requiredVersion': {
|
||||
desc: "Allowed version range of the app that allowed to access the server. Requires 'client/versionControl.enabled' to be true to take effect.",
|
||||
default: '>=0.20.0',
|
||||
default: '>=0.25.0',
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Injectable, NotFoundException } from '@nestjs/common';
|
||||
import {
|
||||
Args,
|
||||
Field,
|
||||
@@ -189,6 +189,12 @@ class AdminUpdateWorkspaceInput extends PartialType(
|
||||
export class AdminWorkspaceResolver {
|
||||
constructor(private readonly models: Models) {}
|
||||
|
||||
private assertCloudOnly() {
|
||||
if (env.selfhosted) {
|
||||
throw new NotFoundException();
|
||||
}
|
||||
}
|
||||
|
||||
@Query(() => [AdminWorkspace], {
|
||||
description: 'List workspaces for admin',
|
||||
})
|
||||
@@ -196,6 +202,7 @@ export class AdminWorkspaceResolver {
|
||||
@Args('filter', { type: () => ListWorkspaceInput })
|
||||
filter: ListWorkspaceInput
|
||||
) {
|
||||
this.assertCloudOnly();
|
||||
const { rows } = await this.models.workspace.adminListWorkspaces({
|
||||
first: filter.first,
|
||||
skip: filter.skip,
|
||||
@@ -219,6 +226,7 @@ export class AdminWorkspaceResolver {
|
||||
@Args('filter', { type: () => ListWorkspaceInput })
|
||||
filter: ListWorkspaceInput
|
||||
) {
|
||||
this.assertCloudOnly();
|
||||
const total = await this.models.workspace.adminCountWorkspaces({
|
||||
keyword: filter.keyword,
|
||||
features: filter.features,
|
||||
@@ -238,6 +246,7 @@ export class AdminWorkspaceResolver {
|
||||
nullable: true,
|
||||
})
|
||||
async adminWorkspace(@Args('id') id: string) {
|
||||
this.assertCloudOnly();
|
||||
const { rows } = await this.models.workspace.adminListWorkspaces({
|
||||
first: 1,
|
||||
skip: 0,
|
||||
@@ -318,6 +327,7 @@ export class AdminWorkspaceResolver {
|
||||
@Args('input', { type: () => AdminUpdateWorkspaceInput })
|
||||
input: AdminUpdateWorkspaceInput
|
||||
) {
|
||||
this.assertCloudOnly();
|
||||
const { id, features, ...updates } = input;
|
||||
|
||||
if (Object.keys(updates).length) {
|
||||
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
Mutation,
|
||||
ObjectType,
|
||||
Parent,
|
||||
Query,
|
||||
registerEnumType,
|
||||
ResolveField,
|
||||
Resolver,
|
||||
@@ -33,7 +32,7 @@ import {
|
||||
MULTIPART_PART_SIZE,
|
||||
MULTIPART_THRESHOLD,
|
||||
} from '../../storage/constants';
|
||||
import { WorkspaceBlobSizes, WorkspaceType } from '../types';
|
||||
import { WorkspaceType } from '../types';
|
||||
|
||||
enum BlobUploadMethod {
|
||||
GRAPHQL = 'GRAPHQL',
|
||||
@@ -169,14 +168,6 @@ export class WorkspaceBlobResolver {
|
||||
return this.getUploadPart(user, workspace.id, key, uploadId, partNumber);
|
||||
}
|
||||
|
||||
@Query(() => WorkspaceBlobSizes, {
|
||||
deprecationReason: 'use `user.quotaUsage` instead',
|
||||
})
|
||||
async collectAllBlobSizes(@CurrentUser() user: CurrentUser) {
|
||||
const size = await this.quota.getUserStorageUsage(user.id);
|
||||
return { size };
|
||||
}
|
||||
|
||||
@Mutation(() => String)
|
||||
async setBlob(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@@ -412,19 +403,6 @@ export class WorkspaceBlobResolver {
|
||||
return key;
|
||||
}
|
||||
|
||||
@Mutation(() => BlobUploadPart, {
|
||||
deprecationReason: 'use WorkspaceType.blobUploadPartUrl',
|
||||
})
|
||||
async getBlobUploadPartUrl(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('workspaceId') workspaceId: string,
|
||||
@Args('key') key: string,
|
||||
@Args('uploadId') uploadId: string,
|
||||
@Args('partNumber', { type: () => Int }) partNumber: number
|
||||
): Promise<BlobUploadPart> {
|
||||
return this.getUploadPart(user, workspaceId, key, uploadId, partNumber);
|
||||
}
|
||||
|
||||
@Mutation(() => Boolean)
|
||||
async abortBlobUpload(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
|
||||
@@ -238,20 +238,6 @@ export class WorkspaceDocResolver {
|
||||
return this.models.doc.findPublics(workspace.id);
|
||||
}
|
||||
|
||||
@ResolveField(() => DocType, {
|
||||
description: 'Get public page of a workspace by page id.',
|
||||
complexity: 2,
|
||||
nullable: true,
|
||||
deprecationReason: 'use [WorkspaceType.doc] instead',
|
||||
})
|
||||
async publicPage(
|
||||
@CurrentUser() me: CurrentUser,
|
||||
@Parent() workspace: WorkspaceType,
|
||||
@Args('pageId') pageId: string
|
||||
) {
|
||||
return this.doc(me, workspace, pageId);
|
||||
}
|
||||
|
||||
@ResolveField(() => PaginatedDocType)
|
||||
async docs(
|
||||
@Parent() workspace: WorkspaceType,
|
||||
@@ -314,24 +300,6 @@ export class WorkspaceDocResolver {
|
||||
};
|
||||
}
|
||||
|
||||
@Mutation(() => DocType, {
|
||||
deprecationReason: 'use publishDoc instead',
|
||||
})
|
||||
async publishPage(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('workspaceId') workspaceId: string,
|
||||
@Args('pageId') pageId: string,
|
||||
@Args({
|
||||
name: 'mode',
|
||||
type: () => PublicDocMode,
|
||||
nullable: true,
|
||||
defaultValue: PublicDocMode.Page,
|
||||
})
|
||||
mode: PublicDocMode
|
||||
) {
|
||||
return this.publishDoc(user, workspaceId, pageId, mode);
|
||||
}
|
||||
|
||||
@Mutation(() => DocType)
|
||||
async publishDoc(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@@ -364,17 +332,6 @@ export class WorkspaceDocResolver {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Mutation(() => DocType, {
|
||||
deprecationReason: 'use revokePublicDoc instead',
|
||||
})
|
||||
async revokePublicPage(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('workspaceId') workspaceId: string,
|
||||
@Args('docId') docId: string
|
||||
) {
|
||||
return this.revokePublicDoc(user, workspaceId, docId);
|
||||
}
|
||||
|
||||
@Mutation(() => DocType)
|
||||
async revokePublicDoc(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
|
||||
@@ -234,25 +234,6 @@ export class WorkspaceMemberResolver {
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated
|
||||
*/
|
||||
@Mutation(() => [InviteResult], {
|
||||
deprecationReason: 'use [inviteMembers] instead',
|
||||
})
|
||||
async inviteBatch(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('workspaceId') workspaceId: string,
|
||||
@Args({ name: 'emails', type: () => [String] }) emails: string[],
|
||||
@Args('sendInviteMail', {
|
||||
nullable: true,
|
||||
deprecationReason: 'never used',
|
||||
})
|
||||
_sendInviteMail: boolean = false
|
||||
) {
|
||||
return this.inviteMembers(user, workspaceId, emails);
|
||||
}
|
||||
|
||||
@ResolveField(() => InviteLink, {
|
||||
description: 'invite link for workspace',
|
||||
nullable: true,
|
||||
@@ -456,20 +437,6 @@ export class WorkspaceMemberResolver {
|
||||
return { workspace, user: owner, invitee, status };
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated
|
||||
*/
|
||||
@Mutation(() => Boolean, {
|
||||
deprecationReason: 'use [revokeMember] instead',
|
||||
})
|
||||
async revoke(
|
||||
@CurrentUser() me: CurrentUser,
|
||||
@Args('workspaceId') workspaceId: string,
|
||||
@Args('userId') userId: string
|
||||
) {
|
||||
return this.revokeMember(me, workspaceId, userId);
|
||||
}
|
||||
|
||||
@Mutation(() => Boolean)
|
||||
async revokeMember(
|
||||
@CurrentUser() me: CurrentUser,
|
||||
|
||||
@@ -156,40 +156,6 @@ export class WorkspaceResolver {
|
||||
};
|
||||
}
|
||||
|
||||
@Query(() => Boolean, {
|
||||
description: 'Get is owner of workspace',
|
||||
complexity: 2,
|
||||
deprecationReason: 'use WorkspaceType[role] instead',
|
||||
})
|
||||
async isOwner(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('workspaceId') workspaceId: string
|
||||
) {
|
||||
const role = await this.models.workspaceUser.getActive(
|
||||
workspaceId,
|
||||
user.id
|
||||
);
|
||||
|
||||
return role?.type === WorkspaceRole.Owner;
|
||||
}
|
||||
|
||||
@Query(() => Boolean, {
|
||||
description: 'Get is admin of workspace',
|
||||
complexity: 2,
|
||||
deprecationReason: 'use WorkspaceType[role] instead',
|
||||
})
|
||||
async isAdmin(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('workspaceId') workspaceId: string
|
||||
) {
|
||||
const role = await this.models.workspaceUser.getActive(
|
||||
workspaceId,
|
||||
user.id
|
||||
);
|
||||
|
||||
return role?.type === WorkspaceRole.Admin;
|
||||
}
|
||||
|
||||
@Query(() => [WorkspaceType], {
|
||||
description: 'Get all accessible workspaces for current user',
|
||||
complexity: 2,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import test from 'ava';
|
||||
|
||||
import { createModule } from '../../__tests__/create-module';
|
||||
@@ -23,8 +24,16 @@ test('should create access token', async t => {
|
||||
t.is(token.userId, user.id);
|
||||
t.is(token.name, 'test');
|
||||
t.truthy(token.token);
|
||||
t.true(token.token.startsWith('ut_'));
|
||||
t.truthy(token.createdAt);
|
||||
t.is(token.expiresAt, null);
|
||||
|
||||
const row = await module.get(PrismaClient).accessToken.findUnique({
|
||||
where: { id: token.id },
|
||||
});
|
||||
t.truthy(row);
|
||||
t.regex(row!.token, /^[0-9a-f]{64}$/);
|
||||
t.not(row!.token, token.token);
|
||||
});
|
||||
|
||||
test('should create access token with expiration', async t => {
|
||||
@@ -50,6 +59,22 @@ test('should list access tokens without token value', async t => {
|
||||
t.is(listed[0].token, undefined);
|
||||
});
|
||||
|
||||
test('should not reveal access token value after creation', async t => {
|
||||
const user = await module.create(Mockers.User);
|
||||
|
||||
const token = await models.accessToken.create({
|
||||
userId: user.id,
|
||||
name: 'test',
|
||||
});
|
||||
|
||||
const listed = await models.accessToken.list(user.id, true);
|
||||
const found = listed.find(item => item.id === token.id);
|
||||
|
||||
t.truthy(found);
|
||||
t.is(found!.token, '[REDACTED]');
|
||||
t.not(found!.token, token.token);
|
||||
});
|
||||
|
||||
test('should be able to revoke access token', async t => {
|
||||
const user = await module.create(Mockers.User);
|
||||
const token = await module.create(Mockers.AccessToken, { userId: user.id });
|
||||
@@ -62,7 +87,10 @@ test('should be able to revoke access token', async t => {
|
||||
|
||||
test('should be able to get access token by token value', async t => {
|
||||
const user = await module.create(Mockers.User);
|
||||
const token = await module.create(Mockers.AccessToken, { userId: user.id });
|
||||
const token = await models.accessToken.create({
|
||||
userId: user.id,
|
||||
name: 'test',
|
||||
});
|
||||
|
||||
const found = await models.accessToken.getByToken(token.token);
|
||||
t.is(found?.id, token.id);
|
||||
@@ -72,8 +100,9 @@ test('should be able to get access token by token value', async t => {
|
||||
|
||||
test('should not get expired access token', async t => {
|
||||
const user = await module.create(Mockers.User);
|
||||
const token = await module.create(Mockers.AccessToken, {
|
||||
const token = await models.accessToken.create({
|
||||
userId: user.id,
|
||||
name: 'test',
|
||||
expiresAt: Due.before('1s'),
|
||||
});
|
||||
|
||||
|
||||
@@ -3,43 +3,53 @@ import { Injectable } from '@nestjs/common';
|
||||
import { CryptoHelper } from '../base';
|
||||
import { BaseModel } from './base';
|
||||
|
||||
const REDACTED_TOKEN = '[REDACTED]';
|
||||
|
||||
export interface CreateAccessTokenInput {
|
||||
userId: string;
|
||||
name: string;
|
||||
expiresAt?: Date | null;
|
||||
}
|
||||
|
||||
type UserAccessToken = {
|
||||
id: string;
|
||||
name: string;
|
||||
createdAt: Date;
|
||||
expiresAt: Date | null;
|
||||
};
|
||||
|
||||
@Injectable()
|
||||
export class AccessTokenModel extends BaseModel {
|
||||
constructor(private readonly crypto: CryptoHelper) {
|
||||
super();
|
||||
}
|
||||
|
||||
async list(userId: string, revealed?: false): Promise<UserAccessToken[]>;
|
||||
async list(
|
||||
userId: string,
|
||||
revealed: true
|
||||
): Promise<(UserAccessToken & { token: string })[]>;
|
||||
async list(userId: string, revealed: boolean = false) {
|
||||
return await this.db.accessToken.findMany({
|
||||
select: {
|
||||
id: true,
|
||||
name: true,
|
||||
createdAt: true,
|
||||
expiresAt: true,
|
||||
token: revealed,
|
||||
},
|
||||
where: {
|
||||
userId,
|
||||
},
|
||||
const tokens = await this.db.accessToken.findMany({
|
||||
select: { id: true, name: true, createdAt: true, expiresAt: true },
|
||||
where: { userId },
|
||||
});
|
||||
|
||||
if (!revealed) return tokens;
|
||||
|
||||
return tokens.map(row => ({ ...row, token: REDACTED_TOKEN }));
|
||||
}
|
||||
|
||||
async create(input: CreateAccessTokenInput) {
|
||||
let token = 'ut_' + this.crypto.randomBytes(40).toString('hex');
|
||||
token = token.substring(0, 40);
|
||||
const token = `ut_${this.crypto.randomBytes(32).toString('base64url')}`;
|
||||
const tokenHash = this.crypto.sha256(token).toString('hex');
|
||||
|
||||
return await this.db.accessToken.create({
|
||||
data: {
|
||||
token,
|
||||
...input,
|
||||
},
|
||||
const created = await this.db.accessToken.create({
|
||||
data: { token: tokenHash, ...input },
|
||||
});
|
||||
|
||||
// NOTE: we only return the plaintext token once, at creation time.
|
||||
return { ...created, token };
|
||||
}
|
||||
|
||||
async revoke(id: string, userId: string) {
|
||||
@@ -52,20 +62,27 @@ export class AccessTokenModel extends BaseModel {
|
||||
}
|
||||
|
||||
async getByToken(token: string) {
|
||||
return await this.db.accessToken.findUnique({
|
||||
where: {
|
||||
token,
|
||||
OR: [
|
||||
{
|
||||
expiresAt: null,
|
||||
},
|
||||
{
|
||||
expiresAt: {
|
||||
gt: new Date(),
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
const tokenHash = this.crypto.sha256(token).toString('hex');
|
||||
|
||||
const condition = [{ expiresAt: null }, { expiresAt: { gt: new Date() } }];
|
||||
const found = await this.db.accessToken.findUnique({
|
||||
where: { token: tokenHash, OR: condition },
|
||||
});
|
||||
|
||||
if (found) return found;
|
||||
|
||||
// Compatibility: lazy-migrate old plaintext tokens in DB.
|
||||
const legacy = await this.db.accessToken.findUnique({
|
||||
where: { token, OR: condition },
|
||||
});
|
||||
|
||||
if (!legacy) return null;
|
||||
|
||||
await this.db.accessToken.update({
|
||||
where: { id: legacy.id },
|
||||
data: { token: tokenHash },
|
||||
});
|
||||
|
||||
return { ...legacy, token: tokenHash };
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,7 +131,7 @@ export class DocModel extends BaseModel {
|
||||
},
|
||||
});
|
||||
if (count > 0) {
|
||||
this.logger.log(
|
||||
this.logger.verbose(
|
||||
`Deleted ${count} updates for workspace ${workspaceId} doc ${docId}`
|
||||
);
|
||||
}
|
||||
@@ -159,7 +159,7 @@ export class DocModel extends BaseModel {
|
||||
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
const result: { updatedAt: Date }[] = await this.db.$queryRaw`
|
||||
INSERT INTO "snapshots" ("workspace_id", "guid", "blob", "size", "created_at", "updated_at", "created_by", "updated_by")
|
||||
VALUES (${spaceId}, ${docId}, ${blob}, ${size}, DEFAULT, ${updatedAt}, ${editorId}, ${editorId})
|
||||
VALUES (${spaceId}, ${docId}, ${blob}, ${size}, ${updatedAt}, ${updatedAt}, ${editorId}, ${editorId})
|
||||
ON CONFLICT ("workspace_id", "guid")
|
||||
DO UPDATE SET "blob" = ${blob}, "size" = ${size}, "updated_at" = ${updatedAt}, "updated_by" = ${editorId}
|
||||
WHERE "snapshots"."workspace_id" = ${spaceId} AND "snapshots"."guid" = ${docId} AND "snapshots"."updated_at" <= ${updatedAt}
|
||||
|
||||
@@ -24,6 +24,7 @@ import { DocModel } from './doc';
|
||||
import { DocUserModel } from './doc-user';
|
||||
import { FeatureModel } from './feature';
|
||||
import { HistoryModel } from './history';
|
||||
import { MagicLinkOtpModel } from './magic-link-otp';
|
||||
import { NotificationModel } from './notification';
|
||||
import { MODELS_SYMBOL } from './provider';
|
||||
import { SessionModel } from './session';
|
||||
@@ -41,6 +42,7 @@ const MODELS = {
|
||||
user: UserModel,
|
||||
session: SessionModel,
|
||||
verificationToken: VerificationTokenModel,
|
||||
magicLinkOtp: MagicLinkOtpModel,
|
||||
feature: FeatureModel,
|
||||
workspace: WorkspaceModel,
|
||||
userFeature: UserFeatureModel,
|
||||
@@ -133,6 +135,7 @@ export * from './doc';
|
||||
export * from './doc-user';
|
||||
export * from './feature';
|
||||
export * from './history';
|
||||
export * from './magic-link-otp';
|
||||
export * from './notification';
|
||||
export * from './session';
|
||||
export * from './user';
|
||||
|
||||
86
packages/backend/server/src/models/magic-link-otp.ts
Normal file
86
packages/backend/server/src/models/magic-link-otp.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Transactional } from '@nestjs-cls/transactional';
|
||||
|
||||
import { CryptoHelper } from '../base';
|
||||
import { BaseModel } from './base';
|
||||
|
||||
const MAX_OTP_ATTEMPTS = 10;
|
||||
const OTP_TTL_IN_SEC = 30 * 60;
|
||||
|
||||
export type ConsumeMagicLinkOtpResult =
|
||||
| { ok: true; token: string }
|
||||
| { ok: false; reason: 'not_found' | 'expired' | 'invalid_otp' | 'locked' }
|
||||
| { ok: false; reason: 'nonce_mismatch' };
|
||||
|
||||
@Injectable()
|
||||
export class MagicLinkOtpModel extends BaseModel {
|
||||
constructor(private readonly crypto: CryptoHelper) {
|
||||
super();
|
||||
}
|
||||
|
||||
private hash(otp: string) {
|
||||
return this.crypto.sha256(otp).toString('hex');
|
||||
}
|
||||
|
||||
async upsert(
|
||||
email: string,
|
||||
otp: string,
|
||||
token: string,
|
||||
clientNonce?: string
|
||||
) {
|
||||
const otpHash = this.hash(otp);
|
||||
const expiresAt = new Date(Date.now() + OTP_TTL_IN_SEC * 1000);
|
||||
|
||||
await this.db.magicLinkOtp.upsert({
|
||||
where: { email },
|
||||
create: { email, otpHash, token, clientNonce, expiresAt, attempts: 0 },
|
||||
update: { otpHash, token, clientNonce, expiresAt, attempts: 0 },
|
||||
});
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async consume(
|
||||
email: string,
|
||||
otp: string,
|
||||
clientNonce?: string
|
||||
): Promise<ConsumeMagicLinkOtpResult> {
|
||||
const now = new Date();
|
||||
const otpHash = this.hash(otp);
|
||||
|
||||
const record = await this.db.magicLinkOtp.findUnique({ where: { email } });
|
||||
if (!record) {
|
||||
return { ok: false, reason: 'not_found' };
|
||||
}
|
||||
|
||||
if (record.expiresAt <= now) {
|
||||
await this.db.magicLinkOtp.delete({ where: { email } });
|
||||
return { ok: false, reason: 'expired' };
|
||||
}
|
||||
|
||||
if (record.clientNonce && record.clientNonce !== clientNonce) {
|
||||
return { ok: false, reason: 'nonce_mismatch' };
|
||||
}
|
||||
|
||||
if (record.attempts >= MAX_OTP_ATTEMPTS) {
|
||||
await this.db.magicLinkOtp.delete({ where: { email } });
|
||||
return { ok: false, reason: 'locked' };
|
||||
}
|
||||
|
||||
const matches = this.crypto.compare(record.otpHash, otpHash);
|
||||
if (!matches) {
|
||||
const attempts = record.attempts + 1;
|
||||
if (attempts >= MAX_OTP_ATTEMPTS) {
|
||||
await this.db.magicLinkOtp.delete({ where: { email } });
|
||||
return { ok: false, reason: 'locked' };
|
||||
}
|
||||
await this.db.magicLinkOtp.update({
|
||||
where: { email },
|
||||
data: { attempts },
|
||||
});
|
||||
return { ok: false, reason: 'invalid_otp' };
|
||||
}
|
||||
|
||||
await this.db.magicLinkOtp.delete({ where: { email } });
|
||||
return { ok: true, token: record.token };
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ export enum TokenType {
|
||||
ChangeEmail,
|
||||
ChangePassword,
|
||||
Challenge,
|
||||
OpenAppSignIn,
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
|
||||
@@ -302,6 +302,29 @@ export class WorkspaceUserModel extends BaseModel {
|
||||
});
|
||||
}
|
||||
|
||||
async hasSharedWorkspace(userId: string, otherUserId: string) {
|
||||
if (userId === otherUserId) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const shared = await this.db.workspaceUserRole.findFirst({
|
||||
select: { id: true },
|
||||
where: {
|
||||
userId,
|
||||
status: WorkspaceMemberStatus.Accepted,
|
||||
workspace: {
|
||||
permissions: {
|
||||
some: {
|
||||
userId: otherUserId,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return !!shared;
|
||||
}
|
||||
|
||||
async paginate(workspaceId: string, pagination: PaginationInput) {
|
||||
return await Promise.all([
|
||||
this.db.workspaceUserRole.findMany({
|
||||
|
||||
@@ -105,10 +105,6 @@ class RemoveContextDocInput {
|
||||
class AddContextFileInput {
|
||||
@Field(() => String)
|
||||
contextId!: string;
|
||||
|
||||
// @TODO(@darkskygit): remove this after client lower then 0.22 has been disconnected
|
||||
@Field(() => String, { nullable: true, deprecationReason: 'Never used' })
|
||||
blobId!: string | undefined;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
|
||||
@@ -1672,42 +1672,12 @@ const imageActions: Prompt[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
// TODO(@darkskygit): deprecated, remove it after <0.22 version is outdated
|
||||
{
|
||||
name: 'debug:action:fal-remove-bg',
|
||||
action: 'Remove background',
|
||||
model: 'imageutils/rembg',
|
||||
messages: [],
|
||||
},
|
||||
{
|
||||
name: 'debug:action:fal-face-to-sticker',
|
||||
action: 'Convert to sticker',
|
||||
model: 'face-to-sticker',
|
||||
messages: [],
|
||||
},
|
||||
{
|
||||
name: 'debug:action:fal-teed',
|
||||
action: 'fal-teed',
|
||||
model: 'workflowutils/teed',
|
||||
messages: [{ role: 'user', content: '{{content}}' }],
|
||||
},
|
||||
{
|
||||
name: 'debug:action:fal-sd15',
|
||||
action: 'image',
|
||||
model: 'lcm-sd15-i2i',
|
||||
messages: [],
|
||||
},
|
||||
{
|
||||
name: 'debug:action:fal-upscaler',
|
||||
action: 'Clearer',
|
||||
model: 'clarity-upscaler',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'best quality, 8K resolution, highres, clarity, {{content}}',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const modelActions: Prompt[] = [
|
||||
|
||||
@@ -24,7 +24,9 @@ import {
|
||||
CopilotPromptInvalid,
|
||||
CopilotProviderNotSupported,
|
||||
CopilotProviderSideError,
|
||||
fetchBuffer,
|
||||
metrics,
|
||||
OneMB,
|
||||
UserFriendlyError,
|
||||
} from '../../../base';
|
||||
import { CopilotProvider } from './provider';
|
||||
@@ -673,14 +675,12 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
for (const [idx, entry] of attachments.entries()) {
|
||||
const url = typeof entry === 'string' ? entry : entry.attachment;
|
||||
const resp = await fetch(url);
|
||||
if (resp.ok) {
|
||||
const type = resp.headers.get('content-type');
|
||||
if (type && type.startsWith('image/')) {
|
||||
const buffer = new Uint8Array(await resp.arrayBuffer());
|
||||
const file = new File([buffer], `${idx}.png`, { type });
|
||||
form.append('image[]', file);
|
||||
}
|
||||
try {
|
||||
const { buffer, type } = await fetchBuffer(url, 10 * OneMB, 'image/');
|
||||
const file = new File([buffer], `${idx}.png`, { type });
|
||||
form.append('image[]', file);
|
||||
} catch {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,11 +12,22 @@ import {
|
||||
import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library';
|
||||
import z, { ZodType } from 'zod';
|
||||
|
||||
import {
|
||||
bufferToArrayBuffer,
|
||||
fetchBuffer,
|
||||
OneMinute,
|
||||
ResponseTooLargeError,
|
||||
safeFetch,
|
||||
SsrfBlockedError,
|
||||
} from '../../../base';
|
||||
import { CustomAITools } from '../tools';
|
||||
import { PromptMessage, StreamObject } from './types';
|
||||
|
||||
type ChatMessage = CoreUserMessage | CoreAssistantMessage;
|
||||
|
||||
const ATTACHMENT_MAX_BYTES = 20 * 1024 * 1024;
|
||||
const ATTACH_HEAD_PARAMS = { timeoutMs: OneMinute / 12, maxRedirects: 3 };
|
||||
|
||||
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
|
||||
const FORMAT_INFER_MAP: Record<string, string> = {
|
||||
pdf: 'application/pdf',
|
||||
@@ -42,6 +53,11 @@ const FORMAT_INFER_MAP: Record<string, string> = {
|
||||
flv: 'video/flv',
|
||||
};
|
||||
|
||||
async function fetchArrayBuffer(url: string): Promise<ArrayBuffer> {
|
||||
const { buffer } = await fetchBuffer(url, ATTACHMENT_MAX_BYTES);
|
||||
return bufferToArrayBuffer(buffer);
|
||||
}
|
||||
|
||||
export async function inferMimeType(url: string) {
|
||||
if (url.startsWith('data:')) {
|
||||
return url.split(';')[0].split(':')[1];
|
||||
@@ -53,12 +69,15 @@ export async function inferMimeType(url: string) {
|
||||
if (ext) {
|
||||
return ext;
|
||||
}
|
||||
const mimeType = await fetch(url, {
|
||||
method: 'HEAD',
|
||||
redirect: 'follow',
|
||||
}).then(res => res.headers.get('Content-Type'));
|
||||
if (mimeType) {
|
||||
return mimeType;
|
||||
try {
|
||||
const mimeType = await safeFetch(
|
||||
url,
|
||||
{ method: 'HEAD' },
|
||||
ATTACH_HEAD_PARAMS
|
||||
).then(res => res.headers.get('content-type'));
|
||||
if (mimeType) return mimeType;
|
||||
} catch {
|
||||
// ignore and fallback to default
|
||||
}
|
||||
}
|
||||
return 'application/octet-stream';
|
||||
@@ -106,7 +125,16 @@ export async function chatToGPTMessage(
|
||||
if (SIMPLE_IMAGE_URL_REGEX.test(attachment)) {
|
||||
const data =
|
||||
attachment.startsWith('data:') || useBase64Attachment
|
||||
? await fetch(attachment).then(r => r.arrayBuffer())
|
||||
? await fetchArrayBuffer(attachment).catch(error => {
|
||||
// Avoid leaking internal details for blocked URLs.
|
||||
if (
|
||||
error instanceof SsrfBlockedError ||
|
||||
error instanceof ResponseTooLargeError
|
||||
) {
|
||||
throw new Error('Attachment URL is not allowed');
|
||||
}
|
||||
throw error;
|
||||
})
|
||||
: new URL(attachment);
|
||||
if (mediaType.startsWith('image/')) {
|
||||
contents.push({ type: 'image', image: data, mediaType });
|
||||
|
||||
@@ -7,7 +7,9 @@ import {
|
||||
BlobQuotaExceeded,
|
||||
CallMetric,
|
||||
Config,
|
||||
fetchBuffer,
|
||||
type FileUpload,
|
||||
OneMB,
|
||||
OnEvent,
|
||||
readBuffer,
|
||||
type StorageProvider,
|
||||
@@ -16,6 +18,8 @@ import {
|
||||
} from '../../base';
|
||||
import { QuotaService } from '../../core/quota';
|
||||
|
||||
const REMOTE_BLOB_MAX_BYTES = 20 * OneMB;
|
||||
|
||||
@Injectable()
|
||||
export class CopilotStorage {
|
||||
public provider!: StorageProvider;
|
||||
@@ -88,9 +92,8 @@ export class CopilotStorage {
|
||||
|
||||
@CallMetric('ai', 'blob_proxy_remote_url')
|
||||
async handleRemoteLink(userId: string, workspaceId: string, link: string) {
|
||||
const response = await fetch(link);
|
||||
const buffer = new Uint8Array(await response.arrayBuffer());
|
||||
const { buffer } = await fetchBuffer(link, REMOTE_BLOB_MAX_BYTES, 'image/');
|
||||
const filename = createHash('sha256').update(buffer).digest('base64url');
|
||||
return this.put(userId, workspaceId, filename, Buffer.from(buffer));
|
||||
return this.put(userId, workspaceId, filename, buffer);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import { ConnectedAccount } from '@prisma/client';
|
||||
import type { Request, Response } from 'express';
|
||||
|
||||
import {
|
||||
ActionForbidden,
|
||||
Config,
|
||||
InvalidAuthState,
|
||||
InvalidOauthCallbackState,
|
||||
@@ -57,6 +58,9 @@ export class OAuthController {
|
||||
if (!unknownProviderName) {
|
||||
throw new MissingOauthQueryParameter({ name: 'provider' });
|
||||
}
|
||||
if (!clientNonce) {
|
||||
throw new MissingOauthQueryParameter({ name: 'client_nonce' });
|
||||
}
|
||||
|
||||
const providerName = OAuthProviderName[unknownProviderName];
|
||||
const provider = this.providerFactory.get(providerName);
|
||||
@@ -67,6 +71,10 @@ export class OAuthController {
|
||||
|
||||
const pkce = provider.requiresPkce ? this.oauth.createPkcePair() : null;
|
||||
|
||||
if (redirectUri && !this.url.isAllowedRedirectUri(redirectUri)) {
|
||||
throw new ActionForbidden();
|
||||
}
|
||||
|
||||
const state = await this.oauth.saveOAuthState({
|
||||
provider: providerName,
|
||||
redirectUri,
|
||||
@@ -173,16 +181,6 @@ export class OAuthController {
|
||||
);
|
||||
}
|
||||
|
||||
// TODO(@fengmk2): clientNonce should be required after the client version >= 0.21.0
|
||||
if (
|
||||
state.clientNonce &&
|
||||
state.clientNonce !== clientNonce &&
|
||||
// apple sign in with nonce stored in id token
|
||||
state.provider !== OAuthProviderName.Apple
|
||||
) {
|
||||
throw new InvalidAuthState();
|
||||
}
|
||||
|
||||
if (!state.provider) {
|
||||
throw new MissingOauthQueryParameter({ name: 'provider' });
|
||||
}
|
||||
@@ -193,6 +191,13 @@ export class OAuthController {
|
||||
throw new UnknownOauthProvider({ name: state.provider ?? 'unknown' });
|
||||
}
|
||||
|
||||
if (
|
||||
state.provider !== OAuthProviderName.Apple &&
|
||||
(!clientNonce || !state.clientNonce || state.clientNonce !== clientNonce)
|
||||
) {
|
||||
throw new InvalidAuthState();
|
||||
}
|
||||
|
||||
let tokens: Tokens;
|
||||
try {
|
||||
tokens = await provider.getToken(code, state);
|
||||
@@ -221,7 +226,7 @@ export class OAuthController {
|
||||
state.provider === OAuthProviderName.Apple &&
|
||||
(!state.client || state.client === 'web')
|
||||
) {
|
||||
return res.redirect(this.url.link(state.redirectUri ?? '/'));
|
||||
return this.url.safeRedirect(res, state.redirectUri ?? '/');
|
||||
}
|
||||
|
||||
res.send({
|
||||
|
||||
@@ -1,38 +1,17 @@
|
||||
import {
|
||||
Context,
|
||||
registerEnumType,
|
||||
ResolveField,
|
||||
Resolver,
|
||||
} from '@nestjs/graphql';
|
||||
import type { Request } from 'express';
|
||||
import semver from 'semver';
|
||||
import { registerEnumType, ResolveField, Resolver } from '@nestjs/graphql';
|
||||
|
||||
import { getClientVersionFromRequest } from '../../base';
|
||||
import { ServerConfigType } from '../../core/config/types';
|
||||
import { OAuthProviderName } from './config';
|
||||
import { OAuthProviderFactory } from './factory';
|
||||
|
||||
registerEnumType(OAuthProviderName, { name: 'OAuthProviderType' });
|
||||
|
||||
const APPLE_OAUTH_PROVIDER_MIN_VERSION = new semver.Range('>=0.22.0', {
|
||||
includePrerelease: true,
|
||||
});
|
||||
|
||||
@Resolver(() => ServerConfigType)
|
||||
export class OAuthResolver {
|
||||
constructor(private readonly factory: OAuthProviderFactory) {}
|
||||
|
||||
@ResolveField(() => [OAuthProviderName])
|
||||
oauthProviders(@Context() ctx: { req: Request }) {
|
||||
// Apple oauth provider is not supported in client version < 0.22.0
|
||||
const providers = this.factory.providers;
|
||||
if (providers.includes(OAuthProviderName.Apple)) {
|
||||
const version = getClientVersionFromRequest(ctx.req);
|
||||
if (!version || !APPLE_OAUTH_PROVIDER_MIN_VERSION.test(version)) {
|
||||
return providers.filter(p => p !== OAuthProviderName.Apple);
|
||||
}
|
||||
}
|
||||
|
||||
return providers;
|
||||
oauthProviders() {
|
||||
return this.factory.providers;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,10 +7,23 @@ import {
|
||||
Req,
|
||||
Res,
|
||||
} from '@nestjs/common';
|
||||
import type { Request, Response } from 'express';
|
||||
import type {
|
||||
Request as ExpressRequest,
|
||||
Response as ExpressResponse,
|
||||
} from 'express';
|
||||
import { HTMLRewriter } from 'htmlrewriter';
|
||||
|
||||
import { BadRequest, Cache, URLHelper, UseNamedGuard } from '../../base';
|
||||
import {
|
||||
BadRequest,
|
||||
Cache,
|
||||
readResponseBufferWithLimit,
|
||||
ResponseTooLargeError,
|
||||
safeFetch,
|
||||
SsrfBlockedError,
|
||||
type SSRFBlockReason,
|
||||
URLHelper,
|
||||
UseNamedGuard,
|
||||
} from '../../base';
|
||||
import { Public } from '../../core/auth';
|
||||
import { WorkerService } from './service';
|
||||
import type { LinkPreviewRequest, LinkPreviewResponse } from './types';
|
||||
@@ -28,6 +41,25 @@ import { decodeWithCharset } from './utils/encoding';
|
||||
|
||||
// cache for 30 minutes
|
||||
const CACHE_TTL = 1000 * 60 * 30;
|
||||
const MAX_REDIRECTS = 3;
|
||||
const FETCH_TIMEOUT_MS = 10_000;
|
||||
const IMAGE_PROXY_MAX_BYTES = 10 * 1024 * 1024;
|
||||
const LINK_PREVIEW_MAX_BYTES = 2 * 1024 * 1024;
|
||||
|
||||
function toBadRequestReason(reason: SSRFBlockReason) {
|
||||
switch (reason) {
|
||||
case 'disallowed_protocol':
|
||||
case 'url_has_credentials':
|
||||
case 'blocked_hostname':
|
||||
case 'blocked_ip':
|
||||
case 'invalid_url':
|
||||
return 'Invalid URL';
|
||||
case 'unresolvable_hostname':
|
||||
return 'Failed to resolve hostname';
|
||||
case 'too_many_redirects':
|
||||
return 'Too many redirects';
|
||||
}
|
||||
}
|
||||
|
||||
@Public()
|
||||
@UseNamedGuard('selfhost')
|
||||
@@ -45,14 +77,33 @@ export class WorkerController {
|
||||
return this.service.allowedOrigins;
|
||||
}
|
||||
|
||||
@Options('/image-proxy')
|
||||
imageProxyOption(
|
||||
@Req() request: ExpressRequest,
|
||||
@Res() resp: ExpressResponse
|
||||
) {
|
||||
const origin = request.headers.origin;
|
||||
return resp
|
||||
.status(204)
|
||||
.header({
|
||||
...getCorsHeaders(origin),
|
||||
'Access-Control-Allow-Methods': 'GET, OPTIONS',
|
||||
'Access-Control-Allow-Headers': 'Content-Type',
|
||||
})
|
||||
.send();
|
||||
}
|
||||
|
||||
@Get('/image-proxy')
|
||||
async imageProxy(@Req() req: Request, @Res() resp: Response) {
|
||||
const origin = req.headers.origin ?? '';
|
||||
async imageProxy(@Req() req: ExpressRequest, @Res() resp: ExpressResponse) {
|
||||
const origin = req.headers.origin;
|
||||
const referer = req.headers.referer;
|
||||
if (
|
||||
(origin && !isOriginAllowed(origin, this.allowedOrigin)) ||
|
||||
(referer && !isRefererAllowed(referer, this.allowedOrigin))
|
||||
) {
|
||||
const originAllowed = origin
|
||||
? isOriginAllowed(origin, this.allowedOrigin)
|
||||
: false;
|
||||
const refererAllowed = referer
|
||||
? isRefererAllowed(referer, this.allowedOrigin)
|
||||
: false;
|
||||
if (!originAllowed && !refererAllowed) {
|
||||
this.logger.error('Invalid Origin', 'ERROR', { origin, referer });
|
||||
throw new BadRequest('Invalid header');
|
||||
}
|
||||
@@ -79,24 +130,66 @@ export class WorkerController {
|
||||
return resp
|
||||
.status(200)
|
||||
.header({
|
||||
'Access-Control-Allow-Origin': origin,
|
||||
Vary: 'Origin',
|
||||
...getCorsHeaders(origin),
|
||||
...(origin ? { Vary: 'Origin' } : {}),
|
||||
'Access-Control-Allow-Methods': 'GET',
|
||||
'Content-Type': 'image/*',
|
||||
})
|
||||
.send(buffer);
|
||||
}
|
||||
|
||||
const response = await fetch(
|
||||
new Request(targetURL.toString(), {
|
||||
method: 'GET',
|
||||
headers: cloneHeader(req.headers),
|
||||
})
|
||||
);
|
||||
let response: Response;
|
||||
try {
|
||||
response = await safeFetch(
|
||||
targetURL.toString(),
|
||||
{ method: 'GET', headers: cloneHeader(req.headers) },
|
||||
{ timeoutMs: FETCH_TIMEOUT_MS, maxRedirects: MAX_REDIRECTS }
|
||||
);
|
||||
} catch (error) {
|
||||
if (error instanceof SsrfBlockedError) {
|
||||
const reason = error.data?.reason as SSRFBlockReason | undefined;
|
||||
this.logger.warn('Blocked image proxy target', {
|
||||
url: imageURL,
|
||||
reason,
|
||||
context: (error as any).context,
|
||||
});
|
||||
throw new BadRequest(toBadRequestReason(reason ?? 'invalid_url'));
|
||||
}
|
||||
if (error instanceof ResponseTooLargeError) {
|
||||
this.logger.warn('Image proxy response too large', {
|
||||
url: imageURL,
|
||||
limitBytes: error.data?.limitBytes,
|
||||
receivedBytes: error.data?.receivedBytes,
|
||||
});
|
||||
throw new BadRequest('Response too large');
|
||||
}
|
||||
this.logger.error('Failed to fetch image', {
|
||||
origin,
|
||||
url: imageURL,
|
||||
error,
|
||||
});
|
||||
throw new BadRequest('Failed to fetch image');
|
||||
}
|
||||
if (response.ok) {
|
||||
const contentType = response.headers.get('Content-Type');
|
||||
if (contentType?.startsWith('image/')) {
|
||||
const buffer = Buffer.from(await response.arrayBuffer());
|
||||
let buffer: Buffer;
|
||||
try {
|
||||
buffer = await readResponseBufferWithLimit(
|
||||
response,
|
||||
IMAGE_PROXY_MAX_BYTES
|
||||
);
|
||||
} catch (error) {
|
||||
if (error instanceof ResponseTooLargeError) {
|
||||
this.logger.warn('Image proxy response too large', {
|
||||
url: imageURL,
|
||||
limitBytes: error.data?.limitBytes,
|
||||
receivedBytes: error.data?.receivedBytes,
|
||||
});
|
||||
throw new BadRequest('Response too large');
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
await this.cache.set(cachedUrl, buffer.toString('base64'), {
|
||||
ttl: CACHE_TTL,
|
||||
});
|
||||
@@ -104,8 +197,8 @@ export class WorkerController {
|
||||
return resp
|
||||
.status(200)
|
||||
.header({
|
||||
'Access-Control-Allow-Origin': origin ?? 'null',
|
||||
Vary: 'Origin',
|
||||
...getCorsHeaders(origin),
|
||||
...(origin ? { Vary: 'Origin' } : {}),
|
||||
'Access-Control-Allow-Methods': 'GET',
|
||||
'Content-Type': contentType,
|
||||
'Content-Disposition': contentDisposition,
|
||||
@@ -124,17 +217,20 @@ export class WorkerController {
|
||||
this.logger.error('Failed to fetch image', {
|
||||
origin,
|
||||
url: imageURL,
|
||||
status: resp.status,
|
||||
status: response.status,
|
||||
});
|
||||
throw new BadRequest('Failed to fetch image');
|
||||
}
|
||||
}
|
||||
|
||||
@Options('/link-preview')
|
||||
linkPreviewOption(@Req() request: Request, @Res() resp: Response) {
|
||||
linkPreviewOption(
|
||||
@Req() request: ExpressRequest,
|
||||
@Res() resp: ExpressResponse
|
||||
) {
|
||||
const origin = request.headers.origin;
|
||||
return resp
|
||||
.status(200)
|
||||
.status(204)
|
||||
.header({
|
||||
...getCorsHeaders(origin),
|
||||
'Access-Control-Allow-Methods': 'POST, OPTIONS',
|
||||
@@ -145,15 +241,18 @@ export class WorkerController {
|
||||
|
||||
@Post('/link-preview')
|
||||
async linkPreview(
|
||||
@Req() request: Request,
|
||||
@Res() resp: Response
|
||||
): Promise<Response> {
|
||||
@Req() request: ExpressRequest,
|
||||
@Res() resp: ExpressResponse
|
||||
): Promise<ExpressResponse> {
|
||||
const origin = request.headers.origin;
|
||||
const referer = request.headers.referer;
|
||||
if (
|
||||
(origin && !isOriginAllowed(origin, this.allowedOrigin)) ||
|
||||
(referer && !isRefererAllowed(referer, this.allowedOrigin))
|
||||
) {
|
||||
const originAllowed = origin
|
||||
? isOriginAllowed(origin, this.allowedOrigin)
|
||||
: false;
|
||||
const refererAllowed = referer
|
||||
? isRefererAllowed(referer, this.allowedOrigin)
|
||||
: false;
|
||||
if (!originAllowed && !refererAllowed) {
|
||||
this.logger.error('Invalid Origin', { origin, referer });
|
||||
throw new BadRequest('Invalid header');
|
||||
}
|
||||
@@ -183,9 +282,13 @@ export class WorkerController {
|
||||
.send(cachedResponse);
|
||||
}
|
||||
|
||||
const response = await fetch(targetURL, {
|
||||
headers: cloneHeader(request.headers),
|
||||
});
|
||||
const method: 'GET' | 'HEAD' = requestBody?.head ? 'HEAD' : 'GET';
|
||||
|
||||
const response = await safeFetch(
|
||||
targetURL.toString(),
|
||||
{ method, headers: cloneHeader(request.headers) },
|
||||
{ timeoutMs: FETCH_TIMEOUT_MS, maxRedirects: MAX_REDIRECTS }
|
||||
);
|
||||
this.logger.debug('Fetched URL', {
|
||||
origin,
|
||||
url: targetURL,
|
||||
@@ -211,7 +314,12 @@ export class WorkerController {
|
||||
};
|
||||
|
||||
if (response.body) {
|
||||
const resp = await decodeWithCharset(response, res);
|
||||
const body = await readResponseBufferWithLimit(
|
||||
response,
|
||||
LINK_PREVIEW_MAX_BYTES
|
||||
);
|
||||
const limitedResponse = new Response(body, response);
|
||||
const resp = await decodeWithCharset(limitedResponse, res);
|
||||
|
||||
const rewriter = new HTMLRewriter()
|
||||
.on('meta', {
|
||||
@@ -287,7 +395,11 @@ export class WorkerController {
|
||||
{
|
||||
// head default path of favicon
|
||||
const faviconUrl = new URL('/favicon.ico?v=2', response.url);
|
||||
const faviconResponse = await fetch(faviconUrl, { method: 'HEAD' });
|
||||
const faviconResponse = await safeFetch(
|
||||
faviconUrl.toString(),
|
||||
{ method: 'HEAD' },
|
||||
{ timeoutMs: FETCH_TIMEOUT_MS, maxRedirects: MAX_REDIRECTS }
|
||||
);
|
||||
if (faviconResponse.ok) {
|
||||
appendUrl(faviconUrl.toString(), res.favicons);
|
||||
}
|
||||
@@ -311,6 +423,25 @@ export class WorkerController {
|
||||
})
|
||||
.send(json);
|
||||
} catch (error) {
|
||||
if (error instanceof SsrfBlockedError) {
|
||||
const reason = error.data?.reason as SSRFBlockReason | undefined;
|
||||
this.logger.warn('Blocked link preview target', {
|
||||
origin,
|
||||
url: requestBody?.url,
|
||||
reason,
|
||||
context: (error as any).context,
|
||||
});
|
||||
throw new BadRequest(toBadRequestReason(reason ?? 'invalid_url'));
|
||||
}
|
||||
if (error instanceof ResponseTooLargeError) {
|
||||
this.logger.warn('Link preview response too large', {
|
||||
origin,
|
||||
url: requestBody?.url,
|
||||
limitBytes: error.data?.limitBytes,
|
||||
receivedBytes: error.data?.receivedBytes,
|
||||
});
|
||||
throw new BadRequest('Response too large');
|
||||
}
|
||||
this.logger.error('Error fetching URL', {
|
||||
origin,
|
||||
url: targetURL,
|
||||
|
||||
@@ -27,7 +27,6 @@ input AddContextDocInput {
|
||||
}
|
||||
|
||||
input AddContextFileInput {
|
||||
blobId: String
|
||||
contextId: String!
|
||||
}
|
||||
|
||||
@@ -798,7 +797,7 @@ type EditorType {
|
||||
name: String!
|
||||
}
|
||||
|
||||
union ErrorDataUnion = AlreadyInSpaceDataType | BlobNotFoundDataType | CalendarProviderRequestErrorDataType | CopilotContextFileNotSupportedDataType | CopilotDocNotFoundDataType | CopilotFailedToAddWorkspaceFileEmbeddingDataType | CopilotFailedToGenerateEmbeddingDataType | CopilotFailedToMatchContextDataType | CopilotFailedToMatchGlobalContextDataType | CopilotFailedToModifyContextDataType | CopilotInvalidContextDataType | CopilotMessageNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderNotSupportedDataType | CopilotProviderSideErrorDataType | DocActionDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | DocUpdateBlockedDataType | ExpectToGrantDocUserRolesDataType | ExpectToRevokeDocUserRolesDataType | ExpectToUpdateDocUserRoleDataType | GraphqlBadRequestDataType | HttpRequestErrorDataType | InvalidAppConfigDataType | InvalidAppConfigInputDataType | InvalidEmailDataType | InvalidHistoryTimestampDataType | InvalidIndexerInputDataType | InvalidLicenseToActivateDataType | InvalidLicenseUpdateParamsDataType | InvalidOauthCallbackCodeDataType | InvalidOauthResponseDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | InvalidSearchProviderRequestDataType | MemberNotFoundInSpaceDataType | MentionUserDocAccessDeniedDataType | MissingOauthQueryParameterDataType | NoCopilotProviderAvailableDataType | NoMoreSeatDataType | NotInSpaceDataType | QueryTooLongDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SpaceAccessDeniedDataType | SpaceNotFoundDataType | SpaceOwnerNotFoundDataType | SpaceShouldHaveOnlyOneOwnerDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | UnsupportedClientVersionDataType | UnsupportedSubscriptionPlanDataType | ValidationErrorDataType | VersionRejectedDataType | WorkspacePermissionNotFoundDataType | WrongSignInCredentialsDataType
|
||||
union ErrorDataUnion = AlreadyInSpaceDataType | BlobNotFoundDataType | CalendarProviderRequestErrorDataType | CopilotContextFileNotSupportedDataType | CopilotDocNotFoundDataType | CopilotFailedToAddWorkspaceFileEmbeddingDataType | CopilotFailedToGenerateEmbeddingDataType | CopilotFailedToMatchContextDataType | CopilotFailedToMatchGlobalContextDataType | CopilotFailedToModifyContextDataType | CopilotInvalidContextDataType | CopilotMessageNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderNotSupportedDataType | CopilotProviderSideErrorDataType | DocActionDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | DocUpdateBlockedDataType | ExpectToGrantDocUserRolesDataType | ExpectToRevokeDocUserRolesDataType | ExpectToUpdateDocUserRoleDataType | GraphqlBadRequestDataType | HttpRequestErrorDataType | InvalidAppConfigDataType | InvalidAppConfigInputDataType | InvalidEmailDataType | InvalidHistoryTimestampDataType | InvalidIndexerInputDataType | InvalidLicenseToActivateDataType | InvalidLicenseUpdateParamsDataType | InvalidOauthCallbackCodeDataType | InvalidOauthResponseDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | InvalidSearchProviderRequestDataType | MemberNotFoundInSpaceDataType | MentionUserDocAccessDeniedDataType | MissingOauthQueryParameterDataType | NoCopilotProviderAvailableDataType | NoMoreSeatDataType | NotInSpaceDataType | QueryTooLongDataType | ResponseTooLargeErrorDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SpaceAccessDeniedDataType | SpaceNotFoundDataType | SpaceOwnerNotFoundDataType | SpaceShouldHaveOnlyOneOwnerDataType | SsrfBlockedErrorDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | UnsupportedClientVersionDataType | UnsupportedSubscriptionPlanDataType | ValidationErrorDataType | VersionRejectedDataType | WorkspacePermissionNotFoundDataType | WrongSignInCredentialsDataType
|
||||
|
||||
enum ErrorNames {
|
||||
ACCESS_DENIED
|
||||
@@ -912,6 +911,7 @@ enum ErrorNames {
|
||||
PASSWORD_REQUIRED
|
||||
QUERY_TOO_LONG
|
||||
REPLY_NOT_FOUND
|
||||
RESPONSE_TOO_LARGE_ERROR
|
||||
RUNTIME_CONFIG_NOT_FOUND
|
||||
SAME_EMAIL_PROVIDED
|
||||
SAME_SUBSCRIPTION_RECURRING
|
||||
@@ -921,6 +921,7 @@ enum ErrorNames {
|
||||
SPACE_NOT_FOUND
|
||||
SPACE_OWNER_NOT_FOUND
|
||||
SPACE_SHOULD_HAVE_ONLY_ONE_OWNER
|
||||
SSRF_BLOCKED_ERROR
|
||||
STORAGE_QUOTA_EXCEEDED
|
||||
SUBSCRIPTION_ALREADY_EXISTS
|
||||
SUBSCRIPTION_EXPIRED
|
||||
@@ -1453,14 +1454,12 @@ type Mutation {
|
||||
forkCopilotSession(options: ForkChatSessionInput!): String!
|
||||
generateLicenseKey(sessionId: String!): String!
|
||||
generateUserAccessToken(input: GenerateAccessTokenInput!): RevealedAccessToken!
|
||||
getBlobUploadPartUrl(key: String!, partNumber: Int!, uploadId: String!, workspaceId: String!): BlobUploadPart! @deprecated(reason: "use WorkspaceType.blobUploadPartUrl")
|
||||
grantDocUserRoles(input: GrantDocUserRolesInput!): Boolean!
|
||||
grantMember(permission: Permission!, userId: String!, workspaceId: String!): Boolean!
|
||||
|
||||
"""import users"""
|
||||
importUsers(input: ImportUsersInput!): [UserImportResultType!]!
|
||||
installLicense(license: Upload!, workspaceId: String!): License!
|
||||
inviteBatch(emails: [String!]!, sendInviteMail: Boolean @deprecated(reason: "never used"), workspaceId: String!): [InviteResult!]! @deprecated(reason: "use [inviteMembers] instead")
|
||||
inviteMembers(emails: [String!]!, workspaceId: String!): [InviteResult!]!
|
||||
leaveWorkspace(sendLeaveMail: Boolean @deprecated(reason: "no used anymore"), workspaceId: String!, workspaceName: String @deprecated(reason: "no longer used")): Boolean!
|
||||
linkCalendarAccount(input: LinkCalendarAccountInput!): String!
|
||||
@@ -1468,7 +1467,6 @@ type Mutation {
|
||||
"""mention user in a doc"""
|
||||
mentionUser(input: MentionInput!): ID!
|
||||
publishDoc(docId: String!, mode: PublicDocMode = Page, workspaceId: String!): DocType!
|
||||
publishPage(mode: PublicDocMode = Page, pageId: String!, workspaceId: String!): DocType! @deprecated(reason: "use publishDoc instead")
|
||||
|
||||
"""queue workspace doc embedding"""
|
||||
queueWorkspaceEmbedding(docId: [String!]!, workspaceId: String!): Boolean!
|
||||
@@ -1510,12 +1508,10 @@ type Mutation {
|
||||
resolveComment(input: CommentResolveInput!): Boolean!
|
||||
resumeSubscription(idempotencyKey: String @deprecated(reason: "use header `Idempotency-Key`"), plan: SubscriptionPlan = Pro, workspaceId: String): SubscriptionType!
|
||||
retryAudioTranscription(jobId: String!, workspaceId: String!): TranscriptionResultType
|
||||
revoke(userId: String!, workspaceId: String!): Boolean! @deprecated(reason: "use [revokeMember] instead")
|
||||
revokeDocUserRoles(input: RevokeDocUserRoleInput!): Boolean!
|
||||
revokeInviteLink(workspaceId: String!): Boolean!
|
||||
revokeMember(userId: String!, workspaceId: String!): Boolean!
|
||||
revokePublicDoc(docId: String!, workspaceId: String!): DocType!
|
||||
revokePublicPage(docId: String!, workspaceId: String!): DocType! @deprecated(reason: "use revokePublicDoc instead")
|
||||
revokeUserAccessToken(id: String!): Boolean!
|
||||
sendChangeEmail(callbackUrl: String!, email: String): Boolean!
|
||||
sendChangePasswordEmail(callbackUrl: String!, email: String @deprecated(reason: "fetched from signed in user")): Boolean!
|
||||
@@ -1574,9 +1570,6 @@ type Mutation {
|
||||
|
||||
"""Upload a comment attachment and return the access url"""
|
||||
uploadCommentAttachment(attachment: Upload!, docId: String!, workspaceId: String!): String!
|
||||
|
||||
"""validate app configuration"""
|
||||
validateAppConfig(updates: [UpdateAppConfigInput!]!): [AppConfigValidateResult!]! @deprecated(reason: "use Query.validateAppConfig")
|
||||
verifyEmail(token: String!): Boolean!
|
||||
}
|
||||
|
||||
@@ -1754,8 +1747,6 @@ type PublicUserType {
|
||||
}
|
||||
|
||||
type Query {
|
||||
accessTokens: [AccessToken!]! @deprecated(reason: "use currentUser.accessTokens")
|
||||
|
||||
"""Get workspace detail for admin"""
|
||||
adminWorkspace(id: String!): AdminWorkspace
|
||||
|
||||
@@ -1770,7 +1761,6 @@ type Query {
|
||||
|
||||
"""Apply updates to a doc using LLM and return the merged markdown."""
|
||||
applyDocUpdates(docId: String!, op: String!, updates: String!, workspaceId: String!): String! @deprecated(reason: "use Mutation.applyDocUpdates")
|
||||
collectAllBlobSizes: WorkspaceBlobSizes! @deprecated(reason: "use `user.quotaUsage` instead")
|
||||
|
||||
"""Get current user"""
|
||||
currentUser: UserType
|
||||
@@ -1779,12 +1769,6 @@ type Query {
|
||||
"""get workspace invitation info"""
|
||||
getInviteInfo(inviteId: String!): InvitationType!
|
||||
|
||||
"""Get is admin of workspace"""
|
||||
isAdmin(workspaceId: String!): Boolean! @deprecated(reason: "use WorkspaceType[role] instead")
|
||||
|
||||
"""Get is owner of workspace"""
|
||||
isOwner(workspaceId: String!): Boolean! @deprecated(reason: "use WorkspaceType[role] instead")
|
||||
|
||||
"""List all copilot prompts"""
|
||||
listCopilotPrompts: [CopilotPromptType!]!
|
||||
prices: [SubscriptionPrice!]!
|
||||
@@ -1918,6 +1902,11 @@ input ReplyUpdateInput {
|
||||
id: ID!
|
||||
}
|
||||
|
||||
type ResponseTooLargeErrorDataType {
|
||||
limitBytes: Int!
|
||||
receivedBytes: Int!
|
||||
}
|
||||
|
||||
type RevealedAccessToken {
|
||||
createdAt: DateTime!
|
||||
expiresAt: DateTime
|
||||
@@ -2104,6 +2093,10 @@ type SpaceShouldHaveOnlyOneOwnerDataType {
|
||||
spaceId: String!
|
||||
}
|
||||
|
||||
type SsrfBlockedErrorDataType {
|
||||
reason: String!
|
||||
}
|
||||
|
||||
type StreamObject {
|
||||
args: JSON
|
||||
result: JSON
|
||||
@@ -2405,10 +2398,6 @@ type VersionRejectedDataType {
|
||||
version: String!
|
||||
}
|
||||
|
||||
type WorkspaceBlobSizes {
|
||||
size: SafeInt!
|
||||
}
|
||||
|
||||
input WorkspaceCalendarItemInput {
|
||||
colorOverride: String
|
||||
sortOrder: Int
|
||||
@@ -2591,9 +2580,6 @@ type WorkspaceType {
|
||||
"""Get public docs of a workspace"""
|
||||
publicDocs: [DocType!]!
|
||||
|
||||
"""Get public page of a workspace by page id."""
|
||||
publicPage(pageId: String!): DocType @deprecated(reason: "use [WorkspaceType.doc] instead")
|
||||
|
||||
"""quota of workspace"""
|
||||
quota: WorkspaceQuotaType!
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
},
|
||||
"include": ["./src"],
|
||||
"references": [
|
||||
{ "path": "../../common/s3-compat" },
|
||||
{ "path": "../native" },
|
||||
{ "path": "../../../tools/cli" },
|
||||
{ "path": "../../../tools/utils" },
|
||||
|
||||
@@ -63,7 +63,6 @@ export interface AddContextDocInput {
|
||||
}
|
||||
|
||||
export interface AddContextFileInput {
|
||||
blobId?: InputMaybe<Scalars['String']['input']>;
|
||||
contextId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
@@ -978,12 +977,14 @@ export type ErrorDataUnion =
|
||||
| NoMoreSeatDataType
|
||||
| NotInSpaceDataType
|
||||
| QueryTooLongDataType
|
||||
| ResponseTooLargeErrorDataType
|
||||
| RuntimeConfigNotFoundDataType
|
||||
| SameSubscriptionRecurringDataType
|
||||
| SpaceAccessDeniedDataType
|
||||
| SpaceNotFoundDataType
|
||||
| SpaceOwnerNotFoundDataType
|
||||
| SpaceShouldHaveOnlyOneOwnerDataType
|
||||
| SsrfBlockedErrorDataType
|
||||
| SubscriptionAlreadyExistsDataType
|
||||
| SubscriptionNotExistsDataType
|
||||
| SubscriptionPlanNotFoundDataType
|
||||
@@ -1107,6 +1108,7 @@ export enum ErrorNames {
|
||||
PASSWORD_REQUIRED = 'PASSWORD_REQUIRED',
|
||||
QUERY_TOO_LONG = 'QUERY_TOO_LONG',
|
||||
REPLY_NOT_FOUND = 'REPLY_NOT_FOUND',
|
||||
RESPONSE_TOO_LARGE_ERROR = 'RESPONSE_TOO_LARGE_ERROR',
|
||||
RUNTIME_CONFIG_NOT_FOUND = 'RUNTIME_CONFIG_NOT_FOUND',
|
||||
SAME_EMAIL_PROVIDED = 'SAME_EMAIL_PROVIDED',
|
||||
SAME_SUBSCRIPTION_RECURRING = 'SAME_SUBSCRIPTION_RECURRING',
|
||||
@@ -1116,6 +1118,7 @@ export enum ErrorNames {
|
||||
SPACE_NOT_FOUND = 'SPACE_NOT_FOUND',
|
||||
SPACE_OWNER_NOT_FOUND = 'SPACE_OWNER_NOT_FOUND',
|
||||
SPACE_SHOULD_HAVE_ONLY_ONE_OWNER = 'SPACE_SHOULD_HAVE_ONLY_ONE_OWNER',
|
||||
SSRF_BLOCKED_ERROR = 'SSRF_BLOCKED_ERROR',
|
||||
STORAGE_QUOTA_EXCEEDED = 'STORAGE_QUOTA_EXCEEDED',
|
||||
SUBSCRIPTION_ALREADY_EXISTS = 'SUBSCRIPTION_ALREADY_EXISTS',
|
||||
SUBSCRIPTION_EXPIRED = 'SUBSCRIPTION_EXPIRED',
|
||||
@@ -1622,23 +1625,17 @@ export interface Mutation {
|
||||
forkCopilotSession: Scalars['String']['output'];
|
||||
generateLicenseKey: Scalars['String']['output'];
|
||||
generateUserAccessToken: RevealedAccessToken;
|
||||
/** @deprecated use WorkspaceType.blobUploadPartUrl */
|
||||
getBlobUploadPartUrl: BlobUploadPart;
|
||||
grantDocUserRoles: Scalars['Boolean']['output'];
|
||||
grantMember: Scalars['Boolean']['output'];
|
||||
/** import users */
|
||||
importUsers: Array<UserImportResultType>;
|
||||
installLicense: License;
|
||||
/** @deprecated use [inviteMembers] instead */
|
||||
inviteBatch: Array<InviteResult>;
|
||||
inviteMembers: Array<InviteResult>;
|
||||
leaveWorkspace: Scalars['Boolean']['output'];
|
||||
linkCalendarAccount: Scalars['String']['output'];
|
||||
/** mention user in a doc */
|
||||
mentionUser: Scalars['ID']['output'];
|
||||
publishDoc: DocType;
|
||||
/** @deprecated use publishDoc instead */
|
||||
publishPage: DocType;
|
||||
/** queue workspace doc embedding */
|
||||
queueWorkspaceEmbedding: Scalars['Boolean']['output'];
|
||||
/** mark all notifications as read */
|
||||
@@ -1668,14 +1665,10 @@ export interface Mutation {
|
||||
resolveComment: Scalars['Boolean']['output'];
|
||||
resumeSubscription: SubscriptionType;
|
||||
retryAudioTranscription: Maybe<TranscriptionResultType>;
|
||||
/** @deprecated use [revokeMember] instead */
|
||||
revoke: Scalars['Boolean']['output'];
|
||||
revokeDocUserRoles: Scalars['Boolean']['output'];
|
||||
revokeInviteLink: Scalars['Boolean']['output'];
|
||||
revokeMember: Scalars['Boolean']['output'];
|
||||
revokePublicDoc: DocType;
|
||||
/** @deprecated use revokePublicDoc instead */
|
||||
revokePublicPage: DocType;
|
||||
revokeUserAccessToken: Scalars['Boolean']['output'];
|
||||
sendChangeEmail: Scalars['Boolean']['output'];
|
||||
sendChangePasswordEmail: Scalars['Boolean']['output'];
|
||||
@@ -1720,11 +1713,6 @@ export interface Mutation {
|
||||
uploadAvatar: UserType;
|
||||
/** Upload a comment attachment and return the access url */
|
||||
uploadCommentAttachment: Scalars['String']['output'];
|
||||
/**
|
||||
* validate app configuration
|
||||
* @deprecated use Query.validateAppConfig
|
||||
*/
|
||||
validateAppConfig: Array<AppConfigValidateResult>;
|
||||
verifyEmail: Scalars['Boolean']['output'];
|
||||
}
|
||||
|
||||
@@ -1925,13 +1913,6 @@ export interface MutationGenerateUserAccessTokenArgs {
|
||||
input: GenerateAccessTokenInput;
|
||||
}
|
||||
|
||||
export interface MutationGetBlobUploadPartUrlArgs {
|
||||
key: Scalars['String']['input'];
|
||||
partNumber: Scalars['Int']['input'];
|
||||
uploadId: Scalars['String']['input'];
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationGrantDocUserRolesArgs {
|
||||
input: GrantDocUserRolesInput;
|
||||
}
|
||||
@@ -1951,12 +1932,6 @@ export interface MutationInstallLicenseArgs {
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationInviteBatchArgs {
|
||||
emails: Array<Scalars['String']['input']>;
|
||||
sendInviteMail?: InputMaybe<Scalars['Boolean']['input']>;
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationInviteMembersArgs {
|
||||
emails: Array<Scalars['String']['input']>;
|
||||
workspaceId: Scalars['String']['input'];
|
||||
@@ -1982,12 +1957,6 @@ export interface MutationPublishDocArgs {
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationPublishPageArgs {
|
||||
mode?: InputMaybe<PublicDocMode>;
|
||||
pageId: Scalars['String']['input'];
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationQueueWorkspaceEmbeddingArgs {
|
||||
docId: Array<Scalars['String']['input']>;
|
||||
workspaceId: Scalars['String']['input'];
|
||||
@@ -2052,11 +2021,6 @@ export interface MutationRetryAudioTranscriptionArgs {
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationRevokeArgs {
|
||||
userId: Scalars['String']['input'];
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationRevokeDocUserRolesArgs {
|
||||
input: RevokeDocUserRoleInput;
|
||||
}
|
||||
@@ -2075,11 +2039,6 @@ export interface MutationRevokePublicDocArgs {
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationRevokePublicPageArgs {
|
||||
docId: Scalars['String']['input'];
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationRevokeUserAccessTokenArgs {
|
||||
id: Scalars['String']['input'];
|
||||
}
|
||||
@@ -2212,10 +2171,6 @@ export interface MutationUploadCommentAttachmentArgs {
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationValidateAppConfigArgs {
|
||||
updates: Array<UpdateAppConfigInput>;
|
||||
}
|
||||
|
||||
export interface MutationVerifyEmailArgs {
|
||||
token: Scalars['String']['input'];
|
||||
}
|
||||
@@ -2401,8 +2356,6 @@ export interface PublicUserType {
|
||||
|
||||
export interface Query {
|
||||
__typename?: 'Query';
|
||||
/** @deprecated use currentUser.accessTokens */
|
||||
accessTokens: Array<AccessToken>;
|
||||
/** Get workspace detail for admin */
|
||||
adminWorkspace: Maybe<AdminWorkspace>;
|
||||
/** List workspaces for admin */
|
||||
@@ -2416,23 +2369,11 @@ export interface Query {
|
||||
* @deprecated use Mutation.applyDocUpdates
|
||||
*/
|
||||
applyDocUpdates: Scalars['String']['output'];
|
||||
/** @deprecated use `user.quotaUsage` instead */
|
||||
collectAllBlobSizes: WorkspaceBlobSizes;
|
||||
/** Get current user */
|
||||
currentUser: Maybe<UserType>;
|
||||
error: ErrorDataUnion;
|
||||
/** get workspace invitation info */
|
||||
getInviteInfo: InvitationType;
|
||||
/**
|
||||
* Get is admin of workspace
|
||||
* @deprecated use WorkspaceType[role] instead
|
||||
*/
|
||||
isAdmin: Scalars['Boolean']['output'];
|
||||
/**
|
||||
* Get is owner of workspace
|
||||
* @deprecated use WorkspaceType[role] instead
|
||||
*/
|
||||
isOwner: Scalars['Boolean']['output'];
|
||||
/** List all copilot prompts */
|
||||
listCopilotPrompts: Array<CopilotPromptType>;
|
||||
prices: Array<SubscriptionPrice>;
|
||||
@@ -2494,14 +2435,6 @@ export interface QueryGetInviteInfoArgs {
|
||||
inviteId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface QueryIsAdminArgs {
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface QueryIsOwnerArgs {
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface QueryPublicUserByIdArgs {
|
||||
id: Scalars['String']['input'];
|
||||
}
|
||||
@@ -2630,6 +2563,12 @@ export interface ReplyUpdateInput {
|
||||
id: Scalars['ID']['input'];
|
||||
}
|
||||
|
||||
export interface ResponseTooLargeErrorDataType {
|
||||
__typename?: 'ResponseTooLargeErrorDataType';
|
||||
limitBytes: Scalars['Int']['output'];
|
||||
receivedBytes: Scalars['Int']['output'];
|
||||
}
|
||||
|
||||
export interface RevealedAccessToken {
|
||||
__typename?: 'RevealedAccessToken';
|
||||
createdAt: Scalars['DateTime']['output'];
|
||||
@@ -2812,6 +2751,11 @@ export interface SpaceShouldHaveOnlyOneOwnerDataType {
|
||||
spaceId: Scalars['String']['output'];
|
||||
}
|
||||
|
||||
export interface SsrfBlockedErrorDataType {
|
||||
__typename?: 'SsrfBlockedErrorDataType';
|
||||
reason: Scalars['String']['output'];
|
||||
}
|
||||
|
||||
export interface StreamObject {
|
||||
__typename?: 'StreamObject';
|
||||
args: Maybe<Scalars['JSON']['output']>;
|
||||
@@ -3126,11 +3070,6 @@ export interface VersionRejectedDataType {
|
||||
version: Scalars['String']['output'];
|
||||
}
|
||||
|
||||
export interface WorkspaceBlobSizes {
|
||||
__typename?: 'WorkspaceBlobSizes';
|
||||
size: Scalars['SafeInt']['output'];
|
||||
}
|
||||
|
||||
export interface WorkspaceCalendarItemInput {
|
||||
colorOverride?: InputMaybe<Scalars['String']['input']>;
|
||||
sortOrder?: InputMaybe<Scalars['Int']['input']>;
|
||||
@@ -3308,11 +3247,6 @@ export interface WorkspaceType {
|
||||
public: Scalars['Boolean']['output'];
|
||||
/** Get public docs of a workspace */
|
||||
publicDocs: Array<DocType>;
|
||||
/**
|
||||
* Get public page of a workspace by page id.
|
||||
* @deprecated use [WorkspaceType.doc] instead
|
||||
*/
|
||||
publicPage: Maybe<DocType>;
|
||||
/** quota of workspace */
|
||||
quota: WorkspaceQuotaType;
|
||||
/** Get recently updated docs of a workspace */
|
||||
@@ -3378,10 +3312,6 @@ export interface WorkspaceTypePageMetaArgs {
|
||||
pageId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface WorkspaceTypePublicPageArgs {
|
||||
pageId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface WorkspaceTypeRecentlyUpdatedDocsArgs {
|
||||
pagination: PaginationInput;
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ export type AppSetting = {
|
||||
autoDownloadUpdate: boolean;
|
||||
enableTelemetry: boolean;
|
||||
showLinkedDocInSidebar: boolean;
|
||||
disableImageAntialiasing: boolean;
|
||||
};
|
||||
export const windowFrameStyleOptions: AppSetting['windowFrameStyle'][] = [
|
||||
'frameless',
|
||||
@@ -35,6 +36,7 @@ const appSettingBaseAtom = atomWithStorage<AppSetting>(
|
||||
autoDownloadUpdate: true,
|
||||
enableTelemetry: true,
|
||||
showLinkedDocInSidebar: true,
|
||||
disableImageAntialiasing: false,
|
||||
},
|
||||
undefined,
|
||||
{
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
import { describe, expect, test } from 'vitest';
|
||||
|
||||
import { isAllowedRedirectTarget } from '../redirect-allowlist';
|
||||
|
||||
describe('redirect allowlist', () => {
|
||||
test('allows same hostname', () => {
|
||||
expect(
|
||||
isAllowedRedirectTarget('https://self.example.com/path', {
|
||||
currentHostname: 'self.example.com',
|
||||
})
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
test('allows trusted domains and subdomains', () => {
|
||||
expect(
|
||||
isAllowedRedirectTarget('https://github.com/toeverything/AFFiNE', {
|
||||
currentHostname: 'self.example.com',
|
||||
})
|
||||
).toBe(true);
|
||||
|
||||
expect(
|
||||
isAllowedRedirectTarget('https://sub.github.com/foo', {
|
||||
currentHostname: 'self.example.com',
|
||||
})
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
test('blocks look-alike domains', () => {
|
||||
expect(
|
||||
isAllowedRedirectTarget('https://evilgithub.com', {
|
||||
currentHostname: 'self.example.com',
|
||||
})
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
test('blocks disallowed protocols', () => {
|
||||
expect(
|
||||
isAllowedRedirectTarget('javascript:alert(1)', {
|
||||
currentHostname: 'self.example.com',
|
||||
})
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
test('handles port and trailing dot', () => {
|
||||
expect(
|
||||
isAllowedRedirectTarget('https://github.com:8443', {
|
||||
currentHostname: 'self.example.com',
|
||||
})
|
||||
).toBe(true);
|
||||
|
||||
expect(
|
||||
isAllowedRedirectTarget('https://affine.pro./', {
|
||||
currentHostname: 'self.example.com',
|
||||
})
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
test('blocks punycode homograph', () => {
|
||||
// "а" is Cyrillic small a (U+0430), different from Latin "a"
|
||||
expect(
|
||||
isAllowedRedirectTarget('https://аffine.pro', {
|
||||
currentHostname: 'self.example.com',
|
||||
})
|
||||
).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -4,6 +4,7 @@ export * from './exhaustmap-with-trailing';
|
||||
export * from './fractional-indexing';
|
||||
export * from './merge-updates';
|
||||
export * from './object-pool';
|
||||
export * from './redirect-allowlist';
|
||||
export * from './stable-hash';
|
||||
export * from './throw-if-aborted';
|
||||
export * from './yjs-observable';
|
||||
|
||||
50
packages/common/infra/src/utils/redirect-allowlist.ts
Normal file
50
packages/common/infra/src/utils/redirect-allowlist.ts
Normal file
@@ -0,0 +1,50 @@
|
||||
export const TRUSTED_REDIRECT_DOMAINS = [
|
||||
'google.com',
|
||||
'stripe.com',
|
||||
'github.com',
|
||||
'twitter.com',
|
||||
'discord.gg',
|
||||
'youtube.com',
|
||||
't.me',
|
||||
'reddit.com',
|
||||
'affine.pro',
|
||||
].map(d => d.toLowerCase());
|
||||
|
||||
export const ALLOWED_REDIRECT_PROTOCOLS = new Set(['http:', 'https:']);
|
||||
|
||||
function normalizeHostname(hostname: string) {
|
||||
return hostname.toLowerCase().replace(/\.$/, '');
|
||||
}
|
||||
|
||||
function hostnameMatchesDomain(hostname: string, domain: string) {
|
||||
return hostname === domain || hostname.endsWith(`.${domain}`);
|
||||
}
|
||||
|
||||
export function isAllowedRedirectTarget(
|
||||
redirectUri: string,
|
||||
options: {
|
||||
currentHostname: string;
|
||||
}
|
||||
) {
|
||||
const currentHostname = normalizeHostname(options.currentHostname);
|
||||
|
||||
try {
|
||||
const target = new URL(redirectUri);
|
||||
|
||||
if (!ALLOWED_REDIRECT_PROTOCOLS.has(target.protocol)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const hostname = normalizeHostname(target.hostname);
|
||||
|
||||
if (hostname === currentHostname) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return TRUSTED_REDIRECT_DOMAINS.some(domain =>
|
||||
hostnameMatchesDomain(hostname, domain)
|
||||
);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,7 @@ doc-loader = [
|
||||
"url",
|
||||
]
|
||||
hashcash = ["chrono", "sha3", "rand"]
|
||||
napi = ["dep:napi"]
|
||||
tree-sitter = [
|
||||
"cc",
|
||||
"dep:tree-sitter",
|
||||
@@ -53,6 +54,7 @@ chrono = { workspace = true, optional = true }
|
||||
docx-parser = { workspace = true, optional = true }
|
||||
infer = { workspace = true, optional = true }
|
||||
nanoid = { workspace = true, optional = true }
|
||||
napi = { workspace = true, optional = true }
|
||||
path-ext = { workspace = true, optional = true }
|
||||
pdf-extract = { workspace = true, optional = true }
|
||||
pulldown-cmark = { workspace = true, optional = true }
|
||||
|
||||
@@ -4,3 +4,5 @@ pub mod doc_loader;
|
||||
pub mod doc_parser;
|
||||
#[cfg(feature = "hashcash")]
|
||||
pub mod hashcash;
|
||||
#[cfg(feature = "napi")]
|
||||
pub mod napi_utils;
|
||||
|
||||
22
packages/common/native/src/napi_utils.rs
Normal file
22
packages/common/native/src/napi_utils.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use std::fmt::{Debug, Display};
|
||||
|
||||
use napi::{Error, Result, Status};
|
||||
|
||||
pub fn to_napi_error<E: Display + Debug>(err: E, status: Status) -> Error {
|
||||
Error::new(status, err.to_string())
|
||||
}
|
||||
|
||||
pub fn map_napi_err<T, E: Display + Debug>(value: std::result::Result<T, E>, status: Status) -> Result<T> {
|
||||
value.map_err(|err| to_napi_error(err, status))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn map_napi_err_keeps_message() {
|
||||
let err = map_napi_err::<(), _>(Err("boom"), Status::GenericFailure).unwrap_err();
|
||||
assert!(err.to_string().contains("boom"));
|
||||
}
|
||||
}
|
||||
@@ -26,7 +26,7 @@
|
||||
"lodash-es": "^4.17.21",
|
||||
"nanoid": "^5.1.6",
|
||||
"rxjs": "^7.8.2",
|
||||
"uuid": "^11.1.0",
|
||||
"uuid": "^13.0.0",
|
||||
"y-protocols": "^1.0.6",
|
||||
"yjs": "^13.6.27"
|
||||
},
|
||||
@@ -36,7 +36,7 @@
|
||||
"@blocksuite/affine": "workspace:*",
|
||||
"fake-indexeddb": "^6.0.0",
|
||||
"idb": "^8.0.0",
|
||||
"socket.io-client": "^4.8.1",
|
||||
"socket.io-client": "^4.8.3",
|
||||
"vitest": "^3.2.4"
|
||||
},
|
||||
"peerDependencies": {
|
||||
@@ -44,6 +44,6 @@
|
||||
"@affine/graphql": "workspace:*",
|
||||
"@blocksuite/affine": "workspace:*",
|
||||
"idb": "^8.0.0",
|
||||
"socket.io-client": "^4.7.5"
|
||||
"socket.io-client": "^4.8.3"
|
||||
}
|
||||
}
|
||||
|
||||
23
packages/common/nbstore/src/__tests__/base64.bench.ts
Normal file
23
packages/common/nbstore/src/__tests__/base64.bench.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import { bench, describe } from 'vitest';
|
||||
|
||||
import { base64ToUint8Array, uint8ArrayToBase64 } from '../impls/cloud/socket';
|
||||
|
||||
const data = new Uint8Array(1024 * 256);
|
||||
for (let i = 0; i < data.length; i++) {
|
||||
data[i] = i % 251;
|
||||
}
|
||||
let encoded = '';
|
||||
|
||||
await uint8ArrayToBase64(data).then(result => {
|
||||
encoded = result;
|
||||
});
|
||||
|
||||
describe('base64 helpers', () => {
|
||||
bench('encode Uint8Array to base64', async () => {
|
||||
await uint8ArrayToBase64(data);
|
||||
});
|
||||
|
||||
bench('decode base64 to Uint8Array', () => {
|
||||
base64ToUint8Array(encoded);
|
||||
});
|
||||
});
|
||||
27
packages/common/nbstore/src/__tests__/base64.spec.ts
Normal file
27
packages/common/nbstore/src/__tests__/base64.spec.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
import { describe, expect, test } from 'vitest';
|
||||
|
||||
import { base64ToUint8Array, uint8ArrayToBase64 } from '../impls/cloud/socket';
|
||||
|
||||
function makeSample(size: number) {
|
||||
const data = new Uint8Array(size);
|
||||
for (let i = 0; i < size; i++) {
|
||||
data[i] = i % 251;
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
describe('base64 helpers', () => {
|
||||
test('roundtrip preserves data', async () => {
|
||||
const input = makeSample(1024);
|
||||
const encoded = await uint8ArrayToBase64(input);
|
||||
const decoded = base64ToUint8Array(encoded);
|
||||
expect(decoded).toEqual(input);
|
||||
});
|
||||
|
||||
test('handles large payloads', async () => {
|
||||
const input = makeSample(256 * 1024);
|
||||
const encoded = await uint8ArrayToBase64(input);
|
||||
const decoded = base64ToUint8Array(encoded);
|
||||
expect(decoded).toEqual(input);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,41 @@
|
||||
import { describe, expect, test } from 'vitest';
|
||||
|
||||
import { CloudDocStorage } from '../impls/cloud/doc';
|
||||
|
||||
const base64UpdateA = 'AQID';
|
||||
const base64UpdateB = 'BAUG';
|
||||
|
||||
describe('CloudDocStorage broadcast updates', () => {
|
||||
test('emits updates from batch payload', () => {
|
||||
const storage = new CloudDocStorage({
|
||||
id: 'space-1',
|
||||
serverBaseUrl: 'http://localhost',
|
||||
isSelfHosted: true,
|
||||
type: 'workspace',
|
||||
readonlyMode: true,
|
||||
});
|
||||
|
||||
(storage as any).connection.idConverter = {
|
||||
oldIdToNewId: (id: string) => id,
|
||||
newIdToOldId: (id: string) => id,
|
||||
};
|
||||
|
||||
const received: Uint8Array[] = [];
|
||||
storage.subscribeDocUpdate(update => {
|
||||
received.push(update.bin);
|
||||
});
|
||||
|
||||
storage.onServerUpdates({
|
||||
spaceType: 'workspace',
|
||||
spaceId: 'space-1',
|
||||
docId: 'doc-1',
|
||||
updates: [base64UpdateA, base64UpdateB],
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
|
||||
expect(received).toEqual([
|
||||
new Uint8Array([1, 2, 3]),
|
||||
new Uint8Array([4, 5, 6]),
|
||||
]);
|
||||
});
|
||||
});
|
||||
@@ -38,12 +38,32 @@ export class CloudDocStorage extends DocStorageBase<CloudDocStorageOptions> {
|
||||
|
||||
onServerUpdate: ServerEventsMap['space:broadcast-doc-update'] = message => {
|
||||
if (
|
||||
this.spaceType === message.spaceType &&
|
||||
this.spaceId === message.spaceId
|
||||
this.spaceType !== message.spaceType ||
|
||||
this.spaceId !== message.spaceId
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.emit('update', {
|
||||
docId: this.idConverter.oldIdToNewId(message.docId),
|
||||
bin: base64ToUint8Array(message.update),
|
||||
timestamp: new Date(message.timestamp),
|
||||
editor: message.editor,
|
||||
});
|
||||
};
|
||||
|
||||
onServerUpdates: ServerEventsMap['space:broadcast-doc-updates'] = message => {
|
||||
if (
|
||||
this.spaceType !== message.spaceType ||
|
||||
this.spaceId !== message.spaceId
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const update of message.updates) {
|
||||
this.emit('update', {
|
||||
docId: this.idConverter.oldIdToNewId(message.docId),
|
||||
bin: base64ToUint8Array(message.update),
|
||||
bin: base64ToUint8Array(update),
|
||||
timestamp: new Date(message.timestamp),
|
||||
editor: message.editor,
|
||||
});
|
||||
@@ -52,7 +72,8 @@ export class CloudDocStorage extends DocStorageBase<CloudDocStorageOptions> {
|
||||
|
||||
readonly connection = new CloudDocStorageConnection(
|
||||
this.options,
|
||||
this.onServerUpdate
|
||||
this.onServerUpdate,
|
||||
this.onServerUpdates
|
||||
);
|
||||
|
||||
override async getDocSnapshot(docId: string) {
|
||||
@@ -184,7 +205,8 @@ export class CloudDocStorage extends DocStorageBase<CloudDocStorageOptions> {
|
||||
class CloudDocStorageConnection extends SocketConnection {
|
||||
constructor(
|
||||
private readonly options: CloudDocStorageOptions,
|
||||
private readonly onServerUpdate: ServerEventsMap['space:broadcast-doc-update']
|
||||
private readonly onServerUpdate: ServerEventsMap['space:broadcast-doc-update'],
|
||||
private readonly onServerUpdates: ServerEventsMap['space:broadcast-doc-updates']
|
||||
) {
|
||||
super(options.serverBaseUrl, options.isSelfHosted);
|
||||
}
|
||||
@@ -210,6 +232,7 @@ class CloudDocStorageConnection extends SocketConnection {
|
||||
}
|
||||
|
||||
socket.on('space:broadcast-doc-update', this.onServerUpdate);
|
||||
socket.on('space:broadcast-doc-updates', this.onServerUpdates);
|
||||
|
||||
return { socket, disconnect };
|
||||
} catch (e) {
|
||||
@@ -230,6 +253,7 @@ class CloudDocStorageConnection extends SocketConnection {
|
||||
spaceId: this.options.id,
|
||||
});
|
||||
socket.off('space:broadcast-doc-update', this.onServerUpdate);
|
||||
socket.off('space:broadcast-doc-updates', this.onServerUpdates);
|
||||
super.doDisconnect({ socket, disconnect });
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user