mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-03-23 07:40:46 +08:00
Compare commits
27 Commits
renovate/c
...
darksky/im
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad6470db82 | ||
|
|
adf8955e3f | ||
|
|
6a93566422 | ||
|
|
7ac8b14b65 | ||
|
|
16a8f17717 | ||
|
|
1ffb8c922c | ||
|
|
daf536f77a | ||
|
|
0d2d4bb6a1 | ||
|
|
cb9897d493 | ||
|
|
8ca8333cd6 | ||
|
|
3bf2503f55 | ||
|
|
59fd942f40 | ||
|
|
d6d5ae6182 | ||
|
|
c1a09b951f | ||
|
|
4ce68d74f1 | ||
|
|
fbfcc01d14 | ||
|
|
1112a06623 | ||
|
|
bbcb7e69fe | ||
|
|
cc2f23339e | ||
|
|
31101a69e7 | ||
|
|
0b1a44863f | ||
|
|
8406f9656e | ||
|
|
121c0d172d | ||
|
|
8f03090780 | ||
|
|
8125cc0e75 | ||
|
|
f537a75f01 | ||
|
|
9456a07889 |
4
.github/renovate.json
vendored
4
.github/renovate.json
vendored
@@ -63,7 +63,7 @@
|
||||
"groupName": "opentelemetry",
|
||||
"matchPackageNames": [
|
||||
"/^@opentelemetry/",
|
||||
"/^@google-cloud\/opentelemetry-/"
|
||||
"/^@google-cloud/opentelemetry-/"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -79,7 +79,7 @@
|
||||
"customManagers": [
|
||||
{
|
||||
"customType": "regex",
|
||||
"fileMatch": ["^rust-toolchain\\.toml?$"],
|
||||
"managerFilePatterns": ["/^rust-toolchain\\.toml?$/"],
|
||||
"matchStrings": [
|
||||
"channel\\s*=\\s*\"(?<currentValue>\\d+\\.\\d+(\\.\\d+)?)\""
|
||||
],
|
||||
|
||||
942
.yarn/releases/yarn-4.12.0.cjs
vendored
942
.yarn/releases/yarn-4.12.0.cjs
vendored
File diff suppressed because one or more lines are too long
940
.yarn/releases/yarn-4.13.0.cjs
vendored
Executable file
940
.yarn/releases/yarn-4.13.0.cjs
vendored
Executable file
File diff suppressed because one or more lines are too long
@@ -12,4 +12,4 @@ npmPublishAccess: public
|
||||
|
||||
npmRegistryServer: "https://registry.npmjs.org"
|
||||
|
||||
yarnPath: .yarn/releases/yarn-4.12.0.cjs
|
||||
yarnPath: .yarn/releases/yarn-4.13.0.cjs
|
||||
|
||||
2703
Cargo.lock
generated
2703
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
11
Cargo.toml
11
Cargo.toml
@@ -36,7 +36,7 @@ resolver = "3"
|
||||
criterion2 = { version = "3", default-features = false }
|
||||
crossbeam-channel = "0.5"
|
||||
dispatch2 = "0.3"
|
||||
docx-parser = { git = "https://github.com/toeverything/docx-parser" }
|
||||
docx-parser = { git = "https://github.com/toeverything/docx-parser", rev = "380beea" }
|
||||
dotenvy = "0.15"
|
||||
file-format = { version = "0.28", features = ["reader"] }
|
||||
homedir = "0.3"
|
||||
@@ -59,6 +59,7 @@ resolver = "3"
|
||||
lru = "0.16"
|
||||
matroska = "0.30"
|
||||
memory-indexer = "0.3.0"
|
||||
mermaid-rs-renderer = { git = "https://github.com/toeverything/mermaid-rs-renderer", rev = "fba9097", default-features = false }
|
||||
mimalloc = "0.1"
|
||||
mp4parse = "0.17"
|
||||
nanoid = "0.4"
|
||||
@@ -122,6 +123,14 @@ resolver = "3"
|
||||
tree-sitter-rust = { version = "0.24" }
|
||||
tree-sitter-scala = { version = "0.24" }
|
||||
tree-sitter-typescript = { version = "0.23" }
|
||||
typst = "0.14.2"
|
||||
typst-as-lib = { version = "0.15.4", default-features = false, features = [
|
||||
"packages",
|
||||
"typst-kit-embed-fonts",
|
||||
"typst-kit-fonts",
|
||||
"ureq",
|
||||
] }
|
||||
typst-svg = "0.14.2"
|
||||
uniffi = "0.29"
|
||||
url = { version = "2.5" }
|
||||
uuid = "1.8"
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
|
||||
|
||||
exports[`snapshot to markdown > imports obsidian vault fixtures 1`] = `
|
||||
{
|
||||
"entry": {
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"delta": [
|
||||
{
|
||||
"insert": "Panel
|
||||
Body line",
|
||||
},
|
||||
],
|
||||
"flavour": "affine:paragraph",
|
||||
"type": "text",
|
||||
},
|
||||
],
|
||||
"emoji": "💡",
|
||||
"flavour": "affine:callout",
|
||||
},
|
||||
{
|
||||
"flavour": "affine:attachment",
|
||||
"name": "archive.zip",
|
||||
"style": "horizontalThin",
|
||||
},
|
||||
{
|
||||
"delta": [
|
||||
{
|
||||
"footnote": {
|
||||
"label": "1",
|
||||
"reference": {
|
||||
"title": "reference body",
|
||||
"type": "url",
|
||||
},
|
||||
},
|
||||
"insert": " ",
|
||||
},
|
||||
],
|
||||
"flavour": "affine:paragraph",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"flavour": "affine:divider",
|
||||
},
|
||||
{
|
||||
"delta": [
|
||||
{
|
||||
"insert": "after note",
|
||||
},
|
||||
],
|
||||
"flavour": "affine:paragraph",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"delta": [
|
||||
{
|
||||
"insert": " ",
|
||||
"reference": {
|
||||
"page": "linked",
|
||||
"type": "LinkedPage",
|
||||
},
|
||||
},
|
||||
],
|
||||
"flavour": "affine:paragraph",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"delta": [
|
||||
{
|
||||
"insert": "Sources",
|
||||
},
|
||||
],
|
||||
"flavour": "affine:paragraph",
|
||||
"type": "h6",
|
||||
},
|
||||
{
|
||||
"flavour": "affine:bookmark",
|
||||
},
|
||||
],
|
||||
"flavour": "affine:note",
|
||||
},
|
||||
],
|
||||
"flavour": "affine:page",
|
||||
},
|
||||
"titles": [
|
||||
"entry",
|
||||
"linked",
|
||||
],
|
||||
}
|
||||
`;
|
||||
@@ -0,0 +1,14 @@
|
||||
> [!custom] Panel
|
||||
> Body line
|
||||
|
||||
![[archive.zip]]
|
||||
|
||||
[^1]
|
||||
|
||||
---
|
||||
|
||||
after note
|
||||
|
||||
[[linked]]
|
||||
|
||||
[^1]: reference body
|
||||
@@ -0,0 +1 @@
|
||||
plain linked page
|
||||
@@ -1,4 +1,10 @@
|
||||
import { MarkdownTransformer } from '@blocksuite/affine/widgets/linked-doc';
|
||||
import { readFileSync } from 'node:fs';
|
||||
import { basename, resolve } from 'node:path';
|
||||
|
||||
import {
|
||||
MarkdownTransformer,
|
||||
ObsidianTransformer,
|
||||
} from '@blocksuite/affine/widgets/linked-doc';
|
||||
import {
|
||||
DefaultTheme,
|
||||
NoteDisplayMode,
|
||||
@@ -8,13 +14,18 @@ import {
|
||||
CalloutAdmonitionType,
|
||||
CalloutExportStyle,
|
||||
calloutMarkdownExportMiddleware,
|
||||
docLinkBaseURLMiddleware,
|
||||
embedSyncedDocMiddleware,
|
||||
MarkdownAdapter,
|
||||
titleMiddleware,
|
||||
} from '@blocksuite/affine-shared/adapters';
|
||||
import type { AffineTextAttributes } from '@blocksuite/affine-shared/types';
|
||||
import type {
|
||||
BlockSnapshot,
|
||||
DeltaInsert,
|
||||
DocSnapshot,
|
||||
SliceSnapshot,
|
||||
Store,
|
||||
TransformerMiddleware,
|
||||
} from '@blocksuite/store';
|
||||
import { AssetsManager, MemoryBlobCRUD, Schema } from '@blocksuite/store';
|
||||
@@ -29,6 +40,138 @@ import { testStoreExtensions } from '../utils/store.js';
|
||||
|
||||
const provider = getProvider();
|
||||
|
||||
function withRelativePath(file: File, relativePath: string): File {
|
||||
Object.defineProperty(file, 'webkitRelativePath', {
|
||||
value: relativePath,
|
||||
writable: false,
|
||||
});
|
||||
return file;
|
||||
}
|
||||
|
||||
function markdownFixture(relativePath: string): File {
|
||||
return withRelativePath(
|
||||
new File(
|
||||
[
|
||||
readFileSync(
|
||||
resolve(import.meta.dirname, 'fixtures/obsidian', relativePath),
|
||||
'utf8'
|
||||
),
|
||||
],
|
||||
basename(relativePath),
|
||||
{ type: 'text/markdown' }
|
||||
),
|
||||
`vault/${relativePath}`
|
||||
);
|
||||
}
|
||||
|
||||
function exportSnapshot(doc: Store): DocSnapshot {
|
||||
const job = doc.getTransformer([
|
||||
docLinkBaseURLMiddleware(doc.workspace.id),
|
||||
titleMiddleware(doc.workspace.meta.docMetas),
|
||||
]);
|
||||
const snapshot = job.docToSnapshot(doc);
|
||||
expect(snapshot).toBeTruthy();
|
||||
return snapshot!;
|
||||
}
|
||||
|
||||
function normalizeDeltaForSnapshot(
|
||||
delta: DeltaInsert<AffineTextAttributes>[],
|
||||
titleById: ReadonlyMap<string, string>
|
||||
) {
|
||||
return delta.map(item => {
|
||||
const normalized: Record<string, unknown> = {
|
||||
insert: item.insert,
|
||||
};
|
||||
|
||||
if (item.attributes?.link) {
|
||||
normalized.link = item.attributes.link;
|
||||
}
|
||||
|
||||
if (item.attributes?.reference?.type === 'LinkedPage') {
|
||||
normalized.reference = {
|
||||
type: 'LinkedPage',
|
||||
page: titleById.get(item.attributes.reference.pageId) ?? '<missing>',
|
||||
...(item.attributes.reference.title
|
||||
? { title: item.attributes.reference.title }
|
||||
: {}),
|
||||
};
|
||||
}
|
||||
|
||||
if (item.attributes?.footnote) {
|
||||
const reference = item.attributes.footnote.reference;
|
||||
normalized.footnote = {
|
||||
label: item.attributes.footnote.label,
|
||||
reference:
|
||||
reference.type === 'doc'
|
||||
? {
|
||||
type: 'doc',
|
||||
page: reference.docId
|
||||
? (titleById.get(reference.docId) ?? '<missing>')
|
||||
: '<missing>',
|
||||
}
|
||||
: {
|
||||
type: reference.type,
|
||||
...(reference.title ? { title: reference.title } : {}),
|
||||
...(reference.fileName ? { fileName: reference.fileName } : {}),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
return normalized;
|
||||
});
|
||||
}
|
||||
|
||||
function simplifyBlockForSnapshot(
|
||||
block: BlockSnapshot,
|
||||
titleById: ReadonlyMap<string, string>
|
||||
): Record<string, unknown> {
|
||||
const simplified: Record<string, unknown> = {
|
||||
flavour: block.flavour,
|
||||
};
|
||||
|
||||
if (block.flavour === 'affine:paragraph' || block.flavour === 'affine:list') {
|
||||
simplified.type = block.props.type;
|
||||
const text = block.props.text as
|
||||
| { delta?: DeltaInsert<AffineTextAttributes>[] }
|
||||
| undefined;
|
||||
simplified.delta = normalizeDeltaForSnapshot(text?.delta ?? [], titleById);
|
||||
}
|
||||
|
||||
if (block.flavour === 'affine:callout') {
|
||||
simplified.emoji = block.props.emoji;
|
||||
}
|
||||
|
||||
if (block.flavour === 'affine:attachment') {
|
||||
simplified.name = block.props.name;
|
||||
simplified.style = block.props.style;
|
||||
}
|
||||
|
||||
if (block.flavour === 'affine:image') {
|
||||
simplified.sourceId = '<asset>';
|
||||
}
|
||||
|
||||
const children = (block.children ?? [])
|
||||
.filter(child => child.flavour !== 'affine:surface')
|
||||
.map(child => simplifyBlockForSnapshot(child, titleById));
|
||||
if (children.length) {
|
||||
simplified.children = children;
|
||||
}
|
||||
|
||||
return simplified;
|
||||
}
|
||||
|
||||
function snapshotDocByTitle(
|
||||
collection: TestWorkspace,
|
||||
title: string,
|
||||
titleById: ReadonlyMap<string, string>
|
||||
) {
|
||||
const meta = collection.meta.docMetas.find(meta => meta.title === title);
|
||||
expect(meta).toBeTruthy();
|
||||
const doc = collection.getDoc(meta!.id)?.getStore({ id: meta!.id });
|
||||
expect(doc).toBeTruthy();
|
||||
return simplifyBlockForSnapshot(exportSnapshot(doc!).blocks, titleById);
|
||||
}
|
||||
|
||||
describe('snapshot to markdown', () => {
|
||||
test('code', async () => {
|
||||
const blockSnapshot: BlockSnapshot = {
|
||||
@@ -127,6 +270,46 @@ Hello world
|
||||
expect(meta?.tags).toEqual(['a', 'b']);
|
||||
});
|
||||
|
||||
test('imports obsidian vault fixtures', async () => {
|
||||
const schema = new Schema().register(AffineSchemas);
|
||||
const collection = new TestWorkspace();
|
||||
collection.storeExtensions = testStoreExtensions;
|
||||
collection.meta.initialize();
|
||||
|
||||
const attachment = withRelativePath(
|
||||
new File([new Uint8Array([80, 75, 3, 4])], 'archive.zip', {
|
||||
type: 'application/zip',
|
||||
}),
|
||||
'vault/archive.zip'
|
||||
);
|
||||
|
||||
const { docIds } = await ObsidianTransformer.importObsidianVault({
|
||||
collection,
|
||||
schema,
|
||||
importedFiles: [
|
||||
markdownFixture('entry.md'),
|
||||
markdownFixture('linked.md'),
|
||||
attachment,
|
||||
],
|
||||
extensions: testStoreExtensions,
|
||||
});
|
||||
expect(docIds).toHaveLength(2);
|
||||
|
||||
const titleById = new Map(
|
||||
collection.meta.docMetas.map(meta => [
|
||||
meta.id,
|
||||
meta.title ?? '<untitled>',
|
||||
])
|
||||
);
|
||||
|
||||
expect({
|
||||
titles: collection.meta.docMetas
|
||||
.map(meta => meta.title)
|
||||
.sort((a, b) => (a ?? '').localeCompare(b ?? '')),
|
||||
entry: snapshotDocByTitle(collection, 'entry', titleById),
|
||||
}).toMatchSnapshot();
|
||||
});
|
||||
|
||||
test('paragraph', async () => {
|
||||
const blockSnapshot: BlockSnapshot = {
|
||||
type: 'block',
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
import {
|
||||
BlockMarkdownAdapterExtension,
|
||||
type BlockMarkdownAdapterMatcher,
|
||||
createAttachmentBlockSnapshot,
|
||||
FOOTNOTE_DEFINITION_PREFIX,
|
||||
getFootnoteDefinitionText,
|
||||
isFootnoteDefinitionNode,
|
||||
@@ -56,18 +57,15 @@ export const attachmentBlockMarkdownAdapterMatcher: BlockMarkdownAdapterMatcher
|
||||
}
|
||||
walkerContext
|
||||
.openNode(
|
||||
{
|
||||
type: 'block',
|
||||
createAttachmentBlockSnapshot({
|
||||
id: nanoid(),
|
||||
flavour: AttachmentBlockSchema.model.flavour,
|
||||
props: {
|
||||
name: fileName,
|
||||
sourceId: blobId,
|
||||
footnoteIdentifier,
|
||||
style: 'citation',
|
||||
},
|
||||
children: [],
|
||||
},
|
||||
}),
|
||||
'children'
|
||||
)
|
||||
.closeNode();
|
||||
|
||||
@@ -32,6 +32,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@vitest/browser-playwright": "^4.0.18",
|
||||
"playwright": "=1.58.2",
|
||||
"vitest": "^4.0.18"
|
||||
},
|
||||
"exports": {
|
||||
|
||||
@@ -516,6 +516,9 @@ export const EdgelessNoteInteraction =
|
||||
}
|
||||
})
|
||||
.catch(console.error);
|
||||
} else if (multiSelect && alreadySelected && editing) {
|
||||
// range selection using Shift-click when editing
|
||||
return;
|
||||
} else {
|
||||
context.default(context);
|
||||
}
|
||||
|
||||
@@ -83,9 +83,9 @@ export class RecordField extends SignalWatcher(
|
||||
border: 1px solid transparent;
|
||||
}
|
||||
|
||||
.field-content .affine-database-number {
|
||||
.field-content affine-database-number-cell .number {
|
||||
text-align: left;
|
||||
justify-content: start;
|
||||
justify-content: flex-start;
|
||||
}
|
||||
|
||||
.field-content:hover {
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@vitest/browser-playwright": "^4.0.18",
|
||||
"playwright": "=1.58.2",
|
||||
"vitest": "^4.0.18"
|
||||
},
|
||||
"exports": {
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import { AttachmentBlockSchema } from '@blocksuite/affine-model';
|
||||
import {
|
||||
type AttachmentBlockProps,
|
||||
AttachmentBlockSchema,
|
||||
} from '@blocksuite/affine-model';
|
||||
import { BlockSuiteError, ErrorCode } from '@blocksuite/global/exceptions';
|
||||
import {
|
||||
type AssetsManager,
|
||||
@@ -23,6 +26,24 @@ import { AdapterFactoryIdentifier } from './types/adapter';
|
||||
|
||||
export type Attachment = File[];
|
||||
|
||||
type CreateAttachmentBlockSnapshotOptions = {
|
||||
id?: string;
|
||||
props: Partial<AttachmentBlockProps> & Pick<AttachmentBlockProps, 'name'>;
|
||||
};
|
||||
|
||||
export function createAttachmentBlockSnapshot({
|
||||
id = nanoid(),
|
||||
props,
|
||||
}: CreateAttachmentBlockSnapshotOptions): BlockSnapshot {
|
||||
return {
|
||||
type: 'block',
|
||||
id,
|
||||
flavour: AttachmentBlockSchema.model.flavour,
|
||||
props,
|
||||
children: [],
|
||||
};
|
||||
}
|
||||
|
||||
type AttachmentToSliceSnapshotPayload = {
|
||||
file: Attachment;
|
||||
assets?: AssetsManager;
|
||||
@@ -97,8 +118,6 @@ export class AttachmentAdapter extends BaseAdapter<Attachment> {
|
||||
if (files.length === 0) return null;
|
||||
|
||||
const content: SliceSnapshot['content'] = [];
|
||||
const flavour = AttachmentBlockSchema.model.flavour;
|
||||
|
||||
for (const blob of files) {
|
||||
const id = nanoid();
|
||||
const { name, size, type } = blob;
|
||||
@@ -108,22 +127,21 @@ export class AttachmentAdapter extends BaseAdapter<Attachment> {
|
||||
mapInto: sourceId => ({ sourceId }),
|
||||
});
|
||||
|
||||
content.push({
|
||||
type: 'block',
|
||||
flavour,
|
||||
id,
|
||||
props: {
|
||||
name,
|
||||
size,
|
||||
type,
|
||||
embed: false,
|
||||
style: 'horizontalThin',
|
||||
index: 'a0',
|
||||
xywh: '[0,0,0,0]',
|
||||
rotate: 0,
|
||||
},
|
||||
children: [],
|
||||
});
|
||||
content.push(
|
||||
createAttachmentBlockSnapshot({
|
||||
id,
|
||||
props: {
|
||||
name,
|
||||
size,
|
||||
type,
|
||||
embed: false,
|
||||
style: 'horizontalThin',
|
||||
index: 'a0',
|
||||
xywh: '[0,0,0,0]',
|
||||
rotate: 0,
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
function safeDecodePathReference(path: string): string {
|
||||
try {
|
||||
return decodeURIComponent(path);
|
||||
} catch {
|
||||
return path;
|
||||
}
|
||||
}
|
||||
|
||||
export function normalizeFilePathReference(path: string): string {
|
||||
return safeDecodePathReference(path)
|
||||
.trim()
|
||||
.replace(/\\/g, '/')
|
||||
.replace(/^\.\/+/, '')
|
||||
.replace(/^\/+/, '')
|
||||
.replace(/\/+/g, '/');
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalizes a relative path by resolving all relative path segments
|
||||
* @param basePath The base path (markdown file's directory)
|
||||
@@ -40,7 +57,7 @@ export function getImageFullPath(
|
||||
imageReference: string
|
||||
): string {
|
||||
// Decode the image reference in case it contains URL-encoded characters
|
||||
const decodedReference = decodeURIComponent(imageReference);
|
||||
const decodedReference = safeDecodePathReference(imageReference);
|
||||
|
||||
// Get the directory of the file path
|
||||
const markdownDir = filePath.substring(0, filePath.lastIndexOf('/'));
|
||||
|
||||
@@ -20,9 +20,30 @@ declare global {
|
||||
showOpenFilePicker?: (
|
||||
options?: OpenFilePickerOptions
|
||||
) => Promise<FileSystemFileHandle[]>;
|
||||
// Window API: showDirectoryPicker
|
||||
showDirectoryPicker?: (options?: {
|
||||
id?: string;
|
||||
mode?: 'read' | 'readwrite';
|
||||
startIn?: FileSystemHandle | string;
|
||||
}) => Promise<FileSystemDirectoryHandle>;
|
||||
}
|
||||
}
|
||||
|
||||
// Minimal polyfill for FileSystemDirectoryHandle to iterate over files
|
||||
interface FileSystemDirectoryHandle {
|
||||
kind: 'directory';
|
||||
name: string;
|
||||
values(): AsyncIterableIterator<
|
||||
FileSystemFileHandle | FileSystemDirectoryHandle
|
||||
>;
|
||||
}
|
||||
|
||||
interface FileSystemFileHandle {
|
||||
kind: 'file';
|
||||
name: string;
|
||||
getFile(): Promise<File>;
|
||||
}
|
||||
|
||||
// See [Common MIME types](https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types/Common_types)
|
||||
const FileTypes: NonNullable<OpenFilePickerOptions['types']> = [
|
||||
{
|
||||
@@ -121,21 +142,27 @@ type AcceptTypes =
|
||||
| 'Docx'
|
||||
| 'MindMap';
|
||||
|
||||
export async function openFilesWith(
|
||||
acceptType: AcceptTypes = 'Any',
|
||||
multiple: boolean = true
|
||||
): Promise<File[] | null> {
|
||||
// Feature detection. The API needs to be supported
|
||||
// and the app not run in an iframe.
|
||||
const supportsFileSystemAccess =
|
||||
'showOpenFilePicker' in window &&
|
||||
function canUseFileSystemAccessAPI(
|
||||
api: 'showOpenFilePicker' | 'showDirectoryPicker'
|
||||
) {
|
||||
return (
|
||||
api in window &&
|
||||
(() => {
|
||||
try {
|
||||
return window.self === window.top;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
})();
|
||||
})()
|
||||
);
|
||||
}
|
||||
|
||||
export async function openFilesWith(
|
||||
acceptType: AcceptTypes = 'Any',
|
||||
multiple: boolean = true
|
||||
): Promise<File[] | null> {
|
||||
const supportsFileSystemAccess =
|
||||
canUseFileSystemAccessAPI('showOpenFilePicker');
|
||||
|
||||
// If the File System Access API is supported…
|
||||
if (supportsFileSystemAccess && window.showOpenFilePicker) {
|
||||
@@ -194,6 +221,75 @@ export async function openFilesWith(
|
||||
});
|
||||
}
|
||||
|
||||
export async function openDirectory(): Promise<File[] | null> {
|
||||
const supportsFileSystemAccess = canUseFileSystemAccessAPI(
|
||||
'showDirectoryPicker'
|
||||
);
|
||||
|
||||
if (supportsFileSystemAccess && window.showDirectoryPicker) {
|
||||
try {
|
||||
const dirHandle = await window.showDirectoryPicker();
|
||||
const files: File[] = [];
|
||||
|
||||
const readDirectory = async (
|
||||
directoryHandle: FileSystemDirectoryHandle,
|
||||
path: string
|
||||
) => {
|
||||
for await (const handle of directoryHandle.values()) {
|
||||
const relativePath = path ? `${path}/${handle.name}` : handle.name;
|
||||
if (handle.kind === 'file') {
|
||||
const fileHandle = handle as FileSystemFileHandle;
|
||||
if (fileHandle.getFile) {
|
||||
const file = await fileHandle.getFile();
|
||||
Object.defineProperty(file, 'webkitRelativePath', {
|
||||
value: relativePath,
|
||||
writable: false,
|
||||
});
|
||||
files.push(file);
|
||||
}
|
||||
} else if (handle.kind === 'directory') {
|
||||
await readDirectory(
|
||||
handle as FileSystemDirectoryHandle,
|
||||
relativePath
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
await readDirectory(dirHandle, '');
|
||||
return files;
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return new Promise(resolve => {
|
||||
const input = document.createElement('input');
|
||||
input.classList.add('affine-upload-input');
|
||||
input.style.display = 'none';
|
||||
input.type = 'file';
|
||||
|
||||
input.setAttribute('webkitdirectory', '');
|
||||
input.setAttribute('directory', '');
|
||||
|
||||
document.body.append(input);
|
||||
|
||||
input.addEventListener('change', () => {
|
||||
input.remove();
|
||||
resolve(input.files ? Array.from(input.files) : null);
|
||||
});
|
||||
|
||||
input.addEventListener('cancel', () => resolve(null));
|
||||
|
||||
if ('showPicker' in HTMLInputElement.prototype) {
|
||||
input.showPicker();
|
||||
} else {
|
||||
input.click();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
export async function openSingleFileWith(
|
||||
acceptType?: AcceptTypes
|
||||
): Promise<File | null> {
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
NotionIcon,
|
||||
} from '@blocksuite/affine-components/icons';
|
||||
import {
|
||||
openDirectory,
|
||||
openFilesWith,
|
||||
openSingleFileWith,
|
||||
} from '@blocksuite/affine-shared/utils';
|
||||
@@ -18,11 +19,16 @@ import { query, state } from 'lit/decorators.js';
|
||||
import { HtmlTransformer } from '../transformers/html.js';
|
||||
import { MarkdownTransformer } from '../transformers/markdown.js';
|
||||
import { NotionHtmlTransformer } from '../transformers/notion-html.js';
|
||||
import { ObsidianTransformer } from '../transformers/obsidian.js';
|
||||
import { styles } from './styles.js';
|
||||
|
||||
export type OnSuccessHandler = (
|
||||
pageIds: string[],
|
||||
options: { isWorkspaceFile: boolean; importedCount: number }
|
||||
options: {
|
||||
isWorkspaceFile: boolean;
|
||||
importedCount: number;
|
||||
docEmojis?: Map<string, string>;
|
||||
}
|
||||
) => void;
|
||||
|
||||
export type OnFailHandler = (message: string) => void;
|
||||
@@ -140,6 +146,29 @@ export class ImportDoc extends WithDisposable(LitElement) {
|
||||
});
|
||||
}
|
||||
|
||||
private async _importObsidian() {
|
||||
const files = await openDirectory();
|
||||
if (!files || files.length === 0) return;
|
||||
const needLoading =
|
||||
files.reduce((acc, f) => acc + f.size, 0) > SHOW_LOADING_SIZE;
|
||||
if (needLoading) {
|
||||
this.hidden = false;
|
||||
this._loading = true;
|
||||
} else {
|
||||
this.abortController.abort();
|
||||
}
|
||||
const { docIds, docEmojis } = await ObsidianTransformer.importObsidianVault(
|
||||
{
|
||||
collection: this.collection,
|
||||
schema: this.schema,
|
||||
importedFiles: files,
|
||||
extensions: this.extensions,
|
||||
}
|
||||
);
|
||||
needLoading && this.abortController.abort();
|
||||
this._onImportSuccess(docIds, { docEmojis });
|
||||
}
|
||||
|
||||
private _onCloseClick(event: MouseEvent) {
|
||||
event.stopPropagation();
|
||||
this.abortController.abort();
|
||||
@@ -151,15 +180,21 @@ export class ImportDoc extends WithDisposable(LitElement) {
|
||||
|
||||
private _onImportSuccess(
|
||||
pageIds: string[],
|
||||
options: { isWorkspaceFile?: boolean; importedCount?: number } = {}
|
||||
options: {
|
||||
isWorkspaceFile?: boolean;
|
||||
importedCount?: number;
|
||||
docEmojis?: Map<string, string>;
|
||||
} = {}
|
||||
) {
|
||||
const {
|
||||
isWorkspaceFile = false,
|
||||
importedCount: pagesImportedCount = pageIds.length,
|
||||
docEmojis,
|
||||
} = options;
|
||||
this.onSuccess?.(pageIds, {
|
||||
isWorkspaceFile,
|
||||
importedCount: pagesImportedCount,
|
||||
docEmojis,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -258,6 +293,13 @@ export class ImportDoc extends WithDisposable(LitElement) {
|
||||
</affine-tooltip>
|
||||
</div>
|
||||
</icon-button>
|
||||
<icon-button
|
||||
class="button-item"
|
||||
text="Obsidian"
|
||||
@click="${this._importObsidian}"
|
||||
>
|
||||
${ExportToMarkdownIcon}
|
||||
</icon-button>
|
||||
<icon-button class="button-item" text="Coming soon..." disabled>
|
||||
${NewIcon}
|
||||
</icon-button>
|
||||
|
||||
@@ -2,6 +2,7 @@ export { DocxTransformer } from './docx.js';
|
||||
export { HtmlTransformer } from './html.js';
|
||||
export { MarkdownTransformer } from './markdown.js';
|
||||
export { NotionHtmlTransformer } from './notion-html.js';
|
||||
export { ObsidianTransformer } from './obsidian.js';
|
||||
export { PdfTransformer } from './pdf.js';
|
||||
export { createAssetsArchive, download } from './utils.js';
|
||||
export { ZipTransformer } from './zip.js';
|
||||
|
||||
@@ -21,8 +21,11 @@ import { extMimeMap, Transformer } from '@blocksuite/store';
|
||||
import type { AssetMap, ImportedFileEntry, PathBlobIdMap } from './type.js';
|
||||
import { createAssetsArchive, download, parseMatter, Unzip } from './utils.js';
|
||||
|
||||
type ParsedFrontmatterMeta = Partial<
|
||||
Pick<DocMeta, 'title' | 'createDate' | 'updatedDate' | 'tags' | 'favorite'>
|
||||
export type ParsedFrontmatterMeta = Partial<
|
||||
Pick<
|
||||
DocMeta,
|
||||
'title' | 'createDate' | 'updatedDate' | 'tags' | 'favorite' | 'trash'
|
||||
>
|
||||
>;
|
||||
|
||||
const FRONTMATTER_KEYS = {
|
||||
@@ -150,11 +153,18 @@ function buildMetaFromFrontmatter(
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (FRONTMATTER_KEYS.trash.includes(key)) {
|
||||
const trash = parseBoolean(value);
|
||||
if (trash !== undefined) {
|
||||
meta.trash = trash;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return meta;
|
||||
}
|
||||
|
||||
function parseFrontmatter(markdown: string): {
|
||||
export function parseFrontmatter(markdown: string): {
|
||||
content: string;
|
||||
meta: ParsedFrontmatterMeta;
|
||||
} {
|
||||
@@ -176,7 +186,7 @@ function parseFrontmatter(markdown: string): {
|
||||
}
|
||||
}
|
||||
|
||||
function applyMetaPatch(
|
||||
export function applyMetaPatch(
|
||||
collection: Workspace,
|
||||
docId: string,
|
||||
meta: ParsedFrontmatterMeta
|
||||
@@ -187,13 +197,14 @@ function applyMetaPatch(
|
||||
if (meta.updatedDate !== undefined) metaPatch.updatedDate = meta.updatedDate;
|
||||
if (meta.tags) metaPatch.tags = meta.tags;
|
||||
if (meta.favorite !== undefined) metaPatch.favorite = meta.favorite;
|
||||
if (meta.trash !== undefined) metaPatch.trash = meta.trash;
|
||||
|
||||
if (Object.keys(metaPatch).length) {
|
||||
collection.meta.setDocMeta(docId, metaPatch);
|
||||
}
|
||||
}
|
||||
|
||||
function getProvider(extensions: ExtensionType[]) {
|
||||
export function getProvider(extensions: ExtensionType[]) {
|
||||
const container = new Container();
|
||||
extensions.forEach(ext => {
|
||||
ext.setup(container);
|
||||
@@ -223,6 +234,103 @@ type ImportMarkdownZipOptions = {
|
||||
extensions: ExtensionType[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Filters hidden/system entries that should never participate in imports.
|
||||
*/
|
||||
export function isSystemImportPath(path: string) {
|
||||
return path.includes('__MACOSX') || path.includes('.DS_Store');
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates the doc CRUD bridge used by importer transformers.
|
||||
*/
|
||||
export function createCollectionDocCRUD(collection: Workspace) {
|
||||
return {
|
||||
create: (id: string) => collection.createDoc(id).getStore({ id }),
|
||||
get: (id: string) => collection.getDoc(id)?.getStore({ id }) ?? null,
|
||||
delete: (id: string) => collection.removeDoc(id),
|
||||
};
|
||||
}
|
||||
|
||||
type CreateMarkdownImportJobOptions = {
|
||||
collection: Workspace;
|
||||
schema: Schema;
|
||||
preferredTitle?: string;
|
||||
fullPath?: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a markdown import job with the standard collection middlewares.
|
||||
*/
|
||||
export function createMarkdownImportJob({
|
||||
collection,
|
||||
schema,
|
||||
preferredTitle,
|
||||
fullPath,
|
||||
}: CreateMarkdownImportJobOptions) {
|
||||
return new Transformer({
|
||||
schema,
|
||||
blobCRUD: collection.blobSync,
|
||||
docCRUD: createCollectionDocCRUD(collection),
|
||||
middlewares: [
|
||||
defaultImageProxyMiddleware,
|
||||
fileNameMiddleware(preferredTitle),
|
||||
docLinkBaseURLMiddleware(collection.id),
|
||||
...(fullPath ? [filePathMiddleware(fullPath)] : []),
|
||||
],
|
||||
});
|
||||
}
|
||||
|
||||
type StageImportedAssetOptions = {
|
||||
pendingAssets: AssetMap;
|
||||
pendingPathBlobIdMap: PathBlobIdMap;
|
||||
path: string;
|
||||
content: Blob;
|
||||
fileName: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* Hashes a non-markdown import file and stages it into the shared asset maps.
|
||||
*/
|
||||
export async function stageImportedAsset({
|
||||
pendingAssets,
|
||||
pendingPathBlobIdMap,
|
||||
path,
|
||||
content,
|
||||
fileName,
|
||||
}: StageImportedAssetOptions) {
|
||||
const ext = path.split('.').at(-1) ?? '';
|
||||
const mime = extMimeMap.get(ext.toLowerCase()) ?? '';
|
||||
const key = await sha(await content.arrayBuffer());
|
||||
pendingPathBlobIdMap.set(path, key);
|
||||
pendingAssets.set(key, new File([content], fileName, { type: mime }));
|
||||
}
|
||||
|
||||
/**
|
||||
* Binds previously staged asset files into a transformer job before import.
|
||||
*/
|
||||
export function bindImportedAssetsToJob(
|
||||
job: Transformer,
|
||||
pendingAssets: AssetMap,
|
||||
pendingPathBlobIdMap: PathBlobIdMap
|
||||
) {
|
||||
const pathBlobIdMap = job.assetsManager.getPathBlobIdMap();
|
||||
// Iterate over all assets to be imported
|
||||
for (const [assetPath, key] of pendingPathBlobIdMap.entries()) {
|
||||
// Get the relative path of the asset to the markdown file
|
||||
// Store the path to blobId map
|
||||
pathBlobIdMap.set(assetPath, key);
|
||||
// Store the asset to assets, the key is the blobId, the value is the file object
|
||||
// In block adapter, it will use the blobId to get the file object
|
||||
const assetFile = pendingAssets.get(key);
|
||||
if (assetFile) {
|
||||
job.assets.set(key, assetFile);
|
||||
}
|
||||
}
|
||||
|
||||
return pathBlobIdMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Exports a doc to a Markdown file or a zip archive containing Markdown and assets.
|
||||
* @param doc The doc to export
|
||||
@@ -329,19 +437,10 @@ async function importMarkdownToDoc({
|
||||
const { content, meta } = parseFrontmatter(markdown);
|
||||
const preferredTitle = meta.title ?? fileName;
|
||||
const provider = getProvider(extensions);
|
||||
const job = new Transformer({
|
||||
const job = createMarkdownImportJob({
|
||||
collection,
|
||||
schema,
|
||||
blobCRUD: collection.blobSync,
|
||||
docCRUD: {
|
||||
create: (id: string) => collection.createDoc(id).getStore({ id }),
|
||||
get: (id: string) => collection.getDoc(id)?.getStore({ id }) ?? null,
|
||||
delete: (id: string) => collection.removeDoc(id),
|
||||
},
|
||||
middlewares: [
|
||||
defaultImageProxyMiddleware,
|
||||
fileNameMiddleware(preferredTitle),
|
||||
docLinkBaseURLMiddleware(collection.id),
|
||||
],
|
||||
preferredTitle,
|
||||
});
|
||||
const mdAdapter = new MarkdownAdapter(job, provider);
|
||||
const page = await mdAdapter.toDoc({
|
||||
@@ -381,7 +480,7 @@ async function importMarkdownZip({
|
||||
// Iterate over all files in the zip
|
||||
for (const { path, content: blob } of unzip) {
|
||||
// Skip the files that are not markdown files
|
||||
if (path.includes('__MACOSX') || path.includes('.DS_Store')) {
|
||||
if (isSystemImportPath(path)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -395,12 +494,13 @@ async function importMarkdownZip({
|
||||
fullPath: path,
|
||||
});
|
||||
} else {
|
||||
// If the file is not a markdown file, store it to pendingAssets
|
||||
const ext = path.split('.').at(-1) ?? '';
|
||||
const mime = extMimeMap.get(ext) ?? '';
|
||||
const key = await sha(await blob.arrayBuffer());
|
||||
pendingPathBlobIdMap.set(path, key);
|
||||
pendingAssets.set(key, new File([blob], fileName, { type: mime }));
|
||||
await stageImportedAsset({
|
||||
pendingAssets,
|
||||
pendingPathBlobIdMap,
|
||||
path,
|
||||
content: blob,
|
||||
fileName,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -411,34 +511,13 @@ async function importMarkdownZip({
|
||||
const markdown = await contentBlob.text();
|
||||
const { content, meta } = parseFrontmatter(markdown);
|
||||
const preferredTitle = meta.title ?? fileNameWithoutExt;
|
||||
const job = new Transformer({
|
||||
const job = createMarkdownImportJob({
|
||||
collection,
|
||||
schema,
|
||||
blobCRUD: collection.blobSync,
|
||||
docCRUD: {
|
||||
create: (id: string) => collection.createDoc(id).getStore({ id }),
|
||||
get: (id: string) => collection.getDoc(id)?.getStore({ id }) ?? null,
|
||||
delete: (id: string) => collection.removeDoc(id),
|
||||
},
|
||||
middlewares: [
|
||||
defaultImageProxyMiddleware,
|
||||
fileNameMiddleware(preferredTitle),
|
||||
docLinkBaseURLMiddleware(collection.id),
|
||||
filePathMiddleware(fullPath),
|
||||
],
|
||||
preferredTitle,
|
||||
fullPath,
|
||||
});
|
||||
const assets = job.assets;
|
||||
const pathBlobIdMap = job.assetsManager.getPathBlobIdMap();
|
||||
// Iterate over all assets to be imported
|
||||
for (const [assetPath, key] of pendingPathBlobIdMap.entries()) {
|
||||
// Get the relative path of the asset to the markdown file
|
||||
// Store the path to blobId map
|
||||
pathBlobIdMap.set(assetPath, key);
|
||||
// Store the asset to assets, the key is the blobId, the value is the file object
|
||||
// In block adapter, it will use the blobId to get the file object
|
||||
if (pendingAssets.get(key)) {
|
||||
assets.set(key, pendingAssets.get(key)!);
|
||||
}
|
||||
}
|
||||
bindImportedAssetsToJob(job, pendingAssets, pendingPathBlobIdMap);
|
||||
|
||||
const mdAdapter = new MarkdownAdapter(job, provider);
|
||||
const doc = await mdAdapter.toDoc({
|
||||
|
||||
@@ -0,0 +1,732 @@
|
||||
import { FootNoteReferenceParamsSchema } from '@blocksuite/affine-model';
|
||||
import {
|
||||
BlockMarkdownAdapterExtension,
|
||||
createAttachmentBlockSnapshot,
|
||||
FULL_FILE_PATH_KEY,
|
||||
getImageFullPath,
|
||||
MarkdownAdapter,
|
||||
type MarkdownAST,
|
||||
MarkdownASTToDeltaExtension,
|
||||
normalizeFilePathReference,
|
||||
} from '@blocksuite/affine-shared/adapters';
|
||||
import type { AffineTextAttributes } from '@blocksuite/affine-shared/types';
|
||||
import type {
|
||||
DeltaInsert,
|
||||
ExtensionType,
|
||||
Schema,
|
||||
Workspace,
|
||||
} from '@blocksuite/store';
|
||||
import { extMimeMap, nanoid } from '@blocksuite/store';
|
||||
import type { Html, Text } from 'mdast';
|
||||
|
||||
import {
|
||||
applyMetaPatch,
|
||||
bindImportedAssetsToJob,
|
||||
createMarkdownImportJob,
|
||||
getProvider,
|
||||
isSystemImportPath,
|
||||
parseFrontmatter,
|
||||
stageImportedAsset,
|
||||
} from './markdown.js';
|
||||
import type {
|
||||
AssetMap,
|
||||
MarkdownFileImportEntry,
|
||||
PathBlobIdMap,
|
||||
} from './type.js';
|
||||
|
||||
const CALLOUT_TYPE_MAP: Record<string, string> = {
|
||||
note: '💡',
|
||||
info: 'ℹ️',
|
||||
tip: '🔥',
|
||||
hint: '✅',
|
||||
important: '‼️',
|
||||
warning: '⚠️',
|
||||
caution: '⚠️',
|
||||
attention: '⚠️',
|
||||
danger: '⚠️',
|
||||
error: '🚨',
|
||||
bug: '🐛',
|
||||
example: '📌',
|
||||
quote: '💬',
|
||||
cite: '💬',
|
||||
abstract: '📋',
|
||||
summary: '📋',
|
||||
todo: '☑️',
|
||||
success: '✅',
|
||||
check: '✅',
|
||||
done: '✅',
|
||||
failure: '❌',
|
||||
fail: '❌',
|
||||
missing: '❌',
|
||||
question: '❓',
|
||||
help: '❓',
|
||||
faq: '❓',
|
||||
};
|
||||
|
||||
const AMBIGUOUS_PAGE_LOOKUP = '__ambiguous__';
|
||||
const DEFAULT_CALLOUT_EMOJI = '💡';
|
||||
const OBSIDIAN_TEXT_FOOTNOTE_URL_PREFIX = 'data:text/plain;charset=utf-8,';
|
||||
const OBSIDIAN_ATTACHMENT_EMBED_TAG = 'obsidian-attachment';
|
||||
|
||||
function normalizeLookupKey(value: string): string {
|
||||
return normalizeFilePathReference(value).toLowerCase();
|
||||
}
|
||||
|
||||
function stripMarkdownExtension(value: string): string {
|
||||
return value.replace(/\.md$/i, '');
|
||||
}
|
||||
|
||||
function basename(value: string): string {
|
||||
return normalizeFilePathReference(value).split('/').pop() ?? value;
|
||||
}
|
||||
|
||||
function parseObsidianTarget(rawTarget: string): {
|
||||
path: string;
|
||||
fragment: string | null;
|
||||
} {
|
||||
const normalizedTarget = normalizeFilePathReference(rawTarget);
|
||||
const match = normalizedTarget.match(/^([^#^]+)([#^].*)?$/);
|
||||
|
||||
return {
|
||||
path: match?.[1]?.trim() ?? normalizedTarget,
|
||||
fragment: match?.[2] ?? null,
|
||||
};
|
||||
}
|
||||
|
||||
function extractTitleAndEmoji(rawTitle: string): {
|
||||
title: string;
|
||||
emoji: string | null;
|
||||
} {
|
||||
const SINGLE_LEADING_EMOJI_RE =
|
||||
/^[\s\u200b]*((?:[\p{Emoji_Presentation}\p{Extended_Pictographic}\u200b]|\u200d|\ufe0f)+)/u;
|
||||
|
||||
let currentTitle = rawTitle;
|
||||
let extractedEmojiClusters = '';
|
||||
let emojiMatch;
|
||||
|
||||
while ((emojiMatch = currentTitle.match(SINGLE_LEADING_EMOJI_RE))) {
|
||||
const matchedCluster = emojiMatch[1].trim();
|
||||
extractedEmojiClusters +=
|
||||
(extractedEmojiClusters ? ' ' : '') + matchedCluster;
|
||||
currentTitle = currentTitle.slice(emojiMatch[0].length);
|
||||
}
|
||||
|
||||
return {
|
||||
title: currentTitle.trim(),
|
||||
emoji: extractedEmojiClusters || null,
|
||||
};
|
||||
}
|
||||
|
||||
function preprocessTitleHeader(markdown: string): string {
|
||||
return markdown.replace(
|
||||
/^(\s*#\s+)(.*)$/m,
|
||||
(_, headerPrefix, titleContent) => {
|
||||
const { title: cleanTitle } = extractTitleAndEmoji(titleContent);
|
||||
return `${headerPrefix}${cleanTitle}`;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
function preprocessObsidianCallouts(markdown: string): string {
|
||||
return markdown.replace(
|
||||
/^(> *)\[!([^\]\n]+)\]([+-]?)([^\n]*)/gm,
|
||||
(_, prefix, type, _fold, rest) => {
|
||||
const calloutToken =
|
||||
CALLOUT_TYPE_MAP[type.trim().toLowerCase()] ?? DEFAULT_CALLOUT_EMOJI;
|
||||
const title = rest.trim();
|
||||
return title
|
||||
? `${prefix}[!${calloutToken}] ${title}`
|
||||
: `${prefix}[!${calloutToken}]`;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
function isStructuredFootnoteDefinition(content: string): boolean {
|
||||
try {
|
||||
return FootNoteReferenceParamsSchema.safeParse(JSON.parse(content.trim()))
|
||||
.success;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function splitFootnoteTextContent(content: string): {
|
||||
title: string;
|
||||
description?: string;
|
||||
} {
|
||||
const lines = content
|
||||
.split('\n')
|
||||
.map(line => line.trim())
|
||||
.filter(Boolean);
|
||||
const title = lines[0] ?? content.trim();
|
||||
const description = lines.slice(1).join('\n').trim();
|
||||
|
||||
return {
|
||||
title,
|
||||
...(description ? { description } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
function createTextFootnoteDefinition(content: string): string {
|
||||
const normalizedContent = content.trim();
|
||||
const { title, description } = splitFootnoteTextContent(normalizedContent);
|
||||
|
||||
return JSON.stringify({
|
||||
type: 'url',
|
||||
url: encodeURIComponent(
|
||||
`${OBSIDIAN_TEXT_FOOTNOTE_URL_PREFIX}${encodeURIComponent(
|
||||
normalizedContent
|
||||
)}`
|
||||
),
|
||||
title,
|
||||
...(description ? { description } : {}),
|
||||
});
|
||||
}
|
||||
|
||||
function extractObsidianFootnotes(markdown: string): {
|
||||
content: string;
|
||||
footnotes: string[];
|
||||
} {
|
||||
const lines = markdown.split('\n');
|
||||
const output: string[] = [];
|
||||
const footnotes: string[] = [];
|
||||
|
||||
for (let index = 0; index < lines.length; index += 1) {
|
||||
const line = lines[index];
|
||||
const match = line.match(/^\[\^([^\]]+)\]:\s*(.*)$/);
|
||||
if (!match) {
|
||||
output.push(line);
|
||||
continue;
|
||||
}
|
||||
|
||||
const identifier = match[1];
|
||||
const contentLines = [match[2]];
|
||||
|
||||
while (index + 1 < lines.length) {
|
||||
const nextLine = lines[index + 1];
|
||||
if (/^(?: {1,4}|\t)/.test(nextLine)) {
|
||||
contentLines.push(nextLine.replace(/^(?: {1,4}|\t)/, ''));
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (
|
||||
nextLine.trim() === '' &&
|
||||
index + 2 < lines.length &&
|
||||
/^(?: {1,4}|\t)/.test(lines[index + 2])
|
||||
) {
|
||||
contentLines.push('');
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
const content = contentLines.join('\n').trim();
|
||||
footnotes.push(
|
||||
`[^${identifier}]: ${
|
||||
!content || isStructuredFootnoteDefinition(content)
|
||||
? content
|
||||
: createTextFootnoteDefinition(content)
|
||||
}`
|
||||
);
|
||||
}
|
||||
|
||||
return { content: output.join('\n'), footnotes };
|
||||
}
|
||||
|
||||
function buildLookupKeys(
|
||||
targetPath: string,
|
||||
currentFilePath?: string
|
||||
): string[] {
|
||||
const parsedTargetPath = normalizeFilePathReference(targetPath);
|
||||
if (!parsedTargetPath) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const keys = new Set<string>();
|
||||
const addPathVariants = (value: string) => {
|
||||
const normalizedValue = normalizeFilePathReference(value);
|
||||
if (!normalizedValue) {
|
||||
return;
|
||||
}
|
||||
|
||||
keys.add(normalizedValue);
|
||||
keys.add(stripMarkdownExtension(normalizedValue));
|
||||
|
||||
const fileName = basename(normalizedValue);
|
||||
keys.add(fileName);
|
||||
keys.add(stripMarkdownExtension(fileName));
|
||||
|
||||
const cleanTitle = extractTitleAndEmoji(
|
||||
stripMarkdownExtension(fileName)
|
||||
).title;
|
||||
if (cleanTitle) {
|
||||
keys.add(cleanTitle);
|
||||
}
|
||||
};
|
||||
|
||||
addPathVariants(parsedTargetPath);
|
||||
|
||||
if (currentFilePath) {
|
||||
addPathVariants(getImageFullPath(currentFilePath, parsedTargetPath));
|
||||
}
|
||||
|
||||
return Array.from(keys).map(normalizeLookupKey);
|
||||
}
|
||||
|
||||
function registerPageLookup(
|
||||
pageLookupMap: Map<string, string>,
|
||||
key: string,
|
||||
pageId: string
|
||||
) {
|
||||
const normalizedKey = normalizeLookupKey(key);
|
||||
if (!normalizedKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const existing = pageLookupMap.get(normalizedKey);
|
||||
if (existing && existing !== pageId) {
|
||||
pageLookupMap.set(normalizedKey, AMBIGUOUS_PAGE_LOOKUP);
|
||||
return;
|
||||
}
|
||||
|
||||
pageLookupMap.set(normalizedKey, pageId);
|
||||
}
|
||||
|
||||
function resolvePageIdFromLookup(
|
||||
pageLookupMap: Pick<ReadonlyMap<string, string>, 'get'>,
|
||||
rawTarget: string,
|
||||
currentFilePath?: string
|
||||
): string | null {
|
||||
const { path } = parseObsidianTarget(rawTarget);
|
||||
for (const key of buildLookupKeys(path, currentFilePath)) {
|
||||
const targetPageId = pageLookupMap.get(key);
|
||||
if (!targetPageId || targetPageId === AMBIGUOUS_PAGE_LOOKUP) {
|
||||
continue;
|
||||
}
|
||||
return targetPageId;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
function resolveWikilinkDisplayTitle(
|
||||
rawAlias: string | undefined,
|
||||
pageEmoji: string | undefined
|
||||
): string | undefined {
|
||||
if (!rawAlias) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const { title: aliasTitle, emoji: aliasEmoji } =
|
||||
extractTitleAndEmoji(rawAlias);
|
||||
|
||||
if (aliasEmoji && aliasEmoji === pageEmoji) {
|
||||
return aliasTitle;
|
||||
}
|
||||
|
||||
return rawAlias;
|
||||
}
|
||||
|
||||
function isImageAssetPath(path: string): boolean {
|
||||
const extension = path.split('.').at(-1)?.toLowerCase() ?? '';
|
||||
return extMimeMap.get(extension)?.startsWith('image/') ?? false;
|
||||
}
|
||||
|
||||
function encodeMarkdownPath(path: string): string {
|
||||
return encodeURI(path).replaceAll('(', '%28').replaceAll(')', '%29');
|
||||
}
|
||||
|
||||
function escapeMarkdownLabel(label: string): string {
|
||||
return label.replace(/[[\]\\]/g, '\\$&');
|
||||
}
|
||||
|
||||
function isObsidianSizeAlias(alias: string | undefined): boolean {
|
||||
return !!alias && /^\d+(?:x\d+)?$/i.test(alias.trim());
|
||||
}
|
||||
|
||||
function getEmbedLabel(
|
||||
rawAlias: string | undefined,
|
||||
targetPath: string,
|
||||
fallbackToFileName: boolean
|
||||
): string {
|
||||
if (!rawAlias || isObsidianSizeAlias(rawAlias)) {
|
||||
return fallbackToFileName
|
||||
? stripMarkdownExtension(basename(targetPath))
|
||||
: '';
|
||||
}
|
||||
|
||||
return rawAlias.trim();
|
||||
}
|
||||
|
||||
type ObsidianAttachmentEmbed = {
|
||||
blobId: string;
|
||||
fileName: string;
|
||||
fileType: string;
|
||||
};
|
||||
|
||||
function createObsidianAttach(embed: ObsidianAttachmentEmbed): string {
|
||||
return `<!-- ${OBSIDIAN_ATTACHMENT_EMBED_TAG} ${encodeURIComponent(
|
||||
JSON.stringify(embed)
|
||||
)} -->`;
|
||||
}
|
||||
|
||||
function parseObsidianAttach(value: string): ObsidianAttachmentEmbed | null {
|
||||
const match = value.match(
|
||||
new RegExp(`^<!-- ${OBSIDIAN_ATTACHMENT_EMBED_TAG} ([^ ]+) -->$`)
|
||||
);
|
||||
if (!match?.[1]) return null;
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(
|
||||
decodeURIComponent(match[1])
|
||||
) as ObsidianAttachmentEmbed;
|
||||
if (!parsed.blobId || !parsed.fileName) {
|
||||
return null;
|
||||
}
|
||||
return parsed;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function preprocessObsidianEmbeds(
|
||||
markdown: string,
|
||||
filePath: string,
|
||||
pageLookupMap: ReadonlyMap<string, string>,
|
||||
pathBlobIdMap: ReadonlyMap<string, string>
|
||||
): string {
|
||||
return markdown.replace(
|
||||
/!\[\[([^\]|]+)(?:\|([^\]]+))?\]\]/g,
|
||||
(match, rawTarget: string, rawAlias?: string) => {
|
||||
const targetPageId = resolvePageIdFromLookup(
|
||||
pageLookupMap,
|
||||
rawTarget,
|
||||
filePath
|
||||
);
|
||||
if (targetPageId) {
|
||||
return `[[${rawTarget}${rawAlias ? `|${rawAlias}` : ''}]]`;
|
||||
}
|
||||
|
||||
const { path } = parseObsidianTarget(rawTarget);
|
||||
if (!path) {
|
||||
return match;
|
||||
}
|
||||
|
||||
const assetPath = getImageFullPath(filePath, path);
|
||||
const encodedPath = encodeMarkdownPath(assetPath);
|
||||
|
||||
if (isImageAssetPath(path)) {
|
||||
const alt = getEmbedLabel(rawAlias, path, false);
|
||||
return ``;
|
||||
}
|
||||
|
||||
const label = getEmbedLabel(rawAlias, path, true);
|
||||
const blobId = pathBlobIdMap.get(assetPath);
|
||||
if (!blobId) return `[${escapeMarkdownLabel(label)}](${encodedPath})`;
|
||||
|
||||
const extension = path.split('.').at(-1)?.toLowerCase() ?? '';
|
||||
return createObsidianAttach({
|
||||
blobId,
|
||||
fileName: basename(path),
|
||||
fileType: extMimeMap.get(extension) ?? '',
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
function preprocessObsidianMarkdown(
|
||||
markdown: string,
|
||||
filePath: string,
|
||||
pageLookupMap: ReadonlyMap<string, string>,
|
||||
pathBlobIdMap: ReadonlyMap<string, string>
|
||||
): string {
|
||||
const { content: contentWithoutFootnotes, footnotes: extractedFootnotes } =
|
||||
extractObsidianFootnotes(markdown);
|
||||
const content = preprocessObsidianEmbeds(
|
||||
contentWithoutFootnotes,
|
||||
filePath,
|
||||
pageLookupMap,
|
||||
pathBlobIdMap
|
||||
);
|
||||
const normalizedMarkdown = preprocessTitleHeader(
|
||||
preprocessObsidianCallouts(content)
|
||||
);
|
||||
|
||||
if (extractedFootnotes.length === 0) {
|
||||
return normalizedMarkdown;
|
||||
}
|
||||
|
||||
const trimmedMarkdown = normalizedMarkdown.replace(/\s+$/, '');
|
||||
return `${trimmedMarkdown}\n\n${extractedFootnotes.join('\n\n')}\n`;
|
||||
}
|
||||
|
||||
function isObsidianAttachmentEmbedNode(node: MarkdownAST): node is Html {
|
||||
return node.type === 'html' && !!parseObsidianAttach(node.value);
|
||||
}
|
||||
|
||||
export const obsidianAttachmentEmbedMarkdownAdapterMatcher =
|
||||
BlockMarkdownAdapterExtension({
|
||||
flavour: 'obsidian:attachment-embed',
|
||||
toMatch: o => isObsidianAttachmentEmbedNode(o.node),
|
||||
fromMatch: () => false,
|
||||
toBlockSnapshot: {
|
||||
enter: (o, context) => {
|
||||
if (!isObsidianAttachmentEmbedNode(o.node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const attachment = parseObsidianAttach(o.node.value);
|
||||
if (!attachment) {
|
||||
return;
|
||||
}
|
||||
|
||||
const assetFile = context.assets?.getAssets().get(attachment.blobId);
|
||||
context.walkerContext
|
||||
.openNode(
|
||||
createAttachmentBlockSnapshot({
|
||||
id: nanoid(),
|
||||
props: {
|
||||
name: attachment.fileName,
|
||||
size: assetFile?.size ?? 0,
|
||||
type:
|
||||
attachment.fileType ||
|
||||
assetFile?.type ||
|
||||
'application/octet-stream',
|
||||
sourceId: attachment.blobId,
|
||||
embed: false,
|
||||
style: 'horizontalThin',
|
||||
footnoteIdentifier: null,
|
||||
},
|
||||
}),
|
||||
'children'
|
||||
)
|
||||
.closeNode();
|
||||
(o.node as unknown as { type: string }).type =
|
||||
'obsidianAttachmentEmbed';
|
||||
},
|
||||
},
|
||||
fromBlockSnapshot: {},
|
||||
});
|
||||
|
||||
export const obsidianWikilinkToDeltaMatcher = MarkdownASTToDeltaExtension({
|
||||
name: 'obsidian-wikilink',
|
||||
match: ast => ast.type === 'text',
|
||||
toDelta: (ast, context) => {
|
||||
const textNode = ast as Text;
|
||||
if (!textNode.value) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const nodeContent = textNode.value;
|
||||
const wikilinkRegex = /\[\[([^\]|]+)(?:\|([^\]]+))?\]\]/g;
|
||||
const deltas: DeltaInsert<AffineTextAttributes>[] = [];
|
||||
|
||||
let lastProcessedIndex = 0;
|
||||
let linkMatch;
|
||||
|
||||
while ((linkMatch = wikilinkRegex.exec(nodeContent)) !== null) {
|
||||
if (linkMatch.index > lastProcessedIndex) {
|
||||
deltas.push({
|
||||
insert: nodeContent.substring(lastProcessedIndex, linkMatch.index),
|
||||
});
|
||||
}
|
||||
|
||||
const targetPageName = linkMatch[1].trim();
|
||||
const alias = linkMatch[2]?.trim();
|
||||
const currentFilePath = context.configs.get(FULL_FILE_PATH_KEY);
|
||||
const targetPageId = resolvePageIdFromLookup(
|
||||
{ get: key => context.configs.get(`obsidian:pageId:${key}`) },
|
||||
targetPageName,
|
||||
typeof currentFilePath === 'string' ? currentFilePath : undefined
|
||||
);
|
||||
|
||||
if (targetPageId) {
|
||||
const pageEmoji = context.configs.get(
|
||||
'obsidian:pageEmoji:' + targetPageId
|
||||
);
|
||||
const displayTitle = resolveWikilinkDisplayTitle(alias, pageEmoji);
|
||||
|
||||
deltas.push({
|
||||
insert: ' ',
|
||||
attributes: {
|
||||
reference: {
|
||||
type: 'LinkedPage',
|
||||
pageId: targetPageId,
|
||||
...(displayTitle ? { title: displayTitle } : {}),
|
||||
},
|
||||
},
|
||||
});
|
||||
} else {
|
||||
deltas.push({ insert: linkMatch[0] });
|
||||
}
|
||||
|
||||
lastProcessedIndex = wikilinkRegex.lastIndex;
|
||||
}
|
||||
|
||||
if (lastProcessedIndex < nodeContent.length) {
|
||||
deltas.push({ insert: nodeContent.substring(lastProcessedIndex) });
|
||||
}
|
||||
|
||||
return deltas;
|
||||
},
|
||||
});
|
||||
|
||||
export type ImportObsidianVaultOptions = {
|
||||
collection: Workspace;
|
||||
schema: Schema;
|
||||
importedFiles: File[];
|
||||
extensions: ExtensionType[];
|
||||
};
|
||||
|
||||
export type ImportObsidianVaultResult = {
|
||||
docIds: string[];
|
||||
docEmojis: Map<string, string>;
|
||||
};
|
||||
|
||||
export async function importObsidianVault({
|
||||
collection,
|
||||
schema,
|
||||
importedFiles,
|
||||
extensions,
|
||||
}: ImportObsidianVaultOptions): Promise<ImportObsidianVaultResult> {
|
||||
const provider = getProvider([
|
||||
obsidianWikilinkToDeltaMatcher,
|
||||
obsidianAttachmentEmbedMarkdownAdapterMatcher,
|
||||
...extensions,
|
||||
]);
|
||||
|
||||
const docIds: string[] = [];
|
||||
const docEmojis = new Map<string, string>();
|
||||
const pendingAssets: AssetMap = new Map();
|
||||
const pendingPathBlobIdMap: PathBlobIdMap = new Map();
|
||||
const markdownBlobs: MarkdownFileImportEntry[] = [];
|
||||
const pageLookupMap = new Map<string, string>();
|
||||
|
||||
for (const file of importedFiles) {
|
||||
const filePath = file.webkitRelativePath || file.name;
|
||||
if (isSystemImportPath(filePath)) continue;
|
||||
|
||||
if (file.name.endsWith('.md')) {
|
||||
const fileNameWithoutExt = file.name.replace(/\.[^/.]+$/, '');
|
||||
const markdown = await file.text();
|
||||
const { content, meta } = parseFrontmatter(markdown);
|
||||
|
||||
const documentTitleCandidate = meta.title ?? fileNameWithoutExt;
|
||||
const { title: preferredTitle, emoji: leadingEmoji } =
|
||||
extractTitleAndEmoji(documentTitleCandidate);
|
||||
|
||||
const newPageId = collection.idGenerator();
|
||||
registerPageLookup(pageLookupMap, filePath, newPageId);
|
||||
registerPageLookup(
|
||||
pageLookupMap,
|
||||
stripMarkdownExtension(filePath),
|
||||
newPageId
|
||||
);
|
||||
registerPageLookup(pageLookupMap, file.name, newPageId);
|
||||
registerPageLookup(pageLookupMap, fileNameWithoutExt, newPageId);
|
||||
registerPageLookup(pageLookupMap, documentTitleCandidate, newPageId);
|
||||
registerPageLookup(pageLookupMap, preferredTitle, newPageId);
|
||||
|
||||
if (leadingEmoji) {
|
||||
docEmojis.set(newPageId, leadingEmoji);
|
||||
}
|
||||
|
||||
markdownBlobs.push({
|
||||
filename: file.name,
|
||||
contentBlob: file,
|
||||
fullPath: filePath,
|
||||
pageId: newPageId,
|
||||
preferredTitle,
|
||||
content,
|
||||
meta,
|
||||
});
|
||||
} else {
|
||||
await stageImportedAsset({
|
||||
pendingAssets,
|
||||
pendingPathBlobIdMap,
|
||||
path: filePath,
|
||||
content: file,
|
||||
fileName: file.name,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
for (const existingDocMeta of collection.meta.docMetas) {
|
||||
if (existingDocMeta.title) {
|
||||
registerPageLookup(
|
||||
pageLookupMap,
|
||||
existingDocMeta.title,
|
||||
existingDocMeta.id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
await Promise.all(
|
||||
markdownBlobs.map(async markdownFile => {
|
||||
const {
|
||||
fullPath,
|
||||
pageId: predefinedId,
|
||||
preferredTitle,
|
||||
content,
|
||||
meta,
|
||||
} = markdownFile;
|
||||
|
||||
const job = createMarkdownImportJob({
|
||||
collection,
|
||||
schema,
|
||||
preferredTitle,
|
||||
fullPath,
|
||||
});
|
||||
|
||||
for (const [lookupKey, id] of pageLookupMap.entries()) {
|
||||
if (id === AMBIGUOUS_PAGE_LOOKUP) {
|
||||
continue;
|
||||
}
|
||||
job.adapterConfigs.set(`obsidian:pageId:${lookupKey}`, id);
|
||||
}
|
||||
for (const [id, emoji] of docEmojis.entries()) {
|
||||
job.adapterConfigs.set('obsidian:pageEmoji:' + id, emoji);
|
||||
}
|
||||
|
||||
const pathBlobIdMap = bindImportedAssetsToJob(
|
||||
job,
|
||||
pendingAssets,
|
||||
pendingPathBlobIdMap
|
||||
);
|
||||
|
||||
const preprocessedMarkdown = preprocessObsidianMarkdown(
|
||||
content,
|
||||
fullPath,
|
||||
pageLookupMap,
|
||||
pathBlobIdMap
|
||||
);
|
||||
const mdAdapter = new MarkdownAdapter(job, provider);
|
||||
const snapshot = await mdAdapter.toDocSnapshot({
|
||||
file: preprocessedMarkdown,
|
||||
assets: job.assetsManager,
|
||||
});
|
||||
|
||||
if (snapshot) {
|
||||
snapshot.meta.id = predefinedId;
|
||||
const doc = await job.snapshotToDoc(snapshot);
|
||||
if (doc) {
|
||||
applyMetaPatch(collection, doc.id, {
|
||||
...meta,
|
||||
title: preferredTitle,
|
||||
trash: false,
|
||||
});
|
||||
docIds.push(doc.id);
|
||||
}
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
return { docIds, docEmojis };
|
||||
}
|
||||
|
||||
export const ObsidianTransformer = {
|
||||
importObsidianVault,
|
||||
};
|
||||
@@ -1,3 +1,5 @@
|
||||
import type { ParsedFrontmatterMeta } from './markdown.js';
|
||||
|
||||
/**
|
||||
* Represents an imported file entry in the zip archive
|
||||
*/
|
||||
@@ -10,6 +12,13 @@ export type ImportedFileEntry = {
|
||||
fullPath: string;
|
||||
};
|
||||
|
||||
export type MarkdownFileImportEntry = ImportedFileEntry & {
|
||||
pageId: string;
|
||||
preferredTitle: string;
|
||||
content: string;
|
||||
meta: ParsedFrontmatterMeta;
|
||||
};
|
||||
|
||||
/**
|
||||
* Map of asset hash to File object for all media files in the zip
|
||||
* Key: SHA hash of the file content (blobId)
|
||||
|
||||
@@ -162,10 +162,11 @@ export class AffineToolbarWidget extends WidgetComponent {
|
||||
}
|
||||
|
||||
setReferenceElementWithElements(gfx: GfxController, elements: GfxModel[]) {
|
||||
const surfaceBounds = getCommonBoundWithRotation(elements);
|
||||
|
||||
const getBoundingClientRect = () => {
|
||||
const bounds = getCommonBoundWithRotation(elements);
|
||||
const { x: offsetX, y: offsetY } = this.getBoundingClientRect();
|
||||
const [x, y, w, h] = gfx.viewport.toViewBound(bounds).toXYWH();
|
||||
const [x, y, w, h] = gfx.viewport.toViewBound(surfaceBounds).toXYWH();
|
||||
const rect = new DOMRect(x + offsetX, y + offsetY, w, h);
|
||||
return rect;
|
||||
};
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@vitest/browser-playwright": "^4.0.18",
|
||||
"playwright": "=1.58.2",
|
||||
"vitest": "^4.0.18"
|
||||
},
|
||||
"exports": {
|
||||
|
||||
@@ -103,8 +103,9 @@ export abstract class GfxPrimitiveElementModel<
|
||||
}
|
||||
|
||||
get deserializedXYWH() {
|
||||
if (!this._lastXYWH || this.xywh !== this._lastXYWH) {
|
||||
const xywh = this.xywh;
|
||||
const xywh = this.xywh;
|
||||
|
||||
if (!this._lastXYWH || xywh !== this._lastXYWH) {
|
||||
this._local.set('deserializedXYWH', deserializeXYWH(xywh));
|
||||
this._lastXYWH = xywh;
|
||||
}
|
||||
@@ -386,6 +387,8 @@ export abstract class GfxGroupLikeElementModel<
|
||||
{
|
||||
private _childIds: string[] = [];
|
||||
|
||||
private _xywhDirty = true;
|
||||
|
||||
private readonly _mutex = createMutex();
|
||||
|
||||
abstract children: Y.Map<any>;
|
||||
@@ -420,24 +423,9 @@ export abstract class GfxGroupLikeElementModel<
|
||||
|
||||
get xywh() {
|
||||
this._mutex(() => {
|
||||
const curXYWH =
|
||||
(this._local.get('xywh') as SerializedXYWH) ?? '[0,0,0,0]';
|
||||
const newXYWH = this._getXYWH().serialize();
|
||||
|
||||
if (curXYWH !== newXYWH || !this._local.has('xywh')) {
|
||||
this._local.set('xywh', newXYWH);
|
||||
|
||||
if (curXYWH !== newXYWH) {
|
||||
this._onChange({
|
||||
props: {
|
||||
xywh: newXYWH,
|
||||
},
|
||||
oldValues: {
|
||||
xywh: curXYWH,
|
||||
},
|
||||
local: true,
|
||||
});
|
||||
}
|
||||
if (this._xywhDirty || !this._local.has('xywh')) {
|
||||
this._local.set('xywh', this._getXYWH().serialize());
|
||||
this._xywhDirty = false;
|
||||
}
|
||||
});
|
||||
|
||||
@@ -457,15 +445,41 @@ export abstract class GfxGroupLikeElementModel<
|
||||
bound = bound ? bound.unite(child.elementBound) : child.elementBound;
|
||||
});
|
||||
|
||||
if (bound) {
|
||||
this._local.set('xywh', bound.serialize());
|
||||
} else {
|
||||
this._local.delete('xywh');
|
||||
}
|
||||
|
||||
return bound ?? new Bound(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
invalidateXYWH() {
|
||||
this._xywhDirty = true;
|
||||
this._local.delete('deserializedXYWH');
|
||||
}
|
||||
|
||||
refreshXYWH(local: boolean) {
|
||||
this._mutex(() => {
|
||||
const oldXYWH =
|
||||
(this._local.get('xywh') as SerializedXYWH) ?? '[0,0,0,0]';
|
||||
const nextXYWH = this._getXYWH().serialize();
|
||||
|
||||
this._xywhDirty = false;
|
||||
|
||||
if (oldXYWH === nextXYWH && this._local.has('xywh')) {
|
||||
return;
|
||||
}
|
||||
|
||||
this._local.set('xywh', nextXYWH);
|
||||
this._local.delete('deserializedXYWH');
|
||||
|
||||
this._onChange({
|
||||
props: {
|
||||
xywh: nextXYWH,
|
||||
},
|
||||
oldValues: {
|
||||
xywh: oldXYWH,
|
||||
},
|
||||
local,
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
abstract addChild(element: GfxModel): void;
|
||||
|
||||
/**
|
||||
@@ -496,6 +510,7 @@ export abstract class GfxGroupLikeElementModel<
|
||||
setChildIds(value: string[], fromLocal: boolean) {
|
||||
const oldChildIds = this.childIds;
|
||||
this._childIds = value;
|
||||
this.invalidateXYWH();
|
||||
|
||||
this._onChange({
|
||||
props: {
|
||||
|
||||
@@ -52,6 +52,12 @@ export type MiddlewareCtx = {
|
||||
export type SurfaceMiddleware = (ctx: MiddlewareCtx) => void;
|
||||
|
||||
export class SurfaceBlockModel extends BlockModel<SurfaceBlockProps> {
|
||||
private static readonly _groupBoundImpactKeys = new Set([
|
||||
'xywh',
|
||||
'rotate',
|
||||
'hidden',
|
||||
]);
|
||||
|
||||
protected _decoratorState = createDecoratorState();
|
||||
|
||||
protected _elementCtorMap: Record<
|
||||
@@ -308,6 +314,42 @@ export class SurfaceBlockModel extends BlockModel<SurfaceBlockProps> {
|
||||
Object.keys(payload.props).forEach(key => {
|
||||
model.propsUpdated.next({ key });
|
||||
});
|
||||
|
||||
this._refreshParentGroupBoundsForElement(model, payload);
|
||||
}
|
||||
|
||||
private _refreshParentGroupBounds(id: string, local: boolean) {
|
||||
const group = this.getGroup(id);
|
||||
|
||||
if (group instanceof GfxGroupLikeElementModel) {
|
||||
group.refreshXYWH(local);
|
||||
}
|
||||
}
|
||||
|
||||
private _refreshParentGroupBoundsForElement(
|
||||
model: GfxPrimitiveElementModel,
|
||||
payload: ElementUpdatedData
|
||||
) {
|
||||
if (
|
||||
model instanceof GfxGroupLikeElementModel &&
|
||||
('childIds' in payload.props || 'childIds' in payload.oldValues)
|
||||
) {
|
||||
model.refreshXYWH(payload.local);
|
||||
return;
|
||||
}
|
||||
|
||||
const affectedKeys = new Set([
|
||||
...Object.keys(payload.props),
|
||||
...Object.keys(payload.oldValues),
|
||||
]);
|
||||
|
||||
if (
|
||||
Array.from(affectedKeys).some(key =>
|
||||
SurfaceBlockModel._groupBoundImpactKeys.has(key)
|
||||
)
|
||||
) {
|
||||
this._refreshParentGroupBounds(model.id, payload.local);
|
||||
}
|
||||
}
|
||||
|
||||
private _initElementModels() {
|
||||
@@ -458,6 +500,10 @@ export class SurfaceBlockModel extends BlockModel<SurfaceBlockProps> {
|
||||
);
|
||||
}
|
||||
|
||||
if (payload.model instanceof BlockModel) {
|
||||
this._refreshParentGroupBounds(payload.id, payload.isLocal);
|
||||
}
|
||||
|
||||
break;
|
||||
case 'delete':
|
||||
if (isGfxGroupCompatibleModel(payload.model)) {
|
||||
@@ -482,6 +528,13 @@ export class SurfaceBlockModel extends BlockModel<SurfaceBlockProps> {
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
payload.props.key &&
|
||||
SurfaceBlockModel._groupBoundImpactKeys.has(payload.props.key)
|
||||
) {
|
||||
this._refreshParentGroupBounds(payload.id, payload.isLocal);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
@@ -42,6 +42,7 @@
|
||||
"devDependencies": {
|
||||
"@vanilla-extract/vite-plugin": "^5.0.0",
|
||||
"@vitest/browser-playwright": "^4.0.18",
|
||||
"playwright": "=1.58.2",
|
||||
"vite": "^7.2.7",
|
||||
"vite-plugin-istanbul": "^7.2.1",
|
||||
"vite-plugin-wasm": "^3.5.0",
|
||||
|
||||
@@ -4,6 +4,7 @@ import type {
|
||||
ConnectorElementModel,
|
||||
GroupElementModel,
|
||||
} from '@blocksuite/affine/model';
|
||||
import { serializeXYWH } from '@blocksuite/global/gfx';
|
||||
import { beforeEach, describe, expect, test } from 'vitest';
|
||||
|
||||
import { wait } from '../utils/common.js';
|
||||
@@ -138,6 +139,29 @@ describe('group', () => {
|
||||
|
||||
expect(group.childIds).toEqual([id]);
|
||||
});
|
||||
|
||||
test('group xywh should update when child xywh changes', () => {
|
||||
const shapeId = model.addElement({
|
||||
type: 'shape',
|
||||
xywh: serializeXYWH(0, 0, 100, 100),
|
||||
});
|
||||
const groupId = model.addElement({
|
||||
type: 'group',
|
||||
children: {
|
||||
[shapeId]: true,
|
||||
},
|
||||
});
|
||||
|
||||
const group = model.getElementById(groupId) as GroupElementModel;
|
||||
|
||||
expect(group.xywh).toBe(serializeXYWH(0, 0, 100, 100));
|
||||
|
||||
model.updateElement(shapeId, {
|
||||
xywh: serializeXYWH(50, 60, 100, 100),
|
||||
});
|
||||
|
||||
expect(group.xywh).toBe(serializeXYWH(50, 60, 100, 100));
|
||||
});
|
||||
});
|
||||
|
||||
describe('connector', () => {
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 25 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 25 KiB |
48
deny.toml
Normal file
48
deny.toml
Normal file
@@ -0,0 +1,48 @@
|
||||
[graph]
|
||||
all-features = true
|
||||
exclude-dev = true
|
||||
targets = [
|
||||
"x86_64-unknown-linux-gnu",
|
||||
"aarch64-apple-darwin",
|
||||
"x86_64-apple-darwin",
|
||||
"x86_64-pc-windows-msvc",
|
||||
"aarch64-linux-android",
|
||||
"aarch64-apple-ios",
|
||||
"aarch64-apple-ios-sim",
|
||||
]
|
||||
|
||||
[licenses]
|
||||
allow = [
|
||||
"0BSD",
|
||||
"Apache-2.0",
|
||||
"Apache-2.0 WITH LLVM-exception",
|
||||
"BSD-2-Clause",
|
||||
"BSD-3-Clause",
|
||||
"BSL-1.0",
|
||||
"CC0-1.0",
|
||||
"CDLA-Permissive-2.0",
|
||||
"ISC",
|
||||
"MIT",
|
||||
"MPL-2.0",
|
||||
"Unicode-3.0",
|
||||
"Unlicense",
|
||||
"Zlib",
|
||||
]
|
||||
confidence-threshold = 0.93
|
||||
unused-allowed-license = "allow"
|
||||
version = 2
|
||||
|
||||
[[licenses.exceptions]]
|
||||
allow = ["AGPL-3.0-only"]
|
||||
crate = "llm_adapter"
|
||||
|
||||
[[licenses.exceptions]]
|
||||
allow = ["AGPL-3.0-or-later"]
|
||||
crate = "memory-indexer"
|
||||
|
||||
[[licenses.exceptions]]
|
||||
allow = ["AGPL-3.0-or-later"]
|
||||
crate = "path-ext"
|
||||
|
||||
[licenses.private]
|
||||
ignore = true
|
||||
@@ -92,7 +92,7 @@
|
||||
"vite": "^7.2.7",
|
||||
"vitest": "^4.0.18"
|
||||
},
|
||||
"packageManager": "yarn@4.12.0",
|
||||
"packageManager": "yarn@4.13.0",
|
||||
"resolutions": {
|
||||
"array-buffer-byte-length": "npm:@nolyfill/array-buffer-byte-length@^1",
|
||||
"array-includes": "npm:@nolyfill/array-includes@^1",
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
edition = "2024"
|
||||
license-file = "LICENSE"
|
||||
name = "affine_server_native"
|
||||
publish = false
|
||||
version = "1.0.0"
|
||||
|
||||
[lib]
|
||||
|
||||
@@ -33,30 +33,30 @@
|
||||
"@nestjs-cls/transactional-adapter-prisma": "^1.3.4",
|
||||
"@nestjs/apollo": "^13.0.4",
|
||||
"@nestjs/bullmq": "^11.0.4",
|
||||
"@nestjs/common": "^11.0.21",
|
||||
"@nestjs/core": "^11.1.14",
|
||||
"@nestjs/common": "^11.1.17",
|
||||
"@nestjs/core": "^11.1.17",
|
||||
"@nestjs/graphql": "^13.0.4",
|
||||
"@nestjs/platform-express": "^11.1.14",
|
||||
"@nestjs/platform-socket.io": "^11.1.14",
|
||||
"@nestjs/platform-express": "^11.1.17",
|
||||
"@nestjs/platform-socket.io": "^11.1.17",
|
||||
"@nestjs/schedule": "^6.1.1",
|
||||
"@nestjs/throttler": "^6.5.0",
|
||||
"@nestjs/websockets": "^11.1.14",
|
||||
"@nestjs/websockets": "^11.1.17",
|
||||
"@node-rs/argon2": "^2.0.2",
|
||||
"@node-rs/crc32": "^1.10.6",
|
||||
"@opentelemetry/api": "^1.9.0",
|
||||
"@opentelemetry/core": "^2.2.0",
|
||||
"@opentelemetry/exporter-prometheus": "^0.212.0",
|
||||
"@opentelemetry/exporter-zipkin": "^2.2.0",
|
||||
"@opentelemetry/host-metrics": "^0.38.0",
|
||||
"@opentelemetry/instrumentation": "^0.212.0",
|
||||
"@opentelemetry/instrumentation-graphql": "^0.60.0",
|
||||
"@opentelemetry/instrumentation-http": "^0.212.0",
|
||||
"@opentelemetry/instrumentation-ioredis": "^0.60.0",
|
||||
"@opentelemetry/instrumentation-nestjs-core": "^0.58.0",
|
||||
"@opentelemetry/instrumentation-socket.io": "^0.59.0",
|
||||
"@opentelemetry/exporter-prometheus": "^0.213.0",
|
||||
"@opentelemetry/exporter-zipkin": "^2.6.0",
|
||||
"@opentelemetry/host-metrics": "^0.38.3",
|
||||
"@opentelemetry/instrumentation": "^0.213.0",
|
||||
"@opentelemetry/instrumentation-graphql": "^0.61.0",
|
||||
"@opentelemetry/instrumentation-http": "^0.213.0",
|
||||
"@opentelemetry/instrumentation-ioredis": "^0.61.0",
|
||||
"@opentelemetry/instrumentation-nestjs-core": "^0.59.0",
|
||||
"@opentelemetry/instrumentation-socket.io": "^0.60.0",
|
||||
"@opentelemetry/resources": "^2.2.0",
|
||||
"@opentelemetry/sdk-metrics": "^2.2.0",
|
||||
"@opentelemetry/sdk-node": "^0.212.0",
|
||||
"@opentelemetry/sdk-node": "^0.213.0",
|
||||
"@opentelemetry/sdk-trace-node": "^2.2.0",
|
||||
"@opentelemetry/semantic-conventions": "^1.38.0",
|
||||
"@prisma/client": "^6.6.0",
|
||||
@@ -72,7 +72,7 @@
|
||||
"eventemitter2": "^6.4.9",
|
||||
"exa-js": "^2.4.0",
|
||||
"express": "^5.0.1",
|
||||
"fast-xml-parser": "^5.3.4",
|
||||
"fast-xml-parser": "^5.5.7",
|
||||
"get-stream": "^9.0.1",
|
||||
"google-auth-library": "^10.2.0",
|
||||
"graphql": "^16.9.0",
|
||||
|
||||
@@ -1,12 +1,35 @@
|
||||
import test from 'ava';
|
||||
import { z } from 'zod';
|
||||
|
||||
import type { DocReader } from '../../core/doc';
|
||||
import type { AccessController } from '../../core/permission';
|
||||
import type { Models } from '../../models';
|
||||
import { NativeLlmRequest, NativeLlmStreamEvent } from '../../native';
|
||||
import {
|
||||
ToolCallAccumulator,
|
||||
ToolCallLoop,
|
||||
ToolSchemaExtractor,
|
||||
} from '../../plugins/copilot/providers/loop';
|
||||
import {
|
||||
buildBlobContentGetter,
|
||||
createBlobReadTool,
|
||||
} from '../../plugins/copilot/tools/blob-read';
|
||||
import {
|
||||
buildDocKeywordSearchGetter,
|
||||
createDocKeywordSearchTool,
|
||||
} from '../../plugins/copilot/tools/doc-keyword-search';
|
||||
import {
|
||||
buildDocContentGetter,
|
||||
createDocReadTool,
|
||||
} from '../../plugins/copilot/tools/doc-read';
|
||||
import {
|
||||
buildDocSearchGetter,
|
||||
createDocSemanticSearchTool,
|
||||
} from '../../plugins/copilot/tools/doc-semantic-search';
|
||||
import {
|
||||
DOCUMENT_SYNC_PENDING_MESSAGE,
|
||||
LOCAL_WORKSPACE_SYNC_REQUIRED_MESSAGE,
|
||||
} from '../../plugins/copilot/tools/doc-sync';
|
||||
|
||||
test('ToolCallAccumulator should merge deltas and complete tool call', t => {
|
||||
const accumulator = new ToolCallAccumulator();
|
||||
@@ -286,3 +309,210 @@ test('ToolCallLoop should surface invalid JSON as tool error without executing',
|
||||
is_error: true,
|
||||
});
|
||||
});
|
||||
|
||||
test('doc_read should return specific sync errors for unavailable docs', async t => {
|
||||
const cases = [
|
||||
{
|
||||
name: 'local workspace without cloud sync',
|
||||
workspace: null,
|
||||
authors: null,
|
||||
markdown: null,
|
||||
expected: {
|
||||
type: 'error',
|
||||
name: 'Workspace Sync Required',
|
||||
message: LOCAL_WORKSPACE_SYNC_REQUIRED_MESSAGE,
|
||||
},
|
||||
docReaderCalled: false,
|
||||
},
|
||||
{
|
||||
name: 'cloud workspace document not synced to server yet',
|
||||
workspace: { id: 'ws-1' },
|
||||
authors: null,
|
||||
markdown: null,
|
||||
expected: {
|
||||
type: 'error',
|
||||
name: 'Document Sync Pending',
|
||||
message: DOCUMENT_SYNC_PENDING_MESSAGE('doc-1'),
|
||||
},
|
||||
docReaderCalled: false,
|
||||
},
|
||||
{
|
||||
name: 'cloud workspace document markdown not ready yet',
|
||||
workspace: { id: 'ws-1' },
|
||||
authors: {
|
||||
createdAt: new Date('2026-01-01T00:00:00.000Z'),
|
||||
updatedAt: new Date('2026-01-01T00:00:00.000Z'),
|
||||
createdByUser: null,
|
||||
updatedByUser: null,
|
||||
},
|
||||
markdown: null,
|
||||
expected: {
|
||||
type: 'error',
|
||||
name: 'Document Sync Pending',
|
||||
message: DOCUMENT_SYNC_PENDING_MESSAGE('doc-1'),
|
||||
},
|
||||
docReaderCalled: true,
|
||||
},
|
||||
] as const;
|
||||
|
||||
const ac = {
|
||||
user: () => ({
|
||||
workspace: () => ({ doc: () => ({ can: async () => true }) }),
|
||||
}),
|
||||
} as unknown as AccessController;
|
||||
|
||||
for (const testCase of cases) {
|
||||
let docReaderCalled = false;
|
||||
const docReader = {
|
||||
getDocMarkdown: async () => {
|
||||
docReaderCalled = true;
|
||||
return testCase.markdown;
|
||||
},
|
||||
} as unknown as DocReader;
|
||||
|
||||
const models = {
|
||||
workspace: {
|
||||
get: async () => testCase.workspace,
|
||||
},
|
||||
doc: {
|
||||
getAuthors: async () => testCase.authors,
|
||||
},
|
||||
} as unknown as Models;
|
||||
|
||||
const getDoc = buildDocContentGetter(ac, docReader, models);
|
||||
const tool = createDocReadTool(
|
||||
getDoc.bind(null, {
|
||||
user: 'user-1',
|
||||
workspace: 'workspace-1',
|
||||
})
|
||||
);
|
||||
|
||||
const result = await tool.execute?.({ doc_id: 'doc-1' }, {});
|
||||
|
||||
t.is(docReaderCalled, testCase.docReaderCalled, testCase.name);
|
||||
t.deepEqual(result, testCase.expected, testCase.name);
|
||||
}
|
||||
});
|
||||
|
||||
test('document search tools should return sync error for local workspace', async t => {
|
||||
const ac = {
|
||||
user: () => ({
|
||||
workspace: () => ({
|
||||
can: async () => true,
|
||||
docs: async () => [],
|
||||
}),
|
||||
}),
|
||||
} as unknown as AccessController;
|
||||
|
||||
const models = {
|
||||
workspace: {
|
||||
get: async () => null,
|
||||
},
|
||||
} as unknown as Models;
|
||||
|
||||
let keywordSearchCalled = false;
|
||||
const indexerService = {
|
||||
searchDocsByKeyword: async () => {
|
||||
keywordSearchCalled = true;
|
||||
return [];
|
||||
},
|
||||
} as unknown as Parameters<typeof buildDocKeywordSearchGetter>[1];
|
||||
|
||||
let semanticSearchCalled = false;
|
||||
const contextService = {
|
||||
matchWorkspaceAll: async () => {
|
||||
semanticSearchCalled = true;
|
||||
return [];
|
||||
},
|
||||
} as unknown as Parameters<typeof buildDocSearchGetter>[1];
|
||||
|
||||
const keywordTool = createDocKeywordSearchTool(
|
||||
buildDocKeywordSearchGetter(ac, indexerService, models).bind(null, {
|
||||
user: 'user-1',
|
||||
workspace: 'workspace-1',
|
||||
})
|
||||
);
|
||||
|
||||
const semanticTool = createDocSemanticSearchTool(
|
||||
buildDocSearchGetter(ac, contextService, null, models).bind(null, {
|
||||
user: 'user-1',
|
||||
workspace: 'workspace-1',
|
||||
})
|
||||
);
|
||||
|
||||
const keywordResult = await keywordTool.execute?.({ query: 'hello' }, {});
|
||||
const semanticResult = await semanticTool.execute?.({ query: 'hello' }, {});
|
||||
|
||||
t.false(keywordSearchCalled);
|
||||
t.false(semanticSearchCalled);
|
||||
t.deepEqual(keywordResult, {
|
||||
type: 'error',
|
||||
name: 'Workspace Sync Required',
|
||||
message: LOCAL_WORKSPACE_SYNC_REQUIRED_MESSAGE,
|
||||
});
|
||||
t.deepEqual(semanticResult, {
|
||||
type: 'error',
|
||||
name: 'Workspace Sync Required',
|
||||
message: LOCAL_WORKSPACE_SYNC_REQUIRED_MESSAGE,
|
||||
});
|
||||
});
|
||||
|
||||
test('doc_semantic_search should return empty array when nothing matches', async t => {
|
||||
const ac = {
|
||||
user: () => ({
|
||||
workspace: () => ({
|
||||
can: async () => true,
|
||||
docs: async () => [],
|
||||
}),
|
||||
}),
|
||||
} as unknown as AccessController;
|
||||
|
||||
const models = {
|
||||
workspace: {
|
||||
get: async () => ({ id: 'workspace-1' }),
|
||||
},
|
||||
} as unknown as Models;
|
||||
|
||||
const contextService = {
|
||||
matchWorkspaceAll: async () => [],
|
||||
} as unknown as Parameters<typeof buildDocSearchGetter>[1];
|
||||
|
||||
const semanticTool = createDocSemanticSearchTool(
|
||||
buildDocSearchGetter(ac, contextService, null, models).bind(null, {
|
||||
user: 'user-1',
|
||||
workspace: 'workspace-1',
|
||||
})
|
||||
);
|
||||
|
||||
const result = await semanticTool.execute?.({ query: 'hello' }, {});
|
||||
|
||||
t.deepEqual(result, []);
|
||||
});
|
||||
|
||||
test('blob_read should return explicit error when attachment context is missing', async t => {
|
||||
const ac = {
|
||||
user: () => ({
|
||||
workspace: () => ({
|
||||
allowLocal: () => ({
|
||||
can: async () => true,
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
} as unknown as AccessController;
|
||||
|
||||
const blobTool = createBlobReadTool(
|
||||
buildBlobContentGetter(ac, null).bind(null, {
|
||||
user: 'user-1',
|
||||
workspace: 'workspace-1',
|
||||
})
|
||||
);
|
||||
|
||||
const result = await blobTool.execute?.({ blob_id: 'blob-1' }, {});
|
||||
|
||||
t.deepEqual(result, {
|
||||
type: 'error',
|
||||
name: 'Blob Read Failed',
|
||||
message:
|
||||
'Missing workspace, user, blob id, or copilot context for blob_read.',
|
||||
});
|
||||
});
|
||||
|
||||
@@ -2,10 +2,12 @@ import {
|
||||
acceptInviteByInviteIdMutation,
|
||||
approveWorkspaceTeamMemberMutation,
|
||||
createInviteLinkMutation,
|
||||
deleteBlobMutation,
|
||||
getInviteInfoQuery,
|
||||
getMembersByWorkspaceIdQuery,
|
||||
inviteByEmailsMutation,
|
||||
leaveWorkspaceMutation,
|
||||
releaseDeletedBlobsMutation,
|
||||
revokeMemberPermissionMutation,
|
||||
WorkspaceInviteLinkExpireTime,
|
||||
WorkspaceMemberStatus,
|
||||
@@ -13,6 +15,11 @@ import {
|
||||
import { faker } from '@faker-js/faker';
|
||||
|
||||
import { Models } from '../../../models';
|
||||
import { FeatureConfigs } from '../../../models/common/feature';
|
||||
import {
|
||||
SubscriptionPlan,
|
||||
SubscriptionRecurring,
|
||||
} from '../../../plugins/payment/types';
|
||||
import { Mockers } from '../../mocks';
|
||||
import { app, e2e } from '../test';
|
||||
|
||||
@@ -81,6 +88,175 @@ e2e('should invite a user', async t => {
|
||||
t.is(getInviteInfo2.status, WorkspaceMemberStatus.Accepted);
|
||||
});
|
||||
|
||||
e2e('should re-check seat when accepting an email invitation', async t => {
|
||||
const { owner, workspace } = await createWorkspace();
|
||||
const member = await app.create(Mockers.User);
|
||||
await app.create(Mockers.TeamWorkspace, {
|
||||
id: workspace.id,
|
||||
quantity: 4,
|
||||
});
|
||||
|
||||
await app.create(Mockers.WorkspaceUser, {
|
||||
workspaceId: workspace.id,
|
||||
userId: (await app.create(Mockers.User)).id,
|
||||
});
|
||||
await app.create(Mockers.WorkspaceUser, {
|
||||
workspaceId: workspace.id,
|
||||
userId: (await app.create(Mockers.User)).id,
|
||||
});
|
||||
|
||||
await app.login(owner);
|
||||
const invite = await app.gql({
|
||||
query: inviteByEmailsMutation,
|
||||
variables: {
|
||||
emails: [member.email],
|
||||
workspaceId: workspace.id,
|
||||
},
|
||||
});
|
||||
|
||||
await app.eventBus.emitAsync('workspace.members.allocateSeats', {
|
||||
workspaceId: workspace.id,
|
||||
quantity: 4,
|
||||
});
|
||||
|
||||
await app.models.workspaceFeature.remove(workspace.id, 'team_plan_v1');
|
||||
|
||||
await app.login(member);
|
||||
await t.throwsAsync(
|
||||
app.gql({
|
||||
query: acceptInviteByInviteIdMutation,
|
||||
variables: {
|
||||
workspaceId: workspace.id,
|
||||
inviteId: invite.inviteMembers[0].inviteId!,
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
const { getInviteInfo } = await app.gql({
|
||||
query: getInviteInfoQuery,
|
||||
variables: {
|
||||
inviteId: invite.inviteMembers[0].inviteId!,
|
||||
},
|
||||
});
|
||||
|
||||
t.is(getInviteInfo.status, WorkspaceMemberStatus.Pending);
|
||||
});
|
||||
|
||||
e2e.serial(
|
||||
'should block accepting pending invitations in readonly mode and recover after blob cleanup',
|
||||
async t => {
|
||||
const { owner, workspace } = await createWorkspace();
|
||||
const member = await app.create(Mockers.User);
|
||||
const freeStorageQuota = FeatureConfigs.free_plan_v1.configs.storageQuota;
|
||||
const lifetimeStorageQuota =
|
||||
FeatureConfigs.lifetime_pro_plan_v1.configs.storageQuota;
|
||||
|
||||
FeatureConfigs.free_plan_v1.configs.storageQuota = 1;
|
||||
FeatureConfigs.lifetime_pro_plan_v1.configs.storageQuota = 2;
|
||||
t.teardown(() => {
|
||||
FeatureConfigs.free_plan_v1.configs.storageQuota = freeStorageQuota;
|
||||
FeatureConfigs.lifetime_pro_plan_v1.configs.storageQuota =
|
||||
lifetimeStorageQuota;
|
||||
});
|
||||
|
||||
await app.models.userFeature.switchQuota(
|
||||
owner.id,
|
||||
'lifetime_pro_plan_v1',
|
||||
'test setup'
|
||||
);
|
||||
|
||||
await app.login(owner);
|
||||
const invite = await app.gql({
|
||||
query: inviteByEmailsMutation,
|
||||
variables: {
|
||||
emails: [member.email],
|
||||
workspaceId: workspace.id,
|
||||
},
|
||||
});
|
||||
|
||||
await app.models.blob.upsert({
|
||||
workspaceId: workspace.id,
|
||||
key: 'overflow-blob',
|
||||
mime: 'application/octet-stream',
|
||||
size: 2,
|
||||
status: 'completed',
|
||||
uploadId: null,
|
||||
});
|
||||
|
||||
await app.eventBus.emitAsync('user.subscription.canceled', {
|
||||
userId: owner.id,
|
||||
plan: SubscriptionPlan.Pro,
|
||||
recurring: SubscriptionRecurring.Lifetime,
|
||||
});
|
||||
|
||||
t.true(
|
||||
await app.models.workspaceFeature.has(
|
||||
workspace.id,
|
||||
'quota_exceeded_readonly_workspace_v1'
|
||||
)
|
||||
);
|
||||
|
||||
await app.login(member);
|
||||
await t.throwsAsync(
|
||||
app.gql({
|
||||
query: acceptInviteByInviteIdMutation,
|
||||
variables: {
|
||||
workspaceId: workspace.id,
|
||||
inviteId: invite.inviteMembers[0].inviteId!,
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
const { getInviteInfo: pendingInvite } = await app.gql({
|
||||
query: getInviteInfoQuery,
|
||||
variables: {
|
||||
inviteId: invite.inviteMembers[0].inviteId!,
|
||||
},
|
||||
});
|
||||
t.is(pendingInvite.status, WorkspaceMemberStatus.Pending);
|
||||
|
||||
await app.login(owner);
|
||||
await app.gql({
|
||||
query: deleteBlobMutation,
|
||||
variables: {
|
||||
workspaceId: workspace.id,
|
||||
key: 'overflow-blob',
|
||||
permanently: false,
|
||||
},
|
||||
});
|
||||
await app.gql({
|
||||
query: releaseDeletedBlobsMutation,
|
||||
variables: {
|
||||
workspaceId: workspace.id,
|
||||
},
|
||||
});
|
||||
|
||||
t.false(
|
||||
await app.models.workspaceFeature.has(
|
||||
workspace.id,
|
||||
'quota_exceeded_readonly_workspace_v1'
|
||||
)
|
||||
);
|
||||
|
||||
await app.login(member);
|
||||
await app.gql({
|
||||
query: acceptInviteByInviteIdMutation,
|
||||
variables: {
|
||||
workspaceId: workspace.id,
|
||||
inviteId: invite.inviteMembers[0].inviteId!,
|
||||
},
|
||||
});
|
||||
|
||||
const { getInviteInfo: acceptedInvite } = await app.gql({
|
||||
query: getInviteInfoQuery,
|
||||
variables: {
|
||||
inviteId: invite.inviteMembers[0].inviteId!,
|
||||
},
|
||||
});
|
||||
t.is(acceptedInvite.status, WorkspaceMemberStatus.Accepted);
|
||||
}
|
||||
);
|
||||
|
||||
e2e('should leave a workspace', async t => {
|
||||
const { owner, workspace } = await createWorkspace();
|
||||
const u2 = await app.create(Mockers.User);
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
import {
|
||||
getInviteInfoQuery,
|
||||
inviteByEmailsMutation,
|
||||
publishPageMutation,
|
||||
revokeMemberPermissionMutation,
|
||||
revokePublicPageMutation,
|
||||
WorkspaceMemberStatus,
|
||||
} from '@affine/graphql';
|
||||
|
||||
import { QuotaService } from '../../../core/quota/service';
|
||||
import { WorkspaceRole } from '../../../models';
|
||||
import {
|
||||
SubscriptionPlan,
|
||||
SubscriptionRecurring,
|
||||
} from '../../../plugins/payment/types';
|
||||
import { Mockers } from '../../mocks';
|
||||
import { app, e2e } from '../test';
|
||||
|
||||
@@ -54,6 +62,42 @@ const getInvitationInfo = async (inviteId: string) => {
|
||||
return result.getInviteInfo;
|
||||
};
|
||||
|
||||
const publishDoc = async (workspaceId: string, docId: string) => {
|
||||
const { publishDoc } = await app.gql({
|
||||
query: publishPageMutation,
|
||||
variables: {
|
||||
workspaceId,
|
||||
pageId: docId,
|
||||
},
|
||||
});
|
||||
|
||||
return publishDoc;
|
||||
};
|
||||
|
||||
const revokePublicDoc = async (workspaceId: string, docId: string) => {
|
||||
const { revokePublicDoc } = await app.gql({
|
||||
query: revokePublicPageMutation,
|
||||
variables: {
|
||||
workspaceId,
|
||||
pageId: docId,
|
||||
},
|
||||
});
|
||||
|
||||
return revokePublicDoc;
|
||||
};
|
||||
|
||||
const revokeMember = async (workspaceId: string, userId: string) => {
|
||||
const { revokeMember } = await app.gql({
|
||||
query: revokeMemberPermissionMutation,
|
||||
variables: {
|
||||
workspaceId,
|
||||
userId,
|
||||
},
|
||||
});
|
||||
|
||||
return revokeMember;
|
||||
};
|
||||
|
||||
e2e('should set new invited users to AllocatingSeat', async t => {
|
||||
const { owner, workspace } = await createTeamWorkspace();
|
||||
await app.login(owner);
|
||||
@@ -165,3 +209,136 @@ e2e('should set all rests to NeedMoreSeat', async t => {
|
||||
WorkspaceMemberStatus.NeedMoreSeat
|
||||
);
|
||||
});
|
||||
|
||||
e2e(
|
||||
'should cleanup non-accepted members when team workspace is downgraded',
|
||||
async t => {
|
||||
const { workspace } = await createTeamWorkspace();
|
||||
|
||||
const pending = await app.create(Mockers.User);
|
||||
await app.create(Mockers.WorkspaceUser, {
|
||||
userId: pending.id,
|
||||
workspaceId: workspace.id,
|
||||
status: WorkspaceMemberStatus.Pending,
|
||||
});
|
||||
|
||||
const allocating = await app.create(Mockers.User);
|
||||
await app.create(Mockers.WorkspaceUser, {
|
||||
userId: allocating.id,
|
||||
workspaceId: workspace.id,
|
||||
status: WorkspaceMemberStatus.AllocatingSeat,
|
||||
source: 'Email',
|
||||
});
|
||||
|
||||
const underReview = await app.create(Mockers.User);
|
||||
await app.create(Mockers.WorkspaceUser, {
|
||||
userId: underReview.id,
|
||||
workspaceId: workspace.id,
|
||||
status: WorkspaceMemberStatus.UnderReview,
|
||||
});
|
||||
|
||||
await app.eventBus.emitAsync('workspace.subscription.canceled', {
|
||||
workspaceId: workspace.id,
|
||||
plan: SubscriptionPlan.Team,
|
||||
recurring: SubscriptionRecurring.Monthly,
|
||||
});
|
||||
|
||||
const [members] = await app.models.workspaceUser.paginate(workspace.id, {
|
||||
first: 20,
|
||||
offset: 0,
|
||||
});
|
||||
|
||||
t.deepEqual(
|
||||
members.map(member => member.status),
|
||||
[
|
||||
WorkspaceMemberStatus.Accepted,
|
||||
WorkspaceMemberStatus.Accepted,
|
||||
WorkspaceMemberStatus.Accepted,
|
||||
]
|
||||
);
|
||||
t.false(await app.models.workspace.isTeamWorkspace(workspace.id));
|
||||
}
|
||||
);
|
||||
|
||||
e2e(
|
||||
'should demote accepted admins and keep workspace writable when downgrade stays within owner quota',
|
||||
async t => {
|
||||
const { workspace, owner, admin } = await createTeamWorkspace();
|
||||
|
||||
await app.eventBus.emitAsync('workspace.subscription.canceled', {
|
||||
workspaceId: workspace.id,
|
||||
plan: SubscriptionPlan.Team,
|
||||
recurring: SubscriptionRecurring.Monthly,
|
||||
});
|
||||
|
||||
t.false(await app.models.workspace.isTeamWorkspace(workspace.id));
|
||||
t.false(
|
||||
await app.models.workspaceFeature.has(
|
||||
workspace.id,
|
||||
'quota_exceeded_readonly_workspace_v1'
|
||||
)
|
||||
);
|
||||
t.is(
|
||||
(await app.models.workspaceUser.get(workspace.id, admin.id))?.type,
|
||||
WorkspaceRole.Collaborator
|
||||
);
|
||||
|
||||
await app.login(owner);
|
||||
await t.notThrowsAsync(publishDoc(workspace.id, 'doc-1'));
|
||||
}
|
||||
);
|
||||
|
||||
e2e(
|
||||
'should enter readonly mode on over-quota team downgrade and recover through cleanup actions',
|
||||
async t => {
|
||||
const { workspace, owner, admin } = await createTeamWorkspace(20);
|
||||
const extraMembers = await Promise.all(
|
||||
Array.from({ length: 8 }).map(async () => {
|
||||
const member = await app.create(Mockers.User);
|
||||
await app.create(Mockers.WorkspaceUser, {
|
||||
workspaceId: workspace.id,
|
||||
userId: member.id,
|
||||
});
|
||||
return member;
|
||||
})
|
||||
);
|
||||
|
||||
await app.login(owner);
|
||||
await publishDoc(workspace.id, 'published-doc');
|
||||
|
||||
await app.eventBus.emitAsync('workspace.subscription.canceled', {
|
||||
workspaceId: workspace.id,
|
||||
plan: SubscriptionPlan.Team,
|
||||
recurring: SubscriptionRecurring.Monthly,
|
||||
});
|
||||
|
||||
t.false(await app.models.workspace.isTeamWorkspace(workspace.id));
|
||||
t.true(
|
||||
await app.models.workspaceFeature.has(
|
||||
workspace.id,
|
||||
'quota_exceeded_readonly_workspace_v1'
|
||||
)
|
||||
);
|
||||
t.is(
|
||||
(await app.models.workspaceUser.get(workspace.id, admin.id))?.type,
|
||||
WorkspaceRole.Collaborator
|
||||
);
|
||||
|
||||
await t.throwsAsync(publishDoc(workspace.id, 'blocked-doc'));
|
||||
await t.notThrowsAsync(revokePublicDoc(workspace.id, 'published-doc'));
|
||||
|
||||
const quota = await app
|
||||
.get(QuotaService)
|
||||
.getWorkspaceQuotaWithUsage(workspace.id);
|
||||
for (const member of extraMembers.slice(0, quota.overcapacityMemberCount)) {
|
||||
await revokeMember(workspace.id, member.id);
|
||||
}
|
||||
|
||||
t.false(
|
||||
await app.models.workspaceFeature.has(
|
||||
workspace.id,
|
||||
'quota_exceeded_readonly_workspace_v1'
|
||||
)
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
@@ -65,6 +65,28 @@ test('should transfer workespace owner', async t => {
|
||||
|
||||
const owner2 = await models.workspaceUser.getOwner(workspace.id);
|
||||
t.is(owner2.id, user2.id);
|
||||
const oldOwnerRole = await models.workspaceUser.get(workspace.id, user.id);
|
||||
t.is(oldOwnerRole?.type, WorkspaceRole.Collaborator);
|
||||
});
|
||||
|
||||
test('should keep old owner as admin when transferring a team workspace', async t => {
|
||||
const [user, user2] = await module.create(Mockers.User, 2);
|
||||
const workspace = await module.create(Mockers.Workspace, {
|
||||
owner: { id: user.id },
|
||||
});
|
||||
await module.create(Mockers.TeamWorkspace, {
|
||||
id: workspace.id,
|
||||
quantity: 10,
|
||||
});
|
||||
await module.create(Mockers.WorkspaceUser, {
|
||||
workspaceId: workspace.id,
|
||||
userId: user2.id,
|
||||
});
|
||||
|
||||
await models.workspaceUser.setOwner(workspace.id, user2.id);
|
||||
|
||||
const oldOwnerRole = await models.workspaceUser.get(workspace.id, user.id);
|
||||
t.is(oldOwnerRole?.type, WorkspaceRole.Admin);
|
||||
});
|
||||
|
||||
test('should throw if transfer owner to non-active member', async t => {
|
||||
|
||||
@@ -6,13 +6,17 @@ import ava, { TestFn } from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import { AppModule } from '../../app.module';
|
||||
import { ConfigFactory, URLHelper } from '../../base';
|
||||
import { ConfigFactory, InvalidOauthResponse, URLHelper } from '../../base';
|
||||
import { ConfigModule } from '../../base/config';
|
||||
import { CurrentUser } from '../../core/auth';
|
||||
import { AuthService } from '../../core/auth/service';
|
||||
import { ServerFeature } from '../../core/config/types';
|
||||
import { Models } from '../../models';
|
||||
import { OAuthProviderName } from '../../plugins/oauth/config';
|
||||
import { OAuthProviderFactory } from '../../plugins/oauth/factory';
|
||||
import { GithubOAuthProvider } from '../../plugins/oauth/providers/github';
|
||||
import { GoogleOAuthProvider } from '../../plugins/oauth/providers/google';
|
||||
import { OIDCProvider } from '../../plugins/oauth/providers/oidc';
|
||||
import { OAuthService } from '../../plugins/oauth/service';
|
||||
import { createTestingApp, currentUser, TestingApp } from '../utils';
|
||||
|
||||
@@ -35,6 +39,16 @@ test.before(async t => {
|
||||
clientId: 'google-client-id',
|
||||
clientSecret: 'google-client-secret',
|
||||
},
|
||||
github: {
|
||||
clientId: 'github-client-id',
|
||||
clientSecret: 'github-client-secret',
|
||||
},
|
||||
oidc: {
|
||||
clientId: '',
|
||||
clientSecret: '',
|
||||
issuer: '',
|
||||
args: {},
|
||||
},
|
||||
},
|
||||
},
|
||||
server: {
|
||||
@@ -284,7 +298,7 @@ test('should be able to get registered oauth providers', async t => {
|
||||
|
||||
const providers = oauth.availableOAuthProviders();
|
||||
|
||||
t.deepEqual(providers, [OAuthProviderName.Google]);
|
||||
t.deepEqual(providers, [OAuthProviderName.Google, OAuthProviderName.GitHub]);
|
||||
});
|
||||
|
||||
test('should throw if code is missing in callback uri', async t => {
|
||||
@@ -432,6 +446,105 @@ function mockOAuthProvider(
|
||||
return clientNonce;
|
||||
}
|
||||
|
||||
function mockGithubOAuthProvider(
|
||||
app: TestingApp,
|
||||
clientNonce: string = randomUUID()
|
||||
) {
|
||||
const provider = app.get(GithubOAuthProvider);
|
||||
const oauth = app.get(OAuthService);
|
||||
|
||||
Sinon.stub(oauth, 'isValidState').resolves(true);
|
||||
Sinon.stub(oauth, 'getOAuthState').resolves({
|
||||
provider: OAuthProviderName.GitHub,
|
||||
clientNonce,
|
||||
});
|
||||
|
||||
Sinon.stub(provider, 'getToken').resolves({ accessToken: '1' });
|
||||
|
||||
return { provider, clientNonce };
|
||||
}
|
||||
|
||||
function mockOidcProvider(
|
||||
provider: OIDCProvider,
|
||||
{
|
||||
args = {},
|
||||
idTokenClaims,
|
||||
userinfo,
|
||||
}: {
|
||||
args?: Record<string, string>;
|
||||
idTokenClaims: Record<string, unknown>;
|
||||
userinfo: Record<string, unknown>;
|
||||
}
|
||||
) {
|
||||
Sinon.stub(provider, 'config').get(() => ({
|
||||
clientId: '',
|
||||
clientSecret: '',
|
||||
issuer: '',
|
||||
args,
|
||||
}));
|
||||
Sinon.stub(
|
||||
provider as unknown as { endpoints: { userinfo_endpoint: string } },
|
||||
'endpoints'
|
||||
).get(() => ({
|
||||
userinfo_endpoint: 'https://oidc.affine.dev/userinfo',
|
||||
}));
|
||||
Sinon.stub(
|
||||
provider as unknown as { verifyIdToken: () => unknown },
|
||||
'verifyIdToken'
|
||||
).resolves(idTokenClaims);
|
||||
Sinon.stub(
|
||||
provider as unknown as { fetchJson: () => unknown },
|
||||
'fetchJson'
|
||||
).resolves(userinfo);
|
||||
}
|
||||
|
||||
function createOidcRegistrationHarness(config?: {
|
||||
clientId?: string;
|
||||
clientSecret?: string;
|
||||
issuer?: string;
|
||||
}) {
|
||||
const server = {
|
||||
enableFeature: Sinon.spy(),
|
||||
disableFeature: Sinon.spy(),
|
||||
};
|
||||
const factory = new OAuthProviderFactory(server as any);
|
||||
const affineConfig = {
|
||||
server: {
|
||||
externalUrl: 'https://affine.example',
|
||||
host: 'localhost',
|
||||
path: '',
|
||||
https: true,
|
||||
hosts: [],
|
||||
},
|
||||
oauth: {
|
||||
providers: {
|
||||
oidc: {
|
||||
clientId: config?.clientId ?? 'oidc-client-id',
|
||||
clientSecret: config?.clientSecret ?? 'oidc-client-secret',
|
||||
issuer: config?.issuer ?? 'https://issuer.affine.dev',
|
||||
args: {},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
const provider = new OIDCProvider(new URLHelper(affineConfig as any));
|
||||
|
||||
(provider as any).factory = factory;
|
||||
(provider as any).AFFiNEConfig = affineConfig;
|
||||
|
||||
return {
|
||||
provider,
|
||||
factory,
|
||||
server,
|
||||
};
|
||||
}
|
||||
|
||||
async function flushAsyncWork(iterations = 5) {
|
||||
for (let i = 0; i < iterations; i++) {
|
||||
await new Promise(resolve => setImmediate(resolve));
|
||||
}
|
||||
}
|
||||
|
||||
test('should be able to sign up with oauth', async t => {
|
||||
const { app, db } = t.context;
|
||||
|
||||
@@ -554,3 +667,279 @@ test('should be able to fullfil user with oauth sign in', async t => {
|
||||
t.truthy(account);
|
||||
t.is(account!.user.id, u3.id);
|
||||
});
|
||||
|
||||
test('github oauth should resolve private email from emails api', async t => {
|
||||
const { app, db } = t.context;
|
||||
|
||||
const email = 'github-private@affine.pro';
|
||||
const { clientNonce, provider } = mockGithubOAuthProvider(app);
|
||||
const fetchJson = Sinon.stub(provider as any, 'fetchJson');
|
||||
|
||||
fetchJson.onFirstCall().resolves({
|
||||
login: 'github-user',
|
||||
email: null,
|
||||
avatar_url: 'avatar',
|
||||
name: 'DarkSky',
|
||||
});
|
||||
fetchJson.onSecondCall().resolves([
|
||||
{ email: 'unverified@affine.pro', primary: true, verified: false },
|
||||
{ email, primary: false, verified: true },
|
||||
]);
|
||||
|
||||
await app
|
||||
.POST('/api/oauth/callback')
|
||||
.send({ code: '1', state: '1', client_nonce: clientNonce })
|
||||
.expect(HttpStatus.OK);
|
||||
|
||||
const sessionUser = await currentUser(app);
|
||||
t.truthy(sessionUser);
|
||||
t.is(sessionUser!.email, email);
|
||||
|
||||
const user = await db.user.findFirst({
|
||||
select: {
|
||||
email: true,
|
||||
connectedAccounts: true,
|
||||
},
|
||||
where: {
|
||||
email,
|
||||
},
|
||||
});
|
||||
|
||||
t.truthy(user);
|
||||
t.is(user!.connectedAccounts[0].provider, OAuthProviderName.GitHub);
|
||||
t.is(user!.connectedAccounts[0].providerAccountId, 'github-user');
|
||||
});
|
||||
|
||||
test('github oauth should reject responses without a verified email', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const provider = app.get(GithubOAuthProvider);
|
||||
const fetchJson = Sinon.stub(provider as any, 'fetchJson');
|
||||
|
||||
fetchJson.onFirstCall().resolves({
|
||||
login: 'github-user',
|
||||
email: null,
|
||||
avatar_url: 'avatar',
|
||||
name: 'DarkSky',
|
||||
});
|
||||
fetchJson
|
||||
.onSecondCall()
|
||||
.resolves([
|
||||
{ email: 'private@affine.pro', primary: true, verified: false },
|
||||
]);
|
||||
|
||||
const error = await t.throwsAsync(
|
||||
provider.getUser(
|
||||
{ accessToken: 'token' },
|
||||
{ token: 'state', provider: OAuthProviderName.GitHub }
|
||||
)
|
||||
);
|
||||
|
||||
t.true(error instanceof InvalidOauthResponse);
|
||||
});
|
||||
|
||||
test('oidc should accept email from id token when userinfo email is missing', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const provider = app.get(OIDCProvider);
|
||||
mockOidcProvider(provider, {
|
||||
idTokenClaims: {
|
||||
sub: 'oidc-user',
|
||||
email: 'oidc-id-token@affine.pro',
|
||||
name: 'OIDC User',
|
||||
},
|
||||
userinfo: {
|
||||
sub: 'oidc-user',
|
||||
name: 'OIDC User',
|
||||
},
|
||||
});
|
||||
|
||||
const user = await provider.getUser(
|
||||
{ accessToken: 'token', idToken: 'id-token' },
|
||||
{ token: 'nonce', provider: OAuthProviderName.OIDC }
|
||||
);
|
||||
|
||||
t.is(user.id, 'oidc-user');
|
||||
t.is(user.email, 'oidc-id-token@affine.pro');
|
||||
t.is(user.name, 'OIDC User');
|
||||
});
|
||||
|
||||
test('oidc should resolve custom email claim from userinfo', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const provider = app.get(OIDCProvider);
|
||||
mockOidcProvider(provider, {
|
||||
args: { claim_email: 'mail', claim_name: 'display_name' },
|
||||
idTokenClaims: {
|
||||
sub: 'oidc-user',
|
||||
},
|
||||
userinfo: {
|
||||
sub: 'oidc-user',
|
||||
mail: 'oidc-userinfo@affine.pro',
|
||||
display_name: 'OIDC Custom',
|
||||
},
|
||||
});
|
||||
|
||||
const user = await provider.getUser(
|
||||
{ accessToken: 'token', idToken: 'id-token' },
|
||||
{ token: 'nonce', provider: OAuthProviderName.OIDC }
|
||||
);
|
||||
|
||||
t.is(user.id, 'oidc-user');
|
||||
t.is(user.email, 'oidc-userinfo@affine.pro');
|
||||
t.is(user.name, 'OIDC Custom');
|
||||
});
|
||||
|
||||
test('oidc should resolve custom email claim from id token', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const provider = app.get(OIDCProvider);
|
||||
mockOidcProvider(provider, {
|
||||
args: { claim_email: 'mail', claim_email_verified: 'mail_verified' },
|
||||
idTokenClaims: {
|
||||
sub: 'oidc-user',
|
||||
mail: 'oidc-custom-id-token@affine.pro',
|
||||
mail_verified: 'true',
|
||||
},
|
||||
userinfo: {
|
||||
sub: 'oidc-user',
|
||||
},
|
||||
});
|
||||
|
||||
const user = await provider.getUser(
|
||||
{ accessToken: 'token', idToken: 'id-token' },
|
||||
{ token: 'nonce', provider: OAuthProviderName.OIDC }
|
||||
);
|
||||
|
||||
t.is(user.id, 'oidc-user');
|
||||
t.is(user.email, 'oidc-custom-id-token@affine.pro');
|
||||
});
|
||||
|
||||
test('oidc should reject responses without a usable email claim', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const provider = app.get(OIDCProvider);
|
||||
mockOidcProvider(provider, {
|
||||
args: { claim_email: 'mail' },
|
||||
idTokenClaims: {
|
||||
sub: 'oidc-user',
|
||||
mail: 'not-an-email',
|
||||
},
|
||||
userinfo: {
|
||||
sub: 'oidc-user',
|
||||
mail: 'still-not-an-email',
|
||||
},
|
||||
});
|
||||
|
||||
const error = await t.throwsAsync(
|
||||
provider.getUser(
|
||||
{ accessToken: 'token', idToken: 'id-token' },
|
||||
{ token: 'nonce', provider: OAuthProviderName.OIDC }
|
||||
)
|
||||
);
|
||||
|
||||
t.true(error instanceof InvalidOauthResponse);
|
||||
t.true(
|
||||
error.message.includes(
|
||||
'Missing valid email claim in OIDC response. Tried userinfo and ID token claims: "mail"'
|
||||
)
|
||||
);
|
||||
});
|
||||
|
||||
test('oidc should not fall back to default email claim when custom claim is configured', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const provider = app.get(OIDCProvider);
|
||||
mockOidcProvider(provider, {
|
||||
args: { claim_email: 'mail' },
|
||||
idTokenClaims: {
|
||||
sub: 'oidc-user',
|
||||
email: 'fallback@affine.pro',
|
||||
},
|
||||
userinfo: {
|
||||
sub: 'oidc-user',
|
||||
email: 'userinfo-fallback@affine.pro',
|
||||
},
|
||||
});
|
||||
|
||||
const error = await t.throwsAsync(
|
||||
provider.getUser(
|
||||
{ accessToken: 'token', idToken: 'id-token' },
|
||||
{ token: 'nonce', provider: OAuthProviderName.OIDC }
|
||||
)
|
||||
);
|
||||
|
||||
t.true(error instanceof InvalidOauthResponse);
|
||||
t.true(
|
||||
error.message.includes(
|
||||
'Missing valid email claim in OIDC response. Tried userinfo and ID token claims: "mail"'
|
||||
)
|
||||
);
|
||||
});
|
||||
|
||||
test('oidc discovery should remove oauth feature on failure and restore it after backoff retry succeeds', async t => {
|
||||
const { provider, factory, server } = createOidcRegistrationHarness();
|
||||
const fetchStub = Sinon.stub(globalThis, 'fetch');
|
||||
const scheduledRetries: Array<() => void> = [];
|
||||
const retryDelays: number[] = [];
|
||||
const setTimeoutStub = Sinon.stub(globalThis, 'setTimeout').callsFake(((
|
||||
callback: Parameters<typeof setTimeout>[0],
|
||||
delay?: number
|
||||
) => {
|
||||
retryDelays.push(Number(delay));
|
||||
scheduledRetries.push(callback as () => void);
|
||||
return Symbol('timeout') as unknown as ReturnType<typeof setTimeout>;
|
||||
}) as typeof setTimeout);
|
||||
t.teardown(() => {
|
||||
provider.onModuleDestroy();
|
||||
fetchStub.restore();
|
||||
setTimeoutStub.restore();
|
||||
});
|
||||
|
||||
fetchStub
|
||||
.onFirstCall()
|
||||
.rejects(new Error('temporary discovery failure'))
|
||||
.onSecondCall()
|
||||
.rejects(new Error('temporary discovery failure'))
|
||||
.onThirdCall()
|
||||
.resolves(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
authorization_endpoint: 'https://issuer.affine.dev/auth',
|
||||
token_endpoint: 'https://issuer.affine.dev/token',
|
||||
userinfo_endpoint: 'https://issuer.affine.dev/userinfo',
|
||||
issuer: 'https://issuer.affine.dev',
|
||||
jwks_uri: 'https://issuer.affine.dev/jwks',
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
(provider as any).setup();
|
||||
|
||||
await flushAsyncWork();
|
||||
t.deepEqual(factory.providers, []);
|
||||
t.true(server.disableFeature.calledWith(ServerFeature.OAuth));
|
||||
t.is(fetchStub.callCount, 1);
|
||||
t.deepEqual(retryDelays, [1000]);
|
||||
|
||||
const firstRetry = scheduledRetries.shift();
|
||||
t.truthy(firstRetry);
|
||||
firstRetry!();
|
||||
await flushAsyncWork();
|
||||
t.is(fetchStub.callCount, 2);
|
||||
t.deepEqual(factory.providers, []);
|
||||
t.deepEqual(retryDelays, [1000, 2000]);
|
||||
|
||||
const secondRetry = scheduledRetries.shift();
|
||||
t.truthy(secondRetry);
|
||||
secondRetry!();
|
||||
await flushAsyncWork();
|
||||
t.is(fetchStub.callCount, 3);
|
||||
t.deepEqual(factory.providers, [OAuthProviderName.OIDC]);
|
||||
t.true(server.enableFeature.calledWith(ServerFeature.OAuth));
|
||||
t.is(scheduledRetries.length, 0);
|
||||
});
|
||||
|
||||
@@ -11,6 +11,7 @@ import { ConfigFactory, ConfigModule } from '../../base/config';
|
||||
import { CurrentUser } from '../../core/auth';
|
||||
import { AuthService } from '../../core/auth/service';
|
||||
import { EarlyAccessType, FeatureService } from '../../core/features';
|
||||
import { SubscriptionCronJobs } from '../../plugins/payment/cron';
|
||||
import { SubscriptionService } from '../../plugins/payment/service';
|
||||
import { StripeFactory } from '../../plugins/payment/stripe';
|
||||
import {
|
||||
@@ -871,6 +872,34 @@ test('should be able to cancel subscription', async t => {
|
||||
t.truthy(subInDB.canceledAt);
|
||||
});
|
||||
|
||||
test('should reconcile canceled stripe subscriptions and revoke local entitlement', async t => {
|
||||
const { app, db, event, service, stripe, u1 } = t.context;
|
||||
const cron = app.get(SubscriptionCronJobs);
|
||||
|
||||
await service.saveStripeSubscription(sub);
|
||||
event.emit.resetHistory();
|
||||
|
||||
stripe.subscriptions.retrieve.resolves({
|
||||
...sub,
|
||||
status: SubscriptionStatus.Canceled,
|
||||
} as any);
|
||||
|
||||
await cron.reconcileStripeSubscriptions();
|
||||
|
||||
const subInDB = await db.subscription.findFirst({
|
||||
where: { targetId: u1.id, stripeSubscriptionId: sub.id },
|
||||
});
|
||||
|
||||
t.is(subInDB, null);
|
||||
t.true(
|
||||
event.emit.calledWith('user.subscription.canceled', {
|
||||
userId: u1.id,
|
||||
plan: SubscriptionPlan.Pro,
|
||||
recurring: SubscriptionRecurring.Monthly,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
test('should be able to resume subscription', async t => {
|
||||
const { service, db, u1, stripe } = t.context;
|
||||
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
import test from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import {
|
||||
exponentialBackoffDelay,
|
||||
ExponentialBackoffScheduler,
|
||||
} from '../promise';
|
||||
|
||||
test('exponentialBackoffDelay should cap exponential growth at maxDelayMs', t => {
|
||||
t.is(exponentialBackoffDelay(0, { baseDelayMs: 100, maxDelayMs: 500 }), 100);
|
||||
t.is(exponentialBackoffDelay(1, { baseDelayMs: 100, maxDelayMs: 500 }), 200);
|
||||
t.is(exponentialBackoffDelay(3, { baseDelayMs: 100, maxDelayMs: 500 }), 500);
|
||||
});
|
||||
|
||||
test('ExponentialBackoffScheduler should track pending callback and increase delay per attempt', async t => {
|
||||
const clock = Sinon.useFakeTimers();
|
||||
t.teardown(() => {
|
||||
clock.restore();
|
||||
});
|
||||
|
||||
const calls: number[] = [];
|
||||
const scheduler = new ExponentialBackoffScheduler({
|
||||
baseDelayMs: 100,
|
||||
maxDelayMs: 500,
|
||||
});
|
||||
|
||||
t.is(
|
||||
scheduler.schedule(() => {
|
||||
calls.push(1);
|
||||
}),
|
||||
100
|
||||
);
|
||||
t.true(scheduler.pending);
|
||||
t.is(
|
||||
scheduler.schedule(() => {
|
||||
calls.push(2);
|
||||
}),
|
||||
null
|
||||
);
|
||||
|
||||
await clock.tickAsync(100);
|
||||
t.deepEqual(calls, [1]);
|
||||
t.false(scheduler.pending);
|
||||
|
||||
t.is(
|
||||
scheduler.schedule(() => {
|
||||
calls.push(3);
|
||||
}),
|
||||
200
|
||||
);
|
||||
await clock.tickAsync(200);
|
||||
t.deepEqual(calls, [1, 3]);
|
||||
});
|
||||
|
||||
test('ExponentialBackoffScheduler reset should clear pending work and restart from the base delay', t => {
|
||||
const scheduler = new ExponentialBackoffScheduler({
|
||||
baseDelayMs: 100,
|
||||
maxDelayMs: 500,
|
||||
});
|
||||
|
||||
t.is(
|
||||
scheduler.schedule(() => {}),
|
||||
100
|
||||
);
|
||||
t.true(scheduler.pending);
|
||||
|
||||
scheduler.reset();
|
||||
t.false(scheduler.pending);
|
||||
t.is(
|
||||
scheduler.schedule(() => {}),
|
||||
100
|
||||
);
|
||||
|
||||
scheduler.clear();
|
||||
});
|
||||
@@ -1,4 +1,4 @@
|
||||
import { setTimeout } from 'node:timers/promises';
|
||||
import { setTimeout as delay } from 'node:timers/promises';
|
||||
|
||||
import { defer as rxjsDefer, retry } from 'rxjs';
|
||||
|
||||
@@ -52,5 +52,61 @@ export function defer(dispose: () => Promise<void>) {
|
||||
}
|
||||
|
||||
export function sleep(ms: number): Promise<void> {
|
||||
return setTimeout(ms);
|
||||
return delay(ms);
|
||||
}
|
||||
|
||||
export function exponentialBackoffDelay(
|
||||
attempt: number,
|
||||
{
|
||||
baseDelayMs,
|
||||
maxDelayMs,
|
||||
factor = 2,
|
||||
}: { baseDelayMs: number; maxDelayMs: number; factor?: number }
|
||||
): number {
|
||||
return Math.min(
|
||||
baseDelayMs * Math.pow(factor, Math.max(0, attempt)),
|
||||
maxDelayMs
|
||||
);
|
||||
}
|
||||
|
||||
export class ExponentialBackoffScheduler {
|
||||
#attempt = 0;
|
||||
#timer: ReturnType<typeof globalThis.setTimeout> | null = null;
|
||||
|
||||
constructor(
|
||||
private readonly options: {
|
||||
baseDelayMs: number;
|
||||
maxDelayMs: number;
|
||||
factor?: number;
|
||||
}
|
||||
) {}
|
||||
|
||||
get pending() {
|
||||
return this.#timer !== null;
|
||||
}
|
||||
|
||||
clear() {
|
||||
if (this.#timer) {
|
||||
clearTimeout(this.#timer);
|
||||
this.#timer = null;
|
||||
}
|
||||
}
|
||||
|
||||
reset() {
|
||||
this.#attempt = 0;
|
||||
this.clear();
|
||||
}
|
||||
|
||||
schedule(callback: () => void) {
|
||||
if (this.#timer) return null;
|
||||
|
||||
const timeout = exponentialBackoffDelay(this.#attempt, this.options);
|
||||
this.#timer = globalThis.setTimeout(() => {
|
||||
this.#timer = null;
|
||||
callback();
|
||||
}, timeout);
|
||||
this.#attempt += 1;
|
||||
|
||||
return timeout;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import test from 'ava';
|
||||
|
||||
import { createTestingModule, TestingModule } from '../../../__tests__/utils';
|
||||
@@ -10,11 +12,13 @@ import {
|
||||
} from '../../../models';
|
||||
import { DocAccessController } from '../doc';
|
||||
import { PermissionModule } from '../index';
|
||||
import { WorkspacePolicyService } from '../policy';
|
||||
import { DocRole, mapDocRoleToPermissions } from '../types';
|
||||
|
||||
let module: TestingModule;
|
||||
let models: Models;
|
||||
let ac: DocAccessController;
|
||||
let policy: WorkspacePolicyService;
|
||||
let user: User;
|
||||
let ws: Workspace;
|
||||
|
||||
@@ -22,11 +26,12 @@ test.before(async () => {
|
||||
module = await createTestingModule({ imports: [PermissionModule] });
|
||||
models = module.get<Models>(Models);
|
||||
ac = module.get(DocAccessController);
|
||||
policy = module.get(WorkspacePolicyService);
|
||||
});
|
||||
|
||||
test.beforeEach(async () => {
|
||||
await module.initTestingDB();
|
||||
user = await models.user.create({ email: 'u1@affine.pro' });
|
||||
user = await models.user.create({ email: `${randomUUID()}@affine.pro` });
|
||||
ws = await models.workspace.create(user.id);
|
||||
});
|
||||
|
||||
@@ -45,7 +50,7 @@ test('should get null role', async t => {
|
||||
});
|
||||
|
||||
test('should return null if workspace role is not accepted', async t => {
|
||||
const u2 = await models.user.create({ email: 'u2@affine.pro' });
|
||||
const u2 = await models.user.create({ email: `${randomUUID()}@affine.pro` });
|
||||
await models.workspaceUser.set(ws.id, u2.id, WorkspaceRole.Collaborator, {
|
||||
status: WorkspaceMemberStatus.UnderReview,
|
||||
});
|
||||
@@ -162,7 +167,7 @@ test('should assert action', async t => {
|
||||
)
|
||||
);
|
||||
|
||||
const u2 = await models.user.create({ email: 'u2@affine.pro' });
|
||||
const u2 = await models.user.create({ email: `${randomUUID()}@affine.pro` });
|
||||
|
||||
await t.throwsAsync(
|
||||
ac.assert(
|
||||
@@ -184,3 +189,37 @@ test('should assert action', async t => {
|
||||
)
|
||||
);
|
||||
});
|
||||
|
||||
test('should apply readonly doc restrictions while keeping cleanup actions', async t => {
|
||||
for (let index = 0; index < 10; index++) {
|
||||
const member = await models.user.create({
|
||||
email: `${randomUUID()}@affine.pro`,
|
||||
});
|
||||
await models.workspaceUser.set(
|
||||
ws.id,
|
||||
member.id,
|
||||
WorkspaceRole.Collaborator,
|
||||
{
|
||||
status: WorkspaceMemberStatus.Accepted,
|
||||
}
|
||||
);
|
||||
}
|
||||
await policy.reconcileWorkspaceQuotaState(ws.id);
|
||||
|
||||
const { permissions } = await ac.role({
|
||||
workspaceId: ws.id,
|
||||
docId: 'doc1',
|
||||
userId: user.id,
|
||||
});
|
||||
|
||||
t.false(permissions['Doc.Update']);
|
||||
t.false(permissions['Doc.Publish']);
|
||||
t.false(permissions['Doc.Duplicate']);
|
||||
t.false(permissions['Doc.Comments.Create']);
|
||||
t.false(permissions['Doc.Comments.Update']);
|
||||
t.false(permissions['Doc.Comments.Resolve']);
|
||||
t.true(permissions['Doc.Read']);
|
||||
t.true(permissions['Doc.Delete']);
|
||||
t.true(permissions['Doc.Trash']);
|
||||
t.true(permissions['Doc.TransferOwner']);
|
||||
});
|
||||
|
||||
@@ -0,0 +1,198 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import ava, { TestFn } from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import {
|
||||
createTestingModule,
|
||||
type TestingModule,
|
||||
} from '../../../__tests__/utils';
|
||||
import { SpaceAccessDenied } from '../../../base';
|
||||
import {
|
||||
Models,
|
||||
User,
|
||||
Workspace,
|
||||
WorkspaceMemberStatus,
|
||||
WorkspaceRole,
|
||||
} from '../../../models';
|
||||
import { QuotaService } from '../../quota/service';
|
||||
import { PermissionModule } from '../index';
|
||||
import { WorkspacePolicyService } from '../policy';
|
||||
|
||||
interface Context {
|
||||
module: TestingModule;
|
||||
models: Models;
|
||||
policy: WorkspacePolicyService;
|
||||
}
|
||||
|
||||
const test = ava as TestFn<Context>;
|
||||
|
||||
const READONLY_FEATURE = 'quota_exceeded_readonly_workspace_v1' as const;
|
||||
type WorkspaceQuotaSnapshot = Awaited<
|
||||
ReturnType<QuotaService['getWorkspaceQuotaWithUsage']>
|
||||
> & {
|
||||
ownerQuota?: string;
|
||||
};
|
||||
async function addAcceptedMembers(
|
||||
models: Models,
|
||||
workspaceId: string,
|
||||
count: number
|
||||
) {
|
||||
for (let index = 0; index < count; index++) {
|
||||
const member = await models.user.create({
|
||||
email: `${randomUUID()}@affine.pro`,
|
||||
});
|
||||
await models.workspaceUser.set(
|
||||
workspaceId,
|
||||
member.id,
|
||||
WorkspaceRole.Collaborator,
|
||||
{
|
||||
status: WorkspaceMemberStatus.Accepted,
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let owner: User;
|
||||
let workspace: Workspace;
|
||||
|
||||
test.before(async t => {
|
||||
const module = await createTestingModule({ imports: [PermissionModule] });
|
||||
t.context.module = module;
|
||||
t.context.models = module.get(Models);
|
||||
t.context.policy = module.get(WorkspacePolicyService);
|
||||
});
|
||||
|
||||
test.beforeEach(async t => {
|
||||
Sinon.restore();
|
||||
await t.context.module.initTestingDB();
|
||||
owner = await t.context.models.user.create({
|
||||
email: `${randomUUID()}@affine.pro`,
|
||||
});
|
||||
workspace = await t.context.models.workspace.create(owner.id);
|
||||
});
|
||||
|
||||
test.after.always(async t => {
|
||||
await t.context.module.close();
|
||||
});
|
||||
|
||||
test('should keep owned workspace writable when quota is within limit', async t => {
|
||||
const state = await t.context.policy.reconcileWorkspaceQuotaState(
|
||||
workspace.id
|
||||
);
|
||||
|
||||
t.false(state.isReadonly);
|
||||
t.deepEqual(state.readonlyReasons, []);
|
||||
t.false(
|
||||
await t.context.models.workspaceFeature.has(workspace.id, READONLY_FEATURE)
|
||||
);
|
||||
});
|
||||
|
||||
test('should enter readonly mode when fallback owner member quota overflows', async t => {
|
||||
await addAcceptedMembers(t.context.models, workspace.id, 10);
|
||||
|
||||
const state = await t.context.policy.reconcileWorkspaceQuotaState(
|
||||
workspace.id
|
||||
);
|
||||
|
||||
t.true(state.isReadonly);
|
||||
t.true(state.canRecoverByRemovingMembers);
|
||||
t.false(state.canRecoverByDeletingBlobs);
|
||||
t.deepEqual(state.readonlyReasons, ['member_overflow']);
|
||||
t.true(
|
||||
await t.context.models.workspaceFeature.has(workspace.id, READONLY_FEATURE)
|
||||
);
|
||||
await t.throwsAsync(t.context.policy.assertCanInviteMembers(workspace.id), {
|
||||
instanceOf: SpaceAccessDenied,
|
||||
});
|
||||
});
|
||||
|
||||
test('should enter readonly mode when fallback owner storage quota overflows', async t => {
|
||||
const quota = Sinon.stub(
|
||||
Reflect.get(t.context.policy, 'quota') as QuotaService,
|
||||
'getWorkspaceQuotaWithUsage'
|
||||
);
|
||||
quota.resolves({
|
||||
name: 'Free',
|
||||
blobLimit: 1,
|
||||
storageQuota: 1,
|
||||
usedStorageQuota: 2,
|
||||
historyPeriod: 1,
|
||||
memberLimit: 3,
|
||||
memberCount: 1,
|
||||
overcapacityMemberCount: 0,
|
||||
usedSize: 2,
|
||||
ownerQuota: owner.id,
|
||||
} satisfies WorkspaceQuotaSnapshot);
|
||||
|
||||
const state = await t.context.policy.reconcileWorkspaceQuotaState(
|
||||
workspace.id
|
||||
);
|
||||
|
||||
t.true(state.isReadonly);
|
||||
t.false(state.canRecoverByRemovingMembers);
|
||||
t.true(state.canRecoverByDeletingBlobs);
|
||||
t.deepEqual(state.readonlyReasons, ['storage_overflow']);
|
||||
t.true(
|
||||
await t.context.models.workspaceFeature.has(workspace.id, READONLY_FEATURE)
|
||||
);
|
||||
});
|
||||
|
||||
test('should leave readonly mode after workspace usage recovers', async t => {
|
||||
const quota = Sinon.stub(
|
||||
Reflect.get(t.context.policy, 'quota') as QuotaService,
|
||||
'getWorkspaceQuotaWithUsage'
|
||||
);
|
||||
quota.onFirstCall().resolves({
|
||||
name: 'Free',
|
||||
blobLimit: 1,
|
||||
storageQuota: 1,
|
||||
usedStorageQuota: 2,
|
||||
historyPeriod: 1,
|
||||
memberLimit: 3,
|
||||
memberCount: 1,
|
||||
overcapacityMemberCount: 0,
|
||||
usedSize: 2,
|
||||
ownerQuota: owner.id,
|
||||
} satisfies WorkspaceQuotaSnapshot);
|
||||
quota.onSecondCall().resolves({
|
||||
name: 'Free',
|
||||
blobLimit: 1,
|
||||
storageQuota: 1,
|
||||
usedStorageQuota: 0,
|
||||
historyPeriod: 1,
|
||||
memberLimit: 3,
|
||||
memberCount: 1,
|
||||
overcapacityMemberCount: 0,
|
||||
usedSize: 0,
|
||||
ownerQuota: owner.id,
|
||||
} satisfies WorkspaceQuotaSnapshot);
|
||||
quota.onThirdCall().resolves({
|
||||
name: 'Free',
|
||||
blobLimit: 1,
|
||||
storageQuota: 1,
|
||||
usedStorageQuota: 0,
|
||||
historyPeriod: 1,
|
||||
memberLimit: 3,
|
||||
memberCount: 1,
|
||||
overcapacityMemberCount: 0,
|
||||
usedSize: 0,
|
||||
ownerQuota: owner.id,
|
||||
} satisfies WorkspaceQuotaSnapshot);
|
||||
|
||||
await t.context.policy.reconcileWorkspaceQuotaState(workspace.id);
|
||||
t.true(
|
||||
await t.context.models.workspaceFeature.has(workspace.id, READONLY_FEATURE)
|
||||
);
|
||||
|
||||
const recovered = await t.context.policy.reconcileWorkspaceQuotaState(
|
||||
workspace.id
|
||||
);
|
||||
|
||||
t.false(recovered.isReadonly);
|
||||
t.deepEqual(recovered.readonlyReasons, []);
|
||||
t.false(
|
||||
await t.context.models.workspaceFeature.has(workspace.id, READONLY_FEATURE)
|
||||
);
|
||||
await t.notThrowsAsync(t.context.policy.assertCanInviteMembers(workspace.id));
|
||||
});
|
||||
@@ -1,3 +1,5 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import test from 'ava';
|
||||
|
||||
import { createTestingModule, TestingModule } from '../../../__tests__/utils';
|
||||
@@ -9,24 +11,27 @@ import {
|
||||
WorkspaceRole,
|
||||
} from '../../../models';
|
||||
import { PermissionModule } from '../index';
|
||||
import { WorkspacePolicyService } from '../policy';
|
||||
import { mapWorkspaceRoleToPermissions } from '../types';
|
||||
import { WorkspaceAccessController } from '../workspace';
|
||||
|
||||
let module: TestingModule;
|
||||
let models: Models;
|
||||
let ac: WorkspaceAccessController;
|
||||
let policy: WorkspacePolicyService;
|
||||
let user: User;
|
||||
let ws: Workspace;
|
||||
|
||||
test.before(async () => {
|
||||
module = await createTestingModule({ imports: [PermissionModule] });
|
||||
models = module.get<Models>(Models);
|
||||
ac = new WorkspaceAccessController(models);
|
||||
ac = module.get(WorkspaceAccessController);
|
||||
policy = module.get(WorkspacePolicyService);
|
||||
});
|
||||
|
||||
test.beforeEach(async () => {
|
||||
await module.initTestingDB();
|
||||
user = await models.user.create({ email: 'u1@affine.pro' });
|
||||
user = await models.user.create({ email: `${randomUUID()}@affine.pro` });
|
||||
ws = await models.workspace.create(user.id);
|
||||
});
|
||||
|
||||
@@ -44,7 +49,7 @@ test('should get null role', async t => {
|
||||
});
|
||||
|
||||
test('should return null if role is not accepted', async t => {
|
||||
const u2 = await models.user.create({ email: 'u2@affine.pro' });
|
||||
const u2 = await models.user.create({ email: `${randomUUID()}@affine.pro` });
|
||||
await models.workspaceUser.set(ws.id, u2.id, WorkspaceRole.Collaborator, {
|
||||
status: WorkspaceMemberStatus.UnderReview,
|
||||
});
|
||||
@@ -183,3 +188,38 @@ test('should assert action', async t => {
|
||||
)
|
||||
);
|
||||
});
|
||||
|
||||
test('should apply readonly workspace restrictions while keeping cleanup actions', async t => {
|
||||
for (let index = 0; index < 10; index++) {
|
||||
const member = await models.user.create({
|
||||
email: `${randomUUID()}@affine.pro`,
|
||||
});
|
||||
await models.workspaceUser.set(
|
||||
ws.id,
|
||||
member.id,
|
||||
WorkspaceRole.Collaborator,
|
||||
{
|
||||
status: WorkspaceMemberStatus.Accepted,
|
||||
}
|
||||
);
|
||||
}
|
||||
await policy.reconcileWorkspaceQuotaState(ws.id);
|
||||
|
||||
const { permissions } = await ac.role({
|
||||
workspaceId: ws.id,
|
||||
userId: user.id,
|
||||
});
|
||||
|
||||
t.false(permissions['Workspace.CreateDoc']);
|
||||
t.false(permissions['Workspace.Settings.Update']);
|
||||
t.false(permissions['Workspace.Properties.Create']);
|
||||
t.false(permissions['Workspace.Properties.Update']);
|
||||
t.false(permissions['Workspace.Properties.Delete']);
|
||||
t.false(permissions['Workspace.Blobs.Write']);
|
||||
t.true(permissions['Workspace.Read']);
|
||||
t.true(permissions['Workspace.Sync']);
|
||||
t.true(permissions['Workspace.Users.Manage']);
|
||||
t.true(permissions['Workspace.Blobs.List']);
|
||||
t.true(permissions['Workspace.TransferOwner']);
|
||||
t.true(permissions['Workspace.Payment.Manage']);
|
||||
});
|
||||
|
||||
@@ -3,6 +3,7 @@ import { Injectable } from '@nestjs/common';
|
||||
import { DocActionDenied } from '../../base';
|
||||
import { Models } from '../../models';
|
||||
import { AccessController, getAccessController } from './controller';
|
||||
import { WorkspacePolicyService } from './policy';
|
||||
import type { Resource } from './resource';
|
||||
import {
|
||||
DocAction,
|
||||
@@ -15,13 +16,19 @@ import { WorkspaceAccessController } from './workspace';
|
||||
@Injectable()
|
||||
export class DocAccessController extends AccessController<'doc'> {
|
||||
protected readonly type = 'doc';
|
||||
constructor(private readonly models: Models) {
|
||||
constructor(
|
||||
private readonly models: Models,
|
||||
private readonly policy: WorkspacePolicyService
|
||||
) {
|
||||
super();
|
||||
}
|
||||
|
||||
async role(resource: Resource<'doc'>) {
|
||||
const role = await this.getRole(resource);
|
||||
const permissions = mapDocRoleToPermissions(role);
|
||||
const permissions = await this.policy.applyDocPermissions(
|
||||
resource.workspaceId,
|
||||
mapDocRoleToPermissions(role)
|
||||
);
|
||||
const sharingAllowed = await this.models.workspace.allowSharing(
|
||||
resource.workspaceId
|
||||
);
|
||||
|
||||
@@ -1,22 +1,29 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { QuotaService } from '../quota/service';
|
||||
import { StorageModule } from '../storage';
|
||||
import { AccessControllerBuilder } from './builder';
|
||||
import { DocAccessController } from './doc';
|
||||
import { EventsListener } from './event';
|
||||
import { WorkspacePolicyService } from './policy';
|
||||
import { WorkspaceAccessController } from './workspace';
|
||||
|
||||
@Module({
|
||||
imports: [StorageModule],
|
||||
providers: [
|
||||
QuotaService,
|
||||
WorkspaceAccessController,
|
||||
DocAccessController,
|
||||
AccessControllerBuilder,
|
||||
EventsListener,
|
||||
WorkspacePolicyService,
|
||||
],
|
||||
exports: [AccessControllerBuilder],
|
||||
exports: [AccessControllerBuilder, WorkspacePolicyService],
|
||||
})
|
||||
export class PermissionModule {}
|
||||
|
||||
export { AccessControllerBuilder as AccessController } from './builder';
|
||||
export { WorkspacePolicyService } from './policy';
|
||||
export {
|
||||
DOC_ACTIONS,
|
||||
type DocAction,
|
||||
|
||||
328
packages/backend/server/src/core/permission/policy.ts
Normal file
328
packages/backend/server/src/core/permission/policy.ts
Normal file
@@ -0,0 +1,328 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { DocActionDenied, OnEvent, SpaceAccessDenied } from '../../base';
|
||||
import { Models, WorkspaceRole } from '../../models';
|
||||
import { QuotaService } from '../quota/service';
|
||||
import { getAccessController } from './controller';
|
||||
import type { Resource } from './resource';
|
||||
import {
|
||||
type DocAction,
|
||||
type DocActionPermissions,
|
||||
mapWorkspaceRoleToPermissions,
|
||||
type WorkspaceAction,
|
||||
type WorkspaceActionPermissions,
|
||||
} from './types';
|
||||
|
||||
export type WorkspaceReadonlyReason = 'member_overflow' | 'storage_overflow';
|
||||
type WorkspaceQuotaSnapshot = Awaited<
|
||||
ReturnType<QuotaService['getWorkspaceQuotaWithUsage']>
|
||||
> & {
|
||||
ownerQuota?: string;
|
||||
};
|
||||
|
||||
export type WorkspaceState = {
|
||||
isTeamWorkspace: boolean;
|
||||
isReadonly: boolean;
|
||||
readonlyReasons: WorkspaceReadonlyReason[];
|
||||
canRecoverByRemovingMembers: boolean;
|
||||
canRecoverByDeletingBlobs: boolean;
|
||||
usesFallbackOwnerQuota: boolean;
|
||||
};
|
||||
|
||||
const READONLY_WORKSPACE_ACTIONS: WorkspaceAction[] = [
|
||||
'Workspace.CreateDoc',
|
||||
'Workspace.Settings.Update',
|
||||
'Workspace.Properties.Create',
|
||||
'Workspace.Properties.Update',
|
||||
'Workspace.Properties.Delete',
|
||||
'Workspace.Blobs.Write',
|
||||
];
|
||||
|
||||
const READONLY_DOC_ACTIONS: DocAction[] = [
|
||||
'Doc.Update',
|
||||
'Doc.Duplicate',
|
||||
'Doc.Publish',
|
||||
'Doc.Comments.Create',
|
||||
'Doc.Comments.Update',
|
||||
'Doc.Comments.Resolve',
|
||||
];
|
||||
|
||||
const READONLY_WORKSPACE_FEATURE =
|
||||
'quota_exceeded_readonly_workspace_v1' as const;
|
||||
|
||||
type WorkspaceRoleChecker = {
|
||||
getRole(resource: Resource<'ws'>): Promise<WorkspaceRole | null>;
|
||||
docRoles(
|
||||
resource: Resource<'ws'>,
|
||||
docIds: string[]
|
||||
): Promise<Array<{ role: unknown; permissions: Record<DocAction, boolean> }>>;
|
||||
};
|
||||
|
||||
declare global {
|
||||
interface Events {
|
||||
'workspace.blobs.updated': {
|
||||
workspaceId: string;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class WorkspacePolicyService {
|
||||
constructor(
|
||||
private readonly models: Models,
|
||||
private readonly quota: QuotaService
|
||||
) {}
|
||||
|
||||
async getWorkspaceState(workspaceId: string): Promise<WorkspaceState> {
|
||||
const [isTeamWorkspace, isUnlimitedWorkspace, quota] = await Promise.all([
|
||||
this.models.workspace.isTeamWorkspace(workspaceId),
|
||||
this.models.workspaceFeature.has(workspaceId, 'unlimited_workspace'),
|
||||
this.quota.getWorkspaceQuotaWithUsage(workspaceId),
|
||||
]);
|
||||
const quotaSnapshot = quota as WorkspaceQuotaSnapshot;
|
||||
|
||||
const readonlyReasons: WorkspaceReadonlyReason[] = [];
|
||||
const usesFallbackOwnerQuota =
|
||||
!!quotaSnapshot.ownerQuota && !isUnlimitedWorkspace;
|
||||
|
||||
if (usesFallbackOwnerQuota && quotaSnapshot.overcapacityMemberCount > 0) {
|
||||
readonlyReasons.push('member_overflow');
|
||||
}
|
||||
|
||||
if (
|
||||
usesFallbackOwnerQuota &&
|
||||
quotaSnapshot.usedStorageQuota > quotaSnapshot.storageQuota
|
||||
) {
|
||||
readonlyReasons.push('storage_overflow');
|
||||
}
|
||||
|
||||
return {
|
||||
isTeamWorkspace,
|
||||
isReadonly: readonlyReasons.length > 0,
|
||||
readonlyReasons,
|
||||
canRecoverByRemovingMembers: readonlyReasons.includes('member_overflow'),
|
||||
canRecoverByDeletingBlobs: readonlyReasons.includes('storage_overflow'),
|
||||
usesFallbackOwnerQuota,
|
||||
};
|
||||
}
|
||||
|
||||
async reconcileOwnedWorkspaces(userId: string) {
|
||||
const workspaces = await this.models.workspaceUser.getUserActiveRoles(
|
||||
userId,
|
||||
{ role: WorkspaceRole.Owner }
|
||||
);
|
||||
|
||||
await Promise.all(
|
||||
workspaces.map(({ workspaceId }) =>
|
||||
this.reconcileWorkspaceQuotaState(workspaceId)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
async reconcileWorkspaceQuotaState(workspaceId: string) {
|
||||
const [state, isReadonlyFeatureEnabled] = await Promise.all([
|
||||
this.getWorkspaceState(workspaceId),
|
||||
this.models.workspaceFeature.has(workspaceId, READONLY_WORKSPACE_FEATURE),
|
||||
]);
|
||||
|
||||
if (state.isReadonly && !isReadonlyFeatureEnabled) {
|
||||
await this.models.workspaceFeature.add(
|
||||
workspaceId,
|
||||
READONLY_WORKSPACE_FEATURE,
|
||||
`workspace recovery mode: ${state.readonlyReasons.join(',')}`
|
||||
);
|
||||
} else if (!state.isReadonly && isReadonlyFeatureEnabled) {
|
||||
await this.models.workspaceFeature.remove(
|
||||
workspaceId,
|
||||
READONLY_WORKSPACE_FEATURE
|
||||
);
|
||||
}
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
async isWorkspaceReadonly(workspaceId: string) {
|
||||
const hasReadonlyFeature = await this.models.workspaceFeature.has(
|
||||
workspaceId,
|
||||
READONLY_WORKSPACE_FEATURE
|
||||
);
|
||||
|
||||
if (!hasReadonlyFeature) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const state = await this.getWorkspaceState(workspaceId);
|
||||
if (!state.isReadonly) {
|
||||
await this.models.workspaceFeature.remove(
|
||||
workspaceId,
|
||||
READONLY_WORKSPACE_FEATURE
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
async applyWorkspacePermissions(
|
||||
workspaceId: string,
|
||||
permissions: WorkspaceActionPermissions
|
||||
) {
|
||||
if (!(await this.isWorkspaceReadonly(workspaceId))) {
|
||||
return permissions;
|
||||
}
|
||||
|
||||
const next = { ...permissions };
|
||||
READONLY_WORKSPACE_ACTIONS.forEach(action => {
|
||||
next[action] = false;
|
||||
});
|
||||
return next;
|
||||
}
|
||||
|
||||
async applyDocPermissions(
|
||||
workspaceId: string,
|
||||
permissions: DocActionPermissions
|
||||
) {
|
||||
if (!(await this.isWorkspaceReadonly(workspaceId))) {
|
||||
return permissions;
|
||||
}
|
||||
|
||||
const next = { ...permissions };
|
||||
READONLY_DOC_ACTIONS.forEach(action => {
|
||||
next[action] = false;
|
||||
});
|
||||
return next;
|
||||
}
|
||||
|
||||
async assertWorkspaceActionAllowed(
|
||||
workspaceId: string,
|
||||
action: WorkspaceAction
|
||||
) {
|
||||
if (
|
||||
READONLY_WORKSPACE_ACTIONS.includes(action) &&
|
||||
(await this.isWorkspaceReadonly(workspaceId))
|
||||
) {
|
||||
throw new SpaceAccessDenied({ spaceId: workspaceId });
|
||||
}
|
||||
}
|
||||
|
||||
async assertDocActionAllowed(
|
||||
workspaceId: string,
|
||||
docId: string,
|
||||
action: DocAction
|
||||
) {
|
||||
if (
|
||||
READONLY_DOC_ACTIONS.includes(action) &&
|
||||
(await this.isWorkspaceReadonly(workspaceId))
|
||||
) {
|
||||
throw new DocActionDenied({
|
||||
action,
|
||||
docId,
|
||||
spaceId: workspaceId,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async assertWorkspaceRoleAction(
|
||||
userId: string,
|
||||
workspaceId: string,
|
||||
action: WorkspaceAction
|
||||
) {
|
||||
const checker = getAccessController(
|
||||
'ws'
|
||||
) as unknown as WorkspaceRoleChecker;
|
||||
const role = await checker.getRole({ userId, workspaceId });
|
||||
const permissions = mapWorkspaceRoleToPermissions(role);
|
||||
|
||||
if (!permissions[action]) {
|
||||
throw new SpaceAccessDenied({ spaceId: workspaceId });
|
||||
}
|
||||
}
|
||||
|
||||
async assertDocRoleAction(
|
||||
userId: string,
|
||||
workspaceId: string,
|
||||
docId: string,
|
||||
action: DocAction
|
||||
) {
|
||||
const checker = getAccessController(
|
||||
'ws'
|
||||
) as unknown as WorkspaceRoleChecker;
|
||||
const [role] = await checker.docRoles({ userId, workspaceId }, [docId]);
|
||||
|
||||
if (!role?.permissions[action]) {
|
||||
throw new DocActionDenied({
|
||||
action,
|
||||
docId,
|
||||
spaceId: workspaceId,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async assertCanUploadBlob(workspaceId: string) {
|
||||
await this.assertWorkspaceActionAllowed(
|
||||
workspaceId,
|
||||
'Workspace.Blobs.Write'
|
||||
);
|
||||
}
|
||||
|
||||
async assertCanDeleteBlob(userId: string, workspaceId: string) {
|
||||
await this.assertWorkspaceRoleAction(
|
||||
userId,
|
||||
workspaceId,
|
||||
'Workspace.Blobs.Write'
|
||||
);
|
||||
}
|
||||
|
||||
async assertCanInviteMembers(workspaceId: string) {
|
||||
if (await this.isWorkspaceReadonly(workspaceId)) {
|
||||
throw new SpaceAccessDenied({ spaceId: workspaceId });
|
||||
}
|
||||
}
|
||||
|
||||
async assertCanRevokeMember(
|
||||
userId: string,
|
||||
workspaceId: string,
|
||||
role: WorkspaceRole
|
||||
) {
|
||||
await this.assertWorkspaceRoleAction(
|
||||
userId,
|
||||
workspaceId,
|
||||
role === WorkspaceRole.Admin
|
||||
? 'Workspace.Administrators.Manage'
|
||||
: 'Workspace.Users.Manage'
|
||||
);
|
||||
}
|
||||
|
||||
async assertCanPublishDoc(workspaceId: string, docId: string) {
|
||||
await this.assertDocActionAllowed(workspaceId, docId, 'Doc.Publish');
|
||||
}
|
||||
|
||||
async assertCanUnpublishDoc(
|
||||
userId: string,
|
||||
workspaceId: string,
|
||||
docId: string
|
||||
) {
|
||||
await this.assertDocRoleAction(userId, workspaceId, docId, 'Doc.Publish');
|
||||
}
|
||||
|
||||
@OnEvent('workspace.members.updated')
|
||||
async onWorkspaceMembersUpdated({
|
||||
workspaceId,
|
||||
}: Events['workspace.members.updated']) {
|
||||
await this.reconcileWorkspaceQuotaState(workspaceId);
|
||||
}
|
||||
|
||||
@OnEvent('workspace.owner.changed')
|
||||
async onWorkspaceOwnerChanged({
|
||||
workspaceId,
|
||||
}: Events['workspace.owner.changed']) {
|
||||
await this.reconcileWorkspaceQuotaState(workspaceId);
|
||||
}
|
||||
|
||||
@OnEvent('workspace.blobs.updated')
|
||||
async onWorkspaceBlobsUpdated({
|
||||
workspaceId,
|
||||
}: Events['workspace.blobs.updated']) {
|
||||
await this.reconcileWorkspaceQuotaState(workspaceId);
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ import { Injectable } from '@nestjs/common';
|
||||
import { SpaceAccessDenied } from '../../base';
|
||||
import { DocRole, Models } from '../../models';
|
||||
import { AccessController } from './controller';
|
||||
import { WorkspacePolicyService } from './policy';
|
||||
import type { Resource } from './resource';
|
||||
import {
|
||||
fixupDocRole,
|
||||
@@ -17,7 +18,10 @@ import {
|
||||
export class WorkspaceAccessController extends AccessController<'ws'> {
|
||||
protected readonly type = 'ws';
|
||||
|
||||
constructor(private readonly models: Models) {
|
||||
constructor(
|
||||
private readonly models: Models,
|
||||
private readonly policy: WorkspacePolicyService
|
||||
) {
|
||||
super();
|
||||
}
|
||||
|
||||
@@ -37,7 +41,10 @@ export class WorkspaceAccessController extends AccessController<'ws'> {
|
||||
|
||||
return {
|
||||
role,
|
||||
permissions: mapWorkspaceRoleToPermissions(role),
|
||||
permissions: await this.policy.applyWorkspacePermissions(
|
||||
resource.workspaceId,
|
||||
mapWorkspaceRoleToPermissions(role)
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { PermissionModule } from '../permission';
|
||||
import { StorageModule } from '../storage';
|
||||
import { QuotaResolver } from './resolver';
|
||||
import { QuotaService } from './service';
|
||||
@@ -12,7 +11,7 @@ import { QuotaService } from './service';
|
||||
* - quota statistics
|
||||
*/
|
||||
@Module({
|
||||
imports: [StorageModule, PermissionModule],
|
||||
imports: [StorageModule],
|
||||
providers: [QuotaService, QuotaResolver],
|
||||
exports: [QuotaService],
|
||||
})
|
||||
|
||||
@@ -20,7 +20,10 @@ type UserQuotaWithUsage = Omit<UserQuotaType, 'humanReadable'>;
|
||||
type WorkspaceQuota = Omit<BaseWorkspaceQuota, 'seatQuota'> & {
|
||||
ownerQuota?: string;
|
||||
};
|
||||
type WorkspaceQuotaWithUsage = Omit<WorkspaceQuotaType, 'humanReadable'>;
|
||||
export type WorkspaceQuotaWithUsage = Omit<
|
||||
WorkspaceQuotaType,
|
||||
'humanReadable'
|
||||
> & { ownerQuota?: string };
|
||||
|
||||
@Injectable()
|
||||
export class QuotaService {
|
||||
|
||||
@@ -26,6 +26,9 @@ declare global {
|
||||
workspaceId: string;
|
||||
key: string;
|
||||
};
|
||||
'workspace.blobs.updated': {
|
||||
workspaceId: string;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,6 +258,9 @@ export class WorkspaceBlobStorage {
|
||||
await this.provider.delete(`${workspaceId}/${key}`);
|
||||
}
|
||||
await this.models.blob.delete(workspaceId, key, permanently);
|
||||
if (!permanently) {
|
||||
await this.event.emitAsync('workspace.blobs.updated', { workspaceId });
|
||||
}
|
||||
}
|
||||
|
||||
async release(workspaceId: string) {
|
||||
@@ -270,6 +276,8 @@ export class WorkspaceBlobStorage {
|
||||
this.logger.log(
|
||||
`released ${deletedBlobs.length} blobs for workspace ${workspaceId}`
|
||||
);
|
||||
|
||||
await this.event.emitAsync('workspace.blobs.updated', { workspaceId });
|
||||
}
|
||||
|
||||
async totalSize(workspaceId: string) {
|
||||
|
||||
@@ -624,6 +624,7 @@ export class SpaceSyncGateway
|
||||
const { spaceType, spaceId, docId, update } = message;
|
||||
const adapter = this.selectAdapter(client, spaceType);
|
||||
|
||||
// Quota recovery mode is intentionally not applied to sync in this phase.
|
||||
// TODO(@forehalo): enable after frontend supporting doc revert
|
||||
// await this.ac.user(user.id).doc(spaceId, docId).assert('Doc.Update');
|
||||
const timestamp = await adapter.push(
|
||||
|
||||
@@ -25,7 +25,7 @@ import {
|
||||
} from '../../../base';
|
||||
import { Models } from '../../../models';
|
||||
import { CurrentUser } from '../../auth';
|
||||
import { AccessController } from '../../permission';
|
||||
import { AccessController, WorkspacePolicyService } from '../../permission';
|
||||
import { QuotaService } from '../../quota';
|
||||
import { WorkspaceBlobStorage } from '../../storage';
|
||||
import {
|
||||
@@ -126,6 +126,7 @@ export class WorkspaceBlobResolver {
|
||||
logger = new Logger(WorkspaceBlobResolver.name);
|
||||
constructor(
|
||||
private readonly ac: AccessController,
|
||||
private readonly policy: WorkspacePolicyService,
|
||||
private readonly quota: QuotaService,
|
||||
private readonly storage: WorkspaceBlobStorage,
|
||||
private readonly models: Models
|
||||
@@ -466,10 +467,7 @@ export class WorkspaceBlobResolver {
|
||||
return false;
|
||||
}
|
||||
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(workspaceId)
|
||||
.assert('Workspace.Blobs.Write');
|
||||
await this.policy.assertCanDeleteBlob(user.id, workspaceId);
|
||||
|
||||
await this.storage.delete(workspaceId, key, permanently);
|
||||
|
||||
@@ -481,10 +479,7 @@ export class WorkspaceBlobResolver {
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('workspaceId') workspaceId: string
|
||||
) {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(workspaceId)
|
||||
.assert('Workspace.Blobs.Write');
|
||||
await this.policy.assertCanDeleteBlob(user.id, workspaceId);
|
||||
|
||||
await this.storage.release(workspaceId);
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ import {
|
||||
DOC_ACTIONS,
|
||||
DocAction,
|
||||
DocRole,
|
||||
WorkspacePolicyService,
|
||||
} from '../../permission';
|
||||
import { PublicUserType, WorkspaceUserType } from '../../user';
|
||||
import { WorkspaceType } from '../types';
|
||||
@@ -295,6 +296,7 @@ export class WorkspaceDocResolver {
|
||||
*/
|
||||
private readonly prisma: PrismaClient,
|
||||
private readonly ac: AccessController,
|
||||
private readonly policy: WorkspacePolicyService,
|
||||
private readonly models: Models,
|
||||
private readonly cache: Cache
|
||||
) {}
|
||||
@@ -437,7 +439,7 @@ export class WorkspaceDocResolver {
|
||||
throw new ExpectToRevokePublicDoc('Expect doc not to be workspace');
|
||||
}
|
||||
|
||||
await this.ac.user(user.id).doc(workspaceId, docId).assert('Doc.Publish');
|
||||
await this.policy.assertCanUnpublishDoc(user.id, workspaceId, docId);
|
||||
|
||||
const doc = await this.models.doc.unpublish(workspaceId, docId);
|
||||
|
||||
|
||||
@@ -36,7 +36,11 @@ import {
|
||||
} from '../../../base';
|
||||
import { Models } from '../../../models';
|
||||
import { CurrentUser, Public } from '../../auth';
|
||||
import { AccessController, WorkspaceRole } from '../../permission';
|
||||
import {
|
||||
AccessController,
|
||||
WorkspacePolicyService,
|
||||
WorkspaceRole,
|
||||
} from '../../permission';
|
||||
import { QuotaService } from '../../quota';
|
||||
import { UserType } from '../../user';
|
||||
import { validators } from '../../utils/validators';
|
||||
@@ -64,6 +68,7 @@ export class WorkspaceMemberResolver {
|
||||
private readonly ac: AccessController,
|
||||
private readonly models: Models,
|
||||
private readonly mutex: RequestMutex,
|
||||
private readonly policy: WorkspacePolicyService,
|
||||
private readonly workspaceService: WorkspaceService,
|
||||
private readonly quota: QuotaService
|
||||
) {}
|
||||
@@ -304,10 +309,11 @@ export class WorkspaceMemberResolver {
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('workspaceId') workspaceId: string
|
||||
) {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(workspaceId)
|
||||
.assert('Workspace.Users.Manage');
|
||||
await this.policy.assertWorkspaceRoleAction(
|
||||
user.id,
|
||||
workspaceId,
|
||||
'Workspace.Users.Manage'
|
||||
);
|
||||
|
||||
const cacheId = `workspace:inviteLink:${workspaceId}`;
|
||||
return await this.cache.delete(cacheId);
|
||||
@@ -359,6 +365,7 @@ export class WorkspaceMemberResolver {
|
||||
role.id,
|
||||
me.id
|
||||
);
|
||||
await this.policy.reconcileWorkspaceQuotaState(workspaceId);
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
@@ -453,14 +460,7 @@ export class WorkspaceMemberResolver {
|
||||
throw new MemberNotFoundInSpace({ spaceId: workspaceId });
|
||||
}
|
||||
|
||||
await this.ac
|
||||
.user(me.id)
|
||||
.workspace(workspaceId)
|
||||
.assert(
|
||||
role.type === WorkspaceRole.Admin
|
||||
? 'Workspace.Administrators.Manage'
|
||||
: 'Workspace.Users.Manage'
|
||||
);
|
||||
await this.policy.assertCanRevokeMember(me.id, workspaceId, role.type);
|
||||
|
||||
await this.models.workspaceUser.delete(workspaceId, userId);
|
||||
|
||||
@@ -480,6 +480,7 @@ export class WorkspaceMemberResolver {
|
||||
this.event.emit('workspace.members.updated', {
|
||||
workspaceId,
|
||||
});
|
||||
await this.policy.reconcileWorkspaceQuotaState(workspaceId);
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -580,11 +581,20 @@ export class WorkspaceMemberResolver {
|
||||
this.event.emit('workspace.members.updated', {
|
||||
workspaceId,
|
||||
});
|
||||
await this.policy.reconcileWorkspaceQuotaState(workspaceId);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private async acceptInvitationByEmail(role: WorkspaceUserRole) {
|
||||
await this.policy.assertCanInviteMembers(role.workspaceId);
|
||||
|
||||
const hasSeat = await this.quota.tryCheckSeat(role.workspaceId, true);
|
||||
|
||||
if (!hasSeat) {
|
||||
throw new NoMoreSeat({ spaceId: role.workspaceId });
|
||||
}
|
||||
|
||||
await this.models.workspaceUser.setStatus(
|
||||
role.workspaceId,
|
||||
role.userId,
|
||||
@@ -596,6 +606,7 @@ export class WorkspaceMemberResolver {
|
||||
(await this.models.workspaceUser.getOwner(role.workspaceId)).id,
|
||||
role.id
|
||||
);
|
||||
await this.policy.reconcileWorkspaceQuotaState(role.workspaceId);
|
||||
}
|
||||
|
||||
private async acceptInvitationByLink(
|
||||
@@ -603,6 +614,8 @@ export class WorkspaceMemberResolver {
|
||||
workspaceId: string,
|
||||
inviterId: string
|
||||
) {
|
||||
await this.policy.assertCanInviteMembers(workspaceId);
|
||||
|
||||
let inviter = await this.models.user.getPublicUser(inviterId);
|
||||
if (!inviter) {
|
||||
inviter = await this.models.workspaceUser.getOwner(workspaceId);
|
||||
|
||||
@@ -53,6 +53,7 @@ export enum Feature {
|
||||
// workspace
|
||||
UnlimitedWorkspace = 'unlimited_workspace',
|
||||
TeamPlan = 'team_plan_v1',
|
||||
QuotaExceededReadonlyWorkspace = 'quota_exceeded_readonly_workspace_v1',
|
||||
}
|
||||
|
||||
// TODO(@forehalo): may merge `FeatureShapes` and `FeatureConfigs`?
|
||||
@@ -66,6 +67,7 @@ export const FeaturesShapes = {
|
||||
pro_plan_v1: UserPlanQuotaConfig,
|
||||
lifetime_pro_plan_v1: UserPlanQuotaConfig,
|
||||
team_plan_v1: WorkspaceQuotaConfig,
|
||||
quota_exceeded_readonly_workspace_v1: EMPTY_CONFIG,
|
||||
} satisfies Record<Feature, z.ZodObject<any>>;
|
||||
|
||||
export type UserFeatureName = keyof Pick<
|
||||
@@ -80,7 +82,9 @@ export type UserFeatureName = keyof Pick<
|
||||
>;
|
||||
export type WorkspaceFeatureName = keyof Pick<
|
||||
typeof FeaturesShapes,
|
||||
'unlimited_workspace' | 'team_plan_v1'
|
||||
| 'unlimited_workspace'
|
||||
| 'team_plan_v1'
|
||||
| 'quota_exceeded_readonly_workspace_v1'
|
||||
>;
|
||||
|
||||
export type FeatureName = UserFeatureName | WorkspaceFeatureName;
|
||||
@@ -162,6 +166,7 @@ export const FeatureConfigs: {
|
||||
team_plan_v1: TeamFeature,
|
||||
early_access: WhitelistFeature,
|
||||
unlimited_workspace: EmptyFeature,
|
||||
quota_exceeded_readonly_workspace_v1: EmptyFeature,
|
||||
unlimited_copilot: EmptyFeature,
|
||||
ai_early_access: EmptyFeature,
|
||||
administrator: EmptyFeature,
|
||||
|
||||
@@ -36,7 +36,8 @@ export class WorkspaceUserModel extends BaseModel {
|
||||
|
||||
/**
|
||||
* Set or update the [Owner] of a workspace.
|
||||
* The old [Owner] will be changed to [Admin] if there is already an [Owner].
|
||||
* The old [Owner] will be changed to [Admin] for team workspace and
|
||||
* [Collaborator] for owned workspace if there is already an [Owner].
|
||||
*/
|
||||
@Transactional()
|
||||
async setOwner(workspaceId: string, userId: string) {
|
||||
@@ -63,12 +64,18 @@ export class WorkspaceUserModel extends BaseModel {
|
||||
throw new NewOwnerIsNotActiveMember();
|
||||
}
|
||||
|
||||
const fallbackRole = (await this.models.workspace.isTeamWorkspace(
|
||||
workspaceId
|
||||
))
|
||||
? WorkspaceRole.Admin
|
||||
: WorkspaceRole.Collaborator;
|
||||
|
||||
await this.db.workspaceUserRole.update({
|
||||
where: {
|
||||
id: oldOwner.id,
|
||||
},
|
||||
data: {
|
||||
type: WorkspaceRole.Admin,
|
||||
type: fallbackRole,
|
||||
},
|
||||
});
|
||||
await this.db.workspaceUserRole.update({
|
||||
@@ -201,6 +208,25 @@ export class WorkspaceUserModel extends BaseModel {
|
||||
});
|
||||
}
|
||||
|
||||
async deleteNonAccepted(workspaceId: string) {
|
||||
return await this.db.workspaceUserRole.deleteMany({
|
||||
where: { workspaceId, status: { not: WorkspaceMemberStatus.Accepted } },
|
||||
});
|
||||
}
|
||||
|
||||
async demoteAcceptedAdmins(workspaceId: string) {
|
||||
return await this.db.workspaceUserRole.updateMany({
|
||||
where: {
|
||||
workspaceId,
|
||||
status: WorkspaceMemberStatus.Accepted,
|
||||
type: WorkspaceRole.Admin,
|
||||
},
|
||||
data: {
|
||||
type: WorkspaceRole.Collaborator,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async get(workspaceId: string, userId: string) {
|
||||
return await this.db.workspaceUserRole.findUnique({
|
||||
where: {
|
||||
|
||||
@@ -37,7 +37,10 @@ import {
|
||||
UserFriendlyError,
|
||||
} from '../../../base';
|
||||
import { CurrentUser } from '../../../core/auth';
|
||||
import { AccessController } from '../../../core/permission';
|
||||
import {
|
||||
AccessController,
|
||||
WorkspacePolicyService,
|
||||
} from '../../../core/permission';
|
||||
import {
|
||||
ContextBlob,
|
||||
ContextCategories,
|
||||
@@ -408,6 +411,7 @@ export class CopilotContextRootResolver {
|
||||
export class CopilotContextResolver {
|
||||
constructor(
|
||||
private readonly ac: AccessController,
|
||||
private readonly policy: WorkspacePolicyService,
|
||||
private readonly models: Models,
|
||||
private readonly mutex: RequestMutex,
|
||||
private readonly context: CopilotContextService,
|
||||
@@ -667,6 +671,7 @@ export class CopilotContextResolver {
|
||||
const blobId = createHash('sha256').update(buffer).digest('base64url');
|
||||
const { filename, mimetype } = content;
|
||||
|
||||
await this.policy.assertCanUploadBlob(session.workspaceId);
|
||||
await this.storage.put(user.id, session.workspaceId, blobId, buffer);
|
||||
const file = await session.addFile(
|
||||
blobId,
|
||||
|
||||
@@ -258,7 +258,7 @@ export class FalProvider extends CopilotProvider<FalConfig> {
|
||||
const model = this.selectModel(cond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
|
||||
// by default, image prompt assumes there is only one message
|
||||
const prompt = this.extractPrompt(messages[messages.length - 1]);
|
||||
@@ -283,7 +283,9 @@ export class FalProvider extends CopilotProvider<FalConfig> {
|
||||
}
|
||||
return data.output;
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -296,12 +298,16 @@ export class FalProvider extends CopilotProvider<FalConfig> {
|
||||
const model = this.selectModel(cond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_calls')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const result = await this.text(cond, messages, options);
|
||||
|
||||
yield result;
|
||||
} catch (e) {
|
||||
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
@@ -319,7 +325,7 @@ export class FalProvider extends CopilotProvider<FalConfig> {
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('generate_images_stream_calls')
|
||||
.add(1, { model: model.id });
|
||||
.add(1, this.metricLabels(model.id));
|
||||
|
||||
// by default, image prompt assumes there is only one message
|
||||
const prompt = this.extractPrompt(
|
||||
@@ -376,7 +382,7 @@ export class FalProvider extends CopilotProvider<FalConfig> {
|
||||
} catch (e) {
|
||||
metrics.ai
|
||||
.counter('generate_images_stream_errors')
|
||||
.add(1, { model: model.id });
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -664,7 +664,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
const backendConfig = this.createNativeConfig();
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const cap = this.getAttachCapability(model, ModelOutputType.Structured);
|
||||
@@ -687,7 +687,9 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
const validated = schema.parse(parsed);
|
||||
return JSON.stringify(validated);
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -983,7 +985,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
metrics.ai
|
||||
.counter('generate_images_stream_calls')
|
||||
.add(1, { model: model.id });
|
||||
.add(1, this.metricLabels(model.id));
|
||||
|
||||
const { content: prompt, attachments } = [...messages].pop() || {};
|
||||
if (!prompt) throw new CopilotPromptInvalid('Prompt is required');
|
||||
@@ -1021,7 +1023,9 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
}
|
||||
return;
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('generate_images_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('generate_images_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -470,7 +470,8 @@ export abstract class CopilotProvider<C = any> {
|
||||
});
|
||||
const searchDocs = buildDocKeywordSearchGetter(
|
||||
ac,
|
||||
indexerService
|
||||
indexerService,
|
||||
models
|
||||
);
|
||||
tools.doc_keyword_search = createDocKeywordSearchTool(
|
||||
searchDocs.bind(null, options)
|
||||
|
||||
@@ -37,7 +37,11 @@ import {
|
||||
import { CurrentUser } from '../../core/auth';
|
||||
import { Admin } from '../../core/common';
|
||||
import { DocReader } from '../../core/doc';
|
||||
import { AccessController, DocAction } from '../../core/permission';
|
||||
import {
|
||||
AccessController,
|
||||
DocAction,
|
||||
WorkspacePolicyService,
|
||||
} from '../../core/permission';
|
||||
import { UserType } from '../../core/user';
|
||||
import type { ListSessionOptions, UpdateChatSession } from '../../models';
|
||||
import { processImage } from '../../native';
|
||||
@@ -378,6 +382,7 @@ export class CopilotResolver {
|
||||
constructor(
|
||||
private readonly ac: AccessController,
|
||||
private readonly mutex: RequestMutex,
|
||||
private readonly policy: WorkspacePolicyService,
|
||||
private readonly prompt: PromptService,
|
||||
private readonly chatSession: ChatSessionService,
|
||||
private readonly storage: CopilotStorage,
|
||||
@@ -778,6 +783,10 @@ export class CopilotResolver {
|
||||
delete options.blob;
|
||||
delete options.blobs;
|
||||
|
||||
if (blobs.length) {
|
||||
await this.policy.assertCanUploadBlob(workspaceId);
|
||||
}
|
||||
|
||||
for (const blob of blobs) {
|
||||
const uploaded = await this.storage.handleUpload(user.id, blob);
|
||||
const detectedMime =
|
||||
|
||||
@@ -18,7 +18,10 @@ export const buildBlobContentGetter = (
|
||||
chunk?: number
|
||||
) => {
|
||||
if (!options?.user || !options?.workspace || !blobId || !context) {
|
||||
return;
|
||||
return toolError(
|
||||
'Blob Read Failed',
|
||||
'Missing workspace, user, blob id, or copilot context for blob_read.'
|
||||
);
|
||||
}
|
||||
const canAccess = await ac
|
||||
.user(options.user)
|
||||
@@ -29,7 +32,10 @@ export const buildBlobContentGetter = (
|
||||
logger.warn(
|
||||
`User ${options.user} does not have access workspace ${options.workspace}`
|
||||
);
|
||||
return;
|
||||
return toolError(
|
||||
'Blob Read Failed',
|
||||
'You do not have permission to access this workspace attachment.'
|
||||
);
|
||||
}
|
||||
|
||||
const contextFile = context.files.find(
|
||||
@@ -42,7 +48,12 @@ export const buildBlobContentGetter = (
|
||||
context.getBlobContent(canonicalBlobId, chunk),
|
||||
]);
|
||||
const content = file?.trim() || blob?.trim();
|
||||
if (!content) return;
|
||||
if (!content) {
|
||||
return toolError(
|
||||
'Blob Read Failed',
|
||||
`Attachment ${canonicalBlobId} is not available for reading in the current copilot context.`
|
||||
);
|
||||
}
|
||||
const info = contextFile
|
||||
? { fileName: contextFile.name, fileType: contextFile.mimeType }
|
||||
: {};
|
||||
@@ -53,10 +64,7 @@ export const buildBlobContentGetter = (
|
||||
};
|
||||
|
||||
export const createBlobReadTool = (
|
||||
getBlobContent: (
|
||||
targetId?: string,
|
||||
chunk?: number
|
||||
) => Promise<object | undefined>
|
||||
getBlobContent: (targetId?: string, chunk?: number) => Promise<object>
|
||||
) => {
|
||||
return defineTool({
|
||||
description:
|
||||
@@ -73,13 +81,10 @@ export const createBlobReadTool = (
|
||||
execute: async ({ blob_id, chunk }) => {
|
||||
try {
|
||||
const blob = await getBlobContent(blob_id, chunk);
|
||||
if (!blob) {
|
||||
return;
|
||||
}
|
||||
return { ...blob };
|
||||
} catch (err: any) {
|
||||
logger.error(`Failed to read the blob ${blob_id} in context`, err);
|
||||
return toolError('Blob Read Failed', err.message);
|
||||
return toolError('Blob Read Failed', err.message ?? String(err));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,27 +1,43 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
import type { AccessController } from '../../../core/permission';
|
||||
import type { Models } from '../../../models';
|
||||
import type { IndexerService, SearchDoc } from '../../indexer';
|
||||
import { workspaceSyncRequiredError } from './doc-sync';
|
||||
import { toolError } from './error';
|
||||
import { defineTool } from './tool';
|
||||
import type { CopilotChatOptions } from './types';
|
||||
|
||||
export const buildDocKeywordSearchGetter = (
|
||||
ac: AccessController,
|
||||
indexerService: IndexerService
|
||||
indexerService: IndexerService,
|
||||
models: Models
|
||||
) => {
|
||||
const searchDocs = async (options: CopilotChatOptions, query?: string) => {
|
||||
if (!options || !query?.trim() || !options.user || !options.workspace) {
|
||||
return undefined;
|
||||
const queryTrimmed = query?.trim();
|
||||
if (!options || !queryTrimmed || !options.user || !options.workspace) {
|
||||
return toolError(
|
||||
'Doc Keyword Search Failed',
|
||||
'Missing workspace, user, or query for doc_keyword_search.'
|
||||
);
|
||||
}
|
||||
const workspace = await models.workspace.get(options.workspace);
|
||||
if (!workspace) {
|
||||
return workspaceSyncRequiredError();
|
||||
}
|
||||
const canAccess = await ac
|
||||
.user(options.user)
|
||||
.workspace(options.workspace)
|
||||
.can('Workspace.Read');
|
||||
if (!canAccess) return undefined;
|
||||
if (!canAccess) {
|
||||
return toolError(
|
||||
'Doc Keyword Search Failed',
|
||||
'You do not have permission to access this workspace.'
|
||||
);
|
||||
}
|
||||
const docs = await indexerService.searchDocsByKeyword(
|
||||
options.workspace,
|
||||
query
|
||||
queryTrimmed
|
||||
);
|
||||
|
||||
// filter current user readable docs
|
||||
@@ -29,13 +45,15 @@ export const buildDocKeywordSearchGetter = (
|
||||
.user(options.user)
|
||||
.workspace(options.workspace)
|
||||
.docs(docs, 'Doc.Read');
|
||||
return readableDocs;
|
||||
return readableDocs ?? [];
|
||||
};
|
||||
return searchDocs;
|
||||
};
|
||||
|
||||
export const createDocKeywordSearchTool = (
|
||||
searchDocs: (query: string) => Promise<SearchDoc[] | undefined>
|
||||
searchDocs: (
|
||||
query: string
|
||||
) => Promise<SearchDoc[] | ReturnType<typeof toolError>>
|
||||
) => {
|
||||
return defineTool({
|
||||
description:
|
||||
@@ -50,8 +68,8 @@ export const createDocKeywordSearchTool = (
|
||||
execute: async ({ query }) => {
|
||||
try {
|
||||
const docs = await searchDocs(query);
|
||||
if (!docs) {
|
||||
return;
|
||||
if (!Array.isArray(docs)) {
|
||||
return docs;
|
||||
}
|
||||
return docs.map(doc => ({
|
||||
docId: doc.docId,
|
||||
|
||||
@@ -3,13 +3,20 @@ import { z } from 'zod';
|
||||
|
||||
import { DocReader } from '../../../core/doc';
|
||||
import { AccessController } from '../../../core/permission';
|
||||
import { Models, publicUserSelect } from '../../../models';
|
||||
import { toolError } from './error';
|
||||
import { Models } from '../../../models';
|
||||
import {
|
||||
documentSyncPendingError,
|
||||
workspaceSyncRequiredError,
|
||||
} from './doc-sync';
|
||||
import { type ToolError, toolError } from './error';
|
||||
import { defineTool } from './tool';
|
||||
import type { CopilotChatOptions } from './types';
|
||||
|
||||
const logger = new Logger('DocReadTool');
|
||||
|
||||
const isToolError = (result: ToolError | object): result is ToolError =>
|
||||
'type' in result && result.type === 'error';
|
||||
|
||||
export const buildDocContentGetter = (
|
||||
ac: AccessController,
|
||||
docReader: DocReader,
|
||||
@@ -17,8 +24,17 @@ export const buildDocContentGetter = (
|
||||
) => {
|
||||
const getDoc = async (options: CopilotChatOptions, docId?: string) => {
|
||||
if (!options?.user || !options?.workspace || !docId) {
|
||||
return;
|
||||
return toolError(
|
||||
'Doc Read Failed',
|
||||
'Missing workspace, user, or document id for doc_read.'
|
||||
);
|
||||
}
|
||||
|
||||
const workspace = await models.workspace.get(options.workspace);
|
||||
if (!workspace) {
|
||||
return workspaceSyncRequiredError();
|
||||
}
|
||||
|
||||
const canAccess = await ac
|
||||
.user(options.user)
|
||||
.workspace(options.workspace)
|
||||
@@ -28,23 +44,15 @@ export const buildDocContentGetter = (
|
||||
logger.warn(
|
||||
`User ${options.user} does not have access to doc ${docId} in workspace ${options.workspace}`
|
||||
);
|
||||
return;
|
||||
return toolError(
|
||||
'Doc Read Failed',
|
||||
`You do not have permission to read document ${docId} in this workspace.`
|
||||
);
|
||||
}
|
||||
|
||||
const docMeta = await models.doc.getSnapshot(options.workspace, docId, {
|
||||
select: {
|
||||
createdAt: true,
|
||||
updatedAt: true,
|
||||
createdByUser: {
|
||||
select: publicUserSelect,
|
||||
},
|
||||
updatedByUser: {
|
||||
select: publicUserSelect,
|
||||
},
|
||||
},
|
||||
});
|
||||
const docMeta = await models.doc.getAuthors(options.workspace, docId);
|
||||
if (!docMeta) {
|
||||
return;
|
||||
return documentSyncPendingError(docId);
|
||||
}
|
||||
|
||||
const content = await docReader.getDocMarkdown(
|
||||
@@ -53,7 +61,7 @@ export const buildDocContentGetter = (
|
||||
true
|
||||
);
|
||||
if (!content) {
|
||||
return;
|
||||
return documentSyncPendingError(docId);
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -69,8 +77,12 @@ export const buildDocContentGetter = (
|
||||
return getDoc;
|
||||
};
|
||||
|
||||
type DocReadToolResult = Awaited<
|
||||
ReturnType<ReturnType<typeof buildDocContentGetter>>
|
||||
>;
|
||||
|
||||
export const createDocReadTool = (
|
||||
getDoc: (targetId?: string) => Promise<object | undefined>
|
||||
getDoc: (targetId?: string) => Promise<DocReadToolResult>
|
||||
) => {
|
||||
return defineTool({
|
||||
description:
|
||||
@@ -81,13 +93,10 @@ export const createDocReadTool = (
|
||||
execute: async ({ doc_id }) => {
|
||||
try {
|
||||
const doc = await getDoc(doc_id);
|
||||
if (!doc) {
|
||||
return;
|
||||
}
|
||||
return { ...doc };
|
||||
return isToolError(doc) ? doc : { ...doc };
|
||||
} catch (err: any) {
|
||||
logger.error(`Failed to read the doc ${doc_id}`, err);
|
||||
return toolError('Doc Read Failed', err.message);
|
||||
return toolError('Doc Read Failed', err.message ?? String(err));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
clearEmbeddingChunk,
|
||||
type Models,
|
||||
} from '../../../models';
|
||||
import { workspaceSyncRequiredError } from './doc-sync';
|
||||
import { toolError } from './error';
|
||||
import { defineTool } from './tool';
|
||||
import type {
|
||||
@@ -27,14 +28,24 @@ export const buildDocSearchGetter = (
|
||||
signal?: AbortSignal
|
||||
) => {
|
||||
if (!options || !query?.trim() || !options.user || !options.workspace) {
|
||||
return `Invalid search parameters.`;
|
||||
return toolError(
|
||||
'Doc Semantic Search Failed',
|
||||
'Missing workspace, user, or query for doc_semantic_search.'
|
||||
);
|
||||
}
|
||||
const workspace = await models.workspace.get(options.workspace);
|
||||
if (!workspace) {
|
||||
return workspaceSyncRequiredError();
|
||||
}
|
||||
const canAccess = await ac
|
||||
.user(options.user)
|
||||
.workspace(options.workspace)
|
||||
.can('Workspace.Read');
|
||||
if (!canAccess)
|
||||
return 'You do not have permission to access this workspace.';
|
||||
return toolError(
|
||||
'Doc Semantic Search Failed',
|
||||
'You do not have permission to access this workspace.'
|
||||
);
|
||||
const [chunks, contextChunks] = await Promise.all([
|
||||
context.matchWorkspaceAll(options.workspace, query, 10, signal),
|
||||
docContext?.matchFiles(query, 10, signal) ?? [],
|
||||
@@ -53,7 +64,7 @@ export const buildDocSearchGetter = (
|
||||
fileChunks.push(...contextChunks);
|
||||
}
|
||||
if (!blobChunks.length && !docChunks.length && !fileChunks.length) {
|
||||
return `No results found for "${query}".`;
|
||||
return [];
|
||||
}
|
||||
|
||||
const docIds = docChunks.map(c => ({
|
||||
@@ -101,7 +112,7 @@ export const createDocSemanticSearchTool = (
|
||||
searchDocs: (
|
||||
query: string,
|
||||
signal?: AbortSignal
|
||||
) => Promise<ChunkSimilarity[] | string | undefined>
|
||||
) => Promise<ChunkSimilarity[] | ReturnType<typeof toolError>>
|
||||
) => {
|
||||
return defineTool({
|
||||
description:
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
import { toolError } from './error';
|
||||
|
||||
export const LOCAL_WORKSPACE_SYNC_REQUIRED_MESSAGE =
|
||||
'This workspace is local-only and does not have AFFiNE Cloud sync enabled yet. Ask the user to enable workspace sync, then try again.';
|
||||
|
||||
export const DOCUMENT_SYNC_PENDING_MESSAGE = (docId: string) =>
|
||||
`Document ${docId} is not available on AFFiNE Cloud yet. Ask the user to wait for workspace sync to finish, then try again.`;
|
||||
|
||||
export const workspaceSyncRequiredError = () =>
|
||||
toolError('Workspace Sync Required', LOCAL_WORKSPACE_SYNC_REQUIRED_MESSAGE);
|
||||
|
||||
export const documentSyncPendingError = (docId: string) =>
|
||||
toolError('Document Sync Pending', DOCUMENT_SYNC_PENDING_MESSAGE(docId));
|
||||
@@ -7,7 +7,8 @@ import { defineTool } from './tool';
|
||||
|
||||
export const createExaSearchTool = (config: Config) => {
|
||||
return defineTool({
|
||||
description: 'Search the web for information',
|
||||
description:
|
||||
'Search the web using Exa, one of the best web search APIs for AI',
|
||||
inputSchema: z.object({
|
||||
query: z.string().describe('The query to search the web for.'),
|
||||
mode: z
|
||||
|
||||
@@ -24,7 +24,10 @@ import {
|
||||
UserFriendlyError,
|
||||
} from '../../../base';
|
||||
import { CurrentUser } from '../../../core/auth';
|
||||
import { AccessController } from '../../../core/permission';
|
||||
import {
|
||||
AccessController,
|
||||
WorkspacePolicyService,
|
||||
} from '../../../core/permission';
|
||||
import { WorkspaceType } from '../../../core/workspaces';
|
||||
import { COPILOT_LOCKER } from '../resolver';
|
||||
import { MAX_EMBEDDABLE_SIZE } from '../utils';
|
||||
@@ -72,6 +75,7 @@ export class CopilotWorkspaceEmbeddingConfigResolver {
|
||||
constructor(
|
||||
private readonly ac: AccessController,
|
||||
private readonly mutex: Mutex,
|
||||
private readonly policy: WorkspacePolicyService,
|
||||
private readonly copilotWorkspace: CopilotWorkspaceService
|
||||
) {}
|
||||
|
||||
@@ -215,10 +219,11 @@ export class CopilotWorkspaceEmbeddingConfigResolver {
|
||||
@Args('fileId', { type: () => String })
|
||||
fileId: string
|
||||
): Promise<boolean> {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(workspaceId)
|
||||
.assert('Workspace.Settings.Update');
|
||||
await this.policy.assertWorkspaceRoleAction(
|
||||
user.id,
|
||||
workspaceId,
|
||||
'Workspace.Settings.Update'
|
||||
);
|
||||
|
||||
return await this.copilotWorkspace.removeFile(workspaceId, fileId);
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
UserFriendlyError,
|
||||
WorkspaceLicenseAlreadyExists,
|
||||
} from '../../base';
|
||||
import { WorkspacePolicyService } from '../../core/permission';
|
||||
import { Models } from '../../models';
|
||||
import {
|
||||
SubscriptionPlan,
|
||||
@@ -59,7 +60,8 @@ export class LicenseService {
|
||||
private readonly db: PrismaClient,
|
||||
private readonly event: EventBus,
|
||||
private readonly models: Models,
|
||||
private readonly crypto: CryptoHelper
|
||||
private readonly crypto: CryptoHelper,
|
||||
private readonly policy: WorkspacePolicyService
|
||||
) {}
|
||||
|
||||
@OnEvent('workspace.subscription.activated')
|
||||
@@ -83,6 +85,7 @@ export class LicenseService {
|
||||
workspaceId,
|
||||
quantity,
|
||||
});
|
||||
await this.policy.reconcileWorkspaceQuotaState(workspaceId);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@@ -96,7 +99,10 @@ export class LicenseService {
|
||||
}: Events['workspace.subscription.canceled']) {
|
||||
switch (plan) {
|
||||
case SubscriptionPlan.SelfHostedTeam:
|
||||
await this.models.workspaceUser.deleteNonAccepted(workspaceId);
|
||||
await this.models.workspaceUser.demoteAcceptedAdmins(workspaceId);
|
||||
await this.models.workspaceFeature.remove(workspaceId, 'team_plan_v1');
|
||||
await this.policy.reconcileWorkspaceQuotaState(workspaceId);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { URLHelper } from '../../../base';
|
||||
import { InvalidOauthResponse, URLHelper } from '../../../base';
|
||||
import { OAuthProviderName } from '../config';
|
||||
import type { OAuthState } from '../types';
|
||||
import { OAuthAccount, OAuthProvider, Tokens } from './def';
|
||||
@@ -13,11 +13,17 @@ interface AuthTokenResponse {
|
||||
|
||||
export interface UserInfo {
|
||||
login: string;
|
||||
email: string;
|
||||
email: string | null;
|
||||
avatar_url: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
interface UserEmailInfo {
|
||||
email: string;
|
||||
primary: boolean;
|
||||
verified: boolean;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class GithubOAuthProvider extends OAuthProvider {
|
||||
provider = OAuthProviderName.GitHub;
|
||||
@@ -30,7 +36,7 @@ export class GithubOAuthProvider extends OAuthProvider {
|
||||
return `https://github.com/login/oauth/authorize?${this.url.stringify({
|
||||
client_id: this.config.clientId,
|
||||
redirect_uri: this.url.link('/oauth/callback'),
|
||||
scope: 'user',
|
||||
scope: 'read:user user:email',
|
||||
...this.config.args,
|
||||
state,
|
||||
})}`;
|
||||
@@ -56,16 +62,36 @@ export class GithubOAuthProvider extends OAuthProvider {
|
||||
async getUser(tokens: Tokens, _state: OAuthState): Promise<OAuthAccount> {
|
||||
const user = await this.fetchJson<UserInfo>('https://api.github.com/user', {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${tokens.accessToken}`,
|
||||
},
|
||||
headers: { Authorization: `Bearer ${tokens.accessToken}` },
|
||||
});
|
||||
|
||||
const email = user.email ?? (await this.getVerifiedEmail(tokens));
|
||||
if (!email) {
|
||||
throw new InvalidOauthResponse({
|
||||
reason: 'GitHub account did not have a verified email address.',
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
id: user.login,
|
||||
avatarUrl: user.avatar_url,
|
||||
email: user.email,
|
||||
email,
|
||||
name: user.name,
|
||||
};
|
||||
}
|
||||
|
||||
private async getVerifiedEmail(tokens: Tokens) {
|
||||
const emails = await this.fetchJson<UserEmailInfo[]>(
|
||||
'https://api.github.com/user/emails',
|
||||
{
|
||||
method: 'GET',
|
||||
headers: { Authorization: `Bearer ${tokens.accessToken}` },
|
||||
}
|
||||
);
|
||||
|
||||
return (
|
||||
emails.find(email => email.primary && email.verified)?.email ??
|
||||
emails.find(email => email.verified)?.email
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Injectable, OnModuleDestroy } from '@nestjs/common';
|
||||
import { createRemoteJWKSet, type JWTPayload, jwtVerify } from 'jose';
|
||||
import { omit } from 'lodash-es';
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
ExponentialBackoffScheduler,
|
||||
InvalidAuthState,
|
||||
InvalidOauthResponse,
|
||||
URLHelper,
|
||||
@@ -35,7 +36,7 @@ const OIDCUserInfoSchema = z
|
||||
.object({
|
||||
sub: z.string(),
|
||||
preferred_username: z.string().optional(),
|
||||
email: z.string().email(),
|
||||
email: z.string().optional(),
|
||||
name: z.string().optional(),
|
||||
email_verified: z
|
||||
.union([z.boolean(), z.enum(['true', 'false', '1', '0', 'yes', 'no'])])
|
||||
@@ -44,6 +45,8 @@ const OIDCUserInfoSchema = z
|
||||
})
|
||||
.passthrough();
|
||||
|
||||
const OIDCEmailSchema = z.string().email();
|
||||
|
||||
const OIDCConfigurationSchema = z.object({
|
||||
authorization_endpoint: z.string().url(),
|
||||
token_endpoint: z.string().url(),
|
||||
@@ -54,16 +57,28 @@ const OIDCConfigurationSchema = z.object({
|
||||
|
||||
type OIDCConfiguration = z.infer<typeof OIDCConfigurationSchema>;
|
||||
|
||||
const OIDC_DISCOVERY_INITIAL_RETRY_DELAY = 1000;
|
||||
const OIDC_DISCOVERY_MAX_RETRY_DELAY = 60_000;
|
||||
|
||||
@Injectable()
|
||||
export class OIDCProvider extends OAuthProvider {
|
||||
export class OIDCProvider extends OAuthProvider implements OnModuleDestroy {
|
||||
override provider = OAuthProviderName.OIDC;
|
||||
#endpoints: OIDCConfiguration | null = null;
|
||||
#jwks: ReturnType<typeof createRemoteJWKSet> | null = null;
|
||||
readonly #retryScheduler = new ExponentialBackoffScheduler({
|
||||
baseDelayMs: OIDC_DISCOVERY_INITIAL_RETRY_DELAY,
|
||||
maxDelayMs: OIDC_DISCOVERY_MAX_RETRY_DELAY,
|
||||
});
|
||||
#validationGeneration = 0;
|
||||
|
||||
constructor(private readonly url: URLHelper) {
|
||||
super();
|
||||
}
|
||||
|
||||
onModuleDestroy() {
|
||||
this.#retryScheduler.clear();
|
||||
}
|
||||
|
||||
override get requiresPkce() {
|
||||
return true;
|
||||
}
|
||||
@@ -87,58 +102,109 @@ export class OIDCProvider extends OAuthProvider {
|
||||
}
|
||||
|
||||
protected override setup() {
|
||||
const validate = async () => {
|
||||
this.#endpoints = null;
|
||||
this.#jwks = null;
|
||||
const generation = ++this.#validationGeneration;
|
||||
this.#retryScheduler.clear();
|
||||
|
||||
if (super.configured) {
|
||||
const config = this.config as OAuthOIDCProviderConfig;
|
||||
if (!config.issuer) {
|
||||
this.logger.error('Missing OIDC issuer configuration');
|
||||
super.setup();
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await fetch(
|
||||
`${config.issuer}/.well-known/openid-configuration`,
|
||||
{
|
||||
method: 'GET',
|
||||
headers: { Accept: 'application/json' },
|
||||
}
|
||||
);
|
||||
|
||||
if (res.ok) {
|
||||
const configuration = OIDCConfigurationSchema.parse(
|
||||
await res.json()
|
||||
);
|
||||
if (
|
||||
this.normalizeIssuer(config.issuer) !==
|
||||
this.normalizeIssuer(configuration.issuer)
|
||||
) {
|
||||
this.logger.error(
|
||||
`OIDC issuer mismatch, expected ${config.issuer}, got ${configuration.issuer}`
|
||||
);
|
||||
} else {
|
||||
this.#endpoints = configuration;
|
||||
this.#jwks = createRemoteJWKSet(new URL(configuration.jwks_uri));
|
||||
}
|
||||
} else {
|
||||
this.logger.error(`Invalid OIDC issuer ${config.issuer}`);
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.error('Failed to validate OIDC configuration', e);
|
||||
}
|
||||
}
|
||||
|
||||
super.setup();
|
||||
};
|
||||
|
||||
validate().catch(() => {
|
||||
this.validateAndSync(generation).catch(() => {
|
||||
/* noop */
|
||||
});
|
||||
}
|
||||
|
||||
private async validateAndSync(generation: number) {
|
||||
if (generation !== this.#validationGeneration) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!super.configured) {
|
||||
this.resetState();
|
||||
this.#retryScheduler.reset();
|
||||
super.setup();
|
||||
return;
|
||||
}
|
||||
|
||||
const config = this.config as OAuthOIDCProviderConfig;
|
||||
if (!config.issuer) {
|
||||
this.logger.error('Missing OIDC issuer configuration');
|
||||
this.resetState();
|
||||
this.#retryScheduler.reset();
|
||||
super.setup();
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await fetch(
|
||||
`${config.issuer}/.well-known/openid-configuration`,
|
||||
{
|
||||
method: 'GET',
|
||||
headers: { Accept: 'application/json' },
|
||||
}
|
||||
);
|
||||
|
||||
if (generation !== this.#validationGeneration) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!res.ok) {
|
||||
this.logger.error(`Invalid OIDC issuer ${config.issuer}`);
|
||||
this.onValidationFailure(generation);
|
||||
return;
|
||||
}
|
||||
|
||||
const configuration = OIDCConfigurationSchema.parse(await res.json());
|
||||
if (
|
||||
this.normalizeIssuer(config.issuer) !==
|
||||
this.normalizeIssuer(configuration.issuer)
|
||||
) {
|
||||
this.logger.error(
|
||||
`OIDC issuer mismatch, expected ${config.issuer}, got ${configuration.issuer}`
|
||||
);
|
||||
this.onValidationFailure(generation);
|
||||
return;
|
||||
}
|
||||
|
||||
this.#endpoints = configuration;
|
||||
this.#jwks = createRemoteJWKSet(new URL(configuration.jwks_uri));
|
||||
this.#retryScheduler.reset();
|
||||
super.setup();
|
||||
} catch (e) {
|
||||
if (generation !== this.#validationGeneration) {
|
||||
return;
|
||||
}
|
||||
this.logger.error('Failed to validate OIDC configuration', e);
|
||||
this.onValidationFailure(generation);
|
||||
}
|
||||
}
|
||||
|
||||
private onValidationFailure(generation: number) {
|
||||
this.resetState();
|
||||
super.setup();
|
||||
this.scheduleRetry(generation);
|
||||
}
|
||||
|
||||
private scheduleRetry(generation: number) {
|
||||
if (generation !== this.#validationGeneration) {
|
||||
return;
|
||||
}
|
||||
|
||||
const delay = this.#retryScheduler.schedule(() => {
|
||||
this.validateAndSync(generation).catch(() => {
|
||||
/* noop */
|
||||
});
|
||||
});
|
||||
if (delay === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.logger.warn(
|
||||
`OIDC discovery validation failed, retrying in ${delay}ms`
|
||||
);
|
||||
}
|
||||
|
||||
private resetState() {
|
||||
this.#endpoints = null;
|
||||
this.#jwks = null;
|
||||
}
|
||||
|
||||
getAuthUrl(state: string): string {
|
||||
const parsedState = this.parseStatePayload(state);
|
||||
const nonce = parsedState?.state ?? state;
|
||||
@@ -291,6 +357,68 @@ export class OIDCProvider extends OAuthProvider {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
private claimCandidates(
|
||||
configuredClaim: string | undefined,
|
||||
defaultClaim: string
|
||||
) {
|
||||
if (typeof configuredClaim === 'string' && configuredClaim.length > 0) {
|
||||
return [configuredClaim];
|
||||
}
|
||||
return [defaultClaim];
|
||||
}
|
||||
|
||||
private formatClaimCandidates(claims: string[]) {
|
||||
return claims.map(claim => `"${claim}"`).join(', ');
|
||||
}
|
||||
|
||||
private resolveStringClaim(
|
||||
claims: string[],
|
||||
...sources: Array<Record<string, unknown>>
|
||||
) {
|
||||
for (const claim of claims) {
|
||||
for (const source of sources) {
|
||||
const value = this.extractString(source[claim]);
|
||||
if (value) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
private resolveBooleanClaim(
|
||||
claims: string[],
|
||||
...sources: Array<Record<string, unknown>>
|
||||
) {
|
||||
for (const claim of claims) {
|
||||
for (const source of sources) {
|
||||
const value = this.extractBoolean(source[claim]);
|
||||
if (value !== undefined) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
private resolveEmailClaim(
|
||||
claims: string[],
|
||||
...sources: Array<Record<string, unknown>>
|
||||
) {
|
||||
for (const claim of claims) {
|
||||
for (const source of sources) {
|
||||
const value = this.extractString(source[claim]);
|
||||
if (value && OIDCEmailSchema.safeParse(value).success) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async getUser(tokens: Tokens, state: OAuthState): Promise<OAuthAccount> {
|
||||
if (!tokens.idToken) {
|
||||
throw new InvalidOauthResponse({
|
||||
@@ -315,6 +443,8 @@ export class OIDCProvider extends OAuthProvider {
|
||||
{ treatServerErrorAsInvalid: true }
|
||||
);
|
||||
const user = OIDCUserInfoSchema.parse(rawUser);
|
||||
const userClaims = user as Record<string, unknown>;
|
||||
const idTokenClaimsRecord = idTokenClaims as Record<string, unknown>;
|
||||
|
||||
if (!user.sub || !idTokenClaims.sub) {
|
||||
throw new InvalidOauthResponse({
|
||||
@@ -327,22 +457,29 @@ export class OIDCProvider extends OAuthProvider {
|
||||
}
|
||||
|
||||
const args = this.config.args ?? {};
|
||||
const idClaims = this.claimCandidates(args.claim_id, 'sub');
|
||||
const emailClaims = this.claimCandidates(args.claim_email, 'email');
|
||||
const nameClaims = this.claimCandidates(args.claim_name, 'name');
|
||||
const emailVerifiedClaims = this.claimCandidates(
|
||||
args.claim_email_verified,
|
||||
'email_verified'
|
||||
);
|
||||
|
||||
const claimsMap = {
|
||||
id: args.claim_id || 'sub',
|
||||
email: args.claim_email || 'email',
|
||||
name: args.claim_name || 'name',
|
||||
emailVerified: args.claim_email_verified || 'email_verified',
|
||||
};
|
||||
|
||||
const accountId =
|
||||
this.extractString(user[claimsMap.id]) ?? idTokenClaims.sub;
|
||||
const email =
|
||||
this.extractString(user[claimsMap.email]) ||
|
||||
this.extractString(idTokenClaims.email);
|
||||
const emailVerified =
|
||||
this.extractBoolean(user[claimsMap.emailVerified]) ??
|
||||
this.extractBoolean(idTokenClaims.email_verified);
|
||||
const accountId = this.resolveStringClaim(
|
||||
idClaims,
|
||||
userClaims,
|
||||
idTokenClaimsRecord
|
||||
);
|
||||
const email = this.resolveEmailClaim(
|
||||
emailClaims,
|
||||
userClaims,
|
||||
idTokenClaimsRecord
|
||||
);
|
||||
const emailVerified = this.resolveBooleanClaim(
|
||||
emailVerifiedClaims,
|
||||
userClaims,
|
||||
idTokenClaimsRecord
|
||||
);
|
||||
|
||||
if (!accountId) {
|
||||
throw new InvalidOauthResponse({
|
||||
@@ -352,7 +489,7 @@ export class OIDCProvider extends OAuthProvider {
|
||||
|
||||
if (!email) {
|
||||
throw new InvalidOauthResponse({
|
||||
reason: 'Missing required claim for email',
|
||||
reason: `Missing valid email claim in OIDC response. Tried userinfo and ID token claims: ${this.formatClaimCandidates(emailClaims)}`,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -367,9 +504,11 @@ export class OIDCProvider extends OAuthProvider {
|
||||
email,
|
||||
};
|
||||
|
||||
const name =
|
||||
this.extractString(user[claimsMap.name]) ||
|
||||
this.extractString(idTokenClaims.name);
|
||||
const name = this.resolveStringClaim(
|
||||
nameClaims,
|
||||
userClaims,
|
||||
idTokenClaimsRecord
|
||||
);
|
||||
if (name) {
|
||||
account.name = name;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { Cron, CronExpression } from '@nestjs/schedule';
|
||||
import { PrismaClient, Provider } from '@prisma/client';
|
||||
|
||||
@@ -18,6 +18,7 @@ declare global {
|
||||
'nightly.cleanExpiredOnetimeSubscriptions': {};
|
||||
'nightly.notifyAboutToExpireWorkspaceSubscriptions': {};
|
||||
'nightly.reconcileRevenueCatSubscriptions': {};
|
||||
'nightly.reconcileStripeSubscriptions': {};
|
||||
'nightly.reconcileStripeRefunds': {};
|
||||
'nightly.revenuecat.syncUser': { userId: string };
|
||||
}
|
||||
@@ -25,6 +26,8 @@ declare global {
|
||||
|
||||
@Injectable()
|
||||
export class SubscriptionCronJobs {
|
||||
private readonly logger = new Logger(SubscriptionCronJobs.name);
|
||||
|
||||
constructor(
|
||||
private readonly db: PrismaClient,
|
||||
private readonly event: EventBus,
|
||||
@@ -61,6 +64,12 @@ export class SubscriptionCronJobs {
|
||||
{ jobId: 'nightly-payment-reconcile-revenuecat-subscriptions' }
|
||||
);
|
||||
|
||||
await this.queue.add(
|
||||
'nightly.reconcileStripeSubscriptions',
|
||||
{},
|
||||
{ jobId: 'nightly-payment-reconcile-stripe-subscriptions' }
|
||||
);
|
||||
|
||||
await this.queue.add(
|
||||
'nightly.reconcileStripeRefunds',
|
||||
{},
|
||||
@@ -202,6 +211,48 @@ export class SubscriptionCronJobs {
|
||||
await this.rcHandler.syncAppUser(payload.userId);
|
||||
}
|
||||
|
||||
@OnJob('nightly.reconcileStripeSubscriptions')
|
||||
async reconcileStripeSubscriptions() {
|
||||
const stripe = this.stripeFactory.stripe;
|
||||
const subs = await this.db.subscription.findMany({
|
||||
where: {
|
||||
provider: Provider.stripe,
|
||||
stripeSubscriptionId: { not: null },
|
||||
status: {
|
||||
in: [
|
||||
SubscriptionStatus.Active,
|
||||
SubscriptionStatus.Trialing,
|
||||
SubscriptionStatus.PastDue,
|
||||
],
|
||||
},
|
||||
},
|
||||
select: { stripeSubscriptionId: true },
|
||||
});
|
||||
|
||||
const subscriptionIds = Array.from(
|
||||
new Set(
|
||||
subs
|
||||
.map(sub => sub.stripeSubscriptionId)
|
||||
.filter((id): id is string => !!id)
|
||||
)
|
||||
);
|
||||
|
||||
for (const subscriptionId of subscriptionIds) {
|
||||
try {
|
||||
const subscription = await stripe.subscriptions.retrieve(
|
||||
subscriptionId,
|
||||
{ expand: ['customer'] }
|
||||
);
|
||||
await this.subscription.saveStripeSubscription(subscription);
|
||||
} catch (e) {
|
||||
this.logger.error(
|
||||
`Failed to reconcile stripe subscription ${subscriptionId}`,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@OnJob('nightly.reconcileStripeRefunds')
|
||||
async reconcileStripeRefunds() {
|
||||
const stripe = this.stripeFactory.stripe;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { EventBus, OnEvent } from '../../base';
|
||||
import { WorkspacePolicyService } from '../../core/permission';
|
||||
import { WorkspaceService } from '../../core/workspaces';
|
||||
import { Models } from '../../models';
|
||||
import { SubscriptionPlan, SubscriptionRecurring } from './types';
|
||||
@@ -10,7 +11,8 @@ export class PaymentEventHandlers {
|
||||
constructor(
|
||||
private readonly workspace: WorkspaceService,
|
||||
private readonly models: Models,
|
||||
private readonly event: EventBus
|
||||
private readonly event: EventBus,
|
||||
private readonly policy: WorkspacePolicyService
|
||||
) {}
|
||||
|
||||
@OnEvent('workspace.subscription.activated')
|
||||
@@ -40,6 +42,7 @@ export class PaymentEventHandlers {
|
||||
// we only send emails when the team workspace is activated
|
||||
await this.workspace.sendTeamWorkspaceUpgradedEmail(workspaceId);
|
||||
}
|
||||
await this.policy.reconcileWorkspaceQuotaState(workspaceId);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -54,7 +57,10 @@ export class PaymentEventHandlers {
|
||||
}: Events['workspace.subscription.canceled']) {
|
||||
switch (plan) {
|
||||
case SubscriptionPlan.Team:
|
||||
await this.models.workspaceUser.deleteNonAccepted(workspaceId);
|
||||
await this.models.workspaceUser.demoteAcceptedAdmins(workspaceId);
|
||||
await this.models.workspaceFeature.remove(workspaceId, 'team_plan_v1');
|
||||
await this.policy.reconcileWorkspaceQuotaState(workspaceId);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@@ -81,6 +87,7 @@ export class PaymentEventHandlers {
|
||||
recurring === 'lifetime' ? 'lifetime_pro_plan_v1' : 'pro_plan_v1',
|
||||
'subscription activated'
|
||||
);
|
||||
await this.policy.reconcileOwnedWorkspaces(userId);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@@ -105,6 +112,7 @@ export class PaymentEventHandlers {
|
||||
'free_plan_v1',
|
||||
'lifetime subscription canceled'
|
||||
);
|
||||
await this.policy.reconcileOwnedWorkspaces(userId);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -121,6 +129,7 @@ export class PaymentEventHandlers {
|
||||
'free_plan_v1',
|
||||
'subscription canceled'
|
||||
);
|
||||
await this.policy.reconcileOwnedWorkspaces(userId);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ interface TestOps extends OpSchema {
|
||||
add: [{ a: number; b: number }, number];
|
||||
bin: [Uint8Array, Uint8Array];
|
||||
sub: [Uint8Array, number];
|
||||
init: [{ fastText?: boolean } | undefined, { ok: true }];
|
||||
}
|
||||
|
||||
declare module 'vitest' {
|
||||
@@ -84,6 +85,55 @@ describe('op client', () => {
|
||||
expect(data.byteLength).toBe(0);
|
||||
});
|
||||
|
||||
it('should send optional payload call with abort signal', async ctx => {
|
||||
const abortController = new AbortController();
|
||||
const result = ctx.producer.call(
|
||||
'init',
|
||||
{ fastText: true },
|
||||
abortController.signal
|
||||
);
|
||||
|
||||
expect(ctx.postMessage.mock.calls[0][0]).toMatchInlineSnapshot(`
|
||||
{
|
||||
"id": "init:1",
|
||||
"name": "init",
|
||||
"payload": {
|
||||
"fastText": true,
|
||||
},
|
||||
"type": "call",
|
||||
}
|
||||
`);
|
||||
|
||||
ctx.handlers.return({
|
||||
type: 'return',
|
||||
id: 'init:1',
|
||||
data: { ok: true },
|
||||
});
|
||||
|
||||
await expect(result).resolves.toEqual({ ok: true });
|
||||
});
|
||||
|
||||
it('should send undefined payload for optional input call', async ctx => {
|
||||
const result = ctx.producer.call('init', undefined);
|
||||
|
||||
expect(ctx.postMessage.mock.calls[0][0]).toMatchInlineSnapshot(`
|
||||
{
|
||||
"id": "init:1",
|
||||
"name": "init",
|
||||
"payload": undefined,
|
||||
"type": "call",
|
||||
}
|
||||
`);
|
||||
|
||||
ctx.handlers.return({
|
||||
type: 'return',
|
||||
id: 'init:1',
|
||||
data: { ok: true },
|
||||
});
|
||||
|
||||
await expect(result).resolves.toEqual({ ok: true });
|
||||
});
|
||||
|
||||
it('should cancel call', async ctx => {
|
||||
const promise = ctx.producer.call('add', { a: 1, b: 2 });
|
||||
|
||||
|
||||
@@ -40,18 +40,14 @@ describe('op consumer', () => {
|
||||
it('should throw if no handler registered', async ctx => {
|
||||
ctx.handlers.call({ type: 'call', id: 'add:1', name: 'add', payload: {} });
|
||||
await vi.advanceTimersToNextTimerAsync();
|
||||
expect(ctx.postMessage.mock.lastCall).toMatchInlineSnapshot(`
|
||||
[
|
||||
{
|
||||
"error": {
|
||||
"message": "Handler for operation [add] is not registered.",
|
||||
"name": "Error",
|
||||
},
|
||||
"id": "add:1",
|
||||
"type": "return",
|
||||
},
|
||||
]
|
||||
`);
|
||||
expect(ctx.postMessage.mock.lastCall?.[0]).toMatchObject({
|
||||
type: 'return',
|
||||
id: 'add:1',
|
||||
error: {
|
||||
message: 'Handler for operation [add] is not registered.',
|
||||
name: 'Error',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle call message', async ctx => {
|
||||
@@ -73,6 +69,38 @@ describe('op consumer', () => {
|
||||
`);
|
||||
});
|
||||
|
||||
it('should serialize string errors with message', async ctx => {
|
||||
ctx.consumer.register('any', () => {
|
||||
throw 'worker panic';
|
||||
});
|
||||
|
||||
ctx.handlers.call({ type: 'call', id: 'any:1', name: 'any', payload: {} });
|
||||
await vi.advanceTimersToNextTimerAsync();
|
||||
|
||||
expect(ctx.postMessage.mock.calls[0][0]).toMatchObject({
|
||||
type: 'return',
|
||||
id: 'any:1',
|
||||
error: {
|
||||
name: 'Error',
|
||||
message: 'worker panic',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should serialize plain object errors with fallback message', async ctx => {
|
||||
ctx.consumer.register('any', () => {
|
||||
throw { reason: 'panic', code: 'E_PANIC' };
|
||||
});
|
||||
|
||||
ctx.handlers.call({ type: 'call', id: 'any:1', name: 'any', payload: {} });
|
||||
await vi.advanceTimersToNextTimerAsync();
|
||||
|
||||
const message = ctx.postMessage.mock.calls[0][0]?.error?.message;
|
||||
expect(typeof message).toBe('string');
|
||||
expect(message).toContain('"reason":"panic"');
|
||||
expect(message).toContain('"code":"E_PANIC"');
|
||||
});
|
||||
|
||||
it('should handle cancel message', async ctx => {
|
||||
ctx.consumer.register('add', ({ a, b }, { signal }) => {
|
||||
const { reject, resolve, promise } = Promise.withResolvers<number>();
|
||||
|
||||
@@ -16,6 +16,96 @@ import {
|
||||
} from './message';
|
||||
import type { OpInput, OpNames, OpOutput, OpSchema } from './types';
|
||||
|
||||
const SERIALIZABLE_ERROR_FIELDS = [
|
||||
'name',
|
||||
'message',
|
||||
'code',
|
||||
'type',
|
||||
'status',
|
||||
'data',
|
||||
'stacktrace',
|
||||
] as const;
|
||||
|
||||
type SerializableErrorShape = Partial<
|
||||
Record<(typeof SERIALIZABLE_ERROR_FIELDS)[number], unknown>
|
||||
> & {
|
||||
name?: string;
|
||||
message?: string;
|
||||
};
|
||||
|
||||
function getFallbackErrorMessage(error: unknown): string {
|
||||
if (typeof error === 'string') {
|
||||
return error;
|
||||
}
|
||||
|
||||
if (error instanceof Error && error.message) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
if (
|
||||
typeof error === 'number' ||
|
||||
typeof error === 'boolean' ||
|
||||
typeof error === 'bigint' ||
|
||||
typeof error === 'symbol'
|
||||
) {
|
||||
return String(error);
|
||||
}
|
||||
|
||||
if (error === null || error === undefined) {
|
||||
return 'Unknown error';
|
||||
}
|
||||
|
||||
try {
|
||||
const jsonMessage = JSON.stringify(error);
|
||||
if (jsonMessage && jsonMessage !== '{}') {
|
||||
return jsonMessage;
|
||||
}
|
||||
} catch {
|
||||
return 'Unknown error';
|
||||
}
|
||||
|
||||
return 'Unknown error';
|
||||
}
|
||||
|
||||
function serializeError(error: unknown): Error {
|
||||
const valueToPick =
|
||||
error && typeof error === 'object'
|
||||
? error
|
||||
: ({} as Record<string, unknown>);
|
||||
const serialized = pick(
|
||||
valueToPick,
|
||||
SERIALIZABLE_ERROR_FIELDS
|
||||
) as SerializableErrorShape;
|
||||
|
||||
if (!serialized.message || typeof serialized.message !== 'string') {
|
||||
serialized.message = getFallbackErrorMessage(error);
|
||||
}
|
||||
|
||||
if (!serialized.name || typeof serialized.name !== 'string') {
|
||||
if (error instanceof Error && error.name) {
|
||||
serialized.name = error.name;
|
||||
} else if (error && typeof error === 'object') {
|
||||
const constructorName = error.constructor?.name;
|
||||
serialized.name =
|
||||
typeof constructorName === 'string' && constructorName.length > 0
|
||||
? constructorName
|
||||
: 'Error';
|
||||
} else {
|
||||
serialized.name = 'Error';
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
!serialized.stacktrace &&
|
||||
error instanceof Error &&
|
||||
typeof error.stack === 'string'
|
||||
) {
|
||||
serialized.stacktrace = error.stack;
|
||||
}
|
||||
|
||||
return serialized as Error;
|
||||
}
|
||||
|
||||
interface OpCallContext {
|
||||
signal: AbortSignal;
|
||||
}
|
||||
@@ -71,15 +161,7 @@ export class OpConsumer<Ops extends OpSchema> extends AutoMessageHandler {
|
||||
this.port.postMessage({
|
||||
type: 'return',
|
||||
id: msg.id,
|
||||
error: pick(error, [
|
||||
'name',
|
||||
'message',
|
||||
'code',
|
||||
'type',
|
||||
'status',
|
||||
'data',
|
||||
'stacktrace',
|
||||
]),
|
||||
error: serializeError(error),
|
||||
} satisfies ReturnMessage);
|
||||
},
|
||||
complete: () => {
|
||||
@@ -109,15 +191,7 @@ export class OpConsumer<Ops extends OpSchema> extends AutoMessageHandler {
|
||||
this.port.postMessage({
|
||||
type: 'error',
|
||||
id: msg.id,
|
||||
error: pick(error, [
|
||||
'name',
|
||||
'message',
|
||||
'code',
|
||||
'type',
|
||||
'status',
|
||||
'data',
|
||||
'stacktrace',
|
||||
]),
|
||||
error: serializeError(error),
|
||||
} satisfies SubscriptionErrorMessage);
|
||||
},
|
||||
complete: () => {
|
||||
|
||||
@@ -12,7 +12,16 @@ export interface OpSchema {
|
||||
[key: string]: [any, any?];
|
||||
}
|
||||
|
||||
type RequiredInput<In> = In extends void ? [] : In extends never ? [] : [In];
|
||||
type IsAny<T> = 0 extends 1 & T ? true : false;
|
||||
|
||||
type RequiredInput<In> =
|
||||
IsAny<In> extends true
|
||||
? [In]
|
||||
: [In] extends [never]
|
||||
? []
|
||||
: [In] extends [void]
|
||||
? []
|
||||
: [In];
|
||||
|
||||
export type OpNames<T extends OpSchema> = ValuesOf<KeyToKey<T>>;
|
||||
export type OpInput<
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
edition = "2024"
|
||||
license-file = "LICENSE"
|
||||
name = "affine_common"
|
||||
publish = false
|
||||
version = "0.1.0"
|
||||
|
||||
[features]
|
||||
|
||||
@@ -1,18 +1,235 @@
|
||||
import 'fake-indexeddb/auto';
|
||||
|
||||
import { expect, test } from 'vitest';
|
||||
import * as reader from '@affine/reader';
|
||||
import { NEVER } from 'rxjs';
|
||||
import { afterEach, expect, test, vi } from 'vitest';
|
||||
import { Doc as YDoc, encodeStateAsUpdate } from 'yjs';
|
||||
|
||||
import { DummyConnection } from '../connection';
|
||||
import {
|
||||
IndexedDBBlobStorage,
|
||||
IndexedDBBlobSyncStorage,
|
||||
IndexedDBDocStorage,
|
||||
IndexedDBDocSyncStorage,
|
||||
} from '../impls/idb';
|
||||
import { SpaceStorage } from '../storage';
|
||||
import {
|
||||
type AggregateOptions,
|
||||
type AggregateResult,
|
||||
type CrawlResult,
|
||||
type DocClock,
|
||||
type DocClocks,
|
||||
type DocDiff,
|
||||
type DocIndexedClock,
|
||||
type DocRecord,
|
||||
type DocStorage,
|
||||
type DocUpdate,
|
||||
type IndexerDocument,
|
||||
type IndexerSchema,
|
||||
IndexerStorageBase,
|
||||
IndexerSyncStorageBase,
|
||||
type Query,
|
||||
type SearchOptions,
|
||||
type SearchResult,
|
||||
SpaceStorage,
|
||||
} from '../storage';
|
||||
import { Sync } from '../sync';
|
||||
import { IndexerSyncImpl } from '../sync/indexer';
|
||||
import { expectYjsEqual } from './utils';
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
function deferred<T = void>() {
|
||||
let resolve!: (value: T | PromiseLike<T>) => void;
|
||||
let reject!: (reason?: unknown) => void;
|
||||
const promise = new Promise<T>((res, rej) => {
|
||||
resolve = res;
|
||||
reject = rej;
|
||||
});
|
||||
return { promise, resolve, reject };
|
||||
}
|
||||
|
||||
class TestDocStorage implements DocStorage {
|
||||
readonly storageType = 'doc' as const;
|
||||
readonly connection = new DummyConnection();
|
||||
readonly isReadonly = false;
|
||||
private readonly subscribers = new Set<
|
||||
(update: DocRecord, origin?: string) => void
|
||||
>();
|
||||
|
||||
constructor(
|
||||
readonly spaceId: string,
|
||||
private readonly timestamps: Map<string, Date>,
|
||||
private readonly crawlDocDataImpl: (
|
||||
docId: string
|
||||
) => Promise<CrawlResult | null>
|
||||
) {}
|
||||
|
||||
async getDoc(_docId: string): Promise<DocRecord | null> {
|
||||
return null;
|
||||
}
|
||||
|
||||
async getDocDiff(
|
||||
_docId: string,
|
||||
_state?: Uint8Array
|
||||
): Promise<DocDiff | null> {
|
||||
return null;
|
||||
}
|
||||
|
||||
async pushDocUpdate(update: DocUpdate, origin?: string): Promise<DocClock> {
|
||||
const timestamp = this.timestamps.get(update.docId) ?? new Date();
|
||||
const record = { ...update, timestamp };
|
||||
this.timestamps.set(update.docId, timestamp);
|
||||
for (const subscriber of this.subscribers) {
|
||||
subscriber(record, origin);
|
||||
}
|
||||
return { docId: update.docId, timestamp };
|
||||
}
|
||||
|
||||
async getDocTimestamp(docId: string): Promise<DocClock | null> {
|
||||
const timestamp = this.timestamps.get(docId);
|
||||
return timestamp ? { docId, timestamp } : null;
|
||||
}
|
||||
|
||||
async getDocTimestamps(): Promise<DocClocks> {
|
||||
return Object.fromEntries(this.timestamps);
|
||||
}
|
||||
|
||||
async deleteDoc(docId: string): Promise<void> {
|
||||
this.timestamps.delete(docId);
|
||||
}
|
||||
|
||||
subscribeDocUpdate(callback: (update: DocRecord, origin?: string) => void) {
|
||||
this.subscribers.add(callback);
|
||||
return () => {
|
||||
this.subscribers.delete(callback);
|
||||
};
|
||||
}
|
||||
|
||||
async crawlDocData(docId: string): Promise<CrawlResult | null> {
|
||||
return this.crawlDocDataImpl(docId);
|
||||
}
|
||||
}
|
||||
|
||||
class TrackingIndexerStorage extends IndexerStorageBase {
|
||||
override readonly connection = new DummyConnection();
|
||||
override readonly isReadonly = false;
|
||||
|
||||
constructor(
|
||||
private readonly calls: string[],
|
||||
override readonly recommendRefreshInterval: number
|
||||
) {
|
||||
super();
|
||||
}
|
||||
|
||||
override async search<
|
||||
T extends keyof IndexerSchema,
|
||||
const O extends SearchOptions<T>,
|
||||
>(_table: T, _query: Query<T>, _options?: O): Promise<SearchResult<T, O>> {
|
||||
return {
|
||||
pagination: { count: 0, limit: 0, skip: 0, hasMore: false },
|
||||
nodes: [],
|
||||
} as SearchResult<T, O>;
|
||||
}
|
||||
|
||||
override async aggregate<
|
||||
T extends keyof IndexerSchema,
|
||||
const O extends AggregateOptions<T>,
|
||||
>(
|
||||
_table: T,
|
||||
_query: Query<T>,
|
||||
_field: keyof IndexerSchema[T],
|
||||
_options?: O
|
||||
): Promise<AggregateResult<T, O>> {
|
||||
return {
|
||||
pagination: { count: 0, limit: 0, skip: 0, hasMore: false },
|
||||
buckets: [],
|
||||
} as AggregateResult<T, O>;
|
||||
}
|
||||
|
||||
override search$<
|
||||
T extends keyof IndexerSchema,
|
||||
const O extends SearchOptions<T>,
|
||||
>(_table: T, _query: Query<T>, _options?: O) {
|
||||
return NEVER;
|
||||
}
|
||||
|
||||
override aggregate$<
|
||||
T extends keyof IndexerSchema,
|
||||
const O extends AggregateOptions<T>,
|
||||
>(_table: T, _query: Query<T>, _field: keyof IndexerSchema[T], _options?: O) {
|
||||
return NEVER;
|
||||
}
|
||||
|
||||
override async deleteByQuery<T extends keyof IndexerSchema>(
|
||||
table: T,
|
||||
_query: Query<T>
|
||||
): Promise<void> {
|
||||
this.calls.push(`deleteByQuery:${String(table)}`);
|
||||
}
|
||||
|
||||
override async insert<T extends keyof IndexerSchema>(
|
||||
table: T,
|
||||
document: IndexerDocument<T>
|
||||
): Promise<void> {
|
||||
this.calls.push(`insert:${String(table)}:${document.id}`);
|
||||
}
|
||||
|
||||
override async delete<T extends keyof IndexerSchema>(
|
||||
table: T,
|
||||
id: string
|
||||
): Promise<void> {
|
||||
this.calls.push(`delete:${String(table)}:${id}`);
|
||||
}
|
||||
|
||||
override async update<T extends keyof IndexerSchema>(
|
||||
table: T,
|
||||
document: IndexerDocument<T>
|
||||
): Promise<void> {
|
||||
this.calls.push(`update:${String(table)}:${document.id}`);
|
||||
}
|
||||
|
||||
override async refresh<T extends keyof IndexerSchema>(
|
||||
_table: T
|
||||
): Promise<void> {
|
||||
return;
|
||||
}
|
||||
|
||||
override async refreshIfNeed(): Promise<void> {
|
||||
this.calls.push('refresh');
|
||||
}
|
||||
|
||||
override async indexVersion(): Promise<number> {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
class TrackingIndexerSyncStorage extends IndexerSyncStorageBase {
|
||||
override readonly connection = new DummyConnection();
|
||||
private readonly clocks = new Map<string, DocIndexedClock>();
|
||||
|
||||
constructor(private readonly calls: string[]) {
|
||||
super();
|
||||
}
|
||||
|
||||
override async getDocIndexedClock(
|
||||
docId: string
|
||||
): Promise<DocIndexedClock | null> {
|
||||
return this.clocks.get(docId) ?? null;
|
||||
}
|
||||
|
||||
override async setDocIndexedClock(clock: DocIndexedClock): Promise<void> {
|
||||
this.calls.push(`setClock:${clock.docId}`);
|
||||
this.clocks.set(clock.docId, clock);
|
||||
}
|
||||
|
||||
override async clearDocIndexedClock(docId: string): Promise<void> {
|
||||
this.calls.push(`clearClock:${docId}`);
|
||||
this.clocks.delete(docId);
|
||||
}
|
||||
}
|
||||
|
||||
test('doc', async () => {
|
||||
const doc = new YDoc();
|
||||
doc.getMap('test').set('hello', 'world');
|
||||
@@ -207,3 +424,114 @@ test('blob', async () => {
|
||||
expect(c?.data).toEqual(new Uint8Array([4, 3, 2, 1]));
|
||||
}
|
||||
});
|
||||
|
||||
test('indexer defers indexed clock persistence until a refresh happens on delayed refresh storages', async () => {
|
||||
const calls: string[] = [];
|
||||
const docsInRootDoc = new Map([['doc1', { title: 'Doc 1' }]]);
|
||||
const docStorage = new TestDocStorage(
|
||||
'workspace-id',
|
||||
new Map([['doc1', new Date('2026-01-01T00:00:00.000Z')]]),
|
||||
async () => ({
|
||||
title: 'Doc 1',
|
||||
summary: 'summary',
|
||||
blocks: [
|
||||
{ blockId: 'block-1', flavour: 'affine:image', blob: ['blob-1'] },
|
||||
],
|
||||
})
|
||||
);
|
||||
const indexer = new TrackingIndexerStorage(calls, 30_000);
|
||||
const indexerSyncStorage = new TrackingIndexerSyncStorage(calls);
|
||||
const sync = new IndexerSyncImpl(
|
||||
docStorage,
|
||||
{
|
||||
local: indexer,
|
||||
remotes: {},
|
||||
},
|
||||
indexerSyncStorage
|
||||
);
|
||||
|
||||
vi.spyOn(reader, 'readAllDocsFromRootDoc').mockImplementation(
|
||||
() => new Map(docsInRootDoc)
|
||||
);
|
||||
|
||||
try {
|
||||
sync.start();
|
||||
await sync.waitForCompleted();
|
||||
|
||||
expect(calls).not.toContain('setClock:doc1');
|
||||
|
||||
sync.stop();
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(calls).toContain('setClock:doc1');
|
||||
});
|
||||
|
||||
const lastRefreshIndex = calls.lastIndexOf('refresh');
|
||||
const setClockIndex = calls.indexOf('setClock:doc1');
|
||||
|
||||
expect(lastRefreshIndex).toBeGreaterThanOrEqual(0);
|
||||
expect(setClockIndex).toBeGreaterThan(lastRefreshIndex);
|
||||
} finally {
|
||||
sync.stop();
|
||||
}
|
||||
});
|
||||
|
||||
test('indexer completion waits for the current job to finish', async () => {
|
||||
const docsInRootDoc = new Map([['doc1', { title: 'Doc 1' }]]);
|
||||
const crawlStarted = deferred<void>();
|
||||
const releaseCrawl = deferred<void>();
|
||||
const docStorage = new TestDocStorage(
|
||||
'workspace-id',
|
||||
new Map([['doc1', new Date('2026-01-01T00:00:00.000Z')]]),
|
||||
async () => {
|
||||
crawlStarted.resolve();
|
||||
await releaseCrawl.promise;
|
||||
return {
|
||||
title: 'Doc 1',
|
||||
summary: 'summary',
|
||||
blocks: [
|
||||
{ blockId: 'block-1', flavour: 'affine:image', blob: ['blob-1'] },
|
||||
],
|
||||
};
|
||||
}
|
||||
);
|
||||
const sync = new IndexerSyncImpl(
|
||||
docStorage,
|
||||
{
|
||||
local: new TrackingIndexerStorage([], 30_000),
|
||||
remotes: {},
|
||||
},
|
||||
new TrackingIndexerSyncStorage([])
|
||||
);
|
||||
|
||||
vi.spyOn(reader, 'readAllDocsFromRootDoc').mockImplementation(
|
||||
() => new Map(docsInRootDoc)
|
||||
);
|
||||
|
||||
try {
|
||||
sync.start();
|
||||
await crawlStarted.promise;
|
||||
|
||||
let completed = false;
|
||||
let docCompleted = false;
|
||||
|
||||
const waitForCompleted = sync.waitForCompleted().then(() => {
|
||||
completed = true;
|
||||
});
|
||||
const waitForDocCompleted = sync.waitForDocCompleted('doc1').then(() => {
|
||||
docCompleted = true;
|
||||
});
|
||||
|
||||
await new Promise(resolve => setTimeout(resolve, 20));
|
||||
|
||||
expect(completed).toBe(false);
|
||||
expect(docCompleted).toBe(false);
|
||||
|
||||
releaseCrawl.resolve();
|
||||
|
||||
await waitForCompleted;
|
||||
await waitForDocCompleted;
|
||||
} finally {
|
||||
sync.stop();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -112,6 +112,10 @@ export class IndexerSyncImpl implements IndexerSync {
|
||||
|
||||
private readonly indexer: IndexerStorage;
|
||||
private readonly remote?: IndexerStorage;
|
||||
private readonly pendingIndexedClocks = new Map<
|
||||
string,
|
||||
{ docId: string; timestamp: Date; indexerVersion: number }
|
||||
>();
|
||||
|
||||
private lastRefreshed = Date.now();
|
||||
|
||||
@@ -372,12 +376,13 @@ export class IndexerSyncImpl implements IndexerSync {
|
||||
field: 'docId',
|
||||
match: docId,
|
||||
});
|
||||
this.pendingIndexedClocks.delete(docId);
|
||||
await this.indexerSync.clearDocIndexedClock(docId);
|
||||
this.status.docsInIndexer.delete(docId);
|
||||
this.status.statusUpdatedSubject$.next(docId);
|
||||
}
|
||||
}
|
||||
await this.refreshIfNeed();
|
||||
await this.refreshIfNeed(true);
|
||||
// #endregion
|
||||
} else {
|
||||
// #region crawl doc
|
||||
@@ -394,7 +399,8 @@ export class IndexerSyncImpl implements IndexerSync {
|
||||
}
|
||||
|
||||
const docIndexedClock =
|
||||
await this.indexerSync.getDocIndexedClock(docId);
|
||||
this.pendingIndexedClocks.get(docId) ??
|
||||
(await this.indexerSync.getDocIndexedClock(docId));
|
||||
if (
|
||||
docIndexedClock &&
|
||||
docIndexedClock.timestamp.getTime() ===
|
||||
@@ -460,13 +466,12 @@ export class IndexerSyncImpl implements IndexerSync {
|
||||
);
|
||||
}
|
||||
|
||||
await this.refreshIfNeed();
|
||||
|
||||
await this.indexerSync.setDocIndexedClock({
|
||||
this.pendingIndexedClocks.set(docId, {
|
||||
docId,
|
||||
timestamp: docClock.timestamp,
|
||||
indexerVersion: indexVersion,
|
||||
});
|
||||
await this.refreshIfNeed();
|
||||
// #endregion
|
||||
}
|
||||
|
||||
@@ -476,7 +481,7 @@ export class IndexerSyncImpl implements IndexerSync {
|
||||
this.status.completeJob();
|
||||
}
|
||||
} finally {
|
||||
await this.refreshIfNeed();
|
||||
await this.refreshIfNeed(true);
|
||||
unsubscribe();
|
||||
}
|
||||
}
|
||||
@@ -484,18 +489,27 @@ export class IndexerSyncImpl implements IndexerSync {
|
||||
// ensure the indexer is refreshed according to recommendRefreshInterval
|
||||
// recommendRefreshInterval <= 0 means force refresh on each operation
|
||||
// recommendRefreshInterval > 0 means refresh if the last refresh is older than recommendRefreshInterval
|
||||
private async refreshIfNeed(): Promise<void> {
|
||||
private async refreshIfNeed(force = false): Promise<void> {
|
||||
const recommendRefreshInterval = this.indexer.recommendRefreshInterval ?? 0;
|
||||
const needRefresh =
|
||||
recommendRefreshInterval > 0 &&
|
||||
this.lastRefreshed + recommendRefreshInterval < Date.now();
|
||||
const forceRefresh = recommendRefreshInterval <= 0;
|
||||
if (needRefresh || forceRefresh) {
|
||||
if (force || needRefresh || forceRefresh) {
|
||||
await this.indexer.refreshIfNeed();
|
||||
await this.flushPendingIndexedClocks();
|
||||
this.lastRefreshed = Date.now();
|
||||
}
|
||||
}
|
||||
|
||||
private async flushPendingIndexedClocks() {
|
||||
if (this.pendingIndexedClocks.size === 0) return;
|
||||
for (const [docId, clock] of this.pendingIndexedClocks) {
|
||||
await this.indexerSync.setDocIndexedClock(clock);
|
||||
this.pendingIndexedClocks.delete(docId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all docs from the root doc, without deleted docs
|
||||
*/
|
||||
@@ -706,7 +720,10 @@ class IndexerSyncStatus {
|
||||
indexing: this.jobs.length() + (this.currentJob ? 1 : 0),
|
||||
total: this.docsInRootDoc.size + 1,
|
||||
errorMessage: this.errorMessage,
|
||||
completed: this.rootDocReady && this.jobs.length() === 0,
|
||||
completed:
|
||||
this.rootDocReady &&
|
||||
this.jobs.length() === 0 &&
|
||||
this.currentJob === null,
|
||||
batterySaveMode: this.batterySaveMode,
|
||||
paused: this.paused !== null,
|
||||
});
|
||||
@@ -734,9 +751,10 @@ class IndexerSyncStatus {
|
||||
completed: true,
|
||||
});
|
||||
} else {
|
||||
const indexing = this.jobs.has(docId) || this.currentJob === docId;
|
||||
subscribe.next({
|
||||
indexing: this.jobs.has(docId),
|
||||
completed: this.docsInIndexer.has(docId) && !this.jobs.has(docId),
|
||||
indexing,
|
||||
completed: this.docsInIndexer.has(docId) && !indexing,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
export const encodeLink = (link: string) =>
|
||||
encodeURI(link)
|
||||
.replace(/\(/g, '%28')
|
||||
.replace(/\)/g, '%29')
|
||||
.replaceAll('(', '%28')
|
||||
.replaceAll(')', '%29')
|
||||
.replace(/(\?|&)response-content-disposition=attachment.*$/, '');
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"aws4": "^1.13.2",
|
||||
"fast-xml-parser": "^5.3.4",
|
||||
"fast-xml-parser": "^5.5.7",
|
||||
"s3mini": "^0.9.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -19,6 +19,7 @@ import app.affine.pro.plugin.AFFiNEThemePlugin
|
||||
import app.affine.pro.plugin.AuthPlugin
|
||||
import app.affine.pro.plugin.HashCashPlugin
|
||||
import app.affine.pro.plugin.NbStorePlugin
|
||||
import app.affine.pro.plugin.PreviewPlugin
|
||||
import app.affine.pro.service.GraphQLService
|
||||
import app.affine.pro.service.SSEService
|
||||
import app.affine.pro.service.WebService
|
||||
@@ -52,6 +53,7 @@ class MainActivity : BridgeActivity(), AIButtonPlugin.Callback, AFFiNEThemePlugi
|
||||
AuthPlugin::class.java,
|
||||
HashCashPlugin::class.java,
|
||||
NbStorePlugin::class.java,
|
||||
PreviewPlugin::class.java,
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package app.affine.pro.ai.chat
|
||||
|
||||
import com.affine.pro.graphql.GetCopilotHistoriesQuery
|
||||
import com.affine.pro.graphql.fragment.CopilotChatHistory
|
||||
import com.affine.pro.graphql.fragment.CopilotChatMessage
|
||||
import kotlinx.datetime.Clock
|
||||
import kotlinx.datetime.Instant
|
||||
|
||||
@@ -53,7 +51,7 @@ data class ChatMessage(
|
||||
createAt = Clock.System.now(),
|
||||
)
|
||||
|
||||
fun from(message: CopilotChatMessage) = ChatMessage(
|
||||
fun from(message: CopilotChatHistory.Message) = ChatMessage(
|
||||
id = message.id,
|
||||
role = Role.fromValue(message.role),
|
||||
content = message.content,
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
package app.affine.pro.plugin
|
||||
|
||||
import android.net.Uri
|
||||
import com.getcapacitor.JSObject
|
||||
import com.getcapacitor.Plugin
|
||||
import com.getcapacitor.PluginCall
|
||||
import com.getcapacitor.PluginMethod
|
||||
import com.getcapacitor.annotation.CapacitorPlugin
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import timber.log.Timber
|
||||
import uniffi.affine_mobile_native.renderMermaidPreviewSvg
|
||||
import uniffi.affine_mobile_native.renderTypstPreviewSvg
|
||||
import java.io.File
|
||||
|
||||
private fun JSObject.getOptionalString(key: String): String? {
|
||||
return if (has(key) && !isNull(key)) getString(key) else null
|
||||
}
|
||||
|
||||
private fun JSObject.getOptionalDouble(key: String): Double? {
|
||||
return if (has(key) && !isNull(key)) getDouble(key) else null
|
||||
}
|
||||
|
||||
private fun resolveLocalFontDir(fontUrl: String): String? {
|
||||
val uri = Uri.parse(fontUrl)
|
||||
val path = when {
|
||||
uri.scheme == null -> {
|
||||
val file = File(fontUrl)
|
||||
if (!file.isAbsolute) {
|
||||
return null
|
||||
}
|
||||
file.path
|
||||
}
|
||||
uri.scheme == "file" -> uri.path
|
||||
else -> null
|
||||
} ?: return null
|
||||
|
||||
val file = File(path)
|
||||
val directory = if (file.isDirectory) file else file.parentFile ?: return null
|
||||
return directory.absolutePath
|
||||
}
|
||||
|
||||
private fun JSObject.resolveTypstFontDirs(): List<String>? {
|
||||
if (!has("fontUrls") || isNull("fontUrls")) {
|
||||
return null
|
||||
}
|
||||
|
||||
val fontUrls = optJSONArray("fontUrls")
|
||||
?: throw IllegalArgumentException("Typst preview fontUrls must be an array of strings.")
|
||||
val fontDirs = buildList(fontUrls.length()) {
|
||||
repeat(fontUrls.length()) { index ->
|
||||
val fontUrl = fontUrls.optString(index, null)
|
||||
?: throw IllegalArgumentException("Typst preview fontUrls must be strings.")
|
||||
val fontDir = resolveLocalFontDir(fontUrl)
|
||||
?: throw IllegalArgumentException("Typst preview on mobile only supports local font file URLs or absolute font directories.")
|
||||
add(fontDir)
|
||||
}
|
||||
}
|
||||
return fontDirs.distinct()
|
||||
}
|
||||
|
||||
@CapacitorPlugin(name = "Preview")
|
||||
class PreviewPlugin : Plugin() {
|
||||
|
||||
@PluginMethod
|
||||
fun renderMermaidSvg(call: PluginCall) {
|
||||
launch(Dispatchers.IO) {
|
||||
try {
|
||||
val code = call.getStringEnsure("code")
|
||||
val options = call.getObject("options")
|
||||
val svg = renderMermaidPreviewSvg(
|
||||
code = code,
|
||||
theme = options?.getOptionalString("theme"),
|
||||
fontFamily = options?.getOptionalString("fontFamily"),
|
||||
fontSize = options?.getOptionalDouble("fontSize"),
|
||||
)
|
||||
call.resolve(JSObject().apply {
|
||||
put("svg", svg)
|
||||
})
|
||||
} catch (e: Exception) {
|
||||
Timber.e(e, "Failed to render Mermaid preview.")
|
||||
call.reject("Failed to render Mermaid preview.", null, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@PluginMethod
|
||||
fun renderTypstSvg(call: PluginCall) {
|
||||
launch(Dispatchers.IO) {
|
||||
try {
|
||||
val code = call.getStringEnsure("code")
|
||||
val options = call.getObject("options")
|
||||
val svg = renderTypstPreviewSvg(
|
||||
code = code,
|
||||
fontDirs = options?.resolveTypstFontDirs(),
|
||||
cacheDir = context.cacheDir.absolutePath,
|
||||
)
|
||||
call.resolve(JSObject().apply {
|
||||
put("svg", svg)
|
||||
})
|
||||
} catch (e: Exception) {
|
||||
Timber.e(e, "Failed to render Typst preview.")
|
||||
call.reject("Failed to render Typst preview.", null, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -72,7 +72,7 @@ class GraphQLService @Inject constructor() {
|
||||
).mapCatching { data ->
|
||||
data.currentUser?.copilot?.chats?.paginatedCopilotChats?.edges?.map { item -> item.node.copilotChatHistory }?.firstOrNull { history ->
|
||||
history.sessionId == sessionId
|
||||
}?.messages?.map { msg -> msg.copilotChatMessage } ?: emptyList()
|
||||
}?.messages ?: emptyList()
|
||||
}
|
||||
|
||||
suspend fun getCopilotHistoryIds(
|
||||
|
||||
@@ -792,6 +792,10 @@ internal interface UniffiForeignFutureCompleteVoid : com.sun.jna.Callback {
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -816,6 +820,10 @@ internal interface IntegrityCheckingUniffiLib : Library {
|
||||
): Short
|
||||
fun uniffi_affine_mobile_native_checksum_func_new_doc_storage_pool(
|
||||
): Short
|
||||
fun uniffi_affine_mobile_native_checksum_func_render_mermaid_preview_svg(
|
||||
): Short
|
||||
fun uniffi_affine_mobile_native_checksum_func_render_typst_preview_svg(
|
||||
): Short
|
||||
fun uniffi_affine_mobile_native_checksum_method_docstoragepool_clear_clocks(
|
||||
): Short
|
||||
fun uniffi_affine_mobile_native_checksum_method_docstoragepool_connect(
|
||||
@@ -1017,6 +1025,10 @@ fun uniffi_affine_mobile_native_fn_func_hashcash_mint(`resource`: RustBuffer.ByV
|
||||
): RustBuffer.ByValue
|
||||
fun uniffi_affine_mobile_native_fn_func_new_doc_storage_pool(uniffi_out_err: UniffiRustCallStatus,
|
||||
): Pointer
|
||||
fun uniffi_affine_mobile_native_fn_func_render_mermaid_preview_svg(`code`: RustBuffer.ByValue,`theme`: RustBuffer.ByValue,`fontFamily`: RustBuffer.ByValue,`fontSize`: RustBuffer.ByValue,uniffi_out_err: UniffiRustCallStatus,
|
||||
): RustBuffer.ByValue
|
||||
fun uniffi_affine_mobile_native_fn_func_render_typst_preview_svg(`code`: RustBuffer.ByValue,`fontDirs`: RustBuffer.ByValue,`cacheDir`: RustBuffer.ByValue,uniffi_out_err: UniffiRustCallStatus,
|
||||
): RustBuffer.ByValue
|
||||
fun ffi_affine_mobile_native_rustbuffer_alloc(`size`: Long,uniffi_out_err: UniffiRustCallStatus,
|
||||
): RustBuffer.ByValue
|
||||
fun ffi_affine_mobile_native_rustbuffer_from_bytes(`bytes`: ForeignBytes.ByValue,uniffi_out_err: UniffiRustCallStatus,
|
||||
@@ -1149,6 +1161,12 @@ private fun uniffiCheckApiChecksums(lib: IntegrityCheckingUniffiLib) {
|
||||
if (lib.uniffi_affine_mobile_native_checksum_func_new_doc_storage_pool() != 32882.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_affine_mobile_native_checksum_func_render_mermaid_preview_svg() != 54334.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_affine_mobile_native_checksum_func_render_typst_preview_svg() != 42796.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_affine_mobile_native_checksum_method_docstoragepool_clear_clocks() != 51151.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
@@ -3178,6 +3196,38 @@ public object FfiConverterOptionalLong: FfiConverterRustBuffer<kotlin.Long?> {
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* @suppress
|
||||
*/
|
||||
public object FfiConverterOptionalDouble: FfiConverterRustBuffer<kotlin.Double?> {
|
||||
override fun read(buf: ByteBuffer): kotlin.Double? {
|
||||
if (buf.get().toInt() == 0) {
|
||||
return null
|
||||
}
|
||||
return FfiConverterDouble.read(buf)
|
||||
}
|
||||
|
||||
override fun allocationSize(value: kotlin.Double?): ULong {
|
||||
if (value == null) {
|
||||
return 1UL
|
||||
} else {
|
||||
return 1UL + FfiConverterDouble.allocationSize(value)
|
||||
}
|
||||
}
|
||||
|
||||
override fun write(value: kotlin.Double?, buf: ByteBuffer) {
|
||||
if (value == null) {
|
||||
buf.put(0)
|
||||
} else {
|
||||
buf.put(1)
|
||||
FfiConverterDouble.write(value, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* @suppress
|
||||
*/
|
||||
@@ -3584,4 +3634,24 @@ public object FfiConverterSequenceTypeSearchHit: FfiConverterRustBuffer<List<Sea
|
||||
}
|
||||
|
||||
|
||||
@Throws(UniffiException::class) fun `renderMermaidPreviewSvg`(`code`: kotlin.String, `theme`: kotlin.String?, `fontFamily`: kotlin.String?, `fontSize`: kotlin.Double?): kotlin.String {
|
||||
return FfiConverterString.lift(
|
||||
uniffiRustCallWithError(UniffiException) { _status ->
|
||||
UniffiLib.INSTANCE.uniffi_affine_mobile_native_fn_func_render_mermaid_preview_svg(
|
||||
FfiConverterString.lower(`code`),FfiConverterOptionalString.lower(`theme`),FfiConverterOptionalString.lower(`fontFamily`),FfiConverterOptionalDouble.lower(`fontSize`),_status)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@Throws(UniffiException::class) fun `renderTypstPreviewSvg`(`code`: kotlin.String, `fontDirs`: List<kotlin.String>?, `cacheDir`: kotlin.String?): kotlin.String {
|
||||
return FfiConverterString.lift(
|
||||
uniffiRustCallWithError(UniffiException) { _status ->
|
||||
UniffiLib.INSTANCE.uniffi_affine_mobile_native_fn_func_render_typst_preview_svg(
|
||||
FfiConverterString.lower(`code`),FfiConverterOptionalSequenceString.lower(`fontDirs`),FfiConverterOptionalString.lower(`cacheDir`),_status)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
ServersService,
|
||||
ValidatorProvider,
|
||||
} from '@affine/core/modules/cloud';
|
||||
import { registerNativePreviewHandlers } from '@affine/core/modules/code-block-preview-renderer';
|
||||
import { DocsService } from '@affine/core/modules/doc';
|
||||
import { GlobalContextService } from '@affine/core/modules/global-context';
|
||||
import { I18nProvider } from '@affine/core/modules/i18n';
|
||||
@@ -54,6 +55,7 @@ import { AIButton } from './plugins/ai-button';
|
||||
import { Auth } from './plugins/auth';
|
||||
import { HashCash } from './plugins/hashcash';
|
||||
import { NbStoreNativeDBApis } from './plugins/nbstore';
|
||||
import { Preview } from './plugins/preview';
|
||||
import { writeEndpointToken } from './proxy';
|
||||
|
||||
const storeManagerClient = createStoreManagerClient();
|
||||
@@ -85,6 +87,11 @@ framework.impl(NbstoreProvider, {
|
||||
});
|
||||
const frameworkProvider = framework.provider();
|
||||
|
||||
registerNativePreviewHandlers({
|
||||
renderMermaidSvg: request => Preview.renderMermaidSvg(request),
|
||||
renderTypstSvg: request => Preview.renderTypstSvg(request),
|
||||
});
|
||||
|
||||
framework.impl(PopupWindowProvider, {
|
||||
open: (url: string) => {
|
||||
InAppBrowser.open({
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
export interface PreviewPlugin {
|
||||
renderMermaidSvg(options: {
|
||||
code: string;
|
||||
options?: {
|
||||
theme?: string;
|
||||
fontFamily?: string;
|
||||
fontSize?: number;
|
||||
};
|
||||
}): Promise<{ svg: string }>;
|
||||
renderTypstSvg(options: {
|
||||
code: string;
|
||||
options?: {
|
||||
fontUrls?: string[];
|
||||
};
|
||||
}): Promise<{ svg: string }>;
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
import { registerPlugin } from '@capacitor/core';
|
||||
|
||||
import type { PreviewPlugin } from './definitions';
|
||||
|
||||
const Preview = registerPlugin<PreviewPlugin>('Preview');
|
||||
|
||||
export * from './definitions';
|
||||
export { Preview };
|
||||
@@ -46,7 +46,10 @@ export function setupEvents(frameworkProvider: FrameworkProvider) {
|
||||
const { workspace } = currentWorkspace;
|
||||
const docsService = workspace.scope.get(DocsService);
|
||||
|
||||
const page = docsService.createDoc({ primaryMode: type });
|
||||
const page =
|
||||
type === 'default'
|
||||
? docsService.createDoc()
|
||||
: docsService.createDoc({ primaryMode: type });
|
||||
workspace.scope.get(WorkbenchService).workbench.openDoc(page.id);
|
||||
})
|
||||
.catch(err => {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { parse } from 'node:path';
|
||||
import { parse, resolve } from 'node:path';
|
||||
|
||||
import { DocStorage, ValidationResult } from '@affine/native';
|
||||
import { parseUniversalId } from '@affine/nbstore';
|
||||
@@ -71,10 +71,34 @@ function getDefaultDBFileName(name: string, id: string) {
|
||||
return fileName.replace(/[/\\?%*:|"<>]/g, '-');
|
||||
}
|
||||
|
||||
async function resolveExistingPath(path: string) {
|
||||
if (!(await fs.pathExists(path))) {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
return await fs.realpath(path);
|
||||
} catch {
|
||||
return resolve(path);
|
||||
}
|
||||
}
|
||||
|
||||
async function isSameFilePath(sourcePath: string, targetPath: string) {
|
||||
if (resolve(sourcePath) === resolve(targetPath)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const [sourceRealPath, targetRealPath] = await Promise.all([
|
||||
resolveExistingPath(sourcePath),
|
||||
resolveExistingPath(targetPath),
|
||||
]);
|
||||
|
||||
return !!sourceRealPath && sourceRealPath === targetRealPath;
|
||||
}
|
||||
|
||||
/**
|
||||
* This function is called when the user clicks the "Save" button in the "Save Workspace" dialog.
|
||||
*
|
||||
* It will just copy the file to the given path
|
||||
* It will export a compacted database file to the given path
|
||||
*/
|
||||
export async function saveDBFileAs(
|
||||
universalId: string,
|
||||
@@ -115,12 +139,26 @@ export async function saveDBFileAs(
|
||||
|
||||
const filePath = ret.filePath;
|
||||
if (ret.canceled || !filePath) {
|
||||
return {
|
||||
canceled: true,
|
||||
};
|
||||
return { canceled: true };
|
||||
}
|
||||
|
||||
await fs.copyFile(dbPath, filePath);
|
||||
if (await isSameFilePath(dbPath, filePath)) {
|
||||
return { error: 'DB_FILE_PATH_INVALID' };
|
||||
}
|
||||
|
||||
const tempFilePath = `${filePath}.${nanoid(6)}.tmp`;
|
||||
if (await fs.pathExists(tempFilePath)) {
|
||||
await fs.remove(tempFilePath);
|
||||
}
|
||||
|
||||
try {
|
||||
await pool.vacuumInto(universalId, tempFilePath);
|
||||
await fs.move(tempFilePath, filePath, { overwrite: true });
|
||||
} finally {
|
||||
if (await fs.pathExists(tempFilePath)) {
|
||||
await fs.remove(tempFilePath);
|
||||
}
|
||||
}
|
||||
logger.log('saved', filePath);
|
||||
if (!fakedResult) {
|
||||
mainRPC.showItemInFolder(filePath).catch(err => {
|
||||
@@ -183,11 +221,7 @@ export async function loadDBFile(
|
||||
const provided =
|
||||
getFakedResult() ??
|
||||
(dbFilePath
|
||||
? {
|
||||
filePath: dbFilePath,
|
||||
filePaths: [dbFilePath],
|
||||
canceled: false,
|
||||
}
|
||||
? { filePath: dbFilePath, filePaths: [dbFilePath], canceled: false }
|
||||
: undefined);
|
||||
const ret =
|
||||
provided ??
|
||||
@@ -224,6 +258,10 @@ export async function loadDBFile(
|
||||
return await cpV1DBFile(originalPath, workspaceId);
|
||||
}
|
||||
|
||||
if (!(await storage.validateImportSchema())) {
|
||||
return { error: 'DB_FILE_INVALID' };
|
||||
}
|
||||
|
||||
// v2 import logic
|
||||
const internalFilePath = await getSpaceDBPath(
|
||||
'local',
|
||||
@@ -231,8 +269,8 @@ export async function loadDBFile(
|
||||
workspaceId
|
||||
);
|
||||
await fs.ensureDir(parse(internalFilePath).dir);
|
||||
await fs.copy(originalPath, internalFilePath);
|
||||
logger.info(`loadDBFile, copy: ${originalPath} -> ${internalFilePath}`);
|
||||
await storage.vacuumInto(internalFilePath);
|
||||
logger.info(`loadDBFile, vacuum: ${originalPath} -> ${internalFilePath}`);
|
||||
|
||||
storage = new DocStorage(internalFilePath);
|
||||
await storage.setSpaceId(workspaceId);
|
||||
@@ -260,17 +298,16 @@ async function cpV1DBFile(
|
||||
return { error: 'DB_FILE_INVALID' }; // invalid db file
|
||||
}
|
||||
|
||||
// checkout to make sure wal is flushed
|
||||
const connection = new SqliteConnection(originalPath);
|
||||
await connection.connect();
|
||||
await connection.checkpoint();
|
||||
await connection.close();
|
||||
if (!(await connection.validateImportSchema())) {
|
||||
return { error: 'DB_FILE_INVALID' };
|
||||
}
|
||||
|
||||
const internalFilePath = await getWorkspaceDBPath('workspace', workspaceId);
|
||||
|
||||
await fs.ensureDir(await getWorkspacesBasePath());
|
||||
await fs.copy(originalPath, internalFilePath);
|
||||
logger.info(`loadDBFile, copy: ${originalPath} -> ${internalFilePath}`);
|
||||
await fs.ensureDir(parse(internalFilePath).dir);
|
||||
await connection.vacuumInto(internalFilePath);
|
||||
logger.info(`loadDBFile, vacuum: ${originalPath} -> ${internalFilePath}`);
|
||||
|
||||
await storeWorkspaceMeta(workspaceId, {
|
||||
id: workspaceId,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { dialogHandlers } from './dialog';
|
||||
import { dbEventsV1, dbHandlersV1, nbstoreHandlers } from './nbstore';
|
||||
import { previewHandlers } from './preview';
|
||||
import { provideExposed } from './provide';
|
||||
import { workspaceEvents, workspaceHandlers } from './workspace';
|
||||
|
||||
@@ -8,6 +9,7 @@ export const handlers = {
|
||||
nbstore: nbstoreHandlers,
|
||||
workspace: workspaceHandlers,
|
||||
dialog: dialogHandlers,
|
||||
preview: previewHandlers,
|
||||
};
|
||||
|
||||
export const events = {
|
||||
|
||||
69
packages/frontend/apps/electron/src/helper/preview/index.ts
Normal file
69
packages/frontend/apps/electron/src/helper/preview/index.ts
Normal file
@@ -0,0 +1,69 @@
|
||||
import fs from 'node:fs';
|
||||
import path from 'node:path';
|
||||
|
||||
import {
|
||||
type MermaidRenderRequest,
|
||||
type MermaidRenderResult,
|
||||
renderMermaidSvg,
|
||||
renderTypstSvg,
|
||||
type TypstRenderRequest,
|
||||
type TypstRenderResult,
|
||||
} from '@affine/native';
|
||||
|
||||
const TYPST_FONT_DIRS_ENV = 'AFFINE_TYPST_FONT_DIRS';
|
||||
|
||||
function parseTypstFontDirsFromEnv() {
|
||||
const value = process.env[TYPST_FONT_DIRS_ENV];
|
||||
if (!value) {
|
||||
return [];
|
||||
}
|
||||
|
||||
return value
|
||||
.split(path.delimiter)
|
||||
.map(dir => dir.trim())
|
||||
.filter(Boolean);
|
||||
}
|
||||
|
||||
function getTypstFontDirCandidates() {
|
||||
const resourcesPath = process.resourcesPath ?? '';
|
||||
|
||||
return [
|
||||
...parseTypstFontDirsFromEnv(),
|
||||
path.join(resourcesPath, 'fonts'),
|
||||
path.join(resourcesPath, 'js', 'fonts'),
|
||||
path.join(resourcesPath, 'app.asar.unpacked', 'fonts'),
|
||||
path.join(resourcesPath, 'app.asar.unpacked', 'js', 'fonts'),
|
||||
];
|
||||
}
|
||||
|
||||
function resolveTypstFontDirs() {
|
||||
return Array.from(
|
||||
new Set(getTypstFontDirCandidates().map(dir => path.resolve(dir)))
|
||||
).filter(dir => fs.statSync(dir, { throwIfNoEntry: false })?.isDirectory());
|
||||
}
|
||||
|
||||
function withTypstFontDirs(
|
||||
request: TypstRenderRequest,
|
||||
fontDirs: string[]
|
||||
): TypstRenderRequest {
|
||||
const nextOptions = request.options ? { ...request.options } : {};
|
||||
if (!nextOptions.fontDirs?.length) {
|
||||
nextOptions.fontDirs = fontDirs;
|
||||
}
|
||||
return { ...request, options: nextOptions };
|
||||
}
|
||||
|
||||
const typstFontDirs = resolveTypstFontDirs();
|
||||
|
||||
export const previewHandlers = {
|
||||
renderMermaidSvg: async (
|
||||
request: MermaidRenderRequest
|
||||
): Promise<MermaidRenderResult> => {
|
||||
return renderMermaidSvg(request);
|
||||
},
|
||||
renderTypstSvg: async (
|
||||
request: TypstRenderRequest
|
||||
): Promise<TypstRenderResult> => {
|
||||
return renderTypstSvg(withTypstFontDirs(request, typstFontDirs));
|
||||
},
|
||||
};
|
||||
@@ -67,7 +67,7 @@ export function createApplicationMenu() {
|
||||
click: async () => {
|
||||
await initAndShowMainWindow();
|
||||
// fixme: if the window is just created, the new page action will not be triggered
|
||||
applicationMenuSubjects.newPageAction$.next('page');
|
||||
applicationMenuSubjects.newPageAction$.next('default');
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { MainEventRegister } from '../type';
|
||||
import { applicationMenuSubjects } from './subject';
|
||||
import { applicationMenuSubjects, type NewPageAction } from './subject';
|
||||
|
||||
export * from './create';
|
||||
export * from './subject';
|
||||
@@ -11,7 +11,7 @@ export const applicationMenuEvents = {
|
||||
/**
|
||||
* File -> New Doc
|
||||
*/
|
||||
onNewPageAction: (fn: (type: 'page' | 'edgeless') => void) => {
|
||||
onNewPageAction: (fn: (type: NewPageAction) => void) => {
|
||||
const sub = applicationMenuSubjects.newPageAction$.subscribe(fn);
|
||||
return () => {
|
||||
sub.unsubscribe();
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import { Subject } from 'rxjs';
|
||||
|
||||
export type NewPageAction = 'page' | 'edgeless' | 'default';
|
||||
|
||||
export const applicationMenuSubjects = {
|
||||
newPageAction$: new Subject<'page' | 'edgeless'>(),
|
||||
newPageAction$: new Subject<NewPageAction>(),
|
||||
openJournal$: new Subject<void>(),
|
||||
openInSettingModal$: new Subject<{
|
||||
activeTab: string;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user