feat(server): migrate copilot to native (#14620)

#### PR Dependency Tree


* **PR #14620** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Native LLM workflows: structured outputs, embeddings, and reranking
plus richer multimodal attachments (images, audio, files) and improved
remote-attachment inlining.

* **Refactor**
* Tooling API unified behind a local tool-definition helper;
provider/adapters reorganized to route through native dispatch paths.

* **Chores**
* Dependency updates, removed legacy Google SDK integrations, and
increased front memory allocation.

* **Tests**
* Expanded end-to-end and streaming tests exercising native provider
flows, attachments, and rerank/structured scenarios.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
DarkSky
2026-03-11 13:55:35 +08:00
committed by GitHub
parent 02744cec00
commit 29a27b561b
62 changed files with 4359 additions and 2296 deletions

View File

@@ -19,3 +19,8 @@ rustflags = [
# pthread_key_create() destructors and segfault after a DSO unloading # pthread_key_create() destructors and segfault after a DSO unloading
[target.'cfg(all(target_env = "gnu", not(target_os = "windows")))'] [target.'cfg(all(target_env = "gnu", not(target_os = "windows")))']
rustflags = ["-C", "link-args=-Wl,-z,nodelete"] rustflags = ["-C", "link-args=-Wl,-z,nodelete"]
# Temporary local llm_adapter override.
# Uncomment when verifying AFFiNE against the sibling llm_adapter workspace.
# [patch.crates-io]
# llm_adapter = { path = "../llm_adapter" }

View File

@@ -31,10 +31,10 @@ podSecurityContext:
resources: resources:
limits: limits:
cpu: '1' cpu: '1'
memory: 4Gi memory: 6Gi
requests: requests:
cpu: '1' cpu: '1'
memory: 2Gi memory: 4Gi
probe: probe:
initialDelaySeconds: 20 initialDelaySeconds: 20

485
Cargo.lock generated
View File

@@ -186,6 +186,7 @@ dependencies = [
"libwebp-sys", "libwebp-sys",
"little_exif", "little_exif",
"llm_adapter", "llm_adapter",
"matroska",
"mimalloc", "mimalloc",
"mp4parse", "mp4parse",
"napi", "napi",
@@ -480,12 +481,6 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "atomic-waker"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]] [[package]]
name = "auto_enums" name = "auto_enums"
version = "0.8.7" version = "0.8.7"
@@ -504,28 +499,6 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "aws-lc-rs"
version = "1.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf"
dependencies = [
"aws-lc-sys",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
version = "0.38.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e"
dependencies = [
"cc",
"cmake",
"dunce",
"fs_extra",
]
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.22.1" version = "0.22.1"
@@ -649,6 +622,15 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "bitstream-io"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "680575de65ce8b916b82a447458b94a48776707d9c2681a9d8da351c06886a1f"
dependencies = [
"core2",
]
[[package]] [[package]]
name = "block-buffer" name = "block-buffer"
version = "0.10.4" version = "0.10.4"
@@ -981,15 +963,6 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d"
[[package]]
name = "cmake"
version = "0.1.57"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "color_quant" name = "color_quant"
version = "1.1.0" version = "1.1.0"
@@ -1534,12 +1507,6 @@ version = "0.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f678cf4a922c215c63e0de95eb1ff08a958a81d47e485cf9da1e27bf6305cfa5" checksum = "f678cf4a922c215c63e0de95eb1ff08a958a81d47e485cf9da1e27bf6305cfa5"
[[package]]
name = "dunce"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]] [[package]]
name = "ecb" name = "ecb"
version = "0.1.2" version = "0.1.2"
@@ -1771,12 +1738,6 @@ dependencies = [
"autocfg", "autocfg",
] ]
[[package]]
name = "fs_extra"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]] [[package]]
name = "futf" name = "futf"
version = "0.1.5" version = "0.1.5"
@@ -1941,11 +1902,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"js-sys",
"libc", "libc",
"r-efi", "r-efi",
"wasip2", "wasip2",
"wasm-bindgen",
] ]
[[package]] [[package]]
@@ -2138,95 +2097,12 @@ dependencies = [
"itoa", "itoa",
] ]
[[package]]
name = "http-body"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [
"bytes",
"http",
]
[[package]]
name = "http-body-util"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [
"bytes",
"futures-core",
"http",
"http-body",
"pin-project-lite",
]
[[package]] [[package]]
name = "httparse" name = "httparse"
version = "1.10.1" version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
[[package]]
name = "hyper"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11"
dependencies = [
"atomic-waker",
"bytes",
"futures-channel",
"futures-core",
"http",
"http-body",
"httparse",
"itoa",
"pin-project-lite",
"pin-utils",
"smallvec",
"tokio",
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58"
dependencies = [
"http",
"hyper",
"hyper-util",
"rustls",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]]
name = "hyper-util"
version = "0.1.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0"
dependencies = [
"base64",
"bytes",
"futures-channel",
"futures-util",
"http",
"http-body",
"hyper",
"ipnet",
"libc",
"percent-encoding",
"pin-project-lite",
"socket2",
"tokio",
"tower-service",
"tracing",
]
[[package]] [[package]]
name = "iana-time-zone" name = "iana-time-zone"
version = "0.1.64" version = "0.1.64"
@@ -2505,22 +2381,6 @@ dependencies = [
"leaky-cow", "leaky-cow",
] ]
[[package]]
name = "ipnet"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "iri-string"
version = "0.7.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a"
dependencies = [
"memchr",
"serde",
]
[[package]] [[package]]
name = "is-terminal" name = "is-terminal"
version = "0.4.17" version = "0.4.17"
@@ -2813,15 +2673,15 @@ dependencies = [
[[package]] [[package]]
name = "llm_adapter" name = "llm_adapter"
version = "0.1.1" version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8dd9a548766bccf8b636695e8d514edee672d180e96a16ab932c971783b4e353" checksum = "e98485dda5180cc89b993a001688bed93307be6bd8fedcde445b69bbca4f554d"
dependencies = [ dependencies = [
"base64", "base64",
"reqwest",
"serde", "serde",
"serde_json", "serde_json",
"thiserror 2.0.17", "thiserror 2.0.17",
"ureq",
] ]
[[package]] [[package]]
@@ -2889,12 +2749,6 @@ dependencies = [
"hashbrown 0.16.1", "hashbrown 0.16.1",
] ]
[[package]]
name = "lru-slab"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
[[package]] [[package]]
name = "mac" name = "mac"
version = "0.1.1" version = "0.1.1"
@@ -2954,6 +2808,16 @@ dependencies = [
"regex-automata", "regex-automata",
] ]
[[package]]
name = "matroska"
version = "0.30.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fde85cd7fb5cf875c4a46fac0cbd6567d413bea2538cef6788e3a0e52a902b45"
dependencies = [
"bitstream-io",
"phf 0.11.3",
]
[[package]] [[package]]
name = "md-5" name = "md-5"
version = "0.10.6" version = "0.10.6"
@@ -3396,12 +3260,6 @@ version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]]
name = "openssl-probe"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
[[package]] [[package]]
name = "ordered-float" name = "ordered-float"
version = "5.1.0" version = "5.1.0"
@@ -3884,62 +3742,6 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "quinn"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20"
dependencies = [
"bytes",
"cfg_aliases",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash 2.1.1",
"rustls",
"socket2",
"thiserror 2.0.17",
"tokio",
"tracing",
"web-time",
]
[[package]]
name = "quinn-proto"
version = "0.11.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31"
dependencies = [
"aws-lc-rs",
"bytes",
"getrandom 0.3.4",
"lru-slab",
"rand 0.9.2",
"ring",
"rustc-hash 2.1.1",
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.17",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.60.2",
]
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.43" version = "1.0.43"
@@ -4128,45 +3930,6 @@ version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58"
[[package]]
name = "reqwest"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801"
dependencies = [
"base64",
"bytes",
"futures-channel",
"futures-core",
"futures-util",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-util",
"js-sys",
"log",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-pki-types",
"rustls-platform-verifier",
"serde",
"serde_json",
"sync_wrapper",
"tokio",
"tokio-rustls",
"tower",
"tower-http",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]] [[package]]
name = "ring" name = "ring"
version = "0.17.14" version = "0.17.14"
@@ -4307,7 +4070,7 @@ version = "0.23.36"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b"
dependencies = [ dependencies = [
"aws-lc-rs", "log",
"once_cell", "once_cell",
"ring", "ring",
"rustls-pki-types", "rustls-pki-types",
@@ -4316,62 +4079,21 @@ dependencies = [
"zeroize", "zeroize",
] ]
[[package]]
name = "rustls-native-certs"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63"
dependencies = [
"openssl-probe",
"rustls-pki-types",
"schannel",
"security-framework",
]
[[package]] [[package]]
name = "rustls-pki-types" name = "rustls-pki-types"
version = "1.13.2" version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282" checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282"
dependencies = [ dependencies = [
"web-time",
"zeroize", "zeroize",
] ]
[[package]]
name = "rustls-platform-verifier"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784"
dependencies = [
"core-foundation",
"core-foundation-sys",
"jni",
"log",
"once_cell",
"rustls",
"rustls-native-certs",
"rustls-platform-verifier-android",
"rustls-webpki",
"security-framework",
"security-framework-sys",
"webpki-root-certs",
"windows-sys 0.61.2",
]
[[package]]
name = "rustls-platform-verifier-android"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"
[[package]] [[package]]
name = "rustls-webpki" name = "rustls-webpki"
version = "0.103.8" version = "0.103.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52"
dependencies = [ dependencies = [
"aws-lc-rs",
"ring", "ring",
"rustls-pki-types", "rustls-pki-types",
"untrusted", "untrusted",
@@ -4410,15 +4132,6 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "schannel"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1"
dependencies = [
"windows-sys 0.61.2",
]
[[package]] [[package]]
name = "scoped-tls" name = "scoped-tls"
version = "1.0.1" version = "1.0.1"
@@ -4467,29 +4180,6 @@ dependencies = [
"syn 2.0.114", "syn 2.0.114",
] ]
[[package]]
name = "security-framework"
version = "3.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d"
dependencies = [
"bitflags 2.11.0",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "semver" name = "semver"
version = "1.0.27" version = "1.0.27"
@@ -5215,15 +4905,6 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
dependencies = [
"futures-core",
]
[[package]] [[package]]
name = "synstructure" name = "synstructure"
version = "0.13.2" version = "0.13.2"
@@ -5415,16 +5096,6 @@ dependencies = [
"syn 2.0.114", "syn 2.0.114",
] ]
[[package]]
name = "tokio-rustls"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61"
dependencies = [
"rustls",
"tokio",
]
[[package]] [[package]]
name = "tokio-stream" name = "tokio-stream"
version = "0.1.18" version = "0.1.18"
@@ -5475,51 +5146,6 @@ dependencies = [
"winnow", "winnow",
] ]
[[package]]
name = "tower"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4"
dependencies = [
"futures-core",
"futures-util",
"pin-project-lite",
"sync_wrapper",
"tokio",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-http"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8"
dependencies = [
"bitflags 2.11.0",
"bytes",
"futures-util",
"http",
"http-body",
"iri-string",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-layer"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
[[package]]
name = "tower-service"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]] [[package]]
name = "tracing" name = "tracing"
version = "0.1.44" version = "0.1.44"
@@ -5722,12 +5348,6 @@ dependencies = [
"tree-sitter-language", "tree-sitter-language",
] ]
[[package]]
name = "try-lock"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]] [[package]]
name = "type1-encoding-parser" name = "type1-encoding-parser"
version = "0.1.0" version = "0.1.0"
@@ -5952,6 +5572,35 @@ version = "0.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae"
[[package]]
name = "ureq"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdc97a28575b85cfedf2a7e7d3cc64b3e11bd8ac766666318003abbacc7a21fc"
dependencies = [
"base64",
"flate2",
"log",
"percent-encoding",
"rustls",
"rustls-pki-types",
"ureq-proto",
"utf-8",
"webpki-roots 1.0.5",
]
[[package]]
name = "ureq-proto"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f"
dependencies = [
"base64",
"http",
"httparse",
"log",
]
[[package]] [[package]]
name = "url" name = "url"
version = "2.5.8" version = "2.5.8"
@@ -6048,15 +5697,6 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "want"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e"
dependencies = [
"try-lock",
]
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.11.1+wasi-snapshot-preview1" version = "0.11.1+wasi-snapshot-preview1"
@@ -6146,25 +5786,6 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "webpki-root-certs"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca"
dependencies = [
"rustls-pki-types",
]
[[package]] [[package]]
name = "webpki-roots" name = "webpki-roots"
version = "0.26.11" version = "0.26.11"

View File

@@ -53,10 +53,11 @@ resolver = "3"
libc = "0.2" libc = "0.2"
libwebp-sys = "0.14.2" libwebp-sys = "0.14.2"
little_exif = "0.6.23" little_exif = "0.6.23"
llm_adapter = "0.1.1" llm_adapter = { version = "0.1.3", default-features = false }
log = "0.4" log = "0.4"
loom = { version = "0.7", features = ["checkpoint"] } loom = { version = "0.7", features = ["checkpoint"] }
lru = "0.16" lru = "0.16"
matroska = "0.30"
memory-indexer = "0.3.0" memory-indexer = "0.3.0"
mimalloc = "0.1" mimalloc = "0.1"
mp4parse = "0.17" mp4parse = "0.17"

View File

@@ -21,7 +21,10 @@ image = { workspace = true }
infer = { workspace = true } infer = { workspace = true }
libwebp-sys = { workspace = true } libwebp-sys = { workspace = true }
little_exif = { workspace = true } little_exif = { workspace = true }
llm_adapter = { workspace = true } llm_adapter = { workspace = true, default-features = false, features = [
"ureq-client",
] }
matroska = { workspace = true }
mp4parse = { workspace = true } mp4parse = { workspace = true }
napi = { workspace = true, features = ["async"] } napi = { workspace = true, features = ["async"] }
napi-derive = { workspace = true } napi-derive = { workspace = true }

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -54,6 +54,12 @@ export declare function llmDispatch(protocol: string, backendConfigJson: string,
export declare function llmDispatchStream(protocol: string, backendConfigJson: string, requestJson: string, callback: ((err: Error | null, arg: string) => void)): LlmStreamHandle export declare function llmDispatchStream(protocol: string, backendConfigJson: string, requestJson: string, callback: ((err: Error | null, arg: string) => void)): LlmStreamHandle
export declare function llmEmbeddingDispatch(protocol: string, backendConfigJson: string, requestJson: string): string
export declare function llmRerankDispatch(protocol: string, backendConfigJson: string, requestJson: string): string
export declare function llmStructuredDispatch(protocol: string, backendConfigJson: string, requestJson: string): string
/** /**
* Merge updates in form like `Y.applyUpdate(doc, update)` way and return the * Merge updates in form like `Y.applyUpdate(doc, update)` way and return the
* result binary. * result binary.

View File

@@ -1,3 +1,4 @@
use matroska::Matroska;
use mp4parse::{TrackType, read_mp4}; use mp4parse::{TrackType, read_mp4};
use napi_derive::napi; use napi_derive::napi;
@@ -8,7 +9,13 @@ pub fn get_mime(input: &[u8]) -> String {
} else { } else {
file_format::FileFormat::from_bytes(input).media_type().to_string() file_format::FileFormat::from_bytes(input).media_type().to_string()
}; };
if mimetype == "video/mp4" { if let Some(container) = matroska_container_kind(input).or(match mimetype.as_str() {
"video/webm" | "application/webm" => Some(ContainerKind::WebM),
"video/x-matroska" | "application/x-matroska" => Some(ContainerKind::Matroska),
_ => None,
}) {
detect_matroska_flavor(input, container, &mimetype)
} else if mimetype == "video/mp4" {
detect_mp4_flavor(input) detect_mp4_flavor(input)
} else { } else {
mimetype mimetype
@@ -37,3 +44,68 @@ fn detect_mp4_flavor(input: &[u8]) -> String {
Err(_) => "video/mp4".to_string(), Err(_) => "video/mp4".to_string(),
} }
} }
#[derive(Clone, Copy)]
enum ContainerKind {
WebM,
Matroska,
}
impl ContainerKind {
fn audio_mime(&self) -> &'static str {
match self {
ContainerKind::WebM => "audio/webm",
ContainerKind::Matroska => "audio/x-matroska",
}
}
}
fn detect_matroska_flavor(input: &[u8], container: ContainerKind, fallback: &str) -> String {
match Matroska::open(std::io::Cursor::new(input)) {
Ok(file) => {
let has_video = file.video_tracks().next().is_some();
let has_audio = file.audio_tracks().next().is_some();
if !has_video && has_audio {
container.audio_mime().to_string()
} else {
fallback.to_string()
}
}
Err(_) => fallback.to_string(),
}
}
fn matroska_container_kind(input: &[u8]) -> Option<ContainerKind> {
let header = &input[..1024.min(input.len())];
if header.windows(4).any(|window| window.eq_ignore_ascii_case(b"webm")) {
Some(ContainerKind::WebM)
} else if header.windows(8).any(|window| window.eq_ignore_ascii_case(b"matroska")) {
Some(ContainerKind::Matroska)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
const AUDIO_ONLY_WEBM: &[u8] = include_bytes!("../fixtures/audio-only.webm");
const AUDIO_VIDEO_WEBM: &[u8] = include_bytes!("../fixtures/audio-video.webm");
const AUDIO_ONLY_MATROSKA: &[u8] = include_bytes!("../fixtures/audio-only.mka");
#[test]
fn detects_audio_only_webm_as_audio() {
assert_eq!(get_mime(AUDIO_ONLY_WEBM), "audio/webm");
}
#[test]
fn preserves_video_webm() {
assert_eq!(get_mime(AUDIO_VIDEO_WEBM), "video/webm");
}
#[test]
fn detects_audio_only_matroska_as_audio() {
assert_eq!(get_mime(AUDIO_ONLY_MATROSKA), "audio/x-matroska");
}
}

View File

@@ -5,9 +5,10 @@ use std::sync::{
use llm_adapter::{ use llm_adapter::{
backend::{ backend::{
BackendConfig, BackendError, BackendProtocol, ReqwestHttpClient, dispatch_request, dispatch_stream_events_with, BackendConfig, BackendError, BackendProtocol, DefaultHttpClient, dispatch_embedding_request, dispatch_request,
dispatch_rerank_request, dispatch_stream_events_with, dispatch_structured_request,
}, },
core::{CoreRequest, StreamEvent}, core::{CoreRequest, EmbeddingRequest, RerankRequest, StreamEvent, StructuredRequest},
middleware::{ middleware::{
MiddlewareConfig, PipelineContext, RequestMiddleware, StreamMiddleware, citation_indexing, clamp_max_tokens, MiddlewareConfig, PipelineContext, RequestMiddleware, StreamMiddleware, citation_indexing, clamp_max_tokens,
normalize_messages, run_request_middleware_chain, run_stream_middleware_chain, stream_event_normalize, normalize_messages, run_request_middleware_chain, run_stream_middleware_chain, stream_event_normalize,
@@ -40,6 +41,20 @@ struct LlmDispatchPayload {
middleware: LlmMiddlewarePayload, middleware: LlmMiddlewarePayload,
} }
#[derive(Debug, Clone, Deserialize)]
struct LlmStructuredDispatchPayload {
#[serde(flatten)]
request: StructuredRequest,
#[serde(default)]
middleware: LlmMiddlewarePayload,
}
#[derive(Debug, Clone, Deserialize)]
struct LlmRerankDispatchPayload {
#[serde(flatten)]
request: RerankRequest,
}
#[napi] #[napi]
pub struct LlmStreamHandle { pub struct LlmStreamHandle {
aborted: Arc<AtomicBool>, aborted: Arc<AtomicBool>,
@@ -61,7 +76,44 @@ pub fn llm_dispatch(protocol: String, backend_config_json: String, request_json:
let request = apply_request_middlewares(payload.request, &payload.middleware)?; let request = apply_request_middlewares(payload.request, &payload.middleware)?;
let response = let response =
dispatch_request(&ReqwestHttpClient::default(), &config, protocol, &request).map_err(map_backend_error)?; dispatch_request(&DefaultHttpClient::default(), &config, protocol, &request).map_err(map_backend_error)?;
serde_json::to_string(&response).map_err(map_json_error)
}
#[napi(catch_unwind)]
pub fn llm_structured_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result<String> {
let protocol = parse_protocol(&protocol)?;
let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?;
let payload: LlmStructuredDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?;
let request = apply_structured_request_middlewares(payload.request, &payload.middleware)?;
let response = dispatch_structured_request(&DefaultHttpClient::default(), &config, protocol, &request)
.map_err(map_backend_error)?;
serde_json::to_string(&response).map_err(map_json_error)
}
#[napi(catch_unwind)]
pub fn llm_embedding_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result<String> {
let protocol = parse_protocol(&protocol)?;
let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?;
let request: EmbeddingRequest = serde_json::from_str(&request_json).map_err(map_json_error)?;
let response = dispatch_embedding_request(&DefaultHttpClient::default(), &config, protocol, &request)
.map_err(map_backend_error)?;
serde_json::to_string(&response).map_err(map_json_error)
}
#[napi(catch_unwind)]
pub fn llm_rerank_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result<String> {
let protocol = parse_protocol(&protocol)?;
let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?;
let payload: LlmRerankDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?;
let response = dispatch_rerank_request(&DefaultHttpClient::default(), &config, protocol, &payload.request)
.map_err(map_backend_error)?;
serde_json::to_string(&response).map_err(map_json_error) serde_json::to_string(&response).map_err(map_json_error)
} }
@@ -98,7 +150,7 @@ pub fn llm_dispatch_stream(
let mut aborted_by_user = false; let mut aborted_by_user = false;
let mut callback_dispatch_failed = false; let mut callback_dispatch_failed = false;
let result = dispatch_stream_events_with(&ReqwestHttpClient::default(), &config, protocol, &request, |event| { let result = dispatch_stream_events_with(&DefaultHttpClient::default(), &config, protocol, &request, |event| {
if aborted_in_worker.load(Ordering::Relaxed) { if aborted_in_worker.load(Ordering::Relaxed) {
aborted_by_user = true; aborted_by_user = true;
return Err(BackendError::Http(STREAM_ABORTED_REASON.to_string())); return Err(BackendError::Http(STREAM_ABORTED_REASON.to_string()));
@@ -155,6 +207,27 @@ fn apply_request_middlewares(request: CoreRequest, middleware: &LlmMiddlewarePay
Ok(run_request_middleware_chain(request, &middleware.config, &chain)) Ok(run_request_middleware_chain(request, &middleware.config, &chain))
} }
fn apply_structured_request_middlewares(
request: StructuredRequest,
middleware: &LlmMiddlewarePayload,
) -> Result<StructuredRequest> {
let mut core = request.as_core_request();
core = apply_request_middlewares(core, middleware)?;
Ok(StructuredRequest {
model: core.model,
messages: core.messages,
schema: core
.response_schema
.ok_or_else(|| Error::new(Status::InvalidArg, "Structured request schema is required"))?,
max_tokens: core.max_tokens,
temperature: core.temperature,
reasoning: core.reasoning,
strict: request.strict,
response_mime_type: request.response_mime_type,
})
}
#[derive(Clone)] #[derive(Clone)]
struct StreamPipeline { struct StreamPipeline {
chain: Vec<StreamMiddleware>, chain: Vec<StreamMiddleware>,
@@ -268,6 +341,7 @@ fn parse_protocol(protocol: &str) -> Result<BackendProtocol> {
} }
"openai_responses" | "openai-responses" | "responses" => Ok(BackendProtocol::OpenaiResponses), "openai_responses" | "openai-responses" | "responses" => Ok(BackendProtocol::OpenaiResponses),
"anthropic" | "anthropic_messages" | "anthropic-messages" => Ok(BackendProtocol::AnthropicMessages), "anthropic" | "anthropic_messages" | "anthropic-messages" => Ok(BackendProtocol::AnthropicMessages),
"gemini" | "gemini_generate_content" | "gemini-generate-content" => Ok(BackendProtocol::GeminiGenerateContent),
other => Err(Error::new( other => Err(Error::new(
Status::InvalidArg, Status::InvalidArg,
format!("Unsupported llm backend protocol: {other}"), format!("Unsupported llm backend protocol: {other}"),
@@ -293,6 +367,7 @@ mod tests {
assert!(parse_protocol("chat-completions").is_ok()); assert!(parse_protocol("chat-completions").is_ok());
assert!(parse_protocol("responses").is_ok()); assert!(parse_protocol("responses").is_ok());
assert!(parse_protocol("anthropic").is_ok()); assert!(parse_protocol("anthropic").is_ok());
assert!(parse_protocol("gemini").is_ok());
} }
#[test] #[test]

View File

@@ -25,8 +25,6 @@
"dependencies": { "dependencies": {
"@affine/s3-compat": "workspace:*", "@affine/s3-compat": "workspace:*",
"@affine/server-native": "workspace:*", "@affine/server-native": "workspace:*",
"@ai-sdk/google": "^3.0.46",
"@ai-sdk/google-vertex": "^4.0.83",
"@apollo/server": "^4.13.0", "@apollo/server": "^4.13.0",
"@fal-ai/serverless-client": "^0.15.0", "@fal-ai/serverless-client": "^0.15.0",
"@google-cloud/opentelemetry-cloud-trace-exporter": "^3.0.0", "@google-cloud/opentelemetry-cloud-trace-exporter": "^3.0.0",
@@ -66,7 +64,6 @@
"@queuedash/api": "^3.16.0", "@queuedash/api": "^3.16.0",
"@react-email/components": "^0.5.7", "@react-email/components": "^0.5.7",
"@socket.io/redis-adapter": "^8.3.0", "@socket.io/redis-adapter": "^8.3.0",
"ai": "^6.0.118",
"bullmq": "^5.40.2", "bullmq": "^5.40.2",
"cookie-parser": "^1.4.7", "cookie-parser": "^1.4.7",
"cross-env": "^10.1.0", "cross-env": "^10.1.0",

View File

@@ -225,6 +225,20 @@ const checkStreamObjects = (result: string) => {
} }
}; };
const parseStreamObjects = (result: string): StreamObject[] => {
const streamObjects = JSON.parse(result);
return z.array(StreamObjectSchema).parse(streamObjects);
};
const getStreamObjectText = (result: string) =>
parseStreamObjects(result)
.filter(
(chunk): chunk is Extract<StreamObject, { type: 'text-delta' }> =>
chunk.type === 'text-delta'
)
.map(chunk => chunk.textDelta)
.join('');
const retry = async ( const retry = async (
action: string, action: string,
t: ExecutionContext<Tester>, t: ExecutionContext<Tester>,
@@ -444,6 +458,49 @@ The term **“CRDT”** was first introduced by Marc Shapiro, Nuno Preguiça, Ca
}, },
type: 'object' as const, type: 'object' as const,
}, },
{
name: 'Gemini native text',
promptName: ['Chat With AFFiNE AI'],
messages: [
{
role: 'user' as const,
content:
'In one short sentence, explain what AFFiNE AI is and mention AFFiNE by name.',
},
],
config: { model: 'gemini-2.5-flash' },
verifier: (t: ExecutionContext<Tester>, result: string) => {
assertNotWrappedInCodeBlock(t, result);
t.assert(
result.toLowerCase().includes('affine'),
'should mention AFFiNE'
);
},
prefer: CopilotProviderType.Gemini,
type: 'text' as const,
},
{
name: 'Gemini native stream objects',
promptName: ['Chat With AFFiNE AI'],
messages: [
{
role: 'user' as const,
content:
'Respond with one short sentence about AFFiNE AI and mention AFFiNE by name.',
},
],
config: { model: 'gemini-2.5-flash' },
verifier: (t: ExecutionContext<Tester>, result: string) => {
t.truthy(checkStreamObjects(result), 'should be valid stream objects');
const assembledText = getStreamObjectText(result);
t.assert(
assembledText.toLowerCase().includes('affine'),
'should mention AFFiNE'
);
},
prefer: CopilotProviderType.Gemini,
type: 'object' as const,
},
{ {
name: 'Should transcribe short audio', name: 'Should transcribe short audio',
promptName: ['Transcript audio'], promptName: ['Transcript audio'],
@@ -716,14 +773,13 @@ for (const {
const { factory, prompt: promptService } = t.context; const { factory, prompt: promptService } = t.context;
const prompt = (await promptService.get(promptName))!; const prompt = (await promptService.get(promptName))!;
t.truthy(prompt, 'should have prompt'); t.truthy(prompt, 'should have prompt');
const provider = (await factory.getProviderByModel(prompt.model, { const finalConfig = Object.assign({}, prompt.config, config);
const modelId = finalConfig.model || prompt.model;
const provider = (await factory.getProviderByModel(modelId, {
prefer, prefer,
}))!; }))!;
t.truthy(provider, 'should have provider'); t.truthy(provider, 'should have provider');
await retry(`action: ${promptName}`, t, async t => { await retry(`action: ${promptName}`, t, async t => {
const finalConfig = Object.assign({}, prompt.config, config);
const modelId = finalConfig.model || prompt.model;
switch (type) { switch (type) {
case 'text': { case 'text': {
const result = await provider.text( const result = await provider.text(
@@ -891,7 +947,7 @@ test(
'should be able to rerank message chunks', 'should be able to rerank message chunks',
runIfCopilotConfigured, runIfCopilotConfigured,
async t => { async t => {
const { factory, prompt } = t.context; const { factory } = t.context;
await retry('rerank', t, async t => { await retry('rerank', t, async t => {
const query = 'Is this content relevant to programming?'; const query = 'Is this content relevant to programming?';
@@ -908,14 +964,18 @@ test(
'The stock market is experiencing significant fluctuations.', 'The stock market is experiencing significant fluctuations.',
]; ];
const p = (await prompt.get('Rerank results'))!; const provider = (await factory.getProviderByModel('gpt-5.2'))!;
t.assert(p, 'should have prompt for rerank');
const provider = (await factory.getProviderByModel(p.model))!;
t.assert(provider, 'should have provider for rerank'); t.assert(provider, 'should have provider for rerank');
const scores = await provider.rerank( const scores = await provider.rerank(
{ modelId: p.model }, { modelId: 'gpt-5.2' },
embeddings.map(e => p.finish({ query, doc: e })) {
query,
candidates: embeddings.map((text, index) => ({
id: String(index),
text,
})),
}
); );
t.is(scores.length, 10, 'should return scores for all chunks'); t.is(scores.length, 10, 'should return scores for all chunks');

View File

@@ -33,10 +33,7 @@ import {
ModelOutputType, ModelOutputType,
OpenAIProvider, OpenAIProvider,
} from '../../plugins/copilot/providers'; } from '../../plugins/copilot/providers';
import { import { TextStreamParser } from '../../plugins/copilot/providers/utils';
CitationParser,
TextStreamParser,
} from '../../plugins/copilot/providers/utils';
import { ChatSessionService } from '../../plugins/copilot/session'; import { ChatSessionService } from '../../plugins/copilot/session';
import { CopilotStorage } from '../../plugins/copilot/storage'; import { CopilotStorage } from '../../plugins/copilot/storage';
import { CopilotTranscriptionService } from '../../plugins/copilot/transcript'; import { CopilotTranscriptionService } from '../../plugins/copilot/transcript';
@@ -660,6 +657,55 @@ test('should be able to generate with message id', async t => {
} }
}); });
test('should preserve file handle attachments when merging user content into prompt', async t => {
const { prompt, session } = t.context;
await prompt.set(promptName, 'model', [
{ role: 'user', content: '{{content}}' },
]);
const sessionId = await session.create({
docId: 'test',
workspaceId: 'test',
userId,
promptName,
pinned: false,
});
const s = (await session.get(sessionId))!;
const message = await session.createMessage({
sessionId,
content: 'Summarize this file',
attachments: [
{
kind: 'file_handle',
fileHandle: 'file_123',
mimeType: 'application/pdf',
},
],
});
await s.pushByMessageId(message);
const finalMessages = s.finish({});
t.deepEqual(finalMessages, [
{
role: 'user',
content: 'Summarize this file',
attachments: [
{
kind: 'file_handle',
fileHandle: 'file_123',
mimeType: 'application/pdf',
},
],
params: {
content: 'Summarize this file',
},
},
]);
});
test('should save message correctly', async t => { test('should save message correctly', async t => {
const { prompt, session } = t.context; const { prompt, session } = t.context;
@@ -1225,149 +1271,6 @@ test('should be able to run image executor', async t => {
Sinon.restore(); Sinon.restore();
}); });
test('CitationParser should replace citation placeholders with URLs', t => {
const content =
'This is [a] test sentence with [citations [1]] and [[2]] and [3].';
const citations = ['https://example1.com', 'https://example2.com'];
const parser = new CitationParser();
for (const citation of citations) {
parser.push(citation);
}
const result = parser.parse(content) + parser.end();
const expected = [
'This is [a] test sentence with [citations [^1]] and [^2] and [3].',
`[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`,
`[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`,
].join('\n');
t.is(result, expected);
});
test('CitationParser should replace chunks of citation placeholders with URLs', t => {
const contents = [
'[[]]',
'This is [',
'a] test sentence ',
'with citations [1',
'] and [',
'[2]] and [[',
'3]] and [[4',
']] and [[5]',
'] and [[6]]',
' and [7',
];
const citations = [
'https://example1.com',
'https://example2.com',
'https://example3.com',
'https://example4.com',
'https://example5.com',
'https://example6.com',
'https://example7.com',
];
const parser = new CitationParser();
for (const citation of citations) {
parser.push(citation);
}
let result = contents.reduce((acc, current) => {
return acc + parser.parse(current);
}, '');
result += parser.end();
const expected = [
'[[]]This is [a] test sentence with citations [^1] and [^2] and [^3] and [^4] and [^5] and [^6] and [7',
`[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`,
`[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`,
`[^3]: {"type":"url","url":"${encodeURIComponent(citations[2])}"}`,
`[^4]: {"type":"url","url":"${encodeURIComponent(citations[3])}"}`,
`[^5]: {"type":"url","url":"${encodeURIComponent(citations[4])}"}`,
`[^6]: {"type":"url","url":"${encodeURIComponent(citations[5])}"}`,
`[^7]: {"type":"url","url":"${encodeURIComponent(citations[6])}"}`,
].join('\n');
t.is(result, expected);
});
test('CitationParser should not replace citation already with URLs', t => {
const content =
'This is [a] test sentence with citations [1](https://example1.com) and [[2]](https://example2.com) and [[3](https://example3.com)].';
const citations = [
'https://example4.com',
'https://example5.com',
'https://example6.com',
];
const parser = new CitationParser();
for (const citation of citations) {
parser.push(citation);
}
const result = parser.parse(content) + parser.end();
const expected = [
content,
`[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`,
`[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`,
`[^3]: {"type":"url","url":"${encodeURIComponent(citations[2])}"}`,
].join('\n');
t.is(result, expected);
});
test('CitationParser should not replace chunks of citation already with URLs', t => {
const contents = [
'This is [a] test sentence with citations [1',
'](https://example1.com) and [[2]',
'](https://example2.com) and [[3](https://example3.com)].',
];
const citations = [
'https://example4.com',
'https://example5.com',
'https://example6.com',
];
const parser = new CitationParser();
for (const citation of citations) {
parser.push(citation);
}
let result = contents.reduce((acc, current) => {
return acc + parser.parse(current);
}, '');
result += parser.end();
const expected = [
contents.join(''),
`[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`,
`[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`,
`[^3]: {"type":"url","url":"${encodeURIComponent(citations[2])}"}`,
].join('\n');
t.is(result, expected);
});
test('CitationParser should replace openai style reference chunks', t => {
const contents = [
'This is [a] test sentence with citations ',
'([example1.com](https://example1.com))',
];
const parser = new CitationParser();
let result = contents.reduce((acc, current) => {
return acc + parser.parse(current);
}, '');
result += parser.end();
const expected = [
contents[0] + '[^1]',
`[^1]: {"type":"url","url":"${encodeURIComponent('https://example1.com')}"}`,
].join('\n');
t.is(result, expected);
});
test('TextStreamParser should format different types of chunks correctly', t => { test('TextStreamParser should format different types of chunks correctly', t => {
// Define interfaces for fixtures // Define interfaces for fixtures
interface BaseFixture { interface BaseFixture {

View File

@@ -1,210 +0,0 @@
import test from 'ava';
import { z } from 'zod';
import type { NativeLlmRequest, NativeLlmStreamEvent } from '../../native';
import {
buildNativeRequest,
NativeProviderAdapter,
} from '../../plugins/copilot/providers/native';
const mockDispatch = () =>
(async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
yield { type: 'text_delta', text: 'Use [^1] now' };
yield { type: 'citation', index: 1, url: 'https://affine.pro' };
yield { type: 'done', finish_reason: 'stop' };
})();
test('NativeProviderAdapter streamText should append citation footnotes', async t => {
const adapter = new NativeProviderAdapter(mockDispatch, {}, 3);
const chunks: string[] = [];
for await (const chunk of adapter.streamText({
model: 'gpt-5-mini',
stream: true,
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
})) {
chunks.push(chunk);
}
const text = chunks.join('');
t.true(text.includes('Use [^1] now'));
t.true(
text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}')
);
});
test('NativeProviderAdapter streamObject should append citation footnotes', async t => {
const adapter = new NativeProviderAdapter(mockDispatch, {}, 3);
const chunks = [];
for await (const chunk of adapter.streamObject({
model: 'gpt-5-mini',
stream: true,
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
})) {
chunks.push(chunk);
}
t.deepEqual(
chunks.map(chunk => chunk.type),
['text-delta', 'text-delta']
);
const text = chunks
.filter(chunk => chunk.type === 'text-delta')
.map(chunk => chunk.textDelta)
.join('');
t.true(text.includes('Use [^1] now'));
t.true(
text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}')
);
});
test('NativeProviderAdapter streamObject should append fallback attachment footnotes', async t => {
const dispatch = () =>
(async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
yield {
type: 'tool_result',
call_id: 'call_1',
name: 'blob_read',
arguments: { blob_id: 'blob_1' },
output: {
blobId: 'blob_1',
fileName: 'a.txt',
fileType: 'text/plain',
content: 'A',
},
};
yield {
type: 'tool_result',
call_id: 'call_2',
name: 'blob_read',
arguments: { blob_id: 'blob_2' },
output: {
blobId: 'blob_2',
fileName: 'b.txt',
fileType: 'text/plain',
content: 'B',
},
};
yield { type: 'text_delta', text: 'Answer from files.' };
yield { type: 'done', finish_reason: 'stop' };
})();
const adapter = new NativeProviderAdapter(dispatch, {}, 3);
const chunks = [];
for await (const chunk of adapter.streamObject({
model: 'gpt-5-mini',
stream: true,
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
})) {
chunks.push(chunk);
}
const text = chunks
.filter(chunk => chunk.type === 'text-delta')
.map(chunk => chunk.textDelta)
.join('');
t.true(text.includes('Answer from files.'));
t.true(text.includes('[^1][^2]'));
t.true(
text.includes(
'[^1]: {"type":"attachment","blobId":"blob_1","fileName":"a.txt","fileType":"text/plain"}'
)
);
t.true(
text.includes(
'[^2]: {"type":"attachment","blobId":"blob_2","fileName":"b.txt","fileType":"text/plain"}'
)
);
});
test('NativeProviderAdapter streamObject should map tool and text events', async t => {
let round = 0;
const dispatch = (_request: NativeLlmRequest) =>
(async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
round += 1;
if (round === 1) {
yield {
type: 'tool_call',
call_id: 'call_1',
name: 'doc_read',
arguments: { doc_id: 'a1' },
};
yield { type: 'done', finish_reason: 'tool_calls' };
return;
}
yield { type: 'text_delta', text: 'ok' };
yield { type: 'done', finish_reason: 'stop' };
})();
const adapter = new NativeProviderAdapter(
dispatch,
{
doc_read: {
inputSchema: z.object({ doc_id: z.string() }),
execute: async () => ({ markdown: '# a1' }),
},
},
4
);
const events = [];
for await (const event of adapter.streamObject({
model: 'gpt-5-mini',
stream: true,
messages: [{ role: 'user', content: [{ type: 'text', text: 'read' }] }],
})) {
events.push(event);
}
t.deepEqual(
events.map(event => event.type),
['tool-call', 'tool-result', 'text-delta']
);
t.deepEqual(events[0], {
type: 'tool-call',
toolCallId: 'call_1',
toolName: 'doc_read',
args: { doc_id: 'a1' },
});
});
test('buildNativeRequest should include rust middleware from profile', async t => {
const { request } = await buildNativeRequest({
model: 'gpt-5-mini',
messages: [{ role: 'user', content: 'hello' }],
tools: {},
middleware: {
rust: {
request: ['normalize_messages', 'clamp_max_tokens'],
stream: ['stream_event_normalize', 'citation_indexing'],
},
node: {
text: ['callout'],
},
},
});
t.deepEqual(request.middleware, {
request: ['normalize_messages', 'clamp_max_tokens'],
stream: ['stream_event_normalize', 'citation_indexing'],
});
});
test('NativeProviderAdapter streamText should skip citation footnotes when disabled', async t => {
const adapter = new NativeProviderAdapter(mockDispatch, {}, 3, {
nodeTextMiddleware: ['callout'],
});
const chunks: string[] = [];
for await (const chunk of adapter.streamText({
model: 'gpt-5-mini',
stream: true,
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
})) {
chunks.push(chunk);
}
const text = chunks.join('');
t.true(text.includes('Use [^1] now'));
t.false(
text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}')
);
});

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,13 @@
import serverNativeModule from '@affine/server-native';
import test from 'ava'; import test from 'ava';
import type { NativeLlmRerankRequest } from '../../native';
import { ProviderMiddlewareConfig } from '../../plugins/copilot/config'; import { ProviderMiddlewareConfig } from '../../plugins/copilot/config';
import { normalizeOpenAIOptionsForModel } from '../../plugins/copilot/providers/openai'; import {
normalizeOpenAIOptionsForModel,
OpenAIProvider,
} from '../../plugins/copilot/providers/openai';
import { CopilotProvider } from '../../plugins/copilot/providers/provider'; import { CopilotProvider } from '../../plugins/copilot/providers/provider';
import { normalizeRerankModel } from '../../plugins/copilot/providers/rerank';
import { import {
CopilotProviderType, CopilotProviderType,
ModelInputType, ModelInputType,
@@ -46,6 +50,33 @@ class TestOpenAIProvider extends CopilotProvider<{ apiKey: string }> {
} }
} }
class NativeRerankProtocolProvider extends OpenAIProvider {
override readonly models = [
{
id: 'gpt-5.2',
capabilities: [
{
input: [ModelInputType.Text],
output: [ModelOutputType.Text, ModelOutputType.Rerank],
defaultForOutputType: true,
},
],
},
];
override get config() {
return {
apiKey: 'test-key',
baseURL: 'https://api.openai.com/v1',
oldApiStyle: false,
};
}
override configured() {
return true;
}
}
function createProvider(profileMiddleware?: ProviderMiddlewareConfig) { function createProvider(profileMiddleware?: ProviderMiddlewareConfig) {
const provider = new TestOpenAIProvider(); const provider = new TestOpenAIProvider();
(provider as any).AFFiNEConfig = { (provider as any).AFFiNEConfig = {
@@ -126,14 +157,44 @@ test('normalizeOpenAIOptionsForModel should keep options for gpt-4.1', t => {
); );
}); });
test('normalizeOpenAIRerankModel should keep supported rerank models', t => { test('OpenAI rerank should always use chat-completions native protocol', async t => {
t.is(normalizeRerankModel('gpt-4.1'), 'gpt-4.1'); const provider = new NativeRerankProtocolProvider();
t.is(normalizeRerankModel('gpt-4.1-mini'), 'gpt-4.1-mini'); let capturedProtocol: string | undefined;
t.is(normalizeRerankModel('gpt-5.2'), 'gpt-5.2'); let capturedRequest: NativeLlmRerankRequest | undefined;
const original = (serverNativeModule as any).llmRerankDispatch;
(serverNativeModule as any).llmRerankDispatch = (
protocol: string,
_backendConfigJson: string,
requestJson: string
) => {
capturedProtocol = protocol;
capturedRequest = JSON.parse(requestJson) as NativeLlmRerankRequest;
return JSON.stringify({ model: 'gpt-5.2', scores: [0.9, 0.1] });
};
t.teardown(() => {
(serverNativeModule as any).llmRerankDispatch = original;
}); });
test('normalizeOpenAIRerankModel should fall back for unsupported models', t => { const scores = await provider.rerank(
t.is(normalizeRerankModel('gpt-5-mini'), 'gpt-5.2'); { modelId: 'gpt-5.2' },
t.is(normalizeRerankModel('gemini-2.5-flash'), 'gpt-5.2'); {
t.is(normalizeRerankModel(undefined), 'gpt-5.2'); query: 'programming',
candidates: [
{ id: 'react', text: 'React is a UI library.' },
{ id: 'weather', text: 'The weather is sunny today.' },
],
}
);
t.deepEqual(scores, [0.9, 0.1]);
t.is(capturedProtocol, 'openai_chat');
t.deepEqual(capturedRequest, {
model: 'gpt-5.2',
query: 'programming',
candidates: [
{ id: 'react', text: 'React is a UI library.' },
{ id: 'weather', text: 'The weather is sunny today.' },
],
});
}); });

View File

@@ -34,6 +34,56 @@ test('ToolCallAccumulator should merge deltas and complete tool call', t => {
id: 'call_1', id: 'call_1',
name: 'doc_read', name: 'doc_read',
args: { doc_id: 'a1' }, args: { doc_id: 'a1' },
rawArgumentsText: '{"doc_id":"a1"}',
thought: undefined,
});
});
test('ToolCallAccumulator should preserve invalid JSON instead of swallowing it', t => {
const accumulator = new ToolCallAccumulator();
accumulator.feedDelta({
type: 'tool_call_delta',
call_id: 'call_1',
name: 'doc_read',
arguments_delta: '{"doc_id":',
});
const pending = accumulator.drainPending();
t.is(pending.length, 1);
t.deepEqual(pending[0]?.id, 'call_1');
t.deepEqual(pending[0]?.name, 'doc_read');
t.deepEqual(pending[0]?.args, {});
t.is(pending[0]?.rawArgumentsText, '{"doc_id":');
t.truthy(pending[0]?.argumentParseError);
});
test('ToolCallAccumulator should prefer native canonical tool arguments metadata', t => {
const accumulator = new ToolCallAccumulator();
accumulator.feedDelta({
type: 'tool_call_delta',
call_id: 'call_1',
name: 'doc_read',
arguments_delta: '{"stale":true}',
});
const completed = accumulator.complete({
type: 'tool_call',
call_id: 'call_1',
name: 'doc_read',
arguments: {},
arguments_text: '{"doc_id":"a1"}',
arguments_error: 'invalid json',
});
t.deepEqual(completed, {
id: 'call_1',
name: 'doc_read',
args: {},
rawArgumentsText: '{"doc_id":"a1"}',
argumentParseError: 'invalid json',
thought: undefined, thought: undefined,
}); });
}); });
@@ -71,6 +121,8 @@ test('ToolSchemaExtractor should convert zod schema to json schema', t => {
test('ToolCallLoop should execute tool call and continue to next round', async t => { test('ToolCallLoop should execute tool call and continue to next round', async t => {
const dispatchRequests: NativeLlmRequest[] = []; const dispatchRequests: NativeLlmRequest[] = [];
const originalMessages = [{ role: 'user', content: 'read doc' }] as const;
const signal = new AbortController().signal;
const dispatch = (request: NativeLlmRequest) => { const dispatch = (request: NativeLlmRequest) => {
dispatchRequests.push(request); dispatchRequests.push(request);
@@ -100,13 +152,17 @@ test('ToolCallLoop should execute tool call and continue to next round', async t
}; };
let executedArgs: Record<string, unknown> | null = null; let executedArgs: Record<string, unknown> | null = null;
let executedMessages: unknown;
let executedSignal: AbortSignal | undefined;
const loop = new ToolCallLoop( const loop = new ToolCallLoop(
dispatch, dispatch,
{ {
doc_read: { doc_read: {
inputSchema: z.object({ doc_id: z.string() }), inputSchema: z.object({ doc_id: z.string() }),
execute: async args => { execute: async (args, options) => {
executedArgs = args; executedArgs = args;
executedMessages = options.messages;
executedSignal = options.signal;
return { markdown: '# doc' }; return { markdown: '# doc' };
}, },
}, },
@@ -114,6 +170,92 @@ test('ToolCallLoop should execute tool call and continue to next round', async t
4 4
); );
const events: NativeLlmStreamEvent[] = [];
for await (const event of loop.run(
{
model: 'gpt-5-mini',
stream: true,
messages: [
{ role: 'user', content: [{ type: 'text', text: 'read doc' }] },
],
},
signal,
[...originalMessages]
)) {
events.push(event);
}
t.deepEqual(executedArgs, { doc_id: 'a1' });
t.deepEqual(executedMessages, originalMessages);
t.is(executedSignal, signal);
t.true(
dispatchRequests[1]?.messages.some(message => message.role === 'tool')
);
t.deepEqual(dispatchRequests[1]?.messages[1]?.content, [
{
type: 'tool_call',
call_id: 'call_1',
name: 'doc_read',
arguments: { doc_id: 'a1' },
arguments_text: '{"doc_id":"a1"}',
arguments_error: undefined,
thought: undefined,
},
]);
t.deepEqual(dispatchRequests[1]?.messages[2]?.content, [
{
type: 'tool_result',
call_id: 'call_1',
name: 'doc_read',
arguments: { doc_id: 'a1' },
arguments_text: '{"doc_id":"a1"}',
arguments_error: undefined,
output: { markdown: '# doc' },
is_error: undefined,
},
]);
t.deepEqual(
events.map(event => event.type),
['tool_call', 'tool_result', 'text_delta', 'done']
);
});
test('ToolCallLoop should surface invalid JSON as tool error without executing', async t => {
let executed = false;
let round = 0;
const loop = new ToolCallLoop(
request => {
round += 1;
const hasToolResult = request.messages.some(
message => message.role === 'tool'
);
return (async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
if (!hasToolResult && round === 1) {
yield {
type: 'tool_call_delta',
call_id: 'call_1',
name: 'doc_read',
arguments_delta: '{"doc_id":',
};
yield { type: 'done', finish_reason: 'tool_calls' };
return;
}
yield { type: 'done', finish_reason: 'stop' };
})();
},
{
doc_read: {
inputSchema: z.object({ doc_id: z.string() }),
execute: async () => {
executed = true;
return { markdown: '# doc' };
},
},
},
2
);
const events: NativeLlmStreamEvent[] = []; const events: NativeLlmStreamEvent[] = [];
for await (const event of loop.run({ for await (const event of loop.run({
model: 'gpt-5-mini', model: 'gpt-5-mini',
@@ -123,12 +265,24 @@ test('ToolCallLoop should execute tool call and continue to next round', async t
events.push(event); events.push(event);
} }
t.deepEqual(executedArgs, { doc_id: 'a1' }); t.false(executed);
t.true( t.true(events[0]?.type === 'tool_result');
dispatchRequests[1]?.messages.some(message => message.role === 'tool') t.deepEqual(events[0], {
); type: 'tool_result',
t.deepEqual( call_id: 'call_1',
events.map(event => event.type), name: 'doc_read',
['tool_call', 'tool_result', 'text_delta', 'done'] arguments: {},
); arguments_text: '{"doc_id":',
arguments_error:
events[0]?.type === 'tool_result' ? events[0].arguments_error : undefined,
output: {
message: 'Invalid tool arguments JSON',
rawArguments: '{"doc_id":',
error:
events[0]?.type === 'tool_result'
? events[0].arguments_error
: undefined,
},
is_error: true,
});
}); });

View File

@@ -1,12 +1,6 @@
import test from 'ava'; import test from 'ava';
import { z } from 'zod';
import { import { CitationFootnoteFormatter } from '../../plugins/copilot/providers/utils';
chatToGPTMessage,
CitationFootnoteFormatter,
CitationParser,
StreamPatternParser,
} from '../../plugins/copilot/providers/utils';
test('CitationFootnoteFormatter should format sorted footnotes from citation events', t => { test('CitationFootnoteFormatter should format sorted footnotes from citation events', t => {
const formatter = new CitationFootnoteFormatter(); const formatter = new CitationFootnoteFormatter();
@@ -50,67 +44,3 @@ test('CitationFootnoteFormatter should overwrite duplicated index with latest ur
'[^1]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fnew"}' '[^1]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fnew"}'
); );
}); });
test('StreamPatternParser should keep state across chunks', t => {
const parser = new StreamPatternParser(pattern => {
if (pattern.kind === 'wrappedLink') {
return `[^${pattern.url}]`;
}
if (pattern.kind === 'index') {
return `[#${pattern.value}]`;
}
return `[${pattern.text}](${pattern.url})`;
});
const first = parser.write('ref ([AFFiNE](https://affine.pro');
const second = parser.write(')) and [2]');
t.is(first, 'ref ');
t.is(second, '[^https://affine.pro] and [#2]');
t.is(parser.end(), '');
});
test('CitationParser should convert wrapped links to numbered footnotes', t => {
const parser = new CitationParser();
const output = parser.parse('Use ([AFFiNE](https://affine.pro)) now');
t.is(output, 'Use [^1] now');
t.regex(
parser.end(),
/\[\^1\]: \{"type":"url","url":"https%3A%2F%2Faffine.pro"\}/
);
});
test('chatToGPTMessage should not mutate input and should keep system schema', async t => {
const schema = z.object({
query: z.string(),
});
const messages = [
{
role: 'system' as const,
content: 'You are helper',
params: { schema },
},
{
role: 'user' as const,
content: '',
attachments: ['https://example.com/a.png'],
},
];
const firstRef = messages[0];
const secondRef = messages[1];
const [system, normalized, parsedSchema] = await chatToGPTMessage(
messages,
false
);
t.is(system, 'You are helper');
t.is(parsedSchema, schema);
t.is(messages.length, 2);
t.is(messages[0], firstRef);
t.is(messages[1], secondRef);
t.deepEqual(normalized[0], {
role: 'user',
content: [{ type: 'text', text: '[no content]' }],
});
});

View File

@@ -33,7 +33,7 @@ export class MockCopilotProvider extends OpenAIProvider {
id: 'test-image', id: 'test-image',
capabilities: [ capabilities: [
{ {
input: [ModelInputType.Text], input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Image], output: [ModelOutputType.Image],
defaultForOutputType: true, defaultForOutputType: true,
}, },

View File

@@ -10,6 +10,7 @@ import {
CopilotSessionNotFound, CopilotSessionNotFound,
} from '../base'; } from '../base';
import { getTokenEncoder } from '../native'; import { getTokenEncoder } from '../native';
import type { PromptAttachment } from '../plugins/copilot/providers/types';
import { BaseModel } from './base'; import { BaseModel } from './base';
export enum SessionType { export enum SessionType {
@@ -24,7 +25,7 @@ type ChatPrompt = {
model: string; model: string;
}; };
type ChatAttachment = { attachment: string; mimeType: string } | string; type ChatAttachment = PromptAttachment;
type ChatStreamObject = { type ChatStreamObject = {
type: 'text-delta' | 'reasoning' | 'tool-call' | 'tool-result'; type: 'text-delta' | 'reasoning' | 'tool-call' | 'tool-result';
@@ -173,22 +174,105 @@ export class CopilotSessionModel extends BaseModel {
} }
return attachments return attachments
.map(attachment => .map(attachment => {
typeof attachment === 'string' if (typeof attachment === 'string') {
? (this.sanitizeString(attachment) ?? '') return this.sanitizeString(attachment) ?? '';
: { }
if ('attachment' in attachment) {
return {
attachment: attachment:
this.sanitizeString(attachment.attachment) ?? this.sanitizeString(attachment.attachment) ??
attachment.attachment, attachment.attachment,
mimeType: mimeType:
this.sanitizeString(attachment.mimeType) ?? attachment.mimeType, this.sanitizeString(attachment.mimeType) ?? attachment.mimeType,
};
} }
)
switch (attachment.kind) {
case 'url':
return {
...attachment,
url: this.sanitizeString(attachment.url) ?? attachment.url,
mimeType:
this.sanitizeString(attachment.mimeType) ?? attachment.mimeType,
fileName:
this.sanitizeString(attachment.fileName) ?? attachment.fileName,
providerHint: attachment.providerHint
? {
provider:
this.sanitizeString(attachment.providerHint.provider) ??
attachment.providerHint.provider,
kind:
this.sanitizeString(attachment.providerHint.kind) ??
attachment.providerHint.kind,
}
: undefined,
};
case 'data':
case 'bytes':
return {
...attachment,
data: this.sanitizeString(attachment.data) ?? attachment.data,
mimeType:
this.sanitizeString(attachment.mimeType) ?? attachment.mimeType,
fileName:
this.sanitizeString(attachment.fileName) ?? attachment.fileName,
providerHint: attachment.providerHint
? {
provider:
this.sanitizeString(attachment.providerHint.provider) ??
attachment.providerHint.provider,
kind:
this.sanitizeString(attachment.providerHint.kind) ??
attachment.providerHint.kind,
}
: undefined,
};
case 'file_handle':
return {
...attachment,
fileHandle:
this.sanitizeString(attachment.fileHandle) ??
attachment.fileHandle,
mimeType:
this.sanitizeString(attachment.mimeType) ?? attachment.mimeType,
fileName:
this.sanitizeString(attachment.fileName) ?? attachment.fileName,
providerHint: attachment.providerHint
? {
provider:
this.sanitizeString(attachment.providerHint.provider) ??
attachment.providerHint.provider,
kind:
this.sanitizeString(attachment.providerHint.kind) ??
attachment.providerHint.kind,
}
: undefined,
};
}
return attachment;
})
.filter(attachment => { .filter(attachment => {
if (typeof attachment === 'string') { if (typeof attachment === 'string') {
return !!attachment; return !!attachment;
} }
if ('attachment' in attachment) {
return !!attachment.attachment && !!attachment.mimeType; return !!attachment.attachment && !!attachment.mimeType;
}
switch (attachment.kind) {
case 'url':
return !!attachment.url;
case 'data':
case 'bytes':
return !!attachment.data && !!attachment.mimeType;
case 'file_handle':
return !!attachment.fileHandle;
}
return false;
}); });
} }

View File

@@ -65,6 +65,21 @@ type NativeLlmModule = {
backendConfigJson: string, backendConfigJson: string,
requestJson: string requestJson: string
) => string | Promise<string>; ) => string | Promise<string>;
llmStructuredDispatch?: (
protocol: string,
backendConfigJson: string,
requestJson: string
) => string | Promise<string>;
llmEmbeddingDispatch?: (
protocol: string,
backendConfigJson: string,
requestJson: string
) => string | Promise<string>;
llmRerankDispatch?: (
protocol: string,
backendConfigJson: string,
requestJson: string
) => string | Promise<string>;
llmDispatchStream?: ( llmDispatchStream?: (
protocol: string, protocol: string,
backendConfigJson: string, backendConfigJson: string,
@@ -79,12 +94,20 @@ const nativeLlmModule = serverNativeModule as typeof serverNativeModule &
export type NativeLlmProtocol = export type NativeLlmProtocol =
| 'openai_chat' | 'openai_chat'
| 'openai_responses' | 'openai_responses'
| 'anthropic'; | 'anthropic'
| 'gemini';
export type NativeLlmBackendConfig = { export type NativeLlmBackendConfig = {
base_url: string; base_url: string;
auth_token: string; auth_token: string;
request_layer?: 'anthropic' | 'chat_completions' | 'responses' | 'vertex'; request_layer?:
| 'anthropic'
| 'chat_completions'
| 'responses'
| 'vertex'
| 'vertex_anthropic'
| 'gemini_api'
| 'gemini_vertex';
headers?: Record<string, string>; headers?: Record<string, string>;
no_streaming?: boolean; no_streaming?: boolean;
timeout_ms?: number; timeout_ms?: number;
@@ -100,6 +123,8 @@ export type NativeLlmCoreContent =
call_id: string; call_id: string;
name: string; name: string;
arguments: Record<string, unknown>; arguments: Record<string, unknown>;
arguments_text?: string;
arguments_error?: string;
thought?: string; thought?: string;
} }
| { | {
@@ -109,8 +134,12 @@ export type NativeLlmCoreContent =
is_error?: boolean; is_error?: boolean;
name?: string; name?: string;
arguments?: Record<string, unknown>; arguments?: Record<string, unknown>;
arguments_text?: string;
arguments_error?: string;
} }
| { type: 'image'; source: Record<string, unknown> | string }; | { type: 'image'; source: Record<string, unknown> | string }
| { type: 'audio'; source: Record<string, unknown> | string }
| { type: 'file'; source: Record<string, unknown> | string };
export type NativeLlmCoreMessage = { export type NativeLlmCoreMessage = {
role: NativeLlmCoreRole; role: NativeLlmCoreRole;
@@ -133,22 +162,54 @@ export type NativeLlmRequest = {
tool_choice?: 'auto' | 'none' | 'required' | { name: string }; tool_choice?: 'auto' | 'none' | 'required' | { name: string };
include?: string[]; include?: string[];
reasoning?: Record<string, unknown>; reasoning?: Record<string, unknown>;
response_schema?: Record<string, unknown>;
middleware?: { middleware?: {
request?: Array< request?: Array<
'normalize_messages' | 'clamp_max_tokens' | 'tool_schema_rewrite' 'normalize_messages' | 'clamp_max_tokens' | 'tool_schema_rewrite'
>; >;
stream?: Array<'stream_event_normalize' | 'citation_indexing'>; stream?: Array<'stream_event_normalize' | 'citation_indexing'>;
config?: { config?: {
no_additional_properties?: boolean; additional_properties_policy?: 'preserve' | 'forbid';
drop_property_format?: boolean; property_format_policy?: 'preserve' | 'drop';
drop_property_min_length?: boolean; property_min_length_policy?: 'preserve' | 'drop';
drop_array_min_items?: boolean; array_min_items_policy?: 'preserve' | 'drop';
drop_array_max_items?: boolean; array_max_items_policy?: 'preserve' | 'drop';
max_tokens_cap?: number; max_tokens_cap?: number;
}; };
}; };
}; };
export type NativeLlmStructuredRequest = {
model: string;
messages: NativeLlmCoreMessage[];
schema: Record<string, unknown>;
max_tokens?: number;
temperature?: number;
reasoning?: Record<string, unknown>;
strict?: boolean;
response_mime_type?: string;
middleware?: NativeLlmRequest['middleware'];
};
export type NativeLlmEmbeddingRequest = {
model: string;
inputs: string[];
dimensions?: number;
task_type?: string;
};
export type NativeLlmRerankCandidate = {
id?: string;
text: string;
};
export type NativeLlmRerankRequest = {
model: string;
query: string;
candidates: NativeLlmRerankCandidate[];
top_n?: number;
};
export type NativeLlmDispatchResponse = { export type NativeLlmDispatchResponse = {
id: string; id: string;
model: string; model: string;
@@ -159,10 +220,39 @@ export type NativeLlmDispatchResponse = {
total_tokens: number; total_tokens: number;
cached_tokens?: number; cached_tokens?: number;
}; };
finish_reason: string; finish_reason:
| 'stop'
| 'length'
| 'tool_calls'
| 'content_filter'
| 'error'
| string;
reasoning_details?: unknown; reasoning_details?: unknown;
}; };
export type NativeLlmStructuredResponse = {
id: string;
model: string;
output_text: string;
usage: NativeLlmDispatchResponse['usage'];
finish_reason: NativeLlmDispatchResponse['finish_reason'];
reasoning_details?: unknown;
};
export type NativeLlmEmbeddingResponse = {
model: string;
embeddings: number[][];
usage?: {
prompt_tokens: number;
total_tokens: number;
};
};
export type NativeLlmRerankResponse = {
model: string;
scores: number[];
};
export type NativeLlmStreamEvent = export type NativeLlmStreamEvent =
| { type: 'message_start'; id?: string; model?: string } | { type: 'message_start'; id?: string; model?: string }
| { type: 'text_delta'; text: string } | { type: 'text_delta'; text: string }
@@ -178,6 +268,8 @@ export type NativeLlmStreamEvent =
call_id: string; call_id: string;
name: string; name: string;
arguments: Record<string, unknown>; arguments: Record<string, unknown>;
arguments_text?: string;
arguments_error?: string;
thought?: string; thought?: string;
} }
| { | {
@@ -187,6 +279,8 @@ export type NativeLlmStreamEvent =
is_error?: boolean; is_error?: boolean;
name?: string; name?: string;
arguments?: Record<string, unknown>; arguments?: Record<string, unknown>;
arguments_text?: string;
arguments_error?: string;
} }
| { type: 'citation'; index: number; url: string } | { type: 'citation'; index: number; url: string }
| { | {
@@ -200,7 +294,7 @@ export type NativeLlmStreamEvent =
} }
| { | {
type: 'done'; type: 'done';
finish_reason?: string; finish_reason?: NativeLlmDispatchResponse['finish_reason'];
usage?: { usage?: {
prompt_tokens: number; prompt_tokens: number;
completion_tokens: number; completion_tokens: number;
@@ -228,6 +322,57 @@ export async function llmDispatch(
return JSON.parse(responseText) as NativeLlmDispatchResponse; return JSON.parse(responseText) as NativeLlmDispatchResponse;
} }
export async function llmStructuredDispatch(
protocol: NativeLlmProtocol,
backendConfig: NativeLlmBackendConfig,
request: NativeLlmStructuredRequest
): Promise<NativeLlmStructuredResponse> {
if (!nativeLlmModule.llmStructuredDispatch) {
throw new Error('native llm structured dispatch is not available');
}
const response = nativeLlmModule.llmStructuredDispatch(
protocol,
JSON.stringify(backendConfig),
JSON.stringify(request)
);
const responseText = await Promise.resolve(response);
return JSON.parse(responseText) as NativeLlmStructuredResponse;
}
export async function llmEmbeddingDispatch(
protocol: NativeLlmProtocol,
backendConfig: NativeLlmBackendConfig,
request: NativeLlmEmbeddingRequest
): Promise<NativeLlmEmbeddingResponse> {
if (!nativeLlmModule.llmEmbeddingDispatch) {
throw new Error('native llm embedding dispatch is not available');
}
const response = nativeLlmModule.llmEmbeddingDispatch(
protocol,
JSON.stringify(backendConfig),
JSON.stringify(request)
);
const responseText = await Promise.resolve(response);
return JSON.parse(responseText) as NativeLlmEmbeddingResponse;
}
export async function llmRerankDispatch(
protocol: NativeLlmProtocol,
backendConfig: NativeLlmBackendConfig,
request: NativeLlmRerankRequest
): Promise<NativeLlmRerankResponse> {
if (!nativeLlmModule.llmRerankDispatch) {
throw new Error('native llm rerank dispatch is not available');
}
const response = nativeLlmModule.llmRerankDispatch(
protocol,
JSON.stringify(backendConfig),
JSON.stringify(request)
);
const responseText = await Promise.resolve(response);
return JSON.parse(responseText) as NativeLlmRerankResponse;
}
export class NativeStreamAdapter<T> implements AsyncIterableIterator<T> { export class NativeStreamAdapter<T> implements AsyncIterableIterator<T> {
readonly #queue: T[] = []; readonly #queue: T[] = [];
readonly #waiters: ((result: IteratorResult<T>) => void)[] = []; readonly #waiters: ((result: IteratorResult<T>) => void)[] = [];

View File

@@ -81,7 +81,7 @@ export type CopilotProviderProfile = CopilotProviderProfileCommon &
}[CopilotProviderType]; }[CopilotProviderType];
export type CopilotProviderDefaults = Partial< export type CopilotProviderDefaults = Partial<
Record<ModelOutputType, string> Record<Exclude<ModelOutputType, ModelOutputType.Rerank>, string>
> & { > & {
fallback?: string; fallback?: string;
}; };
@@ -184,6 +184,7 @@ const CopilotProviderDefaultsShape = z.object({
[ModelOutputType.Object]: z.string().optional(), [ModelOutputType.Object]: z.string().optional(),
[ModelOutputType.Embedding]: z.string().optional(), [ModelOutputType.Embedding]: z.string().optional(),
[ModelOutputType.Image]: z.string().optional(), [ModelOutputType.Image]: z.string().optional(),
[ModelOutputType.Rerank]: z.string().optional(),
[ModelOutputType.Structured]: z.string().optional(), [ModelOutputType.Structured]: z.string().optional(),
fallback: z.string().optional(), fallback: z.string().optional(),
}); });

View File

@@ -1,25 +1,17 @@
import { Logger } from '@nestjs/common'; import { Logger } from '@nestjs/common';
import type { ModuleRef } from '@nestjs/core'; import type { ModuleRef } from '@nestjs/core';
import { import { Config, CopilotProviderNotSupported } from '../../../base';
Config,
CopilotPromptNotFound,
CopilotProviderNotSupported,
} from '../../../base';
import { CopilotFailedToGenerateEmbedding } from '../../../base/error/errors.gen'; import { CopilotFailedToGenerateEmbedding } from '../../../base/error/errors.gen';
import { import {
ChunkSimilarity, ChunkSimilarity,
Embedding, Embedding,
EMBEDDING_DIMENSIONS, EMBEDDING_DIMENSIONS,
} from '../../../models'; } from '../../../models';
import { PromptService } from '../prompt/service';
import { CopilotProviderFactory } from '../providers/factory'; import { CopilotProviderFactory } from '../providers/factory';
import type { CopilotProvider } from '../providers/provider'; import type { CopilotProvider } from '../providers/provider';
import { import {
DEFAULT_RERANK_MODEL, type CopilotRerankRequest,
normalizeRerankModel,
} from '../providers/rerank';
import {
type ModelFullConditions, type ModelFullConditions,
ModelInputType, ModelInputType,
ModelOutputType, ModelOutputType,
@@ -27,24 +19,20 @@ import {
import { EmbeddingClient, type ReRankResult } from './types'; import { EmbeddingClient, type ReRankResult } from './types';
const EMBEDDING_MODEL = 'gemini-embedding-001'; const EMBEDDING_MODEL = 'gemini-embedding-001';
const RERANK_PROMPT = 'Rerank results'; const RERANK_MODEL = 'gpt-5.2';
class ProductionEmbeddingClient extends EmbeddingClient { class ProductionEmbeddingClient extends EmbeddingClient {
private readonly logger = new Logger(ProductionEmbeddingClient.name); private readonly logger = new Logger(ProductionEmbeddingClient.name);
constructor( constructor(
private readonly config: Config, private readonly config: Config,
private readonly providerFactory: CopilotProviderFactory, private readonly providerFactory: CopilotProviderFactory
private readonly prompt: PromptService
) { ) {
super(); super();
} }
override async configured(): Promise<boolean> { override async configured(): Promise<boolean> {
const embedding = await this.providerFactory.getProvider({ const embedding = await this.providerFactory.getProvider({
modelId: this.config.copilot?.scenarios?.override_enabled modelId: this.getEmbeddingModelId(),
? this.config.copilot.scenarios.scenarios?.embedding || EMBEDDING_MODEL
: EMBEDDING_MODEL,
outputType: ModelOutputType.Embedding, outputType: ModelOutputType.Embedding,
}); });
const result = Boolean(embedding); const result = Boolean(embedding);
@@ -69,9 +57,15 @@ class ProductionEmbeddingClient extends EmbeddingClient {
return provider; return provider;
} }
private getEmbeddingModelId() {
return this.config.copilot?.scenarios?.override_enabled
? this.config.copilot.scenarios.scenarios?.embedding || EMBEDDING_MODEL
: EMBEDDING_MODEL;
}
async getEmbeddings(input: string[]): Promise<Embedding[]> { async getEmbeddings(input: string[]): Promise<Embedding[]> {
const provider = await this.getProvider({ const provider = await this.getProvider({
modelId: EMBEDDING_MODEL, modelId: this.getEmbeddingModelId(),
outputType: ModelOutputType.Embedding, outputType: ModelOutputType.Embedding,
}); });
this.logger.verbose( this.logger.verbose(
@@ -114,21 +108,22 @@ class ProductionEmbeddingClient extends EmbeddingClient {
): Promise<ReRankResult> { ): Promise<ReRankResult> {
if (!embeddings.length) return []; if (!embeddings.length) return [];
const prompt = await this.prompt.get(RERANK_PROMPT); const provider = await this.getProvider({
if (!prompt) { modelId: RERANK_MODEL,
throw new CopilotPromptNotFound({ name: RERANK_PROMPT }); outputType: ModelOutputType.Rerank,
} });
const rerankModel = normalizeRerankModel(prompt.model);
if (prompt.model !== rerankModel) { const rerankRequest: CopilotRerankRequest = {
this.logger.warn( query,
`Unsupported rerank model "${prompt.model}" configured, falling back to "${DEFAULT_RERANK_MODEL}".` candidates: embeddings.map((embedding, index) => ({
); id: String(index),
} text: embedding.content,
const provider = await this.getProvider({ modelId: rerankModel }); })),
};
const ranks = await provider.rerank( const ranks = await provider.rerank(
{ modelId: rerankModel }, { modelId: RERANK_MODEL },
embeddings.map(e => prompt.finish({ query, doc: e.content })), rerankRequest,
{ signal } { signal }
); );
@@ -227,9 +222,7 @@ export async function getEmbeddingClient(
const providerFactory = moduleRef.get(CopilotProviderFactory, { const providerFactory = moduleRef.get(CopilotProviderFactory, {
strict: false, strict: false,
}); });
const prompt = moduleRef.get(PromptService, { strict: false }); const client = new ProductionEmbeddingClient(config, providerFactory);
const client = new ProductionEmbeddingClient(config, providerFactory, prompt);
if (await client.configured()) { if (await client.configured()) {
EMBEDDING_CLIENT = client; EMBEDDING_CLIENT = client;
} }

View File

@@ -418,21 +418,6 @@ Convert a multi-speaker audio recording into a structured JSON format by transcr
maxRetries: 1, maxRetries: 1,
}, },
}, },
{
name: 'Rerank results',
action: 'Rerank results',
model: 'gpt-5.2',
messages: [
{
role: 'system',
content: `Judge whether the Document meets the requirements based on the Query and the Instruct provided. The answer must be "yes" or "no".`,
},
{
role: 'user',
content: `<Instruct>: Given a document search result, determine whether the result is relevant to the query.\n<Query>: {{query}}\n<Document>: {{doc}}`,
},
],
},
{ {
name: 'Generate a caption', name: 'Generate a caption',
action: 'Generate a caption', action: 'Generate a caption',

View File

@@ -1,5 +1,3 @@
import type { ToolSet } from 'ai';
import { import {
CopilotProviderSideError, CopilotProviderSideError,
metrics, metrics,
@@ -11,6 +9,7 @@ import {
type NativeLlmRequest, type NativeLlmRequest,
} from '../../../../native'; } from '../../../../native';
import type { NodeTextMiddleware } from '../../config'; import type { NodeTextMiddleware } from '../../config';
import type { CopilotToolSet } from '../../tools';
import { buildNativeRequest, NativeProviderAdapter } from '../native'; import { buildNativeRequest, NativeProviderAdapter } from '../native';
import { CopilotProvider } from '../provider'; import { CopilotProvider } from '../provider';
import type { import type {
@@ -20,7 +19,11 @@ import type {
StreamObject, StreamObject,
} from '../types'; } from '../types';
import { CopilotProviderType, ModelOutputType } from '../types'; import { CopilotProviderType, ModelOutputType } from '../types';
import { getGoogleAuth, getVertexAnthropicBaseUrl } from '../utils'; import {
getGoogleAuth,
getVertexAnthropicBaseUrl,
type VertexAnthropicProviderConfig,
} from '../utils';
export abstract class AnthropicProvider<T> extends CopilotProvider<T> { export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
private handleError(e: any) { private handleError(e: any) {
@@ -36,22 +39,16 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
private async createNativeConfig(): Promise<NativeLlmBackendConfig> { private async createNativeConfig(): Promise<NativeLlmBackendConfig> {
if (this.type === CopilotProviderType.AnthropicVertex) { if (this.type === CopilotProviderType.AnthropicVertex) {
const auth = await getGoogleAuth(this.config as any, 'anthropic'); const config = this.config as VertexAnthropicProviderConfig;
const headers = auth.headers(); const auth = await getGoogleAuth(config, 'anthropic');
const authorization = const { Authorization: authHeader } = auth.headers();
headers.Authorization || const token = authHeader.replace(/^Bearer\s+/i, '');
(headers as Record<string, string | undefined>).authorization; const baseUrl = getVertexAnthropicBaseUrl(config) || auth.baseUrl;
const token =
typeof authorization === 'string'
? authorization.replace(/^Bearer\s+/i, '')
: '';
const baseUrl =
getVertexAnthropicBaseUrl(this.config as any) || auth.baseUrl;
return { return {
base_url: baseUrl || '', base_url: baseUrl || '',
auth_token: token, auth_token: token,
request_layer: 'vertex', request_layer: 'vertex_anthropic',
headers, headers: { Authorization: authHeader },
}; };
} }
@@ -65,7 +62,7 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
private createAdapter( private createAdapter(
backendConfig: NativeLlmBackendConfig, backendConfig: NativeLlmBackendConfig,
tools: ToolSet, tools: CopilotToolSet,
nodeTextMiddleware?: NodeTextMiddleware[] nodeTextMiddleware?: NodeTextMiddleware[]
) { ) {
return new NativeProviderAdapter( return new NativeProviderAdapter(
@@ -93,8 +90,12 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): Promise<string> { ): Promise<string> {
const fullCond = { ...cond, outputType: ModelOutputType.Text }; const fullCond = { ...cond, outputType: ModelOutputType.Text };
await this.checkParams({ cond: fullCond, messages, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); cond: fullCond,
messages,
options,
});
const model = this.selectModel(normalizedCond);
try { try {
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
@@ -102,11 +103,13 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
const tools = await this.getTools(options, model.id); const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware(); const middleware = this.getActiveProviderMiddleware();
const reasoning = this.getReasoning(options, model.id); const reasoning = this.getReasoning(options, model.id);
const cap = this.getAttachCapability(model, ModelOutputType.Text);
const { request } = await buildNativeRequest({ const { request } = await buildNativeRequest({
model: model.id, model: model.id,
messages, messages,
options, options,
tools, tools,
attachmentCapability: cap,
reasoning, reasoning,
middleware, middleware,
}); });
@@ -115,7 +118,7 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
tools, tools,
middleware.node?.text middleware.node?.text
); );
return await adapter.text(request, options.signal); return await adapter.text(request, options.signal, messages);
} catch (e: any) { } catch (e: any) {
metrics.ai metrics.ai
.counter('chat_text_errors') .counter('chat_text_errors')
@@ -130,8 +133,12 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): AsyncIterable<string> { ): AsyncIterable<string> {
const fullCond = { ...cond, outputType: ModelOutputType.Text }; const fullCond = { ...cond, outputType: ModelOutputType.Text };
await this.checkParams({ cond: fullCond, messages, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); cond: fullCond,
messages,
options,
});
const model = this.selectModel(normalizedCond);
try { try {
metrics.ai metrics.ai
@@ -140,11 +147,13 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
const backendConfig = await this.createNativeConfig(); const backendConfig = await this.createNativeConfig();
const tools = await this.getTools(options, model.id); const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware(); const middleware = this.getActiveProviderMiddleware();
const cap = this.getAttachCapability(model, ModelOutputType.Text);
const { request } = await buildNativeRequest({ const { request } = await buildNativeRequest({
model: model.id, model: model.id,
messages, messages,
options, options,
tools, tools,
attachmentCapability: cap,
reasoning: this.getReasoning(options, model.id), reasoning: this.getReasoning(options, model.id),
middleware, middleware,
}); });
@@ -153,7 +162,11 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
tools, tools,
middleware.node?.text middleware.node?.text
); );
for await (const chunk of adapter.streamText(request, options.signal)) { for await (const chunk of adapter.streamText(
request,
options.signal,
messages
)) {
yield chunk; yield chunk;
} }
} catch (e: any) { } catch (e: any) {
@@ -170,8 +183,12 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> { ): AsyncIterable<StreamObject> {
const fullCond = { ...cond, outputType: ModelOutputType.Object }; const fullCond = { ...cond, outputType: ModelOutputType.Object };
await this.checkParams({ cond: fullCond, messages, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); cond: fullCond,
messages,
options,
});
const model = this.selectModel(normalizedCond);
try { try {
metrics.ai metrics.ai
@@ -180,11 +197,13 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
const backendConfig = await this.createNativeConfig(); const backendConfig = await this.createNativeConfig();
const tools = await this.getTools(options, model.id); const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware(); const middleware = this.getActiveProviderMiddleware();
const cap = this.getAttachCapability(model, ModelOutputType.Object);
const { request } = await buildNativeRequest({ const { request } = await buildNativeRequest({
model: model.id, model: model.id,
messages, messages,
options, options,
tools, tools,
attachmentCapability: cap,
reasoning: this.getReasoning(options, model.id), reasoning: this.getReasoning(options, model.id),
middleware, middleware,
}); });
@@ -193,7 +212,11 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
tools, tools,
middleware.node?.text middleware.node?.text
); );
for await (const chunk of adapter.streamObject(request, options.signal)) { for await (const chunk of adapter.streamObject(
request,
options.signal,
messages
)) {
yield chunk; yield chunk;
} }
} catch (e: any) { } catch (e: any) {

View File

@@ -1,5 +1,6 @@
import z from 'zod'; import z from 'zod';
import { IMAGE_ATTACHMENT_CAPABILITY } from '../attachments';
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
import { AnthropicProvider } from './anthropic'; import { AnthropicProvider } from './anthropic';
@@ -23,6 +24,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
{ {
input: [ModelInputType.Text, ModelInputType.Image], input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text, ModelOutputType.Object], output: [ModelOutputType.Text, ModelOutputType.Object],
attachments: IMAGE_ATTACHMENT_CAPABILITY,
}, },
], ],
}, },
@@ -33,6 +35,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
{ {
input: [ModelInputType.Text, ModelInputType.Image], input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text, ModelOutputType.Object], output: [ModelOutputType.Text, ModelOutputType.Object],
attachments: IMAGE_ATTACHMENT_CAPABILITY,
}, },
], ],
}, },
@@ -43,6 +46,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
{ {
input: [ModelInputType.Text, ModelInputType.Image], input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text, ModelOutputType.Object], output: [ModelOutputType.Text, ModelOutputType.Object],
attachments: IMAGE_ATTACHMENT_CAPABILITY,
}, },
], ],
}, },

View File

@@ -1,18 +1,14 @@
import { import { IMAGE_ATTACHMENT_CAPABILITY } from '../attachments';
createVertexAnthropic,
type GoogleVertexAnthropicProvider,
type GoogleVertexAnthropicProviderSettings,
} from '@ai-sdk/google-vertex/anthropic';
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
import { import {
getGoogleAuth, getGoogleAuth,
getVertexAnthropicBaseUrl, getVertexAnthropicBaseUrl,
VertexModelListSchema, VertexModelListSchema,
type VertexProviderConfig,
} from '../utils'; } from '../utils';
import { AnthropicProvider } from './anthropic'; import { AnthropicProvider } from './anthropic';
export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings; export type AnthropicVertexConfig = VertexProviderConfig;
export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexConfig> { export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexConfig> {
override readonly type = CopilotProviderType.AnthropicVertex; override readonly type = CopilotProviderType.AnthropicVertex;
@@ -25,6 +21,7 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
{ {
input: [ModelInputType.Text, ModelInputType.Image], input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text, ModelOutputType.Object], output: [ModelOutputType.Text, ModelOutputType.Object],
attachments: IMAGE_ATTACHMENT_CAPABILITY,
}, },
], ],
}, },
@@ -35,6 +32,7 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
{ {
input: [ModelInputType.Text, ModelInputType.Image], input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text, ModelOutputType.Object], output: [ModelOutputType.Text, ModelOutputType.Object],
attachments: IMAGE_ATTACHMENT_CAPABILITY,
}, },
], ],
}, },
@@ -45,23 +43,17 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
{ {
input: [ModelInputType.Text, ModelInputType.Image], input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Text, ModelOutputType.Object], output: [ModelOutputType.Text, ModelOutputType.Object],
attachments: IMAGE_ATTACHMENT_CAPABILITY,
}, },
], ],
}, },
]; ];
protected instance!: GoogleVertexAnthropicProvider;
override configured(): boolean { override configured(): boolean {
if (!this.config.location || !this.config.googleAuthOptions) return false; if (!this.config.location || !this.config.googleAuthOptions) return false;
return !!this.config.project || !!getVertexAnthropicBaseUrl(this.config); return !!this.config.project || !!getVertexAnthropicBaseUrl(this.config);
} }
override setup() {
super.setup();
this.instance = createVertexAnthropic(this.config);
}
override async refreshOnlineModels() { override async refreshOnlineModels() {
try { try {
const { baseUrl, headers } = await getGoogleAuth( const { baseUrl, headers } = await getGoogleAuth(

View File

@@ -0,0 +1,233 @@
import type {
ModelAttachmentCapability,
PromptAttachment,
PromptAttachmentKind,
PromptAttachmentSourceKind,
PromptMessage,
} from './types';
import { inferMimeType } from './utils';
export const IMAGE_ATTACHMENT_CAPABILITY: ModelAttachmentCapability = {
kinds: ['image'],
sourceKinds: ['url', 'data'],
allowRemoteUrls: true,
};
export const GEMINI_ATTACHMENT_CAPABILITY: ModelAttachmentCapability = {
kinds: ['image', 'audio', 'file'],
sourceKinds: ['url', 'data', 'bytes', 'file_handle'],
allowRemoteUrls: true,
};
export type CanonicalPromptAttachment = {
kind: PromptAttachmentKind;
sourceKind: PromptAttachmentSourceKind;
mediaType?: string;
source: Record<string, unknown>;
isRemote: boolean;
};
function parseDataUrl(url: string) {
if (!url.startsWith('data:')) {
return null;
}
const commaIndex = url.indexOf(',');
if (commaIndex === -1) {
return null;
}
const meta = url.slice(5, commaIndex);
const payload = url.slice(commaIndex + 1);
const parts = meta.split(';');
const mediaType = parts[0] || 'text/plain;charset=US-ASCII';
const isBase64 = parts.includes('base64');
return {
mediaType,
data: isBase64
? payload
: Buffer.from(decodeURIComponent(payload), 'utf8').toString('base64'),
};
}
function attachmentTypeFromMediaType(mediaType: string): PromptAttachmentKind {
if (mediaType.startsWith('image/')) {
return 'image';
}
if (mediaType.startsWith('audio/')) {
return 'audio';
}
return 'file';
}
function attachmentKindFromHintOrMediaType(
hint: PromptAttachmentKind | undefined,
mediaType: string | undefined
): PromptAttachmentKind {
if (hint) return hint;
return attachmentTypeFromMediaType(mediaType || '');
}
function toBase64Data(data: string, encoding: 'base64' | 'utf8' = 'base64') {
return encoding === 'base64'
? data
: Buffer.from(data, 'utf8').toString('base64');
}
function appendAttachMetadata(
source: Record<string, unknown>,
attachment: Exclude<PromptAttachment, string> & Record<string, unknown>
) {
if (attachment.fileName) {
source.file_name = attachment.fileName;
}
if (attachment.providerHint) {
source.provider_hint = attachment.providerHint;
}
return source;
}
export function promptAttachmentHasSource(
attachment: PromptAttachment
): boolean {
if (typeof attachment === 'string') {
return !!attachment.trim();
}
if ('attachment' in attachment) {
return !!attachment.attachment;
}
switch (attachment.kind) {
case 'url':
return !!attachment.url;
case 'data':
case 'bytes':
return !!attachment.data;
case 'file_handle':
return !!attachment.fileHandle;
}
}
export async function canonicalizePromptAttachment(
attachment: PromptAttachment,
message: Pick<PromptMessage, 'params'>
): Promise<CanonicalPromptAttachment> {
const fallbackMimeType =
typeof message.params?.mimetype === 'string'
? message.params.mimetype
: undefined;
if (typeof attachment === 'string') {
const dataUrl = parseDataUrl(attachment);
const mediaType =
fallbackMimeType ??
dataUrl?.mediaType ??
(await inferMimeType(attachment));
const kind = attachmentKindFromHintOrMediaType(undefined, mediaType);
if (dataUrl) {
return {
kind,
sourceKind: 'data',
mediaType,
isRemote: false,
source: {
media_type: mediaType || dataUrl.mediaType,
data: dataUrl.data,
},
};
}
return {
kind,
sourceKind: 'url',
mediaType,
isRemote: /^https?:\/\//.test(attachment),
source: { url: attachment, media_type: mediaType },
};
}
if ('attachment' in attachment) {
return await canonicalizePromptAttachment(
{
kind: 'url',
url: attachment.attachment,
mimeType: attachment.mimeType,
},
message
);
}
if (attachment.kind === 'url') {
const dataUrl = parseDataUrl(attachment.url);
const mediaType =
attachment.mimeType ??
fallbackMimeType ??
dataUrl?.mediaType ??
(await inferMimeType(attachment.url));
const kind = attachmentKindFromHintOrMediaType(
attachment.providerHint?.kind,
mediaType
);
if (dataUrl) {
return {
kind,
sourceKind: 'data',
mediaType,
isRemote: false,
source: appendAttachMetadata(
{ media_type: mediaType || dataUrl.mediaType, data: dataUrl.data },
attachment
),
};
}
return {
kind,
sourceKind: 'url',
mediaType,
isRemote: /^https?:\/\//.test(attachment.url),
source: appendAttachMetadata(
{ url: attachment.url, media_type: mediaType },
attachment
),
};
}
if (attachment.kind === 'data' || attachment.kind === 'bytes') {
return {
kind: attachmentKindFromHintOrMediaType(
attachment.providerHint?.kind,
attachment.mimeType
),
sourceKind: attachment.kind,
mediaType: attachment.mimeType,
isRemote: false,
source: appendAttachMetadata(
{
media_type: attachment.mimeType,
data: toBase64Data(
attachment.data,
attachment.kind === 'data' ? attachment.encoding : 'base64'
),
},
attachment
),
};
}
return {
kind: attachmentKindFromHintOrMediaType(
attachment.providerHint?.kind,
attachment.mimeType
),
sourceKind: 'file_handle',
mediaType: attachment.mimeType,
isRemote: false,
source: appendAttachMetadata(
{ file_handle: attachment.fileHandle, media_type: attachment.mimeType },
attachment
),
};
}

View File

@@ -19,6 +19,7 @@ import type {
PromptMessage, PromptMessage,
} from './types'; } from './types';
import { CopilotProviderType, ModelInputType, ModelOutputType } from './types'; import { CopilotProviderType, ModelInputType, ModelOutputType } from './types';
import { promptAttachmentMimeType, promptAttachmentToUrl } from './utils';
export type FalConfig = { export type FalConfig = {
apiKey: string; apiKey: string;
@@ -183,13 +184,14 @@ export class FalProvider extends CopilotProvider<FalConfig> {
return { return {
model_name: options.modelName || undefined, model_name: options.modelName || undefined,
image_url: attachments image_url: attachments
?.map(v => ?.map(v => {
typeof v === 'string' const url = promptAttachmentToUrl(v);
? v const mediaType = promptAttachmentMimeType(
: v.mimeType.startsWith('image/') v,
? v.attachment typeof params?.mimetype === 'string' ? params.mimetype : undefined
: undefined );
) return url && mediaType?.startsWith('image/') ? url : undefined;
})
.find(v => !!v), .find(v => !!v),
prompt: content.trim(), prompt: content.trim(),
loras: lora.length ? lora : undefined, loras: lora.length ? lora : undefined,

View File

@@ -1,87 +1,94 @@
import type { import { setTimeout as delay } from 'node:timers/promises';
GoogleGenerativeAIProvider,
GoogleGenerativeAIProviderOptions, import { ZodError } from 'zod';
} from '@ai-sdk/google';
import type { GoogleVertexProvider } from '@ai-sdk/google-vertex';
import {
AISDKError,
type EmbeddingModel,
embedMany,
generateObject,
generateText,
JSONParseError,
stepCountIs,
streamText,
} from 'ai';
import { import {
CopilotPromptInvalid,
CopilotProviderSideError, CopilotProviderSideError,
metrics, metrics,
OneMB,
readResponseBufferWithLimit,
safeFetch,
UserFriendlyError, UserFriendlyError,
} from '../../../../base'; } from '../../../../base';
import { sniffMime } from '../../../../base/storage/providers/utils';
import {
llmDispatchStream,
llmEmbeddingDispatch,
llmStructuredDispatch,
type NativeLlmBackendConfig,
type NativeLlmEmbeddingRequest,
type NativeLlmRequest,
type NativeLlmStructuredRequest,
} from '../../../../native';
import type { NodeTextMiddleware } from '../../config';
import type { CopilotToolSet } from '../../tools';
import {
buildNativeEmbeddingRequest,
buildNativeRequest,
buildNativeStructuredRequest,
NativeProviderAdapter,
parseNativeStructuredOutput,
StructuredResponseParseError,
} from '../native';
import { CopilotProvider } from '../provider'; import { CopilotProvider } from '../provider';
import type { import type {
CopilotChatOptions, CopilotChatOptions,
CopilotEmbeddingOptions, CopilotEmbeddingOptions,
CopilotImageOptions, CopilotImageOptions,
CopilotProviderModel, CopilotStructuredOptions,
ModelConditions, ModelConditions,
PromptAttachment,
PromptMessage, PromptMessage,
StreamObject, StreamObject,
} from '../types'; } from '../types';
import { ModelOutputType } from '../types'; import { ModelOutputType } from '../types';
import { import { promptAttachmentMimeType, promptAttachmentToUrl } from '../utils';
chatToGPTMessage,
StreamObjectParser,
TextStreamParser,
} from '../utils';
export const DEFAULT_DIMENSIONS = 256; export const DEFAULT_DIMENSIONS = 256;
const GEMINI_REMOTE_ATTACHMENT_MAX_BYTES = 64 * OneMB;
const TRUSTED_ATTACHMENT_HOST_SUFFIXES = ['cdn.affine.pro'];
const GEMINI_RETRY_INITIAL_DELAY_MS = 2_000;
function normalizeMimeType(mediaType?: string) {
return mediaType?.split(';', 1)[0]?.trim() || 'application/octet-stream';
}
function isYoutubeUrl(url: URL) {
const hostname = url.hostname.toLowerCase();
if (hostname === 'youtu.be') {
return /^\/[\w-]+$/.test(url.pathname);
}
if (hostname !== 'youtube.com' && hostname !== 'www.youtube.com') {
return false;
}
if (url.pathname !== '/watch') {
return false;
}
return !!url.searchParams.get('v');
}
function isGeminiFileUrl(url: URL, baseUrl: string) {
try {
const base = new URL(baseUrl);
const basePath = base.pathname.replace(/\/+$/, '');
return (
url.origin === base.origin &&
url.pathname.startsWith(`${basePath}/files/`)
);
} catch {
return false;
}
}
export abstract class GeminiProvider<T> extends CopilotProvider<T> { export abstract class GeminiProvider<T> extends CopilotProvider<T> {
protected abstract instance: protected abstract createNativeConfig(): Promise<NativeLlmBackendConfig>;
| GoogleGenerativeAIProvider
| GoogleVertexProvider;
private getThinkingConfig(
model: string,
options: { includeThoughts: boolean; useDynamicBudget?: boolean }
): NonNullable<GoogleGenerativeAIProviderOptions['thinkingConfig']> {
if (this.isGemini3Model(model)) {
return {
includeThoughts: options.includeThoughts,
thinkingLevel: 'high',
};
}
return {
includeThoughts: options.includeThoughts,
thinkingBudget: options.useDynamicBudget ? -1 : 12000,
};
}
private getEmbeddingModel(model: string) {
const provider = this.instance as typeof this.instance & {
embeddingModel?: (modelId: string) => EmbeddingModel;
textEmbeddingModel?: (modelId: string) => EmbeddingModel;
};
return (
provider.embeddingModel?.(model) ?? provider.textEmbeddingModel?.(model)
);
}
private handleError(e: any) { private handleError(e: any) {
if (e instanceof UserFriendlyError) { if (e instanceof UserFriendlyError) {
return e; return e;
} else if (e instanceof AISDKError) {
this.logger.error('Throw error from ai sdk:', e);
return new CopilotProviderSideError({
provider: this.type,
kind: e.name || 'unknown',
message: e.message,
});
} else { } else {
return new CopilotProviderSideError({ return new CopilotProviderSideError({
provider: this.type, provider: this.type,
@@ -91,37 +98,261 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
} }
} }
protected createNativeDispatch(backendConfig: NativeLlmBackendConfig) {
return (request: NativeLlmRequest, signal?: AbortSignal) =>
llmDispatchStream('gemini', backendConfig, request, signal);
}
protected createNativeStructuredDispatch(
backendConfig: NativeLlmBackendConfig
) {
return (request: NativeLlmStructuredRequest) =>
llmStructuredDispatch('gemini', backendConfig, request);
}
protected createNativeEmbeddingDispatch(
backendConfig: NativeLlmBackendConfig
) {
return (request: NativeLlmEmbeddingRequest) =>
llmEmbeddingDispatch('gemini', backendConfig, request);
}
protected createNativeAdapter(
backendConfig: NativeLlmBackendConfig,
tools: CopilotToolSet,
nodeTextMiddleware?: NodeTextMiddleware[]
) {
return new NativeProviderAdapter(
this.createNativeDispatch(backendConfig),
tools,
this.MAX_STEPS,
{ nodeTextMiddleware }
);
}
protected async fetchRemoteAttach(url: string, signal?: AbortSignal) {
const parsed = new URL(url);
const response = await safeFetch(
parsed,
{ method: 'GET', signal },
this.buildAttachFetchOptions(parsed)
);
if (!response.ok) {
throw new Error(
`Failed to fetch attachment: ${response.status} ${response.statusText}`
);
}
const buffer = await readResponseBufferWithLimit(
response,
GEMINI_REMOTE_ATTACHMENT_MAX_BYTES
);
const headerMimeType = normalizeMimeType(
response.headers.get('content-type') || ''
);
return {
data: buffer.toString('base64'),
mimeType: normalizeMimeType(sniffMime(buffer, headerMimeType)),
};
}
private buildAttachFetchOptions(url: URL) {
const baseOptions = { timeoutMs: 15_000, maxRedirects: 3 } as const;
if (!env.prod) {
return { ...baseOptions, allowPrivateOrigins: new Set([url.origin]) };
}
const trustedOrigins = new Set<string>();
const protocol = this.AFFiNEConfig.server.https ? 'https:' : 'http:';
const port = this.AFFiNEConfig.server.port;
const isDefaultPort =
(protocol === 'https:' && port === 443) ||
(protocol === 'http:' && port === 80);
const addHostOrigin = (host: string) => {
if (!host) return;
try {
const parsed = new URL(`${protocol}//${host}`);
if (!parsed.port && !isDefaultPort) {
parsed.port = String(port);
}
trustedOrigins.add(parsed.origin);
} catch {
// ignore invalid host config entries
}
};
if (this.AFFiNEConfig.server.externalUrl) {
try {
trustedOrigins.add(
new URL(this.AFFiNEConfig.server.externalUrl).origin
);
} catch {
// ignore invalid external URL
}
}
addHostOrigin(this.AFFiNEConfig.server.host);
for (const host of this.AFFiNEConfig.server.hosts) {
addHostOrigin(host);
}
const hostname = url.hostname.toLowerCase();
const trustedByHost = TRUSTED_ATTACHMENT_HOST_SUFFIXES.some(
suffix => hostname === suffix || hostname.endsWith(`.${suffix}`)
);
if (trustedOrigins.has(url.origin) || trustedByHost) {
return { ...baseOptions, allowPrivateOrigins: new Set([url.origin]) };
}
return baseOptions;
}
private shouldInlineRemoteAttach(url: URL, config: NativeLlmBackendConfig) {
switch (config.request_layer) {
case 'gemini_api':
if (url.protocol !== 'http:' && url.protocol !== 'https:') return false;
return !(isGeminiFileUrl(url, config.base_url) || isYoutubeUrl(url));
case 'gemini_vertex':
return false;
default:
return false;
}
}
private toInlineAttach(
attachment: PromptAttachment,
mimeType: string,
data: string
): PromptAttachment {
if (typeof attachment === 'string' || !('kind' in attachment)) {
return { kind: 'bytes', data, mimeType };
}
if (attachment.kind !== 'url') {
return attachment;
}
return {
kind: 'bytes',
data,
mimeType,
fileName: attachment.fileName,
providerHint: attachment.providerHint,
};
}
protected async prepareMessages(
messages: PromptMessage[],
backendConfig: NativeLlmBackendConfig,
signal?: AbortSignal
): Promise<PromptMessage[]> {
const prepared: PromptMessage[] = [];
for (const message of messages) {
signal?.throwIfAborted();
if (!Array.isArray(message.attachments) || !message.attachments.length) {
prepared.push(message);
continue;
}
const attachments: PromptAttachment[] = [];
let changed = false;
for (const attachment of message.attachments) {
signal?.throwIfAborted();
const rawUrl = promptAttachmentToUrl(attachment);
if (!rawUrl || rawUrl.startsWith('data:')) {
attachments.push(attachment);
continue;
}
let parsed: URL;
try {
parsed = new URL(rawUrl);
} catch {
attachments.push(attachment);
continue;
}
if (!this.shouldInlineRemoteAttach(parsed, backendConfig)) {
attachments.push(attachment);
continue;
}
const declaredMimeType = promptAttachmentMimeType(
attachment,
typeof message.params?.mimetype === 'string'
? message.params.mimetype
: undefined
);
const downloaded = await this.fetchRemoteAttach(rawUrl, signal);
attachments.push(
this.toInlineAttach(
attachment,
declaredMimeType
? normalizeMimeType(declaredMimeType)
: downloaded.mimeType,
downloaded.data
)
);
changed = true;
}
prepared.push(changed ? { ...message, attachments } : message);
}
return prepared;
}
protected async waitForStructuredRetry(
delayMs: number,
signal?: AbortSignal
) {
await delay(delayMs, undefined, signal ? { signal } : undefined);
}
async text( async text(
cond: ModelConditions, cond: ModelConditions,
messages: PromptMessage[], messages: PromptMessage[],
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): Promise<string> { ): Promise<string> {
const fullCond = { ...cond, outputType: ModelOutputType.Text }; const fullCond = { ...cond, outputType: ModelOutputType.Text };
await this.checkParams({ cond: fullCond, messages, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); cond: fullCond,
messages,
options,
});
const model = this.selectModel(normalizedCond);
try { try {
metrics.ai.counter('chat_text_calls').add(1, { model: model.id }); metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
const backendConfig = await this.createNativeConfig();
const [system, msgs] = await chatToGPTMessage(messages); const msg = await this.prepareMessages(
messages,
const modelInstance = this.instance(model.id); backendConfig,
const { text } = await generateText({ options.signal
model: modelInstance, );
system, const tools = await this.getTools(options, model.id);
messages: msgs, const middleware = this.getActiveProviderMiddleware();
abortSignal: options.signal, const cap = this.getAttachCapability(model, ModelOutputType.Text);
providerOptions: { const { request } = await buildNativeRequest({
google: this.getGeminiOptions(options, model.id), model: model.id,
}, messages: msg,
tools: await this.getTools(options, model.id), options,
stopWhen: stepCountIs(this.MAX_STEPS), tools,
attachmentCapability: cap,
reasoning: this.getReasoning(options, model.id),
middleware,
}); });
const adapter = this.createNativeAdapter(
if (!text) throw new Error('Failed to generate text'); backendConfig,
return text.trim(); tools,
middleware.node?.text
);
return await adapter.text(request, options.signal, messages);
} catch (e: any) { } catch (e: any) {
metrics.ai.counter('chat_text_errors').add(1, { model: model.id }); metrics.ai
.counter('chat_text_errors')
.add(1, this.metricLabels(model.id));
throw this.handleError(e); throw this.handleError(e);
} }
} }
@@ -129,55 +360,65 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
override async structure( override async structure(
cond: ModelConditions, cond: ModelConditions,
messages: PromptMessage[], messages: PromptMessage[],
options: CopilotChatOptions = {} options: CopilotStructuredOptions = {}
): Promise<string> { ): Promise<string> {
const fullCond = { ...cond, outputType: ModelOutputType.Structured }; const fullCond = { ...cond, outputType: ModelOutputType.Structured };
await this.checkParams({ cond: fullCond, messages, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); cond: fullCond,
messages,
options,
});
const model = this.selectModel(normalizedCond);
try { try {
metrics.ai.counter('chat_text_calls').add(1, { model: model.id }); metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
const backendConfig = await this.createNativeConfig();
const [system, msgs, schema] = await chatToGPTMessage(messages); const msg = await this.prepareMessages(
if (!schema) { messages,
throw new CopilotPromptInvalid('Schema is required'); backendConfig,
} options.signal
);
const modelInstance = this.instance(model.id); const structuredDispatch =
const { object } = await generateObject({ this.createNativeStructuredDispatch(backendConfig);
model: modelInstance, const middleware = this.getActiveProviderMiddleware();
system, const cap = this.getAttachCapability(model, ModelOutputType.Structured);
messages: msgs, const { request, schema } = await buildNativeStructuredRequest({
schema, model: model.id,
providerOptions: { messages: msg,
google: { options,
thinkingConfig: this.getThinkingConfig(model.id, { attachmentCapability: cap,
includeThoughts: false, reasoning: this.getReasoning(options, model.id),
useDynamicBudget: true, responseSchema: options.schema,
}), middleware,
},
},
abortSignal: options.signal,
maxRetries: options.maxRetries || 3,
experimental_repairText: async ({ text, error }) => {
if (error instanceof JSONParseError) {
// strange fixed response, temporarily replace it
const ret = text.replaceAll(/^ny\n/g, ' ').trim();
if (ret.startsWith('```') || ret.endsWith('```')) {
return ret
.replace(/```[\w\s]+\n/g, '')
.replace(/\n```/g, '')
.trim();
}
return ret;
}
return null;
},
}); });
const maxRetries = Math.max(options.maxRetries ?? 3, 0);
return JSON.stringify(object); for (let attempt = 0; ; attempt++) {
try {
const response = await structuredDispatch(request);
const parsed = parseNativeStructuredOutput(response);
const validated = schema.parse(parsed);
return JSON.stringify(validated);
} catch (error) {
const isParsingError =
error instanceof StructuredResponseParseError ||
error instanceof ZodError;
const retryableError =
isParsingError || !(error instanceof UserFriendlyError);
if (!retryableError || attempt >= maxRetries) {
throw error;
}
if (!isParsingError) {
await this.waitForStructuredRetry(
GEMINI_RETRY_INITIAL_DELAY_MS * 2 ** attempt,
options.signal
);
}
}
}
} catch (e: any) { } catch (e: any) {
metrics.ai.counter('chat_text_errors').add(1, { model: model.id }); metrics.ai
.counter('chat_text_errors')
.add(1, this.metricLabels(model.id));
throw this.handleError(e); throw this.handleError(e);
} }
} }
@@ -188,29 +429,54 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
options: CopilotChatOptions | CopilotImageOptions = {} options: CopilotChatOptions | CopilotImageOptions = {}
): AsyncIterable<string> { ): AsyncIterable<string> {
const fullCond = { ...cond, outputType: ModelOutputType.Text }; const fullCond = { ...cond, outputType: ModelOutputType.Text };
await this.checkParams({ cond: fullCond, messages, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); cond: fullCond,
messages,
options,
});
const model = this.selectModel(normalizedCond);
try { try {
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id }); metrics.ai
const fullStream = await this.getFullStream(model, messages, options); .counter('chat_text_stream_calls')
const parser = new TextStreamParser(); .add(1, this.metricLabels(model.id));
for await (const chunk of fullStream) { const backendConfig = await this.createNativeConfig();
const result = parser.parse(chunk); const preparedMessages = await this.prepareMessages(
yield result; messages,
if (options.signal?.aborted) { backendConfig,
await fullStream.cancel(); options.signal
break; );
} const tools = await this.getTools(
} options as CopilotChatOptions,
if (!options.signal?.aborted) { model.id
const footnotes = parser.end(); );
if (footnotes.length) { const middleware = this.getActiveProviderMiddleware();
yield `\n\n${footnotes}`; const cap = this.getAttachCapability(model, ModelOutputType.Text);
} const { request } = await buildNativeRequest({
model: model.id,
messages: preparedMessages,
options: options as CopilotChatOptions,
tools,
attachmentCapability: cap,
reasoning: this.getReasoning(options, model.id),
middleware,
});
const adapter = this.createNativeAdapter(
backendConfig,
tools,
middleware.node?.text
);
for await (const chunk of adapter.streamText(
request,
options.signal,
messages
)) {
yield chunk;
} }
} catch (e: any) { } catch (e: any) {
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id }); metrics.ai
.counter('chat_text_stream_errors')
.add(1, this.metricLabels(model.id));
throw this.handleError(e); throw this.handleError(e);
} }
} }
@@ -221,29 +487,51 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> { ): AsyncIterable<StreamObject> {
const fullCond = { ...cond, outputType: ModelOutputType.Object }; const fullCond = { ...cond, outputType: ModelOutputType.Object };
await this.checkParams({ cond: fullCond, messages, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); cond: fullCond,
messages,
options,
});
const model = this.selectModel(normalizedCond);
try { try {
metrics.ai metrics.ai
.counter('chat_object_stream_calls') .counter('chat_object_stream_calls')
.add(1, { model: model.id }); .add(1, this.metricLabels(model.id));
const fullStream = await this.getFullStream(model, messages, options); const backendConfig = await this.createNativeConfig();
const parser = new StreamObjectParser(); const msg = await this.prepareMessages(
for await (const chunk of fullStream) { messages,
const result = parser.parse(chunk); backendConfig,
if (result) { options.signal
yield result; );
} const tools = await this.getTools(options, model.id);
if (options.signal?.aborted) { const middleware = this.getActiveProviderMiddleware();
await fullStream.cancel(); const cap = this.getAttachCapability(model, ModelOutputType.Object);
break; const { request } = await buildNativeRequest({
} model: model.id,
messages: msg,
options,
tools,
attachmentCapability: cap,
reasoning: this.getReasoning(options, model.id),
middleware,
});
const adapter = this.createNativeAdapter(
backendConfig,
tools,
middleware.node?.text
);
for await (const chunk of adapter.streamObject(
request,
options.signal,
messages
)) {
yield chunk;
} }
} catch (e: any) { } catch (e: any) {
metrics.ai metrics.ai
.counter('chat_object_stream_errors') .counter('chat_object_stream_errors')
.add(1, { model: model.id }); .add(1, this.metricLabels(model.id));
throw this.handleError(e); throw this.handleError(e);
} }
} }
@@ -253,76 +541,53 @@ export abstract class GeminiProvider<T> extends CopilotProvider<T> {
messages: string | string[], messages: string | string[],
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS } options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
): Promise<number[][]> { ): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages]; const values = Array.isArray(messages) ? messages : [messages];
const fullCond = { ...cond, outputType: ModelOutputType.Embedding }; const fullCond = { ...cond, outputType: ModelOutputType.Embedding };
await this.checkParams({ embeddings: messages, cond: fullCond, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); embeddings: values,
cond: fullCond,
options,
});
const model = this.selectModel(normalizedCond);
try { try {
metrics.ai metrics.ai
.counter('generate_embedding_calls') .counter('generate_embedding_calls')
.add(1, { model: model.id }); .add(1, this.metricLabels(model.id));
const backendConfig = await this.createNativeConfig();
const modelInstance = this.getEmbeddingModel(model.id); const response = await this.createNativeEmbeddingDispatch(backendConfig)(
if (!modelInstance) { buildNativeEmbeddingRequest({
throw new Error(`Embedding model is not available for ${model.id}`); model: model.id,
} inputs: values,
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
const embeddings = await Promise.allSettled(
messages.map(m =>
embedMany({
model: modelInstance,
values: [m],
maxRetries: 3,
providerOptions: {
google: {
outputDimensionality: options.dimensions || DEFAULT_DIMENSIONS,
taskType: 'RETRIEVAL_DOCUMENT', taskType: 'RETRIEVAL_DOCUMENT',
},
},
}) })
)
); );
return response.embeddings;
return embeddings
.flatMap(e => (e.status === 'fulfilled' ? e.value.embeddings : null))
.filter((v): v is number[] => !!v && Array.isArray(v));
} catch (e: any) { } catch (e: any) {
metrics.ai metrics.ai
.counter('generate_embedding_errors') .counter('generate_embedding_errors')
.add(1, { model: model.id }); .add(1, this.metricLabels(model.id));
throw this.handleError(e); throw this.handleError(e);
} }
} }
private async getFullStream( protected getReasoning(
model: CopilotProviderModel, options: CopilotChatOptions | CopilotImageOptions,
messages: PromptMessage[], model: string
options: CopilotChatOptions = {} ): Record<string, unknown> | undefined {
if (
options &&
'reasoning' in options &&
options.reasoning &&
this.isReasoningModel(model)
) { ) {
const [system, msgs] = await chatToGPTMessage(messages); return this.isGemini3Model(model)
const { fullStream } = streamText({ ? { include_thoughts: true, thinking_level: 'high' }
model: this.instance(model.id), : { include_thoughts: true, thinking_budget: 12000 };
system,
messages: msgs,
abortSignal: options.signal,
providerOptions: {
google: this.getGeminiOptions(options, model.id),
},
tools: await this.getTools(options, model.id),
stopWhen: stepCountIs(this.MAX_STEPS),
});
return fullStream;
} }
private getGeminiOptions(options: CopilotChatOptions, model: string) { return undefined;
const result: GoogleGenerativeAIProviderOptions = {};
if (options?.reasoning && this.isReasoningModel(model)) {
result.thinkingConfig = this.getThinkingConfig(model, {
includeThoughts: true,
});
}
return result;
} }
private isGemini3Model(model: string) { private isGemini3Model(model: string) {

View File

@@ -1,9 +1,7 @@
import {
createGoogleGenerativeAI,
type GoogleGenerativeAIProvider,
} from '@ai-sdk/google';
import z from 'zod'; import z from 'zod';
import type { NativeLlmBackendConfig } from '../../../../native';
import { GEMINI_ATTACHMENT_CAPABILITY } from '../attachments';
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
import { GeminiProvider } from './gemini'; import { GeminiProvider } from './gemini';
@@ -29,12 +27,15 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
ModelInputType.Text, ModelInputType.Text,
ModelInputType.Image, ModelInputType.Image,
ModelInputType.Audio, ModelInputType.Audio,
ModelInputType.File,
], ],
output: [ output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Structured, ModelOutputType.Structured,
], ],
attachments: GEMINI_ATTACHMENT_CAPABILITY,
structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY,
}, },
], ],
}, },
@@ -47,12 +48,15 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
ModelInputType.Text, ModelInputType.Text,
ModelInputType.Image, ModelInputType.Image,
ModelInputType.Audio, ModelInputType.Audio,
ModelInputType.File,
], ],
output: [ output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Structured, ModelOutputType.Structured,
], ],
attachments: GEMINI_ATTACHMENT_CAPABILITY,
structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY,
}, },
], ],
}, },
@@ -65,12 +69,15 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
ModelInputType.Text, ModelInputType.Text,
ModelInputType.Image, ModelInputType.Image,
ModelInputType.Audio, ModelInputType.Audio,
ModelInputType.File,
], ],
output: [ output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Structured, ModelOutputType.Structured,
], ],
attachments: GEMINI_ATTACHMENT_CAPABILITY,
structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY,
}, },
], ],
}, },
@@ -86,21 +93,10 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
], ],
}, },
]; ];
protected instance!: GoogleGenerativeAIProvider;
override configured(): boolean { override configured(): boolean {
return !!this.config.apiKey; return !!this.config.apiKey;
} }
protected override setup() {
super.setup();
this.instance = createGoogleGenerativeAI({
apiKey: this.config.apiKey,
baseURL: this.config.baseURL,
});
}
override async refreshOnlineModels() { override async refreshOnlineModels() {
try { try {
const baseUrl = const baseUrl =
@@ -120,4 +116,15 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
this.logger.error('Failed to fetch available models', e); this.logger.error('Failed to fetch available models', e);
} }
} }
protected override async createNativeConfig(): Promise<NativeLlmBackendConfig> {
return {
base_url: (
this.config.baseURL ||
'https://generativelanguage.googleapis.com/v1beta'
).replace(/\/$/, ''),
auth_token: this.config.apiKey,
request_layer: 'gemini_api',
};
}
} }

View File

@@ -1,14 +1,14 @@
import { import type { NativeLlmBackendConfig } from '../../../../native';
createVertex, import { GEMINI_ATTACHMENT_CAPABILITY } from '../attachments';
type GoogleVertexProvider,
type GoogleVertexProviderSettings,
} from '@ai-sdk/google-vertex';
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
import { getGoogleAuth, VertexModelListSchema } from '../utils'; import {
getGoogleAuth,
VertexModelListSchema,
type VertexProviderConfig,
} from '../utils';
import { GeminiProvider } from './gemini'; import { GeminiProvider } from './gemini';
export type GeminiVertexConfig = GoogleVertexProviderSettings; export type GeminiVertexConfig = VertexProviderConfig;
export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> { export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
override readonly type = CopilotProviderType.GeminiVertex; override readonly type = CopilotProviderType.GeminiVertex;
@@ -23,12 +23,15 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
ModelInputType.Text, ModelInputType.Text,
ModelInputType.Image, ModelInputType.Image,
ModelInputType.Audio, ModelInputType.Audio,
ModelInputType.File,
], ],
output: [ output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Structured, ModelOutputType.Structured,
], ],
attachments: GEMINI_ATTACHMENT_CAPABILITY,
structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY,
}, },
], ],
}, },
@@ -41,12 +44,15 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
ModelInputType.Text, ModelInputType.Text,
ModelInputType.Image, ModelInputType.Image,
ModelInputType.Audio, ModelInputType.Audio,
ModelInputType.File,
], ],
output: [ output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Structured, ModelOutputType.Structured,
], ],
attachments: GEMINI_ATTACHMENT_CAPABILITY,
structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY,
}, },
], ],
}, },
@@ -59,12 +65,15 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
ModelInputType.Text, ModelInputType.Text,
ModelInputType.Image, ModelInputType.Image,
ModelInputType.Audio, ModelInputType.Audio,
ModelInputType.File,
], ],
output: [ output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Structured, ModelOutputType.Structured,
], ],
attachments: GEMINI_ATTACHMENT_CAPABILITY,
structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY,
}, },
], ],
}, },
@@ -80,21 +89,13 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
], ],
}, },
]; ];
protected instance!: GoogleVertexProvider;
override configured(): boolean { override configured(): boolean {
return !!this.config.location && !!this.config.googleAuthOptions; return !!this.config.location && !!this.config.googleAuthOptions;
} }
protected override setup() {
super.setup();
this.instance = createVertex(this.config);
}
override async refreshOnlineModels() { override async refreshOnlineModels() {
try { try {
const { baseUrl, headers } = await getGoogleAuth(this.config, 'google'); const { baseUrl, headers } = await this.resolveVertexAuth();
if (baseUrl && !this.onlineModelList.length) { if (baseUrl && !this.onlineModelList.length) {
const { publisherModels } = await fetch(`${baseUrl}/models`, { const { publisherModels } = await fetch(`${baseUrl}/models`, {
headers: headers(), headers: headers(),
@@ -109,4 +110,19 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
this.logger.error('Failed to fetch available models', e); this.logger.error('Failed to fetch available models', e);
} }
} }
protected async resolveVertexAuth() {
return await getGoogleAuth(this.config, 'google');
}
protected override async createNativeConfig(): Promise<NativeLlmBackendConfig> {
const auth = await this.resolveVertexAuth();
const { Authorization: authHeader } = auth.headers();
return {
base_url: auth.baseUrl || '',
auth_token: authHeader.replace(/^Bearer\s+/i, ''),
request_layer: 'gemini_vertex',
};
}
} }

View File

@@ -1,4 +1,3 @@
import type { ToolSet } from 'ai';
import { z } from 'zod'; import { z } from 'zod';
import type { import type {
@@ -6,6 +5,11 @@ import type {
NativeLlmStreamEvent, NativeLlmStreamEvent,
NativeLlmToolDefinition, NativeLlmToolDefinition,
} from '../../../native'; } from '../../../native';
import type {
CopilotTool,
CopilotToolExecuteOptions,
CopilotToolSet,
} from '../tools';
export type NativeDispatchFn = ( export type NativeDispatchFn = (
request: NativeLlmRequest, request: NativeLlmRequest,
@@ -16,6 +20,8 @@ export type NativeToolCall = {
id: string; id: string;
name: string; name: string;
args: Record<string, unknown>; args: Record<string, unknown>;
rawArgumentsText?: string;
argumentParseError?: string;
thought?: string; thought?: string;
}; };
@@ -28,10 +34,18 @@ type ToolExecutionResult = {
callId: string; callId: string;
name: string; name: string;
args: Record<string, unknown>; args: Record<string, unknown>;
rawArgumentsText?: string;
argumentParseError?: string;
output: unknown; output: unknown;
isError?: boolean; isError?: boolean;
}; };
type ParsedToolArguments = {
args: Record<string, unknown>;
rawArgumentsText?: string;
argumentParseError?: string;
};
export class ToolCallAccumulator { export class ToolCallAccumulator {
readonly #states = new Map<string, ToolCallState>(); readonly #states = new Map<string, ToolCallState>();
@@ -51,12 +65,20 @@ export class ToolCallAccumulator {
complete(event: Extract<NativeLlmStreamEvent, { type: 'tool_call' }>) { complete(event: Extract<NativeLlmStreamEvent, { type: 'tool_call' }>) {
const state = this.#states.get(event.call_id); const state = this.#states.get(event.call_id);
this.#states.delete(event.call_id); this.#states.delete(event.call_id);
const parsed =
event.arguments_text !== undefined || event.arguments_error !== undefined
? {
args: event.arguments ?? {},
rawArgumentsText: event.arguments_text ?? state?.argumentsText,
argumentParseError: event.arguments_error,
}
: event.arguments
? this.parseArgs(event.arguments, state?.argumentsText)
: this.parseJson(state?.argumentsText ?? '{}');
return { return {
id: event.call_id, id: event.call_id,
name: event.name || state?.name || '', name: event.name || state?.name || '',
args: this.parseArgs( ...parsed,
event.arguments ?? this.parseJson(state?.argumentsText ?? '{}')
),
thought: event.thought, thought: event.thought,
} satisfies NativeToolCall; } satisfies NativeToolCall;
} }
@@ -70,51 +92,61 @@ export class ToolCallAccumulator {
pending.push({ pending.push({
id: callId, id: callId,
name: state.name, name: state.name,
args: this.parseArgs(this.parseJson(state.argumentsText)), ...this.parseJson(state.argumentsText),
}); });
} }
this.#states.clear(); this.#states.clear();
return pending; return pending;
} }
private parseJson(jsonText: string): unknown { private parseJson(jsonText: string): ParsedToolArguments {
if (!jsonText.trim()) { if (!jsonText.trim()) {
return {}; return { args: {} };
} }
try { try {
return JSON.parse(jsonText); return this.parseArgs(JSON.parse(jsonText), jsonText);
} catch { } catch (error) {
return {}; return {
args: {},
rawArgumentsText: jsonText,
argumentParseError:
error instanceof Error
? error.message
: 'Invalid tool arguments JSON',
};
} }
} }
private parseArgs(value: unknown): Record<string, unknown> { private parseArgs(
value: unknown,
rawArgumentsText?: string
): ParsedToolArguments {
if (value && typeof value === 'object' && !Array.isArray(value)) { if (value && typeof value === 'object' && !Array.isArray(value)) {
return value as Record<string, unknown>; return {
args: value as Record<string, unknown>,
rawArgumentsText,
};
} }
return {}; return {
args: {},
rawArgumentsText,
argumentParseError: 'Tool arguments must be a JSON object',
};
} }
} }
export class ToolSchemaExtractor { export class ToolSchemaExtractor {
static extract(toolSet: ToolSet): NativeLlmToolDefinition[] { static extract(toolSet: CopilotToolSet): NativeLlmToolDefinition[] {
return Object.entries(toolSet).map(([name, tool]) => { return Object.entries(toolSet).map(([name, tool]) => {
const unknownTool = tool as Record<string, unknown>;
const inputSchema =
unknownTool.inputSchema ?? unknownTool.parameters ?? z.object({});
return { return {
name, name,
description: description: tool.description,
typeof unknownTool.description === 'string' parameters: this.toJsonSchema(tool.inputSchema ?? z.object({})),
? unknownTool.description
: undefined,
parameters: this.toJsonSchema(inputSchema),
}; };
}); });
} }
private static toJsonSchema(schema: unknown): Record<string, unknown> { static toJsonSchema(schema: unknown): Record<string, unknown> {
if (!(schema instanceof z.ZodType)) { if (!(schema instanceof z.ZodType)) {
if (schema && typeof schema === 'object' && !Array.isArray(schema)) { if (schema && typeof schema === 'object' && !Array.isArray(schema)) {
return schema as Record<string, unknown>; return schema as Record<string, unknown>;
@@ -228,14 +260,45 @@ export class ToolSchemaExtractor {
export class ToolCallLoop { export class ToolCallLoop {
constructor( constructor(
private readonly dispatch: NativeDispatchFn, private readonly dispatch: NativeDispatchFn,
private readonly tools: ToolSet, private readonly tools: CopilotToolSet,
private readonly maxSteps = 20 private readonly maxSteps = 20
) {} ) {}
private normalizeToolExecuteOptions(
signalOrOptions?: AbortSignal | CopilotToolExecuteOptions,
maybeMessages?: CopilotToolExecuteOptions['messages']
): CopilotToolExecuteOptions {
if (
signalOrOptions &&
typeof signalOrOptions === 'object' &&
'aborted' in signalOrOptions
) {
return {
signal: signalOrOptions,
messages: maybeMessages,
};
}
if (!signalOrOptions) {
return maybeMessages ? { messages: maybeMessages } : {};
}
return {
...signalOrOptions,
signal: signalOrOptions.signal,
messages: signalOrOptions.messages ?? maybeMessages,
};
}
async *run( async *run(
request: NativeLlmRequest, request: NativeLlmRequest,
signal?: AbortSignal signalOrOptions?: AbortSignal | CopilotToolExecuteOptions,
maybeMessages?: CopilotToolExecuteOptions['messages']
): AsyncIterableIterator<NativeLlmStreamEvent> { ): AsyncIterableIterator<NativeLlmStreamEvent> {
const toolExecuteOptions = this.normalizeToolExecuteOptions(
signalOrOptions,
maybeMessages
);
const messages = request.messages.map(message => ({ const messages = request.messages.map(message => ({
...message, ...message,
content: [...message.content], content: [...message.content],
@@ -253,7 +316,7 @@ export class ToolCallLoop {
stream: true, stream: true,
messages, messages,
}, },
signal toolExecuteOptions.signal
)) { )) {
switch (event.type) { switch (event.type) {
case 'tool_call_delta': { case 'tool_call_delta': {
@@ -291,7 +354,10 @@ export class ToolCallLoop {
throw new Error('ToolCallLoop max steps reached'); throw new Error('ToolCallLoop max steps reached');
} }
const toolResults = await this.executeTools(toolCalls); const toolResults = await this.executeTools(
toolCalls,
toolExecuteOptions
);
messages.push({ messages.push({
role: 'assistant', role: 'assistant',
@@ -300,6 +366,8 @@ export class ToolCallLoop {
call_id: call.id, call_id: call.id,
name: call.name, name: call.name,
arguments: call.args, arguments: call.args,
arguments_text: call.rawArgumentsText,
arguments_error: call.argumentParseError,
thought: call.thought, thought: call.thought,
})), })),
}); });
@@ -311,6 +379,10 @@ export class ToolCallLoop {
{ {
type: 'tool_result', type: 'tool_result',
call_id: result.callId, call_id: result.callId,
name: result.name,
arguments: result.args,
arguments_text: result.rawArgumentsText,
arguments_error: result.argumentParseError,
output: result.output, output: result.output,
is_error: result.isError, is_error: result.isError,
}, },
@@ -321,6 +393,8 @@ export class ToolCallLoop {
call_id: result.callId, call_id: result.callId,
name: result.name, name: result.name,
arguments: result.args, arguments: result.args,
arguments_text: result.rawArgumentsText,
arguments_error: result.argumentParseError,
output: result.output, output: result.output,
is_error: result.isError, is_error: result.isError,
}; };
@@ -328,24 +402,28 @@ export class ToolCallLoop {
} }
} }
private async executeTools(calls: NativeToolCall[]) { private async executeTools(
return await Promise.all(calls.map(call => this.executeTool(call))); calls: NativeToolCall[],
options: CopilotToolExecuteOptions
) {
return await Promise.all(
calls.map(call => this.executeTool(call, options))
);
} }
private async executeTool( private async executeTool(
call: NativeToolCall call: NativeToolCall,
options: CopilotToolExecuteOptions
): Promise<ToolExecutionResult> { ): Promise<ToolExecutionResult> {
const tool = this.tools[call.name] as const tool = this.tools[call.name] as CopilotTool | undefined;
| {
execute?: (args: Record<string, unknown>) => Promise<unknown>;
}
| undefined;
if (!tool?.execute) { if (!tool?.execute) {
return { return {
callId: call.id, callId: call.id,
name: call.name, name: call.name,
args: call.args, args: call.args,
rawArgumentsText: call.rawArgumentsText,
argumentParseError: call.argumentParseError,
isError: true, isError: true,
output: { output: {
message: `Tool not found: ${call.name}`, message: `Tool not found: ${call.name}`,
@@ -353,12 +431,30 @@ export class ToolCallLoop {
}; };
} }
try { if (call.argumentParseError) {
const output = await tool.execute(call.args);
return { return {
callId: call.id, callId: call.id,
name: call.name, name: call.name,
args: call.args, args: call.args,
rawArgumentsText: call.rawArgumentsText,
argumentParseError: call.argumentParseError,
isError: true,
output: {
message: 'Invalid tool arguments JSON',
rawArguments: call.rawArgumentsText,
error: call.argumentParseError,
},
};
}
try {
const output = await tool.execute(call.args, options);
return {
callId: call.id,
name: call.name,
args: call.args,
rawArgumentsText: call.rawArgumentsText,
argumentParseError: call.argumentParseError,
output: output ?? null, output: output ?? null,
}; };
} catch (error) { } catch (error) {
@@ -371,6 +467,8 @@ export class ToolCallLoop {
callId: call.id, callId: call.id,
name: call.name, name: call.name,
args: call.args, args: call.args,
rawArgumentsText: call.rawArgumentsText,
argumentParseError: call.argumentParseError,
isError: true, isError: true,
output: { output: {
message: 'Tool execution failed', message: 'Tool execution failed',

View File

@@ -1,5 +1,3 @@
import type { ToolSet } from 'ai';
import { import {
CopilotProviderSideError, CopilotProviderSideError,
metrics, metrics,
@@ -11,6 +9,7 @@ import {
type NativeLlmRequest, type NativeLlmRequest,
} from '../../../native'; } from '../../../native';
import type { NodeTextMiddleware } from '../config'; import type { NodeTextMiddleware } from '../config';
import type { CopilotToolSet } from '../tools';
import { buildNativeRequest, NativeProviderAdapter } from './native'; import { buildNativeRequest, NativeProviderAdapter } from './native';
import { CopilotProvider } from './provider'; import { CopilotProvider } from './provider';
import type { import type {
@@ -86,7 +85,7 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
} }
private createNativeAdapter( private createNativeAdapter(
tools: ToolSet, tools: CopilotToolSet,
nodeTextMiddleware?: NodeTextMiddleware[] nodeTextMiddleware?: NodeTextMiddleware[]
) { ) {
return new NativeProviderAdapter( return new NativeProviderAdapter(
@@ -108,12 +107,14 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
messages: PromptMessage[], messages: PromptMessage[],
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): Promise<string> { ): Promise<string> {
const fullCond = { const fullCond = { ...cond, outputType: ModelOutputType.Text };
...cond, const model = this.selectModel(
outputType: ModelOutputType.Text, await this.checkParams({
}; messages,
await this.checkParams({ messages, cond: fullCond, options }); cond: fullCond,
const model = this.selectModel(fullCond); options,
})
);
try { try {
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
@@ -127,7 +128,7 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
middleware, middleware,
}); });
const adapter = this.createNativeAdapter(tools, middleware.node?.text); const adapter = this.createNativeAdapter(tools, middleware.node?.text);
return await adapter.text(request, options.signal); return await adapter.text(request, options.signal, messages);
} catch (e: any) { } catch (e: any) {
metrics.ai metrics.ai
.counter('chat_text_errors') .counter('chat_text_errors')
@@ -141,12 +142,14 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
messages: PromptMessage[], messages: PromptMessage[],
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): AsyncIterable<string> { ): AsyncIterable<string> {
const fullCond = { const fullCond = { ...cond, outputType: ModelOutputType.Text };
...cond, const model = this.selectModel(
outputType: ModelOutputType.Text, await this.checkParams({
}; messages,
await this.checkParams({ messages, cond: fullCond, options }); cond: fullCond,
const model = this.selectModel(fullCond); options,
})
);
try { try {
metrics.ai metrics.ai
@@ -162,7 +165,11 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
middleware, middleware,
}); });
const adapter = this.createNativeAdapter(tools, middleware.node?.text); const adapter = this.createNativeAdapter(tools, middleware.node?.text);
for await (const chunk of adapter.streamText(request, options.signal)) { for await (const chunk of adapter.streamText(
request,
options.signal,
messages
)) {
yield chunk; yield chunk;
} }
} catch (e: any) { } catch (e: any) {

View File

@@ -1,31 +1,41 @@
import type { ToolSet } from 'ai';
import { ZodType } from 'zod'; import { ZodType } from 'zod';
import { CopilotPromptInvalid } from '../../../base';
import type { import type {
NativeLlmCoreContent, NativeLlmCoreContent,
NativeLlmCoreMessage, NativeLlmCoreMessage,
NativeLlmEmbeddingRequest,
NativeLlmRequest, NativeLlmRequest,
NativeLlmStreamEvent, NativeLlmStreamEvent,
NativeLlmStructuredRequest,
NativeLlmStructuredResponse,
} from '../../../native'; } from '../../../native';
import type { NodeTextMiddleware, ProviderMiddlewareConfig } from '../config'; import type { NodeTextMiddleware, ProviderMiddlewareConfig } from '../config';
import { NativeDispatchFn, ToolCallLoop, ToolSchemaExtractor } from './loop'; import type { CopilotToolSet } from '../tools';
import type { CopilotChatOptions, PromptMessage, StreamObject } from './types';
import { import {
CitationFootnoteFormatter, canonicalizePromptAttachment,
inferMimeType, type CanonicalPromptAttachment,
TextStreamParser, } from './attachments';
} from './utils'; import { NativeDispatchFn, ToolCallLoop, ToolSchemaExtractor } from './loop';
import type {
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/; CopilotChatOptions,
CopilotStructuredOptions,
ModelAttachmentCapability,
PromptMessage,
StreamObject,
} from './types';
import { CitationFootnoteFormatter, TextStreamParser } from './utils';
type BuildNativeRequestOptions = { type BuildNativeRequestOptions = {
model: string; model: string;
messages: PromptMessage[]; messages: PromptMessage[];
options?: CopilotChatOptions; options?: CopilotChatOptions | CopilotStructuredOptions;
tools?: ToolSet; tools?: CopilotToolSet;
withAttachment?: boolean; withAttachment?: boolean;
attachmentCapability?: ModelAttachmentCapability;
include?: string[]; include?: string[];
reasoning?: Record<string, unknown>; reasoning?: Record<string, unknown>;
responseSchema?: unknown;
middleware?: ProviderMiddlewareConfig; middleware?: ProviderMiddlewareConfig;
}; };
@@ -34,6 +44,11 @@ type BuildNativeRequestResult = {
schema?: ZodType; schema?: ZodType;
}; };
type BuildNativeStructuredRequestResult = {
request: NativeLlmStructuredRequest;
schema: ZodType;
};
type ToolCallMeta = { type ToolCallMeta = {
name: string; name: string;
args: Record<string, unknown>; args: Record<string, unknown>;
@@ -68,9 +83,121 @@ function roleToCore(role: PromptMessage['role']) {
} }
} }
function ensureAttachmentSupported(
attachment: CanonicalPromptAttachment,
attachmentCapability?: ModelAttachmentCapability
) {
if (!attachmentCapability) return;
if (!attachmentCapability.kinds.includes(attachment.kind)) {
throw new CopilotPromptInvalid(
`Native path does not support ${attachment.kind} attachments${
attachment.mediaType ? ` (${attachment.mediaType})` : ''
}`
);
}
if (
attachmentCapability.sourceKinds?.length &&
!attachmentCapability.sourceKinds.includes(attachment.sourceKind)
) {
throw new CopilotPromptInvalid(
`Native path does not support ${attachment.sourceKind} attachment sources`
);
}
if (attachment.isRemote && attachmentCapability.allowRemoteUrls === false) {
throw new CopilotPromptInvalid(
'Native path does not support remote attachment urls'
);
}
}
function resolveResponseSchema(
systemMessage: PromptMessage | undefined,
responseSchema?: unknown
): ZodType | undefined {
if (responseSchema instanceof ZodType) {
return responseSchema;
}
if (systemMessage?.responseFormat?.schema instanceof ZodType) {
return systemMessage.responseFormat.schema;
}
return systemMessage?.params?.schema instanceof ZodType
? systemMessage.params.schema
: undefined;
}
function resolveResponseStrict(
systemMessage: PromptMessage | undefined,
options?: CopilotStructuredOptions
) {
return options?.strict ?? systemMessage?.responseFormat?.strict ?? true;
}
export class StructuredResponseParseError extends Error {}
function normalizeStructuredText(text: string) {
const trimmed = text.replaceAll(/^ny\n/g, ' ').trim();
if (trimmed.startsWith('```') || trimmed.endsWith('```')) {
return trimmed
.replace(/```[\w\s-]*\n/g, '')
.replace(/\n```/g, '')
.trim();
}
return trimmed;
}
export function parseNativeStructuredOutput(
response: Pick<NativeLlmStructuredResponse, 'output_text'> & {
output_json?: unknown;
}
) {
if (response.output_json !== undefined) {
return response.output_json;
}
const normalized = normalizeStructuredText(response.output_text);
const candidates = [
() => normalized,
() => {
const objectStart = normalized.indexOf('{');
const objectEnd = normalized.lastIndexOf('}');
return objectStart !== -1 && objectEnd > objectStart
? normalized.slice(objectStart, objectEnd + 1)
: null;
},
() => {
const arrayStart = normalized.indexOf('[');
const arrayEnd = normalized.lastIndexOf(']');
return arrayStart !== -1 && arrayEnd > arrayStart
? normalized.slice(arrayStart, arrayEnd + 1)
: null;
},
];
for (const candidate of candidates) {
try {
const candidateText = candidate();
if (typeof candidateText === 'string') {
return JSON.parse(candidateText);
}
} catch {
continue;
}
}
throw new StructuredResponseParseError(
`Unexpected structured response: ${normalized.slice(0, 200)}`
);
}
async function toCoreContents( async function toCoreContents(
message: PromptMessage, message: PromptMessage,
withAttachment: boolean withAttachment: boolean,
attachmentCapability?: ModelAttachmentCapability
): Promise<NativeLlmCoreContent[]> { ): Promise<NativeLlmCoreContent[]> {
const contents: NativeLlmCoreContent[] = []; const contents: NativeLlmCoreContent[] = [];
@@ -81,24 +208,12 @@ async function toCoreContents(
if (!withAttachment || !Array.isArray(message.attachments)) return contents; if (!withAttachment || !Array.isArray(message.attachments)) return contents;
for (const entry of message.attachments) { for (const entry of message.attachments) {
let attachmentUrl: string; const normalized = await canonicalizePromptAttachment(entry, message);
let mediaType: string; ensureAttachmentSupported(normalized, attachmentCapability);
contents.push({
if (typeof entry === 'string') { type: normalized.kind,
attachmentUrl = entry; source: normalized.source,
mediaType = });
typeof message.params?.mimetype === 'string'
? message.params.mimetype
: await inferMimeType(entry);
} else {
attachmentUrl = entry.attachment;
mediaType = entry.mimeType;
}
if (!SIMPLE_IMAGE_URL_REGEX.test(attachmentUrl)) continue;
if (!mediaType.startsWith('image/')) continue;
contents.push({ type: 'image', source: { url: attachmentUrl } });
} }
return contents; return contents;
@@ -110,8 +225,10 @@ export async function buildNativeRequest({
options = {}, options = {},
tools = {}, tools = {},
withAttachment = true, withAttachment = true,
attachmentCapability,
include, include,
reasoning, reasoning,
responseSchema,
middleware, middleware,
}: BuildNativeRequestOptions): Promise<BuildNativeRequestResult> { }: BuildNativeRequestOptions): Promise<BuildNativeRequestResult> {
const copiedMessages = messages.map(message => ({ const copiedMessages = messages.map(message => ({
@@ -123,10 +240,7 @@ export async function buildNativeRequest({
const systemMessage = const systemMessage =
copiedMessages[0]?.role === 'system' ? copiedMessages.shift() : undefined; copiedMessages[0]?.role === 'system' ? copiedMessages.shift() : undefined;
const schema = const schema = resolveResponseSchema(systemMessage, responseSchema);
systemMessage?.params?.schema instanceof ZodType
? systemMessage.params.schema
: undefined;
const coreMessages: NativeLlmCoreMessage[] = []; const coreMessages: NativeLlmCoreMessage[] = [];
if (systemMessage?.content?.length) { if (systemMessage?.content?.length) {
@@ -138,7 +252,11 @@ export async function buildNativeRequest({
for (const message of copiedMessages) { for (const message of copiedMessages) {
if (message.role === 'system') continue; if (message.role === 'system') continue;
const content = await toCoreContents(message, withAttachment); const content = await toCoreContents(
message,
withAttachment,
attachmentCapability
);
coreMessages.push({ role: roleToCore(message.role), content }); coreMessages.push({ role: roleToCore(message.role), content });
} }
@@ -153,6 +271,9 @@ export async function buildNativeRequest({
tool_choice: Object.keys(tools).length ? 'auto' : undefined, tool_choice: Object.keys(tools).length ? 'auto' : undefined,
include, include,
reasoning, reasoning,
response_schema: schema
? ToolSchemaExtractor.toJsonSchema(schema)
: undefined,
middleware: middleware?.rust middleware: middleware?.rust
? { request: middleware.rust.request, stream: middleware.rust.stream } ? { request: middleware.rust.request, stream: middleware.rust.stream }
: undefined, : undefined,
@@ -161,6 +282,90 @@ export async function buildNativeRequest({
}; };
} }
export async function buildNativeStructuredRequest({
model,
messages,
options = {},
withAttachment = true,
attachmentCapability,
reasoning,
responseSchema,
middleware,
}: Omit<
BuildNativeRequestOptions,
'tools' | 'include'
>): Promise<BuildNativeStructuredRequestResult> {
const copiedMessages = messages.map(message => ({
...message,
attachments: message.attachments
? [...message.attachments]
: message.attachments,
}));
const systemMessage =
copiedMessages[0]?.role === 'system' ? copiedMessages.shift() : undefined;
const schema = resolveResponseSchema(systemMessage, responseSchema);
const strict = resolveResponseStrict(systemMessage, options);
if (!schema) {
throw new CopilotPromptInvalid('Schema is required');
}
const coreMessages: NativeLlmCoreMessage[] = [];
if (systemMessage?.content?.length) {
coreMessages.push({
role: 'system',
content: [{ type: 'text', text: systemMessage.content }],
});
}
for (const message of copiedMessages) {
if (message.role === 'system') continue;
const content = await toCoreContents(
message,
withAttachment,
attachmentCapability
);
coreMessages.push({ role: roleToCore(message.role), content });
}
return {
request: {
model,
messages: coreMessages,
schema: ToolSchemaExtractor.toJsonSchema(schema),
max_tokens: options.maxTokens ?? undefined,
temperature: options.temperature ?? undefined,
reasoning,
strict,
response_mime_type: 'application/json',
middleware: middleware?.rust
? { request: middleware.rust.request }
: undefined,
},
schema,
};
}
export function buildNativeEmbeddingRequest({
model,
inputs,
dimensions,
taskType = 'RETRIEVAL_DOCUMENT',
}: {
model: string;
inputs: string[];
dimensions?: number;
taskType?: string;
}): NativeLlmEmbeddingRequest {
return {
model,
inputs,
dimensions,
task_type: taskType,
};
}
function ensureToolResultMeta( function ensureToolResultMeta(
event: Extract<NativeLlmStreamEvent, { type: 'tool_result' }>, event: Extract<NativeLlmStreamEvent, { type: 'tool_result' }>,
toolCalls: Map<string, ToolCallMeta> toolCalls: Map<string, ToolCallMeta>
@@ -244,7 +449,7 @@ export class NativeProviderAdapter {
constructor( constructor(
dispatch: NativeDispatchFn, dispatch: NativeDispatchFn,
tools: ToolSet, tools: CopilotToolSet,
maxSteps = 20, maxSteps = 20,
options: NativeProviderAdapterOptions = {} options: NativeProviderAdapterOptions = {}
) { ) {
@@ -259,9 +464,13 @@ export class NativeProviderAdapter {
enabledNodeTextMiddlewares.has('citation_footnote'); enabledNodeTextMiddlewares.has('citation_footnote');
} }
async text(request: NativeLlmRequest, signal?: AbortSignal) { async text(
request: NativeLlmRequest,
signal?: AbortSignal,
messages?: PromptMessage[]
) {
let output = ''; let output = '';
for await (const chunk of this.streamText(request, signal)) { for await (const chunk of this.streamText(request, signal, messages)) {
output += chunk; output += chunk;
} }
return output.trim(); return output.trim();
@@ -269,7 +478,8 @@ export class NativeProviderAdapter {
async *streamText( async *streamText(
request: NativeLlmRequest, request: NativeLlmRequest,
signal?: AbortSignal signal?: AbortSignal,
messages?: PromptMessage[]
): AsyncIterableIterator<string> { ): AsyncIterableIterator<string> {
const textParser = this.#enableCallout ? new TextStreamParser() : null; const textParser = this.#enableCallout ? new TextStreamParser() : null;
const citationFormatter = this.#enableCitationFootnote const citationFormatter = this.#enableCitationFootnote
@@ -278,7 +488,7 @@ export class NativeProviderAdapter {
const toolCalls = new Map<string, ToolCallMeta>(); const toolCalls = new Map<string, ToolCallMeta>();
let streamPartId = 0; let streamPartId = 0;
for await (const event of this.#loop.run(request, signal)) { for await (const event of this.#loop.run(request, signal, messages)) {
switch (event.type) { switch (event.type) {
case 'text_delta': { case 'text_delta': {
if (textParser) { if (textParser) {
@@ -364,7 +574,8 @@ export class NativeProviderAdapter {
async *streamObject( async *streamObject(
request: NativeLlmRequest, request: NativeLlmRequest,
signal?: AbortSignal signal?: AbortSignal,
messages?: PromptMessage[]
): AsyncIterableIterator<StreamObject> { ): AsyncIterableIterator<StreamObject> {
const toolCalls = new Map<string, ToolCallMeta>(); const toolCalls = new Map<string, ToolCallMeta>();
const citationFormatter = this.#enableCitationFootnote const citationFormatter = this.#enableCitationFootnote
@@ -373,7 +584,7 @@ export class NativeProviderAdapter {
const fallbackAttachmentFootnotes = new Map<string, AttachmentFootnote>(); const fallbackAttachmentFootnotes = new Map<string, AttachmentFootnote>();
let hasFootnoteReference = false; let hasFootnoteReference = false;
for await (const event of this.#loop.run(request, signal)) { for await (const event of this.#loop.run(request, signal, messages)) {
switch (event.type) { switch (event.type) {
case 'text_delta': { case 'text_delta': {
if (event.text.includes('[^')) { if (event.text.includes('[^')) {

View File

@@ -1,4 +1,3 @@
import type { Tool, ToolSet } from 'ai';
import { z } from 'zod'; import { z } from 'zod';
import { import {
@@ -12,30 +11,41 @@ import {
} from '../../../base'; } from '../../../base';
import { import {
llmDispatchStream, llmDispatchStream,
llmEmbeddingDispatch,
llmRerankDispatch,
llmStructuredDispatch,
type NativeLlmBackendConfig, type NativeLlmBackendConfig,
type NativeLlmEmbeddingRequest,
type NativeLlmRequest, type NativeLlmRequest,
type NativeLlmRerankRequest,
type NativeLlmRerankResponse,
type NativeLlmStructuredRequest,
} from '../../../native'; } from '../../../native';
import type { NodeTextMiddleware } from '../config'; import type { NodeTextMiddleware } from '../config';
import { buildNativeRequest, NativeProviderAdapter } from './native'; import type { CopilotTool, CopilotToolSet } from '../tools';
import { CopilotProvider } from './provider'; import { IMAGE_ATTACHMENT_CAPABILITY } from './attachments';
import { import {
normalizeRerankModel, buildNativeEmbeddingRequest,
OPENAI_RERANK_MAX_COMPLETION_TOKENS, buildNativeRequest,
OPENAI_RERANK_TOP_LOGPROBS_LIMIT, buildNativeStructuredRequest,
usesRerankReasoning, NativeProviderAdapter,
} from './rerank'; parseNativeStructuredOutput,
} from './native';
import { CopilotProvider } from './provider';
import type { import type {
CopilotChatOptions, CopilotChatOptions,
CopilotChatTools, CopilotChatTools,
CopilotEmbeddingOptions, CopilotEmbeddingOptions,
CopilotImageOptions, CopilotImageOptions,
CopilotRerankRequest,
CopilotStructuredOptions, CopilotStructuredOptions,
ModelCapability,
ModelConditions, ModelConditions,
PromptMessage, PromptMessage,
StreamObject, StreamObject,
} from './types'; } from './types';
import { CopilotProviderType, ModelInputType, ModelOutputType } from './types'; import { CopilotProviderType, ModelInputType, ModelOutputType } from './types';
import { chatToGPTMessage } from './utils'; import { promptAttachmentToUrl } from './utils';
export const DEFAULT_DIMENSIONS = 256; export const DEFAULT_DIMENSIONS = 256;
@@ -91,19 +101,6 @@ const ImageResponseSchema = z.union([
}), }),
}), }),
]); ]);
const LogProbsSchema = z.array(
z.object({
token: z.string(),
logprob: z.number(),
top_logprobs: z.array(
z.object({
token: z.string(),
logprob: z.number(),
})
),
})
);
const TRUSTED_ATTACHMENT_HOST_SUFFIXES = ['cdn.affine.pro']; const TRUSTED_ATTACHMENT_HOST_SUFFIXES = ['cdn.affine.pro'];
function normalizeImageFormatToMime(format?: string) { function normalizeImageFormatToMime(format?: string) {
@@ -136,6 +133,34 @@ function normalizeImageResponseData(
.filter((value): value is string => typeof value === 'string'); .filter((value): value is string => typeof value === 'string');
} }
function buildOpenAIRerankRequest(
model: string,
request: CopilotRerankRequest
): NativeLlmRerankRequest {
return {
model,
query: request.query,
candidates: request.candidates.map(candidate => ({
...(candidate.id ? { id: candidate.id } : {}),
text: candidate.text,
})),
...(request.topK ? { top_n: request.topK } : {}),
};
}
function createOpenAIMultimodalCapability(
output: ModelCapability['output'],
options: Pick<ModelCapability, 'defaultForOutputType'> = {}
): ModelCapability {
return {
input: [ModelInputType.Text, ModelInputType.Image],
output,
attachments: IMAGE_ATTACHMENT_CAPABILITY,
structuredAttachments: IMAGE_ATTACHMENT_CAPABILITY,
...options,
};
}
export class OpenAIProvider extends CopilotProvider<OpenAIConfig> { export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
readonly type = CopilotProviderType.OpenAI; readonly type = CopilotProviderType.OpenAI;
@@ -145,10 +170,10 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
name: 'GPT 4o', name: 'GPT 4o',
id: 'gpt-4o', id: 'gpt-4o',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image], ModelOutputType.Text,
output: [ModelOutputType.Text, ModelOutputType.Object], ModelOutputType.Object,
}, ]),
], ],
}, },
// FIXME(@darkskygit): deprecated // FIXME(@darkskygit): deprecated
@@ -156,20 +181,20 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
name: 'GPT 4o 2024-08-06', name: 'GPT 4o 2024-08-06',
id: 'gpt-4o-2024-08-06', id: 'gpt-4o-2024-08-06',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image], ModelOutputType.Text,
output: [ModelOutputType.Text, ModelOutputType.Object], ModelOutputType.Object,
}, ]),
], ],
}, },
{ {
name: 'GPT 4o Mini', name: 'GPT 4o Mini',
id: 'gpt-4o-mini', id: 'gpt-4o-mini',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image], ModelOutputType.Text,
output: [ModelOutputType.Text, ModelOutputType.Object], ModelOutputType.Object,
}, ]),
], ],
}, },
// FIXME(@darkskygit): deprecated // FIXME(@darkskygit): deprecated
@@ -177,181 +202,158 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
name: 'GPT 4o Mini 2024-07-18', name: 'GPT 4o Mini 2024-07-18',
id: 'gpt-4o-mini-2024-07-18', id: 'gpt-4o-mini-2024-07-18',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image], ModelOutputType.Text,
output: [ModelOutputType.Text, ModelOutputType.Object], ModelOutputType.Object,
}, ]),
], ],
}, },
{ {
name: 'GPT 4.1', name: 'GPT 4.1',
id: 'gpt-4.1', id: 'gpt-4.1',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability(
input: [ModelInputType.Text, ModelInputType.Image], [
output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Rerank,
ModelOutputType.Structured, ModelOutputType.Structured,
], ],
defaultForOutputType: true, { defaultForOutputType: true }
}, ),
], ],
}, },
{ {
name: 'GPT 4.1 2025-04-14', name: 'GPT 4.1 2025-04-14',
id: 'gpt-4.1-2025-04-14', id: 'gpt-4.1-2025-04-14',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image],
output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Rerank,
ModelOutputType.Structured, ModelOutputType.Structured,
], ]),
},
], ],
}, },
{ {
name: 'GPT 4.1 Mini', name: 'GPT 4.1 Mini',
id: 'gpt-4.1-mini', id: 'gpt-4.1-mini',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image],
output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Rerank,
ModelOutputType.Structured, ModelOutputType.Structured,
], ]),
},
], ],
}, },
{ {
name: 'GPT 4.1 Nano', name: 'GPT 4.1 Nano',
id: 'gpt-4.1-nano', id: 'gpt-4.1-nano',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image],
output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Rerank,
ModelOutputType.Structured, ModelOutputType.Structured,
], ]),
},
], ],
}, },
{ {
name: 'GPT 5', name: 'GPT 5',
id: 'gpt-5', id: 'gpt-5',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image],
output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Structured, ModelOutputType.Structured,
], ]),
},
], ],
}, },
{ {
name: 'GPT 5 2025-08-07', name: 'GPT 5 2025-08-07',
id: 'gpt-5-2025-08-07', id: 'gpt-5-2025-08-07',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image],
output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Structured, ModelOutputType.Structured,
], ]),
},
], ],
}, },
{ {
name: 'GPT 5 Mini', name: 'GPT 5 Mini',
id: 'gpt-5-mini', id: 'gpt-5-mini',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image],
output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Structured, ModelOutputType.Structured,
], ]),
},
], ],
}, },
{ {
name: 'GPT 5.2', name: 'GPT 5.2',
id: 'gpt-5.2', id: 'gpt-5.2',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image],
output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Rerank,
ModelOutputType.Structured, ModelOutputType.Structured,
], ]),
},
], ],
}, },
{ {
name: 'GPT 5.2 2025-12-11', name: 'GPT 5.2 2025-12-11',
id: 'gpt-5.2-2025-12-11', id: 'gpt-5.2-2025-12-11',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image],
output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Structured, ModelOutputType.Structured,
], ]),
},
], ],
}, },
{ {
name: 'GPT 5 Nano', name: 'GPT 5 Nano',
id: 'gpt-5-nano', id: 'gpt-5-nano',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image],
output: [
ModelOutputType.Text, ModelOutputType.Text,
ModelOutputType.Object, ModelOutputType.Object,
ModelOutputType.Structured, ModelOutputType.Structured,
], ]),
},
], ],
}, },
{ {
name: 'GPT O1', name: 'GPT O1',
id: 'o1', id: 'o1',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image], ModelOutputType.Text,
output: [ModelOutputType.Text, ModelOutputType.Object], ModelOutputType.Object,
}, ]),
], ],
}, },
{ {
name: 'GPT O3', name: 'GPT O3',
id: 'o3', id: 'o3',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image], ModelOutputType.Text,
output: [ModelOutputType.Text, ModelOutputType.Object], ModelOutputType.Object,
}, ]),
], ],
}, },
{ {
name: 'GPT O4 Mini', name: 'GPT O4 Mini',
id: 'o4-mini', id: 'o4-mini',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([
input: [ModelInputType.Text, ModelInputType.Image], ModelOutputType.Text,
output: [ModelOutputType.Text, ModelOutputType.Object], ModelOutputType.Object,
}, ]),
], ],
}, },
// Embedding models // Embedding models
@@ -387,11 +389,9 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
{ {
id: 'gpt-image-1', id: 'gpt-image-1',
capabilities: [ capabilities: [
{ createOpenAIMultimodalCapability([ModelOutputType.Image], {
input: [ModelInputType.Text, ModelInputType.Image],
output: [ModelOutputType.Image],
defaultForOutputType: true, defaultForOutputType: true,
}, }),
], ],
}, },
]; ];
@@ -437,7 +437,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
override getProviderSpecificTools( override getProviderSpecificTools(
toolName: CopilotChatTools, toolName: CopilotChatTools,
_model: string _model: string
): [string, Tool?] | undefined { ): [string, CopilotTool?] | undefined {
if (toolName === 'docEdit') { if (toolName === 'docEdit') {
return ['doc_edit', undefined]; return ['doc_edit', undefined];
} }
@@ -452,14 +452,18 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
}; };
} }
private getNativeProtocol() {
return this.config.oldApiStyle ? 'openai_chat' : 'openai_responses';
}
private createNativeAdapter( private createNativeAdapter(
tools: ToolSet, tools: CopilotToolSet,
nodeTextMiddleware?: NodeTextMiddleware[] nodeTextMiddleware?: NodeTextMiddleware[]
) { ) {
return new NativeProviderAdapter( return new NativeProviderAdapter(
(request: NativeLlmRequest, signal?: AbortSignal) => (request: NativeLlmRequest, signal?: AbortSignal) =>
llmDispatchStream( llmDispatchStream(
this.config.oldApiStyle ? 'openai_chat' : 'openai_responses', this.getNativeProtocol(),
this.createNativeConfig(), this.createNativeConfig(),
request, request,
signal signal
@@ -470,6 +474,27 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
); );
} }
protected createNativeStructuredDispatch(
backendConfig: NativeLlmBackendConfig
) {
return (request: NativeLlmStructuredRequest) =>
llmStructuredDispatch(this.getNativeProtocol(), backendConfig, request);
}
protected createNativeEmbeddingDispatch(
backendConfig: NativeLlmBackendConfig
) {
return (request: NativeLlmEmbeddingRequest) =>
llmEmbeddingDispatch(this.getNativeProtocol(), backendConfig, request);
}
protected createNativeRerankDispatch(backendConfig: NativeLlmBackendConfig) {
return (
request: NativeLlmRerankRequest
): Promise<NativeLlmRerankResponse> =>
llmRerankDispatch('openai_chat', backendConfig, request);
}
private getReasoning( private getReasoning(
options: NonNullable<CopilotChatOptions>, options: NonNullable<CopilotChatOptions>,
model: string model: string
@@ -486,13 +511,18 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): Promise<string> { ): Promise<string> {
const fullCond = { ...cond, outputType: ModelOutputType.Text }; const fullCond = { ...cond, outputType: ModelOutputType.Text };
await this.checkParams({ messages, cond: fullCond, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); messages,
cond: fullCond,
options,
});
const model = this.selectModel(normalizedCond);
try { try {
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
const tools = await this.getTools(options, model.id); const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware(); const middleware = this.getActiveProviderMiddleware();
const cap = this.getAttachCapability(model, ModelOutputType.Text);
const normalizedOptions = normalizeOpenAIOptionsForModel( const normalizedOptions = normalizeOpenAIOptionsForModel(
options, options,
model.id model.id
@@ -502,12 +532,13 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
messages, messages,
options: normalizedOptions, options: normalizedOptions,
tools, tools,
attachmentCapability: cap,
include: options.webSearch ? ['citations'] : undefined, include: options.webSearch ? ['citations'] : undefined,
reasoning: this.getReasoning(options, model.id), reasoning: this.getReasoning(options, model.id),
middleware, middleware,
}); });
const adapter = this.createNativeAdapter(tools, middleware.node?.text); const adapter = this.createNativeAdapter(tools, middleware.node?.text);
return await adapter.text(request, options.signal); return await adapter.text(request, options.signal, messages);
} catch (e: any) { } catch (e: any) {
metrics.ai metrics.ai
.counter('chat_text_errors') .counter('chat_text_errors')
@@ -525,8 +556,12 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
...cond, ...cond,
outputType: ModelOutputType.Text, outputType: ModelOutputType.Text,
}; };
await this.checkParams({ messages, cond: fullCond, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); messages,
cond: fullCond,
options,
});
const model = this.selectModel(normalizedCond);
try { try {
metrics.ai metrics.ai
@@ -534,6 +569,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
.add(1, this.metricLabels(model.id)); .add(1, this.metricLabels(model.id));
const tools = await this.getTools(options, model.id); const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware(); const middleware = this.getActiveProviderMiddleware();
const cap = this.getAttachCapability(model, ModelOutputType.Text);
const normalizedOptions = normalizeOpenAIOptionsForModel( const normalizedOptions = normalizeOpenAIOptionsForModel(
options, options,
model.id model.id
@@ -543,12 +579,17 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
messages, messages,
options: normalizedOptions, options: normalizedOptions,
tools, tools,
attachmentCapability: cap,
include: options.webSearch ? ['citations'] : undefined, include: options.webSearch ? ['citations'] : undefined,
reasoning: this.getReasoning(options, model.id), reasoning: this.getReasoning(options, model.id),
middleware, middleware,
}); });
const adapter = this.createNativeAdapter(tools, middleware.node?.text); const adapter = this.createNativeAdapter(tools, middleware.node?.text);
for await (const chunk of adapter.streamText(request, options.signal)) { for await (const chunk of adapter.streamText(
request,
options.signal,
messages
)) {
yield chunk; yield chunk;
} }
} catch (e: any) { } catch (e: any) {
@@ -565,8 +606,12 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): AsyncIterable<StreamObject> { ): AsyncIterable<StreamObject> {
const fullCond = { ...cond, outputType: ModelOutputType.Object }; const fullCond = { ...cond, outputType: ModelOutputType.Object };
await this.checkParams({ cond: fullCond, messages, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); cond: fullCond,
messages,
options,
});
const model = this.selectModel(normalizedCond);
try { try {
metrics.ai metrics.ai
@@ -574,6 +619,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
.add(1, this.metricLabels(model.id)); .add(1, this.metricLabels(model.id));
const tools = await this.getTools(options, model.id); const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware(); const middleware = this.getActiveProviderMiddleware();
const cap = this.getAttachCapability(model, ModelOutputType.Object);
const normalizedOptions = normalizeOpenAIOptionsForModel( const normalizedOptions = normalizeOpenAIOptionsForModel(
options, options,
model.id model.id
@@ -583,12 +629,17 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
messages, messages,
options: normalizedOptions, options: normalizedOptions,
tools, tools,
attachmentCapability: cap,
include: options.webSearch ? ['citations'] : undefined, include: options.webSearch ? ['citations'] : undefined,
reasoning: this.getReasoning(options, model.id), reasoning: this.getReasoning(options, model.id),
middleware, middleware,
}); });
const adapter = this.createNativeAdapter(tools, middleware.node?.text); const adapter = this.createNativeAdapter(tools, middleware.node?.text);
for await (const chunk of adapter.streamObject(request, options.signal)) { for await (const chunk of adapter.streamObject(
request,
options.signal,
messages
)) {
yield chunk; yield chunk;
} }
} catch (e: any) { } catch (e: any) {
@@ -605,31 +656,34 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
options: CopilotStructuredOptions = {} options: CopilotStructuredOptions = {}
): Promise<string> { ): Promise<string> {
const fullCond = { ...cond, outputType: ModelOutputType.Structured }; const fullCond = { ...cond, outputType: ModelOutputType.Structured };
await this.checkParams({ messages, cond: fullCond, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); messages,
cond: fullCond,
options,
});
const model = this.selectModel(normalizedCond);
try { try {
metrics.ai.counter('chat_text_calls').add(1, { model: model.id }); metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
const tools = await this.getTools(options, model.id); const backendConfig = this.createNativeConfig();
const middleware = this.getActiveProviderMiddleware(); const middleware = this.getActiveProviderMiddleware();
const cap = this.getAttachCapability(model, ModelOutputType.Structured);
const normalizedOptions = normalizeOpenAIOptionsForModel( const normalizedOptions = normalizeOpenAIOptionsForModel(
options, options,
model.id model.id
); );
const { request, schema } = await buildNativeRequest({ const { request, schema } = await buildNativeStructuredRequest({
model: model.id, model: model.id,
messages, messages,
options: normalizedOptions, options: normalizedOptions,
tools, attachmentCapability: cap,
reasoning: this.getReasoning(options, model.id), reasoning: this.getReasoning(options, model.id),
responseSchema: options.schema,
middleware, middleware,
}); });
if (!schema) { const response =
throw new CopilotPromptInvalid('Schema is required'); await this.createNativeStructuredDispatch(backendConfig)(request);
} const parsed = parseNativeStructuredOutput(response);
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
const text = await adapter.text(request, options.signal);
const parsed = JSON.parse(text);
const validated = schema.parse(parsed); const validated = schema.parse(parsed);
return JSON.stringify(validated); return JSON.stringify(validated);
} catch (e: any) { } catch (e: any) {
@@ -640,71 +694,26 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
override async rerank( override async rerank(
cond: ModelConditions, cond: ModelConditions,
chunkMessages: PromptMessage[][], request: CopilotRerankRequest,
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): Promise<number[]> { ): Promise<number[]> {
const fullCond = { ...cond, outputType: ModelOutputType.Text }; const fullCond = { ...cond, outputType: ModelOutputType.Rerank };
await this.checkParams({ messages: [], cond: fullCond, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); messages: [],
cond: fullCond,
options,
});
const model = this.selectModel(normalizedCond);
const scores = await Promise.all( try {
chunkMessages.map(async messages => { const backendConfig = this.createNativeConfig();
const [system, msgs] = await chatToGPTMessage(messages); const nativeRequest = buildOpenAIRerankRequest(model.id, request);
const rerankModel = normalizeRerankModel(model.id); const response =
const response = await this.requestOpenAIJson( await this.createNativeRerankDispatch(backendConfig)(nativeRequest);
'/chat/completions', return response.scores;
{ } catch (e: any) {
model: rerankModel, throw this.handleError(e);
messages: this.toOpenAIChatMessages(system, msgs),
temperature: 0,
logprobs: true,
top_logprobs: OPENAI_RERANK_TOP_LOGPROBS_LIMIT,
...(usesRerankReasoning(rerankModel)
? {
reasoning_effort: 'none' as const,
max_completion_tokens: OPENAI_RERANK_MAX_COMPLETION_TOKENS,
} }
: { max_tokens: OPENAI_RERANK_MAX_COMPLETION_TOKENS }),
},
options.signal
);
const logprobs = response?.choices?.[0]?.logprobs?.content;
if (!Array.isArray(logprobs) || logprobs.length === 0) {
return 0;
}
const parsedLogprobs = LogProbsSchema.parse(logprobs);
const topMap = parsedLogprobs[0].top_logprobs.reduce(
(acc, { token, logprob }) => ({ ...acc, [token]: logprob }),
{} as Record<string, number>
);
const findLogProb = (token: string): number => {
// OpenAI often includes a leading space, so try matching '.yes', '_yes', ' yes' and 'yes'
return [...'_:. "-\t,(=_“'.split('').map(c => c + token), token]
.flatMap(v => [v, v.toLowerCase(), v.toUpperCase()])
.reduce<number>(
(best, key) =>
(topMap[key] ?? Number.NEGATIVE_INFINITY) > best
? topMap[key]
: best,
Number.NEGATIVE_INFINITY
);
};
const logYes = findLogProb('Yes');
const logNo = findLogProb('No');
const pYes = Math.exp(logYes);
const pNo = Math.exp(logNo);
const prob = pYes + pNo === 0 ? 0 : pYes / (pYes + pNo);
return prob;
})
);
return scores;
} }
// ====== text to image ====== // ====== text to image ======
@@ -906,7 +915,8 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
form.set('output_format', outputFormat); form.set('output_format', outputFormat);
for (const [idx, entry] of attachments.entries()) { for (const [idx, entry] of attachments.entries()) {
const url = typeof entry === 'string' ? entry : entry.attachment; const url = promptAttachmentToUrl(entry);
if (!url) continue;
try { try {
const attachment = await this.fetchImage(url, maxBytes, signal); const attachment = await this.fetchImage(url, maxBytes, signal);
if (!attachment) continue; if (!attachment) continue;
@@ -964,8 +974,12 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
options: CopilotImageOptions = {} options: CopilotImageOptions = {}
) { ) {
const fullCond = { ...cond, outputType: ModelOutputType.Image }; const fullCond = { ...cond, outputType: ModelOutputType.Image };
await this.checkParams({ messages, cond: fullCond, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); messages,
cond: fullCond,
options,
});
const model = this.selectModel(normalizedCond);
metrics.ai metrics.ai
.counter('generate_images_stream_calls') .counter('generate_images_stream_calls')
@@ -1017,65 +1031,36 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
messages: string | string[], messages: string | string[],
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS } options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
): Promise<number[][]> { ): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages]; const input = Array.isArray(messages) ? messages : [messages];
const fullCond = { ...cond, outputType: ModelOutputType.Embedding }; const fullCond = { ...cond, outputType: ModelOutputType.Embedding };
await this.checkParams({ embeddings: messages, cond: fullCond, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); embeddings: input,
cond: fullCond,
options,
});
const model = this.selectModel(normalizedCond);
try { try {
metrics.ai metrics.ai
.counter('generate_embedding_calls') .counter('generate_embedding_calls')
.add(1, { model: model.id }); .add(1, this.metricLabels(model.id));
const response = await this.requestOpenAIJson('/embeddings', { const backendConfig = this.createNativeConfig();
const response = await this.createNativeEmbeddingDispatch(backendConfig)(
buildNativeEmbeddingRequest({
model: model.id, model: model.id,
input: messages, inputs: input,
dimensions: options.dimensions || DEFAULT_DIMENSIONS, dimensions: options.dimensions || DEFAULT_DIMENSIONS,
}); })
const data = Array.isArray(response?.data) ? response.data : []; );
return data return response.embeddings;
.map((item: any) => item?.embedding)
.filter((embedding: unknown) => Array.isArray(embedding)) as number[][];
} catch (e: any) { } catch (e: any) {
metrics.ai metrics.ai
.counter('generate_embedding_errors') .counter('generate_embedding_errors')
.add(1, { model: model.id }); .add(1, this.metricLabels(model.id));
throw this.handleError(e); throw this.handleError(e);
} }
} }
private toOpenAIChatMessages(
system: string | undefined,
messages: Awaited<ReturnType<typeof chatToGPTMessage>>[1]
) {
const result: Array<{ role: string; content: string }> = [];
if (system) {
result.push({ role: 'system', content: system });
}
for (const message of messages) {
if (typeof message.content === 'string') {
result.push({ role: message.role, content: message.content });
continue;
}
const text = message.content
.filter(
part =>
part &&
typeof part === 'object' &&
'type' in part &&
part.type === 'text' &&
'text' in part
)
.map(part => String((part as { text: string }).text))
.join('\n');
result.push({ role: message.role, content: text || '[no content]' });
}
return result;
}
private async requestOpenAIJson( private async requestOpenAIJson(
path: string, path: string,
body: Record<string, unknown>, body: Record<string, unknown>,

View File

@@ -1,5 +1,3 @@
import type { ToolSet } from 'ai';
import { CopilotProviderSideError, metrics } from '../../../base'; import { CopilotProviderSideError, metrics } from '../../../base';
import { import {
llmDispatchStream, llmDispatchStream,
@@ -7,6 +5,7 @@ import {
type NativeLlmRequest, type NativeLlmRequest,
} from '../../../native'; } from '../../../native';
import type { NodeTextMiddleware } from '../config'; import type { NodeTextMiddleware } from '../config';
import type { CopilotToolSet } from '../tools';
import { buildNativeRequest, NativeProviderAdapter } from './native'; import { buildNativeRequest, NativeProviderAdapter } from './native';
import { CopilotProvider } from './provider'; import { CopilotProvider } from './provider';
import { import {
@@ -87,7 +86,7 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
} }
private createNativeAdapter( private createNativeAdapter(
tools: ToolSet, tools: CopilotToolSet,
nodeTextMiddleware?: NodeTextMiddleware[] nodeTextMiddleware?: NodeTextMiddleware[]
) { ) {
return new NativeProviderAdapter( return new NativeProviderAdapter(
@@ -110,8 +109,13 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): Promise<string> { ): Promise<string> {
const fullCond = { ...cond, outputType: ModelOutputType.Text }; const fullCond = { ...cond, outputType: ModelOutputType.Text };
await this.checkParams({ cond: fullCond, messages, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); cond: fullCond,
messages,
options,
withAttachment: false,
});
const model = this.selectModel(normalizedCond);
try { try {
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
@@ -128,7 +132,7 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
middleware, middleware,
}); });
const adapter = this.createNativeAdapter(tools, middleware.node?.text); const adapter = this.createNativeAdapter(tools, middleware.node?.text);
return await adapter.text(request, options.signal); return await adapter.text(request, options.signal, messages);
} catch (e: any) { } catch (e: any) {
metrics.ai metrics.ai
.counter('chat_text_errors') .counter('chat_text_errors')
@@ -143,8 +147,13 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): AsyncIterable<string> { ): AsyncIterable<string> {
const fullCond = { ...cond, outputType: ModelOutputType.Text }; const fullCond = { ...cond, outputType: ModelOutputType.Text };
await this.checkParams({ cond: fullCond, messages, options }); const normalizedCond = await this.checkParams({
const model = this.selectModel(fullCond); cond: fullCond,
messages,
options,
withAttachment: false,
});
const model = this.selectModel(normalizedCond);
try { try {
metrics.ai metrics.ai
@@ -163,7 +172,11 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
middleware, middleware,
}); });
const adapter = this.createNativeAdapter(tools, middleware.node?.text); const adapter = this.createNativeAdapter(tools, middleware.node?.text);
for await (const chunk of adapter.streamText(request, options.signal)) { for await (const chunk of adapter.streamText(
request,
options.signal,
messages
)) {
yield chunk; yield chunk;
} }
} catch (e: any) { } catch (e: any) {

View File

@@ -51,13 +51,21 @@ const DEFAULT_MIDDLEWARE_BY_TYPE: Record<
}, },
}, },
[CopilotProviderType.Gemini]: { [CopilotProviderType.Gemini]: {
rust: {
request: ['normalize_messages', 'tool_schema_rewrite'],
stream: ['stream_event_normalize', 'citation_indexing'],
},
node: { node: {
text: ['callout'], text: ['citation_footnote', 'callout'],
}, },
}, },
[CopilotProviderType.GeminiVertex]: { [CopilotProviderType.GeminiVertex]: {
rust: {
request: ['normalize_messages', 'tool_schema_rewrite'],
stream: ['stream_event_normalize', 'citation_indexing'],
},
node: { node: {
text: ['callout'], text: ['citation_footnote', 'callout'],
}, },
}, },
[CopilotProviderType.FAL]: {}, [CopilotProviderType.FAL]: {},

View File

@@ -5,7 +5,7 @@ import type {
ProviderMiddlewareConfig, ProviderMiddlewareConfig,
} from '../config'; } from '../config';
import { resolveProviderMiddleware } from './provider-middleware'; import { resolveProviderMiddleware } from './provider-middleware';
import { CopilotProviderType, type ModelOutputType } from './types'; import { CopilotProviderType, ModelOutputType } from './types';
const PROVIDER_ID_PATTERN = /^[a-zA-Z0-9-_]+$/; const PROVIDER_ID_PATTERN = /^[a-zA-Z0-9-_]+$/;
@@ -239,8 +239,13 @@ export function resolveModel({
}; };
} }
const defaultProviderId =
outputType && outputType !== ModelOutputType.Rerank
? registry.defaults[outputType]
: undefined;
const fallbackOrder = [ const fallbackOrder = [
...(outputType ? [registry.defaults[outputType]] : []), ...(defaultProviderId ? [defaultProviderId] : []),
registry.defaults.fallback, registry.defaults.fallback,
...registry.order, ...registry.order,
].filter((id): id is string => !!id); ].filter((id): id is string => !!id);

View File

@@ -2,7 +2,6 @@ import { AsyncLocalStorage } from 'node:async_hooks';
import { Inject, Injectable, Logger } from '@nestjs/common'; import { Inject, Injectable, Logger } from '@nestjs/common';
import { ModuleRef } from '@nestjs/core'; import { ModuleRef } from '@nestjs/core';
import { Tool, ToolSet } from 'ai';
import { z } from 'zod'; import { z } from 'zod';
import { import {
@@ -27,6 +26,8 @@ import {
buildDocSearchGetter, buildDocSearchGetter,
buildDocUpdateHandler, buildDocUpdateHandler,
buildDocUpdateMetaHandler, buildDocUpdateMetaHandler,
type CopilotTool,
type CopilotToolSet,
createBlobReadTool, createBlobReadTool,
createCodeArtifactTool, createCodeArtifactTool,
createConversationSummaryTool, createConversationSummaryTool,
@@ -42,6 +43,7 @@ import {
createExaSearchTool, createExaSearchTool,
createSectionEditTool, createSectionEditTool,
} from '../tools'; } from '../tools';
import { canonicalizePromptAttachment } from './attachments';
import { CopilotProviderFactory } from './factory'; import { CopilotProviderFactory } from './factory';
import { resolveProviderMiddleware } from './provider-middleware'; import { resolveProviderMiddleware } from './provider-middleware';
import { buildProviderRegistry } from './provider-registry'; import { buildProviderRegistry } from './provider-registry';
@@ -52,12 +54,17 @@ import {
type CopilotImageOptions, type CopilotImageOptions,
CopilotProviderModel, CopilotProviderModel,
CopilotProviderType, CopilotProviderType,
type CopilotRerankRequest,
CopilotStructuredOptions, CopilotStructuredOptions,
EmbeddingMessage, EmbeddingMessage,
type ModelAttachmentCapability,
ModelCapability, ModelCapability,
ModelConditions, ModelConditions,
ModelFullConditions, ModelFullConditions,
ModelInputType, ModelInputType,
ModelOutputType,
type PromptAttachmentKind,
type PromptAttachmentSourceKind,
type PromptMessage, type PromptMessage,
PromptMessageSchema, PromptMessageSchema,
StreamObject, StreamObject,
@@ -163,6 +170,163 @@ export abstract class CopilotProvider<C = any> {
async refreshOnlineModels() {} async refreshOnlineModels() {}
private unique<T>(values: Iterable<T>) {
return Array.from(new Set(values));
}
private attachmentKindToInputType(
kind: PromptAttachmentKind
): ModelInputType {
switch (kind) {
case 'image':
return ModelInputType.Image;
case 'audio':
return ModelInputType.Audio;
default:
return ModelInputType.File;
}
}
protected async inferModelConditionsFromMessages(
messages?: PromptMessage[],
withAttachment = true
): Promise<Partial<ModelFullConditions>> {
if (!messages?.length || !withAttachment) return {};
const attachmentKinds: PromptAttachmentKind[] = [];
const attachmentSourceKinds: PromptAttachmentSourceKind[] = [];
const inputTypes: ModelInputType[] = [];
let hasRemoteAttachments = false;
for (const message of messages) {
if (!Array.isArray(message.attachments)) continue;
for (const attachment of message.attachments) {
const normalized = await canonicalizePromptAttachment(
attachment,
message
);
attachmentKinds.push(normalized.kind);
inputTypes.push(this.attachmentKindToInputType(normalized.kind));
attachmentSourceKinds.push(normalized.sourceKind);
hasRemoteAttachments = hasRemoteAttachments || normalized.isRemote;
}
}
return {
...(attachmentKinds.length
? { attachmentKinds: this.unique(attachmentKinds) }
: {}),
...(attachmentSourceKinds.length
? { attachmentSourceKinds: this.unique(attachmentSourceKinds) }
: {}),
...(inputTypes.length ? { inputTypes: this.unique(inputTypes) } : {}),
...(hasRemoteAttachments ? { hasRemoteAttachments } : {}),
};
}
private mergeModelConditions(
cond: ModelFullConditions,
inferredCond: Partial<ModelFullConditions>
): ModelFullConditions {
return {
...inferredCond,
...cond,
inputTypes: this.unique([
...(inferredCond.inputTypes ?? []),
...(cond.inputTypes ?? []),
]),
attachmentKinds: this.unique([
...(inferredCond.attachmentKinds ?? []),
...(cond.attachmentKinds ?? []),
]),
attachmentSourceKinds: this.unique([
...(inferredCond.attachmentSourceKinds ?? []),
...(cond.attachmentSourceKinds ?? []),
]),
hasRemoteAttachments:
cond.hasRemoteAttachments ?? inferredCond.hasRemoteAttachments,
};
}
protected getAttachCapability(
model: CopilotProviderModel,
outputType: ModelOutputType
): ModelAttachmentCapability | undefined {
const capability =
model.capabilities.find(cap => cap.output.includes(outputType)) ??
model.capabilities[0];
if (!capability) {
return;
}
return this.resolveAttachmentCapability(capability, outputType);
}
private resolveAttachmentCapability(
cap: ModelCapability,
outputType?: ModelOutputType
): ModelAttachmentCapability | undefined {
if (outputType === ModelOutputType.Structured) {
return cap.structuredAttachments ?? cap.attachments;
}
return cap.attachments;
}
private matchesAttachCapability(
cap: ModelCapability,
cond: ModelFullConditions
) {
const {
attachmentKinds,
attachmentSourceKinds,
hasRemoteAttachments,
outputType,
} = cond;
if (
!attachmentKinds?.length &&
!attachmentSourceKinds?.length &&
!hasRemoteAttachments
) {
return true;
}
const attachmentCapability = this.resolveAttachmentCapability(
cap,
outputType
);
if (!attachmentCapability) {
return !attachmentKinds?.some(
kind => !cap.input.includes(this.attachmentKindToInputType(kind))
);
}
if (
attachmentKinds?.some(kind => !attachmentCapability.kinds.includes(kind))
) {
return false;
}
if (
attachmentSourceKinds?.length &&
attachmentCapability.sourceKinds?.length &&
attachmentSourceKinds.some(
kind => !attachmentCapability.sourceKinds?.includes(kind)
)
) {
return false;
}
if (
hasRemoteAttachments &&
attachmentCapability.allowRemoteUrls === false
) {
return false;
}
return true;
}
private findValidModel( private findValidModel(
cond: ModelFullConditions cond: ModelFullConditions
): CopilotProviderModel | undefined { ): CopilotProviderModel | undefined {
@@ -170,7 +334,8 @@ export abstract class CopilotProvider<C = any> {
const matcher = (cap: ModelCapability) => const matcher = (cap: ModelCapability) =>
(!outputType || cap.output.includes(outputType)) && (!outputType || cap.output.includes(outputType)) &&
(!inputTypes?.length || (!inputTypes?.length ||
inputTypes.every(type => cap.input.includes(type))); inputTypes.every(type => cap.input.includes(type))) &&
this.matchesAttachCapability(cap, cond);
if (modelId) { if (modelId) {
const hasOnlineModel = this.onlineModelList.includes(modelId); const hasOnlineModel = this.onlineModelList.includes(modelId);
@@ -213,7 +378,7 @@ export abstract class CopilotProvider<C = any> {
protected getProviderSpecificTools( protected getProviderSpecificTools(
_toolName: CopilotChatTools, _toolName: CopilotChatTools,
_model: string _model: string
): [string, Tool?] | undefined { ): [string, CopilotTool?] | undefined {
return; return;
} }
@@ -221,8 +386,8 @@ export abstract class CopilotProvider<C = any> {
protected async getTools( protected async getTools(
options: CopilotChatOptions, options: CopilotChatOptions,
model: string model: string
): Promise<ToolSet> { ): Promise<CopilotToolSet> {
const tools: ToolSet = {}; const tools: CopilotToolSet = {};
if (options?.tools?.length) { if (options?.tools?.length) {
this.logger.debug(`getTools: ${JSON.stringify(options.tools)}`); this.logger.debug(`getTools: ${JSON.stringify(options.tools)}`);
const ac = this.moduleRef.get(AccessController, { strict: false }); const ac = this.moduleRef.get(AccessController, { strict: false });
@@ -377,19 +542,14 @@ export abstract class CopilotProvider<C = any> {
messages, messages,
embeddings, embeddings,
options = {}, options = {},
withAttachment = true,
}: { }: {
cond: ModelFullConditions; cond: ModelFullConditions;
messages?: PromptMessage[]; messages?: PromptMessage[];
embeddings?: string[]; embeddings?: string[];
options?: CopilotChatOptions; options?: CopilotChatOptions | CopilotStructuredOptions;
}) { withAttachment?: boolean;
const model = this.selectModel(cond); }): Promise<ModelFullConditions> {
const multimodal = model.capabilities.some(c =>
[ModelInputType.Image, ModelInputType.Audio].some(t =>
c.input.includes(t)
)
);
if (messages) { if (messages) {
const { requireContent = true, requireAttachment = false } = options; const { requireContent = true, requireAttachment = false } = options;
@@ -402,20 +562,56 @@ export abstract class CopilotProvider<C = any> {
}) })
.passthrough() .passthrough()
.catchall(z.union([z.string(), z.number(), z.date(), z.null()])) .catchall(z.union([z.string(), z.number(), z.date(), z.null()]))
.refine(
m =>
!(multimodal && requireAttachment && m.role === 'user') ||
(m.attachments ? m.attachments.length > 0 : true),
{ message: 'attachments required in multimodal mode' }
)
) )
.optional(); .optional();
this.handleZodError(MessageSchema.safeParse(messages)); this.handleZodError(MessageSchema.safeParse(messages));
const inferredCond = await this.inferModelConditionsFromMessages(
messages,
withAttachment
);
const mergedCond = this.mergeModelConditions(cond, inferredCond);
const model = this.selectModel(mergedCond);
const multimodal = model.capabilities.some(c =>
[ModelInputType.Image, ModelInputType.Audio, ModelInputType.File].some(
t => c.input.includes(t)
)
);
if (
multimodal &&
requireAttachment &&
!messages.some(
message =>
message.role === 'user' &&
Array.isArray(message.attachments) &&
message.attachments.length > 0
)
) {
throw new CopilotPromptInvalid(
'attachments required in multimodal mode'
);
} }
if (embeddings) { if (embeddings) {
this.handleZodError(EmbeddingMessage.safeParse(embeddings)); this.handleZodError(EmbeddingMessage.safeParse(embeddings));
} }
return mergedCond;
}
const inferredCond = await this.inferModelConditionsFromMessages(
messages,
withAttachment
);
const mergedCond = this.mergeModelConditions(cond, inferredCond);
if (embeddings) {
this.handleZodError(EmbeddingMessage.safeParse(embeddings));
}
return mergedCond;
} }
abstract text( abstract text(
@@ -476,7 +672,7 @@ export abstract class CopilotProvider<C = any> {
async rerank( async rerank(
_model: ModelConditions, _model: ModelConditions,
_messages: PromptMessage[][], _request: CopilotRerankRequest,
_options?: CopilotChatOptions _options?: CopilotChatOptions
): Promise<number[]> { ): Promise<number[]> {
throw new CopilotProviderNotSupported({ throw new CopilotProviderNotSupported({

View File

@@ -1,23 +0,0 @@
const GPT_4_RERANK_MODELS = /^(gpt-4(?:$|[.-]))/;
const GPT_5_RERANK_LOGPROBS_MODELS = /^(gpt-5\.2(?:$|-))/;
export const DEFAULT_RERANK_MODEL = 'gpt-5.2';
export const OPENAI_RERANK_TOP_LOGPROBS_LIMIT = 5;
export const OPENAI_RERANK_MAX_COMPLETION_TOKENS = 16;
export function supportsRerankModel(model: string): boolean {
return (
GPT_4_RERANK_MODELS.test(model) || GPT_5_RERANK_LOGPROBS_MODELS.test(model)
);
}
export function usesRerankReasoning(model: string): boolean {
return GPT_5_RERANK_LOGPROBS_MODELS.test(model);
}
export function normalizeRerankModel(model?: string | null): string {
if (model && supportsRerankModel(model)) {
return model;
}
return DEFAULT_RERANK_MODEL;
}

View File

@@ -124,14 +124,97 @@ export const ChatMessageRole = Object.values(AiPromptRole) as [
'user', 'user',
]; ];
const AttachmentUrlSchema = z.string().refine(value => {
if (value.startsWith('data:')) {
return true;
}
try {
const url = new URL(value);
return (
url.protocol === 'http:' ||
url.protocol === 'https:' ||
url.protocol === 'gs:'
);
} catch {
return false;
}
}, 'attachments must use https?://, gs:// or data: urls');
export const PromptAttachmentSourceKindSchema = z.enum([
'url',
'data',
'bytes',
'file_handle',
]);
export const PromptAttachmentKindSchema = z.enum(['image', 'audio', 'file']);
const AttachmentProviderHintSchema = z
.object({
provider: z.nativeEnum(CopilotProviderType).optional(),
kind: PromptAttachmentKindSchema.optional(),
})
.strict();
const PromptAttachmentSchema = z.discriminatedUnion('kind', [
z
.object({
kind: z.literal('url'),
url: AttachmentUrlSchema,
mimeType: z.string().optional(),
fileName: z.string().optional(),
providerHint: AttachmentProviderHintSchema.optional(),
})
.strict(),
z
.object({
kind: z.literal('data'),
data: z.string(),
mimeType: z.string(),
encoding: z.enum(['base64', 'utf8']).optional(),
fileName: z.string().optional(),
providerHint: AttachmentProviderHintSchema.optional(),
})
.strict(),
z
.object({
kind: z.literal('bytes'),
data: z.string(),
mimeType: z.string(),
encoding: z.literal('base64').optional(),
fileName: z.string().optional(),
providerHint: AttachmentProviderHintSchema.optional(),
})
.strict(),
z
.object({
kind: z.literal('file_handle'),
fileHandle: z.string().trim().min(1),
mimeType: z.string().optional(),
fileName: z.string().optional(),
providerHint: AttachmentProviderHintSchema.optional(),
})
.strict(),
]);
export const ChatMessageAttachment = z.union([ export const ChatMessageAttachment = z.union([
z.string().url(), AttachmentUrlSchema,
z.object({ z.object({
attachment: z.string(), attachment: AttachmentUrlSchema,
mimeType: z.string(), mimeType: z.string(),
}), }),
PromptAttachmentSchema,
]); ]);
export const PromptResponseFormatSchema = z
.object({
type: z.literal('json_schema'),
schema: z.any(),
strict: z.boolean().optional(),
})
.strict();
export const StreamObjectSchema = z.discriminatedUnion('type', [ export const StreamObjectSchema = z.discriminatedUnion('type', [
z.object({ z.object({
type: z.literal('text-delta'), type: z.literal('text-delta'),
@@ -161,6 +244,7 @@ export const PureMessageSchema = z.object({
streamObjects: z.array(StreamObjectSchema).optional().nullable(), streamObjects: z.array(StreamObjectSchema).optional().nullable(),
attachments: z.array(ChatMessageAttachment).optional().nullable(), attachments: z.array(ChatMessageAttachment).optional().nullable(),
params: z.record(z.any()).optional().nullable(), params: z.record(z.any()).optional().nullable(),
responseFormat: PromptResponseFormatSchema.optional().nullable(),
}); });
export const PromptMessageSchema = PureMessageSchema.extend({ export const PromptMessageSchema = PureMessageSchema.extend({
@@ -169,6 +253,12 @@ export const PromptMessageSchema = PureMessageSchema.extend({
export type PromptMessage = z.infer<typeof PromptMessageSchema>; export type PromptMessage = z.infer<typeof PromptMessageSchema>;
export type PromptParams = NonNullable<PromptMessage['params']>; export type PromptParams = NonNullable<PromptMessage['params']>;
export type StreamObject = z.infer<typeof StreamObjectSchema>; export type StreamObject = z.infer<typeof StreamObjectSchema>;
export type PromptAttachment = z.infer<typeof ChatMessageAttachment>;
export type PromptAttachmentSourceKind = z.infer<
typeof PromptAttachmentSourceKindSchema
>;
export type PromptAttachmentKind = z.infer<typeof PromptAttachmentKindSchema>;
export type PromptResponseFormat = z.infer<typeof PromptResponseFormatSchema>;
// ========== options ========== // ========== options ==========
@@ -194,7 +284,9 @@ export type CopilotChatTools = NonNullable<
>[number]; >[number];
export const CopilotStructuredOptionsSchema = export const CopilotStructuredOptionsSchema =
CopilotProviderOptionsSchema.merge(PromptConfigStrictSchema).optional(); CopilotProviderOptionsSchema.merge(PromptConfigStrictSchema)
.extend({ schema: z.any().optional(), strict: z.boolean().optional() })
.optional();
export type CopilotStructuredOptions = z.infer< export type CopilotStructuredOptions = z.infer<
typeof CopilotStructuredOptionsSchema typeof CopilotStructuredOptionsSchema
@@ -220,10 +312,22 @@ export type CopilotEmbeddingOptions = z.infer<
typeof CopilotEmbeddingOptionsSchema typeof CopilotEmbeddingOptionsSchema
>; >;
export type CopilotRerankCandidate = {
id?: string;
text: string;
};
export type CopilotRerankRequest = {
query: string;
candidates: CopilotRerankCandidate[];
topK?: number;
};
export enum ModelInputType { export enum ModelInputType {
Text = 'text', Text = 'text',
Image = 'image', Image = 'image',
Audio = 'audio', Audio = 'audio',
File = 'file',
} }
export enum ModelOutputType { export enum ModelOutputType {
@@ -231,12 +335,21 @@ export enum ModelOutputType {
Object = 'object', Object = 'object',
Embedding = 'embedding', Embedding = 'embedding',
Image = 'image', Image = 'image',
Rerank = 'rerank',
Structured = 'structured', Structured = 'structured',
} }
export interface ModelAttachmentCapability {
kinds: PromptAttachmentKind[];
sourceKinds?: PromptAttachmentSourceKind[];
allowRemoteUrls?: boolean;
}
export interface ModelCapability { export interface ModelCapability {
input: ModelInputType[]; input: ModelInputType[];
output: ModelOutputType[]; output: ModelOutputType[];
attachments?: ModelAttachmentCapability;
structuredAttachments?: ModelAttachmentCapability;
defaultForOutputType?: boolean; defaultForOutputType?: boolean;
} }
@@ -248,6 +361,9 @@ export interface CopilotProviderModel {
export type ModelConditions = { export type ModelConditions = {
inputTypes?: ModelInputType[]; inputTypes?: ModelInputType[];
attachmentKinds?: PromptAttachmentKind[];
attachmentSourceKinds?: PromptAttachmentSourceKind[];
hasRemoteAttachments?: boolean;
modelId?: string; modelId?: string;
}; };

View File

@@ -1,34 +1,39 @@
import { GoogleVertexProviderSettings } from '@ai-sdk/google-vertex';
import { GoogleVertexAnthropicProviderSettings } from '@ai-sdk/google-vertex/anthropic';
import { Logger } from '@nestjs/common'; import { Logger } from '@nestjs/common';
import {
AssistantModelMessage,
FilePart,
ImagePart,
TextPart,
TextStreamPart,
UserModelMessage,
} from 'ai';
import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library';
import z, { ZodType } from 'zod'; import z from 'zod';
import { import { OneMinute, safeFetch } from '../../../base';
bufferToArrayBuffer, import { PromptAttachment, StreamObject } from './types';
fetchBuffer,
OneMinute,
ResponseTooLargeError,
safeFetch,
SsrfBlockedError,
} from '../../../base';
import { CustomAITools } from '../tools';
import { PromptMessage, StreamObject } from './types';
type ChatMessage = UserModelMessage | AssistantModelMessage; export type VertexProviderConfig = {
location?: string;
project?: string;
baseURL?: string;
googleAuthOptions?: GoogleAuthOptions;
fetch?: typeof fetch;
};
export type VertexAnthropicProviderConfig = VertexProviderConfig;
type CopilotTextStreamPart =
| { type: 'text-delta'; text: string; id?: string }
| { type: 'reasoning-delta'; text: string; id?: string }
| {
type: 'tool-call';
toolCallId: string;
toolName: string;
input: Record<string, unknown>;
}
| {
type: 'tool-result';
toolCallId: string;
toolName: string;
input: Record<string, unknown>;
output: unknown;
}
| { type: 'error'; error: unknown };
const ATTACHMENT_MAX_BYTES = 20 * 1024 * 1024;
const ATTACH_HEAD_PARAMS = { timeoutMs: OneMinute / 12, maxRedirects: 3 }; const ATTACH_HEAD_PARAMS = { timeoutMs: OneMinute / 12, maxRedirects: 3 };
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
const FORMAT_INFER_MAP: Record<string, string> = { const FORMAT_INFER_MAP: Record<string, string> = {
pdf: 'application/pdf', pdf: 'application/pdf',
mp3: 'audio/mpeg', mp3: 'audio/mpeg',
@@ -53,9 +58,39 @@ const FORMAT_INFER_MAP: Record<string, string> = {
flv: 'video/flv', flv: 'video/flv',
}; };
async function fetchArrayBuffer(url: string): Promise<ArrayBuffer> { function toBase64Data(data: string, encoding: 'base64' | 'utf8' = 'base64') {
const { buffer } = await fetchBuffer(url, ATTACHMENT_MAX_BYTES); return encoding === 'base64'
return bufferToArrayBuffer(buffer); ? data
: Buffer.from(data, 'utf8').toString('base64');
}
export function promptAttachmentToUrl(
attachment: PromptAttachment
): string | undefined {
if (typeof attachment === 'string') return attachment;
if ('attachment' in attachment) return attachment.attachment;
switch (attachment.kind) {
case 'url':
return attachment.url;
case 'data':
return `data:${attachment.mimeType};base64,${toBase64Data(
attachment.data,
attachment.encoding
)}`;
case 'bytes':
return `data:${attachment.mimeType};base64,${attachment.data}`;
case 'file_handle':
return;
}
}
export function promptAttachmentMimeType(
attachment: PromptAttachment,
fallbackMimeType?: string
): string | undefined {
if (typeof attachment === 'string') return fallbackMimeType;
if ('attachment' in attachment) return attachment.mimeType;
return attachment.mimeType ?? fallbackMimeType;
} }
export async function inferMimeType(url: string) { export async function inferMimeType(url: string) {
@@ -69,6 +104,7 @@ export async function inferMimeType(url: string) {
if (ext) { if (ext) {
return ext; return ext;
} }
}
try { try {
const mimeType = await safeFetch( const mimeType = await safeFetch(
url, url,
@@ -79,336 +115,10 @@ export async function inferMimeType(url: string) {
} catch { } catch {
// ignore and fallback to default // ignore and fallback to default
} }
}
return 'application/octet-stream'; return 'application/octet-stream';
} }
export async function chatToGPTMessage( type CitationIndexedEvent = {
messages: PromptMessage[],
// TODO(@darkskygit): move this logic in interface refactoring
withAttachment: boolean = true,
// NOTE: some providers in vercel ai sdk are not able to handle url attachments yet
// so we need to use base64 encoded attachments instead
useBase64Attachment: boolean = false
): Promise<[string | undefined, ChatMessage[], ZodType?]> {
const hasSystem = messages[0]?.role === 'system';
const system = hasSystem ? messages[0] : undefined;
const normalizedMessages = hasSystem ? messages.slice(1) : messages;
const schema =
system?.params?.schema && system.params.schema instanceof ZodType
? system.params.schema
: undefined;
// filter redundant fields
const msgs: ChatMessage[] = [];
for (let { role, content, attachments, params } of normalizedMessages.filter(
m => m.role !== 'system'
)) {
content = content.trim();
role = role as 'user' | 'assistant';
const mimetype = params?.mimetype;
if (Array.isArray(attachments)) {
const contents: (TextPart | ImagePart | FilePart)[] = [];
if (content.length) {
contents.push({ type: 'text', text: content });
}
if (withAttachment) {
for (let attachment of attachments) {
let mediaType: string;
if (typeof attachment === 'string') {
mediaType =
typeof mimetype === 'string'
? mimetype
: await inferMimeType(attachment);
} else {
({ attachment, mimeType: mediaType } = attachment);
}
if (SIMPLE_IMAGE_URL_REGEX.test(attachment)) {
const data =
attachment.startsWith('data:') || useBase64Attachment
? await fetchArrayBuffer(attachment).catch(error => {
// Avoid leaking internal details for blocked URLs.
if (
error instanceof SsrfBlockedError ||
error instanceof ResponseTooLargeError
) {
throw new Error('Attachment URL is not allowed');
}
throw error;
})
: new URL(attachment);
if (mediaType.startsWith('image/')) {
contents.push({ type: 'image', image: data, mediaType });
} else {
contents.push({ type: 'file' as const, data, mediaType });
}
}
}
} else if (!content.length) {
// temp fix for pplx
contents.push({ type: 'text', text: '[no content]' });
}
msgs.push({ role, content: contents } as ChatMessage);
} else {
msgs.push({ role, content });
}
}
return [system?.content, msgs, schema];
}
// pattern types the callback will receive
type Pattern =
| { kind: 'index'; value: number } // [123]
| { kind: 'link'; text: string; url: string } // [text](url)
| { kind: 'wrappedLink'; text: string; url: string }; // ([text](url))
type NeedMore = { kind: 'needMore' };
type Failed = { kind: 'fail'; nextPos: number };
type Finished =
| { kind: 'ok'; endPos: number; text: string; url: string }
| { kind: 'index'; endPos: number; value: number };
type ParseStatus = Finished | NeedMore | Failed;
type PatternCallback = (m: Pattern) => string;
export class StreamPatternParser {
#buffer = '';
constructor(private readonly callback: PatternCallback) {}
write(chunk: string): string {
this.#buffer += chunk;
const output: string[] = [];
let i = 0;
while (i < this.#buffer.length) {
const ch = this.#buffer[i];
// [[[number]]] or [text](url) or ([text](url))
if (ch === '[' || (ch === '(' && this.peek(i + 1) === '[')) {
const isWrapped = ch === '(';
const startPos = isWrapped ? i + 1 : i;
const res = this.tryParse(startPos);
if (res.kind === 'needMore') break;
const { output: out, nextPos } = this.handlePattern(
res,
isWrapped,
startPos,
i
);
output.push(out);
i = nextPos;
continue;
}
output.push(ch);
i += 1;
}
this.#buffer = this.#buffer.slice(i);
return output.join('');
}
end(): string {
const rest = this.#buffer;
this.#buffer = '';
return rest;
}
// =========== helpers ===========
private peek(pos: number): string | undefined {
return pos < this.#buffer.length ? this.#buffer[pos] : undefined;
}
private tryParse(pos: number): ParseStatus {
const nestedRes = this.tryParseNestedIndex(pos);
if (nestedRes) return nestedRes;
return this.tryParseBracketPattern(pos);
}
private tryParseNestedIndex(pos: number): ParseStatus | null {
if (this.peek(pos + 1) !== '[') return null;
let i = pos;
let bracketCount = 0;
while (i < this.#buffer.length && this.#buffer[i] === '[') {
bracketCount++;
i++;
}
if (bracketCount >= 2) {
if (i >= this.#buffer.length) {
return { kind: 'needMore' };
}
let content = '';
while (i < this.#buffer.length && this.#buffer[i] !== ']') {
content += this.#buffer[i++];
}
let rightBracketCount = 0;
while (i < this.#buffer.length && this.#buffer[i] === ']') {
rightBracketCount++;
i++;
}
if (i >= this.#buffer.length && rightBracketCount < bracketCount) {
return { kind: 'needMore' };
}
if (
rightBracketCount === bracketCount &&
content.length > 0 &&
this.isNumeric(content)
) {
if (this.peek(i) === '(') {
return { kind: 'fail', nextPos: i };
}
return { kind: 'index', endPos: i, value: Number(content) };
}
}
return null;
}
private tryParseBracketPattern(pos: number): ParseStatus {
let i = pos + 1; // skip '['
if (i >= this.#buffer.length) {
return { kind: 'needMore' };
}
let content = '';
while (i < this.#buffer.length && this.#buffer[i] !== ']') {
const nextChar = this.#buffer[i];
if (nextChar === '[') {
return { kind: 'fail', nextPos: i };
}
content += nextChar;
i += 1;
}
if (i >= this.#buffer.length) {
return { kind: 'needMore' };
}
const after = i + 1;
const afterChar = this.peek(after);
if (content.length > 0 && this.isNumeric(content) && afterChar !== '(') {
// [number] pattern
return { kind: 'index', endPos: after, value: Number(content) };
} else if (afterChar !== '(') {
// [text](url) pattern
return { kind: 'fail', nextPos: after };
}
i = after + 1; // skip '('
if (i >= this.#buffer.length) {
return { kind: 'needMore' };
}
let url = '';
while (i < this.#buffer.length && this.#buffer[i] !== ')') {
url += this.#buffer[i++];
}
if (i >= this.#buffer.length) {
return { kind: 'needMore' };
}
return { kind: 'ok', endPos: i + 1, text: content, url };
}
private isNumeric(str: string): boolean {
return !Number.isNaN(Number(str)) && str.trim() !== '';
}
private handlePattern(
pattern: Finished | Failed,
isWrapped: boolean,
start: number,
current: number
): { output: string; nextPos: number } {
if (pattern.kind === 'fail') {
return {
output: this.#buffer.slice(current, pattern.nextPos),
nextPos: pattern.nextPos,
};
}
if (isWrapped) {
const afterLinkPos = pattern.endPos;
if (this.peek(afterLinkPos) !== ')') {
if (afterLinkPos >= this.#buffer.length) {
return { output: '', nextPos: current };
}
return { output: '(', nextPos: start };
}
const out =
pattern.kind === 'index'
? this.callback({ ...pattern, kind: 'index' })
: this.callback({ ...pattern, kind: 'wrappedLink' });
return { output: out, nextPos: afterLinkPos + 1 };
} else {
const out =
pattern.kind === 'ok'
? this.callback({ ...pattern, kind: 'link' })
: this.callback({ ...pattern, kind: 'index' });
return { output: out, nextPos: pattern.endPos };
}
}
}
export class CitationParser {
private readonly citations: string[] = [];
private readonly parser = new StreamPatternParser(p => {
switch (p.kind) {
case 'index': {
if (p.value <= this.citations.length) {
return `[^${p.value}]`;
}
return `[${p.value}]`;
}
case 'wrappedLink': {
const index = this.citations.indexOf(p.url);
if (index === -1) {
this.citations.push(p.url);
return `[^${this.citations.length}]`;
}
return `[^${index + 1}]`;
}
case 'link': {
return `[${p.text}](${p.url})`;
}
}
});
public push(citation: string) {
this.citations.push(citation);
}
public parse(content: string) {
return this.parser.write(content);
}
public end() {
return this.parser.end() + '\n' + this.getFootnotes();
}
private getFootnotes() {
const footnotes = this.citations.map((citation, index) => {
return `[^${index + 1}]: {"type":"url","url":"${encodeURIComponent(
citation
)}"}`;
});
return footnotes.join('\n');
}
}
export type CitationIndexedEvent = {
type: 'citation'; type: 'citation';
index: number; index: number;
url: string; url: string;
@@ -436,7 +146,7 @@ export class CitationFootnoteFormatter {
} }
} }
type ChunkType = TextStreamPart<CustomAITools>['type']; type ChunkType = CopilotTextStreamPart['type'];
export function toError(error: unknown): Error { export function toError(error: unknown): Error {
if (typeof error === 'string') { if (typeof error === 'string') {
@@ -458,6 +168,14 @@ type DocEditFootnote = {
intent: string; intent: string;
result: string; result: string;
}; };
function asRecord(value: unknown): Record<string, unknown> | null {
if (value && typeof value === 'object' && !Array.isArray(value)) {
return value as Record<string, unknown>;
}
return null;
}
export class TextStreamParser { export class TextStreamParser {
private readonly logger = new Logger(TextStreamParser.name); private readonly logger = new Logger(TextStreamParser.name);
private readonly CALLOUT_PREFIX = '\n[!]\n'; private readonly CALLOUT_PREFIX = '\n[!]\n';
@@ -468,7 +186,7 @@ export class TextStreamParser {
private readonly docEditFootnotes: DocEditFootnote[] = []; private readonly docEditFootnotes: DocEditFootnote[] = [];
public parse(chunk: TextStreamPart<CustomAITools>) { public parse(chunk: CopilotTextStreamPart) {
let result = ''; let result = '';
switch (chunk.type) { switch (chunk.type) {
case 'text-delta': { case 'text-delta': {
@@ -517,7 +235,7 @@ export class TextStreamParser {
} }
case 'doc_edit': { case 'doc_edit': {
this.docEditFootnotes.push({ this.docEditFootnotes.push({
intent: chunk.input.instructions, intent: String(chunk.input.instructions ?? ''),
result: '', result: '',
}); });
break; break;
@@ -533,14 +251,12 @@ export class TextStreamParser {
result = this.addPrefix(result); result = this.addPrefix(result);
switch (chunk.toolName) { switch (chunk.toolName) {
case 'doc_edit': { case 'doc_edit': {
const array = const output = asRecord(chunk.output);
chunk.output && typeof chunk.output === 'object' const array = output?.result;
? chunk.output.result
: undefined;
if (Array.isArray(array)) { if (Array.isArray(array)) {
result += array result += array
.map(item => { .map(item => {
return `\n${item.changedContent}\n`; return `\n${String(asRecord(item)?.changedContent ?? '')}\n`;
}) })
.join(''); .join('');
this.docEditFootnotes[this.docEditFootnotes.length - 1].result = this.docEditFootnotes[this.docEditFootnotes.length - 1].result =
@@ -557,8 +273,11 @@ export class TextStreamParser {
} else if (typeof output === 'string') { } else if (typeof output === 'string') {
result += `\n${output}\n`; result += `\n${output}\n`;
} else { } else {
const message = asRecord(output)?.message;
this.logger.warn( this.logger.warn(
`Unexpected result type for doc_semantic_search: ${output?.message || 'Unknown error'}` `Unexpected result type for doc_semantic_search: ${
typeof message === 'string' ? message : 'Unknown error'
}`
); );
} }
break; break;
@@ -572,9 +291,11 @@ export class TextStreamParser {
break; break;
} }
case 'doc_compose': { case 'doc_compose': {
const output = chunk.output; const output = asRecord(chunk.output);
if (output && typeof output === 'object' && 'title' in output) { if (output && typeof output.title === 'string') {
result += `\nDocument "${output.title}" created successfully with ${output.wordCount} words.\n`; result += `\nDocument "${output.title}" created successfully with ${String(
output.wordCount ?? 0
)} words.\n`;
} }
break; break;
} }
@@ -654,7 +375,7 @@ export class TextStreamParser {
} }
export class StreamObjectParser { export class StreamObjectParser {
public parse(chunk: TextStreamPart<CustomAITools>) { public parse(chunk: CopilotTextStreamPart) {
switch (chunk.type) { switch (chunk.type) {
case 'reasoning-delta': { case 'reasoning-delta': {
return { type: 'reasoning' as const, textDelta: chunk.text }; return { type: 'reasoning' as const, textDelta: chunk.text };
@@ -747,9 +468,7 @@ function normalizeUrl(baseURL?: string) {
} }
} }
export function getVertexAnthropicBaseUrl( export function getVertexAnthropicBaseUrl(options: VertexProviderConfig) {
options: GoogleVertexAnthropicProviderSettings
) {
const normalizedBaseUrl = normalizeUrl(options.baseURL); const normalizedBaseUrl = normalizeUrl(options.baseURL);
if (normalizedBaseUrl) return normalizedBaseUrl; if (normalizedBaseUrl) return normalizedBaseUrl;
const { location, project } = options; const { location, project } = options;
@@ -758,7 +477,7 @@ export function getVertexAnthropicBaseUrl(
} }
export async function getGoogleAuth( export async function getGoogleAuth(
options: GoogleVertexAnthropicProviderSettings | GoogleVertexProviderSettings, options: VertexProviderConfig,
publisher: 'anthropic' | 'google' publisher: 'anthropic' | 'google'
) { ) {
function getBaseUrl() { function getBaseUrl() {
@@ -777,7 +496,7 @@ export async function getGoogleAuth(
} }
const auth = new GoogleAuth({ const auth = new GoogleAuth({
scopes: ['https://www.googleapis.com/auth/cloud-platform'], scopes: ['https://www.googleapis.com/auth/cloud-platform'],
...(options.googleAuthOptions as GoogleAuthOptions), ...options.googleAuthOptions,
}); });
const client = await auth.getClient(); const client = await auth.getClient();
const token = await client.getAccessToken(); const token = await client.getAccessToken();

View File

@@ -31,6 +31,7 @@ import { SubscriptionPlan, SubscriptionStatus } from '../payment/types';
import { ChatMessageCache } from './message'; import { ChatMessageCache } from './message';
import { ChatPrompt } from './prompt/chat-prompt'; import { ChatPrompt } from './prompt/chat-prompt';
import { PromptService } from './prompt/service'; import { PromptService } from './prompt/service';
import { promptAttachmentHasSource } from './providers/attachments';
import { CopilotProviderFactory } from './providers/factory'; import { CopilotProviderFactory } from './providers/factory';
import { buildProviderRegistry } from './providers/provider-registry'; import { buildProviderRegistry } from './providers/provider-registry';
import { import {
@@ -38,6 +39,7 @@ import {
type PromptMessage, type PromptMessage,
type PromptParams, type PromptParams,
} from './providers/types'; } from './providers/types';
import { promptAttachmentToUrl } from './providers/utils';
import { import {
type ChatHistory, type ChatHistory,
type ChatMessage, type ChatMessage,
@@ -272,11 +274,7 @@ export class ChatSession implements AsyncDisposable {
lastMessage.attachments || [], lastMessage.attachments || [],
] ]
.flat() .flat()
.filter(v => .filter(v => promptAttachmentHasSource(v));
typeof v === 'string'
? !!v.trim()
: v && v.attachment.trim() && v.mimeType
);
//insert all previous user message content before first user message //insert all previous user message content before first user message
finished.splice(firstUserMessageIndex, 0, ...messages); finished.splice(firstUserMessageIndex, 0, ...messages);
@@ -466,8 +464,8 @@ export class ChatSessionService {
messages: preload.concat(messages).map(m => ({ messages: preload.concat(messages).map(m => ({
...m, ...m,
attachments: m.attachments attachments: m.attachments
?.map(a => (typeof a === 'string' ? a : a.attachment)) ?.map(a => promptAttachmentToUrl(a))
.filter(a => !!a), .filter((a): a is string => !!a),
})), })),
}; };
} else { } else {

View File

@@ -1,9 +1,9 @@
import { Logger } from '@nestjs/common'; import { Logger } from '@nestjs/common';
import { tool } from 'ai';
import { z } from 'zod'; import { z } from 'zod';
import { AccessController } from '../../../core/permission'; import { AccessController } from '../../../core/permission';
import { toolError } from './error'; import { toolError } from './error';
import { defineTool } from './tool';
import type { ContextSession, CopilotChatOptions } from './types'; import type { ContextSession, CopilotChatOptions } from './types';
const logger = new Logger('ContextBlobReadTool'); const logger = new Logger('ContextBlobReadTool');
@@ -58,7 +58,7 @@ export const createBlobReadTool = (
chunk?: number chunk?: number
) => Promise<object | undefined> ) => Promise<object | undefined>
) => { ) => {
return tool({ return defineTool({
description: description:
'Return the content and basic metadata of a single attachment identified by blobId; more inclined to use search tools rather than this tool.', 'Return the content and basic metadata of a single attachment identified by blobId; more inclined to use search tools rather than this tool.',
inputSchema: z.object({ inputSchema: z.object({

View File

@@ -1,8 +1,8 @@
import { Logger } from '@nestjs/common'; import { Logger } from '@nestjs/common';
import { tool } from 'ai';
import { z } from 'zod'; import { z } from 'zod';
import { toolError } from './error'; import { toolError } from './error';
import { defineTool } from './tool';
import type { CopilotProviderFactory, PromptService } from './types'; import type { CopilotProviderFactory, PromptService } from './types';
const logger = new Logger('CodeArtifactTool'); const logger = new Logger('CodeArtifactTool');
@@ -16,7 +16,7 @@ export const createCodeArtifactTool = (
promptService: PromptService, promptService: PromptService,
factory: CopilotProviderFactory factory: CopilotProviderFactory
) => { ) => {
return tool({ return defineTool({
description: description:
'Generate a single-file HTML snippet (with inline <style> and <script>) that accomplishes the requested functionality. The final HTML should be runnable when saved as an .html file and opened in a browser. Do NOT reference external resources (CSS, JS, images) except through data URIs.', 'Generate a single-file HTML snippet (with inline <style> and <script>) that accomplishes the requested functionality. The final HTML should be runnable when saved as an .html file and opened in a browser. Do NOT reference external resources (CSS, JS, images) except through data URIs.',
inputSchema: z.object({ inputSchema: z.object({

View File

@@ -1,8 +1,8 @@
import { Logger } from '@nestjs/common'; import { Logger } from '@nestjs/common';
import { tool } from 'ai';
import { z } from 'zod'; import { z } from 'zod';
import { toolError } from './error'; import { toolError } from './error';
import { defineTool } from './tool';
import type { CopilotProviderFactory, PromptService } from './types'; import type { CopilotProviderFactory, PromptService } from './types';
const logger = new Logger('ConversationSummaryTool'); const logger = new Logger('ConversationSummaryTool');
@@ -12,7 +12,7 @@ export const createConversationSummaryTool = (
promptService: PromptService, promptService: PromptService,
factory: CopilotProviderFactory factory: CopilotProviderFactory
) => { ) => {
return tool({ return defineTool({
description: description:
'Create a concise, AI-generated summary of the conversation so far—capturing key topics, decisions, and critical details. Use this tool whenever the context becomes lengthy to preserve essential information that might otherwise be lost to truncation in future turns.', 'Create a concise, AI-generated summary of the conversation so far—capturing key topics, decisions, and critical details. Use this tool whenever the context becomes lengthy to preserve essential information that might otherwise be lost to truncation in future turns.',
inputSchema: z.object({ inputSchema: z.object({

View File

@@ -1,8 +1,8 @@
import { Logger } from '@nestjs/common'; import { Logger } from '@nestjs/common';
import { tool } from 'ai';
import { z } from 'zod'; import { z } from 'zod';
import { toolError } from './error'; import { toolError } from './error';
import { defineTool } from './tool';
import type { CopilotProviderFactory, PromptService } from './types'; import type { CopilotProviderFactory, PromptService } from './types';
const logger = new Logger('DocComposeTool'); const logger = new Logger('DocComposeTool');
@@ -11,7 +11,7 @@ export const createDocComposeTool = (
promptService: PromptService, promptService: PromptService,
factory: CopilotProviderFactory factory: CopilotProviderFactory
) => { ) => {
return tool({ return defineTool({
description: description:
'Write a new document with markdown content. This tool creates structured markdown content for documents including titles, sections, and formatting.', 'Write a new document with markdown content. This tool creates structured markdown content for documents including titles, sections, and formatting.',
inputSchema: z.object({ inputSchema: z.object({

View File

@@ -1,8 +1,8 @@
import { tool } from 'ai';
import { z } from 'zod'; import { z } from 'zod';
import { DocReader } from '../../../core/doc'; import { DocReader } from '../../../core/doc';
import { AccessController } from '../../../core/permission'; import { AccessController } from '../../../core/permission';
import { defineTool } from './tool';
import type { import type {
CopilotChatOptions, CopilotChatOptions,
CopilotProviderFactory, CopilotProviderFactory,
@@ -50,7 +50,7 @@ export const createDocEditTool = (
prompt: PromptService, prompt: PromptService,
getContent: (targetId?: string) => Promise<string | undefined> getContent: (targetId?: string) => Promise<string | undefined>
) => { ) => {
return tool({ return defineTool({
description: ` description: `
Use this tool to propose an edit to a structured Markdown document with identifiable blocks. Use this tool to propose an edit to a structured Markdown document with identifiable blocks.
Each block begins with a comment like <!-- block_id=... -->, and represents a unit of editable content such as a heading, paragraph, list, or code snippet. Each block begins with a comment like <!-- block_id=... -->, and represents a unit of editable content such as a heading, paragraph, list, or code snippet.

View File

@@ -1,9 +1,9 @@
import { tool } from 'ai';
import { z } from 'zod'; import { z } from 'zod';
import type { AccessController } from '../../../core/permission'; import type { AccessController } from '../../../core/permission';
import type { IndexerService, SearchDoc } from '../../indexer'; import type { IndexerService, SearchDoc } from '../../indexer';
import { toolError } from './error'; import { toolError } from './error';
import { defineTool } from './tool';
import type { CopilotChatOptions } from './types'; import type { CopilotChatOptions } from './types';
export const buildDocKeywordSearchGetter = ( export const buildDocKeywordSearchGetter = (
@@ -37,7 +37,7 @@ export const buildDocKeywordSearchGetter = (
export const createDocKeywordSearchTool = ( export const createDocKeywordSearchTool = (
searchDocs: (query: string) => Promise<SearchDoc[] | undefined> searchDocs: (query: string) => Promise<SearchDoc[] | undefined>
) => { ) => {
return tool({ return defineTool({
description: description:
'Fuzzy search all workspace documents for the exact keyword or phrase supplied and return passages ranked by textual match. Use this tool by default whenever a straightforward term-based or keyword-base lookup is sufficient.', 'Fuzzy search all workspace documents for the exact keyword or phrase supplied and return passages ranked by textual match. Use this tool by default whenever a straightforward term-based or keyword-base lookup is sufficient.',
inputSchema: z.object({ inputSchema: z.object({

View File

@@ -1,11 +1,11 @@
import { Logger } from '@nestjs/common'; import { Logger } from '@nestjs/common';
import { tool } from 'ai';
import { z } from 'zod'; import { z } from 'zod';
import { DocReader } from '../../../core/doc'; import { DocReader } from '../../../core/doc';
import { AccessController } from '../../../core/permission'; import { AccessController } from '../../../core/permission';
import { Models, publicUserSelect } from '../../../models'; import { Models, publicUserSelect } from '../../../models';
import { toolError } from './error'; import { toolError } from './error';
import { defineTool } from './tool';
import type { CopilotChatOptions } from './types'; import type { CopilotChatOptions } from './types';
const logger = new Logger('DocReadTool'); const logger = new Logger('DocReadTool');
@@ -72,7 +72,7 @@ export const buildDocContentGetter = (
export const createDocReadTool = ( export const createDocReadTool = (
getDoc: (targetId?: string) => Promise<object | undefined> getDoc: (targetId?: string) => Promise<object | undefined>
) => { ) => {
return tool({ return defineTool({
description: description:
'Return the complete text and basic metadata of a single document identified by docId; use this when the user needs the full content of a specific file rather than a search result.', 'Return the complete text and basic metadata of a single document identified by docId; use this when the user needs the full content of a specific file rather than a search result.',
inputSchema: z.object({ inputSchema: z.object({

View File

@@ -1,4 +1,3 @@
import { tool } from 'ai';
import { omit } from 'lodash-es'; import { omit } from 'lodash-es';
import { z } from 'zod'; import { z } from 'zod';
@@ -9,6 +8,7 @@ import {
type Models, type Models,
} from '../../../models'; } from '../../../models';
import { toolError } from './error'; import { toolError } from './error';
import { defineTool } from './tool';
import type { import type {
ContextSession, ContextSession,
CopilotChatOptions, CopilotChatOptions,
@@ -24,7 +24,7 @@ export const buildDocSearchGetter = (
const searchDocs = async ( const searchDocs = async (
options: CopilotChatOptions, options: CopilotChatOptions,
query?: string, query?: string,
abortSignal?: AbortSignal signal?: AbortSignal
) => { ) => {
if (!options || !query?.trim() || !options.user || !options.workspace) { if (!options || !query?.trim() || !options.user || !options.workspace) {
return `Invalid search parameters.`; return `Invalid search parameters.`;
@@ -36,8 +36,8 @@ export const buildDocSearchGetter = (
if (!canAccess) if (!canAccess)
return 'You do not have permission to access this workspace.'; return 'You do not have permission to access this workspace.';
const [chunks, contextChunks] = await Promise.all([ const [chunks, contextChunks] = await Promise.all([
context.matchWorkspaceAll(options.workspace, query, 10, abortSignal), context.matchWorkspaceAll(options.workspace, query, 10, signal),
docContext?.matchFiles(query, 10, abortSignal) ?? [], docContext?.matchFiles(query, 10, signal) ?? [],
]); ]);
const docChunks = await ac const docChunks = await ac
@@ -100,10 +100,10 @@ export const buildDocSearchGetter = (
export const createDocSemanticSearchTool = ( export const createDocSemanticSearchTool = (
searchDocs: ( searchDocs: (
query: string, query: string,
abortSignal?: AbortSignal signal?: AbortSignal
) => Promise<ChunkSimilarity[] | string | undefined> ) => Promise<ChunkSimilarity[] | string | undefined>
) => { ) => {
return tool({ return defineTool({
description: description:
'Retrieve conceptually related passages by performing vector-based semantic similarity search across embedded documents; use this tool only when exact keyword search fails or the user explicitly needs meaning-level matches (e.g., paraphrases, synonyms, broader concepts, recent documents).', 'Retrieve conceptually related passages by performing vector-based semantic similarity search across embedded documents; use this tool only when exact keyword search fails or the user explicitly needs meaning-level matches (e.g., paraphrases, synonyms, broader concepts, recent documents).',
inputSchema: z.object({ inputSchema: z.object({
@@ -115,7 +115,7 @@ export const createDocSemanticSearchTool = (
}), }),
execute: async ({ query }, options) => { execute: async ({ query }, options) => {
try { try {
return await searchDocs(query, options.abortSignal); return await searchDocs(query, options.signal);
} catch (e: any) { } catch (e: any) {
return toolError('Doc Semantic Search Failed', e.message); return toolError('Doc Semantic Search Failed', e.message);
} }

View File

@@ -1,10 +1,10 @@
import { Logger } from '@nestjs/common'; import { Logger } from '@nestjs/common';
import { tool } from 'ai';
import { z } from 'zod'; import { z } from 'zod';
import { DocWriter } from '../../../core/doc'; import { DocWriter } from '../../../core/doc';
import { AccessController } from '../../../core/permission'; import { AccessController } from '../../../core/permission';
import { toolError } from './error'; import { toolError } from './error';
import { defineTool } from './tool';
import type { CopilotChatOptions } from './types'; import type { CopilotChatOptions } from './types';
const logger = new Logger('DocWriteTool'); const logger = new Logger('DocWriteTool');
@@ -141,7 +141,7 @@ export const buildDocUpdateMetaHandler = (
export const createDocCreateTool = ( export const createDocCreateTool = (
createDoc: (title: string, content: string) => Promise<object> createDoc: (title: string, content: string) => Promise<object>
) => { ) => {
return tool({ return defineTool({
description: description:
'Create a new document in the workspace with the given title and markdown content. Returns the ID of the created document. This tool not support insert or update database block and image yet.', 'Create a new document in the workspace with the given title and markdown content. Returns the ID of the created document. This tool not support insert or update database block and image yet.',
inputSchema: z.object({ inputSchema: z.object({
@@ -164,7 +164,7 @@ export const createDocCreateTool = (
export const createDocUpdateTool = ( export const createDocUpdateTool = (
updateDoc: (docId: string, content: string) => Promise<object> updateDoc: (docId: string, content: string) => Promise<object>
) => { ) => {
return tool({ return defineTool({
description: description:
'Update an existing document with new markdown content (body only). Uses structural diffing to apply minimal changes. This does NOT update the document title. This tool not support insert or update database block and image yet.', 'Update an existing document with new markdown content (body only). Uses structural diffing to apply minimal changes. This does NOT update the document title. This tool not support insert or update database block and image yet.',
inputSchema: z.object({ inputSchema: z.object({
@@ -189,7 +189,7 @@ export const createDocUpdateTool = (
export const createDocUpdateMetaTool = ( export const createDocUpdateMetaTool = (
updateDocMeta: (docId: string, title: string) => Promise<object> updateDocMeta: (docId: string, title: string) => Promise<object>
) => { ) => {
return tool({ return defineTool({
description: 'Update document metadata (currently title only).', description: 'Update document metadata (currently title only).',
inputSchema: z.object({ inputSchema: z.object({
doc_id: z.string().describe('The ID of the document to update'), doc_id: z.string().describe('The ID of the document to update'),

View File

@@ -1,12 +1,12 @@
import { tool } from 'ai';
import Exa from 'exa-js'; import Exa from 'exa-js';
import { z } from 'zod'; import { z } from 'zod';
import { Config } from '../../../base'; import { Config } from '../../../base';
import { toolError } from './error'; import { toolError } from './error';
import { defineTool } from './tool';
export const createExaCrawlTool = (config: Config) => { export const createExaCrawlTool = (config: Config) => {
return tool({ return defineTool({
description: 'Crawl the web url for information', description: 'Crawl the web url for information',
inputSchema: z.object({ inputSchema: z.object({
url: z url: z

View File

@@ -1,12 +1,12 @@
import { tool } from 'ai';
import Exa from 'exa-js'; import Exa from 'exa-js';
import { z } from 'zod'; import { z } from 'zod';
import { Config } from '../../../base'; import { Config } from '../../../base';
import { toolError } from './error'; import { toolError } from './error';
import { defineTool } from './tool';
export const createExaSearchTool = (config: Config) => { export const createExaSearchTool = (config: Config) => {
return tool({ return defineTool({
description: 'Search the web for information', description: 'Search the web for information',
inputSchema: z.object({ inputSchema: z.object({
query: z.string().describe('The query to search the web for.'), query: z.string().describe('The query to search the web for.'),

View File

@@ -1,39 +1,3 @@
import { ToolSet } from 'ai';
import { createBlobReadTool } from './blob-read';
import { createCodeArtifactTool } from './code-artifact';
import { createConversationSummaryTool } from './conversation-summary';
import { createDocComposeTool } from './doc-compose';
import { createDocEditTool } from './doc-edit';
import { createDocKeywordSearchTool } from './doc-keyword-search';
import { createDocReadTool } from './doc-read';
import { createDocSemanticSearchTool } from './doc-semantic-search';
import {
createDocCreateTool,
createDocUpdateMetaTool,
createDocUpdateTool,
} from './doc-write';
import { createExaCrawlTool } from './exa-crawl';
import { createExaSearchTool } from './exa-search';
import { createSectionEditTool } from './section-edit';
export interface CustomAITools extends ToolSet {
blob_read: ReturnType<typeof createBlobReadTool>;
code_artifact: ReturnType<typeof createCodeArtifactTool>;
conversation_summary: ReturnType<typeof createConversationSummaryTool>;
doc_edit: ReturnType<typeof createDocEditTool>;
doc_semantic_search: ReturnType<typeof createDocSemanticSearchTool>;
doc_keyword_search: ReturnType<typeof createDocKeywordSearchTool>;
doc_read: ReturnType<typeof createDocReadTool>;
doc_create: ReturnType<typeof createDocCreateTool>;
doc_update: ReturnType<typeof createDocUpdateTool>;
doc_update_meta: ReturnType<typeof createDocUpdateMetaTool>;
doc_compose: ReturnType<typeof createDocComposeTool>;
section_edit: ReturnType<typeof createSectionEditTool>;
web_search_exa: ReturnType<typeof createExaSearchTool>;
web_crawl_exa: ReturnType<typeof createExaCrawlTool>;
}
export * from './blob-read'; export * from './blob-read';
export * from './code-artifact'; export * from './code-artifact';
export * from './conversation-summary'; export * from './conversation-summary';
@@ -47,3 +11,4 @@ export * from './error';
export * from './exa-crawl'; export * from './exa-crawl';
export * from './exa-search'; export * from './exa-search';
export * from './section-edit'; export * from './section-edit';
export * from './tool';

View File

@@ -1,8 +1,8 @@
import { Logger } from '@nestjs/common'; import { Logger } from '@nestjs/common';
import { tool } from 'ai';
import { z } from 'zod'; import { z } from 'zod';
import { toolError } from './error'; import { toolError } from './error';
import { defineTool } from './tool';
import type { CopilotProviderFactory, PromptService } from './types'; import type { CopilotProviderFactory, PromptService } from './types';
const logger = new Logger('SectionEditTool'); const logger = new Logger('SectionEditTool');
@@ -11,7 +11,7 @@ export const createSectionEditTool = (
promptService: PromptService, promptService: PromptService,
factory: CopilotProviderFactory factory: CopilotProviderFactory
) => { ) => {
return tool({ return defineTool({
description: description:
'Intelligently edit and modify a specific section of a document based on user instructions, with full document context awareness. This tool can refine, rewrite, translate, restructure, or enhance any part of markdown content while preserving formatting, maintaining contextual coherence, and ensuring consistency with the entire document. Perfect for targeted improvements that consider the broader document context.', 'Intelligently edit and modify a specific section of a document based on user instructions, with full document context awareness. This tool can refine, rewrite, translate, restructure, or enhance any part of markdown content while preserving formatting, maintaining contextual coherence, and ensuring consistency with the entire document. Perfect for targeted improvements that consider the broader document context.',
inputSchema: z.object({ inputSchema: z.object({

View File

@@ -0,0 +1,33 @@
import type { ZodTypeAny } from 'zod';
import { z } from 'zod';
import type { PromptMessage } from '../providers/types';
export type CopilotToolExecuteOptions = {
signal?: AbortSignal;
messages?: PromptMessage[];
};
export type CopilotTool = {
description?: string;
inputSchema?: ZodTypeAny | Record<string, unknown>;
execute?: {
bivarianceHack: (
args: Record<string, unknown>,
options: CopilotToolExecuteOptions
) => Promise<unknown> | unknown;
}['bivarianceHack'];
};
export type CopilotToolSet = Record<string, CopilotTool>;
export function defineTool<TSchema extends ZodTypeAny, TResult>(tool: {
description?: string;
inputSchema: TSchema;
execute: (
args: z.infer<TSchema>,
options: CopilotToolExecuteOptions
) => Promise<TResult> | TResult;
}): CopilotTool {
return tool;
}

View File

@@ -224,11 +224,10 @@ export class CopilotTranscriptionService {
const config = Object.assign({}, prompt.config); const config = Object.assign({}, prompt.config);
if (schema) { if (schema) {
const provider = await this.getProvider(prompt.model, true, prefer); const provider = await this.getProvider(prompt.model, true, prefer);
return provider.structure( return provider.structure(cond, [...prompt.finish({}), msg], {
cond, ...config,
[...prompt.finish({ schema }), msg], schema,
config });
);
} else { } else {
const provider = await this.getProvider(prompt.model, false); const provider = await this.getProvider(prompt.model, false);
return provider.text(cond, [...prompt.finish({}), msg], config); return provider.text(cond, [...prompt.finish({}), msg], config);

View File

@@ -189,20 +189,13 @@ test.describe('AISettings/Embedding', () => {
await utils.settings.closeSettingsPanel(page); await utils.settings.closeSettingsPanel(page);
await utils.chatPanel.makeChat( const query = `Use semantic search across workspace and attached files, then tell me whether Workspace${randomStr1} is a cat or dog and whether Workspace${randomStr2} is a cat or dog. Answer with citations.`;
page,
`What is Workspace${randomStr1}? What is Workspace${randomStr2}?` await utils.chatPanel.makeChat(page, query);
);
await utils.chatPanel.waitForHistory(page, [ await utils.chatPanel.waitForHistory(page, [
{ { role: 'user', content: query },
role: 'user', { role: 'assistant', status: 'success' },
content: `What is Workspace${randomStr1}? What is Workspace${randomStr2}?`,
},
{
role: 'assistant',
status: 'success',
},
]); ]);
await expect(async () => { await expect(async () => {

147
yarn.lock
View File

@@ -962,8 +962,6 @@ __metadata:
"@affine/graphql": "workspace:*" "@affine/graphql": "workspace:*"
"@affine/s3-compat": "workspace:*" "@affine/s3-compat": "workspace:*"
"@affine/server-native": "workspace:*" "@affine/server-native": "workspace:*"
"@ai-sdk/google": "npm:^3.0.46"
"@ai-sdk/google-vertex": "npm:^4.0.83"
"@apollo/server": "npm:^4.13.0" "@apollo/server": "npm:^4.13.0"
"@faker-js/faker": "npm:^10.1.0" "@faker-js/faker": "npm:^10.1.0"
"@fal-ai/serverless-client": "npm:^0.15.0" "@fal-ai/serverless-client": "npm:^0.15.0"
@@ -1022,7 +1020,6 @@ __metadata:
"@types/semver": "npm:^7.5.8" "@types/semver": "npm:^7.5.8"
"@types/sinon": "npm:^21.0.0" "@types/sinon": "npm:^21.0.0"
"@types/supertest": "npm:^7.0.0" "@types/supertest": "npm:^7.0.0"
ai: "npm:^6.0.118"
ava: "npm:^7.0.0" ava: "npm:^7.0.0"
bullmq: "npm:^5.40.2" bullmq: "npm:^5.40.2"
c8: "npm:^10.1.3" c8: "npm:^10.1.3"
@@ -1122,80 +1119,6 @@ __metadata:
languageName: unknown languageName: unknown
linkType: soft linkType: soft
"@ai-sdk/anthropic@npm:3.0.59":
version: 3.0.59
resolution: "@ai-sdk/anthropic@npm:3.0.59"
dependencies:
"@ai-sdk/provider": "npm:3.0.8"
"@ai-sdk/provider-utils": "npm:4.0.20"
peerDependencies:
zod: ^3.25.76 || ^4.1.8
checksum: 10/b7504dc845f2cd487a4a18db9dbf9e2231fbe3b0a5a22ea12bedb7d4f276463cdd4fd39493efd40f39c78b9023af5f00a4e603981265b66ea43701ad699da5c9
languageName: node
linkType: hard
"@ai-sdk/gateway@npm:3.0.68":
version: 3.0.68
resolution: "@ai-sdk/gateway@npm:3.0.68"
dependencies:
"@ai-sdk/provider": "npm:3.0.8"
"@ai-sdk/provider-utils": "npm:4.0.20"
"@vercel/oidc": "npm:3.1.0"
peerDependencies:
zod: ^3.25.76 || ^4.1.8
checksum: 10/4a6923a6e610472de0ea18f25077df0b394d31b287285f512d6e16ee8b0b90421faf284399f37e11045ea370794a6060686a3b69b68fda04b6dc13562cd8fd8e
languageName: node
linkType: hard
"@ai-sdk/google-vertex@npm:^4.0.83":
version: 4.0.83
resolution: "@ai-sdk/google-vertex@npm:4.0.83"
dependencies:
"@ai-sdk/anthropic": "npm:3.0.59"
"@ai-sdk/google": "npm:3.0.46"
"@ai-sdk/provider": "npm:3.0.8"
"@ai-sdk/provider-utils": "npm:4.0.20"
google-auth-library: "npm:^10.5.0"
peerDependencies:
zod: ^3.25.76 || ^4.1.8
checksum: 10/330ed81cac6779d81e904fe6668cd366bda9d91f911bf318ea3f4c5f9c246ff9f89523fedd268e27fe77aece89ced8e7f49a18e179977f25a9c3251d07df358c
languageName: node
linkType: hard
"@ai-sdk/google@npm:3.0.46, @ai-sdk/google@npm:^3.0.46":
version: 3.0.46
resolution: "@ai-sdk/google@npm:3.0.46"
dependencies:
"@ai-sdk/provider": "npm:3.0.8"
"@ai-sdk/provider-utils": "npm:4.0.20"
peerDependencies:
zod: ^3.25.76 || ^4.1.8
checksum: 10/295e9f53c6e14e836164a6755d2c50b2840c7a9542919c2684b916c3b8155cf4fabcd30c431a61e12d658e3dceb6af23a284daa63d5311850a03f9d3346038f9
languageName: node
linkType: hard
"@ai-sdk/provider-utils@npm:4.0.20":
version: 4.0.20
resolution: "@ai-sdk/provider-utils@npm:4.0.20"
dependencies:
"@ai-sdk/provider": "npm:3.0.8"
"@standard-schema/spec": "npm:^1.1.0"
eventsource-parser: "npm:^3.0.6"
peerDependencies:
zod: ^3.25.76 || ^4.1.8
checksum: 10/1a2d5adc262582cfff9b86afae37ba6291fae5b9155250f02ee8fdd119a7cc1351960ed20181e6f671c28153daf8d69e864e883dee06b96d36c486e2a1a32be9
languageName: node
linkType: hard
"@ai-sdk/provider@npm:3.0.8":
version: 3.0.8
resolution: "@ai-sdk/provider@npm:3.0.8"
dependencies:
json-schema: "npm:^0.4.0"
checksum: 10/85fb7b9c7cd9ea1aa9840aa57a9517a7ecec8c25a33a31e4615f4eceede9fe61f072b2a2915e4713f2b78c8b94a8c25a79ddbcf998f0d537c02ba47442402542
languageName: node
linkType: hard
"@alloc/quick-lru@npm:^5.2.0": "@alloc/quick-lru@npm:^5.2.0":
version: 5.2.0 version: 5.2.0
resolution: "@alloc/quick-lru@npm:5.2.0" resolution: "@alloc/quick-lru@npm:5.2.0"
@@ -9760,7 +9683,7 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"@opentelemetry/api@npm:1.9.0, @opentelemetry/api@npm:^1.3.0, @opentelemetry/api@npm:^1.9.0": "@opentelemetry/api@npm:^1.3.0, @opentelemetry/api@npm:^1.9.0":
version: 1.9.0 version: 1.9.0
resolution: "@opentelemetry/api@npm:1.9.0" resolution: "@opentelemetry/api@npm:1.9.0"
checksum: 10/a607f0eef971893c4f2ee2a4c2069aade6ec3e84e2a1f5c2aac19f65c5d9eeea41aa72db917c1029faafdd71789a1a040bdc18f40d63690e22ccae5d7070f194 checksum: 10/a607f0eef971893c4f2ee2a4c2069aade6ec3e84e2a1f5c2aac19f65c5d9eeea41aa72db917c1029faafdd71789a1a040bdc18f40d63690e22ccae5d7070f194
@@ -15268,7 +15191,7 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"@standard-schema/spec@npm:^1.0.0, @standard-schema/spec@npm:^1.1.0": "@standard-schema/spec@npm:^1.0.0":
version: 1.1.0 version: 1.1.0
resolution: "@standard-schema/spec@npm:1.1.0" resolution: "@standard-schema/spec@npm:1.1.0"
checksum: 10/a209615c9e8b2ea535d7db0a5f6aa0f962fd4ab73ee86a46c100fb78116964af1f55a27c1794d4801e534a196794223daa25ff5135021e03c7828aa3d95e1763 checksum: 10/a209615c9e8b2ea535d7db0a5f6aa0f962fd4ab73ee86a46c100fb78116964af1f55a27c1794d4801e534a196794223daa25ff5135021e03c7828aa3d95e1763
@@ -15949,13 +15872,6 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"@trysound/sax@npm:0.2.0":
version: 0.2.0
resolution: "@trysound/sax@npm:0.2.0"
checksum: 10/7379713eca480ac0d9b6c7b063e06b00a7eac57092354556c81027066eb65b61ea141a69d0cc2e15d32e05b2834d4c9c2184793a5e36bbf5daf05ee5676af18c
languageName: node
linkType: hard
"@tweakpane/core@npm:^2.0.4": "@tweakpane/core@npm:^2.0.4":
version: 2.0.5 version: 2.0.5
resolution: "@tweakpane/core@npm:2.0.5" resolution: "@tweakpane/core@npm:2.0.5"
@@ -17695,13 +17611,6 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"@vercel/oidc@npm:3.1.0":
version: 3.1.0
resolution: "@vercel/oidc@npm:3.1.0"
checksum: 10/2e7fe962a441bbc8b305639f8ab1830fb3c2bb51affa90ae84431af65a29c98343aa089d84dff3730013f0b3fb8dc67ad10fad97c4ce7fdf584510d79fa3919c
languageName: node
linkType: hard
"@vitejs/plugin-react-swc@npm:^4.0.0": "@vitejs/plugin-react-swc@npm:^4.0.0":
version: 4.2.3 version: 4.2.3
resolution: "@vitejs/plugin-react-swc@npm:4.2.3" resolution: "@vitejs/plugin-react-swc@npm:4.2.3"
@@ -18284,20 +18193,6 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"ai@npm:^6.0.118":
version: 6.0.118
resolution: "ai@npm:6.0.118"
dependencies:
"@ai-sdk/gateway": "npm:3.0.68"
"@ai-sdk/provider": "npm:3.0.8"
"@ai-sdk/provider-utils": "npm:4.0.20"
"@opentelemetry/api": "npm:1.9.0"
peerDependencies:
zod: ^3.25.76 || ^4.1.8
checksum: 10/ec77fe34a4cfe0e4ac283133fd9e838eea741ed1569598b02a95a54c60113153d88638c95e3f67eed8bb1f307c2cdc8310b5338f50fad76c58e8fb0d2dc457eb
languageName: node
linkType: hard
"ajv-formats@npm:^2.1.1": "ajv-formats@npm:^2.1.1":
version: 2.1.1 version: 2.1.1
resolution: "ajv-formats@npm:2.1.1" resolution: "ajv-formats@npm:2.1.1"
@@ -23259,13 +23154,6 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"eventsource-parser@npm:^3.0.6":
version: 3.0.6
resolution: "eventsource-parser@npm:3.0.6"
checksum: 10/febf7058b9c2168ecbb33e92711a1646e06bd1568f60b6eb6a01a8bf9f8fcd29cc8320d57247059cacf657a296280159f21306d2e3ff33309a9552b2ef889387
languageName: node
linkType: hard
"exa-js@npm:^2.4.0": "exa-js@npm:^2.4.0":
version: 2.4.0 version: 2.4.0
resolution: "exa-js@npm:2.4.0" resolution: "exa-js@npm:2.4.0"
@@ -24617,7 +24505,7 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"google-auth-library@npm:^10.1.0, google-auth-library@npm:^10.2.0, google-auth-library@npm:^10.5.0": "google-auth-library@npm:^10.1.0, google-auth-library@npm:^10.2.0":
version: 10.5.0 version: 10.5.0
resolution: "google-auth-library@npm:10.5.0" resolution: "google-auth-library@npm:10.5.0"
dependencies: dependencies:
@@ -26588,13 +26476,6 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"json-schema@npm:^0.4.0":
version: 0.4.0
resolution: "json-schema@npm:0.4.0"
checksum: 10/8b3b64eff4a807dc2a3045b104ed1b9335cd8d57aa74c58718f07f0f48b8baa3293b00af4dcfbdc9144c3aafea1e97982cc27cc8e150fc5d93c540649507a458
languageName: node
linkType: hard
"json-stable-stringify-without-jsonify@npm:^1.0.1": "json-stable-stringify-without-jsonify@npm:^1.0.1":
version: 1.0.1 version: 1.0.1
resolution: "json-stable-stringify-without-jsonify@npm:1.0.1" resolution: "json-stable-stringify-without-jsonify@npm:1.0.1"
@@ -33056,10 +32937,10 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"sax@npm:>=0.6.0, sax@npm:^1.2.4, sax@npm:^1.4.1": "sax@npm:>=0.6.0, sax@npm:^1.2.4, sax@npm:^1.4.1, sax@npm:^1.5.0":
version: 1.4.1 version: 1.5.0
resolution: "sax@npm:1.4.1" resolution: "sax@npm:1.5.0"
checksum: 10/b1c784b545019187b53a0c28edb4f6314951c971e2963a69739c6ce222bfbc767e54d320e689352daba79b7d5e06d22b5d7113b99336219d6e93718e2f99d335 checksum: 10/9012ff37dda7a7ac5da45db2143b04036103e8bef8d586c3023afd5df6caf0ebd7f38017eee344ad2e2247eded7d38e9c42cf291d8dd91781352900ac0fd2d9f
languageName: node languageName: node
linkType: hard linkType: hard
@@ -34375,19 +34256,19 @@ __metadata:
linkType: hard linkType: hard
"svgo@npm:^3.3.2": "svgo@npm:^3.3.2":
version: 3.3.2 version: 3.3.3
resolution: "svgo@npm:3.3.2" resolution: "svgo@npm:3.3.3"
dependencies: dependencies:
"@trysound/sax": "npm:0.2.0"
commander: "npm:^7.2.0" commander: "npm:^7.2.0"
css-select: "npm:^5.1.0" css-select: "npm:^5.1.0"
css-tree: "npm:^2.3.1" css-tree: "npm:^2.3.1"
css-what: "npm:^6.1.0" css-what: "npm:^6.1.0"
csso: "npm:^5.0.5" csso: "npm:^5.0.5"
picocolors: "npm:^1.0.0" picocolors: "npm:^1.0.0"
sax: "npm:^1.5.0"
bin: bin:
svgo: ./bin/svgo svgo: ./bin/svgo
checksum: 10/82fdea9b938884d808506104228e4d3af0050d643d5b46ff7abc903ff47a91bbf6561373394868aaf07a28f006c4057b8fbf14bbd666298abdd7cc590d4f7700 checksum: 10/f3c1b4d05d1704483e53515d5995af5f06a2718df85e3a8320f57bb256b8dc926b84c87a1a9b98e9d3ca1224314cc0676a803bdd03163508292f2d45c7077096
languageName: node languageName: node
linkType: hard linkType: hard
@@ -35311,9 +35192,9 @@ __metadata:
linkType: hard linkType: hard
"underscore@npm:^1.13.1": "underscore@npm:^1.13.1":
version: 1.13.7 version: 1.13.8
resolution: "underscore@npm:1.13.7" resolution: "underscore@npm:1.13.8"
checksum: 10/1ce3368dbe73d1e99678fa5d341a9682bd27316032ad2de7883901918f0f5d50e80320ccc543f53c1862ab057a818abc560462b5f83578afe2dd8dd7f779766c checksum: 10/b50ac5806d059cc180b1bd9adea6f7ed500021f4dc782dfc75d66a90337f6f0506623c1b37863f4a9bf64ffbeb5769b638a54b7f2f5966816189955815953139
languageName: node languageName: node
linkType: hard linkType: hard