From 29a27b561b3635e6759ba9747673c0643c9e344c Mon Sep 17 00:00:00 2001 From: DarkSky <25152247+darkskygit@users.noreply.github.com> Date: Wed, 11 Mar 2026 13:55:35 +0800 Subject: [PATCH] feat(server): migrate copilot to native (#14620) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit #### PR Dependency Tree * **PR #14620** 👈 This tree was auto-generated by [Charcoal](https://github.com/danerwilliams/charcoal) ## Summary by CodeRabbit * **New Features** * Native LLM workflows: structured outputs, embeddings, and reranking plus richer multimodal attachments (images, audio, files) and improved remote-attachment inlining. * **Refactor** * Tooling API unified behind a local tool-definition helper; provider/adapters reorganized to route through native dispatch paths. * **Chores** * Dependency updates, removed legacy Google SDK integrations, and increased front memory allocation. * **Tests** * Expanded end-to-end and streaming tests exercising native provider flows, attachments, and rerank/structured scenarios. --- .cargo/config.toml | 5 + .github/helm/affine/charts/front/values.yaml | 4 +- Cargo.lock | 485 +----- Cargo.toml | 3 +- packages/backend/native/Cargo.toml | 5 +- .../backend/native/fixtures/audio-only.mka | Bin 0 -> 3385 bytes .../backend/native/fixtures/audio-only.webm | Bin 0 -> 1023 bytes .../backend/native/fixtures/audio-video.webm | Bin 0 -> 1269 bytes packages/backend/native/index.d.ts | 6 + packages/backend/native/src/file_type.rs | 74 +- packages/backend/native/src/llm.rs | 83 +- packages/backend/server/package.json | 3 - .../copilot/copilot-provider.spec.ts | 80 +- .../src/__tests__/copilot/copilot.spec.ts | 197 +-- .../__tests__/copilot/native-adapter.spec.ts | 210 --- .../__tests__/copilot/native-provider.spec.ts | 1431 +++++++++++++++++ .../__tests__/copilot/provider-native.spec.ts | 83 +- .../__tests__/copilot/tool-call-loop.spec.ts | 172 +- .../src/__tests__/copilot/utils.spec.ts | 72 +- .../src/__tests__/mocks/copilot.mock.ts | 2 +- .../server/src/models/copilot-session.ts | 106 +- packages/backend/server/src/native.ts | 165 +- .../server/src/plugins/copilot/config.ts | 3 +- .../src/plugins/copilot/embedding/client.ts | 61 +- .../src/plugins/copilot/prompt/prompts.ts | 15 - .../copilot/providers/anthropic/anthropic.ts | 75 +- .../copilot/providers/anthropic/official.ts | 4 + .../copilot/providers/anthropic/vertex.ts | 20 +- .../plugins/copilot/providers/attachments.ts | 233 +++ .../src/plugins/copilot/providers/fal.ts | 16 +- .../copilot/providers/gemini/gemini.ts | 693 +++++--- .../copilot/providers/gemini/generative.ts | 37 +- .../copilot/providers/gemini/vertex.ts | 50 +- .../src/plugins/copilot/providers/loop.ts | 174 +- .../src/plugins/copilot/providers/morph.ts | 41 +- .../src/plugins/copilot/providers/native.ts | 295 +++- .../src/plugins/copilot/providers/openai.ts | 517 +++--- .../plugins/copilot/providers/perplexity.ts | 31 +- .../copilot/providers/provider-middleware.ts | 12 +- .../copilot/providers/provider-registry.ts | 9 +- .../src/plugins/copilot/providers/provider.ts | 238 ++- .../src/plugins/copilot/providers/rerank.ts | 23 - .../src/plugins/copilot/providers/types.ts | 122 +- .../src/plugins/copilot/providers/utils.ts | 483 ++---- .../server/src/plugins/copilot/session.ts | 12 +- .../src/plugins/copilot/tools/blob-read.ts | 4 +- .../plugins/copilot/tools/code-artifact.ts | 4 +- .../copilot/tools/conversation-summary.ts | 4 +- .../src/plugins/copilot/tools/doc-compose.ts | 4 +- .../src/plugins/copilot/tools/doc-edit.ts | 4 +- .../copilot/tools/doc-keyword-search.ts | 4 +- .../src/plugins/copilot/tools/doc-read.ts | 4 +- .../copilot/tools/doc-semantic-search.ts | 14 +- .../src/plugins/copilot/tools/doc-write.ts | 8 +- .../src/plugins/copilot/tools/exa-crawl.ts | 4 +- .../src/plugins/copilot/tools/exa-search.ts | 4 +- .../server/src/plugins/copilot/tools/index.ts | 37 +- .../src/plugins/copilot/tools/section-edit.ts | 4 +- .../server/src/plugins/copilot/tools/tool.ts | 33 + .../src/plugins/copilot/transcript/service.ts | 9 +- .../e2e/settings/embedding.spec.ts | 17 +- yarn.lock | 147 +- 62 files changed, 4359 insertions(+), 2296 deletions(-) create mode 100644 packages/backend/native/fixtures/audio-only.mka create mode 100644 packages/backend/native/fixtures/audio-only.webm create mode 100644 packages/backend/native/fixtures/audio-video.webm delete mode 100644 packages/backend/server/src/__tests__/copilot/native-adapter.spec.ts create mode 100644 packages/backend/server/src/__tests__/copilot/native-provider.spec.ts create mode 100644 packages/backend/server/src/plugins/copilot/providers/attachments.ts delete mode 100644 packages/backend/server/src/plugins/copilot/providers/rerank.ts create mode 100644 packages/backend/server/src/plugins/copilot/tools/tool.ts diff --git a/.cargo/config.toml b/.cargo/config.toml index 6f84a73999..a3b2334182 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -19,3 +19,8 @@ rustflags = [ # pthread_key_create() destructors and segfault after a DSO unloading [target.'cfg(all(target_env = "gnu", not(target_os = "windows")))'] rustflags = ["-C", "link-args=-Wl,-z,nodelete"] + +# Temporary local llm_adapter override. +# Uncomment when verifying AFFiNE against the sibling llm_adapter workspace. +# [patch.crates-io] +# llm_adapter = { path = "../llm_adapter" } diff --git a/.github/helm/affine/charts/front/values.yaml b/.github/helm/affine/charts/front/values.yaml index 08933d27d1..cc4ac4b6bb 100644 --- a/.github/helm/affine/charts/front/values.yaml +++ b/.github/helm/affine/charts/front/values.yaml @@ -31,10 +31,10 @@ podSecurityContext: resources: limits: cpu: '1' - memory: 4Gi + memory: 6Gi requests: cpu: '1' - memory: 2Gi + memory: 4Gi probe: initialDelaySeconds: 20 diff --git a/Cargo.lock b/Cargo.lock index 33fa44a7a5..9593b0ccee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -186,6 +186,7 @@ dependencies = [ "libwebp-sys", "little_exif", "llm_adapter", + "matroska", "mimalloc", "mp4parse", "napi", @@ -480,12 +481,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "atomic-waker" -version = "1.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" - [[package]] name = "auto_enums" version = "0.8.7" @@ -504,28 +499,6 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" -[[package]] -name = "aws-lc-rs" -version = "1.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf" -dependencies = [ - "aws-lc-sys", - "zeroize", -] - -[[package]] -name = "aws-lc-sys" -version = "0.38.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e" -dependencies = [ - "cc", - "cmake", - "dunce", - "fs_extra", -] - [[package]] name = "base64" version = "0.22.1" @@ -649,6 +622,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "bitstream-io" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "680575de65ce8b916b82a447458b94a48776707d9c2681a9d8da351c06886a1f" +dependencies = [ + "core2", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -981,15 +963,6 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" -[[package]] -name = "cmake" -version = "0.1.57" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" -dependencies = [ - "cc", -] - [[package]] name = "color_quant" version = "1.1.0" @@ -1534,12 +1507,6 @@ version = "0.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f678cf4a922c215c63e0de95eb1ff08a958a81d47e485cf9da1e27bf6305cfa5" -[[package]] -name = "dunce" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" - [[package]] name = "ecb" version = "0.1.2" @@ -1771,12 +1738,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "fs_extra" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" - [[package]] name = "futf" version = "0.1.5" @@ -1941,11 +1902,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", - "js-sys", "libc", "r-efi", "wasip2", - "wasm-bindgen", ] [[package]] @@ -2138,95 +2097,12 @@ dependencies = [ "itoa", ] -[[package]] -name = "http-body" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" -dependencies = [ - "bytes", - "http", -] - -[[package]] -name = "http-body-util" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" -dependencies = [ - "bytes", - "futures-core", - "http", - "http-body", - "pin-project-lite", -] - [[package]] name = "httparse" version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" -[[package]] -name = "hyper" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" -dependencies = [ - "atomic-waker", - "bytes", - "futures-channel", - "futures-core", - "http", - "http-body", - "httparse", - "itoa", - "pin-project-lite", - "pin-utils", - "smallvec", - "tokio", - "want", -] - -[[package]] -name = "hyper-rustls" -version = "0.27.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" -dependencies = [ - "http", - "hyper", - "hyper-util", - "rustls", - "rustls-pki-types", - "tokio", - "tokio-rustls", - "tower-service", -] - -[[package]] -name = "hyper-util" -version = "0.1.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" -dependencies = [ - "base64", - "bytes", - "futures-channel", - "futures-util", - "http", - "http-body", - "hyper", - "ipnet", - "libc", - "percent-encoding", - "pin-project-lite", - "socket2", - "tokio", - "tower-service", - "tracing", -] - [[package]] name = "iana-time-zone" version = "0.1.64" @@ -2505,22 +2381,6 @@ dependencies = [ "leaky-cow", ] -[[package]] -name = "ipnet" -version = "2.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" - -[[package]] -name = "iri-string" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" -dependencies = [ - "memchr", - "serde", -] - [[package]] name = "is-terminal" version = "0.4.17" @@ -2813,15 +2673,15 @@ dependencies = [ [[package]] name = "llm_adapter" -version = "0.1.1" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8dd9a548766bccf8b636695e8d514edee672d180e96a16ab932c971783b4e353" +checksum = "e98485dda5180cc89b993a001688bed93307be6bd8fedcde445b69bbca4f554d" dependencies = [ "base64", - "reqwest", "serde", "serde_json", "thiserror 2.0.17", + "ureq", ] [[package]] @@ -2889,12 +2749,6 @@ dependencies = [ "hashbrown 0.16.1", ] -[[package]] -name = "lru-slab" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" - [[package]] name = "mac" version = "0.1.1" @@ -2954,6 +2808,16 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matroska" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fde85cd7fb5cf875c4a46fac0cbd6567d413bea2538cef6788e3a0e52a902b45" +dependencies = [ + "bitstream-io", + "phf 0.11.3", +] + [[package]] name = "md-5" version = "0.10.6" @@ -3396,12 +3260,6 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" -[[package]] -name = "openssl-probe" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" - [[package]] name = "ordered-float" version = "5.1.0" @@ -3884,62 +3742,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "quinn" -version = "0.11.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" -dependencies = [ - "bytes", - "cfg_aliases", - "pin-project-lite", - "quinn-proto", - "quinn-udp", - "rustc-hash 2.1.1", - "rustls", - "socket2", - "thiserror 2.0.17", - "tokio", - "tracing", - "web-time", -] - -[[package]] -name = "quinn-proto" -version = "0.11.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" -dependencies = [ - "aws-lc-rs", - "bytes", - "getrandom 0.3.4", - "lru-slab", - "rand 0.9.2", - "ring", - "rustc-hash 2.1.1", - "rustls", - "rustls-pki-types", - "slab", - "thiserror 2.0.17", - "tinyvec", - "tracing", - "web-time", -] - -[[package]] -name = "quinn-udp" -version = "0.5.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" -dependencies = [ - "cfg_aliases", - "libc", - "once_cell", - "socket2", - "tracing", - "windows-sys 0.60.2", -] - [[package]] name = "quote" version = "1.0.43" @@ -4128,45 +3930,6 @@ version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" -[[package]] -name = "reqwest" -version = "0.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801" -dependencies = [ - "base64", - "bytes", - "futures-channel", - "futures-core", - "futures-util", - "http", - "http-body", - "http-body-util", - "hyper", - "hyper-rustls", - "hyper-util", - "js-sys", - "log", - "percent-encoding", - "pin-project-lite", - "quinn", - "rustls", - "rustls-pki-types", - "rustls-platform-verifier", - "serde", - "serde_json", - "sync_wrapper", - "tokio", - "tokio-rustls", - "tower", - "tower-http", - "tower-service", - "url", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", -] - [[package]] name = "ring" version = "0.17.14" @@ -4307,7 +4070,7 @@ version = "0.23.36" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" dependencies = [ - "aws-lc-rs", + "log", "once_cell", "ring", "rustls-pki-types", @@ -4316,62 +4079,21 @@ dependencies = [ "zeroize", ] -[[package]] -name = "rustls-native-certs" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" -dependencies = [ - "openssl-probe", - "rustls-pki-types", - "schannel", - "security-framework", -] - [[package]] name = "rustls-pki-types" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282" dependencies = [ - "web-time", "zeroize", ] -[[package]] -name = "rustls-platform-verifier" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" -dependencies = [ - "core-foundation", - "core-foundation-sys", - "jni", - "log", - "once_cell", - "rustls", - "rustls-native-certs", - "rustls-platform-verifier-android", - "rustls-webpki", - "security-framework", - "security-framework-sys", - "webpki-root-certs", - "windows-sys 0.61.2", -] - -[[package]] -name = "rustls-platform-verifier-android" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" - [[package]] name = "rustls-webpki" version = "0.103.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" dependencies = [ - "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -4410,15 +4132,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "schannel" -version = "0.1.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" -dependencies = [ - "windows-sys 0.61.2", -] - [[package]] name = "scoped-tls" version = "1.0.1" @@ -4467,29 +4180,6 @@ dependencies = [ "syn 2.0.114", ] -[[package]] -name = "security-framework" -version = "3.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" -dependencies = [ - "bitflags 2.11.0", - "core-foundation", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework-sys" -version = "2.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "semver" version = "1.0.27" @@ -5215,15 +4905,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "sync_wrapper" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" -dependencies = [ - "futures-core", -] - [[package]] name = "synstructure" version = "0.13.2" @@ -5415,16 +5096,6 @@ dependencies = [ "syn 2.0.114", ] -[[package]] -name = "tokio-rustls" -version = "0.26.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" -dependencies = [ - "rustls", - "tokio", -] - [[package]] name = "tokio-stream" version = "0.1.18" @@ -5475,51 +5146,6 @@ dependencies = [ "winnow", ] -[[package]] -name = "tower" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" -dependencies = [ - "futures-core", - "futures-util", - "pin-project-lite", - "sync_wrapper", - "tokio", - "tower-layer", - "tower-service", -] - -[[package]] -name = "tower-http" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" -dependencies = [ - "bitflags 2.11.0", - "bytes", - "futures-util", - "http", - "http-body", - "iri-string", - "pin-project-lite", - "tower", - "tower-layer", - "tower-service", -] - -[[package]] -name = "tower-layer" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" - -[[package]] -name = "tower-service" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" - [[package]] name = "tracing" version = "0.1.44" @@ -5722,12 +5348,6 @@ dependencies = [ "tree-sitter-language", ] -[[package]] -name = "try-lock" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" - [[package]] name = "type1-encoding-parser" version = "0.1.0" @@ -5952,6 +5572,35 @@ version = "0.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" +[[package]] +name = "ureq" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc97a28575b85cfedf2a7e7d3cc64b3e11bd8ac766666318003abbacc7a21fc" +dependencies = [ + "base64", + "flate2", + "log", + "percent-encoding", + "rustls", + "rustls-pki-types", + "ureq-proto", + "utf-8", + "webpki-roots 1.0.5", +] + +[[package]] +name = "ureq-proto" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f" +dependencies = [ + "base64", + "http", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.8" @@ -6048,15 +5697,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "want" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" -dependencies = [ - "try-lock", -] - [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" @@ -6146,25 +5786,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "web-time" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - -[[package]] -name = "webpki-root-certs" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "webpki-roots" version = "0.26.11" diff --git a/Cargo.toml b/Cargo.toml index da26cba572..a547b0f89a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,10 +53,11 @@ resolver = "3" libc = "0.2" libwebp-sys = "0.14.2" little_exif = "0.6.23" - llm_adapter = "0.1.1" + llm_adapter = { version = "0.1.3", default-features = false } log = "0.4" loom = { version = "0.7", features = ["checkpoint"] } lru = "0.16" + matroska = "0.30" memory-indexer = "0.3.0" mimalloc = "0.1" mp4parse = "0.17" diff --git a/packages/backend/native/Cargo.toml b/packages/backend/native/Cargo.toml index caaddf88c6..8309dc13a9 100644 --- a/packages/backend/native/Cargo.toml +++ b/packages/backend/native/Cargo.toml @@ -21,7 +21,10 @@ image = { workspace = true } infer = { workspace = true } libwebp-sys = { workspace = true } little_exif = { workspace = true } -llm_adapter = { workspace = true } +llm_adapter = { workspace = true, default-features = false, features = [ + "ureq-client", +] } +matroska = { workspace = true } mp4parse = { workspace = true } napi = { workspace = true, features = ["async"] } napi-derive = { workspace = true } diff --git a/packages/backend/native/fixtures/audio-only.mka b/packages/backend/native/fixtures/audio-only.mka new file mode 100644 index 0000000000000000000000000000000000000000..0d7e50b25b5275c26f9c893702be8a60673f3530 GIT binary patch literal 3385 zcmeH~e^gRw8pp4$u0JYc%{=9COkICxg@U_MG^vP5f-WMd8MBKDWJaSB8QE-NC}^4n z%7{aH6w^XwM>9N0r3dYY<>)p}k}{4bHCac~4AZH$*!NQB%+CM&-*yjh&U-J<^FGh# z`F`JXt_yxrp@=BrTSTmZ?@PW_gpqF>Q6OT|Pvm^b$)raJ!L35RN8l1C4IydHQ(6R@30IkvR(}15m2-kB6RegHqczu4kY}u)Of|t^i&8xs4HFr2JAM{Uq~G!8=hqgOJF?@ZT&sjTHy} z^3C{X-^$~^9aITVHw5j9EEa{O_L|Fr`+40Xs3P9csy-HdRL8gA+!0kzAO6w#?fWD- zAM(5|KgT_uF5*K)Md5#qB8KD6@hzWTtm9i;7ll0&<-?g$)Ah9JuG{ z2suKKwHsu8)4o6K1q5x?AsjBR3r*<0OJh9KeRcGt66P0nwnrfSX{m?9sY~n4XWf_K z6+(P0E-n%uA1VB8C!8PBde5i*S{mzq-GBdC3@D@l3PC|I26yH$dK&9~JzgP-q{fFQ z5aMDZg%TukxB0U(CTbnLHVX%l5HVuECgBNta=pt7x^}JGURN=Mlr){D2E_J;UtZp6C1_p11PVV zj;-l`_FzfsQ7DYR&=)NkG$0!#6*ICvd9V;8sit_zd{|!oMVy(JJbxdKn3*n)#}Uxl zy7e>IilHo-eD+LbFb{?G75AYFC1NpeV8DmTv&Ug1C=|*&*d zViv}SN$~bUONjn2hHwMx4R*32)&R~|g3|H=QPIg5e_y#=%w$zGqGVv>V2Vfcr48VDaCsCgm=n zHI}nQ?_mlcsQW$&3V(X2@BGAQ#DbQ&w)- zO-?c6or=THHzzZ?Xl!XRb>FJ7FsR=tx!D%s9zI(`XCR(kb}?LO(gdpKRq6*>3Yf{ROz#FRvIO^(n`!qGG?d6Bg>KiAV33FqFfDDA#O5F@krun!GFZpO@AE8aR54mT*>c*ZE!rxv18> zd#Rr=C?o3in*i0r9o_S-9m=y4hIN@DzrI-PvgO9m5qt0FhsdgvC9NHLmHJ_zvhFPp zv~C|bA^_-&BxHvu=le3w4-_-_6d3%d<2?kK>6$%dCBlxZ^M0OsQFXc`v*Y3MZv8Q3 z&ECPM#B3{R#=TniuyaF5Z8XQREZ5Z3vO_)JrtY3o{5Trig;i_-;BB$SQh?@;N6ONi zN&`o$y05IE%xGe<#(p>6T`Ae&=%d=amB+NM6mxI6HFpo{#K}c)@?lWC$HIY4w@=s1 zcj!U<_NMaOqzB{S-qjgkF7C!MWW(Xd9=n(RrF_UXgFj-*m0wkPDI7wL{UYo%(eLZ_ zBe8D*8aVoxO1W=hdR^LmzrMH+yF1~kI~8Z%ZBB$$F$zSrXe_+m??i8QI|W3MQRukZ z=BL|(KbQ}bjv6d zueL|)_TodvV9lzjB)8eDrqF^I^z<-1_APWTOmURPGy=xX9o5rLJX*}$no;EWn!82Q-y)LMg7{-Zj|&lU4gJB#}Kj0r&RVYej{rWComp#pXTR61o z_WP0+I9k75S-y5k`{X4+Gi=_y5HY*MZ*m`=3af$(MD_OOQ+BrX7T8bbv*1?Cfsi3^ z7Nl@T#~4)<)0c4T_?j%frIlMo6HoIm4{ukWE-vL9wW8Xq>NF}RGE04`5k&cGW1Aa49R#M=Xm9KjQlU$RNru-To R>M8**Ny`NmFJJKO{sScE*|7is literal 0 HcmV?d00001 diff --git a/packages/backend/native/fixtures/audio-only.webm b/packages/backend/native/fixtures/audio-only.webm new file mode 100644 index 0000000000000000000000000000000000000000..92cfe023b10c1c506ec5c470c69b481b4a68f40b GIT binary patch literal 1023 zcmb1gy}x+AQ(GgW({~{L)X3uWxsk)EsiizMDc7kT$Zc(8k_c`{XJh~Y=JSHSvrBgQ z?(Pm=-6HC_GA(#b<3b2eEM#3akax{@cMnibDCBT@@R}w@2MAXtcsEe8iFxN6h&c?A zVB=^(fLygf>-xrKekTXth8~~9vNSUzJ!3sX1A}l!6s}8WhePrKup2;*f4UCpwwUXU zjKxPg*14QlG3mHsHK&n5Y3rQk(!7+8MuxUF$9VsM(BQCD&B=NT!gjZ?9NEIq+{pBB z;`~O&tsM@}3&7?jr8Ej8FPiLMP+IJfnwY}KXu;U98!Q6~m_NnE^OFxWq@S4Sy13od z&)MI_HOQqM#cjpK*OL!)CnH>ueqs@_;;BG?B?J9sXbAF{LQZBiFq-Bg5hYjSLJ86-^I=mwgfCu`qh8YZbc5 z!_V^KE_vey=ZCCCe>nr|Z$D0o{ONb_KVQG-(p3x#19FTzesg7TtnF6kj8$8_uaQBt zp<==M1w89|uU)+GOt9Xysrr;&@VBSSj?A~vKjb`Z``bnHs%1{`G4sBkpdeQKe6ra) z*N45*4SkDa#WyeB-pHWYP%(RPob#*h1rm$f?%Vud{x|#KX2$N{cO&i=H5ACCPI$UV zI6u86I_!rz$F;9p7tS$U*L{5P&PE1XB#qs3Ph5+>mSO0bv2WdvMy|6ir#`N*edm1T zleO_&2EpUep1U{9Q!q_sPX;7FJA4@ zPFw4-?$YwBzwf1%JxP2w>#p|ZSn)oP=2Rrj-K$=OecWgt&$zoS+EIyb%iH-^D|Y;} z)x9<2%B|0tLWfU3c)%U~F>~sga9gWKpwOyB(%ji{(A&Y~_KrE1X9!9xm74ndZIt!D z&PUG_|E%}VTXf!L|E>RePML3)e6>k>B@DED!QO>i8yR{TD(0-c67~Fc@mqyb{TaHw zH}eYB-mw^M`r}o5=7-mfv?bG5Onm2aFL&R=Bbx&w-prY}taMScY5D_D7M!@de{&-xrKekTXth8~~9vNSUzJ!3sX1A}l!6s}8WhePrKup2;*cUZd)>NxxB zjf}-dJESM}+-i8_6F;YsL22up=F+^Bjz)&I*0A`n0E^~EM&-xNElf#K3=bx3XcXAh zC@`y$DRe_i=*~t)m^m@m8<~LSoS%Ez>wJH01gbf0j`98hp}}FRnv?Ywgzau&IkJVJ zxsmDN#QBYkTRR+{7l0j>l+q}WylAq2L20o^YGMi_qXlEbZmV(2#y& zrt9K%S3hTe7uO({b`)P07hg+0(4CC%PWp+3$cm=`Lns*-LWYK*AX3Q5Oe!m=0ITV6 z2@P@#@$~m|X`gIhUFpXJ3}^2QC$4_S-;at7Alew-Bf)9>Ja zzJAfAs~8pr!U0Qzi_r27zCyDQ7 z-PPV4E8YikP5>6?0L@HAGP8TttFVt7?c*7Dw?#WD@ojlK|7yjKpSHTUW?Z@TIaBEH z=?4$Eqd#U&T@!9=^#~NUl}MU9TMl|VxZK_`=kg3eiKS9gf4_~g{@3~Fnc|=I{&|be z+w8yfU(YG??UJuHNw0)~Zd$N+;npUGUWSS}Yp+B-zg_%Rp;Ui{Ztu;!LbZ1+Mw|Y4 z)t>p`bt7%b^c54|`P|Fh_wdN(z=$_s6#FX-( G(G>vwtoAGb literal 0 HcmV?d00001 diff --git a/packages/backend/native/index.d.ts b/packages/backend/native/index.d.ts index 8e08b09456..7fe7ecf820 100644 --- a/packages/backend/native/index.d.ts +++ b/packages/backend/native/index.d.ts @@ -54,6 +54,12 @@ export declare function llmDispatch(protocol: string, backendConfigJson: string, export declare function llmDispatchStream(protocol: string, backendConfigJson: string, requestJson: string, callback: ((err: Error | null, arg: string) => void)): LlmStreamHandle +export declare function llmEmbeddingDispatch(protocol: string, backendConfigJson: string, requestJson: string): string + +export declare function llmRerankDispatch(protocol: string, backendConfigJson: string, requestJson: string): string + +export declare function llmStructuredDispatch(protocol: string, backendConfigJson: string, requestJson: string): string + /** * Merge updates in form like `Y.applyUpdate(doc, update)` way and return the * result binary. diff --git a/packages/backend/native/src/file_type.rs b/packages/backend/native/src/file_type.rs index 4f87b2412c..dc43ce1ed9 100644 --- a/packages/backend/native/src/file_type.rs +++ b/packages/backend/native/src/file_type.rs @@ -1,3 +1,4 @@ +use matroska::Matroska; use mp4parse::{TrackType, read_mp4}; use napi_derive::napi; @@ -8,7 +9,13 @@ pub fn get_mime(input: &[u8]) -> String { } else { file_format::FileFormat::from_bytes(input).media_type().to_string() }; - if mimetype == "video/mp4" { + if let Some(container) = matroska_container_kind(input).or(match mimetype.as_str() { + "video/webm" | "application/webm" => Some(ContainerKind::WebM), + "video/x-matroska" | "application/x-matroska" => Some(ContainerKind::Matroska), + _ => None, + }) { + detect_matroska_flavor(input, container, &mimetype) + } else if mimetype == "video/mp4" { detect_mp4_flavor(input) } else { mimetype @@ -37,3 +44,68 @@ fn detect_mp4_flavor(input: &[u8]) -> String { Err(_) => "video/mp4".to_string(), } } + +#[derive(Clone, Copy)] +enum ContainerKind { + WebM, + Matroska, +} + +impl ContainerKind { + fn audio_mime(&self) -> &'static str { + match self { + ContainerKind::WebM => "audio/webm", + ContainerKind::Matroska => "audio/x-matroska", + } + } +} + +fn detect_matroska_flavor(input: &[u8], container: ContainerKind, fallback: &str) -> String { + match Matroska::open(std::io::Cursor::new(input)) { + Ok(file) => { + let has_video = file.video_tracks().next().is_some(); + let has_audio = file.audio_tracks().next().is_some(); + if !has_video && has_audio { + container.audio_mime().to_string() + } else { + fallback.to_string() + } + } + Err(_) => fallback.to_string(), + } +} + +fn matroska_container_kind(input: &[u8]) -> Option { + let header = &input[..1024.min(input.len())]; + if header.windows(4).any(|window| window.eq_ignore_ascii_case(b"webm")) { + Some(ContainerKind::WebM) + } else if header.windows(8).any(|window| window.eq_ignore_ascii_case(b"matroska")) { + Some(ContainerKind::Matroska) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const AUDIO_ONLY_WEBM: &[u8] = include_bytes!("../fixtures/audio-only.webm"); + const AUDIO_VIDEO_WEBM: &[u8] = include_bytes!("../fixtures/audio-video.webm"); + const AUDIO_ONLY_MATROSKA: &[u8] = include_bytes!("../fixtures/audio-only.mka"); + + #[test] + fn detects_audio_only_webm_as_audio() { + assert_eq!(get_mime(AUDIO_ONLY_WEBM), "audio/webm"); + } + + #[test] + fn preserves_video_webm() { + assert_eq!(get_mime(AUDIO_VIDEO_WEBM), "video/webm"); + } + + #[test] + fn detects_audio_only_matroska_as_audio() { + assert_eq!(get_mime(AUDIO_ONLY_MATROSKA), "audio/x-matroska"); + } +} diff --git a/packages/backend/native/src/llm.rs b/packages/backend/native/src/llm.rs index 718db3dbdf..26ef80aed2 100644 --- a/packages/backend/native/src/llm.rs +++ b/packages/backend/native/src/llm.rs @@ -5,9 +5,10 @@ use std::sync::{ use llm_adapter::{ backend::{ - BackendConfig, BackendError, BackendProtocol, ReqwestHttpClient, dispatch_request, dispatch_stream_events_with, + BackendConfig, BackendError, BackendProtocol, DefaultHttpClient, dispatch_embedding_request, dispatch_request, + dispatch_rerank_request, dispatch_stream_events_with, dispatch_structured_request, }, - core::{CoreRequest, StreamEvent}, + core::{CoreRequest, EmbeddingRequest, RerankRequest, StreamEvent, StructuredRequest}, middleware::{ MiddlewareConfig, PipelineContext, RequestMiddleware, StreamMiddleware, citation_indexing, clamp_max_tokens, normalize_messages, run_request_middleware_chain, run_stream_middleware_chain, stream_event_normalize, @@ -40,6 +41,20 @@ struct LlmDispatchPayload { middleware: LlmMiddlewarePayload, } +#[derive(Debug, Clone, Deserialize)] +struct LlmStructuredDispatchPayload { + #[serde(flatten)] + request: StructuredRequest, + #[serde(default)] + middleware: LlmMiddlewarePayload, +} + +#[derive(Debug, Clone, Deserialize)] +struct LlmRerankDispatchPayload { + #[serde(flatten)] + request: RerankRequest, +} + #[napi] pub struct LlmStreamHandle { aborted: Arc, @@ -61,7 +76,44 @@ pub fn llm_dispatch(protocol: String, backend_config_json: String, request_json: let request = apply_request_middlewares(payload.request, &payload.middleware)?; let response = - dispatch_request(&ReqwestHttpClient::default(), &config, protocol, &request).map_err(map_backend_error)?; + dispatch_request(&DefaultHttpClient::default(), &config, protocol, &request).map_err(map_backend_error)?; + + serde_json::to_string(&response).map_err(map_json_error) +} + +#[napi(catch_unwind)] +pub fn llm_structured_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result { + let protocol = parse_protocol(&protocol)?; + let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?; + let payload: LlmStructuredDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?; + let request = apply_structured_request_middlewares(payload.request, &payload.middleware)?; + + let response = dispatch_structured_request(&DefaultHttpClient::default(), &config, protocol, &request) + .map_err(map_backend_error)?; + + serde_json::to_string(&response).map_err(map_json_error) +} + +#[napi(catch_unwind)] +pub fn llm_embedding_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result { + let protocol = parse_protocol(&protocol)?; + let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?; + let request: EmbeddingRequest = serde_json::from_str(&request_json).map_err(map_json_error)?; + + let response = dispatch_embedding_request(&DefaultHttpClient::default(), &config, protocol, &request) + .map_err(map_backend_error)?; + + serde_json::to_string(&response).map_err(map_json_error) +} + +#[napi(catch_unwind)] +pub fn llm_rerank_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result { + let protocol = parse_protocol(&protocol)?; + let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?; + let payload: LlmRerankDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?; + + let response = dispatch_rerank_request(&DefaultHttpClient::default(), &config, protocol, &payload.request) + .map_err(map_backend_error)?; serde_json::to_string(&response).map_err(map_json_error) } @@ -98,7 +150,7 @@ pub fn llm_dispatch_stream( let mut aborted_by_user = false; let mut callback_dispatch_failed = false; - let result = dispatch_stream_events_with(&ReqwestHttpClient::default(), &config, protocol, &request, |event| { + let result = dispatch_stream_events_with(&DefaultHttpClient::default(), &config, protocol, &request, |event| { if aborted_in_worker.load(Ordering::Relaxed) { aborted_by_user = true; return Err(BackendError::Http(STREAM_ABORTED_REASON.to_string())); @@ -155,6 +207,27 @@ fn apply_request_middlewares(request: CoreRequest, middleware: &LlmMiddlewarePay Ok(run_request_middleware_chain(request, &middleware.config, &chain)) } +fn apply_structured_request_middlewares( + request: StructuredRequest, + middleware: &LlmMiddlewarePayload, +) -> Result { + let mut core = request.as_core_request(); + core = apply_request_middlewares(core, middleware)?; + + Ok(StructuredRequest { + model: core.model, + messages: core.messages, + schema: core + .response_schema + .ok_or_else(|| Error::new(Status::InvalidArg, "Structured request schema is required"))?, + max_tokens: core.max_tokens, + temperature: core.temperature, + reasoning: core.reasoning, + strict: request.strict, + response_mime_type: request.response_mime_type, + }) +} + #[derive(Clone)] struct StreamPipeline { chain: Vec, @@ -268,6 +341,7 @@ fn parse_protocol(protocol: &str) -> Result { } "openai_responses" | "openai-responses" | "responses" => Ok(BackendProtocol::OpenaiResponses), "anthropic" | "anthropic_messages" | "anthropic-messages" => Ok(BackendProtocol::AnthropicMessages), + "gemini" | "gemini_generate_content" | "gemini-generate-content" => Ok(BackendProtocol::GeminiGenerateContent), other => Err(Error::new( Status::InvalidArg, format!("Unsupported llm backend protocol: {other}"), @@ -293,6 +367,7 @@ mod tests { assert!(parse_protocol("chat-completions").is_ok()); assert!(parse_protocol("responses").is_ok()); assert!(parse_protocol("anthropic").is_ok()); + assert!(parse_protocol("gemini").is_ok()); } #[test] diff --git a/packages/backend/server/package.json b/packages/backend/server/package.json index 5bef6c0c10..11e83e33a2 100644 --- a/packages/backend/server/package.json +++ b/packages/backend/server/package.json @@ -25,8 +25,6 @@ "dependencies": { "@affine/s3-compat": "workspace:*", "@affine/server-native": "workspace:*", - "@ai-sdk/google": "^3.0.46", - "@ai-sdk/google-vertex": "^4.0.83", "@apollo/server": "^4.13.0", "@fal-ai/serverless-client": "^0.15.0", "@google-cloud/opentelemetry-cloud-trace-exporter": "^3.0.0", @@ -66,7 +64,6 @@ "@queuedash/api": "^3.16.0", "@react-email/components": "^0.5.7", "@socket.io/redis-adapter": "^8.3.0", - "ai": "^6.0.118", "bullmq": "^5.40.2", "cookie-parser": "^1.4.7", "cross-env": "^10.1.0", diff --git a/packages/backend/server/src/__tests__/copilot/copilot-provider.spec.ts b/packages/backend/server/src/__tests__/copilot/copilot-provider.spec.ts index 224051aaff..903932683e 100644 --- a/packages/backend/server/src/__tests__/copilot/copilot-provider.spec.ts +++ b/packages/backend/server/src/__tests__/copilot/copilot-provider.spec.ts @@ -225,6 +225,20 @@ const checkStreamObjects = (result: string) => { } }; +const parseStreamObjects = (result: string): StreamObject[] => { + const streamObjects = JSON.parse(result); + return z.array(StreamObjectSchema).parse(streamObjects); +}; + +const getStreamObjectText = (result: string) => + parseStreamObjects(result) + .filter( + (chunk): chunk is Extract => + chunk.type === 'text-delta' + ) + .map(chunk => chunk.textDelta) + .join(''); + const retry = async ( action: string, t: ExecutionContext, @@ -444,6 +458,49 @@ The term **“CRDT”** was first introduced by Marc Shapiro, Nuno Preguiça, Ca }, type: 'object' as const, }, + { + name: 'Gemini native text', + promptName: ['Chat With AFFiNE AI'], + messages: [ + { + role: 'user' as const, + content: + 'In one short sentence, explain what AFFiNE AI is and mention AFFiNE by name.', + }, + ], + config: { model: 'gemini-2.5-flash' }, + verifier: (t: ExecutionContext, result: string) => { + assertNotWrappedInCodeBlock(t, result); + t.assert( + result.toLowerCase().includes('affine'), + 'should mention AFFiNE' + ); + }, + prefer: CopilotProviderType.Gemini, + type: 'text' as const, + }, + { + name: 'Gemini native stream objects', + promptName: ['Chat With AFFiNE AI'], + messages: [ + { + role: 'user' as const, + content: + 'Respond with one short sentence about AFFiNE AI and mention AFFiNE by name.', + }, + ], + config: { model: 'gemini-2.5-flash' }, + verifier: (t: ExecutionContext, result: string) => { + t.truthy(checkStreamObjects(result), 'should be valid stream objects'); + const assembledText = getStreamObjectText(result); + t.assert( + assembledText.toLowerCase().includes('affine'), + 'should mention AFFiNE' + ); + }, + prefer: CopilotProviderType.Gemini, + type: 'object' as const, + }, { name: 'Should transcribe short audio', promptName: ['Transcript audio'], @@ -716,14 +773,13 @@ for (const { const { factory, prompt: promptService } = t.context; const prompt = (await promptService.get(promptName))!; t.truthy(prompt, 'should have prompt'); - const provider = (await factory.getProviderByModel(prompt.model, { + const finalConfig = Object.assign({}, prompt.config, config); + const modelId = finalConfig.model || prompt.model; + const provider = (await factory.getProviderByModel(modelId, { prefer, }))!; t.truthy(provider, 'should have provider'); await retry(`action: ${promptName}`, t, async t => { - const finalConfig = Object.assign({}, prompt.config, config); - const modelId = finalConfig.model || prompt.model; - switch (type) { case 'text': { const result = await provider.text( @@ -891,7 +947,7 @@ test( 'should be able to rerank message chunks', runIfCopilotConfigured, async t => { - const { factory, prompt } = t.context; + const { factory } = t.context; await retry('rerank', t, async t => { const query = 'Is this content relevant to programming?'; @@ -908,14 +964,18 @@ test( 'The stock market is experiencing significant fluctuations.', ]; - const p = (await prompt.get('Rerank results'))!; - t.assert(p, 'should have prompt for rerank'); - const provider = (await factory.getProviderByModel(p.model))!; + const provider = (await factory.getProviderByModel('gpt-5.2'))!; t.assert(provider, 'should have provider for rerank'); const scores = await provider.rerank( - { modelId: p.model }, - embeddings.map(e => p.finish({ query, doc: e })) + { modelId: 'gpt-5.2' }, + { + query, + candidates: embeddings.map((text, index) => ({ + id: String(index), + text, + })), + } ); t.is(scores.length, 10, 'should return scores for all chunks'); diff --git a/packages/backend/server/src/__tests__/copilot/copilot.spec.ts b/packages/backend/server/src/__tests__/copilot/copilot.spec.ts index 30e6999c40..6b6ef9d22e 100644 --- a/packages/backend/server/src/__tests__/copilot/copilot.spec.ts +++ b/packages/backend/server/src/__tests__/copilot/copilot.spec.ts @@ -33,10 +33,7 @@ import { ModelOutputType, OpenAIProvider, } from '../../plugins/copilot/providers'; -import { - CitationParser, - TextStreamParser, -} from '../../plugins/copilot/providers/utils'; +import { TextStreamParser } from '../../plugins/copilot/providers/utils'; import { ChatSessionService } from '../../plugins/copilot/session'; import { CopilotStorage } from '../../plugins/copilot/storage'; import { CopilotTranscriptionService } from '../../plugins/copilot/transcript'; @@ -660,6 +657,55 @@ test('should be able to generate with message id', async t => { } }); +test('should preserve file handle attachments when merging user content into prompt', async t => { + const { prompt, session } = t.context; + + await prompt.set(promptName, 'model', [ + { role: 'user', content: '{{content}}' }, + ]); + + const sessionId = await session.create({ + docId: 'test', + workspaceId: 'test', + userId, + promptName, + pinned: false, + }); + const s = (await session.get(sessionId))!; + + const message = await session.createMessage({ + sessionId, + content: 'Summarize this file', + attachments: [ + { + kind: 'file_handle', + fileHandle: 'file_123', + mimeType: 'application/pdf', + }, + ], + }); + + await s.pushByMessageId(message); + const finalMessages = s.finish({}); + + t.deepEqual(finalMessages, [ + { + role: 'user', + content: 'Summarize this file', + attachments: [ + { + kind: 'file_handle', + fileHandle: 'file_123', + mimeType: 'application/pdf', + }, + ], + params: { + content: 'Summarize this file', + }, + }, + ]); +}); + test('should save message correctly', async t => { const { prompt, session } = t.context; @@ -1225,149 +1271,6 @@ test('should be able to run image executor', async t => { Sinon.restore(); }); -test('CitationParser should replace citation placeholders with URLs', t => { - const content = - 'This is [a] test sentence with [citations [1]] and [[2]] and [3].'; - const citations = ['https://example1.com', 'https://example2.com']; - - const parser = new CitationParser(); - for (const citation of citations) { - parser.push(citation); - } - - const result = parser.parse(content) + parser.end(); - - const expected = [ - 'This is [a] test sentence with [citations [^1]] and [^2] and [3].', - `[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`, - `[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`, - ].join('\n'); - - t.is(result, expected); -}); - -test('CitationParser should replace chunks of citation placeholders with URLs', t => { - const contents = [ - '[[]]', - 'This is [', - 'a] test sentence ', - 'with citations [1', - '] and [', - '[2]] and [[', - '3]] and [[4', - ']] and [[5]', - '] and [[6]]', - ' and [7', - ]; - const citations = [ - 'https://example1.com', - 'https://example2.com', - 'https://example3.com', - 'https://example4.com', - 'https://example5.com', - 'https://example6.com', - 'https://example7.com', - ]; - - const parser = new CitationParser(); - for (const citation of citations) { - parser.push(citation); - } - - let result = contents.reduce((acc, current) => { - return acc + parser.parse(current); - }, ''); - result += parser.end(); - - const expected = [ - '[[]]This is [a] test sentence with citations [^1] and [^2] and [^3] and [^4] and [^5] and [^6] and [7', - `[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`, - `[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`, - `[^3]: {"type":"url","url":"${encodeURIComponent(citations[2])}"}`, - `[^4]: {"type":"url","url":"${encodeURIComponent(citations[3])}"}`, - `[^5]: {"type":"url","url":"${encodeURIComponent(citations[4])}"}`, - `[^6]: {"type":"url","url":"${encodeURIComponent(citations[5])}"}`, - `[^7]: {"type":"url","url":"${encodeURIComponent(citations[6])}"}`, - ].join('\n'); - t.is(result, expected); -}); - -test('CitationParser should not replace citation already with URLs', t => { - const content = - 'This is [a] test sentence with citations [1](https://example1.com) and [[2]](https://example2.com) and [[3](https://example3.com)].'; - const citations = [ - 'https://example4.com', - 'https://example5.com', - 'https://example6.com', - ]; - - const parser = new CitationParser(); - for (const citation of citations) { - parser.push(citation); - } - - const result = parser.parse(content) + parser.end(); - - const expected = [ - content, - `[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`, - `[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`, - `[^3]: {"type":"url","url":"${encodeURIComponent(citations[2])}"}`, - ].join('\n'); - t.is(result, expected); -}); - -test('CitationParser should not replace chunks of citation already with URLs', t => { - const contents = [ - 'This is [a] test sentence with citations [1', - '](https://example1.com) and [[2]', - '](https://example2.com) and [[3](https://example3.com)].', - ]; - const citations = [ - 'https://example4.com', - 'https://example5.com', - 'https://example6.com', - ]; - - const parser = new CitationParser(); - for (const citation of citations) { - parser.push(citation); - } - - let result = contents.reduce((acc, current) => { - return acc + parser.parse(current); - }, ''); - result += parser.end(); - - const expected = [ - contents.join(''), - `[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`, - `[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`, - `[^3]: {"type":"url","url":"${encodeURIComponent(citations[2])}"}`, - ].join('\n'); - t.is(result, expected); -}); - -test('CitationParser should replace openai style reference chunks', t => { - const contents = [ - 'This is [a] test sentence with citations ', - '([example1.com](https://example1.com))', - ]; - - const parser = new CitationParser(); - - let result = contents.reduce((acc, current) => { - return acc + parser.parse(current); - }, ''); - result += parser.end(); - - const expected = [ - contents[0] + '[^1]', - `[^1]: {"type":"url","url":"${encodeURIComponent('https://example1.com')}"}`, - ].join('\n'); - t.is(result, expected); -}); - test('TextStreamParser should format different types of chunks correctly', t => { // Define interfaces for fixtures interface BaseFixture { diff --git a/packages/backend/server/src/__tests__/copilot/native-adapter.spec.ts b/packages/backend/server/src/__tests__/copilot/native-adapter.spec.ts deleted file mode 100644 index da44e589a1..0000000000 --- a/packages/backend/server/src/__tests__/copilot/native-adapter.spec.ts +++ /dev/null @@ -1,210 +0,0 @@ -import test from 'ava'; -import { z } from 'zod'; - -import type { NativeLlmRequest, NativeLlmStreamEvent } from '../../native'; -import { - buildNativeRequest, - NativeProviderAdapter, -} from '../../plugins/copilot/providers/native'; - -const mockDispatch = () => - (async function* (): AsyncIterableIterator { - yield { type: 'text_delta', text: 'Use [^1] now' }; - yield { type: 'citation', index: 1, url: 'https://affine.pro' }; - yield { type: 'done', finish_reason: 'stop' }; - })(); - -test('NativeProviderAdapter streamText should append citation footnotes', async t => { - const adapter = new NativeProviderAdapter(mockDispatch, {}, 3); - const chunks: string[] = []; - for await (const chunk of adapter.streamText({ - model: 'gpt-5-mini', - stream: true, - messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }], - })) { - chunks.push(chunk); - } - - const text = chunks.join(''); - t.true(text.includes('Use [^1] now')); - t.true( - text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}') - ); -}); - -test('NativeProviderAdapter streamObject should append citation footnotes', async t => { - const adapter = new NativeProviderAdapter(mockDispatch, {}, 3); - const chunks = []; - for await (const chunk of adapter.streamObject({ - model: 'gpt-5-mini', - stream: true, - messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }], - })) { - chunks.push(chunk); - } - - t.deepEqual( - chunks.map(chunk => chunk.type), - ['text-delta', 'text-delta'] - ); - const text = chunks - .filter(chunk => chunk.type === 'text-delta') - .map(chunk => chunk.textDelta) - .join(''); - t.true(text.includes('Use [^1] now')); - t.true( - text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}') - ); -}); - -test('NativeProviderAdapter streamObject should append fallback attachment footnotes', async t => { - const dispatch = () => - (async function* (): AsyncIterableIterator { - yield { - type: 'tool_result', - call_id: 'call_1', - name: 'blob_read', - arguments: { blob_id: 'blob_1' }, - output: { - blobId: 'blob_1', - fileName: 'a.txt', - fileType: 'text/plain', - content: 'A', - }, - }; - yield { - type: 'tool_result', - call_id: 'call_2', - name: 'blob_read', - arguments: { blob_id: 'blob_2' }, - output: { - blobId: 'blob_2', - fileName: 'b.txt', - fileType: 'text/plain', - content: 'B', - }, - }; - yield { type: 'text_delta', text: 'Answer from files.' }; - yield { type: 'done', finish_reason: 'stop' }; - })(); - - const adapter = new NativeProviderAdapter(dispatch, {}, 3); - const chunks = []; - for await (const chunk of adapter.streamObject({ - model: 'gpt-5-mini', - stream: true, - messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }], - })) { - chunks.push(chunk); - } - - const text = chunks - .filter(chunk => chunk.type === 'text-delta') - .map(chunk => chunk.textDelta) - .join(''); - t.true(text.includes('Answer from files.')); - t.true(text.includes('[^1][^2]')); - t.true( - text.includes( - '[^1]: {"type":"attachment","blobId":"blob_1","fileName":"a.txt","fileType":"text/plain"}' - ) - ); - t.true( - text.includes( - '[^2]: {"type":"attachment","blobId":"blob_2","fileName":"b.txt","fileType":"text/plain"}' - ) - ); -}); - -test('NativeProviderAdapter streamObject should map tool and text events', async t => { - let round = 0; - const dispatch = (_request: NativeLlmRequest) => - (async function* (): AsyncIterableIterator { - round += 1; - if (round === 1) { - yield { - type: 'tool_call', - call_id: 'call_1', - name: 'doc_read', - arguments: { doc_id: 'a1' }, - }; - yield { type: 'done', finish_reason: 'tool_calls' }; - return; - } - yield { type: 'text_delta', text: 'ok' }; - yield { type: 'done', finish_reason: 'stop' }; - })(); - - const adapter = new NativeProviderAdapter( - dispatch, - { - doc_read: { - inputSchema: z.object({ doc_id: z.string() }), - execute: async () => ({ markdown: '# a1' }), - }, - }, - 4 - ); - - const events = []; - for await (const event of adapter.streamObject({ - model: 'gpt-5-mini', - stream: true, - messages: [{ role: 'user', content: [{ type: 'text', text: 'read' }] }], - })) { - events.push(event); - } - - t.deepEqual( - events.map(event => event.type), - ['tool-call', 'tool-result', 'text-delta'] - ); - t.deepEqual(events[0], { - type: 'tool-call', - toolCallId: 'call_1', - toolName: 'doc_read', - args: { doc_id: 'a1' }, - }); -}); - -test('buildNativeRequest should include rust middleware from profile', async t => { - const { request } = await buildNativeRequest({ - model: 'gpt-5-mini', - messages: [{ role: 'user', content: 'hello' }], - tools: {}, - middleware: { - rust: { - request: ['normalize_messages', 'clamp_max_tokens'], - stream: ['stream_event_normalize', 'citation_indexing'], - }, - node: { - text: ['callout'], - }, - }, - }); - - t.deepEqual(request.middleware, { - request: ['normalize_messages', 'clamp_max_tokens'], - stream: ['stream_event_normalize', 'citation_indexing'], - }); -}); - -test('NativeProviderAdapter streamText should skip citation footnotes when disabled', async t => { - const adapter = new NativeProviderAdapter(mockDispatch, {}, 3, { - nodeTextMiddleware: ['callout'], - }); - const chunks: string[] = []; - for await (const chunk of adapter.streamText({ - model: 'gpt-5-mini', - stream: true, - messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }], - })) { - chunks.push(chunk); - } - - const text = chunks.join(''); - t.true(text.includes('Use [^1] now')); - t.false( - text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}') - ); -}); diff --git a/packages/backend/server/src/__tests__/copilot/native-provider.spec.ts b/packages/backend/server/src/__tests__/copilot/native-provider.spec.ts new file mode 100644 index 0000000000..1f66c5f2d8 --- /dev/null +++ b/packages/backend/server/src/__tests__/copilot/native-provider.spec.ts @@ -0,0 +1,1431 @@ +import test from 'ava'; +import { z } from 'zod'; + +import { CopilotPromptInvalid, CopilotProviderSideError } from '../../base'; +import type { + NativeLlmBackendConfig, + NativeLlmEmbeddingRequest, + NativeLlmEmbeddingResponse, + NativeLlmRequest, + NativeLlmRerankRequest, + NativeLlmRerankResponse, + NativeLlmStreamEvent, + NativeLlmStructuredRequest, + NativeLlmStructuredResponse, +} from '../../native'; +import { ProviderMiddlewareConfig } from '../../plugins/copilot/config'; +import { GeminiProvider } from '../../plugins/copilot/providers/gemini/gemini'; +import { GeminiVertexProvider } from '../../plugins/copilot/providers/gemini/vertex'; +import { + buildNativeRequest, + NativeProviderAdapter, +} from '../../plugins/copilot/providers/native'; +import { OpenAIProvider } from '../../plugins/copilot/providers/openai'; +import { PerplexityProvider } from '../../plugins/copilot/providers/perplexity'; +import { + CopilotProviderType, + ModelInputType, + ModelOutputType, + type PromptMessage, +} from '../../plugins/copilot/providers/types'; +import type { CopilotToolSet } from '../../plugins/copilot/tools'; + +const mockDispatch = () => + (async function* (): AsyncIterableIterator { + yield { type: 'text_delta', text: 'Use [^1] now' }; + yield { type: 'citation', index: 1, url: 'https://affine.pro' }; + yield { type: 'done', finish_reason: 'stop' }; + })(); + +function stream( + factory: () => NativeLlmStreamEvent[] +): AsyncIterableIterator { + return (async function* () { + for (const event of factory()) { + yield event; + } + })(); +} + +class TestGeminiProvider extends GeminiProvider<{ apiKey: string }> { + override readonly type = CopilotProviderType.Gemini; + override readonly models = [ + { + id: 'gemini-2.5-flash', + capabilities: [ + { + input: [ + ModelInputType.Text, + ModelInputType.Image, + ModelInputType.Audio, + ModelInputType.File, + ], + output: [ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Structured, + ], + }, + ], + }, + { + id: 'gemini-embedding-001', + capabilities: [ + { + input: [ModelInputType.Text], + output: [ModelOutputType.Embedding], + }, + ], + }, + ]; + readonly dispatchRequests: NativeLlmRequest[] = []; + readonly structuredRequests: NativeLlmStructuredRequest[] = []; + readonly embeddingRequests: NativeLlmEmbeddingRequest[] = []; + readonly remoteAttachmentRequests: string[] = []; + readonly remoteAttachmentSignals: Array = []; + readonly retryDelays: number[] = []; + remoteAttachmentResponses = new Map< + string, + { data: string; mimeType: string } + >(); + testTools: CopilotToolSet = {}; + testMiddleware: ProviderMiddlewareConfig = { + rust: { + request: ['normalize_messages', 'tool_schema_rewrite'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, + node: { + text: ['citation_footnote', 'callout'], + }, + }; + dispatchFactory: (request: NativeLlmRequest) => NativeLlmStreamEvent[] = + () => [ + { type: 'text_delta', text: 'native' }, + { type: 'done', finish_reason: 'stop' }, + ]; + structuredFactory: ( + request: NativeLlmStructuredRequest + ) => NativeLlmStructuredResponse = () => ({ + id: 'structured_1', + model: 'gemini-2.5-flash', + output_text: '{"summary":"AFFiNE native"}', + usage: { + prompt_tokens: 4, + completion_tokens: 3, + total_tokens: 7, + }, + finish_reason: 'stop', + }); + embeddingFactory: ( + request: NativeLlmEmbeddingRequest + ) => NativeLlmEmbeddingResponse = request => ({ + model: request.model, + embeddings: request.inputs.map((_, index) => [index + 0.1, index + 0.2]), + usage: { + prompt_tokens: request.inputs.length, + total_tokens: request.inputs.length, + }, + }); + + override configured() { + return true; + } + + protected override async createNativeConfig(): Promise { + return { + base_url: 'https://generativelanguage.googleapis.com/v1beta', + auth_token: 'api-key', + request_layer: 'gemini_api', + }; + } + + protected override createNativeDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return (request: NativeLlmRequest) => { + this.dispatchRequests.push(request); + return stream(() => this.dispatchFactory(request)); + }; + } + + protected override createNativeStructuredDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return async (request: NativeLlmStructuredRequest) => { + this.structuredRequests.push(request); + return this.structuredFactory(request); + }; + } + + protected override createNativeEmbeddingDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return async (request: NativeLlmEmbeddingRequest) => { + this.embeddingRequests.push(request); + return this.embeddingFactory(request); + }; + } + + protected override async fetchRemoteAttach( + url: string, + signal?: AbortSignal + ) { + this.remoteAttachmentRequests.push(url); + this.remoteAttachmentSignals.push(signal); + const response = this.remoteAttachmentResponses.get(url); + if (!response) { + throw new Error(`missing remote attachment stub for ${url}`); + } + return response; + } + + protected override async waitForStructuredRetry(delayMs: number) { + this.retryDelays.push(delayMs); + } + + protected override getActiveProviderMiddleware(): ProviderMiddlewareConfig { + return this.testMiddleware; + } + + protected override async getTools(): Promise { + return this.testTools; + } +} + +class TestGeminiVertexProvider extends GeminiVertexProvider { + testConfig = { + location: 'us-central1', + project: 'p1', + googleAuthOptions: {}, + } as any; + readonly dispatchRequests: NativeLlmRequest[] = []; + readonly remoteAttachmentRequests: string[] = []; + readonly remoteAttachmentSignals: Array = []; + remoteAttachmentResponses = new Map< + string, + { data: string; mimeType: string } + >(); + testTools: CopilotToolSet = {}; + testMiddleware: ProviderMiddlewareConfig = { + rust: { + request: ['normalize_messages', 'tool_schema_rewrite'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, + node: { + text: ['citation_footnote', 'callout'], + }, + }; + + override get config() { + return this.testConfig; + } + + override configured() { + return true; + } + + protected override async resolveVertexAuth() { + return { + baseUrl: 'https://vertex.example', + headers: () => ({ + Authorization: 'Bearer vertex-token', + 'x-goog-user-project': 'p1', + }), + fetch: undefined, + } as const; + } + + protected override createNativeDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return (request: NativeLlmRequest) => { + this.dispatchRequests.push(request); + return stream(() => [ + { type: 'text_delta', text: 'vertex native' }, + { type: 'done', finish_reason: 'stop' }, + ]); + }; + } + + // eslint-disable-next-line sonarjs/no-identical-functions + protected override async fetchRemoteAttach( + url: string, + signal?: AbortSignal + ) { + this.remoteAttachmentRequests.push(url); + this.remoteAttachmentSignals.push(signal); + const response = this.remoteAttachmentResponses.get(url); + if (!response) { + throw new Error(`missing remote attachment stub for ${url}`); + } + return response; + } + + protected override getActiveProviderMiddleware(): ProviderMiddlewareConfig { + return this.testMiddleware; + } + + protected override async getTools(): Promise { + return this.testTools; + } + + async exposeNativeConfig() { + return await this.createNativeConfig(); + } +} + +class TestOpenAIProvider extends OpenAIProvider { + override readonly models = [ + { + id: 'gpt-4.1', + capabilities: [ + { + input: [ModelInputType.Text, ModelInputType.Image], + output: [ + ModelOutputType.Text, + ModelOutputType.Structured, + ModelOutputType.Rerank, + ], + }, + { + input: [ModelInputType.Text], + output: [ModelOutputType.Embedding], + }, + ], + }, + { + id: 'gpt-5.2', + capabilities: [ + { + input: [ModelInputType.Text, ModelInputType.Image], + output: [ + ModelOutputType.Text, + ModelOutputType.Structured, + ModelOutputType.Rerank, + ], + }, + { + input: [ModelInputType.Text], + output: [ModelOutputType.Embedding], + }, + ], + }, + { + id: 'text-embedding-3-small', + capabilities: [ + { + input: [ModelInputType.Text], + output: [ModelOutputType.Embedding], + }, + ], + }, + ]; + + readonly structuredRequests: NativeLlmStructuredRequest[] = []; + readonly embeddingRequests: NativeLlmEmbeddingRequest[] = []; + readonly rerankRequests: NativeLlmRerankRequest[] = []; + testMiddleware: ProviderMiddlewareConfig = { + rust: { + request: ['normalize_messages', 'tool_schema_rewrite'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, + }; + + override get config() { + return { + apiKey: 'openai-key', + baseURL: 'https://api.openai.com/v1', + }; + } + + override configured() { + return true; + } + + protected override getActiveProviderMiddleware(): ProviderMiddlewareConfig { + return this.testMiddleware; + } + + protected override createNativeStructuredDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return async (request: NativeLlmStructuredRequest) => { + this.structuredRequests.push(request); + return { + id: 'structured_openai_1', + model: request.model, + output_text: '{"summary":"AFFiNE structured"}', + usage: { + prompt_tokens: 4, + completion_tokens: 3, + total_tokens: 7, + }, + finish_reason: 'stop', + }; + }; + } + + protected override createNativeEmbeddingDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return async (request: NativeLlmEmbeddingRequest) => { + this.embeddingRequests.push(request); + return { + model: request.model, + embeddings: request.inputs.map(() => [0.4, 0.5]), + usage: { + prompt_tokens: request.inputs.length, + total_tokens: request.inputs.length, + }, + }; + }; + } + + protected override createNativeRerankDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return async (request: NativeLlmRerankRequest) => { + this.rerankRequests.push(request); + return { + model: request.model, + scores: request.candidates.map(() => 0.8), + } satisfies NativeLlmRerankResponse; + }; + } +} + +class TestPerplexityProvider extends PerplexityProvider { + override get config() { + return { apiKey: 'perplexity-key' }; + } + + override configured() { + return true; + } +} + +test('NativeProviderAdapter streamText should append citation footnotes', async t => { + const adapter = new NativeProviderAdapter(mockDispatch, {}, 3); + const chunks: string[] = []; + for await (const chunk of adapter.streamText({ + model: 'gpt-5-mini', + stream: true, + messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }], + })) { + chunks.push(chunk); + } + + const text = chunks.join(''); + t.true(text.includes('Use [^1] now')); + t.true( + text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}') + ); +}); + +test('NativeProviderAdapter streamObject should append citation footnotes', async t => { + const adapter = new NativeProviderAdapter(mockDispatch, {}, 3); + const chunks = []; + for await (const chunk of adapter.streamObject({ + model: 'gpt-5-mini', + stream: true, + messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }], + })) { + chunks.push(chunk); + } + + t.deepEqual( + chunks.map(chunk => chunk.type), + ['text-delta', 'text-delta'] + ); + const text = chunks + .filter(chunk => chunk.type === 'text-delta') + .map(chunk => chunk.textDelta) + .join(''); + t.true(text.includes('Use [^1] now')); + t.true( + text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}') + ); +}); + +test('NativeProviderAdapter streamObject should append fallback attachment footnotes', async t => { + const dispatch = () => + (async function* (): AsyncIterableIterator { + yield { + type: 'tool_result', + call_id: 'call_1', + name: 'blob_read', + arguments: { blob_id: 'blob_1' }, + output: { + blobId: 'blob_1', + fileName: 'a.txt', + fileType: 'text/plain', + content: 'A', + }, + }; + yield { + type: 'tool_result', + call_id: 'call_2', + name: 'blob_read', + arguments: { blob_id: 'blob_2' }, + output: { + blobId: 'blob_2', + fileName: 'b.txt', + fileType: 'text/plain', + content: 'B', + }, + }; + yield { type: 'text_delta', text: 'Answer from files.' }; + yield { type: 'done', finish_reason: 'stop' }; + })(); + + const adapter = new NativeProviderAdapter(dispatch, {}, 3); + const chunks = []; + for await (const chunk of adapter.streamObject({ + model: 'gpt-5-mini', + stream: true, + messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }], + })) { + chunks.push(chunk); + } + + const text = chunks + .filter(chunk => chunk.type === 'text-delta') + .map(chunk => chunk.textDelta) + .join(''); + t.true(text.includes('Answer from files.')); + t.true(text.includes('[^1][^2]')); + t.true( + text.includes( + '[^1]: {"type":"attachment","blobId":"blob_1","fileName":"a.txt","fileType":"text/plain"}' + ) + ); + t.true( + text.includes( + '[^2]: {"type":"attachment","blobId":"blob_2","fileName":"b.txt","fileType":"text/plain"}' + ) + ); +}); + +test('NativeProviderAdapter streamObject should map tool and text events', async t => { + let round = 0; + const dispatch = (_request: NativeLlmRequest) => + (async function* (): AsyncIterableIterator { + round += 1; + if (round === 1) { + yield { + type: 'tool_call', + call_id: 'call_1', + name: 'doc_read', + arguments: { doc_id: 'a1' }, + }; + yield { type: 'done', finish_reason: 'tool_calls' }; + return; + } + yield { type: 'text_delta', text: 'ok' }; + yield { type: 'done', finish_reason: 'stop' }; + })(); + + const adapter = new NativeProviderAdapter( + dispatch, + { + doc_read: { + inputSchema: z.object({ doc_id: z.string() }), + execute: async () => ({ markdown: '# a1' }), + }, + }, + 4 + ); + + const events = []; + for await (const event of adapter.streamObject({ + model: 'gpt-5-mini', + stream: true, + messages: [{ role: 'user', content: [{ type: 'text', text: 'read' }] }], + })) { + events.push(event); + } + + t.deepEqual( + events.map(event => event.type), + ['tool-call', 'tool-result', 'text-delta'] + ); + t.deepEqual(events[0], { + type: 'tool-call', + toolCallId: 'call_1', + toolName: 'doc_read', + args: { doc_id: 'a1' }, + }); +}); + +test('buildNativeRequest should include rust middleware from profile', async t => { + const { request } = await buildNativeRequest({ + model: 'gpt-5-mini', + messages: [{ role: 'user', content: 'hello' }], + tools: {}, + middleware: { + rust: { + request: ['normalize_messages', 'clamp_max_tokens'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, + node: { + text: ['callout'], + }, + }, + }); + + t.deepEqual(request.middleware, { + request: ['normalize_messages', 'clamp_max_tokens'], + stream: ['stream_event_normalize', 'citation_indexing'], + }); +}); + +test('buildNativeRequest should preserve non-image attachment urls for native Gemini', async t => { + const { request } = await buildNativeRequest({ + model: 'gemini-2.5-flash', + messages: [ + { + role: 'user', + content: 'summarize this attachment', + attachments: ['https://example.com/a.pdf'], + params: { mimetype: 'application/pdf' }, + }, + ], + }); + + t.deepEqual(request.messages[0]?.content, [ + { type: 'text', text: 'summarize this attachment' }, + { + type: 'file', + source: { + url: 'https://example.com/a.pdf', + media_type: 'application/pdf', + }, + }, + ]); +}); + +test('buildNativeRequest should inline data url attachments for native Gemini', async t => { + const { request } = await buildNativeRequest({ + model: 'gemini-2.5-flash', + messages: [ + { + role: 'user', + content: 'read this note', + attachments: ['data:text/plain,hello%20world'], + params: { mimetype: 'text/plain' }, + }, + ], + }); + + t.deepEqual(request.messages[0]?.content, [ + { type: 'text', text: 'read this note' }, + { + type: 'file', + source: { + media_type: 'text/plain', + data: Buffer.from('hello world', 'utf8').toString('base64'), + }, + }, + ]); +}); + +test('buildNativeRequest should classify audio attachments for native Gemini', async t => { + const { request } = await buildNativeRequest({ + model: 'gemini-2.5-flash', + messages: [ + { + role: 'user', + content: 'transcribe this clip', + attachments: ['https://example.com/a.mp3'], + params: { mimetype: 'audio/mpeg' }, + }, + ], + }); + + t.deepEqual(request.messages[0]?.content, [ + { type: 'text', text: 'transcribe this clip' }, + { + type: 'audio', + source: { + url: 'https://example.com/a.mp3', + media_type: 'audio/mpeg', + }, + }, + ]); +}); + +test('buildNativeRequest should preserve bytes and file handle attachment sources', async t => { + const { request } = await buildNativeRequest({ + model: 'gemini-2.5-flash', + messages: [ + { + role: 'user', + content: 'inspect these assets', + attachments: [ + { + kind: 'bytes', + data: Buffer.from('hello', 'utf8').toString('base64'), + mimeType: 'text/plain', + fileName: 'hello.txt', + }, + { + kind: 'file_handle', + fileHandle: 'file_123', + mimeType: 'application/pdf', + fileName: 'report.pdf', + }, + ], + }, + ], + attachmentCapability: { + kinds: ['image', 'audio', 'file'], + sourceKinds: ['bytes', 'file_handle'], + }, + }); + + t.deepEqual(request.messages[0]?.content, [ + { type: 'text', text: 'inspect these assets' }, + { + type: 'file', + source: { + media_type: 'text/plain', + data: Buffer.from('hello', 'utf8').toString('base64'), + file_name: 'hello.txt', + }, + }, + { + type: 'file', + source: { + file_handle: 'file_123', + media_type: 'application/pdf', + file_name: 'report.pdf', + }, + }, + ]); +}); + +test('buildNativeRequest should reject attachments outside native admission matrix', async t => { + const error = await t.throwsAsync( + buildNativeRequest({ + model: 'gpt-4o', + messages: [ + { + role: 'user', + content: 'summarize this attachment', + attachments: ['https://example.com/a.pdf'], + params: { mimetype: 'application/pdf' }, + }, + ], + attachmentCapability: { + kinds: ['image'], + sourceKinds: ['url', 'data'], + allowRemoteUrls: true, + }, + }) + ); + + t.true(error instanceof CopilotPromptInvalid); + t.regex(error.message, /does not support file attachments/i); +}); + +test('buildNativeStructuredRequest should prefer explicit schema option', async t => { + const provider = new TestOpenAIProvider(); + const schema = z.object({ summary: z.string() }); + + await provider.structure( + { modelId: 'gpt-4.1' }, + [ + { + role: 'system', + content: 'Return JSON only.', + }, + { + role: 'user', + content: 'Summarize AFFiNE in one sentence.', + }, + ], + { schema } + ); + + t.deepEqual(provider.structuredRequests[0]?.schema, { + type: 'object', + properties: { summary: { type: 'string' } }, + required: ['summary'], + additionalProperties: false, + }); +}); + +test('buildNativeStructuredRequest should preserve caller strictness override', async t => { + const provider = new TestOpenAIProvider(); + + await provider.structure( + { modelId: 'gpt-4.1' }, + [ + { role: 'system', content: 'Return JSON only.' }, + { role: 'user', content: 'Summarize AFFiNE in one sentence.' }, + ], + { schema: z.object({ summary: z.string() }), strict: false } + ); + + t.is(provider.structuredRequests[0]?.strict, false); +}); + +test('NativeProviderAdapter streamText should skip citation footnotes when disabled', async t => { + const adapter = new NativeProviderAdapter(mockDispatch, {}, 3, { + nodeTextMiddleware: ['callout'], + }); + const chunks: string[] = []; + for await (const chunk of adapter.streamText({ + model: 'gpt-5-mini', + stream: true, + messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }], + })) { + chunks.push(chunk); + } + + const text = chunks.join(''); + t.true(text.includes('Use [^1] now')); + t.false( + text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}') + ); +}); + +test('GeminiProvider should use native path for text-only requests', async t => { + const provider = new TestGeminiProvider(); + + const result = await provider.text( + { modelId: 'gemini-2.5-flash' }, + [{ role: 'user', content: 'hello' }], + { reasoning: true } + ); + + t.is(result, 'native'); + t.is(provider.dispatchRequests.length, 1); + t.deepEqual(provider.dispatchRequests[0]?.reasoning, { + include_thoughts: true, + thinking_budget: 12000, + }); + t.deepEqual(provider.dispatchRequests[0]?.middleware, { + request: ['normalize_messages', 'tool_schema_rewrite'], + stream: ['stream_event_normalize', 'citation_indexing'], + }); +}); + +test('GeminiProvider should use native path for structured requests', async t => { + const provider = new TestGeminiProvider(); + + const schema = z.object({ summary: z.string() }); + const result = await provider.structure( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'system', + content: 'Return JSON only.', + }, + { + role: 'user', + content: 'Summarize AFFiNE in one short sentence.', + }, + ], + { schema } + ); + + t.is(provider.structuredRequests.length, 1); + t.deepEqual(provider.structuredRequests[0]?.schema, { + type: 'object', + properties: { + summary: { + type: 'string', + }, + }, + required: ['summary'], + additionalProperties: false, + }); + t.deepEqual(JSON.parse(result), { summary: 'AFFiNE native' }); +}); + +test('GeminiProvider should retry only reparsable structured responses', async t => { + const provider = new TestGeminiProvider(); + let attempts = 0; + provider.structuredFactory = () => { + attempts += 1; + return { + id: `structured_retry_${attempts}`, + model: 'gemini-2.5-flash', + output_text: + attempts === 1 ? '```json\n{"summary":1}\n```' : '{"summary":"ok"}', + usage: { + prompt_tokens: 4, + completion_tokens: 3, + total_tokens: 7, + }, + finish_reason: 'stop', + }; + }; + + const result = await provider.structure( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'system', + content: 'Return JSON only.', + }, + { + role: 'user', + content: 'Summarize AFFiNE in one short sentence.', + }, + ], + { schema: z.object({ summary: z.string() }), maxRetries: 2 } + ); + + t.is(attempts, 2); + t.deepEqual(JSON.parse(result), { summary: 'ok' }); +}); + +test('GeminiProvider should treat maxRetries as retry count for backend failures', async t => { + const provider = new TestGeminiProvider(); + let attempts = 0; + provider.structuredFactory = () => { + attempts += 1; + throw new Error('backend down'); + }; + + const error = await t.throwsAsync( + provider.structure( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'system', + content: 'Return JSON only.', + }, + { + role: 'user', + content: 'Summarize AFFiNE in one short sentence.', + }, + ], + { schema: z.object({ summary: z.string() }), maxRetries: 2 } + ) + ); + + t.is(attempts, 3); + t.deepEqual(provider.retryDelays, [2_000, 4_000]); + t.regex(error.message, /backend down/); +}); + +test('GeminiProvider should use native structured path for audio attachments', async t => { + const provider = new TestGeminiProvider(); + const inlineData = Buffer.from('audio-bytes', 'utf8').toString('base64'); + provider.remoteAttachmentResponses.set('https://example.com/a.mp3', { + data: inlineData, + mimeType: 'audio/mpeg', + }); + provider.structuredFactory = () => ({ + id: 'structured_audio_1', + model: 'gemini-2.5-flash', + output_text: '[{"a":"Speaker 1","s":0,"e":1,"t":"Hello"}]', + usage: { + prompt_tokens: 4, + completion_tokens: 3, + total_tokens: 7, + }, + finish_reason: 'stop', + }); + + const result = await provider.structure( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'system', + content: 'Return JSON only.', + }, + { + role: 'user', + content: 'transcribe the audio', + attachments: ['https://example.com/a.mp3'], + params: { mimetype: 'audio/mpeg' }, + }, + ], + { + schema: z.array( + z.object({ a: z.string(), s: z.number(), e: z.number(), t: z.string() }) + ), + } + ); + + t.is(provider.structuredRequests.length, 1); + t.deepEqual(provider.structuredRequests[0]?.messages[1]?.content, [ + { type: 'text', text: 'transcribe the audio' }, + { + type: 'audio', + source: { + data: inlineData, + media_type: 'audio/mpeg', + }, + }, + ]); + t.deepEqual(provider.remoteAttachmentRequests, ['https://example.com/a.mp3']); + t.deepEqual(JSON.parse(result), [{ a: 'Speaker 1', s: 0, e: 1, t: 'Hello' }]); +}); + +test('GeminiProvider should use native path for embeddings', async t => { + const provider = new TestGeminiProvider(); + + const result = await provider.embedding( + { modelId: 'gemini-embedding-001' }, + ['first', 'second'], + { dimensions: 3 } + ); + + t.deepEqual(result, [ + [0.1, 0.2], + [1.1, 1.2], + ]); + t.is(provider.embeddingRequests.length, 1); + t.deepEqual(provider.embeddingRequests[0], { + model: 'gemini-embedding-001', + inputs: ['first', 'second'], + dimensions: 3, + task_type: 'RETRIEVAL_DOCUMENT', + }); +}); + +test('GeminiProvider should use native path for non-image attachments', async t => { + const provider = new TestGeminiProvider(); + const inlineData = Buffer.from('pdf-bytes', 'utf8').toString('base64'); + provider.remoteAttachmentResponses.set('https://example.com/a.pdf', { + data: inlineData, + mimeType: 'application/pdf', + }); + const messages: PromptMessage[] = [ + { + role: 'user', + content: 'summarize this file', + attachments: ['https://example.com/a.pdf'], + params: { mimetype: 'application/pdf' }, + }, + ]; + + const result = await provider.text( + { modelId: 'gemini-2.5-flash' }, + messages, + {} + ); + + t.is(result, 'native'); + t.is(provider.dispatchRequests.length, 1); + t.deepEqual(provider.dispatchRequests[0]?.messages[0]?.content, [ + { type: 'text', text: 'summarize this file' }, + { + type: 'file', + source: { + data: inlineData, + media_type: 'application/pdf', + }, + }, + ]); +}); + +test('GeminiProvider should inline remote image attachments for text requests', async t => { + const provider = new TestGeminiProvider(); + const inlineData = Buffer.from('image-bytes', 'utf8').toString('base64'); + provider.remoteAttachmentResponses.set('https://example.com/a.jpg', { + data: inlineData, + mimeType: 'image/jpeg', + }); + + const result = await provider.text({ modelId: 'gemini-2.5-flash' }, [ + { + role: 'user', + content: 'describe this image', + attachments: ['https://example.com/a.jpg'], + }, + ]); + + t.is(result, 'native'); + t.deepEqual(provider.dispatchRequests[0]?.messages[0]?.content, [ + { type: 'text', text: 'describe this image' }, + { + type: 'image', + source: { + data: inlineData, + media_type: 'image/jpeg', + }, + }, + ]); +}); + +test('GeminiProvider should pass abort signal to remote attachment prefetch', async t => { + const provider = new TestGeminiProvider(); + provider.remoteAttachmentResponses.set('https://example.com/a.jpg', { + data: Buffer.from('image-bytes', 'utf8').toString('base64'), + mimeType: 'image/jpeg', + }); + const controller = new AbortController(); + + await provider.text( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'user', + content: 'describe this image', + attachments: ['https://example.com/a.jpg'], + }, + ], + { signal: controller.signal } + ); + + t.deepEqual(provider.remoteAttachmentRequests, ['https://example.com/a.jpg']); + t.is(provider.remoteAttachmentSignals[0], controller.signal); +}); + +test('GeminiProvider should classify downloaded audio-only WebM attachments as audio', async t => { + const provider = new TestGeminiProvider(); + const inlineData = Buffer.from('audio-bytes', 'utf8').toString('base64'); + provider.remoteAttachmentResponses.set('https://example.com/a.webm', { + data: inlineData, + mimeType: 'audio/webm', + }); + + const result = await provider.text( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'user', + content: 'transcribe this clip', + attachments: ['https://example.com/a.webm'], + }, + ], + {} + ); + + t.is(result, 'native'); + t.deepEqual(provider.dispatchRequests[0]?.messages[0]?.content, [ + { type: 'text', text: 'transcribe this clip' }, + { type: 'audio', source: { data: inlineData, media_type: 'audio/webm' } }, + ]); +}); + +test('GeminiProvider should preserve Google file urls for native Gemini API', async t => { + const provider = new TestGeminiProvider(); + + await provider.text({ modelId: 'gemini-2.5-flash' }, [ + { + role: 'user', + content: 'summarize this file', + attachments: [ + 'https://generativelanguage.googleapis.com/v1beta/files/file-123', + ], + params: { mimetype: 'application/pdf' }, + }, + ]); + + t.deepEqual(provider.remoteAttachmentRequests, []); + t.deepEqual(provider.dispatchRequests[0]?.messages[0]?.content, [ + { type: 'text', text: 'summarize this file' }, + { + type: 'file', + source: { + url: 'https://generativelanguage.googleapis.com/v1beta/files/file-123', + media_type: 'application/pdf', + }, + }, + ]); +}); + +test('PerplexityProvider should ignore attachments during text model matching', async t => { + const provider = new TestPerplexityProvider(); + let capturedRequest: NativeLlmRequest | undefined; + + (provider as any).getActiveProviderMiddleware = () => ({}); + (provider as any).getTools = async () => ({}); + (provider as any).createNativeAdapter = () => ({ + text: async (request: NativeLlmRequest) => { + capturedRequest = request; + return 'ok'; + }, + }); + + const result = await provider.text( + { modelId: 'sonar' }, + [ + { + role: 'user', + content: 'summarize this', + attachments: ['https://example.com/a.pdf'], + params: { mimetype: 'application/pdf' }, + }, + ], + {} + ); + + t.is(result, 'ok'); + t.deepEqual(capturedRequest?.messages[0]?.content, [ + { type: 'text', text: 'summarize this' }, + ]); +}); + +test('GeminiProvider should reject unsupported attachment schemes at input validation', async t => { + const provider = new TestGeminiProvider(); + + const error = await t.throwsAsync( + provider.text( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'user', + content: 'read this attachment', + attachments: ['blob:https://example.com/file-id'], + params: { mimetype: 'application/pdf' }, + }, + ], + {} + ) + ); + + t.true(error instanceof CopilotPromptInvalid); + t.regex(error.message, /attachments must use https\?:\/\/, gs:\/\/ or data:/); + t.is(provider.dispatchRequests.length, 0); +}); + +test('GeminiProvider should validate malformed attachments before canonicalization', async t => { + const provider = new TestGeminiProvider(); + + const error = await t.throwsAsync( + provider.text( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'user', + content: 'read this attachment', + attachments: [{ kind: 'url' }], + }, + ] as any, + {} + ) + ); + + t.true(error instanceof CopilotPromptInvalid); + t.regex(error.message, /attachments\[0\]/); + t.is(provider.dispatchRequests.length, 0); +}); + +test('GeminiProvider should drive tool loop on native path', async t => { + const provider = new TestGeminiProvider(); + provider.testTools = { + doc_read: { + inputSchema: z.object({ doc_id: z.string() }), + execute: async args => ({ markdown: `# ${(args as any).doc_id}` }), + }, + }; + provider.dispatchFactory = request => { + const hasToolResult = request.messages.some( + message => message.role === 'tool' + ); + if (!hasToolResult) { + return [ + { + type: 'tool_call', + call_id: 'call_1', + name: 'doc_read', + arguments: { doc_id: 'a1' }, + }, + { type: 'done', finish_reason: 'tool_calls' }, + ]; + } + + return [ + { type: 'text_delta', text: 'after tool' }, + { type: 'done', finish_reason: 'stop' }, + ]; + }; + + const result = await provider.text( + { modelId: 'gemini-2.5-flash' }, + [{ role: 'user', content: 'read doc a1' }], + {} + ); + + t.true(result.includes('after tool')); + t.is(provider.dispatchRequests.length, 2); + t.true( + provider.dispatchRequests[1]?.messages.some( + message => message.role === 'tool' + ) + ); +}); + +test('GeminiVertexProvider should prefetch bearer token for native config', async t => { + const provider = new TestGeminiVertexProvider(); + + const config = await provider.exposeNativeConfig(); + + t.deepEqual(config, { + base_url: 'https://vertex.example', + auth_token: 'vertex-token', + request_layer: 'gemini_vertex', + }); +}); + +test('GeminiVertexProvider should preserve remote http attachments like Vertex SDK', async t => { + const provider = new TestGeminiVertexProvider(); + + const result = await provider.text( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'user', + content: 'transcribe the audio', + attachments: ['https://example.com/a.mp3'], + }, + ], + {} + ); + + t.is(result, 'vertex native'); + t.deepEqual(provider.remoteAttachmentRequests, []); + t.deepEqual(provider.dispatchRequests[0]?.messages[0]?.content, [ + { type: 'text', text: 'transcribe the audio' }, + { + type: 'audio', + source: { + url: 'https://example.com/a.mp3', + media_type: 'audio/mpeg', + }, + }, + ]); +}); + +test('GeminiVertexProvider should preserve gs urls for native Vertex requests', async t => { + const provider = new TestGeminiVertexProvider(); + + const result = await provider.text( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'user', + content: 'transcribe the audio', + attachments: ['gs://bucket/audio.opus'], + }, + ], + {} + ); + + t.is(result, 'vertex native'); + t.deepEqual(provider.remoteAttachmentRequests, []); + t.deepEqual(provider.dispatchRequests[0]?.messages[0]?.content, [ + { type: 'text', text: 'transcribe the audio' }, + { + type: 'audio', + source: { + url: 'gs://bucket/audio.opus', + media_type: 'audio/opus', + }, + }, + ]); +}); + +test('OpenAIProvider should use native structured dispatch', async t => { + const provider = new TestOpenAIProvider(); + const schema = z.object({ summary: z.string() }); + + const result = await provider.structure( + { modelId: 'gpt-4.1' }, + [ + { + role: 'system', + content: 'Return JSON only.', + }, + { + role: 'user', + content: 'Summarize AFFiNE in one sentence.', + }, + ], + { schema } + ); + + t.deepEqual(JSON.parse(result), { summary: 'AFFiNE structured' }); + t.is(provider.structuredRequests.length, 1); + t.deepEqual(provider.structuredRequests[0]?.schema, { + type: 'object', + properties: { + summary: { + type: 'string', + }, + }, + required: ['summary'], + additionalProperties: false, + }); +}); + +test('OpenAIProvider should use native embedding dispatch', async t => { + const provider = new TestOpenAIProvider(); + + const result = await provider.embedding( + { modelId: 'text-embedding-3-small' }, + ['alpha', 'beta'], + { dimensions: 8 } + ); + + t.deepEqual(result, [ + [0.4, 0.5], + [0.4, 0.5], + ]); + t.is(provider.embeddingRequests.length, 1); + t.deepEqual(provider.embeddingRequests[0], { + model: 'text-embedding-3-small', + inputs: ['alpha', 'beta'], + dimensions: 8, + task_type: 'RETRIEVAL_DOCUMENT', + }); +}); + +test('OpenAIProvider should use native rerank dispatch', async t => { + const provider = new TestOpenAIProvider(); + + const scores = await provider.rerank( + { modelId: 'gpt-4.1' }, + { + query: 'programming', + candidates: [ + { id: 'react', text: 'React is a UI library.' }, + { id: 'weather', text: 'The park is sunny today.' }, + ], + } + ); + + t.deepEqual(scores, [0.8, 0.8]); + t.is(provider.rerankRequests.length, 1); + t.is(provider.rerankRequests[0]?.model, 'gpt-4.1'); + t.is(provider.rerankRequests[0]?.query, 'programming'); + t.deepEqual(provider.rerankRequests[0]?.candidates, [ + { id: 'react', text: 'React is a UI library.' }, + { id: 'weather', text: 'The park is sunny today.' }, + ]); +}); + +test('OpenAIProvider rerank should normalize native dispatch errors', async t => { + class ErroringOpenAIProvider extends TestOpenAIProvider { + protected override createNativeRerankDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return async () => { + throw new Error('native rerank exploded'); + }; + } + } + + const provider = new ErroringOpenAIProvider(); + + const error = await t.throwsAsync( + provider.rerank( + { modelId: 'gpt-4.1' }, + { + query: 'programming', + candidates: [{ id: 'react', text: 'React is a UI library.' }], + } + ) + ); + + t.true(error instanceof CopilotProviderSideError); + t.regex(error.message, /native rerank exploded/i); +}); diff --git a/packages/backend/server/src/__tests__/copilot/provider-native.spec.ts b/packages/backend/server/src/__tests__/copilot/provider-native.spec.ts index 890a158a45..58e04d4f63 100644 --- a/packages/backend/server/src/__tests__/copilot/provider-native.spec.ts +++ b/packages/backend/server/src/__tests__/copilot/provider-native.spec.ts @@ -1,9 +1,13 @@ +import serverNativeModule from '@affine/server-native'; import test from 'ava'; +import type { NativeLlmRerankRequest } from '../../native'; import { ProviderMiddlewareConfig } from '../../plugins/copilot/config'; -import { normalizeOpenAIOptionsForModel } from '../../plugins/copilot/providers/openai'; +import { + normalizeOpenAIOptionsForModel, + OpenAIProvider, +} from '../../plugins/copilot/providers/openai'; import { CopilotProvider } from '../../plugins/copilot/providers/provider'; -import { normalizeRerankModel } from '../../plugins/copilot/providers/rerank'; import { CopilotProviderType, ModelInputType, @@ -46,6 +50,33 @@ class TestOpenAIProvider extends CopilotProvider<{ apiKey: string }> { } } +class NativeRerankProtocolProvider extends OpenAIProvider { + override readonly models = [ + { + id: 'gpt-5.2', + capabilities: [ + { + input: [ModelInputType.Text], + output: [ModelOutputType.Text, ModelOutputType.Rerank], + defaultForOutputType: true, + }, + ], + }, + ]; + + override get config() { + return { + apiKey: 'test-key', + baseURL: 'https://api.openai.com/v1', + oldApiStyle: false, + }; + } + + override configured() { + return true; + } +} + function createProvider(profileMiddleware?: ProviderMiddlewareConfig) { const provider = new TestOpenAIProvider(); (provider as any).AFFiNEConfig = { @@ -126,14 +157,44 @@ test('normalizeOpenAIOptionsForModel should keep options for gpt-4.1', t => { ); }); -test('normalizeOpenAIRerankModel should keep supported rerank models', t => { - t.is(normalizeRerankModel('gpt-4.1'), 'gpt-4.1'); - t.is(normalizeRerankModel('gpt-4.1-mini'), 'gpt-4.1-mini'); - t.is(normalizeRerankModel('gpt-5.2'), 'gpt-5.2'); -}); +test('OpenAI rerank should always use chat-completions native protocol', async t => { + const provider = new NativeRerankProtocolProvider(); + let capturedProtocol: string | undefined; + let capturedRequest: NativeLlmRerankRequest | undefined; -test('normalizeOpenAIRerankModel should fall back for unsupported models', t => { - t.is(normalizeRerankModel('gpt-5-mini'), 'gpt-5.2'); - t.is(normalizeRerankModel('gemini-2.5-flash'), 'gpt-5.2'); - t.is(normalizeRerankModel(undefined), 'gpt-5.2'); + const original = (serverNativeModule as any).llmRerankDispatch; + (serverNativeModule as any).llmRerankDispatch = ( + protocol: string, + _backendConfigJson: string, + requestJson: string + ) => { + capturedProtocol = protocol; + capturedRequest = JSON.parse(requestJson) as NativeLlmRerankRequest; + return JSON.stringify({ model: 'gpt-5.2', scores: [0.9, 0.1] }); + }; + t.teardown(() => { + (serverNativeModule as any).llmRerankDispatch = original; + }); + + const scores = await provider.rerank( + { modelId: 'gpt-5.2' }, + { + query: 'programming', + candidates: [ + { id: 'react', text: 'React is a UI library.' }, + { id: 'weather', text: 'The weather is sunny today.' }, + ], + } + ); + + t.deepEqual(scores, [0.9, 0.1]); + t.is(capturedProtocol, 'openai_chat'); + t.deepEqual(capturedRequest, { + model: 'gpt-5.2', + query: 'programming', + candidates: [ + { id: 'react', text: 'React is a UI library.' }, + { id: 'weather', text: 'The weather is sunny today.' }, + ], + }); }); diff --git a/packages/backend/server/src/__tests__/copilot/tool-call-loop.spec.ts b/packages/backend/server/src/__tests__/copilot/tool-call-loop.spec.ts index a8f76dea69..846d7d96eb 100644 --- a/packages/backend/server/src/__tests__/copilot/tool-call-loop.spec.ts +++ b/packages/backend/server/src/__tests__/copilot/tool-call-loop.spec.ts @@ -34,6 +34,56 @@ test('ToolCallAccumulator should merge deltas and complete tool call', t => { id: 'call_1', name: 'doc_read', args: { doc_id: 'a1' }, + rawArgumentsText: '{"doc_id":"a1"}', + thought: undefined, + }); +}); + +test('ToolCallAccumulator should preserve invalid JSON instead of swallowing it', t => { + const accumulator = new ToolCallAccumulator(); + + accumulator.feedDelta({ + type: 'tool_call_delta', + call_id: 'call_1', + name: 'doc_read', + arguments_delta: '{"doc_id":', + }); + + const pending = accumulator.drainPending(); + + t.is(pending.length, 1); + t.deepEqual(pending[0]?.id, 'call_1'); + t.deepEqual(pending[0]?.name, 'doc_read'); + t.deepEqual(pending[0]?.args, {}); + t.is(pending[0]?.rawArgumentsText, '{"doc_id":'); + t.truthy(pending[0]?.argumentParseError); +}); + +test('ToolCallAccumulator should prefer native canonical tool arguments metadata', t => { + const accumulator = new ToolCallAccumulator(); + + accumulator.feedDelta({ + type: 'tool_call_delta', + call_id: 'call_1', + name: 'doc_read', + arguments_delta: '{"stale":true}', + }); + + const completed = accumulator.complete({ + type: 'tool_call', + call_id: 'call_1', + name: 'doc_read', + arguments: {}, + arguments_text: '{"doc_id":"a1"}', + arguments_error: 'invalid json', + }); + + t.deepEqual(completed, { + id: 'call_1', + name: 'doc_read', + args: {}, + rawArgumentsText: '{"doc_id":"a1"}', + argumentParseError: 'invalid json', thought: undefined, }); }); @@ -71,6 +121,8 @@ test('ToolSchemaExtractor should convert zod schema to json schema', t => { test('ToolCallLoop should execute tool call and continue to next round', async t => { const dispatchRequests: NativeLlmRequest[] = []; + const originalMessages = [{ role: 'user', content: 'read doc' }] as const; + const signal = new AbortController().signal; const dispatch = (request: NativeLlmRequest) => { dispatchRequests.push(request); @@ -100,13 +152,17 @@ test('ToolCallLoop should execute tool call and continue to next round', async t }; let executedArgs: Record | null = null; + let executedMessages: unknown; + let executedSignal: AbortSignal | undefined; const loop = new ToolCallLoop( dispatch, { doc_read: { inputSchema: z.object({ doc_id: z.string() }), - execute: async args => { + execute: async (args, options) => { executedArgs = args; + executedMessages = options.messages; + executedSignal = options.signal; return { markdown: '# doc' }; }, }, @@ -114,6 +170,92 @@ test('ToolCallLoop should execute tool call and continue to next round', async t 4 ); + const events: NativeLlmStreamEvent[] = []; + for await (const event of loop.run( + { + model: 'gpt-5-mini', + stream: true, + messages: [ + { role: 'user', content: [{ type: 'text', text: 'read doc' }] }, + ], + }, + signal, + [...originalMessages] + )) { + events.push(event); + } + + t.deepEqual(executedArgs, { doc_id: 'a1' }); + t.deepEqual(executedMessages, originalMessages); + t.is(executedSignal, signal); + t.true( + dispatchRequests[1]?.messages.some(message => message.role === 'tool') + ); + t.deepEqual(dispatchRequests[1]?.messages[1]?.content, [ + { + type: 'tool_call', + call_id: 'call_1', + name: 'doc_read', + arguments: { doc_id: 'a1' }, + arguments_text: '{"doc_id":"a1"}', + arguments_error: undefined, + thought: undefined, + }, + ]); + t.deepEqual(dispatchRequests[1]?.messages[2]?.content, [ + { + type: 'tool_result', + call_id: 'call_1', + name: 'doc_read', + arguments: { doc_id: 'a1' }, + arguments_text: '{"doc_id":"a1"}', + arguments_error: undefined, + output: { markdown: '# doc' }, + is_error: undefined, + }, + ]); + t.deepEqual( + events.map(event => event.type), + ['tool_call', 'tool_result', 'text_delta', 'done'] + ); +}); + +test('ToolCallLoop should surface invalid JSON as tool error without executing', async t => { + let executed = false; + let round = 0; + const loop = new ToolCallLoop( + request => { + round += 1; + const hasToolResult = request.messages.some( + message => message.role === 'tool' + ); + return (async function* (): AsyncIterableIterator { + if (!hasToolResult && round === 1) { + yield { + type: 'tool_call_delta', + call_id: 'call_1', + name: 'doc_read', + arguments_delta: '{"doc_id":', + }; + yield { type: 'done', finish_reason: 'tool_calls' }; + return; + } + + yield { type: 'done', finish_reason: 'stop' }; + })(); + }, + { + doc_read: { + inputSchema: z.object({ doc_id: z.string() }), + execute: async () => { + executed = true; + return { markdown: '# doc' }; + }, + }, + }, + 2 + ); + const events: NativeLlmStreamEvent[] = []; for await (const event of loop.run({ model: 'gpt-5-mini', @@ -123,12 +265,24 @@ test('ToolCallLoop should execute tool call and continue to next round', async t events.push(event); } - t.deepEqual(executedArgs, { doc_id: 'a1' }); - t.true( - dispatchRequests[1]?.messages.some(message => message.role === 'tool') - ); - t.deepEqual( - events.map(event => event.type), - ['tool_call', 'tool_result', 'text_delta', 'done'] - ); + t.false(executed); + t.true(events[0]?.type === 'tool_result'); + t.deepEqual(events[0], { + type: 'tool_result', + call_id: 'call_1', + name: 'doc_read', + arguments: {}, + arguments_text: '{"doc_id":', + arguments_error: + events[0]?.type === 'tool_result' ? events[0].arguments_error : undefined, + output: { + message: 'Invalid tool arguments JSON', + rawArguments: '{"doc_id":', + error: + events[0]?.type === 'tool_result' + ? events[0].arguments_error + : undefined, + }, + is_error: true, + }); }); diff --git a/packages/backend/server/src/__tests__/copilot/utils.spec.ts b/packages/backend/server/src/__tests__/copilot/utils.spec.ts index 12c13e5fee..94844cc176 100644 --- a/packages/backend/server/src/__tests__/copilot/utils.spec.ts +++ b/packages/backend/server/src/__tests__/copilot/utils.spec.ts @@ -1,12 +1,6 @@ import test from 'ava'; -import { z } from 'zod'; -import { - chatToGPTMessage, - CitationFootnoteFormatter, - CitationParser, - StreamPatternParser, -} from '../../plugins/copilot/providers/utils'; +import { CitationFootnoteFormatter } from '../../plugins/copilot/providers/utils'; test('CitationFootnoteFormatter should format sorted footnotes from citation events', t => { const formatter = new CitationFootnoteFormatter(); @@ -50,67 +44,3 @@ test('CitationFootnoteFormatter should overwrite duplicated index with latest ur '[^1]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fnew"}' ); }); - -test('StreamPatternParser should keep state across chunks', t => { - const parser = new StreamPatternParser(pattern => { - if (pattern.kind === 'wrappedLink') { - return `[^${pattern.url}]`; - } - if (pattern.kind === 'index') { - return `[#${pattern.value}]`; - } - return `[${pattern.text}](${pattern.url})`; - }); - - const first = parser.write('ref ([AFFiNE](https://affine.pro'); - const second = parser.write(')) and [2]'); - - t.is(first, 'ref '); - t.is(second, '[^https://affine.pro] and [#2]'); - t.is(parser.end(), ''); -}); - -test('CitationParser should convert wrapped links to numbered footnotes', t => { - const parser = new CitationParser(); - - const output = parser.parse('Use ([AFFiNE](https://affine.pro)) now'); - t.is(output, 'Use [^1] now'); - t.regex( - parser.end(), - /\[\^1\]: \{"type":"url","url":"https%3A%2F%2Faffine.pro"\}/ - ); -}); - -test('chatToGPTMessage should not mutate input and should keep system schema', async t => { - const schema = z.object({ - query: z.string(), - }); - const messages = [ - { - role: 'system' as const, - content: 'You are helper', - params: { schema }, - }, - { - role: 'user' as const, - content: '', - attachments: ['https://example.com/a.png'], - }, - ]; - const firstRef = messages[0]; - const secondRef = messages[1]; - const [system, normalized, parsedSchema] = await chatToGPTMessage( - messages, - false - ); - - t.is(system, 'You are helper'); - t.is(parsedSchema, schema); - t.is(messages.length, 2); - t.is(messages[0], firstRef); - t.is(messages[1], secondRef); - t.deepEqual(normalized[0], { - role: 'user', - content: [{ type: 'text', text: '[no content]' }], - }); -}); diff --git a/packages/backend/server/src/__tests__/mocks/copilot.mock.ts b/packages/backend/server/src/__tests__/mocks/copilot.mock.ts index f10dd51049..b43dae79a4 100644 --- a/packages/backend/server/src/__tests__/mocks/copilot.mock.ts +++ b/packages/backend/server/src/__tests__/mocks/copilot.mock.ts @@ -33,7 +33,7 @@ export class MockCopilotProvider extends OpenAIProvider { id: 'test-image', capabilities: [ { - input: [ModelInputType.Text], + input: [ModelInputType.Text, ModelInputType.Image], output: [ModelOutputType.Image], defaultForOutputType: true, }, diff --git a/packages/backend/server/src/models/copilot-session.ts b/packages/backend/server/src/models/copilot-session.ts index 0cc1ee9b4a..b5d177e27b 100644 --- a/packages/backend/server/src/models/copilot-session.ts +++ b/packages/backend/server/src/models/copilot-session.ts @@ -10,6 +10,7 @@ import { CopilotSessionNotFound, } from '../base'; import { getTokenEncoder } from '../native'; +import type { PromptAttachment } from '../plugins/copilot/providers/types'; import { BaseModel } from './base'; export enum SessionType { @@ -24,7 +25,7 @@ type ChatPrompt = { model: string; }; -type ChatAttachment = { attachment: string; mimeType: string } | string; +type ChatAttachment = PromptAttachment; type ChatStreamObject = { type: 'text-delta' | 'reasoning' | 'tool-call' | 'tool-result'; @@ -173,22 +174,105 @@ export class CopilotSessionModel extends BaseModel { } return attachments - .map(attachment => - typeof attachment === 'string' - ? (this.sanitizeString(attachment) ?? '') - : { - attachment: - this.sanitizeString(attachment.attachment) ?? - attachment.attachment, + .map(attachment => { + if (typeof attachment === 'string') { + return this.sanitizeString(attachment) ?? ''; + } + + if ('attachment' in attachment) { + return { + attachment: + this.sanitizeString(attachment.attachment) ?? + attachment.attachment, + mimeType: + this.sanitizeString(attachment.mimeType) ?? attachment.mimeType, + }; + } + + switch (attachment.kind) { + case 'url': + return { + ...attachment, + url: this.sanitizeString(attachment.url) ?? attachment.url, mimeType: this.sanitizeString(attachment.mimeType) ?? attachment.mimeType, - } - ) + fileName: + this.sanitizeString(attachment.fileName) ?? attachment.fileName, + providerHint: attachment.providerHint + ? { + provider: + this.sanitizeString(attachment.providerHint.provider) ?? + attachment.providerHint.provider, + kind: + this.sanitizeString(attachment.providerHint.kind) ?? + attachment.providerHint.kind, + } + : undefined, + }; + case 'data': + case 'bytes': + return { + ...attachment, + data: this.sanitizeString(attachment.data) ?? attachment.data, + mimeType: + this.sanitizeString(attachment.mimeType) ?? attachment.mimeType, + fileName: + this.sanitizeString(attachment.fileName) ?? attachment.fileName, + providerHint: attachment.providerHint + ? { + provider: + this.sanitizeString(attachment.providerHint.provider) ?? + attachment.providerHint.provider, + kind: + this.sanitizeString(attachment.providerHint.kind) ?? + attachment.providerHint.kind, + } + : undefined, + }; + case 'file_handle': + return { + ...attachment, + fileHandle: + this.sanitizeString(attachment.fileHandle) ?? + attachment.fileHandle, + mimeType: + this.sanitizeString(attachment.mimeType) ?? attachment.mimeType, + fileName: + this.sanitizeString(attachment.fileName) ?? attachment.fileName, + providerHint: attachment.providerHint + ? { + provider: + this.sanitizeString(attachment.providerHint.provider) ?? + attachment.providerHint.provider, + kind: + this.sanitizeString(attachment.providerHint.kind) ?? + attachment.providerHint.kind, + } + : undefined, + }; + } + + return attachment; + }) .filter(attachment => { if (typeof attachment === 'string') { return !!attachment; } - return !!attachment.attachment && !!attachment.mimeType; + if ('attachment' in attachment) { + return !!attachment.attachment && !!attachment.mimeType; + } + + switch (attachment.kind) { + case 'url': + return !!attachment.url; + case 'data': + case 'bytes': + return !!attachment.data && !!attachment.mimeType; + case 'file_handle': + return !!attachment.fileHandle; + } + + return false; }); } diff --git a/packages/backend/server/src/native.ts b/packages/backend/server/src/native.ts index 9e6514cdae..1eb481df31 100644 --- a/packages/backend/server/src/native.ts +++ b/packages/backend/server/src/native.ts @@ -65,6 +65,21 @@ type NativeLlmModule = { backendConfigJson: string, requestJson: string ) => string | Promise; + llmStructuredDispatch?: ( + protocol: string, + backendConfigJson: string, + requestJson: string + ) => string | Promise; + llmEmbeddingDispatch?: ( + protocol: string, + backendConfigJson: string, + requestJson: string + ) => string | Promise; + llmRerankDispatch?: ( + protocol: string, + backendConfigJson: string, + requestJson: string + ) => string | Promise; llmDispatchStream?: ( protocol: string, backendConfigJson: string, @@ -79,12 +94,20 @@ const nativeLlmModule = serverNativeModule as typeof serverNativeModule & export type NativeLlmProtocol = | 'openai_chat' | 'openai_responses' - | 'anthropic'; + | 'anthropic' + | 'gemini'; export type NativeLlmBackendConfig = { base_url: string; auth_token: string; - request_layer?: 'anthropic' | 'chat_completions' | 'responses' | 'vertex'; + request_layer?: + | 'anthropic' + | 'chat_completions' + | 'responses' + | 'vertex' + | 'vertex_anthropic' + | 'gemini_api' + | 'gemini_vertex'; headers?: Record; no_streaming?: boolean; timeout_ms?: number; @@ -100,6 +123,8 @@ export type NativeLlmCoreContent = call_id: string; name: string; arguments: Record; + arguments_text?: string; + arguments_error?: string; thought?: string; } | { @@ -109,8 +134,12 @@ export type NativeLlmCoreContent = is_error?: boolean; name?: string; arguments?: Record; + arguments_text?: string; + arguments_error?: string; } - | { type: 'image'; source: Record | string }; + | { type: 'image'; source: Record | string } + | { type: 'audio'; source: Record | string } + | { type: 'file'; source: Record | string }; export type NativeLlmCoreMessage = { role: NativeLlmCoreRole; @@ -133,22 +162,54 @@ export type NativeLlmRequest = { tool_choice?: 'auto' | 'none' | 'required' | { name: string }; include?: string[]; reasoning?: Record; + response_schema?: Record; middleware?: { request?: Array< 'normalize_messages' | 'clamp_max_tokens' | 'tool_schema_rewrite' >; stream?: Array<'stream_event_normalize' | 'citation_indexing'>; config?: { - no_additional_properties?: boolean; - drop_property_format?: boolean; - drop_property_min_length?: boolean; - drop_array_min_items?: boolean; - drop_array_max_items?: boolean; + additional_properties_policy?: 'preserve' | 'forbid'; + property_format_policy?: 'preserve' | 'drop'; + property_min_length_policy?: 'preserve' | 'drop'; + array_min_items_policy?: 'preserve' | 'drop'; + array_max_items_policy?: 'preserve' | 'drop'; max_tokens_cap?: number; }; }; }; +export type NativeLlmStructuredRequest = { + model: string; + messages: NativeLlmCoreMessage[]; + schema: Record; + max_tokens?: number; + temperature?: number; + reasoning?: Record; + strict?: boolean; + response_mime_type?: string; + middleware?: NativeLlmRequest['middleware']; +}; + +export type NativeLlmEmbeddingRequest = { + model: string; + inputs: string[]; + dimensions?: number; + task_type?: string; +}; + +export type NativeLlmRerankCandidate = { + id?: string; + text: string; +}; + +export type NativeLlmRerankRequest = { + model: string; + query: string; + candidates: NativeLlmRerankCandidate[]; + top_n?: number; +}; + export type NativeLlmDispatchResponse = { id: string; model: string; @@ -159,10 +220,39 @@ export type NativeLlmDispatchResponse = { total_tokens: number; cached_tokens?: number; }; - finish_reason: string; + finish_reason: + | 'stop' + | 'length' + | 'tool_calls' + | 'content_filter' + | 'error' + | string; reasoning_details?: unknown; }; +export type NativeLlmStructuredResponse = { + id: string; + model: string; + output_text: string; + usage: NativeLlmDispatchResponse['usage']; + finish_reason: NativeLlmDispatchResponse['finish_reason']; + reasoning_details?: unknown; +}; + +export type NativeLlmEmbeddingResponse = { + model: string; + embeddings: number[][]; + usage?: { + prompt_tokens: number; + total_tokens: number; + }; +}; + +export type NativeLlmRerankResponse = { + model: string; + scores: number[]; +}; + export type NativeLlmStreamEvent = | { type: 'message_start'; id?: string; model?: string } | { type: 'text_delta'; text: string } @@ -178,6 +268,8 @@ export type NativeLlmStreamEvent = call_id: string; name: string; arguments: Record; + arguments_text?: string; + arguments_error?: string; thought?: string; } | { @@ -187,6 +279,8 @@ export type NativeLlmStreamEvent = is_error?: boolean; name?: string; arguments?: Record; + arguments_text?: string; + arguments_error?: string; } | { type: 'citation'; index: number; url: string } | { @@ -200,7 +294,7 @@ export type NativeLlmStreamEvent = } | { type: 'done'; - finish_reason?: string; + finish_reason?: NativeLlmDispatchResponse['finish_reason']; usage?: { prompt_tokens: number; completion_tokens: number; @@ -228,6 +322,57 @@ export async function llmDispatch( return JSON.parse(responseText) as NativeLlmDispatchResponse; } +export async function llmStructuredDispatch( + protocol: NativeLlmProtocol, + backendConfig: NativeLlmBackendConfig, + request: NativeLlmStructuredRequest +): Promise { + if (!nativeLlmModule.llmStructuredDispatch) { + throw new Error('native llm structured dispatch is not available'); + } + const response = nativeLlmModule.llmStructuredDispatch( + protocol, + JSON.stringify(backendConfig), + JSON.stringify(request) + ); + const responseText = await Promise.resolve(response); + return JSON.parse(responseText) as NativeLlmStructuredResponse; +} + +export async function llmEmbeddingDispatch( + protocol: NativeLlmProtocol, + backendConfig: NativeLlmBackendConfig, + request: NativeLlmEmbeddingRequest +): Promise { + if (!nativeLlmModule.llmEmbeddingDispatch) { + throw new Error('native llm embedding dispatch is not available'); + } + const response = nativeLlmModule.llmEmbeddingDispatch( + protocol, + JSON.stringify(backendConfig), + JSON.stringify(request) + ); + const responseText = await Promise.resolve(response); + return JSON.parse(responseText) as NativeLlmEmbeddingResponse; +} + +export async function llmRerankDispatch( + protocol: NativeLlmProtocol, + backendConfig: NativeLlmBackendConfig, + request: NativeLlmRerankRequest +): Promise { + if (!nativeLlmModule.llmRerankDispatch) { + throw new Error('native llm rerank dispatch is not available'); + } + const response = nativeLlmModule.llmRerankDispatch( + protocol, + JSON.stringify(backendConfig), + JSON.stringify(request) + ); + const responseText = await Promise.resolve(response); + return JSON.parse(responseText) as NativeLlmRerankResponse; +} + export class NativeStreamAdapter implements AsyncIterableIterator { readonly #queue: T[] = []; readonly #waiters: ((result: IteratorResult) => void)[] = []; diff --git a/packages/backend/server/src/plugins/copilot/config.ts b/packages/backend/server/src/plugins/copilot/config.ts index 44ecf0244a..aa9e0feffe 100644 --- a/packages/backend/server/src/plugins/copilot/config.ts +++ b/packages/backend/server/src/plugins/copilot/config.ts @@ -81,7 +81,7 @@ export type CopilotProviderProfile = CopilotProviderProfileCommon & }[CopilotProviderType]; export type CopilotProviderDefaults = Partial< - Record + Record, string> > & { fallback?: string; }; @@ -184,6 +184,7 @@ const CopilotProviderDefaultsShape = z.object({ [ModelOutputType.Object]: z.string().optional(), [ModelOutputType.Embedding]: z.string().optional(), [ModelOutputType.Image]: z.string().optional(), + [ModelOutputType.Rerank]: z.string().optional(), [ModelOutputType.Structured]: z.string().optional(), fallback: z.string().optional(), }); diff --git a/packages/backend/server/src/plugins/copilot/embedding/client.ts b/packages/backend/server/src/plugins/copilot/embedding/client.ts index af69c2499c..34e5ccfa5f 100644 --- a/packages/backend/server/src/plugins/copilot/embedding/client.ts +++ b/packages/backend/server/src/plugins/copilot/embedding/client.ts @@ -1,25 +1,17 @@ import { Logger } from '@nestjs/common'; import type { ModuleRef } from '@nestjs/core'; -import { - Config, - CopilotPromptNotFound, - CopilotProviderNotSupported, -} from '../../../base'; +import { Config, CopilotProviderNotSupported } from '../../../base'; import { CopilotFailedToGenerateEmbedding } from '../../../base/error/errors.gen'; import { ChunkSimilarity, Embedding, EMBEDDING_DIMENSIONS, } from '../../../models'; -import { PromptService } from '../prompt/service'; import { CopilotProviderFactory } from '../providers/factory'; import type { CopilotProvider } from '../providers/provider'; import { - DEFAULT_RERANK_MODEL, - normalizeRerankModel, -} from '../providers/rerank'; -import { + type CopilotRerankRequest, type ModelFullConditions, ModelInputType, ModelOutputType, @@ -27,24 +19,20 @@ import { import { EmbeddingClient, type ReRankResult } from './types'; const EMBEDDING_MODEL = 'gemini-embedding-001'; -const RERANK_PROMPT = 'Rerank results'; - +const RERANK_MODEL = 'gpt-5.2'; class ProductionEmbeddingClient extends EmbeddingClient { private readonly logger = new Logger(ProductionEmbeddingClient.name); constructor( private readonly config: Config, - private readonly providerFactory: CopilotProviderFactory, - private readonly prompt: PromptService + private readonly providerFactory: CopilotProviderFactory ) { super(); } override async configured(): Promise { const embedding = await this.providerFactory.getProvider({ - modelId: this.config.copilot?.scenarios?.override_enabled - ? this.config.copilot.scenarios.scenarios?.embedding || EMBEDDING_MODEL - : EMBEDDING_MODEL, + modelId: this.getEmbeddingModelId(), outputType: ModelOutputType.Embedding, }); const result = Boolean(embedding); @@ -69,9 +57,15 @@ class ProductionEmbeddingClient extends EmbeddingClient { return provider; } + private getEmbeddingModelId() { + return this.config.copilot?.scenarios?.override_enabled + ? this.config.copilot.scenarios.scenarios?.embedding || EMBEDDING_MODEL + : EMBEDDING_MODEL; + } + async getEmbeddings(input: string[]): Promise { const provider = await this.getProvider({ - modelId: EMBEDDING_MODEL, + modelId: this.getEmbeddingModelId(), outputType: ModelOutputType.Embedding, }); this.logger.verbose( @@ -114,21 +108,22 @@ class ProductionEmbeddingClient extends EmbeddingClient { ): Promise { if (!embeddings.length) return []; - const prompt = await this.prompt.get(RERANK_PROMPT); - if (!prompt) { - throw new CopilotPromptNotFound({ name: RERANK_PROMPT }); - } - const rerankModel = normalizeRerankModel(prompt.model); - if (prompt.model !== rerankModel) { - this.logger.warn( - `Unsupported rerank model "${prompt.model}" configured, falling back to "${DEFAULT_RERANK_MODEL}".` - ); - } - const provider = await this.getProvider({ modelId: rerankModel }); + const provider = await this.getProvider({ + modelId: RERANK_MODEL, + outputType: ModelOutputType.Rerank, + }); + + const rerankRequest: CopilotRerankRequest = { + query, + candidates: embeddings.map((embedding, index) => ({ + id: String(index), + text: embedding.content, + })), + }; const ranks = await provider.rerank( - { modelId: rerankModel }, - embeddings.map(e => prompt.finish({ query, doc: e.content })), + { modelId: RERANK_MODEL }, + rerankRequest, { signal } ); @@ -227,9 +222,7 @@ export async function getEmbeddingClient( const providerFactory = moduleRef.get(CopilotProviderFactory, { strict: false, }); - const prompt = moduleRef.get(PromptService, { strict: false }); - - const client = new ProductionEmbeddingClient(config, providerFactory, prompt); + const client = new ProductionEmbeddingClient(config, providerFactory); if (await client.configured()) { EMBEDDING_CLIENT = client; } diff --git a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts index 88532bc2ac..6696138fc6 100644 --- a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts +++ b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts @@ -418,21 +418,6 @@ Convert a multi-speaker audio recording into a structured JSON format by transcr maxRetries: 1, }, }, - { - name: 'Rerank results', - action: 'Rerank results', - model: 'gpt-5.2', - messages: [ - { - role: 'system', - content: `Judge whether the Document meets the requirements based on the Query and the Instruct provided. The answer must be "yes" or "no".`, - }, - { - role: 'user', - content: `: Given a document search result, determine whether the result is relevant to the query.\n: {{query}}\n: {{doc}}`, - }, - ], - }, { name: 'Generate a caption', action: 'Generate a caption', diff --git a/packages/backend/server/src/plugins/copilot/providers/anthropic/anthropic.ts b/packages/backend/server/src/plugins/copilot/providers/anthropic/anthropic.ts index 26f2052fa7..ab3e512882 100644 --- a/packages/backend/server/src/plugins/copilot/providers/anthropic/anthropic.ts +++ b/packages/backend/server/src/plugins/copilot/providers/anthropic/anthropic.ts @@ -1,5 +1,3 @@ -import type { ToolSet } from 'ai'; - import { CopilotProviderSideError, metrics, @@ -11,6 +9,7 @@ import { type NativeLlmRequest, } from '../../../../native'; import type { NodeTextMiddleware } from '../../config'; +import type { CopilotToolSet } from '../../tools'; import { buildNativeRequest, NativeProviderAdapter } from '../native'; import { CopilotProvider } from '../provider'; import type { @@ -20,7 +19,11 @@ import type { StreamObject, } from '../types'; import { CopilotProviderType, ModelOutputType } from '../types'; -import { getGoogleAuth, getVertexAnthropicBaseUrl } from '../utils'; +import { + getGoogleAuth, + getVertexAnthropicBaseUrl, + type VertexAnthropicProviderConfig, +} from '../utils'; export abstract class AnthropicProvider extends CopilotProvider { private handleError(e: any) { @@ -36,22 +39,16 @@ export abstract class AnthropicProvider extends CopilotProvider { private async createNativeConfig(): Promise { if (this.type === CopilotProviderType.AnthropicVertex) { - const auth = await getGoogleAuth(this.config as any, 'anthropic'); - const headers = auth.headers(); - const authorization = - headers.Authorization || - (headers as Record).authorization; - const token = - typeof authorization === 'string' - ? authorization.replace(/^Bearer\s+/i, '') - : ''; - const baseUrl = - getVertexAnthropicBaseUrl(this.config as any) || auth.baseUrl; + const config = this.config as VertexAnthropicProviderConfig; + const auth = await getGoogleAuth(config, 'anthropic'); + const { Authorization: authHeader } = auth.headers(); + const token = authHeader.replace(/^Bearer\s+/i, ''); + const baseUrl = getVertexAnthropicBaseUrl(config) || auth.baseUrl; return { base_url: baseUrl || '', auth_token: token, - request_layer: 'vertex', - headers, + request_layer: 'vertex_anthropic', + headers: { Authorization: authHeader }, }; } @@ -65,7 +62,7 @@ export abstract class AnthropicProvider extends CopilotProvider { private createAdapter( backendConfig: NativeLlmBackendConfig, - tools: ToolSet, + tools: CopilotToolSet, nodeTextMiddleware?: NodeTextMiddleware[] ) { return new NativeProviderAdapter( @@ -93,8 +90,12 @@ export abstract class AnthropicProvider extends CopilotProvider { options: CopilotChatOptions = {} ): Promise { const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); @@ -102,11 +103,13 @@ export abstract class AnthropicProvider extends CopilotProvider { const tools = await this.getTools(options, model.id); const middleware = this.getActiveProviderMiddleware(); const reasoning = this.getReasoning(options, model.id); + const cap = this.getAttachCapability(model, ModelOutputType.Text); const { request } = await buildNativeRequest({ model: model.id, messages, options, tools, + attachmentCapability: cap, reasoning, middleware, }); @@ -115,7 +118,7 @@ export abstract class AnthropicProvider extends CopilotProvider { tools, middleware.node?.text ); - return await adapter.text(request, options.signal); + return await adapter.text(request, options.signal, messages); } catch (e: any) { metrics.ai .counter('chat_text_errors') @@ -130,8 +133,12 @@ export abstract class AnthropicProvider extends CopilotProvider { options: CopilotChatOptions = {} ): AsyncIterable { const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai @@ -140,11 +147,13 @@ export abstract class AnthropicProvider extends CopilotProvider { const backendConfig = await this.createNativeConfig(); const tools = await this.getTools(options, model.id); const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Text); const { request } = await buildNativeRequest({ model: model.id, messages, options, tools, + attachmentCapability: cap, reasoning: this.getReasoning(options, model.id), middleware, }); @@ -153,7 +162,11 @@ export abstract class AnthropicProvider extends CopilotProvider { tools, middleware.node?.text ); - for await (const chunk of adapter.streamText(request, options.signal)) { + for await (const chunk of adapter.streamText( + request, + options.signal, + messages + )) { yield chunk; } } catch (e: any) { @@ -170,8 +183,12 @@ export abstract class AnthropicProvider extends CopilotProvider { options: CopilotChatOptions = {} ): AsyncIterable { const fullCond = { ...cond, outputType: ModelOutputType.Object }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai @@ -180,11 +197,13 @@ export abstract class AnthropicProvider extends CopilotProvider { const backendConfig = await this.createNativeConfig(); const tools = await this.getTools(options, model.id); const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Object); const { request } = await buildNativeRequest({ model: model.id, messages, options, tools, + attachmentCapability: cap, reasoning: this.getReasoning(options, model.id), middleware, }); @@ -193,7 +212,11 @@ export abstract class AnthropicProvider extends CopilotProvider { tools, middleware.node?.text ); - for await (const chunk of adapter.streamObject(request, options.signal)) { + for await (const chunk of adapter.streamObject( + request, + options.signal, + messages + )) { yield chunk; } } catch (e: any) { diff --git a/packages/backend/server/src/plugins/copilot/providers/anthropic/official.ts b/packages/backend/server/src/plugins/copilot/providers/anthropic/official.ts index 7f4b476481..3252b32457 100644 --- a/packages/backend/server/src/plugins/copilot/providers/anthropic/official.ts +++ b/packages/backend/server/src/plugins/copilot/providers/anthropic/official.ts @@ -1,5 +1,6 @@ import z from 'zod'; +import { IMAGE_ATTACHMENT_CAPABILITY } from '../attachments'; import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; import { AnthropicProvider } from './anthropic'; @@ -23,6 +24,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider { override readonly type = CopilotProviderType.AnthropicVertex; @@ -25,6 +21,7 @@ export class AnthropicVertexProvider extends AnthropicProvider; + isRemote: boolean; +}; + +function parseDataUrl(url: string) { + if (!url.startsWith('data:')) { + return null; + } + + const commaIndex = url.indexOf(','); + if (commaIndex === -1) { + return null; + } + + const meta = url.slice(5, commaIndex); + const payload = url.slice(commaIndex + 1); + const parts = meta.split(';'); + const mediaType = parts[0] || 'text/plain;charset=US-ASCII'; + const isBase64 = parts.includes('base64'); + + return { + mediaType, + data: isBase64 + ? payload + : Buffer.from(decodeURIComponent(payload), 'utf8').toString('base64'), + }; +} + +function attachmentTypeFromMediaType(mediaType: string): PromptAttachmentKind { + if (mediaType.startsWith('image/')) { + return 'image'; + } + if (mediaType.startsWith('audio/')) { + return 'audio'; + } + return 'file'; +} + +function attachmentKindFromHintOrMediaType( + hint: PromptAttachmentKind | undefined, + mediaType: string | undefined +): PromptAttachmentKind { + if (hint) return hint; + return attachmentTypeFromMediaType(mediaType || ''); +} + +function toBase64Data(data: string, encoding: 'base64' | 'utf8' = 'base64') { + return encoding === 'base64' + ? data + : Buffer.from(data, 'utf8').toString('base64'); +} + +function appendAttachMetadata( + source: Record, + attachment: Exclude & Record +) { + if (attachment.fileName) { + source.file_name = attachment.fileName; + } + if (attachment.providerHint) { + source.provider_hint = attachment.providerHint; + } + return source; +} + +export function promptAttachmentHasSource( + attachment: PromptAttachment +): boolean { + if (typeof attachment === 'string') { + return !!attachment.trim(); + } + + if ('attachment' in attachment) { + return !!attachment.attachment; + } + + switch (attachment.kind) { + case 'url': + return !!attachment.url; + case 'data': + case 'bytes': + return !!attachment.data; + case 'file_handle': + return !!attachment.fileHandle; + } +} + +export async function canonicalizePromptAttachment( + attachment: PromptAttachment, + message: Pick +): Promise { + const fallbackMimeType = + typeof message.params?.mimetype === 'string' + ? message.params.mimetype + : undefined; + + if (typeof attachment === 'string') { + const dataUrl = parseDataUrl(attachment); + const mediaType = + fallbackMimeType ?? + dataUrl?.mediaType ?? + (await inferMimeType(attachment)); + const kind = attachmentKindFromHintOrMediaType(undefined, mediaType); + if (dataUrl) { + return { + kind, + sourceKind: 'data', + mediaType, + isRemote: false, + source: { + media_type: mediaType || dataUrl.mediaType, + data: dataUrl.data, + }, + }; + } + + return { + kind, + sourceKind: 'url', + mediaType, + isRemote: /^https?:\/\//.test(attachment), + source: { url: attachment, media_type: mediaType }, + }; + } + + if ('attachment' in attachment) { + return await canonicalizePromptAttachment( + { + kind: 'url', + url: attachment.attachment, + mimeType: attachment.mimeType, + }, + message + ); + } + + if (attachment.kind === 'url') { + const dataUrl = parseDataUrl(attachment.url); + const mediaType = + attachment.mimeType ?? + fallbackMimeType ?? + dataUrl?.mediaType ?? + (await inferMimeType(attachment.url)); + const kind = attachmentKindFromHintOrMediaType( + attachment.providerHint?.kind, + mediaType + ); + if (dataUrl) { + return { + kind, + sourceKind: 'data', + mediaType, + isRemote: false, + source: appendAttachMetadata( + { media_type: mediaType || dataUrl.mediaType, data: dataUrl.data }, + attachment + ), + }; + } + + return { + kind, + sourceKind: 'url', + mediaType, + isRemote: /^https?:\/\//.test(attachment.url), + source: appendAttachMetadata( + { url: attachment.url, media_type: mediaType }, + attachment + ), + }; + } + + if (attachment.kind === 'data' || attachment.kind === 'bytes') { + return { + kind: attachmentKindFromHintOrMediaType( + attachment.providerHint?.kind, + attachment.mimeType + ), + sourceKind: attachment.kind, + mediaType: attachment.mimeType, + isRemote: false, + source: appendAttachMetadata( + { + media_type: attachment.mimeType, + data: toBase64Data( + attachment.data, + attachment.kind === 'data' ? attachment.encoding : 'base64' + ), + }, + attachment + ), + }; + } + + return { + kind: attachmentKindFromHintOrMediaType( + attachment.providerHint?.kind, + attachment.mimeType + ), + sourceKind: 'file_handle', + mediaType: attachment.mimeType, + isRemote: false, + source: appendAttachMetadata( + { file_handle: attachment.fileHandle, media_type: attachment.mimeType }, + attachment + ), + }; +} diff --git a/packages/backend/server/src/plugins/copilot/providers/fal.ts b/packages/backend/server/src/plugins/copilot/providers/fal.ts index 486d3e91fd..b6927a141d 100644 --- a/packages/backend/server/src/plugins/copilot/providers/fal.ts +++ b/packages/backend/server/src/plugins/copilot/providers/fal.ts @@ -19,6 +19,7 @@ import type { PromptMessage, } from './types'; import { CopilotProviderType, ModelInputType, ModelOutputType } from './types'; +import { promptAttachmentMimeType, promptAttachmentToUrl } from './utils'; export type FalConfig = { apiKey: string; @@ -183,13 +184,14 @@ export class FalProvider extends CopilotProvider { return { model_name: options.modelName || undefined, image_url: attachments - ?.map(v => - typeof v === 'string' - ? v - : v.mimeType.startsWith('image/') - ? v.attachment - : undefined - ) + ?.map(v => { + const url = promptAttachmentToUrl(v); + const mediaType = promptAttachmentMimeType( + v, + typeof params?.mimetype === 'string' ? params.mimetype : undefined + ); + return url && mediaType?.startsWith('image/') ? url : undefined; + }) .find(v => !!v), prompt: content.trim(), loras: lora.length ? lora : undefined, diff --git a/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts b/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts index 26ca7caa88..4a5743dd7c 100644 --- a/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts +++ b/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts @@ -1,87 +1,94 @@ -import type { - GoogleGenerativeAIProvider, - GoogleGenerativeAIProviderOptions, -} from '@ai-sdk/google'; -import type { GoogleVertexProvider } from '@ai-sdk/google-vertex'; -import { - AISDKError, - type EmbeddingModel, - embedMany, - generateObject, - generateText, - JSONParseError, - stepCountIs, - streamText, -} from 'ai'; +import { setTimeout as delay } from 'node:timers/promises'; + +import { ZodError } from 'zod'; import { - CopilotPromptInvalid, CopilotProviderSideError, metrics, + OneMB, + readResponseBufferWithLimit, + safeFetch, UserFriendlyError, } from '../../../../base'; +import { sniffMime } from '../../../../base/storage/providers/utils'; +import { + llmDispatchStream, + llmEmbeddingDispatch, + llmStructuredDispatch, + type NativeLlmBackendConfig, + type NativeLlmEmbeddingRequest, + type NativeLlmRequest, + type NativeLlmStructuredRequest, +} from '../../../../native'; +import type { NodeTextMiddleware } from '../../config'; +import type { CopilotToolSet } from '../../tools'; +import { + buildNativeEmbeddingRequest, + buildNativeRequest, + buildNativeStructuredRequest, + NativeProviderAdapter, + parseNativeStructuredOutput, + StructuredResponseParseError, +} from '../native'; import { CopilotProvider } from '../provider'; import type { CopilotChatOptions, CopilotEmbeddingOptions, CopilotImageOptions, - CopilotProviderModel, + CopilotStructuredOptions, ModelConditions, + PromptAttachment, PromptMessage, StreamObject, } from '../types'; import { ModelOutputType } from '../types'; -import { - chatToGPTMessage, - StreamObjectParser, - TextStreamParser, -} from '../utils'; +import { promptAttachmentMimeType, promptAttachmentToUrl } from '../utils'; export const DEFAULT_DIMENSIONS = 256; +const GEMINI_REMOTE_ATTACHMENT_MAX_BYTES = 64 * OneMB; +const TRUSTED_ATTACHMENT_HOST_SUFFIXES = ['cdn.affine.pro']; +const GEMINI_RETRY_INITIAL_DELAY_MS = 2_000; + +function normalizeMimeType(mediaType?: string) { + return mediaType?.split(';', 1)[0]?.trim() || 'application/octet-stream'; +} + +function isYoutubeUrl(url: URL) { + const hostname = url.hostname.toLowerCase(); + if (hostname === 'youtu.be') { + return /^\/[\w-]+$/.test(url.pathname); + } + + if (hostname !== 'youtube.com' && hostname !== 'www.youtube.com') { + return false; + } + + if (url.pathname !== '/watch') { + return false; + } + + return !!url.searchParams.get('v'); +} + +function isGeminiFileUrl(url: URL, baseUrl: string) { + try { + const base = new URL(baseUrl); + const basePath = base.pathname.replace(/\/+$/, ''); + return ( + url.origin === base.origin && + url.pathname.startsWith(`${basePath}/files/`) + ); + } catch { + return false; + } +} export abstract class GeminiProvider extends CopilotProvider { - protected abstract instance: - | GoogleGenerativeAIProvider - | GoogleVertexProvider; - - private getThinkingConfig( - model: string, - options: { includeThoughts: boolean; useDynamicBudget?: boolean } - ): NonNullable { - if (this.isGemini3Model(model)) { - return { - includeThoughts: options.includeThoughts, - thinkingLevel: 'high', - }; - } - - return { - includeThoughts: options.includeThoughts, - thinkingBudget: options.useDynamicBudget ? -1 : 12000, - }; - } - - private getEmbeddingModel(model: string) { - const provider = this.instance as typeof this.instance & { - embeddingModel?: (modelId: string) => EmbeddingModel; - textEmbeddingModel?: (modelId: string) => EmbeddingModel; - }; - - return ( - provider.embeddingModel?.(model) ?? provider.textEmbeddingModel?.(model) - ); - } + protected abstract createNativeConfig(): Promise; private handleError(e: any) { if (e instanceof UserFriendlyError) { return e; - } else if (e instanceof AISDKError) { - this.logger.error('Throw error from ai sdk:', e); - return new CopilotProviderSideError({ - provider: this.type, - kind: e.name || 'unknown', - message: e.message, - }); } else { return new CopilotProviderSideError({ provider: this.type, @@ -91,37 +98,261 @@ export abstract class GeminiProvider extends CopilotProvider { } } + protected createNativeDispatch(backendConfig: NativeLlmBackendConfig) { + return (request: NativeLlmRequest, signal?: AbortSignal) => + llmDispatchStream('gemini', backendConfig, request, signal); + } + + protected createNativeStructuredDispatch( + backendConfig: NativeLlmBackendConfig + ) { + return (request: NativeLlmStructuredRequest) => + llmStructuredDispatch('gemini', backendConfig, request); + } + + protected createNativeEmbeddingDispatch( + backendConfig: NativeLlmBackendConfig + ) { + return (request: NativeLlmEmbeddingRequest) => + llmEmbeddingDispatch('gemini', backendConfig, request); + } + + protected createNativeAdapter( + backendConfig: NativeLlmBackendConfig, + tools: CopilotToolSet, + nodeTextMiddleware?: NodeTextMiddleware[] + ) { + return new NativeProviderAdapter( + this.createNativeDispatch(backendConfig), + tools, + this.MAX_STEPS, + { nodeTextMiddleware } + ); + } + + protected async fetchRemoteAttach(url: string, signal?: AbortSignal) { + const parsed = new URL(url); + const response = await safeFetch( + parsed, + { method: 'GET', signal }, + this.buildAttachFetchOptions(parsed) + ); + if (!response.ok) { + throw new Error( + `Failed to fetch attachment: ${response.status} ${response.statusText}` + ); + } + const buffer = await readResponseBufferWithLimit( + response, + GEMINI_REMOTE_ATTACHMENT_MAX_BYTES + ); + const headerMimeType = normalizeMimeType( + response.headers.get('content-type') || '' + ); + return { + data: buffer.toString('base64'), + mimeType: normalizeMimeType(sniffMime(buffer, headerMimeType)), + }; + } + + private buildAttachFetchOptions(url: URL) { + const baseOptions = { timeoutMs: 15_000, maxRedirects: 3 } as const; + if (!env.prod) { + return { ...baseOptions, allowPrivateOrigins: new Set([url.origin]) }; + } + + const trustedOrigins = new Set(); + const protocol = this.AFFiNEConfig.server.https ? 'https:' : 'http:'; + const port = this.AFFiNEConfig.server.port; + const isDefaultPort = + (protocol === 'https:' && port === 443) || + (protocol === 'http:' && port === 80); + + const addHostOrigin = (host: string) => { + if (!host) return; + try { + const parsed = new URL(`${protocol}//${host}`); + if (!parsed.port && !isDefaultPort) { + parsed.port = String(port); + } + trustedOrigins.add(parsed.origin); + } catch { + // ignore invalid host config entries + } + }; + + if (this.AFFiNEConfig.server.externalUrl) { + try { + trustedOrigins.add( + new URL(this.AFFiNEConfig.server.externalUrl).origin + ); + } catch { + // ignore invalid external URL + } + } + + addHostOrigin(this.AFFiNEConfig.server.host); + for (const host of this.AFFiNEConfig.server.hosts) { + addHostOrigin(host); + } + + const hostname = url.hostname.toLowerCase(); + const trustedByHost = TRUSTED_ATTACHMENT_HOST_SUFFIXES.some( + suffix => hostname === suffix || hostname.endsWith(`.${suffix}`) + ); + if (trustedOrigins.has(url.origin) || trustedByHost) { + return { ...baseOptions, allowPrivateOrigins: new Set([url.origin]) }; + } + + return baseOptions; + } + + private shouldInlineRemoteAttach(url: URL, config: NativeLlmBackendConfig) { + switch (config.request_layer) { + case 'gemini_api': + if (url.protocol !== 'http:' && url.protocol !== 'https:') return false; + return !(isGeminiFileUrl(url, config.base_url) || isYoutubeUrl(url)); + case 'gemini_vertex': + return false; + default: + return false; + } + } + + private toInlineAttach( + attachment: PromptAttachment, + mimeType: string, + data: string + ): PromptAttachment { + if (typeof attachment === 'string' || !('kind' in attachment)) { + return { kind: 'bytes', data, mimeType }; + } + + if (attachment.kind !== 'url') { + return attachment; + } + + return { + kind: 'bytes', + data, + mimeType, + fileName: attachment.fileName, + providerHint: attachment.providerHint, + }; + } + + protected async prepareMessages( + messages: PromptMessage[], + backendConfig: NativeLlmBackendConfig, + signal?: AbortSignal + ): Promise { + const prepared: PromptMessage[] = []; + + for (const message of messages) { + signal?.throwIfAborted(); + if (!Array.isArray(message.attachments) || !message.attachments.length) { + prepared.push(message); + continue; + } + + const attachments: PromptAttachment[] = []; + let changed = false; + for (const attachment of message.attachments) { + signal?.throwIfAborted(); + const rawUrl = promptAttachmentToUrl(attachment); + if (!rawUrl || rawUrl.startsWith('data:')) { + attachments.push(attachment); + continue; + } + + let parsed: URL; + try { + parsed = new URL(rawUrl); + } catch { + attachments.push(attachment); + continue; + } + + if (!this.shouldInlineRemoteAttach(parsed, backendConfig)) { + attachments.push(attachment); + continue; + } + + const declaredMimeType = promptAttachmentMimeType( + attachment, + typeof message.params?.mimetype === 'string' + ? message.params.mimetype + : undefined + ); + const downloaded = await this.fetchRemoteAttach(rawUrl, signal); + attachments.push( + this.toInlineAttach( + attachment, + declaredMimeType + ? normalizeMimeType(declaredMimeType) + : downloaded.mimeType, + downloaded.data + ) + ); + changed = true; + } + + prepared.push(changed ? { ...message, attachments } : message); + } + + return prepared; + } + + protected async waitForStructuredRetry( + delayMs: number, + signal?: AbortSignal + ) { + await delay(delayMs, undefined, signal ? { signal } : undefined); + } + async text( cond: ModelConditions, messages: PromptMessage[], options: CopilotChatOptions = {} ): Promise { const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { - metrics.ai.counter('chat_text_calls').add(1, { model: model.id }); - - const [system, msgs] = await chatToGPTMessage(messages); - - const modelInstance = this.instance(model.id); - const { text } = await generateText({ - model: modelInstance, - system, - messages: msgs, - abortSignal: options.signal, - providerOptions: { - google: this.getGeminiOptions(options, model.id), - }, - tools: await this.getTools(options, model.id), - stopWhen: stepCountIs(this.MAX_STEPS), + metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); + const backendConfig = await this.createNativeConfig(); + const msg = await this.prepareMessages( + messages, + backendConfig, + options.signal + ); + const tools = await this.getTools(options, model.id); + const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Text); + const { request } = await buildNativeRequest({ + model: model.id, + messages: msg, + options, + tools, + attachmentCapability: cap, + reasoning: this.getReasoning(options, model.id), + middleware, }); - - if (!text) throw new Error('Failed to generate text'); - return text.trim(); + const adapter = this.createNativeAdapter( + backendConfig, + tools, + middleware.node?.text + ); + return await adapter.text(request, options.signal, messages); } catch (e: any) { - metrics.ai.counter('chat_text_errors').add(1, { model: model.id }); + metrics.ai + .counter('chat_text_errors') + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } @@ -129,55 +360,65 @@ export abstract class GeminiProvider extends CopilotProvider { override async structure( cond: ModelConditions, messages: PromptMessage[], - options: CopilotChatOptions = {} + options: CopilotStructuredOptions = {} ): Promise { const fullCond = { ...cond, outputType: ModelOutputType.Structured }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { - metrics.ai.counter('chat_text_calls').add(1, { model: model.id }); - - const [system, msgs, schema] = await chatToGPTMessage(messages); - if (!schema) { - throw new CopilotPromptInvalid('Schema is required'); - } - - const modelInstance = this.instance(model.id); - const { object } = await generateObject({ - model: modelInstance, - system, - messages: msgs, - schema, - providerOptions: { - google: { - thinkingConfig: this.getThinkingConfig(model.id, { - includeThoughts: false, - useDynamicBudget: true, - }), - }, - }, - abortSignal: options.signal, - maxRetries: options.maxRetries || 3, - experimental_repairText: async ({ text, error }) => { - if (error instanceof JSONParseError) { - // strange fixed response, temporarily replace it - const ret = text.replaceAll(/^ny\n/g, ' ').trim(); - if (ret.startsWith('```') || ret.endsWith('```')) { - return ret - .replace(/```[\w\s]+\n/g, '') - .replace(/\n```/g, '') - .trim(); - } - return ret; - } - return null; - }, + metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); + const backendConfig = await this.createNativeConfig(); + const msg = await this.prepareMessages( + messages, + backendConfig, + options.signal + ); + const structuredDispatch = + this.createNativeStructuredDispatch(backendConfig); + const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Structured); + const { request, schema } = await buildNativeStructuredRequest({ + model: model.id, + messages: msg, + options, + attachmentCapability: cap, + reasoning: this.getReasoning(options, model.id), + responseSchema: options.schema, + middleware, }); - - return JSON.stringify(object); + const maxRetries = Math.max(options.maxRetries ?? 3, 0); + for (let attempt = 0; ; attempt++) { + try { + const response = await structuredDispatch(request); + const parsed = parseNativeStructuredOutput(response); + const validated = schema.parse(parsed); + return JSON.stringify(validated); + } catch (error) { + const isParsingError = + error instanceof StructuredResponseParseError || + error instanceof ZodError; + const retryableError = + isParsingError || !(error instanceof UserFriendlyError); + if (!retryableError || attempt >= maxRetries) { + throw error; + } + if (!isParsingError) { + await this.waitForStructuredRetry( + GEMINI_RETRY_INITIAL_DELAY_MS * 2 ** attempt, + options.signal + ); + } + } + } } catch (e: any) { - metrics.ai.counter('chat_text_errors').add(1, { model: model.id }); + metrics.ai + .counter('chat_text_errors') + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } @@ -188,29 +429,54 @@ export abstract class GeminiProvider extends CopilotProvider { options: CopilotChatOptions | CopilotImageOptions = {} ): AsyncIterable { const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { - metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id }); - const fullStream = await this.getFullStream(model, messages, options); - const parser = new TextStreamParser(); - for await (const chunk of fullStream) { - const result = parser.parse(chunk); - yield result; - if (options.signal?.aborted) { - await fullStream.cancel(); - break; - } - } - if (!options.signal?.aborted) { - const footnotes = parser.end(); - if (footnotes.length) { - yield `\n\n${footnotes}`; - } + metrics.ai + .counter('chat_text_stream_calls') + .add(1, this.metricLabels(model.id)); + const backendConfig = await this.createNativeConfig(); + const preparedMessages = await this.prepareMessages( + messages, + backendConfig, + options.signal + ); + const tools = await this.getTools( + options as CopilotChatOptions, + model.id + ); + const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Text); + const { request } = await buildNativeRequest({ + model: model.id, + messages: preparedMessages, + options: options as CopilotChatOptions, + tools, + attachmentCapability: cap, + reasoning: this.getReasoning(options, model.id), + middleware, + }); + const adapter = this.createNativeAdapter( + backendConfig, + tools, + middleware.node?.text + ); + for await (const chunk of adapter.streamText( + request, + options.signal, + messages + )) { + yield chunk; } } catch (e: any) { - metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id }); + metrics.ai + .counter('chat_text_stream_errors') + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } @@ -221,29 +487,51 @@ export abstract class GeminiProvider extends CopilotProvider { options: CopilotChatOptions = {} ): AsyncIterable { const fullCond = { ...cond, outputType: ModelOutputType.Object }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai .counter('chat_object_stream_calls') - .add(1, { model: model.id }); - const fullStream = await this.getFullStream(model, messages, options); - const parser = new StreamObjectParser(); - for await (const chunk of fullStream) { - const result = parser.parse(chunk); - if (result) { - yield result; - } - if (options.signal?.aborted) { - await fullStream.cancel(); - break; - } + .add(1, this.metricLabels(model.id)); + const backendConfig = await this.createNativeConfig(); + const msg = await this.prepareMessages( + messages, + backendConfig, + options.signal + ); + const tools = await this.getTools(options, model.id); + const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Object); + const { request } = await buildNativeRequest({ + model: model.id, + messages: msg, + options, + tools, + attachmentCapability: cap, + reasoning: this.getReasoning(options, model.id), + middleware, + }); + const adapter = this.createNativeAdapter( + backendConfig, + tools, + middleware.node?.text + ); + for await (const chunk of adapter.streamObject( + request, + options.signal, + messages + )) { + yield chunk; } } catch (e: any) { metrics.ai .counter('chat_object_stream_errors') - .add(1, { model: model.id }); + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } @@ -253,76 +541,53 @@ export abstract class GeminiProvider extends CopilotProvider { messages: string | string[], options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS } ): Promise { - messages = Array.isArray(messages) ? messages : [messages]; + const values = Array.isArray(messages) ? messages : [messages]; const fullCond = { ...cond, outputType: ModelOutputType.Embedding }; - await this.checkParams({ embeddings: messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + embeddings: values, + cond: fullCond, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai .counter('generate_embedding_calls') - .add(1, { model: model.id }); - - const modelInstance = this.getEmbeddingModel(model.id); - if (!modelInstance) { - throw new Error(`Embedding model is not available for ${model.id}`); - } - - const embeddings = await Promise.allSettled( - messages.map(m => - embedMany({ - model: modelInstance, - values: [m], - maxRetries: 3, - providerOptions: { - google: { - outputDimensionality: options.dimensions || DEFAULT_DIMENSIONS, - taskType: 'RETRIEVAL_DOCUMENT', - }, - }, - }) - ) + .add(1, this.metricLabels(model.id)); + const backendConfig = await this.createNativeConfig(); + const response = await this.createNativeEmbeddingDispatch(backendConfig)( + buildNativeEmbeddingRequest({ + model: model.id, + inputs: values, + dimensions: options.dimensions || DEFAULT_DIMENSIONS, + taskType: 'RETRIEVAL_DOCUMENT', + }) ); - - return embeddings - .flatMap(e => (e.status === 'fulfilled' ? e.value.embeddings : null)) - .filter((v): v is number[] => !!v && Array.isArray(v)); + return response.embeddings; } catch (e: any) { metrics.ai .counter('generate_embedding_errors') - .add(1, { model: model.id }); + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } - private async getFullStream( - model: CopilotProviderModel, - messages: PromptMessage[], - options: CopilotChatOptions = {} - ) { - const [system, msgs] = await chatToGPTMessage(messages); - const { fullStream } = streamText({ - model: this.instance(model.id), - system, - messages: msgs, - abortSignal: options.signal, - providerOptions: { - google: this.getGeminiOptions(options, model.id), - }, - tools: await this.getTools(options, model.id), - stopWhen: stepCountIs(this.MAX_STEPS), - }); - return fullStream; - } - - private getGeminiOptions(options: CopilotChatOptions, model: string) { - const result: GoogleGenerativeAIProviderOptions = {}; - if (options?.reasoning && this.isReasoningModel(model)) { - result.thinkingConfig = this.getThinkingConfig(model, { - includeThoughts: true, - }); + protected getReasoning( + options: CopilotChatOptions | CopilotImageOptions, + model: string + ): Record | undefined { + if ( + options && + 'reasoning' in options && + options.reasoning && + this.isReasoningModel(model) + ) { + return this.isGemini3Model(model) + ? { include_thoughts: true, thinking_level: 'high' } + : { include_thoughts: true, thinking_budget: 12000 }; } - return result; + + return undefined; } private isGemini3Model(model: string) { diff --git a/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts b/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts index e7bd955e18..d29c5a98ba 100644 --- a/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts +++ b/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts @@ -1,9 +1,7 @@ -import { - createGoogleGenerativeAI, - type GoogleGenerativeAIProvider, -} from '@ai-sdk/google'; import z from 'zod'; +import type { NativeLlmBackendConfig } from '../../../../native'; +import { GEMINI_ATTACHMENT_CAPABILITY } from '../attachments'; import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; import { GeminiProvider } from './gemini'; @@ -29,12 +27,15 @@ export class GeminiGenerativeProvider extends GeminiProvider { + return { + base_url: ( + this.config.baseURL || + 'https://generativelanguage.googleapis.com/v1beta' + ).replace(/\/$/, ''), + auth_token: this.config.apiKey, + request_layer: 'gemini_api', + }; + } } diff --git a/packages/backend/server/src/plugins/copilot/providers/gemini/vertex.ts b/packages/backend/server/src/plugins/copilot/providers/gemini/vertex.ts index 8e86c898a8..e9ef0735a2 100644 --- a/packages/backend/server/src/plugins/copilot/providers/gemini/vertex.ts +++ b/packages/backend/server/src/plugins/copilot/providers/gemini/vertex.ts @@ -1,14 +1,14 @@ -import { - createVertex, - type GoogleVertexProvider, - type GoogleVertexProviderSettings, -} from '@ai-sdk/google-vertex'; - +import type { NativeLlmBackendConfig } from '../../../../native'; +import { GEMINI_ATTACHMENT_CAPABILITY } from '../attachments'; import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; -import { getGoogleAuth, VertexModelListSchema } from '../utils'; +import { + getGoogleAuth, + VertexModelListSchema, + type VertexProviderConfig, +} from '../utils'; import { GeminiProvider } from './gemini'; -export type GeminiVertexConfig = GoogleVertexProviderSettings; +export type GeminiVertexConfig = VertexProviderConfig; export class GeminiVertexProvider extends GeminiProvider { override readonly type = CopilotProviderType.GeminiVertex; @@ -23,12 +23,15 @@ export class GeminiVertexProvider extends GeminiProvider { ModelInputType.Text, ModelInputType.Image, ModelInputType.Audio, + ModelInputType.File, ], output: [ ModelOutputType.Text, ModelOutputType.Object, ModelOutputType.Structured, ], + attachments: GEMINI_ATTACHMENT_CAPABILITY, + structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY, }, ], }, @@ -41,12 +44,15 @@ export class GeminiVertexProvider extends GeminiProvider { ModelInputType.Text, ModelInputType.Image, ModelInputType.Audio, + ModelInputType.File, ], output: [ ModelOutputType.Text, ModelOutputType.Object, ModelOutputType.Structured, ], + attachments: GEMINI_ATTACHMENT_CAPABILITY, + structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY, }, ], }, @@ -59,12 +65,15 @@ export class GeminiVertexProvider extends GeminiProvider { ModelInputType.Text, ModelInputType.Image, ModelInputType.Audio, + ModelInputType.File, ], output: [ ModelOutputType.Text, ModelOutputType.Object, ModelOutputType.Structured, ], + attachments: GEMINI_ATTACHMENT_CAPABILITY, + structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY, }, ], }, @@ -80,21 +89,13 @@ export class GeminiVertexProvider extends GeminiProvider { ], }, ]; - - protected instance!: GoogleVertexProvider; - override configured(): boolean { return !!this.config.location && !!this.config.googleAuthOptions; } - protected override setup() { - super.setup(); - this.instance = createVertex(this.config); - } - override async refreshOnlineModels() { try { - const { baseUrl, headers } = await getGoogleAuth(this.config, 'google'); + const { baseUrl, headers } = await this.resolveVertexAuth(); if (baseUrl && !this.onlineModelList.length) { const { publisherModels } = await fetch(`${baseUrl}/models`, { headers: headers(), @@ -109,4 +110,19 @@ export class GeminiVertexProvider extends GeminiProvider { this.logger.error('Failed to fetch available models', e); } } + + protected async resolveVertexAuth() { + return await getGoogleAuth(this.config, 'google'); + } + + protected override async createNativeConfig(): Promise { + const auth = await this.resolveVertexAuth(); + const { Authorization: authHeader } = auth.headers(); + + return { + base_url: auth.baseUrl || '', + auth_token: authHeader.replace(/^Bearer\s+/i, ''), + request_layer: 'gemini_vertex', + }; + } } diff --git a/packages/backend/server/src/plugins/copilot/providers/loop.ts b/packages/backend/server/src/plugins/copilot/providers/loop.ts index 1da0f35146..24e26522aa 100644 --- a/packages/backend/server/src/plugins/copilot/providers/loop.ts +++ b/packages/backend/server/src/plugins/copilot/providers/loop.ts @@ -1,4 +1,3 @@ -import type { ToolSet } from 'ai'; import { z } from 'zod'; import type { @@ -6,6 +5,11 @@ import type { NativeLlmStreamEvent, NativeLlmToolDefinition, } from '../../../native'; +import type { + CopilotTool, + CopilotToolExecuteOptions, + CopilotToolSet, +} from '../tools'; export type NativeDispatchFn = ( request: NativeLlmRequest, @@ -16,6 +20,8 @@ export type NativeToolCall = { id: string; name: string; args: Record; + rawArgumentsText?: string; + argumentParseError?: string; thought?: string; }; @@ -28,10 +34,18 @@ type ToolExecutionResult = { callId: string; name: string; args: Record; + rawArgumentsText?: string; + argumentParseError?: string; output: unknown; isError?: boolean; }; +type ParsedToolArguments = { + args: Record; + rawArgumentsText?: string; + argumentParseError?: string; +}; + export class ToolCallAccumulator { readonly #states = new Map(); @@ -51,12 +65,20 @@ export class ToolCallAccumulator { complete(event: Extract) { const state = this.#states.get(event.call_id); this.#states.delete(event.call_id); + const parsed = + event.arguments_text !== undefined || event.arguments_error !== undefined + ? { + args: event.arguments ?? {}, + rawArgumentsText: event.arguments_text ?? state?.argumentsText, + argumentParseError: event.arguments_error, + } + : event.arguments + ? this.parseArgs(event.arguments, state?.argumentsText) + : this.parseJson(state?.argumentsText ?? '{}'); return { id: event.call_id, name: event.name || state?.name || '', - args: this.parseArgs( - event.arguments ?? this.parseJson(state?.argumentsText ?? '{}') - ), + ...parsed, thought: event.thought, } satisfies NativeToolCall; } @@ -70,51 +92,61 @@ export class ToolCallAccumulator { pending.push({ id: callId, name: state.name, - args: this.parseArgs(this.parseJson(state.argumentsText)), + ...this.parseJson(state.argumentsText), }); } this.#states.clear(); return pending; } - private parseJson(jsonText: string): unknown { + private parseJson(jsonText: string): ParsedToolArguments { if (!jsonText.trim()) { - return {}; + return { args: {} }; } try { - return JSON.parse(jsonText); - } catch { - return {}; + return this.parseArgs(JSON.parse(jsonText), jsonText); + } catch (error) { + return { + args: {}, + rawArgumentsText: jsonText, + argumentParseError: + error instanceof Error + ? error.message + : 'Invalid tool arguments JSON', + }; } } - private parseArgs(value: unknown): Record { + private parseArgs( + value: unknown, + rawArgumentsText?: string + ): ParsedToolArguments { if (value && typeof value === 'object' && !Array.isArray(value)) { - return value as Record; + return { + args: value as Record, + rawArgumentsText, + }; } - return {}; + return { + args: {}, + rawArgumentsText, + argumentParseError: 'Tool arguments must be a JSON object', + }; } } export class ToolSchemaExtractor { - static extract(toolSet: ToolSet): NativeLlmToolDefinition[] { + static extract(toolSet: CopilotToolSet): NativeLlmToolDefinition[] { return Object.entries(toolSet).map(([name, tool]) => { - const unknownTool = tool as Record; - const inputSchema = - unknownTool.inputSchema ?? unknownTool.parameters ?? z.object({}); - return { name, - description: - typeof unknownTool.description === 'string' - ? unknownTool.description - : undefined, - parameters: this.toJsonSchema(inputSchema), + description: tool.description, + parameters: this.toJsonSchema(tool.inputSchema ?? z.object({})), }; }); } - private static toJsonSchema(schema: unknown): Record { + static toJsonSchema(schema: unknown): Record { if (!(schema instanceof z.ZodType)) { if (schema && typeof schema === 'object' && !Array.isArray(schema)) { return schema as Record; @@ -228,14 +260,45 @@ export class ToolSchemaExtractor { export class ToolCallLoop { constructor( private readonly dispatch: NativeDispatchFn, - private readonly tools: ToolSet, + private readonly tools: CopilotToolSet, private readonly maxSteps = 20 ) {} + private normalizeToolExecuteOptions( + signalOrOptions?: AbortSignal | CopilotToolExecuteOptions, + maybeMessages?: CopilotToolExecuteOptions['messages'] + ): CopilotToolExecuteOptions { + if ( + signalOrOptions && + typeof signalOrOptions === 'object' && + 'aborted' in signalOrOptions + ) { + return { + signal: signalOrOptions, + messages: maybeMessages, + }; + } + + if (!signalOrOptions) { + return maybeMessages ? { messages: maybeMessages } : {}; + } + + return { + ...signalOrOptions, + signal: signalOrOptions.signal, + messages: signalOrOptions.messages ?? maybeMessages, + }; + } + async *run( request: NativeLlmRequest, - signal?: AbortSignal + signalOrOptions?: AbortSignal | CopilotToolExecuteOptions, + maybeMessages?: CopilotToolExecuteOptions['messages'] ): AsyncIterableIterator { + const toolExecuteOptions = this.normalizeToolExecuteOptions( + signalOrOptions, + maybeMessages + ); const messages = request.messages.map(message => ({ ...message, content: [...message.content], @@ -253,7 +316,7 @@ export class ToolCallLoop { stream: true, messages, }, - signal + toolExecuteOptions.signal )) { switch (event.type) { case 'tool_call_delta': { @@ -291,7 +354,10 @@ export class ToolCallLoop { throw new Error('ToolCallLoop max steps reached'); } - const toolResults = await this.executeTools(toolCalls); + const toolResults = await this.executeTools( + toolCalls, + toolExecuteOptions + ); messages.push({ role: 'assistant', @@ -300,6 +366,8 @@ export class ToolCallLoop { call_id: call.id, name: call.name, arguments: call.args, + arguments_text: call.rawArgumentsText, + arguments_error: call.argumentParseError, thought: call.thought, })), }); @@ -311,6 +379,10 @@ export class ToolCallLoop { { type: 'tool_result', call_id: result.callId, + name: result.name, + arguments: result.args, + arguments_text: result.rawArgumentsText, + arguments_error: result.argumentParseError, output: result.output, is_error: result.isError, }, @@ -321,6 +393,8 @@ export class ToolCallLoop { call_id: result.callId, name: result.name, arguments: result.args, + arguments_text: result.rawArgumentsText, + arguments_error: result.argumentParseError, output: result.output, is_error: result.isError, }; @@ -328,24 +402,28 @@ export class ToolCallLoop { } } - private async executeTools(calls: NativeToolCall[]) { - return await Promise.all(calls.map(call => this.executeTool(call))); + private async executeTools( + calls: NativeToolCall[], + options: CopilotToolExecuteOptions + ) { + return await Promise.all( + calls.map(call => this.executeTool(call, options)) + ); } private async executeTool( - call: NativeToolCall + call: NativeToolCall, + options: CopilotToolExecuteOptions ): Promise { - const tool = this.tools[call.name] as - | { - execute?: (args: Record) => Promise; - } - | undefined; + const tool = this.tools[call.name] as CopilotTool | undefined; if (!tool?.execute) { return { callId: call.id, name: call.name, args: call.args, + rawArgumentsText: call.rawArgumentsText, + argumentParseError: call.argumentParseError, isError: true, output: { message: `Tool not found: ${call.name}`, @@ -353,12 +431,30 @@ export class ToolCallLoop { }; } - try { - const output = await tool.execute(call.args); + if (call.argumentParseError) { return { callId: call.id, name: call.name, args: call.args, + rawArgumentsText: call.rawArgumentsText, + argumentParseError: call.argumentParseError, + isError: true, + output: { + message: 'Invalid tool arguments JSON', + rawArguments: call.rawArgumentsText, + error: call.argumentParseError, + }, + }; + } + + try { + const output = await tool.execute(call.args, options); + return { + callId: call.id, + name: call.name, + args: call.args, + rawArgumentsText: call.rawArgumentsText, + argumentParseError: call.argumentParseError, output: output ?? null, }; } catch (error) { @@ -371,6 +467,8 @@ export class ToolCallLoop { callId: call.id, name: call.name, args: call.args, + rawArgumentsText: call.rawArgumentsText, + argumentParseError: call.argumentParseError, isError: true, output: { message: 'Tool execution failed', diff --git a/packages/backend/server/src/plugins/copilot/providers/morph.ts b/packages/backend/server/src/plugins/copilot/providers/morph.ts index 3d5d2aa961..a4298315a2 100644 --- a/packages/backend/server/src/plugins/copilot/providers/morph.ts +++ b/packages/backend/server/src/plugins/copilot/providers/morph.ts @@ -1,5 +1,3 @@ -import type { ToolSet } from 'ai'; - import { CopilotProviderSideError, metrics, @@ -11,6 +9,7 @@ import { type NativeLlmRequest, } from '../../../native'; import type { NodeTextMiddleware } from '../config'; +import type { CopilotToolSet } from '../tools'; import { buildNativeRequest, NativeProviderAdapter } from './native'; import { CopilotProvider } from './provider'; import type { @@ -86,7 +85,7 @@ export class MorphProvider extends CopilotProvider { } private createNativeAdapter( - tools: ToolSet, + tools: CopilotToolSet, nodeTextMiddleware?: NodeTextMiddleware[] ) { return new NativeProviderAdapter( @@ -108,12 +107,14 @@ export class MorphProvider extends CopilotProvider { messages: PromptMessage[], options: CopilotChatOptions = {} ): Promise { - const fullCond = { - ...cond, - outputType: ModelOutputType.Text, - }; - await this.checkParams({ messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); + const fullCond = { ...cond, outputType: ModelOutputType.Text }; + const model = this.selectModel( + await this.checkParams({ + messages, + cond: fullCond, + options, + }) + ); try { metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); @@ -127,7 +128,7 @@ export class MorphProvider extends CopilotProvider { middleware, }); const adapter = this.createNativeAdapter(tools, middleware.node?.text); - return await adapter.text(request, options.signal); + return await adapter.text(request, options.signal, messages); } catch (e: any) { metrics.ai .counter('chat_text_errors') @@ -141,12 +142,14 @@ export class MorphProvider extends CopilotProvider { messages: PromptMessage[], options: CopilotChatOptions = {} ): AsyncIterable { - const fullCond = { - ...cond, - outputType: ModelOutputType.Text, - }; - await this.checkParams({ messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); + const fullCond = { ...cond, outputType: ModelOutputType.Text }; + const model = this.selectModel( + await this.checkParams({ + messages, + cond: fullCond, + options, + }) + ); try { metrics.ai @@ -162,7 +165,11 @@ export class MorphProvider extends CopilotProvider { middleware, }); const adapter = this.createNativeAdapter(tools, middleware.node?.text); - for await (const chunk of adapter.streamText(request, options.signal)) { + for await (const chunk of adapter.streamText( + request, + options.signal, + messages + )) { yield chunk; } } catch (e: any) { diff --git a/packages/backend/server/src/plugins/copilot/providers/native.ts b/packages/backend/server/src/plugins/copilot/providers/native.ts index 0d79a1a7c8..355e68d735 100644 --- a/packages/backend/server/src/plugins/copilot/providers/native.ts +++ b/packages/backend/server/src/plugins/copilot/providers/native.ts @@ -1,31 +1,41 @@ -import type { ToolSet } from 'ai'; import { ZodType } from 'zod'; +import { CopilotPromptInvalid } from '../../../base'; import type { NativeLlmCoreContent, NativeLlmCoreMessage, + NativeLlmEmbeddingRequest, NativeLlmRequest, NativeLlmStreamEvent, + NativeLlmStructuredRequest, + NativeLlmStructuredResponse, } from '../../../native'; import type { NodeTextMiddleware, ProviderMiddlewareConfig } from '../config'; -import { NativeDispatchFn, ToolCallLoop, ToolSchemaExtractor } from './loop'; -import type { CopilotChatOptions, PromptMessage, StreamObject } from './types'; +import type { CopilotToolSet } from '../tools'; import { - CitationFootnoteFormatter, - inferMimeType, - TextStreamParser, -} from './utils'; - -const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/; + canonicalizePromptAttachment, + type CanonicalPromptAttachment, +} from './attachments'; +import { NativeDispatchFn, ToolCallLoop, ToolSchemaExtractor } from './loop'; +import type { + CopilotChatOptions, + CopilotStructuredOptions, + ModelAttachmentCapability, + PromptMessage, + StreamObject, +} from './types'; +import { CitationFootnoteFormatter, TextStreamParser } from './utils'; type BuildNativeRequestOptions = { model: string; messages: PromptMessage[]; - options?: CopilotChatOptions; - tools?: ToolSet; + options?: CopilotChatOptions | CopilotStructuredOptions; + tools?: CopilotToolSet; withAttachment?: boolean; + attachmentCapability?: ModelAttachmentCapability; include?: string[]; reasoning?: Record; + responseSchema?: unknown; middleware?: ProviderMiddlewareConfig; }; @@ -34,6 +44,11 @@ type BuildNativeRequestResult = { schema?: ZodType; }; +type BuildNativeStructuredRequestResult = { + request: NativeLlmStructuredRequest; + schema: ZodType; +}; + type ToolCallMeta = { name: string; args: Record; @@ -68,9 +83,121 @@ function roleToCore(role: PromptMessage['role']) { } } +function ensureAttachmentSupported( + attachment: CanonicalPromptAttachment, + attachmentCapability?: ModelAttachmentCapability +) { + if (!attachmentCapability) return; + + if (!attachmentCapability.kinds.includes(attachment.kind)) { + throw new CopilotPromptInvalid( + `Native path does not support ${attachment.kind} attachments${ + attachment.mediaType ? ` (${attachment.mediaType})` : '' + }` + ); + } + + if ( + attachmentCapability.sourceKinds?.length && + !attachmentCapability.sourceKinds.includes(attachment.sourceKind) + ) { + throw new CopilotPromptInvalid( + `Native path does not support ${attachment.sourceKind} attachment sources` + ); + } + + if (attachment.isRemote && attachmentCapability.allowRemoteUrls === false) { + throw new CopilotPromptInvalid( + 'Native path does not support remote attachment urls' + ); + } +} + +function resolveResponseSchema( + systemMessage: PromptMessage | undefined, + responseSchema?: unknown +): ZodType | undefined { + if (responseSchema instanceof ZodType) { + return responseSchema; + } + + if (systemMessage?.responseFormat?.schema instanceof ZodType) { + return systemMessage.responseFormat.schema; + } + + return systemMessage?.params?.schema instanceof ZodType + ? systemMessage.params.schema + : undefined; +} + +function resolveResponseStrict( + systemMessage: PromptMessage | undefined, + options?: CopilotStructuredOptions +) { + return options?.strict ?? systemMessage?.responseFormat?.strict ?? true; +} + +export class StructuredResponseParseError extends Error {} + +function normalizeStructuredText(text: string) { + const trimmed = text.replaceAll(/^ny\n/g, ' ').trim(); + if (trimmed.startsWith('```') || trimmed.endsWith('```')) { + return trimmed + .replace(/```[\w\s-]*\n/g, '') + .replace(/\n```/g, '') + .trim(); + } + return trimmed; +} + +export function parseNativeStructuredOutput( + response: Pick & { + output_json?: unknown; + } +) { + if (response.output_json !== undefined) { + return response.output_json; + } + + const normalized = normalizeStructuredText(response.output_text); + const candidates = [ + () => normalized, + () => { + const objectStart = normalized.indexOf('{'); + const objectEnd = normalized.lastIndexOf('}'); + return objectStart !== -1 && objectEnd > objectStart + ? normalized.slice(objectStart, objectEnd + 1) + : null; + }, + () => { + const arrayStart = normalized.indexOf('['); + const arrayEnd = normalized.lastIndexOf(']'); + return arrayStart !== -1 && arrayEnd > arrayStart + ? normalized.slice(arrayStart, arrayEnd + 1) + : null; + }, + ]; + + for (const candidate of candidates) { + try { + const candidateText = candidate(); + if (typeof candidateText === 'string') { + return JSON.parse(candidateText); + } + } catch { + continue; + } + } + + throw new StructuredResponseParseError( + `Unexpected structured response: ${normalized.slice(0, 200)}` + ); +} + async function toCoreContents( message: PromptMessage, - withAttachment: boolean + withAttachment: boolean, + attachmentCapability?: ModelAttachmentCapability ): Promise { const contents: NativeLlmCoreContent[] = []; @@ -81,24 +208,12 @@ async function toCoreContents( if (!withAttachment || !Array.isArray(message.attachments)) return contents; for (const entry of message.attachments) { - let attachmentUrl: string; - let mediaType: string; - - if (typeof entry === 'string') { - attachmentUrl = entry; - mediaType = - typeof message.params?.mimetype === 'string' - ? message.params.mimetype - : await inferMimeType(entry); - } else { - attachmentUrl = entry.attachment; - mediaType = entry.mimeType; - } - - if (!SIMPLE_IMAGE_URL_REGEX.test(attachmentUrl)) continue; - if (!mediaType.startsWith('image/')) continue; - - contents.push({ type: 'image', source: { url: attachmentUrl } }); + const normalized = await canonicalizePromptAttachment(entry, message); + ensureAttachmentSupported(normalized, attachmentCapability); + contents.push({ + type: normalized.kind, + source: normalized.source, + }); } return contents; @@ -110,8 +225,10 @@ export async function buildNativeRequest({ options = {}, tools = {}, withAttachment = true, + attachmentCapability, include, reasoning, + responseSchema, middleware, }: BuildNativeRequestOptions): Promise { const copiedMessages = messages.map(message => ({ @@ -123,10 +240,7 @@ export async function buildNativeRequest({ const systemMessage = copiedMessages[0]?.role === 'system' ? copiedMessages.shift() : undefined; - const schema = - systemMessage?.params?.schema instanceof ZodType - ? systemMessage.params.schema - : undefined; + const schema = resolveResponseSchema(systemMessage, responseSchema); const coreMessages: NativeLlmCoreMessage[] = []; if (systemMessage?.content?.length) { @@ -138,7 +252,11 @@ export async function buildNativeRequest({ for (const message of copiedMessages) { if (message.role === 'system') continue; - const content = await toCoreContents(message, withAttachment); + const content = await toCoreContents( + message, + withAttachment, + attachmentCapability + ); coreMessages.push({ role: roleToCore(message.role), content }); } @@ -153,6 +271,9 @@ export async function buildNativeRequest({ tool_choice: Object.keys(tools).length ? 'auto' : undefined, include, reasoning, + response_schema: schema + ? ToolSchemaExtractor.toJsonSchema(schema) + : undefined, middleware: middleware?.rust ? { request: middleware.rust.request, stream: middleware.rust.stream } : undefined, @@ -161,6 +282,90 @@ export async function buildNativeRequest({ }; } +export async function buildNativeStructuredRequest({ + model, + messages, + options = {}, + withAttachment = true, + attachmentCapability, + reasoning, + responseSchema, + middleware, +}: Omit< + BuildNativeRequestOptions, + 'tools' | 'include' +>): Promise { + const copiedMessages = messages.map(message => ({ + ...message, + attachments: message.attachments + ? [...message.attachments] + : message.attachments, + })); + + const systemMessage = + copiedMessages[0]?.role === 'system' ? copiedMessages.shift() : undefined; + const schema = resolveResponseSchema(systemMessage, responseSchema); + const strict = resolveResponseStrict(systemMessage, options); + + if (!schema) { + throw new CopilotPromptInvalid('Schema is required'); + } + + const coreMessages: NativeLlmCoreMessage[] = []; + if (systemMessage?.content?.length) { + coreMessages.push({ + role: 'system', + content: [{ type: 'text', text: systemMessage.content }], + }); + } + + for (const message of copiedMessages) { + if (message.role === 'system') continue; + const content = await toCoreContents( + message, + withAttachment, + attachmentCapability + ); + coreMessages.push({ role: roleToCore(message.role), content }); + } + + return { + request: { + model, + messages: coreMessages, + schema: ToolSchemaExtractor.toJsonSchema(schema), + max_tokens: options.maxTokens ?? undefined, + temperature: options.temperature ?? undefined, + reasoning, + strict, + response_mime_type: 'application/json', + middleware: middleware?.rust + ? { request: middleware.rust.request } + : undefined, + }, + schema, + }; +} + +export function buildNativeEmbeddingRequest({ + model, + inputs, + dimensions, + taskType = 'RETRIEVAL_DOCUMENT', +}: { + model: string; + inputs: string[]; + dimensions?: number; + taskType?: string; +}): NativeLlmEmbeddingRequest { + return { + model, + inputs, + dimensions, + task_type: taskType, + }; +} + function ensureToolResultMeta( event: Extract, toolCalls: Map @@ -244,7 +449,7 @@ export class NativeProviderAdapter { constructor( dispatch: NativeDispatchFn, - tools: ToolSet, + tools: CopilotToolSet, maxSteps = 20, options: NativeProviderAdapterOptions = {} ) { @@ -259,9 +464,13 @@ export class NativeProviderAdapter { enabledNodeTextMiddlewares.has('citation_footnote'); } - async text(request: NativeLlmRequest, signal?: AbortSignal) { + async text( + request: NativeLlmRequest, + signal?: AbortSignal, + messages?: PromptMessage[] + ) { let output = ''; - for await (const chunk of this.streamText(request, signal)) { + for await (const chunk of this.streamText(request, signal, messages)) { output += chunk; } return output.trim(); @@ -269,7 +478,8 @@ export class NativeProviderAdapter { async *streamText( request: NativeLlmRequest, - signal?: AbortSignal + signal?: AbortSignal, + messages?: PromptMessage[] ): AsyncIterableIterator { const textParser = this.#enableCallout ? new TextStreamParser() : null; const citationFormatter = this.#enableCitationFootnote @@ -278,7 +488,7 @@ export class NativeProviderAdapter { const toolCalls = new Map(); let streamPartId = 0; - for await (const event of this.#loop.run(request, signal)) { + for await (const event of this.#loop.run(request, signal, messages)) { switch (event.type) { case 'text_delta': { if (textParser) { @@ -364,7 +574,8 @@ export class NativeProviderAdapter { async *streamObject( request: NativeLlmRequest, - signal?: AbortSignal + signal?: AbortSignal, + messages?: PromptMessage[] ): AsyncIterableIterator { const toolCalls = new Map(); const citationFormatter = this.#enableCitationFootnote @@ -373,7 +584,7 @@ export class NativeProviderAdapter { const fallbackAttachmentFootnotes = new Map(); let hasFootnoteReference = false; - for await (const event of this.#loop.run(request, signal)) { + for await (const event of this.#loop.run(request, signal, messages)) { switch (event.type) { case 'text_delta': { if (event.text.includes('[^')) { diff --git a/packages/backend/server/src/plugins/copilot/providers/openai.ts b/packages/backend/server/src/plugins/copilot/providers/openai.ts index 34f1c7a29e..318d0845e4 100644 --- a/packages/backend/server/src/plugins/copilot/providers/openai.ts +++ b/packages/backend/server/src/plugins/copilot/providers/openai.ts @@ -1,4 +1,3 @@ -import type { Tool, ToolSet } from 'ai'; import { z } from 'zod'; import { @@ -12,30 +11,41 @@ import { } from '../../../base'; import { llmDispatchStream, + llmEmbeddingDispatch, + llmRerankDispatch, + llmStructuredDispatch, type NativeLlmBackendConfig, + type NativeLlmEmbeddingRequest, type NativeLlmRequest, + type NativeLlmRerankRequest, + type NativeLlmRerankResponse, + type NativeLlmStructuredRequest, } from '../../../native'; import type { NodeTextMiddleware } from '../config'; -import { buildNativeRequest, NativeProviderAdapter } from './native'; -import { CopilotProvider } from './provider'; +import type { CopilotTool, CopilotToolSet } from '../tools'; +import { IMAGE_ATTACHMENT_CAPABILITY } from './attachments'; import { - normalizeRerankModel, - OPENAI_RERANK_MAX_COMPLETION_TOKENS, - OPENAI_RERANK_TOP_LOGPROBS_LIMIT, - usesRerankReasoning, -} from './rerank'; + buildNativeEmbeddingRequest, + buildNativeRequest, + buildNativeStructuredRequest, + NativeProviderAdapter, + parseNativeStructuredOutput, +} from './native'; +import { CopilotProvider } from './provider'; import type { CopilotChatOptions, CopilotChatTools, CopilotEmbeddingOptions, CopilotImageOptions, + CopilotRerankRequest, CopilotStructuredOptions, + ModelCapability, ModelConditions, PromptMessage, StreamObject, } from './types'; import { CopilotProviderType, ModelInputType, ModelOutputType } from './types'; -import { chatToGPTMessage } from './utils'; +import { promptAttachmentToUrl } from './utils'; export const DEFAULT_DIMENSIONS = 256; @@ -91,19 +101,6 @@ const ImageResponseSchema = z.union([ }), }), ]); -const LogProbsSchema = z.array( - z.object({ - token: z.string(), - logprob: z.number(), - top_logprobs: z.array( - z.object({ - token: z.string(), - logprob: z.number(), - }) - ), - }) -); - const TRUSTED_ATTACHMENT_HOST_SUFFIXES = ['cdn.affine.pro']; function normalizeImageFormatToMime(format?: string) { @@ -136,6 +133,34 @@ function normalizeImageResponseData( .filter((value): value is string => typeof value === 'string'); } +function buildOpenAIRerankRequest( + model: string, + request: CopilotRerankRequest +): NativeLlmRerankRequest { + return { + model, + query: request.query, + candidates: request.candidates.map(candidate => ({ + ...(candidate.id ? { id: candidate.id } : {}), + text: candidate.text, + })), + ...(request.topK ? { top_n: request.topK } : {}), + }; +} + +function createOpenAIMultimodalCapability( + output: ModelCapability['output'], + options: Pick = {} +): ModelCapability { + return { + input: [ModelInputType.Text, ModelInputType.Image], + output, + attachments: IMAGE_ATTACHMENT_CAPABILITY, + structuredAttachments: IMAGE_ATTACHMENT_CAPABILITY, + ...options, + }; +} + export class OpenAIProvider extends CopilotProvider { readonly type = CopilotProviderType.OpenAI; @@ -145,10 +170,10 @@ export class OpenAIProvider extends CopilotProvider { name: 'GPT 4o', id: 'gpt-4o', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ]), ], }, // FIXME(@darkskygit): deprecated @@ -156,20 +181,20 @@ export class OpenAIProvider extends CopilotProvider { name: 'GPT 4o 2024-08-06', id: 'gpt-4o-2024-08-06', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ]), ], }, { name: 'GPT 4o Mini', id: 'gpt-4o-mini', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ]), ], }, // FIXME(@darkskygit): deprecated @@ -177,181 +202,158 @@ export class OpenAIProvider extends CopilotProvider { name: 'GPT 4o Mini 2024-07-18', id: 'gpt-4o-mini-2024-07-18', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ]), ], }, { name: 'GPT 4.1', id: 'gpt-4.1', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ + createOpenAIMultimodalCapability( + [ ModelOutputType.Text, ModelOutputType.Object, + ModelOutputType.Rerank, ModelOutputType.Structured, ], - defaultForOutputType: true, - }, + { defaultForOutputType: true } + ), ], }, { name: 'GPT 4.1 2025-04-14', id: 'gpt-4.1-2025-04-14', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Rerank, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT 4.1 Mini', id: 'gpt-4.1-mini', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Rerank, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT 4.1 Nano', id: 'gpt-4.1-nano', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Rerank, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT 5', id: 'gpt-5', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT 5 2025-08-07', id: 'gpt-5-2025-08-07', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT 5 Mini', id: 'gpt-5-mini', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT 5.2', id: 'gpt-5.2', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Rerank, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT 5.2 2025-12-11', id: 'gpt-5.2-2025-12-11', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT 5 Nano', id: 'gpt-5-nano', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT O1', id: 'o1', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ]), ], }, { name: 'GPT O3', id: 'o3', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ]), ], }, { name: 'GPT O4 Mini', id: 'o4-mini', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ]), ], }, // Embedding models @@ -387,11 +389,9 @@ export class OpenAIProvider extends CopilotProvider { { id: 'gpt-image-1', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Image], + createOpenAIMultimodalCapability([ModelOutputType.Image], { defaultForOutputType: true, - }, + }), ], }, ]; @@ -437,7 +437,7 @@ export class OpenAIProvider extends CopilotProvider { override getProviderSpecificTools( toolName: CopilotChatTools, _model: string - ): [string, Tool?] | undefined { + ): [string, CopilotTool?] | undefined { if (toolName === 'docEdit') { return ['doc_edit', undefined]; } @@ -452,14 +452,18 @@ export class OpenAIProvider extends CopilotProvider { }; } + private getNativeProtocol() { + return this.config.oldApiStyle ? 'openai_chat' : 'openai_responses'; + } + private createNativeAdapter( - tools: ToolSet, + tools: CopilotToolSet, nodeTextMiddleware?: NodeTextMiddleware[] ) { return new NativeProviderAdapter( (request: NativeLlmRequest, signal?: AbortSignal) => llmDispatchStream( - this.config.oldApiStyle ? 'openai_chat' : 'openai_responses', + this.getNativeProtocol(), this.createNativeConfig(), request, signal @@ -470,6 +474,27 @@ export class OpenAIProvider extends CopilotProvider { ); } + protected createNativeStructuredDispatch( + backendConfig: NativeLlmBackendConfig + ) { + return (request: NativeLlmStructuredRequest) => + llmStructuredDispatch(this.getNativeProtocol(), backendConfig, request); + } + + protected createNativeEmbeddingDispatch( + backendConfig: NativeLlmBackendConfig + ) { + return (request: NativeLlmEmbeddingRequest) => + llmEmbeddingDispatch(this.getNativeProtocol(), backendConfig, request); + } + + protected createNativeRerankDispatch(backendConfig: NativeLlmBackendConfig) { + return ( + request: NativeLlmRerankRequest + ): Promise => + llmRerankDispatch('openai_chat', backendConfig, request); + } + private getReasoning( options: NonNullable, model: string @@ -486,13 +511,18 @@ export class OpenAIProvider extends CopilotProvider { options: CopilotChatOptions = {} ): Promise { const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + messages, + cond: fullCond, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); const tools = await this.getTools(options, model.id); const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Text); const normalizedOptions = normalizeOpenAIOptionsForModel( options, model.id @@ -502,12 +532,13 @@ export class OpenAIProvider extends CopilotProvider { messages, options: normalizedOptions, tools, + attachmentCapability: cap, include: options.webSearch ? ['citations'] : undefined, reasoning: this.getReasoning(options, model.id), middleware, }); const adapter = this.createNativeAdapter(tools, middleware.node?.text); - return await adapter.text(request, options.signal); + return await adapter.text(request, options.signal, messages); } catch (e: any) { metrics.ai .counter('chat_text_errors') @@ -525,8 +556,12 @@ export class OpenAIProvider extends CopilotProvider { ...cond, outputType: ModelOutputType.Text, }; - await this.checkParams({ messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + messages, + cond: fullCond, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai @@ -534,6 +569,7 @@ export class OpenAIProvider extends CopilotProvider { .add(1, this.metricLabels(model.id)); const tools = await this.getTools(options, model.id); const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Text); const normalizedOptions = normalizeOpenAIOptionsForModel( options, model.id @@ -543,12 +579,17 @@ export class OpenAIProvider extends CopilotProvider { messages, options: normalizedOptions, tools, + attachmentCapability: cap, include: options.webSearch ? ['citations'] : undefined, reasoning: this.getReasoning(options, model.id), middleware, }); const adapter = this.createNativeAdapter(tools, middleware.node?.text); - for await (const chunk of adapter.streamText(request, options.signal)) { + for await (const chunk of adapter.streamText( + request, + options.signal, + messages + )) { yield chunk; } } catch (e: any) { @@ -565,8 +606,12 @@ export class OpenAIProvider extends CopilotProvider { options: CopilotChatOptions = {} ): AsyncIterable { const fullCond = { ...cond, outputType: ModelOutputType.Object }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai @@ -574,6 +619,7 @@ export class OpenAIProvider extends CopilotProvider { .add(1, this.metricLabels(model.id)); const tools = await this.getTools(options, model.id); const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Object); const normalizedOptions = normalizeOpenAIOptionsForModel( options, model.id @@ -583,12 +629,17 @@ export class OpenAIProvider extends CopilotProvider { messages, options: normalizedOptions, tools, + attachmentCapability: cap, include: options.webSearch ? ['citations'] : undefined, reasoning: this.getReasoning(options, model.id), middleware, }); const adapter = this.createNativeAdapter(tools, middleware.node?.text); - for await (const chunk of adapter.streamObject(request, options.signal)) { + for await (const chunk of adapter.streamObject( + request, + options.signal, + messages + )) { yield chunk; } } catch (e: any) { @@ -605,31 +656,34 @@ export class OpenAIProvider extends CopilotProvider { options: CopilotStructuredOptions = {} ): Promise { const fullCond = { ...cond, outputType: ModelOutputType.Structured }; - await this.checkParams({ messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + messages, + cond: fullCond, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai.counter('chat_text_calls').add(1, { model: model.id }); - const tools = await this.getTools(options, model.id); + const backendConfig = this.createNativeConfig(); const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Structured); const normalizedOptions = normalizeOpenAIOptionsForModel( options, model.id ); - const { request, schema } = await buildNativeRequest({ + const { request, schema } = await buildNativeStructuredRequest({ model: model.id, messages, options: normalizedOptions, - tools, + attachmentCapability: cap, reasoning: this.getReasoning(options, model.id), + responseSchema: options.schema, middleware, }); - if (!schema) { - throw new CopilotPromptInvalid('Schema is required'); - } - const adapter = this.createNativeAdapter(tools, middleware.node?.text); - const text = await adapter.text(request, options.signal); - const parsed = JSON.parse(text); + const response = + await this.createNativeStructuredDispatch(backendConfig)(request); + const parsed = parseNativeStructuredOutput(response); const validated = schema.parse(parsed); return JSON.stringify(validated); } catch (e: any) { @@ -640,71 +694,26 @@ export class OpenAIProvider extends CopilotProvider { override async rerank( cond: ModelConditions, - chunkMessages: PromptMessage[][], + request: CopilotRerankRequest, options: CopilotChatOptions = {} ): Promise { - const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ messages: [], cond: fullCond, options }); - const model = this.selectModel(fullCond); + const fullCond = { ...cond, outputType: ModelOutputType.Rerank }; + const normalizedCond = await this.checkParams({ + messages: [], + cond: fullCond, + options, + }); + const model = this.selectModel(normalizedCond); - const scores = await Promise.all( - chunkMessages.map(async messages => { - const [system, msgs] = await chatToGPTMessage(messages); - const rerankModel = normalizeRerankModel(model.id); - const response = await this.requestOpenAIJson( - '/chat/completions', - { - model: rerankModel, - messages: this.toOpenAIChatMessages(system, msgs), - temperature: 0, - logprobs: true, - top_logprobs: OPENAI_RERANK_TOP_LOGPROBS_LIMIT, - ...(usesRerankReasoning(rerankModel) - ? { - reasoning_effort: 'none' as const, - max_completion_tokens: OPENAI_RERANK_MAX_COMPLETION_TOKENS, - } - : { max_tokens: OPENAI_RERANK_MAX_COMPLETION_TOKENS }), - }, - options.signal - ); - - const logprobs = response?.choices?.[0]?.logprobs?.content; - if (!Array.isArray(logprobs) || logprobs.length === 0) { - return 0; - } - - const parsedLogprobs = LogProbsSchema.parse(logprobs); - const topMap = parsedLogprobs[0].top_logprobs.reduce( - (acc, { token, logprob }) => ({ ...acc, [token]: logprob }), - {} as Record - ); - - const findLogProb = (token: string): number => { - // OpenAI often includes a leading space, so try matching '.yes', '_yes', ' yes' and 'yes' - return [...'_:. "-\t,(=_“'.split('').map(c => c + token), token] - .flatMap(v => [v, v.toLowerCase(), v.toUpperCase()]) - .reduce( - (best, key) => - (topMap[key] ?? Number.NEGATIVE_INFINITY) > best - ? topMap[key] - : best, - Number.NEGATIVE_INFINITY - ); - }; - - const logYes = findLogProb('Yes'); - const logNo = findLogProb('No'); - - const pYes = Math.exp(logYes); - const pNo = Math.exp(logNo); - const prob = pYes + pNo === 0 ? 0 : pYes / (pYes + pNo); - - return prob; - }) - ); - - return scores; + try { + const backendConfig = this.createNativeConfig(); + const nativeRequest = buildOpenAIRerankRequest(model.id, request); + const response = + await this.createNativeRerankDispatch(backendConfig)(nativeRequest); + return response.scores; + } catch (e: any) { + throw this.handleError(e); + } } // ====== text to image ====== @@ -906,7 +915,8 @@ export class OpenAIProvider extends CopilotProvider { form.set('output_format', outputFormat); for (const [idx, entry] of attachments.entries()) { - const url = typeof entry === 'string' ? entry : entry.attachment; + const url = promptAttachmentToUrl(entry); + if (!url) continue; try { const attachment = await this.fetchImage(url, maxBytes, signal); if (!attachment) continue; @@ -964,8 +974,12 @@ export class OpenAIProvider extends CopilotProvider { options: CopilotImageOptions = {} ) { const fullCond = { ...cond, outputType: ModelOutputType.Image }; - await this.checkParams({ messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + messages, + cond: fullCond, + options, + }); + const model = this.selectModel(normalizedCond); metrics.ai .counter('generate_images_stream_calls') @@ -1017,65 +1031,36 @@ export class OpenAIProvider extends CopilotProvider { messages: string | string[], options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS } ): Promise { - messages = Array.isArray(messages) ? messages : [messages]; + const input = Array.isArray(messages) ? messages : [messages]; const fullCond = { ...cond, outputType: ModelOutputType.Embedding }; - await this.checkParams({ embeddings: messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + embeddings: input, + cond: fullCond, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai .counter('generate_embedding_calls') - .add(1, { model: model.id }); - const response = await this.requestOpenAIJson('/embeddings', { - model: model.id, - input: messages, - dimensions: options.dimensions || DEFAULT_DIMENSIONS, - }); - const data = Array.isArray(response?.data) ? response.data : []; - return data - .map((item: any) => item?.embedding) - .filter((embedding: unknown) => Array.isArray(embedding)) as number[][]; + .add(1, this.metricLabels(model.id)); + const backendConfig = this.createNativeConfig(); + const response = await this.createNativeEmbeddingDispatch(backendConfig)( + buildNativeEmbeddingRequest({ + model: model.id, + inputs: input, + dimensions: options.dimensions || DEFAULT_DIMENSIONS, + }) + ); + return response.embeddings; } catch (e: any) { metrics.ai .counter('generate_embedding_errors') - .add(1, { model: model.id }); + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } - private toOpenAIChatMessages( - system: string | undefined, - messages: Awaited>[1] - ) { - const result: Array<{ role: string; content: string }> = []; - if (system) { - result.push({ role: 'system', content: system }); - } - - for (const message of messages) { - if (typeof message.content === 'string') { - result.push({ role: message.role, content: message.content }); - continue; - } - - const text = message.content - .filter( - part => - part && - typeof part === 'object' && - 'type' in part && - part.type === 'text' && - 'text' in part - ) - .map(part => String((part as { text: string }).text)) - .join('\n'); - - result.push({ role: message.role, content: text || '[no content]' }); - } - - return result; - } - private async requestOpenAIJson( path: string, body: Record, diff --git a/packages/backend/server/src/plugins/copilot/providers/perplexity.ts b/packages/backend/server/src/plugins/copilot/providers/perplexity.ts index f1dd4d6663..0bfd030779 100644 --- a/packages/backend/server/src/plugins/copilot/providers/perplexity.ts +++ b/packages/backend/server/src/plugins/copilot/providers/perplexity.ts @@ -1,5 +1,3 @@ -import type { ToolSet } from 'ai'; - import { CopilotProviderSideError, metrics } from '../../../base'; import { llmDispatchStream, @@ -7,6 +5,7 @@ import { type NativeLlmRequest, } from '../../../native'; import type { NodeTextMiddleware } from '../config'; +import type { CopilotToolSet } from '../tools'; import { buildNativeRequest, NativeProviderAdapter } from './native'; import { CopilotProvider } from './provider'; import { @@ -87,7 +86,7 @@ export class PerplexityProvider extends CopilotProvider { } private createNativeAdapter( - tools: ToolSet, + tools: CopilotToolSet, nodeTextMiddleware?: NodeTextMiddleware[] ) { return new NativeProviderAdapter( @@ -110,8 +109,13 @@ export class PerplexityProvider extends CopilotProvider { options: CopilotChatOptions = {} ): Promise { const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + withAttachment: false, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); @@ -128,7 +132,7 @@ export class PerplexityProvider extends CopilotProvider { middleware, }); const adapter = this.createNativeAdapter(tools, middleware.node?.text); - return await adapter.text(request, options.signal); + return await adapter.text(request, options.signal, messages); } catch (e: any) { metrics.ai .counter('chat_text_errors') @@ -143,8 +147,13 @@ export class PerplexityProvider extends CopilotProvider { options: CopilotChatOptions = {} ): AsyncIterable { const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + withAttachment: false, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai @@ -163,7 +172,11 @@ export class PerplexityProvider extends CopilotProvider { middleware, }); const adapter = this.createNativeAdapter(tools, middleware.node?.text); - for await (const chunk of adapter.streamText(request, options.signal)) { + for await (const chunk of adapter.streamText( + request, + options.signal, + messages + )) { yield chunk; } } catch (e: any) { diff --git a/packages/backend/server/src/plugins/copilot/providers/provider-middleware.ts b/packages/backend/server/src/plugins/copilot/providers/provider-middleware.ts index 98dc64cefc..d43d30c5bc 100644 --- a/packages/backend/server/src/plugins/copilot/providers/provider-middleware.ts +++ b/packages/backend/server/src/plugins/copilot/providers/provider-middleware.ts @@ -51,13 +51,21 @@ const DEFAULT_MIDDLEWARE_BY_TYPE: Record< }, }, [CopilotProviderType.Gemini]: { + rust: { + request: ['normalize_messages', 'tool_schema_rewrite'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, node: { - text: ['callout'], + text: ['citation_footnote', 'callout'], }, }, [CopilotProviderType.GeminiVertex]: { + rust: { + request: ['normalize_messages', 'tool_schema_rewrite'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, node: { - text: ['callout'], + text: ['citation_footnote', 'callout'], }, }, [CopilotProviderType.FAL]: {}, diff --git a/packages/backend/server/src/plugins/copilot/providers/provider-registry.ts b/packages/backend/server/src/plugins/copilot/providers/provider-registry.ts index a877bd4cff..73d1bbdc8b 100644 --- a/packages/backend/server/src/plugins/copilot/providers/provider-registry.ts +++ b/packages/backend/server/src/plugins/copilot/providers/provider-registry.ts @@ -5,7 +5,7 @@ import type { ProviderMiddlewareConfig, } from '../config'; import { resolveProviderMiddleware } from './provider-middleware'; -import { CopilotProviderType, type ModelOutputType } from './types'; +import { CopilotProviderType, ModelOutputType } from './types'; const PROVIDER_ID_PATTERN = /^[a-zA-Z0-9-_]+$/; @@ -239,8 +239,13 @@ export function resolveModel({ }; } + const defaultProviderId = + outputType && outputType !== ModelOutputType.Rerank + ? registry.defaults[outputType] + : undefined; + const fallbackOrder = [ - ...(outputType ? [registry.defaults[outputType]] : []), + ...(defaultProviderId ? [defaultProviderId] : []), registry.defaults.fallback, ...registry.order, ].filter((id): id is string => !!id); diff --git a/packages/backend/server/src/plugins/copilot/providers/provider.ts b/packages/backend/server/src/plugins/copilot/providers/provider.ts index 8d594817ce..5ffab6818b 100644 --- a/packages/backend/server/src/plugins/copilot/providers/provider.ts +++ b/packages/backend/server/src/plugins/copilot/providers/provider.ts @@ -2,7 +2,6 @@ import { AsyncLocalStorage } from 'node:async_hooks'; import { Inject, Injectable, Logger } from '@nestjs/common'; import { ModuleRef } from '@nestjs/core'; -import { Tool, ToolSet } from 'ai'; import { z } from 'zod'; import { @@ -27,6 +26,8 @@ import { buildDocSearchGetter, buildDocUpdateHandler, buildDocUpdateMetaHandler, + type CopilotTool, + type CopilotToolSet, createBlobReadTool, createCodeArtifactTool, createConversationSummaryTool, @@ -42,6 +43,7 @@ import { createExaSearchTool, createSectionEditTool, } from '../tools'; +import { canonicalizePromptAttachment } from './attachments'; import { CopilotProviderFactory } from './factory'; import { resolveProviderMiddleware } from './provider-middleware'; import { buildProviderRegistry } from './provider-registry'; @@ -52,12 +54,17 @@ import { type CopilotImageOptions, CopilotProviderModel, CopilotProviderType, + type CopilotRerankRequest, CopilotStructuredOptions, EmbeddingMessage, + type ModelAttachmentCapability, ModelCapability, ModelConditions, ModelFullConditions, ModelInputType, + ModelOutputType, + type PromptAttachmentKind, + type PromptAttachmentSourceKind, type PromptMessage, PromptMessageSchema, StreamObject, @@ -163,6 +170,163 @@ export abstract class CopilotProvider { async refreshOnlineModels() {} + private unique(values: Iterable) { + return Array.from(new Set(values)); + } + + private attachmentKindToInputType( + kind: PromptAttachmentKind + ): ModelInputType { + switch (kind) { + case 'image': + return ModelInputType.Image; + case 'audio': + return ModelInputType.Audio; + default: + return ModelInputType.File; + } + } + + protected async inferModelConditionsFromMessages( + messages?: PromptMessage[], + withAttachment = true + ): Promise> { + if (!messages?.length || !withAttachment) return {}; + + const attachmentKinds: PromptAttachmentKind[] = []; + const attachmentSourceKinds: PromptAttachmentSourceKind[] = []; + const inputTypes: ModelInputType[] = []; + let hasRemoteAttachments = false; + + for (const message of messages) { + if (!Array.isArray(message.attachments)) continue; + + for (const attachment of message.attachments) { + const normalized = await canonicalizePromptAttachment( + attachment, + message + ); + attachmentKinds.push(normalized.kind); + inputTypes.push(this.attachmentKindToInputType(normalized.kind)); + attachmentSourceKinds.push(normalized.sourceKind); + hasRemoteAttachments = hasRemoteAttachments || normalized.isRemote; + } + } + + return { + ...(attachmentKinds.length + ? { attachmentKinds: this.unique(attachmentKinds) } + : {}), + ...(attachmentSourceKinds.length + ? { attachmentSourceKinds: this.unique(attachmentSourceKinds) } + : {}), + ...(inputTypes.length ? { inputTypes: this.unique(inputTypes) } : {}), + ...(hasRemoteAttachments ? { hasRemoteAttachments } : {}), + }; + } + + private mergeModelConditions( + cond: ModelFullConditions, + inferredCond: Partial + ): ModelFullConditions { + return { + ...inferredCond, + ...cond, + inputTypes: this.unique([ + ...(inferredCond.inputTypes ?? []), + ...(cond.inputTypes ?? []), + ]), + attachmentKinds: this.unique([ + ...(inferredCond.attachmentKinds ?? []), + ...(cond.attachmentKinds ?? []), + ]), + attachmentSourceKinds: this.unique([ + ...(inferredCond.attachmentSourceKinds ?? []), + ...(cond.attachmentSourceKinds ?? []), + ]), + hasRemoteAttachments: + cond.hasRemoteAttachments ?? inferredCond.hasRemoteAttachments, + }; + } + + protected getAttachCapability( + model: CopilotProviderModel, + outputType: ModelOutputType + ): ModelAttachmentCapability | undefined { + const capability = + model.capabilities.find(cap => cap.output.includes(outputType)) ?? + model.capabilities[0]; + if (!capability) { + return; + } + return this.resolveAttachmentCapability(capability, outputType); + } + + private resolveAttachmentCapability( + cap: ModelCapability, + outputType?: ModelOutputType + ): ModelAttachmentCapability | undefined { + if (outputType === ModelOutputType.Structured) { + return cap.structuredAttachments ?? cap.attachments; + } + return cap.attachments; + } + + private matchesAttachCapability( + cap: ModelCapability, + cond: ModelFullConditions + ) { + const { + attachmentKinds, + attachmentSourceKinds, + hasRemoteAttachments, + outputType, + } = cond; + + if ( + !attachmentKinds?.length && + !attachmentSourceKinds?.length && + !hasRemoteAttachments + ) { + return true; + } + + const attachmentCapability = this.resolveAttachmentCapability( + cap, + outputType + ); + if (!attachmentCapability) { + return !attachmentKinds?.some( + kind => !cap.input.includes(this.attachmentKindToInputType(kind)) + ); + } + + if ( + attachmentKinds?.some(kind => !attachmentCapability.kinds.includes(kind)) + ) { + return false; + } + + if ( + attachmentSourceKinds?.length && + attachmentCapability.sourceKinds?.length && + attachmentSourceKinds.some( + kind => !attachmentCapability.sourceKinds?.includes(kind) + ) + ) { + return false; + } + + if ( + hasRemoteAttachments && + attachmentCapability.allowRemoteUrls === false + ) { + return false; + } + + return true; + } + private findValidModel( cond: ModelFullConditions ): CopilotProviderModel | undefined { @@ -170,7 +334,8 @@ export abstract class CopilotProvider { const matcher = (cap: ModelCapability) => (!outputType || cap.output.includes(outputType)) && (!inputTypes?.length || - inputTypes.every(type => cap.input.includes(type))); + inputTypes.every(type => cap.input.includes(type))) && + this.matchesAttachCapability(cap, cond); if (modelId) { const hasOnlineModel = this.onlineModelList.includes(modelId); @@ -213,7 +378,7 @@ export abstract class CopilotProvider { protected getProviderSpecificTools( _toolName: CopilotChatTools, _model: string - ): [string, Tool?] | undefined { + ): [string, CopilotTool?] | undefined { return; } @@ -221,8 +386,8 @@ export abstract class CopilotProvider { protected async getTools( options: CopilotChatOptions, model: string - ): Promise { - const tools: ToolSet = {}; + ): Promise { + const tools: CopilotToolSet = {}; if (options?.tools?.length) { this.logger.debug(`getTools: ${JSON.stringify(options.tools)}`); const ac = this.moduleRef.get(AccessController, { strict: false }); @@ -377,19 +542,14 @@ export abstract class CopilotProvider { messages, embeddings, options = {}, + withAttachment = true, }: { cond: ModelFullConditions; messages?: PromptMessage[]; embeddings?: string[]; - options?: CopilotChatOptions; - }) { - const model = this.selectModel(cond); - const multimodal = model.capabilities.some(c => - [ModelInputType.Image, ModelInputType.Audio].some(t => - c.input.includes(t) - ) - ); - + options?: CopilotChatOptions | CopilotStructuredOptions; + withAttachment?: boolean; + }): Promise { if (messages) { const { requireContent = true, requireAttachment = false } = options; @@ -402,20 +562,56 @@ export abstract class CopilotProvider { }) .passthrough() .catchall(z.union([z.string(), z.number(), z.date(), z.null()])) - .refine( - m => - !(multimodal && requireAttachment && m.role === 'user') || - (m.attachments ? m.attachments.length > 0 : true), - { message: 'attachments required in multimodal mode' } - ) ) .optional(); this.handleZodError(MessageSchema.safeParse(messages)); + + const inferredCond = await this.inferModelConditionsFromMessages( + messages, + withAttachment + ); + const mergedCond = this.mergeModelConditions(cond, inferredCond); + const model = this.selectModel(mergedCond); + const multimodal = model.capabilities.some(c => + [ModelInputType.Image, ModelInputType.Audio, ModelInputType.File].some( + t => c.input.includes(t) + ) + ); + + if ( + multimodal && + requireAttachment && + !messages.some( + message => + message.role === 'user' && + Array.isArray(message.attachments) && + message.attachments.length > 0 + ) + ) { + throw new CopilotPromptInvalid( + 'attachments required in multimodal mode' + ); + } + + if (embeddings) { + this.handleZodError(EmbeddingMessage.safeParse(embeddings)); + } + + return mergedCond; } + + const inferredCond = await this.inferModelConditionsFromMessages( + messages, + withAttachment + ); + const mergedCond = this.mergeModelConditions(cond, inferredCond); + if (embeddings) { this.handleZodError(EmbeddingMessage.safeParse(embeddings)); } + + return mergedCond; } abstract text( @@ -476,7 +672,7 @@ export abstract class CopilotProvider { async rerank( _model: ModelConditions, - _messages: PromptMessage[][], + _request: CopilotRerankRequest, _options?: CopilotChatOptions ): Promise { throw new CopilotProviderNotSupported({ diff --git a/packages/backend/server/src/plugins/copilot/providers/rerank.ts b/packages/backend/server/src/plugins/copilot/providers/rerank.ts deleted file mode 100644 index 528e5066e0..0000000000 --- a/packages/backend/server/src/plugins/copilot/providers/rerank.ts +++ /dev/null @@ -1,23 +0,0 @@ -const GPT_4_RERANK_MODELS = /^(gpt-4(?:$|[.-]))/; -const GPT_5_RERANK_LOGPROBS_MODELS = /^(gpt-5\.2(?:$|-))/; - -export const DEFAULT_RERANK_MODEL = 'gpt-5.2'; -export const OPENAI_RERANK_TOP_LOGPROBS_LIMIT = 5; -export const OPENAI_RERANK_MAX_COMPLETION_TOKENS = 16; - -export function supportsRerankModel(model: string): boolean { - return ( - GPT_4_RERANK_MODELS.test(model) || GPT_5_RERANK_LOGPROBS_MODELS.test(model) - ); -} - -export function usesRerankReasoning(model: string): boolean { - return GPT_5_RERANK_LOGPROBS_MODELS.test(model); -} - -export function normalizeRerankModel(model?: string | null): string { - if (model && supportsRerankModel(model)) { - return model; - } - return DEFAULT_RERANK_MODEL; -} diff --git a/packages/backend/server/src/plugins/copilot/providers/types.ts b/packages/backend/server/src/plugins/copilot/providers/types.ts index 2a06e832f3..9da70b97a4 100644 --- a/packages/backend/server/src/plugins/copilot/providers/types.ts +++ b/packages/backend/server/src/plugins/copilot/providers/types.ts @@ -124,14 +124,97 @@ export const ChatMessageRole = Object.values(AiPromptRole) as [ 'user', ]; +const AttachmentUrlSchema = z.string().refine(value => { + if (value.startsWith('data:')) { + return true; + } + + try { + const url = new URL(value); + return ( + url.protocol === 'http:' || + url.protocol === 'https:' || + url.protocol === 'gs:' + ); + } catch { + return false; + } +}, 'attachments must use https?://, gs:// or data: urls'); + +export const PromptAttachmentSourceKindSchema = z.enum([ + 'url', + 'data', + 'bytes', + 'file_handle', +]); + +export const PromptAttachmentKindSchema = z.enum(['image', 'audio', 'file']); + +const AttachmentProviderHintSchema = z + .object({ + provider: z.nativeEnum(CopilotProviderType).optional(), + kind: PromptAttachmentKindSchema.optional(), + }) + .strict(); + +const PromptAttachmentSchema = z.discriminatedUnion('kind', [ + z + .object({ + kind: z.literal('url'), + url: AttachmentUrlSchema, + mimeType: z.string().optional(), + fileName: z.string().optional(), + providerHint: AttachmentProviderHintSchema.optional(), + }) + .strict(), + z + .object({ + kind: z.literal('data'), + data: z.string(), + mimeType: z.string(), + encoding: z.enum(['base64', 'utf8']).optional(), + fileName: z.string().optional(), + providerHint: AttachmentProviderHintSchema.optional(), + }) + .strict(), + z + .object({ + kind: z.literal('bytes'), + data: z.string(), + mimeType: z.string(), + encoding: z.literal('base64').optional(), + fileName: z.string().optional(), + providerHint: AttachmentProviderHintSchema.optional(), + }) + .strict(), + z + .object({ + kind: z.literal('file_handle'), + fileHandle: z.string().trim().min(1), + mimeType: z.string().optional(), + fileName: z.string().optional(), + providerHint: AttachmentProviderHintSchema.optional(), + }) + .strict(), +]); + export const ChatMessageAttachment = z.union([ - z.string().url(), + AttachmentUrlSchema, z.object({ - attachment: z.string(), + attachment: AttachmentUrlSchema, mimeType: z.string(), }), + PromptAttachmentSchema, ]); +export const PromptResponseFormatSchema = z + .object({ + type: z.literal('json_schema'), + schema: z.any(), + strict: z.boolean().optional(), + }) + .strict(); + export const StreamObjectSchema = z.discriminatedUnion('type', [ z.object({ type: z.literal('text-delta'), @@ -161,6 +244,7 @@ export const PureMessageSchema = z.object({ streamObjects: z.array(StreamObjectSchema).optional().nullable(), attachments: z.array(ChatMessageAttachment).optional().nullable(), params: z.record(z.any()).optional().nullable(), + responseFormat: PromptResponseFormatSchema.optional().nullable(), }); export const PromptMessageSchema = PureMessageSchema.extend({ @@ -169,6 +253,12 @@ export const PromptMessageSchema = PureMessageSchema.extend({ export type PromptMessage = z.infer; export type PromptParams = NonNullable; export type StreamObject = z.infer; +export type PromptAttachment = z.infer; +export type PromptAttachmentSourceKind = z.infer< + typeof PromptAttachmentSourceKindSchema +>; +export type PromptAttachmentKind = z.infer; +export type PromptResponseFormat = z.infer; // ========== options ========== @@ -194,7 +284,9 @@ export type CopilotChatTools = NonNullable< >[number]; export const CopilotStructuredOptionsSchema = - CopilotProviderOptionsSchema.merge(PromptConfigStrictSchema).optional(); + CopilotProviderOptionsSchema.merge(PromptConfigStrictSchema) + .extend({ schema: z.any().optional(), strict: z.boolean().optional() }) + .optional(); export type CopilotStructuredOptions = z.infer< typeof CopilotStructuredOptionsSchema @@ -220,10 +312,22 @@ export type CopilotEmbeddingOptions = z.infer< typeof CopilotEmbeddingOptionsSchema >; +export type CopilotRerankCandidate = { + id?: string; + text: string; +}; + +export type CopilotRerankRequest = { + query: string; + candidates: CopilotRerankCandidate[]; + topK?: number; +}; + export enum ModelInputType { Text = 'text', Image = 'image', Audio = 'audio', + File = 'file', } export enum ModelOutputType { @@ -231,12 +335,21 @@ export enum ModelOutputType { Object = 'object', Embedding = 'embedding', Image = 'image', + Rerank = 'rerank', Structured = 'structured', } +export interface ModelAttachmentCapability { + kinds: PromptAttachmentKind[]; + sourceKinds?: PromptAttachmentSourceKind[]; + allowRemoteUrls?: boolean; +} + export interface ModelCapability { input: ModelInputType[]; output: ModelOutputType[]; + attachments?: ModelAttachmentCapability; + structuredAttachments?: ModelAttachmentCapability; defaultForOutputType?: boolean; } @@ -248,6 +361,9 @@ export interface CopilotProviderModel { export type ModelConditions = { inputTypes?: ModelInputType[]; + attachmentKinds?: PromptAttachmentKind[]; + attachmentSourceKinds?: PromptAttachmentSourceKind[]; + hasRemoteAttachments?: boolean; modelId?: string; }; diff --git a/packages/backend/server/src/plugins/copilot/providers/utils.ts b/packages/backend/server/src/plugins/copilot/providers/utils.ts index 26224ff72e..c20959adf9 100644 --- a/packages/backend/server/src/plugins/copilot/providers/utils.ts +++ b/packages/backend/server/src/plugins/copilot/providers/utils.ts @@ -1,34 +1,39 @@ -import { GoogleVertexProviderSettings } from '@ai-sdk/google-vertex'; -import { GoogleVertexAnthropicProviderSettings } from '@ai-sdk/google-vertex/anthropic'; import { Logger } from '@nestjs/common'; -import { - AssistantModelMessage, - FilePart, - ImagePart, - TextPart, - TextStreamPart, - UserModelMessage, -} from 'ai'; import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; -import z, { ZodType } from 'zod'; +import z from 'zod'; -import { - bufferToArrayBuffer, - fetchBuffer, - OneMinute, - ResponseTooLargeError, - safeFetch, - SsrfBlockedError, -} from '../../../base'; -import { CustomAITools } from '../tools'; -import { PromptMessage, StreamObject } from './types'; +import { OneMinute, safeFetch } from '../../../base'; +import { PromptAttachment, StreamObject } from './types'; -type ChatMessage = UserModelMessage | AssistantModelMessage; +export type VertexProviderConfig = { + location?: string; + project?: string; + baseURL?: string; + googleAuthOptions?: GoogleAuthOptions; + fetch?: typeof fetch; +}; + +export type VertexAnthropicProviderConfig = VertexProviderConfig; + +type CopilotTextStreamPart = + | { type: 'text-delta'; text: string; id?: string } + | { type: 'reasoning-delta'; text: string; id?: string } + | { + type: 'tool-call'; + toolCallId: string; + toolName: string; + input: Record; + } + | { + type: 'tool-result'; + toolCallId: string; + toolName: string; + input: Record; + output: unknown; + } + | { type: 'error'; error: unknown }; -const ATTACHMENT_MAX_BYTES = 20 * 1024 * 1024; const ATTACH_HEAD_PARAMS = { timeoutMs: OneMinute / 12, maxRedirects: 3 }; - -const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/; const FORMAT_INFER_MAP: Record = { pdf: 'application/pdf', mp3: 'audio/mpeg', @@ -53,9 +58,39 @@ const FORMAT_INFER_MAP: Record = { flv: 'video/flv', }; -async function fetchArrayBuffer(url: string): Promise { - const { buffer } = await fetchBuffer(url, ATTACHMENT_MAX_BYTES); - return bufferToArrayBuffer(buffer); +function toBase64Data(data: string, encoding: 'base64' | 'utf8' = 'base64') { + return encoding === 'base64' + ? data + : Buffer.from(data, 'utf8').toString('base64'); +} + +export function promptAttachmentToUrl( + attachment: PromptAttachment +): string | undefined { + if (typeof attachment === 'string') return attachment; + if ('attachment' in attachment) return attachment.attachment; + switch (attachment.kind) { + case 'url': + return attachment.url; + case 'data': + return `data:${attachment.mimeType};base64,${toBase64Data( + attachment.data, + attachment.encoding + )}`; + case 'bytes': + return `data:${attachment.mimeType};base64,${attachment.data}`; + case 'file_handle': + return; + } +} + +export function promptAttachmentMimeType( + attachment: PromptAttachment, + fallbackMimeType?: string +): string | undefined { + if (typeof attachment === 'string') return fallbackMimeType; + if ('attachment' in attachment) return attachment.mimeType; + return attachment.mimeType ?? fallbackMimeType; } export async function inferMimeType(url: string) { @@ -69,346 +104,21 @@ export async function inferMimeType(url: string) { if (ext) { return ext; } - try { - const mimeType = await safeFetch( - url, - { method: 'HEAD' }, - ATTACH_HEAD_PARAMS - ).then(res => res.headers.get('content-type')); - if (mimeType) return mimeType; - } catch { - // ignore and fallback to default - } + } + try { + const mimeType = await safeFetch( + url, + { method: 'HEAD' }, + ATTACH_HEAD_PARAMS + ).then(res => res.headers.get('content-type')); + if (mimeType) return mimeType; + } catch { + // ignore and fallback to default } return 'application/octet-stream'; } -export async function chatToGPTMessage( - messages: PromptMessage[], - // TODO(@darkskygit): move this logic in interface refactoring - withAttachment: boolean = true, - // NOTE: some providers in vercel ai sdk are not able to handle url attachments yet - // so we need to use base64 encoded attachments instead - useBase64Attachment: boolean = false -): Promise<[string | undefined, ChatMessage[], ZodType?]> { - const hasSystem = messages[0]?.role === 'system'; - const system = hasSystem ? messages[0] : undefined; - const normalizedMessages = hasSystem ? messages.slice(1) : messages; - const schema = - system?.params?.schema && system.params.schema instanceof ZodType - ? system.params.schema - : undefined; - - // filter redundant fields - const msgs: ChatMessage[] = []; - for (let { role, content, attachments, params } of normalizedMessages.filter( - m => m.role !== 'system' - )) { - content = content.trim(); - role = role as 'user' | 'assistant'; - const mimetype = params?.mimetype; - if (Array.isArray(attachments)) { - const contents: (TextPart | ImagePart | FilePart)[] = []; - if (content.length) { - contents.push({ type: 'text', text: content }); - } - - if (withAttachment) { - for (let attachment of attachments) { - let mediaType: string; - if (typeof attachment === 'string') { - mediaType = - typeof mimetype === 'string' - ? mimetype - : await inferMimeType(attachment); - } else { - ({ attachment, mimeType: mediaType } = attachment); - } - if (SIMPLE_IMAGE_URL_REGEX.test(attachment)) { - const data = - attachment.startsWith('data:') || useBase64Attachment - ? await fetchArrayBuffer(attachment).catch(error => { - // Avoid leaking internal details for blocked URLs. - if ( - error instanceof SsrfBlockedError || - error instanceof ResponseTooLargeError - ) { - throw new Error('Attachment URL is not allowed'); - } - throw error; - }) - : new URL(attachment); - if (mediaType.startsWith('image/')) { - contents.push({ type: 'image', image: data, mediaType }); - } else { - contents.push({ type: 'file' as const, data, mediaType }); - } - } - } - } else if (!content.length) { - // temp fix for pplx - contents.push({ type: 'text', text: '[no content]' }); - } - - msgs.push({ role, content: contents } as ChatMessage); - } else { - msgs.push({ role, content }); - } - } - - return [system?.content, msgs, schema]; -} - -// pattern types the callback will receive -type Pattern = - | { kind: 'index'; value: number } // [123] - | { kind: 'link'; text: string; url: string } // [text](url) - | { kind: 'wrappedLink'; text: string; url: string }; // ([text](url)) - -type NeedMore = { kind: 'needMore' }; -type Failed = { kind: 'fail'; nextPos: number }; -type Finished = - | { kind: 'ok'; endPos: number; text: string; url: string } - | { kind: 'index'; endPos: number; value: number }; -type ParseStatus = Finished | NeedMore | Failed; - -type PatternCallback = (m: Pattern) => string; - -export class StreamPatternParser { - #buffer = ''; - - constructor(private readonly callback: PatternCallback) {} - - write(chunk: string): string { - this.#buffer += chunk; - const output: string[] = []; - let i = 0; - - while (i < this.#buffer.length) { - const ch = this.#buffer[i]; - - // [[[number]]] or [text](url) or ([text](url)) - if (ch === '[' || (ch === '(' && this.peek(i + 1) === '[')) { - const isWrapped = ch === '('; - const startPos = isWrapped ? i + 1 : i; - const res = this.tryParse(startPos); - if (res.kind === 'needMore') break; - const { output: out, nextPos } = this.handlePattern( - res, - isWrapped, - startPos, - i - ); - output.push(out); - i = nextPos; - continue; - } - output.push(ch); - i += 1; - } - - this.#buffer = this.#buffer.slice(i); - return output.join(''); - } - - end(): string { - const rest = this.#buffer; - this.#buffer = ''; - return rest; - } - - // =========== helpers =========== - - private peek(pos: number): string | undefined { - return pos < this.#buffer.length ? this.#buffer[pos] : undefined; - } - - private tryParse(pos: number): ParseStatus { - const nestedRes = this.tryParseNestedIndex(pos); - if (nestedRes) return nestedRes; - return this.tryParseBracketPattern(pos); - } - - private tryParseNestedIndex(pos: number): ParseStatus | null { - if (this.peek(pos + 1) !== '[') return null; - - let i = pos; - let bracketCount = 0; - - while (i < this.#buffer.length && this.#buffer[i] === '[') { - bracketCount++; - i++; - } - - if (bracketCount >= 2) { - if (i >= this.#buffer.length) { - return { kind: 'needMore' }; - } - - let content = ''; - while (i < this.#buffer.length && this.#buffer[i] !== ']') { - content += this.#buffer[i++]; - } - - let rightBracketCount = 0; - while (i < this.#buffer.length && this.#buffer[i] === ']') { - rightBracketCount++; - i++; - } - - if (i >= this.#buffer.length && rightBracketCount < bracketCount) { - return { kind: 'needMore' }; - } - - if ( - rightBracketCount === bracketCount && - content.length > 0 && - this.isNumeric(content) - ) { - if (this.peek(i) === '(') { - return { kind: 'fail', nextPos: i }; - } - return { kind: 'index', endPos: i, value: Number(content) }; - } - } - - return null; - } - - private tryParseBracketPattern(pos: number): ParseStatus { - let i = pos + 1; // skip '[' - if (i >= this.#buffer.length) { - return { kind: 'needMore' }; - } - - let content = ''; - while (i < this.#buffer.length && this.#buffer[i] !== ']') { - const nextChar = this.#buffer[i]; - if (nextChar === '[') { - return { kind: 'fail', nextPos: i }; - } - content += nextChar; - i += 1; - } - - if (i >= this.#buffer.length) { - return { kind: 'needMore' }; - } - const after = i + 1; - const afterChar = this.peek(after); - - if (content.length > 0 && this.isNumeric(content) && afterChar !== '(') { - // [number] pattern - return { kind: 'index', endPos: after, value: Number(content) }; - } else if (afterChar !== '(') { - // [text](url) pattern - return { kind: 'fail', nextPos: after }; - } - - i = after + 1; // skip '(' - if (i >= this.#buffer.length) { - return { kind: 'needMore' }; - } - - let url = ''; - while (i < this.#buffer.length && this.#buffer[i] !== ')') { - url += this.#buffer[i++]; - } - if (i >= this.#buffer.length) { - return { kind: 'needMore' }; - } - return { kind: 'ok', endPos: i + 1, text: content, url }; - } - - private isNumeric(str: string): boolean { - return !Number.isNaN(Number(str)) && str.trim() !== ''; - } - - private handlePattern( - pattern: Finished | Failed, - isWrapped: boolean, - start: number, - current: number - ): { output: string; nextPos: number } { - if (pattern.kind === 'fail') { - return { - output: this.#buffer.slice(current, pattern.nextPos), - nextPos: pattern.nextPos, - }; - } - - if (isWrapped) { - const afterLinkPos = pattern.endPos; - if (this.peek(afterLinkPos) !== ')') { - if (afterLinkPos >= this.#buffer.length) { - return { output: '', nextPos: current }; - } - return { output: '(', nextPos: start }; - } - - const out = - pattern.kind === 'index' - ? this.callback({ ...pattern, kind: 'index' }) - : this.callback({ ...pattern, kind: 'wrappedLink' }); - return { output: out, nextPos: afterLinkPos + 1 }; - } else { - const out = - pattern.kind === 'ok' - ? this.callback({ ...pattern, kind: 'link' }) - : this.callback({ ...pattern, kind: 'index' }); - return { output: out, nextPos: pattern.endPos }; - } - } -} - -export class CitationParser { - private readonly citations: string[] = []; - - private readonly parser = new StreamPatternParser(p => { - switch (p.kind) { - case 'index': { - if (p.value <= this.citations.length) { - return `[^${p.value}]`; - } - return `[${p.value}]`; - } - case 'wrappedLink': { - const index = this.citations.indexOf(p.url); - if (index === -1) { - this.citations.push(p.url); - return `[^${this.citations.length}]`; - } - return `[^${index + 1}]`; - } - case 'link': { - return `[${p.text}](${p.url})`; - } - } - }); - - public push(citation: string) { - this.citations.push(citation); - } - - public parse(content: string) { - return this.parser.write(content); - } - - public end() { - return this.parser.end() + '\n' + this.getFootnotes(); - } - - private getFootnotes() { - const footnotes = this.citations.map((citation, index) => { - return `[^${index + 1}]: {"type":"url","url":"${encodeURIComponent( - citation - )}"}`; - }); - return footnotes.join('\n'); - } -} - -export type CitationIndexedEvent = { +type CitationIndexedEvent = { type: 'citation'; index: number; url: string; @@ -436,7 +146,7 @@ export class CitationFootnoteFormatter { } } -type ChunkType = TextStreamPart['type']; +type ChunkType = CopilotTextStreamPart['type']; export function toError(error: unknown): Error { if (typeof error === 'string') { @@ -458,6 +168,14 @@ type DocEditFootnote = { intent: string; result: string; }; + +function asRecord(value: unknown): Record | null { + if (value && typeof value === 'object' && !Array.isArray(value)) { + return value as Record; + } + return null; +} + export class TextStreamParser { private readonly logger = new Logger(TextStreamParser.name); private readonly CALLOUT_PREFIX = '\n[!]\n'; @@ -468,7 +186,7 @@ export class TextStreamParser { private readonly docEditFootnotes: DocEditFootnote[] = []; - public parse(chunk: TextStreamPart) { + public parse(chunk: CopilotTextStreamPart) { let result = ''; switch (chunk.type) { case 'text-delta': { @@ -517,7 +235,7 @@ export class TextStreamParser { } case 'doc_edit': { this.docEditFootnotes.push({ - intent: chunk.input.instructions, + intent: String(chunk.input.instructions ?? ''), result: '', }); break; @@ -533,14 +251,12 @@ export class TextStreamParser { result = this.addPrefix(result); switch (chunk.toolName) { case 'doc_edit': { - const array = - chunk.output && typeof chunk.output === 'object' - ? chunk.output.result - : undefined; + const output = asRecord(chunk.output); + const array = output?.result; if (Array.isArray(array)) { result += array .map(item => { - return `\n${item.changedContent}\n`; + return `\n${String(asRecord(item)?.changedContent ?? '')}\n`; }) .join(''); this.docEditFootnotes[this.docEditFootnotes.length - 1].result = @@ -557,8 +273,11 @@ export class TextStreamParser { } else if (typeof output === 'string') { result += `\n${output}\n`; } else { + const message = asRecord(output)?.message; this.logger.warn( - `Unexpected result type for doc_semantic_search: ${output?.message || 'Unknown error'}` + `Unexpected result type for doc_semantic_search: ${ + typeof message === 'string' ? message : 'Unknown error' + }` ); } break; @@ -572,9 +291,11 @@ export class TextStreamParser { break; } case 'doc_compose': { - const output = chunk.output; - if (output && typeof output === 'object' && 'title' in output) { - result += `\nDocument "${output.title}" created successfully with ${output.wordCount} words.\n`; + const output = asRecord(chunk.output); + if (output && typeof output.title === 'string') { + result += `\nDocument "${output.title}" created successfully with ${String( + output.wordCount ?? 0 + )} words.\n`; } break; } @@ -654,7 +375,7 @@ export class TextStreamParser { } export class StreamObjectParser { - public parse(chunk: TextStreamPart) { + public parse(chunk: CopilotTextStreamPart) { switch (chunk.type) { case 'reasoning-delta': { return { type: 'reasoning' as const, textDelta: chunk.text }; @@ -747,9 +468,7 @@ function normalizeUrl(baseURL?: string) { } } -export function getVertexAnthropicBaseUrl( - options: GoogleVertexAnthropicProviderSettings -) { +export function getVertexAnthropicBaseUrl(options: VertexProviderConfig) { const normalizedBaseUrl = normalizeUrl(options.baseURL); if (normalizedBaseUrl) return normalizedBaseUrl; const { location, project } = options; @@ -758,7 +477,7 @@ export function getVertexAnthropicBaseUrl( } export async function getGoogleAuth( - options: GoogleVertexAnthropicProviderSettings | GoogleVertexProviderSettings, + options: VertexProviderConfig, publisher: 'anthropic' | 'google' ) { function getBaseUrl() { @@ -777,7 +496,7 @@ export async function getGoogleAuth( } const auth = new GoogleAuth({ scopes: ['https://www.googleapis.com/auth/cloud-platform'], - ...(options.googleAuthOptions as GoogleAuthOptions), + ...options.googleAuthOptions, }); const client = await auth.getClient(); const token = await client.getAccessToken(); diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index 36b0ea669e..38e5c33eee 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -31,6 +31,7 @@ import { SubscriptionPlan, SubscriptionStatus } from '../payment/types'; import { ChatMessageCache } from './message'; import { ChatPrompt } from './prompt/chat-prompt'; import { PromptService } from './prompt/service'; +import { promptAttachmentHasSource } from './providers/attachments'; import { CopilotProviderFactory } from './providers/factory'; import { buildProviderRegistry } from './providers/provider-registry'; import { @@ -38,6 +39,7 @@ import { type PromptMessage, type PromptParams, } from './providers/types'; +import { promptAttachmentToUrl } from './providers/utils'; import { type ChatHistory, type ChatMessage, @@ -272,11 +274,7 @@ export class ChatSession implements AsyncDisposable { lastMessage.attachments || [], ] .flat() - .filter(v => - typeof v === 'string' - ? !!v.trim() - : v && v.attachment.trim() && v.mimeType - ); + .filter(v => promptAttachmentHasSource(v)); //insert all previous user message content before first user message finished.splice(firstUserMessageIndex, 0, ...messages); @@ -466,8 +464,8 @@ export class ChatSessionService { messages: preload.concat(messages).map(m => ({ ...m, attachments: m.attachments - ?.map(a => (typeof a === 'string' ? a : a.attachment)) - .filter(a => !!a), + ?.map(a => promptAttachmentToUrl(a)) + .filter((a): a is string => !!a), })), }; } else { diff --git a/packages/backend/server/src/plugins/copilot/tools/blob-read.ts b/packages/backend/server/src/plugins/copilot/tools/blob-read.ts index b331b98a44..893e440954 100644 --- a/packages/backend/server/src/plugins/copilot/tools/blob-read.ts +++ b/packages/backend/server/src/plugins/copilot/tools/blob-read.ts @@ -1,9 +1,9 @@ import { Logger } from '@nestjs/common'; -import { tool } from 'ai'; import { z } from 'zod'; import { AccessController } from '../../../core/permission'; import { toolError } from './error'; +import { defineTool } from './tool'; import type { ContextSession, CopilotChatOptions } from './types'; const logger = new Logger('ContextBlobReadTool'); @@ -58,7 +58,7 @@ export const createBlobReadTool = ( chunk?: number ) => Promise ) => { - return tool({ + return defineTool({ description: 'Return the content and basic metadata of a single attachment identified by blobId; more inclined to use search tools rather than this tool.', inputSchema: z.object({ diff --git a/packages/backend/server/src/plugins/copilot/tools/code-artifact.ts b/packages/backend/server/src/plugins/copilot/tools/code-artifact.ts index d46b88f448..639567a84e 100644 --- a/packages/backend/server/src/plugins/copilot/tools/code-artifact.ts +++ b/packages/backend/server/src/plugins/copilot/tools/code-artifact.ts @@ -1,8 +1,8 @@ import { Logger } from '@nestjs/common'; -import { tool } from 'ai'; import { z } from 'zod'; import { toolError } from './error'; +import { defineTool } from './tool'; import type { CopilotProviderFactory, PromptService } from './types'; const logger = new Logger('CodeArtifactTool'); @@ -16,7 +16,7 @@ export const createCodeArtifactTool = ( promptService: PromptService, factory: CopilotProviderFactory ) => { - return tool({ + return defineTool({ description: 'Generate a single-file HTML snippet (with inline