feat: refactor copilot module (#14537)

This commit is contained in:
DarkSky
2026-03-02 13:57:55 +08:00
committed by GitHub
parent 60acd81d4b
commit c5d622531c
92 changed files with 5759 additions and 2170 deletions

View File

@@ -988,6 +988,16 @@
}
}
},
"providers.profiles": {
"type": "array",
"description": "The profile list for copilot providers.\n@default []",
"default": []
},
"providers.defaults": {
"type": "object",
"description": "The default provider ids for model output types and global fallback.\n@default {}",
"default": {}
},
"providers.openai": {
"type": "object",
"description": "The config for the openai provider.\n@default {\"apiKey\":\"\",\"baseURL\":\"https://api.openai.com/v1\"}\n@link https://github.com/openai/openai-node",

515
Cargo.lock generated
View File

@@ -181,6 +181,7 @@ dependencies = [
"chrono",
"file-format",
"infer",
"llm_adapter",
"mimalloc",
"mp4parse",
"napi",
@@ -188,6 +189,8 @@ dependencies = [
"napi-derive",
"rand 0.9.2",
"rayon",
"serde",
"serde_json",
"sha3",
"tiktoken-rs",
"tokio",
@@ -245,7 +248,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed7572b7ba83a31e20d1b48970ee402d2e3e0537dcfe0a3ff4d6eb7508617d43"
dependencies = [
"alsa-sys",
"bitflags 2.10.0",
"bitflags 2.11.0",
"cfg-if",
"libc",
]
@@ -458,6 +461,12 @@ dependencies = [
"num-traits",
]
[[package]]
name = "atomic-waker"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "auto_enums"
version = "0.8.7"
@@ -476,6 +485,28 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "aws-lc-rs"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9a7b350e3bb1767102698302bc37256cbd48422809984b98d292c40e2579aa9"
dependencies = [
"aws-lc-sys",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
version = "0.37.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b092fe214090261288111db7a2b2c2118e5a7f30dc2569f1732c4069a6840549"
dependencies = [
"cc",
"cmake",
"dunce",
"fs_extra",
]
[[package]]
name = "base64"
version = "0.22.1"
@@ -533,7 +564,7 @@ version = "0.72.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895"
dependencies = [
"bitflags 2.10.0",
"bitflags 2.11.0",
"cexpr",
"clang-sys",
"itertools 0.13.0",
@@ -583,9 +614,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.10.0"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3"
checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af"
dependencies = [
"serde_core",
]
@@ -904,6 +935,15 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d"
[[package]]
name = "cmake"
version = "0.1.57"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d"
dependencies = [
"cc",
]
[[package]]
name = "colorchoice"
version = "1.0.4"
@@ -983,7 +1023,7 @@ version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1"
dependencies = [
"bitflags 2.10.0",
"bitflags 2.11.0",
"core-foundation",
"core-graphics-types",
"foreign-types",
@@ -996,7 +1036,7 @@ version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "064badf302c3194842cf2c5d61f56cc88e54a759313879cdf03abdd27d0c3b97"
dependencies = [
"bitflags 2.10.0",
"bitflags 2.11.0",
"core-foundation",
"core-graphics-types",
"foreign-types",
@@ -1009,7 +1049,7 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb"
dependencies = [
"bitflags 2.10.0",
"bitflags 2.11.0",
"core-foundation",
"libc",
]
@@ -1379,7 +1419,7 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89a09f22a6c6069a18470eb92d2298acf25463f14256d24778e1230d789a2aec"
dependencies = [
"bitflags 2.10.0",
"bitflags 2.11.0",
"block2",
"libc",
"objc2",
@@ -1442,6 +1482,12 @@ version = "0.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f678cf4a922c215c63e0de95eb1ff08a958a81d47e485cf9da1e27bf6305cfa5"
[[package]]
name = "dunce"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]]
name = "ecb"
version = "0.1.2"
@@ -1664,6 +1710,12 @@ dependencies = [
"autocfg",
]
[[package]]
name = "fs_extra"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "futf"
version = "0.1.5"
@@ -1828,9 +1880,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"r-efi",
"wasip2",
"wasm-bindgen",
]
[[package]]
@@ -1880,7 +1934,7 @@ version = "1.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0c43e7c3212bd992c11b6b9796563388170950521ae8487f5cdf6f6e792f1c8"
dependencies = [
"bitflags 2.10.0",
"bitflags 2.11.0",
"proc-macro2",
"quote",
"syn 1.0.109",
@@ -2003,6 +2057,105 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "http"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a"
dependencies = [
"bytes",
"itoa",
]
[[package]]
name = "http-body"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [
"bytes",
"http",
]
[[package]]
name = "http-body-util"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [
"bytes",
"futures-core",
"http",
"http-body",
"pin-project-lite",
]
[[package]]
name = "httparse"
version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
[[package]]
name = "hyper"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11"
dependencies = [
"atomic-waker",
"bytes",
"futures-channel",
"futures-core",
"http",
"http-body",
"httparse",
"itoa",
"pin-project-lite",
"pin-utils",
"smallvec",
"tokio",
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58"
dependencies = [
"http",
"hyper",
"hyper-util",
"rustls",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]]
name = "hyper-util"
version = "0.1.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0"
dependencies = [
"base64",
"bytes",
"futures-channel",
"futures-util",
"http",
"http-body",
"hyper",
"ipnet",
"libc",
"percent-encoding",
"pin-project-lite",
"socket2",
"tokio",
"tower-service",
"tracing",
]
[[package]]
name = "iana-time-zone"
version = "0.1.64"
@@ -2253,6 +2406,22 @@ dependencies = [
"leaky-cow",
]
[[package]]
name = "ipnet"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "iri-string"
version = "0.7.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a"
dependencies = [
"memchr",
"serde",
]
[[package]]
name = "is-terminal"
version = "0.4.17"
@@ -2376,9 +2545,9 @@ dependencies = [
[[package]]
name = "keccak"
version = "0.1.5"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654"
checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653"
dependencies = [
"cpufeatures",
]
@@ -2490,7 +2659,7 @@ version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616"
dependencies = [
"bitflags 2.10.0",
"bitflags 2.11.0",
"libc",
"redox_syscall 0.7.0",
]
@@ -2518,6 +2687,19 @@ version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
[[package]]
name = "llm_adapter"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8dd9a548766bccf8b636695e8d514edee672d180e96a16ab932c971783b4e353"
dependencies = [
"base64",
"reqwest",
"serde",
"serde_json",
"thiserror 2.0.17",
]
[[package]]
name = "lock_api"
version = "0.4.14"
@@ -2555,7 +2737,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59fa2559e99ba0f26a12458aabc754432c805bbb8cba516c427825a997af1fb7"
dependencies = [
"aes",
"bitflags 2.10.0",
"bitflags 2.11.0",
"cbc",
"ecb",
"encoding_rs",
@@ -2583,6 +2765,12 @@ dependencies = [
"hashbrown 0.16.1",
]
[[package]]
name = "lru-slab"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
[[package]]
name = "mac"
version = "0.1.1"
@@ -2741,7 +2929,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "000f205daae6646003fdc38517be6232af2b150bad4b67bdaf4c5aadb119d738"
dependencies = [
"anyhow",
"bitflags 2.10.0",
"bitflags 2.11.0",
"chrono",
"ctor",
"futures",
@@ -2801,7 +2989,7 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2076a31b7010b17a38c01907c45b945e8f11495ee4dd588309718901b1f7a5b7"
dependencies = [
"bitflags 2.10.0",
"bitflags 2.11.0",
"jni-sys",
"log",
"ndk-sys",
@@ -2836,7 +3024,7 @@ version = "0.30.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6"
dependencies = [
"bitflags 2.10.0",
"bitflags 2.11.0",
"cfg-if",
"cfg_aliases",
"libc",
@@ -3000,7 +3188,7 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536"
dependencies = [
"bitflags 2.10.0",
"bitflags 2.11.0",
"dispatch2",
"objc2",
]
@@ -3017,7 +3205,7 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3e0adef53c21f888deb4fa59fc59f7eb17404926ee8a6f59f5df0fd7f9f3272"
dependencies = [
"bitflags 2.10.0",
"bitflags 2.11.0",
"block2",
"libc",
"objc2",
@@ -3074,6 +3262,12 @@ version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]]
name = "openssl-probe"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
[[package]]
name = "ordered-float"
version = "5.1.0"
@@ -3469,7 +3663,7 @@ checksum = "bee689443a2bd0a16ab0348b52ee43e3b2d1b1f931c8aa5c9f8de4c86fbe8c40"
dependencies = [
"bit-set 0.8.0",
"bit-vec 0.8.0",
"bitflags 2.10.0",
"bitflags 2.11.0",
"num-traits",
"rand 0.9.2",
"rand_chacha 0.9.0",
@@ -3497,7 +3691,7 @@ version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e8bbe1a966bd2f362681a44f6edce3c2310ac21e4d5067a6e7ec396297a6ea0"
dependencies = [
"bitflags 2.10.0",
"bitflags 2.11.0",
"getopts",
"memchr",
"pulldown-cmark-escape",
@@ -3516,6 +3710,62 @@ version = "1.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
[[package]]
name = "quinn"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20"
dependencies = [
"bytes",
"cfg_aliases",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash 2.1.1",
"rustls",
"socket2",
"thiserror 2.0.17",
"tokio",
"tracing",
"web-time",
]
[[package]]
name = "quinn-proto"
version = "0.11.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31"
dependencies = [
"aws-lc-rs",
"bytes",
"getrandom 0.3.4",
"lru-slab",
"rand 0.9.2",
"ring",
"rustc-hash 2.1.1",
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.17",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.60.2",
]
[[package]]
name = "quote"
version = "1.0.43"
@@ -3663,7 +3913,7 @@ version = "0.5.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d"
dependencies = [
"bitflags 2.10.0",
"bitflags 2.11.0",
]
[[package]]
@@ -3672,7 +3922,7 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49f3fe0889e69e2ae9e41f4d6c4c0181701d00e4697b356fb1f74173a5e0ee27"
dependencies = [
"bitflags 2.10.0",
"bitflags 2.11.0",
]
[[package]]
@@ -3704,6 +3954,45 @@ version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58"
[[package]]
name = "reqwest"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801"
dependencies = [
"base64",
"bytes",
"futures-channel",
"futures-core",
"futures-util",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-util",
"js-sys",
"log",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-pki-types",
"rustls-platform-verifier",
"serde",
"serde_json",
"sync_wrapper",
"tokio",
"tokio-rustls",
"tower",
"tower-http",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "ring"
version = "0.17.14"
@@ -3831,7 +4120,7 @@ version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34"
dependencies = [
"bitflags 2.10.0",
"bitflags 2.11.0",
"errno",
"libc",
"linux-raw-sys",
@@ -3844,6 +4133,7 @@ version = "0.23.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b"
dependencies = [
"aws-lc-rs",
"once_cell",
"ring",
"rustls-pki-types",
@@ -3852,21 +4142,62 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rustls-native-certs"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63"
dependencies = [
"openssl-probe",
"rustls-pki-types",
"schannel",
"security-framework",
]
[[package]]
name = "rustls-pki-types"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282"
dependencies = [
"web-time",
"zeroize",
]
[[package]]
name = "rustls-platform-verifier"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784"
dependencies = [
"core-foundation",
"core-foundation-sys",
"jni",
"log",
"once_cell",
"rustls",
"rustls-native-certs",
"rustls-platform-verifier-android",
"rustls-webpki",
"security-framework",
"security-framework-sys",
"webpki-root-certs",
"windows-sys 0.61.2",
]
[[package]]
name = "rustls-platform-verifier-android"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"
[[package]]
name = "rustls-webpki"
version = "0.103.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52"
dependencies = [
"aws-lc-rs",
"ring",
"rustls-pki-types",
"untrusted",
@@ -3905,6 +4236,15 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1"
dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "scoped-tls"
version = "1.0.1"
@@ -3953,6 +4293,29 @@ dependencies = [
"syn 2.0.114",
]
[[package]]
name = "security-framework"
version = "3.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d"
dependencies = [
"bitflags 2.11.0",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "semver"
version = "1.0.27"
@@ -4269,7 +4632,7 @@ checksum = "aa003f0038df784eb8fecbbac13affe3da23b45194bd57dba231c8f48199c526"
dependencies = [
"atoi",
"base64",
"bitflags 2.10.0",
"bitflags 2.11.0",
"byteorder",
"bytes",
"chrono",
@@ -4312,7 +4675,7 @@ checksum = "db58fcd5a53cf07c184b154801ff91347e4c30d17a3562a635ff028ad5deda46"
dependencies = [
"atoi",
"base64",
"bitflags 2.10.0",
"bitflags 2.11.0",
"byteorder",
"chrono",
"crc",
@@ -4678,6 +5041,15 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
dependencies = [
"futures-core",
]
[[package]]
name = "synstructure"
version = "0.13.2"
@@ -4869,6 +5241,16 @@ dependencies = [
"syn 2.0.114",
]
[[package]]
name = "tokio-rustls"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61"
dependencies = [
"rustls",
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.18"
@@ -4912,13 +5294,58 @@ dependencies = [
[[package]]
name = "toml_parser"
version = "1.0.6+spec-1.1.0"
version = "1.0.9+spec-1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3198b4b0a8e11f09dd03e133c0280504d0801269e9afa46362ffde1cbeebf44"
checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4"
dependencies = [
"winnow",
]
[[package]]
name = "tower"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4"
dependencies = [
"futures-core",
"futures-util",
"pin-project-lite",
"sync_wrapper",
"tokio",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-http"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8"
dependencies = [
"bitflags 2.11.0",
"bytes",
"futures-util",
"http",
"http-body",
"iri-string",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-layer"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
[[package]]
name = "tower-service"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tracing"
version = "0.1.44"
@@ -5121,6 +5548,12 @@ dependencies = [
"tree-sitter-language",
]
[[package]]
name = "try-lock"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "type1-encoding-parser"
version = "0.1.0"
@@ -5441,6 +5874,15 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "want"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e"
dependencies = [
"try-lock",
]
[[package]]
name = "wasi"
version = "0.11.1+wasi-snapshot-preview1"
@@ -5530,6 +5972,25 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "webpki-root-certs"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca"
dependencies = [
"rustls-pki-types",
]
[[package]]
name = "webpki-roots"
version = "0.26.11"

View File

@@ -44,6 +44,7 @@ resolver = "3"
lasso = { version = "0.7", features = ["multi-threaded"] }
lib0 = { version = "0.16", features = ["lib0-serde"] }
libc = "0.2"
llm_adapter = "0.1.1"
log = "0.4"
loom = { version = "0.7", features = ["checkpoint"] }
lru = "0.16"

View File

@@ -108,7 +108,9 @@ export class BookmarkBlockComponent extends CaptionedBlockComponent<BookmarkBloc
}
open = () => {
window.open(this.link, '_blank');
const link = this.link;
if (!link) return;
window.open(link, '_blank', 'noopener,noreferrer');
};
refreshData = () => {

View File

@@ -1,4 +1,8 @@
import { getHostName } from '@blocksuite/affine-shared/utils';
import {
getHostName,
isValidUrl,
normalizeUrl,
} from '@blocksuite/affine-shared/utils';
import { PropTypes, requiredProperties } from '@blocksuite/std';
import { css, LitElement } from 'lit';
import { property } from 'lit/decorators.js';
@@ -44,15 +48,27 @@ export class LinkPreview extends LitElement {
override render() {
const { url } = this;
const normalizedUrl = normalizeUrl(url);
const safeUrl =
normalizedUrl && isValidUrl(normalizedUrl) ? normalizedUrl : null;
const hostName = getHostName(safeUrl ?? url);
if (!safeUrl) {
return html`
<span class="affine-link-preview">
<span>${hostName}</span>
</span>
`;
}
return html`
<a
class="affine-link-preview"
rel="noopener noreferrer"
target="_blank"
href=${url}
href=${safeUrl}
>
<span>${getHostName(url)}</span>
<span>${hostName}</span>
</a>
`;
}

View File

@@ -4,6 +4,7 @@ import type { FootNote } from '@blocksuite/affine-model';
import { CitationProvider } from '@blocksuite/affine-shared/services';
import { unsafeCSSVarV2 } from '@blocksuite/affine-shared/theme';
import type { AffineTextAttributes } from '@blocksuite/affine-shared/types';
import { isValidUrl, normalizeUrl } from '@blocksuite/affine-shared/utils';
import { WithDisposable } from '@blocksuite/global/lit';
import {
BlockSelection,
@@ -152,7 +153,9 @@ export class AffineFootnoteNode extends WithDisposable(ShadowlessElement) {
};
private readonly _handleUrlReference = (url: string) => {
window.open(url, '_blank');
const normalizedUrl = normalizeUrl(url);
if (!normalizedUrl || !isValidUrl(normalizedUrl)) return;
window.open(normalizedUrl, '_blank', 'noopener,noreferrer');
};
private readonly _updateFootnoteAttributes = (footnote: FootNote) => {

View File

@@ -24,6 +24,11 @@ const toURL = (str: string) => {
}
};
const hasAllowedScheme = (url: URL) => {
const protocol = url.protocol.slice(0, -1).toLowerCase();
return ALLOWED_SCHEMES.has(protocol);
};
function resolveURL(str: string, baseUrl: string, padded = false) {
const url = toURL(str);
if (!url) return null;
@@ -61,6 +66,7 @@ export function normalizeUrl(str: string) {
// Formatted
if (url) {
if (!hasAllowedScheme(url)) return '';
if (!str.endsWith('/') && url.href.endsWith('/')) {
return url.href.substring(0, url.href.length - 1);
}

View File

@@ -22,7 +22,7 @@
"af": "r affine.ts",
"dev": "yarn affine dev",
"build": "yarn affine build",
"lint:eslint": "cross-env NODE_OPTIONS=\"--max-old-space-size=8192\" eslint --report-unused-disable-directives-severity=off . --cache",
"lint:eslint": "cross-env NODE_OPTIONS=\"--max-old-space-size=16384\" eslint --report-unused-disable-directives-severity=off . --cache",
"lint:eslint:fix": "yarn lint:eslint --fix --fix-type problem,suggestion,layout",
"lint:prettier": "prettier --ignore-unknown --cache --check .",
"lint:prettier:fix": "prettier --ignore-unknown --cache --write .",

View File

@@ -17,10 +17,13 @@ affine_common = { workspace = true, features = [
chrono = { workspace = true }
file-format = { workspace = true }
infer = { workspace = true }
llm_adapter = { workspace = true }
mp4parse = { workspace = true }
napi = { workspace = true, features = ["async"] }
napi-derive = { workspace = true }
rand = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
sha3 = { workspace = true }
tiktoken-rs = { workspace = true }
v_htmlescape = { workspace = true }

View File

@@ -1,5 +1,9 @@
/* auto-generated by NAPI-RS */
/* eslint-disable */
export declare class LlmStreamHandle {
abort(): void
}
export declare class Tokenizer {
count(content: string, allowedSpecial?: Array<string> | undefined | null): number
}
@@ -46,6 +50,10 @@ export declare function getMime(input: Uint8Array): string
export declare function htmlSanitize(input: string): string
export declare function llmDispatch(protocol: string, backendConfigJson: string, requestJson: string): string
export declare function llmDispatchStream(protocol: string, backendConfigJson: string, requestJson: string, callback: ((err: Error | null, arg: string) => void)): LlmStreamHandle
/**
* Merge updates in form like `Y.applyUpdate(doc, update)` way and return the
* result binary.

View File

@@ -7,6 +7,7 @@ pub mod doc_loader;
pub mod file_type;
pub mod hashcash;
pub mod html_sanitize;
pub mod llm;
pub mod tiktoken;
use affine_common::napi_utils::map_napi_err;

View File

@@ -0,0 +1,339 @@
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
use llm_adapter::{
backend::{
BackendConfig, BackendError, BackendProtocol, ReqwestHttpClient, dispatch_request, dispatch_stream_events_with,
},
core::{CoreRequest, StreamEvent},
middleware::{
MiddlewareConfig, PipelineContext, RequestMiddleware, StreamMiddleware, citation_indexing, clamp_max_tokens,
normalize_messages, run_request_middleware_chain, run_stream_middleware_chain, stream_event_normalize,
tool_schema_rewrite,
},
};
use napi::{
Error, Result, Status,
threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode},
};
use serde::Deserialize;
pub const STREAM_END_MARKER: &str = "__AFFINE_LLM_STREAM_END__";
const STREAM_ABORTED_REASON: &str = "__AFFINE_LLM_STREAM_ABORTED__";
const STREAM_CALLBACK_DISPATCH_FAILED_REASON: &str = "__AFFINE_LLM_STREAM_CALLBACK_DISPATCH_FAILED__";
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(default)]
struct LlmMiddlewarePayload {
request: Vec<String>,
stream: Vec<String>,
config: MiddlewareConfig,
}
#[derive(Debug, Clone, Deserialize)]
struct LlmDispatchPayload {
#[serde(flatten)]
request: CoreRequest,
#[serde(default)]
middleware: LlmMiddlewarePayload,
}
#[napi]
pub struct LlmStreamHandle {
aborted: Arc<AtomicBool>,
}
#[napi]
impl LlmStreamHandle {
#[napi]
pub fn abort(&self) {
self.aborted.store(true, Ordering::SeqCst);
}
}
#[napi(catch_unwind)]
pub fn llm_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result<String> {
let protocol = parse_protocol(&protocol)?;
let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?;
let payload: LlmDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?;
let request = apply_request_middlewares(payload.request, &payload.middleware)?;
let response =
dispatch_request(&ReqwestHttpClient::default(), &config, protocol, &request).map_err(map_backend_error)?;
serde_json::to_string(&response).map_err(map_json_error)
}
#[napi(catch_unwind)]
pub fn llm_dispatch_stream(
protocol: String,
backend_config_json: String,
request_json: String,
callback: ThreadsafeFunction<String, ()>,
) -> Result<LlmStreamHandle> {
let protocol = parse_protocol(&protocol)?;
let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?;
let payload: LlmDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?;
let request = apply_request_middlewares(payload.request, &payload.middleware)?;
let middleware = payload.middleware.clone();
let aborted = Arc::new(AtomicBool::new(false));
let aborted_in_worker = aborted.clone();
std::thread::spawn(move || {
let chain = match resolve_stream_chain(&middleware.stream) {
Ok(chain) => chain,
Err(error) => {
emit_error_event(&callback, error.reason.clone(), "middleware_error");
let _ = callback.call(
Ok(STREAM_END_MARKER.to_string()),
ThreadsafeFunctionCallMode::NonBlocking,
);
return;
}
};
let mut pipeline = StreamPipeline::new(chain, middleware.config.clone());
let mut aborted_by_user = false;
let mut callback_dispatch_failed = false;
let result = dispatch_stream_events_with(&ReqwestHttpClient::default(), &config, protocol, &request, |event| {
if aborted_in_worker.load(Ordering::Relaxed) {
aborted_by_user = true;
return Err(BackendError::Http(STREAM_ABORTED_REASON.to_string()));
}
for event in pipeline.process(event) {
let status = emit_stream_event(&callback, &event);
if status != Status::Ok {
callback_dispatch_failed = true;
return Err(BackendError::Http(format!(
"{STREAM_CALLBACK_DISPATCH_FAILED_REASON}:{status}"
)));
}
}
Ok(())
});
if !aborted_by_user {
for event in pipeline.finish() {
if aborted_in_worker.load(Ordering::Relaxed) {
aborted_by_user = true;
break;
}
if emit_stream_event(&callback, &event) != Status::Ok {
callback_dispatch_failed = true;
break;
}
}
}
if let Err(error) = result
&& !aborted_by_user
&& !callback_dispatch_failed
&& !is_abort_error(&error)
&& !is_callback_dispatch_failed_error(&error)
{
emit_error_event(&callback, error.to_string(), "dispatch_error");
}
if !callback_dispatch_failed {
let _ = callback.call(
Ok(STREAM_END_MARKER.to_string()),
ThreadsafeFunctionCallMode::NonBlocking,
);
}
});
Ok(LlmStreamHandle { aborted })
}
fn apply_request_middlewares(request: CoreRequest, middleware: &LlmMiddlewarePayload) -> Result<CoreRequest> {
let chain = resolve_request_chain(&middleware.request)?;
Ok(run_request_middleware_chain(request, &middleware.config, &chain))
}
#[derive(Clone)]
struct StreamPipeline {
chain: Vec<StreamMiddleware>,
config: MiddlewareConfig,
context: PipelineContext,
}
impl StreamPipeline {
fn new(chain: Vec<StreamMiddleware>, config: MiddlewareConfig) -> Self {
Self {
chain,
config,
context: PipelineContext::default(),
}
}
fn process(&mut self, event: StreamEvent) -> Vec<StreamEvent> {
run_stream_middleware_chain(event, &mut self.context, &self.config, &self.chain)
}
fn finish(&mut self) -> Vec<StreamEvent> {
self.context.flush_pending_deltas();
self.context.drain_queued_events()
}
}
fn emit_stream_event(callback: &ThreadsafeFunction<String, ()>, event: &StreamEvent) -> Status {
let value = serde_json::to_string(event).unwrap_or_else(|error| {
serde_json::json!({
"type": "error",
"message": format!("failed to serialize stream event: {error}"),
})
.to_string()
});
callback.call(Ok(value), ThreadsafeFunctionCallMode::NonBlocking)
}
fn emit_error_event(callback: &ThreadsafeFunction<String, ()>, message: String, code: &str) {
let error_event = serde_json::to_string(&StreamEvent::Error {
message: message.clone(),
code: Some(code.to_string()),
})
.unwrap_or_else(|_| {
serde_json::json!({
"type": "error",
"message": message,
"code": code,
})
.to_string()
});
let _ = callback.call(Ok(error_event), ThreadsafeFunctionCallMode::NonBlocking);
}
fn is_abort_error(error: &BackendError) -> bool {
matches!(
error,
BackendError::Http(reason) if reason == STREAM_ABORTED_REASON
)
}
fn is_callback_dispatch_failed_error(error: &BackendError) -> bool {
matches!(
error,
BackendError::Http(reason) if reason.starts_with(STREAM_CALLBACK_DISPATCH_FAILED_REASON)
)
}
fn resolve_request_chain(request: &[String]) -> Result<Vec<RequestMiddleware>> {
if request.is_empty() {
return Ok(vec![normalize_messages, tool_schema_rewrite]);
}
request
.iter()
.map(|name| match name.as_str() {
"normalize_messages" => Ok(normalize_messages as RequestMiddleware),
"clamp_max_tokens" => Ok(clamp_max_tokens as RequestMiddleware),
"tool_schema_rewrite" => Ok(tool_schema_rewrite as RequestMiddleware),
_ => Err(Error::new(
Status::InvalidArg,
format!("Unsupported request middleware: {name}"),
)),
})
.collect()
}
fn resolve_stream_chain(stream: &[String]) -> Result<Vec<StreamMiddleware>> {
if stream.is_empty() {
return Ok(vec![stream_event_normalize, citation_indexing]);
}
stream
.iter()
.map(|name| match name.as_str() {
"stream_event_normalize" => Ok(stream_event_normalize as StreamMiddleware),
"citation_indexing" => Ok(citation_indexing as StreamMiddleware),
_ => Err(Error::new(
Status::InvalidArg,
format!("Unsupported stream middleware: {name}"),
)),
})
.collect()
}
fn parse_protocol(protocol: &str) -> Result<BackendProtocol> {
match protocol {
"openai_chat" | "openai-chat" | "openai_chat_completions" | "chat-completions" | "chat_completions" => {
Ok(BackendProtocol::OpenaiChatCompletions)
}
"openai_responses" | "openai-responses" | "responses" => Ok(BackendProtocol::OpenaiResponses),
"anthropic" | "anthropic_messages" | "anthropic-messages" => Ok(BackendProtocol::AnthropicMessages),
other => Err(Error::new(
Status::InvalidArg,
format!("Unsupported llm backend protocol: {other}"),
)),
}
}
fn map_json_error(error: serde_json::Error) -> Error {
Error::new(Status::InvalidArg, format!("Invalid JSON payload: {error}"))
}
fn map_backend_error(error: BackendError) -> Error {
Error::new(Status::GenericFailure, error.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_parse_supported_protocol_aliases() {
assert!(parse_protocol("openai_chat").is_ok());
assert!(parse_protocol("chat-completions").is_ok());
assert!(parse_protocol("responses").is_ok());
assert!(parse_protocol("anthropic").is_ok());
}
#[test]
fn should_reject_unsupported_protocol() {
let error = parse_protocol("unknown").unwrap_err();
assert_eq!(error.status, Status::InvalidArg);
assert!(error.reason.contains("Unsupported llm backend protocol"));
}
#[test]
fn llm_dispatch_should_reject_invalid_backend_json() {
let error = llm_dispatch("openai_chat".to_string(), "{".to_string(), "{}".to_string()).unwrap_err();
assert_eq!(error.status, Status::InvalidArg);
assert!(error.reason.contains("Invalid JSON payload"));
}
#[test]
fn map_json_error_should_use_invalid_arg_status() {
let parse_error = serde_json::from_str::<serde_json::Value>("{").unwrap_err();
let error = map_json_error(parse_error);
assert_eq!(error.status, Status::InvalidArg);
assert!(error.reason.contains("Invalid JSON payload"));
}
#[test]
fn resolve_request_chain_should_support_clamp_max_tokens() {
let chain = resolve_request_chain(&["normalize_messages".to_string(), "clamp_max_tokens".to_string()]).unwrap();
assert_eq!(chain.len(), 2);
}
#[test]
fn resolve_request_chain_should_reject_unknown_middleware() {
let error = resolve_request_chain(&["unknown".to_string()]).unwrap_err();
assert_eq!(error.status, Status::InvalidArg);
assert!(error.reason.contains("Unsupported request middleware"));
}
#[test]
fn resolve_stream_chain_should_reject_unknown_middleware() {
let error = resolve_stream_chain(&["unknown".to_string()]).unwrap_err();
assert_eq!(error.status, Status::InvalidArg);
assert!(error.reason.contains("Unsupported stream middleware"));
}
}

View File

@@ -12,9 +12,9 @@
"dev": "nodemon ./src/index.ts",
"dev:mail": "email dev -d src/mails",
"test": "ava --concurrency 1 --serial",
"test:copilot": "ava \"src/__tests__/copilot-*.spec.ts\"",
"test:copilot": "ava \"src/__tests__/copilot/copilot-*.spec.ts\"",
"test:coverage": "c8 ava --concurrency 1 --serial",
"test:copilot:coverage": "c8 ava --timeout=5m \"src/__tests__/copilot-*.spec.ts\"",
"test:copilot:coverage": "c8 ava --timeout=5m \"src/__tests__/copilot/copilot-*.spec.ts\"",
"e2e": "cross-env TEST_MODE=e2e ava --serial",
"e2e:coverage": "cross-env TEST_MODE=e2e c8 ava --serial",
"data-migration": "cross-env NODE_ENV=development SERVER_FLAVOR=script r ./src/index.ts",
@@ -28,12 +28,8 @@
"dependencies": {
"@affine/s3-compat": "workspace:*",
"@affine/server-native": "workspace:*",
"@ai-sdk/anthropic": "^2.0.54",
"@ai-sdk/google": "^2.0.45",
"@ai-sdk/google-vertex": "^3.0.88",
"@ai-sdk/openai": "^2.0.80",
"@ai-sdk/openai-compatible": "^1.0.28",
"@ai-sdk/perplexity": "^2.0.21",
"@apollo/server": "^4.13.0",
"@fal-ai/serverless-client": "^0.15.0",
"@google-cloud/opentelemetry-cloud-trace-exporter": "^3.0.0",

View File

@@ -43,7 +43,9 @@ Generated by [AVA](https://avajs.dev).
> Snapshot 5
Buffer @Uint8Array [
66616b65 20696d61 6765
89504e47 0d0a1a0a 0000000d 49484452 00000001 00000001 08040000 00b51c0c
02000000 0b494441 5478da63 fcff1f00 03030200 efa37c9f 00000000 49454e44
ae426082
]
## should preview link

View File

@@ -12,12 +12,12 @@ Generated by [AVA](https://avajs.dev).
{
messages: [
{
content: 'generate text to text',
content: 'generate text to text stream',
role: 'assistant',
},
],
pinned: false,
tokens: 8,
tokens: 10,
},
]
@@ -27,12 +27,12 @@ Generated by [AVA](https://avajs.dev).
{
messages: [
{
content: 'generate text to text',
content: 'generate text to text stream',
role: 'assistant',
},
],
pinned: false,
tokens: 8,
tokens: 10,
},
]

View File

@@ -4,31 +4,31 @@ import type { ExecutionContext, TestFn } from 'ava';
import ava from 'ava';
import { z } from 'zod';
import { ServerFeature, ServerService } from '../core';
import { AuthService } from '../core/auth';
import { QuotaModule } from '../core/quota';
import { Models } from '../models';
import { CopilotModule } from '../plugins/copilot';
import { prompts, PromptService } from '../plugins/copilot/prompt';
import { ServerFeature, ServerService } from '../../core';
import { AuthService } from '../../core/auth';
import { QuotaModule } from '../../core/quota';
import { Models } from '../../models';
import { CopilotModule } from '../../plugins/copilot';
import { prompts, PromptService } from '../../plugins/copilot/prompt';
import {
CopilotProviderFactory,
CopilotProviderType,
StreamObject,
StreamObjectSchema,
} from '../plugins/copilot/providers';
import { TranscriptionResponseSchema } from '../plugins/copilot/transcript/types';
} from '../../plugins/copilot/providers';
import { TranscriptionResponseSchema } from '../../plugins/copilot/transcript/types';
import {
CopilotChatTextExecutor,
CopilotWorkflowService,
GraphExecutorState,
} from '../plugins/copilot/workflow';
} from '../../plugins/copilot/workflow';
import {
CopilotChatImageExecutor,
CopilotCheckHtmlExecutor,
CopilotCheckJsonExecutor,
} from '../plugins/copilot/workflow/executor';
import { createTestingModule, TestingModule } from './utils';
import { TestAssets } from './utils/copilot';
} from '../../plugins/copilot/workflow/executor';
import { createTestingModule, TestingModule } from '../utils';
import { TestAssets } from '../utils/copilot';
type Tester = {
auth: AuthService;

View File

@@ -6,25 +6,25 @@ import type { TestFn } from 'ava';
import ava from 'ava';
import Sinon from 'sinon';
import { AppModule } from '../app.module';
import { JobQueue } from '../base';
import { ConfigModule } from '../base/config';
import { AuthService } from '../core/auth';
import { DocReader } from '../core/doc';
import { CopilotContextService } from '../plugins/copilot/context';
import { AppModule } from '../../app.module';
import { JobQueue } from '../../base';
import { ConfigModule } from '../../base/config';
import { AuthService } from '../../core/auth';
import { DocReader } from '../../core/doc';
import { CopilotContextService } from '../../plugins/copilot/context';
import {
CopilotEmbeddingJob,
MockEmbeddingClient,
} from '../plugins/copilot/embedding';
import { prompts, PromptService } from '../plugins/copilot/prompt';
} from '../../plugins/copilot/embedding';
import { prompts, PromptService } from '../../plugins/copilot/prompt';
import {
CopilotProviderFactory,
CopilotProviderType,
GeminiGenerativeProvider,
OpenAIProvider,
} from '../plugins/copilot/providers';
import { CopilotStorage } from '../plugins/copilot/storage';
import { MockCopilotProvider } from './mocks';
} from '../../plugins/copilot/providers';
import { CopilotStorage } from '../../plugins/copilot/storage';
import { MockCopilotProvider } from '../mocks';
import {
acceptInviteById,
createTestingApp,
@@ -33,7 +33,7 @@ import {
smallestPng,
TestingApp,
TestUser,
} from './utils';
} from '../utils';
import {
addContextDoc,
addContextFile,
@@ -67,7 +67,7 @@ import {
textToEventStream,
unsplashSearch,
updateCopilotSession,
} from './utils/copilot';
} from '../utils/copilot';
const test = ava as TestFn<{
auth: AuthService;
@@ -513,7 +513,11 @@ test('should be able to chat with api', async t => {
);
const messageId = await createCopilotMessage(app, sessionId);
const ret = await chatWithText(app, sessionId, messageId);
t.is(ret, 'generate text to text', 'should be able to chat with text');
t.is(
ret,
'generate text to text stream',
'should be able to chat with text'
);
const ret2 = await chatWithTextStream(app, sessionId, messageId);
t.is(
@@ -657,7 +661,7 @@ test('should be able to retry with api', async t => {
const histories = await getHistories(app, { workspaceId: id, docId });
t.deepEqual(
histories.map(h => h.messages.map(m => m.content)),
[['generate text to text', 'generate text to text']],
[['generate text to text stream', 'generate text to text stream']],
'should be able to list history'
);
}
@@ -794,7 +798,7 @@ test('should be able to list history', async t => {
const histories = await getHistories(app, { workspaceId, docId });
t.deepEqual(
histories.map(h => h.messages.map(m => m.content)),
[['hello', 'generate text to text']],
[['hello', 'generate text to text stream']],
'should be able to list history'
);
}
@@ -807,7 +811,7 @@ test('should be able to list history', async t => {
});
t.deepEqual(
histories.map(h => h.messages.map(m => m.content)),
[['generate text to text', 'hello']],
[['generate text to text stream', 'hello']],
'should be able to list history'
);
}
@@ -858,7 +862,7 @@ test('should reject request that user have not permission', async t => {
const histories = await getHistories(app, { workspaceId, docId });
t.deepEqual(
histories.map(h => h.messages.map(m => m.content)),
[['generate text to text']],
[['generate text to text stream']],
'should able to list history'
);

View File

@@ -8,38 +8,38 @@ import ava from 'ava';
import { nanoid } from 'nanoid';
import Sinon from 'sinon';
import { EventBus, JobQueue } from '../base';
import { ConfigModule } from '../base/config';
import { AuthService } from '../core/auth';
import { QuotaModule } from '../core/quota';
import { StorageModule, WorkspaceBlobStorage } from '../core/storage';
import { EventBus, JobQueue } from '../../base';
import { ConfigModule } from '../../base/config';
import { AuthService } from '../../core/auth';
import { QuotaModule } from '../../core/quota';
import { StorageModule, WorkspaceBlobStorage } from '../../core/storage';
import {
ContextCategories,
CopilotSessionModel,
WorkspaceModel,
} from '../models';
import { CopilotModule } from '../plugins/copilot';
import { CopilotContextService } from '../plugins/copilot/context';
import { CopilotCronJobs } from '../plugins/copilot/cron';
} from '../../models';
import { CopilotModule } from '../../plugins/copilot';
import { CopilotContextService } from '../../plugins/copilot/context';
import { CopilotCronJobs } from '../../plugins/copilot/cron';
import {
CopilotEmbeddingJob,
MockEmbeddingClient,
} from '../plugins/copilot/embedding';
import { prompts, PromptService } from '../plugins/copilot/prompt';
} from '../../plugins/copilot/embedding';
import { prompts, PromptService } from '../../plugins/copilot/prompt';
import {
CopilotProviderFactory,
CopilotProviderType,
ModelInputType,
ModelOutputType,
OpenAIProvider,
} from '../plugins/copilot/providers';
} from '../../plugins/copilot/providers';
import {
CitationParser,
TextStreamParser,
} from '../plugins/copilot/providers/utils';
import { ChatSessionService } from '../plugins/copilot/session';
import { CopilotStorage } from '../plugins/copilot/storage';
import { CopilotTranscriptionService } from '../plugins/copilot/transcript';
} from '../../plugins/copilot/providers/utils';
import { ChatSessionService } from '../../plugins/copilot/session';
import { CopilotStorage } from '../../plugins/copilot/storage';
import { CopilotTranscriptionService } from '../../plugins/copilot/transcript';
import {
CopilotChatTextExecutor,
CopilotWorkflowService,
@@ -48,7 +48,7 @@ import {
WorkflowGraphExecutor,
type WorkflowNodeData,
WorkflowNodeType,
} from '../plugins/copilot/workflow';
} from '../../plugins/copilot/workflow';
import {
CopilotChatImageExecutor,
CopilotCheckHtmlExecutor,
@@ -56,16 +56,16 @@ import {
getWorkflowExecutor,
NodeExecuteState,
NodeExecutorType,
} from '../plugins/copilot/workflow/executor';
import { AutoRegisteredWorkflowExecutor } from '../plugins/copilot/workflow/executor/utils';
import { WorkflowGraphList } from '../plugins/copilot/workflow/graph';
import { CopilotWorkspaceService } from '../plugins/copilot/workspace';
import { PaymentModule } from '../plugins/payment';
import { SubscriptionService } from '../plugins/payment/service';
import { SubscriptionStatus } from '../plugins/payment/types';
import { MockCopilotProvider } from './mocks';
import { createTestingModule, TestingModule } from './utils';
import { WorkflowTestCases } from './utils/copilot';
} from '../../plugins/copilot/workflow/executor';
import { AutoRegisteredWorkflowExecutor } from '../../plugins/copilot/workflow/executor/utils';
import { WorkflowGraphList } from '../../plugins/copilot/workflow/graph';
import { CopilotWorkspaceService } from '../../plugins/copilot/workspace';
import { PaymentModule } from '../../plugins/payment';
import { SubscriptionService } from '../../plugins/payment/service';
import { SubscriptionStatus } from '../../plugins/payment/types';
import { MockCopilotProvider } from '../mocks';
import { createTestingModule, TestingModule } from '../utils';
import { WorkflowTestCases } from '../utils/copilot';
type Context = {
auth: AuthService;
@@ -364,6 +364,21 @@ test('should be able to manage chat session', async t => {
});
t.is(newSessionId, sessionId, 'should get same session id');
}
// should create a fresh session when reuseLatestChat is explicitly disabled
{
const newSessionId = await session.create({
userId,
promptName,
...commonParams,
reuseLatestChat: false,
});
t.not(
newSessionId,
sessionId,
'should create new session id when reuseLatestChat is false'
);
}
});
test('should be able to update chat session prompt', async t => {
@@ -881,6 +896,26 @@ test('should be able to get provider', async t => {
}
});
test('should resolve provider by prefixed model id', async t => {
const { factory } = t.context;
const provider = await factory.getProviderByModel('openai-default/test');
t.truthy(provider, 'should resolve prefixed model id');
t.is(provider?.type, CopilotProviderType.OpenAI);
const result = await provider?.text({ modelId: 'openai-default/test' }, [
{ role: 'user', content: 'hello' },
]);
t.is(result, 'generate text to text');
});
test('should fallback to null when prefixed provider id does not exist', async t => {
const { factory } = t.context;
const provider = await factory.getProviderByModel('unknown/test');
t.is(provider, null);
});
// ==================== workflow ====================
// this test used to preview the final result of the workflow
@@ -2063,25 +2098,23 @@ test('should handle copilot cron jobs correctly', async t => {
});
test('should resolve model correctly based on subscription status and prompt config', async t => {
const { db, session, subscription } = t.context;
const { prompt, session, subscription } = t.context;
// 1) Seed a prompt that has optionalModels and proModels in config
const promptName = 'resolve-model-test';
await db.aiPrompt.create({
data: {
name: promptName,
model: 'gemini-2.5-flash',
messages: {
create: [{ idx: 0, role: 'system', content: 'test' }],
},
config: { proModels: ['gemini-2.5-pro', 'claude-sonnet-4-5@20250929'] },
await prompt.set(
promptName,
'gemini-2.5-flash',
[{ role: 'system', content: 'test' }],
{ proModels: ['gemini-2.5-pro', 'claude-sonnet-4-5@20250929'] },
{
optionalModels: [
'gemini-2.5-flash',
'gemini-2.5-pro',
'claude-sonnet-4-5@20250929',
],
},
});
}
);
// 2) Create a chat session with this prompt
const sessionId = await session.create({
@@ -2106,6 +2139,16 @@ test('should resolve model correctly based on subscription status and prompt con
const model1 = await s.resolveModel(false, 'gemini-2.5-pro');
t.snapshot(model1, 'should honor requested pro model');
const model1WithPrefix = await s.resolveModel(
false,
'openai-default/gemini-2.5-pro'
);
t.is(
model1WithPrefix,
'openai-default/gemini-2.5-pro',
'should honor requested prefixed pro model'
);
const model2 = await s.resolveModel(false, 'not-in-optional');
t.snapshot(model2, 'should fallback to default model');
}
@@ -2119,6 +2162,16 @@ test('should resolve model correctly based on subscription status and prompt con
'should fallback to default model when requesting pro model during trialing'
);
const model3WithPrefix = await s.resolveModel(
true,
'openai-default/gemini-2.5-pro'
);
t.is(
model3WithPrefix,
'gemini-2.5-flash',
'should fallback to default model when requesting prefixed pro model during trialing'
);
const model4 = await s.resolveModel(true, 'gemini-2.5-flash');
t.snapshot(model4, 'should honor requested non-pro model during trialing');
@@ -2141,6 +2194,16 @@ test('should resolve model correctly based on subscription status and prompt con
const model7 = await s.resolveModel(true, 'claude-sonnet-4-5@20250929');
t.snapshot(model7, 'should honor requested pro model during active');
const model7WithPrefix = await s.resolveModel(
true,
'openai-default/claude-sonnet-4-5@20250929'
);
t.is(
model7WithPrefix,
'openai-default/claude-sonnet-4-5@20250929',
'should honor requested prefixed pro model during active'
);
const model8 = await s.resolveModel(true, 'not-in-optional');
t.snapshot(
model8,

View File

@@ -0,0 +1,210 @@
import test from 'ava';
import { z } from 'zod';
import type { NativeLlmRequest, NativeLlmStreamEvent } from '../../native';
import {
buildNativeRequest,
NativeProviderAdapter,
} from '../../plugins/copilot/providers/native';
const mockDispatch = () =>
(async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
yield { type: 'text_delta', text: 'Use [^1] now' };
yield { type: 'citation', index: 1, url: 'https://affine.pro' };
yield { type: 'done', finish_reason: 'stop' };
})();
test('NativeProviderAdapter streamText should append citation footnotes', async t => {
const adapter = new NativeProviderAdapter(mockDispatch, {}, 3);
const chunks: string[] = [];
for await (const chunk of adapter.streamText({
model: 'gpt-4.1',
stream: true,
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
})) {
chunks.push(chunk);
}
const text = chunks.join('');
t.true(text.includes('Use [^1] now'));
t.true(
text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}')
);
});
test('NativeProviderAdapter streamObject should append citation footnotes', async t => {
const adapter = new NativeProviderAdapter(mockDispatch, {}, 3);
const chunks = [];
for await (const chunk of adapter.streamObject({
model: 'gpt-4.1',
stream: true,
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
})) {
chunks.push(chunk);
}
t.deepEqual(
chunks.map(chunk => chunk.type),
['text-delta', 'text-delta']
);
const text = chunks
.filter(chunk => chunk.type === 'text-delta')
.map(chunk => chunk.textDelta)
.join('');
t.true(text.includes('Use [^1] now'));
t.true(
text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}')
);
});
test('NativeProviderAdapter streamObject should append fallback attachment footnotes', async t => {
const dispatch = () =>
(async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
yield {
type: 'tool_result',
call_id: 'call_1',
name: 'blob_read',
arguments: { blob_id: 'blob_1' },
output: {
blobId: 'blob_1',
fileName: 'a.txt',
fileType: 'text/plain',
content: 'A',
},
};
yield {
type: 'tool_result',
call_id: 'call_2',
name: 'blob_read',
arguments: { blob_id: 'blob_2' },
output: {
blobId: 'blob_2',
fileName: 'b.txt',
fileType: 'text/plain',
content: 'B',
},
};
yield { type: 'text_delta', text: 'Answer from files.' };
yield { type: 'done', finish_reason: 'stop' };
})();
const adapter = new NativeProviderAdapter(dispatch, {}, 3);
const chunks = [];
for await (const chunk of adapter.streamObject({
model: 'gpt-4.1',
stream: true,
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
})) {
chunks.push(chunk);
}
const text = chunks
.filter(chunk => chunk.type === 'text-delta')
.map(chunk => chunk.textDelta)
.join('');
t.true(text.includes('Answer from files.'));
t.true(text.includes('[^1][^2]'));
t.true(
text.includes(
'[^1]: {"type":"attachment","blobId":"blob_1","fileName":"a.txt","fileType":"text/plain"}'
)
);
t.true(
text.includes(
'[^2]: {"type":"attachment","blobId":"blob_2","fileName":"b.txt","fileType":"text/plain"}'
)
);
});
test('NativeProviderAdapter streamObject should map tool and text events', async t => {
let round = 0;
const dispatch = (_request: NativeLlmRequest) =>
(async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
round += 1;
if (round === 1) {
yield {
type: 'tool_call',
call_id: 'call_1',
name: 'doc_read',
arguments: { doc_id: 'a1' },
};
yield { type: 'done', finish_reason: 'tool_calls' };
return;
}
yield { type: 'text_delta', text: 'ok' };
yield { type: 'done', finish_reason: 'stop' };
})();
const adapter = new NativeProviderAdapter(
dispatch,
{
doc_read: {
inputSchema: z.object({ doc_id: z.string() }),
execute: async () => ({ markdown: '# a1' }),
},
},
4
);
const events = [];
for await (const event of adapter.streamObject({
model: 'gpt-4.1',
stream: true,
messages: [{ role: 'user', content: [{ type: 'text', text: 'read' }] }],
})) {
events.push(event);
}
t.deepEqual(
events.map(event => event.type),
['tool-call', 'tool-result', 'text-delta']
);
t.deepEqual(events[0], {
type: 'tool-call',
toolCallId: 'call_1',
toolName: 'doc_read',
args: { doc_id: 'a1' },
});
});
test('buildNativeRequest should include rust middleware from profile', async t => {
const { request } = await buildNativeRequest({
model: 'gpt-4.1',
messages: [{ role: 'user', content: 'hello' }],
tools: {},
middleware: {
rust: {
request: ['normalize_messages', 'clamp_max_tokens'],
stream: ['stream_event_normalize', 'citation_indexing'],
},
node: {
text: ['callout'],
},
},
});
t.deepEqual(request.middleware, {
request: ['normalize_messages', 'clamp_max_tokens'],
stream: ['stream_event_normalize', 'citation_indexing'],
});
});
test('NativeProviderAdapter streamText should skip citation footnotes when disabled', async t => {
const adapter = new NativeProviderAdapter(mockDispatch, {}, 3, {
nodeTextMiddleware: ['callout'],
});
const chunks: string[] = [];
for await (const chunk of adapter.streamText({
model: 'gpt-4.1',
stream: true,
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
})) {
chunks.push(chunk);
}
const text = chunks.join('');
t.true(text.includes('Use [^1] now'));
t.false(
text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}')
);
});

View File

@@ -0,0 +1,56 @@
import test from 'ava';
import { resolveProviderMiddleware } from '../../plugins/copilot/providers/provider-middleware';
import { buildProviderRegistry } from '../../plugins/copilot/providers/provider-registry';
import { CopilotProviderType } from '../../plugins/copilot/providers/types';
test('resolveProviderMiddleware should include anthropic defaults', t => {
const middleware = resolveProviderMiddleware(CopilotProviderType.Anthropic);
t.deepEqual(middleware.rust?.request, [
'normalize_messages',
'tool_schema_rewrite',
]);
t.deepEqual(middleware.rust?.stream, [
'stream_event_normalize',
'citation_indexing',
]);
t.deepEqual(middleware.node?.text, ['citation_footnote', 'callout']);
});
test('resolveProviderMiddleware should merge defaults and overrides', t => {
const middleware = resolveProviderMiddleware(CopilotProviderType.OpenAI, {
rust: { request: ['clamp_max_tokens'] },
node: { text: ['thinking_format'] },
});
t.deepEqual(middleware.rust?.request, [
'normalize_messages',
'clamp_max_tokens',
]);
t.deepEqual(middleware.node?.text, [
'citation_footnote',
'callout',
'thinking_format',
]);
});
test('buildProviderRegistry should normalize profile middleware defaults', t => {
const registry = buildProviderRegistry({
profiles: [
{
id: 'openai-main',
type: CopilotProviderType.OpenAI,
config: { apiKey: '1' },
},
],
});
const profile = registry.profiles.get('openai-main');
t.truthy(profile);
t.deepEqual(profile?.middleware.rust?.stream, [
'stream_event_normalize',
'citation_indexing',
]);
t.deepEqual(profile?.middleware.node?.text, ['citation_footnote', 'callout']);
});

View File

@@ -0,0 +1,99 @@
import test from 'ava';
import { ProviderMiddlewareConfig } from '../../plugins/copilot/config';
import { CopilotProvider } from '../../plugins/copilot/providers/provider';
import {
CopilotProviderType,
ModelInputType,
ModelOutputType,
} from '../../plugins/copilot/providers/types';
class TestOpenAIProvider extends CopilotProvider<{ apiKey: string }> {
readonly type = CopilotProviderType.OpenAI;
readonly models = [
{
id: 'gpt-4.1',
capabilities: [
{
input: [ModelInputType.Text],
output: [ModelOutputType.Text],
defaultForOutputType: true,
},
],
},
];
configured() {
return true;
}
async text(_cond: any, _messages: any[], _options?: any) {
return '';
}
async *streamText(_cond: any, _messages: any[], _options?: any) {
yield '';
}
exposeMetricLabels() {
return this.metricLabels('gpt-4.1');
}
exposeMiddleware() {
return this.getActiveProviderMiddleware();
}
}
function createProvider(profileMiddleware?: ProviderMiddlewareConfig) {
const provider = new TestOpenAIProvider();
(provider as any).AFFiNEConfig = {
copilot: {
providers: {
profiles: [
{
id: 'openai-main',
type: CopilotProviderType.OpenAI,
config: { apiKey: 'test' },
middleware: profileMiddleware,
},
],
defaults: {},
openai: { apiKey: 'legacy' },
},
},
};
return provider;
}
test('metricLabels should include active provider id', t => {
const provider = createProvider();
const labels = provider.runWithProfile('openai-main', () =>
provider.exposeMetricLabels()
);
t.is(labels.providerId, 'openai-main');
});
test('getActiveProviderMiddleware should merge defaults with profile override', t => {
const provider = createProvider({
rust: { request: ['clamp_max_tokens'] },
node: { text: ['thinking_format'] },
});
const middleware = provider.runWithProfile('openai-main', () =>
provider.exposeMiddleware()
);
t.deepEqual(middleware.rust?.request, [
'normalize_messages',
'clamp_max_tokens',
]);
t.deepEqual(middleware.rust?.stream, [
'stream_event_normalize',
'citation_indexing',
]);
t.deepEqual(middleware.node?.text, [
'citation_footnote',
'callout',
'thinking_format',
]);
});

View File

@@ -0,0 +1,165 @@
import test from 'ava';
import {
buildProviderRegistry,
resolveModel,
stripProviderPrefix,
} from '../../plugins/copilot/providers/provider-registry';
import {
CopilotProviderType,
ModelOutputType,
} from '../../plugins/copilot/providers/types';
test('buildProviderRegistry should keep explicit profile over legacy compatibility profile', t => {
const registry = buildProviderRegistry({
profiles: [
{
id: 'openai-default',
type: CopilotProviderType.OpenAI,
priority: 100,
config: { apiKey: 'new' },
},
],
openai: { apiKey: 'legacy' },
});
const profile = registry.profiles.get('openai-default');
t.truthy(profile);
t.deepEqual(profile?.config, { apiKey: 'new' });
});
test('buildProviderRegistry should reject duplicated profile ids', t => {
const error = t.throws(() =>
buildProviderRegistry({
profiles: [
{
id: 'openai-main',
type: CopilotProviderType.OpenAI,
config: { apiKey: '1' },
},
{
id: 'openai-main',
type: CopilotProviderType.OpenAI,
config: { apiKey: '2' },
},
],
})
) as Error;
t.truthy(error);
t.regex(error.message, /Duplicated copilot provider profile id/);
});
test('buildProviderRegistry should reject defaults that reference unknown providers', t => {
const error = t.throws(() =>
buildProviderRegistry({
profiles: [
{
id: 'openai-main',
type: CopilotProviderType.OpenAI,
config: { apiKey: '1' },
},
],
defaults: {
fallback: 'unknown-provider',
},
})
) as Error;
t.truthy(error);
t.regex(error.message, /defaults references unknown providerId/);
});
test('resolveModel should support explicit provider prefix and keep slash models untouched', t => {
const registry = buildProviderRegistry({
profiles: [
{
id: 'openai-main',
type: CopilotProviderType.OpenAI,
config: { apiKey: '1' },
},
{
id: 'fal-main',
type: CopilotProviderType.FAL,
config: { apiKey: '2' },
},
],
});
const prefixed = resolveModel({
registry,
modelId: 'openai-main/gpt-4.1',
});
t.deepEqual(prefixed, {
rawModelId: 'openai-main/gpt-4.1',
modelId: 'gpt-4.1',
explicitProviderId: 'openai-main',
candidateProviderIds: ['openai-main'],
});
const slashModel = resolveModel({
registry,
modelId: 'lora/image-to-image',
});
t.is(slashModel.modelId, 'lora/image-to-image');
t.false(slashModel.candidateProviderIds.includes('lora'));
});
test('resolveModel should follow defaults -> fallback -> order and apply filters', t => {
const registry = buildProviderRegistry({
profiles: [
{
id: 'openai-main',
type: CopilotProviderType.OpenAI,
priority: 10,
config: { apiKey: '1' },
},
{
id: 'anthropic-main',
type: CopilotProviderType.Anthropic,
priority: 5,
config: { apiKey: '2' },
},
{
id: 'fal-main',
type: CopilotProviderType.FAL,
priority: 1,
config: { apiKey: '3' },
},
],
defaults: {
[ModelOutputType.Text]: 'anthropic-main',
fallback: 'openai-main',
},
});
const routed = resolveModel({
registry,
outputType: ModelOutputType.Text,
preferredProviderIds: ['openai-main', 'fal-main'],
});
t.deepEqual(routed.candidateProviderIds, ['openai-main', 'fal-main']);
});
test('stripProviderPrefix should only strip matched provider prefix', t => {
const registry = buildProviderRegistry({
profiles: [
{
id: 'openai-main',
type: CopilotProviderType.OpenAI,
config: { apiKey: '1' },
},
],
});
t.is(
stripProviderPrefix(registry, 'openai-main', 'openai-main/gpt-4.1'),
'gpt-4.1'
);
t.is(
stripProviderPrefix(registry, 'openai-main', 'another-main/gpt-4.1'),
'another-main/gpt-4.1'
);
t.is(stripProviderPrefix(registry, 'openai-main', 'gpt-4.1'), 'gpt-4.1');
});

View File

@@ -0,0 +1,134 @@
import test from 'ava';
import { z } from 'zod';
import { NativeLlmRequest, NativeLlmStreamEvent } from '../../native';
import {
ToolCallAccumulator,
ToolCallLoop,
ToolSchemaExtractor,
} from '../../plugins/copilot/providers/loop';
test('ToolCallAccumulator should merge deltas and complete tool call', t => {
const accumulator = new ToolCallAccumulator();
accumulator.feedDelta({
type: 'tool_call_delta',
call_id: 'call_1',
name: 'doc_read',
arguments_delta: '{"doc_id":"',
});
accumulator.feedDelta({
type: 'tool_call_delta',
call_id: 'call_1',
arguments_delta: 'a1"}',
});
const completed = accumulator.complete({
type: 'tool_call',
call_id: 'call_1',
name: 'doc_read',
arguments: { doc_id: 'a1' },
});
t.deepEqual(completed, {
id: 'call_1',
name: 'doc_read',
args: { doc_id: 'a1' },
thought: undefined,
});
});
test('ToolSchemaExtractor should convert zod schema to json schema', t => {
const toolSet = {
doc_read: {
description: 'Read doc',
inputSchema: z.object({
doc_id: z.string(),
limit: z.number().optional(),
}),
execute: async () => ({}),
},
};
const extracted = ToolSchemaExtractor.extract(toolSet);
t.deepEqual(extracted, [
{
name: 'doc_read',
description: 'Read doc',
parameters: {
type: 'object',
properties: {
doc_id: { type: 'string' },
limit: { type: 'number' },
},
additionalProperties: false,
required: ['doc_id'],
},
},
]);
});
test('ToolCallLoop should execute tool call and continue to next round', async t => {
const dispatchRequests: NativeLlmRequest[] = [];
const dispatch = (request: NativeLlmRequest) => {
dispatchRequests.push(request);
const round = dispatchRequests.length;
return (async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
if (round === 1) {
yield {
type: 'tool_call_delta',
call_id: 'call_1',
name: 'doc_read',
arguments_delta: '{"doc_id":"a1"}',
};
yield {
type: 'tool_call',
call_id: 'call_1',
name: 'doc_read',
arguments: { doc_id: 'a1' },
};
yield { type: 'done', finish_reason: 'tool_calls' };
return;
}
yield { type: 'text_delta', text: 'done' };
yield { type: 'done', finish_reason: 'stop' };
})();
};
let executedArgs: Record<string, unknown> | null = null;
const loop = new ToolCallLoop(
dispatch,
{
doc_read: {
inputSchema: z.object({ doc_id: z.string() }),
execute: async args => {
executedArgs = args;
return { markdown: '# doc' };
},
},
},
4
);
const events: NativeLlmStreamEvent[] = [];
for await (const event of loop.run({
model: 'gpt-4.1',
stream: true,
messages: [{ role: 'user', content: [{ type: 'text', text: 'read doc' }] }],
})) {
events.push(event);
}
t.deepEqual(executedArgs, { doc_id: 'a1' });
t.true(
dispatchRequests[1]?.messages.some(message => message.role === 'tool')
);
t.deepEqual(
events.map(event => event.type),
['tool_call', 'tool_result', 'text_delta', 'done']
);
});

View File

@@ -0,0 +1,116 @@
import test from 'ava';
import { z } from 'zod';
import {
chatToGPTMessage,
CitationFootnoteFormatter,
CitationParser,
StreamPatternParser,
} from '../../plugins/copilot/providers/utils';
test('CitationFootnoteFormatter should format sorted footnotes from citation events', t => {
const formatter = new CitationFootnoteFormatter();
formatter.consume({
type: 'citation',
index: 2,
url: 'https://example.com/b',
});
formatter.consume({
type: 'citation',
index: 1,
url: 'https://example.com/a',
});
t.is(
formatter.end(),
[
'[^1]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fa"}',
'[^2]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fb"}',
].join('\n')
);
});
test('CitationFootnoteFormatter should overwrite duplicated index with latest url', t => {
const formatter = new CitationFootnoteFormatter();
formatter.consume({
type: 'citation',
index: 1,
url: 'https://example.com/old',
});
formatter.consume({
type: 'citation',
index: 1,
url: 'https://example.com/new',
});
t.is(
formatter.end(),
'[^1]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fnew"}'
);
});
test('StreamPatternParser should keep state across chunks', t => {
const parser = new StreamPatternParser(pattern => {
if (pattern.kind === 'wrappedLink') {
return `[^${pattern.url}]`;
}
if (pattern.kind === 'index') {
return `[#${pattern.value}]`;
}
return `[${pattern.text}](${pattern.url})`;
});
const first = parser.write('ref ([AFFiNE](https://affine.pro');
const second = parser.write(')) and [2]');
t.is(first, 'ref ');
t.is(second, '[^https://affine.pro] and [#2]');
t.is(parser.end(), '');
});
test('CitationParser should convert wrapped links to numbered footnotes', t => {
const parser = new CitationParser();
const output = parser.parse('Use ([AFFiNE](https://affine.pro)) now');
t.is(output, 'Use [^1] now');
t.regex(
parser.end(),
/\[\^1\]: \{"type":"url","url":"https%3A%2F%2Faffine.pro"\}/
);
});
test('chatToGPTMessage should not mutate input and should keep system schema', async t => {
const schema = z.object({
query: z.string(),
});
const messages = [
{
role: 'system' as const,
content: 'You are helper',
params: { schema },
},
{
role: 'user' as const,
content: '',
attachments: ['https://example.com/a.png'],
},
];
const firstRef = messages[0];
const secondRef = messages[1];
const [system, normalized, parsedSchema] = await chatToGPTMessage(
messages,
false
);
t.is(system, 'You are helper');
t.is(parsedSchema, schema);
t.is(messages.length, 2);
t.is(messages[0], firstRef);
t.is(messages[1], secondRef);
t.deepEqual(normalized[0], {
role: 'user',
content: [{ type: 'text', text: '[no content]' }],
});
});

View File

@@ -0,0 +1,82 @@
import test from 'ava';
import { NativeStreamAdapter } from '../native';
test('NativeStreamAdapter should support buffered and awaited consumption', async t => {
const adapter = new NativeStreamAdapter<number>(undefined);
adapter.push(1);
const first = await adapter.next();
t.deepEqual(first, { value: 1, done: false });
const pending = adapter.next();
adapter.push(2);
const second = await pending;
t.deepEqual(second, { value: 2, done: false });
adapter.push(null);
const done = await adapter.next();
t.true(done.done);
});
test('NativeStreamAdapter return should abort handle and end iteration', async t => {
let abortCount = 0;
const adapter = new NativeStreamAdapter<number>({
abort: () => {
abortCount += 1;
},
});
const ended = await adapter.return();
t.is(abortCount, 1);
t.true(ended.done);
const secondReturn = await adapter.return();
t.true(secondReturn.done);
t.is(abortCount, 1);
const next = await adapter.next();
t.true(next.done);
});
test('NativeStreamAdapter should abort when AbortSignal is triggered', async t => {
let abortCount = 0;
const controller = new AbortController();
const adapter = new NativeStreamAdapter<number>(
{
abort: () => {
abortCount += 1;
},
},
controller.signal
);
const pending = adapter.next();
controller.abort();
const done = await pending;
t.true(done.done);
t.is(abortCount, 1);
});
test('NativeStreamAdapter should end immediately for pre-aborted signal', async t => {
let abortCount = 0;
const controller = new AbortController();
controller.abort();
const adapter = new NativeStreamAdapter<number>(
{
abort: () => {
abortCount += 1;
},
},
controller.signal
);
const next = await adapter.next();
t.true(next.done);
t.is(abortCount, 1);
adapter.push(1);
const stillDone = await adapter.next();
t.true(stillDone.done);
});

View File

@@ -629,14 +629,35 @@ export async function chatWithText(
prefix = '',
retry?: boolean
): Promise<string> {
const endpoint = prefix || '/stream';
const query = messageId
? `?messageId=${messageId}` + (retry ? '&retry=true' : '')
: '';
const res = await app
.GET(`/api/copilot/chat/${sessionId}${prefix}${query}`)
.GET(`/api/copilot/chat/${sessionId}${endpoint}${query}`)
.expect(200);
return res.text;
if (prefix) {
return res.text;
}
const events = sse2array(res.text);
const errorEvent = events.find(event => event.event === 'error');
if (errorEvent?.data) {
let message = errorEvent.data;
try {
const parsed = JSON.parse(errorEvent.data);
message = parsed.message || message;
} catch {
// noop: keep raw error data
}
throw new Error(message);
}
return events
.filter(event => event.event === 'message')
.map(event => event.data ?? '')
.join('');
}
export async function chatWithTextStream(

View File

@@ -38,8 +38,11 @@ test.before(async t => {
t.context.app = app;
});
test.after.always(async t => {
test.afterEach.always(() => {
Sinon.restore();
});
test.after.always(async t => {
__resetDnsLookupForTests();
await t.context.app.close();
});
@@ -80,6 +83,7 @@ const assertAndSnapshotRaw = async (
test('should proxy image', async t => {
const assertAndSnapshot = assertAndSnapshotRaw.bind(null, t);
const imageUrl = `http://example.com/image-${Date.now()}.png`;
await assertAndSnapshot(
'/api/worker/image-proxy',
@@ -105,7 +109,7 @@ test('should proxy image', async t => {
{
await assertAndSnapshot(
'/api/worker/image-proxy?url=http://example.com/image.png',
`/api/worker/image-proxy?url=${imageUrl}`,
'should return 400 if origin and referer are missing',
{ status: 400, origin: null, referer: null }
);
@@ -113,14 +117,17 @@ test('should proxy image', async t => {
{
await assertAndSnapshot(
'/api/worker/image-proxy?url=http://example.com/image.png',
`/api/worker/image-proxy?url=${imageUrl}`,
'should return 400 for invalid origin header',
{ status: 400, origin: 'http://invalid.com' }
);
}
{
const fakeBuffer = Buffer.from('fake image');
const fakeBuffer = Buffer.from(
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+jfJ8AAAAASUVORK5CYII=',
'base64'
);
const fakeResponse = new Response(fakeBuffer, {
status: 200,
headers: {
@@ -130,13 +137,14 @@ test('should proxy image', async t => {
});
const fetchSpy = Sinon.stub(global, 'fetch').resolves(fakeResponse);
await assertAndSnapshot(
'/api/worker/image-proxy?url=http://example.com/image.png',
'should return image buffer'
);
fetchSpy.restore();
try {
await assertAndSnapshot(
`/api/worker/image-proxy?url=${imageUrl}`,
'should return image buffer'
);
} finally {
fetchSpy.restore();
}
}
});
@@ -200,18 +208,19 @@ test('should preview link', async t => {
});
const fetchSpy = Sinon.stub(global, 'fetch').resolves(fakeHTML);
await assertAndSnapshot(
'/api/worker/link-preview',
'should process a valid external URL and return link preview data',
{
status: 200,
method: 'POST',
body: { url: 'http://external.com/page' },
}
);
fetchSpy.restore();
try {
await assertAndSnapshot(
'/api/worker/link-preview',
'should process a valid external URL and return link preview data',
{
status: 200,
method: 'POST',
body: { url: 'http://external.com/page' },
}
);
} finally {
fetchSpy.restore();
}
}
{
@@ -251,18 +260,19 @@ test('should preview link', async t => {
});
const fetchSpy = Sinon.stub(global, 'fetch').resolves(fakeHTML);
await assertAndSnapshot(
'/api/worker/link-preview',
'should decode HTML content with charset',
{
status: 200,
method: 'POST',
body: { url: `http://example.com/${charset}` },
}
);
fetchSpy.restore();
try {
await assertAndSnapshot(
'/api/worker/link-preview',
'should decode HTML content with charset',
{
status: 200,
method: 'POST',
body: { url: `http://example.com/${charset}` },
}
);
} finally {
fetchSpy.restore();
}
}
}
});

View File

@@ -42,8 +42,18 @@ export class Ga4Client {
timestamp_micros: event.timestampMicros,
})),
};
await this.post(payload);
try {
await this.post(payload);
} catch {
if (env.DEPLOYMENT_TYPE === 'affine') {
// In production, we want to be resilient to GA4 failures, so we catch and ignore errors.
// In non-production environments, we rethrow to surface issues during development and testing.
console.info(
'Failed to send telemetry event to GA4:',
chunk.map(e => e.eventName).join(', ')
);
}
}
}
}
}

View File

@@ -57,3 +57,316 @@ export const addDocToRootDoc = serverNativeModule.addDocToRootDoc;
export const updateDocTitle = serverNativeModule.updateDocTitle;
export const updateDocProperties = serverNativeModule.updateDocProperties;
export const updateRootDocMetaTitle = serverNativeModule.updateRootDocMetaTitle;
type NativeLlmModule = {
llmDispatch?: (
protocol: string,
backendConfigJson: string,
requestJson: string
) => string | Promise<string>;
llmDispatchStream?: (
protocol: string,
backendConfigJson: string,
requestJson: string,
callback: (error: Error | null, eventJson: string) => void
) => { abort?: () => void } | undefined;
};
const nativeLlmModule = serverNativeModule as typeof serverNativeModule &
NativeLlmModule;
export type NativeLlmProtocol =
| 'openai_chat'
| 'openai_responses'
| 'anthropic';
export type NativeLlmBackendConfig = {
base_url: string;
auth_token: string;
request_layer?: 'anthropic' | 'chat_completions' | 'responses' | 'vertex';
headers?: Record<string, string>;
no_streaming?: boolean;
timeout_ms?: number;
};
export type NativeLlmCoreRole = 'system' | 'user' | 'assistant' | 'tool';
export type NativeLlmCoreContent =
| { type: 'text'; text: string }
| { type: 'reasoning'; text: string; signature?: string }
| {
type: 'tool_call';
call_id: string;
name: string;
arguments: Record<string, unknown>;
thought?: string;
}
| {
type: 'tool_result';
call_id: string;
output: unknown;
is_error?: boolean;
name?: string;
arguments?: Record<string, unknown>;
}
| { type: 'image'; source: Record<string, unknown> | string };
export type NativeLlmCoreMessage = {
role: NativeLlmCoreRole;
content: NativeLlmCoreContent[];
};
export type NativeLlmToolDefinition = {
name: string;
description?: string;
parameters: Record<string, unknown>;
};
export type NativeLlmRequest = {
model: string;
messages: NativeLlmCoreMessage[];
stream?: boolean;
max_tokens?: number;
temperature?: number;
tools?: NativeLlmToolDefinition[];
tool_choice?: 'auto' | 'none' | 'required' | { name: string };
include?: string[];
reasoning?: Record<string, unknown>;
middleware?: {
request?: Array<
'normalize_messages' | 'clamp_max_tokens' | 'tool_schema_rewrite'
>;
stream?: Array<'stream_event_normalize' | 'citation_indexing'>;
config?: {
no_additional_properties?: boolean;
drop_property_format?: boolean;
drop_property_min_length?: boolean;
drop_array_min_items?: boolean;
drop_array_max_items?: boolean;
max_tokens_cap?: number;
};
};
};
export type NativeLlmDispatchResponse = {
id: string;
model: string;
message: NativeLlmCoreMessage;
usage: {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
cached_tokens?: number;
};
finish_reason: string;
reasoning_details?: unknown;
};
export type NativeLlmStreamEvent =
| { type: 'message_start'; id?: string; model?: string }
| { type: 'text_delta'; text: string }
| { type: 'reasoning_delta'; text: string }
| {
type: 'tool_call_delta';
call_id: string;
name?: string;
arguments_delta: string;
}
| {
type: 'tool_call';
call_id: string;
name: string;
arguments: Record<string, unknown>;
thought?: string;
}
| {
type: 'tool_result';
call_id: string;
output: unknown;
is_error?: boolean;
name?: string;
arguments?: Record<string, unknown>;
}
| { type: 'citation'; index: number; url: string }
| {
type: 'usage';
usage: {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
cached_tokens?: number;
};
}
| {
type: 'done';
finish_reason?: string;
usage?: {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
cached_tokens?: number;
};
}
| { type: 'error'; message: string; code?: string; raw?: string };
const LLM_STREAM_END_MARKER = '__AFFINE_LLM_STREAM_END__';
export async function llmDispatch(
protocol: NativeLlmProtocol,
backendConfig: NativeLlmBackendConfig,
request: NativeLlmRequest
): Promise<NativeLlmDispatchResponse> {
if (!nativeLlmModule.llmDispatch) {
throw new Error('native llm dispatch is not available');
}
const response = nativeLlmModule.llmDispatch(
protocol,
JSON.stringify(backendConfig),
JSON.stringify(request)
);
const responseText = await Promise.resolve(response);
return JSON.parse(responseText) as NativeLlmDispatchResponse;
}
export class NativeStreamAdapter<T> implements AsyncIterableIterator<T> {
readonly #queue: T[] = [];
readonly #waiters: ((result: IteratorResult<T>) => void)[] = [];
readonly #handle: { abort?: () => void } | undefined;
readonly #signal?: AbortSignal;
readonly #abortListener?: () => void;
#ended = false;
constructor(
handle: { abort?: () => void } | undefined,
signal?: AbortSignal
) {
this.#handle = handle;
this.#signal = signal;
if (signal?.aborted) {
this.close(true);
return;
}
if (signal) {
this.#abortListener = () => {
this.close(true);
};
signal.addEventListener('abort', this.#abortListener, { once: true });
}
}
private close(abortHandle: boolean) {
if (this.#ended) {
return;
}
this.#ended = true;
if (this.#signal && this.#abortListener) {
this.#signal.removeEventListener('abort', this.#abortListener);
}
if (abortHandle) {
this.#handle?.abort?.();
}
while (this.#waiters.length) {
const waiter = this.#waiters.shift();
waiter?.({ value: undefined as T, done: true });
}
}
push(value: T | null) {
if (this.#ended) {
return;
}
if (value === null) {
this.close(false);
return;
}
const waiter = this.#waiters.shift();
if (waiter) {
waiter({ value, done: false });
return;
}
this.#queue.push(value);
}
[Symbol.asyncIterator]() {
return this;
}
async next(): Promise<IteratorResult<T>> {
if (this.#queue.length > 0) {
const value = this.#queue.shift() as T;
return { value, done: false };
}
if (this.#ended) {
return { value: undefined as T, done: true };
}
return await new Promise(resolve => {
this.#waiters.push(resolve);
});
}
async return(): Promise<IteratorResult<T>> {
this.close(true);
return { value: undefined as T, done: true };
}
}
export function llmDispatchStream(
protocol: NativeLlmProtocol,
backendConfig: NativeLlmBackendConfig,
request: NativeLlmRequest,
signal?: AbortSignal
): AsyncIterableIterator<NativeLlmStreamEvent> {
if (!nativeLlmModule.llmDispatchStream) {
throw new Error('native llm stream dispatch is not available');
}
let adapter: NativeStreamAdapter<NativeLlmStreamEvent> | undefined;
const buffer: (NativeLlmStreamEvent | null)[] = [];
let pushFn = (event: NativeLlmStreamEvent | null) => {
buffer.push(event);
};
const handle = nativeLlmModule.llmDispatchStream(
protocol,
JSON.stringify(backendConfig),
JSON.stringify(request),
(error, eventJson) => {
if (error) {
pushFn({ type: 'error', message: error.message, raw: eventJson });
return;
}
if (eventJson === LLM_STREAM_END_MARKER) {
pushFn(null);
return;
}
try {
pushFn(JSON.parse(eventJson) as NativeLlmStreamEvent);
} catch (error) {
pushFn({
type: 'error',
message:
error instanceof Error
? error.message
: 'failed to parse native stream event',
raw: eventJson,
});
}
}
);
adapter = new NativeStreamAdapter(handle, signal);
pushFn = event => {
adapter.push(event);
};
for (const event of buffer) {
adapter.push(event);
}
return adapter;
}

View File

@@ -1,3 +1,5 @@
import { z } from 'zod';
import {
defineModuleConfig,
StorageJSONSchema,
@@ -13,7 +15,179 @@ import { GeminiGenerativeConfig, GeminiVertexConfig } from './providers/gemini';
import { MorphConfig } from './providers/morph';
import { OpenAIConfig } from './providers/openai';
import { PerplexityConfig } from './providers/perplexity';
import { VertexSchema } from './providers/types';
import {
CopilotProviderType,
ModelOutputType,
VertexSchema,
} from './providers/types';
export type CopilotProviderConfigMap = {
[CopilotProviderType.OpenAI]: OpenAIConfig;
[CopilotProviderType.FAL]: FalConfig;
[CopilotProviderType.Gemini]: GeminiGenerativeConfig;
[CopilotProviderType.GeminiVertex]: GeminiVertexConfig;
[CopilotProviderType.Perplexity]: PerplexityConfig;
[CopilotProviderType.Anthropic]: AnthropicOfficialConfig;
[CopilotProviderType.AnthropicVertex]: AnthropicVertexConfig;
[CopilotProviderType.Morph]: MorphConfig;
};
export type ProviderSpecificConfig =
CopilotProviderConfigMap[keyof CopilotProviderConfigMap];
export const RustRequestMiddlewareValues = [
'normalize_messages',
'clamp_max_tokens',
'tool_schema_rewrite',
] as const;
export type RustRequestMiddleware =
(typeof RustRequestMiddlewareValues)[number];
export const RustStreamMiddlewareValues = [
'stream_event_normalize',
'citation_indexing',
] as const;
export type RustStreamMiddleware = (typeof RustStreamMiddlewareValues)[number];
export const NodeTextMiddlewareValues = [
'citation_footnote',
'callout',
'thinking_format',
] as const;
export type NodeTextMiddleware = (typeof NodeTextMiddlewareValues)[number];
export type ProviderMiddlewareConfig = {
rust?: { request?: RustRequestMiddleware[]; stream?: RustStreamMiddleware[] };
node?: { text?: NodeTextMiddleware[] };
};
type CopilotProviderProfileCommon = {
id: string;
displayName?: string;
priority?: number;
enabled?: boolean;
models?: string[];
middleware?: ProviderMiddlewareConfig;
};
type CopilotProviderProfileVariant<T extends CopilotProviderType> = {
type: T;
config: CopilotProviderConfigMap[T];
};
export type CopilotProviderProfile = CopilotProviderProfileCommon &
{
[Type in CopilotProviderType]: CopilotProviderProfileVariant<Type>;
}[CopilotProviderType];
export type CopilotProviderDefaults = Partial<
Record<ModelOutputType, string>
> & {
fallback?: string;
};
const CopilotProviderProfileBaseShape = z.object({
id: z.string().regex(/^[a-zA-Z0-9-_]+$/),
displayName: z.string().optional(),
priority: z.number().optional(),
enabled: z.boolean().optional(),
models: z.array(z.string()).optional(),
middleware: z
.object({
rust: z
.object({
request: z.array(z.enum(RustRequestMiddlewareValues)).optional(),
stream: z.array(z.enum(RustStreamMiddlewareValues)).optional(),
})
.optional(),
node: z
.object({ text: z.array(z.enum(NodeTextMiddlewareValues)).optional() })
.optional(),
})
.optional(),
});
const OpenAIConfigShape = z.object({
apiKey: z.string(),
baseURL: z.string().optional(),
oldApiStyle: z.boolean().optional(),
});
const FalConfigShape = z.object({
apiKey: z.string(),
});
const GeminiGenerativeConfigShape = z.object({
apiKey: z.string(),
baseURL: z.string().optional(),
});
const VertexProviderConfigShape = z.object({
location: z.string().optional(),
project: z.string().optional(),
baseURL: z.string().optional(),
googleAuthOptions: z.any().optional(),
fetch: z.any().optional(),
});
const PerplexityConfigShape = z.object({
apiKey: z.string(),
endpoint: z.string().optional(),
});
const AnthropicOfficialConfigShape = z.object({
apiKey: z.string(),
baseURL: z.string().optional(),
});
const MorphConfigShape = z.object({
apiKey: z.string().optional(),
});
const CopilotProviderProfileShape = z.discriminatedUnion('type', [
CopilotProviderProfileBaseShape.extend({
type: z.literal(CopilotProviderType.OpenAI),
config: OpenAIConfigShape,
}),
CopilotProviderProfileBaseShape.extend({
type: z.literal(CopilotProviderType.FAL),
config: FalConfigShape,
}),
CopilotProviderProfileBaseShape.extend({
type: z.literal(CopilotProviderType.Gemini),
config: GeminiGenerativeConfigShape,
}),
CopilotProviderProfileBaseShape.extend({
type: z.literal(CopilotProviderType.GeminiVertex),
config: VertexProviderConfigShape,
}),
CopilotProviderProfileBaseShape.extend({
type: z.literal(CopilotProviderType.Perplexity),
config: PerplexityConfigShape,
}),
CopilotProviderProfileBaseShape.extend({
type: z.literal(CopilotProviderType.Anthropic),
config: AnthropicOfficialConfigShape,
}),
CopilotProviderProfileBaseShape.extend({
type: z.literal(CopilotProviderType.AnthropicVertex),
config: VertexProviderConfigShape,
}),
CopilotProviderProfileBaseShape.extend({
type: z.literal(CopilotProviderType.Morph),
config: MorphConfigShape,
}),
]);
const CopilotProviderDefaultsShape = z.object({
[ModelOutputType.Text]: z.string().optional(),
[ModelOutputType.Object]: z.string().optional(),
[ModelOutputType.Embedding]: z.string().optional(),
[ModelOutputType.Image]: z.string().optional(),
[ModelOutputType.Structured]: z.string().optional(),
fallback: z.string().optional(),
});
declare global {
interface AppConfigSchema {
copilot: {
@@ -27,6 +201,8 @@ declare global {
storage: ConfigItem<StorageProviderConfig>;
scenarios: ConfigItem<CopilotPromptScenario>;
providers: {
profiles: ConfigItem<CopilotProviderProfile[]>;
defaults: ConfigItem<CopilotProviderDefaults>;
openai: ConfigItem<OpenAIConfig>;
fal: ConfigItem<FalConfig>;
gemini: ConfigItem<GeminiGenerativeConfig>;
@@ -63,6 +239,16 @@ defineModuleConfig('copilot', {
},
},
},
'providers.profiles': {
desc: 'The profile list for copilot providers.',
default: [],
shape: z.array(CopilotProviderProfileShape),
},
'providers.defaults': {
desc: 'The default provider ids for model output types and global fallback.',
default: {},
shape: CopilotProviderDefaultsShape,
},
'providers.openai': {
desc: 'The config for the openai provider.',
default: {

View File

@@ -36,10 +36,7 @@ import {
BlobNotFound,
CallMetric,
Config,
CopilotFailedToGenerateText,
CopilotSessionNotFound,
InternalServerError,
mapAnyError,
mapSseError,
metrics,
NoCopilotProviderAvailable,
@@ -242,61 +239,6 @@ export class CopilotController implements BeforeApplicationShutdown {
};
}
@Get('/chat/:sessionId')
@CallMetric('ai', 'chat', { timer: true })
async chat(
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query() query: Record<string, string | string[]>
): Promise<string> {
const info: any = { sessionId, params: query };
try {
const { provider, model, session, finalMessage } =
await this.prepareChatSession(
user,
sessionId,
query,
ModelOutputType.Text
);
info.model = model;
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
metrics.ai.counter('chat_calls').add(1, { model });
const { reasoning, webSearch, toolsConfig } =
ChatQuerySchema.parse(query);
const content = await provider.text({ modelId: model }, finalMessage, {
...session.config.promptConfig,
signal: getSignal(req).signal,
user: user.id,
session: session.config.sessionId,
workspace: session.config.workspaceId,
reasoning,
webSearch,
tools: getTools(session.config.promptConfig?.tools, toolsConfig),
});
session.push({
role: 'assistant',
content,
createdAt: new Date(),
});
await session.save();
return content;
} catch (e: any) {
metrics.ai.counter('chat_errors').add(1);
let error = mapAnyError(e);
if (error instanceof InternalServerError) {
error = new CopilotFailedToGenerateText(e.message);
}
error.log('CopilotChat', info);
throw error;
}
}
@Sse('/chat/:sessionId/stream')
@CallMetric('ai', 'chat_stream', { timer: true })
async chatStream(

View File

@@ -3,7 +3,7 @@ import { AiPrompt, PrismaClient } from '@prisma/client';
import type { PromptConfig, PromptMessage } from '../providers/types';
type Prompt = Omit<
export type Prompt = Omit<
AiPrompt,
| 'id'
| 'createdAt'
@@ -2095,17 +2095,14 @@ export const prompts: Prompt[] = [
export async function refreshPrompts(db: PrismaClient) {
const needToSkip = await db.aiPrompt
.findMany({
where: { modified: true },
select: { name: true },
})
.findMany({ where: { modified: true }, select: { name: true } })
.then(p => p.map(p => p.name));
for (const prompt of prompts) {
// skip prompt update if already modified by admin panel
if (needToSkip.includes(prompt.name)) {
new Logger('CopilotPrompt').warn(`Skip modified prompt: ${prompt.name}`);
return;
continue;
}
await db.aiPrompt.upsert({

View File

@@ -12,6 +12,7 @@ import {
import { ChatPrompt } from './chat-prompt';
import {
CopilotPromptScenario,
type Prompt,
prompts,
refreshPrompts,
Scenario,
@@ -21,6 +22,7 @@ import {
export class PromptService implements OnApplicationBootstrap {
private readonly logger = new Logger(PromptService.name);
private readonly cache = new Map<string, ChatPrompt>();
private readonly inMemoryPrompts = new Map<string, Prompt>();
constructor(
private readonly config: Config,
@@ -28,7 +30,7 @@ export class PromptService implements OnApplicationBootstrap {
) {}
async onApplicationBootstrap() {
this.cache.clear();
this.resetInMemoryPrompts();
await refreshPrompts(this.db);
}
@@ -45,6 +47,7 @@ export class PromptService implements OnApplicationBootstrap {
}
protected async setup(scenarios?: CopilotPromptScenario) {
this.ensureInMemoryPrompts();
if (!!scenarios && scenarios.override_enabled && scenarios.scenarios) {
this.logger.log('Updating prompts based on scenarios...');
for (const [scenario, model] of Object.entries(scenarios.scenarios)) {
@@ -75,25 +78,29 @@ export class PromptService implements OnApplicationBootstrap {
* @returns prompt names
*/
async listNames() {
return this.db.aiPrompt
.findMany({ select: { name: true } })
.then(prompts => Array.from(new Set(prompts.map(p => p.name))));
this.ensureInMemoryPrompts();
return Array.from(this.inMemoryPrompts.keys());
}
async list() {
return this.db.aiPrompt.findMany({
select: {
name: true,
action: true,
model: true,
config: true,
messages: {
select: { role: true, content: true, params: true },
orderBy: { idx: 'asc' },
},
},
orderBy: { action: { sort: 'asc', nulls: 'first' } },
});
this.ensureInMemoryPrompts();
return Array.from(this.inMemoryPrompts.values())
.map(prompt => ({
name: prompt.name,
action: prompt.action ?? null,
model: prompt.model,
config: prompt.config ? structuredClone(prompt.config) : null,
messages: prompt.messages.map(message => ({
role: message.role,
content: message.content,
params: message.params ?? null,
})),
}))
.sort((a, b) => {
if (a.action === null && b.action !== null) return -1;
if (a.action !== null && b.action === null) return 1;
return (a.action ?? '').localeCompare(b.action ?? '');
});
}
/**
@@ -102,40 +109,24 @@ export class PromptService implements OnApplicationBootstrap {
* @returns prompt messages
*/
async get(name: string): Promise<ChatPrompt | null> {
this.ensureInMemoryPrompts();
// skip cache in dev mode to ensure the latest prompt is always fetched
if (!env.dev) {
const cached = this.cache.get(name);
if (cached) return cached;
}
const prompt = await this.db.aiPrompt.findUnique({
where: {
name,
},
select: {
name: true,
action: true,
model: true,
optionalModels: true,
config: true,
messages: {
select: {
role: true,
content: true,
params: true,
},
orderBy: {
idx: 'asc',
},
},
},
});
const prompt = this.inMemoryPrompts.get(name);
if (!prompt) return null;
const messages = PromptMessageSchema.array().safeParse(prompt?.messages);
const config = PromptConfigSchema.safeParse(prompt?.config);
if (prompt && messages.success && config.success) {
const messages = PromptMessageSchema.array().safeParse(prompt.messages);
const config = PromptConfigSchema.safeParse(prompt.config);
if (messages.success && config.success) {
const chatPrompt = ChatPrompt.createFromPrompt({
...prompt,
...this.clonePrompt(prompt),
action: prompt.action ?? null,
optionalModels: prompt.optionalModels ?? [],
config: config.data,
messages: messages.data,
});
@@ -149,25 +140,69 @@ export class PromptService implements OnApplicationBootstrap {
name: string,
model: string,
messages: PromptMessage[],
config?: PromptConfig | null
config?: PromptConfig | null,
extraConfig?: { optionalModels: string[] }
) {
return await this.db.aiPrompt
.create({
data: {
name,
model,
config: config || undefined,
messages: {
create: messages.map((m, idx) => ({
idx,
...m,
attachments: m.attachments || undefined,
params: m.params || undefined,
})),
this.ensureInMemoryPrompts();
const existing = this.inMemoryPrompts.get(name);
const mergedOptionalModels = existing?.optionalModels
? [...existing.optionalModels, ...(extraConfig?.optionalModels ?? [])]
: extraConfig?.optionalModels;
const inMemoryConfig = (!!config && structuredClone(config)) || undefined;
const dbConfig = this.toDbConfig(config);
this.inMemoryPrompts.set(name, {
name,
model,
action: existing?.action,
optionalModels: mergedOptionalModels,
config: inMemoryConfig,
messages: this.cloneMessages(messages),
});
this.cache.delete(name);
try {
return await this.db.aiPrompt
.upsert({
where: { name },
create: {
name,
action: existing?.action,
model,
optionalModels: mergedOptionalModels,
config: dbConfig,
messages: {
create: messages.map((m, idx) => ({
idx,
...m,
attachments: m.attachments || undefined,
params: m.params || undefined,
})),
},
},
},
})
.then(ret => ret.id);
update: {
model,
optionalModels: mergedOptionalModels,
config: dbConfig,
updatedAt: new Date(),
messages: {
deleteMany: {},
create: messages.map((m, idx) => ({
idx,
...m,
attachments: m.attachments || undefined,
params: m.params || undefined,
})),
},
},
})
.then(ret => ret.id);
} catch (error) {
this.logger.warn(
`Compat prompt upsert failed for "${name}": ${this.stringifyError(error)}`
);
return -1;
}
}
@Transactional()
@@ -177,44 +212,123 @@ export class PromptService implements OnApplicationBootstrap {
messages?: PromptMessage[];
model?: string;
modified?: boolean;
config?: PromptConfig;
config?: PromptConfig | null;
},
where?: Prisma.AiPromptWhereInput
) {
this.ensureInMemoryPrompts();
const { config, messages, model, modified } = data;
const existing = await this.db.aiPrompt
.count({ where: { ...where, name } })
.then(count => count > 0);
if (existing) {
await this.db.aiPrompt.update({
where: { name },
data: {
config: config || undefined,
updatedAt: new Date(),
modified,
model,
messages: messages
? {
// cleanup old messages
deleteMany: {},
create: messages.map((m, idx) => ({
idx,
...m,
attachments: m.attachments || undefined,
params: m.params || undefined,
})),
}
: undefined,
},
});
const current = this.inMemoryPrompts.get(name);
if (current) {
const next = this.clonePrompt(current);
if (model !== undefined) {
next.model = model;
}
if (config === null) {
next.config = undefined;
} else if (config !== undefined) {
next.config = structuredClone(config);
}
if (messages) {
next.messages = this.cloneMessages(messages);
}
this.inMemoryPrompts.set(name, next);
this.cache.delete(name);
}
try {
const existing = await this.db.aiPrompt
.count({ where: { ...where, name } })
.then(count => count > 0);
if (existing) {
await this.db.aiPrompt.update({
where: { name },
data: {
config: this.toDbConfig(config),
updatedAt: new Date(),
modified,
model,
messages: messages
? {
// cleanup old messages
deleteMany: {},
create: messages.map((m, idx) => ({
idx,
...m,
attachments: m.attachments || undefined,
params: m.params || undefined,
})),
}
: undefined,
},
});
}
} catch (error) {
this.logger.warn(
`Compat prompt update failed for "${name}": ${this.stringifyError(error)}`
);
}
}
async delete(name: string) {
const { id } = await this.db.aiPrompt.delete({ where: { name } });
this.inMemoryPrompts.delete(name);
this.cache.delete(name);
return id;
try {
const { id } = await this.db.aiPrompt.delete({ where: { name } });
return id;
} catch (error) {
this.logger.warn(
`Compat prompt delete failed for "${name}": ${this.stringifyError(error)}`
);
return -1;
}
}
private resetInMemoryPrompts() {
this.cache.clear();
this.inMemoryPrompts.clear();
for (const prompt of prompts) {
this.inMemoryPrompts.set(prompt.name, this.clonePrompt(prompt));
}
}
private ensureInMemoryPrompts() {
if (!this.inMemoryPrompts.size) {
this.resetInMemoryPrompts();
}
}
private toDbConfig(
config: PromptConfig | null | undefined
): Prisma.InputJsonValue | Prisma.NullableJsonNullValueInput | undefined {
if (config === null) return Prisma.DbNull;
if (config === undefined) return undefined;
return config as Prisma.InputJsonValue;
}
private cloneMessages(messages: PromptMessage[]) {
return messages.map(message => ({
...message,
attachments: message.attachments ? [...message.attachments] : undefined,
params: message.params ? structuredClone(message.params) : undefined,
}));
}
private clonePrompt(prompt: Prompt): Prompt {
return {
...prompt,
optionalModels: prompt.optionalModels
? [...prompt.optionalModels]
: undefined,
config: prompt.config ? structuredClone(prompt.config) : undefined,
messages: this.cloneMessages(prompt.messages),
};
}
private stringifyError(error: unknown) {
return error instanceof Error ? error.message : String(error);
}
}

View File

@@ -1,52 +1,90 @@
import {
type AnthropicProvider as AnthropicSDKProvider,
type AnthropicProviderOptions,
} from '@ai-sdk/anthropic';
import { type GoogleVertexAnthropicProvider } from '@ai-sdk/google-vertex/anthropic';
import { AISDKError, generateText, stepCountIs, streamText } from 'ai';
import type { ToolSet } from 'ai';
import {
CopilotProviderSideError,
metrics,
UserFriendlyError,
} from '../../../../base';
import {
llmDispatchStream,
type NativeLlmBackendConfig,
type NativeLlmRequest,
} from '../../../../native';
import type { NodeTextMiddleware } from '../../config';
import { buildNativeRequest, NativeProviderAdapter } from '../native';
import { CopilotProvider } from '../provider';
import type {
CopilotChatOptions,
CopilotProviderModel,
ModelConditions,
PromptMessage,
StreamObject,
} from '../types';
import { ModelOutputType } from '../types';
import {
chatToGPTMessage,
StreamObjectParser,
TextStreamParser,
} from '../utils';
import { CopilotProviderType, ModelOutputType } from '../types';
import { getGoogleAuth, getVertexAnthropicBaseUrl } from '../utils';
export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
protected abstract instance:
| AnthropicSDKProvider
| GoogleVertexAnthropicProvider;
private handleError(e: any) {
if (e instanceof UserFriendlyError) {
return e;
} else if (e instanceof AISDKError) {
this.logger.error('Throw error from ai sdk:', e);
return new CopilotProviderSideError({
provider: this.type,
kind: e.name || 'unknown',
message: e.message,
});
} else {
return new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: e?.message || 'Unexpected anthropic response',
});
}
return new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: e?.message || 'Unexpected anthropic response',
});
}
private async createNativeConfig(): Promise<NativeLlmBackendConfig> {
if (this.type === CopilotProviderType.AnthropicVertex) {
const auth = await getGoogleAuth(this.config as any, 'anthropic');
const headers = auth.headers();
const authorization =
headers.Authorization ||
(headers as Record<string, string | undefined>).authorization;
const token =
typeof authorization === 'string'
? authorization.replace(/^Bearer\s+/i, '')
: '';
const baseUrl =
getVertexAnthropicBaseUrl(this.config as any) || auth.baseUrl;
return {
base_url: baseUrl || '',
auth_token: token,
request_layer: 'vertex',
headers,
};
}
const config = this.config as { apiKey: string; baseURL?: string };
const baseUrl = config.baseURL || 'https://api.anthropic.com/v1';
return {
base_url: baseUrl.replace(/\/v1\/?$/, ''),
auth_token: config.apiKey,
};
}
private createAdapter(
backendConfig: NativeLlmBackendConfig,
tools: ToolSet,
nodeTextMiddleware?: NodeTextMiddleware[]
) {
return new NativeProviderAdapter(
(request: NativeLlmRequest, signal?: AbortSignal) =>
llmDispatchStream('anthropic', backendConfig, request, signal),
tools,
this.MAX_STEPS,
{ nodeTextMiddleware }
);
}
private getReasoning(
options: NonNullable<CopilotChatOptions>,
model: string
): Record<string, unknown> | undefined {
if (options.reasoning && this.isReasoningModel(model)) {
return { budget_tokens: 12000, include_thought: true };
}
return undefined;
}
async text(
@@ -59,28 +97,29 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
const model = this.selectModel(fullCond);
try {
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
const [system, msgs] = await chatToGPTMessage(messages, true, true);
const modelInstance = this.instance(model.id);
const { text, reasoning } = await generateText({
model: modelInstance,
system,
messages: msgs,
abortSignal: options.signal,
providerOptions: {
anthropic: this.getAnthropicOptions(options, model.id),
},
tools: await this.getTools(options, model.id),
stopWhen: stepCountIs(this.MAX_STEPS),
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
const backendConfig = await this.createNativeConfig();
const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware();
const reasoning = this.getReasoning(options, model.id);
const { request } = await buildNativeRequest({
model: model.id,
messages,
options,
tools,
reasoning,
middleware,
});
if (!text) throw new Error('Failed to generate text');
return reasoning ? `${reasoning}\n${text}` : text;
const adapter = this.createAdapter(
backendConfig,
tools,
middleware.node?.text
);
return await adapter.text(request, options.signal);
} catch (e: any) {
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
metrics.ai
.counter('chat_text_errors')
.add(1, this.metricLabels(model.id));
throw this.handleError(e);
}
}
@@ -95,25 +134,32 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
const model = this.selectModel(fullCond);
try {
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
const fullStream = await this.getFullStream(model, messages, options);
const parser = new TextStreamParser();
for await (const chunk of fullStream) {
const result = parser.parse(chunk);
yield result;
if (options.signal?.aborted) {
await fullStream.cancel();
break;
}
}
if (!options.signal?.aborted) {
const footnotes = parser.end();
if (footnotes.length) {
yield `\n\n${footnotes}`;
}
metrics.ai
.counter('chat_text_stream_calls')
.add(1, this.metricLabels(model.id));
const backendConfig = await this.createNativeConfig();
const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware();
const { request } = await buildNativeRequest({
model: model.id,
messages,
options,
tools,
reasoning: this.getReasoning(options, model.id),
middleware,
});
const adapter = this.createAdapter(
backendConfig,
tools,
middleware.node?.text
);
for await (const chunk of adapter.streamText(request, options.signal)) {
yield chunk;
}
} catch (e: any) {
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
metrics.ai
.counter('chat_text_stream_errors')
.add(1, this.metricLabels(model.id));
throw this.handleError(e);
}
}
@@ -130,58 +176,34 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
try {
metrics.ai
.counter('chat_object_stream_calls')
.add(1, { model: model.id });
const fullStream = await this.getFullStream(model, messages, options);
const parser = new StreamObjectParser();
for await (const chunk of fullStream) {
const result = parser.parse(chunk);
if (result) {
yield result;
}
if (options.signal?.aborted) {
await fullStream.cancel();
break;
}
.add(1, this.metricLabels(model.id));
const backendConfig = await this.createNativeConfig();
const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware();
const { request } = await buildNativeRequest({
model: model.id,
messages,
options,
tools,
reasoning: this.getReasoning(options, model.id),
middleware,
});
const adapter = this.createAdapter(
backendConfig,
tools,
middleware.node?.text
);
for await (const chunk of adapter.streamObject(request, options.signal)) {
yield chunk;
}
} catch (e: any) {
metrics.ai
.counter('chat_object_stream_errors')
.add(1, { model: model.id });
.add(1, this.metricLabels(model.id));
throw this.handleError(e);
}
}
private async getFullStream(
model: CopilotProviderModel,
messages: PromptMessage[],
options: CopilotChatOptions = {}
) {
const [system, msgs] = await chatToGPTMessage(messages, true, true);
const { fullStream } = streamText({
model: this.instance(model.id),
system,
messages: msgs,
abortSignal: options.signal,
providerOptions: {
anthropic: this.getAnthropicOptions(options, model.id),
},
tools: await this.getTools(options, model.id),
stopWhen: stepCountIs(this.MAX_STEPS),
});
return fullStream;
}
private getAnthropicOptions(options: CopilotChatOptions, model: string) {
const result: AnthropicProviderOptions = {};
if (options?.reasoning && this.isReasoningModel(model)) {
result.thinking = {
type: 'enabled',
budgetTokens: 12000,
};
}
return result;
}
private isReasoningModel(model: string) {
// claude 3.5 sonnet doesn't support reasoning config
return model.includes('sonnet') && !model.startsWith('claude-3-5-sonnet');

View File

@@ -1,7 +1,3 @@
import {
type AnthropicProvider as AnthropicSDKProvider,
createAnthropic,
} from '@ai-sdk/anthropic';
import z from 'zod';
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
@@ -52,18 +48,12 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
},
];
protected instance!: AnthropicSDKProvider;
override configured(): boolean {
return !!this.config.apiKey;
}
override setup() {
super.setup();
this.instance = createAnthropic({
apiKey: this.config.apiKey,
baseURL: this.config.baseURL,
});
}
override async refreshOnlineModels() {

View File

@@ -5,7 +5,11 @@ import {
} from '@ai-sdk/google-vertex/anthropic';
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
import { getGoogleAuth, VertexModelListSchema } from '../utils';
import {
getGoogleAuth,
getVertexAnthropicBaseUrl,
VertexModelListSchema,
} from '../utils';
import { AnthropicProvider } from './anthropic';
export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings;
@@ -49,7 +53,8 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
protected instance!: GoogleVertexAnthropicProvider;
override configured(): boolean {
return !!this.config.location && !!this.config.googleAuthOptions;
if (!this.config.location || !this.config.googleAuthOptions) return false;
return !!this.config.project || !!getVertexAnthropicBaseUrl(this.config);
}
override setup() {

View File

@@ -1,16 +1,141 @@
import { Injectable, Logger } from '@nestjs/common';
import { Config } from '../../../base';
import { ServerFeature, ServerService } from '../../../core';
import type { CopilotProvider } from './provider';
import {
buildProviderRegistry,
resolveModel,
stripProviderPrefix,
} from './provider-registry';
import { CopilotProviderType, ModelFullConditions } from './types';
function isAsyncIterable(value: unknown): value is AsyncIterable<unknown> {
return (
value !== null &&
value !== undefined &&
typeof (value as AsyncIterable<unknown>)[Symbol.asyncIterator] ===
'function'
);
}
@Injectable()
export class CopilotProviderFactory {
constructor(private readonly server: ServerService) {}
constructor(
private readonly server: ServerService,
private readonly config: Config
) {}
private readonly logger = new Logger(CopilotProviderFactory.name);
readonly #providers = new Map<CopilotProviderType, CopilotProvider>();
readonly #providers = new Map<string, CopilotProvider>();
readonly #boundProviders = new Map<string, CopilotProvider>();
readonly #providerIdsByType = new Map<CopilotProviderType, Set<string>>();
private getRegistry() {
return buildProviderRegistry(this.config.copilot.providers);
}
private getPreferredProviderIds(type?: CopilotProviderType) {
if (!type) return undefined;
return this.#providerIdsByType.get(type);
}
private normalizeCond(
providerId: string,
cond: ModelFullConditions
): ModelFullConditions {
const registry = this.getRegistry();
const modelId = stripProviderPrefix(registry, providerId, cond.modelId);
return { ...cond, modelId };
}
private normalizeMethodArgs(providerId: string, args: unknown[]) {
const [first, ...rest] = args;
if (
!first ||
typeof first !== 'object' ||
Array.isArray(first) ||
!('modelId' in first)
) {
return args;
}
const cond = first as Record<string, unknown>;
if (typeof cond.modelId !== 'string') return args;
const registry = this.getRegistry();
const modelId = stripProviderPrefix(registry, providerId, cond.modelId);
return [{ ...cond, modelId }, ...rest];
}
private wrapAsyncIterable<T>(
provider: CopilotProvider,
providerId: string,
iterable: AsyncIterable<T>
): AsyncIterableIterator<T> {
const iterator = iterable[Symbol.asyncIterator]();
return {
next: value =>
provider.runWithProfile(providerId, () => iterator.next(value)),
return: value =>
provider.runWithProfile(providerId, async () => {
if (typeof iterator.return === 'function') {
return iterator.return(value as never);
}
return { done: true, value: value as T };
}),
throw: error =>
provider.runWithProfile(providerId, async () => {
if (typeof iterator.throw === 'function') {
return iterator.throw(error);
}
throw error;
}),
[Symbol.asyncIterator]() {
return this;
},
};
}
private getBoundProvider(providerId: string, provider: CopilotProvider) {
const cached = this.#boundProviders.get(providerId);
if (cached) {
return cached;
}
const wrapped = new Proxy(provider, {
get: (target, prop, receiver) => {
if (prop === 'providerId') {
return providerId;
}
const value = Reflect.get(target, prop, receiver);
if (typeof value !== 'function') {
return value;
}
return (...args: unknown[]) => {
const normalizedArgs = this.normalizeMethodArgs(providerId, args);
const result = provider.runWithProfile(providerId, () =>
Reflect.apply(value, provider, normalizedArgs)
);
if (isAsyncIterable(result)) {
return this.wrapAsyncIterable(
provider,
providerId,
result as AsyncIterable<unknown>
);
}
return result;
};
},
}) as CopilotProvider;
this.#boundProviders.set(providerId, wrapped);
return wrapped;
}
async getProvider(
cond: ModelFullConditions,
@@ -21,22 +146,41 @@ export class CopilotProviderFactory {
this.logger.debug(
`Resolving copilot provider for output type: ${cond.outputType}`
);
let candidate: CopilotProvider | null = null;
for (const [type, provider] of this.#providers.entries()) {
if (filter.prefer && filter.prefer !== type) {
const route = resolveModel({
registry: this.getRegistry(),
modelId: cond.modelId,
outputType: cond.outputType,
availableProviderIds: this.#providers.keys(),
preferredProviderIds: this.getPreferredProviderIds(filter.prefer),
});
const registry = this.getRegistry();
for (const providerId of route.candidateProviderIds) {
const provider = this.#providers.get(providerId);
if (!provider) continue;
const profile = registry.profiles.get(providerId);
const normalizedCond = this.normalizeCond(providerId, cond);
if (
normalizedCond.modelId &&
profile?.models?.length &&
!profile.models.includes(normalizedCond.modelId)
) {
continue;
}
const isMatched = await provider.match(cond);
const matched = await provider.runWithProfile(providerId, () =>
provider.match(normalizedCond)
);
if (!matched) continue;
if (isMatched) {
candidate = provider;
this.logger.debug(`Copilot provider candidate found: ${type}`);
break;
}
this.logger.debug(
`Copilot provider candidate found: ${provider.type} (${providerId})`
);
return this.getBoundProvider(providerId, provider);
}
return candidate;
return null;
}
async getProviderByModel(
@@ -46,31 +190,50 @@ export class CopilotProviderFactory {
} = {}
): Promise<CopilotProvider | null> {
this.logger.debug(`Resolving copilot provider for model: ${modelId}`);
return this.getProvider({ modelId }, filter);
}
let candidate: CopilotProvider | null = null;
for (const [type, provider] of this.#providers.entries()) {
if (filter.prefer && filter.prefer !== type) {
continue;
}
if (await provider.match({ modelId })) {
candidate = provider;
this.logger.debug(`Copilot provider candidate found: ${type}`);
register(providerId: string, provider: CopilotProvider) {
const existed = this.#providers.get(providerId);
if (existed?.type && existed.type !== provider.type) {
const ids = this.#providerIdsByType.get(existed.type);
ids?.delete(providerId);
if (!ids?.size) {
this.#providerIdsByType.delete(existed.type);
}
}
return candidate;
}
this.#providers.set(providerId, provider);
this.#boundProviders.delete(providerId);
register(provider: CopilotProvider) {
this.#providers.set(provider.type, provider);
this.logger.log(`Copilot provider [${provider.type}] registered.`);
const ids = this.#providerIdsByType.get(provider.type) ?? new Set<string>();
ids.add(providerId);
this.#providerIdsByType.set(provider.type, ids);
this.logger.log(
`Copilot provider [${provider.type}] registered as [${providerId}].`
);
this.server.enableFeature(ServerFeature.Copilot);
}
unregister(provider: CopilotProvider) {
this.#providers.delete(provider.type);
this.logger.log(`Copilot provider [${provider.type}] unregistered.`);
unregister(providerId: string, provider: CopilotProvider) {
const existed = this.#providers.get(providerId);
if (!existed || existed !== provider) {
return;
}
this.#providers.delete(providerId);
this.#boundProviders.delete(providerId);
const ids = this.#providerIdsByType.get(provider.type);
ids?.delete(providerId);
if (!ids?.size) {
this.#providerIdsByType.delete(provider.type);
}
this.logger.log(
`Copilot provider [${provider.type}] unregistered from [${providerId}].`
);
if (this.#providers.size === 0) {
this.server.disableFeature(ServerFeature.Copilot);
}

View File

@@ -0,0 +1,381 @@
import type { ToolSet } from 'ai';
import { z } from 'zod';
import type {
NativeLlmRequest,
NativeLlmStreamEvent,
NativeLlmToolDefinition,
} from '../../../native';
export type NativeDispatchFn = (
request: NativeLlmRequest,
signal?: AbortSignal
) => AsyncIterableIterator<NativeLlmStreamEvent>;
export type NativeToolCall = {
id: string;
name: string;
args: Record<string, unknown>;
thought?: string;
};
type ToolCallState = {
name?: string;
argumentsText: string;
};
type ToolExecutionResult = {
callId: string;
name: string;
args: Record<string, unknown>;
output: unknown;
isError?: boolean;
};
export class ToolCallAccumulator {
readonly #states = new Map<string, ToolCallState>();
feedDelta(event: Extract<NativeLlmStreamEvent, { type: 'tool_call_delta' }>) {
const state = this.#states.get(event.call_id) ?? {
argumentsText: '',
};
if (event.name) {
state.name = event.name;
}
if (event.arguments_delta) {
state.argumentsText += event.arguments_delta;
}
this.#states.set(event.call_id, state);
}
complete(event: Extract<NativeLlmStreamEvent, { type: 'tool_call' }>) {
const state = this.#states.get(event.call_id);
this.#states.delete(event.call_id);
return {
id: event.call_id,
name: event.name || state?.name || '',
args: this.parseArgs(
event.arguments ?? this.parseJson(state?.argumentsText ?? '{}')
),
thought: event.thought,
} satisfies NativeToolCall;
}
drainPending() {
const pending: NativeToolCall[] = [];
for (const [callId, state] of this.#states.entries()) {
if (!state.name) {
continue;
}
pending.push({
id: callId,
name: state.name,
args: this.parseArgs(this.parseJson(state.argumentsText)),
});
}
this.#states.clear();
return pending;
}
private parseJson(jsonText: string): unknown {
if (!jsonText.trim()) {
return {};
}
try {
return JSON.parse(jsonText);
} catch {
return {};
}
}
private parseArgs(value: unknown): Record<string, unknown> {
if (value && typeof value === 'object' && !Array.isArray(value)) {
return value as Record<string, unknown>;
}
return {};
}
}
export class ToolSchemaExtractor {
static extract(toolSet: ToolSet): NativeLlmToolDefinition[] {
return Object.entries(toolSet).map(([name, tool]) => {
const unknownTool = tool as Record<string, unknown>;
const inputSchema =
unknownTool.inputSchema ?? unknownTool.parameters ?? z.object({});
return {
name,
description:
typeof unknownTool.description === 'string'
? unknownTool.description
: undefined,
parameters: this.toJsonSchema(inputSchema),
};
});
}
private static toJsonSchema(schema: unknown): Record<string, unknown> {
if (!(schema instanceof z.ZodType)) {
if (schema && typeof schema === 'object' && !Array.isArray(schema)) {
return schema as Record<string, unknown>;
}
return { type: 'object', properties: {} };
}
if (schema instanceof z.ZodObject) {
const shape = schema.shape;
const properties: Record<string, unknown> = {};
const required: string[] = [];
for (const [key, child] of Object.entries(
shape as Record<string, z.ZodTypeAny>
)) {
properties[key] = this.toJsonSchema(child);
if (!this.isOptional(child)) {
required.push(key);
}
}
return {
type: 'object',
properties,
additionalProperties: false,
...(required.length ? { required } : {}),
};
}
if (schema instanceof z.ZodString) {
return { type: 'string' };
}
if (schema instanceof z.ZodNumber) {
return { type: 'number' };
}
if (schema instanceof z.ZodBoolean) {
return { type: 'boolean' };
}
if (schema instanceof z.ZodArray) {
return { type: 'array', items: this.toJsonSchema(schema.element) };
}
if (schema instanceof z.ZodEnum) {
return { type: 'string', enum: schema.options };
}
if (schema instanceof z.ZodLiteral) {
const literal = schema.value;
if (literal === null) {
return { const: null, type: 'null' };
}
if (typeof literal === 'string') {
return { const: literal, type: 'string' };
}
if (typeof literal === 'number') {
return { const: literal, type: 'number' };
}
if (typeof literal === 'boolean') {
return { const: literal, type: 'boolean' };
}
return { const: literal };
}
if (schema instanceof z.ZodUnion) {
return {
anyOf: schema.options.map((option: z.ZodTypeAny) =>
this.toJsonSchema(option)
),
};
}
if (schema instanceof z.ZodRecord) {
return {
type: 'object',
additionalProperties: this.toJsonSchema(schema.valueSchema),
};
}
if (schema instanceof z.ZodNullable) {
const inner = (schema._def as { innerType?: z.ZodTypeAny }).innerType;
return { anyOf: [this.toJsonSchema(inner), { type: 'null' }] };
}
if (schema instanceof z.ZodOptional || schema instanceof z.ZodDefault) {
return this.toJsonSchema(
(schema._def as { innerType?: z.ZodTypeAny }).innerType
);
}
if (schema instanceof z.ZodEffects) {
return this.toJsonSchema(
(schema._def as { schema?: z.ZodTypeAny }).schema
);
}
return { type: 'object', properties: {} };
}
private static isOptional(schema: z.ZodTypeAny): boolean {
if (schema instanceof z.ZodOptional || schema instanceof z.ZodDefault) {
return true;
}
if (schema instanceof z.ZodNullable) {
return this.isOptional(
(schema._def as { innerType: z.ZodTypeAny }).innerType
);
}
if (schema instanceof z.ZodEffects) {
return this.isOptional((schema._def as { schema: z.ZodTypeAny }).schema);
}
return false;
}
}
export class ToolCallLoop {
constructor(
private readonly dispatch: NativeDispatchFn,
private readonly tools: ToolSet,
private readonly maxSteps = 20
) {}
async *run(
request: NativeLlmRequest,
signal?: AbortSignal
): AsyncIterableIterator<NativeLlmStreamEvent> {
const messages = request.messages.map(message => ({
...message,
content: [...message.content],
}));
for (let step = 0; step < this.maxSteps; step++) {
const toolCalls: NativeToolCall[] = [];
const accumulator = new ToolCallAccumulator();
let finalDone: Extract<NativeLlmStreamEvent, { type: 'done' }> | null =
null;
for await (const event of this.dispatch(
{
...request,
stream: true,
messages,
},
signal
)) {
switch (event.type) {
case 'tool_call_delta': {
accumulator.feedDelta(event);
break;
}
case 'tool_call': {
toolCalls.push(accumulator.complete(event));
yield event;
break;
}
case 'done': {
finalDone = event;
break;
}
case 'error': {
throw new Error(event.message);
}
default: {
yield event;
break;
}
}
}
toolCalls.push(...accumulator.drainPending());
if (toolCalls.length === 0) {
if (finalDone) {
yield finalDone;
}
break;
}
if (step === this.maxSteps - 1) {
throw new Error('ToolCallLoop max steps reached');
}
const toolResults = await this.executeTools(toolCalls);
messages.push({
role: 'assistant',
content: toolCalls.map(call => ({
type: 'tool_call',
call_id: call.id,
name: call.name,
arguments: call.args,
thought: call.thought,
})),
});
for (const result of toolResults) {
messages.push({
role: 'tool',
content: [
{
type: 'tool_result',
call_id: result.callId,
output: result.output,
is_error: result.isError,
},
],
});
yield {
type: 'tool_result',
call_id: result.callId,
name: result.name,
arguments: result.args,
output: result.output,
is_error: result.isError,
};
}
}
}
private async executeTools(calls: NativeToolCall[]) {
return await Promise.all(calls.map(call => this.executeTool(call)));
}
private async executeTool(
call: NativeToolCall
): Promise<ToolExecutionResult> {
const tool = this.tools[call.name] as
| {
execute?: (args: Record<string, unknown>) => Promise<unknown>;
}
| undefined;
if (!tool?.execute) {
return {
callId: call.id,
name: call.name,
args: call.args,
isError: true,
output: {
message: `Tool not found: ${call.name}`,
},
};
}
try {
const output = await tool.execute(call.args);
return {
callId: call.id,
name: call.name,
args: call.args,
output: output ?? null,
};
} catch (error) {
console.error('Tool execution failed', {
callId: call.id,
toolName: call.name,
error,
});
return {
callId: call.id,
name: call.name,
args: call.args,
isError: true,
output: {
message: 'Tool execution failed',
},
};
}
}
}

View File

@@ -1,14 +1,17 @@
import {
createOpenAICompatible,
OpenAICompatibleProvider as VercelOpenAICompatibleProvider,
} from '@ai-sdk/openai-compatible';
import { AISDKError, generateText, streamText } from 'ai';
import type { ToolSet } from 'ai';
import {
CopilotProviderSideError,
metrics,
UserFriendlyError,
} from '../../../base';
import {
llmDispatchStream,
type NativeLlmBackendConfig,
type NativeLlmRequest,
} from '../../../native';
import type { NodeTextMiddleware } from '../config';
import { buildNativeRequest, NativeProviderAdapter } from './native';
import { CopilotProvider } from './provider';
import type {
CopilotChatOptions,
@@ -16,7 +19,6 @@ import type {
PromptMessage,
} from './types';
import { CopilotProviderType, ModelInputType, ModelOutputType } from './types';
import { chatToGPTMessage, TextStreamParser } from './utils';
export const DEFAULT_DIMENSIONS = 256;
@@ -57,37 +59,48 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
},
];
#instance!: VercelOpenAICompatibleProvider;
override configured(): boolean {
return !!this.config.apiKey;
}
protected override setup() {
super.setup();
this.#instance = createOpenAICompatible({
name: this.type,
apiKey: this.config.apiKey,
baseURL: 'https://api.morphllm.com/v1',
});
}
private handleError(e: any) {
if (e instanceof UserFriendlyError) {
return e;
} else if (e instanceof AISDKError) {
return new CopilotProviderSideError({
provider: this.type,
kind: e.name || 'unknown',
message: e.message,
});
} else {
return new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: e?.message || 'Unexpected morph response',
});
}
return new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: e?.message || 'Unexpected morph response',
});
}
private createNativeConfig(): NativeLlmBackendConfig {
return {
base_url: 'https://api.morphllm.com',
auth_token: this.config.apiKey ?? '',
};
}
private createNativeAdapter(
tools: ToolSet,
nodeTextMiddleware?: NodeTextMiddleware[]
) {
return new NativeProviderAdapter(
(request: NativeLlmRequest, signal?: AbortSignal) =>
llmDispatchStream(
'openai_chat',
this.createNativeConfig(),
request,
signal
),
tools,
this.MAX_STEPS,
{ nodeTextMiddleware }
);
}
async text(
@@ -103,22 +116,22 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
const model = this.selectModel(fullCond);
try {
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
const [system, msgs] = await chatToGPTMessage(messages);
const modelInstance = this.#instance(model.id);
const { text } = await generateText({
model: modelInstance,
system,
messages: msgs,
abortSignal: options.signal,
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware();
const { request } = await buildNativeRequest({
model: model.id,
messages,
options,
tools,
middleware,
});
return text.trim();
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
return await adapter.text(request, options.signal);
} catch (e: any) {
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
metrics.ai
.counter('chat_text_errors')
.add(1, this.metricLabels(model.id));
throw this.handleError(e);
}
}
@@ -136,38 +149,26 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
const model = this.selectModel(fullCond);
try {
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
const [system, msgs] = await chatToGPTMessage(messages);
const modelInstance = this.#instance(model.id);
const { fullStream } = streamText({
model: modelInstance,
system,
messages: msgs,
abortSignal: options.signal,
metrics.ai
.counter('chat_text_stream_calls')
.add(1, this.metricLabels(model.id));
const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware();
const { request } = await buildNativeRequest({
model: model.id,
messages,
options,
tools,
middleware,
});
const textParser = new TextStreamParser();
for await (const chunk of fullStream) {
switch (chunk.type) {
case 'text-delta': {
let result = textParser.parse(chunk);
yield result;
break;
}
default: {
yield textParser.parse(chunk);
break;
}
}
if (options.signal?.aborted) {
await fullStream.cancel();
break;
}
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
for await (const chunk of adapter.streamText(request, options.signal)) {
yield chunk;
}
} catch (e: any) {
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
metrics.ai
.counter('chat_text_stream_errors')
.add(1, this.metricLabels(model.id));
throw this.handleError(e);
}
}

View File

@@ -0,0 +1,464 @@
import type { ToolSet } from 'ai';
import { ZodType } from 'zod';
import type {
NativeLlmCoreContent,
NativeLlmCoreMessage,
NativeLlmRequest,
NativeLlmStreamEvent,
} from '../../../native';
import type { NodeTextMiddleware, ProviderMiddlewareConfig } from '../config';
import { NativeDispatchFn, ToolCallLoop, ToolSchemaExtractor } from './loop';
import type { CopilotChatOptions, PromptMessage, StreamObject } from './types';
import {
CitationFootnoteFormatter,
inferMimeType,
TextStreamParser,
} from './utils';
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
type BuildNativeRequestOptions = {
model: string;
messages: PromptMessage[];
options?: CopilotChatOptions;
tools?: ToolSet;
withAttachment?: boolean;
include?: string[];
reasoning?: Record<string, unknown>;
middleware?: ProviderMiddlewareConfig;
};
type BuildNativeRequestResult = {
request: NativeLlmRequest;
schema?: ZodType;
};
type ToolCallMeta = {
name: string;
args: Record<string, unknown>;
};
type NormalizedToolResultEvent = Extract<
NativeLlmStreamEvent,
{ type: 'tool_result' }
> & {
name: string;
arguments: Record<string, unknown>;
};
type AttachmentFootnote = {
blobId: string;
fileName: string;
fileType: string;
};
type NativeProviderAdapterOptions = {
nodeTextMiddleware?: NodeTextMiddleware[];
};
function roleToCore(role: PromptMessage['role']) {
switch (role) {
case 'assistant':
return 'assistant';
case 'system':
return 'system';
default:
return 'user';
}
}
async function toCoreContents(
message: PromptMessage,
withAttachment: boolean
): Promise<NativeLlmCoreContent[]> {
const contents: NativeLlmCoreContent[] = [];
if (typeof message.content === 'string' && message.content.length) {
contents.push({ type: 'text', text: message.content });
}
if (!withAttachment || !Array.isArray(message.attachments)) return contents;
for (const entry of message.attachments) {
let attachmentUrl: string;
let mediaType: string;
if (typeof entry === 'string') {
attachmentUrl = entry;
mediaType =
typeof message.params?.mimetype === 'string'
? message.params.mimetype
: await inferMimeType(entry);
} else {
attachmentUrl = entry.attachment;
mediaType = entry.mimeType;
}
if (!SIMPLE_IMAGE_URL_REGEX.test(attachmentUrl)) continue;
if (!mediaType.startsWith('image/')) continue;
contents.push({ type: 'image', source: { url: attachmentUrl } });
}
return contents;
}
export async function buildNativeRequest({
model,
messages,
options = {},
tools = {},
withAttachment = true,
include,
reasoning,
middleware,
}: BuildNativeRequestOptions): Promise<BuildNativeRequestResult> {
const copiedMessages = messages.map(message => ({
...message,
attachments: message.attachments
? [...message.attachments]
: message.attachments,
}));
const systemMessage =
copiedMessages[0]?.role === 'system' ? copiedMessages.shift() : undefined;
const schema =
systemMessage?.params?.schema instanceof ZodType
? systemMessage.params.schema
: undefined;
const coreMessages: NativeLlmCoreMessage[] = [];
if (systemMessage?.content?.length) {
coreMessages.push({
role: 'system',
content: [{ type: 'text', text: systemMessage.content }],
});
}
for (const message of copiedMessages) {
if (message.role === 'system') continue;
const content = await toCoreContents(message, withAttachment);
coreMessages.push({ role: roleToCore(message.role), content });
}
return {
request: {
model,
stream: true,
messages: coreMessages,
max_tokens: options.maxTokens ?? undefined,
temperature: options.temperature ?? undefined,
tools: ToolSchemaExtractor.extract(tools),
tool_choice: Object.keys(tools).length ? 'auto' : undefined,
include,
reasoning,
middleware: middleware?.rust
? { request: middleware.rust.request, stream: middleware.rust.stream }
: undefined,
},
schema,
};
}
function ensureToolResultMeta(
event: Extract<NativeLlmStreamEvent, { type: 'tool_result' }>,
toolCalls: Map<string, ToolCallMeta>
): NormalizedToolResultEvent | null {
const name = event.name ?? toolCalls.get(event.call_id)?.name;
const args = event.arguments ?? toolCalls.get(event.call_id)?.args;
if (!name || !args) return null;
return { ...event, name, arguments: args };
}
function pickAttachmentFootnote(value: unknown): AttachmentFootnote | null {
if (!value || typeof value !== 'object') {
return null;
}
const record = value as Record<string, unknown>;
const blobId =
typeof record.blobId === 'string'
? record.blobId
: typeof record.blob_id === 'string'
? record.blob_id
: undefined;
const fileName =
typeof record.fileName === 'string'
? record.fileName
: typeof record.name === 'string'
? record.name
: undefined;
const fileType =
typeof record.fileType === 'string'
? record.fileType
: typeof record.mimeType === 'string'
? record.mimeType
: 'application/octet-stream';
if (!blobId || !fileName) {
return null;
}
return { blobId, fileName, fileType };
}
function collectAttachmentFootnotes(
event: NormalizedToolResultEvent
): AttachmentFootnote[] {
if (event.name === 'blob_read') {
const item = pickAttachmentFootnote(event.output);
return item ? [item] : [];
}
if (event.name === 'doc_semantic_search' && Array.isArray(event.output)) {
return event.output
.map(item => pickAttachmentFootnote(item))
.filter((item): item is AttachmentFootnote => item !== null);
}
return [];
}
function formatAttachmentFootnotes(attachments: AttachmentFootnote[]) {
const references = attachments.map((_, index) => `[^${index + 1}]`).join('');
const definitions = attachments
.map((attachment, index) => {
return `[^${index + 1}]: ${JSON.stringify({
type: 'attachment',
blobId: attachment.blobId,
fileName: attachment.fileName,
fileType: attachment.fileType,
})}`;
})
.join('\n');
return `\n\n${references}\n\n${definitions}`;
}
export class NativeProviderAdapter {
readonly #loop: ToolCallLoop;
readonly #enableCallout: boolean;
readonly #enableCitationFootnote: boolean;
constructor(
dispatch: NativeDispatchFn,
tools: ToolSet,
maxSteps = 20,
options: NativeProviderAdapterOptions = {}
) {
this.#loop = new ToolCallLoop(dispatch, tools, maxSteps);
const enabledNodeTextMiddlewares = new Set(
options.nodeTextMiddleware ?? ['citation_footnote', 'callout']
);
this.#enableCallout =
enabledNodeTextMiddlewares.has('callout') ||
enabledNodeTextMiddlewares.has('thinking_format');
this.#enableCitationFootnote =
enabledNodeTextMiddlewares.has('citation_footnote');
}
async text(request: NativeLlmRequest, signal?: AbortSignal) {
let output = '';
for await (const chunk of this.streamText(request, signal)) {
output += chunk;
}
return output.trim();
}
async *streamText(
request: NativeLlmRequest,
signal?: AbortSignal
): AsyncIterableIterator<string> {
const textParser = this.#enableCallout ? new TextStreamParser() : null;
const citationFormatter = this.#enableCitationFootnote
? new CitationFootnoteFormatter()
: null;
const toolCalls = new Map<string, ToolCallMeta>();
let streamPartId = 0;
for await (const event of this.#loop.run(request, signal)) {
switch (event.type) {
case 'text_delta': {
if (textParser) {
yield textParser.parse({
type: 'text-delta',
id: String(streamPartId++),
text: event.text,
});
} else {
yield event.text;
}
break;
}
case 'reasoning_delta': {
if (textParser) {
yield textParser.parse({
type: 'reasoning-delta',
id: String(streamPartId++),
text: event.text,
});
} else {
yield event.text;
}
break;
}
case 'tool_call': {
const toolCall = {
name: event.name,
args: event.arguments,
};
toolCalls.set(event.call_id, toolCall);
if (textParser) {
yield textParser.parse({
type: 'tool-call',
toolCallId: event.call_id,
toolName: event.name as never,
input: event.arguments,
});
}
break;
}
case 'tool_result': {
const normalized = ensureToolResultMeta(event, toolCalls);
if (!normalized || !textParser) {
break;
}
yield textParser.parse({
type: 'tool-result',
toolCallId: normalized.call_id,
toolName: normalized.name as never,
input: normalized.arguments,
output: normalized.output,
});
break;
}
case 'citation': {
if (citationFormatter) {
citationFormatter.consume({
type: 'citation',
index: event.index,
url: event.url,
});
}
break;
}
case 'done': {
const footnotes = textParser?.end() ?? '';
const citations = citationFormatter?.end() ?? '';
const tails = [citations, footnotes].filter(Boolean).join('\n');
if (tails) {
yield `\n${tails}`;
}
break;
}
case 'error': {
throw new Error(event.message);
}
default:
break;
}
}
}
async *streamObject(
request: NativeLlmRequest,
signal?: AbortSignal
): AsyncIterableIterator<StreamObject> {
const toolCalls = new Map<string, ToolCallMeta>();
const citationFormatter = this.#enableCitationFootnote
? new CitationFootnoteFormatter()
: null;
const fallbackAttachmentFootnotes = new Map<string, AttachmentFootnote>();
let hasFootnoteReference = false;
for await (const event of this.#loop.run(request, signal)) {
switch (event.type) {
case 'text_delta': {
if (event.text.includes('[^')) {
hasFootnoteReference = true;
}
yield {
type: 'text-delta',
textDelta: event.text,
};
break;
}
case 'reasoning_delta': {
yield {
type: 'reasoning',
textDelta: event.text,
};
break;
}
case 'tool_call': {
const toolCall = {
name: event.name,
args: event.arguments,
};
toolCalls.set(event.call_id, toolCall);
yield {
type: 'tool-call',
toolCallId: event.call_id,
toolName: event.name,
args: event.arguments,
};
break;
}
case 'tool_result': {
const normalized = ensureToolResultMeta(event, toolCalls);
if (!normalized) {
break;
}
const attachments = collectAttachmentFootnotes(normalized);
attachments.forEach(attachment => {
fallbackAttachmentFootnotes.set(attachment.blobId, attachment);
});
yield {
type: 'tool-result',
toolCallId: normalized.call_id,
toolName: normalized.name,
args: normalized.arguments,
result: normalized.output,
};
break;
}
case 'citation': {
if (citationFormatter) {
citationFormatter.consume({
type: 'citation',
index: event.index,
url: event.url,
});
}
break;
}
case 'done': {
const citations = citationFormatter?.end() ?? '';
if (citations) {
hasFootnoteReference = true;
yield {
type: 'text-delta',
textDelta: `\n${citations}`,
};
}
if (!hasFootnoteReference && fallbackAttachmentFootnotes.size > 0) {
yield {
type: 'text-delta',
textDelta: formatAttachmentFootnotes(
Array.from(fallbackAttachmentFootnotes.values())
),
};
}
break;
}
case 'error': {
throw new Error(event.message);
}
default:
break;
}
}
}
}

View File

@@ -1,53 +1,35 @@
import {
createOpenAI,
openai,
type OpenAIProvider as VercelOpenAIProvider,
OpenAIResponsesProviderOptions,
} from '@ai-sdk/openai';
import {
createOpenAICompatible,
type OpenAICompatibleProvider as VercelOpenAICompatibleProvider,
} from '@ai-sdk/openai-compatible';
import {
AISDKError,
embedMany,
experimental_generateImage as generateImage,
generateObject,
generateText,
stepCountIs,
streamText,
Tool,
} from 'ai';
import type { Tool, ToolSet } from 'ai';
import { z } from 'zod';
import {
CopilotPromptInvalid,
CopilotProviderNotSupported,
CopilotProviderSideError,
fetchBuffer,
metrics,
OneMB,
readResponseBufferWithLimit,
safeFetch,
UserFriendlyError,
} from '../../../base';
import {
llmDispatchStream,
type NativeLlmBackendConfig,
type NativeLlmRequest,
} from '../../../native';
import type { NodeTextMiddleware } from '../config';
import { buildNativeRequest, NativeProviderAdapter } from './native';
import { CopilotProvider } from './provider';
import type {
CopilotChatOptions,
CopilotChatTools,
CopilotEmbeddingOptions,
CopilotImageOptions,
CopilotProviderModel,
CopilotStructuredOptions,
ModelConditions,
PromptMessage,
StreamObject,
} from './types';
import { CopilotProviderType, ModelInputType, ModelOutputType } from './types';
import {
chatToGPTMessage,
CitationParser,
StreamObjectParser,
TextStreamParser,
} from './utils';
import { chatToGPTMessage } from './utils';
export const DEFAULT_DIMENSIONS = 256;
@@ -63,7 +45,12 @@ const ModelListSchema = z.object({
const ImageResponseSchema = z.union([
z.object({
data: z.array(z.object({ b64_json: z.string() })),
data: z.array(
z.object({
b64_json: z.string().optional(),
url: z.string().optional(),
})
),
}),
z.object({
error: z.object({
@@ -87,6 +74,38 @@ const LogProbsSchema = z.array(
})
);
const TRUSTED_ATTACHMENT_HOST_SUFFIXES = ['cdn.affine.pro'];
function normalizeImageFormatToMime(format?: string) {
switch (format?.toLowerCase()) {
case 'jpg':
case 'jpeg':
return 'image/jpeg';
case 'webp':
return 'image/webp';
case 'png':
return 'image/png';
case 'gif':
return 'image/gif';
default:
return 'image/png';
}
}
function normalizeImageResponseData(
data: { b64_json?: string; url?: string }[],
mimeType: string = 'image/png'
) {
return data
.map(image => {
if (image.b64_json) {
return `data:${mimeType};base64,${image.b64_json}`;
}
return image.url;
})
.filter((value): value is string => typeof value === 'string');
}
export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
readonly type = CopilotProviderType.OpenAI;
@@ -319,53 +338,23 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
},
];
#instance!: VercelOpenAIProvider | VercelOpenAICompatibleProvider;
override configured(): boolean {
return !!this.config.apiKey;
}
protected override setup() {
super.setup();
this.#instance =
this.config.oldApiStyle && this.config.baseURL
? createOpenAICompatible({
name: 'openai-compatible-old-style',
apiKey: this.config.apiKey,
baseURL: this.config.baseURL,
})
: createOpenAI({
apiKey: this.config.apiKey,
baseURL: this.config.baseURL,
});
}
private handleError(
e: any,
model: string,
options: CopilotImageOptions = {}
) {
private handleError(e: any) {
if (e instanceof UserFriendlyError) {
return e;
} else if (e instanceof AISDKError) {
if (e.message.includes('safety') || e.message.includes('risk')) {
metrics.ai
.counter('chat_text_risk_errors')
.add(1, { model, user: options.user || undefined });
}
return new CopilotProviderSideError({
provider: this.type,
kind: e.name || 'unknown',
message: e.message,
});
} else {
return new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: e?.message || 'Unexpected openai response',
});
}
return new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: e?.message || 'Unexpected openai response',
});
}
override async refreshOnlineModels() {
@@ -389,20 +378,50 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
override getProviderSpecificTools(
toolName: CopilotChatTools,
model: string
_model: string
): [string, Tool?] | undefined {
if (
toolName === 'webSearch' &&
'responses' in this.#instance &&
!this.isReasoningModel(model)
) {
return ['web_search_preview', openai.tools.webSearch({})];
} else if (toolName === 'docEdit') {
if (toolName === 'docEdit') {
return ['doc_edit', undefined];
}
return;
}
private createNativeConfig(): NativeLlmBackendConfig {
const baseUrl = this.config.baseURL || 'https://api.openai.com/v1';
return {
base_url: baseUrl.replace(/\/v1\/?$/, ''),
auth_token: this.config.apiKey,
};
}
private createNativeAdapter(
tools: ToolSet,
nodeTextMiddleware?: NodeTextMiddleware[]
) {
return new NativeProviderAdapter(
(request: NativeLlmRequest, signal?: AbortSignal) =>
llmDispatchStream(
this.config.oldApiStyle ? 'openai_chat' : 'openai_responses',
this.createNativeConfig(),
request,
signal
),
tools,
this.MAX_STEPS,
{ nodeTextMiddleware }
);
}
private getReasoning(
options: NonNullable<CopilotChatOptions>,
model: string
): Record<string, unknown> | undefined {
if (options.reasoning && this.isReasoningModel(model)) {
return { effort: 'medium' };
}
return undefined;
}
async text(
cond: ModelConditions,
messages: PromptMessage[],
@@ -413,33 +432,25 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
const model = this.selectModel(fullCond);
try {
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
const [system, msgs] = await chatToGPTMessage(messages);
const modelInstance =
'responses' in this.#instance
? this.#instance.responses(model.id)
: this.#instance(model.id);
const { text } = await generateText({
model: modelInstance,
system,
messages: msgs,
temperature: options.temperature ?? 0,
maxOutputTokens: options.maxTokens ?? 4096,
providerOptions: {
openai: this.getOpenAIOptions(options, model.id),
},
tools: await this.getTools(options, model.id),
stopWhen: stepCountIs(this.MAX_STEPS),
abortSignal: options.signal,
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware();
const { request } = await buildNativeRequest({
model: model.id,
messages,
options,
tools,
include: options.webSearch ? ['citations'] : undefined,
reasoning: this.getReasoning(options, model.id),
middleware,
});
return text.trim();
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
return await adapter.text(request, options.signal);
} catch (e: any) {
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
throw this.handleError(e, model.id, options);
metrics.ai
.counter('chat_text_errors')
.add(1, this.metricLabels(model.id));
throw this.handleError(e);
}
}
@@ -456,38 +467,29 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
const model = this.selectModel(fullCond);
try {
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
const fullStream = await this.getFullStream(model, messages, options);
const citationParser = new CitationParser();
const textParser = new TextStreamParser();
for await (const chunk of fullStream) {
switch (chunk.type) {
case 'text-delta': {
let result = textParser.parse(chunk);
result = citationParser.parse(result);
yield result;
break;
}
case 'finish': {
const footnotes = textParser.end();
const result =
citationParser.end() + (footnotes.length ? '\n' + footnotes : '');
yield result;
break;
}
default: {
yield textParser.parse(chunk);
break;
}
}
if (options.signal?.aborted) {
await fullStream.cancel();
break;
}
metrics.ai
.counter('chat_text_stream_calls')
.add(1, this.metricLabels(model.id));
const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware();
const { request } = await buildNativeRequest({
model: model.id,
messages,
options,
tools,
include: options.webSearch ? ['citations'] : undefined,
reasoning: this.getReasoning(options, model.id),
middleware,
});
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
for await (const chunk of adapter.streamText(request, options.signal)) {
yield chunk;
}
} catch (e: any) {
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
throw this.handleError(e, model.id, options);
metrics.ai
.counter('chat_text_stream_errors')
.add(1, this.metricLabels(model.id));
throw this.handleError(e);
}
}
@@ -503,24 +505,27 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
try {
metrics.ai
.counter('chat_object_stream_calls')
.add(1, { model: model.id });
const fullStream = await this.getFullStream(model, messages, options);
const parser = new StreamObjectParser();
for await (const chunk of fullStream) {
const result = parser.parse(chunk);
if (result) {
yield result;
}
if (options.signal?.aborted) {
await fullStream.cancel();
break;
}
.add(1, this.metricLabels(model.id));
const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware();
const { request } = await buildNativeRequest({
model: model.id,
messages,
options,
tools,
include: options.webSearch ? ['citations'] : undefined,
reasoning: this.getReasoning(options, model.id),
middleware,
});
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
for await (const chunk of adapter.streamObject(request, options.signal)) {
yield chunk;
}
} catch (e: any) {
metrics.ai
.counter('chat_object_stream_errors')
.add(1, { model: model.id });
throw this.handleError(e, model.id, options);
.add(1, this.metricLabels(model.id));
throw this.handleError(e);
}
}
@@ -535,35 +540,27 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
try {
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
const [system, msgs, schema] = await chatToGPTMessage(messages);
const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware();
const { request, schema } = await buildNativeRequest({
model: model.id,
messages,
options,
tools,
reasoning: this.getReasoning(options, model.id),
middleware,
});
if (!schema) {
throw new CopilotPromptInvalid('Schema is required');
}
const modelInstance =
'responses' in this.#instance
? this.#instance.responses(model.id)
: this.#instance(model.id);
const { object } = await generateObject({
model: modelInstance,
system,
messages: msgs,
temperature: options.temperature ?? 0,
maxOutputTokens: options.maxTokens ?? 4096,
maxRetries: options.maxRetries ?? 3,
schema,
providerOptions: {
openai: options.user ? { user: options.user } : {},
},
abortSignal: options.signal,
});
return JSON.stringify(object);
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
const text = await adapter.text(request, options.signal);
const parsed = JSON.parse(text);
const validated = schema.parse(parsed);
return JSON.stringify(validated);
} catch (e: any) {
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
throw this.handleError(e, model.id, options);
throw this.handleError(e);
}
}
@@ -575,36 +572,32 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
const fullCond = { ...cond, outputType: ModelOutputType.Text };
await this.checkParams({ messages: [], cond: fullCond, options });
const model = this.selectModel(fullCond);
// get the log probability of "yes"/"no"
const instance =
'chat' in this.#instance
? this.#instance.chat(model.id)
: this.#instance(model.id);
const scores = await Promise.all(
chunkMessages.map(async messages => {
const [system, msgs] = await chatToGPTMessage(messages);
const result = await generateText({
model: instance,
system,
messages: msgs,
temperature: 0,
maxOutputTokens: 16,
providerOptions: {
openai: {
...this.getOpenAIOptions(options, model.id),
logprobs: 16,
},
const response = await this.requestOpenAIJson(
'/chat/completions',
{
model: model.id,
messages: this.toOpenAIChatMessages(system, msgs),
temperature: 0,
max_tokens: 16,
logprobs: true,
top_logprobs: 16,
},
abortSignal: options.signal,
});
options.signal
);
const topMap: Record<string, number> = LogProbsSchema.parse(
result.providerMetadata?.openai?.logprobs
)[0].top_logprobs.reduce<Record<string, number>>(
const logprobs = response?.choices?.[0]?.logprobs?.content;
if (!Array.isArray(logprobs) || logprobs.length === 0) {
return 0;
}
const parsedLogprobs = LogProbsSchema.parse(logprobs);
const topMap = parsedLogprobs[0].top_logprobs.reduce(
(acc, { token, logprob }) => ({ ...acc, [token]: logprob }),
{}
{} as Record<string, number>
);
const findLogProb = (token: string): number => {
@@ -634,50 +627,212 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
return scores;
}
private async getFullStream(
model: CopilotProviderModel,
messages: PromptMessage[],
options: CopilotChatOptions = {}
) {
const [system, msgs] = await chatToGPTMessage(messages);
const modelInstance =
'responses' in this.#instance
? this.#instance.responses(model.id)
: this.#instance(model.id);
const { fullStream } = streamText({
model: modelInstance,
system,
messages: msgs,
frequencyPenalty: options.frequencyPenalty ?? 0,
presencePenalty: options.presencePenalty ?? 0,
temperature: options.temperature ?? 0,
maxOutputTokens: options.maxTokens ?? 4096,
providerOptions: {
openai: this.getOpenAIOptions(options, model.id),
},
tools: await this.getTools(options, model.id),
stopWhen: stepCountIs(this.MAX_STEPS),
abortSignal: options.signal,
});
return fullStream;
// ====== text to image ======
private buildImageFetchOptions(url: URL) {
const baseOptions = { timeoutMs: 15_000, maxRedirects: 3 } as const;
const trustedOrigins = new Set<string>();
const protocol = this.AFFiNEConfig.server.https ? 'https:' : 'http:';
const port = this.AFFiNEConfig.server.port;
const isDefaultPort =
(protocol === 'https:' && port === 443) ||
(protocol === 'http:' && port === 80);
const addHostOrigin = (host: string) => {
if (!host) return;
try {
const parsed = new URL(`${protocol}//${host}`);
if (!parsed.port && !isDefaultPort) {
parsed.port = String(port);
}
trustedOrigins.add(parsed.origin);
} catch {
// ignore invalid host config entries
}
};
if (this.AFFiNEConfig.server.externalUrl) {
try {
trustedOrigins.add(
new URL(this.AFFiNEConfig.server.externalUrl).origin
);
} catch {
// ignore invalid external URL
}
}
addHostOrigin(this.AFFiNEConfig.server.host);
for (const host of this.AFFiNEConfig.server.hosts) {
addHostOrigin(host);
}
const hostname = url.hostname.toLowerCase();
const trustedByHost = TRUSTED_ATTACHMENT_HOST_SUFFIXES.some(
suffix => hostname === suffix || hostname.endsWith(`.${suffix}`)
);
if (trustedOrigins.has(url.origin) || trustedByHost) {
return { ...baseOptions, allowPrivateOrigins: new Set([url.origin]) };
}
return baseOptions;
}
private redactUrl(raw: string | URL): string {
try {
const parsed = raw instanceof URL ? raw : new URL(raw);
if (parsed.protocol === 'data:') return 'data:[redacted]';
const segments = parsed.pathname.split('/').filter(Boolean);
const redactedPath =
segments.length <= 2
? parsed.pathname || '/'
: `/${segments[0]}/${segments[1]}/...`;
return `${parsed.origin}${redactedPath}`;
} catch {
return '[invalid-url]';
}
}
private async fetchImage(
url: string,
maxBytes: number,
signal?: AbortSignal
): Promise<{ buffer: Buffer; type: string } | null> {
if (url.startsWith('data:')) {
let response: Response;
try {
response = await fetch(url, { signal });
} catch (error) {
this.logger.warn(
`Skip image attachment data URL due to read failure: ${
error instanceof Error ? error.message : String(error)
}`
);
return null;
}
if (!response.ok) {
this.logger.warn(
`Skip image attachment data URL due to invalid response: ${response.status}`
);
return null;
}
const type =
response.headers.get('content-type') || 'application/octet-stream';
if (!type.startsWith('image/')) {
await response.body?.cancel().catch(() => undefined);
this.logger.warn(
`Skip non-image attachment data URL with content-type ${type}`
);
return null;
}
try {
const buffer = await readResponseBufferWithLimit(response, maxBytes);
return { buffer, type };
} catch (error) {
this.logger.warn(
`Skip image attachment data URL due to read failure/size limit: ${
error instanceof Error ? error.message : String(error)
}`
);
return null;
}
}
let parsed: URL;
try {
parsed = new URL(url);
} catch {
this.logger.warn(
`Skip image attachment with invalid URL: ${this.redactUrl(url)}`
);
return null;
}
const redactedUrl = this.redactUrl(parsed);
if (parsed.protocol !== 'http:' && parsed.protocol !== 'https:') {
this.logger.warn(
`Skip image attachment with unsupported protocol: ${redactedUrl}`
);
return null;
}
let response: Response;
try {
response = await safeFetch(
parsed,
{ method: 'GET', signal },
this.buildImageFetchOptions(parsed)
);
} catch (error) {
this.logger.warn(
`Skip image attachment due to blocked/unreachable URL: ${redactedUrl}, reason: ${
error instanceof Error ? error.message : String(error)
}`
);
return null;
}
if (!response.ok) {
this.logger.warn(
`Skip image attachment fetch failure ${response.status}: ${redactedUrl}`
);
return null;
}
const type =
response.headers.get('content-type') || 'application/octet-stream';
if (!type.startsWith('image/')) {
await response.body?.cancel().catch(() => undefined);
this.logger.warn(
`Skip non-image attachment with content-type ${type}: ${redactedUrl}`
);
return null;
}
const contentLength = Number(response.headers.get('content-length'));
if (Number.isFinite(contentLength) && contentLength > maxBytes) {
await response.body?.cancel().catch(() => undefined);
this.logger.warn(
`Skip oversized image attachment by content-length (${contentLength}): ${redactedUrl}`
);
return null;
}
try {
const buffer = await readResponseBufferWithLimit(response, maxBytes);
return { buffer, type };
} catch (error) {
this.logger.warn(
`Skip image attachment due to read failure/size limit: ${redactedUrl}, reason: ${
error instanceof Error ? error.message : String(error)
}`
);
return null;
}
}
// ====== text to image ======
private async *generateImageWithAttachments(
model: string,
prompt: string,
attachments: NonNullable<PromptMessage['attachments']>
attachments: NonNullable<PromptMessage['attachments']>,
signal?: AbortSignal
): AsyncGenerator<string> {
const form = new FormData();
const outputFormat = 'webp';
const maxBytes = 10 * OneMB;
form.set('model', model);
form.set('prompt', prompt);
form.set('output_format', 'webp');
form.set('output_format', outputFormat);
for (const [idx, entry] of attachments.entries()) {
const url = typeof entry === 'string' ? entry : entry.attachment;
try {
const { buffer, type } = await fetchBuffer(url, 10 * OneMB, 'image/');
const file = new File([buffer], `${idx}.png`, { type });
const attachment = await this.fetchImage(url, maxBytes, signal);
if (!attachment) continue;
const { buffer, type } = attachment;
const extension = type.split(';')[0].split('/')[1] || 'png';
const file = new File([buffer], `${idx}.${extension}`, { type });
form.append('image[]', file);
} catch {
continue;
@@ -703,18 +858,24 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
const json = await res.json();
const imageResponse = ImageResponseSchema.safeParse(json);
if (imageResponse.success) {
const data = imageResponse.data;
if ('error' in data) {
throw new Error(data.error.message);
} else {
for (const image of data.data) {
yield `data:image/webp;base64,${image.b64_json}`;
}
}
} else {
if (!imageResponse.success) {
throw new Error(imageResponse.error.message);
}
const data = imageResponse.data;
if ('error' in data) {
throw new Error(data.error.message);
}
const images = normalizeImageResponseData(
data.data,
normalizeImageFormatToMime(outputFormat)
);
if (!images.length) {
throw new Error('No images returned from OpenAI');
}
for (const image of images) {
yield image;
}
}
override async *streamImages(
@@ -726,13 +887,6 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
await this.checkParams({ messages, cond: fullCond, options });
const model = this.selectModel(fullCond);
if (!('image' in this.#instance)) {
throw new CopilotProviderNotSupported({
provider: this.type,
kind: 'image',
});
}
metrics.ai
.counter('generate_images_stream_calls')
.add(1, { model: model.id });
@@ -742,22 +896,27 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
try {
if (attachments && attachments.length > 0) {
yield* this.generateImageWithAttachments(model.id, prompt, attachments);
} else {
const modelInstance = this.#instance.image(model.id);
const result = await generateImage({
model: modelInstance,
yield* this.generateImageWithAttachments(
model.id,
prompt,
providerOptions: {
openai: {
quality: options.quality || null,
},
},
});
const imageUrls = result.images.map(
image => `data:image/png;base64,${image.base64}`
attachments,
options.signal
);
} else {
const response = await this.requestOpenAIJson('/images/generations', {
model: model.id,
prompt,
...(options.quality ? { quality: options.quality } : {}),
});
const imageResponse = ImageResponseSchema.parse(response);
if ('error' in imageResponse) {
throw new Error(imageResponse.error.message);
}
const imageUrls = normalizeImageResponseData(imageResponse.data);
if (!imageUrls.length) {
throw new Error('No images returned from OpenAI');
}
for (const imageUrl of imageUrls) {
yield imageUrl;
@@ -769,7 +928,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
return;
} catch (e: any) {
metrics.ai.counter('generate_images_errors').add(1, { model: model.id });
throw this.handleError(e, model.id, options);
throw this.handleError(e);
}
}
@@ -783,51 +942,85 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
await this.checkParams({ embeddings: messages, cond: fullCond, options });
const model = this.selectModel(fullCond);
if (!('embedding' in this.#instance)) {
throw new CopilotProviderNotSupported({
provider: this.type,
kind: 'embedding',
});
}
try {
metrics.ai
.counter('generate_embedding_calls')
.add(1, { model: model.id });
const modelInstance = this.#instance.embedding(model.id);
const { embeddings } = await embedMany({
model: modelInstance,
values: messages,
providerOptions: {
openai: {
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
},
},
const response = await this.requestOpenAIJson('/embeddings', {
model: model.id,
input: messages,
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
});
return embeddings.filter(v => v && Array.isArray(v));
const data = Array.isArray(response?.data) ? response.data : [];
return data
.map((item: any) => item?.embedding)
.filter((embedding: unknown) => Array.isArray(embedding)) as number[][];
} catch (e: any) {
metrics.ai
.counter('generate_embedding_errors')
.add(1, { model: model.id });
throw this.handleError(e, model.id, options);
throw this.handleError(e);
}
}
private getOpenAIOptions(options: CopilotChatOptions, model: string) {
const result: OpenAIResponsesProviderOptions = {};
if (options?.reasoning && this.isReasoningModel(model)) {
result.reasoningEffort = 'medium';
result.reasoningSummary = 'detailed';
private toOpenAIChatMessages(
system: string | undefined,
messages: Awaited<ReturnType<typeof chatToGPTMessage>>[1]
) {
const result: Array<{ role: string; content: string }> = [];
if (system) {
result.push({ role: 'system', content: system });
}
if (options?.user) {
result.user = options.user;
for (const message of messages) {
if (typeof message.content === 'string') {
result.push({ role: message.role, content: message.content });
continue;
}
const text = message.content
.filter(
part =>
part &&
typeof part === 'object' &&
'type' in part &&
part.type === 'text' &&
'text' in part
)
.map(part => String((part as { text: string }).text))
.join('\n');
result.push({ role: message.role, content: text || '[no content]' });
}
return result;
}
private async requestOpenAIJson(
path: string,
body: Record<string, unknown>,
signal?: AbortSignal
): Promise<any> {
const baseUrl = this.config.baseURL || 'https://api.openai.com/v1';
const response = await fetch(`${baseUrl}${path}`, {
method: 'POST',
headers: {
Authorization: `Bearer ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify(body),
signal,
});
if (!response.ok) {
throw new Error(
`OpenAI API error ${response.status}: ${await response.text()}`
);
}
return await response.json();
}
private isReasoningModel(model: string) {
// o series reasoning models
return model.startsWith('o') || model.startsWith('gpt-5');

View File

@@ -1,11 +1,13 @@
import {
createPerplexity,
type PerplexityProvider as VercelPerplexityProvider,
} from '@ai-sdk/perplexity';
import { generateText, streamText } from 'ai';
import { z } from 'zod';
import type { ToolSet } from 'ai';
import { CopilotProviderSideError, metrics } from '../../../base';
import {
llmDispatchStream,
type NativeLlmBackendConfig,
type NativeLlmRequest,
} from '../../../native';
import type { NodeTextMiddleware } from '../config';
import { buildNativeRequest, NativeProviderAdapter } from './native';
import { CopilotProvider } from './provider';
import {
CopilotChatOptions,
@@ -15,34 +17,12 @@ import {
ModelOutputType,
PromptMessage,
} from './types';
import { chatToGPTMessage, CitationParser } from './utils';
export type PerplexityConfig = {
apiKey: string;
endpoint?: string;
};
const PerplexityErrorSchema = z.union([
z.object({
detail: z.array(
z.object({
loc: z.array(z.string()),
msg: z.string(),
type: z.string(),
})
),
}),
z.object({
error: z.object({
message: z.string(),
type: z.string(),
code: z.number(),
}),
}),
]);
type PerplexityError = z.infer<typeof PerplexityErrorSchema>;
export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
readonly type = CopilotProviderType.Perplexity;
@@ -90,18 +70,38 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
},
];
#instance!: VercelPerplexityProvider;
override configured(): boolean {
return !!this.config.apiKey;
}
protected override setup() {
super.setup();
this.#instance = createPerplexity({
apiKey: this.config.apiKey,
baseURL: this.config.endpoint,
});
}
private createNativeConfig(): NativeLlmBackendConfig {
const baseUrl = this.config.endpoint || 'https://api.perplexity.ai';
return {
base_url: baseUrl.replace(/\/v1\/?$/, ''),
auth_token: this.config.apiKey,
};
}
private createNativeAdapter(
tools: ToolSet,
nodeTextMiddleware?: NodeTextMiddleware[]
) {
return new NativeProviderAdapter(
(request: NativeLlmRequest, signal?: AbortSignal) =>
llmDispatchStream(
'openai_chat',
this.createNativeConfig(),
request,
signal
),
tools,
this.MAX_STEPS,
{ nodeTextMiddleware }
);
}
async text(
@@ -114,32 +114,25 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
const model = this.selectModel(fullCond);
try {
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
const [system, msgs] = await chatToGPTMessage(messages, false);
const modelInstance = this.#instance(model.id);
const { text, sources } = await generateText({
model: modelInstance,
system,
messages: msgs,
temperature: options.temperature ?? 0,
maxOutputTokens: options.maxTokens ?? 4096,
abortSignal: options.signal,
const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware();
const { request } = await buildNativeRequest({
model: model.id,
messages,
options,
tools,
withAttachment: false,
include: ['citations'],
middleware,
});
const parser = new CitationParser();
for (const source of sources.filter(s => s.sourceType === 'url')) {
parser.push(source.url);
}
let result = text.replaceAll(/<\/?think>\n/g, '\n---\n');
result = parser.parse(result);
result += parser.end();
return result;
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
return await adapter.text(request, options.signal);
} catch (e: any) {
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
metrics.ai
.counter('chat_text_errors')
.add(1, this.metricLabels(model.id));
throw this.handleError(e);
}
}
@@ -154,79 +147,33 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
const model = this.selectModel(fullCond);
try {
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
metrics.ai
.counter('chat_text_stream_calls')
.add(1, this.metricLabels(model.id));
const [system, msgs] = await chatToGPTMessage(messages, false);
const modelInstance = this.#instance(model.id);
const stream = streamText({
model: modelInstance,
system,
messages: msgs,
temperature: options.temperature ?? 0,
maxOutputTokens: options.maxTokens ?? 4096,
abortSignal: options.signal,
const tools = await this.getTools(options, model.id);
const middleware = this.getActiveProviderMiddleware();
const { request } = await buildNativeRequest({
model: model.id,
messages,
options,
tools,
withAttachment: false,
include: ['citations'],
middleware,
});
const parser = new CitationParser();
for await (const chunk of stream.fullStream) {
switch (chunk.type) {
case 'source': {
if (chunk.sourceType === 'url') {
parser.push(chunk.url);
}
break;
}
case 'text-delta': {
const text = chunk.text.replaceAll(/<\/?think>\n?/g, '\n---\n');
const result = parser.parse(text);
yield result;
break;
}
case 'finish-step': {
const result = parser.end();
yield result;
break;
}
case 'error': {
const json =
typeof chunk.error === 'string'
? JSON.parse(chunk.error)
: chunk.error;
if (json && typeof json === 'object') {
const data = PerplexityErrorSchema.parse(json);
if ('detail' in data || 'error' in data) {
throw this.convertError(data);
}
}
}
}
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
for await (const chunk of adapter.streamText(request, options.signal)) {
yield chunk;
}
} catch (e) {
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
throw e;
} catch (e: any) {
metrics.ai
.counter('chat_text_stream_errors')
.add(1, this.metricLabels(model.id));
throw this.handleError(e);
}
}
private convertError(e: PerplexityError) {
function getErrMessage(e: PerplexityError) {
let err = 'Unexpected perplexity response';
if ('detail' in e) {
err = e.detail[0].msg || err;
} else if ('error' in e) {
err = e.error.message || err;
}
return err;
}
throw new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: getErrMessage(e),
});
}
private handleError(e: any) {
if (e instanceof CopilotProviderSideError) {
return e;

View File

@@ -0,0 +1,98 @@
import type { ProviderMiddlewareConfig } from '../config';
import { CopilotProviderType } from './types';
const DEFAULT_MIDDLEWARE_BY_TYPE: Record<
CopilotProviderType,
ProviderMiddlewareConfig
> = {
[CopilotProviderType.OpenAI]: {
rust: {
request: ['normalize_messages'],
stream: ['stream_event_normalize', 'citation_indexing'],
},
node: {
text: ['citation_footnote', 'callout'],
},
},
[CopilotProviderType.Anthropic]: {
rust: {
request: ['normalize_messages', 'tool_schema_rewrite'],
stream: ['stream_event_normalize', 'citation_indexing'],
},
node: {
text: ['citation_footnote', 'callout'],
},
},
[CopilotProviderType.AnthropicVertex]: {
rust: {
request: ['normalize_messages', 'tool_schema_rewrite'],
stream: ['stream_event_normalize', 'citation_indexing'],
},
node: {
text: ['citation_footnote', 'callout'],
},
},
[CopilotProviderType.Morph]: {
rust: {
request: ['clamp_max_tokens'],
stream: ['stream_event_normalize', 'citation_indexing'],
},
node: {
text: ['citation_footnote', 'callout'],
},
},
[CopilotProviderType.Perplexity]: {
rust: {
request: ['clamp_max_tokens'],
stream: ['stream_event_normalize', 'citation_indexing'],
},
node: {
text: ['citation_footnote', 'callout'],
},
},
[CopilotProviderType.Gemini]: {
node: {
text: ['callout'],
},
},
[CopilotProviderType.GeminiVertex]: {
node: {
text: ['callout'],
},
},
[CopilotProviderType.FAL]: {},
};
function unique<T>(items: T[]) {
return [...new Set(items)];
}
function mergeArray<T>(base: T[] | undefined, override: T[] | undefined) {
if (!base?.length && !override?.length) {
return undefined;
}
return unique([...(base ?? []), ...(override ?? [])]);
}
export function mergeProviderMiddleware(
defaults: ProviderMiddlewareConfig,
override?: ProviderMiddlewareConfig
): ProviderMiddlewareConfig {
return {
rust: {
request: mergeArray(defaults.rust?.request, override?.rust?.request),
stream: mergeArray(defaults.rust?.stream, override?.rust?.stream),
},
node: {
text: mergeArray(defaults.node?.text, override?.node?.text),
},
};
}
export function resolveProviderMiddleware(
type: CopilotProviderType,
override?: ProviderMiddlewareConfig
): ProviderMiddlewareConfig {
const defaults = DEFAULT_MIDDLEWARE_BY_TYPE[type] ?? {};
return mergeProviderMiddleware(defaults, override);
}

View File

@@ -0,0 +1,273 @@
import type {
CopilotProviderConfigMap,
CopilotProviderDefaults,
CopilotProviderProfile,
ProviderMiddlewareConfig,
} from '../config';
import { resolveProviderMiddleware } from './provider-middleware';
import { CopilotProviderType, type ModelOutputType } from './types';
const PROVIDER_ID_PATTERN = /^[a-zA-Z0-9-_]+$/;
const LEGACY_PROVIDER_ORDER: CopilotProviderType[] = [
CopilotProviderType.OpenAI,
CopilotProviderType.FAL,
CopilotProviderType.Gemini,
CopilotProviderType.GeminiVertex,
CopilotProviderType.Perplexity,
CopilotProviderType.Anthropic,
CopilotProviderType.AnthropicVertex,
CopilotProviderType.Morph,
];
const LEGACY_PROVIDER_PRIORITY = LEGACY_PROVIDER_ORDER.reduce(
(acc, type, index) => {
acc[type] = LEGACY_PROVIDER_ORDER.length - index;
return acc;
},
{} as Record<CopilotProviderType, number>
);
type LegacyProvidersConfig = Partial<
Record<CopilotProviderType, CopilotProviderConfigMap[CopilotProviderType]>
>;
export type CopilotProvidersConfigInput = LegacyProvidersConfig & {
profiles?: CopilotProviderProfile[] | null;
defaults?: CopilotProviderDefaults | null;
};
export type NormalizedCopilotProviderProfile = Omit<
CopilotProviderProfile,
'enabled' | 'priority' | 'middleware'
> & {
enabled: boolean;
priority: number;
middleware: ProviderMiddlewareConfig;
};
export type CopilotProviderRegistry = {
profiles: Map<string, NormalizedCopilotProviderProfile>;
defaults: CopilotProviderDefaults;
order: string[];
byType: Map<CopilotProviderType, string[]>;
};
export type ResolveModelResult = {
rawModelId?: string;
modelId?: string;
explicitProviderId?: string;
candidateProviderIds: string[];
};
type ResolveModelOptions = {
registry: CopilotProviderRegistry;
modelId?: string;
outputType?: ModelOutputType;
availableProviderIds?: Iterable<string>;
preferredProviderIds?: Iterable<string>;
};
function unique<T>(list: T[]): T[] {
return [...new Set(list)];
}
function asArray<T>(iter?: Iterable<T>): T[] {
return iter ? Array.from(iter) : [];
}
function parseModelPrefix(
registry: CopilotProviderRegistry,
modelId: string
): { providerId: string; modelId?: string } | null {
const index = modelId.indexOf('/');
if (index <= 0) {
return null;
}
const providerId = modelId.slice(0, index);
if (!registry.profiles.has(providerId)) {
return null;
}
const model = modelId.slice(index + 1);
return { providerId, modelId: model || undefined };
}
function normalizeProfile(
profile: CopilotProviderProfile
): NormalizedCopilotProviderProfile {
return {
...profile,
enabled: profile.enabled !== false,
priority: profile.priority ?? 0,
middleware: resolveProviderMiddleware(profile.type, profile.middleware),
};
}
function toLegacyProfiles(
config: CopilotProvidersConfigInput
): CopilotProviderProfile[] {
const legacyProfiles: CopilotProviderProfile[] = [];
for (const type of LEGACY_PROVIDER_ORDER) {
const legacyConfig = config[type];
if (!legacyConfig) {
continue;
}
legacyProfiles.push({
id: `${type}-default`,
type,
priority: LEGACY_PROVIDER_PRIORITY[type],
config: legacyConfig,
} as CopilotProviderProfile);
}
return legacyProfiles;
}
function mergeProfiles(
explicitProfiles: CopilotProviderProfile[],
legacyProfiles: CopilotProviderProfile[]
): CopilotProviderProfile[] {
const profiles = new Map<string, CopilotProviderProfile>();
for (const profile of explicitProfiles) {
if (!PROVIDER_ID_PATTERN.test(profile.id)) {
throw new Error(`Invalid copilot provider profile id: ${profile.id}`);
}
if (profiles.has(profile.id)) {
throw new Error(`Duplicated copilot provider profile id: ${profile.id}`);
}
profiles.set(profile.id, profile);
}
for (const profile of legacyProfiles) {
if (!profiles.has(profile.id)) {
profiles.set(profile.id, profile);
}
}
return Array.from(profiles.values());
}
function sortProfiles(profiles: NormalizedCopilotProviderProfile[]) {
return profiles.toSorted((a, b) => {
if (a.priority !== b.priority) {
return b.priority - a.priority;
}
return a.id.localeCompare(b.id);
});
}
function assertDefaults(
defaults: CopilotProviderDefaults,
profiles: Map<string, NormalizedCopilotProviderProfile>
) {
for (const providerId of Object.values(defaults)) {
if (!providerId) {
continue;
}
if (!profiles.has(providerId)) {
throw new Error(
`Copilot provider defaults references unknown providerId: ${providerId}`
);
}
}
}
export function buildProviderRegistry(
config: CopilotProvidersConfigInput
): CopilotProviderRegistry {
const explicitProfiles = config.profiles ?? [];
const legacyProfiles = toLegacyProfiles(config);
const mergedProfiles = mergeProfiles(explicitProfiles, legacyProfiles)
.map(normalizeProfile)
.filter(profile => profile.enabled);
const sortedProfiles = sortProfiles(mergedProfiles);
const profiles = new Map(
sortedProfiles.map(profile => [profile.id, profile] as const)
);
const defaults = config.defaults ?? {};
assertDefaults(defaults, profiles);
const order = sortedProfiles.map(profile => profile.id);
const byType = new Map<CopilotProviderType, string[]>();
for (const profile of sortedProfiles) {
const ids = byType.get(profile.type) ?? [];
ids.push(profile.id);
byType.set(profile.type, ids);
}
return { profiles, defaults, order, byType };
}
export function resolveModel({
registry,
modelId,
outputType,
availableProviderIds,
preferredProviderIds,
}: ResolveModelOptions): ResolveModelResult {
const available = new Set(asArray(availableProviderIds));
const preferred = new Set(asArray(preferredProviderIds));
const hasAvailableFilter = available.size > 0;
const hasPreferredFilter = preferred.size > 0;
const isAllowed = (providerId: string) => {
const profile = registry.profiles.get(providerId);
if (!profile?.enabled) {
return false;
}
if (hasAvailableFilter && !available.has(providerId)) {
return false;
}
if (hasPreferredFilter && !preferred.has(providerId)) {
return false;
}
return true;
};
const prefixed = modelId ? parseModelPrefix(registry, modelId) : null;
if (prefixed) {
return {
rawModelId: modelId,
modelId: prefixed.modelId,
explicitProviderId: prefixed.providerId,
candidateProviderIds: isAllowed(prefixed.providerId)
? [prefixed.providerId]
: [],
};
}
const fallbackOrder = [
...(outputType ? [registry.defaults[outputType]] : []),
registry.defaults.fallback,
...registry.order,
].filter((id): id is string => !!id);
return {
rawModelId: modelId,
modelId,
candidateProviderIds: unique(
fallbackOrder.filter(providerId => isAllowed(providerId))
),
};
}
export function stripProviderPrefix(
registry: CopilotProviderRegistry,
providerId: string,
modelId?: string
) {
if (!modelId) {
return modelId;
}
const prefixed = parseModelPrefix(registry, modelId);
if (!prefixed) {
return modelId;
}
if (prefixed.providerId !== providerId) {
return modelId;
}
return prefixed.modelId;
}

View File

@@ -1,3 +1,5 @@
import { AsyncLocalStorage } from 'node:async_hooks';
import { Inject, Injectable, Logger } from '@nestjs/common';
import { ModuleRef } from '@nestjs/core';
import { Tool, ToolSet } from 'ai';
@@ -13,6 +15,7 @@ import { DocReader, DocWriter } from '../../../core/doc';
import { AccessController } from '../../../core/permission';
import { Models } from '../../../models';
import { IndexerService } from '../../indexer';
import type { ProviderMiddlewareConfig } from '../config';
import { CopilotContextService } from '../context/service';
import { PromptService } from '../prompt/service';
import {
@@ -40,6 +43,8 @@ import {
createSectionEditTool,
} from '../tools';
import { CopilotProviderFactory } from './factory';
import { resolveProviderMiddleware } from './provider-middleware';
import { buildProviderRegistry } from './provider-registry';
import {
type CopilotChatOptions,
CopilotChatTools,
@@ -58,11 +63,14 @@ import {
StreamObject,
} from './types';
const providerProfileContext = new AsyncLocalStorage<string>();
@Injectable()
export abstract class CopilotProvider<C = any> {
protected readonly logger = new Logger(this.constructor.name);
protected readonly MAX_STEPS = 20;
protected onlineModelList: string[] = [];
abstract readonly type: CopilotProviderType;
abstract readonly models: CopilotProviderModel[];
abstract configured(): boolean;
@@ -70,8 +78,39 @@ export abstract class CopilotProvider<C = any> {
@Inject() protected readonly AFFiNEConfig!: Config;
@Inject() protected readonly factory!: CopilotProviderFactory;
@Inject() protected readonly moduleRef!: ModuleRef;
readonly #registeredProviderIds = new Set<string>();
runWithProfile<T>(providerId: string, callback: () => T): T {
return providerProfileContext.run(providerId, callback);
}
protected getActiveProviderId() {
return providerProfileContext.getStore() ?? `${this.type}-default`;
}
protected getActiveProviderMiddleware(): ProviderMiddlewareConfig {
const providerId = this.getActiveProviderId();
const registry = buildProviderRegistry(this.AFFiNEConfig.copilot.providers);
const profile = registry.profiles.get(providerId);
return profile?.middleware ?? resolveProviderMiddleware(this.type);
}
protected metricLabels(
model: string,
labels: Record<string, string | number | boolean | undefined> = {}
) {
const providerId = this.getActiveProviderId();
return { model, providerId, ...labels };
}
get config(): C {
const profileId = providerProfileContext.getStore();
if (profileId) {
const profile = this.AFFiNEConfig.copilot.providers.profiles?.find(
profile => profile.id === profileId && profile.type === this.type
);
if (profile) return profile.config as C;
}
return this.AFFiNEConfig.copilot.providers[this.type] as C;
}
@@ -88,15 +127,37 @@ export abstract class CopilotProvider<C = any> {
}
protected setup() {
if (this.configured()) {
this.factory.register(this);
if (env.selfhosted) {
const registry = buildProviderRegistry(this.AFFiNEConfig.copilot.providers);
const providerIds = registry.byType.get(this.type) ?? [];
const nextProviderIds = new Set<string>();
for (const id of providerIds) {
const configured = this.runWithProfile(id, () => this.configured());
if (configured) {
nextProviderIds.add(id);
this.factory.register(id, this);
} else {
this.factory.unregister(id, this);
}
}
for (const providerId of this.#registeredProviderIds) {
if (!nextProviderIds.has(providerId)) {
this.factory.unregister(providerId, this);
}
}
this.#registeredProviderIds.clear();
for (const providerId of nextProviderIds) {
this.#registeredProviderIds.add(providerId);
}
if (env.selfhosted && nextProviderIds.size > 0) {
const [providerId] = Array.from(nextProviderIds);
this.runWithProfile(providerId, () => {
this.refreshOnlineModels().catch(e =>
this.logger.error('Failed to refresh online models', e)
);
}
} else {
this.factory.unregister(this);
});
}
}

View File

@@ -91,7 +91,9 @@ export async function chatToGPTMessage(
// so we need to use base64 encoded attachments instead
useBase64Attachment: boolean = false
): Promise<[string | undefined, ChatMessage[], ZodType?]> {
const system = messages[0]?.role === 'system' ? messages.shift() : undefined;
const hasSystem = messages[0]?.role === 'system';
const system = hasSystem ? messages[0] : undefined;
const normalizedMessages = hasSystem ? messages.slice(1) : messages;
const schema =
system?.params?.schema && system.params.schema instanceof ZodType
? system.params.schema
@@ -99,7 +101,7 @@ export async function chatToGPTMessage(
// filter redundant fields
const msgs: ChatMessage[] = [];
for (let { role, content, attachments, params } of messages.filter(
for (let { role, content, attachments, params } of normalizedMessages.filter(
m => m.role !== 'system'
)) {
content = content.trim();
@@ -406,6 +408,34 @@ export class CitationParser {
}
}
export type CitationIndexedEvent = {
type: 'citation';
index: number;
url: string;
};
export class CitationFootnoteFormatter {
private readonly citations = new Map<number, string>();
public consume(event: CitationIndexedEvent) {
if (event.type !== 'citation') {
return '';
}
this.citations.set(event.index, event.url);
return '';
}
public end() {
const footnotes = Array.from(this.citations.entries())
.sort((a, b) => a[0] - b[0])
.map(
([index, citation]) =>
`[^${index}]: {"type":"url","url":"${encodeURIComponent(citation)}"}`
);
return footnotes.join('\n');
}
}
type ChunkType = TextStreamPart<CustomAITools>['type'];
export function toError(error: unknown): Error {
@@ -703,21 +733,39 @@ export const VertexModelListSchema = z.object({
),
});
function normalizeUrl(baseURL?: string) {
if (!baseURL?.trim()) {
return undefined;
}
try {
const url = new URL(baseURL);
const serialized = url.toString();
if (serialized.endsWith('/')) return serialized.slice(0, -1);
return serialized;
} catch {
return undefined;
}
}
export function getVertexAnthropicBaseUrl(
options: GoogleVertexAnthropicProviderSettings
) {
const normalizedBaseUrl = normalizeUrl(options.baseURL);
if (normalizedBaseUrl) return normalizedBaseUrl;
const { location, project } = options;
if (!location || !project) return undefined;
return `https://${location}-aiplatform.googleapis.com/v1/projects/${project}/locations/${location}/publishers/anthropic`;
}
export async function getGoogleAuth(
options: GoogleVertexAnthropicProviderSettings | GoogleVertexProviderSettings,
publisher: 'anthropic' | 'google'
) {
function getBaseUrl() {
const { baseURL, location } = options;
if (baseURL?.trim()) {
try {
const url = new URL(baseURL);
if (url.pathname.endsWith('/')) {
url.pathname = url.pathname.slice(0, -1);
}
return url.toString();
} catch {}
} else if (location) {
const normalizedBaseUrl = normalizeUrl(options.baseURL);
if (normalizedBaseUrl) return normalizedBaseUrl;
const { location } = options;
if (location) {
return `https://${location}-aiplatform.googleapis.com/v1beta1/publishers/${publisher}`;
}
return undefined;

View File

@@ -4,7 +4,6 @@ import { BadRequestException, NotFoundException } from '@nestjs/common';
import {
Args,
Field,
Float,
ID,
InputType,
Mutation,
@@ -15,7 +14,6 @@ import {
ResolveField,
Resolver,
} from '@nestjs/graphql';
import { AiPromptRole } from '@prisma/client';
import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars';
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
@@ -313,57 +311,6 @@ class CopilotQuotaType {
used!: number;
}
registerEnumType(AiPromptRole, {
name: 'CopilotPromptMessageRole',
});
@InputType('CopilotPromptConfigInput')
@ObjectType()
class CopilotPromptConfigType {
@Field(() => Float, { nullable: true })
frequencyPenalty!: number | null;
@Field(() => Float, { nullable: true })
presencePenalty!: number | null;
@Field(() => Float, { nullable: true })
temperature!: number | null;
@Field(() => Float, { nullable: true })
topP!: number | null;
}
@InputType('CopilotPromptMessageInput')
@ObjectType()
class CopilotPromptMessageType {
@Field(() => AiPromptRole)
role!: AiPromptRole;
@Field(() => String)
content!: string;
@Field(() => GraphQLJSON, { nullable: true })
params!: Record<string, string> | null;
}
@ObjectType()
class CopilotPromptType {
@Field(() => String)
name!: string;
@Field(() => String)
model!: string;
@Field(() => String, { nullable: true })
action!: string | null;
@Field(() => CopilotPromptConfigType, { nullable: true })
config!: CopilotPromptConfigType | null;
@Field(() => [CopilotPromptMessageType])
messages!: CopilotPromptMessageType[];
}
@ObjectType()
class CopilotModelType {
@Field(() => String)
@@ -638,13 +585,8 @@ export class CopilotResolver {
);
}
@Mutation(() => String, {
description: 'Create a chat session',
})
@CallMetric('ai', 'chat_session_create')
async createCopilotSession(
@CurrentUser() user: CurrentUser,
@Args({ name: 'options', type: () => CreateChatSessionInput })
private async createCopilotSessionInternal(
user: CurrentUser,
options: CreateChatSessionInput
): Promise<string> {
// permission check based on session type
@@ -666,6 +608,42 @@ export class CopilotResolver {
});
}
@Mutation(() => String, {
description: 'Create a chat session',
deprecationReason: 'use `createCopilotSessionWithHistory` instead',
})
@CallMetric('ai', 'chat_session_create')
async createCopilotSession(
@CurrentUser() user: CurrentUser,
@Args({ name: 'options', type: () => CreateChatSessionInput })
options: CreateChatSessionInput
): Promise<string> {
return await this.createCopilotSessionInternal(user, options);
}
@Mutation(() => CopilotHistoriesType, {
description: 'Create a chat session and return full session payload',
})
@CallMetric('ai', 'chat_session_create_with_history')
async createCopilotSessionWithHistory(
@CurrentUser() user: CurrentUser,
@Args({ name: 'options', type: () => CreateChatSessionInput })
options: CreateChatSessionInput
): Promise<CopilotHistoriesType> {
const sessionId = await this.createCopilotSessionInternal(user, options);
const session = await this.chatSession.getSessionInfo(sessionId);
if (!session) {
throw new NotFoundException('Session not found');
}
return {
...session,
messages: session.messages.map(message => ({
...message,
id: message.id,
})) as ChatMessageType[],
};
}
@Mutation(() => String, {
description: 'Update a chat session',
})
@@ -939,31 +917,10 @@ export class UserCopilotResolver {
}
}
@InputType()
class CreateCopilotPromptInput {
@Field(() => String)
name!: string;
@Field(() => String)
model!: string;
@Field(() => String, { nullable: true })
action!: string | null;
@Field(() => CopilotPromptConfigType, { nullable: true })
config!: CopilotPromptConfigType | null;
@Field(() => [CopilotPromptMessageType])
messages!: CopilotPromptMessageType[];
}
@Admin()
@Resolver(() => String)
export class PromptsManagementResolver {
constructor(
private readonly cron: CopilotCronJobs,
private readonly promptService: PromptService
) {}
constructor(private readonly cron: CopilotCronJobs) {}
@Mutation(() => Boolean, {
description: 'Trigger generate missing titles cron job',
@@ -980,48 +937,4 @@ export class PromptsManagementResolver {
await this.cron.triggerCleanupTrashedDocEmbeddings();
return true;
}
@Query(() => [CopilotPromptType], {
description: 'List all copilot prompts',
})
async listCopilotPrompts() {
const prompts = await this.promptService.list();
return prompts.filter(
p =>
p.messages.length > 0 &&
// ignore internal prompts
!p.name.startsWith('workflow:') &&
!p.name.startsWith('debug:') &&
!p.name.startsWith('chat:') &&
!p.name.startsWith('action:')
);
}
@Mutation(() => CopilotPromptType, {
description: 'Create a copilot prompt',
})
async createCopilotPrompt(
@Args({ type: () => CreateCopilotPromptInput, name: 'input' })
input: CreateCopilotPromptInput
) {
await this.promptService.set(
input.name,
input.model,
input.messages,
input.config
);
return this.promptService.get(input.name);
}
@Mutation(() => CopilotPromptType, {
description: 'Update a copilot prompt',
})
async updateCopilotPrompt(
@Args('name') name: string,
@Args('messages', { type: () => [CopilotPromptMessageType] })
messages: CopilotPromptMessageType[]
) {
await this.promptService.update(name, { messages, modified: true });
return this.promptService.get(name);
}
}

View File

@@ -7,6 +7,7 @@ import { AiPromptRole } from '@prisma/client';
import { pick } from 'lodash-es';
import {
Config,
CopilotActionTaken,
CopilotMessageNotFound,
CopilotPromptNotFound,
@@ -31,6 +32,7 @@ import { ChatMessageCache } from './message';
import { ChatPrompt } from './prompt/chat-prompt';
import { PromptService } from './prompt/service';
import { CopilotProviderFactory } from './providers/factory';
import { buildProviderRegistry } from './providers/provider-registry';
import {
ModelOutputType,
type PromptMessage,
@@ -105,10 +107,31 @@ export class ChatSession implements AsyncDisposable {
hasPayment: boolean,
requestedModelId?: string
): Promise<string> {
const config = this.moduleRef.get(Config, { strict: false });
const registry = config
? buildProviderRegistry(config.copilot.providers)
: null;
const defaultModel = this.model;
const normalize = (m?: string) =>
!!m && this.optionalModels.includes(m) ? m : defaultModel;
const isPro = (m?: string) => !!m && this.proModels.includes(m);
const normalizeModel = (modelId?: string) => {
if (!modelId) return modelId;
const separatorIndex = modelId.indexOf('/');
if (separatorIndex <= 0) return modelId;
const providerId = modelId.slice(0, separatorIndex);
if (!registry?.profiles.has(providerId)) return modelId;
return modelId.slice(separatorIndex + 1);
};
const inModelList = (models: string[], modelId?: string) => {
if (!modelId) return false;
return (
models.includes(modelId) ||
models.includes(normalizeModel(modelId) ?? '')
);
};
const normalize = (m?: string) => {
if (inModelList(this.optionalModels, m)) return m;
return defaultModel;
};
const isPro = (m?: string) => inModelList(this.proModels, m);
// try resolve payment subscription service lazily
let paymentEnabled = hasPayment;
@@ -132,10 +155,19 @@ export class ChatSession implements AsyncDisposable {
}
if (paymentEnabled && !isUserAIPro && isPro(requestedModelId)) {
if (!defaultModel) {
throw new CopilotSessionInvalidInput(
'Model is required for AI subscription fallback'
);
}
return defaultModel;
}
return normalize(requestedModelId);
const resolvedModel = normalize(requestedModelId);
if (!resolvedModel) {
throw new CopilotSessionInvalidInput('Model is required');
}
return resolvedModel;
}
push(message: ChatMessage) {

View File

@@ -32,16 +32,22 @@ export const buildBlobContentGetter = (
return;
}
const contextFile = context.files.find(
file => file.blobId === blobId || file.id === blobId
);
const canonicalBlobId = contextFile?.blobId ?? blobId;
const targetFileId = contextFile?.id;
const [file, blob] = await Promise.all([
context?.getFileContent(blobId, chunk),
context?.getBlobContent(blobId, chunk),
targetFileId ? context.getFileContent(targetFileId, chunk) : undefined,
context.getBlobContent(canonicalBlobId, chunk),
]);
const content = file?.trim() || blob?.trim();
if (!content) {
return;
}
if (!content) return;
const info = contextFile
? { fileName: contextFile.name, fileType: contextFile.mimeType }
: {};
return { blobId, chunk, content };
return { blobId: canonicalBlobId, chunk, content, ...info };
};
return getBlobContent;
};

View File

@@ -14,6 +14,7 @@ import type {
import { HTMLRewriter } from 'htmlrewriter';
import {
applyAttachHeaders,
BadRequest,
Cache,
readResponseBufferWithLimit,
@@ -127,15 +128,18 @@ export class WorkerController {
if (buffer.length === 0) {
return resp.status(404).header(getCorsHeaders(origin)).send();
}
return resp
.status(200)
.header({
...getCorsHeaders(origin),
...(origin ? { Vary: 'Origin' } : {}),
'Access-Control-Allow-Methods': 'GET',
'Content-Type': 'image/*',
})
.send(buffer);
resp.header({
...getCorsHeaders(origin),
...(origin ? { Vary: 'Origin' } : {}),
'Access-Control-Allow-Methods': 'GET',
});
applyAttachHeaders(resp, { buffer });
const contentType = resp.getHeader('Content-Type') as string | undefined;
if (contentType?.startsWith('image/')) {
return resp.status(200).send(buffer);
} else {
throw new BadRequest('Invalid content type');
}
}
let response: Response;
@@ -171,39 +175,39 @@ export class WorkerController {
throw new BadRequest('Failed to fetch image');
}
if (response.ok) {
const contentType = response.headers.get('Content-Type');
if (contentType?.startsWith('image/')) {
let buffer: Buffer;
try {
buffer = await readResponseBufferWithLimit(
response,
IMAGE_PROXY_MAX_BYTES
);
} catch (error) {
if (error instanceof ResponseTooLargeError) {
this.logger.warn('Image proxy response too large', {
url: imageURL,
limitBytes: error.data?.limitBytes,
receivedBytes: error.data?.receivedBytes,
});
throw new BadRequest('Response too large');
}
throw error;
let buffer: Buffer;
try {
buffer = await readResponseBufferWithLimit(
response,
IMAGE_PROXY_MAX_BYTES
);
} catch (error) {
if (error instanceof ResponseTooLargeError) {
this.logger.warn('Image proxy response too large', {
url: imageURL,
limitBytes: error.data?.limitBytes,
receivedBytes: error.data?.receivedBytes,
});
throw new BadRequest('Response too large');
}
await this.cache.set(cachedUrl, buffer.toString('base64'), {
ttl: CACHE_TTL,
});
const contentDisposition = response.headers.get('Content-Disposition');
return resp
.status(200)
.header({
...getCorsHeaders(origin),
...(origin ? { Vary: 'Origin' } : {}),
'Access-Control-Allow-Methods': 'GET',
'Content-Type': contentType,
'Content-Disposition': contentDisposition,
})
.send(buffer);
throw error;
}
await this.cache.set(cachedUrl, buffer.toString('base64'), {
ttl: CACHE_TTL,
});
const contentDisposition = response.headers.get('Content-Disposition');
resp.header({
...getCorsHeaders(origin),
...(origin ? { Vary: 'Origin' } : {}),
'Access-Control-Allow-Methods': 'GET',
});
if (contentDisposition) {
resp.setHeader('Content-Disposition', contentDisposition);
}
applyAttachHeaders(resp, { buffer });
const contentType = resp.getHeader('Content-Type') as string | undefined;
if (contentType?.startsWith('image/')) {
return resp.status(200).send(buffer);
} else {
throw new BadRequest('Invalid content type');
}

View File

@@ -607,50 +607,10 @@ type CopilotModelsType {
proModels: [CopilotModelType!]!
}
input CopilotPromptConfigInput {
frequencyPenalty: Float
presencePenalty: Float
temperature: Float
topP: Float
}
type CopilotPromptConfigType {
frequencyPenalty: Float
presencePenalty: Float
temperature: Float
topP: Float
}
input CopilotPromptMessageInput {
content: String!
params: JSON
role: CopilotPromptMessageRole!
}
enum CopilotPromptMessageRole {
assistant
system
user
}
type CopilotPromptMessageType {
content: String!
params: JSON
role: CopilotPromptMessageRole!
}
type CopilotPromptNotFoundDataType {
name: String!
}
type CopilotPromptType {
action: String
config: CopilotPromptConfigType
messages: [CopilotPromptMessageType!]!
model: String!
name: String!
}
type CopilotProviderNotSupportedDataType {
kind: String!
provider: String!
@@ -747,14 +707,6 @@ input CreateCheckoutSessionInput {
variant: SubscriptionVariant
}
input CreateCopilotPromptInput {
action: String
config: CopilotPromptConfigInput
messages: [CopilotPromptMessageInput!]!
model: String!
name: String!
}
input CreateUserInput {
email: String!
name: String
@@ -1551,11 +1503,11 @@ type Mutation {
"""Create a chat message"""
createCopilotMessage(options: CreateChatMessageInput!): String!
"""Create a copilot prompt"""
createCopilotPrompt(input: CreateCopilotPromptInput!): CopilotPromptType!
"""Create a chat session"""
createCopilotSession(options: CreateChatSessionInput!): String!
createCopilotSession(options: CreateChatSessionInput!): String! @deprecated(reason: "use `createCopilotSessionWithHistory` instead")
"""Create a chat session and return full session payload"""
createCopilotSessionWithHistory(options: CreateChatSessionInput!): CopilotHistories!
"""Create a stripe customer portal to manage payment methods"""
createCustomerPortal: String!
@@ -1672,9 +1624,6 @@ type Mutation {
"""Update a comment content"""
updateComment(input: CommentUpdateInput!): Boolean!
"""Update a copilot prompt"""
updateCopilotPrompt(messages: [CopilotPromptMessageInput!]!, name: String!): CopilotPromptType!
"""Update a chat session"""
updateCopilotSession(options: UpdateChatSessionInput!): String!
updateDocDefaultRole(input: UpdateDocDefaultRoleInput!): Boolean!
@@ -1923,9 +1872,6 @@ type Query {
"""get workspace invitation info"""
getInviteInfo(inviteId: String!): InvitationType!
"""List all copilot prompts"""
listCopilotPrompts: [CopilotPromptType!]!
prices: [SubscriptionPrice!]!
"""Get public user by id"""

View File

@@ -1,18 +0,0 @@
query getPrompts {
listCopilotPrompts {
name
model
action
config {
frequencyPenalty
presencePenalty
temperature
topP
}
messages {
role
content
params
}
}
}

View File

@@ -1,21 +0,0 @@
mutation updatePrompt(
$name: String!
$messages: [CopilotPromptMessageInput!]!
) {
updateCopilotPrompt(name: $name, messages: $messages) {
name
model
action
config {
frequencyPenalty
presencePenalty
temperature
topP
}
messages {
role
content
params
}
}
}

View File

@@ -1,4 +1,4 @@
#import "./fragments/copilot.gql"
#import "./fragments/paginated-copilot-chats.gql"
query getCopilotDocSessions(
$workspaceId: String!

View File

@@ -1,4 +1,4 @@
#import "./fragments/copilot.gql"
#import "./fragments/paginated-copilot-chats.gql"
query getCopilotPinnedSessions(
$workspaceId: String!

View File

@@ -1,4 +1,4 @@
#import "./fragments/copilot.gql"
#import "./fragments/paginated-copilot-chats.gql"
query getCopilotWorkspaceSessions(
$workspaceId: String!

View File

@@ -1,4 +1,4 @@
#import "./fragments/copilot.gql"
#import "./fragments/paginated-copilot-chats.gql"
query getCopilotHistories(
$workspaceId: String!

View File

@@ -0,0 +1,7 @@
#import "./fragments/copilot-chat-history.gql"
mutation createCopilotSessionWithHistory($options: CreateChatSessionInput!) {
createCopilotSessionWithHistory(options: $options) {
...CopilotChatHistory
}
}

View File

@@ -1,4 +1,4 @@
#import "./fragments/copilot.gql"
#import "./fragments/paginated-copilot-chats.gql"
query getCopilotLatestDocSession(
$workspaceId: String!

View File

@@ -1,4 +1,4 @@
#import "./fragments/copilot.gql"
#import "./fragments/paginated-copilot-chats.gql"
query getCopilotSession(
$workspaceId: String!

View File

@@ -1,4 +1,4 @@
#import "./fragments/copilot.gql"
#import "./fragments/paginated-copilot-chats.gql"
query getCopilotRecentSessions(
$workspaceId: String!

View File

@@ -1,4 +1,4 @@
#import "./fragments/copilot.gql"
#import "./fragments/paginated-copilot-chats.gql"
query getCopilotSessions(
$workspaceId: String!

View File

@@ -0,0 +1,30 @@
fragment CopilotChatHistory on CopilotHistories {
sessionId
workspaceId
docId
parentSessionId
promptName
model
optionalModels
action
pinned
title
tokens
messages {
id
role
content
attachments
streamObjects {
type
textDelta
toolCallId
toolName
args
result
}
createdAt
}
createdAt
updatedAt
}

View File

@@ -1,49 +0,0 @@
fragment CopilotChatMessage on ChatMessage {
id
role
content
attachments
streamObjects {
type
textDelta
toolCallId
toolName
args
result
}
createdAt
}
fragment CopilotChatHistory on CopilotHistories {
sessionId
workspaceId
docId
parentSessionId
promptName
model
optionalModels
action
pinned
title
tokens
messages {
...CopilotChatMessage
}
createdAt
updatedAt
}
fragment PaginatedCopilotChats on PaginatedCopilotHistoriesType {
pageInfo {
hasNextPage
hasPreviousPage
startCursor
endCursor
}
edges {
cursor
node {
...CopilotChatHistory
}
}
}

View File

@@ -0,0 +1,16 @@
#import "./copilot-chat-history.gql"
fragment PaginatedCopilotChats on PaginatedCopilotHistoriesType {
pageInfo {
hasNextPage
hasPreviousPage
startCursor
endCursor
}
edges {
cursor
node {
...CopilotChatHistory
}
}
}

View File

@@ -6,21 +6,6 @@ export interface GraphQLQuery {
file?: boolean;
deprecations?: string[];
}
export const copilotChatMessageFragment = `fragment CopilotChatMessage on ChatMessage {
id
role
content
attachments
streamObjects {
type
textDelta
toolCallId
toolName
args
result
}
createdAt
}`;
export const copilotChatHistoryFragment = `fragment CopilotChatHistory on CopilotHistories {
sessionId
workspaceId
@@ -34,25 +19,23 @@ export const copilotChatHistoryFragment = `fragment CopilotChatHistory on Copilo
title
tokens
messages {
...CopilotChatMessage
id
role
content
attachments
streamObjects {
type
textDelta
toolCallId
toolName
args
result
}
createdAt
}
createdAt
updatedAt
}`;
export const paginatedCopilotChatsFragment = `fragment PaginatedCopilotChats on PaginatedCopilotHistoriesType {
pageInfo {
hasNextPage
hasPreviousPage
startCursor
endCursor
}
edges {
cursor
node {
...CopilotChatHistory
}
}
}`;
export const credentialsRequirementsFragment = `fragment CredentialsRequirements on CredentialsRequirementType {
password {
...PasswordLimits
@@ -94,6 +77,20 @@ export const currentUserProfileFragment = `fragment CurrentUserProfile on UserTy
}
}
}`;
export const paginatedCopilotChatsFragment = `fragment PaginatedCopilotChats on PaginatedCopilotHistoriesType {
pageInfo {
hasNextPage
hasPreviousPage
startCursor
endCursor
}
edges {
cursor
node {
...CopilotChatHistory
}
}
}${copilotChatHistoryFragment}`;
export const passwordLimitsFragment = `fragment PasswordLimits on PasswordLimitsType {
minLength
maxLength
@@ -404,52 +401,6 @@ export const appConfigQuery = {
}`,
};
export const getPromptsQuery = {
id: 'getPromptsQuery' as const,
op: 'getPrompts',
query: `query getPrompts {
listCopilotPrompts {
name
model
action
config {
frequencyPenalty
presencePenalty
temperature
topP
}
messages {
role
content
params
}
}
}`,
};
export const updatePromptMutation = {
id: 'updatePromptMutation' as const,
op: 'updatePrompt',
query: `mutation updatePrompt($name: String!, $messages: [CopilotPromptMessageInput!]!) {
updateCopilotPrompt(name: $name, messages: $messages) {
name
model
action
config {
frequencyPenalty
presencePenalty
temperature
topP
}
messages {
role
content
params
}
}
}`,
};
export const createUserMutation = {
id: 'createUserMutation' as const,
op: 'createUser',
@@ -1411,8 +1362,6 @@ export const getCopilotDocSessionsQuery = {
}
}
}
${copilotChatMessageFragment}
${copilotChatHistoryFragment}
${paginatedCopilotChatsFragment}`,
};
@@ -1432,8 +1381,6 @@ export const getCopilotPinnedSessionsQuery = {
}
}
}
${copilotChatMessageFragment}
${copilotChatHistoryFragment}
${paginatedCopilotChatsFragment}`,
};
@@ -1449,8 +1396,6 @@ export const getCopilotWorkspaceSessionsQuery = {
}
}
}
${copilotChatMessageFragment}
${copilotChatHistoryFragment}
${paginatedCopilotChatsFragment}`,
};
@@ -1466,8 +1411,6 @@ export const getCopilotHistoriesQuery = {
}
}
}
${copilotChatMessageFragment}
${copilotChatHistoryFragment}
${paginatedCopilotChatsFragment}`,
};
@@ -1596,12 +1539,24 @@ export const cleanupCopilotSessionMutation = {
}`,
};
export const createCopilotSessionWithHistoryMutation = {
id: 'createCopilotSessionWithHistoryMutation' as const,
op: 'createCopilotSessionWithHistory',
query: `mutation createCopilotSessionWithHistory($options: CreateChatSessionInput!) {
createCopilotSessionWithHistory(options: $options) {
...CopilotChatHistory
}
}
${copilotChatHistoryFragment}`,
};
export const createCopilotSessionMutation = {
id: 'createCopilotSessionMutation' as const,
op: 'createCopilotSession',
query: `mutation createCopilotSession($options: CreateChatSessionInput!) {
createCopilotSession(options: $options)
}`,
deprecations: ["'createCopilotSession' is deprecated: use `createCopilotSessionWithHistory` instead"],
};
export const forkCopilotSessionMutation = {
@@ -1628,8 +1583,6 @@ export const getCopilotLatestDocSessionQuery = {
}
}
}
${copilotChatMessageFragment}
${copilotChatHistoryFragment}
${paginatedCopilotChatsFragment}`,
};
@@ -1645,8 +1598,6 @@ export const getCopilotSessionQuery = {
}
}
}
${copilotChatMessageFragment}
${copilotChatHistoryFragment}
${paginatedCopilotChatsFragment}`,
};
@@ -1665,8 +1616,6 @@ export const getCopilotRecentSessionsQuery = {
}
}
}
${copilotChatMessageFragment}
${copilotChatHistoryFragment}
${paginatedCopilotChatsFragment}`,
};
@@ -1690,8 +1639,6 @@ export const getCopilotSessionsQuery = {
}
}
}
${copilotChatMessageFragment}
${copilotChatHistoryFragment}
${paginatedCopilotChatsFragment}`,
};

View File

@@ -725,54 +725,11 @@ export interface CopilotModelsType {
proModels: Array<CopilotModelType>;
}
export interface CopilotPromptConfigInput {
frequencyPenalty?: InputMaybe<Scalars['Float']['input']>;
presencePenalty?: InputMaybe<Scalars['Float']['input']>;
temperature?: InputMaybe<Scalars['Float']['input']>;
topP?: InputMaybe<Scalars['Float']['input']>;
}
export interface CopilotPromptConfigType {
__typename?: 'CopilotPromptConfigType';
frequencyPenalty: Maybe<Scalars['Float']['output']>;
presencePenalty: Maybe<Scalars['Float']['output']>;
temperature: Maybe<Scalars['Float']['output']>;
topP: Maybe<Scalars['Float']['output']>;
}
export interface CopilotPromptMessageInput {
content: Scalars['String']['input'];
params?: InputMaybe<Scalars['JSON']['input']>;
role: CopilotPromptMessageRole;
}
export enum CopilotPromptMessageRole {
assistant = 'assistant',
system = 'system',
user = 'user',
}
export interface CopilotPromptMessageType {
__typename?: 'CopilotPromptMessageType';
content: Scalars['String']['output'];
params: Maybe<Scalars['JSON']['output']>;
role: CopilotPromptMessageRole;
}
export interface CopilotPromptNotFoundDataType {
__typename?: 'CopilotPromptNotFoundDataType';
name: Scalars['String']['output'];
}
export interface CopilotPromptType {
__typename?: 'CopilotPromptType';
action: Maybe<Scalars['String']['output']>;
config: Maybe<CopilotPromptConfigType>;
messages: Array<CopilotPromptMessageType>;
model: Scalars['String']['output'];
name: Scalars['String']['output'];
}
export interface CopilotProviderNotSupportedDataType {
__typename?: 'CopilotProviderNotSupportedDataType';
kind: Scalars['String']['output'];
@@ -884,14 +841,6 @@ export interface CreateCheckoutSessionInput {
variant?: InputMaybe<SubscriptionVariant>;
}
export interface CreateCopilotPromptInput {
action?: InputMaybe<Scalars['String']['input']>;
config?: InputMaybe<CopilotPromptConfigInput>;
messages: Array<CopilotPromptMessageInput>;
model: Scalars['String']['input'];
name: Scalars['String']['input'];
}
export interface CreateUserInput {
email: Scalars['String']['input'];
name?: InputMaybe<Scalars['String']['input']>;
@@ -1752,10 +1701,13 @@ export interface Mutation {
createCopilotContext: Scalars['String']['output'];
/** Create a chat message */
createCopilotMessage: Scalars['String']['output'];
/** Create a copilot prompt */
createCopilotPrompt: CopilotPromptType;
/** Create a chat session */
/**
* Create a chat session
* @deprecated use `createCopilotSessionWithHistory` instead
*/
createCopilotSession: Scalars['String']['output'];
/** Create a chat session and return full session payload */
createCopilotSessionWithHistory: CopilotHistories;
/** Create a stripe customer portal to manage payment methods */
createCustomerPortal: Scalars['String']['output'];
createInviteLink: InviteLink;
@@ -1845,8 +1797,6 @@ export interface Mutation {
updateCalendarAccount: Maybe<CalendarAccountObjectType>;
/** Update a comment content */
updateComment: Scalars['Boolean']['output'];
/** Update a copilot prompt */
updateCopilotPrompt: CopilotPromptType;
/** Update a chat session */
updateCopilotSession: Scalars['String']['output'];
updateDocDefaultRole: Scalars['Boolean']['output'];
@@ -1998,11 +1948,11 @@ export interface MutationCreateCopilotMessageArgs {
options: CreateChatMessageInput;
}
export interface MutationCreateCopilotPromptArgs {
input: CreateCopilotPromptInput;
export interface MutationCreateCopilotSessionArgs {
options: CreateChatSessionInput;
}
export interface MutationCreateCopilotSessionArgs {
export interface MutationCreateCopilotSessionWithHistoryArgs {
options: CreateChatSessionInput;
}
@@ -2262,11 +2212,6 @@ export interface MutationUpdateCommentArgs {
input: CommentUpdateInput;
}
export interface MutationUpdateCopilotPromptArgs {
messages: Array<CopilotPromptMessageInput>;
name: Scalars['String']['input'];
}
export interface MutationUpdateCopilotSessionArgs {
options: UpdateChatSessionInput;
}
@@ -2554,8 +2499,6 @@ export interface Query {
error: ErrorDataUnion;
/** get workspace invitation info */
getInviteInfo: InvitationType;
/** List all copilot prompts */
listCopilotPrompts: Array<CopilotPromptType>;
prices: Array<SubscriptionPrice>;
/** Get public user by id */
publicUserById: Maybe<PublicUserType>;
@@ -3886,59 +3829,6 @@ export type AppConfigQueryVariables = Exact<{ [key: string]: never }>;
export type AppConfigQuery = { __typename?: 'Query'; appConfig: any };
export type GetPromptsQueryVariables = Exact<{ [key: string]: never }>;
export type GetPromptsQuery = {
__typename?: 'Query';
listCopilotPrompts: Array<{
__typename?: 'CopilotPromptType';
name: string;
model: string;
action: string | null;
config: {
__typename?: 'CopilotPromptConfigType';
frequencyPenalty: number | null;
presencePenalty: number | null;
temperature: number | null;
topP: number | null;
} | null;
messages: Array<{
__typename?: 'CopilotPromptMessageType';
role: CopilotPromptMessageRole;
content: string;
params: Record<string, string> | null;
}>;
}>;
};
export type UpdatePromptMutationVariables = Exact<{
name: Scalars['String']['input'];
messages: Array<CopilotPromptMessageInput> | CopilotPromptMessageInput;
}>;
export type UpdatePromptMutation = {
__typename?: 'Mutation';
updateCopilotPrompt: {
__typename?: 'CopilotPromptType';
name: string;
model: string;
action: string | null;
config: {
__typename?: 'CopilotPromptConfigType';
frequencyPenalty: number | null;
presencePenalty: number | null;
temperature: number | null;
topP: number | null;
} | null;
messages: Array<{
__typename?: 'CopilotPromptMessageType';
role: CopilotPromptMessageRole;
content: string;
params: Record<string, string> | null;
}>;
};
};
export type CreateUserMutationVariables = Exact<{
input: CreateUserInput;
}>;
@@ -5425,6 +5315,47 @@ export type CleanupCopilotSessionMutation = {
cleanupCopilotSession: Array<string>;
};
export type CreateCopilotSessionWithHistoryMutationVariables = Exact<{
options: CreateChatSessionInput;
}>;
export type CreateCopilotSessionWithHistoryMutation = {
__typename?: 'Mutation';
createCopilotSessionWithHistory: {
__typename?: 'CopilotHistories';
sessionId: string;
workspaceId: string;
docId: string | null;
parentSessionId: string | null;
promptName: string;
model: string;
optionalModels: Array<string>;
action: string | null;
pinned: boolean;
title: string | null;
tokens: number;
createdAt: string;
updatedAt: string;
messages: Array<{
__typename?: 'ChatMessage';
id: string | null;
role: string;
content: string;
attachments: Array<string> | null;
createdAt: string;
streamObjects: Array<{
__typename?: 'StreamObject';
type: string;
textDelta: string | null;
toolCallId: string | null;
toolName: string | null;
args: Record<string, string> | null;
result: Record<string, string> | null;
}> | null;
}>;
};
};
export type CreateCopilotSessionMutationVariables = Exact<{
options: CreateChatSessionInput;
}>;
@@ -5934,24 +5865,6 @@ export type GetDocRolePermissionsQuery = {
};
};
export type CopilotChatMessageFragment = {
__typename?: 'ChatMessage';
id: string | null;
role: string;
content: string;
attachments: Array<string> | null;
createdAt: string;
streamObjects: Array<{
__typename?: 'StreamObject';
type: string;
textDelta: string | null;
toolCallId: string | null;
toolName: string | null;
args: Record<string, string> | null;
result: Record<string, string> | null;
}> | null;
};
export type CopilotChatHistoryFragment = {
__typename?: 'CopilotHistories';
sessionId: string;
@@ -5986,6 +5899,52 @@ export type CopilotChatHistoryFragment = {
}>;
};
export type CredentialsRequirementsFragment = {
__typename?: 'CredentialsRequirementType';
password: {
__typename?: 'PasswordLimitsType';
minLength: number;
maxLength: number;
};
};
export type CurrentUserProfileFragment = {
__typename?: 'UserType';
id: string;
name: string;
email: string;
avatarUrl: string | null;
emailVerified: boolean;
features: Array<FeatureType>;
settings: {
__typename?: 'UserSettingsType';
receiveInvitationEmail: boolean;
receiveMentionEmail: boolean;
receiveCommentEmail: boolean;
};
quota: {
__typename?: 'UserQuotaType';
name: string;
blobLimit: number;
storageQuota: number;
historyPeriod: number;
memberLimit: number;
humanReadable: {
__typename?: 'UserQuotaHumanReadableType';
name: string;
blobLimit: string;
storageQuota: string;
historyPeriod: string;
memberLimit: string;
};
};
quotaUsage: { __typename?: 'UserQuotaUsageType'; storageQuota: number };
copilot: {
__typename?: 'Copilot';
quota: { __typename?: 'CopilotQuota'; limit: number | null; used: number };
};
};
export type PaginatedCopilotChatsFragment = {
__typename?: 'PaginatedCopilotHistoriesType';
pageInfo: {
@@ -6034,52 +5993,6 @@ export type PaginatedCopilotChatsFragment = {
}>;
};
export type CredentialsRequirementsFragment = {
__typename?: 'CredentialsRequirementType';
password: {
__typename?: 'PasswordLimitsType';
minLength: number;
maxLength: number;
};
};
export type CurrentUserProfileFragment = {
__typename?: 'UserType';
id: string;
name: string;
email: string;
avatarUrl: string | null;
emailVerified: boolean;
features: Array<FeatureType>;
settings: {
__typename?: 'UserSettingsType';
receiveInvitationEmail: boolean;
receiveMentionEmail: boolean;
receiveCommentEmail: boolean;
};
quota: {
__typename?: 'UserQuotaType';
name: string;
blobLimit: number;
storageQuota: number;
historyPeriod: number;
memberLimit: number;
humanReadable: {
__typename?: 'UserQuotaHumanReadableType';
name: string;
blobLimit: string;
storageQuota: string;
historyPeriod: string;
memberLimit: string;
};
};
quotaUsage: { __typename?: 'UserQuotaUsageType'; storageQuota: number };
copilot: {
__typename?: 'Copilot';
quota: { __typename?: 'CopilotQuota'; limit: number | null; used: number };
};
};
export type PasswordLimitsFragment = {
__typename?: 'PasswordLimitsType';
minLength: number;
@@ -7623,11 +7536,6 @@ export type Queries =
variables: AppConfigQueryVariables;
response: AppConfigQuery;
}
| {
name: 'getPromptsQuery';
variables: GetPromptsQueryVariables;
response: GetPromptsQuery;
}
| {
name: 'getUserByEmailQuery';
variables: GetUserByEmailQueryVariables;
@@ -8035,11 +7943,6 @@ export type Mutations =
variables: CreateChangePasswordUrlMutationVariables;
response: CreateChangePasswordUrlMutation;
}
| {
name: 'updatePromptMutation';
variables: UpdatePromptMutationVariables;
response: UpdatePromptMutation;
}
| {
name: 'createUserMutation';
variables: CreateUserMutationVariables;
@@ -8275,6 +8178,11 @@ export type Mutations =
variables: CleanupCopilotSessionMutationVariables;
response: CleanupCopilotSessionMutation;
}
| {
name: 'createCopilotSessionWithHistoryMutation';
variables: CreateCopilotSessionWithHistoryMutationVariables;
response: CreateCopilotSessionWithHistoryMutation;
}
| {
name: 'createCopilotSessionMutation';
variables: CreateCopilotSessionMutationVariables;

View File

@@ -313,6 +313,14 @@
"type": "Object",
"desc": "Use custom models in scenarios and override default settings."
},
"providers.profiles": {
"type": "Array",
"desc": "The profile list for copilot providers."
},
"providers.defaults": {
"type": "Object",
"desc": "The default provider ids for model output types and global fallback."
},
"providers.openai": {
"type": "Object",
"desc": "The config for the openai provider.",

View File

@@ -1,146 +0,0 @@
import { ScrollArea } from '@affine/admin/components/ui/scroll-area';
import { Separator } from '@affine/admin/components/ui/separator';
import { Textarea } from '@affine/admin/components/ui/textarea';
import { useCallback, useEffect, useMemo, useState } from 'react';
import { RightPanelHeader } from '../header';
import { useRightPanel } from '../panel/context';
import type { Prompt } from './prompts';
import { usePrompt } from './use-prompt';
export function EditPrompt({
item,
setCanSave,
}: {
item: Prompt;
setCanSave: (changed: boolean) => void;
}) {
const { closePanel } = useRightPanel();
const [messages, setMessages] = useState(item.messages);
const { updatePrompt } = usePrompt();
const disableSave = useMemo(
() => JSON.stringify(messages) === JSON.stringify(item.messages),
[item.messages, messages]
);
const handleChange = useCallback(
(e: React.ChangeEvent<HTMLTextAreaElement>, index: number) => {
const newMessages = [...messages];
newMessages[index] = {
...newMessages[index],
content: e.target.value,
};
setMessages(newMessages);
setCanSave(!disableSave);
},
[disableSave, messages, setCanSave]
);
const handleClose = useCallback(() => {
setMessages(item.messages);
closePanel();
}, [closePanel, item.messages]);
const onConfirm = useCallback(() => {
if (!disableSave) {
updatePrompt({ name: item.name, messages });
}
handleClose();
}, [disableSave, handleClose, item.name, messages, updatePrompt]);
useEffect(() => {
setMessages(item.messages);
}, [item.messages]);
return (
<div className="flex flex-col h-full gap-1">
<RightPanelHeader
title="Edit Prompt"
handleClose={handleClose}
handleConfirm={onConfirm}
canSave={!disableSave}
/>
<ScrollArea>
<div className="grid">
<div className="px-5 py-4 overflow-y-auto space-y-[10px] flex flex-col gap-5">
<div className="flex flex-col">
<div className="text-sm font-medium">Name</div>
<div className="text-sm font-normal text-muted-foreground">
{item.name}
</div>
</div>
{item.action ? (
<div className="flex flex-col">
<div className="text-sm font-medium">Action</div>
<div className="text-sm font-normal text-muted-foreground">
{item.action}
</div>
</div>
) : null}
<div className="flex flex-col">
<div className="text-sm font-medium">Model</div>
<div className="text-sm font-normal text-muted-foreground">
{item.model}
</div>
</div>
{item.config ? (
<div className="flex flex-col border rounded p-3">
<div className="text-sm font-medium">Config</div>
{Object.entries(item.config).map(([key, value], index) => (
<div key={key} className="flex flex-col">
{index !== 0 && <Separator />}
<span className="text-sm font-normal">{key}</span>
<span className="text-sm font-normal text-muted-foreground">
{value?.toString()}
</span>
</div>
))}
</div>
) : null}
</div>
<div className="px-5 py-4 overflow-y-auto space-y-[10px] flex flex-col">
<div className="text-sm font-medium">Messages</div>
{messages.map((message, index) => (
<div key={message.content} className="flex flex-col gap-3">
{index !== 0 && <Separator />}
<div>
<div className="text-sm font-normal">Role</div>
<div className="text-sm font-normal text-muted-foreground">
{message.role}
</div>
</div>
{message.params ? (
<div>
<div className="text-sm font-medium">Params</div>
{Object.entries(message.params).map(
([key, value], index) => (
<div key={key} className="flex flex-col">
{index !== 0 && <Separator />}
<span className="text-sm font-normal">{key}</span>
<span
className="text-sm font-normal text-muted-foreground"
style={{ overflowWrap: 'break-word' }}
>
{value.toString()}
</span>
</div>
)
)}
</div>
) : null}
<div className="text-sm font-normal">Content</div>
<Textarea
className=" min-h-48"
value={message.content}
onChange={e => handleChange(e, index)}
/>
</div>
))}
</div>
</div>
</ScrollArea>
</div>
);
}

View File

@@ -32,7 +32,6 @@ function AiPage() {
/>
</div>
</div>
{/* <Prompts /> */}
</ScrollAreaPrimitive.Viewport>
<ScrollAreaPrimitive.ScrollAreaScrollbar
className={cn(

View File

@@ -1,108 +0,0 @@
import { Button } from '@affine/admin/components/ui/button';
import { Separator } from '@affine/admin/components/ui/separator';
import type { CopilotPromptMessageRole } from '@affine/graphql';
import { useCallback, useState } from 'react';
import { DiscardChanges } from '../../components/shared/discard-changes';
import { useRightPanel } from '../panel/context';
import { EditPrompt } from './edit-prompt';
import { usePrompt } from './use-prompt';
export type Prompt = {
__typename?: 'CopilotPromptType';
name: string;
model: string;
action: string | null;
config: {
__typename?: 'CopilotPromptConfigType';
frequencyPenalty: number | null;
presencePenalty: number | null;
temperature: number | null;
topP: number | null;
} | null;
messages: Array<{
__typename?: 'CopilotPromptMessageType';
role: CopilotPromptMessageRole;
content: string;
params: Record<string, string> | null;
}>;
};
export function Prompts() {
const { prompts: list } = usePrompt();
return (
<div className="flex flex-col h-full gap-3 py-5 px-6 w-full">
<div className="flex items-center">
<span className="text-xl font-semibold">Prompts</span>
</div>
<div className="flex-grow overflow-y-auto space-y-[10px]">
<div className="flex flex-col rounded-md border w-full">
{list.map((item, index) => (
<PromptRow
key={`${item.name}-${index}`}
item={item}
index={index}
/>
))}
</div>
</div>
</div>
);
}
export const PromptRow = ({ item, index }: { item: Prompt; index: number }) => {
const { setPanelContent, openPanel, isOpen } = useRightPanel();
const [dialogOpen, setDialogOpen] = useState(false);
const [canSave, setCanSave] = useState(false);
const handleDiscardChangesCancel = useCallback(() => {
setDialogOpen(false);
setCanSave(false);
}, []);
const handleConfirm = useCallback(
(item: Prompt) => {
setPanelContent(<EditPrompt item={item} setCanSave={setCanSave} />);
if (dialogOpen) {
handleDiscardChangesCancel();
}
if (!isOpen) {
openPanel();
}
},
[dialogOpen, handleDiscardChangesCancel, isOpen, openPanel, setPanelContent]
);
const handleEdit = useCallback(
(item: Prompt) => {
if (isOpen && canSave) {
setDialogOpen(true);
} else {
handleConfirm(item);
}
},
[canSave, handleConfirm, isOpen]
);
return (
<div>
{index !== 0 && <Separator />}
<Button
variant="ghost"
className="flex flex-col gap-1 w-full items-start px-6 py-[14px] h-full "
onClick={() => handleEdit(item)}
>
<div>{item.name}</div>
<div className="text-left w-full opacity-50 overflow-hidden text-ellipsis whitespace-nowrap break-words text-nowrap">
{item.messages.flatMap(message => message.content).join(' ')}
</div>
</Button>
<DiscardChanges
open={dialogOpen}
onOpenChange={setDialogOpen}
onClose={handleDiscardChangesCancel}
onConfirm={() => handleConfirm(item)}
/>
</div>
);
};

View File

@@ -1,51 +0,0 @@
import {
useMutateQueryResource,
useMutation,
} from '@affine/admin/use-mutation';
import { useQuery } from '@affine/admin/use-query';
import { useAsyncCallback } from '@affine/core/components/hooks/affine-async-hooks';
import { getPromptsQuery, updatePromptMutation } from '@affine/graphql';
import { toast } from 'sonner';
import type { Prompt } from './prompts';
export const usePrompt = () => {
const { data } = useQuery({
query: getPromptsQuery,
});
const { trigger } = useMutation({
mutation: updatePromptMutation,
});
const revalidate = useMutateQueryResource();
const updatePrompt = useAsyncCallback(
async ({
name,
messages,
}: {
name: string;
messages: Prompt['messages'];
}) => {
await trigger({
name,
messages,
})
.then(async () => {
await revalidate(getPromptsQuery);
toast.success('Prompt updated successfully');
})
.catch(e => {
toast(e.message);
console.error(e);
});
},
[revalidate, trigger]
);
return {
prompts: data.listCopilotPrompts,
updatePrompt,
};
};

View File

@@ -411,6 +411,9 @@ declare global {
interface AISessionService {
createSession: (options: AICreateSessionOptions) => Promise<string>;
createSessionWithHistory: (
options: AICreateSessionOptions
) => Promise<CopilotChatHistoryFragment | undefined>;
getSession: (
workspaceId: string,
sessionId: string

View File

@@ -185,6 +185,9 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
@state()
private accessor hasMore = true;
@state()
private accessor selectedSessionId: string | undefined;
private accessor currentOffset = 0;
private readonly pageSize = 10;
@@ -267,9 +270,16 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
override connectedCallback() {
super.connectedCallback();
this.selectedSessionId = this.session?.sessionId ?? undefined;
this.getRecentSessions().catch(console.error);
}
protected override willUpdate(changedProperties: PropertyValues) {
if (changedProperties.has('session')) {
this.selectedSessionId = this.session?.sessionId ?? undefined;
}
}
override firstUpdated(changedProperties: PropertyValues) {
super.firstUpdated(changedProperties);
this.disposables.add(() => {
@@ -294,9 +304,10 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
class="ai-session-item"
@click=${(e: MouseEvent) => {
e.stopPropagation();
this.selectedSessionId = session.sessionId;
this.onSessionClick(session.sessionId);
}}
aria-selected=${this.session?.sessionId === session.sessionId}
aria-selected=${this.selectedSessionId === session.sessionId}
data-session-id=${session.sessionId}
>
<div class="ai-session-title">
@@ -332,6 +343,7 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
class="ai-session-doc"
@click=${(e: MouseEvent) => {
e.stopPropagation();
this.selectedSessionId = sessionId;
this.onDocClick(docId, sessionId);
}}
>

View File

@@ -152,6 +152,7 @@ export class AIProvider {
}>(),
// downstream can emit this slot to notify ai presets that user info has been updated
userInfo: new Subject<AIUserInfo | null>(),
sessionReady: new BehaviorSubject<boolean>(false),
previewPanelOpenChange: new Subject<boolean>(),
/* eslint-enable rxjs/finnish */
};
@@ -344,6 +345,7 @@ export class AIProvider {
} else if (id === 'session') {
AIProvider.instance.session =
action as BlockSuitePresets.AISessionService;
AIProvider.instance.slots.sessionReady.next(true);
} else if (id === 'context') {
AIProvider.instance.context =
action as BlockSuitePresets.AIContextService;

View File

@@ -11,6 +11,7 @@ import {
createCopilotContextMutation,
createCopilotMessageMutation,
createCopilotSessionMutation,
createCopilotSessionWithHistoryMutation,
forkCopilotSessionMutation,
getCopilotHistoriesQuery,
getCopilotHistoryIdsQuery,
@@ -41,7 +42,6 @@ import {
} from './error';
export enum Endpoint {
Stream = 'stream',
StreamObject = 'stream-object',
Workflow = 'workflow',
Images = 'images',
@@ -96,7 +96,6 @@ export class CopilotClient {
readonly gql: <Query extends GraphQLQuery>(
options: QueryOptions<Query>
) => Promise<QueryResponse<Query>>,
readonly fetcher: (input: string, init?: RequestInit) => Promise<Response>,
readonly eventSource: (
url: string,
eventSourceInitDict?: EventSourceInit
@@ -119,6 +118,20 @@ export class CopilotClient {
}
}
async createSessionWithHistory(
options: OptionsField<typeof createCopilotSessionWithHistoryMutation>
) {
try {
const res = await this.gql({
query: createCopilotSessionWithHistoryMutation,
variables: { options },
});
return res.createCopilotSessionWithHistory;
} catch (err) {
throw resolveError(err);
}
}
async updateSession(
options: OptionsField<typeof updateCopilotSessionMutation>
) {
@@ -150,7 +163,11 @@ export class CopilotClient {
}
async createMessage(
options: OptionsField<typeof createCopilotMessageMutation>
options: OptionsField<typeof createCopilotMessageMutation>,
requestOptions?: Pick<
RequestOptions<typeof createCopilotMessageMutation>,
'timeout' | 'signal'
>
) {
try {
const res = await this.gql({
@@ -158,6 +175,8 @@ export class CopilotClient {
variables: {
options,
},
timeout: requestOptions?.timeout,
signal: requestOptions?.signal,
});
return res.createCopilotMessage;
} catch (err) {
@@ -442,35 +461,6 @@ export class CopilotClient {
return { files, docs };
}
async chatText({
sessionId,
messageId,
reasoning,
modelId,
toolsConfig,
signal,
}: {
sessionId: string;
messageId?: string;
reasoning?: boolean;
modelId?: string;
toolsConfig?: AIToolsConfig;
signal?: AbortSignal;
}) {
let url = `/api/copilot/chat/${sessionId}`;
const queryString = this.paramsToQueryString({
messageId,
reasoning,
modelId,
toolsConfig,
});
if (queryString) {
url += `?${queryString}`;
}
const response = await this.fetcher(url.toString(), { signal });
return response.text();
}
// Text or image to text
chatTextStream(
{
@@ -486,7 +476,7 @@ export class CopilotClient {
modelId?: string;
toolsConfig?: AIToolsConfig;
},
endpoint = Endpoint.Stream
endpoint = Endpoint.StreamObject
) {
let url = `/api/copilot/chat/${sessionId}/${endpoint}`;
const queryString = this.paramsToQueryString({

View File

@@ -3,7 +3,7 @@ import { partition } from 'lodash-es';
import { AIProvider } from './ai-provider';
import { type CopilotClient, Endpoint } from './copilot-client';
import { delay, toTextStream } from './event-source';
import { toTextStream } from './event-source';
const TIMEOUT = 50000;
@@ -67,6 +67,8 @@ interface CreateMessageOptions {
content?: string;
attachments?: (string | Blob | File)[];
params?: Record<string, any>;
timeout?: number;
signal?: AbortSignal;
}
async function createMessage({
@@ -75,6 +77,8 @@ async function createMessage({
content,
attachments,
params,
timeout,
signal,
}: CreateMessageOptions): Promise<string> {
const hasAttachments = attachments && attachments.length > 0;
const options: Parameters<CopilotClient['createMessage']>[0] = {
@@ -102,7 +106,7 @@ async function createMessage({
).filter(Boolean) as File[];
}
return await client.createMessage(options);
return await client.createMessage(options, { timeout, signal });
}
export function textToText({
@@ -115,7 +119,7 @@ export function textToText({
signal,
timeout = TIMEOUT,
retry = false,
endpoint = Endpoint.Stream,
endpoint = Endpoint.StreamObject,
postfix,
reasoning,
modelId,
@@ -133,6 +137,8 @@ export function textToText({
content,
attachments,
params,
timeout,
signal,
});
}
const eventSource = client.chatTextStream(
@@ -147,65 +153,105 @@ export function textToText({
);
AIProvider.LAST_ACTION_SESSIONID = sessionId;
if (signal) {
if (signal.aborted) {
eventSource.close();
return;
let onAbort: (() => void) | undefined;
try {
if (signal) {
if (signal.aborted) {
eventSource.close();
return;
}
onAbort = () => {
eventSource.close();
};
signal.addEventListener('abort', onAbort, { once: true });
}
signal.onabort = () => {
eventSource.close();
};
}
if (postfix) {
const messages: string[] = [];
for await (const event of toTextStream(eventSource, {
timeout,
signal,
})) {
if (event.type === 'message') {
messages.push(event.data);
if (postfix) {
const messages: string[] = [];
for await (const event of toTextStream(eventSource, {
timeout,
signal,
})) {
if (event.type === 'message') {
messages.push(event.data);
}
}
yield postfix(messages.join(''));
} else {
for await (const event of toTextStream(eventSource, {
timeout,
signal,
})) {
if (event.type === 'message') {
yield event.data;
}
}
}
yield postfix(messages.join(''));
} else {
for await (const event of toTextStream(eventSource, {
timeout,
signal,
})) {
if (event.type === 'message') {
yield event.data;
}
} finally {
eventSource.close();
if (signal && onAbort) {
signal.removeEventListener('abort', onAbort);
}
}
},
};
} else {
return Promise.race([
timeout
? delay(timeout).then(() => {
throw new Error('Timeout');
})
: null,
(async function () {
if (!retry) {
messageId = await createMessage({
client,
sessionId,
content,
attachments,
params,
});
}
AIProvider.LAST_ACTION_SESSIONID = sessionId;
return client.chatText({
return (async function () {
if (!retry) {
messageId = await createMessage({
client,
sessionId,
content,
attachments,
params,
timeout,
signal,
});
}
const eventSource = client.chatTextStream(
{
sessionId,
messageId,
reasoning,
modelId,
});
})(),
]);
toolsConfig,
},
endpoint
);
AIProvider.LAST_ACTION_SESSIONID = sessionId;
let onAbort: (() => void) | undefined;
try {
if (signal) {
if (signal.aborted) {
eventSource.close();
return '';
}
onAbort = () => {
eventSource.close();
};
signal.addEventListener('abort', onAbort, { once: true });
}
const messages: string[] = [];
for await (const event of toTextStream(eventSource, {
timeout,
signal,
})) {
if (event.type === 'message') {
messages.push(event.data);
}
}
const result = messages.join('');
return postfix ? postfix(result) : result;
} finally {
eventSource.close();
if (signal && onAbort) {
signal.removeEventListener('abort', onAbort);
}
}
})();
}
}
@@ -232,6 +278,8 @@ export function toImage({
content,
attachments,
params,
timeout,
signal,
});
}
const eventSource = client.imagesStream(

View File

@@ -582,6 +582,21 @@ Could you make a new website based on these notes and send back just the html fi
AIProvider.provide('session', {
createSession,
createSessionWithHistory: async options => {
if (!options.sessionId && !options.retry) {
return client.createSessionWithHistory({
workspaceId: options.workspaceId,
docId: options.docId,
promptName: options.promptName,
pinned: options.pinned,
reuseLatestChat: options.reuseLatestChat,
});
}
const sessionId = await createSession(options);
if (!sessionId) return undefined;
return client.getSession(options.workspaceId, sessionId);
},
getSession: async (workspaceId: string, sessionId: string) => {
return client.getSession(workspaceId, sessionId);
},
@@ -823,7 +838,7 @@ Could you make a new website based on these notes and send back just the html fi
regular: string;
};
}[];
} = await client.fetcher(url.toString()).then(res => res.json());
} = await fetch(url.toString()).then((res: Response) => res.json());
if (!result.results) return [];
return result.results.map(r => {
const url = new URL(r.urls.regular);

View File

@@ -14,7 +14,6 @@ import { OverCapacityNotification } from '@affine/core/components/over-capacity'
import {
AuthService,
EventSourceService,
FetchService,
GraphQLService,
} from '@affine/core/modules/cloud';
import {
@@ -140,16 +139,11 @@ export const WorkspaceSideEffects = () => {
const graphqlService = useService(GraphQLService);
const eventSourceService = useService(EventSourceService);
const fetchService = useService(FetchService);
const authService = useService(AuthService);
useEffect(() => {
const dispose = setupAIProvider(
new CopilotClient(
graphqlService.gql,
fetchService.fetch,
eventSourceService.eventSource
),
new CopilotClient(graphqlService.gql, eventSourceService.eventSource),
globalDialogService,
authService
);
@@ -158,7 +152,6 @@ export const WorkspaceSideEffects = () => {
};
}, [
eventSourceService,
fetchService,
workspaceDialogService,
graphqlService,
globalDialogService,

View File

@@ -23,7 +23,6 @@ import {
import { AIModelService } from '@affine/core/modules/ai-button/services/models';
import {
EventSourceService,
FetchService,
GraphQLService,
ServerService,
SubscriptionService,
@@ -58,16 +57,10 @@ type CopilotSession = Awaited<ReturnType<CopilotClient['getSession']>>;
function useCopilotClient() {
const graphqlService = useService(GraphQLService);
const eventSourceService = useService(EventSourceService);
const fetchService = useService(FetchService);
return useMemo(
() =>
new CopilotClient(
graphqlService.gql,
fetchService.fetch,
eventSourceService.eventSource
),
[graphqlService, eventSourceService, fetchService]
() => new CopilotClient(graphqlService.gql, eventSourceService.eventSource),
[graphqlService, eventSourceService]
);
}
@@ -106,6 +99,7 @@ export const Component = () => {
const [status, setStatus] = useState<ChatStatus>('idle');
const [isTogglingPin, setIsTogglingPin] = useState(false);
const [isOpeningSession, setIsOpeningSession] = useState(false);
const hasRestoredPinnedSessionRef = useRef(false);
const chatContainerRef = useRef<HTMLDivElement>(null);
const chatToolContainerRef = useRef<HTMLDivElement>(null);
const widthSignalRef = useRef<Signal<number>>(signal(0));
@@ -114,6 +108,10 @@ export const Component = () => {
const workspaceId = useService(WorkspaceService).workspace.id;
useEffect(() => {
hasRestoredPinnedSessionRef.current = false;
}, [workspaceId]);
const { docDisplayConfig, searchMenuConfig, reasoningConfig } =
useAIChatConfig();
@@ -122,14 +120,12 @@ export const Component = () => {
if (currentSession) {
return currentSession;
}
const sessionId = await client.createSession({
const session = await client.createSessionWithHistory({
workspaceId,
promptName: 'Chat With AFFiNE AI' satisfies PromptKey,
reuseLatestChat: false,
...options,
});
const session = await client.getSession(workspaceId, sessionId);
setCurrentSession(session);
return session;
},
@@ -169,23 +165,50 @@ export const Component = () => {
});
}, []);
const createFreshSession = useCallback(async () => {
if (isOpeningSession) {
return;
}
setIsOpeningSession(true);
try {
setCurrentSession(null);
reMountChatContent();
const session = await client.createSessionWithHistory({
workspaceId,
promptName: 'Chat With AFFiNE AI' satisfies PromptKey,
reuseLatestChat: false,
});
setCurrentSession(session);
} catch (error) {
console.error(error);
} finally {
setIsOpeningSession(false);
}
}, [client, isOpeningSession, reMountChatContent, workspaceId]);
const onOpenSession = useCallback(
(sessionId: string) => {
if (isOpeningSession) return;
async (sessionId: string) => {
if (isOpeningSession || currentSession?.sessionId === sessionId) return;
setIsOpeningSession(true);
client
.getSession(workspaceId, sessionId)
.then(session => {
setCurrentSession(session);
reMountChatContent();
chatTool?.closeHistoryMenu();
})
.catch(console.error)
.finally(() => {
setIsOpeningSession(false);
});
try {
const session = await client.getSession(workspaceId, sessionId);
setCurrentSession(session);
reMountChatContent();
chatTool?.closeHistoryMenu();
} catch (error) {
console.error(error);
} finally {
setIsOpeningSession(false);
}
},
[chatTool, client, isOpeningSession, reMountChatContent, workspaceId]
[
chatTool,
client,
currentSession?.sessionId,
isOpeningSession,
reMountChatContent,
workspaceId,
]
);
const onContextChange = useCallback((context: Partial<ChatContextValue>) => {
@@ -198,6 +221,16 @@ export const Component = () => {
},
[workbench]
);
const onOpenSessionDoc = useCallback(
(docId: string, sessionId: string) => {
const { workbench } = framework.get(WorkbenchService);
const viewService = framework.get(ViewService);
workbench.open(`/${docId}?sessionId=${sessionId}`, { at: 'active' });
workbench.openSidebar();
viewService.view.activeSidebarTab('chat');
},
[framework]
);
const confirmModal = useConfirmModal();
const notificationService = useMemo(
@@ -286,7 +319,6 @@ export const Component = () => {
}
}, [
chatContent,
client,
createSession,
currentSession,
docDisplayConfig,
@@ -296,7 +328,6 @@ export const Component = () => {
reasoningConfig,
searchMenuConfig,
workspaceId,
confirmModal,
onContextChange,
notificationService,
specs,
@@ -316,19 +347,15 @@ export const Component = () => {
status,
docDisplayConfig,
notificationService,
onOpenSession,
onOpenSession: sessionId => {
onOpenSession(sessionId).catch(console.error);
},
onNewSession: () => {
if (!currentSession) return;
setCurrentSession(null);
reMountChatContent();
createFreshSession().catch(console.error);
},
onTogglePin: togglePin,
onOpenDoc: (docId: string, sessionId: string) => {
const { workbench } = framework.get(WorkbenchService);
const viewService = framework.get(ViewService);
workbench.open(`/${docId}?sessionId=${sessionId}`, { at: 'active' });
workbench.openSidebar();
viewService.view.activeSidebarTab('chat');
onOpenSessionDoc(docId, sessionId);
},
onSessionDelete: (sessionToDelete: BlockSuitePresets.AIRecentSession) => {
deleteSession(sessionToDelete).catch(console.error);
@@ -349,12 +376,11 @@ export const Component = () => {
onOpenSession,
togglePin,
workspaceId,
confirmModal,
framework,
onOpenSessionDoc,
deleteSession,
status,
reMountChatContent,
notificationService,
createFreshSession,
]);
useEffect(() => {
@@ -375,30 +401,51 @@ export const Component = () => {
// restore pinned session
useEffect(() => {
if (hasRestoredPinnedSessionRef.current || currentSession) return;
hasRestoredPinnedSessionRef.current = true;
const controller = new AbortController();
const signal = controller.signal;
client
.getSessions(
workspaceId,
{},
undefined,
{ pinned: true, limit: 1 },
signal
)
.then(sessions => {
if (!Array.isArray(sessions)) return;
const session = sessions[0];
if (!session) return;
setCurrentSession(session);
reMountChatContent();
})
.catch(console.error);
const loadPinnedSession = async () => {
try {
const sessions = await client.getSessions(
workspaceId,
{},
undefined,
{ pinned: true, limit: 1 },
controller.signal
);
if (controller.signal.aborted || !Array.isArray(sessions)) {
return;
}
const pinnedSession = sessions[0];
if (!pinnedSession) {
return;
}
let shouldRemount = false;
setCurrentSession(prev => {
if (prev) return prev;
shouldRemount = true;
return pinnedSession;
});
if (shouldRemount) reMountChatContent();
} catch (error) {
if (controller.signal.aborted) {
return;
}
console.error(error);
}
};
loadPinnedSession().catch(error => {
if (controller.signal.aborted) return;
console.error(error);
});
// abort the request
return () => {
controller.abort();
};
}, [client, reMountChatContent, workspaceId]);
}, [client, currentSession, reMountChatContent, workspaceId]);
const onChatContainerRef = useCallback((node: HTMLDivElement) => {
if (node) {

View File

@@ -97,7 +97,9 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
const chatContainerRef = useRef<HTMLDivElement | null>(null);
const chatToolbarContainerRef = useRef<HTMLDivElement | null>(null);
const contentKeyRef = useRef<string | null>(null);
const prevSessionIdRef = useRef<string | null>(null);
const lastDocIdRef = useRef<string | null>(null);
const sessionLoadSeqRef = useRef(0);
const doc = editor?.doc;
const host = editor?.host;
@@ -127,12 +129,14 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
}, [appSidebarConfig]);
const resetPanel = useCallback(() => {
sessionLoadSeqRef.current += 1;
setSession(undefined);
setEmbeddingProgress([0, 0]);
setHasPinned(false);
}, []);
const initPanel = useCallback(async () => {
const requestSeq = ++sessionLoadSeqRef.current;
try {
const nextSession = await resolveInitialSession({
sessionService: AIProvider.session ?? undefined,
@@ -140,6 +144,7 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
workbench: workbench as WorkbenchLike,
});
if (requestSeq !== sessionLoadSeqRef.current) return;
if (nextSession === undefined) {
return;
}
@@ -156,22 +161,18 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
if (session || !AIProvider.session || !doc) {
return session ?? undefined;
}
const sessionId = await AIProvider.session.createSession({
const requestSeq = ++sessionLoadSeqRef.current;
const nextSession = await AIProvider.session.createSessionWithHistory({
docId: doc.id,
workspaceId: doc.workspace.id,
promptName: 'Chat With AFFiNE AI',
reuseLatestChat: false,
...options,
});
if (sessionId) {
const nextSession = await AIProvider.session.getSession(
doc.workspace.id,
sessionId
);
setSession(nextSession ?? null);
return nextSession ?? undefined;
}
return session ?? undefined;
if (requestSeq !== sessionLoadSeqRef.current) return undefined;
setSession(nextSession ?? null);
setHasPinned(!!nextSession?.pinned);
return nextSession ?? undefined;
},
[doc, session]
);
@@ -181,37 +182,64 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
if (!AIProvider.session || !doc) {
return undefined;
}
const requestSeq = ++sessionLoadSeqRef.current;
await AIProvider.session.updateSession(options);
const nextSession = await AIProvider.session.getSession(
doc.workspace.id,
options.sessionId
);
if (requestSeq !== sessionLoadSeqRef.current) return undefined;
setSession(nextSession ?? null);
setHasPinned(!!nextSession?.pinned);
return nextSession ?? undefined;
},
[doc]
);
const newSession = useCallback(() => {
const newSession = useCallback(async () => {
resetPanel();
requestAnimationFrame(() => {
setSession(null);
});
}, [resetPanel]);
const requestSeq = sessionLoadSeqRef.current;
setSession(null);
if (!AIProvider.session || !doc) {
return;
}
try {
const nextSession = await AIProvider.session.createSessionWithHistory({
docId: doc.id,
workspaceId: doc.workspace.id,
promptName: 'Chat With AFFiNE AI',
reuseLatestChat: false,
});
if (requestSeq === sessionLoadSeqRef.current) {
setSession(nextSession ?? null);
setHasPinned(!!nextSession?.pinned);
}
} catch (error) {
console.error(error);
}
}, [doc, resetPanel]);
const openSession = useCallback(
async (sessionId: string) => {
if (session?.sessionId === sessionId || !AIProvider.session || !doc) {
return;
}
resetPanel();
const nextSession = await AIProvider.session.getSession(
doc.workspace.id,
sessionId
);
setSession(nextSession ?? null);
const requestSeq = ++sessionLoadSeqRef.current;
try {
const nextSession = await AIProvider.session.getSession(
doc.workspace.id,
sessionId
);
if (requestSeq !== sessionLoadSeqRef.current) return;
setSession(nextSession ?? null);
setHasPinned(!!nextSession?.pinned);
} catch (error) {
console.error(error);
}
},
[doc, resetPanel, session?.sessionId]
[doc, session?.sessionId]
);
const openDoc = useCallback(
@@ -252,7 +280,9 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
},
isActiveSession: sessionToDelete =>
sessionToDelete.sessionId === session?.sessionId,
onActiveSessionDeleted: newSession,
onActiveSessionDeleted: () => {
newSession().catch(console.error);
},
}),
[newSession, notificationService, session?.sessionId, t]
);
@@ -342,35 +372,33 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
if (!doc || session !== undefined) {
return;
}
let cancelled = false;
let timerId: ReturnType<typeof setTimeout> | null = null;
const tryInit = () => {
if (cancelled || session !== undefined) {
return;
}
// Session service may be registered after the panel mounts.
if (AIProvider.session) {
initPanel().catch(console.error);
return;
}
timerId = setTimeout(tryInit, 200);
};
tryInit();
return () => {
cancelled = true;
if (timerId) {
clearTimeout(timerId);
}
};
if (AIProvider.session) {
initPanel().catch(console.error);
return;
}
const subscription = AIProvider.slots.sessionReady.subscribe(ready => {
if (!ready || session !== undefined) return;
initPanel().catch(console.error);
});
return () => subscription.unsubscribe();
}, [doc, initPanel, session]);
const contentKey = hasPinned
? (session?.sessionId ?? doc?.id ?? 'chat-panel')
: (doc?.id ?? 'chat-panel');
const hasSessionHistory = !!session?.messages?.length;
const sessionSwitched = !!(
session?.sessionId &&
prevSessionIdRef.current &&
prevSessionIdRef.current !== session.sessionId
);
const contentKey =
hasPinned || (session?.sessionId && (hasSessionHistory || sessionSwitched))
? (session?.sessionId ?? doc?.id ?? 'chat-panel')
: (doc?.id ?? 'chat-panel');
useEffect(() => {
if (session?.sessionId) {
prevSessionIdRef.current = session.sessionId;
}
}, [session?.sessionId]);
useEffect(() => {
if (!chatContent) {
@@ -469,7 +497,9 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
status,
docDisplayConfig,
notificationService,
onNewSession: newSession,
onNewSession: () => {
newSession().catch(console.error);
},
onTogglePin: togglePin,
onOpenSession: (sessionId: string) => {
openSession(sessionId).catch(console.error);

View File

@@ -15,7 +15,7 @@ test.describe('AIAction/ExplainCode', () => {
'javascript'
);
const { answer } = await explainCode();
await expect(answer).toHaveText(/console.log/);
await expect(answer).toContainText(/(console\.log|Hello,\s*World)/i);
});
test.skip('should show chat history in chat panel', async ({

View File

@@ -93,7 +93,10 @@ test.describe('AIChatWith/Attachments', () => {
await utils.chatPanel.getLatestAssistantMessage(page);
expect(content).toMatch(new RegExp(`Attachment${randomStr1}`));
expect(content).toMatch(new RegExp(`Attachment${randomStr2}`));
expect(await message.locator('affine-footnote-node').count()).toBe(2);
const footnoteCount = await message
.locator('affine-footnote-node')
.count();
expect(footnoteCount > 0 || /sources?/i.test(content)).toBe(true);
}).toPass({ timeout: 20000 });
});
});

View File

@@ -206,11 +206,13 @@ test.describe('AISettings/Embedding', () => {
]);
await expect(async () => {
const { content, message } =
await utils.chatPanel.getLatestAssistantMessage(page);
expect(content).toMatch(new RegExp(`Workspace${randomStr1}.*cat`));
expect(content).toMatch(new RegExp(`Workspace${randomStr2}.*dog`));
expect(await message.locator('affine-footnote-node').count()).toBe(2);
const { message } = await utils.chatPanel.getLatestAssistantMessage(page);
const fullText = await message.innerText();
expect(fullText).toMatch(new RegExp(`Workspace${randomStr1}.*cat`));
expect(fullText).toMatch(new RegExp(`Workspace${randomStr2}.*dog`));
expect(
await message.locator('affine-footnote-node').count()
).toBeGreaterThanOrEqual(1);
}).toPass({ timeout: 20000 });
});
@@ -269,6 +271,7 @@ test.describe('AISettings/Embedding', () => {
await utils.settings.waitForFileEmbeddingReadiness(page, 1);
await utils.settings.closeSettingsPanel(page);
const query = `Use semantic search across workspace and attached files, then list all hobbies of ${person}.`;
await utils.chatPanel.chatWithAttachments(
page,
@@ -279,13 +282,13 @@ test.describe('AISettings/Embedding', () => {
buffer: hobby2,
},
],
`What is ${person}'s hobby?`
query
);
await utils.chatPanel.waitForHistory(page, [
{
role: 'user',
content: `What is ${person}'s hobby?`,
content: query,
},
{
role: 'assistant',
@@ -294,11 +297,13 @@ test.describe('AISettings/Embedding', () => {
]);
await expect(async () => {
const { content, message } =
await utils.chatPanel.getLatestAssistantMessage(page);
expect(content).toMatch(/climbing/i);
expect(content).toMatch(/skating/i);
expect(await message.locator('affine-footnote-node').count()).toBe(2);
const { message } = await utils.chatPanel.getLatestAssistantMessage(page);
const fullText = await message.innerText();
expect(fullText).toMatch(/climbing/i);
expect(fullText).toMatch(/skating/i);
expect(
await message.locator('affine-footnote-node').count()
).toBeGreaterThanOrEqual(1);
}).toPass({ timeout: 20000 });
});

View File

@@ -67,54 +67,70 @@ export class ChatPanelUtils {
}
public static async collectHistory(page: Page) {
return await page.evaluate(() => {
const chatPanel = document.querySelector<HTMLElement>(
'[data-testid="chat-panel-messages"]'
);
if (!chatPanel) {
return [] as ChatMessage[];
const selectors =
':is(chat-message-user,chat-message-assistant,chat-message-action,[data-testid="chat-message-user"],[data-testid="chat-message-assistant"],[data-testid="chat-message-action"])';
const messages = page.locator(selectors);
const count = await messages.count();
if (!count) return [] as ChatMessage[];
const history: ChatMessage[] = [];
for (let i = 0; i < count; i++) {
const message = messages.nth(i);
const testId = await message.getAttribute('data-testid');
const tag = await message.evaluate(el => el.tagName.toLowerCase());
const isAssistant =
testId === 'chat-message-assistant' || tag === 'chat-message-assistant';
const isAction =
testId === 'chat-message-action' || tag === 'chat-message-action';
const isUser =
testId === 'chat-message-user' || tag === 'chat-message-user';
if (!isAssistant && !isAction && !isUser) continue;
const titleNode = message.locator('.user-info').first();
const title =
(await titleNode.count()) > 0 ? await titleNode.innerText() : '';
if (isUser) {
const pureText = message.getByTestId('chat-content-pure-text').first();
const content =
(await pureText.count()) > 0
? await pureText.innerText()
: ((await message.innerText()) ?? '');
history.push({ role: 'user', content });
continue;
}
const messages = chatPanel.querySelectorAll<HTMLElement>(
'chat-message-user,chat-message-assistant,chat-message-action'
);
return Array.from(messages).map(m => {
const isAssistant = m.dataset.testid === 'chat-message-assistant';
const isChatAction = m.dataset.testid === 'chat-message-action';
const richText = message.locator('chat-content-rich-text editor-host');
const richContent =
(await richText.count()) > 0
? (await richText.allInnerTexts()).join(' ')
: '';
const content = richContent || ((await message.innerText()) ?? '').trim();
const isUser = !isAssistant && !isChatAction;
if (isAssistant) {
const inferredStatus = (await message
.getByTestId('ai-loading')
.isVisible()
.catch(() => false))
? 'transmitting'
: content
? 'success'
: 'idle';
history.push({
role: 'assistant',
status: ((await message.getAttribute('data-status')) ??
inferredStatus) as ChatStatus,
title,
content,
});
continue;
}
if (isUser) {
return {
role: 'user' as const,
content:
m.querySelector<HTMLElement>(
'[data-testid="chat-content-pure-text"]'
)?.innerText || '',
};
}
history.push({ role: 'action', title, content });
}
if (isAssistant) {
return {
role: 'assistant' as const,
status: m.dataset.status as ChatStatus,
title: m.querySelector<HTMLElement>('.user-info')?.innerText || '',
content:
m.querySelector<HTMLElement>('chat-content-rich-text editor-host')
?.innerText || '',
};
}
// Must be chat action at this point
return {
role: 'action' as const,
title: m.querySelector<HTMLElement>('.user-info')?.innerText || '',
content:
m.querySelector<HTMLElement>('chat-content-rich-text editor-host')
?.innerText || '',
};
});
});
return history;
}
private static expectHistory(
@@ -126,8 +142,34 @@ export class ChatPanelUtils {
)[]
) {
expect(history).toHaveLength(expected.length);
const assistantStage = {
loading: 1,
transmitting: 1,
success: 2,
} as const;
history.forEach((message, index) => {
const expectedMessage = expected[index];
if (
message.role === 'assistant' &&
expectedMessage?.role === 'assistant' &&
expectedMessage.status
) {
const expectedStatus = expectedMessage.status;
if (
expectedStatus in assistantStage &&
message.status in assistantStage
) {
expect(
assistantStage[message.status as keyof typeof assistantStage]
).toBeGreaterThanOrEqual(
assistantStage[expectedStatus as keyof typeof assistantStage]
);
const { status: _status, ...expectedRest } = expectedMessage;
expect(message).toMatchObject(expectedRest);
return;
}
}
expect(message).toMatchObject(expectedMessage);
});
}

View File

@@ -82,7 +82,7 @@ export class EditorUtils {
}
public static async waitForAiAnswer(page: Page) {
const answer = await page.getByTestId('ai-penel-answer');
const answer = page.getByTestId('ai-penel-answer').last();
await answer.waitFor({
state: 'visible',
timeout: 2 * 60000,

View File

@@ -962,12 +962,8 @@ __metadata:
"@affine/graphql": "workspace:*"
"@affine/s3-compat": "workspace:*"
"@affine/server-native": "workspace:*"
"@ai-sdk/anthropic": "npm:^2.0.54"
"@ai-sdk/google": "npm:^2.0.45"
"@ai-sdk/google-vertex": "npm:^3.0.88"
"@ai-sdk/openai": "npm:^2.0.80"
"@ai-sdk/openai-compatible": "npm:^1.0.28"
"@ai-sdk/perplexity": "npm:^2.0.21"
"@apollo/server": "npm:^4.13.0"
"@faker-js/faker": "npm:^10.1.0"
"@fal-ai/serverless-client": "npm:^0.15.0"
@@ -1129,7 +1125,7 @@ __metadata:
languageName: unknown
linkType: soft
"@ai-sdk/anthropic@npm:2.0.57, @ai-sdk/anthropic@npm:^2.0.54":
"@ai-sdk/anthropic@npm:2.0.57":
version: 2.0.57
resolution: "@ai-sdk/anthropic@npm:2.0.57"
dependencies:
@@ -1181,42 +1177,6 @@ __metadata:
languageName: node
linkType: hard
"@ai-sdk/openai-compatible@npm:^1.0.28":
version: 1.0.30
resolution: "@ai-sdk/openai-compatible@npm:1.0.30"
dependencies:
"@ai-sdk/provider": "npm:2.0.1"
"@ai-sdk/provider-utils": "npm:3.0.20"
peerDependencies:
zod: ^3.25.76 || ^4.1.8
checksum: 10/5a925424b52c8dbc912b0beeceac6de406bc3bcac289f3065780108631736789b1bae78613b204515dac711f1bc50b410a9bbe5bbd60d503d9d654fc2ba0406e
languageName: node
linkType: hard
"@ai-sdk/openai@npm:^2.0.80":
version: 2.0.89
resolution: "@ai-sdk/openai@npm:2.0.89"
dependencies:
"@ai-sdk/provider": "npm:2.0.1"
"@ai-sdk/provider-utils": "npm:3.0.20"
peerDependencies:
zod: ^3.25.76 || ^4.1.8
checksum: 10/96b6037491365043e6f08f0e9bdf4a9e5d902ca62aac748c2fa610d0ec1422a7d252ec2e6ab1271d10d2804ab487e33837ee8bf2d0f74acbdfec5547bbfca09f
languageName: node
linkType: hard
"@ai-sdk/perplexity@npm:^2.0.21":
version: 2.0.23
resolution: "@ai-sdk/perplexity@npm:2.0.23"
dependencies:
"@ai-sdk/provider": "npm:2.0.1"
"@ai-sdk/provider-utils": "npm:3.0.20"
peerDependencies:
zod: ^3.25.76 || ^4.1.8
checksum: 10/b7d5ba8618bde044b88e69be17e2fc03127e15de1e23f5e64c4db799d83e536f8551bae6b7f484eef3bec1c064ee7848e6c4a18d06ffc3fc18129d25110e7f9e
languageName: node
linkType: hard
"@ai-sdk/provider-utils@npm:3.0.20":
version: 3.0.20
resolution: "@ai-sdk/provider-utils@npm:3.0.20"