mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-03-23 15:50:43 +08:00
feat: refactor copilot module (#14537)
This commit is contained in:
@@ -988,6 +988,16 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"providers.profiles": {
|
||||
"type": "array",
|
||||
"description": "The profile list for copilot providers.\n@default []",
|
||||
"default": []
|
||||
},
|
||||
"providers.defaults": {
|
||||
"type": "object",
|
||||
"description": "The default provider ids for model output types and global fallback.\n@default {}",
|
||||
"default": {}
|
||||
},
|
||||
"providers.openai": {
|
||||
"type": "object",
|
||||
"description": "The config for the openai provider.\n@default {\"apiKey\":\"\",\"baseURL\":\"https://api.openai.com/v1\"}\n@link https://github.com/openai/openai-node",
|
||||
|
||||
515
Cargo.lock
generated
515
Cargo.lock
generated
@@ -181,6 +181,7 @@ dependencies = [
|
||||
"chrono",
|
||||
"file-format",
|
||||
"infer",
|
||||
"llm_adapter",
|
||||
"mimalloc",
|
||||
"mp4parse",
|
||||
"napi",
|
||||
@@ -188,6 +189,8 @@ dependencies = [
|
||||
"napi-derive",
|
||||
"rand 0.9.2",
|
||||
"rayon",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha3",
|
||||
"tiktoken-rs",
|
||||
"tokio",
|
||||
@@ -245,7 +248,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed7572b7ba83a31e20d1b48970ee402d2e3e0537dcfe0a3ff4d6eb7508617d43"
|
||||
dependencies = [
|
||||
"alsa-sys",
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
]
|
||||
@@ -458,6 +461,12 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-waker"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
|
||||
|
||||
[[package]]
|
||||
name = "auto_enums"
|
||||
version = "0.8.7"
|
||||
@@ -476,6 +485,28 @@ version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
||||
|
||||
[[package]]
|
||||
name = "aws-lc-rs"
|
||||
version = "1.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9a7b350e3bb1767102698302bc37256cbd48422809984b98d292c40e2579aa9"
|
||||
dependencies = [
|
||||
"aws-lc-sys",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aws-lc-sys"
|
||||
version = "0.37.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b092fe214090261288111db7a2b2c2118e5a7f30dc2569f1732c4069a6840549"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"cmake",
|
||||
"dunce",
|
||||
"fs_extra",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.22.1"
|
||||
@@ -533,7 +564,7 @@ version = "0.72.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.13.0",
|
||||
@@ -583,9 +614,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.10.0"
|
||||
version = "2.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3"
|
||||
checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af"
|
||||
dependencies = [
|
||||
"serde_core",
|
||||
]
|
||||
@@ -904,6 +935,15 @@ version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d"
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "0.1.57"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
version = "1.0.4"
|
||||
@@ -983,7 +1023,7 @@ version = "0.24.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"core-foundation",
|
||||
"core-graphics-types",
|
||||
"foreign-types",
|
||||
@@ -996,7 +1036,7 @@ version = "0.25.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "064badf302c3194842cf2c5d61f56cc88e54a759313879cdf03abdd27d0c3b97"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"core-foundation",
|
||||
"core-graphics-types",
|
||||
"foreign-types",
|
||||
@@ -1009,7 +1049,7 @@ version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"core-foundation",
|
||||
"libc",
|
||||
]
|
||||
@@ -1379,7 +1419,7 @@ version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89a09f22a6c6069a18470eb92d2298acf25463f14256d24778e1230d789a2aec"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"block2",
|
||||
"libc",
|
||||
"objc2",
|
||||
@@ -1442,6 +1482,12 @@ version = "0.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f678cf4a922c215c63e0de95eb1ff08a958a81d47e485cf9da1e27bf6305cfa5"
|
||||
|
||||
[[package]]
|
||||
name = "dunce"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
|
||||
|
||||
[[package]]
|
||||
name = "ecb"
|
||||
version = "0.1.2"
|
||||
@@ -1664,6 +1710,12 @@ dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fs_extra"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
||||
|
||||
[[package]]
|
||||
name = "futf"
|
||||
version = "0.1.5"
|
||||
@@ -1828,9 +1880,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"r-efi",
|
||||
"wasip2",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1880,7 +1934,7 @@ version = "1.41.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0c43e7c3212bd992c11b6b9796563388170950521ae8487f5cdf6f6e792f1c8"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
@@ -2003,6 +2057,105 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"itoa",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body-util"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"http",
|
||||
"http-body",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httparse"
|
||||
version = "1.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11"
|
||||
dependencies = [
|
||||
"atomic-waker",
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"http",
|
||||
"http-body",
|
||||
"httparse",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"smallvec",
|
||||
"tokio",
|
||||
"want",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-rustls"
|
||||
version = "0.27.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58"
|
||||
dependencies = [
|
||||
"http",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"ipnet",
|
||||
"libc",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"socket2",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.64"
|
||||
@@ -2253,6 +2406,22 @@ dependencies = [
|
||||
"leaky-cow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ipnet"
|
||||
version = "2.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
|
||||
|
||||
[[package]]
|
||||
name = "iri-string"
|
||||
version = "0.7.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "is-terminal"
|
||||
version = "0.4.17"
|
||||
@@ -2376,9 +2545,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "keccak"
|
||||
version = "0.1.5"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654"
|
||||
checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653"
|
||||
dependencies = [
|
||||
"cpufeatures",
|
||||
]
|
||||
@@ -2490,7 +2659,7 @@ version = "0.1.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"libc",
|
||||
"redox_syscall 0.7.0",
|
||||
]
|
||||
@@ -2518,6 +2687,19 @@ version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
|
||||
|
||||
[[package]]
|
||||
name = "llm_adapter"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8dd9a548766bccf8b636695e8d514edee672d180e96a16ab932c971783b4e353"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.17",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.14"
|
||||
@@ -2555,7 +2737,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59fa2559e99ba0f26a12458aabc754432c805bbb8cba516c427825a997af1fb7"
|
||||
dependencies = [
|
||||
"aes",
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"cbc",
|
||||
"ecb",
|
||||
"encoding_rs",
|
||||
@@ -2583,6 +2765,12 @@ dependencies = [
|
||||
"hashbrown 0.16.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lru-slab"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
|
||||
|
||||
[[package]]
|
||||
name = "mac"
|
||||
version = "0.1.1"
|
||||
@@ -2741,7 +2929,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "000f205daae6646003fdc38517be6232af2b150bad4b67bdaf4c5aadb119d738"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"chrono",
|
||||
"ctor",
|
||||
"futures",
|
||||
@@ -2801,7 +2989,7 @@ version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2076a31b7010b17a38c01907c45b945e8f11495ee4dd588309718901b1f7a5b7"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"jni-sys",
|
||||
"log",
|
||||
"ndk-sys",
|
||||
@@ -2836,7 +3024,7 @@ version = "0.30.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"cfg-if",
|
||||
"cfg_aliases",
|
||||
"libc",
|
||||
@@ -3000,7 +3188,7 @@ version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"dispatch2",
|
||||
"objc2",
|
||||
]
|
||||
@@ -3017,7 +3205,7 @@ version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3e0adef53c21f888deb4fa59fc59f7eb17404926ee8a6f59f5df0fd7f9f3272"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"block2",
|
||||
"libc",
|
||||
"objc2",
|
||||
@@ -3074,6 +3262,12 @@ version = "11.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-probe"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
|
||||
|
||||
[[package]]
|
||||
name = "ordered-float"
|
||||
version = "5.1.0"
|
||||
@@ -3469,7 +3663,7 @@ checksum = "bee689443a2bd0a16ab0348b52ee43e3b2d1b1f931c8aa5c9f8de4c86fbe8c40"
|
||||
dependencies = [
|
||||
"bit-set 0.8.0",
|
||||
"bit-vec 0.8.0",
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"num-traits",
|
||||
"rand 0.9.2",
|
||||
"rand_chacha 0.9.0",
|
||||
@@ -3497,7 +3691,7 @@ version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e8bbe1a966bd2f362681a44f6edce3c2310ac21e4d5067a6e7ec396297a6ea0"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"getopts",
|
||||
"memchr",
|
||||
"pulldown-cmark-escape",
|
||||
@@ -3516,6 +3710,62 @@ version = "1.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
|
||||
|
||||
[[package]]
|
||||
name = "quinn"
|
||||
version = "0.11.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"cfg_aliases",
|
||||
"pin-project-lite",
|
||||
"quinn-proto",
|
||||
"quinn-udp",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustls",
|
||||
"socket2",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-proto"
|
||||
version = "0.11.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31"
|
||||
dependencies = [
|
||||
"aws-lc-rs",
|
||||
"bytes",
|
||||
"getrandom 0.3.4",
|
||||
"lru-slab",
|
||||
"rand 0.9.2",
|
||||
"ring",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"slab",
|
||||
"thiserror 2.0.17",
|
||||
"tinyvec",
|
||||
"tracing",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-udp"
|
||||
version = "0.5.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd"
|
||||
dependencies = [
|
||||
"cfg_aliases",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"socket2",
|
||||
"tracing",
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.43"
|
||||
@@ -3663,7 +3913,7 @@ version = "0.5.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3672,7 +3922,7 @@ version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49f3fe0889e69e2ae9e41f4d6c4c0181701d00e4697b356fb1f74173a5e0ee27"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3704,6 +3954,45 @@ version = "0.8.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-rustls",
|
||||
"hyper-util",
|
||||
"js-sys",
|
||||
"log",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"quinn",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"rustls-platform-verifier",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tower-service",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
version = "0.17.14"
|
||||
@@ -3831,7 +4120,7 @@ version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
@@ -3844,6 +4133,7 @@ version = "0.23.36"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b"
|
||||
dependencies = [
|
||||
"aws-lc-rs",
|
||||
"once_cell",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
@@ -3852,21 +4142,62 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-native-certs"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63"
|
||||
dependencies = [
|
||||
"openssl-probe",
|
||||
"rustls-pki-types",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pki-types"
|
||||
version = "1.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282"
|
||||
dependencies = [
|
||||
"web-time",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-platform-verifier"
|
||||
version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784"
|
||||
dependencies = [
|
||||
"core-foundation",
|
||||
"core-foundation-sys",
|
||||
"jni",
|
||||
"log",
|
||||
"once_cell",
|
||||
"rustls",
|
||||
"rustls-native-certs",
|
||||
"rustls-platform-verifier-android",
|
||||
"rustls-webpki",
|
||||
"security-framework",
|
||||
"security-framework-sys",
|
||||
"webpki-root-certs",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-platform-verifier-android"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.103.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52"
|
||||
dependencies = [
|
||||
"aws-lc-rs",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"untrusted",
|
||||
@@ -3905,6 +4236,15 @@ dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scoped-tls"
|
||||
version = "1.0.1"
|
||||
@@ -3953,6 +4293,29 @@ dependencies = [
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "3.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"core-foundation",
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
"security-framework-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework-sys"
|
||||
version = "2.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.27"
|
||||
@@ -4269,7 +4632,7 @@ checksum = "aa003f0038df784eb8fecbbac13affe3da23b45194bd57dba231c8f48199c526"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"base64",
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"chrono",
|
||||
@@ -4312,7 +4675,7 @@ checksum = "db58fcd5a53cf07c184b154801ff91347e4c30d17a3562a635ff028ad5deda46"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"base64",
|
||||
"bitflags 2.10.0",
|
||||
"bitflags 2.11.0",
|
||||
"byteorder",
|
||||
"chrono",
|
||||
"crc",
|
||||
@@ -4678,6 +5041,15 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sync_wrapper"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "synstructure"
|
||||
version = "0.13.2"
|
||||
@@ -4869,6 +5241,16 @@ dependencies = [
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.26.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61"
|
||||
dependencies = [
|
||||
"rustls",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-stream"
|
||||
version = "0.1.18"
|
||||
@@ -4912,13 +5294,58 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "toml_parser"
|
||||
version = "1.0.6+spec-1.1.0"
|
||||
version = "1.0.9+spec-1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a3198b4b0a8e11f09dd03e133c0280504d0801269e9afa46362ffde1cbeebf44"
|
||||
checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4"
|
||||
dependencies = [
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"pin-project-lite",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-http"
|
||||
version = "0.6.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"iri-string",
|
||||
"pin-project-lite",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-layer"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
|
||||
|
||||
[[package]]
|
||||
name = "tower-service"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
|
||||
|
||||
[[package]]
|
||||
name = "tracing"
|
||||
version = "0.1.44"
|
||||
@@ -5121,6 +5548,12 @@ dependencies = [
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "try-lock"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
|
||||
[[package]]
|
||||
name = "type1-encoding-parser"
|
||||
version = "0.1.0"
|
||||
@@ -5441,6 +5874,15 @@ dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "want"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e"
|
||||
dependencies = [
|
||||
"try-lock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.11.1+wasi-snapshot-preview1"
|
||||
@@ -5530,6 +5972,25 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "web-time"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
|
||||
dependencies = [
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-root-certs"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca"
|
||||
dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.26.11"
|
||||
|
||||
@@ -44,6 +44,7 @@ resolver = "3"
|
||||
lasso = { version = "0.7", features = ["multi-threaded"] }
|
||||
lib0 = { version = "0.16", features = ["lib0-serde"] }
|
||||
libc = "0.2"
|
||||
llm_adapter = "0.1.1"
|
||||
log = "0.4"
|
||||
loom = { version = "0.7", features = ["checkpoint"] }
|
||||
lru = "0.16"
|
||||
|
||||
@@ -108,7 +108,9 @@ export class BookmarkBlockComponent extends CaptionedBlockComponent<BookmarkBloc
|
||||
}
|
||||
|
||||
open = () => {
|
||||
window.open(this.link, '_blank');
|
||||
const link = this.link;
|
||||
if (!link) return;
|
||||
window.open(link, '_blank', 'noopener,noreferrer');
|
||||
};
|
||||
|
||||
refreshData = () => {
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import { getHostName } from '@blocksuite/affine-shared/utils';
|
||||
import {
|
||||
getHostName,
|
||||
isValidUrl,
|
||||
normalizeUrl,
|
||||
} from '@blocksuite/affine-shared/utils';
|
||||
import { PropTypes, requiredProperties } from '@blocksuite/std';
|
||||
import { css, LitElement } from 'lit';
|
||||
import { property } from 'lit/decorators.js';
|
||||
@@ -44,15 +48,27 @@ export class LinkPreview extends LitElement {
|
||||
|
||||
override render() {
|
||||
const { url } = this;
|
||||
const normalizedUrl = normalizeUrl(url);
|
||||
const safeUrl =
|
||||
normalizedUrl && isValidUrl(normalizedUrl) ? normalizedUrl : null;
|
||||
const hostName = getHostName(safeUrl ?? url);
|
||||
|
||||
if (!safeUrl) {
|
||||
return html`
|
||||
<span class="affine-link-preview">
|
||||
<span>${hostName}</span>
|
||||
</span>
|
||||
`;
|
||||
}
|
||||
|
||||
return html`
|
||||
<a
|
||||
class="affine-link-preview"
|
||||
rel="noopener noreferrer"
|
||||
target="_blank"
|
||||
href=${url}
|
||||
href=${safeUrl}
|
||||
>
|
||||
<span>${getHostName(url)}</span>
|
||||
<span>${hostName}</span>
|
||||
</a>
|
||||
`;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import type { FootNote } from '@blocksuite/affine-model';
|
||||
import { CitationProvider } from '@blocksuite/affine-shared/services';
|
||||
import { unsafeCSSVarV2 } from '@blocksuite/affine-shared/theme';
|
||||
import type { AffineTextAttributes } from '@blocksuite/affine-shared/types';
|
||||
import { isValidUrl, normalizeUrl } from '@blocksuite/affine-shared/utils';
|
||||
import { WithDisposable } from '@blocksuite/global/lit';
|
||||
import {
|
||||
BlockSelection,
|
||||
@@ -152,7 +153,9 @@ export class AffineFootnoteNode extends WithDisposable(ShadowlessElement) {
|
||||
};
|
||||
|
||||
private readonly _handleUrlReference = (url: string) => {
|
||||
window.open(url, '_blank');
|
||||
const normalizedUrl = normalizeUrl(url);
|
||||
if (!normalizedUrl || !isValidUrl(normalizedUrl)) return;
|
||||
window.open(normalizedUrl, '_blank', 'noopener,noreferrer');
|
||||
};
|
||||
|
||||
private readonly _updateFootnoteAttributes = (footnote: FootNote) => {
|
||||
|
||||
@@ -24,6 +24,11 @@ const toURL = (str: string) => {
|
||||
}
|
||||
};
|
||||
|
||||
const hasAllowedScheme = (url: URL) => {
|
||||
const protocol = url.protocol.slice(0, -1).toLowerCase();
|
||||
return ALLOWED_SCHEMES.has(protocol);
|
||||
};
|
||||
|
||||
function resolveURL(str: string, baseUrl: string, padded = false) {
|
||||
const url = toURL(str);
|
||||
if (!url) return null;
|
||||
@@ -61,6 +66,7 @@ export function normalizeUrl(str: string) {
|
||||
|
||||
// Formatted
|
||||
if (url) {
|
||||
if (!hasAllowedScheme(url)) return '';
|
||||
if (!str.endsWith('/') && url.href.endsWith('/')) {
|
||||
return url.href.substring(0, url.href.length - 1);
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
"af": "r affine.ts",
|
||||
"dev": "yarn affine dev",
|
||||
"build": "yarn affine build",
|
||||
"lint:eslint": "cross-env NODE_OPTIONS=\"--max-old-space-size=8192\" eslint --report-unused-disable-directives-severity=off . --cache",
|
||||
"lint:eslint": "cross-env NODE_OPTIONS=\"--max-old-space-size=16384\" eslint --report-unused-disable-directives-severity=off . --cache",
|
||||
"lint:eslint:fix": "yarn lint:eslint --fix --fix-type problem,suggestion,layout",
|
||||
"lint:prettier": "prettier --ignore-unknown --cache --check .",
|
||||
"lint:prettier:fix": "prettier --ignore-unknown --cache --write .",
|
||||
|
||||
@@ -17,10 +17,13 @@ affine_common = { workspace = true, features = [
|
||||
chrono = { workspace = true }
|
||||
file-format = { workspace = true }
|
||||
infer = { workspace = true }
|
||||
llm_adapter = { workspace = true }
|
||||
mp4parse = { workspace = true }
|
||||
napi = { workspace = true, features = ["async"] }
|
||||
napi-derive = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
sha3 = { workspace = true }
|
||||
tiktoken-rs = { workspace = true }
|
||||
v_htmlescape = { workspace = true }
|
||||
|
||||
8
packages/backend/native/index.d.ts
vendored
8
packages/backend/native/index.d.ts
vendored
@@ -1,5 +1,9 @@
|
||||
/* auto-generated by NAPI-RS */
|
||||
/* eslint-disable */
|
||||
export declare class LlmStreamHandle {
|
||||
abort(): void
|
||||
}
|
||||
|
||||
export declare class Tokenizer {
|
||||
count(content: string, allowedSpecial?: Array<string> | undefined | null): number
|
||||
}
|
||||
@@ -46,6 +50,10 @@ export declare function getMime(input: Uint8Array): string
|
||||
|
||||
export declare function htmlSanitize(input: string): string
|
||||
|
||||
export declare function llmDispatch(protocol: string, backendConfigJson: string, requestJson: string): string
|
||||
|
||||
export declare function llmDispatchStream(protocol: string, backendConfigJson: string, requestJson: string, callback: ((err: Error | null, arg: string) => void)): LlmStreamHandle
|
||||
|
||||
/**
|
||||
* Merge updates in form like `Y.applyUpdate(doc, update)` way and return the
|
||||
* result binary.
|
||||
|
||||
@@ -7,6 +7,7 @@ pub mod doc_loader;
|
||||
pub mod file_type;
|
||||
pub mod hashcash;
|
||||
pub mod html_sanitize;
|
||||
pub mod llm;
|
||||
pub mod tiktoken;
|
||||
|
||||
use affine_common::napi_utils::map_napi_err;
|
||||
|
||||
339
packages/backend/native/src/llm.rs
Normal file
339
packages/backend/native/src/llm.rs
Normal file
@@ -0,0 +1,339 @@
|
||||
use std::sync::{
|
||||
Arc,
|
||||
atomic::{AtomicBool, Ordering},
|
||||
};
|
||||
|
||||
use llm_adapter::{
|
||||
backend::{
|
||||
BackendConfig, BackendError, BackendProtocol, ReqwestHttpClient, dispatch_request, dispatch_stream_events_with,
|
||||
},
|
||||
core::{CoreRequest, StreamEvent},
|
||||
middleware::{
|
||||
MiddlewareConfig, PipelineContext, RequestMiddleware, StreamMiddleware, citation_indexing, clamp_max_tokens,
|
||||
normalize_messages, run_request_middleware_chain, run_stream_middleware_chain, stream_event_normalize,
|
||||
tool_schema_rewrite,
|
||||
},
|
||||
};
|
||||
use napi::{
|
||||
Error, Result, Status,
|
||||
threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const STREAM_END_MARKER: &str = "__AFFINE_LLM_STREAM_END__";
|
||||
const STREAM_ABORTED_REASON: &str = "__AFFINE_LLM_STREAM_ABORTED__";
|
||||
const STREAM_CALLBACK_DISPATCH_FAILED_REASON: &str = "__AFFINE_LLM_STREAM_CALLBACK_DISPATCH_FAILED__";
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize)]
|
||||
#[serde(default)]
|
||||
struct LlmMiddlewarePayload {
|
||||
request: Vec<String>,
|
||||
stream: Vec<String>,
|
||||
config: MiddlewareConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct LlmDispatchPayload {
|
||||
#[serde(flatten)]
|
||||
request: CoreRequest,
|
||||
#[serde(default)]
|
||||
middleware: LlmMiddlewarePayload,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub struct LlmStreamHandle {
|
||||
aborted: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl LlmStreamHandle {
|
||||
#[napi]
|
||||
pub fn abort(&self) {
|
||||
self.aborted.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub fn llm_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result<String> {
|
||||
let protocol = parse_protocol(&protocol)?;
|
||||
let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?;
|
||||
let payload: LlmDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?;
|
||||
let request = apply_request_middlewares(payload.request, &payload.middleware)?;
|
||||
|
||||
let response =
|
||||
dispatch_request(&ReqwestHttpClient::default(), &config, protocol, &request).map_err(map_backend_error)?;
|
||||
|
||||
serde_json::to_string(&response).map_err(map_json_error)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub fn llm_dispatch_stream(
|
||||
protocol: String,
|
||||
backend_config_json: String,
|
||||
request_json: String,
|
||||
callback: ThreadsafeFunction<String, ()>,
|
||||
) -> Result<LlmStreamHandle> {
|
||||
let protocol = parse_protocol(&protocol)?;
|
||||
let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?;
|
||||
let payload: LlmDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?;
|
||||
let request = apply_request_middlewares(payload.request, &payload.middleware)?;
|
||||
let middleware = payload.middleware.clone();
|
||||
|
||||
let aborted = Arc::new(AtomicBool::new(false));
|
||||
let aborted_in_worker = aborted.clone();
|
||||
|
||||
std::thread::spawn(move || {
|
||||
let chain = match resolve_stream_chain(&middleware.stream) {
|
||||
Ok(chain) => chain,
|
||||
Err(error) => {
|
||||
emit_error_event(&callback, error.reason.clone(), "middleware_error");
|
||||
let _ = callback.call(
|
||||
Ok(STREAM_END_MARKER.to_string()),
|
||||
ThreadsafeFunctionCallMode::NonBlocking,
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let mut pipeline = StreamPipeline::new(chain, middleware.config.clone());
|
||||
let mut aborted_by_user = false;
|
||||
let mut callback_dispatch_failed = false;
|
||||
|
||||
let result = dispatch_stream_events_with(&ReqwestHttpClient::default(), &config, protocol, &request, |event| {
|
||||
if aborted_in_worker.load(Ordering::Relaxed) {
|
||||
aborted_by_user = true;
|
||||
return Err(BackendError::Http(STREAM_ABORTED_REASON.to_string()));
|
||||
}
|
||||
|
||||
for event in pipeline.process(event) {
|
||||
let status = emit_stream_event(&callback, &event);
|
||||
if status != Status::Ok {
|
||||
callback_dispatch_failed = true;
|
||||
return Err(BackendError::Http(format!(
|
||||
"{STREAM_CALLBACK_DISPATCH_FAILED_REASON}:{status}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
});
|
||||
|
||||
if !aborted_by_user {
|
||||
for event in pipeline.finish() {
|
||||
if aborted_in_worker.load(Ordering::Relaxed) {
|
||||
aborted_by_user = true;
|
||||
break;
|
||||
}
|
||||
if emit_stream_event(&callback, &event) != Status::Ok {
|
||||
callback_dispatch_failed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(error) = result
|
||||
&& !aborted_by_user
|
||||
&& !callback_dispatch_failed
|
||||
&& !is_abort_error(&error)
|
||||
&& !is_callback_dispatch_failed_error(&error)
|
||||
{
|
||||
emit_error_event(&callback, error.to_string(), "dispatch_error");
|
||||
}
|
||||
|
||||
if !callback_dispatch_failed {
|
||||
let _ = callback.call(
|
||||
Ok(STREAM_END_MARKER.to_string()),
|
||||
ThreadsafeFunctionCallMode::NonBlocking,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(LlmStreamHandle { aborted })
|
||||
}
|
||||
|
||||
fn apply_request_middlewares(request: CoreRequest, middleware: &LlmMiddlewarePayload) -> Result<CoreRequest> {
|
||||
let chain = resolve_request_chain(&middleware.request)?;
|
||||
Ok(run_request_middleware_chain(request, &middleware.config, &chain))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StreamPipeline {
|
||||
chain: Vec<StreamMiddleware>,
|
||||
config: MiddlewareConfig,
|
||||
context: PipelineContext,
|
||||
}
|
||||
|
||||
impl StreamPipeline {
|
||||
fn new(chain: Vec<StreamMiddleware>, config: MiddlewareConfig) -> Self {
|
||||
Self {
|
||||
chain,
|
||||
config,
|
||||
context: PipelineContext::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn process(&mut self, event: StreamEvent) -> Vec<StreamEvent> {
|
||||
run_stream_middleware_chain(event, &mut self.context, &self.config, &self.chain)
|
||||
}
|
||||
|
||||
fn finish(&mut self) -> Vec<StreamEvent> {
|
||||
self.context.flush_pending_deltas();
|
||||
self.context.drain_queued_events()
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_stream_event(callback: &ThreadsafeFunction<String, ()>, event: &StreamEvent) -> Status {
|
||||
let value = serde_json::to_string(event).unwrap_or_else(|error| {
|
||||
serde_json::json!({
|
||||
"type": "error",
|
||||
"message": format!("failed to serialize stream event: {error}"),
|
||||
})
|
||||
.to_string()
|
||||
});
|
||||
|
||||
callback.call(Ok(value), ThreadsafeFunctionCallMode::NonBlocking)
|
||||
}
|
||||
|
||||
fn emit_error_event(callback: &ThreadsafeFunction<String, ()>, message: String, code: &str) {
|
||||
let error_event = serde_json::to_string(&StreamEvent::Error {
|
||||
message: message.clone(),
|
||||
code: Some(code.to_string()),
|
||||
})
|
||||
.unwrap_or_else(|_| {
|
||||
serde_json::json!({
|
||||
"type": "error",
|
||||
"message": message,
|
||||
"code": code,
|
||||
})
|
||||
.to_string()
|
||||
});
|
||||
|
||||
let _ = callback.call(Ok(error_event), ThreadsafeFunctionCallMode::NonBlocking);
|
||||
}
|
||||
|
||||
fn is_abort_error(error: &BackendError) -> bool {
|
||||
matches!(
|
||||
error,
|
||||
BackendError::Http(reason) if reason == STREAM_ABORTED_REASON
|
||||
)
|
||||
}
|
||||
|
||||
fn is_callback_dispatch_failed_error(error: &BackendError) -> bool {
|
||||
matches!(
|
||||
error,
|
||||
BackendError::Http(reason) if reason.starts_with(STREAM_CALLBACK_DISPATCH_FAILED_REASON)
|
||||
)
|
||||
}
|
||||
|
||||
fn resolve_request_chain(request: &[String]) -> Result<Vec<RequestMiddleware>> {
|
||||
if request.is_empty() {
|
||||
return Ok(vec![normalize_messages, tool_schema_rewrite]);
|
||||
}
|
||||
|
||||
request
|
||||
.iter()
|
||||
.map(|name| match name.as_str() {
|
||||
"normalize_messages" => Ok(normalize_messages as RequestMiddleware),
|
||||
"clamp_max_tokens" => Ok(clamp_max_tokens as RequestMiddleware),
|
||||
"tool_schema_rewrite" => Ok(tool_schema_rewrite as RequestMiddleware),
|
||||
_ => Err(Error::new(
|
||||
Status::InvalidArg,
|
||||
format!("Unsupported request middleware: {name}"),
|
||||
)),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn resolve_stream_chain(stream: &[String]) -> Result<Vec<StreamMiddleware>> {
|
||||
if stream.is_empty() {
|
||||
return Ok(vec![stream_event_normalize, citation_indexing]);
|
||||
}
|
||||
|
||||
stream
|
||||
.iter()
|
||||
.map(|name| match name.as_str() {
|
||||
"stream_event_normalize" => Ok(stream_event_normalize as StreamMiddleware),
|
||||
"citation_indexing" => Ok(citation_indexing as StreamMiddleware),
|
||||
_ => Err(Error::new(
|
||||
Status::InvalidArg,
|
||||
format!("Unsupported stream middleware: {name}"),
|
||||
)),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn parse_protocol(protocol: &str) -> Result<BackendProtocol> {
|
||||
match protocol {
|
||||
"openai_chat" | "openai-chat" | "openai_chat_completions" | "chat-completions" | "chat_completions" => {
|
||||
Ok(BackendProtocol::OpenaiChatCompletions)
|
||||
}
|
||||
"openai_responses" | "openai-responses" | "responses" => Ok(BackendProtocol::OpenaiResponses),
|
||||
"anthropic" | "anthropic_messages" | "anthropic-messages" => Ok(BackendProtocol::AnthropicMessages),
|
||||
other => Err(Error::new(
|
||||
Status::InvalidArg,
|
||||
format!("Unsupported llm backend protocol: {other}"),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn map_json_error(error: serde_json::Error) -> Error {
|
||||
Error::new(Status::InvalidArg, format!("Invalid JSON payload: {error}"))
|
||||
}
|
||||
|
||||
fn map_backend_error(error: BackendError) -> Error {
|
||||
Error::new(Status::GenericFailure, error.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn should_parse_supported_protocol_aliases() {
|
||||
assert!(parse_protocol("openai_chat").is_ok());
|
||||
assert!(parse_protocol("chat-completions").is_ok());
|
||||
assert!(parse_protocol("responses").is_ok());
|
||||
assert!(parse_protocol("anthropic").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_reject_unsupported_protocol() {
|
||||
let error = parse_protocol("unknown").unwrap_err();
|
||||
assert_eq!(error.status, Status::InvalidArg);
|
||||
assert!(error.reason.contains("Unsupported llm backend protocol"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_dispatch_should_reject_invalid_backend_json() {
|
||||
let error = llm_dispatch("openai_chat".to_string(), "{".to_string(), "{}".to_string()).unwrap_err();
|
||||
assert_eq!(error.status, Status::InvalidArg);
|
||||
assert!(error.reason.contains("Invalid JSON payload"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn map_json_error_should_use_invalid_arg_status() {
|
||||
let parse_error = serde_json::from_str::<serde_json::Value>("{").unwrap_err();
|
||||
let error = map_json_error(parse_error);
|
||||
assert_eq!(error.status, Status::InvalidArg);
|
||||
assert!(error.reason.contains("Invalid JSON payload"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_request_chain_should_support_clamp_max_tokens() {
|
||||
let chain = resolve_request_chain(&["normalize_messages".to_string(), "clamp_max_tokens".to_string()]).unwrap();
|
||||
assert_eq!(chain.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_request_chain_should_reject_unknown_middleware() {
|
||||
let error = resolve_request_chain(&["unknown".to_string()]).unwrap_err();
|
||||
assert_eq!(error.status, Status::InvalidArg);
|
||||
assert!(error.reason.contains("Unsupported request middleware"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_stream_chain_should_reject_unknown_middleware() {
|
||||
let error = resolve_stream_chain(&["unknown".to_string()]).unwrap_err();
|
||||
assert_eq!(error.status, Status::InvalidArg);
|
||||
assert!(error.reason.contains("Unsupported stream middleware"));
|
||||
}
|
||||
}
|
||||
@@ -12,9 +12,9 @@
|
||||
"dev": "nodemon ./src/index.ts",
|
||||
"dev:mail": "email dev -d src/mails",
|
||||
"test": "ava --concurrency 1 --serial",
|
||||
"test:copilot": "ava \"src/__tests__/copilot-*.spec.ts\"",
|
||||
"test:copilot": "ava \"src/__tests__/copilot/copilot-*.spec.ts\"",
|
||||
"test:coverage": "c8 ava --concurrency 1 --serial",
|
||||
"test:copilot:coverage": "c8 ava --timeout=5m \"src/__tests__/copilot-*.spec.ts\"",
|
||||
"test:copilot:coverage": "c8 ava --timeout=5m \"src/__tests__/copilot/copilot-*.spec.ts\"",
|
||||
"e2e": "cross-env TEST_MODE=e2e ava --serial",
|
||||
"e2e:coverage": "cross-env TEST_MODE=e2e c8 ava --serial",
|
||||
"data-migration": "cross-env NODE_ENV=development SERVER_FLAVOR=script r ./src/index.ts",
|
||||
@@ -28,12 +28,8 @@
|
||||
"dependencies": {
|
||||
"@affine/s3-compat": "workspace:*",
|
||||
"@affine/server-native": "workspace:*",
|
||||
"@ai-sdk/anthropic": "^2.0.54",
|
||||
"@ai-sdk/google": "^2.0.45",
|
||||
"@ai-sdk/google-vertex": "^3.0.88",
|
||||
"@ai-sdk/openai": "^2.0.80",
|
||||
"@ai-sdk/openai-compatible": "^1.0.28",
|
||||
"@ai-sdk/perplexity": "^2.0.21",
|
||||
"@apollo/server": "^4.13.0",
|
||||
"@fal-ai/serverless-client": "^0.15.0",
|
||||
"@google-cloud/opentelemetry-cloud-trace-exporter": "^3.0.0",
|
||||
|
||||
Binary file not shown.
@@ -43,7 +43,9 @@ Generated by [AVA](https://avajs.dev).
|
||||
> Snapshot 5
|
||||
|
||||
Buffer @Uint8Array [
|
||||
66616b65 20696d61 6765
|
||||
89504e47 0d0a1a0a 0000000d 49484452 00000001 00000001 08040000 00b51c0c
|
||||
02000000 0b494441 5478da63 fcff1f00 03030200 efa37c9f 00000000 49454e44
|
||||
ae426082
|
||||
]
|
||||
|
||||
## should preview link
|
||||
|
||||
Binary file not shown.
@@ -12,12 +12,12 @@ Generated by [AVA](https://avajs.dev).
|
||||
{
|
||||
messages: [
|
||||
{
|
||||
content: 'generate text to text',
|
||||
content: 'generate text to text stream',
|
||||
role: 'assistant',
|
||||
},
|
||||
],
|
||||
pinned: false,
|
||||
tokens: 8,
|
||||
tokens: 10,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -27,12 +27,12 @@ Generated by [AVA](https://avajs.dev).
|
||||
{
|
||||
messages: [
|
||||
{
|
||||
content: 'generate text to text',
|
||||
content: 'generate text to text stream',
|
||||
role: 'assistant',
|
||||
},
|
||||
],
|
||||
pinned: false,
|
||||
tokens: 8,
|
||||
tokens: 10,
|
||||
},
|
||||
]
|
||||
|
||||
Binary file not shown.
@@ -4,31 +4,31 @@ import type { ExecutionContext, TestFn } from 'ava';
|
||||
import ava from 'ava';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { ServerFeature, ServerService } from '../core';
|
||||
import { AuthService } from '../core/auth';
|
||||
import { QuotaModule } from '../core/quota';
|
||||
import { Models } from '../models';
|
||||
import { CopilotModule } from '../plugins/copilot';
|
||||
import { prompts, PromptService } from '../plugins/copilot/prompt';
|
||||
import { ServerFeature, ServerService } from '../../core';
|
||||
import { AuthService } from '../../core/auth';
|
||||
import { QuotaModule } from '../../core/quota';
|
||||
import { Models } from '../../models';
|
||||
import { CopilotModule } from '../../plugins/copilot';
|
||||
import { prompts, PromptService } from '../../plugins/copilot/prompt';
|
||||
import {
|
||||
CopilotProviderFactory,
|
||||
CopilotProviderType,
|
||||
StreamObject,
|
||||
StreamObjectSchema,
|
||||
} from '../plugins/copilot/providers';
|
||||
import { TranscriptionResponseSchema } from '../plugins/copilot/transcript/types';
|
||||
} from '../../plugins/copilot/providers';
|
||||
import { TranscriptionResponseSchema } from '../../plugins/copilot/transcript/types';
|
||||
import {
|
||||
CopilotChatTextExecutor,
|
||||
CopilotWorkflowService,
|
||||
GraphExecutorState,
|
||||
} from '../plugins/copilot/workflow';
|
||||
} from '../../plugins/copilot/workflow';
|
||||
import {
|
||||
CopilotChatImageExecutor,
|
||||
CopilotCheckHtmlExecutor,
|
||||
CopilotCheckJsonExecutor,
|
||||
} from '../plugins/copilot/workflow/executor';
|
||||
import { createTestingModule, TestingModule } from './utils';
|
||||
import { TestAssets } from './utils/copilot';
|
||||
} from '../../plugins/copilot/workflow/executor';
|
||||
import { createTestingModule, TestingModule } from '../utils';
|
||||
import { TestAssets } from '../utils/copilot';
|
||||
|
||||
type Tester = {
|
||||
auth: AuthService;
|
||||
@@ -6,25 +6,25 @@ import type { TestFn } from 'ava';
|
||||
import ava from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import { AppModule } from '../app.module';
|
||||
import { JobQueue } from '../base';
|
||||
import { ConfigModule } from '../base/config';
|
||||
import { AuthService } from '../core/auth';
|
||||
import { DocReader } from '../core/doc';
|
||||
import { CopilotContextService } from '../plugins/copilot/context';
|
||||
import { AppModule } from '../../app.module';
|
||||
import { JobQueue } from '../../base';
|
||||
import { ConfigModule } from '../../base/config';
|
||||
import { AuthService } from '../../core/auth';
|
||||
import { DocReader } from '../../core/doc';
|
||||
import { CopilotContextService } from '../../plugins/copilot/context';
|
||||
import {
|
||||
CopilotEmbeddingJob,
|
||||
MockEmbeddingClient,
|
||||
} from '../plugins/copilot/embedding';
|
||||
import { prompts, PromptService } from '../plugins/copilot/prompt';
|
||||
} from '../../plugins/copilot/embedding';
|
||||
import { prompts, PromptService } from '../../plugins/copilot/prompt';
|
||||
import {
|
||||
CopilotProviderFactory,
|
||||
CopilotProviderType,
|
||||
GeminiGenerativeProvider,
|
||||
OpenAIProvider,
|
||||
} from '../plugins/copilot/providers';
|
||||
import { CopilotStorage } from '../plugins/copilot/storage';
|
||||
import { MockCopilotProvider } from './mocks';
|
||||
} from '../../plugins/copilot/providers';
|
||||
import { CopilotStorage } from '../../plugins/copilot/storage';
|
||||
import { MockCopilotProvider } from '../mocks';
|
||||
import {
|
||||
acceptInviteById,
|
||||
createTestingApp,
|
||||
@@ -33,7 +33,7 @@ import {
|
||||
smallestPng,
|
||||
TestingApp,
|
||||
TestUser,
|
||||
} from './utils';
|
||||
} from '../utils';
|
||||
import {
|
||||
addContextDoc,
|
||||
addContextFile,
|
||||
@@ -67,7 +67,7 @@ import {
|
||||
textToEventStream,
|
||||
unsplashSearch,
|
||||
updateCopilotSession,
|
||||
} from './utils/copilot';
|
||||
} from '../utils/copilot';
|
||||
|
||||
const test = ava as TestFn<{
|
||||
auth: AuthService;
|
||||
@@ -513,7 +513,11 @@ test('should be able to chat with api', async t => {
|
||||
);
|
||||
const messageId = await createCopilotMessage(app, sessionId);
|
||||
const ret = await chatWithText(app, sessionId, messageId);
|
||||
t.is(ret, 'generate text to text', 'should be able to chat with text');
|
||||
t.is(
|
||||
ret,
|
||||
'generate text to text stream',
|
||||
'should be able to chat with text'
|
||||
);
|
||||
|
||||
const ret2 = await chatWithTextStream(app, sessionId, messageId);
|
||||
t.is(
|
||||
@@ -657,7 +661,7 @@ test('should be able to retry with api', async t => {
|
||||
const histories = await getHistories(app, { workspaceId: id, docId });
|
||||
t.deepEqual(
|
||||
histories.map(h => h.messages.map(m => m.content)),
|
||||
[['generate text to text', 'generate text to text']],
|
||||
[['generate text to text stream', 'generate text to text stream']],
|
||||
'should be able to list history'
|
||||
);
|
||||
}
|
||||
@@ -794,7 +798,7 @@ test('should be able to list history', async t => {
|
||||
const histories = await getHistories(app, { workspaceId, docId });
|
||||
t.deepEqual(
|
||||
histories.map(h => h.messages.map(m => m.content)),
|
||||
[['hello', 'generate text to text']],
|
||||
[['hello', 'generate text to text stream']],
|
||||
'should be able to list history'
|
||||
);
|
||||
}
|
||||
@@ -807,7 +811,7 @@ test('should be able to list history', async t => {
|
||||
});
|
||||
t.deepEqual(
|
||||
histories.map(h => h.messages.map(m => m.content)),
|
||||
[['generate text to text', 'hello']],
|
||||
[['generate text to text stream', 'hello']],
|
||||
'should be able to list history'
|
||||
);
|
||||
}
|
||||
@@ -858,7 +862,7 @@ test('should reject request that user have not permission', async t => {
|
||||
const histories = await getHistories(app, { workspaceId, docId });
|
||||
t.deepEqual(
|
||||
histories.map(h => h.messages.map(m => m.content)),
|
||||
[['generate text to text']],
|
||||
[['generate text to text stream']],
|
||||
'should able to list history'
|
||||
);
|
||||
|
||||
@@ -8,38 +8,38 @@ import ava from 'ava';
|
||||
import { nanoid } from 'nanoid';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import { EventBus, JobQueue } from '../base';
|
||||
import { ConfigModule } from '../base/config';
|
||||
import { AuthService } from '../core/auth';
|
||||
import { QuotaModule } from '../core/quota';
|
||||
import { StorageModule, WorkspaceBlobStorage } from '../core/storage';
|
||||
import { EventBus, JobQueue } from '../../base';
|
||||
import { ConfigModule } from '../../base/config';
|
||||
import { AuthService } from '../../core/auth';
|
||||
import { QuotaModule } from '../../core/quota';
|
||||
import { StorageModule, WorkspaceBlobStorage } from '../../core/storage';
|
||||
import {
|
||||
ContextCategories,
|
||||
CopilotSessionModel,
|
||||
WorkspaceModel,
|
||||
} from '../models';
|
||||
import { CopilotModule } from '../plugins/copilot';
|
||||
import { CopilotContextService } from '../plugins/copilot/context';
|
||||
import { CopilotCronJobs } from '../plugins/copilot/cron';
|
||||
} from '../../models';
|
||||
import { CopilotModule } from '../../plugins/copilot';
|
||||
import { CopilotContextService } from '../../plugins/copilot/context';
|
||||
import { CopilotCronJobs } from '../../plugins/copilot/cron';
|
||||
import {
|
||||
CopilotEmbeddingJob,
|
||||
MockEmbeddingClient,
|
||||
} from '../plugins/copilot/embedding';
|
||||
import { prompts, PromptService } from '../plugins/copilot/prompt';
|
||||
} from '../../plugins/copilot/embedding';
|
||||
import { prompts, PromptService } from '../../plugins/copilot/prompt';
|
||||
import {
|
||||
CopilotProviderFactory,
|
||||
CopilotProviderType,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
OpenAIProvider,
|
||||
} from '../plugins/copilot/providers';
|
||||
} from '../../plugins/copilot/providers';
|
||||
import {
|
||||
CitationParser,
|
||||
TextStreamParser,
|
||||
} from '../plugins/copilot/providers/utils';
|
||||
import { ChatSessionService } from '../plugins/copilot/session';
|
||||
import { CopilotStorage } from '../plugins/copilot/storage';
|
||||
import { CopilotTranscriptionService } from '../plugins/copilot/transcript';
|
||||
} from '../../plugins/copilot/providers/utils';
|
||||
import { ChatSessionService } from '../../plugins/copilot/session';
|
||||
import { CopilotStorage } from '../../plugins/copilot/storage';
|
||||
import { CopilotTranscriptionService } from '../../plugins/copilot/transcript';
|
||||
import {
|
||||
CopilotChatTextExecutor,
|
||||
CopilotWorkflowService,
|
||||
@@ -48,7 +48,7 @@ import {
|
||||
WorkflowGraphExecutor,
|
||||
type WorkflowNodeData,
|
||||
WorkflowNodeType,
|
||||
} from '../plugins/copilot/workflow';
|
||||
} from '../../plugins/copilot/workflow';
|
||||
import {
|
||||
CopilotChatImageExecutor,
|
||||
CopilotCheckHtmlExecutor,
|
||||
@@ -56,16 +56,16 @@ import {
|
||||
getWorkflowExecutor,
|
||||
NodeExecuteState,
|
||||
NodeExecutorType,
|
||||
} from '../plugins/copilot/workflow/executor';
|
||||
import { AutoRegisteredWorkflowExecutor } from '../plugins/copilot/workflow/executor/utils';
|
||||
import { WorkflowGraphList } from '../plugins/copilot/workflow/graph';
|
||||
import { CopilotWorkspaceService } from '../plugins/copilot/workspace';
|
||||
import { PaymentModule } from '../plugins/payment';
|
||||
import { SubscriptionService } from '../plugins/payment/service';
|
||||
import { SubscriptionStatus } from '../plugins/payment/types';
|
||||
import { MockCopilotProvider } from './mocks';
|
||||
import { createTestingModule, TestingModule } from './utils';
|
||||
import { WorkflowTestCases } from './utils/copilot';
|
||||
} from '../../plugins/copilot/workflow/executor';
|
||||
import { AutoRegisteredWorkflowExecutor } from '../../plugins/copilot/workflow/executor/utils';
|
||||
import { WorkflowGraphList } from '../../plugins/copilot/workflow/graph';
|
||||
import { CopilotWorkspaceService } from '../../plugins/copilot/workspace';
|
||||
import { PaymentModule } from '../../plugins/payment';
|
||||
import { SubscriptionService } from '../../plugins/payment/service';
|
||||
import { SubscriptionStatus } from '../../plugins/payment/types';
|
||||
import { MockCopilotProvider } from '../mocks';
|
||||
import { createTestingModule, TestingModule } from '../utils';
|
||||
import { WorkflowTestCases } from '../utils/copilot';
|
||||
|
||||
type Context = {
|
||||
auth: AuthService;
|
||||
@@ -364,6 +364,21 @@ test('should be able to manage chat session', async t => {
|
||||
});
|
||||
t.is(newSessionId, sessionId, 'should get same session id');
|
||||
}
|
||||
|
||||
// should create a fresh session when reuseLatestChat is explicitly disabled
|
||||
{
|
||||
const newSessionId = await session.create({
|
||||
userId,
|
||||
promptName,
|
||||
...commonParams,
|
||||
reuseLatestChat: false,
|
||||
});
|
||||
t.not(
|
||||
newSessionId,
|
||||
sessionId,
|
||||
'should create new session id when reuseLatestChat is false'
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
test('should be able to update chat session prompt', async t => {
|
||||
@@ -881,6 +896,26 @@ test('should be able to get provider', async t => {
|
||||
}
|
||||
});
|
||||
|
||||
test('should resolve provider by prefixed model id', async t => {
|
||||
const { factory } = t.context;
|
||||
|
||||
const provider = await factory.getProviderByModel('openai-default/test');
|
||||
t.truthy(provider, 'should resolve prefixed model id');
|
||||
t.is(provider?.type, CopilotProviderType.OpenAI);
|
||||
|
||||
const result = await provider?.text({ modelId: 'openai-default/test' }, [
|
||||
{ role: 'user', content: 'hello' },
|
||||
]);
|
||||
t.is(result, 'generate text to text');
|
||||
});
|
||||
|
||||
test('should fallback to null when prefixed provider id does not exist', async t => {
|
||||
const { factory } = t.context;
|
||||
|
||||
const provider = await factory.getProviderByModel('unknown/test');
|
||||
t.is(provider, null);
|
||||
});
|
||||
|
||||
// ==================== workflow ====================
|
||||
|
||||
// this test used to preview the final result of the workflow
|
||||
@@ -2063,25 +2098,23 @@ test('should handle copilot cron jobs correctly', async t => {
|
||||
});
|
||||
|
||||
test('should resolve model correctly based on subscription status and prompt config', async t => {
|
||||
const { db, session, subscription } = t.context;
|
||||
const { prompt, session, subscription } = t.context;
|
||||
|
||||
// 1) Seed a prompt that has optionalModels and proModels in config
|
||||
const promptName = 'resolve-model-test';
|
||||
await db.aiPrompt.create({
|
||||
data: {
|
||||
name: promptName,
|
||||
model: 'gemini-2.5-flash',
|
||||
messages: {
|
||||
create: [{ idx: 0, role: 'system', content: 'test' }],
|
||||
},
|
||||
config: { proModels: ['gemini-2.5-pro', 'claude-sonnet-4-5@20250929'] },
|
||||
await prompt.set(
|
||||
promptName,
|
||||
'gemini-2.5-flash',
|
||||
[{ role: 'system', content: 'test' }],
|
||||
{ proModels: ['gemini-2.5-pro', 'claude-sonnet-4-5@20250929'] },
|
||||
{
|
||||
optionalModels: [
|
||||
'gemini-2.5-flash',
|
||||
'gemini-2.5-pro',
|
||||
'claude-sonnet-4-5@20250929',
|
||||
],
|
||||
},
|
||||
});
|
||||
}
|
||||
);
|
||||
|
||||
// 2) Create a chat session with this prompt
|
||||
const sessionId = await session.create({
|
||||
@@ -2106,6 +2139,16 @@ test('should resolve model correctly based on subscription status and prompt con
|
||||
const model1 = await s.resolveModel(false, 'gemini-2.5-pro');
|
||||
t.snapshot(model1, 'should honor requested pro model');
|
||||
|
||||
const model1WithPrefix = await s.resolveModel(
|
||||
false,
|
||||
'openai-default/gemini-2.5-pro'
|
||||
);
|
||||
t.is(
|
||||
model1WithPrefix,
|
||||
'openai-default/gemini-2.5-pro',
|
||||
'should honor requested prefixed pro model'
|
||||
);
|
||||
|
||||
const model2 = await s.resolveModel(false, 'not-in-optional');
|
||||
t.snapshot(model2, 'should fallback to default model');
|
||||
}
|
||||
@@ -2119,6 +2162,16 @@ test('should resolve model correctly based on subscription status and prompt con
|
||||
'should fallback to default model when requesting pro model during trialing'
|
||||
);
|
||||
|
||||
const model3WithPrefix = await s.resolveModel(
|
||||
true,
|
||||
'openai-default/gemini-2.5-pro'
|
||||
);
|
||||
t.is(
|
||||
model3WithPrefix,
|
||||
'gemini-2.5-flash',
|
||||
'should fallback to default model when requesting prefixed pro model during trialing'
|
||||
);
|
||||
|
||||
const model4 = await s.resolveModel(true, 'gemini-2.5-flash');
|
||||
t.snapshot(model4, 'should honor requested non-pro model during trialing');
|
||||
|
||||
@@ -2141,6 +2194,16 @@ test('should resolve model correctly based on subscription status and prompt con
|
||||
const model7 = await s.resolveModel(true, 'claude-sonnet-4-5@20250929');
|
||||
t.snapshot(model7, 'should honor requested pro model during active');
|
||||
|
||||
const model7WithPrefix = await s.resolveModel(
|
||||
true,
|
||||
'openai-default/claude-sonnet-4-5@20250929'
|
||||
);
|
||||
t.is(
|
||||
model7WithPrefix,
|
||||
'openai-default/claude-sonnet-4-5@20250929',
|
||||
'should honor requested prefixed pro model during active'
|
||||
);
|
||||
|
||||
const model8 = await s.resolveModel(true, 'not-in-optional');
|
||||
t.snapshot(
|
||||
model8,
|
||||
@@ -0,0 +1,210 @@
|
||||
import test from 'ava';
|
||||
import { z } from 'zod';
|
||||
|
||||
import type { NativeLlmRequest, NativeLlmStreamEvent } from '../../native';
|
||||
import {
|
||||
buildNativeRequest,
|
||||
NativeProviderAdapter,
|
||||
} from '../../plugins/copilot/providers/native';
|
||||
|
||||
const mockDispatch = () =>
|
||||
(async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
yield { type: 'text_delta', text: 'Use [^1] now' };
|
||||
yield { type: 'citation', index: 1, url: 'https://affine.pro' };
|
||||
yield { type: 'done', finish_reason: 'stop' };
|
||||
})();
|
||||
|
||||
test('NativeProviderAdapter streamText should append citation footnotes', async t => {
|
||||
const adapter = new NativeProviderAdapter(mockDispatch, {}, 3);
|
||||
const chunks: string[] = [];
|
||||
for await (const chunk of adapter.streamText({
|
||||
model: 'gpt-4.1',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
|
||||
})) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
const text = chunks.join('');
|
||||
t.true(text.includes('Use [^1] now'));
|
||||
t.true(
|
||||
text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}')
|
||||
);
|
||||
});
|
||||
|
||||
test('NativeProviderAdapter streamObject should append citation footnotes', async t => {
|
||||
const adapter = new NativeProviderAdapter(mockDispatch, {}, 3);
|
||||
const chunks = [];
|
||||
for await (const chunk of adapter.streamObject({
|
||||
model: 'gpt-4.1',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
|
||||
})) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
t.deepEqual(
|
||||
chunks.map(chunk => chunk.type),
|
||||
['text-delta', 'text-delta']
|
||||
);
|
||||
const text = chunks
|
||||
.filter(chunk => chunk.type === 'text-delta')
|
||||
.map(chunk => chunk.textDelta)
|
||||
.join('');
|
||||
t.true(text.includes('Use [^1] now'));
|
||||
t.true(
|
||||
text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}')
|
||||
);
|
||||
});
|
||||
|
||||
test('NativeProviderAdapter streamObject should append fallback attachment footnotes', async t => {
|
||||
const dispatch = () =>
|
||||
(async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
yield {
|
||||
type: 'tool_result',
|
||||
call_id: 'call_1',
|
||||
name: 'blob_read',
|
||||
arguments: { blob_id: 'blob_1' },
|
||||
output: {
|
||||
blobId: 'blob_1',
|
||||
fileName: 'a.txt',
|
||||
fileType: 'text/plain',
|
||||
content: 'A',
|
||||
},
|
||||
};
|
||||
yield {
|
||||
type: 'tool_result',
|
||||
call_id: 'call_2',
|
||||
name: 'blob_read',
|
||||
arguments: { blob_id: 'blob_2' },
|
||||
output: {
|
||||
blobId: 'blob_2',
|
||||
fileName: 'b.txt',
|
||||
fileType: 'text/plain',
|
||||
content: 'B',
|
||||
},
|
||||
};
|
||||
yield { type: 'text_delta', text: 'Answer from files.' };
|
||||
yield { type: 'done', finish_reason: 'stop' };
|
||||
})();
|
||||
|
||||
const adapter = new NativeProviderAdapter(dispatch, {}, 3);
|
||||
const chunks = [];
|
||||
for await (const chunk of adapter.streamObject({
|
||||
model: 'gpt-4.1',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
|
||||
})) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
const text = chunks
|
||||
.filter(chunk => chunk.type === 'text-delta')
|
||||
.map(chunk => chunk.textDelta)
|
||||
.join('');
|
||||
t.true(text.includes('Answer from files.'));
|
||||
t.true(text.includes('[^1][^2]'));
|
||||
t.true(
|
||||
text.includes(
|
||||
'[^1]: {"type":"attachment","blobId":"blob_1","fileName":"a.txt","fileType":"text/plain"}'
|
||||
)
|
||||
);
|
||||
t.true(
|
||||
text.includes(
|
||||
'[^2]: {"type":"attachment","blobId":"blob_2","fileName":"b.txt","fileType":"text/plain"}'
|
||||
)
|
||||
);
|
||||
});
|
||||
|
||||
test('NativeProviderAdapter streamObject should map tool and text events', async t => {
|
||||
let round = 0;
|
||||
const dispatch = (_request: NativeLlmRequest) =>
|
||||
(async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
round += 1;
|
||||
if (round === 1) {
|
||||
yield {
|
||||
type: 'tool_call',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments: { doc_id: 'a1' },
|
||||
};
|
||||
yield { type: 'done', finish_reason: 'tool_calls' };
|
||||
return;
|
||||
}
|
||||
yield { type: 'text_delta', text: 'ok' };
|
||||
yield { type: 'done', finish_reason: 'stop' };
|
||||
})();
|
||||
|
||||
const adapter = new NativeProviderAdapter(
|
||||
dispatch,
|
||||
{
|
||||
doc_read: {
|
||||
inputSchema: z.object({ doc_id: z.string() }),
|
||||
execute: async () => ({ markdown: '# a1' }),
|
||||
},
|
||||
},
|
||||
4
|
||||
);
|
||||
|
||||
const events = [];
|
||||
for await (const event of adapter.streamObject({
|
||||
model: 'gpt-4.1',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'read' }] }],
|
||||
})) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
t.deepEqual(
|
||||
events.map(event => event.type),
|
||||
['tool-call', 'tool-result', 'text-delta']
|
||||
);
|
||||
t.deepEqual(events[0], {
|
||||
type: 'tool-call',
|
||||
toolCallId: 'call_1',
|
||||
toolName: 'doc_read',
|
||||
args: { doc_id: 'a1' },
|
||||
});
|
||||
});
|
||||
|
||||
test('buildNativeRequest should include rust middleware from profile', async t => {
|
||||
const { request } = await buildNativeRequest({
|
||||
model: 'gpt-4.1',
|
||||
messages: [{ role: 'user', content: 'hello' }],
|
||||
tools: {},
|
||||
middleware: {
|
||||
rust: {
|
||||
request: ['normalize_messages', 'clamp_max_tokens'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
},
|
||||
node: {
|
||||
text: ['callout'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
t.deepEqual(request.middleware, {
|
||||
request: ['normalize_messages', 'clamp_max_tokens'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
});
|
||||
});
|
||||
|
||||
test('NativeProviderAdapter streamText should skip citation footnotes when disabled', async t => {
|
||||
const adapter = new NativeProviderAdapter(mockDispatch, {}, 3, {
|
||||
nodeTextMiddleware: ['callout'],
|
||||
});
|
||||
const chunks: string[] = [];
|
||||
for await (const chunk of adapter.streamText({
|
||||
model: 'gpt-4.1',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
|
||||
})) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
const text = chunks.join('');
|
||||
t.true(text.includes('Use [^1] now'));
|
||||
t.false(
|
||||
text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}')
|
||||
);
|
||||
});
|
||||
@@ -0,0 +1,56 @@
|
||||
import test from 'ava';
|
||||
|
||||
import { resolveProviderMiddleware } from '../../plugins/copilot/providers/provider-middleware';
|
||||
import { buildProviderRegistry } from '../../plugins/copilot/providers/provider-registry';
|
||||
import { CopilotProviderType } from '../../plugins/copilot/providers/types';
|
||||
|
||||
test('resolveProviderMiddleware should include anthropic defaults', t => {
|
||||
const middleware = resolveProviderMiddleware(CopilotProviderType.Anthropic);
|
||||
|
||||
t.deepEqual(middleware.rust?.request, [
|
||||
'normalize_messages',
|
||||
'tool_schema_rewrite',
|
||||
]);
|
||||
t.deepEqual(middleware.rust?.stream, [
|
||||
'stream_event_normalize',
|
||||
'citation_indexing',
|
||||
]);
|
||||
t.deepEqual(middleware.node?.text, ['citation_footnote', 'callout']);
|
||||
});
|
||||
|
||||
test('resolveProviderMiddleware should merge defaults and overrides', t => {
|
||||
const middleware = resolveProviderMiddleware(CopilotProviderType.OpenAI, {
|
||||
rust: { request: ['clamp_max_tokens'] },
|
||||
node: { text: ['thinking_format'] },
|
||||
});
|
||||
|
||||
t.deepEqual(middleware.rust?.request, [
|
||||
'normalize_messages',
|
||||
'clamp_max_tokens',
|
||||
]);
|
||||
t.deepEqual(middleware.node?.text, [
|
||||
'citation_footnote',
|
||||
'callout',
|
||||
'thinking_format',
|
||||
]);
|
||||
});
|
||||
|
||||
test('buildProviderRegistry should normalize profile middleware defaults', t => {
|
||||
const registry = buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
config: { apiKey: '1' },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const profile = registry.profiles.get('openai-main');
|
||||
t.truthy(profile);
|
||||
t.deepEqual(profile?.middleware.rust?.stream, [
|
||||
'stream_event_normalize',
|
||||
'citation_indexing',
|
||||
]);
|
||||
t.deepEqual(profile?.middleware.node?.text, ['citation_footnote', 'callout']);
|
||||
});
|
||||
@@ -0,0 +1,99 @@
|
||||
import test from 'ava';
|
||||
|
||||
import { ProviderMiddlewareConfig } from '../../plugins/copilot/config';
|
||||
import { CopilotProvider } from '../../plugins/copilot/providers/provider';
|
||||
import {
|
||||
CopilotProviderType,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
} from '../../plugins/copilot/providers/types';
|
||||
|
||||
class TestOpenAIProvider extends CopilotProvider<{ apiKey: string }> {
|
||||
readonly type = CopilotProviderType.OpenAI;
|
||||
readonly models = [
|
||||
{
|
||||
id: 'gpt-4.1',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text],
|
||||
output: [ModelOutputType.Text],
|
||||
defaultForOutputType: true,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
configured() {
|
||||
return true;
|
||||
}
|
||||
|
||||
async text(_cond: any, _messages: any[], _options?: any) {
|
||||
return '';
|
||||
}
|
||||
|
||||
async *streamText(_cond: any, _messages: any[], _options?: any) {
|
||||
yield '';
|
||||
}
|
||||
|
||||
exposeMetricLabels() {
|
||||
return this.metricLabels('gpt-4.1');
|
||||
}
|
||||
|
||||
exposeMiddleware() {
|
||||
return this.getActiveProviderMiddleware();
|
||||
}
|
||||
}
|
||||
|
||||
function createProvider(profileMiddleware?: ProviderMiddlewareConfig) {
|
||||
const provider = new TestOpenAIProvider();
|
||||
(provider as any).AFFiNEConfig = {
|
||||
copilot: {
|
||||
providers: {
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
config: { apiKey: 'test' },
|
||||
middleware: profileMiddleware,
|
||||
},
|
||||
],
|
||||
defaults: {},
|
||||
openai: { apiKey: 'legacy' },
|
||||
},
|
||||
},
|
||||
};
|
||||
return provider;
|
||||
}
|
||||
|
||||
test('metricLabels should include active provider id', t => {
|
||||
const provider = createProvider();
|
||||
const labels = provider.runWithProfile('openai-main', () =>
|
||||
provider.exposeMetricLabels()
|
||||
);
|
||||
t.is(labels.providerId, 'openai-main');
|
||||
});
|
||||
|
||||
test('getActiveProviderMiddleware should merge defaults with profile override', t => {
|
||||
const provider = createProvider({
|
||||
rust: { request: ['clamp_max_tokens'] },
|
||||
node: { text: ['thinking_format'] },
|
||||
});
|
||||
|
||||
const middleware = provider.runWithProfile('openai-main', () =>
|
||||
provider.exposeMiddleware()
|
||||
);
|
||||
|
||||
t.deepEqual(middleware.rust?.request, [
|
||||
'normalize_messages',
|
||||
'clamp_max_tokens',
|
||||
]);
|
||||
t.deepEqual(middleware.rust?.stream, [
|
||||
'stream_event_normalize',
|
||||
'citation_indexing',
|
||||
]);
|
||||
t.deepEqual(middleware.node?.text, [
|
||||
'citation_footnote',
|
||||
'callout',
|
||||
'thinking_format',
|
||||
]);
|
||||
});
|
||||
@@ -0,0 +1,165 @@
|
||||
import test from 'ava';
|
||||
|
||||
import {
|
||||
buildProviderRegistry,
|
||||
resolveModel,
|
||||
stripProviderPrefix,
|
||||
} from '../../plugins/copilot/providers/provider-registry';
|
||||
import {
|
||||
CopilotProviderType,
|
||||
ModelOutputType,
|
||||
} from '../../plugins/copilot/providers/types';
|
||||
|
||||
test('buildProviderRegistry should keep explicit profile over legacy compatibility profile', t => {
|
||||
const registry = buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-default',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
priority: 100,
|
||||
config: { apiKey: 'new' },
|
||||
},
|
||||
],
|
||||
openai: { apiKey: 'legacy' },
|
||||
});
|
||||
|
||||
const profile = registry.profiles.get('openai-default');
|
||||
t.truthy(profile);
|
||||
t.deepEqual(profile?.config, { apiKey: 'new' });
|
||||
});
|
||||
|
||||
test('buildProviderRegistry should reject duplicated profile ids', t => {
|
||||
const error = t.throws(() =>
|
||||
buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
config: { apiKey: '1' },
|
||||
},
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
config: { apiKey: '2' },
|
||||
},
|
||||
],
|
||||
})
|
||||
) as Error;
|
||||
|
||||
t.truthy(error);
|
||||
t.regex(error.message, /Duplicated copilot provider profile id/);
|
||||
});
|
||||
|
||||
test('buildProviderRegistry should reject defaults that reference unknown providers', t => {
|
||||
const error = t.throws(() =>
|
||||
buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
config: { apiKey: '1' },
|
||||
},
|
||||
],
|
||||
defaults: {
|
||||
fallback: 'unknown-provider',
|
||||
},
|
||||
})
|
||||
) as Error;
|
||||
|
||||
t.truthy(error);
|
||||
t.regex(error.message, /defaults references unknown providerId/);
|
||||
});
|
||||
|
||||
test('resolveModel should support explicit provider prefix and keep slash models untouched', t => {
|
||||
const registry = buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
config: { apiKey: '1' },
|
||||
},
|
||||
{
|
||||
id: 'fal-main',
|
||||
type: CopilotProviderType.FAL,
|
||||
config: { apiKey: '2' },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const prefixed = resolveModel({
|
||||
registry,
|
||||
modelId: 'openai-main/gpt-4.1',
|
||||
});
|
||||
t.deepEqual(prefixed, {
|
||||
rawModelId: 'openai-main/gpt-4.1',
|
||||
modelId: 'gpt-4.1',
|
||||
explicitProviderId: 'openai-main',
|
||||
candidateProviderIds: ['openai-main'],
|
||||
});
|
||||
|
||||
const slashModel = resolveModel({
|
||||
registry,
|
||||
modelId: 'lora/image-to-image',
|
||||
});
|
||||
t.is(slashModel.modelId, 'lora/image-to-image');
|
||||
t.false(slashModel.candidateProviderIds.includes('lora'));
|
||||
});
|
||||
|
||||
test('resolveModel should follow defaults -> fallback -> order and apply filters', t => {
|
||||
const registry = buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
priority: 10,
|
||||
config: { apiKey: '1' },
|
||||
},
|
||||
{
|
||||
id: 'anthropic-main',
|
||||
type: CopilotProviderType.Anthropic,
|
||||
priority: 5,
|
||||
config: { apiKey: '2' },
|
||||
},
|
||||
{
|
||||
id: 'fal-main',
|
||||
type: CopilotProviderType.FAL,
|
||||
priority: 1,
|
||||
config: { apiKey: '3' },
|
||||
},
|
||||
],
|
||||
defaults: {
|
||||
[ModelOutputType.Text]: 'anthropic-main',
|
||||
fallback: 'openai-main',
|
||||
},
|
||||
});
|
||||
|
||||
const routed = resolveModel({
|
||||
registry,
|
||||
outputType: ModelOutputType.Text,
|
||||
preferredProviderIds: ['openai-main', 'fal-main'],
|
||||
});
|
||||
|
||||
t.deepEqual(routed.candidateProviderIds, ['openai-main', 'fal-main']);
|
||||
});
|
||||
|
||||
test('stripProviderPrefix should only strip matched provider prefix', t => {
|
||||
const registry = buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
config: { apiKey: '1' },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
t.is(
|
||||
stripProviderPrefix(registry, 'openai-main', 'openai-main/gpt-4.1'),
|
||||
'gpt-4.1'
|
||||
);
|
||||
t.is(
|
||||
stripProviderPrefix(registry, 'openai-main', 'another-main/gpt-4.1'),
|
||||
'another-main/gpt-4.1'
|
||||
);
|
||||
t.is(stripProviderPrefix(registry, 'openai-main', 'gpt-4.1'), 'gpt-4.1');
|
||||
});
|
||||
@@ -0,0 +1,134 @@
|
||||
import test from 'ava';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { NativeLlmRequest, NativeLlmStreamEvent } from '../../native';
|
||||
import {
|
||||
ToolCallAccumulator,
|
||||
ToolCallLoop,
|
||||
ToolSchemaExtractor,
|
||||
} from '../../plugins/copilot/providers/loop';
|
||||
|
||||
test('ToolCallAccumulator should merge deltas and complete tool call', t => {
|
||||
const accumulator = new ToolCallAccumulator();
|
||||
|
||||
accumulator.feedDelta({
|
||||
type: 'tool_call_delta',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments_delta: '{"doc_id":"',
|
||||
});
|
||||
accumulator.feedDelta({
|
||||
type: 'tool_call_delta',
|
||||
call_id: 'call_1',
|
||||
arguments_delta: 'a1"}',
|
||||
});
|
||||
|
||||
const completed = accumulator.complete({
|
||||
type: 'tool_call',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments: { doc_id: 'a1' },
|
||||
});
|
||||
|
||||
t.deepEqual(completed, {
|
||||
id: 'call_1',
|
||||
name: 'doc_read',
|
||||
args: { doc_id: 'a1' },
|
||||
thought: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
test('ToolSchemaExtractor should convert zod schema to json schema', t => {
|
||||
const toolSet = {
|
||||
doc_read: {
|
||||
description: 'Read doc',
|
||||
inputSchema: z.object({
|
||||
doc_id: z.string(),
|
||||
limit: z.number().optional(),
|
||||
}),
|
||||
execute: async () => ({}),
|
||||
},
|
||||
};
|
||||
|
||||
const extracted = ToolSchemaExtractor.extract(toolSet);
|
||||
|
||||
t.deepEqual(extracted, [
|
||||
{
|
||||
name: 'doc_read',
|
||||
description: 'Read doc',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
doc_id: { type: 'string' },
|
||||
limit: { type: 'number' },
|
||||
},
|
||||
additionalProperties: false,
|
||||
required: ['doc_id'],
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
test('ToolCallLoop should execute tool call and continue to next round', async t => {
|
||||
const dispatchRequests: NativeLlmRequest[] = [];
|
||||
|
||||
const dispatch = (request: NativeLlmRequest) => {
|
||||
dispatchRequests.push(request);
|
||||
const round = dispatchRequests.length;
|
||||
|
||||
return (async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
if (round === 1) {
|
||||
yield {
|
||||
type: 'tool_call_delta',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments_delta: '{"doc_id":"a1"}',
|
||||
};
|
||||
yield {
|
||||
type: 'tool_call',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments: { doc_id: 'a1' },
|
||||
};
|
||||
yield { type: 'done', finish_reason: 'tool_calls' };
|
||||
return;
|
||||
}
|
||||
|
||||
yield { type: 'text_delta', text: 'done' };
|
||||
yield { type: 'done', finish_reason: 'stop' };
|
||||
})();
|
||||
};
|
||||
|
||||
let executedArgs: Record<string, unknown> | null = null;
|
||||
const loop = new ToolCallLoop(
|
||||
dispatch,
|
||||
{
|
||||
doc_read: {
|
||||
inputSchema: z.object({ doc_id: z.string() }),
|
||||
execute: async args => {
|
||||
executedArgs = args;
|
||||
return { markdown: '# doc' };
|
||||
},
|
||||
},
|
||||
},
|
||||
4
|
||||
);
|
||||
|
||||
const events: NativeLlmStreamEvent[] = [];
|
||||
for await (const event of loop.run({
|
||||
model: 'gpt-4.1',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'read doc' }] }],
|
||||
})) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
t.deepEqual(executedArgs, { doc_id: 'a1' });
|
||||
t.true(
|
||||
dispatchRequests[1]?.messages.some(message => message.role === 'tool')
|
||||
);
|
||||
t.deepEqual(
|
||||
events.map(event => event.type),
|
||||
['tool_call', 'tool_result', 'text_delta', 'done']
|
||||
);
|
||||
});
|
||||
116
packages/backend/server/src/__tests__/copilot/utils.spec.ts
Normal file
116
packages/backend/server/src/__tests__/copilot/utils.spec.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
import test from 'ava';
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
chatToGPTMessage,
|
||||
CitationFootnoteFormatter,
|
||||
CitationParser,
|
||||
StreamPatternParser,
|
||||
} from '../../plugins/copilot/providers/utils';
|
||||
|
||||
test('CitationFootnoteFormatter should format sorted footnotes from citation events', t => {
|
||||
const formatter = new CitationFootnoteFormatter();
|
||||
|
||||
formatter.consume({
|
||||
type: 'citation',
|
||||
index: 2,
|
||||
url: 'https://example.com/b',
|
||||
});
|
||||
formatter.consume({
|
||||
type: 'citation',
|
||||
index: 1,
|
||||
url: 'https://example.com/a',
|
||||
});
|
||||
|
||||
t.is(
|
||||
formatter.end(),
|
||||
[
|
||||
'[^1]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fa"}',
|
||||
'[^2]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fb"}',
|
||||
].join('\n')
|
||||
);
|
||||
});
|
||||
|
||||
test('CitationFootnoteFormatter should overwrite duplicated index with latest url', t => {
|
||||
const formatter = new CitationFootnoteFormatter();
|
||||
|
||||
formatter.consume({
|
||||
type: 'citation',
|
||||
index: 1,
|
||||
url: 'https://example.com/old',
|
||||
});
|
||||
formatter.consume({
|
||||
type: 'citation',
|
||||
index: 1,
|
||||
url: 'https://example.com/new',
|
||||
});
|
||||
|
||||
t.is(
|
||||
formatter.end(),
|
||||
'[^1]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fnew"}'
|
||||
);
|
||||
});
|
||||
|
||||
test('StreamPatternParser should keep state across chunks', t => {
|
||||
const parser = new StreamPatternParser(pattern => {
|
||||
if (pattern.kind === 'wrappedLink') {
|
||||
return `[^${pattern.url}]`;
|
||||
}
|
||||
if (pattern.kind === 'index') {
|
||||
return `[#${pattern.value}]`;
|
||||
}
|
||||
return `[${pattern.text}](${pattern.url})`;
|
||||
});
|
||||
|
||||
const first = parser.write('ref ([AFFiNE](https://affine.pro');
|
||||
const second = parser.write(')) and [2]');
|
||||
|
||||
t.is(first, 'ref ');
|
||||
t.is(second, '[^https://affine.pro] and [#2]');
|
||||
t.is(parser.end(), '');
|
||||
});
|
||||
|
||||
test('CitationParser should convert wrapped links to numbered footnotes', t => {
|
||||
const parser = new CitationParser();
|
||||
|
||||
const output = parser.parse('Use ([AFFiNE](https://affine.pro)) now');
|
||||
t.is(output, 'Use [^1] now');
|
||||
t.regex(
|
||||
parser.end(),
|
||||
/\[\^1\]: \{"type":"url","url":"https%3A%2F%2Faffine.pro"\}/
|
||||
);
|
||||
});
|
||||
|
||||
test('chatToGPTMessage should not mutate input and should keep system schema', async t => {
|
||||
const schema = z.object({
|
||||
query: z.string(),
|
||||
});
|
||||
const messages = [
|
||||
{
|
||||
role: 'system' as const,
|
||||
content: 'You are helper',
|
||||
params: { schema },
|
||||
},
|
||||
{
|
||||
role: 'user' as const,
|
||||
content: '',
|
||||
attachments: ['https://example.com/a.png'],
|
||||
},
|
||||
];
|
||||
const firstRef = messages[0];
|
||||
const secondRef = messages[1];
|
||||
const [system, normalized, parsedSchema] = await chatToGPTMessage(
|
||||
messages,
|
||||
false
|
||||
);
|
||||
|
||||
t.is(system, 'You are helper');
|
||||
t.is(parsedSchema, schema);
|
||||
t.is(messages.length, 2);
|
||||
t.is(messages[0], firstRef);
|
||||
t.is(messages[1], secondRef);
|
||||
t.deepEqual(normalized[0], {
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: '[no content]' }],
|
||||
});
|
||||
});
|
||||
82
packages/backend/server/src/__tests__/native.spec.ts
Normal file
82
packages/backend/server/src/__tests__/native.spec.ts
Normal file
@@ -0,0 +1,82 @@
|
||||
import test from 'ava';
|
||||
|
||||
import { NativeStreamAdapter } from '../native';
|
||||
|
||||
test('NativeStreamAdapter should support buffered and awaited consumption', async t => {
|
||||
const adapter = new NativeStreamAdapter<number>(undefined);
|
||||
|
||||
adapter.push(1);
|
||||
const first = await adapter.next();
|
||||
t.deepEqual(first, { value: 1, done: false });
|
||||
|
||||
const pending = adapter.next();
|
||||
adapter.push(2);
|
||||
const second = await pending;
|
||||
t.deepEqual(second, { value: 2, done: false });
|
||||
|
||||
adapter.push(null);
|
||||
const done = await adapter.next();
|
||||
t.true(done.done);
|
||||
});
|
||||
|
||||
test('NativeStreamAdapter return should abort handle and end iteration', async t => {
|
||||
let abortCount = 0;
|
||||
const adapter = new NativeStreamAdapter<number>({
|
||||
abort: () => {
|
||||
abortCount += 1;
|
||||
},
|
||||
});
|
||||
|
||||
const ended = await adapter.return();
|
||||
t.is(abortCount, 1);
|
||||
t.true(ended.done);
|
||||
|
||||
const secondReturn = await adapter.return();
|
||||
t.true(secondReturn.done);
|
||||
t.is(abortCount, 1);
|
||||
|
||||
const next = await adapter.next();
|
||||
t.true(next.done);
|
||||
});
|
||||
|
||||
test('NativeStreamAdapter should abort when AbortSignal is triggered', async t => {
|
||||
let abortCount = 0;
|
||||
const controller = new AbortController();
|
||||
const adapter = new NativeStreamAdapter<number>(
|
||||
{
|
||||
abort: () => {
|
||||
abortCount += 1;
|
||||
},
|
||||
},
|
||||
controller.signal
|
||||
);
|
||||
|
||||
const pending = adapter.next();
|
||||
controller.abort();
|
||||
const done = await pending;
|
||||
t.true(done.done);
|
||||
t.is(abortCount, 1);
|
||||
});
|
||||
|
||||
test('NativeStreamAdapter should end immediately for pre-aborted signal', async t => {
|
||||
let abortCount = 0;
|
||||
const controller = new AbortController();
|
||||
controller.abort();
|
||||
|
||||
const adapter = new NativeStreamAdapter<number>(
|
||||
{
|
||||
abort: () => {
|
||||
abortCount += 1;
|
||||
},
|
||||
},
|
||||
controller.signal
|
||||
);
|
||||
|
||||
const next = await adapter.next();
|
||||
t.true(next.done);
|
||||
t.is(abortCount, 1);
|
||||
|
||||
adapter.push(1);
|
||||
const stillDone = await adapter.next();
|
||||
t.true(stillDone.done);
|
||||
});
|
||||
@@ -629,14 +629,35 @@ export async function chatWithText(
|
||||
prefix = '',
|
||||
retry?: boolean
|
||||
): Promise<string> {
|
||||
const endpoint = prefix || '/stream';
|
||||
const query = messageId
|
||||
? `?messageId=${messageId}` + (retry ? '&retry=true' : '')
|
||||
: '';
|
||||
const res = await app
|
||||
.GET(`/api/copilot/chat/${sessionId}${prefix}${query}`)
|
||||
.GET(`/api/copilot/chat/${sessionId}${endpoint}${query}`)
|
||||
.expect(200);
|
||||
|
||||
return res.text;
|
||||
if (prefix) {
|
||||
return res.text;
|
||||
}
|
||||
|
||||
const events = sse2array(res.text);
|
||||
const errorEvent = events.find(event => event.event === 'error');
|
||||
if (errorEvent?.data) {
|
||||
let message = errorEvent.data;
|
||||
try {
|
||||
const parsed = JSON.parse(errorEvent.data);
|
||||
message = parsed.message || message;
|
||||
} catch {
|
||||
// noop: keep raw error data
|
||||
}
|
||||
throw new Error(message);
|
||||
}
|
||||
|
||||
return events
|
||||
.filter(event => event.event === 'message')
|
||||
.map(event => event.data ?? '')
|
||||
.join('');
|
||||
}
|
||||
|
||||
export async function chatWithTextStream(
|
||||
|
||||
@@ -38,8 +38,11 @@ test.before(async t => {
|
||||
t.context.app = app;
|
||||
});
|
||||
|
||||
test.after.always(async t => {
|
||||
test.afterEach.always(() => {
|
||||
Sinon.restore();
|
||||
});
|
||||
|
||||
test.after.always(async t => {
|
||||
__resetDnsLookupForTests();
|
||||
await t.context.app.close();
|
||||
});
|
||||
@@ -80,6 +83,7 @@ const assertAndSnapshotRaw = async (
|
||||
|
||||
test('should proxy image', async t => {
|
||||
const assertAndSnapshot = assertAndSnapshotRaw.bind(null, t);
|
||||
const imageUrl = `http://example.com/image-${Date.now()}.png`;
|
||||
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/image-proxy',
|
||||
@@ -105,7 +109,7 @@ test('should proxy image', async t => {
|
||||
|
||||
{
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/image-proxy?url=http://example.com/image.png',
|
||||
`/api/worker/image-proxy?url=${imageUrl}`,
|
||||
'should return 400 if origin and referer are missing',
|
||||
{ status: 400, origin: null, referer: null }
|
||||
);
|
||||
@@ -113,14 +117,17 @@ test('should proxy image', async t => {
|
||||
|
||||
{
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/image-proxy?url=http://example.com/image.png',
|
||||
`/api/worker/image-proxy?url=${imageUrl}`,
|
||||
'should return 400 for invalid origin header',
|
||||
{ status: 400, origin: 'http://invalid.com' }
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
const fakeBuffer = Buffer.from('fake image');
|
||||
const fakeBuffer = Buffer.from(
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+jfJ8AAAAASUVORK5CYII=',
|
||||
'base64'
|
||||
);
|
||||
const fakeResponse = new Response(fakeBuffer, {
|
||||
status: 200,
|
||||
headers: {
|
||||
@@ -130,13 +137,14 @@ test('should proxy image', async t => {
|
||||
});
|
||||
|
||||
const fetchSpy = Sinon.stub(global, 'fetch').resolves(fakeResponse);
|
||||
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/image-proxy?url=http://example.com/image.png',
|
||||
'should return image buffer'
|
||||
);
|
||||
|
||||
fetchSpy.restore();
|
||||
try {
|
||||
await assertAndSnapshot(
|
||||
`/api/worker/image-proxy?url=${imageUrl}`,
|
||||
'should return image buffer'
|
||||
);
|
||||
} finally {
|
||||
fetchSpy.restore();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -200,18 +208,19 @@ test('should preview link', async t => {
|
||||
});
|
||||
|
||||
const fetchSpy = Sinon.stub(global, 'fetch').resolves(fakeHTML);
|
||||
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/link-preview',
|
||||
'should process a valid external URL and return link preview data',
|
||||
{
|
||||
status: 200,
|
||||
method: 'POST',
|
||||
body: { url: 'http://external.com/page' },
|
||||
}
|
||||
);
|
||||
|
||||
fetchSpy.restore();
|
||||
try {
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/link-preview',
|
||||
'should process a valid external URL and return link preview data',
|
||||
{
|
||||
status: 200,
|
||||
method: 'POST',
|
||||
body: { url: 'http://external.com/page' },
|
||||
}
|
||||
);
|
||||
} finally {
|
||||
fetchSpy.restore();
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
@@ -251,18 +260,19 @@ test('should preview link', async t => {
|
||||
});
|
||||
|
||||
const fetchSpy = Sinon.stub(global, 'fetch').resolves(fakeHTML);
|
||||
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/link-preview',
|
||||
'should decode HTML content with charset',
|
||||
{
|
||||
status: 200,
|
||||
method: 'POST',
|
||||
body: { url: `http://example.com/${charset}` },
|
||||
}
|
||||
);
|
||||
|
||||
fetchSpy.restore();
|
||||
try {
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/link-preview',
|
||||
'should decode HTML content with charset',
|
||||
{
|
||||
status: 200,
|
||||
method: 'POST',
|
||||
body: { url: `http://example.com/${charset}` },
|
||||
}
|
||||
);
|
||||
} finally {
|
||||
fetchSpy.restore();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -42,8 +42,18 @@ export class Ga4Client {
|
||||
timestamp_micros: event.timestampMicros,
|
||||
})),
|
||||
};
|
||||
|
||||
await this.post(payload);
|
||||
try {
|
||||
await this.post(payload);
|
||||
} catch {
|
||||
if (env.DEPLOYMENT_TYPE === 'affine') {
|
||||
// In production, we want to be resilient to GA4 failures, so we catch and ignore errors.
|
||||
// In non-production environments, we rethrow to surface issues during development and testing.
|
||||
console.info(
|
||||
'Failed to send telemetry event to GA4:',
|
||||
chunk.map(e => e.eventName).join(', ')
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,3 +57,316 @@ export const addDocToRootDoc = serverNativeModule.addDocToRootDoc;
|
||||
export const updateDocTitle = serverNativeModule.updateDocTitle;
|
||||
export const updateDocProperties = serverNativeModule.updateDocProperties;
|
||||
export const updateRootDocMetaTitle = serverNativeModule.updateRootDocMetaTitle;
|
||||
|
||||
type NativeLlmModule = {
|
||||
llmDispatch?: (
|
||||
protocol: string,
|
||||
backendConfigJson: string,
|
||||
requestJson: string
|
||||
) => string | Promise<string>;
|
||||
llmDispatchStream?: (
|
||||
protocol: string,
|
||||
backendConfigJson: string,
|
||||
requestJson: string,
|
||||
callback: (error: Error | null, eventJson: string) => void
|
||||
) => { abort?: () => void } | undefined;
|
||||
};
|
||||
|
||||
const nativeLlmModule = serverNativeModule as typeof serverNativeModule &
|
||||
NativeLlmModule;
|
||||
|
||||
export type NativeLlmProtocol =
|
||||
| 'openai_chat'
|
||||
| 'openai_responses'
|
||||
| 'anthropic';
|
||||
|
||||
export type NativeLlmBackendConfig = {
|
||||
base_url: string;
|
||||
auth_token: string;
|
||||
request_layer?: 'anthropic' | 'chat_completions' | 'responses' | 'vertex';
|
||||
headers?: Record<string, string>;
|
||||
no_streaming?: boolean;
|
||||
timeout_ms?: number;
|
||||
};
|
||||
|
||||
export type NativeLlmCoreRole = 'system' | 'user' | 'assistant' | 'tool';
|
||||
|
||||
export type NativeLlmCoreContent =
|
||||
| { type: 'text'; text: string }
|
||||
| { type: 'reasoning'; text: string; signature?: string }
|
||||
| {
|
||||
type: 'tool_call';
|
||||
call_id: string;
|
||||
name: string;
|
||||
arguments: Record<string, unknown>;
|
||||
thought?: string;
|
||||
}
|
||||
| {
|
||||
type: 'tool_result';
|
||||
call_id: string;
|
||||
output: unknown;
|
||||
is_error?: boolean;
|
||||
name?: string;
|
||||
arguments?: Record<string, unknown>;
|
||||
}
|
||||
| { type: 'image'; source: Record<string, unknown> | string };
|
||||
|
||||
export type NativeLlmCoreMessage = {
|
||||
role: NativeLlmCoreRole;
|
||||
content: NativeLlmCoreContent[];
|
||||
};
|
||||
|
||||
export type NativeLlmToolDefinition = {
|
||||
name: string;
|
||||
description?: string;
|
||||
parameters: Record<string, unknown>;
|
||||
};
|
||||
|
||||
export type NativeLlmRequest = {
|
||||
model: string;
|
||||
messages: NativeLlmCoreMessage[];
|
||||
stream?: boolean;
|
||||
max_tokens?: number;
|
||||
temperature?: number;
|
||||
tools?: NativeLlmToolDefinition[];
|
||||
tool_choice?: 'auto' | 'none' | 'required' | { name: string };
|
||||
include?: string[];
|
||||
reasoning?: Record<string, unknown>;
|
||||
middleware?: {
|
||||
request?: Array<
|
||||
'normalize_messages' | 'clamp_max_tokens' | 'tool_schema_rewrite'
|
||||
>;
|
||||
stream?: Array<'stream_event_normalize' | 'citation_indexing'>;
|
||||
config?: {
|
||||
no_additional_properties?: boolean;
|
||||
drop_property_format?: boolean;
|
||||
drop_property_min_length?: boolean;
|
||||
drop_array_min_items?: boolean;
|
||||
drop_array_max_items?: boolean;
|
||||
max_tokens_cap?: number;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
export type NativeLlmDispatchResponse = {
|
||||
id: string;
|
||||
model: string;
|
||||
message: NativeLlmCoreMessage;
|
||||
usage: {
|
||||
prompt_tokens: number;
|
||||
completion_tokens: number;
|
||||
total_tokens: number;
|
||||
cached_tokens?: number;
|
||||
};
|
||||
finish_reason: string;
|
||||
reasoning_details?: unknown;
|
||||
};
|
||||
|
||||
export type NativeLlmStreamEvent =
|
||||
| { type: 'message_start'; id?: string; model?: string }
|
||||
| { type: 'text_delta'; text: string }
|
||||
| { type: 'reasoning_delta'; text: string }
|
||||
| {
|
||||
type: 'tool_call_delta';
|
||||
call_id: string;
|
||||
name?: string;
|
||||
arguments_delta: string;
|
||||
}
|
||||
| {
|
||||
type: 'tool_call';
|
||||
call_id: string;
|
||||
name: string;
|
||||
arguments: Record<string, unknown>;
|
||||
thought?: string;
|
||||
}
|
||||
| {
|
||||
type: 'tool_result';
|
||||
call_id: string;
|
||||
output: unknown;
|
||||
is_error?: boolean;
|
||||
name?: string;
|
||||
arguments?: Record<string, unknown>;
|
||||
}
|
||||
| { type: 'citation'; index: number; url: string }
|
||||
| {
|
||||
type: 'usage';
|
||||
usage: {
|
||||
prompt_tokens: number;
|
||||
completion_tokens: number;
|
||||
total_tokens: number;
|
||||
cached_tokens?: number;
|
||||
};
|
||||
}
|
||||
| {
|
||||
type: 'done';
|
||||
finish_reason?: string;
|
||||
usage?: {
|
||||
prompt_tokens: number;
|
||||
completion_tokens: number;
|
||||
total_tokens: number;
|
||||
cached_tokens?: number;
|
||||
};
|
||||
}
|
||||
| { type: 'error'; message: string; code?: string; raw?: string };
|
||||
const LLM_STREAM_END_MARKER = '__AFFINE_LLM_STREAM_END__';
|
||||
|
||||
export async function llmDispatch(
|
||||
protocol: NativeLlmProtocol,
|
||||
backendConfig: NativeLlmBackendConfig,
|
||||
request: NativeLlmRequest
|
||||
): Promise<NativeLlmDispatchResponse> {
|
||||
if (!nativeLlmModule.llmDispatch) {
|
||||
throw new Error('native llm dispatch is not available');
|
||||
}
|
||||
const response = nativeLlmModule.llmDispatch(
|
||||
protocol,
|
||||
JSON.stringify(backendConfig),
|
||||
JSON.stringify(request)
|
||||
);
|
||||
const responseText = await Promise.resolve(response);
|
||||
return JSON.parse(responseText) as NativeLlmDispatchResponse;
|
||||
}
|
||||
|
||||
export class NativeStreamAdapter<T> implements AsyncIterableIterator<T> {
|
||||
readonly #queue: T[] = [];
|
||||
readonly #waiters: ((result: IteratorResult<T>) => void)[] = [];
|
||||
readonly #handle: { abort?: () => void } | undefined;
|
||||
readonly #signal?: AbortSignal;
|
||||
readonly #abortListener?: () => void;
|
||||
#ended = false;
|
||||
|
||||
constructor(
|
||||
handle: { abort?: () => void } | undefined,
|
||||
signal?: AbortSignal
|
||||
) {
|
||||
this.#handle = handle;
|
||||
this.#signal = signal;
|
||||
|
||||
if (signal?.aborted) {
|
||||
this.close(true);
|
||||
return;
|
||||
}
|
||||
|
||||
if (signal) {
|
||||
this.#abortListener = () => {
|
||||
this.close(true);
|
||||
};
|
||||
signal.addEventListener('abort', this.#abortListener, { once: true });
|
||||
}
|
||||
}
|
||||
|
||||
private close(abortHandle: boolean) {
|
||||
if (this.#ended) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.#ended = true;
|
||||
if (this.#signal && this.#abortListener) {
|
||||
this.#signal.removeEventListener('abort', this.#abortListener);
|
||||
}
|
||||
if (abortHandle) {
|
||||
this.#handle?.abort?.();
|
||||
}
|
||||
|
||||
while (this.#waiters.length) {
|
||||
const waiter = this.#waiters.shift();
|
||||
waiter?.({ value: undefined as T, done: true });
|
||||
}
|
||||
}
|
||||
|
||||
push(value: T | null) {
|
||||
if (this.#ended) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (value === null) {
|
||||
this.close(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const waiter = this.#waiters.shift();
|
||||
if (waiter) {
|
||||
waiter({ value, done: false });
|
||||
return;
|
||||
}
|
||||
|
||||
this.#queue.push(value);
|
||||
}
|
||||
|
||||
[Symbol.asyncIterator]() {
|
||||
return this;
|
||||
}
|
||||
|
||||
async next(): Promise<IteratorResult<T>> {
|
||||
if (this.#queue.length > 0) {
|
||||
const value = this.#queue.shift() as T;
|
||||
return { value, done: false };
|
||||
}
|
||||
|
||||
if (this.#ended) {
|
||||
return { value: undefined as T, done: true };
|
||||
}
|
||||
|
||||
return await new Promise(resolve => {
|
||||
this.#waiters.push(resolve);
|
||||
});
|
||||
}
|
||||
|
||||
async return(): Promise<IteratorResult<T>> {
|
||||
this.close(true);
|
||||
|
||||
return { value: undefined as T, done: true };
|
||||
}
|
||||
}
|
||||
|
||||
export function llmDispatchStream(
|
||||
protocol: NativeLlmProtocol,
|
||||
backendConfig: NativeLlmBackendConfig,
|
||||
request: NativeLlmRequest,
|
||||
signal?: AbortSignal
|
||||
): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
if (!nativeLlmModule.llmDispatchStream) {
|
||||
throw new Error('native llm stream dispatch is not available');
|
||||
}
|
||||
|
||||
let adapter: NativeStreamAdapter<NativeLlmStreamEvent> | undefined;
|
||||
const buffer: (NativeLlmStreamEvent | null)[] = [];
|
||||
let pushFn = (event: NativeLlmStreamEvent | null) => {
|
||||
buffer.push(event);
|
||||
};
|
||||
const handle = nativeLlmModule.llmDispatchStream(
|
||||
protocol,
|
||||
JSON.stringify(backendConfig),
|
||||
JSON.stringify(request),
|
||||
(error, eventJson) => {
|
||||
if (error) {
|
||||
pushFn({ type: 'error', message: error.message, raw: eventJson });
|
||||
return;
|
||||
}
|
||||
if (eventJson === LLM_STREAM_END_MARKER) {
|
||||
pushFn(null);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
pushFn(JSON.parse(eventJson) as NativeLlmStreamEvent);
|
||||
} catch (error) {
|
||||
pushFn({
|
||||
type: 'error',
|
||||
message:
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: 'failed to parse native stream event',
|
||||
raw: eventJson,
|
||||
});
|
||||
}
|
||||
}
|
||||
);
|
||||
adapter = new NativeStreamAdapter(handle, signal);
|
||||
pushFn = event => {
|
||||
adapter.push(event);
|
||||
};
|
||||
for (const event of buffer) {
|
||||
adapter.push(event);
|
||||
}
|
||||
return adapter;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
defineModuleConfig,
|
||||
StorageJSONSchema,
|
||||
@@ -13,7 +15,179 @@ import { GeminiGenerativeConfig, GeminiVertexConfig } from './providers/gemini';
|
||||
import { MorphConfig } from './providers/morph';
|
||||
import { OpenAIConfig } from './providers/openai';
|
||||
import { PerplexityConfig } from './providers/perplexity';
|
||||
import { VertexSchema } from './providers/types';
|
||||
import {
|
||||
CopilotProviderType,
|
||||
ModelOutputType,
|
||||
VertexSchema,
|
||||
} from './providers/types';
|
||||
|
||||
export type CopilotProviderConfigMap = {
|
||||
[CopilotProviderType.OpenAI]: OpenAIConfig;
|
||||
[CopilotProviderType.FAL]: FalConfig;
|
||||
[CopilotProviderType.Gemini]: GeminiGenerativeConfig;
|
||||
[CopilotProviderType.GeminiVertex]: GeminiVertexConfig;
|
||||
[CopilotProviderType.Perplexity]: PerplexityConfig;
|
||||
[CopilotProviderType.Anthropic]: AnthropicOfficialConfig;
|
||||
[CopilotProviderType.AnthropicVertex]: AnthropicVertexConfig;
|
||||
[CopilotProviderType.Morph]: MorphConfig;
|
||||
};
|
||||
|
||||
export type ProviderSpecificConfig =
|
||||
CopilotProviderConfigMap[keyof CopilotProviderConfigMap];
|
||||
|
||||
export const RustRequestMiddlewareValues = [
|
||||
'normalize_messages',
|
||||
'clamp_max_tokens',
|
||||
'tool_schema_rewrite',
|
||||
] as const;
|
||||
export type RustRequestMiddleware =
|
||||
(typeof RustRequestMiddlewareValues)[number];
|
||||
|
||||
export const RustStreamMiddlewareValues = [
|
||||
'stream_event_normalize',
|
||||
'citation_indexing',
|
||||
] as const;
|
||||
export type RustStreamMiddleware = (typeof RustStreamMiddlewareValues)[number];
|
||||
|
||||
export const NodeTextMiddlewareValues = [
|
||||
'citation_footnote',
|
||||
'callout',
|
||||
'thinking_format',
|
||||
] as const;
|
||||
export type NodeTextMiddleware = (typeof NodeTextMiddlewareValues)[number];
|
||||
|
||||
export type ProviderMiddlewareConfig = {
|
||||
rust?: { request?: RustRequestMiddleware[]; stream?: RustStreamMiddleware[] };
|
||||
node?: { text?: NodeTextMiddleware[] };
|
||||
};
|
||||
|
||||
type CopilotProviderProfileCommon = {
|
||||
id: string;
|
||||
displayName?: string;
|
||||
priority?: number;
|
||||
enabled?: boolean;
|
||||
models?: string[];
|
||||
middleware?: ProviderMiddlewareConfig;
|
||||
};
|
||||
|
||||
type CopilotProviderProfileVariant<T extends CopilotProviderType> = {
|
||||
type: T;
|
||||
config: CopilotProviderConfigMap[T];
|
||||
};
|
||||
|
||||
export type CopilotProviderProfile = CopilotProviderProfileCommon &
|
||||
{
|
||||
[Type in CopilotProviderType]: CopilotProviderProfileVariant<Type>;
|
||||
}[CopilotProviderType];
|
||||
|
||||
export type CopilotProviderDefaults = Partial<
|
||||
Record<ModelOutputType, string>
|
||||
> & {
|
||||
fallback?: string;
|
||||
};
|
||||
|
||||
const CopilotProviderProfileBaseShape = z.object({
|
||||
id: z.string().regex(/^[a-zA-Z0-9-_]+$/),
|
||||
displayName: z.string().optional(),
|
||||
priority: z.number().optional(),
|
||||
enabled: z.boolean().optional(),
|
||||
models: z.array(z.string()).optional(),
|
||||
middleware: z
|
||||
.object({
|
||||
rust: z
|
||||
.object({
|
||||
request: z.array(z.enum(RustRequestMiddlewareValues)).optional(),
|
||||
stream: z.array(z.enum(RustStreamMiddlewareValues)).optional(),
|
||||
})
|
||||
.optional(),
|
||||
node: z
|
||||
.object({ text: z.array(z.enum(NodeTextMiddlewareValues)).optional() })
|
||||
.optional(),
|
||||
})
|
||||
.optional(),
|
||||
});
|
||||
|
||||
const OpenAIConfigShape = z.object({
|
||||
apiKey: z.string(),
|
||||
baseURL: z.string().optional(),
|
||||
oldApiStyle: z.boolean().optional(),
|
||||
});
|
||||
|
||||
const FalConfigShape = z.object({
|
||||
apiKey: z.string(),
|
||||
});
|
||||
|
||||
const GeminiGenerativeConfigShape = z.object({
|
||||
apiKey: z.string(),
|
||||
baseURL: z.string().optional(),
|
||||
});
|
||||
|
||||
const VertexProviderConfigShape = z.object({
|
||||
location: z.string().optional(),
|
||||
project: z.string().optional(),
|
||||
baseURL: z.string().optional(),
|
||||
googleAuthOptions: z.any().optional(),
|
||||
fetch: z.any().optional(),
|
||||
});
|
||||
|
||||
const PerplexityConfigShape = z.object({
|
||||
apiKey: z.string(),
|
||||
endpoint: z.string().optional(),
|
||||
});
|
||||
|
||||
const AnthropicOfficialConfigShape = z.object({
|
||||
apiKey: z.string(),
|
||||
baseURL: z.string().optional(),
|
||||
});
|
||||
|
||||
const MorphConfigShape = z.object({
|
||||
apiKey: z.string().optional(),
|
||||
});
|
||||
|
||||
const CopilotProviderProfileShape = z.discriminatedUnion('type', [
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.OpenAI),
|
||||
config: OpenAIConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.FAL),
|
||||
config: FalConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.Gemini),
|
||||
config: GeminiGenerativeConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.GeminiVertex),
|
||||
config: VertexProviderConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.Perplexity),
|
||||
config: PerplexityConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.Anthropic),
|
||||
config: AnthropicOfficialConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.AnthropicVertex),
|
||||
config: VertexProviderConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.Morph),
|
||||
config: MorphConfigShape,
|
||||
}),
|
||||
]);
|
||||
|
||||
const CopilotProviderDefaultsShape = z.object({
|
||||
[ModelOutputType.Text]: z.string().optional(),
|
||||
[ModelOutputType.Object]: z.string().optional(),
|
||||
[ModelOutputType.Embedding]: z.string().optional(),
|
||||
[ModelOutputType.Image]: z.string().optional(),
|
||||
[ModelOutputType.Structured]: z.string().optional(),
|
||||
fallback: z.string().optional(),
|
||||
});
|
||||
|
||||
declare global {
|
||||
interface AppConfigSchema {
|
||||
copilot: {
|
||||
@@ -27,6 +201,8 @@ declare global {
|
||||
storage: ConfigItem<StorageProviderConfig>;
|
||||
scenarios: ConfigItem<CopilotPromptScenario>;
|
||||
providers: {
|
||||
profiles: ConfigItem<CopilotProviderProfile[]>;
|
||||
defaults: ConfigItem<CopilotProviderDefaults>;
|
||||
openai: ConfigItem<OpenAIConfig>;
|
||||
fal: ConfigItem<FalConfig>;
|
||||
gemini: ConfigItem<GeminiGenerativeConfig>;
|
||||
@@ -63,6 +239,16 @@ defineModuleConfig('copilot', {
|
||||
},
|
||||
},
|
||||
},
|
||||
'providers.profiles': {
|
||||
desc: 'The profile list for copilot providers.',
|
||||
default: [],
|
||||
shape: z.array(CopilotProviderProfileShape),
|
||||
},
|
||||
'providers.defaults': {
|
||||
desc: 'The default provider ids for model output types and global fallback.',
|
||||
default: {},
|
||||
shape: CopilotProviderDefaultsShape,
|
||||
},
|
||||
'providers.openai': {
|
||||
desc: 'The config for the openai provider.',
|
||||
default: {
|
||||
|
||||
@@ -36,10 +36,7 @@ import {
|
||||
BlobNotFound,
|
||||
CallMetric,
|
||||
Config,
|
||||
CopilotFailedToGenerateText,
|
||||
CopilotSessionNotFound,
|
||||
InternalServerError,
|
||||
mapAnyError,
|
||||
mapSseError,
|
||||
metrics,
|
||||
NoCopilotProviderAvailable,
|
||||
@@ -242,61 +239,6 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
};
|
||||
}
|
||||
|
||||
@Get('/chat/:sessionId')
|
||||
@CallMetric('ai', 'chat', { timer: true })
|
||||
async chat(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query() query: Record<string, string | string[]>
|
||||
): Promise<string> {
|
||||
const info: any = { sessionId, params: query };
|
||||
|
||||
try {
|
||||
const { provider, model, session, finalMessage } =
|
||||
await this.prepareChatSession(
|
||||
user,
|
||||
sessionId,
|
||||
query,
|
||||
ModelOutputType.Text
|
||||
);
|
||||
|
||||
info.model = model;
|
||||
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
|
||||
metrics.ai.counter('chat_calls').add(1, { model });
|
||||
|
||||
const { reasoning, webSearch, toolsConfig } =
|
||||
ChatQuerySchema.parse(query);
|
||||
const content = await provider.text({ modelId: model }, finalMessage, {
|
||||
...session.config.promptConfig,
|
||||
signal: getSignal(req).signal,
|
||||
user: user.id,
|
||||
session: session.config.sessionId,
|
||||
workspace: session.config.workspaceId,
|
||||
reasoning,
|
||||
webSearch,
|
||||
tools: getTools(session.config.promptConfig?.tools, toolsConfig),
|
||||
});
|
||||
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
await session.save();
|
||||
|
||||
return content;
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_errors').add(1);
|
||||
let error = mapAnyError(e);
|
||||
if (error instanceof InternalServerError) {
|
||||
error = new CopilotFailedToGenerateText(e.message);
|
||||
}
|
||||
error.log('CopilotChat', info);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
@Sse('/chat/:sessionId/stream')
|
||||
@CallMetric('ai', 'chat_stream', { timer: true })
|
||||
async chatStream(
|
||||
|
||||
@@ -3,7 +3,7 @@ import { AiPrompt, PrismaClient } from '@prisma/client';
|
||||
|
||||
import type { PromptConfig, PromptMessage } from '../providers/types';
|
||||
|
||||
type Prompt = Omit<
|
||||
export type Prompt = Omit<
|
||||
AiPrompt,
|
||||
| 'id'
|
||||
| 'createdAt'
|
||||
@@ -2095,17 +2095,14 @@ export const prompts: Prompt[] = [
|
||||
|
||||
export async function refreshPrompts(db: PrismaClient) {
|
||||
const needToSkip = await db.aiPrompt
|
||||
.findMany({
|
||||
where: { modified: true },
|
||||
select: { name: true },
|
||||
})
|
||||
.findMany({ where: { modified: true }, select: { name: true } })
|
||||
.then(p => p.map(p => p.name));
|
||||
|
||||
for (const prompt of prompts) {
|
||||
// skip prompt update if already modified by admin panel
|
||||
if (needToSkip.includes(prompt.name)) {
|
||||
new Logger('CopilotPrompt').warn(`Skip modified prompt: ${prompt.name}`);
|
||||
return;
|
||||
continue;
|
||||
}
|
||||
|
||||
await db.aiPrompt.upsert({
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
import { ChatPrompt } from './chat-prompt';
|
||||
import {
|
||||
CopilotPromptScenario,
|
||||
type Prompt,
|
||||
prompts,
|
||||
refreshPrompts,
|
||||
Scenario,
|
||||
@@ -21,6 +22,7 @@ import {
|
||||
export class PromptService implements OnApplicationBootstrap {
|
||||
private readonly logger = new Logger(PromptService.name);
|
||||
private readonly cache = new Map<string, ChatPrompt>();
|
||||
private readonly inMemoryPrompts = new Map<string, Prompt>();
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
@@ -28,7 +30,7 @@ export class PromptService implements OnApplicationBootstrap {
|
||||
) {}
|
||||
|
||||
async onApplicationBootstrap() {
|
||||
this.cache.clear();
|
||||
this.resetInMemoryPrompts();
|
||||
await refreshPrompts(this.db);
|
||||
}
|
||||
|
||||
@@ -45,6 +47,7 @@ export class PromptService implements OnApplicationBootstrap {
|
||||
}
|
||||
|
||||
protected async setup(scenarios?: CopilotPromptScenario) {
|
||||
this.ensureInMemoryPrompts();
|
||||
if (!!scenarios && scenarios.override_enabled && scenarios.scenarios) {
|
||||
this.logger.log('Updating prompts based on scenarios...');
|
||||
for (const [scenario, model] of Object.entries(scenarios.scenarios)) {
|
||||
@@ -75,25 +78,29 @@ export class PromptService implements OnApplicationBootstrap {
|
||||
* @returns prompt names
|
||||
*/
|
||||
async listNames() {
|
||||
return this.db.aiPrompt
|
||||
.findMany({ select: { name: true } })
|
||||
.then(prompts => Array.from(new Set(prompts.map(p => p.name))));
|
||||
this.ensureInMemoryPrompts();
|
||||
return Array.from(this.inMemoryPrompts.keys());
|
||||
}
|
||||
|
||||
async list() {
|
||||
return this.db.aiPrompt.findMany({
|
||||
select: {
|
||||
name: true,
|
||||
action: true,
|
||||
model: true,
|
||||
config: true,
|
||||
messages: {
|
||||
select: { role: true, content: true, params: true },
|
||||
orderBy: { idx: 'asc' },
|
||||
},
|
||||
},
|
||||
orderBy: { action: { sort: 'asc', nulls: 'first' } },
|
||||
});
|
||||
this.ensureInMemoryPrompts();
|
||||
return Array.from(this.inMemoryPrompts.values())
|
||||
.map(prompt => ({
|
||||
name: prompt.name,
|
||||
action: prompt.action ?? null,
|
||||
model: prompt.model,
|
||||
config: prompt.config ? structuredClone(prompt.config) : null,
|
||||
messages: prompt.messages.map(message => ({
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
params: message.params ?? null,
|
||||
})),
|
||||
}))
|
||||
.sort((a, b) => {
|
||||
if (a.action === null && b.action !== null) return -1;
|
||||
if (a.action !== null && b.action === null) return 1;
|
||||
return (a.action ?? '').localeCompare(b.action ?? '');
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -102,40 +109,24 @@ export class PromptService implements OnApplicationBootstrap {
|
||||
* @returns prompt messages
|
||||
*/
|
||||
async get(name: string): Promise<ChatPrompt | null> {
|
||||
this.ensureInMemoryPrompts();
|
||||
|
||||
// skip cache in dev mode to ensure the latest prompt is always fetched
|
||||
if (!env.dev) {
|
||||
const cached = this.cache.get(name);
|
||||
if (cached) return cached;
|
||||
}
|
||||
|
||||
const prompt = await this.db.aiPrompt.findUnique({
|
||||
where: {
|
||||
name,
|
||||
},
|
||||
select: {
|
||||
name: true,
|
||||
action: true,
|
||||
model: true,
|
||||
optionalModels: true,
|
||||
config: true,
|
||||
messages: {
|
||||
select: {
|
||||
role: true,
|
||||
content: true,
|
||||
params: true,
|
||||
},
|
||||
orderBy: {
|
||||
idx: 'asc',
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
const prompt = this.inMemoryPrompts.get(name);
|
||||
if (!prompt) return null;
|
||||
|
||||
const messages = PromptMessageSchema.array().safeParse(prompt?.messages);
|
||||
const config = PromptConfigSchema.safeParse(prompt?.config);
|
||||
if (prompt && messages.success && config.success) {
|
||||
const messages = PromptMessageSchema.array().safeParse(prompt.messages);
|
||||
const config = PromptConfigSchema.safeParse(prompt.config);
|
||||
if (messages.success && config.success) {
|
||||
const chatPrompt = ChatPrompt.createFromPrompt({
|
||||
...prompt,
|
||||
...this.clonePrompt(prompt),
|
||||
action: prompt.action ?? null,
|
||||
optionalModels: prompt.optionalModels ?? [],
|
||||
config: config.data,
|
||||
messages: messages.data,
|
||||
});
|
||||
@@ -149,25 +140,69 @@ export class PromptService implements OnApplicationBootstrap {
|
||||
name: string,
|
||||
model: string,
|
||||
messages: PromptMessage[],
|
||||
config?: PromptConfig | null
|
||||
config?: PromptConfig | null,
|
||||
extraConfig?: { optionalModels: string[] }
|
||||
) {
|
||||
return await this.db.aiPrompt
|
||||
.create({
|
||||
data: {
|
||||
name,
|
||||
model,
|
||||
config: config || undefined,
|
||||
messages: {
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
...m,
|
||||
attachments: m.attachments || undefined,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
this.ensureInMemoryPrompts();
|
||||
|
||||
const existing = this.inMemoryPrompts.get(name);
|
||||
const mergedOptionalModels = existing?.optionalModels
|
||||
? [...existing.optionalModels, ...(extraConfig?.optionalModels ?? [])]
|
||||
: extraConfig?.optionalModels;
|
||||
const inMemoryConfig = (!!config && structuredClone(config)) || undefined;
|
||||
const dbConfig = this.toDbConfig(config);
|
||||
this.inMemoryPrompts.set(name, {
|
||||
name,
|
||||
model,
|
||||
action: existing?.action,
|
||||
optionalModels: mergedOptionalModels,
|
||||
config: inMemoryConfig,
|
||||
messages: this.cloneMessages(messages),
|
||||
});
|
||||
this.cache.delete(name);
|
||||
|
||||
try {
|
||||
return await this.db.aiPrompt
|
||||
.upsert({
|
||||
where: { name },
|
||||
create: {
|
||||
name,
|
||||
action: existing?.action,
|
||||
model,
|
||||
optionalModels: mergedOptionalModels,
|
||||
config: dbConfig,
|
||||
messages: {
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
...m,
|
||||
attachments: m.attachments || undefined,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
.then(ret => ret.id);
|
||||
update: {
|
||||
model,
|
||||
optionalModels: mergedOptionalModels,
|
||||
config: dbConfig,
|
||||
updatedAt: new Date(),
|
||||
messages: {
|
||||
deleteMany: {},
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
...m,
|
||||
attachments: m.attachments || undefined,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
},
|
||||
},
|
||||
})
|
||||
.then(ret => ret.id);
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Compat prompt upsert failed for "${name}": ${this.stringifyError(error)}`
|
||||
);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
@@ -177,44 +212,123 @@ export class PromptService implements OnApplicationBootstrap {
|
||||
messages?: PromptMessage[];
|
||||
model?: string;
|
||||
modified?: boolean;
|
||||
config?: PromptConfig;
|
||||
config?: PromptConfig | null;
|
||||
},
|
||||
where?: Prisma.AiPromptWhereInput
|
||||
) {
|
||||
this.ensureInMemoryPrompts();
|
||||
const { config, messages, model, modified } = data;
|
||||
const existing = await this.db.aiPrompt
|
||||
.count({ where: { ...where, name } })
|
||||
.then(count => count > 0);
|
||||
if (existing) {
|
||||
await this.db.aiPrompt.update({
|
||||
where: { name },
|
||||
data: {
|
||||
config: config || undefined,
|
||||
updatedAt: new Date(),
|
||||
modified,
|
||||
model,
|
||||
messages: messages
|
||||
? {
|
||||
// cleanup old messages
|
||||
deleteMany: {},
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
...m,
|
||||
attachments: m.attachments || undefined,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
}
|
||||
: undefined,
|
||||
},
|
||||
});
|
||||
|
||||
const current = this.inMemoryPrompts.get(name);
|
||||
if (current) {
|
||||
const next = this.clonePrompt(current);
|
||||
if (model !== undefined) {
|
||||
next.model = model;
|
||||
}
|
||||
if (config === null) {
|
||||
next.config = undefined;
|
||||
} else if (config !== undefined) {
|
||||
next.config = structuredClone(config);
|
||||
}
|
||||
if (messages) {
|
||||
next.messages = this.cloneMessages(messages);
|
||||
}
|
||||
|
||||
this.inMemoryPrompts.set(name, next);
|
||||
this.cache.delete(name);
|
||||
}
|
||||
|
||||
try {
|
||||
const existing = await this.db.aiPrompt
|
||||
.count({ where: { ...where, name } })
|
||||
.then(count => count > 0);
|
||||
if (existing) {
|
||||
await this.db.aiPrompt.update({
|
||||
where: { name },
|
||||
data: {
|
||||
config: this.toDbConfig(config),
|
||||
updatedAt: new Date(),
|
||||
modified,
|
||||
model,
|
||||
messages: messages
|
||||
? {
|
||||
// cleanup old messages
|
||||
deleteMany: {},
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
...m,
|
||||
attachments: m.attachments || undefined,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
}
|
||||
: undefined,
|
||||
},
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Compat prompt update failed for "${name}": ${this.stringifyError(error)}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async delete(name: string) {
|
||||
const { id } = await this.db.aiPrompt.delete({ where: { name } });
|
||||
this.inMemoryPrompts.delete(name);
|
||||
this.cache.delete(name);
|
||||
return id;
|
||||
|
||||
try {
|
||||
const { id } = await this.db.aiPrompt.delete({ where: { name } });
|
||||
return id;
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Compat prompt delete failed for "${name}": ${this.stringifyError(error)}`
|
||||
);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
private resetInMemoryPrompts() {
|
||||
this.cache.clear();
|
||||
this.inMemoryPrompts.clear();
|
||||
for (const prompt of prompts) {
|
||||
this.inMemoryPrompts.set(prompt.name, this.clonePrompt(prompt));
|
||||
}
|
||||
}
|
||||
|
||||
private ensureInMemoryPrompts() {
|
||||
if (!this.inMemoryPrompts.size) {
|
||||
this.resetInMemoryPrompts();
|
||||
}
|
||||
}
|
||||
|
||||
private toDbConfig(
|
||||
config: PromptConfig | null | undefined
|
||||
): Prisma.InputJsonValue | Prisma.NullableJsonNullValueInput | undefined {
|
||||
if (config === null) return Prisma.DbNull;
|
||||
if (config === undefined) return undefined;
|
||||
return config as Prisma.InputJsonValue;
|
||||
}
|
||||
|
||||
private cloneMessages(messages: PromptMessage[]) {
|
||||
return messages.map(message => ({
|
||||
...message,
|
||||
attachments: message.attachments ? [...message.attachments] : undefined,
|
||||
params: message.params ? structuredClone(message.params) : undefined,
|
||||
}));
|
||||
}
|
||||
|
||||
private clonePrompt(prompt: Prompt): Prompt {
|
||||
return {
|
||||
...prompt,
|
||||
optionalModels: prompt.optionalModels
|
||||
? [...prompt.optionalModels]
|
||||
: undefined,
|
||||
config: prompt.config ? structuredClone(prompt.config) : undefined,
|
||||
messages: this.cloneMessages(prompt.messages),
|
||||
};
|
||||
}
|
||||
|
||||
private stringifyError(error: unknown) {
|
||||
return error instanceof Error ? error.message : String(error);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,52 +1,90 @@
|
||||
import {
|
||||
type AnthropicProvider as AnthropicSDKProvider,
|
||||
type AnthropicProviderOptions,
|
||||
} from '@ai-sdk/anthropic';
|
||||
import { type GoogleVertexAnthropicProvider } from '@ai-sdk/google-vertex/anthropic';
|
||||
import { AISDKError, generateText, stepCountIs, streamText } from 'ai';
|
||||
import type { ToolSet } from 'ai';
|
||||
|
||||
import {
|
||||
CopilotProviderSideError,
|
||||
metrics,
|
||||
UserFriendlyError,
|
||||
} from '../../../../base';
|
||||
import {
|
||||
llmDispatchStream,
|
||||
type NativeLlmBackendConfig,
|
||||
type NativeLlmRequest,
|
||||
} from '../../../../native';
|
||||
import type { NodeTextMiddleware } from '../../config';
|
||||
import { buildNativeRequest, NativeProviderAdapter } from '../native';
|
||||
import { CopilotProvider } from '../provider';
|
||||
import type {
|
||||
CopilotChatOptions,
|
||||
CopilotProviderModel,
|
||||
ModelConditions,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from '../types';
|
||||
import { ModelOutputType } from '../types';
|
||||
import {
|
||||
chatToGPTMessage,
|
||||
StreamObjectParser,
|
||||
TextStreamParser,
|
||||
} from '../utils';
|
||||
import { CopilotProviderType, ModelOutputType } from '../types';
|
||||
import { getGoogleAuth, getVertexAnthropicBaseUrl } from '../utils';
|
||||
|
||||
export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
protected abstract instance:
|
||||
| AnthropicSDKProvider
|
||||
| GoogleVertexAnthropicProvider;
|
||||
|
||||
private handleError(e: any) {
|
||||
if (e instanceof UserFriendlyError) {
|
||||
return e;
|
||||
} else if (e instanceof AISDKError) {
|
||||
this.logger.error('Throw error from ai sdk:', e);
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: e.name || 'unknown',
|
||||
message: e.message,
|
||||
});
|
||||
} else {
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected anthropic response',
|
||||
});
|
||||
}
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected anthropic response',
|
||||
});
|
||||
}
|
||||
|
||||
private async createNativeConfig(): Promise<NativeLlmBackendConfig> {
|
||||
if (this.type === CopilotProviderType.AnthropicVertex) {
|
||||
const auth = await getGoogleAuth(this.config as any, 'anthropic');
|
||||
const headers = auth.headers();
|
||||
const authorization =
|
||||
headers.Authorization ||
|
||||
(headers as Record<string, string | undefined>).authorization;
|
||||
const token =
|
||||
typeof authorization === 'string'
|
||||
? authorization.replace(/^Bearer\s+/i, '')
|
||||
: '';
|
||||
const baseUrl =
|
||||
getVertexAnthropicBaseUrl(this.config as any) || auth.baseUrl;
|
||||
return {
|
||||
base_url: baseUrl || '',
|
||||
auth_token: token,
|
||||
request_layer: 'vertex',
|
||||
headers,
|
||||
};
|
||||
}
|
||||
|
||||
const config = this.config as { apiKey: string; baseURL?: string };
|
||||
const baseUrl = config.baseURL || 'https://api.anthropic.com/v1';
|
||||
return {
|
||||
base_url: baseUrl.replace(/\/v1\/?$/, ''),
|
||||
auth_token: config.apiKey,
|
||||
};
|
||||
}
|
||||
|
||||
private createAdapter(
|
||||
backendConfig: NativeLlmBackendConfig,
|
||||
tools: ToolSet,
|
||||
nodeTextMiddleware?: NodeTextMiddleware[]
|
||||
) {
|
||||
return new NativeProviderAdapter(
|
||||
(request: NativeLlmRequest, signal?: AbortSignal) =>
|
||||
llmDispatchStream('anthropic', backendConfig, request, signal),
|
||||
tools,
|
||||
this.MAX_STEPS,
|
||||
{ nodeTextMiddleware }
|
||||
);
|
||||
}
|
||||
|
||||
private getReasoning(
|
||||
options: NonNullable<CopilotChatOptions>,
|
||||
model: string
|
||||
): Record<string, unknown> | undefined {
|
||||
if (options.reasoning && this.isReasoningModel(model)) {
|
||||
return { budget_tokens: 12000, include_thought: true };
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async text(
|
||||
@@ -59,28 +97,29 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
|
||||
const [system, msgs] = await chatToGPTMessage(messages, true, true);
|
||||
|
||||
const modelInstance = this.instance(model.id);
|
||||
const { text, reasoning } = await generateText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
abortSignal: options.signal,
|
||||
providerOptions: {
|
||||
anthropic: this.getAnthropicOptions(options, model.id),
|
||||
},
|
||||
tools: await this.getTools(options, model.id),
|
||||
stopWhen: stepCountIs(this.MAX_STEPS),
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
const backendConfig = await this.createNativeConfig();
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const reasoning = this.getReasoning(options, model.id);
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
reasoning,
|
||||
middleware,
|
||||
});
|
||||
|
||||
if (!text) throw new Error('Failed to generate text');
|
||||
|
||||
return reasoning ? `${reasoning}\n${text}` : text;
|
||||
const adapter = this.createAdapter(
|
||||
backendConfig,
|
||||
tools,
|
||||
middleware.node?.text
|
||||
);
|
||||
return await adapter.text(request, options.signal);
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -95,25 +134,32 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const parser = new TextStreamParser();
|
||||
for await (const chunk of fullStream) {
|
||||
const result = parser.parse(chunk);
|
||||
yield result;
|
||||
if (options.signal?.aborted) {
|
||||
await fullStream.cancel();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!options.signal?.aborted) {
|
||||
const footnotes = parser.end();
|
||||
if (footnotes.length) {
|
||||
yield `\n\n${footnotes}`;
|
||||
}
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_calls')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const backendConfig = await this.createNativeConfig();
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createAdapter(
|
||||
backendConfig,
|
||||
tools,
|
||||
middleware.node?.text
|
||||
);
|
||||
for await (const chunk of adapter.streamText(request, options.signal)) {
|
||||
yield chunk;
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -130,58 +176,34 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_calls')
|
||||
.add(1, { model: model.id });
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const parser = new StreamObjectParser();
|
||||
for await (const chunk of fullStream) {
|
||||
const result = parser.parse(chunk);
|
||||
if (result) {
|
||||
yield result;
|
||||
}
|
||||
if (options.signal?.aborted) {
|
||||
await fullStream.cancel();
|
||||
break;
|
||||
}
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const backendConfig = await this.createNativeConfig();
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createAdapter(
|
||||
backendConfig,
|
||||
tools,
|
||||
middleware.node?.text
|
||||
);
|
||||
for await (const chunk of adapter.streamObject(request, options.signal)) {
|
||||
yield chunk;
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_errors')
|
||||
.add(1, { model: model.id });
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
private async getFullStream(
|
||||
model: CopilotProviderModel,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
) {
|
||||
const [system, msgs] = await chatToGPTMessage(messages, true, true);
|
||||
const { fullStream } = streamText({
|
||||
model: this.instance(model.id),
|
||||
system,
|
||||
messages: msgs,
|
||||
abortSignal: options.signal,
|
||||
providerOptions: {
|
||||
anthropic: this.getAnthropicOptions(options, model.id),
|
||||
},
|
||||
tools: await this.getTools(options, model.id),
|
||||
stopWhen: stepCountIs(this.MAX_STEPS),
|
||||
});
|
||||
return fullStream;
|
||||
}
|
||||
|
||||
private getAnthropicOptions(options: CopilotChatOptions, model: string) {
|
||||
const result: AnthropicProviderOptions = {};
|
||||
if (options?.reasoning && this.isReasoningModel(model)) {
|
||||
result.thinking = {
|
||||
type: 'enabled',
|
||||
budgetTokens: 12000,
|
||||
};
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private isReasoningModel(model: string) {
|
||||
// claude 3.5 sonnet doesn't support reasoning config
|
||||
return model.includes('sonnet') && !model.startsWith('claude-3-5-sonnet');
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
import {
|
||||
type AnthropicProvider as AnthropicSDKProvider,
|
||||
createAnthropic,
|
||||
} from '@ai-sdk/anthropic';
|
||||
import z from 'zod';
|
||||
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
@@ -52,18 +48,12 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
|
||||
},
|
||||
];
|
||||
|
||||
protected instance!: AnthropicSDKProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
return !!this.config.apiKey;
|
||||
}
|
||||
|
||||
override setup() {
|
||||
super.setup();
|
||||
this.instance = createAnthropic({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseURL,
|
||||
});
|
||||
}
|
||||
|
||||
override async refreshOnlineModels() {
|
||||
|
||||
@@ -5,7 +5,11 @@ import {
|
||||
} from '@ai-sdk/google-vertex/anthropic';
|
||||
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
import { getGoogleAuth, VertexModelListSchema } from '../utils';
|
||||
import {
|
||||
getGoogleAuth,
|
||||
getVertexAnthropicBaseUrl,
|
||||
VertexModelListSchema,
|
||||
} from '../utils';
|
||||
import { AnthropicProvider } from './anthropic';
|
||||
|
||||
export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings;
|
||||
@@ -49,7 +53,8 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
|
||||
protected instance!: GoogleVertexAnthropicProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
return !!this.config.location && !!this.config.googleAuthOptions;
|
||||
if (!this.config.location || !this.config.googleAuthOptions) return false;
|
||||
return !!this.config.project || !!getVertexAnthropicBaseUrl(this.config);
|
||||
}
|
||||
|
||||
override setup() {
|
||||
|
||||
@@ -1,16 +1,141 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import { Config } from '../../../base';
|
||||
import { ServerFeature, ServerService } from '../../../core';
|
||||
import type { CopilotProvider } from './provider';
|
||||
import {
|
||||
buildProviderRegistry,
|
||||
resolveModel,
|
||||
stripProviderPrefix,
|
||||
} from './provider-registry';
|
||||
import { CopilotProviderType, ModelFullConditions } from './types';
|
||||
|
||||
function isAsyncIterable(value: unknown): value is AsyncIterable<unknown> {
|
||||
return (
|
||||
value !== null &&
|
||||
value !== undefined &&
|
||||
typeof (value as AsyncIterable<unknown>)[Symbol.asyncIterator] ===
|
||||
'function'
|
||||
);
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class CopilotProviderFactory {
|
||||
constructor(private readonly server: ServerService) {}
|
||||
constructor(
|
||||
private readonly server: ServerService,
|
||||
private readonly config: Config
|
||||
) {}
|
||||
|
||||
private readonly logger = new Logger(CopilotProviderFactory.name);
|
||||
|
||||
readonly #providers = new Map<CopilotProviderType, CopilotProvider>();
|
||||
readonly #providers = new Map<string, CopilotProvider>();
|
||||
readonly #boundProviders = new Map<string, CopilotProvider>();
|
||||
readonly #providerIdsByType = new Map<CopilotProviderType, Set<string>>();
|
||||
|
||||
private getRegistry() {
|
||||
return buildProviderRegistry(this.config.copilot.providers);
|
||||
}
|
||||
|
||||
private getPreferredProviderIds(type?: CopilotProviderType) {
|
||||
if (!type) return undefined;
|
||||
return this.#providerIdsByType.get(type);
|
||||
}
|
||||
|
||||
private normalizeCond(
|
||||
providerId: string,
|
||||
cond: ModelFullConditions
|
||||
): ModelFullConditions {
|
||||
const registry = this.getRegistry();
|
||||
const modelId = stripProviderPrefix(registry, providerId, cond.modelId);
|
||||
return { ...cond, modelId };
|
||||
}
|
||||
|
||||
private normalizeMethodArgs(providerId: string, args: unknown[]) {
|
||||
const [first, ...rest] = args;
|
||||
if (
|
||||
!first ||
|
||||
typeof first !== 'object' ||
|
||||
Array.isArray(first) ||
|
||||
!('modelId' in first)
|
||||
) {
|
||||
return args;
|
||||
}
|
||||
|
||||
const cond = first as Record<string, unknown>;
|
||||
if (typeof cond.modelId !== 'string') return args;
|
||||
|
||||
const registry = this.getRegistry();
|
||||
const modelId = stripProviderPrefix(registry, providerId, cond.modelId);
|
||||
return [{ ...cond, modelId }, ...rest];
|
||||
}
|
||||
|
||||
private wrapAsyncIterable<T>(
|
||||
provider: CopilotProvider,
|
||||
providerId: string,
|
||||
iterable: AsyncIterable<T>
|
||||
): AsyncIterableIterator<T> {
|
||||
const iterator = iterable[Symbol.asyncIterator]();
|
||||
|
||||
return {
|
||||
next: value =>
|
||||
provider.runWithProfile(providerId, () => iterator.next(value)),
|
||||
return: value =>
|
||||
provider.runWithProfile(providerId, async () => {
|
||||
if (typeof iterator.return === 'function') {
|
||||
return iterator.return(value as never);
|
||||
}
|
||||
return { done: true, value: value as T };
|
||||
}),
|
||||
throw: error =>
|
||||
provider.runWithProfile(providerId, async () => {
|
||||
if (typeof iterator.throw === 'function') {
|
||||
return iterator.throw(error);
|
||||
}
|
||||
throw error;
|
||||
}),
|
||||
[Symbol.asyncIterator]() {
|
||||
return this;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
private getBoundProvider(providerId: string, provider: CopilotProvider) {
|
||||
const cached = this.#boundProviders.get(providerId);
|
||||
if (cached) {
|
||||
return cached;
|
||||
}
|
||||
|
||||
const wrapped = new Proxy(provider, {
|
||||
get: (target, prop, receiver) => {
|
||||
if (prop === 'providerId') {
|
||||
return providerId;
|
||||
}
|
||||
|
||||
const value = Reflect.get(target, prop, receiver);
|
||||
if (typeof value !== 'function') {
|
||||
return value;
|
||||
}
|
||||
|
||||
return (...args: unknown[]) => {
|
||||
const normalizedArgs = this.normalizeMethodArgs(providerId, args);
|
||||
const result = provider.runWithProfile(providerId, () =>
|
||||
Reflect.apply(value, provider, normalizedArgs)
|
||||
);
|
||||
if (isAsyncIterable(result)) {
|
||||
return this.wrapAsyncIterable(
|
||||
provider,
|
||||
providerId,
|
||||
result as AsyncIterable<unknown>
|
||||
);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
},
|
||||
}) as CopilotProvider;
|
||||
|
||||
this.#boundProviders.set(providerId, wrapped);
|
||||
return wrapped;
|
||||
}
|
||||
|
||||
async getProvider(
|
||||
cond: ModelFullConditions,
|
||||
@@ -21,22 +146,41 @@ export class CopilotProviderFactory {
|
||||
this.logger.debug(
|
||||
`Resolving copilot provider for output type: ${cond.outputType}`
|
||||
);
|
||||
let candidate: CopilotProvider | null = null;
|
||||
for (const [type, provider] of this.#providers.entries()) {
|
||||
if (filter.prefer && filter.prefer !== type) {
|
||||
const route = resolveModel({
|
||||
registry: this.getRegistry(),
|
||||
modelId: cond.modelId,
|
||||
outputType: cond.outputType,
|
||||
availableProviderIds: this.#providers.keys(),
|
||||
preferredProviderIds: this.getPreferredProviderIds(filter.prefer),
|
||||
});
|
||||
|
||||
const registry = this.getRegistry();
|
||||
for (const providerId of route.candidateProviderIds) {
|
||||
const provider = this.#providers.get(providerId);
|
||||
if (!provider) continue;
|
||||
|
||||
const profile = registry.profiles.get(providerId);
|
||||
const normalizedCond = this.normalizeCond(providerId, cond);
|
||||
if (
|
||||
normalizedCond.modelId &&
|
||||
profile?.models?.length &&
|
||||
!profile.models.includes(normalizedCond.modelId)
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const isMatched = await provider.match(cond);
|
||||
const matched = await provider.runWithProfile(providerId, () =>
|
||||
provider.match(normalizedCond)
|
||||
);
|
||||
if (!matched) continue;
|
||||
|
||||
if (isMatched) {
|
||||
candidate = provider;
|
||||
this.logger.debug(`Copilot provider candidate found: ${type}`);
|
||||
break;
|
||||
}
|
||||
this.logger.debug(
|
||||
`Copilot provider candidate found: ${provider.type} (${providerId})`
|
||||
);
|
||||
return this.getBoundProvider(providerId, provider);
|
||||
}
|
||||
|
||||
return candidate;
|
||||
return null;
|
||||
}
|
||||
|
||||
async getProviderByModel(
|
||||
@@ -46,31 +190,50 @@ export class CopilotProviderFactory {
|
||||
} = {}
|
||||
): Promise<CopilotProvider | null> {
|
||||
this.logger.debug(`Resolving copilot provider for model: ${modelId}`);
|
||||
return this.getProvider({ modelId }, filter);
|
||||
}
|
||||
|
||||
let candidate: CopilotProvider | null = null;
|
||||
for (const [type, provider] of this.#providers.entries()) {
|
||||
if (filter.prefer && filter.prefer !== type) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (await provider.match({ modelId })) {
|
||||
candidate = provider;
|
||||
this.logger.debug(`Copilot provider candidate found: ${type}`);
|
||||
register(providerId: string, provider: CopilotProvider) {
|
||||
const existed = this.#providers.get(providerId);
|
||||
if (existed?.type && existed.type !== provider.type) {
|
||||
const ids = this.#providerIdsByType.get(existed.type);
|
||||
ids?.delete(providerId);
|
||||
if (!ids?.size) {
|
||||
this.#providerIdsByType.delete(existed.type);
|
||||
}
|
||||
}
|
||||
|
||||
return candidate;
|
||||
}
|
||||
this.#providers.set(providerId, provider);
|
||||
this.#boundProviders.delete(providerId);
|
||||
|
||||
register(provider: CopilotProvider) {
|
||||
this.#providers.set(provider.type, provider);
|
||||
this.logger.log(`Copilot provider [${provider.type}] registered.`);
|
||||
const ids = this.#providerIdsByType.get(provider.type) ?? new Set<string>();
|
||||
ids.add(providerId);
|
||||
this.#providerIdsByType.set(provider.type, ids);
|
||||
|
||||
this.logger.log(
|
||||
`Copilot provider [${provider.type}] registered as [${providerId}].`
|
||||
);
|
||||
this.server.enableFeature(ServerFeature.Copilot);
|
||||
}
|
||||
|
||||
unregister(provider: CopilotProvider) {
|
||||
this.#providers.delete(provider.type);
|
||||
this.logger.log(`Copilot provider [${provider.type}] unregistered.`);
|
||||
unregister(providerId: string, provider: CopilotProvider) {
|
||||
const existed = this.#providers.get(providerId);
|
||||
if (!existed || existed !== provider) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.#providers.delete(providerId);
|
||||
this.#boundProviders.delete(providerId);
|
||||
|
||||
const ids = this.#providerIdsByType.get(provider.type);
|
||||
ids?.delete(providerId);
|
||||
if (!ids?.size) {
|
||||
this.#providerIdsByType.delete(provider.type);
|
||||
}
|
||||
|
||||
this.logger.log(
|
||||
`Copilot provider [${provider.type}] unregistered from [${providerId}].`
|
||||
);
|
||||
if (this.#providers.size === 0) {
|
||||
this.server.disableFeature(ServerFeature.Copilot);
|
||||
}
|
||||
|
||||
381
packages/backend/server/src/plugins/copilot/providers/loop.ts
Normal file
381
packages/backend/server/src/plugins/copilot/providers/loop.ts
Normal file
@@ -0,0 +1,381 @@
|
||||
import type { ToolSet } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import type {
|
||||
NativeLlmRequest,
|
||||
NativeLlmStreamEvent,
|
||||
NativeLlmToolDefinition,
|
||||
} from '../../../native';
|
||||
|
||||
export type NativeDispatchFn = (
|
||||
request: NativeLlmRequest,
|
||||
signal?: AbortSignal
|
||||
) => AsyncIterableIterator<NativeLlmStreamEvent>;
|
||||
|
||||
export type NativeToolCall = {
|
||||
id: string;
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
thought?: string;
|
||||
};
|
||||
|
||||
type ToolCallState = {
|
||||
name?: string;
|
||||
argumentsText: string;
|
||||
};
|
||||
|
||||
type ToolExecutionResult = {
|
||||
callId: string;
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
output: unknown;
|
||||
isError?: boolean;
|
||||
};
|
||||
|
||||
export class ToolCallAccumulator {
|
||||
readonly #states = new Map<string, ToolCallState>();
|
||||
|
||||
feedDelta(event: Extract<NativeLlmStreamEvent, { type: 'tool_call_delta' }>) {
|
||||
const state = this.#states.get(event.call_id) ?? {
|
||||
argumentsText: '',
|
||||
};
|
||||
if (event.name) {
|
||||
state.name = event.name;
|
||||
}
|
||||
if (event.arguments_delta) {
|
||||
state.argumentsText += event.arguments_delta;
|
||||
}
|
||||
this.#states.set(event.call_id, state);
|
||||
}
|
||||
|
||||
complete(event: Extract<NativeLlmStreamEvent, { type: 'tool_call' }>) {
|
||||
const state = this.#states.get(event.call_id);
|
||||
this.#states.delete(event.call_id);
|
||||
return {
|
||||
id: event.call_id,
|
||||
name: event.name || state?.name || '',
|
||||
args: this.parseArgs(
|
||||
event.arguments ?? this.parseJson(state?.argumentsText ?? '{}')
|
||||
),
|
||||
thought: event.thought,
|
||||
} satisfies NativeToolCall;
|
||||
}
|
||||
|
||||
drainPending() {
|
||||
const pending: NativeToolCall[] = [];
|
||||
for (const [callId, state] of this.#states.entries()) {
|
||||
if (!state.name) {
|
||||
continue;
|
||||
}
|
||||
pending.push({
|
||||
id: callId,
|
||||
name: state.name,
|
||||
args: this.parseArgs(this.parseJson(state.argumentsText)),
|
||||
});
|
||||
}
|
||||
this.#states.clear();
|
||||
return pending;
|
||||
}
|
||||
|
||||
private parseJson(jsonText: string): unknown {
|
||||
if (!jsonText.trim()) {
|
||||
return {};
|
||||
}
|
||||
try {
|
||||
return JSON.parse(jsonText);
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
private parseArgs(value: unknown): Record<string, unknown> {
|
||||
if (value && typeof value === 'object' && !Array.isArray(value)) {
|
||||
return value as Record<string, unknown>;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
export class ToolSchemaExtractor {
|
||||
static extract(toolSet: ToolSet): NativeLlmToolDefinition[] {
|
||||
return Object.entries(toolSet).map(([name, tool]) => {
|
||||
const unknownTool = tool as Record<string, unknown>;
|
||||
const inputSchema =
|
||||
unknownTool.inputSchema ?? unknownTool.parameters ?? z.object({});
|
||||
|
||||
return {
|
||||
name,
|
||||
description:
|
||||
typeof unknownTool.description === 'string'
|
||||
? unknownTool.description
|
||||
: undefined,
|
||||
parameters: this.toJsonSchema(inputSchema),
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
private static toJsonSchema(schema: unknown): Record<string, unknown> {
|
||||
if (!(schema instanceof z.ZodType)) {
|
||||
if (schema && typeof schema === 'object' && !Array.isArray(schema)) {
|
||||
return schema as Record<string, unknown>;
|
||||
}
|
||||
return { type: 'object', properties: {} };
|
||||
}
|
||||
|
||||
if (schema instanceof z.ZodObject) {
|
||||
const shape = schema.shape;
|
||||
const properties: Record<string, unknown> = {};
|
||||
const required: string[] = [];
|
||||
|
||||
for (const [key, child] of Object.entries(
|
||||
shape as Record<string, z.ZodTypeAny>
|
||||
)) {
|
||||
properties[key] = this.toJsonSchema(child);
|
||||
if (!this.isOptional(child)) {
|
||||
required.push(key);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'object',
|
||||
properties,
|
||||
additionalProperties: false,
|
||||
...(required.length ? { required } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
if (schema instanceof z.ZodString) {
|
||||
return { type: 'string' };
|
||||
}
|
||||
if (schema instanceof z.ZodNumber) {
|
||||
return { type: 'number' };
|
||||
}
|
||||
if (schema instanceof z.ZodBoolean) {
|
||||
return { type: 'boolean' };
|
||||
}
|
||||
if (schema instanceof z.ZodArray) {
|
||||
return { type: 'array', items: this.toJsonSchema(schema.element) };
|
||||
}
|
||||
if (schema instanceof z.ZodEnum) {
|
||||
return { type: 'string', enum: schema.options };
|
||||
}
|
||||
if (schema instanceof z.ZodLiteral) {
|
||||
const literal = schema.value;
|
||||
if (literal === null) {
|
||||
return { const: null, type: 'null' };
|
||||
}
|
||||
if (typeof literal === 'string') {
|
||||
return { const: literal, type: 'string' };
|
||||
}
|
||||
if (typeof literal === 'number') {
|
||||
return { const: literal, type: 'number' };
|
||||
}
|
||||
if (typeof literal === 'boolean') {
|
||||
return { const: literal, type: 'boolean' };
|
||||
}
|
||||
return { const: literal };
|
||||
}
|
||||
if (schema instanceof z.ZodUnion) {
|
||||
return {
|
||||
anyOf: schema.options.map((option: z.ZodTypeAny) =>
|
||||
this.toJsonSchema(option)
|
||||
),
|
||||
};
|
||||
}
|
||||
if (schema instanceof z.ZodRecord) {
|
||||
return {
|
||||
type: 'object',
|
||||
additionalProperties: this.toJsonSchema(schema.valueSchema),
|
||||
};
|
||||
}
|
||||
|
||||
if (schema instanceof z.ZodNullable) {
|
||||
const inner = (schema._def as { innerType?: z.ZodTypeAny }).innerType;
|
||||
return { anyOf: [this.toJsonSchema(inner), { type: 'null' }] };
|
||||
}
|
||||
|
||||
if (schema instanceof z.ZodOptional || schema instanceof z.ZodDefault) {
|
||||
return this.toJsonSchema(
|
||||
(schema._def as { innerType?: z.ZodTypeAny }).innerType
|
||||
);
|
||||
}
|
||||
|
||||
if (schema instanceof z.ZodEffects) {
|
||||
return this.toJsonSchema(
|
||||
(schema._def as { schema?: z.ZodTypeAny }).schema
|
||||
);
|
||||
}
|
||||
|
||||
return { type: 'object', properties: {} };
|
||||
}
|
||||
|
||||
private static isOptional(schema: z.ZodTypeAny): boolean {
|
||||
if (schema instanceof z.ZodOptional || schema instanceof z.ZodDefault) {
|
||||
return true;
|
||||
}
|
||||
if (schema instanceof z.ZodNullable) {
|
||||
return this.isOptional(
|
||||
(schema._def as { innerType: z.ZodTypeAny }).innerType
|
||||
);
|
||||
}
|
||||
if (schema instanceof z.ZodEffects) {
|
||||
return this.isOptional((schema._def as { schema: z.ZodTypeAny }).schema);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export class ToolCallLoop {
|
||||
constructor(
|
||||
private readonly dispatch: NativeDispatchFn,
|
||||
private readonly tools: ToolSet,
|
||||
private readonly maxSteps = 20
|
||||
) {}
|
||||
|
||||
async *run(
|
||||
request: NativeLlmRequest,
|
||||
signal?: AbortSignal
|
||||
): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
const messages = request.messages.map(message => ({
|
||||
...message,
|
||||
content: [...message.content],
|
||||
}));
|
||||
|
||||
for (let step = 0; step < this.maxSteps; step++) {
|
||||
const toolCalls: NativeToolCall[] = [];
|
||||
const accumulator = new ToolCallAccumulator();
|
||||
let finalDone: Extract<NativeLlmStreamEvent, { type: 'done' }> | null =
|
||||
null;
|
||||
|
||||
for await (const event of this.dispatch(
|
||||
{
|
||||
...request,
|
||||
stream: true,
|
||||
messages,
|
||||
},
|
||||
signal
|
||||
)) {
|
||||
switch (event.type) {
|
||||
case 'tool_call_delta': {
|
||||
accumulator.feedDelta(event);
|
||||
break;
|
||||
}
|
||||
case 'tool_call': {
|
||||
toolCalls.push(accumulator.complete(event));
|
||||
yield event;
|
||||
break;
|
||||
}
|
||||
case 'done': {
|
||||
finalDone = event;
|
||||
break;
|
||||
}
|
||||
case 'error': {
|
||||
throw new Error(event.message);
|
||||
}
|
||||
default: {
|
||||
yield event;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push(...accumulator.drainPending());
|
||||
if (toolCalls.length === 0) {
|
||||
if (finalDone) {
|
||||
yield finalDone;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (step === this.maxSteps - 1) {
|
||||
throw new Error('ToolCallLoop max steps reached');
|
||||
}
|
||||
|
||||
const toolResults = await this.executeTools(toolCalls);
|
||||
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
content: toolCalls.map(call => ({
|
||||
type: 'tool_call',
|
||||
call_id: call.id,
|
||||
name: call.name,
|
||||
arguments: call.args,
|
||||
thought: call.thought,
|
||||
})),
|
||||
});
|
||||
|
||||
for (const result of toolResults) {
|
||||
messages.push({
|
||||
role: 'tool',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
call_id: result.callId,
|
||||
output: result.output,
|
||||
is_error: result.isError,
|
||||
},
|
||||
],
|
||||
});
|
||||
yield {
|
||||
type: 'tool_result',
|
||||
call_id: result.callId,
|
||||
name: result.name,
|
||||
arguments: result.args,
|
||||
output: result.output,
|
||||
is_error: result.isError,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async executeTools(calls: NativeToolCall[]) {
|
||||
return await Promise.all(calls.map(call => this.executeTool(call)));
|
||||
}
|
||||
|
||||
private async executeTool(
|
||||
call: NativeToolCall
|
||||
): Promise<ToolExecutionResult> {
|
||||
const tool = this.tools[call.name] as
|
||||
| {
|
||||
execute?: (args: Record<string, unknown>) => Promise<unknown>;
|
||||
}
|
||||
| undefined;
|
||||
|
||||
if (!tool?.execute) {
|
||||
return {
|
||||
callId: call.id,
|
||||
name: call.name,
|
||||
args: call.args,
|
||||
isError: true,
|
||||
output: {
|
||||
message: `Tool not found: ${call.name}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const output = await tool.execute(call.args);
|
||||
return {
|
||||
callId: call.id,
|
||||
name: call.name,
|
||||
args: call.args,
|
||||
output: output ?? null,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Tool execution failed', {
|
||||
callId: call.id,
|
||||
toolName: call.name,
|
||||
error,
|
||||
});
|
||||
return {
|
||||
callId: call.id,
|
||||
name: call.name,
|
||||
args: call.args,
|
||||
isError: true,
|
||||
output: {
|
||||
message: 'Tool execution failed',
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,17 @@
|
||||
import {
|
||||
createOpenAICompatible,
|
||||
OpenAICompatibleProvider as VercelOpenAICompatibleProvider,
|
||||
} from '@ai-sdk/openai-compatible';
|
||||
import { AISDKError, generateText, streamText } from 'ai';
|
||||
import type { ToolSet } from 'ai';
|
||||
|
||||
import {
|
||||
CopilotProviderSideError,
|
||||
metrics,
|
||||
UserFriendlyError,
|
||||
} from '../../../base';
|
||||
import {
|
||||
llmDispatchStream,
|
||||
type NativeLlmBackendConfig,
|
||||
type NativeLlmRequest,
|
||||
} from '../../../native';
|
||||
import type { NodeTextMiddleware } from '../config';
|
||||
import { buildNativeRequest, NativeProviderAdapter } from './native';
|
||||
import { CopilotProvider } from './provider';
|
||||
import type {
|
||||
CopilotChatOptions,
|
||||
@@ -16,7 +19,6 @@ import type {
|
||||
PromptMessage,
|
||||
} from './types';
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from './types';
|
||||
import { chatToGPTMessage, TextStreamParser } from './utils';
|
||||
|
||||
export const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
@@ -57,37 +59,48 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
|
||||
},
|
||||
];
|
||||
|
||||
#instance!: VercelOpenAICompatibleProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
return !!this.config.apiKey;
|
||||
}
|
||||
|
||||
protected override setup() {
|
||||
super.setup();
|
||||
this.#instance = createOpenAICompatible({
|
||||
name: this.type,
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: 'https://api.morphllm.com/v1',
|
||||
});
|
||||
}
|
||||
|
||||
private handleError(e: any) {
|
||||
if (e instanceof UserFriendlyError) {
|
||||
return e;
|
||||
} else if (e instanceof AISDKError) {
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: e.name || 'unknown',
|
||||
message: e.message,
|
||||
});
|
||||
} else {
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected morph response',
|
||||
});
|
||||
}
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected morph response',
|
||||
});
|
||||
}
|
||||
|
||||
private createNativeConfig(): NativeLlmBackendConfig {
|
||||
return {
|
||||
base_url: 'https://api.morphllm.com',
|
||||
auth_token: this.config.apiKey ?? '',
|
||||
};
|
||||
}
|
||||
|
||||
private createNativeAdapter(
|
||||
tools: ToolSet,
|
||||
nodeTextMiddleware?: NodeTextMiddleware[]
|
||||
) {
|
||||
return new NativeProviderAdapter(
|
||||
(request: NativeLlmRequest, signal?: AbortSignal) =>
|
||||
llmDispatchStream(
|
||||
'openai_chat',
|
||||
this.createNativeConfig(),
|
||||
request,
|
||||
signal
|
||||
),
|
||||
tools,
|
||||
this.MAX_STEPS,
|
||||
{ nodeTextMiddleware }
|
||||
);
|
||||
}
|
||||
|
||||
async text(
|
||||
@@ -103,22 +116,22 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const modelInstance = this.#instance(model.id);
|
||||
|
||||
const { text } = await generateText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
abortSignal: options.signal,
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
middleware,
|
||||
});
|
||||
|
||||
return text.trim();
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
return await adapter.text(request, options.signal);
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -136,38 +149,26 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const modelInstance = this.#instance(model.id);
|
||||
|
||||
const { fullStream } = streamText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
abortSignal: options.signal,
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_calls')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
middleware,
|
||||
});
|
||||
|
||||
const textParser = new TextStreamParser();
|
||||
for await (const chunk of fullStream) {
|
||||
switch (chunk.type) {
|
||||
case 'text-delta': {
|
||||
let result = textParser.parse(chunk);
|
||||
yield result;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
yield textParser.parse(chunk);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (options.signal?.aborted) {
|
||||
await fullStream.cancel();
|
||||
break;
|
||||
}
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
for await (const chunk of adapter.streamText(request, options.signal)) {
|
||||
yield chunk;
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
464
packages/backend/server/src/plugins/copilot/providers/native.ts
Normal file
464
packages/backend/server/src/plugins/copilot/providers/native.ts
Normal file
@@ -0,0 +1,464 @@
|
||||
import type { ToolSet } from 'ai';
|
||||
import { ZodType } from 'zod';
|
||||
|
||||
import type {
|
||||
NativeLlmCoreContent,
|
||||
NativeLlmCoreMessage,
|
||||
NativeLlmRequest,
|
||||
NativeLlmStreamEvent,
|
||||
} from '../../../native';
|
||||
import type { NodeTextMiddleware, ProviderMiddlewareConfig } from '../config';
|
||||
import { NativeDispatchFn, ToolCallLoop, ToolSchemaExtractor } from './loop';
|
||||
import type { CopilotChatOptions, PromptMessage, StreamObject } from './types';
|
||||
import {
|
||||
CitationFootnoteFormatter,
|
||||
inferMimeType,
|
||||
TextStreamParser,
|
||||
} from './utils';
|
||||
|
||||
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
|
||||
|
||||
type BuildNativeRequestOptions = {
|
||||
model: string;
|
||||
messages: PromptMessage[];
|
||||
options?: CopilotChatOptions;
|
||||
tools?: ToolSet;
|
||||
withAttachment?: boolean;
|
||||
include?: string[];
|
||||
reasoning?: Record<string, unknown>;
|
||||
middleware?: ProviderMiddlewareConfig;
|
||||
};
|
||||
|
||||
type BuildNativeRequestResult = {
|
||||
request: NativeLlmRequest;
|
||||
schema?: ZodType;
|
||||
};
|
||||
|
||||
type ToolCallMeta = {
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
};
|
||||
|
||||
type NormalizedToolResultEvent = Extract<
|
||||
NativeLlmStreamEvent,
|
||||
{ type: 'tool_result' }
|
||||
> & {
|
||||
name: string;
|
||||
arguments: Record<string, unknown>;
|
||||
};
|
||||
|
||||
type AttachmentFootnote = {
|
||||
blobId: string;
|
||||
fileName: string;
|
||||
fileType: string;
|
||||
};
|
||||
|
||||
type NativeProviderAdapterOptions = {
|
||||
nodeTextMiddleware?: NodeTextMiddleware[];
|
||||
};
|
||||
|
||||
function roleToCore(role: PromptMessage['role']) {
|
||||
switch (role) {
|
||||
case 'assistant':
|
||||
return 'assistant';
|
||||
case 'system':
|
||||
return 'system';
|
||||
default:
|
||||
return 'user';
|
||||
}
|
||||
}
|
||||
|
||||
async function toCoreContents(
|
||||
message: PromptMessage,
|
||||
withAttachment: boolean
|
||||
): Promise<NativeLlmCoreContent[]> {
|
||||
const contents: NativeLlmCoreContent[] = [];
|
||||
|
||||
if (typeof message.content === 'string' && message.content.length) {
|
||||
contents.push({ type: 'text', text: message.content });
|
||||
}
|
||||
|
||||
if (!withAttachment || !Array.isArray(message.attachments)) return contents;
|
||||
|
||||
for (const entry of message.attachments) {
|
||||
let attachmentUrl: string;
|
||||
let mediaType: string;
|
||||
|
||||
if (typeof entry === 'string') {
|
||||
attachmentUrl = entry;
|
||||
mediaType =
|
||||
typeof message.params?.mimetype === 'string'
|
||||
? message.params.mimetype
|
||||
: await inferMimeType(entry);
|
||||
} else {
|
||||
attachmentUrl = entry.attachment;
|
||||
mediaType = entry.mimeType;
|
||||
}
|
||||
|
||||
if (!SIMPLE_IMAGE_URL_REGEX.test(attachmentUrl)) continue;
|
||||
if (!mediaType.startsWith('image/')) continue;
|
||||
|
||||
contents.push({ type: 'image', source: { url: attachmentUrl } });
|
||||
}
|
||||
|
||||
return contents;
|
||||
}
|
||||
|
||||
export async function buildNativeRequest({
|
||||
model,
|
||||
messages,
|
||||
options = {},
|
||||
tools = {},
|
||||
withAttachment = true,
|
||||
include,
|
||||
reasoning,
|
||||
middleware,
|
||||
}: BuildNativeRequestOptions): Promise<BuildNativeRequestResult> {
|
||||
const copiedMessages = messages.map(message => ({
|
||||
...message,
|
||||
attachments: message.attachments
|
||||
? [...message.attachments]
|
||||
: message.attachments,
|
||||
}));
|
||||
|
||||
const systemMessage =
|
||||
copiedMessages[0]?.role === 'system' ? copiedMessages.shift() : undefined;
|
||||
const schema =
|
||||
systemMessage?.params?.schema instanceof ZodType
|
||||
? systemMessage.params.schema
|
||||
: undefined;
|
||||
|
||||
const coreMessages: NativeLlmCoreMessage[] = [];
|
||||
if (systemMessage?.content?.length) {
|
||||
coreMessages.push({
|
||||
role: 'system',
|
||||
content: [{ type: 'text', text: systemMessage.content }],
|
||||
});
|
||||
}
|
||||
|
||||
for (const message of copiedMessages) {
|
||||
if (message.role === 'system') continue;
|
||||
const content = await toCoreContents(message, withAttachment);
|
||||
coreMessages.push({ role: roleToCore(message.role), content });
|
||||
}
|
||||
|
||||
return {
|
||||
request: {
|
||||
model,
|
||||
stream: true,
|
||||
messages: coreMessages,
|
||||
max_tokens: options.maxTokens ?? undefined,
|
||||
temperature: options.temperature ?? undefined,
|
||||
tools: ToolSchemaExtractor.extract(tools),
|
||||
tool_choice: Object.keys(tools).length ? 'auto' : undefined,
|
||||
include,
|
||||
reasoning,
|
||||
middleware: middleware?.rust
|
||||
? { request: middleware.rust.request, stream: middleware.rust.stream }
|
||||
: undefined,
|
||||
},
|
||||
schema,
|
||||
};
|
||||
}
|
||||
|
||||
function ensureToolResultMeta(
|
||||
event: Extract<NativeLlmStreamEvent, { type: 'tool_result' }>,
|
||||
toolCalls: Map<string, ToolCallMeta>
|
||||
): NormalizedToolResultEvent | null {
|
||||
const name = event.name ?? toolCalls.get(event.call_id)?.name;
|
||||
const args = event.arguments ?? toolCalls.get(event.call_id)?.args;
|
||||
|
||||
if (!name || !args) return null;
|
||||
return { ...event, name, arguments: args };
|
||||
}
|
||||
|
||||
function pickAttachmentFootnote(value: unknown): AttachmentFootnote | null {
|
||||
if (!value || typeof value !== 'object') {
|
||||
return null;
|
||||
}
|
||||
|
||||
const record = value as Record<string, unknown>;
|
||||
const blobId =
|
||||
typeof record.blobId === 'string'
|
||||
? record.blobId
|
||||
: typeof record.blob_id === 'string'
|
||||
? record.blob_id
|
||||
: undefined;
|
||||
const fileName =
|
||||
typeof record.fileName === 'string'
|
||||
? record.fileName
|
||||
: typeof record.name === 'string'
|
||||
? record.name
|
||||
: undefined;
|
||||
const fileType =
|
||||
typeof record.fileType === 'string'
|
||||
? record.fileType
|
||||
: typeof record.mimeType === 'string'
|
||||
? record.mimeType
|
||||
: 'application/octet-stream';
|
||||
|
||||
if (!blobId || !fileName) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return { blobId, fileName, fileType };
|
||||
}
|
||||
|
||||
function collectAttachmentFootnotes(
|
||||
event: NormalizedToolResultEvent
|
||||
): AttachmentFootnote[] {
|
||||
if (event.name === 'blob_read') {
|
||||
const item = pickAttachmentFootnote(event.output);
|
||||
return item ? [item] : [];
|
||||
}
|
||||
|
||||
if (event.name === 'doc_semantic_search' && Array.isArray(event.output)) {
|
||||
return event.output
|
||||
.map(item => pickAttachmentFootnote(item))
|
||||
.filter((item): item is AttachmentFootnote => item !== null);
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
function formatAttachmentFootnotes(attachments: AttachmentFootnote[]) {
|
||||
const references = attachments.map((_, index) => `[^${index + 1}]`).join('');
|
||||
const definitions = attachments
|
||||
.map((attachment, index) => {
|
||||
return `[^${index + 1}]: ${JSON.stringify({
|
||||
type: 'attachment',
|
||||
blobId: attachment.blobId,
|
||||
fileName: attachment.fileName,
|
||||
fileType: attachment.fileType,
|
||||
})}`;
|
||||
})
|
||||
.join('\n');
|
||||
|
||||
return `\n\n${references}\n\n${definitions}`;
|
||||
}
|
||||
|
||||
export class NativeProviderAdapter {
|
||||
readonly #loop: ToolCallLoop;
|
||||
readonly #enableCallout: boolean;
|
||||
readonly #enableCitationFootnote: boolean;
|
||||
|
||||
constructor(
|
||||
dispatch: NativeDispatchFn,
|
||||
tools: ToolSet,
|
||||
maxSteps = 20,
|
||||
options: NativeProviderAdapterOptions = {}
|
||||
) {
|
||||
this.#loop = new ToolCallLoop(dispatch, tools, maxSteps);
|
||||
const enabledNodeTextMiddlewares = new Set(
|
||||
options.nodeTextMiddleware ?? ['citation_footnote', 'callout']
|
||||
);
|
||||
this.#enableCallout =
|
||||
enabledNodeTextMiddlewares.has('callout') ||
|
||||
enabledNodeTextMiddlewares.has('thinking_format');
|
||||
this.#enableCitationFootnote =
|
||||
enabledNodeTextMiddlewares.has('citation_footnote');
|
||||
}
|
||||
|
||||
async text(request: NativeLlmRequest, signal?: AbortSignal) {
|
||||
let output = '';
|
||||
for await (const chunk of this.streamText(request, signal)) {
|
||||
output += chunk;
|
||||
}
|
||||
return output.trim();
|
||||
}
|
||||
|
||||
async *streamText(
|
||||
request: NativeLlmRequest,
|
||||
signal?: AbortSignal
|
||||
): AsyncIterableIterator<string> {
|
||||
const textParser = this.#enableCallout ? new TextStreamParser() : null;
|
||||
const citationFormatter = this.#enableCitationFootnote
|
||||
? new CitationFootnoteFormatter()
|
||||
: null;
|
||||
const toolCalls = new Map<string, ToolCallMeta>();
|
||||
let streamPartId = 0;
|
||||
|
||||
for await (const event of this.#loop.run(request, signal)) {
|
||||
switch (event.type) {
|
||||
case 'text_delta': {
|
||||
if (textParser) {
|
||||
yield textParser.parse({
|
||||
type: 'text-delta',
|
||||
id: String(streamPartId++),
|
||||
text: event.text,
|
||||
});
|
||||
} else {
|
||||
yield event.text;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'reasoning_delta': {
|
||||
if (textParser) {
|
||||
yield textParser.parse({
|
||||
type: 'reasoning-delta',
|
||||
id: String(streamPartId++),
|
||||
text: event.text,
|
||||
});
|
||||
} else {
|
||||
yield event.text;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'tool_call': {
|
||||
const toolCall = {
|
||||
name: event.name,
|
||||
args: event.arguments,
|
||||
};
|
||||
toolCalls.set(event.call_id, toolCall);
|
||||
if (textParser) {
|
||||
yield textParser.parse({
|
||||
type: 'tool-call',
|
||||
toolCallId: event.call_id,
|
||||
toolName: event.name as never,
|
||||
input: event.arguments,
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'tool_result': {
|
||||
const normalized = ensureToolResultMeta(event, toolCalls);
|
||||
if (!normalized || !textParser) {
|
||||
break;
|
||||
}
|
||||
yield textParser.parse({
|
||||
type: 'tool-result',
|
||||
toolCallId: normalized.call_id,
|
||||
toolName: normalized.name as never,
|
||||
input: normalized.arguments,
|
||||
output: normalized.output,
|
||||
});
|
||||
break;
|
||||
}
|
||||
case 'citation': {
|
||||
if (citationFormatter) {
|
||||
citationFormatter.consume({
|
||||
type: 'citation',
|
||||
index: event.index,
|
||||
url: event.url,
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'done': {
|
||||
const footnotes = textParser?.end() ?? '';
|
||||
const citations = citationFormatter?.end() ?? '';
|
||||
const tails = [citations, footnotes].filter(Boolean).join('\n');
|
||||
if (tails) {
|
||||
yield `\n${tails}`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'error': {
|
||||
throw new Error(event.message);
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async *streamObject(
|
||||
request: NativeLlmRequest,
|
||||
signal?: AbortSignal
|
||||
): AsyncIterableIterator<StreamObject> {
|
||||
const toolCalls = new Map<string, ToolCallMeta>();
|
||||
const citationFormatter = this.#enableCitationFootnote
|
||||
? new CitationFootnoteFormatter()
|
||||
: null;
|
||||
const fallbackAttachmentFootnotes = new Map<string, AttachmentFootnote>();
|
||||
let hasFootnoteReference = false;
|
||||
|
||||
for await (const event of this.#loop.run(request, signal)) {
|
||||
switch (event.type) {
|
||||
case 'text_delta': {
|
||||
if (event.text.includes('[^')) {
|
||||
hasFootnoteReference = true;
|
||||
}
|
||||
yield {
|
||||
type: 'text-delta',
|
||||
textDelta: event.text,
|
||||
};
|
||||
break;
|
||||
}
|
||||
case 'reasoning_delta': {
|
||||
yield {
|
||||
type: 'reasoning',
|
||||
textDelta: event.text,
|
||||
};
|
||||
break;
|
||||
}
|
||||
case 'tool_call': {
|
||||
const toolCall = {
|
||||
name: event.name,
|
||||
args: event.arguments,
|
||||
};
|
||||
toolCalls.set(event.call_id, toolCall);
|
||||
yield {
|
||||
type: 'tool-call',
|
||||
toolCallId: event.call_id,
|
||||
toolName: event.name,
|
||||
args: event.arguments,
|
||||
};
|
||||
break;
|
||||
}
|
||||
case 'tool_result': {
|
||||
const normalized = ensureToolResultMeta(event, toolCalls);
|
||||
if (!normalized) {
|
||||
break;
|
||||
}
|
||||
const attachments = collectAttachmentFootnotes(normalized);
|
||||
attachments.forEach(attachment => {
|
||||
fallbackAttachmentFootnotes.set(attachment.blobId, attachment);
|
||||
});
|
||||
yield {
|
||||
type: 'tool-result',
|
||||
toolCallId: normalized.call_id,
|
||||
toolName: normalized.name,
|
||||
args: normalized.arguments,
|
||||
result: normalized.output,
|
||||
};
|
||||
break;
|
||||
}
|
||||
case 'citation': {
|
||||
if (citationFormatter) {
|
||||
citationFormatter.consume({
|
||||
type: 'citation',
|
||||
index: event.index,
|
||||
url: event.url,
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'done': {
|
||||
const citations = citationFormatter?.end() ?? '';
|
||||
if (citations) {
|
||||
hasFootnoteReference = true;
|
||||
yield {
|
||||
type: 'text-delta',
|
||||
textDelta: `\n${citations}`,
|
||||
};
|
||||
}
|
||||
if (!hasFootnoteReference && fallbackAttachmentFootnotes.size > 0) {
|
||||
yield {
|
||||
type: 'text-delta',
|
||||
textDelta: formatAttachmentFootnotes(
|
||||
Array.from(fallbackAttachmentFootnotes.values())
|
||||
),
|
||||
};
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'error': {
|
||||
throw new Error(event.message);
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,53 +1,35 @@
|
||||
import {
|
||||
createOpenAI,
|
||||
openai,
|
||||
type OpenAIProvider as VercelOpenAIProvider,
|
||||
OpenAIResponsesProviderOptions,
|
||||
} from '@ai-sdk/openai';
|
||||
import {
|
||||
createOpenAICompatible,
|
||||
type OpenAICompatibleProvider as VercelOpenAICompatibleProvider,
|
||||
} from '@ai-sdk/openai-compatible';
|
||||
import {
|
||||
AISDKError,
|
||||
embedMany,
|
||||
experimental_generateImage as generateImage,
|
||||
generateObject,
|
||||
generateText,
|
||||
stepCountIs,
|
||||
streamText,
|
||||
Tool,
|
||||
} from 'ai';
|
||||
import type { Tool, ToolSet } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
CopilotPromptInvalid,
|
||||
CopilotProviderNotSupported,
|
||||
CopilotProviderSideError,
|
||||
fetchBuffer,
|
||||
metrics,
|
||||
OneMB,
|
||||
readResponseBufferWithLimit,
|
||||
safeFetch,
|
||||
UserFriendlyError,
|
||||
} from '../../../base';
|
||||
import {
|
||||
llmDispatchStream,
|
||||
type NativeLlmBackendConfig,
|
||||
type NativeLlmRequest,
|
||||
} from '../../../native';
|
||||
import type { NodeTextMiddleware } from '../config';
|
||||
import { buildNativeRequest, NativeProviderAdapter } from './native';
|
||||
import { CopilotProvider } from './provider';
|
||||
import type {
|
||||
CopilotChatOptions,
|
||||
CopilotChatTools,
|
||||
CopilotEmbeddingOptions,
|
||||
CopilotImageOptions,
|
||||
CopilotProviderModel,
|
||||
CopilotStructuredOptions,
|
||||
ModelConditions,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from './types';
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from './types';
|
||||
import {
|
||||
chatToGPTMessage,
|
||||
CitationParser,
|
||||
StreamObjectParser,
|
||||
TextStreamParser,
|
||||
} from './utils';
|
||||
import { chatToGPTMessage } from './utils';
|
||||
|
||||
export const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
@@ -63,7 +45,12 @@ const ModelListSchema = z.object({
|
||||
|
||||
const ImageResponseSchema = z.union([
|
||||
z.object({
|
||||
data: z.array(z.object({ b64_json: z.string() })),
|
||||
data: z.array(
|
||||
z.object({
|
||||
b64_json: z.string().optional(),
|
||||
url: z.string().optional(),
|
||||
})
|
||||
),
|
||||
}),
|
||||
z.object({
|
||||
error: z.object({
|
||||
@@ -87,6 +74,38 @@ const LogProbsSchema = z.array(
|
||||
})
|
||||
);
|
||||
|
||||
const TRUSTED_ATTACHMENT_HOST_SUFFIXES = ['cdn.affine.pro'];
|
||||
|
||||
function normalizeImageFormatToMime(format?: string) {
|
||||
switch (format?.toLowerCase()) {
|
||||
case 'jpg':
|
||||
case 'jpeg':
|
||||
return 'image/jpeg';
|
||||
case 'webp':
|
||||
return 'image/webp';
|
||||
case 'png':
|
||||
return 'image/png';
|
||||
case 'gif':
|
||||
return 'image/gif';
|
||||
default:
|
||||
return 'image/png';
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeImageResponseData(
|
||||
data: { b64_json?: string; url?: string }[],
|
||||
mimeType: string = 'image/png'
|
||||
) {
|
||||
return data
|
||||
.map(image => {
|
||||
if (image.b64_json) {
|
||||
return `data:${mimeType};base64,${image.b64_json}`;
|
||||
}
|
||||
return image.url;
|
||||
})
|
||||
.filter((value): value is string => typeof value === 'string');
|
||||
}
|
||||
|
||||
export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
readonly type = CopilotProviderType.OpenAI;
|
||||
|
||||
@@ -319,53 +338,23 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
},
|
||||
];
|
||||
|
||||
#instance!: VercelOpenAIProvider | VercelOpenAICompatibleProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
return !!this.config.apiKey;
|
||||
}
|
||||
|
||||
protected override setup() {
|
||||
super.setup();
|
||||
this.#instance =
|
||||
this.config.oldApiStyle && this.config.baseURL
|
||||
? createOpenAICompatible({
|
||||
name: 'openai-compatible-old-style',
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseURL,
|
||||
})
|
||||
: createOpenAI({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseURL,
|
||||
});
|
||||
}
|
||||
|
||||
private handleError(
|
||||
e: any,
|
||||
model: string,
|
||||
options: CopilotImageOptions = {}
|
||||
) {
|
||||
private handleError(e: any) {
|
||||
if (e instanceof UserFriendlyError) {
|
||||
return e;
|
||||
} else if (e instanceof AISDKError) {
|
||||
if (e.message.includes('safety') || e.message.includes('risk')) {
|
||||
metrics.ai
|
||||
.counter('chat_text_risk_errors')
|
||||
.add(1, { model, user: options.user || undefined });
|
||||
}
|
||||
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: e.name || 'unknown',
|
||||
message: e.message,
|
||||
});
|
||||
} else {
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected openai response',
|
||||
});
|
||||
}
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected openai response',
|
||||
});
|
||||
}
|
||||
|
||||
override async refreshOnlineModels() {
|
||||
@@ -389,20 +378,50 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
override getProviderSpecificTools(
|
||||
toolName: CopilotChatTools,
|
||||
model: string
|
||||
_model: string
|
||||
): [string, Tool?] | undefined {
|
||||
if (
|
||||
toolName === 'webSearch' &&
|
||||
'responses' in this.#instance &&
|
||||
!this.isReasoningModel(model)
|
||||
) {
|
||||
return ['web_search_preview', openai.tools.webSearch({})];
|
||||
} else if (toolName === 'docEdit') {
|
||||
if (toolName === 'docEdit') {
|
||||
return ['doc_edit', undefined];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
private createNativeConfig(): NativeLlmBackendConfig {
|
||||
const baseUrl = this.config.baseURL || 'https://api.openai.com/v1';
|
||||
return {
|
||||
base_url: baseUrl.replace(/\/v1\/?$/, ''),
|
||||
auth_token: this.config.apiKey,
|
||||
};
|
||||
}
|
||||
|
||||
private createNativeAdapter(
|
||||
tools: ToolSet,
|
||||
nodeTextMiddleware?: NodeTextMiddleware[]
|
||||
) {
|
||||
return new NativeProviderAdapter(
|
||||
(request: NativeLlmRequest, signal?: AbortSignal) =>
|
||||
llmDispatchStream(
|
||||
this.config.oldApiStyle ? 'openai_chat' : 'openai_responses',
|
||||
this.createNativeConfig(),
|
||||
request,
|
||||
signal
|
||||
),
|
||||
tools,
|
||||
this.MAX_STEPS,
|
||||
{ nodeTextMiddleware }
|
||||
);
|
||||
}
|
||||
|
||||
private getReasoning(
|
||||
options: NonNullable<CopilotChatOptions>,
|
||||
model: string
|
||||
): Record<string, unknown> | undefined {
|
||||
if (options.reasoning && this.isReasoningModel(model)) {
|
||||
return { effort: 'medium' };
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async text(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
@@ -413,33 +432,25 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const modelInstance =
|
||||
'responses' in this.#instance
|
||||
? this.#instance.responses(model.id)
|
||||
: this.#instance(model.id);
|
||||
|
||||
const { text } = await generateText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
temperature: options.temperature ?? 0,
|
||||
maxOutputTokens: options.maxTokens ?? 4096,
|
||||
providerOptions: {
|
||||
openai: this.getOpenAIOptions(options, model.id),
|
||||
},
|
||||
tools: await this.getTools(options, model.id),
|
||||
stopWhen: stepCountIs(this.MAX_STEPS),
|
||||
abortSignal: options.signal,
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
include: options.webSearch ? ['citations'] : undefined,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
|
||||
return text.trim();
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
return await adapter.text(request, options.signal);
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
throw this.handleError(e, model.id, options);
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -456,38 +467,29 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const citationParser = new CitationParser();
|
||||
const textParser = new TextStreamParser();
|
||||
for await (const chunk of fullStream) {
|
||||
switch (chunk.type) {
|
||||
case 'text-delta': {
|
||||
let result = textParser.parse(chunk);
|
||||
result = citationParser.parse(result);
|
||||
yield result;
|
||||
break;
|
||||
}
|
||||
case 'finish': {
|
||||
const footnotes = textParser.end();
|
||||
const result =
|
||||
citationParser.end() + (footnotes.length ? '\n' + footnotes : '');
|
||||
yield result;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
yield textParser.parse(chunk);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (options.signal?.aborted) {
|
||||
await fullStream.cancel();
|
||||
break;
|
||||
}
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_calls')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
include: options.webSearch ? ['citations'] : undefined,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
for await (const chunk of adapter.streamText(request, options.signal)) {
|
||||
yield chunk;
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
|
||||
throw this.handleError(e, model.id, options);
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -503,24 +505,27 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_calls')
|
||||
.add(1, { model: model.id });
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const parser = new StreamObjectParser();
|
||||
for await (const chunk of fullStream) {
|
||||
const result = parser.parse(chunk);
|
||||
if (result) {
|
||||
yield result;
|
||||
}
|
||||
if (options.signal?.aborted) {
|
||||
await fullStream.cancel();
|
||||
break;
|
||||
}
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
include: options.webSearch ? ['citations'] : undefined,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
for await (const chunk of adapter.streamObject(request, options.signal)) {
|
||||
yield chunk;
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_errors')
|
||||
.add(1, { model: model.id });
|
||||
throw this.handleError(e, model.id, options);
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -535,35 +540,27 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
|
||||
const [system, msgs, schema] = await chatToGPTMessage(messages);
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request, schema } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
if (!schema) {
|
||||
throw new CopilotPromptInvalid('Schema is required');
|
||||
}
|
||||
|
||||
const modelInstance =
|
||||
'responses' in this.#instance
|
||||
? this.#instance.responses(model.id)
|
||||
: this.#instance(model.id);
|
||||
|
||||
const { object } = await generateObject({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
temperature: options.temperature ?? 0,
|
||||
maxOutputTokens: options.maxTokens ?? 4096,
|
||||
maxRetries: options.maxRetries ?? 3,
|
||||
schema,
|
||||
providerOptions: {
|
||||
openai: options.user ? { user: options.user } : {},
|
||||
},
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
|
||||
return JSON.stringify(object);
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
const text = await adapter.text(request, options.signal);
|
||||
const parsed = JSON.parse(text);
|
||||
const validated = schema.parse(parsed);
|
||||
return JSON.stringify(validated);
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
throw this.handleError(e, model.id, options);
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -575,36 +572,32 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ messages: [], cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
// get the log probability of "yes"/"no"
|
||||
const instance =
|
||||
'chat' in this.#instance
|
||||
? this.#instance.chat(model.id)
|
||||
: this.#instance(model.id);
|
||||
|
||||
const scores = await Promise.all(
|
||||
chunkMessages.map(async messages => {
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const result = await generateText({
|
||||
model: instance,
|
||||
system,
|
||||
messages: msgs,
|
||||
temperature: 0,
|
||||
maxOutputTokens: 16,
|
||||
providerOptions: {
|
||||
openai: {
|
||||
...this.getOpenAIOptions(options, model.id),
|
||||
logprobs: 16,
|
||||
},
|
||||
const response = await this.requestOpenAIJson(
|
||||
'/chat/completions',
|
||||
{
|
||||
model: model.id,
|
||||
messages: this.toOpenAIChatMessages(system, msgs),
|
||||
temperature: 0,
|
||||
max_tokens: 16,
|
||||
logprobs: true,
|
||||
top_logprobs: 16,
|
||||
},
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
options.signal
|
||||
);
|
||||
|
||||
const topMap: Record<string, number> = LogProbsSchema.parse(
|
||||
result.providerMetadata?.openai?.logprobs
|
||||
)[0].top_logprobs.reduce<Record<string, number>>(
|
||||
const logprobs = response?.choices?.[0]?.logprobs?.content;
|
||||
if (!Array.isArray(logprobs) || logprobs.length === 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const parsedLogprobs = LogProbsSchema.parse(logprobs);
|
||||
const topMap = parsedLogprobs[0].top_logprobs.reduce(
|
||||
(acc, { token, logprob }) => ({ ...acc, [token]: logprob }),
|
||||
{}
|
||||
{} as Record<string, number>
|
||||
);
|
||||
|
||||
const findLogProb = (token: string): number => {
|
||||
@@ -634,50 +627,212 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
return scores;
|
||||
}
|
||||
|
||||
private async getFullStream(
|
||||
model: CopilotProviderModel,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
) {
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
const modelInstance =
|
||||
'responses' in this.#instance
|
||||
? this.#instance.responses(model.id)
|
||||
: this.#instance(model.id);
|
||||
const { fullStream } = streamText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
frequencyPenalty: options.frequencyPenalty ?? 0,
|
||||
presencePenalty: options.presencePenalty ?? 0,
|
||||
temperature: options.temperature ?? 0,
|
||||
maxOutputTokens: options.maxTokens ?? 4096,
|
||||
providerOptions: {
|
||||
openai: this.getOpenAIOptions(options, model.id),
|
||||
},
|
||||
tools: await this.getTools(options, model.id),
|
||||
stopWhen: stepCountIs(this.MAX_STEPS),
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
return fullStream;
|
||||
// ====== text to image ======
|
||||
private buildImageFetchOptions(url: URL) {
|
||||
const baseOptions = { timeoutMs: 15_000, maxRedirects: 3 } as const;
|
||||
const trustedOrigins = new Set<string>();
|
||||
const protocol = this.AFFiNEConfig.server.https ? 'https:' : 'http:';
|
||||
const port = this.AFFiNEConfig.server.port;
|
||||
const isDefaultPort =
|
||||
(protocol === 'https:' && port === 443) ||
|
||||
(protocol === 'http:' && port === 80);
|
||||
|
||||
const addHostOrigin = (host: string) => {
|
||||
if (!host) return;
|
||||
try {
|
||||
const parsed = new URL(`${protocol}//${host}`);
|
||||
if (!parsed.port && !isDefaultPort) {
|
||||
parsed.port = String(port);
|
||||
}
|
||||
trustedOrigins.add(parsed.origin);
|
||||
} catch {
|
||||
// ignore invalid host config entries
|
||||
}
|
||||
};
|
||||
|
||||
if (this.AFFiNEConfig.server.externalUrl) {
|
||||
try {
|
||||
trustedOrigins.add(
|
||||
new URL(this.AFFiNEConfig.server.externalUrl).origin
|
||||
);
|
||||
} catch {
|
||||
// ignore invalid external URL
|
||||
}
|
||||
}
|
||||
|
||||
addHostOrigin(this.AFFiNEConfig.server.host);
|
||||
for (const host of this.AFFiNEConfig.server.hosts) {
|
||||
addHostOrigin(host);
|
||||
}
|
||||
|
||||
const hostname = url.hostname.toLowerCase();
|
||||
const trustedByHost = TRUSTED_ATTACHMENT_HOST_SUFFIXES.some(
|
||||
suffix => hostname === suffix || hostname.endsWith(`.${suffix}`)
|
||||
);
|
||||
if (trustedOrigins.has(url.origin) || trustedByHost) {
|
||||
return { ...baseOptions, allowPrivateOrigins: new Set([url.origin]) };
|
||||
}
|
||||
|
||||
return baseOptions;
|
||||
}
|
||||
|
||||
private redactUrl(raw: string | URL): string {
|
||||
try {
|
||||
const parsed = raw instanceof URL ? raw : new URL(raw);
|
||||
if (parsed.protocol === 'data:') return 'data:[redacted]';
|
||||
const segments = parsed.pathname.split('/').filter(Boolean);
|
||||
const redactedPath =
|
||||
segments.length <= 2
|
||||
? parsed.pathname || '/'
|
||||
: `/${segments[0]}/${segments[1]}/...`;
|
||||
return `${parsed.origin}${redactedPath}`;
|
||||
} catch {
|
||||
return '[invalid-url]';
|
||||
}
|
||||
}
|
||||
|
||||
private async fetchImage(
|
||||
url: string,
|
||||
maxBytes: number,
|
||||
signal?: AbortSignal
|
||||
): Promise<{ buffer: Buffer; type: string } | null> {
|
||||
if (url.startsWith('data:')) {
|
||||
let response: Response;
|
||||
try {
|
||||
response = await fetch(url, { signal });
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Skip image attachment data URL due to read failure: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
this.logger.warn(
|
||||
`Skip image attachment data URL due to invalid response: ${response.status}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
const type =
|
||||
response.headers.get('content-type') || 'application/octet-stream';
|
||||
if (!type.startsWith('image/')) {
|
||||
await response.body?.cancel().catch(() => undefined);
|
||||
this.logger.warn(
|
||||
`Skip non-image attachment data URL with content-type ${type}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
const buffer = await readResponseBufferWithLimit(response, maxBytes);
|
||||
return { buffer, type };
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Skip image attachment data URL due to read failure/size limit: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
let parsed: URL;
|
||||
try {
|
||||
parsed = new URL(url);
|
||||
} catch {
|
||||
this.logger.warn(
|
||||
`Skip image attachment with invalid URL: ${this.redactUrl(url)}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
const redactedUrl = this.redactUrl(parsed);
|
||||
|
||||
if (parsed.protocol !== 'http:' && parsed.protocol !== 'https:') {
|
||||
this.logger.warn(
|
||||
`Skip image attachment with unsupported protocol: ${redactedUrl}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
let response: Response;
|
||||
try {
|
||||
response = await safeFetch(
|
||||
parsed,
|
||||
{ method: 'GET', signal },
|
||||
this.buildImageFetchOptions(parsed)
|
||||
);
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Skip image attachment due to blocked/unreachable URL: ${redactedUrl}, reason: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
this.logger.warn(
|
||||
`Skip image attachment fetch failure ${response.status}: ${redactedUrl}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
const type =
|
||||
response.headers.get('content-type') || 'application/octet-stream';
|
||||
if (!type.startsWith('image/')) {
|
||||
await response.body?.cancel().catch(() => undefined);
|
||||
this.logger.warn(
|
||||
`Skip non-image attachment with content-type ${type}: ${redactedUrl}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
const contentLength = Number(response.headers.get('content-length'));
|
||||
if (Number.isFinite(contentLength) && contentLength > maxBytes) {
|
||||
await response.body?.cancel().catch(() => undefined);
|
||||
this.logger.warn(
|
||||
`Skip oversized image attachment by content-length (${contentLength}): ${redactedUrl}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
const buffer = await readResponseBufferWithLimit(response, maxBytes);
|
||||
return { buffer, type };
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Skip image attachment due to read failure/size limit: ${redactedUrl}, reason: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
// ====== text to image ======
|
||||
private async *generateImageWithAttachments(
|
||||
model: string,
|
||||
prompt: string,
|
||||
attachments: NonNullable<PromptMessage['attachments']>
|
||||
attachments: NonNullable<PromptMessage['attachments']>,
|
||||
signal?: AbortSignal
|
||||
): AsyncGenerator<string> {
|
||||
const form = new FormData();
|
||||
const outputFormat = 'webp';
|
||||
const maxBytes = 10 * OneMB;
|
||||
form.set('model', model);
|
||||
form.set('prompt', prompt);
|
||||
form.set('output_format', 'webp');
|
||||
form.set('output_format', outputFormat);
|
||||
|
||||
for (const [idx, entry] of attachments.entries()) {
|
||||
const url = typeof entry === 'string' ? entry : entry.attachment;
|
||||
try {
|
||||
const { buffer, type } = await fetchBuffer(url, 10 * OneMB, 'image/');
|
||||
const file = new File([buffer], `${idx}.png`, { type });
|
||||
const attachment = await this.fetchImage(url, maxBytes, signal);
|
||||
if (!attachment) continue;
|
||||
const { buffer, type } = attachment;
|
||||
const extension = type.split(';')[0].split('/')[1] || 'png';
|
||||
const file = new File([buffer], `${idx}.${extension}`, { type });
|
||||
form.append('image[]', file);
|
||||
} catch {
|
||||
continue;
|
||||
@@ -703,18 +858,24 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
const json = await res.json();
|
||||
const imageResponse = ImageResponseSchema.safeParse(json);
|
||||
if (imageResponse.success) {
|
||||
const data = imageResponse.data;
|
||||
if ('error' in data) {
|
||||
throw new Error(data.error.message);
|
||||
} else {
|
||||
for (const image of data.data) {
|
||||
yield `data:image/webp;base64,${image.b64_json}`;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!imageResponse.success) {
|
||||
throw new Error(imageResponse.error.message);
|
||||
}
|
||||
const data = imageResponse.data;
|
||||
if ('error' in data) {
|
||||
throw new Error(data.error.message);
|
||||
}
|
||||
|
||||
const images = normalizeImageResponseData(
|
||||
data.data,
|
||||
normalizeImageFormatToMime(outputFormat)
|
||||
);
|
||||
if (!images.length) {
|
||||
throw new Error('No images returned from OpenAI');
|
||||
}
|
||||
for (const image of images) {
|
||||
yield image;
|
||||
}
|
||||
}
|
||||
|
||||
override async *streamImages(
|
||||
@@ -726,13 +887,6 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
if (!('image' in this.#instance)) {
|
||||
throw new CopilotProviderNotSupported({
|
||||
provider: this.type,
|
||||
kind: 'image',
|
||||
});
|
||||
}
|
||||
|
||||
metrics.ai
|
||||
.counter('generate_images_stream_calls')
|
||||
.add(1, { model: model.id });
|
||||
@@ -742,22 +896,27 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
try {
|
||||
if (attachments && attachments.length > 0) {
|
||||
yield* this.generateImageWithAttachments(model.id, prompt, attachments);
|
||||
} else {
|
||||
const modelInstance = this.#instance.image(model.id);
|
||||
const result = await generateImage({
|
||||
model: modelInstance,
|
||||
yield* this.generateImageWithAttachments(
|
||||
model.id,
|
||||
prompt,
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: options.quality || null,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const imageUrls = result.images.map(
|
||||
image => `data:image/png;base64,${image.base64}`
|
||||
attachments,
|
||||
options.signal
|
||||
);
|
||||
} else {
|
||||
const response = await this.requestOpenAIJson('/images/generations', {
|
||||
model: model.id,
|
||||
prompt,
|
||||
...(options.quality ? { quality: options.quality } : {}),
|
||||
});
|
||||
const imageResponse = ImageResponseSchema.parse(response);
|
||||
if ('error' in imageResponse) {
|
||||
throw new Error(imageResponse.error.message);
|
||||
}
|
||||
|
||||
const imageUrls = normalizeImageResponseData(imageResponse.data);
|
||||
if (!imageUrls.length) {
|
||||
throw new Error('No images returned from OpenAI');
|
||||
}
|
||||
|
||||
for (const imageUrl of imageUrls) {
|
||||
yield imageUrl;
|
||||
@@ -769,7 +928,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
return;
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('generate_images_errors').add(1, { model: model.id });
|
||||
throw this.handleError(e, model.id, options);
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -783,51 +942,85 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
await this.checkParams({ embeddings: messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
if (!('embedding' in this.#instance)) {
|
||||
throw new CopilotProviderNotSupported({
|
||||
provider: this.type,
|
||||
kind: 'embedding',
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('generate_embedding_calls')
|
||||
.add(1, { model: model.id });
|
||||
|
||||
const modelInstance = this.#instance.embedding(model.id);
|
||||
|
||||
const { embeddings } = await embedMany({
|
||||
model: modelInstance,
|
||||
values: messages,
|
||||
providerOptions: {
|
||||
openai: {
|
||||
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
|
||||
},
|
||||
},
|
||||
const response = await this.requestOpenAIJson('/embeddings', {
|
||||
model: model.id,
|
||||
input: messages,
|
||||
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
|
||||
});
|
||||
|
||||
return embeddings.filter(v => v && Array.isArray(v));
|
||||
const data = Array.isArray(response?.data) ? response.data : [];
|
||||
return data
|
||||
.map((item: any) => item?.embedding)
|
||||
.filter((embedding: unknown) => Array.isArray(embedding)) as number[][];
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('generate_embedding_errors')
|
||||
.add(1, { model: model.id });
|
||||
throw this.handleError(e, model.id, options);
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
private getOpenAIOptions(options: CopilotChatOptions, model: string) {
|
||||
const result: OpenAIResponsesProviderOptions = {};
|
||||
if (options?.reasoning && this.isReasoningModel(model)) {
|
||||
result.reasoningEffort = 'medium';
|
||||
result.reasoningSummary = 'detailed';
|
||||
private toOpenAIChatMessages(
|
||||
system: string | undefined,
|
||||
messages: Awaited<ReturnType<typeof chatToGPTMessage>>[1]
|
||||
) {
|
||||
const result: Array<{ role: string; content: string }> = [];
|
||||
if (system) {
|
||||
result.push({ role: 'system', content: system });
|
||||
}
|
||||
if (options?.user) {
|
||||
result.user = options.user;
|
||||
|
||||
for (const message of messages) {
|
||||
if (typeof message.content === 'string') {
|
||||
result.push({ role: message.role, content: message.content });
|
||||
continue;
|
||||
}
|
||||
|
||||
const text = message.content
|
||||
.filter(
|
||||
part =>
|
||||
part &&
|
||||
typeof part === 'object' &&
|
||||
'type' in part &&
|
||||
part.type === 'text' &&
|
||||
'text' in part
|
||||
)
|
||||
.map(part => String((part as { text: string }).text))
|
||||
.join('\n');
|
||||
|
||||
result.push({ role: message.role, content: text || '[no content]' });
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private async requestOpenAIJson(
|
||||
path: string,
|
||||
body: Record<string, unknown>,
|
||||
signal?: AbortSignal
|
||||
): Promise<any> {
|
||||
const baseUrl = this.config.baseURL || 'https://api.openai.com/v1';
|
||||
const response = await fetch(`${baseUrl}${path}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`OpenAI API error ${response.status}: ${await response.text()}`
|
||||
);
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
private isReasoningModel(model: string) {
|
||||
// o series reasoning models
|
||||
return model.startsWith('o') || model.startsWith('gpt-5');
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import {
|
||||
createPerplexity,
|
||||
type PerplexityProvider as VercelPerplexityProvider,
|
||||
} from '@ai-sdk/perplexity';
|
||||
import { generateText, streamText } from 'ai';
|
||||
import { z } from 'zod';
|
||||
import type { ToolSet } from 'ai';
|
||||
|
||||
import { CopilotProviderSideError, metrics } from '../../../base';
|
||||
import {
|
||||
llmDispatchStream,
|
||||
type NativeLlmBackendConfig,
|
||||
type NativeLlmRequest,
|
||||
} from '../../../native';
|
||||
import type { NodeTextMiddleware } from '../config';
|
||||
import { buildNativeRequest, NativeProviderAdapter } from './native';
|
||||
import { CopilotProvider } from './provider';
|
||||
import {
|
||||
CopilotChatOptions,
|
||||
@@ -15,34 +17,12 @@ import {
|
||||
ModelOutputType,
|
||||
PromptMessage,
|
||||
} from './types';
|
||||
import { chatToGPTMessage, CitationParser } from './utils';
|
||||
|
||||
export type PerplexityConfig = {
|
||||
apiKey: string;
|
||||
endpoint?: string;
|
||||
};
|
||||
|
||||
const PerplexityErrorSchema = z.union([
|
||||
z.object({
|
||||
detail: z.array(
|
||||
z.object({
|
||||
loc: z.array(z.string()),
|
||||
msg: z.string(),
|
||||
type: z.string(),
|
||||
})
|
||||
),
|
||||
}),
|
||||
z.object({
|
||||
error: z.object({
|
||||
message: z.string(),
|
||||
type: z.string(),
|
||||
code: z.number(),
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
|
||||
type PerplexityError = z.infer<typeof PerplexityErrorSchema>;
|
||||
|
||||
export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
readonly type = CopilotProviderType.Perplexity;
|
||||
|
||||
@@ -90,18 +70,38 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
},
|
||||
];
|
||||
|
||||
#instance!: VercelPerplexityProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
return !!this.config.apiKey;
|
||||
}
|
||||
|
||||
protected override setup() {
|
||||
super.setup();
|
||||
this.#instance = createPerplexity({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.endpoint,
|
||||
});
|
||||
}
|
||||
|
||||
private createNativeConfig(): NativeLlmBackendConfig {
|
||||
const baseUrl = this.config.endpoint || 'https://api.perplexity.ai';
|
||||
return {
|
||||
base_url: baseUrl.replace(/\/v1\/?$/, ''),
|
||||
auth_token: this.config.apiKey,
|
||||
};
|
||||
}
|
||||
|
||||
private createNativeAdapter(
|
||||
tools: ToolSet,
|
||||
nodeTextMiddleware?: NodeTextMiddleware[]
|
||||
) {
|
||||
return new NativeProviderAdapter(
|
||||
(request: NativeLlmRequest, signal?: AbortSignal) =>
|
||||
llmDispatchStream(
|
||||
'openai_chat',
|
||||
this.createNativeConfig(),
|
||||
request,
|
||||
signal
|
||||
),
|
||||
tools,
|
||||
this.MAX_STEPS,
|
||||
{ nodeTextMiddleware }
|
||||
);
|
||||
}
|
||||
|
||||
async text(
|
||||
@@ -114,32 +114,25 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
|
||||
const [system, msgs] = await chatToGPTMessage(messages, false);
|
||||
|
||||
const modelInstance = this.#instance(model.id);
|
||||
|
||||
const { text, sources } = await generateText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
temperature: options.temperature ?? 0,
|
||||
maxOutputTokens: options.maxTokens ?? 4096,
|
||||
abortSignal: options.signal,
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
withAttachment: false,
|
||||
include: ['citations'],
|
||||
middleware,
|
||||
});
|
||||
|
||||
const parser = new CitationParser();
|
||||
for (const source of sources.filter(s => s.sourceType === 'url')) {
|
||||
parser.push(source.url);
|
||||
}
|
||||
|
||||
let result = text.replaceAll(/<\/?think>\n/g, '\n---\n');
|
||||
result = parser.parse(result);
|
||||
result += parser.end();
|
||||
return result;
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
return await adapter.text(request, options.signal);
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -154,79 +147,33 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_calls')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
|
||||
const [system, msgs] = await chatToGPTMessage(messages, false);
|
||||
|
||||
const modelInstance = this.#instance(model.id);
|
||||
|
||||
const stream = streamText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
temperature: options.temperature ?? 0,
|
||||
maxOutputTokens: options.maxTokens ?? 4096,
|
||||
abortSignal: options.signal,
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
withAttachment: false,
|
||||
include: ['citations'],
|
||||
middleware,
|
||||
});
|
||||
|
||||
const parser = new CitationParser();
|
||||
for await (const chunk of stream.fullStream) {
|
||||
switch (chunk.type) {
|
||||
case 'source': {
|
||||
if (chunk.sourceType === 'url') {
|
||||
parser.push(chunk.url);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'text-delta': {
|
||||
const text = chunk.text.replaceAll(/<\/?think>\n?/g, '\n---\n');
|
||||
const result = parser.parse(text);
|
||||
yield result;
|
||||
break;
|
||||
}
|
||||
case 'finish-step': {
|
||||
const result = parser.end();
|
||||
yield result;
|
||||
break;
|
||||
}
|
||||
case 'error': {
|
||||
const json =
|
||||
typeof chunk.error === 'string'
|
||||
? JSON.parse(chunk.error)
|
||||
: chunk.error;
|
||||
if (json && typeof json === 'object') {
|
||||
const data = PerplexityErrorSchema.parse(json);
|
||||
if ('detail' in data || 'error' in data) {
|
||||
throw this.convertError(data);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
for await (const chunk of adapter.streamText(request, options.signal)) {
|
||||
yield chunk;
|
||||
}
|
||||
} catch (e) {
|
||||
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
|
||||
throw e;
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
private convertError(e: PerplexityError) {
|
||||
function getErrMessage(e: PerplexityError) {
|
||||
let err = 'Unexpected perplexity response';
|
||||
if ('detail' in e) {
|
||||
err = e.detail[0].msg || err;
|
||||
} else if ('error' in e) {
|
||||
err = e.error.message || err;
|
||||
}
|
||||
return err;
|
||||
}
|
||||
|
||||
throw new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: getErrMessage(e),
|
||||
});
|
||||
}
|
||||
|
||||
private handleError(e: any) {
|
||||
if (e instanceof CopilotProviderSideError) {
|
||||
return e;
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
import type { ProviderMiddlewareConfig } from '../config';
|
||||
import { CopilotProviderType } from './types';
|
||||
|
||||
const DEFAULT_MIDDLEWARE_BY_TYPE: Record<
|
||||
CopilotProviderType,
|
||||
ProviderMiddlewareConfig
|
||||
> = {
|
||||
[CopilotProviderType.OpenAI]: {
|
||||
rust: {
|
||||
request: ['normalize_messages'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
},
|
||||
node: {
|
||||
text: ['citation_footnote', 'callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.Anthropic]: {
|
||||
rust: {
|
||||
request: ['normalize_messages', 'tool_schema_rewrite'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
},
|
||||
node: {
|
||||
text: ['citation_footnote', 'callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.AnthropicVertex]: {
|
||||
rust: {
|
||||
request: ['normalize_messages', 'tool_schema_rewrite'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
},
|
||||
node: {
|
||||
text: ['citation_footnote', 'callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.Morph]: {
|
||||
rust: {
|
||||
request: ['clamp_max_tokens'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
},
|
||||
node: {
|
||||
text: ['citation_footnote', 'callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.Perplexity]: {
|
||||
rust: {
|
||||
request: ['clamp_max_tokens'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
},
|
||||
node: {
|
||||
text: ['citation_footnote', 'callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.Gemini]: {
|
||||
node: {
|
||||
text: ['callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.GeminiVertex]: {
|
||||
node: {
|
||||
text: ['callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.FAL]: {},
|
||||
};
|
||||
|
||||
function unique<T>(items: T[]) {
|
||||
return [...new Set(items)];
|
||||
}
|
||||
|
||||
function mergeArray<T>(base: T[] | undefined, override: T[] | undefined) {
|
||||
if (!base?.length && !override?.length) {
|
||||
return undefined;
|
||||
}
|
||||
return unique([...(base ?? []), ...(override ?? [])]);
|
||||
}
|
||||
|
||||
export function mergeProviderMiddleware(
|
||||
defaults: ProviderMiddlewareConfig,
|
||||
override?: ProviderMiddlewareConfig
|
||||
): ProviderMiddlewareConfig {
|
||||
return {
|
||||
rust: {
|
||||
request: mergeArray(defaults.rust?.request, override?.rust?.request),
|
||||
stream: mergeArray(defaults.rust?.stream, override?.rust?.stream),
|
||||
},
|
||||
node: {
|
||||
text: mergeArray(defaults.node?.text, override?.node?.text),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export function resolveProviderMiddleware(
|
||||
type: CopilotProviderType,
|
||||
override?: ProviderMiddlewareConfig
|
||||
): ProviderMiddlewareConfig {
|
||||
const defaults = DEFAULT_MIDDLEWARE_BY_TYPE[type] ?? {};
|
||||
return mergeProviderMiddleware(defaults, override);
|
||||
}
|
||||
@@ -0,0 +1,273 @@
|
||||
import type {
|
||||
CopilotProviderConfigMap,
|
||||
CopilotProviderDefaults,
|
||||
CopilotProviderProfile,
|
||||
ProviderMiddlewareConfig,
|
||||
} from '../config';
|
||||
import { resolveProviderMiddleware } from './provider-middleware';
|
||||
import { CopilotProviderType, type ModelOutputType } from './types';
|
||||
|
||||
const PROVIDER_ID_PATTERN = /^[a-zA-Z0-9-_]+$/;
|
||||
|
||||
const LEGACY_PROVIDER_ORDER: CopilotProviderType[] = [
|
||||
CopilotProviderType.OpenAI,
|
||||
CopilotProviderType.FAL,
|
||||
CopilotProviderType.Gemini,
|
||||
CopilotProviderType.GeminiVertex,
|
||||
CopilotProviderType.Perplexity,
|
||||
CopilotProviderType.Anthropic,
|
||||
CopilotProviderType.AnthropicVertex,
|
||||
CopilotProviderType.Morph,
|
||||
];
|
||||
|
||||
const LEGACY_PROVIDER_PRIORITY = LEGACY_PROVIDER_ORDER.reduce(
|
||||
(acc, type, index) => {
|
||||
acc[type] = LEGACY_PROVIDER_ORDER.length - index;
|
||||
return acc;
|
||||
},
|
||||
{} as Record<CopilotProviderType, number>
|
||||
);
|
||||
|
||||
type LegacyProvidersConfig = Partial<
|
||||
Record<CopilotProviderType, CopilotProviderConfigMap[CopilotProviderType]>
|
||||
>;
|
||||
|
||||
export type CopilotProvidersConfigInput = LegacyProvidersConfig & {
|
||||
profiles?: CopilotProviderProfile[] | null;
|
||||
defaults?: CopilotProviderDefaults | null;
|
||||
};
|
||||
|
||||
export type NormalizedCopilotProviderProfile = Omit<
|
||||
CopilotProviderProfile,
|
||||
'enabled' | 'priority' | 'middleware'
|
||||
> & {
|
||||
enabled: boolean;
|
||||
priority: number;
|
||||
middleware: ProviderMiddlewareConfig;
|
||||
};
|
||||
|
||||
export type CopilotProviderRegistry = {
|
||||
profiles: Map<string, NormalizedCopilotProviderProfile>;
|
||||
defaults: CopilotProviderDefaults;
|
||||
order: string[];
|
||||
byType: Map<CopilotProviderType, string[]>;
|
||||
};
|
||||
|
||||
export type ResolveModelResult = {
|
||||
rawModelId?: string;
|
||||
modelId?: string;
|
||||
explicitProviderId?: string;
|
||||
candidateProviderIds: string[];
|
||||
};
|
||||
|
||||
type ResolveModelOptions = {
|
||||
registry: CopilotProviderRegistry;
|
||||
modelId?: string;
|
||||
outputType?: ModelOutputType;
|
||||
availableProviderIds?: Iterable<string>;
|
||||
preferredProviderIds?: Iterable<string>;
|
||||
};
|
||||
|
||||
function unique<T>(list: T[]): T[] {
|
||||
return [...new Set(list)];
|
||||
}
|
||||
|
||||
function asArray<T>(iter?: Iterable<T>): T[] {
|
||||
return iter ? Array.from(iter) : [];
|
||||
}
|
||||
|
||||
function parseModelPrefix(
|
||||
registry: CopilotProviderRegistry,
|
||||
modelId: string
|
||||
): { providerId: string; modelId?: string } | null {
|
||||
const index = modelId.indexOf('/');
|
||||
if (index <= 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const providerId = modelId.slice(0, index);
|
||||
if (!registry.profiles.has(providerId)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const model = modelId.slice(index + 1);
|
||||
return { providerId, modelId: model || undefined };
|
||||
}
|
||||
|
||||
function normalizeProfile(
|
||||
profile: CopilotProviderProfile
|
||||
): NormalizedCopilotProviderProfile {
|
||||
return {
|
||||
...profile,
|
||||
enabled: profile.enabled !== false,
|
||||
priority: profile.priority ?? 0,
|
||||
middleware: resolveProviderMiddleware(profile.type, profile.middleware),
|
||||
};
|
||||
}
|
||||
|
||||
function toLegacyProfiles(
|
||||
config: CopilotProvidersConfigInput
|
||||
): CopilotProviderProfile[] {
|
||||
const legacyProfiles: CopilotProviderProfile[] = [];
|
||||
for (const type of LEGACY_PROVIDER_ORDER) {
|
||||
const legacyConfig = config[type];
|
||||
if (!legacyConfig) {
|
||||
continue;
|
||||
}
|
||||
legacyProfiles.push({
|
||||
id: `${type}-default`,
|
||||
type,
|
||||
priority: LEGACY_PROVIDER_PRIORITY[type],
|
||||
config: legacyConfig,
|
||||
} as CopilotProviderProfile);
|
||||
}
|
||||
return legacyProfiles;
|
||||
}
|
||||
|
||||
function mergeProfiles(
|
||||
explicitProfiles: CopilotProviderProfile[],
|
||||
legacyProfiles: CopilotProviderProfile[]
|
||||
): CopilotProviderProfile[] {
|
||||
const profiles = new Map<string, CopilotProviderProfile>();
|
||||
|
||||
for (const profile of explicitProfiles) {
|
||||
if (!PROVIDER_ID_PATTERN.test(profile.id)) {
|
||||
throw new Error(`Invalid copilot provider profile id: ${profile.id}`);
|
||||
}
|
||||
if (profiles.has(profile.id)) {
|
||||
throw new Error(`Duplicated copilot provider profile id: ${profile.id}`);
|
||||
}
|
||||
profiles.set(profile.id, profile);
|
||||
}
|
||||
|
||||
for (const profile of legacyProfiles) {
|
||||
if (!profiles.has(profile.id)) {
|
||||
profiles.set(profile.id, profile);
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(profiles.values());
|
||||
}
|
||||
|
||||
function sortProfiles(profiles: NormalizedCopilotProviderProfile[]) {
|
||||
return profiles.toSorted((a, b) => {
|
||||
if (a.priority !== b.priority) {
|
||||
return b.priority - a.priority;
|
||||
}
|
||||
return a.id.localeCompare(b.id);
|
||||
});
|
||||
}
|
||||
|
||||
function assertDefaults(
|
||||
defaults: CopilotProviderDefaults,
|
||||
profiles: Map<string, NormalizedCopilotProviderProfile>
|
||||
) {
|
||||
for (const providerId of Object.values(defaults)) {
|
||||
if (!providerId) {
|
||||
continue;
|
||||
}
|
||||
if (!profiles.has(providerId)) {
|
||||
throw new Error(
|
||||
`Copilot provider defaults references unknown providerId: ${providerId}`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function buildProviderRegistry(
|
||||
config: CopilotProvidersConfigInput
|
||||
): CopilotProviderRegistry {
|
||||
const explicitProfiles = config.profiles ?? [];
|
||||
const legacyProfiles = toLegacyProfiles(config);
|
||||
const mergedProfiles = mergeProfiles(explicitProfiles, legacyProfiles)
|
||||
.map(normalizeProfile)
|
||||
.filter(profile => profile.enabled);
|
||||
const sortedProfiles = sortProfiles(mergedProfiles);
|
||||
|
||||
const profiles = new Map(
|
||||
sortedProfiles.map(profile => [profile.id, profile] as const)
|
||||
);
|
||||
const defaults = config.defaults ?? {};
|
||||
assertDefaults(defaults, profiles);
|
||||
|
||||
const order = sortedProfiles.map(profile => profile.id);
|
||||
const byType = new Map<CopilotProviderType, string[]>();
|
||||
for (const profile of sortedProfiles) {
|
||||
const ids = byType.get(profile.type) ?? [];
|
||||
ids.push(profile.id);
|
||||
byType.set(profile.type, ids);
|
||||
}
|
||||
|
||||
return { profiles, defaults, order, byType };
|
||||
}
|
||||
|
||||
export function resolveModel({
|
||||
registry,
|
||||
modelId,
|
||||
outputType,
|
||||
availableProviderIds,
|
||||
preferredProviderIds,
|
||||
}: ResolveModelOptions): ResolveModelResult {
|
||||
const available = new Set(asArray(availableProviderIds));
|
||||
const preferred = new Set(asArray(preferredProviderIds));
|
||||
const hasAvailableFilter = available.size > 0;
|
||||
const hasPreferredFilter = preferred.size > 0;
|
||||
|
||||
const isAllowed = (providerId: string) => {
|
||||
const profile = registry.profiles.get(providerId);
|
||||
if (!profile?.enabled) {
|
||||
return false;
|
||||
}
|
||||
if (hasAvailableFilter && !available.has(providerId)) {
|
||||
return false;
|
||||
}
|
||||
if (hasPreferredFilter && !preferred.has(providerId)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
const prefixed = modelId ? parseModelPrefix(registry, modelId) : null;
|
||||
if (prefixed) {
|
||||
return {
|
||||
rawModelId: modelId,
|
||||
modelId: prefixed.modelId,
|
||||
explicitProviderId: prefixed.providerId,
|
||||
candidateProviderIds: isAllowed(prefixed.providerId)
|
||||
? [prefixed.providerId]
|
||||
: [],
|
||||
};
|
||||
}
|
||||
|
||||
const fallbackOrder = [
|
||||
...(outputType ? [registry.defaults[outputType]] : []),
|
||||
registry.defaults.fallback,
|
||||
...registry.order,
|
||||
].filter((id): id is string => !!id);
|
||||
|
||||
return {
|
||||
rawModelId: modelId,
|
||||
modelId,
|
||||
candidateProviderIds: unique(
|
||||
fallbackOrder.filter(providerId => isAllowed(providerId))
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
export function stripProviderPrefix(
|
||||
registry: CopilotProviderRegistry,
|
||||
providerId: string,
|
||||
modelId?: string
|
||||
) {
|
||||
if (!modelId) {
|
||||
return modelId;
|
||||
}
|
||||
const prefixed = parseModelPrefix(registry, modelId);
|
||||
if (!prefixed) {
|
||||
return modelId;
|
||||
}
|
||||
if (prefixed.providerId !== providerId) {
|
||||
return modelId;
|
||||
}
|
||||
return prefixed.modelId;
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
import { AsyncLocalStorage } from 'node:async_hooks';
|
||||
|
||||
import { Inject, Injectable, Logger } from '@nestjs/common';
|
||||
import { ModuleRef } from '@nestjs/core';
|
||||
import { Tool, ToolSet } from 'ai';
|
||||
@@ -13,6 +15,7 @@ import { DocReader, DocWriter } from '../../../core/doc';
|
||||
import { AccessController } from '../../../core/permission';
|
||||
import { Models } from '../../../models';
|
||||
import { IndexerService } from '../../indexer';
|
||||
import type { ProviderMiddlewareConfig } from '../config';
|
||||
import { CopilotContextService } from '../context/service';
|
||||
import { PromptService } from '../prompt/service';
|
||||
import {
|
||||
@@ -40,6 +43,8 @@ import {
|
||||
createSectionEditTool,
|
||||
} from '../tools';
|
||||
import { CopilotProviderFactory } from './factory';
|
||||
import { resolveProviderMiddleware } from './provider-middleware';
|
||||
import { buildProviderRegistry } from './provider-registry';
|
||||
import {
|
||||
type CopilotChatOptions,
|
||||
CopilotChatTools,
|
||||
@@ -58,11 +63,14 @@ import {
|
||||
StreamObject,
|
||||
} from './types';
|
||||
|
||||
const providerProfileContext = new AsyncLocalStorage<string>();
|
||||
|
||||
@Injectable()
|
||||
export abstract class CopilotProvider<C = any> {
|
||||
protected readonly logger = new Logger(this.constructor.name);
|
||||
protected readonly MAX_STEPS = 20;
|
||||
protected onlineModelList: string[] = [];
|
||||
|
||||
abstract readonly type: CopilotProviderType;
|
||||
abstract readonly models: CopilotProviderModel[];
|
||||
abstract configured(): boolean;
|
||||
@@ -70,8 +78,39 @@ export abstract class CopilotProvider<C = any> {
|
||||
@Inject() protected readonly AFFiNEConfig!: Config;
|
||||
@Inject() protected readonly factory!: CopilotProviderFactory;
|
||||
@Inject() protected readonly moduleRef!: ModuleRef;
|
||||
readonly #registeredProviderIds = new Set<string>();
|
||||
|
||||
runWithProfile<T>(providerId: string, callback: () => T): T {
|
||||
return providerProfileContext.run(providerId, callback);
|
||||
}
|
||||
|
||||
protected getActiveProviderId() {
|
||||
return providerProfileContext.getStore() ?? `${this.type}-default`;
|
||||
}
|
||||
|
||||
protected getActiveProviderMiddleware(): ProviderMiddlewareConfig {
|
||||
const providerId = this.getActiveProviderId();
|
||||
const registry = buildProviderRegistry(this.AFFiNEConfig.copilot.providers);
|
||||
const profile = registry.profiles.get(providerId);
|
||||
return profile?.middleware ?? resolveProviderMiddleware(this.type);
|
||||
}
|
||||
|
||||
protected metricLabels(
|
||||
model: string,
|
||||
labels: Record<string, string | number | boolean | undefined> = {}
|
||||
) {
|
||||
const providerId = this.getActiveProviderId();
|
||||
return { model, providerId, ...labels };
|
||||
}
|
||||
|
||||
get config(): C {
|
||||
const profileId = providerProfileContext.getStore();
|
||||
if (profileId) {
|
||||
const profile = this.AFFiNEConfig.copilot.providers.profiles?.find(
|
||||
profile => profile.id === profileId && profile.type === this.type
|
||||
);
|
||||
if (profile) return profile.config as C;
|
||||
}
|
||||
return this.AFFiNEConfig.copilot.providers[this.type] as C;
|
||||
}
|
||||
|
||||
@@ -88,15 +127,37 @@ export abstract class CopilotProvider<C = any> {
|
||||
}
|
||||
|
||||
protected setup() {
|
||||
if (this.configured()) {
|
||||
this.factory.register(this);
|
||||
if (env.selfhosted) {
|
||||
const registry = buildProviderRegistry(this.AFFiNEConfig.copilot.providers);
|
||||
const providerIds = registry.byType.get(this.type) ?? [];
|
||||
const nextProviderIds = new Set<string>();
|
||||
|
||||
for (const id of providerIds) {
|
||||
const configured = this.runWithProfile(id, () => this.configured());
|
||||
if (configured) {
|
||||
nextProviderIds.add(id);
|
||||
this.factory.register(id, this);
|
||||
} else {
|
||||
this.factory.unregister(id, this);
|
||||
}
|
||||
}
|
||||
|
||||
for (const providerId of this.#registeredProviderIds) {
|
||||
if (!nextProviderIds.has(providerId)) {
|
||||
this.factory.unregister(providerId, this);
|
||||
}
|
||||
}
|
||||
this.#registeredProviderIds.clear();
|
||||
for (const providerId of nextProviderIds) {
|
||||
this.#registeredProviderIds.add(providerId);
|
||||
}
|
||||
|
||||
if (env.selfhosted && nextProviderIds.size > 0) {
|
||||
const [providerId] = Array.from(nextProviderIds);
|
||||
this.runWithProfile(providerId, () => {
|
||||
this.refreshOnlineModels().catch(e =>
|
||||
this.logger.error('Failed to refresh online models', e)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
this.factory.unregister(this);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -91,7 +91,9 @@ export async function chatToGPTMessage(
|
||||
// so we need to use base64 encoded attachments instead
|
||||
useBase64Attachment: boolean = false
|
||||
): Promise<[string | undefined, ChatMessage[], ZodType?]> {
|
||||
const system = messages[0]?.role === 'system' ? messages.shift() : undefined;
|
||||
const hasSystem = messages[0]?.role === 'system';
|
||||
const system = hasSystem ? messages[0] : undefined;
|
||||
const normalizedMessages = hasSystem ? messages.slice(1) : messages;
|
||||
const schema =
|
||||
system?.params?.schema && system.params.schema instanceof ZodType
|
||||
? system.params.schema
|
||||
@@ -99,7 +101,7 @@ export async function chatToGPTMessage(
|
||||
|
||||
// filter redundant fields
|
||||
const msgs: ChatMessage[] = [];
|
||||
for (let { role, content, attachments, params } of messages.filter(
|
||||
for (let { role, content, attachments, params } of normalizedMessages.filter(
|
||||
m => m.role !== 'system'
|
||||
)) {
|
||||
content = content.trim();
|
||||
@@ -406,6 +408,34 @@ export class CitationParser {
|
||||
}
|
||||
}
|
||||
|
||||
export type CitationIndexedEvent = {
|
||||
type: 'citation';
|
||||
index: number;
|
||||
url: string;
|
||||
};
|
||||
|
||||
export class CitationFootnoteFormatter {
|
||||
private readonly citations = new Map<number, string>();
|
||||
|
||||
public consume(event: CitationIndexedEvent) {
|
||||
if (event.type !== 'citation') {
|
||||
return '';
|
||||
}
|
||||
this.citations.set(event.index, event.url);
|
||||
return '';
|
||||
}
|
||||
|
||||
public end() {
|
||||
const footnotes = Array.from(this.citations.entries())
|
||||
.sort((a, b) => a[0] - b[0])
|
||||
.map(
|
||||
([index, citation]) =>
|
||||
`[^${index}]: {"type":"url","url":"${encodeURIComponent(citation)}"}`
|
||||
);
|
||||
return footnotes.join('\n');
|
||||
}
|
||||
}
|
||||
|
||||
type ChunkType = TextStreamPart<CustomAITools>['type'];
|
||||
|
||||
export function toError(error: unknown): Error {
|
||||
@@ -703,21 +733,39 @@ export const VertexModelListSchema = z.object({
|
||||
),
|
||||
});
|
||||
|
||||
function normalizeUrl(baseURL?: string) {
|
||||
if (!baseURL?.trim()) {
|
||||
return undefined;
|
||||
}
|
||||
try {
|
||||
const url = new URL(baseURL);
|
||||
const serialized = url.toString();
|
||||
if (serialized.endsWith('/')) return serialized.slice(0, -1);
|
||||
return serialized;
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export function getVertexAnthropicBaseUrl(
|
||||
options: GoogleVertexAnthropicProviderSettings
|
||||
) {
|
||||
const normalizedBaseUrl = normalizeUrl(options.baseURL);
|
||||
if (normalizedBaseUrl) return normalizedBaseUrl;
|
||||
const { location, project } = options;
|
||||
if (!location || !project) return undefined;
|
||||
return `https://${location}-aiplatform.googleapis.com/v1/projects/${project}/locations/${location}/publishers/anthropic`;
|
||||
}
|
||||
|
||||
export async function getGoogleAuth(
|
||||
options: GoogleVertexAnthropicProviderSettings | GoogleVertexProviderSettings,
|
||||
publisher: 'anthropic' | 'google'
|
||||
) {
|
||||
function getBaseUrl() {
|
||||
const { baseURL, location } = options;
|
||||
if (baseURL?.trim()) {
|
||||
try {
|
||||
const url = new URL(baseURL);
|
||||
if (url.pathname.endsWith('/')) {
|
||||
url.pathname = url.pathname.slice(0, -1);
|
||||
}
|
||||
return url.toString();
|
||||
} catch {}
|
||||
} else if (location) {
|
||||
const normalizedBaseUrl = normalizeUrl(options.baseURL);
|
||||
if (normalizedBaseUrl) return normalizedBaseUrl;
|
||||
const { location } = options;
|
||||
if (location) {
|
||||
return `https://${location}-aiplatform.googleapis.com/v1beta1/publishers/${publisher}`;
|
||||
}
|
||||
return undefined;
|
||||
|
||||
@@ -4,7 +4,6 @@ import { BadRequestException, NotFoundException } from '@nestjs/common';
|
||||
import {
|
||||
Args,
|
||||
Field,
|
||||
Float,
|
||||
ID,
|
||||
InputType,
|
||||
Mutation,
|
||||
@@ -15,7 +14,6 @@ import {
|
||||
ResolveField,
|
||||
Resolver,
|
||||
} from '@nestjs/graphql';
|
||||
import { AiPromptRole } from '@prisma/client';
|
||||
import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars';
|
||||
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
|
||||
|
||||
@@ -313,57 +311,6 @@ class CopilotQuotaType {
|
||||
used!: number;
|
||||
}
|
||||
|
||||
registerEnumType(AiPromptRole, {
|
||||
name: 'CopilotPromptMessageRole',
|
||||
});
|
||||
|
||||
@InputType('CopilotPromptConfigInput')
|
||||
@ObjectType()
|
||||
class CopilotPromptConfigType {
|
||||
@Field(() => Float, { nullable: true })
|
||||
frequencyPenalty!: number | null;
|
||||
|
||||
@Field(() => Float, { nullable: true })
|
||||
presencePenalty!: number | null;
|
||||
|
||||
@Field(() => Float, { nullable: true })
|
||||
temperature!: number | null;
|
||||
|
||||
@Field(() => Float, { nullable: true })
|
||||
topP!: number | null;
|
||||
}
|
||||
|
||||
@InputType('CopilotPromptMessageInput')
|
||||
@ObjectType()
|
||||
class CopilotPromptMessageType {
|
||||
@Field(() => AiPromptRole)
|
||||
role!: AiPromptRole;
|
||||
|
||||
@Field(() => String)
|
||||
content!: string;
|
||||
|
||||
@Field(() => GraphQLJSON, { nullable: true })
|
||||
params!: Record<string, string> | null;
|
||||
}
|
||||
|
||||
@ObjectType()
|
||||
class CopilotPromptType {
|
||||
@Field(() => String)
|
||||
name!: string;
|
||||
|
||||
@Field(() => String)
|
||||
model!: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
action!: string | null;
|
||||
|
||||
@Field(() => CopilotPromptConfigType, { nullable: true })
|
||||
config!: CopilotPromptConfigType | null;
|
||||
|
||||
@Field(() => [CopilotPromptMessageType])
|
||||
messages!: CopilotPromptMessageType[];
|
||||
}
|
||||
|
||||
@ObjectType()
|
||||
class CopilotModelType {
|
||||
@Field(() => String)
|
||||
@@ -638,13 +585,8 @@ export class CopilotResolver {
|
||||
);
|
||||
}
|
||||
|
||||
@Mutation(() => String, {
|
||||
description: 'Create a chat session',
|
||||
})
|
||||
@CallMetric('ai', 'chat_session_create')
|
||||
async createCopilotSession(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'options', type: () => CreateChatSessionInput })
|
||||
private async createCopilotSessionInternal(
|
||||
user: CurrentUser,
|
||||
options: CreateChatSessionInput
|
||||
): Promise<string> {
|
||||
// permission check based on session type
|
||||
@@ -666,6 +608,42 @@ export class CopilotResolver {
|
||||
});
|
||||
}
|
||||
|
||||
@Mutation(() => String, {
|
||||
description: 'Create a chat session',
|
||||
deprecationReason: 'use `createCopilotSessionWithHistory` instead',
|
||||
})
|
||||
@CallMetric('ai', 'chat_session_create')
|
||||
async createCopilotSession(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'options', type: () => CreateChatSessionInput })
|
||||
options: CreateChatSessionInput
|
||||
): Promise<string> {
|
||||
return await this.createCopilotSessionInternal(user, options);
|
||||
}
|
||||
|
||||
@Mutation(() => CopilotHistoriesType, {
|
||||
description: 'Create a chat session and return full session payload',
|
||||
})
|
||||
@CallMetric('ai', 'chat_session_create_with_history')
|
||||
async createCopilotSessionWithHistory(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'options', type: () => CreateChatSessionInput })
|
||||
options: CreateChatSessionInput
|
||||
): Promise<CopilotHistoriesType> {
|
||||
const sessionId = await this.createCopilotSessionInternal(user, options);
|
||||
const session = await this.chatSession.getSessionInfo(sessionId);
|
||||
if (!session) {
|
||||
throw new NotFoundException('Session not found');
|
||||
}
|
||||
return {
|
||||
...session,
|
||||
messages: session.messages.map(message => ({
|
||||
...message,
|
||||
id: message.id,
|
||||
})) as ChatMessageType[],
|
||||
};
|
||||
}
|
||||
|
||||
@Mutation(() => String, {
|
||||
description: 'Update a chat session',
|
||||
})
|
||||
@@ -939,31 +917,10 @@ export class UserCopilotResolver {
|
||||
}
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class CreateCopilotPromptInput {
|
||||
@Field(() => String)
|
||||
name!: string;
|
||||
|
||||
@Field(() => String)
|
||||
model!: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
action!: string | null;
|
||||
|
||||
@Field(() => CopilotPromptConfigType, { nullable: true })
|
||||
config!: CopilotPromptConfigType | null;
|
||||
|
||||
@Field(() => [CopilotPromptMessageType])
|
||||
messages!: CopilotPromptMessageType[];
|
||||
}
|
||||
|
||||
@Admin()
|
||||
@Resolver(() => String)
|
||||
export class PromptsManagementResolver {
|
||||
constructor(
|
||||
private readonly cron: CopilotCronJobs,
|
||||
private readonly promptService: PromptService
|
||||
) {}
|
||||
constructor(private readonly cron: CopilotCronJobs) {}
|
||||
|
||||
@Mutation(() => Boolean, {
|
||||
description: 'Trigger generate missing titles cron job',
|
||||
@@ -980,48 +937,4 @@ export class PromptsManagementResolver {
|
||||
await this.cron.triggerCleanupTrashedDocEmbeddings();
|
||||
return true;
|
||||
}
|
||||
|
||||
@Query(() => [CopilotPromptType], {
|
||||
description: 'List all copilot prompts',
|
||||
})
|
||||
async listCopilotPrompts() {
|
||||
const prompts = await this.promptService.list();
|
||||
return prompts.filter(
|
||||
p =>
|
||||
p.messages.length > 0 &&
|
||||
// ignore internal prompts
|
||||
!p.name.startsWith('workflow:') &&
|
||||
!p.name.startsWith('debug:') &&
|
||||
!p.name.startsWith('chat:') &&
|
||||
!p.name.startsWith('action:')
|
||||
);
|
||||
}
|
||||
|
||||
@Mutation(() => CopilotPromptType, {
|
||||
description: 'Create a copilot prompt',
|
||||
})
|
||||
async createCopilotPrompt(
|
||||
@Args({ type: () => CreateCopilotPromptInput, name: 'input' })
|
||||
input: CreateCopilotPromptInput
|
||||
) {
|
||||
await this.promptService.set(
|
||||
input.name,
|
||||
input.model,
|
||||
input.messages,
|
||||
input.config
|
||||
);
|
||||
return this.promptService.get(input.name);
|
||||
}
|
||||
|
||||
@Mutation(() => CopilotPromptType, {
|
||||
description: 'Update a copilot prompt',
|
||||
})
|
||||
async updateCopilotPrompt(
|
||||
@Args('name') name: string,
|
||||
@Args('messages', { type: () => [CopilotPromptMessageType] })
|
||||
messages: CopilotPromptMessageType[]
|
||||
) {
|
||||
await this.promptService.update(name, { messages, modified: true });
|
||||
return this.promptService.get(name);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import { AiPromptRole } from '@prisma/client';
|
||||
import { pick } from 'lodash-es';
|
||||
|
||||
import {
|
||||
Config,
|
||||
CopilotActionTaken,
|
||||
CopilotMessageNotFound,
|
||||
CopilotPromptNotFound,
|
||||
@@ -31,6 +32,7 @@ import { ChatMessageCache } from './message';
|
||||
import { ChatPrompt } from './prompt/chat-prompt';
|
||||
import { PromptService } from './prompt/service';
|
||||
import { CopilotProviderFactory } from './providers/factory';
|
||||
import { buildProviderRegistry } from './providers/provider-registry';
|
||||
import {
|
||||
ModelOutputType,
|
||||
type PromptMessage,
|
||||
@@ -105,10 +107,31 @@ export class ChatSession implements AsyncDisposable {
|
||||
hasPayment: boolean,
|
||||
requestedModelId?: string
|
||||
): Promise<string> {
|
||||
const config = this.moduleRef.get(Config, { strict: false });
|
||||
const registry = config
|
||||
? buildProviderRegistry(config.copilot.providers)
|
||||
: null;
|
||||
const defaultModel = this.model;
|
||||
const normalize = (m?: string) =>
|
||||
!!m && this.optionalModels.includes(m) ? m : defaultModel;
|
||||
const isPro = (m?: string) => !!m && this.proModels.includes(m);
|
||||
const normalizeModel = (modelId?: string) => {
|
||||
if (!modelId) return modelId;
|
||||
const separatorIndex = modelId.indexOf('/');
|
||||
if (separatorIndex <= 0) return modelId;
|
||||
const providerId = modelId.slice(0, separatorIndex);
|
||||
if (!registry?.profiles.has(providerId)) return modelId;
|
||||
return modelId.slice(separatorIndex + 1);
|
||||
};
|
||||
const inModelList = (models: string[], modelId?: string) => {
|
||||
if (!modelId) return false;
|
||||
return (
|
||||
models.includes(modelId) ||
|
||||
models.includes(normalizeModel(modelId) ?? '')
|
||||
);
|
||||
};
|
||||
const normalize = (m?: string) => {
|
||||
if (inModelList(this.optionalModels, m)) return m;
|
||||
return defaultModel;
|
||||
};
|
||||
const isPro = (m?: string) => inModelList(this.proModels, m);
|
||||
|
||||
// try resolve payment subscription service lazily
|
||||
let paymentEnabled = hasPayment;
|
||||
@@ -132,10 +155,19 @@ export class ChatSession implements AsyncDisposable {
|
||||
}
|
||||
|
||||
if (paymentEnabled && !isUserAIPro && isPro(requestedModelId)) {
|
||||
if (!defaultModel) {
|
||||
throw new CopilotSessionInvalidInput(
|
||||
'Model is required for AI subscription fallback'
|
||||
);
|
||||
}
|
||||
return defaultModel;
|
||||
}
|
||||
|
||||
return normalize(requestedModelId);
|
||||
const resolvedModel = normalize(requestedModelId);
|
||||
if (!resolvedModel) {
|
||||
throw new CopilotSessionInvalidInput('Model is required');
|
||||
}
|
||||
return resolvedModel;
|
||||
}
|
||||
|
||||
push(message: ChatMessage) {
|
||||
|
||||
@@ -32,16 +32,22 @@ export const buildBlobContentGetter = (
|
||||
return;
|
||||
}
|
||||
|
||||
const contextFile = context.files.find(
|
||||
file => file.blobId === blobId || file.id === blobId
|
||||
);
|
||||
const canonicalBlobId = contextFile?.blobId ?? blobId;
|
||||
const targetFileId = contextFile?.id;
|
||||
const [file, blob] = await Promise.all([
|
||||
context?.getFileContent(blobId, chunk),
|
||||
context?.getBlobContent(blobId, chunk),
|
||||
targetFileId ? context.getFileContent(targetFileId, chunk) : undefined,
|
||||
context.getBlobContent(canonicalBlobId, chunk),
|
||||
]);
|
||||
const content = file?.trim() || blob?.trim();
|
||||
if (!content) {
|
||||
return;
|
||||
}
|
||||
if (!content) return;
|
||||
const info = contextFile
|
||||
? { fileName: contextFile.name, fileType: contextFile.mimeType }
|
||||
: {};
|
||||
|
||||
return { blobId, chunk, content };
|
||||
return { blobId: canonicalBlobId, chunk, content, ...info };
|
||||
};
|
||||
return getBlobContent;
|
||||
};
|
||||
|
||||
@@ -14,6 +14,7 @@ import type {
|
||||
import { HTMLRewriter } from 'htmlrewriter';
|
||||
|
||||
import {
|
||||
applyAttachHeaders,
|
||||
BadRequest,
|
||||
Cache,
|
||||
readResponseBufferWithLimit,
|
||||
@@ -127,15 +128,18 @@ export class WorkerController {
|
||||
if (buffer.length === 0) {
|
||||
return resp.status(404).header(getCorsHeaders(origin)).send();
|
||||
}
|
||||
return resp
|
||||
.status(200)
|
||||
.header({
|
||||
...getCorsHeaders(origin),
|
||||
...(origin ? { Vary: 'Origin' } : {}),
|
||||
'Access-Control-Allow-Methods': 'GET',
|
||||
'Content-Type': 'image/*',
|
||||
})
|
||||
.send(buffer);
|
||||
resp.header({
|
||||
...getCorsHeaders(origin),
|
||||
...(origin ? { Vary: 'Origin' } : {}),
|
||||
'Access-Control-Allow-Methods': 'GET',
|
||||
});
|
||||
applyAttachHeaders(resp, { buffer });
|
||||
const contentType = resp.getHeader('Content-Type') as string | undefined;
|
||||
if (contentType?.startsWith('image/')) {
|
||||
return resp.status(200).send(buffer);
|
||||
} else {
|
||||
throw new BadRequest('Invalid content type');
|
||||
}
|
||||
}
|
||||
|
||||
let response: Response;
|
||||
@@ -171,39 +175,39 @@ export class WorkerController {
|
||||
throw new BadRequest('Failed to fetch image');
|
||||
}
|
||||
if (response.ok) {
|
||||
const contentType = response.headers.get('Content-Type');
|
||||
if (contentType?.startsWith('image/')) {
|
||||
let buffer: Buffer;
|
||||
try {
|
||||
buffer = await readResponseBufferWithLimit(
|
||||
response,
|
||||
IMAGE_PROXY_MAX_BYTES
|
||||
);
|
||||
} catch (error) {
|
||||
if (error instanceof ResponseTooLargeError) {
|
||||
this.logger.warn('Image proxy response too large', {
|
||||
url: imageURL,
|
||||
limitBytes: error.data?.limitBytes,
|
||||
receivedBytes: error.data?.receivedBytes,
|
||||
});
|
||||
throw new BadRequest('Response too large');
|
||||
}
|
||||
throw error;
|
||||
let buffer: Buffer;
|
||||
try {
|
||||
buffer = await readResponseBufferWithLimit(
|
||||
response,
|
||||
IMAGE_PROXY_MAX_BYTES
|
||||
);
|
||||
} catch (error) {
|
||||
if (error instanceof ResponseTooLargeError) {
|
||||
this.logger.warn('Image proxy response too large', {
|
||||
url: imageURL,
|
||||
limitBytes: error.data?.limitBytes,
|
||||
receivedBytes: error.data?.receivedBytes,
|
||||
});
|
||||
throw new BadRequest('Response too large');
|
||||
}
|
||||
await this.cache.set(cachedUrl, buffer.toString('base64'), {
|
||||
ttl: CACHE_TTL,
|
||||
});
|
||||
const contentDisposition = response.headers.get('Content-Disposition');
|
||||
return resp
|
||||
.status(200)
|
||||
.header({
|
||||
...getCorsHeaders(origin),
|
||||
...(origin ? { Vary: 'Origin' } : {}),
|
||||
'Access-Control-Allow-Methods': 'GET',
|
||||
'Content-Type': contentType,
|
||||
'Content-Disposition': contentDisposition,
|
||||
})
|
||||
.send(buffer);
|
||||
throw error;
|
||||
}
|
||||
await this.cache.set(cachedUrl, buffer.toString('base64'), {
|
||||
ttl: CACHE_TTL,
|
||||
});
|
||||
const contentDisposition = response.headers.get('Content-Disposition');
|
||||
resp.header({
|
||||
...getCorsHeaders(origin),
|
||||
...(origin ? { Vary: 'Origin' } : {}),
|
||||
'Access-Control-Allow-Methods': 'GET',
|
||||
});
|
||||
if (contentDisposition) {
|
||||
resp.setHeader('Content-Disposition', contentDisposition);
|
||||
}
|
||||
applyAttachHeaders(resp, { buffer });
|
||||
const contentType = resp.getHeader('Content-Type') as string | undefined;
|
||||
if (contentType?.startsWith('image/')) {
|
||||
return resp.status(200).send(buffer);
|
||||
} else {
|
||||
throw new BadRequest('Invalid content type');
|
||||
}
|
||||
|
||||
@@ -607,50 +607,10 @@ type CopilotModelsType {
|
||||
proModels: [CopilotModelType!]!
|
||||
}
|
||||
|
||||
input CopilotPromptConfigInput {
|
||||
frequencyPenalty: Float
|
||||
presencePenalty: Float
|
||||
temperature: Float
|
||||
topP: Float
|
||||
}
|
||||
|
||||
type CopilotPromptConfigType {
|
||||
frequencyPenalty: Float
|
||||
presencePenalty: Float
|
||||
temperature: Float
|
||||
topP: Float
|
||||
}
|
||||
|
||||
input CopilotPromptMessageInput {
|
||||
content: String!
|
||||
params: JSON
|
||||
role: CopilotPromptMessageRole!
|
||||
}
|
||||
|
||||
enum CopilotPromptMessageRole {
|
||||
assistant
|
||||
system
|
||||
user
|
||||
}
|
||||
|
||||
type CopilotPromptMessageType {
|
||||
content: String!
|
||||
params: JSON
|
||||
role: CopilotPromptMessageRole!
|
||||
}
|
||||
|
||||
type CopilotPromptNotFoundDataType {
|
||||
name: String!
|
||||
}
|
||||
|
||||
type CopilotPromptType {
|
||||
action: String
|
||||
config: CopilotPromptConfigType
|
||||
messages: [CopilotPromptMessageType!]!
|
||||
model: String!
|
||||
name: String!
|
||||
}
|
||||
|
||||
type CopilotProviderNotSupportedDataType {
|
||||
kind: String!
|
||||
provider: String!
|
||||
@@ -747,14 +707,6 @@ input CreateCheckoutSessionInput {
|
||||
variant: SubscriptionVariant
|
||||
}
|
||||
|
||||
input CreateCopilotPromptInput {
|
||||
action: String
|
||||
config: CopilotPromptConfigInput
|
||||
messages: [CopilotPromptMessageInput!]!
|
||||
model: String!
|
||||
name: String!
|
||||
}
|
||||
|
||||
input CreateUserInput {
|
||||
email: String!
|
||||
name: String
|
||||
@@ -1551,11 +1503,11 @@ type Mutation {
|
||||
"""Create a chat message"""
|
||||
createCopilotMessage(options: CreateChatMessageInput!): String!
|
||||
|
||||
"""Create a copilot prompt"""
|
||||
createCopilotPrompt(input: CreateCopilotPromptInput!): CopilotPromptType!
|
||||
|
||||
"""Create a chat session"""
|
||||
createCopilotSession(options: CreateChatSessionInput!): String!
|
||||
createCopilotSession(options: CreateChatSessionInput!): String! @deprecated(reason: "use `createCopilotSessionWithHistory` instead")
|
||||
|
||||
"""Create a chat session and return full session payload"""
|
||||
createCopilotSessionWithHistory(options: CreateChatSessionInput!): CopilotHistories!
|
||||
|
||||
"""Create a stripe customer portal to manage payment methods"""
|
||||
createCustomerPortal: String!
|
||||
@@ -1672,9 +1624,6 @@ type Mutation {
|
||||
"""Update a comment content"""
|
||||
updateComment(input: CommentUpdateInput!): Boolean!
|
||||
|
||||
"""Update a copilot prompt"""
|
||||
updateCopilotPrompt(messages: [CopilotPromptMessageInput!]!, name: String!): CopilotPromptType!
|
||||
|
||||
"""Update a chat session"""
|
||||
updateCopilotSession(options: UpdateChatSessionInput!): String!
|
||||
updateDocDefaultRole(input: UpdateDocDefaultRoleInput!): Boolean!
|
||||
@@ -1923,9 +1872,6 @@ type Query {
|
||||
|
||||
"""get workspace invitation info"""
|
||||
getInviteInfo(inviteId: String!): InvitationType!
|
||||
|
||||
"""List all copilot prompts"""
|
||||
listCopilotPrompts: [CopilotPromptType!]!
|
||||
prices: [SubscriptionPrice!]!
|
||||
|
||||
"""Get public user by id"""
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
query getPrompts {
|
||||
listCopilotPrompts {
|
||||
name
|
||||
model
|
||||
action
|
||||
config {
|
||||
frequencyPenalty
|
||||
presencePenalty
|
||||
temperature
|
||||
topP
|
||||
}
|
||||
messages {
|
||||
role
|
||||
content
|
||||
params
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
mutation updatePrompt(
|
||||
$name: String!
|
||||
$messages: [CopilotPromptMessageInput!]!
|
||||
) {
|
||||
updateCopilotPrompt(name: $name, messages: $messages) {
|
||||
name
|
||||
model
|
||||
action
|
||||
config {
|
||||
frequencyPenalty
|
||||
presencePenalty
|
||||
temperature
|
||||
topP
|
||||
}
|
||||
messages {
|
||||
role
|
||||
content
|
||||
params
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/copilot.gql"
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
|
||||
query getCopilotDocSessions(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/copilot.gql"
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
|
||||
query getCopilotPinnedSessions(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/copilot.gql"
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
|
||||
query getCopilotWorkspaceSessions(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/copilot.gql"
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
|
||||
query getCopilotHistories(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
#import "./fragments/copilot-chat-history.gql"
|
||||
|
||||
mutation createCopilotSessionWithHistory($options: CreateChatSessionInput!) {
|
||||
createCopilotSessionWithHistory(options: $options) {
|
||||
...CopilotChatHistory
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/copilot.gql"
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
|
||||
query getCopilotLatestDocSession(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/copilot.gql"
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
|
||||
query getCopilotSession(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/copilot.gql"
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
|
||||
query getCopilotRecentSessions(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/copilot.gql"
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
|
||||
query getCopilotSessions(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
fragment CopilotChatHistory on CopilotHistories {
|
||||
sessionId
|
||||
workspaceId
|
||||
docId
|
||||
parentSessionId
|
||||
promptName
|
||||
model
|
||||
optionalModels
|
||||
action
|
||||
pinned
|
||||
title
|
||||
tokens
|
||||
messages {
|
||||
id
|
||||
role
|
||||
content
|
||||
attachments
|
||||
streamObjects {
|
||||
type
|
||||
textDelta
|
||||
toolCallId
|
||||
toolName
|
||||
args
|
||||
result
|
||||
}
|
||||
createdAt
|
||||
}
|
||||
createdAt
|
||||
updatedAt
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
fragment CopilotChatMessage on ChatMessage {
|
||||
id
|
||||
role
|
||||
content
|
||||
attachments
|
||||
streamObjects {
|
||||
type
|
||||
textDelta
|
||||
toolCallId
|
||||
toolName
|
||||
args
|
||||
result
|
||||
}
|
||||
createdAt
|
||||
}
|
||||
|
||||
fragment CopilotChatHistory on CopilotHistories {
|
||||
sessionId
|
||||
workspaceId
|
||||
docId
|
||||
parentSessionId
|
||||
promptName
|
||||
model
|
||||
optionalModels
|
||||
action
|
||||
pinned
|
||||
title
|
||||
tokens
|
||||
messages {
|
||||
...CopilotChatMessage
|
||||
}
|
||||
createdAt
|
||||
updatedAt
|
||||
}
|
||||
|
||||
fragment PaginatedCopilotChats on PaginatedCopilotHistoriesType {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
hasPreviousPage
|
||||
startCursor
|
||||
endCursor
|
||||
}
|
||||
edges {
|
||||
cursor
|
||||
node {
|
||||
...CopilotChatHistory
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
#import "./copilot-chat-history.gql"
|
||||
|
||||
fragment PaginatedCopilotChats on PaginatedCopilotHistoriesType {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
hasPreviousPage
|
||||
startCursor
|
||||
endCursor
|
||||
}
|
||||
edges {
|
||||
cursor
|
||||
node {
|
||||
...CopilotChatHistory
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6,21 +6,6 @@ export interface GraphQLQuery {
|
||||
file?: boolean;
|
||||
deprecations?: string[];
|
||||
}
|
||||
export const copilotChatMessageFragment = `fragment CopilotChatMessage on ChatMessage {
|
||||
id
|
||||
role
|
||||
content
|
||||
attachments
|
||||
streamObjects {
|
||||
type
|
||||
textDelta
|
||||
toolCallId
|
||||
toolName
|
||||
args
|
||||
result
|
||||
}
|
||||
createdAt
|
||||
}`;
|
||||
export const copilotChatHistoryFragment = `fragment CopilotChatHistory on CopilotHistories {
|
||||
sessionId
|
||||
workspaceId
|
||||
@@ -34,25 +19,23 @@ export const copilotChatHistoryFragment = `fragment CopilotChatHistory on Copilo
|
||||
title
|
||||
tokens
|
||||
messages {
|
||||
...CopilotChatMessage
|
||||
id
|
||||
role
|
||||
content
|
||||
attachments
|
||||
streamObjects {
|
||||
type
|
||||
textDelta
|
||||
toolCallId
|
||||
toolName
|
||||
args
|
||||
result
|
||||
}
|
||||
createdAt
|
||||
}
|
||||
createdAt
|
||||
updatedAt
|
||||
}`;
|
||||
export const paginatedCopilotChatsFragment = `fragment PaginatedCopilotChats on PaginatedCopilotHistoriesType {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
hasPreviousPage
|
||||
startCursor
|
||||
endCursor
|
||||
}
|
||||
edges {
|
||||
cursor
|
||||
node {
|
||||
...CopilotChatHistory
|
||||
}
|
||||
}
|
||||
}`;
|
||||
export const credentialsRequirementsFragment = `fragment CredentialsRequirements on CredentialsRequirementType {
|
||||
password {
|
||||
...PasswordLimits
|
||||
@@ -94,6 +77,20 @@ export const currentUserProfileFragment = `fragment CurrentUserProfile on UserTy
|
||||
}
|
||||
}
|
||||
}`;
|
||||
export const paginatedCopilotChatsFragment = `fragment PaginatedCopilotChats on PaginatedCopilotHistoriesType {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
hasPreviousPage
|
||||
startCursor
|
||||
endCursor
|
||||
}
|
||||
edges {
|
||||
cursor
|
||||
node {
|
||||
...CopilotChatHistory
|
||||
}
|
||||
}
|
||||
}${copilotChatHistoryFragment}`;
|
||||
export const passwordLimitsFragment = `fragment PasswordLimits on PasswordLimitsType {
|
||||
minLength
|
||||
maxLength
|
||||
@@ -404,52 +401,6 @@ export const appConfigQuery = {
|
||||
}`,
|
||||
};
|
||||
|
||||
export const getPromptsQuery = {
|
||||
id: 'getPromptsQuery' as const,
|
||||
op: 'getPrompts',
|
||||
query: `query getPrompts {
|
||||
listCopilotPrompts {
|
||||
name
|
||||
model
|
||||
action
|
||||
config {
|
||||
frequencyPenalty
|
||||
presencePenalty
|
||||
temperature
|
||||
topP
|
||||
}
|
||||
messages {
|
||||
role
|
||||
content
|
||||
params
|
||||
}
|
||||
}
|
||||
}`,
|
||||
};
|
||||
|
||||
export const updatePromptMutation = {
|
||||
id: 'updatePromptMutation' as const,
|
||||
op: 'updatePrompt',
|
||||
query: `mutation updatePrompt($name: String!, $messages: [CopilotPromptMessageInput!]!) {
|
||||
updateCopilotPrompt(name: $name, messages: $messages) {
|
||||
name
|
||||
model
|
||||
action
|
||||
config {
|
||||
frequencyPenalty
|
||||
presencePenalty
|
||||
temperature
|
||||
topP
|
||||
}
|
||||
messages {
|
||||
role
|
||||
content
|
||||
params
|
||||
}
|
||||
}
|
||||
}`,
|
||||
};
|
||||
|
||||
export const createUserMutation = {
|
||||
id: 'createUserMutation' as const,
|
||||
op: 'createUser',
|
||||
@@ -1411,8 +1362,6 @@ export const getCopilotDocSessionsQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
@@ -1432,8 +1381,6 @@ export const getCopilotPinnedSessionsQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
@@ -1449,8 +1396,6 @@ export const getCopilotWorkspaceSessionsQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
@@ -1466,8 +1411,6 @@ export const getCopilotHistoriesQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
@@ -1596,12 +1539,24 @@ export const cleanupCopilotSessionMutation = {
|
||||
}`,
|
||||
};
|
||||
|
||||
export const createCopilotSessionWithHistoryMutation = {
|
||||
id: 'createCopilotSessionWithHistoryMutation' as const,
|
||||
op: 'createCopilotSessionWithHistory',
|
||||
query: `mutation createCopilotSessionWithHistory($options: CreateChatSessionInput!) {
|
||||
createCopilotSessionWithHistory(options: $options) {
|
||||
...CopilotChatHistory
|
||||
}
|
||||
}
|
||||
${copilotChatHistoryFragment}`,
|
||||
};
|
||||
|
||||
export const createCopilotSessionMutation = {
|
||||
id: 'createCopilotSessionMutation' as const,
|
||||
op: 'createCopilotSession',
|
||||
query: `mutation createCopilotSession($options: CreateChatSessionInput!) {
|
||||
createCopilotSession(options: $options)
|
||||
}`,
|
||||
deprecations: ["'createCopilotSession' is deprecated: use `createCopilotSessionWithHistory` instead"],
|
||||
};
|
||||
|
||||
export const forkCopilotSessionMutation = {
|
||||
@@ -1628,8 +1583,6 @@ export const getCopilotLatestDocSessionQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
@@ -1645,8 +1598,6 @@ export const getCopilotSessionQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
@@ -1665,8 +1616,6 @@ export const getCopilotRecentSessionsQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
@@ -1690,8 +1639,6 @@ export const getCopilotSessionsQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
|
||||
@@ -725,54 +725,11 @@ export interface CopilotModelsType {
|
||||
proModels: Array<CopilotModelType>;
|
||||
}
|
||||
|
||||
export interface CopilotPromptConfigInput {
|
||||
frequencyPenalty?: InputMaybe<Scalars['Float']['input']>;
|
||||
presencePenalty?: InputMaybe<Scalars['Float']['input']>;
|
||||
temperature?: InputMaybe<Scalars['Float']['input']>;
|
||||
topP?: InputMaybe<Scalars['Float']['input']>;
|
||||
}
|
||||
|
||||
export interface CopilotPromptConfigType {
|
||||
__typename?: 'CopilotPromptConfigType';
|
||||
frequencyPenalty: Maybe<Scalars['Float']['output']>;
|
||||
presencePenalty: Maybe<Scalars['Float']['output']>;
|
||||
temperature: Maybe<Scalars['Float']['output']>;
|
||||
topP: Maybe<Scalars['Float']['output']>;
|
||||
}
|
||||
|
||||
export interface CopilotPromptMessageInput {
|
||||
content: Scalars['String']['input'];
|
||||
params?: InputMaybe<Scalars['JSON']['input']>;
|
||||
role: CopilotPromptMessageRole;
|
||||
}
|
||||
|
||||
export enum CopilotPromptMessageRole {
|
||||
assistant = 'assistant',
|
||||
system = 'system',
|
||||
user = 'user',
|
||||
}
|
||||
|
||||
export interface CopilotPromptMessageType {
|
||||
__typename?: 'CopilotPromptMessageType';
|
||||
content: Scalars['String']['output'];
|
||||
params: Maybe<Scalars['JSON']['output']>;
|
||||
role: CopilotPromptMessageRole;
|
||||
}
|
||||
|
||||
export interface CopilotPromptNotFoundDataType {
|
||||
__typename?: 'CopilotPromptNotFoundDataType';
|
||||
name: Scalars['String']['output'];
|
||||
}
|
||||
|
||||
export interface CopilotPromptType {
|
||||
__typename?: 'CopilotPromptType';
|
||||
action: Maybe<Scalars['String']['output']>;
|
||||
config: Maybe<CopilotPromptConfigType>;
|
||||
messages: Array<CopilotPromptMessageType>;
|
||||
model: Scalars['String']['output'];
|
||||
name: Scalars['String']['output'];
|
||||
}
|
||||
|
||||
export interface CopilotProviderNotSupportedDataType {
|
||||
__typename?: 'CopilotProviderNotSupportedDataType';
|
||||
kind: Scalars['String']['output'];
|
||||
@@ -884,14 +841,6 @@ export interface CreateCheckoutSessionInput {
|
||||
variant?: InputMaybe<SubscriptionVariant>;
|
||||
}
|
||||
|
||||
export interface CreateCopilotPromptInput {
|
||||
action?: InputMaybe<Scalars['String']['input']>;
|
||||
config?: InputMaybe<CopilotPromptConfigInput>;
|
||||
messages: Array<CopilotPromptMessageInput>;
|
||||
model: Scalars['String']['input'];
|
||||
name: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface CreateUserInput {
|
||||
email: Scalars['String']['input'];
|
||||
name?: InputMaybe<Scalars['String']['input']>;
|
||||
@@ -1752,10 +1701,13 @@ export interface Mutation {
|
||||
createCopilotContext: Scalars['String']['output'];
|
||||
/** Create a chat message */
|
||||
createCopilotMessage: Scalars['String']['output'];
|
||||
/** Create a copilot prompt */
|
||||
createCopilotPrompt: CopilotPromptType;
|
||||
/** Create a chat session */
|
||||
/**
|
||||
* Create a chat session
|
||||
* @deprecated use `createCopilotSessionWithHistory` instead
|
||||
*/
|
||||
createCopilotSession: Scalars['String']['output'];
|
||||
/** Create a chat session and return full session payload */
|
||||
createCopilotSessionWithHistory: CopilotHistories;
|
||||
/** Create a stripe customer portal to manage payment methods */
|
||||
createCustomerPortal: Scalars['String']['output'];
|
||||
createInviteLink: InviteLink;
|
||||
@@ -1845,8 +1797,6 @@ export interface Mutation {
|
||||
updateCalendarAccount: Maybe<CalendarAccountObjectType>;
|
||||
/** Update a comment content */
|
||||
updateComment: Scalars['Boolean']['output'];
|
||||
/** Update a copilot prompt */
|
||||
updateCopilotPrompt: CopilotPromptType;
|
||||
/** Update a chat session */
|
||||
updateCopilotSession: Scalars['String']['output'];
|
||||
updateDocDefaultRole: Scalars['Boolean']['output'];
|
||||
@@ -1998,11 +1948,11 @@ export interface MutationCreateCopilotMessageArgs {
|
||||
options: CreateChatMessageInput;
|
||||
}
|
||||
|
||||
export interface MutationCreateCopilotPromptArgs {
|
||||
input: CreateCopilotPromptInput;
|
||||
export interface MutationCreateCopilotSessionArgs {
|
||||
options: CreateChatSessionInput;
|
||||
}
|
||||
|
||||
export interface MutationCreateCopilotSessionArgs {
|
||||
export interface MutationCreateCopilotSessionWithHistoryArgs {
|
||||
options: CreateChatSessionInput;
|
||||
}
|
||||
|
||||
@@ -2262,11 +2212,6 @@ export interface MutationUpdateCommentArgs {
|
||||
input: CommentUpdateInput;
|
||||
}
|
||||
|
||||
export interface MutationUpdateCopilotPromptArgs {
|
||||
messages: Array<CopilotPromptMessageInput>;
|
||||
name: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationUpdateCopilotSessionArgs {
|
||||
options: UpdateChatSessionInput;
|
||||
}
|
||||
@@ -2554,8 +2499,6 @@ export interface Query {
|
||||
error: ErrorDataUnion;
|
||||
/** get workspace invitation info */
|
||||
getInviteInfo: InvitationType;
|
||||
/** List all copilot prompts */
|
||||
listCopilotPrompts: Array<CopilotPromptType>;
|
||||
prices: Array<SubscriptionPrice>;
|
||||
/** Get public user by id */
|
||||
publicUserById: Maybe<PublicUserType>;
|
||||
@@ -3886,59 +3829,6 @@ export type AppConfigQueryVariables = Exact<{ [key: string]: never }>;
|
||||
|
||||
export type AppConfigQuery = { __typename?: 'Query'; appConfig: any };
|
||||
|
||||
export type GetPromptsQueryVariables = Exact<{ [key: string]: never }>;
|
||||
|
||||
export type GetPromptsQuery = {
|
||||
__typename?: 'Query';
|
||||
listCopilotPrompts: Array<{
|
||||
__typename?: 'CopilotPromptType';
|
||||
name: string;
|
||||
model: string;
|
||||
action: string | null;
|
||||
config: {
|
||||
__typename?: 'CopilotPromptConfigType';
|
||||
frequencyPenalty: number | null;
|
||||
presencePenalty: number | null;
|
||||
temperature: number | null;
|
||||
topP: number | null;
|
||||
} | null;
|
||||
messages: Array<{
|
||||
__typename?: 'CopilotPromptMessageType';
|
||||
role: CopilotPromptMessageRole;
|
||||
content: string;
|
||||
params: Record<string, string> | null;
|
||||
}>;
|
||||
}>;
|
||||
};
|
||||
|
||||
export type UpdatePromptMutationVariables = Exact<{
|
||||
name: Scalars['String']['input'];
|
||||
messages: Array<CopilotPromptMessageInput> | CopilotPromptMessageInput;
|
||||
}>;
|
||||
|
||||
export type UpdatePromptMutation = {
|
||||
__typename?: 'Mutation';
|
||||
updateCopilotPrompt: {
|
||||
__typename?: 'CopilotPromptType';
|
||||
name: string;
|
||||
model: string;
|
||||
action: string | null;
|
||||
config: {
|
||||
__typename?: 'CopilotPromptConfigType';
|
||||
frequencyPenalty: number | null;
|
||||
presencePenalty: number | null;
|
||||
temperature: number | null;
|
||||
topP: number | null;
|
||||
} | null;
|
||||
messages: Array<{
|
||||
__typename?: 'CopilotPromptMessageType';
|
||||
role: CopilotPromptMessageRole;
|
||||
content: string;
|
||||
params: Record<string, string> | null;
|
||||
}>;
|
||||
};
|
||||
};
|
||||
|
||||
export type CreateUserMutationVariables = Exact<{
|
||||
input: CreateUserInput;
|
||||
}>;
|
||||
@@ -5425,6 +5315,47 @@ export type CleanupCopilotSessionMutation = {
|
||||
cleanupCopilotSession: Array<string>;
|
||||
};
|
||||
|
||||
export type CreateCopilotSessionWithHistoryMutationVariables = Exact<{
|
||||
options: CreateChatSessionInput;
|
||||
}>;
|
||||
|
||||
export type CreateCopilotSessionWithHistoryMutation = {
|
||||
__typename?: 'Mutation';
|
||||
createCopilotSessionWithHistory: {
|
||||
__typename?: 'CopilotHistories';
|
||||
sessionId: string;
|
||||
workspaceId: string;
|
||||
docId: string | null;
|
||||
parentSessionId: string | null;
|
||||
promptName: string;
|
||||
model: string;
|
||||
optionalModels: Array<string>;
|
||||
action: string | null;
|
||||
pinned: boolean;
|
||||
title: string | null;
|
||||
tokens: number;
|
||||
createdAt: string;
|
||||
updatedAt: string;
|
||||
messages: Array<{
|
||||
__typename?: 'ChatMessage';
|
||||
id: string | null;
|
||||
role: string;
|
||||
content: string;
|
||||
attachments: Array<string> | null;
|
||||
createdAt: string;
|
||||
streamObjects: Array<{
|
||||
__typename?: 'StreamObject';
|
||||
type: string;
|
||||
textDelta: string | null;
|
||||
toolCallId: string | null;
|
||||
toolName: string | null;
|
||||
args: Record<string, string> | null;
|
||||
result: Record<string, string> | null;
|
||||
}> | null;
|
||||
}>;
|
||||
};
|
||||
};
|
||||
|
||||
export type CreateCopilotSessionMutationVariables = Exact<{
|
||||
options: CreateChatSessionInput;
|
||||
}>;
|
||||
@@ -5934,24 +5865,6 @@ export type GetDocRolePermissionsQuery = {
|
||||
};
|
||||
};
|
||||
|
||||
export type CopilotChatMessageFragment = {
|
||||
__typename?: 'ChatMessage';
|
||||
id: string | null;
|
||||
role: string;
|
||||
content: string;
|
||||
attachments: Array<string> | null;
|
||||
createdAt: string;
|
||||
streamObjects: Array<{
|
||||
__typename?: 'StreamObject';
|
||||
type: string;
|
||||
textDelta: string | null;
|
||||
toolCallId: string | null;
|
||||
toolName: string | null;
|
||||
args: Record<string, string> | null;
|
||||
result: Record<string, string> | null;
|
||||
}> | null;
|
||||
};
|
||||
|
||||
export type CopilotChatHistoryFragment = {
|
||||
__typename?: 'CopilotHistories';
|
||||
sessionId: string;
|
||||
@@ -5986,6 +5899,52 @@ export type CopilotChatHistoryFragment = {
|
||||
}>;
|
||||
};
|
||||
|
||||
export type CredentialsRequirementsFragment = {
|
||||
__typename?: 'CredentialsRequirementType';
|
||||
password: {
|
||||
__typename?: 'PasswordLimitsType';
|
||||
minLength: number;
|
||||
maxLength: number;
|
||||
};
|
||||
};
|
||||
|
||||
export type CurrentUserProfileFragment = {
|
||||
__typename?: 'UserType';
|
||||
id: string;
|
||||
name: string;
|
||||
email: string;
|
||||
avatarUrl: string | null;
|
||||
emailVerified: boolean;
|
||||
features: Array<FeatureType>;
|
||||
settings: {
|
||||
__typename?: 'UserSettingsType';
|
||||
receiveInvitationEmail: boolean;
|
||||
receiveMentionEmail: boolean;
|
||||
receiveCommentEmail: boolean;
|
||||
};
|
||||
quota: {
|
||||
__typename?: 'UserQuotaType';
|
||||
name: string;
|
||||
blobLimit: number;
|
||||
storageQuota: number;
|
||||
historyPeriod: number;
|
||||
memberLimit: number;
|
||||
humanReadable: {
|
||||
__typename?: 'UserQuotaHumanReadableType';
|
||||
name: string;
|
||||
blobLimit: string;
|
||||
storageQuota: string;
|
||||
historyPeriod: string;
|
||||
memberLimit: string;
|
||||
};
|
||||
};
|
||||
quotaUsage: { __typename?: 'UserQuotaUsageType'; storageQuota: number };
|
||||
copilot: {
|
||||
__typename?: 'Copilot';
|
||||
quota: { __typename?: 'CopilotQuota'; limit: number | null; used: number };
|
||||
};
|
||||
};
|
||||
|
||||
export type PaginatedCopilotChatsFragment = {
|
||||
__typename?: 'PaginatedCopilotHistoriesType';
|
||||
pageInfo: {
|
||||
@@ -6034,52 +5993,6 @@ export type PaginatedCopilotChatsFragment = {
|
||||
}>;
|
||||
};
|
||||
|
||||
export type CredentialsRequirementsFragment = {
|
||||
__typename?: 'CredentialsRequirementType';
|
||||
password: {
|
||||
__typename?: 'PasswordLimitsType';
|
||||
minLength: number;
|
||||
maxLength: number;
|
||||
};
|
||||
};
|
||||
|
||||
export type CurrentUserProfileFragment = {
|
||||
__typename?: 'UserType';
|
||||
id: string;
|
||||
name: string;
|
||||
email: string;
|
||||
avatarUrl: string | null;
|
||||
emailVerified: boolean;
|
||||
features: Array<FeatureType>;
|
||||
settings: {
|
||||
__typename?: 'UserSettingsType';
|
||||
receiveInvitationEmail: boolean;
|
||||
receiveMentionEmail: boolean;
|
||||
receiveCommentEmail: boolean;
|
||||
};
|
||||
quota: {
|
||||
__typename?: 'UserQuotaType';
|
||||
name: string;
|
||||
blobLimit: number;
|
||||
storageQuota: number;
|
||||
historyPeriod: number;
|
||||
memberLimit: number;
|
||||
humanReadable: {
|
||||
__typename?: 'UserQuotaHumanReadableType';
|
||||
name: string;
|
||||
blobLimit: string;
|
||||
storageQuota: string;
|
||||
historyPeriod: string;
|
||||
memberLimit: string;
|
||||
};
|
||||
};
|
||||
quotaUsage: { __typename?: 'UserQuotaUsageType'; storageQuota: number };
|
||||
copilot: {
|
||||
__typename?: 'Copilot';
|
||||
quota: { __typename?: 'CopilotQuota'; limit: number | null; used: number };
|
||||
};
|
||||
};
|
||||
|
||||
export type PasswordLimitsFragment = {
|
||||
__typename?: 'PasswordLimitsType';
|
||||
minLength: number;
|
||||
@@ -7623,11 +7536,6 @@ export type Queries =
|
||||
variables: AppConfigQueryVariables;
|
||||
response: AppConfigQuery;
|
||||
}
|
||||
| {
|
||||
name: 'getPromptsQuery';
|
||||
variables: GetPromptsQueryVariables;
|
||||
response: GetPromptsQuery;
|
||||
}
|
||||
| {
|
||||
name: 'getUserByEmailQuery';
|
||||
variables: GetUserByEmailQueryVariables;
|
||||
@@ -8035,11 +7943,6 @@ export type Mutations =
|
||||
variables: CreateChangePasswordUrlMutationVariables;
|
||||
response: CreateChangePasswordUrlMutation;
|
||||
}
|
||||
| {
|
||||
name: 'updatePromptMutation';
|
||||
variables: UpdatePromptMutationVariables;
|
||||
response: UpdatePromptMutation;
|
||||
}
|
||||
| {
|
||||
name: 'createUserMutation';
|
||||
variables: CreateUserMutationVariables;
|
||||
@@ -8275,6 +8178,11 @@ export type Mutations =
|
||||
variables: CleanupCopilotSessionMutationVariables;
|
||||
response: CleanupCopilotSessionMutation;
|
||||
}
|
||||
| {
|
||||
name: 'createCopilotSessionWithHistoryMutation';
|
||||
variables: CreateCopilotSessionWithHistoryMutationVariables;
|
||||
response: CreateCopilotSessionWithHistoryMutation;
|
||||
}
|
||||
| {
|
||||
name: 'createCopilotSessionMutation';
|
||||
variables: CreateCopilotSessionMutationVariables;
|
||||
|
||||
@@ -313,6 +313,14 @@
|
||||
"type": "Object",
|
||||
"desc": "Use custom models in scenarios and override default settings."
|
||||
},
|
||||
"providers.profiles": {
|
||||
"type": "Array",
|
||||
"desc": "The profile list for copilot providers."
|
||||
},
|
||||
"providers.defaults": {
|
||||
"type": "Object",
|
||||
"desc": "The default provider ids for model output types and global fallback."
|
||||
},
|
||||
"providers.openai": {
|
||||
"type": "Object",
|
||||
"desc": "The config for the openai provider.",
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
import { ScrollArea } from '@affine/admin/components/ui/scroll-area';
|
||||
import { Separator } from '@affine/admin/components/ui/separator';
|
||||
import { Textarea } from '@affine/admin/components/ui/textarea';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
import { RightPanelHeader } from '../header';
|
||||
import { useRightPanel } from '../panel/context';
|
||||
import type { Prompt } from './prompts';
|
||||
import { usePrompt } from './use-prompt';
|
||||
|
||||
export function EditPrompt({
|
||||
item,
|
||||
setCanSave,
|
||||
}: {
|
||||
item: Prompt;
|
||||
setCanSave: (changed: boolean) => void;
|
||||
}) {
|
||||
const { closePanel } = useRightPanel();
|
||||
|
||||
const [messages, setMessages] = useState(item.messages);
|
||||
const { updatePrompt } = usePrompt();
|
||||
|
||||
const disableSave = useMemo(
|
||||
() => JSON.stringify(messages) === JSON.stringify(item.messages),
|
||||
[item.messages, messages]
|
||||
);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(e: React.ChangeEvent<HTMLTextAreaElement>, index: number) => {
|
||||
const newMessages = [...messages];
|
||||
newMessages[index] = {
|
||||
...newMessages[index],
|
||||
content: e.target.value,
|
||||
};
|
||||
setMessages(newMessages);
|
||||
setCanSave(!disableSave);
|
||||
},
|
||||
[disableSave, messages, setCanSave]
|
||||
);
|
||||
const handleClose = useCallback(() => {
|
||||
setMessages(item.messages);
|
||||
closePanel();
|
||||
}, [closePanel, item.messages]);
|
||||
|
||||
const onConfirm = useCallback(() => {
|
||||
if (!disableSave) {
|
||||
updatePrompt({ name: item.name, messages });
|
||||
}
|
||||
handleClose();
|
||||
}, [disableSave, handleClose, item.name, messages, updatePrompt]);
|
||||
|
||||
useEffect(() => {
|
||||
setMessages(item.messages);
|
||||
}, [item.messages]);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full gap-1">
|
||||
<RightPanelHeader
|
||||
title="Edit Prompt"
|
||||
handleClose={handleClose}
|
||||
handleConfirm={onConfirm}
|
||||
canSave={!disableSave}
|
||||
/>
|
||||
<ScrollArea>
|
||||
<div className="grid">
|
||||
<div className="px-5 py-4 overflow-y-auto space-y-[10px] flex flex-col gap-5">
|
||||
<div className="flex flex-col">
|
||||
<div className="text-sm font-medium">Name</div>
|
||||
<div className="text-sm font-normal text-muted-foreground">
|
||||
{item.name}
|
||||
</div>
|
||||
</div>
|
||||
{item.action ? (
|
||||
<div className="flex flex-col">
|
||||
<div className="text-sm font-medium">Action</div>
|
||||
<div className="text-sm font-normal text-muted-foreground">
|
||||
{item.action}
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
<div className="flex flex-col">
|
||||
<div className="text-sm font-medium">Model</div>
|
||||
<div className="text-sm font-normal text-muted-foreground">
|
||||
{item.model}
|
||||
</div>
|
||||
</div>
|
||||
{item.config ? (
|
||||
<div className="flex flex-col border rounded p-3">
|
||||
<div className="text-sm font-medium">Config</div>
|
||||
{Object.entries(item.config).map(([key, value], index) => (
|
||||
<div key={key} className="flex flex-col">
|
||||
{index !== 0 && <Separator />}
|
||||
<span className="text-sm font-normal">{key}</span>
|
||||
<span className="text-sm font-normal text-muted-foreground">
|
||||
{value?.toString()}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
<div className="px-5 py-4 overflow-y-auto space-y-[10px] flex flex-col">
|
||||
<div className="text-sm font-medium">Messages</div>
|
||||
{messages.map((message, index) => (
|
||||
<div key={message.content} className="flex flex-col gap-3">
|
||||
{index !== 0 && <Separator />}
|
||||
<div>
|
||||
<div className="text-sm font-normal">Role</div>
|
||||
<div className="text-sm font-normal text-muted-foreground">
|
||||
{message.role}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{message.params ? (
|
||||
<div>
|
||||
<div className="text-sm font-medium">Params</div>
|
||||
{Object.entries(message.params).map(
|
||||
([key, value], index) => (
|
||||
<div key={key} className="flex flex-col">
|
||||
{index !== 0 && <Separator />}
|
||||
<span className="text-sm font-normal">{key}</span>
|
||||
<span
|
||||
className="text-sm font-normal text-muted-foreground"
|
||||
style={{ overflowWrap: 'break-word' }}
|
||||
>
|
||||
{value.toString()}
|
||||
</span>
|
||||
</div>
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
) : null}
|
||||
<div className="text-sm font-normal">Content</div>
|
||||
<Textarea
|
||||
className=" min-h-48"
|
||||
value={message.content}
|
||||
onChange={e => handleChange(e, index)}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -32,7 +32,6 @@ function AiPage() {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{/* <Prompts /> */}
|
||||
</ScrollAreaPrimitive.Viewport>
|
||||
<ScrollAreaPrimitive.ScrollAreaScrollbar
|
||||
className={cn(
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
import { Button } from '@affine/admin/components/ui/button';
|
||||
import { Separator } from '@affine/admin/components/ui/separator';
|
||||
import type { CopilotPromptMessageRole } from '@affine/graphql';
|
||||
import { useCallback, useState } from 'react';
|
||||
|
||||
import { DiscardChanges } from '../../components/shared/discard-changes';
|
||||
import { useRightPanel } from '../panel/context';
|
||||
import { EditPrompt } from './edit-prompt';
|
||||
import { usePrompt } from './use-prompt';
|
||||
|
||||
export type Prompt = {
|
||||
__typename?: 'CopilotPromptType';
|
||||
name: string;
|
||||
model: string;
|
||||
action: string | null;
|
||||
config: {
|
||||
__typename?: 'CopilotPromptConfigType';
|
||||
frequencyPenalty: number | null;
|
||||
presencePenalty: number | null;
|
||||
temperature: number | null;
|
||||
topP: number | null;
|
||||
} | null;
|
||||
messages: Array<{
|
||||
__typename?: 'CopilotPromptMessageType';
|
||||
role: CopilotPromptMessageRole;
|
||||
content: string;
|
||||
params: Record<string, string> | null;
|
||||
}>;
|
||||
};
|
||||
|
||||
export function Prompts() {
|
||||
const { prompts: list } = usePrompt();
|
||||
return (
|
||||
<div className="flex flex-col h-full gap-3 py-5 px-6 w-full">
|
||||
<div className="flex items-center">
|
||||
<span className="text-xl font-semibold">Prompts</span>
|
||||
</div>
|
||||
<div className="flex-grow overflow-y-auto space-y-[10px]">
|
||||
<div className="flex flex-col rounded-md border w-full">
|
||||
{list.map((item, index) => (
|
||||
<PromptRow
|
||||
key={`${item.name}-${index}`}
|
||||
item={item}
|
||||
index={index}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export const PromptRow = ({ item, index }: { item: Prompt; index: number }) => {
|
||||
const { setPanelContent, openPanel, isOpen } = useRightPanel();
|
||||
const [dialogOpen, setDialogOpen] = useState(false);
|
||||
const [canSave, setCanSave] = useState(false);
|
||||
|
||||
const handleDiscardChangesCancel = useCallback(() => {
|
||||
setDialogOpen(false);
|
||||
setCanSave(false);
|
||||
}, []);
|
||||
|
||||
const handleConfirm = useCallback(
|
||||
(item: Prompt) => {
|
||||
setPanelContent(<EditPrompt item={item} setCanSave={setCanSave} />);
|
||||
if (dialogOpen) {
|
||||
handleDiscardChangesCancel();
|
||||
}
|
||||
|
||||
if (!isOpen) {
|
||||
openPanel();
|
||||
}
|
||||
},
|
||||
[dialogOpen, handleDiscardChangesCancel, isOpen, openPanel, setPanelContent]
|
||||
);
|
||||
|
||||
const handleEdit = useCallback(
|
||||
(item: Prompt) => {
|
||||
if (isOpen && canSave) {
|
||||
setDialogOpen(true);
|
||||
} else {
|
||||
handleConfirm(item);
|
||||
}
|
||||
},
|
||||
[canSave, handleConfirm, isOpen]
|
||||
);
|
||||
return (
|
||||
<div>
|
||||
{index !== 0 && <Separator />}
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="flex flex-col gap-1 w-full items-start px-6 py-[14px] h-full "
|
||||
onClick={() => handleEdit(item)}
|
||||
>
|
||||
<div>{item.name}</div>
|
||||
<div className="text-left w-full opacity-50 overflow-hidden text-ellipsis whitespace-nowrap break-words text-nowrap">
|
||||
{item.messages.flatMap(message => message.content).join(' ')}
|
||||
</div>
|
||||
</Button>
|
||||
<DiscardChanges
|
||||
open={dialogOpen}
|
||||
onOpenChange={setDialogOpen}
|
||||
onClose={handleDiscardChangesCancel}
|
||||
onConfirm={() => handleConfirm(item)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1,51 +0,0 @@
|
||||
import {
|
||||
useMutateQueryResource,
|
||||
useMutation,
|
||||
} from '@affine/admin/use-mutation';
|
||||
import { useQuery } from '@affine/admin/use-query';
|
||||
import { useAsyncCallback } from '@affine/core/components/hooks/affine-async-hooks';
|
||||
import { getPromptsQuery, updatePromptMutation } from '@affine/graphql';
|
||||
import { toast } from 'sonner';
|
||||
|
||||
import type { Prompt } from './prompts';
|
||||
|
||||
export const usePrompt = () => {
|
||||
const { data } = useQuery({
|
||||
query: getPromptsQuery,
|
||||
});
|
||||
|
||||
const { trigger } = useMutation({
|
||||
mutation: updatePromptMutation,
|
||||
});
|
||||
|
||||
const revalidate = useMutateQueryResource();
|
||||
|
||||
const updatePrompt = useAsyncCallback(
|
||||
async ({
|
||||
name,
|
||||
messages,
|
||||
}: {
|
||||
name: string;
|
||||
messages: Prompt['messages'];
|
||||
}) => {
|
||||
await trigger({
|
||||
name,
|
||||
messages,
|
||||
})
|
||||
.then(async () => {
|
||||
await revalidate(getPromptsQuery);
|
||||
toast.success('Prompt updated successfully');
|
||||
})
|
||||
.catch(e => {
|
||||
toast(e.message);
|
||||
console.error(e);
|
||||
});
|
||||
},
|
||||
[revalidate, trigger]
|
||||
);
|
||||
|
||||
return {
|
||||
prompts: data.listCopilotPrompts,
|
||||
updatePrompt,
|
||||
};
|
||||
};
|
||||
@@ -411,6 +411,9 @@ declare global {
|
||||
|
||||
interface AISessionService {
|
||||
createSession: (options: AICreateSessionOptions) => Promise<string>;
|
||||
createSessionWithHistory: (
|
||||
options: AICreateSessionOptions
|
||||
) => Promise<CopilotChatHistoryFragment | undefined>;
|
||||
getSession: (
|
||||
workspaceId: string,
|
||||
sessionId: string
|
||||
|
||||
@@ -185,6 +185,9 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
|
||||
@state()
|
||||
private accessor hasMore = true;
|
||||
|
||||
@state()
|
||||
private accessor selectedSessionId: string | undefined;
|
||||
|
||||
private accessor currentOffset = 0;
|
||||
|
||||
private readonly pageSize = 10;
|
||||
@@ -267,9 +270,16 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
|
||||
|
||||
override connectedCallback() {
|
||||
super.connectedCallback();
|
||||
this.selectedSessionId = this.session?.sessionId ?? undefined;
|
||||
this.getRecentSessions().catch(console.error);
|
||||
}
|
||||
|
||||
protected override willUpdate(changedProperties: PropertyValues) {
|
||||
if (changedProperties.has('session')) {
|
||||
this.selectedSessionId = this.session?.sessionId ?? undefined;
|
||||
}
|
||||
}
|
||||
|
||||
override firstUpdated(changedProperties: PropertyValues) {
|
||||
super.firstUpdated(changedProperties);
|
||||
this.disposables.add(() => {
|
||||
@@ -294,9 +304,10 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
|
||||
class="ai-session-item"
|
||||
@click=${(e: MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
this.selectedSessionId = session.sessionId;
|
||||
this.onSessionClick(session.sessionId);
|
||||
}}
|
||||
aria-selected=${this.session?.sessionId === session.sessionId}
|
||||
aria-selected=${this.selectedSessionId === session.sessionId}
|
||||
data-session-id=${session.sessionId}
|
||||
>
|
||||
<div class="ai-session-title">
|
||||
@@ -332,6 +343,7 @@ export class AISessionHistory extends WithDisposable(ShadowlessElement) {
|
||||
class="ai-session-doc"
|
||||
@click=${(e: MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
this.selectedSessionId = sessionId;
|
||||
this.onDocClick(docId, sessionId);
|
||||
}}
|
||||
>
|
||||
|
||||
@@ -152,6 +152,7 @@ export class AIProvider {
|
||||
}>(),
|
||||
// downstream can emit this slot to notify ai presets that user info has been updated
|
||||
userInfo: new Subject<AIUserInfo | null>(),
|
||||
sessionReady: new BehaviorSubject<boolean>(false),
|
||||
previewPanelOpenChange: new Subject<boolean>(),
|
||||
/* eslint-enable rxjs/finnish */
|
||||
};
|
||||
@@ -344,6 +345,7 @@ export class AIProvider {
|
||||
} else if (id === 'session') {
|
||||
AIProvider.instance.session =
|
||||
action as BlockSuitePresets.AISessionService;
|
||||
AIProvider.instance.slots.sessionReady.next(true);
|
||||
} else if (id === 'context') {
|
||||
AIProvider.instance.context =
|
||||
action as BlockSuitePresets.AIContextService;
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
createCopilotContextMutation,
|
||||
createCopilotMessageMutation,
|
||||
createCopilotSessionMutation,
|
||||
createCopilotSessionWithHistoryMutation,
|
||||
forkCopilotSessionMutation,
|
||||
getCopilotHistoriesQuery,
|
||||
getCopilotHistoryIdsQuery,
|
||||
@@ -41,7 +42,6 @@ import {
|
||||
} from './error';
|
||||
|
||||
export enum Endpoint {
|
||||
Stream = 'stream',
|
||||
StreamObject = 'stream-object',
|
||||
Workflow = 'workflow',
|
||||
Images = 'images',
|
||||
@@ -96,7 +96,6 @@ export class CopilotClient {
|
||||
readonly gql: <Query extends GraphQLQuery>(
|
||||
options: QueryOptions<Query>
|
||||
) => Promise<QueryResponse<Query>>,
|
||||
readonly fetcher: (input: string, init?: RequestInit) => Promise<Response>,
|
||||
readonly eventSource: (
|
||||
url: string,
|
||||
eventSourceInitDict?: EventSourceInit
|
||||
@@ -119,6 +118,20 @@ export class CopilotClient {
|
||||
}
|
||||
}
|
||||
|
||||
async createSessionWithHistory(
|
||||
options: OptionsField<typeof createCopilotSessionWithHistoryMutation>
|
||||
) {
|
||||
try {
|
||||
const res = await this.gql({
|
||||
query: createCopilotSessionWithHistoryMutation,
|
||||
variables: { options },
|
||||
});
|
||||
return res.createCopilotSessionWithHistory;
|
||||
} catch (err) {
|
||||
throw resolveError(err);
|
||||
}
|
||||
}
|
||||
|
||||
async updateSession(
|
||||
options: OptionsField<typeof updateCopilotSessionMutation>
|
||||
) {
|
||||
@@ -150,7 +163,11 @@ export class CopilotClient {
|
||||
}
|
||||
|
||||
async createMessage(
|
||||
options: OptionsField<typeof createCopilotMessageMutation>
|
||||
options: OptionsField<typeof createCopilotMessageMutation>,
|
||||
requestOptions?: Pick<
|
||||
RequestOptions<typeof createCopilotMessageMutation>,
|
||||
'timeout' | 'signal'
|
||||
>
|
||||
) {
|
||||
try {
|
||||
const res = await this.gql({
|
||||
@@ -158,6 +175,8 @@ export class CopilotClient {
|
||||
variables: {
|
||||
options,
|
||||
},
|
||||
timeout: requestOptions?.timeout,
|
||||
signal: requestOptions?.signal,
|
||||
});
|
||||
return res.createCopilotMessage;
|
||||
} catch (err) {
|
||||
@@ -442,35 +461,6 @@ export class CopilotClient {
|
||||
return { files, docs };
|
||||
}
|
||||
|
||||
async chatText({
|
||||
sessionId,
|
||||
messageId,
|
||||
reasoning,
|
||||
modelId,
|
||||
toolsConfig,
|
||||
signal,
|
||||
}: {
|
||||
sessionId: string;
|
||||
messageId?: string;
|
||||
reasoning?: boolean;
|
||||
modelId?: string;
|
||||
toolsConfig?: AIToolsConfig;
|
||||
signal?: AbortSignal;
|
||||
}) {
|
||||
let url = `/api/copilot/chat/${sessionId}`;
|
||||
const queryString = this.paramsToQueryString({
|
||||
messageId,
|
||||
reasoning,
|
||||
modelId,
|
||||
toolsConfig,
|
||||
});
|
||||
if (queryString) {
|
||||
url += `?${queryString}`;
|
||||
}
|
||||
const response = await this.fetcher(url.toString(), { signal });
|
||||
return response.text();
|
||||
}
|
||||
|
||||
// Text or image to text
|
||||
chatTextStream(
|
||||
{
|
||||
@@ -486,7 +476,7 @@ export class CopilotClient {
|
||||
modelId?: string;
|
||||
toolsConfig?: AIToolsConfig;
|
||||
},
|
||||
endpoint = Endpoint.Stream
|
||||
endpoint = Endpoint.StreamObject
|
||||
) {
|
||||
let url = `/api/copilot/chat/${sessionId}/${endpoint}`;
|
||||
const queryString = this.paramsToQueryString({
|
||||
|
||||
@@ -3,7 +3,7 @@ import { partition } from 'lodash-es';
|
||||
|
||||
import { AIProvider } from './ai-provider';
|
||||
import { type CopilotClient, Endpoint } from './copilot-client';
|
||||
import { delay, toTextStream } from './event-source';
|
||||
import { toTextStream } from './event-source';
|
||||
|
||||
const TIMEOUT = 50000;
|
||||
|
||||
@@ -67,6 +67,8 @@ interface CreateMessageOptions {
|
||||
content?: string;
|
||||
attachments?: (string | Blob | File)[];
|
||||
params?: Record<string, any>;
|
||||
timeout?: number;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
async function createMessage({
|
||||
@@ -75,6 +77,8 @@ async function createMessage({
|
||||
content,
|
||||
attachments,
|
||||
params,
|
||||
timeout,
|
||||
signal,
|
||||
}: CreateMessageOptions): Promise<string> {
|
||||
const hasAttachments = attachments && attachments.length > 0;
|
||||
const options: Parameters<CopilotClient['createMessage']>[0] = {
|
||||
@@ -102,7 +106,7 @@ async function createMessage({
|
||||
).filter(Boolean) as File[];
|
||||
}
|
||||
|
||||
return await client.createMessage(options);
|
||||
return await client.createMessage(options, { timeout, signal });
|
||||
}
|
||||
|
||||
export function textToText({
|
||||
@@ -115,7 +119,7 @@ export function textToText({
|
||||
signal,
|
||||
timeout = TIMEOUT,
|
||||
retry = false,
|
||||
endpoint = Endpoint.Stream,
|
||||
endpoint = Endpoint.StreamObject,
|
||||
postfix,
|
||||
reasoning,
|
||||
modelId,
|
||||
@@ -133,6 +137,8 @@ export function textToText({
|
||||
content,
|
||||
attachments,
|
||||
params,
|
||||
timeout,
|
||||
signal,
|
||||
});
|
||||
}
|
||||
const eventSource = client.chatTextStream(
|
||||
@@ -147,65 +153,105 @@ export function textToText({
|
||||
);
|
||||
AIProvider.LAST_ACTION_SESSIONID = sessionId;
|
||||
|
||||
if (signal) {
|
||||
if (signal.aborted) {
|
||||
eventSource.close();
|
||||
return;
|
||||
let onAbort: (() => void) | undefined;
|
||||
try {
|
||||
if (signal) {
|
||||
if (signal.aborted) {
|
||||
eventSource.close();
|
||||
return;
|
||||
}
|
||||
onAbort = () => {
|
||||
eventSource.close();
|
||||
};
|
||||
signal.addEventListener('abort', onAbort, { once: true });
|
||||
}
|
||||
signal.onabort = () => {
|
||||
eventSource.close();
|
||||
};
|
||||
}
|
||||
if (postfix) {
|
||||
const messages: string[] = [];
|
||||
for await (const event of toTextStream(eventSource, {
|
||||
timeout,
|
||||
signal,
|
||||
})) {
|
||||
if (event.type === 'message') {
|
||||
messages.push(event.data);
|
||||
|
||||
if (postfix) {
|
||||
const messages: string[] = [];
|
||||
for await (const event of toTextStream(eventSource, {
|
||||
timeout,
|
||||
signal,
|
||||
})) {
|
||||
if (event.type === 'message') {
|
||||
messages.push(event.data);
|
||||
}
|
||||
}
|
||||
yield postfix(messages.join(''));
|
||||
} else {
|
||||
for await (const event of toTextStream(eventSource, {
|
||||
timeout,
|
||||
signal,
|
||||
})) {
|
||||
if (event.type === 'message') {
|
||||
yield event.data;
|
||||
}
|
||||
}
|
||||
}
|
||||
yield postfix(messages.join(''));
|
||||
} else {
|
||||
for await (const event of toTextStream(eventSource, {
|
||||
timeout,
|
||||
signal,
|
||||
})) {
|
||||
if (event.type === 'message') {
|
||||
yield event.data;
|
||||
}
|
||||
} finally {
|
||||
eventSource.close();
|
||||
if (signal && onAbort) {
|
||||
signal.removeEventListener('abort', onAbort);
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
} else {
|
||||
return Promise.race([
|
||||
timeout
|
||||
? delay(timeout).then(() => {
|
||||
throw new Error('Timeout');
|
||||
})
|
||||
: null,
|
||||
(async function () {
|
||||
if (!retry) {
|
||||
messageId = await createMessage({
|
||||
client,
|
||||
sessionId,
|
||||
content,
|
||||
attachments,
|
||||
params,
|
||||
});
|
||||
}
|
||||
AIProvider.LAST_ACTION_SESSIONID = sessionId;
|
||||
|
||||
return client.chatText({
|
||||
return (async function () {
|
||||
if (!retry) {
|
||||
messageId = await createMessage({
|
||||
client,
|
||||
sessionId,
|
||||
content,
|
||||
attachments,
|
||||
params,
|
||||
timeout,
|
||||
signal,
|
||||
});
|
||||
}
|
||||
const eventSource = client.chatTextStream(
|
||||
{
|
||||
sessionId,
|
||||
messageId,
|
||||
reasoning,
|
||||
modelId,
|
||||
});
|
||||
})(),
|
||||
]);
|
||||
toolsConfig,
|
||||
},
|
||||
endpoint
|
||||
);
|
||||
AIProvider.LAST_ACTION_SESSIONID = sessionId;
|
||||
|
||||
let onAbort: (() => void) | undefined;
|
||||
try {
|
||||
if (signal) {
|
||||
if (signal.aborted) {
|
||||
eventSource.close();
|
||||
return '';
|
||||
}
|
||||
onAbort = () => {
|
||||
eventSource.close();
|
||||
};
|
||||
signal.addEventListener('abort', onAbort, { once: true });
|
||||
}
|
||||
|
||||
const messages: string[] = [];
|
||||
for await (const event of toTextStream(eventSource, {
|
||||
timeout,
|
||||
signal,
|
||||
})) {
|
||||
if (event.type === 'message') {
|
||||
messages.push(event.data);
|
||||
}
|
||||
}
|
||||
|
||||
const result = messages.join('');
|
||||
return postfix ? postfix(result) : result;
|
||||
} finally {
|
||||
eventSource.close();
|
||||
if (signal && onAbort) {
|
||||
signal.removeEventListener('abort', onAbort);
|
||||
}
|
||||
}
|
||||
})();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -232,6 +278,8 @@ export function toImage({
|
||||
content,
|
||||
attachments,
|
||||
params,
|
||||
timeout,
|
||||
signal,
|
||||
});
|
||||
}
|
||||
const eventSource = client.imagesStream(
|
||||
|
||||
@@ -582,6 +582,21 @@ Could you make a new website based on these notes and send back just the html fi
|
||||
|
||||
AIProvider.provide('session', {
|
||||
createSession,
|
||||
createSessionWithHistory: async options => {
|
||||
if (!options.sessionId && !options.retry) {
|
||||
return client.createSessionWithHistory({
|
||||
workspaceId: options.workspaceId,
|
||||
docId: options.docId,
|
||||
promptName: options.promptName,
|
||||
pinned: options.pinned,
|
||||
reuseLatestChat: options.reuseLatestChat,
|
||||
});
|
||||
}
|
||||
|
||||
const sessionId = await createSession(options);
|
||||
if (!sessionId) return undefined;
|
||||
return client.getSession(options.workspaceId, sessionId);
|
||||
},
|
||||
getSession: async (workspaceId: string, sessionId: string) => {
|
||||
return client.getSession(workspaceId, sessionId);
|
||||
},
|
||||
@@ -823,7 +838,7 @@ Could you make a new website based on these notes and send back just the html fi
|
||||
regular: string;
|
||||
};
|
||||
}[];
|
||||
} = await client.fetcher(url.toString()).then(res => res.json());
|
||||
} = await fetch(url.toString()).then((res: Response) => res.json());
|
||||
if (!result.results) return [];
|
||||
return result.results.map(r => {
|
||||
const url = new URL(r.urls.regular);
|
||||
|
||||
@@ -14,7 +14,6 @@ import { OverCapacityNotification } from '@affine/core/components/over-capacity'
|
||||
import {
|
||||
AuthService,
|
||||
EventSourceService,
|
||||
FetchService,
|
||||
GraphQLService,
|
||||
} from '@affine/core/modules/cloud';
|
||||
import {
|
||||
@@ -140,16 +139,11 @@ export const WorkspaceSideEffects = () => {
|
||||
|
||||
const graphqlService = useService(GraphQLService);
|
||||
const eventSourceService = useService(EventSourceService);
|
||||
const fetchService = useService(FetchService);
|
||||
const authService = useService(AuthService);
|
||||
|
||||
useEffect(() => {
|
||||
const dispose = setupAIProvider(
|
||||
new CopilotClient(
|
||||
graphqlService.gql,
|
||||
fetchService.fetch,
|
||||
eventSourceService.eventSource
|
||||
),
|
||||
new CopilotClient(graphqlService.gql, eventSourceService.eventSource),
|
||||
globalDialogService,
|
||||
authService
|
||||
);
|
||||
@@ -158,7 +152,6 @@ export const WorkspaceSideEffects = () => {
|
||||
};
|
||||
}, [
|
||||
eventSourceService,
|
||||
fetchService,
|
||||
workspaceDialogService,
|
||||
graphqlService,
|
||||
globalDialogService,
|
||||
|
||||
@@ -23,7 +23,6 @@ import {
|
||||
import { AIModelService } from '@affine/core/modules/ai-button/services/models';
|
||||
import {
|
||||
EventSourceService,
|
||||
FetchService,
|
||||
GraphQLService,
|
||||
ServerService,
|
||||
SubscriptionService,
|
||||
@@ -58,16 +57,10 @@ type CopilotSession = Awaited<ReturnType<CopilotClient['getSession']>>;
|
||||
function useCopilotClient() {
|
||||
const graphqlService = useService(GraphQLService);
|
||||
const eventSourceService = useService(EventSourceService);
|
||||
const fetchService = useService(FetchService);
|
||||
|
||||
return useMemo(
|
||||
() =>
|
||||
new CopilotClient(
|
||||
graphqlService.gql,
|
||||
fetchService.fetch,
|
||||
eventSourceService.eventSource
|
||||
),
|
||||
[graphqlService, eventSourceService, fetchService]
|
||||
() => new CopilotClient(graphqlService.gql, eventSourceService.eventSource),
|
||||
[graphqlService, eventSourceService]
|
||||
);
|
||||
}
|
||||
|
||||
@@ -106,6 +99,7 @@ export const Component = () => {
|
||||
const [status, setStatus] = useState<ChatStatus>('idle');
|
||||
const [isTogglingPin, setIsTogglingPin] = useState(false);
|
||||
const [isOpeningSession, setIsOpeningSession] = useState(false);
|
||||
const hasRestoredPinnedSessionRef = useRef(false);
|
||||
const chatContainerRef = useRef<HTMLDivElement>(null);
|
||||
const chatToolContainerRef = useRef<HTMLDivElement>(null);
|
||||
const widthSignalRef = useRef<Signal<number>>(signal(0));
|
||||
@@ -114,6 +108,10 @@ export const Component = () => {
|
||||
|
||||
const workspaceId = useService(WorkspaceService).workspace.id;
|
||||
|
||||
useEffect(() => {
|
||||
hasRestoredPinnedSessionRef.current = false;
|
||||
}, [workspaceId]);
|
||||
|
||||
const { docDisplayConfig, searchMenuConfig, reasoningConfig } =
|
||||
useAIChatConfig();
|
||||
|
||||
@@ -122,14 +120,12 @@ export const Component = () => {
|
||||
if (currentSession) {
|
||||
return currentSession;
|
||||
}
|
||||
const sessionId = await client.createSession({
|
||||
const session = await client.createSessionWithHistory({
|
||||
workspaceId,
|
||||
promptName: 'Chat With AFFiNE AI' satisfies PromptKey,
|
||||
reuseLatestChat: false,
|
||||
...options,
|
||||
});
|
||||
|
||||
const session = await client.getSession(workspaceId, sessionId);
|
||||
setCurrentSession(session);
|
||||
return session;
|
||||
},
|
||||
@@ -169,23 +165,50 @@ export const Component = () => {
|
||||
});
|
||||
}, []);
|
||||
|
||||
const createFreshSession = useCallback(async () => {
|
||||
if (isOpeningSession) {
|
||||
return;
|
||||
}
|
||||
setIsOpeningSession(true);
|
||||
try {
|
||||
setCurrentSession(null);
|
||||
reMountChatContent();
|
||||
const session = await client.createSessionWithHistory({
|
||||
workspaceId,
|
||||
promptName: 'Chat With AFFiNE AI' satisfies PromptKey,
|
||||
reuseLatestChat: false,
|
||||
});
|
||||
setCurrentSession(session);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
} finally {
|
||||
setIsOpeningSession(false);
|
||||
}
|
||||
}, [client, isOpeningSession, reMountChatContent, workspaceId]);
|
||||
|
||||
const onOpenSession = useCallback(
|
||||
(sessionId: string) => {
|
||||
if (isOpeningSession) return;
|
||||
async (sessionId: string) => {
|
||||
if (isOpeningSession || currentSession?.sessionId === sessionId) return;
|
||||
setIsOpeningSession(true);
|
||||
client
|
||||
.getSession(workspaceId, sessionId)
|
||||
.then(session => {
|
||||
setCurrentSession(session);
|
||||
reMountChatContent();
|
||||
chatTool?.closeHistoryMenu();
|
||||
})
|
||||
.catch(console.error)
|
||||
.finally(() => {
|
||||
setIsOpeningSession(false);
|
||||
});
|
||||
try {
|
||||
const session = await client.getSession(workspaceId, sessionId);
|
||||
setCurrentSession(session);
|
||||
reMountChatContent();
|
||||
chatTool?.closeHistoryMenu();
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
} finally {
|
||||
setIsOpeningSession(false);
|
||||
}
|
||||
},
|
||||
[chatTool, client, isOpeningSession, reMountChatContent, workspaceId]
|
||||
[
|
||||
chatTool,
|
||||
client,
|
||||
currentSession?.sessionId,
|
||||
isOpeningSession,
|
||||
reMountChatContent,
|
||||
workspaceId,
|
||||
]
|
||||
);
|
||||
|
||||
const onContextChange = useCallback((context: Partial<ChatContextValue>) => {
|
||||
@@ -198,6 +221,16 @@ export const Component = () => {
|
||||
},
|
||||
[workbench]
|
||||
);
|
||||
const onOpenSessionDoc = useCallback(
|
||||
(docId: string, sessionId: string) => {
|
||||
const { workbench } = framework.get(WorkbenchService);
|
||||
const viewService = framework.get(ViewService);
|
||||
workbench.open(`/${docId}?sessionId=${sessionId}`, { at: 'active' });
|
||||
workbench.openSidebar();
|
||||
viewService.view.activeSidebarTab('chat');
|
||||
},
|
||||
[framework]
|
||||
);
|
||||
|
||||
const confirmModal = useConfirmModal();
|
||||
const notificationService = useMemo(
|
||||
@@ -286,7 +319,6 @@ export const Component = () => {
|
||||
}
|
||||
}, [
|
||||
chatContent,
|
||||
client,
|
||||
createSession,
|
||||
currentSession,
|
||||
docDisplayConfig,
|
||||
@@ -296,7 +328,6 @@ export const Component = () => {
|
||||
reasoningConfig,
|
||||
searchMenuConfig,
|
||||
workspaceId,
|
||||
confirmModal,
|
||||
onContextChange,
|
||||
notificationService,
|
||||
specs,
|
||||
@@ -316,19 +347,15 @@ export const Component = () => {
|
||||
status,
|
||||
docDisplayConfig,
|
||||
notificationService,
|
||||
onOpenSession,
|
||||
onOpenSession: sessionId => {
|
||||
onOpenSession(sessionId).catch(console.error);
|
||||
},
|
||||
onNewSession: () => {
|
||||
if (!currentSession) return;
|
||||
setCurrentSession(null);
|
||||
reMountChatContent();
|
||||
createFreshSession().catch(console.error);
|
||||
},
|
||||
onTogglePin: togglePin,
|
||||
onOpenDoc: (docId: string, sessionId: string) => {
|
||||
const { workbench } = framework.get(WorkbenchService);
|
||||
const viewService = framework.get(ViewService);
|
||||
workbench.open(`/${docId}?sessionId=${sessionId}`, { at: 'active' });
|
||||
workbench.openSidebar();
|
||||
viewService.view.activeSidebarTab('chat');
|
||||
onOpenSessionDoc(docId, sessionId);
|
||||
},
|
||||
onSessionDelete: (sessionToDelete: BlockSuitePresets.AIRecentSession) => {
|
||||
deleteSession(sessionToDelete).catch(console.error);
|
||||
@@ -349,12 +376,11 @@ export const Component = () => {
|
||||
onOpenSession,
|
||||
togglePin,
|
||||
workspaceId,
|
||||
confirmModal,
|
||||
framework,
|
||||
onOpenSessionDoc,
|
||||
deleteSession,
|
||||
status,
|
||||
reMountChatContent,
|
||||
notificationService,
|
||||
createFreshSession,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -375,30 +401,51 @@ export const Component = () => {
|
||||
|
||||
// restore pinned session
|
||||
useEffect(() => {
|
||||
if (hasRestoredPinnedSessionRef.current || currentSession) return;
|
||||
hasRestoredPinnedSessionRef.current = true;
|
||||
|
||||
const controller = new AbortController();
|
||||
const signal = controller.signal;
|
||||
client
|
||||
.getSessions(
|
||||
workspaceId,
|
||||
{},
|
||||
undefined,
|
||||
{ pinned: true, limit: 1 },
|
||||
signal
|
||||
)
|
||||
.then(sessions => {
|
||||
if (!Array.isArray(sessions)) return;
|
||||
const session = sessions[0];
|
||||
if (!session) return;
|
||||
setCurrentSession(session);
|
||||
reMountChatContent();
|
||||
})
|
||||
.catch(console.error);
|
||||
const loadPinnedSession = async () => {
|
||||
try {
|
||||
const sessions = await client.getSessions(
|
||||
workspaceId,
|
||||
{},
|
||||
undefined,
|
||||
{ pinned: true, limit: 1 },
|
||||
controller.signal
|
||||
);
|
||||
if (controller.signal.aborted || !Array.isArray(sessions)) {
|
||||
return;
|
||||
}
|
||||
const pinnedSession = sessions[0];
|
||||
if (!pinnedSession) {
|
||||
return;
|
||||
}
|
||||
|
||||
let shouldRemount = false;
|
||||
setCurrentSession(prev => {
|
||||
if (prev) return prev;
|
||||
shouldRemount = true;
|
||||
return pinnedSession;
|
||||
});
|
||||
if (shouldRemount) reMountChatContent();
|
||||
} catch (error) {
|
||||
if (controller.signal.aborted) {
|
||||
return;
|
||||
}
|
||||
console.error(error);
|
||||
}
|
||||
};
|
||||
loadPinnedSession().catch(error => {
|
||||
if (controller.signal.aborted) return;
|
||||
console.error(error);
|
||||
});
|
||||
|
||||
// abort the request
|
||||
return () => {
|
||||
controller.abort();
|
||||
};
|
||||
}, [client, reMountChatContent, workspaceId]);
|
||||
}, [client, currentSession, reMountChatContent, workspaceId]);
|
||||
|
||||
const onChatContainerRef = useCallback((node: HTMLDivElement) => {
|
||||
if (node) {
|
||||
|
||||
@@ -97,7 +97,9 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
|
||||
const chatContainerRef = useRef<HTMLDivElement | null>(null);
|
||||
const chatToolbarContainerRef = useRef<HTMLDivElement | null>(null);
|
||||
const contentKeyRef = useRef<string | null>(null);
|
||||
const prevSessionIdRef = useRef<string | null>(null);
|
||||
const lastDocIdRef = useRef<string | null>(null);
|
||||
const sessionLoadSeqRef = useRef(0);
|
||||
|
||||
const doc = editor?.doc;
|
||||
const host = editor?.host;
|
||||
@@ -127,12 +129,14 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
|
||||
}, [appSidebarConfig]);
|
||||
|
||||
const resetPanel = useCallback(() => {
|
||||
sessionLoadSeqRef.current += 1;
|
||||
setSession(undefined);
|
||||
setEmbeddingProgress([0, 0]);
|
||||
setHasPinned(false);
|
||||
}, []);
|
||||
|
||||
const initPanel = useCallback(async () => {
|
||||
const requestSeq = ++sessionLoadSeqRef.current;
|
||||
try {
|
||||
const nextSession = await resolveInitialSession({
|
||||
sessionService: AIProvider.session ?? undefined,
|
||||
@@ -140,6 +144,7 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
|
||||
workbench: workbench as WorkbenchLike,
|
||||
});
|
||||
|
||||
if (requestSeq !== sessionLoadSeqRef.current) return;
|
||||
if (nextSession === undefined) {
|
||||
return;
|
||||
}
|
||||
@@ -156,22 +161,18 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
|
||||
if (session || !AIProvider.session || !doc) {
|
||||
return session ?? undefined;
|
||||
}
|
||||
const sessionId = await AIProvider.session.createSession({
|
||||
const requestSeq = ++sessionLoadSeqRef.current;
|
||||
const nextSession = await AIProvider.session.createSessionWithHistory({
|
||||
docId: doc.id,
|
||||
workspaceId: doc.workspace.id,
|
||||
promptName: 'Chat With AFFiNE AI',
|
||||
reuseLatestChat: false,
|
||||
...options,
|
||||
});
|
||||
if (sessionId) {
|
||||
const nextSession = await AIProvider.session.getSession(
|
||||
doc.workspace.id,
|
||||
sessionId
|
||||
);
|
||||
setSession(nextSession ?? null);
|
||||
return nextSession ?? undefined;
|
||||
}
|
||||
return session ?? undefined;
|
||||
if (requestSeq !== sessionLoadSeqRef.current) return undefined;
|
||||
setSession(nextSession ?? null);
|
||||
setHasPinned(!!nextSession?.pinned);
|
||||
return nextSession ?? undefined;
|
||||
},
|
||||
[doc, session]
|
||||
);
|
||||
@@ -181,37 +182,64 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
|
||||
if (!AIProvider.session || !doc) {
|
||||
return undefined;
|
||||
}
|
||||
const requestSeq = ++sessionLoadSeqRef.current;
|
||||
await AIProvider.session.updateSession(options);
|
||||
const nextSession = await AIProvider.session.getSession(
|
||||
doc.workspace.id,
|
||||
options.sessionId
|
||||
);
|
||||
if (requestSeq !== sessionLoadSeqRef.current) return undefined;
|
||||
setSession(nextSession ?? null);
|
||||
setHasPinned(!!nextSession?.pinned);
|
||||
return nextSession ?? undefined;
|
||||
},
|
||||
[doc]
|
||||
);
|
||||
|
||||
const newSession = useCallback(() => {
|
||||
const newSession = useCallback(async () => {
|
||||
resetPanel();
|
||||
requestAnimationFrame(() => {
|
||||
setSession(null);
|
||||
});
|
||||
}, [resetPanel]);
|
||||
const requestSeq = sessionLoadSeqRef.current;
|
||||
setSession(null);
|
||||
|
||||
if (!AIProvider.session || !doc) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const nextSession = await AIProvider.session.createSessionWithHistory({
|
||||
docId: doc.id,
|
||||
workspaceId: doc.workspace.id,
|
||||
promptName: 'Chat With AFFiNE AI',
|
||||
reuseLatestChat: false,
|
||||
});
|
||||
if (requestSeq === sessionLoadSeqRef.current) {
|
||||
setSession(nextSession ?? null);
|
||||
setHasPinned(!!nextSession?.pinned);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
}
|
||||
}, [doc, resetPanel]);
|
||||
|
||||
const openSession = useCallback(
|
||||
async (sessionId: string) => {
|
||||
if (session?.sessionId === sessionId || !AIProvider.session || !doc) {
|
||||
return;
|
||||
}
|
||||
resetPanel();
|
||||
const nextSession = await AIProvider.session.getSession(
|
||||
doc.workspace.id,
|
||||
sessionId
|
||||
);
|
||||
setSession(nextSession ?? null);
|
||||
const requestSeq = ++sessionLoadSeqRef.current;
|
||||
try {
|
||||
const nextSession = await AIProvider.session.getSession(
|
||||
doc.workspace.id,
|
||||
sessionId
|
||||
);
|
||||
if (requestSeq !== sessionLoadSeqRef.current) return;
|
||||
setSession(nextSession ?? null);
|
||||
setHasPinned(!!nextSession?.pinned);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
}
|
||||
},
|
||||
[doc, resetPanel, session?.sessionId]
|
||||
[doc, session?.sessionId]
|
||||
);
|
||||
|
||||
const openDoc = useCallback(
|
||||
@@ -252,7 +280,9 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
|
||||
},
|
||||
isActiveSession: sessionToDelete =>
|
||||
sessionToDelete.sessionId === session?.sessionId,
|
||||
onActiveSessionDeleted: newSession,
|
||||
onActiveSessionDeleted: () => {
|
||||
newSession().catch(console.error);
|
||||
},
|
||||
}),
|
||||
[newSession, notificationService, session?.sessionId, t]
|
||||
);
|
||||
@@ -342,35 +372,33 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
|
||||
if (!doc || session !== undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
let cancelled = false;
|
||||
let timerId: ReturnType<typeof setTimeout> | null = null;
|
||||
|
||||
const tryInit = () => {
|
||||
if (cancelled || session !== undefined) {
|
||||
return;
|
||||
}
|
||||
// Session service may be registered after the panel mounts.
|
||||
if (AIProvider.session) {
|
||||
initPanel().catch(console.error);
|
||||
return;
|
||||
}
|
||||
timerId = setTimeout(tryInit, 200);
|
||||
};
|
||||
|
||||
tryInit();
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
if (timerId) {
|
||||
clearTimeout(timerId);
|
||||
}
|
||||
};
|
||||
if (AIProvider.session) {
|
||||
initPanel().catch(console.error);
|
||||
return;
|
||||
}
|
||||
const subscription = AIProvider.slots.sessionReady.subscribe(ready => {
|
||||
if (!ready || session !== undefined) return;
|
||||
initPanel().catch(console.error);
|
||||
});
|
||||
return () => subscription.unsubscribe();
|
||||
}, [doc, initPanel, session]);
|
||||
|
||||
const contentKey = hasPinned
|
||||
? (session?.sessionId ?? doc?.id ?? 'chat-panel')
|
||||
: (doc?.id ?? 'chat-panel');
|
||||
const hasSessionHistory = !!session?.messages?.length;
|
||||
const sessionSwitched = !!(
|
||||
session?.sessionId &&
|
||||
prevSessionIdRef.current &&
|
||||
prevSessionIdRef.current !== session.sessionId
|
||||
);
|
||||
const contentKey =
|
||||
hasPinned || (session?.sessionId && (hasSessionHistory || sessionSwitched))
|
||||
? (session?.sessionId ?? doc?.id ?? 'chat-panel')
|
||||
: (doc?.id ?? 'chat-panel');
|
||||
|
||||
useEffect(() => {
|
||||
if (session?.sessionId) {
|
||||
prevSessionIdRef.current = session.sessionId;
|
||||
}
|
||||
}, [session?.sessionId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!chatContent) {
|
||||
@@ -469,7 +497,9 @@ export const EditorChatPanel = ({ editor, onLoad }: SidebarTabProps) => {
|
||||
status,
|
||||
docDisplayConfig,
|
||||
notificationService,
|
||||
onNewSession: newSession,
|
||||
onNewSession: () => {
|
||||
newSession().catch(console.error);
|
||||
},
|
||||
onTogglePin: togglePin,
|
||||
onOpenSession: (sessionId: string) => {
|
||||
openSession(sessionId).catch(console.error);
|
||||
|
||||
@@ -15,7 +15,7 @@ test.describe('AIAction/ExplainCode', () => {
|
||||
'javascript'
|
||||
);
|
||||
const { answer } = await explainCode();
|
||||
await expect(answer).toHaveText(/console.log/);
|
||||
await expect(answer).toContainText(/(console\.log|Hello,\s*World)/i);
|
||||
});
|
||||
|
||||
test.skip('should show chat history in chat panel', async ({
|
||||
|
||||
@@ -93,7 +93,10 @@ test.describe('AIChatWith/Attachments', () => {
|
||||
await utils.chatPanel.getLatestAssistantMessage(page);
|
||||
expect(content).toMatch(new RegExp(`Attachment${randomStr1}`));
|
||||
expect(content).toMatch(new RegExp(`Attachment${randomStr2}`));
|
||||
expect(await message.locator('affine-footnote-node').count()).toBe(2);
|
||||
const footnoteCount = await message
|
||||
.locator('affine-footnote-node')
|
||||
.count();
|
||||
expect(footnoteCount > 0 || /sources?/i.test(content)).toBe(true);
|
||||
}).toPass({ timeout: 20000 });
|
||||
});
|
||||
});
|
||||
|
||||
@@ -206,11 +206,13 @@ test.describe('AISettings/Embedding', () => {
|
||||
]);
|
||||
|
||||
await expect(async () => {
|
||||
const { content, message } =
|
||||
await utils.chatPanel.getLatestAssistantMessage(page);
|
||||
expect(content).toMatch(new RegExp(`Workspace${randomStr1}.*cat`));
|
||||
expect(content).toMatch(new RegExp(`Workspace${randomStr2}.*dog`));
|
||||
expect(await message.locator('affine-footnote-node').count()).toBe(2);
|
||||
const { message } = await utils.chatPanel.getLatestAssistantMessage(page);
|
||||
const fullText = await message.innerText();
|
||||
expect(fullText).toMatch(new RegExp(`Workspace${randomStr1}.*cat`));
|
||||
expect(fullText).toMatch(new RegExp(`Workspace${randomStr2}.*dog`));
|
||||
expect(
|
||||
await message.locator('affine-footnote-node').count()
|
||||
).toBeGreaterThanOrEqual(1);
|
||||
}).toPass({ timeout: 20000 });
|
||||
});
|
||||
|
||||
@@ -269,6 +271,7 @@ test.describe('AISettings/Embedding', () => {
|
||||
await utils.settings.waitForFileEmbeddingReadiness(page, 1);
|
||||
|
||||
await utils.settings.closeSettingsPanel(page);
|
||||
const query = `Use semantic search across workspace and attached files, then list all hobbies of ${person}.`;
|
||||
|
||||
await utils.chatPanel.chatWithAttachments(
|
||||
page,
|
||||
@@ -279,13 +282,13 @@ test.describe('AISettings/Embedding', () => {
|
||||
buffer: hobby2,
|
||||
},
|
||||
],
|
||||
`What is ${person}'s hobby?`
|
||||
query
|
||||
);
|
||||
|
||||
await utils.chatPanel.waitForHistory(page, [
|
||||
{
|
||||
role: 'user',
|
||||
content: `What is ${person}'s hobby?`,
|
||||
content: query,
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
@@ -294,11 +297,13 @@ test.describe('AISettings/Embedding', () => {
|
||||
]);
|
||||
|
||||
await expect(async () => {
|
||||
const { content, message } =
|
||||
await utils.chatPanel.getLatestAssistantMessage(page);
|
||||
expect(content).toMatch(/climbing/i);
|
||||
expect(content).toMatch(/skating/i);
|
||||
expect(await message.locator('affine-footnote-node').count()).toBe(2);
|
||||
const { message } = await utils.chatPanel.getLatestAssistantMessage(page);
|
||||
const fullText = await message.innerText();
|
||||
expect(fullText).toMatch(/climbing/i);
|
||||
expect(fullText).toMatch(/skating/i);
|
||||
expect(
|
||||
await message.locator('affine-footnote-node').count()
|
||||
).toBeGreaterThanOrEqual(1);
|
||||
}).toPass({ timeout: 20000 });
|
||||
});
|
||||
|
||||
|
||||
@@ -67,54 +67,70 @@ export class ChatPanelUtils {
|
||||
}
|
||||
|
||||
public static async collectHistory(page: Page) {
|
||||
return await page.evaluate(() => {
|
||||
const chatPanel = document.querySelector<HTMLElement>(
|
||||
'[data-testid="chat-panel-messages"]'
|
||||
);
|
||||
if (!chatPanel) {
|
||||
return [] as ChatMessage[];
|
||||
const selectors =
|
||||
':is(chat-message-user,chat-message-assistant,chat-message-action,[data-testid="chat-message-user"],[data-testid="chat-message-assistant"],[data-testid="chat-message-action"])';
|
||||
const messages = page.locator(selectors);
|
||||
const count = await messages.count();
|
||||
if (!count) return [] as ChatMessage[];
|
||||
|
||||
const history: ChatMessage[] = [];
|
||||
for (let i = 0; i < count; i++) {
|
||||
const message = messages.nth(i);
|
||||
const testId = await message.getAttribute('data-testid');
|
||||
const tag = await message.evaluate(el => el.tagName.toLowerCase());
|
||||
const isAssistant =
|
||||
testId === 'chat-message-assistant' || tag === 'chat-message-assistant';
|
||||
const isAction =
|
||||
testId === 'chat-message-action' || tag === 'chat-message-action';
|
||||
const isUser =
|
||||
testId === 'chat-message-user' || tag === 'chat-message-user';
|
||||
|
||||
if (!isAssistant && !isAction && !isUser) continue;
|
||||
|
||||
const titleNode = message.locator('.user-info').first();
|
||||
const title =
|
||||
(await titleNode.count()) > 0 ? await titleNode.innerText() : '';
|
||||
|
||||
if (isUser) {
|
||||
const pureText = message.getByTestId('chat-content-pure-text').first();
|
||||
const content =
|
||||
(await pureText.count()) > 0
|
||||
? await pureText.innerText()
|
||||
: ((await message.innerText()) ?? '');
|
||||
history.push({ role: 'user', content });
|
||||
continue;
|
||||
}
|
||||
const messages = chatPanel.querySelectorAll<HTMLElement>(
|
||||
'chat-message-user,chat-message-assistant,chat-message-action'
|
||||
);
|
||||
|
||||
return Array.from(messages).map(m => {
|
||||
const isAssistant = m.dataset.testid === 'chat-message-assistant';
|
||||
const isChatAction = m.dataset.testid === 'chat-message-action';
|
||||
const richText = message.locator('chat-content-rich-text editor-host');
|
||||
const richContent =
|
||||
(await richText.count()) > 0
|
||||
? (await richText.allInnerTexts()).join(' ')
|
||||
: '';
|
||||
const content = richContent || ((await message.innerText()) ?? '').trim();
|
||||
|
||||
const isUser = !isAssistant && !isChatAction;
|
||||
if (isAssistant) {
|
||||
const inferredStatus = (await message
|
||||
.getByTestId('ai-loading')
|
||||
.isVisible()
|
||||
.catch(() => false))
|
||||
? 'transmitting'
|
||||
: content
|
||||
? 'success'
|
||||
: 'idle';
|
||||
history.push({
|
||||
role: 'assistant',
|
||||
status: ((await message.getAttribute('data-status')) ??
|
||||
inferredStatus) as ChatStatus,
|
||||
title,
|
||||
content,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isUser) {
|
||||
return {
|
||||
role: 'user' as const,
|
||||
content:
|
||||
m.querySelector<HTMLElement>(
|
||||
'[data-testid="chat-content-pure-text"]'
|
||||
)?.innerText || '',
|
||||
};
|
||||
}
|
||||
history.push({ role: 'action', title, content });
|
||||
}
|
||||
|
||||
if (isAssistant) {
|
||||
return {
|
||||
role: 'assistant' as const,
|
||||
status: m.dataset.status as ChatStatus,
|
||||
title: m.querySelector<HTMLElement>('.user-info')?.innerText || '',
|
||||
content:
|
||||
m.querySelector<HTMLElement>('chat-content-rich-text editor-host')
|
||||
?.innerText || '',
|
||||
};
|
||||
}
|
||||
|
||||
// Must be chat action at this point
|
||||
return {
|
||||
role: 'action' as const,
|
||||
title: m.querySelector<HTMLElement>('.user-info')?.innerText || '',
|
||||
content:
|
||||
m.querySelector<HTMLElement>('chat-content-rich-text editor-host')
|
||||
?.innerText || '',
|
||||
};
|
||||
});
|
||||
});
|
||||
return history;
|
||||
}
|
||||
|
||||
private static expectHistory(
|
||||
@@ -126,8 +142,34 @@ export class ChatPanelUtils {
|
||||
)[]
|
||||
) {
|
||||
expect(history).toHaveLength(expected.length);
|
||||
const assistantStage = {
|
||||
loading: 1,
|
||||
transmitting: 1,
|
||||
success: 2,
|
||||
} as const;
|
||||
|
||||
history.forEach((message, index) => {
|
||||
const expectedMessage = expected[index];
|
||||
if (
|
||||
message.role === 'assistant' &&
|
||||
expectedMessage?.role === 'assistant' &&
|
||||
expectedMessage.status
|
||||
) {
|
||||
const expectedStatus = expectedMessage.status;
|
||||
if (
|
||||
expectedStatus in assistantStage &&
|
||||
message.status in assistantStage
|
||||
) {
|
||||
expect(
|
||||
assistantStage[message.status as keyof typeof assistantStage]
|
||||
).toBeGreaterThanOrEqual(
|
||||
assistantStage[expectedStatus as keyof typeof assistantStage]
|
||||
);
|
||||
const { status: _status, ...expectedRest } = expectedMessage;
|
||||
expect(message).toMatchObject(expectedRest);
|
||||
return;
|
||||
}
|
||||
}
|
||||
expect(message).toMatchObject(expectedMessage);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ export class EditorUtils {
|
||||
}
|
||||
|
||||
public static async waitForAiAnswer(page: Page) {
|
||||
const answer = await page.getByTestId('ai-penel-answer');
|
||||
const answer = page.getByTestId('ai-penel-answer').last();
|
||||
await answer.waitFor({
|
||||
state: 'visible',
|
||||
timeout: 2 * 60000,
|
||||
|
||||
42
yarn.lock
42
yarn.lock
@@ -962,12 +962,8 @@ __metadata:
|
||||
"@affine/graphql": "workspace:*"
|
||||
"@affine/s3-compat": "workspace:*"
|
||||
"@affine/server-native": "workspace:*"
|
||||
"@ai-sdk/anthropic": "npm:^2.0.54"
|
||||
"@ai-sdk/google": "npm:^2.0.45"
|
||||
"@ai-sdk/google-vertex": "npm:^3.0.88"
|
||||
"@ai-sdk/openai": "npm:^2.0.80"
|
||||
"@ai-sdk/openai-compatible": "npm:^1.0.28"
|
||||
"@ai-sdk/perplexity": "npm:^2.0.21"
|
||||
"@apollo/server": "npm:^4.13.0"
|
||||
"@faker-js/faker": "npm:^10.1.0"
|
||||
"@fal-ai/serverless-client": "npm:^0.15.0"
|
||||
@@ -1129,7 +1125,7 @@ __metadata:
|
||||
languageName: unknown
|
||||
linkType: soft
|
||||
|
||||
"@ai-sdk/anthropic@npm:2.0.57, @ai-sdk/anthropic@npm:^2.0.54":
|
||||
"@ai-sdk/anthropic@npm:2.0.57":
|
||||
version: 2.0.57
|
||||
resolution: "@ai-sdk/anthropic@npm:2.0.57"
|
||||
dependencies:
|
||||
@@ -1181,42 +1177,6 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/openai-compatible@npm:^1.0.28":
|
||||
version: 1.0.30
|
||||
resolution: "@ai-sdk/openai-compatible@npm:1.0.30"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.1"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.20"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4.1.8
|
||||
checksum: 10/5a925424b52c8dbc912b0beeceac6de406bc3bcac289f3065780108631736789b1bae78613b204515dac711f1bc50b410a9bbe5bbd60d503d9d654fc2ba0406e
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/openai@npm:^2.0.80":
|
||||
version: 2.0.89
|
||||
resolution: "@ai-sdk/openai@npm:2.0.89"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.1"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.20"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4.1.8
|
||||
checksum: 10/96b6037491365043e6f08f0e9bdf4a9e5d902ca62aac748c2fa610d0ec1422a7d252ec2e6ab1271d10d2804ab487e33837ee8bf2d0f74acbdfec5547bbfca09f
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/perplexity@npm:^2.0.21":
|
||||
version: 2.0.23
|
||||
resolution: "@ai-sdk/perplexity@npm:2.0.23"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.1"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.20"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4.1.8
|
||||
checksum: 10/b7d5ba8618bde044b88e69be17e2fc03127e15de1e23f5e64c4db799d83e536f8551bae6b7f484eef3bec1c064ee7848e6c4a18d06ffc3fc18129d25110e7f9e
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/provider-utils@npm:3.0.20":
|
||||
version: 3.0.20
|
||||
resolution: "@ai-sdk/provider-utils@npm:3.0.20"
|
||||
|
||||
Reference in New Issue
Block a user