mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-07-04 19:15:33 +08:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6557e5d01d |
@@ -988,16 +988,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"providers.profiles": {
|
||||
"type": "array",
|
||||
"description": "The profile list for copilot providers.\n@default []",
|
||||
"default": []
|
||||
},
|
||||
"providers.defaults": {
|
||||
"type": "object",
|
||||
"description": "The default provider ids for model output types and global fallback.\n@default {}",
|
||||
"default": {}
|
||||
},
|
||||
"providers.openai": {
|
||||
"type": "object",
|
||||
"description": "The config for the openai provider.\n@default {\"apiKey\":\"\",\"baseURL\":\"https://api.openai.com/v1\"}\n@link https://github.com/openai/openai-node",
|
||||
|
||||
@@ -93,7 +93,7 @@ runs:
|
||||
run: node -e "const p = $(yarn config cacheFolder --json).effective; console.log('yarn_global_cache=' + p)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache non-full yarn cache on Linux
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
if: ${{ inputs.full-cache != 'true' && runner.os == 'Linux' }}
|
||||
with:
|
||||
path: |
|
||||
@@ -105,7 +105,7 @@ runs:
|
||||
# and the decompression performance on Windows is very terrible
|
||||
# so we reduce the number of cached files on non-Linux systems by remove node_modules from cache path.
|
||||
- name: Cache non-full yarn cache on non-Linux
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
if: ${{ inputs.full-cache != 'true' && runner.os != 'Linux' }}
|
||||
with:
|
||||
path: |
|
||||
@@ -113,7 +113,7 @@ runs:
|
||||
key: node_modules-cache-${{ github.job }}-${{ runner.os }}-${{ runner.arch }}-${{ steps.system-info.outputs.name }}-${{ steps.system-info.outputs.release }}-${{ steps.system-info.outputs.version }}
|
||||
|
||||
- name: Cache full yarn cache on Linux
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
if: ${{ inputs.full-cache == 'true' && runner.os == 'Linux' }}
|
||||
with:
|
||||
path: |
|
||||
@@ -122,7 +122,7 @@ runs:
|
||||
key: node_modules-cache-full-${{ runner.os }}-${{ runner.arch }}-${{ steps.system-info.outputs.name }}-${{ steps.system-info.outputs.release }}-${{ steps.system-info.outputs.version }}
|
||||
|
||||
- name: Cache full yarn cache on non-Linux
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
if: ${{ inputs.full-cache == 'true' && runner.os != 'Linux' }}
|
||||
with:
|
||||
path: |
|
||||
@@ -154,7 +154,7 @@ runs:
|
||||
# Note: Playwright's cache directory is hard coded because that's what it
|
||||
# says to do in the docs. There doesn't appear to be a command that prints
|
||||
# it out for us.
|
||||
- uses: actions/cache@v5
|
||||
- uses: actions/cache@v4
|
||||
id: playwright-cache
|
||||
if: ${{ inputs.playwright-install == 'true' }}
|
||||
with:
|
||||
@@ -189,7 +189,7 @@ runs:
|
||||
run: |
|
||||
echo "version=$(yarn why --json electron | grep -h 'workspace:.' | jq --raw-output '.children[].locator' | sed -e 's/@playwright\/test@.*://' | head -n 1)" >> $GITHUB_OUTPUT
|
||||
|
||||
- uses: actions/cache@v5
|
||||
- uses: actions/cache@v4
|
||||
id: electron-cache
|
||||
if: ${{ inputs.electron-install == 'true' }}
|
||||
with:
|
||||
|
||||
@@ -13,5 +13,5 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/labeler@v6
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/labeler@v5
|
||||
|
||||
@@ -24,7 +24,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
environment: ${{ inputs.build-type }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Version
|
||||
uses: ./.github/actions/setup-version
|
||||
with:
|
||||
@@ -57,7 +57,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
environment: ${{ inputs.build-type }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Version
|
||||
uses: ./.github/actions/setup-version
|
||||
with:
|
||||
@@ -89,7 +89,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
environment: ${{ inputs.build-type }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Version
|
||||
uses: ./.github/actions/setup-version
|
||||
with:
|
||||
@@ -132,7 +132,7 @@ jobs:
|
||||
file: server-native.armv7.node
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Version
|
||||
uses: ./.github/actions/setup-version
|
||||
with:
|
||||
@@ -166,7 +166,7 @@ jobs:
|
||||
needs:
|
||||
- build-server-native
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Version
|
||||
uses: ./.github/actions/setup-version
|
||||
with:
|
||||
@@ -202,7 +202,7 @@ jobs:
|
||||
- build-mobile
|
||||
- build-admin
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Download server dist
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
|
||||
@@ -46,7 +46,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
@@ -67,9 +67,9 @@ jobs:
|
||||
name: Lint
|
||||
runs-on: ubuntu-24.04-arm
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Go (for actionlint)
|
||||
uses: actions/setup-go@v6
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 'stable'
|
||||
- name: Install actionlint
|
||||
@@ -111,7 +111,7 @@ jobs:
|
||||
env:
|
||||
NODE_OPTIONS: --max-old-space-size=14384
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
with:
|
||||
@@ -138,7 +138,7 @@ jobs:
|
||||
outputs:
|
||||
run-rust: ${{ steps.rust-filter.outputs.rust }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: dorny/paths-filter@v3
|
||||
id: rust-filter
|
||||
@@ -159,7 +159,7 @@ jobs:
|
||||
needs:
|
||||
- rust-test-filter
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- uses: ./.github/actions/build-rust
|
||||
with:
|
||||
target: x86_64-unknown-linux-gnu
|
||||
@@ -182,7 +182,7 @@ jobs:
|
||||
needs:
|
||||
- build-server-native
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
with:
|
||||
@@ -212,7 +212,7 @@ jobs:
|
||||
name: Check yarn binary
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Run check
|
||||
run: |
|
||||
set -euo pipefail
|
||||
@@ -228,7 +228,7 @@ jobs:
|
||||
matrix:
|
||||
shard: [1, 2]
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
with:
|
||||
@@ -256,7 +256,7 @@ jobs:
|
||||
name: E2E BlockSuite Cross Browser Test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
with:
|
||||
@@ -282,6 +282,52 @@ jobs:
|
||||
path: ./test-results
|
||||
if-no-files-found: ignore
|
||||
|
||||
bundler-matrix:
|
||||
name: Bundler Matrix (${{ matrix.bundler }})
|
||||
runs-on: ubuntu-24.04-arm
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
bundler: [webpack, rspack]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
with:
|
||||
playwright-install: false
|
||||
electron-install: false
|
||||
full-cache: true
|
||||
|
||||
- name: Run frontend build matrix
|
||||
env:
|
||||
AFFINE_BUNDLER: ${{ matrix.bundler }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
packages=(
|
||||
"@affine/web"
|
||||
"@affine/mobile"
|
||||
"@affine/ios"
|
||||
"@affine/android"
|
||||
"@affine/admin"
|
||||
"@affine/electron-renderer"
|
||||
)
|
||||
summary="test-results-bundler-${AFFINE_BUNDLER}.txt"
|
||||
: > "$summary"
|
||||
for pkg in "${packages[@]}"; do
|
||||
start=$(date +%s)
|
||||
yarn affine "$pkg" build
|
||||
end=$(date +%s)
|
||||
echo "${pkg},$((end-start))" >> "$summary"
|
||||
done
|
||||
|
||||
- name: Upload bundler timing
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-results-bundler-${{ matrix.bundler }}
|
||||
path: ./test-results-bundler-${{ matrix.bundler }}.txt
|
||||
if-no-files-found: ignore
|
||||
|
||||
e2e-test:
|
||||
name: E2E Test
|
||||
runs-on: ubuntu-24.04-arm
|
||||
@@ -294,7 +340,7 @@ jobs:
|
||||
matrix:
|
||||
shard: [1, 2, 3, 4, 5]
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
with:
|
||||
@@ -326,7 +372,7 @@ jobs:
|
||||
matrix:
|
||||
shard: [1, 2]
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
with:
|
||||
@@ -358,7 +404,7 @@ jobs:
|
||||
matrix:
|
||||
shard: [1, 2, 3]
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
with:
|
||||
@@ -391,7 +437,7 @@ jobs:
|
||||
env:
|
||||
CARGO_PROFILE_RELEASE_DEBUG: '1'
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
with:
|
||||
@@ -430,7 +476,7 @@ jobs:
|
||||
- { os: macos-latest, target: aarch64-apple-darwin }
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
with:
|
||||
@@ -471,7 +517,7 @@ jobs:
|
||||
- { os: windows-latest, target: aarch64-pc-windows-msvc }
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- uses: samypr100/setup-dev-drive@v3
|
||||
with:
|
||||
workspace-copy: true
|
||||
@@ -511,7 +557,7 @@ jobs:
|
||||
env:
|
||||
CARGO_PROFILE_RELEASE_DEBUG: '1'
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
with:
|
||||
@@ -534,7 +580,7 @@ jobs:
|
||||
name: Build @affine/electron renderer
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
with:
|
||||
@@ -561,7 +607,7 @@ jobs:
|
||||
needs:
|
||||
- build-native-linux
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
with:
|
||||
@@ -615,7 +661,7 @@ jobs:
|
||||
ports:
|
||||
- 9308:9308
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
@@ -696,7 +742,7 @@ jobs:
|
||||
stack-version: 9.0.1
|
||||
security-enabled: false
|
||||
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
@@ -759,7 +805,7 @@ jobs:
|
||||
ports:
|
||||
- 9308:9308
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
@@ -800,7 +846,7 @@ jobs:
|
||||
CARGO_TERM_COLOR: always
|
||||
MIRIFLAGS: -Zmiri-backtrace=full -Zmiri-tree-borrows
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
@@ -828,7 +874,7 @@ jobs:
|
||||
RUST_BACKTRACE: full
|
||||
CARGO_TERM_COLOR: always
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
@@ -852,7 +898,7 @@ jobs:
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
@@ -891,7 +937,7 @@ jobs:
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Rust
|
||||
uses: ./.github/actions/build-rust
|
||||
with:
|
||||
@@ -914,7 +960,7 @@ jobs:
|
||||
run-api: ${{ steps.decision.outputs.run_api }}
|
||||
run-e2e: ${{ steps.decision.outputs.run_e2e }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: dorny/paths-filter@v3
|
||||
id: copilot-filter
|
||||
@@ -983,7 +1029,7 @@ jobs:
|
||||
ports:
|
||||
- 9308:9308
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
@@ -1056,7 +1102,7 @@ jobs:
|
||||
ports:
|
||||
- 9308:9308
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
@@ -1139,7 +1185,7 @@ jobs:
|
||||
ports:
|
||||
- 9308:9308
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
@@ -1220,7 +1266,7 @@ jobs:
|
||||
test: true,
|
||||
}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
timeout-minutes: 10
|
||||
|
||||
@@ -10,7 +10,7 @@ jobs:
|
||||
env:
|
||||
CARGO_PROFILE_RELEASE_DEBUG: '1'
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
with:
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
ports:
|
||||
- 9308:9308
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
@@ -134,7 +134,7 @@ jobs:
|
||||
ports:
|
||||
- 9308:9308
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: ./.github/actions/setup-node
|
||||
@@ -167,7 +167,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
name: Post test result message
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Node.js
|
||||
|
||||
@@ -18,7 +18,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.action != 'edited' || github.event.changes.title != null }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
|
||||
@@ -35,7 +35,7 @@ jobs:
|
||||
- build-images
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Deploy to ${{ inputs.build-type }}
|
||||
uses: ./.github/actions/deploy
|
||||
with:
|
||||
|
||||
@@ -69,7 +69,7 @@ jobs:
|
||||
SENTRY_DSN: ${{ secrets.SENTRY_DSN }}
|
||||
SENTRY_RELEASE: ${{ inputs.app_version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Version
|
||||
uses: ./.github/actions/setup-version
|
||||
@@ -178,14 +178,14 @@ jobs:
|
||||
mv packages/frontend/apps/electron/out/*/make/deb/${{ inputs.arch }}/*.deb ./builds/affine-${{ env.RELEASE_VERSION }}-${{ env.BUILD_TYPE }}-linux-${{ inputs.arch }}.deb
|
||||
mv packages/frontend/apps/electron/out/*/make/flatpak/*/*.flatpak ./builds/affine-${{ env.RELEASE_VERSION }}-${{ env.BUILD_TYPE }}-linux-${{ inputs.arch }}.flatpak
|
||||
|
||||
- uses: actions/attest-build-provenance@v4
|
||||
- uses: actions/attest-build-provenance@v2
|
||||
if: ${{ inputs.platform == 'darwin' }}
|
||||
with:
|
||||
subject-path: |
|
||||
./builds/affine-${{ env.RELEASE_VERSION }}-${{ env.BUILD_TYPE }}-macos-${{ inputs.arch }}.zip
|
||||
./builds/affine-${{ env.RELEASE_VERSION }}-${{ env.BUILD_TYPE }}-macos-${{ inputs.arch }}.dmg
|
||||
|
||||
- uses: actions/attest-build-provenance@v4
|
||||
- uses: actions/attest-build-provenance@v2
|
||||
if: ${{ inputs.platform == 'linux' }}
|
||||
with:
|
||||
subject-path: |
|
||||
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
environment: ${{ inputs.build-type }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Version
|
||||
uses: ./.github/actions/setup-version
|
||||
with:
|
||||
@@ -187,7 +187,7 @@ jobs:
|
||||
FILES_TO_BE_SIGNED_x64: ${{ steps.get_files_to_be_signed.outputs.FILES_TO_BE_SIGNED_x64 }}
|
||||
FILES_TO_BE_SIGNED_arm64: ${{ steps.get_files_to_be_signed.outputs.FILES_TO_BE_SIGNED_arm64 }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Version
|
||||
uses: ./.github/actions/setup-version
|
||||
with:
|
||||
@@ -344,7 +344,7 @@ jobs:
|
||||
mv packages/frontend/apps/electron/out/*/make/squirrel.windows/${{ matrix.spec.arch }}/*.exe ./builds/affine-${{ env.RELEASE_VERSION }}-${{ env.BUILD_TYPE }}-windows-${{ matrix.spec.arch }}.exe
|
||||
mv packages/frontend/apps/electron/out/*/make/nsis.windows/${{ matrix.spec.arch }}/*.exe ./builds/affine-${{ env.RELEASE_VERSION }}-${{ env.BUILD_TYPE }}-windows-${{ matrix.spec.arch }}.nsis.exe
|
||||
|
||||
- uses: actions/attest-build-provenance@v4
|
||||
- uses: actions/attest-build-provenance@v2
|
||||
with:
|
||||
subject-path: |
|
||||
./builds/affine-${{ env.RELEASE_VERSION }}-${{ env.BUILD_TYPE }}-windows-${{ matrix.spec.arch }}.zip
|
||||
@@ -369,7 +369,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Download Artifacts (macos-x64)
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
|
||||
@@ -26,7 +26,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
environment: ${{ inputs.build-type }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Version
|
||||
uses: ./.github/actions/setup-version
|
||||
with:
|
||||
@@ -54,7 +54,7 @@ jobs:
|
||||
build-android-web:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Version
|
||||
uses: ./.github/actions/setup-version
|
||||
with:
|
||||
@@ -83,7 +83,7 @@ jobs:
|
||||
needs:
|
||||
- build-ios-web
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Version
|
||||
uses: ./.github/actions/setup-version
|
||||
with:
|
||||
@@ -147,7 +147,7 @@ jobs:
|
||||
needs:
|
||||
- build-android-web
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Version
|
||||
uses: ./.github/actions/setup-version
|
||||
with:
|
||||
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
GIT_SHORT_HASH: ${{ steps.prepare.outputs.GIT_SHORT_HASH }}
|
||||
BUILD_TYPE: ${{ steps.prepare.outputs.BUILD_TYPE }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Prepare Release
|
||||
id: prepare
|
||||
uses: ./.github/actions/prepare-release
|
||||
@@ -72,7 +72,7 @@ jobs:
|
||||
steps:
|
||||
- name: Decide whether to release
|
||||
id: decide
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const buildType = '${{ needs.prepare.outputs.BUILD_TYPE }}'
|
||||
|
||||
Generated
+29
-488
@@ -135,10 +135,12 @@ dependencies = [
|
||||
"napi-derive",
|
||||
"once_cell",
|
||||
"serde_json",
|
||||
"sha3",
|
||||
"sqlx",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"uuid",
|
||||
"y-octo",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -181,7 +183,6 @@ dependencies = [
|
||||
"chrono",
|
||||
"file-format",
|
||||
"infer",
|
||||
"llm_adapter",
|
||||
"mimalloc",
|
||||
"mp4parse",
|
||||
"napi",
|
||||
@@ -189,8 +190,6 @@ dependencies = [
|
||||
"napi-derive",
|
||||
"rand 0.9.2",
|
||||
"rayon",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha3",
|
||||
"tiktoken-rs",
|
||||
"tokio",
|
||||
@@ -248,7 +247,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed7572b7ba83a31e20d1b48970ee402d2e3e0537dcfe0a3ff4d6eb7508617d43"
|
||||
dependencies = [
|
||||
"alsa-sys",
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
]
|
||||
@@ -461,12 +460,6 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-waker"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
|
||||
|
||||
[[package]]
|
||||
name = "auto_enums"
|
||||
version = "0.8.7"
|
||||
@@ -485,28 +478,6 @@ version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
||||
|
||||
[[package]]
|
||||
name = "aws-lc-rs"
|
||||
version = "1.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9a7b350e3bb1767102698302bc37256cbd48422809984b98d292c40e2579aa9"
|
||||
dependencies = [
|
||||
"aws-lc-sys",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aws-lc-sys"
|
||||
version = "0.37.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b092fe214090261288111db7a2b2c2118e5a7f30dc2569f1732c4069a6840549"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"cmake",
|
||||
"dunce",
|
||||
"fs_extra",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.22.1"
|
||||
@@ -564,7 +535,7 @@ version = "0.72.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.13.0",
|
||||
@@ -614,9 +585,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.11.0"
|
||||
version = "2.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af"
|
||||
checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3"
|
||||
dependencies = [
|
||||
"serde_core",
|
||||
]
|
||||
@@ -935,15 +906,6 @@ version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d"
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "0.1.57"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
version = "1.0.4"
|
||||
@@ -1023,7 +985,7 @@ version = "0.24.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"core-foundation",
|
||||
"core-graphics-types",
|
||||
"foreign-types",
|
||||
@@ -1036,7 +998,7 @@ version = "0.25.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "064badf302c3194842cf2c5d61f56cc88e54a759313879cdf03abdd27d0c3b97"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"core-foundation",
|
||||
"core-graphics-types",
|
||||
"foreign-types",
|
||||
@@ -1049,7 +1011,7 @@ version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"core-foundation",
|
||||
"libc",
|
||||
]
|
||||
@@ -1419,7 +1381,7 @@ version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89a09f22a6c6069a18470eb92d2298acf25463f14256d24778e1230d789a2aec"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"block2",
|
||||
"libc",
|
||||
"objc2",
|
||||
@@ -1482,12 +1444,6 @@ version = "0.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f678cf4a922c215c63e0de95eb1ff08a958a81d47e485cf9da1e27bf6305cfa5"
|
||||
|
||||
[[package]]
|
||||
name = "dunce"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
|
||||
|
||||
[[package]]
|
||||
name = "ecb"
|
||||
version = "0.1.2"
|
||||
@@ -1710,12 +1666,6 @@ dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fs_extra"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
||||
|
||||
[[package]]
|
||||
name = "futf"
|
||||
version = "0.1.5"
|
||||
@@ -1880,11 +1830,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"r-efi",
|
||||
"wasip2",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1934,7 +1882,7 @@ version = "1.41.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0c43e7c3212bd992c11b6b9796563388170950521ae8487f5cdf6f6e792f1c8"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
@@ -2057,105 +2005,6 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"itoa",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body-util"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"http",
|
||||
"http-body",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httparse"
|
||||
version = "1.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11"
|
||||
dependencies = [
|
||||
"atomic-waker",
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"http",
|
||||
"http-body",
|
||||
"httparse",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"smallvec",
|
||||
"tokio",
|
||||
"want",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-rustls"
|
||||
version = "0.27.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58"
|
||||
dependencies = [
|
||||
"http",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"ipnet",
|
||||
"libc",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"socket2",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.64"
|
||||
@@ -2406,22 +2255,6 @@ dependencies = [
|
||||
"leaky-cow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ipnet"
|
||||
version = "2.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
|
||||
|
||||
[[package]]
|
||||
name = "iri-string"
|
||||
version = "0.7.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "is-terminal"
|
||||
version = "0.4.17"
|
||||
@@ -2545,9 +2378,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "keccak"
|
||||
version = "0.1.6"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653"
|
||||
checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654"
|
||||
dependencies = [
|
||||
"cpufeatures",
|
||||
]
|
||||
@@ -2659,7 +2492,7 @@ version = "0.1.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"libc",
|
||||
"redox_syscall 0.7.0",
|
||||
]
|
||||
@@ -2687,19 +2520,6 @@ version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
|
||||
|
||||
[[package]]
|
||||
name = "llm_adapter"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8dd9a548766bccf8b636695e8d514edee672d180e96a16ab932c971783b4e353"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.17",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.14"
|
||||
@@ -2737,7 +2557,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59fa2559e99ba0f26a12458aabc754432c805bbb8cba516c427825a997af1fb7"
|
||||
dependencies = [
|
||||
"aes",
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"cbc",
|
||||
"ecb",
|
||||
"encoding_rs",
|
||||
@@ -2765,12 +2585,6 @@ dependencies = [
|
||||
"hashbrown 0.16.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lru-slab"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
|
||||
|
||||
[[package]]
|
||||
name = "mac"
|
||||
version = "0.1.1"
|
||||
@@ -2929,7 +2743,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "000f205daae6646003fdc38517be6232af2b150bad4b67bdaf4c5aadb119d738"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"chrono",
|
||||
"ctor",
|
||||
"futures",
|
||||
@@ -2989,7 +2803,7 @@ version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2076a31b7010b17a38c01907c45b945e8f11495ee4dd588309718901b1f7a5b7"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"jni-sys",
|
||||
"log",
|
||||
"ndk-sys",
|
||||
@@ -3024,7 +2838,7 @@ version = "0.30.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"cfg-if",
|
||||
"cfg_aliases",
|
||||
"libc",
|
||||
@@ -3188,7 +3002,7 @@ version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"dispatch2",
|
||||
"objc2",
|
||||
]
|
||||
@@ -3205,7 +3019,7 @@ version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3e0adef53c21f888deb4fa59fc59f7eb17404926ee8a6f59f5df0fd7f9f3272"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"block2",
|
||||
"libc",
|
||||
"objc2",
|
||||
@@ -3262,12 +3076,6 @@ version = "11.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-probe"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
|
||||
|
||||
[[package]]
|
||||
name = "ordered-float"
|
||||
version = "5.1.0"
|
||||
@@ -3663,7 +3471,7 @@ checksum = "bee689443a2bd0a16ab0348b52ee43e3b2d1b1f931c8aa5c9f8de4c86fbe8c40"
|
||||
dependencies = [
|
||||
"bit-set 0.8.0",
|
||||
"bit-vec 0.8.0",
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"num-traits",
|
||||
"rand 0.9.2",
|
||||
"rand_chacha 0.9.0",
|
||||
@@ -3691,7 +3499,7 @@ version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e8bbe1a966bd2f362681a44f6edce3c2310ac21e4d5067a6e7ec396297a6ea0"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"getopts",
|
||||
"memchr",
|
||||
"pulldown-cmark-escape",
|
||||
@@ -3710,62 +3518,6 @@ version = "1.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
|
||||
|
||||
[[package]]
|
||||
name = "quinn"
|
||||
version = "0.11.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"cfg_aliases",
|
||||
"pin-project-lite",
|
||||
"quinn-proto",
|
||||
"quinn-udp",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustls",
|
||||
"socket2",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-proto"
|
||||
version = "0.11.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31"
|
||||
dependencies = [
|
||||
"aws-lc-rs",
|
||||
"bytes",
|
||||
"getrandom 0.3.4",
|
||||
"lru-slab",
|
||||
"rand 0.9.2",
|
||||
"ring",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"slab",
|
||||
"thiserror 2.0.17",
|
||||
"tinyvec",
|
||||
"tracing",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-udp"
|
||||
version = "0.5.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd"
|
||||
dependencies = [
|
||||
"cfg_aliases",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"socket2",
|
||||
"tracing",
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.43"
|
||||
@@ -3913,7 +3665,7 @@ version = "0.5.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3922,7 +3674,7 @@ version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49f3fe0889e69e2ae9e41f4d6c4c0181701d00e4697b356fb1f74173a5e0ee27"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3954,45 +3706,6 @@ version = "0.8.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-rustls",
|
||||
"hyper-util",
|
||||
"js-sys",
|
||||
"log",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"quinn",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"rustls-platform-verifier",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tower-service",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
version = "0.17.14"
|
||||
@@ -4120,7 +3833,7 @@ version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
@@ -4133,7 +3846,6 @@ version = "0.23.36"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b"
|
||||
dependencies = [
|
||||
"aws-lc-rs",
|
||||
"once_cell",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
@@ -4142,62 +3854,21 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-native-certs"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63"
|
||||
dependencies = [
|
||||
"openssl-probe",
|
||||
"rustls-pki-types",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pki-types"
|
||||
version = "1.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282"
|
||||
dependencies = [
|
||||
"web-time",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-platform-verifier"
|
||||
version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784"
|
||||
dependencies = [
|
||||
"core-foundation",
|
||||
"core-foundation-sys",
|
||||
"jni",
|
||||
"log",
|
||||
"once_cell",
|
||||
"rustls",
|
||||
"rustls-native-certs",
|
||||
"rustls-platform-verifier-android",
|
||||
"rustls-webpki",
|
||||
"security-framework",
|
||||
"security-framework-sys",
|
||||
"webpki-root-certs",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-platform-verifier-android"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.103.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52"
|
||||
dependencies = [
|
||||
"aws-lc-rs",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"untrusted",
|
||||
@@ -4236,15 +3907,6 @@ dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scoped-tls"
|
||||
version = "1.0.1"
|
||||
@@ -4293,29 +3955,6 @@ dependencies = [
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "3.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"core-foundation",
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
"security-framework-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework-sys"
|
||||
version = "2.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.27"
|
||||
@@ -4632,7 +4271,7 @@ checksum = "aa003f0038df784eb8fecbbac13affe3da23b45194bd57dba231c8f48199c526"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"base64",
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"chrono",
|
||||
@@ -4675,7 +4314,7 @@ checksum = "db58fcd5a53cf07c184b154801ff91347e4c30d17a3562a635ff028ad5deda46"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"base64",
|
||||
"bitflags 2.11.0",
|
||||
"bitflags 2.10.0",
|
||||
"byteorder",
|
||||
"chrono",
|
||||
"crc",
|
||||
@@ -5041,15 +4680,6 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sync_wrapper"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "synstructure"
|
||||
version = "0.13.2"
|
||||
@@ -5241,16 +4871,6 @@ dependencies = [
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.26.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61"
|
||||
dependencies = [
|
||||
"rustls",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-stream"
|
||||
version = "0.1.18"
|
||||
@@ -5294,58 +4914,13 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "toml_parser"
|
||||
version = "1.0.9+spec-1.1.0"
|
||||
version = "1.0.6+spec-1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4"
|
||||
checksum = "a3198b4b0a8e11f09dd03e133c0280504d0801269e9afa46362ffde1cbeebf44"
|
||||
dependencies = [
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"pin-project-lite",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-http"
|
||||
version = "0.6.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"iri-string",
|
||||
"pin-project-lite",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-layer"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
|
||||
|
||||
[[package]]
|
||||
name = "tower-service"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
|
||||
|
||||
[[package]]
|
||||
name = "tracing"
|
||||
version = "0.1.44"
|
||||
@@ -5548,12 +5123,6 @@ dependencies = [
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "try-lock"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
|
||||
[[package]]
|
||||
name = "type1-encoding-parser"
|
||||
version = "0.1.0"
|
||||
@@ -5874,15 +5443,6 @@ dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "want"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e"
|
||||
dependencies = [
|
||||
"try-lock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.11.1+wasi-snapshot-preview1"
|
||||
@@ -5972,25 +5532,6 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "web-time"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
|
||||
dependencies = [
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-root-certs"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca"
|
||||
dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.26.11"
|
||||
|
||||
@@ -44,7 +44,6 @@ resolver = "3"
|
||||
lasso = { version = "0.7", features = ["multi-threaded"] }
|
||||
lib0 = { version = "0.16", features = ["lib0-serde"] }
|
||||
libc = "0.2"
|
||||
llm_adapter = "0.1.1"
|
||||
log = "0.4"
|
||||
loom = { version = "0.7", features = ["checkpoint"] }
|
||||
lru = "0.16"
|
||||
|
||||
@@ -108,9 +108,7 @@ export class BookmarkBlockComponent extends CaptionedBlockComponent<BookmarkBloc
|
||||
}
|
||||
|
||||
open = () => {
|
||||
const link = this.link;
|
||||
if (!link) return;
|
||||
window.open(link, '_blank', 'noopener,noreferrer');
|
||||
window.open(this.link, '_blank');
|
||||
};
|
||||
|
||||
refreshData = () => {
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
"@blocksuite/sync": "workspace:*",
|
||||
"@floating-ui/dom": "^1.6.13",
|
||||
"@lit/context": "^1.1.2",
|
||||
"@lottiefiles/dotlottie-wc": "^0.9.4",
|
||||
"@lottiefiles/dotlottie-wc": "^0.5.0",
|
||||
"@preact/signals-core": "^1.8.0",
|
||||
"@toeverything/theme": "^1.1.23",
|
||||
"@types/hast": "^3.0.4",
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
import {
|
||||
getHostName,
|
||||
isValidUrl,
|
||||
normalizeUrl,
|
||||
} from '@blocksuite/affine-shared/utils';
|
||||
import { getHostName } from '@blocksuite/affine-shared/utils';
|
||||
import { PropTypes, requiredProperties } from '@blocksuite/std';
|
||||
import { css, LitElement } from 'lit';
|
||||
import { property } from 'lit/decorators.js';
|
||||
@@ -48,27 +44,15 @@ export class LinkPreview extends LitElement {
|
||||
|
||||
override render() {
|
||||
const { url } = this;
|
||||
const normalizedUrl = normalizeUrl(url);
|
||||
const safeUrl =
|
||||
normalizedUrl && isValidUrl(normalizedUrl) ? normalizedUrl : null;
|
||||
const hostName = getHostName(safeUrl ?? url);
|
||||
|
||||
if (!safeUrl) {
|
||||
return html`
|
||||
<span class="affine-link-preview">
|
||||
<span>${hostName}</span>
|
||||
</span>
|
||||
`;
|
||||
}
|
||||
|
||||
return html`
|
||||
<a
|
||||
class="affine-link-preview"
|
||||
rel="noopener noreferrer"
|
||||
target="_blank"
|
||||
href=${safeUrl}
|
||||
href=${url}
|
||||
>
|
||||
<span>${hostName}</span>
|
||||
<span>${getHostName(url)}</span>
|
||||
</a>
|
||||
`;
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import type { FootNote } from '@blocksuite/affine-model';
|
||||
import { CitationProvider } from '@blocksuite/affine-shared/services';
|
||||
import { unsafeCSSVarV2 } from '@blocksuite/affine-shared/theme';
|
||||
import type { AffineTextAttributes } from '@blocksuite/affine-shared/types';
|
||||
import { isValidUrl, normalizeUrl } from '@blocksuite/affine-shared/utils';
|
||||
import { WithDisposable } from '@blocksuite/global/lit';
|
||||
import {
|
||||
BlockSelection,
|
||||
@@ -153,9 +152,7 @@ export class AffineFootnoteNode extends WithDisposable(ShadowlessElement) {
|
||||
};
|
||||
|
||||
private readonly _handleUrlReference = (url: string) => {
|
||||
const normalizedUrl = normalizeUrl(url);
|
||||
if (!normalizedUrl || !isValidUrl(normalizedUrl)) return;
|
||||
window.open(normalizedUrl, '_blank', 'noopener,noreferrer');
|
||||
window.open(url, '_blank');
|
||||
};
|
||||
|
||||
private readonly _updateFootnoteAttributes = (footnote: FootNote) => {
|
||||
|
||||
@@ -24,11 +24,6 @@ const toURL = (str: string) => {
|
||||
}
|
||||
};
|
||||
|
||||
const hasAllowedScheme = (url: URL) => {
|
||||
const protocol = url.protocol.slice(0, -1).toLowerCase();
|
||||
return ALLOWED_SCHEMES.has(protocol);
|
||||
};
|
||||
|
||||
function resolveURL(str: string, baseUrl: string, padded = false) {
|
||||
const url = toURL(str);
|
||||
if (!url) return null;
|
||||
@@ -66,7 +61,6 @@ export function normalizeUrl(str: string) {
|
||||
|
||||
// Formatted
|
||||
if (url) {
|
||||
if (!hasAllowedScheme(url)) return '';
|
||||
if (!str.endsWith('/') && url.href.endsWith('/')) {
|
||||
return url.href.substring(0, url.href.length - 1);
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"@blocksuite/icons": "^2.2.17",
|
||||
"@floating-ui/dom": "^1.6.13",
|
||||
"@lit/context": "^1.1.3",
|
||||
"@lottiefiles/dotlottie-wc": "^0.9.4",
|
||||
"@lottiefiles/dotlottie-wc": "^0.5.0",
|
||||
"@preact/signals-core": "^1.8.0",
|
||||
"@toeverything/theme": "^1.1.23",
|
||||
"@vanilla-extract/css": "^1.17.0",
|
||||
|
||||
+2
-2
@@ -22,7 +22,7 @@
|
||||
"af": "r affine.ts",
|
||||
"dev": "yarn affine dev",
|
||||
"build": "yarn affine build",
|
||||
"lint:eslint": "cross-env NODE_OPTIONS=\"--max-old-space-size=16384\" eslint --report-unused-disable-directives-severity=off . --cache",
|
||||
"lint:eslint": "cross-env NODE_OPTIONS=\"--max-old-space-size=8192\" eslint --report-unused-disable-directives-severity=off . --cache",
|
||||
"lint:eslint:fix": "yarn lint:eslint --fix --fix-type problem,suggestion,layout",
|
||||
"lint:prettier": "prettier --ignore-unknown --cache --check .",
|
||||
"lint:prettier:fix": "prettier --ignore-unknown --cache --write .",
|
||||
@@ -56,7 +56,7 @@
|
||||
"@faker-js/faker": "^10.1.0",
|
||||
"@istanbuljs/schema": "^0.1.3",
|
||||
"@magic-works/i18n-codegen": "^0.6.1",
|
||||
"@playwright/test": "=1.58.2",
|
||||
"@playwright/test": "=1.52.0",
|
||||
"@smarttools/eslint-plugin-rxjs": "^1.0.8",
|
||||
"@taplo/cli": "^0.7.0",
|
||||
"@toeverything/infra": "workspace:*",
|
||||
|
||||
@@ -17,13 +17,10 @@ affine_common = { workspace = true, features = [
|
||||
chrono = { workspace = true }
|
||||
file-format = { workspace = true }
|
||||
infer = { workspace = true }
|
||||
llm_adapter = { workspace = true }
|
||||
mp4parse = { workspace = true }
|
||||
napi = { workspace = true, features = ["async"] }
|
||||
napi-derive = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
sha3 = { workspace = true }
|
||||
tiktoken-rs = { workspace = true }
|
||||
v_htmlescape = { workspace = true }
|
||||
|
||||
Vendored
-8
@@ -1,9 +1,5 @@
|
||||
/* auto-generated by NAPI-RS */
|
||||
/* eslint-disable */
|
||||
export declare class LlmStreamHandle {
|
||||
abort(): void
|
||||
}
|
||||
|
||||
export declare class Tokenizer {
|
||||
count(content: string, allowedSpecial?: Array<string> | undefined | null): number
|
||||
}
|
||||
@@ -50,10 +46,6 @@ export declare function getMime(input: Uint8Array): string
|
||||
|
||||
export declare function htmlSanitize(input: string): string
|
||||
|
||||
export declare function llmDispatch(protocol: string, backendConfigJson: string, requestJson: string): string
|
||||
|
||||
export declare function llmDispatchStream(protocol: string, backendConfigJson: string, requestJson: string, callback: ((err: Error | null, arg: string) => void)): LlmStreamHandle
|
||||
|
||||
/**
|
||||
* Merge updates in form like `Y.applyUpdate(doc, update)` way and return the
|
||||
* result binary.
|
||||
|
||||
@@ -7,7 +7,6 @@ pub mod doc_loader;
|
||||
pub mod file_type;
|
||||
pub mod hashcash;
|
||||
pub mod html_sanitize;
|
||||
pub mod llm;
|
||||
pub mod tiktoken;
|
||||
|
||||
use affine_common::napi_utils::map_napi_err;
|
||||
|
||||
@@ -1,339 +0,0 @@
|
||||
use std::sync::{
|
||||
Arc,
|
||||
atomic::{AtomicBool, Ordering},
|
||||
};
|
||||
|
||||
use llm_adapter::{
|
||||
backend::{
|
||||
BackendConfig, BackendError, BackendProtocol, ReqwestHttpClient, dispatch_request, dispatch_stream_events_with,
|
||||
},
|
||||
core::{CoreRequest, StreamEvent},
|
||||
middleware::{
|
||||
MiddlewareConfig, PipelineContext, RequestMiddleware, StreamMiddleware, citation_indexing, clamp_max_tokens,
|
||||
normalize_messages, run_request_middleware_chain, run_stream_middleware_chain, stream_event_normalize,
|
||||
tool_schema_rewrite,
|
||||
},
|
||||
};
|
||||
use napi::{
|
||||
Error, Result, Status,
|
||||
threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const STREAM_END_MARKER: &str = "__AFFINE_LLM_STREAM_END__";
|
||||
const STREAM_ABORTED_REASON: &str = "__AFFINE_LLM_STREAM_ABORTED__";
|
||||
const STREAM_CALLBACK_DISPATCH_FAILED_REASON: &str = "__AFFINE_LLM_STREAM_CALLBACK_DISPATCH_FAILED__";
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize)]
|
||||
#[serde(default)]
|
||||
struct LlmMiddlewarePayload {
|
||||
request: Vec<String>,
|
||||
stream: Vec<String>,
|
||||
config: MiddlewareConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct LlmDispatchPayload {
|
||||
#[serde(flatten)]
|
||||
request: CoreRequest,
|
||||
#[serde(default)]
|
||||
middleware: LlmMiddlewarePayload,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub struct LlmStreamHandle {
|
||||
aborted: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl LlmStreamHandle {
|
||||
#[napi]
|
||||
pub fn abort(&self) {
|
||||
self.aborted.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub fn llm_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result<String> {
|
||||
let protocol = parse_protocol(&protocol)?;
|
||||
let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?;
|
||||
let payload: LlmDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?;
|
||||
let request = apply_request_middlewares(payload.request, &payload.middleware)?;
|
||||
|
||||
let response =
|
||||
dispatch_request(&ReqwestHttpClient::default(), &config, protocol, &request).map_err(map_backend_error)?;
|
||||
|
||||
serde_json::to_string(&response).map_err(map_json_error)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub fn llm_dispatch_stream(
|
||||
protocol: String,
|
||||
backend_config_json: String,
|
||||
request_json: String,
|
||||
callback: ThreadsafeFunction<String, ()>,
|
||||
) -> Result<LlmStreamHandle> {
|
||||
let protocol = parse_protocol(&protocol)?;
|
||||
let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?;
|
||||
let payload: LlmDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?;
|
||||
let request = apply_request_middlewares(payload.request, &payload.middleware)?;
|
||||
let middleware = payload.middleware.clone();
|
||||
|
||||
let aborted = Arc::new(AtomicBool::new(false));
|
||||
let aborted_in_worker = aborted.clone();
|
||||
|
||||
std::thread::spawn(move || {
|
||||
let chain = match resolve_stream_chain(&middleware.stream) {
|
||||
Ok(chain) => chain,
|
||||
Err(error) => {
|
||||
emit_error_event(&callback, error.reason.clone(), "middleware_error");
|
||||
let _ = callback.call(
|
||||
Ok(STREAM_END_MARKER.to_string()),
|
||||
ThreadsafeFunctionCallMode::NonBlocking,
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let mut pipeline = StreamPipeline::new(chain, middleware.config.clone());
|
||||
let mut aborted_by_user = false;
|
||||
let mut callback_dispatch_failed = false;
|
||||
|
||||
let result = dispatch_stream_events_with(&ReqwestHttpClient::default(), &config, protocol, &request, |event| {
|
||||
if aborted_in_worker.load(Ordering::Relaxed) {
|
||||
aborted_by_user = true;
|
||||
return Err(BackendError::Http(STREAM_ABORTED_REASON.to_string()));
|
||||
}
|
||||
|
||||
for event in pipeline.process(event) {
|
||||
let status = emit_stream_event(&callback, &event);
|
||||
if status != Status::Ok {
|
||||
callback_dispatch_failed = true;
|
||||
return Err(BackendError::Http(format!(
|
||||
"{STREAM_CALLBACK_DISPATCH_FAILED_REASON}:{status}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
});
|
||||
|
||||
if !aborted_by_user {
|
||||
for event in pipeline.finish() {
|
||||
if aborted_in_worker.load(Ordering::Relaxed) {
|
||||
aborted_by_user = true;
|
||||
break;
|
||||
}
|
||||
if emit_stream_event(&callback, &event) != Status::Ok {
|
||||
callback_dispatch_failed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(error) = result
|
||||
&& !aborted_by_user
|
||||
&& !callback_dispatch_failed
|
||||
&& !is_abort_error(&error)
|
||||
&& !is_callback_dispatch_failed_error(&error)
|
||||
{
|
||||
emit_error_event(&callback, error.to_string(), "dispatch_error");
|
||||
}
|
||||
|
||||
if !callback_dispatch_failed {
|
||||
let _ = callback.call(
|
||||
Ok(STREAM_END_MARKER.to_string()),
|
||||
ThreadsafeFunctionCallMode::NonBlocking,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(LlmStreamHandle { aborted })
|
||||
}
|
||||
|
||||
fn apply_request_middlewares(request: CoreRequest, middleware: &LlmMiddlewarePayload) -> Result<CoreRequest> {
|
||||
let chain = resolve_request_chain(&middleware.request)?;
|
||||
Ok(run_request_middleware_chain(request, &middleware.config, &chain))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StreamPipeline {
|
||||
chain: Vec<StreamMiddleware>,
|
||||
config: MiddlewareConfig,
|
||||
context: PipelineContext,
|
||||
}
|
||||
|
||||
impl StreamPipeline {
|
||||
fn new(chain: Vec<StreamMiddleware>, config: MiddlewareConfig) -> Self {
|
||||
Self {
|
||||
chain,
|
||||
config,
|
||||
context: PipelineContext::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn process(&mut self, event: StreamEvent) -> Vec<StreamEvent> {
|
||||
run_stream_middleware_chain(event, &mut self.context, &self.config, &self.chain)
|
||||
}
|
||||
|
||||
fn finish(&mut self) -> Vec<StreamEvent> {
|
||||
self.context.flush_pending_deltas();
|
||||
self.context.drain_queued_events()
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_stream_event(callback: &ThreadsafeFunction<String, ()>, event: &StreamEvent) -> Status {
|
||||
let value = serde_json::to_string(event).unwrap_or_else(|error| {
|
||||
serde_json::json!({
|
||||
"type": "error",
|
||||
"message": format!("failed to serialize stream event: {error}"),
|
||||
})
|
||||
.to_string()
|
||||
});
|
||||
|
||||
callback.call(Ok(value), ThreadsafeFunctionCallMode::NonBlocking)
|
||||
}
|
||||
|
||||
fn emit_error_event(callback: &ThreadsafeFunction<String, ()>, message: String, code: &str) {
|
||||
let error_event = serde_json::to_string(&StreamEvent::Error {
|
||||
message: message.clone(),
|
||||
code: Some(code.to_string()),
|
||||
})
|
||||
.unwrap_or_else(|_| {
|
||||
serde_json::json!({
|
||||
"type": "error",
|
||||
"message": message,
|
||||
"code": code,
|
||||
})
|
||||
.to_string()
|
||||
});
|
||||
|
||||
let _ = callback.call(Ok(error_event), ThreadsafeFunctionCallMode::NonBlocking);
|
||||
}
|
||||
|
||||
fn is_abort_error(error: &BackendError) -> bool {
|
||||
matches!(
|
||||
error,
|
||||
BackendError::Http(reason) if reason == STREAM_ABORTED_REASON
|
||||
)
|
||||
}
|
||||
|
||||
fn is_callback_dispatch_failed_error(error: &BackendError) -> bool {
|
||||
matches!(
|
||||
error,
|
||||
BackendError::Http(reason) if reason.starts_with(STREAM_CALLBACK_DISPATCH_FAILED_REASON)
|
||||
)
|
||||
}
|
||||
|
||||
fn resolve_request_chain(request: &[String]) -> Result<Vec<RequestMiddleware>> {
|
||||
if request.is_empty() {
|
||||
return Ok(vec![normalize_messages, tool_schema_rewrite]);
|
||||
}
|
||||
|
||||
request
|
||||
.iter()
|
||||
.map(|name| match name.as_str() {
|
||||
"normalize_messages" => Ok(normalize_messages as RequestMiddleware),
|
||||
"clamp_max_tokens" => Ok(clamp_max_tokens as RequestMiddleware),
|
||||
"tool_schema_rewrite" => Ok(tool_schema_rewrite as RequestMiddleware),
|
||||
_ => Err(Error::new(
|
||||
Status::InvalidArg,
|
||||
format!("Unsupported request middleware: {name}"),
|
||||
)),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn resolve_stream_chain(stream: &[String]) -> Result<Vec<StreamMiddleware>> {
|
||||
if stream.is_empty() {
|
||||
return Ok(vec![stream_event_normalize, citation_indexing]);
|
||||
}
|
||||
|
||||
stream
|
||||
.iter()
|
||||
.map(|name| match name.as_str() {
|
||||
"stream_event_normalize" => Ok(stream_event_normalize as StreamMiddleware),
|
||||
"citation_indexing" => Ok(citation_indexing as StreamMiddleware),
|
||||
_ => Err(Error::new(
|
||||
Status::InvalidArg,
|
||||
format!("Unsupported stream middleware: {name}"),
|
||||
)),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn parse_protocol(protocol: &str) -> Result<BackendProtocol> {
|
||||
match protocol {
|
||||
"openai_chat" | "openai-chat" | "openai_chat_completions" | "chat-completions" | "chat_completions" => {
|
||||
Ok(BackendProtocol::OpenaiChatCompletions)
|
||||
}
|
||||
"openai_responses" | "openai-responses" | "responses" => Ok(BackendProtocol::OpenaiResponses),
|
||||
"anthropic" | "anthropic_messages" | "anthropic-messages" => Ok(BackendProtocol::AnthropicMessages),
|
||||
other => Err(Error::new(
|
||||
Status::InvalidArg,
|
||||
format!("Unsupported llm backend protocol: {other}"),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn map_json_error(error: serde_json::Error) -> Error {
|
||||
Error::new(Status::InvalidArg, format!("Invalid JSON payload: {error}"))
|
||||
}
|
||||
|
||||
fn map_backend_error(error: BackendError) -> Error {
|
||||
Error::new(Status::GenericFailure, error.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn should_parse_supported_protocol_aliases() {
|
||||
assert!(parse_protocol("openai_chat").is_ok());
|
||||
assert!(parse_protocol("chat-completions").is_ok());
|
||||
assert!(parse_protocol("responses").is_ok());
|
||||
assert!(parse_protocol("anthropic").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_reject_unsupported_protocol() {
|
||||
let error = parse_protocol("unknown").unwrap_err();
|
||||
assert_eq!(error.status, Status::InvalidArg);
|
||||
assert!(error.reason.contains("Unsupported llm backend protocol"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_dispatch_should_reject_invalid_backend_json() {
|
||||
let error = llm_dispatch("openai_chat".to_string(), "{".to_string(), "{}".to_string()).unwrap_err();
|
||||
assert_eq!(error.status, Status::InvalidArg);
|
||||
assert!(error.reason.contains("Invalid JSON payload"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn map_json_error_should_use_invalid_arg_status() {
|
||||
let parse_error = serde_json::from_str::<serde_json::Value>("{").unwrap_err();
|
||||
let error = map_json_error(parse_error);
|
||||
assert_eq!(error.status, Status::InvalidArg);
|
||||
assert!(error.reason.contains("Invalid JSON payload"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_request_chain_should_support_clamp_max_tokens() {
|
||||
let chain = resolve_request_chain(&["normalize_messages".to_string(), "clamp_max_tokens".to_string()]).unwrap();
|
||||
assert_eq!(chain.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_request_chain_should_reject_unknown_middleware() {
|
||||
let error = resolve_request_chain(&["unknown".to_string()]).unwrap_err();
|
||||
assert_eq!(error.status, Status::InvalidArg);
|
||||
assert!(error.reason.contains("Unsupported request middleware"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_stream_chain_should_reject_unknown_middleware() {
|
||||
let error = resolve_stream_chain(&["unknown".to_string()]).unwrap_err();
|
||||
assert_eq!(error.status, Status::InvalidArg);
|
||||
assert!(error.reason.contains("Unsupported stream middleware"));
|
||||
}
|
||||
}
|
||||
@@ -12,9 +12,9 @@
|
||||
"dev": "nodemon ./src/index.ts",
|
||||
"dev:mail": "email dev -d src/mails",
|
||||
"test": "ava --concurrency 1 --serial",
|
||||
"test:copilot": "ava \"src/__tests__/copilot/copilot-*.spec.ts\"",
|
||||
"test:copilot": "ava \"src/__tests__/copilot-*.spec.ts\"",
|
||||
"test:coverage": "c8 ava --concurrency 1 --serial",
|
||||
"test:copilot:coverage": "c8 ava --timeout=5m \"src/__tests__/copilot/copilot-*.spec.ts\"",
|
||||
"test:copilot:coverage": "c8 ava --timeout=5m \"src/__tests__/copilot-*.spec.ts\"",
|
||||
"e2e": "cross-env TEST_MODE=e2e ava --serial",
|
||||
"e2e:coverage": "cross-env TEST_MODE=e2e c8 ava --serial",
|
||||
"data-migration": "cross-env NODE_ENV=development SERVER_FLAVOR=script r ./src/index.ts",
|
||||
@@ -28,8 +28,12 @@
|
||||
"dependencies": {
|
||||
"@affine/s3-compat": "workspace:*",
|
||||
"@affine/server-native": "workspace:*",
|
||||
"@ai-sdk/anthropic": "^2.0.54",
|
||||
"@ai-sdk/google": "^2.0.45",
|
||||
"@ai-sdk/google-vertex": "^3.0.88",
|
||||
"@ai-sdk/openai": "^2.0.80",
|
||||
"@ai-sdk/openai-compatible": "^1.0.28",
|
||||
"@ai-sdk/perplexity": "^2.0.21",
|
||||
"@apollo/server": "^4.13.0",
|
||||
"@fal-ai/serverless-client": "^0.15.0",
|
||||
"@google-cloud/opentelemetry-cloud-trace-exporter": "^3.0.0",
|
||||
@@ -51,18 +55,18 @@
|
||||
"@node-rs/crc32": "^1.10.6",
|
||||
"@opentelemetry/api": "^1.9.0",
|
||||
"@opentelemetry/core": "^2.2.0",
|
||||
"@opentelemetry/exporter-prometheus": "^0.212.0",
|
||||
"@opentelemetry/exporter-prometheus": "^0.211.0",
|
||||
"@opentelemetry/exporter-zipkin": "^2.2.0",
|
||||
"@opentelemetry/host-metrics": "^0.38.0",
|
||||
"@opentelemetry/instrumentation": "^0.212.0",
|
||||
"@opentelemetry/instrumentation-graphql": "^0.60.0",
|
||||
"@opentelemetry/instrumentation-http": "^0.212.0",
|
||||
"@opentelemetry/instrumentation-ioredis": "^0.60.0",
|
||||
"@opentelemetry/instrumentation-nestjs-core": "^0.58.0",
|
||||
"@opentelemetry/instrumentation-socket.io": "^0.59.0",
|
||||
"@opentelemetry/instrumentation": "^0.211.0",
|
||||
"@opentelemetry/instrumentation-graphql": "^0.58.0",
|
||||
"@opentelemetry/instrumentation-http": "^0.211.0",
|
||||
"@opentelemetry/instrumentation-ioredis": "^0.59.0",
|
||||
"@opentelemetry/instrumentation-nestjs-core": "^0.57.0",
|
||||
"@opentelemetry/instrumentation-socket.io": "^0.57.0",
|
||||
"@opentelemetry/resources": "^2.2.0",
|
||||
"@opentelemetry/sdk-metrics": "^2.2.0",
|
||||
"@opentelemetry/sdk-node": "^0.212.0",
|
||||
"@opentelemetry/sdk-node": "^0.211.0",
|
||||
"@opentelemetry/sdk-trace-node": "^2.2.0",
|
||||
"@opentelemetry/semantic-conventions": "^1.38.0",
|
||||
"@prisma/client": "^6.6.0",
|
||||
@@ -122,6 +126,7 @@
|
||||
"@faker-js/faker": "^10.1.0",
|
||||
"@nestjs/swagger": "^11.2.0",
|
||||
"@nestjs/testing": "patch:@nestjs/testing@npm%3A10.4.15#~/.yarn/patches/@nestjs-testing-npm-10.4.15-d591a1705a.patch",
|
||||
"@react-email/preview-server": "^4.3.2",
|
||||
"@types/cookie-parser": "^1.4.8",
|
||||
"@types/express": "^5.0.1",
|
||||
"@types/express-serve-static-core": "^5.0.6",
|
||||
@@ -137,7 +142,7 @@
|
||||
"@types/react-dom": "^19.0.2",
|
||||
"@types/semver": "^7.5.8",
|
||||
"@types/sinon": "^21.0.0",
|
||||
"@types/supertest": "^7.0.0",
|
||||
"@types/supertest": "^6.0.2",
|
||||
"ava": "^6.4.0",
|
||||
"c8": "^10.1.3",
|
||||
"nodemon": "^3.1.14",
|
||||
|
||||
+4
-4
@@ -12,12 +12,12 @@ Generated by [AVA](https://avajs.dev).
|
||||
{
|
||||
messages: [
|
||||
{
|
||||
content: 'generate text to text stream',
|
||||
content: 'generate text to text',
|
||||
role: 'assistant',
|
||||
},
|
||||
],
|
||||
pinned: false,
|
||||
tokens: 10,
|
||||
tokens: 8,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -27,12 +27,12 @@ Generated by [AVA](https://avajs.dev).
|
||||
{
|
||||
messages: [
|
||||
{
|
||||
content: 'generate text to text stream',
|
||||
content: 'generate text to text',
|
||||
role: 'assistant',
|
||||
},
|
||||
],
|
||||
pinned: false,
|
||||
tokens: 10,
|
||||
tokens: 8,
|
||||
},
|
||||
]
|
||||
|
||||
Binary file not shown.
@@ -43,9 +43,7 @@ Generated by [AVA](https://avajs.dev).
|
||||
> Snapshot 5
|
||||
|
||||
Buffer @Uint8Array [
|
||||
89504e47 0d0a1a0a 0000000d 49484452 00000001 00000001 08040000 00b51c0c
|
||||
02000000 0b494441 5478da63 fcff1f00 03030200 efa37c9f 00000000 49454e44
|
||||
ae426082
|
||||
66616b65 20696d61 6765
|
||||
]
|
||||
|
||||
## should preview link
|
||||
|
||||
Binary file not shown.
+12
-12
@@ -4,31 +4,31 @@ import type { ExecutionContext, TestFn } from 'ava';
|
||||
import ava from 'ava';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { ServerFeature, ServerService } from '../../core';
|
||||
import { AuthService } from '../../core/auth';
|
||||
import { QuotaModule } from '../../core/quota';
|
||||
import { Models } from '../../models';
|
||||
import { CopilotModule } from '../../plugins/copilot';
|
||||
import { prompts, PromptService } from '../../plugins/copilot/prompt';
|
||||
import { ServerFeature, ServerService } from '../core';
|
||||
import { AuthService } from '../core/auth';
|
||||
import { QuotaModule } from '../core/quota';
|
||||
import { Models } from '../models';
|
||||
import { CopilotModule } from '../plugins/copilot';
|
||||
import { prompts, PromptService } from '../plugins/copilot/prompt';
|
||||
import {
|
||||
CopilotProviderFactory,
|
||||
CopilotProviderType,
|
||||
StreamObject,
|
||||
StreamObjectSchema,
|
||||
} from '../../plugins/copilot/providers';
|
||||
import { TranscriptionResponseSchema } from '../../plugins/copilot/transcript/types';
|
||||
} from '../plugins/copilot/providers';
|
||||
import { TranscriptionResponseSchema } from '../plugins/copilot/transcript/types';
|
||||
import {
|
||||
CopilotChatTextExecutor,
|
||||
CopilotWorkflowService,
|
||||
GraphExecutorState,
|
||||
} from '../../plugins/copilot/workflow';
|
||||
} from '../plugins/copilot/workflow';
|
||||
import {
|
||||
CopilotChatImageExecutor,
|
||||
CopilotCheckHtmlExecutor,
|
||||
CopilotCheckJsonExecutor,
|
||||
} from '../../plugins/copilot/workflow/executor';
|
||||
import { createTestingModule, TestingModule } from '../utils';
|
||||
import { TestAssets } from '../utils/copilot';
|
||||
} from '../plugins/copilot/workflow/executor';
|
||||
import { createTestingModule, TestingModule } from './utils';
|
||||
import { TestAssets } from './utils/copilot';
|
||||
|
||||
type Tester = {
|
||||
auth: AuthService;
|
||||
+18
-22
@@ -6,25 +6,25 @@ import type { TestFn } from 'ava';
|
||||
import ava from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import { AppModule } from '../../app.module';
|
||||
import { JobQueue } from '../../base';
|
||||
import { ConfigModule } from '../../base/config';
|
||||
import { AuthService } from '../../core/auth';
|
||||
import { DocReader } from '../../core/doc';
|
||||
import { CopilotContextService } from '../../plugins/copilot/context';
|
||||
import { AppModule } from '../app.module';
|
||||
import { JobQueue } from '../base';
|
||||
import { ConfigModule } from '../base/config';
|
||||
import { AuthService } from '../core/auth';
|
||||
import { DocReader } from '../core/doc';
|
||||
import { CopilotContextService } from '../plugins/copilot/context';
|
||||
import {
|
||||
CopilotEmbeddingJob,
|
||||
MockEmbeddingClient,
|
||||
} from '../../plugins/copilot/embedding';
|
||||
import { prompts, PromptService } from '../../plugins/copilot/prompt';
|
||||
} from '../plugins/copilot/embedding';
|
||||
import { prompts, PromptService } from '../plugins/copilot/prompt';
|
||||
import {
|
||||
CopilotProviderFactory,
|
||||
CopilotProviderType,
|
||||
GeminiGenerativeProvider,
|
||||
OpenAIProvider,
|
||||
} from '../../plugins/copilot/providers';
|
||||
import { CopilotStorage } from '../../plugins/copilot/storage';
|
||||
import { MockCopilotProvider } from '../mocks';
|
||||
} from '../plugins/copilot/providers';
|
||||
import { CopilotStorage } from '../plugins/copilot/storage';
|
||||
import { MockCopilotProvider } from './mocks';
|
||||
import {
|
||||
acceptInviteById,
|
||||
createTestingApp,
|
||||
@@ -33,7 +33,7 @@ import {
|
||||
smallestPng,
|
||||
TestingApp,
|
||||
TestUser,
|
||||
} from '../utils';
|
||||
} from './utils';
|
||||
import {
|
||||
addContextDoc,
|
||||
addContextFile,
|
||||
@@ -67,7 +67,7 @@ import {
|
||||
textToEventStream,
|
||||
unsplashSearch,
|
||||
updateCopilotSession,
|
||||
} from '../utils/copilot';
|
||||
} from './utils/copilot';
|
||||
|
||||
const test = ava as TestFn<{
|
||||
auth: AuthService;
|
||||
@@ -513,11 +513,7 @@ test('should be able to chat with api', async t => {
|
||||
);
|
||||
const messageId = await createCopilotMessage(app, sessionId);
|
||||
const ret = await chatWithText(app, sessionId, messageId);
|
||||
t.is(
|
||||
ret,
|
||||
'generate text to text stream',
|
||||
'should be able to chat with text'
|
||||
);
|
||||
t.is(ret, 'generate text to text', 'should be able to chat with text');
|
||||
|
||||
const ret2 = await chatWithTextStream(app, sessionId, messageId);
|
||||
t.is(
|
||||
@@ -661,7 +657,7 @@ test('should be able to retry with api', async t => {
|
||||
const histories = await getHistories(app, { workspaceId: id, docId });
|
||||
t.deepEqual(
|
||||
histories.map(h => h.messages.map(m => m.content)),
|
||||
[['generate text to text stream', 'generate text to text stream']],
|
||||
[['generate text to text', 'generate text to text']],
|
||||
'should be able to list history'
|
||||
);
|
||||
}
|
||||
@@ -798,7 +794,7 @@ test('should be able to list history', async t => {
|
||||
const histories = await getHistories(app, { workspaceId, docId });
|
||||
t.deepEqual(
|
||||
histories.map(h => h.messages.map(m => m.content)),
|
||||
[['hello', 'generate text to text stream']],
|
||||
[['hello', 'generate text to text']],
|
||||
'should be able to list history'
|
||||
);
|
||||
}
|
||||
@@ -811,7 +807,7 @@ test('should be able to list history', async t => {
|
||||
});
|
||||
t.deepEqual(
|
||||
histories.map(h => h.messages.map(m => m.content)),
|
||||
[['generate text to text stream', 'hello']],
|
||||
[['generate text to text', 'hello']],
|
||||
'should be able to list history'
|
||||
);
|
||||
}
|
||||
@@ -862,7 +858,7 @@ test('should reject request that user have not permission', async t => {
|
||||
const histories = await getHistories(app, { workspaceId, docId });
|
||||
t.deepEqual(
|
||||
histories.map(h => h.messages.map(m => m.content)),
|
||||
[['generate text to text stream']],
|
||||
[['generate text to text']],
|
||||
'should able to list history'
|
||||
);
|
||||
|
||||
+38
-101
@@ -8,38 +8,38 @@ import ava from 'ava';
|
||||
import { nanoid } from 'nanoid';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import { EventBus, JobQueue } from '../../base';
|
||||
import { ConfigModule } from '../../base/config';
|
||||
import { AuthService } from '../../core/auth';
|
||||
import { QuotaModule } from '../../core/quota';
|
||||
import { StorageModule, WorkspaceBlobStorage } from '../../core/storage';
|
||||
import { EventBus, JobQueue } from '../base';
|
||||
import { ConfigModule } from '../base/config';
|
||||
import { AuthService } from '../core/auth';
|
||||
import { QuotaModule } from '../core/quota';
|
||||
import { StorageModule, WorkspaceBlobStorage } from '../core/storage';
|
||||
import {
|
||||
ContextCategories,
|
||||
CopilotSessionModel,
|
||||
WorkspaceModel,
|
||||
} from '../../models';
|
||||
import { CopilotModule } from '../../plugins/copilot';
|
||||
import { CopilotContextService } from '../../plugins/copilot/context';
|
||||
import { CopilotCronJobs } from '../../plugins/copilot/cron';
|
||||
} from '../models';
|
||||
import { CopilotModule } from '../plugins/copilot';
|
||||
import { CopilotContextService } from '../plugins/copilot/context';
|
||||
import { CopilotCronJobs } from '../plugins/copilot/cron';
|
||||
import {
|
||||
CopilotEmbeddingJob,
|
||||
MockEmbeddingClient,
|
||||
} from '../../plugins/copilot/embedding';
|
||||
import { prompts, PromptService } from '../../plugins/copilot/prompt';
|
||||
} from '../plugins/copilot/embedding';
|
||||
import { prompts, PromptService } from '../plugins/copilot/prompt';
|
||||
import {
|
||||
CopilotProviderFactory,
|
||||
CopilotProviderType,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
OpenAIProvider,
|
||||
} from '../../plugins/copilot/providers';
|
||||
} from '../plugins/copilot/providers';
|
||||
import {
|
||||
CitationParser,
|
||||
TextStreamParser,
|
||||
} from '../../plugins/copilot/providers/utils';
|
||||
import { ChatSessionService } from '../../plugins/copilot/session';
|
||||
import { CopilotStorage } from '../../plugins/copilot/storage';
|
||||
import { CopilotTranscriptionService } from '../../plugins/copilot/transcript';
|
||||
} from '../plugins/copilot/providers/utils';
|
||||
import { ChatSessionService } from '../plugins/copilot/session';
|
||||
import { CopilotStorage } from '../plugins/copilot/storage';
|
||||
import { CopilotTranscriptionService } from '../plugins/copilot/transcript';
|
||||
import {
|
||||
CopilotChatTextExecutor,
|
||||
CopilotWorkflowService,
|
||||
@@ -48,7 +48,7 @@ import {
|
||||
WorkflowGraphExecutor,
|
||||
type WorkflowNodeData,
|
||||
WorkflowNodeType,
|
||||
} from '../../plugins/copilot/workflow';
|
||||
} from '../plugins/copilot/workflow';
|
||||
import {
|
||||
CopilotChatImageExecutor,
|
||||
CopilotCheckHtmlExecutor,
|
||||
@@ -56,16 +56,16 @@ import {
|
||||
getWorkflowExecutor,
|
||||
NodeExecuteState,
|
||||
NodeExecutorType,
|
||||
} from '../../plugins/copilot/workflow/executor';
|
||||
import { AutoRegisteredWorkflowExecutor } from '../../plugins/copilot/workflow/executor/utils';
|
||||
import { WorkflowGraphList } from '../../plugins/copilot/workflow/graph';
|
||||
import { CopilotWorkspaceService } from '../../plugins/copilot/workspace';
|
||||
import { PaymentModule } from '../../plugins/payment';
|
||||
import { SubscriptionService } from '../../plugins/payment/service';
|
||||
import { SubscriptionStatus } from '../../plugins/payment/types';
|
||||
import { MockCopilotProvider } from '../mocks';
|
||||
import { createTestingModule, TestingModule } from '../utils';
|
||||
import { WorkflowTestCases } from '../utils/copilot';
|
||||
} from '../plugins/copilot/workflow/executor';
|
||||
import { AutoRegisteredWorkflowExecutor } from '../plugins/copilot/workflow/executor/utils';
|
||||
import { WorkflowGraphList } from '../plugins/copilot/workflow/graph';
|
||||
import { CopilotWorkspaceService } from '../plugins/copilot/workspace';
|
||||
import { PaymentModule } from '../plugins/payment';
|
||||
import { SubscriptionService } from '../plugins/payment/service';
|
||||
import { SubscriptionStatus } from '../plugins/payment/types';
|
||||
import { MockCopilotProvider } from './mocks';
|
||||
import { createTestingModule, TestingModule } from './utils';
|
||||
import { WorkflowTestCases } from './utils/copilot';
|
||||
|
||||
type Context = {
|
||||
auth: AuthService;
|
||||
@@ -364,21 +364,6 @@ test('should be able to manage chat session', async t => {
|
||||
});
|
||||
t.is(newSessionId, sessionId, 'should get same session id');
|
||||
}
|
||||
|
||||
// should create a fresh session when reuseLatestChat is explicitly disabled
|
||||
{
|
||||
const newSessionId = await session.create({
|
||||
userId,
|
||||
promptName,
|
||||
...commonParams,
|
||||
reuseLatestChat: false,
|
||||
});
|
||||
t.not(
|
||||
newSessionId,
|
||||
sessionId,
|
||||
'should create new session id when reuseLatestChat is false'
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
test('should be able to update chat session prompt', async t => {
|
||||
@@ -896,26 +881,6 @@ test('should be able to get provider', async t => {
|
||||
}
|
||||
});
|
||||
|
||||
test('should resolve provider by prefixed model id', async t => {
|
||||
const { factory } = t.context;
|
||||
|
||||
const provider = await factory.getProviderByModel('openai-default/test');
|
||||
t.truthy(provider, 'should resolve prefixed model id');
|
||||
t.is(provider?.type, CopilotProviderType.OpenAI);
|
||||
|
||||
const result = await provider?.text({ modelId: 'openai-default/test' }, [
|
||||
{ role: 'user', content: 'hello' },
|
||||
]);
|
||||
t.is(result, 'generate text to text');
|
||||
});
|
||||
|
||||
test('should fallback to null when prefixed provider id does not exist', async t => {
|
||||
const { factory } = t.context;
|
||||
|
||||
const provider = await factory.getProviderByModel('unknown/test');
|
||||
t.is(provider, null);
|
||||
});
|
||||
|
||||
// ==================== workflow ====================
|
||||
|
||||
// this test used to preview the final result of the workflow
|
||||
@@ -2098,23 +2063,25 @@ test('should handle copilot cron jobs correctly', async t => {
|
||||
});
|
||||
|
||||
test('should resolve model correctly based on subscription status and prompt config', async t => {
|
||||
const { prompt, session, subscription } = t.context;
|
||||
const { db, session, subscription } = t.context;
|
||||
|
||||
// 1) Seed a prompt that has optionalModels and proModels in config
|
||||
const promptName = 'resolve-model-test';
|
||||
await prompt.set(
|
||||
promptName,
|
||||
'gemini-2.5-flash',
|
||||
[{ role: 'system', content: 'test' }],
|
||||
{ proModels: ['gemini-2.5-pro', 'claude-sonnet-4-5@20250929'] },
|
||||
{
|
||||
await db.aiPrompt.create({
|
||||
data: {
|
||||
name: promptName,
|
||||
model: 'gemini-2.5-flash',
|
||||
messages: {
|
||||
create: [{ idx: 0, role: 'system', content: 'test' }],
|
||||
},
|
||||
config: { proModels: ['gemini-2.5-pro', 'claude-sonnet-4-5@20250929'] },
|
||||
optionalModels: [
|
||||
'gemini-2.5-flash',
|
||||
'gemini-2.5-pro',
|
||||
'claude-sonnet-4-5@20250929',
|
||||
],
|
||||
}
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
// 2) Create a chat session with this prompt
|
||||
const sessionId = await session.create({
|
||||
@@ -2139,16 +2106,6 @@ test('should resolve model correctly based on subscription status and prompt con
|
||||
const model1 = await s.resolveModel(false, 'gemini-2.5-pro');
|
||||
t.snapshot(model1, 'should honor requested pro model');
|
||||
|
||||
const model1WithPrefix = await s.resolveModel(
|
||||
false,
|
||||
'openai-default/gemini-2.5-pro'
|
||||
);
|
||||
t.is(
|
||||
model1WithPrefix,
|
||||
'openai-default/gemini-2.5-pro',
|
||||
'should honor requested prefixed pro model'
|
||||
);
|
||||
|
||||
const model2 = await s.resolveModel(false, 'not-in-optional');
|
||||
t.snapshot(model2, 'should fallback to default model');
|
||||
}
|
||||
@@ -2162,16 +2119,6 @@ test('should resolve model correctly based on subscription status and prompt con
|
||||
'should fallback to default model when requesting pro model during trialing'
|
||||
);
|
||||
|
||||
const model3WithPrefix = await s.resolveModel(
|
||||
true,
|
||||
'openai-default/gemini-2.5-pro'
|
||||
);
|
||||
t.is(
|
||||
model3WithPrefix,
|
||||
'gemini-2.5-flash',
|
||||
'should fallback to default model when requesting prefixed pro model during trialing'
|
||||
);
|
||||
|
||||
const model4 = await s.resolveModel(true, 'gemini-2.5-flash');
|
||||
t.snapshot(model4, 'should honor requested non-pro model during trialing');
|
||||
|
||||
@@ -2194,16 +2141,6 @@ test('should resolve model correctly based on subscription status and prompt con
|
||||
const model7 = await s.resolveModel(true, 'claude-sonnet-4-5@20250929');
|
||||
t.snapshot(model7, 'should honor requested pro model during active');
|
||||
|
||||
const model7WithPrefix = await s.resolveModel(
|
||||
true,
|
||||
'openai-default/claude-sonnet-4-5@20250929'
|
||||
);
|
||||
t.is(
|
||||
model7WithPrefix,
|
||||
'openai-default/claude-sonnet-4-5@20250929',
|
||||
'should honor requested prefixed pro model during active'
|
||||
);
|
||||
|
||||
const model8 = await s.resolveModel(true, 'not-in-optional');
|
||||
t.snapshot(
|
||||
model8,
|
||||
Binary file not shown.
@@ -1,210 +0,0 @@
|
||||
import test from 'ava';
|
||||
import { z } from 'zod';
|
||||
|
||||
import type { NativeLlmRequest, NativeLlmStreamEvent } from '../../native';
|
||||
import {
|
||||
buildNativeRequest,
|
||||
NativeProviderAdapter,
|
||||
} from '../../plugins/copilot/providers/native';
|
||||
|
||||
const mockDispatch = () =>
|
||||
(async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
yield { type: 'text_delta', text: 'Use [^1] now' };
|
||||
yield { type: 'citation', index: 1, url: 'https://affine.pro' };
|
||||
yield { type: 'done', finish_reason: 'stop' };
|
||||
})();
|
||||
|
||||
test('NativeProviderAdapter streamText should append citation footnotes', async t => {
|
||||
const adapter = new NativeProviderAdapter(mockDispatch, {}, 3);
|
||||
const chunks: string[] = [];
|
||||
for await (const chunk of adapter.streamText({
|
||||
model: 'gpt-4.1',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
|
||||
})) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
const text = chunks.join('');
|
||||
t.true(text.includes('Use [^1] now'));
|
||||
t.true(
|
||||
text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}')
|
||||
);
|
||||
});
|
||||
|
||||
test('NativeProviderAdapter streamObject should append citation footnotes', async t => {
|
||||
const adapter = new NativeProviderAdapter(mockDispatch, {}, 3);
|
||||
const chunks = [];
|
||||
for await (const chunk of adapter.streamObject({
|
||||
model: 'gpt-4.1',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
|
||||
})) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
t.deepEqual(
|
||||
chunks.map(chunk => chunk.type),
|
||||
['text-delta', 'text-delta']
|
||||
);
|
||||
const text = chunks
|
||||
.filter(chunk => chunk.type === 'text-delta')
|
||||
.map(chunk => chunk.textDelta)
|
||||
.join('');
|
||||
t.true(text.includes('Use [^1] now'));
|
||||
t.true(
|
||||
text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}')
|
||||
);
|
||||
});
|
||||
|
||||
test('NativeProviderAdapter streamObject should append fallback attachment footnotes', async t => {
|
||||
const dispatch = () =>
|
||||
(async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
yield {
|
||||
type: 'tool_result',
|
||||
call_id: 'call_1',
|
||||
name: 'blob_read',
|
||||
arguments: { blob_id: 'blob_1' },
|
||||
output: {
|
||||
blobId: 'blob_1',
|
||||
fileName: 'a.txt',
|
||||
fileType: 'text/plain',
|
||||
content: 'A',
|
||||
},
|
||||
};
|
||||
yield {
|
||||
type: 'tool_result',
|
||||
call_id: 'call_2',
|
||||
name: 'blob_read',
|
||||
arguments: { blob_id: 'blob_2' },
|
||||
output: {
|
||||
blobId: 'blob_2',
|
||||
fileName: 'b.txt',
|
||||
fileType: 'text/plain',
|
||||
content: 'B',
|
||||
},
|
||||
};
|
||||
yield { type: 'text_delta', text: 'Answer from files.' };
|
||||
yield { type: 'done', finish_reason: 'stop' };
|
||||
})();
|
||||
|
||||
const adapter = new NativeProviderAdapter(dispatch, {}, 3);
|
||||
const chunks = [];
|
||||
for await (const chunk of adapter.streamObject({
|
||||
model: 'gpt-4.1',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
|
||||
})) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
const text = chunks
|
||||
.filter(chunk => chunk.type === 'text-delta')
|
||||
.map(chunk => chunk.textDelta)
|
||||
.join('');
|
||||
t.true(text.includes('Answer from files.'));
|
||||
t.true(text.includes('[^1][^2]'));
|
||||
t.true(
|
||||
text.includes(
|
||||
'[^1]: {"type":"attachment","blobId":"blob_1","fileName":"a.txt","fileType":"text/plain"}'
|
||||
)
|
||||
);
|
||||
t.true(
|
||||
text.includes(
|
||||
'[^2]: {"type":"attachment","blobId":"blob_2","fileName":"b.txt","fileType":"text/plain"}'
|
||||
)
|
||||
);
|
||||
});
|
||||
|
||||
test('NativeProviderAdapter streamObject should map tool and text events', async t => {
|
||||
let round = 0;
|
||||
const dispatch = (_request: NativeLlmRequest) =>
|
||||
(async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
round += 1;
|
||||
if (round === 1) {
|
||||
yield {
|
||||
type: 'tool_call',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments: { doc_id: 'a1' },
|
||||
};
|
||||
yield { type: 'done', finish_reason: 'tool_calls' };
|
||||
return;
|
||||
}
|
||||
yield { type: 'text_delta', text: 'ok' };
|
||||
yield { type: 'done', finish_reason: 'stop' };
|
||||
})();
|
||||
|
||||
const adapter = new NativeProviderAdapter(
|
||||
dispatch,
|
||||
{
|
||||
doc_read: {
|
||||
inputSchema: z.object({ doc_id: z.string() }),
|
||||
execute: async () => ({ markdown: '# a1' }),
|
||||
},
|
||||
},
|
||||
4
|
||||
);
|
||||
|
||||
const events = [];
|
||||
for await (const event of adapter.streamObject({
|
||||
model: 'gpt-4.1',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'read' }] }],
|
||||
})) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
t.deepEqual(
|
||||
events.map(event => event.type),
|
||||
['tool-call', 'tool-result', 'text-delta']
|
||||
);
|
||||
t.deepEqual(events[0], {
|
||||
type: 'tool-call',
|
||||
toolCallId: 'call_1',
|
||||
toolName: 'doc_read',
|
||||
args: { doc_id: 'a1' },
|
||||
});
|
||||
});
|
||||
|
||||
test('buildNativeRequest should include rust middleware from profile', async t => {
|
||||
const { request } = await buildNativeRequest({
|
||||
model: 'gpt-4.1',
|
||||
messages: [{ role: 'user', content: 'hello' }],
|
||||
tools: {},
|
||||
middleware: {
|
||||
rust: {
|
||||
request: ['normalize_messages', 'clamp_max_tokens'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
},
|
||||
node: {
|
||||
text: ['callout'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
t.deepEqual(request.middleware, {
|
||||
request: ['normalize_messages', 'clamp_max_tokens'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
});
|
||||
});
|
||||
|
||||
test('NativeProviderAdapter streamText should skip citation footnotes when disabled', async t => {
|
||||
const adapter = new NativeProviderAdapter(mockDispatch, {}, 3, {
|
||||
nodeTextMiddleware: ['callout'],
|
||||
});
|
||||
const chunks: string[] = [];
|
||||
for await (const chunk of adapter.streamText({
|
||||
model: 'gpt-4.1',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }],
|
||||
})) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
const text = chunks.join('');
|
||||
t.true(text.includes('Use [^1] now'));
|
||||
t.false(
|
||||
text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}')
|
||||
);
|
||||
});
|
||||
@@ -1,56 +0,0 @@
|
||||
import test from 'ava';
|
||||
|
||||
import { resolveProviderMiddleware } from '../../plugins/copilot/providers/provider-middleware';
|
||||
import { buildProviderRegistry } from '../../plugins/copilot/providers/provider-registry';
|
||||
import { CopilotProviderType } from '../../plugins/copilot/providers/types';
|
||||
|
||||
test('resolveProviderMiddleware should include anthropic defaults', t => {
|
||||
const middleware = resolveProviderMiddleware(CopilotProviderType.Anthropic);
|
||||
|
||||
t.deepEqual(middleware.rust?.request, [
|
||||
'normalize_messages',
|
||||
'tool_schema_rewrite',
|
||||
]);
|
||||
t.deepEqual(middleware.rust?.stream, [
|
||||
'stream_event_normalize',
|
||||
'citation_indexing',
|
||||
]);
|
||||
t.deepEqual(middleware.node?.text, ['citation_footnote', 'callout']);
|
||||
});
|
||||
|
||||
test('resolveProviderMiddleware should merge defaults and overrides', t => {
|
||||
const middleware = resolveProviderMiddleware(CopilotProviderType.OpenAI, {
|
||||
rust: { request: ['clamp_max_tokens'] },
|
||||
node: { text: ['thinking_format'] },
|
||||
});
|
||||
|
||||
t.deepEqual(middleware.rust?.request, [
|
||||
'normalize_messages',
|
||||
'clamp_max_tokens',
|
||||
]);
|
||||
t.deepEqual(middleware.node?.text, [
|
||||
'citation_footnote',
|
||||
'callout',
|
||||
'thinking_format',
|
||||
]);
|
||||
});
|
||||
|
||||
test('buildProviderRegistry should normalize profile middleware defaults', t => {
|
||||
const registry = buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
config: { apiKey: '1' },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const profile = registry.profiles.get('openai-main');
|
||||
t.truthy(profile);
|
||||
t.deepEqual(profile?.middleware.rust?.stream, [
|
||||
'stream_event_normalize',
|
||||
'citation_indexing',
|
||||
]);
|
||||
t.deepEqual(profile?.middleware.node?.text, ['citation_footnote', 'callout']);
|
||||
});
|
||||
@@ -1,99 +0,0 @@
|
||||
import test from 'ava';
|
||||
|
||||
import { ProviderMiddlewareConfig } from '../../plugins/copilot/config';
|
||||
import { CopilotProvider } from '../../plugins/copilot/providers/provider';
|
||||
import {
|
||||
CopilotProviderType,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
} from '../../plugins/copilot/providers/types';
|
||||
|
||||
class TestOpenAIProvider extends CopilotProvider<{ apiKey: string }> {
|
||||
readonly type = CopilotProviderType.OpenAI;
|
||||
readonly models = [
|
||||
{
|
||||
id: 'gpt-4.1',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text],
|
||||
output: [ModelOutputType.Text],
|
||||
defaultForOutputType: true,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
configured() {
|
||||
return true;
|
||||
}
|
||||
|
||||
async text(_cond: any, _messages: any[], _options?: any) {
|
||||
return '';
|
||||
}
|
||||
|
||||
async *streamText(_cond: any, _messages: any[], _options?: any) {
|
||||
yield '';
|
||||
}
|
||||
|
||||
exposeMetricLabels() {
|
||||
return this.metricLabels('gpt-4.1');
|
||||
}
|
||||
|
||||
exposeMiddleware() {
|
||||
return this.getActiveProviderMiddleware();
|
||||
}
|
||||
}
|
||||
|
||||
function createProvider(profileMiddleware?: ProviderMiddlewareConfig) {
|
||||
const provider = new TestOpenAIProvider();
|
||||
(provider as any).AFFiNEConfig = {
|
||||
copilot: {
|
||||
providers: {
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
config: { apiKey: 'test' },
|
||||
middleware: profileMiddleware,
|
||||
},
|
||||
],
|
||||
defaults: {},
|
||||
openai: { apiKey: 'legacy' },
|
||||
},
|
||||
},
|
||||
};
|
||||
return provider;
|
||||
}
|
||||
|
||||
test('metricLabels should include active provider id', t => {
|
||||
const provider = createProvider();
|
||||
const labels = provider.runWithProfile('openai-main', () =>
|
||||
provider.exposeMetricLabels()
|
||||
);
|
||||
t.is(labels.providerId, 'openai-main');
|
||||
});
|
||||
|
||||
test('getActiveProviderMiddleware should merge defaults with profile override', t => {
|
||||
const provider = createProvider({
|
||||
rust: { request: ['clamp_max_tokens'] },
|
||||
node: { text: ['thinking_format'] },
|
||||
});
|
||||
|
||||
const middleware = provider.runWithProfile('openai-main', () =>
|
||||
provider.exposeMiddleware()
|
||||
);
|
||||
|
||||
t.deepEqual(middleware.rust?.request, [
|
||||
'normalize_messages',
|
||||
'clamp_max_tokens',
|
||||
]);
|
||||
t.deepEqual(middleware.rust?.stream, [
|
||||
'stream_event_normalize',
|
||||
'citation_indexing',
|
||||
]);
|
||||
t.deepEqual(middleware.node?.text, [
|
||||
'citation_footnote',
|
||||
'callout',
|
||||
'thinking_format',
|
||||
]);
|
||||
});
|
||||
@@ -1,165 +0,0 @@
|
||||
import test from 'ava';
|
||||
|
||||
import {
|
||||
buildProviderRegistry,
|
||||
resolveModel,
|
||||
stripProviderPrefix,
|
||||
} from '../../plugins/copilot/providers/provider-registry';
|
||||
import {
|
||||
CopilotProviderType,
|
||||
ModelOutputType,
|
||||
} from '../../plugins/copilot/providers/types';
|
||||
|
||||
test('buildProviderRegistry should keep explicit profile over legacy compatibility profile', t => {
|
||||
const registry = buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-default',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
priority: 100,
|
||||
config: { apiKey: 'new' },
|
||||
},
|
||||
],
|
||||
openai: { apiKey: 'legacy' },
|
||||
});
|
||||
|
||||
const profile = registry.profiles.get('openai-default');
|
||||
t.truthy(profile);
|
||||
t.deepEqual(profile?.config, { apiKey: 'new' });
|
||||
});
|
||||
|
||||
test('buildProviderRegistry should reject duplicated profile ids', t => {
|
||||
const error = t.throws(() =>
|
||||
buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
config: { apiKey: '1' },
|
||||
},
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
config: { apiKey: '2' },
|
||||
},
|
||||
],
|
||||
})
|
||||
) as Error;
|
||||
|
||||
t.truthy(error);
|
||||
t.regex(error.message, /Duplicated copilot provider profile id/);
|
||||
});
|
||||
|
||||
test('buildProviderRegistry should reject defaults that reference unknown providers', t => {
|
||||
const error = t.throws(() =>
|
||||
buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
config: { apiKey: '1' },
|
||||
},
|
||||
],
|
||||
defaults: {
|
||||
fallback: 'unknown-provider',
|
||||
},
|
||||
})
|
||||
) as Error;
|
||||
|
||||
t.truthy(error);
|
||||
t.regex(error.message, /defaults references unknown providerId/);
|
||||
});
|
||||
|
||||
test('resolveModel should support explicit provider prefix and keep slash models untouched', t => {
|
||||
const registry = buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
config: { apiKey: '1' },
|
||||
},
|
||||
{
|
||||
id: 'fal-main',
|
||||
type: CopilotProviderType.FAL,
|
||||
config: { apiKey: '2' },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const prefixed = resolveModel({
|
||||
registry,
|
||||
modelId: 'openai-main/gpt-4.1',
|
||||
});
|
||||
t.deepEqual(prefixed, {
|
||||
rawModelId: 'openai-main/gpt-4.1',
|
||||
modelId: 'gpt-4.1',
|
||||
explicitProviderId: 'openai-main',
|
||||
candidateProviderIds: ['openai-main'],
|
||||
});
|
||||
|
||||
const slashModel = resolveModel({
|
||||
registry,
|
||||
modelId: 'lora/image-to-image',
|
||||
});
|
||||
t.is(slashModel.modelId, 'lora/image-to-image');
|
||||
t.false(slashModel.candidateProviderIds.includes('lora'));
|
||||
});
|
||||
|
||||
test('resolveModel should follow defaults -> fallback -> order and apply filters', t => {
|
||||
const registry = buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
priority: 10,
|
||||
config: { apiKey: '1' },
|
||||
},
|
||||
{
|
||||
id: 'anthropic-main',
|
||||
type: CopilotProviderType.Anthropic,
|
||||
priority: 5,
|
||||
config: { apiKey: '2' },
|
||||
},
|
||||
{
|
||||
id: 'fal-main',
|
||||
type: CopilotProviderType.FAL,
|
||||
priority: 1,
|
||||
config: { apiKey: '3' },
|
||||
},
|
||||
],
|
||||
defaults: {
|
||||
[ModelOutputType.Text]: 'anthropic-main',
|
||||
fallback: 'openai-main',
|
||||
},
|
||||
});
|
||||
|
||||
const routed = resolveModel({
|
||||
registry,
|
||||
outputType: ModelOutputType.Text,
|
||||
preferredProviderIds: ['openai-main', 'fal-main'],
|
||||
});
|
||||
|
||||
t.deepEqual(routed.candidateProviderIds, ['openai-main', 'fal-main']);
|
||||
});
|
||||
|
||||
test('stripProviderPrefix should only strip matched provider prefix', t => {
|
||||
const registry = buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
config: { apiKey: '1' },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
t.is(
|
||||
stripProviderPrefix(registry, 'openai-main', 'openai-main/gpt-4.1'),
|
||||
'gpt-4.1'
|
||||
);
|
||||
t.is(
|
||||
stripProviderPrefix(registry, 'openai-main', 'another-main/gpt-4.1'),
|
||||
'another-main/gpt-4.1'
|
||||
);
|
||||
t.is(stripProviderPrefix(registry, 'openai-main', 'gpt-4.1'), 'gpt-4.1');
|
||||
});
|
||||
@@ -1,134 +0,0 @@
|
||||
import test from 'ava';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { NativeLlmRequest, NativeLlmStreamEvent } from '../../native';
|
||||
import {
|
||||
ToolCallAccumulator,
|
||||
ToolCallLoop,
|
||||
ToolSchemaExtractor,
|
||||
} from '../../plugins/copilot/providers/loop';
|
||||
|
||||
test('ToolCallAccumulator should merge deltas and complete tool call', t => {
|
||||
const accumulator = new ToolCallAccumulator();
|
||||
|
||||
accumulator.feedDelta({
|
||||
type: 'tool_call_delta',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments_delta: '{"doc_id":"',
|
||||
});
|
||||
accumulator.feedDelta({
|
||||
type: 'tool_call_delta',
|
||||
call_id: 'call_1',
|
||||
arguments_delta: 'a1"}',
|
||||
});
|
||||
|
||||
const completed = accumulator.complete({
|
||||
type: 'tool_call',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments: { doc_id: 'a1' },
|
||||
});
|
||||
|
||||
t.deepEqual(completed, {
|
||||
id: 'call_1',
|
||||
name: 'doc_read',
|
||||
args: { doc_id: 'a1' },
|
||||
thought: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
test('ToolSchemaExtractor should convert zod schema to json schema', t => {
|
||||
const toolSet = {
|
||||
doc_read: {
|
||||
description: 'Read doc',
|
||||
inputSchema: z.object({
|
||||
doc_id: z.string(),
|
||||
limit: z.number().optional(),
|
||||
}),
|
||||
execute: async () => ({}),
|
||||
},
|
||||
};
|
||||
|
||||
const extracted = ToolSchemaExtractor.extract(toolSet);
|
||||
|
||||
t.deepEqual(extracted, [
|
||||
{
|
||||
name: 'doc_read',
|
||||
description: 'Read doc',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
doc_id: { type: 'string' },
|
||||
limit: { type: 'number' },
|
||||
},
|
||||
additionalProperties: false,
|
||||
required: ['doc_id'],
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
test('ToolCallLoop should execute tool call and continue to next round', async t => {
|
||||
const dispatchRequests: NativeLlmRequest[] = [];
|
||||
|
||||
const dispatch = (request: NativeLlmRequest) => {
|
||||
dispatchRequests.push(request);
|
||||
const round = dispatchRequests.length;
|
||||
|
||||
return (async function* (): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
if (round === 1) {
|
||||
yield {
|
||||
type: 'tool_call_delta',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments_delta: '{"doc_id":"a1"}',
|
||||
};
|
||||
yield {
|
||||
type: 'tool_call',
|
||||
call_id: 'call_1',
|
||||
name: 'doc_read',
|
||||
arguments: { doc_id: 'a1' },
|
||||
};
|
||||
yield { type: 'done', finish_reason: 'tool_calls' };
|
||||
return;
|
||||
}
|
||||
|
||||
yield { type: 'text_delta', text: 'done' };
|
||||
yield { type: 'done', finish_reason: 'stop' };
|
||||
})();
|
||||
};
|
||||
|
||||
let executedArgs: Record<string, unknown> | null = null;
|
||||
const loop = new ToolCallLoop(
|
||||
dispatch,
|
||||
{
|
||||
doc_read: {
|
||||
inputSchema: z.object({ doc_id: z.string() }),
|
||||
execute: async args => {
|
||||
executedArgs = args;
|
||||
return { markdown: '# doc' };
|
||||
},
|
||||
},
|
||||
},
|
||||
4
|
||||
);
|
||||
|
||||
const events: NativeLlmStreamEvent[] = [];
|
||||
for await (const event of loop.run({
|
||||
model: 'gpt-4.1',
|
||||
stream: true,
|
||||
messages: [{ role: 'user', content: [{ type: 'text', text: 'read doc' }] }],
|
||||
})) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
t.deepEqual(executedArgs, { doc_id: 'a1' });
|
||||
t.true(
|
||||
dispatchRequests[1]?.messages.some(message => message.role === 'tool')
|
||||
);
|
||||
t.deepEqual(
|
||||
events.map(event => event.type),
|
||||
['tool_call', 'tool_result', 'text_delta', 'done']
|
||||
);
|
||||
});
|
||||
@@ -1,116 +0,0 @@
|
||||
import test from 'ava';
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
chatToGPTMessage,
|
||||
CitationFootnoteFormatter,
|
||||
CitationParser,
|
||||
StreamPatternParser,
|
||||
} from '../../plugins/copilot/providers/utils';
|
||||
|
||||
test('CitationFootnoteFormatter should format sorted footnotes from citation events', t => {
|
||||
const formatter = new CitationFootnoteFormatter();
|
||||
|
||||
formatter.consume({
|
||||
type: 'citation',
|
||||
index: 2,
|
||||
url: 'https://example.com/b',
|
||||
});
|
||||
formatter.consume({
|
||||
type: 'citation',
|
||||
index: 1,
|
||||
url: 'https://example.com/a',
|
||||
});
|
||||
|
||||
t.is(
|
||||
formatter.end(),
|
||||
[
|
||||
'[^1]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fa"}',
|
||||
'[^2]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fb"}',
|
||||
].join('\n')
|
||||
);
|
||||
});
|
||||
|
||||
test('CitationFootnoteFormatter should overwrite duplicated index with latest url', t => {
|
||||
const formatter = new CitationFootnoteFormatter();
|
||||
|
||||
formatter.consume({
|
||||
type: 'citation',
|
||||
index: 1,
|
||||
url: 'https://example.com/old',
|
||||
});
|
||||
formatter.consume({
|
||||
type: 'citation',
|
||||
index: 1,
|
||||
url: 'https://example.com/new',
|
||||
});
|
||||
|
||||
t.is(
|
||||
formatter.end(),
|
||||
'[^1]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fnew"}'
|
||||
);
|
||||
});
|
||||
|
||||
test('StreamPatternParser should keep state across chunks', t => {
|
||||
const parser = new StreamPatternParser(pattern => {
|
||||
if (pattern.kind === 'wrappedLink') {
|
||||
return `[^${pattern.url}]`;
|
||||
}
|
||||
if (pattern.kind === 'index') {
|
||||
return `[#${pattern.value}]`;
|
||||
}
|
||||
return `[${pattern.text}](${pattern.url})`;
|
||||
});
|
||||
|
||||
const first = parser.write('ref ([AFFiNE](https://affine.pro');
|
||||
const second = parser.write(')) and [2]');
|
||||
|
||||
t.is(first, 'ref ');
|
||||
t.is(second, '[^https://affine.pro] and [#2]');
|
||||
t.is(parser.end(), '');
|
||||
});
|
||||
|
||||
test('CitationParser should convert wrapped links to numbered footnotes', t => {
|
||||
const parser = new CitationParser();
|
||||
|
||||
const output = parser.parse('Use ([AFFiNE](https://affine.pro)) now');
|
||||
t.is(output, 'Use [^1] now');
|
||||
t.regex(
|
||||
parser.end(),
|
||||
/\[\^1\]: \{"type":"url","url":"https%3A%2F%2Faffine.pro"\}/
|
||||
);
|
||||
});
|
||||
|
||||
test('chatToGPTMessage should not mutate input and should keep system schema', async t => {
|
||||
const schema = z.object({
|
||||
query: z.string(),
|
||||
});
|
||||
const messages = [
|
||||
{
|
||||
role: 'system' as const,
|
||||
content: 'You are helper',
|
||||
params: { schema },
|
||||
},
|
||||
{
|
||||
role: 'user' as const,
|
||||
content: '',
|
||||
attachments: ['https://example.com/a.png'],
|
||||
},
|
||||
];
|
||||
const firstRef = messages[0];
|
||||
const secondRef = messages[1];
|
||||
const [system, normalized, parsedSchema] = await chatToGPTMessage(
|
||||
messages,
|
||||
false
|
||||
);
|
||||
|
||||
t.is(system, 'You are helper');
|
||||
t.is(parsedSchema, schema);
|
||||
t.is(messages.length, 2);
|
||||
t.is(messages[0], firstRef);
|
||||
t.is(messages[1], secondRef);
|
||||
t.deepEqual(normalized[0], {
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: '[no content]' }],
|
||||
});
|
||||
});
|
||||
@@ -8,7 +8,6 @@ export class MockEventBus {
|
||||
|
||||
emit = this.stub.emitAsync;
|
||||
emitAsync = this.stub.emitAsync;
|
||||
emitDetached = this.stub.emitAsync;
|
||||
broadcast = this.stub.broadcast;
|
||||
|
||||
last<Event extends EventName>(
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
import test from 'ava';
|
||||
|
||||
import { NativeStreamAdapter } from '../native';
|
||||
|
||||
test('NativeStreamAdapter should support buffered and awaited consumption', async t => {
|
||||
const adapter = new NativeStreamAdapter<number>(undefined);
|
||||
|
||||
adapter.push(1);
|
||||
const first = await adapter.next();
|
||||
t.deepEqual(first, { value: 1, done: false });
|
||||
|
||||
const pending = adapter.next();
|
||||
adapter.push(2);
|
||||
const second = await pending;
|
||||
t.deepEqual(second, { value: 2, done: false });
|
||||
|
||||
adapter.push(null);
|
||||
const done = await adapter.next();
|
||||
t.true(done.done);
|
||||
});
|
||||
|
||||
test('NativeStreamAdapter return should abort handle and end iteration', async t => {
|
||||
let abortCount = 0;
|
||||
const adapter = new NativeStreamAdapter<number>({
|
||||
abort: () => {
|
||||
abortCount += 1;
|
||||
},
|
||||
});
|
||||
|
||||
const ended = await adapter.return();
|
||||
t.is(abortCount, 1);
|
||||
t.true(ended.done);
|
||||
|
||||
const secondReturn = await adapter.return();
|
||||
t.true(secondReturn.done);
|
||||
t.is(abortCount, 1);
|
||||
|
||||
const next = await adapter.next();
|
||||
t.true(next.done);
|
||||
});
|
||||
|
||||
test('NativeStreamAdapter should abort when AbortSignal is triggered', async t => {
|
||||
let abortCount = 0;
|
||||
const controller = new AbortController();
|
||||
const adapter = new NativeStreamAdapter<number>(
|
||||
{
|
||||
abort: () => {
|
||||
abortCount += 1;
|
||||
},
|
||||
},
|
||||
controller.signal
|
||||
);
|
||||
|
||||
const pending = adapter.next();
|
||||
controller.abort();
|
||||
const done = await pending;
|
||||
t.true(done.done);
|
||||
t.is(abortCount, 1);
|
||||
});
|
||||
|
||||
test('NativeStreamAdapter should end immediately for pre-aborted signal', async t => {
|
||||
let abortCount = 0;
|
||||
const controller = new AbortController();
|
||||
controller.abort();
|
||||
|
||||
const adapter = new NativeStreamAdapter<number>(
|
||||
{
|
||||
abort: () => {
|
||||
abortCount += 1;
|
||||
},
|
||||
},
|
||||
controller.signal
|
||||
);
|
||||
|
||||
const next = await adapter.next();
|
||||
t.true(next.done);
|
||||
t.is(abortCount, 1);
|
||||
|
||||
adapter.push(1);
|
||||
const stillDone = await adapter.next();
|
||||
t.true(stillDone.done);
|
||||
});
|
||||
@@ -629,35 +629,14 @@ export async function chatWithText(
|
||||
prefix = '',
|
||||
retry?: boolean
|
||||
): Promise<string> {
|
||||
const endpoint = prefix || '/stream';
|
||||
const query = messageId
|
||||
? `?messageId=${messageId}` + (retry ? '&retry=true' : '')
|
||||
: '';
|
||||
const res = await app
|
||||
.GET(`/api/copilot/chat/${sessionId}${endpoint}${query}`)
|
||||
.GET(`/api/copilot/chat/${sessionId}${prefix}${query}`)
|
||||
.expect(200);
|
||||
|
||||
if (prefix) {
|
||||
return res.text;
|
||||
}
|
||||
|
||||
const events = sse2array(res.text);
|
||||
const errorEvent = events.find(event => event.event === 'error');
|
||||
if (errorEvent?.data) {
|
||||
let message = errorEvent.data;
|
||||
try {
|
||||
const parsed = JSON.parse(errorEvent.data);
|
||||
message = parsed.message || message;
|
||||
} catch {
|
||||
// noop: keep raw error data
|
||||
}
|
||||
throw new Error(message);
|
||||
}
|
||||
|
||||
return events
|
||||
.filter(event => event.event === 'message')
|
||||
.map(event => event.data ?? '')
|
||||
.join('');
|
||||
return res.text;
|
||||
}
|
||||
|
||||
export async function chatWithTextStream(
|
||||
|
||||
@@ -38,11 +38,8 @@ test.before(async t => {
|
||||
t.context.app = app;
|
||||
});
|
||||
|
||||
test.afterEach.always(() => {
|
||||
Sinon.restore();
|
||||
});
|
||||
|
||||
test.after.always(async t => {
|
||||
Sinon.restore();
|
||||
__resetDnsLookupForTests();
|
||||
await t.context.app.close();
|
||||
});
|
||||
@@ -83,7 +80,6 @@ const assertAndSnapshotRaw = async (
|
||||
|
||||
test('should proxy image', async t => {
|
||||
const assertAndSnapshot = assertAndSnapshotRaw.bind(null, t);
|
||||
const imageUrl = `http://example.com/image-${Date.now()}.png`;
|
||||
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/image-proxy',
|
||||
@@ -109,7 +105,7 @@ test('should proxy image', async t => {
|
||||
|
||||
{
|
||||
await assertAndSnapshot(
|
||||
`/api/worker/image-proxy?url=${imageUrl}`,
|
||||
'/api/worker/image-proxy?url=http://example.com/image.png',
|
||||
'should return 400 if origin and referer are missing',
|
||||
{ status: 400, origin: null, referer: null }
|
||||
);
|
||||
@@ -117,17 +113,14 @@ test('should proxy image', async t => {
|
||||
|
||||
{
|
||||
await assertAndSnapshot(
|
||||
`/api/worker/image-proxy?url=${imageUrl}`,
|
||||
'/api/worker/image-proxy?url=http://example.com/image.png',
|
||||
'should return 400 for invalid origin header',
|
||||
{ status: 400, origin: 'http://invalid.com' }
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
const fakeBuffer = Buffer.from(
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+jfJ8AAAAASUVORK5CYII=',
|
||||
'base64'
|
||||
);
|
||||
const fakeBuffer = Buffer.from('fake image');
|
||||
const fakeResponse = new Response(fakeBuffer, {
|
||||
status: 200,
|
||||
headers: {
|
||||
@@ -137,14 +130,13 @@ test('should proxy image', async t => {
|
||||
});
|
||||
|
||||
const fetchSpy = Sinon.stub(global, 'fetch').resolves(fakeResponse);
|
||||
try {
|
||||
await assertAndSnapshot(
|
||||
`/api/worker/image-proxy?url=${imageUrl}`,
|
||||
'should return image buffer'
|
||||
);
|
||||
} finally {
|
||||
fetchSpy.restore();
|
||||
}
|
||||
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/image-proxy?url=http://example.com/image.png',
|
||||
'should return image buffer'
|
||||
);
|
||||
|
||||
fetchSpy.restore();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -208,19 +200,18 @@ test('should preview link', async t => {
|
||||
});
|
||||
|
||||
const fetchSpy = Sinon.stub(global, 'fetch').resolves(fakeHTML);
|
||||
try {
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/link-preview',
|
||||
'should process a valid external URL and return link preview data',
|
||||
{
|
||||
status: 200,
|
||||
method: 'POST',
|
||||
body: { url: 'http://external.com/page' },
|
||||
}
|
||||
);
|
||||
} finally {
|
||||
fetchSpy.restore();
|
||||
}
|
||||
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/link-preview',
|
||||
'should process a valid external URL and return link preview data',
|
||||
{
|
||||
status: 200,
|
||||
method: 'POST',
|
||||
body: { url: 'http://external.com/page' },
|
||||
}
|
||||
);
|
||||
|
||||
fetchSpy.restore();
|
||||
}
|
||||
|
||||
{
|
||||
@@ -260,19 +251,18 @@ test('should preview link', async t => {
|
||||
});
|
||||
|
||||
const fetchSpy = Sinon.stub(global, 'fetch').resolves(fakeHTML);
|
||||
try {
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/link-preview',
|
||||
'should decode HTML content with charset',
|
||||
{
|
||||
status: 200,
|
||||
method: 'POST',
|
||||
body: { url: `http://example.com/${charset}` },
|
||||
}
|
||||
);
|
||||
} finally {
|
||||
fetchSpy.restore();
|
||||
}
|
||||
|
||||
await assertAndSnapshot(
|
||||
'/api/worker/link-preview',
|
||||
'should decode HTML content with charset',
|
||||
{
|
||||
status: 200,
|
||||
method: 'POST',
|
||||
body: { url: `http://example.com/${charset}` },
|
||||
}
|
||||
);
|
||||
|
||||
fetchSpy.restore();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -88,21 +88,12 @@ export class EventBus
|
||||
emit<T extends EventName>(event: T, payload: Events[T]) {
|
||||
this.logger.debug(`Dispatch event: ${event}`);
|
||||
|
||||
this.dispatchAsync(event, payload);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit event in detached cls context to avoid inheriting current transaction.
|
||||
*/
|
||||
emitDetached<T extends EventName>(event: T, payload: Events[T]) {
|
||||
this.logger.debug(`Dispatch event: ${event} (detached)`);
|
||||
|
||||
const requestId = this.cls.getId();
|
||||
this.cls.run({ ifNested: 'override' }, () => {
|
||||
this.cls.set(CLS_ID, requestId ?? genRequestId('event'));
|
||||
this.dispatchAsync(event, payload);
|
||||
// NOTE(@forehalo):
|
||||
// Because all event handlers are wrapped in promisified metrics and cls context, they will always run in standalone tick.
|
||||
// In which way, if handler throws, an unhandled rejection will be triggered and end up with process exiting.
|
||||
// So we catch it here with `emitAsync`
|
||||
this.emitter.emitAsync(event, payload).catch(e => {
|
||||
this.emitter.emit('error', { event, payload, error: e });
|
||||
});
|
||||
|
||||
return true;
|
||||
@@ -175,16 +166,6 @@ export class EventBus
|
||||
return this.emitter.waitFor(name, timeout);
|
||||
}
|
||||
|
||||
private dispatchAsync<T extends EventName>(event: T, payload: Events[T]) {
|
||||
// NOTE:
|
||||
// Because all event handlers are wrapped in promisified metrics and cls context, they will always run in standalone tick.
|
||||
// In which way, if handler throws, an unhandled rejection will be triggered and end up with process exiting.
|
||||
// So we catch it here with `emitAsync`
|
||||
this.emitter.emitAsync(event, payload).catch(e => {
|
||||
this.emitter.emit('error', { event, payload, error: e });
|
||||
});
|
||||
}
|
||||
|
||||
private readonly bindEventHandlers = once(() => {
|
||||
this.scanner.scan().forEach(({ event, handler, opts }) => {
|
||||
this.on(event, handler, opts);
|
||||
|
||||
@@ -68,7 +68,7 @@ test('should update doc content to database when doc is updated', async t => {
|
||||
|
||||
const docId = randomUUID();
|
||||
await adapter.pushDocUpdates(workspace.id, docId, updates);
|
||||
await adapter.getDocBinNative(workspace.id, docId);
|
||||
await adapter.getDoc(workspace.id, docId);
|
||||
|
||||
mock.method(docReader, 'parseDocContent', () => {
|
||||
return {
|
||||
@@ -181,22 +181,3 @@ test('should ignore update workspace content to database when parse workspace co
|
||||
t.is(content!.name, null);
|
||||
t.is(content!.avatarKey, null);
|
||||
});
|
||||
|
||||
test('should ignore stale workspace when updating doc meta from snapshot event', async t => {
|
||||
const { docReader, listener, models } = t.context;
|
||||
const docId = randomUUID();
|
||||
mock.method(docReader, 'parseDocContent', () => ({
|
||||
title: 'test title',
|
||||
summary: 'test summary',
|
||||
}));
|
||||
|
||||
await models.workspace.delete(workspace.id);
|
||||
|
||||
await t.notThrowsAsync(async () => {
|
||||
await listener.markDocContentCacheStale({
|
||||
workspaceId: workspace.id,
|
||||
docId,
|
||||
blob: Buffer.from([0x01]),
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -110,7 +110,7 @@ export class PgWorkspaceDocStorageAdapter extends DocStorageAdapter {
|
||||
});
|
||||
|
||||
if (isNewDoc) {
|
||||
this.event.emitDetached('doc.created', {
|
||||
this.event.emit('doc.created', {
|
||||
workspaceId,
|
||||
docId,
|
||||
editor: editorId,
|
||||
@@ -334,7 +334,7 @@ export class PgWorkspaceDocStorageAdapter extends DocStorageAdapter {
|
||||
});
|
||||
|
||||
if (updatedSnapshot) {
|
||||
this.event.emitDetached('doc.snapshot.updated', {
|
||||
this.event.emit('doc.snapshot.updated', {
|
||||
workspaceId: snapshot.spaceId,
|
||||
docId: snapshot.docId,
|
||||
blob,
|
||||
|
||||
@@ -1,29 +1,12 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { Prisma } from '@prisma/client';
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { OnEvent } from '../../base';
|
||||
import { Models } from '../../models';
|
||||
import { PgWorkspaceDocStorageAdapter } from './adapters/workspace';
|
||||
import { DocReader } from './reader';
|
||||
|
||||
const IGNORED_PRISMA_CODES = new Set(['P2003', 'P2025', 'P2028']);
|
||||
|
||||
function isIgnorableDocEventError(error: unknown) {
|
||||
if (error instanceof Prisma.PrismaClientKnownRequestError) {
|
||||
return IGNORED_PRISMA_CODES.has(error.code);
|
||||
}
|
||||
if (error instanceof Prisma.PrismaClientUnknownRequestError) {
|
||||
return /transaction is aborted|transaction already closed/i.test(
|
||||
error.message
|
||||
);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class DocEventsListener {
|
||||
private readonly logger = new Logger(DocEventsListener.name);
|
||||
|
||||
constructor(
|
||||
private readonly docReader: DocReader,
|
||||
private readonly models: Models,
|
||||
@@ -37,39 +20,21 @@ export class DocEventsListener {
|
||||
blob,
|
||||
}: Events['doc.snapshot.updated']) {
|
||||
await this.docReader.markDocContentCacheStale(workspaceId, docId);
|
||||
const workspace = await this.models.workspace.get(workspaceId);
|
||||
if (!workspace) {
|
||||
this.logger.warn(
|
||||
`Skip stale doc snapshot event for missing workspace ${workspaceId}/${docId}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
const isDoc = workspaceId !== docId;
|
||||
// update doc content to database
|
||||
try {
|
||||
if (isDoc) {
|
||||
const content = this.docReader.parseDocContent(blob);
|
||||
if (!content) {
|
||||
return;
|
||||
}
|
||||
await this.models.doc.upsertMeta(workspaceId, docId, content);
|
||||
} else {
|
||||
// update workspace content to database
|
||||
const content = this.docReader.parseWorkspaceContent(blob);
|
||||
if (!content) {
|
||||
return;
|
||||
}
|
||||
await this.models.workspace.update(workspaceId, content);
|
||||
}
|
||||
} catch (error) {
|
||||
if (isIgnorableDocEventError(error)) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
this.logger.warn(
|
||||
`Ignore stale doc snapshot event for ${workspaceId}/${docId}: ${message}`
|
||||
);
|
||||
if (isDoc) {
|
||||
const content = this.docReader.parseDocContent(blob);
|
||||
if (!content) {
|
||||
return;
|
||||
}
|
||||
throw error;
|
||||
await this.models.doc.upsertMeta(workspaceId, docId, content);
|
||||
} else {
|
||||
// update workspace content to database
|
||||
const content = this.docReader.parseWorkspaceContent(blob);
|
||||
if (!content) {
|
||||
return;
|
||||
}
|
||||
await this.models.workspace.update(workspaceId, content);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import ava, { TestFn } from 'ava';
|
||||
|
||||
import {
|
||||
createTestingModule,
|
||||
type TestingModule,
|
||||
} from '../../../__tests__/utils';
|
||||
import { DocRole, Models, User, Workspace } from '../../../models';
|
||||
import { EventsListener } from '../event';
|
||||
import { PermissionModule } from '../index';
|
||||
|
||||
interface Context {
|
||||
module: TestingModule;
|
||||
models: Models;
|
||||
listener: EventsListener;
|
||||
}
|
||||
|
||||
const test = ava as TestFn<Context>;
|
||||
|
||||
let owner: User;
|
||||
let workspace: Workspace;
|
||||
|
||||
test.before(async t => {
|
||||
const module = await createTestingModule({ imports: [PermissionModule] });
|
||||
t.context.module = module;
|
||||
t.context.models = module.get(Models);
|
||||
t.context.listener = module.get(EventsListener);
|
||||
});
|
||||
|
||||
test.beforeEach(async t => {
|
||||
await t.context.module.initTestingDB();
|
||||
owner = await t.context.models.user.create({
|
||||
email: `${randomUUID()}@affine.pro`,
|
||||
});
|
||||
workspace = await t.context.models.workspace.create(owner.id);
|
||||
});
|
||||
|
||||
test.after.always(async t => {
|
||||
await t.context.module.close();
|
||||
});
|
||||
|
||||
test('should ignore default owner event when workspace does not exist', async t => {
|
||||
await t.notThrowsAsync(async () => {
|
||||
await t.context.listener.setDefaultPageOwner({
|
||||
workspaceId: randomUUID(),
|
||||
docId: randomUUID(),
|
||||
editor: owner.id,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
test('should ignore default owner event when editor does not exist', async t => {
|
||||
await t.notThrowsAsync(async () => {
|
||||
await t.context.listener.setDefaultPageOwner({
|
||||
workspaceId: workspace.id,
|
||||
docId: randomUUID(),
|
||||
editor: randomUUID(),
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
test('should set owner when workspace and editor exist', async t => {
|
||||
const docId = randomUUID();
|
||||
await t.context.listener.setDefaultPageOwner({
|
||||
workspaceId: workspace.id,
|
||||
docId,
|
||||
editor: owner.id,
|
||||
});
|
||||
|
||||
const role = await t.context.models.docUser.get(
|
||||
workspace.id,
|
||||
docId,
|
||||
owner.id
|
||||
);
|
||||
t.is(role?.type, DocRole.Owner);
|
||||
});
|
||||
@@ -1,27 +1,10 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { Prisma } from '@prisma/client';
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { OnEvent } from '../../base';
|
||||
import { Models } from '../../models';
|
||||
|
||||
const IGNORED_PRISMA_CODES = new Set(['P2003', 'P2025', 'P2028']);
|
||||
|
||||
function isIgnorablePermissionEventError(error: unknown) {
|
||||
if (error instanceof Prisma.PrismaClientKnownRequestError) {
|
||||
return IGNORED_PRISMA_CODES.has(error.code);
|
||||
}
|
||||
if (error instanceof Prisma.PrismaClientUnknownRequestError) {
|
||||
return /transaction is aborted|transaction already closed/i.test(
|
||||
error.message
|
||||
);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class EventsListener {
|
||||
private readonly logger = new Logger(EventsListener.name);
|
||||
|
||||
constructor(private readonly models: Models) {}
|
||||
|
||||
@OnEvent('doc.created')
|
||||
@@ -32,33 +15,6 @@ export class EventsListener {
|
||||
return;
|
||||
}
|
||||
|
||||
const workspace = await this.models.workspace.get(workspaceId);
|
||||
if (!workspace) {
|
||||
this.logger.warn(
|
||||
`Skip default doc owner event for missing workspace ${workspaceId}/${docId}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const user = await this.models.user.get(editor);
|
||||
if (!user) {
|
||||
this.logger.warn(
|
||||
`Skip default doc owner event for missing editor ${workspaceId}/${docId}/${editor}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await this.models.docUser.setOwner(workspaceId, docId, editor);
|
||||
} catch (error) {
|
||||
if (isIgnorablePermissionEventError(error)) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
this.logger.warn(
|
||||
`Ignore stale doc owner event for ${workspaceId}/${docId}/${editor}: ${message}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
await this.models.docUser.setOwner(workspaceId, docId, editor);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,18 +42,8 @@ export class Ga4Client {
|
||||
timestamp_micros: event.timestampMicros,
|
||||
})),
|
||||
};
|
||||
try {
|
||||
await this.post(payload);
|
||||
} catch {
|
||||
if (env.DEPLOYMENT_TYPE === 'affine') {
|
||||
// In production, we want to be resilient to GA4 failures, so we catch and ignore errors.
|
||||
// In non-production environments, we rethrow to surface issues during development and testing.
|
||||
console.info(
|
||||
'Failed to send telemetry event to GA4:',
|
||||
chunk.map(e => e.eventName).join(', ')
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
await this.post(payload);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ import assert from 'node:assert';
|
||||
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Transactional } from '@nestjs-cls/transactional';
|
||||
import type { TransactionalAdapterPrisma } from '@nestjs-cls/transactional-adapter-prisma';
|
||||
import { WorkspaceDocUserRole } from '@prisma/client';
|
||||
|
||||
import { CanNotBatchGrantDocOwnerPermissions, PaginationInput } from '../base';
|
||||
@@ -15,20 +14,31 @@ export class DocUserModel extends BaseModel {
|
||||
* Set or update the [Owner] of a doc.
|
||||
* The old [Owner] will be changed to [Manager] if there is already an [Owner].
|
||||
*/
|
||||
@Transactional<TransactionalAdapterPrisma>({ timeout: 15000 })
|
||||
@Transactional()
|
||||
async setOwner(workspaceId: string, docId: string, userId: string) {
|
||||
await this.db.workspaceDocUserRole.updateMany({
|
||||
const oldOwner = await this.db.workspaceDocUserRole.findFirst({
|
||||
where: {
|
||||
workspaceId,
|
||||
docId,
|
||||
type: DocRole.Owner,
|
||||
userId: { not: userId },
|
||||
},
|
||||
data: {
|
||||
type: DocRole.Manager,
|
||||
},
|
||||
});
|
||||
|
||||
if (oldOwner) {
|
||||
await this.db.workspaceDocUserRole.update({
|
||||
where: {
|
||||
workspaceId_docId_userId: {
|
||||
workspaceId,
|
||||
docId,
|
||||
userId: oldOwner.userId,
|
||||
},
|
||||
},
|
||||
data: {
|
||||
type: DocRole.Manager,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
await this.db.workspaceDocUserRole.upsert({
|
||||
where: {
|
||||
workspaceId_docId_userId: {
|
||||
@@ -47,9 +57,16 @@ export class DocUserModel extends BaseModel {
|
||||
type: DocRole.Owner,
|
||||
},
|
||||
});
|
||||
this.logger.log(
|
||||
`Set doc owner of [${workspaceId}/${docId}] to [${userId}]`
|
||||
);
|
||||
|
||||
if (oldOwner) {
|
||||
this.logger.log(
|
||||
`Transfer doc owner of [${workspaceId}/${docId}] from [${oldOwner.userId}] to [${userId}]`
|
||||
);
|
||||
} else {
|
||||
this.logger.log(
|
||||
`Set doc owner of [${workspaceId}/${docId}] to [${userId}]`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -57,316 +57,3 @@ export const addDocToRootDoc = serverNativeModule.addDocToRootDoc;
|
||||
export const updateDocTitle = serverNativeModule.updateDocTitle;
|
||||
export const updateDocProperties = serverNativeModule.updateDocProperties;
|
||||
export const updateRootDocMetaTitle = serverNativeModule.updateRootDocMetaTitle;
|
||||
|
||||
type NativeLlmModule = {
|
||||
llmDispatch?: (
|
||||
protocol: string,
|
||||
backendConfigJson: string,
|
||||
requestJson: string
|
||||
) => string | Promise<string>;
|
||||
llmDispatchStream?: (
|
||||
protocol: string,
|
||||
backendConfigJson: string,
|
||||
requestJson: string,
|
||||
callback: (error: Error | null, eventJson: string) => void
|
||||
) => { abort?: () => void } | undefined;
|
||||
};
|
||||
|
||||
const nativeLlmModule = serverNativeModule as typeof serverNativeModule &
|
||||
NativeLlmModule;
|
||||
|
||||
export type NativeLlmProtocol =
|
||||
| 'openai_chat'
|
||||
| 'openai_responses'
|
||||
| 'anthropic';
|
||||
|
||||
export type NativeLlmBackendConfig = {
|
||||
base_url: string;
|
||||
auth_token: string;
|
||||
request_layer?: 'anthropic' | 'chat_completions' | 'responses' | 'vertex';
|
||||
headers?: Record<string, string>;
|
||||
no_streaming?: boolean;
|
||||
timeout_ms?: number;
|
||||
};
|
||||
|
||||
export type NativeLlmCoreRole = 'system' | 'user' | 'assistant' | 'tool';
|
||||
|
||||
export type NativeLlmCoreContent =
|
||||
| { type: 'text'; text: string }
|
||||
| { type: 'reasoning'; text: string; signature?: string }
|
||||
| {
|
||||
type: 'tool_call';
|
||||
call_id: string;
|
||||
name: string;
|
||||
arguments: Record<string, unknown>;
|
||||
thought?: string;
|
||||
}
|
||||
| {
|
||||
type: 'tool_result';
|
||||
call_id: string;
|
||||
output: unknown;
|
||||
is_error?: boolean;
|
||||
name?: string;
|
||||
arguments?: Record<string, unknown>;
|
||||
}
|
||||
| { type: 'image'; source: Record<string, unknown> | string };
|
||||
|
||||
export type NativeLlmCoreMessage = {
|
||||
role: NativeLlmCoreRole;
|
||||
content: NativeLlmCoreContent[];
|
||||
};
|
||||
|
||||
export type NativeLlmToolDefinition = {
|
||||
name: string;
|
||||
description?: string;
|
||||
parameters: Record<string, unknown>;
|
||||
};
|
||||
|
||||
export type NativeLlmRequest = {
|
||||
model: string;
|
||||
messages: NativeLlmCoreMessage[];
|
||||
stream?: boolean;
|
||||
max_tokens?: number;
|
||||
temperature?: number;
|
||||
tools?: NativeLlmToolDefinition[];
|
||||
tool_choice?: 'auto' | 'none' | 'required' | { name: string };
|
||||
include?: string[];
|
||||
reasoning?: Record<string, unknown>;
|
||||
middleware?: {
|
||||
request?: Array<
|
||||
'normalize_messages' | 'clamp_max_tokens' | 'tool_schema_rewrite'
|
||||
>;
|
||||
stream?: Array<'stream_event_normalize' | 'citation_indexing'>;
|
||||
config?: {
|
||||
no_additional_properties?: boolean;
|
||||
drop_property_format?: boolean;
|
||||
drop_property_min_length?: boolean;
|
||||
drop_array_min_items?: boolean;
|
||||
drop_array_max_items?: boolean;
|
||||
max_tokens_cap?: number;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
export type NativeLlmDispatchResponse = {
|
||||
id: string;
|
||||
model: string;
|
||||
message: NativeLlmCoreMessage;
|
||||
usage: {
|
||||
prompt_tokens: number;
|
||||
completion_tokens: number;
|
||||
total_tokens: number;
|
||||
cached_tokens?: number;
|
||||
};
|
||||
finish_reason: string;
|
||||
reasoning_details?: unknown;
|
||||
};
|
||||
|
||||
export type NativeLlmStreamEvent =
|
||||
| { type: 'message_start'; id?: string; model?: string }
|
||||
| { type: 'text_delta'; text: string }
|
||||
| { type: 'reasoning_delta'; text: string }
|
||||
| {
|
||||
type: 'tool_call_delta';
|
||||
call_id: string;
|
||||
name?: string;
|
||||
arguments_delta: string;
|
||||
}
|
||||
| {
|
||||
type: 'tool_call';
|
||||
call_id: string;
|
||||
name: string;
|
||||
arguments: Record<string, unknown>;
|
||||
thought?: string;
|
||||
}
|
||||
| {
|
||||
type: 'tool_result';
|
||||
call_id: string;
|
||||
output: unknown;
|
||||
is_error?: boolean;
|
||||
name?: string;
|
||||
arguments?: Record<string, unknown>;
|
||||
}
|
||||
| { type: 'citation'; index: number; url: string }
|
||||
| {
|
||||
type: 'usage';
|
||||
usage: {
|
||||
prompt_tokens: number;
|
||||
completion_tokens: number;
|
||||
total_tokens: number;
|
||||
cached_tokens?: number;
|
||||
};
|
||||
}
|
||||
| {
|
||||
type: 'done';
|
||||
finish_reason?: string;
|
||||
usage?: {
|
||||
prompt_tokens: number;
|
||||
completion_tokens: number;
|
||||
total_tokens: number;
|
||||
cached_tokens?: number;
|
||||
};
|
||||
}
|
||||
| { type: 'error'; message: string; code?: string; raw?: string };
|
||||
const LLM_STREAM_END_MARKER = '__AFFINE_LLM_STREAM_END__';
|
||||
|
||||
export async function llmDispatch(
|
||||
protocol: NativeLlmProtocol,
|
||||
backendConfig: NativeLlmBackendConfig,
|
||||
request: NativeLlmRequest
|
||||
): Promise<NativeLlmDispatchResponse> {
|
||||
if (!nativeLlmModule.llmDispatch) {
|
||||
throw new Error('native llm dispatch is not available');
|
||||
}
|
||||
const response = nativeLlmModule.llmDispatch(
|
||||
protocol,
|
||||
JSON.stringify(backendConfig),
|
||||
JSON.stringify(request)
|
||||
);
|
||||
const responseText = await Promise.resolve(response);
|
||||
return JSON.parse(responseText) as NativeLlmDispatchResponse;
|
||||
}
|
||||
|
||||
export class NativeStreamAdapter<T> implements AsyncIterableIterator<T> {
|
||||
readonly #queue: T[] = [];
|
||||
readonly #waiters: ((result: IteratorResult<T>) => void)[] = [];
|
||||
readonly #handle: { abort?: () => void } | undefined;
|
||||
readonly #signal?: AbortSignal;
|
||||
readonly #abortListener?: () => void;
|
||||
#ended = false;
|
||||
|
||||
constructor(
|
||||
handle: { abort?: () => void } | undefined,
|
||||
signal?: AbortSignal
|
||||
) {
|
||||
this.#handle = handle;
|
||||
this.#signal = signal;
|
||||
|
||||
if (signal?.aborted) {
|
||||
this.close(true);
|
||||
return;
|
||||
}
|
||||
|
||||
if (signal) {
|
||||
this.#abortListener = () => {
|
||||
this.close(true);
|
||||
};
|
||||
signal.addEventListener('abort', this.#abortListener, { once: true });
|
||||
}
|
||||
}
|
||||
|
||||
private close(abortHandle: boolean) {
|
||||
if (this.#ended) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.#ended = true;
|
||||
if (this.#signal && this.#abortListener) {
|
||||
this.#signal.removeEventListener('abort', this.#abortListener);
|
||||
}
|
||||
if (abortHandle) {
|
||||
this.#handle?.abort?.();
|
||||
}
|
||||
|
||||
while (this.#waiters.length) {
|
||||
const waiter = this.#waiters.shift();
|
||||
waiter?.({ value: undefined as T, done: true });
|
||||
}
|
||||
}
|
||||
|
||||
push(value: T | null) {
|
||||
if (this.#ended) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (value === null) {
|
||||
this.close(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const waiter = this.#waiters.shift();
|
||||
if (waiter) {
|
||||
waiter({ value, done: false });
|
||||
return;
|
||||
}
|
||||
|
||||
this.#queue.push(value);
|
||||
}
|
||||
|
||||
[Symbol.asyncIterator]() {
|
||||
return this;
|
||||
}
|
||||
|
||||
async next(): Promise<IteratorResult<T>> {
|
||||
if (this.#queue.length > 0) {
|
||||
const value = this.#queue.shift() as T;
|
||||
return { value, done: false };
|
||||
}
|
||||
|
||||
if (this.#ended) {
|
||||
return { value: undefined as T, done: true };
|
||||
}
|
||||
|
||||
return await new Promise(resolve => {
|
||||
this.#waiters.push(resolve);
|
||||
});
|
||||
}
|
||||
|
||||
async return(): Promise<IteratorResult<T>> {
|
||||
this.close(true);
|
||||
|
||||
return { value: undefined as T, done: true };
|
||||
}
|
||||
}
|
||||
|
||||
export function llmDispatchStream(
|
||||
protocol: NativeLlmProtocol,
|
||||
backendConfig: NativeLlmBackendConfig,
|
||||
request: NativeLlmRequest,
|
||||
signal?: AbortSignal
|
||||
): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
if (!nativeLlmModule.llmDispatchStream) {
|
||||
throw new Error('native llm stream dispatch is not available');
|
||||
}
|
||||
|
||||
let adapter: NativeStreamAdapter<NativeLlmStreamEvent> | undefined;
|
||||
const buffer: (NativeLlmStreamEvent | null)[] = [];
|
||||
let pushFn = (event: NativeLlmStreamEvent | null) => {
|
||||
buffer.push(event);
|
||||
};
|
||||
const handle = nativeLlmModule.llmDispatchStream(
|
||||
protocol,
|
||||
JSON.stringify(backendConfig),
|
||||
JSON.stringify(request),
|
||||
(error, eventJson) => {
|
||||
if (error) {
|
||||
pushFn({ type: 'error', message: error.message, raw: eventJson });
|
||||
return;
|
||||
}
|
||||
if (eventJson === LLM_STREAM_END_MARKER) {
|
||||
pushFn(null);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
pushFn(JSON.parse(eventJson) as NativeLlmStreamEvent);
|
||||
} catch (error) {
|
||||
pushFn({
|
||||
type: 'error',
|
||||
message:
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: 'failed to parse native stream event',
|
||||
raw: eventJson,
|
||||
});
|
||||
}
|
||||
}
|
||||
);
|
||||
adapter = new NativeStreamAdapter(handle, signal);
|
||||
pushFn = event => {
|
||||
adapter.push(event);
|
||||
};
|
||||
for (const event of buffer) {
|
||||
adapter.push(event);
|
||||
}
|
||||
return adapter;
|
||||
}
|
||||
|
||||
@@ -154,8 +154,8 @@ export abstract class CalendarProvider {
|
||||
|
||||
protected async fetchJson<T>(url: string, init?: RequestInit) {
|
||||
const response = await fetch(url, {
|
||||
headers: { Accept: 'application/json', ...init?.headers },
|
||||
...init,
|
||||
headers: { ...init?.headers, Accept: 'application/json' },
|
||||
});
|
||||
const body = await response.text();
|
||||
if (!response.ok) {
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
defineModuleConfig,
|
||||
StorageJSONSchema,
|
||||
@@ -15,179 +13,7 @@ import { GeminiGenerativeConfig, GeminiVertexConfig } from './providers/gemini';
|
||||
import { MorphConfig } from './providers/morph';
|
||||
import { OpenAIConfig } from './providers/openai';
|
||||
import { PerplexityConfig } from './providers/perplexity';
|
||||
import {
|
||||
CopilotProviderType,
|
||||
ModelOutputType,
|
||||
VertexSchema,
|
||||
} from './providers/types';
|
||||
|
||||
export type CopilotProviderConfigMap = {
|
||||
[CopilotProviderType.OpenAI]: OpenAIConfig;
|
||||
[CopilotProviderType.FAL]: FalConfig;
|
||||
[CopilotProviderType.Gemini]: GeminiGenerativeConfig;
|
||||
[CopilotProviderType.GeminiVertex]: GeminiVertexConfig;
|
||||
[CopilotProviderType.Perplexity]: PerplexityConfig;
|
||||
[CopilotProviderType.Anthropic]: AnthropicOfficialConfig;
|
||||
[CopilotProviderType.AnthropicVertex]: AnthropicVertexConfig;
|
||||
[CopilotProviderType.Morph]: MorphConfig;
|
||||
};
|
||||
|
||||
export type ProviderSpecificConfig =
|
||||
CopilotProviderConfigMap[keyof CopilotProviderConfigMap];
|
||||
|
||||
export const RustRequestMiddlewareValues = [
|
||||
'normalize_messages',
|
||||
'clamp_max_tokens',
|
||||
'tool_schema_rewrite',
|
||||
] as const;
|
||||
export type RustRequestMiddleware =
|
||||
(typeof RustRequestMiddlewareValues)[number];
|
||||
|
||||
export const RustStreamMiddlewareValues = [
|
||||
'stream_event_normalize',
|
||||
'citation_indexing',
|
||||
] as const;
|
||||
export type RustStreamMiddleware = (typeof RustStreamMiddlewareValues)[number];
|
||||
|
||||
export const NodeTextMiddlewareValues = [
|
||||
'citation_footnote',
|
||||
'callout',
|
||||
'thinking_format',
|
||||
] as const;
|
||||
export type NodeTextMiddleware = (typeof NodeTextMiddlewareValues)[number];
|
||||
|
||||
export type ProviderMiddlewareConfig = {
|
||||
rust?: { request?: RustRequestMiddleware[]; stream?: RustStreamMiddleware[] };
|
||||
node?: { text?: NodeTextMiddleware[] };
|
||||
};
|
||||
|
||||
type CopilotProviderProfileCommon = {
|
||||
id: string;
|
||||
displayName?: string;
|
||||
priority?: number;
|
||||
enabled?: boolean;
|
||||
models?: string[];
|
||||
middleware?: ProviderMiddlewareConfig;
|
||||
};
|
||||
|
||||
type CopilotProviderProfileVariant<T extends CopilotProviderType> = {
|
||||
type: T;
|
||||
config: CopilotProviderConfigMap[T];
|
||||
};
|
||||
|
||||
export type CopilotProviderProfile = CopilotProviderProfileCommon &
|
||||
{
|
||||
[Type in CopilotProviderType]: CopilotProviderProfileVariant<Type>;
|
||||
}[CopilotProviderType];
|
||||
|
||||
export type CopilotProviderDefaults = Partial<
|
||||
Record<ModelOutputType, string>
|
||||
> & {
|
||||
fallback?: string;
|
||||
};
|
||||
|
||||
const CopilotProviderProfileBaseShape = z.object({
|
||||
id: z.string().regex(/^[a-zA-Z0-9-_]+$/),
|
||||
displayName: z.string().optional(),
|
||||
priority: z.number().optional(),
|
||||
enabled: z.boolean().optional(),
|
||||
models: z.array(z.string()).optional(),
|
||||
middleware: z
|
||||
.object({
|
||||
rust: z
|
||||
.object({
|
||||
request: z.array(z.enum(RustRequestMiddlewareValues)).optional(),
|
||||
stream: z.array(z.enum(RustStreamMiddlewareValues)).optional(),
|
||||
})
|
||||
.optional(),
|
||||
node: z
|
||||
.object({ text: z.array(z.enum(NodeTextMiddlewareValues)).optional() })
|
||||
.optional(),
|
||||
})
|
||||
.optional(),
|
||||
});
|
||||
|
||||
const OpenAIConfigShape = z.object({
|
||||
apiKey: z.string(),
|
||||
baseURL: z.string().optional(),
|
||||
oldApiStyle: z.boolean().optional(),
|
||||
});
|
||||
|
||||
const FalConfigShape = z.object({
|
||||
apiKey: z.string(),
|
||||
});
|
||||
|
||||
const GeminiGenerativeConfigShape = z.object({
|
||||
apiKey: z.string(),
|
||||
baseURL: z.string().optional(),
|
||||
});
|
||||
|
||||
const VertexProviderConfigShape = z.object({
|
||||
location: z.string().optional(),
|
||||
project: z.string().optional(),
|
||||
baseURL: z.string().optional(),
|
||||
googleAuthOptions: z.any().optional(),
|
||||
fetch: z.any().optional(),
|
||||
});
|
||||
|
||||
const PerplexityConfigShape = z.object({
|
||||
apiKey: z.string(),
|
||||
endpoint: z.string().optional(),
|
||||
});
|
||||
|
||||
const AnthropicOfficialConfigShape = z.object({
|
||||
apiKey: z.string(),
|
||||
baseURL: z.string().optional(),
|
||||
});
|
||||
|
||||
const MorphConfigShape = z.object({
|
||||
apiKey: z.string().optional(),
|
||||
});
|
||||
|
||||
const CopilotProviderProfileShape = z.discriminatedUnion('type', [
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.OpenAI),
|
||||
config: OpenAIConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.FAL),
|
||||
config: FalConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.Gemini),
|
||||
config: GeminiGenerativeConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.GeminiVertex),
|
||||
config: VertexProviderConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.Perplexity),
|
||||
config: PerplexityConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.Anthropic),
|
||||
config: AnthropicOfficialConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.AnthropicVertex),
|
||||
config: VertexProviderConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.Morph),
|
||||
config: MorphConfigShape,
|
||||
}),
|
||||
]);
|
||||
|
||||
const CopilotProviderDefaultsShape = z.object({
|
||||
[ModelOutputType.Text]: z.string().optional(),
|
||||
[ModelOutputType.Object]: z.string().optional(),
|
||||
[ModelOutputType.Embedding]: z.string().optional(),
|
||||
[ModelOutputType.Image]: z.string().optional(),
|
||||
[ModelOutputType.Structured]: z.string().optional(),
|
||||
fallback: z.string().optional(),
|
||||
});
|
||||
|
||||
import { VertexSchema } from './providers/types';
|
||||
declare global {
|
||||
interface AppConfigSchema {
|
||||
copilot: {
|
||||
@@ -201,8 +27,6 @@ declare global {
|
||||
storage: ConfigItem<StorageProviderConfig>;
|
||||
scenarios: ConfigItem<CopilotPromptScenario>;
|
||||
providers: {
|
||||
profiles: ConfigItem<CopilotProviderProfile[]>;
|
||||
defaults: ConfigItem<CopilotProviderDefaults>;
|
||||
openai: ConfigItem<OpenAIConfig>;
|
||||
fal: ConfigItem<FalConfig>;
|
||||
gemini: ConfigItem<GeminiGenerativeConfig>;
|
||||
@@ -239,16 +63,6 @@ defineModuleConfig('copilot', {
|
||||
},
|
||||
},
|
||||
},
|
||||
'providers.profiles': {
|
||||
desc: 'The profile list for copilot providers.',
|
||||
default: [],
|
||||
shape: z.array(CopilotProviderProfileShape),
|
||||
},
|
||||
'providers.defaults': {
|
||||
desc: 'The default provider ids for model output types and global fallback.',
|
||||
default: {},
|
||||
shape: CopilotProviderDefaultsShape,
|
||||
},
|
||||
'providers.openai': {
|
||||
desc: 'The config for the openai provider.',
|
||||
default: {
|
||||
|
||||
@@ -36,7 +36,10 @@ import {
|
||||
BlobNotFound,
|
||||
CallMetric,
|
||||
Config,
|
||||
CopilotFailedToGenerateText,
|
||||
CopilotSessionNotFound,
|
||||
InternalServerError,
|
||||
mapAnyError,
|
||||
mapSseError,
|
||||
metrics,
|
||||
NoCopilotProviderAvailable,
|
||||
@@ -239,6 +242,61 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
};
|
||||
}
|
||||
|
||||
@Get('/chat/:sessionId')
|
||||
@CallMetric('ai', 'chat', { timer: true })
|
||||
async chat(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query() query: Record<string, string | string[]>
|
||||
): Promise<string> {
|
||||
const info: any = { sessionId, params: query };
|
||||
|
||||
try {
|
||||
const { provider, model, session, finalMessage } =
|
||||
await this.prepareChatSession(
|
||||
user,
|
||||
sessionId,
|
||||
query,
|
||||
ModelOutputType.Text
|
||||
);
|
||||
|
||||
info.model = model;
|
||||
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
|
||||
metrics.ai.counter('chat_calls').add(1, { model });
|
||||
|
||||
const { reasoning, webSearch, toolsConfig } =
|
||||
ChatQuerySchema.parse(query);
|
||||
const content = await provider.text({ modelId: model }, finalMessage, {
|
||||
...session.config.promptConfig,
|
||||
signal: getSignal(req).signal,
|
||||
user: user.id,
|
||||
session: session.config.sessionId,
|
||||
workspace: session.config.workspaceId,
|
||||
reasoning,
|
||||
webSearch,
|
||||
tools: getTools(session.config.promptConfig?.tools, toolsConfig),
|
||||
});
|
||||
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
await session.save();
|
||||
|
||||
return content;
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_errors').add(1);
|
||||
let error = mapAnyError(e);
|
||||
if (error instanceof InternalServerError) {
|
||||
error = new CopilotFailedToGenerateText(e.message);
|
||||
}
|
||||
error.log('CopilotChat', info);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
@Sse('/chat/:sessionId/stream')
|
||||
@CallMetric('ai', 'chat_stream', { timer: true })
|
||||
async chatStream(
|
||||
|
||||
@@ -3,7 +3,7 @@ import { AiPrompt, PrismaClient } from '@prisma/client';
|
||||
|
||||
import type { PromptConfig, PromptMessage } from '../providers/types';
|
||||
|
||||
export type Prompt = Omit<
|
||||
type Prompt = Omit<
|
||||
AiPrompt,
|
||||
| 'id'
|
||||
| 'createdAt'
|
||||
@@ -2095,14 +2095,17 @@ export const prompts: Prompt[] = [
|
||||
|
||||
export async function refreshPrompts(db: PrismaClient) {
|
||||
const needToSkip = await db.aiPrompt
|
||||
.findMany({ where: { modified: true }, select: { name: true } })
|
||||
.findMany({
|
||||
where: { modified: true },
|
||||
select: { name: true },
|
||||
})
|
||||
.then(p => p.map(p => p.name));
|
||||
|
||||
for (const prompt of prompts) {
|
||||
// skip prompt update if already modified by admin panel
|
||||
if (needToSkip.includes(prompt.name)) {
|
||||
new Logger('CopilotPrompt').warn(`Skip modified prompt: ${prompt.name}`);
|
||||
continue;
|
||||
return;
|
||||
}
|
||||
|
||||
await db.aiPrompt.upsert({
|
||||
|
||||
@@ -12,7 +12,6 @@ import {
|
||||
import { ChatPrompt } from './chat-prompt';
|
||||
import {
|
||||
CopilotPromptScenario,
|
||||
type Prompt,
|
||||
prompts,
|
||||
refreshPrompts,
|
||||
Scenario,
|
||||
@@ -22,7 +21,6 @@ import {
|
||||
export class PromptService implements OnApplicationBootstrap {
|
||||
private readonly logger = new Logger(PromptService.name);
|
||||
private readonly cache = new Map<string, ChatPrompt>();
|
||||
private readonly inMemoryPrompts = new Map<string, Prompt>();
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
@@ -30,7 +28,7 @@ export class PromptService implements OnApplicationBootstrap {
|
||||
) {}
|
||||
|
||||
async onApplicationBootstrap() {
|
||||
this.resetInMemoryPrompts();
|
||||
this.cache.clear();
|
||||
await refreshPrompts(this.db);
|
||||
}
|
||||
|
||||
@@ -47,7 +45,6 @@ export class PromptService implements OnApplicationBootstrap {
|
||||
}
|
||||
|
||||
protected async setup(scenarios?: CopilotPromptScenario) {
|
||||
this.ensureInMemoryPrompts();
|
||||
if (!!scenarios && scenarios.override_enabled && scenarios.scenarios) {
|
||||
this.logger.log('Updating prompts based on scenarios...');
|
||||
for (const [scenario, model] of Object.entries(scenarios.scenarios)) {
|
||||
@@ -78,29 +75,25 @@ export class PromptService implements OnApplicationBootstrap {
|
||||
* @returns prompt names
|
||||
*/
|
||||
async listNames() {
|
||||
this.ensureInMemoryPrompts();
|
||||
return Array.from(this.inMemoryPrompts.keys());
|
||||
return this.db.aiPrompt
|
||||
.findMany({ select: { name: true } })
|
||||
.then(prompts => Array.from(new Set(prompts.map(p => p.name))));
|
||||
}
|
||||
|
||||
async list() {
|
||||
this.ensureInMemoryPrompts();
|
||||
return Array.from(this.inMemoryPrompts.values())
|
||||
.map(prompt => ({
|
||||
name: prompt.name,
|
||||
action: prompt.action ?? null,
|
||||
model: prompt.model,
|
||||
config: prompt.config ? structuredClone(prompt.config) : null,
|
||||
messages: prompt.messages.map(message => ({
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
params: message.params ?? null,
|
||||
})),
|
||||
}))
|
||||
.sort((a, b) => {
|
||||
if (a.action === null && b.action !== null) return -1;
|
||||
if (a.action !== null && b.action === null) return 1;
|
||||
return (a.action ?? '').localeCompare(b.action ?? '');
|
||||
});
|
||||
return this.db.aiPrompt.findMany({
|
||||
select: {
|
||||
name: true,
|
||||
action: true,
|
||||
model: true,
|
||||
config: true,
|
||||
messages: {
|
||||
select: { role: true, content: true, params: true },
|
||||
orderBy: { idx: 'asc' },
|
||||
},
|
||||
},
|
||||
orderBy: { action: { sort: 'asc', nulls: 'first' } },
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -109,24 +102,40 @@ export class PromptService implements OnApplicationBootstrap {
|
||||
* @returns prompt messages
|
||||
*/
|
||||
async get(name: string): Promise<ChatPrompt | null> {
|
||||
this.ensureInMemoryPrompts();
|
||||
|
||||
// skip cache in dev mode to ensure the latest prompt is always fetched
|
||||
if (!env.dev) {
|
||||
const cached = this.cache.get(name);
|
||||
if (cached) return cached;
|
||||
}
|
||||
|
||||
const prompt = this.inMemoryPrompts.get(name);
|
||||
if (!prompt) return null;
|
||||
const prompt = await this.db.aiPrompt.findUnique({
|
||||
where: {
|
||||
name,
|
||||
},
|
||||
select: {
|
||||
name: true,
|
||||
action: true,
|
||||
model: true,
|
||||
optionalModels: true,
|
||||
config: true,
|
||||
messages: {
|
||||
select: {
|
||||
role: true,
|
||||
content: true,
|
||||
params: true,
|
||||
},
|
||||
orderBy: {
|
||||
idx: 'asc',
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const messages = PromptMessageSchema.array().safeParse(prompt.messages);
|
||||
const config = PromptConfigSchema.safeParse(prompt.config);
|
||||
if (messages.success && config.success) {
|
||||
const messages = PromptMessageSchema.array().safeParse(prompt?.messages);
|
||||
const config = PromptConfigSchema.safeParse(prompt?.config);
|
||||
if (prompt && messages.success && config.success) {
|
||||
const chatPrompt = ChatPrompt.createFromPrompt({
|
||||
...this.clonePrompt(prompt),
|
||||
action: prompt.action ?? null,
|
||||
optionalModels: prompt.optionalModels ?? [],
|
||||
...prompt,
|
||||
config: config.data,
|
||||
messages: messages.data,
|
||||
});
|
||||
@@ -140,69 +149,25 @@ export class PromptService implements OnApplicationBootstrap {
|
||||
name: string,
|
||||
model: string,
|
||||
messages: PromptMessage[],
|
||||
config?: PromptConfig | null,
|
||||
extraConfig?: { optionalModels: string[] }
|
||||
config?: PromptConfig | null
|
||||
) {
|
||||
this.ensureInMemoryPrompts();
|
||||
|
||||
const existing = this.inMemoryPrompts.get(name);
|
||||
const mergedOptionalModels = existing?.optionalModels
|
||||
? [...existing.optionalModels, ...(extraConfig?.optionalModels ?? [])]
|
||||
: extraConfig?.optionalModels;
|
||||
const inMemoryConfig = (!!config && structuredClone(config)) || undefined;
|
||||
const dbConfig = this.toDbConfig(config);
|
||||
this.inMemoryPrompts.set(name, {
|
||||
name,
|
||||
model,
|
||||
action: existing?.action,
|
||||
optionalModels: mergedOptionalModels,
|
||||
config: inMemoryConfig,
|
||||
messages: this.cloneMessages(messages),
|
||||
});
|
||||
this.cache.delete(name);
|
||||
|
||||
try {
|
||||
return await this.db.aiPrompt
|
||||
.upsert({
|
||||
where: { name },
|
||||
create: {
|
||||
name,
|
||||
action: existing?.action,
|
||||
model,
|
||||
optionalModels: mergedOptionalModels,
|
||||
config: dbConfig,
|
||||
messages: {
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
...m,
|
||||
attachments: m.attachments || undefined,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
},
|
||||
return await this.db.aiPrompt
|
||||
.create({
|
||||
data: {
|
||||
name,
|
||||
model,
|
||||
config: config || undefined,
|
||||
messages: {
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
...m,
|
||||
attachments: m.attachments || undefined,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
},
|
||||
update: {
|
||||
model,
|
||||
optionalModels: mergedOptionalModels,
|
||||
config: dbConfig,
|
||||
updatedAt: new Date(),
|
||||
messages: {
|
||||
deleteMany: {},
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
...m,
|
||||
attachments: m.attachments || undefined,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
},
|
||||
},
|
||||
})
|
||||
.then(ret => ret.id);
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Compat prompt upsert failed for "${name}": ${this.stringifyError(error)}`
|
||||
);
|
||||
return -1;
|
||||
}
|
||||
},
|
||||
})
|
||||
.then(ret => ret.id);
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
@@ -212,123 +177,44 @@ export class PromptService implements OnApplicationBootstrap {
|
||||
messages?: PromptMessage[];
|
||||
model?: string;
|
||||
modified?: boolean;
|
||||
config?: PromptConfig | null;
|
||||
config?: PromptConfig;
|
||||
},
|
||||
where?: Prisma.AiPromptWhereInput
|
||||
) {
|
||||
this.ensureInMemoryPrompts();
|
||||
const { config, messages, model, modified } = data;
|
||||
const existing = await this.db.aiPrompt
|
||||
.count({ where: { ...where, name } })
|
||||
.then(count => count > 0);
|
||||
if (existing) {
|
||||
await this.db.aiPrompt.update({
|
||||
where: { name },
|
||||
data: {
|
||||
config: config || undefined,
|
||||
updatedAt: new Date(),
|
||||
modified,
|
||||
model,
|
||||
messages: messages
|
||||
? {
|
||||
// cleanup old messages
|
||||
deleteMany: {},
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
...m,
|
||||
attachments: m.attachments || undefined,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
}
|
||||
: undefined,
|
||||
},
|
||||
});
|
||||
|
||||
const current = this.inMemoryPrompts.get(name);
|
||||
if (current) {
|
||||
const next = this.clonePrompt(current);
|
||||
if (model !== undefined) {
|
||||
next.model = model;
|
||||
}
|
||||
if (config === null) {
|
||||
next.config = undefined;
|
||||
} else if (config !== undefined) {
|
||||
next.config = structuredClone(config);
|
||||
}
|
||||
if (messages) {
|
||||
next.messages = this.cloneMessages(messages);
|
||||
}
|
||||
|
||||
this.inMemoryPrompts.set(name, next);
|
||||
this.cache.delete(name);
|
||||
}
|
||||
|
||||
try {
|
||||
const existing = await this.db.aiPrompt
|
||||
.count({ where: { ...where, name } })
|
||||
.then(count => count > 0);
|
||||
if (existing) {
|
||||
await this.db.aiPrompt.update({
|
||||
where: { name },
|
||||
data: {
|
||||
config: this.toDbConfig(config),
|
||||
updatedAt: new Date(),
|
||||
modified,
|
||||
model,
|
||||
messages: messages
|
||||
? {
|
||||
// cleanup old messages
|
||||
deleteMany: {},
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
...m,
|
||||
attachments: m.attachments || undefined,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
}
|
||||
: undefined,
|
||||
},
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Compat prompt update failed for "${name}": ${this.stringifyError(error)}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async delete(name: string) {
|
||||
this.inMemoryPrompts.delete(name);
|
||||
const { id } = await this.db.aiPrompt.delete({ where: { name } });
|
||||
this.cache.delete(name);
|
||||
|
||||
try {
|
||||
const { id } = await this.db.aiPrompt.delete({ where: { name } });
|
||||
return id;
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Compat prompt delete failed for "${name}": ${this.stringifyError(error)}`
|
||||
);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
private resetInMemoryPrompts() {
|
||||
this.cache.clear();
|
||||
this.inMemoryPrompts.clear();
|
||||
for (const prompt of prompts) {
|
||||
this.inMemoryPrompts.set(prompt.name, this.clonePrompt(prompt));
|
||||
}
|
||||
}
|
||||
|
||||
private ensureInMemoryPrompts() {
|
||||
if (!this.inMemoryPrompts.size) {
|
||||
this.resetInMemoryPrompts();
|
||||
}
|
||||
}
|
||||
|
||||
private toDbConfig(
|
||||
config: PromptConfig | null | undefined
|
||||
): Prisma.InputJsonValue | Prisma.NullableJsonNullValueInput | undefined {
|
||||
if (config === null) return Prisma.DbNull;
|
||||
if (config === undefined) return undefined;
|
||||
return config as Prisma.InputJsonValue;
|
||||
}
|
||||
|
||||
private cloneMessages(messages: PromptMessage[]) {
|
||||
return messages.map(message => ({
|
||||
...message,
|
||||
attachments: message.attachments ? [...message.attachments] : undefined,
|
||||
params: message.params ? structuredClone(message.params) : undefined,
|
||||
}));
|
||||
}
|
||||
|
||||
private clonePrompt(prompt: Prompt): Prompt {
|
||||
return {
|
||||
...prompt,
|
||||
optionalModels: prompt.optionalModels
|
||||
? [...prompt.optionalModels]
|
||||
: undefined,
|
||||
config: prompt.config ? structuredClone(prompt.config) : undefined,
|
||||
messages: this.cloneMessages(prompt.messages),
|
||||
};
|
||||
}
|
||||
|
||||
private stringifyError(error: unknown) {
|
||||
return error instanceof Error ? error.message : String(error);
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,90 +1,52 @@
|
||||
import type { ToolSet } from 'ai';
|
||||
import {
|
||||
type AnthropicProvider as AnthropicSDKProvider,
|
||||
type AnthropicProviderOptions,
|
||||
} from '@ai-sdk/anthropic';
|
||||
import { type GoogleVertexAnthropicProvider } from '@ai-sdk/google-vertex/anthropic';
|
||||
import { AISDKError, generateText, stepCountIs, streamText } from 'ai';
|
||||
|
||||
import {
|
||||
CopilotProviderSideError,
|
||||
metrics,
|
||||
UserFriendlyError,
|
||||
} from '../../../../base';
|
||||
import {
|
||||
llmDispatchStream,
|
||||
type NativeLlmBackendConfig,
|
||||
type NativeLlmRequest,
|
||||
} from '../../../../native';
|
||||
import type { NodeTextMiddleware } from '../../config';
|
||||
import { buildNativeRequest, NativeProviderAdapter } from '../native';
|
||||
import { CopilotProvider } from '../provider';
|
||||
import type {
|
||||
CopilotChatOptions,
|
||||
CopilotProviderModel,
|
||||
ModelConditions,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from '../types';
|
||||
import { CopilotProviderType, ModelOutputType } from '../types';
|
||||
import { getGoogleAuth, getVertexAnthropicBaseUrl } from '../utils';
|
||||
import { ModelOutputType } from '../types';
|
||||
import {
|
||||
chatToGPTMessage,
|
||||
StreamObjectParser,
|
||||
TextStreamParser,
|
||||
} from '../utils';
|
||||
|
||||
export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
protected abstract instance:
|
||||
| AnthropicSDKProvider
|
||||
| GoogleVertexAnthropicProvider;
|
||||
|
||||
private handleError(e: any) {
|
||||
if (e instanceof UserFriendlyError) {
|
||||
return e;
|
||||
} else if (e instanceof AISDKError) {
|
||||
this.logger.error('Throw error from ai sdk:', e);
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: e.name || 'unknown',
|
||||
message: e.message,
|
||||
});
|
||||
} else {
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected anthropic response',
|
||||
});
|
||||
}
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected anthropic response',
|
||||
});
|
||||
}
|
||||
|
||||
private async createNativeConfig(): Promise<NativeLlmBackendConfig> {
|
||||
if (this.type === CopilotProviderType.AnthropicVertex) {
|
||||
const auth = await getGoogleAuth(this.config as any, 'anthropic');
|
||||
const headers = auth.headers();
|
||||
const authorization =
|
||||
headers.Authorization ||
|
||||
(headers as Record<string, string | undefined>).authorization;
|
||||
const token =
|
||||
typeof authorization === 'string'
|
||||
? authorization.replace(/^Bearer\s+/i, '')
|
||||
: '';
|
||||
const baseUrl =
|
||||
getVertexAnthropicBaseUrl(this.config as any) || auth.baseUrl;
|
||||
return {
|
||||
base_url: baseUrl || '',
|
||||
auth_token: token,
|
||||
request_layer: 'vertex',
|
||||
headers,
|
||||
};
|
||||
}
|
||||
|
||||
const config = this.config as { apiKey: string; baseURL?: string };
|
||||
const baseUrl = config.baseURL || 'https://api.anthropic.com/v1';
|
||||
return {
|
||||
base_url: baseUrl.replace(/\/v1\/?$/, ''),
|
||||
auth_token: config.apiKey,
|
||||
};
|
||||
}
|
||||
|
||||
private createAdapter(
|
||||
backendConfig: NativeLlmBackendConfig,
|
||||
tools: ToolSet,
|
||||
nodeTextMiddleware?: NodeTextMiddleware[]
|
||||
) {
|
||||
return new NativeProviderAdapter(
|
||||
(request: NativeLlmRequest, signal?: AbortSignal) =>
|
||||
llmDispatchStream('anthropic', backendConfig, request, signal),
|
||||
tools,
|
||||
this.MAX_STEPS,
|
||||
{ nodeTextMiddleware }
|
||||
);
|
||||
}
|
||||
|
||||
private getReasoning(
|
||||
options: NonNullable<CopilotChatOptions>,
|
||||
model: string
|
||||
): Record<string, unknown> | undefined {
|
||||
if (options.reasoning && this.isReasoningModel(model)) {
|
||||
return { budget_tokens: 12000, include_thought: true };
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async text(
|
||||
@@ -97,29 +59,28 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
const backendConfig = await this.createNativeConfig();
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const reasoning = this.getReasoning(options, model.id);
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
reasoning,
|
||||
middleware,
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
|
||||
const [system, msgs] = await chatToGPTMessage(messages, true, true);
|
||||
|
||||
const modelInstance = this.instance(model.id);
|
||||
const { text, reasoning } = await generateText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
abortSignal: options.signal,
|
||||
providerOptions: {
|
||||
anthropic: this.getAnthropicOptions(options, model.id),
|
||||
},
|
||||
tools: await this.getTools(options, model.id),
|
||||
stopWhen: stepCountIs(this.MAX_STEPS),
|
||||
});
|
||||
const adapter = this.createAdapter(
|
||||
backendConfig,
|
||||
tools,
|
||||
middleware.node?.text
|
||||
);
|
||||
return await adapter.text(request, options.signal);
|
||||
|
||||
if (!text) throw new Error('Failed to generate text');
|
||||
|
||||
return reasoning ? `${reasoning}\n${text}` : text;
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -134,32 +95,25 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_calls')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const backendConfig = await this.createNativeConfig();
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createAdapter(
|
||||
backendConfig,
|
||||
tools,
|
||||
middleware.node?.text
|
||||
);
|
||||
for await (const chunk of adapter.streamText(request, options.signal)) {
|
||||
yield chunk;
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const parser = new TextStreamParser();
|
||||
for await (const chunk of fullStream) {
|
||||
const result = parser.parse(chunk);
|
||||
yield result;
|
||||
if (options.signal?.aborted) {
|
||||
await fullStream.cancel();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!options.signal?.aborted) {
|
||||
const footnotes = parser.end();
|
||||
if (footnotes.length) {
|
||||
yield `\n\n${footnotes}`;
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -176,34 +130,58 @@ export abstract class AnthropicProvider<T> extends CopilotProvider<T> {
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_calls')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const backendConfig = await this.createNativeConfig();
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createAdapter(
|
||||
backendConfig,
|
||||
tools,
|
||||
middleware.node?.text
|
||||
);
|
||||
for await (const chunk of adapter.streamObject(request, options.signal)) {
|
||||
yield chunk;
|
||||
.add(1, { model: model.id });
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const parser = new StreamObjectParser();
|
||||
for await (const chunk of fullStream) {
|
||||
const result = parser.parse(chunk);
|
||||
if (result) {
|
||||
yield result;
|
||||
}
|
||||
if (options.signal?.aborted) {
|
||||
await fullStream.cancel();
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
.add(1, { model: model.id });
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
private async getFullStream(
|
||||
model: CopilotProviderModel,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
) {
|
||||
const [system, msgs] = await chatToGPTMessage(messages, true, true);
|
||||
const { fullStream } = streamText({
|
||||
model: this.instance(model.id),
|
||||
system,
|
||||
messages: msgs,
|
||||
abortSignal: options.signal,
|
||||
providerOptions: {
|
||||
anthropic: this.getAnthropicOptions(options, model.id),
|
||||
},
|
||||
tools: await this.getTools(options, model.id),
|
||||
stopWhen: stepCountIs(this.MAX_STEPS),
|
||||
});
|
||||
return fullStream;
|
||||
}
|
||||
|
||||
private getAnthropicOptions(options: CopilotChatOptions, model: string) {
|
||||
const result: AnthropicProviderOptions = {};
|
||||
if (options?.reasoning && this.isReasoningModel(model)) {
|
||||
result.thinking = {
|
||||
type: 'enabled',
|
||||
budgetTokens: 12000,
|
||||
};
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private isReasoningModel(model: string) {
|
||||
// claude 3.5 sonnet doesn't support reasoning config
|
||||
return model.includes('sonnet') && !model.startsWith('claude-3-5-sonnet');
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
import {
|
||||
type AnthropicProvider as AnthropicSDKProvider,
|
||||
createAnthropic,
|
||||
} from '@ai-sdk/anthropic';
|
||||
import z from 'zod';
|
||||
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
@@ -48,12 +52,18 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
|
||||
},
|
||||
];
|
||||
|
||||
protected instance!: AnthropicSDKProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
return !!this.config.apiKey;
|
||||
}
|
||||
|
||||
override setup() {
|
||||
super.setup();
|
||||
this.instance = createAnthropic({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseURL,
|
||||
});
|
||||
}
|
||||
|
||||
override async refreshOnlineModels() {
|
||||
|
||||
@@ -5,11 +5,7 @@ import {
|
||||
} from '@ai-sdk/google-vertex/anthropic';
|
||||
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
import {
|
||||
getGoogleAuth,
|
||||
getVertexAnthropicBaseUrl,
|
||||
VertexModelListSchema,
|
||||
} from '../utils';
|
||||
import { getGoogleAuth, VertexModelListSchema } from '../utils';
|
||||
import { AnthropicProvider } from './anthropic';
|
||||
|
||||
export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings;
|
||||
@@ -53,8 +49,7 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
|
||||
protected instance!: GoogleVertexAnthropicProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
if (!this.config.location || !this.config.googleAuthOptions) return false;
|
||||
return !!this.config.project || !!getVertexAnthropicBaseUrl(this.config);
|
||||
return !!this.config.location && !!this.config.googleAuthOptions;
|
||||
}
|
||||
|
||||
override setup() {
|
||||
|
||||
@@ -1,141 +1,16 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import { Config } from '../../../base';
|
||||
import { ServerFeature, ServerService } from '../../../core';
|
||||
import type { CopilotProvider } from './provider';
|
||||
import {
|
||||
buildProviderRegistry,
|
||||
resolveModel,
|
||||
stripProviderPrefix,
|
||||
} from './provider-registry';
|
||||
import { CopilotProviderType, ModelFullConditions } from './types';
|
||||
|
||||
function isAsyncIterable(value: unknown): value is AsyncIterable<unknown> {
|
||||
return (
|
||||
value !== null &&
|
||||
value !== undefined &&
|
||||
typeof (value as AsyncIterable<unknown>)[Symbol.asyncIterator] ===
|
||||
'function'
|
||||
);
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class CopilotProviderFactory {
|
||||
constructor(
|
||||
private readonly server: ServerService,
|
||||
private readonly config: Config
|
||||
) {}
|
||||
constructor(private readonly server: ServerService) {}
|
||||
|
||||
private readonly logger = new Logger(CopilotProviderFactory.name);
|
||||
|
||||
readonly #providers = new Map<string, CopilotProvider>();
|
||||
readonly #boundProviders = new Map<string, CopilotProvider>();
|
||||
readonly #providerIdsByType = new Map<CopilotProviderType, Set<string>>();
|
||||
|
||||
private getRegistry() {
|
||||
return buildProviderRegistry(this.config.copilot.providers);
|
||||
}
|
||||
|
||||
private getPreferredProviderIds(type?: CopilotProviderType) {
|
||||
if (!type) return undefined;
|
||||
return this.#providerIdsByType.get(type);
|
||||
}
|
||||
|
||||
private normalizeCond(
|
||||
providerId: string,
|
||||
cond: ModelFullConditions
|
||||
): ModelFullConditions {
|
||||
const registry = this.getRegistry();
|
||||
const modelId = stripProviderPrefix(registry, providerId, cond.modelId);
|
||||
return { ...cond, modelId };
|
||||
}
|
||||
|
||||
private normalizeMethodArgs(providerId: string, args: unknown[]) {
|
||||
const [first, ...rest] = args;
|
||||
if (
|
||||
!first ||
|
||||
typeof first !== 'object' ||
|
||||
Array.isArray(first) ||
|
||||
!('modelId' in first)
|
||||
) {
|
||||
return args;
|
||||
}
|
||||
|
||||
const cond = first as Record<string, unknown>;
|
||||
if (typeof cond.modelId !== 'string') return args;
|
||||
|
||||
const registry = this.getRegistry();
|
||||
const modelId = stripProviderPrefix(registry, providerId, cond.modelId);
|
||||
return [{ ...cond, modelId }, ...rest];
|
||||
}
|
||||
|
||||
private wrapAsyncIterable<T>(
|
||||
provider: CopilotProvider,
|
||||
providerId: string,
|
||||
iterable: AsyncIterable<T>
|
||||
): AsyncIterableIterator<T> {
|
||||
const iterator = iterable[Symbol.asyncIterator]();
|
||||
|
||||
return {
|
||||
next: value =>
|
||||
provider.runWithProfile(providerId, () => iterator.next(value)),
|
||||
return: value =>
|
||||
provider.runWithProfile(providerId, async () => {
|
||||
if (typeof iterator.return === 'function') {
|
||||
return iterator.return(value as never);
|
||||
}
|
||||
return { done: true, value: value as T };
|
||||
}),
|
||||
throw: error =>
|
||||
provider.runWithProfile(providerId, async () => {
|
||||
if (typeof iterator.throw === 'function') {
|
||||
return iterator.throw(error);
|
||||
}
|
||||
throw error;
|
||||
}),
|
||||
[Symbol.asyncIterator]() {
|
||||
return this;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
private getBoundProvider(providerId: string, provider: CopilotProvider) {
|
||||
const cached = this.#boundProviders.get(providerId);
|
||||
if (cached) {
|
||||
return cached;
|
||||
}
|
||||
|
||||
const wrapped = new Proxy(provider, {
|
||||
get: (target, prop, receiver) => {
|
||||
if (prop === 'providerId') {
|
||||
return providerId;
|
||||
}
|
||||
|
||||
const value = Reflect.get(target, prop, receiver);
|
||||
if (typeof value !== 'function') {
|
||||
return value;
|
||||
}
|
||||
|
||||
return (...args: unknown[]) => {
|
||||
const normalizedArgs = this.normalizeMethodArgs(providerId, args);
|
||||
const result = provider.runWithProfile(providerId, () =>
|
||||
Reflect.apply(value, provider, normalizedArgs)
|
||||
);
|
||||
if (isAsyncIterable(result)) {
|
||||
return this.wrapAsyncIterable(
|
||||
provider,
|
||||
providerId,
|
||||
result as AsyncIterable<unknown>
|
||||
);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
},
|
||||
}) as CopilotProvider;
|
||||
|
||||
this.#boundProviders.set(providerId, wrapped);
|
||||
return wrapped;
|
||||
}
|
||||
readonly #providers = new Map<CopilotProviderType, CopilotProvider>();
|
||||
|
||||
async getProvider(
|
||||
cond: ModelFullConditions,
|
||||
@@ -146,41 +21,22 @@ export class CopilotProviderFactory {
|
||||
this.logger.debug(
|
||||
`Resolving copilot provider for output type: ${cond.outputType}`
|
||||
);
|
||||
const route = resolveModel({
|
||||
registry: this.getRegistry(),
|
||||
modelId: cond.modelId,
|
||||
outputType: cond.outputType,
|
||||
availableProviderIds: this.#providers.keys(),
|
||||
preferredProviderIds: this.getPreferredProviderIds(filter.prefer),
|
||||
});
|
||||
|
||||
const registry = this.getRegistry();
|
||||
for (const providerId of route.candidateProviderIds) {
|
||||
const provider = this.#providers.get(providerId);
|
||||
if (!provider) continue;
|
||||
|
||||
const profile = registry.profiles.get(providerId);
|
||||
const normalizedCond = this.normalizeCond(providerId, cond);
|
||||
if (
|
||||
normalizedCond.modelId &&
|
||||
profile?.models?.length &&
|
||||
!profile.models.includes(normalizedCond.modelId)
|
||||
) {
|
||||
let candidate: CopilotProvider | null = null;
|
||||
for (const [type, provider] of this.#providers.entries()) {
|
||||
if (filter.prefer && filter.prefer !== type) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const matched = await provider.runWithProfile(providerId, () =>
|
||||
provider.match(normalizedCond)
|
||||
);
|
||||
if (!matched) continue;
|
||||
const isMatched = await provider.match(cond);
|
||||
|
||||
this.logger.debug(
|
||||
`Copilot provider candidate found: ${provider.type} (${providerId})`
|
||||
);
|
||||
return this.getBoundProvider(providerId, provider);
|
||||
if (isMatched) {
|
||||
candidate = provider;
|
||||
this.logger.debug(`Copilot provider candidate found: ${type}`);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
return candidate;
|
||||
}
|
||||
|
||||
async getProviderByModel(
|
||||
@@ -190,50 +46,31 @@ export class CopilotProviderFactory {
|
||||
} = {}
|
||||
): Promise<CopilotProvider | null> {
|
||||
this.logger.debug(`Resolving copilot provider for model: ${modelId}`);
|
||||
return this.getProvider({ modelId }, filter);
|
||||
}
|
||||
|
||||
register(providerId: string, provider: CopilotProvider) {
|
||||
const existed = this.#providers.get(providerId);
|
||||
if (existed?.type && existed.type !== provider.type) {
|
||||
const ids = this.#providerIdsByType.get(existed.type);
|
||||
ids?.delete(providerId);
|
||||
if (!ids?.size) {
|
||||
this.#providerIdsByType.delete(existed.type);
|
||||
let candidate: CopilotProvider | null = null;
|
||||
for (const [type, provider] of this.#providers.entries()) {
|
||||
if (filter.prefer && filter.prefer !== type) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (await provider.match({ modelId })) {
|
||||
candidate = provider;
|
||||
this.logger.debug(`Copilot provider candidate found: ${type}`);
|
||||
}
|
||||
}
|
||||
|
||||
this.#providers.set(providerId, provider);
|
||||
this.#boundProviders.delete(providerId);
|
||||
return candidate;
|
||||
}
|
||||
|
||||
const ids = this.#providerIdsByType.get(provider.type) ?? new Set<string>();
|
||||
ids.add(providerId);
|
||||
this.#providerIdsByType.set(provider.type, ids);
|
||||
|
||||
this.logger.log(
|
||||
`Copilot provider [${provider.type}] registered as [${providerId}].`
|
||||
);
|
||||
register(provider: CopilotProvider) {
|
||||
this.#providers.set(provider.type, provider);
|
||||
this.logger.log(`Copilot provider [${provider.type}] registered.`);
|
||||
this.server.enableFeature(ServerFeature.Copilot);
|
||||
}
|
||||
|
||||
unregister(providerId: string, provider: CopilotProvider) {
|
||||
const existed = this.#providers.get(providerId);
|
||||
if (!existed || existed !== provider) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.#providers.delete(providerId);
|
||||
this.#boundProviders.delete(providerId);
|
||||
|
||||
const ids = this.#providerIdsByType.get(provider.type);
|
||||
ids?.delete(providerId);
|
||||
if (!ids?.size) {
|
||||
this.#providerIdsByType.delete(provider.type);
|
||||
}
|
||||
|
||||
this.logger.log(
|
||||
`Copilot provider [${provider.type}] unregistered from [${providerId}].`
|
||||
);
|
||||
unregister(provider: CopilotProvider) {
|
||||
this.#providers.delete(provider.type);
|
||||
this.logger.log(`Copilot provider [${provider.type}] unregistered.`);
|
||||
if (this.#providers.size === 0) {
|
||||
this.server.disableFeature(ServerFeature.Copilot);
|
||||
}
|
||||
|
||||
@@ -1,381 +0,0 @@
|
||||
import type { ToolSet } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import type {
|
||||
NativeLlmRequest,
|
||||
NativeLlmStreamEvent,
|
||||
NativeLlmToolDefinition,
|
||||
} from '../../../native';
|
||||
|
||||
export type NativeDispatchFn = (
|
||||
request: NativeLlmRequest,
|
||||
signal?: AbortSignal
|
||||
) => AsyncIterableIterator<NativeLlmStreamEvent>;
|
||||
|
||||
export type NativeToolCall = {
|
||||
id: string;
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
thought?: string;
|
||||
};
|
||||
|
||||
type ToolCallState = {
|
||||
name?: string;
|
||||
argumentsText: string;
|
||||
};
|
||||
|
||||
type ToolExecutionResult = {
|
||||
callId: string;
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
output: unknown;
|
||||
isError?: boolean;
|
||||
};
|
||||
|
||||
export class ToolCallAccumulator {
|
||||
readonly #states = new Map<string, ToolCallState>();
|
||||
|
||||
feedDelta(event: Extract<NativeLlmStreamEvent, { type: 'tool_call_delta' }>) {
|
||||
const state = this.#states.get(event.call_id) ?? {
|
||||
argumentsText: '',
|
||||
};
|
||||
if (event.name) {
|
||||
state.name = event.name;
|
||||
}
|
||||
if (event.arguments_delta) {
|
||||
state.argumentsText += event.arguments_delta;
|
||||
}
|
||||
this.#states.set(event.call_id, state);
|
||||
}
|
||||
|
||||
complete(event: Extract<NativeLlmStreamEvent, { type: 'tool_call' }>) {
|
||||
const state = this.#states.get(event.call_id);
|
||||
this.#states.delete(event.call_id);
|
||||
return {
|
||||
id: event.call_id,
|
||||
name: event.name || state?.name || '',
|
||||
args: this.parseArgs(
|
||||
event.arguments ?? this.parseJson(state?.argumentsText ?? '{}')
|
||||
),
|
||||
thought: event.thought,
|
||||
} satisfies NativeToolCall;
|
||||
}
|
||||
|
||||
drainPending() {
|
||||
const pending: NativeToolCall[] = [];
|
||||
for (const [callId, state] of this.#states.entries()) {
|
||||
if (!state.name) {
|
||||
continue;
|
||||
}
|
||||
pending.push({
|
||||
id: callId,
|
||||
name: state.name,
|
||||
args: this.parseArgs(this.parseJson(state.argumentsText)),
|
||||
});
|
||||
}
|
||||
this.#states.clear();
|
||||
return pending;
|
||||
}
|
||||
|
||||
private parseJson(jsonText: string): unknown {
|
||||
if (!jsonText.trim()) {
|
||||
return {};
|
||||
}
|
||||
try {
|
||||
return JSON.parse(jsonText);
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
private parseArgs(value: unknown): Record<string, unknown> {
|
||||
if (value && typeof value === 'object' && !Array.isArray(value)) {
|
||||
return value as Record<string, unknown>;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
export class ToolSchemaExtractor {
|
||||
static extract(toolSet: ToolSet): NativeLlmToolDefinition[] {
|
||||
return Object.entries(toolSet).map(([name, tool]) => {
|
||||
const unknownTool = tool as Record<string, unknown>;
|
||||
const inputSchema =
|
||||
unknownTool.inputSchema ?? unknownTool.parameters ?? z.object({});
|
||||
|
||||
return {
|
||||
name,
|
||||
description:
|
||||
typeof unknownTool.description === 'string'
|
||||
? unknownTool.description
|
||||
: undefined,
|
||||
parameters: this.toJsonSchema(inputSchema),
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
private static toJsonSchema(schema: unknown): Record<string, unknown> {
|
||||
if (!(schema instanceof z.ZodType)) {
|
||||
if (schema && typeof schema === 'object' && !Array.isArray(schema)) {
|
||||
return schema as Record<string, unknown>;
|
||||
}
|
||||
return { type: 'object', properties: {} };
|
||||
}
|
||||
|
||||
if (schema instanceof z.ZodObject) {
|
||||
const shape = schema.shape;
|
||||
const properties: Record<string, unknown> = {};
|
||||
const required: string[] = [];
|
||||
|
||||
for (const [key, child] of Object.entries(
|
||||
shape as Record<string, z.ZodTypeAny>
|
||||
)) {
|
||||
properties[key] = this.toJsonSchema(child);
|
||||
if (!this.isOptional(child)) {
|
||||
required.push(key);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'object',
|
||||
properties,
|
||||
additionalProperties: false,
|
||||
...(required.length ? { required } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
if (schema instanceof z.ZodString) {
|
||||
return { type: 'string' };
|
||||
}
|
||||
if (schema instanceof z.ZodNumber) {
|
||||
return { type: 'number' };
|
||||
}
|
||||
if (schema instanceof z.ZodBoolean) {
|
||||
return { type: 'boolean' };
|
||||
}
|
||||
if (schema instanceof z.ZodArray) {
|
||||
return { type: 'array', items: this.toJsonSchema(schema.element) };
|
||||
}
|
||||
if (schema instanceof z.ZodEnum) {
|
||||
return { type: 'string', enum: schema.options };
|
||||
}
|
||||
if (schema instanceof z.ZodLiteral) {
|
||||
const literal = schema.value;
|
||||
if (literal === null) {
|
||||
return { const: null, type: 'null' };
|
||||
}
|
||||
if (typeof literal === 'string') {
|
||||
return { const: literal, type: 'string' };
|
||||
}
|
||||
if (typeof literal === 'number') {
|
||||
return { const: literal, type: 'number' };
|
||||
}
|
||||
if (typeof literal === 'boolean') {
|
||||
return { const: literal, type: 'boolean' };
|
||||
}
|
||||
return { const: literal };
|
||||
}
|
||||
if (schema instanceof z.ZodUnion) {
|
||||
return {
|
||||
anyOf: schema.options.map((option: z.ZodTypeAny) =>
|
||||
this.toJsonSchema(option)
|
||||
),
|
||||
};
|
||||
}
|
||||
if (schema instanceof z.ZodRecord) {
|
||||
return {
|
||||
type: 'object',
|
||||
additionalProperties: this.toJsonSchema(schema.valueSchema),
|
||||
};
|
||||
}
|
||||
|
||||
if (schema instanceof z.ZodNullable) {
|
||||
const inner = (schema._def as { innerType?: z.ZodTypeAny }).innerType;
|
||||
return { anyOf: [this.toJsonSchema(inner), { type: 'null' }] };
|
||||
}
|
||||
|
||||
if (schema instanceof z.ZodOptional || schema instanceof z.ZodDefault) {
|
||||
return this.toJsonSchema(
|
||||
(schema._def as { innerType?: z.ZodTypeAny }).innerType
|
||||
);
|
||||
}
|
||||
|
||||
if (schema instanceof z.ZodEffects) {
|
||||
return this.toJsonSchema(
|
||||
(schema._def as { schema?: z.ZodTypeAny }).schema
|
||||
);
|
||||
}
|
||||
|
||||
return { type: 'object', properties: {} };
|
||||
}
|
||||
|
||||
private static isOptional(schema: z.ZodTypeAny): boolean {
|
||||
if (schema instanceof z.ZodOptional || schema instanceof z.ZodDefault) {
|
||||
return true;
|
||||
}
|
||||
if (schema instanceof z.ZodNullable) {
|
||||
return this.isOptional(
|
||||
(schema._def as { innerType: z.ZodTypeAny }).innerType
|
||||
);
|
||||
}
|
||||
if (schema instanceof z.ZodEffects) {
|
||||
return this.isOptional((schema._def as { schema: z.ZodTypeAny }).schema);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export class ToolCallLoop {
|
||||
constructor(
|
||||
private readonly dispatch: NativeDispatchFn,
|
||||
private readonly tools: ToolSet,
|
||||
private readonly maxSteps = 20
|
||||
) {}
|
||||
|
||||
async *run(
|
||||
request: NativeLlmRequest,
|
||||
signal?: AbortSignal
|
||||
): AsyncIterableIterator<NativeLlmStreamEvent> {
|
||||
const messages = request.messages.map(message => ({
|
||||
...message,
|
||||
content: [...message.content],
|
||||
}));
|
||||
|
||||
for (let step = 0; step < this.maxSteps; step++) {
|
||||
const toolCalls: NativeToolCall[] = [];
|
||||
const accumulator = new ToolCallAccumulator();
|
||||
let finalDone: Extract<NativeLlmStreamEvent, { type: 'done' }> | null =
|
||||
null;
|
||||
|
||||
for await (const event of this.dispatch(
|
||||
{
|
||||
...request,
|
||||
stream: true,
|
||||
messages,
|
||||
},
|
||||
signal
|
||||
)) {
|
||||
switch (event.type) {
|
||||
case 'tool_call_delta': {
|
||||
accumulator.feedDelta(event);
|
||||
break;
|
||||
}
|
||||
case 'tool_call': {
|
||||
toolCalls.push(accumulator.complete(event));
|
||||
yield event;
|
||||
break;
|
||||
}
|
||||
case 'done': {
|
||||
finalDone = event;
|
||||
break;
|
||||
}
|
||||
case 'error': {
|
||||
throw new Error(event.message);
|
||||
}
|
||||
default: {
|
||||
yield event;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push(...accumulator.drainPending());
|
||||
if (toolCalls.length === 0) {
|
||||
if (finalDone) {
|
||||
yield finalDone;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (step === this.maxSteps - 1) {
|
||||
throw new Error('ToolCallLoop max steps reached');
|
||||
}
|
||||
|
||||
const toolResults = await this.executeTools(toolCalls);
|
||||
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
content: toolCalls.map(call => ({
|
||||
type: 'tool_call',
|
||||
call_id: call.id,
|
||||
name: call.name,
|
||||
arguments: call.args,
|
||||
thought: call.thought,
|
||||
})),
|
||||
});
|
||||
|
||||
for (const result of toolResults) {
|
||||
messages.push({
|
||||
role: 'tool',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
call_id: result.callId,
|
||||
output: result.output,
|
||||
is_error: result.isError,
|
||||
},
|
||||
],
|
||||
});
|
||||
yield {
|
||||
type: 'tool_result',
|
||||
call_id: result.callId,
|
||||
name: result.name,
|
||||
arguments: result.args,
|
||||
output: result.output,
|
||||
is_error: result.isError,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async executeTools(calls: NativeToolCall[]) {
|
||||
return await Promise.all(calls.map(call => this.executeTool(call)));
|
||||
}
|
||||
|
||||
private async executeTool(
|
||||
call: NativeToolCall
|
||||
): Promise<ToolExecutionResult> {
|
||||
const tool = this.tools[call.name] as
|
||||
| {
|
||||
execute?: (args: Record<string, unknown>) => Promise<unknown>;
|
||||
}
|
||||
| undefined;
|
||||
|
||||
if (!tool?.execute) {
|
||||
return {
|
||||
callId: call.id,
|
||||
name: call.name,
|
||||
args: call.args,
|
||||
isError: true,
|
||||
output: {
|
||||
message: `Tool not found: ${call.name}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const output = await tool.execute(call.args);
|
||||
return {
|
||||
callId: call.id,
|
||||
name: call.name,
|
||||
args: call.args,
|
||||
output: output ?? null,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Tool execution failed', {
|
||||
callId: call.id,
|
||||
toolName: call.name,
|
||||
error,
|
||||
});
|
||||
return {
|
||||
callId: call.id,
|
||||
name: call.name,
|
||||
args: call.args,
|
||||
isError: true,
|
||||
output: {
|
||||
message: 'Tool execution failed',
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,14 @@
|
||||
import type { ToolSet } from 'ai';
|
||||
import {
|
||||
createOpenAICompatible,
|
||||
OpenAICompatibleProvider as VercelOpenAICompatibleProvider,
|
||||
} from '@ai-sdk/openai-compatible';
|
||||
import { AISDKError, generateText, streamText } from 'ai';
|
||||
|
||||
import {
|
||||
CopilotProviderSideError,
|
||||
metrics,
|
||||
UserFriendlyError,
|
||||
} from '../../../base';
|
||||
import {
|
||||
llmDispatchStream,
|
||||
type NativeLlmBackendConfig,
|
||||
type NativeLlmRequest,
|
||||
} from '../../../native';
|
||||
import type { NodeTextMiddleware } from '../config';
|
||||
import { buildNativeRequest, NativeProviderAdapter } from './native';
|
||||
import { CopilotProvider } from './provider';
|
||||
import type {
|
||||
CopilotChatOptions,
|
||||
@@ -19,6 +16,7 @@ import type {
|
||||
PromptMessage,
|
||||
} from './types';
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from './types';
|
||||
import { chatToGPTMessage, TextStreamParser } from './utils';
|
||||
|
||||
export const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
@@ -59,48 +57,37 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
|
||||
},
|
||||
];
|
||||
|
||||
#instance!: VercelOpenAICompatibleProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
return !!this.config.apiKey;
|
||||
}
|
||||
|
||||
protected override setup() {
|
||||
super.setup();
|
||||
this.#instance = createOpenAICompatible({
|
||||
name: this.type,
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: 'https://api.morphllm.com/v1',
|
||||
});
|
||||
}
|
||||
|
||||
private handleError(e: any) {
|
||||
if (e instanceof UserFriendlyError) {
|
||||
return e;
|
||||
} else if (e instanceof AISDKError) {
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: e.name || 'unknown',
|
||||
message: e.message,
|
||||
});
|
||||
} else {
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected morph response',
|
||||
});
|
||||
}
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected morph response',
|
||||
});
|
||||
}
|
||||
|
||||
private createNativeConfig(): NativeLlmBackendConfig {
|
||||
return {
|
||||
base_url: 'https://api.morphllm.com',
|
||||
auth_token: this.config.apiKey ?? '',
|
||||
};
|
||||
}
|
||||
|
||||
private createNativeAdapter(
|
||||
tools: ToolSet,
|
||||
nodeTextMiddleware?: NodeTextMiddleware[]
|
||||
) {
|
||||
return new NativeProviderAdapter(
|
||||
(request: NativeLlmRequest, signal?: AbortSignal) =>
|
||||
llmDispatchStream(
|
||||
'openai_chat',
|
||||
this.createNativeConfig(),
|
||||
request,
|
||||
signal
|
||||
),
|
||||
tools,
|
||||
this.MAX_STEPS,
|
||||
{ nodeTextMiddleware }
|
||||
);
|
||||
}
|
||||
|
||||
async text(
|
||||
@@ -116,22 +103,22 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
middleware,
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const modelInstance = this.#instance(model.id);
|
||||
|
||||
const { text } = await generateText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
return await adapter.text(request, options.signal);
|
||||
|
||||
return text.trim();
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -149,26 +136,38 @@ export class MorphProvider extends CopilotProvider<MorphConfig> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_calls')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
middleware,
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const modelInstance = this.#instance(model.id);
|
||||
|
||||
const { fullStream } = streamText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
for await (const chunk of adapter.streamText(request, options.signal)) {
|
||||
yield chunk;
|
||||
|
||||
const textParser = new TextStreamParser();
|
||||
for await (const chunk of fullStream) {
|
||||
switch (chunk.type) {
|
||||
case 'text-delta': {
|
||||
let result = textParser.parse(chunk);
|
||||
yield result;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
yield textParser.parse(chunk);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (options.signal?.aborted) {
|
||||
await fullStream.cancel();
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,464 +0,0 @@
|
||||
import type { ToolSet } from 'ai';
|
||||
import { ZodType } from 'zod';
|
||||
|
||||
import type {
|
||||
NativeLlmCoreContent,
|
||||
NativeLlmCoreMessage,
|
||||
NativeLlmRequest,
|
||||
NativeLlmStreamEvent,
|
||||
} from '../../../native';
|
||||
import type { NodeTextMiddleware, ProviderMiddlewareConfig } from '../config';
|
||||
import { NativeDispatchFn, ToolCallLoop, ToolSchemaExtractor } from './loop';
|
||||
import type { CopilotChatOptions, PromptMessage, StreamObject } from './types';
|
||||
import {
|
||||
CitationFootnoteFormatter,
|
||||
inferMimeType,
|
||||
TextStreamParser,
|
||||
} from './utils';
|
||||
|
||||
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
|
||||
|
||||
type BuildNativeRequestOptions = {
|
||||
model: string;
|
||||
messages: PromptMessage[];
|
||||
options?: CopilotChatOptions;
|
||||
tools?: ToolSet;
|
||||
withAttachment?: boolean;
|
||||
include?: string[];
|
||||
reasoning?: Record<string, unknown>;
|
||||
middleware?: ProviderMiddlewareConfig;
|
||||
};
|
||||
|
||||
type BuildNativeRequestResult = {
|
||||
request: NativeLlmRequest;
|
||||
schema?: ZodType;
|
||||
};
|
||||
|
||||
type ToolCallMeta = {
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
};
|
||||
|
||||
type NormalizedToolResultEvent = Extract<
|
||||
NativeLlmStreamEvent,
|
||||
{ type: 'tool_result' }
|
||||
> & {
|
||||
name: string;
|
||||
arguments: Record<string, unknown>;
|
||||
};
|
||||
|
||||
type AttachmentFootnote = {
|
||||
blobId: string;
|
||||
fileName: string;
|
||||
fileType: string;
|
||||
};
|
||||
|
||||
type NativeProviderAdapterOptions = {
|
||||
nodeTextMiddleware?: NodeTextMiddleware[];
|
||||
};
|
||||
|
||||
function roleToCore(role: PromptMessage['role']) {
|
||||
switch (role) {
|
||||
case 'assistant':
|
||||
return 'assistant';
|
||||
case 'system':
|
||||
return 'system';
|
||||
default:
|
||||
return 'user';
|
||||
}
|
||||
}
|
||||
|
||||
async function toCoreContents(
|
||||
message: PromptMessage,
|
||||
withAttachment: boolean
|
||||
): Promise<NativeLlmCoreContent[]> {
|
||||
const contents: NativeLlmCoreContent[] = [];
|
||||
|
||||
if (typeof message.content === 'string' && message.content.length) {
|
||||
contents.push({ type: 'text', text: message.content });
|
||||
}
|
||||
|
||||
if (!withAttachment || !Array.isArray(message.attachments)) return contents;
|
||||
|
||||
for (const entry of message.attachments) {
|
||||
let attachmentUrl: string;
|
||||
let mediaType: string;
|
||||
|
||||
if (typeof entry === 'string') {
|
||||
attachmentUrl = entry;
|
||||
mediaType =
|
||||
typeof message.params?.mimetype === 'string'
|
||||
? message.params.mimetype
|
||||
: await inferMimeType(entry);
|
||||
} else {
|
||||
attachmentUrl = entry.attachment;
|
||||
mediaType = entry.mimeType;
|
||||
}
|
||||
|
||||
if (!SIMPLE_IMAGE_URL_REGEX.test(attachmentUrl)) continue;
|
||||
if (!mediaType.startsWith('image/')) continue;
|
||||
|
||||
contents.push({ type: 'image', source: { url: attachmentUrl } });
|
||||
}
|
||||
|
||||
return contents;
|
||||
}
|
||||
|
||||
export async function buildNativeRequest({
|
||||
model,
|
||||
messages,
|
||||
options = {},
|
||||
tools = {},
|
||||
withAttachment = true,
|
||||
include,
|
||||
reasoning,
|
||||
middleware,
|
||||
}: BuildNativeRequestOptions): Promise<BuildNativeRequestResult> {
|
||||
const copiedMessages = messages.map(message => ({
|
||||
...message,
|
||||
attachments: message.attachments
|
||||
? [...message.attachments]
|
||||
: message.attachments,
|
||||
}));
|
||||
|
||||
const systemMessage =
|
||||
copiedMessages[0]?.role === 'system' ? copiedMessages.shift() : undefined;
|
||||
const schema =
|
||||
systemMessage?.params?.schema instanceof ZodType
|
||||
? systemMessage.params.schema
|
||||
: undefined;
|
||||
|
||||
const coreMessages: NativeLlmCoreMessage[] = [];
|
||||
if (systemMessage?.content?.length) {
|
||||
coreMessages.push({
|
||||
role: 'system',
|
||||
content: [{ type: 'text', text: systemMessage.content }],
|
||||
});
|
||||
}
|
||||
|
||||
for (const message of copiedMessages) {
|
||||
if (message.role === 'system') continue;
|
||||
const content = await toCoreContents(message, withAttachment);
|
||||
coreMessages.push({ role: roleToCore(message.role), content });
|
||||
}
|
||||
|
||||
return {
|
||||
request: {
|
||||
model,
|
||||
stream: true,
|
||||
messages: coreMessages,
|
||||
max_tokens: options.maxTokens ?? undefined,
|
||||
temperature: options.temperature ?? undefined,
|
||||
tools: ToolSchemaExtractor.extract(tools),
|
||||
tool_choice: Object.keys(tools).length ? 'auto' : undefined,
|
||||
include,
|
||||
reasoning,
|
||||
middleware: middleware?.rust
|
||||
? { request: middleware.rust.request, stream: middleware.rust.stream }
|
||||
: undefined,
|
||||
},
|
||||
schema,
|
||||
};
|
||||
}
|
||||
|
||||
function ensureToolResultMeta(
|
||||
event: Extract<NativeLlmStreamEvent, { type: 'tool_result' }>,
|
||||
toolCalls: Map<string, ToolCallMeta>
|
||||
): NormalizedToolResultEvent | null {
|
||||
const name = event.name ?? toolCalls.get(event.call_id)?.name;
|
||||
const args = event.arguments ?? toolCalls.get(event.call_id)?.args;
|
||||
|
||||
if (!name || !args) return null;
|
||||
return { ...event, name, arguments: args };
|
||||
}
|
||||
|
||||
function pickAttachmentFootnote(value: unknown): AttachmentFootnote | null {
|
||||
if (!value || typeof value !== 'object') {
|
||||
return null;
|
||||
}
|
||||
|
||||
const record = value as Record<string, unknown>;
|
||||
const blobId =
|
||||
typeof record.blobId === 'string'
|
||||
? record.blobId
|
||||
: typeof record.blob_id === 'string'
|
||||
? record.blob_id
|
||||
: undefined;
|
||||
const fileName =
|
||||
typeof record.fileName === 'string'
|
||||
? record.fileName
|
||||
: typeof record.name === 'string'
|
||||
? record.name
|
||||
: undefined;
|
||||
const fileType =
|
||||
typeof record.fileType === 'string'
|
||||
? record.fileType
|
||||
: typeof record.mimeType === 'string'
|
||||
? record.mimeType
|
||||
: 'application/octet-stream';
|
||||
|
||||
if (!blobId || !fileName) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return { blobId, fileName, fileType };
|
||||
}
|
||||
|
||||
function collectAttachmentFootnotes(
|
||||
event: NormalizedToolResultEvent
|
||||
): AttachmentFootnote[] {
|
||||
if (event.name === 'blob_read') {
|
||||
const item = pickAttachmentFootnote(event.output);
|
||||
return item ? [item] : [];
|
||||
}
|
||||
|
||||
if (event.name === 'doc_semantic_search' && Array.isArray(event.output)) {
|
||||
return event.output
|
||||
.map(item => pickAttachmentFootnote(item))
|
||||
.filter((item): item is AttachmentFootnote => item !== null);
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
function formatAttachmentFootnotes(attachments: AttachmentFootnote[]) {
|
||||
const references = attachments.map((_, index) => `[^${index + 1}]`).join('');
|
||||
const definitions = attachments
|
||||
.map((attachment, index) => {
|
||||
return `[^${index + 1}]: ${JSON.stringify({
|
||||
type: 'attachment',
|
||||
blobId: attachment.blobId,
|
||||
fileName: attachment.fileName,
|
||||
fileType: attachment.fileType,
|
||||
})}`;
|
||||
})
|
||||
.join('\n');
|
||||
|
||||
return `\n\n${references}\n\n${definitions}`;
|
||||
}
|
||||
|
||||
export class NativeProviderAdapter {
|
||||
readonly #loop: ToolCallLoop;
|
||||
readonly #enableCallout: boolean;
|
||||
readonly #enableCitationFootnote: boolean;
|
||||
|
||||
constructor(
|
||||
dispatch: NativeDispatchFn,
|
||||
tools: ToolSet,
|
||||
maxSteps = 20,
|
||||
options: NativeProviderAdapterOptions = {}
|
||||
) {
|
||||
this.#loop = new ToolCallLoop(dispatch, tools, maxSteps);
|
||||
const enabledNodeTextMiddlewares = new Set(
|
||||
options.nodeTextMiddleware ?? ['citation_footnote', 'callout']
|
||||
);
|
||||
this.#enableCallout =
|
||||
enabledNodeTextMiddlewares.has('callout') ||
|
||||
enabledNodeTextMiddlewares.has('thinking_format');
|
||||
this.#enableCitationFootnote =
|
||||
enabledNodeTextMiddlewares.has('citation_footnote');
|
||||
}
|
||||
|
||||
async text(request: NativeLlmRequest, signal?: AbortSignal) {
|
||||
let output = '';
|
||||
for await (const chunk of this.streamText(request, signal)) {
|
||||
output += chunk;
|
||||
}
|
||||
return output.trim();
|
||||
}
|
||||
|
||||
async *streamText(
|
||||
request: NativeLlmRequest,
|
||||
signal?: AbortSignal
|
||||
): AsyncIterableIterator<string> {
|
||||
const textParser = this.#enableCallout ? new TextStreamParser() : null;
|
||||
const citationFormatter = this.#enableCitationFootnote
|
||||
? new CitationFootnoteFormatter()
|
||||
: null;
|
||||
const toolCalls = new Map<string, ToolCallMeta>();
|
||||
let streamPartId = 0;
|
||||
|
||||
for await (const event of this.#loop.run(request, signal)) {
|
||||
switch (event.type) {
|
||||
case 'text_delta': {
|
||||
if (textParser) {
|
||||
yield textParser.parse({
|
||||
type: 'text-delta',
|
||||
id: String(streamPartId++),
|
||||
text: event.text,
|
||||
});
|
||||
} else {
|
||||
yield event.text;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'reasoning_delta': {
|
||||
if (textParser) {
|
||||
yield textParser.parse({
|
||||
type: 'reasoning-delta',
|
||||
id: String(streamPartId++),
|
||||
text: event.text,
|
||||
});
|
||||
} else {
|
||||
yield event.text;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'tool_call': {
|
||||
const toolCall = {
|
||||
name: event.name,
|
||||
args: event.arguments,
|
||||
};
|
||||
toolCalls.set(event.call_id, toolCall);
|
||||
if (textParser) {
|
||||
yield textParser.parse({
|
||||
type: 'tool-call',
|
||||
toolCallId: event.call_id,
|
||||
toolName: event.name as never,
|
||||
input: event.arguments,
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'tool_result': {
|
||||
const normalized = ensureToolResultMeta(event, toolCalls);
|
||||
if (!normalized || !textParser) {
|
||||
break;
|
||||
}
|
||||
yield textParser.parse({
|
||||
type: 'tool-result',
|
||||
toolCallId: normalized.call_id,
|
||||
toolName: normalized.name as never,
|
||||
input: normalized.arguments,
|
||||
output: normalized.output,
|
||||
});
|
||||
break;
|
||||
}
|
||||
case 'citation': {
|
||||
if (citationFormatter) {
|
||||
citationFormatter.consume({
|
||||
type: 'citation',
|
||||
index: event.index,
|
||||
url: event.url,
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'done': {
|
||||
const footnotes = textParser?.end() ?? '';
|
||||
const citations = citationFormatter?.end() ?? '';
|
||||
const tails = [citations, footnotes].filter(Boolean).join('\n');
|
||||
if (tails) {
|
||||
yield `\n${tails}`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'error': {
|
||||
throw new Error(event.message);
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async *streamObject(
|
||||
request: NativeLlmRequest,
|
||||
signal?: AbortSignal
|
||||
): AsyncIterableIterator<StreamObject> {
|
||||
const toolCalls = new Map<string, ToolCallMeta>();
|
||||
const citationFormatter = this.#enableCitationFootnote
|
||||
? new CitationFootnoteFormatter()
|
||||
: null;
|
||||
const fallbackAttachmentFootnotes = new Map<string, AttachmentFootnote>();
|
||||
let hasFootnoteReference = false;
|
||||
|
||||
for await (const event of this.#loop.run(request, signal)) {
|
||||
switch (event.type) {
|
||||
case 'text_delta': {
|
||||
if (event.text.includes('[^')) {
|
||||
hasFootnoteReference = true;
|
||||
}
|
||||
yield {
|
||||
type: 'text-delta',
|
||||
textDelta: event.text,
|
||||
};
|
||||
break;
|
||||
}
|
||||
case 'reasoning_delta': {
|
||||
yield {
|
||||
type: 'reasoning',
|
||||
textDelta: event.text,
|
||||
};
|
||||
break;
|
||||
}
|
||||
case 'tool_call': {
|
||||
const toolCall = {
|
||||
name: event.name,
|
||||
args: event.arguments,
|
||||
};
|
||||
toolCalls.set(event.call_id, toolCall);
|
||||
yield {
|
||||
type: 'tool-call',
|
||||
toolCallId: event.call_id,
|
||||
toolName: event.name,
|
||||
args: event.arguments,
|
||||
};
|
||||
break;
|
||||
}
|
||||
case 'tool_result': {
|
||||
const normalized = ensureToolResultMeta(event, toolCalls);
|
||||
if (!normalized) {
|
||||
break;
|
||||
}
|
||||
const attachments = collectAttachmentFootnotes(normalized);
|
||||
attachments.forEach(attachment => {
|
||||
fallbackAttachmentFootnotes.set(attachment.blobId, attachment);
|
||||
});
|
||||
yield {
|
||||
type: 'tool-result',
|
||||
toolCallId: normalized.call_id,
|
||||
toolName: normalized.name,
|
||||
args: normalized.arguments,
|
||||
result: normalized.output,
|
||||
};
|
||||
break;
|
||||
}
|
||||
case 'citation': {
|
||||
if (citationFormatter) {
|
||||
citationFormatter.consume({
|
||||
type: 'citation',
|
||||
index: event.index,
|
||||
url: event.url,
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'done': {
|
||||
const citations = citationFormatter?.end() ?? '';
|
||||
if (citations) {
|
||||
hasFootnoteReference = true;
|
||||
yield {
|
||||
type: 'text-delta',
|
||||
textDelta: `\n${citations}`,
|
||||
};
|
||||
}
|
||||
if (!hasFootnoteReference && fallbackAttachmentFootnotes.size > 0) {
|
||||
yield {
|
||||
type: 'text-delta',
|
||||
textDelta: formatAttachmentFootnotes(
|
||||
Array.from(fallbackAttachmentFootnotes.values())
|
||||
),
|
||||
};
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'error': {
|
||||
throw new Error(event.message);
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,35 +1,53 @@
|
||||
import type { Tool, ToolSet } from 'ai';
|
||||
import {
|
||||
createOpenAI,
|
||||
openai,
|
||||
type OpenAIProvider as VercelOpenAIProvider,
|
||||
OpenAIResponsesProviderOptions,
|
||||
} from '@ai-sdk/openai';
|
||||
import {
|
||||
createOpenAICompatible,
|
||||
type OpenAICompatibleProvider as VercelOpenAICompatibleProvider,
|
||||
} from '@ai-sdk/openai-compatible';
|
||||
import {
|
||||
AISDKError,
|
||||
embedMany,
|
||||
experimental_generateImage as generateImage,
|
||||
generateObject,
|
||||
generateText,
|
||||
stepCountIs,
|
||||
streamText,
|
||||
Tool,
|
||||
} from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
CopilotPromptInvalid,
|
||||
CopilotProviderNotSupported,
|
||||
CopilotProviderSideError,
|
||||
fetchBuffer,
|
||||
metrics,
|
||||
OneMB,
|
||||
readResponseBufferWithLimit,
|
||||
safeFetch,
|
||||
UserFriendlyError,
|
||||
} from '../../../base';
|
||||
import {
|
||||
llmDispatchStream,
|
||||
type NativeLlmBackendConfig,
|
||||
type NativeLlmRequest,
|
||||
} from '../../../native';
|
||||
import type { NodeTextMiddleware } from '../config';
|
||||
import { buildNativeRequest, NativeProviderAdapter } from './native';
|
||||
import { CopilotProvider } from './provider';
|
||||
import type {
|
||||
CopilotChatOptions,
|
||||
CopilotChatTools,
|
||||
CopilotEmbeddingOptions,
|
||||
CopilotImageOptions,
|
||||
CopilotProviderModel,
|
||||
CopilotStructuredOptions,
|
||||
ModelConditions,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from './types';
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from './types';
|
||||
import { chatToGPTMessage } from './utils';
|
||||
import {
|
||||
chatToGPTMessage,
|
||||
CitationParser,
|
||||
StreamObjectParser,
|
||||
TextStreamParser,
|
||||
} from './utils';
|
||||
|
||||
export const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
@@ -45,12 +63,7 @@ const ModelListSchema = z.object({
|
||||
|
||||
const ImageResponseSchema = z.union([
|
||||
z.object({
|
||||
data: z.array(
|
||||
z.object({
|
||||
b64_json: z.string().optional(),
|
||||
url: z.string().optional(),
|
||||
})
|
||||
),
|
||||
data: z.array(z.object({ b64_json: z.string() })),
|
||||
}),
|
||||
z.object({
|
||||
error: z.object({
|
||||
@@ -74,38 +87,6 @@ const LogProbsSchema = z.array(
|
||||
})
|
||||
);
|
||||
|
||||
const TRUSTED_ATTACHMENT_HOST_SUFFIXES = ['cdn.affine.pro'];
|
||||
|
||||
function normalizeImageFormatToMime(format?: string) {
|
||||
switch (format?.toLowerCase()) {
|
||||
case 'jpg':
|
||||
case 'jpeg':
|
||||
return 'image/jpeg';
|
||||
case 'webp':
|
||||
return 'image/webp';
|
||||
case 'png':
|
||||
return 'image/png';
|
||||
case 'gif':
|
||||
return 'image/gif';
|
||||
default:
|
||||
return 'image/png';
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeImageResponseData(
|
||||
data: { b64_json?: string; url?: string }[],
|
||||
mimeType: string = 'image/png'
|
||||
) {
|
||||
return data
|
||||
.map(image => {
|
||||
if (image.b64_json) {
|
||||
return `data:${mimeType};base64,${image.b64_json}`;
|
||||
}
|
||||
return image.url;
|
||||
})
|
||||
.filter((value): value is string => typeof value === 'string');
|
||||
}
|
||||
|
||||
export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
readonly type = CopilotProviderType.OpenAI;
|
||||
|
||||
@@ -338,23 +319,53 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
},
|
||||
];
|
||||
|
||||
#instance!: VercelOpenAIProvider | VercelOpenAICompatibleProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
return !!this.config.apiKey;
|
||||
}
|
||||
|
||||
protected override setup() {
|
||||
super.setup();
|
||||
this.#instance =
|
||||
this.config.oldApiStyle && this.config.baseURL
|
||||
? createOpenAICompatible({
|
||||
name: 'openai-compatible-old-style',
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseURL,
|
||||
})
|
||||
: createOpenAI({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseURL,
|
||||
});
|
||||
}
|
||||
|
||||
private handleError(e: any) {
|
||||
private handleError(
|
||||
e: any,
|
||||
model: string,
|
||||
options: CopilotImageOptions = {}
|
||||
) {
|
||||
if (e instanceof UserFriendlyError) {
|
||||
return e;
|
||||
} else if (e instanceof AISDKError) {
|
||||
if (e.message.includes('safety') || e.message.includes('risk')) {
|
||||
metrics.ai
|
||||
.counter('chat_text_risk_errors')
|
||||
.add(1, { model, user: options.user || undefined });
|
||||
}
|
||||
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: e.name || 'unknown',
|
||||
message: e.message,
|
||||
});
|
||||
} else {
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected openai response',
|
||||
});
|
||||
}
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected openai response',
|
||||
});
|
||||
}
|
||||
|
||||
override async refreshOnlineModels() {
|
||||
@@ -378,50 +389,20 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
override getProviderSpecificTools(
|
||||
toolName: CopilotChatTools,
|
||||
_model: string
|
||||
model: string
|
||||
): [string, Tool?] | undefined {
|
||||
if (toolName === 'docEdit') {
|
||||
if (
|
||||
toolName === 'webSearch' &&
|
||||
'responses' in this.#instance &&
|
||||
!this.isReasoningModel(model)
|
||||
) {
|
||||
return ['web_search_preview', openai.tools.webSearch({})];
|
||||
} else if (toolName === 'docEdit') {
|
||||
return ['doc_edit', undefined];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
private createNativeConfig(): NativeLlmBackendConfig {
|
||||
const baseUrl = this.config.baseURL || 'https://api.openai.com/v1';
|
||||
return {
|
||||
base_url: baseUrl.replace(/\/v1\/?$/, ''),
|
||||
auth_token: this.config.apiKey,
|
||||
};
|
||||
}
|
||||
|
||||
private createNativeAdapter(
|
||||
tools: ToolSet,
|
||||
nodeTextMiddleware?: NodeTextMiddleware[]
|
||||
) {
|
||||
return new NativeProviderAdapter(
|
||||
(request: NativeLlmRequest, signal?: AbortSignal) =>
|
||||
llmDispatchStream(
|
||||
this.config.oldApiStyle ? 'openai_chat' : 'openai_responses',
|
||||
this.createNativeConfig(),
|
||||
request,
|
||||
signal
|
||||
),
|
||||
tools,
|
||||
this.MAX_STEPS,
|
||||
{ nodeTextMiddleware }
|
||||
);
|
||||
}
|
||||
|
||||
private getReasoning(
|
||||
options: NonNullable<CopilotChatOptions>,
|
||||
model: string
|
||||
): Record<string, unknown> | undefined {
|
||||
if (options.reasoning && this.isReasoningModel(model)) {
|
||||
return { effort: 'medium' };
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async text(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
@@ -432,25 +413,33 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
include: options.webSearch ? ['citations'] : undefined,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
|
||||
const modelInstance =
|
||||
'responses' in this.#instance
|
||||
? this.#instance.responses(model.id)
|
||||
: this.#instance(model.id);
|
||||
|
||||
const { text } = await generateText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
temperature: options.temperature ?? 0,
|
||||
maxOutputTokens: options.maxTokens ?? 4096,
|
||||
providerOptions: {
|
||||
openai: this.getOpenAIOptions(options, model.id),
|
||||
},
|
||||
tools: await this.getTools(options, model.id),
|
||||
stopWhen: stepCountIs(this.MAX_STEPS),
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
return await adapter.text(request, options.signal);
|
||||
|
||||
return text.trim();
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
throw this.handleError(e, model.id, options);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -467,29 +456,38 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_calls')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
include: options.webSearch ? ['citations'] : undefined,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
for await (const chunk of adapter.streamText(request, options.signal)) {
|
||||
yield chunk;
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const citationParser = new CitationParser();
|
||||
const textParser = new TextStreamParser();
|
||||
for await (const chunk of fullStream) {
|
||||
switch (chunk.type) {
|
||||
case 'text-delta': {
|
||||
let result = textParser.parse(chunk);
|
||||
result = citationParser.parse(result);
|
||||
yield result;
|
||||
break;
|
||||
}
|
||||
case 'finish': {
|
||||
const footnotes = textParser.end();
|
||||
const result =
|
||||
citationParser.end() + (footnotes.length ? '\n' + footnotes : '');
|
||||
yield result;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
yield textParser.parse(chunk);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (options.signal?.aborted) {
|
||||
await fullStream.cancel();
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
|
||||
throw this.handleError(e, model.id, options);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -505,27 +503,24 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_calls')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
include: options.webSearch ? ['citations'] : undefined,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
for await (const chunk of adapter.streamObject(request, options.signal)) {
|
||||
yield chunk;
|
||||
.add(1, { model: model.id });
|
||||
const fullStream = await this.getFullStream(model, messages, options);
|
||||
const parser = new StreamObjectParser();
|
||||
for await (const chunk of fullStream) {
|
||||
const result = parser.parse(chunk);
|
||||
if (result) {
|
||||
yield result;
|
||||
}
|
||||
if (options.signal?.aborted) {
|
||||
await fullStream.cancel();
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_object_stream_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
.add(1, { model: model.id });
|
||||
throw this.handleError(e, model.id, options);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -540,27 +535,35 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request, schema } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
reasoning: this.getReasoning(options, model.id),
|
||||
middleware,
|
||||
});
|
||||
|
||||
const [system, msgs, schema] = await chatToGPTMessage(messages);
|
||||
if (!schema) {
|
||||
throw new CopilotPromptInvalid('Schema is required');
|
||||
}
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
const text = await adapter.text(request, options.signal);
|
||||
const parsed = JSON.parse(text);
|
||||
const validated = schema.parse(parsed);
|
||||
return JSON.stringify(validated);
|
||||
|
||||
const modelInstance =
|
||||
'responses' in this.#instance
|
||||
? this.#instance.responses(model.id)
|
||||
: this.#instance(model.id);
|
||||
|
||||
const { object } = await generateObject({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
temperature: options.temperature ?? 0,
|
||||
maxOutputTokens: options.maxTokens ?? 4096,
|
||||
maxRetries: options.maxRetries ?? 3,
|
||||
schema,
|
||||
providerOptions: {
|
||||
openai: options.user ? { user: options.user } : {},
|
||||
},
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
|
||||
return JSON.stringify(object);
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
throw this.handleError(e);
|
||||
throw this.handleError(e, model.id, options);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -572,32 +575,36 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ messages: [], cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
// get the log probability of "yes"/"no"
|
||||
const instance =
|
||||
'chat' in this.#instance
|
||||
? this.#instance.chat(model.id)
|
||||
: this.#instance(model.id);
|
||||
|
||||
const scores = await Promise.all(
|
||||
chunkMessages.map(async messages => {
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
const response = await this.requestOpenAIJson(
|
||||
'/chat/completions',
|
||||
{
|
||||
model: model.id,
|
||||
messages: this.toOpenAIChatMessages(system, msgs),
|
||||
temperature: 0,
|
||||
max_tokens: 16,
|
||||
logprobs: true,
|
||||
top_logprobs: 16,
|
||||
|
||||
const result = await generateText({
|
||||
model: instance,
|
||||
system,
|
||||
messages: msgs,
|
||||
temperature: 0,
|
||||
maxOutputTokens: 16,
|
||||
providerOptions: {
|
||||
openai: {
|
||||
...this.getOpenAIOptions(options, model.id),
|
||||
logprobs: 16,
|
||||
},
|
||||
},
|
||||
options.signal
|
||||
);
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
|
||||
const logprobs = response?.choices?.[0]?.logprobs?.content;
|
||||
if (!Array.isArray(logprobs) || logprobs.length === 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const parsedLogprobs = LogProbsSchema.parse(logprobs);
|
||||
const topMap = parsedLogprobs[0].top_logprobs.reduce(
|
||||
const topMap: Record<string, number> = LogProbsSchema.parse(
|
||||
result.providerMetadata?.openai?.logprobs
|
||||
)[0].top_logprobs.reduce<Record<string, number>>(
|
||||
(acc, { token, logprob }) => ({ ...acc, [token]: logprob }),
|
||||
{} as Record<string, number>
|
||||
{}
|
||||
);
|
||||
|
||||
const findLogProb = (token: string): number => {
|
||||
@@ -627,212 +634,50 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
return scores;
|
||||
}
|
||||
|
||||
private async getFullStream(
|
||||
model: CopilotProviderModel,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
) {
|
||||
const [system, msgs] = await chatToGPTMessage(messages);
|
||||
const modelInstance =
|
||||
'responses' in this.#instance
|
||||
? this.#instance.responses(model.id)
|
||||
: this.#instance(model.id);
|
||||
const { fullStream } = streamText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
frequencyPenalty: options.frequencyPenalty ?? 0,
|
||||
presencePenalty: options.presencePenalty ?? 0,
|
||||
temperature: options.temperature ?? 0,
|
||||
maxOutputTokens: options.maxTokens ?? 4096,
|
||||
providerOptions: {
|
||||
openai: this.getOpenAIOptions(options, model.id),
|
||||
},
|
||||
tools: await this.getTools(options, model.id),
|
||||
stopWhen: stepCountIs(this.MAX_STEPS),
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
return fullStream;
|
||||
}
|
||||
|
||||
// ====== text to image ======
|
||||
private buildImageFetchOptions(url: URL) {
|
||||
const baseOptions = { timeoutMs: 15_000, maxRedirects: 3 } as const;
|
||||
const trustedOrigins = new Set<string>();
|
||||
const protocol = this.AFFiNEConfig.server.https ? 'https:' : 'http:';
|
||||
const port = this.AFFiNEConfig.server.port;
|
||||
const isDefaultPort =
|
||||
(protocol === 'https:' && port === 443) ||
|
||||
(protocol === 'http:' && port === 80);
|
||||
|
||||
const addHostOrigin = (host: string) => {
|
||||
if (!host) return;
|
||||
try {
|
||||
const parsed = new URL(`${protocol}//${host}`);
|
||||
if (!parsed.port && !isDefaultPort) {
|
||||
parsed.port = String(port);
|
||||
}
|
||||
trustedOrigins.add(parsed.origin);
|
||||
} catch {
|
||||
// ignore invalid host config entries
|
||||
}
|
||||
};
|
||||
|
||||
if (this.AFFiNEConfig.server.externalUrl) {
|
||||
try {
|
||||
trustedOrigins.add(
|
||||
new URL(this.AFFiNEConfig.server.externalUrl).origin
|
||||
);
|
||||
} catch {
|
||||
// ignore invalid external URL
|
||||
}
|
||||
}
|
||||
|
||||
addHostOrigin(this.AFFiNEConfig.server.host);
|
||||
for (const host of this.AFFiNEConfig.server.hosts) {
|
||||
addHostOrigin(host);
|
||||
}
|
||||
|
||||
const hostname = url.hostname.toLowerCase();
|
||||
const trustedByHost = TRUSTED_ATTACHMENT_HOST_SUFFIXES.some(
|
||||
suffix => hostname === suffix || hostname.endsWith(`.${suffix}`)
|
||||
);
|
||||
if (trustedOrigins.has(url.origin) || trustedByHost) {
|
||||
return { ...baseOptions, allowPrivateOrigins: new Set([url.origin]) };
|
||||
}
|
||||
|
||||
return baseOptions;
|
||||
}
|
||||
|
||||
private redactUrl(raw: string | URL): string {
|
||||
try {
|
||||
const parsed = raw instanceof URL ? raw : new URL(raw);
|
||||
if (parsed.protocol === 'data:') return 'data:[redacted]';
|
||||
const segments = parsed.pathname.split('/').filter(Boolean);
|
||||
const redactedPath =
|
||||
segments.length <= 2
|
||||
? parsed.pathname || '/'
|
||||
: `/${segments[0]}/${segments[1]}/...`;
|
||||
return `${parsed.origin}${redactedPath}`;
|
||||
} catch {
|
||||
return '[invalid-url]';
|
||||
}
|
||||
}
|
||||
|
||||
private async fetchImage(
|
||||
url: string,
|
||||
maxBytes: number,
|
||||
signal?: AbortSignal
|
||||
): Promise<{ buffer: Buffer; type: string } | null> {
|
||||
if (url.startsWith('data:')) {
|
||||
let response: Response;
|
||||
try {
|
||||
response = await fetch(url, { signal });
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Skip image attachment data URL due to read failure: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
this.logger.warn(
|
||||
`Skip image attachment data URL due to invalid response: ${response.status}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
const type =
|
||||
response.headers.get('content-type') || 'application/octet-stream';
|
||||
if (!type.startsWith('image/')) {
|
||||
await response.body?.cancel().catch(() => undefined);
|
||||
this.logger.warn(
|
||||
`Skip non-image attachment data URL with content-type ${type}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
const buffer = await readResponseBufferWithLimit(response, maxBytes);
|
||||
return { buffer, type };
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Skip image attachment data URL due to read failure/size limit: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
let parsed: URL;
|
||||
try {
|
||||
parsed = new URL(url);
|
||||
} catch {
|
||||
this.logger.warn(
|
||||
`Skip image attachment with invalid URL: ${this.redactUrl(url)}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
const redactedUrl = this.redactUrl(parsed);
|
||||
|
||||
if (parsed.protocol !== 'http:' && parsed.protocol !== 'https:') {
|
||||
this.logger.warn(
|
||||
`Skip image attachment with unsupported protocol: ${redactedUrl}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
let response: Response;
|
||||
try {
|
||||
response = await safeFetch(
|
||||
parsed,
|
||||
{ method: 'GET', signal },
|
||||
this.buildImageFetchOptions(parsed)
|
||||
);
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Skip image attachment due to blocked/unreachable URL: ${redactedUrl}, reason: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
this.logger.warn(
|
||||
`Skip image attachment fetch failure ${response.status}: ${redactedUrl}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
const type =
|
||||
response.headers.get('content-type') || 'application/octet-stream';
|
||||
if (!type.startsWith('image/')) {
|
||||
await response.body?.cancel().catch(() => undefined);
|
||||
this.logger.warn(
|
||||
`Skip non-image attachment with content-type ${type}: ${redactedUrl}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
const contentLength = Number(response.headers.get('content-length'));
|
||||
if (Number.isFinite(contentLength) && contentLength > maxBytes) {
|
||||
await response.body?.cancel().catch(() => undefined);
|
||||
this.logger.warn(
|
||||
`Skip oversized image attachment by content-length (${contentLength}): ${redactedUrl}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
const buffer = await readResponseBufferWithLimit(response, maxBytes);
|
||||
return { buffer, type };
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Skip image attachment due to read failure/size limit: ${redactedUrl}, reason: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private async *generateImageWithAttachments(
|
||||
model: string,
|
||||
prompt: string,
|
||||
attachments: NonNullable<PromptMessage['attachments']>,
|
||||
signal?: AbortSignal
|
||||
attachments: NonNullable<PromptMessage['attachments']>
|
||||
): AsyncGenerator<string> {
|
||||
const form = new FormData();
|
||||
const outputFormat = 'webp';
|
||||
const maxBytes = 10 * OneMB;
|
||||
form.set('model', model);
|
||||
form.set('prompt', prompt);
|
||||
form.set('output_format', outputFormat);
|
||||
form.set('output_format', 'webp');
|
||||
|
||||
for (const [idx, entry] of attachments.entries()) {
|
||||
const url = typeof entry === 'string' ? entry : entry.attachment;
|
||||
try {
|
||||
const attachment = await this.fetchImage(url, maxBytes, signal);
|
||||
if (!attachment) continue;
|
||||
const { buffer, type } = attachment;
|
||||
const extension = type.split(';')[0].split('/')[1] || 'png';
|
||||
const file = new File([buffer], `${idx}.${extension}`, { type });
|
||||
const { buffer, type } = await fetchBuffer(url, 10 * OneMB, 'image/');
|
||||
const file = new File([buffer], `${idx}.png`, { type });
|
||||
form.append('image[]', file);
|
||||
} catch {
|
||||
continue;
|
||||
@@ -858,24 +703,18 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
const json = await res.json();
|
||||
const imageResponse = ImageResponseSchema.safeParse(json);
|
||||
if (!imageResponse.success) {
|
||||
if (imageResponse.success) {
|
||||
const data = imageResponse.data;
|
||||
if ('error' in data) {
|
||||
throw new Error(data.error.message);
|
||||
} else {
|
||||
for (const image of data.data) {
|
||||
yield `data:image/webp;base64,${image.b64_json}`;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw new Error(imageResponse.error.message);
|
||||
}
|
||||
const data = imageResponse.data;
|
||||
if ('error' in data) {
|
||||
throw new Error(data.error.message);
|
||||
}
|
||||
|
||||
const images = normalizeImageResponseData(
|
||||
data.data,
|
||||
normalizeImageFormatToMime(outputFormat)
|
||||
);
|
||||
if (!images.length) {
|
||||
throw new Error('No images returned from OpenAI');
|
||||
}
|
||||
for (const image of images) {
|
||||
yield image;
|
||||
}
|
||||
}
|
||||
|
||||
override async *streamImages(
|
||||
@@ -887,6 +726,13 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
if (!('image' in this.#instance)) {
|
||||
throw new CopilotProviderNotSupported({
|
||||
provider: this.type,
|
||||
kind: 'image',
|
||||
});
|
||||
}
|
||||
|
||||
metrics.ai
|
||||
.counter('generate_images_stream_calls')
|
||||
.add(1, { model: model.id });
|
||||
@@ -896,27 +742,22 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
try {
|
||||
if (attachments && attachments.length > 0) {
|
||||
yield* this.generateImageWithAttachments(
|
||||
model.id,
|
||||
prompt,
|
||||
attachments,
|
||||
options.signal
|
||||
);
|
||||
yield* this.generateImageWithAttachments(model.id, prompt, attachments);
|
||||
} else {
|
||||
const response = await this.requestOpenAIJson('/images/generations', {
|
||||
model: model.id,
|
||||
const modelInstance = this.#instance.image(model.id);
|
||||
const result = await generateImage({
|
||||
model: modelInstance,
|
||||
prompt,
|
||||
...(options.quality ? { quality: options.quality } : {}),
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: options.quality || null,
|
||||
},
|
||||
},
|
||||
});
|
||||
const imageResponse = ImageResponseSchema.parse(response);
|
||||
if ('error' in imageResponse) {
|
||||
throw new Error(imageResponse.error.message);
|
||||
}
|
||||
|
||||
const imageUrls = normalizeImageResponseData(imageResponse.data);
|
||||
if (!imageUrls.length) {
|
||||
throw new Error('No images returned from OpenAI');
|
||||
}
|
||||
const imageUrls = result.images.map(
|
||||
image => `data:image/png;base64,${image.base64}`
|
||||
);
|
||||
|
||||
for (const imageUrl of imageUrls) {
|
||||
yield imageUrl;
|
||||
@@ -928,7 +769,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
return;
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('generate_images_errors').add(1, { model: model.id });
|
||||
throw this.handleError(e);
|
||||
throw this.handleError(e, model.id, options);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -942,85 +783,51 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
await this.checkParams({ embeddings: messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
if (!('embedding' in this.#instance)) {
|
||||
throw new CopilotProviderNotSupported({
|
||||
provider: this.type,
|
||||
kind: 'embedding',
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('generate_embedding_calls')
|
||||
.add(1, { model: model.id });
|
||||
const response = await this.requestOpenAIJson('/embeddings', {
|
||||
model: model.id,
|
||||
input: messages,
|
||||
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
|
||||
|
||||
const modelInstance = this.#instance.embedding(model.id);
|
||||
|
||||
const { embeddings } = await embedMany({
|
||||
model: modelInstance,
|
||||
values: messages,
|
||||
providerOptions: {
|
||||
openai: {
|
||||
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
|
||||
},
|
||||
},
|
||||
});
|
||||
const data = Array.isArray(response?.data) ? response.data : [];
|
||||
return data
|
||||
.map((item: any) => item?.embedding)
|
||||
.filter((embedding: unknown) => Array.isArray(embedding)) as number[][];
|
||||
|
||||
return embeddings.filter(v => v && Array.isArray(v));
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('generate_embedding_errors')
|
||||
.add(1, { model: model.id });
|
||||
throw this.handleError(e);
|
||||
throw this.handleError(e, model.id, options);
|
||||
}
|
||||
}
|
||||
|
||||
private toOpenAIChatMessages(
|
||||
system: string | undefined,
|
||||
messages: Awaited<ReturnType<typeof chatToGPTMessage>>[1]
|
||||
) {
|
||||
const result: Array<{ role: string; content: string }> = [];
|
||||
if (system) {
|
||||
result.push({ role: 'system', content: system });
|
||||
private getOpenAIOptions(options: CopilotChatOptions, model: string) {
|
||||
const result: OpenAIResponsesProviderOptions = {};
|
||||
if (options?.reasoning && this.isReasoningModel(model)) {
|
||||
result.reasoningEffort = 'medium';
|
||||
result.reasoningSummary = 'detailed';
|
||||
}
|
||||
|
||||
for (const message of messages) {
|
||||
if (typeof message.content === 'string') {
|
||||
result.push({ role: message.role, content: message.content });
|
||||
continue;
|
||||
}
|
||||
|
||||
const text = message.content
|
||||
.filter(
|
||||
part =>
|
||||
part &&
|
||||
typeof part === 'object' &&
|
||||
'type' in part &&
|
||||
part.type === 'text' &&
|
||||
'text' in part
|
||||
)
|
||||
.map(part => String((part as { text: string }).text))
|
||||
.join('\n');
|
||||
|
||||
result.push({ role: message.role, content: text || '[no content]' });
|
||||
if (options?.user) {
|
||||
result.user = options.user;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private async requestOpenAIJson(
|
||||
path: string,
|
||||
body: Record<string, unknown>,
|
||||
signal?: AbortSignal
|
||||
): Promise<any> {
|
||||
const baseUrl = this.config.baseURL || 'https://api.openai.com/v1';
|
||||
const response = await fetch(`${baseUrl}${path}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`OpenAI API error ${response.status}: ${await response.text()}`
|
||||
);
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
private isReasoningModel(model: string) {
|
||||
// o series reasoning models
|
||||
return model.startsWith('o') || model.startsWith('gpt-5');
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import type { ToolSet } from 'ai';
|
||||
import {
|
||||
createPerplexity,
|
||||
type PerplexityProvider as VercelPerplexityProvider,
|
||||
} from '@ai-sdk/perplexity';
|
||||
import { generateText, streamText } from 'ai';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { CopilotProviderSideError, metrics } from '../../../base';
|
||||
import {
|
||||
llmDispatchStream,
|
||||
type NativeLlmBackendConfig,
|
||||
type NativeLlmRequest,
|
||||
} from '../../../native';
|
||||
import type { NodeTextMiddleware } from '../config';
|
||||
import { buildNativeRequest, NativeProviderAdapter } from './native';
|
||||
import { CopilotProvider } from './provider';
|
||||
import {
|
||||
CopilotChatOptions,
|
||||
@@ -17,12 +15,34 @@ import {
|
||||
ModelOutputType,
|
||||
PromptMessage,
|
||||
} from './types';
|
||||
import { chatToGPTMessage, CitationParser } from './utils';
|
||||
|
||||
export type PerplexityConfig = {
|
||||
apiKey: string;
|
||||
endpoint?: string;
|
||||
};
|
||||
|
||||
const PerplexityErrorSchema = z.union([
|
||||
z.object({
|
||||
detail: z.array(
|
||||
z.object({
|
||||
loc: z.array(z.string()),
|
||||
msg: z.string(),
|
||||
type: z.string(),
|
||||
})
|
||||
),
|
||||
}),
|
||||
z.object({
|
||||
error: z.object({
|
||||
message: z.string(),
|
||||
type: z.string(),
|
||||
code: z.number(),
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
|
||||
type PerplexityError = z.infer<typeof PerplexityErrorSchema>;
|
||||
|
||||
export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
readonly type = CopilotProviderType.Perplexity;
|
||||
|
||||
@@ -70,38 +90,18 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
},
|
||||
];
|
||||
|
||||
#instance!: VercelPerplexityProvider;
|
||||
|
||||
override configured(): boolean {
|
||||
return !!this.config.apiKey;
|
||||
}
|
||||
|
||||
protected override setup() {
|
||||
super.setup();
|
||||
}
|
||||
|
||||
private createNativeConfig(): NativeLlmBackendConfig {
|
||||
const baseUrl = this.config.endpoint || 'https://api.perplexity.ai';
|
||||
return {
|
||||
base_url: baseUrl.replace(/\/v1\/?$/, ''),
|
||||
auth_token: this.config.apiKey,
|
||||
};
|
||||
}
|
||||
|
||||
private createNativeAdapter(
|
||||
tools: ToolSet,
|
||||
nodeTextMiddleware?: NodeTextMiddleware[]
|
||||
) {
|
||||
return new NativeProviderAdapter(
|
||||
(request: NativeLlmRequest, signal?: AbortSignal) =>
|
||||
llmDispatchStream(
|
||||
'openai_chat',
|
||||
this.createNativeConfig(),
|
||||
request,
|
||||
signal
|
||||
),
|
||||
tools,
|
||||
this.MAX_STEPS,
|
||||
{ nodeTextMiddleware }
|
||||
);
|
||||
this.#instance = createPerplexity({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.endpoint,
|
||||
});
|
||||
}
|
||||
|
||||
async text(
|
||||
@@ -114,25 +114,32 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id));
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model: model.id });
|
||||
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
withAttachment: false,
|
||||
include: ['citations'],
|
||||
middleware,
|
||||
const [system, msgs] = await chatToGPTMessage(messages, false);
|
||||
|
||||
const modelInstance = this.#instance(model.id);
|
||||
|
||||
const { text, sources } = await generateText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
temperature: options.temperature ?? 0,
|
||||
maxOutputTokens: options.maxTokens ?? 4096,
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
return await adapter.text(request, options.signal);
|
||||
|
||||
const parser = new CitationParser();
|
||||
for (const source of sources.filter(s => s.sourceType === 'url')) {
|
||||
parser.push(source.url);
|
||||
}
|
||||
|
||||
let result = text.replaceAll(/<\/?think>\n/g, '\n---\n');
|
||||
result = parser.parse(result);
|
||||
result += parser.end();
|
||||
return result;
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_text_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model: model.id });
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
@@ -147,33 +154,79 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
try {
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_calls')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id });
|
||||
|
||||
const tools = await this.getTools(options, model.id);
|
||||
const middleware = this.getActiveProviderMiddleware();
|
||||
const { request } = await buildNativeRequest({
|
||||
model: model.id,
|
||||
messages,
|
||||
options,
|
||||
tools,
|
||||
withAttachment: false,
|
||||
include: ['citations'],
|
||||
middleware,
|
||||
const [system, msgs] = await chatToGPTMessage(messages, false);
|
||||
|
||||
const modelInstance = this.#instance(model.id);
|
||||
|
||||
const stream = streamText({
|
||||
model: modelInstance,
|
||||
system,
|
||||
messages: msgs,
|
||||
temperature: options.temperature ?? 0,
|
||||
maxOutputTokens: options.maxTokens ?? 4096,
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
const adapter = this.createNativeAdapter(tools, middleware.node?.text);
|
||||
for await (const chunk of adapter.streamText(request, options.signal)) {
|
||||
yield chunk;
|
||||
|
||||
const parser = new CitationParser();
|
||||
for await (const chunk of stream.fullStream) {
|
||||
switch (chunk.type) {
|
||||
case 'source': {
|
||||
if (chunk.sourceType === 'url') {
|
||||
parser.push(chunk.url);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'text-delta': {
|
||||
const text = chunk.text.replaceAll(/<\/?think>\n?/g, '\n---\n');
|
||||
const result = parser.parse(text);
|
||||
yield result;
|
||||
break;
|
||||
}
|
||||
case 'finish-step': {
|
||||
const result = parser.end();
|
||||
yield result;
|
||||
break;
|
||||
}
|
||||
case 'error': {
|
||||
const json =
|
||||
typeof chunk.error === 'string'
|
||||
? JSON.parse(chunk.error)
|
||||
: chunk.error;
|
||||
if (json && typeof json === 'object') {
|
||||
const data = PerplexityErrorSchema.parse(json);
|
||||
if ('detail' in data || 'error' in data) {
|
||||
throw this.convertError(data);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai
|
||||
.counter('chat_text_stream_errors')
|
||||
.add(1, this.metricLabels(model.id));
|
||||
throw this.handleError(e);
|
||||
} catch (e) {
|
||||
metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id });
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
private convertError(e: PerplexityError) {
|
||||
function getErrMessage(e: PerplexityError) {
|
||||
let err = 'Unexpected perplexity response';
|
||||
if ('detail' in e) {
|
||||
err = e.detail[0].msg || err;
|
||||
} else if ('error' in e) {
|
||||
err = e.error.message || err;
|
||||
}
|
||||
return err;
|
||||
}
|
||||
|
||||
throw new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: getErrMessage(e),
|
||||
});
|
||||
}
|
||||
|
||||
private handleError(e: any) {
|
||||
if (e instanceof CopilotProviderSideError) {
|
||||
return e;
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
import type { ProviderMiddlewareConfig } from '../config';
|
||||
import { CopilotProviderType } from './types';
|
||||
|
||||
const DEFAULT_MIDDLEWARE_BY_TYPE: Record<
|
||||
CopilotProviderType,
|
||||
ProviderMiddlewareConfig
|
||||
> = {
|
||||
[CopilotProviderType.OpenAI]: {
|
||||
rust: {
|
||||
request: ['normalize_messages'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
},
|
||||
node: {
|
||||
text: ['citation_footnote', 'callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.Anthropic]: {
|
||||
rust: {
|
||||
request: ['normalize_messages', 'tool_schema_rewrite'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
},
|
||||
node: {
|
||||
text: ['citation_footnote', 'callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.AnthropicVertex]: {
|
||||
rust: {
|
||||
request: ['normalize_messages', 'tool_schema_rewrite'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
},
|
||||
node: {
|
||||
text: ['citation_footnote', 'callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.Morph]: {
|
||||
rust: {
|
||||
request: ['clamp_max_tokens'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
},
|
||||
node: {
|
||||
text: ['citation_footnote', 'callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.Perplexity]: {
|
||||
rust: {
|
||||
request: ['clamp_max_tokens'],
|
||||
stream: ['stream_event_normalize', 'citation_indexing'],
|
||||
},
|
||||
node: {
|
||||
text: ['citation_footnote', 'callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.Gemini]: {
|
||||
node: {
|
||||
text: ['callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.GeminiVertex]: {
|
||||
node: {
|
||||
text: ['callout'],
|
||||
},
|
||||
},
|
||||
[CopilotProviderType.FAL]: {},
|
||||
};
|
||||
|
||||
function unique<T>(items: T[]) {
|
||||
return [...new Set(items)];
|
||||
}
|
||||
|
||||
function mergeArray<T>(base: T[] | undefined, override: T[] | undefined) {
|
||||
if (!base?.length && !override?.length) {
|
||||
return undefined;
|
||||
}
|
||||
return unique([...(base ?? []), ...(override ?? [])]);
|
||||
}
|
||||
|
||||
export function mergeProviderMiddleware(
|
||||
defaults: ProviderMiddlewareConfig,
|
||||
override?: ProviderMiddlewareConfig
|
||||
): ProviderMiddlewareConfig {
|
||||
return {
|
||||
rust: {
|
||||
request: mergeArray(defaults.rust?.request, override?.rust?.request),
|
||||
stream: mergeArray(defaults.rust?.stream, override?.rust?.stream),
|
||||
},
|
||||
node: {
|
||||
text: mergeArray(defaults.node?.text, override?.node?.text),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export function resolveProviderMiddleware(
|
||||
type: CopilotProviderType,
|
||||
override?: ProviderMiddlewareConfig
|
||||
): ProviderMiddlewareConfig {
|
||||
const defaults = DEFAULT_MIDDLEWARE_BY_TYPE[type] ?? {};
|
||||
return mergeProviderMiddleware(defaults, override);
|
||||
}
|
||||
@@ -1,273 +0,0 @@
|
||||
import type {
|
||||
CopilotProviderConfigMap,
|
||||
CopilotProviderDefaults,
|
||||
CopilotProviderProfile,
|
||||
ProviderMiddlewareConfig,
|
||||
} from '../config';
|
||||
import { resolveProviderMiddleware } from './provider-middleware';
|
||||
import { CopilotProviderType, type ModelOutputType } from './types';
|
||||
|
||||
const PROVIDER_ID_PATTERN = /^[a-zA-Z0-9-_]+$/;
|
||||
|
||||
const LEGACY_PROVIDER_ORDER: CopilotProviderType[] = [
|
||||
CopilotProviderType.OpenAI,
|
||||
CopilotProviderType.FAL,
|
||||
CopilotProviderType.Gemini,
|
||||
CopilotProviderType.GeminiVertex,
|
||||
CopilotProviderType.Perplexity,
|
||||
CopilotProviderType.Anthropic,
|
||||
CopilotProviderType.AnthropicVertex,
|
||||
CopilotProviderType.Morph,
|
||||
];
|
||||
|
||||
const LEGACY_PROVIDER_PRIORITY = LEGACY_PROVIDER_ORDER.reduce(
|
||||
(acc, type, index) => {
|
||||
acc[type] = LEGACY_PROVIDER_ORDER.length - index;
|
||||
return acc;
|
||||
},
|
||||
{} as Record<CopilotProviderType, number>
|
||||
);
|
||||
|
||||
type LegacyProvidersConfig = Partial<
|
||||
Record<CopilotProviderType, CopilotProviderConfigMap[CopilotProviderType]>
|
||||
>;
|
||||
|
||||
export type CopilotProvidersConfigInput = LegacyProvidersConfig & {
|
||||
profiles?: CopilotProviderProfile[] | null;
|
||||
defaults?: CopilotProviderDefaults | null;
|
||||
};
|
||||
|
||||
export type NormalizedCopilotProviderProfile = Omit<
|
||||
CopilotProviderProfile,
|
||||
'enabled' | 'priority' | 'middleware'
|
||||
> & {
|
||||
enabled: boolean;
|
||||
priority: number;
|
||||
middleware: ProviderMiddlewareConfig;
|
||||
};
|
||||
|
||||
export type CopilotProviderRegistry = {
|
||||
profiles: Map<string, NormalizedCopilotProviderProfile>;
|
||||
defaults: CopilotProviderDefaults;
|
||||
order: string[];
|
||||
byType: Map<CopilotProviderType, string[]>;
|
||||
};
|
||||
|
||||
export type ResolveModelResult = {
|
||||
rawModelId?: string;
|
||||
modelId?: string;
|
||||
explicitProviderId?: string;
|
||||
candidateProviderIds: string[];
|
||||
};
|
||||
|
||||
type ResolveModelOptions = {
|
||||
registry: CopilotProviderRegistry;
|
||||
modelId?: string;
|
||||
outputType?: ModelOutputType;
|
||||
availableProviderIds?: Iterable<string>;
|
||||
preferredProviderIds?: Iterable<string>;
|
||||
};
|
||||
|
||||
function unique<T>(list: T[]): T[] {
|
||||
return [...new Set(list)];
|
||||
}
|
||||
|
||||
function asArray<T>(iter?: Iterable<T>): T[] {
|
||||
return iter ? Array.from(iter) : [];
|
||||
}
|
||||
|
||||
function parseModelPrefix(
|
||||
registry: CopilotProviderRegistry,
|
||||
modelId: string
|
||||
): { providerId: string; modelId?: string } | null {
|
||||
const index = modelId.indexOf('/');
|
||||
if (index <= 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const providerId = modelId.slice(0, index);
|
||||
if (!registry.profiles.has(providerId)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const model = modelId.slice(index + 1);
|
||||
return { providerId, modelId: model || undefined };
|
||||
}
|
||||
|
||||
function normalizeProfile(
|
||||
profile: CopilotProviderProfile
|
||||
): NormalizedCopilotProviderProfile {
|
||||
return {
|
||||
...profile,
|
||||
enabled: profile.enabled !== false,
|
||||
priority: profile.priority ?? 0,
|
||||
middleware: resolveProviderMiddleware(profile.type, profile.middleware),
|
||||
};
|
||||
}
|
||||
|
||||
function toLegacyProfiles(
|
||||
config: CopilotProvidersConfigInput
|
||||
): CopilotProviderProfile[] {
|
||||
const legacyProfiles: CopilotProviderProfile[] = [];
|
||||
for (const type of LEGACY_PROVIDER_ORDER) {
|
||||
const legacyConfig = config[type];
|
||||
if (!legacyConfig) {
|
||||
continue;
|
||||
}
|
||||
legacyProfiles.push({
|
||||
id: `${type}-default`,
|
||||
type,
|
||||
priority: LEGACY_PROVIDER_PRIORITY[type],
|
||||
config: legacyConfig,
|
||||
} as CopilotProviderProfile);
|
||||
}
|
||||
return legacyProfiles;
|
||||
}
|
||||
|
||||
function mergeProfiles(
|
||||
explicitProfiles: CopilotProviderProfile[],
|
||||
legacyProfiles: CopilotProviderProfile[]
|
||||
): CopilotProviderProfile[] {
|
||||
const profiles = new Map<string, CopilotProviderProfile>();
|
||||
|
||||
for (const profile of explicitProfiles) {
|
||||
if (!PROVIDER_ID_PATTERN.test(profile.id)) {
|
||||
throw new Error(`Invalid copilot provider profile id: ${profile.id}`);
|
||||
}
|
||||
if (profiles.has(profile.id)) {
|
||||
throw new Error(`Duplicated copilot provider profile id: ${profile.id}`);
|
||||
}
|
||||
profiles.set(profile.id, profile);
|
||||
}
|
||||
|
||||
for (const profile of legacyProfiles) {
|
||||
if (!profiles.has(profile.id)) {
|
||||
profiles.set(profile.id, profile);
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(profiles.values());
|
||||
}
|
||||
|
||||
function sortProfiles(profiles: NormalizedCopilotProviderProfile[]) {
|
||||
return profiles.toSorted((a, b) => {
|
||||
if (a.priority !== b.priority) {
|
||||
return b.priority - a.priority;
|
||||
}
|
||||
return a.id.localeCompare(b.id);
|
||||
});
|
||||
}
|
||||
|
||||
function assertDefaults(
|
||||
defaults: CopilotProviderDefaults,
|
||||
profiles: Map<string, NormalizedCopilotProviderProfile>
|
||||
) {
|
||||
for (const providerId of Object.values(defaults)) {
|
||||
if (!providerId) {
|
||||
continue;
|
||||
}
|
||||
if (!profiles.has(providerId)) {
|
||||
throw new Error(
|
||||
`Copilot provider defaults references unknown providerId: ${providerId}`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function buildProviderRegistry(
|
||||
config: CopilotProvidersConfigInput
|
||||
): CopilotProviderRegistry {
|
||||
const explicitProfiles = config.profiles ?? [];
|
||||
const legacyProfiles = toLegacyProfiles(config);
|
||||
const mergedProfiles = mergeProfiles(explicitProfiles, legacyProfiles)
|
||||
.map(normalizeProfile)
|
||||
.filter(profile => profile.enabled);
|
||||
const sortedProfiles = sortProfiles(mergedProfiles);
|
||||
|
||||
const profiles = new Map(
|
||||
sortedProfiles.map(profile => [profile.id, profile] as const)
|
||||
);
|
||||
const defaults = config.defaults ?? {};
|
||||
assertDefaults(defaults, profiles);
|
||||
|
||||
const order = sortedProfiles.map(profile => profile.id);
|
||||
const byType = new Map<CopilotProviderType, string[]>();
|
||||
for (const profile of sortedProfiles) {
|
||||
const ids = byType.get(profile.type) ?? [];
|
||||
ids.push(profile.id);
|
||||
byType.set(profile.type, ids);
|
||||
}
|
||||
|
||||
return { profiles, defaults, order, byType };
|
||||
}
|
||||
|
||||
export function resolveModel({
|
||||
registry,
|
||||
modelId,
|
||||
outputType,
|
||||
availableProviderIds,
|
||||
preferredProviderIds,
|
||||
}: ResolveModelOptions): ResolveModelResult {
|
||||
const available = new Set(asArray(availableProviderIds));
|
||||
const preferred = new Set(asArray(preferredProviderIds));
|
||||
const hasAvailableFilter = available.size > 0;
|
||||
const hasPreferredFilter = preferred.size > 0;
|
||||
|
||||
const isAllowed = (providerId: string) => {
|
||||
const profile = registry.profiles.get(providerId);
|
||||
if (!profile?.enabled) {
|
||||
return false;
|
||||
}
|
||||
if (hasAvailableFilter && !available.has(providerId)) {
|
||||
return false;
|
||||
}
|
||||
if (hasPreferredFilter && !preferred.has(providerId)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
const prefixed = modelId ? parseModelPrefix(registry, modelId) : null;
|
||||
if (prefixed) {
|
||||
return {
|
||||
rawModelId: modelId,
|
||||
modelId: prefixed.modelId,
|
||||
explicitProviderId: prefixed.providerId,
|
||||
candidateProviderIds: isAllowed(prefixed.providerId)
|
||||
? [prefixed.providerId]
|
||||
: [],
|
||||
};
|
||||
}
|
||||
|
||||
const fallbackOrder = [
|
||||
...(outputType ? [registry.defaults[outputType]] : []),
|
||||
registry.defaults.fallback,
|
||||
...registry.order,
|
||||
].filter((id): id is string => !!id);
|
||||
|
||||
return {
|
||||
rawModelId: modelId,
|
||||
modelId,
|
||||
candidateProviderIds: unique(
|
||||
fallbackOrder.filter(providerId => isAllowed(providerId))
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
export function stripProviderPrefix(
|
||||
registry: CopilotProviderRegistry,
|
||||
providerId: string,
|
||||
modelId?: string
|
||||
) {
|
||||
if (!modelId) {
|
||||
return modelId;
|
||||
}
|
||||
const prefixed = parseModelPrefix(registry, modelId);
|
||||
if (!prefixed) {
|
||||
return modelId;
|
||||
}
|
||||
if (prefixed.providerId !== providerId) {
|
||||
return modelId;
|
||||
}
|
||||
return prefixed.modelId;
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
import { AsyncLocalStorage } from 'node:async_hooks';
|
||||
|
||||
import { Inject, Injectable, Logger } from '@nestjs/common';
|
||||
import { ModuleRef } from '@nestjs/core';
|
||||
import { Tool, ToolSet } from 'ai';
|
||||
@@ -15,7 +13,6 @@ import { DocReader, DocWriter } from '../../../core/doc';
|
||||
import { AccessController } from '../../../core/permission';
|
||||
import { Models } from '../../../models';
|
||||
import { IndexerService } from '../../indexer';
|
||||
import type { ProviderMiddlewareConfig } from '../config';
|
||||
import { CopilotContextService } from '../context/service';
|
||||
import { PromptService } from '../prompt/service';
|
||||
import {
|
||||
@@ -43,8 +40,6 @@ import {
|
||||
createSectionEditTool,
|
||||
} from '../tools';
|
||||
import { CopilotProviderFactory } from './factory';
|
||||
import { resolveProviderMiddleware } from './provider-middleware';
|
||||
import { buildProviderRegistry } from './provider-registry';
|
||||
import {
|
||||
type CopilotChatOptions,
|
||||
CopilotChatTools,
|
||||
@@ -63,14 +58,11 @@ import {
|
||||
StreamObject,
|
||||
} from './types';
|
||||
|
||||
const providerProfileContext = new AsyncLocalStorage<string>();
|
||||
|
||||
@Injectable()
|
||||
export abstract class CopilotProvider<C = any> {
|
||||
protected readonly logger = new Logger(this.constructor.name);
|
||||
protected readonly MAX_STEPS = 20;
|
||||
protected onlineModelList: string[] = [];
|
||||
|
||||
abstract readonly type: CopilotProviderType;
|
||||
abstract readonly models: CopilotProviderModel[];
|
||||
abstract configured(): boolean;
|
||||
@@ -78,39 +70,8 @@ export abstract class CopilotProvider<C = any> {
|
||||
@Inject() protected readonly AFFiNEConfig!: Config;
|
||||
@Inject() protected readonly factory!: CopilotProviderFactory;
|
||||
@Inject() protected readonly moduleRef!: ModuleRef;
|
||||
readonly #registeredProviderIds = new Set<string>();
|
||||
|
||||
runWithProfile<T>(providerId: string, callback: () => T): T {
|
||||
return providerProfileContext.run(providerId, callback);
|
||||
}
|
||||
|
||||
protected getActiveProviderId() {
|
||||
return providerProfileContext.getStore() ?? `${this.type}-default`;
|
||||
}
|
||||
|
||||
protected getActiveProviderMiddleware(): ProviderMiddlewareConfig {
|
||||
const providerId = this.getActiveProviderId();
|
||||
const registry = buildProviderRegistry(this.AFFiNEConfig.copilot.providers);
|
||||
const profile = registry.profiles.get(providerId);
|
||||
return profile?.middleware ?? resolveProviderMiddleware(this.type);
|
||||
}
|
||||
|
||||
protected metricLabels(
|
||||
model: string,
|
||||
labels: Record<string, string | number | boolean | undefined> = {}
|
||||
) {
|
||||
const providerId = this.getActiveProviderId();
|
||||
return { model, providerId, ...labels };
|
||||
}
|
||||
|
||||
get config(): C {
|
||||
const profileId = providerProfileContext.getStore();
|
||||
if (profileId) {
|
||||
const profile = this.AFFiNEConfig.copilot.providers.profiles?.find(
|
||||
profile => profile.id === profileId && profile.type === this.type
|
||||
);
|
||||
if (profile) return profile.config as C;
|
||||
}
|
||||
return this.AFFiNEConfig.copilot.providers[this.type] as C;
|
||||
}
|
||||
|
||||
@@ -127,37 +88,15 @@ export abstract class CopilotProvider<C = any> {
|
||||
}
|
||||
|
||||
protected setup() {
|
||||
const registry = buildProviderRegistry(this.AFFiNEConfig.copilot.providers);
|
||||
const providerIds = registry.byType.get(this.type) ?? [];
|
||||
const nextProviderIds = new Set<string>();
|
||||
|
||||
for (const id of providerIds) {
|
||||
const configured = this.runWithProfile(id, () => this.configured());
|
||||
if (configured) {
|
||||
nextProviderIds.add(id);
|
||||
this.factory.register(id, this);
|
||||
} else {
|
||||
this.factory.unregister(id, this);
|
||||
}
|
||||
}
|
||||
|
||||
for (const providerId of this.#registeredProviderIds) {
|
||||
if (!nextProviderIds.has(providerId)) {
|
||||
this.factory.unregister(providerId, this);
|
||||
}
|
||||
}
|
||||
this.#registeredProviderIds.clear();
|
||||
for (const providerId of nextProviderIds) {
|
||||
this.#registeredProviderIds.add(providerId);
|
||||
}
|
||||
|
||||
if (env.selfhosted && nextProviderIds.size > 0) {
|
||||
const [providerId] = Array.from(nextProviderIds);
|
||||
this.runWithProfile(providerId, () => {
|
||||
if (this.configured()) {
|
||||
this.factory.register(this);
|
||||
if (env.selfhosted) {
|
||||
this.refreshOnlineModels().catch(e =>
|
||||
this.logger.error('Failed to refresh online models', e)
|
||||
);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
this.factory.unregister(this);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -91,9 +91,7 @@ export async function chatToGPTMessage(
|
||||
// so we need to use base64 encoded attachments instead
|
||||
useBase64Attachment: boolean = false
|
||||
): Promise<[string | undefined, ChatMessage[], ZodType?]> {
|
||||
const hasSystem = messages[0]?.role === 'system';
|
||||
const system = hasSystem ? messages[0] : undefined;
|
||||
const normalizedMessages = hasSystem ? messages.slice(1) : messages;
|
||||
const system = messages[0]?.role === 'system' ? messages.shift() : undefined;
|
||||
const schema =
|
||||
system?.params?.schema && system.params.schema instanceof ZodType
|
||||
? system.params.schema
|
||||
@@ -101,7 +99,7 @@ export async function chatToGPTMessage(
|
||||
|
||||
// filter redundant fields
|
||||
const msgs: ChatMessage[] = [];
|
||||
for (let { role, content, attachments, params } of normalizedMessages.filter(
|
||||
for (let { role, content, attachments, params } of messages.filter(
|
||||
m => m.role !== 'system'
|
||||
)) {
|
||||
content = content.trim();
|
||||
@@ -408,34 +406,6 @@ export class CitationParser {
|
||||
}
|
||||
}
|
||||
|
||||
export type CitationIndexedEvent = {
|
||||
type: 'citation';
|
||||
index: number;
|
||||
url: string;
|
||||
};
|
||||
|
||||
export class CitationFootnoteFormatter {
|
||||
private readonly citations = new Map<number, string>();
|
||||
|
||||
public consume(event: CitationIndexedEvent) {
|
||||
if (event.type !== 'citation') {
|
||||
return '';
|
||||
}
|
||||
this.citations.set(event.index, event.url);
|
||||
return '';
|
||||
}
|
||||
|
||||
public end() {
|
||||
const footnotes = Array.from(this.citations.entries())
|
||||
.sort((a, b) => a[0] - b[0])
|
||||
.map(
|
||||
([index, citation]) =>
|
||||
`[^${index}]: {"type":"url","url":"${encodeURIComponent(citation)}"}`
|
||||
);
|
||||
return footnotes.join('\n');
|
||||
}
|
||||
}
|
||||
|
||||
type ChunkType = TextStreamPart<CustomAITools>['type'];
|
||||
|
||||
export function toError(error: unknown): Error {
|
||||
@@ -733,39 +703,21 @@ export const VertexModelListSchema = z.object({
|
||||
),
|
||||
});
|
||||
|
||||
function normalizeUrl(baseURL?: string) {
|
||||
if (!baseURL?.trim()) {
|
||||
return undefined;
|
||||
}
|
||||
try {
|
||||
const url = new URL(baseURL);
|
||||
const serialized = url.toString();
|
||||
if (serialized.endsWith('/')) return serialized.slice(0, -1);
|
||||
return serialized;
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export function getVertexAnthropicBaseUrl(
|
||||
options: GoogleVertexAnthropicProviderSettings
|
||||
) {
|
||||
const normalizedBaseUrl = normalizeUrl(options.baseURL);
|
||||
if (normalizedBaseUrl) return normalizedBaseUrl;
|
||||
const { location, project } = options;
|
||||
if (!location || !project) return undefined;
|
||||
return `https://${location}-aiplatform.googleapis.com/v1/projects/${project}/locations/${location}/publishers/anthropic`;
|
||||
}
|
||||
|
||||
export async function getGoogleAuth(
|
||||
options: GoogleVertexAnthropicProviderSettings | GoogleVertexProviderSettings,
|
||||
publisher: 'anthropic' | 'google'
|
||||
) {
|
||||
function getBaseUrl() {
|
||||
const normalizedBaseUrl = normalizeUrl(options.baseURL);
|
||||
if (normalizedBaseUrl) return normalizedBaseUrl;
|
||||
const { location } = options;
|
||||
if (location) {
|
||||
const { baseURL, location } = options;
|
||||
if (baseURL?.trim()) {
|
||||
try {
|
||||
const url = new URL(baseURL);
|
||||
if (url.pathname.endsWith('/')) {
|
||||
url.pathname = url.pathname.slice(0, -1);
|
||||
}
|
||||
return url.toString();
|
||||
} catch {}
|
||||
} else if (location) {
|
||||
return `https://${location}-aiplatform.googleapis.com/v1beta1/publishers/${publisher}`;
|
||||
}
|
||||
return undefined;
|
||||
|
||||
@@ -4,6 +4,7 @@ import { BadRequestException, NotFoundException } from '@nestjs/common';
|
||||
import {
|
||||
Args,
|
||||
Field,
|
||||
Float,
|
||||
ID,
|
||||
InputType,
|
||||
Mutation,
|
||||
@@ -14,6 +15,7 @@ import {
|
||||
ResolveField,
|
||||
Resolver,
|
||||
} from '@nestjs/graphql';
|
||||
import { AiPromptRole } from '@prisma/client';
|
||||
import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars';
|
||||
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
|
||||
|
||||
@@ -311,6 +313,57 @@ class CopilotQuotaType {
|
||||
used!: number;
|
||||
}
|
||||
|
||||
registerEnumType(AiPromptRole, {
|
||||
name: 'CopilotPromptMessageRole',
|
||||
});
|
||||
|
||||
@InputType('CopilotPromptConfigInput')
|
||||
@ObjectType()
|
||||
class CopilotPromptConfigType {
|
||||
@Field(() => Float, { nullable: true })
|
||||
frequencyPenalty!: number | null;
|
||||
|
||||
@Field(() => Float, { nullable: true })
|
||||
presencePenalty!: number | null;
|
||||
|
||||
@Field(() => Float, { nullable: true })
|
||||
temperature!: number | null;
|
||||
|
||||
@Field(() => Float, { nullable: true })
|
||||
topP!: number | null;
|
||||
}
|
||||
|
||||
@InputType('CopilotPromptMessageInput')
|
||||
@ObjectType()
|
||||
class CopilotPromptMessageType {
|
||||
@Field(() => AiPromptRole)
|
||||
role!: AiPromptRole;
|
||||
|
||||
@Field(() => String)
|
||||
content!: string;
|
||||
|
||||
@Field(() => GraphQLJSON, { nullable: true })
|
||||
params!: Record<string, string> | null;
|
||||
}
|
||||
|
||||
@ObjectType()
|
||||
class CopilotPromptType {
|
||||
@Field(() => String)
|
||||
name!: string;
|
||||
|
||||
@Field(() => String)
|
||||
model!: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
action!: string | null;
|
||||
|
||||
@Field(() => CopilotPromptConfigType, { nullable: true })
|
||||
config!: CopilotPromptConfigType | null;
|
||||
|
||||
@Field(() => [CopilotPromptMessageType])
|
||||
messages!: CopilotPromptMessageType[];
|
||||
}
|
||||
|
||||
@ObjectType()
|
||||
class CopilotModelType {
|
||||
@Field(() => String)
|
||||
@@ -585,8 +638,13 @@ export class CopilotResolver {
|
||||
);
|
||||
}
|
||||
|
||||
private async createCopilotSessionInternal(
|
||||
user: CurrentUser,
|
||||
@Mutation(() => String, {
|
||||
description: 'Create a chat session',
|
||||
})
|
||||
@CallMetric('ai', 'chat_session_create')
|
||||
async createCopilotSession(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'options', type: () => CreateChatSessionInput })
|
||||
options: CreateChatSessionInput
|
||||
): Promise<string> {
|
||||
// permission check based on session type
|
||||
@@ -608,42 +666,6 @@ export class CopilotResolver {
|
||||
});
|
||||
}
|
||||
|
||||
@Mutation(() => String, {
|
||||
description: 'Create a chat session',
|
||||
deprecationReason: 'use `createCopilotSessionWithHistory` instead',
|
||||
})
|
||||
@CallMetric('ai', 'chat_session_create')
|
||||
async createCopilotSession(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'options', type: () => CreateChatSessionInput })
|
||||
options: CreateChatSessionInput
|
||||
): Promise<string> {
|
||||
return await this.createCopilotSessionInternal(user, options);
|
||||
}
|
||||
|
||||
@Mutation(() => CopilotHistoriesType, {
|
||||
description: 'Create a chat session and return full session payload',
|
||||
})
|
||||
@CallMetric('ai', 'chat_session_create_with_history')
|
||||
async createCopilotSessionWithHistory(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'options', type: () => CreateChatSessionInput })
|
||||
options: CreateChatSessionInput
|
||||
): Promise<CopilotHistoriesType> {
|
||||
const sessionId = await this.createCopilotSessionInternal(user, options);
|
||||
const session = await this.chatSession.getSessionInfo(sessionId);
|
||||
if (!session) {
|
||||
throw new NotFoundException('Session not found');
|
||||
}
|
||||
return {
|
||||
...session,
|
||||
messages: session.messages.map(message => ({
|
||||
...message,
|
||||
id: message.id,
|
||||
})) as ChatMessageType[],
|
||||
};
|
||||
}
|
||||
|
||||
@Mutation(() => String, {
|
||||
description: 'Update a chat session',
|
||||
})
|
||||
@@ -917,10 +939,31 @@ export class UserCopilotResolver {
|
||||
}
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class CreateCopilotPromptInput {
|
||||
@Field(() => String)
|
||||
name!: string;
|
||||
|
||||
@Field(() => String)
|
||||
model!: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
action!: string | null;
|
||||
|
||||
@Field(() => CopilotPromptConfigType, { nullable: true })
|
||||
config!: CopilotPromptConfigType | null;
|
||||
|
||||
@Field(() => [CopilotPromptMessageType])
|
||||
messages!: CopilotPromptMessageType[];
|
||||
}
|
||||
|
||||
@Admin()
|
||||
@Resolver(() => String)
|
||||
export class PromptsManagementResolver {
|
||||
constructor(private readonly cron: CopilotCronJobs) {}
|
||||
constructor(
|
||||
private readonly cron: CopilotCronJobs,
|
||||
private readonly promptService: PromptService
|
||||
) {}
|
||||
|
||||
@Mutation(() => Boolean, {
|
||||
description: 'Trigger generate missing titles cron job',
|
||||
@@ -937,4 +980,48 @@ export class PromptsManagementResolver {
|
||||
await this.cron.triggerCleanupTrashedDocEmbeddings();
|
||||
return true;
|
||||
}
|
||||
|
||||
@Query(() => [CopilotPromptType], {
|
||||
description: 'List all copilot prompts',
|
||||
})
|
||||
async listCopilotPrompts() {
|
||||
const prompts = await this.promptService.list();
|
||||
return prompts.filter(
|
||||
p =>
|
||||
p.messages.length > 0 &&
|
||||
// ignore internal prompts
|
||||
!p.name.startsWith('workflow:') &&
|
||||
!p.name.startsWith('debug:') &&
|
||||
!p.name.startsWith('chat:') &&
|
||||
!p.name.startsWith('action:')
|
||||
);
|
||||
}
|
||||
|
||||
@Mutation(() => CopilotPromptType, {
|
||||
description: 'Create a copilot prompt',
|
||||
})
|
||||
async createCopilotPrompt(
|
||||
@Args({ type: () => CreateCopilotPromptInput, name: 'input' })
|
||||
input: CreateCopilotPromptInput
|
||||
) {
|
||||
await this.promptService.set(
|
||||
input.name,
|
||||
input.model,
|
||||
input.messages,
|
||||
input.config
|
||||
);
|
||||
return this.promptService.get(input.name);
|
||||
}
|
||||
|
||||
@Mutation(() => CopilotPromptType, {
|
||||
description: 'Update a copilot prompt',
|
||||
})
|
||||
async updateCopilotPrompt(
|
||||
@Args('name') name: string,
|
||||
@Args('messages', { type: () => [CopilotPromptMessageType] })
|
||||
messages: CopilotPromptMessageType[]
|
||||
) {
|
||||
await this.promptService.update(name, { messages, modified: true });
|
||||
return this.promptService.get(name);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import { AiPromptRole } from '@prisma/client';
|
||||
import { pick } from 'lodash-es';
|
||||
|
||||
import {
|
||||
Config,
|
||||
CopilotActionTaken,
|
||||
CopilotMessageNotFound,
|
||||
CopilotPromptNotFound,
|
||||
@@ -32,7 +31,6 @@ import { ChatMessageCache } from './message';
|
||||
import { ChatPrompt } from './prompt/chat-prompt';
|
||||
import { PromptService } from './prompt/service';
|
||||
import { CopilotProviderFactory } from './providers/factory';
|
||||
import { buildProviderRegistry } from './providers/provider-registry';
|
||||
import {
|
||||
ModelOutputType,
|
||||
type PromptMessage,
|
||||
@@ -107,31 +105,10 @@ export class ChatSession implements AsyncDisposable {
|
||||
hasPayment: boolean,
|
||||
requestedModelId?: string
|
||||
): Promise<string> {
|
||||
const config = this.moduleRef.get(Config, { strict: false });
|
||||
const registry = config
|
||||
? buildProviderRegistry(config.copilot.providers)
|
||||
: null;
|
||||
const defaultModel = this.model;
|
||||
const normalizeModel = (modelId?: string) => {
|
||||
if (!modelId) return modelId;
|
||||
const separatorIndex = modelId.indexOf('/');
|
||||
if (separatorIndex <= 0) return modelId;
|
||||
const providerId = modelId.slice(0, separatorIndex);
|
||||
if (!registry?.profiles.has(providerId)) return modelId;
|
||||
return modelId.slice(separatorIndex + 1);
|
||||
};
|
||||
const inModelList = (models: string[], modelId?: string) => {
|
||||
if (!modelId) return false;
|
||||
return (
|
||||
models.includes(modelId) ||
|
||||
models.includes(normalizeModel(modelId) ?? '')
|
||||
);
|
||||
};
|
||||
const normalize = (m?: string) => {
|
||||
if (inModelList(this.optionalModels, m)) return m;
|
||||
return defaultModel;
|
||||
};
|
||||
const isPro = (m?: string) => inModelList(this.proModels, m);
|
||||
const normalize = (m?: string) =>
|
||||
!!m && this.optionalModels.includes(m) ? m : defaultModel;
|
||||
const isPro = (m?: string) => !!m && this.proModels.includes(m);
|
||||
|
||||
// try resolve payment subscription service lazily
|
||||
let paymentEnabled = hasPayment;
|
||||
@@ -155,19 +132,10 @@ export class ChatSession implements AsyncDisposable {
|
||||
}
|
||||
|
||||
if (paymentEnabled && !isUserAIPro && isPro(requestedModelId)) {
|
||||
if (!defaultModel) {
|
||||
throw new CopilotSessionInvalidInput(
|
||||
'Model is required for AI subscription fallback'
|
||||
);
|
||||
}
|
||||
return defaultModel;
|
||||
}
|
||||
|
||||
const resolvedModel = normalize(requestedModelId);
|
||||
if (!resolvedModel) {
|
||||
throw new CopilotSessionInvalidInput('Model is required');
|
||||
}
|
||||
return resolvedModel;
|
||||
return normalize(requestedModelId);
|
||||
}
|
||||
|
||||
push(message: ChatMessage) {
|
||||
|
||||
@@ -32,22 +32,16 @@ export const buildBlobContentGetter = (
|
||||
return;
|
||||
}
|
||||
|
||||
const contextFile = context.files.find(
|
||||
file => file.blobId === blobId || file.id === blobId
|
||||
);
|
||||
const canonicalBlobId = contextFile?.blobId ?? blobId;
|
||||
const targetFileId = contextFile?.id;
|
||||
const [file, blob] = await Promise.all([
|
||||
targetFileId ? context.getFileContent(targetFileId, chunk) : undefined,
|
||||
context.getBlobContent(canonicalBlobId, chunk),
|
||||
context?.getFileContent(blobId, chunk),
|
||||
context?.getBlobContent(blobId, chunk),
|
||||
]);
|
||||
const content = file?.trim() || blob?.trim();
|
||||
if (!content) return;
|
||||
const info = contextFile
|
||||
? { fileName: contextFile.name, fileType: contextFile.mimeType }
|
||||
: {};
|
||||
if (!content) {
|
||||
return;
|
||||
}
|
||||
|
||||
return { blobId: canonicalBlobId, chunk, content, ...info };
|
||||
return { blobId, chunk, content };
|
||||
};
|
||||
return getBlobContent;
|
||||
};
|
||||
|
||||
@@ -84,8 +84,8 @@ export abstract class OAuthProvider {
|
||||
options?: { treatServerErrorAsInvalid?: boolean }
|
||||
) {
|
||||
const response = await fetch(url, {
|
||||
headers: { Accept: 'application/json', ...init?.headers },
|
||||
...init,
|
||||
headers: { ...init?.headers, Accept: 'application/json' },
|
||||
});
|
||||
|
||||
const body = await response.text();
|
||||
|
||||
@@ -14,7 +14,6 @@ import type {
|
||||
import { HTMLRewriter } from 'htmlrewriter';
|
||||
|
||||
import {
|
||||
applyAttachHeaders,
|
||||
BadRequest,
|
||||
Cache,
|
||||
readResponseBufferWithLimit,
|
||||
@@ -128,18 +127,15 @@ export class WorkerController {
|
||||
if (buffer.length === 0) {
|
||||
return resp.status(404).header(getCorsHeaders(origin)).send();
|
||||
}
|
||||
resp.header({
|
||||
...getCorsHeaders(origin),
|
||||
...(origin ? { Vary: 'Origin' } : {}),
|
||||
'Access-Control-Allow-Methods': 'GET',
|
||||
});
|
||||
applyAttachHeaders(resp, { buffer });
|
||||
const contentType = resp.getHeader('Content-Type') as string | undefined;
|
||||
if (contentType?.startsWith('image/')) {
|
||||
return resp.status(200).send(buffer);
|
||||
} else {
|
||||
throw new BadRequest('Invalid content type');
|
||||
}
|
||||
return resp
|
||||
.status(200)
|
||||
.header({
|
||||
...getCorsHeaders(origin),
|
||||
...(origin ? { Vary: 'Origin' } : {}),
|
||||
'Access-Control-Allow-Methods': 'GET',
|
||||
'Content-Type': 'image/*',
|
||||
})
|
||||
.send(buffer);
|
||||
}
|
||||
|
||||
let response: Response;
|
||||
@@ -175,39 +171,39 @@ export class WorkerController {
|
||||
throw new BadRequest('Failed to fetch image');
|
||||
}
|
||||
if (response.ok) {
|
||||
let buffer: Buffer;
|
||||
try {
|
||||
buffer = await readResponseBufferWithLimit(
|
||||
response,
|
||||
IMAGE_PROXY_MAX_BYTES
|
||||
);
|
||||
} catch (error) {
|
||||
if (error instanceof ResponseTooLargeError) {
|
||||
this.logger.warn('Image proxy response too large', {
|
||||
url: imageURL,
|
||||
limitBytes: error.data?.limitBytes,
|
||||
receivedBytes: error.data?.receivedBytes,
|
||||
});
|
||||
throw new BadRequest('Response too large');
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
await this.cache.set(cachedUrl, buffer.toString('base64'), {
|
||||
ttl: CACHE_TTL,
|
||||
});
|
||||
const contentDisposition = response.headers.get('Content-Disposition');
|
||||
resp.header({
|
||||
...getCorsHeaders(origin),
|
||||
...(origin ? { Vary: 'Origin' } : {}),
|
||||
'Access-Control-Allow-Methods': 'GET',
|
||||
});
|
||||
if (contentDisposition) {
|
||||
resp.setHeader('Content-Disposition', contentDisposition);
|
||||
}
|
||||
applyAttachHeaders(resp, { buffer });
|
||||
const contentType = resp.getHeader('Content-Type') as string | undefined;
|
||||
const contentType = response.headers.get('Content-Type');
|
||||
if (contentType?.startsWith('image/')) {
|
||||
return resp.status(200).send(buffer);
|
||||
let buffer: Buffer;
|
||||
try {
|
||||
buffer = await readResponseBufferWithLimit(
|
||||
response,
|
||||
IMAGE_PROXY_MAX_BYTES
|
||||
);
|
||||
} catch (error) {
|
||||
if (error instanceof ResponseTooLargeError) {
|
||||
this.logger.warn('Image proxy response too large', {
|
||||
url: imageURL,
|
||||
limitBytes: error.data?.limitBytes,
|
||||
receivedBytes: error.data?.receivedBytes,
|
||||
});
|
||||
throw new BadRequest('Response too large');
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
await this.cache.set(cachedUrl, buffer.toString('base64'), {
|
||||
ttl: CACHE_TTL,
|
||||
});
|
||||
const contentDisposition = response.headers.get('Content-Disposition');
|
||||
return resp
|
||||
.status(200)
|
||||
.header({
|
||||
...getCorsHeaders(origin),
|
||||
...(origin ? { Vary: 'Origin' } : {}),
|
||||
'Access-Control-Allow-Methods': 'GET',
|
||||
'Content-Type': contentType,
|
||||
'Content-Disposition': contentDisposition,
|
||||
})
|
||||
.send(buffer);
|
||||
} else {
|
||||
throw new BadRequest('Invalid content type');
|
||||
}
|
||||
|
||||
@@ -607,10 +607,50 @@ type CopilotModelsType {
|
||||
proModels: [CopilotModelType!]!
|
||||
}
|
||||
|
||||
input CopilotPromptConfigInput {
|
||||
frequencyPenalty: Float
|
||||
presencePenalty: Float
|
||||
temperature: Float
|
||||
topP: Float
|
||||
}
|
||||
|
||||
type CopilotPromptConfigType {
|
||||
frequencyPenalty: Float
|
||||
presencePenalty: Float
|
||||
temperature: Float
|
||||
topP: Float
|
||||
}
|
||||
|
||||
input CopilotPromptMessageInput {
|
||||
content: String!
|
||||
params: JSON
|
||||
role: CopilotPromptMessageRole!
|
||||
}
|
||||
|
||||
enum CopilotPromptMessageRole {
|
||||
assistant
|
||||
system
|
||||
user
|
||||
}
|
||||
|
||||
type CopilotPromptMessageType {
|
||||
content: String!
|
||||
params: JSON
|
||||
role: CopilotPromptMessageRole!
|
||||
}
|
||||
|
||||
type CopilotPromptNotFoundDataType {
|
||||
name: String!
|
||||
}
|
||||
|
||||
type CopilotPromptType {
|
||||
action: String
|
||||
config: CopilotPromptConfigType
|
||||
messages: [CopilotPromptMessageType!]!
|
||||
model: String!
|
||||
name: String!
|
||||
}
|
||||
|
||||
type CopilotProviderNotSupportedDataType {
|
||||
kind: String!
|
||||
provider: String!
|
||||
@@ -707,6 +747,14 @@ input CreateCheckoutSessionInput {
|
||||
variant: SubscriptionVariant
|
||||
}
|
||||
|
||||
input CreateCopilotPromptInput {
|
||||
action: String
|
||||
config: CopilotPromptConfigInput
|
||||
messages: [CopilotPromptMessageInput!]!
|
||||
model: String!
|
||||
name: String!
|
||||
}
|
||||
|
||||
input CreateUserInput {
|
||||
email: String!
|
||||
name: String
|
||||
@@ -1503,11 +1551,11 @@ type Mutation {
|
||||
"""Create a chat message"""
|
||||
createCopilotMessage(options: CreateChatMessageInput!): String!
|
||||
|
||||
"""Create a chat session"""
|
||||
createCopilotSession(options: CreateChatSessionInput!): String! @deprecated(reason: "use `createCopilotSessionWithHistory` instead")
|
||||
"""Create a copilot prompt"""
|
||||
createCopilotPrompt(input: CreateCopilotPromptInput!): CopilotPromptType!
|
||||
|
||||
"""Create a chat session and return full session payload"""
|
||||
createCopilotSessionWithHistory(options: CreateChatSessionInput!): CopilotHistories!
|
||||
"""Create a chat session"""
|
||||
createCopilotSession(options: CreateChatSessionInput!): String!
|
||||
|
||||
"""Create a stripe customer portal to manage payment methods"""
|
||||
createCustomerPortal: String!
|
||||
@@ -1624,6 +1672,9 @@ type Mutation {
|
||||
"""Update a comment content"""
|
||||
updateComment(input: CommentUpdateInput!): Boolean!
|
||||
|
||||
"""Update a copilot prompt"""
|
||||
updateCopilotPrompt(messages: [CopilotPromptMessageInput!]!, name: String!): CopilotPromptType!
|
||||
|
||||
"""Update a chat session"""
|
||||
updateCopilotSession(options: UpdateChatSessionInput!): String!
|
||||
updateDocDefaultRole(input: UpdateDocDefaultRoleInput!): Boolean!
|
||||
@@ -1872,6 +1923,9 @@ type Query {
|
||||
|
||||
"""get workspace invitation info"""
|
||||
getInviteInfo(inviteId: String!): InvitationType!
|
||||
|
||||
"""List all copilot prompts"""
|
||||
listCopilotPrompts: [CopilotPromptType!]!
|
||||
prices: [SubscriptionPrice!]!
|
||||
|
||||
"""Get public user by id"""
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
query getPrompts {
|
||||
listCopilotPrompts {
|
||||
name
|
||||
model
|
||||
action
|
||||
config {
|
||||
frequencyPenalty
|
||||
presencePenalty
|
||||
temperature
|
||||
topP
|
||||
}
|
||||
messages {
|
||||
role
|
||||
content
|
||||
params
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
mutation updatePrompt(
|
||||
$name: String!
|
||||
$messages: [CopilotPromptMessageInput!]!
|
||||
) {
|
||||
updateCopilotPrompt(name: $name, messages: $messages) {
|
||||
name
|
||||
model
|
||||
action
|
||||
config {
|
||||
frequencyPenalty
|
||||
presencePenalty
|
||||
temperature
|
||||
topP
|
||||
}
|
||||
messages {
|
||||
role
|
||||
content
|
||||
params
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
#import "./fragments/copilot.gql"
|
||||
|
||||
query getCopilotDocSessions(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
#import "./fragments/copilot.gql"
|
||||
|
||||
query getCopilotPinnedSessions(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
#import "./fragments/copilot.gql"
|
||||
|
||||
query getCopilotWorkspaceSessions(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
#import "./fragments/copilot.gql"
|
||||
|
||||
query getCopilotHistories(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
#import "./fragments/copilot-chat-history.gql"
|
||||
|
||||
mutation createCopilotSessionWithHistory($options: CreateChatSessionInput!) {
|
||||
createCopilotSessionWithHistory(options: $options) {
|
||||
...CopilotChatHistory
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
#import "./fragments/copilot.gql"
|
||||
|
||||
query getCopilotLatestDocSession(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
#import "./fragments/copilot.gql"
|
||||
|
||||
query getCopilotSession(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
#import "./fragments/copilot.gql"
|
||||
|
||||
query getCopilotRecentSessions(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#import "./fragments/paginated-copilot-chats.gql"
|
||||
#import "./fragments/copilot.gql"
|
||||
|
||||
query getCopilotSessions(
|
||||
$workspaceId: String!
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
fragment CopilotChatHistory on CopilotHistories {
|
||||
sessionId
|
||||
workspaceId
|
||||
docId
|
||||
parentSessionId
|
||||
promptName
|
||||
model
|
||||
optionalModels
|
||||
action
|
||||
pinned
|
||||
title
|
||||
tokens
|
||||
messages {
|
||||
id
|
||||
role
|
||||
content
|
||||
attachments
|
||||
streamObjects {
|
||||
type
|
||||
textDelta
|
||||
toolCallId
|
||||
toolName
|
||||
args
|
||||
result
|
||||
}
|
||||
createdAt
|
||||
}
|
||||
createdAt
|
||||
updatedAt
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
fragment CopilotChatMessage on ChatMessage {
|
||||
id
|
||||
role
|
||||
content
|
||||
attachments
|
||||
streamObjects {
|
||||
type
|
||||
textDelta
|
||||
toolCallId
|
||||
toolName
|
||||
args
|
||||
result
|
||||
}
|
||||
createdAt
|
||||
}
|
||||
|
||||
fragment CopilotChatHistory on CopilotHistories {
|
||||
sessionId
|
||||
workspaceId
|
||||
docId
|
||||
parentSessionId
|
||||
promptName
|
||||
model
|
||||
optionalModels
|
||||
action
|
||||
pinned
|
||||
title
|
||||
tokens
|
||||
messages {
|
||||
...CopilotChatMessage
|
||||
}
|
||||
createdAt
|
||||
updatedAt
|
||||
}
|
||||
|
||||
fragment PaginatedCopilotChats on PaginatedCopilotHistoriesType {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
hasPreviousPage
|
||||
startCursor
|
||||
endCursor
|
||||
}
|
||||
edges {
|
||||
cursor
|
||||
node {
|
||||
...CopilotChatHistory
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
#import "./copilot-chat-history.gql"
|
||||
|
||||
fragment PaginatedCopilotChats on PaginatedCopilotHistoriesType {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
hasPreviousPage
|
||||
startCursor
|
||||
endCursor
|
||||
}
|
||||
edges {
|
||||
cursor
|
||||
node {
|
||||
...CopilotChatHistory
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,21 @@ export interface GraphQLQuery {
|
||||
file?: boolean;
|
||||
deprecations?: string[];
|
||||
}
|
||||
export const copilotChatMessageFragment = `fragment CopilotChatMessage on ChatMessage {
|
||||
id
|
||||
role
|
||||
content
|
||||
attachments
|
||||
streamObjects {
|
||||
type
|
||||
textDelta
|
||||
toolCallId
|
||||
toolName
|
||||
args
|
||||
result
|
||||
}
|
||||
createdAt
|
||||
}`;
|
||||
export const copilotChatHistoryFragment = `fragment CopilotChatHistory on CopilotHistories {
|
||||
sessionId
|
||||
workspaceId
|
||||
@@ -19,23 +34,25 @@ export const copilotChatHistoryFragment = `fragment CopilotChatHistory on Copilo
|
||||
title
|
||||
tokens
|
||||
messages {
|
||||
id
|
||||
role
|
||||
content
|
||||
attachments
|
||||
streamObjects {
|
||||
type
|
||||
textDelta
|
||||
toolCallId
|
||||
toolName
|
||||
args
|
||||
result
|
||||
}
|
||||
createdAt
|
||||
...CopilotChatMessage
|
||||
}
|
||||
createdAt
|
||||
updatedAt
|
||||
}`;
|
||||
export const paginatedCopilotChatsFragment = `fragment PaginatedCopilotChats on PaginatedCopilotHistoriesType {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
hasPreviousPage
|
||||
startCursor
|
||||
endCursor
|
||||
}
|
||||
edges {
|
||||
cursor
|
||||
node {
|
||||
...CopilotChatHistory
|
||||
}
|
||||
}
|
||||
}`;
|
||||
export const credentialsRequirementsFragment = `fragment CredentialsRequirements on CredentialsRequirementType {
|
||||
password {
|
||||
...PasswordLimits
|
||||
@@ -77,20 +94,6 @@ export const currentUserProfileFragment = `fragment CurrentUserProfile on UserTy
|
||||
}
|
||||
}
|
||||
}`;
|
||||
export const paginatedCopilotChatsFragment = `fragment PaginatedCopilotChats on PaginatedCopilotHistoriesType {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
hasPreviousPage
|
||||
startCursor
|
||||
endCursor
|
||||
}
|
||||
edges {
|
||||
cursor
|
||||
node {
|
||||
...CopilotChatHistory
|
||||
}
|
||||
}
|
||||
}${copilotChatHistoryFragment}`;
|
||||
export const passwordLimitsFragment = `fragment PasswordLimits on PasswordLimitsType {
|
||||
minLength
|
||||
maxLength
|
||||
@@ -401,6 +404,52 @@ export const appConfigQuery = {
|
||||
}`,
|
||||
};
|
||||
|
||||
export const getPromptsQuery = {
|
||||
id: 'getPromptsQuery' as const,
|
||||
op: 'getPrompts',
|
||||
query: `query getPrompts {
|
||||
listCopilotPrompts {
|
||||
name
|
||||
model
|
||||
action
|
||||
config {
|
||||
frequencyPenalty
|
||||
presencePenalty
|
||||
temperature
|
||||
topP
|
||||
}
|
||||
messages {
|
||||
role
|
||||
content
|
||||
params
|
||||
}
|
||||
}
|
||||
}`,
|
||||
};
|
||||
|
||||
export const updatePromptMutation = {
|
||||
id: 'updatePromptMutation' as const,
|
||||
op: 'updatePrompt',
|
||||
query: `mutation updatePrompt($name: String!, $messages: [CopilotPromptMessageInput!]!) {
|
||||
updateCopilotPrompt(name: $name, messages: $messages) {
|
||||
name
|
||||
model
|
||||
action
|
||||
config {
|
||||
frequencyPenalty
|
||||
presencePenalty
|
||||
temperature
|
||||
topP
|
||||
}
|
||||
messages {
|
||||
role
|
||||
content
|
||||
params
|
||||
}
|
||||
}
|
||||
}`,
|
||||
};
|
||||
|
||||
export const createUserMutation = {
|
||||
id: 'createUserMutation' as const,
|
||||
op: 'createUser',
|
||||
@@ -1362,6 +1411,8 @@ export const getCopilotDocSessionsQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
@@ -1381,6 +1432,8 @@ export const getCopilotPinnedSessionsQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
@@ -1396,6 +1449,8 @@ export const getCopilotWorkspaceSessionsQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
@@ -1411,6 +1466,8 @@ export const getCopilotHistoriesQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
@@ -1539,24 +1596,12 @@ export const cleanupCopilotSessionMutation = {
|
||||
}`,
|
||||
};
|
||||
|
||||
export const createCopilotSessionWithHistoryMutation = {
|
||||
id: 'createCopilotSessionWithHistoryMutation' as const,
|
||||
op: 'createCopilotSessionWithHistory',
|
||||
query: `mutation createCopilotSessionWithHistory($options: CreateChatSessionInput!) {
|
||||
createCopilotSessionWithHistory(options: $options) {
|
||||
...CopilotChatHistory
|
||||
}
|
||||
}
|
||||
${copilotChatHistoryFragment}`,
|
||||
};
|
||||
|
||||
export const createCopilotSessionMutation = {
|
||||
id: 'createCopilotSessionMutation' as const,
|
||||
op: 'createCopilotSession',
|
||||
query: `mutation createCopilotSession($options: CreateChatSessionInput!) {
|
||||
createCopilotSession(options: $options)
|
||||
}`,
|
||||
deprecations: ["'createCopilotSession' is deprecated: use `createCopilotSessionWithHistory` instead"],
|
||||
};
|
||||
|
||||
export const forkCopilotSessionMutation = {
|
||||
@@ -1583,6 +1628,8 @@ export const getCopilotLatestDocSessionQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
@@ -1598,6 +1645,8 @@ export const getCopilotSessionQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
@@ -1616,6 +1665,8 @@ export const getCopilotRecentSessionsQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
@@ -1639,6 +1690,8 @@ export const getCopilotSessionsQuery = {
|
||||
}
|
||||
}
|
||||
}
|
||||
${copilotChatMessageFragment}
|
||||
${copilotChatHistoryFragment}
|
||||
${paginatedCopilotChatsFragment}`,
|
||||
};
|
||||
|
||||
|
||||
@@ -725,11 +725,54 @@ export interface CopilotModelsType {
|
||||
proModels: Array<CopilotModelType>;
|
||||
}
|
||||
|
||||
export interface CopilotPromptConfigInput {
|
||||
frequencyPenalty?: InputMaybe<Scalars['Float']['input']>;
|
||||
presencePenalty?: InputMaybe<Scalars['Float']['input']>;
|
||||
temperature?: InputMaybe<Scalars['Float']['input']>;
|
||||
topP?: InputMaybe<Scalars['Float']['input']>;
|
||||
}
|
||||
|
||||
export interface CopilotPromptConfigType {
|
||||
__typename?: 'CopilotPromptConfigType';
|
||||
frequencyPenalty: Maybe<Scalars['Float']['output']>;
|
||||
presencePenalty: Maybe<Scalars['Float']['output']>;
|
||||
temperature: Maybe<Scalars['Float']['output']>;
|
||||
topP: Maybe<Scalars['Float']['output']>;
|
||||
}
|
||||
|
||||
export interface CopilotPromptMessageInput {
|
||||
content: Scalars['String']['input'];
|
||||
params?: InputMaybe<Scalars['JSON']['input']>;
|
||||
role: CopilotPromptMessageRole;
|
||||
}
|
||||
|
||||
export enum CopilotPromptMessageRole {
|
||||
assistant = 'assistant',
|
||||
system = 'system',
|
||||
user = 'user',
|
||||
}
|
||||
|
||||
export interface CopilotPromptMessageType {
|
||||
__typename?: 'CopilotPromptMessageType';
|
||||
content: Scalars['String']['output'];
|
||||
params: Maybe<Scalars['JSON']['output']>;
|
||||
role: CopilotPromptMessageRole;
|
||||
}
|
||||
|
||||
export interface CopilotPromptNotFoundDataType {
|
||||
__typename?: 'CopilotPromptNotFoundDataType';
|
||||
name: Scalars['String']['output'];
|
||||
}
|
||||
|
||||
export interface CopilotPromptType {
|
||||
__typename?: 'CopilotPromptType';
|
||||
action: Maybe<Scalars['String']['output']>;
|
||||
config: Maybe<CopilotPromptConfigType>;
|
||||
messages: Array<CopilotPromptMessageType>;
|
||||
model: Scalars['String']['output'];
|
||||
name: Scalars['String']['output'];
|
||||
}
|
||||
|
||||
export interface CopilotProviderNotSupportedDataType {
|
||||
__typename?: 'CopilotProviderNotSupportedDataType';
|
||||
kind: Scalars['String']['output'];
|
||||
@@ -841,6 +884,14 @@ export interface CreateCheckoutSessionInput {
|
||||
variant?: InputMaybe<SubscriptionVariant>;
|
||||
}
|
||||
|
||||
export interface CreateCopilotPromptInput {
|
||||
action?: InputMaybe<Scalars['String']['input']>;
|
||||
config?: InputMaybe<CopilotPromptConfigInput>;
|
||||
messages: Array<CopilotPromptMessageInput>;
|
||||
model: Scalars['String']['input'];
|
||||
name: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface CreateUserInput {
|
||||
email: Scalars['String']['input'];
|
||||
name?: InputMaybe<Scalars['String']['input']>;
|
||||
@@ -1701,13 +1752,10 @@ export interface Mutation {
|
||||
createCopilotContext: Scalars['String']['output'];
|
||||
/** Create a chat message */
|
||||
createCopilotMessage: Scalars['String']['output'];
|
||||
/**
|
||||
* Create a chat session
|
||||
* @deprecated use `createCopilotSessionWithHistory` instead
|
||||
*/
|
||||
/** Create a copilot prompt */
|
||||
createCopilotPrompt: CopilotPromptType;
|
||||
/** Create a chat session */
|
||||
createCopilotSession: Scalars['String']['output'];
|
||||
/** Create a chat session and return full session payload */
|
||||
createCopilotSessionWithHistory: CopilotHistories;
|
||||
/** Create a stripe customer portal to manage payment methods */
|
||||
createCustomerPortal: Scalars['String']['output'];
|
||||
createInviteLink: InviteLink;
|
||||
@@ -1797,6 +1845,8 @@ export interface Mutation {
|
||||
updateCalendarAccount: Maybe<CalendarAccountObjectType>;
|
||||
/** Update a comment content */
|
||||
updateComment: Scalars['Boolean']['output'];
|
||||
/** Update a copilot prompt */
|
||||
updateCopilotPrompt: CopilotPromptType;
|
||||
/** Update a chat session */
|
||||
updateCopilotSession: Scalars['String']['output'];
|
||||
updateDocDefaultRole: Scalars['Boolean']['output'];
|
||||
@@ -1948,11 +1998,11 @@ export interface MutationCreateCopilotMessageArgs {
|
||||
options: CreateChatMessageInput;
|
||||
}
|
||||
|
||||
export interface MutationCreateCopilotSessionArgs {
|
||||
options: CreateChatSessionInput;
|
||||
export interface MutationCreateCopilotPromptArgs {
|
||||
input: CreateCopilotPromptInput;
|
||||
}
|
||||
|
||||
export interface MutationCreateCopilotSessionWithHistoryArgs {
|
||||
export interface MutationCreateCopilotSessionArgs {
|
||||
options: CreateChatSessionInput;
|
||||
}
|
||||
|
||||
@@ -2212,6 +2262,11 @@ export interface MutationUpdateCommentArgs {
|
||||
input: CommentUpdateInput;
|
||||
}
|
||||
|
||||
export interface MutationUpdateCopilotPromptArgs {
|
||||
messages: Array<CopilotPromptMessageInput>;
|
||||
name: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationUpdateCopilotSessionArgs {
|
||||
options: UpdateChatSessionInput;
|
||||
}
|
||||
@@ -2499,6 +2554,8 @@ export interface Query {
|
||||
error: ErrorDataUnion;
|
||||
/** get workspace invitation info */
|
||||
getInviteInfo: InvitationType;
|
||||
/** List all copilot prompts */
|
||||
listCopilotPrompts: Array<CopilotPromptType>;
|
||||
prices: Array<SubscriptionPrice>;
|
||||
/** Get public user by id */
|
||||
publicUserById: Maybe<PublicUserType>;
|
||||
@@ -3829,6 +3886,59 @@ export type AppConfigQueryVariables = Exact<{ [key: string]: never }>;
|
||||
|
||||
export type AppConfigQuery = { __typename?: 'Query'; appConfig: any };
|
||||
|
||||
export type GetPromptsQueryVariables = Exact<{ [key: string]: never }>;
|
||||
|
||||
export type GetPromptsQuery = {
|
||||
__typename?: 'Query';
|
||||
listCopilotPrompts: Array<{
|
||||
__typename?: 'CopilotPromptType';
|
||||
name: string;
|
||||
model: string;
|
||||
action: string | null;
|
||||
config: {
|
||||
__typename?: 'CopilotPromptConfigType';
|
||||
frequencyPenalty: number | null;
|
||||
presencePenalty: number | null;
|
||||
temperature: number | null;
|
||||
topP: number | null;
|
||||
} | null;
|
||||
messages: Array<{
|
||||
__typename?: 'CopilotPromptMessageType';
|
||||
role: CopilotPromptMessageRole;
|
||||
content: string;
|
||||
params: Record<string, string> | null;
|
||||
}>;
|
||||
}>;
|
||||
};
|
||||
|
||||
export type UpdatePromptMutationVariables = Exact<{
|
||||
name: Scalars['String']['input'];
|
||||
messages: Array<CopilotPromptMessageInput> | CopilotPromptMessageInput;
|
||||
}>;
|
||||
|
||||
export type UpdatePromptMutation = {
|
||||
__typename?: 'Mutation';
|
||||
updateCopilotPrompt: {
|
||||
__typename?: 'CopilotPromptType';
|
||||
name: string;
|
||||
model: string;
|
||||
action: string | null;
|
||||
config: {
|
||||
__typename?: 'CopilotPromptConfigType';
|
||||
frequencyPenalty: number | null;
|
||||
presencePenalty: number | null;
|
||||
temperature: number | null;
|
||||
topP: number | null;
|
||||
} | null;
|
||||
messages: Array<{
|
||||
__typename?: 'CopilotPromptMessageType';
|
||||
role: CopilotPromptMessageRole;
|
||||
content: string;
|
||||
params: Record<string, string> | null;
|
||||
}>;
|
||||
};
|
||||
};
|
||||
|
||||
export type CreateUserMutationVariables = Exact<{
|
||||
input: CreateUserInput;
|
||||
}>;
|
||||
@@ -5315,47 +5425,6 @@ export type CleanupCopilotSessionMutation = {
|
||||
cleanupCopilotSession: Array<string>;
|
||||
};
|
||||
|
||||
export type CreateCopilotSessionWithHistoryMutationVariables = Exact<{
|
||||
options: CreateChatSessionInput;
|
||||
}>;
|
||||
|
||||
export type CreateCopilotSessionWithHistoryMutation = {
|
||||
__typename?: 'Mutation';
|
||||
createCopilotSessionWithHistory: {
|
||||
__typename?: 'CopilotHistories';
|
||||
sessionId: string;
|
||||
workspaceId: string;
|
||||
docId: string | null;
|
||||
parentSessionId: string | null;
|
||||
promptName: string;
|
||||
model: string;
|
||||
optionalModels: Array<string>;
|
||||
action: string | null;
|
||||
pinned: boolean;
|
||||
title: string | null;
|
||||
tokens: number;
|
||||
createdAt: string;
|
||||
updatedAt: string;
|
||||
messages: Array<{
|
||||
__typename?: 'ChatMessage';
|
||||
id: string | null;
|
||||
role: string;
|
||||
content: string;
|
||||
attachments: Array<string> | null;
|
||||
createdAt: string;
|
||||
streamObjects: Array<{
|
||||
__typename?: 'StreamObject';
|
||||
type: string;
|
||||
textDelta: string | null;
|
||||
toolCallId: string | null;
|
||||
toolName: string | null;
|
||||
args: Record<string, string> | null;
|
||||
result: Record<string, string> | null;
|
||||
}> | null;
|
||||
}>;
|
||||
};
|
||||
};
|
||||
|
||||
export type CreateCopilotSessionMutationVariables = Exact<{
|
||||
options: CreateChatSessionInput;
|
||||
}>;
|
||||
@@ -5865,6 +5934,24 @@ export type GetDocRolePermissionsQuery = {
|
||||
};
|
||||
};
|
||||
|
||||
export type CopilotChatMessageFragment = {
|
||||
__typename?: 'ChatMessage';
|
||||
id: string | null;
|
||||
role: string;
|
||||
content: string;
|
||||
attachments: Array<string> | null;
|
||||
createdAt: string;
|
||||
streamObjects: Array<{
|
||||
__typename?: 'StreamObject';
|
||||
type: string;
|
||||
textDelta: string | null;
|
||||
toolCallId: string | null;
|
||||
toolName: string | null;
|
||||
args: Record<string, string> | null;
|
||||
result: Record<string, string> | null;
|
||||
}> | null;
|
||||
};
|
||||
|
||||
export type CopilotChatHistoryFragment = {
|
||||
__typename?: 'CopilotHistories';
|
||||
sessionId: string;
|
||||
@@ -5899,52 +5986,6 @@ export type CopilotChatHistoryFragment = {
|
||||
}>;
|
||||
};
|
||||
|
||||
export type CredentialsRequirementsFragment = {
|
||||
__typename?: 'CredentialsRequirementType';
|
||||
password: {
|
||||
__typename?: 'PasswordLimitsType';
|
||||
minLength: number;
|
||||
maxLength: number;
|
||||
};
|
||||
};
|
||||
|
||||
export type CurrentUserProfileFragment = {
|
||||
__typename?: 'UserType';
|
||||
id: string;
|
||||
name: string;
|
||||
email: string;
|
||||
avatarUrl: string | null;
|
||||
emailVerified: boolean;
|
||||
features: Array<FeatureType>;
|
||||
settings: {
|
||||
__typename?: 'UserSettingsType';
|
||||
receiveInvitationEmail: boolean;
|
||||
receiveMentionEmail: boolean;
|
||||
receiveCommentEmail: boolean;
|
||||
};
|
||||
quota: {
|
||||
__typename?: 'UserQuotaType';
|
||||
name: string;
|
||||
blobLimit: number;
|
||||
storageQuota: number;
|
||||
historyPeriod: number;
|
||||
memberLimit: number;
|
||||
humanReadable: {
|
||||
__typename?: 'UserQuotaHumanReadableType';
|
||||
name: string;
|
||||
blobLimit: string;
|
||||
storageQuota: string;
|
||||
historyPeriod: string;
|
||||
memberLimit: string;
|
||||
};
|
||||
};
|
||||
quotaUsage: { __typename?: 'UserQuotaUsageType'; storageQuota: number };
|
||||
copilot: {
|
||||
__typename?: 'Copilot';
|
||||
quota: { __typename?: 'CopilotQuota'; limit: number | null; used: number };
|
||||
};
|
||||
};
|
||||
|
||||
export type PaginatedCopilotChatsFragment = {
|
||||
__typename?: 'PaginatedCopilotHistoriesType';
|
||||
pageInfo: {
|
||||
@@ -5993,6 +6034,52 @@ export type PaginatedCopilotChatsFragment = {
|
||||
}>;
|
||||
};
|
||||
|
||||
export type CredentialsRequirementsFragment = {
|
||||
__typename?: 'CredentialsRequirementType';
|
||||
password: {
|
||||
__typename?: 'PasswordLimitsType';
|
||||
minLength: number;
|
||||
maxLength: number;
|
||||
};
|
||||
};
|
||||
|
||||
export type CurrentUserProfileFragment = {
|
||||
__typename?: 'UserType';
|
||||
id: string;
|
||||
name: string;
|
||||
email: string;
|
||||
avatarUrl: string | null;
|
||||
emailVerified: boolean;
|
||||
features: Array<FeatureType>;
|
||||
settings: {
|
||||
__typename?: 'UserSettingsType';
|
||||
receiveInvitationEmail: boolean;
|
||||
receiveMentionEmail: boolean;
|
||||
receiveCommentEmail: boolean;
|
||||
};
|
||||
quota: {
|
||||
__typename?: 'UserQuotaType';
|
||||
name: string;
|
||||
blobLimit: number;
|
||||
storageQuota: number;
|
||||
historyPeriod: number;
|
||||
memberLimit: number;
|
||||
humanReadable: {
|
||||
__typename?: 'UserQuotaHumanReadableType';
|
||||
name: string;
|
||||
blobLimit: string;
|
||||
storageQuota: string;
|
||||
historyPeriod: string;
|
||||
memberLimit: string;
|
||||
};
|
||||
};
|
||||
quotaUsage: { __typename?: 'UserQuotaUsageType'; storageQuota: number };
|
||||
copilot: {
|
||||
__typename?: 'Copilot';
|
||||
quota: { __typename?: 'CopilotQuota'; limit: number | null; used: number };
|
||||
};
|
||||
};
|
||||
|
||||
export type PasswordLimitsFragment = {
|
||||
__typename?: 'PasswordLimitsType';
|
||||
minLength: number;
|
||||
@@ -7536,6 +7623,11 @@ export type Queries =
|
||||
variables: AppConfigQueryVariables;
|
||||
response: AppConfigQuery;
|
||||
}
|
||||
| {
|
||||
name: 'getPromptsQuery';
|
||||
variables: GetPromptsQueryVariables;
|
||||
response: GetPromptsQuery;
|
||||
}
|
||||
| {
|
||||
name: 'getUserByEmailQuery';
|
||||
variables: GetUserByEmailQueryVariables;
|
||||
@@ -7943,6 +8035,11 @@ export type Mutations =
|
||||
variables: CreateChangePasswordUrlMutationVariables;
|
||||
response: CreateChangePasswordUrlMutation;
|
||||
}
|
||||
| {
|
||||
name: 'updatePromptMutation';
|
||||
variables: UpdatePromptMutationVariables;
|
||||
response: UpdatePromptMutation;
|
||||
}
|
||||
| {
|
||||
name: 'createUserMutation';
|
||||
variables: CreateUserMutationVariables;
|
||||
@@ -8178,11 +8275,6 @@ export type Mutations =
|
||||
variables: CleanupCopilotSessionMutationVariables;
|
||||
response: CleanupCopilotSessionMutation;
|
||||
}
|
||||
| {
|
||||
name: 'createCopilotSessionWithHistoryMutation';
|
||||
variables: CreateCopilotSessionWithHistoryMutationVariables;
|
||||
response: CreateCopilotSessionWithHistoryMutation;
|
||||
}
|
||||
| {
|
||||
name: 'createCopilotSessionMutation';
|
||||
variables: CreateCopilotSessionMutationVariables;
|
||||
|
||||
@@ -430,7 +430,9 @@ fn parse_markdown_inner(markdown: &str) -> Result<MarkdownDocument, ParseError>
|
||||
table_handled = true;
|
||||
}
|
||||
Event::Html(html) | Event::InlineHtml(html) => {
|
||||
if let Some(text) = extract_wrapped_html_text(html) {
|
||||
if is_html_comment(html) || is_iframe_end_tag(html) {
|
||||
// Ignore HTML comments and iframe end tags inside table cells.
|
||||
} else if let Some(text) = extract_wrapped_html_text(html) {
|
||||
state.push_text(&text);
|
||||
} else if is_html_line_break(html) {
|
||||
state.push_text("\n");
|
||||
@@ -621,6 +623,9 @@ fn parse_markdown_inner(markdown: &str) -> Result<MarkdownDocument, ParseError>
|
||||
}
|
||||
}
|
||||
Event::Html(html) | Event::InlineHtml(html) => {
|
||||
if is_html_comment(&html) || is_iframe_end_tag(&html) {
|
||||
continue;
|
||||
}
|
||||
if is_ai_editable_comment(&html) {
|
||||
continue;
|
||||
}
|
||||
@@ -773,6 +778,9 @@ fn validate_markdown_inner(markdown: &str) -> Result<(), ParseError> {
|
||||
match event {
|
||||
Event::Start(tag) => ensure_supported_tag(&tag)?,
|
||||
Event::Html(html) | Event::InlineHtml(html) => {
|
||||
if is_html_comment(&html) || is_iframe_end_tag(&html) {
|
||||
continue;
|
||||
}
|
||||
if is_ai_editable_comment(&html) {
|
||||
continue;
|
||||
}
|
||||
@@ -936,6 +944,15 @@ fn is_ai_editable_comment(html: &str) -> bool {
|
||||
body.contains("block_id=") && body.contains("flavour=")
|
||||
}
|
||||
|
||||
fn is_html_comment(html: &str) -> bool {
|
||||
let trimmed = html.trim();
|
||||
trimmed.starts_with("<!--") && trimmed.ends_with("-->")
|
||||
}
|
||||
|
||||
fn is_iframe_end_tag(html: &str) -> bool {
|
||||
parse_html_tag(html).is_some_and(|tag| tag.closing && tag.name == "iframe")
|
||||
}
|
||||
|
||||
fn is_html_line_break(html: &str) -> bool {
|
||||
let trimmed = html.trim();
|
||||
if !trimmed.starts_with('<') || !trimmed.ends_with('>') {
|
||||
@@ -1716,6 +1733,13 @@ mod tests {
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_markdown_allows_html_comment() {
|
||||
let markdown = "# Title\n\n<!-- omit from toc -->\n\nContent.";
|
||||
let result = validate_markdown(markdown);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_markdown_rejects_html() {
|
||||
let markdown = "# Title\n\n<div>HTML</div>";
|
||||
|
||||
@@ -282,6 +282,9 @@ pub fn parse_doc_to_markdown(
|
||||
0
|
||||
};
|
||||
let ai_block = ai_editable && block_level == 2;
|
||||
let ai_preserve_block = ai_block
|
||||
&& (matches!(flavour.as_str(), "affine:database" | "affine:callout")
|
||||
|| BlockFlavour::from_str(flavour.as_str()).is_none());
|
||||
|
||||
let mut block_markdown = String::new();
|
||||
|
||||
@@ -308,7 +311,9 @@ pub fn parse_doc_to_markdown(
|
||||
};
|
||||
renderer.write_block(&mut block_markdown, &spec, list_depth);
|
||||
} else {
|
||||
return Err(ParseError::ParserError(format!("unsupported_block_flavour:{flavour}")));
|
||||
block_markdown.push_str(&format!(
|
||||
"<!-- unsupported_block_flavour:{flavour} block_id={block_id} -->\n\n"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -317,6 +322,9 @@ pub fn parse_doc_to_markdown(
|
||||
markdown.push_str(&format!("<!-- block_id={block_id} flavour={flavour} -->\n"));
|
||||
}
|
||||
markdown.push_str(&block_markdown);
|
||||
if ai_preserve_block {
|
||||
markdown.push_str(&format!("<!-- block_id={block_id} flavour={flavour} end -->\n"));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(MarkdownResult {
|
||||
@@ -792,4 +800,59 @@ mod tests {
|
||||
assert!(md.contains("|A|B|"));
|
||||
assert!(md.contains("|---|---|"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_doc_to_markdown_skips_unsupported_block_flavour() {
|
||||
let doc_id = "unsupported-doc".to_string();
|
||||
let doc = DocOptions::new().with_guid(doc_id.clone()).build();
|
||||
let mut blocks = doc.get_or_create_map("blocks").unwrap();
|
||||
|
||||
let mut page = doc.create_map().unwrap();
|
||||
page.insert("sys:id".into(), "page").unwrap();
|
||||
page.insert("sys:flavour".into(), "affine:page").unwrap();
|
||||
let mut page_children = doc.create_array().unwrap();
|
||||
page_children.push("note").unwrap();
|
||||
page.insert("sys:children".into(), Value::Array(page_children)).unwrap();
|
||||
let mut page_title = doc.create_text().unwrap();
|
||||
page_title.insert(0, "Page").unwrap();
|
||||
page.insert("prop:title".into(), Value::Text(page_title)).unwrap();
|
||||
blocks.insert("page".into(), Value::Map(page)).unwrap();
|
||||
|
||||
let mut note = doc.create_map().unwrap();
|
||||
note.insert("sys:id".into(), "note").unwrap();
|
||||
note.insert("sys:flavour".into(), "affine:note").unwrap();
|
||||
let mut note_children = doc.create_array().unwrap();
|
||||
note_children.push("latex").unwrap();
|
||||
note_children.push("paragraph").unwrap();
|
||||
note.insert("sys:children".into(), Value::Array(note_children)).unwrap();
|
||||
note.insert("prop:displayMode".into(), "page").unwrap();
|
||||
blocks.insert("note".into(), Value::Map(note)).unwrap();
|
||||
|
||||
let mut unsupported = doc.create_map().unwrap();
|
||||
unsupported.insert("sys:id".into(), "latex").unwrap();
|
||||
unsupported.insert("sys:flavour".into(), "affine:latex").unwrap();
|
||||
unsupported
|
||||
.insert("sys:children".into(), Value::Array(doc.create_array().unwrap()))
|
||||
.unwrap();
|
||||
blocks.insert("latex".into(), Value::Map(unsupported)).unwrap();
|
||||
|
||||
let mut paragraph = doc.create_map().unwrap();
|
||||
paragraph.insert("sys:id".into(), "paragraph").unwrap();
|
||||
paragraph.insert("sys:flavour".into(), "affine:paragraph").unwrap();
|
||||
paragraph
|
||||
.insert("sys:children".into(), Value::Array(doc.create_array().unwrap()))
|
||||
.unwrap();
|
||||
let mut paragraph_text = doc.create_text().unwrap();
|
||||
paragraph_text.insert(0, "After unsupported block").unwrap();
|
||||
paragraph
|
||||
.insert("prop:text".into(), Value::Text(paragraph_text))
|
||||
.unwrap();
|
||||
blocks.insert("paragraph".into(), Value::Map(paragraph)).unwrap();
|
||||
|
||||
let doc_bin = doc.encode_update_v1().unwrap();
|
||||
let result = parse_doc_to_markdown(doc_bin, doc_id, false, None).expect("parse doc");
|
||||
|
||||
assert!(result.markdown.contains("unsupported_block_flavour:affine:latex"));
|
||||
assert!(result.markdown.contains("After unsupported block"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//! Converts markdown content into AFFiNE-compatible y-octo document binary
|
||||
//! format.
|
||||
|
||||
use y_octo::DocOptions;
|
||||
use y_octo::{DocOptions, StateVector};
|
||||
|
||||
use super::{
|
||||
super::{
|
||||
@@ -73,7 +73,7 @@ fn build_doc_update(doc_id: &str, title: &str, blocks: &[BlockNode]) -> Result<V
|
||||
note_map.insert(PROP_HIDDEN.to_string(), Any::False)?;
|
||||
note_map.insert(PROP_DISPLAY_MODE.to_string(), Any::String("both".to_string()))?;
|
||||
|
||||
Ok(doc.encode_update_v1()?)
|
||||
Ok(doc.encode_state_as_update_v1(&StateVector::default())?)
|
||||
}
|
||||
|
||||
fn insert_block_trees(doc: &Doc, blocks_map: &mut Map, blocks: &[BlockNode]) -> Result<Vec<String>, ParseError> {
|
||||
|
||||
@@ -8,19 +8,37 @@ use std::collections::HashMap;
|
||||
use super::{
|
||||
super::{
|
||||
block_spec::{TreeNode, count_tree_nodes, text_delta_eq},
|
||||
blocksuite::{collect_child_ids, find_child_id_by_flavour},
|
||||
blocksuite::{collect_child_ids, find_child_id_by_flavour, get_string},
|
||||
markdown::{MAX_BLOCKS, parse_markdown_blocks},
|
||||
schema::{PROP_BACKGROUND, PROP_DISPLAY_MODE, PROP_ELEMENTS, PROP_HIDDEN, PROP_INDEX, PROP_XYWH, SURFACE_FLAVOUR},
|
||||
},
|
||||
builder::{
|
||||
ApplyBlockOptions, BOXED_NATIVE_TYPE, NOTE_BG_DARK, NOTE_BG_LIGHT, apply_block_spec, boxed_empty_map,
|
||||
insert_block_map, insert_block_tree, insert_children, insert_sys_fields, insert_text, note_background_map,
|
||||
text_ops_from_plain,
|
||||
},
|
||||
builder::{ApplyBlockOptions, apply_block_spec, insert_block_tree, insert_children},
|
||||
*,
|
||||
};
|
||||
|
||||
const MAX_LCS_CELLS: usize = 2_000_000;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum NodeSpec {
|
||||
Supported(BlockSpec),
|
||||
/// A block flavour we don't support for markdown diffing/updating (e.g.
|
||||
/// `affine:database`).
|
||||
///
|
||||
/// These nodes are treated as opaque: we preserve them and never modify their
|
||||
/// properties/children.
|
||||
Opaque {
|
||||
flavour: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct StoredNode {
|
||||
id: String,
|
||||
spec: BlockSpec,
|
||||
spec: NodeSpec,
|
||||
children: Vec<StoredNode>,
|
||||
}
|
||||
|
||||
@@ -30,6 +48,20 @@ impl TreeNode for StoredNode {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TargetNode {
|
||||
/// Optional block id marker from exported markdown (AI-editable markers).
|
||||
id_hint: Option<String>,
|
||||
spec: NodeSpec,
|
||||
children: Vec<TargetNode>,
|
||||
}
|
||||
|
||||
impl TreeNode for TargetNode {
|
||||
fn children(&self) -> &[TargetNode] {
|
||||
&self.children
|
||||
}
|
||||
}
|
||||
|
||||
struct DocState {
|
||||
doc: Doc,
|
||||
note_id: String,
|
||||
@@ -59,8 +91,24 @@ enum PatchOp {
|
||||
/// # Returns
|
||||
/// A binary vector representing only the delta (changes) to apply
|
||||
pub fn update_doc(existing_binary: &[u8], new_markdown: &str, doc_id: &str) -> Result<Vec<u8>, ParseError> {
|
||||
let mut new_nodes = parse_markdown_blocks(new_markdown)?;
|
||||
let state = load_doc_state(existing_binary, doc_id)?;
|
||||
let state = match load_doc_state(existing_binary, doc_id) {
|
||||
Ok(state) => state,
|
||||
Err(ParseError::ParserError(msg))
|
||||
if matches!(
|
||||
msg.as_str(),
|
||||
"blocks map is empty" | "page block not found" | "note block not found"
|
||||
) =>
|
||||
{
|
||||
// The existing doc may be a stub/partial document (e.g. created by references)
|
||||
// and doesn't contain the canonical page/note structure yet. In that
|
||||
// case, initialize the doc from the markdown instead of failing hard.
|
||||
let new_nodes = parse_markdown_blocks(new_markdown)?;
|
||||
return init_doc_from_markdown(existing_binary, new_markdown, doc_id, &new_nodes);
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
|
||||
let mut new_nodes = parse_markdown_targets(new_markdown)?;
|
||||
|
||||
check_limits(&state.blocks, &new_nodes)?;
|
||||
|
||||
@@ -74,6 +122,315 @@ pub fn update_doc(existing_binary: &[u8], new_markdown: &str, doc_id: &str) -> R
|
||||
Ok(state.doc.encode_state_as_update_v1(&state_before)?)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct BlockMarker {
|
||||
id: String,
|
||||
flavour: String,
|
||||
end: bool,
|
||||
}
|
||||
|
||||
fn parse_block_marker_line(line: &str) -> Option<BlockMarker> {
|
||||
let trimmed = line.trim();
|
||||
if !trimmed.starts_with("<!--") || !trimmed.ends_with("-->") {
|
||||
return None;
|
||||
}
|
||||
let body = trimmed.trim_start_matches("<!--").trim_end_matches("-->").trim();
|
||||
if !body.contains("block_id=") || !body.contains("flavour=") {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut id: Option<String> = None;
|
||||
let mut flavour: Option<String> = None;
|
||||
let mut end = false;
|
||||
|
||||
for token in body.split_whitespace() {
|
||||
if token == "end" || token == "type=end" || token == "end=true" {
|
||||
end = true;
|
||||
continue;
|
||||
}
|
||||
if let Some(value) = token.strip_prefix("block_id=") {
|
||||
if !value.is_empty() {
|
||||
id = Some(value.to_string());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if let Some(value) = token.strip_prefix("flavour=") {
|
||||
if !value.is_empty() {
|
||||
flavour = Some(value.to_string());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
Some(BlockMarker {
|
||||
id: id?,
|
||||
flavour: flavour?,
|
||||
end,
|
||||
})
|
||||
}
|
||||
|
||||
fn should_preserve_marker_flavour(flavour: &str) -> bool {
|
||||
matches!(flavour, "affine:database" | "affine:callout")
|
||||
}
|
||||
|
||||
fn parse_markdown_targets(markdown: &str) -> Result<Vec<TargetNode>, ParseError> {
|
||||
// Fast path: no markers, behave like the original implementation.
|
||||
if !markdown.contains("block_id=") || !markdown.contains("flavour=") {
|
||||
let blocks = parse_markdown_blocks(markdown)?;
|
||||
return Ok(blocks.into_iter().map(|b| target_from_block_node(b, None)).collect());
|
||||
}
|
||||
|
||||
// Split the markdown by marker comments. For most blocks, a marker indicates
|
||||
// the start of a block. For preserved blocks (e.g. database), an optional end
|
||||
// marker can be emitted so users can append new content after the preserved
|
||||
// section without needing to add markers manually.
|
||||
let mut segments: Vec<(Option<BlockMarker>, String)> = Vec::new();
|
||||
let mut current_marker: Option<BlockMarker> = None;
|
||||
let mut current_body = String::new();
|
||||
let mut saw_marker = false;
|
||||
|
||||
for line in markdown.lines() {
|
||||
if let Some(marker) = parse_block_marker_line(line) {
|
||||
saw_marker = true;
|
||||
if marker.end {
|
||||
if current_marker.is_some() || !current_body.is_empty() {
|
||||
segments.push((current_marker.take(), std::mem::take(&mut current_body)));
|
||||
}
|
||||
// Close the marker scope; subsequent lines belong to an unmarked segment.
|
||||
current_marker = None;
|
||||
continue;
|
||||
}
|
||||
|
||||
if current_marker.is_some() || !current_body.is_empty() {
|
||||
segments.push((current_marker.take(), std::mem::take(&mut current_body)));
|
||||
}
|
||||
current_marker = Some(marker);
|
||||
continue;
|
||||
}
|
||||
|
||||
current_body.push_str(line);
|
||||
current_body.push('\n');
|
||||
}
|
||||
|
||||
if current_marker.is_some() || !current_body.is_empty() {
|
||||
segments.push((current_marker.take(), current_body));
|
||||
}
|
||||
|
||||
if !saw_marker {
|
||||
let blocks = parse_markdown_blocks(markdown)?;
|
||||
return Ok(blocks.into_iter().map(|b| target_from_block_node(b, None)).collect());
|
||||
}
|
||||
|
||||
let mut out: Vec<TargetNode> = Vec::new();
|
||||
for (marker, body) in segments {
|
||||
if let Some(marker) = marker {
|
||||
let preserve =
|
||||
should_preserve_marker_flavour(&marker.flavour) || BlockFlavour::from_str(&marker.flavour).is_none();
|
||||
if preserve {
|
||||
out.push(TargetNode {
|
||||
id_hint: Some(marker.id),
|
||||
spec: NodeSpec::Opaque {
|
||||
flavour: marker.flavour,
|
||||
},
|
||||
children: Vec::new(),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
let blocks = parse_markdown_blocks(&body)?;
|
||||
for (idx, block) in blocks.into_iter().enumerate() {
|
||||
let id_hint = if idx == 0 { Some(marker.id.clone()) } else { None };
|
||||
out.push(target_from_block_node(block, id_hint));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let trimmed = body.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let blocks = parse_markdown_blocks(&body)?;
|
||||
for block in blocks {
|
||||
out.push(target_from_block_node(block, None));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn target_from_block_node(node: BlockNode, id_hint: Option<String>) -> TargetNode {
|
||||
TargetNode {
|
||||
id_hint,
|
||||
spec: NodeSpec::Supported(node.spec),
|
||||
children: node
|
||||
.children
|
||||
.into_iter()
|
||||
.map(|child| target_from_block_node(child, None))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn target_node_to_block_node(node: &TargetNode) -> Result<BlockNode, ParseError> {
|
||||
let NodeSpec::Supported(spec) = &node.spec else {
|
||||
return Err(ParseError::ParserError("cannot_insert_opaque_block".into()));
|
||||
};
|
||||
Ok(BlockNode {
|
||||
spec: spec.clone(),
|
||||
children: node
|
||||
.children
|
||||
.iter()
|
||||
.map(target_node_to_block_node)
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn init_doc_from_markdown(
|
||||
existing_binary: &[u8],
|
||||
new_markdown: &str,
|
||||
doc_id: &str,
|
||||
blocks: &[BlockNode],
|
||||
) -> Result<Vec<u8>, ParseError> {
|
||||
let doc = load_doc(existing_binary, Some(doc_id))?;
|
||||
let state_before = doc.get_state_vector();
|
||||
let mut blocks_map = doc.get_or_create_map("blocks")?;
|
||||
|
||||
let title = derive_title_from_markdown(new_markdown).unwrap_or_else(|| "Untitled".to_string());
|
||||
// Prefer reusing an existing page block if the doc already has one (but is
|
||||
// missing surface/note). This avoids creating multiple page roots when
|
||||
// recovering from partial documents.
|
||||
if !blocks_map.is_empty() {
|
||||
let index = build_block_index(&blocks_map);
|
||||
if let Some(page_id) = find_block_id_by_flavour(&index.block_pool, PAGE_FLAVOUR) {
|
||||
insert_page_children(&doc, &mut blocks_map, &page_id, &title, blocks)?;
|
||||
return Ok(doc.encode_state_as_update_v1(&state_before)?);
|
||||
}
|
||||
}
|
||||
|
||||
insert_page_doc(&doc, &mut blocks_map, &title, blocks)?;
|
||||
|
||||
Ok(doc.encode_state_as_update_v1(&state_before)?)
|
||||
}
|
||||
|
||||
fn derive_title_from_markdown(markdown: &str) -> Option<String> {
|
||||
for line in markdown.lines() {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Some(rest) = trimmed.strip_prefix("# ") {
|
||||
let title = rest.trim();
|
||||
if !title.is_empty() {
|
||||
return Some(title.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn insert_page_doc(doc: &Doc, blocks_map: &mut Map, title: &str, blocks: &[BlockNode]) -> Result<(), ParseError> {
|
||||
let page_id = nanoid::nanoid!();
|
||||
let surface_id = nanoid::nanoid!();
|
||||
let note_id = nanoid::nanoid!();
|
||||
|
||||
// Insert root blocks first to establish stable IDs.
|
||||
let mut page_map = insert_block_map(doc, blocks_map, &page_id)?;
|
||||
let mut surface_map = insert_block_map(doc, blocks_map, &surface_id)?;
|
||||
let mut note_map = insert_block_map(doc, blocks_map, ¬e_id)?;
|
||||
|
||||
// Create content blocks under note.
|
||||
let content_ids = insert_block_trees(doc, blocks_map, blocks)?;
|
||||
|
||||
// Page block.
|
||||
insert_sys_fields(&mut page_map, &page_id, PAGE_FLAVOUR)?;
|
||||
insert_children(doc, &mut page_map, &[surface_id.clone(), note_id.clone()])?;
|
||||
insert_text(doc, &mut page_map, PROP_TITLE, &text_ops_from_plain(title))?;
|
||||
|
||||
// Surface block.
|
||||
insert_sys_fields(&mut surface_map, &surface_id, SURFACE_FLAVOUR)?;
|
||||
insert_children(doc, &mut surface_map, &[])?;
|
||||
let mut boxed = boxed_empty_map(doc)?;
|
||||
surface_map.insert(PROP_ELEMENTS.to_string(), Value::Map(boxed.clone()))?;
|
||||
boxed.insert("type".to_string(), Any::String(BOXED_NATIVE_TYPE.to_string()))?;
|
||||
let value = doc.create_map()?;
|
||||
boxed.insert("value".to_string(), Value::Map(value))?;
|
||||
|
||||
// Note block.
|
||||
insert_sys_fields(&mut note_map, ¬e_id, NOTE_FLAVOUR)?;
|
||||
insert_children(doc, &mut note_map, &content_ids)?;
|
||||
let mut background = note_background_map(doc)?;
|
||||
note_map.insert(PROP_BACKGROUND.to_string(), Value::Map(background.clone()))?;
|
||||
background.insert("light".to_string(), Any::String(NOTE_BG_LIGHT.to_string()))?;
|
||||
background.insert("dark".to_string(), Any::String(NOTE_BG_DARK.to_string()))?;
|
||||
note_map.insert(PROP_XYWH.to_string(), Any::String("[0,0,800,95]".to_string()))?;
|
||||
note_map.insert(PROP_INDEX.to_string(), Any::String("a0".to_string()))?;
|
||||
note_map.insert(PROP_HIDDEN.to_string(), Any::False)?;
|
||||
note_map.insert(PROP_DISPLAY_MODE.to_string(), Any::String("both".to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn insert_page_children(
|
||||
doc: &Doc,
|
||||
blocks_map: &mut Map,
|
||||
page_id: &str,
|
||||
title: &str,
|
||||
blocks: &[BlockNode],
|
||||
) -> Result<(), ParseError> {
|
||||
let surface_id = nanoid::nanoid!();
|
||||
let note_id = nanoid::nanoid!();
|
||||
|
||||
// Insert root blocks first to establish stable IDs.
|
||||
let mut surface_map = insert_block_map(doc, blocks_map, &surface_id)?;
|
||||
let mut note_map = insert_block_map(doc, blocks_map, ¬e_id)?;
|
||||
|
||||
// Create content blocks under note.
|
||||
let content_ids = insert_block_trees(doc, blocks_map, blocks)?;
|
||||
|
||||
let Some(mut page_map) = blocks_map.get(page_id).and_then(|v| v.to_map()) else {
|
||||
return Err(ParseError::ParserError("page block not found".into()));
|
||||
};
|
||||
|
||||
// Page block.
|
||||
insert_sys_fields(&mut page_map, page_id, PAGE_FLAVOUR)?;
|
||||
insert_children(doc, &mut page_map, &[surface_id.clone(), note_id.clone()])?;
|
||||
if page_map.get(PROP_TITLE).is_none() {
|
||||
insert_text(doc, &mut page_map, PROP_TITLE, &text_ops_from_plain(title))?;
|
||||
}
|
||||
|
||||
// Surface block.
|
||||
insert_sys_fields(&mut surface_map, &surface_id, SURFACE_FLAVOUR)?;
|
||||
insert_children(doc, &mut surface_map, &[])?;
|
||||
let mut boxed = boxed_empty_map(doc)?;
|
||||
surface_map.insert(PROP_ELEMENTS.to_string(), Value::Map(boxed.clone()))?;
|
||||
boxed.insert("type".to_string(), Any::String(BOXED_NATIVE_TYPE.to_string()))?;
|
||||
let value = doc.create_map()?;
|
||||
boxed.insert("value".to_string(), Value::Map(value))?;
|
||||
|
||||
// Note block.
|
||||
insert_sys_fields(&mut note_map, ¬e_id, NOTE_FLAVOUR)?;
|
||||
insert_children(doc, &mut note_map, &content_ids)?;
|
||||
let mut background = note_background_map(doc)?;
|
||||
note_map.insert(PROP_BACKGROUND.to_string(), Value::Map(background.clone()))?;
|
||||
background.insert("light".to_string(), Any::String(NOTE_BG_LIGHT.to_string()))?;
|
||||
background.insert("dark".to_string(), Any::String(NOTE_BG_DARK.to_string()))?;
|
||||
note_map.insert(PROP_XYWH.to_string(), Any::String("[0,0,800,95]".to_string()))?;
|
||||
note_map.insert(PROP_INDEX.to_string(), Any::String("a0".to_string()))?;
|
||||
note_map.insert(PROP_HIDDEN.to_string(), Any::False)?;
|
||||
note_map.insert(PROP_DISPLAY_MODE.to_string(), Any::String("both".to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn insert_block_trees(doc: &Doc, blocks_map: &mut Map, blocks: &[BlockNode]) -> Result<Vec<String>, ParseError> {
|
||||
let mut ids = Vec::with_capacity(blocks.len());
|
||||
for block in blocks {
|
||||
let id = insert_block_tree(doc, blocks_map, block)?;
|
||||
ids.push(id);
|
||||
}
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
fn load_doc_state(binary: &[u8], doc_id: &str) -> Result<DocState, ParseError> {
|
||||
let doc = load_doc(binary, Some(doc_id))?;
|
||||
|
||||
@@ -110,14 +467,31 @@ fn load_doc_state(binary: &[u8], doc_id: &str) -> Result<DocState, ParseError> {
|
||||
}
|
||||
|
||||
fn build_stored_tree(block_id: &str, block: &Map, pool: &HashMap<String, Map>) -> Result<StoredNode, ParseError> {
|
||||
let spec = BlockSpec::from_block_map(block)?;
|
||||
|
||||
let child_ids = collect_child_ids(block);
|
||||
let flavour = get_string(block, "sys:flavour").unwrap_or_default();
|
||||
|
||||
let spec = match BlockSpec::from_block_map(block) {
|
||||
Ok(spec) => spec,
|
||||
Err(ParseError::ParserError(msg)) if msg.starts_with("unsupported block flavour:") => {
|
||||
return Ok(StoredNode {
|
||||
id: block_id.to_string(),
|
||||
spec: NodeSpec::Opaque { flavour },
|
||||
children: Vec::new(),
|
||||
});
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
|
||||
// Only list/callout are supported as containers for markdown diffing.
|
||||
// For any other block with children, treat as opaque so we never corrupt it.
|
||||
if !child_ids.is_empty() && !matches!(spec.flavour, BlockFlavour::List | BlockFlavour::Callout) {
|
||||
return Err(ParseError::ParserError(format!(
|
||||
"unsupported children on block: {block_id}"
|
||||
)));
|
||||
return Ok(StoredNode {
|
||||
id: block_id.to_string(),
|
||||
spec: NodeSpec::Opaque { flavour },
|
||||
children: Vec::new(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut children = Vec::new();
|
||||
for child_id in child_ids {
|
||||
let child_block = pool
|
||||
@@ -128,7 +502,7 @@ fn build_stored_tree(block_id: &str, block: &Map, pool: &HashMap<String, Map>) -
|
||||
|
||||
Ok(StoredNode {
|
||||
id: block_id.to_string(),
|
||||
spec,
|
||||
spec: NodeSpec::Supported(spec),
|
||||
children,
|
||||
})
|
||||
}
|
||||
@@ -137,7 +511,7 @@ fn sync_nodes(
|
||||
doc: &Doc,
|
||||
blocks_map: &mut Map,
|
||||
current: &[StoredNode],
|
||||
target: &mut [BlockNode],
|
||||
target: &mut [TargetNode],
|
||||
) -> Result<Vec<String>, ParseError> {
|
||||
let ops = diff_blocks(current, target);
|
||||
let mut new_children = Vec::new();
|
||||
@@ -148,29 +522,47 @@ fn sync_nodes(
|
||||
PatchOp::Keep(old_idx, new_idx) => {
|
||||
let old_node = ¤t[old_idx];
|
||||
let new_node = &target[new_idx];
|
||||
update_block_props(doc, blocks_map, old_node, &new_node.spec, true)?;
|
||||
let child_ids = sync_nodes(doc, blocks_map, &old_node.children, &mut new_node.children.clone())?;
|
||||
sync_children(doc, blocks_map, &old_node.id, &child_ids)?;
|
||||
if let (NodeSpec::Supported(old_spec), NodeSpec::Supported(new_spec)) = (&old_node.spec, &new_node.spec) {
|
||||
update_block_props(doc, blocks_map, &old_node.id, old_spec, new_spec, true)?;
|
||||
let child_ids = sync_nodes(doc, blocks_map, &old_node.children, &mut new_node.children.clone())?;
|
||||
sync_children(doc, blocks_map, &old_node.id, &child_ids)?;
|
||||
} else {
|
||||
// Preserve opaque blocks (and any mismatched marker blocks) as-is.
|
||||
// Don't touch their properties or children ordering.
|
||||
}
|
||||
new_children.push(old_node.id.clone());
|
||||
}
|
||||
PatchOp::Update(old_idx, new_idx) => {
|
||||
let old_node = ¤t[old_idx];
|
||||
let new_node = &target[new_idx];
|
||||
update_block_props(doc, blocks_map, old_node, &new_node.spec, false)?;
|
||||
let child_ids = sync_nodes(doc, blocks_map, &old_node.children, &mut new_node.children.clone())?;
|
||||
sync_children(doc, blocks_map, &old_node.id, &child_ids)?;
|
||||
if let (NodeSpec::Supported(old_spec), NodeSpec::Supported(new_spec)) = (&old_node.spec, &new_node.spec) {
|
||||
update_block_props(doc, blocks_map, &old_node.id, old_spec, new_spec, false)?;
|
||||
let child_ids = sync_nodes(doc, blocks_map, &old_node.children, &mut new_node.children.clone())?;
|
||||
sync_children(doc, blocks_map, &old_node.id, &child_ids)?;
|
||||
} else {
|
||||
// Opaque blocks are never updated from markdown.
|
||||
}
|
||||
new_children.push(old_node.id.clone());
|
||||
}
|
||||
PatchOp::Insert(new_idx) => {
|
||||
let new_id = insert_block_tree(doc, blocks_map, &target[new_idx])?;
|
||||
new_children.push(new_id);
|
||||
if let Ok(node) = target_node_to_block_node(&target[new_idx]) {
|
||||
let new_id = insert_block_tree(doc, blocks_map, &node)?;
|
||||
new_children.push(new_id);
|
||||
}
|
||||
}
|
||||
PatchOp::Delete(old_idx) => {
|
||||
let node = ¤t[old_idx];
|
||||
if node.spec.flavour == BlockFlavour::Callout {
|
||||
new_children.push(node.id.clone());
|
||||
} else {
|
||||
collect_tree_ids(node, &mut to_remove);
|
||||
match &node.spec {
|
||||
NodeSpec::Opaque { .. } => {
|
||||
// Never delete opaque blocks when syncing from markdown. They might contain
|
||||
// rich data that can't be represented in markdown, so keeping them
|
||||
// avoids data loss.
|
||||
new_children.push(node.id.clone());
|
||||
}
|
||||
NodeSpec::Supported(spec) if spec.flavour == BlockFlavour::Callout => {
|
||||
new_children.push(node.id.clone());
|
||||
}
|
||||
NodeSpec::Supported(_) => collect_tree_ids(node, &mut to_remove),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -183,7 +575,7 @@ fn sync_nodes(
|
||||
Ok(new_children)
|
||||
}
|
||||
|
||||
fn diff_blocks(current: &[StoredNode], target: &[BlockNode]) -> Vec<PatchOp> {
|
||||
fn diff_blocks(current: &[StoredNode], target: &[TargetNode]) -> Vec<PatchOp> {
|
||||
let old_len = current.len();
|
||||
let new_len = target.len();
|
||||
|
||||
@@ -198,10 +590,10 @@ fn diff_blocks(current: &[StoredNode], target: &[BlockNode]) -> Vec<PatchOp> {
|
||||
|
||||
for i in 1..=old_len {
|
||||
for j in 1..=new_len {
|
||||
let old_spec = ¤t[i - 1].spec;
|
||||
let new_spec = &target[j - 1].spec;
|
||||
let old_node = ¤t[i - 1];
|
||||
let new_node = &target[j - 1];
|
||||
|
||||
if old_spec.is_exact(new_spec) {
|
||||
if nodes_align(old_node, new_node) {
|
||||
lcs[i][j] = lcs[i - 1][j - 1] + 1;
|
||||
} else {
|
||||
lcs[i][j] = std::cmp::max(lcs[i - 1][j], lcs[i][j - 1]);
|
||||
@@ -215,14 +607,18 @@ fn diff_blocks(current: &[StoredNode], target: &[BlockNode]) -> Vec<PatchOp> {
|
||||
|
||||
while i > 0 || j > 0 {
|
||||
if i > 0 && j > 0 {
|
||||
let old_spec = ¤t[i - 1].spec;
|
||||
let new_spec = &target[j - 1].spec;
|
||||
let old_node = ¤t[i - 1];
|
||||
let new_node = &target[j - 1];
|
||||
|
||||
if old_spec.is_exact(new_spec) {
|
||||
ops.push(PatchOp::Keep(i - 1, j - 1));
|
||||
if nodes_align(old_node, new_node) {
|
||||
if nodes_should_update(old_node, new_node) {
|
||||
ops.push(PatchOp::Update(i - 1, j - 1));
|
||||
} else {
|
||||
ops.push(PatchOp::Keep(i - 1, j - 1));
|
||||
}
|
||||
i -= 1;
|
||||
j -= 1;
|
||||
} else if old_spec.is_similar(new_spec)
|
||||
} else if nodes_similar(old_node, new_node)
|
||||
&& lcs[i - 1][j - 1] >= lcs[i - 1][j]
|
||||
&& lcs[i - 1][j - 1] >= lcs[i][j - 1]
|
||||
{
|
||||
@@ -249,15 +645,60 @@ fn diff_blocks(current: &[StoredNode], target: &[BlockNode]) -> Vec<PatchOp> {
|
||||
ops
|
||||
}
|
||||
|
||||
fn nodes_align(old_node: &StoredNode, new_node: &TargetNode) -> bool {
|
||||
if marker_matches(old_node, new_node) {
|
||||
return true;
|
||||
}
|
||||
match (&old_node.spec, &new_node.spec) {
|
||||
(NodeSpec::Supported(old_spec), NodeSpec::Supported(new_spec)) => old_spec.is_exact(new_spec),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn nodes_should_update(old_node: &StoredNode, new_node: &TargetNode) -> bool {
|
||||
if marker_matches(old_node, new_node) {
|
||||
return match (&old_node.spec, &new_node.spec) {
|
||||
(NodeSpec::Supported(old_spec), NodeSpec::Supported(new_spec)) => !old_spec.is_exact(new_spec),
|
||||
_ => false,
|
||||
};
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn nodes_similar(old_node: &StoredNode, new_node: &TargetNode) -> bool {
|
||||
match (&old_node.spec, &new_node.spec) {
|
||||
(NodeSpec::Supported(old_spec), NodeSpec::Supported(new_spec)) => old_spec.is_similar(new_spec),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn marker_matches(old_node: &StoredNode, new_node: &TargetNode) -> bool {
|
||||
let Some(id) = new_node.id_hint.as_deref() else {
|
||||
return false;
|
||||
};
|
||||
if id != old_node.id.as_str() {
|
||||
return false;
|
||||
}
|
||||
node_flavour_str(&old_node.spec) == node_flavour_str(&new_node.spec)
|
||||
}
|
||||
|
||||
fn node_flavour_str(spec: &NodeSpec) -> &str {
|
||||
match spec {
|
||||
NodeSpec::Supported(spec) => spec.flavour.as_str(),
|
||||
NodeSpec::Opaque { flavour } => flavour.as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_block_props(
|
||||
doc: &Doc,
|
||||
blocks_map: &mut Map,
|
||||
node: &StoredNode,
|
||||
node_id: &str,
|
||||
current: &BlockSpec,
|
||||
target: &BlockSpec,
|
||||
preserve_text: bool,
|
||||
) -> Result<(), ParseError> {
|
||||
let Some(mut block) = blocks_map.get(&node.id).and_then(|v| v.to_map()) else {
|
||||
return Err(ParseError::ParserError(format!("Block {} not found", node.id)));
|
||||
let Some(mut block) = blocks_map.get(node_id).and_then(|v| v.to_map()) else {
|
||||
return Err(ParseError::ParserError(format!("Block {} not found", node_id)));
|
||||
};
|
||||
|
||||
let preserve = match target.flavour {
|
||||
@@ -266,7 +707,7 @@ fn update_block_props(
|
||||
| BlockFlavour::Bookmark
|
||||
| BlockFlavour::EmbedYoutube
|
||||
| BlockFlavour::EmbedIframe => preserve_text,
|
||||
_ => preserve_text || text_delta_eq(&node.spec.text, &target.text),
|
||||
_ => preserve_text || text_delta_eq(¤t.text, &target.text),
|
||||
};
|
||||
|
||||
apply_block_spec(
|
||||
@@ -302,7 +743,7 @@ fn collect_tree_ids(node: &StoredNode, output: &mut Vec<String>) {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_limits(current: &[StoredNode], target: &[BlockNode]) -> Result<(), ParseError> {
|
||||
fn check_limits(current: &[StoredNode], target: &[TargetNode]) -> Result<(), ParseError> {
|
||||
let current_count = count_tree_nodes(current);
|
||||
let target_count = count_tree_nodes(target);
|
||||
|
||||
@@ -319,7 +760,7 @@ fn check_limits(current: &[StoredNode], target: &[BlockNode]) -> Result<(), Pars
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use y_octo::{Any, DocOptions, TextDeltaOp, TextInsert};
|
||||
use y_octo::{Any, DocOptions, StateVector, TextDeltaOp, TextInsert};
|
||||
|
||||
use super::{super::builder::text_ops_from_plain, *};
|
||||
use crate::doc_parser::{
|
||||
@@ -647,6 +1088,233 @@ mod tests {
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_ydoc_fallback_when_blocks_empty() {
|
||||
let doc_id = "stub-empty-blocks";
|
||||
let markdown = "# From Markdown\n\nHello from markdown.";
|
||||
|
||||
// Build a valid ydoc update that results in an empty `blocks` map.
|
||||
// NOTE: yjs/y-octo may encode a completely empty doc as `[0,0]`, which we treat
|
||||
// as empty/invalid. We intentionally insert + remove a temp key so the
|
||||
// update is non-empty while the final map is empty.
|
||||
let doc = DocOptions::new().with_guid(doc_id.to_string()).build();
|
||||
let mut blocks = doc.get_or_create_map("blocks").expect("create blocks map");
|
||||
blocks
|
||||
.insert("tmp".to_string(), Any::String("1".to_string()))
|
||||
.expect("insert temp");
|
||||
blocks.remove("tmp");
|
||||
let stub_bin = doc
|
||||
.encode_state_as_update_v1(&StateVector::default())
|
||||
.expect("encode stub update");
|
||||
assert!(
|
||||
!stub_bin.is_empty() && stub_bin.as_slice() != [0, 0],
|
||||
"stub update should not be empty update"
|
||||
);
|
||||
|
||||
let delta = update_doc(&stub_bin, markdown, doc_id).expect("fallback delta");
|
||||
assert!(!delta.is_empty(), "delta should contain changes");
|
||||
|
||||
let mut updated = DocOptions::new().with_guid(doc_id.to_string()).build();
|
||||
updated
|
||||
.apply_update_from_binary_v1(&stub_bin)
|
||||
.expect("apply stub update");
|
||||
updated
|
||||
.apply_update_from_binary_v1(&delta)
|
||||
.expect("apply fallback delta");
|
||||
|
||||
let blocks_map = updated.get_map("blocks").expect("blocks map exists");
|
||||
|
||||
let mut page: Option<Map> = None;
|
||||
for (_, value) in blocks_map.iter() {
|
||||
if let Some(block_map) = value.to_map()
|
||||
&& get_string(&block_map, "sys:flavour").as_deref() == Some(PAGE_FLAVOUR)
|
||||
{
|
||||
page = Some(block_map);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let page = page.expect("page block created");
|
||||
assert_eq!(
|
||||
get_string(&page, "prop:title").as_deref(),
|
||||
Some("From Markdown"),
|
||||
"page title should be derived from markdown H1"
|
||||
);
|
||||
|
||||
let index = build_block_index(&blocks_map);
|
||||
let note_id = find_child_id_by_flavour(&page, &index.block_pool, NOTE_FLAVOUR).expect("note child exists");
|
||||
|
||||
let note = index.block_pool.get(¬e_id).expect("note block exists").clone();
|
||||
assert!(
|
||||
!collect_child_ids(¬e).is_empty(),
|
||||
"note should contain imported content blocks"
|
||||
);
|
||||
|
||||
let full_bin = updated
|
||||
.encode_state_as_update_v1(&StateVector::default())
|
||||
.expect("encode full doc");
|
||||
let md = parse_doc_to_markdown(full_bin, doc_id.to_string(), false, None).expect("render markdown");
|
||||
assert!(md.markdown.contains("Hello from markdown."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_ydoc_fallback_when_page_missing() {
|
||||
let doc_id = "stub-page-missing";
|
||||
let markdown = "# Title\n\nUpdated content.";
|
||||
|
||||
// Build a stub doc that has some blocks, but no `affine:page` root.
|
||||
let doc = DocOptions::new().with_guid(doc_id.to_string()).build();
|
||||
let mut blocks_map = doc.get_or_create_map("blocks").expect("create blocks map");
|
||||
let para_id = "para-1";
|
||||
let mut para = insert_block_map(&doc, &mut blocks_map, para_id).expect("insert para");
|
||||
insert_sys_fields(&mut para, para_id, "affine:paragraph").expect("sys fields");
|
||||
insert_children(&doc, &mut para, &[]).expect("children");
|
||||
|
||||
let stub_bin = doc
|
||||
.encode_state_as_update_v1(&StateVector::default())
|
||||
.expect("encode stub update");
|
||||
assert!(!stub_bin.is_empty(), "stub update should not be empty");
|
||||
|
||||
let delta = update_doc(&stub_bin, markdown, doc_id).expect("fallback delta");
|
||||
assert!(!delta.is_empty(), "delta should contain changes");
|
||||
|
||||
let mut updated = DocOptions::new().with_guid(doc_id.to_string()).build();
|
||||
updated
|
||||
.apply_update_from_binary_v1(&stub_bin)
|
||||
.expect("apply stub update");
|
||||
updated
|
||||
.apply_update_from_binary_v1(&delta)
|
||||
.expect("apply fallback delta");
|
||||
|
||||
let blocks_map = updated.get_map("blocks").expect("blocks map exists");
|
||||
let index = build_block_index(&blocks_map);
|
||||
let page_id = find_block_id_by_flavour(&index.block_pool, PAGE_FLAVOUR).expect("page block exists");
|
||||
let page = index.block_pool.get(&page_id).expect("page map exists").clone();
|
||||
|
||||
let note_id = find_child_id_by_flavour(&page, &index.block_pool, NOTE_FLAVOUR).expect("note child exists");
|
||||
let note = index.block_pool.get(¬e_id).expect("note block exists").clone();
|
||||
assert!(
|
||||
!collect_child_ids(¬e).is_empty(),
|
||||
"note should contain imported content blocks"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_ydoc_fallback_when_note_missing() {
|
||||
let doc_id = "stub-note-missing";
|
||||
let markdown = "# Title\n\nUpdated content.";
|
||||
|
||||
// Build a stub doc that has an `affine:page` block but doesn't contain a note
|
||||
// child.
|
||||
let doc = DocOptions::new().with_guid(doc_id.to_string()).build();
|
||||
let mut blocks_map = doc.get_or_create_map("blocks").expect("create blocks map");
|
||||
let page_id = "page-1";
|
||||
let mut page = insert_block_map(&doc, &mut blocks_map, page_id).expect("insert page");
|
||||
insert_sys_fields(&mut page, page_id, PAGE_FLAVOUR).expect("sys fields");
|
||||
insert_children(&doc, &mut page, &[]).expect("children");
|
||||
|
||||
let stub_bin = doc
|
||||
.encode_state_as_update_v1(&StateVector::default())
|
||||
.expect("encode stub update");
|
||||
assert!(!stub_bin.is_empty(), "stub update should not be empty");
|
||||
|
||||
let delta = update_doc(&stub_bin, markdown, doc_id).expect("fallback delta");
|
||||
assert!(!delta.is_empty(), "delta should contain changes");
|
||||
|
||||
let mut updated = DocOptions::new().with_guid(doc_id.to_string()).build();
|
||||
updated
|
||||
.apply_update_from_binary_v1(&stub_bin)
|
||||
.expect("apply stub update");
|
||||
updated
|
||||
.apply_update_from_binary_v1(&delta)
|
||||
.expect("apply fallback delta");
|
||||
|
||||
let blocks_map = updated.get_map("blocks").expect("blocks map exists");
|
||||
let index = build_block_index(&blocks_map);
|
||||
let page_id = find_block_id_by_flavour(&index.block_pool, PAGE_FLAVOUR).expect("page block exists");
|
||||
let page = index.block_pool.get(&page_id).expect("page map exists").clone();
|
||||
|
||||
let note_id = find_child_id_by_flavour(&page, &index.block_pool, NOTE_FLAVOUR).expect("note child exists");
|
||||
let note = index.block_pool.get(¬e_id).expect("note block exists").clone();
|
||||
assert!(
|
||||
!collect_child_ids(¬e).is_empty(),
|
||||
"note should contain imported content blocks"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_ydoc_preserves_opaque_blocks_when_unsupported_block_flavour() {
|
||||
let doc_id = "unsupported-flavour-replace";
|
||||
|
||||
// Build a doc with canonical page/note structure, but add an unsupported block
|
||||
// flavour under note. This simulates real-world docs that contain blocks we
|
||||
// don't support for structural diffing.
|
||||
let doc = DocOptions::new().with_guid(doc_id.to_string()).build();
|
||||
let mut blocks_map = doc.get_or_create_map("blocks").expect("create blocks map");
|
||||
|
||||
let page_id = "page-1";
|
||||
let surface_id = "surface-1";
|
||||
let note_id = "note-1";
|
||||
let db_id = "db-1";
|
||||
|
||||
let mut page = insert_block_map(&doc, &mut blocks_map, page_id).expect("insert page");
|
||||
let mut surface = insert_block_map(&doc, &mut blocks_map, surface_id).expect("insert surface");
|
||||
let mut note = insert_block_map(&doc, &mut blocks_map, note_id).expect("insert note");
|
||||
let mut db = insert_block_map(&doc, &mut blocks_map, db_id).expect("insert db");
|
||||
|
||||
insert_sys_fields(&mut page, page_id, PAGE_FLAVOUR).expect("page sys fields");
|
||||
insert_children(&doc, &mut page, &[surface_id.to_string(), note_id.to_string()]).expect("page children");
|
||||
insert_text(&doc, &mut page, PROP_TITLE, &text_ops_from_plain("Title")).expect("page title");
|
||||
|
||||
insert_sys_fields(&mut surface, surface_id, SURFACE_FLAVOUR).expect("surface sys fields");
|
||||
insert_children(&doc, &mut surface, &[]).expect("surface children");
|
||||
let mut boxed = boxed_empty_map(&doc).expect("boxed map");
|
||||
surface
|
||||
.insert(PROP_ELEMENTS.to_string(), Value::Map(boxed.clone()))
|
||||
.expect("surface elements");
|
||||
boxed
|
||||
.insert("type".to_string(), Any::String(BOXED_NATIVE_TYPE.to_string()))
|
||||
.expect("boxed type");
|
||||
let value = doc.create_map().expect("boxed value map");
|
||||
boxed
|
||||
.insert("value".to_string(), Value::Map(value))
|
||||
.expect("boxed value");
|
||||
|
||||
insert_sys_fields(&mut note, note_id, NOTE_FLAVOUR).expect("note sys fields");
|
||||
insert_children(&doc, &mut note, &[db_id.to_string()]).expect("note children");
|
||||
|
||||
// Unsupported flavour.
|
||||
insert_sys_fields(&mut db, db_id, "affine:database").expect("db sys fields");
|
||||
insert_children(&doc, &mut db, &[]).expect("db children");
|
||||
|
||||
let initial_bin = doc
|
||||
.encode_state_as_update_v1(&StateVector::default())
|
||||
.expect("encode initial");
|
||||
|
||||
// Updating should succeed and preserve the opaque block rather than deleting
|
||||
// it.
|
||||
let updated_md = "# New Title\n\nHello.";
|
||||
let delta = update_doc(&initial_bin, updated_md, doc_id).expect("delta");
|
||||
assert!(!delta.is_empty(), "delta should contain changes");
|
||||
|
||||
let mut updated_doc = DocOptions::new().with_guid(doc_id.to_string()).build();
|
||||
updated_doc
|
||||
.apply_update_from_binary_v1(&initial_bin)
|
||||
.expect("apply initial");
|
||||
updated_doc.apply_update_from_binary_v1(&delta).expect("apply delta");
|
||||
|
||||
let blocks_map = updated_doc.get_map("blocks").expect("blocks map");
|
||||
assert!(
|
||||
blocks_map.get(db_id).is_some(),
|
||||
"opaque block should be preserved when syncing from markdown"
|
||||
);
|
||||
|
||||
let md = parse_doc_to_markdown(updated_doc.encode_update_v1().unwrap(), doc_id.to_string(), false, None)
|
||||
.expect("render markdown")
|
||||
.markdown;
|
||||
assert!(md.contains("Hello."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_ydoc_markdown_too_large() {
|
||||
let initial_md = "# Title\n\nContent.";
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
"./broadcast-channel": "./src/impls/broadcast-channel/index.ts",
|
||||
"./idb/v1": "./src/impls/idb/v1/index.ts",
|
||||
"./cloud": "./src/impls/cloud/index.ts",
|
||||
"./disk": "./src/impls/disk/index.ts",
|
||||
"./sqlite": "./src/impls/sqlite/index.ts",
|
||||
"./sqlite/v1": "./src/impls/sqlite/v1/index.ts",
|
||||
"./sync": "./src/sync/index.ts",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user