mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-03-24 16:18:39 +08:00
Compare commits
31 Commits
v2026.3.9-
...
d7adbb99c9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d7adbb99c9 | ||
|
|
6a93566422 | ||
|
|
7ac8b14b65 | ||
|
|
16a8f17717 | ||
|
|
1ffb8c922c | ||
|
|
daf536f77a | ||
|
|
0d2d4bb6a1 | ||
|
|
cb9897d493 | ||
|
|
8ca8333cd6 | ||
|
|
3bf2503f55 | ||
|
|
59fd942f40 | ||
|
|
d6d5ae6182 | ||
|
|
c1a09b951f | ||
|
|
4ce68d74f1 | ||
|
|
fbfcc01d14 | ||
|
|
1112a06623 | ||
|
|
bbcb7e69fe | ||
|
|
cc2f23339e | ||
|
|
31101a69e7 | ||
|
|
0b1a44863f | ||
|
|
8406f9656e | ||
|
|
121c0d172d | ||
|
|
8f03090780 | ||
|
|
8125cc0e75 | ||
|
|
f537a75f01 | ||
|
|
9456a07889 | ||
|
|
8f571ddc30 | ||
|
|
13ad1beb10 | ||
|
|
9844ca4d54 | ||
|
|
d7d67841b8 | ||
|
|
29a27b561b |
@@ -19,3 +19,8 @@ rustflags = [
|
||||
# pthread_key_create() destructors and segfault after a DSO unloading
|
||||
[target.'cfg(all(target_env = "gnu", not(target_os = "windows")))']
|
||||
rustflags = ["-C", "link-args=-Wl,-z,nodelete"]
|
||||
|
||||
# Temporary local llm_adapter override.
|
||||
# Uncomment when verifying AFFiNE against the sibling llm_adapter workspace.
|
||||
# [patch.crates-io]
|
||||
# llm_adapter = { path = "../llm_adapter" }
|
||||
|
||||
4
.github/helm/affine/charts/front/values.yaml
vendored
4
.github/helm/affine/charts/front/values.yaml
vendored
@@ -31,10 +31,10 @@ podSecurityContext:
|
||||
resources:
|
||||
limits:
|
||||
cpu: '1'
|
||||
memory: 4Gi
|
||||
memory: 6Gi
|
||||
requests:
|
||||
cpu: '1'
|
||||
memory: 2Gi
|
||||
memory: 4Gi
|
||||
|
||||
probe:
|
||||
initialDelaySeconds: 20
|
||||
|
||||
4
.github/renovate.json
vendored
4
.github/renovate.json
vendored
@@ -63,7 +63,7 @@
|
||||
"groupName": "opentelemetry",
|
||||
"matchPackageNames": [
|
||||
"/^@opentelemetry/",
|
||||
"/^@google-cloud\/opentelemetry-/"
|
||||
"/^@google-cloud/opentelemetry-/"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -79,7 +79,7 @@
|
||||
"customManagers": [
|
||||
{
|
||||
"customType": "regex",
|
||||
"fileMatch": ["^rust-toolchain\\.toml?$"],
|
||||
"managerFilePatterns": ["/^rust-toolchain\\.toml?$/"],
|
||||
"matchStrings": [
|
||||
"channel\\s*=\\s*\"(?<currentValue>\\d+\\.\\d+(\\.\\d+)?)\""
|
||||
],
|
||||
|
||||
942
.yarn/releases/yarn-4.12.0.cjs
vendored
942
.yarn/releases/yarn-4.12.0.cjs
vendored
File diff suppressed because one or more lines are too long
940
.yarn/releases/yarn-4.13.0.cjs
vendored
Executable file
940
.yarn/releases/yarn-4.13.0.cjs
vendored
Executable file
File diff suppressed because one or more lines are too long
@@ -12,4 +12,4 @@ npmPublishAccess: public
|
||||
|
||||
npmRegistryServer: "https://registry.npmjs.org"
|
||||
|
||||
yarnPath: .yarn/releases/yarn-4.12.0.cjs
|
||||
yarnPath: .yarn/releases/yarn-4.13.0.cjs
|
||||
|
||||
3130
Cargo.lock
generated
3130
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
14
Cargo.toml
14
Cargo.toml
@@ -36,7 +36,7 @@ resolver = "3"
|
||||
criterion2 = { version = "3", default-features = false }
|
||||
crossbeam-channel = "0.5"
|
||||
dispatch2 = "0.3"
|
||||
docx-parser = { git = "https://github.com/toeverything/docx-parser" }
|
||||
docx-parser = { git = "https://github.com/toeverything/docx-parser", rev = "380beea" }
|
||||
dotenvy = "0.15"
|
||||
file-format = { version = "0.28", features = ["reader"] }
|
||||
homedir = "0.3"
|
||||
@@ -53,11 +53,13 @@ resolver = "3"
|
||||
libc = "0.2"
|
||||
libwebp-sys = "0.14.2"
|
||||
little_exif = "0.6.23"
|
||||
llm_adapter = "0.1.1"
|
||||
llm_adapter = { version = "0.1.3", default-features = false }
|
||||
log = "0.4"
|
||||
loom = { version = "0.7", features = ["checkpoint"] }
|
||||
lru = "0.16"
|
||||
matroska = "0.30"
|
||||
memory-indexer = "0.3.0"
|
||||
mermaid-rs-renderer = { git = "https://github.com/toeverything/mermaid-rs-renderer", rev = "fba9097", default-features = false }
|
||||
mimalloc = "0.1"
|
||||
mp4parse = "0.17"
|
||||
nanoid = "0.4"
|
||||
@@ -121,6 +123,14 @@ resolver = "3"
|
||||
tree-sitter-rust = { version = "0.24" }
|
||||
tree-sitter-scala = { version = "0.24" }
|
||||
tree-sitter-typescript = { version = "0.23" }
|
||||
typst = "0.14.2"
|
||||
typst-as-lib = { version = "0.15.4", default-features = false, features = [
|
||||
"packages",
|
||||
"typst-kit-embed-fonts",
|
||||
"typst-kit-fonts",
|
||||
"ureq",
|
||||
] }
|
||||
typst-svg = "0.14.2"
|
||||
uniffi = "0.29"
|
||||
url = { version = "2.5" }
|
||||
uuid = "1.8"
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
|
||||
|
||||
exports[`snapshot to markdown > imports obsidian vault fixtures 1`] = `
|
||||
{
|
||||
"entry": {
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"children": [
|
||||
{
|
||||
"delta": [
|
||||
{
|
||||
"insert": "Panel
|
||||
Body line",
|
||||
},
|
||||
],
|
||||
"flavour": "affine:paragraph",
|
||||
"type": "text",
|
||||
},
|
||||
],
|
||||
"emoji": "💡",
|
||||
"flavour": "affine:callout",
|
||||
},
|
||||
{
|
||||
"flavour": "affine:attachment",
|
||||
"name": "archive.zip",
|
||||
"style": "horizontalThin",
|
||||
},
|
||||
{
|
||||
"delta": [
|
||||
{
|
||||
"footnote": {
|
||||
"label": "1",
|
||||
"reference": {
|
||||
"title": "reference body",
|
||||
"type": "url",
|
||||
},
|
||||
},
|
||||
"insert": " ",
|
||||
},
|
||||
],
|
||||
"flavour": "affine:paragraph",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"flavour": "affine:divider",
|
||||
},
|
||||
{
|
||||
"delta": [
|
||||
{
|
||||
"insert": "after note",
|
||||
},
|
||||
],
|
||||
"flavour": "affine:paragraph",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"delta": [
|
||||
{
|
||||
"insert": " ",
|
||||
"reference": {
|
||||
"page": "linked",
|
||||
"type": "LinkedPage",
|
||||
},
|
||||
},
|
||||
],
|
||||
"flavour": "affine:paragraph",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"delta": [
|
||||
{
|
||||
"insert": "Sources",
|
||||
},
|
||||
],
|
||||
"flavour": "affine:paragraph",
|
||||
"type": "h6",
|
||||
},
|
||||
{
|
||||
"flavour": "affine:bookmark",
|
||||
},
|
||||
],
|
||||
"flavour": "affine:note",
|
||||
},
|
||||
],
|
||||
"flavour": "affine:page",
|
||||
},
|
||||
"titles": [
|
||||
"entry",
|
||||
"linked",
|
||||
],
|
||||
}
|
||||
`;
|
||||
@@ -0,0 +1,14 @@
|
||||
> [!custom] Panel
|
||||
> Body line
|
||||
|
||||
![[archive.zip]]
|
||||
|
||||
[^1]
|
||||
|
||||
---
|
||||
|
||||
after note
|
||||
|
||||
[[linked]]
|
||||
|
||||
[^1]: reference body
|
||||
@@ -0,0 +1 @@
|
||||
plain linked page
|
||||
@@ -1,4 +1,10 @@
|
||||
import { MarkdownTransformer } from '@blocksuite/affine/widgets/linked-doc';
|
||||
import { readFileSync } from 'node:fs';
|
||||
import { basename, resolve } from 'node:path';
|
||||
|
||||
import {
|
||||
MarkdownTransformer,
|
||||
ObsidianTransformer,
|
||||
} from '@blocksuite/affine/widgets/linked-doc';
|
||||
import {
|
||||
DefaultTheme,
|
||||
NoteDisplayMode,
|
||||
@@ -8,13 +14,18 @@ import {
|
||||
CalloutAdmonitionType,
|
||||
CalloutExportStyle,
|
||||
calloutMarkdownExportMiddleware,
|
||||
docLinkBaseURLMiddleware,
|
||||
embedSyncedDocMiddleware,
|
||||
MarkdownAdapter,
|
||||
titleMiddleware,
|
||||
} from '@blocksuite/affine-shared/adapters';
|
||||
import type { AffineTextAttributes } from '@blocksuite/affine-shared/types';
|
||||
import type {
|
||||
BlockSnapshot,
|
||||
DeltaInsert,
|
||||
DocSnapshot,
|
||||
SliceSnapshot,
|
||||
Store,
|
||||
TransformerMiddleware,
|
||||
} from '@blocksuite/store';
|
||||
import { AssetsManager, MemoryBlobCRUD, Schema } from '@blocksuite/store';
|
||||
@@ -29,6 +40,138 @@ import { testStoreExtensions } from '../utils/store.js';
|
||||
|
||||
const provider = getProvider();
|
||||
|
||||
function withRelativePath(file: File, relativePath: string): File {
|
||||
Object.defineProperty(file, 'webkitRelativePath', {
|
||||
value: relativePath,
|
||||
writable: false,
|
||||
});
|
||||
return file;
|
||||
}
|
||||
|
||||
function markdownFixture(relativePath: string): File {
|
||||
return withRelativePath(
|
||||
new File(
|
||||
[
|
||||
readFileSync(
|
||||
resolve(import.meta.dirname, 'fixtures/obsidian', relativePath),
|
||||
'utf8'
|
||||
),
|
||||
],
|
||||
basename(relativePath),
|
||||
{ type: 'text/markdown' }
|
||||
),
|
||||
`vault/${relativePath}`
|
||||
);
|
||||
}
|
||||
|
||||
function exportSnapshot(doc: Store): DocSnapshot {
|
||||
const job = doc.getTransformer([
|
||||
docLinkBaseURLMiddleware(doc.workspace.id),
|
||||
titleMiddleware(doc.workspace.meta.docMetas),
|
||||
]);
|
||||
const snapshot = job.docToSnapshot(doc);
|
||||
expect(snapshot).toBeTruthy();
|
||||
return snapshot!;
|
||||
}
|
||||
|
||||
function normalizeDeltaForSnapshot(
|
||||
delta: DeltaInsert<AffineTextAttributes>[],
|
||||
titleById: ReadonlyMap<string, string>
|
||||
) {
|
||||
return delta.map(item => {
|
||||
const normalized: Record<string, unknown> = {
|
||||
insert: item.insert,
|
||||
};
|
||||
|
||||
if (item.attributes?.link) {
|
||||
normalized.link = item.attributes.link;
|
||||
}
|
||||
|
||||
if (item.attributes?.reference?.type === 'LinkedPage') {
|
||||
normalized.reference = {
|
||||
type: 'LinkedPage',
|
||||
page: titleById.get(item.attributes.reference.pageId) ?? '<missing>',
|
||||
...(item.attributes.reference.title
|
||||
? { title: item.attributes.reference.title }
|
||||
: {}),
|
||||
};
|
||||
}
|
||||
|
||||
if (item.attributes?.footnote) {
|
||||
const reference = item.attributes.footnote.reference;
|
||||
normalized.footnote = {
|
||||
label: item.attributes.footnote.label,
|
||||
reference:
|
||||
reference.type === 'doc'
|
||||
? {
|
||||
type: 'doc',
|
||||
page: reference.docId
|
||||
? (titleById.get(reference.docId) ?? '<missing>')
|
||||
: '<missing>',
|
||||
}
|
||||
: {
|
||||
type: reference.type,
|
||||
...(reference.title ? { title: reference.title } : {}),
|
||||
...(reference.fileName ? { fileName: reference.fileName } : {}),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
return normalized;
|
||||
});
|
||||
}
|
||||
|
||||
function simplifyBlockForSnapshot(
|
||||
block: BlockSnapshot,
|
||||
titleById: ReadonlyMap<string, string>
|
||||
): Record<string, unknown> {
|
||||
const simplified: Record<string, unknown> = {
|
||||
flavour: block.flavour,
|
||||
};
|
||||
|
||||
if (block.flavour === 'affine:paragraph' || block.flavour === 'affine:list') {
|
||||
simplified.type = block.props.type;
|
||||
const text = block.props.text as
|
||||
| { delta?: DeltaInsert<AffineTextAttributes>[] }
|
||||
| undefined;
|
||||
simplified.delta = normalizeDeltaForSnapshot(text?.delta ?? [], titleById);
|
||||
}
|
||||
|
||||
if (block.flavour === 'affine:callout') {
|
||||
simplified.emoji = block.props.emoji;
|
||||
}
|
||||
|
||||
if (block.flavour === 'affine:attachment') {
|
||||
simplified.name = block.props.name;
|
||||
simplified.style = block.props.style;
|
||||
}
|
||||
|
||||
if (block.flavour === 'affine:image') {
|
||||
simplified.sourceId = '<asset>';
|
||||
}
|
||||
|
||||
const children = (block.children ?? [])
|
||||
.filter(child => child.flavour !== 'affine:surface')
|
||||
.map(child => simplifyBlockForSnapshot(child, titleById));
|
||||
if (children.length) {
|
||||
simplified.children = children;
|
||||
}
|
||||
|
||||
return simplified;
|
||||
}
|
||||
|
||||
function snapshotDocByTitle(
|
||||
collection: TestWorkspace,
|
||||
title: string,
|
||||
titleById: ReadonlyMap<string, string>
|
||||
) {
|
||||
const meta = collection.meta.docMetas.find(meta => meta.title === title);
|
||||
expect(meta).toBeTruthy();
|
||||
const doc = collection.getDoc(meta!.id)?.getStore({ id: meta!.id });
|
||||
expect(doc).toBeTruthy();
|
||||
return simplifyBlockForSnapshot(exportSnapshot(doc!).blocks, titleById);
|
||||
}
|
||||
|
||||
describe('snapshot to markdown', () => {
|
||||
test('code', async () => {
|
||||
const blockSnapshot: BlockSnapshot = {
|
||||
@@ -127,6 +270,46 @@ Hello world
|
||||
expect(meta?.tags).toEqual(['a', 'b']);
|
||||
});
|
||||
|
||||
test('imports obsidian vault fixtures', async () => {
|
||||
const schema = new Schema().register(AffineSchemas);
|
||||
const collection = new TestWorkspace();
|
||||
collection.storeExtensions = testStoreExtensions;
|
||||
collection.meta.initialize();
|
||||
|
||||
const attachment = withRelativePath(
|
||||
new File([new Uint8Array([80, 75, 3, 4])], 'archive.zip', {
|
||||
type: 'application/zip',
|
||||
}),
|
||||
'vault/archive.zip'
|
||||
);
|
||||
|
||||
const { docIds } = await ObsidianTransformer.importObsidianVault({
|
||||
collection,
|
||||
schema,
|
||||
importedFiles: [
|
||||
markdownFixture('entry.md'),
|
||||
markdownFixture('linked.md'),
|
||||
attachment,
|
||||
],
|
||||
extensions: testStoreExtensions,
|
||||
});
|
||||
expect(docIds).toHaveLength(2);
|
||||
|
||||
const titleById = new Map(
|
||||
collection.meta.docMetas.map(meta => [
|
||||
meta.id,
|
||||
meta.title ?? '<untitled>',
|
||||
])
|
||||
);
|
||||
|
||||
expect({
|
||||
titles: collection.meta.docMetas
|
||||
.map(meta => meta.title)
|
||||
.sort((a, b) => (a ?? '').localeCompare(b ?? '')),
|
||||
entry: snapshotDocByTitle(collection, 'entry', titleById),
|
||||
}).toMatchSnapshot();
|
||||
});
|
||||
|
||||
test('paragraph', async () => {
|
||||
const blockSnapshot: BlockSnapshot = {
|
||||
type: 'block',
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
import {
|
||||
BlockMarkdownAdapterExtension,
|
||||
type BlockMarkdownAdapterMatcher,
|
||||
createAttachmentBlockSnapshot,
|
||||
FOOTNOTE_DEFINITION_PREFIX,
|
||||
getFootnoteDefinitionText,
|
||||
isFootnoteDefinitionNode,
|
||||
@@ -56,18 +57,15 @@ export const attachmentBlockMarkdownAdapterMatcher: BlockMarkdownAdapterMatcher
|
||||
}
|
||||
walkerContext
|
||||
.openNode(
|
||||
{
|
||||
type: 'block',
|
||||
createAttachmentBlockSnapshot({
|
||||
id: nanoid(),
|
||||
flavour: AttachmentBlockSchema.model.flavour,
|
||||
props: {
|
||||
name: fileName,
|
||||
sourceId: blobId,
|
||||
footnoteIdentifier,
|
||||
style: 'citation',
|
||||
},
|
||||
children: [],
|
||||
},
|
||||
}),
|
||||
'children'
|
||||
)
|
||||
.closeNode();
|
||||
|
||||
@@ -32,6 +32,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@vitest/browser-playwright": "^4.0.18",
|
||||
"playwright": "=1.58.2",
|
||||
"vitest": "^4.0.18"
|
||||
},
|
||||
"exports": {
|
||||
|
||||
@@ -516,6 +516,9 @@ export const EdgelessNoteInteraction =
|
||||
}
|
||||
})
|
||||
.catch(console.error);
|
||||
} else if (multiSelect && alreadySelected && editing) {
|
||||
// range selection using Shift-click when editing
|
||||
return;
|
||||
} else {
|
||||
context.default(context);
|
||||
}
|
||||
|
||||
@@ -83,9 +83,9 @@ export class RecordField extends SignalWatcher(
|
||||
border: 1px solid transparent;
|
||||
}
|
||||
|
||||
.field-content .affine-database-number {
|
||||
.field-content affine-database-number-cell .number {
|
||||
text-align: left;
|
||||
justify-content: start;
|
||||
justify-content: flex-start;
|
||||
}
|
||||
|
||||
.field-content:hover {
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@vitest/browser-playwright": "^4.0.18",
|
||||
"playwright": "=1.58.2",
|
||||
"vitest": "^4.0.18"
|
||||
},
|
||||
"exports": {
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import { AttachmentBlockSchema } from '@blocksuite/affine-model';
|
||||
import {
|
||||
type AttachmentBlockProps,
|
||||
AttachmentBlockSchema,
|
||||
} from '@blocksuite/affine-model';
|
||||
import { BlockSuiteError, ErrorCode } from '@blocksuite/global/exceptions';
|
||||
import {
|
||||
type AssetsManager,
|
||||
@@ -23,6 +26,24 @@ import { AdapterFactoryIdentifier } from './types/adapter';
|
||||
|
||||
export type Attachment = File[];
|
||||
|
||||
type CreateAttachmentBlockSnapshotOptions = {
|
||||
id?: string;
|
||||
props: Partial<AttachmentBlockProps> & Pick<AttachmentBlockProps, 'name'>;
|
||||
};
|
||||
|
||||
export function createAttachmentBlockSnapshot({
|
||||
id = nanoid(),
|
||||
props,
|
||||
}: CreateAttachmentBlockSnapshotOptions): BlockSnapshot {
|
||||
return {
|
||||
type: 'block',
|
||||
id,
|
||||
flavour: AttachmentBlockSchema.model.flavour,
|
||||
props,
|
||||
children: [],
|
||||
};
|
||||
}
|
||||
|
||||
type AttachmentToSliceSnapshotPayload = {
|
||||
file: Attachment;
|
||||
assets?: AssetsManager;
|
||||
@@ -97,8 +118,6 @@ export class AttachmentAdapter extends BaseAdapter<Attachment> {
|
||||
if (files.length === 0) return null;
|
||||
|
||||
const content: SliceSnapshot['content'] = [];
|
||||
const flavour = AttachmentBlockSchema.model.flavour;
|
||||
|
||||
for (const blob of files) {
|
||||
const id = nanoid();
|
||||
const { name, size, type } = blob;
|
||||
@@ -108,22 +127,21 @@ export class AttachmentAdapter extends BaseAdapter<Attachment> {
|
||||
mapInto: sourceId => ({ sourceId }),
|
||||
});
|
||||
|
||||
content.push({
|
||||
type: 'block',
|
||||
flavour,
|
||||
id,
|
||||
props: {
|
||||
name,
|
||||
size,
|
||||
type,
|
||||
embed: false,
|
||||
style: 'horizontalThin',
|
||||
index: 'a0',
|
||||
xywh: '[0,0,0,0]',
|
||||
rotate: 0,
|
||||
},
|
||||
children: [],
|
||||
});
|
||||
content.push(
|
||||
createAttachmentBlockSnapshot({
|
||||
id,
|
||||
props: {
|
||||
name,
|
||||
size,
|
||||
type,
|
||||
embed: false,
|
||||
style: 'horizontalThin',
|
||||
index: 'a0',
|
||||
xywh: '[0,0,0,0]',
|
||||
rotate: 0,
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
function safeDecodePathReference(path: string): string {
|
||||
try {
|
||||
return decodeURIComponent(path);
|
||||
} catch {
|
||||
return path;
|
||||
}
|
||||
}
|
||||
|
||||
export function normalizeFilePathReference(path: string): string {
|
||||
return safeDecodePathReference(path)
|
||||
.trim()
|
||||
.replace(/\\/g, '/')
|
||||
.replace(/^\.\/+/, '')
|
||||
.replace(/^\/+/, '')
|
||||
.replace(/\/+/g, '/');
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalizes a relative path by resolving all relative path segments
|
||||
* @param basePath The base path (markdown file's directory)
|
||||
@@ -40,7 +57,7 @@ export function getImageFullPath(
|
||||
imageReference: string
|
||||
): string {
|
||||
// Decode the image reference in case it contains URL-encoded characters
|
||||
const decodedReference = decodeURIComponent(imageReference);
|
||||
const decodedReference = safeDecodePathReference(imageReference);
|
||||
|
||||
// Get the directory of the file path
|
||||
const markdownDir = filePath.substring(0, filePath.lastIndexOf('/'));
|
||||
|
||||
@@ -20,9 +20,30 @@ declare global {
|
||||
showOpenFilePicker?: (
|
||||
options?: OpenFilePickerOptions
|
||||
) => Promise<FileSystemFileHandle[]>;
|
||||
// Window API: showDirectoryPicker
|
||||
showDirectoryPicker?: (options?: {
|
||||
id?: string;
|
||||
mode?: 'read' | 'readwrite';
|
||||
startIn?: FileSystemHandle | string;
|
||||
}) => Promise<FileSystemDirectoryHandle>;
|
||||
}
|
||||
}
|
||||
|
||||
// Minimal polyfill for FileSystemDirectoryHandle to iterate over files
|
||||
interface FileSystemDirectoryHandle {
|
||||
kind: 'directory';
|
||||
name: string;
|
||||
values(): AsyncIterableIterator<
|
||||
FileSystemFileHandle | FileSystemDirectoryHandle
|
||||
>;
|
||||
}
|
||||
|
||||
interface FileSystemFileHandle {
|
||||
kind: 'file';
|
||||
name: string;
|
||||
getFile(): Promise<File>;
|
||||
}
|
||||
|
||||
// See [Common MIME types](https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types/Common_types)
|
||||
const FileTypes: NonNullable<OpenFilePickerOptions['types']> = [
|
||||
{
|
||||
@@ -121,21 +142,27 @@ type AcceptTypes =
|
||||
| 'Docx'
|
||||
| 'MindMap';
|
||||
|
||||
export async function openFilesWith(
|
||||
acceptType: AcceptTypes = 'Any',
|
||||
multiple: boolean = true
|
||||
): Promise<File[] | null> {
|
||||
// Feature detection. The API needs to be supported
|
||||
// and the app not run in an iframe.
|
||||
const supportsFileSystemAccess =
|
||||
'showOpenFilePicker' in window &&
|
||||
function canUseFileSystemAccessAPI(
|
||||
api: 'showOpenFilePicker' | 'showDirectoryPicker'
|
||||
) {
|
||||
return (
|
||||
api in window &&
|
||||
(() => {
|
||||
try {
|
||||
return window.self === window.top;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
})();
|
||||
})()
|
||||
);
|
||||
}
|
||||
|
||||
export async function openFilesWith(
|
||||
acceptType: AcceptTypes = 'Any',
|
||||
multiple: boolean = true
|
||||
): Promise<File[] | null> {
|
||||
const supportsFileSystemAccess =
|
||||
canUseFileSystemAccessAPI('showOpenFilePicker');
|
||||
|
||||
// If the File System Access API is supported…
|
||||
if (supportsFileSystemAccess && window.showOpenFilePicker) {
|
||||
@@ -194,6 +221,75 @@ export async function openFilesWith(
|
||||
});
|
||||
}
|
||||
|
||||
export async function openDirectory(): Promise<File[] | null> {
|
||||
const supportsFileSystemAccess = canUseFileSystemAccessAPI(
|
||||
'showDirectoryPicker'
|
||||
);
|
||||
|
||||
if (supportsFileSystemAccess && window.showDirectoryPicker) {
|
||||
try {
|
||||
const dirHandle = await window.showDirectoryPicker();
|
||||
const files: File[] = [];
|
||||
|
||||
const readDirectory = async (
|
||||
directoryHandle: FileSystemDirectoryHandle,
|
||||
path: string
|
||||
) => {
|
||||
for await (const handle of directoryHandle.values()) {
|
||||
const relativePath = path ? `${path}/${handle.name}` : handle.name;
|
||||
if (handle.kind === 'file') {
|
||||
const fileHandle = handle as FileSystemFileHandle;
|
||||
if (fileHandle.getFile) {
|
||||
const file = await fileHandle.getFile();
|
||||
Object.defineProperty(file, 'webkitRelativePath', {
|
||||
value: relativePath,
|
||||
writable: false,
|
||||
});
|
||||
files.push(file);
|
||||
}
|
||||
} else if (handle.kind === 'directory') {
|
||||
await readDirectory(
|
||||
handle as FileSystemDirectoryHandle,
|
||||
relativePath
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
await readDirectory(dirHandle, '');
|
||||
return files;
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return new Promise(resolve => {
|
||||
const input = document.createElement('input');
|
||||
input.classList.add('affine-upload-input');
|
||||
input.style.display = 'none';
|
||||
input.type = 'file';
|
||||
|
||||
input.setAttribute('webkitdirectory', '');
|
||||
input.setAttribute('directory', '');
|
||||
|
||||
document.body.append(input);
|
||||
|
||||
input.addEventListener('change', () => {
|
||||
input.remove();
|
||||
resolve(input.files ? Array.from(input.files) : null);
|
||||
});
|
||||
|
||||
input.addEventListener('cancel', () => resolve(null));
|
||||
|
||||
if ('showPicker' in HTMLInputElement.prototype) {
|
||||
input.showPicker();
|
||||
} else {
|
||||
input.click();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
export async function openSingleFileWith(
|
||||
acceptType?: AcceptTypes
|
||||
): Promise<File | null> {
|
||||
|
||||
@@ -17,7 +17,14 @@ export async function printToPdf(
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
const iframe = document.createElement('iframe');
|
||||
document.body.append(iframe);
|
||||
iframe.style.display = 'none';
|
||||
// Use a hidden but rendering-enabled state instead of display: none
|
||||
Object.assign(iframe.style, {
|
||||
visibility: 'hidden',
|
||||
position: 'absolute',
|
||||
width: '0',
|
||||
height: '0',
|
||||
border: 'none',
|
||||
});
|
||||
iframe.srcdoc = '<!DOCTYPE html>';
|
||||
iframe.onload = async () => {
|
||||
if (!iframe.contentWindow) {
|
||||
@@ -28,6 +35,44 @@ export async function printToPdf(
|
||||
reject(new Error('Root element not defined, unable to print pdf'));
|
||||
return;
|
||||
}
|
||||
|
||||
const doc = iframe.contentWindow.document;
|
||||
|
||||
doc.write(`<!DOCTYPE html><html><head><style>@media print {
|
||||
html, body {
|
||||
height: initial !important;
|
||||
overflow: initial !important;
|
||||
print-color-adjust: exact;
|
||||
-webkit-print-color-adjust: exact;
|
||||
color: #000 !important;
|
||||
background: #fff !important;
|
||||
color-scheme: light !important;
|
||||
}
|
||||
::-webkit-scrollbar {
|
||||
display: none;
|
||||
}
|
||||
:root, body {
|
||||
--affine-text-primary: #000 !important;
|
||||
--affine-text-secondary: #111 !important;
|
||||
--affine-text-tertiary: #333 !important;
|
||||
--affine-background-primary: #fff !important;
|
||||
--affine-background-secondary: #fff !important;
|
||||
--affine-background-tertiary: #fff !important;
|
||||
}
|
||||
body, [data-theme='dark'] {
|
||||
color: #000 !important;
|
||||
background: #fff !important;
|
||||
}
|
||||
body * {
|
||||
color: #000 !important;
|
||||
-webkit-text-fill-color: #000 !important;
|
||||
}
|
||||
:root {
|
||||
--affine-note-shadow-box: none !important;
|
||||
--affine-note-shadow-sticker: none !important;
|
||||
}
|
||||
}</style></head><body></body></html>`);
|
||||
doc.close();
|
||||
iframe.contentWindow.document
|
||||
.write(`<!DOCTYPE html><html><head><style>@media print {
|
||||
html, body {
|
||||
@@ -49,6 +94,9 @@ export async function printToPdf(
|
||||
--affine-background-primary: #fff !important;
|
||||
--affine-background-secondary: #fff !important;
|
||||
--affine-background-tertiary: #fff !important;
|
||||
--affine-background-code-block: #f5f5f5 !important;
|
||||
--affine-quote-color: #e3e3e3 !important;
|
||||
--affine-border-color: #e3e3e3 !important;
|
||||
}
|
||||
body, [data-theme='dark'] {
|
||||
color: #000 !important;
|
||||
@@ -68,7 +116,7 @@ export async function printToPdf(
|
||||
for (const element of document.styleSheets) {
|
||||
try {
|
||||
for (const cssRule of element.cssRules) {
|
||||
const target = iframe.contentWindow.document.styleSheets[0];
|
||||
const target = doc.styleSheets[0];
|
||||
target.insertRule(cssRule.cssText, target.cssRules.length);
|
||||
}
|
||||
} catch (e) {
|
||||
@@ -83,12 +131,33 @@ export async function printToPdf(
|
||||
}
|
||||
}
|
||||
|
||||
// Recursive function to find all canvases, including those in shadow roots
|
||||
const findAllCanvases = (root: Node): HTMLCanvasElement[] => {
|
||||
const canvases: HTMLCanvasElement[] = [];
|
||||
const traverse = (node: Node) => {
|
||||
if (node instanceof HTMLCanvasElement) {
|
||||
canvases.push(node);
|
||||
}
|
||||
if (node instanceof HTMLElement || node instanceof ShadowRoot) {
|
||||
node.childNodes.forEach(traverse);
|
||||
}
|
||||
if (node instanceof HTMLElement && node.shadowRoot) {
|
||||
traverse(node.shadowRoot);
|
||||
}
|
||||
};
|
||||
traverse(root);
|
||||
return canvases;
|
||||
};
|
||||
|
||||
// convert all canvas to image
|
||||
const canvasImgObjectUrlMap = new Map<string, string>();
|
||||
const allCanvas = rootElement.getElementsByTagName('canvas');
|
||||
const allCanvas = findAllCanvases(rootElement);
|
||||
let canvasKey = 1;
|
||||
const canvasToKeyMap = new Map<HTMLCanvasElement, string>();
|
||||
|
||||
for (const canvas of allCanvas) {
|
||||
canvas.dataset['printToPdfCanvasKey'] = canvasKey.toString();
|
||||
const key = canvasKey.toString();
|
||||
canvasToKeyMap.set(canvas, key);
|
||||
canvasKey++;
|
||||
const canvasImgObjectUrl = await new Promise<Blob | null>(resolve => {
|
||||
try {
|
||||
@@ -103,20 +172,42 @@ export async function printToPdf(
|
||||
);
|
||||
continue;
|
||||
}
|
||||
canvasImgObjectUrlMap.set(
|
||||
canvas.dataset['printToPdfCanvasKey'],
|
||||
URL.createObjectURL(canvasImgObjectUrl)
|
||||
);
|
||||
canvasImgObjectUrlMap.set(key, URL.createObjectURL(canvasImgObjectUrl));
|
||||
}
|
||||
|
||||
const importedRoot = iframe.contentWindow.document.importNode(
|
||||
rootElement,
|
||||
true
|
||||
) as HTMLDivElement;
|
||||
// Recursive deep clone that flattens Shadow DOM into Light DOM
|
||||
const deepCloneWithShadows = (node: Node): Node => {
|
||||
const clone = doc.importNode(node, false);
|
||||
|
||||
if (
|
||||
clone instanceof HTMLCanvasElement &&
|
||||
node instanceof HTMLCanvasElement
|
||||
) {
|
||||
const key = canvasToKeyMap.get(node);
|
||||
if (key) {
|
||||
clone.dataset['printToPdfCanvasKey'] = key;
|
||||
}
|
||||
}
|
||||
|
||||
const appendChildren = (source: Node) => {
|
||||
source.childNodes.forEach(child => {
|
||||
(clone as Element).append(deepCloneWithShadows(child));
|
||||
});
|
||||
};
|
||||
|
||||
if (node instanceof HTMLElement && node.shadowRoot) {
|
||||
appendChildren(node.shadowRoot);
|
||||
}
|
||||
appendChildren(node);
|
||||
|
||||
return clone;
|
||||
};
|
||||
|
||||
const importedRoot = deepCloneWithShadows(rootElement) as HTMLDivElement;
|
||||
|
||||
// force light theme in print iframe
|
||||
iframe.contentWindow.document.documentElement.dataset.theme = 'light';
|
||||
iframe.contentWindow.document.body.dataset.theme = 'light';
|
||||
doc.documentElement.dataset.theme = 'light';
|
||||
doc.body.dataset.theme = 'light';
|
||||
importedRoot.dataset.theme = 'light';
|
||||
|
||||
// draw saved canvas image to canvas
|
||||
@@ -135,17 +226,67 @@ export async function printToPdf(
|
||||
}
|
||||
}
|
||||
|
||||
// append to iframe and print
|
||||
iframe.contentWindow.document.body.append(importedRoot);
|
||||
// Remove lazy loading from all images and force reload
|
||||
const allImages = importedRoot.querySelectorAll('img');
|
||||
allImages.forEach(img => {
|
||||
img.removeAttribute('loading');
|
||||
const src = img.getAttribute('src');
|
||||
if (src) img.setAttribute('src', src);
|
||||
});
|
||||
|
||||
// append to iframe
|
||||
doc.body.append(importedRoot);
|
||||
|
||||
await options.beforeprint?.(iframe);
|
||||
|
||||
// browser may take some time to load font
|
||||
await new Promise<void>(resolve => {
|
||||
setTimeout(() => {
|
||||
resolve();
|
||||
}, 1000);
|
||||
});
|
||||
// Robust image waiting logic
|
||||
const waitForImages = async (container: HTMLElement) => {
|
||||
const images: HTMLImageElement[] = [];
|
||||
const view = container.ownerDocument.defaultView;
|
||||
if (!view) return;
|
||||
|
||||
const findImages = (root: Node) => {
|
||||
if (root instanceof view.HTMLImageElement) {
|
||||
images.push(root);
|
||||
}
|
||||
if (
|
||||
root instanceof view.HTMLElement ||
|
||||
root instanceof view.ShadowRoot
|
||||
) {
|
||||
root.childNodes.forEach(findImages);
|
||||
}
|
||||
if (root instanceof view.HTMLElement && root.shadowRoot) {
|
||||
findImages(root.shadowRoot);
|
||||
}
|
||||
};
|
||||
|
||||
findImages(container);
|
||||
|
||||
await Promise.all(
|
||||
images.map(img => {
|
||||
if (img.complete) {
|
||||
if (img.naturalWidth === 0) {
|
||||
console.warn('Image failed to load:', img.src);
|
||||
}
|
||||
return Promise.resolve();
|
||||
}
|
||||
return new Promise(resolve => {
|
||||
img.onload = resolve;
|
||||
img.onerror = resolve;
|
||||
});
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
await waitForImages(importedRoot);
|
||||
|
||||
// browser may take some time to load font or other resources
|
||||
await (doc.fonts?.ready ??
|
||||
new Promise<void>(resolve => {
|
||||
setTimeout(() => {
|
||||
resolve();
|
||||
}, 1000);
|
||||
}));
|
||||
|
||||
iframe.contentWindow.onafterprint = async () => {
|
||||
iframe.remove();
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
NotionIcon,
|
||||
} from '@blocksuite/affine-components/icons';
|
||||
import {
|
||||
openDirectory,
|
||||
openFilesWith,
|
||||
openSingleFileWith,
|
||||
} from '@blocksuite/affine-shared/utils';
|
||||
@@ -18,11 +19,16 @@ import { query, state } from 'lit/decorators.js';
|
||||
import { HtmlTransformer } from '../transformers/html.js';
|
||||
import { MarkdownTransformer } from '../transformers/markdown.js';
|
||||
import { NotionHtmlTransformer } from '../transformers/notion-html.js';
|
||||
import { ObsidianTransformer } from '../transformers/obsidian.js';
|
||||
import { styles } from './styles.js';
|
||||
|
||||
export type OnSuccessHandler = (
|
||||
pageIds: string[],
|
||||
options: { isWorkspaceFile: boolean; importedCount: number }
|
||||
options: {
|
||||
isWorkspaceFile: boolean;
|
||||
importedCount: number;
|
||||
docEmojis?: Map<string, string>;
|
||||
}
|
||||
) => void;
|
||||
|
||||
export type OnFailHandler = (message: string) => void;
|
||||
@@ -140,6 +146,29 @@ export class ImportDoc extends WithDisposable(LitElement) {
|
||||
});
|
||||
}
|
||||
|
||||
private async _importObsidian() {
|
||||
const files = await openDirectory();
|
||||
if (!files || files.length === 0) return;
|
||||
const needLoading =
|
||||
files.reduce((acc, f) => acc + f.size, 0) > SHOW_LOADING_SIZE;
|
||||
if (needLoading) {
|
||||
this.hidden = false;
|
||||
this._loading = true;
|
||||
} else {
|
||||
this.abortController.abort();
|
||||
}
|
||||
const { docIds, docEmojis } = await ObsidianTransformer.importObsidianVault(
|
||||
{
|
||||
collection: this.collection,
|
||||
schema: this.schema,
|
||||
importedFiles: files,
|
||||
extensions: this.extensions,
|
||||
}
|
||||
);
|
||||
needLoading && this.abortController.abort();
|
||||
this._onImportSuccess(docIds, { docEmojis });
|
||||
}
|
||||
|
||||
private _onCloseClick(event: MouseEvent) {
|
||||
event.stopPropagation();
|
||||
this.abortController.abort();
|
||||
@@ -151,15 +180,21 @@ export class ImportDoc extends WithDisposable(LitElement) {
|
||||
|
||||
private _onImportSuccess(
|
||||
pageIds: string[],
|
||||
options: { isWorkspaceFile?: boolean; importedCount?: number } = {}
|
||||
options: {
|
||||
isWorkspaceFile?: boolean;
|
||||
importedCount?: number;
|
||||
docEmojis?: Map<string, string>;
|
||||
} = {}
|
||||
) {
|
||||
const {
|
||||
isWorkspaceFile = false,
|
||||
importedCount: pagesImportedCount = pageIds.length,
|
||||
docEmojis,
|
||||
} = options;
|
||||
this.onSuccess?.(pageIds, {
|
||||
isWorkspaceFile,
|
||||
importedCount: pagesImportedCount,
|
||||
docEmojis,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -258,6 +293,13 @@ export class ImportDoc extends WithDisposable(LitElement) {
|
||||
</affine-tooltip>
|
||||
</div>
|
||||
</icon-button>
|
||||
<icon-button
|
||||
class="button-item"
|
||||
text="Obsidian"
|
||||
@click="${this._importObsidian}"
|
||||
>
|
||||
${ExportToMarkdownIcon}
|
||||
</icon-button>
|
||||
<icon-button class="button-item" text="Coming soon..." disabled>
|
||||
${NewIcon}
|
||||
</icon-button>
|
||||
|
||||
@@ -2,6 +2,7 @@ export { DocxTransformer } from './docx.js';
|
||||
export { HtmlTransformer } from './html.js';
|
||||
export { MarkdownTransformer } from './markdown.js';
|
||||
export { NotionHtmlTransformer } from './notion-html.js';
|
||||
export { ObsidianTransformer } from './obsidian.js';
|
||||
export { PdfTransformer } from './pdf.js';
|
||||
export { createAssetsArchive, download } from './utils.js';
|
||||
export { ZipTransformer } from './zip.js';
|
||||
|
||||
@@ -21,8 +21,11 @@ import { extMimeMap, Transformer } from '@blocksuite/store';
|
||||
import type { AssetMap, ImportedFileEntry, PathBlobIdMap } from './type.js';
|
||||
import { createAssetsArchive, download, parseMatter, Unzip } from './utils.js';
|
||||
|
||||
type ParsedFrontmatterMeta = Partial<
|
||||
Pick<DocMeta, 'title' | 'createDate' | 'updatedDate' | 'tags' | 'favorite'>
|
||||
export type ParsedFrontmatterMeta = Partial<
|
||||
Pick<
|
||||
DocMeta,
|
||||
'title' | 'createDate' | 'updatedDate' | 'tags' | 'favorite' | 'trash'
|
||||
>
|
||||
>;
|
||||
|
||||
const FRONTMATTER_KEYS = {
|
||||
@@ -150,11 +153,18 @@ function buildMetaFromFrontmatter(
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (FRONTMATTER_KEYS.trash.includes(key)) {
|
||||
const trash = parseBoolean(value);
|
||||
if (trash !== undefined) {
|
||||
meta.trash = trash;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return meta;
|
||||
}
|
||||
|
||||
function parseFrontmatter(markdown: string): {
|
||||
export function parseFrontmatter(markdown: string): {
|
||||
content: string;
|
||||
meta: ParsedFrontmatterMeta;
|
||||
} {
|
||||
@@ -176,7 +186,7 @@ function parseFrontmatter(markdown: string): {
|
||||
}
|
||||
}
|
||||
|
||||
function applyMetaPatch(
|
||||
export function applyMetaPatch(
|
||||
collection: Workspace,
|
||||
docId: string,
|
||||
meta: ParsedFrontmatterMeta
|
||||
@@ -187,13 +197,14 @@ function applyMetaPatch(
|
||||
if (meta.updatedDate !== undefined) metaPatch.updatedDate = meta.updatedDate;
|
||||
if (meta.tags) metaPatch.tags = meta.tags;
|
||||
if (meta.favorite !== undefined) metaPatch.favorite = meta.favorite;
|
||||
if (meta.trash !== undefined) metaPatch.trash = meta.trash;
|
||||
|
||||
if (Object.keys(metaPatch).length) {
|
||||
collection.meta.setDocMeta(docId, metaPatch);
|
||||
}
|
||||
}
|
||||
|
||||
function getProvider(extensions: ExtensionType[]) {
|
||||
export function getProvider(extensions: ExtensionType[]) {
|
||||
const container = new Container();
|
||||
extensions.forEach(ext => {
|
||||
ext.setup(container);
|
||||
@@ -223,6 +234,103 @@ type ImportMarkdownZipOptions = {
|
||||
extensions: ExtensionType[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Filters hidden/system entries that should never participate in imports.
|
||||
*/
|
||||
export function isSystemImportPath(path: string) {
|
||||
return path.includes('__MACOSX') || path.includes('.DS_Store');
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates the doc CRUD bridge used by importer transformers.
|
||||
*/
|
||||
export function createCollectionDocCRUD(collection: Workspace) {
|
||||
return {
|
||||
create: (id: string) => collection.createDoc(id).getStore({ id }),
|
||||
get: (id: string) => collection.getDoc(id)?.getStore({ id }) ?? null,
|
||||
delete: (id: string) => collection.removeDoc(id),
|
||||
};
|
||||
}
|
||||
|
||||
type CreateMarkdownImportJobOptions = {
|
||||
collection: Workspace;
|
||||
schema: Schema;
|
||||
preferredTitle?: string;
|
||||
fullPath?: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a markdown import job with the standard collection middlewares.
|
||||
*/
|
||||
export function createMarkdownImportJob({
|
||||
collection,
|
||||
schema,
|
||||
preferredTitle,
|
||||
fullPath,
|
||||
}: CreateMarkdownImportJobOptions) {
|
||||
return new Transformer({
|
||||
schema,
|
||||
blobCRUD: collection.blobSync,
|
||||
docCRUD: createCollectionDocCRUD(collection),
|
||||
middlewares: [
|
||||
defaultImageProxyMiddleware,
|
||||
fileNameMiddleware(preferredTitle),
|
||||
docLinkBaseURLMiddleware(collection.id),
|
||||
...(fullPath ? [filePathMiddleware(fullPath)] : []),
|
||||
],
|
||||
});
|
||||
}
|
||||
|
||||
type StageImportedAssetOptions = {
|
||||
pendingAssets: AssetMap;
|
||||
pendingPathBlobIdMap: PathBlobIdMap;
|
||||
path: string;
|
||||
content: Blob;
|
||||
fileName: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* Hashes a non-markdown import file and stages it into the shared asset maps.
|
||||
*/
|
||||
export async function stageImportedAsset({
|
||||
pendingAssets,
|
||||
pendingPathBlobIdMap,
|
||||
path,
|
||||
content,
|
||||
fileName,
|
||||
}: StageImportedAssetOptions) {
|
||||
const ext = path.split('.').at(-1) ?? '';
|
||||
const mime = extMimeMap.get(ext.toLowerCase()) ?? '';
|
||||
const key = await sha(await content.arrayBuffer());
|
||||
pendingPathBlobIdMap.set(path, key);
|
||||
pendingAssets.set(key, new File([content], fileName, { type: mime }));
|
||||
}
|
||||
|
||||
/**
|
||||
* Binds previously staged asset files into a transformer job before import.
|
||||
*/
|
||||
export function bindImportedAssetsToJob(
|
||||
job: Transformer,
|
||||
pendingAssets: AssetMap,
|
||||
pendingPathBlobIdMap: PathBlobIdMap
|
||||
) {
|
||||
const pathBlobIdMap = job.assetsManager.getPathBlobIdMap();
|
||||
// Iterate over all assets to be imported
|
||||
for (const [assetPath, key] of pendingPathBlobIdMap.entries()) {
|
||||
// Get the relative path of the asset to the markdown file
|
||||
// Store the path to blobId map
|
||||
pathBlobIdMap.set(assetPath, key);
|
||||
// Store the asset to assets, the key is the blobId, the value is the file object
|
||||
// In block adapter, it will use the blobId to get the file object
|
||||
const assetFile = pendingAssets.get(key);
|
||||
if (assetFile) {
|
||||
job.assets.set(key, assetFile);
|
||||
}
|
||||
}
|
||||
|
||||
return pathBlobIdMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Exports a doc to a Markdown file or a zip archive containing Markdown and assets.
|
||||
* @param doc The doc to export
|
||||
@@ -329,19 +437,10 @@ async function importMarkdownToDoc({
|
||||
const { content, meta } = parseFrontmatter(markdown);
|
||||
const preferredTitle = meta.title ?? fileName;
|
||||
const provider = getProvider(extensions);
|
||||
const job = new Transformer({
|
||||
const job = createMarkdownImportJob({
|
||||
collection,
|
||||
schema,
|
||||
blobCRUD: collection.blobSync,
|
||||
docCRUD: {
|
||||
create: (id: string) => collection.createDoc(id).getStore({ id }),
|
||||
get: (id: string) => collection.getDoc(id)?.getStore({ id }) ?? null,
|
||||
delete: (id: string) => collection.removeDoc(id),
|
||||
},
|
||||
middlewares: [
|
||||
defaultImageProxyMiddleware,
|
||||
fileNameMiddleware(preferredTitle),
|
||||
docLinkBaseURLMiddleware(collection.id),
|
||||
],
|
||||
preferredTitle,
|
||||
});
|
||||
const mdAdapter = new MarkdownAdapter(job, provider);
|
||||
const page = await mdAdapter.toDoc({
|
||||
@@ -381,7 +480,7 @@ async function importMarkdownZip({
|
||||
// Iterate over all files in the zip
|
||||
for (const { path, content: blob } of unzip) {
|
||||
// Skip the files that are not markdown files
|
||||
if (path.includes('__MACOSX') || path.includes('.DS_Store')) {
|
||||
if (isSystemImportPath(path)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -395,12 +494,13 @@ async function importMarkdownZip({
|
||||
fullPath: path,
|
||||
});
|
||||
} else {
|
||||
// If the file is not a markdown file, store it to pendingAssets
|
||||
const ext = path.split('.').at(-1) ?? '';
|
||||
const mime = extMimeMap.get(ext) ?? '';
|
||||
const key = await sha(await blob.arrayBuffer());
|
||||
pendingPathBlobIdMap.set(path, key);
|
||||
pendingAssets.set(key, new File([blob], fileName, { type: mime }));
|
||||
await stageImportedAsset({
|
||||
pendingAssets,
|
||||
pendingPathBlobIdMap,
|
||||
path,
|
||||
content: blob,
|
||||
fileName,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -411,34 +511,13 @@ async function importMarkdownZip({
|
||||
const markdown = await contentBlob.text();
|
||||
const { content, meta } = parseFrontmatter(markdown);
|
||||
const preferredTitle = meta.title ?? fileNameWithoutExt;
|
||||
const job = new Transformer({
|
||||
const job = createMarkdownImportJob({
|
||||
collection,
|
||||
schema,
|
||||
blobCRUD: collection.blobSync,
|
||||
docCRUD: {
|
||||
create: (id: string) => collection.createDoc(id).getStore({ id }),
|
||||
get: (id: string) => collection.getDoc(id)?.getStore({ id }) ?? null,
|
||||
delete: (id: string) => collection.removeDoc(id),
|
||||
},
|
||||
middlewares: [
|
||||
defaultImageProxyMiddleware,
|
||||
fileNameMiddleware(preferredTitle),
|
||||
docLinkBaseURLMiddleware(collection.id),
|
||||
filePathMiddleware(fullPath),
|
||||
],
|
||||
preferredTitle,
|
||||
fullPath,
|
||||
});
|
||||
const assets = job.assets;
|
||||
const pathBlobIdMap = job.assetsManager.getPathBlobIdMap();
|
||||
// Iterate over all assets to be imported
|
||||
for (const [assetPath, key] of pendingPathBlobIdMap.entries()) {
|
||||
// Get the relative path of the asset to the markdown file
|
||||
// Store the path to blobId map
|
||||
pathBlobIdMap.set(assetPath, key);
|
||||
// Store the asset to assets, the key is the blobId, the value is the file object
|
||||
// In block adapter, it will use the blobId to get the file object
|
||||
if (pendingAssets.get(key)) {
|
||||
assets.set(key, pendingAssets.get(key)!);
|
||||
}
|
||||
}
|
||||
bindImportedAssetsToJob(job, pendingAssets, pendingPathBlobIdMap);
|
||||
|
||||
const mdAdapter = new MarkdownAdapter(job, provider);
|
||||
const doc = await mdAdapter.toDoc({
|
||||
|
||||
@@ -0,0 +1,732 @@
|
||||
import { FootNoteReferenceParamsSchema } from '@blocksuite/affine-model';
|
||||
import {
|
||||
BlockMarkdownAdapterExtension,
|
||||
createAttachmentBlockSnapshot,
|
||||
FULL_FILE_PATH_KEY,
|
||||
getImageFullPath,
|
||||
MarkdownAdapter,
|
||||
type MarkdownAST,
|
||||
MarkdownASTToDeltaExtension,
|
||||
normalizeFilePathReference,
|
||||
} from '@blocksuite/affine-shared/adapters';
|
||||
import type { AffineTextAttributes } from '@blocksuite/affine-shared/types';
|
||||
import type {
|
||||
DeltaInsert,
|
||||
ExtensionType,
|
||||
Schema,
|
||||
Workspace,
|
||||
} from '@blocksuite/store';
|
||||
import { extMimeMap, nanoid } from '@blocksuite/store';
|
||||
import type { Html, Text } from 'mdast';
|
||||
|
||||
import {
|
||||
applyMetaPatch,
|
||||
bindImportedAssetsToJob,
|
||||
createMarkdownImportJob,
|
||||
getProvider,
|
||||
isSystemImportPath,
|
||||
parseFrontmatter,
|
||||
stageImportedAsset,
|
||||
} from './markdown.js';
|
||||
import type {
|
||||
AssetMap,
|
||||
MarkdownFileImportEntry,
|
||||
PathBlobIdMap,
|
||||
} from './type.js';
|
||||
|
||||
const CALLOUT_TYPE_MAP: Record<string, string> = {
|
||||
note: '💡',
|
||||
info: 'ℹ️',
|
||||
tip: '🔥',
|
||||
hint: '✅',
|
||||
important: '‼️',
|
||||
warning: '⚠️',
|
||||
caution: '⚠️',
|
||||
attention: '⚠️',
|
||||
danger: '⚠️',
|
||||
error: '🚨',
|
||||
bug: '🐛',
|
||||
example: '📌',
|
||||
quote: '💬',
|
||||
cite: '💬',
|
||||
abstract: '📋',
|
||||
summary: '📋',
|
||||
todo: '☑️',
|
||||
success: '✅',
|
||||
check: '✅',
|
||||
done: '✅',
|
||||
failure: '❌',
|
||||
fail: '❌',
|
||||
missing: '❌',
|
||||
question: '❓',
|
||||
help: '❓',
|
||||
faq: '❓',
|
||||
};
|
||||
|
||||
const AMBIGUOUS_PAGE_LOOKUP = '__ambiguous__';
|
||||
const DEFAULT_CALLOUT_EMOJI = '💡';
|
||||
const OBSIDIAN_TEXT_FOOTNOTE_URL_PREFIX = 'data:text/plain;charset=utf-8,';
|
||||
const OBSIDIAN_ATTACHMENT_EMBED_TAG = 'obsidian-attachment';
|
||||
|
||||
function normalizeLookupKey(value: string): string {
|
||||
return normalizeFilePathReference(value).toLowerCase();
|
||||
}
|
||||
|
||||
function stripMarkdownExtension(value: string): string {
|
||||
return value.replace(/\.md$/i, '');
|
||||
}
|
||||
|
||||
function basename(value: string): string {
|
||||
return normalizeFilePathReference(value).split('/').pop() ?? value;
|
||||
}
|
||||
|
||||
function parseObsidianTarget(rawTarget: string): {
|
||||
path: string;
|
||||
fragment: string | null;
|
||||
} {
|
||||
const normalizedTarget = normalizeFilePathReference(rawTarget);
|
||||
const match = normalizedTarget.match(/^([^#^]+)([#^].*)?$/);
|
||||
|
||||
return {
|
||||
path: match?.[1]?.trim() ?? normalizedTarget,
|
||||
fragment: match?.[2] ?? null,
|
||||
};
|
||||
}
|
||||
|
||||
function extractTitleAndEmoji(rawTitle: string): {
|
||||
title: string;
|
||||
emoji: string | null;
|
||||
} {
|
||||
const SINGLE_LEADING_EMOJI_RE =
|
||||
/^[\s\u200b]*((?:[\p{Emoji_Presentation}\p{Extended_Pictographic}\u200b]|\u200d|\ufe0f)+)/u;
|
||||
|
||||
let currentTitle = rawTitle;
|
||||
let extractedEmojiClusters = '';
|
||||
let emojiMatch;
|
||||
|
||||
while ((emojiMatch = currentTitle.match(SINGLE_LEADING_EMOJI_RE))) {
|
||||
const matchedCluster = emojiMatch[1].trim();
|
||||
extractedEmojiClusters +=
|
||||
(extractedEmojiClusters ? ' ' : '') + matchedCluster;
|
||||
currentTitle = currentTitle.slice(emojiMatch[0].length);
|
||||
}
|
||||
|
||||
return {
|
||||
title: currentTitle.trim(),
|
||||
emoji: extractedEmojiClusters || null,
|
||||
};
|
||||
}
|
||||
|
||||
function preprocessTitleHeader(markdown: string): string {
|
||||
return markdown.replace(
|
||||
/^(\s*#\s+)(.*)$/m,
|
||||
(_, headerPrefix, titleContent) => {
|
||||
const { title: cleanTitle } = extractTitleAndEmoji(titleContent);
|
||||
return `${headerPrefix}${cleanTitle}`;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
function preprocessObsidianCallouts(markdown: string): string {
|
||||
return markdown.replace(
|
||||
/^(> *)\[!([^\]\n]+)\]([+-]?)([^\n]*)/gm,
|
||||
(_, prefix, type, _fold, rest) => {
|
||||
const calloutToken =
|
||||
CALLOUT_TYPE_MAP[type.trim().toLowerCase()] ?? DEFAULT_CALLOUT_EMOJI;
|
||||
const title = rest.trim();
|
||||
return title
|
||||
? `${prefix}[!${calloutToken}] ${title}`
|
||||
: `${prefix}[!${calloutToken}]`;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
function isStructuredFootnoteDefinition(content: string): boolean {
|
||||
try {
|
||||
return FootNoteReferenceParamsSchema.safeParse(JSON.parse(content.trim()))
|
||||
.success;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function splitFootnoteTextContent(content: string): {
|
||||
title: string;
|
||||
description?: string;
|
||||
} {
|
||||
const lines = content
|
||||
.split('\n')
|
||||
.map(line => line.trim())
|
||||
.filter(Boolean);
|
||||
const title = lines[0] ?? content.trim();
|
||||
const description = lines.slice(1).join('\n').trim();
|
||||
|
||||
return {
|
||||
title,
|
||||
...(description ? { description } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
function createTextFootnoteDefinition(content: string): string {
|
||||
const normalizedContent = content.trim();
|
||||
const { title, description } = splitFootnoteTextContent(normalizedContent);
|
||||
|
||||
return JSON.stringify({
|
||||
type: 'url',
|
||||
url: encodeURIComponent(
|
||||
`${OBSIDIAN_TEXT_FOOTNOTE_URL_PREFIX}${encodeURIComponent(
|
||||
normalizedContent
|
||||
)}`
|
||||
),
|
||||
title,
|
||||
...(description ? { description } : {}),
|
||||
});
|
||||
}
|
||||
|
||||
function extractObsidianFootnotes(markdown: string): {
|
||||
content: string;
|
||||
footnotes: string[];
|
||||
} {
|
||||
const lines = markdown.split('\n');
|
||||
const output: string[] = [];
|
||||
const footnotes: string[] = [];
|
||||
|
||||
for (let index = 0; index < lines.length; index += 1) {
|
||||
const line = lines[index];
|
||||
const match = line.match(/^\[\^([^\]]+)\]:\s*(.*)$/);
|
||||
if (!match) {
|
||||
output.push(line);
|
||||
continue;
|
||||
}
|
||||
|
||||
const identifier = match[1];
|
||||
const contentLines = [match[2]];
|
||||
|
||||
while (index + 1 < lines.length) {
|
||||
const nextLine = lines[index + 1];
|
||||
if (/^(?: {1,4}|\t)/.test(nextLine)) {
|
||||
contentLines.push(nextLine.replace(/^(?: {1,4}|\t)/, ''));
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (
|
||||
nextLine.trim() === '' &&
|
||||
index + 2 < lines.length &&
|
||||
/^(?: {1,4}|\t)/.test(lines[index + 2])
|
||||
) {
|
||||
contentLines.push('');
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
const content = contentLines.join('\n').trim();
|
||||
footnotes.push(
|
||||
`[^${identifier}]: ${
|
||||
!content || isStructuredFootnoteDefinition(content)
|
||||
? content
|
||||
: createTextFootnoteDefinition(content)
|
||||
}`
|
||||
);
|
||||
}
|
||||
|
||||
return { content: output.join('\n'), footnotes };
|
||||
}
|
||||
|
||||
function buildLookupKeys(
|
||||
targetPath: string,
|
||||
currentFilePath?: string
|
||||
): string[] {
|
||||
const parsedTargetPath = normalizeFilePathReference(targetPath);
|
||||
if (!parsedTargetPath) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const keys = new Set<string>();
|
||||
const addPathVariants = (value: string) => {
|
||||
const normalizedValue = normalizeFilePathReference(value);
|
||||
if (!normalizedValue) {
|
||||
return;
|
||||
}
|
||||
|
||||
keys.add(normalizedValue);
|
||||
keys.add(stripMarkdownExtension(normalizedValue));
|
||||
|
||||
const fileName = basename(normalizedValue);
|
||||
keys.add(fileName);
|
||||
keys.add(stripMarkdownExtension(fileName));
|
||||
|
||||
const cleanTitle = extractTitleAndEmoji(
|
||||
stripMarkdownExtension(fileName)
|
||||
).title;
|
||||
if (cleanTitle) {
|
||||
keys.add(cleanTitle);
|
||||
}
|
||||
};
|
||||
|
||||
addPathVariants(parsedTargetPath);
|
||||
|
||||
if (currentFilePath) {
|
||||
addPathVariants(getImageFullPath(currentFilePath, parsedTargetPath));
|
||||
}
|
||||
|
||||
return Array.from(keys).map(normalizeLookupKey);
|
||||
}
|
||||
|
||||
function registerPageLookup(
|
||||
pageLookupMap: Map<string, string>,
|
||||
key: string,
|
||||
pageId: string
|
||||
) {
|
||||
const normalizedKey = normalizeLookupKey(key);
|
||||
if (!normalizedKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const existing = pageLookupMap.get(normalizedKey);
|
||||
if (existing && existing !== pageId) {
|
||||
pageLookupMap.set(normalizedKey, AMBIGUOUS_PAGE_LOOKUP);
|
||||
return;
|
||||
}
|
||||
|
||||
pageLookupMap.set(normalizedKey, pageId);
|
||||
}
|
||||
|
||||
function resolvePageIdFromLookup(
|
||||
pageLookupMap: Pick<ReadonlyMap<string, string>, 'get'>,
|
||||
rawTarget: string,
|
||||
currentFilePath?: string
|
||||
): string | null {
|
||||
const { path } = parseObsidianTarget(rawTarget);
|
||||
for (const key of buildLookupKeys(path, currentFilePath)) {
|
||||
const targetPageId = pageLookupMap.get(key);
|
||||
if (!targetPageId || targetPageId === AMBIGUOUS_PAGE_LOOKUP) {
|
||||
continue;
|
||||
}
|
||||
return targetPageId;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
function resolveWikilinkDisplayTitle(
|
||||
rawAlias: string | undefined,
|
||||
pageEmoji: string | undefined
|
||||
): string | undefined {
|
||||
if (!rawAlias) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const { title: aliasTitle, emoji: aliasEmoji } =
|
||||
extractTitleAndEmoji(rawAlias);
|
||||
|
||||
if (aliasEmoji && aliasEmoji === pageEmoji) {
|
||||
return aliasTitle;
|
||||
}
|
||||
|
||||
return rawAlias;
|
||||
}
|
||||
|
||||
function isImageAssetPath(path: string): boolean {
|
||||
const extension = path.split('.').at(-1)?.toLowerCase() ?? '';
|
||||
return extMimeMap.get(extension)?.startsWith('image/') ?? false;
|
||||
}
|
||||
|
||||
function encodeMarkdownPath(path: string): string {
|
||||
return encodeURI(path).replaceAll('(', '%28').replaceAll(')', '%29');
|
||||
}
|
||||
|
||||
function escapeMarkdownLabel(label: string): string {
|
||||
return label.replace(/[[\]\\]/g, '\\$&');
|
||||
}
|
||||
|
||||
function isObsidianSizeAlias(alias: string | undefined): boolean {
|
||||
return !!alias && /^\d+(?:x\d+)?$/i.test(alias.trim());
|
||||
}
|
||||
|
||||
function getEmbedLabel(
|
||||
rawAlias: string | undefined,
|
||||
targetPath: string,
|
||||
fallbackToFileName: boolean
|
||||
): string {
|
||||
if (!rawAlias || isObsidianSizeAlias(rawAlias)) {
|
||||
return fallbackToFileName
|
||||
? stripMarkdownExtension(basename(targetPath))
|
||||
: '';
|
||||
}
|
||||
|
||||
return rawAlias.trim();
|
||||
}
|
||||
|
||||
type ObsidianAttachmentEmbed = {
|
||||
blobId: string;
|
||||
fileName: string;
|
||||
fileType: string;
|
||||
};
|
||||
|
||||
function createObsidianAttach(embed: ObsidianAttachmentEmbed): string {
|
||||
return `<!-- ${OBSIDIAN_ATTACHMENT_EMBED_TAG} ${encodeURIComponent(
|
||||
JSON.stringify(embed)
|
||||
)} -->`;
|
||||
}
|
||||
|
||||
function parseObsidianAttach(value: string): ObsidianAttachmentEmbed | null {
|
||||
const match = value.match(
|
||||
new RegExp(`^<!-- ${OBSIDIAN_ATTACHMENT_EMBED_TAG} ([^ ]+) -->$`)
|
||||
);
|
||||
if (!match?.[1]) return null;
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(
|
||||
decodeURIComponent(match[1])
|
||||
) as ObsidianAttachmentEmbed;
|
||||
if (!parsed.blobId || !parsed.fileName) {
|
||||
return null;
|
||||
}
|
||||
return parsed;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function preprocessObsidianEmbeds(
|
||||
markdown: string,
|
||||
filePath: string,
|
||||
pageLookupMap: ReadonlyMap<string, string>,
|
||||
pathBlobIdMap: ReadonlyMap<string, string>
|
||||
): string {
|
||||
return markdown.replace(
|
||||
/!\[\[([^\]|]+)(?:\|([^\]]+))?\]\]/g,
|
||||
(match, rawTarget: string, rawAlias?: string) => {
|
||||
const targetPageId = resolvePageIdFromLookup(
|
||||
pageLookupMap,
|
||||
rawTarget,
|
||||
filePath
|
||||
);
|
||||
if (targetPageId) {
|
||||
return `[[${rawTarget}${rawAlias ? `|${rawAlias}` : ''}]]`;
|
||||
}
|
||||
|
||||
const { path } = parseObsidianTarget(rawTarget);
|
||||
if (!path) {
|
||||
return match;
|
||||
}
|
||||
|
||||
const assetPath = getImageFullPath(filePath, path);
|
||||
const encodedPath = encodeMarkdownPath(assetPath);
|
||||
|
||||
if (isImageAssetPath(path)) {
|
||||
const alt = getEmbedLabel(rawAlias, path, false);
|
||||
return ``;
|
||||
}
|
||||
|
||||
const label = getEmbedLabel(rawAlias, path, true);
|
||||
const blobId = pathBlobIdMap.get(assetPath);
|
||||
if (!blobId) return `[${escapeMarkdownLabel(label)}](${encodedPath})`;
|
||||
|
||||
const extension = path.split('.').at(-1)?.toLowerCase() ?? '';
|
||||
return createObsidianAttach({
|
||||
blobId,
|
||||
fileName: basename(path),
|
||||
fileType: extMimeMap.get(extension) ?? '',
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
function preprocessObsidianMarkdown(
|
||||
markdown: string,
|
||||
filePath: string,
|
||||
pageLookupMap: ReadonlyMap<string, string>,
|
||||
pathBlobIdMap: ReadonlyMap<string, string>
|
||||
): string {
|
||||
const { content: contentWithoutFootnotes, footnotes: extractedFootnotes } =
|
||||
extractObsidianFootnotes(markdown);
|
||||
const content = preprocessObsidianEmbeds(
|
||||
contentWithoutFootnotes,
|
||||
filePath,
|
||||
pageLookupMap,
|
||||
pathBlobIdMap
|
||||
);
|
||||
const normalizedMarkdown = preprocessTitleHeader(
|
||||
preprocessObsidianCallouts(content)
|
||||
);
|
||||
|
||||
if (extractedFootnotes.length === 0) {
|
||||
return normalizedMarkdown;
|
||||
}
|
||||
|
||||
const trimmedMarkdown = normalizedMarkdown.replace(/\s+$/, '');
|
||||
return `${trimmedMarkdown}\n\n${extractedFootnotes.join('\n\n')}\n`;
|
||||
}
|
||||
|
||||
function isObsidianAttachmentEmbedNode(node: MarkdownAST): node is Html {
|
||||
return node.type === 'html' && !!parseObsidianAttach(node.value);
|
||||
}
|
||||
|
||||
export const obsidianAttachmentEmbedMarkdownAdapterMatcher =
|
||||
BlockMarkdownAdapterExtension({
|
||||
flavour: 'obsidian:attachment-embed',
|
||||
toMatch: o => isObsidianAttachmentEmbedNode(o.node),
|
||||
fromMatch: () => false,
|
||||
toBlockSnapshot: {
|
||||
enter: (o, context) => {
|
||||
if (!isObsidianAttachmentEmbedNode(o.node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const attachment = parseObsidianAttach(o.node.value);
|
||||
if (!attachment) {
|
||||
return;
|
||||
}
|
||||
|
||||
const assetFile = context.assets?.getAssets().get(attachment.blobId);
|
||||
context.walkerContext
|
||||
.openNode(
|
||||
createAttachmentBlockSnapshot({
|
||||
id: nanoid(),
|
||||
props: {
|
||||
name: attachment.fileName,
|
||||
size: assetFile?.size ?? 0,
|
||||
type:
|
||||
attachment.fileType ||
|
||||
assetFile?.type ||
|
||||
'application/octet-stream',
|
||||
sourceId: attachment.blobId,
|
||||
embed: false,
|
||||
style: 'horizontalThin',
|
||||
footnoteIdentifier: null,
|
||||
},
|
||||
}),
|
||||
'children'
|
||||
)
|
||||
.closeNode();
|
||||
(o.node as unknown as { type: string }).type =
|
||||
'obsidianAttachmentEmbed';
|
||||
},
|
||||
},
|
||||
fromBlockSnapshot: {},
|
||||
});
|
||||
|
||||
export const obsidianWikilinkToDeltaMatcher = MarkdownASTToDeltaExtension({
|
||||
name: 'obsidian-wikilink',
|
||||
match: ast => ast.type === 'text',
|
||||
toDelta: (ast, context) => {
|
||||
const textNode = ast as Text;
|
||||
if (!textNode.value) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const nodeContent = textNode.value;
|
||||
const wikilinkRegex = /\[\[([^\]|]+)(?:\|([^\]]+))?\]\]/g;
|
||||
const deltas: DeltaInsert<AffineTextAttributes>[] = [];
|
||||
|
||||
let lastProcessedIndex = 0;
|
||||
let linkMatch;
|
||||
|
||||
while ((linkMatch = wikilinkRegex.exec(nodeContent)) !== null) {
|
||||
if (linkMatch.index > lastProcessedIndex) {
|
||||
deltas.push({
|
||||
insert: nodeContent.substring(lastProcessedIndex, linkMatch.index),
|
||||
});
|
||||
}
|
||||
|
||||
const targetPageName = linkMatch[1].trim();
|
||||
const alias = linkMatch[2]?.trim();
|
||||
const currentFilePath = context.configs.get(FULL_FILE_PATH_KEY);
|
||||
const targetPageId = resolvePageIdFromLookup(
|
||||
{ get: key => context.configs.get(`obsidian:pageId:${key}`) },
|
||||
targetPageName,
|
||||
typeof currentFilePath === 'string' ? currentFilePath : undefined
|
||||
);
|
||||
|
||||
if (targetPageId) {
|
||||
const pageEmoji = context.configs.get(
|
||||
'obsidian:pageEmoji:' + targetPageId
|
||||
);
|
||||
const displayTitle = resolveWikilinkDisplayTitle(alias, pageEmoji);
|
||||
|
||||
deltas.push({
|
||||
insert: ' ',
|
||||
attributes: {
|
||||
reference: {
|
||||
type: 'LinkedPage',
|
||||
pageId: targetPageId,
|
||||
...(displayTitle ? { title: displayTitle } : {}),
|
||||
},
|
||||
},
|
||||
});
|
||||
} else {
|
||||
deltas.push({ insert: linkMatch[0] });
|
||||
}
|
||||
|
||||
lastProcessedIndex = wikilinkRegex.lastIndex;
|
||||
}
|
||||
|
||||
if (lastProcessedIndex < nodeContent.length) {
|
||||
deltas.push({ insert: nodeContent.substring(lastProcessedIndex) });
|
||||
}
|
||||
|
||||
return deltas;
|
||||
},
|
||||
});
|
||||
|
||||
export type ImportObsidianVaultOptions = {
|
||||
collection: Workspace;
|
||||
schema: Schema;
|
||||
importedFiles: File[];
|
||||
extensions: ExtensionType[];
|
||||
};
|
||||
|
||||
export type ImportObsidianVaultResult = {
|
||||
docIds: string[];
|
||||
docEmojis: Map<string, string>;
|
||||
};
|
||||
|
||||
export async function importObsidianVault({
|
||||
collection,
|
||||
schema,
|
||||
importedFiles,
|
||||
extensions,
|
||||
}: ImportObsidianVaultOptions): Promise<ImportObsidianVaultResult> {
|
||||
const provider = getProvider([
|
||||
obsidianWikilinkToDeltaMatcher,
|
||||
obsidianAttachmentEmbedMarkdownAdapterMatcher,
|
||||
...extensions,
|
||||
]);
|
||||
|
||||
const docIds: string[] = [];
|
||||
const docEmojis = new Map<string, string>();
|
||||
const pendingAssets: AssetMap = new Map();
|
||||
const pendingPathBlobIdMap: PathBlobIdMap = new Map();
|
||||
const markdownBlobs: MarkdownFileImportEntry[] = [];
|
||||
const pageLookupMap = new Map<string, string>();
|
||||
|
||||
for (const file of importedFiles) {
|
||||
const filePath = file.webkitRelativePath || file.name;
|
||||
if (isSystemImportPath(filePath)) continue;
|
||||
|
||||
if (file.name.endsWith('.md')) {
|
||||
const fileNameWithoutExt = file.name.replace(/\.[^/.]+$/, '');
|
||||
const markdown = await file.text();
|
||||
const { content, meta } = parseFrontmatter(markdown);
|
||||
|
||||
const documentTitleCandidate = meta.title ?? fileNameWithoutExt;
|
||||
const { title: preferredTitle, emoji: leadingEmoji } =
|
||||
extractTitleAndEmoji(documentTitleCandidate);
|
||||
|
||||
const newPageId = collection.idGenerator();
|
||||
registerPageLookup(pageLookupMap, filePath, newPageId);
|
||||
registerPageLookup(
|
||||
pageLookupMap,
|
||||
stripMarkdownExtension(filePath),
|
||||
newPageId
|
||||
);
|
||||
registerPageLookup(pageLookupMap, file.name, newPageId);
|
||||
registerPageLookup(pageLookupMap, fileNameWithoutExt, newPageId);
|
||||
registerPageLookup(pageLookupMap, documentTitleCandidate, newPageId);
|
||||
registerPageLookup(pageLookupMap, preferredTitle, newPageId);
|
||||
|
||||
if (leadingEmoji) {
|
||||
docEmojis.set(newPageId, leadingEmoji);
|
||||
}
|
||||
|
||||
markdownBlobs.push({
|
||||
filename: file.name,
|
||||
contentBlob: file,
|
||||
fullPath: filePath,
|
||||
pageId: newPageId,
|
||||
preferredTitle,
|
||||
content,
|
||||
meta,
|
||||
});
|
||||
} else {
|
||||
await stageImportedAsset({
|
||||
pendingAssets,
|
||||
pendingPathBlobIdMap,
|
||||
path: filePath,
|
||||
content: file,
|
||||
fileName: file.name,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
for (const existingDocMeta of collection.meta.docMetas) {
|
||||
if (existingDocMeta.title) {
|
||||
registerPageLookup(
|
||||
pageLookupMap,
|
||||
existingDocMeta.title,
|
||||
existingDocMeta.id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
await Promise.all(
|
||||
markdownBlobs.map(async markdownFile => {
|
||||
const {
|
||||
fullPath,
|
||||
pageId: predefinedId,
|
||||
preferredTitle,
|
||||
content,
|
||||
meta,
|
||||
} = markdownFile;
|
||||
|
||||
const job = createMarkdownImportJob({
|
||||
collection,
|
||||
schema,
|
||||
preferredTitle,
|
||||
fullPath,
|
||||
});
|
||||
|
||||
for (const [lookupKey, id] of pageLookupMap.entries()) {
|
||||
if (id === AMBIGUOUS_PAGE_LOOKUP) {
|
||||
continue;
|
||||
}
|
||||
job.adapterConfigs.set(`obsidian:pageId:${lookupKey}`, id);
|
||||
}
|
||||
for (const [id, emoji] of docEmojis.entries()) {
|
||||
job.adapterConfigs.set('obsidian:pageEmoji:' + id, emoji);
|
||||
}
|
||||
|
||||
const pathBlobIdMap = bindImportedAssetsToJob(
|
||||
job,
|
||||
pendingAssets,
|
||||
pendingPathBlobIdMap
|
||||
);
|
||||
|
||||
const preprocessedMarkdown = preprocessObsidianMarkdown(
|
||||
content,
|
||||
fullPath,
|
||||
pageLookupMap,
|
||||
pathBlobIdMap
|
||||
);
|
||||
const mdAdapter = new MarkdownAdapter(job, provider);
|
||||
const snapshot = await mdAdapter.toDocSnapshot({
|
||||
file: preprocessedMarkdown,
|
||||
assets: job.assetsManager,
|
||||
});
|
||||
|
||||
if (snapshot) {
|
||||
snapshot.meta.id = predefinedId;
|
||||
const doc = await job.snapshotToDoc(snapshot);
|
||||
if (doc) {
|
||||
applyMetaPatch(collection, doc.id, {
|
||||
...meta,
|
||||
title: preferredTitle,
|
||||
trash: false,
|
||||
});
|
||||
docIds.push(doc.id);
|
||||
}
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
return { docIds, docEmojis };
|
||||
}
|
||||
|
||||
export const ObsidianTransformer = {
|
||||
importObsidianVault,
|
||||
};
|
||||
@@ -1,3 +1,5 @@
|
||||
import type { ParsedFrontmatterMeta } from './markdown.js';
|
||||
|
||||
/**
|
||||
* Represents an imported file entry in the zip archive
|
||||
*/
|
||||
@@ -10,6 +12,13 @@ export type ImportedFileEntry = {
|
||||
fullPath: string;
|
||||
};
|
||||
|
||||
export type MarkdownFileImportEntry = ImportedFileEntry & {
|
||||
pageId: string;
|
||||
preferredTitle: string;
|
||||
content: string;
|
||||
meta: ParsedFrontmatterMeta;
|
||||
};
|
||||
|
||||
/**
|
||||
* Map of asset hash to File object for all media files in the zip
|
||||
* Key: SHA hash of the file content (blobId)
|
||||
|
||||
@@ -162,10 +162,11 @@ export class AffineToolbarWidget extends WidgetComponent {
|
||||
}
|
||||
|
||||
setReferenceElementWithElements(gfx: GfxController, elements: GfxModel[]) {
|
||||
const surfaceBounds = getCommonBoundWithRotation(elements);
|
||||
|
||||
const getBoundingClientRect = () => {
|
||||
const bounds = getCommonBoundWithRotation(elements);
|
||||
const { x: offsetX, y: offsetY } = this.getBoundingClientRect();
|
||||
const [x, y, w, h] = gfx.viewport.toViewBound(bounds).toXYWH();
|
||||
const [x, y, w, h] = gfx.viewport.toViewBound(surfaceBounds).toXYWH();
|
||||
const rect = new DOMRect(x + offsetX, y + offsetY, w, h);
|
||||
return rect;
|
||||
};
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@vitest/browser-playwright": "^4.0.18",
|
||||
"playwright": "=1.58.2",
|
||||
"vitest": "^4.0.18"
|
||||
},
|
||||
"exports": {
|
||||
|
||||
@@ -103,8 +103,9 @@ export abstract class GfxPrimitiveElementModel<
|
||||
}
|
||||
|
||||
get deserializedXYWH() {
|
||||
if (!this._lastXYWH || this.xywh !== this._lastXYWH) {
|
||||
const xywh = this.xywh;
|
||||
const xywh = this.xywh;
|
||||
|
||||
if (!this._lastXYWH || xywh !== this._lastXYWH) {
|
||||
this._local.set('deserializedXYWH', deserializeXYWH(xywh));
|
||||
this._lastXYWH = xywh;
|
||||
}
|
||||
@@ -386,6 +387,8 @@ export abstract class GfxGroupLikeElementModel<
|
||||
{
|
||||
private _childIds: string[] = [];
|
||||
|
||||
private _xywhDirty = true;
|
||||
|
||||
private readonly _mutex = createMutex();
|
||||
|
||||
abstract children: Y.Map<any>;
|
||||
@@ -420,24 +423,9 @@ export abstract class GfxGroupLikeElementModel<
|
||||
|
||||
get xywh() {
|
||||
this._mutex(() => {
|
||||
const curXYWH =
|
||||
(this._local.get('xywh') as SerializedXYWH) ?? '[0,0,0,0]';
|
||||
const newXYWH = this._getXYWH().serialize();
|
||||
|
||||
if (curXYWH !== newXYWH || !this._local.has('xywh')) {
|
||||
this._local.set('xywh', newXYWH);
|
||||
|
||||
if (curXYWH !== newXYWH) {
|
||||
this._onChange({
|
||||
props: {
|
||||
xywh: newXYWH,
|
||||
},
|
||||
oldValues: {
|
||||
xywh: curXYWH,
|
||||
},
|
||||
local: true,
|
||||
});
|
||||
}
|
||||
if (this._xywhDirty || !this._local.has('xywh')) {
|
||||
this._local.set('xywh', this._getXYWH().serialize());
|
||||
this._xywhDirty = false;
|
||||
}
|
||||
});
|
||||
|
||||
@@ -457,15 +445,41 @@ export abstract class GfxGroupLikeElementModel<
|
||||
bound = bound ? bound.unite(child.elementBound) : child.elementBound;
|
||||
});
|
||||
|
||||
if (bound) {
|
||||
this._local.set('xywh', bound.serialize());
|
||||
} else {
|
||||
this._local.delete('xywh');
|
||||
}
|
||||
|
||||
return bound ?? new Bound(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
invalidateXYWH() {
|
||||
this._xywhDirty = true;
|
||||
this._local.delete('deserializedXYWH');
|
||||
}
|
||||
|
||||
refreshXYWH(local: boolean) {
|
||||
this._mutex(() => {
|
||||
const oldXYWH =
|
||||
(this._local.get('xywh') as SerializedXYWH) ?? '[0,0,0,0]';
|
||||
const nextXYWH = this._getXYWH().serialize();
|
||||
|
||||
this._xywhDirty = false;
|
||||
|
||||
if (oldXYWH === nextXYWH && this._local.has('xywh')) {
|
||||
return;
|
||||
}
|
||||
|
||||
this._local.set('xywh', nextXYWH);
|
||||
this._local.delete('deserializedXYWH');
|
||||
|
||||
this._onChange({
|
||||
props: {
|
||||
xywh: nextXYWH,
|
||||
},
|
||||
oldValues: {
|
||||
xywh: oldXYWH,
|
||||
},
|
||||
local,
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
abstract addChild(element: GfxModel): void;
|
||||
|
||||
/**
|
||||
@@ -496,6 +510,7 @@ export abstract class GfxGroupLikeElementModel<
|
||||
setChildIds(value: string[], fromLocal: boolean) {
|
||||
const oldChildIds = this.childIds;
|
||||
this._childIds = value;
|
||||
this.invalidateXYWH();
|
||||
|
||||
this._onChange({
|
||||
props: {
|
||||
|
||||
@@ -52,6 +52,12 @@ export type MiddlewareCtx = {
|
||||
export type SurfaceMiddleware = (ctx: MiddlewareCtx) => void;
|
||||
|
||||
export class SurfaceBlockModel extends BlockModel<SurfaceBlockProps> {
|
||||
private static readonly _groupBoundImpactKeys = new Set([
|
||||
'xywh',
|
||||
'rotate',
|
||||
'hidden',
|
||||
]);
|
||||
|
||||
protected _decoratorState = createDecoratorState();
|
||||
|
||||
protected _elementCtorMap: Record<
|
||||
@@ -308,6 +314,42 @@ export class SurfaceBlockModel extends BlockModel<SurfaceBlockProps> {
|
||||
Object.keys(payload.props).forEach(key => {
|
||||
model.propsUpdated.next({ key });
|
||||
});
|
||||
|
||||
this._refreshParentGroupBoundsForElement(model, payload);
|
||||
}
|
||||
|
||||
private _refreshParentGroupBounds(id: string, local: boolean) {
|
||||
const group = this.getGroup(id);
|
||||
|
||||
if (group instanceof GfxGroupLikeElementModel) {
|
||||
group.refreshXYWH(local);
|
||||
}
|
||||
}
|
||||
|
||||
private _refreshParentGroupBoundsForElement(
|
||||
model: GfxPrimitiveElementModel,
|
||||
payload: ElementUpdatedData
|
||||
) {
|
||||
if (
|
||||
model instanceof GfxGroupLikeElementModel &&
|
||||
('childIds' in payload.props || 'childIds' in payload.oldValues)
|
||||
) {
|
||||
model.refreshXYWH(payload.local);
|
||||
return;
|
||||
}
|
||||
|
||||
const affectedKeys = new Set([
|
||||
...Object.keys(payload.props),
|
||||
...Object.keys(payload.oldValues),
|
||||
]);
|
||||
|
||||
if (
|
||||
Array.from(affectedKeys).some(key =>
|
||||
SurfaceBlockModel._groupBoundImpactKeys.has(key)
|
||||
)
|
||||
) {
|
||||
this._refreshParentGroupBounds(model.id, payload.local);
|
||||
}
|
||||
}
|
||||
|
||||
private _initElementModels() {
|
||||
@@ -458,6 +500,10 @@ export class SurfaceBlockModel extends BlockModel<SurfaceBlockProps> {
|
||||
);
|
||||
}
|
||||
|
||||
if (payload.model instanceof BlockModel) {
|
||||
this._refreshParentGroupBounds(payload.id, payload.isLocal);
|
||||
}
|
||||
|
||||
break;
|
||||
case 'delete':
|
||||
if (isGfxGroupCompatibleModel(payload.model)) {
|
||||
@@ -482,6 +528,13 @@ export class SurfaceBlockModel extends BlockModel<SurfaceBlockProps> {
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
payload.props.key &&
|
||||
SurfaceBlockModel._groupBoundImpactKeys.has(payload.props.key)
|
||||
) {
|
||||
this._refreshParentGroupBounds(payload.id, payload.isLocal);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
@@ -42,6 +42,7 @@
|
||||
"devDependencies": {
|
||||
"@vanilla-extract/vite-plugin": "^5.0.0",
|
||||
"@vitest/browser-playwright": "^4.0.18",
|
||||
"playwright": "=1.58.2",
|
||||
"vite": "^7.2.7",
|
||||
"vite-plugin-istanbul": "^7.2.1",
|
||||
"vite-plugin-wasm": "^3.5.0",
|
||||
|
||||
@@ -4,6 +4,7 @@ import type {
|
||||
ConnectorElementModel,
|
||||
GroupElementModel,
|
||||
} from '@blocksuite/affine/model';
|
||||
import { serializeXYWH } from '@blocksuite/global/gfx';
|
||||
import { beforeEach, describe, expect, test } from 'vitest';
|
||||
|
||||
import { wait } from '../utils/common.js';
|
||||
@@ -138,6 +139,29 @@ describe('group', () => {
|
||||
|
||||
expect(group.childIds).toEqual([id]);
|
||||
});
|
||||
|
||||
test('group xywh should update when child xywh changes', () => {
|
||||
const shapeId = model.addElement({
|
||||
type: 'shape',
|
||||
xywh: serializeXYWH(0, 0, 100, 100),
|
||||
});
|
||||
const groupId = model.addElement({
|
||||
type: 'group',
|
||||
children: {
|
||||
[shapeId]: true,
|
||||
},
|
||||
});
|
||||
|
||||
const group = model.getElementById(groupId) as GroupElementModel;
|
||||
|
||||
expect(group.xywh).toBe(serializeXYWH(0, 0, 100, 100));
|
||||
|
||||
model.updateElement(shapeId, {
|
||||
xywh: serializeXYWH(50, 60, 100, 100),
|
||||
});
|
||||
|
||||
expect(group.xywh).toBe(serializeXYWH(50, 60, 100, 100));
|
||||
});
|
||||
});
|
||||
|
||||
describe('connector', () => {
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 25 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 25 KiB |
48
deny.toml
Normal file
48
deny.toml
Normal file
@@ -0,0 +1,48 @@
|
||||
[graph]
|
||||
all-features = true
|
||||
exclude-dev = true
|
||||
targets = [
|
||||
"x86_64-unknown-linux-gnu",
|
||||
"aarch64-apple-darwin",
|
||||
"x86_64-apple-darwin",
|
||||
"x86_64-pc-windows-msvc",
|
||||
"aarch64-linux-android",
|
||||
"aarch64-apple-ios",
|
||||
"aarch64-apple-ios-sim",
|
||||
]
|
||||
|
||||
[licenses]
|
||||
allow = [
|
||||
"0BSD",
|
||||
"Apache-2.0",
|
||||
"Apache-2.0 WITH LLVM-exception",
|
||||
"BSD-2-Clause",
|
||||
"BSD-3-Clause",
|
||||
"BSL-1.0",
|
||||
"CC0-1.0",
|
||||
"CDLA-Permissive-2.0",
|
||||
"ISC",
|
||||
"MIT",
|
||||
"MPL-2.0",
|
||||
"Unicode-3.0",
|
||||
"Unlicense",
|
||||
"Zlib",
|
||||
]
|
||||
confidence-threshold = 0.93
|
||||
unused-allowed-license = "allow"
|
||||
version = 2
|
||||
|
||||
[[licenses.exceptions]]
|
||||
allow = ["AGPL-3.0-only"]
|
||||
crate = "llm_adapter"
|
||||
|
||||
[[licenses.exceptions]]
|
||||
allow = ["AGPL-3.0-or-later"]
|
||||
crate = "memory-indexer"
|
||||
|
||||
[[licenses.exceptions]]
|
||||
allow = ["AGPL-3.0-or-later"]
|
||||
crate = "path-ext"
|
||||
|
||||
[licenses.private]
|
||||
ignore = true
|
||||
@@ -92,7 +92,7 @@
|
||||
"vite": "^7.2.7",
|
||||
"vitest": "^4.0.18"
|
||||
},
|
||||
"packageManager": "yarn@4.12.0",
|
||||
"packageManager": "yarn@4.13.0",
|
||||
"resolutions": {
|
||||
"array-buffer-byte-length": "npm:@nolyfill/array-buffer-byte-length@^1",
|
||||
"array-includes": "npm:@nolyfill/array-includes@^1",
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
edition = "2024"
|
||||
license-file = "LICENSE"
|
||||
name = "affine_server_native"
|
||||
publish = false
|
||||
version = "1.0.0"
|
||||
|
||||
[lib]
|
||||
@@ -21,7 +22,10 @@ image = { workspace = true }
|
||||
infer = { workspace = true }
|
||||
libwebp-sys = { workspace = true }
|
||||
little_exif = { workspace = true }
|
||||
llm_adapter = { workspace = true }
|
||||
llm_adapter = { workspace = true, default-features = false, features = [
|
||||
"ureq-client",
|
||||
] }
|
||||
matroska = { workspace = true }
|
||||
mp4parse = { workspace = true }
|
||||
napi = { workspace = true, features = ["async"] }
|
||||
napi-derive = { workspace = true }
|
||||
|
||||
BIN
packages/backend/native/fixtures/audio-only.mka
Normal file
BIN
packages/backend/native/fixtures/audio-only.mka
Normal file
Binary file not shown.
BIN
packages/backend/native/fixtures/audio-only.webm
Normal file
BIN
packages/backend/native/fixtures/audio-only.webm
Normal file
Binary file not shown.
BIN
packages/backend/native/fixtures/audio-video.webm
Normal file
BIN
packages/backend/native/fixtures/audio-video.webm
Normal file
Binary file not shown.
6
packages/backend/native/index.d.ts
vendored
6
packages/backend/native/index.d.ts
vendored
@@ -54,6 +54,12 @@ export declare function llmDispatch(protocol: string, backendConfigJson: string,
|
||||
|
||||
export declare function llmDispatchStream(protocol: string, backendConfigJson: string, requestJson: string, callback: ((err: Error | null, arg: string) => void)): LlmStreamHandle
|
||||
|
||||
export declare function llmEmbeddingDispatch(protocol: string, backendConfigJson: string, requestJson: string): string
|
||||
|
||||
export declare function llmRerankDispatch(protocol: string, backendConfigJson: string, requestJson: string): string
|
||||
|
||||
export declare function llmStructuredDispatch(protocol: string, backendConfigJson: string, requestJson: string): string
|
||||
|
||||
/**
|
||||
* Merge updates in form like `Y.applyUpdate(doc, update)` way and return the
|
||||
* result binary.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use matroska::Matroska;
|
||||
use mp4parse::{TrackType, read_mp4};
|
||||
use napi_derive::napi;
|
||||
|
||||
@@ -8,7 +9,13 @@ pub fn get_mime(input: &[u8]) -> String {
|
||||
} else {
|
||||
file_format::FileFormat::from_bytes(input).media_type().to_string()
|
||||
};
|
||||
if mimetype == "video/mp4" {
|
||||
if let Some(container) = matroska_container_kind(input).or(match mimetype.as_str() {
|
||||
"video/webm" | "application/webm" => Some(ContainerKind::WebM),
|
||||
"video/x-matroska" | "application/x-matroska" => Some(ContainerKind::Matroska),
|
||||
_ => None,
|
||||
}) {
|
||||
detect_matroska_flavor(input, container, &mimetype)
|
||||
} else if mimetype == "video/mp4" {
|
||||
detect_mp4_flavor(input)
|
||||
} else {
|
||||
mimetype
|
||||
@@ -37,3 +44,68 @@ fn detect_mp4_flavor(input: &[u8]) -> String {
|
||||
Err(_) => "video/mp4".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum ContainerKind {
|
||||
WebM,
|
||||
Matroska,
|
||||
}
|
||||
|
||||
impl ContainerKind {
|
||||
fn audio_mime(&self) -> &'static str {
|
||||
match self {
|
||||
ContainerKind::WebM => "audio/webm",
|
||||
ContainerKind::Matroska => "audio/x-matroska",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn detect_matroska_flavor(input: &[u8], container: ContainerKind, fallback: &str) -> String {
|
||||
match Matroska::open(std::io::Cursor::new(input)) {
|
||||
Ok(file) => {
|
||||
let has_video = file.video_tracks().next().is_some();
|
||||
let has_audio = file.audio_tracks().next().is_some();
|
||||
if !has_video && has_audio {
|
||||
container.audio_mime().to_string()
|
||||
} else {
|
||||
fallback.to_string()
|
||||
}
|
||||
}
|
||||
Err(_) => fallback.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn matroska_container_kind(input: &[u8]) -> Option<ContainerKind> {
|
||||
let header = &input[..1024.min(input.len())];
|
||||
if header.windows(4).any(|window| window.eq_ignore_ascii_case(b"webm")) {
|
||||
Some(ContainerKind::WebM)
|
||||
} else if header.windows(8).any(|window| window.eq_ignore_ascii_case(b"matroska")) {
|
||||
Some(ContainerKind::Matroska)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const AUDIO_ONLY_WEBM: &[u8] = include_bytes!("../fixtures/audio-only.webm");
|
||||
const AUDIO_VIDEO_WEBM: &[u8] = include_bytes!("../fixtures/audio-video.webm");
|
||||
const AUDIO_ONLY_MATROSKA: &[u8] = include_bytes!("../fixtures/audio-only.mka");
|
||||
|
||||
#[test]
|
||||
fn detects_audio_only_webm_as_audio() {
|
||||
assert_eq!(get_mime(AUDIO_ONLY_WEBM), "audio/webm");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preserves_video_webm() {
|
||||
assert_eq!(get_mime(AUDIO_VIDEO_WEBM), "video/webm");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_audio_only_matroska_as_audio() {
|
||||
assert_eq!(get_mime(AUDIO_ONLY_MATROSKA), "audio/x-matroska");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,9 +5,10 @@ use std::sync::{
|
||||
|
||||
use llm_adapter::{
|
||||
backend::{
|
||||
BackendConfig, BackendError, BackendProtocol, ReqwestHttpClient, dispatch_request, dispatch_stream_events_with,
|
||||
BackendConfig, BackendError, BackendProtocol, DefaultHttpClient, dispatch_embedding_request, dispatch_request,
|
||||
dispatch_rerank_request, dispatch_stream_events_with, dispatch_structured_request,
|
||||
},
|
||||
core::{CoreRequest, StreamEvent},
|
||||
core::{CoreRequest, EmbeddingRequest, RerankRequest, StreamEvent, StructuredRequest},
|
||||
middleware::{
|
||||
MiddlewareConfig, PipelineContext, RequestMiddleware, StreamMiddleware, citation_indexing, clamp_max_tokens,
|
||||
normalize_messages, run_request_middleware_chain, run_stream_middleware_chain, stream_event_normalize,
|
||||
@@ -40,6 +41,20 @@ struct LlmDispatchPayload {
|
||||
middleware: LlmMiddlewarePayload,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct LlmStructuredDispatchPayload {
|
||||
#[serde(flatten)]
|
||||
request: StructuredRequest,
|
||||
#[serde(default)]
|
||||
middleware: LlmMiddlewarePayload,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct LlmRerankDispatchPayload {
|
||||
#[serde(flatten)]
|
||||
request: RerankRequest,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub struct LlmStreamHandle {
|
||||
aborted: Arc<AtomicBool>,
|
||||
@@ -61,7 +76,44 @@ pub fn llm_dispatch(protocol: String, backend_config_json: String, request_json:
|
||||
let request = apply_request_middlewares(payload.request, &payload.middleware)?;
|
||||
|
||||
let response =
|
||||
dispatch_request(&ReqwestHttpClient::default(), &config, protocol, &request).map_err(map_backend_error)?;
|
||||
dispatch_request(&DefaultHttpClient::default(), &config, protocol, &request).map_err(map_backend_error)?;
|
||||
|
||||
serde_json::to_string(&response).map_err(map_json_error)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub fn llm_structured_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result<String> {
|
||||
let protocol = parse_protocol(&protocol)?;
|
||||
let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?;
|
||||
let payload: LlmStructuredDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?;
|
||||
let request = apply_structured_request_middlewares(payload.request, &payload.middleware)?;
|
||||
|
||||
let response = dispatch_structured_request(&DefaultHttpClient::default(), &config, protocol, &request)
|
||||
.map_err(map_backend_error)?;
|
||||
|
||||
serde_json::to_string(&response).map_err(map_json_error)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub fn llm_embedding_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result<String> {
|
||||
let protocol = parse_protocol(&protocol)?;
|
||||
let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?;
|
||||
let request: EmbeddingRequest = serde_json::from_str(&request_json).map_err(map_json_error)?;
|
||||
|
||||
let response = dispatch_embedding_request(&DefaultHttpClient::default(), &config, protocol, &request)
|
||||
.map_err(map_backend_error)?;
|
||||
|
||||
serde_json::to_string(&response).map_err(map_json_error)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub fn llm_rerank_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result<String> {
|
||||
let protocol = parse_protocol(&protocol)?;
|
||||
let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?;
|
||||
let payload: LlmRerankDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?;
|
||||
|
||||
let response = dispatch_rerank_request(&DefaultHttpClient::default(), &config, protocol, &payload.request)
|
||||
.map_err(map_backend_error)?;
|
||||
|
||||
serde_json::to_string(&response).map_err(map_json_error)
|
||||
}
|
||||
@@ -98,7 +150,7 @@ pub fn llm_dispatch_stream(
|
||||
let mut aborted_by_user = false;
|
||||
let mut callback_dispatch_failed = false;
|
||||
|
||||
let result = dispatch_stream_events_with(&ReqwestHttpClient::default(), &config, protocol, &request, |event| {
|
||||
let result = dispatch_stream_events_with(&DefaultHttpClient::default(), &config, protocol, &request, |event| {
|
||||
if aborted_in_worker.load(Ordering::Relaxed) {
|
||||
aborted_by_user = true;
|
||||
return Err(BackendError::Http(STREAM_ABORTED_REASON.to_string()));
|
||||
@@ -155,6 +207,27 @@ fn apply_request_middlewares(request: CoreRequest, middleware: &LlmMiddlewarePay
|
||||
Ok(run_request_middleware_chain(request, &middleware.config, &chain))
|
||||
}
|
||||
|
||||
fn apply_structured_request_middlewares(
|
||||
request: StructuredRequest,
|
||||
middleware: &LlmMiddlewarePayload,
|
||||
) -> Result<StructuredRequest> {
|
||||
let mut core = request.as_core_request();
|
||||
core = apply_request_middlewares(core, middleware)?;
|
||||
|
||||
Ok(StructuredRequest {
|
||||
model: core.model,
|
||||
messages: core.messages,
|
||||
schema: core
|
||||
.response_schema
|
||||
.ok_or_else(|| Error::new(Status::InvalidArg, "Structured request schema is required"))?,
|
||||
max_tokens: core.max_tokens,
|
||||
temperature: core.temperature,
|
||||
reasoning: core.reasoning,
|
||||
strict: request.strict,
|
||||
response_mime_type: request.response_mime_type,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StreamPipeline {
|
||||
chain: Vec<StreamMiddleware>,
|
||||
@@ -268,6 +341,7 @@ fn parse_protocol(protocol: &str) -> Result<BackendProtocol> {
|
||||
}
|
||||
"openai_responses" | "openai-responses" | "responses" => Ok(BackendProtocol::OpenaiResponses),
|
||||
"anthropic" | "anthropic_messages" | "anthropic-messages" => Ok(BackendProtocol::AnthropicMessages),
|
||||
"gemini" | "gemini_generate_content" | "gemini-generate-content" => Ok(BackendProtocol::GeminiGenerateContent),
|
||||
other => Err(Error::new(
|
||||
Status::InvalidArg,
|
||||
format!("Unsupported llm backend protocol: {other}"),
|
||||
@@ -293,6 +367,7 @@ mod tests {
|
||||
assert!(parse_protocol("chat-completions").is_ok());
|
||||
assert!(parse_protocol("responses").is_ok());
|
||||
assert!(parse_protocol("anthropic").is_ok());
|
||||
assert!(parse_protocol("gemini").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -25,8 +25,6 @@
|
||||
"dependencies": {
|
||||
"@affine/s3-compat": "workspace:*",
|
||||
"@affine/server-native": "workspace:*",
|
||||
"@ai-sdk/google": "^3.0.46",
|
||||
"@ai-sdk/google-vertex": "^4.0.83",
|
||||
"@apollo/server": "^4.13.0",
|
||||
"@fal-ai/serverless-client": "^0.15.0",
|
||||
"@google-cloud/opentelemetry-cloud-trace-exporter": "^3.0.0",
|
||||
@@ -35,30 +33,30 @@
|
||||
"@nestjs-cls/transactional-adapter-prisma": "^1.3.4",
|
||||
"@nestjs/apollo": "^13.0.4",
|
||||
"@nestjs/bullmq": "^11.0.4",
|
||||
"@nestjs/common": "^11.0.21",
|
||||
"@nestjs/core": "^11.1.14",
|
||||
"@nestjs/common": "^11.1.17",
|
||||
"@nestjs/core": "^11.1.17",
|
||||
"@nestjs/graphql": "^13.0.4",
|
||||
"@nestjs/platform-express": "^11.1.14",
|
||||
"@nestjs/platform-socket.io": "^11.1.14",
|
||||
"@nestjs/platform-express": "^11.1.17",
|
||||
"@nestjs/platform-socket.io": "^11.1.17",
|
||||
"@nestjs/schedule": "^6.1.1",
|
||||
"@nestjs/throttler": "^6.5.0",
|
||||
"@nestjs/websockets": "^11.1.14",
|
||||
"@nestjs/websockets": "^11.1.17",
|
||||
"@node-rs/argon2": "^2.0.2",
|
||||
"@node-rs/crc32": "^1.10.6",
|
||||
"@opentelemetry/api": "^1.9.0",
|
||||
"@opentelemetry/core": "^2.2.0",
|
||||
"@opentelemetry/exporter-prometheus": "^0.212.0",
|
||||
"@opentelemetry/exporter-zipkin": "^2.2.0",
|
||||
"@opentelemetry/host-metrics": "^0.38.0",
|
||||
"@opentelemetry/instrumentation": "^0.212.0",
|
||||
"@opentelemetry/instrumentation-graphql": "^0.60.0",
|
||||
"@opentelemetry/instrumentation-http": "^0.212.0",
|
||||
"@opentelemetry/instrumentation-ioredis": "^0.60.0",
|
||||
"@opentelemetry/instrumentation-nestjs-core": "^0.58.0",
|
||||
"@opentelemetry/instrumentation-socket.io": "^0.59.0",
|
||||
"@opentelemetry/exporter-prometheus": "^0.213.0",
|
||||
"@opentelemetry/exporter-zipkin": "^2.6.0",
|
||||
"@opentelemetry/host-metrics": "^0.38.3",
|
||||
"@opentelemetry/instrumentation": "^0.213.0",
|
||||
"@opentelemetry/instrumentation-graphql": "^0.61.0",
|
||||
"@opentelemetry/instrumentation-http": "^0.213.0",
|
||||
"@opentelemetry/instrumentation-ioredis": "^0.61.0",
|
||||
"@opentelemetry/instrumentation-nestjs-core": "^0.59.0",
|
||||
"@opentelemetry/instrumentation-socket.io": "^0.60.0",
|
||||
"@opentelemetry/resources": "^2.2.0",
|
||||
"@opentelemetry/sdk-metrics": "^2.2.0",
|
||||
"@opentelemetry/sdk-node": "^0.212.0",
|
||||
"@opentelemetry/sdk-node": "^0.213.0",
|
||||
"@opentelemetry/sdk-trace-node": "^2.2.0",
|
||||
"@opentelemetry/semantic-conventions": "^1.38.0",
|
||||
"@prisma/client": "^6.6.0",
|
||||
@@ -66,7 +64,6 @@
|
||||
"@queuedash/api": "^3.16.0",
|
||||
"@react-email/components": "^0.5.7",
|
||||
"@socket.io/redis-adapter": "^8.3.0",
|
||||
"ai": "^6.0.118",
|
||||
"bullmq": "^5.40.2",
|
||||
"cookie-parser": "^1.4.7",
|
||||
"cross-env": "^10.1.0",
|
||||
@@ -75,7 +72,7 @@
|
||||
"eventemitter2": "^6.4.9",
|
||||
"exa-js": "^2.4.0",
|
||||
"express": "^5.0.1",
|
||||
"fast-xml-parser": "^5.3.4",
|
||||
"fast-xml-parser": "^5.5.7",
|
||||
"get-stream": "^9.0.1",
|
||||
"google-auth-library": "^10.2.0",
|
||||
"graphql": "^16.9.0",
|
||||
|
||||
@@ -225,6 +225,20 @@ const checkStreamObjects = (result: string) => {
|
||||
}
|
||||
};
|
||||
|
||||
const parseStreamObjects = (result: string): StreamObject[] => {
|
||||
const streamObjects = JSON.parse(result);
|
||||
return z.array(StreamObjectSchema).parse(streamObjects);
|
||||
};
|
||||
|
||||
const getStreamObjectText = (result: string) =>
|
||||
parseStreamObjects(result)
|
||||
.filter(
|
||||
(chunk): chunk is Extract<StreamObject, { type: 'text-delta' }> =>
|
||||
chunk.type === 'text-delta'
|
||||
)
|
||||
.map(chunk => chunk.textDelta)
|
||||
.join('');
|
||||
|
||||
const retry = async (
|
||||
action: string,
|
||||
t: ExecutionContext<Tester>,
|
||||
@@ -444,6 +458,49 @@ The term **“CRDT”** was first introduced by Marc Shapiro, Nuno Preguiça, Ca
|
||||
},
|
||||
type: 'object' as const,
|
||||
},
|
||||
{
|
||||
name: 'Gemini native text',
|
||||
promptName: ['Chat With AFFiNE AI'],
|
||||
messages: [
|
||||
{
|
||||
role: 'user' as const,
|
||||
content:
|
||||
'In one short sentence, explain what AFFiNE AI is and mention AFFiNE by name.',
|
||||
},
|
||||
],
|
||||
config: { model: 'gemini-2.5-flash' },
|
||||
verifier: (t: ExecutionContext<Tester>, result: string) => {
|
||||
assertNotWrappedInCodeBlock(t, result);
|
||||
t.assert(
|
||||
result.toLowerCase().includes('affine'),
|
||||
'should mention AFFiNE'
|
||||
);
|
||||
},
|
||||
prefer: CopilotProviderType.Gemini,
|
||||
type: 'text' as const,
|
||||
},
|
||||
{
|
||||
name: 'Gemini native stream objects',
|
||||
promptName: ['Chat With AFFiNE AI'],
|
||||
messages: [
|
||||
{
|
||||
role: 'user' as const,
|
||||
content:
|
||||
'Respond with one short sentence about AFFiNE AI and mention AFFiNE by name.',
|
||||
},
|
||||
],
|
||||
config: { model: 'gemini-2.5-flash' },
|
||||
verifier: (t: ExecutionContext<Tester>, result: string) => {
|
||||
t.truthy(checkStreamObjects(result), 'should be valid stream objects');
|
||||
const assembledText = getStreamObjectText(result);
|
||||
t.assert(
|
||||
assembledText.toLowerCase().includes('affine'),
|
||||
'should mention AFFiNE'
|
||||
);
|
||||
},
|
||||
prefer: CopilotProviderType.Gemini,
|
||||
type: 'object' as const,
|
||||
},
|
||||
{
|
||||
name: 'Should transcribe short audio',
|
||||
promptName: ['Transcript audio'],
|
||||
@@ -716,14 +773,13 @@ for (const {
|
||||
const { factory, prompt: promptService } = t.context;
|
||||
const prompt = (await promptService.get(promptName))!;
|
||||
t.truthy(prompt, 'should have prompt');
|
||||
const provider = (await factory.getProviderByModel(prompt.model, {
|
||||
const finalConfig = Object.assign({}, prompt.config, config);
|
||||
const modelId = finalConfig.model || prompt.model;
|
||||
const provider = (await factory.getProviderByModel(modelId, {
|
||||
prefer,
|
||||
}))!;
|
||||
t.truthy(provider, 'should have provider');
|
||||
await retry(`action: ${promptName}`, t, async t => {
|
||||
const finalConfig = Object.assign({}, prompt.config, config);
|
||||
const modelId = finalConfig.model || prompt.model;
|
||||
|
||||
switch (type) {
|
||||
case 'text': {
|
||||
const result = await provider.text(
|
||||
@@ -891,7 +947,7 @@ test(
|
||||
'should be able to rerank message chunks',
|
||||
runIfCopilotConfigured,
|
||||
async t => {
|
||||
const { factory, prompt } = t.context;
|
||||
const { factory } = t.context;
|
||||
|
||||
await retry('rerank', t, async t => {
|
||||
const query = 'Is this content relevant to programming?';
|
||||
@@ -908,14 +964,18 @@ test(
|
||||
'The stock market is experiencing significant fluctuations.',
|
||||
];
|
||||
|
||||
const p = (await prompt.get('Rerank results'))!;
|
||||
t.assert(p, 'should have prompt for rerank');
|
||||
const provider = (await factory.getProviderByModel(p.model))!;
|
||||
const provider = (await factory.getProviderByModel('gpt-5.2'))!;
|
||||
t.assert(provider, 'should have provider for rerank');
|
||||
|
||||
const scores = await provider.rerank(
|
||||
{ modelId: p.model },
|
||||
embeddings.map(e => p.finish({ query, doc: e }))
|
||||
{ modelId: 'gpt-5.2' },
|
||||
{
|
||||
query,
|
||||
candidates: embeddings.map((text, index) => ({
|
||||
id: String(index),
|
||||
text,
|
||||
})),
|
||||
}
|
||||
);
|
||||
|
||||
t.is(scores.length, 10, 'should return scores for all chunks');
|
||||
|
||||
@@ -33,10 +33,7 @@ import {
|
||||
ModelOutputType,
|
||||
OpenAIProvider,
|
||||
} from '../../plugins/copilot/providers';
|
||||
import {
|
||||
CitationParser,
|
||||
TextStreamParser,
|
||||
} from '../../plugins/copilot/providers/utils';
|
||||
import { TextStreamParser } from '../../plugins/copilot/providers/utils';
|
||||
import { ChatSessionService } from '../../plugins/copilot/session';
|
||||
import { CopilotStorage } from '../../plugins/copilot/storage';
|
||||
import { CopilotTranscriptionService } from '../../plugins/copilot/transcript';
|
||||
@@ -660,6 +657,55 @@ test('should be able to generate with message id', async t => {
|
||||
}
|
||||
});
|
||||
|
||||
test('should preserve file handle attachments when merging user content into prompt', async t => {
|
||||
const { prompt, session } = t.context;
|
||||
|
||||
await prompt.set(promptName, 'model', [
|
||||
{ role: 'user', content: '{{content}}' },
|
||||
]);
|
||||
|
||||
const sessionId = await session.create({
|
||||
docId: 'test',
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
promptName,
|
||||
pinned: false,
|
||||
});
|
||||
const s = (await session.get(sessionId))!;
|
||||
|
||||
const message = await session.createMessage({
|
||||
sessionId,
|
||||
content: 'Summarize this file',
|
||||
attachments: [
|
||||
{
|
||||
kind: 'file_handle',
|
||||
fileHandle: 'file_123',
|
||||
mimeType: 'application/pdf',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
await s.pushByMessageId(message);
|
||||
const finalMessages = s.finish({});
|
||||
|
||||
t.deepEqual(finalMessages, [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Summarize this file',
|
||||
attachments: [
|
||||
{
|
||||
kind: 'file_handle',
|
||||
fileHandle: 'file_123',
|
||||
mimeType: 'application/pdf',
|
||||
},
|
||||
],
|
||||
params: {
|
||||
content: 'Summarize this file',
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
test('should save message correctly', async t => {
|
||||
const { prompt, session } = t.context;
|
||||
|
||||
@@ -1225,149 +1271,6 @@ test('should be able to run image executor', async t => {
|
||||
Sinon.restore();
|
||||
});
|
||||
|
||||
test('CitationParser should replace citation placeholders with URLs', t => {
|
||||
const content =
|
||||
'This is [a] test sentence with [citations [1]] and [[2]] and [3].';
|
||||
const citations = ['https://example1.com', 'https://example2.com'];
|
||||
|
||||
const parser = new CitationParser();
|
||||
for (const citation of citations) {
|
||||
parser.push(citation);
|
||||
}
|
||||
|
||||
const result = parser.parse(content) + parser.end();
|
||||
|
||||
const expected = [
|
||||
'This is [a] test sentence with [citations [^1]] and [^2] and [3].',
|
||||
`[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`,
|
||||
`[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`,
|
||||
].join('\n');
|
||||
|
||||
t.is(result, expected);
|
||||
});
|
||||
|
||||
test('CitationParser should replace chunks of citation placeholders with URLs', t => {
|
||||
const contents = [
|
||||
'[[]]',
|
||||
'This is [',
|
||||
'a] test sentence ',
|
||||
'with citations [1',
|
||||
'] and [',
|
||||
'[2]] and [[',
|
||||
'3]] and [[4',
|
||||
']] and [[5]',
|
||||
'] and [[6]]',
|
||||
' and [7',
|
||||
];
|
||||
const citations = [
|
||||
'https://example1.com',
|
||||
'https://example2.com',
|
||||
'https://example3.com',
|
||||
'https://example4.com',
|
||||
'https://example5.com',
|
||||
'https://example6.com',
|
||||
'https://example7.com',
|
||||
];
|
||||
|
||||
const parser = new CitationParser();
|
||||
for (const citation of citations) {
|
||||
parser.push(citation);
|
||||
}
|
||||
|
||||
let result = contents.reduce((acc, current) => {
|
||||
return acc + parser.parse(current);
|
||||
}, '');
|
||||
result += parser.end();
|
||||
|
||||
const expected = [
|
||||
'[[]]This is [a] test sentence with citations [^1] and [^2] and [^3] and [^4] and [^5] and [^6] and [7',
|
||||
`[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`,
|
||||
`[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`,
|
||||
`[^3]: {"type":"url","url":"${encodeURIComponent(citations[2])}"}`,
|
||||
`[^4]: {"type":"url","url":"${encodeURIComponent(citations[3])}"}`,
|
||||
`[^5]: {"type":"url","url":"${encodeURIComponent(citations[4])}"}`,
|
||||
`[^6]: {"type":"url","url":"${encodeURIComponent(citations[5])}"}`,
|
||||
`[^7]: {"type":"url","url":"${encodeURIComponent(citations[6])}"}`,
|
||||
].join('\n');
|
||||
t.is(result, expected);
|
||||
});
|
||||
|
||||
test('CitationParser should not replace citation already with URLs', t => {
|
||||
const content =
|
||||
'This is [a] test sentence with citations [1](https://example1.com) and [[2]](https://example2.com) and [[3](https://example3.com)].';
|
||||
const citations = [
|
||||
'https://example4.com',
|
||||
'https://example5.com',
|
||||
'https://example6.com',
|
||||
];
|
||||
|
||||
const parser = new CitationParser();
|
||||
for (const citation of citations) {
|
||||
parser.push(citation);
|
||||
}
|
||||
|
||||
const result = parser.parse(content) + parser.end();
|
||||
|
||||
const expected = [
|
||||
content,
|
||||
`[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`,
|
||||
`[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`,
|
||||
`[^3]: {"type":"url","url":"${encodeURIComponent(citations[2])}"}`,
|
||||
].join('\n');
|
||||
t.is(result, expected);
|
||||
});
|
||||
|
||||
test('CitationParser should not replace chunks of citation already with URLs', t => {
|
||||
const contents = [
|
||||
'This is [a] test sentence with citations [1',
|
||||
'](https://example1.com) and [[2]',
|
||||
'](https://example2.com) and [[3](https://example3.com)].',
|
||||
];
|
||||
const citations = [
|
||||
'https://example4.com',
|
||||
'https://example5.com',
|
||||
'https://example6.com',
|
||||
];
|
||||
|
||||
const parser = new CitationParser();
|
||||
for (const citation of citations) {
|
||||
parser.push(citation);
|
||||
}
|
||||
|
||||
let result = contents.reduce((acc, current) => {
|
||||
return acc + parser.parse(current);
|
||||
}, '');
|
||||
result += parser.end();
|
||||
|
||||
const expected = [
|
||||
contents.join(''),
|
||||
`[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`,
|
||||
`[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`,
|
||||
`[^3]: {"type":"url","url":"${encodeURIComponent(citations[2])}"}`,
|
||||
].join('\n');
|
||||
t.is(result, expected);
|
||||
});
|
||||
|
||||
test('CitationParser should replace openai style reference chunks', t => {
|
||||
const contents = [
|
||||
'This is [a] test sentence with citations ',
|
||||
'([example1.com](https://example1.com))',
|
||||
];
|
||||
|
||||
const parser = new CitationParser();
|
||||
|
||||
let result = contents.reduce((acc, current) => {
|
||||
return acc + parser.parse(current);
|
||||
}, '');
|
||||
result += parser.end();
|
||||
|
||||
const expected = [
|
||||
contents[0] + '[^1]',
|
||||
`[^1]: {"type":"url","url":"${encodeURIComponent('https://example1.com')}"}`,
|
||||
].join('\n');
|
||||
t.is(result, expected);
|
||||
});
|
||||
|
||||
test('TextStreamParser should format different types of chunks correctly', t => {
|
||||
// Define interfaces for fixtures
|
||||
interface BaseFixture {
|
||||
|
||||
@@ -1,210 +0,0 @@
|
||||
import test from 'ava';
|
||||
import { z } from 'zod';
|
||||
|
||||
import type { NativeLlmRequest, NativeLlmStreamEvent } from '../../native';
|
||||
import {
|
||||
buildNativeRequest,
|
||||
NativeProviderAdapter,
|
||||
} from '../../plugins/copilot/providers/native';
|
||||
|
||||
const mockDispatch = () =>
|
||||
(async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
yield { type: 'text_delta', text: 'Use [^1] now' };
|
||||
yield { type: 'citation', index: 1, url: 'https://affine.pro' };
|
||||
yield { type: 'done', finish_reason: 'stop' };
|
||||
})();
|
||||
|
||||
test('NativeProviderAdapter streamText should append citation footnotes', async t => {
|
||||
const adapter = new NativeProviderAdapter(mockDispatch, {}, 3);
|
||||
const chunks: string[] = [];
|
||||
for await (const chunk of adapter.streamText({
|
||||
model: 'gpt-5-mini',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
|
||||
})) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
const text = chunks.join('');
|
||||
t.true(text.includes('Use [^1] now'));
|
||||
t.true(
|
||||
text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}')
|
||||
);
|
||||
});
|
||||
|
||||
test('NativeProviderAdapter streamObject should append citation footnotes', async t => {
|
||||
const adapter = new NativeProviderAdapter(mockDispatch, {}, 3);
|
||||
const chunks = [];
|
||||
for await (const chunk of adapter.streamObject({
|
||||
model: 'gpt-5-mini',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
|
||||
})) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
t.deepEqual(
|
||||
chunks.map(chunk => chunk.type),
|
||||
['text-delta', 'text-delta']
|
||||
);
|
||||
const text = chunks
|
||||
.filter(chunk => chunk.type === 'text-delta')
|
||||
.map(chunk => chunk.textDelta)
|
||||
.join('');
|
||||
t.true(text.includes('Use [^1] now'));
|
||||
t.true(
|
||||
text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}')
|
||||
);
|
||||
});
|
||||
|
||||
test('NativeProviderAdapter streamObject should append fallback attachment footnotes', async t => {
|
||||
const dispatch = () =>
|
||||
(async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
yield {
|
||||
type: 'tool_result',
|
||||
call_id: 'call_1',
|
||||
name: 'blob_read',
|
||||
arguments: { blob_id: 'blob_1' },
|
||||
output: {
|
||||
blobId: 'blob_1',
|
||||
fileName: 'a.txt',
|
||||
fileType: 'text/plain',
|
||||
content: 'A',
|
||||
},
|
||||
};
|
||||
yield {
|
||||
type: 'tool_result',
|
||||
call_id: 'call_2',
|
||||
name: 'blob_read',
|
||||
arguments: { blob_id: 'blob_2' },
|
||||
output: {
|
||||
blobId: 'blob_2',
|
||||
fileName: 'b.txt',
|
||||
fileType: 'text/plain',
|
||||
content: 'B',
|
||||
},
|
||||
};
|
||||
yield { type: 'text_delta', text: 'Answer from files.' };
|
||||
yield { type: 'done', finish_reason: 'stop' };
|
||||
})();
|
||||
|
||||
const adapter = new NativeProviderAdapter(dispatch, {}, 3);
|
||||
const chunks = [];
|
||||
for await (const chunk of adapter.streamObject({
|
||||
model: 'gpt-5-mini',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
|
||||
})) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
const text = chunks
|
||||
.filter(chunk => chunk.type === 'text-delta')
|
||||
.map(chunk => chunk.textDelta)
|
||||
.join('');
|
||||
t.true(text.includes('Answer from files.'));
|
||||
t.true(text.includes('[^1][^2]'));
|
||||
t.true(
|
||||
text.includes(
|
||||
'[^1]: {"type":"attachment","blobId":"blob_1","fileName":"a.txt","fileType":"text/plain"}'
|
||||
)
|
||||
);
|
||||
t.true(
|
||||
text.includes(
|
||||
'[^2]: {"type":"attachment","blobId":"blob_2","fileName":"b.txt","fileType":"text/plain"}'
|
||||
)
|
||||
);
|
||||
});
|
||||
|
||||
test('NativeProviderAdapter streamObject should map tool and text events', async t => {
|
||||
let round = 0;
|
||||
const dispatch = (_request: NativeLlmRequest) =>
|
||||
(async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
round += 1;
|
||||
if (round === 1) {
|
||||
yield {
|
||||
type: 'tool_call',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments: { doc_id: 'a1' },
|
||||
};
|
||||
yield { type: 'done', finish_reason: 'tool_calls' };
|
||||
return;
|
||||
}
|
||||
yield { type: 'text_delta', text: 'ok' };
|
||||
yield { type: 'done', finish_reason: 'stop' };
|
||||
})();
|
||||
|
||||
const adapter = new NativeProviderAdapter(
|
||||
dispatch,
|
||||
{
|
||||
doc_read: {
|
||||
inputSchema: z.object({ doc_id: z.string() }),
|
||||
execute: async () => ({ markdown: '# a1' }),
|
||||
},
|
||||
},
|
||||
4
|
||||
);
|
||||
|
||||
const events = [];
|
||||
for await (const event of adapter.streamObject({
|
||||
model: 'gpt-5-mini',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'read' }] }],
|
||||
})) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
t.deepEqual(
|
||||
events.map(event => event.type),
|
||||
['tool-call', 'tool-result', 'text-delta']
|
||||
);
|
||||
t.deepEqual(events[0], {
|
||||
type: 'tool-call',
|
||||
toolCallId: 'call_1',
|
||||
toolName: 'doc_read',
|
||||
args: { doc_id: 'a1' },
|
||||
});
|
||||
});
|
||||
|
||||
test('buildNativeRequest should include rust middleware from profile', async t => {
|
||||
const { request } = await buildNativeRequest({
|
||||
model: 'gpt-5-mini',
|
||||
messages: [{ role: 'user', content: 'hello' }],
|
||||
tools: {},
|
||||
middleware: {
|
||||
rust: {
|
||||
request: ['normalize_messages', 'clamp_max_tokens'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
},
|
||||
node: {
|
||||
text: ['callout'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
t.deepEqual(request.middleware, {
|
||||
request: ['normalize_messages', 'clamp_max_tokens'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
});
|
||||
});
|
||||
|
||||
test('NativeProviderAdapter streamText should skip citation footnotes when disabled', async t => {
|
||||
const adapter = new NativeProviderAdapter(mockDispatch, {}, 3, {
|
||||
nodeTextMiddleware: ['callout'],
|
||||
});
|
||||
const chunks: string[] = [];
|
||||
for await (const chunk of adapter.streamText({
|
||||
model: 'gpt-5-mini',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
|
||||
})) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
const text = chunks.join('');
|
||||
t.true(text.includes('Use [^1] now'));
|
||||
t.false(
|
||||
text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}')
|
||||
);
|
||||
});
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,13 @@
|
||||
import serverNativeModule from '@affine/server-native';
|
||||
import test from 'ava';
|
||||
|
||||
import type { NativeLlmRerankRequest } from '../../native';
|
||||
import { ProviderMiddlewareConfig } from '../../plugins/copilot/config';
|
||||
import { normalizeOpenAIOptionsForModel } from '../../plugins/copilot/providers/openai';
|
||||
import {
|
||||
normalizeOpenAIOptionsForModel,
|
||||
OpenAIProvider,
|
||||
} from '../../plugins/copilot/providers/openai';
|
||||
import { CopilotProvider } from '../../plugins/copilot/providers/provider';
|
||||
import { normalizeRerankModel } from '../../plugins/copilot/providers/rerank';
|
||||
import {
|
||||
CopilotProviderType,
|
||||
ModelInputType,
|
||||
@@ -46,6 +50,33 @@ class TestOpenAIProvider extends CopilotProvider<{ apiKey: string }> {
|
||||
}
|
||||
}
|
||||
|
||||
class NativeRerankProtocolProvider extends OpenAIProvider {
|
||||
override readonly models = [
|
||||
{
|
||||
id: 'gpt-5.2',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Rerank],
|
||||
defaultForOutputType: true,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
override get config() {
|
||||
return {
|
||||
apiKey: 'test-key',
|
||||
baseURL: 'https://api.openai.com/v1',
|
||||
oldApiStyle: false,
|
||||
};
|
||||
}
|
||||
|
||||
override configured() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
function createProvider(profileMiddleware?: ProviderMiddlewareConfig) {
|
||||
const provider = new TestOpenAIProvider();
|
||||
(provider as any).AFFiNEConfig = {
|
||||
@@ -126,14 +157,44 @@ test('normalizeOpenAIOptionsForModel should keep options for gpt-4.1', t => {
|
||||
);
|
||||
});
|
||||
|
||||
test('normalizeOpenAIRerankModel should keep supported rerank models', t => {
|
||||
t.is(normalizeRerankModel('gpt-4.1'), 'gpt-4.1');
|
||||
t.is(normalizeRerankModel('gpt-4.1-mini'), 'gpt-4.1-mini');
|
||||
t.is(normalizeRerankModel('gpt-5.2'), 'gpt-5.2');
|
||||
});
|
||||
test('OpenAI rerank should always use chat-completions native protocol', async t => {
|
||||
const provider = new NativeRerankProtocolProvider();
|
||||
let capturedProtocol: string | undefined;
|
||||
let capturedRequest: NativeLlmRerankRequest | undefined;
|
||||
|
||||
test('normalizeOpenAIRerankModel should fall back for unsupported models', t => {
|
||||
t.is(normalizeRerankModel('gpt-5-mini'), 'gpt-5.2');
|
||||
t.is(normalizeRerankModel('gemini-2.5-flash'), 'gpt-5.2');
|
||||
t.is(normalizeRerankModel(undefined), 'gpt-5.2');
|
||||
const original = (serverNativeModule as any).llmRerankDispatch;
|
||||
(serverNativeModule as any).llmRerankDispatch = (
|
||||
protocol: string,
|
||||
_backendConfigJson: string,
|
||||
requestJson: string
|
||||
) => {
|
||||
capturedProtocol = protocol;
|
||||
capturedRequest = JSON.parse(requestJson) as NativeLlmRerankRequest;
|
||||
return JSON.stringify({ model: 'gpt-5.2', scores: [0.9, 0.1] });
|
||||
};
|
||||
t.teardown(() => {
|
||||
(serverNativeModule as any).llmRerankDispatch = original;
|
||||
});
|
||||
|
||||
const scores = await provider.rerank(
|
||||
{ modelId: 'gpt-5.2' },
|
||||
{
|
||||
query: 'programming',
|
||||
candidates: [
|
||||
{ id: 'react', text: 'React is a UI library.' },
|
||||
{ id: 'weather', text: 'The weather is sunny today.' },
|
||||
],
|
||||
}
|
||||
);
|
||||
|
||||
t.deepEqual(scores, [0.9, 0.1]);
|
||||
t.is(capturedProtocol, 'openai_chat');
|
||||
t.deepEqual(capturedRequest, {
|
||||
model: 'gpt-5.2',
|
||||
query: 'programming',
|
||||
candidates: [
|
||||
{ id: 'react', text: 'React is a UI library.' },
|
||||
{ id: 'weather', text: 'The weather is sunny today.' },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,12 +1,35 @@
|
||||
import test from 'ava';
|
||||
import { z } from 'zod';
|
||||
|
||||
import type { DocReader } from '../../core/doc';
|
||||
import type { AccessController } from '../../core/permission';
|
||||
import type { Models } from '../../models';
|
||||
import { NativeLlmRequest, NativeLlmStreamEvent } from '../../native';
|
||||
import {
|
||||
ToolCallAccumulator,
|
||||
ToolCallLoop,
|
||||
ToolSchemaExtractor,
|
||||
} from '../../plugins/copilot/providers/loop';
|
||||
import {
|
||||
buildBlobContentGetter,
|
||||
createBlobReadTool,
|
||||
} from '../../plugins/copilot/tools/blob-read';
|
||||
import {
|
||||
buildDocKeywordSearchGetter,
|
||||
createDocKeywordSearchTool,
|
||||
} from '../../plugins/copilot/tools/doc-keyword-search';
|
||||
import {
|
||||
buildDocContentGetter,
|
||||
createDocReadTool,
|
||||
} from '../../plugins/copilot/tools/doc-read';
|
||||
import {
|
||||
buildDocSearchGetter,
|
||||
createDocSemanticSearchTool,
|
||||
} from '../../plugins/copilot/tools/doc-semantic-search';
|
||||
import {
|
||||
DOCUMENT_SYNC_PENDING_MESSAGE,
|
||||
LOCAL_WORKSPACE_SYNC_REQUIRED_MESSAGE,
|
||||
} from '../../plugins/copilot/tools/doc-sync';
|
||||
|
||||
test('ToolCallAccumulator should merge deltas and complete tool call', t => {
|
||||
const accumulator = new ToolCallAccumulator();
|
||||
@@ -34,6 +57,56 @@ test('ToolCallAccumulator should merge deltas and complete tool call', t => {
|
||||
id: 'call_1',
|
||||
name: 'doc_read',
|
||||
args: { doc_id: 'a1' },
|
||||
rawArgumentsText: '{"doc_id":"a1"}',
|
||||
thought: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
test('ToolCallAccumulator should preserve invalid JSON instead of swallowing it', t => {
|
||||
const accumulator = new ToolCallAccumulator();
|
||||
|
||||
accumulator.feedDelta({
|
||||
type: 'tool_call_delta',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments_delta: '{"doc_id":',
|
||||
});
|
||||
|
||||
const pending = accumulator.drainPending();
|
||||
|
||||
t.is(pending.length, 1);
|
||||
t.deepEqual(pending[0]?.id, 'call_1');
|
||||
t.deepEqual(pending[0]?.name, 'doc_read');
|
||||
t.deepEqual(pending[0]?.args, {});
|
||||
t.is(pending[0]?.rawArgumentsText, '{"doc_id":');
|
||||
t.truthy(pending[0]?.argumentParseError);
|
||||
});
|
||||
|
||||
test('ToolCallAccumulator should prefer native canonical tool arguments metadata', t => {
|
||||
const accumulator = new ToolCallAccumulator();
|
||||
|
||||
accumulator.feedDelta({
|
||||
type: 'tool_call_delta',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments_delta: '{"stale":true}',
|
||||
});
|
||||
|
||||
const completed = accumulator.complete({
|
||||
type: 'tool_call',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments: {},
|
||||
arguments_text: '{"doc_id":"a1"}',
|
||||
arguments_error: 'invalid json',
|
||||
});
|
||||
|
||||
t.deepEqual(completed, {
|
||||
id: 'call_1',
|
||||
name: 'doc_read',
|
||||
args: {},
|
||||
rawArgumentsText: '{"doc_id":"a1"}',
|
||||
argumentParseError: 'invalid json',
|
||||
thought: undefined,
|
||||
});
|
||||
});
|
||||
@@ -71,6 +144,8 @@ test('ToolSchemaExtractor should convert zod schema to json schema', t => {
|
||||
|
||||
test('ToolCallLoop should execute tool call and continue to next round', async t => {
|
||||
const dispatchRequests: NativeLlmRequest[] = [];
|
||||
const originalMessages = [{ role: 'user', content: 'read doc' }] as const;
|
||||
const signal = new AbortController().signal;
|
||||
|
||||
const dispatch = (request: NativeLlmRequest) => {
|
||||
dispatchRequests.push(request);
|
||||
@@ -100,13 +175,17 @@ test('ToolCallLoop should execute tool call and continue to next round', async t
|
||||
};
|
||||
|
||||
let executedArgs: Record<string, unknown> | null = null;
|
||||
let executedMessages: unknown;
|
||||
let executedSignal: AbortSignal | undefined;
|
||||
const loop = new ToolCallLoop(
|
||||
dispatch,
|
||||
{
|
||||
doc_read: {
|
||||
inputSchema: z.object({ doc_id: z.string() }),
|
||||
execute: async args => {
|
||||
execute: async (args, options) => {
|
||||
executedArgs = args;
|
||||
executedMessages = options.messages;
|
||||
executedSignal = options.signal;
|
||||
return { markdown: '# doc' };
|
||||
},
|
||||
},
|
||||
@@ -114,6 +193,92 @@ test('ToolCallLoop should execute tool call and continue to next round', async t
|
||||
4
|
||||
);
|
||||
|
||||
const events: NativeLlmStreamEvent[] = [];
|
||||
for await (const event of loop.run(
|
||||
{
|
||||
model: 'gpt-5-mini',
|
||||
stream: true,
|
||||
messages: [
|
||||
{ role: 'user', content: [{ type: 'text', text: 'read doc' }] },
|
||||
],
|
||||
},
|
||||
signal,
|
||||
[...originalMessages]
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
t.deepEqual(executedArgs, { doc_id: 'a1' });
|
||||
t.deepEqual(executedMessages, originalMessages);
|
||||
t.is(executedSignal, signal);
|
||||
t.true(
|
||||
dispatchRequests[1]?.messages.some(message => message.role === 'tool')
|
||||
);
|
||||
t.deepEqual(dispatchRequests[1]?.messages[1]?.content, [
|
||||
{
|
||||
type: 'tool_call',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments: { doc_id: 'a1' },
|
||||
arguments_text: '{"doc_id":"a1"}',
|
||||
arguments_error: undefined,
|
||||
thought: undefined,
|
||||
},
|
||||
]);
|
||||
t.deepEqual(dispatchRequests[1]?.messages[2]?.content, [
|
||||
{
|
||||
type: 'tool_result',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments: { doc_id: 'a1' },
|
||||
arguments_text: '{"doc_id":"a1"}',
|
||||
arguments_error: undefined,
|
||||
output: { markdown: '# doc' },
|
||||
is_error: undefined,
|
||||
},
|
||||
]);
|
||||
t.deepEqual(
|
||||
events.map(event => event.type),
|
||||
['tool_call', 'tool_result', 'text_delta', 'done']
|
||||
);
|
||||
});
|
||||
|
||||
test('ToolCallLoop should surface invalid JSON as tool error without executing', async t => {
|
||||
let executed = false;
|
||||
let round = 0;
|
||||
const loop = new ToolCallLoop(
|
||||
request => {
|
||||
round += 1;
|
||||
const hasToolResult = request.messages.some(
|
||||
message => message.role === 'tool'
|
||||
);
|
||||
return (async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
if (!hasToolResult && round === 1) {
|
||||
yield {
|
||||
type: 'tool_call_delta',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments_delta: '{"doc_id":',
|
||||
};
|
||||
yield { type: 'done', finish_reason: 'tool_calls' };
|
||||
return;
|
||||
}
|
||||
|
||||
yield { type: 'done', finish_reason: 'stop' };
|
||||
})();
|
||||
},
|
||||
{
|
||||
doc_read: {
|
||||
inputSchema: z.object({ doc_id: z.string() }),
|
||||
execute: async () => {
|
||||
executed = true;
|
||||
return { markdown: '# doc' };
|
||||
},
|
||||
},
|
||||
},
|
||||
2
|
||||
);
|
||||
|
||||
const events: NativeLlmStreamEvent[] = [];
|
||||
for await (const event of loop.run({
|
||||
model: 'gpt-5-mini',
|
||||
@@ -123,12 +288,231 @@ test('ToolCallLoop should execute tool call and continue to next round', async t
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
t.deepEqual(executedArgs, { doc_id: 'a1' });
|
||||
t.true(
|
||||
dispatchRequests[1]?.messages.some(message => message.role === 'tool')
|
||||
);
|
||||
t.deepEqual(
|
||||
events.map(event => event.type),
|
||||
['tool_call', 'tool_result', 'text_delta', 'done']
|
||||
);
|
||||
t.false(executed);
|
||||
t.true(events[0]?.type === 'tool_result');
|
||||
t.deepEqual(events[0], {
|
||||
type: 'tool_result',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments: {},
|
||||
arguments_text: '{"doc_id":',
|
||||
arguments_error:
|
||||
events[0]?.type === 'tool_result' ? events[0].arguments_error : undefined,
|
||||
output: {
|
||||
message: 'Invalid tool arguments JSON',
|
||||
rawArguments: '{"doc_id":',
|
||||
error:
|
||||
events[0]?.type === 'tool_result'
|
||||
? events[0].arguments_error
|
||||
: undefined,
|
||||
},
|
||||
is_error: true,
|
||||
});
|
||||
});
|
||||
|
||||
test('doc_read should return specific sync errors for unavailable docs', async t => {
|
||||
const cases = [
|
||||
{
|
||||
name: 'local workspace without cloud sync',
|
||||
workspace: null,
|
||||
authors: null,
|
||||
markdown: null,
|
||||
expected: {
|
||||
type: 'error',
|
||||
name: 'Workspace Sync Required',
|
||||
message: LOCAL_WORKSPACE_SYNC_REQUIRED_MESSAGE,
|
||||
},
|
||||
docReaderCalled: false,
|
||||
},
|
||||
{
|
||||
name: 'cloud workspace document not synced to server yet',
|
||||
workspace: { id: 'ws-1' },
|
||||
authors: null,
|
||||
markdown: null,
|
||||
expected: {
|
||||
type: 'error',
|
||||
name: 'Document Sync Pending',
|
||||
message: DOCUMENT_SYNC_PENDING_MESSAGE('doc-1'),
|
||||
},
|
||||
docReaderCalled: false,
|
||||
},
|
||||
{
|
||||
name: 'cloud workspace document markdown not ready yet',
|
||||
workspace: { id: 'ws-1' },
|
||||
authors: {
|
||||
createdAt: new Date('2026-01-01T00:00:00.000Z'),
|
||||
updatedAt: new Date('2026-01-01T00:00:00.000Z'),
|
||||
createdByUser: null,
|
||||
updatedByUser: null,
|
||||
},
|
||||
markdown: null,
|
||||
expected: {
|
||||
type: 'error',
|
||||
name: 'Document Sync Pending',
|
||||
message: DOCUMENT_SYNC_PENDING_MESSAGE('doc-1'),
|
||||
},
|
||||
docReaderCalled: true,
|
||||
},
|
||||
] as const;
|
||||
|
||||
const ac = {
|
||||
user: () => ({
|
||||
workspace: () => ({ doc: () => ({ can: async () => true }) }),
|
||||
}),
|
||||
} as unknown as AccessController;
|
||||
|
||||
for (const testCase of cases) {
|
||||
let docReaderCalled = false;
|
||||
const docReader = {
|
||||
getDocMarkdown: async () => {
|
||||
docReaderCalled = true;
|
||||
return testCase.markdown;
|
||||
},
|
||||
} as unknown as DocReader;
|
||||
|
||||
const models = {
|
||||
workspace: {
|
||||
get: async () => testCase.workspace,
|
||||
},
|
||||
doc: {
|
||||
getAuthors: async () => testCase.authors,
|
||||
},
|
||||
} as unknown as Models;
|
||||
|
||||
const getDoc = buildDocContentGetter(ac, docReader, models);
|
||||
const tool = createDocReadTool(
|
||||
getDoc.bind(null, {
|
||||
user: 'user-1',
|
||||
workspace: 'workspace-1',
|
||||
})
|
||||
);
|
||||
|
||||
const result = await tool.execute?.({ doc_id: 'doc-1' }, {});
|
||||
|
||||
t.is(docReaderCalled, testCase.docReaderCalled, testCase.name);
|
||||
t.deepEqual(result, testCase.expected, testCase.name);
|
||||
}
|
||||
});
|
||||
|
||||
test('document search tools should return sync error for local workspace', async t => {
|
||||
const ac = {
|
||||
user: () => ({
|
||||
workspace: () => ({
|
||||
can: async () => true,
|
||||
docs: async () => [],
|
||||
}),
|
||||
}),
|
||||
} as unknown as AccessController;
|
||||
|
||||
const models = {
|
||||
workspace: {
|
||||
get: async () => null,
|
||||
},
|
||||
} as unknown as Models;
|
||||
|
||||
let keywordSearchCalled = false;
|
||||
const indexerService = {
|
||||
searchDocsByKeyword: async () => {
|
||||
keywordSearchCalled = true;
|
||||
return [];
|
||||
},
|
||||
} as unknown as Parameters<typeof buildDocKeywordSearchGetter>[1];
|
||||
|
||||
let semanticSearchCalled = false;
|
||||
const contextService = {
|
||||
matchWorkspaceAll: async () => {
|
||||
semanticSearchCalled = true;
|
||||
return [];
|
||||
},
|
||||
} as unknown as Parameters<typeof buildDocSearchGetter>[1];
|
||||
|
||||
const keywordTool = createDocKeywordSearchTool(
|
||||
buildDocKeywordSearchGetter(ac, indexerService, models).bind(null, {
|
||||
user: 'user-1',
|
||||
workspace: 'workspace-1',
|
||||
})
|
||||
);
|
||||
|
||||
const semanticTool = createDocSemanticSearchTool(
|
||||
buildDocSearchGetter(ac, contextService, null, models).bind(null, {
|
||||
user: 'user-1',
|
||||
workspace: 'workspace-1',
|
||||
})
|
||||
);
|
||||
|
||||
const keywordResult = await keywordTool.execute?.({ query: 'hello' }, {});
|
||||
const semanticResult = await semanticTool.execute?.({ query: 'hello' }, {});
|
||||
|
||||
t.false(keywordSearchCalled);
|
||||
t.false(semanticSearchCalled);
|
||||
t.deepEqual(keywordResult, {
|
||||
type: 'error',
|
||||
name: 'Workspace Sync Required',
|
||||
message: LOCAL_WORKSPACE_SYNC_REQUIRED_MESSAGE,
|
||||
});
|
||||
t.deepEqual(semanticResult, {
|
||||
type: 'error',
|
||||
name: 'Workspace Sync Required',
|
||||
message: LOCAL_WORKSPACE_SYNC_REQUIRED_MESSAGE,
|
||||
});
|
||||
});
|
||||
|
||||
test('doc_semantic_search should return empty array when nothing matches', async t => {
|
||||
const ac = {
|
||||
user: () => ({
|
||||
workspace: () => ({
|
||||
can: async () => true,
|
||||
docs: async () => [],
|
||||
}),
|
||||
}),
|
||||
} as unknown as AccessController;
|
||||
|
||||
const models = {
|
||||
workspace: {
|
||||
get: async () => ({ id: 'workspace-1' }),
|
||||
},
|
||||
} as unknown as Models;
|
||||
|
||||
const contextService = {
|
||||
matchWorkspaceAll: async () => [],
|
||||
} as unknown as Parameters<typeof buildDocSearchGetter>[1];
|
||||
|
||||
const semanticTool = createDocSemanticSearchTool(
|
||||
buildDocSearchGetter(ac, contextService, null, models).bind(null, {
|
||||
user: 'user-1',
|
||||
workspace: 'workspace-1',
|
||||
})
|
||||
);
|
||||
|
||||
const result = await semanticTool.execute?.({ query: 'hello' }, {});
|
||||
|
||||
t.deepEqual(result, []);
|
||||
});
|
||||
|
||||
test('blob_read should return explicit error when attachment context is missing', async t => {
|
||||
const ac = {
|
||||
user: () => ({
|
||||
workspace: () => ({
|
||||
allowLocal: () => ({
|
||||
can: async () => true,
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
} as unknown as AccessController;
|
||||
|
||||
const blobTool = createBlobReadTool(
|
||||
buildBlobContentGetter(ac, null).bind(null, {
|
||||
user: 'user-1',
|
||||
workspace: 'workspace-1',
|
||||
})
|
||||
);
|
||||
|
||||
const result = await blobTool.execute?.({ blob_id: 'blob-1' }, {});
|
||||
|
||||
t.deepEqual(result, {
|
||||
type: 'error',
|
||||
name: 'Blob Read Failed',
|
||||
message:
|
||||
'Missing workspace, user, blob id, or copilot context for blob_read.',
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
import test from 'ava';
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
chatToGPTMessage,
|
||||
CitationFootnoteFormatter,
|
||||
CitationParser,
|
||||
StreamPatternParser,
|
||||
} from '../../plugins/copilot/providers/utils';
|
||||
import { CitationFootnoteFormatter } from '../../plugins/copilot/providers/utils';
|
||||
|
||||
test('CitationFootnoteFormatter should format sorted footnotes from citation events', t => {
|
||||
const formatter = new CitationFootnoteFormatter();
|
||||
@@ -50,67 +44,3 @@ test('CitationFootnoteFormatter should overwrite duplicated index with latest ur
|
||||
'[^1]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fnew"}'
|
||||
);
|
||||
});
|
||||
|
||||
test('StreamPatternParser should keep state across chunks', t => {
|
||||
const parser = new StreamPatternParser(pattern => {
|
||||
if (pattern.kind === 'wrappedLink') {
|
||||
return `[^${pattern.url}]`;
|
||||
}
|
||||
if (pattern.kind === 'index') {
|
||||
return `[#${pattern.value}]`;
|
||||
}
|
||||
return `[${pattern.text}](${pattern.url})`;
|
||||
});
|
||||
|
||||
const first = parser.write('ref ([AFFiNE](https://affine.pro');
|
||||
const second = parser.write(')) and [2]');
|
||||
|
||||
t.is(first, 'ref ');
|
||||
t.is(second, '[^https://affine.pro] and [#2]');
|
||||
t.is(parser.end(), '');
|
||||
});
|
||||
|
||||
test('CitationParser should convert wrapped links to numbered footnotes', t => {
|
||||
const parser = new CitationParser();
|
||||
|
||||
const output = parser.parse('Use ([AFFiNE](https://affine.pro)) now');
|
||||
t.is(output, 'Use [^1] now');
|
||||
t.regex(
|
||||
parser.end(),
|
||||
/\[\^1\]: \{"type":"url","url":"https%3A%2F%2Faffine.pro"\}/
|
||||
);
|
||||
});
|
||||
|
||||
test('chatToGPTMessage should not mutate input and should keep system schema', async t => {
|
||||
const schema = z.object({
|
||||
query: z.string(),
|
||||
});
|
||||
const messages = [
|
||||
{
|
||||
role: 'system' as const,
|
||||
content: 'You are helper',
|
||||
params: { schema },
|
||||
},
|
||||
{
|
||||
role: 'user' as const,
|
||||
content: '',
|
||||
attachments: ['https://example.com/a.png'],
|
||||
},
|
||||
];
|
||||
const firstRef = messages[0];
|
||||
const secondRef = messages[1];
|
||||
const [system, normalized, parsedSchema] = await chatToGPTMessage(
|
||||
messages,
|
||||
false
|
||||
);
|
||||
|
||||
t.is(system, 'You are helper');
|
||||
t.is(parsedSchema, schema);
|
||||
t.is(messages.length, 2);
|
||||
t.is(messages[0], firstRef);
|
||||
t.is(messages[1], secondRef);
|
||||
t.deepEqual(normalized[0], {
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: '[no content]' }],
|
||||
});
|
||||
});
|
||||
|
||||
@@ -33,7 +33,7 @@ export class MockCopilotProvider extends OpenAIProvider {
|
||||
id: 'test-image',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text],
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Image],
|
||||
defaultForOutputType: true,
|
||||
},
|
||||
|
||||
@@ -6,13 +6,16 @@ import ava, { TestFn } from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import { AppModule } from '../../app.module';
|
||||
import { ConfigFactory, URLHelper } from '../../base';
|
||||
import { ConfigFactory, InvalidOauthResponse, URLHelper } from '../../base';
|
||||
import { ConfigModule } from '../../base/config';
|
||||
import { CurrentUser } from '../../core/auth';
|
||||
import { AuthService } from '../../core/auth/service';
|
||||
import { ServerFeature } from '../../core/config/types';
|
||||
import { Models } from '../../models';
|
||||
import { OAuthProviderName } from '../../plugins/oauth/config';
|
||||
import { OAuthProviderFactory } from '../../plugins/oauth/factory';
|
||||
import { GoogleOAuthProvider } from '../../plugins/oauth/providers/google';
|
||||
import { OIDCProvider } from '../../plugins/oauth/providers/oidc';
|
||||
import { OAuthService } from '../../plugins/oauth/service';
|
||||
import { createTestingApp, currentUser, TestingApp } from '../utils';
|
||||
|
||||
@@ -35,6 +38,12 @@ test.before(async t => {
|
||||
clientId: 'google-client-id',
|
||||
clientSecret: 'google-client-secret',
|
||||
},
|
||||
oidc: {
|
||||
clientId: '',
|
||||
clientSecret: '',
|
||||
issuer: '',
|
||||
args: {},
|
||||
},
|
||||
},
|
||||
},
|
||||
server: {
|
||||
@@ -432,6 +441,87 @@ function mockOAuthProvider(
|
||||
return clientNonce;
|
||||
}
|
||||
|
||||
function mockOidcProvider(
|
||||
provider: OIDCProvider,
|
||||
{
|
||||
args = {},
|
||||
idTokenClaims,
|
||||
userinfo,
|
||||
}: {
|
||||
args?: Record<string, string>;
|
||||
idTokenClaims: Record<string, unknown>;
|
||||
userinfo: Record<string, unknown>;
|
||||
}
|
||||
) {
|
||||
Sinon.stub(provider, 'config').get(() => ({
|
||||
clientId: '',
|
||||
clientSecret: '',
|
||||
issuer: '',
|
||||
args,
|
||||
}));
|
||||
Sinon.stub(
|
||||
provider as unknown as { endpoints: { userinfo_endpoint: string } },
|
||||
'endpoints'
|
||||
).get(() => ({
|
||||
userinfo_endpoint: 'https://oidc.affine.dev/userinfo',
|
||||
}));
|
||||
Sinon.stub(
|
||||
provider as unknown as { verifyIdToken: () => unknown },
|
||||
'verifyIdToken'
|
||||
).resolves(idTokenClaims);
|
||||
Sinon.stub(
|
||||
provider as unknown as { fetchJson: () => unknown },
|
||||
'fetchJson'
|
||||
).resolves(userinfo);
|
||||
}
|
||||
|
||||
function createOidcRegistrationHarness(config?: {
|
||||
clientId?: string;
|
||||
clientSecret?: string;
|
||||
issuer?: string;
|
||||
}) {
|
||||
const server = {
|
||||
enableFeature: Sinon.spy(),
|
||||
disableFeature: Sinon.spy(),
|
||||
};
|
||||
const factory = new OAuthProviderFactory(server as any);
|
||||
const affineConfig = {
|
||||
server: {
|
||||
externalUrl: 'https://affine.example',
|
||||
host: 'localhost',
|
||||
path: '',
|
||||
https: true,
|
||||
hosts: [],
|
||||
},
|
||||
oauth: {
|
||||
providers: {
|
||||
oidc: {
|
||||
clientId: config?.clientId ?? 'oidc-client-id',
|
||||
clientSecret: config?.clientSecret ?? 'oidc-client-secret',
|
||||
issuer: config?.issuer ?? 'https://issuer.affine.dev',
|
||||
args: {},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
const provider = new OIDCProvider(new URLHelper(affineConfig as any));
|
||||
|
||||
(provider as any).factory = factory;
|
||||
(provider as any).AFFiNEConfig = affineConfig;
|
||||
|
||||
return {
|
||||
provider,
|
||||
factory,
|
||||
server,
|
||||
};
|
||||
}
|
||||
|
||||
async function flushAsyncWork(iterations = 5) {
|
||||
for (let i = 0; i < iterations; i++) {
|
||||
await new Promise(resolve => setImmediate(resolve));
|
||||
}
|
||||
}
|
||||
|
||||
test('should be able to sign up with oauth', async t => {
|
||||
const { app, db } = t.context;
|
||||
|
||||
@@ -554,3 +644,209 @@ test('should be able to fullfil user with oauth sign in', async t => {
|
||||
t.truthy(account);
|
||||
t.is(account!.user.id, u3.id);
|
||||
});
|
||||
|
||||
test('oidc should accept email from id token when userinfo email is missing', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const provider = app.get(OIDCProvider);
|
||||
mockOidcProvider(provider, {
|
||||
idTokenClaims: {
|
||||
sub: 'oidc-user',
|
||||
email: 'oidc-id-token@affine.pro',
|
||||
name: 'OIDC User',
|
||||
},
|
||||
userinfo: {
|
||||
sub: 'oidc-user',
|
||||
name: 'OIDC User',
|
||||
},
|
||||
});
|
||||
|
||||
const user = await provider.getUser(
|
||||
{ accessToken: 'token', idToken: 'id-token' },
|
||||
{ token: 'nonce', provider: OAuthProviderName.OIDC }
|
||||
);
|
||||
|
||||
t.is(user.id, 'oidc-user');
|
||||
t.is(user.email, 'oidc-id-token@affine.pro');
|
||||
t.is(user.name, 'OIDC User');
|
||||
});
|
||||
|
||||
test('oidc should resolve custom email claim from userinfo', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const provider = app.get(OIDCProvider);
|
||||
mockOidcProvider(provider, {
|
||||
args: { claim_email: 'mail', claim_name: 'display_name' },
|
||||
idTokenClaims: {
|
||||
sub: 'oidc-user',
|
||||
},
|
||||
userinfo: {
|
||||
sub: 'oidc-user',
|
||||
mail: 'oidc-userinfo@affine.pro',
|
||||
display_name: 'OIDC Custom',
|
||||
},
|
||||
});
|
||||
|
||||
const user = await provider.getUser(
|
||||
{ accessToken: 'token', idToken: 'id-token' },
|
||||
{ token: 'nonce', provider: OAuthProviderName.OIDC }
|
||||
);
|
||||
|
||||
t.is(user.id, 'oidc-user');
|
||||
t.is(user.email, 'oidc-userinfo@affine.pro');
|
||||
t.is(user.name, 'OIDC Custom');
|
||||
});
|
||||
|
||||
test('oidc should resolve custom email claim from id token', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const provider = app.get(OIDCProvider);
|
||||
mockOidcProvider(provider, {
|
||||
args: { claim_email: 'mail', claim_email_verified: 'mail_verified' },
|
||||
idTokenClaims: {
|
||||
sub: 'oidc-user',
|
||||
mail: 'oidc-custom-id-token@affine.pro',
|
||||
mail_verified: 'true',
|
||||
},
|
||||
userinfo: {
|
||||
sub: 'oidc-user',
|
||||
},
|
||||
});
|
||||
|
||||
const user = await provider.getUser(
|
||||
{ accessToken: 'token', idToken: 'id-token' },
|
||||
{ token: 'nonce', provider: OAuthProviderName.OIDC }
|
||||
);
|
||||
|
||||
t.is(user.id, 'oidc-user');
|
||||
t.is(user.email, 'oidc-custom-id-token@affine.pro');
|
||||
});
|
||||
|
||||
test('oidc should reject responses without a usable email claim', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const provider = app.get(OIDCProvider);
|
||||
mockOidcProvider(provider, {
|
||||
args: { claim_email: 'mail' },
|
||||
idTokenClaims: {
|
||||
sub: 'oidc-user',
|
||||
mail: 'not-an-email',
|
||||
},
|
||||
userinfo: {
|
||||
sub: 'oidc-user',
|
||||
mail: 'still-not-an-email',
|
||||
},
|
||||
});
|
||||
|
||||
const error = await t.throwsAsync(
|
||||
provider.getUser(
|
||||
{ accessToken: 'token', idToken: 'id-token' },
|
||||
{ token: 'nonce', provider: OAuthProviderName.OIDC }
|
||||
)
|
||||
);
|
||||
|
||||
t.true(error instanceof InvalidOauthResponse);
|
||||
t.true(
|
||||
error.message.includes(
|
||||
'Missing valid email claim in OIDC response. Tried userinfo and ID token claims: "mail"'
|
||||
)
|
||||
);
|
||||
});
|
||||
|
||||
test('oidc should not fall back to default email claim when custom claim is configured', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const provider = app.get(OIDCProvider);
|
||||
mockOidcProvider(provider, {
|
||||
args: { claim_email: 'mail' },
|
||||
idTokenClaims: {
|
||||
sub: 'oidc-user',
|
||||
email: 'fallback@affine.pro',
|
||||
},
|
||||
userinfo: {
|
||||
sub: 'oidc-user',
|
||||
email: 'userinfo-fallback@affine.pro',
|
||||
},
|
||||
});
|
||||
|
||||
const error = await t.throwsAsync(
|
||||
provider.getUser(
|
||||
{ accessToken: 'token', idToken: 'id-token' },
|
||||
{ token: 'nonce', provider: OAuthProviderName.OIDC }
|
||||
)
|
||||
);
|
||||
|
||||
t.true(error instanceof InvalidOauthResponse);
|
||||
t.true(
|
||||
error.message.includes(
|
||||
'Missing valid email claim in OIDC response. Tried userinfo and ID token claims: "mail"'
|
||||
)
|
||||
);
|
||||
});
|
||||
|
||||
test('oidc discovery should remove oauth feature on failure and restore it after backoff retry succeeds', async t => {
|
||||
const { provider, factory, server } = createOidcRegistrationHarness();
|
||||
const fetchStub = Sinon.stub(globalThis, 'fetch');
|
||||
const scheduledRetries: Array<() => void> = [];
|
||||
const retryDelays: number[] = [];
|
||||
const setTimeoutStub = Sinon.stub(globalThis, 'setTimeout').callsFake(((
|
||||
callback: Parameters<typeof setTimeout>[0],
|
||||
delay?: number
|
||||
) => {
|
||||
retryDelays.push(Number(delay));
|
||||
scheduledRetries.push(callback as () => void);
|
||||
return Symbol('timeout') as unknown as ReturnType<typeof setTimeout>;
|
||||
}) as typeof setTimeout);
|
||||
t.teardown(() => {
|
||||
provider.onModuleDestroy();
|
||||
fetchStub.restore();
|
||||
setTimeoutStub.restore();
|
||||
});
|
||||
|
||||
fetchStub
|
||||
.onFirstCall()
|
||||
.rejects(new Error('temporary discovery failure'))
|
||||
.onSecondCall()
|
||||
.rejects(new Error('temporary discovery failure'))
|
||||
.onThirdCall()
|
||||
.resolves(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
authorization_endpoint: 'https://issuer.affine.dev/auth',
|
||||
token_endpoint: 'https://issuer.affine.dev/token',
|
||||
userinfo_endpoint: 'https://issuer.affine.dev/userinfo',
|
||||
issuer: 'https://issuer.affine.dev',
|
||||
jwks_uri: 'https://issuer.affine.dev/jwks',
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
(provider as any).setup();
|
||||
|
||||
await flushAsyncWork();
|
||||
t.deepEqual(factory.providers, []);
|
||||
t.true(server.disableFeature.calledWith(ServerFeature.OAuth));
|
||||
t.is(fetchStub.callCount, 1);
|
||||
t.deepEqual(retryDelays, [1000]);
|
||||
|
||||
const firstRetry = scheduledRetries.shift();
|
||||
t.truthy(firstRetry);
|
||||
firstRetry!();
|
||||
await flushAsyncWork();
|
||||
t.is(fetchStub.callCount, 2);
|
||||
t.deepEqual(factory.providers, []);
|
||||
t.deepEqual(retryDelays, [1000, 2000]);
|
||||
|
||||
const secondRetry = scheduledRetries.shift();
|
||||
t.truthy(secondRetry);
|
||||
secondRetry!();
|
||||
await flushAsyncWork();
|
||||
t.is(fetchStub.callCount, 3);
|
||||
t.deepEqual(factory.providers, [OAuthProviderName.OIDC]);
|
||||
t.true(server.enableFeature.calledWith(ServerFeature.OAuth));
|
||||
t.is(scheduledRetries.length, 0);
|
||||
});
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
import test from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import {
|
||||
exponentialBackoffDelay,
|
||||
ExponentialBackoffScheduler,
|
||||
} from '../promise';
|
||||
|
||||
test('exponentialBackoffDelay should cap exponential growth at maxDelayMs', t => {
|
||||
t.is(exponentialBackoffDelay(0, { baseDelayMs: 100, maxDelayMs: 500 }), 100);
|
||||
t.is(exponentialBackoffDelay(1, { baseDelayMs: 100, maxDelayMs: 500 }), 200);
|
||||
t.is(exponentialBackoffDelay(3, { baseDelayMs: 100, maxDelayMs: 500 }), 500);
|
||||
});
|
||||
|
||||
test('ExponentialBackoffScheduler should track pending callback and increase delay per attempt', async t => {
|
||||
const clock = Sinon.useFakeTimers();
|
||||
t.teardown(() => {
|
||||
clock.restore();
|
||||
});
|
||||
|
||||
const calls: number[] = [];
|
||||
const scheduler = new ExponentialBackoffScheduler({
|
||||
baseDelayMs: 100,
|
||||
maxDelayMs: 500,
|
||||
});
|
||||
|
||||
t.is(
|
||||
scheduler.schedule(() => {
|
||||
calls.push(1);
|
||||
}),
|
||||
100
|
||||
);
|
||||
t.true(scheduler.pending);
|
||||
t.is(
|
||||
scheduler.schedule(() => {
|
||||
calls.push(2);
|
||||
}),
|
||||
null
|
||||
);
|
||||
|
||||
await clock.tickAsync(100);
|
||||
t.deepEqual(calls, [1]);
|
||||
t.false(scheduler.pending);
|
||||
|
||||
t.is(
|
||||
scheduler.schedule(() => {
|
||||
calls.push(3);
|
||||
}),
|
||||
200
|
||||
);
|
||||
await clock.tickAsync(200);
|
||||
t.deepEqual(calls, [1, 3]);
|
||||
});
|
||||
|
||||
test('ExponentialBackoffScheduler reset should clear pending work and restart from the base delay', t => {
|
||||
const scheduler = new ExponentialBackoffScheduler({
|
||||
baseDelayMs: 100,
|
||||
maxDelayMs: 500,
|
||||
});
|
||||
|
||||
t.is(
|
||||
scheduler.schedule(() => {}),
|
||||
100
|
||||
);
|
||||
t.true(scheduler.pending);
|
||||
|
||||
scheduler.reset();
|
||||
t.false(scheduler.pending);
|
||||
t.is(
|
||||
scheduler.schedule(() => {}),
|
||||
100
|
||||
);
|
||||
|
||||
scheduler.clear();
|
||||
});
|
||||
@@ -1,4 +1,4 @@
|
||||
import { setTimeout } from 'node:timers/promises';
|
||||
import { setTimeout as delay } from 'node:timers/promises';
|
||||
|
||||
import { defer as rxjsDefer, retry } from 'rxjs';
|
||||
|
||||
@@ -52,5 +52,61 @@ export function defer(dispose: () => Promise<void>) {
|
||||
}
|
||||
|
||||
export function sleep(ms: number): Promise<void> {
|
||||
return setTimeout(ms);
|
||||
return delay(ms);
|
||||
}
|
||||
|
||||
export function exponentialBackoffDelay(
|
||||
attempt: number,
|
||||
{
|
||||
baseDelayMs,
|
||||
maxDelayMs,
|
||||
factor = 2,
|
||||
}: { baseDelayMs: number; maxDelayMs: number; factor?: number }
|
||||
): number {
|
||||
return Math.min(
|
||||
baseDelayMs * Math.pow(factor, Math.max(0, attempt)),
|
||||
maxDelayMs
|
||||
);
|
||||
}
|
||||
|
||||
export class ExponentialBackoffScheduler {
|
||||
#attempt = 0;
|
||||
#timer: ReturnType<typeof globalThis.setTimeout> | null = null;
|
||||
|
||||
constructor(
|
||||
private readonly options: {
|
||||
baseDelayMs: number;
|
||||
maxDelayMs: number;
|
||||
factor?: number;
|
||||
}
|
||||
) {}
|
||||
|
||||
get pending() {
|
||||
return this.#timer !== null;
|
||||
}
|
||||
|
||||
clear() {
|
||||
if (this.#timer) {
|
||||
clearTimeout(this.#timer);
|
||||
this.#timer = null;
|
||||
}
|
||||
}
|
||||
|
||||
reset() {
|
||||
this.#attempt = 0;
|
||||
this.clear();
|
||||
}
|
||||
|
||||
schedule(callback: () => void) {
|
||||
if (this.#timer) return null;
|
||||
|
||||
const timeout = exponentialBackoffDelay(this.#attempt, this.options);
|
||||
this.#timer = globalThis.setTimeout(() => {
|
||||
this.#timer = null;
|
||||
callback();
|
||||
}, timeout);
|
||||
this.#attempt += 1;
|
||||
|
||||
return timeout;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
CopilotSessionNotFound,
|
||||
} from '../base';
|
||||
import { getTokenEncoder } from '../native';
|
||||
import type { PromptAttachment } from '../plugins/copilot/providers/types';
|
||||
import { BaseModel } from './base';
|
||||
|
||||
export enum SessionType {
|
||||
@@ -24,7 +25,7 @@ type ChatPrompt = {
|
||||
model: string;
|
||||
};
|
||||
|
||||
type ChatAttachment = { attachment: string; mimeType: string } | string;
|
||||
type ChatAttachment = PromptAttachment;
|
||||
|
||||
type ChatStreamObject = {
|
||||
type: 'text-delta' | 'reasoning' | 'tool-call' | 'tool-result';
|
||||
@@ -173,22 +174,105 @@ export class CopilotSessionModel extends BaseModel {
|
||||
}
|
||||
|
||||
return attachments
|
||||
.map(attachment =>
|
||||
typeof attachment === 'string'
|
||||
? (this.sanitizeString(attachment) ?? '')
|
||||
: {
|
||||
attachment:
|
||||
this.sanitizeString(attachment.attachment) ??
|
||||
attachment.attachment,
|
||||
.map(attachment => {
|
||||
if (typeof attachment === 'string') {
|
||||
return this.sanitizeString(attachment) ?? '';
|
||||
}
|
||||
|
||||
if ('attachment' in attachment) {
|
||||
return {
|
||||
attachment:
|
||||
this.sanitizeString(attachment.attachment) ??
|
||||
attachment.attachment,
|
||||
mimeType:
|
||||
this.sanitizeString(attachment.mimeType) ?? attachment.mimeType,
|
||||
};
|
||||
}
|
||||
|
||||
switch (attachment.kind) {
|
||||
case 'url':
|
||||
return {
|
||||
...attachment,
|
||||
url: this.sanitizeString(attachment.url) ?? attachment.url,
|
||||
mimeType:
|
||||
this.sanitizeString(attachment.mimeType) ?? attachment.mimeType,
|
||||
}
|
||||
)
|
||||
fileName:
|
||||
this.sanitizeString(attachment.fileName) ?? attachment.fileName,
|
||||
providerHint: attachment.providerHint
|
||||
? {
|
||||
provider:
|
||||
this.sanitizeString(attachment.providerHint.provider) ??
|
||||
attachment.providerHint.provider,
|
||||
kind:
|
||||
this.sanitizeString(attachment.providerHint.kind) ??
|
||||
attachment.providerHint.kind,
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
case 'data':
|
||||
case 'bytes':
|
||||
return {
|
||||
...attachment,
|
||||
data: this.sanitizeString(attachment.data) ?? attachment.data,
|
||||
mimeType:
|
||||
this.sanitizeString(attachment.mimeType) ?? attachment.mimeType,
|
||||
fileName:
|
||||
this.sanitizeString(attachment.fileName) ?? attachment.fileName,
|
||||
providerHint: attachment.providerHint
|
||||
? {
|
||||
provider:
|
||||
this.sanitizeString(attachment.providerHint.provider) ??
|
||||
attachment.providerHint.provider,
|
||||
kind:
|
||||
this.sanitizeString(attachment.providerHint.kind) ??
|
||||
attachment.providerHint.kind,
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
case 'file_handle':
|
||||
return {
|
||||
...attachment,
|
||||
fileHandle:
|
||||
this.sanitizeString(attachment.fileHandle) ??
|
||||
attachment.fileHandle,
|
||||
mimeType:
|
||||
this.sanitizeString(attachment.mimeType) ?? attachment.mimeType,
|
||||
fileName:
|
||||
this.sanitizeString(attachment.fileName) ?? attachment.fileName,
|
||||
providerHint: attachment.providerHint
|
||||
? {
|
||||
provider:
|
||||
this.sanitizeString(attachment.providerHint.provider) ??
|
||||
attachment.providerHint.provider,
|
||||
kind:
|
||||
this.sanitizeString(attachment.providerHint.kind) ??
|
||||
attachment.providerHint.kind,
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return attachment;
|
||||
})
|
||||
.filter(attachment => {
|
||||
if (typeof attachment === 'string') {
|
||||
return !!attachment;
|
||||
}
|
||||
return !!attachment.attachment && !!attachment.mimeType;
|
||||
if ('attachment' in attachment) {
|
||||
return !!attachment.attachment && !!attachment.mimeType;
|
||||
}
|
||||
|
||||
switch (attachment.kind) {
|
||||
case 'url':
|
||||
return !!attachment.url;
|
||||
case 'data':
|
||||
case 'bytes':
|
||||
return !!attachment.data && !!attachment.mimeType;
|
||||
case 'file_handle':
|
||||
return !!attachment.fileHandle;
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -65,6 +65,21 @@ type NativeLlmModule = {
|
||||
backendConfigJson: string,
|
||||
requestJson: string
|
||||
) => string | Promise<string>;
|
||||
llmStructuredDispatch?: (
|
||||
protocol: string,
|
||||
backendConfigJson: string,
|
||||
requestJson: string
|
||||
) => string | Promise<string>;
|
||||
llmEmbeddingDispatch?: (
|
||||
protocol: string,
|
||||
backendConfigJson: string,
|
||||
requestJson: string
|
||||
) => string | Promise<string>;
|
||||
llmRerankDispatch?: (
|
||||
protocol: string,
|
||||
backendConfigJson: string,
|
||||
requestJson: string
|
||||
) => string | Promise<string>;
|
||||
llmDispatchStream?: (
|
||||
protocol: string,
|
||||
backendConfigJson: string,
|
||||
@@ -79,12 +94,20 @@ const nativeLlmModule = serverNativeModule as typeof serverNativeModule &
|
||||
export type NativeLlmProtocol =
|
||||
| 'openai_chat'
|
||||
| 'openai_responses'
|
||||
| 'anthropic';
|
||||
| 'anthropic'
|
||||
| 'gemini';
|
||||
|
||||
export type NativeLlmBackendConfig = {
|
||||
base_url: string;
|
||||
auth_token: string;
|
||||
request_layer?: 'anthropic' | 'chat_completions' | 'responses' | 'vertex';
|
||||
request_layer?:
|
||||
| 'anthropic'
|
||||
| 'chat_completions'
|
||||
| 'responses'
|
||||
| 'vertex'
|
||||
| 'vertex_anthropic'
|
||||
| 'gemini_api'
|
||||
| 'gemini_vertex';
|
||||
headers?: Record<string, string>;
|
||||
no_streaming?: boolean;
|
||||
timeout_ms?: number;
|
||||
@@ -100,6 +123,8 @@ export type NativeLlmCoreContent =
|
||||
call_id: string;
|
||||
name: string;
|
||||
arguments: Record<string, unknown>;
|
||||
arguments_text?: string;
|
||||
arguments_error?: string;
|
||||
thought?: string;
|
||||
}
|
||||
| {
|
||||
@@ -109,8 +134,12 @@ export type NativeLlmCoreContent =
|
||||
is_error?: boolean;
|
||||
name?: string;
|
||||
arguments?: Record<string, unknown>;
|
||||
arguments_text?: string;
|
||||
arguments_error?: string;
|
||||
}
|
||||
| { type: 'image'; source: Record<string, unknown> | string };
|
||||
| { type: 'image'; source: Record<string, unknown> | string }
|
||||
| { type: 'audio'; source: Record<string, unknown> | string }
|
||||
| { type: 'file'; source: Record<string, unknown> | string };
|
||||
|
||||
export type NativeLlmCoreMessage = {
|
||||
role: NativeLlmCoreRole;
|
||||
@@ -133,22 +162,54 @@ export type NativeLlmRequest = {
|
||||
tool_choice?: 'auto' | 'none' | 'required' | { name: string };
|
||||
include?: string[];
|
||||
reasoning?: Record<string, unknown>;
|
||||
response_schema?: Record<string, unknown>;
|
||||
middleware?: {
|
||||
request?: Array<
|
||||
'normalize_messages' | 'clamp_max_tokens' | 'tool_schema_rewrite'
|
||||
>;
|
||||
stream?: Array<'stream_event_normalize' | 'citation_indexing'>;
|
||||
config?: {
|
||||
no_additional_properties?: boolean;
|
||||
drop_property_format?: boolean;
|
||||
drop_property_min_length?: boolean;
|
||||
drop_array_min_items?: boolean;
|
||||
drop_array_max_items?: boolean;
|
||||
additional_properties_policy?: 'preserve' | 'forbid';
|
||||
property_format_policy?: 'preserve' | 'drop';
|
||||
property_min_length_policy?: 'preserve' | 'drop';
|
||||
array_min_items_policy?: 'preserve' | 'drop';
|
||||
array_max_items_policy?: 'preserve' | 'drop';
|
||||
max_tokens_cap?: number;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
export type NativeLlmStructuredRequest = {
|
||||
model: string;
|
||||
messages: NativeLlmCoreMessage[];
|
||||
schema: Record<string, unknown>;
|
||||
max_tokens?: number;
|
||||
temperature?: number;
|
||||
reasoning?: Record<string, unknown>;
|
||||
strict?: boolean;
|
||||
response_mime_type?: string;
|
||||
middleware?: NativeLlmRequest['middleware'];
|
||||
};
|
||||
|
||||
export type NativeLlmEmbeddingRequest = {
|
||||
model: string;
|
||||
inputs: string[];
|
||||
dimensions?: number;
|
||||
task_type?: string;
|
||||
};
|
||||
|
||||
export type NativeLlmRerankCandidate = {
|
||||
id?: string;
|
||||
text: string;
|
||||
};
|
||||
|
||||
export type NativeLlmRerankRequest = {
|
||||
model: string;
|
||||
query: string;
|
||||
candidates: NativeLlmRerankCandidate[];
|
||||
top_n?: number;
|
||||
};
|
||||
|
||||
export type NativeLlmDispatchResponse = {
|
||||
id: string;
|
||||
model: string;
|
||||
@@ -159,10 +220,39 @@ export type NativeLlmDispatchResponse = {
|
||||
total_tokens: number;
|
||||
cached_tokens?: number;
|
||||
};
|
||||
finish_reason: string;
|
||||
finish_reason:
|
||||
| 'stop'
|
||||
| 'length'
|
||||
| 'tool_calls'
|
||||
| 'content_filter'
|
||||
| 'error'
|
||||
| string;
|
||||
reasoning_details?: unknown;
|
||||
};
|
||||
|
||||
export type NativeLlmStructuredResponse = {
|
||||
id: string;
|
||||
model: string;
|
||||
output_text: string;
|
||||
usage: NativeLlmDispatchResponse['usage'];
|
||||
finish_reason: NativeLlmDispatchResponse['finish_reason'];
|
||||
reasoning_details?: unknown;
|
||||
};
|
||||
|
||||
export type NativeLlmEmbeddingResponse = {
|
||||
model: string;
|
||||
embeddings: number[][];
|
||||
usage?: {
|
||||
prompt_tokens: number;
|
||||
total_tokens: number;
|
||||
};
|
||||
};
|
||||
|
||||
export type NativeLlmRerankResponse = {
|
||||
model: string;
|
||||
scores: number[];
|
||||
};
|
||||
|
||||
export type NativeLlmStreamEvent =
|
||||
| { type: 'message_start'; id?: string; model?: string }
|
||||
| { type: 'text_delta'; text: string }
|
||||
@@ -178,6 +268,8 @@ export type NativeLlmStreamEvent =
|
||||
call_id: string;
|
||||
name: string;
|
||||
arguments: Record<string, unknown>;
|
||||
arguments_text?: string;
|
||||
arguments_error?: string;
|
||||
thought?: string;
|
||||
}
|
||||
| {
|
||||
@@ -187,6 +279,8 @@ export type NativeLlmStreamEvent =
|
||||
is_error?: boolean;
|
||||
name?: string;
|
||||
arguments?: Record<string, unknown>;
|
||||
arguments_text?: string;
|
||||
arguments_error?: string;
|
||||
}
|
||||
| { type: 'citation'; index: number; url: string }
|
||||
| {
|
||||
@@ -200,7 +294,7 @@ export type NativeLlmStreamEvent =
|
||||
}
|
||||
| {
|
||||
type: 'done';
|
||||
finish_reason?: string;
|
||||
finish_reason?: NativeLlmDispatchResponse['finish_reason'];
|
||||
usage?: {
|
||||
prompt_tokens: number;
|
||||
completion_tokens: number;
|
||||
@@ -228,6 +322,57 @@ export async function llmDispatch(
|
||||
return JSON.parse(responseText) as NativeLlmDispatchResponse;
|
||||
}
|
||||
|
||||
export async function llmStructuredDispatch(
|
||||
protocol: NativeLlmProtocol,
|
||||
backendConfig: NativeLlmBackendConfig,
|
||||
request: NativeLlmStructuredRequest
|
||||
): Promise<NativeLlmStructuredResponse> {
|
||||
if (!nativeLlmModule.llmStructuredDispatch) {
|
||||
throw new Error('native llm structured dispatch is not available');
|
||||
}
|
||||
const response = nativeLlmModule.llmStructuredDispatch(
|
||||
protocol,
|
||||
JSON.stringify(backendConfig),
|
||||
JSON.stringify(request)
|
||||
);
|
||||
const responseText = await Promise.resolve(response);
|
||||
return JSON.parse(responseText) as NativeLlmStructuredResponse;
|
||||
}
|
||||
|
||||
export async function llmEmbeddingDispatch(
|
||||
protocol: NativeLlmProtocol,
|
||||
backendConfig: NativeLlmBackendConfig,
|
||||
request: NativeLlmEmbeddingRequest
|
||||
): Promise<NativeLlmEmbeddingResponse> {
|
||||
if (!nativeLlmModule.llmEmbeddingDispatch) {
|
||||
throw new Error('native llm embedding dispatch is not available');
|
||||
}
|
||||
const response = nativeLlmModule.llmEmbeddingDispatch(
|
||||
protocol,
|
||||
JSON.stringify(backendConfig),
|
||||
JSON.stringify(request)
|
||||
);
|
||||
const responseText = await Promise.resolve(response);
|
||||
return JSON.parse(responseText) as NativeLlmEmbeddingResponse;
|
||||
}
|
||||
|
||||
export async function llmRerankDispatch(
|
||||
protocol: NativeLlmProtocol,
|
||||
backendConfig: NativeLlmBackendConfig,
|
||||
request: NativeLlmRerankRequest
|
||||
): Promise<NativeLlmRerankResponse> {
|
||||
if (!nativeLlmModule.llmRerankDispatch) {
|
||||
throw new Error('native llm rerank dispatch is not available');
|
||||
}
|
||||
const response = nativeLlmModule.llmRerankDispatch(
|
||||
protocol,
|
||||
JSON.stringify(backendConfig),
|
||||
JSON.stringify(request)
|
||||
);
|
||||
const responseText = await Promise.resolve(response);
|
||||
return JSON.parse(responseText) as NativeLlmRerankResponse;
|
||||
}
|
||||
|
||||
export class NativeStreamAdapter<T> implements AsyncIterableIterator<T> {
|
||||
readonly #queue: T[] = [];
|
||||
readonly #waiters: ((result: IteratorResult<T>) => void)[] = [];
|
||||
|
||||
@@ -81,7 +81,7 @@ export type CopilotProviderProfile = CopilotProviderProfileCommon &
|
||||
}[CopilotProviderType];
|
||||
|
||||
export type CopilotProviderDefaults = Partial<
|
||||
Record<ModelOutputType, string>
|
||||
Record<Exclude<ModelOutputType, ModelOutputType.Rerank>, string>
|
||||
> & {
|
||||
fallback?: string;
|
||||
};
|
||||
@@ -184,6 +184,7 @@ const CopilotProviderDefaultsShape = z.object({
|
||||
[ModelOutputType.Object]: z.string().optional(),
|
||||
[ModelOutputType.Embedding]: z.string().optional(),
|
||||
[ModelOutputType.Image]: z.string().optional(),
|
||||
[ModelOutputType.Rerank]: z.string().optional(),
|
||||
[ModelOutputType.Structured]: z.string().optional(),
|
||||
fallback: z.string().optional(),
|
||||
});
|
||||
|
||||
@@ -1,25 +1,17 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import type { ModuleRef } from '@nestjs/core';
|
||||
|
||||
import {
|
||||
Config,
|
||||
CopilotPromptNotFound,
|
||||
CopilotProviderNotSupported,
|
||||
} from '../../../base';
|
||||
import { Config, CopilotProviderNotSupported } from '../../../base';
|
||||
import { CopilotFailedToGenerateEmbedding } from '../../../base/error/errors.gen';
|
||||
import {
|
||||
ChunkSimilarity,
|
||||
Embedding,
|
||||
EMBEDDING_DIMENSIONS,
|
||||
} from '../../../models';
|
||||
import { PromptService } from '../prompt/service';
|
||||
import { CopilotProviderFactory } from '../providers/factory';
|
||||
import type { CopilotProvider } from '../providers/provider';
|
||||
import {
|
||||
DEFAULT_RERANK_MODEL,
|
||||
normalizeRerankModel,
|
||||
} from '../providers/rerank';
|
||||
import {
|
||||
type CopilotRerankRequest,
|
||||
type ModelFullConditions,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
@@ -27,24 +19,20 @@ import {
|
||||
import { EmbeddingClient, type ReRankResult } from './types';
|
||||
|
||||
const EMBEDDING_MODEL = 'gemini-embedding-001';
|
||||
const RERANK_PROMPT = 'Rerank results';
|
||||
|
||||
const RERANK_MODEL = 'gpt-5.2';
|
||||
class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
private readonly logger = new Logger(ProductionEmbeddingClient.name);
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly providerFactory: CopilotProviderFactory,
|
||||
private readonly prompt: PromptService
|
||||
private readonly providerFactory: CopilotProviderFactory
|
||||
) {
|
||||
super();
|
||||
}
|
||||
|
||||
override async configured(): Promise<boolean> {
|
||||
const embedding = await this.providerFactory.getProvider({
|
||||
modelId: this.config.copilot?.scenarios?.override_enabled
|
||||
? this.config.copilot.scenarios.scenarios?.embedding || EMBEDDING_MODEL
|
||||
: EMBEDDING_MODEL,
|
||||
modelId: this.getEmbeddingModelId(),
|
||||
outputType: ModelOutputType.Embedding,
|
||||
});
|
||||
const result = Boolean(embedding);
|
||||
@@ -69,9 +57,15 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
return provider;
|
||||
}
|
||||
|
||||
private getEmbeddingModelId() {
|
||||
return this.config.copilot?.scenarios?.override_enabled
|
||||
? this.config.copilot.scenarios.scenarios?.embedding || EMBEDDING_MODEL
|
||||
: EMBEDDING_MODEL;
|
||||
}
|
||||
|
||||
async getEmbeddings(input: string[]): Promise<Embedding[]> {
|
||||
const provider = await this.getProvider({
|
||||
modelId: EMBEDDING_MODEL,
|
||||
modelId: this.getEmbeddingModelId(),
|
||||
outputType: ModelOutputType.Embedding,
|
||||
});
|
||||
this.logger.verbose(
|
||||
@@ -114,21 +108,22 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
): Promise<ReRankResult> {
|
||||
if (!embeddings.length) return [];
|
||||
|
||||
const prompt = await this.prompt.get(RERANK_PROMPT);
|
||||
if (!prompt) {
|
||||
throw new CopilotPromptNotFound({ name: RERANK_PROMPT });
|
||||
}
|
||||
const rerankModel = normalizeRerankModel(prompt.model);
|
||||
if (prompt.model !== rerankModel) {
|
||||
this.logger.warn(
|
||||
`Unsupported rerank model "${prompt.model}" configured, falling back to "${DEFAULT_RERANK_MODEL}".`
|
||||
);
|
||||
}
|
||||
const provider = await this.getProvider({ modelId: rerankModel });
|
||||
const provider = await this.getProvider({
|
||||
modelId: RERANK_MODEL,
|
||||
outputType: ModelOutputType.Rerank,
|
||||
});
|
||||
|
||||
const rerankRequest: CopilotRerankRequest = {
|
||||
query,
|
||||
candidates: embeddings.map((embedding, index) => ({
|
||||
id: String(index),
|
||||
text: embedding.content,
|
||||
})),
|
||||
};
|
||||
|
||||
const ranks = await provider.rerank(
|
||||
{ modelId: rerankModel },
|
||||
embeddings.map(e => prompt.finish({ query, doc: e.content })),
|
||||
{ modelId: RERANK_MODEL },
|
||||
rerankRequest,
|
||||
{ signal }
|
||||
);
|
||||
|
||||
@@ -227,9 +222,7 @@ export async function getEmbeddingClient(
|
||||
const providerFactory = moduleRef.get(CopilotProviderFactory, {
|
||||
strict: false,
|
||||
});
|
||||
const prompt = moduleRef.get(PromptService, { strict: false });
|
||||
|
||||
const client = new ProductionEmbeddingClient(config, providerFactory, prompt);
|
||||
const client = new ProductionEmbeddingClient(config, providerFactory);
|
||||
if (await client.configured()) {
|
||||
EMBEDDING_CLIENT = client;
|
||||
}
|
||||
|
||||
@@ -418,21 +418,6 @@ Convert a multi-speaker audio recording into a structured JSON format by transcr
|
||||
maxRetries: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'Rerank results',
|
||||
action: 'Rerank results',
|
||||
model: 'gpt-5.2',
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: `Judge whether the Document meets the requirements based on the Query and the Instruct provided. The answer must be "yes" or "no".`,
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: `<Instruct>: Given a document search result, determine whether the result is relevant to the query.\n<Query>: {{query}}\n<Document>: {{doc}}`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'Generate a caption',
|
||||
action: 'Generate a caption',
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import type { ToolSet } from 'ai';
|
||||
|
||||
import {
|
||||
CopilotProviderSideError,
|
||||
metrics,
|
||||
@@ -11,6 +9,7 @@ import {
|
||||
type NativeLlmRequest,
|
||||
} from '../../../../native';
|
||||
import type { NodeTextMiddleware } from '../../config';
|
||||
import type { CopilotToolSet } from '../../tools';
|
||||
import { buildNativeRequest, NativeProviderAdapter } from '../native';
|
||||
import { CopilotProvider } from '../provider';
|
||||
import type {
|
||||
@@ -20,7 +19,11 @@ import type {
|
||||
StreamObject,
|
||||
} from '../types';
|
||||
import { CopilotProviderType, ModelOutputType } from '../types';
|
||||
import { getGoogleAuth, getVertexAnthropicBaseUrl } from '../utils';
|
||||
import {
|
||||
getGoogleAuth,
|
||||
getVertexAnthropicBaseUrl,
|
||||
type VertexAnthropicProviderConfig,
|
||||
} from '../utils';
|
||||
|
||||
export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
private handleError(e: any) {
|
||||
@@ -36,22 +39,16 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
|
||||
private async createNativeConfig(): Promise<NativeLlmBackendConfig> {
|
||||
if (this.type === CopilotProviderType.AnthropicVertex) {
|
||||
const auth = await getGoogleAuth(this.config as any, 'anthropic');
|
||||
const headers = auth.headers();
|
||||
const authorization =
|
||||
headers.Authorization ||
|
||||
(headers as Record<string, string | undefined>).authorization;
|
||||
const token =
|
||||
typeof authorization === 'string'
|
||||
? authorization.replace(/^Bearer\s+/i, '')
|
||||
: '';
|
||||
const baseUrl =
|
||||
getVertexAnthropicBaseUrl(this.config as any) || auth.baseUrl;
|
||||
const config = this.config as VertexAnthropicProviderConfig;
|
||||
const auth = await getGoogleAuth(config, 'anthropic');
|
||||
const { Authorization: authHeader } = auth.headers();
|
||||
const token = authHeader.replace(/^Bearer\s+/i, '');
|
||||
const baseUrl = getVertexAnthropicBaseUrl(config) || auth.baseUrl;
|
||||
return {
|
||||
base_url: baseUrl || '',
|
||||
auth_token: token,
|
||||
request_layer: 'vertex',
|
||||
headers,
|
||||
request_layer: 'vertex_anthropic',
|
||||
headers: { Authorization: authHeader },
|
||||
};
|
||||
}
|
||||
|
||||
@@ -65,7 +62,7 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
|
||||
private createAdapter(
|
||||
backendConfig: NativeLlmBackendConfig,
|
||||
tools: ToolSet,
|
||||
tools: CopilotToolSet,
|
||||
nodeTextMiddleware?: NodeTextMiddleware[]
|
||||
) {
|
||||
return new NativeProviderAdapter(
|
||||
@@ -93,8 +90,12 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
cond: fullCond,
|
||||
messages,
|
||||
options,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
@@ -102,11 +103,13 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const reasoning = this.getReasoning(options, model.id);
|
||||
const cap = this.getAttachCapability(model, ModelOutputType.Text);
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
attachmentCapability: cap,
|
||||
reasoning,
|
||||
middleware,
|
||||
});
|
||||
@@ -115,7 +118,7 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
tools,
|
||||
middleware.node?.text
|
||||
);
|
||||
return await adapter.text(request, options.signal);
|
||||
return await adapter.text(request, options.signal, messages);
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
@@ -130,8 +133,12 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
cond: fullCond,
|
||||
messages,
|
||||
options,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
@@ -140,11 +147,13 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
const backendConfig = await this.createNativeConfig();
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const cap = this.getAttachCapability(model, ModelOutputType.Text);
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
attachmentCapability: cap,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
@@ -153,7 +162,11 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
tools,
|
||||
middleware.node?.text
|
||||
);
|
||||
for await (const chunk of adapter.streamText(request, options.signal)) {
|
||||
for await (const chunk of adapter.streamText(
|
||||
request,
|
||||
options.signal,
|
||||
messages
|
||||
)) {
|
||||
yield chunk;
|
||||
}
|
||||
} catch (e: any) {
|
||||
@@ -170,8 +183,12 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Object };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
cond: fullCond,
|
||||
messages,
|
||||
options,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
@@ -180,11 +197,13 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
const backendConfig = await this.createNativeConfig();
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const cap = this.getAttachCapability(model, ModelOutputType.Object);
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
attachmentCapability: cap,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
@@ -193,7 +212,11 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
tools,
|
||||
middleware.node?.text
|
||||
);
|
||||
for await (const chunk of adapter.streamObject(request, options.signal)) {
|
||||
for await (const chunk of adapter.streamObject(
|
||||
request,
|
||||
options.signal,
|
||||
messages
|
||||
)) {
|
||||
yield chunk;
|
||||
}
|
||||
} catch (e: any) {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import z from 'zod';
|
||||
|
||||
import { IMAGE_ATTACHMENT_CAPABILITY } from '../attachments';
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
import { AnthropicProvider } from './anthropic';
|
||||
|
||||
@@ -23,6 +24,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
attachments: IMAGE_ATTACHMENT_CAPABILITY,
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -33,6 +35,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
attachments: IMAGE_ATTACHMENT_CAPABILITY,
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -43,6 +46,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
attachments: IMAGE_ATTACHMENT_CAPABILITY,
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
import {
|
||||
createVertexAnthropic,
|
||||
type GoogleVertexAnthropicProvider,
|
||||
type GoogleVertexAnthropicProviderSettings,
|
||||
} from '@ai-sdk/google-vertex/anthropic';
|
||||
|
||||
import { IMAGE_ATTACHMENT_CAPABILITY } from '../attachments';
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
import {
|
||||
getGoogleAuth,
|
||||
getVertexAnthropicBaseUrl,
|
||||
VertexModelListSchema,
|
||||
type VertexProviderConfig,
|
||||
} from '../utils';
|
||||
import { AnthropicProvider } from './anthropic';
|
||||
|
||||
export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings;
|
||||
export type AnthropicVertexConfig = VertexProviderConfig;
|
||||
|
||||
export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexConfig> {
|
||||
override readonly type = CopilotProviderType.AnthropicVertex;
|
||||
@@ -25,6 +21,7 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
attachments: IMAGE_ATTACHMENT_CAPABILITY,
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -35,6 +32,7 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
attachments: IMAGE_ATTACHMENT_CAPABILITY,
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -45,23 +43,17 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
attachments: IMAGE_ATTACHMENT_CAPABILITY,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
protected instance!: GoogleVertexAnthropicProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
if (!this.config.location || !this.config.googleAuthOptions) return false;
|
||||
return !!this.config.project || !!getVertexAnthropicBaseUrl(this.config);
|
||||
}
|
||||
|
||||
override setup() {
|
||||
super.setup();
|
||||
this.instance = createVertexAnthropic(this.config);
|
||||
}
|
||||
|
||||
override async refreshOnlineModels() {
|
||||
try {
|
||||
const { baseUrl, headers } = await getGoogleAuth(
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
import type {
|
||||
ModelAttachmentCapability,
|
||||
PromptAttachment,
|
||||
PromptAttachmentKind,
|
||||
PromptAttachmentSourceKind,
|
||||
PromptMessage,
|
||||
} from './types';
|
||||
import { inferMimeType } from './utils';
|
||||
|
||||
export const IMAGE_ATTACHMENT_CAPABILITY: ModelAttachmentCapability = {
|
||||
kinds: ['image'],
|
||||
sourceKinds: ['url', 'data'],
|
||||
allowRemoteUrls: true,
|
||||
};
|
||||
|
||||
export const GEMINI_ATTACHMENT_CAPABILITY: ModelAttachmentCapability = {
|
||||
kinds: ['image', 'audio', 'file'],
|
||||
sourceKinds: ['url', 'data', 'bytes', 'file_handle'],
|
||||
allowRemoteUrls: true,
|
||||
};
|
||||
|
||||
export type CanonicalPromptAttachment = {
|
||||
kind: PromptAttachmentKind;
|
||||
sourceKind: PromptAttachmentSourceKind;
|
||||
mediaType?: string;
|
||||
source: Record<string, unknown>;
|
||||
isRemote: boolean;
|
||||
};
|
||||
|
||||
function parseDataUrl(url: string) {
|
||||
if (!url.startsWith('data:')) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const commaIndex = url.indexOf(',');
|
||||
if (commaIndex === -1) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const meta = url.slice(5, commaIndex);
|
||||
const payload = url.slice(commaIndex + 1);
|
||||
const parts = meta.split(';');
|
||||
const mediaType = parts[0] || 'text/plain;charset=US-ASCII';
|
||||
const isBase64 = parts.includes('base64');
|
||||
|
||||
return {
|
||||
mediaType,
|
||||
data: isBase64
|
||||
? payload
|
||||
: Buffer.from(decodeURIComponent(payload), 'utf8').toString('base64'),
|
||||
};
|
||||
}
|
||||
|
||||
function attachmentTypeFromMediaType(mediaType: string): PromptAttachmentKind {
|
||||
if (mediaType.startsWith('image/')) {
|
||||
return 'image';
|
||||
}
|
||||
if (mediaType.startsWith('audio/')) {
|
||||
return 'audio';
|
||||
}
|
||||
return 'file';
|
||||
}
|
||||
|
||||
function attachmentKindFromHintOrMediaType(
|
||||
hint: PromptAttachmentKind | undefined,
|
||||
mediaType: string | undefined
|
||||
): PromptAttachmentKind {
|
||||
if (hint) return hint;
|
||||
return attachmentTypeFromMediaType(mediaType || '');
|
||||
}
|
||||
|
||||
function toBase64Data(data: string, encoding: 'base64' | 'utf8' = 'base64') {
|
||||
return encoding === 'base64'
|
||||
? data
|
||||
: Buffer.from(data, 'utf8').toString('base64');
|
||||
}
|
||||
|
||||
function appendAttachMetadata(
|
||||
source: Record<string, unknown>,
|
||||
attachment: Exclude<PromptAttachment, string> & Record<string, unknown>
|
||||
) {
|
||||
if (attachment.fileName) {
|
||||
source.file_name = attachment.fileName;
|
||||
}
|
||||
if (attachment.providerHint) {
|
||||
source.provider_hint = attachment.providerHint;
|
||||
}
|
||||
return source;
|
||||
}
|
||||
|
||||
export function promptAttachmentHasSource(
|
||||
attachment: PromptAttachment
|
||||
): boolean {
|
||||
if (typeof attachment === 'string') {
|
||||
return !!attachment.trim();
|
||||
}
|
||||
|
||||
if ('attachment' in attachment) {
|
||||
return !!attachment.attachment;
|
||||
}
|
||||
|
||||
switch (attachment.kind) {
|
||||
case 'url':
|
||||
return !!attachment.url;
|
||||
case 'data':
|
||||
case 'bytes':
|
||||
return !!attachment.data;
|
||||
case 'file_handle':
|
||||
return !!attachment.fileHandle;
|
||||
}
|
||||
}
|
||||
|
||||
export async function canonicalizePromptAttachment(
|
||||
attachment: PromptAttachment,
|
||||
message: Pick<PromptMessage, 'params'>
|
||||
): Promise<CanonicalPromptAttachment> {
|
||||
const fallbackMimeType =
|
||||
typeof message.params?.mimetype === 'string'
|
||||
? message.params.mimetype
|
||||
: undefined;
|
||||
|
||||
if (typeof attachment === 'string') {
|
||||
const dataUrl = parseDataUrl(attachment);
|
||||
const mediaType =
|
||||
fallbackMimeType ??
|
||||
dataUrl?.mediaType ??
|
||||
(await inferMimeType(attachment));
|
||||
const kind = attachmentKindFromHintOrMediaType(undefined, mediaType);
|
||||
if (dataUrl) {
|
||||
return {
|
||||
kind,
|
||||
sourceKind: 'data',
|
||||
mediaType,
|
||||
isRemote: false,
|
||||
source: {
|
||||
media_type: mediaType || dataUrl.mediaType,
|
||||
data: dataUrl.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
kind,
|
||||
sourceKind: 'url',
|
||||
mediaType,
|
||||
isRemote: /^https?:\/\//.test(attachment),
|
||||
source: { url: attachment, media_type: mediaType },
|
||||
};
|
||||
}
|
||||
|
||||
if ('attachment' in attachment) {
|
||||
return await canonicalizePromptAttachment(
|
||||
{
|
||||
kind: 'url',
|
||||
url: attachment.attachment,
|
||||
mimeType: attachment.mimeType,
|
||||
},
|
||||
message
|
||||
);
|
||||
}
|
||||
|
||||
if (attachment.kind === 'url') {
|
||||
const dataUrl = parseDataUrl(attachment.url);
|
||||
const mediaType =
|
||||
attachment.mimeType ??
|
||||
fallbackMimeType ??
|
||||
dataUrl?.mediaType ??
|
||||
(await inferMimeType(attachment.url));
|
||||
const kind = attachmentKindFromHintOrMediaType(
|
||||
attachment.providerHint?.kind,
|
||||
mediaType
|
||||
);
|
||||
if (dataUrl) {
|
||||
return {
|
||||
kind,
|
||||
sourceKind: 'data',
|
||||
mediaType,
|
||||
isRemote: false,
|
||||
source: appendAttachMetadata(
|
||||
{ media_type: mediaType || dataUrl.mediaType, data: dataUrl.data },
|
||||
attachment
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
kind,
|
||||
sourceKind: 'url',
|
||||
mediaType,
|
||||
isRemote: /^https?:\/\//.test(attachment.url),
|
||||
source: appendAttachMetadata(
|
||||
{ url: attachment.url, media_type: mediaType },
|
||||
attachment
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
if (attachment.kind === 'data' || attachment.kind === 'bytes') {
|
||||
return {
|
||||
kind: attachmentKindFromHintOrMediaType(
|
||||
attachment.providerHint?.kind,
|
||||
attachment.mimeType
|
||||
),
|
||||
sourceKind: attachment.kind,
|
||||
mediaType: attachment.mimeType,
|
||||
isRemote: false,
|
||||
source: appendAttachMetadata(
|
||||
{
|
||||
media_type: attachment.mimeType,
|
||||
data: toBase64Data(
|
||||
attachment.data,
|
||||
attachment.kind === 'data' ? attachment.encoding : 'base64'
|
||||
),
|
||||
},
|
||||
attachment
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
kind: attachmentKindFromHintOrMediaType(
|
||||
attachment.providerHint?.kind,
|
||||
attachment.mimeType
|
||||
),
|
||||
sourceKind: 'file_handle',
|
||||
mediaType: attachment.mimeType,
|
||||
isRemote: false,
|
||||
source: appendAttachMetadata(
|
||||
{ file_handle: attachment.fileHandle, media_type: attachment.mimeType },
|
||||
attachment
|
||||
),
|
||||
};
|
||||
}
|
||||
@@ -19,6 +19,7 @@ import type {
|
||||
PromptMessage,
|
||||
} from './types';
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from './types';
|
||||
import { promptAttachmentMimeType, promptAttachmentToUrl } from './utils';
|
||||
|
||||
export type FalConfig = {
|
||||
apiKey: string;
|
||||
@@ -183,13 +184,14 @@ export class FalProvider extends CopilotProvider<FalConfig> {
|
||||
return {
|
||||
model_name: options.modelName || undefined,
|
||||
image_url: attachments
|
||||
?.map(v =>
|
||||
typeof v === 'string'
|
||||
? v
|
||||
: v.mimeType.startsWith('image/')
|
||||
? v.attachment
|
||||
: undefined
|
||||
)
|
||||
?.map(v => {
|
||||
const url = promptAttachmentToUrl(v);
|
||||
const mediaType = promptAttachmentMimeType(
|
||||
v,
|
||||
typeof params?.mimetype === 'string' ? params.mimetype : undefined
|
||||
);
|
||||
return url && mediaType?.startsWith('image/') ? url : undefined;
|
||||
})
|
||||
.find(v => !!v),
|
||||
prompt: content.trim(),
|
||||
loras: lora.length ? lora : undefined,
|
||||
@@ -256,7 +258,7 @@ export class FalProvider extends CopilotProvider<FalConfig> {
|
||||
const model = this.selectModel(cond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
|
||||
// by default, image prompt assumes there is only one message
|
||||
const prompt = this.extractPrompt(messages[messages.length - 1]);
|
||||
@@ -281,7 +283,9 @@ export class FalProvider extends CopilotProvider<FalConfig> {
|
||||
}
|
||||
return data.output;
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -294,12 +298,16 @@ export class FalProvider extends CopilotProvider<FalConfig> {
|
||||
const model = this.selectModel(cond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_calls')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const result = await this.text(cond, messages, options);
|
||||
|
||||
yield result;
|
||||
} catch (e) {
|
||||
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
@@ -317,7 +325,7 @@ export class FalProvider extends CopilotProvider<FalConfig> {
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('generate_images_stream_calls')
|
||||
.add(1, { model: model.id });
|
||||
.add(1, this.metricLabels(model.id));
|
||||
|
||||
// by default, image prompt assumes there is only one message
|
||||
const prompt = this.extractPrompt(
|
||||
@@ -374,7 +382,7 @@ export class FalProvider extends CopilotProvider<FalConfig> {
|
||||
} catch (e) {
|
||||
metrics.ai
|
||||
.counter('generate_images_stream_errors')
|
||||
.add(1, { model: model.id });
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,87 +1,94 @@
|
||||
import type {
|
||||
GoogleGenerativeAIProvider,
|
||||
GoogleGenerativeAIProviderOptions,
|
||||
} from '@ai-sdk/google';
|
||||
import type { GoogleVertexProvider } from '@ai-sdk/google-vertex';
|
||||
import {
|
||||
AISDKError,
|
||||
type EmbeddingModel,
|
||||
embedMany,
|
||||
generateObject,
|
||||
generateText,
|
||||
JSONParseError,
|
||||
stepCountIs,
|
||||
streamText,
|
||||
} from 'ai';
|
||||
import { setTimeout as delay } from 'node:timers/promises';
|
||||
|
||||
import { ZodError } from 'zod';
|
||||
|
||||
import {
|
||||
CopilotPromptInvalid,
|
||||
CopilotProviderSideError,
|
||||
metrics,
|
||||
OneMB,
|
||||
readResponseBufferWithLimit,
|
||||
safeFetch,
|
||||
UserFriendlyError,
|
||||
} from '../../../../base';
|
||||
import { sniffMime } from '../../../../base/storage/providers/utils';
|
||||
import {
|
||||
llmDispatchStream,
|
||||
llmEmbeddingDispatch,
|
||||
llmStructuredDispatch,
|
||||
type NativeLlmBackendConfig,
|
||||
type NativeLlmEmbeddingRequest,
|
||||
type NativeLlmRequest,
|
||||
type NativeLlmStructuredRequest,
|
||||
} from '../../../../native';
|
||||
import type { NodeTextMiddleware } from '../../config';
|
||||
import type { CopilotToolSet } from '../../tools';
|
||||
import {
|
||||
buildNativeEmbeddingRequest,
|
||||
buildNativeRequest,
|
||||
buildNativeStructuredRequest,
|
||||
NativeProviderAdapter,
|
||||
parseNativeStructuredOutput,
|
||||
StructuredResponseParseError,
|
||||
} from '../native';
|
||||
import { CopilotProvider } from '../provider';
|
||||
import type {
|
||||
CopilotChatOptions,
|
||||
CopilotEmbeddingOptions,
|
||||
CopilotImageOptions,
|
||||
CopilotProviderModel,
|
||||
CopilotStructuredOptions,
|
||||
ModelConditions,
|
||||
PromptAttachment,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from '../types';
|
||||
import { ModelOutputType } from '../types';
|
||||
import {
|
||||
chatToGPTMessage,
|
||||
StreamObjectParser,
|
||||
TextStreamParser,
|
||||
} from '../utils';
|
||||
import { promptAttachmentMimeType, promptAttachmentToUrl } from '../utils';
|
||||
|
||||
export const DEFAULT_DIMENSIONS = 256;
|
||||
const GEMINI_REMOTE_ATTACHMENT_MAX_BYTES = 64 * OneMB;
|
||||
const TRUSTED_ATTACHMENT_HOST_SUFFIXES = ['cdn.affine.pro'];
|
||||
const GEMINI_RETRY_INITIAL_DELAY_MS = 2_000;
|
||||
|
||||
function normalizeMimeType(mediaType?: string) {
|
||||
return mediaType?.split(';', 1)[0]?.trim() || 'application/octet-stream';
|
||||
}
|
||||
|
||||
function isYoutubeUrl(url: URL) {
|
||||
const hostname = url.hostname.toLowerCase();
|
||||
if (hostname === 'youtu.be') {
|
||||
return /^\/[\w-]+$/.test(url.pathname);
|
||||
}
|
||||
|
||||
if (hostname !== 'youtube.com' && hostname !== 'www.youtube.com') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (url.pathname !== '/watch') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return !!url.searchParams.get('v');
|
||||
}
|
||||
|
||||
function isGeminiFileUrl(url: URL, baseUrl: string) {
|
||||
try {
|
||||
const base = new URL(baseUrl);
|
||||
const basePath = base.pathname.replace(/\/+$/, '');
|
||||
return (
|
||||
url.origin === base.origin &&
|
||||
url.pathname.startsWith(`${basePath}/files/`)
|
||||
);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export abstract class GeminiProvider<T> extends CopilotProvider<T> {
|
||||
protected abstract instance:
|
||||
| GoogleGenerativeAIProvider
|
||||
| GoogleVertexProvider;
|
||||
|
||||
private getThinkingConfig(
|
||||
model: string,
|
||||
options: { includeThoughts: boolean; useDynamicBudget?: boolean }
|
||||
): NonNullable<GoogleGenerativeAIProviderOptions['thinkingConfig']> {
|
||||
if (this.isGemini3Model(model)) {
|
||||
return {
|
||||
includeThoughts: options.includeThoughts,
|
||||
thinkingLevel: 'high',
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
includeThoughts: options.includeThoughts,
|
||||
thinkingBudget: options.useDynamicBudget ? -1 : 12000,
|
||||
};
|
||||
}
|
||||
|
||||
private getEmbeddingModel(model: string) {
|
||||
const provider = this.instance as typeof this.instance & {
|
||||
embeddingModel?: (modelId: string) => EmbeddingModel;
|
||||
textEmbeddingModel?: (modelId: string) => EmbeddingModel;
|
||||
};
|
||||
|
||||
return (
|
||||
provider.embeddingModel?.(model) ?? provider.textEmbeddingModel?.(model)
|
||||
);
|
||||
}
|
||||
protected abstract createNativeConfig(): Promise<NativeLlmBackendConfig>;
|
||||
|
||||
private handleError(e: any) {
|
||||
if (e instanceof UserFriendlyError) {
|
||||
return e;
|
||||
} else if (e instanceof AISDKError) {
|
||||
this.logger.error('Throw error from ai sdk:', e);
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: e.name || 'unknown',
|
||||
message: e.message,
|
||||
});
|
||||
} else {
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
@@ -91,37 +98,261 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
|
||||
}
|
||||
}
|
||||
|
||||
protected createNativeDispatch(backendConfig: NativeLlmBackendConfig) {
|
||||
return (request: NativeLlmRequest, signal?: AbortSignal) =>
|
||||
llmDispatchStream('gemini', backendConfig, request, signal);
|
||||
}
|
||||
|
||||
protected createNativeStructuredDispatch(
|
||||
backendConfig: NativeLlmBackendConfig
|
||||
) {
|
||||
return (request: NativeLlmStructuredRequest) =>
|
||||
llmStructuredDispatch('gemini', backendConfig, request);
|
||||
}
|
||||
|
||||
protected createNativeEmbeddingDispatch(
|
||||
backendConfig: NativeLlmBackendConfig
|
||||
) {
|
||||
return (request: NativeLlmEmbeddingRequest) =>
|
||||
llmEmbeddingDispatch('gemini', backendConfig, request);
|
||||
}
|
||||
|
||||
protected createNativeAdapter(
|
||||
backendConfig: NativeLlmBackendConfig,
|
||||
tools: CopilotToolSet,
|
||||
nodeTextMiddleware?: NodeTextMiddleware[]
|
||||
) {
|
||||
return new NativeProviderAdapter(
|
||||
this.createNativeDispatch(backendConfig),
|
||||
tools,
|
||||
this.MAX_STEPS,
|
||||
{ nodeTextMiddleware }
|
||||
);
|
||||
}
|
||||
|
||||
protected async fetchRemoteAttach(url: string, signal?: AbortSignal) {
|
||||
const parsed = new URL(url);
|
||||
const response = await safeFetch(
|
||||
parsed,
|
||||
{ method: 'GET', signal },
|
||||
this.buildAttachFetchOptions(parsed)
|
||||
);
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Failed to fetch attachment: ${response.status} ${response.statusText}`
|
||||
);
|
||||
}
|
||||
const buffer = await readResponseBufferWithLimit(
|
||||
response,
|
||||
GEMINI_REMOTE_ATTACHMENT_MAX_BYTES
|
||||
);
|
||||
const headerMimeType = normalizeMimeType(
|
||||
response.headers.get('content-type') || ''
|
||||
);
|
||||
return {
|
||||
data: buffer.toString('base64'),
|
||||
mimeType: normalizeMimeType(sniffMime(buffer, headerMimeType)),
|
||||
};
|
||||
}
|
||||
|
||||
private buildAttachFetchOptions(url: URL) {
|
||||
const baseOptions = { timeoutMs: 15_000, maxRedirects: 3 } as const;
|
||||
if (!env.prod) {
|
||||
return { ...baseOptions, allowPrivateOrigins: new Set([url.origin]) };
|
||||
}
|
||||
|
||||
const trustedOrigins = new Set<string>();
|
||||
const protocol = this.AFFiNEConfig.server.https ? 'https:' : 'http:';
|
||||
const port = this.AFFiNEConfig.server.port;
|
||||
const isDefaultPort =
|
||||
(protocol === 'https:' && port === 443) ||
|
||||
(protocol === 'http:' && port === 80);
|
||||
|
||||
const addHostOrigin = (host: string) => {
|
||||
if (!host) return;
|
||||
try {
|
||||
const parsed = new URL(`${protocol}//${host}`);
|
||||
if (!parsed.port && !isDefaultPort) {
|
||||
parsed.port = String(port);
|
||||
}
|
||||
trustedOrigins.add(parsed.origin);
|
||||
} catch {
|
||||
// ignore invalid host config entries
|
||||
}
|
||||
};
|
||||
|
||||
if (this.AFFiNEConfig.server.externalUrl) {
|
||||
try {
|
||||
trustedOrigins.add(
|
||||
new URL(this.AFFiNEConfig.server.externalUrl).origin
|
||||
);
|
||||
} catch {
|
||||
// ignore invalid external URL
|
||||
}
|
||||
}
|
||||
|
||||
addHostOrigin(this.AFFiNEConfig.server.host);
|
||||
for (const host of this.AFFiNEConfig.server.hosts) {
|
||||
addHostOrigin(host);
|
||||
}
|
||||
|
||||
const hostname = url.hostname.toLowerCase();
|
||||
const trustedByHost = TRUSTED_ATTACHMENT_HOST_SUFFIXES.some(
|
||||
suffix => hostname === suffix || hostname.endsWith(`.${suffix}`)
|
||||
);
|
||||
if (trustedOrigins.has(url.origin) || trustedByHost) {
|
||||
return { ...baseOptions, allowPrivateOrigins: new Set([url.origin]) };
|
||||
}
|
||||
|
||||
return baseOptions;
|
||||
}
|
||||
|
||||
private shouldInlineRemoteAttach(url: URL, config: NativeLlmBackendConfig) {
|
||||
switch (config.request_layer) {
|
||||
case 'gemini_api':
|
||||
if (url.protocol !== 'http:' && url.protocol !== 'https:') return false;
|
||||
return !(isGeminiFileUrl(url, config.base_url) || isYoutubeUrl(url));
|
||||
case 'gemini_vertex':
|
||||
return false;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private toInlineAttach(
|
||||
attachment: PromptAttachment,
|
||||
mimeType: string,
|
||||
data: string
|
||||
): PromptAttachment {
|
||||
if (typeof attachment === 'string' || !('kind' in attachment)) {
|
||||
return { kind: 'bytes', data, mimeType };
|
||||
}
|
||||
|
||||
if (attachment.kind !== 'url') {
|
||||
return attachment;
|
||||
}
|
||||
|
||||
return {
|
||||
kind: 'bytes',
|
||||
data,
|
||||
mimeType,
|
||||
fileName: attachment.fileName,
|
||||
providerHint: attachment.providerHint,
|
||||
};
|
||||
}
|
||||
|
||||
protected async prepareMessages(
|
||||
messages: PromptMessage[],
|
||||
backendConfig: NativeLlmBackendConfig,
|
||||
signal?: AbortSignal
|
||||
): Promise<PromptMessage[]> {
|
||||
const prepared: PromptMessage[] = [];
|
||||
|
||||
for (const message of messages) {
|
||||
signal?.throwIfAborted();
|
||||
if (!Array.isArray(message.attachments) || !message.attachments.length) {
|
||||
prepared.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
const attachments: PromptAttachment[] = [];
|
||||
let changed = false;
|
||||
for (const attachment of message.attachments) {
|
||||
signal?.throwIfAborted();
|
||||
const rawUrl = promptAttachmentToUrl(attachment);
|
||||
if (!rawUrl || rawUrl.startsWith('data:')) {
|
||||
attachments.push(attachment);
|
||||
continue;
|
||||
}
|
||||
|
||||
let parsed: URL;
|
||||
try {
|
||||
parsed = new URL(rawUrl);
|
||||
} catch {
|
||||
attachments.push(attachment);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!this.shouldInlineRemoteAttach(parsed, backendConfig)) {
|
||||
attachments.push(attachment);
|
||||
continue;
|
||||
}
|
||||
|
||||
const declaredMimeType = promptAttachmentMimeType(
|
||||
attachment,
|
||||
typeof message.params?.mimetype === 'string'
|
||||
? message.params.mimetype
|
||||
: undefined
|
||||
);
|
||||
const downloaded = await this.fetchRemoteAttach(rawUrl, signal);
|
||||
attachments.push(
|
||||
this.toInlineAttach(
|
||||
attachment,
|
||||
declaredMimeType
|
||||
? normalizeMimeType(declaredMimeType)
|
||||
: downloaded.mimeType,
|
||||
downloaded.data
|
||||
)
|
||||
);
|
||||
changed = true;
|
||||
}
|
||||
|
||||
prepared.push(changed ? { ...message, attachments } : message);
|
||||
}
|
||||
|
||||
return prepared;
|
||||
}
|
||||
|
||||
protected async waitForStructuredRetry(
|
||||
delayMs: number,
|
||||
signal?: AbortSignal
|
||||
) {
|
||||
await delay(delayMs, undefined, signal ? { signal } : undefined);
|
||||
}
|
||||
|
||||
async text(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
cond: fullCond,
|
||||
messages,
|
||||
options,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const modelInstance = this.instance(model.id);
|
||||
const { text } = await generateText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
abortSignal: options.signal,
|
||||
providerOptions: {
|
||||
google: this.getGeminiOptions(options, model.id),
|
||||
},
|
||||
tools: await this.getTools(options, model.id),
|
||||
stopWhen: stepCountIs(this.MAX_STEPS),
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
const backendConfig = await this.createNativeConfig();
|
||||
const msg = await this.prepareMessages(
|
||||
messages,
|
||||
backendConfig,
|
||||
options.signal
|
||||
);
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const cap = this.getAttachCapability(model, ModelOutputType.Text);
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages: msg,
|
||||
options,
|
||||
tools,
|
||||
attachmentCapability: cap,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
|
||||
if (!text) throw new Error('Failed to generate text');
|
||||
return text.trim();
|
||||
const adapter = this.createNativeAdapter(
|
||||
backendConfig,
|
||||
tools,
|
||||
middleware.node?.text
|
||||
);
|
||||
return await adapter.text(request, options.signal, messages);
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -129,55 +360,65 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
|
||||
override async structure(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
options: CopilotStructuredOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Structured };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
cond: fullCond,
|
||||
messages,
|
||||
options,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
|
||||
const [system, msgs, schema] = await chatToGPTMessage(messages);
|
||||
if (!schema) {
|
||||
throw new CopilotPromptInvalid('Schema is required');
|
||||
}
|
||||
|
||||
const modelInstance = this.instance(model.id);
|
||||
const { object } = await generateObject({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
schema,
|
||||
providerOptions: {
|
||||
google: {
|
||||
thinkingConfig: this.getThinkingConfig(model.id, {
|
||||
includeThoughts: false,
|
||||
useDynamicBudget: true,
|
||||
}),
|
||||
},
|
||||
},
|
||||
abortSignal: options.signal,
|
||||
maxRetries: options.maxRetries || 3,
|
||||
experimental_repairText: async ({ text, error }) => {
|
||||
if (error instanceof JSONParseError) {
|
||||
// strange fixed response, temporarily replace it
|
||||
const ret = text.replaceAll(/^ny\n/g, ' ').trim();
|
||||
if (ret.startsWith('```') || ret.endsWith('```')) {
|
||||
return ret
|
||||
.replace(/```[\w\s]+\n/g, '')
|
||||
.replace(/\n```/g, '')
|
||||
.trim();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
return null;
|
||||
},
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
const backendConfig = await this.createNativeConfig();
|
||||
const msg = await this.prepareMessages(
|
||||
messages,
|
||||
backendConfig,
|
||||
options.signal
|
||||
);
|
||||
const structuredDispatch =
|
||||
this.createNativeStructuredDispatch(backendConfig);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const cap = this.getAttachCapability(model, ModelOutputType.Structured);
|
||||
const { request, schema } = await buildNativeStructuredRequest({
|
||||
model: model.id,
|
||||
messages: msg,
|
||||
options,
|
||||
attachmentCapability: cap,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
responseSchema: options.schema,
|
||||
middleware,
|
||||
});
|
||||
|
||||
return JSON.stringify(object);
|
||||
const maxRetries = Math.max(options.maxRetries ?? 3, 0);
|
||||
for (let attempt = 0; ; attempt++) {
|
||||
try {
|
||||
const response = await structuredDispatch(request);
|
||||
const parsed = parseNativeStructuredOutput(response);
|
||||
const validated = schema.parse(parsed);
|
||||
return JSON.stringify(validated);
|
||||
} catch (error) {
|
||||
const isParsingError =
|
||||
error instanceof StructuredResponseParseError ||
|
||||
error instanceof ZodError;
|
||||
const retryableError =
|
||||
isParsingError || !(error instanceof UserFriendlyError);
|
||||
if (!retryableError || attempt >= maxRetries) {
|
||||
throw error;
|
||||
}
|
||||
if (!isParsingError) {
|
||||
await this.waitForStructuredRetry(
|
||||
GEMINI_RETRY_INITIAL_DELAY_MS * 2 ** attempt,
|
||||
options.signal
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -188,29 +429,54 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
|
||||
options: CopilotChatOptions | CopilotImageOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
cond: fullCond,
|
||||
messages,
|
||||
options,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const parser = new TextStreamParser();
|
||||
for await (const chunk of fullStream) {
|
||||
const result = parser.parse(chunk);
|
||||
yield result;
|
||||
if (options.signal?.aborted) {
|
||||
await fullStream.cancel();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!options.signal?.aborted) {
|
||||
const footnotes = parser.end();
|
||||
if (footnotes.length) {
|
||||
yield `\n\n${footnotes}`;
|
||||
}
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_calls')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const backendConfig = await this.createNativeConfig();
|
||||
const preparedMessages = await this.prepareMessages(
|
||||
messages,
|
||||
backendConfig,
|
||||
options.signal
|
||||
);
|
||||
const tools = await this.getTools(
|
||||
options as CopilotChatOptions,
|
||||
model.id
|
||||
);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const cap = this.getAttachCapability(model, ModelOutputType.Text);
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages: preparedMessages,
|
||||
options: options as CopilotChatOptions,
|
||||
tools,
|
||||
attachmentCapability: cap,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(
|
||||
backendConfig,
|
||||
tools,
|
||||
middleware.node?.text
|
||||
);
|
||||
for await (const chunk of adapter.streamText(
|
||||
request,
|
||||
options.signal,
|
||||
messages
|
||||
)) {
|
||||
yield chunk;
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -221,29 +487,51 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Object };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
cond: fullCond,
|
||||
messages,
|
||||
options,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_calls')
|
||||
.add(1, { model: model.id });
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const parser = new StreamObjectParser();
|
||||
for await (const chunk of fullStream) {
|
||||
const result = parser.parse(chunk);
|
||||
if (result) {
|
||||
yield result;
|
||||
}
|
||||
if (options.signal?.aborted) {
|
||||
await fullStream.cancel();
|
||||
break;
|
||||
}
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const backendConfig = await this.createNativeConfig();
|
||||
const msg = await this.prepareMessages(
|
||||
messages,
|
||||
backendConfig,
|
||||
options.signal
|
||||
);
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const cap = this.getAttachCapability(model, ModelOutputType.Object);
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages: msg,
|
||||
options,
|
||||
tools,
|
||||
attachmentCapability: cap,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(
|
||||
backendConfig,
|
||||
tools,
|
||||
middleware.node?.text
|
||||
);
|
||||
for await (const chunk of adapter.streamObject(
|
||||
request,
|
||||
options.signal,
|
||||
messages
|
||||
)) {
|
||||
yield chunk;
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_errors')
|
||||
.add(1, { model: model.id });
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -253,76 +541,53 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
|
||||
messages: string | string[],
|
||||
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
|
||||
): Promise<number[][]> {
|
||||
messages = Array.isArray(messages) ? messages : [messages];
|
||||
const values = Array.isArray(messages) ? messages : [messages];
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Embedding };
|
||||
await this.checkParams({ embeddings: messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
embeddings: values,
|
||||
cond: fullCond,
|
||||
options,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('generate_embedding_calls')
|
||||
.add(1, { model: model.id });
|
||||
|
||||
const modelInstance = this.getEmbeddingModel(model.id);
|
||||
if (!modelInstance) {
|
||||
throw new Error(`Embedding model is not available for ${model.id}`);
|
||||
}
|
||||
|
||||
const embeddings = await Promise.allSettled(
|
||||
messages.map(m =>
|
||||
embedMany({
|
||||
model: modelInstance,
|
||||
values: [m],
|
||||
maxRetries: 3,
|
||||
providerOptions: {
|
||||
google: {
|
||||
outputDimensionality: options.dimensions || DEFAULT_DIMENSIONS,
|
||||
taskType: 'RETRIEVAL_DOCUMENT',
|
||||
},
|
||||
},
|
||||
})
|
||||
)
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const backendConfig = await this.createNativeConfig();
|
||||
const response = await this.createNativeEmbeddingDispatch(backendConfig)(
|
||||
buildNativeEmbeddingRequest({
|
||||
model: model.id,
|
||||
inputs: values,
|
||||
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
|
||||
taskType: 'RETRIEVAL_DOCUMENT',
|
||||
})
|
||||
);
|
||||
|
||||
return embeddings
|
||||
.flatMap(e => (e.status === 'fulfilled' ? e.value.embeddings : null))
|
||||
.filter((v): v is number[] => !!v && Array.isArray(v));
|
||||
return response.embeddings;
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('generate_embedding_errors')
|
||||
.add(1, { model: model.id });
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
private async getFullStream(
|
||||
model: CopilotProviderModel,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
) {
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
const { fullStream } = streamText({
|
||||
model: this.instance(model.id),
|
||||
system,
|
||||
messages: msgs,
|
||||
abortSignal: options.signal,
|
||||
providerOptions: {
|
||||
google: this.getGeminiOptions(options, model.id),
|
||||
},
|
||||
tools: await this.getTools(options, model.id),
|
||||
stopWhen: stepCountIs(this.MAX_STEPS),
|
||||
});
|
||||
return fullStream;
|
||||
}
|
||||
|
||||
private getGeminiOptions(options: CopilotChatOptions, model: string) {
|
||||
const result: GoogleGenerativeAIProviderOptions = {};
|
||||
if (options?.reasoning && this.isReasoningModel(model)) {
|
||||
result.thinkingConfig = this.getThinkingConfig(model, {
|
||||
includeThoughts: true,
|
||||
});
|
||||
protected getReasoning(
|
||||
options: CopilotChatOptions | CopilotImageOptions,
|
||||
model: string
|
||||
): Record<string, unknown> | undefined {
|
||||
if (
|
||||
options &&
|
||||
'reasoning' in options &&
|
||||
options.reasoning &&
|
||||
this.isReasoningModel(model)
|
||||
) {
|
||||
return this.isGemini3Model(model)
|
||||
? { include_thoughts: true, thinking_level: 'high' }
|
||||
: { include_thoughts: true, thinking_budget: 12000 };
|
||||
}
|
||||
return result;
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
private isGemini3Model(model: string) {
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import {
|
||||
createGoogleGenerativeAI,
|
||||
type GoogleGenerativeAIProvider,
|
||||
} from '@ai-sdk/google';
|
||||
import z from 'zod';
|
||||
|
||||
import type { NativeLlmBackendConfig } from '../../../../native';
|
||||
import { GEMINI_ATTACHMENT_CAPABILITY } from '../attachments';
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
import { GeminiProvider } from './gemini';
|
||||
|
||||
@@ -29,12 +27,15 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
|
||||
ModelInputType.Text,
|
||||
ModelInputType.Image,
|
||||
ModelInputType.Audio,
|
||||
ModelInputType.File,
|
||||
],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
attachments: GEMINI_ATTACHMENT_CAPABILITY,
|
||||
structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY,
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -47,12 +48,15 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
|
||||
ModelInputType.Text,
|
||||
ModelInputType.Image,
|
||||
ModelInputType.Audio,
|
||||
ModelInputType.File,
|
||||
],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
attachments: GEMINI_ATTACHMENT_CAPABILITY,
|
||||
structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY,
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -65,12 +69,15 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
|
||||
ModelInputType.Text,
|
||||
ModelInputType.Image,
|
||||
ModelInputType.Audio,
|
||||
ModelInputType.File,
|
||||
],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
attachments: GEMINI_ATTACHMENT_CAPABILITY,
|
||||
structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY,
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -86,21 +93,10 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
protected instance!: GoogleGenerativeAIProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
return !!this.config.apiKey;
|
||||
}
|
||||
|
||||
protected override setup() {
|
||||
super.setup();
|
||||
this.instance = createGoogleGenerativeAI({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseURL,
|
||||
});
|
||||
}
|
||||
|
||||
override async refreshOnlineModels() {
|
||||
try {
|
||||
const baseUrl =
|
||||
@@ -120,4 +116,15 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
|
||||
this.logger.error('Failed to fetch available models', e);
|
||||
}
|
||||
}
|
||||
|
||||
protected override async createNativeConfig(): Promise<NativeLlmBackendConfig> {
|
||||
return {
|
||||
base_url: (
|
||||
this.config.baseURL ||
|
||||
'https://generativelanguage.googleapis.com/v1beta'
|
||||
).replace(/\/$/, ''),
|
||||
auth_token: this.config.apiKey,
|
||||
request_layer: 'gemini_api',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import {
|
||||
createVertex,
|
||||
type GoogleVertexProvider,
|
||||
type GoogleVertexProviderSettings,
|
||||
} from '@ai-sdk/google-vertex';
|
||||
|
||||
import type { NativeLlmBackendConfig } from '../../../../native';
|
||||
import { GEMINI_ATTACHMENT_CAPABILITY } from '../attachments';
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
import { getGoogleAuth, VertexModelListSchema } from '../utils';
|
||||
import {
|
||||
getGoogleAuth,
|
||||
VertexModelListSchema,
|
||||
type VertexProviderConfig,
|
||||
} from '../utils';
|
||||
import { GeminiProvider } from './gemini';
|
||||
|
||||
export type GeminiVertexConfig = GoogleVertexProviderSettings;
|
||||
export type GeminiVertexConfig = VertexProviderConfig;
|
||||
|
||||
export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
|
||||
override readonly type = CopilotProviderType.GeminiVertex;
|
||||
@@ -23,12 +23,15 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
|
||||
ModelInputType.Text,
|
||||
ModelInputType.Image,
|
||||
ModelInputType.Audio,
|
||||
ModelInputType.File,
|
||||
],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
attachments: GEMINI_ATTACHMENT_CAPABILITY,
|
||||
structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY,
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -41,12 +44,15 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
|
||||
ModelInputType.Text,
|
||||
ModelInputType.Image,
|
||||
ModelInputType.Audio,
|
||||
ModelInputType.File,
|
||||
],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
attachments: GEMINI_ATTACHMENT_CAPABILITY,
|
||||
structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY,
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -59,12 +65,15 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
|
||||
ModelInputType.Text,
|
||||
ModelInputType.Image,
|
||||
ModelInputType.Audio,
|
||||
ModelInputType.File,
|
||||
],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
attachments: GEMINI_ATTACHMENT_CAPABILITY,
|
||||
structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY,
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -80,21 +89,13 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
protected instance!: GoogleVertexProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
return !!this.config.location && !!this.config.googleAuthOptions;
|
||||
}
|
||||
|
||||
protected override setup() {
|
||||
super.setup();
|
||||
this.instance = createVertex(this.config);
|
||||
}
|
||||
|
||||
override async refreshOnlineModels() {
|
||||
try {
|
||||
const { baseUrl, headers } = await getGoogleAuth(this.config, 'google');
|
||||
const { baseUrl, headers } = await this.resolveVertexAuth();
|
||||
if (baseUrl && !this.onlineModelList.length) {
|
||||
const { publisherModels } = await fetch(`${baseUrl}/models`, {
|
||||
headers: headers(),
|
||||
@@ -109,4 +110,19 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
|
||||
this.logger.error('Failed to fetch available models', e);
|
||||
}
|
||||
}
|
||||
|
||||
protected async resolveVertexAuth() {
|
||||
return await getGoogleAuth(this.config, 'google');
|
||||
}
|
||||
|
||||
protected override async createNativeConfig(): Promise<NativeLlmBackendConfig> {
|
||||
const auth = await this.resolveVertexAuth();
|
||||
const { Authorization: authHeader } = auth.headers();
|
||||
|
||||
return {
|
||||
base_url: auth.baseUrl || '',
|
||||
auth_token: authHeader.replace(/^Bearer\s+/i, ''),
|
||||
request_layer: 'gemini_vertex',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { ToolSet } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import type {
|
||||
@@ -6,6 +5,11 @@ import type {
|
||||
NativeLlmStreamEvent,
|
||||
NativeLlmToolDefinition,
|
||||
} from '../../../native';
|
||||
import type {
|
||||
CopilotTool,
|
||||
CopilotToolExecuteOptions,
|
||||
CopilotToolSet,
|
||||
} from '../tools';
|
||||
|
||||
export type NativeDispatchFn = (
|
||||
request: NativeLlmRequest,
|
||||
@@ -16,6 +20,8 @@ export type NativeToolCall = {
|
||||
id: string;
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
rawArgumentsText?: string;
|
||||
argumentParseError?: string;
|
||||
thought?: string;
|
||||
};
|
||||
|
||||
@@ -28,10 +34,18 @@ type ToolExecutionResult = {
|
||||
callId: string;
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
rawArgumentsText?: string;
|
||||
argumentParseError?: string;
|
||||
output: unknown;
|
||||
isError?: boolean;
|
||||
};
|
||||
|
||||
type ParsedToolArguments = {
|
||||
args: Record<string, unknown>;
|
||||
rawArgumentsText?: string;
|
||||
argumentParseError?: string;
|
||||
};
|
||||
|
||||
export class ToolCallAccumulator {
|
||||
readonly #states = new Map<string, ToolCallState>();
|
||||
|
||||
@@ -51,12 +65,20 @@ export class ToolCallAccumulator {
|
||||
complete(event: Extract<NativeLlmStreamEvent, { type: 'tool_call' }>) {
|
||||
const state = this.#states.get(event.call_id);
|
||||
this.#states.delete(event.call_id);
|
||||
const parsed =
|
||||
event.arguments_text !== undefined || event.arguments_error !== undefined
|
||||
? {
|
||||
args: event.arguments ?? {},
|
||||
rawArgumentsText: event.arguments_text ?? state?.argumentsText,
|
||||
argumentParseError: event.arguments_error,
|
||||
}
|
||||
: event.arguments
|
||||
? this.parseArgs(event.arguments, state?.argumentsText)
|
||||
: this.parseJson(state?.argumentsText ?? '{}');
|
||||
return {
|
||||
id: event.call_id,
|
||||
name: event.name || state?.name || '',
|
||||
args: this.parseArgs(
|
||||
event.arguments ?? this.parseJson(state?.argumentsText ?? '{}')
|
||||
),
|
||||
...parsed,
|
||||
thought: event.thought,
|
||||
} satisfies NativeToolCall;
|
||||
}
|
||||
@@ -70,51 +92,61 @@ export class ToolCallAccumulator {
|
||||
pending.push({
|
||||
id: callId,
|
||||
name: state.name,
|
||||
args: this.parseArgs(this.parseJson(state.argumentsText)),
|
||||
...this.parseJson(state.argumentsText),
|
||||
});
|
||||
}
|
||||
this.#states.clear();
|
||||
return pending;
|
||||
}
|
||||
|
||||
private parseJson(jsonText: string): unknown {
|
||||
private parseJson(jsonText: string): ParsedToolArguments {
|
||||
if (!jsonText.trim()) {
|
||||
return {};
|
||||
return { args: {} };
|
||||
}
|
||||
try {
|
||||
return JSON.parse(jsonText);
|
||||
} catch {
|
||||
return {};
|
||||
return this.parseArgs(JSON.parse(jsonText), jsonText);
|
||||
} catch (error) {
|
||||
return {
|
||||
args: {},
|
||||
rawArgumentsText: jsonText,
|
||||
argumentParseError:
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: 'Invalid tool arguments JSON',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private parseArgs(value: unknown): Record<string, unknown> {
|
||||
private parseArgs(
|
||||
value: unknown,
|
||||
rawArgumentsText?: string
|
||||
): ParsedToolArguments {
|
||||
if (value && typeof value === 'object' && !Array.isArray(value)) {
|
||||
return value as Record<string, unknown>;
|
||||
return {
|
||||
args: value as Record<string, unknown>,
|
||||
rawArgumentsText,
|
||||
};
|
||||
}
|
||||
return {};
|
||||
return {
|
||||
args: {},
|
||||
rawArgumentsText,
|
||||
argumentParseError: 'Tool arguments must be a JSON object',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export class ToolSchemaExtractor {
|
||||
static extract(toolSet: ToolSet): NativeLlmToolDefinition[] {
|
||||
static extract(toolSet: CopilotToolSet): NativeLlmToolDefinition[] {
|
||||
return Object.entries(toolSet).map(([name, tool]) => {
|
||||
const unknownTool = tool as Record<string, unknown>;
|
||||
const inputSchema =
|
||||
unknownTool.inputSchema ?? unknownTool.parameters ?? z.object({});
|
||||
|
||||
return {
|
||||
name,
|
||||
description:
|
||||
typeof unknownTool.description === 'string'
|
||||
? unknownTool.description
|
||||
: undefined,
|
||||
parameters: this.toJsonSchema(inputSchema),
|
||||
description: tool.description,
|
||||
parameters: this.toJsonSchema(tool.inputSchema ?? z.object({})),
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
private static toJsonSchema(schema: unknown): Record<string, unknown> {
|
||||
static toJsonSchema(schema: unknown): Record<string, unknown> {
|
||||
if (!(schema instanceof z.ZodType)) {
|
||||
if (schema && typeof schema === 'object' && !Array.isArray(schema)) {
|
||||
return schema as Record<string, unknown>;
|
||||
@@ -228,14 +260,45 @@ export class ToolSchemaExtractor {
|
||||
export class ToolCallLoop {
|
||||
constructor(
|
||||
private readonly dispatch: NativeDispatchFn,
|
||||
private readonly tools: ToolSet,
|
||||
private readonly tools: CopilotToolSet,
|
||||
private readonly maxSteps = 20
|
||||
) {}
|
||||
|
||||
private normalizeToolExecuteOptions(
|
||||
signalOrOptions?: AbortSignal | CopilotToolExecuteOptions,
|
||||
maybeMessages?: CopilotToolExecuteOptions['messages']
|
||||
): CopilotToolExecuteOptions {
|
||||
if (
|
||||
signalOrOptions &&
|
||||
typeof signalOrOptions === 'object' &&
|
||||
'aborted' in signalOrOptions
|
||||
) {
|
||||
return {
|
||||
signal: signalOrOptions,
|
||||
messages: maybeMessages,
|
||||
};
|
||||
}
|
||||
|
||||
if (!signalOrOptions) {
|
||||
return maybeMessages ? { messages: maybeMessages } : {};
|
||||
}
|
||||
|
||||
return {
|
||||
...signalOrOptions,
|
||||
signal: signalOrOptions.signal,
|
||||
messages: signalOrOptions.messages ?? maybeMessages,
|
||||
};
|
||||
}
|
||||
|
||||
async *run(
|
||||
request: NativeLlmRequest,
|
||||
signal?: AbortSignal
|
||||
signalOrOptions?: AbortSignal | CopilotToolExecuteOptions,
|
||||
maybeMessages?: CopilotToolExecuteOptions['messages']
|
||||
): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
const toolExecuteOptions = this.normalizeToolExecuteOptions(
|
||||
signalOrOptions,
|
||||
maybeMessages
|
||||
);
|
||||
const messages = request.messages.map(message => ({
|
||||
...message,
|
||||
content: [...message.content],
|
||||
@@ -253,7 +316,7 @@ export class ToolCallLoop {
|
||||
stream: true,
|
||||
messages,
|
||||
},
|
||||
signal
|
||||
toolExecuteOptions.signal
|
||||
)) {
|
||||
switch (event.type) {
|
||||
case 'tool_call_delta': {
|
||||
@@ -291,7 +354,10 @@ export class ToolCallLoop {
|
||||
throw new Error('ToolCallLoop max steps reached');
|
||||
}
|
||||
|
||||
const toolResults = await this.executeTools(toolCalls);
|
||||
const toolResults = await this.executeTools(
|
||||
toolCalls,
|
||||
toolExecuteOptions
|
||||
);
|
||||
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
@@ -300,6 +366,8 @@ export class ToolCallLoop {
|
||||
call_id: call.id,
|
||||
name: call.name,
|
||||
arguments: call.args,
|
||||
arguments_text: call.rawArgumentsText,
|
||||
arguments_error: call.argumentParseError,
|
||||
thought: call.thought,
|
||||
})),
|
||||
});
|
||||
@@ -311,6 +379,10 @@ export class ToolCallLoop {
|
||||
{
|
||||
type: 'tool_result',
|
||||
call_id: result.callId,
|
||||
name: result.name,
|
||||
arguments: result.args,
|
||||
arguments_text: result.rawArgumentsText,
|
||||
arguments_error: result.argumentParseError,
|
||||
output: result.output,
|
||||
is_error: result.isError,
|
||||
},
|
||||
@@ -321,6 +393,8 @@ export class ToolCallLoop {
|
||||
call_id: result.callId,
|
||||
name: result.name,
|
||||
arguments: result.args,
|
||||
arguments_text: result.rawArgumentsText,
|
||||
arguments_error: result.argumentParseError,
|
||||
output: result.output,
|
||||
is_error: result.isError,
|
||||
};
|
||||
@@ -328,24 +402,28 @@ export class ToolCallLoop {
|
||||
}
|
||||
}
|
||||
|
||||
private async executeTools(calls: NativeToolCall[]) {
|
||||
return await Promise.all(calls.map(call => this.executeTool(call)));
|
||||
private async executeTools(
|
||||
calls: NativeToolCall[],
|
||||
options: CopilotToolExecuteOptions
|
||||
) {
|
||||
return await Promise.all(
|
||||
calls.map(call => this.executeTool(call, options))
|
||||
);
|
||||
}
|
||||
|
||||
private async executeTool(
|
||||
call: NativeToolCall
|
||||
call: NativeToolCall,
|
||||
options: CopilotToolExecuteOptions
|
||||
): Promise<ToolExecutionResult> {
|
||||
const tool = this.tools[call.name] as
|
||||
| {
|
||||
execute?: (args: Record<string, unknown>) => Promise<unknown>;
|
||||
}
|
||||
| undefined;
|
||||
const tool = this.tools[call.name] as CopilotTool | undefined;
|
||||
|
||||
if (!tool?.execute) {
|
||||
return {
|
||||
callId: call.id,
|
||||
name: call.name,
|
||||
args: call.args,
|
||||
rawArgumentsText: call.rawArgumentsText,
|
||||
argumentParseError: call.argumentParseError,
|
||||
isError: true,
|
||||
output: {
|
||||
message: `Tool not found: ${call.name}`,
|
||||
@@ -353,12 +431,30 @@ export class ToolCallLoop {
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const output = await tool.execute(call.args);
|
||||
if (call.argumentParseError) {
|
||||
return {
|
||||
callId: call.id,
|
||||
name: call.name,
|
||||
args: call.args,
|
||||
rawArgumentsText: call.rawArgumentsText,
|
||||
argumentParseError: call.argumentParseError,
|
||||
isError: true,
|
||||
output: {
|
||||
message: 'Invalid tool arguments JSON',
|
||||
rawArguments: call.rawArgumentsText,
|
||||
error: call.argumentParseError,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const output = await tool.execute(call.args, options);
|
||||
return {
|
||||
callId: call.id,
|
||||
name: call.name,
|
||||
args: call.args,
|
||||
rawArgumentsText: call.rawArgumentsText,
|
||||
argumentParseError: call.argumentParseError,
|
||||
output: output ?? null,
|
||||
};
|
||||
} catch (error) {
|
||||
@@ -371,6 +467,8 @@ export class ToolCallLoop {
|
||||
callId: call.id,
|
||||
name: call.name,
|
||||
args: call.args,
|
||||
rawArgumentsText: call.rawArgumentsText,
|
||||
argumentParseError: call.argumentParseError,
|
||||
isError: true,
|
||||
output: {
|
||||
message: 'Tool execution failed',
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import type { ToolSet } from 'ai';
|
||||
|
||||
import {
|
||||
CopilotProviderSideError,
|
||||
metrics,
|
||||
@@ -11,6 +9,7 @@ import {
|
||||
type NativeLlmRequest,
|
||||
} from '../../../native';
|
||||
import type { NodeTextMiddleware } from '../config';
|
||||
import type { CopilotToolSet } from '../tools';
|
||||
import { buildNativeRequest, NativeProviderAdapter } from './native';
|
||||
import { CopilotProvider } from './provider';
|
||||
import type {
|
||||
@@ -86,7 +85,7 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
|
||||
}
|
||||
|
||||
private createNativeAdapter(
|
||||
tools: ToolSet,
|
||||
tools: CopilotToolSet,
|
||||
nodeTextMiddleware?: NodeTextMiddleware[]
|
||||
) {
|
||||
return new NativeProviderAdapter(
|
||||
@@ -108,12 +107,14 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Text,
|
||||
};
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
const model = this.selectModel(
|
||||
await this.checkParams({
|
||||
messages,
|
||||
cond: fullCond,
|
||||
options,
|
||||
})
|
||||
);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
@@ -127,7 +128,7 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
return await adapter.text(request, options.signal);
|
||||
return await adapter.text(request, options.signal, messages);
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
@@ -141,12 +142,14 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Text,
|
||||
};
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
const model = this.selectModel(
|
||||
await this.checkParams({
|
||||
messages,
|
||||
cond: fullCond,
|
||||
options,
|
||||
})
|
||||
);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
@@ -162,7 +165,11 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
for await (const chunk of adapter.streamText(request, options.signal)) {
|
||||
for await (const chunk of adapter.streamText(
|
||||
request,
|
||||
options.signal,
|
||||
messages
|
||||
)) {
|
||||
yield chunk;
|
||||
}
|
||||
} catch (e: any) {
|
||||
|
||||
@@ -1,31 +1,41 @@
|
||||
import type { ToolSet } from 'ai';
|
||||
import { ZodType } from 'zod';
|
||||
|
||||
import { CopilotPromptInvalid } from '../../../base';
|
||||
import type {
|
||||
NativeLlmCoreContent,
|
||||
NativeLlmCoreMessage,
|
||||
NativeLlmEmbeddingRequest,
|
||||
NativeLlmRequest,
|
||||
NativeLlmStreamEvent,
|
||||
NativeLlmStructuredRequest,
|
||||
NativeLlmStructuredResponse,
|
||||
} from '../../../native';
|
||||
import type { NodeTextMiddleware, ProviderMiddlewareConfig } from '../config';
|
||||
import { NativeDispatchFn, ToolCallLoop, ToolSchemaExtractor } from './loop';
|
||||
import type { CopilotChatOptions, PromptMessage, StreamObject } from './types';
|
||||
import type { CopilotToolSet } from '../tools';
|
||||
import {
|
||||
CitationFootnoteFormatter,
|
||||
inferMimeType,
|
||||
TextStreamParser,
|
||||
} from './utils';
|
||||
|
||||
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
|
||||
canonicalizePromptAttachment,
|
||||
type CanonicalPromptAttachment,
|
||||
} from './attachments';
|
||||
import { NativeDispatchFn, ToolCallLoop, ToolSchemaExtractor } from './loop';
|
||||
import type {
|
||||
CopilotChatOptions,
|
||||
CopilotStructuredOptions,
|
||||
ModelAttachmentCapability,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from './types';
|
||||
import { CitationFootnoteFormatter, TextStreamParser } from './utils';
|
||||
|
||||
type BuildNativeRequestOptions = {
|
||||
model: string;
|
||||
messages: PromptMessage[];
|
||||
options?: CopilotChatOptions;
|
||||
tools?: ToolSet;
|
||||
options?: CopilotChatOptions | CopilotStructuredOptions;
|
||||
tools?: CopilotToolSet;
|
||||
withAttachment?: boolean;
|
||||
attachmentCapability?: ModelAttachmentCapability;
|
||||
include?: string[];
|
||||
reasoning?: Record<string, unknown>;
|
||||
responseSchema?: unknown;
|
||||
middleware?: ProviderMiddlewareConfig;
|
||||
};
|
||||
|
||||
@@ -34,6 +44,11 @@ type BuildNativeRequestResult = {
|
||||
schema?: ZodType;
|
||||
};
|
||||
|
||||
type BuildNativeStructuredRequestResult = {
|
||||
request: NativeLlmStructuredRequest;
|
||||
schema: ZodType;
|
||||
};
|
||||
|
||||
type ToolCallMeta = {
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
@@ -68,9 +83,121 @@ function roleToCore(role: PromptMessage['role']) {
|
||||
}
|
||||
}
|
||||
|
||||
function ensureAttachmentSupported(
|
||||
attachment: CanonicalPromptAttachment,
|
||||
attachmentCapability?: ModelAttachmentCapability
|
||||
) {
|
||||
if (!attachmentCapability) return;
|
||||
|
||||
if (!attachmentCapability.kinds.includes(attachment.kind)) {
|
||||
throw new CopilotPromptInvalid(
|
||||
`Native path does not support ${attachment.kind} attachments${
|
||||
attachment.mediaType ? ` (${attachment.mediaType})` : ''
|
||||
}`
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
attachmentCapability.sourceKinds?.length &&
|
||||
!attachmentCapability.sourceKinds.includes(attachment.sourceKind)
|
||||
) {
|
||||
throw new CopilotPromptInvalid(
|
||||
`Native path does not support ${attachment.sourceKind} attachment sources`
|
||||
);
|
||||
}
|
||||
|
||||
if (attachment.isRemote && attachmentCapability.allowRemoteUrls === false) {
|
||||
throw new CopilotPromptInvalid(
|
||||
'Native path does not support remote attachment urls'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
function resolveResponseSchema(
|
||||
systemMessage: PromptMessage | undefined,
|
||||
responseSchema?: unknown
|
||||
): ZodType | undefined {
|
||||
if (responseSchema instanceof ZodType) {
|
||||
return responseSchema;
|
||||
}
|
||||
|
||||
if (systemMessage?.responseFormat?.schema instanceof ZodType) {
|
||||
return systemMessage.responseFormat.schema;
|
||||
}
|
||||
|
||||
return systemMessage?.params?.schema instanceof ZodType
|
||||
? systemMessage.params.schema
|
||||
: undefined;
|
||||
}
|
||||
|
||||
function resolveResponseStrict(
|
||||
systemMessage: PromptMessage | undefined,
|
||||
options?: CopilotStructuredOptions
|
||||
) {
|
||||
return options?.strict ?? systemMessage?.responseFormat?.strict ?? true;
|
||||
}
|
||||
|
||||
export class StructuredResponseParseError extends Error {}
|
||||
|
||||
function normalizeStructuredText(text: string) {
|
||||
const trimmed = text.replaceAll(/^ny\n/g, ' ').trim();
|
||||
if (trimmed.startsWith('```') || trimmed.endsWith('```')) {
|
||||
return trimmed
|
||||
.replace(/```[\w\s-]*\n/g, '')
|
||||
.replace(/\n```/g, '')
|
||||
.trim();
|
||||
}
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
export function parseNativeStructuredOutput(
|
||||
response: Pick<NativeLlmStructuredResponse, 'output_text'> & {
|
||||
output_json?: unknown;
|
||||
}
|
||||
) {
|
||||
if (response.output_json !== undefined) {
|
||||
return response.output_json;
|
||||
}
|
||||
|
||||
const normalized = normalizeStructuredText(response.output_text);
|
||||
const candidates = [
|
||||
() => normalized,
|
||||
() => {
|
||||
const objectStart = normalized.indexOf('{');
|
||||
const objectEnd = normalized.lastIndexOf('}');
|
||||
return objectStart !== -1 && objectEnd > objectStart
|
||||
? normalized.slice(objectStart, objectEnd + 1)
|
||||
: null;
|
||||
},
|
||||
() => {
|
||||
const arrayStart = normalized.indexOf('[');
|
||||
const arrayEnd = normalized.lastIndexOf(']');
|
||||
return arrayStart !== -1 && arrayEnd > arrayStart
|
||||
? normalized.slice(arrayStart, arrayEnd + 1)
|
||||
: null;
|
||||
},
|
||||
];
|
||||
|
||||
for (const candidate of candidates) {
|
||||
try {
|
||||
const candidateText = candidate();
|
||||
if (typeof candidateText === 'string') {
|
||||
return JSON.parse(candidateText);
|
||||
}
|
||||
} catch {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
throw new StructuredResponseParseError(
|
||||
`Unexpected structured response: ${normalized.slice(0, 200)}`
|
||||
);
|
||||
}
|
||||
|
||||
async function toCoreContents(
|
||||
message: PromptMessage,
|
||||
withAttachment: boolean
|
||||
withAttachment: boolean,
|
||||
attachmentCapability?: ModelAttachmentCapability
|
||||
): Promise<NativeLlmCoreContent[]> {
|
||||
const contents: NativeLlmCoreContent[] = [];
|
||||
|
||||
@@ -81,24 +208,12 @@ async function toCoreContents(
|
||||
if (!withAttachment || !Array.isArray(message.attachments)) return contents;
|
||||
|
||||
for (const entry of message.attachments) {
|
||||
let attachmentUrl: string;
|
||||
let mediaType: string;
|
||||
|
||||
if (typeof entry === 'string') {
|
||||
attachmentUrl = entry;
|
||||
mediaType =
|
||||
typeof message.params?.mimetype === 'string'
|
||||
? message.params.mimetype
|
||||
: await inferMimeType(entry);
|
||||
} else {
|
||||
attachmentUrl = entry.attachment;
|
||||
mediaType = entry.mimeType;
|
||||
}
|
||||
|
||||
if (!SIMPLE_IMAGE_URL_REGEX.test(attachmentUrl)) continue;
|
||||
if (!mediaType.startsWith('image/')) continue;
|
||||
|
||||
contents.push({ type: 'image', source: { url: attachmentUrl } });
|
||||
const normalized = await canonicalizePromptAttachment(entry, message);
|
||||
ensureAttachmentSupported(normalized, attachmentCapability);
|
||||
contents.push({
|
||||
type: normalized.kind,
|
||||
source: normalized.source,
|
||||
});
|
||||
}
|
||||
|
||||
return contents;
|
||||
@@ -110,8 +225,10 @@ export async function buildNativeRequest({
|
||||
options = {},
|
||||
tools = {},
|
||||
withAttachment = true,
|
||||
attachmentCapability,
|
||||
include,
|
||||
reasoning,
|
||||
responseSchema,
|
||||
middleware,
|
||||
}: BuildNativeRequestOptions): Promise<BuildNativeRequestResult> {
|
||||
const copiedMessages = messages.map(message => ({
|
||||
@@ -123,10 +240,7 @@ export async function buildNativeRequest({
|
||||
|
||||
const systemMessage =
|
||||
copiedMessages[0]?.role === 'system' ? copiedMessages.shift() : undefined;
|
||||
const schema =
|
||||
systemMessage?.params?.schema instanceof ZodType
|
||||
? systemMessage.params.schema
|
||||
: undefined;
|
||||
const schema = resolveResponseSchema(systemMessage, responseSchema);
|
||||
|
||||
const coreMessages: NativeLlmCoreMessage[] = [];
|
||||
if (systemMessage?.content?.length) {
|
||||
@@ -138,7 +252,11 @@ export async function buildNativeRequest({
|
||||
|
||||
for (const message of copiedMessages) {
|
||||
if (message.role === 'system') continue;
|
||||
const content = await toCoreContents(message, withAttachment);
|
||||
const content = await toCoreContents(
|
||||
message,
|
||||
withAttachment,
|
||||
attachmentCapability
|
||||
);
|
||||
coreMessages.push({ role: roleToCore(message.role), content });
|
||||
}
|
||||
|
||||
@@ -153,6 +271,9 @@ export async function buildNativeRequest({
|
||||
tool_choice: Object.keys(tools).length ? 'auto' : undefined,
|
||||
include,
|
||||
reasoning,
|
||||
response_schema: schema
|
||||
? ToolSchemaExtractor.toJsonSchema(schema)
|
||||
: undefined,
|
||||
middleware: middleware?.rust
|
||||
? { request: middleware.rust.request, stream: middleware.rust.stream }
|
||||
: undefined,
|
||||
@@ -161,6 +282,90 @@ export async function buildNativeRequest({
|
||||
};
|
||||
}
|
||||
|
||||
export async function buildNativeStructuredRequest({
|
||||
model,
|
||||
messages,
|
||||
options = {},
|
||||
withAttachment = true,
|
||||
attachmentCapability,
|
||||
reasoning,
|
||||
responseSchema,
|
||||
middleware,
|
||||
}: Omit<
|
||||
BuildNativeRequestOptions,
|
||||
'tools' | 'include'
|
||||
>): Promise<BuildNativeStructuredRequestResult> {
|
||||
const copiedMessages = messages.map(message => ({
|
||||
...message,
|
||||
attachments: message.attachments
|
||||
? [...message.attachments]
|
||||
: message.attachments,
|
||||
}));
|
||||
|
||||
const systemMessage =
|
||||
copiedMessages[0]?.role === 'system' ? copiedMessages.shift() : undefined;
|
||||
const schema = resolveResponseSchema(systemMessage, responseSchema);
|
||||
const strict = resolveResponseStrict(systemMessage, options);
|
||||
|
||||
if (!schema) {
|
||||
throw new CopilotPromptInvalid('Schema is required');
|
||||
}
|
||||
|
||||
const coreMessages: NativeLlmCoreMessage[] = [];
|
||||
if (systemMessage?.content?.length) {
|
||||
coreMessages.push({
|
||||
role: 'system',
|
||||
content: [{ type: 'text', text: systemMessage.content }],
|
||||
});
|
||||
}
|
||||
|
||||
for (const message of copiedMessages) {
|
||||
if (message.role === 'system') continue;
|
||||
const content = await toCoreContents(
|
||||
message,
|
||||
withAttachment,
|
||||
attachmentCapability
|
||||
);
|
||||
coreMessages.push({ role: roleToCore(message.role), content });
|
||||
}
|
||||
|
||||
return {
|
||||
request: {
|
||||
model,
|
||||
messages: coreMessages,
|
||||
schema: ToolSchemaExtractor.toJsonSchema(schema),
|
||||
max_tokens: options.maxTokens ?? undefined,
|
||||
temperature: options.temperature ?? undefined,
|
||||
reasoning,
|
||||
strict,
|
||||
response_mime_type: 'application/json',
|
||||
middleware: middleware?.rust
|
||||
? { request: middleware.rust.request }
|
||||
: undefined,
|
||||
},
|
||||
schema,
|
||||
};
|
||||
}
|
||||
|
||||
export function buildNativeEmbeddingRequest({
|
||||
model,
|
||||
inputs,
|
||||
dimensions,
|
||||
taskType = 'RETRIEVAL_DOCUMENT',
|
||||
}: {
|
||||
model: string;
|
||||
inputs: string[];
|
||||
dimensions?: number;
|
||||
taskType?: string;
|
||||
}): NativeLlmEmbeddingRequest {
|
||||
return {
|
||||
model,
|
||||
inputs,
|
||||
dimensions,
|
||||
task_type: taskType,
|
||||
};
|
||||
}
|
||||
|
||||
function ensureToolResultMeta(
|
||||
event: Extract<NativeLlmStreamEvent, { type: 'tool_result' }>,
|
||||
toolCalls: Map<string, ToolCallMeta>
|
||||
@@ -244,7 +449,7 @@ export class NativeProviderAdapter {
|
||||
|
||||
constructor(
|
||||
dispatch: NativeDispatchFn,
|
||||
tools: ToolSet,
|
||||
tools: CopilotToolSet,
|
||||
maxSteps = 20,
|
||||
options: NativeProviderAdapterOptions = {}
|
||||
) {
|
||||
@@ -259,9 +464,13 @@ export class NativeProviderAdapter {
|
||||
enabledNodeTextMiddlewares.has('citation_footnote');
|
||||
}
|
||||
|
||||
async text(request: NativeLlmRequest, signal?: AbortSignal) {
|
||||
async text(
|
||||
request: NativeLlmRequest,
|
||||
signal?: AbortSignal,
|
||||
messages?: PromptMessage[]
|
||||
) {
|
||||
let output = '';
|
||||
for await (const chunk of this.streamText(request, signal)) {
|
||||
for await (const chunk of this.streamText(request, signal, messages)) {
|
||||
output += chunk;
|
||||
}
|
||||
return output.trim();
|
||||
@@ -269,7 +478,8 @@ export class NativeProviderAdapter {
|
||||
|
||||
async *streamText(
|
||||
request: NativeLlmRequest,
|
||||
signal?: AbortSignal
|
||||
signal?: AbortSignal,
|
||||
messages?: PromptMessage[]
|
||||
): AsyncIterableIterator<string> {
|
||||
const textParser = this.#enableCallout ? new TextStreamParser() : null;
|
||||
const citationFormatter = this.#enableCitationFootnote
|
||||
@@ -278,7 +488,7 @@ export class NativeProviderAdapter {
|
||||
const toolCalls = new Map<string, ToolCallMeta>();
|
||||
let streamPartId = 0;
|
||||
|
||||
for await (const event of this.#loop.run(request, signal)) {
|
||||
for await (const event of this.#loop.run(request, signal, messages)) {
|
||||
switch (event.type) {
|
||||
case 'text_delta': {
|
||||
if (textParser) {
|
||||
@@ -364,7 +574,8 @@ export class NativeProviderAdapter {
|
||||
|
||||
async *streamObject(
|
||||
request: NativeLlmRequest,
|
||||
signal?: AbortSignal
|
||||
signal?: AbortSignal,
|
||||
messages?: PromptMessage[]
|
||||
): AsyncIterableIterator<StreamObject> {
|
||||
const toolCalls = new Map<string, ToolCallMeta>();
|
||||
const citationFormatter = this.#enableCitationFootnote
|
||||
@@ -373,7 +584,7 @@ export class NativeProviderAdapter {
|
||||
const fallbackAttachmentFootnotes = new Map<string, AttachmentFootnote>();
|
||||
let hasFootnoteReference = false;
|
||||
|
||||
for await (const event of this.#loop.run(request, signal)) {
|
||||
for await (const event of this.#loop.run(request, signal, messages)) {
|
||||
switch (event.type) {
|
||||
case 'text_delta': {
|
||||
if (event.text.includes('[^')) {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { Tool, ToolSet } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
@@ -12,30 +11,41 @@ import {
|
||||
} from '../../../base';
|
||||
import {
|
||||
llmDispatchStream,
|
||||
llmEmbeddingDispatch,
|
||||
llmRerankDispatch,
|
||||
llmStructuredDispatch,
|
||||
type NativeLlmBackendConfig,
|
||||
type NativeLlmEmbeddingRequest,
|
||||
type NativeLlmRequest,
|
||||
type NativeLlmRerankRequest,
|
||||
type NativeLlmRerankResponse,
|
||||
type NativeLlmStructuredRequest,
|
||||
} from '../../../native';
|
||||
import type { NodeTextMiddleware } from '../config';
|
||||
import { buildNativeRequest, NativeProviderAdapter } from './native';
|
||||
import { CopilotProvider } from './provider';
|
||||
import type { CopilotTool, CopilotToolSet } from '../tools';
|
||||
import { IMAGE_ATTACHMENT_CAPABILITY } from './attachments';
|
||||
import {
|
||||
normalizeRerankModel,
|
||||
OPENAI_RERANK_MAX_COMPLETION_TOKENS,
|
||||
OPENAI_RERANK_TOP_LOGPROBS_LIMIT,
|
||||
usesRerankReasoning,
|
||||
} from './rerank';
|
||||
buildNativeEmbeddingRequest,
|
||||
buildNativeRequest,
|
||||
buildNativeStructuredRequest,
|
||||
NativeProviderAdapter,
|
||||
parseNativeStructuredOutput,
|
||||
} from './native';
|
||||
import { CopilotProvider } from './provider';
|
||||
import type {
|
||||
CopilotChatOptions,
|
||||
CopilotChatTools,
|
||||
CopilotEmbeddingOptions,
|
||||
CopilotImageOptions,
|
||||
CopilotRerankRequest,
|
||||
CopilotStructuredOptions,
|
||||
ModelCapability,
|
||||
ModelConditions,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from './types';
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from './types';
|
||||
import { chatToGPTMessage } from './utils';
|
||||
import { promptAttachmentToUrl } from './utils';
|
||||
|
||||
export const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
@@ -91,19 +101,6 @@ const ImageResponseSchema = z.union([
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
const LogProbsSchema = z.array(
|
||||
z.object({
|
||||
token: z.string(),
|
||||
logprob: z.number(),
|
||||
top_logprobs: z.array(
|
||||
z.object({
|
||||
token: z.string(),
|
||||
logprob: z.number(),
|
||||
})
|
||||
),
|
||||
})
|
||||
);
|
||||
|
||||
const TRUSTED_ATTACHMENT_HOST_SUFFIXES = ['cdn.affine.pro'];
|
||||
|
||||
function normalizeImageFormatToMime(format?: string) {
|
||||
@@ -136,6 +133,34 @@ function normalizeImageResponseData(
|
||||
.filter((value): value is string => typeof value === 'string');
|
||||
}
|
||||
|
||||
function buildOpenAIRerankRequest(
|
||||
model: string,
|
||||
request: CopilotRerankRequest
|
||||
): NativeLlmRerankRequest {
|
||||
return {
|
||||
model,
|
||||
query: request.query,
|
||||
candidates: request.candidates.map(candidate => ({
|
||||
...(candidate.id ? { id: candidate.id } : {}),
|
||||
text: candidate.text,
|
||||
})),
|
||||
...(request.topK ? { top_n: request.topK } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
function createOpenAIMultimodalCapability(
|
||||
output: ModelCapability['output'],
|
||||
options: Pick<ModelCapability, 'defaultForOutputType'> = {}
|
||||
): ModelCapability {
|
||||
return {
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output,
|
||||
attachments: IMAGE_ATTACHMENT_CAPABILITY,
|
||||
structuredAttachments: IMAGE_ATTACHMENT_CAPABILITY,
|
||||
...options,
|
||||
};
|
||||
}
|
||||
|
||||
export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
readonly type = CopilotProviderType.OpenAI;
|
||||
|
||||
@@ -145,10 +170,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
name: 'GPT 4o',
|
||||
id: 'gpt-4o',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
]),
|
||||
],
|
||||
},
|
||||
// FIXME(@darkskygit): deprecated
|
||||
@@ -156,20 +181,20 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
name: 'GPT 4o 2024-08-06',
|
||||
id: 'gpt-4o-2024-08-06',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
]),
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'GPT 4o Mini',
|
||||
id: 'gpt-4o-mini',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
]),
|
||||
],
|
||||
},
|
||||
// FIXME(@darkskygit): deprecated
|
||||
@@ -177,181 +202,158 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
name: 'GPT 4o Mini 2024-07-18',
|
||||
id: 'gpt-4o-mini-2024-07-18',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
]),
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'GPT 4.1',
|
||||
id: 'gpt-4.1',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [
|
||||
createOpenAIMultimodalCapability(
|
||||
[
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Rerank,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
defaultForOutputType: true,
|
||||
},
|
||||
{ defaultForOutputType: true }
|
||||
),
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'GPT 4.1 2025-04-14',
|
||||
id: 'gpt-4.1-2025-04-14',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Rerank,
|
||||
ModelOutputType.Structured,
|
||||
]),
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'GPT 4.1 Mini',
|
||||
id: 'gpt-4.1-mini',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Rerank,
|
||||
ModelOutputType.Structured,
|
||||
]),
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'GPT 4.1 Nano',
|
||||
id: 'gpt-4.1-nano',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Rerank,
|
||||
ModelOutputType.Structured,
|
||||
]),
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'GPT 5',
|
||||
id: 'gpt-5',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
]),
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'GPT 5 2025-08-07',
|
||||
id: 'gpt-5-2025-08-07',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
]),
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'GPT 5 Mini',
|
||||
id: 'gpt-5-mini',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
]),
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'GPT 5.2',
|
||||
id: 'gpt-5.2',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Rerank,
|
||||
ModelOutputType.Structured,
|
||||
]),
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'GPT 5.2 2025-12-11',
|
||||
id: 'gpt-5.2-2025-12-11',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
]),
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'GPT 5 Nano',
|
||||
id: 'gpt-5-nano',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
ModelOutputType.Structured,
|
||||
]),
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'GPT O1',
|
||||
id: 'o1',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
]),
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'GPT O3',
|
||||
id: 'o3',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
]),
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'GPT O4 Mini',
|
||||
id: 'o4-mini',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Text, ModelOutputType.Object],
|
||||
},
|
||||
createOpenAIMultimodalCapability([
|
||||
ModelOutputType.Text,
|
||||
ModelOutputType.Object,
|
||||
]),
|
||||
],
|
||||
},
|
||||
// Embedding models
|
||||
@@ -387,11 +389,9 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
{
|
||||
id: 'gpt-image-1',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text, ModelInputType.Image],
|
||||
output: [ModelOutputType.Image],
|
||||
createOpenAIMultimodalCapability([ModelOutputType.Image], {
|
||||
defaultForOutputType: true,
|
||||
},
|
||||
}),
|
||||
],
|
||||
},
|
||||
];
|
||||
@@ -437,7 +437,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
override getProviderSpecificTools(
|
||||
toolName: CopilotChatTools,
|
||||
_model: string
|
||||
): [string, Tool?] | undefined {
|
||||
): [string, CopilotTool?] | undefined {
|
||||
if (toolName === 'docEdit') {
|
||||
return ['doc_edit', undefined];
|
||||
}
|
||||
@@ -452,14 +452,18 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
};
|
||||
}
|
||||
|
||||
private getNativeProtocol() {
|
||||
return this.config.oldApiStyle ? 'openai_chat' : 'openai_responses';
|
||||
}
|
||||
|
||||
private createNativeAdapter(
|
||||
tools: ToolSet,
|
||||
tools: CopilotToolSet,
|
||||
nodeTextMiddleware?: NodeTextMiddleware[]
|
||||
) {
|
||||
return new NativeProviderAdapter(
|
||||
(request: NativeLlmRequest, signal?: AbortSignal) =>
|
||||
llmDispatchStream(
|
||||
this.config.oldApiStyle ? 'openai_chat' : 'openai_responses',
|
||||
this.getNativeProtocol(),
|
||||
this.createNativeConfig(),
|
||||
request,
|
||||
signal
|
||||
@@ -470,6 +474,27 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
);
|
||||
}
|
||||
|
||||
protected createNativeStructuredDispatch(
|
||||
backendConfig: NativeLlmBackendConfig
|
||||
) {
|
||||
return (request: NativeLlmStructuredRequest) =>
|
||||
llmStructuredDispatch(this.getNativeProtocol(), backendConfig, request);
|
||||
}
|
||||
|
||||
protected createNativeEmbeddingDispatch(
|
||||
backendConfig: NativeLlmBackendConfig
|
||||
) {
|
||||
return (request: NativeLlmEmbeddingRequest) =>
|
||||
llmEmbeddingDispatch(this.getNativeProtocol(), backendConfig, request);
|
||||
}
|
||||
|
||||
protected createNativeRerankDispatch(backendConfig: NativeLlmBackendConfig) {
|
||||
return (
|
||||
request: NativeLlmRerankRequest
|
||||
): Promise<NativeLlmRerankResponse> =>
|
||||
llmRerankDispatch('openai_chat', backendConfig, request);
|
||||
}
|
||||
|
||||
private getReasoning(
|
||||
options: NonNullable<CopilotChatOptions>,
|
||||
model: string
|
||||
@@ -486,13 +511,18 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
messages,
|
||||
cond: fullCond,
|
||||
options,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const cap = this.getAttachCapability(model, ModelOutputType.Text);
|
||||
const normalizedOptions = normalizeOpenAIOptionsForModel(
|
||||
options,
|
||||
model.id
|
||||
@@ -502,12 +532,13 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
messages,
|
||||
options: normalizedOptions,
|
||||
tools,
|
||||
attachmentCapability: cap,
|
||||
include: options.webSearch ? ['citations'] : undefined,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
return await adapter.text(request, options.signal);
|
||||
return await adapter.text(request, options.signal, messages);
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
@@ -525,8 +556,12 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Text,
|
||||
};
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
messages,
|
||||
cond: fullCond,
|
||||
options,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
@@ -534,6 +569,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const cap = this.getAttachCapability(model, ModelOutputType.Text);
|
||||
const normalizedOptions = normalizeOpenAIOptionsForModel(
|
||||
options,
|
||||
model.id
|
||||
@@ -543,12 +579,17 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
messages,
|
||||
options: normalizedOptions,
|
||||
tools,
|
||||
attachmentCapability: cap,
|
||||
include: options.webSearch ? ['citations'] : undefined,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
for await (const chunk of adapter.streamText(request, options.signal)) {
|
||||
for await (const chunk of adapter.streamText(
|
||||
request,
|
||||
options.signal,
|
||||
messages
|
||||
)) {
|
||||
yield chunk;
|
||||
}
|
||||
} catch (e: any) {
|
||||
@@ -565,8 +606,12 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Object };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
cond: fullCond,
|
||||
messages,
|
||||
options,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
@@ -574,6 +619,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const cap = this.getAttachCapability(model, ModelOutputType.Object);
|
||||
const normalizedOptions = normalizeOpenAIOptionsForModel(
|
||||
options,
|
||||
model.id
|
||||
@@ -583,12 +629,17 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
messages,
|
||||
options: normalizedOptions,
|
||||
tools,
|
||||
attachmentCapability: cap,
|
||||
include: options.webSearch ? ['citations'] : undefined,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
for await (const chunk of adapter.streamObject(request, options.signal)) {
|
||||
for await (const chunk of adapter.streamObject(
|
||||
request,
|
||||
options.signal,
|
||||
messages
|
||||
)) {
|
||||
yield chunk;
|
||||
}
|
||||
} catch (e: any) {
|
||||
@@ -605,106 +656,66 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
options: CopilotStructuredOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Structured };
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
messages,
|
||||
cond: fullCond,
|
||||
options,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
const tools = await this.getTools(options, model.id);
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
const backendConfig = this.createNativeConfig();
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const cap = this.getAttachCapability(model, ModelOutputType.Structured);
|
||||
const normalizedOptions = normalizeOpenAIOptionsForModel(
|
||||
options,
|
||||
model.id
|
||||
);
|
||||
const { request, schema } = await buildNativeRequest({
|
||||
const { request, schema } = await buildNativeStructuredRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options: normalizedOptions,
|
||||
tools,
|
||||
attachmentCapability: cap,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
responseSchema: options.schema,
|
||||
middleware,
|
||||
});
|
||||
if (!schema) {
|
||||
throw new CopilotPromptInvalid('Schema is required');
|
||||
}
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
const text = await adapter.text(request, options.signal);
|
||||
const parsed = JSON.parse(text);
|
||||
const response =
|
||||
await this.createNativeStructuredDispatch(backendConfig)(request);
|
||||
const parsed = parseNativeStructuredOutput(response);
|
||||
const validated = schema.parse(parsed);
|
||||
return JSON.stringify(validated);
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
override async rerank(
|
||||
cond: ModelConditions,
|
||||
chunkMessages: PromptMessage[][],
|
||||
request: CopilotRerankRequest,
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<number[]> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ messages: [], cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Rerank };
|
||||
const normalizedCond = await this.checkParams({
|
||||
messages: [],
|
||||
cond: fullCond,
|
||||
options,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
const scores = await Promise.all(
|
||||
chunkMessages.map(async messages => {
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
const rerankModel = normalizeRerankModel(model.id);
|
||||
const response = await this.requestOpenAIJson(
|
||||
'/chat/completions',
|
||||
{
|
||||
model: rerankModel,
|
||||
messages: this.toOpenAIChatMessages(system, msgs),
|
||||
temperature: 0,
|
||||
logprobs: true,
|
||||
top_logprobs: OPENAI_RERANK_TOP_LOGPROBS_LIMIT,
|
||||
...(usesRerankReasoning(rerankModel)
|
||||
? {
|
||||
reasoning_effort: 'none' as const,
|
||||
max_completion_tokens: OPENAI_RERANK_MAX_COMPLETION_TOKENS,
|
||||
}
|
||||
: { max_tokens: OPENAI_RERANK_MAX_COMPLETION_TOKENS }),
|
||||
},
|
||||
options.signal
|
||||
);
|
||||
|
||||
const logprobs = response?.choices?.[0]?.logprobs?.content;
|
||||
if (!Array.isArray(logprobs) || logprobs.length === 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const parsedLogprobs = LogProbsSchema.parse(logprobs);
|
||||
const topMap = parsedLogprobs[0].top_logprobs.reduce(
|
||||
(acc, { token, logprob }) => ({ ...acc, [token]: logprob }),
|
||||
{} as Record<string, number>
|
||||
);
|
||||
|
||||
const findLogProb = (token: string): number => {
|
||||
// OpenAI often includes a leading space, so try matching '.yes', '_yes', ' yes' and 'yes'
|
||||
return [...'_:. "-\t,(=_“'.split('').map(c => c + token), token]
|
||||
.flatMap(v => [v, v.toLowerCase(), v.toUpperCase()])
|
||||
.reduce<number>(
|
||||
(best, key) =>
|
||||
(topMap[key] ?? Number.NEGATIVE_INFINITY) > best
|
||||
? topMap[key]
|
||||
: best,
|
||||
Number.NEGATIVE_INFINITY
|
||||
);
|
||||
};
|
||||
|
||||
const logYes = findLogProb('Yes');
|
||||
const logNo = findLogProb('No');
|
||||
|
||||
const pYes = Math.exp(logYes);
|
||||
const pNo = Math.exp(logNo);
|
||||
const prob = pYes + pNo === 0 ? 0 : pYes / (pYes + pNo);
|
||||
|
||||
return prob;
|
||||
})
|
||||
);
|
||||
|
||||
return scores;
|
||||
try {
|
||||
const backendConfig = this.createNativeConfig();
|
||||
const nativeRequest = buildOpenAIRerankRequest(model.id, request);
|
||||
const response =
|
||||
await this.createNativeRerankDispatch(backendConfig)(nativeRequest);
|
||||
return response.scores;
|
||||
} catch (e: any) {
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
// ====== text to image ======
|
||||
@@ -906,7 +917,8 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
form.set('output_format', outputFormat);
|
||||
|
||||
for (const [idx, entry] of attachments.entries()) {
|
||||
const url = typeof entry === 'string' ? entry : entry.attachment;
|
||||
const url = promptAttachmentToUrl(entry);
|
||||
if (!url) continue;
|
||||
try {
|
||||
const attachment = await this.fetchImage(url, maxBytes, signal);
|
||||
if (!attachment) continue;
|
||||
@@ -964,12 +976,16 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
options: CopilotImageOptions = {}
|
||||
) {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Image };
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
messages,
|
||||
cond: fullCond,
|
||||
options,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
metrics.ai
|
||||
.counter('generate_images_stream_calls')
|
||||
.add(1, { model: model.id });
|
||||
.add(1, this.metricLabels(model.id));
|
||||
|
||||
const { content: prompt, attachments } = [...messages].pop() || {};
|
||||
if (!prompt) throw new CopilotPromptInvalid('Prompt is required');
|
||||
@@ -1007,7 +1023,9 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
}
|
||||
return;
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('generate_images_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('generate_images_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -1017,65 +1035,36 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
messages: string | string[],
|
||||
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
|
||||
): Promise<number[][]> {
|
||||
messages = Array.isArray(messages) ? messages : [messages];
|
||||
const input = Array.isArray(messages) ? messages : [messages];
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Embedding };
|
||||
await this.checkParams({ embeddings: messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
embeddings: input,
|
||||
cond: fullCond,
|
||||
options,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('generate_embedding_calls')
|
||||
.add(1, { model: model.id });
|
||||
const response = await this.requestOpenAIJson('/embeddings', {
|
||||
model: model.id,
|
||||
input: messages,
|
||||
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
|
||||
});
|
||||
const data = Array.isArray(response?.data) ? response.data : [];
|
||||
return data
|
||||
.map((item: any) => item?.embedding)
|
||||
.filter((embedding: unknown) => Array.isArray(embedding)) as number[][];
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const backendConfig = this.createNativeConfig();
|
||||
const response = await this.createNativeEmbeddingDispatch(backendConfig)(
|
||||
buildNativeEmbeddingRequest({
|
||||
model: model.id,
|
||||
inputs: input,
|
||||
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
|
||||
})
|
||||
);
|
||||
return response.embeddings;
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('generate_embedding_errors')
|
||||
.add(1, { model: model.id });
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
private toOpenAIChatMessages(
|
||||
system: string | undefined,
|
||||
messages: Awaited<ReturnType<typeof chatToGPTMessage>>[1]
|
||||
) {
|
||||
const result: Array<{ role: string; content: string }> = [];
|
||||
if (system) {
|
||||
result.push({ role: 'system', content: system });
|
||||
}
|
||||
|
||||
for (const message of messages) {
|
||||
if (typeof message.content === 'string') {
|
||||
result.push({ role: message.role, content: message.content });
|
||||
continue;
|
||||
}
|
||||
|
||||
const text = message.content
|
||||
.filter(
|
||||
part =>
|
||||
part &&
|
||||
typeof part === 'object' &&
|
||||
'type' in part &&
|
||||
part.type === 'text' &&
|
||||
'text' in part
|
||||
)
|
||||
.map(part => String((part as { text: string }).text))
|
||||
.join('\n');
|
||||
|
||||
result.push({ role: message.role, content: text || '[no content]' });
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private async requestOpenAIJson(
|
||||
path: string,
|
||||
body: Record<string, unknown>,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import type { ToolSet } from 'ai';
|
||||
|
||||
import { CopilotProviderSideError, metrics } from '../../../base';
|
||||
import {
|
||||
llmDispatchStream,
|
||||
@@ -7,6 +5,7 @@ import {
|
||||
type NativeLlmRequest,
|
||||
} from '../../../native';
|
||||
import type { NodeTextMiddleware } from '../config';
|
||||
import type { CopilotToolSet } from '../tools';
|
||||
import { buildNativeRequest, NativeProviderAdapter } from './native';
|
||||
import { CopilotProvider } from './provider';
|
||||
import {
|
||||
@@ -87,7 +86,7 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
}
|
||||
|
||||
private createNativeAdapter(
|
||||
tools: ToolSet,
|
||||
tools: CopilotToolSet,
|
||||
nodeTextMiddleware?: NodeTextMiddleware[]
|
||||
) {
|
||||
return new NativeProviderAdapter(
|
||||
@@ -110,8 +109,13 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
cond: fullCond,
|
||||
messages,
|
||||
options,
|
||||
withAttachment: false,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
@@ -128,7 +132,7 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
return await adapter.text(request, options.signal);
|
||||
return await adapter.text(request, options.signal, messages);
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
@@ -143,8 +147,13 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
const normalizedCond = await this.checkParams({
|
||||
cond: fullCond,
|
||||
messages,
|
||||
options,
|
||||
withAttachment: false,
|
||||
});
|
||||
const model = this.selectModel(normalizedCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
@@ -163,7 +172,11 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
for await (const chunk of adapter.streamText(request, options.signal)) {
|
||||
for await (const chunk of adapter.streamText(
|
||||
request,
|
||||
options.signal,
|
||||
messages
|
||||
)) {
|
||||
yield chunk;
|
||||
}
|
||||
} catch (e: any) {
|
||||
|
||||
@@ -51,13 +51,21 @@ const DEFAULT_MIDDLEWARE_BY_TYPE: Record<
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.Gemini]: {
|
||||
rust: {
|
||||
request: ['normalize_messages', 'tool_schema_rewrite'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
},
|
||||
node: {
|
||||
text: ['callout'],
|
||||
text: ['citation_footnote', 'callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.GeminiVertex]: {
|
||||
rust: {
|
||||
request: ['normalize_messages', 'tool_schema_rewrite'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
},
|
||||
node: {
|
||||
text: ['callout'],
|
||||
text: ['citation_footnote', 'callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.FAL]: {},
|
||||
|
||||
@@ -5,7 +5,7 @@ import type {
|
||||
ProviderMiddlewareConfig,
|
||||
} from '../config';
|
||||
import { resolveProviderMiddleware } from './provider-middleware';
|
||||
import { CopilotProviderType, type ModelOutputType } from './types';
|
||||
import { CopilotProviderType, ModelOutputType } from './types';
|
||||
|
||||
const PROVIDER_ID_PATTERN = /^[a-zA-Z0-9-_]+$/;
|
||||
|
||||
@@ -239,8 +239,13 @@ export function resolveModel({
|
||||
};
|
||||
}
|
||||
|
||||
const defaultProviderId =
|
||||
outputType && outputType !== ModelOutputType.Rerank
|
||||
? registry.defaults[outputType]
|
||||
: undefined;
|
||||
|
||||
const fallbackOrder = [
|
||||
...(outputType ? [registry.defaults[outputType]] : []),
|
||||
...(defaultProviderId ? [defaultProviderId] : []),
|
||||
registry.defaults.fallback,
|
||||
...registry.order,
|
||||
].filter((id): id is string => !!id);
|
||||
|
||||
@@ -2,7 +2,6 @@ import { AsyncLocalStorage } from 'node:async_hooks';
|
||||
|
||||
import { Inject, Injectable, Logger } from '@nestjs/common';
|
||||
import { ModuleRef } from '@nestjs/core';
|
||||
import { Tool, ToolSet } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
@@ -27,6 +26,8 @@ import {
|
||||
buildDocSearchGetter,
|
||||
buildDocUpdateHandler,
|
||||
buildDocUpdateMetaHandler,
|
||||
type CopilotTool,
|
||||
type CopilotToolSet,
|
||||
createBlobReadTool,
|
||||
createCodeArtifactTool,
|
||||
createConversationSummaryTool,
|
||||
@@ -42,6 +43,7 @@ import {
|
||||
createExaSearchTool,
|
||||
createSectionEditTool,
|
||||
} from '../tools';
|
||||
import { canonicalizePromptAttachment } from './attachments';
|
||||
import { CopilotProviderFactory } from './factory';
|
||||
import { resolveProviderMiddleware } from './provider-middleware';
|
||||
import { buildProviderRegistry } from './provider-registry';
|
||||
@@ -52,12 +54,17 @@ import {
|
||||
type CopilotImageOptions,
|
||||
CopilotProviderModel,
|
||||
CopilotProviderType,
|
||||
type CopilotRerankRequest,
|
||||
CopilotStructuredOptions,
|
||||
EmbeddingMessage,
|
||||
type ModelAttachmentCapability,
|
||||
ModelCapability,
|
||||
ModelConditions,
|
||||
ModelFullConditions,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
type PromptAttachmentKind,
|
||||
type PromptAttachmentSourceKind,
|
||||
type PromptMessage,
|
||||
PromptMessageSchema,
|
||||
StreamObject,
|
||||
@@ -163,6 +170,163 @@ export abstract class CopilotProvider<C = any> {
|
||||
|
||||
async refreshOnlineModels() {}
|
||||
|
||||
private unique<T>(values: Iterable<T>) {
|
||||
return Array.from(new Set(values));
|
||||
}
|
||||
|
||||
private attachmentKindToInputType(
|
||||
kind: PromptAttachmentKind
|
||||
): ModelInputType {
|
||||
switch (kind) {
|
||||
case 'image':
|
||||
return ModelInputType.Image;
|
||||
case 'audio':
|
||||
return ModelInputType.Audio;
|
||||
default:
|
||||
return ModelInputType.File;
|
||||
}
|
||||
}
|
||||
|
||||
protected async inferModelConditionsFromMessages(
|
||||
messages?: PromptMessage[],
|
||||
withAttachment = true
|
||||
): Promise<Partial<ModelFullConditions>> {
|
||||
if (!messages?.length || !withAttachment) return {};
|
||||
|
||||
const attachmentKinds: PromptAttachmentKind[] = [];
|
||||
const attachmentSourceKinds: PromptAttachmentSourceKind[] = [];
|
||||
const inputTypes: ModelInputType[] = [];
|
||||
let hasRemoteAttachments = false;
|
||||
|
||||
for (const message of messages) {
|
||||
if (!Array.isArray(message.attachments)) continue;
|
||||
|
||||
for (const attachment of message.attachments) {
|
||||
const normalized = await canonicalizePromptAttachment(
|
||||
attachment,
|
||||
message
|
||||
);
|
||||
attachmentKinds.push(normalized.kind);
|
||||
inputTypes.push(this.attachmentKindToInputType(normalized.kind));
|
||||
attachmentSourceKinds.push(normalized.sourceKind);
|
||||
hasRemoteAttachments = hasRemoteAttachments || normalized.isRemote;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
...(attachmentKinds.length
|
||||
? { attachmentKinds: this.unique(attachmentKinds) }
|
||||
: {}),
|
||||
...(attachmentSourceKinds.length
|
||||
? { attachmentSourceKinds: this.unique(attachmentSourceKinds) }
|
||||
: {}),
|
||||
...(inputTypes.length ? { inputTypes: this.unique(inputTypes) } : {}),
|
||||
...(hasRemoteAttachments ? { hasRemoteAttachments } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
private mergeModelConditions(
|
||||
cond: ModelFullConditions,
|
||||
inferredCond: Partial<ModelFullConditions>
|
||||
): ModelFullConditions {
|
||||
return {
|
||||
...inferredCond,
|
||||
...cond,
|
||||
inputTypes: this.unique([
|
||||
...(inferredCond.inputTypes ?? []),
|
||||
...(cond.inputTypes ?? []),
|
||||
]),
|
||||
attachmentKinds: this.unique([
|
||||
...(inferredCond.attachmentKinds ?? []),
|
||||
...(cond.attachmentKinds ?? []),
|
||||
]),
|
||||
attachmentSourceKinds: this.unique([
|
||||
...(inferredCond.attachmentSourceKinds ?? []),
|
||||
...(cond.attachmentSourceKinds ?? []),
|
||||
]),
|
||||
hasRemoteAttachments:
|
||||
cond.hasRemoteAttachments ?? inferredCond.hasRemoteAttachments,
|
||||
};
|
||||
}
|
||||
|
||||
protected getAttachCapability(
|
||||
model: CopilotProviderModel,
|
||||
outputType: ModelOutputType
|
||||
): ModelAttachmentCapability | undefined {
|
||||
const capability =
|
||||
model.capabilities.find(cap => cap.output.includes(outputType)) ??
|
||||
model.capabilities[0];
|
||||
if (!capability) {
|
||||
return;
|
||||
}
|
||||
return this.resolveAttachmentCapability(capability, outputType);
|
||||
}
|
||||
|
||||
private resolveAttachmentCapability(
|
||||
cap: ModelCapability,
|
||||
outputType?: ModelOutputType
|
||||
): ModelAttachmentCapability | undefined {
|
||||
if (outputType === ModelOutputType.Structured) {
|
||||
return cap.structuredAttachments ?? cap.attachments;
|
||||
}
|
||||
return cap.attachments;
|
||||
}
|
||||
|
||||
private matchesAttachCapability(
|
||||
cap: ModelCapability,
|
||||
cond: ModelFullConditions
|
||||
) {
|
||||
const {
|
||||
attachmentKinds,
|
||||
attachmentSourceKinds,
|
||||
hasRemoteAttachments,
|
||||
outputType,
|
||||
} = cond;
|
||||
|
||||
if (
|
||||
!attachmentKinds?.length &&
|
||||
!attachmentSourceKinds?.length &&
|
||||
!hasRemoteAttachments
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const attachmentCapability = this.resolveAttachmentCapability(
|
||||
cap,
|
||||
outputType
|
||||
);
|
||||
if (!attachmentCapability) {
|
||||
return !attachmentKinds?.some(
|
||||
kind => !cap.input.includes(this.attachmentKindToInputType(kind))
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
attachmentKinds?.some(kind => !attachmentCapability.kinds.includes(kind))
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (
|
||||
attachmentSourceKinds?.length &&
|
||||
attachmentCapability.sourceKinds?.length &&
|
||||
attachmentSourceKinds.some(
|
||||
kind => !attachmentCapability.sourceKinds?.includes(kind)
|
||||
)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (
|
||||
hasRemoteAttachments &&
|
||||
attachmentCapability.allowRemoteUrls === false
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private findValidModel(
|
||||
cond: ModelFullConditions
|
||||
): CopilotProviderModel | undefined {
|
||||
@@ -170,7 +334,8 @@ export abstract class CopilotProvider<C = any> {
|
||||
const matcher = (cap: ModelCapability) =>
|
||||
(!outputType || cap.output.includes(outputType)) &&
|
||||
(!inputTypes?.length ||
|
||||
inputTypes.every(type => cap.input.includes(type)));
|
||||
inputTypes.every(type => cap.input.includes(type))) &&
|
||||
this.matchesAttachCapability(cap, cond);
|
||||
|
||||
if (modelId) {
|
||||
const hasOnlineModel = this.onlineModelList.includes(modelId);
|
||||
@@ -213,7 +378,7 @@ export abstract class CopilotProvider<C = any> {
|
||||
protected getProviderSpecificTools(
|
||||
_toolName: CopilotChatTools,
|
||||
_model: string
|
||||
): [string, Tool?] | undefined {
|
||||
): [string, CopilotTool?] | undefined {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -221,8 +386,8 @@ export abstract class CopilotProvider<C = any> {
|
||||
protected async getTools(
|
||||
options: CopilotChatOptions,
|
||||
model: string
|
||||
): Promise<ToolSet> {
|
||||
const tools: ToolSet = {};
|
||||
): Promise<CopilotToolSet> {
|
||||
const tools: CopilotToolSet = {};
|
||||
if (options?.tools?.length) {
|
||||
this.logger.debug(`getTools: ${JSON.stringify(options.tools)}`);
|
||||
const ac = this.moduleRef.get(AccessController, { strict: false });
|
||||
@@ -305,7 +470,8 @@ export abstract class CopilotProvider<C = any> {
|
||||
});
|
||||
const searchDocs = buildDocKeywordSearchGetter(
|
||||
ac,
|
||||
indexerService
|
||||
indexerService,
|
||||
models
|
||||
);
|
||||
tools.doc_keyword_search = createDocKeywordSearchTool(
|
||||
searchDocs.bind(null, options)
|
||||
@@ -377,19 +543,14 @@ export abstract class CopilotProvider<C = any> {
|
||||
messages,
|
||||
embeddings,
|
||||
options = {},
|
||||
withAttachment = true,
|
||||
}: {
|
||||
cond: ModelFullConditions;
|
||||
messages?: PromptMessage[];
|
||||
embeddings?: string[];
|
||||
options?: CopilotChatOptions;
|
||||
}) {
|
||||
const model = this.selectModel(cond);
|
||||
const multimodal = model.capabilities.some(c =>
|
||||
[ModelInputType.Image, ModelInputType.Audio].some(t =>
|
||||
c.input.includes(t)
|
||||
)
|
||||
);
|
||||
|
||||
options?: CopilotChatOptions | CopilotStructuredOptions;
|
||||
withAttachment?: boolean;
|
||||
}): Promise<ModelFullConditions> {
|
||||
if (messages) {
|
||||
const { requireContent = true, requireAttachment = false } = options;
|
||||
|
||||
@@ -402,20 +563,56 @@ export abstract class CopilotProvider<C = any> {
|
||||
})
|
||||
.passthrough()
|
||||
.catchall(z.union([z.string(), z.number(), z.date(), z.null()]))
|
||||
.refine(
|
||||
m =>
|
||||
!(multimodal && requireAttachment && m.role === 'user') ||
|
||||
(m.attachments ? m.attachments.length > 0 : true),
|
||||
{ message: 'attachments required in multimodal mode' }
|
||||
)
|
||||
)
|
||||
.optional();
|
||||
|
||||
this.handleZodError(MessageSchema.safeParse(messages));
|
||||
|
||||
const inferredCond = await this.inferModelConditionsFromMessages(
|
||||
messages,
|
||||
withAttachment
|
||||
);
|
||||
const mergedCond = this.mergeModelConditions(cond, inferredCond);
|
||||
const model = this.selectModel(mergedCond);
|
||||
const multimodal = model.capabilities.some(c =>
|
||||
[ModelInputType.Image, ModelInputType.Audio, ModelInputType.File].some(
|
||||
t => c.input.includes(t)
|
||||
)
|
||||
);
|
||||
|
||||
if (
|
||||
multimodal &&
|
||||
requireAttachment &&
|
||||
!messages.some(
|
||||
message =>
|
||||
message.role === 'user' &&
|
||||
Array.isArray(message.attachments) &&
|
||||
message.attachments.length > 0
|
||||
)
|
||||
) {
|
||||
throw new CopilotPromptInvalid(
|
||||
'attachments required in multimodal mode'
|
||||
);
|
||||
}
|
||||
|
||||
if (embeddings) {
|
||||
this.handleZodError(EmbeddingMessage.safeParse(embeddings));
|
||||
}
|
||||
|
||||
return mergedCond;
|
||||
}
|
||||
|
||||
const inferredCond = await this.inferModelConditionsFromMessages(
|
||||
messages,
|
||||
withAttachment
|
||||
);
|
||||
const mergedCond = this.mergeModelConditions(cond, inferredCond);
|
||||
|
||||
if (embeddings) {
|
||||
this.handleZodError(EmbeddingMessage.safeParse(embeddings));
|
||||
}
|
||||
|
||||
return mergedCond;
|
||||
}
|
||||
|
||||
abstract text(
|
||||
@@ -476,7 +673,7 @@ export abstract class CopilotProvider<C = any> {
|
||||
|
||||
async rerank(
|
||||
_model: ModelConditions,
|
||||
_messages: PromptMessage[][],
|
||||
_request: CopilotRerankRequest,
|
||||
_options?: CopilotChatOptions
|
||||
): Promise<number[]> {
|
||||
throw new CopilotProviderNotSupported({
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
const GPT_4_RERANK_MODELS = /^(gpt-4(?:$|[.-]))/;
|
||||
const GPT_5_RERANK_LOGPROBS_MODELS = /^(gpt-5\.2(?:$|-))/;
|
||||
|
||||
export const DEFAULT_RERANK_MODEL = 'gpt-5.2';
|
||||
export const OPENAI_RERANK_TOP_LOGPROBS_LIMIT = 5;
|
||||
export const OPENAI_RERANK_MAX_COMPLETION_TOKENS = 16;
|
||||
|
||||
export function supportsRerankModel(model: string): boolean {
|
||||
return (
|
||||
GPT_4_RERANK_MODELS.test(model) || GPT_5_RERANK_LOGPROBS_MODELS.test(model)
|
||||
);
|
||||
}
|
||||
|
||||
export function usesRerankReasoning(model: string): boolean {
|
||||
return GPT_5_RERANK_LOGPROBS_MODELS.test(model);
|
||||
}
|
||||
|
||||
export function normalizeRerankModel(model?: string | null): string {
|
||||
if (model && supportsRerankModel(model)) {
|
||||
return model;
|
||||
}
|
||||
return DEFAULT_RERANK_MODEL;
|
||||
}
|
||||
@@ -124,14 +124,97 @@ export const ChatMessageRole = Object.values(AiPromptRole) as [
|
||||
'user',
|
||||
];
|
||||
|
||||
const AttachmentUrlSchema = z.string().refine(value => {
|
||||
if (value.startsWith('data:')) {
|
||||
return true;
|
||||
}
|
||||
|
||||
try {
|
||||
const url = new URL(value);
|
||||
return (
|
||||
url.protocol === 'http:' ||
|
||||
url.protocol === 'https:' ||
|
||||
url.protocol === 'gs:'
|
||||
);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}, 'attachments must use https?://, gs:// or data: urls');
|
||||
|
||||
export const PromptAttachmentSourceKindSchema = z.enum([
|
||||
'url',
|
||||
'data',
|
||||
'bytes',
|
||||
'file_handle',
|
||||
]);
|
||||
|
||||
export const PromptAttachmentKindSchema = z.enum(['image', 'audio', 'file']);
|
||||
|
||||
const AttachmentProviderHintSchema = z
|
||||
.object({
|
||||
provider: z.nativeEnum(CopilotProviderType).optional(),
|
||||
kind: PromptAttachmentKindSchema.optional(),
|
||||
})
|
||||
.strict();
|
||||
|
||||
const PromptAttachmentSchema = z.discriminatedUnion('kind', [
|
||||
z
|
||||
.object({
|
||||
kind: z.literal('url'),
|
||||
url: AttachmentUrlSchema,
|
||||
mimeType: z.string().optional(),
|
||||
fileName: z.string().optional(),
|
||||
providerHint: AttachmentProviderHintSchema.optional(),
|
||||
})
|
||||
.strict(),
|
||||
z
|
||||
.object({
|
||||
kind: z.literal('data'),
|
||||
data: z.string(),
|
||||
mimeType: z.string(),
|
||||
encoding: z.enum(['base64', 'utf8']).optional(),
|
||||
fileName: z.string().optional(),
|
||||
providerHint: AttachmentProviderHintSchema.optional(),
|
||||
})
|
||||
.strict(),
|
||||
z
|
||||
.object({
|
||||
kind: z.literal('bytes'),
|
||||
data: z.string(),
|
||||
mimeType: z.string(),
|
||||
encoding: z.literal('base64').optional(),
|
||||
fileName: z.string().optional(),
|
||||
providerHint: AttachmentProviderHintSchema.optional(),
|
||||
})
|
||||
.strict(),
|
||||
z
|
||||
.object({
|
||||
kind: z.literal('file_handle'),
|
||||
fileHandle: z.string().trim().min(1),
|
||||
mimeType: z.string().optional(),
|
||||
fileName: z.string().optional(),
|
||||
providerHint: AttachmentProviderHintSchema.optional(),
|
||||
})
|
||||
.strict(),
|
||||
]);
|
||||
|
||||
export const ChatMessageAttachment = z.union([
|
||||
z.string().url(),
|
||||
AttachmentUrlSchema,
|
||||
z.object({
|
||||
attachment: z.string(),
|
||||
attachment: AttachmentUrlSchema,
|
||||
mimeType: z.string(),
|
||||
}),
|
||||
PromptAttachmentSchema,
|
||||
]);
|
||||
|
||||
export const PromptResponseFormatSchema = z
|
||||
.object({
|
||||
type: z.literal('json_schema'),
|
||||
schema: z.any(),
|
||||
strict: z.boolean().optional(),
|
||||
})
|
||||
.strict();
|
||||
|
||||
export const StreamObjectSchema = z.discriminatedUnion('type', [
|
||||
z.object({
|
||||
type: z.literal('text-delta'),
|
||||
@@ -161,6 +244,7 @@ export const PureMessageSchema = z.object({
|
||||
streamObjects: z.array(StreamObjectSchema).optional().nullable(),
|
||||
attachments: z.array(ChatMessageAttachment).optional().nullable(),
|
||||
params: z.record(z.any()).optional().nullable(),
|
||||
responseFormat: PromptResponseFormatSchema.optional().nullable(),
|
||||
});
|
||||
|
||||
export const PromptMessageSchema = PureMessageSchema.extend({
|
||||
@@ -169,6 +253,12 @@ export const PromptMessageSchema = PureMessageSchema.extend({
|
||||
export type PromptMessage = z.infer<typeof PromptMessageSchema>;
|
||||
export type PromptParams = NonNullable<PromptMessage['params']>;
|
||||
export type StreamObject = z.infer<typeof StreamObjectSchema>;
|
||||
export type PromptAttachment = z.infer<typeof ChatMessageAttachment>;
|
||||
export type PromptAttachmentSourceKind = z.infer<
|
||||
typeof PromptAttachmentSourceKindSchema
|
||||
>;
|
||||
export type PromptAttachmentKind = z.infer<typeof PromptAttachmentKindSchema>;
|
||||
export type PromptResponseFormat = z.infer<typeof PromptResponseFormatSchema>;
|
||||
|
||||
// ========== options ==========
|
||||
|
||||
@@ -194,7 +284,9 @@ export type CopilotChatTools = NonNullable<
|
||||
>[number];
|
||||
|
||||
export const CopilotStructuredOptionsSchema =
|
||||
CopilotProviderOptionsSchema.merge(PromptConfigStrictSchema).optional();
|
||||
CopilotProviderOptionsSchema.merge(PromptConfigStrictSchema)
|
||||
.extend({ schema: z.any().optional(), strict: z.boolean().optional() })
|
||||
.optional();
|
||||
|
||||
export type CopilotStructuredOptions = z.infer<
|
||||
typeof CopilotStructuredOptionsSchema
|
||||
@@ -220,10 +312,22 @@ export type CopilotEmbeddingOptions = z.infer<
|
||||
typeof CopilotEmbeddingOptionsSchema
|
||||
>;
|
||||
|
||||
export type CopilotRerankCandidate = {
|
||||
id?: string;
|
||||
text: string;
|
||||
};
|
||||
|
||||
export type CopilotRerankRequest = {
|
||||
query: string;
|
||||
candidates: CopilotRerankCandidate[];
|
||||
topK?: number;
|
||||
};
|
||||
|
||||
export enum ModelInputType {
|
||||
Text = 'text',
|
||||
Image = 'image',
|
||||
Audio = 'audio',
|
||||
File = 'file',
|
||||
}
|
||||
|
||||
export enum ModelOutputType {
|
||||
@@ -231,12 +335,21 @@ export enum ModelOutputType {
|
||||
Object = 'object',
|
||||
Embedding = 'embedding',
|
||||
Image = 'image',
|
||||
Rerank = 'rerank',
|
||||
Structured = 'structured',
|
||||
}
|
||||
|
||||
export interface ModelAttachmentCapability {
|
||||
kinds: PromptAttachmentKind[];
|
||||
sourceKinds?: PromptAttachmentSourceKind[];
|
||||
allowRemoteUrls?: boolean;
|
||||
}
|
||||
|
||||
export interface ModelCapability {
|
||||
input: ModelInputType[];
|
||||
output: ModelOutputType[];
|
||||
attachments?: ModelAttachmentCapability;
|
||||
structuredAttachments?: ModelAttachmentCapability;
|
||||
defaultForOutputType?: boolean;
|
||||
}
|
||||
|
||||
@@ -248,6 +361,9 @@ export interface CopilotProviderModel {
|
||||
|
||||
export type ModelConditions = {
|
||||
inputTypes?: ModelInputType[];
|
||||
attachmentKinds?: PromptAttachmentKind[];
|
||||
attachmentSourceKinds?: PromptAttachmentSourceKind[];
|
||||
hasRemoteAttachments?: boolean;
|
||||
modelId?: string;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,34 +1,39 @@
|
||||
import { GoogleVertexProviderSettings } from '@ai-sdk/google-vertex';
|
||||
import { GoogleVertexAnthropicProviderSettings } from '@ai-sdk/google-vertex/anthropic';
|
||||
import { Logger } from '@nestjs/common';
|
||||
import {
|
||||
AssistantModelMessage,
|
||||
FilePart,
|
||||
ImagePart,
|
||||
TextPart,
|
||||
TextStreamPart,
|
||||
UserModelMessage,
|
||||
} from 'ai';
|
||||
import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library';
|
||||
import z, { ZodType } from 'zod';
|
||||
import z from 'zod';
|
||||
|
||||
import {
|
||||
bufferToArrayBuffer,
|
||||
fetchBuffer,
|
||||
OneMinute,
|
||||
ResponseTooLargeError,
|
||||
safeFetch,
|
||||
SsrfBlockedError,
|
||||
} from '../../../base';
|
||||
import { CustomAITools } from '../tools';
|
||||
import { PromptMessage, StreamObject } from './types';
|
||||
import { OneMinute, safeFetch } from '../../../base';
|
||||
import { PromptAttachment, StreamObject } from './types';
|
||||
|
||||
type ChatMessage = UserModelMessage | AssistantModelMessage;
|
||||
export type VertexProviderConfig = {
|
||||
location?: string;
|
||||
project?: string;
|
||||
baseURL?: string;
|
||||
googleAuthOptions?: GoogleAuthOptions;
|
||||
fetch?: typeof fetch;
|
||||
};
|
||||
|
||||
export type VertexAnthropicProviderConfig = VertexProviderConfig;
|
||||
|
||||
type CopilotTextStreamPart =
|
||||
| { type: 'text-delta'; text: string; id?: string }
|
||||
| { type: 'reasoning-delta'; text: string; id?: string }
|
||||
| {
|
||||
type: 'tool-call';
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
input: Record<string, unknown>;
|
||||
}
|
||||
| {
|
||||
type: 'tool-result';
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
input: Record<string, unknown>;
|
||||
output: unknown;
|
||||
}
|
||||
| { type: 'error'; error: unknown };
|
||||
|
||||
const ATTACHMENT_MAX_BYTES = 20 * 1024 * 1024;
|
||||
const ATTACH_HEAD_PARAMS = { timeoutMs: OneMinute / 12, maxRedirects: 3 };
|
||||
|
||||
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
|
||||
const FORMAT_INFER_MAP: Record<string, string> = {
|
||||
pdf: 'application/pdf',
|
||||
mp3: 'audio/mpeg',
|
||||
@@ -53,9 +58,39 @@ const FORMAT_INFER_MAP: Record<string, string> = {
|
||||
flv: 'video/flv',
|
||||
};
|
||||
|
||||
async function fetchArrayBuffer(url: string): Promise<ArrayBuffer> {
|
||||
const { buffer } = await fetchBuffer(url, ATTACHMENT_MAX_BYTES);
|
||||
return bufferToArrayBuffer(buffer);
|
||||
function toBase64Data(data: string, encoding: 'base64' | 'utf8' = 'base64') {
|
||||
return encoding === 'base64'
|
||||
? data
|
||||
: Buffer.from(data, 'utf8').toString('base64');
|
||||
}
|
||||
|
||||
export function promptAttachmentToUrl(
|
||||
attachment: PromptAttachment
|
||||
): string | undefined {
|
||||
if (typeof attachment === 'string') return attachment;
|
||||
if ('attachment' in attachment) return attachment.attachment;
|
||||
switch (attachment.kind) {
|
||||
case 'url':
|
||||
return attachment.url;
|
||||
case 'data':
|
||||
return `data:${attachment.mimeType};base64,${toBase64Data(
|
||||
attachment.data,
|
||||
attachment.encoding
|
||||
)}`;
|
||||
case 'bytes':
|
||||
return `data:${attachment.mimeType};base64,${attachment.data}`;
|
||||
case 'file_handle':
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
export function promptAttachmentMimeType(
|
||||
attachment: PromptAttachment,
|
||||
fallbackMimeType?: string
|
||||
): string | undefined {
|
||||
if (typeof attachment === 'string') return fallbackMimeType;
|
||||
if ('attachment' in attachment) return attachment.mimeType;
|
||||
return attachment.mimeType ?? fallbackMimeType;
|
||||
}
|
||||
|
||||
export async function inferMimeType(url: string) {
|
||||
@@ -69,346 +104,21 @@ export async function inferMimeType(url: string) {
|
||||
if (ext) {
|
||||
return ext;
|
||||
}
|
||||
try {
|
||||
const mimeType = await safeFetch(
|
||||
url,
|
||||
{ method: 'HEAD' },
|
||||
ATTACH_HEAD_PARAMS
|
||||
).then(res => res.headers.get('content-type'));
|
||||
if (mimeType) return mimeType;
|
||||
} catch {
|
||||
// ignore and fallback to default
|
||||
}
|
||||
}
|
||||
try {
|
||||
const mimeType = await safeFetch(
|
||||
url,
|
||||
{ method: 'HEAD' },
|
||||
ATTACH_HEAD_PARAMS
|
||||
).then(res => res.headers.get('content-type'));
|
||||
if (mimeType) return mimeType;
|
||||
} catch {
|
||||
// ignore and fallback to default
|
||||
}
|
||||
return 'application/octet-stream';
|
||||
}
|
||||
|
||||
export async function chatToGPTMessage(
|
||||
messages: PromptMessage[],
|
||||
// TODO(@darkskygit): move this logic in interface refactoring
|
||||
withAttachment: boolean = true,
|
||||
// NOTE: some providers in vercel ai sdk are not able to handle url attachments yet
|
||||
// so we need to use base64 encoded attachments instead
|
||||
useBase64Attachment: boolean = false
|
||||
): Promise<[string | undefined, ChatMessage[], ZodType?]> {
|
||||
const hasSystem = messages[0]?.role === 'system';
|
||||
const system = hasSystem ? messages[0] : undefined;
|
||||
const normalizedMessages = hasSystem ? messages.slice(1) : messages;
|
||||
const schema =
|
||||
system?.params?.schema && system.params.schema instanceof ZodType
|
||||
? system.params.schema
|
||||
: undefined;
|
||||
|
||||
// filter redundant fields
|
||||
const msgs: ChatMessage[] = [];
|
||||
for (let { role, content, attachments, params } of normalizedMessages.filter(
|
||||
m => m.role !== 'system'
|
||||
)) {
|
||||
content = content.trim();
|
||||
role = role as 'user' | 'assistant';
|
||||
const mimetype = params?.mimetype;
|
||||
if (Array.isArray(attachments)) {
|
||||
const contents: (TextPart | ImagePart | FilePart)[] = [];
|
||||
if (content.length) {
|
||||
contents.push({ type: 'text', text: content });
|
||||
}
|
||||
|
||||
if (withAttachment) {
|
||||
for (let attachment of attachments) {
|
||||
let mediaType: string;
|
||||
if (typeof attachment === 'string') {
|
||||
mediaType =
|
||||
typeof mimetype === 'string'
|
||||
? mimetype
|
||||
: await inferMimeType(attachment);
|
||||
} else {
|
||||
({ attachment, mimeType: mediaType } = attachment);
|
||||
}
|
||||
if (SIMPLE_IMAGE_URL_REGEX.test(attachment)) {
|
||||
const data =
|
||||
attachment.startsWith('data:') || useBase64Attachment
|
||||
? await fetchArrayBuffer(attachment).catch(error => {
|
||||
// Avoid leaking internal details for blocked URLs.
|
||||
if (
|
||||
error instanceof SsrfBlockedError ||
|
||||
error instanceof ResponseTooLargeError
|
||||
) {
|
||||
throw new Error('Attachment URL is not allowed');
|
||||
}
|
||||
throw error;
|
||||
})
|
||||
: new URL(attachment);
|
||||
if (mediaType.startsWith('image/')) {
|
||||
contents.push({ type: 'image', image: data, mediaType });
|
||||
} else {
|
||||
contents.push({ type: 'file' as const, data, mediaType });
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (!content.length) {
|
||||
// temp fix for pplx
|
||||
contents.push({ type: 'text', text: '[no content]' });
|
||||
}
|
||||
|
||||
msgs.push({ role, content: contents } as ChatMessage);
|
||||
} else {
|
||||
msgs.push({ role, content });
|
||||
}
|
||||
}
|
||||
|
||||
return [system?.content, msgs, schema];
|
||||
}
|
||||
|
||||
// pattern types the callback will receive
|
||||
type Pattern =
|
||||
| { kind: 'index'; value: number } // [123]
|
||||
| { kind: 'link'; text: string; url: string } // [text](url)
|
||||
| { kind: 'wrappedLink'; text: string; url: string }; // ([text](url))
|
||||
|
||||
type NeedMore = { kind: 'needMore' };
|
||||
type Failed = { kind: 'fail'; nextPos: number };
|
||||
type Finished =
|
||||
| { kind: 'ok'; endPos: number; text: string; url: string }
|
||||
| { kind: 'index'; endPos: number; value: number };
|
||||
type ParseStatus = Finished | NeedMore | Failed;
|
||||
|
||||
type PatternCallback = (m: Pattern) => string;
|
||||
|
||||
export class StreamPatternParser {
|
||||
#buffer = '';
|
||||
|
||||
constructor(private readonly callback: PatternCallback) {}
|
||||
|
||||
write(chunk: string): string {
|
||||
this.#buffer += chunk;
|
||||
const output: string[] = [];
|
||||
let i = 0;
|
||||
|
||||
while (i < this.#buffer.length) {
|
||||
const ch = this.#buffer[i];
|
||||
|
||||
// [[[number]]] or [text](url) or ([text](url))
|
||||
if (ch === '[' || (ch === '(' && this.peek(i + 1) === '[')) {
|
||||
const isWrapped = ch === '(';
|
||||
const startPos = isWrapped ? i + 1 : i;
|
||||
const res = this.tryParse(startPos);
|
||||
if (res.kind === 'needMore') break;
|
||||
const { output: out, nextPos } = this.handlePattern(
|
||||
res,
|
||||
isWrapped,
|
||||
startPos,
|
||||
i
|
||||
);
|
||||
output.push(out);
|
||||
i = nextPos;
|
||||
continue;
|
||||
}
|
||||
output.push(ch);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
this.#buffer = this.#buffer.slice(i);
|
||||
return output.join('');
|
||||
}
|
||||
|
||||
end(): string {
|
||||
const rest = this.#buffer;
|
||||
this.#buffer = '';
|
||||
return rest;
|
||||
}
|
||||
|
||||
// =========== helpers ===========
|
||||
|
||||
private peek(pos: number): string | undefined {
|
||||
return pos < this.#buffer.length ? this.#buffer[pos] : undefined;
|
||||
}
|
||||
|
||||
private tryParse(pos: number): ParseStatus {
|
||||
const nestedRes = this.tryParseNestedIndex(pos);
|
||||
if (nestedRes) return nestedRes;
|
||||
return this.tryParseBracketPattern(pos);
|
||||
}
|
||||
|
||||
private tryParseNestedIndex(pos: number): ParseStatus | null {
|
||||
if (this.peek(pos + 1) !== '[') return null;
|
||||
|
||||
let i = pos;
|
||||
let bracketCount = 0;
|
||||
|
||||
while (i < this.#buffer.length && this.#buffer[i] === '[') {
|
||||
bracketCount++;
|
||||
i++;
|
||||
}
|
||||
|
||||
if (bracketCount >= 2) {
|
||||
if (i >= this.#buffer.length) {
|
||||
return { kind: 'needMore' };
|
||||
}
|
||||
|
||||
let content = '';
|
||||
while (i < this.#buffer.length && this.#buffer[i] !== ']') {
|
||||
content += this.#buffer[i++];
|
||||
}
|
||||
|
||||
let rightBracketCount = 0;
|
||||
while (i < this.#buffer.length && this.#buffer[i] === ']') {
|
||||
rightBracketCount++;
|
||||
i++;
|
||||
}
|
||||
|
||||
if (i >= this.#buffer.length && rightBracketCount < bracketCount) {
|
||||
return { kind: 'needMore' };
|
||||
}
|
||||
|
||||
if (
|
||||
rightBracketCount === bracketCount &&
|
||||
content.length > 0 &&
|
||||
this.isNumeric(content)
|
||||
) {
|
||||
if (this.peek(i) === '(') {
|
||||
return { kind: 'fail', nextPos: i };
|
||||
}
|
||||
return { kind: 'index', endPos: i, value: Number(content) };
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private tryParseBracketPattern(pos: number): ParseStatus {
|
||||
let i = pos + 1; // skip '['
|
||||
if (i >= this.#buffer.length) {
|
||||
return { kind: 'needMore' };
|
||||
}
|
||||
|
||||
let content = '';
|
||||
while (i < this.#buffer.length && this.#buffer[i] !== ']') {
|
||||
const nextChar = this.#buffer[i];
|
||||
if (nextChar === '[') {
|
||||
return { kind: 'fail', nextPos: i };
|
||||
}
|
||||
content += nextChar;
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if (i >= this.#buffer.length) {
|
||||
return { kind: 'needMore' };
|
||||
}
|
||||
const after = i + 1;
|
||||
const afterChar = this.peek(after);
|
||||
|
||||
if (content.length > 0 && this.isNumeric(content) && afterChar !== '(') {
|
||||
// [number] pattern
|
||||
return { kind: 'index', endPos: after, value: Number(content) };
|
||||
} else if (afterChar !== '(') {
|
||||
// [text](url) pattern
|
||||
return { kind: 'fail', nextPos: after };
|
||||
}
|
||||
|
||||
i = after + 1; // skip '('
|
||||
if (i >= this.#buffer.length) {
|
||||
return { kind: 'needMore' };
|
||||
}
|
||||
|
||||
let url = '';
|
||||
while (i < this.#buffer.length && this.#buffer[i] !== ')') {
|
||||
url += this.#buffer[i++];
|
||||
}
|
||||
if (i >= this.#buffer.length) {
|
||||
return { kind: 'needMore' };
|
||||
}
|
||||
return { kind: 'ok', endPos: i + 1, text: content, url };
|
||||
}
|
||||
|
||||
private isNumeric(str: string): boolean {
|
||||
return !Number.isNaN(Number(str)) && str.trim() !== '';
|
||||
}
|
||||
|
||||
private handlePattern(
|
||||
pattern: Finished | Failed,
|
||||
isWrapped: boolean,
|
||||
start: number,
|
||||
current: number
|
||||
): { output: string; nextPos: number } {
|
||||
if (pattern.kind === 'fail') {
|
||||
return {
|
||||
output: this.#buffer.slice(current, pattern.nextPos),
|
||||
nextPos: pattern.nextPos,
|
||||
};
|
||||
}
|
||||
|
||||
if (isWrapped) {
|
||||
const afterLinkPos = pattern.endPos;
|
||||
if (this.peek(afterLinkPos) !== ')') {
|
||||
if (afterLinkPos >= this.#buffer.length) {
|
||||
return { output: '', nextPos: current };
|
||||
}
|
||||
return { output: '(', nextPos: start };
|
||||
}
|
||||
|
||||
const out =
|
||||
pattern.kind === 'index'
|
||||
? this.callback({ ...pattern, kind: 'index' })
|
||||
: this.callback({ ...pattern, kind: 'wrappedLink' });
|
||||
return { output: out, nextPos: afterLinkPos + 1 };
|
||||
} else {
|
||||
const out =
|
||||
pattern.kind === 'ok'
|
||||
? this.callback({ ...pattern, kind: 'link' })
|
||||
: this.callback({ ...pattern, kind: 'index' });
|
||||
return { output: out, nextPos: pattern.endPos };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class CitationParser {
|
||||
private readonly citations: string[] = [];
|
||||
|
||||
private readonly parser = new StreamPatternParser(p => {
|
||||
switch (p.kind) {
|
||||
case 'index': {
|
||||
if (p.value <= this.citations.length) {
|
||||
return `[^${p.value}]`;
|
||||
}
|
||||
return `[${p.value}]`;
|
||||
}
|
||||
case 'wrappedLink': {
|
||||
const index = this.citations.indexOf(p.url);
|
||||
if (index === -1) {
|
||||
this.citations.push(p.url);
|
||||
return `[^${this.citations.length}]`;
|
||||
}
|
||||
return `[^${index + 1}]`;
|
||||
}
|
||||
case 'link': {
|
||||
return `[${p.text}](${p.url})`;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
public push(citation: string) {
|
||||
this.citations.push(citation);
|
||||
}
|
||||
|
||||
public parse(content: string) {
|
||||
return this.parser.write(content);
|
||||
}
|
||||
|
||||
public end() {
|
||||
return this.parser.end() + '\n' + this.getFootnotes();
|
||||
}
|
||||
|
||||
private getFootnotes() {
|
||||
const footnotes = this.citations.map((citation, index) => {
|
||||
return `[^${index + 1}]: {"type":"url","url":"${encodeURIComponent(
|
||||
citation
|
||||
)}"}`;
|
||||
});
|
||||
return footnotes.join('\n');
|
||||
}
|
||||
}
|
||||
|
||||
export type CitationIndexedEvent = {
|
||||
type CitationIndexedEvent = {
|
||||
type: 'citation';
|
||||
index: number;
|
||||
url: string;
|
||||
@@ -436,7 +146,7 @@ export class CitationFootnoteFormatter {
|
||||
}
|
||||
}
|
||||
|
||||
type ChunkType = TextStreamPart<CustomAITools>['type'];
|
||||
type ChunkType = CopilotTextStreamPart['type'];
|
||||
|
||||
export function toError(error: unknown): Error {
|
||||
if (typeof error === 'string') {
|
||||
@@ -458,6 +168,14 @@ type DocEditFootnote = {
|
||||
intent: string;
|
||||
result: string;
|
||||
};
|
||||
|
||||
function asRecord(value: unknown): Record<string, unknown> | null {
|
||||
if (value && typeof value === 'object' && !Array.isArray(value)) {
|
||||
return value as Record<string, unknown>;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export class TextStreamParser {
|
||||
private readonly logger = new Logger(TextStreamParser.name);
|
||||
private readonly CALLOUT_PREFIX = '\n[!]\n';
|
||||
@@ -468,7 +186,7 @@ export class TextStreamParser {
|
||||
|
||||
private readonly docEditFootnotes: DocEditFootnote[] = [];
|
||||
|
||||
public parse(chunk: TextStreamPart<CustomAITools>) {
|
||||
public parse(chunk: CopilotTextStreamPart) {
|
||||
let result = '';
|
||||
switch (chunk.type) {
|
||||
case 'text-delta': {
|
||||
@@ -517,7 +235,7 @@ export class TextStreamParser {
|
||||
}
|
||||
case 'doc_edit': {
|
||||
this.docEditFootnotes.push({
|
||||
intent: chunk.input.instructions,
|
||||
intent: String(chunk.input.instructions ?? ''),
|
||||
result: '',
|
||||
});
|
||||
break;
|
||||
@@ -533,14 +251,12 @@ export class TextStreamParser {
|
||||
result = this.addPrefix(result);
|
||||
switch (chunk.toolName) {
|
||||
case 'doc_edit': {
|
||||
const array =
|
||||
chunk.output && typeof chunk.output === 'object'
|
||||
? chunk.output.result
|
||||
: undefined;
|
||||
const output = asRecord(chunk.output);
|
||||
const array = output?.result;
|
||||
if (Array.isArray(array)) {
|
||||
result += array
|
||||
.map(item => {
|
||||
return `\n${item.changedContent}\n`;
|
||||
return `\n${String(asRecord(item)?.changedContent ?? '')}\n`;
|
||||
})
|
||||
.join('');
|
||||
this.docEditFootnotes[this.docEditFootnotes.length - 1].result =
|
||||
@@ -557,8 +273,11 @@ export class TextStreamParser {
|
||||
} else if (typeof output === 'string') {
|
||||
result += `\n${output}\n`;
|
||||
} else {
|
||||
const message = asRecord(output)?.message;
|
||||
this.logger.warn(
|
||||
`Unexpected result type for doc_semantic_search: ${output?.message || 'Unknown error'}`
|
||||
`Unexpected result type for doc_semantic_search: ${
|
||||
typeof message === 'string' ? message : 'Unknown error'
|
||||
}`
|
||||
);
|
||||
}
|
||||
break;
|
||||
@@ -572,9 +291,11 @@ export class TextStreamParser {
|
||||
break;
|
||||
}
|
||||
case 'doc_compose': {
|
||||
const output = chunk.output;
|
||||
if (output && typeof output === 'object' && 'title' in output) {
|
||||
result += `\nDocument "${output.title}" created successfully with ${output.wordCount} words.\n`;
|
||||
const output = asRecord(chunk.output);
|
||||
if (output && typeof output.title === 'string') {
|
||||
result += `\nDocument "${output.title}" created successfully with ${String(
|
||||
output.wordCount ?? 0
|
||||
)} words.\n`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -654,7 +375,7 @@ export class TextStreamParser {
|
||||
}
|
||||
|
||||
export class StreamObjectParser {
|
||||
public parse(chunk: TextStreamPart<CustomAITools>) {
|
||||
public parse(chunk: CopilotTextStreamPart) {
|
||||
switch (chunk.type) {
|
||||
case 'reasoning-delta': {
|
||||
return { type: 'reasoning' as const, textDelta: chunk.text };
|
||||
@@ -747,9 +468,7 @@ function normalizeUrl(baseURL?: string) {
|
||||
}
|
||||
}
|
||||
|
||||
export function getVertexAnthropicBaseUrl(
|
||||
options: GoogleVertexAnthropicProviderSettings
|
||||
) {
|
||||
export function getVertexAnthropicBaseUrl(options: VertexProviderConfig) {
|
||||
const normalizedBaseUrl = normalizeUrl(options.baseURL);
|
||||
if (normalizedBaseUrl) return normalizedBaseUrl;
|
||||
const { location, project } = options;
|
||||
@@ -758,7 +477,7 @@ export function getVertexAnthropicBaseUrl(
|
||||
}
|
||||
|
||||
export async function getGoogleAuth(
|
||||
options: GoogleVertexAnthropicProviderSettings | GoogleVertexProviderSettings,
|
||||
options: VertexProviderConfig,
|
||||
publisher: 'anthropic' | 'google'
|
||||
) {
|
||||
function getBaseUrl() {
|
||||
@@ -777,7 +496,7 @@ export async function getGoogleAuth(
|
||||
}
|
||||
const auth = new GoogleAuth({
|
||||
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
|
||||
...(options.googleAuthOptions as GoogleAuthOptions),
|
||||
...options.googleAuthOptions,
|
||||
});
|
||||
const client = await auth.getClient();
|
||||
const token = await client.getAccessToken();
|
||||
|
||||
@@ -31,6 +31,7 @@ import { SubscriptionPlan, SubscriptionStatus } from '../payment/types';
|
||||
import { ChatMessageCache } from './message';
|
||||
import { ChatPrompt } from './prompt/chat-prompt';
|
||||
import { PromptService } from './prompt/service';
|
||||
import { promptAttachmentHasSource } from './providers/attachments';
|
||||
import { CopilotProviderFactory } from './providers/factory';
|
||||
import { buildProviderRegistry } from './providers/provider-registry';
|
||||
import {
|
||||
@@ -38,6 +39,7 @@ import {
|
||||
type PromptMessage,
|
||||
type PromptParams,
|
||||
} from './providers/types';
|
||||
import { promptAttachmentToUrl } from './providers/utils';
|
||||
import {
|
||||
type ChatHistory,
|
||||
type ChatMessage,
|
||||
@@ -272,11 +274,7 @@ export class ChatSession implements AsyncDisposable {
|
||||
lastMessage.attachments || [],
|
||||
]
|
||||
.flat()
|
||||
.filter(v =>
|
||||
typeof v === 'string'
|
||||
? !!v.trim()
|
||||
: v && v.attachment.trim() && v.mimeType
|
||||
);
|
||||
.filter(v => promptAttachmentHasSource(v));
|
||||
//insert all previous user message content before first user message
|
||||
finished.splice(firstUserMessageIndex, 0, ...messages);
|
||||
|
||||
@@ -466,8 +464,8 @@ export class ChatSessionService {
|
||||
messages: preload.concat(messages).map(m => ({
|
||||
...m,
|
||||
attachments: m.attachments
|
||||
?.map(a => (typeof a === 'string' ? a : a.attachment))
|
||||
.filter(a => !!a),
|
||||
?.map(a => promptAttachmentToUrl(a))
|
||||
.filter((a): a is string => !!a),
|
||||
})),
|
||||
};
|
||||
} else {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { tool } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { AccessController } from '../../../core/permission';
|
||||
import { toolError } from './error';
|
||||
import { defineTool } from './tool';
|
||||
import type { ContextSession, CopilotChatOptions } from './types';
|
||||
|
||||
const logger = new Logger('ContextBlobReadTool');
|
||||
@@ -18,7 +18,10 @@ export const buildBlobContentGetter = (
|
||||
chunk?: number
|
||||
) => {
|
||||
if (!options?.user || !options?.workspace || !blobId || !context) {
|
||||
return;
|
||||
return toolError(
|
||||
'Blob Read Failed',
|
||||
'Missing workspace, user, blob id, or copilot context for blob_read.'
|
||||
);
|
||||
}
|
||||
const canAccess = await ac
|
||||
.user(options.user)
|
||||
@@ -29,7 +32,10 @@ export const buildBlobContentGetter = (
|
||||
logger.warn(
|
||||
`User ${options.user} does not have access workspace ${options.workspace}`
|
||||
);
|
||||
return;
|
||||
return toolError(
|
||||
'Blob Read Failed',
|
||||
'You do not have permission to access this workspace attachment.'
|
||||
);
|
||||
}
|
||||
|
||||
const contextFile = context.files.find(
|
||||
@@ -42,7 +48,12 @@ export const buildBlobContentGetter = (
|
||||
context.getBlobContent(canonicalBlobId, chunk),
|
||||
]);
|
||||
const content = file?.trim() || blob?.trim();
|
||||
if (!content) return;
|
||||
if (!content) {
|
||||
return toolError(
|
||||
'Blob Read Failed',
|
||||
`Attachment ${canonicalBlobId} is not available for reading in the current copilot context.`
|
||||
);
|
||||
}
|
||||
const info = contextFile
|
||||
? { fileName: contextFile.name, fileType: contextFile.mimeType }
|
||||
: {};
|
||||
@@ -53,12 +64,9 @@ export const buildBlobContentGetter = (
|
||||
};
|
||||
|
||||
export const createBlobReadTool = (
|
||||
getBlobContent: (
|
||||
targetId?: string,
|
||||
chunk?: number
|
||||
) => Promise<object | undefined>
|
||||
getBlobContent: (targetId?: string, chunk?: number) => Promise<object>
|
||||
) => {
|
||||
return tool({
|
||||
return defineTool({
|
||||
description:
|
||||
'Return the content and basic metadata of a single attachment identified by blobId; more inclined to use search tools rather than this tool.',
|
||||
inputSchema: z.object({
|
||||
@@ -73,13 +81,10 @@ export const createBlobReadTool = (
|
||||
execute: async ({ blob_id, chunk }) => {
|
||||
try {
|
||||
const blob = await getBlobContent(blob_id, chunk);
|
||||
if (!blob) {
|
||||
return;
|
||||
}
|
||||
return { ...blob };
|
||||
} catch (err: any) {
|
||||
logger.error(`Failed to read the blob ${blob_id} in context`, err);
|
||||
return toolError('Blob Read Failed', err.message);
|
||||
return toolError('Blob Read Failed', err.message ?? String(err));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { tool } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { toolError } from './error';
|
||||
import { defineTool } from './tool';
|
||||
import type { CopilotProviderFactory, PromptService } from './types';
|
||||
|
||||
const logger = new Logger('CodeArtifactTool');
|
||||
@@ -16,7 +16,7 @@ export const createCodeArtifactTool = (
|
||||
promptService: PromptService,
|
||||
factory: CopilotProviderFactory
|
||||
) => {
|
||||
return tool({
|
||||
return defineTool({
|
||||
description:
|
||||
'Generate a single-file HTML snippet (with inline <style> and <script>) that accomplishes the requested functionality. The final HTML should be runnable when saved as an .html file and opened in a browser. Do NOT reference external resources (CSS, JS, images) except through data URIs.',
|
||||
inputSchema: z.object({
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { tool } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { toolError } from './error';
|
||||
import { defineTool } from './tool';
|
||||
import type { CopilotProviderFactory, PromptService } from './types';
|
||||
|
||||
const logger = new Logger('ConversationSummaryTool');
|
||||
@@ -12,7 +12,7 @@ export const createConversationSummaryTool = (
|
||||
promptService: PromptService,
|
||||
factory: CopilotProviderFactory
|
||||
) => {
|
||||
return tool({
|
||||
return defineTool({
|
||||
description:
|
||||
'Create a concise, AI-generated summary of the conversation so far—capturing key topics, decisions, and critical details. Use this tool whenever the context becomes lengthy to preserve essential information that might otherwise be lost to truncation in future turns.',
|
||||
inputSchema: z.object({
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { tool } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { toolError } from './error';
|
||||
import { defineTool } from './tool';
|
||||
import type { CopilotProviderFactory, PromptService } from './types';
|
||||
|
||||
const logger = new Logger('DocComposeTool');
|
||||
@@ -11,7 +11,7 @@ export const createDocComposeTool = (
|
||||
promptService: PromptService,
|
||||
factory: CopilotProviderFactory
|
||||
) => {
|
||||
return tool({
|
||||
return defineTool({
|
||||
description:
|
||||
'Write a new document with markdown content. This tool creates structured markdown content for documents including titles, sections, and formatting.',
|
||||
inputSchema: z.object({
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { tool } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { DocReader } from '../../../core/doc';
|
||||
import { AccessController } from '../../../core/permission';
|
||||
import { defineTool } from './tool';
|
||||
import type {
|
||||
CopilotChatOptions,
|
||||
CopilotProviderFactory,
|
||||
@@ -50,7 +50,7 @@ export const createDocEditTool = (
|
||||
prompt: PromptService,
|
||||
getContent: (targetId?: string) => Promise<string | undefined>
|
||||
) => {
|
||||
return tool({
|
||||
return defineTool({
|
||||
description: `
|
||||
Use this tool to propose an edit to a structured Markdown document with identifiable blocks.
|
||||
Each block begins with a comment like <!-- block_id=... -->, and represents a unit of editable content such as a heading, paragraph, list, or code snippet.
|
||||
|
||||
@@ -1,27 +1,43 @@
|
||||
import { tool } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import type { AccessController } from '../../../core/permission';
|
||||
import type { Models } from '../../../models';
|
||||
import type { IndexerService, SearchDoc } from '../../indexer';
|
||||
import { workspaceSyncRequiredError } from './doc-sync';
|
||||
import { toolError } from './error';
|
||||
import { defineTool } from './tool';
|
||||
import type { CopilotChatOptions } from './types';
|
||||
|
||||
export const buildDocKeywordSearchGetter = (
|
||||
ac: AccessController,
|
||||
indexerService: IndexerService
|
||||
indexerService: IndexerService,
|
||||
models: Models
|
||||
) => {
|
||||
const searchDocs = async (options: CopilotChatOptions, query?: string) => {
|
||||
if (!options || !query?.trim() || !options.user || !options.workspace) {
|
||||
return undefined;
|
||||
const queryTrimmed = query?.trim();
|
||||
if (!options || !queryTrimmed || !options.user || !options.workspace) {
|
||||
return toolError(
|
||||
'Doc Keyword Search Failed',
|
||||
'Missing workspace, user, or query for doc_keyword_search.'
|
||||
);
|
||||
}
|
||||
const workspace = await models.workspace.get(options.workspace);
|
||||
if (!workspace) {
|
||||
return workspaceSyncRequiredError();
|
||||
}
|
||||
const canAccess = await ac
|
||||
.user(options.user)
|
||||
.workspace(options.workspace)
|
||||
.can('Workspace.Read');
|
||||
if (!canAccess) return undefined;
|
||||
if (!canAccess) {
|
||||
return toolError(
|
||||
'Doc Keyword Search Failed',
|
||||
'You do not have permission to access this workspace.'
|
||||
);
|
||||
}
|
||||
const docs = await indexerService.searchDocsByKeyword(
|
||||
options.workspace,
|
||||
query
|
||||
queryTrimmed
|
||||
);
|
||||
|
||||
// filter current user readable docs
|
||||
@@ -29,15 +45,17 @@ export const buildDocKeywordSearchGetter = (
|
||||
.user(options.user)
|
||||
.workspace(options.workspace)
|
||||
.docs(docs, 'Doc.Read');
|
||||
return readableDocs;
|
||||
return readableDocs ?? [];
|
||||
};
|
||||
return searchDocs;
|
||||
};
|
||||
|
||||
export const createDocKeywordSearchTool = (
|
||||
searchDocs: (query: string) => Promise<SearchDoc[] | undefined>
|
||||
searchDocs: (
|
||||
query: string
|
||||
) => Promise<SearchDoc[] | ReturnType<typeof toolError>>
|
||||
) => {
|
||||
return tool({
|
||||
return defineTool({
|
||||
description:
|
||||
'Fuzzy search all workspace documents for the exact keyword or phrase supplied and return passages ranked by textual match. Use this tool by default whenever a straightforward term-based or keyword-base lookup is sufficient.',
|
||||
inputSchema: z.object({
|
||||
@@ -50,8 +68,8 @@ export const createDocKeywordSearchTool = (
|
||||
execute: async ({ query }) => {
|
||||
try {
|
||||
const docs = await searchDocs(query);
|
||||
if (!docs) {
|
||||
return;
|
||||
if (!Array.isArray(docs)) {
|
||||
return docs;
|
||||
}
|
||||
return docs.map(doc => ({
|
||||
docId: doc.docId,
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { tool } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { DocReader } from '../../../core/doc';
|
||||
import { AccessController } from '../../../core/permission';
|
||||
import { Models, publicUserSelect } from '../../../models';
|
||||
import { toolError } from './error';
|
||||
import { Models } from '../../../models';
|
||||
import {
|
||||
documentSyncPendingError,
|
||||
workspaceSyncRequiredError,
|
||||
} from './doc-sync';
|
||||
import { type ToolError, toolError } from './error';
|
||||
import { defineTool } from './tool';
|
||||
import type { CopilotChatOptions } from './types';
|
||||
|
||||
const logger = new Logger('DocReadTool');
|
||||
|
||||
const isToolError = (result: ToolError | object): result is ToolError =>
|
||||
'type' in result && result.type === 'error';
|
||||
|
||||
export const buildDocContentGetter = (
|
||||
ac: AccessController,
|
||||
docReader: DocReader,
|
||||
@@ -17,8 +24,17 @@ export const buildDocContentGetter = (
|
||||
) => {
|
||||
const getDoc = async (options: CopilotChatOptions, docId?: string) => {
|
||||
if (!options?.user || !options?.workspace || !docId) {
|
||||
return;
|
||||
return toolError(
|
||||
'Doc Read Failed',
|
||||
'Missing workspace, user, or document id for doc_read.'
|
||||
);
|
||||
}
|
||||
|
||||
const workspace = await models.workspace.get(options.workspace);
|
||||
if (!workspace) {
|
||||
return workspaceSyncRequiredError();
|
||||
}
|
||||
|
||||
const canAccess = await ac
|
||||
.user(options.user)
|
||||
.workspace(options.workspace)
|
||||
@@ -28,23 +44,15 @@ export const buildDocContentGetter = (
|
||||
logger.warn(
|
||||
`User ${options.user} does not have access to doc ${docId} in workspace ${options.workspace}`
|
||||
);
|
||||
return;
|
||||
return toolError(
|
||||
'Doc Read Failed',
|
||||
`You do not have permission to read document ${docId} in this workspace.`
|
||||
);
|
||||
}
|
||||
|
||||
const docMeta = await models.doc.getSnapshot(options.workspace, docId, {
|
||||
select: {
|
||||
createdAt: true,
|
||||
updatedAt: true,
|
||||
createdByUser: {
|
||||
select: publicUserSelect,
|
||||
},
|
||||
updatedByUser: {
|
||||
select: publicUserSelect,
|
||||
},
|
||||
},
|
||||
});
|
||||
const docMeta = await models.doc.getAuthors(options.workspace, docId);
|
||||
if (!docMeta) {
|
||||
return;
|
||||
return documentSyncPendingError(docId);
|
||||
}
|
||||
|
||||
const content = await docReader.getDocMarkdown(
|
||||
@@ -53,7 +61,7 @@ export const buildDocContentGetter = (
|
||||
true
|
||||
);
|
||||
if (!content) {
|
||||
return;
|
||||
return documentSyncPendingError(docId);
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -69,10 +77,14 @@ export const buildDocContentGetter = (
|
||||
return getDoc;
|
||||
};
|
||||
|
||||
type DocReadToolResult = Awaited<
|
||||
ReturnType<ReturnType<typeof buildDocContentGetter>>
|
||||
>;
|
||||
|
||||
export const createDocReadTool = (
|
||||
getDoc: (targetId?: string) => Promise<object | undefined>
|
||||
getDoc: (targetId?: string) => Promise<DocReadToolResult>
|
||||
) => {
|
||||
return tool({
|
||||
return defineTool({
|
||||
description:
|
||||
'Return the complete text and basic metadata of a single document identified by docId; use this when the user needs the full content of a specific file rather than a search result.',
|
||||
inputSchema: z.object({
|
||||
@@ -81,13 +93,10 @@ export const createDocReadTool = (
|
||||
execute: async ({ doc_id }) => {
|
||||
try {
|
||||
const doc = await getDoc(doc_id);
|
||||
if (!doc) {
|
||||
return;
|
||||
}
|
||||
return { ...doc };
|
||||
return isToolError(doc) ? doc : { ...doc };
|
||||
} catch (err: any) {
|
||||
logger.error(`Failed to read the doc ${doc_id}`, err);
|
||||
return toolError('Doc Read Failed', err.message);
|
||||
return toolError('Doc Read Failed', err.message ?? String(err));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { tool } from 'ai';
|
||||
import { omit } from 'lodash-es';
|
||||
import { z } from 'zod';
|
||||
|
||||
@@ -8,7 +7,9 @@ import {
|
||||
clearEmbeddingChunk,
|
||||
type Models,
|
||||
} from '../../../models';
|
||||
import { workspaceSyncRequiredError } from './doc-sync';
|
||||
import { toolError } from './error';
|
||||
import { defineTool } from './tool';
|
||||
import type {
|
||||
ContextSession,
|
||||
CopilotChatOptions,
|
||||
@@ -24,20 +25,30 @@ export const buildDocSearchGetter = (
|
||||
const searchDocs = async (
|
||||
options: CopilotChatOptions,
|
||||
query?: string,
|
||||
abortSignal?: AbortSignal
|
||||
signal?: AbortSignal
|
||||
) => {
|
||||
if (!options || !query?.trim() || !options.user || !options.workspace) {
|
||||
return `Invalid search parameters.`;
|
||||
return toolError(
|
||||
'Doc Semantic Search Failed',
|
||||
'Missing workspace, user, or query for doc_semantic_search.'
|
||||
);
|
||||
}
|
||||
const workspace = await models.workspace.get(options.workspace);
|
||||
if (!workspace) {
|
||||
return workspaceSyncRequiredError();
|
||||
}
|
||||
const canAccess = await ac
|
||||
.user(options.user)
|
||||
.workspace(options.workspace)
|
||||
.can('Workspace.Read');
|
||||
if (!canAccess)
|
||||
return 'You do not have permission to access this workspace.';
|
||||
return toolError(
|
||||
'Doc Semantic Search Failed',
|
||||
'You do not have permission to access this workspace.'
|
||||
);
|
||||
const [chunks, contextChunks] = await Promise.all([
|
||||
context.matchWorkspaceAll(options.workspace, query, 10, abortSignal),
|
||||
docContext?.matchFiles(query, 10, abortSignal) ?? [],
|
||||
context.matchWorkspaceAll(options.workspace, query, 10, signal),
|
||||
docContext?.matchFiles(query, 10, signal) ?? [],
|
||||
]);
|
||||
|
||||
const docChunks = await ac
|
||||
@@ -53,7 +64,7 @@ export const buildDocSearchGetter = (
|
||||
fileChunks.push(...contextChunks);
|
||||
}
|
||||
if (!blobChunks.length && !docChunks.length && !fileChunks.length) {
|
||||
return `No results found for "${query}".`;
|
||||
return [];
|
||||
}
|
||||
|
||||
const docIds = docChunks.map(c => ({
|
||||
@@ -100,10 +111,10 @@ export const buildDocSearchGetter = (
|
||||
export const createDocSemanticSearchTool = (
|
||||
searchDocs: (
|
||||
query: string,
|
||||
abortSignal?: AbortSignal
|
||||
) => Promise<ChunkSimilarity[] | string | undefined>
|
||||
signal?: AbortSignal
|
||||
) => Promise<ChunkSimilarity[] | ReturnType<typeof toolError>>
|
||||
) => {
|
||||
return tool({
|
||||
return defineTool({
|
||||
description:
|
||||
'Retrieve conceptually related passages by performing vector-based semantic similarity search across embedded documents; use this tool only when exact keyword search fails or the user explicitly needs meaning-level matches (e.g., paraphrases, synonyms, broader concepts, recent documents).',
|
||||
inputSchema: z.object({
|
||||
@@ -115,7 +126,7 @@ export const createDocSemanticSearchTool = (
|
||||
}),
|
||||
execute: async ({ query }, options) => {
|
||||
try {
|
||||
return await searchDocs(query, options.abortSignal);
|
||||
return await searchDocs(query, options.signal);
|
||||
} catch (e: any) {
|
||||
return toolError('Doc Semantic Search Failed', e.message);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
import { toolError } from './error';
|
||||
|
||||
export const LOCAL_WORKSPACE_SYNC_REQUIRED_MESSAGE =
|
||||
'This workspace is local-only and does not have AFFiNE Cloud sync enabled yet. Ask the user to enable workspace sync, then try again.';
|
||||
|
||||
export const DOCUMENT_SYNC_PENDING_MESSAGE = (docId: string) =>
|
||||
`Document ${docId} is not available on AFFiNE Cloud yet. Ask the user to wait for workspace sync to finish, then try again.`;
|
||||
|
||||
export const workspaceSyncRequiredError = () =>
|
||||
toolError('Workspace Sync Required', LOCAL_WORKSPACE_SYNC_REQUIRED_MESSAGE);
|
||||
|
||||
export const documentSyncPendingError = (docId: string) =>
|
||||
toolError('Document Sync Pending', DOCUMENT_SYNC_PENDING_MESSAGE(docId));
|
||||
@@ -1,10 +1,10 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { tool } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { DocWriter } from '../../../core/doc';
|
||||
import { AccessController } from '../../../core/permission';
|
||||
import { toolError } from './error';
|
||||
import { defineTool } from './tool';
|
||||
import type { CopilotChatOptions } from './types';
|
||||
|
||||
const logger = new Logger('DocWriteTool');
|
||||
@@ -141,7 +141,7 @@ export const buildDocUpdateMetaHandler = (
|
||||
export const createDocCreateTool = (
|
||||
createDoc: (title: string, content: string) => Promise<object>
|
||||
) => {
|
||||
return tool({
|
||||
return defineTool({
|
||||
description:
|
||||
'Create a new document in the workspace with the given title and markdown content. Returns the ID of the created document. This tool not support insert or update database block and image yet.',
|
||||
inputSchema: z.object({
|
||||
@@ -164,7 +164,7 @@ export const createDocCreateTool = (
|
||||
export const createDocUpdateTool = (
|
||||
updateDoc: (docId: string, content: string) => Promise<object>
|
||||
) => {
|
||||
return tool({
|
||||
return defineTool({
|
||||
description:
|
||||
'Update an existing document with new markdown content (body only). Uses structural diffing to apply minimal changes. This does NOT update the document title. This tool not support insert or update database block and image yet.',
|
||||
inputSchema: z.object({
|
||||
@@ -189,7 +189,7 @@ export const createDocUpdateTool = (
|
||||
export const createDocUpdateMetaTool = (
|
||||
updateDocMeta: (docId: string, title: string) => Promise<object>
|
||||
) => {
|
||||
return tool({
|
||||
return defineTool({
|
||||
description: 'Update document metadata (currently title only).',
|
||||
inputSchema: z.object({
|
||||
doc_id: z.string().describe('The ID of the document to update'),
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import { tool } from 'ai';
|
||||
import Exa from 'exa-js';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { Config } from '../../../base';
|
||||
import { toolError } from './error';
|
||||
import { defineTool } from './tool';
|
||||
|
||||
export const createExaCrawlTool = (config: Config) => {
|
||||
return tool({
|
||||
return defineTool({
|
||||
description: 'Crawl the web url for information',
|
||||
inputSchema: z.object({
|
||||
url: z
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import { tool } from 'ai';
|
||||
import Exa from 'exa-js';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { Config } from '../../../base';
|
||||
import { toolError } from './error';
|
||||
import { defineTool } from './tool';
|
||||
|
||||
export const createExaSearchTool = (config: Config) => {
|
||||
return tool({
|
||||
description: 'Search the web for information',
|
||||
return defineTool({
|
||||
description:
|
||||
'Search the web using Exa, one of the best web search APIs for AI',
|
||||
inputSchema: z.object({
|
||||
query: z.string().describe('The query to search the web for.'),
|
||||
mode: z
|
||||
|
||||
@@ -1,39 +1,3 @@
|
||||
import { ToolSet } from 'ai';
|
||||
|
||||
import { createBlobReadTool } from './blob-read';
|
||||
import { createCodeArtifactTool } from './code-artifact';
|
||||
import { createConversationSummaryTool } from './conversation-summary';
|
||||
import { createDocComposeTool } from './doc-compose';
|
||||
import { createDocEditTool } from './doc-edit';
|
||||
import { createDocKeywordSearchTool } from './doc-keyword-search';
|
||||
import { createDocReadTool } from './doc-read';
|
||||
import { createDocSemanticSearchTool } from './doc-semantic-search';
|
||||
import {
|
||||
createDocCreateTool,
|
||||
createDocUpdateMetaTool,
|
||||
createDocUpdateTool,
|
||||
} from './doc-write';
|
||||
import { createExaCrawlTool } from './exa-crawl';
|
||||
import { createExaSearchTool } from './exa-search';
|
||||
import { createSectionEditTool } from './section-edit';
|
||||
|
||||
export interface CustomAITools extends ToolSet {
|
||||
blob_read: ReturnType<typeof createBlobReadTool>;
|
||||
code_artifact: ReturnType<typeof createCodeArtifactTool>;
|
||||
conversation_summary: ReturnType<typeof createConversationSummaryTool>;
|
||||
doc_edit: ReturnType<typeof createDocEditTool>;
|
||||
doc_semantic_search: ReturnType<typeof createDocSemanticSearchTool>;
|
||||
doc_keyword_search: ReturnType<typeof createDocKeywordSearchTool>;
|
||||
doc_read: ReturnType<typeof createDocReadTool>;
|
||||
doc_create: ReturnType<typeof createDocCreateTool>;
|
||||
doc_update: ReturnType<typeof createDocUpdateTool>;
|
||||
doc_update_meta: ReturnType<typeof createDocUpdateMetaTool>;
|
||||
doc_compose: ReturnType<typeof createDocComposeTool>;
|
||||
section_edit: ReturnType<typeof createSectionEditTool>;
|
||||
web_search_exa: ReturnType<typeof createExaSearchTool>;
|
||||
web_crawl_exa: ReturnType<typeof createExaCrawlTool>;
|
||||
}
|
||||
|
||||
export * from './blob-read';
|
||||
export * from './code-artifact';
|
||||
export * from './conversation-summary';
|
||||
@@ -47,3 +11,4 @@ export * from './error';
|
||||
export * from './exa-crawl';
|
||||
export * from './exa-search';
|
||||
export * from './section-edit';
|
||||
export * from './tool';
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { tool } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { toolError } from './error';
|
||||
import { defineTool } from './tool';
|
||||
import type { CopilotProviderFactory, PromptService } from './types';
|
||||
|
||||
const logger = new Logger('SectionEditTool');
|
||||
@@ -11,7 +11,7 @@ export const createSectionEditTool = (
|
||||
promptService: PromptService,
|
||||
factory: CopilotProviderFactory
|
||||
) => {
|
||||
return tool({
|
||||
return defineTool({
|
||||
description:
|
||||
'Intelligently edit and modify a specific section of a document based on user instructions, with full document context awareness. This tool can refine, rewrite, translate, restructure, or enhance any part of markdown content while preserving formatting, maintaining contextual coherence, and ensuring consistency with the entire document. Perfect for targeted improvements that consider the broader document context.',
|
||||
inputSchema: z.object({
|
||||
|
||||
33
packages/backend/server/src/plugins/copilot/tools/tool.ts
Normal file
33
packages/backend/server/src/plugins/copilot/tools/tool.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import type { ZodTypeAny } from 'zod';
|
||||
import { z } from 'zod';
|
||||
|
||||
import type { PromptMessage } from '../providers/types';
|
||||
|
||||
export type CopilotToolExecuteOptions = {
|
||||
signal?: AbortSignal;
|
||||
messages?: PromptMessage[];
|
||||
};
|
||||
|
||||
export type CopilotTool = {
|
||||
description?: string;
|
||||
inputSchema?: ZodTypeAny | Record<string, unknown>;
|
||||
execute?: {
|
||||
bivarianceHack: (
|
||||
args: Record<string, unknown>,
|
||||
options: CopilotToolExecuteOptions
|
||||
) => Promise<unknown> | unknown;
|
||||
}['bivarianceHack'];
|
||||
};
|
||||
|
||||
export type CopilotToolSet = Record<string, CopilotTool>;
|
||||
|
||||
export function defineTool<TSchema extends ZodTypeAny, TResult>(tool: {
|
||||
description?: string;
|
||||
inputSchema: TSchema;
|
||||
execute: (
|
||||
args: z.infer<TSchema>,
|
||||
options: CopilotToolExecuteOptions
|
||||
) => Promise<TResult> | TResult;
|
||||
}): CopilotTool {
|
||||
return tool;
|
||||
}
|
||||
@@ -224,11 +224,10 @@ export class CopilotTranscriptionService {
|
||||
const config = Object.assign({}, prompt.config);
|
||||
if (schema) {
|
||||
const provider = await this.getProvider(prompt.model, true, prefer);
|
||||
return provider.structure(
|
||||
cond,
|
||||
[...prompt.finish({ schema }), msg],
|
||||
config
|
||||
);
|
||||
return provider.structure(cond, [...prompt.finish({}), msg], {
|
||||
...config,
|
||||
schema,
|
||||
});
|
||||
} else {
|
||||
const provider = await this.getProvider(prompt.model, false);
|
||||
return provider.text(cond, [...prompt.finish({}), msg], config);
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Injectable, OnModuleDestroy } from '@nestjs/common';
|
||||
import { createRemoteJWKSet, type JWTPayload, jwtVerify } from 'jose';
|
||||
import { omit } from 'lodash-es';
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
ExponentialBackoffScheduler,
|
||||
InvalidAuthState,
|
||||
InvalidOauthResponse,
|
||||
URLHelper,
|
||||
@@ -35,7 +36,7 @@ const OIDCUserInfoSchema = z
|
||||
.object({
|
||||
sub: z.string(),
|
||||
preferred_username: z.string().optional(),
|
||||
email: z.string().email(),
|
||||
email: z.string().optional(),
|
||||
name: z.string().optional(),
|
||||
email_verified: z
|
||||
.union([z.boolean(), z.enum(['true', 'false', '1', '0', 'yes', 'no'])])
|
||||
@@ -44,6 +45,8 @@ const OIDCUserInfoSchema = z
|
||||
})
|
||||
.passthrough();
|
||||
|
||||
const OIDCEmailSchema = z.string().email();
|
||||
|
||||
const OIDCConfigurationSchema = z.object({
|
||||
authorization_endpoint: z.string().url(),
|
||||
token_endpoint: z.string().url(),
|
||||
@@ -54,16 +57,28 @@ const OIDCConfigurationSchema = z.object({
|
||||
|
||||
type OIDCConfiguration = z.infer<typeof OIDCConfigurationSchema>;
|
||||
|
||||
const OIDC_DISCOVERY_INITIAL_RETRY_DELAY = 1000;
|
||||
const OIDC_DISCOVERY_MAX_RETRY_DELAY = 60_000;
|
||||
|
||||
@Injectable()
|
||||
export class OIDCProvider extends OAuthProvider {
|
||||
export class OIDCProvider extends OAuthProvider implements OnModuleDestroy {
|
||||
override provider = OAuthProviderName.OIDC;
|
||||
#endpoints: OIDCConfiguration | null = null;
|
||||
#jwks: ReturnType<typeof createRemoteJWKSet> | null = null;
|
||||
readonly #retryScheduler = new ExponentialBackoffScheduler({
|
||||
baseDelayMs: OIDC_DISCOVERY_INITIAL_RETRY_DELAY,
|
||||
maxDelayMs: OIDC_DISCOVERY_MAX_RETRY_DELAY,
|
||||
});
|
||||
#validationGeneration = 0;
|
||||
|
||||
constructor(private readonly url: URLHelper) {
|
||||
super();
|
||||
}
|
||||
|
||||
onModuleDestroy() {
|
||||
this.#retryScheduler.clear();
|
||||
}
|
||||
|
||||
override get requiresPkce() {
|
||||
return true;
|
||||
}
|
||||
@@ -87,58 +102,109 @@ export class OIDCProvider extends OAuthProvider {
|
||||
}
|
||||
|
||||
protected override setup() {
|
||||
const validate = async () => {
|
||||
this.#endpoints = null;
|
||||
this.#jwks = null;
|
||||
const generation = ++this.#validationGeneration;
|
||||
this.#retryScheduler.clear();
|
||||
|
||||
if (super.configured) {
|
||||
const config = this.config as OAuthOIDCProviderConfig;
|
||||
if (!config.issuer) {
|
||||
this.logger.error('Missing OIDC issuer configuration');
|
||||
super.setup();
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await fetch(
|
||||
`${config.issuer}/.well-known/openid-configuration`,
|
||||
{
|
||||
method: 'GET',
|
||||
headers: { Accept: 'application/json' },
|
||||
}
|
||||
);
|
||||
|
||||
if (res.ok) {
|
||||
const configuration = OIDCConfigurationSchema.parse(
|
||||
await res.json()
|
||||
);
|
||||
if (
|
||||
this.normalizeIssuer(config.issuer) !==
|
||||
this.normalizeIssuer(configuration.issuer)
|
||||
) {
|
||||
this.logger.error(
|
||||
`OIDC issuer mismatch, expected ${config.issuer}, got ${configuration.issuer}`
|
||||
);
|
||||
} else {
|
||||
this.#endpoints = configuration;
|
||||
this.#jwks = createRemoteJWKSet(new URL(configuration.jwks_uri));
|
||||
}
|
||||
} else {
|
||||
this.logger.error(`Invalid OIDC issuer ${config.issuer}`);
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.error('Failed to validate OIDC configuration', e);
|
||||
}
|
||||
}
|
||||
|
||||
super.setup();
|
||||
};
|
||||
|
||||
validate().catch(() => {
|
||||
this.validateAndSync(generation).catch(() => {
|
||||
/* noop */
|
||||
});
|
||||
}
|
||||
|
||||
private async validateAndSync(generation: number) {
|
||||
if (generation !== this.#validationGeneration) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!super.configured) {
|
||||
this.resetState();
|
||||
this.#retryScheduler.reset();
|
||||
super.setup();
|
||||
return;
|
||||
}
|
||||
|
||||
const config = this.config as OAuthOIDCProviderConfig;
|
||||
if (!config.issuer) {
|
||||
this.logger.error('Missing OIDC issuer configuration');
|
||||
this.resetState();
|
||||
this.#retryScheduler.reset();
|
||||
super.setup();
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await fetch(
|
||||
`${config.issuer}/.well-known/openid-configuration`,
|
||||
{
|
||||
method: 'GET',
|
||||
headers: { Accept: 'application/json' },
|
||||
}
|
||||
);
|
||||
|
||||
if (generation !== this.#validationGeneration) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!res.ok) {
|
||||
this.logger.error(`Invalid OIDC issuer ${config.issuer}`);
|
||||
this.onValidationFailure(generation);
|
||||
return;
|
||||
}
|
||||
|
||||
const configuration = OIDCConfigurationSchema.parse(await res.json());
|
||||
if (
|
||||
this.normalizeIssuer(config.issuer) !==
|
||||
this.normalizeIssuer(configuration.issuer)
|
||||
) {
|
||||
this.logger.error(
|
||||
`OIDC issuer mismatch, expected ${config.issuer}, got ${configuration.issuer}`
|
||||
);
|
||||
this.onValidationFailure(generation);
|
||||
return;
|
||||
}
|
||||
|
||||
this.#endpoints = configuration;
|
||||
this.#jwks = createRemoteJWKSet(new URL(configuration.jwks_uri));
|
||||
this.#retryScheduler.reset();
|
||||
super.setup();
|
||||
} catch (e) {
|
||||
if (generation !== this.#validationGeneration) {
|
||||
return;
|
||||
}
|
||||
this.logger.error('Failed to validate OIDC configuration', e);
|
||||
this.onValidationFailure(generation);
|
||||
}
|
||||
}
|
||||
|
||||
private onValidationFailure(generation: number) {
|
||||
this.resetState();
|
||||
super.setup();
|
||||
this.scheduleRetry(generation);
|
||||
}
|
||||
|
||||
private scheduleRetry(generation: number) {
|
||||
if (generation !== this.#validationGeneration) {
|
||||
return;
|
||||
}
|
||||
|
||||
const delay = this.#retryScheduler.schedule(() => {
|
||||
this.validateAndSync(generation).catch(() => {
|
||||
/* noop */
|
||||
});
|
||||
});
|
||||
if (delay === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.logger.warn(
|
||||
`OIDC discovery validation failed, retrying in ${delay}ms`
|
||||
);
|
||||
}
|
||||
|
||||
private resetState() {
|
||||
this.#endpoints = null;
|
||||
this.#jwks = null;
|
||||
}
|
||||
|
||||
getAuthUrl(state: string): string {
|
||||
const parsedState = this.parseStatePayload(state);
|
||||
const nonce = parsedState?.state ?? state;
|
||||
@@ -291,6 +357,68 @@ export class OIDCProvider extends OAuthProvider {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
private claimCandidates(
|
||||
configuredClaim: string | undefined,
|
||||
defaultClaim: string
|
||||
) {
|
||||
if (typeof configuredClaim === 'string' && configuredClaim.length > 0) {
|
||||
return [configuredClaim];
|
||||
}
|
||||
return [defaultClaim];
|
||||
}
|
||||
|
||||
private formatClaimCandidates(claims: string[]) {
|
||||
return claims.map(claim => `"${claim}"`).join(', ');
|
||||
}
|
||||
|
||||
private resolveStringClaim(
|
||||
claims: string[],
|
||||
...sources: Array<Record<string, unknown>>
|
||||
) {
|
||||
for (const claim of claims) {
|
||||
for (const source of sources) {
|
||||
const value = this.extractString(source[claim]);
|
||||
if (value) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
private resolveBooleanClaim(
|
||||
claims: string[],
|
||||
...sources: Array<Record<string, unknown>>
|
||||
) {
|
||||
for (const claim of claims) {
|
||||
for (const source of sources) {
|
||||
const value = this.extractBoolean(source[claim]);
|
||||
if (value !== undefined) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
private resolveEmailClaim(
|
||||
claims: string[],
|
||||
...sources: Array<Record<string, unknown>>
|
||||
) {
|
||||
for (const claim of claims) {
|
||||
for (const source of sources) {
|
||||
const value = this.extractString(source[claim]);
|
||||
if (value && OIDCEmailSchema.safeParse(value).success) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async getUser(tokens: Tokens, state: OAuthState): Promise<OAuthAccount> {
|
||||
if (!tokens.idToken) {
|
||||
throw new InvalidOauthResponse({
|
||||
@@ -315,6 +443,8 @@ export class OIDCProvider extends OAuthProvider {
|
||||
{ treatServerErrorAsInvalid: true }
|
||||
);
|
||||
const user = OIDCUserInfoSchema.parse(rawUser);
|
||||
const userClaims = user as Record<string, unknown>;
|
||||
const idTokenClaimsRecord = idTokenClaims as Record<string, unknown>;
|
||||
|
||||
if (!user.sub || !idTokenClaims.sub) {
|
||||
throw new InvalidOauthResponse({
|
||||
@@ -327,22 +457,29 @@ export class OIDCProvider extends OAuthProvider {
|
||||
}
|
||||
|
||||
const args = this.config.args ?? {};
|
||||
const idClaims = this.claimCandidates(args.claim_id, 'sub');
|
||||
const emailClaims = this.claimCandidates(args.claim_email, 'email');
|
||||
const nameClaims = this.claimCandidates(args.claim_name, 'name');
|
||||
const emailVerifiedClaims = this.claimCandidates(
|
||||
args.claim_email_verified,
|
||||
'email_verified'
|
||||
);
|
||||
|
||||
const claimsMap = {
|
||||
id: args.claim_id || 'sub',
|
||||
email: args.claim_email || 'email',
|
||||
name: args.claim_name || 'name',
|
||||
emailVerified: args.claim_email_verified || 'email_verified',
|
||||
};
|
||||
|
||||
const accountId =
|
||||
this.extractString(user[claimsMap.id]) ?? idTokenClaims.sub;
|
||||
const email =
|
||||
this.extractString(user[claimsMap.email]) ||
|
||||
this.extractString(idTokenClaims.email);
|
||||
const emailVerified =
|
||||
this.extractBoolean(user[claimsMap.emailVerified]) ??
|
||||
this.extractBoolean(idTokenClaims.email_verified);
|
||||
const accountId = this.resolveStringClaim(
|
||||
idClaims,
|
||||
userClaims,
|
||||
idTokenClaimsRecord
|
||||
);
|
||||
const email = this.resolveEmailClaim(
|
||||
emailClaims,
|
||||
userClaims,
|
||||
idTokenClaimsRecord
|
||||
);
|
||||
const emailVerified = this.resolveBooleanClaim(
|
||||
emailVerifiedClaims,
|
||||
userClaims,
|
||||
idTokenClaimsRecord
|
||||
);
|
||||
|
||||
if (!accountId) {
|
||||
throw new InvalidOauthResponse({
|
||||
@@ -352,7 +489,7 @@ export class OIDCProvider extends OAuthProvider {
|
||||
|
||||
if (!email) {
|
||||
throw new InvalidOauthResponse({
|
||||
reason: 'Missing required claim for email',
|
||||
reason: `Missing valid email claim in OIDC response. Tried userinfo and ID token claims: ${this.formatClaimCandidates(emailClaims)}`,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -367,9 +504,11 @@ export class OIDCProvider extends OAuthProvider {
|
||||
email,
|
||||
};
|
||||
|
||||
const name =
|
||||
this.extractString(user[claimsMap.name]) ||
|
||||
this.extractString(idTokenClaims.name);
|
||||
const name = this.resolveStringClaim(
|
||||
nameClaims,
|
||||
userClaims,
|
||||
idTokenClaimsRecord
|
||||
);
|
||||
if (name) {
|
||||
account.name = name;
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ interface TestOps extends OpSchema {
|
||||
add: [{ a: number; b: number }, number];
|
||||
bin: [Uint8Array, Uint8Array];
|
||||
sub: [Uint8Array, number];
|
||||
init: [{ fastText?: boolean } | undefined, { ok: true }];
|
||||
}
|
||||
|
||||
declare module 'vitest' {
|
||||
@@ -84,6 +85,55 @@ describe('op client', () => {
|
||||
expect(data.byteLength).toBe(0);
|
||||
});
|
||||
|
||||
it('should send optional payload call with abort signal', async ctx => {
|
||||
const abortController = new AbortController();
|
||||
const result = ctx.producer.call(
|
||||
'init',
|
||||
{ fastText: true },
|
||||
abortController.signal
|
||||
);
|
||||
|
||||
expect(ctx.postMessage.mock.calls[0][0]).toMatchInlineSnapshot(`
|
||||
{
|
||||
"id": "init:1",
|
||||
"name": "init",
|
||||
"payload": {
|
||||
"fastText": true,
|
||||
},
|
||||
"type": "call",
|
||||
}
|
||||
`);
|
||||
|
||||
ctx.handlers.return({
|
||||
type: 'return',
|
||||
id: 'init:1',
|
||||
data: { ok: true },
|
||||
});
|
||||
|
||||
await expect(result).resolves.toEqual({ ok: true });
|
||||
});
|
||||
|
||||
it('should send undefined payload for optional input call', async ctx => {
|
||||
const result = ctx.producer.call('init', undefined);
|
||||
|
||||
expect(ctx.postMessage.mock.calls[0][0]).toMatchInlineSnapshot(`
|
||||
{
|
||||
"id": "init:1",
|
||||
"name": "init",
|
||||
"payload": undefined,
|
||||
"type": "call",
|
||||
}
|
||||
`);
|
||||
|
||||
ctx.handlers.return({
|
||||
type: 'return',
|
||||
id: 'init:1',
|
||||
data: { ok: true },
|
||||
});
|
||||
|
||||
await expect(result).resolves.toEqual({ ok: true });
|
||||
});
|
||||
|
||||
it('should cancel call', async ctx => {
|
||||
const promise = ctx.producer.call('add', { a: 1, b: 2 });
|
||||
|
||||
|
||||
@@ -40,18 +40,14 @@ describe('op consumer', () => {
|
||||
it('should throw if no handler registered', async ctx => {
|
||||
ctx.handlers.call({ type: 'call', id: 'add:1', name: 'add', payload: {} });
|
||||
await vi.advanceTimersToNextTimerAsync();
|
||||
expect(ctx.postMessage.mock.lastCall).toMatchInlineSnapshot(`
|
||||
[
|
||||
{
|
||||
"error": {
|
||||
"message": "Handler for operation [add] is not registered.",
|
||||
"name": "Error",
|
||||
},
|
||||
"id": "add:1",
|
||||
"type": "return",
|
||||
},
|
||||
]
|
||||
`);
|
||||
expect(ctx.postMessage.mock.lastCall?.[0]).toMatchObject({
|
||||
type: 'return',
|
||||
id: 'add:1',
|
||||
error: {
|
||||
message: 'Handler for operation [add] is not registered.',
|
||||
name: 'Error',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle call message', async ctx => {
|
||||
@@ -73,6 +69,38 @@ describe('op consumer', () => {
|
||||
`);
|
||||
});
|
||||
|
||||
it('should serialize string errors with message', async ctx => {
|
||||
ctx.consumer.register('any', () => {
|
||||
throw 'worker panic';
|
||||
});
|
||||
|
||||
ctx.handlers.call({ type: 'call', id: 'any:1', name: 'any', payload: {} });
|
||||
await vi.advanceTimersToNextTimerAsync();
|
||||
|
||||
expect(ctx.postMessage.mock.calls[0][0]).toMatchObject({
|
||||
type: 'return',
|
||||
id: 'any:1',
|
||||
error: {
|
||||
name: 'Error',
|
||||
message: 'worker panic',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should serialize plain object errors with fallback message', async ctx => {
|
||||
ctx.consumer.register('any', () => {
|
||||
throw { reason: 'panic', code: 'E_PANIC' };
|
||||
});
|
||||
|
||||
ctx.handlers.call({ type: 'call', id: 'any:1', name: 'any', payload: {} });
|
||||
await vi.advanceTimersToNextTimerAsync();
|
||||
|
||||
const message = ctx.postMessage.mock.calls[0][0]?.error?.message;
|
||||
expect(typeof message).toBe('string');
|
||||
expect(message).toContain('"reason":"panic"');
|
||||
expect(message).toContain('"code":"E_PANIC"');
|
||||
});
|
||||
|
||||
it('should handle cancel message', async ctx => {
|
||||
ctx.consumer.register('add', ({ a, b }, { signal }) => {
|
||||
const { reject, resolve, promise } = Promise.withResolvers<number>();
|
||||
|
||||
@@ -16,6 +16,96 @@ import {
|
||||
} from './message';
|
||||
import type { OpInput, OpNames, OpOutput, OpSchema } from './types';
|
||||
|
||||
const SERIALIZABLE_ERROR_FIELDS = [
|
||||
'name',
|
||||
'message',
|
||||
'code',
|
||||
'type',
|
||||
'status',
|
||||
'data',
|
||||
'stacktrace',
|
||||
] as const;
|
||||
|
||||
type SerializableErrorShape = Partial<
|
||||
Record<(typeof SERIALIZABLE_ERROR_FIELDS)[number], unknown>
|
||||
> & {
|
||||
name?: string;
|
||||
message?: string;
|
||||
};
|
||||
|
||||
function getFallbackErrorMessage(error: unknown): string {
|
||||
if (typeof error === 'string') {
|
||||
return error;
|
||||
}
|
||||
|
||||
if (error instanceof Error && error.message) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
if (
|
||||
typeof error === 'number' ||
|
||||
typeof error === 'boolean' ||
|
||||
typeof error === 'bigint' ||
|
||||
typeof error === 'symbol'
|
||||
) {
|
||||
return String(error);
|
||||
}
|
||||
|
||||
if (error === null || error === undefined) {
|
||||
return 'Unknown error';
|
||||
}
|
||||
|
||||
try {
|
||||
const jsonMessage = JSON.stringify(error);
|
||||
if (jsonMessage && jsonMessage !== '{}') {
|
||||
return jsonMessage;
|
||||
}
|
||||
} catch {
|
||||
return 'Unknown error';
|
||||
}
|
||||
|
||||
return 'Unknown error';
|
||||
}
|
||||
|
||||
function serializeError(error: unknown): Error {
|
||||
const valueToPick =
|
||||
error && typeof error === 'object'
|
||||
? error
|
||||
: ({} as Record<string, unknown>);
|
||||
const serialized = pick(
|
||||
valueToPick,
|
||||
SERIALIZABLE_ERROR_FIELDS
|
||||
) as SerializableErrorShape;
|
||||
|
||||
if (!serialized.message || typeof serialized.message !== 'string') {
|
||||
serialized.message = getFallbackErrorMessage(error);
|
||||
}
|
||||
|
||||
if (!serialized.name || typeof serialized.name !== 'string') {
|
||||
if (error instanceof Error && error.name) {
|
||||
serialized.name = error.name;
|
||||
} else if (error && typeof error === 'object') {
|
||||
const constructorName = error.constructor?.name;
|
||||
serialized.name =
|
||||
typeof constructorName === 'string' && constructorName.length > 0
|
||||
? constructorName
|
||||
: 'Error';
|
||||
} else {
|
||||
serialized.name = 'Error';
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
!serialized.stacktrace &&
|
||||
error instanceof Error &&
|
||||
typeof error.stack === 'string'
|
||||
) {
|
||||
serialized.stacktrace = error.stack;
|
||||
}
|
||||
|
||||
return serialized as Error;
|
||||
}
|
||||
|
||||
interface OpCallContext {
|
||||
signal: AbortSignal;
|
||||
}
|
||||
@@ -71,15 +161,7 @@ export class OpConsumer<Ops extends OpSchema> extends AutoMessageHandler {
|
||||
this.port.postMessage({
|
||||
type: 'return',
|
||||
id: msg.id,
|
||||
error: pick(error, [
|
||||
'name',
|
||||
'message',
|
||||
'code',
|
||||
'type',
|
||||
'status',
|
||||
'data',
|
||||
'stacktrace',
|
||||
]),
|
||||
error: serializeError(error),
|
||||
} satisfies ReturnMessage);
|
||||
},
|
||||
complete: () => {
|
||||
@@ -109,15 +191,7 @@ export class OpConsumer<Ops extends OpSchema> extends AutoMessageHandler {
|
||||
this.port.postMessage({
|
||||
type: 'error',
|
||||
id: msg.id,
|
||||
error: pick(error, [
|
||||
'name',
|
||||
'message',
|
||||
'code',
|
||||
'type',
|
||||
'status',
|
||||
'data',
|
||||
'stacktrace',
|
||||
]),
|
||||
error: serializeError(error),
|
||||
} satisfies SubscriptionErrorMessage);
|
||||
},
|
||||
complete: () => {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user